Before we start the proof, we provide a brief overview of the general structure. The rest of the document will focus on the concrete proof for ADDW - but the details specific to this instruction will be restricted to just one lemma. The rest will be plumbing.

A Note On The General Proof Technique

Although we focus on the ADDW instruction in this post, the general proof outline will hold for all FormatR and FormatI instructions. The crux of the idea is that all these instructions simply read from registers and an immediate. Reads do not change state. After reading values in the source registers, the final outcome is to update a single register rd. No other field in the state machine apart from rd changes. So what all the proofs really boil down to is checking whether the value written to rd is the same, and both the jolt sequence and the original RISCV instruction write the same value to the same rd. Nothing else should be touched ever.

However, to formally prove two imperative programs are equivalent, we will have to explicitly prove that every field of the state machine and the return values are the same. Given what we said above, this means a lot of the proof will constitute machinery to get rid of the monadic plumbing, saying both LHS and RHS sequences do not touch memory, so they must be equal. Lean cannot infer this on its own. At the same time, this is not interesting, and we do not want to spend time and energy doing this for every proof. As a result, we will try to "templatise" this and "generalise" this as much as possible (see below, as to how).

All the real work will happen in a helper lemma with the suffix _concrete in its name. Here we do the main "math" of the proof via a bridge lemma (the concrete example below will further elucidate this). This _concrete theorem statement will say, look this sequence of jolt instructions only affects register rd. And the value written to rd is exactly the value written by the RISCV instruction. Once we prove this lemma, in the main proof we just substitute this lemma in place of the jolt-update. The net effect is that the main-theorem is a sequence of unfolding the monadic binds and re-writing, with one application of our helper lemma -- and bang the proof passes.

The hope then is for each FormatR and FormatI instruction, we'll just have to write 1 helper lemma, and the rest of the proof machinery, once written will just follow.

We now walk through the proof of theorem jolt_addw_eq_sail one tactic at a time. At each step, we show the goal before the tactic, the tactic itself, and how the goal changes.

The full proof:

theorem jolt_addw_eq_sail (rs2 rs1 rd : regidx) (hrd : rd ≠ regidx.Regidx 0)
    (js : SailJoltState) (hwf : WellFormed js) :
    projectResult ((jolt_addw rs2 rs1 rd).run js) =
    (execute_RTYPEW rs2 rs1 rd ropw.ADDW).run js.sail := by
  obtain ⟨js', v1, v2, hj_rx1, hj_rx2, hj, hj_sail⟩ :=
    jolt_addw_concrete rs2 rs1 rd hrd js hwf
  -- Sail side: unfold and rewrite with the same v1, v2
  rw [execute_RTYPEW_ADDW_factored]
  simp only [EStateM.run, bind, EStateM.bind, pure, EStateM.pure, hj_rx1, hj_rx2]
  -- Jolt side: rewrite with concrete result
  show projectResult ((jolt_addw rs2 rs1 rd).run js) = _
  rw [hj]
  simp only [projectResult, project]
  rw [hj_sail]
  -- Both sides now write the same value.
  obtain ⟨s', hw⟩ := wX_shape rd _ js.sail
  rw [hw]
  congr 1
  exact (wX_bits_eq_stateAfterWrite rd _ js.sail s' hw).symm

Step 0: The initial goal

After by, Lean presents us with the goal we need to prove:

⊢ projectResult ((jolt_addw rs2 rs1 rd).run js) =
  (execute_RTYPEW rs2 rs1 rd ropw.ADDW).run js.sail

This is exactly the theorem statement. Now we begin.

Helpers

The proof uses several helper lemmas for reasoning about register operations. These are documented in detail in the Register Operation Helpers post. The key ones are:

  • stateAfterWrite — returns a new SailState identical to the original except that register rd now holds the given value. It is a specification-level way of describing a register write — without going through the monadic wX_bits. It also guarantees that the only change in state is the write to rd.
  • wX_shape — writing to a register via wX_bits always succeeds. Unlike readReg (which can throw Error.Unreachable if a key is missing), wX_bits always produces an .ok result. This is why we don't need a WellFormed-like precondition for writes.
  • wX_bits_eq_stateAfterWrite — running the monadic register write wX_bits produces exactly the same state as the pure stateAfterWrite. This lets us move between the monadic world (used in the actual computation) and the specification world (used in the proof).

jolt_addw_concrete

This is the main worker lemma. All the heavy lifting happens here.

theorem jolt_addw_concrete (rs2 rs1 rd : regidx)
    (hrd : rd ≠ regidx.Regidx 0) (js : SailJoltState) (hwf : WellFormed js) :
    ∃ (js' : SailJoltState) (v1 v2 : BitVec 64),
      rX_bits rs1 js.sail = .ok v1 js.sail ∧ -- item 1 
      rX_bits rs2 js.sail = .ok v2 js.sail ∧ -- item 2
      (jolt_addw rs2 rs1 rd).run js = .ok RETIRE_SUCCESS js' ∧ -- item 3 
      -- item 4
      js'.sail = stateAfterWrite js.sail rd 
        (sign_extend (m := 64) (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0))

The above Lean lemma in plain English:

(Jolt ADDW Concrete)

It says: there exist a final Jolt state js' and values v1, v2 such that (see comments on code block that map items below to lean statements):

  1. Reading rs1 succeeds: rX_bits rs1 js.sail = .ok v1 js.sail — reading register rs1 from the initial state returns value v1 without modifying the state.
  2. Reading rs2 succeeds: rX_bits rs2 js.sail = .ok v2 js.sail — reading register rs2 returns value v2 without modifying the state.
  3. Execution succeeds: (jolt_addw rs2 rs1 rd).run js = .ok RETIRE_SUCCESS js' — running Jolt's ADDW on the initial state succeeds with RETIRE_SUCCESS and produces final state js'.
  4. The Sail state is characterised: js'.sail = stateAfterWrite js.sail rd (sign_extend (m := 64) (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)) — the Sail component of the final state is the original Sail state with register rd overwritten by sign_extend(v1[31:0] + v2[31:0]).

What this theorem really allows us to do is the following.

Let X = sign_extend(v1[31:0] +₃₂ v2[31:0]). Note that we know by staring at the execute_RTYPEW code that Sail's native ADDW writes X to rd — this is visible directly from its definition.

But jolt_addw does two separate steps:

  1. ADD writes Y = v1 +₆₄ v2 to rd (the full 64-bit sum)
  2. sign-extend-word reads rd back, gets Y, extracts the lower 32 bits, sign-extends to 64, and writes back sext₆₄(Y[31:0])

The claim in jolt_addw_concrete is that after both steps, rd holds X. This requires the bridge lemma extractLsb_add:

theorem extractLsb_add (a b : BitVec 64) :
    Sail.BitVec.extractLsb (a + b) 31 0 =
    Sail.BitVec.extractLsb a 31 0 + Sail.BitVec.extractLsb b 31 0

which says truncation distributes over addition: (a +₆₄ b)[31:0] = a[31:0] +₃₂ b[31:0]. This rewrites the Jolt result sext₆₄((v1 + v2)[31:0]) into sext₆₄(v1[31:0] + v2[31:0]) = X.

The Main Benefit Of This Work

When we invoke this in the main theorem, we will observe that all fields of SailState will be definitionally equal. The RHS by definition writes X to rd and, we just showed that the LHS also writes X to rd. Furthermore, the StateAfterWrite business tells us that nothing else could have changed!

The Proof

With the helpers defined, we now show how the main proof just decomposes into mechanical unwinding of the monadic machinery.

Step 1: obtain

obtain ⟨js', v1, v2, hj_rx1, hj_rx2, hj, hj_sail⟩ :=
    jolt_addw_concrete rs2 rs1 rd hrd js hwf

This applies jolt_addw_concrete (which we just described above) to our hypotheses rs2, rs1, rd, hrd, js, and hwf. The result is an existential — obtain destructures it, introducing seven new facts into the context:

  • js' : SailJoltState — the final Jolt state after running jolt_addw
  • v1 : BitVec 64 — the value read from register rs1
  • v2 : BitVec 64 — the value read from register rs2
  • hj_rx1 : rX_bits rs1 js.sail = .ok v1 js.sail — reading rs1 succeeds (and state is unchanged).
  • hj_rx2 : rX_bits rs2 js.sail = .ok v2 js.sail — reading rs2 succeeds (and state is unchanged).
  • hj : (jolt_addw rs2 rs1 rd).run js = .ok RETIRE_SUCCESS js'jolt_addw succeeds
  • hj_sail : js'.sail = stateAfterWrite js.sail rd (sign_extend ...) — the final Sail state is characterised with rd set to the value given by (sign_extend .. ). Nothing else about js' is different from js.

Obviously, applying the above lemma in no way should change our goal.

⊢ projectResult ((jolt_addw rs2 rs1 rd).run js) =
  (execute_RTYPEW rs2 rs1 rd ropw.ADDW).run js.sail

But now we have concrete facts about what rX_bits rs1, rX_bits rs2, and jolt_addw compute, which we can use to rewrite both sides.

Step 2: Simplify the RHS

rw [execute_RTYPEW_ADDW_factored]
simp only [EStateM.run, bind, EStateM.bind, pure, EStateM.pure, hj_rx1, hj_rx2]

First, rw [execute_RTYPEW_ADDW_factored] replaces execute_RTYPEW rs2 rs1 rd ropw.ADDW with the factored form that directly reads rs1 and rs2, adds their lower 32 bits, sign-extends, and writes to rd. This eliminates the match on ropw, and expresses the sail side in pure monadic form.

Then simp only targets the RHS. It:

  1. Unfolds the monadic plumbing1bind, EStateM.bind, pure, EStateM.pure, EStateM.run — turning the do block into concrete function applications on the state.
1

For a gentle introduction to monads and bind, see this post.

  1. Rewrites with hj_rx1 — replaces rX_bits rs1 js.sail with .ok v1 js.sail.
  2. Rewrites with hj_rx2 — replaces rX_bits rs2 js.sail with .ok v2 js.sail.

After this, the RHS is fully evaluated: the monadic chain is gone, v1 and v2 are substituted in, and what remains is the concrete result of writing to rd and returning RETIRE_SUCCESS. The simp only in Step 2 unfolded EStateM.run on the LHS, turning (jolt_addw rs2 rs1 rd).run js into jolt_addw rs2 rs1 rd js (direct function application). The goal is now:

⊢ projectResult (jolt_addw rs2 rs1 rd js)
  =
  match wX_bits rd (sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)) js.sail with
  | EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
  | EStateM.Result.error e s => EStateM.Result.error e s

The RHS is now a match on the result of wX_bits — if the write succeeds (.ok), return RETIRE_SUCCESS with the new state; if it errors, propagate the error. Next we need to simplify the LHS to match.

Step 3: show + rw [hj]

show projectResult ((jolt_addw rs2 rs1 rd).run js) = _
rw [hj]

The show tactic re-folds this back to the .run form so that it matches the shape of hj : (jolt_addw rs2 rs1 rd).run js = .ok RETIRE_SUCCESS js'.

Then rw [hj] replaces (jolt_addw rs2 rs1 rd).run js with .ok RETIRE_SUCCESS js' on the LHS. The goal is now:

⊢ projectResult (EStateM.Result.ok RETIRE_SUCCESS js')
  =
  match wX_bits rd (sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)) js.sail with
  | EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
  | EStateM.Result.error e s => EStateM.Result.error e s

The LHS now has a concrete Result value — projectResult applied to .ok RETIRE_SUCCESS js'.

Step 4: simp only [projectResult, project]

simp only [projectResult, project]

Now we unfold projectResult and project on the LHS. projectResult (.ok RETIRE_SUCCESS js') reduces to .ok RETIRE_SUCCESS js'.sail, and project js' reduces to js'.sail. The goal becomes:

⊢ EStateM.Result.ok RETIRE_SUCCESS js'.sail
  =
  match wX_bits rd (sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)) js.sail with
  | EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
  | EStateM.Result.error e s => EStateM.Result.error e s

The LHS is now fully concrete: .ok RETIRE_SUCCESS js'.sail.

Step 5: rw [hj_sail]

rw [hj_sail]

We rewrite js'.sail using hj_sail : js'.sail = stateAfterWrite js.sail rd (sign_extend ...). The goal becomes:

⊢ EStateM.Result.ok RETIRE_SUCCESS
    (stateAfterWrite js.sail rd
      (sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)))
  =
  match wX_bits rd (sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)) js.sail with
  | EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
  | EStateM.Result.error e s => EStateM.Result.error e s

The LHS now says: .ok RETIRE_SUCCESS with the state after writing the computed value to rd. The RHS still has the match on wX_bits. We need to collapse that match.

Step 6: obtain ⟨s', hw⟩ := wX_shape rd _ js.sail + rw [hw]

obtain ⟨s', hw⟩ := wX_shape rd _ js.sail
rw [hw]

We use the helper lemma wX_shape2. The obtain gives us s' and hw : wX_bits rd (...) js.sail = .ok () s'.

2

wX_shape says wX_bits always succeeds: there exists some state s' such that wX_bits rd val js.sail = .ok () s'.

Then rw [hw] substitutes this into the RHS match, collapsing it to the .ok branch. The goal becomes:

⊢ EStateM.Result.ok RETIRE_SUCCESS
    (stateAfterWrite js.sail rd
      (sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)))
  =
  match EStateM.Result.ok () s' with
  | EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
  | EStateM.Result.error e s => EStateM.Result.error e s

The match on .ok () s' reduces to .ok RETIRE_SUCCESS s', so the goal is really:

⊢ EStateM.Result.ok RETIRE_SUCCESS
    (stateAfterWrite js.sail rd
      (sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)))
  =
  EStateM.Result.ok RETIRE_SUCCESS s'

Step 7: congr 1

congr 1

Both sides are .ok RETIRE_SUCCESS _. The congr 1 tactic says: since the constructors and the first argument (RETIRE_SUCCESS) match, it suffices to show the second arguments are equal. The goal reduces to:

⊢ stateAfterWrite js.sail rd (sign_extend (Sail.BitVec.extractLsb v1 31 0 + Sail.BitVec.extractLsb v2 31 0)) = s'

Step 8: exact

exact (wX_bits_eq_stateAfterWrite rd _ js.sail s' hw).symm

wX_bits_eq_stateAfterWrite is a lemma that connects wX_bits and stateAfterWrite: if wX_bits rd val s = .ok () s', then s' = stateAfterWrite s rd val. We have hw : wX_bits rd (...) js.sail = .ok () s', so this gives us s' = stateAfterWrite js.sail rd (...). The .symm flips it to stateAfterWrite js.sail rd (...) = s', which is exactly our goal. QED.

Other _concrete lemmas

For every other FormatR and FormatI instruction that we have proven so far, we have a bridge lemma and a concrete lemma. Every instruction follows this same pattern. Here are all the _concrete lemmas and their bridge lemmas across the codebase:

InstructionConcrete LemmaBridge LemmaFile
ADDWjolt_addw_concrete (line 67)extractLsb_addAddw.lean
ADDIWjolt_addiw_concrete (line 71)Addiw.lean
SUBWjolt_subw_concrete (line 79)extractLsb_subSubw.lean
MULWjolt_mulw_concrete (line 49)truncate distributes over multiplyMulw.lean
MULHjolt_mulh_concrete (line 67)7-step computation = mult_to_bits_halfMulh.lean
SLLjolt_sll_concrete (line 63)multiply by 2^(rs2[5:0]) = shift_bits_leftSll.lean
SLLIjolt_slli_concrete (line 56)multiply by 2^shamt = shift_bits_leftSlli.lean
SLLIWjolt_slliw_concrete (line 54)multiply + truncate = 32-bit shiftSlliw.lean
SLLWjolt_sllw_concrete (line 103)Sllw.lean
SRAjolt_sra_concrete (line 49)sshiftRight via ctz(bitmask) = shift_bits_right_arithSra.lean
SRAIjolt_srai_concrete (line 69)Srai.lean
SRAIWjolt_sraiw_concrete (line 133)3-step Jolt value = arithmetic right shiftSraiw.lean
SRAWjolt_sraw_concrete (line 115)Sraw.lean
SRLjolt_srl_concrete (line 48)logical shift via ctz(bitmask) = shift_bits_rightSrl.lean
SRLIjolt_srli_concrete (line 63)logical shift via ctz(bitmask) = shift_bits_rightSrli.lean
SRLIWjolt_srliw_concrete (line 84)slli-32 + srli via bitmask = 32-bit logical right shiftSrliw.lean
SRLWjolt_srlw_concrete (line 145)Srlw.lean

All files are under JoltBytecode/EmbeddedSailJoltState/InstructionEquivalence/. Each _concrete lemma says: Jolt's decomposition succeeds with RETIRE_SUCCESS and the final Sail state equals stateAfterWrite js.sail rd (value) — where value is the instruction-specific computation. The bridge lemma (where needed) connects the Jolt computation to the Sail computation. The main _eq_sail theorem then shows the Sail instruction writes the same value via the uniform plumbing pattern described above.

Did We Need This Modularisation

We made a huge deal about having this concrete lemma. Here's a glimpse into the goal state the same theorem passing cleanly but with everything done in the main proof for ADDW

(match
    match
      match
        match
          match
            rX_bits rs1
              { regs := js.regs, choiceState := js.choiceState, mem := js.mem, tags := js.tags,
                cycleCount := js.cycleCount, sailOutput := js.sailOutput } with
          | EStateM.Result.ok a s =>
            match rX_bits rs2 s with
            | EStateM.Result.ok a_1 s => EStateM.Result.ok (a + a_1) s
            | EStateM.Result.error e s => EStateM.Result.error e s
          | EStateM.Result.error e s => EStateM.Result.error e s with
        | EStateM.Result.ok a s =>
          match wX_bits rd a s with
          | EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
          | EStateM.Result.error e s => EStateM.Result.error e s
        | EStateM.Result.error e s => EStateM.Result.error e s with
      | EStateM.Result.ok a ss' =>
        EStateM.Result.ok a
          { regs := ss'.regs, choiceState := ss'.choiceState, mem := ss'.mem, tags := ss'.tags,
            cycleCount := ss'.cycleCount, sailOutput := ss'.sailOutput, vregs := js.vregs }
      | EStateM.Result.error e ss' =>
        EStateM.Result.error e
          { regs := ss'.regs, choiceState := ss'.choiceState, mem := ss'.mem, tags := ss'.tags,
            cycleCount := ss'.cycleCount, sailOutput := ss'.sailOutput, vregs := js.vregs } with
    | EStateM.Result.ok a s =>
      match
        match
          match
            rX_bits rd
              { regs := s.regs, choiceState := s.choiceState, mem := s.mem, tags := s.tags, cycleCount := s.cycleCount,
                sailOutput := s.sailOutput } with
          | EStateM.Result.ok a ss' =>
            EStateM.Result.ok a
              { regs := ss'.regs, choiceState := ss'.choiceState, mem := ss'.mem, tags := ss'.tags,
                cycleCount := ss'.cycleCount, sailOutput := ss'.sailOutput, vregs := s.vregs }
          | EStateM.Result.error e ss' =>
            EStateM.Result.error e
              { regs := ss'.regs, choiceState := ss'.choiceState, mem := ss'.mem, tags := ss'.tags,
                cycleCount := ss'.cycleCount, sailOutput := ss'.sailOutput, vregs := s.vregs } with
        | EStateM.Result.ok a s =>
          match
            wX_bits rd (sign_extend (Sail.BitVec.extractLsb a 31 0))
              { regs := s.regs, choiceState := s.choiceState, mem := s.mem, tags := s.tags, cycleCount := s.cycleCount,
                sailOutput := s.sailOutput } with
          | EStateM.Result.ok a ss' =>
            EStateM.Result.ok a
              { regs := ss'.regs, choiceState := ss'.choiceState, mem := ss'.mem, tags := ss'.tags,
                cycleCount := ss'.cycleCount, sailOutput := ss'.sailOutput, vregs := s.vregs }
          | EStateM.Result.error e ss' =>
            EStateM.Result.error e
              { regs := ss'.regs, choiceState := ss'.choiceState, mem := ss'.mem, tags := ss'.tags,
                cycleCount := ss'.cycleCount, sailOutput := ss'.sailOutput, vregs := s.vregs }
        | EStateM.Result.error e s => EStateM.Result.error e s with
      | EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
      | EStateM.Result.error e s => EStateM.Result.error e s
    | EStateM.Result.error e s => EStateM.Result.error e s with
  | EStateM.Result.ok a js' =>
    EStateM.Result.ok a
      { regs := js'.regs, choiceState := js'.choiceState, mem := js'.mem, tags := js'.tags,
        cycleCount := js'.cycleCount, sailOutput := js'.sailOutput }
  | EStateM.Result.error e js' =>
    EStateM.Result.error e
      { regs := js'.regs, choiceState := js'.choiceState, mem := js'.mem, tags := js'.tags,
        cycleCount := js'.cycleCount, sailOutput := js'.sailOutput }) =
  match
    rX_bits rs1
      { regs := js.regs, choiceState := js.choiceState, mem := js.mem, tags := js.tags, cycleCount := js.cycleCount,
        sailOutput := js.sailOutput } with
  | EStateM.Result.ok a s =>
    match rX_bits rs2 s with
    | EStateM.Result.ok a_1 s =>
      match wX_bits rd (sign_extend (Sail.BitVec.extractLsb a 31 0 + Sail.BitVec.extractLsb a_1 31 0)) s with
      | EStateM.Result.ok a s => EStateM.Result.ok RETIRE_SUCCESS s
      | EStateM.Result.error e s => EStateM.Result.error e s
    | EStateM.Result.error e s => EStateM.Result.error e s
  | EStateM.Result.error e s => EStateM.Result.error e s

Remember that Lean is a functional programming language. So every time we have multiple of lines of code in the do block, it leads to a match and bind call.