diff --git a/Cslib/Foundations/Control/Monad/Free/Effects.lean b/Cslib/Foundations/Control/Monad/Free/Effects.lean index e1675c01a..e5c163221 100644 --- a/Cslib/Foundations/Control/Monad/Free/Effects.lean +++ b/Cslib/Foundations/Control/Monad/Free/Effects.lean @@ -100,9 +100,8 @@ The canonical interpreter `toStateM` derived from `liftM` agrees with the hand-w recursive interpreter `run` for `FreeState`. -/ @[simp] -theorem run_toStateM {α : Type u} (comp : FreeState σ α) : - (toStateM comp).run = run comp := by - ext s₀ : 1 +theorem run_toStateM {α : Type u} (comp : FreeState σ α) (s₀ : σ) : + (toStateM comp).run s₀ = pure (run comp s₀) := by induction comp generalizing s₀ with | pure a => rfl | liftBind op cont ih => @@ -123,11 +122,10 @@ lemma run_set (s' : σ) (k : PUnit → FreeState σ α) (s₀ : σ) : /-- Run a state computation, returning only the result. -/ def run' (c : FreeState σ α) (s₀ : σ) : α := (run c s₀).1 -@[simp] -theorem run'_toStateM {α : Type u} (comp : FreeState σ α) : - (toStateM comp).run' = run' comp := by - ext s₀ : 1 - rw [run', ← run_toStateM] +-- not `simp` since `StateT.run'` is unfolded by `simp` +theorem run'_toStateM {α : Type u} (comp : FreeState σ α) (s₀ : σ) : + (toStateM comp).run' s₀ = pure (run' comp s₀) := by + rw [run', StateT.run'_eq, run_toStateM] rfl @[simp] @@ -204,18 +202,40 @@ lemma run_pure [Monoid ω] (a : α) : lemma run_liftBind_tell [Monoid ω] (w : ω) (k : PUnit → FreeWriter ω α) : run (liftBind (.tell w) k) = (let (a, w') := run (k .unit); (a, w * w')) := rfl + +-- https://github.com/leanprover-community/mathlib4/pull/36497 +section missing_from_mathlib + +@[simp] +theorem _root_.WriterT.run_pure [Monoid ω] [Monad M] (a : α) : + WriterT.run (pure a : WriterT ω M α) = pure (a, 1) := rfl + +@[simp] +theorem _root_.WriterT.run_bind [Monoid ω] [Monad M] (x : WriterT ω M α) (f : α → WriterT ω M β) : + WriterT.run (x >>= f) = x.run >>= fun (a, w₁) => (fun (b, w₂) => (b, w₁ * w₂)) <$> (f a).run := + rfl + +@[simp] +theorem _root_.WriterT.run_tell [Monad M] (w : ω) : + WriterT.run (MonadWriter.tell w : WriterT ω M PUnit) = pure (.unit, w) := rfl + +end missing_from_mathlib + /-- The canonical interpreter `toWriterT` derived from `liftM` agrees with the hand-written recursive interpreter `run` for `FreeWriter`. -/ @[simp] -theorem run_toWriterT {α : Type u} [Monoid ω] : - ∀ comp : FreeWriter ω α, (toWriterT comp).run = run comp - | .pure _ => by simp only [toWriterT, liftM_pure, run_pure, pure, WriterT.run] - | liftBind (.tell w) cont => by - simp only [toWriterT, liftM_liftBind, run_liftBind_tell] at * - rw [← run_toWriterT] - congr +theorem run_toWriterT {α : Type u} [Monoid ω] (comp : FreeWriter ω α) : + (toWriterT comp).run = pure (run comp) := by + ext : 1 + induction comp with + | pure _ => simp only [toWriterT, liftM_pure, run_pure, pure, WriterT.run] + | liftBind op cont ih => + cases op + simp only [toWriterT, liftM_liftBind, run_liftBind_tell, Id.run_pure] at * + rw [ ← ih] + simp [WriterT.run_bind, writerInterp] /-- `listen` captures the log produced by a subcomputation incrementally. It traverses the computation, @@ -301,9 +321,8 @@ The canonical interpreter `toContT` derived from `liftM` agrees with the hand-wr recursive interpreter `run` for `FreeCont`. -/ @[simp] -theorem run_toContT {α : Type u} (comp : FreeCont r α) : - (toContT comp).run = run comp := by - ext k +theorem run_toContT {α : Type u} (comp : FreeCont r α) (k : α → r) : + (toContT comp).run k = pure (run comp k) := by induction comp with | pure a => rfl | liftBind op cont ih => @@ -387,7 +406,7 @@ The canonical interpreter `toReaderM` derived from `liftM` agrees with the hand- recursive interpreter `run` for `FreeReader` -/ @[simp] theorem run_toReaderM {α : Type u} (comp : FreeReader σ α) (s : σ) : - (toReaderM comp).run s = run comp s := by + (toReaderM comp).run s = pure (run comp s) := by induction comp generalizing s with | pure a => rfl | liftBind op cont ih =>