Skip to content

Commit 183bb8d

Browse files
committed
Lox is progressing though the book
Nearly have an interpreter
1 parent 684cb47 commit 183bb8d

File tree

1 file changed

+211
-35
lines changed
  • teachprogramming/static/language_reference/languages/aqa

1 file changed

+211
-35
lines changed

teachprogramming/static/language_reference/languages/aqa/AQA.py

Lines changed: 211 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import sys
12
import enum
2-
from typing import NamedTuple, Self, Callable
3+
from typing import NamedTuple, Self, Callable, Protocol, Any
34
from collections.abc import Sequence, MutableSequence, Mapping
45
from functools import cached_property
56
from pprint import pprint as pp
67
import functools
78
import logging
8-
9+
from pathlib import Path
910

1011
log = logging.getLogger(__name__)
1112

@@ -60,6 +61,34 @@ class TokenType(enum.StrEnum):
6061
VAR = enum.auto()
6162
WHILE = enum.auto()
6263

64+
KEYWORDS = frozenset((
65+
TokenType.AND,
66+
TokenType.CLASS,
67+
TokenType.ELSE,
68+
TokenType.FALSE,
69+
TokenType.FOR,
70+
TokenType.IF,
71+
TokenType.NIL,
72+
TokenType.OR,
73+
TokenType.PRINT,
74+
TokenType.RETURN,
75+
TokenType.SUPER,
76+
TokenType.THIS,
77+
TokenType.TRUE,
78+
TokenType.VAR,
79+
TokenType.WHILE,
80+
))
81+
OPERATORS_NUMERICAL = frozenset((
82+
TokenType.MINUS,
83+
TokenType.SLASH,
84+
TokenType.STAR,
85+
TokenType.PLUS,
86+
TokenType.GREATER,
87+
TokenType.GREATER_EQUAL,
88+
TokenType.LESS,
89+
TokenType.LESS_EQUAL,
90+
))
91+
6392
# ------------------------------------------------------------------------------
6493

6594
class TextLocation(NamedTuple):
@@ -81,6 +110,8 @@ class MutableTextLocation():
81110
@property
82111
def immutable(self) -> TextLocation:
83112
return TextLocation(self.line, self.col)
113+
def __str__(self) -> str:
114+
return str(self.immutable)
84115
def newLine(self):
85116
self.line += 1
86117
self.col = 0
@@ -162,24 +193,6 @@ def number(s: Scanner) -> wasConsumed:
162193
while s.peek().isdigit(): s.advance()
163194
s.addToken(TokenType.NUMBER, s.source[s.index_start:s.index_current])
164195

165-
KEYWORDS = frozenset((
166-
TokenType.AND,
167-
TokenType.CLASS,
168-
TokenType.ELSE,
169-
TokenType.FALSE,
170-
TokenType.FOR,
171-
TokenType.IF,
172-
TokenType.NIL,
173-
TokenType.OR,
174-
TokenType.PRINT,
175-
TokenType.RETURN,
176-
TokenType.SUPER,
177-
TokenType.THIS,
178-
TokenType.TRUE,
179-
TokenType.VAR,
180-
TokenType.WHILE,
181-
))
182-
183196
def identifier(s: Scanner) -> wasConsumed:
184197
if s.peek().isalpha():
185198
while s.peek().isalnum(): s.advance()
@@ -203,9 +216,10 @@ def _t(s: Scanner) -> wasConsumed:
203216
*map(createDefaultTokenHandlerFor, ('!=', '==', '<=', '>=', '(', ')', '{', '}', ',', '.', '-', '+', ';', '*', '!', '>', '<', '=', '/'))
204217
)
205218

219+
scanner: Callable[[str], Scanner] = functools.partial(Scanner, token_handlers=DEFAULT_TOKEN_HANDLERS)
206220

207221
def test_scanner():
208-
tokens = Scanner('thing = ("test" + 1.23) # This is a comment', DEFAULT_TOKEN_HANDLERS).tokens
222+
tokens = scanner('thing = ("test" + 1.23) # This is a comment').tokens
209223
token_types = tuple(t.type for t in tokens)
210224
assert token_types == (
211225
TokenType.IDENTIFIER,
@@ -221,34 +235,37 @@ def test_scanner():
221235

222236
# ------------------------------------------------------------------------------
223237

224-
import abc
238+
#class (Protocol):
239+
225240

226-
class Expr(abc.ABC):
227-
pass
228-
class Literal(Expr):
229-
def __init__(self, literal: str|bool|None|int|float):
230-
self.literal = literal
241+
type Number = int|float
242+
type PrimitiveValue = str|bool|None|Number
243+
class Literal():
244+
def __init__(self, value: PrimitiveValue):
245+
self.value = value
231246
def __str__(self) -> str:
232-
return str(self.literal)
233-
class Unary(Expr):
234-
def __init__(self, operator: Token, expression: Expr):
247+
return str(self.value)
248+
class Unary():
249+
def __init__(self, operator: Token, expression: 'Expr'):
235250
self.operator = operator
236251
self.expression = expression
237252
def __str__(self) -> str:
238253
return f'{self.operator.type.value}{self.expression}'
239-
class Binary(Expr):
240-
def __init__(self, expression1: Expr, operator: Token, expression2: Expr):
254+
class Binary():
255+
def __init__(self, expression1: 'Expr', operator: Token, expression2: 'Expr'):
241256
self.expression1 = expression1
242257
self.operator = operator
243258
self.expression2 = expression2
244259
def __str__(self) -> str:
245260
return ''.join(map(str, (self.expression1, self.operator.type.value, self.expression2)))
246-
class Grouping(Expr):
247-
def __init__(self, expression: Expr):
261+
class Grouping():
262+
def __init__(self, expression: 'Expr'):
248263
self.expression = expression
249264
def __str__(self) -> str:
250265
return f'({self.expression})'
251266

267+
Expr = Literal | Unary | Binary | Grouping
268+
252269

253270
class Parser():
254271
class ParseError(BaseException): ...
@@ -262,6 +279,7 @@ def parse(self) -> Expr | None:
262279
try:
263280
return self.expression()
264281
except self.ParseError as pe:
282+
log.exception('ParseError: TODO')
265283
return None
266284

267285
@property
@@ -350,4 +368,162 @@ def primary(self) -> Expr:
350368
def test_parser():
351369
tokens = Scanner('12.3 * (45 - "test") >= !10', DEFAULT_TOKEN_HANDLERS).tokens
352370
expr = Parser(tokens).parse
353-
assert False
371+
assert str(expr) == '12.3*(45-test)>=!10'
372+
373+
374+
# ------------------------------------------------------------------------------
375+
376+
377+
class Interpreter():
378+
379+
class RuntimeError(BaseException):
380+
token: Token
381+
def __init__(self, token: Token, message: str):
382+
super().__init__(message)
383+
self.token = token
384+
@property
385+
def message(self) -> str: return self.args[0]
386+
387+
388+
def stringify(self, obj: Any) -> str:
389+
if obj == None: return 'nil'
390+
if isinstance(obj, (float,)):
391+
text = str(obj)
392+
if text.endswith(".0"):
393+
#text = text.substring(0, text.length() - 2)
394+
pass # TODO
395+
return text
396+
return str(obj)
397+
398+
def isEqual(self, a: Any, b: Any) -> bool:
399+
if (a == None and b == None): return True
400+
if (a == None): return False
401+
return a == b
402+
403+
def isTruthy(self, obj: Any) -> bool:
404+
match obj:
405+
case None:
406+
return False
407+
case bool():
408+
return obj
409+
case _:
410+
return True
411+
412+
def checkNumberOperand(self, operator: Token, operand: Any):
413+
if isinstance(object, (float,)): return
414+
raise RuntimeError(operator, "Operand must be a number.")
415+
def checkNumberOperands(self, operator: Token, left: Any, right: Any):
416+
if isinstance(left, (float,)) and isinstance(right, (float,)): return
417+
raise RuntimeError(operator, "Operands must be numbers.")
418+
419+
def evaluateUnary(self, expr: Unary) -> PrimitiveValue:
420+
right = self.evaluate(expr.expression)
421+
match expr.operator.type:
422+
case TokenType.MINUS:
423+
self.checkNumberOperand(expr.operator, right)
424+
return -float(right)
425+
case TokenType.BANG:
426+
return not self.isTruthy(right)
427+
raise NotImplementedError()
428+
429+
def evaluateBinary(self, expr: Binary) -> PrimitiveValue:
430+
left = self.evaluate(expr.expression1)
431+
right = self.evaluate(expr.expression2)
432+
433+
match expr.operator.type:
434+
case TokenType.PLUS:
435+
if isinstance(left,(float,)) and isinstance(right, (float,)):
436+
return left + right
437+
if isinstance(left, str) and isinstance(right, str):
438+
return ''.join((left, right))
439+
raise RuntimeError(expr.operator, "Operands must be two numbers or two strings.")
440+
case TokenType.MINUS:
441+
self.checkNumberOperands(expr.operator, left, right)
442+
return float(left) - float(right)
443+
case TokenType.SLASH:
444+
self.checkNumberOperands(expr.operator, left, right)
445+
return float(left) / float(right)
446+
case TokenType.STAR:
447+
self.checkNumberOperands(expr.operator, left, right)
448+
return float(left) * float(right)
449+
case TokenType.GREATER:
450+
self.checkNumberOperands(expr.operator, left, right)
451+
return float(left) > float(right)
452+
case TokenType.GREATER_EQUAL:
453+
self.checkNumberOperands(expr.operator, left, right)
454+
return float(left) >= float(right)
455+
case TokenType.LESS:
456+
self.checkNumberOperands(expr.operator, left, right)
457+
return float(left) < float(right)
458+
case TokenType.LESS_EQUAL:
459+
self.checkNumberOperands(expr.operator, left, right)
460+
return float(left) <= float(right)
461+
462+
case TokenType.BANG_EQUAL:
463+
return not self.isEqual(left, right)
464+
case TokenType.EQUAL_EQUAL:
465+
return self.isEqual(left, right)
466+
467+
case _:
468+
raise NotImplementedError()
469+
470+
471+
def evaluate(self, expr: Expr) -> PrimitiveValue:
472+
match expr:
473+
case Literal():
474+
return expr.value
475+
case Grouping():
476+
return self.evaluate(expr.expression)
477+
case Unary():
478+
return self.evaluateUnary(expr)
479+
case Binary():
480+
return self.evaluateBinary(expr)
481+
482+
def interpret(self, expression: Expr) -> None:
483+
try:
484+
value = self.evaluate(expression)
485+
print(str(value))
486+
except RuntimeError as error:
487+
log.exception('RuntimeError - MORE TODO HERE')
488+
489+
490+
def test_interperet_evaluate_expression():
491+
expr_str = '5 * 5'
492+
value = Interpreter().evaluate(Parser(Scanner(expr_str, DEFAULT_TOKEN_HANDLERS).tokens).parse)
493+
assert value == 10
494+
495+
496+
# ------------------------------------------------------------------------------
497+
498+
499+
class Lox:
500+
501+
def __init__(self):
502+
self.interpreter = Interpreter()
503+
self.hadError = False
504+
self.hadRuntimeError = False
505+
506+
def report(self, message: str, location: TextLocation):
507+
log.warning(f'[{location}] {message}')
508+
self.hadError = True
509+
510+
def runtimeError(self, error: Interpreter.RuntimeError) -> None:
511+
log.error(f"{error.message}\n[line {error.token.location}]")
512+
self.hadRuntimeError = True
513+
514+
def run(self, source: str) -> None:
515+
tokens = scanner(source).tokens
516+
expr = Parser(tokens).parse
517+
self.interpreter.evaluate(expr)
518+
519+
520+
if __name__ == '__main__':
521+
lox = Lox()
522+
if len(sys.argv) == 2 and (file := Path(sys.argv[1])):
523+
lox.run(file.read_text())
524+
else:
525+
print('Lox REPL')
526+
while _input := input('> '):
527+
lox.run(_input)
528+
if lox.hadError: sys.exit(65)
529+
if lox.hadRuntimeError: sys.exit(75)

0 commit comments

Comments
 (0)