Monads: What Is?
All code for this post can be found at monad-dan.lean or the live viewer
This is an attempt at re-doing this blog post in Lean instead of Haskell syntax. My big takeaway message is that monadic programming is just a way to bring composition back when there are side effects. So we start with an instance of composable computation.
Composable Problem
We start with a happy world where everything composes.
Given some floating point number x we wish to compute (x + 1)*2.
We can write
def f (x : Float) : Float := x + 1.0
def g (x : Float) : Float := x * 2.0
#eval (g ∘ f) 5.0 -- 12.0
All good so far.
Now say we want to have a global variable called debug, and each time we call f and g we append a string to that global variable.
Imperative python code for that looks something like this
debug = ""
def f(x):
global debug
debug += "f was called."
return x + 1.0
def g(x):
global debug
debug += "g was called."
return x * 2.0
result =((5.0))
print() # 12.0
print() # "f was called.g was called."
In Python we can still write the computation as composition, because the appending to a global can be considered to be a side-effect of the python methods f and g.
Lean being a functional language, we can't have this - our functions cannot have side effects.
The big question: How can we modify f and g to admit the side effect of appending to a global string1 while still writing functional code?
By "side effect" here we will more generally mean returning something extra alongside the main result — in this case, a debugging string. This extra thing will prevent us from composing easily.
The first step is to update the function definitions of f and g to accommodate the string appending.
The updated versions are called f' and g' respectively, and they produce strings as well as real numbers as output2. In other words, we need f' and g' to be of type ℝ → ℝ × String.
We'll call these debuggable functions3 — functions that take a normal input but return a pair: the result and a string.
This is just a name we made up — it's not standard terminology.
This is the only possible way to do side effects in Lean. In a pure language, the only way to produce extra output is to include it in the return type.
def f' (x : ℝ) : ℝ × String := (x + 1.0, "f was called.")
def g' (x : ℝ) : ℝ × String := (x * 2.0, "g was called.") x x
| |
+---+ +---+
| f | normal | f'| debuggable
+---+ +---+
| | \
f(x) f(x) "f was called."Now we'd like to compose f' and g' — apply g' then f', concatenating the debug strings. But we can't just do f' (g' x) because g' returns ℝ × String, not ℝ. We have to manually unwrap:
Instead composition requires us to write this block of code that gets f's output, peels off the actual output value, feeds it to g, and sends the strings to a concatenator.
def compose (x : ℝ) : ℝ × String :=
let (y, s) := g' x
let (z, t) := f' y
(z, String.append s t)
This looks nothing like composition. Here's the plumbing drawn out:
x
|
+---+
| g'|
+---+
| \
+---+ | "g was called."
| f'| |
+---+ |
| \ |
| \ |
| +---------------+
| | String.append |
| +---------------+
| |
f(g x) "g was called.f was called."
The result of g' is split: the real goes into f', and the two strings are concatenated with String.append.
Now imagine we add a third debuggable function and want to compose all three:
def h' (x : ℝ) : ℝ × String := (x - 3.0, "h was called.")
def compose3 (x : ℝ) : ℝ × String :=
let (y, s1) := g' x
let (z, s2) := f' y
let (w, s3) := h' z
(w, String.append (String.append s1 s2) s3)
Every time we add another function, we add another let to unwrap and another String.append to glue the logs. Four functions:
def k' (x : ℝ) : ℝ × String := (x / 2.0, "k was called.")
def compose4 (x : ℝ) : ℝ × String :=
let (y, s1) := g' x
let (z, s2) := f' y
let (w, s3) := h' z
let (v, s4) := k' w
(v, String.append (String.append (String.append s1 s2) s3) s4)
Notice that without the debugging strings, composing four functions is just:
def compose4_pure (x : ℝ) : ℝ := k (h (f (g x)))
Or equivalently k ∘ h ∘ f ∘ g.
The side effect of outputting strings is what breaks this natural composition4 — we can no longer just plug one function into the next because the types don't line up (ℝ × String vs ℝ).
In an imperative language, as we showed above with Python, this isn't a problem — side effects happen invisibly and composition still works.
But note that we have a pattern that is always the same:
- unwrap the pair,
- feed the value to the next function,
- accumulate the strings.
We're copying and pasting the same plumbing every time. So we do the most obvious thing possible, we factor that bit out and call it bind.
What bind does is bring back composition.
It's a function that handles the plumbing so we can chain the functions as easily as we chain pure ones.
bind takes two things:
- a pair of result and the debug string i.e. The new data type which encapsulates our side effect.
- the next function we wish to apply.
And it does the unwrap-apply-concat step for you:
def bind (pair : ℝ × String) (next : ℝ → ℝ × String) : ℝ × String :=
let (value, log1) := pair -- 1. unwrap the pair
let (result, log2) := next value -- 2. feed value to next function
(result, String.append log1 log2) -- 3. glue the logs together
So instead of all that manual plumbing:
def compose4 (x : ℝ) : ℝ × String :=
let (y, s1) := g' x
let (z, s2) := f' y
let (w, s3) := h' z
let (v, s4) := k' w
(v, String.append (String.append (String.append s1 s2) s3) s4)
We write:
def compose4 (x : ℝ) : ℝ × String :=
bind (bind (bind (g' x) f') h') k'
Intuitively, what bind does is play the role of a middle-man router that parses the output of functions and sends the next function exactly the input it needs.
So at the cost of ugly syntax above (which we will also clean up) we essentially get back composition like syntax.
Lift and Unit
What we have shown is that we can compose "debuggable" functions using bind.
Write this composition as f' * g' = bind (g' x) f'5.
Even though the output of g' is incompatible with the input of f', we have a nice way to chain them.
We use * instead of ∘ because this is a different operation from ordinary composition. f' ∘ g' is a type error — g' returns ℝ × String but f' expects ℝ. The * composition goes through bind to handle the plumbing.
bind lets us compose things already in the debuggable world. But how do we enter the debuggable world from outside? If we have a plain value 5 or a plain function f : ℝ → ℝ, how do we use them with bind?
This suggests a question: is there an identity for this composition?
Ordinary composition has id (where def id (x : α) : α := x), satisfying f ∘ id = f and id ∘ f = f. So we're looking for a debuggable function unit such that unit * f' = f' * unit = f'. It should do nothing — return the value unchanged and produce the empty string:
def unit (x : ℝ) : ℝ × String := (x, "")
unit also lets us lift any ordinary function into a debuggable one. If f : ℝ → ℝ is a plain function, then lift f = unit ∘ f — apply f first, then wrap the result with unit:
def lift (f : ℝ → ℝ) (x : ℝ) : ℝ × String := unit (f x)
The lifted version does the same as the original and produces the empty string as its side effect.
Believe it or not, we wrote a monad. But which part was the monad?
Before I point to the above code and say, "Hey this bit is the monad", we will work through two other examples. The hope is that by working through different use cases, you will see a common pattern. That common pattern is what we will call a monad.
The Container: Multivalued Functions
Now a completely different problem. Consider square roots and cube roots of complex numbers. Over the complex numbers, every nonzero number has two square roots and three cube roots. For example, $-1$ has two square roots: $i$ and $-i$. So these functions naturally return lists:
def sqrt' (x : ℂ) : List ℂ := sorry
def cbrt' (x : ℂ) : List ℂ := sorry
We'll call these multivalued functions — functions of type ℂ → List ℂ.
Now suppose we want to find all six sixth roots of a number. A sixth root is a square root of a cube root, or equivalently a cube root of a square root. With ordinary functions we'd just compose: sixthroot x = sqrt (cbrt x), as we'd do in latex $ x^{1/6} = f(g(x))$ where $f(x) = x^{1/2}$ and $g(x) = x^{1/3}$.
But sqrt' and cbrt' can't be composed with ∘ — cbrt' returns List ℂ, not ℂ. We need to apply the next function to every element of the list, then flatten the results into one list.
We want to write cbrt' once but have it applied to both sqrt' values. Sound familiar? We need another bind — a function that handles the plumbing for multivalued composition:
def bindList (x : List ℂ) (f : ℂ → List ℂ) : List ℂ :=
x.flatMap f
That's it. bind for lists is just flatMap — apply the function to each element and concatenate all the results. Now composing is easy:
def sixthRoots (x : ℂ) : List ℂ :=
bindList (sqrt' x) cbrt'
This is saying first compute the square root of x, bind will take all the outputs and pass them individually to the input of cube root, and then concatenate the answers.
Route inputs, append outputs - again!
And the identity? How do we lift an input into this context heavy data -- i.e a list?
def unitList (x : ℂ) : List ℂ := [x]
You've defined your second monad. Two entirely different problems, same structure:
Data Type: α × String
bind:
def bind (pair : ℝ × String)
(next : ℝ → ℝ × String) : ℝ × String :=
let (value, log1) := pair
let (result, log2) := next value
(result, String.append log1 log2)
unit:
def unit (x : ℝ) : ℝ × String := (x, "")
unwrap the pair, apply next, concat strings.
Data Type: List α
bind:
def bindList (x : List ℂ)
(next : ℂ → List ℂ) : List ℂ :=
x.flatMap next
unit:
def unitList (x : ℂ) : List ℂ := [x]
apply next to each element, flatten the lists.
A More Complex Side Effect: State
A third problem. Suppose we want to generate random numbers in a pure language. A random number generator needs a seed — you pass one in, get a value and a new seed out:
-- RandomGen.next : StdGen → Nat × StdGen
Note the similarity to the debuggable case: we're returning extra data alongside the result. But this time we're passing in extra data too. The seed goes in and comes out. A function that adds a random digit to a number looks like this:
def addDigit (n : Nat) (seed : StdGen) : Nat × StdGen :=
let (rand, seed') := RandomGen.next seed
(n + rand % 10, seed')
And a plain function that multiplies by 10:
def shift (n : Nat) : Nat := n * 10
We want to build a random 2-digit number:
- add a digit,
- shift by 10,
- add another digit.
Without the seed (my side effect) this would just be addDigit (shift (addDigit 0)).
But shift takes no seed, while the second addDigit will need it.
So seed will have to be lifted into context space, and must pass the seed to the second addDigit.
If we did not think of routing/binding and lifting we'd just write the following block of code that actually just does the routing and lifting.
def twoDigitManual (seed : StdGen) : Nat × StdGen :=
let (n1, seed1) := addDigit 0 seed
let n2 := shift n1 -- (pull out seed1 before calling shift (similar to f' and g' above))
let (n3, seed2) := addDigit n2 seed1 -- (pass seed1 along)
(n3, seed2)
Same pattern again: unwrap, apply, thread the extra stuff. So we write bind:
bindState takes as input a function rng that takes in a seed and outputs a value and updated seed.
It also takes a next function where the output of this input should go through, and that function will also produce some usable output and an updated seed.
The types of the actual values that we need are given by α and β.
def bindState (rng : StdGen → α × StdGen) (next : α → StdGen → β × StdGen)
: StdGen → β × StdGen :=
fun seed =>
let (a, seed') := rng seed -- run first computation, get value and new seed
next a seed' -- pass both to the next computation
The seed from the first computation is fed into the second. And unit:
def unitState (x : α) : StdGen → α × StdGen :=
fun seed => (x, seed)
Returns the value, passes the seed through unchanged. We can lift plain functions too:
def liftState (f : α → β) (x : α) : StdGen → β × StdGen :=
unitState (f x)
Now composing without manual plumbing:
def twoDigit : StdGen → Nat × StdGen :=
bindState (addDigit 0) (fun n1 =>
bindState (liftState shift n1) (fun n2 =>
addDigit n2))
Let's parse this. The outermost bindState takes:
addDigit 0— a functionStdGen → Nat × StdGen(the first random digit computation, starting from 0)fun n1 => ...— a function that takes the resultn1 : Natand returns anotherStdGen → Nat × StdGen
Inside that, the inner bindState takes:
liftState shift n1— lifts the pure functionshiftinto seed-land:StdGen → Nat × StdGen(multipliesn1by 10, passes seed through unchanged)fun n2 => addDigit n2— takes the shifted resultn2and adds a second random digit
If the lambdas are confusing, here's the same thing with named functions:
-- Step 1: add a random digit to 0
def step1 : StdGen → Nat × StdGen := addDigit 0
-- Step 2: given n1 from step 1, shift it and add another digit
def step2 (n1 : Nat) : StdGen → Nat × StdGen :=
bindState (liftState shift n1) (fun n2 => addDigit n2)
-- Compose step1 and step2
def twoDigit' : StdGen → Nat × StdGen :=
bindState step1 step2
bindState step1 step2 says: "run step1, feed its result into step2, thread the seed through." Same as the lambda version, just easier to read.
Same pattern, third time.
Monads
It's now time to step back and discern the common structure. In all three cases we had:
- A type that wraps values with some extra context6 — call it
m α.
"Extra context" means different things in each case. For α × String, the context is a log string. For StdGen → α × StdGen, it's a seed threaded through. For List α, it's not really extra data attached to the value — it's the fact that there are multiple possible values. Piponi notes: "monads let you do more than handle side-effects, in particular many types of container object can be viewed as monads. Some of the introductions to monads find it hard to reconcile these two different uses."
-
A function
α → m βthat we couldn't compose directly because the types didn't line up (likef'andg'in example 1,sqrt'andcbrt'in example 2,addDigitin example 3) -
A
bind : m α → (α → m β) → m βthat restored composition (bind,bindList,bindState) by routing things properly (i guess they prefer saying by binding the right inputs to the right functions ?!? I dunno who came up with the name7). -
A
unit : α → m αthat acted as the identity (unit,unitList,unitState) aka lifter.
-- Debuggable: m α = α × String
-- Multivalued: m α = List α
-- Randomised: m α = StdGen → α × StdGen
In each case we were faced with the same problem: we had a function α → m β but needed to apply it to something of type m α instead of α. In each case we solved it by writing bind and unit.
The triple (m, unit, bind) is a monad — where m is the context/multiplicity, unit is the identity (how you enter the monadic world doing nothing), and bind is the composer (how you chain operations inside it). lift f = unit ∘ f is derived from unit and lets you upgrade any ordinary function into a monadic one.
A monadic function is any function of type α → m β. Before we had the word, we were calling them "debuggable functions" (ℝ → ℝ × String), "multivalued functions" (ℂ → List ℂ), and "randomised functions" (α → StdGen → β × StdGen). They are all the same thing — functions whose return type is wrapped in a monadic context m.
In Lean, unit is called pure and bind is written >>= (as a binary operator, we will see this more).
The operator >>= is referred to as "bind" in Lean. I don't know why it's called that.