From 48500bd68798ead238f55d5f2905b59dbbb2ba0d Mon Sep 17 00:00:00 2001 From: Krishna Date: Fri, 25 Oct 2024 10:52:16 +0530 Subject: [PATCH 1/2] events.py Simplified conditional checks and utilized get for dictionary access Added logging statements to key points for better traceability. Improved the structure of message transformations, emphasizing clarity and maintainability --- src/controlflow/events/events.py | 88 +++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 31 deletions(-) diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py index a28383e6..ee7b502a 100644 --- a/src/controlflow/events/events.py +++ b/src/controlflow/events/events.py @@ -1,7 +1,6 @@ -from typing import TYPE_CHECKING, Literal, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union, List from pydantic import field_validator, model_validator - from controlflow.agents.agent import Agent from controlflow.events.base import Event, UnpersistedEvent from controlflow.llm.messages import ( @@ -16,6 +15,7 @@ if TYPE_CHECKING: from controlflow.events.message_compiler import CompileContext + logger = get_logger(__name__) ORCHESTRATOR_PREFIX = "The following message is from the orchestrator." @@ -23,84 +23,97 @@ class OrchestratorMessage(Event): """ - Messages from the orchestrator to agents. + Represents messages from the orchestrator to agents. """ event: Literal["orchestrator-message"] = "orchestrator-message" - content: Union[str, list[Union[str, dict]]] + content: Union[str, List[Union[str, dict]]] prefix: Optional[str] = ORCHESTRATOR_PREFIX name: Optional[str] = None - def to_messages(self, context: "CompileContext") -> list[BaseMessage]: - messages = [] - # if self.prefix: - # messages.append(SystemMessage(content=self.prefix)) - messages.append( + def to_messages(self, context: "CompileContext") -> List[BaseMessage]: + logger.debug("Creating orchestrator messages with prefix: %s", self.prefix) + messages = [ HumanMessage(content=f"({self.prefix})\n\n{self.content}", name=self.name) - ) + ] return messages class UserMessage(Event): + """ + Represents messages from the user. + """ + event: Literal["user-message"] = "user-message" - content: Union[str, list[Union[str, dict]]] + content: Union[str, List[Union[str, dict]]] - def to_messages(self, context: "CompileContext") -> list[BaseMessage]: + def to_messages(self, context: "CompileContext") -> List[BaseMessage]: + logger.debug("Creating user message: %s", self.content) return [HumanMessage(content=self.content)] class AgentMessage(Event): + """ + Represents messages from an agent. + """ + event: Literal["agent-message"] = "agent-message" agent: Agent message: dict @field_validator("message", mode="before") - def _message(cls, v): + def validate_message(cls, v): if isinstance(v, BaseMessage): v = v.model_dump() v["type"] = "ai" return v @model_validator(mode="after") - def _finalize(self): + def finalize_message(self): self.message["name"] = self.agent.name + logger.debug("Finalized agent message for agent: %s", self.agent.name) return self @property def ai_message(self) -> AIMessage: return AIMessage(**self.message) - def to_messages(self, context: "CompileContext") -> list[BaseMessage]: + def to_messages(self, context: "CompileContext") -> List[BaseMessage]: if self.agent.name == context.agent.name: return [self.ai_message] - elif self.message["content"]: + + if self.message.get("content"): return OrchestratorMessage( prefix=f'The following message was posted by Agent "{self.agent.name}" with ID {self.agent.id}', content=self.message["content"], name=self.agent.name, ).to_messages(context) - else: - return [] + + return [] class AgentMessageDelta(UnpersistedEvent): - event: Literal["agent-message-delta"] = "agent-message-delta" + """ + Represents a delta change in an agent's message. + """ + event: Literal["agent-message-delta"] = "agent-message-delta" agent: Agent delta: dict snapshot: dict @field_validator("delta", "snapshot", mode="before") - def _message(cls, v): + def validate_delta_and_snapshot(cls, v): if isinstance(v, BaseMessage): v = v.model_dump() v["type"] = "AIMessageChunk" return v @model_validator(mode="after") - def _finalize(self): + def finalize_delta_and_snapshot(self): self.delta["name"] = self.agent.name self.snapshot["name"] = self.agent.name + logger.debug("Finalized delta and snapshot for agent: %s", self.agent.name) return self @property @@ -109,29 +122,42 @@ def delta_message(self) -> AIMessageChunk: @property def snapshot_message(self) -> AIMessage: - return AIMessage(**self.snapshot | {"type": "ai"}) + return AIMessage(**{**self.snapshot, "type": "ai"}) class EndTurn(Event): + """ + Represents the end of an agent's turn. + """ + event: Literal["end-turn"] = "end-turn" agent: Agent next_agent_name: Optional[str] = None class ToolCallEvent(Event): + """ + Represents a tool call made by an agent. + """ + event: Literal["tool-call"] = "tool-call" agent: Agent tool_call: Union[ToolCall, InvalidToolCall] class ToolResultEvent(Event): + """ + Represents the result of a tool call made by an agent. + """ + event: Literal["tool-result"] = "tool-result" agent: Agent tool_call: Union[ToolCall, InvalidToolCall] tool_result: ToolResult - def to_messages(self, context: "CompileContext") -> list[BaseMessage]: + def to_messages(self, context: "CompileContext") -> List[BaseMessage]: if self.agent.name == context.agent.name: + logger.debug("Creating tool result message for agent: %s", self.agent.name) return [ ToolMessage( content=self.tool_result.str_result, @@ -139,11 +165,11 @@ def to_messages(self, context: "CompileContext") -> list[BaseMessage]: name=self.agent.name, ) ] - else: - return OrchestratorMessage( - prefix=f'Agent "{self.agent.name}" with ID {self.agent.id} made a tool ' - f'call: {self.tool_call}. The tool{" failed and" if self.tool_result.is_error else " "} ' - f'produced this result:', - content=self.tool_result.str_result, - name=self.agent.name, - ).to_messages(context) + + logger.debug("Creating orchestrator message for tool result from agent: %s", self.agent.name) + return OrchestratorMessage( + prefix=f'Agent "{self.agent.name}" with ID {self.agent.id} made a tool call: {self.tool_call}. ' + f'The tool{" failed and" if self.tool_result.is_error else " "} produced this result:', + content=self.tool_result.str_result, + name=self.agent.name, + ).to_messages(context) From 3d7c1c566f92e7e8dc563a791739e2f572fac0ac Mon Sep 17 00:00:00 2001 From: Krishna Date: Wed, 30 Oct 2024 07:30:04 +0530 Subject: [PATCH 2/2] Update events.py made the changes suggested. --- src/controlflow/events/events.py | 197 ++++++++++++++++++++++++------- 1 file changed, 155 insertions(+), 42 deletions(-) diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py index 422408bc..c2cdc53b 100644 --- a/src/controlflow/events/events.py +++ b/src/controlflow/events/events.py @@ -1,10 +1,7 @@ -from typing import TYPE_CHECKING, Literal, Optional, Union, List - - +from typing import TYPE_CHECKING, Literal, Optional, Union from pydantic import ConfigDict, field_validator, model_validator - from controlflow.agents.agent import Agent from controlflow.events.base import Event, UnpersistedEvent from controlflow.llm.messages import ( @@ -19,7 +16,6 @@ if TYPE_CHECKING: from controlflow.events.message_compiler import CompileContext - logger = get_logger(__name__) ORCHESTRATOR_PREFIX = "The following message is from the orchestrator." @@ -27,38 +23,70 @@ class OrchestratorMessage(Event): """ - Represents messages from the orchestrator to agents. + Represents a message from the orchestrator to agents. + + Attributes: + event: Literal identifier for orchestrator messages. + content: The message content, either a string or a list of strings/dicts. + prefix: An optional prefix to specify the source of the message. + name: An optional name associated with the message sender. """ event: Literal["orchestrator-message"] = "orchestrator-message" - content: Union[str, List[Union[str, dict]]] + content: Union[str, list[Union[str, dict]]] prefix: Optional[str] = ORCHESTRATOR_PREFIX name: Optional[str] = None - def to_messages(self, context: "CompileContext") -> List[BaseMessage]: - logger.debug("Creating orchestrator messages with prefix: %s", self.prefix) - messages = [ + def to_messages(self, context: "CompileContext") -> list[BaseMessage]: + """ + Converts the orchestrator message into a list of BaseMessage instances. + + Args: + context: The context for message compilation. + + Returns: + A list of BaseMessage objects representing the orchestrator's message. + """ + messages = [] + messages.append( HumanMessage(content=f"({self.prefix})\n\n{self.content}", name=self.name) - ] + ) return messages class UserMessage(Event): """ - Represents messages from the user. + Represents a message sent by a user. + + Attributes: + event: Literal identifier for user messages. + content: The message content, either a string or a list of strings/dicts. """ event: Literal["user-message"] = "user-message" - content: Union[str, List[Union[str, dict]]] + content: Union[str, list[Union[str, dict]]] + + def to_messages(self, context: "CompileContext") -> list[BaseMessage]: + """ + Converts the user message into a list of BaseMessage instances. - def to_messages(self, context: "CompileContext") -> List[BaseMessage]: - logger.debug("Creating user message: %s", self.content) + Args: + context: The context for message compilation. + + Returns: + A list containing a single HumanMessage object. + """ return [HumanMessage(content=self.content)] class AgentMessage(Event): """ - Represents messages from an agent. + Represents a message sent by an agent. + + Attributes: + event: Literal identifier for agent messages. + agent: The agent sending the message. + message: The message content, in dictionary format. """ event: Literal["agent-message"] = "agent-message" @@ -66,72 +94,138 @@ class AgentMessage(Event): message: dict @field_validator("message", mode="before") - def validate_message(cls, v): + def _message(cls, v): + """ + Validates and converts the message format, setting its type to "ai" if needed. + + Args: + v: The initial message content. + + Returns: + The validated message content. + """ if isinstance(v, BaseMessage): v = v.model_dump() v["type"] = "ai" return v @model_validator(mode="after") - def finalize_message(self): + def _finalize(self): + """ + Finalizes the message by setting the agent's name. + + Returns: + The updated message with agent's name added. + """ self.message["name"] = self.agent.name - logger.debug("Finalized agent message for agent: %s", self.agent.name) return self @property def ai_message(self) -> AIMessage: + """ + Returns the message as an AIMessage object. + + Returns: + An instance of AIMessage. + """ return AIMessage(**self.message) - def to_messages(self, context: "CompileContext") -> List[BaseMessage]: + def to_messages(self, context: "CompileContext") -> list[BaseMessage]: + """ + Converts the agent message into a list of BaseMessage instances based on the context. + + Args: + context: The context for message compilation. + + Returns: + A list of BaseMessage objects, depending on whether the agent matches the context agent. + """ if self.agent.name == context.agent.name: return [self.ai_message] - - if self.message.get("content"): + elif self.message["content"]: return OrchestratorMessage( prefix=f'The following message was posted by Agent "{self.agent.name}" with ID {self.agent.id}', content=self.message["content"], name=self.agent.name, ).to_messages(context) - - return [] + else: + return [] class AgentMessageDelta(UnpersistedEvent): """ - Represents a delta change in an agent's message. + Represents an incremental update (delta) to an agent's message. + + Attributes: + event: Literal identifier for agent message deltas. + agent: The agent associated with the delta. + delta: The delta content, in dictionary format. + snapshot: The snapshot content representing the complete message at the current state. """ event: Literal["agent-message-delta"] = "agent-message-delta" + agent: Agent delta: dict snapshot: dict @field_validator("delta", "snapshot", mode="before") - def validate_delta_and_snapshot(cls, v): + def _message(cls, v): + """ + Validates and converts the delta and snapshot content format, setting type to "AIMessageChunk". + + Args: + v: The initial delta or snapshot content. + + Returns: + The validated content. + """ if isinstance(v, BaseMessage): v = v.model_dump() v["type"] = "AIMessageChunk" return v @model_validator(mode="after") - def finalize_delta_and_snapshot(self): + def _finalize(self): + """ + Finalizes the delta and snapshot by setting the agent's name. + + Returns: + The updated delta and snapshot with agent's name added. + """ self.delta["name"] = self.agent.name self.snapshot["name"] = self.agent.name - logger.debug("Finalized delta and snapshot for agent: %s", self.agent.name) return self @property def delta_message(self) -> AIMessageChunk: + """ + Returns the delta as an AIMessageChunk object. + + Returns: + An instance of AIMessageChunk. + """ return AIMessageChunk(**self.delta) @property def snapshot_message(self) -> AIMessage: - return AIMessage(**{**self.snapshot, "type": "ai"}) + """ + Returns the snapshot as an AIMessage object. + + Returns: + An instance of AIMessage. + """ + return AIMessage(**self.snapshot | {"type": "ai"}) class EndTurn(Event): """ - Represents the end of an agent's turn. + Represents an event signaling the end of an agent's turn. + + Attributes: + event: Literal identifier for end-turn events. + agent: The agent ending their turn. + next_agent_name: Optional name of the next agent to act. """ event: Literal["end-turn"] = "end-turn" @@ -141,7 +235,12 @@ class EndTurn(Event): class ToolCallEvent(Event): """ - Represents a tool call made by an agent. + Represents an event where an agent makes a tool call. + + Attributes: + event: Literal identifier for tool call events. + agent: The agent making the tool call. + tool_call: The tool call, either a valid ToolCall or InvalidToolCall. """ event: Literal["tool-call"] = "tool-call" @@ -151,7 +250,13 @@ class ToolCallEvent(Event): class ToolResultEvent(Event): """ - Represents the result of a tool call made by an agent. + Represents an event where a tool call produces a result. + + Attributes: + event: Literal identifier for tool result events. + agent: The agent receiving the tool result. + tool_call: The initial tool call. + tool_result: The result produced by the tool call. """ event: Literal["tool-result"] = "tool-result" @@ -159,9 +264,17 @@ class ToolResultEvent(Event): tool_call: Union[ToolCall, InvalidToolCall] tool_result: ToolResult - def to_messages(self, context: "CompileContext") -> List[BaseMessage]: + def to_messages(self, context: "CompileContext") -> list[BaseMessage]: + """ + Converts the tool result event into a list of BaseMessage instances. + + Args: + context: The context for message compilation. + + Returns: + A list of BaseMessage objects representing the tool result, tailored to the agent and context. + """ if self.agent.name == context.agent.name: - logger.debug("Creating tool result message for agent: %s", self.agent.name) return [ ToolMessage( content=self.tool_result.str_result, @@ -169,11 +282,11 @@ def to_messages(self, context: "CompileContext") -> List[BaseMessage]: name=self.agent.name, ) ] - - logger.debug("Creating orchestrator message for tool result from agent: %s", self.agent.name) - return OrchestratorMessage( - prefix=f'Agent "{self.agent.name}" with ID {self.agent.id} made a tool call: {self.tool_call}. ' - f'The tool{" failed and" if self.tool_result.is_error else " "} produced this result:', - content=self.tool_result.str_result, - name=self.agent.name, - ).to_messages(context) + else: + return OrchestratorMessage( + prefix=f'Agent "{self.agent.name}" with ID {self.agent.id} made a tool ' + f'call: {self.tool_call}. The tool{" failed and" if self.tool_result.is_error else " "} ' + f'produced this result:', + content=self.tool_result.str_result, + name=self.agent.name, + ).to_messages(context)