Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions Strata/Languages/Laurel/EliminateMultipleOutputs.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/-
Copyright Strata Contributors

SPDX-License-Identifier: Apache-2.0 OR MIT
-/
module

public import Strata.Languages.Laurel.MapStmtExpr

/-!
# Eliminate Multiple Outputs

Transforms functional procedures (`.isFunctional = true`) with multiple outputs
into procedures that return a single synthesized result datatype. Call sites are
rewritten to destructure the result using the generated accessors.

Emits an error when the number of LHS assignment targets does not match the
number of outputs of the called function.

This pass operates on `Program → Program × List DiagnosticModel`.
-/

namespace Strata.Laurel

public section

private def mkMd (e : StmtExpr) : StmtExprMd := { val := e, source := none }
private def mkTy (t : HighType) : HighTypeMd := { val := t, source := none }

/-- Info about a function whose multiple outputs have been collapsed into a result datatype. -/
private structure MultiOutInfo where
funcName : String
resultTypeName : String
constructorName : String
outputs : List Parameter

/-- Identify functional procedures with multiple outputs. -/
private def collectMultiOutFunctions (procs : List Procedure) : List MultiOutInfo :=
procs.filterMap fun f =>
if f.isFunctional && f.outputs.length > 1 then
some {
funcName := f.name.text
resultTypeName := s!"{f.name.text}$result"
constructorName := s!"{f.name.text}$result$mk"
outputs := f.outputs
}
else none

/-- Generate a result datatype for a multi-output function. -/
private def mkResultDatatype (info : MultiOutInfo) : DatatypeDefinition :=
let args := info.outputs.zipIdx.map fun (p, i) =>
{ name := mkId s!"out{i}", type := p.type : Parameter }
{ name := mkId info.resultTypeName
typeArgs := []
constructors := [{ name := mkId info.constructorName, args := args }] }

/-- Transform a multi-output function to return the result datatype. -/
private def transformFunction (info : MultiOutInfo) (proc : Procedure) : Procedure :=
let resultOutput : Parameter :=
{ name := mkId "$result", type := mkTy (.UserDefined (mkId info.resultTypeName)) }
{ proc with outputs := [resultOutput] }

/-- Destructor name for field `outN` of the result datatype. -/
private def destructorName (info : MultiOutInfo) (idx : Nat) : String :=
s!"{info.resultTypeName}..out{idx}"

private def mismatchError (source : Option FileRange) (callee : String)
(expected actual : Nat) : DiagnosticModel :=
let msg := s!"call to '{callee}' has {actual} assignment target(s), but the function returns {expected} output(s)"
match source with
| some fr => DiagnosticModel.withRange fr msg
| none => DiagnosticModel.fromMessage msg

/-- Scan statements for mismatched multi-output call sites, returning diagnostics. -/
private def validateStmts (infoMap : Std.HashMap String MultiOutInfo)
(stmts : List StmtExprMd) : List DiagnosticModel :=
stmts.filterMap fun stmt =>
match stmt.val with
| .Assign targets ⟨.StaticCall callee _args, callSrc, _⟩ =>
match infoMap.get? callee.text with
| some info =>
if targets.length != info.outputs.length then
some (mismatchError (callSrc.orElse fun _ => stmt.source) callee.text info.outputs.length targets.length)
else none
| none => none
| .LocalVariable _name _ty (some ⟨.StaticCall callee _args, callSrc, _⟩) =>
match infoMap.get? callee.text with
| some info =>
some (mismatchError (callSrc.orElse fun _ => stmt.source) callee.text info.outputs.length 1)
| none => none
| _ => none

/-- Validate all procedure bodies for mismatched call sites. -/
private def validateExpr (infoMap : Std.HashMap String MultiOutInfo)
(expr : StmtExprMd) : List DiagnosticModel :=
StateT.run (s := ([] : List DiagnosticModel)) (
mapStmtExprM (m := StateM (List DiagnosticModel)) (fun e => do
match e.val with
| .Block stmts _label =>
modify (· ++ validateStmts infoMap stmts)
| _ => pure ()
return e) expr) |>.2

/-- Validate a procedure body. -/
private def validateProcedure (infoMap : Std.HashMap String MultiOutInfo)
(proc : Procedure) : List DiagnosticModel :=
match proc.body with
| .Transparent b => validateExpr infoMap (mkMd (.Block [b] none))
| .Opaque _posts (some impl) _mods => validateExpr infoMap (mkMd (.Block [impl] none))
| _ => []

/-- Rewrite a statement list, replacing multi-output call patterns. -/
private def rewriteStmts (infoMap : Std.HashMap String MultiOutInfo)
(stmts : List StmtExprMd) : List StmtExprMd :=
let rec go (remaining : List StmtExprMd) (acc : List StmtExprMd) : List StmtExprMd :=
match remaining with
| [] => acc.reverse
| stmt :: rest =>
match stmt.val with
| .Assign targets ⟨.StaticCall callee args, callSrc, callMd⟩ =>
match infoMap.get? callee.text with
| some info =>
if targets.length != info.outputs.length then
go rest (stmt :: acc)
else
let tempName := mkId s!"${callee.text}$temp"
let tempTy := mkTy (.UserDefined (mkId info.resultTypeName))
let tempDecl := mkMd (.LocalVariable tempName tempTy
(some ⟨.StaticCall callee args, callSrc, callMd⟩))
let assigns := targets.zipIdx.map fun (tgt, i) =>
mkMd (.Assign [tgt]
(mkMd (.StaticCall (mkId (destructorName info i))
[mkMd (.Identifier tempName)])))
go rest (assigns.reverse ++ (tempDecl :: acc))
| none => go rest (stmt :: acc)
| _ => go rest (stmt :: acc)
go stmts []

/-- Rewrite blocks in a StmtExprMd tree to handle multi-output calls. -/
private def rewriteExpr (infoMap : Std.HashMap String MultiOutInfo)
(expr : StmtExprMd) : StmtExprMd :=
mapStmtExpr (fun e =>
match e.val with
| .Block stmts label => ⟨.Block (rewriteStmts infoMap stmts) label, e.source, e.md⟩
| _ => e) expr

/-- Rewrite all procedure bodies. -/
private def rewriteProcedure (infoMap : Std.HashMap String MultiOutInfo)
(proc : Procedure) : Procedure :=
match proc.body with
| .Transparent b =>
let wrapped := mkMd (.Block [b] none)
let rewritten := rewriteExpr infoMap wrapped
{ proc with body := .Transparent rewritten }
| .Opaque posts (some impl) mods =>
let wrapped := mkMd (.Block [impl] none)
let rewritten := rewriteExpr infoMap wrapped
{ proc with body := .Opaque posts (some rewritten) mods }
| _ => proc

/-- Eliminate multiple outputs from a Program. Only applies to functional procedures. -/
def eliminateMultipleOutputs (program : Program) : Program × List DiagnosticModel :=
let infos := collectMultiOutFunctions program.staticProcedures
if infos.isEmpty then (program, []) else
let infoMap : Std.HashMap String MultiOutInfo :=
infos.foldl (fun m info => m.insert info.funcName info) {}
-- Validate all call sites first
let diags := program.staticProcedures.flatMap (validateProcedure infoMap)
-- If there are errors, return the program unchanged
if !diags.isEmpty then (program, diags) else
let newDatatypes := infos.map mkResultDatatype
let procs := program.staticProcedures.map fun f =>
match infoMap.get? f.name.text with
| some info => rewriteProcedure infoMap (transformFunction info f)
| none => rewriteProcedure infoMap f
({ program with
staticProcedures := procs
types := program.types ++ newDatatypes.map TypeDefinition.Datatype }, [])

end -- public section
end Strata.Laurel
8 changes: 7 additions & 1 deletion Strata/Languages/Laurel/LaurelCompilationPipeline.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public import Strata.Languages.Laurel.LaurelToCoreTranslator
import Strata.Languages.Laurel.DesugarShortCircuit
import Strata.Languages.Laurel.EliminateReturnsInExpression
import Strata.Languages.Laurel.EliminateValueReturns
import Strata.Languages.Laurel.EliminateMultipleOutputs
import Strata.Languages.Laurel.ConstrainedTypeElim
import Strata.Languages.Core.Verifier

Expand Down Expand Up @@ -110,8 +111,13 @@ private def runLaurelPasses (options : LaurelTranslateOptions) (program : Progra
let (program, model) := (result.program, result.model)
emit "ConstrainedTypeElim" program

let (program, multiOutDiags) := eliminateMultipleOutputs program
let result := resolve program (some model)
let (program, model) := (result.program, result.model)
emit "EliminateMultipleOutputs" program

let allDiags := resolutionErrors ++ diamondErrors ++ nonCompositeDiags ++
valueReturnDiags.toList ++ modifiesDiags ++ constrainedTypeDiags
valueReturnDiags.toList ++ modifiesDiags ++ constrainedTypeDiags ++ multiOutDiags
return (program, model, allDiags)

/--
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/-
Copyright Strata Contributors

SPDX-License-Identifier: Apache-2.0 OR MIT
-/

import StrataTest.Util.TestDiagnostics
import StrataTest.Languages.Laurel.TestExamples

open StrataTest.Util
open Strata

namespace Strata.Laurel

/-! # Multiple Output Functions

Tests that functions with multiple output parameters are correctly handled
by the EliminateMultipleOutputs pass, which synthesizes a result datatype
and rewrites call sites.
-/

def multiOutputProgram := r"
function twoOutputs(x: int)
returns (a: int, b: int);

procedure testMultiOut() {
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@keyboardDrummer-bot Can you also add a test, which will need to be in another program, where there is no error?

var a: int;
a := twoOutputs(5)
// ^^^^^^^^^^^^^ error: call to 'twoOutputs' has 1 assignment target(s), but the function returns 2 output(s)
};
"

#guard_msgs (drop info, error) in
#eval testInputWithOffset "MultiOutput" multiOutputProgram 14 processLaurelFile

end Strata.Laurel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/-
Copyright Strata Contributors

SPDX-License-Identifier: Apache-2.0 OR MIT
-/

import StrataTest.Util.TestDiagnostics
import StrataTest.Languages.Laurel.TestExamples

open StrataTest.Util
open Strata

namespace Strata.Laurel

/-! # Multiple Output Functions (No Error)

Tests that a function with multiple output parameters is accepted
by the EliminateMultipleOutputs pass when there is no mismatched call site.
-/

def multiOutputNoErrorProgram := r"
function twoOutputs(x: int)
returns (a: int, b: int);
"

#guard_msgs (drop info, error) in
#eval testInputWithOffset "MultiOutputNoError" multiOutputNoErrorProgram 20 processLaurelFile

end Strata.Laurel
Loading