Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 94 additions & 21 deletions Strata/DDM/AST.lean
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,11 @@ inductive TypeExprF (α : Type) where
/-- A polymorphic type variable (universally quantified).
Used for polymorphic function type parameters -/
| tvar (ann : α) (name : String)
/-- A reference to a global variable along with any arguments to ensure it is well-typed. -/
| fvar (ann : α) (fvar : FreeVarIndex) (args : Array (TypeExprF α))
/-- A reference to a global variable along with any arguments to ensure it is
well-typed. The name field stores the original type name for lookup in
mutual blocks. -/
| fvar (ann : α) (fvar : FreeVarIndex) (name : Option String)
(args : Array (TypeExprF α))
/-- A function type. -/
| arrow (ann : α) (arg : TypeExprF α) (res : TypeExprF α)
deriving BEq, Inhabited, Repr
Expand All @@ -106,7 +109,7 @@ def ann {α} : TypeExprF α → α
| .ident ann _ _ => ann
| .bvar ann _ => ann
| .tvar ann _ => ann
| .fvar ann _ _ => ann
| .fvar ann _ _ _ => ann
| .arrow ann _ _ => ann

def mkFunType {α} (n : α) (bindings : Array (String × TypeExprF α)) (res : TypeExprF α) : TypeExprF α :=
Expand All @@ -115,15 +118,15 @@ def mkFunType {α} (n : α) (bindings : Array (String × TypeExprF α)) (res : T
protected def incIndices {α} (tp : TypeExprF α) (count : Nat) : TypeExprF α :=
match tp with
| .ident n i args => .ident n i (args.attach.map fun ⟨e, _⟩ => e.incIndices count)
| .fvar n f args => .fvar n f (args.attach.map fun ⟨e, _⟩ => e.incIndices count)
| .fvar n f name args => .fvar n f name (args.attach.map fun ⟨e, _⟩ => e.incIndices count)
| .bvar n idx => .bvar n (idx + count)
| .tvar n name => .tvar n name -- tvar doesn't use indices
| .arrow n a r => .arrow n (a.incIndices count) (r.incIndices count)

/-- Return true if type expression has a bound variable. -/
protected def hasUnboundVar {α} (bindingCount : Nat := 0) : TypeExprF α → Bool
| .ident _ _ args => args.attach.any (fun ⟨e, _⟩ => e.hasUnboundVar bindingCount)
| .fvar _ _ args => args.attach.any (fun ⟨e, _⟩ => e.hasUnboundVar bindingCount)
| .fvar _ _ _ args => args.attach.any (fun ⟨e, _⟩ => e.hasUnboundVar bindingCount)
| .bvar _ idx => idx ≥ bindingCount
| .tvar _ _ => true
| .arrow _ a r => a.hasUnboundVar bindingCount || r.hasUnboundVar bindingCount
Expand All @@ -142,8 +145,8 @@ protected def instTypeM {m α} [Monad m] (d : TypeExprF α) (bindings : α → N
| .ident n i a =>
.ident n i <$> a.attach.mapM (fun ⟨e, _⟩ => e.instTypeM bindings)
| .bvar n idx => bindings n idx
| .fvar n idx name a => .fvar n idx name <$> a.attach.mapM (fun ⟨e, _⟩ => e.instTypeM bindings)
| .tvar n name => pure (.tvar n name)
| .fvar n idx a => .fvar n idx <$> a.attach.mapM (fun ⟨e, _⟩ => e.instTypeM bindings)
| .arrow n a b => .arrow n <$> a.instTypeM bindings <*> b.instTypeM bindings
termination_by d

Expand Down Expand Up @@ -624,7 +627,7 @@ inductive PreType where
Used for polymorphic function type parameters -/
| tvar (ann : SourceRange) (name : String)
/-- A reference to a global variable along with any arguments to ensure it is well-typed. -/
| fvar (ann : SourceRange) (fvar : FreeVarIndex) (args : Array PreType)
| fvar (ann : SourceRange) (fvar : FreeVarIndex) (name : Option String) (args : Array PreType)
/-- A function type. -/
| arrow (ann : SourceRange) (arg : PreType) (res : PreType)
/-- A function created from a reference to bindings and a result type. -/
Expand All @@ -638,15 +641,15 @@ def ann : PreType → SourceRange
| .ident ann _ _ => ann
| .bvar ann _ => ann
| .tvar ann _ => ann
| .fvar ann _ _ => ann
| .fvar ann _ _ _ => ann
| .arrow ann _ _ => ann
| .funMacro ann _ _ => ann

def ofType : TypeExprF SourceRange → PreType
| .ident loc name args => .ident loc name (args.map fun a => .ofType a)
| .bvar loc idx => .bvar loc idx
| .fvar loc idx name args => .fvar loc idx name (args.map fun a => .ofType a)
| .tvar loc name => .tvar loc name
| .fvar loc idx args => .fvar loc idx (args.map fun a => .ofType a)
| .arrow loc a r => .arrow loc (.ofType a) (.ofType r)
termination_by tp => tp

Expand Down Expand Up @@ -869,9 +872,9 @@ structure ConstructorInfo where
Build a TypeExpr reference to the datatype with type parameters, using
`.fvar` for the datatype's GlobalContext index and `.tvar` for type parameters.
-/
def mkDatatypeTypeRef (ann : SourceRange) (datatypeIndex : FreeVarIndex) (typeParams : Array String) : TypeExpr :=
def mkDatatypeTypeRef (ann : SourceRange) (datatypeIndex : FreeVarIndex) (typeParams : Array String) (datatypeName : Option String := none) : TypeExpr :=
let typeArgs := typeParams.map fun name => TypeExprF.tvar ann name
TypeExprF.fvar ann datatypeIndex typeArgs
TypeExprF.fvar ann datatypeIndex datatypeName typeArgs

/--
Build an arrow type from field types to the datatype type. E.g. for cons,
Expand Down Expand Up @@ -1043,6 +1046,7 @@ A spec for introducing a new binding into a type context.
inductive BindingSpec (argDecls : ArgDecls) where
| value (_ : ValueBindingSpec argDecls)
| type (_ : TypeBindingSpec argDecls)
| typeForward (_ : TypeBindingSpec argDecls) -- Forward declaration (no AST node)
| datatype (_ : DatatypeBindingSpec argDecls)
| tvar (_ : TvarBindingSpec argDecls)
deriving Repr
Expand All @@ -1052,6 +1056,7 @@ namespace BindingSpec
def nameIndex {argDecls} : BindingSpec argDecls → DebruijnIndex argDecls.size
| .value v => v.nameIndex
| .type v => v.nameIndex
| .typeForward v => v.nameIndex
| .datatype v => v.nameIndex
| .tvar v => v.nameIndex

Expand Down Expand Up @@ -1135,6 +1140,22 @@ def parseNewBindings (md : Metadata) (argDecls : ArgDecls) : Array (BindingSpec
pure <| some ⟨idx, argsP⟩
| _ => newBindingErr "declareType args invalid."; return none
some <$> .type <$> pure { nameIndex, argsIndex, defIndex := none }
| q`StrataDDL.declareTypeForward => do
let #[.catbvar nameIndex, .option mArgsArg ] := attr.args
| newBindingErr s!"declareTypeForward has bad arguments {repr attr.args}."; return none
let .isTrue nameP := inferInstanceAs (Decidable (nameIndex < argDecls.size))
| return panic! "Invalid name index"
let nameIndex := ⟨nameIndex, nameP⟩
checkNameIndexIsIdent argDecls nameIndex
let argsIndex ←
match mArgsArg with
| none => pure none
| some (.catbvar idx) =>
let .isTrue argsP := inferInstanceAs (Decidable (idx < argDecls.size))
| return panic! "Invalid arg index"
pure <| some ⟨idx, argsP⟩
| _ => newBindingErr "declareTypeForward args invalid."; return none
some <$> .typeForward <$> pure { nameIndex, argsIndex, defIndex := none }
| q`StrataDDL.aliasType => do
let #[.catbvar nameIndex, .option mArgsArg, .catbvar defIndex] := attr.args
| newBindingErr "aliasType missing arguments."; return none
Expand Down Expand Up @@ -1767,6 +1788,12 @@ inductive GlobalKind where
| type (params : List String) (definition : Option TypeExpr)
deriving BEq, Inhabited, Repr

/-- State of a symbol in the GlobalContext -/
inductive DeclState where
| forward -- Symbol is forward-declared (no AST node will be generated)
| defined -- Symbol has a complete definition
deriving BEq, DecidableEq, Repr, Inhabited

/-- Resolves a binding spec into a global kind. -/
partial def resolveBindingIndices { argDecls : ArgDecls } (m : DialectMap) (src : SourceRange) (b : BindingSpec argDecls) (args : Vector Arg argDecls.size) : Option GlobalKind :=
match b with
Expand All @@ -1792,7 +1819,7 @@ partial def resolveBindingIndices { argDecls : ArgDecls } (m : DialectMap) (src
panic! s!"Expected new binding to be Type instead of {repr c}."
| a =>
panic! s!"Expected new binding to be bound to type instead of {repr a}."
| .type b =>
| .type b | .typeForward b =>
let params : Array String :=
match b.argsIndex with
| none => #[]
Expand Down Expand Up @@ -1831,7 +1858,7 @@ Typing environment created from declarations in an environment.
-/
structure GlobalContext where
nameMap : Std.HashMap Var FreeVarIndex
vars : Array (Var × GlobalKind)
vars : Array (Var × GlobalKind × DeclState)
deriving Repr

namespace GlobalContext
Expand All @@ -1851,12 +1878,45 @@ instance : Membership Var GlobalContext where
def instDecidableMem (v : Var) (ctx : GlobalContext) : Decidable (v ∈ ctx) :=
inferInstanceAs (Decidable (v ∈ ctx.nameMap))

def push (ctx : GlobalContext) (v : Var) (k : GlobalKind) : GlobalContext :=
/-- Add a forward declaration (must not exist). Used by @[declareTypeForward].
This adds to GlobalContext for name resolution but will NOT generate an AST node. -/
def declareForward (ctx : GlobalContext) (v : Var) (k : GlobalKind) : Except String GlobalContext :=
if v ∈ ctx then
panic! s!"Var {v} already defined"
.error s!"Symbol '{v}' is already in scope"
else
let idx := ctx.vars.size
{ nameMap := ctx.nameMap.insert v idx, vars := ctx.vars.push (v, k) }
.ok { nameMap := ctx.nameMap.insert v idx,
vars := ctx.vars.push (v, k, .forward) }

/-- Define a symbol. Used by @[declareDatatype], @[declareFn] with body, etc.
Replaces forward declaration, or adds new as defined. -/
def define (ctx : GlobalContext) (v : Var) (k : GlobalKind) : Except String GlobalContext :=
match ctx.nameMap.get? v with
| none =>
-- Not declared, add as defined directly
let idx := ctx.vars.size
.ok { nameMap := ctx.nameMap.insert v idx,
vars := ctx.vars.push (v, k, .defined) }
| some idx =>
let (name, _, state) := ctx.vars[idx]!
match state with
| .forward =>
-- Replace forward declaration with definition (update in place)
.ok { ctx with vars := ctx.vars.set! idx (name, k, .defined) }
| .defined =>
.error s!"Symbol '{v}' is already defined"

/-- Check if a symbol is forward-declared (not yet defined). -/
def isForward (ctx : GlobalContext) (idx : FreeVarIndex) : Bool :=
match ctx.vars[idx]? with
| some (_, _, .forward) => true
| _ => false

/-- Add a symbol as defined. -/
def push (ctx : GlobalContext) (v : Var) (k : GlobalKind) : GlobalContext :=
match ctx.define v k with
| .ok ctx' => ctx'
| .error msg => panic! msg

/-- Return the index of the variable with the given name. -/
def findIndex? (ctx : GlobalContext) (v : Var) : Option FreeVarIndex := ctx.nameMap.get? v
Expand All @@ -1865,7 +1925,7 @@ def nameOf? (ctx : GlobalContext) (idx : FreeVarIndex) : Option String := ctx.va

def kindOf! (ctx : GlobalContext) (idx : FreeVarIndex) : GlobalKind :=
assert! idx < ctx.vars.size
ctx.vars[idx]!.snd
ctx.vars[idx]!.2.1

/-!
## Annotation-based Constructor Info Extraction
Expand Down Expand Up @@ -2066,10 +2126,12 @@ private def addDatatypeBindings

let constructorInfo := extractConstructorInfo dialects args[b.constructorsIndex.toLevel]

-- Step 1: Add datatype type
let gctx := gctx.push datatypeName (GlobalKind.type typeParams.toList none)
let datatypeIndex := gctx.vars.size - 1
let datatypeType := mkDatatypeTypeRef src datatypeIndex typeParams
-- Step 1: Add datatype type (or update forward declaration)
let gctx := match gctx.define datatypeName (GlobalKind.type typeParams.toList none) with
| .ok gctx' => gctx'
| .error msg => panic! s!"addDatatypeBindings: {msg}"
let datatypeIndex := gctx.findIndex? datatypeName |>.getD (gctx.vars.size - 1)
let datatypeType := mkDatatypeTypeRef src datatypeIndex typeParams (some datatypeName)

-- Step 2: Add constructor signatures
let gctx := constructorInfo.foldl (init := gctx) fun gctx constr =>
Expand All @@ -2093,6 +2155,17 @@ def addCommand (dialects : DialectMap) (init : GlobalContext) (op : Operation) :
match b with
| .datatype datatypeSpec =>
addDatatypeBindings dialects gctx l dialectName datatypeSpec args
| .typeForward typeSpec =>
let name :=
match args[typeSpec.nameIndex.toLevel] with
| .ident _ e => e
| a => panic! s!"Expected ident at {typeSpec.nameIndex.toLevel} {repr a}"
match resolveBindingIndices dialects l b args with
| some kind =>
match gctx.declareForward name kind with
| .ok gctx' => gctx'
| .error msg => panic! msg
| none => gctx
| _ =>
let name :=
match args[b.nameIndex.toLevel] with
Expand Down
1 change: 1 addition & 0 deletions Strata/DDM/BuiltinDialects/StrataDDL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def StrataDDL : Dialect := BuiltinM.create! "StrataDDL" #[initDialect] do
-- Metadata for marking an operation as a constructor list push (list followed by constructor)
declareMetadata { name := "constructorListPush", args := #[.mk "list" .ident, .mk "constructor" .ident] }
declareMetadata { name := "declareType", args := #[.mk "name" .ident, .mk "args" (.opt .ident)] }
declareMetadata { name := "declareTypeForward", args := #[.mk "name" .ident, .mk "args" (.opt .ident)] }
declareMetadata { name := "aliasType", args := #[.mk "name" .ident, .mk "args" (.opt .ident), .mk "def" .ident] }
declareMetadata { name := "declare", args := #[.mk "name" .ident, .mk "type" .ident] }
declareMetadata { name := "declareFn", args := #[.mk "name" .ident, .mk "args" .ident, .mk "type" .ident] }
Expand Down
Loading
Loading