Skip to content

Commit 8676823

Browse files
committed
refactor(ast lowerer): split the calls lowering in multiple methods, to separate shortcircuit, operators and function calls
1 parent b34f75f commit 8676823

2 files changed

Lines changed: 167 additions & 146 deletions

File tree

include/Ark/Compiler/Lowerer/ASTLowerer.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ namespace Ark::internal
246246
void compilePluginImport(const Node& x, Page p);
247247
void pushFunctionCallArguments(Node& call, Page p, bool is_tail_call);
248248
void handleCalls(Node& x, Page p, bool is_result_unused, bool is_terminal);
249+
void handleShortcircuit(Node& x, Page p);
250+
void handleOperator(Node& x, Page p, Instruction op);
251+
bool handleFunctionCall(Node& x, Page p, bool is_terminal);
249252

250253
/**
251254
* @brief Register a given node in the symbol table

src/arkreactor/Compiler/Lowerer/ASTLowerer.cpp

Lines changed: 164 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,9 @@ namespace Ark::internal
239239
x.list()[i],
240240
p,
241241
// All the nodes in a 'begin' (except for the last one) are producing a result that we want to drop.
242-
(i != size - 1) || is_result_unused,
242+
/* is_result_unused= */ (i != size - 1) || is_result_unused,
243243
// If the 'begin' is a terminal node, only its last node is terminal.
244-
is_terminal && (i == size - 1));
244+
/* is_terminal= */ is_terminal && (i == size - 1));
245245
break;
246246
}
247247

@@ -638,179 +638,197 @@ namespace Ark::internal
638638

639639
void ASTLowerer::handleCalls(Node& x, const Page p, bool is_result_unused, const bool is_terminal)
640640
{
641-
constexpr std::size_t start_index = 1;
642-
643-
Node& node = x.list()[0];
644-
const std::optional<Instruction> maybe_operator = node.nodeType() == NodeType::Symbol ? getOperator(node.string()) : std::nullopt;
641+
const Node& node = x.constList()[0];
642+
bool matched = false;
645643

646-
const std::optional<Instruction> maybe_shortcircuit =
647-
node.nodeType() == NodeType::Symbol
648-
? (node.string() == Language::And
649-
? std::make_optional(Instruction::SHORTCIRCUIT_AND)
650-
: (node.string() == Language::Or
651-
? std::make_optional(Instruction::SHORTCIRCUIT_OR)
652-
: std::nullopt))
653-
: std::nullopt;
644+
if (node.nodeType() == NodeType::Symbol)
645+
{
646+
if (node.string() == Language::And || node.string() == Language::Or)
647+
{
648+
matched = true;
649+
handleShortcircuit(x, p);
650+
}
651+
if (const auto maybe_operator = getOperator(node.string()); maybe_operator.has_value())
652+
{
653+
matched = true;
654+
if (maybe_operator.value() == BREAKPOINT)
655+
is_result_unused = false;
656+
handleOperator(x, p, maybe_operator.value());
657+
}
658+
}
654659

655-
if (maybe_shortcircuit.has_value())
660+
if (!matched)
656661
{
657-
// short circuit implementation
658-
if (x.constList().size() < 3)
659-
buildAndThrowError(
660-
fmt::format(
661-
"Expected at least 2 arguments while compiling '{}', got {}",
662-
node.string(),
663-
x.constList().size() - 1),
664-
x);
665-
const auto name = node.string(); // and / or
662+
// if nothing else matched, then compile a function call
663+
if (handleFunctionCall(x, p, is_terminal))
664+
// if it returned true, we compiled a tail call, skip the POP at the end
665+
return;
666+
}
666667

667-
if (!nodeProducesOutput(x.list()[1]))
668+
if (is_result_unused)
669+
page(p).emplace_back(POP);
670+
}
671+
672+
void ASTLowerer::handleShortcircuit(Node& x, const Page p)
673+
{
674+
const Node& node = x.constList()[0];
675+
const auto name = node.string(); // and / or
676+
const Instruction inst = name == Language::And ? SHORTCIRCUIT_AND : SHORTCIRCUIT_OR;
677+
678+
// short circuit implementation
679+
if (x.constList().size() < 3)
680+
buildAndThrowError(
681+
fmt::format(
682+
"Expected at least 2 arguments while compiling '{}', got {}",
683+
name,
684+
x.constList().size() - 1),
685+
x);
686+
687+
if (!nodeProducesOutput(x.list()[1]))
688+
buildAndThrowError(
689+
fmt::format(
690+
"Can not use `{}' inside a `{}' expression, as it doesn't return a value",
691+
x.list()[1].repr(), name),
692+
x.list()[1]);
693+
compileExpression(x.list()[1], p, false, false);
694+
695+
const auto label_shortcircuit = IR::Entity::Label(m_current_label++);
696+
auto shortcircuit_entity = IR::Entity::Goto(label_shortcircuit, inst);
697+
page(p).emplace_back(shortcircuit_entity);
698+
699+
for (std::size_t i = 2, end = x.constList().size(); i < end; ++i)
700+
{
701+
if (!nodeProducesOutput(x.list()[i]))
668702
buildAndThrowError(
669703
fmt::format(
670704
"Can not use `{}' inside a `{}' expression, as it doesn't return a value",
671-
x.list()[1].repr(), name),
672-
x.list()[1]);
673-
compileExpression(x.list()[1], p, false, false);
705+
x.list()[i].repr(), name),
706+
x.list()[i]);
707+
compileExpression(x.list()[i], p, false, false);
708+
if (i + 1 != end)
709+
page(p).emplace_back(shortcircuit_entity);
710+
}
674711

675-
const auto label_shortcircuit = IR::Entity::Label(m_current_label++);
676-
auto shortcircuit_entity = IR::Entity::Goto(label_shortcircuit, maybe_shortcircuit.value());
677-
page(p).emplace_back(shortcircuit_entity);
712+
page(p).emplace_back(label_shortcircuit);
713+
}
678714

679-
for (std::size_t i = 2, end = x.constList().size(); i < end; ++i)
680-
{
681-
if (!nodeProducesOutput(x.list()[i]))
682-
buildAndThrowError(
683-
fmt::format(
684-
"Can not use `{}' inside a `{}' expression, as it doesn't return a value",
685-
x.list()[i].repr(), name),
686-
x.list()[i]);
687-
compileExpression(x.list()[i], p, false, false);
688-
if (i + 1 != end)
689-
page(p).emplace_back(shortcircuit_entity);
690-
}
715+
void ASTLowerer::handleOperator(Node& x, const Page p, const Instruction op)
716+
{
717+
constexpr std::size_t start_index = 1;
718+
const Node& node = x.constList()[0];
719+
const auto op_name = Language::operators[static_cast<std::size_t>(op - FIRST_OPERATOR)];
691720

692-
page(p).emplace_back(label_shortcircuit);
693-
}
694-
else if (!maybe_operator.has_value())
695-
{
696-
if (is_terminal && node.nodeType() == NodeType::Symbol && isFunctionCallingItself(node.string()))
697-
{
698-
pushFunctionCallArguments(x, p, /* is_tail_call= */ true);
699721

700-
// jump to the top of the function
701-
page(p).emplace_back(JUMP, 0_u16);
702-
page(p).back().setSourceLocation(node.filename(), node.position().start.line);
703-
return; // skip the potential Instruction::POP at the end
704-
}
722+
// push arguments on current page
723+
std::size_t exp_count = 0;
724+
for (std::size_t index = start_index, size = x.constList().size(); index < size; ++index)
725+
{
726+
const bool is_breakpoint = isBreakpoint(x.constList()[index]);
727+
if (nodeProducesOutput(x.constList()[index]) || is_breakpoint)
728+
compileExpression(x.list()[index], p, false, false);
705729
else
706-
{
707-
if (!nodeProducesOutput(node))
708-
buildAndThrowError(fmt::format("Can not call `{}', as it doesn't return a value", node.repr()), node);
730+
buildAndThrowError(fmt::format("Invalid node inside call to operator `{}'", node.repr()), x.constList()[index]);
709731

710-
const auto proc_page = createNewCodePage(/* temp= */ true);
732+
if (!is_breakpoint)
733+
exp_count++;
711734

712-
// compile the function resolution to a separate page
713-
if (node.nodeType() == NodeType::Symbol && isFunctionCallingItself(node.string()))
714-
{
715-
// The function is trying to call itself, but this isn't a tail call.
716-
// We can skip the LOAD_FAST function_name and directly push the current
717-
// function page, which will be quicker than a local variable resolution.
718-
// We set its argument to the symbol id of the function we are calling,
719-
// so that the VM knows the name of the last called function.
720-
page(proc_page).emplace_back(GET_CURRENT_PAGE_ADDR, addSymbol(node));
721-
}
722-
else
723-
{
724-
// closure chains have been handled (eg: closure.field.field.function)
725-
compileExpression(node, proc_page, false, false); // storing proc
726-
}
735+
// in order to be able to handle things like (op A B C D...)
736+
// which should be transformed into A B op C op D op...
737+
if (exp_count >= 2 && !isTernaryInst(op) && !is_breakpoint)
738+
page(p).emplace_back(op);
739+
}
727740

728-
if (m_temp_pages.back().empty())
729-
buildAndThrowError(fmt::format("Can not call {}", x.constList()[0].repr()), x);
741+
if (isBreakpoint(x))
742+
{
743+
if (exp_count > 1)
744+
buildAndThrowError(fmt::format("`{}' expected at most one argument, but was called with {}", op_name, exp_count), x.constList()[0]);
745+
page(p).emplace_back(op, exp_count);
746+
}
747+
else if (isUnaryInst(op))
748+
{
749+
if (exp_count != 1)
750+
buildAndThrowError(fmt::format("`{}' expected one argument, but was called with {}", op_name, exp_count), x.constList()[0]);
751+
page(p).emplace_back(op);
752+
}
753+
else if (isTernaryInst(op))
754+
{
755+
if (exp_count != 3)
756+
buildAndThrowError(fmt::format("`{}' expected three arguments, but was called with {}", op_name, exp_count), x.constList()[0]);
757+
page(p).emplace_back(op);
758+
}
759+
else if (exp_count <= 1)
760+
buildAndThrowError(fmt::format("`{}' expected two arguments, but was called with {}", op_name, exp_count), x.constList()[0]);
730761

731-
const auto label_return = IR::Entity::Label(m_current_label++);
732-
page(p).emplace_back(IR::Entity::Goto(label_return, PUSH_RETURN_ADDRESS));
733-
page(p).back().setSourceLocation(x.filename(), x.position().start.line);
762+
// need to check we didn't push the (op A B C D...) things for operators not supporting it
763+
if (exp_count > 2 && !isRepeatableOperation(op) && !isTernaryInst(op))
764+
buildAndThrowError(fmt::format("`{}' requires 2 arguments, but got {}.", op_name, exp_count), x);
734765

735-
pushFunctionCallArguments(x, p, /* is_tail_call= */ false);
736-
// push proc from temp page
737-
for (const auto& inst : m_temp_pages.back())
738-
page(p).push_back(inst);
739-
m_temp_pages.pop_back();
766+
page(p).back().setSourceLocation(x.filename(), x.position().start.line);
767+
}
740768

741-
// number of arguments
742-
std::size_t args_count = 0;
743-
for (auto it = x.constList().begin() + start_index, it_end = x.constList().end(); it != it_end; ++it)
744-
{
745-
if (it->nodeType() != NodeType::Capture && !isBreakpoint(*it))
746-
args_count++;
747-
}
748-
// call the procedure
749-
page(p).emplace_back(CALL, args_count);
750-
page(p).back().setSourceLocation(node.filename(), node.position().start.line);
769+
bool ASTLowerer::handleFunctionCall(Node& x, const Page p, const bool is_terminal)
770+
{
771+
constexpr std::size_t start_index = 1;
772+
Node& node = x.list()[0];
751773

752-
// patch the PUSH_RETURN_ADDRESS instruction with the return location (IP=CALL instruction IP)
753-
page(p).emplace_back(label_return);
754-
}
755-
}
756-
else // operator
774+
if (is_terminal && node.nodeType() == NodeType::Symbol && isFunctionCallingItself(node.string()))
757775
{
758-
// retrieve operator
759-
const auto op = maybe_operator.value();
760-
const auto op_name = Language::operators[static_cast<std::size_t>(op - FIRST_OPERATOR)];
776+
pushFunctionCallArguments(x, p, /* is_tail_call= */ true);
761777

762-
if (op == BREAKPOINT)
763-
is_result_unused = false;
778+
// jump to the top of the function
779+
page(p).emplace_back(JUMP, 0_u16);
780+
page(p).back().setSourceLocation(node.filename(), node.position().start.line);
781+
return true; // skip the potential Instruction::POP at the end
782+
}
764783

765-
// push arguments on current page
766-
std::size_t exp_count = 0;
767-
for (std::size_t index = start_index, size = x.constList().size(); index < size; ++index)
768-
{
769-
const bool is_breakpoint = isBreakpoint(x.constList()[index]);
770-
if (nodeProducesOutput(x.constList()[index]) || is_breakpoint)
771-
compileExpression(x.list()[index], p, false, false);
772-
else
773-
buildAndThrowError(fmt::format("Invalid node inside call to operator `{}'", node.repr()), x.constList()[index]);
784+
if (!nodeProducesOutput(node))
785+
buildAndThrowError(fmt::format("Can not call `{}', as it doesn't return a value", node.repr()), node);
774786

775-
if (!is_breakpoint)
776-
exp_count++;
787+
const auto proc_page = createNewCodePage(/* temp= */ true);
777788

778-
// in order to be able to handle things like (op A B C D...)
779-
// which should be transformed into A B op C op D op...
780-
if (exp_count >= 2 && !isTernaryInst(op) && !is_breakpoint)
781-
page(p).emplace_back(op);
782-
}
789+
// compile the function resolution to a separate page
790+
if (node.nodeType() == NodeType::Symbol && isFunctionCallingItself(node.string()))
791+
{
792+
// The function is trying to call itself, but this isn't a tail call.
793+
// We can skip the LOAD_FAST function_name and directly push the current
794+
// function page, which will be quicker than a local variable resolution.
795+
// We set its argument to the symbol id of the function we are calling,
796+
// so that the VM knows the name of the last called function.
797+
page(proc_page).emplace_back(GET_CURRENT_PAGE_ADDR, addSymbol(node));
798+
}
799+
else
800+
{
801+
// closure chains have been handled (eg: closure.field.field.function)
802+
compileExpression(node, proc_page, false, false); // storing proc
803+
}
783804

784-
if (isBreakpoint(x))
785-
{
786-
if (exp_count > 1)
787-
buildAndThrowError(fmt::format("`{}' expected at most one argument, but was called with {}", op_name, exp_count), x.constList()[0]);
788-
page(p).emplace_back(op, exp_count);
789-
}
790-
else if (isUnaryInst(op))
791-
{
792-
if (exp_count != 1)
793-
buildAndThrowError(fmt::format("`{}' expected one argument, but was called with {}", op_name, exp_count), x.constList()[0]);
794-
page(p).emplace_back(op);
795-
}
796-
else if (isTernaryInst(op))
797-
{
798-
if (exp_count != 3)
799-
buildAndThrowError(fmt::format("`{}' expected three arguments, but was called with {}", op_name, exp_count), x.constList()[0]);
800-
page(p).emplace_back(op);
801-
}
802-
else if (exp_count <= 1)
803-
buildAndThrowError(fmt::format("`{}' expected two arguments, but was called with {}", op_name, exp_count), x.constList()[0]);
805+
if (m_temp_pages.back().empty())
806+
buildAndThrowError(fmt::format("Can not call {}", x.constList()[0].repr()), x);
807+
808+
const auto label_return = IR::Entity::Label(m_current_label++);
809+
page(p).emplace_back(IR::Entity::Goto(label_return, PUSH_RETURN_ADDRESS));
810+
page(p).back().setSourceLocation(x.filename(), x.position().start.line);
804811

805-
// need to check we didn't push the (op A B C D...) things for operators not supporting it
806-
if (exp_count > 2 && !isRepeatableOperation(op) && !isTernaryInst(op))
807-
buildAndThrowError(fmt::format("`{}' requires 2 arguments, but got {}.", op_name, exp_count), x);
812+
pushFunctionCallArguments(x, p, /* is_tail_call= */ false);
813+
// push proc from temp page
814+
for (const auto& inst : m_temp_pages.back())
815+
page(p).push_back(inst);
816+
m_temp_pages.pop_back();
808817

809-
page(p).back().setSourceLocation(x.filename(), x.position().start.line);
818+
// number of arguments
819+
std::size_t args_count = 0;
820+
for (auto it = x.constList().begin() + start_index, it_end = x.constList().end(); it != it_end; ++it)
821+
{
822+
if (it->nodeType() != NodeType::Capture && !isBreakpoint(*it))
823+
args_count++;
810824
}
825+
// call the procedure
826+
page(p).emplace_back(CALL, args_count);
827+
page(p).back().setSourceLocation(node.filename(), node.position().start.line);
811828

812-
if (is_result_unused)
813-
page(p).emplace_back(POP);
829+
// patch the PUSH_RETURN_ADDRESS instruction with the return location (IP=CALL instruction IP)
830+
page(p).emplace_back(label_return);
831+
return false; // we didn't compile a tail call
814832
}
815833

816834
uint16_t ASTLowerer::addSymbol(const Node& sym)

0 commit comments

Comments
 (0)