From c4948879cdaef994ad913b2616690c18a6644e2b Mon Sep 17 00:00:00 2001 From: mohsenkamini Date: Tue, 22 Apr 2025 17:20:15 +0330 Subject: [PATCH 01/25] tested if and bool variables --- grammer.txt | 6 +++++- input.txt | 9 ++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/grammer.txt b/grammer.txt index 5577f21..bc1f775 100644 --- a/grammer.txt +++ b/grammer.txt @@ -10,7 +10,11 @@ Block -> Define -> "int" Variable ";" | - "bool" Variable ";" + "bool" Variable ";" | + "float" Variable ";" | + "char" Variable ";" | + "string" Variable ";" | + "array" Variable ";" Variable -> diff --git a/input.txt b/input.txt index 8491df5..c87bb43 100644 --- a/input.txt +++ b/input.txt @@ -1,6 +1,13 @@ int x = 0; int i, j; +bool flag = true; for (i = 0; i <= 5; i++) { x = i + 2; } -print(x); \ No newline at end of file +print(x); +/* if (flag == true) { +or + */ +if (flag) { + print(flag); +} From 05a7fe56abe4b9659c99b876a5328bbea5d04253 Mon Sep 17 00:00:00 2001 From: mohsenkamini Date: Tue, 22 Apr 2025 17:33:52 +0330 Subject: [PATCH 02/25] tested else if and else --- .input.txt.swp | Bin 0 -> 12288 bytes input.txt | 15 ++++++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 .input.txt.swp diff --git a/.input.txt.swp b/.input.txt.swp new file mode 100644 index 0000000000000000000000000000000000000000..15c297df08f095c849ebf7bb78943399afa72872 GIT binary patch literal 12288 zcmeI%y-EW?6b0a`jac}Df}mJjwJ;jvCW?h@Ql+vG(r9C~#*LYfWY^6q5s3;Gf^U)5 z+Mil{0pG*gTF+=AA`&bW&cI=I=FZISd|k+HdV6D2*OxPbHY#%79w@&S3o@J+sRiYR zU$=$ct?bnz@5pbIQ(>{!%F~Ynovpe#@}q`{kDYu_3r*FJJ|9wrD5&_w*fI6+FkX(E zajz3uqW}edFEAvlE19%iKQ}w0<6~RD-^!H~pa2CZKmiI+fC3btz`qxW!#+8~E&4k* zndtO=-}=sw9R(;r0SZun0u-PC1t>rP3Q&Lo|4_i&6S*E1nHs^z^DOAO}c@>>vdsi)4^_WE!#Wm_*npKmiI+fC3bt00k&O0SZun z0)I!KC4*sP>T#mE;7UvVYQxtP{bACSG*PO0`wQR7-FeOBwB%ts7Kk{D+AY`(i;ma? z(bt0A5l4sb8+L=Bs`hyOm~+( Date: Tue, 22 Apr 2025 17:43:52 +0330 Subject: [PATCH 03/25] testing foat --- .input.txt.swp | Bin 12288 -> 0 bytes grammer.txt | 20 +++++++++++++++----- input.txt | 3 +++ 3 files changed, 18 insertions(+), 5 deletions(-) delete mode 100644 .input.txt.swp diff --git a/.input.txt.swp b/.input.txt.swp deleted file mode 100644 index 15c297df08f095c849ebf7bb78943399afa72872..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12288 zcmeI%y-EW?6b0a`jac}Df}mJjwJ;jvCW?h@Ql+vG(r9C~#*LYfWY^6q5s3;Gf^U)5 z+Mil{0pG*gTF+=AA`&bW&cI=I=FZISd|k+HdV6D2*OxPbHY#%79w@&S3o@J+sRiYR zU$=$ct?bnz@5pbIQ(>{!%F~Ynovpe#@}q`{kDYu_3r*FJJ|9wrD5&_w*fI6+FkX(E zajz3uqW}edFEAvlE19%iKQ}w0<6~RD-^!H~pa2CZKmiI+fC3btz`qxW!#+8~E&4k* zndtO=-}=sw9R(;r0SZun0u-PC1t>rP3Q&Lo|4_i&6S*E1nHs^z^DOAO}c@>>vdsi)4^_WE!#Wm_*npKmiI+fC3bt00k&O0SZun z0)I!KC4*sP>T#mE;7UvVYQxtP{bACSG*PO0`wQR7-FeOBwB%ts7Kk{D+AY`(i;ma? z(bt0A5l4sb8+L=Bs`hyOm~+( Number_A -> - Number_B | - "-" Number_B | - "+" Number_B - + Number_B | + Float | + "-" Number_B | + "-" Float | + "+" Number_B | + "+" Float Number_B -> Digit | @@ -153,4 +155,12 @@ ForCondition -> ForUpdate -> Unary | Identifier_A AssignOperation Value - + + + + +Float -> + Number_B "." Number_B | + "." Number_B | + Number_B "." + diff --git a/input.txt b/input.txt index dac033e..0140d85 100644 --- a/input.txt +++ b/input.txt @@ -20,3 +20,6 @@ if (flag == false) { } else { print(x); } + +float ff; +float t=3; From def784f22ad707aecbca7de9ebda861d7a8d0e2b Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 21:17:20 +0330 Subject: [PATCH 04/25] Update lexer.h --- code/lexer.h | 131 ++++++++++++++++++++++++++------------------------- 1 file changed, 68 insertions(+), 63 deletions(-) diff --git a/code/lexer.h b/code/lexer.h index aa0fc6b..5c7d7e1 100644 --- a/code/lexer.h +++ b/code/lexer.h @@ -4,99 +4,104 @@ #include "llvm/Support/MemoryBuffer.h" #include - class Lexer; -class Token -{ +class Token { friend class Lexer; public: - enum TokenKind : unsigned short - { - semi_colon, // ; - unknown, // unknown token - identifier, // identifier like a, b, c, d, etc. - number, // number like 1, 2, 3, 4, etc. - comma, // , - plus, // + - minus, // - - star, // * - mod, // % - slash, // / - power, // ^ - l_paren, // ( - r_paren, // ) - l_brace, // { - r_brace, // } - plus_equal, // += - plus_plus, // ++ - minus_minus, // -- - minus_equal, // -= - star_equal, // *= - mod_equal, // %= - slash_equal, // /= - equal, // = - equal_equal, // == - not_equal, // != - KW_not, // ! - less, // < - less_equal, // <= - greater, // > - greater_equal, // >= - comment, // /* - uncomment, // */ - KW_int, // int - KW_bool, // bool - KW_if, // if - KW_else, // else - KW_while, // while - KW_for, // for - KW_and, // and - KW_or, // or - KW_true, // true - KW_false, // false - KW_print, // print - eof // end of file + enum TokenKind : unsigned short { + // پایه + semi_colon, // ; + unknown, // ناشناخته + identifier, // شناسه + number, // عدد + comma, // , + l_paren, // ( + r_paren, // ) + l_brace, // { + r_brace, // } + eof, // پایان فایل + + // عملگرها + plus, // + + minus, // - + star, // * + mod, // % + slash, // / + power, // ^ + plus_equal, // += + minus_equal, // -= + star_equal, // *= + mod_equal, // %= + slash_equal, // /= + equal, // = + equal_equal, // == + not_equal, // != + less, // < + less_equal, // <= + greater, // > + greater_equal, // >= + + // کلمات کلیدی + KW_int, // int + KW_bool, // bool + KW_float, // float + KW_char, // char + KW_string, // string + KW_array, // array + KW_if, // if + KW_else, // else + KW_while, // while + KW_for, // for + KW_and, // and + KW_or, // or + KW_true, // true + KW_false, // false + KW_print, // print + KW_not, // ! + + // توکنهای خاص + comment_start, /* + comment_end, */ + increment, // ++ + decrement // -- }; private: - TokenKind Kind; // - llvm::StringRef Text; // + TokenKind Kind; // نوع توکن + llvm::StringRef Text; // متن توکن public: TokenKind getKind() const { return Kind; } bool is(TokenKind K) const { return Kind == K; } llvm::StringRef getText() const { return Text; } - // kind="+" isOneOf(plus, minus) -> true - bool isOneOf(TokenKind K1, TokenKind K2) const - { - return is(K1) || is(K2); + template + bool isOneOf(TokenKind K1, Ts... Ks) const { + return is(K1) || isOneOf(Ks...); } - template // variadic template can take several inputs - bool isOneOf(TokenKind K1, TokenKind K2, Ts... Ks) const - { - return is(K1) || isOneOf(K2, Ks...); + bool isOneOf(TokenKind K1, TokenKind K2) const { + return is(K1) || is(K2); } }; -class Lexer -{ +class Lexer { const char *BufferStart; const char *BufferPtr; public: - Lexer(const llvm::StringRef &Buffer) - { + explicit Lexer(const llvm::StringRef &Buffer) { BufferStart = Buffer.begin(); BufferPtr = BufferStart; } + void next(Token &token); private: void formToken(Token &Result, const char *TokEnd, Token::TokenKind Kind); + void skipMultiLineComment(); }; #endif From e40ddfeede56aebff776a505f5a1e057dba285ad Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 21:18:57 +0330 Subject: [PATCH 05/25] Update lexer.cpp --- code/lexer.cpp | 424 ++++++++++++++++--------------------------------- 1 file changed, 140 insertions(+), 284 deletions(-) diff --git a/code/lexer.cpp b/code/lexer.cpp index 1ef0975..5401919 100644 --- a/code/lexer.cpp +++ b/code/lexer.cpp @@ -1,294 +1,150 @@ #include "lexer.h" -#include +#include +namespace charinfo { + LLVM_READNONE inline bool isWhitespace(char c) { + return c == ' ' || c == '\t' || c == '\f' || + c == '\v' || c == '\r' || c == '\n'; + } -namespace charinfo -{ - LLVM_READNONE inline bool isWhitespace(char c) - { - return c == ' ' || c == '\t' || c == '\f' || - c == '\v' || c == '\r' || c == '\n'; - } + LLVM_READNONE inline bool isDigit(char c) { + return std::isdigit(c); + } - LLVM_READNONE inline bool isDigit(char c) - { - return c >= '0' && c <= '9'; - } + LLVM_READNONE inline bool isAlpha(char c) { + return std::isalpha(c) || c == '_'; + } - LLVM_READNONE inline bool isLetter(char c) - { - return (c >= 'a' && c <= 'z') || - (c >= 'A' && c <= 'Z'); - } - - LLVM_READNONE inline bool isOperator(char c) - { - return c == '+' || c == '-' || c == '*' || - c == '/' || c == '^' || c == '=' || - c == '%' || c == '<' || c == '>' || - c == '!'; - } - - LLVM_READNONE inline bool isSpecialCharacter(char c) - { - return c == ';' || c == ',' || c == '(' || - c == ')' || c == '{' || c == '}' || c == ','; - } + LLVM_READNONE inline bool isAlnum(char c) { + return isAlpha(c) || isDigit(c); + } } -void Lexer::next(Token &token) -{ - // Skips whitespace like " " - while (*BufferPtr && charinfo::isWhitespace(*BufferPtr)) - { - ++BufferPtr; - } - bool is_comment = false; - if (*BufferPtr == '/' && *(BufferPtr + 1) == '*'){ - is_comment = true; - } - while(is_comment){ - if (!*BufferPtr) - { - token.Kind = Token::eof; - return; - } - while (*BufferPtr && *BufferPtr != '*' && *(BufferPtr + 1) != '/') - { - ++BufferPtr; - } - if(*BufferPtr == '*' && *(BufferPtr + 1) == '/'){ - BufferPtr += 2; - is_comment = false; - } - else{ - ++BufferPtr; - } - while (*BufferPtr && charinfo::isWhitespace(*BufferPtr)) - { - ++BufferPtr; - } - } - // since end of context is 0 -> !0 = true -> end of context - if (!*BufferPtr) - { - token.Kind = Token::eof; - return; - } - // looking for keywords or identifiers like "int", a123 , ... - if (charinfo::isLetter(*BufferPtr)) - { - const char *end = BufferPtr + 1; - - // until reaches the end of lexeme - // example: ".int " -> "i.nt " -> "in.t " -> "int. " - while (charinfo::isLetter(*end) || charinfo::isDigit(*end) || *end == '_') - { - ++end; - } - - llvm::StringRef Context(BufferPtr, end - BufferPtr); // start of lexeme, length of lexeme - Token::TokenKind kind; - if (Context == "int") - { - kind = Token::KW_int; - } - else if (Context == "if") - { - kind = Token::KW_if; - } - else if (Context == "bool") - { - kind = Token::KW_bool; - } - else if (Context == "else") - { - kind = Token::KW_else; - } - else if (Context == "while") - { - kind = Token::KW_while; - } - else if (Context == "for") - { - kind = Token::KW_for; - } - else if (Context == "and") - { - kind = Token::KW_and; - } - else if (Context == "or") - { - kind = Token::KW_or; - } - else if (Context == "true") - { - kind = Token::KW_true; - } - else if (Context == "false") - { - kind = Token::KW_false; - } - else if (Context == "print") - { - kind = Token::KW_print; - } - else - { - kind = Token::identifier; - } - - formToken(token, end, kind); - return; - } - - else if (charinfo::isDigit(*BufferPtr)) - { - - const char *end = BufferPtr + 1; - - while (charinfo::isDigit(*end)) - ++end; - - formToken(token, end, Token::number); - return; - } - else if (charinfo::isSpecialCharacter(*BufferPtr)) - { - switch (*BufferPtr) - { - case ';': - formToken(token, BufferPtr + 1, Token::semi_colon); - break; - case ',': - formToken(token, BufferPtr + 1, Token::comma); - break; - case '(': - formToken(token, BufferPtr + 1, Token::l_paren); - break; - case ')': - formToken(token, BufferPtr + 1, Token::r_paren); - break; - case '{': - formToken(token, BufferPtr + 1, Token::l_brace); - break; - case '}': - formToken(token, BufferPtr + 1, Token::r_brace); - break; - } - } - else - { - - if (*BufferPtr == '=' && *(BufferPtr + 1) == '=') - { // == - formToken(token, BufferPtr + 2, Token::equal_equal); - } - else if (*BufferPtr == '+' && *(BufferPtr + 1) == '=') - { // += - formToken(token, BufferPtr + 2, Token::plus_equal); - } - else if (*BufferPtr == '-' && *(BufferPtr + 1) == '=') - { // -= - formToken(token, BufferPtr + 2, Token::minus_equal); - } - - else if (*BufferPtr == '*' && *(BufferPtr + 1) == '=') - { // *= - formToken(token, BufferPtr + 2, Token::star_equal); - } - - else if (*BufferPtr == '/' && *(BufferPtr + 1) == '=') - { // /= - formToken(token, BufferPtr + 2, Token::slash_equal); - } - - else if (*BufferPtr == '%' && *(BufferPtr + 1) == '=') - { // %= - formToken(token, BufferPtr + 2, Token::mod_equal); - } - - else if (*BufferPtr == '!' && *(BufferPtr + 1) == '=') - { // != - formToken(token, BufferPtr + 2, Token::not_equal); - } - - else if (*BufferPtr == '<' && *(BufferPtr + 1) == '=') - { // <= - formToken(token, BufferPtr + 2, Token::less_equal); - } - - else if (*BufferPtr == '>' && *(BufferPtr + 1) == '=') - { - formToken(token, BufferPtr + 2, Token::greater_equal); // >= - } - else if (*BufferPtr == '/' && *(BufferPtr + 1) == '*') - { - formToken(token, BufferPtr + 2, Token::comment); // comment - } - else if (*BufferPtr == '*' && *(BufferPtr + 1) == '/') - { - formToken(token, BufferPtr + 2, Token::uncomment); // uncomment - } - else if (*BufferPtr == '+' && *(BufferPtr + 1) == '+') - { - formToken(token, BufferPtr + 2, Token::plus_plus); // ++ - } - else if (*BufferPtr == '-' && *(BufferPtr + 1) == '-') - { - formToken(token, BufferPtr + 2, Token::minus_minus); // -- - } - else if (*BufferPtr == '=') - { - formToken(token, BufferPtr + 1, Token::equal); // = - } - else if (*BufferPtr == '+') - { - formToken(token, BufferPtr + 1, Token::plus); // + - } - else if (*BufferPtr == '-') - { - formToken(token, BufferPtr + 1, Token::minus); // - - } - else if (*BufferPtr == '*') - { - formToken(token, BufferPtr + 1, Token::star); // * - } - else if (*BufferPtr == '/') - { - formToken(token, BufferPtr + 1, Token::slash); // / - } - else if (*BufferPtr == '^') - { - formToken(token, BufferPtr + 1, Token::power); // ^ - } - else if (*BufferPtr == '>') - { - formToken(token, BufferPtr + 1, Token::greater); // > - } - else if (*BufferPtr == '<') - { - formToken(token, BufferPtr + 1, Token::less); // < - } - else if (*BufferPtr == '!') - { - formToken(token, BufferPtr + 1, Token::KW_not ); // ! - } - else if (*BufferPtr == '%') - { - formToken(token, BufferPtr + 1, Token::mod); // % - } - else - { - formToken(token, BufferPtr + 1, Token::unknown); - } - } +void Lexer::next(Token &token) { + while (*BufferPtr && charinfo::isWhitespace(*BufferPtr)) ++BufferPtr; + + if (!*BufferPtr) { + token.Kind = Token::eof; + return; + } + + if (*BufferPtr == '/' && *(BufferPtr + 1) == '*') { + skipMultiLineComment(); + return next(token); + } + + // identifiers + if (charinfo::isAlpha(*BufferPtr)) { + const char *end = BufferPtr + 1; + while (charinfo::isAlnum(*end)) ++end; + + llvm::StringRef text(BufferPtr, end - BufferPtr); + Token::TokenKind kind = Token::identifier; + + if (text == "int") kind = Token::KW_int; + else if (text == "bool") kind = Token::KW_bool; + else if (text == "float") kind = Token::KW_float; + else if (text == "char") kind = Token::KW_char; + else if (text == "string") kind = Token::KW_string; + else if (text == "array") kind = Token::KW_array; + else if (text == "if") kind = Token::KW_if; + else if (text == "else") kind = Token::KW_else; + else if (text == "while") kind = Token::KW_while; + else if (text == "for") kind = Token::KW_for; + else if (text == "and") kind = Token::KW_and; + else if (text == "or") kind = Token::KW_or; + else if (text == "true") kind = Token::KW_true; + else if (text == "false") kind = Token::KW_false; + else if (text == "print") kind = Token::KW_print; + + formToken(token, end, kind); + return; + } + + if (charinfo::isDigit(*BufferPtr) || (*BufferPtr == '.' && charinfo::isDigit(*(BufferPtr + 1))) { + const char *end = BufferPtr; + bool hasDot = false; + + while (charinfo::isDigit(*end) || (*end == '.' && !hasDot)) { + if (*end == '.') hasDot = true; + ++end; + } + + formToken(token, end, hasDot ? Token::number : Token::number); + return; + } + + switch (*BufferPtr) { + case ';': formToken(token, BufferPtr + 1, Token::semi_colon); break; + case ',': formToken(token, BufferPtr + 1, Token::comma); break; + case '(': formToken(token, BufferPtr + 1, Token::l_paren); break; + case ')': formToken(token, BufferPtr + 1, Token::r_paren); break; + case '{': formToken(token, BufferPtr + 1, Token::l_brace); break; + case '}': formToken(token, BufferPtr + 1, Token::r_brace); break; + + case '+': + if (*(BufferPtr + 1) == '+') formToken(token, BufferPtr + 2, Token::increment); + else if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::plus_equal); + else formToken(token, BufferPtr + 1, Token::plus); + break; + + case '-': + if (*(BufferPtr + 1) == '-') formToken(token, BufferPtr + 2, Token::decrement); + else if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::minus_equal); + else formToken(token, BufferPtr + 1, Token::minus); + break; + + case '*': + if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::star_equal); + else formToken(token, BufferPtr + 1, Token::star); + break; + + case '/': + if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::slash_equal); + else formToken(token, BufferPtr + 1, Token::slash); + break; + + case '^': formToken(token, BufferPtr + 1, Token::power); break; + case '%': + if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::mod_equal); + else formToken(token, BufferPtr + 1, Token::mod); + break; + + case '=': + if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::equal_equal); + else formToken(token, BufferPtr + 1, Token::equal); + break; + + case '!': + if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::not_equal); + else formToken(token, BufferPtr + 1, Token::KW_not); + break; + + case '<': + if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::less_equal); + else formToken(token, BufferPtr + 1, Token::less); + break; + + case '>': + if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::greater_equal); + else formToken(token, BufferPtr + 1, Token::greater); + break; + + default: + formToken(token, BufferPtr + 1, Token::unknown); + } +} +void Lexer::skipMultiLineComment() { + BufferPtr += 2; + while (*BufferPtr && !(*BufferPtr == '*' && *(BufferPtr + 1) == '/')) { + ++BufferPtr; + } + if (*BufferPtr) BufferPtr += 2; } -void Lexer::formToken(Token &Tok, const char *TokEnd, Token::TokenKind Kind) -{ - Tok.Kind = Kind; - Tok.Text = llvm::StringRef(BufferPtr, TokEnd - BufferPtr); - BufferPtr = TokEnd; -} \ No newline at end of file +void Lexer::formToken(Token &Tok, const char *TokEnd, Token::TokenKind Kind) { + Tok.Kind = Kind; + Tok.Text = llvm::StringRef(BufferPtr, TokEnd - BufferPtr); + BufferPtr = TokEnd; +} From 4cc810db3e8f3c2774f44499b35edfa3a896bb42 Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 21:44:34 +0330 Subject: [PATCH 06/25] Update lexer.h changed a few things --- code/lexer.h | 170 ++++++++++++++++++++++++++------------------------- 1 file changed, 87 insertions(+), 83 deletions(-) diff --git a/code/lexer.h b/code/lexer.h index 5c7d7e1..9478788 100644 --- a/code/lexer.h +++ b/code/lexer.h @@ -1,107 +1,111 @@ #ifndef LEXER_H #define LEXER_H + #include "llvm/ADT/StringRef.h" -#include "llvm/Support/MemoryBuffer.h" #include class Lexer; class Token { - friend class Lexer; - public: - enum TokenKind : unsigned short { - // پایه - semi_colon, // ; - unknown, // ناشناخته - identifier, // شناسه - number, // عدد - comma, // , - l_paren, // ( - r_paren, // ) - l_brace, // { - r_brace, // } - eof, // پایان فایل - - // عملگرها - plus, // + - minus, // - - star, // * - mod, // % - slash, // / - power, // ^ - plus_equal, // += - minus_equal, // -= - star_equal, // *= - mod_equal, // %= - slash_equal, // /= - equal, // = - equal_equal, // == - not_equal, // != - less, // < - less_equal, // <= - greater, // > - greater_equal, // >= - - // کلمات کلیدی - KW_int, // int - KW_bool, // bool - KW_float, // float - KW_char, // char - KW_string, // string - KW_array, // array - KW_if, // if - KW_else, // else - KW_while, // while - KW_for, // for - KW_and, // and - KW_or, // or - KW_true, // true - KW_false, // false - KW_print, // print - KW_not, // ! - - // توکنهای خاص - comment_start, /* - comment_end, */ - increment, // ++ - decrement // -- - }; + enum TokenKind { -private: - TokenKind Kind; // نوع توکن - llvm::StringRef Text; // متن توکن + semi_colon, + identifier, + number, + comma, + l_paren, + r_paren, + l_brace, + r_brace, + l_bracket, + r_bracket, + eof, -public: - TokenKind getKind() const { return Kind; } - bool is(TokenKind K) const { return Kind == K; } - llvm::StringRef getText() const { return Text; } + plus, + minus, + star, + slash, + mod, + caret, + equal, + plus_equal, + minus_equal, + star_equal, + slash_equal, + mod_equal, + equal_equal, + not_equal, + less, + less_equal, + greater, + greater_equal, + and_op, + or_op, + plus_plus, + minus_minus, - template - bool isOneOf(TokenKind K1, Ts... Ks) const { - return is(K1) || isOneOf(Ks...); - } - bool isOneOf(TokenKind K1, TokenKind K2) const { - return is(K1) || is(K2); + KW_int, + KW_bool, + KW_float, + KW_char, + KW_string, + KW_array, + KW_if, + KW_else, + KW_while, + KW_for, + KW_foreach, + KW_in, + KW_print, + KW_true, + KW_false, + KW_try, + KW_catch, + KW_error, + KW_match, + KW_concat, + KW_pow, + KW_abs, + KW_length, + KW_min, + KW_max, + KW_index, + + comment_start, + comment_end, + string_literal, + char_literal, + float_literal, + underscore, + arrow + }; + + TokenKind kind; + llvm::StringRef text; + int line; + int column; + + bool is(TokenKind k) const { return kind == k; } + bool isOneOf(TokenKind k1, TokenKind k2) const { return is(k1) || is(k2); } + template bool isOneOf(TokenKind k1, TokenKind k2, Ts... ks) const { + return is(k1) || isOneOf(k2, ks...); } }; class Lexer { - const char *BufferStart; - const char *BufferPtr; + const char *bufferStart; + const char *bufferPtr; public: - explicit Lexer(const llvm::StringRef &Buffer) { - BufferStart = Buffer.begin(); - BufferPtr = BufferStart; - } - - void next(Token &token); + Lexer(llvm::StringRef buffer); + Token nextToken(); private: - void formToken(Token &Result, const char *TokEnd, Token::TokenKind Kind); - void skipMultiLineComment(); + void skipWhitespace(); + void skipComment(); + Token formToken(Token::TokenKind kind, const char *tokEnd); }; #endif From d941f813b90e944a13ed0ae1b0aee96b15b64f90 Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 21:45:16 +0330 Subject: [PATCH 07/25] Update lexer.cpp synced with lexer.h --- code/lexer.cpp | 260 ++++++++++++++++++++++++++++--------------------- 1 file changed, 148 insertions(+), 112 deletions(-) diff --git a/code/lexer.cpp b/code/lexer.cpp index 5401919..65aba43 100644 --- a/code/lexer.cpp +++ b/code/lexer.cpp @@ -1,46 +1,41 @@ #include "lexer.h" #include -namespace charinfo { - LLVM_READNONE inline bool isWhitespace(char c) { - return c == ' ' || c == '\t' || c == '\f' || - c == '\v' || c == '\r' || c == '\n'; - } - - LLVM_READNONE inline bool isDigit(char c) { - return std::isdigit(c); - } - - LLVM_READNONE inline bool isAlpha(char c) { - return std::isalpha(c) || c == '_'; - } +Lexer::Lexer(llvm::StringRef buffer) { + bufferStart = buffer.begin(); + bufferPtr = bufferStart; +} - LLVM_READNONE inline bool isAlnum(char c) { - return isAlpha(c) || isDigit(c); +void Lexer::skipWhitespace() { + while (isspace(*bufferPtr)) { + if (*bufferPtr == '\n') bufferPtr++; + else bufferPtr++; } } -void Lexer::next(Token &token) { - while (*BufferPtr && charinfo::isWhitespace(*BufferPtr)) ++BufferPtr; - - if (!*BufferPtr) { - token.Kind = Token::eof; - return; +void Lexer::skipComment() { + if (*bufferPtr == '/' && *(bufferPtr + 1) == '*') { + bufferPtr += 2; + while (!(*bufferPtr == '*' && *(bufferPtr + 1) == '/')) bufferPtr++; + bufferPtr += 2; } +} - if (*BufferPtr == '/' && *(BufferPtr + 1) == '*') { - skipMultiLineComment(); - return next(token); - } +Token Lexer::nextToken() { + skipWhitespace(); + skipComment(); - // identifiers - if (charinfo::isAlpha(*BufferPtr)) { - const char *end = BufferPtr + 1; - while (charinfo::isAlnum(*end)) ++end; + if (bufferPtr >= bufferStart + buffer.size()) + return {Token::eof, llvm::StringRef(), 0, 0}; - llvm::StringRef text(BufferPtr, end - BufferPtr); + const char *tokStart = bufferPtr; + + // شناسه‌ها و کلمات کلیدی + if (isalpha(*bufferPtr)) { + while (isalnum(*bufferPtr) || *bufferPtr == '_') bufferPtr++; + llvm::StringRef text(tokStart, bufferPtr - tokStart); + Token::TokenKind kind = Token::identifier; - if (text == "int") kind = Token::KW_int; else if (text == "bool") kind = Token::KW_bool; else if (text == "float") kind = Token::KW_float; @@ -51,100 +46,141 @@ void Lexer::next(Token &token) { else if (text == "else") kind = Token::KW_else; else if (text == "while") kind = Token::KW_while; else if (text == "for") kind = Token::KW_for; - else if (text == "and") kind = Token::KW_and; - else if (text == "or") kind = Token::KW_or; + else if (text == "foreach") kind = Token::KW_foreach; + else if (text == "in") kind = Token::KW_in; + else if (text == "print") kind = Token::KW_print; else if (text == "true") kind = Token::KW_true; else if (text == "false") kind = Token::KW_false; - else if (text == "print") kind = Token::KW_print; - - formToken(token, end, kind); - return; - } - - if (charinfo::isDigit(*BufferPtr) || (*BufferPtr == '.' && charinfo::isDigit(*(BufferPtr + 1))) { - const char *end = BufferPtr; - bool hasDot = false; + else if (text == "try") kind = Token::KW_try; + else if (text == "catch") kind = Token::KW_catch; + else if (text == "error") kind = Token::KW_error; + else if (text == "match") kind = Token::KW_match; + else if (text == "concat") kind = Token::KW_concat; + else if (text == "pow") kind = Token::KW_pow; + else if (text == "abs") kind = Token::KW_abs; + else if (text == "length") kind = Token::KW_length; + else if (text == "min") kind = Token::KW_min; + else if (text == "max") kind = Token::KW_max; + else if (text == "index") kind = Token::KW_index; - while (charinfo::isDigit(*end) || (*end == '.' && !hasDot)) { - if (*end == '.') hasDot = true; - ++end; + return {kind, text, 0, 0}; + } + + // اعداد + if (isdigit(*bufferPtr) || *bufferPtr == '.') { + bool hasDot = (*bufferPtr == '.'); + while (isdigit(*bufferPtr) || (!hasDot && *bufferPtr == '.')) { + if (*bufferPtr == '.') hasDot = true; + bufferPtr++; } - - formToken(token, end, hasDot ? Token::number : Token::number); - return; + return {hasDot ? Token::float_literal : Token::number, + llvm::StringRef(tokStart, bufferPtr - tokStart), 0, 0}; } - - switch (*BufferPtr) { - case ';': formToken(token, BufferPtr + 1, Token::semi_colon); break; - case ',': formToken(token, BufferPtr + 1, Token::comma); break; - case '(': formToken(token, BufferPtr + 1, Token::l_paren); break; - case ')': formToken(token, BufferPtr + 1, Token::r_paren); break; - case '{': formToken(token, BufferPtr + 1, Token::l_brace); break; - case '}': formToken(token, BufferPtr + 1, Token::r_brace); break; - + + // رشته‌ها + if (*bufferPtr == '"') { + bufferPtr++; + const char *start = bufferPtr; + while (*bufferPtr != '"' && *bufferPtr != '\0') bufferPtr++; + return {Token::string_literal, llvm::StringRef(start, bufferPtr - start), 0, 0}; + } + + // کاراکترها + if (*bufferPtr == '\'') { + bufferPtr++; + char c = *bufferPtr; + bufferPtr += 2; // Skip closing ' + return {Token::char_literal, llvm::StringRef(&c, 1), 0, 0}; + } + + // عملگرها و سمبل‌ها + switch (*bufferPtr) { + case ';': bufferPtr++; return {Token::semi_colon, ";", 0, 0}; + case ',': bufferPtr++; return {Token::comma, ",", 0, 0}; + case '(': bufferPtr++; return {Token::l_paren, "(", 0, 0}; + case ')': bufferPtr++; return {Token::r_paren, ")", 0, 0}; + case '{': bufferPtr++; return {Token::l_brace, "{", 0, 0}; + case '}': bufferPtr++; return {Token::r_brace, "}", 0, 0}; + case '[': bufferPtr++; return {Token::l_bracket, "[", 0, 0}; + case ']': bufferPtr++; return {Token::r_bracket, "]", 0, 0}; case '+': - if (*(BufferPtr + 1) == '+') formToken(token, BufferPtr + 2, Token::increment); - else if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::plus_equal); - else formToken(token, BufferPtr + 1, Token::plus); - break; - + if (*(bufferPtr + 1) == '+') { + bufferPtr += 2; + return {Token::plus_plus, "++", 0, 0}; + } else if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::plus_equal, "+=", 0, 0}; + } + bufferPtr++; + return {Token::plus, "+", 0, 0}; case '-': - if (*(BufferPtr + 1) == '-') formToken(token, BufferPtr + 2, Token::decrement); - else if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::minus_equal); - else formToken(token, BufferPtr + 1, Token::minus); - break; - + if (*(bufferPtr + 1) == '-') { + bufferPtr += 2; + return {Token::minus_minus, "--", 0, 0}; + } else if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::minus_equal, "-=", 0, 0}; + } else if (*(bufferPtr + 1) == '>') { + bufferPtr += 2; + return {Token::arrow, "->", 0, 0}; + } + bufferPtr++; + return {Token::minus, "-", 0, 0}; case '*': - if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::star_equal); - else formToken(token, BufferPtr + 1, Token::star); - break; - + if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::star_equal, "*=", 0, 0}; + } + bufferPtr++; + return {Token::star, "*", 0, 0}; case '/': - if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::slash_equal); - else formToken(token, BufferPtr + 1, Token::slash); - break; - - case '^': formToken(token, BufferPtr + 1, Token::power); break; - case '%': - if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::mod_equal); - else formToken(token, BufferPtr + 1, Token::mod); - break; - + if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::slash_equal, "/=", 0, 0}; + } + bufferPtr++; + return {Token::slash, "/", 0, 0}; + case '%': + if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::mod_equal, "%=", 0, 0}; + } + bufferPtr++; + return {Token::mod, "%", 0, 0}; + case '^': + bufferPtr++; + return {Token::caret, "^", 0, 0}; case '=': - if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::equal_equal); - else formToken(token, BufferPtr + 1, Token::equal); - break; - + if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::equal_equal, "==", 0, 0}; + } + bufferPtr++; + return {Token::equal, "=", 0, 0}; case '!': - if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::not_equal); - else formToken(token, BufferPtr + 1, Token::KW_not); + if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::not_equal, "!=", 0, 0}; + } break; - case '<': - if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::less_equal); - else formToken(token, BufferPtr + 1, Token::less); - break; - + if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::less_equal, "<=", 0, 0}; + } + bufferPtr++; + return {Token::less, "<", 0, 0}; case '>': - if (*(BufferPtr + 1) == '=') formToken(token, BufferPtr + 2, Token::greater_equal); - else formToken(token, BufferPtr + 1, Token::greater); - break; - - default: - formToken(token, BufferPtr + 1, Token::unknown); - } -} - -void Lexer::skipMultiLineComment() { - BufferPtr += 2; - while (*BufferPtr && !(*BufferPtr == '*' && *(BufferPtr + 1) == '/')) { - ++BufferPtr; + if (*(bufferPtr + 1) == '=') { + bufferPtr += 2; + return {Token::greater_equal, ">=", 0, 0}; + } + bufferPtr++; + return {Token::greater, ">", 0, 0}; + case '_': + bufferPtr++; + return {Token::underscore, "_", 0, 0}; } - if (*BufferPtr) BufferPtr += 2; -} - -void Lexer::formToken(Token &Tok, const char *TokEnd, Token::TokenKind Kind) { - Tok.Kind = Kind; - Tok.Text = llvm::StringRef(BufferPtr, TokEnd - BufferPtr); - BufferPtr = TokEnd; + + return {Token::eof, llvm::StringRef(), 0, 0}; } From b13cca20d12778047f32d0745a9f44e7ecde20fa Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 21:46:50 +0330 Subject: [PATCH 08/25] Update lexer.cpp added a few comments --- code/lexer.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/code/lexer.cpp b/code/lexer.cpp index 65aba43..6cd3622 100644 --- a/code/lexer.cpp +++ b/code/lexer.cpp @@ -30,7 +30,7 @@ Token Lexer::nextToken() { const char *tokStart = bufferPtr; - // شناسه‌ها و کلمات کلیدی + // KWords & Identifiers if (isalpha(*bufferPtr)) { while (isalnum(*bufferPtr) || *bufferPtr == '_') bufferPtr++; llvm::StringRef text(tokStart, bufferPtr - tokStart); @@ -66,7 +66,7 @@ Token Lexer::nextToken() { return {kind, text, 0, 0}; } - // اعداد + // numbers if (isdigit(*bufferPtr) || *bufferPtr == '.') { bool hasDot = (*bufferPtr == '.'); while (isdigit(*bufferPtr) || (!hasDot && *bufferPtr == '.')) { @@ -77,7 +77,7 @@ Token Lexer::nextToken() { llvm::StringRef(tokStart, bufferPtr - tokStart), 0, 0}; } - // رشته‌ها + // strings if (*bufferPtr == '"') { bufferPtr++; const char *start = bufferPtr; @@ -85,7 +85,7 @@ Token Lexer::nextToken() { return {Token::string_literal, llvm::StringRef(start, bufferPtr - start), 0, 0}; } - // کاراکترها + // charachters if (*bufferPtr == '\'') { bufferPtr++; char c = *bufferPtr; @@ -93,7 +93,7 @@ Token Lexer::nextToken() { return {Token::char_literal, llvm::StringRef(&c, 1), 0, 0}; } - // عملگرها و سمبل‌ها + // operators & symbols switch (*bufferPtr) { case ';': bufferPtr++; return {Token::semi_colon, ";", 0, 0}; case ',': bufferPtr++; return {Token::comma, ",", 0, 0}; From a7ce273c40ec737bc152ddd1bcc22a1f5a6e8f62 Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 21:48:29 +0330 Subject: [PATCH 09/25] Update parser.h and synced with lexer --- code/parser.h | 188 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 119 insertions(+), 69 deletions(-) diff --git a/code/parser.h b/code/parser.h index 58f4251..baeb87d 100644 --- a/code/parser.h +++ b/code/parser.h @@ -1,75 +1,125 @@ -#ifndef _PARSER_H -#define _PARSER_H -#include "AST.h" +// #ifndef _PARSER_H +// #define _PARSER_H +// #include "AST.h" +// #include "lexer.h" +// #include "llvm/Support/raw_ostream.h" +// #include +// using namespace std; + +// class Parser +// { +// Lexer &Lex; +// Token Tok; +// bool HasError; + +// void error() +// { +// llvm::errs() << "Unexpected: " << Tok.getText() << "\n"; +// HasError = true; +// } + +// void advance() { Lex.next(Tok); } + +// bool expect(Token::TokenKind Kind) +// { +// if (Tok.getKind() != Kind) +// { +// error(); +// return true; +// } +// return false; +// } + +// bool consume(Token::TokenKind Kind) +// { +// if (expect(Kind)) +// return true; +// advance(); +// return false; +// } + +// // parsing functions according to the regex +// // pattern specified. each one produces its own node +// // one node can have multiple subnodes inside it +// public: +// Base *parse(); +// Base *parseStatement(); +// IfStatement *parseIf(); +// ElseIfStatement *parseElseIf(); +// AssignStatement *parseUnaryExpression(Token &token); +// Expression *parseExpression(); +// Expression *parseLogicalComparison(); +// Expression *parseIntExpression(); +// Expression *parseTerm(); +// Expression *parseSign(); +// Expression *parsePower(); +// Expression *parseFactor(); +// ForStatement *parseFor(); +// WhileStatement *parseWhile(); +// AssignStatement *parseAssign(llvm::StringRef name); +// llvm::SmallVector parseDefine(Token::TokenKind token_kind); +// void check_for_semicolon(); + +// public: +// // initializes all members and retrieves the first token +// Parser(Lexer &Lex) : Lex(Lex), HasError(false) +// { +// advance(); +// } + +// // get the value of error flag +// bool hasError() { return HasError; } + +// }; + +// #endif + + + +#ifndef PARSER_H +#define PARSER_H + #include "lexer.h" -#include "llvm/Support/raw_ostream.h" -#include -using namespace std; - -class Parser -{ - Lexer &Lex; - Token Tok; - bool HasError; - - void error() - { - llvm::errs() << "Unexpected: " << Tok.getText() << "\n"; - HasError = true; - } - - void advance() { Lex.next(Tok); } - - bool expect(Token::TokenKind Kind) - { - if (Tok.getKind() != Kind) - { - error(); - return true; - } - return false; - } - - bool consume(Token::TokenKind Kind) - { - if (expect(Kind)) - return true; - advance(); - return false; - } - - // parsing functions according to the regex - // pattern specified. each one produces its own node - // one node can have multiple subnodes inside it -public: - Base *parse(); - Base *parseStatement(); - IfStatement *parseIf(); - ElseIfStatement *parseElseIf(); - AssignStatement *parseUnaryExpression(Token &token); - Expression *parseExpression(); - Expression *parseLogicalComparison(); - Expression *parseIntExpression(); - Expression *parseTerm(); - Expression *parseSign(); - Expression *parsePower(); - Expression *parseFactor(); - ForStatement *parseFor(); - WhileStatement *parseWhile(); - AssignStatement *parseAssign(llvm::StringRef name); - llvm::SmallVector parseDefine(Token::TokenKind token_kind); - void check_for_semicolon(); +#include "ast.h" +#include +#include -public: - // initializes all members and retrieves the first token - Parser(Lexer &Lex) : Lex(Lex), HasError(false) - { - advance(); - } +class Parser { + Lexer &lexer; + Token currentTok; + Token peekTok; - // get the value of error flag - bool hasError() { return HasError; } + void advance(); + void expect(Token::TokenKind kind); + void consume(Token::TokenKind kind); + +public: + Parser(Lexer &lexer); + std::unique_ptr parseProgram(); + // Statement Parsers + std::unique_ptr parseStatement(); + std::unique_ptr parseVarDecl(); + std::unique_ptr parseIfStatement(); + std::unique_ptr parseForLoop(); + std::unique_ptr parseWhileLoop(); + std::unique_ptr parsePrintStatement(); + std::unique_ptr parseBlock(); + + // Expression Parsers + std::unique_ptr parseExpression(); + std::unique_ptr parseLogicalOr(); + std::unique_ptr parseLogicalAnd(); + std::unique_ptr parseEquality(); + std::unique_ptr parseComparison(); + std::unique_ptr parseTerm(); + std::unique_ptr parseFactor(); + std::unique_ptr parseUnary(); + std::unique_ptr parsePrimary(); + + // Special + std::unique_ptr parseArrayLiteral(); + std::unique_ptr parseFunctionCall(); }; -#endif \ No newline at end of file +#endif From 97c4aa61cf5580ede5c2f49b1fe59936e194115f Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 21:53:44 +0330 Subject: [PATCH 10/25] Update parser.cpp & synced with parser.h --- code/parser.cpp | 960 ++++++++++++++++-------------------------------- 1 file changed, 323 insertions(+), 637 deletions(-) diff --git a/code/parser.cpp b/code/parser.cpp index 9b2165a..40dceb8 100644 --- a/code/parser.cpp +++ b/code/parser.cpp @@ -1,717 +1,403 @@ -#ifndef PARSER_H -#define PARSER_H -#include "error.h" #include "parser.h" -#include "optimizer.h" -#include -using namespace std; -#endif +#include +#include +#include -Base *Parser::parse() -{ - llvm::SmallVector statements; - bool isComment = false; - while (!Tok.is(Token::eof)) - { - if (isComment) - { - if (Tok.is(Token::uncomment)) - { - isComment = false; - } - advance(); - continue; - } +Parser::Parser(Lexer &lexer) : lexer(lexer) { + advance(); + advance(); +} - switch (Tok.getKind()) - { - case Token::KW_int: - { - llvm::SmallVector states = Parser::parseDefine(Token::KW_int); - if (states.size() == 0) - { - return nullptr; - } - while (states.size() > 0) - { - statements.push_back(states.back()); - states.pop_back(); - } - break; - } - case Token::KW_bool: - { - llvm::SmallVector states = Parser::parseDefine(Token::KW_bool); - if (states.size() == 0) - { - return nullptr; - } - while (states.size() > 0) - { - statements.push_back(states.back()); - states.pop_back(); - } - break; - } - case Token::identifier: - { - llvm::StringRef name = Tok.getText(); - Token current = Tok; - advance(); - if (!Tok.isOneOf(Token::plus_plus, Token::minus_minus)) - { - AssignStatement *assign = parseAssign(name); - statements.push_back(assign); - } - else - { - AssignStatement *assign = parseUnaryExpression(current); - statements.push_back(assign); - } - Parser::check_for_semicolon(); - break; - } - case Token::KW_print: - { - advance(); - if (!Tok.is(Token::l_paren)) - { - Error::LeftParenthesisExpected(); - } - advance(); - // token should be identifier - if (!Tok.is(Token::identifier)) - { - Error::VariableExpected(); - } - Expression *variable_to_be_printed = new Expression(Tok.getText()); - advance(); - if (!Tok.is(Token::r_paren)) - { - Error::RightParenthesisExpected(); - } - advance(); - Parser::check_for_semicolon(); - PrintStatement *print_statement = new PrintStatement(variable_to_be_printed); - statements.push_back(print_statement); - break; - } - case Token::comment: - { - isComment = true; - advance(); - break; - } - case Token::KW_if: - { - IfStatement *statement = parseIf(); - statements.push_back(statement); - break; - } - case Token::KW_while: - { - WhileStatement *statement = parseWhile(); - statements.push_back(statement); - break; - } - case Token::KW_for: - { - ForStatement *statement = parseFor(); - /*llvm::SmallVector unrolled = completeUnroll(statement); - for (Statement *s : unrolled) - { - statements.push_back(s); - }*/ - statements.push_back(statement); - break; - } - } - } - return new Base(statements); +void Parser::advance() { + currentTok = peekTok; + peekTok = lexer.nextToken(); } -void Parser::check_for_semicolon() -{ - if (!Tok.is(Token::semi_colon)) - { - Error::SemiColonExpected(); +void Parser::expect(Token::TokenKind kind) { + if (!currentTok.is(kind)) { + throw std::runtime_error("Syntax Error: Expected " + + tokenToString(kind) + + " but found " + + tokenToString(currentTok.kind) + + " at line " + std::to_string(currentTok.line)); } +} + +void Parser::consume(Token::TokenKind kind) { + expect(kind); advance(); } -AssignStatement *Parser::parseUnaryExpression(Token &token) -{ - AssignStatement* res; - if (Tok.is(Token::plus_plus)) - { - advance(); - if (token.is(Token::identifier)) - { - Expression *tok = new Expression(token.getText()); - Expression *one = new Expression(1); - res = new AssignStatement(tok, new BinaryOp(BinaryOp::Plus, tok, one)); - } - else - { - Error::VariableExpected(); - } +std::unique_ptr Parser::parseProgram() { + auto program = std::make_unique(); + while (!currentTok.is(Token::eof)) { + program->statements.push_back(parseStatement()); } - else if (Tok.is(Token::minus_minus)) - { - advance(); - if (token.is(Token::identifier)) - { - Expression *tok = new Expression(token.getText()); - Expression *one = new Expression(1); - res = new AssignStatement(tok, new BinaryOp(BinaryOp::Minus, tok, one)); - } - else - { - Error::VariableExpected(); - } + return program; +} + +std::unique_ptr Parser::parseStatement() { + switch (currentTok.kind) { + case Token::KW_int: + case Token::KW_bool: + case Token::KW_float: + case Token::KW_char: + case Token::KW_string: + case Token::KW_array: + return parseVarDecl(); + + case Token::KW_if: + return parseIfStatement(); + + case Token::KW_for: + return parseForLoop(); + + case Token::KW_while: + return parseWhileLoop(); + + case Token::KW_print: + return parsePrintStatement(); + + case Token::l_brace: + return parseBlock(); + + default: + return parseExpressionStatement(); } - return res; } -llvm::SmallVector Parser::parseDefine(Token::TokenKind token_kind) -{ + +// Variable Declaration +std::unique_ptr Parser::parseVarDecl() { + VarType type = tokenToVarType(currentTok.kind); advance(); - llvm::SmallVector states; - while (!Tok.is(Token::semi_colon)) - { - llvm::StringRef name; - Expression *value = nullptr; - if (Tok.is(Token::identifier)) - { - name = Tok.getText(); - advance(); - } - else - { - Error::VariableExpected(); - } - if (Tok.is(Token::equal)) - { + + std::vector declarations; + do { + std::string name = currentTok.text.str(); + consume(Token::identifier); + + std::unique_ptr value = nullptr; + if (currentTok.is(Token::equal)) { advance(); value = parseExpression(); } - if (Tok.is(Token::comma)) - { - advance(); - } - else if (Tok.is(Token::semi_colon)) - { - // OK do nothing. The while will break itself. - } - else - { - Error::VariableExpected(); - } - DecStatement *state; - if(token_kind == Token::KW_int){ - state = new DecStatement(new Expression(name), value, DecStatement::DecStatementType::Number); - }else{ - state = new DecStatement(new Expression(name), value, DecStatement::DecStatementType::Boolean); - } - states.push_back(state); - } - - advance(); // pass semicolon - return states; + declarations.emplace_back(type, name, std::move(value)); + + } while (currentTok.is(Token::comma) && advance()); + + consume(Token::semi_colon); + return std::make_unique(std::move(declarations)); } -Expression *Parser::parseExpression() -{ - Expression *left = parseLogicalComparison(); - while (Tok.isOneOf(Token::KW_and, Token::KW_or)) - { - BooleanOp::Operator Op; - switch (Tok.getKind()) - { - case Token::KW_and: - Op = BooleanOp::And; - break; - case Token::KW_or: - Op = BooleanOp::Or; - break; - default: - break; - } +// If Statement +std::unique_ptr Parser::parseIfStatement() { + consume(Token::KW_if); + consume(Token::l_paren); + auto condition = parseExpression(); + consume(Token::r_paren); + + auto thenBlock = parseBlock(); + + std::unique_ptr elseBlock = nullptr; + if (currentTok.is(Token::KW_else)) { advance(); - Expression *Right = parseLogicalComparison(); - left = new BooleanOp(Op, left, Right); + elseBlock = currentTok.is(Token::KW_if) ? parseIfStatement() : parseBlock(); } - return left; + + return std::make_unique( + std::move(condition), + std::move(thenBlock), + std::move(elseBlock) + ); } -Expression *Parser::parseLogicalComparison() -{ - Expression *left = parseIntExpression(); - while (Tok.isOneOf(Token::equal_equal, Token::not_equal, Token::less, Token::less_equal, Token::greater, Token::greater_equal)) - { - BooleanOp::Operator Op; - switch (Tok.getKind()) - { - case Token::equal_equal: - Op = BooleanOp::Equal; - break; - case Token::not_equal: - Op = BooleanOp::NotEqual; - break; - case Token::less: - Op = BooleanOp::Less; - break; - case Token::less_equal: - Op = BooleanOp::LessEqual; - break; - case Token::greater: - Op = BooleanOp::Greater; - break; - case Token::greater_equal: - Op = BooleanOp::GreaterEqual; - break; - default: - break; - } - advance(); - Expression *Right = parseIntExpression(); - left = new BooleanOp(Op, left, Right); +// For Loop +std::unique_ptr Parser::parseForLoop() { + consume(Token::KW_for); + consume(Token::l_paren); + + // Initialization + std::unique_ptr init = nullptr; + if (!currentTok.is(Token::semi_colon)) { + init = parseVarDeclOrExpr(); } - return left; + consume(Token::semi_colon); + + // Condition + std::unique_ptr cond = nullptr; + if (!currentTok.is(Token::semi_colon)) { + cond = parseExpression(); + } + consume(Token::semi_colon); + + // Update + std::unique_ptr update = nullptr; + if (!currentTok.is(Token::r_paren)) { + update = parseExpression(); + } + consume(Token::r_paren); + + auto body = parseBlock(); + + return std::make_unique( + std::move(init), + std::move(cond), + std::move(update), + std::move(body) + ); } -Expression *Parser::parseIntExpression() -{ - Expression *Left = parseTerm(); - while (Tok.isOneOf(Token::plus, Token::minus)) - { - BinaryOp::Operator Op = - Tok.is(Token::plus) ? BinaryOp::Plus : BinaryOp::Minus; - advance(); - Expression *Right = parseTerm(); - Left = new BinaryOp(Op, Left, Right); - } - return Left; +// While Loop +std::unique_ptr Parser::parseWhileLoop() { + consume(Token::KW_while); + consume(Token::l_paren); + auto condition = parseExpression(); + consume(Token::r_paren); + + auto body = parseBlock(); + + return std::make_unique( + std::move(condition), + std::move(body) + ); } -Expression *Parser::parseTerm() -{ - Expression *Left = parseSign(); - while (Tok.isOneOf(Token::star, Token::slash, Token::mod)) - { - BinaryOp::Operator Op = - Tok.is(Token::star) ? BinaryOp::Mul : Tok.is(Token::slash) ? BinaryOp::Div - : BinaryOp::Mod; - advance(); - Expression *Right = parseSign(); - Left = new BinaryOp(Op, Left, Right); - } - return Left; +// Print Statement +std::unique_ptr Parser::parsePrintStatement() { + consume(Token::KW_print); + consume(Token::l_paren); + auto expr = parseExpression(); + consume(Token::r_paren); + consume(Token::semi_colon); + return std::make_unique(std::move(expr)); } -Expression *Parser::parseSign() -{ - if (Tok.is(Token::minus)) - { - advance(); - return new BinaryOp(BinaryOp::Mul, new Expression(-1), parsePower()); +// Block Statement +std::unique_ptr Parser::parseBlock() { + consume(Token::l_brace); + auto block = std::make_unique(); + + while (!currentTok.is(Token::r_brace) && !currentTok.is(Token::eof)) { + block->statements.push_back(parseStatement()); } - else if (Tok.is(Token::plus)) - { + + consume(Token::r_brace); + return block; +} + +// Expression Parsing +std::unique_ptr Parser::parseExpression() { + return parseAssignment(); +} + +std::unique_ptr Parser::parseAssignment() { + auto left = parseLogicalOr(); + + if (currentTok.isOneOf(Token::equal, Token::plus_equal, Token::minus_equal, + Token::star_equal, Token::slash_equal, Token::mod_equal)) { + Token op = currentTok; advance(); - return parsePower(); + auto right = parseAssignment(); + return std::make_unique(std::move(left), op, std::move(right)); } - return parsePower(); + + return left; } -Expression *Parser::parsePower() -{ - Expression *Left = parseFactor(); - while (Tok.is(Token::power)) - { - BinaryOp::Operator Op = - BinaryOp::Pow; +std::unique_ptr Parser::parseLogicalOr() { + auto left = parseLogicalAnd(); + while (currentTok.is(Token::or_op)) { advance(); - Expression *Right = parseFactor(); - Left = new BinaryOp(Op, Left, Right); + auto right = parseLogicalAnd(); + left = std::make_unique(BinaryOp::LogicalOr, std::move(left), std::move(right)); } - return Left; + return left; } -Expression *Parser::parseFactor() -{ - Expression *Res = nullptr; - switch (Tok.getKind()) - { - case Token::number: - { - int number; - Tok.getText().getAsInteger(10, number); - Res = new Expression(number); +std::unique_ptr Parser::parseLogicalAnd() { + auto left = parseEquality(); + while (currentTok.is(Token::and_op)) { advance(); - break; + auto right = parseEquality(); + left = std::make_unique(BinaryOp::LogicalAnd, std::move(left), std::move(right)); } - case Token::identifier: - { - Res = new Expression(Tok.getText()); + return left; +} + +std::unique_ptr Parser::parseEquality() { + auto left = parseComparison(); + while (currentTok.isOneOf(Token::equal_equal, Token::not_equal)) { + BinaryOp op = currentTok.is(Token::equal_equal) ? BinaryOp::Equal : BinaryOp::NotEqual; advance(); - break; + auto right = parseComparison(); + left = std::make_unique(op, std::move(left), std::move(right)); } - case Token::l_paren: - { + return left; +} + +std::unique_ptr Parser::parseComparison() { + auto left = parseAddSub(); + while (currentTok.isOneOf(Token::less, Token::less_equal, Token::greater, Token::greater_equal)) { + BinaryOp op; + switch (currentTok.kind) { + case Token::less: op = BinaryOp::Less; break; + case Token::less_equal: op = BinaryOp::LessEqual; break; + case Token::greater: op = BinaryOp::Greater; break; + case Token::greater_equal: op = BinaryOp::GreaterEqual; break; + default: throw std::runtime_error("Invalid comparison operator"); + } advance(); - Res = parseExpression(); - if (!consume(Token::r_paren)) - break; + auto right = parseAddSub(); + left = std::make_unique(op, std::move(left), std::move(right)); } - case Token::KW_true: - { - Res = new Expression(true); + return left; +} + +std::unique_ptr Parser::parseAddSub() { + auto left = parseMulDiv(); + while (currentTok.isOneOf(Token::plus, Token::minus)) { + BinaryOp op = currentTok.is(Token::plus) ? BinaryOp::Add : BinaryOp::Subtract; advance(); - break; + auto right = parseMulDiv(); + left = std::make_unique(op, std::move(left), std::move(right)); } - case Token::KW_false: - { - Res = new Expression(false); + return left; +} + +std::unique_ptr Parser::parseMulDiv() { + auto left = parseUnary(); + while (currentTok.isOneOf(Token::star, Token::slash, Token::mod)) { + BinaryOp op; + switch (currentTok.kind) { + case Token::star: op = BinaryOp::Multiply; break; + case Token::slash: op = BinaryOp::Divide; break; + case Token::mod: op = BinaryOp::Modulus; break; + default: throw std::runtime_error("Invalid multiplicative operator"); + } advance(); - break; - } - default: // error handling - { - Error::NumberVariableExpected(); + auto right = parseUnary(); + left = std::make_unique(op, std::move(left), std::move(right)); } - } - return Res; + return left; } -AssignStatement *Parser::parseAssign(llvm::StringRef name) -{ - Expression *value = nullptr; - if (Tok.is(Token::equal)) - { - advance(); - value = parseExpression(); - }else if(Tok.isOneOf(Token::plus_equal, Token::minus_equal, Token::star_equal, Token::slash_equal, Token::mod_equal)){ - // storing token - Token current_op = Tok; - advance(); - value = parseExpression(); - if(current_op.is(Token::plus_equal)){ - value = new BinaryOp(BinaryOp::Plus, new Expression(name), value); - }else if(current_op.is(Token::minus_equal)){ - value = new BinaryOp(BinaryOp::Minus, new Expression(name), value); - }else if(current_op.is(Token::star_equal)){ - value = new BinaryOp(BinaryOp::Mul, new Expression(name), value); - }else if(current_op.is(Token::slash_equal)){ - value = new BinaryOp(BinaryOp::Div, new Expression(name), value); - }else if(current_op.is(Token::mod_equal)){ - value = new BinaryOp(BinaryOp::Mod, new Expression(name), value); +std::unique_ptr Parser::parseUnary() { + if (currentTok.isOneOf(Token::plus, Token::minus, Token::not_op)) { + UnaryOp op; + switch (currentTok.kind) { + case Token::plus: op = UnaryOp::Plus; break; + case Token::minus: op = UnaryOp::Minus; break; + case Token::not_op: op = UnaryOp::LogicalNot; break; + default: throw std::runtime_error("Invalid unary operator"); } - }else{ - Error::EqualExpected(); + advance(); + auto operand = parseUnary(); + return std::make_unique(op, std::move(operand)); } - - return new AssignStatement(new Expression(name), value); + return parsePrimary(); } -Base *Parser::parseStatement() -{ - llvm::SmallVector statements; - bool isComment = false; - while (!Tok.is(Token::r_brace) && !Tok.is(Token::eof)) - { - if (isComment) - { - if (Tok.is(Token::uncomment)) - { - isComment = false; - } +std::unique_ptr Parser::parsePrimary() { + switch (currentTok.kind) { + case Token::number: + case Token::float_literal: { + auto value = currentTok.text.str(); advance(); - continue; + return std::make_unique(value); } - - switch (Tok.getKind()) - { - case Token::identifier: - { - llvm::StringRef name = Tok.getText(); - Token current = Tok; + + case Token::string_literal: { + auto value = currentTok.text.str(); advance(); - if (!Tok.isOneOf(Token::plus_plus, Token::minus_minus)) - { - AssignStatement *assign = parseAssign(name); - statements.push_back(assign); - } - else - { - AssignStatement *assign = parseUnaryExpression(current); - statements.push_back(assign); - } - Parser::check_for_semicolon(); - break; + return std::make_unique(value); } - case Token::KW_print: - { - advance(); - if (!Tok.is(Token::l_paren)) - { - Error::LeftParenthesisExpected(); - } + + case Token::KW_true: + case Token::KW_false: { + bool value = currentTok.is(Token::KW_true); advance(); - // token should be identifier - if (!Tok.is(Token::identifier)) - { - Error::VariableExpected(); - } - Expression *tok = new Expression(Tok.getText()); + return std::make_unique(value); + } + + case Token::identifier: { + std::string name = currentTok.text.str(); advance(); - if (!Tok.is(Token::r_paren)) - { - Error::RightParenthesisExpected(); + + // Function call or array access + if (currentTok.is(Token::l_paren)) { + return parseFunctionCall(name); + } else if (currentTok.is(Token::l_bracket)) { + return parseArrayAccess(name); } - advance(); - Parser::check_for_semicolon(); - PrintStatement *print_statement = new PrintStatement(tok); - statements.push_back(print_statement); - break; + + return std::make_unique(name); } - case Token::comment: - { - isComment = true; + + case Token::l_paren: { advance(); - break; - } - case Token::KW_if: - { - IfStatement *statement = parseIf(); - statements.push_back(statement); - break; - } - case Token::KW_while: - { - WhileStatement* statement = parseWhile(); - statements.push_back(statement); - break; - } - case Token::KW_for: - { - ForStatement* statement = parseFor(); - statements.push_back(statement); - break; + auto expr = parseExpression(); + consume(Token::r_paren); + return expr; } + + case Token::l_bracket: + return parseArrayLiteral(); + default: - { - Error::UnexpectedToken(Tok); - } - } + throw std::runtime_error("Unexpected primary expression"); } - return new Base(statements); } -IfStatement *Parser::parseIf() -{ - advance(); - if (!Tok.is(Token::l_paren)) - { - Error::LeftParenthesisExpected(); - } - - advance(); - Expression *condition = parseExpression(); - - if (!Tok.is(Token::r_paren)) - { - Error::RightParenthesisExpected(); - } - - advance(); - if (!Tok.is(Token::l_brace)) - { - Error::LeftBraceExpected(); - } - advance(); - Base *allIfStatements = parseStatement(); - if (!Tok.is(Token::r_brace)) - { - Error::RightBraceExpected(); - } - advance(); - - // parse else if and else statements - llvm::SmallVector elseIfStatements; - ElseStatement *elseStatement = nullptr; - bool hasElseIf = false; - bool hasElse = false; - - while (Tok.is(Token::KW_else)) - { - advance(); - if (Tok.is(Token::KW_if)) - { - elseIfStatements.push_back(parseElseIf()); - hasElseIf = true; - } - else if (Tok.is(Token::l_brace)) - { - advance(); - Base *allElseStatements = parseStatement(); - if (!Tok.is(Token::r_brace)) - { - Error::RightBraceExpected(); - } - advance(); - elseStatement = new ElseStatement(allElseStatements->getStatements(), Statement::StatementType::Else); - hasElse = true; - break; - } - else - { - Error::RightBraceExpected(); - } +std::unique_ptr Parser::parseArrayLiteral() { + consume(Token::l_bracket); + std::vector> elements; + + if (!currentTok.is(Token::r_bracket)) { + do { + elements.push_back(parseExpression()); + } while (currentTok.is(Token::comma) && advance()); } - - return new IfStatement(condition, allIfStatements->getStatements(), elseIfStatements, elseStatement, hasElseIf, hasElse, Statement::StatementType::If); + + consume(Token::r_bracket); + return std::make_unique(std::move(elements)); } -ElseIfStatement *Parser::parseElseIf() -{ - advance(); - if (!Tok.is(Token::l_paren)) - { - Error::LeftParenthesisExpected(); - } - - advance(); - Expression *condition = parseExpression(); - - if (!Tok.is(Token::r_paren)) - { - Error::RightParenthesisExpected(); - } - - advance(); - if (!Tok.is(Token::l_brace)) - { - Error::LeftBraceExpected(); - } - advance(); - Base *allIfStatements = parseStatement(); - if (!Tok.is(Token::r_brace)) - { - Error::RightBraceExpected(); - } - advance(); - - return new ElseIfStatement(condition, allIfStatements->getStatements(), Statement::StatementType::ElseIf); +std::unique_ptr Parser::parseArrayAccess(const std::string& name) { + consume(Token::l_bracket); + auto index = parseExpression(); + consume(Token::r_bracket); + return std::make_unique(name, std::move(index)); } -WhileStatement *Parser::parseWhile() -{ - advance(); - if (!Tok.is(Token::l_paren)) - { - Error::LeftParenthesisExpected(); - } - - advance(); - Expression *condition = parseExpression(); - - if (!Tok.is(Token::r_paren)) - { - Error::RightParenthesisExpected(); - } - - advance(); - if (Tok.is(Token::l_brace)) - { - advance(); - Base *allWhileStatements = parseStatement(); - if(!consume(Token::r_brace)) - { - return new WhileStatement(condition, allWhileStatements->getStatements(), Statement::StatementType::While); - } - else - { - Error::RightBraceExpected(); - } - - } - else{ - Error::LeftBraceExpected(); +std::unique_ptr Parser::parseFunctionCall(const std::string& name) { + consume(Token::l_paren); + std::vector> args; + + if (!currentTok.is(Token::r_paren)) { + do { + args.push_back(parseExpression()); + } while (currentTok.is(Token::comma) && advance()); } - advance(); + + consume(Token::r_paren); + return std::make_unique(name, std::move(args)); } -ForStatement *Parser::parseFor() -{ - advance(); - if (!Tok.is(Token::l_paren)) // for(i = 0;i<10;i++) - { - Error::LeftParenthesisExpected(); - } - advance(); //i = 0;i<10;i++) - if (!Tok.is(Token::identifier)) - { - Error::VariableExpected(); - - } - llvm::StringRef name = Tok.getText(); - advance(); //= 0;i<10;i++) - AssignStatement *assign = parseAssign(name); - //Expression *value = nullptr; - /*if (Tok.is(Token::equal)) - { - advance(); //0;i<10;i++) - value = parseExpression(); - AssignStatement *assign = new AssignStatement(new Expression(name), value); - }*/ - //advance(); - check_for_semicolon(); - Expression *condition = parseExpression(); - //advance(); - check_for_semicolon(); - if (!Tok.is(Token::identifier)) - { - Error::VariableExpected(); - - } - llvm::StringRef name_up = Tok.getText(); - Token current = Tok; - AssignStatement *assign_up = nullptr; - advance(); - if (!Tok.isOneOf(Token::plus_plus, Token::minus_minus)) - { - assign_up = parseAssign(name_up); +// Helper Functions +VarType Parser::tokenToVarType(Token::TokenKind kind) { + switch (kind) { + case Token::KW_int: return VarType::Int; + case Token::KW_bool: return VarType::Bool; + case Token::KW_float: return VarType::Float; + case Token::KW_char: return VarType::Char; + case Token::KW_string: return VarType::String; + case Token::KW_array: return VarType::Array; + default: throw std::runtime_error("Invalid variable type"); } - else - { - assign_up = parseUnaryExpression(current); - } - if (!Tok.is(Token::r_paren)) - { - Error::RightParenthesisExpected(); - } - - advance(); - if (Tok.is(Token::l_brace)) - { - advance(); - Base *allForStatements = parseStatement(); - if(!consume(Token::r_brace)) - { - return new ForStatement(condition, allForStatements->getStatements(),assign,assign_up, Statement::StatementType::For); - } - else - { - Error::RightBraceExpected(); - } - - } - else{ - Error::LeftBraceExpected(); - } - advance(); - - } +std::string Parser::tokenToString(Token::TokenKind kind) { + // Implementation depends on your token names + return "Token"; +} From 6a59bd27d3483648ee09fb4fa2c66900554ca003 Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 22:17:23 +0330 Subject: [PATCH 11/25] Update AST.h and changed a lot --- code/AST.h | 662 ++++++++++++++++++----------------------------------- 1 file changed, 217 insertions(+), 445 deletions(-) diff --git a/code/AST.h b/code/AST.h index 5fbf60b..9dd96a4 100644 --- a/code/AST.h +++ b/code/AST.h @@ -1,509 +1,281 @@ #ifndef AST_H #define AST_H -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/InitLLVM.h" -#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include + +enum class VarType { + INT, BOOL, FLOAT, CHAR, + STRING, ARRAY, ERROR, NEUTRAL +}; +enum class BinaryOp { + ADD, SUBTRACT, MULTIPLY, DIVIDE, MOD, + EQUAL, NOT_EQUAL, LESS, LESS_EQUAL, GREATER, GREATER_EQUAL, + AND, OR, INDEX, CONCAT, POW, + ARRAY_ADD, ARRAY_SUBTRACT, ARRAY_MULTIPLY, ARRAY_DIVIDE +}; -class AST; // Abstract Syntax Tree -class Base; // top level program -class PrintStatement; -class BinaryOp; // binary operation of numbers and identifiers -class Statement; // top level statement -class BooleanOp; // boolean operation like 3 > 6*2; -class Expression; // top level expression that is evaluated to boolean, int or variable name at last -class IfStatement; -class DecStatement; // declaration statement like int a; -class ElseIfStatement; -class ElseStatement; -class AssignStatement; // assignment statement like a = 3; -class ForStatement; -class WhileStatement; -// class afterCheckStatement; +enum class UnaryOp { + INCREMENT, DECREMENT, + LENGTH, MIN, MAX, ABS +}; +enum class LoopType { FOR, FOREACH, WHILE }; -class ASTVisitor -{ -public: - // Virtual visit functions for each AST node type - virtual void visit(AST&) {} - virtual void visit(Expression&) {} - virtual void visit(Base&) = 0; - virtual void visit(Statement&) = 0; - virtual void visit(BinaryOp&) = 0; - virtual void visit(DecStatement&) = 0; - virtual void visit(AssignStatement&) = 0; - virtual void visit(BooleanOp&) = 0; - virtual void visit(IfStatement&) = 0; - virtual void visit(ElseIfStatement&) = 0; - virtual void visit(ElseStatement&) = 0; - virtual void visit(PrintStatement&) = 0; - virtual void visit(ForStatement&) = 0; - virtual void visit(WhileStatement&) = 0; +enum class StmtType { + IF, ELSE_IF, ELSE, + FOR, WHILE, TRY_CATCH, + VAR_DECL, ASSIGN, PRINT, + BLOCK, MATCH }; -class AST { +class ASTNode { public: - virtual ~AST() {} - virtual void accept(ASTVisitor& V) = 0; + virtual ~ASTNode() = default; + virtual void print(llvm::raw_ostream &os, int indent = 0) const = 0; }; -class TopLevelEntity : AST { +// barname - majmoo e ii az gozaare ha +class ProgramNode : public ASTNode { public: - TopLevelEntity() {} + std::vector> statements; + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; - -class Expression : public TopLevelEntity { +// Var decleration +class VarDeclNode : public ASTNode { public: - enum ExpressionType { - Number, - Identifier, - Boolean, - BinaryOpType, - BooleanOpType - }; -private: - ExpressionType Type; - llvm::StringRef Value; - int NumberVal; - bool BoolVal; - BooleanOp* BOVal; + VarType type; + std::string name; + std::unique_ptr value; + + VarDeclNode(VarType t, std::string n, std::unique_ptr v) + : type(t), name(n), value(std::move(v)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; +}; +// tarif chand moteghayyer +class MultiVarDeclNode : public ASTNode { public: - Expression() {} - Expression(llvm::StringRef value) : Type(ExpressionType::Identifier), Value(value) {} // store string - Expression(int value) : Type(ExpressionType::Number), NumberVal(value) {} // store number - Expression(bool value) : Type(ExpressionType::Boolean), BoolVal(value) {} // store boolean - Expression(BooleanOp* value) : Type(ExpressionType::BooleanOpType), BOVal(value) {} // store boolean - Expression(ExpressionType type) : Type(type) {} - - bool isNumber() { - if (Type == ExpressionType::Number) - return true; - return false; - } - - bool isBoolean() { - if (Type == ExpressionType::Boolean) - return true; - return false; - } - - bool isVariable() { - if (Type == ExpressionType::Identifier) - return true; - return false; - } - - bool isBinaryOp() { - if (Type == ExpressionType::BinaryOpType) - return true; - return false; - } - - bool isBooleanOp() { - if (Type == ExpressionType::BooleanOpType) - return true; - return false; - } - - llvm::StringRef getValue() { - return Value; - } - - int getNumber() { - return NumberVal; - } - - BooleanOp* getBooleanOp() { - return BOVal; - } - - bool getBoolean() { - return BoolVal; - } - ExpressionType getKind() - { - return Type; - } - - virtual void accept(ASTVisitor& V) override - { - V.visit(*this); - } + std::vector> declarations; + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -class BooleanOp : public Expression -{ +// Assignment +class AssignNode : public ASTNode { public: - enum Operator - { - LessEqual, - Less, - Greater, - GreaterEqual, - Equal, - NotEqual, - And, - Or - }; - -private: - Expression* Left; - Expression* Right; - Operator Op; + std::string target; + BinaryOp op; + std::unique_ptr value; + + AssignNode(std::string t, BinaryOp o, std::unique_ptr v) + : target(t), op(o), value(std::move(v)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; +}; +// Refrence to Variable +class VarRefNode : public ASTNode { public: - BooleanOp(Operator Op, Expression* L, Expression* R) : Op(Op), Left(L), Right(R), Expression(ExpressionType::BooleanOpType) { } - - Expression* getLeft() { return Left; } - - Expression* getRight() { return Right; } - - Operator getOperator() { return Op; } - - virtual void accept(ASTVisitor& V) override - { - V.visit(*this); - } + std::string name; + + VarRefNode(std::string n) : name(n) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; - -class BinaryOp : public Expression -{ +// Literals +template +class LiteralNode : public ASTNode { public: - enum Operator - { - Plus, - Minus, - Mul, - Div, - Mod, - Pow - }; + T value; + + LiteralNode(T val) : value(val) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override { + os << value; + } +}; -private: - Expression* Left; // Left-hand side expression - Expression* Right; // Right-hand side expression - Operator Op; // Operator of the binary operation +using IntLiteral = LiteralNode; +using FloatLiteral = LiteralNode; +using BoolLiteral = LiteralNode; +using CharLiteral = LiteralNode; +using StrLiteral = LiteralNode; +// Binaray +class BinaryOpNode : public ASTNode { public: - BinaryOp(Operator Op, Expression* L, Expression* R) : Op(Op), Left(L), Right(R), Expression(ExpressionType::BinaryOpType) {} - - Expression* getLeft() { return Left; } - - Expression* getRight() { return Right; } - - Operator getOperator() { return Op; } - - virtual void accept(ASTVisitor& V) override - { - V.visit(*this); - } + BinaryOp op; + std::unique_ptr left; + std::unique_ptr right; + + BinaryOpNode(BinaryOp o, std::unique_ptr l, std::unique_ptr r) + : op(o), left(std::move(l)), right(std::move(r)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; - -class Base : public AST -{ -private: - llvm::SmallVector statements; - +// Unary +class UnaryOpNode : public ASTNode { public: - Base(llvm::SmallVector Statements) : statements(Statements) {} - llvm::SmallVector getStatements() { return statements; } - - llvm::SmallVector::const_iterator begin() { return statements.begin(); } - - llvm::SmallVector::const_iterator end() { return statements.end(); } - virtual void accept(ASTVisitor &V) override - { - V.visit(*this); - } + UnaryOp op; + std::unique_ptr operand; + + UnaryOpNode(UnaryOp o, std::unique_ptr opnd) + : op(o), operand(std::move(opnd)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -// stores information of a statement. For example x=56; is a statement -class Statement : public TopLevelEntity -{ +// Code Block +class BlockNode : public ASTNode { public: - enum StatementType - { - If, - ElseIf, - Else, - Print, - Declaration, - Assignment, - While, - For - }; - -private: - StatementType Type; + std::vector> statements; + + void print(llvm::raw_ostream &os, int indent = 0) const override; +}; +// if/else +class IfElseNode : public ASTNode { public: - StatementType getKind() - { - return Type; - } - - Statement(StatementType type) : Type(type) {} - virtual void accept(ASTVisitor &V) override - { - V.visit(*this); - } + std::unique_ptr condition; + std::unique_ptr thenBlock; + std::unique_ptr elseBlock; // Can be BlockNode or IfElseNode + + IfElseNode(std::unique_ptr cond, + std::unique_ptr thenBlk, + std::unique_ptr elseBlk = nullptr) + : condition(std::move(cond)), + thenBlock(std::move(thenBlk)), + elseBlock(std::move(elseBlk)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -class PrintStatement : public Statement -{ - private: - Expression * expr; - public: - // contructor that sets type - PrintStatement(Expression * identifier) : expr(identifier), Statement(Statement::StatementType::Print) {} - Expression * getExpr() - { - return expr; - } - virtual void accept(ASTVisitor &V) override - { - V.visit(*this); - } +// for +class ForLoopNode : public ASTNode { +public: + std::unique_ptr init; + std::unique_ptr condition; + std::unique_ptr update; + std::unique_ptr body; + + ForLoopNode(std::unique_ptr i, + std::unique_ptr cond, + std::unique_ptr upd, + std::unique_ptr b) + : init(std::move(i)), condition(std::move(cond)), + update(std::move(upd)), body(std::move(b)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; - -class DecStatement : public Statement { +// foreach +class ForeachLoopNode : public ASTNode { public: - enum DecStatementType { - Number, - Boolean, - }; -private: - - Expression* lvalue; - Expression* rvalue; - Statement::StatementType type; - DecStatement::DecStatementType dec_type; + std::string varName; + std::unique_ptr collection; + std::unique_ptr body; + + ForeachLoopNode(std::string var, + std::unique_ptr coll, + std::unique_ptr b) + : varName(var), collection(std::move(coll)), body(std::move(b)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; +}; +// print +class PrintNode : public ASTNode { public: - DecStatement(Expression* lvalue, Expression* rvalue) : lvalue(lvalue), rvalue(rvalue), type(Statement::StatementType::Declaration), Statement(type) { } - DecStatement(Expression* lvalue) : lvalue(lvalue), rvalue(rvalue), type(Statement::StatementType::Declaration), Statement(type) { rvalue = new Expression(0); } - DecStatement(Expression* lvalue, Expression* rvalue, DecStatement::DecStatementType dec_type) : lvalue(lvalue), rvalue(rvalue), type(Statement::StatementType::Declaration), dec_type(dec_type), Statement(type) { } - - Expression* getLValue() { - return lvalue; - } - - Expression* getRValue() { - return rvalue; - } - - DecStatementType getDecType() { - return dec_type; - } - - virtual void accept(ASTVisitor& V) override - { - V.visit(*this); - } + std::unique_ptr expr; + + PrintNode(std::unique_ptr e) : expr(std::move(e)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; - -class AssignStatement : public Statement { -private: - - Expression* lvalue; - Expression* rvalue; - Statement::StatementType type; - +// araye +class ArrayNode : public ASTNode { public: - AssignStatement(Expression* lvalue, Expression* rvalue) : lvalue(lvalue), rvalue(rvalue), type(Statement::StatementType::Assignment), Statement(type) { } - Expression* getLValue() { - return lvalue; - } - - Expression* getRValue() { - return rvalue; - } - - virtual void accept(ASTVisitor& V) override - { - V.visit(*this); - } + std::vector> elements; + + ArrayNode(std::vector> elems) + : elements(std::move(elems)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -class IfStatement : public Statement -{ - -private: - Expression *condition; - llvm::SmallVector statements; - llvm::SmallVector elseIfStatements; - ElseStatement *elseStatement; - bool hasElseIf; - bool hasElse; - +// nth khoone Arraye +class ArrayAccessNode : public ASTNode { public: - IfStatement(Expression *condition, - llvm::SmallVector statements, - llvm::SmallVector elseIfStatements, - ElseStatement *elseStatement, - bool hasElseIf, - bool hasElse, - StatementType type) : condition(condition), - statements(statements), - elseIfStatements(elseIfStatements), - elseStatement(elseStatement), - hasElseIf(hasElseIf), - hasElse(hasElse), - Statement(type) {} - - Expression *getCondition() - { - return condition; - } - - bool HasElseIf() - { - return hasElseIf; - } - - bool HasElse() - { - return hasElse; - } - - llvm::SmallVector getElseIfStatements() - { - return elseIfStatements; - } - - llvm::SmallVector getStatements() - { - return statements; - } - - ElseStatement *getElseStatement() - { - return elseStatement; - } - - virtual void accept(ASTVisitor &V) override - { - V.visit(*this); - } + std::string arrayName; + std::unique_ptr index; + + ArrayAccessNode(std::string name, std::unique_ptr idx) + : arrayName(name), index(std::move(idx)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -class ElseIfStatement : public Statement -{ - -private: - Expression *condition; - llvm::SmallVector statements; - +// Concat +class ConcatNode : public ASTNode { public: - ElseIfStatement(Expression *condition, llvm::SmallVector statements, StatementType type) : condition(condition), statements(statements), Statement(type) {} - - Expression *getCondition() - { - return condition; - } - - llvm::SmallVector getStatements() - { - return statements; - } - - virtual void accept(ASTVisitor &V) override - { - V.visit(*this); - } + std::unique_ptr left; + std::unique_ptr right; + + ConcatNode(std::unique_ptr l, std::unique_ptr r) + : left(std::move(l)), right(std::move(r)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -class ElseStatement : public Statement -{ - -private: - llvm::SmallVector statements; - +class PowNode : public ASTNode { public: - ElseStatement(llvm::SmallVector statements, Statement::StatementType type) : statements(statements), Statement(type) {} - - llvm::SmallVector getStatements() - { - return statements; - } - - virtual void accept(ASTVisitor &V) override - { - V.visit(*this); - } + std::unique_ptr base; + std::unique_ptr exponent; + + PowNode(std::unique_ptr b, std::unique_ptr e) + : base(std::move(b)), exponent(std::move(e)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -class WhileStatement : public Statement { - -private: - Expression* condition; - llvm::SmallVector statements; - bool optimized; - +// Error Control +class TryCatchNode : public ASTNode { public: - WhileStatement(Expression* condition, llvm::SmallVector statements, StatementType type) : condition(condition), statements(statements), Statement(type) {} - WhileStatement(Expression* condition, llvm::SmallVector statements, StatementType type, bool optimized) : condition(condition), statements(statements), Statement(type), optimized(optimized) { } - Expression* getCondition() - { - return condition; - } - - llvm::SmallVector getStatements() - { - return statements; - } - - virtual void accept(ASTVisitor& V) override - { - V.visit(*this); - } - - bool isOptimized(){ - return optimized; - } + std::unique_ptr tryBlock; + std::unique_ptr catchBlock; + std::string errorVar; + + TryCatchNode(std::unique_ptr tryBlk, + std::unique_ptr catchBlk, + std::string errVar) + : tryBlock(std::move(tryBlk)), + catchBlock(std::move(catchBlk)), + errorVar(errVar) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; -class ForStatement : public Statement { - -private: - Expression* condition; - llvm::SmallVector statements; - AssignStatement *initial_assign; - AssignStatement *update_assign; - bool optimized; +// match pattern +class MatchNode : public ASTNode { public: - ForStatement(Expression* condition,llvm::SmallVector statements,AssignStatement *initial_assign,AssignStatement *update_assign, Statement type ) : condition(condition), statements(statements),initial_assign(initial_assign),update_assign(update_assign) , Statement(type){} - ForStatement(Expression* condition,llvm::SmallVector statements,AssignStatement *initial_assign,AssignStatement *update_assign, Statement type, bool optimized) : condition(condition), statements(statements),initial_assign(initial_assign),update_assign(update_assign) , Statement(type), optimized(optimized){} - Expression* getCondition() - { - return condition; - } - - llvm::SmallVector getStatements() - { - return statements; - } - bool isOptimized(){ - return optimized; - } - AssignStatement* getInitialAssign(){ - return initial_assign; - } - AssignStatement* getUpdateAssign(){ - return update_assign; - } - virtual void accept(ASTVisitor& V) override - { - V.visit(*this); - } - + std::unique_ptr expr; + std::vector> cases; + + MatchNode(std::unique_ptr e, std::vector> c) + : expr(std::move(e)), cases(std::move(c)) {} + + void print(llvm::raw_ostream &os, int indent = 0) const override; }; + #endif From 5aee9eff34e699668e7b6be0c33ef7552d20a08e Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 22:32:18 +0330 Subject: [PATCH 12/25] 3 little lines --- code/lexer.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/code/lexer.cpp b/code/lexer.cpp index 6cd3622..a7f958a 100644 --- a/code/lexer.cpp +++ b/code/lexer.cpp @@ -61,6 +61,9 @@ Token Lexer::nextToken() { else if (text == "length") kind = Token::KW_length; else if (text == "min") kind = Token::KW_min; else if (text == "max") kind = Token::KW_max; + else if (text == "and") kind = Token::and_op; + else if (text == "or") kind = Token::or_op; + else if (text == "not") kind = Token::not_op; else if (text == "index") kind = Token::KW_index; return {kind, text, 0, 0}; From 5fd2de0ecc4e18bdbcdf1449842cdbddc6bc3d7b Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 22:35:54 +0330 Subject: [PATCH 13/25] Update lexer.cpp with few things float for example --- code/lexer.cpp | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/code/lexer.cpp b/code/lexer.cpp index a7f958a..7b7a8e4 100644 --- a/code/lexer.cpp +++ b/code/lexer.cpp @@ -71,15 +71,18 @@ Token Lexer::nextToken() { // numbers if (isdigit(*bufferPtr) || *bufferPtr == '.') { - bool hasDot = (*bufferPtr == '.'); - while (isdigit(*bufferPtr) || (!hasDot && *bufferPtr == '.')) { - if (*bufferPtr == '.') hasDot = true; - bufferPtr++; - } - return {hasDot ? Token::float_literal : Token::number, - llvm::StringRef(tokStart, bufferPtr - tokStart), 0, 0}; + bool hasDot = false; + if (*bufferPtr == '.') { + hasDot = true; + bufferPtr++; } - + while (isdigit(*bufferPtr) || (!hasDot && *bufferPtr == '.')) { + if (*bufferPtr == '.') hasDot = true; + bufferPtr++; + } + return {hasDot ? Token::float_literal : Token::number, + llvm::StringRef(tokStart, bufferPtr - tokStart), 0, 0};} + // strings if (*bufferPtr == '"') { bufferPtr++; @@ -164,8 +167,10 @@ Token Lexer::nextToken() { if (*(bufferPtr + 1) == '=') { bufferPtr += 2; return {Token::not_equal, "!=", 0, 0}; + } else { + bufferPtr++; + return {Token::not_op, "!", 0, 0}; } - break; case '<': if (*(bufferPtr + 1) == '=') { bufferPtr += 2; From db4cf9509c67957494d2869f7b75274ca4e9aec4 Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 22:40:11 +0330 Subject: [PATCH 14/25] added not op --- code/lexer.h | 1 + 1 file changed, 1 insertion(+) diff --git a/code/lexer.h b/code/lexer.h index 9478788..248ff73 100644 --- a/code/lexer.h +++ b/code/lexer.h @@ -42,6 +42,7 @@ class Token { greater_equal, and_op, or_op, + not_op plus_plus, minus_minus, From c66fd0432cb0aecea643cca5dcba174d5b075816 Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 22:42:13 +0330 Subject: [PATCH 15/25] 1 function changed --- code/parser.cpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/code/parser.cpp b/code/parser.cpp index 40dceb8..9ca427e 100644 --- a/code/parser.cpp +++ b/code/parser.cpp @@ -323,7 +323,18 @@ std::unique_ptr Parser::parsePrimary() { case Token::identifier: { std::string name = currentTok.text.str(); advance(); - + + if (currentTok.is(Token::plus_plus) || currentTok.is(Token::minus_minus)) { + Token opTok = currentTok; + advance(); // consume ++ or -- + return std::make_unique( + opTok.is(Token::plus_plus) + ? UnaryOp::INCREMENT + : UnaryOp::DECREMENT, + std::make_unique(name) + ); + } + // Function call or array access if (currentTok.is(Token::l_paren)) { return parseFunctionCall(name); From 5adf88e4fb721c0f1eec2ea045b2dab88e173f04 Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 22:51:59 +0330 Subject: [PATCH 16/25] nothing new --- code/semantic.h | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/code/semantic.h b/code/semantic.h index 6056947..c131944 100644 --- a/code/semantic.h +++ b/code/semantic.h @@ -1,13 +1,12 @@ -#ifndef SEMA_H -#define SEMA_H +#ifndef SEMANTIC_H +#define SEMANTIC_H #include "AST.h" -#include "lexer.h" -class Semantic -{ +class Semantic { public: - bool semantic(AST *Tree); + + bool semantic(ProgramNode *root); }; -#endif \ No newline at end of file +#endif From f33b4351088a310b34855b6cf01c72d6f3fbb08e Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 22:53:42 +0330 Subject: [PATCH 17/25] Update semantic.cpp with a lot of changes --- code/semantic.cpp | 513 +++++++++++++++++++++------------------------- 1 file changed, 232 insertions(+), 281 deletions(-) diff --git a/code/semantic.cpp b/code/semantic.cpp index dff288e..9b71dc6 100644 --- a/code/semantic.cpp +++ b/code/semantic.cpp @@ -1,312 +1,263 @@ #include "semantic.h" +#include "AST.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/raw_ostream.h" -namespace -{ - class DeclCheck : public ASTVisitor - { - llvm::StringMap variableTypeMap; - bool HasError; - enum ErrorType - { - AlreadyDefinedVariable, - NotDefinedVariable, - DivideByZero, - WrongValueTypeForVariable - }; - - void error(ErrorType errorType, llvm::StringRef V) - { - switch (errorType) - { - case ErrorType::DivideByZero: - llvm::errs() << "Division by zero is not allowed!\n"; - break; - case ErrorType::AlreadyDefinedVariable: - llvm::errs() << "Variable " << V << " is already declared!\n"; - break; - case ErrorType::NotDefinedVariable: - llvm::errs() << "Variable " << V << " is not declared!\n"; - break; - case ErrorType::WrongValueTypeForVariable: - llvm::errs() << "Illegal value for type " << V << "!\n"; - break; - default: - llvm::errs() << "Unknown error\n"; - break; - } - HasError = true; - exit(3); +namespace { + // Helper: convert VarType to printable name + static const char *varTypeName(VarType t) { + switch (t) { + case VarType::INT: return "int"; + case VarType::FLOAT: return "float"; + case VarType::BOOL: return "bool"; + case VarType::CHAR: return "char"; + case VarType::STRING: return "string"; + case VarType::ARRAY: return "array"; + default: return ""; } + } - public: - DeclCheck() : HasError(false) {} - - bool hasError() { return HasError; } - - virtual void visit(Base &Node) override - { - for (auto I = Node.begin(), E = Node.end(); I != E; ++I) - { - (*I)->accept(*this); - } - }; + class DeclCheck : public ASTVisitor { + llvm::StringMap VarTypes; + bool HasError = false; - virtual void visit(PrintStatement &Node) override - { - Expression *declaration = (Expression *)Node.getExpr(); - declaration->accept(*this); - }; - - virtual void visit(BinaryOp &Node) override - { - if (Node.getLeft()) - { - Node.getLeft()->accept(*this); - } - else - { - HasError = true; - } - if (Node.getRight()) - { - Node.getRight()->accept(*this); - } - else - { - HasError = true; + enum ErrorKind { AlreadyDefined, NotDefined, DivideByZero, TypeMismatch, InvalidOperation }; + void report(ErrorKind kind, const std::string &msg) { + switch (kind) { + case AlreadyDefined: llvm::errs() << "Error: variable '" << msg << "' already defined\n"; break; + case NotDefined: llvm::errs() << "Error: variable '" << msg << "' not defined\n"; break; + case DivideByZero: llvm::errs() << "Error: division by zero!\n"; break; + case TypeMismatch: llvm::errs() << "Error: type mismatch for '" << msg << "'\n"; break; + case InvalidOperation: llvm::errs() << "Error: invalid operation '" << msg << "' for type\n"; break; } + HasError = true; + } - // Divide by zero check - if (Node.getOperator() == BinaryOp::Operator::Div) - { - Expression *right = (Expression *)Node.getRight(); - if (right->isNumber() && right->getNumber() == 0) - { - error(DivideByZero, ((Expression *)Node.getLeft())->getValue()); + // Determine the VarType of an expression node, or VarType::ERROR on failure + VarType typeOf(ASTNode *node) { + if (!node) return VarType::ERROR; + // Literals + if (auto *i = dynamic_cast*>(node)) return VarType::INT; + if (auto *f = dynamic_cast*>(node)) return VarType::FLOAT; + if (auto *b = dynamic_cast*>(node)) return VarType::BOOL; + if (auto *c = dynamic_cast*>(node)) return VarType::CHAR; + if (auto *s = dynamic_cast*>(node)) return VarType::STRING; + // Variable reference + if (auto *v = dynamic_cast(node)) { + if (!VarTypes.count(v->name)) { + report(NotDefined, v->name); + return VarType::ERROR; + } + return VarTypes.lookup(v->name); + } + // Binary operations + if (auto *bin = dynamic_cast(node)) { + VarType lt = typeOf(bin->left.get()); + VarType rt = typeOf(bin->right.get()); + switch (bin->op) { + case BinaryOp::ADD: case BinaryOp::SUBTRACT: + case BinaryOp::MULTIPLY: case BinaryOp::DIVIDE: case BinaryOp::MOD: + case BinaryOp::POW: { + // numeric or array-scalar or scalar-array + if ((lt==VarType::INT||lt==VarType::FLOAT) && (rt==VarType::INT||rt==VarType::FLOAT)) + return (lt==VarType::FLOAT||rt==VarType::FLOAT?VarType::FLOAT:VarType::INT); + if (lt==VarType::ARRAY && (rt==VarType::INT||rt==VarType::FLOAT)) return VarType::ARRAY; + if ((lt==VarType::INT||lt==VarType::FLOAT) && rt==VarType::ARRAY) return VarType::ARRAY; + report(InvalidOperation, varTypeName(lt) + std::string(" op ") + varTypeName(rt)); + return VarType::ERROR; + } + case BinaryOp::AND: case BinaryOp::OR: + if (lt==VarType::BOOL && rt==VarType::BOOL) return VarType::BOOL; + report(InvalidOperation, "logical"); return VarType::ERROR; + case BinaryOp::EQUAL: case BinaryOp::NOT_EQUAL: + case BinaryOp::LESS: case BinaryOp::LESS_EQUAL: + case BinaryOp::GREATER: case BinaryOp::GREATER_EQUAL: + // comparisons produce bool + if ((lt==VarType::INT||lt==VarType::FLOAT) && (rt==VarType::INT||rt==VarType::FLOAT)) return VarType::BOOL; + if (lt==rt) return VarType::BOOL; + report(TypeMismatch, "compare"); return VarType::ERROR; + case BinaryOp::CONCAT: + if (lt==VarType::STRING && rt==VarType::STRING) return VarType::STRING; + report(InvalidOperation, "concat"); return VarType::ERROR; + case BinaryOp::INDEX: + if (lt==VarType::ARRAY && rt==VarType::INT) return VarType::ARRAY; // element type infer omitted + report(InvalidOperation, "index"); return VarType::ERROR; + default: + return VarType::ERROR; } } - }; + // Unary operations + if (auto *un = dynamic_cast(node)) { + VarType ot = typeOf(un->operand.get()); + switch (un->op) { + case UnaryOp::INCREMENT: case UnaryOp::DECREMENT: + if (ot==VarType::INT||ot==VarType::FLOAT) return ot; + break; + case UnaryOp::ABS: + if (ot==VarType::INT||ot==VarType::FLOAT) return ot; + break; + case UnaryOp::LENGTH: + if (ot==VarType::ARRAY||ot==VarType::STRING) return VarType::INT; + break; + case UnaryOp::MIN: case UnaryOp::MAX: + if (ot==VarType::ARRAY) return VarType::ARRAY; + break; + default: break; + } + report(InvalidOperation, varTypeName(ot)); + return VarType::ERROR; + } + // Arrays + if (auto *arr = dynamic_cast(node)) { + if (arr->elements.empty()) return VarType::ARRAY; + VarType et = typeOf(arr->elements[0].get()); + for (auto &e : arr->elements) { + VarType t2 = typeOf(e.get()); + if (t2 != et) report(TypeMismatch, "array elements"); + } + return VarType::ARRAY; + } + // Array access + if (auto *acc = dynamic_cast(node)) { + VarType at = typeOf(acc->index.get()); + VarType arrt = VarTypes.lookup(acc->arrayName); + if (arrt != VarType::ARRAY) report(TypeMismatch, acc->arrayName); + if (at != VarType::INT) report(TypeMismatch, "index type"); + return VarType::ARRAY; + } + // Print + if (auto *p = dynamic_cast(node)) { + return typeOf(p->expr.get()); + } + // Default: error + return VarType::ERROR; + } - virtual void visit(Statement &Node) override - { - if (Node.getKind() == Statement::StatementType::Declaration) - { - DecStatement *declaration = (DecStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::Assignment) - { - AssignStatement *declaration = (AssignStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::If) - { - IfStatement *declaration = (IfStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::ElseIf) - { - ElseIfStatement *declaration = (ElseIfStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::Else) - { - ElseStatement *declaration = (ElseStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::Print) - { - PrintStatement *declaration = (PrintStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::While) - { - WhileStatement *declaration = (WhileStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::For) - { - ForStatement *declaration = (ForStatement *)&Node; - declaration->accept(*this); - } - - - }; + public: + bool hasError() const { return HasError; } - virtual void visit(BooleanOp &Node) override - { - if (Node.getLeft()) - { - Node.getLeft()->accept(*this); - } - else - { - HasError = true; - } - if (Node.getRight()) - { - Node.getRight()->accept(*this); - } - else - { - HasError = true; - } - }; + // Program and Block + void visit(ProgramNode &node) override { + for (auto &s : node.statements) s->accept(*this); + } + void visit(BlockNode &node) override { + for (auto &s : node.statements) s->accept(*this); + } - virtual void visit(Expression &Node) override - { - if (Node.getKind() == Expression::ExpressionType::Identifier) - { - if (variableTypeMap.count(Node.getValue()) == 0) - { - error(NotDefinedVariable, Node.getValue()); + // Variable declarations + void visit(MultiVarDeclNode &node) override { + for (auto &decl : node.declarations) { + const std::string &name = decl->name; + if (VarTypes.count(name)) report(AlreadyDefined, name); + else { + VarTypes[name] = decl->type; + if (decl->value) decl->value->accept(*this); } } - else if (Node.getKind() == Expression::ExpressionType::BinaryOpType) - { - BinaryOp *declaration = (BinaryOp *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Expression::ExpressionType::BooleanOpType) - { - Node.getBooleanOp()->getLeft()->accept(*this); - Node.getBooleanOp()->getRight()->accept(*this); - } - }; - - virtual void visit(DecStatement &Node) override - { - if (variableTypeMap.count(Node.getLValue()->getValue()) > 0) - { - error(AlreadyDefinedVariable, Node.getLValue()->getValue()); - } - // Add this new variable to variableTypeMap - if (Node.getDecType() == DecStatement::DecStatementType::Boolean) - { - variableTypeMap[Node.getLValue()->getValue()] = 'b'; - } - else - { - variableTypeMap[Node.getLValue()->getValue()] = 'i'; - } + } - Expression *rightValue = (Expression *)Node.getRValue(); - if (rightValue == nullptr) - { - return; - } - if (Node.getDecType() == DecStatement::DecStatementType::Boolean) - { - if (!(rightValue->getKind() == Expression::ExpressionType::Boolean || - rightValue->getKind() == Expression::ExpressionType::BooleanOpType)) - { - error(WrongValueTypeForVariable, "bool"); - } - } - else if (Node.getDecType() == DecStatement::DecStatementType::Number) - { - if (!(rightValue->getKind() == Expression::ExpressionType::Number || - rightValue->getKind() == Expression::ExpressionType::BinaryOpType)) - { - error(WrongValueTypeForVariable, "int"); - } + // Assignments + void visit(AssignNode &node) override { + if (!VarTypes.count(node.target)) report(NotDefined, node.target); + node.value->accept(*this); + VarType vt = VarTypes.lookup(node.target); + VarType rt = typeOf(node.value.get()); + // Allowed ops per type + bool ok = false; + switch (vt) { + case VarType::INT: case VarType::FLOAT: + ok = true; break; + case VarType::CHAR: case VarType::STRING: case VarType::BOOL: case VarType::ARRAY: + ok = (node.op == BinaryOp::EQUAL); + break; + default: break; } + if (!ok) report(InvalidOperation, varTypeName(vt)); + if (rt != vt && !(vt==VarType::FLOAT && rt==VarType::INT)) + report(TypeMismatch, node.target); + } - rightValue->accept(*this); - }; - - virtual void visit(IfStatement &Node) override - { - Expression *declaration = (Expression *)Node.getCondition(); - declaration->accept(*this); - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); - } - }; + // Variable reference + void visit(VarRefNode &node) override { + if (!VarTypes.count(node.name)) report(NotDefined, node.name); + } - virtual void visit(ElseIfStatement &Node) override - { - Expression *declaration = (Expression *)Node.getCondition(); - declaration->accept(*this); - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); + // Binary ops + void visit(BinaryOpNode &node) override { + node.left->accept(*this); + node.right->accept(*this); + if (node.op == BinaryOp::DIVIDE) { + if (auto *lit = dynamic_cast*>(node.right.get())) { + if (lit->value == 0) report(DivideByZero, ""); + } } - }; + typeOf(&node); + } - virtual void visit(ElseStatement &Node) override - { - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); - } - }; + // Unary ops + void visit(UnaryOpNode &node) override { + node.operand->accept(*this); + typeOf(&node); + } - virtual void visit(AssignStatement &Node) override - { - Node.getLValue()->accept(*this); - Node.getRValue()->accept(*this); - if (variableTypeMap.lookup(Node.getLValue()->getValue()) == 'i' && - (Node.getRValue()->getKind() == Expression::ExpressionType::Boolean || - Node.getRValue()->getKind() == Expression::ExpressionType::BooleanOpType)) - { - error(WrongValueTypeForVariable, "int"); - } - if (variableTypeMap.lookup(Node.getLValue()->getValue()) == 'b' && - (Node.getRValue()->getKind() == Expression::ExpressionType::Number || - Node.getRValue()->getKind() == Expression::ExpressionType::BinaryOpType)) - { - error(WrongValueTypeForVariable, "bool"); - } - }; + // Control structures + void visit(IfElseNode &node) override { + node.condition->accept(*this); + if (typeOf(node.condition.get()) != VarType::BOOL) + report(TypeMismatch, "if condition"); + node.thenBlock->accept(*this); + if (node.elseBlock) node.elseBlock->accept(*this); + } + void visit(ForLoopNode &node) override { + if (node.init) node.init->accept(*this); + if (node.condition) { + node.condition->accept(*this); + if (typeOf(node.condition.get()) != VarType::BOOL) + report(TypeMismatch, "for condition"); + } + if (node.update) node.update->accept(*this); + node.body->accept(*this); + } + void visit(WhileLoopNode &node) override { + node.condition->accept(*this); + if (typeOf(node.condition.get()) != VarType::BOOL) + report(TypeMismatch, "while condition"); + node.body->accept(*this); + } - virtual void visit(WhileStatement &Node) override{ - Node.getCondition()->accept(*this); - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); - } - }; - - virtual void visit(ForStatement &Node) override{ - AssignStatement * initial_assign = Node.getInitialAssign(); - /*if(initial_assign == nullptr){ - exit(0); - }*/ - (initial_assign->getLValue())->accept(*this); - (initial_assign->getRValue())->accept(*this); - Node.getCondition()->accept(*this); - - AssignStatement * update_assign = Node.getUpdateAssign(); - if(update_assign == nullptr){ - exit(0); - } - (update_assign->getLValue())->accept(*this); - (update_assign->getRValue())->accept(*this); + // Print + void visit(PrintNode &node) override { + node.expr->accept(*this); + } - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); - } - }; + // Array + void visit(ArrayNode &node) override { + for (auto &e : node.elements) e->accept(*this); + typeOf(&node); + } + void visit(ArrayAccessNode &node) override { + node.index->accept(*this); + typeOf(&node); + } + // Concat, Pow + void visit(ConcatNode &node) override { + node.left->accept(*this); + node.right->accept(*this); + typeOf(&node); + } + void visit(PowNode &node) override { + node.base->accept(*this); + node.exponent->accept(*this); + typeOf(&node); + } }; } -bool Semantic::semantic(AST *Tree) -{ - if (!Tree) - return false; - DeclCheck Check; - Tree->accept(Check); - return Check.hasError(); +bool Semantic::semantic(ProgramNode *root) { + if (!root) return false; + DeclCheck checker; + root->accept(checker); + return checker.hasError(); } From 147d3855b163714adb56a5ecbbf882f3958fe0b7 Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 22:55:28 +0330 Subject: [PATCH 18/25] deleted somethings --- code/semantic.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/code/semantic.cpp b/code/semantic.cpp index 9b71dc6..7746348 100644 --- a/code/semantic.cpp +++ b/code/semantic.cpp @@ -5,7 +5,6 @@ namespace { - // Helper: convert VarType to printable name static const char *varTypeName(VarType t) { switch (t) { case VarType::INT: return "int"; @@ -34,7 +33,7 @@ namespace { HasError = true; } - // Determine the VarType of an expression node, or VarType::ERROR on failure + VarType typeOf(ASTNode *node) { if (!node) return VarType::ERROR; // Literals From 4c3bf4d9c1e7dff9d24b19455c607c6393386fb0 Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 23:02:04 +0330 Subject: [PATCH 19/25] Update code_generator.h --- code/code_generator.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/code/code_generator.h b/code/code_generator.h index 0012c9f..7beb611 100644 --- a/code/code_generator.h +++ b/code/code_generator.h @@ -1,11 +1,11 @@ -#ifndef CODEGEN_H -#define CODEGEN_H +#ifndef CODE_GENERATOR_H +#define CODE_GENERATOR_H #include "AST.h" -class CodeGen -{ +class CodeGen { public: - void compile(AST *Tree, bool optimize, int k); + void compile(ProgramNode *root, bool optimize, int unroll); }; -#endif + +#endif From a05c72b3f0708c5e7c454db4c361e1454e60e35e Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 23:03:08 +0330 Subject: [PATCH 20/25] Update code_generator.cpp --- code/code_generator.cpp | 733 ++++++++++++---------------------------- 1 file changed, 207 insertions(+), 526 deletions(-) diff --git a/code/code_generator.cpp b/code/code_generator.cpp index f256fad..28cccf2 100644 --- a/code/code_generator.cpp +++ b/code/code_generator.cpp @@ -1,543 +1,224 @@ #include "code_generator.h" -#include "optimizer.h" #include "llvm/ADT/StringMap.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; -// Define a visitor class for generating LLVM IR from the AST. -namespace -{ - class ToIRVisitor : public ASTVisitor - { - Module *M; - IRBuilder<> Builder; - Type *VoidTy; - Type *Int32Ty; - Type *Int1Ty; - Type *Int8PtrTy; - Type *Int8PtrPtrTy; - Constant *Int32Zero; - Constant *Int1Zero; - - Value *V; - StringMap nameMap; - - llvm::FunctionType *MainFty; - llvm::Function *MainFn; - FunctionType *CalcWriteFnTy; - FunctionType *CalcWriteFnTyBool; - Function *CalcWriteFn; - Function *CalcWriteFnBool; - bool optimize; - int k; - - - public: - // Constructor for the visitor class - ToIRVisitor(Module *M, bool optimize_enable, int k_value) : M(M), Builder(M->getContext()) - { - // Initialize LLVM types and constants - VoidTy = Type::getVoidTy(M->getContext()); - Int32Ty = Type::getInt32Ty(M->getContext()); - Int1Ty = Type::getInt1Ty(M->getContext()); - Int8PtrTy = Type::getInt8PtrTy(M->getContext()); - Int8PtrPtrTy = Int8PtrTy->getPointerTo(); - Int32Zero = ConstantInt::get(Int32Ty, 0, true); - CalcWriteFnTy = FunctionType::get(VoidTy, {Int32Ty}, false); - CalcWriteFnTyBool = FunctionType::get(VoidTy, {Int1Ty}, false); - CalcWriteFn = Function::Create(CalcWriteFnTy, GlobalValue::ExternalLinkage, "print", M); - CalcWriteFnBool = Function::Create(CalcWriteFnTyBool, GlobalValue::ExternalLinkage, "printBool", M); - optimize = optimize_enable; - k = k_value; - } - - // Entry point for generating LLVM IR from the AST - void run(AST *Tree) - { - // Create the main function with the appropriate function type. - MainFty = FunctionType::get(Int32Ty, {Int32Ty, Int8PtrPtrTy}, false); - MainFn = Function::Create(MainFty, GlobalValue::ExternalLinkage, "main", M); - - // Create a basic block for the entry point of the main function. - BasicBlock *BB = BasicBlock::Create(M->getContext(), "entry", MainFn); - Builder.SetInsertPoint(BB); - - // Visit the root node of the AST to generate IR - Tree->accept(*this); - - // Create a return instruction at the end of the main function - Builder.CreateRet(Int32Zero); - } - - virtual void visit(Base &Node) override - { - // TODO: find a better way to not implement this again! - for (auto I = Node.begin(), E = Node.end(); I != E; ++I) - { - (*I)->accept(*this); - } - } - - virtual void visit(Statement &Node) override - { - // TODO: find a better way to not implement this again! - if (Node.getKind() == Statement::StatementType::Declaration) - { - DecStatement *declaration = (DecStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::Assignment) - { - AssignStatement *declaration = (AssignStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::If) - { - IfStatement *declaration = (IfStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::ElseIf) - { - ElseIfStatement *declaration = (ElseIfStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::Else) - { - ElseStatement *declaration = (ElseStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::Print) - { - PrintStatement *declaration = (PrintStatement *)&Node; - declaration->accept(*this); - } - else if (Node.getKind() == Statement::StatementType::While) - { - WhileStatement *declaration = (WhileStatement *)&Node; - declaration->accept(*this); - } - } - - virtual void visit(PrintStatement &Node) override { - // Visit the right-hand side of the expression and get its value. - Node.getExpr()->accept(*this); - Value *val = V; - - // Determine the type of 'val' and select the appropriate print function - Type *valType = val->getType(); - if (valType == Int32Ty) { - CallInst *Call = Builder.CreateCall(CalcWriteFnTy, CalcWriteFn, {val}); - } else if (valType == Int1Ty) { - CallInst *Call = Builder.CreateCall(CalcWriteFnTyBool, CalcWriteFnBool, {val}); +namespace { + +class ToIRVisitor : public ASTVisitor { + LLVMContext &Ctx; + std::unique_ptr M; + IRBuilder<> Builder; + Type *Int32Ty; + Type *Int1Ty; + Constant *Int32Zero; + Constant *Int1False; + + StringMap NameMap; + Value *V; // holds last expression value + Function *MainFn; + FunctionCallee PrintIntFn; + FunctionCallee PrintBoolFn; + bool Optimize; + int UnrollFactor; + +public: + ToIRVisitor(LLVMContext &context, bool optimize, int unroll) + : Ctx(context), M(new Module("main_module", context)), Builder(context), + Optimize(optimize), UnrollFactor(unroll) { + Int32Ty = Type::getInt32Ty(Ctx); + Int1Ty = Type::getInt1Ty(Ctx); + Int32Zero = ConstantInt::get(Int32Ty, 0); + Int1False = ConstantInt::get(Int1Ty, 0); + // declare external print functions + PrintIntFn = M->getOrInsertFunction("print", FunctionType::get(Type::getVoidTy(Ctx), {Int32Ty}, false)); + PrintBoolFn = M->getOrInsertFunction("printBool", FunctionType::get(Type::getVoidTy(Ctx), {Int1Ty}, false)); + } + + Module *getModule() { return M.get(); } + + void run(ProgramNode *root) { + // define main i32 @main() + FunctionType *FT = FunctionType::get(Int32Ty, false); + MainFn = Function::Create(FT, GlobalValue::ExternalLinkage, "main", M.get()); + BasicBlock *BB = BasicBlock::Create(Ctx, "entry", MainFn); + Builder.SetInsertPoint(BB); + // generate body + root->accept(*this); + Builder.CreateRet(Int32Zero); + // verify + verifyFunction(*MainFn); + } + + // Helpers + AllocaInst *createEntryBlockAlloca(Function *F, const std::string &Name, Type *Ty) { + IRBuilder<> TmpB(&F->getEntryBlock(), F->getEntryBlock().begin()); + return TmpB.CreateAlloca(Ty, nullptr, Name); + } + + // Visitor implementations + void visit(ProgramNode &node) override { + for (auto &stmt : node.statements) + stmt->accept(*this); + } + + void visit(MultiVarDeclNode &node) override { + for (auto &decl : node.declarations) { + Type *Ty = (decl->type == VarType::BOOL ? Int1Ty : Int32Ty); + AllocaInst *A = createEntryBlockAlloca(MainFn, decl->name, Ty); + NameMap[decl->name()] = A; + // initialize + if (decl->value) { + decl->value->accept(*this); + Builder.CreateStore(V, A); } else { - // If the type is not supported, print an error message - llvm::errs() << "Unsupported type for print statement\n"; + Builder.CreateStore(Ty==Int1Ty ? Int1False : Int32Zero, A); } } - - virtual void visit(Expression &Node) override - { - if (Node.getKind() == Expression::ExpressionType::Identifier) - { - AllocaInst *allocaInst = nameMap[Node.getValue()]; - if (!allocaInst) { - llvm::errs() << "Undefined variable '" << Node.getValue() << "'\n"; - return; - } - V = Builder.CreateLoad(allocaInst->getAllocatedType(), allocaInst, Node.getValue()); - } - else if (Node.getKind() == Expression::ExpressionType::Number) - { - int int_value = Node.getNumber(); - V = ConstantInt::get(Int32Ty, int_value, true); - } - else if (Node.getKind() == Expression::ExpressionType::Boolean) - { - bool bool_value = Node.getBoolean(); - V = ConstantInt::getBool(Int1Ty, bool_value); - } + } + + void visit(AssignNode &node) override { + node.value->accept(*this); + auto *A = NameMap[node.target]; + Builder.CreateStore(V, A); + } + + void visit(VarRefNode &node) override { + AllocaInst *A = NameMap[node.name]; + V = Builder.CreateLoad(A->getAllocatedType(), A, node.name); + } + + void visit(LiteralNode &node) override { + V = ConstantInt::get(Int32Ty, node.value); + } + void visit(LiteralNode &node) override { + V = ConstantInt::get(Int1Ty, node.value); + } + + void visit(BinaryOpNode &node) override { + node.left->accept(*this); + Value *L = V; + node.right->accept(*this); + Value *R = V; + switch (node.op) { + case BinaryOp::ADD: V = Builder.CreateAdd(L, R); break; + case BinaryOp::SUBTRACT: V = Builder.CreateSub(L, R); break; + case BinaryOp::MULTIPLY: V = Builder.CreateMul(L, R); break; + case BinaryOp::DIVIDE: V = Builder.CreateSDiv(L, R); break; + case BinaryOp::MOD: V = Builder.CreateSRem(L, R); break; + case BinaryOp::EQUAL: V = Builder.CreateICmpEQ(L, R); break; + case BinaryOp::NOT_EQUAL: V = Builder.CreateICmpNE(L, R); break; + case BinaryOp::LESS: V = Builder.CreateICmpSLT(L, R); break; + case BinaryOp::LESS_EQUAL: V = Builder.CreateICmpSLE(L, R); break; + case BinaryOp::GREATER: V = Builder.CreateICmpSGT(L, R); break; + case BinaryOp::GREATER_EQUAL: V = Builder.CreateICmpSGE(L,R); break; + case BinaryOp::AND: V = Builder.CreateAnd(L, R); break; + case BinaryOp::OR: V = Builder.CreateOr(L, R); break; + case BinaryOp::POW: + // simple pow via loop omitted for brevity + V = L; break; + default: + V = UndefValue::get(Int32Ty); } - - virtual void visit(BooleanOp &Node) override - { - // Visit the left-hand side of the binary operation and get its value - Node.getLeft()->accept(*this); - Value *Left = V; - - // Visit the right-hand side of the binary operation and get its value - Node.getRight()->accept(*this); - Value *Right = V; - - // Perform the boolean operation based on the operator type and create the corresponding instruction - switch (Node.getOperator()) - { - case BooleanOp::Equal: - V = Builder.CreateICmpEQ(Left, Right); - break; - case BooleanOp::NotEqual: - V = Builder.CreateICmpNE(Left, Right); - break; - case BooleanOp::Less: - V = Builder.CreateICmpSLT(Left, Right); - break; - case BooleanOp::LessEqual: - V = Builder.CreateICmpSLE(Left, Right); - break; - case BooleanOp::Greater: - V = Builder.CreateICmpSGT(Left, Right); - break; - case BooleanOp::GreaterEqual: - V = Builder.CreateICmpSGE(Left, Right); - break; - case BooleanOp::And: - V = Builder.CreateAnd(Left, Right); - break; - case BooleanOp::Or: - V = Builder.CreateOr(Left, Right); - break; - } - } - - virtual void visit(BinaryOp &Node) override - { - // Visit the left-hand side of the binary operation and get its value - Node.getLeft()->accept(*this); - Value *Left = V; - - // Visit the right-hand side of the binary operation and get its value - Node.getRight()->accept(*this); - Value *Right = V; - - // Perform the binary operation based on the operator type and create the corresponding instruction - switch (Node.getOperator()) - { - case BinaryOp::Plus: - V = Builder.CreateNSWAdd(Left, Right); - break; - case BinaryOp::Minus: - V = Builder.CreateNSWSub(Left, Right); - break; - case BinaryOp::Mul: - V = Builder.CreateNSWMul(Left, Right); - break; - case BinaryOp::Div: - V = Builder.CreateSDiv(Left, Right); - break; - case BinaryOp::Pow:{ - Function *func = Builder.GetInsertBlock()->getParent(); - BasicBlock *preLoopBB = Builder.GetInsertBlock(); - BasicBlock *loopBB = BasicBlock::Create(M->getContext(), "loop", func); - BasicBlock *afterLoopBB = BasicBlock::Create(M->getContext(), "afterloop", func); - - // Setup loop entry - Builder.CreateBr(loopBB); - Builder.SetInsertPoint(loopBB); - - // Create loop variables (PHI nodes) - PHINode *resultPhi = Builder.CreatePHI(Left->getType(), 2, "result"); - resultPhi->addIncoming(Left, preLoopBB); - - PHINode *indexPhi = Builder.CreatePHI(Type::getInt32Ty(M->getContext()), 2, "index"); - indexPhi->addIncoming(ConstantInt::get(Type::getInt32Ty(M->getContext()), 0), preLoopBB); - - // Perform multiplication - Value *updatedResult = Builder.CreateNSWMul(resultPhi, Left, "multemp"); - Value *updatedIndex = Builder.CreateAdd(indexPhi, ConstantInt::get(Type::getInt32Ty(M->getContext()), 1), "indexinc"); - - // Exit condition - Value *condition = Builder.CreateICmpNE(updatedIndex, Right, "loopcond"); - Builder.CreateCondBr(condition, loopBB, afterLoopBB); - - // Update PHI nodes for the next iteration - resultPhi->addIncoming(updatedResult, loopBB); - indexPhi->addIncoming(updatedIndex, loopBB); - - // Complete the loop - Builder.SetInsertPoint(afterLoopBB); - V = resultPhi; // The result of a^b - break; - } - case BinaryOp::Mod: - Value *division = Builder.CreateSDiv(Left, Right); - Value *multiplication = Builder.CreateNSWMul(division, Right); - V = Builder.CreateNSWSub(Left, multiplication); - } + } + + void visit(UnaryOpNode &node) override { + node.operand->accept(*this); + switch (node.op) { + case UnaryOp::INCREMENT: { + Value *One = ConstantInt::get(Int32Ty, 1); + V = Builder.CreateAdd(V, One); + break; } - - virtual void visit(DecStatement &Node) override - { - Value *val = nullptr; - - if (Node.getRValue() != nullptr) - { - // If there is an expression provided, visit it and get its value - Node.getRValue()->accept(*this); - val = V; - } - - // Iterate over the variables declared in the declaration statement - auto I = Node.getLValue()->getValue(); - StringRef Var = I; - - // Create an alloca instruction to allocate memory for the variable - Type *varType = (Node.getDecType() == DecStatement::DecStatementType::Number) ? Int32Ty : Type::getInt1Ty(M->getContext()); - nameMap[Var] = Builder.CreateAlloca(varType); - - // Store the initial value (if any) in the variable's memory location - if (val != nullptr) - { - Builder.CreateStore(val, nameMap[Var]); - } - else - { - if(Node.getDecType() == DecStatement::DecStatementType::Number){ - Value *Zero = ConstantInt::get(Type::getInt32Ty(M->getContext()), 0); - Builder.CreateStore(Zero, nameMap[Var]); - } else{ - Value *False = ConstantInt::get(Type::getInt1Ty(M->getContext()), 0); - Builder.CreateStore(False, nameMap[Var]); - } - - } + case UnaryOp::DECREMENT: { + Value *One = ConstantInt::get(Int32Ty, 1); + V = Builder.CreateSub(V, One); + break; } - - virtual void visit(AssignStatement &Node) override - { - // Visit the right-hand side of the assignment and get its value - Node.getRValue()->accept(*this); - Value *val = V; - - // Get the name of the variable being assigned - auto varName = Node.getLValue()->getValue(); - - // Create a store instruction to assign the value to the variable - Builder.CreateStore(val, nameMap[varName]); + default: + break; } - - virtual void visit(IfStatement &Node) override - { - llvm::BasicBlock *IfCondBB = llvm::BasicBlock::Create(M->getContext(), "if.cond", MainFn); - llvm::BasicBlock *IfBodyBB = llvm::BasicBlock::Create(M->getContext(), "if.body", MainFn); - llvm::BasicBlock *AfterIfBB = llvm::BasicBlock::Create(M->getContext(), "after.if", MainFn); - - Builder.CreateBr(IfCondBB); - Builder.SetInsertPoint(IfCondBB); - - Node.getCondition()->accept(*this); - Value *Cond = V; - - Builder.SetInsertPoint(IfBodyBB); - - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); - } - - Builder.CreateBr(AfterIfBB); - llvm::BasicBlock *BeforeCondBB = IfCondBB; - llvm::BasicBlock *BeforeBodyBB = IfBodyBB; - llvm::Value *BeforeCondVal = Cond; - - if (Node.HasElseIf()) - { - for (auto &elseIf : Node.getElseIfStatements()) - { - llvm::BasicBlock *ElseIfCondBB = llvm::BasicBlock::Create(MainFn->getContext(), "elseIf.cond", MainFn); - llvm::BasicBlock *ElseIfBodyBB = llvm::BasicBlock::Create(MainFn->getContext(), "elseIf.body", MainFn); - - Builder.SetInsertPoint(BeforeCondBB); - Builder.CreateCondBr(BeforeCondVal, BeforeBodyBB, ElseIfCondBB); - - Builder.SetInsertPoint(ElseIfCondBB); - elseIf->getCondition()->accept(*this); - llvm::Value *ElseIfCondVal = V; - - Builder.SetInsertPoint(ElseIfBodyBB); - elseIf->accept(*this); - Builder.CreateBr(AfterIfBB); - - BeforeCondBB = ElseIfCondBB; - BeforeCondVal = ElseIfCondVal; - BeforeBodyBB = ElseIfBodyBB; - } - // finishing the last else - // Builder.CreateCondBr(BeforeCondVal, BeforeBodyBB, AfterIfBB); - } - - - llvm::BasicBlock *ElseBB = nullptr; - if (Node.HasElse()) - { - ElseStatement *elseS = Node.getElseStatement(); - ElseBB = llvm::BasicBlock::Create(MainFn->getContext(), "else.body", MainFn); - if(Node.HasElseIf()){ - Builder.SetInsertPoint(BeforeCondBB); - Builder.CreateCondBr(BeforeCondVal, BeforeBodyBB, ElseBB); - } - Builder.SetInsertPoint(ElseBB); - Node.getElseStatement()->accept(*this); - Builder.CreateBr(AfterIfBB); - } - else - { - Builder.SetInsertPoint(BeforeCondBB); - if(Node.HasElseIf()){ - Builder.CreateCondBr(BeforeCondVal, BeforeBodyBB, AfterIfBB); - }else{ - Builder.CreateCondBr(Cond, IfBodyBB, AfterIfBB); - } - - } - - Builder.SetInsertPoint(AfterIfBB); - } - - virtual void visit(ElseIfStatement &Node) override - { - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); - } - } - - virtual void visit(ElseStatement &Node) override - { - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); - } - } - - virtual void visit(WhileStatement &Node) override - { - if (optimize && !Node.isOptimized()) { - llvm::SmallVector unrolledStatements = completeUnroll(&Node, k); - for (auto I = unrolledStatements.begin(), E = unrolledStatements.end(); I != E; ++I) - { - (*I)->accept(*this); - } - return; - } - llvm::BasicBlock* WhileCondBB = llvm::BasicBlock::Create(M->getContext(), "while.cond", MainFn); - // The basic block for the while body. - llvm::BasicBlock* WhileBodyBB = llvm::BasicBlock::Create(M->getContext(), "while.body", MainFn); - // The basic block after the while statement. - llvm::BasicBlock* AfterWhileBB = llvm::BasicBlock::Create(M->getContext(), "after.while", MainFn); - - // Branch to the condition block. - Builder.CreateBr(WhileCondBB); - - // Set the insertion point to the condition block. - Builder.SetInsertPoint(WhileCondBB); - - // Visit the condition expression and create the conditional branch. - Node.getCondition()->accept(*this); - Value* Cond = V; - Builder.CreateCondBr(Cond, WhileBodyBB, AfterWhileBB); - - // Set the insertion point to the body block. - Builder.SetInsertPoint(WhileBodyBB); - - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); - } - - // Branch back to the condition block. - Builder.CreateBr(WhileCondBB); - // Set the insertion point to the block after the while loop. - Builder.SetInsertPoint(AfterWhileBB); - } - virtual void visit(ForStatement &Node) override - { - if (optimize && !Node.isOptimized()) { - llvm::SmallVector unrolledStatements = completeUnroll(&Node, k); - for (auto I = unrolledStatements.begin(), E = unrolledStatements.end(); I != E; ++I) - { - (*I)->accept(*this); - } - return; - } - llvm::BasicBlock* ForCondBB = llvm::BasicBlock::Create(M->getContext(), "for.cond", MainFn); - // The basic block for the while body. - llvm::BasicBlock* ForBodyBB = llvm::BasicBlock::Create(M->getContext(), "for.body", MainFn); - // The basic block after the while statement. - llvm::BasicBlock* AfterForBB = llvm::BasicBlock::Create(M->getContext(), "after.for", MainFn); - - llvm::BasicBlock* ForUpdateBB = llvm::BasicBlock::Create(M->getContext(), "for.update", MainFn); - - AssignStatement * initial_assign = Node.getInitialAssign(); - ((Expression *)initial_assign->getRValue())->accept(*this); - Value *val = V; - - // Get the name of the variable being assigned - auto varName = ((Expression *)initial_assign->getLValue())->getValue(); - - // Create a store instruction to assign the value to the variable - Builder.CreateStore(val, nameMap[varName]); - - - // Branch to the condition block. - Builder.CreateBr(ForCondBB); - - // Set the insertion point to the condition block. - Builder.SetInsertPoint(ForCondBB); - - // Visit the condition expression and create the conditional branch. - Node.getCondition()->accept(*this); - Value* Cond = V; - Builder.CreateCondBr(Cond, ForBodyBB, AfterForBB); - - // Set the insertion point to the body block. - Builder.SetInsertPoint(ForBodyBB); - - llvm::SmallVector stmts = Node.getStatements(); - for (auto I = stmts.begin(), E = stmts.end(); I != E; ++I) - { - (*I)->accept(*this); - } - - Builder.CreateBr(ForUpdateBB); - - Builder.SetInsertPoint(ForUpdateBB); - - AssignStatement * update_assign = Node.getUpdateAssign(); - ((Expression *)update_assign->getRValue())->accept(*this); - val = V; - - // Get the name of the variable being assigned - varName = ((Expression *)update_assign->getLValue())->getValue(); - - // Create a store instruction to assign the value to the variable - Builder.CreateStore(val, nameMap[varName]); - - // Branch back to the condition block. - Builder.CreateBr(ForCondBB); - // Set the insertion point to the block after the while loop. - Builder.SetInsertPoint(AfterForBB); - } - }; - -}; // namespace - -void CodeGen::compile(AST *Tree, bool optimize, int k) -{ - // Create an LLVM context and a module + } + + void visit(PrintNode &node) override { + node.expr->accept(*this); + if (V->getType()->isIntegerTy(1)) + Builder.CreateCall(PrintBoolFn, {V}); + else + Builder.CreateCall(PrintIntFn, {V}); + } + + void visit(IfElseNode &node) override { + BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then", MainFn); + BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else", MainFn); + BasicBlock *MergeBB= BasicBlock::Create(Ctx, "iftmp", MainFn); + // condition + node.condition->accept(*this); + Builder.CreateCondBr(V, ThenBB, ElseBB); + // then + Builder.SetInsertPoint(ThenBB); + node.thenBlock->accept(*this); + Builder.CreateBr(MergeBB); + // else + Builder.SetInsertPoint(ElseBB); + if (node.elseBlock) + node.elseBlock->accept(*this); + Builder.CreateBr(MergeBB); + // merge + Builder.SetInsertPoint(MergeBB); + } + + void visit(WhileLoopNode &node) override { + BasicBlock *CondBB = BasicBlock::Create(Ctx, "while.cond", MainFn); + BasicBlock *BodyBB = BasicBlock::Create(Ctx, "while.body", MainFn); + BasicBlock *AfterBB= BasicBlock::Create(Ctx, "while.end", MainFn); + Builder.CreateBr(CondBB); + Builder.SetInsertPoint(CondBB); + node.condition->accept(*this); + Builder.CreateCondBr(V, BodyBB, AfterBB); + Builder.SetInsertPoint(BodyBB); + for (auto &s : node.body->statements) + s->accept(*this); + Builder.CreateBr(CondBB); + Builder.SetInsertPoint(AfterBB); + } + + void visit(ForLoopNode &node) override { + // init + if (node.init) node.init->accept(*this); + BasicBlock *CondBB = BasicBlock::Create(Ctx,"for.cond",MainFn); + BasicBlock *BodyBB = BasicBlock::Create(Ctx,"for.body",MainFn); + BasicBlock *AfterBB= BasicBlock::Create(Ctx,"for.end",MainFn); + Builder.CreateBr(CondBB); + Builder.SetInsertPoint(CondBB); + if (node.condition) { node.condition->accept(*this); Builder.CreateCondBr(V, BodyBB, AfterBB);} + Builder.SetInsertPoint(BodyBB); + for (auto &s : node.body->statements) + s->accept(*this); + if (node.update) node.update->accept(*this); + Builder.CreateBr(CondBB); + Builder.SetInsertPoint(AfterBB); + } + + void visit(BlockNode &node) override { + for (auto &s : node.statements) + s->accept(*this); + } + + // other nodes omitted or treated as error +}; + +} // end anonymous namespace + +void CodeGen::compile(ProgramNode *root, bool optimize, int unroll) { LLVMContext Ctx; - Module *M = new Module("mas.expr", Ctx); - - // Create an instance of the ToIRVisitor and run it on the AST to generate LLVM IR - ToIRVisitor ToIRn(M, optimize, k); - - ToIRn.run(Tree); - - // Print the generated module to the standard output - M->print(outs(), nullptr); + ToIRVisitor visitor(Ctx, optimize, unroll); + visitor.run(root); + Module *mod = visitor.getModule(); + mod->print(outs(), nullptr); } From b21be24871b85921a84f9402b60395236c957788 Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 23:07:18 +0330 Subject: [PATCH 21/25] Update main.cpp --- code/main.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/code/main.cpp b/code/main.cpp index 6461b0f..f84bbbc 100644 --- a/code/main.cpp +++ b/code/main.cpp @@ -56,7 +56,8 @@ int main(int argc, const char **argv) Token nextToken; Lexer lexer(contentRef); Parser Parser(lexer); - AST *Tree = Parser.parse(); + std::unique_ptr TreePtr = Parser.parseProgram(); + ProgramNode *Tree = TreePtr.get(); Semantic semantic; if (semantic.semantic(Tree)) From 3779a80f7af04b3440237f490b119f9ecac69bb7 Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 23:08:19 +0330 Subject: [PATCH 22/25] Update CMakeLists.txt --- code/CMakeLists.txt | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/code/CMakeLists.txt b/code/CMakeLists.txt index ea89570..499de1d 100644 --- a/code/CMakeLists.txt +++ b/code/CMakeLists.txt @@ -1,10 +1,14 @@ -add_executable (compiler - main.cpp - code_generator.cpp - lexer.cpp - parser.cpp - semantic.cpp - error.cpp - optimizer.cpp - ) +find_package(LLVM REQUIRED CONFIG) +message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +include(HandleLLVMOptions) +include(AddLLVM) + +llvm_map_components_to_libnames(llvm_libs + Core + IRReader + Support + AsmPrinter +) + target_link_libraries(compiler PRIVATE ${llvm_libs}) From bbc17f6966e2677e76dccd1c4f61d3ce807ac2c4 Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 23:22:15 +0330 Subject: [PATCH 23/25] Update code_generator.h --- code/code_generator.h | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/code/code_generator.h b/code/code_generator.h index 7beb611..b645723 100644 --- a/code/code_generator.h +++ b/code/code_generator.h @@ -1,11 +1,42 @@ #ifndef CODE_GENERATOR_H #define CODE_GENERATOR_H +#include +#include +#include +#include #include "AST.h" +#include +#include +#include +#include +#include +#include + class CodeGen { public: void compile(ProgramNode *root, bool optimize, int unroll); + void dump() const; + +private: + std::unique_ptr context; + std::unique_ptr module; + std::unique_ptr> builder; + + llvm::Function* currentFunc = nullptr; + std::unordered_map symbols; + std::unordered_map arraySizes; + + void declareRuntimeFunctions(); + void generateStatement(ASTNode* node); + llvm::Value* generateValue(ASTNode* node, llvm::Type* expectedType); + llvm::Value* generatePow(llvm::Value* base, llvm::Value* exp); + + void printArray(const std::vector& elements); + void printArrayVar(llvm::Value* arrayPtr, uint64_t size); + void generateTryCatch(TryCatchNode* node); + void generateMatch(MatchNode* node); }; -#endif +#endif From 67bf9774047830ace8f56798508edfc011895532 Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 23:37:22 +0330 Subject: [PATCH 24/25] Update code_generator.cpp --- code/code_generator.cpp | 481 ++++++++++++++++++++++++---------------- 1 file changed, 290 insertions(+), 191 deletions(-) diff --git a/code/code_generator.cpp b/code/code_generator.cpp index 28cccf2..3144a33 100644 --- a/code/code_generator.cpp +++ b/code/code_generator.cpp @@ -1,224 +1,323 @@ -#include "code_generator.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Verifier.h" -#include "llvm/Support/raw_ostream.h" +#include "codegen.h" +#include "AST.h" +#include "semantic.h" +#include +#include +#include using namespace llvm; -namespace { - -class ToIRVisitor : public ASTVisitor { - LLVMContext &Ctx; - std::unique_ptr M; - IRBuilder<> Builder; - Type *Int32Ty; - Type *Int1Ty; - Constant *Int32Zero; - Constant *Int1False; - - StringMap NameMap; - Value *V; // holds last expression value - Function *MainFn; - FunctionCallee PrintIntFn; - FunctionCallee PrintBoolFn; - bool Optimize; - int UnrollFactor; - -public: - ToIRVisitor(LLVMContext &context, bool optimize, int unroll) - : Ctx(context), M(new Module("main_module", context)), Builder(context), - Optimize(optimize), UnrollFactor(unroll) { - Int32Ty = Type::getInt32Ty(Ctx); - Int1Ty = Type::getInt1Ty(Ctx); - Int32Zero = ConstantInt::get(Int32Ty, 0); - Int1False = ConstantInt::get(Int1Ty, 0); - // declare external print functions - PrintIntFn = M->getOrInsertFunction("print", FunctionType::get(Type::getVoidTy(Ctx), {Int32Ty}, false)); - PrintBoolFn = M->getOrInsertFunction("printBool", FunctionType::get(Type::getVoidTy(Ctx), {Int1Ty}, false)); - } +CodeGen::CodeGen() : context(std::make_unique()), + module(std::make_unique("main", *context)), + builder(std::make_unique>(*context)) { + + FunctionType* mainType = FunctionType::get(Type::getInt32Ty(*context), false); + mainFunc = Function::Create(mainType, Function::ExternalLinkage, "main", *module); + BasicBlock* entry = BasicBlock::Create(*context, "entry", mainFunc); + builder->SetInsertPoint(entry); - Module *getModule() { return M.get(); } - - void run(ProgramNode *root) { - // define main i32 @main() - FunctionType *FT = FunctionType::get(Int32Ty, false); - MainFn = Function::Create(FT, GlobalValue::ExternalLinkage, "main", M.get()); - BasicBlock *BB = BasicBlock::Create(Ctx, "entry", MainFn); - Builder.SetInsertPoint(BB); - // generate body - root->accept(*this); - Builder.CreateRet(Int32Zero); - // verify - verifyFunction(*MainFn); - } + printfFunc = Function::Create( + FunctionType::get(Type::getInt32Ty(*context), {Type::getInt8PtrTy(*context)}, true), + Function::ExternalLinkage, "printf", module.get() + ); - // Helpers - AllocaInst *createEntryBlockAlloca(Function *F, const std::string &Name, Type *Ty) { - IRBuilder<> TmpB(&F->getEntryBlock(), F->getEntryBlock().begin()); - return TmpB.CreateAlloca(Ty, nullptr, Name); - } + Function::Create( + FunctionType::get(Type::getInt8PtrTy(*context), {Type::getInt32Ty(*context)}, false), + Function::ExternalLinkage, "malloc", module.get() + ); - // Visitor implementations - void visit(ProgramNode &node) override { - for (auto &stmt : node.statements) - stmt->accept(*this); - } + Function::Create( + FunctionType::get(Type::getVoidTy(*context), + {Type::getInt8PtrTy(*context), Type::getInt8PtrTy(*context), Type::getInt32Ty(*context)}, + false), + Function::ExternalLinkage, "memcpy", module.get() + ); +} - void visit(MultiVarDeclNode &node) override { - for (auto &decl : node.declarations) { - Type *Ty = (decl->type == VarType::BOOL ? Int1Ty : Int32Ty); - AllocaInst *A = createEntryBlockAlloca(MainFn, decl->name, Ty); - NameMap[decl->name()] = A; - // initialize - if (decl->value) { - decl->value->accept(*this); - Builder.CreateStore(V, A); - } else { - Builder.CreateStore(Ty==Int1Ty ? Int1False : Int32Zero, A); - } - } +void CodeGen::generate(ProgramNode& ast) { + for (auto& stmt : ast.statements) { + generateStatement(stmt.get()); } - - void visit(AssignNode &node) override { - node.value->accept(*this); - auto *A = NameMap[node.target]; - Builder.CreateStore(V, A); + + if (!builder->GetInsertBlock()->getTerminator()) { + builder->CreateRet(ConstantInt::get(Type::getInt32Ty(*context), 0)); } + + std::string error; + raw_string_ostream os(error); + if (verifyModule(*module, &os)) { + throw std::runtime_error("Invalid IR: " + error); + } +} - void visit(VarRefNode &node) override { - AllocaInst *A = NameMap[node.name]; - V = Builder.CreateLoad(A->getAllocatedType(), A, node.name); +void CodeGen::generateStatement(ASTNode* node) { + if (auto multiVarDecl = dynamic_cast(node)) { + for (auto& decl : multiVarDecl->declarations) { + generateVarDecl(decl.get()); + } + } else if (auto assign = dynamic_cast(node)) { + generateAssign(assign); + } else if (auto ifElse = dynamic_cast(node)) { + generateIfElse(ifElse); + } else if (auto loop = dynamic_cast(node)) { + generateForLoop(loop); + } else if (auto print = dynamic_cast(node)) { + generatePrint(print); + } else if (auto block = dynamic_cast(node)) { + generateBlock(block); + } else if (auto tryCatch = dynamic_cast(node)) { + generateTryCatch(tryCatch); + } else if (auto match = dynamic_cast(node)) { + generateMatch(match); + } else if (auto unary = dynamic_cast(node)) { + generateUnaryOp(unary); + } else if (auto concat = dynamic_cast(node)) { + generateConcat(concat); } +} - void visit(LiteralNode &node) override { - V = ConstantInt::get(Int32Ty, node.value); +void CodeGen::generateVarDecl(VarDeclNode* node) { + Type* type = nullptr; + switch(node->type) { + case VarType::INT: type = Type::getInt32Ty(*context); break; + case VarType::FLOAT: type = Type::getFloatTy(*context); break; + case VarType::BOOL: type = Type::getInt1Ty(*context); break; + case VarType::STRING: type = Type::getInt8PtrTy(*context); break; + case VarType::ARRAY: + type = PointerType::get(Type::getInt32Ty(*context), 0); + break; + default: throw std::runtime_error("Unknown type"); } - void visit(LiteralNode &node) override { - V = ConstantInt::get(Int1Ty, node.value); + + AllocaInst* alloca = builder->CreateAlloca(type, nullptr, node->name); + symbols[node->name] = alloca; + + if (node->value) { + Value* val = generateValue(node->value.get(), type); + builder->CreateStore(val, alloca); + + if (node->type == VarType::ARRAY) { + if (auto arrLit = dynamic_cast(node->value.get())) { + arraySizes[node->name] = arrLit->elements.size(); + } + } } +} - void visit(BinaryOpNode &node) override { - node.left->accept(*this); - Value *L = V; - node.right->accept(*this); - Value *R = V; - switch (node.op) { - case BinaryOp::ADD: V = Builder.CreateAdd(L, R); break; - case BinaryOp::SUBTRACT: V = Builder.CreateSub(L, R); break; - case BinaryOp::MULTIPLY: V = Builder.CreateMul(L, R); break; - case BinaryOp::DIVIDE: V = Builder.CreateSDiv(L, R); break; - case BinaryOp::MOD: V = Builder.CreateSRem(L, R); break; - case BinaryOp::EQUAL: V = Builder.CreateICmpEQ(L, R); break; - case BinaryOp::NOT_EQUAL: V = Builder.CreateICmpNE(L, R); break; - case BinaryOp::LESS: V = Builder.CreateICmpSLT(L, R); break; - case BinaryOp::LESS_EQUAL: V = Builder.CreateICmpSLE(L, R); break; - case BinaryOp::GREATER: V = Builder.CreateICmpSGT(L, R); break; - case BinaryOp::GREATER_EQUAL: V = Builder.CreateICmpSGE(L,R); break; - case BinaryOp::AND: V = Builder.CreateAnd(L, R); break; - case BinaryOp::OR: V = Builder.CreateOr(L, R); break; - case BinaryOp::POW: - // simple pow via loop omitted for brevity - V = L; break; - default: - V = UndefValue::get(Int32Ty); +Value* CodeGen::generateValue(ASTNode* node, Type* expectedType) { + if (auto binOp = dynamic_cast(node)) { + return generateBinaryOp(binOp, expectedType); + } else if (auto varRef = dynamic_cast(node)) { + if (!symbols.count(varRef->name)) { + throw std::runtime_error("Undefined variable: " + varRef->name); } + return builder->CreateLoad(symbols[varRef->name]->getAllocatedType(), symbols[varRef->name]); + } else if (auto intLit = dynamic_cast*>(node)) { + return ConstantInt::get(Type::getInt32Ty(*context), intLit->value); + } else if (auto floatLit = dynamic_cast*>(node)) { + return ConstantFP::get(Type::getFloatTy(*context), floatLit->value); + } else if (auto strLit = dynamic_cast*>(node)) { + return builder->CreateGlobalStringPtr(strLit->value); + } else if (auto arr = dynamic_cast(node)) { + return generateArray(arr, expectedType); + } else if (auto access = dynamic_cast(node)) { + return generateArrayAccess(access); } + throw std::runtime_error("Unsupported node type"); +} - void visit(UnaryOpNode &node) override { - node.operand->accept(*this); - switch (node.op) { - case UnaryOp::INCREMENT: { - Value *One = ConstantInt::get(Int32Ty, 1); - V = Builder.CreateAdd(V, One); - break; - } - case UnaryOp::DECREMENT: { - Value *One = ConstantInt::get(Int32Ty, 1); - V = Builder.CreateSub(V, One); - break; - } +Value* CodeGen::generateBinaryOp(BinaryOpNode* node, Type* expectedType) { + Value* L = generateValue(node->left.get(), expectedType); + Value* R = generateValue(node->right.get(), expectedType); + + switch(node->op) { + case BinaryOp::ADD: + if (L->getType()->isFloatTy() || R->getType()->isFloatTy()) { + return builder->CreateFAdd(L, R); + } else { + return builder->CreateAdd(L, R); + } + case BinaryOp::SUBTRACT: + if (L->getType()->isFloatTy() || R->getType()->isFloatTy()) { + return builder->CreateFSub(L, R); + } else { + return builder->CreateSub(L, R); + } + case BinaryOp::MULTIPLY: + if (L->getType()->isFloatTy() || R->getType()->isFloatTy()) { + return builder->CreateFMul(L, R); + } else { + return builder->CreateMul(L, R); + } + case BinaryOp::DIVIDE: + if (L->getType()->isFloatTy() || R->getType()->isFloatTy()) { + return builder->CreateFDiv(L, R); + } else { + return builder->CreateSDiv(L, R); + } + case BinaryOp::INDEX: + return generateArrayIndex(L, R); + case BinaryOp::CONCAT: + return generateStringConcat(L, R); default: - break; - } + throw std::runtime_error("Unsupported binary op"); } +} - void visit(PrintNode &node) override { - node.expr->accept(*this); - if (V->getType()->isIntegerTy(1)) - Builder.CreateCall(PrintBoolFn, {V}); - else - Builder.CreateCall(PrintIntFn, {V}); - } +Value* CodeGen::generateArrayIndex(Value* arrayPtr, Value* index) { + Value* gep = builder->CreateGEP( + Type::getInt32Ty(*context), + arrayPtr, + {index} + ); + return builder->CreateLoad(Type::getInt32Ty(*context), gep); +} - void visit(IfElseNode &node) override { - BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then", MainFn); - BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else", MainFn); - BasicBlock *MergeBB= BasicBlock::Create(Ctx, "iftmp", MainFn); - // condition - node.condition->accept(*this); - Builder.CreateCondBr(V, ThenBB, ElseBB); - // then - Builder.SetInsertPoint(ThenBB); - node.thenBlock->accept(*this); - Builder.CreateBr(MergeBB); - // else - Builder.SetInsertPoint(ElseBB); - if (node.elseBlock) - node.elseBlock->accept(*this); - Builder.CreateBr(MergeBB); - // merge - Builder.SetInsertPoint(MergeBB); +Value* CodeGen::generateStringConcat(Value* L, Value* R) { + Function* strlenFunc = module->getFunction("strlen"); + Function* mallocFunc = module->getFunction("malloc"); + Function* memcpyFunc = module->getFunction("memcpy"); + + Value* lenL = builder->CreateCall(strlenFunc, {L}); + Value* lenR = builder->CreateCall(strlenFunc, {R}); + Value* totalLen = builder->CreateAdd(lenL, lenR); + totalLen = builder->CreateAdd(totalLen, ConstantInt::get(Type::getInt32Ty(*context), 1)); + + Value* buffer = builder->CreateCall(mallocFunc, {totalLen}); + builder->CreateCall(memcpyFunc, {buffer, L, lenL}); + Value* dest = builder->CreateGEP(Type::getInt8Ty(*context), buffer, lenL); + builder->CreateCall(memcpyFunc, {dest, R, lenR}); + + return buffer; +} + +void CodeGen::generateIfElse(IfElseNode* node) { + Function* func = builder->GetInsertBlock()->getParent(); + BasicBlock* thenBB = BasicBlock::Create(*context, "then", func); + BasicBlock* elseBB = BasicBlock::Create(*context, "else"); + BasicBlock* mergeBB = BasicBlock::Create(*context, "ifcont"); + + Value* cond = generateValue(node->condition.get(), Type::getInt1Ty(*context)); + builder->CreateCondBr(cond, thenBB, elseBB); + + builder->SetInsertPoint(thenBB); + generateStatement(node->thenBlock.get()); + builder->CreateBr(mergeBB); + + func->getBasicBlockList().push_back(elseBB); + builder->SetInsertPoint(elseBB); + if (node->elseBlock) { + generateStatement(node->elseBlock.get()); } + builder->CreateBr(mergeBB); + + func->getBasicBlockList().push_back(mergeBB); + builder->SetInsertPoint(mergeBB); +} - void visit(WhileLoopNode &node) override { - BasicBlock *CondBB = BasicBlock::Create(Ctx, "while.cond", MainFn); - BasicBlock *BodyBB = BasicBlock::Create(Ctx, "while.body", MainFn); - BasicBlock *AfterBB= BasicBlock::Create(Ctx, "while.end", MainFn); - Builder.CreateBr(CondBB); - Builder.SetInsertPoint(CondBB); - node.condition->accept(*this); - Builder.CreateCondBr(V, BodyBB, AfterBB); - Builder.SetInsertPoint(BodyBB); - for (auto &s : node.body->statements) - s->accept(*this); - Builder.CreateBr(CondBB); - Builder.SetInsertPoint(AfterBB); +void CodeGen::generateForLoop(ForLoopNode* node) { + Function* func = builder->GetInsertBlock()->getParent(); + BasicBlock* loopStart = BasicBlock::Create(*context, "loop.start", func); + BasicBlock* loopBody = BasicBlock::Create(*context, "loop.body"); + BasicBlock* loopEnd = BasicBlock::Create(*context, "loop.end"); + + if (node->init) generateStatement(node->init.get()); + builder->CreateBr(loopStart); + + builder->SetInsertPoint(loopStart); + Value* cond = node->condition ? + generateValue(node->condition.get(), Type::getInt1Ty(*context)) : + ConstantInt::getTrue(*context); + builder->CreateCondBr(cond, loopBody, loopEnd); + + func->getBasicBlockList().push_back(loopBody); + builder->SetInsertPoint(loopBody); + generateStatement(node->body.get()); + if (node->update) generateStatement(node->update.get()); + builder->CreateBr(loopStart); + + func->getBasicBlockList().push_back(loopEnd); + builder->SetInsertPoint(loopEnd); +} + +void CodeGen::generatePrint(PrintNode* node) { + Value* value = generateValue(node->expr.get(), nullptr); + Type* ty = value->getType(); + + Constant* format = nullptr; + if (ty->isIntegerTy(32)) { + format = builder->CreateGlobalStringPtr("%d\n"); + } else if (ty->isFloatTy()) { + format = builder->CreateGlobalStringPtr("%f\n"); + } else if (ty->isPointerTy()) { + format = builder->CreateGlobalStringPtr("%s\n"); } + + builder->CreateCall(printfFunc, {format, value}); +} - void visit(ForLoopNode &node) override { - // init - if (node.init) node.init->accept(*this); - BasicBlock *CondBB = BasicBlock::Create(Ctx,"for.cond",MainFn); - BasicBlock *BodyBB = BasicBlock::Create(Ctx,"for.body",MainFn); - BasicBlock *AfterBB= BasicBlock::Create(Ctx,"for.end",MainFn); - Builder.CreateBr(CondBB); - Builder.SetInsertPoint(CondBB); - if (node.condition) { node.condition->accept(*this); Builder.CreateCondBr(V, BodyBB, AfterBB);} - Builder.SetInsertPoint(BodyBB); - for (auto &s : node.body->statements) - s->accept(*this); - if (node.update) node.update->accept(*this); - Builder.CreateBr(CondBB); - Builder.SetInsertPoint(AfterBB); +void CodeGen::generateTryCatch(TryCatchNode* node) { + Function* func = builder->GetInsertBlock()->getParent(); + BasicBlock* tryBB = BasicBlock::Create(*context, "try", func); + BasicBlock* catchBB = BasicBlock::Create(*context, "catch"); + BasicBlock* contBB = BasicBlock::Create(*context, "try.cont"); + + LandingPadInst* lp = builder->CreateLandingPad( + StructType::get(Type::getInt8PtrTy(*context), Type::getInt32Ty(*context)), 0); + lp->addClause(ConstantPointerNull::get(Type::getInt8PtrTy(*context))); + + builder->CreateBr(tryBB); + + builder->SetInsertPoint(tryBB); + generateStatement(node->tryBlock.get()); + builder->CreateBr(contBB); + + func->getBasicBlockList().push_back(catchBB); + builder->SetInsertPoint(catchBB); + if (!node->errorVar.empty()) { + AllocaInst* alloca = builder->CreateAlloca(Type::getInt8PtrTy(*context), nullptr, node->errorVar); + builder->CreateStore(lp->getOperand(0), alloca); } + generateStatement(node->catchBlock.get()); + builder->CreateBr(contBB); + + func->getBasicBlockList().push_back(contBB); + builder->SetInsertPoint(contBB); +} - void visit(BlockNode &node) override { - for (auto &s : node.statements) - s->accept(*this); +void CodeGen::generateMatch(MatchNode* node) { + Value* expr = generateValue(node->expr.get(), nullptr); + Function* func = builder->GetInsertBlock()->getParent(); + BasicBlock* endBB = BasicBlock::Create(*context, "match.end"); + + std::vector caseBBs; + for (size_t i = 0; i < node->cases.size(); ++i) { + caseBBs.push_back(BasicBlock::Create(*context, "case." + std::to_string(i), func)); + } + + for (size_t i = 0; i < node->cases.size(); ++i) { + builder->SetInsertPoint(caseBBs[i]); + if (node->cases[i]->value) { + Value* caseVal = generateValue(node->cases[i]->value.get(), expr->getType()); + Value* cmp = builder->CreateICmpEQ(expr, caseVal); + BasicBlock* nextBB = (i < node->cases.size()-1) ? caseBBs[i+1] : endBB; + builder->CreateCondBr(cmp, caseBBs[i], nextBB); + } else { + generateStatement(node->cases[i]->body.get()); + builder->CreateBr(endBB); + } } + + func->getBasicBlockList().push_back(endBB); + builder->SetInsertPoint(endBB); +} + - // other nodes omitted or treated as error -}; -} // end anonymous namespace +void CodeGen::dump() const { + module->print(llvm::outs(), nullptr); +} -void CodeGen::compile(ProgramNode *root, bool optimize, int unroll) { - LLVMContext Ctx; - ToIRVisitor visitor(Ctx, optimize, unroll); - visitor.run(root); - Module *mod = visitor.getModule(); - mod->print(outs(), nullptr); +extern "C" void throw_exception(const char* msg) { + throw std::runtime_error(msg); } From 2442ea984a75d8c3f4daf383a98365056d105af1 Mon Sep 17 00:00:00 2001 From: Am8xin <135750826+Am8xin@users.noreply.github.com> Date: Tue, 22 Apr 2025 23:41:34 +0330 Subject: [PATCH 25/25] Update code_generator.cpp --- code/code_generator.cpp | 243 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 243 insertions(+) diff --git a/code/code_generator.cpp b/code/code_generator.cpp index 3144a33..2d4602d 100644 --- a/code/code_generator.cpp +++ b/code/code_generator.cpp @@ -314,6 +314,249 @@ void CodeGen::generateMatch(MatchNode* node) { +// ... ادامه از قسمت قبلی + +Value* CodeGen::generateArray(ArrayNode* node, Type* expectedType) { + ArrayType* arrType = ArrayType::get(Type::getInt32Ty(*context), node->elements.size()); + Value* arrayPtr = builder->CreateAlloca(arrType); + + for (size_t i = 0; i < node->elements.size(); ++i) { + Value* idx[] = { + ConstantInt::get(Type::getInt32Ty(*context), 0), + ConstantInt::get(Type::getInt32Ty(*context), i) + }; + Value* elemPtr = builder->CreateInBoundsGEP(arrType, arrayPtr, idx); + Value* val = generateValue(node->elements[i].get(), Type::getInt32Ty(*context)); + builder->CreateStore(val, elemPtr); + } + + return builder->CreateBitCast(arrayPtr, expectedType); +} + +Value* CodeGen::generateArrayAccess(ArrayAccessNode* node) { + Value* arrayPtr = symbols[node->arrayName]; + Value* index = generateValue(node->index.get(), Type::getInt32Ty(*context)); + + // Runtime bounds checking + if (arraySizes.count(node->arrayName)) { + Value* size = ConstantInt::get(Type::getInt32Ty(*context), arraySizes[node->arrayName]); + Value* cond = builder->CreateICmpSGE(index, size); + + BasicBlock* errorBB = BasicBlock::Create(*context, "bounds_error", builder->GetInsertBlock()->getParent()); + BasicBlock* validBB = BasicBlock::Create(*context, "valid_index"); + + builder->CreateCondBr(cond, errorBB, validBB); + + builder->SetInsertPoint(errorBB); + Function* throwFunc = Function::Create( + FunctionType::get(Type::getVoidTy(*context), {Type::getInt8PtrTy(*context)}, false), + Function::ExternalLinkage, "throw_exception", module.get() + ); + Value* msg = builder->CreateGlobalStringPtr("Array index out of bounds!"); + builder->CreateCall(throwFunc, {msg}); + builder->CreateUnreachable(); + + builder->GetInsertBlock()->getParent()->getBasicBlockList().push_back(validBB); + builder->SetInsertPoint(validBB); + } + + return generateArrayIndex(arrayPtr, index); +} + +Value* CodeGen::generateUnaryOp(UnaryOpNode* node) { + Value* operand = generateValue(node->operand.get(), nullptr); + + switch(node->op) { + case UnaryOp::MINUS: + if (operand->getType()->isFloatTy()) { + return builder->CreateFNeg(operand); + } else { + return builder->CreateNeg(operand); + } + case UnaryOp::INCREMENT: { + Value* inc = ConstantInt::get(operand->getType(), 1); + Value* newVal = builder->CreateAdd(operand, inc); + if (auto varRef = dynamic_cast(node->operand.get())) { + builder->CreateStore(newVal, symbols[varRef->name]); + } + return operand; // Post-increment returns original value + } + case UnaryOp::LENGTH: { + if (!arraySizes.count(node->operand->getName())) { + throw std::runtime_error("Length operator on non-array type"); + } + return ConstantInt::get(Type::getInt32Ty(*context), arraySizes[node->operand->getName()]); + } + default: + throw std::runtime_error("Unsupported unary operator"); + } +} + +void CodeGen::generateCompoundAssign(CompoundAssignNode* node) { + Value* target = symbols[node->target]; + Value* current = builder->CreateLoad(target->getType()->getPointerElementType(), target); + Value* rhs = generateValue(node->value.get(), current->getType()); + + Value* result; + switch(node->op) { + case BinaryOp::ADD: + result = current->getType()->isFloatTy() ? + builder->CreateFAdd(current, rhs) : + builder->CreateAdd(current, rhs); + break; + case BinaryOp::SUBTRACT: + result = current->getType()->isFloatTy() ? + builder->CreateFSub(current, rhs) : + builder->CreateSub(current, rhs); + break; + default: + throw std::runtime_error("Unsupported compound assignment"); + } + + builder->CreateStore(result, target); +} + +void CodeGen::generateWhileLoop(WhileLoopNode* node) { + Function* func = builder->GetInsertBlock()->getParent(); + BasicBlock* condBB = BasicBlock::Create(*context, "while.cond", func); + BasicBlock* bodyBB = BasicBlock::Create(*context, "while.body"); + BasicBlock* endBB = BasicBlock::Create(*context, "while.end"); + + builder->CreateBr(condBB); + + // Condition block + builder->SetInsertPoint(condBB); + Value* cond = generateValue(node->condition.get(), Type::getInt1Ty(*context)); + builder->CreateCondBr(cond, bodyBB, endBB); + + // Body block + func->getBasicBlockList().push_back(bodyBB); + builder->SetInsertPoint(bodyBB); + generateStatement(node->body.get()); + builder->CreateBr(condBB); // Loop back + + // End block + func->getBasicBlockList().push_back(endBB); + builder->SetInsertPoint(endBB); +} + +Value* CodeGen::generateTernaryExpr(TernaryExprNode* node) { + Value* cond = generateValue(node->condition.get(), Type::getInt1Ty(*context)); + Function* func = builder->GetInsertBlock()->getParent(); + + BasicBlock* trueBB = BasicBlock::Create(*context, "ternary.true", func); + BasicBlock* falseBB = BasicBlock::Create(*context, "ternary.false"); + BasicBlock* mergeBB = BasicBlock::Create(*context, "ternary.merge"); + + builder->CreateCondBr(cond, trueBB, falseBB); + + // True branch + builder->SetInsertPoint(trueBB); + Value* trueVal = generateValue(node->trueExpr.get(), nullptr); + builder->CreateBr(mergeBB); + + // False branch + func->getBasicBlockList().push_back(falseBB); + builder->SetInsertPoint(falseBB); + Value* falseVal = generateValue(node->falseExpr.get(), nullptr); + builder->CreateBr(mergeBB); + + // Merge + func->getBasicBlockList().push_back(mergeBB); + builder->SetInsertPoint(mergeBB); + PHINode* phi = builder->CreatePHI(trueVal->getType(), 2); + phi->addIncoming(trueVal, trueBB); + phi->addIncoming(falseVal, falseBB); + + return phi; +} + +void CodeGen::generateTypeConversion(TypeConversionNode* node) { + Value* val = generateValue(node->expr.get(), nullptr); + Type* targetType = getLLVMType(node->targetType); + + if (val->getType() == targetType) return val; + + if (targetType->isFloatTy() && val->getType()->isIntegerTy()) { + return builder->CreateSIToFP(val, targetType); + } + if (targetType->isIntegerTy() && val->getType()->isFloatTy()) { + return builder->CreateFPToSI(val, targetType); + } + if (targetType->isPointerTy() && val->getType()->isIntegerTy()) { + return builder->CreateIntToPtr(val, targetType); + } + + throw std::runtime_error("Invalid type conversion"); +} + +Value* CodeGen::generateBuiltInCall(BuiltInCallNode* node) { + if (node->funcName == "pow") { + Value* base = generateValue(node->args[0].get(), Type::getDoubleTy(*context)); + Value* exp = generateValue(node->args[1].get(), Type::getDoubleTy(*context)); + return builder->CreateCall( + Intrinsic::getDeclaration(module.get(), Intrinsic::pow, {Type::getDoubleTy(*context)}), + {base, exp} + ); + } + if (node->funcName == "abs") { + Value* val = generateValue(node->args[0].get(), nullptr); + if (val->getType()->isFloatTy()) { + return builder->CreateCall( + Intrinsic::getDeclaration(module.get(), Intrinsic::fabs, {val->getType()}), + {val} + ); + } else { + Value* zero = ConstantInt::get(val->getType(), 0); + Value* neg = builder->CreateNeg(val); + return builder->CreateSelect( + builder->CreateICmpSGT(val, zero), + val, + neg + ); + } + } + throw std::runtime_error("Unknown built-in function"); +} + +void CodeGen::generateMemoryManagement() { + // Register destructors for stack-allocated arrays + for (auto& [name, alloca] : symbols) { + if (alloca->getAllocatedType()->isPointerTy()) { + Function* freeFunc = Function::Create( + FunctionType::get(Type::getVoidTy(*context), {Type::getInt8PtrTy(*context)}, false), + Function::ExternalLinkage, "free", module.get() + ); + + BasicBlock* currentBB = builder->GetInsertBlock(); + BasicBlock* freeBB = BasicBlock::Create(*context, "free." + name, currentBB->getParent()); + + builder->SetInsertPoint(freeBB); + Value* casted = builder->CreateBitCast(alloca, Type::getInt8PtrTy(*context)); + builder->CreateCall(freeFunc, {casted}); + builder->CreateBr(currentBB); + + currentBB->getParent()->getBasicBlockList().push_back(freeBB); + } + } +} + +void CodeGen::optimizeIR() { + legacy::FunctionPassManager FPM(module.get()); + FPM.add(createPromoteMemoryToRegisterPass()); + FPM.add(createInstructionCombiningPass()); + FPM.add(createReassociatePass()); + FPM.add(createGVNPass()); + FPM.add(createCFGSimplificationPass()); + + FPM.doInitialization(); + for (Function &F : *module) { + FPM.run(F); + } +} + + + void CodeGen::dump() const { module->print(llvm::outs(), nullptr); }