diff --git a/Strata/Languages/Laurel/Resolution.lean b/Strata/Languages/Laurel/Resolution.lean index 16bcf1333f..71c1510d91 100644 --- a/Strata/Languages/Laurel/Resolution.lean +++ b/Strata/Languages/Laurel/Resolution.lean @@ -213,6 +213,10 @@ structure ResolveState where /-- When resolving inside an instance procedure, the owning composite type name. Used by `resolveFieldRef` to resolve `self.field` when `self` has type `Any`. -/ instanceTypeName : Option String := none + /-- True when resolving inside an expression where the value is used (e.g., as an + argument to another call or operator). Multi-output calls are only diagnosed + in value context, not in statement position or direct assignment RHS. -/ + inValueContext : Bool := false @[expose] abbrev ResolveM := StateM ResolveState @@ -358,7 +362,10 @@ def resolveStmtExpr (exprMd : StmtExprMd) : ResolveM StmtExprMd := do | AstNode.mk expr source => let val' ← match _: expr with | .IfThenElse cond thenBr elseBr => + let saved := (← get).inValueContext + modify fun s => { s with inValueContext := true } let cond' ← resolveStmtExpr cond + modify fun s => { s with inValueContext := saved } let thenBr' ← resolveStmtExpr thenBr let elseBr' ← elseBr.attach.mapM (fun a => have := a.property; resolveStmtExpr a.val) pure (.IfThenElse cond' thenBr' elseBr') @@ -367,7 +374,10 @@ def resolveStmtExpr (exprMd : StmtExprMd) : ResolveM StmtExprMd := do let stmts' ← stmts.mapM resolveStmtExpr pure (.Block stmts' label) | .While cond invs dec body => + let saved := (← get).inValueContext + modify fun s => { s with inValueContext := true } let cond' ← resolveStmtExpr cond + modify fun s => { s with inValueContext := saved } let invs' ← invs.attach.mapM (fun a => have := a.property; resolveStmtExpr a.val) let dec' ← dec.attach.mapM (fun a => have := a.property; resolveStmtExpr a.val) let body' ← resolveStmtExpr body @@ -437,10 +447,30 @@ def resolveStmtExpr (exprMd : StmtExprMd) : ResolveM StmtExprMd := do | .StaticCall callee args => let callee' ← resolveRef callee source (expected := #[.parameter, .staticProcedure, .datatypeConstructor, .constant]) + -- Resolve arguments in value context (their results are used as values) + let saved := (← get).inValueContext + modify fun s => { s with inValueContext := true } let args' ← args.mapM resolveStmtExpr + modify fun s => { s with inValueContext := saved } + -- Multi-output procedures must not appear in value context: the extra + -- outputs (e.g. error channels) would be silently discarded. + let s ← get + if s.inValueContext then + let outputCount := match s.scope.get? callee'.text with + | some (_, .staticProcedure proc) => proc.outputs.length + | some (_, .instanceProcedure _ proc) => proc.outputs.length + | _ => 0 + if outputCount > 1 then + let diag := diagnosticFromSource source + s!"Multi-output procedure '{callee'.text}' used in expression position; it returns {outputCount} values but only one can be used here. Use a multi-target assignment instead." + modify fun s => { s with errors := s.errors.push diag } pure (.StaticCall callee' args') | .PrimitiveOp op args => + -- Resolve arguments in value context + let saved := (← get).inValueContext + modify fun s => { s with inValueContext := true } let args' ← args.mapM resolveStmtExpr + modify fun s => { s with inValueContext := saved } pure (.PrimitiveOp op args') | .New ref => let ref' ← resolveRef ref source @@ -482,10 +512,16 @@ def resolveStmtExpr (exprMd : StmtExprMd) : ResolveM StmtExprMd := do let val' ← resolveStmtExpr val pure (.Fresh val') | .Assert ⟨condExpr, summary⟩ => + let saved := (← get).inValueContext + modify fun s => { s with inValueContext := true } let cond' ← resolveStmtExpr condExpr + modify fun s => { s with inValueContext := saved } pure (.Assert { condition := cond', summary }) | .Assume cond => + let saved := (← get).inValueContext + modify fun s => { s with inValueContext := true } let cond' ← resolveStmtExpr cond + modify fun s => { s with inValueContext := saved } pure (.Assume cond') | .ProveBy val proof => let val' ← resolveStmtExpr val diff --git a/StrataTest/Languages/Laurel/ResolutionKindTests.lean b/StrataTest/Languages/Laurel/ResolutionKindTests.lean index acbef556b6..6c58bcd573 100644 --- a/StrataTest/Languages/Laurel/ResolutionKindTests.lean +++ b/StrataTest/Languages/Laurel/ResolutionKindTests.lean @@ -97,4 +97,17 @@ composite Foo extends nat { } #guard_msgs (error, drop all) in #eval testInputWithOffset "ExtendConstrained" extendConstrained 90 processResolution +/-! ## Multi-output procedure used in expression position -/ + +def multiOutputInExpr := r" +procedure multi(x: int) returns (a: int, b: int) opaque; +procedure test() opaque { + assert multi(1) == 1 +// ^^^^^^^^ error: Multi-output procedure 'multi' used in expression position +}; +" + +#guard_msgs (error, drop all) in +#eval testInputWithOffset "MultiOutputInExpr" multiOutputInExpr 100 processResolution + end Laurel