Skip to content

Commit 50b2f7c

Browse files
authored
[EWB-4558] await only awaitable step actions (#203)
Signed-off-by: Max Chesterfield <max.chesterfield@zepben.com>
1 parent 915f01c commit 50b2f7c

4 files changed

Lines changed: 53 additions & 4 deletions

File tree

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
### Fixes
1313
* Moved ZepbenTokenAuth to use python dataclasses instead of `zepben.ewb.dataclassy`, existing code should work as is.
14+
* `TypeError`s occurring in `StepAction`s will no longer silently pass
1415

1516
### Notes
1617
* None.

src/zepben/ewb/services/network/tracing/traversal/traversal.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,10 +390,9 @@ def copy_step_actions(self, other: Traversal[T, D]) -> D:
390390

391391
async def apply_step_actions(self, item: T, context: StepContext) -> D:
392392
for it in self.step_actions:
393-
try:
394-
await it.apply(item, context)
395-
except TypeError:
396-
pass
393+
_apply = it.apply(item, context)
394+
if inspect.iscoroutine(_apply):
395+
await _apply
397396
return self
398397

399398
def add_context_value_computer(self, computer: ContextValueComputer[T]) -> D:

test/services/network/tracing/traversal/test_step_action.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# This Source Code Form is subject to the terms of the Mozilla Public
33
# License, v. 2.0. If a copy of the MPL was not distributed with this
44
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
5+
import pytest
56
from pytest import raises
67

78
from zepben.ewb import StepAction, StepContext
@@ -48,3 +49,15 @@ def _apply(self, item: T, context: StepContext):
4849
step_action.apply(expected_item, expected_ctx)
4950

5051
assert captured == [(expected_item, expected_ctx)]
52+
53+
@pytest.mark.asyncio
54+
async def test_async_step_action(self):
55+
captured = []
56+
57+
class MyStepAction(StepAction):
58+
async def _apply(self, item: T, context: StepContext):
59+
captured.append(item)
60+
61+
step_action = MyStepAction()
62+
await step_action.apply(1, None)
63+
assert captured == [1]

test/services/network/tracing/traversal/test_traversal.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,3 +510,39 @@ async def test_multiple_start_items_respect_can_stop_on_start(self):
510510
await traversal.run(can_stop_on_start_item=False)
511511

512512
assert steps == [1, 11, 2, 12]
513+
514+
@pytest.mark.asyncio
515+
async def test_can_use_async_step_action(self):
516+
steps = []
517+
518+
class MyStepAction(StepAction):
519+
async def _apply(self, item: T, context: StepContext):
520+
steps.append(item)
521+
522+
traversal = (
523+
_create_traversal(queue=TraversalQueue.breadth_first())
524+
.add_stop_condition(lambda item, x: True)
525+
.add_step_action(MyStepAction())
526+
.add_start_item(1)
527+
.add_start_item(11)
528+
)
529+
await traversal.run(can_stop_on_start_item=False)
530+
531+
assert steps == [1, 11, 2, 12]
532+
533+
@pytest.mark.asyncio
534+
async def test_errors_in_step_action_arent_masked(self):
535+
class MyStepAction(StepAction):
536+
async def _apply(self, item: T, context: StepContext):
537+
# noinspection PyTypeChecker
538+
int(1 + "abc")
539+
540+
traversal = (
541+
_create_traversal(queue=TraversalQueue.breadth_first())
542+
.add_stop_condition(lambda item, x: True)
543+
.add_step_action(MyStepAction())
544+
.add_start_item(1)
545+
.add_start_item(11)
546+
)
547+
with pytest.raises(TypeError):
548+
await traversal.run(can_stop_on_start_item=False)

0 commit comments

Comments
 (0)