Skip to content
Open
3 changes: 2 additions & 1 deletion .kiro/settings/mcp.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
"lean_loogle",
"lean_leanfinder",
"lean_state_search",
"lean_hammer_premise"
"lean_hammer_premise",
"lean_file_outline"
]
}
}
Expand Down
3 changes: 2 additions & 1 deletion Strata.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ import Strata.DL.Imperative.Imperative
/- Utilities -/
import Strata.Util.Sarif

/- Strata Core -/
/- Strata Languages -/
import Strata.Languages.Core.StatementSemantics
import Strata.Languages.Laurel.LaurelToCoreTranslator
import Strata.Languages.Core.SarifOutput

/- Backends -/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ partial def translateStmtExpr (arg : Arg) : TransM StmtExpr := do
let target ← translateStmtExpr arg0
let value ← translateStmtExpr arg1
let md ← getArgMetaData (.op op)
return .Assign target value md
return .Assign [target] value md
| q`Laurel.call, #[arg0, argsSeq] =>
let callee ← translateStmtExpr arg0
let calleeName := match callee with
Expand Down
18 changes: 9 additions & 9 deletions Strata/Languages/Laurel/Grammar/LaurelGrammar.st
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ op parenthesis (inner: StmtExpr): StmtExpr => "(" inner ")";
op assign (target: StmtExpr, value: StmtExpr): StmtExpr => @[prec(10)] target ":=" value ";";

// Binary operators
op add (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(60)] lhs "+" rhs;
op eq (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs "==" rhs;
op neq (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs "!=" rhs;
op gt (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs ">" rhs;
op lt (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs "<" rhs;
op le (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs "<=" rhs;
op ge (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs ">=" rhs;
op add (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(60), leftassoc] lhs " + " rhs;
op eq (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40), leftassoc] lhs " == " rhs;
op neq (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40), leftassoc] lhs " != " rhs;
op gt (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40), leftassoc] lhs " > " rhs;
op lt (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40), leftassoc] lhs " < " rhs;
op le (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40), leftassoc] lhs " <= " rhs;
op ge (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40), leftassoc] lhs " >= " rhs;

// Logical operators
op and (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(30)] lhs "&&" rhs;
op or (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(25)] lhs "||" rhs;
op and (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(30), leftassoc] lhs " && " rhs;
op or (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(25), leftassoc] lhs " || " rhs;

// If-else
category OptionalElse;
Expand Down
154 changes: 112 additions & 42 deletions Strata/Languages/Laurel/HeapParameterization.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,24 @@ import Strata.Languages.Laurel.LaurelFormat
/-
Heap Parameterization Pass

Transforms procedures that interact with the heap using a global `$heap` variable:
Transforms procedures that interact with the heap by adding explicit heap parameters:

1. All procedures that read or write fields use the global `$heap` variable
- Field reads are translated to calls to `heapRead($heap, <fieldConstant>)`
- Field writes are translated to assignments to `$heap` via `heapStore`
1. Procedures that write the heap get an inout heap parameter
- Input: `heap : THeap`
- Output: `heap : THeap`
- Field writes become: `heap := heapStore(heap, obj, field, value)`

2. No heap parameters are added to procedure signatures
- The heap is accessed as a global variable
- Procedure calls don't pass or receive heap values
2. Procedures that only read the heap get an in heap parameter
- Input: `heap : THeap`
- Field reads become: `heapRead(heap, obj, field)`

3. Procedure calls are transformed:
- Calls to heap-writing procedures in expressions:
`f()` => `(var freshVar: type; freshVar, heap := f(heap); freshVar)`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on what you wrote below, should this be heap, freshVar := f(heap)?

- Calls to heap-writing procedures as statements:
`f()` => `heap := f(heap)`
- Calls to heap-reading procedures:
`f()` => `f(heap)`

The analysis is transitive: if procedure A calls procedure B, and B reads/writes the heap,
then A is also considered to read/write the heap.
Expand All @@ -42,13 +51,14 @@ partial def collectExpr (expr : StmtExpr) : StateM AnalysisResult Unit := do
| .LocalVariable _ _ i => if let some x := i then collectExpr x
| .While c i d b => collectExpr c; collectExpr b; if let some x := i then collectExpr x; if let some x := d then collectExpr x
| .Return v => if let some x := v then collectExpr x
| .Assign t v _ =>
-- Check if this is a field assignment (heap write)
match t with
| .FieldSelect target _ =>
modify fun s => { s with writesHeapDirectly := true }
collectExpr target
| _ => collectExpr t
| .Assign targets v _ =>
-- Check if any target is a field assignment (heap write)
for t in targets do
match t with
| .FieldSelect target _ =>
modify fun s => { s with writesHeapDirectly := true }
collectExpr target
| _ => collectExpr t
collectExpr v
| .PureFieldUpdate t _ v => collectExpr t; collectExpr v
| .PrimitiveOp _ args => for a in args do collectExpr a
Expand Down Expand Up @@ -117,6 +127,7 @@ structure TransformState where
heapReaders : List Identifier
heapWriters : List Identifier
fieldTypes : List (Identifier × HighType) := [] -- Maps field names to their value types
freshCounter : Nat := 0 -- Counter for generating fresh variable names

abbrev TransformM := StateM TransformState

Expand All @@ -133,6 +144,11 @@ def readsHeap (name : Identifier) : TransformM Bool := do
def writesHeap (name : Identifier) : TransformM Bool := do
return (← get).heapWriters.contains name

def freshVarName : TransformM Identifier := do
let s ← get
set { s with freshCounter := s.freshCounter + 1 }
return s!"$tmp{s.freshCounter}"

partial def heapTransformExpr (heapVar : Identifier) (expr : StmtExpr) : TransformM StmtExpr := do
match expr with
| .FieldSelect target fieldName =>
Expand All @@ -144,8 +160,26 @@ partial def heapTransformExpr (heapVar : Identifier) (expr : StmtExpr) : Transfo
return .StaticCall "heapRead" [.Identifier heapVar, t, .Identifier fieldName]
| .StaticCall callee args =>
let args' ← args.mapM (heapTransformExpr heapVar)
-- Heap is global, so no need to pass it as parameter
return .StaticCall callee args'
let calleeReadsHeap ← readsHeap callee
let calleeWritesHeap ← writesHeap callee
if calleeWritesHeap then
-- Heap-writing procedure call in expression context:
-- f(args) => (var freshVar: type; heapVar, freshVar := f(heapVar, args); freshVar)
-- The callee takes heap_in and returns (heap_out, result), we pass our heapVar and receive back into heapVar
let freshVar ← freshVarName
let varDecl := StmtExpr.LocalVariable freshVar .TInt none
-- Call with heapVar as first argument, receives (heap_out, result) which we assign to [heapVar, freshVar]
let callWithHeap := StmtExpr.Assign
[.Identifier heapVar, .Identifier freshVar]
(.StaticCall callee (StmtExpr.Identifier heapVar :: args'))
.empty
return .Block [varDecl, callWithHeap, .Identifier freshVar] none
else if calleeReadsHeap then
-- Heap-reading procedure: add heapVar as first argument (callee expects heap_in)
return .StaticCall callee (StmtExpr.Identifier heapVar :: args')
else
-- Non-heap procedure: no change
return .StaticCall callee args'
| .InstanceCall target callee args =>
let t ← heapTransformExpr heapVar target
let args' ← args.mapM (heapTransformExpr heapVar)
Expand All @@ -155,18 +189,25 @@ partial def heapTransformExpr (heapVar : Identifier) (expr : StmtExpr) : Transfo
| .LocalVariable n ty i => return .LocalVariable n ty (← i.mapM (heapTransformExpr heapVar))
| .While c i d b => return .While (← heapTransformExpr heapVar c) (← i.mapM (heapTransformExpr heapVar)) (← d.mapM (heapTransformExpr heapVar)) (← heapTransformExpr heapVar b)
| .Return v => return .Return (← v.mapM (heapTransformExpr heapVar))
| .Assign t v md =>
match t with
| .FieldSelect target fieldName =>
| .Assign targets v md =>
-- Check if first target is a field select (heap write)
match targets with
| [StmtExpr.FieldSelect target fieldName] =>
let fieldType ← lookupFieldType fieldName
match fieldType with
| some ty => addFieldConstant fieldName ty
| none => addFieldConstant fieldName .TInt -- Fallback to int if type unknown
let target' ← heapTransformExpr heapVar target
let v' ← heapTransformExpr heapVar v
-- Assign to global heap variable
return .Assign (.Identifier heapVar) (.StaticCall "heapStore" [.Identifier heapVar, target', .Identifier fieldName, v']) md
| _ => return .Assign (← heapTransformExpr heapVar t) (← heapTransformExpr heapVar v) md
-- Assign to heap variable, but wrap in a block that returns the stored value
-- This ensures that when used in expression context, the value is the stored value, not the heap
let heapAssign := StmtExpr.Assign [StmtExpr.Identifier heapVar] (.StaticCall "heapStore" [.Identifier heapVar, target', .Identifier fieldName, v']) md
return .Block [heapAssign, v'] none
| _ =>
-- Transform all targets and value
let targets' ← targets.mapM (heapTransformExpr heapVar)
let v' ← heapTransformExpr heapVar v
return .Assign targets' v' md
| .PureFieldUpdate t f v => return .PureFieldUpdate (← heapTransformExpr heapVar t) f (← heapTransformExpr heapVar v)
| .PrimitiveOp op args => return .PrimitiveOp op (← args.mapM (heapTransformExpr heapVar))
| .ReferenceEquals l r => return .ReferenceEquals (← heapTransformExpr heapVar l) (← heapTransformExpr heapVar r)
Expand All @@ -184,51 +225,80 @@ partial def heapTransformExpr (heapVar : Identifier) (expr : StmtExpr) : Transfo
| other => return other

def heapTransformProcedure (proc : Procedure) : TransformM Procedure := do
let heapName := "$heap"
let heapInName := "heap_in"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to use some names that aren't legal variable names for user-written code.

let heapOutName := "heap_out"
let readsHeap := (← get).heapReaders.contains proc.name
let writesHeap := (← get).heapWriters.contains proc.name

if readsHeap || writesHeap then
-- This procedure reads or writes the heap - transform to use global $heap
let precondition' ← heapTransformExpr heapName proc.precondition
if writesHeap then
-- This procedure writes the heap - add heap_in as input and heap_out as output
-- At the start, assign heap_in to heap_out, then use heap_out throughout
let heapInParam : Parameter := { name := heapInName, type := .THeap }
let heapOutParam : Parameter := { name := heapOutName, type := .THeap }

let inputs' := heapInParam :: proc.inputs
let outputs' := heapOutParam :: proc.outputs

-- Precondition uses heap_in (the input state)
let precondition' ← heapTransformExpr heapInName proc.precondition

let body' ← match proc.body with
| .Transparent bodyExpr =>
let bodyExpr' ← heapTransformExpr heapName bodyExpr
pure (.Transparent bodyExpr')
-- First assign heap_in to heap_out, then transform body using heap_out
let assignHeapOut := StmtExpr.Assign [StmtExpr.Identifier heapOutName] (StmtExpr.Identifier heapInName) .empty
let bodyExpr' ← heapTransformExpr heapOutName bodyExpr
pure (.Transparent (.Block [assignHeapOut, bodyExpr'] none))
| .Opaque postcond impl modif =>
let postcond' ← heapTransformExpr heapName postcond
let impl' ← impl.mapM (heapTransformExpr heapName)
let modif' ← modif.mapM (heapTransformExpr heapName)
-- Postcondition uses heap_out (the output state)
let postcond' ← heapTransformExpr heapOutName postcond
let impl' ← match impl with
| some implExpr =>
let assignHeapOut := StmtExpr.Assign [StmtExpr.Identifier heapOutName] (StmtExpr.Identifier heapInName) .empty
let implExpr' ← heapTransformExpr heapOutName implExpr
pure (some (.Block [assignHeapOut, implExpr'] none))
| none => pure none
let modif' ← modif.mapM (heapTransformExpr heapOutName)
pure (.Opaque postcond' impl' modif')
| .Abstract postcond =>
let postcond' ← heapTransformExpr heapName postcond
let postcond' ← heapTransformExpr heapOutName postcond
pure (.Abstract postcond')

return { proc with
inputs := inputs',
outputs := outputs',
precondition := precondition',
body := body' }

else
-- This procedure doesn't read or write the heap
-- Still transform contracts in case they reference fields
let precondition' ← heapTransformExpr heapName proc.precondition
else if readsHeap then
-- This procedure only reads the heap - add heap_in as input only
let heapInParam : Parameter := { name := heapInName, type := .THeap }
let inputs' := heapInParam :: proc.inputs

let precondition' ← heapTransformExpr heapInName proc.precondition

let body' ← match proc.body with
| .Transparent bodyExpr =>
pure (.Transparent bodyExpr)
let bodyExpr' ← heapTransformExpr heapInName bodyExpr
pure (.Transparent bodyExpr')
| .Opaque postcond impl modif =>
let postcond' ← heapTransformExpr heapName postcond
pure (.Opaque postcond' impl modif)
let postcond' ← heapTransformExpr heapInName postcond
let impl' ← impl.mapM (heapTransformExpr heapInName)
let modif' ← modif.mapM (heapTransformExpr heapInName)
pure (.Opaque postcond' impl' modif')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you ever wanted to avoid parentheses, you could use the following idiom

Suggested change
pure (.Opaque postcond' impl' modif')
return .Opaque postcond' impl' modif'

| .Abstract postcond =>
let postcond' ← heapTransformExpr heapName postcond
let postcond' ← heapTransformExpr heapInName postcond
pure (.Abstract postcond')

return { proc with
inputs := inputs',
precondition := precondition',
body := body' }

def heapParameterization (program : Program) : Program × List Identifier :=
else
-- This procedure doesn't read or write the heap - no changes needed
return proc

def heapParameterization (program : Program) : Program :=
let heapReaders := computeReadsHeap program.staticProcedures
let heapWriters := computeWritesHeap program.staticProcedures
-- Extract field types from composite type definitions
Expand All @@ -240,6 +310,6 @@ def heapParameterization (program : Program) : Program × List Identifier :=
dbg_trace s!"Heap readers: {heapReaders}"
dbg_trace s!"Heap writers: {heapWriters}"
let (procs', finalState) := (program.staticProcedures.mapM heapTransformProcedure).run { heapReaders, heapWriters, fieldTypes }
({ program with staticProcedures := procs', constants := program.constants ++ finalState.fieldConstants }, heapWriters)
{ program with staticProcedures := procs', constants := program.constants ++ finalState.fieldConstants }

end Strata.Laurel
7 changes: 5 additions & 2 deletions Strata/Languages/Laurel/Laurel.lean
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,11 @@ inductive StmtExpr : Type where
| LiteralInt (value: Int)
| LiteralBool (value: Bool)
| Identifier (name : Identifier)
/- Assign is only allowed in an impure context -/
| Assign (target : StmtExpr) (value : StmtExpr) (md : Imperative.MetaData Core.Expression)
/- Assign is only allowed in an impure context.
For single target assignments, use a single-element list.
Multiple targets are only allowed when the value is a StaticCall to a procedure
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, this should be enforced by a WF predicate or similar.

with multiple outputs, and the number of targets must match the number of outputs. -/
| Assign (targets : List StmtExpr) (value : StmtExpr) (md : Imperative.MetaData Core.Expression)
/- Used by itself for fields reads and in combination with Assign for field writes -/
| FieldSelect (target : StmtExpr) (fieldName : Identifier)
/- PureFieldUpdate is the only way to assign values to fields of pure types -/
Expand Down
21 changes: 10 additions & 11 deletions Strata/Languages/Laurel/LaurelFormat.lean
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def formatHighType : HighType → Format
Format.joinSep (types.map formatHighType) " & "

def formatStmtExpr (s:StmtExpr) : Format :=
match h: s with
match s with
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I don't know if it applies here but if you ever need to do termination proofs and naming the matches causes a spurious "unused variable" despite the variable being used in automatic proofs, you could also use '_' for the variable name.

Suggested change
match s with
match _: s with

| .IfThenElse cond thenBr elseBr =>
"if " ++ formatStmtExpr cond ++ " then " ++ formatStmtExpr thenBr ++
match elseBr with
Expand All @@ -69,19 +69,22 @@ def formatStmtExpr (s:StmtExpr) : Format :=
| .LiteralInt n => Format.text (toString n)
| .LiteralBool b => if b then "true" else "false"
| .Identifier name => Format.text name
| .Assign target value _ =>
formatStmtExpr target ++ " := " ++ formatStmtExpr value
| .Assign [single] value _ =>
formatStmtExpr single ++ " := " ++ formatStmtExpr value
| .Assign targets value _ =>
"(" ++ Format.joinSep (targets.map formatStmtExpr) ", " ++ ")" ++ " := " ++ formatStmtExpr value
| .FieldSelect target field =>
formatStmtExpr target ++ "#" ++ Format.text field
| .PureFieldUpdate target field value =>
formatStmtExpr target ++ " with { " ++ Format.text field ++ " := " ++ formatStmtExpr value ++ " }"
| .StaticCall name args =>
Format.text name ++ "(" ++ Format.joinSep (args.map formatStmtExpr) ", " ++ ")"
| .PrimitiveOp op [a] =>
formatOperation op ++ formatStmtExpr a
| .PrimitiveOp op [a, b] =>
formatStmtExpr a ++ " " ++ formatOperation op ++ " " ++ formatStmtExpr b
| .PrimitiveOp op args =>
match args with
| [a] => formatOperation op ++ formatStmtExpr a
| [a, b] => formatStmtExpr a ++ " " ++ formatOperation op ++ " " ++ formatStmtExpr b
| _ => formatOperation op ++ "(" ++ Format.joinSep (args.map formatStmtExpr) ", " ++ ")"
formatOperation op ++ "(" ++ Format.joinSep (args.map formatStmtExpr) ", " ++ ")"
| .This => "this"
| .ReferenceEquals lhs rhs =>
formatStmtExpr lhs ++ " === " ++ formatStmtExpr rhs
Expand All @@ -107,10 +110,6 @@ def formatStmtExpr (s:StmtExpr) : Format :=
| .Abstract => "abstract"
| .All => "all"
| .Hole => "<?>"
decreasing_by
all_goals (simp_wf; try omega)
any_goals (rename_i x_in; have := List.sizeOf_lt_of_mem x_in; omega)
subst_vars; cases h; rename_i x_in; have := List.sizeOf_lt_of_mem x_in; omega

def formatParameter (p : Parameter) : Format :=
Format.text p.name ++ ": " ++ formatHighType p.type
Expand Down
Loading
Loading