Skip to content

Commit 5bc0a44

Browse files
committed
feat: add basic interpreter, need to make functioncall a epression
1 parent 7c49ace commit 5bc0a44

4 files changed

Lines changed: 545 additions & 571 deletions

File tree

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
from typing import Any, Dict, List, Union
2+
from copy import deepcopy
3+
4+
from ..schemas.ast_nodes import *
5+
6+
RuntimeValue = Union[int, float, str, bool]
7+
8+
9+
class ReturnException(Exception):
10+
def __init__(self, value: RuntimeValue):
11+
self.value = value
12+
13+
14+
class Interpreter:
15+
def __init__(self):
16+
self.variables: Dict[str, RuntimeValue] = {}
17+
self.functions: Dict[str, FunctionNode] = {}
18+
self.output: List[str] = []
19+
20+
# -----------------------------------
21+
# Public Entry Point
22+
# -----------------------------------
23+
24+
def run(self, program: ProgramNode):
25+
self.variables = {}
26+
self.functions = {}
27+
self.output = []
28+
29+
# Register functions first
30+
for func in program.functions:
31+
self.functions[func.name] = func
32+
33+
# Execute global statements
34+
if program.global_statements:
35+
self.execute_block(program.global_statements)
36+
37+
return {
38+
"variables": deepcopy(self.variables),
39+
"output": list(self.output)
40+
}
41+
42+
# -----------------------------------
43+
# Block Execution
44+
# -----------------------------------
45+
46+
def execute_block(self, block: BlockNode):
47+
for stmt in block.statements:
48+
self.execute_statement(stmt)
49+
50+
# -----------------------------------
51+
# Statement Execution
52+
# -----------------------------------
53+
54+
def execute_statement(self, node: ASTNode):
55+
56+
if isinstance(node, AssignmentNode):
57+
value = self.evaluate_expression(node.value)
58+
if isinstance(node.target, VariableNode):
59+
self.variables[node.target.name] = value
60+
elif isinstance(node.target, ArrayAccessNode):
61+
arr = self.evaluate_expression(node.target.array)
62+
index = int(self.evaluate_expression(node.target.index))
63+
arr[index] = value
64+
65+
elif isinstance(node, ReturnNode):
66+
value = self.evaluate_expression(node.value)
67+
raise ReturnException(value)
68+
69+
elif isinstance(node, FunctionCallNode):
70+
self.call_function(node)
71+
72+
elif isinstance(node, LoopNode):
73+
self.execute_loop(node)
74+
75+
elif isinstance(node, ConditionalNode):
76+
self.execute_conditional(node)
77+
78+
elif isinstance(node, FunctionNode):
79+
# Already registered in run()
80+
pass
81+
82+
elif isinstance(node, BlockNode):
83+
self.execute_block(node)
84+
85+
else:
86+
raise NotImplementedError(f"Unsupported statement: {type(node)}")
87+
88+
# -----------------------------------
89+
# Loop Execution
90+
# -----------------------------------
91+
92+
def execute_loop(self, node: LoopNode):
93+
94+
if node.loop_type == LoopType.FOR:
95+
start = int(self.evaluate_expression(node.start))
96+
end = int(self.evaluate_expression(node.end))
97+
step = int(self.evaluate_expression(node.step)) if node.step else 1
98+
99+
for i in range(start, end + 1, step):
100+
self.variables[node.iterator.name] = i
101+
self.execute_block(node.body)
102+
103+
elif node.loop_type == LoopType.WHILE:
104+
while self.evaluate_expression(node.condition):
105+
self.execute_block(node.body)
106+
107+
else:
108+
raise NotImplementedError(f"Loop type {node.loop_type} not supported")
109+
110+
# -----------------------------------
111+
# Conditional Execution
112+
# -----------------------------------
113+
114+
def execute_conditional(self, node: ConditionalNode):
115+
116+
if self.evaluate_expression(node.condition):
117+
self.execute_block(node.then_branch)
118+
return
119+
120+
for elif_branch in node.elif_branches:
121+
if self.evaluate_expression(elif_branch.condition):
122+
self.execute_block(elif_branch.then_branch)
123+
return
124+
125+
if node.else_branch:
126+
self.execute_block(node.else_branch)
127+
128+
# -----------------------------------
129+
# Expression Evaluation
130+
# -----------------------------------
131+
132+
def evaluate_expression(self, node: ExpressionNode) -> RuntimeValue:
133+
134+
if isinstance(node, LiteralNode):
135+
return node.value
136+
137+
if isinstance(node, VariableNode):
138+
return self.variables.get(node.name, 0)
139+
140+
if isinstance(node, UnaryOpNode):
141+
value = self.evaluate_expression(node.operand)
142+
if node.operator == OperatorType.SUBTRACT:
143+
return -value
144+
if node.operator == OperatorType.NOT:
145+
return not value
146+
147+
if isinstance(node, BinaryOpNode):
148+
left = self.evaluate_expression(node.left)
149+
right = self.evaluate_expression(node.right)
150+
op = node.operator
151+
152+
if op == OperatorType.ADD:
153+
return left + right
154+
if op == OperatorType.SUBTRACT:
155+
return left - right
156+
if op == OperatorType.MULTIPLY:
157+
return left * right
158+
if op == OperatorType.DIVIDE:
159+
return left / right
160+
if op == OperatorType.MODULO:
161+
return left % right
162+
if op == OperatorType.POWER:
163+
return left ** right
164+
165+
# Comparisons
166+
if op == OperatorType.EQUAL:
167+
return left == right
168+
if op == OperatorType.NOT_EQUAL:
169+
return left != right
170+
if op == OperatorType.LESS_THAN:
171+
return left < right
172+
if op == OperatorType.LESS_EQUAL:
173+
return left <= right
174+
if op == OperatorType.GREATER_THAN:
175+
return left > right
176+
if op == OperatorType.GREATER_EQUAL:
177+
return left >= right
178+
179+
# Logical
180+
if op == OperatorType.AND:
181+
return bool(left) and bool(right)
182+
if op == OperatorType.OR:
183+
return bool(left) or bool(right)
184+
185+
if isinstance(node, ArrayAccessNode):
186+
arr = self.evaluate_expression(node.array)
187+
index = int(self.evaluate_expression(node.index))
188+
return arr[index]
189+
190+
if isinstance(node, RecursiveCallNode):
191+
return self.call_function(node)
192+
193+
if isinstance(node, FunctionCallNode):
194+
return self.call_function(node)
195+
196+
raise NotImplementedError(f"Unsupported expression: {type(node)}")
197+
198+
# -----------------------------------
199+
# Function Calls
200+
# -----------------------------------
201+
202+
def call_function(self, node: FunctionCallNode) -> RuntimeValue:
203+
204+
if node.function_name not in self.functions:
205+
raise Exception(f"Function {node.function_name} not defined")
206+
207+
func = self.functions[node.function_name]
208+
209+
# Evaluate arguments in current scope
210+
arg_values = [self.evaluate_expression(arg) for arg in node.arguments]
211+
212+
# Save previous scope
213+
previous_vars = deepcopy(self.variables)
214+
215+
# Set parameters
216+
for param, value in zip(func.parameters, arg_values):
217+
self.variables[param.name] = value
218+
219+
try:
220+
self.execute_block(func.body)
221+
self.variables = previous_vars
222+
return 0
223+
except ReturnException as e:
224+
self.variables = previous_vars
225+
return e.value

0 commit comments

Comments
 (0)