From 45616d9fc9e94117bec8e9c9f41b0e4092ed2c83 Mon Sep 17 00:00:00 2001 From: Anurag Tiwari Date: Mon, 2 Mar 2026 15:47:55 -0800 Subject: [PATCH 1/2] Add guardrails client. Add a Guardrails client to the ADK that allows users to evaluate content against safety rails (jailbreak, content_moderation, sensitive_data). Guardrail evaluations are automatically captured as tool spans in the ADK trace when used inside @entrypoint. - Add gradient_adk/guardrails.py with Guardrails client, result types, and tracing span integration - Export Guardrails, GuardrailResult, GuardrailsError from __init__.py - Add 13 unit tests covering check(), error handling, and tracing - Update README with guardrails feature documentation and examples --- README.md | 58 +++++++ gradient_adk/__init__.py | 5 + gradient_adk/guardrails.py | 287 ++++++++++++++++++++++++++++++++ tests/guardrails_test.py | 328 +++++++++++++++++++++++++++++++++++++ 4 files changed, 678 insertions(+) create mode 100644 gradient_adk/guardrails.py create mode 100644 tests/guardrails_test.py diff --git a/README.md b/README.md index c64f00b..5a64b0c 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,13 @@ Building AI agents is challenging enough without worrying about observability, e - **Streaming Support**: Full support for streaming responses with trace capture - **Production Ready**: Designed for seamless deployment to DigitalOcean infrastructure +### 🛡️ Guardrails + +- **Built-in Safety**: Evaluate user inputs and AI outputs against content safety rails +- **Multiple Rail Types**: Jailbreak detection, content moderation, and sensitive data detection +- **Simple API**: Single `check()` method with clear pass/fail results +- **Automatic Tracing**: Guardrail evaluations are captured as spans in the ADK trace when used inside `@entrypoint` + ## Installation ```bash @@ -168,6 +175,56 @@ async def main(input: dict, context: RequestContext): yield chunk ``` +### Using Guardrails + +Check user inputs and AI outputs against safety rails before and after LLM calls: + +```python +from gradient_adk import entrypoint, RequestContext, Guardrails + +guardrails = Guardrails() + +@entrypoint +async def main(input: dict, context: RequestContext): + # Check user input before calling the LLM + result = await guardrails.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": input["prompt"]}], + ) + if not result.allowed: + return {"error": "Blocked", "violations": [v.message for v in result.violations]} + + response = await llm.generate(input["prompt"]) + + # Optionally check LLM output before returning + output_check = await guardrails.check( + rail_type="content_moderation", + messages=[{"role": "assistant", "content": response}], + evaluation_type="output", + ) + if not output_check.allowed: + return {"error": "Response blocked by content moderation"} + + return {"response": response} +``` + +The `check()` method returns a `GuardrailResult` with: + +| Field | Type | Description | +| ------------- | ------------------------ | --------------------------------------------- | +| `allowed` | `bool` | Whether the content passed the guardrail | +| `violations` | `list[GuardrailViolation]` | List of violations (empty if allowed) | +| `team_id` | `int` | Team ID associated with the request | +| `token_usage` | `TokenUsage` | Token consumption (`input_tokens`, `output_tokens`, `total_tokens`) | + +**Available rail types:** + +| Rail Type | Description | +| ---------------------- | ------------------------------------------------ | +| `jailbreak` | Detects prompt injection and jailbreak attempts | +| `content_moderation` | Detects harmful, violent, or inappropriate content | +| `sensitive_data` | Detects PII and sensitive information | + ## CLI Commands ### Agent Management @@ -349,6 +406,7 @@ The Gradient ADK is designed to work with any Python-based AI agent framework: - ✅ **LangChain** - Use trace decorators (`@trace_llm`, `@trace_tool`, `@trace_retriever`) for custom spans - ✅ **CrewAI** - Use trace decorators for agent and task execution - ✅ **Custom Frameworks** - Use trace decorators for any function +- ✅ **Guardrails** - Built-in safety checks for jailbreak, content moderation, and sensitive data detection ## Support diff --git a/gradient_adk/__init__.py b/gradient_adk/__init__.py index 70ba0bf..903ec25 100644 --- a/gradient_adk/__init__.py +++ b/gradient_adk/__init__.py @@ -14,6 +14,7 @@ add_tool_span, add_agent_span, ) +from .guardrails import Guardrails, GuardrailResult, GuardrailsError __all__ = [ "entrypoint", @@ -26,6 +27,10 @@ "add_llm_span", "add_tool_span", "add_agent_span", + # Guardrails + "Guardrails", + "GuardrailResult", + "GuardrailsError", ] __version__ = "0.0.5" diff --git a/gradient_adk/guardrails.py b/gradient_adk/guardrails.py new file mode 100644 index 0000000..648a035 --- /dev/null +++ b/gradient_adk/guardrails.py @@ -0,0 +1,287 @@ +"""Guardrails client for evaluating content against safety rails. + +Provides a simple async client to call the DigitalOcean Guardrails service. +When used inside an ``@entrypoint``-decorated function, guardrail evaluations +are automatically captured as spans in the ADK trace. + +Example usage:: + + from gradient_adk import Guardrails + + guardrails = Guardrails() + + async def check_input(prompt: str): + result = await guardrails.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": prompt}], + ) + if not result.allowed: + raise ValueError(f"Blocked: {result.violations[0].message}") + return result +""" + +from __future__ import annotations + +import os +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +import httpx + +from .runtime.helpers import get_tracker, _is_tracing_disabled +from .runtime.interfaces import NodeExecution + +_DEFAULT_TIMEOUT = 30.0 + + +@dataclass +class GuardrailViolation: + """A single guardrail violation.""" + + message: str + rule_name: str + + +@dataclass +class TokenUsage: + """Token consumption for a guardrail evaluation.""" + + input_tokens: int = 0 + output_tokens: int = 0 + total_tokens: int = 0 + + +@dataclass +class GuardrailResult: + """Result of a guardrail evaluation.""" + + allowed: bool + team_id: int + violations: List[GuardrailViolation] = field(default_factory=list) + token_usage: TokenUsage = field(default_factory=TokenUsage) + + +class GuardrailsError(Exception): + """Raised when a guardrails evaluation fails.""" + + def __init__(self, message: str, *, status_code: Optional[int] = None): + super().__init__(message) + self.status_code = status_code + + +class Guardrails: + """Client for the DigitalOcean Guardrails service. + + Evaluates content against safety rails (jailbreak, content_moderation, + sensitive_data). Authentication and service configuration are handled + automatically via environment variables. Guardrail evaluations are + captured as tool spans in the ADK trace. + """ + + def __init__(self) -> None: + self._base_url = os.environ.get("GUARDRAILS_URL", "") + self._timeout = _DEFAULT_TIMEOUT + + def _resolve_token(self) -> str: + token = os.environ.get("DIGITALOCEAN_API_TOKEN") + if not token: + raise GuardrailsError( + "DIGITALOCEAN_API_TOKEN environment variable is not set." + ) + return token + + def _resolve_url(self) -> str: + if not self._base_url: + raise GuardrailsError( + "GUARDRAILS_URL environment variable is not set." + ) + return self._base_url.rstrip("/") + + async def check( + self, + rail_type: str, + messages: List[Dict[str, str]], + *, + evaluation_type: str = "input", + ) -> GuardrailResult: + """Evaluate content against a guardrail. + + Args: + rail_type: Type of guardrail — ``"jailbreak"``, + ``"content_moderation"``, or ``"sensitive_data"``. + messages: Messages to evaluate, each with ``role`` and ``content``. + evaluation_type: ``"input"`` (default) to evaluate user messages + before LLM processing, or ``"output"`` to evaluate AI responses. + + Returns: + :class:`GuardrailResult` with ``allowed``, ``violations``, + ``team_id``, and ``token_usage``. + + Raises: + GuardrailsError: On authentication failure, invalid rail type, + or service unavailability. + + Example:: + + result = await guardrails.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Hello!"}], + ) + if result.allowed: + print("Content is safe") + else: + for v in result.violations: + print(f"Violation: {v.message} ({v.rule_name})") + """ + token = self._resolve_token() + url = self._resolve_url() + payload = { + "rail_type": rail_type, + "messages": messages, + "evaluation_type": evaluation_type, + } + + span = _start_guardrail_span(rail_type, payload) + start_ns = time.monotonic_ns() + + try: + result = await self._call(token, url, payload) + duration_ns = time.monotonic_ns() - start_ns + _end_guardrail_span(span, result, duration_ns) + return result + except Exception as exc: + duration_ns = time.monotonic_ns() - start_ns + _error_guardrail_span(span, exc, duration_ns) + raise + + async def _call( + self, token: str, url: str, payload: Dict[str, Any] + ) -> GuardrailResult: + async with httpx.AsyncClient(timeout=self._timeout) as client: + resp = await client.post( + url, + json=payload, + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + ) + + if resp.status_code == 401: + body = resp.json() + raise GuardrailsError( + body.get("description", "Authentication failed"), + status_code=401, + ) + + if resp.status_code != 200: + try: + body = resp.json() + detail = body.get("detail", body.get("message", resp.text)) + except Exception: + detail = resp.text + raise GuardrailsError( + f"Guardrails service error ({resp.status_code}): {detail}", + status_code=resp.status_code, + ) + + body = resp.json() + violations = [ + GuardrailViolation(message=v["message"], rule_name=v["rule_name"]) + for v in body.get("violations", []) + ] + usage = body.get("token_usage", {}) + return GuardrailResult( + allowed=body["allowed"], + team_id=body["team_id"], + violations=violations, + token_usage=TokenUsage( + input_tokens=usage.get("input_tokens", 0), + output_tokens=usage.get("output_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + ) + + +# --------------------------------------------------------------------------- +# Tracing integration +# --------------------------------------------------------------------------- + +def _start_guardrail_span( + rail_type: str, payload: Dict[str, Any] +) -> Optional[NodeExecution]: + if _is_tracing_disabled(): + return None + tracker = get_tracker() + if not tracker: + return None + + span = NodeExecution( + node_id=str(uuid.uuid4()), + node_name=f"guardrail:{rail_type}", + framework="guardrails", + start_time=datetime.now(timezone.utc), + inputs=payload, + metadata={ + "is_tool_call": True, + "is_programmatic": True, + "rail_type": rail_type, + }, + ) + tracker.on_node_start(span) + return span + + +def _end_guardrail_span( + span: Optional[NodeExecution], + result: GuardrailResult, + duration_ns: int, +) -> None: + if span is None: + return + tracker = get_tracker() + if not tracker: + return + + output = { + "allowed": result.allowed, + "team_id": result.team_id, + "violations": [ + {"message": v.message, "rule_name": v.rule_name} + for v in result.violations + ], + "token_usage": { + "input_tokens": result.token_usage.input_tokens, + "output_tokens": result.token_usage.output_tokens, + "total_tokens": result.token_usage.total_tokens, + }, + } + + meta = span.metadata or {} + meta["duration_ns"] = duration_ns + meta["guardrail_allowed"] = result.allowed + meta["guardrail_violations"] = len(result.violations) + span.metadata = meta + + tracker.on_node_end(span, output) + + +def _error_guardrail_span( + span: Optional[NodeExecution], + exc: Exception, + duration_ns: int, +) -> None: + if span is None: + return + tracker = get_tracker() + if not tracker: + return + + meta = span.metadata or {} + meta["duration_ns"] = duration_ns + span.metadata = meta + + tracker.on_node_error(span, exc) diff --git a/tests/guardrails_test.py b/tests/guardrails_test.py new file mode 100644 index 0000000..0847d6b --- /dev/null +++ b/tests/guardrails_test.py @@ -0,0 +1,328 @@ +"""Tests for the guardrails client module.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from gradient_adk.guardrails import ( + Guardrails, + GuardrailResult, + GuardrailsError, + GuardrailViolation, + TokenUsage, +) + +_TEST_URL = "https://test.guardrails.example.com" + + +@pytest.fixture(autouse=True) +def _set_guardrails_url(monkeypatch): + """Set GUARDRAILS_URL for all tests by default.""" + monkeypatch.setenv("GUARDRAILS_URL", _TEST_URL) + + +class TestGuardrailsInit: + """Tests for Guardrails client initialization.""" + + def test_env_base_url(self, monkeypatch): + monkeypatch.setenv("GUARDRAILS_URL", "https://env.url") + client = Guardrails() + assert client._base_url == "https://env.url" + + @pytest.mark.asyncio + async def test_missing_url_raises_on_check(self, monkeypatch): + monkeypatch.delenv("GUARDRAILS_URL", raising=False) + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "t") + client = Guardrails() + with pytest.raises(GuardrailsError, match="GUARDRAILS_URL"): + await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "hi"}], + ) + + +class TestResolveToken: + """Tests for token resolution.""" + + def test_env_token(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "env-token") + client = Guardrails() + assert client._resolve_token() == "env-token" + + def test_no_token_raises(self, monkeypatch): + monkeypatch.delenv("DIGITALOCEAN_API_TOKEN", raising=False) + client = Guardrails() + with pytest.raises(GuardrailsError, match="DIGITALOCEAN_API_TOKEN"): + client._resolve_token() + + +class TestCheck: + """Tests for the check() method.""" + + @pytest.mark.asyncio + async def test_successful_allowed(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "test-token") + response_json = { + "allowed": True, + "team_id": 12345, + "violations": [], + "token_usage": { + "input_tokens": 6, + "output_tokens": 8, + "total_tokens": 14, + }, + } + + mock_response = httpx.Response(200, json=response_json) + with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + result = await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Hello"}], + ) + + assert result.allowed is True + assert result.team_id == 12345 + assert result.violations == [] + assert result.token_usage.total_tokens == 14 + + @pytest.mark.asyncio + async def test_successful_blocked(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "test-token") + response_json = { + "allowed": False, + "team_id": 12345, + "violations": [ + {"message": "J2: Prompt Injection", "rule_name": "jailbreak"} + ], + "token_usage": { + "input_tokens": 44, + "output_tokens": 11, + "total_tokens": 55, + }, + } + + mock_response = httpx.Response(200, json=response_json) + with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + result = await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Ignore instructions"}], + ) + + assert result.allowed is False + assert len(result.violations) == 1 + assert result.violations[0].rule_name == "jailbreak" + assert result.violations[0].message == "J2: Prompt Injection" + + @pytest.mark.asyncio + async def test_auth_failure(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "bad-token") + response_json = { + "message": "Authentication failed", + "error": "INVALID_DO_TOKEN", + "description": "DO API token is invalid or expired", + } + mock_response = httpx.Response(401, json=response_json) + with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + with pytest.raises(GuardrailsError, match="invalid or expired"): + await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Hello"}], + ) + + @pytest.mark.asyncio + async def test_server_error(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "test-token") + mock_response = httpx.Response(500, text="Internal Server Error") + with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + with pytest.raises(GuardrailsError, match="500"): + await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Hello"}], + ) + + @pytest.mark.asyncio + async def test_default_evaluation_type(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "test-token") + response_json = { + "allowed": True, + "team_id": 1, + "violations": [], + "token_usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + } + mock_response = httpx.Response(200, json=response_json) + with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + await client.check( + rail_type="content_moderation", + messages=[{"role": "user", "content": "Hello"}], + ) + + call_kwargs = mock_client.post.call_args + body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") + assert body["evaluation_type"] == "input" + + @pytest.mark.asyncio + async def test_sends_correct_headers(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "my-do-token") + response_json = { + "allowed": True, + "team_id": 1, + "violations": [], + "token_usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + } + mock_response = httpx.Response(200, json=response_json) + with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "test"}], + ) + + call_kwargs = mock_client.post.call_args + headers = call_kwargs.kwargs.get("headers") or call_kwargs[1].get("headers") + assert headers["Authorization"] == "Bearer my-do-token" + + +class TestTracing: + """Tests for trace span integration.""" + + @pytest.mark.asyncio + async def test_creates_trace_span_on_success(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "t") + mock_tracker = MagicMock() + response_json = { + "allowed": True, + "team_id": 1, + "violations": [], + "token_usage": {"input_tokens": 5, "output_tokens": 3, "total_tokens": 8}, + } + mock_response = httpx.Response(200, json=response_json) + + with ( + patch("gradient_adk.guardrails.get_tracker", return_value=mock_tracker), + patch("gradient_adk.guardrails._is_tracing_disabled", return_value=False), + patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls, + ): + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Hi"}], + ) + + assert mock_tracker.on_node_start.call_count == 1 + span = mock_tracker.on_node_start.call_args[0][0] + assert span.node_name == "guardrail:jailbreak" + assert span.framework == "guardrails" + assert span.metadata["is_tool_call"] is True + assert span.metadata["rail_type"] == "jailbreak" + + assert mock_tracker.on_node_end.call_count == 1 + end_output = mock_tracker.on_node_end.call_args[0][1] + assert end_output["allowed"] is True + + @pytest.mark.asyncio + async def test_creates_error_span_on_failure(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "bad") + mock_tracker = MagicMock() + mock_response = httpx.Response(401, json={ + "description": "token expired", + }) + + with ( + patch("gradient_adk.guardrails.get_tracker", return_value=mock_tracker), + patch("gradient_adk.guardrails._is_tracing_disabled", return_value=False), + patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls, + ): + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + with pytest.raises(GuardrailsError): + await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Hi"}], + ) + + assert mock_tracker.on_node_start.call_count == 1 + assert mock_tracker.on_node_error.call_count == 1 + + @pytest.mark.asyncio + async def test_no_span_when_tracing_disabled(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "t") + response_json = { + "allowed": True, + "team_id": 1, + "violations": [], + "token_usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + } + mock_response = httpx.Response(200, json=response_json) + + with ( + patch("gradient_adk.guardrails._is_tracing_disabled", return_value=True), + patch("gradient_adk.guardrails.get_tracker") as mock_get_tracker, + patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls, + ): + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + result = await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Hi"}], + ) + + assert result.allowed is True + mock_get_tracker.assert_not_called() From 7d140a3c0c650801b7cd8e35d8e0238fe22f18a2 Mon Sep 17 00:00:00 2001 From: Anurag Tiwari Date: Tue, 3 Mar 2026 14:08:07 -0800 Subject: [PATCH 2/2] Refactor code --- README.md | 20 +++--- gradient_adk/__init__.py | 4 +- gradient_adk/guardrails.py | 140 +++++++------------------------------ tests/guardrails_test.py | 93 ++++++++++-------------- 4 files changed, 75 insertions(+), 182 deletions(-) diff --git a/README.md b/README.md index 5a64b0c..f570ca0 100644 --- a/README.md +++ b/README.md @@ -191,8 +191,8 @@ async def main(input: dict, context: RequestContext): rail_type="jailbreak", messages=[{"role": "user", "content": input["prompt"]}], ) - if not result.allowed: - return {"error": "Blocked", "violations": [v.message for v in result.violations]} + if not result["allowed"]: + return {"error": "Blocked", "violations": result["violations"]} response = await llm.generate(input["prompt"]) @@ -202,20 +202,20 @@ async def main(input: dict, context: RequestContext): messages=[{"role": "assistant", "content": response}], evaluation_type="output", ) - if not output_check.allowed: + if not output_check["allowed"]: return {"error": "Response blocked by content moderation"} return {"response": response} ``` -The `check()` method returns a `GuardrailResult` with: +The `check()` method returns a dict with: -| Field | Type | Description | -| ------------- | ------------------------ | --------------------------------------------- | -| `allowed` | `bool` | Whether the content passed the guardrail | -| `violations` | `list[GuardrailViolation]` | List of violations (empty if allowed) | -| `team_id` | `int` | Team ID associated with the request | -| `token_usage` | `TokenUsage` | Token consumption (`input_tokens`, `output_tokens`, `total_tokens`) | +| Key | Type | Description | +| ------------- | ------------ | --------------------------------------------- | +| `allowed` | `bool` | Whether the content passed the guardrail | +| `violations` | `list[dict]` | List of violations, each with `message` and `rule_name` | +| `team_id` | `int` | Team ID associated with the request | +| `token_usage` | `dict` | Token consumption (`input_tokens`, `output_tokens`, `total_tokens`) | **Available rail types:** diff --git a/gradient_adk/__init__.py b/gradient_adk/__init__.py index 903ec25..4883ed1 100644 --- a/gradient_adk/__init__.py +++ b/gradient_adk/__init__.py @@ -14,7 +14,7 @@ add_tool_span, add_agent_span, ) -from .guardrails import Guardrails, GuardrailResult, GuardrailsError +from .guardrails import Guardrails __all__ = [ "entrypoint", @@ -29,8 +29,6 @@ "add_agent_span", # Guardrails "Guardrails", - "GuardrailResult", - "GuardrailsError", ] __version__ = "0.0.5" diff --git a/gradient_adk/guardrails.py b/gradient_adk/guardrails.py index 648a035..39aa1ba 100644 --- a/gradient_adk/guardrails.py +++ b/gradient_adk/guardrails.py @@ -15,8 +15,8 @@ async def check_input(prompt: str): rail_type="jailbreak", messages=[{"role": "user", "content": prompt}], ) - if not result.allowed: - raise ValueError(f"Blocked: {result.violations[0].message}") + if not result["allowed"]: + raise ValueError(f"Blocked: {result['violations'][0]['message']}") return result """ @@ -25,7 +25,6 @@ async def check_input(prompt: str): import os import time import uuid -from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any, Dict, List, Optional @@ -34,79 +33,37 @@ async def check_input(prompt: str): from .runtime.helpers import get_tracker, _is_tracing_disabled from .runtime.interfaces import NodeExecution +_GUARDRAILS_ENDPOINT = "https://guardrails.do-ai.run/v2/rail" _DEFAULT_TIMEOUT = 30.0 -@dataclass -class GuardrailViolation: - """A single guardrail violation.""" - - message: str - rule_name: str - - -@dataclass -class TokenUsage: - """Token consumption for a guardrail evaluation.""" - - input_tokens: int = 0 - output_tokens: int = 0 - total_tokens: int = 0 - - -@dataclass -class GuardrailResult: - """Result of a guardrail evaluation.""" - - allowed: bool - team_id: int - violations: List[GuardrailViolation] = field(default_factory=list) - token_usage: TokenUsage = field(default_factory=TokenUsage) - - -class GuardrailsError(Exception): - """Raised when a guardrails evaluation fails.""" - - def __init__(self, message: str, *, status_code: Optional[int] = None): - super().__init__(message) - self.status_code = status_code - - class Guardrails: """Client for the DigitalOcean Guardrails service. Evaluates content against safety rails (jailbreak, content_moderation, - sensitive_data). Authentication and service configuration are handled - automatically via environment variables. Guardrail evaluations are - captured as tool spans in the ADK trace. + sensitive_data). When used inside an ``@entrypoint`` function, guardrail + evaluations are automatically captured as tool spans in the ADK trace. """ def __init__(self) -> None: - self._base_url = os.environ.get("GUARDRAILS_URL", "") + self._endpoint = _GUARDRAILS_ENDPOINT self._timeout = _DEFAULT_TIMEOUT def _resolve_token(self) -> str: token = os.environ.get("DIGITALOCEAN_API_TOKEN") if not token: - raise GuardrailsError( + raise RuntimeError( "DIGITALOCEAN_API_TOKEN environment variable is not set." ) return token - def _resolve_url(self) -> str: - if not self._base_url: - raise GuardrailsError( - "GUARDRAILS_URL environment variable is not set." - ) - return self._base_url.rstrip("/") - async def check( self, rail_type: str, messages: List[Dict[str, str]], *, evaluation_type: str = "input", - ) -> GuardrailResult: + ) -> Dict[str, Any]: """Evaluate content against a guardrail. Args: @@ -117,12 +74,14 @@ async def check( before LLM processing, or ``"output"`` to evaluate AI responses. Returns: - :class:`GuardrailResult` with ``allowed``, ``violations``, - ``team_id``, and ``token_usage``. + A dict with ``allowed`` (bool), ``team_id`` (int), + ``violations`` (list of dicts with ``message`` and ``rule_name``), + and ``token_usage`` (dict with ``input_tokens``, ``output_tokens``, + ``total_tokens``). Raises: - GuardrailsError: On authentication failure, invalid rail type, - or service unavailability. + RuntimeError: If ``DIGITALOCEAN_API_TOKEN`` is not set. + httpx.HTTPStatusError: On non-200 responses from the service. Example:: @@ -130,14 +89,13 @@ async def check( rail_type="jailbreak", messages=[{"role": "user", "content": "Hello!"}], ) - if result.allowed: + if result["allowed"]: print("Content is safe") else: - for v in result.violations: - print(f"Violation: {v.message} ({v.rule_name})") + for v in result["violations"]: + print(f"Violation: {v['message']} ({v['rule_name']})") """ token = self._resolve_token() - url = self._resolve_url() payload = { "rail_type": rail_type, "messages": messages, @@ -148,7 +106,7 @@ async def check( start_ns = time.monotonic_ns() try: - result = await self._call(token, url, payload) + result = await self._call(token, self._endpoint, payload) duration_ns = time.monotonic_ns() - start_ns _end_guardrail_span(span, result, duration_ns) return result @@ -159,7 +117,7 @@ async def check( async def _call( self, token: str, url: str, payload: Dict[str, Any] - ) -> GuardrailResult: + ) -> Dict[str, Any]: async with httpx.AsyncClient(timeout=self._timeout) as client: resp = await client.post( url, @@ -170,40 +128,8 @@ async def _call( }, ) - if resp.status_code == 401: - body = resp.json() - raise GuardrailsError( - body.get("description", "Authentication failed"), - status_code=401, - ) - - if resp.status_code != 200: - try: - body = resp.json() - detail = body.get("detail", body.get("message", resp.text)) - except Exception: - detail = resp.text - raise GuardrailsError( - f"Guardrails service error ({resp.status_code}): {detail}", - status_code=resp.status_code, - ) - - body = resp.json() - violations = [ - GuardrailViolation(message=v["message"], rule_name=v["rule_name"]) - for v in body.get("violations", []) - ] - usage = body.get("token_usage", {}) - return GuardrailResult( - allowed=body["allowed"], - team_id=body["team_id"], - violations=violations, - token_usage=TokenUsage( - input_tokens=usage.get("input_tokens", 0), - output_tokens=usage.get("output_tokens", 0), - total_tokens=usage.get("total_tokens", 0), - ), - ) + resp.raise_for_status() + return resp.json() # --------------------------------------------------------------------------- @@ -237,7 +163,7 @@ def _start_guardrail_span( def _end_guardrail_span( span: Optional[NodeExecution], - result: GuardrailResult, + result: Dict[str, Any], duration_ns: int, ) -> None: if span is None: @@ -246,27 +172,15 @@ def _end_guardrail_span( if not tracker: return - output = { - "allowed": result.allowed, - "team_id": result.team_id, - "violations": [ - {"message": v.message, "rule_name": v.rule_name} - for v in result.violations - ], - "token_usage": { - "input_tokens": result.token_usage.input_tokens, - "output_tokens": result.token_usage.output_tokens, - "total_tokens": result.token_usage.total_tokens, - }, - } - meta = span.metadata or {} meta["duration_ns"] = duration_ns - meta["guardrail_allowed"] = result.allowed - meta["guardrail_violations"] = len(result.violations) + meta["guardrail_allowed"] = result.get("allowed") + meta["guardrail_violations"] = len(result.get("violations", [])) + token_usage = result.get("token_usage", {}) + meta["guardrail_total_tokens"] = token_usage.get("total_tokens", 0) span.metadata = meta - tracker.on_node_end(span, output) + tracker.on_node_end(span, result) def _error_guardrail_span( diff --git a/tests/guardrails_test.py b/tests/guardrails_test.py index 0847d6b..1d6ec4d 100644 --- a/tests/guardrails_test.py +++ b/tests/guardrails_test.py @@ -5,41 +5,17 @@ import httpx import pytest -from gradient_adk.guardrails import ( - Guardrails, - GuardrailResult, - GuardrailsError, - GuardrailViolation, - TokenUsage, -) +from gradient_adk.guardrails import Guardrails _TEST_URL = "https://test.guardrails.example.com" -@pytest.fixture(autouse=True) -def _set_guardrails_url(monkeypatch): - """Set GUARDRAILS_URL for all tests by default.""" - monkeypatch.setenv("GUARDRAILS_URL", _TEST_URL) - - class TestGuardrailsInit: """Tests for Guardrails client initialization.""" - def test_env_base_url(self, monkeypatch): - monkeypatch.setenv("GUARDRAILS_URL", "https://env.url") + def test_default_endpoint(self): client = Guardrails() - assert client._base_url == "https://env.url" - - @pytest.mark.asyncio - async def test_missing_url_raises_on_check(self, monkeypatch): - monkeypatch.delenv("GUARDRAILS_URL", raising=False) - monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "t") - client = Guardrails() - with pytest.raises(GuardrailsError, match="GUARDRAILS_URL"): - await client.check( - rail_type="jailbreak", - messages=[{"role": "user", "content": "hi"}], - ) + assert "guardrails" in client._endpoint class TestResolveToken: @@ -53,7 +29,7 @@ def test_env_token(self, monkeypatch): def test_no_token_raises(self, monkeypatch): monkeypatch.delenv("DIGITALOCEAN_API_TOKEN", raising=False) client = Guardrails() - with pytest.raises(GuardrailsError, match="DIGITALOCEAN_API_TOKEN"): + with pytest.raises(RuntimeError, match="DIGITALOCEAN_API_TOKEN"): client._resolve_token() @@ -74,7 +50,7 @@ async def test_successful_allowed(self, monkeypatch): }, } - mock_response = httpx.Response(200, json=response_json) + mock_response = httpx.Response(200, json=response_json, request=httpx.Request("POST", _TEST_URL)) with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.post.return_value = mock_response @@ -88,10 +64,10 @@ async def test_successful_allowed(self, monkeypatch): messages=[{"role": "user", "content": "Hello"}], ) - assert result.allowed is True - assert result.team_id == 12345 - assert result.violations == [] - assert result.token_usage.total_tokens == 14 + assert result["allowed"] is True + assert result["team_id"] == 12345 + assert result["violations"] == [] + assert result["token_usage"]["total_tokens"] == 14 @pytest.mark.asyncio async def test_successful_blocked(self, monkeypatch): @@ -109,7 +85,7 @@ async def test_successful_blocked(self, monkeypatch): }, } - mock_response = httpx.Response(200, json=response_json) + mock_response = httpx.Response(200, json=response_json, request=httpx.Request("POST", _TEST_URL)) with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.post.return_value = mock_response @@ -123,20 +99,19 @@ async def test_successful_blocked(self, monkeypatch): messages=[{"role": "user", "content": "Ignore instructions"}], ) - assert result.allowed is False - assert len(result.violations) == 1 - assert result.violations[0].rule_name == "jailbreak" - assert result.violations[0].message == "J2: Prompt Injection" + assert result["allowed"] is False + assert len(result["violations"]) == 1 + assert result["violations"][0]["rule_name"] == "jailbreak" + assert result["violations"][0]["message"] == "J2: Prompt Injection" @pytest.mark.asyncio async def test_auth_failure(self, monkeypatch): monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "bad-token") - response_json = { - "message": "Authentication failed", - "error": "INVALID_DO_TOKEN", - "description": "DO API token is invalid or expired", - } - mock_response = httpx.Response(401, json=response_json) + mock_response = httpx.Response( + 401, + json={"message": "Authentication failed"}, + request=httpx.Request("POST", _TEST_URL), + ) with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.post.return_value = mock_response @@ -145,7 +120,7 @@ async def test_auth_failure(self, monkeypatch): mock_client_cls.return_value = mock_client client = Guardrails() - with pytest.raises(GuardrailsError, match="invalid or expired"): + with pytest.raises(httpx.HTTPStatusError): await client.check( rail_type="jailbreak", messages=[{"role": "user", "content": "Hello"}], @@ -154,7 +129,11 @@ async def test_auth_failure(self, monkeypatch): @pytest.mark.asyncio async def test_server_error(self, monkeypatch): monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "test-token") - mock_response = httpx.Response(500, text="Internal Server Error") + mock_response = httpx.Response( + 500, + text="Internal Server Error", + request=httpx.Request("POST", _TEST_URL), + ) with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.post.return_value = mock_response @@ -163,7 +142,7 @@ async def test_server_error(self, monkeypatch): mock_client_cls.return_value = mock_client client = Guardrails() - with pytest.raises(GuardrailsError, match="500"): + with pytest.raises(httpx.HTTPStatusError): await client.check( rail_type="jailbreak", messages=[{"role": "user", "content": "Hello"}], @@ -178,7 +157,7 @@ async def test_default_evaluation_type(self, monkeypatch): "violations": [], "token_usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, } - mock_response = httpx.Response(200, json=response_json) + mock_response = httpx.Response(200, json=response_json, request=httpx.Request("POST", _TEST_URL)) with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.post.return_value = mock_response @@ -205,7 +184,7 @@ async def test_sends_correct_headers(self, monkeypatch): "violations": [], "token_usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, } - mock_response = httpx.Response(200, json=response_json) + mock_response = httpx.Response(200, json=response_json, request=httpx.Request("POST", _TEST_URL)) with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.post.return_value = mock_response @@ -237,7 +216,7 @@ async def test_creates_trace_span_on_success(self, monkeypatch): "violations": [], "token_usage": {"input_tokens": 5, "output_tokens": 3, "total_tokens": 8}, } - mock_response = httpx.Response(200, json=response_json) + mock_response = httpx.Response(200, json=response_json, request=httpx.Request("POST", _TEST_URL)) with ( patch("gradient_adk.guardrails.get_tracker", return_value=mock_tracker), @@ -271,9 +250,11 @@ async def test_creates_trace_span_on_success(self, monkeypatch): async def test_creates_error_span_on_failure(self, monkeypatch): monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "bad") mock_tracker = MagicMock() - mock_response = httpx.Response(401, json={ - "description": "token expired", - }) + mock_response = httpx.Response( + 401, + json={"description": "token expired"}, + request=httpx.Request("POST", _TEST_URL), + ) with ( patch("gradient_adk.guardrails.get_tracker", return_value=mock_tracker), @@ -287,7 +268,7 @@ async def test_creates_error_span_on_failure(self, monkeypatch): mock_client_cls.return_value = mock_client client = Guardrails() - with pytest.raises(GuardrailsError): + with pytest.raises(httpx.HTTPStatusError): await client.check( rail_type="jailbreak", messages=[{"role": "user", "content": "Hi"}], @@ -305,7 +286,7 @@ async def test_no_span_when_tracing_disabled(self, monkeypatch): "violations": [], "token_usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, } - mock_response = httpx.Response(200, json=response_json) + mock_response = httpx.Response(200, json=response_json, request=httpx.Request("POST", _TEST_URL)) with ( patch("gradient_adk.guardrails._is_tracing_disabled", return_value=True), @@ -324,5 +305,5 @@ async def test_no_span_when_tracing_disabled(self, monkeypatch): messages=[{"role": "user", "content": "Hi"}], ) - assert result.allowed is True + assert result["allowed"] is True mock_get_tracker.assert_not_called()