Skip to content

Commit 730b3ee

Browse files
Add kw/default args, RUN operator (not to be confused with old RUN operator, which is now CL. This one is more like Python eval)
1 parent d42a5eb commit 730b3ee

File tree

4 files changed

+231
-45
lines changed

4 files changed

+231
-45
lines changed

asmln.exe

4.28 KB
Binary file not shown.

interpreter.py

Lines changed: 155 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Assignment,
1414
Block,
1515
BreakStatement,
16+
CallArgument,
1617
CallExpression,
1718
Expression,
1819
ExpressionStatement,
@@ -24,6 +25,7 @@
2425
IfBranch,
2526
IfStatement,
2627
Literal,
28+
Param,
2729
Parser,
2830
Program,
2931
ReturnStatement,
@@ -155,8 +157,7 @@ def snapshot(self) -> Dict[str, str]:
155157
@dataclass
156158
class Function:
157159
name: str
158-
params: List[str]
159-
param_types: List[str]
160+
params: List[Param]
160161
return_type: str
161162
body: Block
162163
closure: Environment
@@ -295,6 +296,7 @@ def __init__(self) -> None:
295296
self._register_custom("MAIN", 0, 0, self._main)
296297
self._register_custom("OS", 0, 0, self._os)
297298
self._register_custom("IMPORT", 1, 1, self._import)
299+
self._register_custom("RUN", 1, 1, self._run)
298300
self._register_custom("INPUT", 0, 0, self._input)
299301
self._register_custom("PRINT", 0, None, self._print)
300302
self._register_custom("ASSERT", 1, 1, self._assert)
@@ -707,7 +709,6 @@ def _import(
707709
interpreter.functions[dotted_name] = Function(
708710
name=dotted_name,
709711
params=fn.params,
710-
param_types=fn.param_types,
711712
return_type=fn.return_type,
712713
body=fn.body,
713714
closure=fn.closure,
@@ -717,7 +718,6 @@ def _import(
717718
interpreter.functions[dotted_name] = Function(
718719
name=dotted_name,
719720
params=fn.params,
720-
param_types=fn.param_types,
721721
return_type=fn.return_type,
722722
body=fn.body,
723723
closure=module_env,
@@ -728,6 +728,29 @@ def _import(
728728
env.set(dotted, v, declared_type=v.type)
729729
return Value(TYPE_INT, 0)
730730

731+
def _run(
732+
self,
733+
interpreter: "Interpreter",
734+
args: List[Value],
735+
__: List[Expression],
736+
env: Environment,
737+
location: SourceLocation,
738+
) -> Value:
739+
# RUN(source): execute the provided source code string within the
740+
# current environment (mutating `env` and `interpreter.functions`).
741+
source_text = self._expect_str(args[0], "RUN", location)
742+
run_filename = location.file if location and location.file else "<run>"
743+
744+
lexer = Lexer(source_text, run_filename)
745+
tokens = lexer.tokenize()
746+
parser = Parser(tokens, run_filename, source_text.splitlines())
747+
program = parser.parse()
748+
749+
# Execute parsed statements in the caller's environment so that
750+
# assignments and function definitions are visible to the caller.
751+
interpreter._execute_block(program.statements, env)
752+
return Value(TYPE_INT, 0)
753+
731754
def _input(
732755
self,
733756
interpreter: "Interpreter",
@@ -1061,12 +1084,9 @@ def _execute_statement(self, statement: Statement, env: Environment) -> None:
10611084
raise ASMRuntimeError(
10621085
f"Function name '{statement.name}' conflicts with built-in", location=statement.location
10631086
)
1064-
param_types = [ptype for ptype, _ in statement.params]
1065-
param_names = [pname for _, pname in statement.params]
10661087
self.functions[statement.name] = Function(
10671088
name=statement.name,
1068-
params=param_names,
1069-
param_types=param_types,
1089+
params=statement.params,
10701090
return_type=statement.return_type,
10711091
body=statement.body,
10721092
closure=env,
@@ -1180,27 +1200,49 @@ def _evaluate_expression(self, expression: Expression, env: Environment) -> Valu
11801200
raise
11811201
if isinstance(expression, CallExpression):
11821202
if expression.name == "IMPORT":
1183-
module_label = expression.args[0].name if (expression.args and isinstance(expression.args[0], Identifier)) else None
1203+
if any(arg.name for arg in expression.args):
1204+
raise ASMRuntimeError("IMPORT does not accept keyword arguments", location=expression.location, rewrite_rule="IMPORT")
1205+
first_expr = expression.args[0].expression if expression.args else None
1206+
module_label = first_expr.name if isinstance(first_expr, Identifier) else None
11841207
dummy_args: List[Value] = [Value(TYPE_INT, 0)] * len(expression.args)
1208+
arg_nodes = [arg.expression for arg in expression.args]
11851209
try:
1186-
result = self.builtins.invoke(self, expression.name, dummy_args, expression.args, env, expression.location)
1210+
result = self.builtins.invoke(self, expression.name, dummy_args, arg_nodes, env, expression.location)
11871211
except ASMRuntimeError:
11881212
self._log_step(rule="IMPORT", location=expression.location, extra={"module": module_label, "status": "error"})
11891213
raise
11901214
self._log_step(rule="IMPORT", location=expression.location, extra={"module": module_label, "result": result.value})
11911215
return result
11921216
if expression.name in ("DEL", "EXIST"):
1217+
if any(arg.name for arg in expression.args):
1218+
raise ASMRuntimeError(
1219+
f"{expression.name} does not accept keyword arguments",
1220+
location=expression.location,
1221+
rewrite_rule=expression.name,
1222+
)
11931223
dummy_args: List[Value] = [Value(TYPE_INT, 0)] * len(expression.args)
1224+
arg_nodes = [arg.expression for arg in expression.args]
11941225
try:
1195-
result = self.builtins.invoke(self, expression.name, dummy_args, expression.args, env, expression.location)
1226+
result = self.builtins.invoke(self, expression.name, dummy_args, arg_nodes, env, expression.location)
11961227
except ASMRuntimeError:
11971228
self._log_step(rule=expression.name, location=expression.location, extra={"args": None, "status": "error"})
11981229
raise
11991230
self._log_step(rule=expression.name, location=expression.location, extra={"args": None, "result": result.value})
12001231
return result
1201-
args: List[Value] = []
1232+
positional_args: List[Value] = []
1233+
keyword_args: Dict[str, Value] = {}
12021234
for arg in expression.args:
1203-
args.append(self._evaluate_expression(arg, env))
1235+
value = self._evaluate_expression(arg.expression, env)
1236+
if arg.name is None:
1237+
positional_args.append(value)
1238+
else:
1239+
if arg.name in keyword_args:
1240+
raise ASMRuntimeError(
1241+
f"Duplicate keyword argument '{arg.name}'",
1242+
location=expression.location,
1243+
rewrite_rule=expression.name,
1244+
)
1245+
keyword_args[arg.name] = value
12041246
func_name: Optional[str] = None
12051247
if expression.name in self.functions:
12061248
func_name = expression.name
@@ -1213,32 +1255,120 @@ def _evaluate_expression(self, expression: Expression, env: Environment) -> Valu
12131255
if candidate in self.functions:
12141256
func_name = candidate
12151257
if func_name is not None:
1216-
self._log_step(rule="CALL", location=expression.location, extra={"function": func_name, "args": [a.value for a in args]})
1217-
return self._call_user_function(self.functions[func_name], args, expression.location)
1258+
self._log_step(
1259+
rule="CALL",
1260+
location=expression.location,
1261+
extra={
1262+
"function": func_name,
1263+
"positional": [a.value for a in positional_args],
1264+
"keyword": {k: v.value for k, v in keyword_args.items()},
1265+
},
1266+
)
1267+
return self._call_user_function(
1268+
self.functions[func_name],
1269+
positional_args,
1270+
keyword_args,
1271+
expression.location,
1272+
)
12181273
try:
1219-
result = self.builtins.invoke(self, expression.name, args, expression.args, env, expression.location)
1274+
if keyword_args:
1275+
raise ASMRuntimeError(
1276+
f"{expression.name} does not accept keyword arguments",
1277+
location=expression.location,
1278+
rewrite_rule=expression.name,
1279+
)
1280+
arg_nodes = [a.expression for a in expression.args]
1281+
result = self.builtins.invoke(self, expression.name, positional_args, arg_nodes, env, expression.location)
12201282
except ASMRuntimeError:
1221-
self._log_step(rule=expression.name, location=expression.location, extra={"args": [a.value for a in args], "status": "error"})
1283+
self._log_step(
1284+
rule=expression.name,
1285+
location=expression.location,
1286+
extra={
1287+
"args": [a.value for a in positional_args],
1288+
"keyword": {k: v.value for k, v in keyword_args.items()},
1289+
"status": "error",
1290+
},
1291+
)
12221292
raise
1223-
self._log_step(rule=expression.name, location=expression.location, extra={"args": [a.value for a in args], "result": result.value})
1293+
self._log_step(
1294+
rule=expression.name,
1295+
location=expression.location,
1296+
extra={
1297+
"args": [a.value for a in positional_args],
1298+
"keyword": {k: v.value for k, v in keyword_args.items()},
1299+
"result": result.value,
1300+
},
1301+
)
12241302
return result
12251303
raise ASMRuntimeError("Unsupported expression", location=expression.location)
12261304

1227-
def _call_user_function(self, function: Function, args: List[Value], call_location: SourceLocation) -> Value:
1228-
if len(args) != len(function.params):
1305+
def _call_user_function(
1306+
self,
1307+
function: Function,
1308+
positional_args: List[Value],
1309+
keyword_args: Dict[str, Value],
1310+
call_location: SourceLocation,
1311+
) -> Value:
1312+
if len(positional_args) > len(function.params):
12291313
raise ASMRuntimeError(
1230-
f"Function {function.name} expects {len(function.params)} arguments but received {len(args)}",
1314+
f"Function {function.name} expects at most {len(function.params)} positional arguments but received {len(positional_args)}",
12311315
location=call_location,
1316+
rewrite_rule=function.name,
12321317
)
1318+
12331319
env = Environment(parent=function.closure)
1234-
for param_name, param_type, arg in zip(function.params, function.param_types, args):
1235-
if arg.type != param_type:
1320+
1321+
kwds = dict(keyword_args)
1322+
1323+
for param, arg in zip(function.params, positional_args):
1324+
if arg.type != param.type:
12361325
raise ASMRuntimeError(
1237-
f"Argument for '{param_name}' expected {param_type} but got {arg.type}",
1326+
f"Argument for '{param.name}' expected {param.type} but got {arg.type}",
12381327
location=call_location,
12391328
rewrite_rule=function.name,
12401329
)
1241-
env.set(param_name, arg, declared_type=param_type)
1330+
env.set(param.name, arg, declared_type=param.type)
1331+
1332+
remaining_params = function.params[len(positional_args) :]
1333+
for param in remaining_params:
1334+
if param.name in kwds:
1335+
if param.default is None:
1336+
raise ASMRuntimeError(
1337+
f"Parameter '{param.name}' does not accept keyword arguments",
1338+
location=call_location,
1339+
rewrite_rule=function.name,
1340+
)
1341+
value = kwds.pop(param.name)
1342+
if value.type != param.type:
1343+
raise ASMRuntimeError(
1344+
f"Argument for '{param.name}' expected {param.type} but got {value.type}",
1345+
location=call_location,
1346+
rewrite_rule=function.name,
1347+
)
1348+
env.set(param.name, value, declared_type=param.type)
1349+
continue
1350+
if param.default is None:
1351+
raise ASMRuntimeError(
1352+
f"Missing required argument '{param.name}' for function {function.name}",
1353+
location=call_location,
1354+
rewrite_rule=function.name,
1355+
)
1356+
default_value = self._evaluate_expression(param.default, env)
1357+
if default_value.type != param.type:
1358+
raise ASMRuntimeError(
1359+
f"Default for '{param.name}' expected {param.type} but got {default_value.type}",
1360+
location=call_location,
1361+
rewrite_rule=function.name,
1362+
)
1363+
env.set(param.name, default_value, declared_type=param.type)
1364+
1365+
if kwds:
1366+
unexpected = ", ".join(sorted(kwds.keys()))
1367+
raise ASMRuntimeError(
1368+
f"Unexpected keyword arguments: {unexpected}",
1369+
location=call_location,
1370+
rewrite_rule=function.name,
1371+
)
12421372
frame = self._new_frame(function.name, env, call_location)
12431373
self.call_stack.append(frame)
12441374
try:

parser.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,18 @@ class ForStatement(Statement):
7474
@dataclass
7575
class FuncDef(Statement):
7676
name: str
77-
params: List[Tuple[str, str]]
77+
params: List["Param"]
7878
return_type: str
7979
body: Block
8080

8181

82+
@dataclass
83+
class Param:
84+
type: str
85+
name: str
86+
default: Optional["Expression"]
87+
88+
8289
@dataclass
8390
class ReturnStatement(Statement):
8491
expression: Optional["Expression"]
@@ -122,7 +129,13 @@ class Identifier(Expression):
122129
@dataclass
123130
class CallExpression(Expression):
124131
name: str
125-
args: List[Expression]
132+
args: List["CallArgument"]
133+
134+
135+
@dataclass
136+
class CallArgument:
137+
name: Optional[str]
138+
expression: Expression
126139

127140

128141
class Parser:
@@ -189,12 +202,21 @@ def _parse_func(self) -> FuncDef:
189202
keyword = self._consume("FUNC")
190203
name_token = self._consume("IDENT")
191204
self._consume("LPAREN")
192-
params: List[Tuple[str, str]] = []
205+
params: List[Param] = []
206+
seen_default = False
193207
if self._peek().type != "RPAREN":
194208
while True:
195209
type_token = self._consume_type_token()
196210
self._consume("COLON")
197-
params.append((type_token.value, self._consume("IDENT").value))
211+
name_tok = self._consume("IDENT")
212+
default_expr: Optional[Expression] = None
213+
if self._match("EQUALS"):
214+
seen_default = True
215+
default_expr = self._parse_expression()
216+
elif seen_default:
217+
raise ASMParseError(
218+
f"Positional parameter cannot follow parameter with default at line {name_tok.line}")
219+
params.append(Param(type=type_token.value, name=name_tok.value, default=default_expr))
198220
if not self._match("COMMA"):
199221
break
200222
self._consume("RPAREN")
@@ -298,10 +320,21 @@ def _parse_expression(self) -> Expression:
298320
ident: Token = self._consume("IDENT")
299321
location: SourceLocation = self._location_from_token(ident)
300322
if self._match("LPAREN"):
301-
args: List[Expression] = []
323+
args: List[CallArgument] = []
324+
seen_kw = False
302325
if self._peek().type != "RPAREN":
303326
while True:
304-
args.append(self._parse_expression())
327+
if self._peek().type == "IDENT" and self._peek_next().type == "EQUALS":
328+
name_tok = self._consume("IDENT")
329+
self._consume("EQUALS")
330+
arg_expr = self._parse_expression()
331+
seen_kw = True
332+
args.append(CallArgument(name=name_tok.value, expression=arg_expr))
333+
else:
334+
if seen_kw:
335+
raise ASMParseError(
336+
f"Positional argument cannot follow keyword argument at line {self._peek().line}")
337+
args.append(CallArgument(name=None, expression=self._parse_expression()))
305338
if not self._match("COMMA"):
306339
break
307340
self._consume("RPAREN")

0 commit comments

Comments
 (0)