diff --git a/axonflow/__init__.py b/axonflow/__init__.py index 2c825aa..828824c 100644 --- a/axonflow/__init__.py +++ b/axonflow/__init__.py @@ -207,6 +207,7 @@ StepGateRequest, StepGateResponse, StepType, + ToolContext, WorkflowSource, WorkflowStatus, WorkflowStatusResponse, @@ -368,6 +369,7 @@ "MarkStepCompletedRequest", "AbortWorkflowRequest", "PolicyMatch", + "ToolContext", # WCP Approval types (Feature 5) "ApproveStepResponse", "RejectStepResponse", diff --git a/axonflow/adapters/langgraph.py b/axonflow/adapters/langgraph.py index c8724be..72d2c61 100644 --- a/axonflow/adapters/langgraph.py +++ b/axonflow/adapters/langgraph.py @@ -42,6 +42,7 @@ MarkStepCompletedRequest, StepGateRequest, StepType, + ToolContext, WorkflowSource, ) @@ -132,6 +133,7 @@ async def start_workflow( self, total_steps: int | None = None, metadata: dict[str, Any] | None = None, + trace_id: str | None = None, ) -> str: """Register the workflow with AxonFlow. @@ -140,6 +142,7 @@ async def start_workflow( Args: total_steps: Total number of steps (if known) metadata: Additional workflow metadata + trace_id: External trace ID for correlation (Langsmith, Datadog, OTel) Returns: The assigned workflow ID @@ -147,7 +150,8 @@ async def start_workflow( Example: >>> workflow_id = await adapter.start_workflow( ... total_steps=5, - ... metadata={"customer_id": "cust-123"} + ... metadata={"customer_id": "cust-123"}, + ... trace_id="langsmith-run-abc123", ... ) """ request = CreateWorkflowRequest( @@ -155,6 +159,7 @@ async def start_workflow( source=self.source, total_steps=total_steps, metadata=metadata or {}, + trace_id=trace_id, ) response = await self.client.create_workflow(request) @@ -170,6 +175,7 @@ async def check_gate( step_input: dict[str, Any] | None = None, model: str | None = None, provider: str | None = None, + tool_context: ToolContext | None = None, ) -> bool: """Check if a step is allowed to proceed. @@ -182,6 +188,7 @@ async def check_gate( step_input: Input data for the step (for policy evaluation) model: LLM model being used provider: LLM provider being used + tool_context: Tool-level context for per-tool governance (tool_call steps) Returns: True if step is allowed, False if blocked (when auto_block=False) @@ -214,6 +221,7 @@ async def check_gate( step_input=step_input or {}, model=model, provider=provider, + tool_context=tool_context, ) response = await self.client.step_gate(self.workflow_id, step_id, request) @@ -247,6 +255,9 @@ async def step_completed( step_id: str | None = None, output: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, + tokens_in: int | None = None, + tokens_out: int | None = None, + cost_usd: float | None = None, ) -> None: """Mark a step as completed. @@ -257,6 +268,9 @@ async def step_completed( step_id: Optional step ID (must match the one used in check_gate) output: Output data from the step metadata: Additional metadata + tokens_in: Input tokens consumed + tokens_out: Output tokens produced + cost_usd: Cost in USD Example: >>> await adapter.step_completed("generate", output={"code": result}) @@ -272,10 +286,114 @@ async def step_completed( request = MarkStepCompletedRequest( output=output or {}, metadata=metadata or {}, + tokens_in=tokens_in, + tokens_out=tokens_out, + cost_usd=cost_usd, ) await self.client.mark_step_completed(self.workflow_id, step_id, request) + async def check_tool_gate( + self, + tool_name: str, + tool_type: str | None = None, + *, + step_name: str | None = None, + step_id: str | None = None, + tool_input: dict[str, Any] | None = None, + model: str | None = None, + provider: str | None = None, + ) -> bool: + """Check if a specific tool invocation is allowed. + + Convenience wrapper around check_gate() that sets step_type=TOOL_CALL + and includes ToolContext for per-tool governance. + + Args: + tool_name: Name of the tool being invoked + tool_type: Tool type (function, mcp, api) + step_name: Step name (defaults to "tools/{tool_name}") + step_id: Optional step ID (auto-generated if not provided) + tool_input: Input arguments for the tool + model: LLM model being used + provider: LLM provider being used + + Returns: + True if tool invocation is allowed, False if blocked (when auto_block=False) + + Raises: + WorkflowBlockedError: If tool is blocked and auto_block=True + WorkflowApprovalRequiredError: If tool requires approval + ValueError: If workflow not started + + Example: + >>> if await adapter.check_tool_gate("web_search", "function", + ... tool_input={"query": "latest news"}): + ... result = await web_search(query="latest news") + ... await adapter.tool_completed("web_search", output={"results": result}) + """ + if step_name is None: + step_name = f"tools/{tool_name}" + + tool_context = ToolContext( + tool_name=tool_name, + tool_type=tool_type, + tool_input=tool_input or {}, + ) + + return await self.check_gate( + step_name=step_name, + step_type=StepType.TOOL_CALL, + step_id=step_id, + model=model, + provider=provider, + tool_context=tool_context, + ) + + async def tool_completed( + self, + tool_name: str, + *, + step_name: str | None = None, + step_id: str | None = None, + output: dict[str, Any] | None = None, + tokens_in: int | None = None, + tokens_out: int | None = None, + cost_usd: float | None = None, + ) -> None: + """Mark a tool invocation as completed. + + Convenience wrapper around step_completed() for tool-level tracking. + + Args: + tool_name: Name of the tool that was invoked + step_name: Step name (defaults to "tools/{tool_name}") + step_id: Optional step ID (must match the one used in check_tool_gate) + output: Output data from the tool + tokens_in: Input tokens consumed + tokens_out: Output tokens produced + cost_usd: Cost in USD + + Example: + >>> await adapter.tool_completed("web_search", + ... output={"results": search_results}, + ... tokens_in=150, + ... tokens_out=500, + ... cost_usd=0.002, + ... ) + """ + if step_name is None: + step_name = f"tools/{tool_name}" + + await self.step_completed( + step_name=step_name, + step_id=step_id, + output=output, + tokens_in=tokens_in, + tokens_out=tokens_out, + cost_usd=cost_usd, + ) + async def complete_workflow(self) -> None: """Mark the workflow as completed. diff --git a/axonflow/client.py b/axonflow/client.py index fc56bb8..98d88a6 100644 --- a/axonflow/client.py +++ b/axonflow/client.py @@ -3606,6 +3606,8 @@ async def list_workflows( params.append(f"status={options.status.value}") if options.source: params.append(f"source={options.source.value}") + if options.trace_id: + params.append(f"trace_id={options.trace_id}") if options.limit: params.append(f"limit={options.limit}") if options.offset: diff --git a/axonflow/workflow.py b/axonflow/workflow.py index 70ff19a..b9a8ecb 100644 --- a/axonflow/workflow.py +++ b/axonflow/workflow.py @@ -81,6 +81,10 @@ class CreateWorkflowRequest(BaseModel): metadata: dict[str, Any] = Field( default_factory=dict, description="Additional metadata for the workflow" ) + trace_id: str | None = Field( + default=None, + description="External trace ID for correlation (Langsmith, Datadog, OTel)", + ) class CreateWorkflowResponse(BaseModel): @@ -91,6 +95,15 @@ class CreateWorkflowResponse(BaseModel): source: WorkflowSource = Field(..., description="Source orchestrator") status: WorkflowStatus = Field(..., description="Current status (always 'in_progress' for new)") created_at: datetime = Field(..., description="When the workflow was created") + trace_id: str | None = None + + +class ToolContext(BaseModel): + """Tool-level context for per-tool governance within tool_call steps.""" + + tool_name: str + tool_type: str | None = Field(default=None, description="Tool type: function, mcp, api") + tool_input: dict[str, Any] = Field(default_factory=dict) class StepGateRequest(BaseModel): @@ -105,6 +118,7 @@ class StepGateRequest(BaseModel): ) model: str | None = Field(default=None, description="LLM model being used (if applicable)") provider: str | None = Field(default=None, description="LLM provider (if applicable)") + tool_context: ToolContext | None = None class StepGateResponse(BaseModel): @@ -178,6 +192,7 @@ class WorkflowStatusResponse(BaseModel): steps: list[WorkflowStepInfo] = Field( default_factory=list, description="List of steps in the workflow" ) + trace_id: str | None = None def is_terminal(self) -> bool: """Check if the workflow is in a terminal state (completed, aborted, or failed).""" @@ -191,6 +206,7 @@ class ListWorkflowsOptions(BaseModel): status: WorkflowStatus | None = Field(default=None, description="Filter by workflow status") source: WorkflowSource | None = Field(default=None, description="Filter by source") + trace_id: str | None = Field(default=None, description="Filter by external trace ID") limit: int = Field(default=50, ge=1, le=100, description="Maximum number of results to return") offset: int = Field(default=0, ge=0, description="Offset for pagination")