Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions axonflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@
StepGateRequest,
StepGateResponse,
StepType,
ToolContext,
WorkflowSource,
WorkflowStatus,
WorkflowStatusResponse,
Expand Down Expand Up @@ -368,6 +369,7 @@
"MarkStepCompletedRequest",
"AbortWorkflowRequest",
"PolicyMatch",
"ToolContext",
# WCP Approval types (Feature 5)
"ApproveStepResponse",
"RejectStepResponse",
Expand Down
120 changes: 119 additions & 1 deletion axonflow/adapters/langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
MarkStepCompletedRequest,
StepGateRequest,
StepType,
ToolContext,
WorkflowSource,
)

Expand Down Expand Up @@ -132,6 +133,7 @@ async def start_workflow(
self,
total_steps: int | None = None,
metadata: dict[str, Any] | None = None,
trace_id: str | None = None,
) -> str:
"""Register the workflow with AxonFlow.

Expand All @@ -140,21 +142,24 @@ async def start_workflow(
Args:
total_steps: Total number of steps (if known)
metadata: Additional workflow metadata
trace_id: External trace ID for correlation (Langsmith, Datadog, OTel)

Returns:
The assigned workflow ID

Example:
>>> workflow_id = await adapter.start_workflow(
... total_steps=5,
... metadata={"customer_id": "cust-123"}
... metadata={"customer_id": "cust-123"},
... trace_id="langsmith-run-abc123",
... )
"""
request = CreateWorkflowRequest(
workflow_name=self.workflow_name,
source=self.source,
total_steps=total_steps,
metadata=metadata or {},
trace_id=trace_id,
)

response = await self.client.create_workflow(request)
Expand All @@ -170,6 +175,7 @@ async def check_gate(
step_input: dict[str, Any] | None = None,
model: str | None = None,
provider: str | None = None,
tool_context: ToolContext | None = None,
) -> bool:
"""Check if a step is allowed to proceed.

Expand All @@ -182,6 +188,7 @@ async def check_gate(
step_input: Input data for the step (for policy evaluation)
model: LLM model being used
provider: LLM provider being used
tool_context: Tool-level context for per-tool governance (tool_call steps)

Returns:
True if step is allowed, False if blocked (when auto_block=False)
Expand Down Expand Up @@ -214,6 +221,7 @@ async def check_gate(
step_input=step_input or {},
model=model,
provider=provider,
tool_context=tool_context,
)

response = await self.client.step_gate(self.workflow_id, step_id, request)
Expand Down Expand Up @@ -247,6 +255,9 @@ async def step_completed(
step_id: str | None = None,
output: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
tokens_in: int | None = None,
tokens_out: int | None = None,
cost_usd: float | None = None,
) -> None:
"""Mark a step as completed.

Expand All @@ -257,6 +268,9 @@ async def step_completed(
step_id: Optional step ID (must match the one used in check_gate)
output: Output data from the step
metadata: Additional metadata
tokens_in: Input tokens consumed
tokens_out: Output tokens produced
cost_usd: Cost in USD

Example:
>>> await adapter.step_completed("generate", output={"code": result})
Expand All @@ -272,10 +286,114 @@ async def step_completed(
request = MarkStepCompletedRequest(
output=output or {},
metadata=metadata or {},
tokens_in=tokens_in,
tokens_out=tokens_out,
cost_usd=cost_usd,
)

await self.client.mark_step_completed(self.workflow_id, step_id, request)

async def check_tool_gate(
self,
tool_name: str,
tool_type: str | None = None,
*,
step_name: str | None = None,
step_id: str | None = None,
tool_input: dict[str, Any] | None = None,
model: str | None = None,
provider: str | None = None,
) -> bool:
"""Check if a specific tool invocation is allowed.

Convenience wrapper around check_gate() that sets step_type=TOOL_CALL
and includes ToolContext for per-tool governance.

Args:
tool_name: Name of the tool being invoked
tool_type: Tool type (function, mcp, api)
step_name: Step name (defaults to "tools/{tool_name}")
step_id: Optional step ID (auto-generated if not provided)
tool_input: Input arguments for the tool
model: LLM model being used
provider: LLM provider being used

Returns:
True if tool invocation is allowed, False if blocked (when auto_block=False)

Raises:
WorkflowBlockedError: If tool is blocked and auto_block=True
WorkflowApprovalRequiredError: If tool requires approval
ValueError: If workflow not started

Example:
>>> if await adapter.check_tool_gate("web_search", "function",
... tool_input={"query": "latest news"}):
... result = await web_search(query="latest news")
... await adapter.tool_completed("web_search", output={"results": result})
"""
if step_name is None:
step_name = f"tools/{tool_name}"

tool_context = ToolContext(
tool_name=tool_name,
tool_type=tool_type,
tool_input=tool_input or {},
)

return await self.check_gate(
step_name=step_name,
step_type=StepType.TOOL_CALL,
step_id=step_id,
model=model,
provider=provider,
tool_context=tool_context,
)

async def tool_completed(
self,
tool_name: str,
*,
step_name: str | None = None,
step_id: str | None = None,
output: dict[str, Any] | None = None,
tokens_in: int | None = None,
tokens_out: int | None = None,
cost_usd: float | None = None,
) -> None:
"""Mark a tool invocation as completed.

Convenience wrapper around step_completed() for tool-level tracking.

Args:
tool_name: Name of the tool that was invoked
step_name: Step name (defaults to "tools/{tool_name}")
step_id: Optional step ID (must match the one used in check_tool_gate)
output: Output data from the tool
tokens_in: Input tokens consumed
tokens_out: Output tokens produced
cost_usd: Cost in USD

Example:
>>> await adapter.tool_completed("web_search",
... output={"results": search_results},
... tokens_in=150,
... tokens_out=500,
... cost_usd=0.002,
... )
"""
if step_name is None:
step_name = f"tools/{tool_name}"

await self.step_completed(
step_name=step_name,
step_id=step_id,
output=output,
tokens_in=tokens_in,
tokens_out=tokens_out,
cost_usd=cost_usd,
)

async def complete_workflow(self) -> None:
"""Mark the workflow as completed.

Expand Down
2 changes: 2 additions & 0 deletions axonflow/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3606,6 +3606,8 @@ async def list_workflows(
params.append(f"status={options.status.value}")
if options.source:
params.append(f"source={options.source.value}")
if options.trace_id:
params.append(f"trace_id={options.trace_id}")
if options.limit:
params.append(f"limit={options.limit}")
if options.offset:
Expand Down
16 changes: 16 additions & 0 deletions axonflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class CreateWorkflowRequest(BaseModel):
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional metadata for the workflow"
)
trace_id: str | None = Field(
default=None,
description="External trace ID for correlation (Langsmith, Datadog, OTel)",
)


class CreateWorkflowResponse(BaseModel):
Expand All @@ -91,6 +95,15 @@ class CreateWorkflowResponse(BaseModel):
source: WorkflowSource = Field(..., description="Source orchestrator")
status: WorkflowStatus = Field(..., description="Current status (always 'in_progress' for new)")
created_at: datetime = Field(..., description="When the workflow was created")
trace_id: str | None = None


class ToolContext(BaseModel):
"""Tool-level context for per-tool governance within tool_call steps."""

tool_name: str
tool_type: str | None = Field(default=None, description="Tool type: function, mcp, api")
tool_input: dict[str, Any] = Field(default_factory=dict)


class StepGateRequest(BaseModel):
Expand All @@ -105,6 +118,7 @@ class StepGateRequest(BaseModel):
)
model: str | None = Field(default=None, description="LLM model being used (if applicable)")
provider: str | None = Field(default=None, description="LLM provider (if applicable)")
tool_context: ToolContext | None = None


class StepGateResponse(BaseModel):
Expand Down Expand Up @@ -178,6 +192,7 @@ class WorkflowStatusResponse(BaseModel):
steps: list[WorkflowStepInfo] = Field(
default_factory=list, description="List of steps in the workflow"
)
trace_id: str | None = None

def is_terminal(self) -> bool:
"""Check if the workflow is in a terminal state (completed, aborted, or failed)."""
Expand All @@ -191,6 +206,7 @@ class ListWorkflowsOptions(BaseModel):

status: WorkflowStatus | None = Field(default=None, description="Filter by workflow status")
source: WorkflowSource | None = Field(default=None, description="Filter by source")
trace_id: str | None = Field(default=None, description="Filter by external trace ID")
limit: int = Field(default=50, ge=1, le=100, description="Maximum number of results to return")
offset: int = Field(default=0, ge=0, description="Offset for pagination")

Expand Down