|
15 | 15 | from uipath_agent_framework.runtime.schema import get_agent_graph |
16 | 16 |
|
17 | 17 |
|
18 | | -def _make_agent(name="test_agent", tools=None) -> BaseAgent: |
| 18 | +def _make_agent(name="test_agent", tools=None, model_id=None) -> BaseAgent: |
19 | 19 | """Create a mock BaseAgent for testing.""" |
20 | 20 | agent = MagicMock(spec=BaseAgent) |
21 | 21 | agent.name = name |
22 | 22 | agent.default_options = {"tools": tools or []} |
| 23 | + if model_id is not None: |
| 24 | + client = MagicMock() |
| 25 | + client.model_id = model_id |
| 26 | + agent.client = client |
23 | 27 | return agent |
24 | 28 |
|
25 | 29 |
|
@@ -477,3 +481,84 @@ def test_workflow_concurrent_pattern(self): |
477 | 481 | assert ("topics", "merge") in edge_pairs |
478 | 482 | assert ("summary", "merge") in edge_pairs |
479 | 483 | assert ("merge", "__end__") in edge_pairs |
| 484 | + |
| 485 | + def test_agent_executor_with_model_is_model_node(self): |
| 486 | + """AgentExecutor with a chat client becomes a model node.""" |
| 487 | + inner_agent = _make_agent(name="assistant", model_id="gpt-4.1-mini-2025-04-14") |
| 488 | + executors = { |
| 489 | + "assistant": _make_executor("assistant", agent=inner_agent), |
| 490 | + } |
| 491 | + workflow = _make_workflow( |
| 492 | + executors=executors, |
| 493 | + edge_groups=[], |
| 494 | + start_executor_id="assistant", |
| 495 | + ) |
| 496 | + agent = _make_workflow_agent(workflow) |
| 497 | + graph = get_agent_graph(agent) |
| 498 | + |
| 499 | + node = next(n for n in graph.nodes if n.id == "assistant") |
| 500 | + assert node.type == "model" |
| 501 | + assert node.metadata == {"model_name": "gpt-4.1-mini-2025-04-14"} |
| 502 | + |
| 503 | + def test_agent_executor_without_model_is_regular_node(self): |
| 504 | + """AgentExecutor without a chat client stays as a regular node.""" |
| 505 | + inner_agent = _make_agent(name="assistant") |
| 506 | + executors = { |
| 507 | + "assistant": _make_executor("assistant", agent=inner_agent), |
| 508 | + } |
| 509 | + workflow = _make_workflow( |
| 510 | + executors=executors, |
| 511 | + edge_groups=[], |
| 512 | + start_executor_id="assistant", |
| 513 | + ) |
| 514 | + agent = _make_workflow_agent(workflow) |
| 515 | + graph = get_agent_graph(agent) |
| 516 | + |
| 517 | + node = next(n for n in graph.nodes if n.id == "assistant") |
| 518 | + assert node.type == "node" |
| 519 | + assert node.metadata is None |
| 520 | + |
| 521 | + def test_plain_executor_is_regular_node(self): |
| 522 | + """Non-AgentExecutor stays as a regular node.""" |
| 523 | + executors = { |
| 524 | + "step": _make_executor("step"), # no agent |
| 525 | + } |
| 526 | + workflow = _make_workflow( |
| 527 | + executors=executors, |
| 528 | + edge_groups=[], |
| 529 | + start_executor_id="step", |
| 530 | + ) |
| 531 | + agent = _make_workflow_agent(workflow) |
| 532 | + graph = get_agent_graph(agent) |
| 533 | + |
| 534 | + node = next(n for n in graph.nodes if n.id == "step") |
| 535 | + assert node.type == "node" |
| 536 | + assert node.metadata is None |
| 537 | + |
| 538 | + def test_multi_agent_workflow_with_different_models(self): |
| 539 | + """Multiple agents with different models each get their model name.""" |
| 540 | + triage_agent = _make_agent(name="triage", model_id="gpt-4.1-2025-04-14") |
| 541 | + billing_agent = _make_agent( |
| 542 | + name="billing", model_id="anthropic.claude-haiku-4-5-20251001-v1:0" |
| 543 | + ) |
| 544 | + executors = { |
| 545 | + "triage": _make_executor("triage", agent=triage_agent), |
| 546 | + "billing": _make_executor("billing", agent=billing_agent), |
| 547 | + } |
| 548 | + workflow = _make_workflow( |
| 549 | + executors=executors, |
| 550 | + edge_groups=[], |
| 551 | + start_executor_id="triage", |
| 552 | + ) |
| 553 | + agent = _make_workflow_agent(workflow) |
| 554 | + graph = get_agent_graph(agent) |
| 555 | + |
| 556 | + triage_node = next(n for n in graph.nodes if n.id == "triage") |
| 557 | + assert triage_node.type == "model" |
| 558 | + assert triage_node.metadata == {"model_name": "gpt-4.1-2025-04-14"} |
| 559 | + |
| 560 | + billing_node = next(n for n in graph.nodes if n.id == "billing") |
| 561 | + assert billing_node.type == "model" |
| 562 | + assert billing_node.metadata == { |
| 563 | + "model_name": "anthropic.claude-haiku-4-5-20251001-v1:0" |
| 564 | + } |
0 commit comments