From 226e8593bb0131f559cdfcfe201334a42554e5cb Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Wed, 18 Mar 2026 05:26:35 +0000 Subject: [PATCH 1/2] fix: missing `pure`s in the Id monad --- .../Control/Monad/Free/Effects.lean | 55 +++++++++++++------ 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/Cslib/Foundations/Control/Monad/Free/Effects.lean b/Cslib/Foundations/Control/Monad/Free/Effects.lean index e1675c01a..8903a4c23 100644 --- a/Cslib/Foundations/Control/Monad/Free/Effects.lean +++ b/Cslib/Foundations/Control/Monad/Free/Effects.lean @@ -100,9 +100,8 @@ The canonical interpreter `toStateM` derived from `liftM` agrees with the hand-w recursive interpreter `run` for `FreeState`. -/ @[simp] -theorem run_toStateM {α : Type u} (comp : FreeState σ α) : - (toStateM comp).run = run comp := by - ext s₀ : 1 +theorem run_toStateM {α : Type u} (comp : FreeState σ α) (s₀ : σ) : + (toStateM comp).run s₀ = pure (run comp s₀) := by induction comp generalizing s₀ with | pure a => rfl | liftBind op cont ih => @@ -124,10 +123,9 @@ lemma run_set (s' : σ) (k : PUnit → FreeState σ α) (s₀ : σ) : def run' (c : FreeState σ α) (s₀ : σ) : α := (run c s₀).1 @[simp] -theorem run'_toStateM {α : Type u} (comp : FreeState σ α) : - (toStateM comp).run' = run' comp := by - ext s₀ : 1 - rw [run', ← run_toStateM] +theorem run'_toStateM {α : Type u} (comp : FreeState σ α) (s₀ : σ) : + (toStateM comp).run' s₀ = pure (run' comp s₀) := by + rw [run', StateT.run'_eq, run_toStateM] rfl @[simp] @@ -204,18 +202,40 @@ lemma run_pure [Monoid ω] (a : α) : lemma run_liftBind_tell [Monoid ω] (w : ω) (k : PUnit → FreeWriter ω α) : run (liftBind (.tell w) k) = (let (a, w') := run (k .unit); (a, w * w')) := rfl + +-- https://github.com/leanprover-community/mathlib4/pull/36497 +section missing_from_mathlib + +@[simp] +theorem _root_.WriterT.run_pure [Monoid ω] [Monad M] (a : α) : + WriterT.run (pure a : WriterT ω M α) = pure (a, 1) := rfl + +@[simp] +theorem _root_.WriterT.run_bind [Monoid ω] [Monad M] (x : WriterT ω M α) (f : α → WriterT ω M β) : + WriterT.run (x >>= f) = x.run >>= fun (a, w₁) => (fun (b, w₂) => (b, w₁ * w₂)) <$> (f a).run := + rfl + +@[simp] +theorem _root_.WriterT.run_tell [Monad M] (w : ω) : + WriterT.run (MonadWriter.tell w : WriterT ω M PUnit) = pure (.unit, w) := rfl + +end missing_from_mathlib + /-- The canonical interpreter `toWriterT` derived from `liftM` agrees with the hand-written recursive interpreter `run` for `FreeWriter`. -/ @[simp] -theorem run_toWriterT {α : Type u} [Monoid ω] : - ∀ comp : FreeWriter ω α, (toWriterT comp).run = run comp - | .pure _ => by simp only [toWriterT, liftM_pure, run_pure, pure, WriterT.run] - | liftBind (.tell w) cont => by - simp only [toWriterT, liftM_liftBind, run_liftBind_tell] at * - rw [← run_toWriterT] - congr +theorem run_toWriterT {α : Type u} [Monoid ω] (comp : FreeWriter ω α) : + (toWriterT comp).run = pure (run comp) := by + ext : 1 + induction comp with + | pure _ => simp only [toWriterT, liftM_pure, run_pure, pure, WriterT.run] + | liftBind op cont ih => + cases op + simp only [toWriterT, liftM_liftBind, run_liftBind_tell, Id.run_pure] at * + rw [ ← ih] + simp [WriterT.run_bind, writerInterp] /-- `listen` captures the log produced by a subcomputation incrementally. It traverses the computation, @@ -301,9 +321,8 @@ The canonical interpreter `toContT` derived from `liftM` agrees with the hand-wr recursive interpreter `run` for `FreeCont`. -/ @[simp] -theorem run_toContT {α : Type u} (comp : FreeCont r α) : - (toContT comp).run = run comp := by - ext k +theorem run_toContT {α : Type u} (comp : FreeCont r α) (k : α → r) : + (toContT comp).run k = pure (run comp k) := by induction comp with | pure a => rfl | liftBind op cont ih => @@ -387,7 +406,7 @@ The canonical interpreter `toReaderM` derived from `liftM` agrees with the hand- recursive interpreter `run` for `FreeReader` -/ @[simp] theorem run_toReaderM {α : Type u} (comp : FreeReader σ α) (s : σ) : - (toReaderM comp).run s = run comp s := by + (toReaderM comp).run s = pure (run comp s) := by induction comp generalizing s with | pure a => rfl | liftBind op cont ih => From 8417ae7a98c7ac66f1b3ddb930a94806c4d79a39 Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Wed, 18 Mar 2026 05:37:52 +0000 Subject: [PATCH 2/2] refactor: use simpnormal-forms for standard monadic operations --- .../Control/Monad/Free/Effects.lean | 55 +++++++------------ 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/Cslib/Foundations/Control/Monad/Free/Effects.lean b/Cslib/Foundations/Control/Monad/Free/Effects.lean index 8903a4c23..c5307dc65 100644 --- a/Cslib/Foundations/Control/Monad/Free/Effects.lean +++ b/Cslib/Foundations/Control/Monad/Free/Effects.lean @@ -174,16 +174,6 @@ def toWriterT {α : Type u} [Monoid ω] (comp : FreeWriter ω α) : WriterT ω I theorem toWriterT_unique {α : Type u} [Monoid ω] (g : FreeWriter ω α → WriterT ω Id α) (h : Interprets writerInterp g) : g = toWriterT := h.eq -/-- -Writes a log entry. This creates an effectful node in the computation tree. --/ -abbrev tell (w : ω) : FreeWriter ω PUnit := - lift (.tell w) - -@[simp] -lemma tell_def (w : ω) : - tell w = .lift (.tell w) := rfl - /-- Interprets a `FreeWriter` computation by recursively traversing the tree, accumulating log entries with the monoid operation, and returns the final value paired with the accumulated log. @@ -237,17 +227,24 @@ theorem run_toWriterT {α : Type u} [Monoid ω] (comp : FreeWriter ω α) : rw [ ← ih] simp [WriterT.run_bind, writerInterp] -/-- -`listen` captures the log produced by a subcomputation incrementally. It traverses the computation, -emitting log entries as encountered, and returns the accumulated log as a result. --/ -def listen [Monoid ω] : FreeWriter ω α → FreeWriter ω (α × ω) +/-- Implementation of `MonadWriter.listen`. -/ +protected def listen [Monoid ω] : FreeWriter ω α → FreeWriter ω (α × ω) | .pure a => .pure (a, 1) | .liftBind (.tell w) k => liftBind (.tell w) fun _ => - listen (k .unit) >>= fun (a, w') => + FreeWriter.listen (k .unit) >>= fun (a, w') => pure (a, w * w') +/-- Implementation of `MonadWriter.pass`. -/ +protected def pass [Monoid ω] (m : FreeWriter ω (α × (ω → ω))) : FreeWriter ω α := + let ((a, f), w) := run m + liftBind (.tell (f w)) (fun _ => .pure a) + +instance [Monoid ω] : MonadWriter ω (FreeWriter ω) where + tell w := lift (.tell w) + listen := FreeWriter.listen + pass := FreeWriter.pass + @[simp] lemma listen_pure [Monoid ω] (a : α) : listen (.pure a : FreeWriter ω α) = .pure (a, 1) := rfl @@ -261,24 +258,14 @@ lemma listen_liftBind_tell [Monoid ω] (w : ω) pure (a, w * w')) := by rfl -/-- -`pass` allows a subcomputation to modify its own log. After traversing the computation and -accumulating its log, the resulting function is applied to rewrite the accumulated log -before re-emission. --/ -def pass [Monoid ω] (m : FreeWriter ω (α × (ω → ω))) : FreeWriter ω α := - let ((a, f), w) := run m - liftBind (.tell (f w)) (fun _ => .pure a) +@[simp] +lemma tell_def [Monoid ω] (w : ω) : + (tell w : FreeWriter ω _) = .lift (.tell w) := rfl @[simp] lemma pass_def [Monoid ω] (m : FreeWriter ω (α × (ω → ω))) : pass m = let ((a, f), w) := run m; liftBind (.tell (f w)) fun _ => .pure a := rfl -instance [Monoid ω] : MonadWriter ω (FreeWriter ω) where - tell := tell - listen := listen - pass := pass - end FreeWriter /-! ### Continuation Monad via `FreeM` -/ @@ -341,7 +328,7 @@ lemma run_liftBind_callCC (g : (α → r) → r) (cont : α → FreeCont r β) (k : β → r) : run (liftBind (.callCC g) cont) k = g (fun a => run (cont a) k) := rfl -/-- Call with current continuation for the Free continuation monad. -/ +/-- Universe-generic version of `MonadCont.callCC` -/ def callCC (f : MonadCont.Label α (FreeCont r) β → FreeCont r α) : FreeCont r α := liftBind (.callCC fun k => run (f ⟨fun x => liftBind (.callCC fun _ => k x) pure⟩) k) pure @@ -352,15 +339,15 @@ lemma callCC_def (f : MonadCont.Label α (FreeCont r) β → FreeCont r α) : liftBind (.callCC fun k => run (f ⟨fun x => liftBind (.callCC fun _ => k x) pure⟩) k) pure := rfl -instance : MonadCont (FreeCont r) where - callCC := .callCC - /-- `run` of a `callCC` node simplifies to running the handler with the current continuation. -/ @[simp] lemma run_callCC (f : MonadCont.Label α (FreeCont r) β → FreeCont r α) (k : α → r) : run (callCC f) k = run (f ⟨fun x => liftBind (.callCC fun _ => k x) pure⟩) k := by simp [callCC, run_liftBind_callCC] +instance : MonadCont (FreeCont r) where + callCC := .callCC + end FreeCont /-- Type constructor for reader operations. -/ @@ -388,7 +375,7 @@ def readInterp {α : Type u} : ReaderF σ α → ReaderM σ α /-- Convert a `FreeReader` computation into a `ReaderM` computation. This is the canonical interpreter derived from `liftM`. -/ -def toReaderM {α : Type u} (comp : FreeReader σ α) : ReaderM σ α := +abbrev toReaderM {α : Type u} (comp : FreeReader σ α) : ReaderM σ α := comp.liftM readInterp /-- `toReaderM` is the unique interpreter extending `readInterp`. -/