diff --git a/Strata/DDM/AST.lean b/Strata/DDM/AST.lean index dc5c2c337..5b01ddfdf 100644 --- a/Strata/DDM/AST.lean +++ b/Strata/DDM/AST.lean @@ -94,8 +94,11 @@ inductive TypeExprF (α : Type) where /-- A polymorphic type variable (universally quantified). Used for polymorphic function type parameters -/ | tvar (ann : α) (name : String) - /-- A reference to a global variable along with any arguments to ensure it is well-typed. -/ -| fvar (ann : α) (fvar : FreeVarIndex) (args : Array (TypeExprF α)) + /-- A reference to a global variable along with any arguments to ensure it is + well-typed. The name field stores the original type name for lookup in + mutual blocks. -/ +| fvar (ann : α) (fvar : FreeVarIndex) (name : Option String) + (args : Array (TypeExprF α)) /-- A function type. -/ | arrow (ann : α) (arg : TypeExprF α) (res : TypeExprF α) deriving BEq, Inhabited, Repr @@ -106,7 +109,7 @@ def ann {α} : TypeExprF α → α | .ident ann _ _ => ann | .bvar ann _ => ann | .tvar ann _ => ann -| .fvar ann _ _ => ann +| .fvar ann _ _ _ => ann | .arrow ann _ _ => ann def mkFunType {α} (n : α) (bindings : Array (String × TypeExprF α)) (res : TypeExprF α) : TypeExprF α := @@ -115,7 +118,7 @@ def mkFunType {α} (n : α) (bindings : Array (String × TypeExprF α)) (res : T protected def incIndices {α} (tp : TypeExprF α) (count : Nat) : TypeExprF α := match tp with | .ident n i args => .ident n i (args.attach.map fun ⟨e, _⟩ => e.incIndices count) - | .fvar n f args => .fvar n f (args.attach.map fun ⟨e, _⟩ => e.incIndices count) + | .fvar n f name args => .fvar n f name (args.attach.map fun ⟨e, _⟩ => e.incIndices count) | .bvar n idx => .bvar n (idx + count) | .tvar n name => .tvar n name -- tvar doesn't use indices | .arrow n a r => .arrow n (a.incIndices count) (r.incIndices count) @@ -123,7 +126,7 @@ protected def incIndices {α} (tp : TypeExprF α) (count : Nat) : TypeExprF α : /-- Return true if type expression has a bound variable. -/ protected def hasUnboundVar {α} (bindingCount : Nat := 0) : TypeExprF α → Bool | .ident _ _ args => args.attach.any (fun ⟨e, _⟩ => e.hasUnboundVar bindingCount) -| .fvar _ _ args => args.attach.any (fun ⟨e, _⟩ => e.hasUnboundVar bindingCount) +| .fvar _ _ _ args => args.attach.any (fun ⟨e, _⟩ => e.hasUnboundVar bindingCount) | .bvar _ idx => idx ≥ bindingCount | .tvar _ _ => true | .arrow _ a r => a.hasUnboundVar bindingCount || r.hasUnboundVar bindingCount @@ -142,8 +145,8 @@ protected def instTypeM {m α} [Monad m] (d : TypeExprF α) (bindings : α → N | .ident n i a => .ident n i <$> a.attach.mapM (fun ⟨e, _⟩ => e.instTypeM bindings) | .bvar n idx => bindings n idx + | .fvar n idx name a => .fvar n idx name <$> a.attach.mapM (fun ⟨e, _⟩ => e.instTypeM bindings) | .tvar n name => pure (.tvar n name) - | .fvar n idx a => .fvar n idx <$> a.attach.mapM (fun ⟨e, _⟩ => e.instTypeM bindings) | .arrow n a b => .arrow n <$> a.instTypeM bindings <*> b.instTypeM bindings termination_by d @@ -624,7 +627,7 @@ inductive PreType where Used for polymorphic function type parameters -/ | tvar (ann : SourceRange) (name : String) /-- A reference to a global variable along with any arguments to ensure it is well-typed. -/ -| fvar (ann : SourceRange) (fvar : FreeVarIndex) (args : Array PreType) +| fvar (ann : SourceRange) (fvar : FreeVarIndex) (name : Option String) (args : Array PreType) /-- A function type. -/ | arrow (ann : SourceRange) (arg : PreType) (res : PreType) /-- A function created from a reference to bindings and a result type. -/ @@ -638,15 +641,15 @@ def ann : PreType → SourceRange | .ident ann _ _ => ann | .bvar ann _ => ann | .tvar ann _ => ann -| .fvar ann _ _ => ann +| .fvar ann _ _ _ => ann | .arrow ann _ _ => ann | .funMacro ann _ _ => ann def ofType : TypeExprF SourceRange → PreType | .ident loc name args => .ident loc name (args.map fun a => .ofType a) | .bvar loc idx => .bvar loc idx +| .fvar loc idx name args => .fvar loc idx name (args.map fun a => .ofType a) | .tvar loc name => .tvar loc name -| .fvar loc idx args => .fvar loc idx (args.map fun a => .ofType a) | .arrow loc a r => .arrow loc (.ofType a) (.ofType r) termination_by tp => tp @@ -869,9 +872,9 @@ structure ConstructorInfo where Build a TypeExpr reference to the datatype with type parameters, using `.fvar` for the datatype's GlobalContext index and `.tvar` for type parameters. -/ -def mkDatatypeTypeRef (ann : SourceRange) (datatypeIndex : FreeVarIndex) (typeParams : Array String) : TypeExpr := +def mkDatatypeTypeRef (ann : SourceRange) (datatypeIndex : FreeVarIndex) (typeParams : Array String) (datatypeName : Option String := none) : TypeExpr := let typeArgs := typeParams.map fun name => TypeExprF.tvar ann name - TypeExprF.fvar ann datatypeIndex typeArgs + TypeExprF.fvar ann datatypeIndex datatypeName typeArgs /-- Build an arrow type from field types to the datatype type. E.g. for cons, @@ -1043,6 +1046,7 @@ A spec for introducing a new binding into a type context. inductive BindingSpec (argDecls : ArgDecls) where | value (_ : ValueBindingSpec argDecls) | type (_ : TypeBindingSpec argDecls) +| typeForward (_ : TypeBindingSpec argDecls) -- Forward declaration (no AST node) | datatype (_ : DatatypeBindingSpec argDecls) | tvar (_ : TvarBindingSpec argDecls) deriving Repr @@ -1052,6 +1056,7 @@ namespace BindingSpec def nameIndex {argDecls} : BindingSpec argDecls → DebruijnIndex argDecls.size | .value v => v.nameIndex | .type v => v.nameIndex +| .typeForward v => v.nameIndex | .datatype v => v.nameIndex | .tvar v => v.nameIndex @@ -1135,6 +1140,22 @@ def parseNewBindings (md : Metadata) (argDecls : ArgDecls) : Array (BindingSpec pure <| some ⟨idx, argsP⟩ | _ => newBindingErr "declareType args invalid."; return none some <$> .type <$> pure { nameIndex, argsIndex, defIndex := none } + | q`StrataDDL.declareTypeForward => do + let #[.catbvar nameIndex, .option mArgsArg ] := attr.args + | newBindingErr s!"declareTypeForward has bad arguments {repr attr.args}."; return none + let .isTrue nameP := inferInstanceAs (Decidable (nameIndex < argDecls.size)) + | return panic! "Invalid name index" + let nameIndex := ⟨nameIndex, nameP⟩ + checkNameIndexIsIdent argDecls nameIndex + let argsIndex ← + match mArgsArg with + | none => pure none + | some (.catbvar idx) => + let .isTrue argsP := inferInstanceAs (Decidable (idx < argDecls.size)) + | return panic! "Invalid arg index" + pure <| some ⟨idx, argsP⟩ + | _ => newBindingErr "declareTypeForward args invalid."; return none + some <$> .typeForward <$> pure { nameIndex, argsIndex, defIndex := none } | q`StrataDDL.aliasType => do let #[.catbvar nameIndex, .option mArgsArg, .catbvar defIndex] := attr.args | newBindingErr "aliasType missing arguments."; return none @@ -1767,6 +1788,12 @@ inductive GlobalKind where | type (params : List String) (definition : Option TypeExpr) deriving BEq, Inhabited, Repr +/-- State of a symbol in the GlobalContext -/ +inductive DeclState where + | forward -- Symbol is forward-declared (no AST node will be generated) + | defined -- Symbol has a complete definition +deriving BEq, DecidableEq, Repr, Inhabited + /-- Resolves a binding spec into a global kind. -/ partial def resolveBindingIndices { argDecls : ArgDecls } (m : DialectMap) (src : SourceRange) (b : BindingSpec argDecls) (args : Vector Arg argDecls.size) : Option GlobalKind := match b with @@ -1792,7 +1819,7 @@ partial def resolveBindingIndices { argDecls : ArgDecls } (m : DialectMap) (src panic! s!"Expected new binding to be Type instead of {repr c}." | a => panic! s!"Expected new binding to be bound to type instead of {repr a}." - | .type b => + | .type b | .typeForward b => let params : Array String := match b.argsIndex with | none => #[] @@ -1831,7 +1858,7 @@ Typing environment created from declarations in an environment. -/ structure GlobalContext where nameMap : Std.HashMap Var FreeVarIndex - vars : Array (Var × GlobalKind) + vars : Array (Var × GlobalKind × DeclState) deriving Repr namespace GlobalContext @@ -1851,12 +1878,45 @@ instance : Membership Var GlobalContext where def instDecidableMem (v : Var) (ctx : GlobalContext) : Decidable (v ∈ ctx) := inferInstanceAs (Decidable (v ∈ ctx.nameMap)) -def push (ctx : GlobalContext) (v : Var) (k : GlobalKind) : GlobalContext := +/-- Add a forward declaration (must not exist). Used by @[declareTypeForward]. + This adds to GlobalContext for name resolution but will NOT generate an AST node. -/ +def declareForward (ctx : GlobalContext) (v : Var) (k : GlobalKind) : Except String GlobalContext := if v ∈ ctx then - panic! s!"Var {v} already defined" + .error s!"Symbol '{v}' is already in scope" else let idx := ctx.vars.size - { nameMap := ctx.nameMap.insert v idx, vars := ctx.vars.push (v, k) } + .ok { nameMap := ctx.nameMap.insert v idx, + vars := ctx.vars.push (v, k, .forward) } + +/-- Define a symbol. Used by @[declareDatatype], @[declareFn] with body, etc. + Replaces forward declaration, or adds new as defined. -/ +def define (ctx : GlobalContext) (v : Var) (k : GlobalKind) : Except String GlobalContext := + match ctx.nameMap.get? v with + | none => + -- Not declared, add as defined directly + let idx := ctx.vars.size + .ok { nameMap := ctx.nameMap.insert v idx, + vars := ctx.vars.push (v, k, .defined) } + | some idx => + let (name, _, state) := ctx.vars[idx]! + match state with + | .forward => + -- Replace forward declaration with definition (update in place) + .ok { ctx with vars := ctx.vars.set! idx (name, k, .defined) } + | .defined => + .error s!"Symbol '{v}' is already defined" + +/-- Check if a symbol is forward-declared (not yet defined). -/ +def isForward (ctx : GlobalContext) (idx : FreeVarIndex) : Bool := + match ctx.vars[idx]? with + | some (_, _, .forward) => true + | _ => false + +/-- Add a symbol as defined. -/ +def push (ctx : GlobalContext) (v : Var) (k : GlobalKind) : GlobalContext := + match ctx.define v k with + | .ok ctx' => ctx' + | .error msg => panic! msg /-- Return the index of the variable with the given name. -/ def findIndex? (ctx : GlobalContext) (v : Var) : Option FreeVarIndex := ctx.nameMap.get? v @@ -1865,7 +1925,7 @@ def nameOf? (ctx : GlobalContext) (idx : FreeVarIndex) : Option String := ctx.va def kindOf! (ctx : GlobalContext) (idx : FreeVarIndex) : GlobalKind := assert! idx < ctx.vars.size - ctx.vars[idx]!.snd + ctx.vars[idx]!.2.1 /-! ## Annotation-based Constructor Info Extraction @@ -2066,10 +2126,12 @@ private def addDatatypeBindings let constructorInfo := extractConstructorInfo dialects args[b.constructorsIndex.toLevel] - -- Step 1: Add datatype type - let gctx := gctx.push datatypeName (GlobalKind.type typeParams.toList none) - let datatypeIndex := gctx.vars.size - 1 - let datatypeType := mkDatatypeTypeRef src datatypeIndex typeParams + -- Step 1: Add datatype type (or update forward declaration) + let gctx := match gctx.define datatypeName (GlobalKind.type typeParams.toList none) with + | .ok gctx' => gctx' + | .error msg => panic! s!"addDatatypeBindings: {msg}" + let datatypeIndex := gctx.findIndex? datatypeName |>.getD (gctx.vars.size - 1) + let datatypeType := mkDatatypeTypeRef src datatypeIndex typeParams (some datatypeName) -- Step 2: Add constructor signatures let gctx := constructorInfo.foldl (init := gctx) fun gctx constr => @@ -2093,6 +2155,17 @@ def addCommand (dialects : DialectMap) (init : GlobalContext) (op : Operation) : match b with | .datatype datatypeSpec => addDatatypeBindings dialects gctx l dialectName datatypeSpec args + | .typeForward typeSpec => + let name := + match args[typeSpec.nameIndex.toLevel] with + | .ident _ e => e + | a => panic! s!"Expected ident at {typeSpec.nameIndex.toLevel} {repr a}" + match resolveBindingIndices dialects l b args with + | some kind => + match gctx.declareForward name kind with + | .ok gctx' => gctx' + | .error msg => panic! msg + | none => gctx | _ => let name := match args[b.nameIndex.toLevel] with diff --git a/Strata/DDM/BuiltinDialects/StrataDDL.lean b/Strata/DDM/BuiltinDialects/StrataDDL.lean index 33c8e1bbf..6e2313dfa 100644 --- a/Strata/DDM/BuiltinDialects/StrataDDL.lean +++ b/Strata/DDM/BuiltinDialects/StrataDDL.lean @@ -164,6 +164,7 @@ def StrataDDL : Dialect := BuiltinM.create! "StrataDDL" #[initDialect] do -- Metadata for marking an operation as a constructor list push (list followed by constructor) declareMetadata { name := "constructorListPush", args := #[.mk "list" .ident, .mk "constructor" .ident] } declareMetadata { name := "declareType", args := #[.mk "name" .ident, .mk "args" (.opt .ident)] } + declareMetadata { name := "declareTypeForward", args := #[.mk "name" .ident, .mk "args" (.opt .ident)] } declareMetadata { name := "aliasType", args := #[.mk "name" .ident, .mk "args" (.opt .ident), .mk "def" .ident] } declareMetadata { name := "declare", args := #[.mk "name" .ident, .mk "type" .ident] } declareMetadata { name := "declareFn", args := #[.mk "name" .ident, .mk "args" .ident, .mk "type" .ident] } diff --git a/Strata/DDM/Elab/Core.lean b/Strata/DDM/Elab/Core.lean index f3d242efd..9a924882b 100644 --- a/Strata/DDM/Elab/Core.lean +++ b/Strata/DDM/Elab/Core.lean @@ -53,7 +53,7 @@ partial def expandMacros (m : DialectMap) (f : PreType) (args : Nat → Option A match f with | .ident loc i a => .ident loc i <$> a.mapM fun e => expandMacros m e args | .arrow loc a b => .arrow loc <$> expandMacros m a args <*> expandMacros m b args - | .fvar loc i a => .fvar loc i <$> a.mapM fun e => expandMacros m e args + | .fvar loc i name a => .fvar loc i name <$> a.mapM fun e => expandMacros m e args | .bvar loc idx => pure (.bvar loc idx) | .tvar loc name => pure (.tvar loc name) | .funMacro loc i r => do @@ -81,7 +81,7 @@ the head is in a normal form. partial def hnf (tctx : TypingContext) (e : TypeExpr) : TypeExpr := match e with | .arrow .. | .ident .. | .tvar .. => e - | .fvar _ idx args => + | .fvar _ idx _ args => let gctx := tctx.globalContext match gctx.kindOf! idx with | .expr _ => panic! "Type free variable bound to expression." @@ -203,7 +203,7 @@ def resolveTypeBinding (tctx : TypingContext) (loc : SourceRange) (name : String | logErrorMF c.info.loc mf!"Expected type" tpArgs := tpArgs.push cinfo.typeExpr children := children.push c - let tp := .fvar loc fidx tpArgs + let tp := .fvar loc fidx (some name) tpArgs let info : TypeInfo := { inputCtx := tctx, loc := loc, typeExpr := tp, isInferred := false } return .node (.ofTypeInfo info) children else if let some a := args[params.size]? then @@ -334,7 +334,7 @@ N.B. This expects that macros have already been expanded in e. partial def headExpandTypeAlias (gctx : GlobalContext) (e : TypeExpr) : TypeExpr := match e with | .arrow .. | .ident .. | .bvar .. | .tvar .. => e - | .fvar _ idx args => + | .fvar _ idx _ args => match gctx.kindOf! idx with | .expr _ => panic! "Type free variable bound to expression." | .type params (some d) => @@ -368,7 +368,7 @@ partial def checkExpressionType (tctx : TypingContext) (itype rtype : TypeExpr) return false | .bvar _ ii, .bvar _ ri => return ii = ri - | .fvar _ ii ia, .fvar _ ri ra => + | .fvar _ ii _ ia, .fvar _ ri _ ra => if p : ii = ri ∧ ia.size = ra.size then do for i in Fin.range ia.size do if !(← checkExpressionType tctx ia[i] ra[i]) then @@ -446,9 +446,9 @@ partial def unifyTypes | _ => logErrorMF exprLoc mf!"Encountered {inferredHead} expression when {expectedType} expected." return args - | .fvar _ eid ea => + | .fvar _ eid _ ea => match tctx.hnf inferredType with - | .fvar _ iid ia => + | .fvar _ iid _ ia => if eid != iid then logErrorMF exprLoc mf!"Encountered {inferredType} expression when {expectedType} expected." return args @@ -749,7 +749,7 @@ def translateBindingKind (tree : Tree) : ElabM BindingKind := do | .type params _ => let params := params.toArray if params.size = tpArgs.size then - return .expr (.fvar nameLoc fidx tpArgs) + return .expr (.fvar nameLoc fidx (some name) tpArgs) else if let some a := tpArgs[params.size]? then logErrorMF a.ann mf!"Unexpected argument to {name}." return default @@ -799,6 +799,18 @@ def translateBindingKind (tree : Tree) : ElabM BindingKind := do logInternalError argInfo.loc s!"translateArgDeclKind given invalid kind {opInfo.op.name}" return default +/-- Extract type parameter names from a bindings argument. -/ +def elabTypeParams {n} (initSize : Nat) (args : Vector Tree n) + (idx : Option (DebruijnIndex n)) : ElabM (List String) := do + let params ← elabArgIndex initSize args idx fun argLoc b => do + match b.kind with + | .type _ [] _ => pure () + | .tvar _ _ => pure () + | .type .. | .expr _ | .cat _ => + logError argLoc s!"{b.ident} must have type Type instead of {repr b.kind}." + return b.ident + pure params.toList + /-- Construct a binding from a binding spec and the arguments to an operation. -/ @@ -844,17 +856,9 @@ def evalBindingSpec panic! s!"Cannot bind {ident}: Type at {b.typeIndex.val} has unexpected arg {repr arg}" -- TODO: Decide if new bindings for Type and Expr (or other categories) and should not be allowed? pure { ident, kind } - | .type b => + | .type b | .typeForward b => let ident := evalBindingNameIndex args b.nameIndex - let params ← elabArgIndex initSize args b.argsIndex fun argLoc b => do - match b.kind with - | .type _ [] _ => - pure () - | .tvar _ _ => - pure () - | .type .. | .expr _ | .cat _ => do - logError argLoc s!"{b.ident} must be have type Type instead of {repr b.kind}." - return b.ident + let params ← elabTypeParams initSize args b.argsIndex let value : Option TypeExpr := match b.defIndex with | none => none @@ -864,10 +868,11 @@ def evalBindingSpec some info.typeExpr | _ => panic! "Bad arg" - pure { ident, kind := .type loc params.toList value } + pure { ident, kind := .type loc params value } | .datatype b => let ident := evalBindingNameIndex args b.nameIndex - pure { ident, kind := .type loc [] none } + let params ← elabTypeParams initSize args (some b.typeParamsIndex) + pure { ident, kind := .type loc params none } | .tvar b => let ident := evalBindingNameIndex args b.nameIndex pure { ident, kind := .tvar loc ident } @@ -1045,11 +1050,15 @@ partial def runSyntaxElaborator | .ofIdentInfo info => info.val | _ => panic! "Expected identifier for datatype name" let baseCtx := typeParamsT.resultContext - -- Extract type parameter names from the bindings - let typeParamNames := baseCtx.bindings.toArray.filterMap fun b => - match b.kind with - | .type _ [] _ => some b.ident - | _ => none + /- Extract type parameter names only from NEW bindings added by + typeParams, not inherited bindings (which may include datatypes from + previous commands) -/ + let inheritedCount := tctx0.bindings.size + let typeParamNames := baseCtx.bindings.toArray.extract inheritedCount baseCtx.bindings.size + |>.filterMap fun b => + match b.kind with + | .type _ [] _ => some b.ident + | _ => none -- Add the datatype name to the GlobalContext as a type let gctx := baseCtx.globalContext let gctx := @@ -1057,7 +1066,9 @@ partial def runSyntaxElaborator else gctx.push datatypeName (GlobalKind.type typeParamNames.toList none) -- Add .tvar bindings for type parameters let loc := typeParamsT.info.loc - let tctx := typeParamNames.foldl (init := baseCtx.withGlobalContext gctx) fun ctx name => + -- Start with empty local bindings - don't inherit from baseCtx + -- This prevents datatype names from leaking between mutual block entries + let tctx := typeParamNames.foldl (init := TypingContext.empty gctx) fun ctx name => ctx.push { ident := name, kind := .tvar loc name } pure tctx | _, _ => continue diff --git a/Strata/DDM/Elab/DialectM.lean b/Strata/DDM/Elab/DialectM.lean index 0db0868eb..a8655b9a6 100644 --- a/Strata/DDM/Elab/DialectM.lean +++ b/Strata/DDM/Elab/DialectM.lean @@ -27,7 +27,7 @@ Note this does not return variables referenced by .funMacro. private def foldBoundTypeVars {α} (tp : PreType) (init : α) (f : α → Nat → α) : α := match tp with | .ident _ _ a => a.attach.foldl (init := init) fun r ⟨e, _⟩ => e.foldBoundTypeVars r f - | .fvar _ _ a => a.attach.foldl (init := init) fun r ⟨e, _⟩ => e.foldBoundTypeVars r f + | .fvar _ _ _ a => a.attach.foldl (init := init) fun r ⟨e, _⟩ => e.foldBoundTypeVars r f | .bvar _ i => f init i | .tvar _ _ => init | .arrow _ a r => r.foldBoundTypeVars (a.foldBoundTypeVars init f) f diff --git a/Strata/DDM/Format.lean b/Strata/DDM/Format.lean index d5c08d22b..7f7614d6b 100644 --- a/Strata/DDM/Format.lean +++ b/Strata/DDM/Format.lean @@ -272,9 +272,9 @@ private protected def mformat : TypeExprF α → StrataFormat | .ident _ tp a => a.attach.foldl (init := mformat tp) fun m ⟨e, _⟩ => mf!"{m} {e.mformat.ensurePrec (appPrec + 1)}".setPrec appPrec | .bvar _ idx => .bvar idx -| .tvar _ name => mf!"tvar!{name}" -| .fvar _ idx a => a.attach.foldl (init := .fvar idx) fun m ⟨e, _⟩ => +| .fvar _ idx _ a => a.attach.foldl (init := .fvar idx) fun m ⟨e, _⟩ => mf!"{m} {e.mformat.ensurePrec (appPrec + 1)}".setPrec appPrec +| .tvar _ name => mf!"tvar!{name}" | .arrow _ a r => mf!"{a.mformat.ensurePrec (arrowPrec+1)} -> {r.mformat.ensurePrec arrowPrec}" instance {α} : ToStrataFormat (TypeExprF α) where @@ -287,8 +287,8 @@ namespace PreType private protected def mformat : PreType → StrataFormat | .ident _ tp a => a.attach.foldl (init := mformat tp) (fun m ⟨e, _⟩ => mf!"{m} {e.mformat}") | .bvar _ idx => .bvar idx +| .fvar _ idx _ a => a.attach.foldl (init := .fvar idx) (fun m ⟨e, _⟩ => mf!"{m} {e.mformat}") | .tvar _ name => mf!"tvar!{name}" -| .fvar _ idx a => a.attach.foldl (init := .fvar idx) (fun m ⟨e, _⟩ => mf!"{m} {e.mformat}") | .arrow _ a r => mf!"{a.mformat} -> {r.mformat}" | .funMacro _ idx r => mf!"fnOf({StrataFormat.bvar idx}, {r.mformat})" diff --git a/Strata/DDM/Integration/Lean/ToExpr.lean b/Strata/DDM/Integration/Lean/ToExpr.lean index d86c3c09b..313f20c7e 100644 --- a/Strata/DDM/Integration/Lean/ToExpr.lean +++ b/Strata/DDM/Integration/Lean/ToExpr.lean @@ -128,9 +128,9 @@ private protected def toExpr {α} [ToExpr α] : TypeExprF α → Lean.Expr astAnnExpr! bvar ann (toExpr idx) | .tvar ann name => astAnnExpr! tvar ann (toExpr name) -| .fvar ann idx a => +| .fvar ann idx name a => let ae := arrayToExpr levelZero (TypeExprF.typeExpr (toTypeExpr α)) (a.map (·.toExpr)) - astAnnExpr! fvar ann (toExpr idx) ae + astAnnExpr! fvar ann (toExpr idx) (toExpr name) ae | .arrow ann a r => astAnnExpr! arrow ann a.toExpr r.toExpr @@ -221,9 +221,9 @@ private protected def toExpr : PreType → Lean.Expr astExpr! ident (toExpr loc) (toExpr nm) args | .bvar loc idx => astExpr! bvar (toExpr loc) (toExpr idx) | .tvar loc name => astExpr! tvar (toExpr loc) (toExpr name) -| .fvar loc idx a => +| .fvar loc idx name a => let args := arrayToExpr .zero PreType.typeExpr (a.map (·.toExpr)) - astExpr! fvar (toExpr loc) (toExpr idx) args + astExpr! fvar (toExpr loc) (toExpr idx) (toExpr name) args | .arrow loc a r => astExpr! arrow (toExpr loc) a.toExpr r.toExpr | .funMacro loc i r => @@ -422,6 +422,7 @@ private def toExpr {argDecls} (bi : BindingSpec argDecls) (argDeclsExpr : Lean.E match bi with | .value b => astExpr! value argDeclsExpr (b.toExpr argDeclsExpr) | .type b => astExpr! type argDeclsExpr (b.toExpr argDeclsExpr) + | .typeForward b => astExpr! typeForward argDeclsExpr (b.toExpr argDeclsExpr) | .datatype b => astExpr! datatype argDeclsExpr (b.toExpr argDeclsExpr) | .tvar b => astExpr! tvar argDeclsExpr (b.toExpr argDeclsExpr) diff --git a/Strata/DDM/Ion.lean b/Strata/DDM/Ion.lean index 805022b03..7d7b55031 100644 --- a/Strata/DDM/Ion.lean +++ b/Strata/DDM/Ion.lean @@ -309,6 +309,16 @@ private def deserializeValue {α} (bs : ByteArray) (act : Ion SymbolId → FromI | .ok res => pure res +protected def fromIonString (v : Ion SymbolId) : FromIonM String := + match v.app with + | .string s => pure s + | _ => throw s!"Expected string, got {repr v}" + +protected def fromIonOption {α} (f : Ion SymbolId → FromIonM α) (v : Ion SymbolId) : FromIonM (Option α) := + match v.app with + | .null => pure none + | _ => some <$> f v + end FromIonM class FromIon (α : Type) where @@ -316,6 +326,12 @@ class FromIon (α : Type) where export FromIon (fromIon) +instance : FromIon String where + fromIon := FromIonM.fromIonString + +instance {α} [FromIon α] : FromIon (Option α) where + fromIon := FromIonM.fromIonOption fromIon + namespace FromIon def deserialize {α} [FromIon α] (bs : ByteArray) : Except String α := @@ -351,6 +367,14 @@ private class ToIon (α : Type) where private abbrev toIon := @ToIon.toIon +private instance : ToIon String where + toIon s := pure (.string s) + +private instance {α} [ToIon α] : ToIon (Option α) where + toIon + | some a => toIon a + | none => pure .null + namespace SyntaxCatF private protected def toIon {α} [ToIon α] (cat : SyntaxCatF α) : Ion.InternM (Ion SymbolId) := do @@ -392,14 +416,14 @@ private protected def toIon {α} [ToIon α] (refs : SymbolIdCache) (tpe : TypeEx let args : Array (Ion SymbolId) := #[ionSymbol! "ident", ← toIon ann, v] Ion.sexp <$> a.attach.mapM_off (init := args) fun ⟨e, _⟩ => e.toIon refs - -- A bound type variable with the given index. - | .bvar ann vidx => - return Ion.sexp #[ionSymbol! "bvar", ← toIon ann, .int vidx] -- A polymorphic type variable with the given name. | .tvar ann name => return Ion.sexp #[ionSymbol! "tvar", ← toIon ann, .string name] - | .fvar ann idx a => do - let s : Array (Ion SymbolId) := #[ionSymbol! "fvar", ← toIon ann, .int idx] + -- A bound type variable with the given index. + | .bvar ann vidx => + return Ion.sexp #[ionSymbol! "bvar", ← toIon ann, .int vidx] + | .fvar ann idx name a => do + let s : Array (Ion SymbolId) := #[ionSymbol! "fvar", ← toIon ann, .int idx, ← toIon name] let s ← a.attach.mapM_off (init := s) fun ⟨e, _⟩ => e.toIon refs return Ion.sexp s @@ -435,12 +459,13 @@ private protected def fromIon {α} [FromIon α] (v : Ion SymbolId) : FromIonM (T (← FromIon.fromIon args[1]) (← .asString "Type expression tvar name" args[2]) | "fvar" => - let ⟨p⟩ ← .checkArgMin "Type expression free variable" args 3 + let ⟨p⟩ ← .checkArgMin "Type expression free variable" args 4 let ann ← FromIon.fromIon args[1] let idx ← .asNat "Type expression free variable index" args[2] - let a ← args.attach.mapM_off (start := 3) fun ⟨e, _⟩ => + let name ← FromIon.fromIon args[3] + let a ← args.attach.mapM_off (start := 4) fun ⟨e, _⟩ => TypeExprF.fromIon e - pure <| .fvar ann idx a + pure <| .fvar ann idx name a | "ident" => let ⟨p⟩ ← .checkArgMin "TypeExpr identifier" args 3 let ann ← FromIon.fromIon args[1] @@ -949,13 +974,13 @@ private protected def toIon (refs : SymbolIdCache) (tpe : PreType) : InternM (Io -- A bound type variable with the given index. | .bvar loc vidx => return Ion.sexp #[ionSymbol! "bvar", ← toIon loc, .int vidx] + | .fvar loc idx name a => do + let s : Array (Ion SymbolId) := #[ionSymbol! "fvar", ← toIon loc, .int idx, ← toIon name] + let s ← a.attach.mapM_off (init := s) fun ⟨e, _⟩ => e.toIon refs + return Ion.sexp s -- A polymorphic type variable with the given name. | .tvar loc name => return Ion.sexp #[ionSymbol! "tvar", ← toIon loc, .string name] - | .fvar loc idx a => do - let s : Array (Ion SymbolId) := #[ionSymbol! "fvar", ← toIon loc, .int idx] - let s ← a.attach.mapM_off (init := s) fun ⟨e, _⟩ => e.toIon refs - return Ion.sexp s | .arrow loc l r => do return Ion.sexp #[ionSymbol! "arrow", ← toIon loc, ← l.toIon refs, ← r.toIon refs] | .funMacro loc i r => @@ -985,12 +1010,13 @@ private protected def fromIon (v : Ion SymbolId) : FromIonM PreType := do (← fromIon args[1]) (← .asString "PreType tvar name" args[2]) | "fvar" => - let ⟨p⟩ ← .checkArgMin "fvar" args 3 + let ⟨p⟩ ← .checkArgMin "fvar" args 4 let ann ← fromIon args[1] let idx ← .asNat "fvar" args[2] - let a ← args.attach.mapM_off (start := 3) fun ⟨e, _⟩ => + let name ← fromIon args[3] + let a ← args.attach.mapM_off (start := 4) fun ⟨e, _⟩ => PreType.fromIon e - pure <| .fvar ann idx a + pure <| .fvar ann idx name a | "ident" => let ⟨p⟩ ← .checkArgMin "ident" args 3 let ann ← fromIon args[1] diff --git a/Strata/DL/Lambda/TypeFactory.lean b/Strata/DL/Lambda/TypeFactory.lean index a897d62d2..632f7d5a1 100644 --- a/Strata/DL/Lambda/TypeFactory.lean +++ b/Strata/DL/Lambda/TypeFactory.lean @@ -49,6 +49,9 @@ structure LConstr (IDMeta : Type) where testerName : String := "is" ++ name.name deriving Repr, DecidableEq +instance [Inhabited IDMeta] : Inhabited (LConstr IDMeta) where + default := { name := default, args := [] } + instance: ToFormat (LConstr IDMeta) where format c := f!"Name:{Format.line}{c.name}{Format.line}\ Args:{Format.line}{c.args}{Format.line}\ @@ -65,6 +68,9 @@ structure LDatatype (IDMeta : Type) where constrs_ne : constrs.length != 0 deriving Repr, DecidableEq +instance [Inhabited IDMeta] : Inhabited (LDatatype IDMeta) where + default := { name := "", typeArgs := [], constrs := [default], constrs_ne := rfl } + instance : ToFormat (LDatatype IDMeta) where format d := f!"type:{Format.line}{d.name}{Format.line}\ Type Arguments:{Format.line}{d.typeArgs}{Format.line}\ diff --git a/Strata/Languages/Core/DDMTransform/Parse.lean b/Strata/Languages/Core/DDMTransform/Parse.lean index 690d7de27..afe98f492 100644 --- a/Strata/Languages/Core/DDMTransform/Parse.lean +++ b/Strata/Languages/Core/DDMTransform/Parse.lean @@ -254,6 +254,10 @@ op command_procedure (name : Ident, op command_typedecl (name : Ident, args : Option Bindings) : Command => "type " name args ";\n"; +@[declareTypeForward(name, some args)] +op command_forward_typedecl (name : Ident, args : Option Bindings) : Command => + "forward type " name args ";\n"; + @[aliasType(name, some args, rhs)] op command_typesynonym (name : Ident, args : Option Bindings, @@ -328,6 +332,12 @@ op command_datatype (name : Ident, @[scopeDatatype(name, typeParams)] constructors : ConstructorList) : Command => "datatype " name typeParams " {" constructors "}" ";\n"; +// Mutual block for defining mutually recursive types +// Types should be forward-declared before the mutual block +@[scope(commands)] +op command_mutual (commands : SpacePrefixSepBy Command) : Command => + "mutual\n" indent(2, commands) "end;\n"; + #end namespace CoreDDM diff --git a/Strata/Languages/Core/DDMTransform/Translate.lean b/Strata/Languages/Core/DDMTransform/Translate.lean index 6d8b5b326..673e378d8 100644 --- a/Strata/Languages/Core/DDMTransform/Translate.lean +++ b/Strata/Languages/Core/DDMTransform/Translate.lean @@ -235,7 +235,7 @@ partial def translateLMonoTy (bindings : TransBindings) (arg : Arg) : | .ident _ i argst => let argst' ← translateLMonoTys bindings (argst.map ArgF.type) pure <| (.tcons i.name argst'.toList.reverse) - | .fvar _ i argst => + | .fvar _ i typeName argst => assert! i < bindings.freeVars.size let decl := bindings.freeVars[i]! let ty_core ← match decl with @@ -249,13 +249,18 @@ partial def translateLMonoTy (bindings : TransBindings) (arg : Arg) : | .type (.syn syn) _md => let ty := syn.toLHSLMonoTy pure ty - | .type (.data (ldatatype :: _)) _md => - -- Datatype Declaration - -- TODO: Handle mutual blocks, need to find the specific datatype by name + | .type (.data block) _md => + -- Datatype Declaration (possibly mutual) + -- Use the type name stored in the fvar to find the right datatype in the block + let ldatatype := match typeName, block with + | some name, _ => + match block.find? (fun (d : LDatatype Visibility) => d.name == name) with + | some d => d + | none => panic! "Error: datatype {name} not found in block {block}" + | none, d :: _ => d + | none, [] => panic! "Empty datatype block" let args := ldatatype.typeArgs.map LMonoTy.ftvar pure (.tcons ldatatype.name args) - | .type (.data []) _md => - TransM.error "Empty mutual datatype block" | _ => TransM.error s!"translateLMonoTy not yet implemented for this declaration: \ @@ -345,6 +350,31 @@ def translateTypeDecl (bindings : TransBindings) (op : Operation) : let decl := Core.Decl.type (.con { name := name, numargs := numargs }) md return (decl, { bindings with freeVars := bindings.freeVars.push decl }) +/-- +Translate a forward type declaration. This creates a placeholder entry that will +be replaced when the actual datatype definition is encountered in a mutual block. +-/ +def translateForwardTypeDecl (bindings : TransBindings) (op : Operation) : + TransM (Core.Decl × TransBindings) := do + let _ ← @checkOp (Core.Decl × TransBindings) op q`Core.command_forward_typedecl 2 + let name ← translateIdent TyIdentifier op.args[0]! + let numargs ← + translateOption + (fun maybearg => + do match maybearg with + | none => pure 0 + | some arg => + let bargs ← checkOpArg arg q`Core.mkBindings 1 + let numargs ← + match bargs[0]! with + | .seq _ .comma args => pure args.size + | _ => TransM.error + s!"translateForwardTypeDecl expects a comma separated list: {repr bargs[0]!}") + op.args[1]! + let md ← getOpMetaData op + let decl := Core.Decl.type (.con { name := name, numargs := numargs }) md + return (decl, { bindings with freeVars := bindings.freeVars.push decl }) + --------------------------------------------------------------------- def translateLhs (arg : Arg) : TransM CoreIdent := do @@ -425,7 +455,7 @@ def translateOptionMonoDeclList (bindings : TransBindings) (arg : Arg) : partial def dealiasTypeExpr (p : Program) (te : TypeExpr) : TypeExpr := match te with - | (.fvar _ idx #[]) => + | (.fvar _ idx _ #[]) => match p.globalContext.kindOf! idx with | .expr te => te | .type [] (.some te) => te @@ -1311,6 +1341,88 @@ def translateConstructorList (p : Program) (bindings : TransBindings) (arg : Arg let constructorInfos := GlobalContext.extractConstructorInfo p.dialects arg constructorInfos.mapM (translateConstructorInfo bindings) +--------------------------------------------------------------------- +-- Common helpers for datatype translation + +/-- +Extract type arguments from a datatype's optional bindings argument. +-/ +def translateDatatypeTypeArgs (bindings : TransBindings) (arg : Arg) (errorContext : String) : + TransM (List TyIdentifier × TransBindings) := + translateOption + (fun maybearg => do + match maybearg with + | none => pure ([], bindings) + | some arg => + let bargs ← checkOpArg arg q`Core.mkBindings 1 + match bargs[0]! with + | .seq _ .comma args => + let (arr, bindings) ← translateTypeBindings bindings args + return (arr.toList, bindings) + | _ => TransM.error s!"{errorContext} expects a comma separated list: {repr bargs[0]!}") + arg + +/-- +Create a placeholder LDatatype for recursive type references. +-/ +def mkPlaceholderLDatatype (name : String) (typeArgs : List TyIdentifier) : LDatatype Visibility := + { name := name + typeArgs := typeArgs + constrs := [{ name := name, args := [], testerName := "" }] + constrs_ne := by simp } + +/-- +Filter factory function declarations to extract constructor, tester, and field accessor decls +for a single datatype. +-/ +def filterDatatypeDecls (ldatatype : LDatatype Visibility) (funcDecls : List Core.Decl) : + List Core.Decl × List Core.Decl × List Core.Decl := + let constructorNames := ldatatype.constrs.map fun c => c.name.name + let testerNames := ldatatype.constrs.map fun c => c.testerName + let fieldAccessorNames := ldatatype.constrs.foldl (fun acc c => + acc ++ (c.args.map fun (fieldName, _) => ldatatype.name ++ ".." ++ fieldName.name)) [] + + let constructorDecls := funcDecls.filter fun decl => + match decl with + | .func f => constructorNames.contains f.name.name + | _ => false + + let testerDecls := funcDecls.filter fun decl => + match decl with + | .func f => testerNames.contains f.name.name + | _ => false + + let fieldAccessorDecls := funcDecls.filter fun decl => + match decl with + | .func f => fieldAccessorNames.contains f.name.name + | _ => false + + (constructorDecls, testerDecls, fieldAccessorDecls) + +/-- +Build LConstr list from TransConstructorInfo array. +-/ +def buildLConstrs (datatypeName : String) (constructors : Array TransConstructorInfo) : + List (LConstr Visibility) := + let testerPattern : Array NamePatternPart := #[.datatype, .literal "..is", .constructor] + constructors.toList.map fun constr => + let testerName := expandNamePattern testerPattern datatypeName (some constr.name.name) + { name := constr.name + args := constr.fields.toList.map fun (fieldName, fieldType) => (fieldName, fieldType) + testerName := testerName } + +/-- +Generate factory function declarations from a list of LDatatypes. +-/ +def genDatatypeFactory (ldatatypes : List (LDatatype Visibility)) : + TransM (List Core.Decl) := do + let factory ← match genBlockFactory ldatatypes (T := CoreLParams) with + | .ok f => pure f + | .error e => TransM.error s!"Failed to generate datatype factory: {e}" + return factory.toList.map fun func => Core.Decl.func func + +--------------------------------------------------------------------- + /-- Translate a datatype declaration to Boogie declarations, updating bindings appropriately. @@ -1326,38 +1438,19 @@ duplicates. - `op`: The `command_datatype` operation to translate -/ def translateDatatype (p : Program) (bindings : TransBindings) (op : Operation) : - TransM (Core.Decls × TransBindings) := do + TransM (Core.Decl × TransBindings) := do -- Check operation has correct name and argument count let _ ← @checkOp (Core.Decls × TransBindings) op q`Core.command_datatype 3 let datatypeName ← translateIdent String op.args[0]! -- Extract type arguments (optional bindings) - let (typeArgs, bindings) ← - translateOption - (fun maybearg => - do match maybearg with - | none => pure ([], bindings) - | some arg => - let bargs ← checkOpArg arg q`Core.mkBindings 1 - let args ← - match bargs[0]! with - | .seq _ .comma args => - let (arr, bindings) ← translateTypeBindings bindings args - return (arr.toList, bindings) - | _ => TransM.error - s!"translateDatatype expects a comma separated list: {repr bargs[0]!}") - op.args[1]! + let (typeArgs, bindings) ← translateDatatypeTypeArgs bindings op.args[1]! "translateDatatype" /- Note: Add a placeholder for the datatype type BEFORE translating constructors, for recursive constructors. Replaced with actual declaration later. -/ - let placeholderLDatatype : LDatatype Visibility := - { name := datatypeName - typeArgs := typeArgs - constrs := [{ name := datatypeName, args := [], testerName := "" }] - constrs_ne := by simp } - let placeholderDecl := Core.Decl.type (.data [placeholderLDatatype]) + let placeholderDecl := Core.Decl.type (.data [mkPlaceholderLDatatype datatypeName typeArgs]) let bindingsWithPlaceholder := { bindings with freeVars := bindings.freeVars.push placeholderDecl } -- Extract constructor information (possibly recursive) @@ -1367,15 +1460,10 @@ def translateDatatype (p : Program) (bindings : TransBindings) (op : Operation) TransM.error s!"Datatype {datatypeName} must have at least one constructor" else -- Build LConstr list from TransConstructorInfo - let testerPattern : Array NamePatternPart := #[.datatype, .literal "..is", .constructor] - let lConstrs : List (LConstr Visibility) := constructors.toList.map fun constr => - let testerName := expandNamePattern testerPattern datatypeName (some constr.name.name) - { name := constr.name - args := constr.fields.toList.map fun (fieldName, fieldType) => (fieldName, fieldType) - testerName := testerName } + let lConstrs := buildLConstrs datatypeName constructors have constrs_ne : lConstrs.length != 0 := by - simp [lConstrs] + simp [lConstrs, buildLConstrs] intro heq; subst_vars; apply h; rfl let ldatatype : LDatatype Visibility := @@ -1384,57 +1472,105 @@ def translateDatatype (p : Program) (bindings : TransBindings) (op : Operation) constrs := lConstrs constrs_ne := constrs_ne } - -- Generate factory from LDatatype and convert to Core.Decl - -- (used only for bindings.freeVars, not for allDecls) - let factory ← match genBlockFactory [ldatatype] (T := CoreLParams) with - | .ok f => pure f - | .error e => TransM.error s!"Failed to generate datatype factory: {e}" - let funcDecls : List Core.Decl := factory.toList.map fun func => - Core.Decl.func func + -- Generate factory from LDatatype + let funcDecls ← genDatatypeFactory [ldatatype] - -- Only includes typeDecl, factory functions generated later let md ← getOpMetaData op let typeDecl := Core.Decl.type (.data [ldatatype]) md - let allDecls := [typeDecl] - - /- - We must add to bindings.freeVars in the same order as the DDM's - `addDatatypeBindings`: type, constructors, template functions. We do NOT - include eliminators here because the DDM does not (yet) produce them. - -/ - - let constructorNames : List String := lConstrs.map fun c => c.name.name - let testerNames : List String := lConstrs.map fun c => c.testerName - - -- Extract all field accessor names across all constructors for field projections - -- Note: DDM validates that field names are unique across constructors - -- Field accessors are named as "Datatype..fieldName" - let fieldAccessorNames : List String := lConstrs.foldl (fun acc c => - acc ++ (c.args.map fun (fieldName, _) => datatypeName ++ ".." ++ fieldName.name)) [] - - -- Filter factory functions to get constructors, testers, projections - -- TODO: this could be more efficient via `LDatatype.genFunctionMaps` - let constructorDecls := funcDecls.filter fun decl => - match decl with - | .func f => constructorNames.contains f.name.name - | _ => false - - let testerDecls := funcDecls.filter fun decl => - match decl with - | .func f => testerNames.contains f.name.name - | _ => false - - let fieldAccessorDecls := funcDecls.filter fun decl => - match decl with - | .func f => fieldAccessorNames.contains f.name.name - | _ => false + -- Filter and add declarations to bindings + let (constructorDecls, testerDecls, fieldAccessorDecls) := filterDatatypeDecls ldatatype funcDecls let bindingDecls := typeDecl :: constructorDecls ++ testerDecls ++ fieldAccessorDecls let bindings := bindingDecls.foldl (fun b d => { b with freeVars := b.freeVars.push d } ) bindings - return (allDecls, bindings) + return (typeDecl, bindings) + +/-- +Translate a mutual block containing mutually recursive datatype definitions. +This collects all datatypes, creates a single TypeDecl.data with all of them, +and updates the forward-declared entries in bindings.freeVars. +-/ +def translateMutualBlock (p : Program) (bindings : TransBindings) (op : Operation) : + TransM (Core.Decl × TransBindings) := do + let _ ← @checkOp (Core.Decls × TransBindings) op q`Core.command_mutual 1 + + -- Extract commands from the SpacePrefixSepBy + let .seq _ _ commands := op.args[0]! + | TransM.error s!"translateMutualBlock expected sequence: {repr op.args[0]!}" + + -- Filter to only datatype commands + let datatypeOps := commands.filterMap fun arg => + match arg with + | .op op => if op.name == q`Core.command_datatype then some op else none + | _ => none + + if datatypeOps.size == 0 then + TransM.error "Mutual block must contain at least one datatype" + else + -- First pass: collect all datatype names, type args, and their indices in freeVars + -- Forward declarations MUST already be in bindings.freeVars + let mut datatypeInfos : Array (String × List TyIdentifier × Nat) := #[] + let mut bindingsWithPlaceholders := bindings + + for dtOp in datatypeOps do + let datatypeName ← translateIdent String dtOp.args[0]! + let (typeArgs, _) ← translateDatatypeTypeArgs bindings dtOp.args[1]! "translateMutualBlock" + + -- Find the index of this datatype in freeVars (from forward declaration) + let existingIdx := bindings.freeVars.findIdx? fun decl => + match decl with + | .type t _ => t.names.contains datatypeName + | _ => false + + match existingIdx with + | some i => + let placeholderDecl := Core.Decl.type (.data [mkPlaceholderLDatatype datatypeName typeArgs]) + datatypeInfos := datatypeInfos.push (datatypeName, typeArgs, i) + bindingsWithPlaceholders := { bindingsWithPlaceholders with + freeVars := bindingsWithPlaceholders.freeVars.set! i placeholderDecl } + | none => + TransM.error s!"Mutual datatype {datatypeName} requires a forward declaration" + + -- Second pass: translate all constructors with all placeholders in scope + let ldatatypes ← (datatypeOps.zip datatypeInfos).toList.mapM fun (dtOp, (datatypeName, typeArgs, _idx)) => do + let constructors ← translateConstructorList p bindingsWithPlaceholders dtOp.args[2]! + if h : constructors.size == 0 then + TransM.error s!"Datatype {datatypeName} must have at least one constructor" + else + let lConstrs := buildLConstrs datatypeName constructors + have constrs_ne : lConstrs.length != 0 := by + simp [lConstrs, buildLConstrs] + intro heq; subst_vars; apply h; rfl + pure { name := datatypeName, typeArgs := typeArgs, constrs := lConstrs, constrs_ne := constrs_ne } + + -- Generate factory functions for the ENTIRE mutual block at once + let allFuncDecls ← genDatatypeFactory ldatatypes + + -- Create the mutual TypeDecl with all datatypes + let md ← getOpMetaData op + let mutualTypeDecl := Core.Decl.type (.data ldatatypes) md + + -- Update bindings.freeVars: replace forward-declared entries with the mutual block + -- For each datatype, update its entry to point to the mutual TypeDecl + let mut finalBindings := bindings + + for (_datatypeName, _typeArgs, idx) in datatypeInfos do + if idx < finalBindings.freeVars.size then + finalBindings := { finalBindings with + freeVars := finalBindings.freeVars.set! idx mutualTypeDecl } + else + finalBindings := { finalBindings with + freeVars := finalBindings.freeVars.push mutualTypeDecl } + + -- Add constructor, tester, and accessor functions for each datatype + for ldatatype in ldatatypes do + let (constructorDecls, testerDecls, fieldAccessorDecls) := filterDatatypeDecls ldatatype allFuncDecls + for d in constructorDecls ++ testerDecls ++ fieldAccessorDecls do + finalBindings := { finalBindings with freeVars := finalBindings.freeVars.push d } + + return (mutualTypeDecl, finalBindings) --------------------------------------------------------------------- @@ -1459,15 +1595,19 @@ partial def translateCoreDecls (p : Program) (bindings : TransBindings) : | 0 => return ([], bindings) | _ + 1 => let op := ops[count]! - -- Commands that produce multiple declarations let (newDecls, bindings) ← match op.name with - | q`Core.command_datatype => - translateDatatype p bindings op + | q`Core.command_forward_typedecl => + -- Forward declarations do NOT produce AST nodes - they only update bindings + let (_, bindings) ← translateForwardTypeDecl bindings op + pure ([], bindings) | _ => - -- All other commands produce a single declaration let (decl, bindings) ← match op.name with + | q`Core.command_datatype => + translateDatatype p bindings op + | q`Core.command_mutual => + translateMutualBlock p bindings op | q`Core.command_var => translateGlobalVar bindings op | q`Core.command_constdecl => diff --git a/StrataTest/DDM/MutualDatatypes.lean b/StrataTest/DDM/MutualDatatypes.lean new file mode 100644 index 000000000..4d4a2f071 --- /dev/null +++ b/StrataTest/DDM/MutualDatatypes.lean @@ -0,0 +1,140 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import Strata.DDM.Integration.Lean + +/-! +# Tests for mutual datatype blocks in DDM + +Tests that mutually recursive datatypes can be declared via forward +declarations and mutual blocks. +-/ + +#dialect +dialect TestMutual; + +metadata declareDatatype (name : Ident, typeParams : Ident, constructors : Ident); + +type int; + +category Binding; +@[declare(name, tp)] +op mkBinding (name : Ident, tp : TypeP) : Binding => @[prec(40)] name ":" tp; + +category Bindings; +@[scope(bindings)] +op mkBindings (bindings : CommaSepBy Binding) : Bindings => "(" bindings ")"; + +category Constructor; +category ConstructorList; + +@[constructor(name, fields)] +op constructor_mk (name : Ident, fields : Option (CommaSepBy Binding)) : Constructor => + name "(" fields ")"; + +@[constructorListAtom(c)] +op constructorListAtom (c : Constructor) : ConstructorList => c; + +@[constructorListPush(cl, c)] +op constructorListPush (cl : ConstructorList, c : Constructor) : ConstructorList => + cl ", " c; + +@[declareTypeForward(name, none)] +op command_forward (name : Ident) : Command => + "forward type " name ";\n"; + +@[declareDatatype(name, typeParams, constructors)] +op command_datatype (name : Ident, + typeParams : Option Bindings, + @[scopeDatatype(name, typeParams)] constructors : ConstructorList) : Command => + "datatype " name typeParams " { " constructors " };\n"; + +@[scope(commands)] +op command_mutual (commands : SpacePrefixSepBy Command) : Command => + "mutual\n" indent(2, commands) "end;\n"; + +#end + +--------------------------------------------------------------------- +-- Test 1: Basic mutual recursion (Tree/Forest) +--------------------------------------------------------------------- + +def mutualBasicPgm := +#strata +program TestMutual; +forward type Tree; +forward type Forest; +mutual + datatype Tree { Node(val: int, children: Forest) }; + datatype Forest { FNil(), FCons(head: Tree, tail: Forest) }; +end; +#end + +/-- +info: program TestMutual; +forward type Tree; +forward type Forest; +mutual + datatype Tree { (Node)(val:int, children:Forest) }; + datatype Forest { ((FNil)()), ((FCons)(head:Tree, tail:Forest)) }; +end; +-/ +#guard_msgs in +#eval IO.println mutualBasicPgm + +--------------------------------------------------------------------- +-- Test 2: Single datatype in mutual block (should not actually be used) +--------------------------------------------------------------------- + +def mutualSinglePgm := +#strata +program TestMutual; +forward type List; +mutual + datatype List { Nil(), Cons(head: int, tail: List) }; +end; +#end + +/-- +info: program TestMutual; +forward type List; +mutual + datatype List { ((Nil)()), ((Cons)(head:int, tail:List)) }; +end; +-/ +#guard_msgs in +#eval IO.println mutualSinglePgm + +--------------------------------------------------------------------- +-- Test 3: Three-way mutual recursion +--------------------------------------------------------------------- + +def mutualThreeWayPgm := +#strata +program TestMutual; +forward type A; +forward type B; +forward type C; +mutual + datatype A { MkA(toB: B) }; + datatype B { MkB(toC: C) }; + datatype C { MkC(toA: A), CBase() }; +end; +#end + +/-- +info: program TestMutual; +forward type A; +forward type B; +forward type C; +mutual + datatype A { (MkA)(toB:B) }; + datatype B { (MkB)(toC:C) }; + datatype C { ((MkC)(toA:A)), ((CBase)()) }; +end; +-/ +#guard_msgs in +#eval IO.println mutualThreeWayPgm diff --git a/StrataTest/Languages/B3/DDMFormatTests.lean b/StrataTest/Languages/B3/DDMFormatTests.lean index be25b5036..0588a8bae 100644 --- a/StrataTest/Languages/B3/DDMFormatTests.lean +++ b/StrataTest/Languages/B3/DDMFormatTests.lean @@ -120,7 +120,7 @@ mutual | .ident () tp a => .ident default tp (a.map typeExprFUnitToSourceRange) | .bvar () idx => .bvar default idx | .tvar () name => .tvar default name - | .fvar () idx a => .fvar default idx (a.map typeExprFUnitToSourceRange) + | .fvar () idx n a => .fvar default idx n (a.map typeExprFUnitToSourceRange) | .arrow () a r => .arrow default (typeExprFUnitToSourceRange a) (typeExprFUnitToSourceRange r) partial def syntaxCatFUnitToSourceRange : SyntaxCatF Unit → SyntaxCatF SourceRange diff --git a/StrataTest/Languages/Core/Examples/DDMAxiomsExtraction.lean b/StrataTest/Languages/Core/Examples/DDMAxiomsExtraction.lean index f01e4ebe7..b9c5df81b 100644 --- a/StrataTest/Languages/Core/Examples/DDMAxiomsExtraction.lean +++ b/StrataTest/Languages/Core/Examples/DDMAxiomsExtraction.lean @@ -157,11 +157,11 @@ info: #[{ ann := { start := { byteIdx := 296 }, stop := { byteIdx := 303 } }, (TypeExprF.fvar { start := { byteIdx := 350 }, stop := { byteIdx := 351 } } - 1 (Array.mkEmpty 0))).push + 1 (some "v") (Array.mkEmpty 0))).push (TypeExprF.fvar { start := { byteIdx := 348 }, stop := { byteIdx := 349 } } - 0 (Array.mkEmpty 0))))) }) })).push + 0 (some "k") (Array.mkEmpty 0))))) }) })).push (ArgF.op { ann := { start := { byteIdx := 353 }, stop := { byteIdx := 358 } }, name := { dialect := "Core", name := "bind_mk" }, @@ -174,7 +174,7 @@ info: #[{ ann := { start := { byteIdx := 296 }, stop := { byteIdx := 303 } }, none)).push (ArgF.type (TypeExprF.fvar { start := { byteIdx := 357 }, stop := { byteIdx := 358 } } - 0 (Array.mkEmpty 0))) }) })).push + 0 (some "k") (Array.mkEmpty 0))) }) })).push (ArgF.op { ann := { start := { byteIdx := 360 }, stop := { byteIdx := 365 } }, name := { dialect := "Core", name := "bind_mk" }, @@ -184,7 +184,7 @@ info: #[{ ann := { start := { byteIdx := 296 }, stop := { byteIdx := 303 } }, "vv")).push (ArgF.option { start := { byteIdx := 364 }, stop := { byteIdx := 364 } } none)).push (ArgF.type - (TypeExprF.fvar { start := { byteIdx := 364 }, stop := { byteIdx := 365 } } 1 + (TypeExprF.fvar { start := { byteIdx := 364 }, stop := { byteIdx := 365 } } 1 (some "v") (Array.mkEmpty 0))) }) })) (ArgF.expr (ExprF.app { start := { byteIdx := 369 }, stop := { byteIdx := 390 } } @@ -193,7 +193,8 @@ info: #[{ ann := { start := { byteIdx := 296 }, stop := { byteIdx := 303 } }, (ExprF.fn { start := { byteIdx := 369 }, stop := { byteIdx := 390 } } { dialect := "Core", name := "equal" }) (ArgF.type - (TypeExprF.fvar { start := { byteIdx := 350 }, stop := { byteIdx := 351 } } 1 (Array.mkEmpty 0)))) + (TypeExprF.fvar { start := { byteIdx := 350 }, stop := { byteIdx := 351 } } 1 (some "v") + (Array.mkEmpty 0)))) (ArgF.expr (ExprF.app { start := { byteIdx := 369 }, stop := { byteIdx := 384 } } (ExprF.app { start := { byteIdx := 369 }, stop := { byteIdx := 384 } } @@ -202,10 +203,10 @@ info: #[{ ann := { start := { byteIdx := 296 }, stop := { byteIdx := 303 } }, (ExprF.fn { start := { byteIdx := 369 }, stop := { byteIdx := 384 } } { dialect := "Core", name := "map_get" }) (ArgF.type - (TypeExprF.fvar { start := { byteIdx := 348 }, stop := { byteIdx := 349 } } 0 + (TypeExprF.fvar { start := { byteIdx := 348 }, stop := { byteIdx := 349 } } 0 (some "k") (Array.mkEmpty 0)))) (ArgF.type - (TypeExprF.fvar { start := { byteIdx := 350 }, stop := { byteIdx := 351 } } 1 + (TypeExprF.fvar { start := { byteIdx := 350 }, stop := { byteIdx := 351 } } 1 (some "v") (Array.mkEmpty 0)))) (ArgF.expr (ExprF.app { start := { byteIdx := 369 }, stop := { byteIdx := 380 } } @@ -217,10 +218,10 @@ info: #[{ ann := { start := { byteIdx := 296 }, stop := { byteIdx := 303 } }, { dialect := "Core", name := "map_set" }) (ArgF.type (TypeExprF.fvar { start := { byteIdx := 348 }, stop := { byteIdx := 349 } } 0 - (Array.mkEmpty 0)))) + (some "k") (Array.mkEmpty 0)))) (ArgF.type (TypeExprF.fvar { start := { byteIdx := 350 }, stop := { byteIdx := 351 } } 1 - (Array.mkEmpty 0)))) + (some "v") (Array.mkEmpty 0)))) (ArgF.expr (ExprF.bvar { start := { byteIdx := 369 }, stop := { byteIdx := 370 } } 2))) (ArgF.expr (ExprF.bvar { start := { byteIdx := 371 }, stop := { byteIdx := 373 } } 1))) (ArgF.expr (ExprF.bvar { start := { byteIdx := 377 }, stop := { byteIdx := 379 } } 0))))) @@ -291,11 +292,11 @@ info: #[{ ann := { start := { byteIdx := 296 }, stop := { byteIdx := 303 } }, (TypeExprF.fvar { start := { byteIdx := 433 }, stop := { byteIdx := 434 } } - 1 (Array.mkEmpty 0))).push + 1 (some "v") (Array.mkEmpty 0))).push (TypeExprF.fvar { start := { byteIdx := 431 }, stop := { byteIdx := 432 } } - 0 (Array.mkEmpty 0))))) }) })).push + 0 (some "k") (Array.mkEmpty 0))))) }) })).push (ArgF.op { ann := { start := { byteIdx := 436 }, stop := { byteIdx := 442 } }, name := { dialect := "Core", name := "bind_mk" }, @@ -311,7 +312,7 @@ info: #[{ ann := { start := { byteIdx := 296 }, stop := { byteIdx := 303 } }, (ArgF.type (TypeExprF.fvar { start := { byteIdx := 441 }, stop := { byteIdx := 442 } } 0 - (Array.mkEmpty 0))) }) })).push + (some "k") (Array.mkEmpty 0))) }) })).push (ArgF.op { ann := { start := { byteIdx := 444 }, stop := { byteIdx := 449 } }, name := { dialect := "Core", name := "bind_mk" }, @@ -324,7 +325,7 @@ info: #[{ ann := { start := { byteIdx := 296 }, stop := { byteIdx := 303 } }, none)).push (ArgF.type (TypeExprF.fvar { start := { byteIdx := 448 }, stop := { byteIdx := 449 } } - 0 (Array.mkEmpty 0))) }) })).push + 0 (some "k") (Array.mkEmpty 0))) }) })).push (ArgF.op { ann := { start := { byteIdx := 451 }, stop := { byteIdx := 456 } }, name := { dialect := "Core", name := "bind_mk" }, @@ -334,7 +335,7 @@ info: #[{ ann := { start := { byteIdx := 296 }, stop := { byteIdx := 303 } }, "vv")).push (ArgF.option { start := { byteIdx := 455 }, stop := { byteIdx := 455 } } none)).push (ArgF.type - (TypeExprF.fvar { start := { byteIdx := 455 }, stop := { byteIdx := 456 } } 1 + (TypeExprF.fvar { start := { byteIdx := 455 }, stop := { byteIdx := 456 } } 1 (some "v") (Array.mkEmpty 0))) }) })) (ArgF.expr (ExprF.app { start := { byteIdx := 460 }, stop := { byteIdx := 486 } } @@ -343,7 +344,8 @@ info: #[{ ann := { start := { byteIdx := 296 }, stop := { byteIdx := 303 } }, (ExprF.fn { start := { byteIdx := 460 }, stop := { byteIdx := 486 } } { dialect := "Core", name := "equal" }) (ArgF.type - (TypeExprF.fvar { start := { byteIdx := 433 }, stop := { byteIdx := 434 } } 1 (Array.mkEmpty 0)))) + (TypeExprF.fvar { start := { byteIdx := 433 }, stop := { byteIdx := 434 } } 1 (some "v") + (Array.mkEmpty 0)))) (ArgF.expr (ExprF.app { start := { byteIdx := 460 }, stop := { byteIdx := 476 } } (ExprF.app { start := { byteIdx := 460 }, stop := { byteIdx := 476 } } @@ -352,10 +354,10 @@ info: #[{ ann := { start := { byteIdx := 296 }, stop := { byteIdx := 303 } }, (ExprF.fn { start := { byteIdx := 460 }, stop := { byteIdx := 476 } } { dialect := "Core", name := "map_get" }) (ArgF.type - (TypeExprF.fvar { start := { byteIdx := 431 }, stop := { byteIdx := 432 } } 0 + (TypeExprF.fvar { start := { byteIdx := 431 }, stop := { byteIdx := 432 } } 0 (some "k") (Array.mkEmpty 0)))) (ArgF.type - (TypeExprF.fvar { start := { byteIdx := 433 }, stop := { byteIdx := 434 } } 1 + (TypeExprF.fvar { start := { byteIdx := 433 }, stop := { byteIdx := 434 } } 1 (some "v") (Array.mkEmpty 0)))) (ArgF.expr (ExprF.app { start := { byteIdx := 460 }, stop := { byteIdx := 471 } } @@ -367,10 +369,10 @@ info: #[{ ann := { start := { byteIdx := 296 }, stop := { byteIdx := 303 } }, { dialect := "Core", name := "map_set" }) (ArgF.type (TypeExprF.fvar { start := { byteIdx := 431 }, stop := { byteIdx := 432 } } 0 - (Array.mkEmpty 0)))) + (some "k") (Array.mkEmpty 0)))) (ArgF.type (TypeExprF.fvar { start := { byteIdx := 433 }, stop := { byteIdx := 434 } } 1 - (Array.mkEmpty 0)))) + (some "v") (Array.mkEmpty 0)))) (ArgF.expr (ExprF.bvar { start := { byteIdx := 460 }, stop := { byteIdx := 461 } } 3))) (ArgF.expr (ExprF.bvar { start := { byteIdx := 462 }, stop := { byteIdx := 464 } } 1))) (ArgF.expr (ExprF.bvar { start := { byteIdx := 468 }, stop := { byteIdx := 470 } } 0))))) @@ -383,10 +385,10 @@ info: #[{ ann := { start := { byteIdx := 296 }, stop := { byteIdx := 303 } }, (ExprF.fn { start := { byteIdx := 480 }, stop := { byteIdx := 486 } } { dialect := "Core", name := "map_get" }) (ArgF.type - (TypeExprF.fvar { start := { byteIdx := 431 }, stop := { byteIdx := 432 } } 0 + (TypeExprF.fvar { start := { byteIdx := 431 }, stop := { byteIdx := 432 } } 0 (some "k") (Array.mkEmpty 0)))) (ArgF.type - (TypeExprF.fvar { start := { byteIdx := 433 }, stop := { byteIdx := 434 } } 1 + (TypeExprF.fvar { start := { byteIdx := 433 }, stop := { byteIdx := 434 } } 1 (some "v") (Array.mkEmpty 0)))) (ArgF.expr (ExprF.bvar { start := { byteIdx := 480 }, stop := { byteIdx := 481 } } 3))) (ArgF.expr (ExprF.bvar { start := { byteIdx := 482 }, stop := { byteIdx := 485 } } 2)))))))) }] diff --git a/StrataTest/Languages/Core/Examples/MutualDatatypes.lean b/StrataTest/Languages/Core/Examples/MutualDatatypes.lean new file mode 100644 index 000000000..e8b8365ff --- /dev/null +++ b/StrataTest/Languages/Core/Examples/MutualDatatypes.lean @@ -0,0 +1,360 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import Strata.Languages.Core.Verifier + +/-! +# Mutual Datatype Integration Tests + +Tests mutually recursive datatypes using the DDM datatype declaration syntax. +-/ + +namespace Strata.MutualDatatypeTest + +--------------------------------------------------------------------- +-- Test 1: Basic Rose Tree Datatype Declaration and Tester Functions +--------------------------------------------------------------------- + +def roseTreeTesterPgm : Program := +#strata +program Core; + +forward type RoseTree; +forward type Forest; +mutual + datatype Forest { FNil(), FCons(head: RoseTree, tail: Forest) }; + datatype RoseTree { Node(val: int, children: Forest) }; +end; + +procedure TestRoseTreeTesters() returns () +spec { + ensures true; +} +{ + var t : RoseTree; + var f : Forest; + + f := FNil(); + assert [isFNil]: Forest..isFNil(f); + assert [notFCons]: !Forest..isFCons(f); + + t := Node(42, FNil()); + assert [isNode]: RoseTree..isNode(t); + + f := FCons(Node(1, FNil()), FNil()); + assert [isFCons]: Forest..isFCons(f); + assert [notFNil]: !Forest..isFNil(f); +}; +#end + +/-- info: true -/ +#guard_msgs in +#eval TransM.run Inhabited.default (translateProgram roseTreeTesterPgm) |>.snd |>.isEmpty + +/-- +info: +Obligation: isFNil +Property: assert +Result: ✅ pass + +Obligation: notFCons +Property: assert +Result: ✅ pass + +Obligation: isNode +Property: assert +Result: ✅ pass + +Obligation: isFCons +Property: assert +Result: ✅ pass + +Obligation: notFNil +Property: assert +Result: ✅ pass + +Obligation: TestRoseTreeTesters_ensures_0 +Property: assert +Result: ✅ pass +-/ +#guard_msgs in +#eval verify "cvc5" roseTreeTesterPgm Inhabited.default Options.quiet + +--------------------------------------------------------------------- +-- Test 2: Rose Tree Destructor Functions +--------------------------------------------------------------------- + +def roseTreeDestructorPgm : Program := +#strata +program Core; + +forward type RoseTree; +forward type Forest; +mutual + datatype Forest { FNil(), FCons(head: RoseTree, tail: Forest) }; + datatype RoseTree { Node(val: int, children: Forest) }; +end; + +procedure TestRoseTreeDestructor() returns () +spec { + ensures true; +} +{ + var t : RoseTree; + var f : Forest; + var v : int; + var c : Forest; + + t := Node(42, FNil()); + + v := RoseTree..val(t); + assert [valIs42]: v == 42; + + c := RoseTree..children(t); + assert [childrenIsNil]: Forest..isFNil(c); + + f := FCons(Node(10, FNil()), FNil()); + + t := Forest..head(f); + assert [headIsNode]: RoseTree..isNode(t); + assert [headVal]: RoseTree..val(t) == 10; + + f := Forest..tail(f); + assert [tailIsNil]: Forest..isFNil(f); +}; +#end + +/-- info: true -/ +#guard_msgs in +#eval TransM.run Inhabited.default (translateProgram roseTreeDestructorPgm) |>.snd |>.isEmpty + +/-- +info: +Obligation: valIs42 +Property: assert +Result: ✅ pass + +Obligation: childrenIsNil +Property: assert +Result: ✅ pass + +Obligation: headIsNode +Property: assert +Result: ✅ pass + +Obligation: headVal +Property: assert +Result: ✅ pass + +Obligation: tailIsNil +Property: assert +Result: ✅ pass + +Obligation: TestRoseTreeDestructor_ensures_0 +Property: assert +Result: ✅ pass +-/ +#guard_msgs in +#eval verify "cvc5" roseTreeDestructorPgm Inhabited.default Options.quiet + +--------------------------------------------------------------------- +-- Test 3: Rose Tree Equality +--------------------------------------------------------------------- + +def roseTreeEqualityPgm : Program := +#strata +program Core; + +forward type RoseTree; +forward type Forest; +mutual + datatype Forest { FNil(), FCons(head: RoseTree, tail: Forest) }; + datatype RoseTree { Node(val: int, children: Forest) }; +end; + +procedure TestRoseTreeEquality() returns () +spec { + ensures true; +} +{ + var t1 : RoseTree; + var t2 : RoseTree; + var f1 : Forest; + var f2 : Forest; + + t1 := Node(42, FNil()); + t2 := Node(42, FNil()); + assert [leafEquality]: t1 == t2; + + f1 := FNil(); + f2 := FNil(); + assert [emptyForestEquality]: f1 == f2; + + f1 := FCons(Node(1, FNil()), FNil()); + f2 := FCons(Node(1, FNil()), FNil()); + assert [forestEquality]: f1 == f2; +}; +#end + +/-- info: true -/ +#guard_msgs in +#eval TransM.run Inhabited.default (translateProgram roseTreeEqualityPgm) |>.snd |>.isEmpty + +/-- +info: +Obligation: leafEquality +Property: assert +Result: ✅ pass + +Obligation: emptyForestEquality +Property: assert +Result: ✅ pass + +Obligation: forestEquality +Property: assert +Result: ✅ pass + +Obligation: TestRoseTreeEquality_ensures_0 +Property: assert +Result: ✅ pass +-/ +#guard_msgs in +#eval verify "cvc5" roseTreeEqualityPgm Inhabited.default Options.quiet + +--------------------------------------------------------------------- +-- Test 4: Polymorphic Rose Tree with Havoc (SMT verification) +--------------------------------------------------------------------- + +def polyRoseTreeHavocPgm : Program := +#strata +program Core; + +forward type RoseTree (a : Type); +forward type Forest (a : Type); +mutual + datatype Forest (a : Type) { FNil(), FCons(head: RoseTree a, tail: Forest a) }; + datatype RoseTree (a : Type) { Node(val: a, children: Forest a) }; +end; + +procedure TestPolyRoseTreeHavoc() returns () +spec { + ensures true; +} +{ + var t : RoseTree int; + var f : Forest int; + + havoc t; + havoc f; + + assume t == Node(42, FNil()); + assume f == FCons(t, FNil()); + + assert [valIs42]: RoseTree..val(t) == 42; + assert [headIsT]: Forest..head(f) == t; + assert [headVal]: RoseTree..val(Forest..head(f)) == 42; +}; +#end + +/-- info: true -/ +#guard_msgs in +#eval TransM.run Inhabited.default (translateProgram polyRoseTreeHavocPgm) |>.snd |>.isEmpty + +/-- +info: +Obligation: valIs42 +Property: assert +Result: ✅ pass + +Obligation: headIsT +Property: assert +Result: ✅ pass + +Obligation: headVal +Property: assert +Result: ✅ pass + +Obligation: TestPolyRoseTreeHavoc_ensures_0 +Property: assert +Result: ✅ pass +-/ +#guard_msgs in +#eval verify "cvc5" polyRoseTreeHavocPgm Inhabited.default Options.quiet + +--------------------------------------------------------------------- +-- Test 5: Imperative Stmt/StmtList with Havoc (SMT verification) +--------------------------------------------------------------------- + +/-- Mutually recursive Stmt/StmtList modeling Imperative.Stmt -/ +def stmtListHavocPgm : Program := +#strata +program Core; + +forward type Stmt (e : Type, c : Type); +forward type StmtList (e : Type, c : Type); +mutual + datatype StmtList (e : Type, c : Type) { SNil(), SCons(hd: Stmt e c, tl: StmtList e c) }; + datatype Stmt (e : Type, c : Type) { + Cmd(cmd: c), + Block(label: int, blockBody: StmtList e c), + Ite(cond: e, thenB: StmtList e c, elseB: StmtList e c), + Loop(guard: e, loopBody: StmtList e c), + Goto(target: int) + }; +end; + +procedure TestStmtListHavoc() returns () +spec { + ensures true; +} +{ + var s : Stmt bool int; + var ss : StmtList bool int; + + havoc s; + havoc ss; + + // A block containing a single command + assume s == Block(1, SCons(Cmd(42), SNil())); + assume ss == SCons(s, SCons(Goto(1), SNil())); + + assert [isBlock]: Stmt..isBlock(s); + assert [bodyHd]: Stmt..isCmd(StmtList..hd(Stmt..blockBody(s))); + assert [cmdVal]: Stmt..cmd(StmtList..hd(Stmt..blockBody(s))) == 42; + assert [secondIsGoto]: Stmt..isGoto(StmtList..hd(StmtList..tl(ss))); +}; +#end + +/-- info: true -/ +#guard_msgs in +#eval TransM.run Inhabited.default (translateProgram stmtListHavocPgm) |>.snd |>.isEmpty + +/-- +info: +Obligation: isBlock +Property: assert +Result: ✅ pass + +Obligation: bodyHd +Property: assert +Result: ✅ pass + +Obligation: cmdVal +Property: assert +Result: ✅ pass + +Obligation: secondIsGoto +Property: assert +Result: ✅ pass + +Obligation: TestStmtListHavoc_ensures_0 +Property: assert +Result: ✅ pass +-/ +#guard_msgs in +#eval verify "cvc5" stmtListHavocPgm Inhabited.default Options.quiet + +end Strata.MutualDatatypeTest diff --git a/StrataTest/Languages/Core/Examples/RemoveIrrelevantAxioms.lean b/StrataTest/Languages/Core/Examples/RemoveIrrelevantAxioms.lean index 90a10b922..220aa1436 100644 --- a/StrataTest/Languages/Core/Examples/RemoveIrrelevantAxioms.lean +++ b/StrataTest/Languages/Core/Examples/RemoveIrrelevantAxioms.lean @@ -39,7 +39,7 @@ procedure P() returns () assert f(23); assert f(-(5)); } - end : {} + _exit : {} }; procedure Q0(x : int) returns () @@ -49,7 +49,7 @@ procedure Q0(x : int) returns () assert (x == 2); assert (x == 2); } - end : {} + _exit : {} }; procedure Q1(x : int) returns () @@ -59,7 +59,7 @@ procedure Q1(x : int) returns () assert (x == 2); assert (x == 2); } - end : {} + _exit : {} }; procedure Q2(x : int) returns () @@ -69,7 +69,7 @@ procedure Q2(x : int) returns () assert (x == 2); assert (x == 2); } - end : {} + _exit : {} }; procedure Q3(x : int) returns () @@ -79,7 +79,7 @@ procedure Q3(x : int) returns () assert (x == 2); assert (x == 2); } - end : {} + _exit : {} }; #end diff --git a/StrataTest/Transform/ProcedureInlining.lean b/StrataTest/Transform/ProcedureInlining.lean index 91d859de3..c83b45b57 100644 --- a/StrataTest/Transform/ProcedureInlining.lean +++ b/StrataTest/Transform/ProcedureInlining.lean @@ -294,17 +294,17 @@ def Test2 := program Core; procedure f(x : bool) returns (y : bool) { if (x) { - goto end; + goto _exit; } else { y := false; } - end: {} + _exit: {} }; procedure h() returns () { var b_in : bool; var b_out : bool; call b_out := f(b_in); - end: {} + _exit: {} }; #end @@ -313,10 +313,10 @@ def Test2Ans := program Core; procedure f(x : bool) returns (y : bool) { if (x) { - goto end; + goto _exit; } else { y := false; } - end: {} + _exit: {} }; procedure h() returns () { @@ -334,7 +334,7 @@ procedure h() returns () { f_end: {} b_out := f_y; } - end: {} + _exit: {} }; #end diff --git a/Tools/BoogieToStrata/Source/StrataGenerator.cs b/Tools/BoogieToStrata/Source/StrataGenerator.cs index 1bd7448fe..b854f8ea3 100644 --- a/Tools/BoogieToStrata/Source/StrataGenerator.cs +++ b/Tools/BoogieToStrata/Source/StrataGenerator.cs @@ -846,7 +846,7 @@ public override Cmd VisitAssignCmd(AssignCmd node) { } public override ReturnCmd VisitReturnCmd(ReturnCmd node) { - IndentLine("goto end;"); + IndentLine("goto _exit;"); return node; } @@ -1333,7 +1333,7 @@ public override Implementation VisitImplementation(Implementation node) { } } - IndentLine("end : {}"); + IndentLine("_exit : {}"); DecIndent(); WriteLine("};"); diff --git a/Tools/BoogieToStrata/Tests/B.bpl b/Tools/BoogieToStrata/Tests/B.bpl index 00d8c37b8..0465acefb 100644 --- a/Tools/BoogieToStrata/Tests/B.bpl +++ b/Tools/BoogieToStrata/Tests/B.bpl @@ -14,13 +14,13 @@ procedure Q0() Then: h := 15; - goto end; + goto _exit; Else: assume h == 0; - goto end; + goto _exit; - end: + _exit: assert 0 <= h; return; } @@ -34,13 +34,13 @@ procedure Q1() Then: h := -15; - goto end; + goto _exit; Else: assume h == 0; - goto end; + goto _exit; - end: + _exit: h := -h; assert 0 <= h; return; @@ -54,13 +54,13 @@ procedure P0(this: ref) Then: Heap[this, N] := 15; - goto end; + goto _exit; Else: assume Heap[this, N] == 0; - goto end; + goto _exit; - end: + _exit: assert 0 <= Heap[this, N]; return; } @@ -73,13 +73,13 @@ procedure P1(this: ref) Then: Heap[this, N] := -15; - goto end; + goto _exit; Else: assume Heap[this, N] == 0; - goto end; + goto _exit; - end: + _exit: Heap[this, N] := -Heap[this, N]; assert 0 <= Heap[this, N]; return;