diff --git a/Strata/Languages/Laurel/ConstrainedTypeElim.lean b/Strata/Languages/Laurel/ConstrainedTypeElim.lean index 6b10099e4e..663eedcdeb 100644 --- a/Strata/Languages/Laurel/ConstrainedTypeElim.lean +++ b/Strata/Languages/Laurel/ConstrainedTypeElim.lean @@ -55,7 +55,7 @@ def constraintCallFor (ptMap : ConstrainedTypeMap) (ty : HighType) (varName : Identifier) (src : Option FileRange := none) : Option StmtExprMd := match ty with | .UserDefined name => if ptMap.contains name.text then - some ⟨.StaticCall (mkId s!"{name.text}$constraint") [⟨.Identifier varName, src⟩], src⟩ + some ⟨.StaticCall (mkId s!"{name.text}$constraint") [⟨.Variable (.Local varName), src⟩], src⟩ else none | _ => none @@ -68,7 +68,7 @@ def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Proce if ptMap.contains parent.text then let paramId := { ct.valueName with uniqueId := none } let paramRef : StmtExprMd := - { val := .Identifier paramId, source := none } + { val := .Variable (.Local paramId), source := none } let parentCall : StmtExprMd := { val := .StaticCall (mkId s!"{parent.text}$constraint") [paramRef], source := none } { val := .PrimitiveOp .And [ct.constraint, parentCall], source := none } @@ -133,7 +133,7 @@ def elimStmt (ptMap : ConstrainedTypeMap) pure ([⟨.LocalVariable name ty init', source⟩] ++ check) | .Assign [target] _ => match target.val with - | .Identifier name => do + | .Local name => do match (← get).get? name.text with | some ty => let assert := (constraintCallFor ptMap ty name (src := source)).toList.map diff --git a/Strata/Languages/Laurel/CoreGroupingAndOrdering.lean b/Strata/Languages/Laurel/CoreGroupingAndOrdering.lean index 2a9287e605..d8899d2be4 100644 --- a/Strata/Languages/Laurel/CoreGroupingAndOrdering.lean +++ b/Strata/Languages/Laurel/CoreGroupingAndOrdering.lean @@ -63,7 +63,11 @@ def collectStaticCallNames (expr : StmtExprMd) : List String := | none => [] | .Block stmts _ => stmts.flatMap (fun s => collectStaticCallNames s) | .Assign targets v => - targets.flatMap (fun t => collectStaticCallNames t) ++ + targets.flatMap (fun t => match ht : t.val with + | .Local _ => [] + | .Field target _ => + have : sizeOf target < sizeOf t := Variable.sizeOf_field_target t ht + collectStaticCallNames target) ++ collectStaticCallNames v | .LocalVariable _ _ initOption => match initOption with @@ -85,7 +89,7 @@ def collectStaticCallNames (expr : StmtExprMd) : List String := | some t => collectStaticCallNames t | none => []) ++ collectStaticCallNames body - | .FieldSelect t _ => collectStaticCallNames t + | .Variable (.Field t _) => collectStaticCallNames t | .PureFieldUpdate t _ v => collectStaticCallNames t ++ collectStaticCallNames v | .InstanceCall t _ args => collectStaticCallNames t ++ args.flatMap (fun a => collectStaticCallNames a) diff --git a/Strata/Languages/Laurel/EliminateHoles.lean b/Strata/Languages/Laurel/EliminateHoles.lean index a10cacd9d0..6710a99caa 100644 --- a/Strata/Languages/Laurel/EliminateHoles.lean +++ b/Strata/Languages/Laurel/EliminateHoles.lean @@ -52,7 +52,7 @@ private def mkHoleCall (holeType : HighTypeMd) : ElimHoleM StmtExprMd := do body := .Opaque [] none [] } modify fun s => { s with generatedFunctions := s.generatedFunctions ++ [holeProc] } - return bare (.StaticCall holeName (inputs.map (fun p => bare (.Identifier p.name)))) + return bare (.StaticCall holeName (inputs.map (fun p => bare (.Variable (.Local p.name))))) /-- Replace a deterministic `.Hole` with a call to a fresh uninterpreted function. Non-hole nodes pass through unchanged; recursion is handled by `mapStmtExprM`. -/ diff --git a/Strata/Languages/Laurel/EliminateValueReturns.lean b/Strata/Languages/Laurel/EliminateValueReturns.lean index b057c0d8c9..f465c6055c 100644 --- a/Strata/Languages/Laurel/EliminateValueReturns.lean +++ b/Strata/Languages/Laurel/EliminateValueReturns.lean @@ -27,7 +27,7 @@ private def eliminateValueReturnNode (outParam : Identifier) (stmt : StmtExprMd) match stmt.val with | .Return (some value) => -- Synthesized nodes use default metadata since no diagnostics should be reported on them - let target : StmtExprMd := { val := .Identifier outParam, source := none } + let target : VariableMd := { val := .Local outParam, source := none } let assign : StmtExprMd := { val := .Assign [target] value, source := none } let ret : StmtExprMd := { val := .Return none, source := stmt.source } { val := .Block [assign, ret] none, source := none } diff --git a/Strata/Languages/Laurel/FilterPrelude.lean b/Strata/Languages/Laurel/FilterPrelude.lean index 569dde80ad..8311937393 100644 --- a/Strata/Languages/Laurel/FilterPrelude.lean +++ b/Strata/Languages/Laurel/FilterPrelude.lean @@ -100,8 +100,11 @@ private partial def collectExprNames (expr : StmtExprMd) : CollectM Unit := do dec.forM collectExprNames collectExprNames body | .Assign targets value => - collectExprNames value; targets.forM collectExprNames - | .FieldSelect target _ => collectExprNames target + collectExprNames value + targets.forM fun t => match t.val with + | .Local _ => pure () + | .Field target _ => collectExprNames target + | .Variable (.Field target _) => collectExprNames target | .PureFieldUpdate target _ newVal => collectExprNames target; collectExprNames newVal | .PrimitiveOp _ args => args.forM collectExprNames @@ -120,7 +123,7 @@ private partial def collectExprNames (expr : StmtExprMd) : CollectM Unit := do | .ReferenceEquals lhs rhs => collectExprNames lhs; collectExprNames rhs | .Hole _ ty => ty.forM collectHighTypeNames | .Exit _ | .LiteralInt _ | .LiteralBool _ | .LiteralString _ | .LiteralDecimal _ - | .Identifier _ | .This | .Abstract | .All => pure () + | .Variable (.Local _) | .This | .Abstract | .All => pure () /-- Collect names from a procedure body. -/ private def collectBodyNames (body : Body) : CollectM Unit := do @@ -177,7 +180,7 @@ private partial def collectInvokeOnTargets (expr : StmtExprMd) | .StaticCall callee args => let rest ← args.flatMapM collectInvokeOnTargets return callee.text :: rest - | .Identifier _ | .LiteralInt _ | .LiteralBool _ | .LiteralString _ + | .Variable (.Local _) | .LiteralInt _ | .LiteralBool _ | .LiteralString _ | .LiteralDecimal _ => return [] | _ => throw s!"FilterPrelude.collectInvokeOnTargets: unexpected node in invokeOn expression" diff --git a/Strata/Languages/Laurel/Grammar/AbstractToConcreteTreeTranslator.lean b/Strata/Languages/Laurel/Grammar/AbstractToConcreteTreeTranslator.lean index 5532883ef4..ef0b7603cd 100644 --- a/Strata/Languages/Laurel/Grammar/AbstractToConcreteTreeTranslator.lean +++ b/Strata/Languages/Laurel/Grammar/AbstractToConcreteTreeTranslator.lean @@ -77,6 +77,12 @@ private def operationName : Operation → String | .Gt => "gt" | .Geq => "ge" | .StrConcat => "strConcat" -- Internal-only: public because `partial` prevents `private` in this section +mutual +partial def variableToArg (v : VariableMd) : Arg := + match v.val with + | .Local name => laurelOp "identifier" #[ident name.text] + | .Field target field => laurelOp "fieldAccess" #[stmtExprToArg target, ident field.text] + partial def stmtExprToArg (s : StmtExprMd) : Arg := stmtExprValToArg s.val where @@ -90,7 +96,9 @@ where | .LiteralString s => laurelOp "string" #[.strlit sr s] | .Hole true _ => laurelOp "hole" | .Hole false _ => laurelOp "nondetHole" - | .Identifier name => laurelOp "identifier" #[ident name.text] + | .Variable (.Local name) => laurelOp "identifier" #[ident name.text] + | .Variable (.Field target field) => + laurelOp "fieldAccess" #[stmtExprToArg target, ident field.text] | .Block stmts label => let stmtArgs := stmts.map stmtExprToArg |>.toArray match label with @@ -101,13 +109,16 @@ where let initOpt := optionArg (init.map fun e => laurelOp "initializer" #[stmtExprToArg e]) laurelOp "varDecl" #[ident name.text, typeOpt, initOpt] | .Assign targets value => - -- Grammar only supports single-target assign; use first target or placeholder - let targetArg := match targets with - | t :: _ => stmtExprToArg t - | [] => laurelOp "identifier" #[ident "_"] - laurelOp "assign" #[targetArg, stmtExprToArg value] - | .FieldSelect target field => - laurelOp "fieldAccess" #[stmtExprToArg target, ident field.text] + if targets.length > 1 then + let targetArgs := targets.map fun t => match t.val with + | .Local name => ident name.text + | .Field _ _ => ident "_" + laurelOp "multiAssign" #[commaSep targetArgs.toArray, stmtExprToArg value] + else + let targetArg := match targets with + | t :: _ => variableToArg t + | [] => laurelOp "identifier" #[ident "_"] + laurelOp "assign" #[targetArg, stmtExprToArg value] | .StaticCall callee args => let calleeArg := laurelOp "identifier" #[ident callee.text] let argsArr := args.map stmtExprToArg |>.toArray @@ -165,9 +176,10 @@ where | .PureFieldUpdate target field value => -- Not directly in grammar; emit as assignment to field laurelOp "assign" #[ - laurelOp "fieldAccess" #[stmtExprToArg target, ident field.text], + variableToArg ⟨.Field target field, none⟩, stmtExprToArg value ] +end private def parameterToArg (p : Parameter) : Arg := laurelOp "parameter" #[ident p.name.text, highTypeToArg p.type] diff --git a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean index e3ea1f61c9..a366bc44f1 100644 --- a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean +++ b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean @@ -87,6 +87,12 @@ instance : Inhabited Parameter where def mkHighTypeMd (t : HighType) (source : Option FileRange) : HighTypeMd := { val := t, source := source } def mkStmtExprMd (e : StmtExpr) (source : Option FileRange) : StmtExprMd := { val := e, source := source } +/-- Convert a parsed StmtExprMd (from the assign target position) into a VariableMd. -/ +def stmtExprToVariable (e : StmtExprMd) : VariableMd := + match e.val with + | .Variable v => ⟨v, e.source⟩ + | _ => ⟨.Local { text := "_invalid_" }, e.source⟩ + def translateNat (arg : Arg) : TransM Nat := do let .num _ n := arg | TransM.error s!"translateNat expects num literal" @@ -243,12 +249,20 @@ partial def translateStmtExpr (arg : Arg) : TransM StmtExprMd := do return mkStmtExprMd (.LocalVariable name varType value) src | q`Laurel.identifier, #[arg0] => let name ← translateIdent arg0 - return mkStmtExprMd (.Identifier name) src + return mkStmtExprMd (.Variable (.Local name)) src | q`Laurel.parenthesis, #[arg0] => translateStmtExpr arg0 | q`Laurel.assign, #[arg0, arg1] => let target ← translateStmtExpr arg0 + let varTarget := stmtExprToVariable target + let value ← translateStmtExpr arg1 + return mkStmtExprMd (.Assign [varTarget] value) src + | q`Laurel.multiAssign, #[targetsSeq, arg1] => + let targetIdents ← match targetsSeq with + | .seq _ .comma args => args.toList.mapM translateIdent + | _ => pure [] + let variables := targetIdents.map fun name => (⟨.Local name, name.source⟩ : VariableMd) let value ← translateStmtExpr arg1 - return mkStmtExprMd (.Assign [target] value) src + return mkStmtExprMd (.Assign variables value) src | q`Laurel.new, #[nameArg] => let name ← translateIdent nameArg return mkStmtExprMd (.New name) src @@ -263,7 +277,7 @@ partial def translateStmtExpr (arg : Arg) : TransM StmtExprMd := do | q`Laurel.call, #[arg0, argsSeq] => let callee ← translateStmtExpr arg0 let calleeName := match callee.val with - | .Identifier name => name + | .Variable (.Local name) => name | _ => "" let argsList ← match argsSeq with | .seq _ .comma args => args.toList.mapM translateStmtExpr @@ -285,7 +299,7 @@ partial def translateStmtExpr (arg : Arg) : TransM StmtExprMd := do let obj ← translateStmtExpr objArg let field ← translateIdent fieldArg let fieldSrc ← getArgFileRange fieldArg - return mkStmtExprMd (.FieldSelect obj field) fieldSrc + return mkStmtExprMd (.Variable (.Field obj field)) fieldSrc | q`Laurel.while, #[condArg, invSeqArg, bodyArg] => let cond ← translateStmtExpr condArg let invariants ← match invSeqArg with diff --git a/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean b/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean index b36b1b54ea..710f2c8db8 100644 --- a/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean +++ b/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean @@ -9,7 +9,7 @@ module -- Laurel dialect definition, loaded from LaurelGrammar.st -- NOTE: Changes to LaurelGrammar.st are not automatically tracked by the build system. -- Update this file (e.g. this comment) to trigger a recompile after modifying LaurelGrammar.st. --- Last grammar change: added modifiesWildcard for `modifies *` and opaque keyword +-- Last grammar change: added multiAssign with parenthesized syntax public import Strata.DDM.Integration.Lean public meta import Strata.DDM.Integration.Lean diff --git a/Strata/Languages/Laurel/Grammar/LaurelGrammar.st b/Strata/Languages/Laurel/Grammar/LaurelGrammar.st index cc2cc46083..f71e191909 100644 --- a/Strata/Languages/Laurel/Grammar/LaurelGrammar.st +++ b/Strata/Languages/Laurel/Grammar/LaurelGrammar.st @@ -47,6 +47,7 @@ op parenthesis (inner: StmtExpr): StmtExpr => "(" inner ")"; // Assignment op assign (target: StmtExpr, value: StmtExpr): StmtExpr => @[prec(10)] target " := " value; +op multiAssign (targets: CommaSepBy Ident, value: StmtExpr): StmtExpr => @[prec(10)] "(" targets ") := " value; // Binary operators op add (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(60), leftassoc] lhs " + " rhs; diff --git a/Strata/Languages/Laurel/HeapParameterization.lean b/Strata/Languages/Laurel/HeapParameterization.lean index 02cf41f45b..4a1314fa3b 100644 --- a/Strata/Languages/Laurel/HeapParameterization.lean +++ b/Strata/Languages/Laurel/HeapParameterization.lean @@ -57,7 +57,7 @@ def collectExprMd (expr : StmtExprMd) : StateM AnalysisResult Unit := collectExp def collectExpr (expr : StmtExpr) : StateM AnalysisResult Unit := do match _: expr with - | .FieldSelect target _ => + | .Variable (.Field target _) => modify fun s => { s with readsHeapDirectly := true }; collectExprMd target | .InstanceCall target _ args => collectExprMd target; for a in args do collectExprMd a | .StaticCall callee args => modify fun s => { s with callees := callee :: s.callees }; for a in args do collectExprMd a @@ -69,11 +69,12 @@ def collectExpr (expr : StmtExpr) : StateM AnalysisResult Unit := do | .Assign assignTargets v => -- Check if any target is a field assignment (heap write) for ⟨assignTarget, _⟩ in assignTargets.attach do - match assignTarget.val with - | .FieldSelect _ _ => + match ht : assignTarget.val with + | .Field target _ => + have : sizeOf target < sizeOf assignTarget := Variable.sizeOf_field_target assignTarget ht modify fun s => { s with writesHeapDirectly := true } - | _ => pure () - collectExprMd assignTarget + collectExprMd target + | .Local _ => pure () collectExprMd v | .PureFieldUpdate t _ v => collectExprMd t; collectExprMd v | .PrimitiveOp _ args => for a in args do collectExprMd a @@ -237,6 +238,8 @@ def freshVarName : TransformM Identifier := do /-- Helper to wrap a StmtExpr into StmtExprMd with empty metadata -/ private def mkMd (e : StmtExpr) : StmtExprMd := { val := e, source := none } +/-- Helper to wrap a Variable into VariableMd with empty metadata -/ +private def mkMd' (v : Variable) : VariableMd := { val := v, source := none } /-- Resolve the owning composite type name for a field access by computing the target expression's type. @@ -260,12 +263,12 @@ where recurse (exprMd : StmtExprMd) (valueUsed : Bool := true) : TransformM StmtExprMd := do let ⟨expr, source⟩ := exprMd match _h : expr with - | .FieldSelect selectTarget fieldName => do + | .Variable (.Field selectTarget fieldName) => do let some qualifiedName := resolveQualifiedFieldName model fieldName | return ⟨ .Hole, source ⟩ let valTy := (model.get fieldName).getType - let readExpr := ⟨ .StaticCall "readField" [mkMd (.Identifier heapVar), selectTarget, mkMd (.StaticCall qualifiedName [])], source ⟩ + let readExpr := ⟨ .StaticCall "readField" [mkMd (.Variable (.Local heapVar)), selectTarget, mkMd (.StaticCall qualifiedName [])], source ⟩ -- Unwrap Box: apply the appropriate destructor recordBoxConstructor model valTy.val return mkMd <| .StaticCall (boxDestructorName model valTy.val) [readExpr] @@ -278,13 +281,13 @@ where let freshVar ← freshVarName let varDecl := mkMd (.LocalVariable freshVar (computeExprType model exprMd) none) let callWithHeap := ⟨ .Assign - [mkMd (.Identifier heapVar), mkMd (.Identifier freshVar)] - (⟨ .StaticCall callee (mkMd (.Identifier heapVar) :: args'), source ⟩), source ⟩ - return ⟨ .Block [varDecl, callWithHeap, mkMd (.Identifier freshVar)] none, source ⟩ + [mkMd' (.Local heapVar), mkMd' (.Local freshVar)] + (⟨ .StaticCall callee (mkMd (.Variable (.Local heapVar)) :: args'), source ⟩), source ⟩ + return ⟨ .Block [varDecl, callWithHeap, mkMd (.Variable (.Local freshVar))] none, source ⟩ else - return ⟨ .Assign [mkMd (.Identifier heapVar)] (⟨ .StaticCall callee (mkMd (.Identifier heapVar) :: args'), source ⟩), source ⟩ + return ⟨ .Assign [mkMd' (.Local heapVar)] (⟨ .StaticCall callee (mkMd (.Variable (.Local heapVar)) :: args'), source ⟩), source ⟩ else if calleeReadsHeap then - return ⟨ .StaticCall callee (mkMd (.Identifier heapVar) :: args'), source ⟩ + return ⟨ .StaticCall callee (mkMd (.Variable (.Local heapVar)) :: args'), source ⟩ else return ⟨ .StaticCall callee args', source ⟩ | .InstanceCall callTarget callee args => @@ -318,7 +321,7 @@ where return ⟨ .Return v', source ⟩ | .Assign targets v => match targets with - | [⟨.FieldSelect target fieldName, _⟩] => + | [⟨.Field target fieldName, _⟩] => let some qualifiedName := resolveQualifiedFieldName model fieldName | return ⟨ .Hole, source ⟩ let valTy := (model.get fieldName).getType @@ -327,21 +330,21 @@ where -- Wrap value in Box constructor recordBoxConstructor model valTy.val let boxedVal := mkMd <| .StaticCall (boxConstructorName model valTy.val) [v'] - let heapAssign := ⟨ .Assign [mkMd (.Identifier heapVar)] - (mkMd (.StaticCall "updateField" [mkMd (.Identifier heapVar), target', mkMd (.StaticCall qualifiedName []), boxedVal])), source ⟩ + let heapAssign := ⟨ .Assign [mkMd' (.Local heapVar)] + (mkMd (.StaticCall "updateField" [mkMd (.Variable (.Local heapVar)), target', mkMd (.StaticCall qualifiedName []), boxedVal])), source ⟩ if valueUsed then return ⟨ .Block [heapAssign, v'] none, source ⟩ else return heapAssign - | [fieldSelectMd] => - let tgt' ← recurse fieldSelectMd - return ⟨ .Assign [tgt'] (← recurse v), source ⟩ - | [] => - return ⟨ .Assign [] (← recurse v), source ⟩ - | tgt :: rest => - let tgt' ← recurse tgt - let targets' ← rest.mapM (recurse ·) - return ⟨ .Assign (tgt' :: targets') (← recurse v), source ⟩ + | _ => + let targets' ← targets.attach.mapM fun ⟨vm, hmem⟩ => do + match hvm : vm.val with + | .Local _ => pure vm + | .Field target fieldName => + have _h1 : sizeOf target < sizeOf vm := Variable.sizeOf_field_target vm hvm + have _h2 : sizeOf vm < sizeOf targets := List.sizeOf_lt_of_mem hmem + pure ⟨.Field (← recurse target) fieldName, vm.source⟩ + return ⟨ .Assign targets' (← recurse v), source ⟩ | .PureFieldUpdate t f v => return ⟨ .PureFieldUpdate (← recurse t) f (← recurse v), source ⟩ | .PrimitiveOp op args => let args' ← args.mapM (recurse ·) @@ -385,6 +388,7 @@ where | .ContractOf ty f => return ⟨ .ContractOf ty (← recurse f), source ⟩ | _ => return exprMd termination_by sizeOf exprMd + decreasing_by all_goals simp_wf; all_goals (try omega); all_goals (try term_by_mem); all_goals (try simp_all); all_goals omega def heapTransformProcedure (model: SemanticModel) (proc : Procedure) : TransformM Procedure := do let heapName : Identifier := "$heap" @@ -408,7 +412,7 @@ def heapTransformProcedure (model: SemanticModel) (proc : Procedure) : Transform let body' ← match proc.body with | .Transparent bodyExpr => -- First assign $heap_in to $heap, then transform body using $heap - let assignHeap := mkMd (.Assign [mkMd (.Identifier heapName)] (mkMd (.Identifier heapInName))) + let assignHeap := mkMd (.Assign [mkMd' (.Local heapName)] (mkMd (.Variable (.Local heapInName)))) let bodyExpr' ← heapTransformExpr heapName model bodyExpr bodyValueIsUsed pure (.Transparent (mkMd (.Block [assignHeap, bodyExpr'] none))) | .Opaque postconds impl modif => @@ -416,7 +420,7 @@ def heapTransformProcedure (model: SemanticModel) (proc : Procedure) : Transform let postconds' ← postconds.mapM (·.mapM (heapTransformExpr heapName model)) let impl' ← match impl with | some implExpr => - let assignHeap := mkMd (.Assign [mkMd (.Identifier heapName)] (mkMd (.Identifier heapInName))) + let assignHeap := mkMd (.Assign [mkMd' (.Local heapName)] (mkMd (.Variable (.Local heapInName)))) let implExpr' ← heapTransformExpr heapName model implExpr bodyValueIsUsed pure (some (mkMd (.Block [assignHeap, implExpr'] none))) | none => pure none diff --git a/Strata/Languages/Laurel/InferHoleTypes.lean b/Strata/Languages/Laurel/InferHoleTypes.lean index 76244e3c7b..24770fbe3b 100644 --- a/Strata/Languages/Laurel/InferHoleTypes.lean +++ b/Strata/Languages/Laurel/InferHoleTypes.lean @@ -126,7 +126,9 @@ private def inferExpr (expr : StmtExprMd) (expectedType : HighTypeMd) : InferHol return ⟨.Block (← inferBlockStmts stmts expectedType) label, source⟩ | .Assign targets value => let targetType := match targets with - | target :: _ => computeExprType model target + | target :: _ => match target.val with + | .Local id => (model.get id).getType + | .Field _ fieldName => (model.get fieldName).getType | _ => defaultHoleType return ⟨.Assign targets (← inferExpr value targetType), source⟩ | .LocalVariable name ty init => diff --git a/Strata/Languages/Laurel/Laurel.lean b/Strata/Languages/Laurel/Laurel.lean index a8d6aa9860..4aeef0033f 100644 --- a/Strata/Languages/Laurel/Laurel.lean +++ b/Strata/Languages/Laurel/Laurel.lean @@ -235,6 +235,16 @@ inductive Body where /-- An external body for procedures that are not translated to Core (e.g., built-in primitives). -/ | External +/-- +A variable reference in Laurel. Used as the target of assignments and +as a general variable/field access expression. +-/ +inductive Variable where + /-- A local variable reference by name. -/ + | Local (name : Identifier) + /-- Read a field from a target expression. Combined with `Assign` for field writes. -/ + | Field (target : AstNode StmtExpr) (fieldName : Identifier) + /-- The unified statement-expression type for Laurel programs. @@ -266,12 +276,10 @@ inductive StmtExpr : Type where | LiteralString (value : String) /-- A decimal literal. -/ | LiteralDecimal (value : Decimal) - /-- A variable reference by name. -/ - | Identifier (name : Identifier) - /-- Assignment to one or more targets. Multiple targets are only allowed when the value is a `StaticCall` to a procedure with multiple outputs. -/ - | Assign (targets : List (AstNode StmtExpr)) (value : AstNode StmtExpr) - /-- Read a field from a target expression. Combined with `Assign` for field writes. -/ - | FieldSelect (target : AstNode StmtExpr) (fieldName : Identifier) + /-- A variable reference by name or field access. -/ + | Variable (ref : Variable) + /-- Assignment to one or more targets. Multiple targets are only allowed with identifier targets and when the value is a `StaticCall` to a procedure with multiple outputs. -/ + | Assign (targets : List (AstNode Variable)) (value : AstNode StmtExpr) /-- Update a field on a pure (value) type, producing a new value. -/ | PureFieldUpdate (target : AstNode StmtExpr) (fieldName : Identifier) (newValue : AstNode StmtExpr) /-- Call a static procedure by name with the given arguments. -/ @@ -324,10 +332,16 @@ end @[expose] abbrev HighTypeMd := AstNode HighType @[expose] abbrev StmtExprMd := AstNode StmtExpr +@[expose] abbrev VariableMd := AstNode Variable theorem AstNode.sizeOf_val_lt {t : Type} [SizeOf t] (e : AstNode t) : sizeOf e.val < sizeOf e := by cases e; grind +theorem Variable.sizeOf_field_target (v : AstNode Variable) + (h : v.val = .Field target fieldName) : sizeOf target < sizeOf v := by + have h1 := AstNode.sizeOf_val_lt v + rw [h] at h1; grind + theorem Condition.sizeOf_condition_lt (c : Condition) : sizeOf c.condition < 1 + sizeOf c := by cases c; grind diff --git a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean index b24c41ea23..9de2154a8a 100644 --- a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean +++ b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean @@ -162,7 +162,7 @@ def translateExpr (expr : StmtExprMd) | .LiteralInt i => return .const () (.intConst i) | .LiteralString s => return .const () (.strConst s) | .LiteralDecimal d => return .const () (.realConst (Strata.Decimal.toRat d)) - | .Identifier name => + | .Variable (.Local name) => -- First check if this name is bound by an enclosing quantifier match boundVars.findIdx? (· == name) with | some idx => @@ -285,7 +285,7 @@ def translateExpr (expr : StmtExprMd) | .IsType _ _ => throwExprDiagnostic $ diagnosticFromSource expr.source "IsType should have been lowered" DiagnosticType.StrataBug | .New _ => throwExprDiagnostic $ diagnosticFromSource expr.source s!"New should have been eliminated by typeHierarchyTransform" DiagnosticType.StrataBug - | .FieldSelect target fieldId => + | .Variable (.Field target fieldId) => -- Field selects should have been eliminated by heap parameterization -- If we see one here, it's an error in the pipeline throwExprDiagnostic $ diagnosticFromSource expr.source s!"FieldSelect should have been eliminated by heap parameterization: {Std.ToFormat.format target}#{fieldId.text}" DiagnosticType.StrataBug @@ -399,7 +399,7 @@ def translateStmt (stmt : StmtExprMd) return [Core.Statement.init ident coreType .nondet md] | .Assign targets value => match targets with - | [⟨ .Identifier targetId, _ ⟩] => + | [⟨ .Local targetId, _ ⟩] => let ident := ⟨targetId.text, ()⟩ -- Check if RHS is a procedure call (not a function) match value.val with @@ -444,7 +444,7 @@ def translateStmt (stmt : StmtExprMd) let coreArgs ← args.mapM (fun a => translateExpr a) let lhsIdents := targets.filterMap fun t => match t.val with - | .Identifier name => some (⟨name.text, ()⟩) + | .Local name => some (⟨name.text, ()⟩) | _ => none let outArgs : List (Core.CallArg Core.Expression) := lhsIdents.map .outArg return [Core.Statement.call callee.text (coreArgs.map .inArg ++ outArgs) (astNodeToCoreMd value)] @@ -452,7 +452,7 @@ def translateStmt (stmt : StmtExprMd) -- Instance method call: havoc all target variables let havocStmts := targets.filterMap fun t => match t.val with - | .Identifier name => some (Core.Statement.havoc ⟨name.text, ()⟩ md) + | .Local name => some (Core.Statement.havoc ⟨name.text, ()⟩ md) | _ => none return (havocStmts) | _ => diff --git a/Strata/Languages/Laurel/LaurelTypes.lean b/Strata/Languages/Laurel/LaurelTypes.lean index 6f42948de3..208300878a 100644 --- a/Strata/Languages/Laurel/LaurelTypes.lean +++ b/Strata/Languages/Laurel/LaurelTypes.lean @@ -36,9 +36,9 @@ def computeExprType (model : SemanticModel) (expr : StmtExprMd) : HighTypeMd := | .LiteralString _ => ⟨ .TString, source ⟩ | .LiteralDecimal _ => ⟨ .TReal, source ⟩ -- Variables - | .Identifier id => (model.get id).getType + | .Variable (.Local id) => (model.get id).getType -- Field access - | .FieldSelect _ fieldName => (model.get fieldName).getType + | .Variable (.Field _ fieldName) => (model.get fieldName).getType -- Pure field update returns the same type as the target | .PureFieldUpdate target _ _ => computeExprType model target -- Calls — return the declared output type when available, fall back to Unknown otherwise diff --git a/Strata/Languages/Laurel/LiftImperativeExpressions.lean b/Strata/Languages/Laurel/LiftImperativeExpressions.lean index 1a32241e93..2256f311d2 100644 --- a/Strata/Languages/Laurel/LiftImperativeExpressions.lean +++ b/Strata/Languages/Laurel/LiftImperativeExpressions.lean @@ -90,6 +90,9 @@ private def emptyMd : Option String := none /-- Wrap a StmtExpr value with empty metadata -/ private def bare (v : StmtExpr) : StmtExprMd := ⟨v, none⟩ +/-- Wrap a Variable value with empty metadata -/ +private def bare' (v : Variable) : VariableMd := ⟨v, none⟩ + /-- Wrap a HighType value with empty metadata -/ private def bareType (v : HighType) : HighTypeMd := ⟨v, none⟩ @@ -157,6 +160,12 @@ private def computeType (expr : StmtExprMd) : LiftM HighTypeMd := do let s ← get return computeExprType s.model expr +private def computeVariableType (v : VariableMd) : LiftM HighTypeMd := do + let s ← get + match v.val with + | .Local id => return (s.model.get id).getType + | .Field _ fieldName => return (s.model.get fieldName).getType + /-- Check if an expression contains any assignments or imperative calls (recursively). -/ def containsAssignmentOrImperativeCall (model: SemanticModel) (expr : StmtExprMd) : Bool := match expr with @@ -201,18 +210,18 @@ Shared logic for lifting an assignment in expression position: prepends the assignment, creates before-snapshots for all targets, and updates substitutions. The value should already be transformed by the caller. -/ -private def liftAssignExpr (targets : List StmtExprMd) (seqValue : StmtExprMd) +private def liftAssignExpr (targets : List (AstNode Variable)) (seqValue : StmtExprMd) (source : Option FileRange) : LiftM Unit := do -- Prepend the assignment itself prepend (⟨.Assign targets seqValue, source⟩) -- Create a before-snapshot for each target and update substitutions for target in targets do match target.val with - | .Identifier varName => + | .Local varName => let snapshotName ← freshTempFor varName - let varType ← computeType target + let varType ← computeVariableType target -- Snapshot goes before the assignment (cons pushes to front) - prepend (⟨.LocalVariable snapshotName varType (some (⟨.Identifier varName, source⟩)), source⟩) + prepend (⟨.LocalVariable snapshotName varType (some (⟨.Variable (.Local varName), source⟩)), source⟩) setSubst varName snapshotName | _ => pure () @@ -225,8 +234,8 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do match expr with | AstNode.mk val source => match val with - | .Identifier name => - return ⟨.Identifier (← getSubst name), source⟩ + | .Variable (.Local name) => + return ⟨.Variable (.Local (← getSubst name)), source⟩ | .LiteralInt _ | .LiteralBool _ | .LiteralString _ | .LiteralDecimal _ => return expr @@ -234,7 +243,7 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do -- Nondeterministic typed hole: lift to a fresh variable with no initializer (havoc) let holeVar ← freshCondVar prepend (bare (.LocalVariable holeVar holeType none)) - return bare (.Identifier holeVar) + return bare (.Variable (.Local holeVar)) | .Assign targets value => -- The expression result is the current substitution for the first target @@ -244,7 +253,7 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do | _ => return expr let resultExpr ← match firstTarget.val with - | .Identifier varName => pure (⟨.Identifier (← getSubst varName), source⟩) + | .Local varName => pure (⟨.Variable (.Local (← getSubst varName)), source⟩) | _ => dbg_trace "Strata bug: non-identifier targets should have been removed before the lift expression phase"; return expr @@ -272,10 +281,10 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do let callResultType ← computeType expr let liftedCall := [ ⟨ (.LocalVariable callResultVar callResultType none), source⟩, - ⟨.Assign [bare (.Identifier callResultVar)] seqCall, source⟩ + ⟨.Assign [bare' (.Local callResultVar)] seqCall, source⟩ ] modify fun s => { s with prependedStmts := s.prependedStmts ++ liftedCall} - return bare (.Identifier callResultVar) + return bare (.Variable (.Local callResultVar)) | .IfThenElse cond thenBranch elseBranch => let model := (← get).model @@ -294,14 +303,14 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do modify fun s => { s with prependedStmts := [], subst := [] } let seqThen ← transformExpr thenBranch let thenPrepends ← takePrepends - let thenBlock := bare (.Block (thenPrepends ++ [⟨.Assign [bare (.Identifier condVar)] seqThen, source⟩]) none) + let thenBlock := bare (.Block (thenPrepends ++ [⟨.Assign [bare' (.Local condVar)] seqThen, source⟩]) none) -- Process else-branch from scratch modify fun s => { s with prependedStmts := [], subst := [] } let seqElse ← match elseBranch with | some e => do let se ← transformExpr e let elsePrepends ← takePrepends - pure (some (bare (.Block (elsePrepends ++ [⟨.Assign [bare (.Identifier condVar)] se, source⟩]) none))) + pure (some (bare (.Block (elsePrepends ++ [⟨.Assign [bare' (.Local condVar)] se, source⟩]) none))) | none => pure none -- Restore outer state modify fun s => { s with subst := savedSubst, prependedStmts := savedPrepends } @@ -313,7 +322,7 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do -- Output order: declaration, then if-then-else prepend (⟨.IfThenElse seqCond thenBlock seqElse, source⟩) prepend (bare (.LocalVariable condVar condType none)) - return bare (.Identifier condVar) + return bare (.Variable (.Local condVar)) else -- No assignments in branches — recurse normally let seqCond ← transformExpr cond @@ -339,7 +348,7 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do prepend (⟨.LocalVariable name ty (some seqInit), expr.source⟩) | none => prepend (⟨.LocalVariable name ty none, expr.source⟩) - return ⟨.Identifier (← getSubst name), expr.source⟩ + return ⟨.Variable (.Local (← getSubst name)), expr.source⟩ else return expr diff --git a/Strata/Languages/Laurel/MapStmtExpr.lean b/Strata/Languages/Laurel/MapStmtExpr.lean index b236c37317..3d3dac952c 100644 --- a/Strata/Languages/Laurel/MapStmtExpr.lean +++ b/Strata/Languages/Laurel/MapStmtExpr.lean @@ -48,9 +48,15 @@ def mapStmtExprM [Monad m] (f : StmtExprMd → m StmtExprMd) (expr : StmtExprMd) | .Return v => pure ⟨.Return (← v.attach.mapM fun ⟨e, _⟩ => mapStmtExprM f e), source⟩ | .Assign targets value => - pure ⟨.Assign (← targets.attach.mapM fun ⟨e, _⟩ => mapStmtExprM f e) (← mapStmtExprM f value), source⟩ - | .FieldSelect target fieldName => - pure ⟨.FieldSelect (← mapStmtExprM f target) fieldName, source⟩ + let mappedTargets ← targets.attach.mapM fun ⟨v, _⟩ => do + match hv : v.val with + | .Local _ => pure v + | .Field target fieldName => + have : sizeOf target < sizeOf v := Variable.sizeOf_field_target v hv + pure ⟨.Field (← mapStmtExprM f target) fieldName, v.source⟩ + pure ⟨.Assign mappedTargets (← mapStmtExprM f value), source⟩ + | .Variable (.Field target fieldName) => + pure ⟨.Variable (.Field (← mapStmtExprM f target) fieldName), source⟩ | .PureFieldUpdate target fieldName newValue => pure ⟨.PureFieldUpdate (← mapStmtExprM f target) fieldName (← mapStmtExprM f newValue), source⟩ | .StaticCall callee args => @@ -88,7 +94,7 @@ def mapStmtExprM [Monad m] (f : StmtExprMd → m StmtExprMd) (expr : StmtExprMd) -- it must get its own arm above; otherwise all passes will silently -- skip recursion into those children. | .Exit _ | .LiteralInt _ | .LiteralBool _ | .LiteralString _ | .LiteralDecimal _ - | .Identifier _ | .New _ | .This | .Abstract | .All | .Hole .. => pure expr + | .Variable (.Local _) | .New _ | .This | .Abstract | .All | .Hole .. => pure expr f rebuilt termination_by sizeOf expr decreasing_by diff --git a/Strata/Languages/Laurel/ModifiesClauses.lean b/Strata/Languages/Laurel/ModifiesClauses.lean index ac9309a5de..34db4b4339 100644 --- a/Strata/Languages/Laurel/ModifiesClauses.lean +++ b/Strata/Languages/Laurel/ModifiesClauses.lean @@ -103,10 +103,10 @@ def buildModifiesEnsures (proc: Procedure) (model: SemanticModel) (modifiesExprs let entries := extractModifiesEntries model modifiesExprs let objName : Identifier := "$modifies_obj" let fldName : Identifier := "$modifies_fld" - let obj := mkMd <| .Identifier objName - let fld := mkMd <| .Identifier fldName - let heapIn := mkMd <| .Identifier heapInName - let heapOut := mkMd <| .Identifier heapOutName + let obj := mkMd <| .Variable (.Local objName) + let fld := mkMd <| .Variable (.Local fldName) + let heapIn := mkMd <| .Variable (.Local heapInName) + let heapOut := mkMd <| .Variable (.Local heapOutName) -- Build the "obj is allocated" condition: Composite..ref($obj) < $heap_in.nextReference let heapCounter := mkMd <| .StaticCall "Heap..nextReference!" [heapIn] let objRef := mkMd <| .StaticCall "Composite..ref!" [obj] diff --git a/Strata/Languages/Laurel/Resolution.lean b/Strata/Languages/Laurel/Resolution.lean index 00914cfd74..e84ed6475a 100644 --- a/Strata/Languages/Laurel/Resolution.lean +++ b/Strata/Languages/Laurel/Resolution.lean @@ -43,10 +43,10 @@ resolved sub-trees (e.g. a procedure's parameters already have their IDs). - `Constant` — named constant ### Reference nodes (use a name) -- `StmtExpr.Identifier` — variable reference +- `StmtExpr.Variable (.Local _)` — variable reference - `StmtExpr.StaticCall` — static procedure call - `StmtExpr.InstanceCall` — instance method call -- `StmtExpr.FieldSelect` — field access +- `StmtExpr.Variable (.Field _ _)` — field access - `StmtExpr.New` — object creation (references a type) - `StmtExpr.Exit` — exit a labelled block - `HighType.UserDefined` — type reference @@ -274,7 +274,7 @@ def resolveRef (name : Identifier) (source : Option FileRange := none) private def targetTypeName (target : StmtExprMd) : ResolveM (Option String) := do let s ← get match target.val with - | .Identifier ref => + | .Variable (.Local ref) => match s.scope.get? ref.text with | some (_, node) => match node.getType.val with @@ -384,17 +384,26 @@ def resolveStmtExpr (exprMd : StmtExprMd) : ResolveM StmtExprMd := do | .LiteralBool v => pure (.LiteralBool v) | .LiteralString v => pure (.LiteralString v) | .LiteralDecimal v => pure (.LiteralDecimal v) - | .Identifier ref => + | .Variable (.Local ref) => let ref' ← resolveRef ref source - pure (.Identifier ref') + pure (.Variable (.Local ref')) | .Assign targets value => - let targets' ← targets.mapM resolveStmtExpr + let targets' ← targets.attach.mapM fun ⟨v, _⟩ => do + match hv : v.val with + | .Local ref => + let ref' ← resolveRef ref source + pure ⟨.Local ref', v.source⟩ + | .Field target fieldName => + have : sizeOf target < sizeOf v := Variable.sizeOf_field_target v hv + let target' ← resolveStmtExpr target + let fieldName' ← resolveFieldRef target' fieldName source + pure ⟨.Field target' fieldName', v.source⟩ let value' ← resolveStmtExpr value pure (.Assign targets' value') - | .FieldSelect target fieldName => + | .Variable (.Field target fieldName) => let target' ← resolveStmtExpr target let fieldName' ← resolveFieldRef target' fieldName source - pure (.FieldSelect target' fieldName') + pure (.Variable (.Field target' fieldName')) | .PureFieldUpdate target fieldName newVal => let target' ← resolveStmtExpr target let fieldName' ← resolveFieldRef target' fieldName source @@ -649,11 +658,15 @@ private def collectStmtExpr (map : Std.HashMap Nat ResolvedNode) (expr : StmtExp let map := match dec with | some d => collectStmtExpr map d | none => map collectStmtExpr map body | .Return val => match val with | some v => collectStmtExpr map v | none => map - | .Identifier _ => map + | .Variable (.Local _) => map + | .Variable (.Field target _) => collectStmtExpr map target | .Assign targets value => - let map := targets.foldl collectStmtExpr map + let map := targets.foldl (fun m t => match ht : t.val with + | .Local _ => m + | .Field target _ => + have : sizeOf target < sizeOf t := Variable.sizeOf_field_target t ht + collectStmtExpr m target) map collectStmtExpr map value - | .FieldSelect target _ => collectStmtExpr map target | .PureFieldUpdate target _ newVal => let map := collectStmtExpr map target collectStmtExpr map newVal @@ -687,6 +700,8 @@ private def collectStmtExpr (map : Std.HashMap Nat ResolvedNode) (expr : StmtExp | .ContractOf _ fn => collectStmtExpr map fn | .New _ | .This | .Exit _ | .LiteralInt _ | .LiteralBool _ | .LiteralString _ | .LiteralDecimal _ | .Abstract | .All | .Hole _ _ => map +termination_by expr +decreasing_by all_goals (try have := AstNode.sizeOf_val_lt expr); all_goals term_by_mem private def collectBody (map : Std.HashMap Nat ResolvedNode) (body : Body) : Std.HashMap Nat ResolvedNode := diff --git a/Strata/Languages/Laurel/TypeHierarchy.lean b/Strata/Languages/Laurel/TypeHierarchy.lean index 8b15aaa76e..be9c860c02 100644 --- a/Strata/Languages/Laurel/TypeHierarchy.lean +++ b/Strata/Languages/Laurel/TypeHierarchy.lean @@ -122,7 +122,7 @@ Walk a StmtExpr AST and collect DiagnosticModel errors for diamond-inherited fie def validateDiamondFieldAccessesForStmtExpr (model : SemanticModel) (expr : StmtExprMd) : List DiagnosticModel := match _h : expr.val with - | .FieldSelect target fieldName => + | .Variable (.Field target fieldName) => let targetErrors := validateDiamondFieldAccessesForStmtExpr model target let fieldError := match (computeExprType model target).val with | .UserDefined typeName => @@ -135,7 +135,19 @@ def validateDiamondFieldAccessesForStmtExpr (model : SemanticModel) | .Block stmts _ => stmts.flatMap (fun s => validateDiamondFieldAccessesForStmtExpr model s) | .Assign targets value => - let targetErrors := targets.attach.foldl (fun acc ⟨t, _⟩ => acc ++ validateDiamondFieldAccessesForStmtExpr model t) [] + let targetErrors := targets.attach.foldl (fun acc ⟨t, _⟩ => acc ++ match ht : t.val with + | .Local _ => [] + | .Field target fieldName => + have : sizeOf target < sizeOf t := Variable.sizeOf_field_target t ht + let innerErrors := validateDiamondFieldAccessesForStmtExpr model target + let fieldError := match (computeExprType model target).val with + | .UserDefined typeName => + if isDiamondInheritedField model typeName fieldName then + let fileRange := t.source.getD FileRange.unknown + [DiagnosticModel.withRange fileRange s!"fields that are inherited multiple times can not be accessed."] + else [] + | _ => [] + innerErrors ++ fieldError) [] targetErrors ++ validateDiamondFieldAccessesForStmtExpr model value | .IfThenElse c t e => let errs := validateDiamondFieldAccessesForStmtExpr model c ++ @@ -218,11 +230,11 @@ Lower `New name` to a block that: def lowerNew (name : Identifier) (source : Option FileRange) : THM StmtExprMd := do let heapVar : Identifier := "$heap" let freshVar ← freshVarName - let getCounter := mkMd (.StaticCall "Heap..nextReference!" [mkMd (.Identifier heapVar)]) + let getCounter := mkMd (.StaticCall "Heap..nextReference!" [mkMd (.Variable (.Local heapVar))]) let saveCounter := mkMd (.LocalVariable freshVar ⟨.TInt, none⟩ (some getCounter)) - let newHeap := mkMd (.StaticCall "increment" [mkMd (.Identifier heapVar)]) - let updateHeap := mkMd (.Assign [mkMd (.Identifier heapVar)] newHeap) - let compositeResult := mkMd (.StaticCall "MkComposite" [mkMd (.Identifier freshVar), mkMd (.StaticCall (name.text ++ "_TypeTag") [])]) + let newHeap := mkMd (.StaticCall "increment" [mkMd (.Variable (.Local heapVar))]) + let updateHeap := mkMd (.Assign [⟨.Local heapVar, none⟩] newHeap) + let compositeResult := mkMd (.StaticCall "MkComposite" [mkMd (.Variable (.Local freshVar)), mkMd (.StaticCall (name.text ++ "_TypeTag") [])]) return { val := .Block [saveCounter, updateHeap, compositeResult] none, source := source } /-- Local rewrite of `IsType` and `New` nodes. Recursion is handled by `mapStmtExprM`. -/ diff --git a/Strata/Languages/Python/PythonLaurelTypedExpr.lean b/Strata/Languages/Python/PythonLaurelTypedExpr.lean index 5b5e9549ba..6fdff13f64 100644 --- a/Strata/Languages/Python/PythonLaurelTypedExpr.lean +++ b/Strata/Languages/Python/PythonLaurelTypedExpr.lean @@ -48,7 +48,7 @@ def ofStmt {tp} (s : StmtExpr) (source : Option FileRange := none) : TypedStmtEx def identifier (v : String) (tp : HighType) (source : Option FileRange := none) : TypedStmtExpr tp := - .ofStmt (.Identifier (mkId v)) source + .ofStmt (.Variable (.Local (mkId v))) source def literalBool (v : Bool) (source : Option FileRange := none) : TypedStmtExpr .TBool := diff --git a/Strata/Languages/Python/PythonToLaurel.lean b/Strata/Languages/Python/PythonToLaurel.lean index 5336b929a1..bab330ebda 100644 --- a/Strata/Languages/Python/PythonToLaurel.lean +++ b/Strata/Languages/Python/PythonToLaurel.lean @@ -574,7 +574,7 @@ partial def translateExpr (ctx : TranslationContext) (e : Python.expr SourceRang -- Variable references | .Name _ name _ => - return mkStmtExprMd (StmtExpr.Identifier name.val) + return mkStmtExprMd (StmtExpr.Variable (.Local name.val)) -- Binary operations | .BinOp _ left op right => do @@ -663,7 +663,7 @@ partial def translateExpr (ctx : TranslationContext) (e : Python.expr SourceRang let freshVar := s!"$cmp_tmp_{e.toAst.ann.start.byteIdx}_{i}" let varDecl := mkStmtExprMd (StmtExpr.LocalVariable freshVar AnyTy (some comp)) tempDecls := tempDecls.push varDecl - operandRefs := operandRefs.push (mkStmtExprMd (StmtExpr.Identifier freshVar)) + operandRefs := operandRefs.push (mkStmtExprMd (StmtExpr.Variable (.Local freshVar))) else operandRefs := operandRefs.push comp let ⟨hOpSize⟩ ← guardProp (p := operandRefs.size ≥ n + 1) "operandRefs size < n+1" @@ -728,7 +728,7 @@ partial def translateExpr (ctx : TranslationContext) (e : Python.expr SourceRang let dict ← fields.foldlM (fun acc (fname, fty) => return mkStmtExprMd (.StaticCall "DictStrAny_cons" [mkStmtExprMd (.LiteralString fname), - ← wrapFieldInAny fty (mkStmtExprMd (.FieldSelect inner fname)), acc])) + ← wrapFieldInAny fty (mkStmtExprMd (.Variable (.Field inner fname))), acc])) (mkStmtExprMd (.StaticCall "DictStrAny_empty" [])) pure <| mkStmtExprMd (.StaticCall "from_ClassInstance" [mkStmtExprMd (.LiteralString ty), dict]) @@ -804,9 +804,9 @@ partial def translateExpr (ctx : TranslationContext) (e : Python.expr SourceRang | .Name _ name _ => if name.val == "self" && ctx.currentClassName.isSome then -- self.field in a method - field type is Any (builtins) or Composite (classes) - let fieldExpr := mkStmtExprMd (StmtExpr.FieldSelect - (mkStmtExprMd (StmtExpr.Identifier "self")) - attr.val) + let fieldExpr := mkStmtExprMd (StmtExpr.Variable (.Field + (mkStmtExprMd (StmtExpr.Variable (.Local "self"))) + attr.val)) let className := ctx.currentClassName.get! match tryLookupFieldHighType ctx className attr.val with | some (.UserDefined _) => @@ -829,7 +829,7 @@ partial def translateExpr (ctx : TranslationContext) (e : Python.expr SourceRang else -- Regular object.field access let objExpr ← translateExpr ctx obj - let fieldExpr := mkStmtExprMd (StmtExpr.FieldSelect objExpr attr.val) + let fieldExpr := mkStmtExprMd (StmtExpr.Variable (.Field objExpr attr.val)) let objType ← inferExprType ctx obj match tryLookupFieldHighType ctx objType attr.val with | some ty => wrapFieldInAny ty fieldExpr @@ -837,7 +837,7 @@ partial def translateExpr (ctx : TranslationContext) (e : Python.expr SourceRang | _ => -- Complex object expression - translate and access field let objExpr ← translateExpr ctx obj - let fieldExpr := mkStmtExprMd (StmtExpr.FieldSelect objExpr attr.val) + let fieldExpr := mkStmtExprMd (StmtExpr.Variable (.Field objExpr attr.val)) let objType ← inferExprType ctx obj match tryLookupFieldHighType ctx objType attr.val with | some ty => wrapFieldInAny ty fieldExpr @@ -1121,7 +1121,7 @@ partial def translateExprAsReceiver (ctx : TranslationContext) match tryLookupFieldHighType ctx objType fieldAttr.val with | some (.UserDefined _) => let objExpr ← translateExprAsReceiver ctx obj - pure <| mkStmtExprMd (StmtExpr.FieldSelect objExpr fieldAttr.val) + pure <| mkStmtExprMd (StmtExpr.Variable (.Field objExpr fieldAttr.val)) | _ => translateExpr ctx e | _ => translateExpr ctx e @@ -1165,7 +1165,7 @@ partial def translateCall (ctx : TranslationContext) | .Attribute _ (.Name _ receiverName _) _ _ => if receiverName.val ∈ ctx.variableTypes.unzip.1 then [mkStmtExprMd (StmtExpr.Assign - [mkStmtExprMd (StmtExpr.Identifier receiverName.val)] + [⟨.Local receiverName.val, none⟩] (mkStmtExprMd .Hole))] else [] | _ => [] @@ -1175,7 +1175,7 @@ partial def translateCall (ctx : TranslationContext) | some (varName, ty) => if ty == PyLauType.Any then [mkStmtExprMd (StmtExpr.Assign - [mkStmtExprMd (StmtExpr.Identifier varName)] + [⟨.Local varName, none⟩] (mkStmtExprMd (.Hole false none)))] else [] | _ => [] @@ -1310,9 +1310,15 @@ def withException (ctx : TranslationContext) (funcname: String) : Bool := | some sig => hasErrorOutput sig | none => false -def freeVar (name: String) := mkStmtExprMd (.Identifier name) -def maybeExceptVar := freeVar "maybe_except" -def nullcall_var := freeVar "nullcall_ret" +def freeVar (name: String) := mkStmtExprMd (.Variable (.Local name)) +def freeVarTarget (name: String) : VariableMd := ⟨.Local name, none⟩ +/-- Convert a StmtExprMd (expected to be a Variable) to a VariableMd for use as an assign target. -/ +def toVarTarget (e : StmtExprMd) : VariableMd := + match e.val with + | .Variable v => ⟨v, e.source⟩ + | _ => ⟨.Local { text := "_invalid_" }, e.source⟩ +def maybeExceptVar := freeVarTarget "maybe_except" +def nullcall_var := freeVarTarget "nullcall_ret" partial def translateAssign (ctx : TranslationContext) (lhs: Python.expr SourceRange) @@ -1362,7 +1368,7 @@ partial def translateAssign (ctx : TranslationContext) match lhs with | .Name _ n _ => if n.val ∈ ctx.variableTypes.unzip.1 then - let targetExpr := mkStmtExprMd (StmtExpr.Identifier n.val) + let targetExpr : VariableMd := ⟨.Local n.val, none⟩ return (ctx, [mkStmtExprMd (StmtExpr.Assign [targetExpr] rhs_trans)] ++ exceptHavoc, true) else -- Use type annotation if it matches a known composite type @@ -1381,14 +1387,14 @@ partial def translateAssign (ctx : TranslationContext) let mut newctx := ctx match lhs with | .Name _ n _ => - let targetExpr := mkStmtExprMd (StmtExpr.Identifier n.val) + let targetExpr : VariableMd := ⟨.Local n.val, none⟩ let assignStmts := match rhs_trans.val with | .StaticCall fnname args => if let some (ImportedSymbol.compositeType laurelName) := ctx.importedSymbols[fnname.text]? then let resolvedId := mkId laurelName let newExpr := mkStmtExprMd (StmtExpr.New resolvedId) let varType := mkHighTypeMd (.UserDefined resolvedId) - let selfRef := mkStmtExprMd (StmtExpr.Identifier n.val) + let selfRef := mkStmtExprMd (StmtExpr.Variable (.Local n.val)) let initStmt := mkInstanceMethodCall laurelName "__init__" selfRef args source if n.val ∈ ctx.variableTypes.unzip.1 then let assignStmt := mkStmtExprMdWithLoc (StmtExpr.Assign [targetExpr] newExpr) source @@ -1437,7 +1443,7 @@ partial def translateAssign (ctx : TranslationContext) let slices ← slices.mapM (translateExpr ctx) let source := sourceRangeToSource ctx.filePath lhs.toAst.ann let anySetsExpr := mkStmtExprMdWithLoc (StmtExpr.StaticCall "Any_sets!" [ListAny_mk slices, target, rhs_trans]) source - let assignStmts := [mkStmtExprMdWithLoc (StmtExpr.Assign [target] anySetsExpr) source] + let assignStmts := [mkStmtExprMdWithLoc (StmtExpr.Assign [toVarTarget target] anySetsExpr) source] return (ctx,assignStmts, false) | _ => throw (.internalError "Invalid Subscript Expr") | .Attribute _ obj attr _ => @@ -1445,9 +1451,9 @@ partial def translateAssign (ctx : TranslationContext) | .Name _ name _ => if name.val == "self" && ctx.currentClassName.isSome then -- self.field : type = value in a method - let fieldAccess := mkStmtExprMd (StmtExpr.FieldSelect - (mkStmtExprMd (StmtExpr.Identifier "self")) - attr.val) + let fieldAccess := mkStmtExprMd (StmtExpr.Variable (.Field + (mkStmtExprMd (StmtExpr.Variable (.Local "self"))) + attr.val)) -- When the annotation is a composite type, the RHS (which is Any) -- cannot be assigned directly; use New to initialize the field. let rhs' ← match annotation with @@ -1457,11 +1463,11 @@ partial def translateAssign (ctx : TranslationContext) pure (mkStmtExprMd (StmtExpr.New (mkId laurelName))) else pure rhs_trans | none => pure rhs_trans - let assignStmt := mkStmtExprMdWithLoc (StmtExpr.Assign [fieldAccess] rhs') source + let assignStmt := mkStmtExprMdWithLoc (StmtExpr.Assign [toVarTarget fieldAccess] rhs') source return (ctx, [assignStmt], true) else let targetExpr ← translateExpr ctx lhs -- This will handle self.field via translateExpr - let assignStmt := mkStmtExprMdWithLoc (StmtExpr.Assign [targetExpr] rhs_trans) source + let assignStmt := mkStmtExprMdWithLoc (StmtExpr.Assign [toVarTarget targetExpr] rhs_trans) source return (ctx, [assignStmt], true) | _ => throw (.unsupportedConstruct "Assignment targets not yet supported" (toString (repr lhs))) | _ => throw (.unsupportedConstruct "Assignment targets not yet supported" (toString (repr lhs))) @@ -1590,7 +1596,7 @@ def getExceptionCheckPreamble (ctx : TranslationContext) (e : StmtExprMd) (varNa ([], e) else if containsUserCall ctx e then let varDecl := mkStmtExprMd (StmtExpr.LocalVariable varName AnyTy (some e)) - let varRef := mkStmtExprMd (StmtExpr.Identifier varName) + let varRef := mkStmtExprMd (StmtExpr.Variable (.Local varName)) ([varDecl, mkExceptionCheckAssert varRef "Check exception"], varRef) else (getExceptionAssertions ctx e, e) @@ -1647,7 +1653,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang let annStr := pyExprToString annotation match typeTester? annStr with | some testerName => - let varExpr := mkStmtExprMd (StmtExpr.Identifier n.val) + let varExpr := mkStmtExprMd (StmtExpr.Variable (.Local n.val)) let cond := mkStmtExprMd (StmtExpr.StaticCall testerName [varExpr]) [mkStmtExprMdWithLoc (StmtExpr.Assert { condition := cond }) md] | none => [] @@ -1702,7 +1708,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang let (preamble, eRef) := getExceptionCheckPreamble ctx e s!"$ret_exc_{expr.toAst.ann.start.byteIdx}" -- Coerce Composite return values to Any for LaurelResult : Any let eRef ← coerceToAny ctx expr eRef - let assign := mkStmtExprMdWithLoc (StmtExpr.Assign [mkStmtExprMd (StmtExpr.Identifier PyLauFuncReturnVar)] eRef) md + let assign := mkStmtExprMdWithLoc (StmtExpr.Assign [⟨.Local PyLauFuncReturnVar, none⟩] eRef) md .ok $ preamble ++ [assign, mkStmtExprMdWithLoc (StmtExpr.Exit "$body") md] | none => .ok [mkStmtExprMdWithLoc (StmtExpr.Exit "$body") md] return (ctx, stmts) @@ -1721,7 +1727,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang let freshVar := s!"assert_cond_{test.toAst.ann.start.byteIdx}" let varType := mkHighTypeMd .TBool let varDecl := mkStmtExprMd (StmtExpr.LocalVariable freshVar varType (some condExpr)) - let varRef := mkStmtExprMd (StmtExpr.Identifier freshVar) + let varRef := mkStmtExprMd (StmtExpr.Variable (.Local freshVar)) ([varDecl], varRef, { ctx with variableTypes := ctx.variableTypes ++ [(freshVar, "bool")] }) | _ => ([], condExpr, ctx) @@ -1791,7 +1797,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang -- Insert exception checks after each statement in try body let bodyStmtsWithChecks := bodyStmts.flatMap fun stmt => let isException := mkStmtExprMd (StmtExpr.StaticCall "isError" - [mkStmtExprMd (StmtExpr.Identifier "maybe_except")]) + [mkStmtExprMd (StmtExpr.Variable (.Local "maybe_except"))]) let exitToHandler := mkStmtExprMd (StmtExpr.IfThenElse isException (mkStmtExprMd (StmtExpr.Exit catchersLabel)) none) [stmt, exitToHandler] @@ -1830,7 +1836,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang let mgrTy ← inferExprType currentCtx ctxExpr let mgrLauTy ← translateType currentCtx mgrTy let mgrDecl := mkStmtExprMd (StmtExpr.LocalVariable mgrName mgrLauTy (some mgrExpr)) - let mgrRef := mkStmtExprMd (StmtExpr.Identifier mgrName) + let mgrRef := mkStmtExprMd (StmtExpr.Variable (.Local mgrName)) currentCtx := {currentCtx with variableTypes := currentCtx.variableTypes ++ [(mgrName, mgrTy)]} let enterCall := mkInstanceMethodCall mgrTy "__enter__" mgrRef [] md let exitCall := mkInstanceMethodCall mgrTy "__exit__" mgrRef [] md @@ -1839,7 +1845,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang let varName := pyExprToString varExpr if varName ∈ currentCtx.variableTypes.unzip.fst then let assignStmt := mkStmtExprMd (StmtExpr.Assign - [mkStmtExprMd (StmtExpr.Identifier varName)] enterCall) + [⟨.Local varName, none⟩] enterCall) setupStmts := setupStmts ++ [mgrDecl, assignStmt] else -- New variable — declare outside the block so it's visible after @@ -1879,7 +1885,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang | .Block (_ :: _ :: _) _ => let varName := s!"$for_iter_{iter.toAst.ann.start.byteIdx}" let varDecl := mkStmtExprMd (StmtExpr.LocalVariable varName AnyTy (some iterRaw)) - let varRef := mkStmtExprMd (StmtExpr.Identifier varName) + let varRef := mkStmtExprMd (StmtExpr.Variable (.Local varName)) ([varDecl], varRef) | _ => ([], iterRaw) if let .Call _ (.Name _ {val:= "range",..} _) _ _ := iter then @@ -1893,9 +1899,9 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang -- Havoc the target(s) (Ellipsis always translates to Hole) let sr := target.ann let counterName := s!"@for_loop_counter_{s.toAst.ann.start.byteIdx}" - let counterVar := freeVar counterName + let counterVar := freeVarTarget counterName let counterDecl := mkStmtExprMd $ .LocalVariable counterName (mkHighTypeMd $ .TInt) (mkStmtExprMd $ .LiteralInt 0) - let counterIncrease := mkStmtExprMd $ .Assign [counterVar] (mkStmtExprMd $ .PrimitiveOp .Add [counterVar, mkStmtExprMd $ .LiteralInt 1]) + let counterIncrease := mkStmtExprMd $ .Assign [counterVar] (mkStmtExprMd $ .PrimitiveOp .Add [freeVar counterName, mkStmtExprMd $ .LiteralInt 1]) let indexRhs := expr.Call sr (.Name sr {val:= "Any_iter_index", ann:= sr} default) {val:= #[iter, .Name sr {val:= counterName, ann:= sr} default], ann:= sr} {val:= #[], ann:= sr} -- Any_iter_index is defined in PythonRuntimeLaurelPart, so indexRhs would be translated into .StaticCall "Any_iter_index" ..., hot .Hole @@ -1907,7 +1913,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang let (finalCtx, bodyStmts) ← translateStmtList bodyCtx body.val.toList let assumeStmts : List StmtExprMd ← do match target with | .Name _ n _ => - let targetVar := mkStmtExprMd (StmtExpr.Identifier n.val) + let targetVar := mkStmtExprMd (StmtExpr.Variable (.Local n.val)) let isAnyNone (s: StmtExprMd) := match s.val with | .StaticCall constructor _ => constructor == AnyConstructor.None | _ => false match iterExpr.val with @@ -1929,10 +1935,10 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang | _ => pure [] let counterLtLen := match iterExpr.val with | .StaticCall "range" (boundExpr::_) => - mkStmtExprMd $ .PrimitiveOp .Lt [counterVar, + mkStmtExprMd $ .PrimitiveOp .Lt [freeVar counterName, mkStmtExprMd $ .StaticCall "Any..as_int!" [boundExpr]] | _ => - mkStmtExprMd $ .PrimitiveOp .Lt [counterVar, + mkStmtExprMd $ .PrimitiveOp .Lt [freeVar counterName, mkStmtExprMd $ .StaticCall "Any_len" [iterExpr]] let bodyStmts := targetDecls ++ assumeStmts ++ bodyStmts ++ [counterIncrease] let innerBlock := mkStmtExprMd (StmtExpr.Block bodyStmts (some continueLabel)) @@ -2150,7 +2156,7 @@ def renameInputParams (inputs : List Parameter) (exclude : String → Bool := fu let orig : String := p.name.text let prefixed : String := paramInputPrefix ++ orig mkStmtExprMd (StmtExpr.LocalVariable (mkId orig) p.type - (some (mkStmtExprMd (StmtExpr.Identifier prefixed)))) + (some (mkStmtExprMd (StmtExpr.Variable (.Local prefixed))))) (renamed, copies) /-- Translate a Python function body: collect all variable declarations, hoist them @@ -2164,7 +2170,7 @@ def translateFunctionBody (ctx : TranslationContext) (kwargsName : Option String let nonSelfParams := inputs.filter (fun p => p.name.text != "self") let (_, paramCopies) := renameInputParams nonSelfParams (match kwargsName with | some kw => (· == kw) | none => fun _ => false) - let noneReturn := mkStmtExprMd (.Assign [mkStmtExprMd (.Identifier PyLauFuncReturnVar)] AnyNone) + let noneReturn := mkStmtExprMd (.Assign [⟨.Local PyLauFuncReturnVar, none⟩] AnyNone) let bodyStmts := noneReturn::paramCopies ++ bodyStmts let bodyStmts := (mkStmtExprMd (.LocalVariable "nullcall_ret" AnyTy (some AnyNone))) :: bodyStmts return (mkStmtExprMd (StmtExpr.Block bodyStmts none), newctx) diff --git a/Strata/Languages/Python/Specs/ToLaurel.lean b/Strata/Languages/Python/Specs/ToLaurel.lean index 217a5fd053..7326aaf774 100644 --- a/Strata/Languages/Python/Specs/ToLaurel.lean +++ b/Strata/Languages/Python/Specs/ToLaurel.lean @@ -419,12 +419,12 @@ def buildSpecBody (allArgs : Array Arg) let mut stmts : Array StmtExprMd := #[] -- 1. Havoc the result: result := Hole(nondet) let holeExpr : StmtExprMd := { val := .Hole (deterministic := false), source := source } - let resultId : StmtExprMd := { val := .Identifier (mkId "result"), source := source } + let resultId : VariableMd := { val := .Local (mkId "result"), source := source } let assignStmt ← mkStmtWithLoc (.Assign [resultId] holeExpr) default stmts := stmts.push assignStmt -- 2. Assert type / required-param preconditions for arg in allArgs do - let paramId : StmtExprMd := { val := .Identifier (mkId arg.name), source := source } + let paramId : StmtExprMd := { val := .Variable (.Local (mkId arg.name)), source := source } match ← typeAssertion? arg.type paramId source with | some assertion => if arg.default.isSome then @@ -471,7 +471,7 @@ def buildSpecBody (allArgs : Array Arg) -- NOTE. Skip NoneType: generated stubs currently declare `-> None` even for methods -- that return values. Assuming isfrom_None would make callers unreachable. if returnType.asIdent != some .noneType then - let resultRef : StmtExprMd := { val := .Identifier (mkId "result"), source := source } + let resultRef : StmtExprMd := { val := .Variable (.Local (mkId "result")), source := source } if let some retAssertion ← typeAssertion? returnType resultRef source then let assumeStmt ← mkStmtWithLoc (.Assume retAssertion) default stmts := stmts.push assumeStmt diff --git a/StrataTest/Languages/Laurel/TypeAliasElimTest.lean b/StrataTest/Languages/Laurel/TypeAliasElimTest.lean index 4ca58d611c..a60f62a33d 100644 --- a/StrataTest/Languages/Laurel/TypeAliasElimTest.lean +++ b/StrataTest/Languages/Laurel/TypeAliasElimTest.lean @@ -49,7 +49,7 @@ private def chainedProgram : Program := mkProc "test" [{ name := mkId "x", type := mkTy (.UserDefined (mkId "B")) }] [{ name := mkId "r", type := mkTy (.UserDefined (mkId "A")) }] - (.Transparent ⟨.Return (some ⟨.Identifier (mkId "x"), none⟩), none⟩) + (.Transparent ⟨.Return (some ⟨.Variable (.Local (mkId "x")), none⟩), none⟩) ] staticFields := [] types := [ @@ -111,7 +111,7 @@ private def procSigProgram : Program := [{ name := mkId "a", type := mkTy (.UserDefined (mkId "MyInt")) }, { name := mkId "b", type := mkTy (.UserDefined (mkId "MyBool")) }] [{ name := mkId "r", type := mkTy (.UserDefined (mkId "MyInt")) }] - (.Transparent ⟨.Return (some ⟨.Identifier (mkId "a"), none⟩), none⟩) + (.Transparent ⟨.Return (some ⟨.Variable (.Local (mkId "a")), none⟩), none⟩) ] staticFields := [] types := [