diff --git a/Strata/Languages/Laurel/EliminateMultipleOutputs.lean b/Strata/Languages/Laurel/EliminateMultipleOutputs.lean new file mode 100644 index 0000000000..4945e6ddb8 --- /dev/null +++ b/Strata/Languages/Laurel/EliminateMultipleOutputs.lean @@ -0,0 +1,181 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ +module + +public import Strata.Languages.Laurel.MapStmtExpr + +/-! +# Eliminate Multiple Outputs + +Transforms functional procedures (`.isFunctional = true`) with multiple outputs +into procedures that return a single synthesized result datatype. Call sites are +rewritten to destructure the result using the generated accessors. + +Emits an error when the number of LHS assignment targets does not match the +number of outputs of the called function. + +This pass operates on `Program → Program × List DiagnosticModel`. +-/ + +namespace Strata.Laurel + +public section + +private def mkMd (e : StmtExpr) : StmtExprMd := { val := e, source := none } +private def mkTy (t : HighType) : HighTypeMd := { val := t, source := none } + +/-- Info about a function whose multiple outputs have been collapsed into a result datatype. -/ +private structure MultiOutInfo where + funcName : String + resultTypeName : String + constructorName : String + outputs : List Parameter + +/-- Identify functional procedures with multiple outputs. -/ +private def collectMultiOutFunctions (procs : List Procedure) : List MultiOutInfo := + procs.filterMap fun f => + if f.isFunctional && f.outputs.length > 1 then + some { + funcName := f.name.text + resultTypeName := s!"{f.name.text}$result" + constructorName := s!"{f.name.text}$result$mk" + outputs := f.outputs + } + else none + +/-- Generate a result datatype for a multi-output function. -/ +private def mkResultDatatype (info : MultiOutInfo) : DatatypeDefinition := + let args := info.outputs.zipIdx.map fun (p, i) => + { name := mkId s!"out{i}", type := p.type : Parameter } + { name := mkId info.resultTypeName + typeArgs := [] + constructors := [{ name := mkId info.constructorName, args := args }] } + +/-- Transform a multi-output function to return the result datatype. -/ +private def transformFunction (info : MultiOutInfo) (proc : Procedure) : Procedure := + let resultOutput : Parameter := + { name := mkId "$result", type := mkTy (.UserDefined (mkId info.resultTypeName)) } + { proc with outputs := [resultOutput] } + +/-- Destructor name for field `outN` of the result datatype. -/ +private def destructorName (info : MultiOutInfo) (idx : Nat) : String := + s!"{info.resultTypeName}..out{idx}" + +private def mismatchError (source : Option FileRange) (callee : String) + (expected actual : Nat) : DiagnosticModel := + let msg := s!"call to '{callee}' has {actual} assignment target(s), but the function returns {expected} output(s)" + match source with + | some fr => DiagnosticModel.withRange fr msg + | none => DiagnosticModel.fromMessage msg + +/-- Scan statements for mismatched multi-output call sites, returning diagnostics. -/ +private def validateStmts (infoMap : Std.HashMap String MultiOutInfo) + (stmts : List StmtExprMd) : List DiagnosticModel := + stmts.filterMap fun stmt => + match stmt.val with + | .Assign targets ⟨.StaticCall callee _args, callSrc, _⟩ => + match infoMap.get? callee.text with + | some info => + if targets.length != info.outputs.length then + some (mismatchError (callSrc.orElse fun _ => stmt.source) callee.text info.outputs.length targets.length) + else none + | none => none + | .LocalVariable _name _ty (some ⟨.StaticCall callee _args, callSrc, _⟩) => + match infoMap.get? callee.text with + | some info => + some (mismatchError (callSrc.orElse fun _ => stmt.source) callee.text info.outputs.length 1) + | none => none + | _ => none + +/-- Validate all procedure bodies for mismatched call sites. -/ +private def validateExpr (infoMap : Std.HashMap String MultiOutInfo) + (expr : StmtExprMd) : List DiagnosticModel := + StateT.run (s := ([] : List DiagnosticModel)) ( + mapStmtExprM (m := StateM (List DiagnosticModel)) (fun e => do + match e.val with + | .Block stmts _label => + modify (· ++ validateStmts infoMap stmts) + | _ => pure () + return e) expr) |>.2 + +/-- Validate a procedure body. -/ +private def validateProcedure (infoMap : Std.HashMap String MultiOutInfo) + (proc : Procedure) : List DiagnosticModel := + match proc.body with + | .Transparent b => validateExpr infoMap (mkMd (.Block [b] none)) + | .Opaque _posts (some impl) _mods => validateExpr infoMap (mkMd (.Block [impl] none)) + | _ => [] + +/-- Rewrite a statement list, replacing multi-output call patterns. -/ +private def rewriteStmts (infoMap : Std.HashMap String MultiOutInfo) + (stmts : List StmtExprMd) : List StmtExprMd := + let rec go (remaining : List StmtExprMd) (acc : List StmtExprMd) : List StmtExprMd := + match remaining with + | [] => acc.reverse + | stmt :: rest => + match stmt.val with + | .Assign targets ⟨.StaticCall callee args, callSrc, callMd⟩ => + match infoMap.get? callee.text with + | some info => + if targets.length != info.outputs.length then + go rest (stmt :: acc) + else + let tempName := mkId s!"${callee.text}$temp" + let tempTy := mkTy (.UserDefined (mkId info.resultTypeName)) + let tempDecl := mkMd (.LocalVariable tempName tempTy + (some ⟨.StaticCall callee args, callSrc, callMd⟩)) + let assigns := targets.zipIdx.map fun (tgt, i) => + mkMd (.Assign [tgt] + (mkMd (.StaticCall (mkId (destructorName info i)) + [mkMd (.Identifier tempName)]))) + go rest (assigns.reverse ++ (tempDecl :: acc)) + | none => go rest (stmt :: acc) + | _ => go rest (stmt :: acc) + go stmts [] + +/-- Rewrite blocks in a StmtExprMd tree to handle multi-output calls. -/ +private def rewriteExpr (infoMap : Std.HashMap String MultiOutInfo) + (expr : StmtExprMd) : StmtExprMd := + mapStmtExpr (fun e => + match e.val with + | .Block stmts label => ⟨.Block (rewriteStmts infoMap stmts) label, e.source, e.md⟩ + | _ => e) expr + +/-- Rewrite all procedure bodies. -/ +private def rewriteProcedure (infoMap : Std.HashMap String MultiOutInfo) + (proc : Procedure) : Procedure := + match proc.body with + | .Transparent b => + let wrapped := mkMd (.Block [b] none) + let rewritten := rewriteExpr infoMap wrapped + { proc with body := .Transparent rewritten } + | .Opaque posts (some impl) mods => + let wrapped := mkMd (.Block [impl] none) + let rewritten := rewriteExpr infoMap wrapped + { proc with body := .Opaque posts (some rewritten) mods } + | _ => proc + +/-- Eliminate multiple outputs from a Program. Only applies to functional procedures. -/ +def eliminateMultipleOutputs (program : Program) : Program × List DiagnosticModel := + let infos := collectMultiOutFunctions program.staticProcedures + if infos.isEmpty then (program, []) else + let infoMap : Std.HashMap String MultiOutInfo := + infos.foldl (fun m info => m.insert info.funcName info) {} + -- Validate all call sites first + let diags := program.staticProcedures.flatMap (validateProcedure infoMap) + -- If there are errors, return the program unchanged + if !diags.isEmpty then (program, diags) else + let newDatatypes := infos.map mkResultDatatype + let procs := program.staticProcedures.map fun f => + match infoMap.get? f.name.text with + | some info => rewriteProcedure infoMap (transformFunction info f) + | none => rewriteProcedure infoMap f + ({ program with + staticProcedures := procs + types := program.types ++ newDatatypes.map TypeDefinition.Datatype }, []) + +end -- public section +end Strata.Laurel diff --git a/Strata/Languages/Laurel/LaurelCompilationPipeline.lean b/Strata/Languages/Laurel/LaurelCompilationPipeline.lean index 81d34ba1d6..570855bc66 100644 --- a/Strata/Languages/Laurel/LaurelCompilationPipeline.lean +++ b/Strata/Languages/Laurel/LaurelCompilationPipeline.lean @@ -9,6 +9,7 @@ public import Strata.Languages.Laurel.LaurelToCoreTranslator import Strata.Languages.Laurel.DesugarShortCircuit import Strata.Languages.Laurel.EliminateReturnsInExpression import Strata.Languages.Laurel.EliminateValueReturns +import Strata.Languages.Laurel.EliminateMultipleOutputs import Strata.Languages.Laurel.ConstrainedTypeElim import Strata.Languages.Core.Verifier @@ -110,8 +111,13 @@ private def runLaurelPasses (options : LaurelTranslateOptions) (program : Progra let (program, model) := (result.program, result.model) emit "ConstrainedTypeElim" program + let (program, multiOutDiags) := eliminateMultipleOutputs program + let result := resolve program (some model) + let (program, model) := (result.program, result.model) + emit "EliminateMultipleOutputs" program + let allDiags := resolutionErrors ++ diamondErrors ++ nonCompositeDiags ++ - valueReturnDiags.toList ++ modifiesDiags ++ constrainedTypeDiags + valueReturnDiags.toList ++ modifiesDiags ++ constrainedTypeDiags ++ multiOutDiags return (program, model, allDiags) /-- diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean new file mode 100644 index 0000000000..2a1b039db7 --- /dev/null +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputs.lean @@ -0,0 +1,36 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import StrataTest.Util.TestDiagnostics +import StrataTest.Languages.Laurel.TestExamples + +open StrataTest.Util +open Strata + +namespace Strata.Laurel + +/-! # Multiple Output Functions + +Tests that functions with multiple output parameters are correctly handled +by the EliminateMultipleOutputs pass, which synthesizes a result datatype +and rewrites call sites. +-/ + +def multiOutputProgram := r" +function twoOutputs(x: int) + returns (a: int, b: int); + +procedure testMultiOut() { + var a: int; + a := twoOutputs(5) +// ^^^^^^^^^^^^^ error: call to 'twoOutputs' has 1 assignment target(s), but the function returns 2 output(s) +}; +" + +#guard_msgs (drop info, error) in +#eval testInputWithOffset "MultiOutput" multiOutputProgram 14 processLaurelFile + +end Strata.Laurel diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputsNoError.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputsNoError.lean new file mode 100644 index 0000000000..29fea30c2d --- /dev/null +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T22_MultipleOutputsNoError.lean @@ -0,0 +1,29 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import StrataTest.Util.TestDiagnostics +import StrataTest.Languages.Laurel.TestExamples + +open StrataTest.Util +open Strata + +namespace Strata.Laurel + +/-! # Multiple Output Functions (No Error) + +Tests that a function with multiple output parameters is accepted +by the EliminateMultipleOutputs pass when there is no mismatched call site. +-/ + +def multiOutputNoErrorProgram := r" +function twoOutputs(x: int) + returns (a: int, b: int); +" + +#guard_msgs (drop info, error) in +#eval testInputWithOffset "MultiOutputNoError" multiOutputNoErrorProgram 20 processLaurelFile + +end Strata.Laurel