diff --git a/Strata/DL/Imperative/PureExpr.lean b/Strata/DL/Imperative/PureExpr.lean index 9428839c..6006546e 100644 --- a/Strata/DL/Imperative/PureExpr.lean +++ b/Strata/DL/Imperative/PureExpr.lean @@ -21,6 +21,8 @@ structure PureExpr : Type 1 where Expr : Type /-- Types -/ Ty : Type + /-- Expression metadata type (for use in function declarations, etc.) -/ + ExprMetadata : Type /-- Typing environment, expected to contain a map of variables to their types, type substitution, etc. -/ diff --git a/Strata/DL/Imperative/Stmt.lean b/Strata/DL/Imperative/Stmt.lean index 795171fb..71978957 100644 --- a/Strata/DL/Imperative/Stmt.lean +++ b/Strata/DL/Imperative/Stmt.lean @@ -7,6 +7,7 @@ import Strata.DL.Imperative.Cmd +import Strata.DL.Lambda.Factory namespace Imperative --------------------------------------------------------------------- @@ -38,6 +39,8 @@ inductive Stmt (P : PureExpr) (Cmd : Type) : Type where This will likely be removed, in favor of an alternative view of imperative programs that is purely untructured. -/ | goto (label : String) (md : MetaData P := .empty) + /-- A function declaration within a statement block. -/ + | funcDecl (decl : Lambda.PureFunc P) (md : MetaData P := .empty) deriving Inhabited /-- A block is simply an abbreviation for a list of commands. -/ @@ -69,19 +72,22 @@ def Stmt.inductionOn {P : PureExpr} {Cmd : Type} motive (Stmt.loop guard measure invariant body md)) (goto_case : ∀ (label : String) (md : MetaData P), motive (Stmt.goto label md)) + (funcDecl_case : ∀ (decl : Lambda.PureFunc P) (md : MetaData P), + motive (Stmt.funcDecl decl md)) (s : Stmt P Cmd) : motive s := match s with | Stmt.cmd c => cmd_case c | Stmt.block label b md => - block_case label b md (fun s _ => inductionOn cmd_case block_case ite_case loop_case goto_case s) + block_case label b md (fun s _ => inductionOn cmd_case block_case ite_case loop_case goto_case funcDecl_case s) | Stmt.ite cond thenb elseb md => ite_case cond thenb elseb md - (fun s _ => inductionOn cmd_case block_case ite_case loop_case goto_case s) - (fun s _ => inductionOn cmd_case block_case ite_case loop_case goto_case s) + (fun s _ => inductionOn cmd_case block_case ite_case loop_case goto_case funcDecl_case s) + (fun s _ => inductionOn cmd_case block_case ite_case loop_case goto_case funcDecl_case s) | Stmt.loop guard measure invariant body md => loop_case guard measure invariant body md - (fun s _ => inductionOn cmd_case block_case ite_case loop_case goto_case s) + (fun s _ => inductionOn cmd_case block_case ite_case loop_case goto_case funcDecl_case s) | Stmt.goto label md => goto_case label md + | Stmt.funcDecl decl md => funcDecl_case decl md termination_by s --------------------------------------------------------------------- @@ -97,6 +103,7 @@ def Stmt.sizeOf (s : Imperative.Stmt P C) : Nat := | .ite c tss ess _ => 3 + sizeOf c + Block.sizeOf tss + Block.sizeOf ess | .loop g _ _ bss _ => 3 + sizeOf g + Block.sizeOf bss | .goto _ _ => 1 + | .funcDecl _ _ => 1 @[simp] def Block.sizeOf (ss : Imperative.Block P C) : Nat := @@ -159,6 +166,11 @@ def Stmt.getVars [HasVarsPure P P.Expr] [HasVarsPure P C] (s : Stmt P C) : List | .ite _ tbss ebss _ => Block.getVars tbss ++ Block.getVars ebss | .loop _ _ _ bss _ => Block.getVars bss | .goto _ _ => [] + | .funcDecl decl _ => + -- Get variables from function body (including parameters for simplicity) + match decl.body with + | none => [] + | some body => HasVarsPure.getVars body termination_by (Stmt.sizeOf s) def Block.getVars [HasVarsPure P P.Expr] [HasVarsPure P C] (ss : Block P C) : List P.Ident := @@ -183,6 +195,7 @@ def Stmt.definedVars [HasVarsImp P C] (s : Stmt P C) : List P.Ident := | .cmd cmd => HasVarsImp.definedVars cmd | .block _ bss _ => Block.definedVars bss | .ite _ tbss ebss _ => Block.definedVars tbss ++ Block.definedVars ebss + | .funcDecl decl _ => [decl.name] -- Function declaration defines the function name | _ => [] termination_by (Stmt.sizeOf s) @@ -202,6 +215,7 @@ def Stmt.modifiedVars [HasVarsImp P C] (s : Stmt P C) : List P.Ident := | .block _ bss _ => Block.modifiedVars bss | .ite _ tbss ebss _ => Block.modifiedVars tbss ++ Block.modifiedVars ebss | .loop _ _ _ bss _ => Block.modifiedVars bss + | .funcDecl _ _ => [] -- Function declarations don't modify variables termination_by (Stmt.sizeOf s) def Block.modifiedVars [HasVarsImp P C] (ss : Block P C) : List P.Ident := @@ -262,6 +276,7 @@ def formatStmt (P : PureExpr) (s : Stmt P C) | .loop guard measure invariant body md => f!"{md}while ({guard}) ({measure}) ({invariant}) " ++ Format.bracket "{" f!"{formatBlock P body}" "}" | .goto label md => f!"{md}goto {label}" + | .funcDecl _ md => f!"{md}funcDecl " termination_by s.sizeOf def formatBlock (P : PureExpr) (ss : List (Stmt P C)) diff --git a/Strata/DL/Lambda/Factory.lean b/Strata/DL/Lambda/Factory.lean index cf5af43e..321bc092 100644 --- a/Strata/DL/Lambda/Factory.lean +++ b/Strata/DL/Lambda/Factory.lean @@ -10,6 +10,8 @@ import Strata.DDM.AST import Strata.DDM.Util.Array import Strata.DL.Util.List import Strata.DL.Util.ListMap +import Strata.DL.Imperative.PureExpr +import Strata.DL.Imperative.MetaData /-! ## Lambda's Factory @@ -57,9 +59,11 @@ abbrev LTySignature := Signature IDMeta LTy /-- -A Lambda factory function, where the body can be optional. Universally -quantified type identifiers, if any, appear before this signature and can -quantify over the type identifiers in it. +A generic function structure, parameterized by identifier, expression, type, and metadata types. + +This structure can be instantiated for different expression languages. +For Lambda expressions, use `LFunc`. For other expression systems, instantiate +with appropriate types. A optional evaluation function can be provided in the `concreteEval` field for each factory function to allow the partial evaluator to do constant propagation @@ -85,44 +89,74 @@ concrete/constants, this fails and it returns .none. (TODO) Use `.bvar`s in the body to correspond to the formals instead of using `.fvar`s. -/ -structure LFunc (T : LExprParams) where - name : T.Identifier +structure Func (IdentT : Type) (ExprT : Type) (TyT : Type) (MetadataT : Type) where + name : IdentT typeArgs : List TyIdentifier := [] isConstr : Bool := false --whether function is datatype constructor - inputs : @LMonoTySignature T.IDMeta - output : LMonoTy - body : Option (LExpr T.mono) := .none + inputs : ListMap IdentT TyT + output : TyT + body : Option ExprT := .none -- (TODO): Add support for a fixed set of attributes (e.g., whether to inline -- a function, etc.). attr : Array String := #[] - -- The T.Metadata argument is the metadata that will be attached to the + -- The MetadataT argument is the metadata that will be attached to the -- resulting expression of concreteEval if evaluation was successful. - concreteEval : Option (T.Metadata → List (LExpr T.mono) → Option (LExpr T.mono)) := .none - axioms : List (LExpr T.mono) := [] -- For axiomatic definitions + concreteEval : Option (MetadataT → List ExprT → Option ExprT) := .none + axioms : List ExprT := [] -- For axiomatic definitions + +/-- +A Lambda factory function - instantiation of `Func` for Lambda expressions. + +Universally quantified type identifiers, if any, appear before this signature and can +quantify over the type identifiers in it. +-/ +abbrev LFunc (T : LExprParams) := Func (T.Identifier) (LExpr T.mono) LMonoTy T.Metadata /-- -Well-formedness properties of LFunc. These are split from LFunc because -otherwise it becomes impossible to create a 'temporary' LFunc object whose +A function declaration for use with `PureExpr` - instantiation of `Func` for +any expression system that implements the `PureExpr` interface. +-/ +abbrev PureFunc (P : Imperative.PureExpr) := Func P.Ident P.Expr P.Ty P.ExprMetadata + +/-- +Helper constructor for LFunc to maintain backward compatibility. +-/ +def LFunc.mk {T : LExprParams} (name : T.Identifier) (typeArgs : List TyIdentifier := []) + (isConstr : Bool := false) (inputs : ListMap T.Identifier LMonoTy) (output : LMonoTy) + (body : Option (LExpr T.mono) := .none) (attr : Array String := #[]) + (concreteEval : Option (T.Metadata → List (LExpr T.mono) → Option (LExpr T.mono)) := .none) + (axioms : List (LExpr T.mono) := []) : LFunc T := + Func.mk name typeArgs isConstr inputs output body attr concreteEval axioms + +/-- +Well-formedness properties of Func. These are split from Func because +otherwise it becomes impossible to create a 'temporary' Func object whose wellformedness might not hold yet. -/ -structure LFuncWF {T : LExprParams} (f : LFunc T) where +structure FuncWF {IdentT ExprT TyT MetadataT : Type} (f : Func IdentT ExprT TyT MetadataT) where -- No args have same name. arg_nodup: - List.Nodup (f.inputs.map (·.1.name)) + List.Nodup (f.inputs.map (·.1)) + -- concreteEval does not succeed if the length of args is incorrect. + concreteEval_argmatch: + ∀ fn md args res, f.concreteEval = .some fn + → fn md args = .some res + → args.length = f.inputs.length + +/-- +Well-formedness properties of LFunc - extends FuncWF with Lambda-specific properties. +-/ +structure LFuncWF {T : LExprParams} (f : LFunc T) extends FuncWF f where -- Free variables of body must be arguments. body_freevars: ∀ b freevars, f.body = .some b → freevars = LExpr.freeVars b → (∀ fv, fv ∈ freevars → ∃ arg, List.Mem arg f.inputs ∧ fv.1.name = arg.1.name) - -- concreteEval does not succeed if the length of args is incorrect. - concreteEval_argmatch: - ∀ fn md args res, f.concreteEval = .some fn - → fn md args = .some res - → args.length = f.inputs.length -instance LFuncWF.arg_nodup_decidable {T : LExprParams} (f : LFunc T): - Decidable (List.Nodup (f.inputs.map (·.1.name))) := by +instance FuncWF.arg_nodup_decidable {IdentT ExprT TyT MetadataT : Type} [DecidableEq IdentT] + (f : Func IdentT ExprT TyT MetadataT): + Decidable (List.Nodup (f.inputs.map (·.1))) := by apply List.nodupDecidable instance LFuncWF.body_freevars_decidable {T : LExprParams} (f : LFunc T): @@ -151,24 +185,36 @@ instance LFuncWF.body_freevars_decidable {T : LExprParams} (f : LFunc T): | .none => by apply isTrue; grind --- LFuncWF.concreteEval_argmatch is not decidable. +-- FuncWF.concreteEval_argmatch and LFuncWF.concreteEval_argmatch are not decidable.-- FuncWF.body_freevars is commented out as it's expression-type specific +-- FuncWF.concreteEval_argmatch is not decidable. instance [Inhabited T.Metadata] [Inhabited T.IDMeta] : Inhabited (LFunc T) where default := { name := Inhabited.default, inputs := [], output := LMonoTy.bool } -instance : ToFormat (LFunc T) where +instance {IdentT ExprT TyT MetadataT : Type} [ToFormat IdentT] [ToFormat ExprT] [ToFormat TyT] [Inhabited ExprT] : ToFormat (Func IdentT ExprT TyT MetadataT) where format f := let attr := if f.attr.isEmpty then f!"" else f!"@[{f.attr}]{Format.line}" let typeArgs := if f.typeArgs.isEmpty then f!"" else f!"∀{f.typeArgs}." - let type := f!"{typeArgs} ({Signature.format f.inputs}) → {f.output}" + -- Format inputs recursively like Signature.format + let rec formatInputs (inputs : List (IdentT × TyT)) : Format := + match inputs with + | [] => f!"" + | [(k, v)] => f!"({k} : {v})" + | (k, v) :: rest => f!"({k} : {v}) " ++ formatInputs rest + let type := f!"{typeArgs} ({formatInputs f.inputs}) → {f.output}" let sep := if f.body.isNone then f!";" else f!" :=" let body := if f.body.isNone then f!"" else Std.Format.indentD f!"({f.body.get!})" f!"{attr}\ func {f.name} : {type}{sep}\ {body}" +-- Provide explicit instance for LFunc to ensure proper resolution +-- Requires ToFormat for T.IDMeta (for identifiers in expressions) and T.Metadata (for Inhabited LExpr) +instance [ToFormat T.IDMeta] [Inhabited T.Metadata] : ToFormat (LFunc T) where + format := format + def LFunc.type [DecidableEq T.IDMeta] (f : (LFunc T)) : Except Format LTy := do if !(decide f.inputs.keys.Nodup) then .error f!"[{f.name}] Duplicates found in the formals!\ diff --git a/Strata/Languages/C_Simp/C_Simp.lean b/Strata/Languages/C_Simp/C_Simp.lean index e292aa69..d9ce17ab 100644 --- a/Strata/Languages/C_Simp/C_Simp.lean +++ b/Strata/Languages/C_Simp/C_Simp.lean @@ -28,6 +28,7 @@ abbrev Expression : Imperative.PureExpr := { Ident := Lambda.Identifier Unit, Expr := Lambda.LExpr CSimpLParams.mono, Ty := Lambda.LTy, + ExprMetadata := CSimpLParams.Metadata, TyEnv := Lambda.TEnv Unit, TyContext := Lambda.LContext ⟨Unit, Unit⟩, EvalEnv := Lambda.LState ⟨Unit, String⟩, diff --git a/Strata/Languages/C_Simp/Verify.lean b/Strata/Languages/C_Simp/Verify.lean index e979428b..7f3adb55 100644 --- a/Strata/Languages/C_Simp/Verify.lean +++ b/Strata/Languages/C_Simp/Verify.lean @@ -48,6 +48,7 @@ def translate_stmt (s: Imperative.Stmt C_Simp.Expression C_Simp.Command) : Core. | .block l b _md => .block l (b.map translate_stmt) {} | .ite cond thenb elseb _md => .ite (translate_expr cond) (thenb.map translate_stmt) (elseb.map translate_stmt) {} | .loop guard measure invariant body _md => .loop (translate_expr guard) (translate_opt_expr measure) (translate_opt_expr invariant) (body.map translate_stmt) {} + | .funcDecl _ _ => panic! "C_Simp does not support function declarations" | .goto label _md => .goto label {} termination_by s.sizeOf decreasing_by diff --git a/Strata/Languages/Core/CallGraph.lean b/Strata/Languages/Core/CallGraph.lean index 26ebda9a..c23dd5f0 100644 --- a/Strata/Languages/Core/CallGraph.lean +++ b/Strata/Languages/Core/CallGraph.lean @@ -102,6 +102,7 @@ partial def extractCallsFromStatement (stmt : Statement) : List String := extractCallsFromStatements elseBody | .loop _ _ _ body _ => extractCallsFromStatements body | .goto _ _ => [] + | .funcDecl _ _ => [] /-- Extract procedure calls from a list of statements -/ partial def extractCallsFromStatements (stmts : List Statement) : List String := diff --git a/Strata/Languages/Core/Expressions.lean b/Strata/Languages/Core/Expressions.lean index 60d37b6f..32bd6c7d 100644 --- a/Strata/Languages/Core/Expressions.lean +++ b/Strata/Languages/Core/Expressions.lean @@ -22,6 +22,7 @@ abbrev Expression : Imperative.PureExpr := EqIdent := inferInstanceAs (DecidableEq (Lambda.Identifier _)) Expr := Lambda.LExpr ⟨⟨ExpressionMetadata, Visibility⟩, Lambda.LMonoTy⟩, Ty := Lambda.LTy, + ExprMetadata := ExpressionMetadata, TyEnv := @Lambda.TEnv Visibility, TyContext := @Lambda.LContext ⟨ExpressionMetadata, Visibility⟩, EvalEnv := Lambda.LState ⟨ExpressionMetadata, Visibility⟩ } diff --git a/Strata/Languages/Core/FunctionType.lean b/Strata/Languages/Core/FunctionType.lean index 76d4af93..415070c2 100644 --- a/Strata/Languages/Core/FunctionType.lean +++ b/Strata/Languages/Core/FunctionType.lean @@ -40,12 +40,12 @@ def typeCheck (C: Core.Expression.TyContext) (Env : Core.Expression.TyEnv) (func | some body => -- Temporarily add formals in the context. let Env := Env.pushEmptyContext - let Env := Env.addToContext func.inputPolyTypes + let Env := Env.addToContext (LFunc.inputPolyTypes func) -- Type check and annotate the body, and ensure that it unifies with the -- return type. let (bodya, Env) ← LExpr.resolve C Env body let bodyty := bodya.toLMonoTy - let (retty, Env) ← func.outputPolyType.instantiateWithCheck C Env + let (retty, Env) ← (LFunc.outputPolyType func).instantiateWithCheck C Env let S ← Constraints.unify [(retty, bodyty)] Env.stateSubstInfo |>.mapError format let Env := Env.updateSubst S let Env := Env.popContext diff --git a/Strata/Languages/Core/Statement.lean b/Strata/Languages/Core/Statement.lean index eaa84339..40da26c2 100644 --- a/Strata/Languages/Core/Statement.lean +++ b/Strata/Languages/Core/Statement.lean @@ -126,6 +126,11 @@ def Statement.eraseTypes (s : Statement) : Statement := let body' := Statements.eraseTypes bss .loop guard measure invariant body' md | .goto l md => .goto l md + | .funcDecl decl md => + let decl' := { decl with + body := decl.body.map Lambda.LExpr.eraseTypes, + axioms := decl.axioms.map Lambda.LExpr.eraseTypes } + .funcDecl decl' md termination_by (Stmt.sizeOf s) decreasing_by all_goals simp_wf <;> simp [sizeOf] <;> omega @@ -203,6 +208,7 @@ def Statement.modifiedVarsTrans Statements.modifiedVarsTrans π tbss ++ Statements.modifiedVarsTrans π ebss | .loop _ _ _ bss _ => Statements.modifiedVarsTrans π bss + | .funcDecl _ _ => [] -- Function declarations don't modify variables termination_by (Stmt.sizeOf s) def Statements.modifiedVarsTrans @@ -241,6 +247,11 @@ def Statement.getVarsTrans Statements.getVarsTrans π tbss ++ Statements.getVarsTrans π ebss | .loop _ _ _ bss _ => Statements.getVarsTrans π bss + | .funcDecl decl _ => + -- Get variables from function body (including parameters for simplicity) + match decl.body with + | none => [] + | some body => HasVarsPure.getVars body termination_by (Stmt.sizeOf s) def Statements.getVarsTrans @@ -284,6 +295,7 @@ def Statement.touchedVarsTrans | .block _ bss _ => Statements.touchedVarsTrans π bss | .ite _ tbss ebss _ => Statements.touchedVarsTrans π tbss ++ Statements.touchedVarsTrans π ebss | .loop _ _ _ bss _ => Statements.touchedVarsTrans π bss + | .funcDecl decl _ => [decl.name] -- Function declaration touches (defines) the function name termination_by (Stmt.sizeOf s) def Statements.touchedVarsTrans @@ -347,6 +359,12 @@ def Statement.substFvar (s : Core.Statement) (Block.substFvar body fr to) metadata | .goto _ _ => s + | .funcDecl decl md => + -- Substitute in function body and axioms + let decl' := { decl with + body := decl.body.map (Lambda.LExpr.substFvar · fr to), + axioms := decl.axioms.map (Lambda.LExpr.substFvar · fr to) } + .funcDecl decl' md termination_by s.sizeOf decreasing_by all_goals(simp_wf; try omega) end @@ -373,6 +391,10 @@ def Statement.renameLhs (s : Core.Statement) if l.name == fr then to else l)) pname args metadata | .block lbl b metadata => .block lbl (Block.renameLhs b fr to) metadata + | .funcDecl decl md => + -- Rename function name if it matches + let decl' := if decl.name == fr then { decl with name := to } else decl + .funcDecl decl' md | .havoc _ _ | .assert _ _ _ | .assume _ _ _ | .ite _ _ _ _ | .loop _ _ _ _ _ | .goto _ _ | .cover _ _ _ => s termination_by s.sizeOf diff --git a/Strata/Languages/Core/StatementEval.lean b/Strata/Languages/Core/StatementEval.lean index 16ed8b5b..17430455 100644 --- a/Strata/Languages/Core/StatementEval.lean +++ b/Strata/Languages/Core/StatementEval.lean @@ -186,7 +186,7 @@ def Statement.containsCmd (predicate : Imperative.Cmd Expression → Bool) (s : | .ite _ then_ss else_ss _ => Statements.containsCmds predicate then_ss || Statements.containsCmds predicate else_ss | .loop _ _ _ body_ss _ => Statements.containsCmds predicate body_ss - | .goto _ _ => false + | .funcDecl _ _ | .goto _ _ => false -- Function declarations and gotos don't contain commands termination_by Imperative.Stmt.sizeOf s /-- @@ -225,7 +225,7 @@ def Statement.collectCovers (s : Statement) : List (String × Imperative.MetaDat | .block _ inner_ss _ => Statements.collectCovers inner_ss | .ite _ then_ss else_ss _ => Statements.collectCovers then_ss ++ Statements.collectCovers else_ss | .loop _ _ _ body_ss _ => Statements.collectCovers body_ss - | .goto _ _ => [] + | .funcDecl _ _ | .goto _ _ => [] -- Function declarations and gotos don't contain cover commands termination_by Imperative.Stmt.sizeOf s /-- Collect all `cover` commands from statements `ss` with their labels and metadata. @@ -249,7 +249,7 @@ def Statement.collectAsserts (s : Statement) : List (String × Imperative.MetaDa | .block _ inner_ss _ => Statements.collectAsserts inner_ss | .ite _ then_ss else_ss _ => Statements.collectAsserts then_ss ++ Statements.collectAsserts else_ss | .loop _ _ _ body_ss _ => Statements.collectAsserts body_ss - | .goto _ _ => [] + | .funcDecl _ _ | .goto _ _ => [] -- Function declarations and gotos don't contain assert commands termination_by Imperative.Stmt.sizeOf s /-- Collect all `assert` commands from statements `ss` with their labels and metadata. @@ -400,6 +400,34 @@ def evalAuxGo (steps : Nat) (old_var_subst : SubstMap) (Ewn : EnvWithNext) (ss : Please transform your program to eliminate loops before \ calling Core.Statement.evalAux" + | .funcDecl decl _ => + -- Add function to factory with value capture semantics + -- Substitute current values of free variables into function body + let func : Lambda.LFunc CoreLParams := { + name := decl.name, + typeArgs := decl.typeArgs, + isConstr := decl.isConstr, + inputs := decl.inputs.map (fun (id, ty) => (id, Lambda.LTy.toMonoTypeUnsafe ty)), + output := Lambda.LTy.toMonoTypeUnsafe decl.output, + body := decl.body.map (fun e => + -- Substitute free variables with their current values from the environment + let freeVars := Lambda.LExpr.freeVars e + freeVars.foldl (fun body fv => + match Ewn.env.exprEnv.state.find? fv.fst with + | some (_, val) => Lambda.LExpr.substFvar body fv.fst val + | none => body + ) e + ), + attr := decl.attr, + concreteEval := decl.concreteEval, + axioms := decl.axioms + } + match Ewn.env.addFactoryFunc func with + | .ok env' => [{ Ewn with env := env' }] + | .error e => + -- If adding fails, set error but continue + [{ Ewn with env := { Ewn.env with error := some (.Misc f!"{e}") } }] + | .goto l md => [{ Ewn with stk := Ewn.stk.appendToTop [.goto l md], nextLabel := (some l)}] List.flatMap (fun (ewn : EnvWithNext) => go' ewn rest ewn.nextLabel) EAndNexts diff --git a/Strata/Languages/Core/StatementType.lean b/Strata/Languages/Core/StatementType.lean index feb1f012..a60d9d9d 100644 --- a/Strata/Languages/Core/StatementType.lean +++ b/Strata/Languages/Core/StatementType.lean @@ -10,6 +10,7 @@ import Strata.Languages.Core.Statement import Strata.Languages.Core.CmdType import Strata.Languages.Core.Program import Strata.Languages.Core.OldExpressions +import Strata.Languages.Core.FunctionType import Strata.DL.Imperative.CmdType namespace Core @@ -165,6 +166,37 @@ where -- Add source location to error messages. .error (errorWithSourceLoc e md) + | .funcDecl decl md => do try + -- Type check the function declaration + -- Manually convert PureFunc Expression to Function for type checking + let func : Function := { + name := decl.name, + typeArgs := decl.typeArgs, + isConstr := decl.isConstr, + inputs := decl.inputs.map (fun (id, ty) => (id, Lambda.LTy.toMonoTypeUnsafe ty)), + output := Lambda.LTy.toMonoTypeUnsafe decl.output, + body := decl.body, + attr := decl.attr, + concreteEval := none, -- Can't convert concreteEval safely + axioms := decl.axioms + } + let (func', Env) ← Function.typeCheck C Env func |>.mapError DiagnosticModel.fromFormat + -- Convert back by wrapping monotypes in trivial polytypes + let decl' : PureFunc Expression := { + name := func'.name, + typeArgs := func'.typeArgs, + isConstr := func'.isConstr, + inputs := func'.inputs.map (fun (id, mty) => (id, .forAll [] mty)), + output := .forAll [] func'.output, + body := func'.body, + attr := func'.attr, + concreteEval := decl.concreteEval, -- Preserve original + axioms := func'.axioms + } + .ok (.funcDecl decl' md, Env) + catch e => + .error (errorWithSourceLoc e md) + go Env srest (s' :: acc) termination_by Block.sizeOf ss decreasing_by @@ -207,6 +239,13 @@ def Statement.subst (S : Subst) (s : Statement) : Statement := | .loop guard m i bss md => .loop (guard.applySubst S) (substOptionExpr S m) (substOptionExpr S i) (go S bss []) md | .goto _ _ => s + | .funcDecl decl md => + let decl' := { decl with + inputs := decl.inputs.map (fun (id, ty) => (id, Lambda.LTy.subst S ty)), + output := Lambda.LTy.subst S decl.output, + body := decl.body.map (·.applySubst S), + axioms := decl.axioms.map (·.applySubst S) } + .funcDecl decl' md where go S ss acc : List Statement := match ss with diff --git a/Strata/Languages/Core/StatementWF.lean b/Strata/Languages/Core/StatementWF.lean index 0b79fc32..2e7e2069 100644 --- a/Strata/Languages/Core/StatementWF.lean +++ b/Strata/Languages/Core/StatementWF.lean @@ -125,6 +125,14 @@ theorem Statement.typeCheckAux_go_WF : any_goals (try simp [WFStatementsProp] at *; try simp [List.Forall_append, *]; try constructor) any_goals (simp [Forall]) any_goals constructor + | funcDecl decl md => + simp [Except.bind] at tcok + repeat (split at tcok <;> try contradiction) + have tcok := Statement.typeCheckAux_elim_singleton tcok + rw[List.append_cons]; + apply ih tcok <;> try assumption + simp [WFStatementsProp] at * + simp [List.Forall_append, Forall, *] /-- A list of Statement `ss` that passes type checking is well formed with respect diff --git a/Strata/Languages/Core/WF.lean b/Strata/Languages/Core/WF.lean index 4a79db2d..3cee9cb5 100644 --- a/Strata/Languages/Core/WF.lean +++ b/Strata/Languages/Core/WF.lean @@ -79,6 +79,7 @@ def WFStatementProp (p : Program) (stmt : Statement) : Prop := match stmt with | .loop (guard : Expression.Expr) (measure : Option Expression.Expr) (invariant : Option Expression.Expr) (body : Block) _ => WFloopProp (CmdExt Expression) p guard measure invariant body | .goto (label : String) _ => WFgotoProp p label + | .funcDecl _ _ => True -- Function declarations are always well-formed at this level abbrev WFStatementsProp (p : Program) := Forall (WFStatementProp p) diff --git a/Strata/Transform/CoreTransform.lean b/Strata/Transform/CoreTransform.lean index d4bd7664..74bd2ef3 100644 --- a/Strata/Transform/CoreTransform.lean +++ b/Strata/Transform/CoreTransform.lean @@ -224,6 +224,7 @@ def runStmtsRec (f : Command → Program → CoreTransformM (List Statement)) | .loop guard measure invariant body md => do let body' ← runStmtsRec f body inputProg return [.loop guard measure invariant body' md] + | .funcDecl _ _ => return [s] -- Function declarations pass through unchanged | .goto _lbl _md => return [s]) return (sres ++ ss'') diff --git a/Strata/Transform/DetToNondet.lean b/Strata/Transform/DetToNondet.lean index ec163de4..6bf37da8 100644 --- a/Strata/Transform/DetToNondet.lean +++ b/Strata/Transform/DetToNondet.lean @@ -29,6 +29,7 @@ def StmtToNondetStmt {P : PureExpr} [Imperative.HasBool P] [HasNot P] | .loop guard _measure _inv bss md => .loop (.seq (.assume "guard" guard md) (BlockToNondetStmt bss)) | .goto _ _ => (.assume "skip" Imperative.HasBool.tt) + | .funcDecl _ _ => (.assume "skip" Imperative.HasBool.tt) /-- Deterministic-to-nondeterministic transformation for multiple (deterministic) statements -/ diff --git a/Strata/Transform/LoopElim.lean b/Strata/Transform/LoopElim.lean index 69d8e577..a35f423d 100644 --- a/Strata/Transform/LoopElim.lean +++ b/Strata/Transform/LoopElim.lean @@ -56,6 +56,7 @@ def Stmt.removeLoopsM pure (.block label bss md) | .cmd _ => pure s | .goto _ _ => pure s + | .funcDecl _ _ => pure s -- Function declarations pass through unchanged def Block.removeLoopsM [HasNot P] [HasVarsImp P C] [HasHavoc P C] [HasPassiveCmds P C] diff --git a/Strata/Transform/ProcedureInlining.lean b/Strata/Transform/ProcedureInlining.lean index e722fb14..e5036762 100644 --- a/Strata/Transform/ProcedureInlining.lean +++ b/Strata/Transform/ProcedureInlining.lean @@ -55,6 +55,12 @@ def Statement.substFvar (s : Core.Statement) (Option.map (Lambda.LExpr.substFvar · fr to) invariant) (Block.substFvar body fr to) metadata + | .funcDecl decl md => + -- Substitute in function body and axioms + let decl' := { decl with + body := decl.body.map (Lambda.LExpr.substFvar · fr to), + axioms := decl.axioms.map (Lambda.LExpr.substFvar · fr to) } + .funcDecl decl' md | .goto _ _ => s termination_by s.sizeOf end @@ -82,6 +88,10 @@ def Statement.renameLhs (s : Core.Statement) (fr: Lambda.Identifier Visibility) | .loop m g i b md => .loop m g i (Block.renameLhs b fr to) md | .havoc l md => .havoc (if l.name == fr then to else l) md + | .funcDecl decl md => + -- Rename function name if it matches + let decl' := if decl.name == fr then { decl with name := to } else decl + .funcDecl decl' md | .assert _ _ _ | .assume _ _ _ | .cover _ _ _ | .goto _ _ => s termination_by s.sizeOf end @@ -106,6 +116,7 @@ def Statement.labels (s : Core.Statement) : List String := | .goto _ _ => [] -- No other labeled commands. | .cmd _ => [] + | .funcDecl _ _ => [] termination_by s.sizeOf end @@ -133,6 +144,7 @@ def Statement.replaceLabels | .assert lbl e m => .assert (app lbl) e m | .cover lbl e m => .cover (app lbl) e m | .cmd _ => s + | .funcDecl _ _ => s termination_by s.sizeOf end diff --git a/StrataTest/Backends/CBMC/CoreToCProverGOTO.lean b/StrataTest/Backends/CBMC/CoreToCProverGOTO.lean index 06f254c5..f8c5dd79 100644 --- a/StrataTest/Backends/CBMC/CoreToCProverGOTO.lean +++ b/StrataTest/Backends/CBMC/CoreToCProverGOTO.lean @@ -27,6 +27,7 @@ abbrev Core.ExprStr : Imperative.PureExpr := { Ident := CoreParams.Identifier, Expr := Lambda.LExpr CoreParams.mono, Ty := Lambda.LTy, + ExprMetadata := CoreParams.Metadata, TyEnv := @Lambda.TEnv CoreParams.IDMeta, TyContext := @Lambda.LContext CoreParams, EvalEnv := Lambda.LState CoreParams diff --git a/StrataTest/Backends/CBMC/ToCProverGOTO.lean b/StrataTest/Backends/CBMC/ToCProverGOTO.lean index 4ba14958..bbde1c8b 100644 --- a/StrataTest/Backends/CBMC/ToCProverGOTO.lean +++ b/StrataTest/Backends/CBMC/ToCProverGOTO.lean @@ -21,6 +21,7 @@ private abbrev LExprTP : Imperative.PureExpr := { Ident := TestParams.Identifier, Expr := Lambda.LExprT TestParams.mono, Ty := Lambda.LMonoTy, + ExprMetadata := TestParams.Metadata, TyEnv := @Lambda.TEnv TestParams.IDMeta, TyContext := @Lambda.LContext TestParams, EvalEnv := Lambda.LState TestParams diff --git a/StrataTest/DL/Imperative/ArithExpr.lean b/StrataTest/DL/Imperative/ArithExpr.lean index 8c98cda4..81143fd5 100644 --- a/StrataTest/DL/Imperative/ArithExpr.lean +++ b/StrataTest/DL/Imperative/ArithExpr.lean @@ -96,6 +96,7 @@ abbrev PureExpr : PureExpr := { Ident := String, Expr := Expr, Ty := Ty, + ExprMetadata := Unit, TyEnv := TEnv, TyContext := Unit, EvalEnv := Env, diff --git a/StrataTest/Languages/Core/FactoryWF.lean b/StrataTest/Languages/Core/FactoryWF.lean index 0a6f02af..4d2c0b98 100644 --- a/StrataTest/Languages/Core/FactoryWF.lean +++ b/StrataTest/Languages/Core/FactoryWF.lean @@ -37,21 +37,24 @@ theorem Factory_wf : repeat ( rcases Hmem with _ | ⟨ a', Hmem ⟩ · apply LFuncWF.mk - · decide -- LFuncWF.arg_nodup + rotate_left · decide -- LFuncWF.body_freevars - · -- LFuncWf.concreteEval_argmatch - simp (config := { ground := true }) - try ( - try unfold unOpCeval - try unfold binOpCeval - try unfold cevalIntDiv - try unfold cevalIntMod - try unfold bvUnaryOp - try unfold bvBinaryOp - try unfold bvShiftOp - try unfold bvBinaryPred - intros lf md args res - repeat (rcases args with _ | ⟨ args0, args ⟩ <;> try grind))) + rotate_left + · apply FuncWF.mk + · decide -- LFuncWF.arg_nodup + · -- LFuncWf.concreteEval_argmatch + simp (config := { ground := true }) + try ( + try unfold unOpCeval + try unfold binOpCeval + try unfold cevalIntDiv + try unfold cevalIntMod + try unfold bvUnaryOp + try unfold bvBinaryOp + try unfold bvShiftOp + try unfold bvBinaryPred + intros lf md args res + repeat (rcases args with _ | ⟨ args0, args ⟩ <;> try grind))) contradiction end Core diff --git a/StrataTest/Languages/Core/StatementEvalTests.lean b/StrataTest/Languages/Core/StatementEvalTests.lean index 8281bec1..f3130cec 100644 --- a/StrataTest/Languages/Core/StatementEvalTests.lean +++ b/StrataTest/Languages/Core/StatementEvalTests.lean @@ -377,6 +377,122 @@ Proof Obligation: #guard_msgs in #eval (evalOne ∅ ∅ prog2) |>.snd |> format +/-- +Test funcDecl: declare a helper function and use it +-/ +def testFuncDecl : List Statement := + let doubleFunc : PureFunc Expression := { + name := CoreIdent.unres "double", + typeArgs := [], + isConstr := false, + inputs := [(CoreIdent.unres "x", .forAll [] .int)], + output := .forAll [] .int, + body := some eb[((~Int.Add x) x)], + attr := #[], + concreteEval := none, + axioms := [] + } + [ + .funcDecl doubleFunc, + .init "y" t[int] eb[(~double #5)], + .assert "y_eq_10" eb[y == #10] + ] + +/-- +info: Error: +none +Subst Map: + +Expression Env: +State: +[(y : int) → (~double #5)] + +Evaluation Config: +Eval Depth: 200 +Variable Prefix: $__ +Variable gen count: 0 +Factory Functions: +func double : ((x : int)) → int := + (((~Int.Add x) x)) + + +Datatypes: + +Path Conditions: + + +Warnings: +[] +Deferred Proof Obligations: +Label: y_eq_10 +Property: assert +Assumptions: +Proof Obligation: +((~double #5) == #10) +-/ +#guard_msgs in +#eval (evalOne ∅ ∅ testFuncDecl) |>.snd |> format + +/-- +Test funcDecl with variable capture: function captures variable value at declaration time, +not affected by subsequent mutations +-/ +def testFuncDeclSymbolic : List Statement := + let addNFunc : PureFunc Expression := { + name := CoreIdent.unres "addN", + typeArgs := [], + isConstr := false, + inputs := [(CoreIdent.unres "x", .forAll [] .int)], + output := .forAll [] .int, + body := some eb[((~Int.Add x) n)], -- Captures 'n' at declaration time + attr := #[], + concreteEval := none, + axioms := [] + } + [ + .init "n" t[int] eb[#10], -- Initialize n to 10 + .funcDecl addNFunc, -- Function captures n = 10 at declaration time + .set "n" eb[#20], -- Mutate n to 20 + .init "result" t[int] eb[(~addN #5)], -- Call function + .assert "result_eq_15" eb[result == #15] -- Result is 5 + 10 = 15 (uses captured value) + ] + +/-- +info: Error: +none +Subst Map: + +Expression Env: +State: +[(n : int) → #20 +(result : int) → (~addN #5)] + +Evaluation Config: +Eval Depth: 200 +Variable Prefix: $__ +Variable gen count: 0 +Factory Functions: +func addN : ((x : int)) → int := + (((~Int.Add x) #10)) + + +Datatypes: + +Path Conditions: + + +Warnings: +[] +Deferred Proof Obligations: +Label: result_eq_15 +Property: assert +Assumptions: +Proof Obligation: +((~addN #5) == #15) +-/ +#guard_msgs in +#eval (evalOne ∅ ∅ testFuncDeclSymbolic) |>.snd |> format + end Tests --------------------------------------------------------------------- end Core diff --git a/StrataTest/Transform/CallElimCorrect.lean b/StrataTest/Transform/CallElimCorrect.lean index 2f32c816..093b064e 100644 --- a/StrataTest/Transform/CallElimCorrect.lean +++ b/StrataTest/Transform/CallElimCorrect.lean @@ -179,6 +179,7 @@ theorem callElimBlockNoExcept : | ite cd tb eb md => exists [.ite cd tb eb md] | goto l b => exists [.goto l b] | loop g m i b md => exists [.loop g m i b md] + | funcDecl f md => exists [.funcDecl f md] | cmd c => cases c with | cmd c' => exists [Imperative.Stmt.cmd (CmdExt.cmd c')] @@ -3298,6 +3299,7 @@ theorem callElimStatementCorrect [LawfulBEq Expression.Expr] : case ite => exact ⟨σ', Inits.init InitVars.init_none, Heval⟩ case goto => exact ⟨σ', Inits.init InitVars.init_none, Heval⟩ case loop => exact ⟨σ', Inits.init InitVars.init_none, Heval⟩ + case funcDecl => exact ⟨σ', Inits.init InitVars.init_none, Heval⟩ case cmd c => cases c with | cmd c' =>