diff --git a/code/AST.h b/code/AST.h index 5fbf60b..9dd96a4 100644 --- a/code/AST.h +++ b/code/AST.h @@ -1,509 +1,281 @@ #ifndef AST_H #define AST_H -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/InitLLVM.h" -#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include + +enum class VarType { + INT, BOOL, FLOAT, CHAR, + STRING, ARRAY, ERROR, NEUTRAL +}; +enum class BinaryOp { + ADD, SUBTRACT, MULTIPLY, DIVIDE, MOD, + EQUAL, NOT_EQUAL, LESS, LESS_EQUAL, GREATER, GREATER_EQUAL, + AND, OR, INDEX, CONCAT, POW, + ARRAY_ADD, ARRAY_SUBTRACT, ARRAY_MULTIPLY, ARRAY_DIVIDE +}; -class AST; // Abstract Syntax Tree -class Base; // top level program -class PrintStatement; -class BinaryOp; // binary operation of numbers and identifiers -class Statement; // top level statement -class BooleanOp; // boolean operation like 3 > 6*2; -class Expression; // top level expression that is evaluated to boolean, int or variable name at last -class IfStatement; -class DecStatement; // declaration statement like int a; -class ElseIfStatement; -class ElseStatement; -class AssignStatement; // assignment statement like a = 3; -class ForStatement; -class WhileStatement; -// class afterCheckStatement; +enum class UnaryOp { + INCREMENT, DECREMENT, + LENGTH, MIN, MAX, ABS +}; +enum class LoopType { FOR, FOREACH, WHILE }; -class ASTVisitor -{ -public: - // Virtual visit functions for each AST node type - virtual void visit(AST&) {} - virtual void visit(Expression&) {} - virtual void visit(Base&) = 0; - virtual void visit(Statement&) = 0; - virtual void visit(BinaryOp&) = 0; - virtual void visit(DecStatement&) = 0; - virtual void visit(AssignStatement&) = 0; - virtual void visit(BooleanOp&) = 0; - virtual void visit(IfStatement&) = 0; - virtual void visit(ElseIfStatement&) = 0; - virtual void visit(ElseStatement&) = 0; - virtual void visit(PrintStatement&) = 0; - virtual void visit(ForStatement&) = 0; - virtual void visit(WhileStatement&) = 0; +enum class StmtType { + IF, ELSE_IF, ELSE, + FOR, WHILE, TRY_CATCH, + VAR_DECL, ASSIGN, PRINT, + BLOCK, MATCH }; -class AST { +class ASTNode { public: - virtual ~AST() {} - virtual void accept(ASTVisitor& V) = 0; + virtual ~ASTNode() = default; + virtual void print(llvm::raw_ostream &os, int indent = 0) const = 0; }; -class TopLevelEntity : AST { +// barname - majmoo e ii az gozaare ha +class ProgramNode : public ASTNode { public: - TopLevelEntity() {} + std::vector> statements; + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; - -class Expression : public TopLevelEntity { +// Var decleration +class VarDeclNode : public ASTNode { public: - enum ExpressionType { - Number, - Identifier, - Boolean, - BinaryOpType, - BooleanOpType - }; -private: - ExpressionType Type; - llvm::StringRef Value; - int NumberVal; - bool BoolVal; - BooleanOp* BOVal; + VarType type; + std::string name; + std::unique_ptr value; + + VarDeclNode(VarType t, std::string n, std::unique_ptr v) + : type(t), name(n), value(std::move(v)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; +}; +// tarif chand moteghayyer +class MultiVarDeclNode : public ASTNode { public: - Expression() {} - Expression(llvm::StringRef value) : Type(ExpressionType::Identifier), Value(value) {} // store string - Expression(int value) : Type(ExpressionType::Number), NumberVal(value) {} // store number - Expression(bool value) : Type(ExpressionType::Boolean), BoolVal(value) {} // store boolean - Expression(BooleanOp* value) : Type(ExpressionType::BooleanOpType), BOVal(value) {} // store boolean - Expression(ExpressionType type) : Type(type) {} - - bool isNumber() { - if (Type == ExpressionType::Number) - return true; - return false; - } - - bool isBoolean() { - if (Type == ExpressionType::Boolean) - return true; - return false; - } - - bool isVariable() { - if (Type == ExpressionType::Identifier) - return true; - return false; - } - - bool isBinaryOp() { - if (Type == ExpressionType::BinaryOpType) - return true; - return false; - } - - bool isBooleanOp() { - if (Type == ExpressionType::BooleanOpType) - return true; - return false; - } - - llvm::StringRef getValue() { - return Value; - } - - int getNumber() { - return NumberVal; - } - - BooleanOp* getBooleanOp() { - return BOVal; - } - - bool getBoolean() { - return BoolVal; - } - ExpressionType getKind() - { - return Type; - } - - virtual void accept(ASTVisitor& V) override - { - V.visit(*this); - } + std::vector> declarations; + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -class BooleanOp : public Expression -{ +// Assignment +class AssignNode : public ASTNode { public: - enum Operator - { - LessEqual, - Less, - Greater, - GreaterEqual, - Equal, - NotEqual, - And, - Or - }; - -private: - Expression* Left; - Expression* Right; - Operator Op; + std::string target; + BinaryOp op; + std::unique_ptr value; + + AssignNode(std::string t, BinaryOp o, std::unique_ptr v) + : target(t), op(o), value(std::move(v)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; +}; +// Refrence to Variable +class VarRefNode : public ASTNode { public: - BooleanOp(Operator Op, Expression* L, Expression* R) : Op(Op), Left(L), Right(R), Expression(ExpressionType::BooleanOpType) { } - - Expression* getLeft() { return Left; } - - Expression* getRight() { return Right; } - - Operator getOperator() { return Op; } - - virtual void accept(ASTVisitor& V) override - { - V.visit(*this); - } + std::string name; + + VarRefNode(std::string n) : name(n) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; - -class BinaryOp : public Expression -{ +// Literals +template +class LiteralNode : public ASTNode { public: - enum Operator - { - Plus, - Minus, - Mul, - Div, - Mod, - Pow - }; + T value; + + LiteralNode(T val) : value(val) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override { + os << value; + } +}; -private: - Expression* Left; // Left-hand side expression - Expression* Right; // Right-hand side expression - Operator Op; // Operator of the binary operation +using IntLiteral = LiteralNode; +using FloatLiteral = LiteralNode; +using BoolLiteral = LiteralNode; +using CharLiteral = LiteralNode; +using StrLiteral = LiteralNode; +// Binaray +class BinaryOpNode : public ASTNode { public: - BinaryOp(Operator Op, Expression* L, Expression* R) : Op(Op), Left(L), Right(R), Expression(ExpressionType::BinaryOpType) {} - - Expression* getLeft() { return Left; } - - Expression* getRight() { return Right; } - - Operator getOperator() { return Op; } - - virtual void accept(ASTVisitor& V) override - { - V.visit(*this); - } + BinaryOp op; + std::unique_ptr left; + std::unique_ptr right; + + BinaryOpNode(BinaryOp o, std::unique_ptr l, std::unique_ptr r) + : op(o), left(std::move(l)), right(std::move(r)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; - -class Base : public AST -{ -private: - llvm::SmallVector statements; - +// Unary +class UnaryOpNode : public ASTNode { public: - Base(llvm::SmallVector Statements) : statements(Statements) {} - llvm::SmallVector getStatements() { return statements; } - - llvm::SmallVector::const_iterator begin() { return statements.begin(); } - - llvm::SmallVector::const_iterator end() { return statements.end(); } - virtual void accept(ASTVisitor &V) override - { - V.visit(*this); - } + UnaryOp op; + std::unique_ptr operand; + + UnaryOpNode(UnaryOp o, std::unique_ptr opnd) + : op(o), operand(std::move(opnd)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -// stores information of a statement. For example x=56; is a statement -class Statement : public TopLevelEntity -{ +// Code Block +class BlockNode : public ASTNode { public: - enum StatementType - { - If, - ElseIf, - Else, - Print, - Declaration, - Assignment, - While, - For - }; - -private: - StatementType Type; + std::vector> statements; + + void print(llvm::raw_ostream &os, int indent = 0) const override; +}; +// if/else +class IfElseNode : public ASTNode { public: - StatementType getKind() - { - return Type; - } - - Statement(StatementType type) : Type(type) {} - virtual void accept(ASTVisitor &V) override - { - V.visit(*this); - } + std::unique_ptr condition; + std::unique_ptr thenBlock; + std::unique_ptr elseBlock; // Can be BlockNode or IfElseNode + + IfElseNode(std::unique_ptr cond, + std::unique_ptr thenBlk, + std::unique_ptr elseBlk = nullptr) + : condition(std::move(cond)), + thenBlock(std::move(thenBlk)), + elseBlock(std::move(elseBlk)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -class PrintStatement : public Statement -{ - private: - Expression * expr; - public: - // contructor that sets type - PrintStatement(Expression * identifier) : expr(identifier), Statement(Statement::StatementType::Print) {} - Expression * getExpr() - { - return expr; - } - virtual void accept(ASTVisitor &V) override - { - V.visit(*this); - } +// for +class ForLoopNode : public ASTNode { +public: + std::unique_ptr init; + std::unique_ptr condition; + std::unique_ptr update; + std::unique_ptr body; + + ForLoopNode(std::unique_ptr i, + std::unique_ptr cond, + std::unique_ptr upd, + std::unique_ptr b) + : init(std::move(i)), condition(std::move(cond)), + update(std::move(upd)), body(std::move(b)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; - -class DecStatement : public Statement { +// foreach +class ForeachLoopNode : public ASTNode { public: - enum DecStatementType { - Number, - Boolean, - }; -private: - - Expression* lvalue; - Expression* rvalue; - Statement::StatementType type; - DecStatement::DecStatementType dec_type; + std::string varName; + std::unique_ptr collection; + std::unique_ptr body; + + ForeachLoopNode(std::string var, + std::unique_ptr coll, + std::unique_ptr b) + : varName(var), collection(std::move(coll)), body(std::move(b)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; +}; +// print +class PrintNode : public ASTNode { public: - DecStatement(Expression* lvalue, Expression* rvalue) : lvalue(lvalue), rvalue(rvalue), type(Statement::StatementType::Declaration), Statement(type) { } - DecStatement(Expression* lvalue) : lvalue(lvalue), rvalue(rvalue), type(Statement::StatementType::Declaration), Statement(type) { rvalue = new Expression(0); } - DecStatement(Expression* lvalue, Expression* rvalue, DecStatement::DecStatementType dec_type) : lvalue(lvalue), rvalue(rvalue), type(Statement::StatementType::Declaration), dec_type(dec_type), Statement(type) { } - - Expression* getLValue() { - return lvalue; - } - - Expression* getRValue() { - return rvalue; - } - - DecStatementType getDecType() { - return dec_type; - } - - virtual void accept(ASTVisitor& V) override - { - V.visit(*this); - } + std::unique_ptr expr; + + PrintNode(std::unique_ptr e) : expr(std::move(e)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; - -class AssignStatement : public Statement { -private: - - Expression* lvalue; - Expression* rvalue; - Statement::StatementType type; - +// araye +class ArrayNode : public ASTNode { public: - AssignStatement(Expression* lvalue, Expression* rvalue) : lvalue(lvalue), rvalue(rvalue), type(Statement::StatementType::Assignment), Statement(type) { } - Expression* getLValue() { - return lvalue; - } - - Expression* getRValue() { - return rvalue; - } - - virtual void accept(ASTVisitor& V) override - { - V.visit(*this); - } + std::vector> elements; + + ArrayNode(std::vector> elems) + : elements(std::move(elems)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -class IfStatement : public Statement -{ - -private: - Expression *condition; - llvm::SmallVector statements; - llvm::SmallVector elseIfStatements; - ElseStatement *elseStatement; - bool hasElseIf; - bool hasElse; - +// nth khoone Arraye +class ArrayAccessNode : public ASTNode { public: - IfStatement(Expression *condition, - llvm::SmallVector statements, - llvm::SmallVector elseIfStatements, - ElseStatement *elseStatement, - bool hasElseIf, - bool hasElse, - StatementType type) : condition(condition), - statements(statements), - elseIfStatements(elseIfStatements), - elseStatement(elseStatement), - hasElseIf(hasElseIf), - hasElse(hasElse), - Statement(type) {} - - Expression *getCondition() - { - return condition; - } - - bool HasElseIf() - { - return hasElseIf; - } - - bool HasElse() - { - return hasElse; - } - - llvm::SmallVector getElseIfStatements() - { - return elseIfStatements; - } - - llvm::SmallVector getStatements() - { - return statements; - } - - ElseStatement *getElseStatement() - { - return elseStatement; - } - - virtual void accept(ASTVisitor &V) override - { - V.visit(*this); - } + std::string arrayName; + std::unique_ptr index; + + ArrayAccessNode(std::string name, std::unique_ptr idx) + : arrayName(name), index(std::move(idx)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -class ElseIfStatement : public Statement -{ - -private: - Expression *condition; - llvm::SmallVector statements; - +// Concat +class ConcatNode : public ASTNode { public: - ElseIfStatement(Expression *condition, llvm::SmallVector statements, StatementType type) : condition(condition), statements(statements), Statement(type) {} - - Expression *getCondition() - { - return condition; - } - - llvm::SmallVector getStatements() - { - return statements; - } - - virtual void accept(ASTVisitor &V) override - { - V.visit(*this); - } + std::unique_ptr left; + std::unique_ptr right; + + ConcatNode(std::unique_ptr l, std::unique_ptr r) + : left(std::move(l)), right(std::move(r)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -class ElseStatement : public Statement -{ - -private: - llvm::SmallVector statements; - +class PowNode : public ASTNode { public: - ElseStatement(llvm::SmallVector statements, Statement::StatementType type) : statements(statements), Statement(type) {} - - llvm::SmallVector getStatements() - { - return statements; - } - - virtual void accept(ASTVisitor &V) override - { - V.visit(*this); - } + std::unique_ptr base; + std::unique_ptr exponent; + + PowNode(std::unique_ptr b, std::unique_ptr e) + : base(std::move(b)), exponent(std::move(e)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -class WhileStatement : public Statement { - -private: - Expression* condition; - llvm::SmallVector statements; - bool optimized; - +// Error Control +class TryCatchNode : public ASTNode { public: - WhileStatement(Expression* condition, llvm::SmallVector statements, StatementType type) : condition(condition), statements(statements), Statement(type) {} - WhileStatement(Expression* condition, llvm::SmallVector statements, StatementType type, bool optimized) : condition(condition), statements(statements), Statement(type), optimized(optimized) { } - Expression* getCondition() - { - return condition; - } - - llvm::SmallVector getStatements() - { - return statements; - } - - virtual void accept(ASTVisitor& V) override - { - V.visit(*this); - } - - bool isOptimized(){ - return optimized; - } + std::unique_ptr tryBlock; + std::unique_ptr catchBlock; + std::string errorVar; + + TryCatchNode(std::unique_ptr tryBlk, + std::unique_ptr catchBlk, + std::string errVar) + : tryBlock(std::move(tryBlk)), + catchBlock(std::move(catchBlk)), + errorVar(errVar) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -class ForStatement : public Statement { - -private: - Expression* condition; - llvm::SmallVector statements; - AssignStatement *initial_assign; - AssignStatement *update_assign; - bool optimized; +// match pattern +class MatchNode : public ASTNode { public: - ForStatement(Expression* condition,llvm::SmallVector statements,AssignStatement *initial_assign,AssignStatement *update_assign, Statement type ) : condition(condition), statements(statements),initial_assign(initial_assign),update_assign(update_assign) , Statement(type){} - ForStatement(Expression* condition,llvm::SmallVector statements,AssignStatement *initial_assign,AssignStatement *update_assign, Statement type, bool optimized) : condition(condition), statements(statements),initial_assign(initial_assign),update_assign(update_assign) , Statement(type), optimized(optimized){} - Expression* getCondition() - { - return condition; - } - - llvm::SmallVector getStatements() - { - return statements; - } - bool isOptimized(){ - return optimized; - } - AssignStatement* getInitialAssign(){ - return initial_assign; - } - AssignStatement* getUpdateAssign(){ - return update_assign; - } - virtual void accept(ASTVisitor& V) override - { - V.visit(*this); - } - + std::unique_ptr expr; + std::vector> cases; + + MatchNode(std::unique_ptr e, std::vector> c) + : expr(std::move(e)), cases(std::move(c)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; + #endif diff --git a/code/CMakeLists.txt b/code/CMakeLists.txt index ea89570..499de1d 100644 --- a/code/CMakeLists.txt +++ b/code/CMakeLists.txt @@ -1,10 +1,14 @@ -add_executable (compiler - main.cpp - code_generator.cpp - lexer.cpp - parser.cpp - semantic.cpp - error.cpp - optimizer.cpp - ) +find_package(LLVM REQUIRED CONFIG) +message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +include(HandleLLVMOptions) +include(AddLLVM) + +llvm_map_components_to_libnames(llvm_libs + Core + IRReader + Support + AsmPrinter +) + target_link_libraries(compiler PRIVATE ${llvm_libs}) diff --git a/code/code_generator.cpp b/code/code_generator.cpp index f256fad..2d4602d 100644 --- a/code/code_generator.cpp +++ b/code/code_generator.cpp @@ -1,543 +1,566 @@ -#include "code_generator.h" -#include "optimizer.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/Support/raw_ostream.h" +#include "codegen.h" +#include "AST.h" +#include "semantic.h" +#include +#include +#include using namespace llvm; -// Define a visitor class for generating LLVM IR from the AST. -namespace -{ - class ToIRVisitor : public ASTVisitor - { - Module *M; - IRBuilder<> Builder; - Type *VoidTy; - Type *Int32Ty; - Type *Int1Ty; - Type *Int8PtrTy; - Type *Int8PtrPtrTy; - Constant *Int32Zero; - Constant *Int1Zero; - - Value *V; - StringMap nameMap; - - llvm::FunctionType *MainFty; - llvm::Function *MainFn; - FunctionType *CalcWriteFnTy; - FunctionType *CalcWriteFnTyBool; - Function *CalcWriteFn; - Function *CalcWriteFnBool; - bool optimize; - int k; - - - public: - // Constructor for the visitor class - ToIRVisitor(Module *M, bool optimize_enable, int k_value) : M(M), Builder(M->getContext()) - { - // Initialize LLVM types and constants - VoidTy = Type::getVoidTy(M->getContext()); - Int32Ty = Type::getInt32Ty(M->getContext()); - Int1Ty = Type::getInt1Ty(M->getContext()); - Int8PtrTy = Type::getInt8PtrTy(M->getContext()); - Int8PtrPtrTy = Int8PtrTy->getPointerTo(); - Int32Zero = ConstantInt::get(Int32Ty, 0, true); - CalcWriteFnTy = FunctionType::get(VoidTy, {Int32Ty}, false); - CalcWriteFnTyBool = FunctionType::get(VoidTy, {Int1Ty}, false); - CalcWriteFn = Function::Create(CalcWriteFnTy, GlobalValue::ExternalLinkage, "print", M); - CalcWriteFnBool = Function::Create(CalcWriteFnTyBool, GlobalValue::ExternalLinkage, "printBool", M); - optimize = optimize_enable; - k = k_value; - } - - // Entry point for generating LLVM IR from the AST - void run(AST *Tree) - { - // Create the main function with the appropriate function type. - MainFty = FunctionType::get(Int32Ty, {Int32Ty, Int8PtrPtrTy}, false); - MainFn = Function::Create(MainFty, GlobalValue::ExternalLinkage, "main", M); - - // Create a basic block for the entry point of the main function. - BasicBlock *BB = BasicBlock::Create(M->getContext(), "entry", MainFn); - Builder.SetInsertPoint(BB); +CodeGen::CodeGen() : context(std::make_unique()), + module(std::make_unique("main", *context)), + builder(std::make_unique>(*context)) { + + FunctionType* mainType = FunctionType::get(Type::getInt32Ty(*context), false); + mainFunc = Function::Create(mainType, Function::ExternalLinkage, "main", *module); + BasicBlock* entry = BasicBlock::Create(*context, "entry", mainFunc); + builder->SetInsertPoint(entry); + + printfFunc = Function::Create( + FunctionType::get(Type::getInt32Ty(*context), {Type::getInt8PtrTy(*context)}, true), + Function::ExternalLinkage, "printf", module.get() + ); + + Function::Create( + FunctionType::get(Type::getInt8PtrTy(*context), {Type::getInt32Ty(*context)}, false), + Function::ExternalLinkage, "malloc", module.get() + ); + + Function::Create( + FunctionType::get(Type::getVoidTy(*context), + {Type::getInt8PtrTy(*context), Type::getInt8PtrTy(*context), Type::getInt32Ty(*context)}, + false), + Function::ExternalLinkage, "memcpy", module.get() + ); +} - // Visit the root node of the AST to generate IR - Tree->accept(*this); +void CodeGen::generate(ProgramNode& ast) { + for (auto& stmt : ast.statements) { + generateStatement(stmt.get()); + } + + if (!builder->GetInsertBlock()->getTerminator()) { + builder->CreateRet(ConstantInt::get(Type::getInt32Ty(*context), 0)); + } + + std::string error; + raw_string_ostream os(error); + if (verifyModule(*module, &os)) { + throw std::runtime_error("Invalid IR: " + error); + } +} - // Create a return instruction at the end of the main function - Builder.CreateRet(Int32Zero); +void CodeGen::generateStatement(ASTNode* node) { + if (auto multiVarDecl = dynamic_cast(node)) { + for (auto& decl : multiVarDecl->declarations) { + generateVarDecl(decl.get()); } + } else if (auto assign = dynamic_cast(node)) { + generateAssign(assign); + } else if (auto ifElse = dynamic_cast(node)) { + generateIfElse(ifElse); + } else if (auto loop = dynamic_cast(node)) { + generateForLoop(loop); + } else if (auto print = dynamic_cast(node)) { + generatePrint(print); + } else if (auto block = dynamic_cast(node)) { + generateBlock(block); + } else if (auto tryCatch = dynamic_cast(node)) { + generateTryCatch(tryCatch); + } else if (auto match = dynamic_cast(node)) { + generateMatch(match); + } else if (auto unary = dynamic_cast(node)) { + generateUnaryOp(unary); + } else if (auto concat = dynamic_cast(node)) { + generateConcat(concat); + } +} - virtual void visit(Base &Node) override - { - // TODO: find a better way to not implement this again! - for (auto I = Node.begin(), E = Node.end(); I != E; ++I) - { - (*I)->accept(*this); +void CodeGen::generateVarDecl(VarDeclNode* node) { + Type* type = nullptr; + switch(node->type) { + case VarType::INT: type = Type::getInt32Ty(*context); break; + case VarType::FLOAT: type = Type::getFloatTy(*context); break; + case VarType::BOOL: type = Type::getInt1Ty(*context); break; + case VarType::STRING: type = Type::getInt8PtrTy(*context); break; + case VarType::ARRAY: + type = PointerType::get(Type::getInt32Ty(*context), 0); + break; + default: throw std::runtime_error("Unknown type"); + } + + AllocaInst* alloca = builder->CreateAlloca(type, nullptr, node->name); + symbols[node->name] = alloca; + + if (node->value) { + Value* val = generateValue(node->value.get(), type); + builder->CreateStore(val, alloca); + + if (node->type == VarType::ARRAY) { + if (auto arrLit = dynamic_cast(node->value.get())) { + arraySizes[node->name] = arrLit->elements.size(); } } + } +} - virtual void visit(Statement &Node) override - { - // TODO: find a better way to not implement this again! - if (Node.getKind() == Statement::StatementType::Declaration) - { - DecStatement *declaration = (DecStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::Assignment) - { - AssignStatement *declaration = (AssignStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::If) - { - IfStatement *declaration = (IfStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::ElseIf) - { - ElseIfStatement *declaration = (ElseIfStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::Else) - { - ElseStatement *declaration = (ElseStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::Print) - { - PrintStatement *declaration = (PrintStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::While) - { - WhileStatement *declaration = (WhileStatement *)&Node; - declaration->accept(*this); - } +Value* CodeGen::generateValue(ASTNode* node, Type* expectedType) { + if (auto binOp = dynamic_cast(node)) { + return generateBinaryOp(binOp, expectedType); + } else if (auto varRef = dynamic_cast(node)) { + if (!symbols.count(varRef->name)) { + throw std::runtime_error("Undefined variable: " + varRef->name); } + return builder->CreateLoad(symbols[varRef->name]->getAllocatedType(), symbols[varRef->name]); + } else if (auto intLit = dynamic_cast*>(node)) { + return ConstantInt::get(Type::getInt32Ty(*context), intLit->value); + } else if (auto floatLit = dynamic_cast*>(node)) { + return ConstantFP::get(Type::getFloatTy(*context), floatLit->value); + } else if (auto strLit = dynamic_cast*>(node)) { + return builder->CreateGlobalStringPtr(strLit->value); + } else if (auto arr = dynamic_cast(node)) { + return generateArray(arr, expectedType); + } else if (auto access = dynamic_cast(node)) { + return generateArrayAccess(access); + } + throw std::runtime_error("Unsupported node type"); +} - virtual void visit(PrintStatement &Node) override { - // Visit the right-hand side of the expression and get its value. - Node.getExpr()->accept(*this); - Value *val = V; - - // Determine the type of 'val' and select the appropriate print function - Type *valType = val->getType(); - if (valType == Int32Ty) { - CallInst *Call = Builder.CreateCall(CalcWriteFnTy, CalcWriteFn, {val}); - } else if (valType == Int1Ty) { - CallInst *Call = Builder.CreateCall(CalcWriteFnTyBool, CalcWriteFnBool, {val}); +Value* CodeGen::generateBinaryOp(BinaryOpNode* node, Type* expectedType) { + Value* L = generateValue(node->left.get(), expectedType); + Value* R = generateValue(node->right.get(), expectedType); + + switch(node->op) { + case BinaryOp::ADD: + if (L->getType()->isFloatTy() || R->getType()->isFloatTy()) { + return builder->CreateFAdd(L, R); } else { - // If the type is not supported, print an error message - llvm::errs() << "Unsupported type for print statement\n"; + return builder->CreateAdd(L, R); } - } - - virtual void visit(Expression &Node) override - { - if (Node.getKind() == Expression::ExpressionType::Identifier) - { - AllocaInst *allocaInst = nameMap[Node.getValue()]; - if (!allocaInst) { - llvm::errs() << "Undefined variable '" << Node.getValue() << "'\n"; - return; - } - V = Builder.CreateLoad(allocaInst->getAllocatedType(), allocaInst, Node.getValue()); - } - else if (Node.getKind() == Expression::ExpressionType::Number) - { - int int_value = Node.getNumber(); - V = ConstantInt::get(Int32Ty, int_value, true); - } - else if (Node.getKind() == Expression::ExpressionType::Boolean) - { - bool bool_value = Node.getBoolean(); - V = ConstantInt::getBool(Int1Ty, bool_value); - } - } - - virtual void visit(BooleanOp &Node) override - { - // Visit the left-hand side of the binary operation and get its value - Node.getLeft()->accept(*this); - Value *Left = V; - - // Visit the right-hand side of the binary operation and get its value - Node.getRight()->accept(*this); - Value *Right = V; - - // Perform the boolean operation based on the operator type and create the corresponding instruction - switch (Node.getOperator()) - { - case BooleanOp::Equal: - V = Builder.CreateICmpEQ(Left, Right); - break; - case BooleanOp::NotEqual: - V = Builder.CreateICmpNE(Left, Right); - break; - case BooleanOp::Less: - V = Builder.CreateICmpSLT(Left, Right); - break; - case BooleanOp::LessEqual: - V = Builder.CreateICmpSLE(Left, Right); - break; - case BooleanOp::Greater: - V = Builder.CreateICmpSGT(Left, Right); - break; - case BooleanOp::GreaterEqual: - V = Builder.CreateICmpSGE(Left, Right); - break; - case BooleanOp::And: - V = Builder.CreateAnd(Left, Right); - break; - case BooleanOp::Or: - V = Builder.CreateOr(Left, Right); - break; + case BinaryOp::SUBTRACT: + if (L->getType()->isFloatTy() || R->getType()->isFloatTy()) { + return builder->CreateFSub(L, R); + } else { + return builder->CreateSub(L, R); } - } - - virtual void visit(BinaryOp &Node) override - { - // Visit the left-hand side of the binary operation and get its value - Node.getLeft()->accept(*this); - Value *Left = V; - - // Visit the right-hand side of the binary operation and get its value - Node.getRight()->accept(*this); - Value *Right = V; - - // Perform the binary operation based on the operator type and create the corresponding instruction - switch (Node.getOperator()) - { - case BinaryOp::Plus: - V = Builder.CreateNSWAdd(Left, Right); - break; - case BinaryOp::Minus: - V = Builder.CreateNSWSub(Left, Right); - break; - case BinaryOp::Mul: - V = Builder.CreateNSWMul(Left, Right); - break; - case BinaryOp::Div: - V = Builder.CreateSDiv(Left, Right); - break; - case BinaryOp::Pow:{ - Function *func = Builder.GetInsertBlock()->getParent(); - BasicBlock *preLoopBB = Builder.GetInsertBlock(); - BasicBlock *loopBB = BasicBlock::Create(M->getContext(), "loop", func); - BasicBlock *afterLoopBB = BasicBlock::Create(M->getContext(), "afterloop", func); - - // Setup loop entry - Builder.CreateBr(loopBB); - Builder.SetInsertPoint(loopBB); - - // Create loop variables (PHI nodes) - PHINode *resultPhi = Builder.CreatePHI(Left->getType(), 2, "result"); - resultPhi->addIncoming(Left, preLoopBB); - - PHINode *indexPhi = Builder.CreatePHI(Type::getInt32Ty(M->getContext()), 2, "index"); - indexPhi->addIncoming(ConstantInt::get(Type::getInt32Ty(M->getContext()), 0), preLoopBB); - - // Perform multiplication - Value *updatedResult = Builder.CreateNSWMul(resultPhi, Left, "multemp"); - Value *updatedIndex = Builder.CreateAdd(indexPhi, ConstantInt::get(Type::getInt32Ty(M->getContext()), 1), "indexinc"); - - // Exit condition - Value *condition = Builder.CreateICmpNE(updatedIndex, Right, "loopcond"); - Builder.CreateCondBr(condition, loopBB, afterLoopBB); - - // Update PHI nodes for the next iteration - resultPhi->addIncoming(updatedResult, loopBB); - indexPhi->addIncoming(updatedIndex, loopBB); - - // Complete the loop - Builder.SetInsertPoint(afterLoopBB); - V = resultPhi; // The result of a^b - break; + case BinaryOp::MULTIPLY: + if (L->getType()->isFloatTy() || R->getType()->isFloatTy()) { + return builder->CreateFMul(L, R); + } else { + return builder->CreateMul(L, R); } - case BinaryOp::Mod: - Value *division = Builder.CreateSDiv(Left, Right); - Value *multiplication = Builder.CreateNSWMul(division, Right); - V = Builder.CreateNSWSub(Left, multiplication); + case BinaryOp::DIVIDE: + if (L->getType()->isFloatTy() || R->getType()->isFloatTy()) { + return builder->CreateFDiv(L, R); + } else { + return builder->CreateSDiv(L, R); } - } - - virtual void visit(DecStatement &Node) override - { - Value *val = nullptr; + case BinaryOp::INDEX: + return generateArrayIndex(L, R); + case BinaryOp::CONCAT: + return generateStringConcat(L, R); + default: + throw std::runtime_error("Unsupported binary op"); + } +} - if (Node.getRValue() != nullptr) - { - // If there is an expression provided, visit it and get its value - Node.getRValue()->accept(*this); - val = V; - } +Value* CodeGen::generateArrayIndex(Value* arrayPtr, Value* index) { + Value* gep = builder->CreateGEP( + Type::getInt32Ty(*context), + arrayPtr, + {index} + ); + return builder->CreateLoad(Type::getInt32Ty(*context), gep); +} - // Iterate over the variables declared in the declaration statement - auto I = Node.getLValue()->getValue(); - StringRef Var = I; +Value* CodeGen::generateStringConcat(Value* L, Value* R) { + Function* strlenFunc = module->getFunction("strlen"); + Function* mallocFunc = module->getFunction("malloc"); + Function* memcpyFunc = module->getFunction("memcpy"); + + Value* lenL = builder->CreateCall(strlenFunc, {L}); + Value* lenR = builder->CreateCall(strlenFunc, {R}); + Value* totalLen = builder->CreateAdd(lenL, lenR); + totalLen = builder->CreateAdd(totalLen, ConstantInt::get(Type::getInt32Ty(*context), 1)); + + Value* buffer = builder->CreateCall(mallocFunc, {totalLen}); + builder->CreateCall(memcpyFunc, {buffer, L, lenL}); + Value* dest = builder->CreateGEP(Type::getInt8Ty(*context), buffer, lenL); + builder->CreateCall(memcpyFunc, {dest, R, lenR}); + + return buffer; +} - // Create an alloca instruction to allocate memory for the variable - Type *varType = (Node.getDecType() == DecStatement::DecStatementType::Number) ? Int32Ty : Type::getInt1Ty(M->getContext()); - nameMap[Var] = Builder.CreateAlloca(varType); +void CodeGen::generateIfElse(IfElseNode* node) { + Function* func = builder->GetInsertBlock()->getParent(); + BasicBlock* thenBB = BasicBlock::Create(*context, "then", func); + BasicBlock* elseBB = BasicBlock::Create(*context, "else"); + BasicBlock* mergeBB = BasicBlock::Create(*context, "ifcont"); + + Value* cond = generateValue(node->condition.get(), Type::getInt1Ty(*context)); + builder->CreateCondBr(cond, thenBB, elseBB); + + builder->SetInsertPoint(thenBB); + generateStatement(node->thenBlock.get()); + builder->CreateBr(mergeBB); + + func->getBasicBlockList().push_back(elseBB); + builder->SetInsertPoint(elseBB); + if (node->elseBlock) { + generateStatement(node->elseBlock.get()); + } + builder->CreateBr(mergeBB); + + func->getBasicBlockList().push_back(mergeBB); + builder->SetInsertPoint(mergeBB); +} - // Store the initial value (if any) in the variable's memory location - if (val != nullptr) - { - Builder.CreateStore(val, nameMap[Var]); - } - else - { - if(Node.getDecType() == DecStatement::DecStatementType::Number){ - Value *Zero = ConstantInt::get(Type::getInt32Ty(M->getContext()), 0); - Builder.CreateStore(Zero, nameMap[Var]); - } else{ - Value *False = ConstantInt::get(Type::getInt1Ty(M->getContext()), 0); - Builder.CreateStore(False, nameMap[Var]); - } - - } - } +void CodeGen::generateForLoop(ForLoopNode* node) { + Function* func = builder->GetInsertBlock()->getParent(); + BasicBlock* loopStart = BasicBlock::Create(*context, "loop.start", func); + BasicBlock* loopBody = BasicBlock::Create(*context, "loop.body"); + BasicBlock* loopEnd = BasicBlock::Create(*context, "loop.end"); + + if (node->init) generateStatement(node->init.get()); + builder->CreateBr(loopStart); + + builder->SetInsertPoint(loopStart); + Value* cond = node->condition ? + generateValue(node->condition.get(), Type::getInt1Ty(*context)) : + ConstantInt::getTrue(*context); + builder->CreateCondBr(cond, loopBody, loopEnd); + + func->getBasicBlockList().push_back(loopBody); + builder->SetInsertPoint(loopBody); + generateStatement(node->body.get()); + if (node->update) generateStatement(node->update.get()); + builder->CreateBr(loopStart); + + func->getBasicBlockList().push_back(loopEnd); + builder->SetInsertPoint(loopEnd); +} - virtual void visit(AssignStatement &Node) override - { - // Visit the right-hand side of the assignment and get its value - Node.getRValue()->accept(*this); - Value *val = V; +void CodeGen::generatePrint(PrintNode* node) { + Value* value = generateValue(node->expr.get(), nullptr); + Type* ty = value->getType(); + + Constant* format = nullptr; + if (ty->isIntegerTy(32)) { + format = builder->CreateGlobalStringPtr("%d\n"); + } else if (ty->isFloatTy()) { + format = builder->CreateGlobalStringPtr("%f\n"); + } else if (ty->isPointerTy()) { + format = builder->CreateGlobalStringPtr("%s\n"); + } + + builder->CreateCall(printfFunc, {format, value}); +} - // Get the name of the variable being assigned - auto varName = Node.getLValue()->getValue(); +void CodeGen::generateTryCatch(TryCatchNode* node) { + Function* func = builder->GetInsertBlock()->getParent(); + BasicBlock* tryBB = BasicBlock::Create(*context, "try", func); + BasicBlock* catchBB = BasicBlock::Create(*context, "catch"); + BasicBlock* contBB = BasicBlock::Create(*context, "try.cont"); + + LandingPadInst* lp = builder->CreateLandingPad( + StructType::get(Type::getInt8PtrTy(*context), Type::getInt32Ty(*context)), 0); + lp->addClause(ConstantPointerNull::get(Type::getInt8PtrTy(*context))); + + builder->CreateBr(tryBB); + + builder->SetInsertPoint(tryBB); + generateStatement(node->tryBlock.get()); + builder->CreateBr(contBB); + + func->getBasicBlockList().push_back(catchBB); + builder->SetInsertPoint(catchBB); + if (!node->errorVar.empty()) { + AllocaInst* alloca = builder->CreateAlloca(Type::getInt8PtrTy(*context), nullptr, node->errorVar); + builder->CreateStore(lp->getOperand(0), alloca); + } + generateStatement(node->catchBlock.get()); + builder->CreateBr(contBB); + + func->getBasicBlockList().push_back(contBB); + builder->SetInsertPoint(contBB); +} - // Create a store instruction to assign the value to the variable - Builder.CreateStore(val, nameMap[varName]); +void CodeGen::generateMatch(MatchNode* node) { + Value* expr = generateValue(node->expr.get(), nullptr); + Function* func = builder->GetInsertBlock()->getParent(); + BasicBlock* endBB = BasicBlock::Create(*context, "match.end"); + + std::vector caseBBs; + for (size_t i = 0; i < node->cases.size(); ++i) { + caseBBs.push_back(BasicBlock::Create(*context, "case." + std::to_string(i), func)); + } + + for (size_t i = 0; i < node->cases.size(); ++i) { + builder->SetInsertPoint(caseBBs[i]); + if (node->cases[i]->value) { + Value* caseVal = generateValue(node->cases[i]->value.get(), expr->getType()); + Value* cmp = builder->CreateICmpEQ(expr, caseVal); + BasicBlock* nextBB = (i < node->cases.size()-1) ? caseBBs[i+1] : endBB; + builder->CreateCondBr(cmp, caseBBs[i], nextBB); + } else { + generateStatement(node->cases[i]->body.get()); + builder->CreateBr(endBB); } + } + + func->getBasicBlockList().push_back(endBB); + builder->SetInsertPoint(endBB); +} - virtual void visit(IfStatement &Node) override - { - llvm::BasicBlock *IfCondBB = llvm::BasicBlock::Create(M->getContext(), "if.cond", MainFn); - llvm::BasicBlock *IfBodyBB = llvm::BasicBlock::Create(M->getContext(), "if.body", MainFn); - llvm::BasicBlock *AfterIfBB = llvm::BasicBlock::Create(M->getContext(), "after.if", MainFn); - - Builder.CreateBr(IfCondBB); - Builder.SetInsertPoint(IfCondBB); - - Node.getCondition()->accept(*this); - Value *Cond = V; - - Builder.SetInsertPoint(IfBodyBB); - - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); - } - Builder.CreateBr(AfterIfBB); - llvm::BasicBlock *BeforeCondBB = IfCondBB; - llvm::BasicBlock *BeforeBodyBB = IfBodyBB; - llvm::Value *BeforeCondVal = Cond; - - if (Node.HasElseIf()) - { - for (auto &elseIf : Node.getElseIfStatements()) - { - llvm::BasicBlock *ElseIfCondBB = llvm::BasicBlock::Create(MainFn->getContext(), "elseIf.cond", MainFn); - llvm::BasicBlock *ElseIfBodyBB = llvm::BasicBlock::Create(MainFn->getContext(), "elseIf.body", MainFn); - - Builder.SetInsertPoint(BeforeCondBB); - Builder.CreateCondBr(BeforeCondVal, BeforeBodyBB, ElseIfCondBB); - - Builder.SetInsertPoint(ElseIfCondBB); - elseIf->getCondition()->accept(*this); - llvm::Value *ElseIfCondVal = V; - - Builder.SetInsertPoint(ElseIfBodyBB); - elseIf->accept(*this); - Builder.CreateBr(AfterIfBB); - - BeforeCondBB = ElseIfCondBB; - BeforeCondVal = ElseIfCondVal; - BeforeBodyBB = ElseIfBodyBB; - } - // finishing the last else - // Builder.CreateCondBr(BeforeCondVal, BeforeBodyBB, AfterIfBB); - } +// ... ادامه از قسمت قبلی - llvm::BasicBlock *ElseBB = nullptr; - if (Node.HasElse()) - { - ElseStatement *elseS = Node.getElseStatement(); - ElseBB = llvm::BasicBlock::Create(MainFn->getContext(), "else.body", MainFn); - if(Node.HasElseIf()){ - Builder.SetInsertPoint(BeforeCondBB); - Builder.CreateCondBr(BeforeCondVal, BeforeBodyBB, ElseBB); - } - Builder.SetInsertPoint(ElseBB); - Node.getElseStatement()->accept(*this); - Builder.CreateBr(AfterIfBB); - } - else - { - Builder.SetInsertPoint(BeforeCondBB); - if(Node.HasElseIf()){ - Builder.CreateCondBr(BeforeCondVal, BeforeBodyBB, AfterIfBB); - }else{ - Builder.CreateCondBr(Cond, IfBodyBB, AfterIfBB); - } - - } +Value* CodeGen::generateArray(ArrayNode* node, Type* expectedType) { + ArrayType* arrType = ArrayType::get(Type::getInt32Ty(*context), node->elements.size()); + Value* arrayPtr = builder->CreateAlloca(arrType); + + for (size_t i = 0; i < node->elements.size(); ++i) { + Value* idx[] = { + ConstantInt::get(Type::getInt32Ty(*context), 0), + ConstantInt::get(Type::getInt32Ty(*context), i) + }; + Value* elemPtr = builder->CreateInBoundsGEP(arrType, arrayPtr, idx); + Value* val = generateValue(node->elements[i].get(), Type::getInt32Ty(*context)); + builder->CreateStore(val, elemPtr); + } + + return builder->CreateBitCast(arrayPtr, expectedType); +} - Builder.SetInsertPoint(AfterIfBB); - } +Value* CodeGen::generateArrayAccess(ArrayAccessNode* node) { + Value* arrayPtr = symbols[node->arrayName]; + Value* index = generateValue(node->index.get(), Type::getInt32Ty(*context)); + + // Runtime bounds checking + if (arraySizes.count(node->arrayName)) { + Value* size = ConstantInt::get(Type::getInt32Ty(*context), arraySizes[node->arrayName]); + Value* cond = builder->CreateICmpSGE(index, size); + + BasicBlock* errorBB = BasicBlock::Create(*context, "bounds_error", builder->GetInsertBlock()->getParent()); + BasicBlock* validBB = BasicBlock::Create(*context, "valid_index"); + + builder->CreateCondBr(cond, errorBB, validBB); + + builder->SetInsertPoint(errorBB); + Function* throwFunc = Function::Create( + FunctionType::get(Type::getVoidTy(*context), {Type::getInt8PtrTy(*context)}, false), + Function::ExternalLinkage, "throw_exception", module.get() + ); + Value* msg = builder->CreateGlobalStringPtr("Array index out of bounds!"); + builder->CreateCall(throwFunc, {msg}); + builder->CreateUnreachable(); + + builder->GetInsertBlock()->getParent()->getBasicBlockList().push_back(validBB); + builder->SetInsertPoint(validBB); + } + + return generateArrayIndex(arrayPtr, index); +} - virtual void visit(ElseIfStatement &Node) override - { - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); +Value* CodeGen::generateUnaryOp(UnaryOpNode* node) { + Value* operand = generateValue(node->operand.get(), nullptr); + + switch(node->op) { + case UnaryOp::MINUS: + if (operand->getType()->isFloatTy()) { + return builder->CreateFNeg(operand); + } else { + return builder->CreateNeg(operand); } - } - - virtual void visit(ElseStatement &Node) override - { - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); + case UnaryOp::INCREMENT: { + Value* inc = ConstantInt::get(operand->getType(), 1); + Value* newVal = builder->CreateAdd(operand, inc); + if (auto varRef = dynamic_cast(node->operand.get())) { + builder->CreateStore(newVal, symbols[varRef->name]); } + return operand; // Post-increment returns original value } - - virtual void visit(WhileStatement &Node) override - { - if (optimize && !Node.isOptimized()) { - llvm::SmallVector unrolledStatements = completeUnroll(&Node, k); - for (auto I = unrolledStatements.begin(), E = unrolledStatements.end(); I != E; ++I) - { - (*I)->accept(*this); - } - return; - } - llvm::BasicBlock* WhileCondBB = llvm::BasicBlock::Create(M->getContext(), "while.cond", MainFn); - // The basic block for the while body. - llvm::BasicBlock* WhileBodyBB = llvm::BasicBlock::Create(M->getContext(), "while.body", MainFn); - // The basic block after the while statement. - llvm::BasicBlock* AfterWhileBB = llvm::BasicBlock::Create(M->getContext(), "after.while", MainFn); - - // Branch to the condition block. - Builder.CreateBr(WhileCondBB); - - // Set the insertion point to the condition block. - Builder.SetInsertPoint(WhileCondBB); - - // Visit the condition expression and create the conditional branch. - Node.getCondition()->accept(*this); - Value* Cond = V; - Builder.CreateCondBr(Cond, WhileBodyBB, AfterWhileBB); - - // Set the insertion point to the body block. - Builder.SetInsertPoint(WhileBodyBB); - - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); + case UnaryOp::LENGTH: { + if (!arraySizes.count(node->operand->getName())) { + throw std::runtime_error("Length operator on non-array type"); } - - // Branch back to the condition block. - Builder.CreateBr(WhileCondBB); - // Set the insertion point to the block after the while loop. - Builder.SetInsertPoint(AfterWhileBB); + return ConstantInt::get(Type::getInt32Ty(*context), arraySizes[node->operand->getName()]); } - virtual void visit(ForStatement &Node) override - { - if (optimize && !Node.isOptimized()) { - llvm::SmallVector unrolledStatements = completeUnroll(&Node, k); - for (auto I = unrolledStatements.begin(), E = unrolledStatements.end(); I != E; ++I) - { - (*I)->accept(*this); - } - return; - } - llvm::BasicBlock* ForCondBB = llvm::BasicBlock::Create(M->getContext(), "for.cond", MainFn); - // The basic block for the while body. - llvm::BasicBlock* ForBodyBB = llvm::BasicBlock::Create(M->getContext(), "for.body", MainFn); - // The basic block after the while statement. - llvm::BasicBlock* AfterForBB = llvm::BasicBlock::Create(M->getContext(), "after.for", MainFn); - - llvm::BasicBlock* ForUpdateBB = llvm::BasicBlock::Create(M->getContext(), "for.update", MainFn); - - AssignStatement * initial_assign = Node.getInitialAssign(); - ((Expression *)initial_assign->getRValue())->accept(*this); - Value *val = V; - - // Get the name of the variable being assigned - auto varName = ((Expression *)initial_assign->getLValue())->getValue(); - - // Create a store instruction to assign the value to the variable - Builder.CreateStore(val, nameMap[varName]); - - - // Branch to the condition block. - Builder.CreateBr(ForCondBB); - - // Set the insertion point to the condition block. - Builder.SetInsertPoint(ForCondBB); - - // Visit the condition expression and create the conditional branch. - Node.getCondition()->accept(*this); - Value* Cond = V; - Builder.CreateCondBr(Cond, ForBodyBB, AfterForBB); - - // Set the insertion point to the body block. - Builder.SetInsertPoint(ForBodyBB); - - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); - } + default: + throw std::runtime_error("Unsupported unary operator"); + } +} - Builder.CreateBr(ForUpdateBB); +void CodeGen::generateCompoundAssign(CompoundAssignNode* node) { + Value* target = symbols[node->target]; + Value* current = builder->CreateLoad(target->getType()->getPointerElementType(), target); + Value* rhs = generateValue(node->value.get(), current->getType()); + + Value* result; + switch(node->op) { + case BinaryOp::ADD: + result = current->getType()->isFloatTy() ? + builder->CreateFAdd(current, rhs) : + builder->CreateAdd(current, rhs); + break; + case BinaryOp::SUBTRACT: + result = current->getType()->isFloatTy() ? + builder->CreateFSub(current, rhs) : + builder->CreateSub(current, rhs); + break; + default: + throw std::runtime_error("Unsupported compound assignment"); + } + + builder->CreateStore(result, target); +} - Builder.SetInsertPoint(ForUpdateBB); +void CodeGen::generateWhileLoop(WhileLoopNode* node) { + Function* func = builder->GetInsertBlock()->getParent(); + BasicBlock* condBB = BasicBlock::Create(*context, "while.cond", func); + BasicBlock* bodyBB = BasicBlock::Create(*context, "while.body"); + BasicBlock* endBB = BasicBlock::Create(*context, "while.end"); + + builder->CreateBr(condBB); + + // Condition block + builder->SetInsertPoint(condBB); + Value* cond = generateValue(node->condition.get(), Type::getInt1Ty(*context)); + builder->CreateCondBr(cond, bodyBB, endBB); + + // Body block + func->getBasicBlockList().push_back(bodyBB); + builder->SetInsertPoint(bodyBB); + generateStatement(node->body.get()); + builder->CreateBr(condBB); // Loop back + + // End block + func->getBasicBlockList().push_back(endBB); + builder->SetInsertPoint(endBB); +} - AssignStatement * update_assign = Node.getUpdateAssign(); - ((Expression *)update_assign->getRValue())->accept(*this); - val = V; +Value* CodeGen::generateTernaryExpr(TernaryExprNode* node) { + Value* cond = generateValue(node->condition.get(), Type::getInt1Ty(*context)); + Function* func = builder->GetInsertBlock()->getParent(); + + BasicBlock* trueBB = BasicBlock::Create(*context, "ternary.true", func); + BasicBlock* falseBB = BasicBlock::Create(*context, "ternary.false"); + BasicBlock* mergeBB = BasicBlock::Create(*context, "ternary.merge"); + + builder->CreateCondBr(cond, trueBB, falseBB); + + // True branch + builder->SetInsertPoint(trueBB); + Value* trueVal = generateValue(node->trueExpr.get(), nullptr); + builder->CreateBr(mergeBB); + + // False branch + func->getBasicBlockList().push_back(falseBB); + builder->SetInsertPoint(falseBB); + Value* falseVal = generateValue(node->falseExpr.get(), nullptr); + builder->CreateBr(mergeBB); + + // Merge + func->getBasicBlockList().push_back(mergeBB); + builder->SetInsertPoint(mergeBB); + PHINode* phi = builder->CreatePHI(trueVal->getType(), 2); + phi->addIncoming(trueVal, trueBB); + phi->addIncoming(falseVal, falseBB); + + return phi; +} - // Get the name of the variable being assigned - varName = ((Expression *)update_assign->getLValue())->getValue(); +void CodeGen::generateTypeConversion(TypeConversionNode* node) { + Value* val = generateValue(node->expr.get(), nullptr); + Type* targetType = getLLVMType(node->targetType); + + if (val->getType() == targetType) return val; + + if (targetType->isFloatTy() && val->getType()->isIntegerTy()) { + return builder->CreateSIToFP(val, targetType); + } + if (targetType->isIntegerTy() && val->getType()->isFloatTy()) { + return builder->CreateFPToSI(val, targetType); + } + if (targetType->isPointerTy() && val->getType()->isIntegerTy()) { + return builder->CreateIntToPtr(val, targetType); + } + + throw std::runtime_error("Invalid type conversion"); +} - // Create a store instruction to assign the value to the variable - Builder.CreateStore(val, nameMap[varName]); +Value* CodeGen::generateBuiltInCall(BuiltInCallNode* node) { + if (node->funcName == "pow") { + Value* base = generateValue(node->args[0].get(), Type::getDoubleTy(*context)); + Value* exp = generateValue(node->args[1].get(), Type::getDoubleTy(*context)); + return builder->CreateCall( + Intrinsic::getDeclaration(module.get(), Intrinsic::pow, {Type::getDoubleTy(*context)}), + {base, exp} + ); + } + if (node->funcName == "abs") { + Value* val = generateValue(node->args[0].get(), nullptr); + if (val->getType()->isFloatTy()) { + return builder->CreateCall( + Intrinsic::getDeclaration(module.get(), Intrinsic::fabs, {val->getType()}), + {val} + ); + } else { + Value* zero = ConstantInt::get(val->getType(), 0); + Value* neg = builder->CreateNeg(val); + return builder->CreateSelect( + builder->CreateICmpSGT(val, zero), + val, + neg + ); + } + } + throw std::runtime_error("Unknown built-in function"); +} - // Branch back to the condition block. - Builder.CreateBr(ForCondBB); - // Set the insertion point to the block after the while loop. - Builder.SetInsertPoint(AfterForBB); +void CodeGen::generateMemoryManagement() { + // Register destructors for stack-allocated arrays + for (auto& [name, alloca] : symbols) { + if (alloca->getAllocatedType()->isPointerTy()) { + Function* freeFunc = Function::Create( + FunctionType::get(Type::getVoidTy(*context), {Type::getInt8PtrTy(*context)}, false), + Function::ExternalLinkage, "free", module.get() + ); + + BasicBlock* currentBB = builder->GetInsertBlock(); + BasicBlock* freeBB = BasicBlock::Create(*context, "free." + name, currentBB->getParent()); + + builder->SetInsertPoint(freeBB); + Value* casted = builder->CreateBitCast(alloca, Type::getInt8PtrTy(*context)); + builder->CreateCall(freeFunc, {casted}); + builder->CreateBr(currentBB); + + currentBB->getParent()->getBasicBlockList().push_back(freeBB); } - }; + } +} + +void CodeGen::optimizeIR() { + legacy::FunctionPassManager FPM(module.get()); + FPM.add(createPromoteMemoryToRegisterPass()); + FPM.add(createInstructionCombiningPass()); + FPM.add(createReassociatePass()); + FPM.add(createGVNPass()); + FPM.add(createCFGSimplificationPass()); -}; // namespace + FPM.doInitialization(); + for (Function &F : *module) { + FPM.run(F); + } +} -void CodeGen::compile(AST *Tree, bool optimize, int k) -{ - // Create an LLVM context and a module - LLVMContext Ctx; - Module *M = new Module("mas.expr", Ctx); - // Create an instance of the ToIRVisitor and run it on the AST to generate LLVM IR - ToIRVisitor ToIRn(M, optimize, k); - ToIRn.run(Tree); +void CodeGen::dump() const { + module->print(llvm::outs(), nullptr); +} - // Print the generated module to the standard output - M->print(outs(), nullptr); +extern "C" void throw_exception(const char* msg) { + throw std::runtime_error(msg); } diff --git a/code/code_generator.h b/code/code_generator.h index 0012c9f..b645723 100644 --- a/code/code_generator.h +++ b/code/code_generator.h @@ -1,11 +1,42 @@ -#ifndef CODEGEN_H -#define CODEGEN_H +#ifndef CODE_GENERATOR_H +#define CODE_GENERATOR_H +#include +#include +#include +#include #include "AST.h" -class CodeGen -{ +#include +#include +#include +#include +#include +#include + +class CodeGen { public: - void compile(AST *Tree, bool optimize, int k); + void compile(ProgramNode *root, bool optimize, int unroll); + void dump() const; + +private: + std::unique_ptr context; + std::unique_ptr module; + std::unique_ptr> builder; + + llvm::Function* currentFunc = nullptr; + std::unordered_map symbols; + std::unordered_map arraySizes; + + void declareRuntimeFunctions(); + void generateStatement(ASTNode* node); + llvm::Value* generateValue(ASTNode* node, llvm::Type* expectedType); + llvm::Value* generatePow(llvm::Value* base, llvm::Value* exp); + + void printArray(const std::vector& elements); + void printArrayVar(llvm::Value* arrayPtr, uint64_t size); + void generateTryCatch(TryCatchNode* node); + void generateMatch(MatchNode* node); }; + #endif diff --git a/code/lexer.cpp b/code/lexer.cpp index 1ef0975..7b7a8e4 100644 --- a/code/lexer.cpp +++ b/code/lexer.cpp @@ -1,294 +1,194 @@ #include "lexer.h" -#include +#include - -namespace charinfo -{ - LLVM_READNONE inline bool isWhitespace(char c) - { - return c == ' ' || c == '\t' || c == '\f' || - c == '\v' || c == '\r' || c == '\n'; - } - - LLVM_READNONE inline bool isDigit(char c) - { - return c >= '0' && c <= '9'; - } - - LLVM_READNONE inline bool isLetter(char c) - { - return (c >= 'a' && c <= 'z') || - (c >= 'A' && c <= 'Z'); - } - - LLVM_READNONE inline bool isOperator(char c) - { - return c == '+' || c == '-' || c == '*' || - c == '/' || c == '^' || c == '=' || - c == '%' || c == '<' || c == '>' || - c == '!'; - } - - LLVM_READNONE inline bool isSpecialCharacter(char c) - { - return c == ';' || c == ',' || c == '(' || - c == ')' || c == '{' || c == '}' || c == ','; - } +Lexer::Lexer(llvm::StringRef buffer) { + bufferStart = buffer.begin(); + bufferPtr = bufferStart; } -void Lexer::next(Token &token) -{ - // Skips whitespace like " " - while (*BufferPtr && charinfo::isWhitespace(*BufferPtr)) - { - ++BufferPtr; - } - bool is_comment = false; - if (*BufferPtr == '/' && *(BufferPtr + 1) == '*'){ - is_comment = true; - } - while(is_comment){ - if (!*BufferPtr) - { - token.Kind = Token::eof; - return; - } - while (*BufferPtr && *BufferPtr != '*' && *(BufferPtr + 1) != '/') - { - ++BufferPtr; - } - if(*BufferPtr == '*' && *(BufferPtr + 1) == '/'){ - BufferPtr += 2; - is_comment = false; - } - else{ - ++BufferPtr; - } - while (*BufferPtr && charinfo::isWhitespace(*BufferPtr)) - { - ++BufferPtr; - } - } - // since end of context is 0 -> !0 = true -> end of context - if (!*BufferPtr) - { - token.Kind = Token::eof; - return; - } - // looking for keywords or identifiers like "int", a123 , ... - if (charinfo::isLetter(*BufferPtr)) - { - const char *end = BufferPtr + 1; - - // until reaches the end of lexeme - // example: ".int " -> "i.nt " -> "in.t " -> "int. " - while (charinfo::isLetter(*end) || charinfo::isDigit(*end) || *end == '_') - { - ++end; - } - - llvm::StringRef Context(BufferPtr, end - BufferPtr); // start of lexeme, length of lexeme - Token::TokenKind kind; - if (Context == "int") - { - kind = Token::KW_int; - } - else if (Context == "if") - { - kind = Token::KW_if; - } - else if (Context == "bool") - { - kind = Token::KW_bool; - } - else if (Context == "else") - { - kind = Token::KW_else; - } - else if (Context == "while") - { - kind = Token::KW_while; - } - else if (Context == "for") - { - kind = Token::KW_for; - } - else if (Context == "and") - { - kind = Token::KW_and; - } - else if (Context == "or") - { - kind = Token::KW_or; - } - else if (Context == "true") - { - kind = Token::KW_true; - } - else if (Context == "false") - { - kind = Token::KW_false; - } - else if (Context == "print") - { - kind = Token::KW_print; - } - else - { - kind = Token::identifier; - } - - formToken(token, end, kind); - return; - } - - else if (charinfo::isDigit(*BufferPtr)) - { - - const char *end = BufferPtr + 1; - - while (charinfo::isDigit(*end)) - ++end; - - formToken(token, end, Token::number); - return; - } - else if (charinfo::isSpecialCharacter(*BufferPtr)) - { - switch (*BufferPtr) - { - case ';': - formToken(token, BufferPtr + 1, Token::semi_colon); - break; - case ',': - formToken(token, BufferPtr + 1, Token::comma); - break; - case '(': - formToken(token, BufferPtr + 1, Token::l_paren); - break; - case ')': - formToken(token, BufferPtr + 1, Token::r_paren); - break; - case '{': - formToken(token, BufferPtr + 1, Token::l_brace); - break; - case '}': - formToken(token, BufferPtr + 1, Token::r_brace); - break; - } - } - else - { - - if (*BufferPtr == '=' && *(BufferPtr + 1) == '=') - { // == - formToken(token, BufferPtr + 2, Token::equal_equal); - } - else if (*BufferPtr == '+' && *(BufferPtr + 1) == '=') - { // += - formToken(token, BufferPtr + 2, Token::plus_equal); - } - else if (*BufferPtr == '-' && *(BufferPtr + 1) == '=') - { // -= - formToken(token, BufferPtr + 2, Token::minus_equal); - } - - else if (*BufferPtr == '*' && *(BufferPtr + 1) == '=') - { // *= - formToken(token, BufferPtr + 2, Token::star_equal); - } - - else if (*BufferPtr == '/' && *(BufferPtr + 1) == '=') - { // /= - formToken(token, BufferPtr + 2, Token::slash_equal); - } - - else if (*BufferPtr == '%' && *(BufferPtr + 1) == '=') - { // %= - formToken(token, BufferPtr + 2, Token::mod_equal); - } - - else if (*BufferPtr == '!' && *(BufferPtr + 1) == '=') - { // != - formToken(token, BufferPtr + 2, Token::not_equal); - } - - else if (*BufferPtr == '<' && *(BufferPtr + 1) == '=') - { // <= - formToken(token, BufferPtr + 2, Token::less_equal); - } - - else if (*BufferPtr == '>' && *(BufferPtr + 1) == '=') - { - formToken(token, BufferPtr + 2, Token::greater_equal); // >= - } - else if (*BufferPtr == '/' && *(BufferPtr + 1) == '*') - { - formToken(token, BufferPtr + 2, Token::comment); // comment - } - else if (*BufferPtr == '*' && *(BufferPtr + 1) == '/') - { - formToken(token, BufferPtr + 2, Token::uncomment); // uncomment - } - else if (*BufferPtr == '+' && *(BufferPtr + 1) == '+') - { - formToken(token, BufferPtr + 2, Token::plus_plus); // ++ - } - else if (*BufferPtr == '-' && *(BufferPtr + 1) == '-') - { - formToken(token, BufferPtr + 2, Token::minus_minus); // -- - } - else if (*BufferPtr == '=') - { - formToken(token, BufferPtr + 1, Token::equal); // = - } - else if (*BufferPtr == '+') - { - formToken(token, BufferPtr + 1, Token::plus); // + - } - else if (*BufferPtr == '-') - { - formToken(token, BufferPtr + 1, Token::minus); // - - } - else if (*BufferPtr == '*') - { - formToken(token, BufferPtr + 1, Token::star); // * - } - else if (*BufferPtr == '/') - { - formToken(token, BufferPtr + 1, Token::slash); // / - } - else if (*BufferPtr == '^') - { - formToken(token, BufferPtr + 1, Token::power); // ^ - } - else if (*BufferPtr == '>') - { - formToken(token, BufferPtr + 1, Token::greater); // > - } - else if (*BufferPtr == '<') - { - formToken(token, BufferPtr + 1, Token::less); // < - } - else if (*BufferPtr == '!') - { - formToken(token, BufferPtr + 1, Token::KW_not ); // ! - } - else if (*BufferPtr == '%') - { - formToken(token, BufferPtr + 1, Token::mod); // % - } - else - { - formToken(token, BufferPtr + 1, Token::unknown); - } - } +void Lexer::skipWhitespace() { + while (isspace(*bufferPtr)) { + if (*bufferPtr == '\n') bufferPtr++; + else bufferPtr++; + } +} +void Lexer::skipComment() { + if (*bufferPtr == '/' && *(bufferPtr + 1) == '*') { + bufferPtr += 2; + while (!(*bufferPtr == '*' && *(bufferPtr + 1) == '/')) bufferPtr++; + bufferPtr += 2; + } } -void Lexer::formToken(Token &Tok, const char *TokEnd, Token::TokenKind Kind) -{ - Tok.Kind = Kind; - Tok.Text = llvm::StringRef(BufferPtr, TokEnd - BufferPtr); - BufferPtr = TokEnd; -} \ No newline at end of file +Token Lexer::nextToken() { + skipWhitespace(); + skipComment(); + + if (bufferPtr >= bufferStart + buffer.size()) + return {Token::eof, llvm::StringRef(), 0, 0}; + + const char *tokStart = bufferPtr; + + // KWords & Identifiers + if (isalpha(*bufferPtr)) { + while (isalnum(*bufferPtr) || *bufferPtr == '_') bufferPtr++; + llvm::StringRef text(tokStart, bufferPtr - tokStart); + + Token::TokenKind kind = Token::identifier; + if (text == "int") kind = Token::KW_int; + else if (text == "bool") kind = Token::KW_bool; + else if (text == "float") kind = Token::KW_float; + else if (text == "char") kind = Token::KW_char; + else if (text == "string") kind = Token::KW_string; + else if (text == "array") kind = Token::KW_array; + else if (text == "if") kind = Token::KW_if; + else if (text == "else") kind = Token::KW_else; + else if (text == "while") kind = Token::KW_while; + else if (text == "for") kind = Token::KW_for; + else if (text == "foreach") kind = Token::KW_foreach; + else if (text == "in") kind = Token::KW_in; + else if (text == "print") kind = Token::KW_print; + else if (text == "true") kind = Token::KW_true; + else if (text == "false") kind = Token::KW_false; + else if (text == "try") kind = Token::KW_try; + else if (text == "catch") kind = Token::KW_catch; + else if (text == "error") kind = Token::KW_error; + else if (text == "match") kind = Token::KW_match; + else if (text == "concat") kind = Token::KW_concat; + else if (text == "pow") kind = Token::KW_pow; + else if (text == "abs") kind = Token::KW_abs; + else if (text == "length") kind = Token::KW_length; + else if (text == "min") kind = Token::KW_min; + else if (text == "max") kind = Token::KW_max; + else if (text == "and") kind = Token::and_op; + else if (text == "or") kind = Token::or_op; + else if (text == "not") kind = Token::not_op; + else if (text == "index") kind = Token::KW_index; + + return {kind, text, 0, 0}; + } + + // numbers + if (isdigit(*bufferPtr) || *bufferPtr == '.') { + bool hasDot = false; + if (*bufferPtr == '.') { + hasDot = true; + bufferPtr++; + } + while (isdigit(*bufferPtr) || (!hasDot && *bufferPtr == '.')) { + if (*bufferPtr == '.') hasDot = true; + bufferPtr++; + } + return {hasDot ? Token::float_literal : Token::number, + llvm::StringRef(tokStart, bufferPtr - tokStart), 0, 0};} + + // strings + if (*bufferPtr == '"') { + bufferPtr++; + const char *start = bufferPtr; + while (*bufferPtr != '"' && *bufferPtr != '\0') bufferPtr++; + return {Token::string_literal, llvm::StringRef(start, bufferPtr - start), 0, 0}; + } + + // charachters + if (*bufferPtr == '\'') { + bufferPtr++; + char c = *bufferPtr; + bufferPtr += 2; // Skip closing ' + return {Token::char_literal, llvm::StringRef(&c, 1), 0, 0}; + } + + // operators & symbols + switch (*bufferPtr) { + case ';': bufferPtr++; return {Token::semi_colon, ";", 0, 0}; + case ',': bufferPtr++; return {Token::comma, ",", 0, 0}; + case '(': bufferPtr++; return {Token::l_paren, "(", 0, 0}; + case ')': bufferPtr++; return {Token::r_paren, ")", 0, 0}; + case '{': bufferPtr++; return {Token::l_brace, "{", 0, 0}; + case '}': bufferPtr++; return {Token::r_brace, "}", 0, 0}; + case '[': bufferPtr++; return {Token::l_bracket, "[", 0, 0}; + case ']': bufferPtr++; return {Token::r_bracket, "]", 0, 0}; + case '+': + if (*(bufferPtr + 1) == '+') { + bufferPtr += 2; + return {Token::plus_plus, "++", 0, 0}; + } else if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::plus_equal, "+=", 0, 0}; + } + bufferPtr++; + return {Token::plus, "+", 0, 0}; + case '-': + if (*(bufferPtr + 1) == '-') { + bufferPtr += 2; + return {Token::minus_minus, "--", 0, 0}; + } else if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::minus_equal, "-=", 0, 0}; + } else if (*(bufferPtr + 1) == '>') { + bufferPtr += 2; + return {Token::arrow, "->", 0, 0}; + } + bufferPtr++; + return {Token::minus, "-", 0, 0}; + case '*': + if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::star_equal, "*=", 0, 0}; + } + bufferPtr++; + return {Token::star, "*", 0, 0}; + case '/': + if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::slash_equal, "/=", 0, 0}; + } + bufferPtr++; + return {Token::slash, "/", 0, 0}; + case '%': + if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::mod_equal, "%=", 0, 0}; + } + bufferPtr++; + return {Token::mod, "%", 0, 0}; + case '^': + bufferPtr++; + return {Token::caret, "^", 0, 0}; + case '=': + if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::equal_equal, "==", 0, 0}; + } + bufferPtr++; + return {Token::equal, "=", 0, 0}; + case '!': + if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::not_equal, "!=", 0, 0}; + } else { + bufferPtr++; + return {Token::not_op, "!", 0, 0}; + } + case '<': + if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::less_equal, "<=", 0, 0}; + } + bufferPtr++; + return {Token::less, "<", 0, 0}; + case '>': + if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::greater_equal, ">=", 0, 0}; + } + bufferPtr++; + return {Token::greater, ">", 0, 0}; + case '_': + bufferPtr++; + return {Token::underscore, "_", 0, 0}; + } + + return {Token::eof, llvm::StringRef(), 0, 0}; +} diff --git a/code/lexer.h b/code/lexer.h index aa0fc6b..248ff73 100644 --- a/code/lexer.h +++ b/code/lexer.h @@ -1,102 +1,112 @@ #ifndef LEXER_H #define LEXER_H + #include "llvm/ADT/StringRef.h" -#include "llvm/Support/MemoryBuffer.h" #include - class Lexer; -class Token -{ - friend class Lexer; - +class Token { public: - enum TokenKind : unsigned short - { - semi_colon, // ; - unknown, // unknown token - identifier, // identifier like a, b, c, d, etc. - number, // number like 1, 2, 3, 4, etc. - comma, // , - plus, // + - minus, // - - star, // * - mod, // % - slash, // / - power, // ^ - l_paren, // ( - r_paren, // ) - l_brace, // { - r_brace, // } - plus_equal, // += - plus_plus, // ++ - minus_minus, // -- - minus_equal, // -= - star_equal, // *= - mod_equal, // %= - slash_equal, // /= - equal, // = - equal_equal, // == - not_equal, // != - KW_not, // ! - less, // < - less_equal, // <= - greater, // > - greater_equal, // >= - comment, // /* - uncomment, // */ - KW_int, // int - KW_bool, // bool - KW_if, // if - KW_else, // else - KW_while, // while - KW_for, // for - KW_and, // and - KW_or, // or - KW_true, // true - KW_false, // false - KW_print, // print - eof // end of file - }; + enum TokenKind { -private: - TokenKind Kind; // - llvm::StringRef Text; // + semi_colon, + identifier, + number, + comma, + l_paren, + r_paren, + l_brace, + r_brace, + l_bracket, + r_bracket, + eof, -public: - TokenKind getKind() const { return Kind; } - bool is(TokenKind K) const { return Kind == K; } - llvm::StringRef getText() const { return Text; } + plus, + minus, + star, + slash, + mod, + caret, + equal, + plus_equal, + minus_equal, + star_equal, + slash_equal, + mod_equal, + equal_equal, + not_equal, + less, + less_equal, + greater, + greater_equal, + and_op, + or_op, + not_op + plus_plus, + minus_minus, - // kind="+" isOneOf(plus, minus) -> true - bool isOneOf(TokenKind K1, TokenKind K2) const - { - return is(K1) || is(K2); - } - template // variadic template can take several inputs - bool isOneOf(TokenKind K1, TokenKind K2, Ts... Ks) const - { - return is(K1) || isOneOf(K2, Ks...); + KW_int, + KW_bool, + KW_float, + KW_char, + KW_string, + KW_array, + KW_if, + KW_else, + KW_while, + KW_for, + KW_foreach, + KW_in, + KW_print, + KW_true, + KW_false, + KW_try, + KW_catch, + KW_error, + KW_match, + KW_concat, + KW_pow, + KW_abs, + KW_length, + KW_min, + KW_max, + KW_index, + + comment_start, + comment_end, + string_literal, + char_literal, + float_literal, + underscore, + arrow + }; + + TokenKind kind; + llvm::StringRef text; + int line; + int column; + + bool is(TokenKind k) const { return kind == k; } + bool isOneOf(TokenKind k1, TokenKind k2) const { return is(k1) || is(k2); } + template bool isOneOf(TokenKind k1, TokenKind k2, Ts... ks) const { + return is(k1) || isOneOf(k2, ks...); } }; -class Lexer -{ - const char *BufferStart; - const char *BufferPtr; +class Lexer { + const char *bufferStart; + const char *bufferPtr; public: - Lexer(const llvm::StringRef &Buffer) - { - BufferStart = Buffer.begin(); - BufferPtr = BufferStart; - } - void next(Token &token); + Lexer(llvm::StringRef buffer); + Token nextToken(); private: - void formToken(Token &Result, const char *TokEnd, Token::TokenKind Kind); + void skipWhitespace(); + void skipComment(); + Token formToken(Token::TokenKind kind, const char *tokEnd); }; #endif diff --git a/code/main.cpp b/code/main.cpp index 6461b0f..f84bbbc 100644 --- a/code/main.cpp +++ b/code/main.cpp @@ -56,7 +56,8 @@ int main(int argc, const char **argv) Token nextToken; Lexer lexer(contentRef); Parser Parser(lexer); - AST *Tree = Parser.parse(); + std::unique_ptr TreePtr = Parser.parseProgram(); + ProgramNode *Tree = TreePtr.get(); Semantic semantic; if (semantic.semantic(Tree)) diff --git a/code/parser.cpp b/code/parser.cpp index 9b2165a..9ca427e 100644 --- a/code/parser.cpp +++ b/code/parser.cpp @@ -1,717 +1,414 @@ -#ifndef PARSER_H -#define PARSER_H -#include "error.h" #include "parser.h" -#include "optimizer.h" -#include -using namespace std; -#endif +#include +#include +#include -Base *Parser::parse() -{ - llvm::SmallVector statements; - bool isComment = false; - while (!Tok.is(Token::eof)) - { - if (isComment) - { - if (Tok.is(Token::uncomment)) - { - isComment = false; - } - advance(); - continue; - } +Parser::Parser(Lexer &lexer) : lexer(lexer) { + advance(); + advance(); +} - switch (Tok.getKind()) - { - case Token::KW_int: - { - llvm::SmallVector states = Parser::parseDefine(Token::KW_int); - if (states.size() == 0) - { - return nullptr; - } - while (states.size() > 0) - { - statements.push_back(states.back()); - states.pop_back(); - } - break; - } - case Token::KW_bool: - { - llvm::SmallVector states = Parser::parseDefine(Token::KW_bool); - if (states.size() == 0) - { - return nullptr; - } - while (states.size() > 0) - { - statements.push_back(states.back()); - states.pop_back(); - } - break; - } - case Token::identifier: - { - llvm::StringRef name = Tok.getText(); - Token current = Tok; - advance(); - if (!Tok.isOneOf(Token::plus_plus, Token::minus_minus)) - { - AssignStatement *assign = parseAssign(name); - statements.push_back(assign); - } - else - { - AssignStatement *assign = parseUnaryExpression(current); - statements.push_back(assign); - } - Parser::check_for_semicolon(); - break; - } - case Token::KW_print: - { - advance(); - if (!Tok.is(Token::l_paren)) - { - Error::LeftParenthesisExpected(); - } - advance(); - // token should be identifier - if (!Tok.is(Token::identifier)) - { - Error::VariableExpected(); - } - Expression *variable_to_be_printed = new Expression(Tok.getText()); - advance(); - if (!Tok.is(Token::r_paren)) - { - Error::RightParenthesisExpected(); - } - advance(); - Parser::check_for_semicolon(); - PrintStatement *print_statement = new PrintStatement(variable_to_be_printed); - statements.push_back(print_statement); - break; - } - case Token::comment: - { - isComment = true; - advance(); - break; - } - case Token::KW_if: - { - IfStatement *statement = parseIf(); - statements.push_back(statement); - break; - } - case Token::KW_while: - { - WhileStatement *statement = parseWhile(); - statements.push_back(statement); - break; - } - case Token::KW_for: - { - ForStatement *statement = parseFor(); - /*llvm::SmallVector unrolled = completeUnroll(statement); - for (Statement *s : unrolled) - { - statements.push_back(s); - }*/ - statements.push_back(statement); - break; - } - } - } - return new Base(statements); +void Parser::advance() { + currentTok = peekTok; + peekTok = lexer.nextToken(); } -void Parser::check_for_semicolon() -{ - if (!Tok.is(Token::semi_colon)) - { - Error::SemiColonExpected(); +void Parser::expect(Token::TokenKind kind) { + if (!currentTok.is(kind)) { + throw std::runtime_error("Syntax Error: Expected " + + tokenToString(kind) + + " but found " + + tokenToString(currentTok.kind) + + " at line " + std::to_string(currentTok.line)); } +} + +void Parser::consume(Token::TokenKind kind) { + expect(kind); advance(); } -AssignStatement *Parser::parseUnaryExpression(Token &token) -{ - AssignStatement* res; - if (Tok.is(Token::plus_plus)) - { - advance(); - if (token.is(Token::identifier)) - { - Expression *tok = new Expression(token.getText()); - Expression *one = new Expression(1); - res = new AssignStatement(tok, new BinaryOp(BinaryOp::Plus, tok, one)); - } - else - { - Error::VariableExpected(); - } +std::unique_ptr Parser::parseProgram() { + auto program = std::make_unique(); + while (!currentTok.is(Token::eof)) { + program->statements.push_back(parseStatement()); } - else if (Tok.is(Token::minus_minus)) - { - advance(); - if (token.is(Token::identifier)) - { - Expression *tok = new Expression(token.getText()); - Expression *one = new Expression(1); - res = new AssignStatement(tok, new BinaryOp(BinaryOp::Minus, tok, one)); - } - else - { - Error::VariableExpected(); - } + return program; +} + +std::unique_ptr Parser::parseStatement() { + switch (currentTok.kind) { + case Token::KW_int: + case Token::KW_bool: + case Token::KW_float: + case Token::KW_char: + case Token::KW_string: + case Token::KW_array: + return parseVarDecl(); + + case Token::KW_if: + return parseIfStatement(); + + case Token::KW_for: + return parseForLoop(); + + case Token::KW_while: + return parseWhileLoop(); + + case Token::KW_print: + return parsePrintStatement(); + + case Token::l_brace: + return parseBlock(); + + default: + return parseExpressionStatement(); } - return res; } -llvm::SmallVector Parser::parseDefine(Token::TokenKind token_kind) -{ + +// Variable Declaration +std::unique_ptr Parser::parseVarDecl() { + VarType type = tokenToVarType(currentTok.kind); advance(); - llvm::SmallVector states; - while (!Tok.is(Token::semi_colon)) - { - llvm::StringRef name; - Expression *value = nullptr; - if (Tok.is(Token::identifier)) - { - name = Tok.getText(); - advance(); - } - else - { - Error::VariableExpected(); - } - if (Tok.is(Token::equal)) - { + + std::vector declarations; + do { + std::string name = currentTok.text.str(); + consume(Token::identifier); + + std::unique_ptr value = nullptr; + if (currentTok.is(Token::equal)) { advance(); value = parseExpression(); } - if (Tok.is(Token::comma)) - { - advance(); - } - else if (Tok.is(Token::semi_colon)) - { - // OK do nothing. The while will break itself. - } - else - { - Error::VariableExpected(); - } - DecStatement *state; - if(token_kind == Token::KW_int){ - state = new DecStatement(new Expression(name), value, DecStatement::DecStatementType::Number); - }else{ - state = new DecStatement(new Expression(name), value, DecStatement::DecStatementType::Boolean); - } - states.push_back(state); - } - - advance(); // pass semicolon - return states; + declarations.emplace_back(type, name, std::move(value)); + + } while (currentTok.is(Token::comma) && advance()); + + consume(Token::semi_colon); + return std::make_unique(std::move(declarations)); } -Expression *Parser::parseExpression() -{ - Expression *left = parseLogicalComparison(); - while (Tok.isOneOf(Token::KW_and, Token::KW_or)) - { - BooleanOp::Operator Op; - switch (Tok.getKind()) - { - case Token::KW_and: - Op = BooleanOp::And; - break; - case Token::KW_or: - Op = BooleanOp::Or; - break; - default: - break; - } +// If Statement +std::unique_ptr Parser::parseIfStatement() { + consume(Token::KW_if); + consume(Token::l_paren); + auto condition = parseExpression(); + consume(Token::r_paren); + + auto thenBlock = parseBlock(); + + std::unique_ptr elseBlock = nullptr; + if (currentTok.is(Token::KW_else)) { advance(); - Expression *Right = parseLogicalComparison(); - left = new BooleanOp(Op, left, Right); + elseBlock = currentTok.is(Token::KW_if) ? parseIfStatement() : parseBlock(); } - return left; + + return std::make_unique( + std::move(condition), + std::move(thenBlock), + std::move(elseBlock) + ); } -Expression *Parser::parseLogicalComparison() -{ - Expression *left = parseIntExpression(); - while (Tok.isOneOf(Token::equal_equal, Token::not_equal, Token::less, Token::less_equal, Token::greater, Token::greater_equal)) - { - BooleanOp::Operator Op; - switch (Tok.getKind()) - { - case Token::equal_equal: - Op = BooleanOp::Equal; - break; - case Token::not_equal: - Op = BooleanOp::NotEqual; - break; - case Token::less: - Op = BooleanOp::Less; - break; - case Token::less_equal: - Op = BooleanOp::LessEqual; - break; - case Token::greater: - Op = BooleanOp::Greater; - break; - case Token::greater_equal: - Op = BooleanOp::GreaterEqual; - break; - default: - break; - } - advance(); - Expression *Right = parseIntExpression(); - left = new BooleanOp(Op, left, Right); +// For Loop +std::unique_ptr Parser::parseForLoop() { + consume(Token::KW_for); + consume(Token::l_paren); + + // Initialization + std::unique_ptr init = nullptr; + if (!currentTok.is(Token::semi_colon)) { + init = parseVarDeclOrExpr(); } - return left; + consume(Token::semi_colon); + + // Condition + std::unique_ptr cond = nullptr; + if (!currentTok.is(Token::semi_colon)) { + cond = parseExpression(); + } + consume(Token::semi_colon); + + // Update + std::unique_ptr update = nullptr; + if (!currentTok.is(Token::r_paren)) { + update = parseExpression(); + } + consume(Token::r_paren); + + auto body = parseBlock(); + + return std::make_unique( + std::move(init), + std::move(cond), + std::move(update), + std::move(body) + ); } -Expression *Parser::parseIntExpression() -{ - Expression *Left = parseTerm(); - while (Tok.isOneOf(Token::plus, Token::minus)) - { - BinaryOp::Operator Op = - Tok.is(Token::plus) ? BinaryOp::Plus : BinaryOp::Minus; - advance(); - Expression *Right = parseTerm(); - Left = new BinaryOp(Op, Left, Right); - } - return Left; +// While Loop +std::unique_ptr Parser::parseWhileLoop() { + consume(Token::KW_while); + consume(Token::l_paren); + auto condition = parseExpression(); + consume(Token::r_paren); + + auto body = parseBlock(); + + return std::make_unique( + std::move(condition), + std::move(body) + ); } -Expression *Parser::parseTerm() -{ - Expression *Left = parseSign(); - while (Tok.isOneOf(Token::star, Token::slash, Token::mod)) - { - BinaryOp::Operator Op = - Tok.is(Token::star) ? BinaryOp::Mul : Tok.is(Token::slash) ? BinaryOp::Div - : BinaryOp::Mod; - advance(); - Expression *Right = parseSign(); - Left = new BinaryOp(Op, Left, Right); - } - return Left; +// Print Statement +std::unique_ptr Parser::parsePrintStatement() { + consume(Token::KW_print); + consume(Token::l_paren); + auto expr = parseExpression(); + consume(Token::r_paren); + consume(Token::semi_colon); + return std::make_unique(std::move(expr)); } -Expression *Parser::parseSign() -{ - if (Tok.is(Token::minus)) - { - advance(); - return new BinaryOp(BinaryOp::Mul, new Expression(-1), parsePower()); +// Block Statement +std::unique_ptr Parser::parseBlock() { + consume(Token::l_brace); + auto block = std::make_unique(); + + while (!currentTok.is(Token::r_brace) && !currentTok.is(Token::eof)) { + block->statements.push_back(parseStatement()); } - else if (Tok.is(Token::plus)) - { + + consume(Token::r_brace); + return block; +} + +// Expression Parsing +std::unique_ptr Parser::parseExpression() { + return parseAssignment(); +} + +std::unique_ptr Parser::parseAssignment() { + auto left = parseLogicalOr(); + + if (currentTok.isOneOf(Token::equal, Token::plus_equal, Token::minus_equal, + Token::star_equal, Token::slash_equal, Token::mod_equal)) { + Token op = currentTok; advance(); - return parsePower(); + auto right = parseAssignment(); + return std::make_unique(std::move(left), op, std::move(right)); } - return parsePower(); + + return left; } -Expression *Parser::parsePower() -{ - Expression *Left = parseFactor(); - while (Tok.is(Token::power)) - { - BinaryOp::Operator Op = - BinaryOp::Pow; +std::unique_ptr Parser::parseLogicalOr() { + auto left = parseLogicalAnd(); + while (currentTok.is(Token::or_op)) { advance(); - Expression *Right = parseFactor(); - Left = new BinaryOp(Op, Left, Right); + auto right = parseLogicalAnd(); + left = std::make_unique(BinaryOp::LogicalOr, std::move(left), std::move(right)); } - return Left; + return left; } -Expression *Parser::parseFactor() -{ - Expression *Res = nullptr; - switch (Tok.getKind()) - { - case Token::number: - { - int number; - Tok.getText().getAsInteger(10, number); - Res = new Expression(number); +std::unique_ptr Parser::parseLogicalAnd() { + auto left = parseEquality(); + while (currentTok.is(Token::and_op)) { advance(); - break; + auto right = parseEquality(); + left = std::make_unique(BinaryOp::LogicalAnd, std::move(left), std::move(right)); } - case Token::identifier: - { - Res = new Expression(Tok.getText()); + return left; +} + +std::unique_ptr Parser::parseEquality() { + auto left = parseComparison(); + while (currentTok.isOneOf(Token::equal_equal, Token::not_equal)) { + BinaryOp op = currentTok.is(Token::equal_equal) ? BinaryOp::Equal : BinaryOp::NotEqual; advance(); - break; + auto right = parseComparison(); + left = std::make_unique(op, std::move(left), std::move(right)); } - case Token::l_paren: - { + return left; +} + +std::unique_ptr Parser::parseComparison() { + auto left = parseAddSub(); + while (currentTok.isOneOf(Token::less, Token::less_equal, Token::greater, Token::greater_equal)) { + BinaryOp op; + switch (currentTok.kind) { + case Token::less: op = BinaryOp::Less; break; + case Token::less_equal: op = BinaryOp::LessEqual; break; + case Token::greater: op = BinaryOp::Greater; break; + case Token::greater_equal: op = BinaryOp::GreaterEqual; break; + default: throw std::runtime_error("Invalid comparison operator"); + } advance(); - Res = parseExpression(); - if (!consume(Token::r_paren)) - break; + auto right = parseAddSub(); + left = std::make_unique(op, std::move(left), std::move(right)); } - case Token::KW_true: - { - Res = new Expression(true); + return left; +} + +std::unique_ptr Parser::parseAddSub() { + auto left = parseMulDiv(); + while (currentTok.isOneOf(Token::plus, Token::minus)) { + BinaryOp op = currentTok.is(Token::plus) ? BinaryOp::Add : BinaryOp::Subtract; advance(); - break; + auto right = parseMulDiv(); + left = std::make_unique(op, std::move(left), std::move(right)); } - case Token::KW_false: - { - Res = new Expression(false); + return left; +} + +std::unique_ptr Parser::parseMulDiv() { + auto left = parseUnary(); + while (currentTok.isOneOf(Token::star, Token::slash, Token::mod)) { + BinaryOp op; + switch (currentTok.kind) { + case Token::star: op = BinaryOp::Multiply; break; + case Token::slash: op = BinaryOp::Divide; break; + case Token::mod: op = BinaryOp::Modulus; break; + default: throw std::runtime_error("Invalid multiplicative operator"); + } advance(); - break; + auto right = parseUnary(); + left = std::make_unique(op, std::move(left), std::move(right)); } - default: // error handling - { - Error::NumberVariableExpected(); - } - } - return Res; + return left; } -AssignStatement *Parser::parseAssign(llvm::StringRef name) -{ - Expression *value = nullptr; - if (Tok.is(Token::equal)) - { - advance(); - value = parseExpression(); - }else if(Tok.isOneOf(Token::plus_equal, Token::minus_equal, Token::star_equal, Token::slash_equal, Token::mod_equal)){ - // storing token - Token current_op = Tok; - advance(); - value = parseExpression(); - if(current_op.is(Token::plus_equal)){ - value = new BinaryOp(BinaryOp::Plus, new Expression(name), value); - }else if(current_op.is(Token::minus_equal)){ - value = new BinaryOp(BinaryOp::Minus, new Expression(name), value); - }else if(current_op.is(Token::star_equal)){ - value = new BinaryOp(BinaryOp::Mul, new Expression(name), value); - }else if(current_op.is(Token::slash_equal)){ - value = new BinaryOp(BinaryOp::Div, new Expression(name), value); - }else if(current_op.is(Token::mod_equal)){ - value = new BinaryOp(BinaryOp::Mod, new Expression(name), value); +std::unique_ptr Parser::parseUnary() { + if (currentTok.isOneOf(Token::plus, Token::minus, Token::not_op)) { + UnaryOp op; + switch (currentTok.kind) { + case Token::plus: op = UnaryOp::Plus; break; + case Token::minus: op = UnaryOp::Minus; break; + case Token::not_op: op = UnaryOp::LogicalNot; break; + default: throw std::runtime_error("Invalid unary operator"); } - }else{ - Error::EqualExpected(); + advance(); + auto operand = parseUnary(); + return std::make_unique(op, std::move(operand)); } - - return new AssignStatement(new Expression(name), value); + return parsePrimary(); } -Base *Parser::parseStatement() -{ - llvm::SmallVector statements; - bool isComment = false; - while (!Tok.is(Token::r_brace) && !Tok.is(Token::eof)) - { - if (isComment) - { - if (Tok.is(Token::uncomment)) - { - isComment = false; - } +std::unique_ptr Parser::parsePrimary() { + switch (currentTok.kind) { + case Token::number: + case Token::float_literal: { + auto value = currentTok.text.str(); advance(); - continue; + return std::make_unique(value); } - - switch (Tok.getKind()) - { - case Token::identifier: - { - llvm::StringRef name = Tok.getText(); - Token current = Tok; + + case Token::string_literal: { + auto value = currentTok.text.str(); advance(); - if (!Tok.isOneOf(Token::plus_plus, Token::minus_minus)) - { - AssignStatement *assign = parseAssign(name); - statements.push_back(assign); - } - else - { - AssignStatement *assign = parseUnaryExpression(current); - statements.push_back(assign); - } - Parser::check_for_semicolon(); - break; + return std::make_unique(value); } - case Token::KW_print: - { + + case Token::KW_true: + case Token::KW_false: { + bool value = currentTok.is(Token::KW_true); advance(); - if (!Tok.is(Token::l_paren)) - { - Error::LeftParenthesisExpected(); - } + return std::make_unique(value); + } + + case Token::identifier: { + std::string name = currentTok.text.str(); advance(); - // token should be identifier - if (!Tok.is(Token::identifier)) - { - Error::VariableExpected(); + + if (currentTok.is(Token::plus_plus) || currentTok.is(Token::minus_minus)) { + Token opTok = currentTok; + advance(); // consume ++ or -- + return std::make_unique( + opTok.is(Token::plus_plus) + ? UnaryOp::INCREMENT + : UnaryOp::DECREMENT, + std::make_unique(name) + ); } - Expression *tok = new Expression(Tok.getText()); - advance(); - if (!Tok.is(Token::r_paren)) - { - Error::RightParenthesisExpected(); + + // Function call or array access + if (currentTok.is(Token::l_paren)) { + return parseFunctionCall(name); + } else if (currentTok.is(Token::l_bracket)) { + return parseArrayAccess(name); } - advance(); - Parser::check_for_semicolon(); - PrintStatement *print_statement = new PrintStatement(tok); - statements.push_back(print_statement); - break; + + return std::make_unique(name); } - case Token::comment: - { - isComment = true; + + case Token::l_paren: { advance(); - break; - } - case Token::KW_if: - { - IfStatement *statement = parseIf(); - statements.push_back(statement); - break; - } - case Token::KW_while: - { - WhileStatement* statement = parseWhile(); - statements.push_back(statement); - break; - } - case Token::KW_for: - { - ForStatement* statement = parseFor(); - statements.push_back(statement); - break; + auto expr = parseExpression(); + consume(Token::r_paren); + return expr; } + + case Token::l_bracket: + return parseArrayLiteral(); + default: - { - Error::UnexpectedToken(Tok); - } - } + throw std::runtime_error("Unexpected primary expression"); } - return new Base(statements); } -IfStatement *Parser::parseIf() -{ - advance(); - if (!Tok.is(Token::l_paren)) - { - Error::LeftParenthesisExpected(); - } - - advance(); - Expression *condition = parseExpression(); - - if (!Tok.is(Token::r_paren)) - { - Error::RightParenthesisExpected(); - } - - advance(); - if (!Tok.is(Token::l_brace)) - { - Error::LeftBraceExpected(); - } - advance(); - Base *allIfStatements = parseStatement(); - if (!Tok.is(Token::r_brace)) - { - Error::RightBraceExpected(); - } - advance(); - - // parse else if and else statements - llvm::SmallVector elseIfStatements; - ElseStatement *elseStatement = nullptr; - bool hasElseIf = false; - bool hasElse = false; - - while (Tok.is(Token::KW_else)) - { - advance(); - if (Tok.is(Token::KW_if)) - { - elseIfStatements.push_back(parseElseIf()); - hasElseIf = true; - } - else if (Tok.is(Token::l_brace)) - { - advance(); - Base *allElseStatements = parseStatement(); - if (!Tok.is(Token::r_brace)) - { - Error::RightBraceExpected(); - } - advance(); - elseStatement = new ElseStatement(allElseStatements->getStatements(), Statement::StatementType::Else); - hasElse = true; - break; - } - else - { - Error::RightBraceExpected(); - } +std::unique_ptr Parser::parseArrayLiteral() { + consume(Token::l_bracket); + std::vector> elements; + + if (!currentTok.is(Token::r_bracket)) { + do { + elements.push_back(parseExpression()); + } while (currentTok.is(Token::comma) && advance()); } - - return new IfStatement(condition, allIfStatements->getStatements(), elseIfStatements, elseStatement, hasElseIf, hasElse, Statement::StatementType::If); + + consume(Token::r_bracket); + return std::make_unique(std::move(elements)); } -ElseIfStatement *Parser::parseElseIf() -{ - advance(); - if (!Tok.is(Token::l_paren)) - { - Error::LeftParenthesisExpected(); - } - - advance(); - Expression *condition = parseExpression(); - - if (!Tok.is(Token::r_paren)) - { - Error::RightParenthesisExpected(); - } - - advance(); - if (!Tok.is(Token::l_brace)) - { - Error::LeftBraceExpected(); - } - advance(); - Base *allIfStatements = parseStatement(); - if (!Tok.is(Token::r_brace)) - { - Error::RightBraceExpected(); - } - advance(); - - return new ElseIfStatement(condition, allIfStatements->getStatements(), Statement::StatementType::ElseIf); +std::unique_ptr Parser::parseArrayAccess(const std::string& name) { + consume(Token::l_bracket); + auto index = parseExpression(); + consume(Token::r_bracket); + return std::make_unique(name, std::move(index)); } -WhileStatement *Parser::parseWhile() -{ - advance(); - if (!Tok.is(Token::l_paren)) - { - Error::LeftParenthesisExpected(); - } - - advance(); - Expression *condition = parseExpression(); - - if (!Tok.is(Token::r_paren)) - { - Error::RightParenthesisExpected(); - } - - advance(); - if (Tok.is(Token::l_brace)) - { - advance(); - Base *allWhileStatements = parseStatement(); - if(!consume(Token::r_brace)) - { - return new WhileStatement(condition, allWhileStatements->getStatements(), Statement::StatementType::While); - } - else - { - Error::RightBraceExpected(); - } - - } - else{ - Error::LeftBraceExpected(); +std::unique_ptr Parser::parseFunctionCall(const std::string& name) { + consume(Token::l_paren); + std::vector> args; + + if (!currentTok.is(Token::r_paren)) { + do { + args.push_back(parseExpression()); + } while (currentTok.is(Token::comma) && advance()); } - advance(); + + consume(Token::r_paren); + return std::make_unique(name, std::move(args)); } -ForStatement *Parser::parseFor() -{ - advance(); - if (!Tok.is(Token::l_paren)) // for(i = 0;i<10;i++) - { - Error::LeftParenthesisExpected(); - } - advance(); //i = 0;i<10;i++) - if (!Tok.is(Token::identifier)) - { - Error::VariableExpected(); - - } - llvm::StringRef name = Tok.getText(); - advance(); //= 0;i<10;i++) - AssignStatement *assign = parseAssign(name); - //Expression *value = nullptr; - /*if (Tok.is(Token::equal)) - { - advance(); //0;i<10;i++) - value = parseExpression(); - AssignStatement *assign = new AssignStatement(new Expression(name), value); - }*/ - //advance(); - check_for_semicolon(); - Expression *condition = parseExpression(); - //advance(); - check_for_semicolon(); - if (!Tok.is(Token::identifier)) - { - Error::VariableExpected(); - - } - llvm::StringRef name_up = Tok.getText(); - Token current = Tok; - AssignStatement *assign_up = nullptr; - advance(); - if (!Tok.isOneOf(Token::plus_plus, Token::minus_minus)) - { - assign_up = parseAssign(name_up); +// Helper Functions +VarType Parser::tokenToVarType(Token::TokenKind kind) { + switch (kind) { + case Token::KW_int: return VarType::Int; + case Token::KW_bool: return VarType::Bool; + case Token::KW_float: return VarType::Float; + case Token::KW_char: return VarType::Char; + case Token::KW_string: return VarType::String; + case Token::KW_array: return VarType::Array; + default: throw std::runtime_error("Invalid variable type"); } - else - { - assign_up = parseUnaryExpression(current); - } - if (!Tok.is(Token::r_paren)) - { - Error::RightParenthesisExpected(); - } - - advance(); - if (Tok.is(Token::l_brace)) - { - advance(); - Base *allForStatements = parseStatement(); - if(!consume(Token::r_brace)) - { - return new ForStatement(condition, allForStatements->getStatements(),assign,assign_up, Statement::StatementType::For); - } - else - { - Error::RightBraceExpected(); - } - - } - else{ - Error::LeftBraceExpected(); - } - advance(); - - } +std::string Parser::tokenToString(Token::TokenKind kind) { + // Implementation depends on your token names + return "Token"; +} diff --git a/code/parser.h b/code/parser.h index 58f4251..baeb87d 100644 --- a/code/parser.h +++ b/code/parser.h @@ -1,75 +1,125 @@ -#ifndef _PARSER_H -#define _PARSER_H -#include "AST.h" +// #ifndef _PARSER_H +// #define _PARSER_H +// #include "AST.h" +// #include "lexer.h" +// #include "llvm/Support/raw_ostream.h" +// #include +// using namespace std; + +// class Parser +// { +// Lexer &Lex; +// Token Tok; +// bool HasError; + +// void error() +// { +// llvm::errs() << "Unexpected: " << Tok.getText() << "\n"; +// HasError = true; +// } + +// void advance() { Lex.next(Tok); } + +// bool expect(Token::TokenKind Kind) +// { +// if (Tok.getKind() != Kind) +// { +// error(); +// return true; +// } +// return false; +// } + +// bool consume(Token::TokenKind Kind) +// { +// if (expect(Kind)) +// return true; +// advance(); +// return false; +// } + +// // parsing functions according to the regex +// // pattern specified. each one produces its own node +// // one node can have multiple subnodes inside it +// public: +// Base *parse(); +// Base *parseStatement(); +// IfStatement *parseIf(); +// ElseIfStatement *parseElseIf(); +// AssignStatement *parseUnaryExpression(Token &token); +// Expression *parseExpression(); +// Expression *parseLogicalComparison(); +// Expression *parseIntExpression(); +// Expression *parseTerm(); +// Expression *parseSign(); +// Expression *parsePower(); +// Expression *parseFactor(); +// ForStatement *parseFor(); +// WhileStatement *parseWhile(); +// AssignStatement *parseAssign(llvm::StringRef name); +// llvm::SmallVector parseDefine(Token::TokenKind token_kind); +// void check_for_semicolon(); + +// public: +// // initializes all members and retrieves the first token +// Parser(Lexer &Lex) : Lex(Lex), HasError(false) +// { +// advance(); +// } + +// // get the value of error flag +// bool hasError() { return HasError; } + +// }; + +// #endif + + + +#ifndef PARSER_H +#define PARSER_H + #include "lexer.h" -#include "llvm/Support/raw_ostream.h" -#include -using namespace std; - -class Parser -{ - Lexer &Lex; - Token Tok; - bool HasError; - - void error() - { - llvm::errs() << "Unexpected: " << Tok.getText() << "\n"; - HasError = true; - } - - void advance() { Lex.next(Tok); } - - bool expect(Token::TokenKind Kind) - { - if (Tok.getKind() != Kind) - { - error(); - return true; - } - return false; - } - - bool consume(Token::TokenKind Kind) - { - if (expect(Kind)) - return true; - advance(); - return false; - } - - // parsing functions according to the regex - // pattern specified. each one produces its own node - // one node can have multiple subnodes inside it -public: - Base *parse(); - Base *parseStatement(); - IfStatement *parseIf(); - ElseIfStatement *parseElseIf(); - AssignStatement *parseUnaryExpression(Token &token); - Expression *parseExpression(); - Expression *parseLogicalComparison(); - Expression *parseIntExpression(); - Expression *parseTerm(); - Expression *parseSign(); - Expression *parsePower(); - Expression *parseFactor(); - ForStatement *parseFor(); - WhileStatement *parseWhile(); - AssignStatement *parseAssign(llvm::StringRef name); - llvm::SmallVector parseDefine(Token::TokenKind token_kind); - void check_for_semicolon(); +#include "ast.h" +#include +#include -public: - // initializes all members and retrieves the first token - Parser(Lexer &Lex) : Lex(Lex), HasError(false) - { - advance(); - } +class Parser { + Lexer &lexer; + Token currentTok; + Token peekTok; - // get the value of error flag - bool hasError() { return HasError; } + void advance(); + void expect(Token::TokenKind kind); + void consume(Token::TokenKind kind); + +public: + Parser(Lexer &lexer); + std::unique_ptr parseProgram(); + // Statement Parsers + std::unique_ptr parseStatement(); + std::unique_ptr parseVarDecl(); + std::unique_ptr parseIfStatement(); + std::unique_ptr parseForLoop(); + std::unique_ptr parseWhileLoop(); + std::unique_ptr parsePrintStatement(); + std::unique_ptr parseBlock(); + + // Expression Parsers + std::unique_ptr parseExpression(); + std::unique_ptr parseLogicalOr(); + std::unique_ptr parseLogicalAnd(); + std::unique_ptr parseEquality(); + std::unique_ptr parseComparison(); + std::unique_ptr parseTerm(); + std::unique_ptr parseFactor(); + std::unique_ptr parseUnary(); + std::unique_ptr parsePrimary(); + + // Special + std::unique_ptr parseArrayLiteral(); + std::unique_ptr parseFunctionCall(); }; -#endif \ No newline at end of file +#endif diff --git a/code/semantic.cpp b/code/semantic.cpp index dff288e..7746348 100644 --- a/code/semantic.cpp +++ b/code/semantic.cpp @@ -1,312 +1,262 @@ #include "semantic.h" +#include "AST.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/raw_ostream.h" -namespace -{ - class DeclCheck : public ASTVisitor - { - llvm::StringMap variableTypeMap; - bool HasError; - enum ErrorType - { - AlreadyDefinedVariable, - NotDefinedVariable, - DivideByZero, - WrongValueTypeForVariable - }; - - void error(ErrorType errorType, llvm::StringRef V) - { - switch (errorType) - { - case ErrorType::DivideByZero: - llvm::errs() << "Division by zero is not allowed!\n"; - break; - case ErrorType::AlreadyDefinedVariable: - llvm::errs() << "Variable " << V << " is already declared!\n"; - break; - case ErrorType::NotDefinedVariable: - llvm::errs() << "Variable " << V << " is not declared!\n"; - break; - case ErrorType::WrongValueTypeForVariable: - llvm::errs() << "Illegal value for type " << V << "!\n"; - break; - default: - llvm::errs() << "Unknown error\n"; - break; - } - HasError = true; - exit(3); +namespace { + static const char *varTypeName(VarType t) { + switch (t) { + case VarType::INT: return "int"; + case VarType::FLOAT: return "float"; + case VarType::BOOL: return "bool"; + case VarType::CHAR: return "char"; + case VarType::STRING: return "string"; + case VarType::ARRAY: return "array"; + default: return ""; } + } - public: - DeclCheck() : HasError(false) {} - - bool hasError() { return HasError; } + class DeclCheck : public ASTVisitor { + llvm::StringMap VarTypes; + bool HasError = false; - virtual void visit(Base &Node) override - { - for (auto I = Node.begin(), E = Node.end(); I != E; ++I) - { - (*I)->accept(*this); + enum ErrorKind { AlreadyDefined, NotDefined, DivideByZero, TypeMismatch, InvalidOperation }; + void report(ErrorKind kind, const std::string &msg) { + switch (kind) { + case AlreadyDefined: llvm::errs() << "Error: variable '" << msg << "' already defined\n"; break; + case NotDefined: llvm::errs() << "Error: variable '" << msg << "' not defined\n"; break; + case DivideByZero: llvm::errs() << "Error: division by zero!\n"; break; + case TypeMismatch: llvm::errs() << "Error: type mismatch for '" << msg << "'\n"; break; + case InvalidOperation: llvm::errs() << "Error: invalid operation '" << msg << "' for type\n"; break; } - }; + HasError = true; + } - virtual void visit(PrintStatement &Node) override - { - Expression *declaration = (Expression *)Node.getExpr(); - declaration->accept(*this); - }; - virtual void visit(BinaryOp &Node) override - { - if (Node.getLeft()) - { - Node.getLeft()->accept(*this); - } - else - { - HasError = true; - } - if (Node.getRight()) - { - Node.getRight()->accept(*this); - } - else - { - HasError = true; - } - - // Divide by zero check - if (Node.getOperator() == BinaryOp::Operator::Div) - { - Expression *right = (Expression *)Node.getRight(); - if (right->isNumber() && right->getNumber() == 0) - { - error(DivideByZero, ((Expression *)Node.getLeft())->getValue()); + VarType typeOf(ASTNode *node) { + if (!node) return VarType::ERROR; + // Literals + if (auto *i = dynamic_cast*>(node)) return VarType::INT; + if (auto *f = dynamic_cast*>(node)) return VarType::FLOAT; + if (auto *b = dynamic_cast*>(node)) return VarType::BOOL; + if (auto *c = dynamic_cast*>(node)) return VarType::CHAR; + if (auto *s = dynamic_cast*>(node)) return VarType::STRING; + // Variable reference + if (auto *v = dynamic_cast(node)) { + if (!VarTypes.count(v->name)) { + report(NotDefined, v->name); + return VarType::ERROR; + } + return VarTypes.lookup(v->name); + } + // Binary operations + if (auto *bin = dynamic_cast(node)) { + VarType lt = typeOf(bin->left.get()); + VarType rt = typeOf(bin->right.get()); + switch (bin->op) { + case BinaryOp::ADD: case BinaryOp::SUBTRACT: + case BinaryOp::MULTIPLY: case BinaryOp::DIVIDE: case BinaryOp::MOD: + case BinaryOp::POW: { + // numeric or array-scalar or scalar-array + if ((lt==VarType::INT||lt==VarType::FLOAT) && (rt==VarType::INT||rt==VarType::FLOAT)) + return (lt==VarType::FLOAT||rt==VarType::FLOAT?VarType::FLOAT:VarType::INT); + if (lt==VarType::ARRAY && (rt==VarType::INT||rt==VarType::FLOAT)) return VarType::ARRAY; + if ((lt==VarType::INT||lt==VarType::FLOAT) && rt==VarType::ARRAY) return VarType::ARRAY; + report(InvalidOperation, varTypeName(lt) + std::string(" op ") + varTypeName(rt)); + return VarType::ERROR; + } + case BinaryOp::AND: case BinaryOp::OR: + if (lt==VarType::BOOL && rt==VarType::BOOL) return VarType::BOOL; + report(InvalidOperation, "logical"); return VarType::ERROR; + case BinaryOp::EQUAL: case BinaryOp::NOT_EQUAL: + case BinaryOp::LESS: case BinaryOp::LESS_EQUAL: + case BinaryOp::GREATER: case BinaryOp::GREATER_EQUAL: + // comparisons produce bool + if ((lt==VarType::INT||lt==VarType::FLOAT) && (rt==VarType::INT||rt==VarType::FLOAT)) return VarType::BOOL; + if (lt==rt) return VarType::BOOL; + report(TypeMismatch, "compare"); return VarType::ERROR; + case BinaryOp::CONCAT: + if (lt==VarType::STRING && rt==VarType::STRING) return VarType::STRING; + report(InvalidOperation, "concat"); return VarType::ERROR; + case BinaryOp::INDEX: + if (lt==VarType::ARRAY && rt==VarType::INT) return VarType::ARRAY; // element type infer omitted + report(InvalidOperation, "index"); return VarType::ERROR; + default: + return VarType::ERROR; } } - }; + // Unary operations + if (auto *un = dynamic_cast(node)) { + VarType ot = typeOf(un->operand.get()); + switch (un->op) { + case UnaryOp::INCREMENT: case UnaryOp::DECREMENT: + if (ot==VarType::INT||ot==VarType::FLOAT) return ot; + break; + case UnaryOp::ABS: + if (ot==VarType::INT||ot==VarType::FLOAT) return ot; + break; + case UnaryOp::LENGTH: + if (ot==VarType::ARRAY||ot==VarType::STRING) return VarType::INT; + break; + case UnaryOp::MIN: case UnaryOp::MAX: + if (ot==VarType::ARRAY) return VarType::ARRAY; + break; + default: break; + } + report(InvalidOperation, varTypeName(ot)); + return VarType::ERROR; + } + // Arrays + if (auto *arr = dynamic_cast(node)) { + if (arr->elements.empty()) return VarType::ARRAY; + VarType et = typeOf(arr->elements[0].get()); + for (auto &e : arr->elements) { + VarType t2 = typeOf(e.get()); + if (t2 != et) report(TypeMismatch, "array elements"); + } + return VarType::ARRAY; + } + // Array access + if (auto *acc = dynamic_cast(node)) { + VarType at = typeOf(acc->index.get()); + VarType arrt = VarTypes.lookup(acc->arrayName); + if (arrt != VarType::ARRAY) report(TypeMismatch, acc->arrayName); + if (at != VarType::INT) report(TypeMismatch, "index type"); + return VarType::ARRAY; + } + // Print + if (auto *p = dynamic_cast(node)) { + return typeOf(p->expr.get()); + } + // Default: error + return VarType::ERROR; + } - virtual void visit(Statement &Node) override - { - if (Node.getKind() == Statement::StatementType::Declaration) - { - DecStatement *declaration = (DecStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::Assignment) - { - AssignStatement *declaration = (AssignStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::If) - { - IfStatement *declaration = (IfStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::ElseIf) - { - ElseIfStatement *declaration = (ElseIfStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::Else) - { - ElseStatement *declaration = (ElseStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::Print) - { - PrintStatement *declaration = (PrintStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::While) - { - WhileStatement *declaration = (WhileStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::For) - { - ForStatement *declaration = (ForStatement *)&Node; - declaration->accept(*this); - } - - - }; + public: + bool hasError() const { return HasError; } - virtual void visit(BooleanOp &Node) override - { - if (Node.getLeft()) - { - Node.getLeft()->accept(*this); - } - else - { - HasError = true; - } - if (Node.getRight()) - { - Node.getRight()->accept(*this); - } - else - { - HasError = true; - } - }; + // Program and Block + void visit(ProgramNode &node) override { + for (auto &s : node.statements) s->accept(*this); + } + void visit(BlockNode &node) override { + for (auto &s : node.statements) s->accept(*this); + } - virtual void visit(Expression &Node) override - { - if (Node.getKind() == Expression::ExpressionType::Identifier) - { - if (variableTypeMap.count(Node.getValue()) == 0) - { - error(NotDefinedVariable, Node.getValue()); + // Variable declarations + void visit(MultiVarDeclNode &node) override { + for (auto &decl : node.declarations) { + const std::string &name = decl->name; + if (VarTypes.count(name)) report(AlreadyDefined, name); + else { + VarTypes[name] = decl->type; + if (decl->value) decl->value->accept(*this); } } - else if (Node.getKind() == Expression::ExpressionType::BinaryOpType) - { - BinaryOp *declaration = (BinaryOp *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Expression::ExpressionType::BooleanOpType) - { - Node.getBooleanOp()->getLeft()->accept(*this); - Node.getBooleanOp()->getRight()->accept(*this); - } - }; - - virtual void visit(DecStatement &Node) override - { - if (variableTypeMap.count(Node.getLValue()->getValue()) > 0) - { - error(AlreadyDefinedVariable, Node.getLValue()->getValue()); - } - // Add this new variable to variableTypeMap - if (Node.getDecType() == DecStatement::DecStatementType::Boolean) - { - variableTypeMap[Node.getLValue()->getValue()] = 'b'; - } - else - { - variableTypeMap[Node.getLValue()->getValue()] = 'i'; - } + } - Expression *rightValue = (Expression *)Node.getRValue(); - if (rightValue == nullptr) - { - return; - } - if (Node.getDecType() == DecStatement::DecStatementType::Boolean) - { - if (!(rightValue->getKind() == Expression::ExpressionType::Boolean || - rightValue->getKind() == Expression::ExpressionType::BooleanOpType)) - { - error(WrongValueTypeForVariable, "bool"); - } - } - else if (Node.getDecType() == DecStatement::DecStatementType::Number) - { - if (!(rightValue->getKind() == Expression::ExpressionType::Number || - rightValue->getKind() == Expression::ExpressionType::BinaryOpType)) - { - error(WrongValueTypeForVariable, "int"); - } + // Assignments + void visit(AssignNode &node) override { + if (!VarTypes.count(node.target)) report(NotDefined, node.target); + node.value->accept(*this); + VarType vt = VarTypes.lookup(node.target); + VarType rt = typeOf(node.value.get()); + // Allowed ops per type + bool ok = false; + switch (vt) { + case VarType::INT: case VarType::FLOAT: + ok = true; break; + case VarType::CHAR: case VarType::STRING: case VarType::BOOL: case VarType::ARRAY: + ok = (node.op == BinaryOp::EQUAL); + break; + default: break; } + if (!ok) report(InvalidOperation, varTypeName(vt)); + if (rt != vt && !(vt==VarType::FLOAT && rt==VarType::INT)) + report(TypeMismatch, node.target); + } - rightValue->accept(*this); - }; - - virtual void visit(IfStatement &Node) override - { - Expression *declaration = (Expression *)Node.getCondition(); - declaration->accept(*this); - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); - } - }; + // Variable reference + void visit(VarRefNode &node) override { + if (!VarTypes.count(node.name)) report(NotDefined, node.name); + } - virtual void visit(ElseIfStatement &Node) override - { - Expression *declaration = (Expression *)Node.getCondition(); - declaration->accept(*this); - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); + // Binary ops + void visit(BinaryOpNode &node) override { + node.left->accept(*this); + node.right->accept(*this); + if (node.op == BinaryOp::DIVIDE) { + if (auto *lit = dynamic_cast*>(node.right.get())) { + if (lit->value == 0) report(DivideByZero, ""); + } } - }; + typeOf(&node); + } - virtual void visit(ElseStatement &Node) override - { - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); - } - }; + // Unary ops + void visit(UnaryOpNode &node) override { + node.operand->accept(*this); + typeOf(&node); + } - virtual void visit(AssignStatement &Node) override - { - Node.getLValue()->accept(*this); - Node.getRValue()->accept(*this); - if (variableTypeMap.lookup(Node.getLValue()->getValue()) == 'i' && - (Node.getRValue()->getKind() == Expression::ExpressionType::Boolean || - Node.getRValue()->getKind() == Expression::ExpressionType::BooleanOpType)) - { - error(WrongValueTypeForVariable, "int"); - } - if (variableTypeMap.lookup(Node.getLValue()->getValue()) == 'b' && - (Node.getRValue()->getKind() == Expression::ExpressionType::Number || - Node.getRValue()->getKind() == Expression::ExpressionType::BinaryOpType)) - { - error(WrongValueTypeForVariable, "bool"); - } - }; + // Control structures + void visit(IfElseNode &node) override { + node.condition->accept(*this); + if (typeOf(node.condition.get()) != VarType::BOOL) + report(TypeMismatch, "if condition"); + node.thenBlock->accept(*this); + if (node.elseBlock) node.elseBlock->accept(*this); + } + void visit(ForLoopNode &node) override { + if (node.init) node.init->accept(*this); + if (node.condition) { + node.condition->accept(*this); + if (typeOf(node.condition.get()) != VarType::BOOL) + report(TypeMismatch, "for condition"); + } + if (node.update) node.update->accept(*this); + node.body->accept(*this); + } + void visit(WhileLoopNode &node) override { + node.condition->accept(*this); + if (typeOf(node.condition.get()) != VarType::BOOL) + report(TypeMismatch, "while condition"); + node.body->accept(*this); + } - virtual void visit(WhileStatement &Node) override{ - Node.getCondition()->accept(*this); - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); - } - }; - - virtual void visit(ForStatement &Node) override{ - AssignStatement * initial_assign = Node.getInitialAssign(); - /*if(initial_assign == nullptr){ - exit(0); - }*/ - (initial_assign->getLValue())->accept(*this); - (initial_assign->getRValue())->accept(*this); - Node.getCondition()->accept(*this); - - AssignStatement * update_assign = Node.getUpdateAssign(); - if(update_assign == nullptr){ - exit(0); - } - (update_assign->getLValue())->accept(*this); - (update_assign->getRValue())->accept(*this); + // Print + void visit(PrintNode &node) override { + node.expr->accept(*this); + } - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); - } - }; + // Array + void visit(ArrayNode &node) override { + for (auto &e : node.elements) e->accept(*this); + typeOf(&node); + } + void visit(ArrayAccessNode &node) override { + node.index->accept(*this); + typeOf(&node); + } + // Concat, Pow + void visit(ConcatNode &node) override { + node.left->accept(*this); + node.right->accept(*this); + typeOf(&node); + } + void visit(PowNode &node) override { + node.base->accept(*this); + node.exponent->accept(*this); + typeOf(&node); + } }; } -bool Semantic::semantic(AST *Tree) -{ - if (!Tree) - return false; - DeclCheck Check; - Tree->accept(Check); - return Check.hasError(); +bool Semantic::semantic(ProgramNode *root) { + if (!root) return false; + DeclCheck checker; + root->accept(checker); + return checker.hasError(); } diff --git a/code/semantic.h b/code/semantic.h index 6056947..c131944 100644 --- a/code/semantic.h +++ b/code/semantic.h @@ -1,13 +1,12 @@ -#ifndef SEMA_H -#define SEMA_H +#ifndef SEMANTIC_H +#define SEMANTIC_H #include "AST.h" -#include "lexer.h" -class Semantic -{ +class Semantic { public: - bool semantic(AST *Tree); + + bool semantic(ProgramNode *root); }; -#endif \ No newline at end of file +#endif diff --git a/grammer.txt b/grammer.txt index 5577f21..c377432 100644 --- a/grammer.txt +++ b/grammer.txt @@ -10,7 +10,11 @@ Block -> Define -> "int" Variable ";" | - "bool" Variable ";" + "bool" Variable ";" | + "float" Variable ";" | + "char" Variable ";" | + "string" Variable ";" | + "array" Variable ";" Variable -> @@ -25,10 +29,12 @@ Digit -> Number_A -> - Number_B | - "-" Number_B | - "+" Number_B - + Number_B | + Float | + "-" Number_B | + "-" Float | + "+" Number_B | + "+" Float Number_B -> Digit | @@ -149,4 +155,12 @@ ForCondition -> ForUpdate -> Unary | Identifier_A AssignOperation Value - + + + + +Float -> + Number_B "." Number_B | + "." Number_B | + Number_B "." + diff --git a/input.txt b/input.txt index 8491df5..0140d85 100644 --- a/input.txt +++ b/input.txt @@ -3,4 +3,23 @@ int i, j; for (i = 0; i <= 5; i++) { x = i + 2; } -print(x); \ No newline at end of file +print(x); + + + +bool flag = true; +/* if (flag == true) { +or + */ +/* if (flag) { */ + +if (flag == false) { + print(x); +} else if (flag) { + print(x); +} else { + print(x); +} + +float ff; +float t=3;