Skip to content

Commit 147e29e

Browse files
committed
PR Changes
Signed-off-by: Max Chesterfield <max.chesterfield@zepben.com>
1 parent 1d2f9cc commit 147e29e

8 files changed

Lines changed: 137 additions & 152 deletions

File tree

src/zepben/evolve/services/network/tracing/phases/set_phases.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _terminals_from_network():
9090
async def _(
9191
self,
9292
start_terminal: Terminal,
93-
phases: Union[PhaseCode, List[SinglePhaseKind]]=None,
93+
phases: Union[PhaseCode, List[SinglePhaseKind], Set[SinglePhaseKind]]=None,
9494
network_state_operators: Type[NetworkStateOperators]=NetworkStateOperators.NORMAL,
9595
seed_terminal: Terminal=None):
9696
"""
@@ -144,11 +144,8 @@ def spread_phases(
144144
:param network_state_operators: The `NetworkStateOperators` to be used when setting phases.
145145
"""
146146

147-
if phases is None:
148-
self.spread_phases(from_terminal, to_terminal, from_terminal.phases.single_phases, network_state_operators)
149-
else:
150-
paths = self._get_nominal_phase_paths(network_state_operators, from_terminal, to_terminal, phases)
151-
self._flow_phases(network_state_operators, from_terminal, to_terminal, paths)
147+
paths = self._get_nominal_phase_paths(network_state_operators, from_terminal, to_terminal, phases or from_terminal.phases.single_phases)
148+
self._flow_phases(network_state_operators, from_terminal, to_terminal, paths)
152149

153150
async def _run_terminals(self, terminals: Iterable[Terminal], network_state_operators: Type[NetworkStateOperators]):
154151

@@ -356,14 +353,15 @@ def _flow_transformer_phases(
356353
# Split the phases into ones we need to flow directly, and ones that have been added by a transformer. In
357354
# the case of an added Y phase (SWER -> LV2 transformer) we need to flow the phases before we can calculate
358355
# the missing phase.
356+
flow_phases = (p for p in paths if p.from_phase == SinglePhaseKind.NONE)
357+
add_phases = (p for p in paths if p.from_phase != SinglePhaseKind.NONE)
358+
for p in flow_phases:
359+
self._try_add_phase(from_terminal, from_phases, to_terminal, to_phases, p.to_phase, allow_suspect_flow,
360+
lambda: updated_phases.append(True))
359361

360-
for path in paths:
361-
if path.from_phase == SinglePhaseKind.NONE:
362-
self._try_add_phase(from_terminal, from_phases, to_terminal, to_phases, path.to_phase, allow_suspect_flow,
363-
lambda: updated_phases.append(True))
364-
else:
365-
self._try_set_phase(from_phases[path.from_phase], from_terminal, from_phases, path.from_phase,
366-
to_terminal, to_phases, path.to_phase, lambda: updated_phases.append(True))
362+
for p in add_phases:
363+
self._try_set_phase(from_phases[p.from_phase], from_terminal, from_phases, p.from_phase,
364+
to_terminal, to_phases, p.to_phase, lambda: updated_phases.append(True))
367365

368366
return any(updated_phases)
369367

@@ -399,8 +397,7 @@ def _try_set_phase(
399397
on_success: Callable[[], None]):
400398

401399
try:
402-
if phase != SinglePhaseKind.NONE:
403-
to_phases[to_] = phase
400+
if phase != SinglePhaseKind.NONE and to_phases.__setitem__(to_, phase):
404401
if self._debug_logger:
405402
self._debug_logger.info(f' {from_terminal.mrid}[{from_}] -> {to_terminal.mrid}[{to_}]: set to {phase}')
406403
on_success()

src/zepben/evolve/services/network/tracing/traversal/debug_logging.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,15 @@ def _wrap_attr(_index: int, _attr: str, _msg: str) -> None:
8585
:param _attr: Method/Function name.
8686
:raises AttributeError: if ``wrappable`` is already wrapped
8787
"""
88+
8889
if isinstance(_attr, tuple):
8990
_attr_name, _log_attr_name = _attr
9091
else:
9192
_attr_name = _log_attr_name = _attr
9293

93-
# Wrapped classes will have __wrapped__ == True - if it exists on the obj passed in, the user is attempting to wrap an
94-
# already wrapped object. This can lead to unexpected outcomes so we do not support it
95-
if (to_wrap := getattr(w_obj, _attr_name)) and hasattr(to_wrap, '__wrapped__'):
96-
raise AttributeError(f'Wrapped objects cannot be rewrapped, pass in the original object instead.')
94+
to_wrap = getattr(w_obj, _attr_name)
9795

9896
setattr(w_obj, _attr_name, self._log_method_call(to_wrap, f'{self.description}: {_log_attr_name}({_index})' + _msg))
99-
setattr(w_obj, '__wrapped__', True)
10097

10198
for clazz in (StepAction, StopCondition, QueueCondition):
10299
if isinstance(w_obj, clazz):

src/zepben/evolve/services/network/tracing/traversal/step_action.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class StepAction(Generic[T]):
2525
"""
2626

2727
def __init__(self, _func: StepActionFunc = None):
28-
self._func = _func
28+
self._func = _func or self._apply
2929

3030
def __init_subclass__(cls):
3131
"""
@@ -50,6 +50,17 @@ def apply(self, item: T, context: StepContext):
5050

5151
return self._func(item, context)
5252

53+
@abstractmethod
54+
def _apply(self, item: T, context: StepContext):
55+
"""
56+
Override this method instead of ``self.apply`` directly
57+
58+
:param item: The current item in the traversal.
59+
:param context: The context associated with the current traversal step.
60+
"""
61+
raise NotImplementedError()
62+
63+
5364
class StepActionWithContextValue(StepAction[T], ContextValueComputer[T]):
5465
"""
5566
Interface representing a step action that utilises a value stored in the :class:`StepContext`.
@@ -59,19 +70,9 @@ class StepActionWithContextValue(StepAction[T], ContextValueComputer[T]):
5970
"""
6071

6172
def __init__(self, key: str, _func: StepActionFunc = None):
62-
StepAction.__init__(self, _func or self._apply)
73+
StepAction.__init__(self, _func)
6374
ContextValueComputer.__init__(self, key)
6475

65-
@abstractmethod
66-
def _apply(self, item: T, context: StepContext):
67-
"""
68-
Override this method instead of ``self.apply`` directly
69-
70-
:param item: The current item in the traversal.
71-
:param context: The context associated with the current traversal step.
72-
"""
73-
raise NotImplementedError()
74-
7576
@abstractmethod
7677
def compute_initial_value(self, item: T):
7778
raise NotImplementedError()

src/zepben/evolve/services/network/tracing/traversal/traversal.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
__all__ = ["Traversal"]
2727

2828
from zepben.evolve.services.network.tracing.traversal.queue import TraversalQueue
29-
from zepben.evolve.util import extra_kwargs_not_allowed
3029

3130
T = TypeVar('T')
3231
U = TypeVar('U')
@@ -205,21 +204,20 @@ def create_new_this(self) -> D:
205204
raise NotImplementedError
206205

207206
@singledispatchmethod
208-
def add_condition(self, condition: ConditionTypes, **kwargs) -> D:
207+
def add_condition(self, condition: ConditionTypes) -> D:
209208
"""
210209
Adds a traversal condition to the traversal.
211210
212211
:param condition: The condition to add.
213-
:keyword allow_re_wrapping: Allow rewrapping of :class:`StopConditions` with debug logging
214212
215213
:return: this traversal instance.
216214
"""
217215

218216
if callable(condition): # Callable[[NetworkTraceStep[T], StepContext], None]
219217
if len(inspect.getfullargspec(condition).args) == 2:
220-
return self.add_stop_condition(condition, **kwargs)
218+
return self.add_stop_condition(condition)
221219
elif len(inspect.getfullargspec(condition).args) == 4:
222-
return self.add_queue_condition(condition, **kwargs)
220+
return self.add_queue_condition(condition)
223221
else:
224222
raise RuntimeError(f'Condition does not match expected: Number of args is not 2(Stop Condition) or 4(QueueCondition)')
225223

@@ -231,30 +229,27 @@ def add_condition(self, condition: ConditionTypes, **kwargs) -> D:
231229

232230
@singledispatchmethod
233231
@add_condition.register(StopCondition)
234-
def add_stop_condition(self, condition: StopConditionTypes, **kwargs) -> D:
232+
def add_stop_condition(self, condition: StopConditionTypes) -> D:
235233
"""
236234
Adds a stop condition to the traversal. If any stop condition returns
237235
``True``, the traversal will not call the callback to queue more items
238236
from the current item.
239237
240238
:param condition: The stop condition to add.
241-
:keyword allow_re_wrapping: Allow rewrapping of :class:`StopCondition`s with debug logging
242239
:return: this traversal instance.
243240
"""
244241

245242
raise RuntimeError(f'Condition [{condition.__class__.__name__}] does not match expected: [StopCondition | StopConditionWithContextValue | Callable]')
246243

247244
@add_stop_condition.register(Callable)
248-
def _(self, condition: ShouldStop, **kwargs):
249-
return self.add_stop_condition(StopCondition(condition), **kwargs)
245+
def _(self, condition: ShouldStop):
246+
return self.add_stop_condition(StopCondition(condition))
250247

251248
@add_stop_condition.register
252-
def _(self, condition: StopCondition, **kwargs):
249+
def _(self, condition: StopCondition):
253250

254251
if self._debug_logger is not None:
255-
self._debug_logger.wrap(condition, kwargs.pop('allow_re_wrapping', False))
256-
257-
extra_kwargs_not_allowed(kwargs, 'add_stop_condition')
252+
self._debug_logger.wrap(condition)
258253

259254
self.stop_conditions.append(condition)
260255
if isinstance(condition, StopConditionWithContextValue):
@@ -281,30 +276,27 @@ def matches_any_stop_condition(self, item: T, context: StepContext) -> bool:
281276

282277
@add_condition.register(QueueCondition)
283278
@singledispatchmethod
284-
def add_queue_condition(self, condition: QueueConditionTypes, **kwargs) -> D:
279+
def add_queue_condition(self, condition: QueueConditionTypes) -> D:
285280
"""
286281
Adds a queue condition to the traversal.
287282
Queue conditions determine whether an item should be queued for traversal.
288283
All registered queue conditions must return true for an item to be queued.
289284
290285
:param condition: The queue condition to add.
291-
:keyword allow_re_wrapping: Allow rewrapping of :class:`QueueCondition`s with debug logging
292286
:returns: The current traversal instance.
293287
"""
294288

295289
raise RuntimeError(f'Condition [{condition.__class__.__name__}] does not match expected: [QueueCondition | QueueConditionWithContextValue | Callable]')
296290

297291
@add_queue_condition.register(Callable)
298-
def _(self, condition: ShouldQueue, **kwargs):
299-
return self.add_queue_condition(QueueCondition(condition), **kwargs)
292+
def _(self, condition: ShouldQueue):
293+
return self.add_queue_condition(QueueCondition(condition))
300294

301295
@add_queue_condition.register
302-
def _(self, condition: QueueCondition, **kwargs):
296+
def _(self, condition: QueueCondition):
303297

304298
if self._debug_logger is not None:
305-
self._debug_logger.wrap(condition, kwargs.pop('allow_re_wrapping', False))
306-
307-
extra_kwargs_not_allowed(kwargs, 'add_queue_condition')
299+
self._debug_logger.wrap(condition)
308300

309301
self.queue_conditions.append(condition)
310302
if isinstance(condition, QueueConditionWithContextValue):
@@ -324,24 +316,21 @@ def copy_queue_conditions(self, other: Traversal[T, D]) -> D:
324316
return self
325317

326318
@singledispatchmethod
327-
def add_step_action(self, action: StepActionTypes, **kwargs) -> D:
319+
def add_step_action(self, action: StepActionTypes) -> D:
328320
"""
329321
Adds an action to be performed on each item in the traversal, including the
330322
starting items.
331323
332324
:param action: The action to perform on each item.
333-
:keyword allow_re_wrapping: Allow rewrapping of :class:`StepAction`s with debug logging
334325
:return: The current traversal instance.
335326
"""
336327

337328
raise RuntimeError(f'StepAction [{action.__class__.__name__}] does not match expected: [StepAction | StepActionWithContextValue | Callable]')
338329

339330
@add_step_action.register
340-
def _(self, action: StepAction, **kwargs):
331+
def _(self, action: StepAction):
341332
if self._debug_logger is not None:
342-
self._debug_logger.wrap(action, kwargs.pop('allow_re_wrapping', False))
343-
344-
extra_kwargs_not_allowed(kwargs, 'add_step_action')
333+
self._debug_logger.wrap(action)
345334

346335
self.step_actions.append(action)
347336
if isinstance(action, StepActionWithContextValue):
@@ -350,47 +339,45 @@ def _(self, action: StepAction, **kwargs):
350339

351340
@add_step_action.register(Callable)
352341
def _(self, action: StepActionFunc, **kwargs):
353-
return self.add_step_action(StepAction(action), **kwargs)
342+
return self.add_step_action(StepAction(action))
354343

355344
@singledispatchmethod
356-
def if_not_stopping(self, action: StepActionTypes, **kwargs) -> D:
345+
def if_not_stopping(self, action: StepActionTypes) -> D:
357346
"""
358347
Adds an action to be performed on each item that does not match any stop condition.
359348
360349
:param action: The action to perform on each non-stopping item.
361-
:keyword allow_re_wrapping: Allow rewrapping of :class:`StepAction`s with debug logging
362350
:return: The current traversal instance.
363351
"""
364352
raise RuntimeError(f'StepAction [{action}] does not match expected: [StepAction | StepActionWithContextValue | Callable]')
365353

366354
@if_not_stopping.register(Callable)
367-
def _(self, action: StepActionFunc, **kwargs) -> D:
368-
return self.add_step_action(lambda it, context: action(it, context) if not context.is_stopping else None, **kwargs)
355+
def _(self, action: StepActionFunc) -> D:
356+
return self.add_step_action(lambda it, context: action(it, context) if not context.is_stopping else None)
369357

370358
@if_not_stopping.register
371-
def _(self, action: StepAction, **kwargs) -> D:
359+
def _(self, action: StepAction) -> D:
372360
action.apply = lambda it, context: action._func(it, context) if not context.is_stopping else None
373-
return self.add_step_action(action, **kwargs)
361+
return self.add_step_action(action)
374362

375363
@singledispatchmethod
376-
def if_stopping(self, action: StepActionTypes, **kwargs) -> D:
364+
def if_stopping(self, action: StepActionTypes) -> D:
377365
"""
378366
Adds an action to be performed on each item that matches a stop condition.
379367
380368
:param action: The action to perform on each stopping item.
381-
:keyword allow_re_wrapping: Allow rewrapping of :class:`StepActions`s with debug logging
382369
:return: The current traversal instance.
383370
"""
384371
raise RuntimeError(f'StepAction [{action}] does not match expected: [StepAction | StepActionWithContextValue | Callable]')
385372

386373
@if_stopping.register(Callable)
387-
def _(self, action: StepActionFunc, **kwargs) -> D:
388-
return self.add_step_action(lambda it, context: action(it, context) if context.is_stopping else None, **kwargs)
374+
def _(self, action: StepActionFunc) -> D:
375+
return self.add_step_action(lambda it, context: action(it, context) if context.is_stopping else None)
389376

390377
@if_stopping.register
391-
def _(self, action: StepAction, **kwargs) -> D:
378+
def _(self, action: StepAction) -> D:
392379
action.apply = lambda it, context: action._func(it, context) if context.is_stopping else None
393-
return self.add_step_action(action, **kwargs)
380+
return self.add_step_action(action)
394381

395382
def copy_step_actions(self, other: Traversal[T, D]) -> D:
396383
"""

src/zepben/evolve/util.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,6 @@ def __get__(self, cls, owner: T) -> T:
183183
return classmethod(self.fget).__get__(None, owner)()
184184

185185

186-
def extra_kwargs_not_allowed(kwargs, function_name):
187-
if kwargs:
188-
raise TypeError(f"'{kwargs.pop()}' is an invalid keyword argument for {function_name}()")
189-
190-
191186
def datetime_to_timestamp(date_time: datetime) -> PBTimestamp:
192187
timestamp = PBTimestamp()
193188
timestamp.FromDatetime(date_time)

0 commit comments

Comments
 (0)