Skip to content

Commit 17728b1

Browse files
gh-17: Add new kwargs to MATCH operator
1 parent da933ae commit 17728b1

File tree

4 files changed

+94
-20
lines changed

4 files changed

+94
-20
lines changed

asm-lang.exe

-1.23 KB
Binary file not shown.

docs/SPECIFICATION.html

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@
699699
700700
- `VALUEIN(ANY: value, MAP: map):INT` - Returns `1` if any value stored in `map` is equal to `value` using the language's equality semantics, otherwise `0`.
701701
702-
- `MATCH(MAP: map, MAP: template, INT: typing = 0):INT` - Returns `1` if every key in `template` is present in `map`. If `typing` is true, also require that for each matching key the stored value types are identical between `map` and `template`; otherwise return `0`.
702+
- `MATCH(MAP: map, MAP: template, INT: typing = 0, INT: recurse = 0, INT: shape = 0):INT` - Returns `1` if every key in `template` is present in `map`. If `typing` is true, also require that for each matching key the stored value types are identical between `map` and `template`. If `shape` is true, require that for any matching key where either side is a `TNS`, both sides are `TNS` and their shapes are identical. If `recurse` is true, apply these same rules to every `MAP` value nested anywhere within `map` (recursively). On any failure return `0`.
703703
704704
- `INV(MAP: map):MAP` - Return a new map whose key/value pairs are reversed. Each value in `map` becomes a key in the returned map and each original key becomes the corresponding value. All values in `map` MUST be scalar (`INT`, `FLT`, or `STR`) so they can be used as keys; otherwise `INV` raises a runtime error. If `map` contains duplicate values (by key type and value), `INV` raises a runtime error.
705705
@@ -725,9 +725,7 @@
725725
726726
### 12.9 Logarithms
727727
728-
- `LOG(INT: a):INT` ; floor(log2(a)) for a > 0
729-
730-
- `LOG(FLT: a):FLT` ; floor(log2(a)) for a > 0
728+
- `LOG(INT|FLT: a):INT|FLT` ; floor(log2(a)) for a > 0
731729
732730
- `CLOG(INT: a):INT` ; ceil(log2(a)) for a > 0
733731

interpreter.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def __init__(self) -> None:
536536
self._register_custom("KEYS", 1, 1, self._keys)
537537
self._register_custom("VALUES", 1, 1, self._values)
538538
self._register_custom("KEYIN", 2, 2, self._keyin)
539-
self._register_custom("MATCH", 2, 3, self._match)
539+
self._register_custom("MATCH", 2, 5, self._match)
540540
self._register_custom("VALUEIN", 2, 2, self._valuein)
541541
self._register_custom("INV", 1, 1, self._inv)
542542
self._register_custom("EXPORT", 2, 2, self._export)
@@ -1271,39 +1271,78 @@ def _match(
12711271
___: Environment,
12721272
location: SourceLocation,
12731273
) -> Value:
1274-
# MATCH(MAP: map, MAP: template, INT: typing = 0):INT
1274+
# MATCH(MAP: map, MAP: template, INT: typing = 0, INT: recurse = 0, INT: shape = 0):INT
12751275
# Return 1 if every key in `template` is present in `map`.
12761276
# If `typing` is true (non-zero), also require the types of the
12771277
# associated values to match between the two maps.
1278-
if len(args) not in (2, 3):
1279-
raise ASMRuntimeError("MATCH requires two or three arguments", location=location, rewrite_rule="MATCH")
1278+
# If `recurse` is true, apply the same rules to every MAP value nested
1279+
# anywhere within `map` (recursively).
1280+
# If `shape` is true, require that any matching TNS values have identical
1281+
# shapes between `map` and `template`.
1282+
if len(args) not in (2, 3, 4, 5):
1283+
raise ASMRuntimeError("MATCH requires 2 to 5 arguments", location=location, rewrite_rule="MATCH")
12801284
mval = args[0]
12811285
tval = args[1]
1282-
typing_flag = Value(TYPE_INT, 0)
1283-
if len(args) == 3:
1284-
typing_flag = args[2]
1286+
typing_flag = args[2] if len(args) >= 3 else Value(TYPE_INT, 0)
1287+
recurse_flag = args[3] if len(args) >= 4 else Value(TYPE_INT, 0)
1288+
shape_flag = args[4] if len(args) >= 5 else Value(TYPE_INT, 0)
12851289

12861290
if mval.type != TYPE_MAP:
12871291
raise ASMRuntimeError("MATCH expects a MAP as first argument", location=location, rewrite_rule="MATCH")
12881292
if tval.type != TYPE_MAP:
12891293
raise ASMRuntimeError("MATCH expects a MAP as second argument", location=location, rewrite_rule="MATCH")
12901294
if typing_flag.type != TYPE_INT:
12911295
raise ASMRuntimeError("MATCH expects typing flag to be INT", location=location, rewrite_rule="MATCH")
1296+
if recurse_flag.type != TYPE_INT:
1297+
raise ASMRuntimeError("MATCH expects recurse flag to be INT", location=location, rewrite_rule="MATCH")
1298+
if shape_flag.type != TYPE_INT:
1299+
raise ASMRuntimeError("MATCH expects shape flag to be INT", location=location, rewrite_rule="MATCH")
12921300

12931301
m = mval.value
12941302
t = tval.value
12951303
assert isinstance(m, Map) and isinstance(t, Map)
12961304

12971305
require_typing = 0 if typing_flag.value == 0 else 1
1306+
require_recurse = 0 if recurse_flag.value == 0 else 1
1307+
require_shape = 0 if shape_flag.value == 0 else 1
12981308

1299-
for key in t.data.keys():
1300-
if key not in m.data:
1301-
return Value(TYPE_INT, 0)
1302-
if require_typing:
1303-
mv = m.data[key]
1309+
def _match_one(mm: Map) -> int:
1310+
for key in t.data.keys():
1311+
if key not in mm.data:
1312+
return 0
1313+
mv = mm.data[key]
13041314
tv = t.data[key]
1305-
if mv.type != tv.type:
1306-
return Value(TYPE_INT, 0)
1315+
if require_typing and mv.type != tv.type:
1316+
return 0
1317+
if require_shape:
1318+
if (mv.type == TYPE_TNS) or (tv.type == TYPE_TNS):
1319+
if mv.type != TYPE_TNS or tv.type != TYPE_TNS:
1320+
return 0
1321+
mt = mv.value
1322+
tt = tv.value
1323+
if not isinstance(mt, Tensor) or not isinstance(tt, Tensor):
1324+
return 0
1325+
if mt.shape != tt.shape:
1326+
return 0
1327+
return 1
1328+
1329+
if not require_recurse:
1330+
return Value(TYPE_INT, _match_one(m))
1331+
1332+
# Recurse into nested MAP values within `map`.
1333+
visited: set[int] = set()
1334+
stack: List[Map] = [m]
1335+
while stack:
1336+
mm = stack.pop()
1337+
mm_id = id(mm)
1338+
if mm_id in visited:
1339+
continue
1340+
visited.add(mm_id)
1341+
if _match_one(mm) == 0:
1342+
return Value(TYPE_INT, 0)
1343+
for v in mm.data.values():
1344+
if v.type == TYPE_MAP and isinstance(v.value, Map):
1345+
stack.append(v.value)
13071346

13081347
return Value(TYPE_INT, 1)
13091348

@@ -1370,7 +1409,7 @@ def _bool(self, interpreter: "Interpreter", args: List[Value], __: List[Expressi
13701409
# BOOL(ANY: item):INT -> truthiness of item (INT: nonzero, STR: non-empty, TNS: any true element)
13711410
return Value(TYPE_INT, 1 if self._as_bool_value(interpreter, args[0], loc) != 0 else 0)
13721411

1373-
def _eq(self, interpreter: "Interpreter", args: List[Value], __: List[Expression], ___: Environment, ___loc: SourceLocation) -> Value:
1412+
def _eq(self, interpreter: "Interpreter", args: List[Value], __: List[Expression], ___: Environment, location: SourceLocation) -> Value:
13741413
a, b = args
13751414
return Value(TYPE_INT, 1 if interpreter._values_equal(a, b) else 0)
13761415

@@ -1639,7 +1678,7 @@ def _os(
16391678
args: List[Value],
16401679
__: List[Expression],
16411680
___: Environment,
1642-
__loc: SourceLocation,
1681+
loc: SourceLocation,
16431682
) -> Value:
16441683
# Return a short lowercase host OS family string as STR.
16451684
plat = sys.platform.lower()

test.asmln

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,43 @@ FUNC BUILTIN_TESTS():INT{
306306
ASSERT( EQ( MATCH(mm2, tmpl3), 1 ) )
307307
MAP: tmpl4 = < "three" = 1 >
308308
ASSERT( EQ( MATCH(mm2, tmpl4), 0 ) )
309+
310+
# MATCH recurse flag: when true, apply same template rules to nested MAPs.
311+
MAP: mmr_ok = < "one" = 1, "child" = < "one" = 1 > >
312+
ASSERT( EQ( MATCH(mmr_ok, tmpl1, 0, 0), 1 ) )
313+
ASSERT( EQ( MATCH(mmr_ok, tmpl1, 0, 1), 1 ) )
314+
MAP: mmr_bad = < "one" = 1, "child" = < "two" = 10 > >
315+
ASSERT( EQ( MATCH(mmr_bad, tmpl1, 0, 0), 1 ) )
316+
ASSERT( EQ( MATCH(mmr_bad, tmpl1, 0, 1), 0 ) )
317+
# recurse + typing: nested MAP must also satisfy typing constraint.
318+
MAP: mmr_tbad = < "one" = 1, "child" = < "one" = "x" > >
319+
ASSERT( EQ( MATCH(mmr_tbad, tmpl1, 1, 0), 1 ) )
320+
ASSERT( EQ( MATCH(mmr_tbad, tmpl1, 1, 1), 0 ) )
321+
322+
# MATCH shape flag: when true, require matching TNS shapes.
323+
TNS: s1 = [1, 10, 11]
324+
TNS: s1b = [0, 0, 0]
325+
TNS: s2 = [[1,10],[11,100]]
326+
MAP: mshape_ok = < "t" = s1 >
327+
MAP: tshape_ok = < "t" = s1b >
328+
ASSERT( EQ( MATCH(mshape_ok, tshape_ok, 1, 0, 1), 1 ) )
329+
MAP: tshape_bad = < "t" = s2 >
330+
ASSERT( EQ( MATCH(mshape_ok, tshape_bad, 1, 0, 1), 0 ) )
331+
# If shape is enabled and either side is a TNS, both must be TNS.
332+
MAP: mshape_non = < "t" = 1 >
333+
ASSERT( EQ( MATCH(mshape_non, tshape_ok, 0, 0, 0), 1 ) )
334+
ASSERT( EQ( MATCH(mshape_non, tshape_ok, 0, 0, 1), 0 ) )
335+
336+
DEL(mmr_ok)
337+
DEL(mmr_bad)
338+
DEL(mmr_tbad)
339+
DEL(s1)
340+
DEL(s1b)
341+
DEL(s2)
342+
DEL(mshape_ok)
343+
DEL(tshape_ok)
344+
DEL(tshape_bad)
345+
DEL(mshape_non)
309346
DEL(tmpl1)
310347
DEL(tmpl2)
311348
DEL(tmpl3)

0 commit comments

Comments
 (0)