diff --git a/include/Ark/Compiler/Lowerer/ASTLowerer.hpp b/include/Ark/Compiler/Lowerer/ASTLowerer.hpp index ad3d897c8..26c08889b 100644 --- a/include/Ark/Compiler/Lowerer/ASTLowerer.hpp +++ b/include/Ark/Compiler/Lowerer/ASTLowerer.hpp @@ -255,20 +255,21 @@ namespace Ark::internal * @param p the current page number we're on * @param is_result_unused * @param is_terminal + * @param can_use_ref */ - void compileExpression(Node& x, Page p, bool is_result_unused, bool is_terminal); + void compileExpression(Node& x, Page p, bool is_result_unused, bool is_terminal, bool can_use_ref); void compileSymbol(const Node& x, Page p, bool is_result_unused, bool can_use_ref); void compileListInstruction(Node& x, Page p, bool is_result_unused); void compileApplyInstruction(Node& x, Page p, bool is_result_unused); - void compileIf(Node& x, Page p, bool is_result_unused, bool is_terminal); + void compileIf(Node& x, Page p, bool is_result_unused, bool is_terminal, bool can_use_ref); void compileFunction(Node& x, Page p, bool is_result_unused); void compileLetMutSet(Keyword n, Node& x, Page p, bool is_result_unused); void compileWhile(Node& x, Page p); void compilePluginImport(const Node& x, Page p); void pushFunctionCallArguments(Node& call, Page p, bool is_tail_call); - void handleCalls(Node& x, Page p, bool is_result_unused, bool is_terminal); - void handleShortcircuit(Node& x, Page p); + void handleCalls(Node& x, Page p, bool is_result_unused, bool is_terminal, bool can_use_ref); + void handleShortcircuit(Node& x, Page p, bool can_use_ref); void handleOperator(Node& x, Page p, Instruction op); bool handleFunctionCall(Node& x, Page p, bool is_terminal); diff --git a/src/arkreactor/Compiler/Lowerer/ASTLowerer.cpp b/src/arkreactor/Compiler/Lowerer/ASTLowerer.cpp index 396d60fc5..2cde165c9 100644 --- a/src/arkreactor/Compiler/Lowerer/ASTLowerer.cpp +++ b/src/arkreactor/Compiler/Lowerer/ASTLowerer.cpp @@ -50,7 +50,8 @@ namespace Ark::internal ast, /* current_page */ global, /* is_result_unused= */ true, - /* is_terminal= */ false); + /* is_terminal= */ false, + /* can_use_ref= */ true); m_logger.traceEnd(); } @@ -208,11 +209,11 @@ namespace Ark::internal } } - void ASTLowerer::compileExpression(Node& x, const Page p, const bool is_result_unused, const bool is_terminal) + void ASTLowerer::compileExpression(Node& x, const Page p, const bool is_result_unused, const bool is_terminal, const bool can_use_ref) { // register symbols if (x.nodeType() == NodeType::Symbol) - compileSymbol(x, p, is_result_unused, /* can_use_ref= */ true); + compileSymbol(x, p, is_result_unused, /* can_use_ref= */ can_use_ref); else if (x.nodeType() == NodeType::Field) { // the parser guarantees us that there is at least 2 elements (eg: a.b) @@ -234,7 +235,7 @@ namespace Ark::internal } // namespace nodes else if (x.nodeType() == NodeType::Namespace) - compileExpression(*x.constArkNamespace().ast, p, is_result_unused, is_terminal); + compileExpression(*x.constArkNamespace().ast, p, is_result_unused, is_terminal, can_use_ref); else if (x.nodeType() == NodeType::List) { // empty code block should be nil @@ -257,7 +258,7 @@ namespace Ark::internal switch (const Keyword keyword = head.keyword()) { case Keyword::If: - compileIf(x, p, is_result_unused, is_terminal); + compileIf(x, p, is_result_unused, is_terminal, can_use_ref); break; case Keyword::Set: @@ -281,7 +282,8 @@ namespace Ark::internal // All the nodes in a 'begin' (except for the last one) are producing a result that we want to drop. /* is_result_unused= */ (i != size - 1) || is_result_unused, // If the 'begin' is a terminal node, only its last node is terminal. - /* is_terminal= */ is_terminal && (i == size - 1)); + /* is_terminal= */ is_terminal && (i == size - 1), + /* can_use_ref= */ can_use_ref); break; } @@ -303,7 +305,7 @@ namespace Ark::internal { // If we are here, we should have a function name via the m_opened_vars. // Push arguments first, then function name, then call it. - handleCalls(x, p, is_result_unused, is_terminal); + handleCalls(x, p, is_result_unused, is_terminal, can_use_ref); } } else if (x.nodeType() != NodeType::Unused) @@ -375,7 +377,7 @@ namespace Ark::internal { Node& node = x.list()[i]; if (nodeProducesOutput(node)) - compileExpression(node, p, false, false); + compileExpression(node, p, false, false, true); else makeError(ErrorKind::InvalidNodeNoReturnValue, node, name); } @@ -420,7 +422,7 @@ namespace Ark::internal { // Load the first argument which should be a symbol (or field), // that append!/concat! write to, so that we have its new value available. - compileExpression(x.list()[1], p, false, false); + compileExpression(x.list()[1], p, false, false, true); } // append!, concat!, pop!, @= and @@= can push to the stack, but not using its returned value isn't an error @@ -445,7 +447,7 @@ namespace Ark::internal for (Node& node : x.list() | std::ranges::views::drop(1)) { if (nodeProducesOutput(node)) - compileExpression(node, p, false, false); + compileExpression(node, p, false, false, true); else makeError(ErrorKind::InvalidNodeNoReturnValue, node, "apply"); } @@ -457,7 +459,7 @@ namespace Ark::internal page(p).emplace_back(POP); } - void ASTLowerer::compileIf(Node& x, const Page p, const bool is_result_unused, const bool is_terminal) + void ASTLowerer::compileIf(Node& x, const Page p, const bool is_result_unused, const bool is_terminal, const bool can_use_ref) { if (x.constList().size() == 1) buildAndThrowError("Invalid condition: missing 'cond' and 'then' nodes, expected (if cond then)", x); @@ -465,7 +467,7 @@ namespace Ark::internal buildAndThrowError(fmt::format("Invalid condition: missing 'then' node, expected (if {} then)", x.constList()[1].repr()), x); // compile condition - compileExpression(x.list()[1], p, false, false); + compileExpression(x.list()[1], p, false, false, true); page(p).back().setSourceLocation(x.constList()[1].filename(), x.constList()[1].position().start.line); // jump only if needed to the "true" branch @@ -476,14 +478,14 @@ namespace Ark::internal if (x.constList().size() == 4) // we have an else clause { m_locals_locator.saveScopeLengthForBranch(); - compileExpression(x.list()[3], p, is_result_unused, is_terminal); + compileExpression(x.list()[3], p, is_result_unused, is_terminal, can_use_ref); page(p).back().setSourceLocation(x.constList()[3].filename(), x.constList()[3].position().start.line); m_locals_locator.dropVarsForBranch(); } else { Node tmp = Node(NodeType::List); - compileExpression(tmp, p, is_result_unused, is_terminal); + compileExpression(tmp, p, is_result_unused, is_terminal, true); } // when else is finished, jump to end @@ -494,7 +496,7 @@ namespace Ark::internal page(p).emplace_back(label_then); // if code m_locals_locator.saveScopeLengthForBranch(); - compileExpression(x.list()[2], p, is_result_unused, is_terminal); + compileExpression(x.list()[2], p, is_result_unused, is_terminal, can_use_ref); page(p).back().setSourceLocation(x.constList()[2].filename(), x.constList()[2].position().start.line); m_locals_locator.dropVarsForBranch(); // set jump to end pos @@ -568,7 +570,7 @@ namespace Ark::internal if (x.isAnonymousFunction()) m_opened_vars.emplace("#anonymous", arg_count); // push body of the function - compileExpression(x.list()[2], function_body_page, false, true); + compileExpression(x.list()[2], function_body_page, false, true, true); if (x.isAnonymousFunction()) m_opened_vars.pop(); @@ -616,7 +618,7 @@ namespace Ark::internal // put value before symbol id // starting at index = 2 because x is a (let|mut|set variable ...) node - compileExpression(x.list()[2], p, false, false); + compileExpression(x.list()[2], p, false, false, true); if (n == Keyword::Let || n == Keyword::Mut) { @@ -647,12 +649,12 @@ namespace Ark::internal const auto label_loop = IR::Entity::Label(m_current_label++); page(p).emplace_back(label_loop); // push condition - compileExpression(x.list()[1], p, false, false); + compileExpression(x.list()[1], p, false, false, true); // absolute jump to end of block if condition is false const auto label_end = IR::Entity::Label(m_current_label++); page(p).emplace_back(IR::Entity::GotoIf(label_end, false)); // push code to page - compileExpression(x.list()[2], p, true, false); + compileExpression(x.list()[2], p, true, false, true); // reset the scope at the end of the loop so that indices are still valid // otherwise, (while true { (let a 5) (print a) (let b 6) (print b) }) @@ -694,18 +696,23 @@ namespace Ark::internal { if (nodeProducesOutput(value) || isBreakpoint(value)) { - // we have to disallow usage of references in tail calls, because if we shuffle arguments around while using refs, they will end up with the same value - if (value.nodeType() == NodeType::Symbol && is_tail_call) - compileSymbol(value, p, /* is_result_unused= */ false, /* can_use_ref= */ false); + if (value.nodeType() == NodeType::Symbol) + { + // we have to disallow usage of references in tail calls, because if we shuffle arguments around while using refs, they will end up with the same value + if (is_tail_call) + compileSymbol(value, p, /* is_result_unused= */ false, /* can_use_ref= */ false); + else + compileSymbol(value, p, /* is_result_unused= */ false, /* can_use_ref= */ true); + } else - compileExpression(value, p, /* is_result_unused= */ false, /* is_terminal= */ false); + compileExpression(value, p, /* is_result_unused= */ false, /* is_terminal= */ false, /* can_use_ref= */ false); } else makeError(is_tail_call ? ErrorKind::InvalidNodeInTailCallNoReturnValue : ErrorKind::InvalidNodeNoReturnValue, value, node.repr()); } } - void ASTLowerer::handleCalls(Node& x, const Page p, bool is_result_unused, const bool is_terminal) + void ASTLowerer::handleCalls(Node& x, const Page p, bool is_result_unused, const bool is_terminal, const bool can_use_ref) { const Node& node = x.constList()[0]; bool matched = false; @@ -715,7 +722,7 @@ namespace Ark::internal if (node.string() == Language::And || node.string() == Language::Or) { matched = true; - handleShortcircuit(x, p); + handleShortcircuit(x, p, can_use_ref); } if (const auto maybe_operator = getOperator(node.string()); maybe_operator.has_value()) { @@ -738,7 +745,7 @@ namespace Ark::internal page(p).emplace_back(POP); } - void ASTLowerer::handleShortcircuit(Node& x, const Page p) + void ASTLowerer::handleShortcircuit(Node& x, const Page p, const bool can_use_ref) { const Node& node = x.constList()[0]; const auto name = node.string(); // and / or @@ -759,7 +766,7 @@ namespace Ark::internal "Can not use `{}' inside a `{}' expression, as it doesn't return a value", x.list()[1].repr(), name), x.list()[1]); - compileExpression(x.list()[1], p, false, false); + compileExpression(x.list()[1], p, false, false, can_use_ref); const auto label_shortcircuit = IR::Entity::Label(m_current_label++); auto shortcircuit_entity = IR::Entity::Goto(label_shortcircuit, inst); @@ -773,7 +780,7 @@ namespace Ark::internal "Can not use `{}' inside a `{}' expression, as it doesn't return a value", x.list()[i].repr(), name), x.list()[i]); - compileExpression(x.list()[i], p, false, false); + compileExpression(x.list()[i], p, false, false, can_use_ref); if (i + 1 != end) page(p).emplace_back(shortcircuit_entity); } @@ -794,7 +801,7 @@ namespace Ark::internal { const bool is_breakpoint = isBreakpoint(x.constList()[index]); if (nodeProducesOutput(x.constList()[index]) || is_breakpoint) - compileExpression(x.list()[index], p, false, false); + compileExpression(x.list()[index], p, false, false, true); else makeError(ErrorKind::InvalidNodeInOperatorNoReturnValue, x.constList()[index], node.repr()); @@ -905,7 +912,7 @@ namespace Ark::internal else { // closure chains have been handled (eg: closure.field.field.function) - compileExpression(node, proc_page, false, false); + compileExpression(node, proc_page, false, false, true); if (page(proc_page).empty()) buildAndThrowError(fmt::format("Can not call {}", x.constList()[0].repr()), x); @@ -964,7 +971,7 @@ namespace Ark::internal case CallType::SymbolByIndex: { const Page temp_page = createNewCodePage(/* temp= */ true); - compileExpression(node, temp_page, false, false); + compileExpression(node, temp_page, false, false, true); assert(page(temp_page).size() == 1 && page(temp_page).back().inst() == LOAD_FAST_BY_INDEX); page(p).emplace_back(CALL_SYMBOL_BY_INDEX, page(temp_page).back().primaryArg(), args_count); m_temp_pages.pop_back(); diff --git a/tests/unittests/resources/CompilerSuite/ir/ref_args.ark b/tests/unittests/resources/CompilerSuite/ir/ref_args.ark new file mode 100644 index 000000000..4b1f221d3 --- /dev/null +++ b/tests/unittests/resources/CompilerSuite/ir/ref_args.ark @@ -0,0 +1,14 @@ +# in tail calls, we need all arguments to be loaded by value and not by ref +# otherwise they get mixed up and bad things happen + +(let f (fun (a b) + (f b (- a 1)))) +(f 9 9) + +(let g (fun (a b) + (g { b } (- a 1)))) +(g 9 9) + +(let h (fun (a b) + (h (or 0 b) (- a 1)))) +(h 9 9) diff --git a/tests/unittests/resources/CompilerSuite/ir/ref_args.expected b/tests/unittests/resources/CompilerSuite/ir/ref_args.expected new file mode 100644 index 000000000..b5cfea74c --- /dev/null +++ b/tests/unittests/resources/CompilerSuite/ir/ref_args.expected @@ -0,0 +1,62 @@ +page_0 + LOAD_CONST 0 + STORE 0 + PUSH_RETURN_ADDRESS L0 + LOAD_CONST 2 + LOAD_CONST 2 + CALL_SYMBOL_BY_INDEX 0, 2 +.L0: + POP 0 + LOAD_CONST 3 + STORE 3 + PUSH_RETURN_ADDRESS L1 + LOAD_CONST 2 + LOAD_CONST 2 + CALL_SYMBOL_BY_INDEX 0, 2 +.L1: + POP 0 + LOAD_CONST 4 + STORE 4 + PUSH_RETURN_ADDRESS L3 + LOAD_CONST 2 + LOAD_CONST 2 + CALL_SYMBOL_BY_INDEX 0, 2 +.L3: + POP 0 + HALT 0 + +page_1 + STORE 1 + STORE 2 + LOAD_SYMBOL 1 + LOAD_FAST_BY_INDEX 0 + LOAD_CONST 1 + SUB 0 + TAIL_CALL_SELF 0 + RET 0 + HALT 0 + +page_2 + STORE 1 + STORE 2 + LOAD_SYMBOL 1 + LOAD_FAST_BY_INDEX 0 + LOAD_CONST 1 + SUB 0 + TAIL_CALL_SELF 0 + RET 0 + HALT 0 + +page_3 + STORE 1 + STORE 2 + LOAD_CONST 5 + SHORTCIRCUIT_OR L2 + LOAD_SYMBOL 1 +.L2: + LOAD_FAST_BY_INDEX 0 + LOAD_CONST 1 + SUB 0 + TAIL_CALL_SELF 0 + RET 0 + HALT 0 diff --git a/tests/unittests/resources/LangSuite/weird-tests.ark b/tests/unittests/resources/LangSuite/weird-tests.ark index e798b61f6..417439697 100644 --- a/tests/unittests/resources/LangSuite/weird-tests.ark +++ b/tests/unittests/resources/LangSuite/weird-tests.ark @@ -76,4 +76,26 @@ (test:eq (type p) "CProc") (test:eq (a data 4) [1 2 3 4]) (test:eq (c data [4 5]) [1 2 3 4 5]) - (test:eq (p data 1) [1 3]) }) }) + (test:eq (p data 1) [1 3]) }) + + (test:case "begin as arg" { + (mut output []) + + (let h (fun (a b) + (if a { + (append! output [a b]) + (h { b } (- a 1)) }))) + (h 2 2) + + (test:eq output [[2 2] [2 1] [1 1] [1 0]]) }) + + (test:case "(or 0 b) as arg" { + (mut output []) + + (let i (fun (a b) + (if a { + (append! output [a b]) + (i (or 0 b) (- a 1)) }))) + (i 2 2) + + (test:eq output [[2 2] [2 1] [1 1] [1 0]]) }) })