Jolt Implementation Of CSRRW

Jolt supports the writing of 6 control status registers, namely, mstatus, mtvec, mepc, mtval, mscratch, mcause; and it treats the writing to each of these registers uniformly.

The assembly signature is csrrw rd, csr, rs1 and the pseudocode of semantics in the Jolt expansion is roughly:

if rd = x0 then 
    [vreg_csr] <- [rs1] + 0 -- implemented via ADDI
if rd = rs1 then 
    tmp <- [rs1] + 0
    [rd] <- [vreg_csr] + 0 
    [v_reg_csr] <- tmp + 0
else 
    [rd] <- [vreg_csr]
    [v_reg_csr] <- [rs1]

This can be verified against the Jolt source code.

Problem

The issue is that the RISC-V specification does not treat writes to control status registers uniformly. Depending on the status register being written to, the logic differs.

For mscratch, mcause, mtval, the above logic agrees with the RISC-V specification. For the remaining three, the spec asks that we first legalize the value in rs1 before writing it to the control status register.

More details below.

mtvec

As per the Sail spec, which we checked to be correct1, the value written to the contents of mtvec control status register is given by

writeReg mtvec (← (legalize_tvec (← readReg mtvec) value))

which in simple speak is saying first apply the legalize_tvec function with inputs the old_mtvec value and rs1_value and then write the control status register. The legalize_tvec function is given by

def legalize_tvec (o : (BitVec 64)) (v : (BitVec 64)) : SailM (BitVec 64) := do
  let v := (Mk_Mtvec v) -- identity function.
  match (trapVectorMode_of_bits (_get_Mtvec_Mode v)) with
  | TV_Direct => (pure v) -- we have to prove that we are in TV_Direct
  | TV_Vector => (pure v) -- or TV_Vector 
  | _ =>
    (do
      match xtvec_mode_reserved_behavior with
      | Xtvec_Fatal =>
        (reserved_behavior
          (HAppend.hAppend "Tried to write a reserved value ("
            (HAppend.hAppend (Int.repr (BitVec.toNatInt (_get_Mtvec_Mode v)))
              ") to the MODE field of xtvec.")))
      | Xtvec_Ignore => (pure (_update_Mtvec_Mode v (_get_Mtvec_Mode o))))

Given the functional-language nature of Lean, the above might seem hard to parse, but all that's going on is the following:

The match statement is saying: get the lowest two bits of rs1_val. If that value is 00 or 01, then [vreg_csr] <- [rs1_val], as in Jolt. Otherwise, grab the two least-significant bits of mtvec, overwrite rs1_val's lowest two bits with them, and then write this updated value into mtvec.

Now for the Jolt code to be correct, we must prove that the lowest two bits of rs1_val are always either 00 or 01. But Jolt does not promise this.

mepc

The first step is to read the old mepc value.

if ((bne access_type CSRWrite) : Bool)
            then (read_CSR csr) -- read csr (we take this branch)
            else (pure (zeros (n := 64))) ) : SailM xlenbits )

which for mepc — whose 12-bit address is 0x341 — results in the call

  | 0x341 => (get_xepc Machine)

which in turn returns

(align_pc (← readReg mepc))

So in simple speak: we read the current value in the mepc register, and then align it first:

def align_pc (addr : (BitVec 64)) : SailM (BitVec 64) := do
  if ((← (currentlyEnabled Ext_Zca)) : Bool)
  then (pure (BitVec.update addr 0 0#1))
  else (pure (Sail.BitVec.updateSubrange addr 1 0 (zeros (n := (1 -i (0 -i 1))))))

So the old value read from mepc is not the value before update (as Jolt assumes), but that value with its low bit cleared. Note that because this hart enables Ext_Zca (hartSupports Ext_Zca => true), the active branch is BitVec.update addr 0 0#1, which clears bit 0 only; the else branch clearing bits [1:0] is reached only when Zca is disabled.2

Similarly, before writing to mepc, we compute:

let target := (legalize_xepc rs1_value)

where

def legalize_xepc (v : (BitVec 64)) : (BitVec 64) :=
  if ((hartSupports Ext_Zca) : Bool)
  then (BitVec.update v 0 0#1)
  else (Sail.BitVec.updateSubrange v 1 0 (zeros (n := (1 -i (0 -i 1)))))

So both the read (align_pc) and the write (legalize_xepc) clear mepc's low bits, whereas Jolt copies rs1_val into the virtual register verbatim ([vreg_csr] <- [rs1] + 0, via ADDI) and later reads it back verbatim.

Now for the Jolt code to be correct, we must prove that every value written to mepc is already aligned — i.e. bit 0 of rs1_val is always 0 (or bits [1:0] are 00 when Zca is disabled). But this guard does not exist in Jolt.

mstatus

This is also a bug, but the write is more complex, in that I have to parse the following Lean code to give you the clean algorithm. I'll do this in a couple of days.

-- v = [rs1]
-- o = [mstatus]
def legalize_mstatus (o : (BitVec 64)) (v : (BitVec 64)) : SailM (BitVec 64) := do
  let v := (Mk_Mstatus v) -- overwrite v 
  let o ← do
    (pure (_update_Mstatus_SIE
        (_update_Mstatus_MIE
          (_update_Mstatus_SPIE
            (_update_Mstatus_MPIE
              (_update_Mstatus_VS
                (_update_Mstatus_SPP
                  (_update_Mstatus_MPP
                    (_update_Mstatus_FS
                      (_update_Mstatus_XS
                        (_update_Mstatus_MPRV
                          (_update_Mstatus_SUM
                            (_update_Mstatus_MXR
                              (_update_Mstatus_TVM
                                (_update_Mstatus_TW
                                  (_update_Mstatus_TSR
                                    (_update_Mstatus_SPELP
                                      (_update_Mstatus_MPELP o
                                        (if ((hartSupports Ext_Zicfilp) : Bool)
                                        then (_get_Mstatus_MPELP v)
                                        else 0#1))
                                      (if ((hartSupports Ext_Zicfilp) : Bool)
                                      then (_get_Mstatus_SPELP v)
                                      else 0#1))
                                    (← do
                                      if ((← (currentlyEnabled Ext_S)) : Bool)
                                      then (pure (_get_Mstatus_TSR v))
                                      else (pure 0#1)))
                                  (← do
                                    if ((← (currentlyEnabled Ext_U)) : Bool)
                                    then (pure (_get_Mstatus_TW v))
                                    else (pure 0#1)))
                                (← do
                                  if ((← (currentlyEnabled Ext_S)) : Bool)
                                  then (pure (_get_Mstatus_TVM v))
                                  else (pure 0#1)))
                              (← do
                                if ((← (currentlyEnabled Ext_S)) : Bool)
                                then (pure (_get_Mstatus_MXR v))
                                else (pure 0#1)))
                            (← do
                              if ((← (virtual_memory_supported ())) : Bool)
                              then (pure (_get_Mstatus_SUM v))
                              else (pure 0#1)))
                          (← do
                            if ((← (currentlyEnabled Ext_U)) : Bool)
                            then (pure (_get_Mstatus_MPRV v))
                            else (pure 0#1))) (extStatus_to_bits Off))
                      (if ((hartSupports Ext_Zfinx) : Bool)
                      then (extStatus_to_bits Off)
                      else (_get_Mstatus_FS v)))
                    (← do
                      if ((← (have_nominal_privLevel (_get_Mstatus_MPP v))) : Bool)
                      then (pure (_get_Mstatus_MPP v))
                      else (pure (privLevel_to_bits (← (lowest_supported_privLevel ()))))))
                  (← do
                    if ((← (currentlyEnabled Ext_S)) : Bool)
                    then (pure (_get_Mstatus_SPP v))
                    else (pure 0#1)))
                (if ((hartSupports Ext_Zve32x) : Bool)
                then (_get_Mstatus_VS v)
                else 0b00#2)) (_get_Mstatus_MPIE v))
            (← do
              if ((← (currentlyEnabled Ext_S)) : Bool)
              then (pure (_get_Mstatus_SPIE v))
              else (pure 0#1))) (_get_Mstatus_MIE v))
        (← do
          if ((← (currentlyEnabled Ext_S)) : Bool)
          then (pure (_get_Mstatus_SIE v))
          else (pure 0#1))))
  let dirty :=
    (((extStatus_of_bits (_get_Mstatus_FS o)) == Dirty) || (((extStatus_of_bits (_get_Mstatus_XS o)) == Dirty) || ((extStatus_of_bits
            (_get_Mstatus_VS o)) == Dirty)))
  (pure (_update_Mstatus_SD o (bool_to_bit dirty)))
1

Checked against the official RISC-V privileged ISA, which defines mtvec as a WARL register whose MODE field has only the Direct (0b00) and Vectored (0b01) encodings.

2

Zca enables 16-bit (compressed) instructions. See the official Zc* specification.

References

All claims above can also be checked against the official RISC-V ISA docs