diff --git a/.kiro/settings/mcp.json b/.kiro/settings/mcp.json index 96d5ca525..c5ede2192 100644 --- a/.kiro/settings/mcp.json +++ b/.kiro/settings/mcp.json @@ -22,7 +22,8 @@ "lean_loogle", "lean_leanfinder", "lean_state_search", - "lean_hammer_premise" + "lean_hammer_premise", + "lean_file_outline" ] } } diff --git a/Strata.lean b/Strata.lean index 4115531d8..0b8785eca 100644 --- a/Strata.lean +++ b/Strata.lean @@ -18,8 +18,9 @@ import Strata.DL.Imperative.Imperative /- Utilities -/ import Strata.Util.Sarif -/- Strata Core -/ +/- Strata Languages -/ import Strata.Languages.Core.StatementSemantics +import Strata.Languages.Laurel.LaurelToCoreTranslator import Strata.Languages.Core.SarifOutput /- Backends -/ diff --git a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean index 1267de667..e6a84fe5f 100644 --- a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean +++ b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean @@ -179,7 +179,7 @@ partial def translateStmtExpr (arg : Arg) : TransM StmtExpr := do let target ← translateStmtExpr arg0 let value ← translateStmtExpr arg1 let md ← getArgMetaData (.op op) - return .Assign target value md + return .Assign [target] value md | q`Laurel.call, #[arg0, argsSeq] => let callee ← translateStmtExpr arg0 let calleeName := match callee with diff --git a/Strata/Languages/Laurel/Grammar/LaurelGrammar.st b/Strata/Languages/Laurel/Grammar/LaurelGrammar.st index 88c654f55..89d6cd995 100644 --- a/Strata/Languages/Laurel/Grammar/LaurelGrammar.st +++ b/Strata/Languages/Laurel/Grammar/LaurelGrammar.st @@ -33,17 +33,17 @@ op parenthesis (inner: StmtExpr): StmtExpr => "(" inner ")"; op assign (target: StmtExpr, value: StmtExpr): StmtExpr => @[prec(10)] target ":=" value ";"; // Binary operators -op add (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(60)] lhs "+" rhs; -op eq (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs "==" rhs; -op neq (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs "!=" rhs; -op gt (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs ">" rhs; -op lt (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs "<" rhs; -op le (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs "<=" rhs; -op ge (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs ">=" rhs; +op add (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(60), leftassoc] lhs " + " rhs; +op eq (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40), leftassoc] lhs " == " rhs; +op neq (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40), leftassoc] lhs " != " rhs; +op gt (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40), leftassoc] lhs " > " rhs; +op lt (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40), leftassoc] lhs " < " rhs; +op le (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40), leftassoc] lhs " <= " rhs; +op ge (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40), leftassoc] lhs " >= " rhs; // Logical operators -op and (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(30)] lhs "&&" rhs; -op or (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(25)] lhs "||" rhs; +op and (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(30), leftassoc] lhs " && " rhs; +op or (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(25), leftassoc] lhs " || " rhs; // If-else category OptionalElse; diff --git a/Strata/Languages/Laurel/HeapParameterization.lean b/Strata/Languages/Laurel/HeapParameterization.lean index c7052b57a..ef8794440 100644 --- a/Strata/Languages/Laurel/HeapParameterization.lean +++ b/Strata/Languages/Laurel/HeapParameterization.lean @@ -10,15 +10,24 @@ import Strata.Languages.Laurel.LaurelFormat /- Heap Parameterization Pass -Transforms procedures that interact with the heap using a global `$heap` variable: +Transforms procedures that interact with the heap by adding explicit heap parameters: -1. All procedures that read or write fields use the global `$heap` variable - - Field reads are translated to calls to `heapRead($heap, )` - - Field writes are translated to assignments to `$heap` via `heapStore` +1. Procedures that write the heap get an inout heap parameter + - Input: `heap : THeap` + - Output: `heap : THeap` + - Field writes become: `heap := heapStore(heap, obj, field, value)` -2. No heap parameters are added to procedure signatures - - The heap is accessed as a global variable - - Procedure calls don't pass or receive heap values +2. Procedures that only read the heap get an in heap parameter + - Input: `heap : THeap` + - Field reads become: `heapRead(heap, obj, field)` + +3. Procedure calls are transformed: + - Calls to heap-writing procedures in expressions: + `f()` => `(var freshVar: type; freshVar, heap := f(heap); freshVar)` + - Calls to heap-writing procedures as statements: + `f()` => `heap := f(heap)` + - Calls to heap-reading procedures: + `f()` => `f(heap)` The analysis is transitive: if procedure A calls procedure B, and B reads/writes the heap, then A is also considered to read/write the heap. @@ -42,13 +51,14 @@ partial def collectExpr (expr : StmtExpr) : StateM AnalysisResult Unit := do | .LocalVariable _ _ i => if let some x := i then collectExpr x | .While c i d b => collectExpr c; collectExpr b; if let some x := i then collectExpr x; if let some x := d then collectExpr x | .Return v => if let some x := v then collectExpr x - | .Assign t v _ => - -- Check if this is a field assignment (heap write) - match t with - | .FieldSelect target _ => - modify fun s => { s with writesHeapDirectly := true } - collectExpr target - | _ => collectExpr t + | .Assign targets v _ => + -- Check if any target is a field assignment (heap write) + for t in targets do + match t with + | .FieldSelect target _ => + modify fun s => { s with writesHeapDirectly := true } + collectExpr target + | _ => collectExpr t collectExpr v | .PureFieldUpdate t _ v => collectExpr t; collectExpr v | .PrimitiveOp _ args => for a in args do collectExpr a @@ -117,6 +127,7 @@ structure TransformState where heapReaders : List Identifier heapWriters : List Identifier fieldTypes : List (Identifier × HighType) := [] -- Maps field names to their value types + freshCounter : Nat := 0 -- Counter for generating fresh variable names abbrev TransformM := StateM TransformState @@ -133,6 +144,11 @@ def readsHeap (name : Identifier) : TransformM Bool := do def writesHeap (name : Identifier) : TransformM Bool := do return (← get).heapWriters.contains name +def freshVarName : TransformM Identifier := do + let s ← get + set { s with freshCounter := s.freshCounter + 1 } + return s!"$tmp{s.freshCounter}" + partial def heapTransformExpr (heapVar : Identifier) (expr : StmtExpr) : TransformM StmtExpr := do match expr with | .FieldSelect target fieldName => @@ -144,8 +160,26 @@ partial def heapTransformExpr (heapVar : Identifier) (expr : StmtExpr) : Transfo return .StaticCall "heapRead" [.Identifier heapVar, t, .Identifier fieldName] | .StaticCall callee args => let args' ← args.mapM (heapTransformExpr heapVar) - -- Heap is global, so no need to pass it as parameter - return .StaticCall callee args' + let calleeReadsHeap ← readsHeap callee + let calleeWritesHeap ← writesHeap callee + if calleeWritesHeap then + -- Heap-writing procedure call in expression context: + -- f(args) => (var freshVar: type; heapVar, freshVar := f(heapVar, args); freshVar) + -- The callee takes heap_in and returns (heap_out, result), we pass our heapVar and receive back into heapVar + let freshVar ← freshVarName + let varDecl := StmtExpr.LocalVariable freshVar .TInt none + -- Call with heapVar as first argument, receives (heap_out, result) which we assign to [heapVar, freshVar] + let callWithHeap := StmtExpr.Assign + [.Identifier heapVar, .Identifier freshVar] + (.StaticCall callee (StmtExpr.Identifier heapVar :: args')) + .empty + return .Block [varDecl, callWithHeap, .Identifier freshVar] none + else if calleeReadsHeap then + -- Heap-reading procedure: add heapVar as first argument (callee expects heap_in) + return .StaticCall callee (StmtExpr.Identifier heapVar :: args') + else + -- Non-heap procedure: no change + return .StaticCall callee args' | .InstanceCall target callee args => let t ← heapTransformExpr heapVar target let args' ← args.mapM (heapTransformExpr heapVar) @@ -155,18 +189,25 @@ partial def heapTransformExpr (heapVar : Identifier) (expr : StmtExpr) : Transfo | .LocalVariable n ty i => return .LocalVariable n ty (← i.mapM (heapTransformExpr heapVar)) | .While c i d b => return .While (← heapTransformExpr heapVar c) (← i.mapM (heapTransformExpr heapVar)) (← d.mapM (heapTransformExpr heapVar)) (← heapTransformExpr heapVar b) | .Return v => return .Return (← v.mapM (heapTransformExpr heapVar)) - | .Assign t v md => - match t with - | .FieldSelect target fieldName => + | .Assign targets v md => + -- Check if first target is a field select (heap write) + match targets with + | [StmtExpr.FieldSelect target fieldName] => let fieldType ← lookupFieldType fieldName match fieldType with | some ty => addFieldConstant fieldName ty | none => addFieldConstant fieldName .TInt -- Fallback to int if type unknown let target' ← heapTransformExpr heapVar target let v' ← heapTransformExpr heapVar v - -- Assign to global heap variable - return .Assign (.Identifier heapVar) (.StaticCall "heapStore" [.Identifier heapVar, target', .Identifier fieldName, v']) md - | _ => return .Assign (← heapTransformExpr heapVar t) (← heapTransformExpr heapVar v) md + -- Assign to heap variable, but wrap in a block that returns the stored value + -- This ensures that when used in expression context, the value is the stored value, not the heap + let heapAssign := StmtExpr.Assign [StmtExpr.Identifier heapVar] (.StaticCall "heapStore" [.Identifier heapVar, target', .Identifier fieldName, v']) md + return .Block [heapAssign, v'] none + | _ => + -- Transform all targets and value + let targets' ← targets.mapM (heapTransformExpr heapVar) + let v' ← heapTransformExpr heapVar v + return .Assign targets' v' md | .PureFieldUpdate t f v => return .PureFieldUpdate (← heapTransformExpr heapVar t) f (← heapTransformExpr heapVar v) | .PrimitiveOp op args => return .PrimitiveOp op (← args.mapM (heapTransformExpr heapVar)) | .ReferenceEquals l r => return .ReferenceEquals (← heapTransformExpr heapVar l) (← heapTransformExpr heapVar r) @@ -184,51 +225,80 @@ partial def heapTransformExpr (heapVar : Identifier) (expr : StmtExpr) : Transfo | other => return other def heapTransformProcedure (proc : Procedure) : TransformM Procedure := do - let heapName := "$heap" + let heapInName := "heap_in" + let heapOutName := "heap_out" let readsHeap := (← get).heapReaders.contains proc.name let writesHeap := (← get).heapWriters.contains proc.name - if readsHeap || writesHeap then - -- This procedure reads or writes the heap - transform to use global $heap - let precondition' ← heapTransformExpr heapName proc.precondition + if writesHeap then + -- This procedure writes the heap - add heap_in as input and heap_out as output + -- At the start, assign heap_in to heap_out, then use heap_out throughout + let heapInParam : Parameter := { name := heapInName, type := .THeap } + let heapOutParam : Parameter := { name := heapOutName, type := .THeap } + + let inputs' := heapInParam :: proc.inputs + let outputs' := heapOutParam :: proc.outputs + + -- Precondition uses heap_in (the input state) + let precondition' ← heapTransformExpr heapInName proc.precondition let body' ← match proc.body with | .Transparent bodyExpr => - let bodyExpr' ← heapTransformExpr heapName bodyExpr - pure (.Transparent bodyExpr') + -- First assign heap_in to heap_out, then transform body using heap_out + let assignHeapOut := StmtExpr.Assign [StmtExpr.Identifier heapOutName] (StmtExpr.Identifier heapInName) .empty + let bodyExpr' ← heapTransformExpr heapOutName bodyExpr + pure (.Transparent (.Block [assignHeapOut, bodyExpr'] none)) | .Opaque postcond impl modif => - let postcond' ← heapTransformExpr heapName postcond - let impl' ← impl.mapM (heapTransformExpr heapName) - let modif' ← modif.mapM (heapTransformExpr heapName) + -- Postcondition uses heap_out (the output state) + let postcond' ← heapTransformExpr heapOutName postcond + let impl' ← match impl with + | some implExpr => + let assignHeapOut := StmtExpr.Assign [StmtExpr.Identifier heapOutName] (StmtExpr.Identifier heapInName) .empty + let implExpr' ← heapTransformExpr heapOutName implExpr + pure (some (.Block [assignHeapOut, implExpr'] none)) + | none => pure none + let modif' ← modif.mapM (heapTransformExpr heapOutName) pure (.Opaque postcond' impl' modif') | .Abstract postcond => - let postcond' ← heapTransformExpr heapName postcond + let postcond' ← heapTransformExpr heapOutName postcond pure (.Abstract postcond') return { proc with + inputs := inputs', + outputs := outputs', precondition := precondition', body := body' } - else - -- This procedure doesn't read or write the heap - -- Still transform contracts in case they reference fields - let precondition' ← heapTransformExpr heapName proc.precondition + else if readsHeap then + -- This procedure only reads the heap - add heap_in as input only + let heapInParam : Parameter := { name := heapInName, type := .THeap } + let inputs' := heapInParam :: proc.inputs + + let precondition' ← heapTransformExpr heapInName proc.precondition let body' ← match proc.body with | .Transparent bodyExpr => - pure (.Transparent bodyExpr) + let bodyExpr' ← heapTransformExpr heapInName bodyExpr + pure (.Transparent bodyExpr') | .Opaque postcond impl modif => - let postcond' ← heapTransformExpr heapName postcond - pure (.Opaque postcond' impl modif) + let postcond' ← heapTransformExpr heapInName postcond + let impl' ← impl.mapM (heapTransformExpr heapInName) + let modif' ← modif.mapM (heapTransformExpr heapInName) + pure (.Opaque postcond' impl' modif') | .Abstract postcond => - let postcond' ← heapTransformExpr heapName postcond + let postcond' ← heapTransformExpr heapInName postcond pure (.Abstract postcond') return { proc with + inputs := inputs', precondition := precondition', body := body' } -def heapParameterization (program : Program) : Program × List Identifier := + else + -- This procedure doesn't read or write the heap - no changes needed + return proc + +def heapParameterization (program : Program) : Program := let heapReaders := computeReadsHeap program.staticProcedures let heapWriters := computeWritesHeap program.staticProcedures -- Extract field types from composite type definitions @@ -240,6 +310,6 @@ def heapParameterization (program : Program) : Program × List Identifier := dbg_trace s!"Heap readers: {heapReaders}" dbg_trace s!"Heap writers: {heapWriters}" let (procs', finalState) := (program.staticProcedures.mapM heapTransformProcedure).run { heapReaders, heapWriters, fieldTypes } - ({ program with staticProcedures := procs', constants := program.constants ++ finalState.fieldConstants }, heapWriters) + { program with staticProcedures := procs', constants := program.constants ++ finalState.fieldConstants } end Strata.Laurel diff --git a/Strata/Languages/Laurel/Laurel.lean b/Strata/Languages/Laurel/Laurel.lean index bd188e0b7..679661846 100644 --- a/Strata/Languages/Laurel/Laurel.lean +++ b/Strata/Languages/Laurel/Laurel.lean @@ -132,8 +132,11 @@ inductive StmtExpr : Type where | LiteralInt (value: Int) | LiteralBool (value: Bool) | Identifier (name : Identifier) - /- Assign is only allowed in an impure context -/ - | Assign (target : StmtExpr) (value : StmtExpr) (md : Imperative.MetaData Core.Expression) + /- Assign is only allowed in an impure context. + For single target assignments, use a single-element list. + Multiple targets are only allowed when the value is a StaticCall to a procedure + with multiple outputs, and the number of targets must match the number of outputs. -/ + | Assign (targets : List StmtExpr) (value : StmtExpr) (md : Imperative.MetaData Core.Expression) /- Used by itself for fields reads and in combination with Assign for field writes -/ | FieldSelect (target : StmtExpr) (fieldName : Identifier) /- PureFieldUpdate is the only way to assign values to fields of pure types -/ diff --git a/Strata/Languages/Laurel/LaurelFormat.lean b/Strata/Languages/Laurel/LaurelFormat.lean index 2fb137be3..88eaf5cb0 100644 --- a/Strata/Languages/Laurel/LaurelFormat.lean +++ b/Strata/Languages/Laurel/LaurelFormat.lean @@ -45,7 +45,7 @@ def formatHighType : HighType → Format Format.joinSep (types.map formatHighType) " & " def formatStmtExpr (s:StmtExpr) : Format := - match h: s with + match s with | .IfThenElse cond thenBr elseBr => "if " ++ formatStmtExpr cond ++ " then " ++ formatStmtExpr thenBr ++ match elseBr with @@ -69,19 +69,22 @@ def formatStmtExpr (s:StmtExpr) : Format := | .LiteralInt n => Format.text (toString n) | .LiteralBool b => if b then "true" else "false" | .Identifier name => Format.text name - | .Assign target value _ => - formatStmtExpr target ++ " := " ++ formatStmtExpr value + | .Assign [single] value _ => + formatStmtExpr single ++ " := " ++ formatStmtExpr value + | .Assign targets value _ => + "(" ++ Format.joinSep (targets.map formatStmtExpr) ", " ++ ")" ++ " := " ++ formatStmtExpr value | .FieldSelect target field => formatStmtExpr target ++ "#" ++ Format.text field | .PureFieldUpdate target field value => formatStmtExpr target ++ " with { " ++ Format.text field ++ " := " ++ formatStmtExpr value ++ " }" | .StaticCall name args => Format.text name ++ "(" ++ Format.joinSep (args.map formatStmtExpr) ", " ++ ")" + | .PrimitiveOp op [a] => + formatOperation op ++ formatStmtExpr a + | .PrimitiveOp op [a, b] => + formatStmtExpr a ++ " " ++ formatOperation op ++ " " ++ formatStmtExpr b | .PrimitiveOp op args => - match args with - | [a] => formatOperation op ++ formatStmtExpr a - | [a, b] => formatStmtExpr a ++ " " ++ formatOperation op ++ " " ++ formatStmtExpr b - | _ => formatOperation op ++ "(" ++ Format.joinSep (args.map formatStmtExpr) ", " ++ ")" + formatOperation op ++ "(" ++ Format.joinSep (args.map formatStmtExpr) ", " ++ ")" | .This => "this" | .ReferenceEquals lhs rhs => formatStmtExpr lhs ++ " === " ++ formatStmtExpr rhs @@ -107,10 +110,6 @@ def formatStmtExpr (s:StmtExpr) : Format := | .Abstract => "abstract" | .All => "all" | .Hole => "" - decreasing_by - all_goals (simp_wf; try omega) - any_goals (rename_i x_in; have := List.sizeOf_lt_of_mem x_in; omega) - subst_vars; cases h; rename_i x_in; have := List.sizeOf_lt_of_mem x_in; omega def formatParameter (p : Parameter) : Format := Format.text p.name ++ ": " ++ formatHighType p.type diff --git a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean index 19606899b..7b3f766ad 100644 --- a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean +++ b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean @@ -56,17 +56,13 @@ def translateExpr (constants : List Constant) (env : TypeEnv) (expr : StmtExpr) | .LiteralBool b => .const () (.boolConst b) | .LiteralInt i => .const () (.intConst i) | .Identifier name => - -- Check if this is a constant (field constant), global variable, or local variable + -- Check if this is a constant (field constant) or local variable if isConstant constants name then -- Constants are global identifiers (functions with no arguments) let ident := Core.CoreIdent.glob name -- Field constants are declared as functions () → Field T -- We just reference them as operations without application .op () ident none - else if name == "$heap" then - -- Global heap variable - let ident := Core.CoreIdent.glob name - .fvar () ident (some (.tcons "Heap" [])) else -- Regular variables are local identifiers let ident := Core.CoreIdent.locl name @@ -126,7 +122,8 @@ def getNameFromMd (md : Imperative.MetaData Core.Expression): String := Translate Laurel StmtExpr to Core Statements Takes the constants list, type environment and output parameter names -/ -def translateStmt (constants : List Constant) (env : TypeEnv) (outputParams : List Parameter) (stmt : StmtExpr) : TypeEnv × List Core.Statement := +def translateStmt (constants : List Constant) (env : TypeEnv) + (outputParams : List Parameter) (stmt : StmtExpr) : TypeEnv × List Core.Statement := match stmt with | @StmtExpr.Assert cond md => let boogieExpr := translateExpr constants env cond @@ -171,19 +168,25 @@ def translateStmt (constants : List Constant) (env : TypeEnv) (outputParams : Li | .TBool => .const () (.boolConst false) | _ => .const () (.intConst 0) (env', [Core.Statement.init ident boogieType defaultExpr]) - | .Assign target value _ => - match target with - | .Identifier name => - -- Check if this is the global heap variable - if name == "$heap" then - let heapIdent := Core.CoreIdent.glob "$heap" - let boogieExpr := translateExpr constants env value - (env, [Core.Statement.set heapIdent boogieExpr]) - else - let ident := Core.CoreIdent.locl name - let boogieExpr := translateExpr constants env value - (env, [Core.Statement.set ident boogieExpr]) - | _ => (env, []) + | .Assign targets value _ => + match targets with + | [.Identifier name] => + let ident := Core.CoreIdent.locl name + let boogieExpr := translateExpr constants env value + (env, [Core.Statement.set ident boogieExpr]) + | _ => + -- Parallel assignment: (var1, var2, ...) := expr + -- Example use is heap-modifying procedure calls: (result, heap) := f(heap, args) + match value with + | .StaticCall callee args => + let boogieArgs := args.map (translateExpr constants env) + let lhsIdents := targets.filterMap fun t => + match t with + | .Identifier name => some (Core.CoreIdent.locl name) + | _ => none + (env, [Core.Statement.call lhsIdents callee boogieArgs]) + | _ => + panic "Assignments with multiple target but without a RHS call should not be constructed" | .IfThenElse cond thenBranch elseBranch => let bcond := translateExpr constants env cond let (_, bthen) := translateStmt constants env outputParams thenBranch @@ -226,7 +229,7 @@ def translateParameterToCore (param : Parameter) : (Core.CoreIdent × LMonoTy) : /-- Translate Laurel Procedure to Core Procedure -/ -def translateProcedure (constants : List Constant) (heapWriters : List Identifier) (proc : Procedure) : Core.Procedure := +def translateProcedure (constants : List Constant) (proc : Procedure) : Core.Procedure := let inputPairs := proc.inputs.map translateParameterToCore let inputs := inputPairs @@ -255,7 +258,7 @@ def translateProcedure (constants : List Constant) (heapWriters : List Identifie let check : Core.Procedure.Check := { expr := translateExpr constants initEnv postcond } [("ensures", check)] | _ => [] - let modifies := if heapWriters.contains proc.name then [Core.CoreIdent.glob "$heap"] else [] + let modifies : List Core.Expression.Ident := [] let spec : Core.Procedure.Spec := { modifies, preconditions, @@ -442,29 +445,21 @@ def translateProcedureToFunction (constants : List Constant) (proc : Procedure) Translate Laurel Program to Core Program -/ def translate (program : Program) : Except (Array DiagnosticModel) Core.Program := do - let sequencedProgram ← liftExpressionAssignments program - let (heapProgram, heapWriters) := heapParameterization sequencedProgram - dbg_trace "=== Program after heapParameterization===" - dbg_trace (toString (Std.Format.pretty (Std.ToFormat.format heapProgram))) + let program := heapParameterization program + let program ← liftExpressionAssignments program + dbg_trace "=== Program after heapParameterization + liftExpressionAssignments ===" + dbg_trace (toString (Std.Format.pretty (Std.ToFormat.format program))) dbg_trace "=================================" -- Separate procedures that can be functions from those that must be procedures - let (funcProcs, procProcs) := heapProgram.staticProcedures.partition canBeBoogieFunction - let procedures := procProcs.map (translateProcedure heapProgram.constants heapWriters) + let (funcProcs, procProcs) := program.staticProcedures.partition canBeBoogieFunction + let procedures := procProcs.map (translateProcedure program.constants) let procDecls := procedures.map (fun p => Core.Decl.proc p .empty) - let laurelFuncDecls := funcProcs.map (translateProcedureToFunction heapProgram.constants) - let constDecls := heapProgram.constants.map translateConstant + let laurelFuncDecls := funcProcs.map (translateProcedureToFunction program.constants) + let constDecls := program.constants.map translateConstant let typeDecls := [heapTypeDecl, fieldTypeDecl, compositeTypeDecl] let funcDecls := [readFunction, updateFunction] let axiomDecls := [readUpdateSameAxiom, readUpdateDiffAxiom] - -- Add global heap variable declaration with a free variable as initializer - let heapTy := LMonoTy.tcons "Heap" [] - let heapInitVar := LExpr.fvar () (Core.CoreIdent.glob "$heap_init") (some heapTy) - let heapVarDecl := Core.Decl.var - (Core.CoreIdent.glob "$heap") - (LTy.forAll [] heapTy) - heapInitVar - .empty - return { decls := typeDecls ++ funcDecls ++ axiomDecls ++ [heapVarDecl] ++ constDecls ++ laurelFuncDecls ++ procDecls } + return { decls := typeDecls ++ funcDecls ++ axiomDecls ++ constDecls ++ laurelFuncDecls ++ procDecls } /-- Verify a Laurel program using an SMT solver diff --git a/Strata/Languages/Laurel/LiftExpressionAssignments.lean b/Strata/Languages/Laurel/LiftExpressionAssignments.lean index eb6668110..0f7ca2dcc 100644 --- a/Strata/Languages/Laurel/LiftExpressionAssignments.lean +++ b/Strata/Languages/Laurel/LiftExpressionAssignments.lean @@ -8,7 +8,6 @@ import Strata.Languages.Laurel.Laurel import Strata.Languages.Laurel.LaurelFormat import Strata.Languages.Core.Verifier - namespace Strata namespace Laurel @@ -26,11 +25,25 @@ Becomes: if (x1 == y1) { ... } -/ +private abbrev TypeEnv := List (Identifier × HighType) + +private def lookupType (env : TypeEnv) (name : Identifier) : HighType := + match env.find? (fun (n, _) => n == name) with + | some (_, ty) => ty + | none => .TInt -- Default fallback + structure SequenceState where insideCondition : Bool prependedStmts : List StmtExpr := [] diagnostics : List DiagnosticModel - tempCounter : Nat := 0 + -- Maps variable names to their counter for generating unique temp names + varCounters : List (Identifier × Nat) := [] + -- Maps variable names to their current snapshot variable name + -- When an assignment is lifted, we create a snapshot and record it here + -- Subsequent references to the variable should use the snapshot + varSnapshots : List (Identifier × Identifier) := [] + -- Type environment mapping variable names to their types + env : TypeEnv := [] abbrev SequenceM := StateM SequenceState @@ -64,30 +77,68 @@ def SequenceM.takePrependedStmts : SequenceM (List StmtExpr) := do modify fun s => { s with prependedStmts := [] } return stmts.reverse -def SequenceM.freshTemp : SequenceM Identifier := do - let counter := (← get).tempCounter - modify fun s => { s with tempCounter := s.tempCounter + 1 } - return s!"__t{counter}" +def SequenceM.freshTempFor (varName : Identifier) : SequenceM Identifier := do + let counters := (← get).varCounters + let counter := counters.find? (·.1 == varName) |>.map (·.2) |>.getD 0 + modify fun s => { s with varCounters := (varName, counter + 1) :: s.varCounters.filter (·.1 != varName) } + return s!"{varName}_{counter}" + +def SequenceM.getSnapshot (varName : Identifier) : SequenceM (Option Identifier) := do + return (← get).varSnapshots.find? (·.1 == varName) |>.map (·.2) + +def SequenceM.setSnapshot (varName : Identifier) (snapshotName : Identifier) : SequenceM Unit := do + modify fun s => { s with varSnapshots := (varName, snapshotName) :: s.varSnapshots.filter (·.1 != varName) } + +def SequenceM.getVarType (varName : Identifier) : SequenceM HighType := do + return lookupType (← get).env varName + +def SequenceM.addToEnv (varName : Identifier) (ty : HighType) : SequenceM Unit := do + modify fun s => { s with env := (varName, ty) :: s.env } + +partial def transformTarget (expr : StmtExpr) : SequenceM StmtExpr := do + match expr with + | .PrimitiveOp op args => + let seqArgs ← args.mapM transformTarget + return .PrimitiveOp op seqArgs + | .StaticCall name args => + let seqArgs ← args.mapM transformTarget + return .StaticCall name seqArgs + | _ => return expr -- Identifiers and other targets stay as-is (no snapshot substitution) mutual /- Process an expression, extracting any assignments to preceding statements. Returns the transformed expression with assignments replaced by variable references. -/ -def transformExpr (expr : StmtExpr) : SequenceM StmtExpr := do +partial def transformExpr (expr : StmtExpr) : SequenceM StmtExpr := do match expr with - | .Assign target value md => + | .Assign targets value md => checkOutsideCondition md -- This is an assignment in expression context -- We need to: 1) execute the assignment, 2) capture the value in a temporary -- This prevents subsequent assignments to the same variable from changing the value let seqValue ← transformExpr value - let assignStmt := StmtExpr.Assign target seqValue md + let assignStmt := StmtExpr.Assign targets seqValue md SequenceM.addPrependedStmt assignStmt - -- Create a temporary variable to capture the assigned value + -- For each target, create a snapshot variable so subsequent references + -- to that variable will see the value after this assignment + for target in targets do + match target with + | .Identifier varName => + let snapshotName ← SequenceM.freshTempFor varName + let snapshotType ← SequenceM.getVarType varName + let snapshotDecl := StmtExpr.LocalVariable snapshotName snapshotType (some (.Identifier varName)) + SequenceM.addPrependedStmt snapshotDecl + SequenceM.setSnapshot varName snapshotName + | _ => pure () + -- Create a temporary variable to capture the assigned value (for expression result) -- Use TInt as the type (could be refined with type inference) - let tempName ← SequenceM.freshTemp - let tempDecl := StmtExpr.LocalVariable tempName .TInt (some target) + -- For multi-target assigns, use the first target + let firstTarget := targets.head?.getD (.Identifier "__unknown") + let tempName ← match firstTarget with + | .Identifier name => SequenceM.freshTempFor name + | _ => SequenceM.freshTempFor "__expr" + let tempDecl := StmtExpr.LocalVariable tempName .TInt (some firstTarget) SequenceM.addPrependedStmt tempDecl -- Return the temporary variable as the expression value return .Identifier tempName @@ -111,21 +162,51 @@ def transformExpr (expr : StmtExpr) : SequenceM StmtExpr := do | .Block stmts metadata => -- Block in expression position: move all but last statement to prepended - let rec next (remStmts: List StmtExpr) := match remStmts with - | [last] => transformExpr last - | head :: tail => do - let seqStmt ← transformStmt head - for s in seqStmt do - SequenceM.addPrependedStmt s - next tail - | [] => return .Block [] metadata - - next stmts + -- Process statements in order, handling assignments specially to set snapshots + match stmts with + | [] => return .Block [] metadata + | [last] => transformExpr last + | _ => + -- Process all but the last statement + let allButLast := stmts.dropLast + let last := stmts.getLast! + for stmt in allButLast do + match stmt with + | .Assign targets value md => + -- For assignments in block context, we need to set snapshots + -- so that subsequent expressions see the correct values + -- IMPORTANT: Use transformTarget for targets (no snapshot substitution) + -- and transformExpr for values (with snapshot substitution) + let seqTargets ← targets.mapM transformTarget + let seqValue ← transformExpr value + let assignStmt := StmtExpr.Assign seqTargets seqValue md + SequenceM.addPrependedStmt assignStmt + -- Create snapshot for variables so subsequent reads + -- see the value after this assignment (not after later assignments) + for target in seqTargets do + match target with + | .Identifier varName => + let snapshotName ← SequenceM.freshTempFor varName + let snapshotType ← SequenceM.getVarType varName + let snapshotDecl := StmtExpr.LocalVariable snapshotName snapshotType (some (.Identifier varName)) + SequenceM.addPrependedStmt snapshotDecl + SequenceM.setSnapshot varName snapshotName + | _ => pure () + | _ => + let seqStmt ← transformStmt stmt + for s in seqStmt do + SequenceM.addPrependedStmt s + -- Process the last statement as an expression + transformExpr last -- Base cases: no assignments to extract | .LiteralBool _ => return expr | .LiteralInt _ => return expr - | .Identifier _ => return expr + | .Identifier varName => do + -- If this variable has a snapshot (from a lifted assignment), use the snapshot + match ← SequenceM.getSnapshot varName with + | some snapshotName => return .Identifier snapshotName + | none => return expr | .LocalVariable _ _ _ => return expr | _ => return expr -- Other cases @@ -133,7 +214,7 @@ def transformExpr (expr : StmtExpr) : SequenceM StmtExpr := do Process a statement, handling any assignments in its sub-expressions. Returns a list of statements (the original one may be split into multiple). -/ -def transformStmt (stmt : StmtExpr) : SequenceM (List StmtExpr) := do +partial def transformStmt (stmt : StmtExpr) : SequenceM (List StmtExpr) := do match stmt with | @StmtExpr.Assert cond md => -- Process the condition, extracting any assignments @@ -151,6 +232,7 @@ def transformStmt (stmt : StmtExpr) : SequenceM (List StmtExpr) := do return [.Block (seqStmts.flatten) metadata] | .LocalVariable name ty initializer => + SequenceM.addToEnv name ty match initializer with | some initExpr => do let seqInit ← transformExpr initExpr @@ -159,10 +241,10 @@ def transformStmt (stmt : StmtExpr) : SequenceM (List StmtExpr) := do | none => return [stmt] - | .Assign target value md => - let seqTarget ← transformExpr target + | .Assign targets value md => + let seqTargets ← targets.mapM transformTarget let seqValue ← transformExpr value - SequenceM.addPrependedStmt <| .Assign seqTarget seqValue md + SequenceM.addPrependedStmt <| .Assign seqTargets seqValue md SequenceM.takePrependedStmts | .IfThenElse cond thenBranch elseBranch => @@ -197,8 +279,11 @@ def transformProcedureBody (body : StmtExpr) : SequenceM StmtExpr := do | multiple => pure <| .Block multiple.reverse none def transformProcedure (proc : Procedure) : SequenceM Procedure := do - -- Reset insideCondition for each procedure to avoid cross-procedure contamination - modify fun s => { s with insideCondition := false } + -- Initialize environment with procedure parameters + let initEnv : TypeEnv := proc.inputs.map (fun p => (p.name, p.type)) ++ + proc.outputs.map (fun p => (p.name, p.type)) + -- Reset state for each procedure to avoid cross-procedure contamination + modify fun s => { s with insideCondition := false, varSnapshots := [], varCounters := [], env := initEnv } match proc.body with | .Transparent bodyExpr => let seqBody ← transformProcedureBody bodyExpr diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T2_ImpureExpressions.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T2_ImpureExpressions.lean index 7e299b75c..14cd81c47 100644 --- a/StrataTest/Languages/Laurel/Examples/Fundamentals/T2_ImpureExpressions.lean +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T2_ImpureExpressions.lean @@ -21,6 +21,12 @@ procedure NestedImpureStatements() { // ^^^^^^^^^^^^^^ error: assertion does not hold assert z == y; } + +procedure multipleAssignments() { + var x: int; + var y: int := ((x := 1;) + x) + (x := 2;); + assert y == 4; +} " #guard_msgs (error, drop all) in diff --git a/StrataTest/Languages/Laurel/Examples/Objects/T1_MutableFields.lean b/StrataTest/Languages/Laurel/Examples/Objects/T1_MutableFields.lean index 7c04768c3..8cdecd812 100644 --- a/StrataTest/Languages/Laurel/Examples/Objects/T1_MutableFields.lean +++ b/StrataTest/Languages/Laurel/Examples/Objects/T1_MutableFields.lean @@ -44,6 +44,18 @@ procedure caller(c: Container, d: Container) { assert d#intValue == 3; } +procedure allowHeapMutatingCallerInExpression(c: Container, d: Container) { + assume d#intValue == 1; + var x: int := foo(c, d) + 1; + assert d#intValue == 3; +} + +procedure subsequentHeapMutations(c: Container) { + // The additional parenthesis on the next line are needed to let the parser succeed. Joe, any idea why this is needed? + var sum: int := ((c#intValue := 1;) + c#intValue) + (c#intValue := 2;); + assert sum == 4; +} + procedure implicitEquality(c: Container, d: Container) { c#intValue := 1; d#intValue := 2;