Skip to content

Commit a011bae

Browse files
committed
td.py: + ensure_scalar(...)
1 parent f48aaa3 commit a011bae

1 file changed

Lines changed: 19 additions & 0 deletions

File tree

src/pyblocksim/td.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def get_loop_symbol():
6262

6363

6464
def set_loop_symbol(ls, expr):
65+
ensure_scalar(expr)
6566
assert loop_mappings[ls] is None
6667
loop_mappings[ls] = expr
6768

@@ -163,6 +164,11 @@ def set_inputs(self, input1, **kwargs):
163164
assert len(self.input_vars) == len(rest_list) + 1
164165

165166
self.input_expr_list = [input1, *rest_list]
167+
self._check_inputs()
168+
169+
def _check_inputs(self):
170+
for expr in self.input_expr_list:
171+
ensure_scalar(expr)
166172

167173
def _get_input_exprs_from_kwargs(self, kwargs: dict, set_attrs=False):
168174

@@ -557,3 +563,16 @@ def compute_block_outputs(kk, xx) -> dict:
557563
block.output_res = block_outputs[block] = block.output_fnc(kk, *xx.T)
558564

559565
return block_outputs
566+
567+
568+
def ensure_scalar(expr):
569+
expr = sp.sympify(expr)
570+
if isinstance(expr, sp.MatrixBase):
571+
# isinstance(expr, sp.Basic) might be also true -> ignore it here
572+
msg = "Unexpectedly got matrix where scalar was expected."
573+
elif isinstance(expr, sp.Basic):
574+
# this is now OK
575+
return
576+
else:
577+
msg = f"Unexpectedly got {type(expr)} where scalar was expected."
578+
raise TypeError(msg)

0 commit comments

Comments
 (0)