Skip to content

Commit d0a7d76

Browse files
Add IN operator
1 parent 91c9617 commit d0a7d76

File tree

3 files changed

+16
-0
lines changed

3 files changed

+16
-0
lines changed

SPECIFICATION.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ Type notation: union signatures such as `INT|STR` restrict arguments to the list
287287
- `SLEN(STR: s):INT` returns the length of the supplied `STR` in characters as an `INT`. The argument must be a `STR`; passing an `INT` raises a runtime error.
288288
- `ILEN(INT: n):INT` returns the length in binary digits of the absolute value of the supplied `INT`. `ILEN(0)` returns 1. The argument must be an `INT`; passing a `STR` raises a runtime error.
289289

290+
- `IN(ANY: value, TNS: tensor):INT` ; returns `1` if `value` is equal to any element of `tensor` (using the language's equality semantics), otherwise `0`.
291+
290292
### Arithmetic (`INT` only)
291293
- `ADD(INT: a, INT: b):INT` ; a + b
292294
- `SUB(INT: a, INT: b):INT` ; a - b

asmln.exe

-188 Bytes
Binary file not shown.

interpreter.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def __init__(self) -> None:
329329
self._register_custom("XOR", 2, 2, self._xor)
330330
self._register_custom("NOT", 1, 1, self._not)
331331
self._register_custom("EQ", 2, 2, self._eq)
332+
self._register_custom("IN", 2, 2, self._in)
332333
self._register_int_only("GT", 2, lambda a, b: 1 if a > b else 0)
333334
self._register_int_only("LT", 2, lambda a, b: 1 if a < b else 0)
334335
self._register_int_only("GTE", 2, lambda a, b: 1 if a >= b else 0)
@@ -662,6 +663,19 @@ def _eq(self, interpreter: "Interpreter", args: List[Value], __: List[Expression
662663
return Value(TYPE_INT, 1 if interpreter._tensor_equal(a.value, b.value) else 0)
663664
return Value(TYPE_INT, 1 if a.value == b.value else 0)
664665

666+
def _in(self, interpreter: "Interpreter", args: List[Value], __: List[Expression], ___: Environment, location: SourceLocation) -> Value:
667+
# IN(ANY: value, TNS: tensor):INT -> 1 if value is contained anywhere in tensor, else 0
668+
if len(args) != 2:
669+
raise ASMRuntimeError("IN requires two arguments", location=location, rewrite_rule="IN")
670+
needle, haystack = args
671+
if haystack.type != TYPE_TNS:
672+
raise ASMRuntimeError("IN requires a tensor as second argument", location=location, rewrite_rule="IN")
673+
assert isinstance(haystack.value, Tensor)
674+
for item in haystack.value.data.flat:
675+
if interpreter._values_equal(needle, item):
676+
return Value(TYPE_INT, 1)
677+
return Value(TYPE_INT, 0)
678+
665679
def _slice(self, _: "Interpreter", args: List[Value], __: List[Expression], ___: Environment, location: SourceLocation) -> Value:
666680
target, hi_val, lo_val = args
667681
hi = self._expect_int(hi_val, "SLICE", location)

0 commit comments

Comments
 (0)