All code for this post can be found at monad-do.lean

Notes on Chapter 4.4: do-Notation for Monads from Functional Programming in Lean.

In the previous post we the monad type as types that have the operations bind and pure as instances of a type class. As, we have discussed ad nauseam by now, monads enable composition. It does so via the bind operator, but this leads to heavily nested code that can be hard to reason about — recall from the first post how composing four debuggable functions with bind gave us bind (bind (bind (g' x) f') h') k'. But writing m x >>= fun x =>... everywhere is sort of annoying. Lean offers do notation is syntactic sugar that allows us to write these nested function calls as imperative code, which is much easier to read, and possibly reason about. We give examples below.

Before we get into translation rules, let's see what do buys us with examples from the earlier posts. Recall the code that gets the first, third, fifth, and seventh elements of a list. We wrap the output in a monad (like Option or Except ε) because the list might not be large enough — the monad handles the failure plumbing. Even with >>=, reading the chain takes effort:

def firstThirdFifthSeventh_v1
    [Monad m]
    (lookup : List α → Nat → m α)
    (xs : List α)
    : m (α × α × α × α) :=
  lookup xs 0 >>= fun first =>
  lookup xs 2 >>= fun third =>
  lookup xs 4 >>= fun fifth =>
  lookup xs 6 >>= fun seventh =>
  pure (first, third, fifth, seventh)
Monadic (>>=)
def firstThirdFifthSeventh
    [Monad m]
    (lookup : List α → Nat → m α)
    (xs : List α)
    : m (α × α × α × α) := do
  let first ← lookup xs 0
  let third ← lookup xs 2
  let fifth ← lookup xs 4
  let seventh ← lookup xs 6
  pure (first, third, fifth, seventh)
do notation

Now you're right to look at these two blocks of code and say, meh, hardly any difference. I could live with the bind notation, especially after we sugared it into a binary operation. The difference becomes noticeable when we attempt to describe logic involving loops. Take a program that sums the elements of a list, while also counting how many are even along the way:

def sumAndCountEvens (xs : List Nat) : Nat × Nat :=
  go xs 0 0
where
  go : List Nat → Nat → Nat → Nat × Nat
    | [], sum, count => (sum, count)
    | x :: rest, sum, count =>
      if x % 2 == 0
      then go rest (sum + x) (count + 1)
      else go rest (sum + x) count
Functional
def sumAndCountEvens (xs : List Nat) : Nat × Nat := Id.run do
  let mut sum := 0
  let mut count := 0
  for x in xs do
    sum := sum + x
    if x % 2 == 0 then
      count := count + 1
  return (sum, count)
do notation
#eval sumAndCountEvens [1, 2, 3, 4, 5]  -- (15, 2)

The functional version needs a recursive helper that threads both accumulators through every call. The do version is a flat loop with two mutable variables — the same code you'd write in Python. Now the do version is just syntactic sugar. Under the hood we are still using pure and bind — for the identity monad, hence Id.run do. If the return type is already monadic, then do uses that monad directly and Id.run is not needed. Better yet, we can make the function generic over any monad — just take a lookup parameter like we did with firstThirdFifthSeventh. The do version barely changes. But the functional version gets much worse — it now threads >>= through every recursive call:

def sumAndCountEvens [Monad m]
    (lookup : List Nat → Nat → m Nat)
    (xs : List Nat) (n : Nat)
    : m (Nat × Nat) :=
  go 0 0 0
where
  go (i sum count : Nat) : m (Nat × Nat) :=
    if i >= n then pure (sum, count)
    else lookup xs i >>= fun x =>
      go (i + 1) (sum + x)
        (if x % 2 == 0 then count + 1
         else count)
Functional (>>=)
def sumAndCountEvens [Monad m]
    (lookup : List Nat → Nat → m Nat)
    (xs : List Nat) (n : Nat)
    : m (Nat × Nat) := do
  let mut sum := 0
  let mut count := 0
  for i in [:n] do
    let x ← lookup xs i
    sum := sum + x
    if x % 2 == 0 then
      count := count + 1
  pure (sum, count)
do notation
#eval sumAndCountEvens lookupOption [1, 2, 3, 4, 5] 5
-- some (15, 2)
#eval sumAndCountEvens lookupOption [1, 2, 3] 5
-- none
#eval sumAndCountEvens lookupExcept [1, 2, 3, 4, 5] 5
-- Except.ok (15, 2)
#eval sumAndCountEvens lookupExcept [1, 2, 3] 5
-- Except.error "Index 3 not found (maximum is 2)"

Same function, same do block — swap the lookup and it works with Option, Except, or any other monad. The functional version has to manually call >>= and thread the result through the recursion.

mapM

Here's another example where do makes life easier. We want a monadic map — apply f : α → m β to every element of a list, collecting the results. In the >>= version we have to bind each call and name the intermediate results explicitly. In the do version we just say: hd is the result of applying f to the head, tl is the recursive call on the tail — much cleaner.

def mapM_v1 [Monad m]
    (f : α → m β)
    : List α → m (List β)
  | [] => pure []
  | x :: xs =>
    f x >>= fun hd =>
    mapM_v1 f xs >>= fun tl =>
    pure (hd :: tl)
Monadic (>>=)
def mapM' [Monad m]
    (f : α → m β)
    : List α → m (List β)
  | [] => pure []
  | x :: xs => do
    let hd ← f x
    let tl ← mapM' f xs
    pure (hd :: tl)
do notation

With nested actions(← E) inside an expression means "bind E and insert the result here." Almost as short as the non-monadic List.map:

def mapM'' [Monad m] (f : α → m β) : List α → m (List β)
  | [] => pure []
  | x :: xs => do
    pure ((← f x) :: (← mapM'' f xs))

Translation Rules

do notation is just syntactic sugar. Under the hood we always write pure functional code with the help of the monadic machinery — pure and >>=. The compiler mechanically rewrites every do block into these primitives. The full translation rules are given in Chapter 4.4: do-Notation for Monads.