Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions include/Ark/Compiler/Lowerer/ASTLowerer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,20 +255,21 @@ namespace Ark::internal
* @param p the current page number we're on
* @param is_result_unused
* @param is_terminal
* @param can_use_ref
*/
void compileExpression(Node& x, Page p, bool is_result_unused, bool is_terminal);
void compileExpression(Node& x, Page p, bool is_result_unused, bool is_terminal, bool can_use_ref);

void compileSymbol(const Node& x, Page p, bool is_result_unused, bool can_use_ref);
void compileListInstruction(Node& x, Page p, bool is_result_unused);
void compileApplyInstruction(Node& x, Page p, bool is_result_unused);
void compileIf(Node& x, Page p, bool is_result_unused, bool is_terminal);
void compileIf(Node& x, Page p, bool is_result_unused, bool is_terminal, bool can_use_ref);
void compileFunction(Node& x, Page p, bool is_result_unused);
void compileLetMutSet(Keyword n, Node& x, Page p, bool is_result_unused);
void compileWhile(Node& x, Page p);
void compilePluginImport(const Node& x, Page p);
void pushFunctionCallArguments(Node& call, Page p, bool is_tail_call);
void handleCalls(Node& x, Page p, bool is_result_unused, bool is_terminal);
void handleShortcircuit(Node& x, Page p);
void handleCalls(Node& x, Page p, bool is_result_unused, bool is_terminal, bool can_use_ref);
void handleShortcircuit(Node& x, Page p, bool can_use_ref);
void handleOperator(Node& x, Page p, Instruction op);
bool handleFunctionCall(Node& x, Page p, bool is_terminal);

Expand Down
69 changes: 38 additions & 31 deletions src/arkreactor/Compiler/Lowerer/ASTLowerer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ namespace Ark::internal
ast,
/* current_page */ global,
/* is_result_unused= */ true,
/* is_terminal= */ false);
/* is_terminal= */ false,
/* can_use_ref= */ true);
m_logger.traceEnd();
}

Expand Down Expand Up @@ -208,11 +209,11 @@ namespace Ark::internal
}
}

void ASTLowerer::compileExpression(Node& x, const Page p, const bool is_result_unused, const bool is_terminal)
void ASTLowerer::compileExpression(Node& x, const Page p, const bool is_result_unused, const bool is_terminal, const bool can_use_ref)
{
// register symbols
if (x.nodeType() == NodeType::Symbol)
compileSymbol(x, p, is_result_unused, /* can_use_ref= */ true);
compileSymbol(x, p, is_result_unused, /* can_use_ref= */ can_use_ref);
else if (x.nodeType() == NodeType::Field)
{
// the parser guarantees us that there is at least 2 elements (eg: a.b)
Expand All @@ -234,7 +235,7 @@ namespace Ark::internal
}
// namespace nodes
else if (x.nodeType() == NodeType::Namespace)
compileExpression(*x.constArkNamespace().ast, p, is_result_unused, is_terminal);
compileExpression(*x.constArkNamespace().ast, p, is_result_unused, is_terminal, can_use_ref);
else if (x.nodeType() == NodeType::List)
{
// empty code block should be nil
Expand All @@ -257,7 +258,7 @@ namespace Ark::internal
switch (const Keyword keyword = head.keyword())
{
case Keyword::If:
compileIf(x, p, is_result_unused, is_terminal);
compileIf(x, p, is_result_unused, is_terminal, can_use_ref);
break;

case Keyword::Set:
Expand All @@ -281,7 +282,8 @@ namespace Ark::internal
// All the nodes in a 'begin' (except for the last one) are producing a result that we want to drop.
/* is_result_unused= */ (i != size - 1) || is_result_unused,
// If the 'begin' is a terminal node, only its last node is terminal.
/* is_terminal= */ is_terminal && (i == size - 1));
/* is_terminal= */ is_terminal && (i == size - 1),
/* can_use_ref= */ can_use_ref);
break;
}

Expand All @@ -303,7 +305,7 @@ namespace Ark::internal
{
// If we are here, we should have a function name via the m_opened_vars.
// Push arguments first, then function name, then call it.
handleCalls(x, p, is_result_unused, is_terminal);
handleCalls(x, p, is_result_unused, is_terminal, can_use_ref);
}
}
else if (x.nodeType() != NodeType::Unused)
Expand Down Expand Up @@ -375,7 +377,7 @@ namespace Ark::internal
{
Node& node = x.list()[i];
if (nodeProducesOutput(node))
compileExpression(node, p, false, false);
compileExpression(node, p, false, false, true);
else
makeError(ErrorKind::InvalidNodeNoReturnValue, node, name);
}
Expand Down Expand Up @@ -420,7 +422,7 @@ namespace Ark::internal
{
// Load the first argument which should be a symbol (or field),
// that append!/concat! write to, so that we have its new value available.
compileExpression(x.list()[1], p, false, false);
compileExpression(x.list()[1], p, false, false, true);
}

// append!, concat!, pop!, @= and @@= can push to the stack, but not using its returned value isn't an error
Expand All @@ -445,7 +447,7 @@ namespace Ark::internal
for (Node& node : x.list() | std::ranges::views::drop(1))
{
if (nodeProducesOutput(node))
compileExpression(node, p, false, false);
compileExpression(node, p, false, false, true);
else
makeError(ErrorKind::InvalidNodeNoReturnValue, node, "apply");
}
Expand All @@ -457,15 +459,15 @@ namespace Ark::internal
page(p).emplace_back(POP);
}

void ASTLowerer::compileIf(Node& x, const Page p, const bool is_result_unused, const bool is_terminal)
void ASTLowerer::compileIf(Node& x, const Page p, const bool is_result_unused, const bool is_terminal, const bool can_use_ref)
{
if (x.constList().size() == 1)
buildAndThrowError("Invalid condition: missing 'cond' and 'then' nodes, expected (if cond then)", x);
if (x.constList().size() == 2)
buildAndThrowError(fmt::format("Invalid condition: missing 'then' node, expected (if {} then)", x.constList()[1].repr()), x);

// compile condition
compileExpression(x.list()[1], p, false, false);
compileExpression(x.list()[1], p, false, false, true);
page(p).back().setSourceLocation(x.constList()[1].filename(), x.constList()[1].position().start.line);

// jump only if needed to the "true" branch
Expand All @@ -476,14 +478,14 @@ namespace Ark::internal
if (x.constList().size() == 4) // we have an else clause
{
m_locals_locator.saveScopeLengthForBranch();
compileExpression(x.list()[3], p, is_result_unused, is_terminal);
compileExpression(x.list()[3], p, is_result_unused, is_terminal, can_use_ref);
page(p).back().setSourceLocation(x.constList()[3].filename(), x.constList()[3].position().start.line);
m_locals_locator.dropVarsForBranch();
}
else
{
Node tmp = Node(NodeType::List);
compileExpression(tmp, p, is_result_unused, is_terminal);
compileExpression(tmp, p, is_result_unused, is_terminal, true);
}

// when else is finished, jump to end
Expand All @@ -494,7 +496,7 @@ namespace Ark::internal
page(p).emplace_back(label_then);
// if code
m_locals_locator.saveScopeLengthForBranch();
compileExpression(x.list()[2], p, is_result_unused, is_terminal);
compileExpression(x.list()[2], p, is_result_unused, is_terminal, can_use_ref);
page(p).back().setSourceLocation(x.constList()[2].filename(), x.constList()[2].position().start.line);
m_locals_locator.dropVarsForBranch();
// set jump to end pos
Expand Down Expand Up @@ -568,7 +570,7 @@ namespace Ark::internal
if (x.isAnonymousFunction())
m_opened_vars.emplace("#anonymous", arg_count);
// push body of the function
compileExpression(x.list()[2], function_body_page, false, true);
compileExpression(x.list()[2], function_body_page, false, true, true);
if (x.isAnonymousFunction())
m_opened_vars.pop();

Expand Down Expand Up @@ -616,7 +618,7 @@ namespace Ark::internal

// put value before symbol id
// starting at index = 2 because x is a (let|mut|set variable ...) node
compileExpression(x.list()[2], p, false, false);
compileExpression(x.list()[2], p, false, false, true);

if (n == Keyword::Let || n == Keyword::Mut)
{
Expand Down Expand Up @@ -647,12 +649,12 @@ namespace Ark::internal
const auto label_loop = IR::Entity::Label(m_current_label++);
page(p).emplace_back(label_loop);
// push condition
compileExpression(x.list()[1], p, false, false);
compileExpression(x.list()[1], p, false, false, true);
// absolute jump to end of block if condition is false
const auto label_end = IR::Entity::Label(m_current_label++);
page(p).emplace_back(IR::Entity::GotoIf(label_end, false));
// push code to page
compileExpression(x.list()[2], p, true, false);
compileExpression(x.list()[2], p, true, false, true);

// reset the scope at the end of the loop so that indices are still valid
// otherwise, (while true { (let a 5) (print a) (let b 6) (print b) })
Expand Down Expand Up @@ -694,18 +696,23 @@ namespace Ark::internal
{
if (nodeProducesOutput(value) || isBreakpoint(value))
{
// we have to disallow usage of references in tail calls, because if we shuffle arguments around while using refs, they will end up with the same value
if (value.nodeType() == NodeType::Symbol && is_tail_call)
compileSymbol(value, p, /* is_result_unused= */ false, /* can_use_ref= */ false);
if (value.nodeType() == NodeType::Symbol)
{
// we have to disallow usage of references in tail calls, because if we shuffle arguments around while using refs, they will end up with the same value
if (is_tail_call)
compileSymbol(value, p, /* is_result_unused= */ false, /* can_use_ref= */ false);
else
compileSymbol(value, p, /* is_result_unused= */ false, /* can_use_ref= */ true);
}
else
compileExpression(value, p, /* is_result_unused= */ false, /* is_terminal= */ false);
compileExpression(value, p, /* is_result_unused= */ false, /* is_terminal= */ false, /* can_use_ref= */ false);
}
else
makeError(is_tail_call ? ErrorKind::InvalidNodeInTailCallNoReturnValue : ErrorKind::InvalidNodeNoReturnValue, value, node.repr());
}
}

void ASTLowerer::handleCalls(Node& x, const Page p, bool is_result_unused, const bool is_terminal)
void ASTLowerer::handleCalls(Node& x, const Page p, bool is_result_unused, const bool is_terminal, const bool can_use_ref)
{
const Node& node = x.constList()[0];
bool matched = false;
Expand All @@ -715,7 +722,7 @@ namespace Ark::internal
if (node.string() == Language::And || node.string() == Language::Or)
{
matched = true;
handleShortcircuit(x, p);
handleShortcircuit(x, p, can_use_ref);
}
if (const auto maybe_operator = getOperator(node.string()); maybe_operator.has_value())
{
Expand All @@ -738,7 +745,7 @@ namespace Ark::internal
page(p).emplace_back(POP);
}

void ASTLowerer::handleShortcircuit(Node& x, const Page p)
void ASTLowerer::handleShortcircuit(Node& x, const Page p, const bool can_use_ref)
{
const Node& node = x.constList()[0];
const auto name = node.string(); // and / or
Expand All @@ -759,7 +766,7 @@ namespace Ark::internal
"Can not use `{}' inside a `{}' expression, as it doesn't return a value",
x.list()[1].repr(), name),
x.list()[1]);
compileExpression(x.list()[1], p, false, false);
compileExpression(x.list()[1], p, false, false, can_use_ref);

const auto label_shortcircuit = IR::Entity::Label(m_current_label++);
auto shortcircuit_entity = IR::Entity::Goto(label_shortcircuit, inst);
Expand All @@ -773,7 +780,7 @@ namespace Ark::internal
"Can not use `{}' inside a `{}' expression, as it doesn't return a value",
x.list()[i].repr(), name),
x.list()[i]);
compileExpression(x.list()[i], p, false, false);
compileExpression(x.list()[i], p, false, false, can_use_ref);
if (i + 1 != end)
page(p).emplace_back(shortcircuit_entity);
}
Expand All @@ -794,7 +801,7 @@ namespace Ark::internal
{
const bool is_breakpoint = isBreakpoint(x.constList()[index]);
if (nodeProducesOutput(x.constList()[index]) || is_breakpoint)
compileExpression(x.list()[index], p, false, false);
compileExpression(x.list()[index], p, false, false, true);
else
makeError(ErrorKind::InvalidNodeInOperatorNoReturnValue, x.constList()[index], node.repr());

Expand Down Expand Up @@ -905,7 +912,7 @@ namespace Ark::internal
else
{
// closure chains have been handled (eg: closure.field.field.function)
compileExpression(node, proc_page, false, false);
compileExpression(node, proc_page, false, false, true);

if (page(proc_page).empty())
buildAndThrowError(fmt::format("Can not call {}", x.constList()[0].repr()), x);
Expand Down Expand Up @@ -964,7 +971,7 @@ namespace Ark::internal
case CallType::SymbolByIndex:
{
const Page temp_page = createNewCodePage(/* temp= */ true);
compileExpression(node, temp_page, false, false);
compileExpression(node, temp_page, false, false, true);
assert(page(temp_page).size() == 1 && page(temp_page).back().inst() == LOAD_FAST_BY_INDEX);
page(p).emplace_back(CALL_SYMBOL_BY_INDEX, page(temp_page).back().primaryArg(), args_count);
m_temp_pages.pop_back();
Expand Down
14 changes: 14 additions & 0 deletions tests/unittests/resources/CompilerSuite/ir/ref_args.ark
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# in tail calls, we need all arguments to be loaded by value and not by ref
# otherwise they get mixed up and bad things happen

(let f (fun (a b)
(f b (- a 1))))
(f 9 9)

(let g (fun (a b)
(g { b } (- a 1))))
(g 9 9)

(let h (fun (a b)
(h (or 0 b) (- a 1))))
(h 9 9)
62 changes: 62 additions & 0 deletions tests/unittests/resources/CompilerSuite/ir/ref_args.expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
page_0
LOAD_CONST 0
STORE 0
PUSH_RETURN_ADDRESS L0
LOAD_CONST 2
LOAD_CONST 2
CALL_SYMBOL_BY_INDEX 0, 2
.L0:
POP 0
LOAD_CONST 3
STORE 3
PUSH_RETURN_ADDRESS L1
LOAD_CONST 2
LOAD_CONST 2
CALL_SYMBOL_BY_INDEX 0, 2
.L1:
POP 0
LOAD_CONST 4
STORE 4
PUSH_RETURN_ADDRESS L3
LOAD_CONST 2
LOAD_CONST 2
CALL_SYMBOL_BY_INDEX 0, 2
.L3:
POP 0
HALT 0

page_1
STORE 1
STORE 2
LOAD_SYMBOL 1
LOAD_FAST_BY_INDEX 0
LOAD_CONST 1
SUB 0
TAIL_CALL_SELF 0
RET 0
HALT 0

page_2
STORE 1
STORE 2
LOAD_SYMBOL 1
LOAD_FAST_BY_INDEX 0
LOAD_CONST 1
SUB 0
TAIL_CALL_SELF 0
RET 0
HALT 0

page_3
STORE 1
STORE 2
LOAD_CONST 5
SHORTCIRCUIT_OR L2
LOAD_SYMBOL 1
.L2:
LOAD_FAST_BY_INDEX 0
LOAD_CONST 1
SUB 0
TAIL_CALL_SELF 0
RET 0
HALT 0
24 changes: 23 additions & 1 deletion tests/unittests/resources/LangSuite/weird-tests.ark
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,26 @@
(test:eq (type p) "CProc")
(test:eq (a data 4) [1 2 3 4])
(test:eq (c data [4 5]) [1 2 3 4 5])
(test:eq (p data 1) [1 3]) }) })
(test:eq (p data 1) [1 3]) })

(test:case "begin as arg" {
(mut output [])

(let h (fun (a b)
(if a {
(append! output [a b])
(h { b } (- a 1)) })))
(h 2 2)

(test:eq output [[2 2] [2 1] [1 1] [1 0]]) })

(test:case "(or 0 b) as arg" {
(mut output [])

(let i (fun (a b)
(if a {
(append! output [a b])
(i (or 0 b) (- a 1)) })))
(i 2 2)

(test:eq output [[2 2] [2 1] [1 1] [1 0]]) }) })
Loading