forked from strata-org/Strata
-
Notifications
You must be signed in to change notification settings - Fork 0
Support functions with multiple outputs (EliminateMultipleOutputs pass) #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
keyboardDrummer-bot
wants to merge
4
commits into
main
Choose a base branch
from
extract-eliminate-multiple-outputs
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
3a2f259
Support functions with multiple outputs
keyboardDrummer-bot d8fa4d4
Revert LocalVariable AST change, keep original (name, type, initializ…
keyboardDrummer-bot f18bcce
Emit error when LHS target count mismatches function output count
keyboardDrummer-bot 489a5be
Add no-error test for multi-output functions
keyboardDrummer-bot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,181 @@ | ||
| /- | ||
| Copyright Strata Contributors | ||
|
|
||
| SPDX-License-Identifier: Apache-2.0 OR MIT | ||
| -/ | ||
| module | ||
|
|
||
| public import Strata.Languages.Laurel.MapStmtExpr | ||
|
|
||
| /-! | ||
| # Eliminate Multiple Outputs | ||
|
|
||
| Transforms functional procedures (`.isFunctional = true`) with multiple outputs | ||
| into procedures that return a single synthesized result datatype. Call sites are | ||
| rewritten to destructure the result using the generated accessors. | ||
|
|
||
| Emits an error when the number of LHS assignment targets does not match the | ||
| number of outputs of the called function. | ||
|
|
||
| This pass operates on `Program → Program × List DiagnosticModel`. | ||
| -/ | ||
|
|
||
| namespace Strata.Laurel | ||
|
|
||
| public section | ||
|
|
||
| private def mkMd (e : StmtExpr) : StmtExprMd := { val := e, source := none } | ||
| private def mkTy (t : HighType) : HighTypeMd := { val := t, source := none } | ||
|
|
||
| /-- Info about a function whose multiple outputs have been collapsed into a result datatype. -/ | ||
| private structure MultiOutInfo where | ||
| funcName : String | ||
| resultTypeName : String | ||
| constructorName : String | ||
| outputs : List Parameter | ||
|
|
||
| /-- Identify functional procedures with multiple outputs. -/ | ||
| private def collectMultiOutFunctions (procs : List Procedure) : List MultiOutInfo := | ||
| procs.filterMap fun f => | ||
| if f.isFunctional && f.outputs.length > 1 then | ||
| some { | ||
| funcName := f.name.text | ||
| resultTypeName := s!"{f.name.text}$result" | ||
| constructorName := s!"{f.name.text}$result$mk" | ||
| outputs := f.outputs | ||
| } | ||
| else none | ||
|
|
||
| /-- Generate a result datatype for a multi-output function. -/ | ||
| private def mkResultDatatype (info : MultiOutInfo) : DatatypeDefinition := | ||
| let args := info.outputs.zipIdx.map fun (p, i) => | ||
| { name := mkId s!"out{i}", type := p.type : Parameter } | ||
| { name := mkId info.resultTypeName | ||
| typeArgs := [] | ||
| constructors := [{ name := mkId info.constructorName, args := args }] } | ||
|
|
||
| /-- Transform a multi-output function to return the result datatype. -/ | ||
| private def transformFunction (info : MultiOutInfo) (proc : Procedure) : Procedure := | ||
| let resultOutput : Parameter := | ||
| { name := mkId "$result", type := mkTy (.UserDefined (mkId info.resultTypeName)) } | ||
| { proc with outputs := [resultOutput] } | ||
|
|
||
| /-- Destructor name for field `outN` of the result datatype. -/ | ||
| private def destructorName (info : MultiOutInfo) (idx : Nat) : String := | ||
| s!"{info.resultTypeName}..out{idx}" | ||
|
|
||
| private def mismatchError (source : Option FileRange) (callee : String) | ||
| (expected actual : Nat) : DiagnosticModel := | ||
| let msg := s!"call to '{callee}' has {actual} assignment target(s), but the function returns {expected} output(s)" | ||
| match source with | ||
| | some fr => DiagnosticModel.withRange fr msg | ||
| | none => DiagnosticModel.fromMessage msg | ||
|
|
||
| /-- Scan statements for mismatched multi-output call sites, returning diagnostics. -/ | ||
| private def validateStmts (infoMap : Std.HashMap String MultiOutInfo) | ||
| (stmts : List StmtExprMd) : List DiagnosticModel := | ||
| stmts.filterMap fun stmt => | ||
| match stmt.val with | ||
| | .Assign targets ⟨.StaticCall callee _args, callSrc, _⟩ => | ||
| match infoMap.get? callee.text with | ||
| | some info => | ||
| if targets.length != info.outputs.length then | ||
| some (mismatchError (callSrc.orElse fun _ => stmt.source) callee.text info.outputs.length targets.length) | ||
| else none | ||
| | none => none | ||
| | .LocalVariable _name _ty (some ⟨.StaticCall callee _args, callSrc, _⟩) => | ||
| match infoMap.get? callee.text with | ||
| | some info => | ||
| some (mismatchError (callSrc.orElse fun _ => stmt.source) callee.text info.outputs.length 1) | ||
| | none => none | ||
| | _ => none | ||
|
|
||
| /-- Validate all procedure bodies for mismatched call sites. -/ | ||
| private def validateExpr (infoMap : Std.HashMap String MultiOutInfo) | ||
| (expr : StmtExprMd) : List DiagnosticModel := | ||
| StateT.run (s := ([] : List DiagnosticModel)) ( | ||
| mapStmtExprM (m := StateM (List DiagnosticModel)) (fun e => do | ||
| match e.val with | ||
| | .Block stmts _label => | ||
| modify (· ++ validateStmts infoMap stmts) | ||
| | _ => pure () | ||
| return e) expr) |>.2 | ||
|
|
||
| /-- Validate a procedure body. -/ | ||
| private def validateProcedure (infoMap : Std.HashMap String MultiOutInfo) | ||
| (proc : Procedure) : List DiagnosticModel := | ||
| match proc.body with | ||
| | .Transparent b => validateExpr infoMap (mkMd (.Block [b] none)) | ||
| | .Opaque _posts (some impl) _mods => validateExpr infoMap (mkMd (.Block [impl] none)) | ||
| | _ => [] | ||
|
|
||
| /-- Rewrite a statement list, replacing multi-output call patterns. -/ | ||
| private def rewriteStmts (infoMap : Std.HashMap String MultiOutInfo) | ||
| (stmts : List StmtExprMd) : List StmtExprMd := | ||
| let rec go (remaining : List StmtExprMd) (acc : List StmtExprMd) : List StmtExprMd := | ||
| match remaining with | ||
| | [] => acc.reverse | ||
| | stmt :: rest => | ||
| match stmt.val with | ||
| | .Assign targets ⟨.StaticCall callee args, callSrc, callMd⟩ => | ||
| match infoMap.get? callee.text with | ||
| | some info => | ||
| if targets.length != info.outputs.length then | ||
| go rest (stmt :: acc) | ||
| else | ||
| let tempName := mkId s!"${callee.text}$temp" | ||
| let tempTy := mkTy (.UserDefined (mkId info.resultTypeName)) | ||
| let tempDecl := mkMd (.LocalVariable tempName tempTy | ||
| (some ⟨.StaticCall callee args, callSrc, callMd⟩)) | ||
| let assigns := targets.zipIdx.map fun (tgt, i) => | ||
| mkMd (.Assign [tgt] | ||
| (mkMd (.StaticCall (mkId (destructorName info i)) | ||
| [mkMd (.Identifier tempName)]))) | ||
| go rest (assigns.reverse ++ (tempDecl :: acc)) | ||
| | none => go rest (stmt :: acc) | ||
| | _ => go rest (stmt :: acc) | ||
| go stmts [] | ||
|
|
||
| /-- Rewrite blocks in a StmtExprMd tree to handle multi-output calls. -/ | ||
| private def rewriteExpr (infoMap : Std.HashMap String MultiOutInfo) | ||
| (expr : StmtExprMd) : StmtExprMd := | ||
| mapStmtExpr (fun e => | ||
| match e.val with | ||
| | .Block stmts label => ⟨.Block (rewriteStmts infoMap stmts) label, e.source, e.md⟩ | ||
| | _ => e) expr | ||
|
|
||
| /-- Rewrite all procedure bodies. -/ | ||
| private def rewriteProcedure (infoMap : Std.HashMap String MultiOutInfo) | ||
| (proc : Procedure) : Procedure := | ||
| match proc.body with | ||
| | .Transparent b => | ||
| let wrapped := mkMd (.Block [b] none) | ||
| let rewritten := rewriteExpr infoMap wrapped | ||
| { proc with body := .Transparent rewritten } | ||
| | .Opaque posts (some impl) mods => | ||
| let wrapped := mkMd (.Block [impl] none) | ||
| let rewritten := rewriteExpr infoMap wrapped | ||
| { proc with body := .Opaque posts (some rewritten) mods } | ||
| | _ => proc | ||
|
|
||
| /-- Eliminate multiple outputs from a Program. Only applies to functional procedures. -/ | ||
| def eliminateMultipleOutputs (program : Program) : Program × List DiagnosticModel := | ||
| let infos := collectMultiOutFunctions program.staticProcedures | ||
| if infos.isEmpty then (program, []) else | ||
| let infoMap : Std.HashMap String MultiOutInfo := | ||
| infos.foldl (fun m info => m.insert info.funcName info) {} | ||
| -- Validate all call sites first | ||
| let diags := program.staticProcedures.flatMap (validateProcedure infoMap) | ||
| -- If there are errors, return the program unchanged | ||
| if !diags.isEmpty then (program, diags) else | ||
| let newDatatypes := infos.map mkResultDatatype | ||
| let procs := program.staticProcedures.map fun f => | ||
| match infoMap.get? f.name.text with | ||
| | some info => rewriteProcedure infoMap (transformFunction info f) | ||
| | none => rewriteProcedure infoMap f | ||
| ({ program with | ||
| staticProcedures := procs | ||
| types := program.types ++ newDatatypes.map TypeDefinition.Datatype }, []) | ||
|
|
||
| end -- public section | ||
| end Strata.Laurel |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 36 additions & 0 deletions
36
StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| /- | ||
| Copyright Strata Contributors | ||
|
|
||
| SPDX-License-Identifier: Apache-2.0 OR MIT | ||
| -/ | ||
|
|
||
| import StrataTest.Util.TestDiagnostics | ||
| import StrataTest.Languages.Laurel.TestExamples | ||
|
|
||
| open StrataTest.Util | ||
| open Strata | ||
|
|
||
| namespace Strata.Laurel | ||
|
|
||
| /-! # Multiple Output Functions | ||
|
|
||
| Tests that functions with multiple output parameters are correctly handled | ||
| by the EliminateMultipleOutputs pass, which synthesizes a result datatype | ||
| and rewrites call sites. | ||
| -/ | ||
|
|
||
| def multiOutputProgram := r" | ||
| function twoOutputs(x: int) | ||
| returns (a: int, b: int); | ||
|
|
||
| procedure testMultiOut() { | ||
| var a: int; | ||
| a := twoOutputs(5) | ||
| // ^^^^^^^^^^^^^ error: call to 'twoOutputs' has 1 assignment target(s), but the function returns 2 output(s) | ||
| }; | ||
| " | ||
|
|
||
| #guard_msgs (drop info, error) in | ||
| #eval testInputWithOffset "MultiOutput" multiOutputProgram 14 processLaurelFile | ||
|
|
||
| end Strata.Laurel | ||
29 changes: 29 additions & 0 deletions
29
StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputsNoError.lean
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| /- | ||
| Copyright Strata Contributors | ||
|
|
||
| SPDX-License-Identifier: Apache-2.0 OR MIT | ||
| -/ | ||
|
|
||
| import StrataTest.Util.TestDiagnostics | ||
| import StrataTest.Languages.Laurel.TestExamples | ||
|
|
||
| open StrataTest.Util | ||
| open Strata | ||
|
|
||
| namespace Strata.Laurel | ||
|
|
||
| /-! # Multiple Output Functions (No Error) | ||
|
|
||
| Tests that a function with multiple output parameters is accepted | ||
| by the EliminateMultipleOutputs pass when there is no mismatched call site. | ||
| -/ | ||
|
|
||
| def multiOutputNoErrorProgram := r" | ||
| function twoOutputs(x: int) | ||
| returns (a: int, b: int); | ||
| " | ||
|
|
||
| #guard_msgs (drop info, error) in | ||
| #eval testInputWithOffset "MultiOutputNoError" multiOutputNoErrorProgram 20 processLaurelFile | ||
|
|
||
| end Strata.Laurel |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@keyboardDrummer-bot Can you also add a test, which will need to be in another program, where there is no error?