Skip to content

Commit e71c012

Browse files
authored
Merge pull request #778 from dmamelin/feature/improve-class-loading
Improve class loading
2 parents cbb2f44 + e4e4eda commit e71c012

File tree

2 files changed

+126
-11
lines changed

2 files changed

+126
-11
lines changed

custom_components/pyscript/eval.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import logging
1313
import sys
1414
import time
15+
import traceback
1516
import weakref
1617

1718
import yaml
@@ -1090,12 +1091,18 @@ async def ast_while(self, arg):
10901091
async def ast_classdef(self, arg):
10911092
"""Evaluate class definition."""
10921093
bases = [(await self.aeval(base)) for base in arg.bases]
1094+
keywords = {kw.arg: await self.aeval(kw.value) for kw in arg.keywords}
1095+
metaclass = keywords.pop("metaclass", type(bases[0]) if bases else type)
1096+
10931097
if self.curr_func and arg.name in self.curr_func.global_names:
10941098
sym_table_assign = self.global_sym_table
10951099
else:
10961100
sym_table_assign = self.sym_table
10971101
sym_table_assign[arg.name] = EvalLocalVar(arg.name)
1098-
sym_table = {}
1102+
if hasattr(metaclass, "__prepare__"):
1103+
sym_table = metaclass.__prepare__(arg.name, tuple(bases), **keywords)
1104+
else:
1105+
sym_table = {}
10991106
self.sym_table_stack.append(self.sym_table)
11001107
self.sym_table = sym_table
11011108
for arg1 in arg.body:
@@ -1106,11 +1113,17 @@ async def ast_classdef(self, arg):
11061113
raise SyntaxError(f"{val.name()} statement outside loop")
11071114
self.sym_table = self.sym_table_stack.pop()
11081115

1116+
decorators = [await self.aeval(dec) for dec in arg.decorator_list]
11091117
sym_table["__init__evalfunc_wrap__"] = None
11101118
if "__init__" in sym_table:
11111119
sym_table["__init__evalfunc_wrap__"] = sym_table["__init__"]
11121120
del sym_table["__init__"]
1113-
sym_table_assign[arg.name].set(type(arg.name, tuple(bases), sym_table))
1121+
cls = metaclass(arg.name, tuple(bases), sym_table, **keywords)
1122+
if inspect.iscoroutine(cls):
1123+
cls = await cls
1124+
for dec_func in reversed(decorators):
1125+
cls = await self.call_func(dec_func, None, cls)
1126+
sym_table_assign[arg.name].set(cls)
11141127

11151128
async def ast_functiondef(self, arg, async_func=False):
11161129
"""Evaluate function definition."""
@@ -1487,7 +1500,11 @@ async def ast_augassign(self, arg):
14871500
await self.recurse_assign(arg.target, new_val)
14881501

14891502
async def ast_annassign(self, arg):
1490-
"""Execute type hint assignment statement (just ignore the type hint)."""
1503+
"""Execute type hint assignment statement and track __annotations__."""
1504+
if isinstance(arg.target, ast.Name):
1505+
annotations = self.sym_table.setdefault("__annotations__", {})
1506+
if arg.annotation:
1507+
annotations[arg.target.id] = await self.aeval(arg.annotation)
14911508
if arg.value is not None:
14921509
rhs = await self.aeval(arg.value)
14931510
await self.recurse_assign(arg.target, rhs)
@@ -1961,19 +1978,25 @@ async def call_func(self, func, func_name, *args, **kwargs):
19611978
if isinstance(func, (EvalFunc, EvalFuncVar)):
19621979
return await func.call(self, *args, **kwargs)
19631980
if inspect.isclass(func) and hasattr(func, "__init__evalfunc_wrap__"):
1964-
inst = func()
1981+
has_init_wrapper = getattr(func, "__init__evalfunc_wrap__") is not None
1982+
inst = func(*args, **kwargs) if not has_init_wrapper else func()
19651983
#
19661984
# we use weak references when we bind the method calls to the instance inst;
19671985
# otherwise these self references cause the object to not be deleted until
19681986
# it is later garbage collected
19691987
#
19701988
inst_weak = weakref.ref(inst)
19711989
for name in dir(inst):
1972-
value = getattr(inst, name)
1990+
try:
1991+
value = getattr(inst, name)
1992+
except AttributeError:
1993+
# same effect as hasattr (which also catches AttributeError)
1994+
# dir() may list names that aren't actually accessible attributes
1995+
continue
19731996
if type(value) is not EvalFuncVar:
19741997
continue
19751998
setattr(inst, name, EvalFuncVarClassInst(value.get_func(), value.get_ast_ctx(), inst_weak))
1976-
if getattr(func, "__init__evalfunc_wrap__") is not None:
1999+
if has_init_wrapper:
19772000
#
19782001
# since our __init__ function is async, call the renamed one
19792002
#
@@ -2197,11 +2220,9 @@ def format_exc(self, exc, lineno=None, col_offset=None, short=False, code_list=N
21972220
else:
21982221
mesg = f"Exception in <{self.filename}>:\n"
21992222
mesg += f"{type(exc).__name__}: {exc}"
2200-
#
2201-
# to get a more detailed traceback on exception (eg, when chasing an internal
2202-
# error), add an "import traceback" above, and uncomment this next line
2203-
#
2204-
# return mesg + "\n" + traceback.format_exc(-1)
2223+
2224+
if _LOGGER.isEnabledFor(logging.DEBUG):
2225+
mesg += "\n" + traceback.format_exc()
22052226
return mesg
22062227

22072228
def get_exception(self):

tests/test_unit_eval.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,100 @@
144144
["x: int = [10, 20]; x", [10, 20]],
145145
["Foo = type('Foo', (), {'x': 100}); Foo.x = 10; Foo.x", 10],
146146
["Foo = type('Foo', (), {'x': 100}); Foo.x += 10; Foo.x", 110],
147+
[
148+
"""
149+
from enum import IntEnum
150+
151+
class TestIntMode(IntEnum):
152+
VAL1 = 1
153+
VAL2 = 2
154+
VAL3 = 3
155+
[TestIntMode.VAL2 == 2, isinstance(TestIntMode.VAL3, IntEnum)]
156+
""",
157+
[True, True],
158+
],
159+
[
160+
"""
161+
from enum import StrEnum
162+
163+
class TestStrEnum(StrEnum):
164+
VAL1 = "val1"
165+
VAL2 = "val2"
166+
VAL3 = "val3"
167+
[TestStrEnum.VAL2 == "val2", isinstance(TestStrEnum.VAL3, StrEnum)]
168+
""",
169+
[True, True],
170+
],
171+
[
172+
"""
173+
from enum import Enum, EnumMeta
174+
175+
class Color(Enum):
176+
RED = 1
177+
BLUE = 2
178+
[type(Color) is EnumMeta, isinstance(Color.RED, Color), list(Color.__members__.keys())]
179+
""",
180+
[True, True, ["RED", "BLUE"]],
181+
],
182+
[
183+
"""
184+
from dataclasses import dataclass
185+
186+
@dataclass()
187+
class DT:
188+
name: str
189+
num: int = 32
190+
obj1 = DT(name="abc")
191+
obj2 = DT("xyz", 5)
192+
[obj1.name, obj1.num, obj2.name, obj2.num]
193+
""",
194+
["abc", 32, "xyz", 5],
195+
],
196+
[
197+
"""
198+
class Meta(type):
199+
def __new__(mcls, name, bases, ns, flag=False):
200+
ns["flag"] = flag
201+
return type.__new__(mcls, name, bases, ns)
202+
203+
class Foo(metaclass=Meta, flag=True):
204+
pass
205+
[Foo.flag, isinstance(Foo, Meta)]
206+
""",
207+
[True, True],
208+
],
209+
[
210+
"""
211+
def deco(label):
212+
def wrap(cls):
213+
cls.labels.append(label)
214+
return cls
215+
return wrap
216+
217+
@deco("first")
218+
@deco("second")
219+
class Decorated:
220+
labels = []
221+
Decorated.labels
222+
""",
223+
["second", "first"],
224+
],
225+
[
226+
"""
227+
hits = []
228+
229+
def anno():
230+
hits.append("ok")
231+
return int
232+
233+
class Annotated:
234+
a: anno()
235+
b: int = 3
236+
c = "skip"
237+
[hits, Annotated.__annotations__, Annotated.b, hasattr(Annotated, "c")]
238+
""",
239+
[["ok"], {"a": int, "b": int}, 3, True],
240+
],
147241
["Foo = [type('Foo', (), {'x': 100})]; Foo[0].x = 10; Foo[0].x", 10],
148242
["Foo = [type('Foo', (), {'x': [100, 101]})]; Foo[0].x[1] = 10; Foo[0].x", [100, 10]],
149243
["Foo = [type('Foo', (), {'x': [0, [[100, 101]]]})]; Foo[0].x[1][0][1] = 10; Foo[0].x[1]", [[100, 10]]],

0 commit comments

Comments
 (0)