The Proof
Before we start the proof, we provide a brief overview of the general structure.
The rest of the document will focus on the concrete proof for ADDW - but the details specific to this instruction will be restricted to just one lemma.
The rest will be plumbing.
Although we focus on the ADDW instruction in this post, the general proof outline will hold for all FormatR and FormatI instructions.
The crux of the idea is that all these instructions simply read from registers and an immediate.
Reads do not change state.
After reading values in the source registers, the final outcome is to update a single register rd.
No other field in the state machine apart from rd changes.
So what all the proofs really boil down to is checking whether the value written to rd is the same, and both the jolt sequence and the original RISCV instruction write the same value to the same rd.
Nothing else should be touched ever.
However, to formally prove two imperative programs are equivalent, we will have to explicitly prove that every field of the state machine and the return values are the same. Given what we said above, this means a lot of the proof will constitute machinery to get rid of the monadic plumbing, saying both LHS and RHS sequences do not touch memory, so they must be equal. Lean cannot infer this on its own. At the same time, this is not interesting, and we do not want to spend time and energy doing this for every proof. As a result, we will try to "templatise" this and "generalise" this as much as possible (see below, as to how).
All the real work will happen in a helper lemma with the suffix _concrete in its name.
Here we do the main "math" of the proof via a bridge lemma (the concrete example below will further elucidate this).
This _concrete theorem statement will say, look this sequence of jolt instructions only affects register rd.
And the value written to rd is exactly the value written by the RISCV instruction.
Once we prove this lemma, in the main proof we just substitute this lemma in place of the jolt-update.
The net effect is that the main-theorem is a sequence of unfolding the monadic binds and re-writing, with one application of our helper lemma -- and bang the proof passes.
The hope then is for each FormatR and FormatI instruction, we'll just have to write 1 helper lemma, and the rest of the proof machinery, once written will just follow.
We now walk through the proof of theorem jolt_addw_eq_sail one tactic at a time. At each step, we show the goal before the tactic, the tactic itself, and how the goal changes.
The full proof:
theorem jolt_addw_eq_sail (rs2 rs1 rd : regidx) (hrd : rd ≠ regidx.Regidx 0)
(js : SailJoltState) (hwf : WellFormed js) :
projectResult ((jolt_addw rs2 rs1 rd).run js) =
(execute_RTYPEW rs2 rs1 rd ropw.ADDW).run js.sail := by
obtain ⟨js', v1, v2, hj_rx1, hj_rx2, hj, hj_sail⟩ :=
jolt_addw_concrete rs2 rs1 rd hrd js hwf
-- Sail side: unfold and rewrite with the same v1, v2
rw [execute_RTYPEW_ADDW_factored]
simp only [EStateM.run, bind, EStateM.bind, pure, EStateM.pure, hj_rx1, hj_rx2]
-- Jolt side: rewrite with concrete result
show projectResult ((jolt_addw rs2 rs1 rd).run js) = _
rw [hj]
simp only [projectResult, project]
rw [hj_sail]
-- Both sides now write the same value.
obtain ⟨s', hw⟩ := wX_shape rd _ js.sail
rw [hw]
congr 1
exact (wX_bits_eq_stateAfterWrite rd _ js.sail s' hw).symmStep 0: The initial goal
After by, Lean presents us with the goal we need to prove:
⊢ projectResult ((jolt_addw rs2 rs1 rd).run js) =
(execute_RTYPEW rs2 rs1 rd ropw.ADDW).run js.sail
This is exactly the theorem statement. Now we begin.
Helpers
The proof uses several helper lemmas for reasoning about register operations. These are documented in detail in the Register Operation Helpers post. The key ones are:
stateAfterWrite— returns a newSailStateidentical to the original except that registerrdnow holds the given value. It is a specification-level way of describing a register write — without going through the monadicwX_bits. It also guarantees that the only change in state is the write tord.wX_shape— writing to a register viawX_bitsalways succeeds. UnlikereadReg(which can throwError.Unreachableif a key is missing),wX_bitsalways produces an.okresult. This is why we don't need aWellFormed-like precondition for writes.wX_bits_eq_stateAfterWrite— running the monadic register writewX_bitsproduces exactly the same state as the purestateAfterWrite. This lets us move between the monadic world (used in the actual computation) and the specification world (used in the proof).
jolt_addw_concrete
This is the main worker lemma. All the heavy lifting happens here.
theorem jolt_addw_concrete (rs2 rs1 rd : regidx)
(hrd : rd ≠ regidx.Regidx 0) (js : SailJoltState) (hwf : WellFormed js) :
∃ (js' : SailJoltState) (v1 v2 : BitVec 64),
rX_bits rs1 js.sail = .ok v1 js.sail ∧ -- item 1
rX_bits rs2 js.sail = .ok v2 js.sail ∧ -- item 2
(jolt_addw rs2 rs1 rd).run js = .ok RETIRE_SUCCESS js' ∧ -- item 3
-- item 4
js'.sail = stateAfterWrite js.sail rd
(sign_extend (m := 64) (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0))
The above Lean lemma in plain English:
It says: there exist a final Jolt state js' and values v1, v2 such that (see comments on code block that map items below to lean statements):
- Reading rs1 succeeds:
rX_bits rs1 js.sail = .ok v1 js.sail— reading registerrs1from the initial state returns valuev1without modifying the state. - Reading rs2 succeeds:
rX_bits rs2 js.sail = .ok v2 js.sail— reading registerrs2returns valuev2without modifying the state. - Execution succeeds:
(jolt_addw rs2 rs1 rd).run js = .ok RETIRE_SUCCESS js'— running Jolt'sADDWon the initial state succeeds withRETIRE_SUCCESSand produces final statejs'. - The Sail state is characterised:
js'.sail = stateAfterWrite js.sail rd (sign_extend (m := 64) (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0))— the Sail component of the final state is the original Sail state with registerrdoverwritten bysign_extend(v1[31:0] + v2[31:0]).
What this theorem really allows us to do is the following.
Let X = sign_extend(v1[31:0] +₃₂ v2[31:0]).
Note that we know by staring at the execute_RTYPEW code that Sail's native ADDW writes X to rd — this is visible directly from its definition.
But jolt_addw does two separate steps:
- ADD writes
Y = v1 +₆₄ v2tord(the full 64-bit sum) - sign-extend-word reads
rdback, getsY, extracts the lower 32 bits, sign-extends to 64, and writes backsext₆₄(Y[31:0])
The claim in jolt_addw_concrete is that after both steps, rd holds X.
This requires the bridge lemma extractLsb_add:
theorem extractLsb_add (a b : BitVec 64) :
Sail.BitVec.extractLsb (a + b) 31 0 =
Sail.BitVec.extractLsb a 31 0 + Sail.BitVec.extractLsb b 31 0
which says truncation distributes over addition: (a +₆₄ b)[31:0] = a[31:0] +₃₂ b[31:0].
This rewrites the Jolt result sext₆₄((v1 + v2)[31:0]) into sext₆₄(v1[31:0] + v2[31:0]) = X.
When we invoke this in the main theorem, we will observe that all fields of SailState will be definitionally equal.
The RHS by definition writes X to rd and, we just showed that the LHS also writes X to rd.
Furthermore, the StateAfterWrite business tells us that nothing else could have changed!
The Proof
With the helpers defined, we now show how the main proof just decomposes into mechanical unwinding of the monadic machinery.
Step 1: obtain
obtain ⟨js', v1, v2, hj_rx1, hj_rx2, hj, hj_sail⟩ :=
jolt_addw_concrete rs2 rs1 rd hrd js hwf
This applies jolt_addw_concrete (which we just described above) to our hypotheses rs2, rs1, rd, hrd, js, and hwf. The result is an existential — obtain destructures it, introducing seven new facts into the context:
js' : SailJoltState— the final Jolt state after runningjolt_addwv1 : BitVec 64— the value read from registerrs1v2 : BitVec 64— the value read from registerrs2hj_rx1 : rX_bits rs1 js.sail = .ok v1 js.sail— readingrs1succeeds (and state is unchanged).hj_rx2 : rX_bits rs2 js.sail = .ok v2 js.sail— readingrs2succeeds (and state is unchanged).hj : (jolt_addw rs2 rs1 rd).run js = .ok RETIRE_SUCCESS js'—jolt_addwsucceedshj_sail : js'.sail = stateAfterWrite js.sail rd (sign_extend ...)— the final Sail state is characterised withrdset to the value given by(sign_extend .. ). Nothing else aboutjs'is different fromjs.
Obviously, applying the above lemma in no way should change our goal.
⊢ projectResult ((jolt_addw rs2 rs1 rd).run js) =
(execute_RTYPEW rs2 rs1 rd ropw.ADDW).run js.sail
But now we have concrete facts about what rX_bits rs1, rX_bits rs2, and jolt_addw compute, which we can use to rewrite both sides.
Step 2: Simplify the RHS
rw [execute_RTYPEW_ADDW_factored]
simp only [EStateM.run, bind, EStateM.bind, pure, EStateM.pure, hj_rx1, hj_rx2]
First, rw [execute_RTYPEW_ADDW_factored] replaces execute_RTYPEW rs2 rs1 rd ropw.ADDW with the factored form that directly reads rs1 and rs2, adds their lower 32 bits, sign-extends, and writes to rd.
This eliminates the match on ropw, and expresses the sail side in pure monadic form.
Then simp only targets the RHS. It:
- Unfolds the monadic plumbing1 —
bind,EStateM.bind,pure,EStateM.pure,EStateM.run— turning thedoblock into concrete function applications on the state.
For a gentle introduction to monads and bind, see this post.
- Rewrites with
hj_rx1— replacesrX_bits rs1 js.sailwith.ok v1 js.sail. - Rewrites with
hj_rx2— replacesrX_bits rs2 js.sailwith.ok v2 js.sail.
After this, the RHS is fully evaluated: the monadic chain is gone, v1 and v2 are substituted in, and what remains is the concrete result of writing to rd and returning RETIRE_SUCCESS.
The simp only in Step 2 unfolded EStateM.run on the LHS, turning (jolt_addw rs2 rs1 rd).run js into jolt_addw rs2 rs1 rd js (direct function application).
The goal is now:
⊢ projectResult (jolt_addw rs2 rs1 rd js)
=
match wX_bits rd (sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)) js.sail with
| EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
| EStateM.Result.error e s => EStateM.Result.error e s
The RHS is now a match on the result of wX_bits — if the write succeeds (.ok), return RETIRE_SUCCESS with the new state; if it errors, propagate the error.
Next we need to simplify the LHS to match.
Step 3: show + rw [hj]
show projectResult ((jolt_addw rs2 rs1 rd).run js) = _
rw [hj]
The show tactic re-folds this back to the .run form so that it matches the shape of hj : (jolt_addw rs2 rs1 rd).run js = .ok RETIRE_SUCCESS js'.
Then rw [hj] replaces (jolt_addw rs2 rs1 rd).run js with .ok RETIRE_SUCCESS js' on the LHS. The goal is now:
⊢ projectResult (EStateM.Result.ok RETIRE_SUCCESS js')
=
match wX_bits rd (sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)) js.sail with
| EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
| EStateM.Result.error e s => EStateM.Result.error e s
The LHS now has a concrete Result value — projectResult applied to .ok RETIRE_SUCCESS js'.
Step 4: simp only [projectResult, project]
simp only [projectResult, project]
Now we unfold projectResult and project on the LHS. projectResult (.ok RETIRE_SUCCESS js') reduces to .ok RETIRE_SUCCESS js'.sail, and project js' reduces to js'.sail. The goal becomes:
⊢ EStateM.Result.ok RETIRE_SUCCESS js'.sail
=
match wX_bits rd (sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)) js.sail with
| EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
| EStateM.Result.error e s => EStateM.Result.error e s
The LHS is now fully concrete: .ok RETIRE_SUCCESS js'.sail.
Step 5: rw [hj_sail]
rw [hj_sail]
We rewrite js'.sail using hj_sail : js'.sail = stateAfterWrite js.sail rd (sign_extend ...). The goal becomes:
⊢ EStateM.Result.ok RETIRE_SUCCESS
(stateAfterWrite js.sail rd
(sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)))
=
match wX_bits rd (sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)) js.sail with
| EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
| EStateM.Result.error e s => EStateM.Result.error e s
The LHS now says: .ok RETIRE_SUCCESS with the state after writing the computed value to rd. The RHS still has the match on wX_bits. We need to collapse that match.
Step 6: obtain ⟨s', hw⟩ := wX_shape rd _ js.sail + rw [hw]
obtain ⟨s', hw⟩ := wX_shape rd _ js.sail
rw [hw]
We use the helper lemma wX_shape2. The obtain gives us s' and hw : wX_bits rd (...) js.sail = .ok () s'.
wX_shape says wX_bits always succeeds: there exists some state s' such that wX_bits rd val js.sail = .ok () s'.
Then rw [hw] substitutes this into the RHS match, collapsing it to the .ok branch. The goal becomes:
⊢ EStateM.Result.ok RETIRE_SUCCESS
(stateAfterWrite js.sail rd
(sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)))
=
match EStateM.Result.ok () s' with
| EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
| EStateM.Result.error e s => EStateM.Result.error e s
The match on .ok () s' reduces to .ok RETIRE_SUCCESS s', so the goal is really:
⊢ EStateM.Result.ok RETIRE_SUCCESS
(stateAfterWrite js.sail rd
(sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)))
=
EStateM.Result.ok RETIRE_SUCCESS s'Step 7: congr 1
congr 1
Both sides are .ok RETIRE_SUCCESS _. The congr 1 tactic says: since the constructors and the first argument (RETIRE_SUCCESS) match, it suffices to show the second arguments are equal. The goal reduces to:
⊢ stateAfterWrite js.sail rd (sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)) = s'Step 8: exact
exact (wX_bits_eq_stateAfterWrite rd _ js.sail s' hw).symm
wX_bits_eq_stateAfterWrite is a lemma that connects wX_bits and stateAfterWrite: if wX_bits rd val s = .ok () s', then s' = stateAfterWrite s rd val. We have hw : wX_bits rd (...) js.sail = .ok () s', so this gives us s' = stateAfterWrite js.sail rd (...). The .symm flips it to stateAfterWrite js.sail rd (...) = s', which is exactly our goal. QED.
Other _concrete lemmas
For every other FormatR and FormatI instruction that we have proven so far, we have a bridge lemma and a concrete lemma.
Every instruction follows this same pattern. Here are all the _concrete lemmas and their bridge lemmas across the codebase:
| Instruction | Concrete Lemma | Bridge Lemma | File |
|---|---|---|---|
| ADDW | jolt_addw_concrete (line 67) | extractLsb_add | Addw.lean |
| ADDIW | jolt_addiw_concrete (line 71) | — | Addiw.lean |
| SUBW | jolt_subw_concrete (line 79) | extractLsb_sub | Subw.lean |
| MULW | jolt_mulw_concrete (line 49) | truncate distributes over multiply | Mulw.lean |
| MULH | jolt_mulh_concrete (line 67) | 7-step computation = mult_to_bits_half | Mulh.lean |
| SLL | jolt_sll_concrete (line 63) | multiply by 2^(rs2[5:0]) = shift_bits_left | Sll.lean |
| SLLI | jolt_slli_concrete (line 56) | multiply by 2^shamt = shift_bits_left | Slli.lean |
| SLLIW | jolt_slliw_concrete (line 54) | multiply + truncate = 32-bit shift | Slliw.lean |
| SLLW | jolt_sllw_concrete (line 103) | — | Sllw.lean |
| SRA | jolt_sra_concrete (line 49) | sshiftRight via ctz(bitmask) = shift_bits_right_arith | Sra.lean |
| SRAI | jolt_srai_concrete (line 69) | — | Srai.lean |
| SRAIW | jolt_sraiw_concrete (line 133) | 3-step Jolt value = arithmetic right shift | Sraiw.lean |
| SRAW | jolt_sraw_concrete (line 115) | — | Sraw.lean |
| SRL | jolt_srl_concrete (line 48) | logical shift via ctz(bitmask) = shift_bits_right | Srl.lean |
| SRLI | jolt_srli_concrete (line 63) | logical shift via ctz(bitmask) = shift_bits_right | Srli.lean |
| SRLIW | jolt_srliw_concrete (line 84) | slli-32 + srli via bitmask = 32-bit logical right shift | Srliw.lean |
| SRLW | jolt_srlw_concrete (line 145) | — | Srlw.lean |
All files are under JoltBytecode/EmbeddedSailJoltState/InstructionEquivalence/. Each _concrete lemma says: Jolt's decomposition succeeds with RETIRE_SUCCESS and the final Sail state equals stateAfterWrite js.sail rd (value) — where value is the instruction-specific computation. The bridge lemma (where needed) connects the Jolt computation to the Sail computation. The main _eq_sail theorem then shows the Sail instruction writes the same value via the uniform plumbing pattern described above.
Did We Need This Modularisation
We made a huge deal about having this concrete lemma.
Here's a glimpse into the goal state the same theorem passing cleanly but with everything done in the main proof for ADDW
(match
match
match
match
match
rX_bits rs1
{ regs := js.regs, choiceState := js.choiceState, mem := js.mem, tags := js.tags,
cycleCount := js.cycleCount, sailOutput := js.sailOutput } with
| EStateM.Result.ok a s =>
match rX_bits rs2 s with
| EStateM.Result.ok a_1 s => EStateM.Result.ok (a + a_1) s
| EStateM.Result.error e s => EStateM.Result.error e s
| EStateM.Result.error e s => EStateM.Result.error e s with
| EStateM.Result.ok a s =>
match wX_bits rd a s with
| EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
| EStateM.Result.error e s => EStateM.Result.error e s
| EStateM.Result.error e s => EStateM.Result.error e s with
| EStateM.Result.ok a ss' =>
EStateM.Result.ok a
{ regs := ss'.regs, choiceState := ss'.choiceState, mem := ss'.mem, tags := ss'.tags,
cycleCount := ss'.cycleCount, sailOutput := ss'.sailOutput, vregs := js.vregs }
| EStateM.Result.error e ss' =>
EStateM.Result.error e
{ regs := ss'.regs, choiceState := ss'.choiceState, mem := ss'.mem, tags := ss'.tags,
cycleCount := ss'.cycleCount, sailOutput := ss'.sailOutput, vregs := js.vregs } with
| EStateM.Result.ok a s =>
match
match
match
rX_bits rd
{ regs := s.regs, choiceState := s.choiceState, mem := s.mem, tags := s.tags, cycleCount := s.cycleCount,
sailOutput := s.sailOutput } with
| EStateM.Result.ok a ss' =>
EStateM.Result.ok a
{ regs := ss'.regs, choiceState := ss'.choiceState, mem := ss'.mem, tags := ss'.tags,
cycleCount := ss'.cycleCount, sailOutput := ss'.sailOutput, vregs := s.vregs }
| EStateM.Result.error e ss' =>
EStateM.Result.error e
{ regs := ss'.regs, choiceState := ss'.choiceState, mem := ss'.mem, tags := ss'.tags,
cycleCount := ss'.cycleCount, sailOutput := ss'.sailOutput, vregs := s.vregs } with
| EStateM.Result.ok a s =>
match
wX_bits rd (sign_extend (Sail.BitVec.extractLsb a 31 0))
{ regs := s.regs, choiceState := s.choiceState, mem := s.mem, tags := s.tags, cycleCount := s.cycleCount,
sailOutput := s.sailOutput } with
| EStateM.Result.ok a ss' =>
EStateM.Result.ok a
{ regs := ss'.regs, choiceState := ss'.choiceState, mem := ss'.mem, tags := ss'.tags,
cycleCount := ss'.cycleCount, sailOutput := ss'.sailOutput, vregs := s.vregs }
| EStateM.Result.error e ss' =>
EStateM.Result.error e
{ regs := ss'.regs, choiceState := ss'.choiceState, mem := ss'.mem, tags := ss'.tags,
cycleCount := ss'.cycleCount, sailOutput := ss'.sailOutput, vregs := s.vregs }
| EStateM.Result.error e s => EStateM.Result.error e s with
| EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
| EStateM.Result.error e s => EStateM.Result.error e s
| EStateM.Result.error e s => EStateM.Result.error e s with
| EStateM.Result.ok a js' =>
EStateM.Result.ok a
{ regs := js'.regs, choiceState := js'.choiceState, mem := js'.mem, tags := js'.tags,
cycleCount := js'.cycleCount, sailOutput := js'.sailOutput }
| EStateM.Result.error e js' =>
EStateM.Result.error e
{ regs := js'.regs, choiceState := js'.choiceState, mem := js'.mem, tags := js'.tags,
cycleCount := js'.cycleCount, sailOutput := js'.sailOutput }) =
match
rX_bits rs1
{ regs := js.regs, choiceState := js.choiceState, mem := js.mem, tags := js.tags, cycleCount := js.cycleCount,
sailOutput := js.sailOutput } with
| EStateM.Result.ok a s =>
match rX_bits rs2 s with
| EStateM.Result.ok a_1 s =>
match wX_bits rd (sign_extend (Sail.BitVec.extractLsb a 31 0 + Sail.BitVec.extractLsb a_1 31 0)) s with
| EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
| EStateM.Result.error e s => EStateM.Result.error e s
| EStateM.Result.error e s => EStateM.Result.error e s
| EStateM.Result.error e s => EStateM.Result.error e s
Remember that Lean is a functional programming language.
So every time we have multiple of lines of code in the do block, it leads to a match and bind call.