diff --git a/README.md b/README.md index c64f00b..f570ca0 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,13 @@ Building AI agents is challenging enough without worrying about observability, e - **Streaming Support**: Full support for streaming responses with trace capture - **Production Ready**: Designed for seamless deployment to DigitalOcean infrastructure +### 🛡️ Guardrails + +- **Built-in Safety**: Evaluate user inputs and AI outputs against content safety rails +- **Multiple Rail Types**: Jailbreak detection, content moderation, and sensitive data detection +- **Simple API**: Single `check()` method with clear pass/fail results +- **Automatic Tracing**: Guardrail evaluations are captured as spans in the ADK trace when used inside `@entrypoint` + ## Installation ```bash @@ -168,6 +175,56 @@ async def main(input: dict, context: RequestContext): yield chunk ``` +### Using Guardrails + +Check user inputs and AI outputs against safety rails before and after LLM calls: + +```python +from gradient_adk import entrypoint, RequestContext, Guardrails + +guardrails = Guardrails() + +@entrypoint +async def main(input: dict, context: RequestContext): + # Check user input before calling the LLM + result = await guardrails.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": input["prompt"]}], + ) + if not result["allowed"]: + return {"error": "Blocked", "violations": result["violations"]} + + response = await llm.generate(input["prompt"]) + + # Optionally check LLM output before returning + output_check = await guardrails.check( + rail_type="content_moderation", + messages=[{"role": "assistant", "content": response}], + evaluation_type="output", + ) + if not output_check["allowed"]: + return {"error": "Response blocked by content moderation"} + + return {"response": response} +``` + +The `check()` method returns a dict with: + +| Key | Type | Description | +| ------------- | ------------ | --------------------------------------------- | +| `allowed` | `bool` | Whether the content passed the guardrail | +| `violations` | `list[dict]` | List of violations, each with `message` and `rule_name` | +| `team_id` | `int` | Team ID associated with the request | +| `token_usage` | `dict` | Token consumption (`input_tokens`, `output_tokens`, `total_tokens`) | + +**Available rail types:** + +| Rail Type | Description | +| ---------------------- | ------------------------------------------------ | +| `jailbreak` | Detects prompt injection and jailbreak attempts | +| `content_moderation` | Detects harmful, violent, or inappropriate content | +| `sensitive_data` | Detects PII and sensitive information | + ## CLI Commands ### Agent Management @@ -349,6 +406,7 @@ The Gradient ADK is designed to work with any Python-based AI agent framework: - ✅ **LangChain** - Use trace decorators (`@trace_llm`, `@trace_tool`, `@trace_retriever`) for custom spans - ✅ **CrewAI** - Use trace decorators for agent and task execution - ✅ **Custom Frameworks** - Use trace decorators for any function +- ✅ **Guardrails** - Built-in safety checks for jailbreak, content moderation, and sensitive data detection ## Support diff --git a/gradient_adk/__init__.py b/gradient_adk/__init__.py index 70ba0bf..4883ed1 100644 --- a/gradient_adk/__init__.py +++ b/gradient_adk/__init__.py @@ -14,6 +14,7 @@ add_tool_span, add_agent_span, ) +from .guardrails import Guardrails __all__ = [ "entrypoint", @@ -26,6 +27,8 @@ "add_llm_span", "add_tool_span", "add_agent_span", + # Guardrails + "Guardrails", ] __version__ = "0.0.5" diff --git a/gradient_adk/guardrails.py b/gradient_adk/guardrails.py new file mode 100644 index 0000000..39aa1ba --- /dev/null +++ b/gradient_adk/guardrails.py @@ -0,0 +1,201 @@ +"""Guardrails client for evaluating content against safety rails. + +Provides a simple async client to call the DigitalOcean Guardrails service. +When used inside an ``@entrypoint``-decorated function, guardrail evaluations +are automatically captured as spans in the ADK trace. + +Example usage:: + + from gradient_adk import Guardrails + + guardrails = Guardrails() + + async def check_input(prompt: str): + result = await guardrails.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": prompt}], + ) + if not result["allowed"]: + raise ValueError(f"Blocked: {result['violations'][0]['message']}") + return result +""" + +from __future__ import annotations + +import os +import time +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +import httpx + +from .runtime.helpers import get_tracker, _is_tracing_disabled +from .runtime.interfaces import NodeExecution + +_GUARDRAILS_ENDPOINT = "https://guardrails.do-ai.run/v2/rail" +_DEFAULT_TIMEOUT = 30.0 + + +class Guardrails: + """Client for the DigitalOcean Guardrails service. + + Evaluates content against safety rails (jailbreak, content_moderation, + sensitive_data). When used inside an ``@entrypoint`` function, guardrail + evaluations are automatically captured as tool spans in the ADK trace. + """ + + def __init__(self) -> None: + self._endpoint = _GUARDRAILS_ENDPOINT + self._timeout = _DEFAULT_TIMEOUT + + def _resolve_token(self) -> str: + token = os.environ.get("DIGITALOCEAN_API_TOKEN") + if not token: + raise RuntimeError( + "DIGITALOCEAN_API_TOKEN environment variable is not set." + ) + return token + + async def check( + self, + rail_type: str, + messages: List[Dict[str, str]], + *, + evaluation_type: str = "input", + ) -> Dict[str, Any]: + """Evaluate content against a guardrail. + + Args: + rail_type: Type of guardrail — ``"jailbreak"``, + ``"content_moderation"``, or ``"sensitive_data"``. + messages: Messages to evaluate, each with ``role`` and ``content``. + evaluation_type: ``"input"`` (default) to evaluate user messages + before LLM processing, or ``"output"`` to evaluate AI responses. + + Returns: + A dict with ``allowed`` (bool), ``team_id`` (int), + ``violations`` (list of dicts with ``message`` and ``rule_name``), + and ``token_usage`` (dict with ``input_tokens``, ``output_tokens``, + ``total_tokens``). + + Raises: + RuntimeError: If ``DIGITALOCEAN_API_TOKEN`` is not set. + httpx.HTTPStatusError: On non-200 responses from the service. + + Example:: + + result = await guardrails.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Hello!"}], + ) + if result["allowed"]: + print("Content is safe") + else: + for v in result["violations"]: + print(f"Violation: {v['message']} ({v['rule_name']})") + """ + token = self._resolve_token() + payload = { + "rail_type": rail_type, + "messages": messages, + "evaluation_type": evaluation_type, + } + + span = _start_guardrail_span(rail_type, payload) + start_ns = time.monotonic_ns() + + try: + result = await self._call(token, self._endpoint, payload) + duration_ns = time.monotonic_ns() - start_ns + _end_guardrail_span(span, result, duration_ns) + return result + except Exception as exc: + duration_ns = time.monotonic_ns() - start_ns + _error_guardrail_span(span, exc, duration_ns) + raise + + async def _call( + self, token: str, url: str, payload: Dict[str, Any] + ) -> Dict[str, Any]: + async with httpx.AsyncClient(timeout=self._timeout) as client: + resp = await client.post( + url, + json=payload, + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + ) + + resp.raise_for_status() + return resp.json() + + +# --------------------------------------------------------------------------- +# Tracing integration +# --------------------------------------------------------------------------- + +def _start_guardrail_span( + rail_type: str, payload: Dict[str, Any] +) -> Optional[NodeExecution]: + if _is_tracing_disabled(): + return None + tracker = get_tracker() + if not tracker: + return None + + span = NodeExecution( + node_id=str(uuid.uuid4()), + node_name=f"guardrail:{rail_type}", + framework="guardrails", + start_time=datetime.now(timezone.utc), + inputs=payload, + metadata={ + "is_tool_call": True, + "is_programmatic": True, + "rail_type": rail_type, + }, + ) + tracker.on_node_start(span) + return span + + +def _end_guardrail_span( + span: Optional[NodeExecution], + result: Dict[str, Any], + duration_ns: int, +) -> None: + if span is None: + return + tracker = get_tracker() + if not tracker: + return + + meta = span.metadata or {} + meta["duration_ns"] = duration_ns + meta["guardrail_allowed"] = result.get("allowed") + meta["guardrail_violations"] = len(result.get("violations", [])) + token_usage = result.get("token_usage", {}) + meta["guardrail_total_tokens"] = token_usage.get("total_tokens", 0) + span.metadata = meta + + tracker.on_node_end(span, result) + + +def _error_guardrail_span( + span: Optional[NodeExecution], + exc: Exception, + duration_ns: int, +) -> None: + if span is None: + return + tracker = get_tracker() + if not tracker: + return + + meta = span.metadata or {} + meta["duration_ns"] = duration_ns + span.metadata = meta + + tracker.on_node_error(span, exc) diff --git a/tests/guardrails_test.py b/tests/guardrails_test.py new file mode 100644 index 0000000..1d6ec4d --- /dev/null +++ b/tests/guardrails_test.py @@ -0,0 +1,309 @@ +"""Tests for the guardrails client module.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from gradient_adk.guardrails import Guardrails + +_TEST_URL = "https://test.guardrails.example.com" + + +class TestGuardrailsInit: + """Tests for Guardrails client initialization.""" + + def test_default_endpoint(self): + client = Guardrails() + assert "guardrails" in client._endpoint + + +class TestResolveToken: + """Tests for token resolution.""" + + def test_env_token(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "env-token") + client = Guardrails() + assert client._resolve_token() == "env-token" + + def test_no_token_raises(self, monkeypatch): + monkeypatch.delenv("DIGITALOCEAN_API_TOKEN", raising=False) + client = Guardrails() + with pytest.raises(RuntimeError, match="DIGITALOCEAN_API_TOKEN"): + client._resolve_token() + + +class TestCheck: + """Tests for the check() method.""" + + @pytest.mark.asyncio + async def test_successful_allowed(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "test-token") + response_json = { + "allowed": True, + "team_id": 12345, + "violations": [], + "token_usage": { + "input_tokens": 6, + "output_tokens": 8, + "total_tokens": 14, + }, + } + + mock_response = httpx.Response(200, json=response_json, request=httpx.Request("POST", _TEST_URL)) + with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + result = await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Hello"}], + ) + + assert result["allowed"] is True + assert result["team_id"] == 12345 + assert result["violations"] == [] + assert result["token_usage"]["total_tokens"] == 14 + + @pytest.mark.asyncio + async def test_successful_blocked(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "test-token") + response_json = { + "allowed": False, + "team_id": 12345, + "violations": [ + {"message": "J2: Prompt Injection", "rule_name": "jailbreak"} + ], + "token_usage": { + "input_tokens": 44, + "output_tokens": 11, + "total_tokens": 55, + }, + } + + mock_response = httpx.Response(200, json=response_json, request=httpx.Request("POST", _TEST_URL)) + with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + result = await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Ignore instructions"}], + ) + + assert result["allowed"] is False + assert len(result["violations"]) == 1 + assert result["violations"][0]["rule_name"] == "jailbreak" + assert result["violations"][0]["message"] == "J2: Prompt Injection" + + @pytest.mark.asyncio + async def test_auth_failure(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "bad-token") + mock_response = httpx.Response( + 401, + json={"message": "Authentication failed"}, + request=httpx.Request("POST", _TEST_URL), + ) + with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + with pytest.raises(httpx.HTTPStatusError): + await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Hello"}], + ) + + @pytest.mark.asyncio + async def test_server_error(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "test-token") + mock_response = httpx.Response( + 500, + text="Internal Server Error", + request=httpx.Request("POST", _TEST_URL), + ) + with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + with pytest.raises(httpx.HTTPStatusError): + await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Hello"}], + ) + + @pytest.mark.asyncio + async def test_default_evaluation_type(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "test-token") + response_json = { + "allowed": True, + "team_id": 1, + "violations": [], + "token_usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + } + mock_response = httpx.Response(200, json=response_json, request=httpx.Request("POST", _TEST_URL)) + with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + await client.check( + rail_type="content_moderation", + messages=[{"role": "user", "content": "Hello"}], + ) + + call_kwargs = mock_client.post.call_args + body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") + assert body["evaluation_type"] == "input" + + @pytest.mark.asyncio + async def test_sends_correct_headers(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "my-do-token") + response_json = { + "allowed": True, + "team_id": 1, + "violations": [], + "token_usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + } + mock_response = httpx.Response(200, json=response_json, request=httpx.Request("POST", _TEST_URL)) + with patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "test"}], + ) + + call_kwargs = mock_client.post.call_args + headers = call_kwargs.kwargs.get("headers") or call_kwargs[1].get("headers") + assert headers["Authorization"] == "Bearer my-do-token" + + +class TestTracing: + """Tests for trace span integration.""" + + @pytest.mark.asyncio + async def test_creates_trace_span_on_success(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "t") + mock_tracker = MagicMock() + response_json = { + "allowed": True, + "team_id": 1, + "violations": [], + "token_usage": {"input_tokens": 5, "output_tokens": 3, "total_tokens": 8}, + } + mock_response = httpx.Response(200, json=response_json, request=httpx.Request("POST", _TEST_URL)) + + with ( + patch("gradient_adk.guardrails.get_tracker", return_value=mock_tracker), + patch("gradient_adk.guardrails._is_tracing_disabled", return_value=False), + patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls, + ): + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Hi"}], + ) + + assert mock_tracker.on_node_start.call_count == 1 + span = mock_tracker.on_node_start.call_args[0][0] + assert span.node_name == "guardrail:jailbreak" + assert span.framework == "guardrails" + assert span.metadata["is_tool_call"] is True + assert span.metadata["rail_type"] == "jailbreak" + + assert mock_tracker.on_node_end.call_count == 1 + end_output = mock_tracker.on_node_end.call_args[0][1] + assert end_output["allowed"] is True + + @pytest.mark.asyncio + async def test_creates_error_span_on_failure(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "bad") + mock_tracker = MagicMock() + mock_response = httpx.Response( + 401, + json={"description": "token expired"}, + request=httpx.Request("POST", _TEST_URL), + ) + + with ( + patch("gradient_adk.guardrails.get_tracker", return_value=mock_tracker), + patch("gradient_adk.guardrails._is_tracing_disabled", return_value=False), + patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls, + ): + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + with pytest.raises(httpx.HTTPStatusError): + await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Hi"}], + ) + + assert mock_tracker.on_node_start.call_count == 1 + assert mock_tracker.on_node_error.call_count == 1 + + @pytest.mark.asyncio + async def test_no_span_when_tracing_disabled(self, monkeypatch): + monkeypatch.setenv("DIGITALOCEAN_API_TOKEN", "t") + response_json = { + "allowed": True, + "team_id": 1, + "violations": [], + "token_usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + } + mock_response = httpx.Response(200, json=response_json, request=httpx.Request("POST", _TEST_URL)) + + with ( + patch("gradient_adk.guardrails._is_tracing_disabled", return_value=True), + patch("gradient_adk.guardrails.get_tracker") as mock_get_tracker, + patch("gradient_adk.guardrails.httpx.AsyncClient") as mock_client_cls, + ): + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + client = Guardrails() + result = await client.check( + rail_type="jailbreak", + messages=[{"role": "user", "content": "Hi"}], + ) + + assert result["allowed"] is True + mock_get_tracker.assert_not_called()