Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
840 changes: 805 additions & 35 deletions compiler/ASTNodes.hpp

Large diffs are not rendered by default.

132 changes: 132 additions & 0 deletions compiler/ASTRewriter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include "ASTRewriter.hpp"
#include <typeinfo>


// The function assumes that the children have been assigned to a new parent.
void deleteNode(Node* node)
{
// The node's children needs to be nullptrs before deletion.
// If they aren't the children will be deleted too.
for (int i = 0; i < node->numChildren(); i++) {
node->setChild(i,nullptr);
}
delete node;
}


// Helper function for replacing the current node.
// The function assumes that the replacementNode already has received currentNode's children.
void replaceNode(const int childIndex, Node* currentNode, Node* replacementNode)
{
currentNode->getParent()->setChild(childIndex,replacementNode);
deleteNode(currentNode);
}


// The rewriter traverses and modifies the subtree of the root node in a pre-order fashion.
// This means that the algorithm first recurses to the bottom of the tree and modifies the
// tree from the bottom and up.
// The childIndex parameter denotes the placement of root relative to its parent.
// childIndex = 0 means it's the first child, childIndex = 1 means it's the second child.
void ASTRewriter::rewrite(Node* root, const int childIndex)
{

if (root == nullptr) { return; }

if (typeid(*root).name() == typeid(SequenceNode).name()) {
auto current = dynamic_cast<SequenceNode*>(root);
int i = 0;
for (auto n : current->nodes()) {
n->setParent(current);
rewrite(n, i);
++i;
}
}else
// Pattern match for binary operations
if (typeid(*root).name() == typeid(BinaryOpNode).name()) {

BinaryOpNode* current = dynamic_cast<BinaryOpNode*>(root);

current->getChild(0)->setParent(current);
current->getChild(1)->setParent(current);
rewrite(current->getChild(0), 0);
rewrite(current->getChild(1), 1);

// Pattern matching for multiply-add
if (current->op() == Add) {

auto* lhs = dynamic_cast<BinaryOpNode*>(current->getChild(0));
auto* rhs = dynamic_cast<BinaryOpNode*>(current->getChild(1));

// Accounts for a*b+c and a+b*c
BinaryOpNode* mulOpNode;
int addOpNodeIndex = 0;
if (lhs != nullptr && lhs->op() == Multiply) {
mulOpNode = lhs;
// child index of rhs
addOpNodeIndex = 1;
} else
if (rhs != nullptr && rhs->op() == Multiply) {
mulOpNode = rhs;
//child index of lhs
addOpNodeIndex = 0;
} else {
return;
}

auto replacementNode =
new MultiplyAddNode(dynamic_cast<ExpressionNode*>(mulOpNode->getChild(0)),
dynamic_cast<ExpressionNode*>(mulOpNode->getChild(1)),
dynamic_cast<ExpressionNode*>(current->getChild(addOpNodeIndex)));
replaceNode(childIndex, current, replacementNode);
deleteNode(mulOpNode);
} else
if (current->op() == Multiply) {
BinaryOpNode* current = dynamic_cast<BinaryOpNode*>(root);
current->getChild(0)->setParent(current);
current->getChild(1)->setParent(current);
rewrite(current->getChild(0), 0);
rewrite(current->getChild(1), 1);

auto* rhs = dynamic_cast<BinaryOpNode*>(current->getChild(1));

if (rhs != nullptr && rhs->op() == Divide) {

auto replacementNode =
new MultiplyDivideNode(dynamic_cast<ExpressionNode*>(current->getChild(0)),
dynamic_cast<ExpressionNode*>(rhs->getChild(0)),
dynamic_cast<ExpressionNode*>(rhs->getChild(1)));
replaceNode(childIndex, current, replacementNode);
deleteNode(rhs);
} else {
return;
}
} else
if (current->op() == Divide) {
BinaryOpNode* current = dynamic_cast<BinaryOpNode*>(root);
current->getChild(0)->setParent(current);
current->getChild(1)->setParent(current);
rewrite(current->getChild(0), 0);
rewrite(current->getChild(1), 1);

auto* lhs = dynamic_cast<BinaryOpNode*>(current->getChild(0));

if (lhs != nullptr && lhs->op() == Multiply) {

auto replacementNode =
new MultiplyDivideNode(dynamic_cast<ExpressionNode*>(lhs->getChild(0)),
dynamic_cast<ExpressionNode*>(lhs->getChild(1)),
dynamic_cast<ExpressionNode*>(current->getChild(1)));
replaceNode(childIndex, current, replacementNode);
deleteNode(lhs);
}
}
} else {
for ( int i = 0; i < root->numChildren(); i++ ) {
if ( root->getChild(i) != nullptr ){
root->getChild(i)->setParent(root);
rewrite(root->getChild(i), i);
}
}
}
}
13 changes: 13 additions & 0 deletions compiler/ASTRewriter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef ASTREWRITER_HEADER_INCLUDED
#define ASTREWRITER_HEADER_INCLUDED

#include "ASTNodes.hpp"
class Node;

class ASTRewriter
{
public:
void rewrite(Node* root, const int childIndex);
};

#endif // ASTREWRITER_HEADER_INCLUDED
8 changes: 8 additions & 0 deletions compiler/ASTVisitorInterface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class ArrayNode;
class RandomAccessNode;
class StencilAssignmentNode;
class StencilNode;
class MultiplyAddNode;
class MultiplyDivideNode;


class ASTVisitorInterface
Expand All @@ -55,6 +57,12 @@ class ASTVisitorInterface
virtual void visit(BinaryOpUnitNode& node) {}
virtual void midVisit(BinaryOpUnitNode& node) {}
virtual void postVisit(BinaryOpUnitNode& node) {}
virtual void visit(MultiplyAddNode& node) = 0;
virtual void midVisit(MultiplyAddNode& node) = 0;
virtual void postVisit(MultiplyAddNode& node) = 0;
virtual void visit(MultiplyDivideNode& node) = 0;
virtual void midVisit(MultiplyDivideNode& node) = 0;
virtual void postVisit(MultiplyDivideNode& node) = 0;
virtual void visit(PowerUnitNode& node) {}
virtual void postVisit(PowerUnitNode& node) {}
virtual void visit(StringNode& node) = 0;
Expand Down
26 changes: 25 additions & 1 deletion compiler/CheckASTVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ CheckASTVisitor::~CheckASTVisitor()

bool CheckASTVisitor::isValid()
{
return valid_;
return valid_;
}


Expand Down Expand Up @@ -205,6 +205,30 @@ void CheckASTVisitor::postVisit(BinaryOpNode& node)
}
}

void CheckASTVisitor::visit(MultiplyAddNode& node)
{
}

void CheckASTVisitor::midVisit(MultiplyAddNode& node)
{
}

void CheckASTVisitor::postVisit(MultiplyAddNode& node)
{
}

void CheckASTVisitor::visit(MultiplyDivideNode& node)
{
}

void CheckASTVisitor::midVisit(MultiplyDivideNode& node)
{
}

void CheckASTVisitor::postVisit(MultiplyDivideNode& node)
{
}

void CheckASTVisitor::visit(ComparisonOpNode&)
{
}
Expand Down
6 changes: 6 additions & 0 deletions compiler/CheckASTVisitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class CheckASTVisitor : public ASTVisitorInterface
void visit(BinaryOpNode& node);
void midVisit(BinaryOpNode& node);
void postVisit(BinaryOpNode& node);
void visit(MultiplyDivideNode& node);
void midVisit(MultiplyDivideNode& node);
void postVisit(MultiplyDivideNode& node);
void visit(MultiplyAddNode& node);
void midVisit(MultiplyAddNode& node);
void postVisit(MultiplyAddNode& node);
void visit(ComparisonOpNode& node);
void midVisit(ComparisonOpNode& node);
void postVisit(ComparisonOpNode& node);
Expand Down
2 changes: 1 addition & 1 deletion compiler/CommandLineOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class CommandLineOptions {
("verbose", "Verbose output")
("config,c", boost::program_options::value<std::string>(), "Configuration filename (specify command line parameters in file)")
("input,i", boost::program_options::value<std::string>()->required(), "Input Equelle file to compile")
("backend", boost::program_options::value<std::string>()->default_value("cpu"), "Backend of compiler to use (io, ast, ast_equelle, cpu*, cuda, mrst)")
("backend", boost::program_options::value<std::string>()->default_value("cpu"), "Backend of compiler to use (io, ast, ast_equelle, cpu*, cuda, cuda-ast-rewrite, mrst)")
("nondimensional", "Disable dimension checking")
("dump", boost::program_options::value<std::string>()->default_value("none"), "Dump compiler internals (symboltable, io)");
}
Expand Down
13 changes: 13 additions & 0 deletions compiler/NodeInterface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Node
{
public:
Node()
:parent_(nullptr)
{}
virtual ~Node()
{}
Expand All @@ -30,7 +31,19 @@ class Node
{
loc_ = loc;
}
Node* getParent()
{
return parent_;
}
void setParent(Node* parent)
{
parent_ = parent;
}
virtual int numChildren() = 0;
virtual Node* getChild(const int index) = 0;
virtual void setChild(const int index, Node* child) = 0;
private:
Node* parent_;
// No copying.
Node(const Node&);
// No assignment.
Expand Down
26 changes: 24 additions & 2 deletions compiler/PrintASTVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ PrintASTVisitor::~PrintASTVisitor()





void PrintASTVisitor::visit(SequenceNode&)
{
if (indent_ == 0) {
Expand Down Expand Up @@ -111,6 +109,14 @@ void PrintASTVisitor::visit(BinaryOpNode& node)
++indent_;
}

void PrintASTVisitor::visit(MultiplyAddNode& node)
{
}

void PrintASTVisitor::visit(MultiplyDivideNode& node)
{
}

void PrintASTVisitor::visit(ComparisonOpNode& node)
{
std::string op(" ");
Expand Down Expand Up @@ -288,6 +294,22 @@ void PrintASTVisitor::postVisit(BinaryOpNode&)
--indent_;
}

void PrintASTVisitor::midVisit(MultiplyAddNode& node)
{
}

void PrintASTVisitor::postVisit(MultiplyAddNode& node)
{
}

void PrintASTVisitor::midVisit(MultiplyDivideNode& node)
{
}

void PrintASTVisitor::postVisit(MultiplyDivideNode& node)
{
}

void PrintASTVisitor::midVisit(ComparisonOpNode&)
{
}
Expand Down
6 changes: 6 additions & 0 deletions compiler/PrintASTVisitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ class PrintASTVisitor : public ASTVisitorInterface
void visit(BinaryOpNode& node);
void midVisit(BinaryOpNode& node);
void postVisit(BinaryOpNode& node);
void visit(MultiplyAddNode& node);
void midVisit(MultiplyAddNode& node);
void postVisit(MultiplyAddNode& node);
void visit(MultiplyDivideNode& node);
void midVisit(MultiplyDivideNode& node);
void postVisit(MultiplyDivideNode& node);
void visit(ComparisonOpNode& node);
void midVisit(ComparisonOpNode& node);
void postVisit(ComparisonOpNode& node);
Expand Down
51 changes: 49 additions & 2 deletions compiler/PrintCPUBackendASTVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void PrintCPUBackendASTVisitor::postVisit(SequenceNode&)
std::cout <<
"\n"
"void ensureRequirements(const " << namespaceNameString() <<
"::" << classNameString() << "& er)\n"
"::" << classNameString() << "& er)\n"
"{\n";
if (requirement_strings_.empty()) {
std::cout << " (void)er;\n";
Expand Down Expand Up @@ -169,6 +169,53 @@ void PrintCPUBackendASTVisitor::postVisit(BinaryOpNode&)
std::cout << ')';
}

void PrintCPUBackendASTVisitor::visit(MultiplyAddNode&)
{
if (isSuppressed()) {
return;
}
std::cout << "er.multiplyAdd(";
}

void PrintCPUBackendASTVisitor::midVisit(MultiplyAddNode&)
{
if (isSuppressed()) {
return;
}
std::cout << ", ";
}

void PrintCPUBackendASTVisitor::postVisit(MultiplyAddNode&)
{
if (isSuppressed()) {
return;
}
std::cout << ')';
}

void PrintCPUBackendASTVisitor::visit(MultiplyDivideNode& node)
{
if (isSuppressed()) {
return;
}
std::cout << "er.multiplyDivide(";
}

void PrintCPUBackendASTVisitor::midVisit(MultiplyDivideNode& node)
{
if (isSuppressed()) {
return;
}
std::cout << ", ";
}

void PrintCPUBackendASTVisitor::postVisit(MultiplyDivideNode& node)
{
if (isSuppressed()) {
return;
}
std::cout << ')';
}
void PrintCPUBackendASTVisitor::visit(ComparisonOpNode&)
{
if (isSuppressed()) {
Expand Down Expand Up @@ -697,7 +744,7 @@ void PrintCPUBackendASTVisitor::postVisit(RandomAccessNode& node)
const char* PrintCPUBackendASTVisitor::cppStartString() const
{
if ( use_cartesian_ ) {
return ::impl_cppCartesianStartString();
return ::impl_cppCartesianStartString();
}
return ::impl_cppStartString();
}
Expand Down
Loading