Skip to content

Commit a66065e

Browse files
committed
Friendly error message for not using DataModel in Node
1 parent 08710c2 commit a66065e

File tree

5 files changed

+122
-0
lines changed

5 files changed

+122
-0
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
redis_url = "redis://localhost:6379/0" # required
2+
extra_modules = ["exception_node"]
3+
4+
[[nodes]]
5+
node_name = "exception_node"
6+
node_class = "exception_node"
7+
8+
[nodes.node_args.print_channel_types]
9+
"tick/secs/1" = "tick"
10+
11+
[[nodes]]
12+
node_name = "tick"
13+
node_class = "tick"
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import AsyncIterator
2+
from aact.nodes import Node, NodeFactory
3+
from aact.messages import Text
4+
from aact import Message
5+
6+
7+
@NodeFactory.register("exception_node")
8+
class ExceptionNode(Node[Text, Text]):
9+
def event_handler(
10+
self, input_channel: str, input_message: Message[Text]
11+
) -> AsyncIterator[tuple[str, Message[Text]]]:
12+
raise Exception("This is an exception from the node.")

src/aact/nodes/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,16 @@ def __init__(
176176
redis_url: str = "redis://localhost:6379/0",
177177
):
178178
try:
179+
for _, input_channel_type in input_channel_types:
180+
if not issubclass(input_channel_type, DataModel):
181+
raise TypeError(
182+
f"Input channel type {input_channel_type} is not a subclass of DataModel"
183+
)
184+
for _, output_channel_type in output_channel_types:
185+
if not issubclass(output_channel_type, DataModel):
186+
raise TypeError(
187+
f"Output channel type {output_channel_type} is not a subclass of DataModel"
188+
)
179189
BaseModel.__init__(
180190
self,
181191
input_channel_types=dict(input_channel_types),

tests/messages/test_message.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from aact.messages import DataModel, DataModelFactory
2+
3+
4+
def test_create_data_model() -> None:
5+
# Create a new data model
6+
@DataModelFactory.register("MyDataModel")
7+
class MyDataModel(DataModel):
8+
name: str
9+
age: int
10+
11+
# Create an instance of the data model
12+
instance = MyDataModel(name="John", age=30)
13+
14+
# Validate the instance
15+
assert instance.name == "John"
16+
assert instance.age == 30
17+
18+
# Check if the instance is of type DataModel
19+
assert isinstance(instance, DataModel)

tests/nodes/test_node_creation.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import pytest
2+
from aact.messages.base import DataModel
3+
from aact.messages.registry import DataModelFactory
4+
from aact.nodes import Node
5+
from aact.nodes import NodeFactory
6+
7+
8+
def test_create_node() -> None:
9+
class AWrong:
10+
pass
11+
12+
class BWrong:
13+
pass
14+
15+
# Create a new node
16+
@NodeFactory.register("MyNode")
17+
class MyNode(Node[AWrong, BWrong]): # type: ignore[type-var]
18+
def __init__(
19+
self,
20+
) -> None:
21+
super().__init__(
22+
node_name="MyNode",
23+
input_channel_types=[("input", AWrong)],
24+
output_channel_types=[("output", BWrong)],
25+
)
26+
27+
async def event_handler(
28+
self, input_channel: str, input_message: AWrong
29+
) -> None: # type: ignore[override]
30+
# Handle the event
31+
pass
32+
33+
@DataModelFactory.register("data_model_a")
34+
class ACorrect(DataModel):
35+
pass
36+
37+
@DataModelFactory.register("data_model_b")
38+
class BCorrect(DataModel):
39+
pass
40+
41+
# Create a new node
42+
@NodeFactory.register("MyNodeCorrect")
43+
class MyNodeCorect(Node[ACorrect, BCorrect]):
44+
def __init__(
45+
self,
46+
) -> None:
47+
super().__init__(
48+
node_name="MyNode",
49+
input_channel_types=[("input", ACorrect)],
50+
output_channel_types=[("output", BCorrect)],
51+
)
52+
53+
async def event_handler(
54+
self, input_channel: str, input_message: ACorrect
55+
) -> None: # type: ignore[override]
56+
# Handle the event
57+
pass
58+
59+
# Create an instance of the node
60+
with pytest.raises(TypeError) as excinfo:
61+
_ = MyNode()
62+
63+
assert (
64+
"Input channel type <class 'test_node_creation.test_create_node.<locals>.AWrong'> is not a subclass of DataModel"
65+
in str(excinfo.value)
66+
)
67+
68+
_ = MyNodeCorect()

0 commit comments

Comments
 (0)