Skip to content

Commit 7cb7b40

Browse files
Guanyi LiGuanyi Li
authored andcommitted
add kwargs to invoke_context and add an example
1 parent 09c8c54 commit 7cb7b40

7 files changed

Lines changed: 2198 additions & 1935 deletions

File tree

docs/docs/user-guide/models.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ assistant_message = Message(
5858
| `conversation_id` | Unique identifier for a conversation between user and assistant. |
5959
| `invoke_id` | Unique identifier for each conversation invoke - an invoke can involve multiple agents.|
6060
| `assistant_request_id` | Created when an agent receives a request from the user. |
61-
| `user_id` | Optional user identifier, defaults to empty string. |
61+
| `user_id` | Optional user identifier, defaults to empty string. |
62+
| `kwargs` | Optional additional keyword arguments and context for the workflow |
6263

6364
### InvokeContext Usage Example
6465

grafi/common/models/invoke_context.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any
2+
from typing import Dict
13
from typing import Optional
24

35
from pydantic import BaseModel
@@ -10,9 +12,22 @@ class InvokeContext(BaseModel):
1012
invoke_id: invoke id of each conversation, an invoke can involve multiple agents
1113
assistant_request_id: assistant_request_id is create when agent receive a request from the user
1214
user_id: user id
15+
kwargs: optional field for any additional context or keyword arguments that need to be passed through the workflow
1316
"""
1417

15-
conversation_id: str
16-
invoke_id: str
17-
assistant_request_id: str
18-
user_id: Optional[str] = Field(default="")
18+
conversation_id: str = Field(
19+
description="Unique identifier for a conversation between user and assistant"
20+
)
21+
invoke_id: str = Field(
22+
description="Unique identifier for each conversation invoke - an invoke can involve multiple agents"
23+
)
24+
assistant_request_id: str = Field(
25+
description="Created when an agent receives a request from the user"
26+
)
27+
user_id: Optional[str] = Field(
28+
default="", description="Optional user identifier, defaults to empty string"
29+
)
30+
kwargs: Optional[Dict[str, Any]] = Field(
31+
default_factory=dict,
32+
description="Additional keyword arguments and context for the workflow",
33+
)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dev = [
3030
"llama-index-core>=0.12.48",
3131
"llama-index-embeddings-openai>=0.3.1",
3232
"llama-index-llms-openai>=0.4.7",
33+
"markdown>=3.8.2",
3334
"mcp>=1.11.0",
3435
"mypy>=1.16.1",
3536
"ollama>=0.5.1",
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import List
2+
3+
from jinja2 import Template
4+
from loguru import logger
5+
6+
from grafi.common.events.topic_events.consume_from_topic_event import (
7+
ConsumeFromTopicEvent,
8+
)
9+
from grafi.common.models.invoke_context import InvokeContext
10+
from grafi.common.models.message import Message
11+
from grafi.common.models.message import Messages
12+
from grafi.tools.llms.llm_command import LLMCommand
13+
14+
15+
class LLMPromptTemplateCommand(LLMCommand):
16+
17+
def get_tool_input(
18+
self,
19+
invoke_context: InvokeContext,
20+
node_input: List[ConsumeFromTopicEvent],
21+
) -> Messages:
22+
"""Prepare the input for the LLM command based on the node input and invoke context."""
23+
24+
messages = super().get_tool_input(invoke_context, node_input)
25+
26+
message = messages[-1] if messages else None
27+
28+
if "prompt_template" in invoke_context.kwargs:
29+
30+
template_str = invoke_context.kwargs["prompt_template"]
31+
template: Template = Template(template_str)
32+
33+
if message and message.content:
34+
# Render the Jinja template with the message content as input
35+
try:
36+
rendered_prompt = template.render(input_text=message.content)
37+
38+
# Create a new message with the rendered template
39+
new_message = Message(role="user", content=rendered_prompt)
40+
41+
logger.info(
42+
f"Rendered prompt template with input: {message.content[:100]}..."
43+
)
44+
return [new_message]
45+
except Exception as e:
46+
logger.error(f"Error rendering template: {e}")
47+
# Fallback to original message if template rendering fails
48+
return messages
49+
50+
return messages
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import os
2+
from typing import Optional
3+
from typing import Self
4+
5+
from openinference.semconv.trace import OpenInferenceSpanKindValues
6+
from pydantic import Field
7+
8+
from grafi.assistants.assistant import Assistant
9+
from grafi.assistants.assistant_base import AssistantBaseBuilder
10+
from grafi.common.topics.output_topic import agent_output_topic
11+
from grafi.common.topics.topic import agent_input_topic
12+
from grafi.nodes.node import Node
13+
from grafi.tools.llms.impl.openai_tool import OpenAITool
14+
from grafi.workflows.impl.event_driven_workflow import EventDrivenWorkflow
15+
from tests_integration.invoke_kwargs.llm_prompt_template_command import (
16+
LLMPromptTemplateCommand,
17+
)
18+
19+
20+
class SimpleLLMPromptTemplateAssistant(Assistant):
21+
"""
22+
A simple assistant class that uses OpenAI's language model to process input and generate responses.
23+
24+
This class sets up a workflow with a single LLM node using OpenAI's API, and provides a method
25+
to run input through this workflow.
26+
27+
Attributes:
28+
api_key (str): The API key for OpenAI. If not provided, it tries to use the OPENAI_API_KEY environment variable.
29+
model (str): The name of the OpenAI model to use.
30+
event_store (EventStore): An instance of EventStore to record events during the assistant's operation.
31+
"""
32+
33+
oi_span_type: OpenInferenceSpanKindValues = Field(
34+
default=OpenInferenceSpanKindValues.AGENT
35+
)
36+
name: str = Field(default="SimpleLLMPromptTemplateAssistant")
37+
type: str = Field(default="SimpleLLMPromptTemplateAssistant")
38+
api_key: Optional[str] = Field(default_factory=lambda: os.getenv("OPENAI_API_KEY"))
39+
system_message: Optional[str] = Field(default=None)
40+
model: str = Field(default="gpt-4o-mini")
41+
42+
@classmethod
43+
def builder(cls) -> "SimpleLLMPromptTemplateAssistantBuilder":
44+
"""Return a builder for SimpleLLMPromptTemplateAssistant."""
45+
return SimpleLLMPromptTemplateAssistantBuilder(cls)
46+
47+
def _construct_workflow(self) -> "SimpleLLMPromptTemplateAssistant":
48+
# Create an LLM node
49+
llm_node: Node = (
50+
Node.builder()
51+
.name("OpenAINode")
52+
.subscribe(agent_input_topic)
53+
.publish_to(agent_output_topic)
54+
.build()
55+
)
56+
57+
llm_node.command = LLMPromptTemplateCommand(
58+
tool=OpenAITool.builder()
59+
.name("OpenAITool")
60+
.api_key(self.api_key)
61+
.model(self.model)
62+
.system_message(self.system_message)
63+
.build()
64+
)
65+
66+
# Create a workflow and add the LLM node
67+
self.workflow = (
68+
EventDrivenWorkflow.builder()
69+
.name("SimpleLLMWorkflow")
70+
.node(llm_node)
71+
.build()
72+
)
73+
74+
return self
75+
76+
77+
class SimpleLLMPromptTemplateAssistantBuilder(
78+
AssistantBaseBuilder[SimpleLLMPromptTemplateAssistant]
79+
):
80+
"""Concrete builder for SimpleLLMPromptTemplateAssistant."""
81+
82+
def api_key(self, api_key: str) -> Self:
83+
self.kwargs["api_key"] = api_key
84+
return self
85+
86+
def system_message(self, system_message: str) -> Self:
87+
self.kwargs["system_message"] = system_message
88+
return self
89+
90+
def model(self, model: str) -> Self:
91+
self.kwargs["model"] = model
92+
return self
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# We will test the SimpleLLMAssistant class in this file.
2+
3+
import os
4+
import uuid
5+
6+
import markdown
7+
8+
from grafi.common.containers.container import container
9+
from grafi.common.instrumentations.tracing import TracingOptions
10+
from grafi.common.instrumentations.tracing import setup_tracing
11+
from grafi.common.models.invoke_context import InvokeContext
12+
from grafi.common.models.message import Message
13+
from tests_integration.invoke_kwargs.simple_llm_prompt_template_assistant import (
14+
SimpleLLMPromptTemplateAssistant,
15+
)
16+
17+
18+
container.register_tracer(setup_tracing(tracing_options=TracingOptions.IN_MEMORY))
19+
event_store = container.event_store
20+
21+
api_key = os.getenv("OPENAI_API_KEY", "")
22+
23+
24+
def get_invoke_context() -> InvokeContext:
25+
26+
return InvokeContext(
27+
conversation_id="conversation_id",
28+
invoke_id=uuid.uuid4().hex,
29+
assistant_request_id=uuid.uuid4().hex,
30+
kwargs={
31+
"prompt_template": """You are a skilled poet and literary analyst. Your task is to analyze the given input and create a 14-line sonnet based on your analysis.
32+
33+
## Input Analysis Instructions:
34+
1. Carefully read and understand the provided input
35+
2. Identify the main themes, emotions, imagery, and key concepts
36+
3. Determine the appropriate tone and mood for the sonnet
37+
4. Consider metaphors, symbolism, or literary devices that would enhance the poem
38+
39+
## Sonnet Requirements:
40+
- Must be exactly 14 lines
41+
- Follow traditional sonnet structure (Shakespearean or Petrarchan)
42+
- Use appropriate rhyme scheme (ABAB CDCD EFEF GG for Shakespearean, or ABBAABBA CDECDE/CDCDCD for Petrarchan)
43+
- Maintain consistent meter (preferably iambic pentameter)
44+
- Include a clear thematic development with a turn (volta)
45+
- End with a powerful concluding couplet or tercet
46+
47+
## Input to Analyze:
48+
{{ input_text }}
49+
50+
## Analysis:
51+
Please first provide a brief analysis of the input, identifying:
52+
- Central themes:
53+
- Emotional tone:
54+
- Key imagery:
55+
- Poetic approach:
56+
57+
## Generated Sonnet:
58+
Based on your analysis, create a 14-line sonnet that captures the essence of the input:
59+
60+
```
61+
[Your 14-line sonnet here]
62+
```
63+
64+
## Explanation:
65+
Briefly explain how your sonnet reflects the input and the literary devices used."""
66+
},
67+
)
68+
69+
70+
def test_simple_llm_assistant() -> None:
71+
invoke_context = get_invoke_context()
72+
assistant = (
73+
SimpleLLMPromptTemplateAssistant.builder()
74+
.name("SimpleLLMPromptTemplateAssistant")
75+
.build()
76+
)
77+
event_store.clear_events()
78+
79+
input_data = [
80+
Message(
81+
content="Graphite is a event driven agentic AI platform, it offers real time observability, comprehensive auditing, and high performance workflow.",
82+
role="user",
83+
)
84+
]
85+
output = assistant.invoke(invoke_context, input_data)
86+
87+
html = markdown.markdown(output[0].content)
88+
print(html)
89+
assert output is not None
90+
assert len(event_store.get_events()) == 12
91+
92+
93+
test_simple_llm_assistant()

0 commit comments

Comments
 (0)