Motivating Problem

As a motivating problem we look at how division is implemented in Jolt. In a typical RISC CPU, we perform one fetch-decode-execute cycle with the DIV rd rs2 rs1 instruction, and the quotient of the contents of rs1 divided by rs2 is written into the destination register rd.

In Jolt, division is implemented very differently. Some untrusted oracle leaves two 64 bit values on the advice tape, calls them the quotient and remainder and claims that it was computed as if we were executing the real RISC division instruction. To validate the oracle's claim we perform 4 tests, which gives us the following completeness and soundness theorems.

Completeness tells us that if the claimed quotient and remainder were indeed computed correctly, then none of the 4 tests will error. The final state of the RISC CPU will be identical to the RISC CPU that executed the DIV instruction.

Soundness tells us that for any (quotient, remainder) different from the correct pair, at least one of the tests will fail.

Based on how we have defined CPU state so far in the series, this is what implementing Jolt's version of division looks like in Lean. Each line represents a Jolt ISA instruction, which behind the scenes is an imperative program that steps the Jolt CPU one step.

def jolt_div (rs2 rs1 rd : regidx)
    (quotient rem_abs : BitVec 64) : JoltMonad ExecutionResult := do
  let _ ← vreg_advice 0 quotient                      -- VirtualAdvice v0, 0    (quotient)
  let _ ← vreg_advice 1 rem_abs                       -- VirtualAdvice v1, 0    (|remainder|)
  let _ ← vreg_assert_valid_div0 rs2 0                -- VirtualAssertValidDiv0 rs2, v0, 0
  let _ ← vreg_change_divisor 2 rs1 rs2               -- VirtualChangeDivisor   v2, rs1, rs2
  let _ ← vreg_MULH 3 0 2                             -- MULH                   v3, v0, v2
  let _ ← vreg_MUL 4 0 2                              -- MUL                    v4, v0, v2
  let _ ← vreg_SRAI 5 4 63                            -- SRAI                   v5, v4, 63
  let _ ← vreg_assert_eq 3 5                          -- VirtualAssertEQ        v3, v5, 0
  let _ ← vreg_SRAI_from_real 3 rs1 63                -- SRAI                   v3, rs1, 63
  let _ ← vreg_XOR 5 1 3                              -- XOR                    v5, v1, v3
  let _ ← vreg_SUB 5 5 3                              -- SUB                    v5, v5, v3
  let _ ← vreg_ADD 4 4 5                              -- ADD                    v4, v4, v5
  let _ ← vreg_assert_eq_real 4 rs1                   -- VirtualAssertEQ        v4, rs1, 0
  let _ ← vreg_SRAI 3 2 63                            -- SRAI                   v3, v2, 63
  let _ ← vreg_XOR 5 2 3                              -- XOR                    v5, v2, v3
  let _ ← vreg_SUB 5 5 3                              -- SUB                    v5, v5, v3
  let _ ← vreg_assert_valid_unsigned_remainder 1 5    -- VirtualAssertValidUnsignedRemainder v1, v5, 0
  vreg_ADDI_to_real rd 0 0                            -- ADDI                   rd, v0, 0   (move quotient)

The completeness theorem against the trusted transpilation looks like the following:

theorem jolt_div_concrete (rs2 rs1 rd : regidx)
    (hrd : rd ≠ regidx.Regidx 0) (js : SailJoltState) (hwf : WellFormed js)
    (dividend divisor : BitVec 64)
    (hrs1 : rX_bits rs1 js.sail = .ok dividend js.sail)
    (hrs2 : rX_bits rs2 js.sail = .ok divisor js.sail) :
    projectResult (
                    (jolt_div rs2 rs1 rd
                      (sail_div_value dividend divisor false) -- Correct quotient
                      (bv_abs (sail_rem_value dividend divisor false)) -- Correct Remainder
                    ).run js
                 ) 
=
    (execute_DIV rs2 rs1 rd false).run js.sail

The theorem statement simply says, run the jolt_div sequence with the correct advice, and the final RISC CPU state will be identical to that if we had run the original DIV instruction to begin with.

Now if we wanted to reason about the LHS and RHS top down as monolithic blocks, we very quickly drive ourselves into a corner of monadic monstrosity. Concretely, the most obvious move is to unfold jolt_div and peel one layer of EStateM.run off each bind:

  unfold jolt_div
  simp only [EStateM.run_bind]

Two lines of tactic, and the goal state Lean hands back is this:

⊢ projectResult
    (match EStateM.run (vreg_advice 0 (sail_div_value dividend divisor false)) js with
    | EStateM.Result.ok x s =>
      match EStateM.run (vreg_advice 1 (bv_abs (sail_rem_value dividend divisor false))) s with
      | EStateM.Result.ok x s =>
        match EStateM.run (vreg_assert_valid_div0 rs2 0) s with
        | EStateM.Result.ok x s =>
          match EStateM.run (vreg_change_divisor 2 rs1 rs2) s with
          | EStateM.Result.ok x s =>
            match EStateM.run (vreg_MULH 3 0 2) s with
            | EStateM.Result.ok x s =>
              match EStateM.run (vreg_MUL 4 0 2) s with
              | EStateM.Result.ok x s =>
                match EStateM.run (vreg_SRAI 5 4 63) s with
                | EStateM.Result.ok x s =>
                  match EStateM.run (vreg_assert_eq 3 5) s with
                  | EStateM.Result.ok x s =>
                    match EStateM.run (vreg_SRAI_from_real 3 rs1 63) s with
                    | EStateM.Result.ok x s =>
                      match EStateM.run (vreg_XOR 5 1 3) s with
                      | EStateM.Result.ok x s =>
                        match EStateM.run (vreg_SUB 5 5 3) s with
                        | EStateM.Result.ok x s =>
                          match EStateM.run (vreg_ADD 4 4 5) s with
                          | EStateM.Result.ok x s =>
                            match EStateM.run (vreg_assert_eq_real 4 rs1) s with
                            | EStateM.Result.ok x s =>
                              match EStateM.run (vreg_SRAI 3 2 63) s with
                              | EStateM.Result.ok x s =>
                                match EStateM.run (vreg_XOR 5 2 3) s with
                                | EStateM.Result.ok x s =>
                                  match EStateM.run (vreg_SUB 5 5 3) s with
                                  | EStateM.Result.ok x s =>
                                    match EStateM.run (vreg_assert_valid_unsigned_remainder 1 5) s with
                                    | EStateM.Result.ok x s => EStateM.run (vreg_ADDI_to_real rd 0 0) s
                                    | EStateM.Result.error e s => EStateM.Result.error e s
                                  | EStateM.Result.error e s => EStateM.Result.error e s
                                | EStateM.Result.error e s => EStateM.Result.error e s
                              | EStateM.Result.error e s => EStateM.Result.error e s
                            | EStateM.Result.error e s => EStateM.Result.error e s
                          | EStateM.Result.error e s => EStateM.Result.error e s
                        | EStateM.Result.error e s => EStateM.Result.error e s
                      | EStateM.Result.error e s => EStateM.Result.error e s
                    | EStateM.Result.error e s => EStateM.Result.error e s
                  | EStateM.Result.error e s => EStateM.Result.error e s
                | EStateM.Result.error e s => EStateM.Result.error e s
              | EStateM.Result.error e s => EStateM.Result.error e s
            | EStateM.Result.error e s => EStateM.Result.error e s
          | EStateM.Result.error e s => EStateM.Result.error e s
        | EStateM.Result.error e s => EStateM.Result.error e s
      | EStateM.Result.error e s => EStateM.Result.error e s
    | EStateM.Result.error e s => EStateM.Result.error e s) =
    EStateM.run (execute_DIV rs2 rs1 rd false) js.sail

Eighteen levels of nested match, one per instruction, every single one with a dead .error e s => .error e s arm that does nothing but re-raise an error through the stack. Note, this blowup is just from the outer unfold. We are yet to unfold each instruction, which would lead to the goal state looking much much worse. No one wants to mechanically deal with each match, and unwind the plumbing. We do not want to do this not only because it is physically intractable and error prone, but no one thinks of the proof of this theorem this way. So in this post, we outline a general strategy that mimics how humans write proofs, without letting the goal state go out of control.

A Simpler Program

To better describe how we deal with the above problem, we first describe a simpler program. We will manifest the same problem in a much more miniature scale. Although significantly simpler, it captures the exact core of the complexity above.

Consider the following simplistic definition of state. Its state is two natural-number registers.

structure State where
  x : Nat
  y : Nat
  deriving Repr

inductive Err where
  | assertFailed : String → Err
  deriving Repr

abbrev M (α : Type) : Type := EStateM Err State α

We will use Lean's inbuilt Error State Monad EStateM to model a state transition of this state, which we abbreviate with M. So far we have nothing special. We have done exactly what we did in the post introducing Jolt State, but we are using a significantly simpler definition of state and error.

This CPU has two instructions:

  • doubleX reads x, writes back 2 * x, and returns the new value. It never fails.
  • divYbyX reads the state, throws assertFailed "division by zero" if x = 0, and otherwise returns y / x without ever updating the state.
def doubleX : M Nat := do
  let s ← get
  let new := 2 * s.x
  set { s with x := new }
  pure new

def divYbyX : M Nat := do
  let s ← get
  if s.x = 0 then throw (Err.assertFailed "division by zero")
  else pure (s.y / s.x)

A toy program is given here. Given initial state s we double x 3 times, and then we divide s.y with the final value of s.x.

def prog1 : M Nat := do
  let _ ← doubleX
  let _ ← doubleX
  let _ ← doubleX
  divYbyX

Evaluated on initial state { x := 5, y := 320 } everything should pass — the different values of x are 5 → 10 → 20 → 40, and the final divYbyX computes 320 / 40 = 8:

#eval prog1.run { x := 5, y := 320 }
-- EStateM.Result.ok 8

Error Evaluation — if we start with x = 0, doubling preserves zero, so when we reach divYbyX the divisor is still 0 and we throw:

#eval prog1.run { x := 0, y := 320 }
-- EStateM.Result.error (Err.assertFailed "division by zero")

Now we can write a couple of theorems.

/-- If we start with `x = 0`, `prog1` errors on the final `divYbyX`. -/
theorem prog1_errors_on_zero (s : State) (h : s.x = 0) :
    ∃ s', prog1.run s = .error (Err.assertFailed "division by zero") s' := by
  sorry

/-- Starting from `x = 5, y = 320`, three doublings take `x` to 40 and
`divYbyX` returns 320 / 40 = 8. -/
theorem prog1_on_5_320 :
    prog1.run { x := 5, y := 320 } = .ok 8 { x := 40, y := 320 } := by
  sorry

Trying the same unfold + simp only [EStateM.run_bind] move on prog1_on_5_320

  unfold prog1
  simp only [EStateM.run_bind]

— gives us back the same monster, at smaller scale:

⊢ (match EStateM.run doubleX { x := 5, y := 320 } with
    | EStateM.Result.ok x s =>
      match EStateM.run doubleX s with
      | EStateM.Result.ok x s =>
        match EStateM.run doubleX s with
        | EStateM.Result.ok x s => EStateM.run divYbyX s
        | EStateM.Result.error e s => EStateM.Result.error e s
      | EStateM.Result.error e s => EStateM.Result.error e s
    | EStateM.Result.error e s => EStateM.Result.error e s) =
    EStateM.Result.ok 8 { x := 40, y := 320 }

Same problem -- just 3 levels instead of 18. The crux of the problem is that I do not wish to deal with this.

How Would You Solve This On An Exam?

Before going into Lean proving malarkey, we ask if you were tasked with solving the above problem on paper, how would you go about it? You would say well we start with s := { x := 5, y := 320 }.

  1. After one invocation of doubleX we would have { x := 10, y := 320 }.
  2. Then one more and we have { x := 20, y := 320 }.
  3. Then one more and we have { x := 40, y := 320 }.
  4. Now x ≠ 0, so divYbyX succeeds — it returns 320 / 40 = 8, and the final state is still { x := 40, y := 320 } (since divYbyX only reads, never writes).

That would qualify as a perfectly acceptable proof. Essentially, we would like to do the same in Lean.

Understanding The Do Notation

If we look at the goal state of prog1 we see this.

EStateM.run
    (do
      let __discr ← doubleX
      have x : Nat := __discr
      let __discr ← doubleX
      have x : Nat := __discr
      let __discr ← doubleX
      have x : Nat := __discr
      divYbyX)
    init_state =
  EStateM.Result.ok 8 final_state

But Lean is a functional programming language. What we see above is syntax sugar for things to look nice. What Lean's type-checker sees is this:

def prog1 : M Nat :=
  doubleX >>= (fun _ =>
    doubleX >>= (fun _ =>
      doubleX >>= (fun _ =>
        divYbyX)))

How do we know this? Well we can ask the type-checker if this functional version is actually the same as the goal state. If we wrote this as the first line of the proof, it passes.

  show (doubleX >>= fun _ =>
           doubleX >>= fun _ =>
           doubleX >>= fun _ =>
           divYbyX).run init_state = .ok 8 final_state 

For a valid lean expression G, show G does exactly two things:

  1. Lean confirms that G is definitionally equal to the current goal. If not, it errors.
  2. Rewrite the goal with G. All later tactics now see G as the target.

Given what the Lean type-checker really sees, we write this helper lemma.

theorem bind_run_of_ok {α β : Type} {m : M α} {f : α → M β}
    {s s₁ : State} {a : α}
    (h : m.run s = .ok a s₁) :
    (m >>= f).run s = (f a).run s₁ := by
  show (m >>= f) s = (f a) s₁
  simp only [bind, EStateM.bind]
  rw [show m s = .ok a s₁ from h]
In Plain English

It says we start in state s and then perform one step described by instruction m successfully to get to s₁ with return value a. Then running the second computation given by f a on state s₁ is the same as running the whole chain m >>= f starting from the original state s.

We also prove two theorems about each instruction.

theorem doubleX_run (s : State) :
    doubleX.run s = .ok (2 * s.x) { s with x := 2 * s.x } := rfl

theorem divYbyX_run_ok (s : State) (h : s.x ≠ 0) :
    divYbyX.run s = .ok (s.y / s.x) s := by
  unfold divYbyX
  simp only [EStateM.run_bind, EStateM.run_get]
  rw [if_neg h]
  rfl 

theorem divYbyX_run_err (s : State) (h : s.x = 0) :
    divYbyX.run s = .error (Err.assertFailed "division by zero") s := by
  unfold divYbyX
  simp only [EStateM.run_bind, EStateM.run_get]
  rw [if_pos h]
  rfl 

theorem divYbyX_run  (s: State):
    divYbyX.run s = 
    if s.x = 0 then .error (Err.assertFailed "division by zero") s 
    else .ok (s.y/ s.x) s := by 
  by_cases h : s.x = 0
  · rw [if_pos h] 
    exact divYbyX_run_err _ h 
  · rw [if_neg h] 
    exact divYbyX_run_ok s h  

Four short lemmas, one per instruction behaviour:

  • doubleX_rundoubleX always succeeds, overwrites x with 2 * s.x, and returns the new value. The proof is rfl: everything in the do-block reduces in the kernel.
  • divYbyX_run_ok — success branch. When s.x ≠ 0, divYbyX returns s.y / s.x and leaves the state untouched.
  • divYbyX_run_err — failure branch. When s.x = 0, divYbyX throws assertFailed "division by zero", again leaving the state untouched.
  • divYbyX_run — the combined if-then-else statement, derived by by_cases on s.x = 0 and dispatching to the two branch lemmas.

The point of splitting _ok and _err out is that downstream callers usually already know which side of the guard they're on, and would rather feed bind_run_of_ok a clean equation than an if.

Note that we are still actually going through the match code, but we are doing it at each instruction level once.

Closing Out Theorem

So now we show how that one lemma completely alleviates us from reasoning about the massively nested match. For starters, we do exactly what we did in our paper proof. We step through each instruction and reason about before and after state.


theorem prog1_on_5_320 (init_state final_state : State) 
  (h_start : init_state = { x := 5, y := 320 }) 
  (h_end : final_state = { x := 40, y := 320 }) 
  :
    prog1.run init_state = .ok 8 final_state := by
  let s1 : State := { x := 10, y := 320 }
  have h1 : doubleX.run init_state = .ok 10 s1 := by
    rw [h_start]
    exact doubleX_run _
  let s2 : State := { x := 20, y := 320 }
  have h2 : doubleX.run s1 = .ok 20 s2 := doubleX_run _ 
  have h3 : doubleX.run s2 = .ok 40 final_state := by 
    rw [h_end]
    exact doubleX_run _
  have hs3ne : final_state.x ≠ 0 := by
    rw [h_end]
    decide 
  have h4 : divYbyX.run final_state = .ok 8 final_state := by
    rw [h_end]
    exact divYbyX_run_ok _ (by rw [h_end] at hs3ne; exact hs3ne)
  -- Nothing has happened so far to the goal

Reading the haves in order — each one is a line straight out of the pen-and-paper proof:

  1. h1: After one doubleX from init_state = { x := 5, y := 320 }, we are at s1 = { x := 10, y := 320 } with return value 10.
  2. h2: Another doubleX from s1 lands us at s2 = { x := 20, y := 320 } with return value 20.
  3. h3: One more doubleX from s2 takes us to final_state = { x := 40, y := 320 } with return value 40.
  4. hs3ne: final_state.x = 40, so final_state.x ≠ 0 — this is the guard divYbyX_run_ok needs.
  5. h4: With the nonzero guard in hand, divYbyX on final_state succeeds: it returns 320 / 40 = 8 and leaves the state untouched.

We close each of the mini haves with our per-theorem reasoning. So far the goal has not changed. We've just built a sequence of hypotheses. This is where our bind_run_of_ok lemma comes in. We can now simply chain them.

To close out the proof we just write

  unfold prog1
  rw [bind_run_of_ok h1]
  rw [bind_run_of_ok h2]
  rw [bind_run_of_ok h3]
  exact h4

Reading the program above in plain English — the goal starts out as

prog1.run init_state = .ok 8 final_state

and each line shrinks the left-hand side by one instruction, without ever touching the right-hand side:

  • unfold prog1 — replace the name prog1 with its body. The LHS is now the explicit bind chain (doubleX >>= fun _ => doubleX >>= fun _ => doubleX >>= fun _ => divYbyX).run init_state.
  • rw [bind_run_of_ok h1]h1 tells us the leading doubleX takes init_state to .ok 10 s1. The lemma converts that into "the whole chain run from init_state" = "the tail chain run from s1". The first doubleX is gone; the LHS is now (doubleX >>= fun _ => doubleX >>= fun _ => divYbyX).run s1.
  • rw [bind_run_of_ok h2] — same move with h2: peel off the next doubleX, advance the state to s2.
  • rw [bind_run_of_ok h3] — peel off the third doubleX, advance to final_state. The LHS is now just divYbyX.run final_state.
  • exact h4h4 is exactly divYbyX.run final_state = .ok 8 final_state, which is the goal.

Each rewrite says the same thing: "I know that this step passed and here is the proof — so replace the whole m >>= f chain with f a run on the new state s₁."

To make that concrete, look at the very first rewrite, rw [bind_run_of_ok h1]. The generic names in the lemma get instantiated as:

  • m := doubleX — the head instruction of the chain.
  • f := fun _ => doubleX >>= fun _ => doubleX >>= fun _ => divYbyX — the tail of the chain, i.e. everything after the first doubleX.
  • s := init_state — the state we're running from.
  • a := 10, s₁ := s1 — what h1 tells us doubleX produced.

With those in hand, bind_run_of_ok h1 converts

(doubleX >>= (fun _ => doubleX >>= fun _ => doubleX >>= fun _ => divYbyX)).run init_state

into

((fun _ => doubleX >>= fun _ => doubleX >>= fun _ => divYbyX) 10).run s1

Since the continuation discards its argument (fun _ => ...), the 10 just evaporates, leaving

(doubleX >>= fun _ => doubleX >>= fun _ => divYbyX).run s1

The same chain, one step shorter, now running from s1 instead of init_state. The next two rewrites do the same thing with h2 and h3: each peels the leading doubleX off and bumps the state forward, until the LHS is just divYbyX.run final_state — which is exactly h4.

Back to jolt_div

The general tactic is we will show two theorems:

  1. The RHS, when run, generates state s. As this is just one step of the trusted RISC CPU, we can prove this directly.
  2. The LHS, which is the Jolt sequence, when run, generates state s. As the LHS is a sequence of instructions, the same lemma closes the 18-step jolt_div sequence — but the real proof doesn't apply bind_run_of_ok eighteen times at the top. It uses the same lemma at two granularities, with a phasing layer in between.

Once we have RHS and LHS lemmas -- the rest is mechanical 4-5 lines of re-writes. The rest of this post talks about how to handle item number 2.

Phasing

The 18 instructions group into five phases, 1 for setup and the rest for each test or assertion. The grouping into phases is given below.

PhaseInstructions
phase_setupadvice, advice, assert_valid_div0
phase_overflow_checkchange_divisor, MULH, MUL, SRAI, assert_eq
phase_quotient_productSRAI_from_real, XOR, SUB, ADD, assert_eq_real
phase_remainder_boundSRAI, XOR, SUB, assert_valid_unsigned_remainder
phase_writebackADDI_to_real

Each phase is a small monadic program — its own do-block of 1–5 instructions. For example, shown below are the first 2 phases.

/-- Phase 1 — advice loads + div-by-zero assert. -/
def phase_setup (rs2 : regidx) (quotient rem_abs : BitVec 64) :
    JoltMonad ExecutionResult := do
  let _ ← vreg_advice 0 quotient
  let _ ← vreg_advice 1 rem_abs
  vreg_assert_valid_div0 rs2 0

/-- Phase 2 — adjusted divisor, MULH/MUL/SRAI, overflow-check assert. -/
def phase_overflow_check (rs1 rs2 : regidx) : JoltMonad ExecutionResult := do
  let _ ← vreg_change_divisor 2 rs1 rs2
  let _ ← vreg_MULH 3 0 2
  let _ ← vreg_MUL 4 0 2
  let _ ← vreg_SRAI 5 4 63
  vreg_assert_eq 3 5 

Next, we prove that this decomposition into phases is definitionally the same as running jolt_div.

theorem jolt_div_phased (rs2 rs1 rd : regidx)
    (quotient rem_abs : BitVec 64) :
    jolt_div rs2 rs1 rd quotient rem_abs = (do
      let _ ← phase_setup rs2 quotient rem_abs
      let _ ← phase_overflow_check rs1 rs2
      let _ ← phase_quotient_product rs1
      let _ ← phase_remainder_bound
      phase_writeback rd) := by
  simp [jolt_div, phase_setup, phase_overflow_check,
        phase_quotient_product, phase_remainder_bound,
        phase_writeback, bind_assoc]

With this theorem in hand -- we are ready to just reason about the phased program instead of jolt_div.

_run lemmas, one per phase

Each phase gets a characterisation lemma, analogous to doubleX_run and divYbyX_run_ok from the toy. It states: "given these preconditions and guards, this phase succeeds and establishes the following conditions on new state js'." For example:

theorem phase_setup_run
    (rs2 : regidx) (q rem : BitVec 64) (js : SailJoltState)
    (divisor : BitVec 64)
    (hrs2 : rX_bits rs2 js.sail = .ok divisor js.sail)
    (hguard_div0 : ¬ (divisor = 0#64 ∧ q ≠ (-1 : BitVec 64))) :
    ∃ js',
      (phase_setup rs2 q rem).run js = .ok RETIRE_SUCCESS js' ∧
      js'.vregs 0 = q ∧
      js'.vregs 1 = rem ∧
      js'.sail = js.sail

Reading: if we enter phase_setup with divisor pinned in rs2 and the div-by-zero guard satisfied, the phase succeeds; afterwards v0 holds the quotient q, v1 holds |r|, and the Sail state is untouched.

This theorem will be like the lemma we proved above in the toy example. For each instruction in the phase, we will have run_ok lemmas -- which will help close this out. We just do not show this here for brevity. There's nothing insightful here, just two stages of the same thing.

Stitching the phases

With the five phase-run lemmas in hand, the top-level jolt_div_concrete proof is the toy proof, again, one layer up:

theorem jolt_div_concrete ... := by
  -- Four guards derived from honest advice (pure math lemmas).
  have hguard_div0             := hguard_div0_of_honest dividend divisor
  have hguard_overflow         := hguard_overflow_of_honest dividend divisor
  have hguard_quotient_product := hguard_quotient_product_of_honest dividend divisor
  have hguard_rem_bound        := hguard_rem_bound_of_honest dividend divisor

  -- Extract the five phase-run results in order, threading state.
  obtain ⟨js₁, hrun1, ...⟩ := phase_setup_run             ... hguard_div0
  obtain ⟨js₂, hrun2, ...⟩ := phase_overflow_check_run    ... hguard_overflow
  obtain ⟨js₃, hrun3, ...⟩ := phase_quotient_product_run  ... hguard_quotient_product
  obtain ⟨js₄, hrun4, ...⟩ := phase_remainder_bound_run   ... hguard_rem_bound
  obtain ⟨js₅, hrun5, h5⟩  := phase_writeback_run          ...

  -- Reshape to a phase-chain, then peel four times.
  refine ⟨js₅, ?_, h5⟩
  rw [jolt_div_phased]
  rw [bind_run_of_ok hrun1]
  rw [bind_run_of_ok hrun2]
  rw [bind_run_of_ok hrun3]
  rw [bind_run_of_ok hrun4]
  exact hrun5

Four bind_run_of_ok calls, one exact. The shape is identical to prog1_on_5_320 — just with phase-run lemmas where the toy had instruction-run lemmas.

Exceptions

The strategy above transfers cleanly from the toy prog1 to jolt_div at the phase level. But when we drop one level down — chaining instructions inside a single phase via bind_run_of_ok — the toy's recipe stops working at scale, and we have to adapt. This section explains the failure mode and the workaround.

Need for _ex lemmas

The toy's State has named fields:

structure State where
  x : Nat
  y : Nat

So when doubleX_run says doubleX.run s = .ok (2 * s.x) { s with x := 2 * s.x }, every successor state is a concrete record literal. Looking up s1.x is rfl — one constructor projection, no work for the kernel.

Our Jolt state is structurally similar but has one critical difference:

structure SailJoltState where
  sail  : SailState
  vregs : BitVec 7 → BitVec 64

The vregs field is a function, not an enumerated set of named fields. A per-instruction _run lemma like vreg_MULH_run therefore has to describe the post-state's vregs as an explicit lambda:

theorem vreg_MULH_run (vd vs1 vs2 : BitVec 7) (js : SailJoltState) :
    vreg_MULH vd vs1 vs2 js = .ok RETIRE_SUCCESS
      { sail := js.sail
        vregs := fun r => if r = vd then mulhs (js.vregs vs1) (js.vregs vs2)
                          else js.vregs r }

That fun r => if r = vd then v else js.vregs r is the crux of the problem. When we chain four such writes by let-binding intermediate states s1, s2, s3, s4, each s_{i+1} references s_i.vregs r in the else branch of its own lambda. Reading s4.vregs 0 then forces the kernel to walk the chain s4 → s3 → s2 → s1 → js, peeling one if k = vd_i then … else … per level.

For a phase like phase_overflow_check (four chained writes, six different vregs-index lookups in the proof body) the kernel ends up re-walking that chain at every lookup. The cumulative term explodes:

error: (kernel) deep recursion detected

The toy never sees this because s1.x = 10, s2.x = 20, s3.x = 40 are flat record projections. Our chain is "flat" only up to symbolic equality — every concrete index lookup forces the kernel to commit to which if-branch fires, and that decision propagates through every lambda below it.

The fix is structural. Instead of describing the post-state by what it equals literally, describe it by what facts hold about it:

theorem vreg_MULH_run_ex (vd vs1 vs2 : BitVec 7) (js : SailJoltState) :
    ∃ js',
      (vreg_MULH vd vs1 vs2).run js = .ok RETIRE_SUCCESS js' ∧
      js'.vregs vd = mulhs (js.vregs vs1) (js.vregs vs2) ∧
      (∀ k, k ≠ vd → js'.vregs k = js.vregs k) ∧
      js'.sail = js.sail

Four conjuncts: the run equation, the new value at the written index, preservation at every other index, and sail untouched. We call this the _ex form of the run lemma — same content as the concrete _run, but the post-state is hidden behind an existential.

Phase proofs then destructure each step:

obtain ⟨s2, h2, h2_v3, h2_pres, h2_sail⟩ := vreg_MULH_run_ex 3 0 2 s1

Now s2 is an opaque fresh fvar — the kernel literally cannot unfold it. There is no lambda body to walk through, so the chain can't form. Cross-state lookups become one-line consequences of h2_pres:

have hs2_v0 : s2.vregs 0 = q := (h2_pres 0 (by decide)).trans hs1_v0

One application of preservation, one transitivity. Linear in the number of writes, not exponential.

The toy's let s1 : State := { x := 10, y := 320 } binds a state we can read directly. The _ex form's obtain ⟨s1, h1, …⟩ binds a state we can only read through its facts. The proof skeleton — bind_run_of_ok chain, refine ⟨s_n, ?_, …⟩, four rewrites, one exact — is identical between the two; the only difference is how the intermediate states are queried.

Closing the _ex lemmas

There is one additional gotcha when proving the _ex lemmas themselves. The natural-looking proof — reuse the concrete _run lemma, then fill in the four conjuncts — also blows the kernel:

theorem vreg_MULH_run_ex:= by
  refine ⟨_, vreg_MULH_run vd vs1 vs2 js, ?_, ?_, rfl⟩  -- BAD: kernel deep recursion

The reason is the same chain-walking phenomenon, one level up. vreg_MULH_run's conclusion type carries the literal fun r => if r = vd then v else js.vregs r lambda. The existential's witness has to unify with that lambda, and the kernel re-reduces it during type-checking. With a compound v like mulhs (js.vregs vs1) (js.vregs vs2), the reduction stalls.

The strategy that works is to prove _ex from scratch, never referring to the concrete _run form:

theorem vreg_MULH_run_ex (vd vs1 vs2 : BitVec 7) (js : SailJoltState) :
    ∃ js', … := by
  unfold vreg_MULH
  simp only [bind, EStateM.bind, pure, EStateM.pure,
             readVReg, writeVReg, get, modify, modifyGet,
             getThe, MonadStateOf.get, MonadStateOf.modifyGet,
             EStateM.get, EStateM.modifyGet]
  refine ⟨_, rfl, ?_, ?_, rfl⟩
  · show (if vd = vd then _ else js.vregs vd) = _
    rw [if_pos rfl]
  · intro k h
    show (if k = vd then _ else js.vregs k) = _
    rw [if_neg h]

The recipe in plain English:

  1. Unfold the instruction so the do-block is exposed.
  2. simp only with the monadic-plumbing simp set (bind, EStateM.bind, pure, readVReg, writeVReg, the various modifyGet / get projections) to flatten the do-block down to a single .ok RETIRE_SUCCESS { … } result.
  3. refine ⟨_, rfl, ?_, ?_, rfl⟩ in one shot:
    • the underscore witness is inferred from the run equation,
    • rfl closes the run equation (now a literal .ok = .ok),
    • rfl closes the sail conjunct (the lambda doesn't touch sail),
    • two ?_ goals remain for the new value at vd and preservation elsewhere.
  4. The vd goal reduces to (if vd = vd then v else js.vregs vd) = v after show. Close with rw [if_pos rfl].
  5. The preservation goal introduces the index k and the hypothesis k ≠ vd, then reduces to (if k = vd then v else js.vregs k) = js.vregs k. Close with rw [if_neg h].

For instructions that touch the Sail register file (e.g. vreg_change_divisor, vreg_SRAI_from_real), the recipe needs two small additions:

  • Also unfold liftSail, since the rX_bits / wX_bits reads are wrapped in it.
  • After the simp, rw [hrs1] (and rw [hrs2] if there are two source registers) to specialise the underlying rX_bits calls to their honest outcomes. A brief simp only [] between the two rewrites discharges the match .ok … iota that the first rewrite leaves behind.

One last detail: the existential's (vreg_X …).run js form sometimes blocks the bind reduction because the outer .run keeps the function abstract. A leading show ∃ js', vreg_X … js = .ok RETIRE_SUCCESS js' ∧ _ strips the .run (definitionally trivial — .run = fun m s => m s) and lets the inner simp see the actual function body.

Once the _ex lemmas are in place, the per-instruction reasoning inside a phase proof matches the toy's pattern exactly — the only adaptation is reading state facts through h_pres and h_at_vd rather than projecting named record fields.