From 3a2f2592167c26aa248ac314d3c6cb4e2270082f Mon Sep 17 00:00:00 2001 From: keyboardDrummer-bot Date: Tue, 21 Apr 2026 13:55:30 +0000 Subject: [PATCH 1/4] Support functions with multiple outputs - Change LocalVariable from (name, type, init) to (List Parameter, init) to support multi-output destructuring - Add EliminateMultipleOutputs pass that transforms functional procedures with multiple outputs into procedures returning a synthesized result datatype, rewriting call sites to destructure via generated accessors - The pass checks .isFunctional to determine which procedures to handle - Add pass call in runLaurelPasses after constrained type elimination - Add test for functions with multiple outputs - Update all pattern matches across the codebase (~18 files) --- .../Languages/Laurel/ConstrainedTypeElim.lean | 25 +-- .../Laurel/CoreGroupingAndOrdering.lean | 2 +- .../Laurel/EliminateMultipleOutputs.lean | 144 ++++++++++++++++++ Strata/Languages/Laurel/FilterPrelude.lean | 4 +- .../AbstractToConcreteTreeTranslator.lean | 8 +- .../ConcreteToAbstractTreeTranslator.lean | 2 +- .../Laurel/HeapParameterization.lean | 8 +- Strata/Languages/Laurel/InferHoleTypes.lean | 5 +- Strata/Languages/Laurel/Laurel.lean | 4 +- .../Laurel/LaurelCompilationPipeline.lean | 6 + .../Laurel/LaurelToCoreTranslator.lean | 83 ++++++---- Strata/Languages/Laurel/LaurelTypes.lean | 2 +- .../Laurel/LiftImperativeExpressions.lean | 32 ++-- Strata/Languages/Laurel/MapStmtExpr.lean | 4 +- Strata/Languages/Laurel/Resolution.lean | 15 +- Strata/Languages/Laurel/TypeHierarchy.lean | 4 +- Strata/Languages/Python/PythonToLaurel.lean | 26 ++-- .../Fundamentals/T22_MultipleOutputs.lean | 38 +++++ 18 files changed, 316 insertions(+), 96 deletions(-) create mode 100644 Strata/Languages/Laurel/EliminateMultipleOutputs.lean create mode 100644 StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean diff --git a/Strata/Languages/Laurel/ConstrainedTypeElim.lean b/Strata/Languages/Laurel/ConstrainedTypeElim.lean index 4658d5876b..cdc418f83b 100644 --- a/Strata/Languages/Laurel/ConstrainedTypeElim.lean +++ b/Strata/Languages/Laurel/ConstrainedTypeElim.lean @@ -92,8 +92,8 @@ def resolveExprNode (ptMap : ConstrainedTypeMap) (expr : StmtExprMd) : StmtExprM let source := expr.source let md := expr.md match expr.val with - | .LocalVariable n ty init => - ⟨.LocalVariable n (resolveType ptMap ty) init, source, md⟩ + | .LocalVariable params init => + ⟨.LocalVariable (params.map fun p => { p with type := resolveType ptMap p.type }) init, source, md⟩ | .Forall param trigger body => let param' := { param with type := resolveType ptMap param.type } -- With bottom-up traversal, `body` is already recursed into. The newly @@ -127,15 +127,18 @@ def elimStmt (ptMap : ConstrainedTypeMap) let source := stmt.source let md := stmt.md match _h : stmt.val with - | .LocalVariable name ty init => - let callOpt := constraintCallFor ptMap ty.val name md (src := source) - if callOpt.isSome then modify fun pv => pv.insert name.text ty.val + | .LocalVariable params init => + for p in params do + let callOpt := constraintCallFor ptMap p.type.val p.name md (src := source) + if callOpt.isSome then modify fun pv => pv.insert p.name.text p.type.val let (init', check) : Option StmtExprMd × List StmtExprMd := match init with - | none => match callOpt with - | some c => (none, [⟨.Assume c, source, md⟩]) - | none => (none, []) - | some _ => (init, callOpt.toList.map fun c => ⟨.Assert c, source, md⟩) - pure ([⟨.LocalVariable name ty init', source, md⟩] ++ check) + | none => + let calls := params.filterMap fun p => constraintCallFor ptMap p.type.val p.name md (src := source) + (none, calls.map fun c => ⟨.Assume c, source, md⟩) + | some _ => + let calls := params.filterMap fun p => constraintCallFor ptMap p.type.val p.name md (src := source) + (init, calls.map fun c => ⟨.Assert c, source, md⟩) + pure ([⟨.LocalVariable params init', source, md⟩] ++ check) | .Assign [target] _ => match target.val with | .Identifier name => do @@ -209,7 +212,7 @@ private def mkWitnessProc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : let md := ct.witness.md let witnessId : Identifier := mkId "$witness" let witnessInit : StmtExprMd := - ⟨.LocalVariable witnessId (resolveType ptMap ct.base) (some ct.witness), src, md⟩ + ⟨.LocalVariable [{ name := witnessId, type := resolveType ptMap ct.base }] (some ct.witness), src, md⟩ let assert : StmtExprMd := ⟨.Assert (constraintCallFor ptMap (.UserDefined ct.name) witnessId md (src := src)).get!, src, md⟩ { name := mkId s!"$witness_{ct.name.text}" diff --git a/Strata/Languages/Laurel/CoreGroupingAndOrdering.lean b/Strata/Languages/Laurel/CoreGroupingAndOrdering.lean index 1d8596235a..fdd367cca2 100644 --- a/Strata/Languages/Laurel/CoreGroupingAndOrdering.lean +++ b/Strata/Languages/Laurel/CoreGroupingAndOrdering.lean @@ -65,7 +65,7 @@ def collectStaticCallNames (expr : StmtExprMd) : List String := | .Assign targets v => targets.flatMap (fun t => collectStaticCallNames t) ++ collectStaticCallNames v - | .LocalVariable _ _ initOption => + | .LocalVariable _ initOption => match initOption with | some init => collectStaticCallNames init | none => [] diff --git a/Strata/Languages/Laurel/EliminateMultipleOutputs.lean b/Strata/Languages/Laurel/EliminateMultipleOutputs.lean new file mode 100644 index 0000000000..ad31101e2d --- /dev/null +++ b/Strata/Languages/Laurel/EliminateMultipleOutputs.lean @@ -0,0 +1,144 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ +module + +public import Strata.Languages.Laurel.MapStmtExpr + +/-! +# Eliminate Multiple Outputs + +Transforms functional procedures (`.isFunctional = true`) with multiple outputs +into procedures that return a single synthesized result datatype. Call sites are +rewritten to destructure the result using the generated accessors. + +This pass operates on `Program → Program`. +-/ + +namespace Strata.Laurel + +public section + +private def mkMd (e : StmtExpr) : StmtExprMd := { val := e, source := none } +private def mkTy (t : HighType) : HighTypeMd := { val := t, source := none } + +/-- Info about a function whose multiple outputs have been collapsed into a result datatype. -/ +private structure MultiOutInfo where + funcName : String + resultTypeName : String + constructorName : String + outputs : List Parameter + +/-- Identify functional procedures with multiple outputs. -/ +private def collectMultiOutFunctions (procs : List Procedure) : List MultiOutInfo := + procs.filterMap fun f => + if f.isFunctional && f.outputs.length > 1 then + some { + funcName := f.name.text + resultTypeName := s!"{f.name.text}$result" + constructorName := s!"{f.name.text}$result$mk" + outputs := f.outputs + } + else none + +/-- Generate a result datatype for a multi-output function. -/ +private def mkResultDatatype (info : MultiOutInfo) : DatatypeDefinition := + let args := info.outputs.zipIdx.map fun (p, i) => + { name := mkId s!"out{i}", type := p.type : Parameter } + { name := mkId info.resultTypeName + typeArgs := [] + constructors := [{ name := mkId info.constructorName, args := args }] } + +/-- Transform a multi-output function to return the result datatype. -/ +private def transformFunction (info : MultiOutInfo) (proc : Procedure) : Procedure := + let resultOutput : Parameter := + { name := mkId "$result", type := mkTy (.UserDefined (mkId info.resultTypeName)) } + { proc with outputs := [resultOutput] } + +/-- Destructor name for field `outN` of the result datatype. -/ +private def destructorName (info : MultiOutInfo) (idx : Nat) : String := + s!"{info.resultTypeName}..out{idx}" + +/-- Rewrite a statement list, replacing multi-output call patterns. -/ +private def rewriteStmts (infoMap : Std.HashMap String MultiOutInfo) + (stmts : List StmtExprMd) : List StmtExprMd := + let rec go (remaining : List StmtExprMd) (acc : List StmtExprMd) : List StmtExprMd := + match remaining with + | [] => acc.reverse + | stmt :: rest => + match stmt.val with + | .Assign targets ⟨.StaticCall callee args, callSrc, callMd⟩ => + match infoMap.get? callee.text with + | some info => + if targets.length == info.outputs.length then + let tempName := s!"${callee.text}$temp" + let tempParam : Parameter := { name := mkId tempName, type := mkTy (.UserDefined (mkId info.resultTypeName)) } + let tempDecl := mkMd (.LocalVariable [tempParam] + (some ⟨.StaticCall callee args, callSrc, callMd⟩)) + let assigns := targets.zipIdx.map fun (tgt, i) => + mkMd (.Assign [tgt] + (mkMd (.StaticCall (mkId (destructorName info i)) + [mkMd (.Identifier (mkId tempName))]))) + go rest (assigns.reverse ++ (tempDecl :: acc)) + else go rest (stmt :: acc) + | none => go rest (stmt :: acc) + | .LocalVariable params (some ⟨.StaticCall callee args, callSrc, callMd⟩) => + match infoMap.get? callee.text with + | some info => + if info.outputs.length > 1 then + let tempName := s!"${callee.text}$temp" + let tempParam : Parameter := { name := mkId tempName, type := mkTy (.UserDefined (mkId info.resultTypeName)) } + let tempDecl := mkMd (.LocalVariable [tempParam] + (some ⟨.StaticCall callee args, callSrc, callMd⟩)) + let localDecls := params.zipIdx.map fun (p, i) => + mkMd (.LocalVariable [p] + (some (mkMd (.StaticCall (mkId (destructorName info i)) + [mkMd (.Identifier (mkId tempName))])))) + go rest (localDecls.reverse ++ (tempDecl :: acc)) + else go rest (stmt :: acc) + | none => go rest (stmt :: acc) + | _ => go rest (stmt :: acc) + termination_by remaining.length + go stmts [] + +/-- Rewrite blocks in a StmtExprMd tree to handle multi-output calls. -/ +private def rewriteExpr (infoMap : Std.HashMap String MultiOutInfo) + (expr : StmtExprMd) : StmtExprMd := + mapStmtExpr (fun e => + match e.val with + | .Block stmts label => ⟨.Block (rewriteStmts infoMap stmts) label, e.source, e.md⟩ + | _ => e) expr + +/-- Rewrite all procedure bodies. -/ +private def rewriteProcedure (infoMap : Std.HashMap String MultiOutInfo) + (proc : Procedure) : Procedure := + match proc.body with + | .Transparent b => + let wrapped := mkMd (.Block [b] none) + let rewritten := rewriteExpr infoMap wrapped + { proc with body := .Transparent rewritten } + | .Opaque posts (some impl) mods => + let wrapped := mkMd (.Block [impl] none) + let rewritten := rewriteExpr infoMap wrapped + { proc with body := .Opaque posts (some rewritten) mods } + | _ => proc + +/-- Eliminate multiple outputs from a Program. Only applies to functional procedures. -/ +def eliminateMultipleOutputs (program : Program) : Program := + let infos := collectMultiOutFunctions program.staticProcedures + if infos.isEmpty then program else + let infoMap : Std.HashMap String MultiOutInfo := + infos.foldl (fun m info => m.insert info.funcName info) {} + let newDatatypes := infos.map mkResultDatatype + let procs := program.staticProcedures.map fun f => + match infoMap.get? f.name.text with + | some info => rewriteProcedure infoMap (transformFunction info f) + | none => rewriteProcedure infoMap f + { program with + staticProcedures := procs + types := program.types ++ newDatatypes.map TypeDefinition.Datatype } + +end -- public section +end Strata.Laurel diff --git a/Strata/Languages/Laurel/FilterPrelude.lean b/Strata/Languages/Laurel/FilterPrelude.lean index ce7c6a3656..2bf7f8c459 100644 --- a/Strata/Languages/Laurel/FilterPrelude.lean +++ b/Strata/Languages/Laurel/FilterPrelude.lean @@ -92,8 +92,8 @@ private partial def collectExprNames (expr : StmtExprMd) : CollectM Unit := do collectExprNames cond; collectExprNames thenB elseB.forM collectExprNames | .Block stmts _ => stmts.forM collectExprNames - | .LocalVariable _ ty init => - collectHighTypeNames ty + | .LocalVariable params init => + params.forM fun p => collectHighTypeNames p.type init.forM collectExprNames | .While cond invs dec body => collectExprNames cond; invs.forM collectExprNames diff --git a/Strata/Languages/Laurel/Grammar/AbstractToConcreteTreeTranslator.lean b/Strata/Languages/Laurel/Grammar/AbstractToConcreteTreeTranslator.lean index 44012f3960..6369f36cbc 100644 --- a/Strata/Languages/Laurel/Grammar/AbstractToConcreteTreeTranslator.lean +++ b/Strata/Languages/Laurel/Grammar/AbstractToConcreteTreeTranslator.lean @@ -95,10 +95,14 @@ where match label with | none => laurelOp "block" #[semicolonSep stmtArgs] | some l => laurelOp "labelledBlock" #[semicolonSep stmtArgs, ident l] - | .LocalVariable name ty init => + | .LocalVariable params init => + -- Grammar only supports single-target varDecl; use first parameter or placeholder + let (nameText, ty) := match params with + | p :: _ => (p.name.text, p.type) + | [] => ("_", ⟨.TVoid, none, #[]⟩) let typeOpt := optionArg (some (laurelOp "typeAnnotation" #[highTypeToArg ty])) let initOpt := optionArg (init.map fun e => laurelOp "initializer" #[stmtExprToArg e]) - laurelOp "varDecl" #[ident name.text, typeOpt, initOpt] + laurelOp "varDecl" #[ident nameText, typeOpt, initOpt] | .Assign targets value => -- Grammar only supports single-target assign; use first target or placeholder let targetArg := match targets with diff --git a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean index 77f223a66c..4180eddd06 100644 --- a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean +++ b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean @@ -240,7 +240,7 @@ partial def translateStmtExpr (arg : Arg) : TransM StmtExprMd := do | _ => TransM.error s!"assignArg {repr assignArg} didn't match expected pattern for variable {name}" | .option _ none => pure none | _ => TransM.error s!"assignArg {repr assignArg} didn't match expected pattern for variable {name}" - return mkStmtExprMd (.LocalVariable name varType value) src + return mkStmtExprMd (.LocalVariable [{ name := name, type := varType }] value) src | q`Laurel.identifier, #[arg0] => let name ← translateIdent arg0 return mkStmtExprMd (.Identifier name) src diff --git a/Strata/Languages/Laurel/HeapParameterization.lean b/Strata/Languages/Laurel/HeapParameterization.lean index 0a3b9bf029..7fef7e03e3 100644 --- a/Strata/Languages/Laurel/HeapParameterization.lean +++ b/Strata/Languages/Laurel/HeapParameterization.lean @@ -63,7 +63,7 @@ def collectExpr (expr : StmtExpr) : StateM AnalysisResult Unit := do | .StaticCall callee args => modify fun s => { s with callees := callee :: s.callees }; for a in args do collectExprMd a | .IfThenElse c t e => collectExprMd c; collectExprMd t; if let some x := e then collectExprMd x | .Block stmts _ => for s in stmts do collectExprMd s - | .LocalVariable _ _ i => if let some x := i then collectExprMd x + | .LocalVariable _ i => if let some x := i then collectExprMd x | .While c invs d b => collectExprMd c; collectExprMd b; for inv in invs do collectExprMd inv; if let some x := d then collectExprMd x | .Return v => if let some x := v then collectExprMd x | .Assign assignTargets v => @@ -277,7 +277,7 @@ where if calleeWritesHeap then if valueUsed then let freshVar ← freshVarName - let varDecl := mkMd (.LocalVariable freshVar (computeExprType model exprMd) none) + let varDecl := mkMd (.LocalVariable [{ name := freshVar, type := computeExprType model exprMd }] none) let callWithHeap := ⟨ .Assign [mkMd (.Identifier heapVar), mkMd (.Identifier freshVar)] (⟨ .StaticCall callee (mkMd (.Identifier heapVar) :: args'), source, md ⟩), source, md ⟩ @@ -308,9 +308,9 @@ where termination_by sizeOf remaining let stmts' ← processStmts 0 stmts return ⟨ .Block stmts' label, source, md ⟩ - | .LocalVariable n ty i => + | .LocalVariable params i => let i' ← match i with | some x => some <$> recurse x | none => pure none - return ⟨ .LocalVariable n ty i', source, md ⟩ + return ⟨ .LocalVariable params i', source, md ⟩ | .While c invs d b => let invs' ← invs.mapM (recurse ·) return ⟨ .While (← recurse c) invs' d (← recurse b false), source, md ⟩ diff --git a/Strata/Languages/Laurel/InferHoleTypes.lean b/Strata/Languages/Laurel/InferHoleTypes.lean index 7070e991a2..590ef858da 100644 --- a/Strata/Languages/Laurel/InferHoleTypes.lean +++ b/Strata/Languages/Laurel/InferHoleTypes.lean @@ -117,9 +117,10 @@ private def inferExpr (expr : StmtExprMd) (expectedType : HighTypeMd) : InferHol | target :: _ => computeExprType model target | _ => defaultHoleType return ⟨.Assign targets (← inferExpr value targetType), source, md⟩ - | .LocalVariable name ty init => + | .LocalVariable params init => + let ty := match params with | p :: _ => p.type | [] => defaultHoleType match init with - | some initExpr => return ⟨.LocalVariable name ty (some (← inferExpr initExpr ty)), source, md⟩ + | some initExpr => return ⟨.LocalVariable params (some (← inferExpr initExpr ty)), source, md⟩ | none => return expr | .While cond invs dec body => let dec' ← match dec with diff --git a/Strata/Languages/Laurel/Laurel.lean b/Strata/Languages/Laurel/Laurel.lean index f7f8129322..c88f873834 100644 --- a/Strata/Languages/Laurel/Laurel.lean +++ b/Strata/Languages/Laurel/Laurel.lean @@ -238,8 +238,8 @@ inductive StmtExpr : Type where | IfThenElse (cond : AstNode StmtExpr) (thenBranch : AstNode StmtExpr) (elseBranch : Option (AstNode StmtExpr)) /-- A sequence of statements with an optional label for `Exit`. -/ | Block (statements : List (AstNode StmtExpr)) (label : Option String) - /-- A local variable declaration with a type and optional initializer. The initializer must be set if this `StmtExpr` is pure. -/ - | LocalVariable (name : Identifier) (type : AstNode HighType) (initializer : Option (AstNode StmtExpr)) + /-- A local variable declaration with typed parameters and optional initializer. The initializer must be set if this `StmtExpr` is pure. Multiple parameters are only allowed when the initializer is a `StaticCall` to a procedure with multiple outputs. -/ + | LocalVariable (parameters : List Parameter) (initializer : Option (AstNode StmtExpr)) /-- A while loop with a condition, invariants, optional termination measure, and body. Only allowed in impure contexts. -/ | While (cond : AstNode StmtExpr) (invariants : List (AstNode StmtExpr)) (decreases : Option (AstNode StmtExpr)) diff --git a/Strata/Languages/Laurel/LaurelCompilationPipeline.lean b/Strata/Languages/Laurel/LaurelCompilationPipeline.lean index 81d34ba1d6..4b02635458 100644 --- a/Strata/Languages/Laurel/LaurelCompilationPipeline.lean +++ b/Strata/Languages/Laurel/LaurelCompilationPipeline.lean @@ -9,6 +9,7 @@ public import Strata.Languages.Laurel.LaurelToCoreTranslator import Strata.Languages.Laurel.DesugarShortCircuit import Strata.Languages.Laurel.EliminateReturnsInExpression import Strata.Languages.Laurel.EliminateValueReturns +import Strata.Languages.Laurel.EliminateMultipleOutputs import Strata.Languages.Laurel.ConstrainedTypeElim import Strata.Languages.Core.Verifier @@ -110,6 +111,11 @@ private def runLaurelPasses (options : LaurelTranslateOptions) (program : Progra let (program, model) := (result.program, result.model) emit "ConstrainedTypeElim" program + let program := eliminateMultipleOutputs program + let result := resolve program (some model) + let (program, model) := (result.program, result.model) + emit "EliminateMultipleOutputs" program + let allDiags := resolutionErrors ++ diamondErrors ++ nonCompositeDiags ++ valueReturnDiags.toList ++ modifiesDiags ++ constrainedTypeDiags return (program, model, allDiags) diff --git a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean index f916dce425..efa9b78e22 100644 --- a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean +++ b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean @@ -277,14 +277,14 @@ def translateExpr (expr : StmtExprMd) | .Block (⟨ .Assume _, innerSrc, innerMd⟩ :: rest) label => _ ← disallowed (fileRangeToCoreMd innerSrc innerMd) "assumes are not YET supported in functions or contracts" translateExpr { val := StmtExpr.Block rest label, source := innerSrc, md := innerMd } boundVars isPureContext - | .Block (⟨ .LocalVariable name ty (some initializer), innerSrc, innerMd⟩ :: rest) label => do + | .Block (⟨ .LocalVariable [⟨ name, ty ⟩] (some initializer), innerSrc, innerMd⟩ :: rest) label => do let valueExpr ← translateExpr initializer boundVars isPureContext let bodyExpr ← translateExpr { val := StmtExpr.Block rest label, source := innerSrc, md := innerMd } (name :: boundVars) isPureContext disallowed (fileRangeToCoreMd innerSrc innerMd) "local variables in functions are not YET supported" -- This doesn't work because of a limitation in Core. -- let coreMonoType := translateType ty -- return .app () (.abs () (some coreMonoType) bodyExpr) valueExpr - | .Block (⟨ .LocalVariable name ty none, innerSrc, innerMd⟩ :: rest) label => + | .Block (⟨ .LocalVariable _params none, innerSrc, innerMd⟩ :: rest) label => disallowed (fileRangeToCoreMd innerSrc innerMd) "local variables in functions must have initializers" | .Block (⟨ .IfThenElse cond thenBranch (some elseBranch), innerSrc, innerMd⟩ :: rest) label => disallowed (fileRangeToCoreMd innerSrc innerMd) "if-then-else only supported as the last statement in a block" @@ -298,7 +298,7 @@ def translateExpr (expr : StmtExprMd) throwExprDiagnostic $ md.toDiagnostic s!"FieldSelect should have been eliminated by heap parameterization: {Std.ToFormat.format target}#{fieldId.text}" DiagnosticType.StrataBug | .Block _ _ => throwExprDiagnostic $ md.toDiagnostic "block expression should have been lowered in a separate pass" DiagnosticType.StrataBug - | .LocalVariable _ _ _ => + | .LocalVariable _ _ => throwExprDiagnostic $ md.toDiagnostic "local variable expression should be lowered in a separate pass" DiagnosticType.StrataBug | .Return _ => disallowed md "return expression should be lowered in a separate pass" @@ -370,37 +370,56 @@ def translateStmt (stmt : StmtExprMd) match label with | some l => return [Imperative.Stmt.block l innerStmts md] | none => return innerStmts - | .LocalVariable id ty initializer => - let coreMonoType ← translateType ty - let coreType := LTy.forAll [] coreMonoType - let ident := ⟨id.text, ()⟩ - match initializer with - | some (⟨ .StaticCall callee args, callSrc, callMd⟩) => - -- Check if this is a function or a procedure call - if model.isFunction callee then - -- Translate as expression (function application) - let coreExpr ← translateExpr { val := .StaticCall callee args, source := callSrc, md := callMd } + | .LocalVariable params initializer => + match params with + | [⟨id, ty⟩] => + let coreMonoType ← translateType ty + let coreType := LTy.forAll [] coreMonoType + let ident := ⟨id.text, ()⟩ + match initializer with + | some (⟨ .StaticCall callee args, callSrc, callMd⟩) => + -- Check if this is a function or a procedure call + if model.isFunction callee then + -- Translate as expression (function application) + let coreExpr ← translateExpr { val := .StaticCall callee args, source := callSrc, md := callMd } + return [Core.Statement.init ident coreType (.det coreExpr) md] + else + -- Translate as: var name; call name := callee(args) + let coreArgs ← args.mapM (fun a => translateExpr a) + let defaultExpr ← defaultExprForType ty + let initStmt := Core.Statement.init ident coreType (.det defaultExpr) md + let callStmt := Core.Statement.call [ident] callee.text coreArgs md + return [initStmt, callStmt] + | some (⟨ .InstanceCall .., _, _⟩) => + -- Instance method call as initializer: var name := target.method(args) + -- Havoc the result since instance methods may be on unmodeled types + let initStmt := Core.Statement.init ident coreType .nondet md + return [initStmt] + | some (⟨ .Hole _ _, _, _⟩) => + -- Hole initializer: treat as havoc (init without value) + return [Core.Statement.init ident coreType .nondet md] + | some initExpr => + let coreExpr ← translateExpr initExpr return [Core.Statement.init ident coreType (.det coreExpr) md] - else - -- Translate as: var name; call name := callee(args) + | none => + return [Core.Statement.init ident coreType .nondet md] + | _ => + -- Multi-parameter LocalVariable: should have been eliminated by EliminateMultipleOutputs + match initializer with + | some (⟨ .StaticCall callee args, callSrc, callMd⟩) => let coreArgs ← args.mapM (fun a => translateExpr a) - let defaultExpr ← defaultExprForType ty - let initStmt := Core.Statement.init ident coreType (.det defaultExpr) md - let callStmt := Core.Statement.call [ident] callee.text coreArgs md - return [initStmt, callStmt] - | some (⟨ .InstanceCall .., _, _⟩) => - -- Instance method call as initializer: var name := target.method(args) - -- Havoc the result since instance methods may be on unmodeled types - let initStmt := Core.Statement.init ident coreType .nondet md - return [initStmt] - | some (⟨ .Hole _ _, _, _⟩) => - -- Hole initializer: treat as havoc (init without value) - return [Core.Statement.init ident coreType .nondet md] - | some initExpr => - let coreExpr ← translateExpr initExpr - return [Core.Statement.init ident coreType (.det coreExpr) md] - | none => - return [Core.Statement.init ident coreType .nondet md] + let initStmts ← params.mapM fun p => do + let coreType := LTy.forAll [] (← translateType p.type) + let defaultExpr ← defaultExprForType p.type + pure (Core.Statement.init ⟨p.name.text, ()⟩ coreType (.det defaultExpr) md) + let idents := params.map fun p => ⟨p.name.text, ()⟩ + let callStmt := Core.Statement.call idents callee.text coreArgs md + return initStmts ++ [callStmt] + | _ => + let stmts ← params.mapM fun p => do + let coreType := LTy.forAll [] (← translateType p.type) + pure (Core.Statement.init ⟨p.name.text, ()⟩ coreType .nondet md) + return stmts | .Assign targets value => match targets with | [⟨ .Identifier targetId, _, _ ⟩] => diff --git a/Strata/Languages/Laurel/LaurelTypes.lean b/Strata/Languages/Laurel/LaurelTypes.lean index 0abc0cdcc2..7530fc8888 100644 --- a/Strata/Languages/Laurel/LaurelTypes.lean +++ b/Strata/Languages/Laurel/LaurelTypes.lean @@ -75,7 +75,7 @@ def computeExprType (model : SemanticModel) (expr : StmtExprMd) : HighTypeMd := computeExprType model last | none => ⟨ .TVoid, source, md ⟩ -- Statements - | .LocalVariable _ _ _ => ⟨ .TVoid, source, md ⟩ + | .LocalVariable _ _ => ⟨ .TVoid, source, md ⟩ | .While _ _ _ _ => ⟨ .TVoid, source, md ⟩ | .Exit _ => ⟨ .TVoid, source, md ⟩ | .Return _ => ⟨ .TVoid, source, md ⟩ diff --git a/Strata/Languages/Laurel/LiftImperativeExpressions.lean b/Strata/Languages/Laurel/LiftImperativeExpressions.lean index 9aa9045606..c4035d6438 100644 --- a/Strata/Languages/Laurel/LiftImperativeExpressions.lean +++ b/Strata/Languages/Laurel/LiftImperativeExpressions.lean @@ -212,7 +212,7 @@ private def liftAssignExpr (targets : List StmtExprMd) (seqValue : StmtExprMd) let snapshotName ← freshTempFor varName let varType ← computeType target -- Snapshot goes before the assignment (cons pushes to front) - prepend (⟨.LocalVariable snapshotName varType (some (⟨.Identifier varName, source, md⟩)), source, md⟩) + prepend (⟨.LocalVariable [{ name := snapshotName, type := varType }] (some (⟨.Identifier varName, source, md⟩)), source, md⟩) setSubst varName snapshotName | _ => pure () @@ -233,7 +233,7 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do | .Hole false (some holeType) => -- Nondeterministic typed hole: lift to a fresh variable with no initializer (havoc) let holeVar ← freshCondVar - prepend (bare (.LocalVariable holeVar holeType none)) + prepend ⟨ .LocalVariable [{ name := holeVar, type := holeType }] none, source, md⟩ return bare (.Identifier holeVar) | .Assign targets value => @@ -271,7 +271,7 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do let callResultVar ← freshCondVar let callResultType ← computeType expr let liftedCall := [ - ⟨ (.LocalVariable callResultVar callResultType none), source, md ⟩, + ⟨ (.LocalVariable [{ name := callResultVar, type := callResultType }] none), source, md ⟩, ⟨.Assign [bare (.Identifier callResultVar)] seqCall, source, md⟩ ] modify fun s => { s with prependedStmts := s.prependedStmts ++ liftedCall} @@ -312,7 +312,7 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do -- IfThenElse added first (cons puts it deeper), then declaration (cons puts it on top) -- Output order: declaration, then if-then-else prepend (⟨.IfThenElse seqCond thenBlock seqElse, source, md⟩) - prepend (bare (.LocalVariable condVar condType none)) + prepend ⟨ .LocalVariable [{ name := condVar, type := condType }] none, source, md ⟩ return bare (.Identifier condVar) else -- No assignments in branches — recurse normally @@ -327,19 +327,23 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do let newStmts := (← stmts.reverse.mapM transformExpr).reverse return ⟨ .Block (← onlyKeepSideEffectStmtsAndLast newStmts) labelOption, source, md ⟩ - | .LocalVariable name ty initializer => - -- If the substitution map has an entry for this variable, it was + | .LocalVariable params initializer => + -- If the substitution map has an entry for any of these variables, it was -- assigned to the right and we need to lift this declaration so it -- appears before the snapshot that references it. - let hasSubst := (← get).subst.lookup name |>.isSome + let subst := (← get).subst + let hasSubst := params.any fun p => subst.lookup p.name |>.isSome if hasSubst then match initializer with | some initExpr => let seqInit ← transformExpr initExpr - prepend (⟨.LocalVariable name ty (some seqInit), expr.source, expr.md⟩) + prepend (⟨.LocalVariable params (some seqInit), expr.source, expr.md⟩) | none => - prepend (⟨.LocalVariable name ty none, expr.source, expr.md⟩) - return ⟨.Identifier (← getSubst name), expr.source, expr.md⟩ + prepend (⟨.LocalVariable params none, expr.source, expr.md⟩) + -- Return substitution for the first name + match params with + | p :: _ => return ⟨.Identifier (← getSubst p.name), expr.source, expr.md⟩ + | [] => return expr else return expr @@ -380,7 +384,7 @@ def transformStmt (stmt : StmtExprMd) : LiftM (List StmtExprMd) := do let seqStmts ← stmts.mapM transformStmt return [bare (.Block seqStmts.flatten metadata)] - | .LocalVariable name ty initializer => + | .LocalVariable params initializer => match _ : initializer with | some initExprMd => -- If the initializer is a direct imperative StaticCall, don't lift it — @@ -394,18 +398,18 @@ def transformStmt (stmt : StmtExprMd) : LiftM (List StmtExprMd) := do let seqInit ← transformExpr initExprMd let prepends ← takePrepends modify fun s => { s with subst := [] } - return prepends ++ [⟨.LocalVariable name ty (some seqInit), source, md⟩] + return prepends ++ [⟨.LocalVariable params (some seqInit), source, md⟩] else -- Pass through as-is; translateStmt will emit init + call let seqArgs ← args.mapM transformExpr let argPrepends ← takePrepends modify fun s => { s with subst := [] } - return argPrepends ++ [⟨.LocalVariable name ty (some ⟨.StaticCall callee seqArgs, initExprMd.source, initExprMd.md⟩), source, md⟩] + return argPrepends ++ [⟨.LocalVariable params (some ⟨.StaticCall callee seqArgs, initExprMd.source, initExprMd.md⟩), source, md⟩] | _ => let seqInit ← transformExpr initExprMd let prepends ← takePrepends modify fun s => { s with subst := [] } - return prepends ++ [⟨.LocalVariable name ty (some seqInit), source, md⟩] + return prepends ++ [⟨.LocalVariable params (some seqInit), source, md⟩] | none => return [stmt] diff --git a/Strata/Languages/Laurel/MapStmtExpr.lean b/Strata/Languages/Laurel/MapStmtExpr.lean index 3ca5fd7beb..9f83444775 100644 --- a/Strata/Languages/Laurel/MapStmtExpr.lean +++ b/Strata/Languages/Laurel/MapStmtExpr.lean @@ -39,8 +39,8 @@ def mapStmtExprM [Monad m] (f : StmtExprMd → m StmtExprMd) (expr : StmtExprMd) (← el.attach.mapM fun ⟨e, _⟩ => mapStmtExprM f e), source, md⟩ | .Block stmts label => pure ⟨.Block (← stmts.attach.mapM fun ⟨e, _⟩ => mapStmtExprM f e) label, source, md⟩ - | .LocalVariable name ty init => - pure ⟨.LocalVariable name ty (← init.attach.mapM fun ⟨e, _⟩ => mapStmtExprM f e), source, md⟩ + | .LocalVariable params init => + pure ⟨.LocalVariable params (← init.attach.mapM fun ⟨e, _⟩ => mapStmtExprM f e), source, md⟩ | .While cond invs dec body => pure ⟨.While (← mapStmtExprM f cond) (← invs.attach.mapM fun ⟨e, _⟩ => mapStmtExprM f e) diff --git a/Strata/Languages/Laurel/Resolution.lean b/Strata/Languages/Laurel/Resolution.lean index 4b0b1217e9..395ef02ae4 100644 --- a/Strata/Languages/Laurel/Resolution.lean +++ b/Strata/Languages/Laurel/Resolution.lean @@ -307,11 +307,13 @@ def resolveStmtExpr (exprMd : StmtExprMd) : ResolveM StmtExprMd := do withScope do let stmts' ← stmts.mapM resolveStmtExpr pure (.Block stmts' label) - | .LocalVariable name ty init => - let ty' ← resolveHighType ty + | .LocalVariable params init => let init' ← init.attach.mapM (fun a => have := a.property; resolveStmtExpr a.val) - let name' ← defineNameCheckDup name (.var name ty') - pure (.LocalVariable name' ty' init') + let params' ← params.mapM fun p => do + let ty' ← resolveHighType p.type + let name' ← defineNameCheckDup p.name (.var p.name ty') + pure { name := name', type := ty' } + pure (.LocalVariable params' init') | .While cond invs dec body => let cond' ← resolveStmtExpr cond let invs' ← invs.attach.mapM (fun a => have := a.property; resolveStmtExpr a.val) @@ -579,9 +581,8 @@ private def collectStmtExpr (map : Std.HashMap Nat ResolvedNode) (expr : StmtExp | some e => collectStmtExpr map e | none => map | .Block stmts _ => stmts.foldl collectStmtExpr map - | .LocalVariable name ty init => - let map := register map name (.var name ty) - let map := collectHighType map ty + | .LocalVariable params init => + let map := params.foldl (fun m p => register (collectHighType m p.type) p.name (.var p.name p.type)) map match init with | some i => collectStmtExpr map i | none => map diff --git a/Strata/Languages/Laurel/TypeHierarchy.lean b/Strata/Languages/Laurel/TypeHierarchy.lean index 30c3602393..5d972263fe 100644 --- a/Strata/Languages/Laurel/TypeHierarchy.lean +++ b/Strata/Languages/Laurel/TypeHierarchy.lean @@ -143,7 +143,7 @@ def validateDiamondFieldAccessesForStmtExpr (model : SemanticModel) match e with | some eb => errs ++ validateDiamondFieldAccessesForStmtExpr model eb | none => errs - | .LocalVariable _ _ (some init) => + | .LocalVariable _ (some init) => validateDiamondFieldAccessesForStmtExpr model init | .While c invs _ b => let errs := validateDiamondFieldAccessesForStmtExpr model c ++ @@ -214,7 +214,7 @@ def lowerNew (name : Identifier) (source : Option FileRange) (md : Imperative.Me let heapVar : Identifier := "$heap" let freshVar ← freshVarName let getCounter := mkMd (.StaticCall "Heap..nextReference!" [mkMd (.Identifier heapVar)]) - let saveCounter := mkMd (.LocalVariable freshVar ⟨.TInt, none, #[]⟩ (some getCounter)) + let saveCounter := mkMd (.LocalVariable [{ name := freshVar, type := ⟨.TInt, none, #[]⟩ }] (some getCounter)) let newHeap := mkMd (.StaticCall "increment" [mkMd (.Identifier heapVar)]) let updateHeap := mkMd (.Assign [mkMd (.Identifier heapVar)] newHeap) let compositeResult := mkMd (.StaticCall "MkComposite" [mkMd (.Identifier freshVar), mkMd (.StaticCall (name.text ++ "_TypeTag") [])]) diff --git a/Strata/Languages/Python/PythonToLaurel.lean b/Strata/Languages/Python/PythonToLaurel.lean index 7cb878359f..9364956b82 100644 --- a/Strata/Languages/Python/PythonToLaurel.lean +++ b/Strata/Languages/Python/PythonToLaurel.lean @@ -1153,7 +1153,7 @@ partial def translateAssign (ctx : TranslationContext) | some _ => throw (.userPythonError lhs.ann s!"'{annType}' is not a type") | _ => pure (AnyTy, "Any") - let initStmt := mkStmtExprMd (StmtExpr.LocalVariable n.val varTy (mkStmtExprMd .Hole)) + let initStmt := mkStmtExprMd (StmtExpr.LocalVariable [{ name := n.val, type := varTy }] (mkStmtExprMd .Hole)) let newctx := {ctx with variableTypes:=(n.val, trackType)::ctx.variableTypes} return (newctx, [initStmt] ++ exceptHavoc, true) | _ => return (ctx, [mkStmtExprMd .Hole] ++ exceptHavoc, false) @@ -1174,7 +1174,7 @@ partial def translateAssign (ctx : TranslationContext) let assignStmt := mkStmtExprMdWithLoc (StmtExpr.Assign [targetExpr] newExpr) md [assignStmt, initStmt] else - let newStmt := mkStmtExprMdWithLoc (StmtExpr.LocalVariable n.val varType (some newExpr)) md + let newStmt := mkStmtExprMdWithLoc (StmtExpr.LocalVariable [{ name := n.val, type := varType }] (some newExpr)) md [newStmt, initStmt] else if withException ctx fnname.text then [mkStmtExprMdWithLoc (StmtExpr.Assign [targetExpr, maybeExceptVar] rhs_trans) md] @@ -1185,7 +1185,7 @@ partial def translateAssign (ctx : TranslationContext) [mkStmtExprMdWithLoc (StmtExpr.Assign [targetExpr] rhs_trans) md] else let varType := mkHighTypeMd (.UserDefined className) - let newStmt := mkStmtExprMdWithLoc (StmtExpr.LocalVariable n.val varType (some rhs_trans)) md + let newStmt := mkStmtExprMdWithLoc (StmtExpr.LocalVariable [{ name := n.val, type := varType }] (some rhs_trans)) md [newStmt] | _ => [mkStmtExprMdWithLoc (StmtExpr.Assign [targetExpr] rhs_trans) md] newctx := match rhs_trans.val with @@ -1207,7 +1207,7 @@ partial def translateAssign (ctx : TranslationContext) -- If the annotation isn't a recognized type, prefer the -- inferred type from the RHS (e.g., overload dispatch). if isKnownType ctx annStr then annStr else inferType - let initStmt := mkStmtExprMd (StmtExpr.LocalVariable n.val AnyTy AnyNone) + let initStmt := mkStmtExprMd (StmtExpr.LocalVariable [{ name := n.val, type := AnyTy }] AnyNone) newctx := {ctx with variableTypes:=(n.val, type)::ctx.variableTypes} return (newctx, initStmt :: assignStmts, true) | .Subscript _ _ _ _ => @@ -1311,7 +1311,7 @@ def createVarDeclStmtsAndCtx (ctx : TranslationContext) (newDecls : List (String then acc else acc ++ [(n, ty)]) [] let hoistedDecls : List StmtExprMd ← newDecls.mapM fun (name, tyStr) => do let ty ← translateType ctx tyStr - pure $ mkStmtExprMd (StmtExpr.LocalVariable (name : String) ty (some (mkStmtExprMd .Hole))) + pure $ mkStmtExprMd (StmtExpr.LocalVariable [{ name := (name : String), type := ty }] (some (mkStmtExprMd .Hole))) let hoistedCtx := { ctx with variableTypes := ctx.variableTypes ++ (newDecls.map fun (n, ty) => if isCompositeType ctx ty then (n, ty) else (n, PyLauType.Any)) } @@ -1405,7 +1405,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang return (ctx, []) let newctx := {ctx with variableTypes:=(varName, varType)::ctx.variableTypes} let varType ← translateType ctx varType - let declStmt := mkStmtExprMd (StmtExpr.LocalVariable varName varType AnyNone) + let declStmt := mkStmtExprMd (StmtExpr.LocalVariable [{ name := varName, type := varType }] AnyNone) return (newctx, [declStmt]) -- If statement @@ -1462,7 +1462,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang | .Hole => let freshVar := s!"assert_cond_{test.toAst.ann.start.byteIdx}" let varType := mkHighTypeMd .TBool - let varDecl := mkStmtExprMd (StmtExpr.LocalVariable freshVar varType (some condExpr)) + let varDecl := mkStmtExprMd (StmtExpr.LocalVariable [{ name := freshVar, type := varType }] (some condExpr)) let varRef := mkStmtExprMd (StmtExpr.Identifier freshVar) ([varDecl], varRef, { ctx with variableTypes := ctx.variableTypes ++ [(freshVar, "bool")] }) | _ => ([], condExpr, ctx) @@ -1565,7 +1565,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang let mgrExpr ← translateExpr currentCtx ctxExpr let mgrTy ← inferExprType currentCtx ctxExpr let mgrLauTy ← translateType currentCtx mgrTy - let mgrDecl := mkStmtExprMd (StmtExpr.LocalVariable mgrName mgrLauTy (some mgrExpr)) + let mgrDecl := mkStmtExprMd (StmtExpr.LocalVariable [{ name := mgrName, type := mgrLauTy }] (some mgrExpr)) let mgrRef := mkStmtExprMd (StmtExpr.Identifier mgrName) currentCtx := {currentCtx with variableTypes := currentCtx.variableTypes ++ [(mgrName, mgrTy)]} let enterCall := mkInstanceMethodCall mgrTy "__enter__" mgrRef [] md @@ -1579,7 +1579,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang setupStmts := setupStmts ++ [mgrDecl, assignStmt] else -- New variable — declare outside the block so it's visible after - let varDecl := mkStmtExprMd (StmtExpr.LocalVariable varName AnyTy (some enterCall)) + let varDecl := mkStmtExprMd (StmtExpr.LocalVariable [{ name := varName, type := AnyTy }] (some enterCall)) currentCtx := {currentCtx with variableTypes := currentCtx.variableTypes ++ [(varName, PyLauType.Any)]} setupStmts := setupStmts ++ [mgrDecl, varDecl] | none => @@ -1665,7 +1665,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang | _ => (target, [], []) let slices ← slices.mapM (translateExpr ctx) let tempVarDecls := (tempVars.zip slices).map λ (var, slice) => - mkStmtExprMd (.LocalVariable { text := var, md := default } AnyTy slice) + mkStmtExprMd (.LocalVariable [{ name := { text := var, md := default }, type := AnyTy }] slice) let rhs : Python.expr SourceRange := .BinOp sr target op value let pyNormalAssign : Python.stmt SourceRange := .Assign sr {val:= #[target], ann:= target.ann} rhs {val:= none, ann:= sr} @@ -1687,7 +1687,7 @@ partial def translateStmtList (ctx : TranslationContext) (stmts : List (Python.s end def prependExceptHandlingHelper (l: List StmtExprMd) : List StmtExprMd := - mkStmtExprMd (.LocalVariable "maybe_except" (mkCoreType "Error") (some NoError)) :: l + mkStmtExprMd (.LocalVariable [{ name := "maybe_except", type := mkCoreType "Error" }] (some NoError)) :: l partial def getNestedSubscripts (expr: Python.expr SourceRange) : List ( Python.expr SourceRange) := match expr with @@ -1825,7 +1825,7 @@ def translateFunctionBody (ctx : TranslationContext) (inputTypes : List (String let (varDecls, ctx) ← createVarDeclStmtsAndCtx ctx newDecls let (newctx, bodyStmts) ← translateStmtList ctx body let bodyStmts := prependExceptHandlingHelper (varDecls ++ bodyStmts) - let bodyStmts := (mkStmtExprMd (.LocalVariable "nullcall_ret" AnyTy (some AnyNone))) :: bodyStmts + let bodyStmts := (mkStmtExprMd (.LocalVariable [{ name := "nullcall_ret", type := AnyTy }] (some AnyNone))) :: bodyStmts return (mkStmtExprMd (StmtExpr.Block bodyStmts none), newctx) /-- Translate Python function to Laurel Procedure -/ @@ -2000,7 +2000,7 @@ def translateMethod (ctx : TranslationContext) (className : String) let paramCopies := nonSelfParams.map fun p => let origName := p.name.text let renamedName := "$in_" ++ origName - mkStmtExprMd (StmtExpr.LocalVariable origName p.type + mkStmtExprMd (StmtExpr.LocalVariable [{ name := origName, type := p.type }] (some (mkStmtExprMd (StmtExpr.Identifier renamedName)))) let bodyStmts := paramCopies ++ bodyStmts let bodyBlock := mkStmtExprMd (StmtExpr.Block bodyStmts none) diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean new file mode 100644 index 0000000000..183cb21b8b --- /dev/null +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean @@ -0,0 +1,38 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import StrataTest.Util.TestDiagnostics +import StrataTest.Languages.Laurel.TestExamples + +open StrataTest.Util +open Strata + +namespace Strata.Laurel + +/-! # Multiple Output Functions + +Tests that functions with multiple output parameters are correctly handled +by the EliminateMultipleOutputs pass, which synthesizes a result datatype +and rewrites call sites. +-/ + +def multiOutputProgram := r" +function twoOutputs(x: int) + returns (a: int, b: int); + +procedure testMultiOut() { + var a: int; + var b: int; + a, b := twoOutputs(5); + assert a > 0 +//^^^^^^^^^^ error: assertion could not be proved +}; +" + +#guard_msgs (drop info, error) in +#eval testInputWithOffset "MultiOutput" multiOutputProgram 14 processLaurelFile + +end Strata.Laurel From d8fa4d409798f6dc5558d3a2d7fb722937f4eb23 Mon Sep 17 00:00:00 2001 From: keyboardDrummer-bot Date: Tue, 21 Apr 2026 14:18:02 +0000 Subject: [PATCH 2/4] Revert LocalVariable AST change, keep original (name, type, initializer) form Reverts the LocalVariable AST change from (parameters : List Parameter) back to (name : Identifier) (type : AstNode HighType) (initializer : ...). The EliminateMultipleOutputs pass now works with the original LocalVariable form. It rewrites both multi-target Assign and single-target LocalVariable/ Assign call sites to use the synthesized result datatype destructors. The test uses a destructuring assignment (single-target assign) as the LHS for the multi-output function call. --- .../Languages/Laurel/ConstrainedTypeElim.lean | 25 +++--- .../Laurel/CoreGroupingAndOrdering.lean | 2 +- .../Laurel/EliminateMultipleOutputs.lean | 43 +++++----- Strata/Languages/Laurel/FilterPrelude.lean | 4 +- .../AbstractToConcreteTreeTranslator.lean | 8 +- .../ConcreteToAbstractTreeTranslator.lean | 2 +- .../Laurel/HeapParameterization.lean | 8 +- Strata/Languages/Laurel/InferHoleTypes.lean | 5 +- Strata/Languages/Laurel/Laurel.lean | 4 +- .../Laurel/LaurelToCoreTranslator.lean | 83 +++++++------------ Strata/Languages/Laurel/LaurelTypes.lean | 2 +- .../Laurel/LiftImperativeExpressions.lean | 32 ++++--- Strata/Languages/Laurel/MapStmtExpr.lean | 4 +- Strata/Languages/Laurel/Resolution.lean | 15 ++-- Strata/Languages/Laurel/TypeHierarchy.lean | 4 +- Strata/Languages/Python/PythonToLaurel.lean | 26 +++--- .../Fundamentals/T22_MultipleOutputs.lean | 5 +- 17 files changed, 118 insertions(+), 154 deletions(-) diff --git a/Strata/Languages/Laurel/ConstrainedTypeElim.lean b/Strata/Languages/Laurel/ConstrainedTypeElim.lean index cdc418f83b..4658d5876b 100644 --- a/Strata/Languages/Laurel/ConstrainedTypeElim.lean +++ b/Strata/Languages/Laurel/ConstrainedTypeElim.lean @@ -92,8 +92,8 @@ def resolveExprNode (ptMap : ConstrainedTypeMap) (expr : StmtExprMd) : StmtExprM let source := expr.source let md := expr.md match expr.val with - | .LocalVariable params init => - ⟨.LocalVariable (params.map fun p => { p with type := resolveType ptMap p.type }) init, source, md⟩ + | .LocalVariable n ty init => + ⟨.LocalVariable n (resolveType ptMap ty) init, source, md⟩ | .Forall param trigger body => let param' := { param with type := resolveType ptMap param.type } -- With bottom-up traversal, `body` is already recursed into. The newly @@ -127,18 +127,15 @@ def elimStmt (ptMap : ConstrainedTypeMap) let source := stmt.source let md := stmt.md match _h : stmt.val with - | .LocalVariable params init => - for p in params do - let callOpt := constraintCallFor ptMap p.type.val p.name md (src := source) - if callOpt.isSome then modify fun pv => pv.insert p.name.text p.type.val + | .LocalVariable name ty init => + let callOpt := constraintCallFor ptMap ty.val name md (src := source) + if callOpt.isSome then modify fun pv => pv.insert name.text ty.val let (init', check) : Option StmtExprMd × List StmtExprMd := match init with - | none => - let calls := params.filterMap fun p => constraintCallFor ptMap p.type.val p.name md (src := source) - (none, calls.map fun c => ⟨.Assume c, source, md⟩) - | some _ => - let calls := params.filterMap fun p => constraintCallFor ptMap p.type.val p.name md (src := source) - (init, calls.map fun c => ⟨.Assert c, source, md⟩) - pure ([⟨.LocalVariable params init', source, md⟩] ++ check) + | none => match callOpt with + | some c => (none, [⟨.Assume c, source, md⟩]) + | none => (none, []) + | some _ => (init, callOpt.toList.map fun c => ⟨.Assert c, source, md⟩) + pure ([⟨.LocalVariable name ty init', source, md⟩] ++ check) | .Assign [target] _ => match target.val with | .Identifier name => do @@ -212,7 +209,7 @@ private def mkWitnessProc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : let md := ct.witness.md let witnessId : Identifier := mkId "$witness" let witnessInit : StmtExprMd := - ⟨.LocalVariable [{ name := witnessId, type := resolveType ptMap ct.base }] (some ct.witness), src, md⟩ + ⟨.LocalVariable witnessId (resolveType ptMap ct.base) (some ct.witness), src, md⟩ let assert : StmtExprMd := ⟨.Assert (constraintCallFor ptMap (.UserDefined ct.name) witnessId md (src := src)).get!, src, md⟩ { name := mkId s!"$witness_{ct.name.text}" diff --git a/Strata/Languages/Laurel/CoreGroupingAndOrdering.lean b/Strata/Languages/Laurel/CoreGroupingAndOrdering.lean index fdd367cca2..1d8596235a 100644 --- a/Strata/Languages/Laurel/CoreGroupingAndOrdering.lean +++ b/Strata/Languages/Laurel/CoreGroupingAndOrdering.lean @@ -65,7 +65,7 @@ def collectStaticCallNames (expr : StmtExprMd) : List String := | .Assign targets v => targets.flatMap (fun t => collectStaticCallNames t) ++ collectStaticCallNames v - | .LocalVariable _ initOption => + | .LocalVariable _ _ initOption => match initOption with | some init => collectStaticCallNames init | none => [] diff --git a/Strata/Languages/Laurel/EliminateMultipleOutputs.lean b/Strata/Languages/Laurel/EliminateMultipleOutputs.lean index ad31101e2d..582bb8ee69 100644 --- a/Strata/Languages/Laurel/EliminateMultipleOutputs.lean +++ b/Strata/Languages/Laurel/EliminateMultipleOutputs.lean @@ -72,35 +72,32 @@ private def rewriteStmts (infoMap : Std.HashMap String MultiOutInfo) | .Assign targets ⟨.StaticCall callee args, callSrc, callMd⟩ => match infoMap.get? callee.text with | some info => - if targets.length == info.outputs.length then - let tempName := s!"${callee.text}$temp" - let tempParam : Parameter := { name := mkId tempName, type := mkTy (.UserDefined (mkId info.resultTypeName)) } - let tempDecl := mkMd (.LocalVariable [tempParam] - (some ⟨.StaticCall callee args, callSrc, callMd⟩)) - let assigns := targets.zipIdx.map fun (tgt, i) => - mkMd (.Assign [tgt] - (mkMd (.StaticCall (mkId (destructorName info i)) - [mkMd (.Identifier (mkId tempName))]))) - go rest (assigns.reverse ++ (tempDecl :: acc)) - else go rest (stmt :: acc) + let tempName := mkId s!"${callee.text}$temp" + let tempTy := mkTy (.UserDefined (mkId info.resultTypeName)) + let tempDecl := mkMd (.LocalVariable tempName tempTy + (some ⟨.StaticCall callee args, callSrc, callMd⟩)) + let assigns := targets.zipIdx.map fun (tgt, i) => + mkMd (.Assign [tgt] + (mkMd (.StaticCall (mkId (destructorName info i)) + [mkMd (.Identifier tempName)]))) + go rest (assigns.reverse ++ (tempDecl :: acc)) | none => go rest (stmt :: acc) - | .LocalVariable params (some ⟨.StaticCall callee args, callSrc, callMd⟩) => + | .LocalVariable name _ty (some ⟨.StaticCall callee args, callSrc, callMd⟩) => match infoMap.get? callee.text with | some info => - if info.outputs.length > 1 then - let tempName := s!"${callee.text}$temp" - let tempParam : Parameter := { name := mkId tempName, type := mkTy (.UserDefined (mkId info.resultTypeName)) } - let tempDecl := mkMd (.LocalVariable [tempParam] + match info.outputs with + | firstOut :: _ => + let tempName := mkId s!"${callee.text}$temp" + let tempTy := mkTy (.UserDefined (mkId info.resultTypeName)) + let tempDecl := mkMd (.LocalVariable tempName tempTy (some ⟨.StaticCall callee args, callSrc, callMd⟩)) - let localDecls := params.zipIdx.map fun (p, i) => - mkMd (.LocalVariable [p] - (some (mkMd (.StaticCall (mkId (destructorName info i)) - [mkMd (.Identifier (mkId tempName))])))) - go rest (localDecls.reverse ++ (tempDecl :: acc)) - else go rest (stmt :: acc) + let localDecl := mkMd (.LocalVariable name firstOut.type + (some (mkMd (.StaticCall (mkId (destructorName info 0)) + [mkMd (.Identifier tempName)])))) + go rest (localDecl :: tempDecl :: acc) + | [] => go rest (stmt :: acc) | none => go rest (stmt :: acc) | _ => go rest (stmt :: acc) - termination_by remaining.length go stmts [] /-- Rewrite blocks in a StmtExprMd tree to handle multi-output calls. -/ diff --git a/Strata/Languages/Laurel/FilterPrelude.lean b/Strata/Languages/Laurel/FilterPrelude.lean index 2bf7f8c459..ce7c6a3656 100644 --- a/Strata/Languages/Laurel/FilterPrelude.lean +++ b/Strata/Languages/Laurel/FilterPrelude.lean @@ -92,8 +92,8 @@ private partial def collectExprNames (expr : StmtExprMd) : CollectM Unit := do collectExprNames cond; collectExprNames thenB elseB.forM collectExprNames | .Block stmts _ => stmts.forM collectExprNames - | .LocalVariable params init => - params.forM fun p => collectHighTypeNames p.type + | .LocalVariable _ ty init => + collectHighTypeNames ty init.forM collectExprNames | .While cond invs dec body => collectExprNames cond; invs.forM collectExprNames diff --git a/Strata/Languages/Laurel/Grammar/AbstractToConcreteTreeTranslator.lean b/Strata/Languages/Laurel/Grammar/AbstractToConcreteTreeTranslator.lean index 6369f36cbc..44012f3960 100644 --- a/Strata/Languages/Laurel/Grammar/AbstractToConcreteTreeTranslator.lean +++ b/Strata/Languages/Laurel/Grammar/AbstractToConcreteTreeTranslator.lean @@ -95,14 +95,10 @@ where match label with | none => laurelOp "block" #[semicolonSep stmtArgs] | some l => laurelOp "labelledBlock" #[semicolonSep stmtArgs, ident l] - | .LocalVariable params init => - -- Grammar only supports single-target varDecl; use first parameter or placeholder - let (nameText, ty) := match params with - | p :: _ => (p.name.text, p.type) - | [] => ("_", ⟨.TVoid, none, #[]⟩) + | .LocalVariable name ty init => let typeOpt := optionArg (some (laurelOp "typeAnnotation" #[highTypeToArg ty])) let initOpt := optionArg (init.map fun e => laurelOp "initializer" #[stmtExprToArg e]) - laurelOp "varDecl" #[ident nameText, typeOpt, initOpt] + laurelOp "varDecl" #[ident name.text, typeOpt, initOpt] | .Assign targets value => -- Grammar only supports single-target assign; use first target or placeholder let targetArg := match targets with diff --git a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean index 4180eddd06..77f223a66c 100644 --- a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean +++ b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean @@ -240,7 +240,7 @@ partial def translateStmtExpr (arg : Arg) : TransM StmtExprMd := do | _ => TransM.error s!"assignArg {repr assignArg} didn't match expected pattern for variable {name}" | .option _ none => pure none | _ => TransM.error s!"assignArg {repr assignArg} didn't match expected pattern for variable {name}" - return mkStmtExprMd (.LocalVariable [{ name := name, type := varType }] value) src + return mkStmtExprMd (.LocalVariable name varType value) src | q`Laurel.identifier, #[arg0] => let name ← translateIdent arg0 return mkStmtExprMd (.Identifier name) src diff --git a/Strata/Languages/Laurel/HeapParameterization.lean b/Strata/Languages/Laurel/HeapParameterization.lean index 7fef7e03e3..0a3b9bf029 100644 --- a/Strata/Languages/Laurel/HeapParameterization.lean +++ b/Strata/Languages/Laurel/HeapParameterization.lean @@ -63,7 +63,7 @@ def collectExpr (expr : StmtExpr) : StateM AnalysisResult Unit := do | .StaticCall callee args => modify fun s => { s with callees := callee :: s.callees }; for a in args do collectExprMd a | .IfThenElse c t e => collectExprMd c; collectExprMd t; if let some x := e then collectExprMd x | .Block stmts _ => for s in stmts do collectExprMd s - | .LocalVariable _ i => if let some x := i then collectExprMd x + | .LocalVariable _ _ i => if let some x := i then collectExprMd x | .While c invs d b => collectExprMd c; collectExprMd b; for inv in invs do collectExprMd inv; if let some x := d then collectExprMd x | .Return v => if let some x := v then collectExprMd x | .Assign assignTargets v => @@ -277,7 +277,7 @@ where if calleeWritesHeap then if valueUsed then let freshVar ← freshVarName - let varDecl := mkMd (.LocalVariable [{ name := freshVar, type := computeExprType model exprMd }] none) + let varDecl := mkMd (.LocalVariable freshVar (computeExprType model exprMd) none) let callWithHeap := ⟨ .Assign [mkMd (.Identifier heapVar), mkMd (.Identifier freshVar)] (⟨ .StaticCall callee (mkMd (.Identifier heapVar) :: args'), source, md ⟩), source, md ⟩ @@ -308,9 +308,9 @@ where termination_by sizeOf remaining let stmts' ← processStmts 0 stmts return ⟨ .Block stmts' label, source, md ⟩ - | .LocalVariable params i => + | .LocalVariable n ty i => let i' ← match i with | some x => some <$> recurse x | none => pure none - return ⟨ .LocalVariable params i', source, md ⟩ + return ⟨ .LocalVariable n ty i', source, md ⟩ | .While c invs d b => let invs' ← invs.mapM (recurse ·) return ⟨ .While (← recurse c) invs' d (← recurse b false), source, md ⟩ diff --git a/Strata/Languages/Laurel/InferHoleTypes.lean b/Strata/Languages/Laurel/InferHoleTypes.lean index 590ef858da..7070e991a2 100644 --- a/Strata/Languages/Laurel/InferHoleTypes.lean +++ b/Strata/Languages/Laurel/InferHoleTypes.lean @@ -117,10 +117,9 @@ private def inferExpr (expr : StmtExprMd) (expectedType : HighTypeMd) : InferHol | target :: _ => computeExprType model target | _ => defaultHoleType return ⟨.Assign targets (← inferExpr value targetType), source, md⟩ - | .LocalVariable params init => - let ty := match params with | p :: _ => p.type | [] => defaultHoleType + | .LocalVariable name ty init => match init with - | some initExpr => return ⟨.LocalVariable params (some (← inferExpr initExpr ty)), source, md⟩ + | some initExpr => return ⟨.LocalVariable name ty (some (← inferExpr initExpr ty)), source, md⟩ | none => return expr | .While cond invs dec body => let dec' ← match dec with diff --git a/Strata/Languages/Laurel/Laurel.lean b/Strata/Languages/Laurel/Laurel.lean index c88f873834..f7f8129322 100644 --- a/Strata/Languages/Laurel/Laurel.lean +++ b/Strata/Languages/Laurel/Laurel.lean @@ -238,8 +238,8 @@ inductive StmtExpr : Type where | IfThenElse (cond : AstNode StmtExpr) (thenBranch : AstNode StmtExpr) (elseBranch : Option (AstNode StmtExpr)) /-- A sequence of statements with an optional label for `Exit`. -/ | Block (statements : List (AstNode StmtExpr)) (label : Option String) - /-- A local variable declaration with typed parameters and optional initializer. The initializer must be set if this `StmtExpr` is pure. Multiple parameters are only allowed when the initializer is a `StaticCall` to a procedure with multiple outputs. -/ - | LocalVariable (parameters : List Parameter) (initializer : Option (AstNode StmtExpr)) + /-- A local variable declaration with a type and optional initializer. The initializer must be set if this `StmtExpr` is pure. -/ + | LocalVariable (name : Identifier) (type : AstNode HighType) (initializer : Option (AstNode StmtExpr)) /-- A while loop with a condition, invariants, optional termination measure, and body. Only allowed in impure contexts. -/ | While (cond : AstNode StmtExpr) (invariants : List (AstNode StmtExpr)) (decreases : Option (AstNode StmtExpr)) diff --git a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean index efa9b78e22..f916dce425 100644 --- a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean +++ b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean @@ -277,14 +277,14 @@ def translateExpr (expr : StmtExprMd) | .Block (⟨ .Assume _, innerSrc, innerMd⟩ :: rest) label => _ ← disallowed (fileRangeToCoreMd innerSrc innerMd) "assumes are not YET supported in functions or contracts" translateExpr { val := StmtExpr.Block rest label, source := innerSrc, md := innerMd } boundVars isPureContext - | .Block (⟨ .LocalVariable [⟨ name, ty ⟩] (some initializer), innerSrc, innerMd⟩ :: rest) label => do + | .Block (⟨ .LocalVariable name ty (some initializer), innerSrc, innerMd⟩ :: rest) label => do let valueExpr ← translateExpr initializer boundVars isPureContext let bodyExpr ← translateExpr { val := StmtExpr.Block rest label, source := innerSrc, md := innerMd } (name :: boundVars) isPureContext disallowed (fileRangeToCoreMd innerSrc innerMd) "local variables in functions are not YET supported" -- This doesn't work because of a limitation in Core. -- let coreMonoType := translateType ty -- return .app () (.abs () (some coreMonoType) bodyExpr) valueExpr - | .Block (⟨ .LocalVariable _params none, innerSrc, innerMd⟩ :: rest) label => + | .Block (⟨ .LocalVariable name ty none, innerSrc, innerMd⟩ :: rest) label => disallowed (fileRangeToCoreMd innerSrc innerMd) "local variables in functions must have initializers" | .Block (⟨ .IfThenElse cond thenBranch (some elseBranch), innerSrc, innerMd⟩ :: rest) label => disallowed (fileRangeToCoreMd innerSrc innerMd) "if-then-else only supported as the last statement in a block" @@ -298,7 +298,7 @@ def translateExpr (expr : StmtExprMd) throwExprDiagnostic $ md.toDiagnostic s!"FieldSelect should have been eliminated by heap parameterization: {Std.ToFormat.format target}#{fieldId.text}" DiagnosticType.StrataBug | .Block _ _ => throwExprDiagnostic $ md.toDiagnostic "block expression should have been lowered in a separate pass" DiagnosticType.StrataBug - | .LocalVariable _ _ => + | .LocalVariable _ _ _ => throwExprDiagnostic $ md.toDiagnostic "local variable expression should be lowered in a separate pass" DiagnosticType.StrataBug | .Return _ => disallowed md "return expression should be lowered in a separate pass" @@ -370,56 +370,37 @@ def translateStmt (stmt : StmtExprMd) match label with | some l => return [Imperative.Stmt.block l innerStmts md] | none => return innerStmts - | .LocalVariable params initializer => - match params with - | [⟨id, ty⟩] => - let coreMonoType ← translateType ty - let coreType := LTy.forAll [] coreMonoType - let ident := ⟨id.text, ()⟩ - match initializer with - | some (⟨ .StaticCall callee args, callSrc, callMd⟩) => - -- Check if this is a function or a procedure call - if model.isFunction callee then - -- Translate as expression (function application) - let coreExpr ← translateExpr { val := .StaticCall callee args, source := callSrc, md := callMd } - return [Core.Statement.init ident coreType (.det coreExpr) md] - else - -- Translate as: var name; call name := callee(args) - let coreArgs ← args.mapM (fun a => translateExpr a) - let defaultExpr ← defaultExprForType ty - let initStmt := Core.Statement.init ident coreType (.det defaultExpr) md - let callStmt := Core.Statement.call [ident] callee.text coreArgs md - return [initStmt, callStmt] - | some (⟨ .InstanceCall .., _, _⟩) => - -- Instance method call as initializer: var name := target.method(args) - -- Havoc the result since instance methods may be on unmodeled types - let initStmt := Core.Statement.init ident coreType .nondet md - return [initStmt] - | some (⟨ .Hole _ _, _, _⟩) => - -- Hole initializer: treat as havoc (init without value) - return [Core.Statement.init ident coreType .nondet md] - | some initExpr => - let coreExpr ← translateExpr initExpr + | .LocalVariable id ty initializer => + let coreMonoType ← translateType ty + let coreType := LTy.forAll [] coreMonoType + let ident := ⟨id.text, ()⟩ + match initializer with + | some (⟨ .StaticCall callee args, callSrc, callMd⟩) => + -- Check if this is a function or a procedure call + if model.isFunction callee then + -- Translate as expression (function application) + let coreExpr ← translateExpr { val := .StaticCall callee args, source := callSrc, md := callMd } return [Core.Statement.init ident coreType (.det coreExpr) md] - | none => - return [Core.Statement.init ident coreType .nondet md] - | _ => - -- Multi-parameter LocalVariable: should have been eliminated by EliminateMultipleOutputs - match initializer with - | some (⟨ .StaticCall callee args, callSrc, callMd⟩) => + else + -- Translate as: var name; call name := callee(args) let coreArgs ← args.mapM (fun a => translateExpr a) - let initStmts ← params.mapM fun p => do - let coreType := LTy.forAll [] (← translateType p.type) - let defaultExpr ← defaultExprForType p.type - pure (Core.Statement.init ⟨p.name.text, ()⟩ coreType (.det defaultExpr) md) - let idents := params.map fun p => ⟨p.name.text, ()⟩ - let callStmt := Core.Statement.call idents callee.text coreArgs md - return initStmts ++ [callStmt] - | _ => - let stmts ← params.mapM fun p => do - let coreType := LTy.forAll [] (← translateType p.type) - pure (Core.Statement.init ⟨p.name.text, ()⟩ coreType .nondet md) - return stmts + let defaultExpr ← defaultExprForType ty + let initStmt := Core.Statement.init ident coreType (.det defaultExpr) md + let callStmt := Core.Statement.call [ident] callee.text coreArgs md + return [initStmt, callStmt] + | some (⟨ .InstanceCall .., _, _⟩) => + -- Instance method call as initializer: var name := target.method(args) + -- Havoc the result since instance methods may be on unmodeled types + let initStmt := Core.Statement.init ident coreType .nondet md + return [initStmt] + | some (⟨ .Hole _ _, _, _⟩) => + -- Hole initializer: treat as havoc (init without value) + return [Core.Statement.init ident coreType .nondet md] + | some initExpr => + let coreExpr ← translateExpr initExpr + return [Core.Statement.init ident coreType (.det coreExpr) md] + | none => + return [Core.Statement.init ident coreType .nondet md] | .Assign targets value => match targets with | [⟨ .Identifier targetId, _, _ ⟩] => diff --git a/Strata/Languages/Laurel/LaurelTypes.lean b/Strata/Languages/Laurel/LaurelTypes.lean index 7530fc8888..0abc0cdcc2 100644 --- a/Strata/Languages/Laurel/LaurelTypes.lean +++ b/Strata/Languages/Laurel/LaurelTypes.lean @@ -75,7 +75,7 @@ def computeExprType (model : SemanticModel) (expr : StmtExprMd) : HighTypeMd := computeExprType model last | none => ⟨ .TVoid, source, md ⟩ -- Statements - | .LocalVariable _ _ => ⟨ .TVoid, source, md ⟩ + | .LocalVariable _ _ _ => ⟨ .TVoid, source, md ⟩ | .While _ _ _ _ => ⟨ .TVoid, source, md ⟩ | .Exit _ => ⟨ .TVoid, source, md ⟩ | .Return _ => ⟨ .TVoid, source, md ⟩ diff --git a/Strata/Languages/Laurel/LiftImperativeExpressions.lean b/Strata/Languages/Laurel/LiftImperativeExpressions.lean index c4035d6438..9aa9045606 100644 --- a/Strata/Languages/Laurel/LiftImperativeExpressions.lean +++ b/Strata/Languages/Laurel/LiftImperativeExpressions.lean @@ -212,7 +212,7 @@ private def liftAssignExpr (targets : List StmtExprMd) (seqValue : StmtExprMd) let snapshotName ← freshTempFor varName let varType ← computeType target -- Snapshot goes before the assignment (cons pushes to front) - prepend (⟨.LocalVariable [{ name := snapshotName, type := varType }] (some (⟨.Identifier varName, source, md⟩)), source, md⟩) + prepend (⟨.LocalVariable snapshotName varType (some (⟨.Identifier varName, source, md⟩)), source, md⟩) setSubst varName snapshotName | _ => pure () @@ -233,7 +233,7 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do | .Hole false (some holeType) => -- Nondeterministic typed hole: lift to a fresh variable with no initializer (havoc) let holeVar ← freshCondVar - prepend ⟨ .LocalVariable [{ name := holeVar, type := holeType }] none, source, md⟩ + prepend (bare (.LocalVariable holeVar holeType none)) return bare (.Identifier holeVar) | .Assign targets value => @@ -271,7 +271,7 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do let callResultVar ← freshCondVar let callResultType ← computeType expr let liftedCall := [ - ⟨ (.LocalVariable [{ name := callResultVar, type := callResultType }] none), source, md ⟩, + ⟨ (.LocalVariable callResultVar callResultType none), source, md ⟩, ⟨.Assign [bare (.Identifier callResultVar)] seqCall, source, md⟩ ] modify fun s => { s with prependedStmts := s.prependedStmts ++ liftedCall} @@ -312,7 +312,7 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do -- IfThenElse added first (cons puts it deeper), then declaration (cons puts it on top) -- Output order: declaration, then if-then-else prepend (⟨.IfThenElse seqCond thenBlock seqElse, source, md⟩) - prepend ⟨ .LocalVariable [{ name := condVar, type := condType }] none, source, md ⟩ + prepend (bare (.LocalVariable condVar condType none)) return bare (.Identifier condVar) else -- No assignments in branches — recurse normally @@ -327,23 +327,19 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do let newStmts := (← stmts.reverse.mapM transformExpr).reverse return ⟨ .Block (← onlyKeepSideEffectStmtsAndLast newStmts) labelOption, source, md ⟩ - | .LocalVariable params initializer => - -- If the substitution map has an entry for any of these variables, it was + | .LocalVariable name ty initializer => + -- If the substitution map has an entry for this variable, it was -- assigned to the right and we need to lift this declaration so it -- appears before the snapshot that references it. - let subst := (← get).subst - let hasSubst := params.any fun p => subst.lookup p.name |>.isSome + let hasSubst := (← get).subst.lookup name |>.isSome if hasSubst then match initializer with | some initExpr => let seqInit ← transformExpr initExpr - prepend (⟨.LocalVariable params (some seqInit), expr.source, expr.md⟩) + prepend (⟨.LocalVariable name ty (some seqInit), expr.source, expr.md⟩) | none => - prepend (⟨.LocalVariable params none, expr.source, expr.md⟩) - -- Return substitution for the first name - match params with - | p :: _ => return ⟨.Identifier (← getSubst p.name), expr.source, expr.md⟩ - | [] => return expr + prepend (⟨.LocalVariable name ty none, expr.source, expr.md⟩) + return ⟨.Identifier (← getSubst name), expr.source, expr.md⟩ else return expr @@ -384,7 +380,7 @@ def transformStmt (stmt : StmtExprMd) : LiftM (List StmtExprMd) := do let seqStmts ← stmts.mapM transformStmt return [bare (.Block seqStmts.flatten metadata)] - | .LocalVariable params initializer => + | .LocalVariable name ty initializer => match _ : initializer with | some initExprMd => -- If the initializer is a direct imperative StaticCall, don't lift it — @@ -398,18 +394,18 @@ def transformStmt (stmt : StmtExprMd) : LiftM (List StmtExprMd) := do let seqInit ← transformExpr initExprMd let prepends ← takePrepends modify fun s => { s with subst := [] } - return prepends ++ [⟨.LocalVariable params (some seqInit), source, md⟩] + return prepends ++ [⟨.LocalVariable name ty (some seqInit), source, md⟩] else -- Pass through as-is; translateStmt will emit init + call let seqArgs ← args.mapM transformExpr let argPrepends ← takePrepends modify fun s => { s with subst := [] } - return argPrepends ++ [⟨.LocalVariable params (some ⟨.StaticCall callee seqArgs, initExprMd.source, initExprMd.md⟩), source, md⟩] + return argPrepends ++ [⟨.LocalVariable name ty (some ⟨.StaticCall callee seqArgs, initExprMd.source, initExprMd.md⟩), source, md⟩] | _ => let seqInit ← transformExpr initExprMd let prepends ← takePrepends modify fun s => { s with subst := [] } - return prepends ++ [⟨.LocalVariable params (some seqInit), source, md⟩] + return prepends ++ [⟨.LocalVariable name ty (some seqInit), source, md⟩] | none => return [stmt] diff --git a/Strata/Languages/Laurel/MapStmtExpr.lean b/Strata/Languages/Laurel/MapStmtExpr.lean index 9f83444775..3ca5fd7beb 100644 --- a/Strata/Languages/Laurel/MapStmtExpr.lean +++ b/Strata/Languages/Laurel/MapStmtExpr.lean @@ -39,8 +39,8 @@ def mapStmtExprM [Monad m] (f : StmtExprMd → m StmtExprMd) (expr : StmtExprMd) (← el.attach.mapM fun ⟨e, _⟩ => mapStmtExprM f e), source, md⟩ | .Block stmts label => pure ⟨.Block (← stmts.attach.mapM fun ⟨e, _⟩ => mapStmtExprM f e) label, source, md⟩ - | .LocalVariable params init => - pure ⟨.LocalVariable params (← init.attach.mapM fun ⟨e, _⟩ => mapStmtExprM f e), source, md⟩ + | .LocalVariable name ty init => + pure ⟨.LocalVariable name ty (← init.attach.mapM fun ⟨e, _⟩ => mapStmtExprM f e), source, md⟩ | .While cond invs dec body => pure ⟨.While (← mapStmtExprM f cond) (← invs.attach.mapM fun ⟨e, _⟩ => mapStmtExprM f e) diff --git a/Strata/Languages/Laurel/Resolution.lean b/Strata/Languages/Laurel/Resolution.lean index 395ef02ae4..4b0b1217e9 100644 --- a/Strata/Languages/Laurel/Resolution.lean +++ b/Strata/Languages/Laurel/Resolution.lean @@ -307,13 +307,11 @@ def resolveStmtExpr (exprMd : StmtExprMd) : ResolveM StmtExprMd := do withScope do let stmts' ← stmts.mapM resolveStmtExpr pure (.Block stmts' label) - | .LocalVariable params init => + | .LocalVariable name ty init => + let ty' ← resolveHighType ty let init' ← init.attach.mapM (fun a => have := a.property; resolveStmtExpr a.val) - let params' ← params.mapM fun p => do - let ty' ← resolveHighType p.type - let name' ← defineNameCheckDup p.name (.var p.name ty') - pure { name := name', type := ty' } - pure (.LocalVariable params' init') + let name' ← defineNameCheckDup name (.var name ty') + pure (.LocalVariable name' ty' init') | .While cond invs dec body => let cond' ← resolveStmtExpr cond let invs' ← invs.attach.mapM (fun a => have := a.property; resolveStmtExpr a.val) @@ -581,8 +579,9 @@ private def collectStmtExpr (map : Std.HashMap Nat ResolvedNode) (expr : StmtExp | some e => collectStmtExpr map e | none => map | .Block stmts _ => stmts.foldl collectStmtExpr map - | .LocalVariable params init => - let map := params.foldl (fun m p => register (collectHighType m p.type) p.name (.var p.name p.type)) map + | .LocalVariable name ty init => + let map := register map name (.var name ty) + let map := collectHighType map ty match init with | some i => collectStmtExpr map i | none => map diff --git a/Strata/Languages/Laurel/TypeHierarchy.lean b/Strata/Languages/Laurel/TypeHierarchy.lean index 5d972263fe..30c3602393 100644 --- a/Strata/Languages/Laurel/TypeHierarchy.lean +++ b/Strata/Languages/Laurel/TypeHierarchy.lean @@ -143,7 +143,7 @@ def validateDiamondFieldAccessesForStmtExpr (model : SemanticModel) match e with | some eb => errs ++ validateDiamondFieldAccessesForStmtExpr model eb | none => errs - | .LocalVariable _ (some init) => + | .LocalVariable _ _ (some init) => validateDiamondFieldAccessesForStmtExpr model init | .While c invs _ b => let errs := validateDiamondFieldAccessesForStmtExpr model c ++ @@ -214,7 +214,7 @@ def lowerNew (name : Identifier) (source : Option FileRange) (md : Imperative.Me let heapVar : Identifier := "$heap" let freshVar ← freshVarName let getCounter := mkMd (.StaticCall "Heap..nextReference!" [mkMd (.Identifier heapVar)]) - let saveCounter := mkMd (.LocalVariable [{ name := freshVar, type := ⟨.TInt, none, #[]⟩ }] (some getCounter)) + let saveCounter := mkMd (.LocalVariable freshVar ⟨.TInt, none, #[]⟩ (some getCounter)) let newHeap := mkMd (.StaticCall "increment" [mkMd (.Identifier heapVar)]) let updateHeap := mkMd (.Assign [mkMd (.Identifier heapVar)] newHeap) let compositeResult := mkMd (.StaticCall "MkComposite" [mkMd (.Identifier freshVar), mkMd (.StaticCall (name.text ++ "_TypeTag") [])]) diff --git a/Strata/Languages/Python/PythonToLaurel.lean b/Strata/Languages/Python/PythonToLaurel.lean index 9364956b82..7cb878359f 100644 --- a/Strata/Languages/Python/PythonToLaurel.lean +++ b/Strata/Languages/Python/PythonToLaurel.lean @@ -1153,7 +1153,7 @@ partial def translateAssign (ctx : TranslationContext) | some _ => throw (.userPythonError lhs.ann s!"'{annType}' is not a type") | _ => pure (AnyTy, "Any") - let initStmt := mkStmtExprMd (StmtExpr.LocalVariable [{ name := n.val, type := varTy }] (mkStmtExprMd .Hole)) + let initStmt := mkStmtExprMd (StmtExpr.LocalVariable n.val varTy (mkStmtExprMd .Hole)) let newctx := {ctx with variableTypes:=(n.val, trackType)::ctx.variableTypes} return (newctx, [initStmt] ++ exceptHavoc, true) | _ => return (ctx, [mkStmtExprMd .Hole] ++ exceptHavoc, false) @@ -1174,7 +1174,7 @@ partial def translateAssign (ctx : TranslationContext) let assignStmt := mkStmtExprMdWithLoc (StmtExpr.Assign [targetExpr] newExpr) md [assignStmt, initStmt] else - let newStmt := mkStmtExprMdWithLoc (StmtExpr.LocalVariable [{ name := n.val, type := varType }] (some newExpr)) md + let newStmt := mkStmtExprMdWithLoc (StmtExpr.LocalVariable n.val varType (some newExpr)) md [newStmt, initStmt] else if withException ctx fnname.text then [mkStmtExprMdWithLoc (StmtExpr.Assign [targetExpr, maybeExceptVar] rhs_trans) md] @@ -1185,7 +1185,7 @@ partial def translateAssign (ctx : TranslationContext) [mkStmtExprMdWithLoc (StmtExpr.Assign [targetExpr] rhs_trans) md] else let varType := mkHighTypeMd (.UserDefined className) - let newStmt := mkStmtExprMdWithLoc (StmtExpr.LocalVariable [{ name := n.val, type := varType }] (some rhs_trans)) md + let newStmt := mkStmtExprMdWithLoc (StmtExpr.LocalVariable n.val varType (some rhs_trans)) md [newStmt] | _ => [mkStmtExprMdWithLoc (StmtExpr.Assign [targetExpr] rhs_trans) md] newctx := match rhs_trans.val with @@ -1207,7 +1207,7 @@ partial def translateAssign (ctx : TranslationContext) -- If the annotation isn't a recognized type, prefer the -- inferred type from the RHS (e.g., overload dispatch). if isKnownType ctx annStr then annStr else inferType - let initStmt := mkStmtExprMd (StmtExpr.LocalVariable [{ name := n.val, type := AnyTy }] AnyNone) + let initStmt := mkStmtExprMd (StmtExpr.LocalVariable n.val AnyTy AnyNone) newctx := {ctx with variableTypes:=(n.val, type)::ctx.variableTypes} return (newctx, initStmt :: assignStmts, true) | .Subscript _ _ _ _ => @@ -1311,7 +1311,7 @@ def createVarDeclStmtsAndCtx (ctx : TranslationContext) (newDecls : List (String then acc else acc ++ [(n, ty)]) [] let hoistedDecls : List StmtExprMd ← newDecls.mapM fun (name, tyStr) => do let ty ← translateType ctx tyStr - pure $ mkStmtExprMd (StmtExpr.LocalVariable [{ name := (name : String), type := ty }] (some (mkStmtExprMd .Hole))) + pure $ mkStmtExprMd (StmtExpr.LocalVariable (name : String) ty (some (mkStmtExprMd .Hole))) let hoistedCtx := { ctx with variableTypes := ctx.variableTypes ++ (newDecls.map fun (n, ty) => if isCompositeType ctx ty then (n, ty) else (n, PyLauType.Any)) } @@ -1405,7 +1405,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang return (ctx, []) let newctx := {ctx with variableTypes:=(varName, varType)::ctx.variableTypes} let varType ← translateType ctx varType - let declStmt := mkStmtExprMd (StmtExpr.LocalVariable [{ name := varName, type := varType }] AnyNone) + let declStmt := mkStmtExprMd (StmtExpr.LocalVariable varName varType AnyNone) return (newctx, [declStmt]) -- If statement @@ -1462,7 +1462,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang | .Hole => let freshVar := s!"assert_cond_{test.toAst.ann.start.byteIdx}" let varType := mkHighTypeMd .TBool - let varDecl := mkStmtExprMd (StmtExpr.LocalVariable [{ name := freshVar, type := varType }] (some condExpr)) + let varDecl := mkStmtExprMd (StmtExpr.LocalVariable freshVar varType (some condExpr)) let varRef := mkStmtExprMd (StmtExpr.Identifier freshVar) ([varDecl], varRef, { ctx with variableTypes := ctx.variableTypes ++ [(freshVar, "bool")] }) | _ => ([], condExpr, ctx) @@ -1565,7 +1565,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang let mgrExpr ← translateExpr currentCtx ctxExpr let mgrTy ← inferExprType currentCtx ctxExpr let mgrLauTy ← translateType currentCtx mgrTy - let mgrDecl := mkStmtExprMd (StmtExpr.LocalVariable [{ name := mgrName, type := mgrLauTy }] (some mgrExpr)) + let mgrDecl := mkStmtExprMd (StmtExpr.LocalVariable mgrName mgrLauTy (some mgrExpr)) let mgrRef := mkStmtExprMd (StmtExpr.Identifier mgrName) currentCtx := {currentCtx with variableTypes := currentCtx.variableTypes ++ [(mgrName, mgrTy)]} let enterCall := mkInstanceMethodCall mgrTy "__enter__" mgrRef [] md @@ -1579,7 +1579,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang setupStmts := setupStmts ++ [mgrDecl, assignStmt] else -- New variable — declare outside the block so it's visible after - let varDecl := mkStmtExprMd (StmtExpr.LocalVariable [{ name := varName, type := AnyTy }] (some enterCall)) + let varDecl := mkStmtExprMd (StmtExpr.LocalVariable varName AnyTy (some enterCall)) currentCtx := {currentCtx with variableTypes := currentCtx.variableTypes ++ [(varName, PyLauType.Any)]} setupStmts := setupStmts ++ [mgrDecl, varDecl] | none => @@ -1665,7 +1665,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang | _ => (target, [], []) let slices ← slices.mapM (translateExpr ctx) let tempVarDecls := (tempVars.zip slices).map λ (var, slice) => - mkStmtExprMd (.LocalVariable [{ name := { text := var, md := default }, type := AnyTy }] slice) + mkStmtExprMd (.LocalVariable { text := var, md := default } AnyTy slice) let rhs : Python.expr SourceRange := .BinOp sr target op value let pyNormalAssign : Python.stmt SourceRange := .Assign sr {val:= #[target], ann:= target.ann} rhs {val:= none, ann:= sr} @@ -1687,7 +1687,7 @@ partial def translateStmtList (ctx : TranslationContext) (stmts : List (Python.s end def prependExceptHandlingHelper (l: List StmtExprMd) : List StmtExprMd := - mkStmtExprMd (.LocalVariable [{ name := "maybe_except", type := mkCoreType "Error" }] (some NoError)) :: l + mkStmtExprMd (.LocalVariable "maybe_except" (mkCoreType "Error") (some NoError)) :: l partial def getNestedSubscripts (expr: Python.expr SourceRange) : List ( Python.expr SourceRange) := match expr with @@ -1825,7 +1825,7 @@ def translateFunctionBody (ctx : TranslationContext) (inputTypes : List (String let (varDecls, ctx) ← createVarDeclStmtsAndCtx ctx newDecls let (newctx, bodyStmts) ← translateStmtList ctx body let bodyStmts := prependExceptHandlingHelper (varDecls ++ bodyStmts) - let bodyStmts := (mkStmtExprMd (.LocalVariable [{ name := "nullcall_ret", type := AnyTy }] (some AnyNone))) :: bodyStmts + let bodyStmts := (mkStmtExprMd (.LocalVariable "nullcall_ret" AnyTy (some AnyNone))) :: bodyStmts return (mkStmtExprMd (StmtExpr.Block bodyStmts none), newctx) /-- Translate Python function to Laurel Procedure -/ @@ -2000,7 +2000,7 @@ def translateMethod (ctx : TranslationContext) (className : String) let paramCopies := nonSelfParams.map fun p => let origName := p.name.text let renamedName := "$in_" ++ origName - mkStmtExprMd (StmtExpr.LocalVariable [{ name := origName, type := p.type }] + mkStmtExprMd (StmtExpr.LocalVariable origName p.type (some (mkStmtExprMd (StmtExpr.Identifier renamedName)))) let bodyStmts := paramCopies ++ bodyStmts let bodyBlock := mkStmtExprMd (StmtExpr.Block bodyStmts none) diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean index 183cb21b8b..18c05d1b19 100644 --- a/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean @@ -25,10 +25,9 @@ function twoOutputs(x: int) procedure testMultiOut() { var a: int; - var b: int; - a, b := twoOutputs(5); + a := twoOutputs(5); assert a > 0 -//^^^^^^^^^^ error: assertion could not be proved +//^^^^^^^^^^^^ error: assertion does not hold }; " From f18bcceabf5893f3944b3bfa919ae5f283441198 Mon Sep 17 00:00:00 2001 From: keyboardDrummer-bot Date: Wed, 22 Apr 2026 12:10:06 +0000 Subject: [PATCH 3/4] Emit error when LHS target count mismatches function output count MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The EliminateMultipleOutputs pass now validates that the number of assignment targets matches the number of outputs for multi-output function calls. When there is a mismatch (e.g., assigning a 2-output function to a single variable), the pass emits a diagnostic error and skips the transformation entirely. Changes: - Added validation phase that scans all call sites before transforming - eliminateMultipleOutputs now returns Program × List DiagnosticModel - Pipeline collects and reports the new diagnostics - Updated test to expect the mismatch error --- .../Laurel/EliminateMultipleOutputs.lean | 90 +++++++++++++------ .../Laurel/LaurelCompilationPipeline.lean | 4 +- .../Fundamentals/T22_MultipleOutputs.lean | 5 +- 3 files changed, 69 insertions(+), 30 deletions(-) diff --git a/Strata/Languages/Laurel/EliminateMultipleOutputs.lean b/Strata/Languages/Laurel/EliminateMultipleOutputs.lean index 582bb8ee69..4945e6ddb8 100644 --- a/Strata/Languages/Laurel/EliminateMultipleOutputs.lean +++ b/Strata/Languages/Laurel/EliminateMultipleOutputs.lean @@ -14,7 +14,10 @@ Transforms functional procedures (`.isFunctional = true`) with multiple outputs into procedures that return a single synthesized result datatype. Call sites are rewritten to destructure the result using the generated accessors. -This pass operates on `Program → Program`. +Emits an error when the number of LHS assignment targets does not match the +number of outputs of the called function. + +This pass operates on `Program → Program × List DiagnosticModel`. -/ namespace Strata.Laurel @@ -61,6 +64,51 @@ private def transformFunction (info : MultiOutInfo) (proc : Procedure) : Procedu private def destructorName (info : MultiOutInfo) (idx : Nat) : String := s!"{info.resultTypeName}..out{idx}" +private def mismatchError (source : Option FileRange) (callee : String) + (expected actual : Nat) : DiagnosticModel := + let msg := s!"call to '{callee}' has {actual} assignment target(s), but the function returns {expected} output(s)" + match source with + | some fr => DiagnosticModel.withRange fr msg + | none => DiagnosticModel.fromMessage msg + +/-- Scan statements for mismatched multi-output call sites, returning diagnostics. -/ +private def validateStmts (infoMap : Std.HashMap String MultiOutInfo) + (stmts : List StmtExprMd) : List DiagnosticModel := + stmts.filterMap fun stmt => + match stmt.val with + | .Assign targets ⟨.StaticCall callee _args, callSrc, _⟩ => + match infoMap.get? callee.text with + | some info => + if targets.length != info.outputs.length then + some (mismatchError (callSrc.orElse fun _ => stmt.source) callee.text info.outputs.length targets.length) + else none + | none => none + | .LocalVariable _name _ty (some ⟨.StaticCall callee _args, callSrc, _⟩) => + match infoMap.get? callee.text with + | some info => + some (mismatchError (callSrc.orElse fun _ => stmt.source) callee.text info.outputs.length 1) + | none => none + | _ => none + +/-- Validate all procedure bodies for mismatched call sites. -/ +private def validateExpr (infoMap : Std.HashMap String MultiOutInfo) + (expr : StmtExprMd) : List DiagnosticModel := + StateT.run (s := ([] : List DiagnosticModel)) ( + mapStmtExprM (m := StateM (List DiagnosticModel)) (fun e => do + match e.val with + | .Block stmts _label => + modify (· ++ validateStmts infoMap stmts) + | _ => pure () + return e) expr) |>.2 + +/-- Validate a procedure body. -/ +private def validateProcedure (infoMap : Std.HashMap String MultiOutInfo) + (proc : Procedure) : List DiagnosticModel := + match proc.body with + | .Transparent b => validateExpr infoMap (mkMd (.Block [b] none)) + | .Opaque _posts (some impl) _mods => validateExpr infoMap (mkMd (.Block [impl] none)) + | _ => [] + /-- Rewrite a statement list, replacing multi-output call patterns. -/ private def rewriteStmts (infoMap : Std.HashMap String MultiOutInfo) (stmts : List StmtExprMd) : List StmtExprMd := @@ -72,30 +120,18 @@ private def rewriteStmts (infoMap : Std.HashMap String MultiOutInfo) | .Assign targets ⟨.StaticCall callee args, callSrc, callMd⟩ => match infoMap.get? callee.text with | some info => - let tempName := mkId s!"${callee.text}$temp" - let tempTy := mkTy (.UserDefined (mkId info.resultTypeName)) - let tempDecl := mkMd (.LocalVariable tempName tempTy - (some ⟨.StaticCall callee args, callSrc, callMd⟩)) - let assigns := targets.zipIdx.map fun (tgt, i) => - mkMd (.Assign [tgt] - (mkMd (.StaticCall (mkId (destructorName info i)) - [mkMd (.Identifier tempName)]))) - go rest (assigns.reverse ++ (tempDecl :: acc)) - | none => go rest (stmt :: acc) - | .LocalVariable name _ty (some ⟨.StaticCall callee args, callSrc, callMd⟩) => - match infoMap.get? callee.text with - | some info => - match info.outputs with - | firstOut :: _ => + if targets.length != info.outputs.length then + go rest (stmt :: acc) + else let tempName := mkId s!"${callee.text}$temp" let tempTy := mkTy (.UserDefined (mkId info.resultTypeName)) let tempDecl := mkMd (.LocalVariable tempName tempTy (some ⟨.StaticCall callee args, callSrc, callMd⟩)) - let localDecl := mkMd (.LocalVariable name firstOut.type - (some (mkMd (.StaticCall (mkId (destructorName info 0)) - [mkMd (.Identifier tempName)])))) - go rest (localDecl :: tempDecl :: acc) - | [] => go rest (stmt :: acc) + let assigns := targets.zipIdx.map fun (tgt, i) => + mkMd (.Assign [tgt] + (mkMd (.StaticCall (mkId (destructorName info i)) + [mkMd (.Identifier tempName)]))) + go rest (assigns.reverse ++ (tempDecl :: acc)) | none => go rest (stmt :: acc) | _ => go rest (stmt :: acc) go stmts [] @@ -123,19 +159,23 @@ private def rewriteProcedure (infoMap : Std.HashMap String MultiOutInfo) | _ => proc /-- Eliminate multiple outputs from a Program. Only applies to functional procedures. -/ -def eliminateMultipleOutputs (program : Program) : Program := +def eliminateMultipleOutputs (program : Program) : Program × List DiagnosticModel := let infos := collectMultiOutFunctions program.staticProcedures - if infos.isEmpty then program else + if infos.isEmpty then (program, []) else let infoMap : Std.HashMap String MultiOutInfo := infos.foldl (fun m info => m.insert info.funcName info) {} + -- Validate all call sites first + let diags := program.staticProcedures.flatMap (validateProcedure infoMap) + -- If there are errors, return the program unchanged + if !diags.isEmpty then (program, diags) else let newDatatypes := infos.map mkResultDatatype let procs := program.staticProcedures.map fun f => match infoMap.get? f.name.text with | some info => rewriteProcedure infoMap (transformFunction info f) | none => rewriteProcedure infoMap f - { program with + ({ program with staticProcedures := procs - types := program.types ++ newDatatypes.map TypeDefinition.Datatype } + types := program.types ++ newDatatypes.map TypeDefinition.Datatype }, []) end -- public section end Strata.Laurel diff --git a/Strata/Languages/Laurel/LaurelCompilationPipeline.lean b/Strata/Languages/Laurel/LaurelCompilationPipeline.lean index 4b02635458..570855bc66 100644 --- a/Strata/Languages/Laurel/LaurelCompilationPipeline.lean +++ b/Strata/Languages/Laurel/LaurelCompilationPipeline.lean @@ -111,13 +111,13 @@ private def runLaurelPasses (options : LaurelTranslateOptions) (program : Progra let (program, model) := (result.program, result.model) emit "ConstrainedTypeElim" program - let program := eliminateMultipleOutputs program + let (program, multiOutDiags) := eliminateMultipleOutputs program let result := resolve program (some model) let (program, model) := (result.program, result.model) emit "EliminateMultipleOutputs" program let allDiags := resolutionErrors ++ diamondErrors ++ nonCompositeDiags ++ - valueReturnDiags.toList ++ modifiesDiags ++ constrainedTypeDiags + valueReturnDiags.toList ++ modifiesDiags ++ constrainedTypeDiags ++ multiOutDiags return (program, model, allDiags) /-- diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean index 18c05d1b19..2a1b039db7 100644 --- a/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean @@ -25,9 +25,8 @@ function twoOutputs(x: int) procedure testMultiOut() { var a: int; - a := twoOutputs(5); - assert a > 0 -//^^^^^^^^^^^^ error: assertion does not hold + a := twoOutputs(5) +// ^^^^^^^^^^^^^ error: call to 'twoOutputs' has 1 assignment target(s), but the function returns 2 output(s) }; " From 489a5be189d237f1a03fa9a8ea9f826fd2b6e333 Mon Sep 17 00:00:00 2001 From: keyboardDrummer-bot Date: Wed, 22 Apr 2026 12:23:30 +0000 Subject: [PATCH 4/4] Add no-error test for multi-output functions --- .../T22_MultipleOutputsNoError.lean | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputsNoError.lean diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputsNoError.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputsNoError.lean new file mode 100644 index 0000000000..29fea30c2d --- /dev/null +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputsNoError.lean @@ -0,0 +1,29 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import StrataTest.Util.TestDiagnostics +import StrataTest.Languages.Laurel.TestExamples + +open StrataTest.Util +open Strata + +namespace Strata.Laurel + +/-! # Multiple Output Functions (No Error) + +Tests that a function with multiple output parameters is accepted +by the EliminateMultipleOutputs pass when there is no mismatched call site. +-/ + +def multiOutputNoErrorProgram := r" +function twoOutputs(x: int) + returns (a: int, b: int); +" + +#guard_msgs (drop info, error) in +#eval testInputWithOffset "MultiOutputNoError" multiOutputNoErrorProgram 20 processLaurelFile + +end Strata.Laurel