From 2f9dfce1284f2b8e6b7d6c33c2556761393d2640 Mon Sep 17 00:00:00 2001 From: Jules Date: Thu, 21 May 2026 00:48:00 -0700 Subject: [PATCH] laurel: wire old() to Core two-state semantics Modifies-clause frame conditions now flow through Core's `old`-prefixed identifiers via Core's native two-state semantics, replacing the previous synthetic `$heap_in` parameter. 1. HeapParameterization: heap-writing procedures take `$heap` as a true inout parameter (same name in inputs and outputs) rather than a `$heap_in` / `$heap` pair with a synthesized assignment prelude. 2. ModifiesClauses: frame condition references `old($heap)` via `StmtExpr.Old` instead of a separate `$heap_in` variable. 3. New PushOldInward Laurel-to-Laurel pass: distributes `StmtExpr.Old` through its sub-expressions until each `Old` immediately wraps a Local Var. Warns if `old(...)` does not mention any inout parameter. 4. LaurelToCoreTranslator: `Old (Var (Local n))` translates directly to `fvar (mkOld n)`. Inout parameters at call sites are detected and emit `.inoutArg` rather than paired `.inArg` + `.outArg`. Call-arg construction is shared between the two StaticCall sites via a new buildCallArgs helper. Tests: PushOldInwardTest exercises the new pass on hand-built ASTs covering distribution over operators, calls, conditionals, fields, quantifiers, and the warning behaviour. Existing modifies-clause coverage in T2_ModifiesClauses already drives the synthetic `Old` path end-to-end. --- .../Laurel/HeapParameterization.lean | 24 +- .../Laurel/LaurelCompilationPipeline.lean | 5 + .../Laurel/LaurelToCoreTranslator.lean | 72 ++- Strata/Languages/Laurel/ModifiesClauses.lean | 23 +- Strata/Languages/Laurel/PushOldInward.lean | 137 ++++++ .../Languages/Laurel/PushOldInwardTest.lean | 454 ++++++++++++++++++ 6 files changed, 670 insertions(+), 45 deletions(-) create mode 100644 Strata/Languages/Laurel/PushOldInward.lean create mode 100644 StrataTest/Languages/Laurel/PushOldInwardTest.lean diff --git a/Strata/Languages/Laurel/HeapParameterization.lean b/Strata/Languages/Laurel/HeapParameterization.lean index fecaf5350c..7cde1383fb 100644 --- a/Strata/Languages/Laurel/HeapParameterization.lean +++ b/Strata/Languages/Laurel/HeapParameterization.lean @@ -455,37 +455,31 @@ where def heapTransformProcedure (model: SemanticModel) (proc : Procedure) : TransformM Procedure := do let heapName : Identifier := "$heap" - let heapInName : Identifier := "$heap_in" let readsHeap := (← get).heapReaders.contains proc.name let writesHeap := (← get).heapWriters.contains proc.name if writesHeap then - -- This procedure writes the heap - add $heap_in as input and $heap as output - -- At the start, assign $heap_in to $heap, then use $heap throughout - let heapInParam : Parameter := { name := heapInName, type := ⟨.THeap, none⟩ } - let heapOutParam : Parameter := { name := heapName, type := ⟨.THeap, none⟩ } + -- This procedure writes the heap — $heap appears in both inputs and outputs + -- (true inout). Core's two-state semantics provide `old $heap` automatically. + let heapParam : Parameter := { name := heapName, type := ⟨.THeap, none⟩ } - let inputs' := heapInParam :: proc.inputs - let outputs' := heapOutParam :: proc.outputs + let inputs' := heapParam :: proc.inputs + let outputs' := heapParam :: proc.outputs - -- Preconditions use $heap_in (the input state) - let preconditions' ← proc.preconditions.mapM (·.mapM (heapTransformExpr heapInName model)) + -- Preconditions reference $heap (evaluated at entry before any mutation) + let preconditions' ← proc.preconditions.mapM (·.mapM (heapTransformExpr heapName model)) let bodyValueIsUsed := !proc.outputs.isEmpty let body' ← match proc.body with | .Transparent bodyExpr => - -- First assign $heap_in to $heap, then transform body using $heap - let assignHeap := mkMd (.Assign [mkVarMd (.Local heapName)] (mkMd (.Var (.Local heapInName)))) let bodyExpr' ← heapTransformExpr heapName model bodyExpr bodyValueIsUsed - pure (.Transparent (mkMd (.Block [assignHeap, bodyExpr'] none))) + pure (.Transparent bodyExpr') | .Opaque postconds impl modif => - -- Postconditions use $heap (the output state) let postconds' ← postconds.mapM (·.mapM (heapTransformExpr heapName model)) let impl' ← match impl with | some implExpr => - let assignHeap := mkMd (.Assign [mkVarMd (.Local heapName)] (mkMd (.Var (.Local heapInName)))) let implExpr' ← heapTransformExpr heapName model implExpr bodyValueIsUsed - pure (some (mkMd (.Block [assignHeap, implExpr'] none))) + pure (some implExpr') | none => pure none let modif' ← modif.mapM (heapTransformExpr heapName model ·) pure (.Opaque postconds' impl' modif') diff --git a/Strata/Languages/Laurel/LaurelCompilationPipeline.lean b/Strata/Languages/Laurel/LaurelCompilationPipeline.lean index c6984120fe..fd993105c6 100644 --- a/Strata/Languages/Laurel/LaurelCompilationPipeline.lean +++ b/Strata/Languages/Laurel/LaurelCompilationPipeline.lean @@ -10,6 +10,7 @@ import Strata.Languages.Laurel.DesugarShortCircuit import Strata.Languages.Laurel.EliminateReturnsInExpression import Strata.Languages.Laurel.EliminateValueReturns import Strata.Languages.Laurel.ConstrainedTypeElim +import Strata.Languages.Laurel.PushOldInward import Strata.Languages.Laurel.TypeAliasElim import Strata.Languages.Core.Verifier import Strata.Util.Profile @@ -111,6 +112,10 @@ private def laurelPipeline : Array LaurelPass := #[ run := fun p m => let (p', diags) := modifiesClausesTransform m p (p', diags, {}) }, + { name := "PushOldInward" + run := fun p _m => + let (p', diags) := pushOldInward p + (p', diags, {}) }, { name := "InferHoleTypes" run := fun p m => let (p', diags, stats) := inferHoleTypes m p diff --git a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean index beb50ad9b0..e54cff86ae 100644 --- a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean +++ b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean @@ -300,7 +300,16 @@ def translateExpr (expr : StmtExprMd) | .AsType target _ => throwExprDiagnostic $ diagnosticFromSource expr.source "AsType expression translation" DiagnosticType.NotYetImplemented | .Assigned _ => throwExprDiagnostic $ diagnosticFromSource expr.source "assigned expression translation" DiagnosticType.NotYetImplemented - | .Old value => throwExprDiagnostic $ diagnosticFromSource expr.source "old expression translation" DiagnosticType.NotYetImplemented + | .Old value => + -- After PushOldInward, every `Old` immediately wraps a Local Var of an inout parameter. + match value.val with + | .Var (.Local name) => + let coreTy ← translateType (model.get name).getType + return .fvar () (Core.CoreIdent.mkOld name.text) (some coreTy) + | _ => + throwExprDiagnostic $ diagnosticFromSource expr.source + "old(...) should have been pushed inward to a variable reference" + DiagnosticType.StrataBug | .Fresh _ => throwExprDiagnostic $ diagnosticFromSource expr.source "fresh expression translation" DiagnosticType.NotYetImplemented | .Assert _ => throwExprDiagnostic $ diagnosticFromSource expr.source "assert expression translation" DiagnosticType.NotYetImplemented | .Assume _ => throwExprDiagnostic $ diagnosticFromSource expr.source "assume expression translation" DiagnosticType.NotYetImplemented @@ -351,6 +360,35 @@ def throwStmtDiagnostic (d : DiagnosticModel): TranslateM (List Core.Statement) modify fun s => { s with coreProgramHasSuperfluousErrors := true } return [] +/-- +Look up the callee's signature and convert positional `coreArgs` into Core +`CallArg`s, emitting `.inoutArg ident` for parameters that appear in both +inputs and outputs (true inout) and `.inArg` otherwise. Returns the call args +along with the callee's outputs and inout names so the caller can build the +matching `.outArg` list. +-/ +private def buildCallArgs (calleeId : Identifier) (coreArgs : List Core.Expression.Expr) + : TranslateM (List (Core.CallArg Core.Expression) × List Parameter × List String) := do + let s ← get + let (calleeInputs, calleeOutputs) := match s.model.get calleeId with + | .staticProcedure proc => (proc.inputs, proc.outputs) + | .instanceProcedure _ proc => (proc.inputs, proc.outputs) + | _ => ([], []) + let calleeInputNames := calleeInputs.map (·.name.text) + let calleeOutputNames := calleeOutputs.map (·.name.text) + let calleeInoutNames := calleeInputNames.filter (calleeOutputNames.contains ·) + let inoutInputIndices := calleeInputNames.zipIdx.filterMap fun (name, i) => + if calleeInoutNames.contains name then some i else none + let mut callArgs : List (Core.CallArg Core.Expression) := [] + for (arg, i) in coreArgs.zipIdx do + if inoutInputIndices.contains i then + match arg with + | .fvar _ ident _ => callArgs := callArgs ++ [.inoutArg ident] + | _ => callArgs := callArgs ++ [.inArg arg] + else + callArgs := callArgs ++ [.inArg arg] + return (callArgs, calleeOutputs, calleeInoutNames) + /-- Translate Laurel StmtExpr to Core Statements using the `TranslateM` monad. Diagnostics are emitted into the monad state. @@ -420,12 +458,14 @@ def translateStmt (stmt : StmtExprMd) lhs := lhs ++ [ident] | .Field _ _ => pure () -- already handled above return (inits, lhs) - -- Translate a procedure/instance call: init Declare targets with nondet, then emit call - let translateCallTargets (calleeName : String) (args : List StmtExprMd) : TranslateM (List Core.Statement) := do + -- Translate a procedure/instance call: init Declare targets with nondet, then emit call. + let translateCallTargets (calleeId : Identifier) (args : List StmtExprMd) : TranslateM (List Core.Statement) := do let coreArgs ← args.mapM (fun a => translateExpr a) let (inits, lhs) ← initTargetsNondet - let outArgs : List (Core.CallArg Core.Expression) := lhs.map .outArg - return inits ++ [Core.Statement.call calleeName (coreArgs.map .inArg ++ outArgs) md] + let (callArgs, _, calleeInoutNames) ← buildCallArgs calleeId coreArgs + let outArgs : List (Core.CallArg Core.Expression) := + lhs.filter (fun id => !calleeInoutNames.contains id.name) |>.map .outArg + return inits ++ [Core.Statement.call calleeId.text (callArgs ++ outArgs) md] -- Match on the value to decide how to translate match _hv : value.val with | .StaticCall callee args => @@ -441,9 +481,9 @@ def translateStmt (stmt : StmtExprMd) | _ => throwStmtDiagnostic $ md.toDiagnostic "function call without a single target" DiagnosticType.StrataBug else - translateCallTargets callee.text args + translateCallTargets callee args | .InstanceCall _target callee args => - translateCallTargets callee.text args + translateCallTargets callee args | .Hole _ _ => -- Hole RHS: havoc all targets (unmodeled call side-effect). dispatchTargets @@ -472,22 +512,18 @@ def translateStmt (stmt : StmtExprMd) exprAsUnusedInit stmt md else let coreArgs ← args.mapM (fun a => translateExpr a) - -- Generate throwaway LHS variables for all outputs so Core arity - -- checking passes (lhs.length == outputs.length). - let outputs := match model.get callee with - | .staticProcedure proc => proc.outputs - | .instanceProcedure _ proc => proc.outputs - | _ => [] + let (callArgs, calleeOutputs, calleeInoutNames) ← buildCallArgs callee coreArgs + -- Generate throwaway LHS for output-only params so Core arity checking passes. let mut inits : List Core.Statement := [] - let mut lhs : List Core.CoreIdent := [] - for out in outputs do + let mut outArgs : List (Core.CallArg Core.Expression) := [] + for out in calleeOutputs do + if calleeInoutNames.contains out.name.text then continue let id ← freshId let ident : Core.CoreIdent := ⟨s!"$unused_{id}", ()⟩ let coreType := LTy.forAll [] (← translateType out.type) inits := inits ++ [Core.Statement.init ident coreType .nondet md] - lhs := lhs ++ [ident] - let outArgs : List (Core.CallArg Core.Expression) := lhs.map .outArg - return inits ++ [Core.Statement.call callee.text (coreArgs.map .inArg ++ outArgs) md] + outArgs := outArgs ++ [.outArg ident] + return inits ++ [Core.Statement.call callee.text (callArgs ++ outArgs) md] | .InstanceCall .. => -- Instance method call as statement: no return value, treated as no-op return ([]) diff --git a/Strata/Languages/Laurel/ModifiesClauses.lean b/Strata/Languages/Laurel/ModifiesClauses.lean index 20fd01445d..315d895e57 100644 --- a/Strata/Languages/Laurel/ModifiesClauses.lean +++ b/Strata/Languages/Laurel/ModifiesClauses.lean @@ -18,7 +18,7 @@ and conjoining it with the postcondition. After this pass, the modifies list is cleared since its semantics have been absorbed into the postcondition. This pass should run after heap parameterization, which has already: -- Added explicit heap parameters ($heap_in, $heap) +- Added explicit heap parameter ($heap as inout) - Transformed field accesses to readField/updateField calls - Collected field constants @@ -27,7 +27,7 @@ all field values are preserved between the input and output heaps. Generates: forall $obj: Composite, $fld: Field => - $obj < $heap_in.nextReference && notModified($obj) ==> readField($heap_in, $obj, $fld) == readField($heap, $obj, $fld) + $obj < old($heap).nextReference && notModified($obj) ==> readField(old($heap), $obj, $fld) == readField($heap, $obj, $fld) where `notModified($obj)` is the conjunction of `$obj != e` for each single entry `e`, and `!(select(s, $obj))` for each set entry `s`. @@ -94,20 +94,20 @@ Build the modifies frame condition as a Laurel StmtExpr. Generates a single quantified formula: forall $obj: Composite, $fld: Field => - notModified($obj) && $obj < $heap_in.nextReference ==> readField($heap_in, $obj, $fld) == readField($heap, $obj, $fld) + notModified($obj) && $obj < old($heap).nextReference ==> readField(old($heap), $obj, $fld) == readField($heap, $obj, $fld) Returns `none` if there are no entries. -/ def buildModifiesEnsures (proc: Procedure) (model: SemanticModel) (modifiesExprs : List StmtExprMd) - (heapInName heapOutName : Identifier) : Option StmtExprMd := + (heapName : Identifier) : Option StmtExprMd := let entries := extractModifiesEntries model modifiesExprs let objName : Identifier := "$modifies_obj" let fldName : Identifier := "$modifies_fld" let obj := mkMd <| .Var (.Local objName) let fld := mkMd <| .Var (.Local fldName) - let heapIn := mkMd <| .Var (.Local heapInName) - let heapOut := mkMd <| .Var (.Local heapOutName) - -- Build the "obj is allocated" condition: Composite..ref($obj) < $heap_in.nextReference + let heapIn := mkMd <| .Old (mkMd (.Var (.Local heapName))) + let heapOut := mkMd <| .Var (.Local heapName) + -- Build the "obj is allocated" condition: Composite..ref($obj) < old($heap).nextReference let heapCounter := mkMd <| .StaticCall "Heap..nextReference!" [heapIn] let objRef := mkMd <| .StaticCall "Composite..ref!" [obj] let objAllocated := mkMd <| .PrimitiveOp .Lt [objRef, heapCounter] @@ -115,10 +115,10 @@ def buildModifiesEnsures (proc: Procedure) (model: SemanticModel) (modifiesExprs then objAllocated else -- Build the "not modified" precondition from all entries - -- Combine: $obj < $heap_in.nextReference && notModified($obj) + -- Combine: $obj < old($heap).nextReference && notModified($obj) let notModified := conjoinAll (entries.map (buildNotModifiedForEntry obj)) mkMd <| .PrimitiveOp .And [objAllocated, notModified] - -- Build: readField($heap_in, $obj, $fld) == readField($heap, $obj, $fld) + -- Build: readField(old($heap), $obj, $fld) == readField($heap, $obj, $fld) let readIn := mkMd <| .StaticCall "readField" [heapIn, obj, fld] let readOut := mkMd <| .StaticCall "readField" [heapOut, obj, fld] let heapUnchanged := mkMd <| .PrimitiveOp .Eq [readIn, readOut] @@ -145,7 +145,7 @@ may modify anything on the heap), and the modifies list is simply cleared. If the procedure has a `$heap` but no modifies clause, adds a postcondition that all allocated objects are preserved between heaps: - `forall $obj: Composite, $fld: Field => $obj < $heap_in.nextReference ==> readField($heap_in, $obj, $fld) == readField($heap, $obj, $fld)` + `forall $obj: Composite, $fld: Field => $obj < old($heap).nextReference ==> readField(old($heap), $obj, $fld) == readField($heap, $obj, $fld)` If the modifies clause uses a wildcard (`*`), the frame condition is skipped entirely — the procedure may modify anything. @@ -159,9 +159,8 @@ def transformModifiesClauses (model: SemanticModel) -- modifies * means the procedure can modify anything; no frame condition .ok { proc with body := .Opaque postconds impl [] } else if hasHeapOut proc then - let heapInName : Identifier := "$heap_in" let heapName : Identifier := "$heap" - let frameCondition := buildModifiesEnsures proc model modifiesExprs heapInName heapName + let frameCondition := buildModifiesEnsures proc model modifiesExprs heapName let postconds' := match frameCondition with | some frame => postconds ++ [{ condition := frame, summary := "modifies clause" }] | none => postconds diff --git a/Strata/Languages/Laurel/PushOldInward.lean b/Strata/Languages/Laurel/PushOldInward.lean new file mode 100644 index 0000000000..fc5eeba524 --- /dev/null +++ b/Strata/Languages/Laurel/PushOldInward.lean @@ -0,0 +1,137 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ +module + +public import Strata.Languages.Laurel.MapStmtExpr + +/-! +# Push `Old` Inward + +Distribute `StmtExpr.Old` over its sub-expressions until each `Old` immediately +wraps a variable reference. After this pass, the Laurel-to-Core translator only +needs to handle `Old (Var (Local n))`: every other `Old` shape has been pushed +deeper or eliminated. + +If an `Old e` does not contain any inout parameter of the enclosing procedure, +`old(...)` has no effect and we emit a warning. The wrapper is then dropped. +-/ + +namespace Strata +namespace Laurel + +public section + +structure PushOldState where + inoutNames : List String := [] + diagnostics : List DiagnosticModel := [] + +abbrev PushOldM := StateM PushOldState + +private def warn (source : Option FileRange) (msg : String) : PushOldM Unit := do + modify fun s => { s with diagnostics := s.diagnostics ++ [diagnosticFromSource source msg .Warning] } + +/-- Does `expr` reference any variable named in `inoutNames`? -/ +partial def mentionsAnyInout (inoutNames : List String) (expr : StmtExprMd) : Bool := + match expr.val with + | .Var (.Local name) => inoutNames.contains name.text + | .Var (.Field target _) => mentionsAnyInout inoutNames target + | .PrimitiveOp _ args | .StaticCall _ args => + args.any (mentionsAnyInout inoutNames) + | .InstanceCall target _ args => + mentionsAnyInout inoutNames target || args.any (mentionsAnyInout inoutNames) + | .ReferenceEquals l r => mentionsAnyInout inoutNames l || mentionsAnyInout inoutNames r + | .IfThenElse c t e => + mentionsAnyInout inoutNames c || mentionsAnyInout inoutNames t + || (e.elim false (mentionsAnyInout inoutNames)) + | .AsType target _ | .IsType target _ => mentionsAnyInout inoutNames target + | .Quantifier _ _ _ body => mentionsAnyInout inoutNames body + | .Old inner | .Fresh inner | .Assigned inner => mentionsAnyInout inoutNames inner + | _ => false + +/-- Distribute `Old` over the structure of `expr`, stopping once each `Old` + immediately wraps an inout `Var`. Variables that are not inout lose the + surrounding `Old` (no two-state difference applies to them). -/ +partial def pushOld (expr : StmtExprMd) : PushOldM StmtExprMd := do + let source := expr.source + let wrap (v : StmtExpr) : StmtExprMd := ⟨v, source⟩ + match expr.val with + | .Var (.Local name) => + if (← get).inoutNames.contains name.text then + return wrap (.Old expr) + else + return expr + | .Var (.Field target fieldName) => + return wrap (.Var (.Field (← pushOld target) fieldName)) + | .LiteralInt _ | .LiteralBool _ | .LiteralString _ | .LiteralDecimal _ + | .This | .Abstract | .All | .New _ => + return expr + | .PrimitiveOp op args => + return wrap (.PrimitiveOp op (← args.mapM pushOld)) + | .StaticCall callee args => + return wrap (.StaticCall callee (← args.mapM pushOld)) + | .InstanceCall target callee args => + return wrap (.InstanceCall (← pushOld target) callee (← args.mapM pushOld)) + | .ReferenceEquals lhs rhs => + return wrap (.ReferenceEquals (← pushOld lhs) (← pushOld rhs)) + | .IfThenElse cond th el => + let el' ← el.mapM pushOld + return wrap (.IfThenElse (← pushOld cond) (← pushOld th) el') + | .AsType target ty => return wrap (.AsType (← pushOld target) ty) + | .IsType target ty => return wrap (.IsType (← pushOld target) ty) + | .Quantifier mode param trigger body => + let trigger' ← trigger.mapM pushOld + return wrap (.Quantifier mode param trigger' (← pushOld body)) + | .Old inner => pushOld inner -- old(old(e)) ≡ old(e) + | _ => return expr + +/-- Top-down rewrite: every `Old e` is replaced by the result of distributing + `Old` through `e`. Warns once per user-written `Old` that does not mention + any inout parameter. -/ +partial def pushOldInwardExpr (expr : StmtExprMd) : PushOldM StmtExprMd := do + match expr.val with + | .Old inner => + let inner' ← pushOldInwardExpr inner + if mentionsAnyInout (← get).inoutNames inner' then + pushOld inner' + else + warn expr.source "`old(...)` has no effect: expression contains no inout parameter" + return inner' + | _ => + let rewriteOld (e : StmtExprMd) : PushOldM StmtExprMd := do + match e.val with + | .Old inner => + if mentionsAnyInout (← get).inoutNames inner then + pushOld inner + else + warn e.source "`old(...)` has no effect: expression contains no inout parameter" + return inner + | _ => return e + mapStmtExprM (m := PushOldM) rewriteOld expr + +/-- Inout names of a procedure: parameters that appear in both inputs and outputs. -/ +private def procInoutNames (proc : Procedure) : List String := + proc.inputs.filterMap fun inp => + if proc.outputs.any (·.name == inp.name) then some inp.name.text else none + +/-- Apply `pushOldInward` to every expression in a procedure. -/ +private def transformProcedure (proc : Procedure) : PushOldM Procedure := do + modify fun s => { s with inoutNames := procInoutNames proc } + mapProcedureM pushOldInwardExpr proc + +/-- +Push every `StmtExpr.Old` inward until it immediately wraps a variable. +Returns the rewritten program along with any warnings emitted (e.g. for +`old(...)` over an expression with no inout variable). +-/ +def pushOldInward (program : Program) : Program × List DiagnosticModel := + let initState : PushOldState := {} + let (program', finalState) := + (program.staticProcedures.mapM transformProcedure).run initState + ({ program with staticProcedures := program' }, finalState.diagnostics) + +end -- public section +end Laurel +end Strata diff --git a/StrataTest/Languages/Laurel/PushOldInwardTest.lean b/StrataTest/Languages/Laurel/PushOldInwardTest.lean new file mode 100644 index 0000000000..9de5e712b3 --- /dev/null +++ b/StrataTest/Languages/Laurel/PushOldInwardTest.lean @@ -0,0 +1,454 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +/- +Unit tests for `pushOldInward`. Builds StmtExpr AST nodes directly (since +`old(...)` has no Laurel surface syntax) and asserts the rewritten shape and +emitted diagnostics. +-/ + +import Strata.Languages.Laurel.PushOldInward +import Strata.Languages.Laurel.Grammar.AbstractToConcreteTreeTranslator + +namespace Strata.Laurel + +private def mk (e : StmtExpr) : StmtExprMd := ⟨e, none⟩ +private def localVar (name : String) : StmtExprMd := mk (.Var (.Local { text := name })) +private def fieldAccess (obj : StmtExprMd) (field : String) : StmtExprMd := + mk (.Var (.Field obj { text := field })) +private def litInt (n : Int) : StmtExprMd := mk (.LiteralInt n) +private def litBool (b : Bool) : StmtExprMd := mk (.LiteralBool b) +private def add (a b : StmtExprMd) : StmtExprMd := mk (.PrimitiveOp .Add [a, b]) +private def eqOp (a b : StmtExprMd) : StmtExprMd := mk (.PrimitiveOp .Eq [a, b]) +private def call (callee : String) (args : List StmtExprMd) : StmtExprMd := + mk (.StaticCall { text := callee } args) +private def old (e : StmtExprMd) : StmtExprMd := mk (.Old e) +private def intTy : AstNode HighType := ⟨ .TInt, none ⟩ + +/-- Run `pushOldInwardExpr` and return `(rewritten, warningCount)`. -/ +private def runPush (inout : List String) (expr : StmtExprMd) : StmtExprMd × Nat := + let (out, st) := (pushOldInwardExpr expr).run { inoutNames := inout } + (out, st.diagnostics.length) + +private def fmt (expr : StmtExprMd) : String := + toString (Std.Format.pretty (Std.ToFormat.format expr)) + +private def report (label : String) (inout : List String) (input : StmtExprMd) : IO Unit := do + let (out, warns) := runPush inout input + IO.println s!"{label}: {fmt out} warnings={warns}" + +/-! ## Leaf cases -/ + +/-- +info: inoutVar: old(h) warnings=0 +-/ +#guard_msgs in +#eval report "inoutVar" ["h"] (old (localVar "h")) + +/-- +info: nonInout: x warnings=1 +-/ +#guard_msgs in +#eval report "nonInout" ["h"] (old (localVar "x")) + +/-- +info: literal: 42 warnings=1 +-/ +#guard_msgs in +#eval report "literal" ["h"] (old (litInt 42)) + +/-- +info: bareVar: h warnings=0 +-/ +#guard_msgs in +#eval report "bareVar" ["h"] (localVar "h") + +/-! ## Distribution over operators -/ + +/-- +info: distribute: x + old(h) warnings=0 +-/ +#guard_msgs in +#eval report "distribute" ["h"] (old (add (localVar "x") (localVar "h"))) + +/-- +info: bothInout: old(a) + old(b) warnings=0 +-/ +#guard_msgs in +#eval report "bothInout" ["a", "b"] (old (add (localVar "a") (localVar "b"))) + +/-- +info: nestedAdd: x + old(h) + old(k) warnings=0 +-/ +#guard_msgs in +#eval report "nestedAdd" ["h", "k"] + (old (add (add (localVar "x") (localVar "h")) (localVar "k"))) + +/-- +info: eq: old(h) == y warnings=0 +-/ +#guard_msgs in +#eval report "eq" ["h"] (old (eqOp (localVar "h") (localVar "y"))) + +/-! ## Calls -/ + +/-- +info: staticCall: f(old(h), x) warnings=0 +-/ +#guard_msgs in +#eval report "staticCall" ["h"] (old (call "f" [localVar "h", localVar "x"])) + +/-- +info: callNoInout: f(x, y) warnings=1 +-/ +#guard_msgs in +#eval report "callNoInout" ["h"] (old (call "f" [localVar "x", localVar "y"])) + +/-! ## Conditionals -/ + +/-- +info: iteWithInout: if b then old(h) else x warnings=0 +-/ +#guard_msgs in +#eval report "iteWithInout" ["h"] + (old (mk (.IfThenElse (localVar "b") (localVar "h") (some (localVar "x"))))) + +/-- +info: iteNoElse: if b then old(h) warnings=0 +-/ +#guard_msgs in +#eval report "iteNoElse" ["h"] + (old (mk (.IfThenElse (localVar "b") (localVar "h") none))) + +/-! ## Idempotence and nesting -/ + +/-- +info: nested: old(h) warnings=0 +-/ +#guard_msgs in +#eval report "nested" ["h"] (old (old (localVar "h"))) + +/-- +info: tripleNested: old(h) warnings=0 +-/ +#guard_msgs in +#eval report "tripleNested" ["h"] (old (old (old (localVar "h")))) + +/-! ## Field access (mentionsAnyInout descends into target) -/ + +/-- +info: fieldOnInout: old(h)#field warnings=0 +-/ +#guard_msgs in +#eval report "fieldOnInout" ["h"] (old (fieldAccess (localVar "h") "field")) + +/-- +info: fieldOnNonInout: x#field warnings=1 +-/ +#guard_msgs in +#eval report "fieldOnNonInout" ["h"] (old (fieldAccess (localVar "x") "field")) + +/-! ## Old in a sub-expression (top-down rewrite path) -/ + +/-- +info: oldInSubexpr: x + old(h) warnings=0 +-/ +#guard_msgs in +#eval report "oldInSubexpr" ["h"] + (add (localVar "x") (old (localVar "h"))) + +/-- +info: oldInCallArg: f(x, old(h)) warnings=0 +-/ +#guard_msgs in +#eval report "oldInCallArg" ["h"] + (call "f" [localVar "x", old (localVar "h")]) + +/-! ## Quantifier body -/ + +/-- +info: quantifier: forall(i: int) => old(h) == i warnings=0 +-/ +#guard_msgs in +#eval report "quantifier" ["h"] + (mk (.Quantifier .Forall ⟨{ text := "i" }, intTy⟩ none + (old (eqOp (localVar "h") (localVar "i"))))) + +/-! ## Tricky cases -/ + +-- Outer Old wraps an expression containing an inner Old. +/-- +info: oldOfOldPlus: x + old(h) warnings=0 +-/ +#guard_msgs in +#eval report "oldOfOldPlus" ["h"] + (old (add (localVar "x") (old (localVar "h")))) + +-- Chained field access on an inout root. +/-- +info: chainedField: (old(a)#b)#c warnings=0 +-/ +#guard_msgs in +#eval report "chainedField" ["a"] + (old (fieldAccess (fieldAccess (localVar "a") "b") "c")) + +-- Old of a ReferenceEquals: distributes to both sides. +/-- +info: refEq: old(a) == b warnings=0 +-/ +#guard_msgs in +#eval report "refEq" ["a"] + (old (mk (.ReferenceEquals (localVar "a") (localVar "b")))) + +-- Empty inoutNames: every Old wrapper warns and unwraps. +/-- +info: emptyInout: x warnings=1 +-/ +#guard_msgs in +#eval report "emptyInout" [] (old (localVar "x")) + +-- Multiple Olds in the same expression: each is independently processed. +/-- +info: twoOlds: old(h) + (old(k) + x) warnings=0 +-/ +#guard_msgs in +#eval report "twoOlds" ["h", "k"] + (add (old (localVar "h")) (add (old (localVar "k")) (localVar "x"))) + +-- Old wrapping an expression composed of inout vars: distributes everywhere. +/-- +info: oldOfEqInout: old(a) == old(b) warnings=0 +-/ +#guard_msgs in +#eval report "oldOfEqInout" ["a", "b"] + (old (eqOp (localVar "a") (localVar "b"))) + +-- Quantifier whose body has both an Old(inout) and a non-inout reference. +/-- +info: quantBoth: exists(j: int) => old(h) + j == y warnings=0 +-/ +#guard_msgs in +#eval report "quantBoth" ["h"] + (mk (.Quantifier .Exists ⟨{ text := "j" }, intTy⟩ none + (eqOp (add (old (localVar "h")) (localVar "j")) (localVar "y")))) + +-- Bool literal under Old: warns and unwraps regardless of inout context. +/-- +info: oldOfBool: true warnings=1 +-/ +#guard_msgs in +#eval report "oldOfBool" ["h"] (old (litBool true)) + +-- Old of a PrimitiveOp where one arg is already old(inout): collapses fine. +/-- +info: redundantInner: old(h) + old(k) warnings=0 +-/ +#guard_msgs in +#eval report "redundantInner" ["h", "k"] + (old (add (old (localVar "h")) (localVar "k"))) + +-- Field path where only an inner segment is inout. +/-- +info: fieldOnInnerInout: f(a, old(h))#field warnings=0 +-/ +#guard_msgs in +#eval report "fieldOnInnerInout" ["h"] + (old (fieldAccess (call "f" [localVar "a", localVar "h"]) "field")) + +/-! ## Stress cases -/ + +-- Quantifier param shadows an inout name. The body's reference to `h` +-- *should* still be rewritten to old(h) because shadowing is a scoping +-- concern that pushOld does not currently track. This documents the +-- current behavior so any future shadow handling is intentional. +/-- +info: shadowingQuantifier: forall(h: int) => old(h) + 1 warnings=0 +-/ +#guard_msgs in +#eval report "shadowingQuantifier" ["h"] + (mk (.Quantifier .Forall ⟨{ text := "h" }, intTy⟩ none + (old (add (localVar "h") (litInt 1))))) + +-- Old whose argument is a primitive op of two literals: no inout, warn. +/-- +info: oldOfLitOp: 1 + 2 warnings=1 +-/ +#guard_msgs in +#eval report "oldOfLitOp" ["h"] (old (add (litInt 1) (litInt 2))) + +-- Mix of nested Old and non-Old subexpressions in a call's args. +/-- +info: callMixedOld: g(old(h), old(k), x, y) warnings=0 +-/ +#guard_msgs in +#eval report "callMixedOld" ["h", "k"] + (call "g" [old (localVar "h"), old (localVar "k"), localVar "x", localVar "y"]) + +-- ITE wrapped in Old where only the condition contains an inout: distributes. +/-- +info: iteCondInout: if old(h) then 1 else 2 warnings=0 +-/ +#guard_msgs in +#eval report "iteCondInout" ["h"] + (old (mk (.IfThenElse (localVar "h") (litInt 1) (some (litInt 2))))) + +-- Reference equality with inout on both sides + a deeply nested old. +/-- +info: deepRefEq: old(a) == f(old(b), old(b)) warnings=0 +-/ +#guard_msgs in +#eval report "deepRefEq" ["a", "b"] + (old (mk (.ReferenceEquals (localVar "a") (call "f" [localVar "b", old (localVar "b")])))) + +-- Old over a quantifier whose body uses the bound variable. The bound variable +-- is not inout, so it stays as is; the inout `h` referenced in the body is +-- rewritten. +/-- +info: oldOverQuant: forall(i: int) => old(h) > i warnings=0 +-/ +#guard_msgs in +#eval report "oldOverQuant" ["h"] + (old (mk (.Quantifier .Forall ⟨{ text := "i" }, intTy⟩ none + (mk (.PrimitiveOp .Gt [localVar "h", localVar "i"]))))) + +-- Old wrapping a call whose only argument is itself an old(inout). +-- Top-down, the inner old is processed first, then the outer distributes. +/-- +info: oldOverOldArg: f(old(h)) warnings=0 +-/ +#guard_msgs in +#eval report "oldOverOldArg" ["h"] (old (call "f" [old (localVar "h")])) + +-- Empty inout list and an expression with no Old: passthrough, no warnings. +/-- +info: noOpExpr: a + b + c warnings=0 +-/ +#guard_msgs in +#eval report "noOpExpr" [] + (add (add (localVar "a") (localVar "b")) (localVar "c")) + +-- Old wrapping a Field where the inout is the field's owner (chain on inout). +/-- +info: oldFieldInout: old(h)#count warnings=0 +-/ +#guard_msgs in +#eval report "oldFieldInout" ["h"] (old (fieldAccess (localVar "h") "count")) + +-- Multiple non-inout Olds in the same expression each warn independently. +/-- +info: twoBadOlds: x + y warnings=2 +-/ +#guard_msgs in +#eval report "twoBadOlds" ["h"] + (add (old (localVar "x")) (old (localVar "y"))) + +/-! ## Edge cases on the public surface -/ + +-- pushOld direct invocation: confirms the helper, when called on an +-- inout var, wraps it in Old (used internally when Old(inout) is at root). +private def runPushOldOnly (inout : List String) (expr : StmtExprMd) : StmtExprMd × Nat := + let (out, st) := (pushOld expr).run { inoutNames := inout } + (out, st.diagnostics.length) + +private def reportPushOld (label : String) (inout : List String) (input : StmtExprMd) : IO Unit := do + let (out, warns) := runPushOldOnly inout input + IO.println s!"{label}: {fmt out} warnings={warns}" + +-- pushOld on a bare inout var directly emits old(h). +/-- +info: pushOldOnVar: old(h) warnings=0 +-/ +#guard_msgs in +#eval reportPushOld "pushOldOnVar" ["h"] (localVar "h") + +-- pushOld on a non-inout var passes through (no Old wrapper, no warning; +-- the warning is the responsibility of pushOldInwardExpr's outer check). +/-- +info: pushOldOnNonInout: x warnings=0 +-/ +#guard_msgs in +#eval reportPushOld "pushOldOnNonInout" ["h"] (localVar "x") + +-- mentionsAnyInout direct: confirms it descends into nested structure. +/-- +info: true +-/ +#guard_msgs in +#eval mentionsAnyInout ["h"] + (call "f" [add (litInt 1) (fieldAccess (localVar "h") "x")]) + +/-- +info: false +-/ +#guard_msgs in +#eval mentionsAnyInout ["h"] (add (localVar "a") (litInt 0)) + +-- An assignment LHS field target shouldn't be reached by pushOldInwardExpr +-- via a top-level `Old`, but if an `Old` happens to wrap an `Assign`'s value +-- expression we should still produce something sensible. Synthetic check: +-- old wrapping a deeply nested call whose root contains the inout. +/-- +info: deepCall: f(g(h(old(a)), x), 7) warnings=0 +-/ +#guard_msgs in +#eval report "deepCall" ["a"] + (call "f" [call "g" [call "h" [old (localVar "a")], localVar "x"], litInt 7]) + +/-! ## Pathological cases -/ + +-- Old over a quantifier whose trigger references inout: trigger and body +-- both get the Old distributed. +/-- +info: quantWithTrigger: forall(i: int){old(h)} => old(h) > i warnings=0 +-/ +#guard_msgs in +#eval report "quantWithTrigger" ["h"] + (old (mk (.Quantifier .Forall ⟨{ text := "i" }, intTy⟩ + (some (localVar "h")) + (mk (.PrimitiveOp .Gt [localVar "h", localVar "i"]))))) + +-- Empty call with no args + inout context: distributes trivially over zero args, +-- the result still has no inout reference, so it warns. +/-- +info: emptyCall: f() warnings=1 +-/ +#guard_msgs in +#eval report "emptyCall" ["h"] (old (call "f" [])) + +-- Top-down: an Old buried under a Quantifier (not the quantifier itself). +/-- +info: oldUnderQuant: forall(i: int) => old(h) + i == 0 warnings=0 +-/ +#guard_msgs in +#eval report "oldUnderQuant" ["h"] + (mk (.Quantifier .Forall ⟨{ text := "i" }, intTy⟩ none + (eqOp (add (old (localVar "h")) (localVar "i")) (litInt 0)))) + +-- Two layers of quantifiers, Old in the inner body referencing the outer-bound +-- variable (which is not inout) plus an actual inout. +/-- +info: nestedQuant: forall(i: int) => exists(j: int) => old(h) + i + j == 0 warnings=0 +-/ +#guard_msgs in +#eval report "nestedQuant" ["h"] + (mk (.Quantifier .Forall ⟨{ text := "i" }, intTy⟩ none + (mk (.Quantifier .Exists ⟨{ text := "j" }, intTy⟩ none + (eqOp (add (add (old (localVar "h")) (localVar "i")) + (localVar "j")) + (litInt 0)))))) + +-- Old on the field itself when the chain mixes inout and non-inout. +-- old(x.h.field): the variable here is `x` (non-inout, treated as the root), +-- so the wrapper warns and unwraps; the `h` segment is just a field name, +-- not a Var reference, so it stays untouched. +/-- +info: mixedFieldChain: (x#h)#field warnings=1 +-/ +#guard_msgs in +#eval report "mixedFieldChain" ["h"] + (old (fieldAccess (fieldAccess (localVar "x") "h") "field")) + +end Strata.Laurel