From 8dfde4fe40c2de611d990ee38f3fb25ce755ff11 Mon Sep 17 00:00:00 2001 From: Mehdi Date: Wed, 4 Feb 2026 04:51:18 +0000 Subject: [PATCH 1/9] Fix quality --- .env | 54 +++ .gitignore | 30 +- Makefile | 139 ++++++- README.md | 15 +- VERSION | 1 + claude_code_api/api/chat.py | 92 +++-- claude_code_api/api/projects.py | 7 +- claude_code_api/config/models.json | 45 +++ claude_code_api/core/claude_manager.py | 10 +- claude_code_api/core/config.py | 51 ++- claude_code_api/core/database.py | 3 +- claude_code_api/core/session_manager.py | 3 +- claude_code_api/main.py | 58 ++- claude_code_api/models/claude.py | 124 +++---- claude_code_api/models/openai.py | 66 +++- claude_code_api/utils/parser.py | 103 ++++-- claude_code_api/utils/streaming.py | 194 +++++----- pyproject.toml | 3 + scripts/record_claude_fixture.py | 110 ++++++ scripts/run-sonar-cloud.sh | 82 +++++ scripts/upload-sbom.sh | 98 +++++ scripts/vault-helper.sh | 347 ++++++++++++++++++ sonar-project.properties | 29 ++ tests/conftest.py | 61 ++- tests/fixtures/claude_stream_simple.jsonl | 3 + tests/fixtures/claude_stream_tool_calls.jsonl | 5 + tests/fixtures/index.json | 17 + tests/test_api.sh | 12 +- tests/test_claude_working.py | 144 ++------ tests/test_e2e_live_api.py | 87 +++++ tests/test_end_to_end.py | 104 ++++-- tests/test_openapi.py | 22 ++ tests/test_real_api.py | 51 +-- 33 files changed, 1737 insertions(+), 433 deletions(-) create mode 100644 .env create mode 100644 VERSION create mode 100644 claude_code_api/config/models.json create mode 100644 scripts/record_claude_fixture.py create mode 100755 scripts/run-sonar-cloud.sh create mode 100755 scripts/upload-sbom.sh create mode 100755 scripts/vault-helper.sh create mode 100644 sonar-project.properties create mode 100644 tests/fixtures/claude_stream_simple.jsonl create mode 100644 tests/fixtures/claude_stream_tool_calls.jsonl create mode 100644 tests/fixtures/index.json create mode 100644 tests/test_e2e_live_api.py create mode 100644 tests/test_openapi.py diff --git a/.env b/.env new file mode 100644 index 0000000..c1d804b --- /dev/null +++ b/.env @@ -0,0 +1,54 @@ +# This file is committed to git no secrets. + +#!/usr/bin/env bash + +# Vault environment loader wrapper +# Usage: source .env.vault (never run directly) + +set -o pipefail + +if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then + echo "This script must be sourced: source .env.vault" >&2 + exit 1 +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +HELPER_PATH="${VAULT_HELPER_PATH:-${SCRIPT_DIR}/scripts/vault-helper.sh}" + +if [[ ! -r "$HELPER_PATH" ]]; then + echo "Vault helper not found at $HELPER_PATH" >&2 + return 1 +fi + +# shellcheck source=./scripts/vault-helper.sh +source "$HELPER_PATH" + +DEFAULT_VAULT_SECRET_DEFS=$'kv/Sonarqube/sonarqube|SONAR_TOKEN=SONAR_TOKEN SONAR_TOKEN=sonar_token SONAR_TOKEN=token\nkv/dependencytrack|DTRACK_API_KEY=DTRACK_API_KEY DTRACK_API_KEY=api_key DTRACK_API_KEY=token' +DEFAULT_VAULT_REQUIRED_VARS="SONAR_TOKEN DTRACK_API_KEY" + +if [[ -z "${VAULT_TOKEN:-}" ]]; then + token_candidates=() + if [[ -n "${VAULT_TOKEN_FILE:-}" ]]; then + token_candidates+=("$VAULT_TOKEN_FILE") + fi + token_candidates+=("${HOME}/.vault-token" "/home/vscode/.vault-token" "/root/.vault-token") + for token_path in "${token_candidates[@]}"; do + if [[ -r "$token_path" ]]; then + VAULT_TOKEN_FILE="$token_path" + export VAULT_TOKEN_FILE + break + fi + done +fi + +SECRET_DEFS="${VAULT_SECRET_PATHS:-$DEFAULT_VAULT_SECRET_DEFS}" +REQUIRED_VARS="${VAULT_REQUIRED_VARS:-$DEFAULT_VAULT_REQUIRED_VARS}" + +vault_helper::load_from_definitions "$SECRET_DEFS" "$REQUIRED_VARS" "$VAULT_TOKEN_FILE" + +# Commented out for CI/automated testing +# SONAR_TOKEN="" +DTR_PROJECT_KEY= +# DTRACK_API_KEY="" +DTRACK_PROJECT=sonarqube-mcp +DTRACK_PROJECT_VERSION=main diff --git a/.gitignore b/.gitignore index 8bf1c53..8d917ca 100644 --- a/.gitignore +++ b/.gitignore @@ -124,7 +124,8 @@ celerybeat.pid *.sage.py # Environments -.env +.env.vault +.env.cloud .venv env/ venv/ @@ -132,6 +133,31 @@ ENV/ env.bak/ venv.bak/ +# SonarQube scanner working directory +dist/quality/sonar/scannerwork/ + +# Quality and Security Artifacts +dist/quality/ +dist/security/ +dist/artifacts/ + +# Gitleaks reports +*.gitleaks-report.json +gitleaks-report.json + +# SBOM reports +*.sbom.json +sbom.json +cyclonedx-*.json + +# Renovate reports +renovate-report.json +renovate-summary-*.txt + +# Semgrep reports +semgrep-report*.json +semgrep-report*.sarif + # Spyder project settings .spyderproject .spyproject @@ -328,4 +354,4 @@ tests/__pycache__/ docs/start.md docs/typescript-translation-plan.md -# Note: test scripts in tests/ directory should be tracked in git \ No newline at end of file +# Note: test scripts in tests/ directory should be tracked in git diff --git a/Makefile b/Makefile index 6425ef8..2ec321c 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,32 @@ # Claude Code API - Simple & Working +SHELL := /bin/bash +.DEFAULT_GOAL := help + +# Directory structure +ARTIFACTS_DIR ?= dist/artifacts +QUALITY_DIR ?= dist/quality +SECURITY_DIR ?= dist/security +TESTS_DIR ?= dist/tests + +BUILD_DIR := $(ARTIFACTS_DIR)/bin +COVERAGE_DIR := $(QUALITY_DIR)/coverage +SONAR_DIR := $(QUALITY_DIR)/sonar +SBOM_DIR := $(SECURITY_DIR)/sbom +GITLEAKS_DIR := $(SECURITY_DIR)/gitleaks + +# Version info +VERSION_FILE := $(shell cat VERSION 2>/dev/null || echo "1.0.0") +VERSION ?= $(VERSION_FILE) +COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown") +BUILD_DATE ?= $(shell date -u +'%Y-%m-%dT%H:%M:%SZ' 2>/dev/null || "") + +# Dependency-Track settings +DTRACK_BASE_URL ?= +DTRACK_API_KEY ?= +DTRACK_PROJECT ?= claude-code-api +DTRACK_PROJECT_VERSION ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo "dev") + # Python targets install: pip install -e . @@ -56,6 +83,105 @@ kill: fi; \ fi +# Quality and Security targets +.PHONY: sonar sonar-cloud coverage-sonar sbom sbom-upload gitleaks fmt lint vet + +sonar: ## Run sonar-scanner for SonarQube analysis + @mkdir -p $(SONAR_DIR) $(COVERAGE_DIR) + @echo "Generating coverage report for SonarQube..." + @python -m pytest --cov=claude_code_api --cov-report=xml --cov-report=term-missing -v tests/ + @if command -v sonar-scanner >/dev/null 2>&1; then \ + if [ -f ".env.vault" ]; then \ + . ./.env.vault; \ + fi; \ + if [ -f ".env" ]; then \ + set -a; . ./.env; set +a; \ + fi; \ + if [ -n "$${VAULT_SECRET_PATHS:-}" ] || [ -n "$${VAULT_REQUIRED_VARS:-}" ]; then \ + if [ -f "./scripts/vault-helper.sh" ]; then \ + . ./scripts/vault-helper.sh; \ + vault_helper::load_from_definitions "$${VAULT_SECRET_PATHS:-}" "$${VAULT_REQUIRED_VARS:-}" "$${VAULT_TOKEN_FILE:-}"; \ + fi; \ + fi; \ + SONAR_HOST_URL="$${SONAR_HOST_URL:-$${SONAR_URL:-}}"; \ + if [ -z "$$SONAR_HOST_URL" ]; then \ + echo "SONAR_URL or SONAR_HOST_URL is required (e.g., https://sonarcloud.io or https://sonar.local)"; \ + exit 1; \ + fi; \ + case "$$SONAR_HOST_URL" in \ + http://*|https://*) ;; \ + *) echo "SONAR_URL must include http(s) scheme: $$SONAR_HOST_URL"; exit 1 ;; \ + esac; \ + if [ -z "$${SONAR_TOKEN:-}" ]; then \ + echo "SONAR_TOKEN not set - proceeding without authentication"; \ + fi; \ + sonar-scanner \ + -Dsonar.host.url=$$SONAR_HOST_URL \ + -Dsonar.token=$$SONAR_TOKEN \ + -Dsonar.projectVersion=$${VERSION:-1.0.0} \ + -Dsonar.working.directory=$(SONAR_DIR)/scannerwork; \ + else \ + echo "sonar-scanner not found. Install with: brew install sonar-scanner or download from https://docs.sonarqube.org/latest/analysis/scan/sonarscanner/"; \ + exit 1; \ + fi + +sonar-cloud: ## Run sonar-scanner for SonarCloud (uses different token/env) + @echo "Running SonarCloud scanner..." + @./scripts/run-sonar-cloud.sh + +coverage-sonar: ## Generate coverage for SonarQube + @mkdir -p $(COVERAGE_DIR) + @python -m pytest --cov=claude_code_api --cov-report=xml --cov-report=term-missing -v tests/ + @echo "Coverage XML generated: $(COVERAGE_DIR)/coverage.xml" + +sbom: ## Generate SBOM with syft + @mkdir -p $(SBOM_DIR) + @if command -v syft >/dev/null 2>&1; then \ + syft dir:. -o cyclonedx-json=$(SBOM_DIR)/sbom.json; \ + echo "SBOM generated: $(SBOM_DIR)/sbom.json"; \ + else \ + echo "syft not found. Install with: brew install syft or visit https://github.com/anchore/syft"; \ + exit 1; \ + fi + +sbom-upload: sbom ## Generate (if needed) and upload SBOM to Dependency-Track + @./scripts/upload-sbom.sh $(SBOM_DIR)/sbom.json + +gitleaks: ## Run gitleaks to detect secrets + @mkdir -p $(GITLEAKS_DIR) + @if command -v gitleaks >/dev/null 2>&1; then \ + gitleaks detect --source . --report-path $(GITLEAKS_DIR)/gitleaks-report.json; \ + else \ + echo "gitleaks not found. Install with: brew install gitleaks"; \ + exit 1; \ + fi + +fmt: ## Format Python code with black + @if command -v black >/dev/null 2>&1; then \ + black claude_code_api/ tests/; \ + else \ + echo "black not found. Install with: pip install black"; \ + fi + +lint: ## Run Python linters (flake8, isort) + @if command -v flake8 >/dev/null 2>&1; then \ + flake8 claude_code_api/ tests/; \ + else \ + echo "flake8 not found. Install with: pip install flake8"; \ + fi + @if command -v isort >/dev/null 2>&1; then \ + isort --check-only claude_code_api/ tests/; \ + else \ + echo "isort not found. Install with: pip install isort"; \ + fi + +vet: ## Run type checking with mypy + @if command -v mypy >/dev/null 2>&1; then \ + mypy claude_code_api/; \ + else \ + echo "mypy not found. Install with: pip install mypy"; \ + fi + help: @echo "Claude Code API Commands:" @echo "" @@ -70,7 +196,7 @@ help: @echo " make start-prod - Start Python API server (production)" @echo "" @echo "TypeScript API:" - @echo " make install-js - Install TypeScript dependencies" + @echo " make install-js - Install TypeScript dependencies" @echo " make test-js - Run TypeScript unit tests" @echo " make test-js-real - Run Python test suite against TypeScript API" @echo " make start-js - Start TypeScript API server (production)" @@ -78,6 +204,17 @@ help: @echo " make start-js-prod - Build and start TypeScript API server (production)" @echo " make build-js - Build TypeScript project" @echo "" + @echo "Quality & Security:" + @echo " make sonar - Run SonarQube analysis (generates coverage + scans)" + @echo " make sonar-cloud - Run SonarCloud scanner (uses SONAR_CLOUD_TOKEN)" + @echo " make coverage-sonar - Generate coverage XML for SonarQube" + @echo " make sbom - Generate SBOM with syft" + @echo " make sbom-upload - Upload SBOM to Dependency-Track" + @echo " make gitleaks - Run gitleaks to detect secrets" + @echo " make fmt - Format Python code with black" + @echo " make lint - Run Python linters (flake8, isort)" + @echo " make vet - Run type checking with mypy" + @echo "" @echo "General:" @echo " make clean - Clean up Python cache files" @echo " make kill PORT=X - Kill process on specific port" diff --git a/README.md b/README.md index 58e6eef..4512d4e 100644 --- a/README.md +++ b/README.md @@ -52,9 +52,12 @@ make start ## Supported Models -- `claude-opus-4-20250514` - Claude Opus 4 (Most powerful) -- `claude-sonnet-4-20250514` - Claude Sonnet 4 (Latest Sonnet) -- `claude-3-7-sonnet-20250219` - Claude Sonnet 3.7 (Advanced) +Model definitions live in `claude_code_api/config/models.json`. +Override with `CLAUDE_CODE_API_MODELS_PATH` to point at a custom JSON file. + +- `claude-opus-4-5-20250929` - Claude Opus 4.5 (Most powerful) +- `claude-sonnet-4-5-20250929` - Claude Sonnet 4.5 (Latest Sonnet) +- `claude-haiku-4-5-20250929` - Claude Haiku 4.5 (Fast & cost-effective) - `claude-3-5-haiku-20241022` - Claude Haiku 3.5 (Fast & cost-effective) ## Quick Start @@ -133,7 +136,7 @@ make check-claude # Check if Claude Code CLI is available curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ - "model": "claude-3-5-haiku-20241022", + "model": "claude-sonnet-4-5-20250929", "messages": [ {"role": "user", "content": "Hello!"} ] @@ -152,7 +155,7 @@ curl http://localhost:8000/v1/models curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ - "model": "claude-3-5-haiku-20241022", + "model": "claude-sonnet-4-5-20250929", "messages": [ {"role": "user", "content": "Tell me a joke"} ], @@ -273,4 +276,4 @@ Response: ## License -This project is licensed under the GNU General Public License v3.0 - see the LICENSE file for details. \ No newline at end of file +This project is licensed under the GNU General Public License v3.0 - see the LICENSE file for details. diff --git a/VERSION b/VERSION new file mode 100644 index 0000000..3eefcb9 --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +1.0.0 diff --git a/claude_code_api/api/chat.py b/claude_code_api/api/chat.py index 5e928ae..9c92334 100644 --- a/claude_code_api/api/chat.py +++ b/claude_code_api/api/chat.py @@ -1,34 +1,52 @@ """Chat completions API endpoint - OpenAI compatible.""" -import uuid import json -from datetime import datetime from typing import Dict, Any from fastapi import APIRouter, Request, HTTPException, status -from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.responses import StreamingResponse from pydantic import ValidationError import structlog from claude_code_api.models.openai import ( ChatCompletionRequest, ChatCompletionResponse, - ChatCompletionChoice, - ChatMessage, - ChatCompletionUsage, ErrorResponse ) -from claude_code_api.models.claude import validate_claude_model, get_model_info +from claude_code_api.models.claude import validate_claude_model from claude_code_api.core.claude_manager import create_project_directory -from claude_code_api.core.session_manager import SessionManager, ConversationManager +from claude_code_api.core.session_manager import SessionManager from claude_code_api.utils.streaming import create_sse_response, create_non_streaming_response -from claude_code_api.utils.parser import ClaudeOutputParser, estimate_tokens +from claude_code_api.utils.parser import ClaudeOutputParser, OpenAIConverter, estimate_tokens, normalize_claude_message logger = structlog.get_logger() router = APIRouter() +CHAT_COMPLETION_RESPONSES = { + 200: { + "description": "Chat completion response (JSON when stream=false, SSE when stream=true).", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/ChatCompletionResponse"} + }, + "text/event-stream": { + "schema": {"$ref": "#/components/schemas/ChatCompletionChunk"} + } + } + }, + 400: {"model": ErrorResponse}, + 422: {"model": ErrorResponse}, + 503: {"model": ErrorResponse}, + 500: {"model": ErrorResponse} +} -@router.post("/chat/completions") + +@router.post( + "/chat/completions", + response_model=ChatCompletionResponse, + responses=CHAT_COMPLETION_RESPONSES +) async def create_chat_completion( + request: ChatCompletionRequest, req: Request ) -> Any: """Create a chat completion, compatible with OpenAI API.""" @@ -44,30 +62,6 @@ async def create_chat_completion( user_agent=req.headers.get("user-agent", "unknown"), raw_body=raw_body.decode()[:1000] if raw_body else "empty" ) - - # Parse JSON manually to see validation errors - if raw_body: - try: - json_data = json.loads(raw_body.decode()) - logger.info("JSON parsed successfully", data_keys=list(json_data.keys())) - except json.JSONDecodeError as e: - logger.error("JSON decode error", error=str(e)) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": {"message": f"Invalid JSON: {str(e)}", "type": "invalid_request_error"}} - ) - - # Try to validate with Pydantic - try: - request = ChatCompletionRequest(**json_data) - logger.info("Pydantic validation successful") - except ValidationError as e: - logger.error("Pydantic validation failed", error=str(e), errors=e.errors()) - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail={"error": {"message": f"Validation error: {str(e)}", "type": "invalid_request_error", "details": e.errors()}} - ) - except HTTPException: raise except Exception as e: @@ -97,7 +91,6 @@ async def create_chat_completion( try: # Validate model claude_model = validate_claude_model(request.model) - model_info = get_model_info(claude_model) # Validate message format if not request.messages: @@ -204,10 +197,11 @@ async def create_chat_completion( # Return streaming response return StreamingResponse( create_sse_response(claude_session_id, claude_model, claude_process), - media_type="text/plain", + media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", + "X-Accel-Buffering": "no", "X-Session-ID": claude_session_id, "X-Project-ID": project_id } @@ -215,6 +209,7 @@ async def create_chat_completion( else: # Collect all output for non-streaming response messages = [] + parser = ClaudeOutputParser() async for claude_message in claude_process.get_output(): # Log each message from Claude @@ -229,15 +224,11 @@ async def create_chat_completion( ) messages.append(claude_message) - - # Check if it's a final message by looking at dict structure - is_final = False - if isinstance(claude_message, dict): - is_final = claude_message.get("type") == "result" - - # Stop on final message or after a reasonable number of messages - if is_final or len(messages) > 10: # Safety limit for testing - break + normalized = normalize_claude_message(claude_message) + if normalized: + parser.parse_message(normalized) + if parser.is_final_message(normalized): + break # Log what we collected logger.info( @@ -246,12 +237,11 @@ async def create_chat_completion( message_types=[msg.get("type") if isinstance(msg, dict) else type(msg).__name__ for msg in messages] ) - # Simple usage tracking without parsing Claude internals - usage_summary = {"total_tokens": 50, "total_cost": 0.001} + usage_summary = OpenAIConverter.calculate_usage(parser) await session_manager.update_session( session_id=claude_session_id, - tokens_used=50, - cost=0.001 + tokens_used=usage_summary.get("total_tokens", 0), + cost=parser.total_cost ) # Create non-streaming response @@ -259,7 +249,7 @@ async def create_chat_completion( messages=messages, session_id=claude_session_id, model=claude_model, - usage_summary=usage_summary + usage=usage_summary ) # Add extension fields @@ -273,7 +263,7 @@ async def create_chat_completion( has_choices_0=bool(response.get("choices") and len(response["choices"]) > 0), choices_0_keys=list(response["choices"][0].keys()) if response.get("choices") and len(response["choices"]) > 0 else [], message_keys=list(response["choices"][0]["message"].keys()) if response.get("choices") and len(response["choices"]) > 0 and "message" in response["choices"][0] else [], - content_length=len(response["choices"][0]["message"].get("content", "")) if response.get("choices") and len(response["choices"]) > 0 and "message" in response["choices"][0] else 0, + content_length=len((response["choices"][0]["message"].get("content") or "")) if response.get("choices") and len(response["choices"]) > 0 and "message" in response["choices"][0] else 0, full_response_keys=list(response.keys()), response_size=len(str(response)) ) diff --git a/claude_code_api/api/projects.py b/claude_code_api/api/projects.py index f7791cf..0b0795a 100644 --- a/claude_code_api/api/projects.py +++ b/claude_code_api/api/projects.py @@ -72,12 +72,7 @@ async def create_project( # Create project directory if project_request.path: # Validate path - try: - project_path = validate_path(project_request.path, settings.project_root) - try: - project_path = validate_path(project_request.path, settings.project_root) - except HTTPException: - raise + project_path = validate_path(project_request.path, settings.project_root) os.makedirs(project_path, exist_ok=True) else: diff --git a/claude_code_api/config/models.json b/claude_code_api/config/models.json new file mode 100644 index 0000000..07f0cec --- /dev/null +++ b/claude_code_api/config/models.json @@ -0,0 +1,45 @@ +{ + "default_model": "claude-sonnet-4-5-20250929", + "models": [ + { + "id": "claude-opus-4-5-20250929", + "name": "Claude Opus 4.5", + "description": "Most powerful Claude model for complex reasoning", + "max_tokens": 500000, + "input_cost_per_1k": 15.0, + "output_cost_per_1k": 75.0, + "supports_streaming": true, + "supports_tools": true + }, + { + "id": "claude-sonnet-4-5-20250929", + "name": "Claude Sonnet 4.5", + "description": "Latest Sonnet model with enhanced capabilities", + "max_tokens": 500000, + "input_cost_per_1k": 3.0, + "output_cost_per_1k": 15.0, + "supports_streaming": true, + "supports_tools": true + }, + { + "id": "claude-haiku-4-5-20250929", + "name": "Claude Haiku 4.5", + "description": "Fast and cost-effective model for quick tasks", + "max_tokens": 200000, + "input_cost_per_1k": 0.25, + "output_cost_per_1k": 1.25, + "supports_streaming": true, + "supports_tools": true + }, + { + "id": "claude-3-5-haiku-20241022", + "name": "Claude Haiku 3.5", + "description": "Fast and cost-effective model for quick tasks", + "max_tokens": 200000, + "input_cost_per_1k": 0.25, + "output_cost_per_1k": 1.25, + "supports_streaming": true, + "supports_tools": true + } + ] +} diff --git a/claude_code_api/core/claude_manager.py b/claude_code_api/core/claude_manager.py index 97ff77c..c3dd299 100644 --- a/claude_code_api/core/claude_manager.py +++ b/claude_code_api/core/claude_manager.py @@ -11,6 +11,7 @@ import structlog from .config import settings +from claude_code_api.models.claude import get_default_model logger = structlog.get_logger() @@ -56,7 +57,7 @@ async def start( "Starting Claude process", session_id=self.session_id, project_path=self.project_path, - model=model or settings.default_model + model=model or get_default_model() ) # Start process from src directory (where Claude works without API key) @@ -104,7 +105,10 @@ async def _read_output(self): continue try: - data = json.loads(line_text) + payload = line_text + if payload.startswith("data: "): + payload = payload[6:].strip() + data = json.loads(payload) # Extract Claude's session ID from the first message if not claude_session_id and data.get("session_id"): claude_session_id = data["session_id"] @@ -272,7 +276,7 @@ async def create_session( # Start process success = await process.start( prompt=prompt, - model=model or settings.default_model, + model=model or get_default_model(), system_prompt=system_prompt, resume_session=resume_session ) diff --git a/claude_code_api/core/config.py b/claude_code_api/core/config.py index 9a69275..5ee5a5f 100644 --- a/claude_code_api/core/config.py +++ b/claude_code_api/core/config.py @@ -4,7 +4,7 @@ import shutil from typing import List, Union from pydantic import Field, field_validator -from pydantic_settings import BaseSettings +from pydantic_settings import BaseSettings, SettingsConfigDict def find_claude_binary() -> str: @@ -51,8 +51,50 @@ def find_claude_binary() -> str: return "claude" # Final fallback +def _looks_like_dotenv(path: str) -> bool: + """Return True when a file appears to be a simple KEY=VALUE dotenv file.""" + try: + with open(path, "r", encoding="utf-8") as handle: + for line in handle: + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + if stripped.startswith("#!") or stripped.startswith("set "): + return False + if "BASH_SOURCE" in stripped or "[[" in stripped: + return False + if stripped.startswith(("if ", "fi", "for ", "done", "source ")): + return False + if stripped.startswith("export "): + stripped = stripped[len("export "):].lstrip() + return "=" in stripped + except FileNotFoundError: + return False + except OSError: + return False + return True + + +def _resolve_env_file() -> str | None: + """Pick a dotenv file only when it is likely compatible.""" + explicit = os.getenv("CLAUDE_CODE_API_ENV_FILE") + if explicit: + return explicit + for candidate in (".env.local", ".env"): + if os.path.exists(candidate) and _looks_like_dotenv(candidate): + return candidate + return None + + class Settings(BaseSettings): """Application settings.""" + + model_config = SettingsConfigDict( + env_file=_resolve_env_file(), + env_file_encoding="utf-8", + case_sensitive=False, + extra="ignore", + ) # API Configuration api_title: str = "Claude Code API Gateway" @@ -77,7 +119,7 @@ def parse_api_keys(cls, v): # Claude Configuration claude_binary_path: str = find_claude_binary() claude_api_key: str = "" - default_model: str = "claude-3-5-sonnet-20241022" + default_model: str = "claude-sonnet-4-5-20250929" max_concurrent_sessions: int = 10 session_timeout_minutes: int = 30 @@ -112,11 +154,6 @@ def parse_cors_lists(cls, v): streaming_chunk_size: int = 1024 streaming_timeout_seconds: int = 300 - class Config: - env_file = ".env" - env_file_encoding = "utf-8" - case_sensitive = False - # Create global settings instance settings = Settings() diff --git a/claude_code_api/core/database.py b/claude_code_api/core/database.py index 6d07481..791e1ac 100644 --- a/claude_code_api/core/database.py +++ b/claude_code_api/core/database.py @@ -12,6 +12,7 @@ import structlog from .config import settings +from claude_code_api.models.claude import get_default_model logger = structlog.get_logger() @@ -53,7 +54,7 @@ class Session(Base): id = Column(String, primary_key=True) project_id = Column(String, ForeignKey("projects.id"), nullable=False) title = Column(String) - model = Column(String, default=settings.default_model) + model = Column(String, default=get_default_model) system_prompt = Column(Text) created_at = Column(DateTime, default=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) diff --git a/claude_code_api/core/session_manager.py b/claude_code_api/core/session_manager.py index 870c579..3747ec4 100644 --- a/claude_code_api/core/session_manager.py +++ b/claude_code_api/core/session_manager.py @@ -7,6 +7,7 @@ import structlog from claude_code_api.core.config import settings +from claude_code_api.models.claude import get_default_model from claude_code_api.core.database import db_manager, Session, Message from claude_code_api.core.claude_manager import ClaudeProcess @@ -74,7 +75,7 @@ async def create_session( session_info = SessionInfo( session_id=session_id, project_id=project_id, - model=model or settings.default_model, + model=model or get_default_model(), system_prompt=system_prompt ) diff --git a/claude_code_api/main.py b/claude_code_api/main.py index ebdce9e..d4282d0 100644 --- a/claude_code_api/main.py +++ b/claude_code_api/main.py @@ -10,7 +10,8 @@ from contextlib import asynccontextmanager from typing import AsyncGenerator -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, status +from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse import structlog @@ -24,6 +25,7 @@ from claude_code_api.api.projects import router as projects_router from claude_code_api.api.sessions import router as sessions_router from claude_code_api.core.auth import auth_middleware +from claude_code_api.models.openai import ChatCompletionChunk # Configure structured logging @@ -91,6 +93,40 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: lifespan=lifespan ) + +def custom_openapi(): + """Extend OpenAPI schema with streaming chunk models.""" + if app.openapi_schema: + return app.openapi_schema + from fastapi.openapi.utils import get_openapi + + schema = get_openapi( + title=app.title, + version=app.version, + description=app.description, + routes=app.routes + ) + + components = schema.setdefault("components", {}).setdefault("schemas", {}) + chunk_schema = ChatCompletionChunk.model_json_schema( + ref_template="#/components/schemas/{model}" + ) + defs = chunk_schema.pop("$defs", {}) + for name, definition in defs.items(): + components.setdefault(name, definition) + components.setdefault("ChatCompletionChunk", chunk_schema) + if "ChatMessage" not in components: + if "ChatMessage-Input" in components: + components["ChatMessage"] = components["ChatMessage-Input"] + elif "ChatMessage-Output" in components: + components["ChatMessage"] = components["ChatMessage-Output"] + + app.openapi_schema = schema + return app.openapi_schema + + +app.openapi = custom_openapi + # CORS middleware app.add_middleware( CORSMiddleware, @@ -118,6 +154,26 @@ async def http_exception_handler(request, exc): ) +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc): + """Return OpenAI-style errors for validation failures.""" + status_code = status.HTTP_422_UNPROCESSABLE_ENTITY + for error in exc.errors(): + if error.get("type") in {"value_error.jsondecode", "json_invalid"}: + status_code = status.HTTP_400_BAD_REQUEST + break + return JSONResponse( + status_code=status_code, + content={ + "error": { + "message": "Validation error", + "type": "invalid_request_error", + "details": exc.errors() + } + } + ) + + @app.exception_handler(Exception) async def global_exception_handler(request, exc): """Global exception handler with structured logging.""" diff --git a/claude_code_api/models/claude.py b/claude_code_api/models/claude.py index 057f4d0..487cd7d 100644 --- a/claude_code_api/models/claude.py +++ b/claude_code_api/models/claude.py @@ -1,17 +1,13 @@ """Claude Code specific models and utilities.""" from datetime import datetime +from functools import lru_cache +from pathlib import Path from typing import List, Optional, Dict, Any, Union, Literal -from pydantic import BaseModel, Field +import json +import os from enum import Enum - - -class ClaudeModel(str, Enum): - """Available Claude models - matching Claude Code CLI supported models.""" - OPUS_4 = "claude-opus-4-20250514" - SONNET_4 = "claude-sonnet-4-20250514" - SONNET_37 = "claude-3-7-sonnet-20250219" - HAIKU_35 = "claude-3-5-haiku-20241022" +from pydantic import BaseModel, Field class ClaudeMessageType(str, Enum): @@ -70,6 +66,10 @@ class ClaudeToolResult(BaseModel): is_error: Optional[bool] = Field(False, description="Whether this is an error result") +def _default_model_factory() -> str: + return get_default_model() + + class ClaudeSessionInfo(BaseModel): """Claude session information.""" session_id: str = Field(..., description="Session ID") @@ -125,7 +125,7 @@ class ClaudeProjectConfig(BaseModel): project_id: str = Field(..., description="Project ID") name: str = Field(..., description="Project name") path: str = Field(..., description="Project path") - default_model: str = Field(ClaudeModel.HAIKU_35, description="Default model") + default_model: str = Field(default_factory=_default_model_factory, description="Default model") system_prompt: Optional[str] = Field(None, description="Default system prompt") tools_enabled: List[ClaudeToolType] = Field(default_factory=list, description="Enabled tools") max_tokens: Optional[int] = Field(None, description="Maximum tokens per request") @@ -194,72 +194,70 @@ class ClaudeModelInfo(BaseModel): supports_tools: bool = Field(True, description="Whether model supports tool use") +MODELS_CONFIG_ENV = "CLAUDE_CODE_API_MODELS_PATH" +DEFAULT_MODELS_PATH = Path(__file__).resolve().parents[1] / "config" / "models.json" + + +def _models_config_path() -> Path: + env_path = os.getenv(MODELS_CONFIG_ENV) + if env_path: + return Path(env_path).expanduser() + return DEFAULT_MODELS_PATH + + +@lru_cache +def _load_models_config() -> Dict[str, Any]: + path = _models_config_path() + if not path.exists(): + raise FileNotFoundError( + f"Models config not found at {path}. Set {MODELS_CONFIG_ENV} to override." + ) + with path.open("r", encoding="utf-8") as handle: + data = json.load(handle) + if not isinstance(data, dict) or "models" not in data: + raise ValueError("Models config must contain a top-level 'models' list.") + return data + + +def _model_index() -> Dict[str, ClaudeModelInfo]: + config = _load_models_config() + models = config.get("models", []) + model_map: Dict[str, ClaudeModelInfo] = {} + for entry in models: + info = ClaudeModelInfo(**entry) + model_map[info.id] = info + return model_map + + # Utility functions for model validation def validate_claude_model(model: str) -> str: """Validate and normalize Claude model name.""" - # Direct Claude model names - valid_models = [model.value for model in ClaudeModel] - - if model in valid_models: + model_map = _model_index() + if model in model_map: return model - - # Default to Haiku for testing - return ClaudeModel.HAIKU_35 + return get_default_model() def get_default_model() -> str: """Get the default Claude model.""" - return ClaudeModel.HAIKU_35 + config = _load_models_config() + default_model = config.get("default_model") + if isinstance(default_model, str) and default_model: + return default_model + model_map = _model_index() + if model_map: + return next(iter(model_map.keys())) + raise ValueError("Models config did not contain any models.") def get_model_info(model_id: str) -> ClaudeModelInfo: """Get information about a Claude model.""" - model_info = { - ClaudeModel.OPUS_4: ClaudeModelInfo( - id=ClaudeModel.OPUS_4, - name="Claude Opus 4", - description="Most powerful Claude model for complex reasoning", - max_tokens=500000, - input_cost_per_1k=15.0, - output_cost_per_1k=75.0, - supports_streaming=True, - supports_tools=True - ), - ClaudeModel.SONNET_4: ClaudeModelInfo( - id=ClaudeModel.SONNET_4, - name="Claude Sonnet 4", - description="Latest Sonnet model with enhanced capabilities", - max_tokens=500000, - input_cost_per_1k=3.0, - output_cost_per_1k=15.0, - supports_streaming=True, - supports_tools=True - ), - ClaudeModel.SONNET_37: ClaudeModelInfo( - id=ClaudeModel.SONNET_37, - name="Claude Sonnet 3.7", - description="Advanced Sonnet model for complex tasks", - max_tokens=200000, - input_cost_per_1k=3.0, - output_cost_per_1k=15.0, - supports_streaming=True, - supports_tools=True - ), - ClaudeModel.HAIKU_35: ClaudeModelInfo( - id=ClaudeModel.HAIKU_35, - name="Claude Haiku 3.5", - description="Fast and cost-effective model for quick tasks", - max_tokens=200000, - input_cost_per_1k=0.25, - output_cost_per_1k=1.25, - supports_streaming=True, - supports_tools=True - ) - } - - return model_info.get(model_id, model_info[ClaudeModel.HAIKU_35]) + model_map = _model_index() + if model_id in model_map: + return model_map[model_id] + return model_map[get_default_model()] def get_available_models() -> List[ClaudeModelInfo]: """Get list of all available Claude models.""" - return [get_model_info(model) for model in ClaudeModel] + return list(_model_index().values()) diff --git a/claude_code_api/models/openai.py b/claude_code_api/models/openai.py index b1feae6..02904f2 100644 --- a/claude_code_api/models/openai.py +++ b/claude_code_api/models/openai.py @@ -2,17 +2,72 @@ from datetime import datetime from typing import List, Optional, Dict, Any, Union, Literal -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field + + +class ToolFunction(BaseModel): + """Tool function definition (OpenAI compatible).""" + name: str = Field(..., description="The name of the function to call") + description: Optional[str] = Field(None, description="The function description") + parameters: Dict[str, Any] = Field(..., description="The JSON schema for the function parameters") + + +class ToolDefinition(BaseModel): + """Tool definition for chat completion requests.""" + type: Literal["function"] = Field("function", description="Tool type") + function: ToolFunction = Field(..., description="Function tool definition") + + +class ToolChoiceFunction(BaseModel): + """Tool choice function selector.""" + name: str = Field(..., description="Name of the function to call") + + +class ToolChoice(BaseModel): + """Tool choice definition.""" + type: Literal["function"] = Field("function", description="Tool choice type") + function: ToolChoiceFunction = Field(..., description="Tool choice function") + + +class ToolCallFunction(BaseModel): + """Tool call function payload.""" + name: str = Field(..., description="Function name") + arguments: str = Field(..., description="JSON-encoded arguments string") + + +class ToolCall(BaseModel): + """Tool call object for responses.""" + id: str = Field(..., description="Tool call ID") + type: Literal["function"] = Field("function", description="Tool call type") + function: ToolCallFunction = Field(..., description="Function call details") + + +class ToolCallFunctionDelta(BaseModel): + """Streaming delta for tool call function.""" + name: Optional[str] = Field(None, description="Function name") + arguments: Optional[str] = Field(None, description="Partial arguments payload") + + +class ToolCallDelta(BaseModel): + """Streaming delta for tool calls.""" + index: int = Field(..., description="Tool call index") + id: Optional[str] = Field(None, description="Tool call ID") + type: Optional[Literal["function"]] = Field(None, description="Tool call type") + function: Optional[ToolCallFunctionDelta] = Field(None, description="Function delta") class ChatMessage(BaseModel): """Chat message model - accepts any content format.""" - role: Literal["system", "user", "assistant"] = Field(..., description="The role of the message author") - content: Any = Field(..., description="The content of the message") # Accept anything + role: Literal["system", "user", "assistant", "tool"] = Field(..., description="The role of the message author") + content: Optional[Any] = Field(None, description="The content of the message") name: Optional[str] = Field(None, description="Optional name for the message author") + tool_calls: Optional[List[ToolCall]] = Field(None, description="Tool calls generated by the assistant") + tool_call_id: Optional[str] = Field(None, description="Tool call ID this tool message is responding to") def get_text_content(self) -> str: """Extract text content from any format.""" + if self.content is None: + return "" if isinstance(self.content, str): return self.content elif isinstance(self.content, list): @@ -43,6 +98,10 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="Frequency penalty") presence_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="Presence penalty") user: Optional[str] = Field(None, description="Unique identifier representing your end-user") + tools: Optional[List[ToolDefinition]] = Field(None, description="Tools available for the model to call") + tool_choice: Optional[Union[str, ToolChoice]] = Field( + None, description="Tool choice preference (e.g. 'auto', 'none', or a specific tool)" + ) # Extension fields for Claude Code project_id: Optional[str] = Field(None, description="Project ID for Claude Code context") @@ -85,6 +144,7 @@ class ChatCompletionChunkDelta(BaseModel): """Delta object for streaming responses.""" role: Optional[str] = Field(None, description="The role of the author of this message") content: Optional[str] = Field(None, description="The contents of the chunk message") + tool_calls: Optional[List[ToolCallDelta]] = Field(None, description="Tool call deltas") class ChatCompletionChunkChoice(BaseModel): diff --git a/claude_code_api/utils/parser.py b/claude_code_api/utils/parser.py index b92a06e..bd30b53 100644 --- a/claude_code_api/utils/parser.py +++ b/claude_code_api/utils/parser.py @@ -1,7 +1,7 @@ """JSONL parser for Claude Code output.""" import json -import re +import uuid from typing import Dict, Any, Optional, List, Generator from datetime import datetime import structlog @@ -29,34 +29,39 @@ def parse_line(self, line: str) -> Optional[ClaudeMessage]: try: data = json.loads(line.strip()) message = ClaudeMessage(**data) - - # Extract session info on first message - if message.session_id and not self.session_id: - self.session_id = message.session_id - - if message.model and not self.model: - self.model = message.model - - # Track metrics - if message.usage: - input_tokens = message.usage.get("input_tokens", 0) - output_tokens = message.usage.get("output_tokens", 0) - self.total_tokens += input_tokens + output_tokens - - if message.cost_usd: - self.total_cost += message.cost_usd - - if message.type in ["user", "assistant"]: - self.message_count += 1 - - return message - + return self.parse_message(message) except json.JSONDecodeError as e: logger.warning("Failed to parse JSONL line", line=line[:100], error=str(e)) return None except Exception as e: logger.error("Error parsing message", line=line[:100], error=str(e)) return None + + def parse_message(self, message: ClaudeMessage) -> Optional[ClaudeMessage]: + """Parse a ClaudeMessage and update metrics.""" + if not message: + return None + + # Extract session info on first message + if message.session_id and not self.session_id: + self.session_id = message.session_id + + if message.model and not self.model: + self.model = message.model + + # Track metrics + if message.usage: + input_tokens = message.usage.get("input_tokens", 0) + output_tokens = message.usage.get("output_tokens", 0) + self.total_tokens += input_tokens + output_tokens + + if message.cost_usd: + self.total_cost += message.cost_usd + + if message.type in ["user", "assistant"]: + self.message_count += 1 + + return message def parse_stream(self, lines: List[str]) -> Generator[ClaudeMessage, None, None]: """Parse multiple JSONL lines.""" @@ -73,17 +78,24 @@ def extract_text_content(self, message: ClaudeMessage) -> str: content = message.message.get("content", []) if isinstance(content, str): return content + if isinstance(content, dict): + if "text" in content: + return str(content.get("text", "")) + if "content" in content: + return str(content.get("content", "")) if isinstance(content, list): text_parts = [] for part in content: if isinstance(part, dict): - if part.get("type") == "text": + if part.get("type") == "text" or "text" in part: text = part.get("text", "") if isinstance(text, str): text_parts.append(text) elif isinstance(text, dict) and "text" in text: text_parts.append(text["text"]) + elif "content" in part: + text_parts.append(str(part.get("content", ""))) elif isinstance(part, str): text_parts.append(part) return "\n".join(text_parts) @@ -275,14 +287,17 @@ def __init__(self): self.current_assistant_content = "" self.parser = ClaudeOutputParser() - def add_message(self, message: ClaudeMessage): + def add_message(self, message: Any): """Add message to aggregator.""" - self.messages.append(message) - self.parser.parse_line(message.json()) + normalized = normalize_claude_message(message) + if not normalized: + return + self.messages.append(normalized) + self.parser.parse_message(normalized) # Aggregate assistant content for complete response - if message.is_assistant_message(): - content = self.parser.extract_text_content(message) + if self.parser.is_assistant_message(normalized): + content = self.parser.extract_text_content(normalized) if content: self.current_assistant_content += content @@ -349,6 +364,36 @@ def estimate_tokens(text: str) -> int: return max(1, len(text) // 4) +def normalize_claude_message(raw: Any) -> Optional[ClaudeMessage]: + """Normalize a raw Claude output object into a ClaudeMessage.""" + if isinstance(raw, ClaudeMessage): + return raw + if isinstance(raw, dict): + try: + return ClaudeMessage(**raw) + except Exception as e: + logger.warning("Failed to normalize Claude message", error=str(e)) + return None + return None + + +def tool_use_to_openai_call(tool_use: ClaudeToolUse) -> Dict[str, Any]: + """Convert a Claude tool use to an OpenAI tool call object.""" + tool_id = tool_use.id or f"call_{uuid.uuid4().hex}" + try: + arguments = json.dumps(tool_use.input or {}, separators=(",", ":"), ensure_ascii=False) + except TypeError: + arguments = json.dumps({"input": str(tool_use.input)}, separators=(",", ":"), ensure_ascii=False) + return { + "id": tool_id, + "type": "function", + "function": { + "name": tool_use.name, + "arguments": arguments + } + } + + def format_timestamp(timestamp: Optional[str]) -> str: """Format timestamp for display.""" if not timestamp: diff --git a/claude_code_api/utils/streaming.py b/claude_code_api/utils/streaming.py index 97cd355..c3aa7b6 100644 --- a/claude_code_api/utils/streaming.py +++ b/claude_code_api/utils/streaming.py @@ -7,8 +7,12 @@ from typing import AsyncGenerator, Dict, Any, Optional import structlog -from claude_code_api.models.claude import ClaudeMessage -from claude_code_api.utils.parser import ClaudeOutputParser, OpenAIConverter, MessageAggregator +from claude_code_api.utils.parser import ( + ClaudeOutputParser, + OpenAIConverter, + normalize_claude_message, + tool_use_to_openai_call +) from claude_code_api.core.claude_manager import ClaudeProcess logger = structlog.get_logger() @@ -60,6 +64,8 @@ def __init__(self, model: str, session_id: str): self.completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" self.created = int(datetime.utcnow().timestamp()) self.chunk_index = 0 + self.parser = ClaudeOutputParser() + self.tool_call_index = 0 async def convert_stream( self, @@ -81,62 +87,65 @@ async def convert_stream( } yield SSEFormatter.format_event(initial_chunk) - assistant_started = False - last_content = "" - chunk_count = 0 - max_chunks = 5 # Limit chunks for better UX + saw_assistant_text = False + saw_tool_calls = False # Process Claude output async for claude_message in claude_process.get_output(): - chunk_count += 1 - if chunk_count > max_chunks: - logger.info("Reached max chunks limit, terminating stream") - break try: - # Simple: just look for assistant messages in the dict - if isinstance(claude_message, dict): - if (claude_message.get("type") == "assistant" and - claude_message.get("message", {}).get("content")): - - message_content = claude_message["message"]["content"] - text_content = "" - - # Handle content array format: [{"type":"text","text":"..."}] - if isinstance(message_content, list): - for content_item in message_content: - if (isinstance(content_item, dict) and - content_item.get("type") == "text" and - content_item.get("text")): - text_content = content_item["text"] - break - # Handle simple string content - elif isinstance(message_content, str): - text_content = message_content - - if text_content.strip(): - chunk = { - "id": self.completion_id, - "object": "chat.completion.chunk", - "created": self.created, - "model": self.model, - "choices": [{ - "index": 0, - "delta": {"content": text_content}, - "finish_reason": None - }] - } - yield SSEFormatter.format_event(chunk) - assistant_started = True - - # Stop on result type - if claude_message.get("type") == "result": - break + message = normalize_claude_message(claude_message) + if not message: + continue + self.parser.parse_message(message) + + if self.parser.is_assistant_message(message): + text_content = self.parser.extract_text_content(message).strip() + if text_content: + chunk = { + "id": self.completion_id, + "object": "chat.completion.chunk", + "created": self.created, + "model": self.model, + "choices": [{ + "index": 0, + "delta": {"content": text_content}, + "finish_reason": None + }] + } + yield SSEFormatter.format_event(chunk) + saw_assistant_text = True + + tool_uses = self.parser.extract_tool_uses(message) + if tool_uses: + tool_calls = [] + for tool_use in tool_uses: + call = tool_use_to_openai_call(tool_use) + call["index"] = self.tool_call_index + self.tool_call_index += 1 + tool_calls.append(call) + tool_chunk = { + "id": self.completion_id, + "object": "chat.completion.chunk", + "created": self.created, + "model": self.model, + "choices": [{ + "index": 0, + "delta": {"tool_calls": tool_calls}, + "finish_reason": None + }] + } + yield SSEFormatter.format_event(tool_chunk) + saw_tool_calls = True + + if self.parser.is_final_message(message): + break except Exception as e: logger.error("Error processing Claude message", error=str(e)) continue # Send final chunk + finish_reason = "tool_calls" if (saw_tool_calls and not saw_assistant_text) else "stop" final_chunk = { "id": self.completion_id, "object": "chat.completion.chunk", @@ -145,7 +154,7 @@ async def convert_stream( "choices": [{ "index": 0, "delta": {}, - "finish_reason": "stop" + "finish_reason": finish_reason }] } yield SSEFormatter.format_event(final_chunk) @@ -332,7 +341,7 @@ def create_non_streaming_response( messages: list, session_id: str, model: str, - usage_summary: Dict[str, Any] + usage: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """Create non-streaming response.""" completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" @@ -346,50 +355,42 @@ def create_non_streaming_response( completion_id=completion_id ) + parser = ClaudeOutputParser() + tool_calls = [] # Extract assistant content from Claude messages content_parts = [] for i, msg in enumerate(messages): + normalized = normalize_claude_message(msg) + if not normalized: + continue + parser.parse_message(normalized) logger.info( f"Processing message {i}", - msg_type=msg.get("type") if isinstance(msg, dict) else type(msg).__name__, - msg_keys=list(msg.keys()) if isinstance(msg, dict) else [], - is_assistant=isinstance(msg, dict) and msg.get("type") == "assistant" + msg_type=normalized.type, + msg_keys=list(normalized.model_dump().keys()), + is_assistant=parser.is_assistant_message(normalized) ) - if isinstance(msg, dict): - # Handle dict messages directly - if msg.get("type") == "assistant" and msg.get("message"): - message_content = msg["message"].get("content", []) - - logger.info( - f"Found assistant message {i}", - content_type=type(message_content).__name__, - content_preview=str(message_content)[:100] if message_content else "empty" - ) - - # Handle content array format: [{"type":"text","text":"..."}] - if isinstance(message_content, list): - for content_item in message_content: - if isinstance(content_item, dict) and content_item.get("type") == "text": - text = content_item.get("text", "").strip() - if text: - content_parts.append(text) - logger.info(f"Extracted text from array: {text[:50]}...") - # Handle simple string content - elif isinstance(message_content, str) and message_content.strip(): - text = message_content.strip() - content_parts.append(text) - logger.info(f"Extracted text from string: {text[:50]}...") + if parser.is_assistant_message(normalized): + text_content = parser.extract_text_content(normalized).strip() + logger.info( + f"Found assistant message {i}", + content_length=len(text_content), + content_preview=text_content[:100] if text_content else "empty" + ) + if text_content: + content_parts.append(text_content) + logger.info(f"Extracted assistant text: {text_content[:50]}...") + + tool_uses = parser.extract_tool_uses(normalized) + for tool_use in tool_uses: + tool_calls.append(tool_use_to_openai_call(tool_use)) - # Use the actual content or fallback - ensure we always have content + # Use the actual content or fallback if content_parts: complete_content = "\n".join(content_parts).strip() else: - complete_content = "Hello! I'm Claude, ready to help." - - # Ensure content is never empty - if not complete_content: - complete_content = "Response received but content was empty." + complete_content = "" logger.info( "Final response content", @@ -399,6 +400,18 @@ def create_non_streaming_response( ) # Return simple OpenAI-compatible response with basic usage stats + if usage is None: + usage = OpenAIConverter.calculate_usage(parser) + + finish_reason = "tool_calls" if (tool_calls and not complete_content) else "stop" + + message_payload: Dict[str, Any] = { + "role": "assistant", + "content": complete_content or None + } + if tool_calls: + message_payload["tool_calls"] = tool_calls + response = { "id": completion_id, "object": "chat.completion", @@ -406,16 +419,13 @@ def create_non_streaming_response( "model": model, "choices": [{ "index": 0, - "message": { - "role": "assistant", - "content": complete_content - }, - "finish_reason": "stop" + "message": message_payload, + "finish_reason": finish_reason }], "usage": { - "prompt_tokens": 10, - "completion_tokens": len(complete_content.split()) if complete_content else 5, - "total_tokens": 10 + (len(complete_content.split()) if complete_content else 5) + "prompt_tokens": usage.get("prompt_tokens", 0), + "completion_tokens": usage.get("completion_tokens", 0), + "total_tokens": usage.get("total_tokens", 0) }, "session_id": session_id } @@ -424,7 +434,7 @@ def create_non_streaming_response( "Response created successfully", response_id=response["id"], choices_count=len(response["choices"]), - message_content_length=len(response["choices"][0]["message"]["content"]) + message_content_length=len(response["choices"][0]["message"]["content"] or "") ) return response diff --git a/pyproject.toml b/pyproject.toml index 26556ac..80cee6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,9 @@ claude-code-api = "claude_code_api.main:main" where = ["."] include = ["claude_code_api*"] +[tool.setuptools.package-data] +claude_code_api = ["config/*.json"] + [tool.pytest.ini_options] minversion = "7.0" addopts = "-ra -q --strict-markers --strict-config" diff --git a/scripts/record_claude_fixture.py b/scripts/record_claude_fixture.py new file mode 100644 index 0000000..ef50f1f --- /dev/null +++ b/scripts/record_claude_fixture.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +"""Record Claude CLI stream-json output into a sanitized fixture.""" + +from __future__ import annotations + +import argparse +import json +import os +import subprocess +import sys +from typing import Any + + +def _replace_in_obj(value: Any, needle: str, replacement: str) -> Any: + if isinstance(value, str): + return value.replace(needle, replacement) + if isinstance(value, list): + return [_replace_in_obj(item, needle, replacement) for item in value] + if isinstance(value, dict): + return {k: _replace_in_obj(v, needle, replacement) for k, v in value.items()} + return value + + +def _sanitize_event(event: dict, cwd_path: str | None, session_id: str | None) -> dict: + if session_id: + event["session_id"] = session_id + if cwd_path: + if event.get("cwd") == cwd_path: + event["cwd"] = "." + event = _replace_in_obj(event, cwd_path, ".") + return event + + +def _run_claude(args: list[str], cwd: str | None) -> bytes: + result = subprocess.run( + args, + cwd=cwd or None, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + text=False, + ) + if result.returncode != 0: + stderr = result.stderr.decode("utf-8", errors="replace") + raise RuntimeError(f"Claude CLI failed: {stderr.strip()}") + return result.stdout + + +def main() -> int: + parser = argparse.ArgumentParser(description="Record Claude stream-json output to a fixture.") + parser.add_argument("--prompt", required=True, help="Prompt to send to Claude.") + parser.add_argument("--out", required=True, help="Output JSONL fixture path.") + parser.add_argument("--model", default="", help="Claude model id or alias.") + parser.add_argument("--claude-bin", default=os.getenv("CLAUDE_BINARY_PATH", "claude")) + parser.add_argument("--session-id", default="", help="Stable session id to embed in fixture.") + parser.add_argument("--cwd", default="", help="Working directory for Claude CLI.") + parser.add_argument("--permission-mode", default="bypassPermissions") + parser.add_argument("--include-partial-messages", action="store_true") + parser.add_argument("--tools", default="", help="Comma-separated tool list for Claude CLI.") + args = parser.parse_args() + + cmd = [ + args.claude_bin, + "--print", + "--output-format", + "stream-json", + "--verbose", + ] + if args.model: + cmd.extend(["--model", args.model]) + if args.permission_mode: + cmd.extend(["--permission-mode", args.permission_mode]) + if args.include_partial_messages: + cmd.append("--include-partial-messages") + if args.tools != "": + cmd.extend(["--tools", args.tools]) + if args.session_id: + cmd.extend(["--session-id", args.session_id]) + cmd.append(args.prompt) + + cwd = args.cwd or None + raw = _run_claude(cmd, cwd) + if not raw.strip(): + raise RuntimeError("Claude CLI returned empty output.") + + out_path = os.path.abspath(args.out) + os.makedirs(os.path.dirname(out_path), exist_ok=True) + + cwd_path = os.path.abspath(cwd) if cwd else None + session_id = args.session_id or None + + lines = raw.splitlines() + with open(out_path, "w", encoding="utf-8") as handle: + for line in lines: + line = line.strip() + if not line: + continue + payload = line + if payload.startswith(b"data: "): + payload = payload[6:] + event = json.loads(payload.decode("utf-8")) + event = _sanitize_event(event, cwd_path, session_id) + handle.write(json.dumps(event, ensure_ascii=False) + "\n") + + print(f"Wrote fixture to {out_path}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/run-sonar-cloud.sh b/scripts/run-sonar-cloud.sh new file mode 100755 index 0000000..bda7b8a --- /dev/null +++ b/scripts/run-sonar-cloud.sh @@ -0,0 +1,82 @@ +#!/usr/bin/env bash +set -e + +# SonarCloud-specific environment loader +echo "Loading SonarCloud configuration..." + +# Load base .env file +if [ -f .env ]; then + echo "Loading base environment from .env" + # Source with automatic export + set -a + . ./.env 2>/dev/null || true + set +a +fi + +# Load SonarCloud-specific .env.cloud file +if [ -f .env.cloud ]; then + echo "Loading SonarCloud configuration from .env.cloud" + set -a + . ./.env.cloud 2>/dev/null || true + set +a +fi + +# Verify VAULT_ADDR is set +if [ -z "${VAULT_ADDR:-}" ]; then + echo "Error: VAULT_ADDR not set" + echo " Ensure .env sets VAULT_ADDR" + exit 1 +fi + +echo "Using Vault: $VAULT_ADDR" + +# Load SONAR_CLOUD_TOKEN from Vault (different path than regular SONAR_TOKEN) +echo "Loading SONAR_CLOUD_TOKEN from Vault..." + +# Try the correct Vault path - adjust for your organization +CLOUD_SECRET=$(vault kv get -field=token kv/sonarcloud 2>/dev/null || echo "") + +if [ -n "$CLOUD_SECRET" ]; then + export SONAR_CLOUD_TOKEN="$CLOUD_SECRET" + echo "SONAR_CLOUD_TOKEN loaded from Vault (path: kv/sonarcloud)" +else + echo "SONAR_CLOUD_TOKEN not found in Vault at kv/sonarcloud" + echo " Tried path: kv/sonarcloud" + echo " Set SONAR_CLOUD_TOKEN manually or add secret to Vault" +fi + +# Check if we have required configuration +if [ -z "${SONAR_CLOUD_TOKEN:-}" ]; then + echo "Error: SONAR_CLOUD_TOKEN not set" + echo " Configure one of the following:" + echo " 1. Add secret to Vault at: kv/sonarcloud" + echo " 2. Set environment variable: export SONAR_CLOUD_TOKEN=your-token" + echo " 3. Add to .env.cloud: SONAR_CLOUD_TOKEN=your-token" + exit 1 +fi + +# Set defaults if not provided +SONAR_HOST_URL="${SONAR_CLOUD_URL:-https://sonarcloud.io}" +SONAR_ORG="${SONAR_CLOUD_ORG:-}" +SONAR_PROJECT_KEY="${SONAR_CLOUD_PROJECT:-claude-code-api}" + +# Generate coverage for SonarCloud +echo "Generating coverage report for SonarCloud..." +mkdir -p dist/quality/coverage dist/quality/sonar +python -m pytest --cov=claude_code_api --cov-report=xml --cov-report=term-missing -v tests/ + +echo "Running SonarCloud scanner..." +echo " Organization: $SONAR_ORG" +echo " Project Key: $SONAR_PROJECT_KEY" +echo " Host URL: $SONAR_HOST_URL" + +# Run sonar-scanner with SonarCloud settings +sonar-scanner \ + -Dsonar.host.url="$SONAR_HOST_URL" \ + -Dsonar.token="$SONAR_CLOUD_TOKEN" \ + -Dsonar.organization="$SONAR_ORG" \ + -Dsonar.projectKey="$SONAR_PROJECT_KEY" \ + -Dsonar.projectVersion="$(cat VERSION 2>/dev/null || echo "1.0.0")" \ + -Dsonar.projectBaseDir=. \ + -Dsonar.scm.provider=git \ + -Dsonar.working.directory=dist/quality/sonar/scannerwork diff --git a/scripts/upload-sbom.sh b/scripts/upload-sbom.sh new file mode 100755 index 0000000..ad153b7 --- /dev/null +++ b/scripts/upload-sbom.sh @@ -0,0 +1,98 @@ +#!/usr/bin/env bash +set -e + +# Load environment if .env exists +if [ -f .env ]; then + echo "Loading environment overrides from .env" + # Source .env with automatic export + set -a + . ./.env || true # Don't fail if there are warnings + set +a + + # Verify critical variables are set + if [ -z "${DTRACK_API_KEY:-}" ]; then + echo "Warning: DTRACK_API_KEY not set, cannot upload SBOM" >&2 + echo "Set DTRACK_API_KEY in .env or via Vault" >&2 + exit 1 + fi + if [ -z "${DTRACK_BASE_URL:-}" ]; then + echo "Warning: DTRACK_BASE_URL not set" >&2 + exit 1 + fi +fi + +# Check required variables +if [ -z "${DTRACK_BASE_URL:-}" ]; then + echo "Error: DTRACK_BASE_URL is not set" >&2 + exit 1 +fi + +if [ -z "${DTRACK_API_KEY:-}" ]; then + echo "Error: DTRACK_API_KEY is not set" >&2 + exit 1 +fi + +PROJECT="${DTRACK_PROJECT:-claude-code-api}" +VERSION="${DTRACK_PROJECT_VERSION:-$(git rev-parse --short HEAD 2>/dev/null || echo "dev")}" +BASE_URL="${DTRACK_BASE_URL%/}" +SBOM_FILE="${1:-dist/security/sbom/sbom.json}" + +echo "Uploading SBOM to $BASE_URL" +echo " Project: $PROJECT" +echo " Version: $VERSION" +echo " File: $SBOM_FILE" + +if [ ! -f "$SBOM_FILE" ]; then + echo "Error: SBOM file not found at $SBOM_FILE" >&2 + echo "Run 'make sbom' first to generate the SBOM" >&2 + exit 1 +fi + +BOM_B64=$(base64 < "$SBOM_FILE" | tr -d '\n') +TMP_RESP=$(mktemp) + +# Try JSON upload first +JSON_PAYLOAD=$(jq -n \ + --arg pn "$PROJECT" \ + --arg pv "$VERSION" \ + --arg bom "$BOM_B64" \ + '{projectName: $pn, projectVersion: $pv, autoCreate: true, bom: $bom}') + +upload_json() { + curl -sS -o "$TMP_RESP" -w "%{http_code}" -X PUT \ + -H "X-Api-Key: $DTRACK_API_KEY" \ + -H "Content-Type: application/json" \ + -d "$JSON_PAYLOAD" \ + "$BASE_URL/api/v1/bom" +} + +status=$(upload_json || echo "curl_failed") +if [ "$status" != "curl_failed" ] && [ "$status" -ge 200 ] && [ "$status" -lt 300 ]; then + echo "SBOM uploaded to $BASE_URL (JSON)" + rm -f "$TMP_RESP" + exit 0 +fi + +echo "Dependency-Track JSON upload failed (status=$status):" +if [ -f "$TMP_RESP" ]; then cat "$TMP_RESP"; fi + +# Try multipart form upload +echo "Retrying with multipart form upload..." +status=$(curl -sS -o "$TMP_RESP" -w "%{http_code}" -X PUT \ + -H "X-Api-Key: $DTRACK_API_KEY" \ + -F "projectName=$PROJECT" \ + -F "projectVersion=$VERSION" \ + -F "autoCreate=true" \ + -F "bom@$SBOM_FILE;type=application/json" \ + "$BASE_URL/api/v1/bom" || echo "curl_failed") + +if [ "$status" != "curl_failed" ] && [ "$status" -ge 200 ] && [ "$status" -lt 300 ]; then + echo "SBOM uploaded to $BASE_URL (multipart)" + rm -f "$TMP_RESP" + exit 0 +fi + +echo "Dependency-Track upload failed (status=$status). Response:" +cat "$TMP_RESP" 2>/dev/null || true +rm -f "$TMP_RESP" +exit 1 diff --git a/scripts/vault-helper.sh b/scripts/vault-helper.sh new file mode 100755 index 0000000..5b55eb5 --- /dev/null +++ b/scripts/vault-helper.sh @@ -0,0 +1,347 @@ +#!/usr/bin/env bash + +vault_helper::log_info() { + printf '[INFO] %s\n' "$1" +} + +vault_helper::log_warn() { + printf '[WARN] %s\n' "$1" +} + +vault_helper::log_error() { + printf '[ERROR] %s\n' "$1" >&2 +} + +vault_helper::require_cli() { + local bin="$1" + if ! command -v "$bin" >/dev/null 2>&1; then + vault_helper::log_error "Required command '$bin' not found in PATH" + return 1 + fi +} + +vault_helper::is_truthy() { + case "${1,,}" in + 1|true|yes|y|on) return 0 ;; + *) return 1 ;; + esac +} + +vault_helper::kv_get() { + local full_path="$1" + local mount="${VAULT_KV_MOUNT:-}" + local secret_path="$full_path" + local use_mount=0 + local skip_lookup kv_version + + skip_lookup="${VAULT_KV_SKIP_MOUNT_LOOKUP:-}" + + if [[ -n "$mount" ]]; then + use_mount=1 + if [[ "$full_path" == "$mount/"* ]]; then + secret_path="${full_path#${mount}/}" + fi + else + if [[ "$full_path" == */* ]]; then + mount="${full_path%%/*}" + secret_path="${full_path#*/}" + fi + fi + + if vault_helper::is_truthy "$skip_lookup"; then + kv_version="${VAULT_KV_ENGINE_VERSION:-${VAULT_KV_VERSION:-}}" + if vault_helper::is_truthy "${VAULT_KV_V2:-}"; then + kv_version=2 + fi + + if [[ -z "$mount" || -z "$secret_path" || "$secret_path" == "$full_path" ]]; then + vault read -format=json "$full_path" + return $? + fi + + if [[ -z "$kv_version" ]]; then + local payload + if payload=$(vault read -format=json "${mount}/data/${secret_path}" 2>/dev/null); then + printf '%s' "$payload" + return 0 + fi + vault read -format=json "${mount}/${secret_path}" + return $? + fi + + if [[ "$kv_version" == "2" ]]; then + vault read -format=json "${mount}/data/${secret_path}" + return $? + fi + + vault read -format=json "${mount}/${secret_path}" + return $? + fi + + if [[ -n "$mount" ]]; then + use_mount=1 + fi + + if [[ "$use_mount" -eq 1 && -n "$mount" && -n "$secret_path" && "$secret_path" != "$full_path" ]]; then + vault kv get -mount="$mount" -format=json "$secret_path" + return $? + fi + + vault kv get -format=json "$full_path" +} + +vault_helper::trim_token() { + tr -d '\r\n' <<<"$1" +} + +vault_helper::trim_string() { + local str="$1" + str="${str#"${str%%[![:space:]]*}"}" + str="${str%"${str##*[![:space:]]}"}" + printf '%s' "$str" +} + +vault_helper::load_token_from_file() { + local file="$1" + [[ -r "$file" ]] || return 1 + vault_helper::trim_token "$(cat "$file")" +} + +vault_helper::save_token() { + local token="$1" + local file="$2" + mkdir -p "$(dirname "$file")" + umask 077 + printf '%s\n' "$token" >"$file" + chmod 600 "$file" 2>/dev/null || true + vault_helper::log_info "Saved Vault token to $file" +} + +vault_helper::validate_token() { + local token="$1" + [[ -n "$token" ]] || return 1 + if VAULT_TOKEN="$token" vault token lookup >/dev/null 2>&1; then + return 0 + fi + return 1 +} + +vault_helper::authenticate() { + local username password login_json token + read -r -p "Vault username: " username >&2 + read -r -s -p "Vault password: " password + echo >&2 + + if ! login_json=$(vault login -format=json -method=userpass username="$username" password="$password" 2>/dev/null); then + vault_helper::log_error "Vault authentication failed" + return 1 + fi + + token=$(jq -r '.auth.client_token // empty' <<<"$login_json") + unset login_json + + if [[ -z "$token" ]]; then + vault_helper::log_error "Vault login response did not include a token" + return 1 + fi + + VAULT_TOKEN="$token" + export VAULT_TOKEN + vault_helper::save_token "$token" "$VAULT_TOKEN_FILE" + + unset username password token + return 0 +} + +vault_helper::set_secret_if_empty() { + local var="$1" + local value="$2" + local source="$3" + + if [[ -z "${!var:-}" && -n "$value" ]]; then + printf -v "$var" '%s' "$value" + export "$var" + vault_helper::log_info "Mapped ${source} -> ${var}" + fi +} + +vault_helper::apply_mappings() { + local json="$1" + local path="$2" + local mappings="$3" + local normalized entry var key value + + [[ -z "$mappings" ]] && return 0 + + normalized=$(printf '%s' "$mappings" | tr ',;' ' ') + for entry in $normalized; do + [[ "$entry" != *=* ]] && continue + var="$(vault_helper::trim_string "${entry%%=*}")" + key="$(vault_helper::trim_string "${entry#*=}")" + [[ -z "$var" || -z "$key" ]] && continue + value=$(jq -r --arg k "$key" '.[$k] // empty' <<<"$json") + if [[ -n "$value" && "$value" != "null" ]]; then + vault_helper::set_secret_if_empty "$var" "$value" "${key}@${path}" + fi + done +} + +vault_helper::fetch_and_export() { + local path="$1" + local mappings="$2" + local payload exports count data_json + + vault_helper::log_info "Fetching secrets from ${path}..." + if ! payload=$(vault_helper::kv_get "$path" 2>&1); then + vault_helper::log_error "Failed to fetch secrets from ${path}" + vault_helper::log_error "$payload" + return 1 + fi + + if ! data_json=$(printf '%s' "$payload" | jq -c '.data.data // .data // {}'); then + vault_helper::log_error "Unable to parse secrets JSON from ${path}" + return 1 + fi + + if [[ "$data_json" == "{}" ]]; then + vault_helper::log_warn "No secrets to export at ${path}" + return 0 + fi + + if ! exports=$(printf '%s' "$data_json" | jq -r ' + to_entries[]? | + "export \(.key)=\(.value | @sh)" + '); then + vault_helper::log_error "Unable to parse secrets from ${path}" + return 1 + fi + + eval "$exports" + vault_helper::apply_mappings "$data_json" "$path" "$mappings" + + count=$(printf '%s' "$data_json" | jq 'length') + vault_helper::log_info "Loaded ${count} secret(s) from ${path}" + return 0 +} + +vault_helper::validate_required_vars() { + local missing=() + local var + for var in "$@"; do + if [[ -z "${!var:-}" ]]; then + missing+=("$var") + fi + done + + if [[ "${#missing[@]}" -gt 0 ]]; then + vault_helper::log_error "Missing required secret(s): ${missing[*]}" + return 1 + fi + + vault_helper::log_info "Validated required secret(s): ${*}" + return 0 +} + +vault_helper::load_from_definitions() { + local secret_defs_raw="$1" + local required_vars_raw="$2" + VAULT_TOKEN_FILE="${3:-$HOME/.vault-token}" + + local -a secret_defs required_vars + local entry path mappings token_from_file + local skip_verify no_auth force_auth + + if vault_helper::is_truthy "${VAULT_SKIP_LOAD:-${VAULT_DISABLE:-}}"; then + vault_helper::log_info "Skipping Vault secret load (VAULT_SKIP_LOAD=1)." + return 0 + fi + + if [[ -z "$(vault_helper::trim_string "$secret_defs_raw")" ]]; then + vault_helper::log_error "No Vault secret paths configured. Set VAULT_SECRET_PATHS or provide a default." + return 1 + fi + + mapfile -t secret_defs < <(printf '%s\n' "$secret_defs_raw" | awk 'NF') + + if [[ "${#secret_defs[@]}" -eq 0 ]]; then + vault_helper::log_error "No Vault secret paths configured. Set VAULT_SECRET_PATHS or provide a default." + return 1 + fi + + if [[ -n "$(vault_helper::trim_string "$required_vars_raw")" ]]; then + mapfile -t required_vars < <(printf '%s\n' "$required_vars_raw" | tr ', \t' '\n' | awk 'NF') + else + required_vars=() + fi + + if ! vault_helper::require_cli vault || ! vault_helper::require_cli jq; then + return 1 + fi + + skip_verify="${VAULT_SKIP_VERIFY:-}" + no_auth="${VAULT_NO_AUTH:-${VAULT_NONINTERACTIVE:-}}" + force_auth="${VAULT_FORCE_AUTH:-}" + + if [[ -z "${VAULT_TOKEN:-}" ]]; then + if token_from_file=$(vault_helper::load_token_from_file "$VAULT_TOKEN_FILE" 2>/dev/null); then + VAULT_TOKEN="$token_from_file" + export VAULT_TOKEN + vault_helper::log_info "Loaded Vault token from $VAULT_TOKEN_FILE" + fi + fi + + if [[ -n "${VAULT_TOKEN:-}" ]]; then + if vault_helper::is_truthy "$skip_verify"; then + vault_helper::log_info "Skipping Vault token validation (VAULT_SKIP_VERIFY=1)." + elif vault_helper::validate_token "${VAULT_TOKEN:-}"; then + vault_helper::log_info "Existing Vault token is valid." + else + if vault_helper::is_truthy "$force_auth" && [[ -t 0 ]]; then + vault_helper::log_info "Vault token invalid; starting authentication (VAULT_FORCE_AUTH=1)." + if ! vault_helper::authenticate; then + return 1 + fi + else + vault_helper::log_warn "Vault token validation failed; proceeding with provided token." + fi + fi + else + if vault_helper::is_truthy "$no_auth" || [[ ! -t 0 ]]; then + vault_helper::log_error "Vault token missing and authentication disabled." + return 1 + fi + vault_helper::log_info "Vault token missing; starting authentication." + if ! vault_helper::authenticate; then + return 1 + fi + fi + + for entry in "${secret_defs[@]}"; do + entry="$(vault_helper::trim_string "$entry")" + path="$entry" + mappings="" + + if [[ "$entry" == *"|"* ]]; then + path="$(vault_helper::trim_string "${entry%%|*}")" + mappings="$(vault_helper::trim_string "${entry#*|}")" + fi + + if [[ -z "$path" ]]; then + vault_helper::log_warn "Skipping empty path definition: $entry" + continue + fi + + if ! vault_helper::fetch_and_export "$path" "$mappings"; then + return 1 + fi + done + + if [[ "${#required_vars[@]}" -gt 0 ]]; then + if ! vault_helper::validate_required_vars "${required_vars[@]}"; then + return 1 + fi + fi + + vault_helper::log_info "Vault secrets loaded successfully." + return 0 +} diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 0000000..038ca5f --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,29 @@ +# SonarQube Project Configuration for claude-code-api +sonar.projectKey=claude-code-api +sonar.projectName=claude-code-api +# Version is set dynamically via Makefile from VERSION file + +# Source and test directories +sonar.sources=claude_code_api +sonar.tests=tests +sonar.test.inclusions=**/test_*.py,**/*_test.py +sonar.exclusions=**/*_test.py,tmp/**,dist/**,out/**,.vscode/**,.husky/**,scripts/**,vendor/**,node_modules/**,**/__pycache__/** + +# Python-specific settings +sonar.python.version=3.11 +sonar.python.coverage.reportPaths=dist/quality/coverage/coverage.xml +# Python test report (pytest can generate JUnit XML) +sonar.python.xunit.reportPath=dist/quality/sonar/xunit-report.xml + +# Coverage exclusions (test files are already excluded from sources) +sonar.coverage.exclusions=**/test_*.py,**/*_test.py,**/tests/** + +# Scanner working directory +sonar.working.directory=dist/quality/sonar/scannerwork + +# Analysis settings +sonar.sourceEncoding=UTF-8 +sonar.scm.provider=git +sonar.scm.forceReloadAll=true +sonar.scm.disabled=false +sonar.host.url= diff --git a/tests/conftest.py b/tests/conftest.py index 36fc933..3cfab3a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import sys import tempfile import shutil +import json from pathlib import Path from fastapi.testclient import TestClient from httpx import AsyncClient @@ -15,6 +16,7 @@ # Now import the app and configuration from claude_code_api.main import app +from claude_code_api.models.claude import get_default_model from claude_code_api.core.config import settings @@ -37,18 +39,54 @@ def setup_test_environment(): settings.project_root = os.path.join(temp_dir, "projects") settings.require_auth = False - # Keep the real Claude binary path if it exists, otherwise use a mock - # settings.claude_binary_path should remain as found by find_claude_binary() - if not shutil.which(settings.claude_binary_path) and not os.path.exists(settings.claude_binary_path): - # Create a mock binary for CI/Sandbox environments + # Prefer deterministic fixtures unless explicitly using real Claude + use_real_claude = os.environ.get("CLAUDE_CODE_API_USE_REAL_CLAUDE") == "1" + if not use_real_claude: + fixtures_dir = Path(__file__).parent / "fixtures" + index_path = fixtures_dir / "index.json" + default_fixture = fixtures_dir / "claude_stream_simple.jsonl" + + fixture_rules = [] + if index_path.exists(): + try: + fixture_rules = json.loads(index_path.read_text(encoding="utf-8")) + except Exception as exc: + raise RuntimeError(f"Failed to parse fixture index: {exc}") from exc + + # Create a mock binary that replays recorded JSONL fixtures mock_path = os.path.join(temp_dir, "claude") with open(mock_path, "w") as f: - f.write('#!/bin/bash\n') + f.write('#!/usr/bin/env bash\n') f.write('if [ "$1" == "--version" ]; then echo "Claude Code 1.0.0"; exit 0; fi\n') - f.write('echo \'{"type":"message","message":{"role":"assistant","content":"Mock response"}}\'\n') - f.write('echo \'{"type":"result","result":"done"}\'\n') + f.write('prompt=""\n') + f.write('args=("$@")\n') + f.write('for ((i=0; i<${#args[@]}; i++)); do\n') + f.write(' if [ "${args[$i]}" == "-p" ]; then\n') + f.write(' prompt="${args[$((i+1))]}"\n') + f.write(' break\n') + f.write(' fi\n') + f.write('done\n') + f.write('prompt_lower="$(printf "%s" "$prompt" | tr "[:upper:]" "[:lower:]")"\n') + f.write(f'fixture_default="{default_fixture}"\n') + f.write('fixture_match="$fixture_default"\n') + for rule in fixture_rules: + matches = rule.get("match", []) + fixture_file = rule.get("file") + if not fixture_file or not matches: + continue + fixture_path = fixtures_dir / fixture_file + for match in matches: + match_escaped = str(match).replace('"', '\\"') + f.write(f'if echo "$prompt_lower" | grep -q "{match_escaped}"; then fixture_match="{fixture_path}"; fi\n') + f.write('cat "$fixture_match"\n') os.chmod(mock_path, 0o755) settings.claude_binary_path = mock_path + else: + # Ensure the real binary is available when requested + if not shutil.which(settings.claude_binary_path) and not os.path.exists(settings.claude_binary_path): + raise RuntimeError( + f"CLAUDE_CODE_API_USE_REAL_CLAUDE=1 but binary not found at {settings.claude_binary_path}" + ) settings.database_url = f"sqlite:///{temp_dir}/test.db" settings.debug = True @@ -88,7 +126,7 @@ async def async_test_client(): def sample_chat_request(): """Sample chat completion request.""" return { - "model": "claude-3-5-sonnet-20241022", + "model": get_default_model(), "messages": [ {"role": "user", "content": "Hi"} ], @@ -100,7 +138,7 @@ def sample_chat_request(): def sample_streaming_request(): """Sample streaming chat completion request.""" return { - "model": "claude-3-5-sonnet-20241022", + "model": get_default_model(), "messages": [ {"role": "user", "content": "Tell me a joke"} ], @@ -123,7 +161,7 @@ def sample_session_request(): return { "project_id": "test-project", "title": "Test Session", - "model": "claude-3-5-sonnet-20241022" + "model": get_default_model() } @@ -139,6 +177,9 @@ def pytest_configure(config): config.addinivalue_line( "markers", "unit: marks tests as unit tests" ) + config.addinivalue_line( + "markers", "e2e: marks tests as end-to-end tests" + ) def pytest_collection_modifyitems(config, items): diff --git a/tests/fixtures/claude_stream_simple.jsonl b/tests/fixtures/claude_stream_simple.jsonl new file mode 100644 index 0000000..5d42d1a --- /dev/null +++ b/tests/fixtures/claude_stream_simple.jsonl @@ -0,0 +1,3 @@ +{"type":"system","message":{"role":"system","content":[{"type":"text","text":"You are Claude Code."}]},"session_id":"sess_simple_1","model":"claude-3-5-haiku-20241022","cwd":".","tools":["bash","read"],"timestamp":"2026-02-04T00:00:00Z"} +{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"Hello! How can I help today?"}]},"session_id":"sess_simple_1","model":"claude-3-5-haiku-20241022"} +{"type":"result","result":"ok","session_id":"sess_simple_1","model":"claude-3-5-haiku-20241022","usage":{"input_tokens":12,"output_tokens":8},"cost_usd":0.00002,"duration_ms":1200,"num_turns":1} diff --git a/tests/fixtures/claude_stream_tool_calls.jsonl b/tests/fixtures/claude_stream_tool_calls.jsonl new file mode 100644 index 0000000..d893cff --- /dev/null +++ b/tests/fixtures/claude_stream_tool_calls.jsonl @@ -0,0 +1,5 @@ +{"type":"system","message":{"role":"system","content":[{"type":"text","text":"You are Claude Code."}]},"session_id":"sess_tool_1","model":"claude-3-5-haiku-20241022","cwd":".","tools":["bash","read"],"timestamp":"2026-02-04T00:00:00Z"} +{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"I'll list the files."},{"type":"tool_use","id":"toolu_123","name":"bash","input":{"command":"ls -1"}}]},"session_id":"sess_tool_1","model":"claude-3-5-haiku-20241022"} +{"type":"tool_result","message":{"role":"tool","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"README.md\nclaude_code_api\n","is_error":false}]},"session_id":"sess_tool_1","model":"claude-3-5-haiku-20241022"} +{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"I found README.md and claude_code_api."}]},"session_id":"sess_tool_1","model":"claude-3-5-haiku-20241022"} +{"type":"result","result":"ok","session_id":"sess_tool_1","model":"claude-3-5-haiku-20241022","usage":{"input_tokens":20,"output_tokens":15},"cost_usd":0.00005,"duration_ms":1500,"num_turns":1} diff --git a/tests/fixtures/index.json b/tests/fixtures/index.json new file mode 100644 index 0000000..d8cee50 --- /dev/null +++ b/tests/fixtures/index.json @@ -0,0 +1,17 @@ +[ + { + "match": [ + "list files", + "list the files", + "use a tool" + ], + "file": "claude_stream_tool_calls.jsonl" + }, + { + "match": [ + "hi", + "hello" + ], + "file": "claude_stream_simple.jsonl" + } +] diff --git a/tests/test_api.sh b/tests/test_api.sh index e29e823..3869400 100644 --- a/tests/test_api.sh +++ b/tests/test_api.sh @@ -1,19 +1,19 @@ #!/bin/bash -echo "🚀 Testing Claude Code API Gateway" +echo "Testing Claude Code API Gateway" echo -echo "📋 Testing Models Endpoint:" +echo "Testing Models Endpoint:" curl -s http://localhost:8000/v1/models | jq . echo -echo "❤️ Testing Health Endpoint:" +echo "Testing Health Endpoint:" curl -s http://localhost:8000/health | jq . echo -echo "✅ API is working! No authentication required." -echo "📝 Available endpoints:" +echo "API is working! No authentication required." +echo "Available endpoints:" echo " - GET /v1/models" echo " - POST /v1/chat/completions" echo " - GET /health" -echo " - GET /docs (API documentation)" \ No newline at end of file +echo " - GET /docs (API documentation)" diff --git a/tests/test_claude_working.py b/tests/test_claude_working.py index 6617a24..7eba735 100644 --- a/tests/test_claude_working.py +++ b/tests/test_claude_working.py @@ -1,110 +1,44 @@ -#!/usr/bin/env python3 -""" -Test that demonstrates Claude CLI is working and create a proper test -""" +"""Fixture-based tests for Claude CLI output parsing.""" -import subprocess import json -import time +from pathlib import Path -def test_claude_directly(): - """Test Claude CLI directly to prove it works""" - print("🧪 Testing Claude CLI directly...") - - cmd = [ - "/usr/local/share/nvm/versions/node/v23.11.1/bin/claude", - "-p", "Say hello and return", - "--model", "claude-3-5-haiku-20241022", - "--output-format", "stream-json", - "--verbose", - "--dangerously-skip-permissions" - ] - - try: - result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) - print(f"✅ Exit code: {result.returncode}") - if result.stdout: - lines = result.stdout.strip().split('\n') - print(f"✅ Got {len(lines)} lines of output") - for i, line in enumerate(lines[:3]): # Show first 3 lines - try: - data = json.loads(line) - print(f" Line {i+1}: {data.get('type', 'unknown')} - {line[:100]}...") - except: - print(f" Line {i+1}: {line[:100]}...") - - if result.stderr: - print(f"⚠️ stderr: {result.stderr[:200]}") - - return result.returncode == 0 - - except subprocess.TimeoutExpired: - print("❌ Command timed out") - return False - except Exception as e: - print(f"❌ Error: {e}") - return False +from claude_code_api.utils.streaming import create_non_streaming_response +from claude_code_api.models.claude import get_default_model -def test_api_with_real_claude(): - """Test if our API is working now""" - import requests - - print("\n🌐 Testing API with real Claude...") - - try: - # Test health first - health = requests.get("http://localhost:8000/health", timeout=5) - print(f"Health check: {health.status_code}") - - # Test chat completion - payload = { - "model": "claude-3-5-haiku-20241022", - "messages": [{"role": "user", "content": "Hi"}], - "stream": False - } - - print("Making chat completion request...") - response = requests.post( - "http://localhost:8000/v1/chat/completions", - json=payload, - timeout=30 - ) - - print(f"✅ Chat completion status: {response.status_code}") - if response.status_code == 200: - data = response.json() - if 'choices' in data: - content = data['choices'][0]['message']['content'] - print(f"✅ Response: {content[:100]}...") - return True - else: - print(f"❌ Error response: {response.text[:200]}") - - except requests.exceptions.Timeout: - print("❌ API request timed out") - except Exception as e: - print(f"❌ API test error: {e}") - - return False +FIXTURES_DIR = Path(__file__).parent / "fixtures" -if __name__ == "__main__": - print("🚀 Testing Claude Code Integration") - print("=" * 50) - - # Test 1: Direct Claude CLI - claude_works = test_claude_directly() - - # Test 2: API with Claude - api_works = test_api_with_real_claude() - - print("\n" + "=" * 50) - print(f"📊 Results:") - print(f" Claude CLI: {'✅ WORKS' if claude_works else '❌ FAILS'}") - print(f" API: {'✅ WORKS' if api_works else '❌ FAILS'}") - - if claude_works and not api_works: - print("\n💡 Claude CLI works but API fails - this means the issue is in our Python async handling!") - elif claude_works and api_works: - print("\n🎉 Everything works! API is ready!") - else: - print("\n❌ Claude CLI itself has issues") \ No newline at end of file + +def load_fixture(filename: str): + path = FIXTURES_DIR / filename + return [json.loads(line) for line in path.read_text().splitlines() if line.strip()] + + +def test_fixture_simple_non_streaming_response(): + """Ensure basic fixture output produces a valid response.""" + messages = load_fixture("claude_stream_simple.jsonl") + response = create_non_streaming_response( + messages=messages, + session_id="sess_simple_1", + model=get_default_model() + ) + + choice = response["choices"][0] + assert choice["message"]["role"] == "assistant" + assert choice["message"]["content"].startswith("Hello!") + assert response["usage"]["total_tokens"] >= 0 + + +def test_fixture_tool_calls_response(): + """Ensure tool calls are surfaced from fixture output.""" + messages = load_fixture("claude_stream_tool_calls.jsonl") + response = create_non_streaming_response( + messages=messages, + session_id="sess_tool_1", + model=get_default_model() + ) + + message = response["choices"][0]["message"] + assert "tool_calls" in message + assert len(message["tool_calls"]) > 0 + assert message["tool_calls"][0]["function"]["name"] == "bash" diff --git a/tests/test_e2e_live_api.py b/tests/test_e2e_live_api.py new file mode 100644 index 0000000..09beaa8 --- /dev/null +++ b/tests/test_e2e_live_api.py @@ -0,0 +1,87 @@ +"""End-to-end tests against a running API server.""" + +import os +import json +import pytest +import httpx +from claude_code_api.models.claude import get_default_model + + +BASE_URL = os.getenv("CLAUDE_CODE_API_BASE_URL", "http://localhost:8000") + + +def _should_run_e2e() -> bool: + return os.getenv("CLAUDE_CODE_API_E2E") == "1" + + +@pytest.fixture(scope="session") +def live_client(): + if not _should_run_e2e(): + pytest.skip("Set CLAUDE_CODE_API_E2E=1 to run live API tests.") + + try: + response = httpx.get(f"{BASE_URL}/health", timeout=5) + if response.status_code != 200: + pytest.skip(f"API not healthy at {BASE_URL}.") + except Exception as exc: + pytest.skip(f"API not reachable at {BASE_URL}: {exc}") + + with httpx.Client(base_url=BASE_URL, timeout=60) as client: + yield client + + +def _parse_sse_lines(lines): + events = [] + for line in lines: + if not line.startswith("data: "): + continue + payload = line[6:] + if payload == "[DONE]": + break + events.append(json.loads(payload)) + return events + + +@pytest.mark.e2e +def test_live_health(live_client): + response = live_client.get("/health") + assert response.status_code == 200 + payload = response.json() + assert payload.get("status") == "healthy" + + +@pytest.mark.e2e +def test_live_models(live_client): + response = live_client.get("/v1/models") + assert response.status_code == 200 + payload = response.json() + assert payload.get("object") == "list" + assert payload.get("data") + + +@pytest.mark.e2e +def test_live_chat_completion(live_client): + payload = { + "model": get_default_model(), + "messages": [{"role": "user", "content": "Say only 'hi'."}], + "stream": False + } + response = live_client.post("/v1/chat/completions", json=payload) + assert response.status_code == 200 + data = response.json() + assert data.get("object") == "chat.completion" + assert data.get("choices") + + +@pytest.mark.e2e +def test_live_chat_streaming(live_client): + payload = { + "model": get_default_model(), + "messages": [{"role": "user", "content": "Say only 'hi'."}], + "stream": True + } + with live_client.stream("POST", "/v1/chat/completions", json=payload) as response: + assert response.status_code == 200 + lines = [line for line in response.iter_lines() if line] + events = _parse_sse_lines(lines) + assert any(event.get("object") == "chat.completion.chunk" for event in events) diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index e323d34..56c08d7 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -20,6 +20,19 @@ import tempfile import shutil + +def parse_sse_events(body_text: str) -> List[Dict[str, Any]]: + """Parse SSE events from a streaming response body.""" + events = [] + for line in body_text.splitlines(): + if not line.startswith("data: "): + continue + payload = line[6:] + if payload == "[DONE]": + break + events.append(json.loads(payload)) + return events + # Import the FastAPI app import sys from pathlib import Path @@ -30,6 +43,11 @@ from claude_code_api.main import app from claude_code_api.core.config import settings +from claude_code_api.models.claude import get_available_models, get_default_model + + +AVAILABLE_MODELS = get_available_models() +DEFAULT_MODEL = get_default_model() class TestConfig: @@ -129,11 +147,12 @@ def test_list_models(self, client): def test_get_specific_model(self, client): """Test getting specific model.""" # Test Claude model - response = client.get("/v1/models/claude-3-5-haiku-20241022") + model_id = AVAILABLE_MODELS[0].id + response = client.get(f"/v1/models/{model_id}") assert response.status_code == 200 data = response.json() - assert data["id"] == "claude-3-5-haiku-20241022" + assert data["id"] == model_id assert data["object"] == "model" def test_get_openai_alias_model(self, client): @@ -162,7 +181,7 @@ class TestChatCompletions: def test_simple_chat_completion_non_streaming(self, client): """Test simple non-streaming chat completion.""" request_data = { - "model": "claude-3-5-haiku-20241022", + "model": DEFAULT_MODEL, "messages": [ {"role": "user", "content": "Hi"} ], @@ -176,7 +195,7 @@ def test_simple_chat_completion_non_streaming(self, client): assert "id" in data assert data["object"] == "chat.completion" assert "created" in data - assert data["model"] == "claude-3-5-haiku-20241022" + assert data["model"] == DEFAULT_MODEL assert "choices" in data assert len(data["choices"]) > 0 @@ -190,7 +209,7 @@ def test_simple_chat_completion_non_streaming(self, client): def test_chat_completion_with_system_prompt(self, client): """Test chat completion with system prompt.""" request_data = { - "model": "claude-3-5-haiku-20241022", + "model": DEFAULT_MODEL, "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, how are you?"} @@ -202,7 +221,7 @@ def test_chat_completion_with_system_prompt(self, client): assert response.status_code == 200 data = response.json() - assert data["model"] == "claude-3-5-haiku-20241022" + assert data["model"] == DEFAULT_MODEL assert len(data["choices"]) > 0 def test_chat_completion_with_invalid_model_fallback(self, client): @@ -222,7 +241,7 @@ def test_chat_completion_with_invalid_model_fallback(self, client): def test_chat_completion_streaming(self, client): """Test streaming chat completion.""" request_data = { - "model": "claude-3-5-haiku-20241022", + "model": DEFAULT_MODEL, "messages": [ {"role": "user", "content": "Tell me a short joke"} ], @@ -231,17 +250,60 @@ def test_chat_completion_streaming(self, client): response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 - assert response.headers["content-type"] == "text/plain; charset=utf-8" + assert "text/event-stream" in response.headers["content-type"] # Check that we get streaming data content = response.text assert "data: " in content - assert "event: " in content or "[DONE]" in content + assert "[DONE]" in content + events = parse_sse_events(content) + assert any(event.get("object") == "chat.completion.chunk" for event in events) + + def test_chat_completion_with_tool_calls(self, client): + """Test chat completion that includes tool calls.""" + request_data = { + "model": DEFAULT_MODEL, + "messages": [ + {"role": "user", "content": "Please use a tool to list files"} + ], + "stream": False + } + + response = client.post("/v1/chat/completions", json=request_data) + assert response.status_code == 200 + + data = response.json() + choice = data["choices"][0] + message = choice["message"] + assert "tool_calls" in message + assert len(message["tool_calls"]) > 0 + assert message["tool_calls"][0]["function"]["name"] == "bash" + + def test_chat_completion_streaming_tool_calls(self, client): + """Test streaming tool call deltas.""" + request_data = { + "model": DEFAULT_MODEL, + "messages": [ + {"role": "user", "content": "Please use a tool to list files"} + ], + "stream": True + } + + response = client.post("/v1/chat/completions", json=request_data) + assert response.status_code == 200 + assert "text/event-stream" in response.headers["content-type"] + + events = parse_sse_events(response.text) + assert any( + choice.get("delta", {}).get("tool_calls") + for event in events + for choice in event.get("choices", []) + ) def test_chat_completion_with_project_context(self, client): """Test chat completion with project context.""" request_data = { - "model": "claude-3-5-sonnet-20241022", + "model": DEFAULT_MODEL, "messages": [ {"role": "user", "content": "Hi, I'm working on a Python project"} ], @@ -259,7 +321,7 @@ def test_chat_completion_with_project_context(self, client): def test_chat_completion_missing_messages(self, client): """Test chat completion with missing messages.""" request_data = { - "model": "claude-3-5-sonnet-20241022", + "model": DEFAULT_MODEL, "messages": [], "stream": False } @@ -274,7 +336,7 @@ def test_chat_completion_missing_messages(self, client): def test_chat_completion_no_user_message(self, client): """Test chat completion with no user message.""" request_data = { - "model": "claude-3-5-sonnet-20241022", + "model": DEFAULT_MODEL, "messages": [ {"role": "system", "content": "You are a helpful assistant."} ], @@ -310,7 +372,7 @@ def test_conversation_continuity(self, client): """Test conversation continuity across messages.""" # First message request_data_1 = { - "model": "claude-3-5-sonnet-20241022", + "model": DEFAULT_MODEL, "messages": [ {"role": "user", "content": "My name is Alice"} ], @@ -326,7 +388,7 @@ def test_conversation_continuity(self, client): if session_id: # Follow-up message in same session request_data_2 = { - "model": "claude-3-5-sonnet-20241022", + "model": DEFAULT_MODEL, "messages": [ {"role": "user", "content": "My name is Alice"}, {"role": "assistant", "content": data_1["choices"][0]["message"]["content"]}, @@ -342,7 +404,7 @@ def test_conversation_continuity(self, client): def test_multiple_user_messages(self, client): """Test handling multiple user messages.""" request_data = { - "model": "claude-3-5-sonnet-20241022", + "model": DEFAULT_MODEL, "messages": [ {"role": "user", "content": "Hi"}, {"role": "user", "content": "How are you doing today?"} @@ -435,7 +497,7 @@ def test_create_session(self, client): session_data = { "project_id": "test-project", "title": "Test Session", - "model": "claude-3-5-sonnet-20241022" + "model": DEFAULT_MODEL } response = client.post("/v1/sessions", json=session_data) @@ -443,7 +505,7 @@ def test_create_session(self, client): data = response.json() assert data["project_id"] == "test-project" - assert data["model"] == "claude-3-5-sonnet-20241022" + assert data["model"] == DEFAULT_MODEL assert "id" in data assert "created_at" in data @@ -485,7 +547,7 @@ def test_missing_required_fields(self, client): def test_invalid_message_role(self, client): """Test handling of invalid message role.""" request_data = { - "model": "claude-3-5-sonnet-20241022", + "model": DEFAULT_MODEL, "messages": [ {"role": "invalid_role", "content": "Hi"} ] @@ -501,7 +563,7 @@ class TestRealWorldScenarios: def test_simple_greeting(self, client): """Test simple greeting - most common use case.""" request_data = { - "model": "claude-3-5-haiku-20241022", + "model": DEFAULT_MODEL, "messages": [ {"role": "user", "content": "Hi"} ] @@ -521,7 +583,7 @@ def test_simple_greeting(self, client): def test_code_generation_request(self, client): """Test code generation request.""" request_data = { - "model": "claude-3-5-haiku-20241022", + "model": DEFAULT_MODEL, "messages": [ {"role": "user", "content": "Write a Python function to calculate fibonacci numbers"} ] @@ -539,7 +601,7 @@ def test_multi_turn_conversation(self, client): """Test multi-turn conversation simulation.""" # Simulate a multi-turn conversation in a single request request_data = { - "model": "claude-3-5-haiku-20241022", + "model": DEFAULT_MODEL, "messages": [ {"role": "user", "content": "Hi, I'm learning Python"}, {"role": "assistant", "content": "Hello! That's great that you're learning Python. It's an excellent programming language for beginners and professionals alike. What specifically would you like to know about Python?"}, diff --git a/tests/test_openapi.py b/tests/test_openapi.py new file mode 100644 index 0000000..0fc4d5e --- /dev/null +++ b/tests/test_openapi.py @@ -0,0 +1,22 @@ +"""OpenAPI schema checks for streaming and tool calls.""" + + +def test_openapi_chat_completions_schema(test_client): + response = test_client.get("/openapi.json") + assert response.status_code == 200 + + schema = response.json() + assert "/v1/chat/completions" in schema["paths"] + + chat_post = schema["paths"]["/v1/chat/completions"]["post"] + assert "requestBody" in chat_post + + content = chat_post["responses"]["200"]["content"] + assert "application/json" in content + assert "text/event-stream" in content + + components = schema.get("components", {}).get("schemas", {}) + assert "ChatMessage" in components + assert "tool_calls" in components["ChatMessage"]["properties"] + assert "ChatCompletionChunkDelta" in components + assert "tool_calls" in components["ChatCompletionChunkDelta"]["properties"] diff --git a/tests/test_real_api.py b/tests/test_real_api.py index 8cb00ab..7a866e6 100755 --- a/tests/test_real_api.py +++ b/tests/test_real_api.py @@ -12,6 +12,7 @@ import signal import os from typing import Optional +from claude_code_api.models.claude import get_default_model class RealAPITester: def __init__(self, base_url: str = "http://localhost:8000"): @@ -22,7 +23,7 @@ def test_health(self) -> bool: """Test health endpoint.""" try: response = self.session.get(f"{self.base_url}/health", timeout=5) - print(f"🔍 Health Check: {response.status_code}") + print(f"Health Check: {response.status_code}") if response.status_code == 200: data = response.json() print(f" Status: {data.get('status')}") @@ -32,14 +33,14 @@ def test_health(self) -> bool: print(f" Error: {response.text}") return False except Exception as e: - print(f"❌ Health check failed: {e}") + print(f"Health check failed: {e}") return False def test_models(self) -> bool: """Test models endpoint.""" try: response = self.session.get(f"{self.base_url}/v1/models", timeout=5) - print(f"🔍 Models API: {response.status_code}") + print(f"Models API: {response.status_code}") if response.status_code == 200: data = response.json() models = data.get('data', []) @@ -51,7 +52,7 @@ def test_models(self) -> bool: print(f" Error: {response.text}") return False except Exception as e: - print(f"❌ Models test failed: {e}") + print(f"Models test failed: {e}") return False def test_auth_bypass(self) -> bool: @@ -59,35 +60,35 @@ def test_auth_bypass(self) -> bool: try: # Test without any auth headers response = self.session.get(f"{self.base_url}/v1/models", timeout=5) - print(f"🔍 Auth Bypass Test: {response.status_code}") + print(f"Auth Bypass Test: {response.status_code}") if response.status_code == 200: - print(" ✅ API works without authentication") + print(" API works without authentication") return True elif response.status_code == 401: - print(" ❌ API requires authentication") + print(" API requires authentication") error = response.json() print(f" Error: {error.get('error', {}).get('message', 'Unknown auth error')}") return False else: - print(f" ❌ Unexpected status: {response.text}") + print(f" Unexpected status: {response.text}") return False except Exception as e: - print(f"❌ Auth test failed: {e}") + print(f"Auth test failed: {e}") return False def test_chat_completion(self) -> bool: """Test chat completion endpoint (may be slow).""" try: payload = { - "model": "claude-3-5-haiku-20241022", + "model": get_default_model(), "messages": [ {"role": "user", "content": "Say 'test successful' and nothing else"} ], "stream": False } - print("🔍 Chat Completion (this may take a while)...") + print("Chat Completion (this may take a while)...") response = self.session.post( f"{self.base_url}/v1/chat/completions", json=payload, @@ -107,15 +108,15 @@ def test_chat_completion(self) -> bool: return response.status_code == 200 except requests.exceptions.Timeout: - print(" ⏰ Chat completion timed out (expected with mock setup)") + print(" Chat completion timed out (expected with mock setup)") return True # Timeout is expected with echo mock except Exception as e: - print(f"❌ Chat completion failed: {e}") + print(f"Chat completion failed: {e}") return False def run_all_tests(self) -> bool: """Run all tests and return overall success.""" - print("🚀 REAL End-to-End API Tests") + print("REAL End-to-End API Tests") print("=" * 40) tests = [ @@ -127,26 +128,26 @@ def run_all_tests(self) -> bool: results = [] for test_name, test_func in tests: - print(f"\n📋 {test_name}:") + print(f"\n{test_name}:") try: result = test_func() results.append(result) - status = "✅ PASS" if result else "❌ FAIL" + status = "PASS" if result else "FAIL" print(f" {status}") except Exception as e: - print(f" ❌ FAIL: {e}") + print(f" FAIL: {e}") results.append(False) print("\n" + "=" * 40) passed = sum(results) total = len(results) - print(f"📊 Results: {passed}/{total} tests passed") + print(f"Results: {passed}/{total} tests passed") if passed == total: - print("🎉 ALL TESTS PASSED!") + print("ALL TESTS PASSED!") return True else: - print("💥 SOME TESTS FAILED!") + print("SOME TESTS FAILED!") return False @@ -160,14 +161,14 @@ def check_server_running(url: str = "http://localhost:8000") -> bool: def main(): - print("🔍 Checking if API server is running...") + print("Checking if API server is running...") if not check_server_running(): - print("❌ API server not running on http://localhost:8000") - print("💡 Start the server with: make start") + print("API server not running on http://localhost:8000") + print("Start the server with: make start") sys.exit(1) - print("✅ Server is running!") + print("Server is running!") print() tester = RealAPITester() @@ -177,4 +178,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() From 5539cc11d21f4e9741a76c5a5752278dfa71b814 Mon Sep 17 00:00:00 2001 From: Mehdi Date: Wed, 4 Feb 2026 15:48:03 +0000 Subject: [PATCH 2/9] Fix QA & coverage --- Makefile | 6 +- README.md | 1 - claude_code_api/api/__init__.py | 7 +- claude_code_api/api/chat.py | 508 ++++++++++-------- claude_code_api/api/models.py | 76 ++- claude_code_api/api/projects.py | 147 +++-- claude_code_api/api/sessions.py | 97 ++-- claude_code_api/config/models.json | 10 - claude_code_api/core/__init__.py | 2 +- claude_code_api/core/auth.py | 95 ++-- claude_code_api/core/claude_manager.py | 278 +++++----- claude_code_api/core/config.py | 103 ++-- claude_code_api/core/database.py | 124 +++-- claude_code_api/core/security.py | 9 +- claude_code_api/core/session_manager.py | 194 ++++--- claude_code_api/main.py | 68 +-- claude_code_api/models/__init__.py | 2 +- claude_code_api/models/claude.py | 73 ++- claude_code_api/models/openai.py | 206 +++++-- claude_code_api/tests/test_gpt_turbo.py | 34 +- claude_code_api/utils/__init__.py | 2 +- claude_code_api/utils/parser.py | 250 ++++----- claude_code_api/utils/streaming.py | 407 +++++++------- claude_code_api/utils/time.py | 13 + setup.cfg | 6 + tests/__init__.py | 1 + tests/conftest.py | 101 ++-- tests/fixtures/claude_stream_simple.jsonl | 6 +- tests/fixtures/claude_stream_tool_calls.jsonl | 10 +- tests/model_utils.py | 14 + tests/test_auth.py | 196 +++++++ tests/test_claude_manager_unit.py | 48 ++ tests/test_claude_working.py | 46 +- tests/test_config.py | 106 ++++ tests/test_e2e_live_api.py | 36 +- tests/test_end_to_end.py | 320 ++++++----- tests/test_parser.py | 170 ++++++ tests/test_real_api.py | 76 +-- tests/test_security.py | 5 + tests/test_session_manager_unit.py | 77 +++ tests/test_utils_time.py | 18 + 41 files changed, 2439 insertions(+), 1509 deletions(-) create mode 100644 claude_code_api/utils/time.py create mode 100644 setup.cfg create mode 100644 tests/__init__.py create mode 100644 tests/model_utils.py create mode 100644 tests/test_auth.py create mode 100644 tests/test_claude_manager_unit.py create mode 100644 tests/test_config.py create mode 100644 tests/test_parser.py create mode 100644 tests/test_session_manager_unit.py create mode 100644 tests/test_utils_time.py diff --git a/Makefile b/Makefile index 2ec321c..f5163df 100644 --- a/Makefile +++ b/Makefile @@ -89,7 +89,7 @@ kill: sonar: ## Run sonar-scanner for SonarQube analysis @mkdir -p $(SONAR_DIR) $(COVERAGE_DIR) @echo "Generating coverage report for SonarQube..." - @python -m pytest --cov=claude_code_api --cov-report=xml --cov-report=term-missing -v tests/ + @python -m pytest --cov=claude_code_api --cov-report=xml:$(COVERAGE_DIR)/coverage.xml --cov-report=term-missing --junitxml=$(SONAR_DIR)/xunit-report.xml -v tests/ @if command -v sonar-scanner >/dev/null 2>&1; then \ if [ -f ".env.vault" ]; then \ . ./.env.vault; \ @@ -131,7 +131,7 @@ sonar-cloud: ## Run sonar-scanner for SonarCloud (uses different token/env) coverage-sonar: ## Generate coverage for SonarQube @mkdir -p $(COVERAGE_DIR) - @python -m pytest --cov=claude_code_api --cov-report=xml --cov-report=term-missing -v tests/ + @python -m pytest --cov=claude_code_api --cov-report=xml:$(COVERAGE_DIR)/coverage.xml --cov-report=term-missing --junitxml=$(SONAR_DIR)/xunit-report.xml -v tests/ @echo "Coverage XML generated: $(COVERAGE_DIR)/coverage.xml" sbom: ## Generate SBOM with syft @@ -220,4 +220,4 @@ help: @echo " make kill PORT=X - Kill process on specific port" @echo "" @echo "IMPORTANT: Both implementations are functionally equivalent!" - @echo "Use Python or TypeScript - both provide the same OpenAI-compatible API." \ No newline at end of file + @echo "Use Python or TypeScript - both provide the same OpenAI-compatible API." diff --git a/README.md b/README.md index 4512d4e..d010c1b 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,6 @@ Override with `CLAUDE_CODE_API_MODELS_PATH` to point at a custom JSON file. - `claude-opus-4-5-20250929` - Claude Opus 4.5 (Most powerful) - `claude-sonnet-4-5-20250929` - Claude Sonnet 4.5 (Latest Sonnet) - `claude-haiku-4-5-20250929` - Claude Haiku 4.5 (Fast & cost-effective) -- `claude-3-5-haiku-20241022` - Claude Haiku 3.5 (Fast & cost-effective) ## Quick Start diff --git a/claude_code_api/api/__init__.py b/claude_code_api/api/__init__.py index 9136fee..a648775 100644 --- a/claude_code_api/api/__init__.py +++ b/claude_code_api/api/__init__.py @@ -5,9 +5,4 @@ from claude_code_api.api.projects import router as projects_router from claude_code_api.api.sessions import router as sessions_router -__all__ = [ - "chat_router", - "models_router", - "projects_router", - "sessions_router" -] +__all__ = ["chat_router", "models_router", "projects_router", "sessions_router"] diff --git a/claude_code_api/api/chat.py b/claude_code_api/api/chat.py index 9c92334..3429830 100644 --- a/claude_code_api/api/chat.py +++ b/claude_code_api/api/chat.py @@ -1,22 +1,31 @@ """Chat completions API endpoint - OpenAI compatible.""" import json -from typing import Dict, Any -from fastapi import APIRouter, Request, HTTPException, status +from typing import Any, Dict, Tuple + +import structlog +from fastapi import APIRouter, HTTPException, Request, status from fastapi.responses import StreamingResponse from pydantic import ValidationError -import structlog +from claude_code_api.core.claude_manager import create_project_directory +from claude_code_api.core.session_manager import SessionManager +from claude_code_api.models.claude import validate_claude_model from claude_code_api.models.openai import ( - ChatCompletionRequest, + ChatCompletionRequest, ChatCompletionResponse, - ErrorResponse + ErrorResponse, +) +from claude_code_api.utils.parser import ( + ClaudeOutputParser, + OpenAIConverter, + estimate_tokens, + normalize_claude_message, +) +from claude_code_api.utils.streaming import ( + create_non_streaming_response, + create_sse_response, ) -from claude_code_api.models.claude import validate_claude_model -from claude_code_api.core.claude_manager import create_project_directory -from claude_code_api.core.session_manager import SessionManager -from claude_code_api.utils.streaming import create_sse_response, create_non_streaming_response -from claude_code_api.utils.parser import ClaudeOutputParser, OpenAIConverter, estimate_tokens, normalize_claude_message logger = structlog.get_logger() router = APIRouter() @@ -30,54 +39,231 @@ }, "text/event-stream": { "schema": {"$ref": "#/components/schemas/ChatCompletionChunk"} - } - } + }, + }, }, 400: {"model": ErrorResponse}, 422: {"model": ErrorResponse}, 503: {"model": ErrorResponse}, - 500: {"model": ErrorResponse} + 500: {"model": ErrorResponse}, } +def _http_error( + status_code: int, message: str, error_type: str, code: str +) -> HTTPException: + return HTTPException( + status_code=status_code, + detail={"error": {"message": message, "type": error_type, "code": code}}, + ) + + +async def _log_raw_request(req: Request) -> None: + raw_body = await req.body() + content_type = req.headers.get("content-type", "unknown") + logger.info( + "Raw request received", + content_type=content_type, + body_size=len(raw_body), + user_agent=req.headers.get("user-agent", "unknown"), + raw_body=raw_body.decode()[:1000] if raw_body else "empty", + ) + + +def _extract_prompts(request: ChatCompletionRequest) -> Tuple[str, str]: + if not request.messages: + raise _http_error( + status.HTTP_400_BAD_REQUEST, + "At least one message is required", + "invalid_request_error", + "missing_messages", + ) + user_messages = [msg for msg in request.messages if msg.role == "user"] + if not user_messages: + raise _http_error( + status.HTTP_400_BAD_REQUEST, + "At least one user message is required", + "invalid_request_error", + "missing_user_message", + ) + user_prompt = user_messages[-1].get_text_content() + system_messages = [msg for msg in request.messages if msg.role == "system"] + system_prompt = ( + system_messages[0].get_text_content() + if system_messages + else request.system_prompt + ) + return user_prompt, system_prompt + + +async def _resolve_session( + session_manager: SessionManager, + request: ChatCompletionRequest, + project_id: str, + claude_model: str, + system_prompt: str, +) -> str: + if request.session_id: + session_id = request.session_id + session_info = await session_manager.get_session(session_id) + if not session_info: + raise _http_error( + status.HTTP_404_NOT_FOUND, + f"Session {session_id} not found", + "invalid_request_error", + "session_not_found", + ) + return session_id + return await session_manager.create_session( + project_id=project_id, model=claude_model, system_prompt=system_prompt + ) + + +async def _collect_non_streaming_response( + claude_process, + session_manager: SessionManager, + session_id: str, + model: str, + project_id: str, +) -> Dict[str, Any]: + messages, parser = await _gather_claude_messages(claude_process) + _log_message_summary(messages) + + usage_summary = OpenAIConverter.calculate_usage(parser) + await _update_session_usage( + session_manager, session_id, usage_summary, parser.total_cost + ) + + response = _build_non_streaming_response( + messages, session_id, model, usage_summary, project_id + ) + _log_response_payload(response) + return response + + +async def _gather_claude_messages(claude_process) -> Tuple[list, ClaudeOutputParser]: + messages = [] + parser = ClaudeOutputParser() + async for claude_message in claude_process.get_output(): + _log_claude_message(claude_message) + messages.append(claude_message) + normalized = normalize_claude_message(claude_message) + if not normalized: + continue + parser.parse_message(normalized) + if parser.is_final_message(normalized): + break + return messages, parser + + +def _log_claude_message(claude_message: Any) -> None: + logger.info( + "Received Claude message", + message_type=( + claude_message.get("type") + if isinstance(claude_message, dict) + else type(claude_message).__name__ + ), + message_keys=( + list(claude_message.keys()) if isinstance(claude_message, dict) else [] + ), + has_assistant_content=bool( + isinstance(claude_message, dict) + and claude_message.get("type") == "assistant" + and claude_message.get("message", {}).get("content") + ), + message_preview=str(claude_message)[:200] if claude_message else "None", + ) + + +def _log_message_summary(messages: list) -> None: + logger.info( + "Claude messages collected", + total_messages=len(messages), + message_types=[ + msg.get("type") if isinstance(msg, dict) else type(msg).__name__ + for msg in messages + ], + ) + + +async def _update_session_usage( + session_manager: SessionManager, + session_id: str, + usage_summary: Dict[str, Any], + total_cost: float, +) -> None: + await session_manager.update_session( + session_id=session_id, + tokens_used=usage_summary.get("total_tokens", 0), + cost=total_cost, + ) + + +def _build_non_streaming_response( + messages: list, + session_id: str, + model: str, + usage_summary: Dict[str, Any], + project_id: str, +) -> Dict[str, Any]: + response = create_non_streaming_response( + messages=messages, session_id=session_id, model=model, usage=usage_summary + ) + response["project_id"] = project_id + return response + + +def _log_response_payload(response: Dict[str, Any]) -> None: + choices = response.get("choices") or [] + first_choice = choices[0] if choices else {} + message = first_choice.get("message", {}) if isinstance(first_choice, dict) else {} + content = message.get("content") if isinstance(message, dict) else None + + logger.info( + "Returning chat completion response", + response_id=response.get("id"), + choices_count=len(choices), + has_choices_0=bool(choices), + choices_0_keys=( + list(first_choice.keys()) if isinstance(first_choice, dict) else [] + ), + message_keys=list(message.keys()) if isinstance(message, dict) else [], + content_length=len(content or ""), + full_response_keys=list(response.keys()), + response_size=len(str(response)), + ) + + @router.post( "/chat/completions", response_model=ChatCompletionResponse, - responses=CHAT_COMPLETION_RESPONSES + responses=CHAT_COMPLETION_RESPONSES, ) -async def create_chat_completion( - request: ChatCompletionRequest, - req: Request -) -> Any: +async def create_chat_completion(request: ChatCompletionRequest, req: Request) -> Any: """Create a chat completion, compatible with OpenAI API.""" - + # Log raw request for debugging try: - raw_body = await req.body() - content_type = req.headers.get("content-type", "unknown") - logger.info( - "Raw request received", - content_type=content_type, - body_size=len(raw_body), - user_agent=req.headers.get("user-agent", "unknown"), - raw_body=raw_body.decode()[:1000] if raw_body else "empty" - ) + await _log_raw_request(req) except HTTPException: raise except Exception as e: logger.error("Failed to process request", error=str(e)) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail={"error": {"message": "Internal server error", "type": "internal_error"}} + raise _http_error( + status.HTTP_500_INTERNAL_SERVER_ERROR, + "Internal server error", + "internal_error", + "internal_error", ) - + # Get managers from app state session_manager: SessionManager = req.app.state.session_manager claude_manager = req.app.state.claude_manager - + # Extract client info for logging - client_id = getattr(req.state, 'client_id', 'anonymous') - + client_id = getattr(req.state, "client_id", "anonymous") + logger.info( "Chat completion request validated", client_id=client_id, @@ -85,75 +271,28 @@ async def create_chat_completion( messages_count=len(request.messages), stream=request.stream, project_id=request.project_id, - session_id=request.session_id + session_id=request.session_id, ) - + try: # Validate model claude_model = validate_claude_model(request.model) - - # Validate message format - if not request.messages: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={ - "error": { - "message": "At least one message is required", - "type": "invalid_request_error", - "code": "missing_messages" - } - } - ) - - # Extract the user prompt (last user message) - user_messages = [msg for msg in request.messages if msg.role == "user"] - if not user_messages: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={ - "error": { - "message": "At least one user message is required", - "type": "invalid_request_error", - "code": "missing_user_message" - } - } - ) - - user_prompt = user_messages[-1].get_text_content() - - # Extract system prompt - system_messages = [msg for msg in request.messages if msg.role == "system"] - system_prompt = system_messages[0].get_text_content() if system_messages else request.system_prompt - + + user_prompt, system_prompt = _extract_prompts(request) + # Handle project context project_id = request.project_id or f"default-{client_id}" project_path = create_project_directory(project_id) - + # Handle session management - if request.session_id: - # Continue existing session - session_id = request.session_id - session_info = await session_manager.get_session(session_id) - - if not session_info: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail={ - "error": { - "message": f"Session {session_id} not found", - "type": "invalid_request_error", - "code": "session_not_found" - } - } - ) - else: - # Create new session - session_id = await session_manager.create_session( - project_id=project_id, - model=claude_model, - system_prompt=system_prompt - ) - + session_id = await _resolve_session( + session_manager=session_manager, + request=request, + project_id=project_id, + claude_model=claude_model, + system_prompt=system_prompt, + ) + # Start Claude Code process try: claude_process = await claude_manager.create_session( @@ -162,39 +301,31 @@ async def create_chat_completion( prompt=user_prompt, model=claude_model, system_prompt=system_prompt, - resume_session=request.session_id ) except Exception as e: logger.error( - "Failed to create Claude session", - session_id=session_id, - error=str(e) + "Failed to create Claude session", session_id=session_id, error=str(e) ) - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail={ - "error": { - "message": f"Failed to start Claude Code: {str(e)}", - "type": "service_unavailable", - "code": "claude_unavailable" - } - } + raise _http_error( + status.HTTP_503_SERVICE_UNAVAILABLE, + f"Failed to start Claude Code: {str(e)}", + "service_unavailable", + "claude_unavailable", ) - + # Use Claude's actual session ID claude_session_id = claude_process.session_id - + # Update session with user message await session_manager.update_session( session_id=claude_session_id, message_content=user_prompt, role="user", - tokens_used=estimate_tokens(user_prompt) + tokens_used=estimate_tokens(user_prompt), ) - + # Handle streaming vs non-streaming if request.stream: - # Return streaming response return StreamingResponse( create_sse_response(claude_session_id, claude_model, claude_process), media_type="text/event-stream", @@ -203,73 +334,18 @@ async def create_chat_completion( "Connection": "keep-alive", "X-Accel-Buffering": "no", "X-Session-ID": claude_session_id, - "X-Project-ID": project_id - } - ) - else: - # Collect all output for non-streaming response - messages = [] - parser = ClaudeOutputParser() - - async for claude_message in claude_process.get_output(): - # Log each message from Claude - logger.info( - "Received Claude message", - message_type=claude_message.get("type") if isinstance(claude_message, dict) else type(claude_message).__name__, - message_keys=list(claude_message.keys()) if isinstance(claude_message, dict) else [], - has_assistant_content=bool(isinstance(claude_message, dict) and - claude_message.get("type") == "assistant" and - claude_message.get("message", {}).get("content")), - message_preview=str(claude_message)[:200] if claude_message else "None" - ) - - messages.append(claude_message) - normalized = normalize_claude_message(claude_message) - if normalized: - parser.parse_message(normalized) - if parser.is_final_message(normalized): - break - - # Log what we collected - logger.info( - "Claude messages collected", - total_messages=len(messages), - message_types=[msg.get("type") if isinstance(msg, dict) else type(msg).__name__ for msg in messages] + "X-Project-ID": project_id, + }, ) - - usage_summary = OpenAIConverter.calculate_usage(parser) - await session_manager.update_session( - session_id=claude_session_id, - tokens_used=usage_summary.get("total_tokens", 0), - cost=parser.total_cost - ) - - # Create non-streaming response - response = create_non_streaming_response( - messages=messages, - session_id=claude_session_id, - model=claude_model, - usage=usage_summary - ) - - # Add extension fields - response["project_id"] = project_id - - # Log the complete response before returning - logger.info( - "Returning chat completion response", - response_id=response.get("id"), - choices_count=len(response.get("choices", [])), - has_choices_0=bool(response.get("choices") and len(response["choices"]) > 0), - choices_0_keys=list(response["choices"][0].keys()) if response.get("choices") and len(response["choices"]) > 0 else [], - message_keys=list(response["choices"][0]["message"].keys()) if response.get("choices") and len(response["choices"]) > 0 and "message" in response["choices"][0] else [], - content_length=len((response["choices"][0]["message"].get("content") or "")) if response.get("choices") and len(response["choices"]) > 0 and "message" in response["choices"][0] else 0, - full_response_keys=list(response.keys()), - response_size=len(str(response)) - ) - - return response - + + return await _collect_non_streaming_response( + claude_process=claude_process, + session_manager=session_manager, + session_id=claude_session_id, + model=claude_model, + project_id=project_id, + ) + except HTTPException: # Re-raise HTTP exceptions raise @@ -278,7 +354,7 @@ async def create_chat_completion( "Unexpected error in chat completion", client_id=client_id, error=str(e), - exc_info=True + exc_info=True, ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -286,22 +362,19 @@ async def create_chat_completion( "error": { "message": "Internal server error", "type": "internal_error", - "code": "unexpected_error" + "code": "unexpected_error", } - } + }, ) @router.get("/chat/completions/{session_id}/status") -async def get_completion_status( - session_id: str, - req: Request -) -> Dict[str, Any]: +async def get_completion_status(session_id: str, req: Request) -> Dict[str, Any]: """Get status of a chat completion session.""" - + session_manager: SessionManager = req.app.state.session_manager claude_manager = req.app.state.claude_manager - + # Get session info session_info = await session_manager.get_session(session_id) if not session_info: @@ -311,15 +384,15 @@ async def get_completion_status( "error": { "message": f"Session {session_id} not found", "type": "not_found", - "code": "session_not_found" + "code": "session_not_found", } - } + }, ) - + # Get Claude process status - claude_process = await claude_manager.get_session(session_id) + claude_process = claude_manager.get_session(session_id) is_running = claude_process is not None and claude_process.is_running - + return { "session_id": session_id, "project_id": session_info.project_id, @@ -329,7 +402,7 @@ async def get_completion_status( "updated_at": session_info.updated_at.isoformat(), "total_tokens": session_info.total_tokens, "total_cost": session_info.total_cost, - "message_count": session_info.message_count + "message_count": session_info.message_count, } @@ -339,18 +412,18 @@ async def debug_chat_completion(req: Request) -> Dict[str, Any]: try: raw_body = await req.body() headers = dict(req.headers) - + logger.info( "Debug request", content_type=headers.get("content-type"), body_size=len(raw_body), headers=headers, - raw_body=raw_body.decode() if raw_body else "empty" + raw_body=raw_body.decode() if raw_body else "empty", ) - + if raw_body: json_data = json.loads(raw_body.decode()) - + # Try validation try: request = ChatCompletionRequest(**json_data) @@ -360,45 +433,36 @@ async def debug_chat_completion(req: Request) -> Dict[str, Any]: "parsed_data": { "model": request.model, "messages_count": len(request.messages), - "stream": request.stream - } + "stream": request.stream, + }, } except ValidationError as e: return { "status": "validation_error", "message": str(e), "errors": e.errors(), - "raw_data": json_data + "raw_data": json_data, } - + return {"status": "no_body"} - + except Exception as e: - return { - "status": "error", - "message": str(e) - } + return {"status": "error", "message": str(e)} @router.delete("/chat/completions/{session_id}") -async def stop_completion( - session_id: str, - req: Request -) -> Dict[str, str]: +async def stop_completion(session_id: str, req: Request) -> Dict[str, str]: """Stop a running chat completion session.""" - + session_manager: SessionManager = req.app.state.session_manager claude_manager = req.app.state.claude_manager - + # Stop Claude process await claude_manager.stop_session(session_id) - + # End session - await session_manager.end_session(session_id) - + session_manager.end_session(session_id) + logger.info("Chat completion stopped", session_id=session_id) - - return { - "session_id": session_id, - "status": "stopped" - } + + return {"session_id": session_id, "status": "stopped"} diff --git a/claude_code_api/api/models.py b/claude_code_api/api/models.py index 1bf5adf..01c892f 100644 --- a/claude_code_api/api/models.py +++ b/claude_code_api/api/models.py @@ -1,12 +1,12 @@ """Models API endpoint - OpenAI compatible.""" from datetime import datetime -from typing import List -from fastapi import APIRouter, Request + import structlog +from fastapi import APIRouter, Request -from claude_code_api.models.openai import ModelObject, ModelListResponse from claude_code_api.models.claude import get_available_models +from claude_code_api.models.openai import ModelListResponse, ModelObject logger = structlog.get_logger() router = APIRouter() @@ -15,58 +15,53 @@ @router.get("/models", response_model=ModelListResponse) async def list_models(req: Request) -> ModelListResponse: """List available models, compatible with OpenAI API.""" - + # Get Claude Code version for owned_by field claude_manager = req.app.state.claude_manager try: claude_version = await claude_manager.get_version() owned_by = f"anthropic-claude-{claude_version}" - except: + except Exception: owned_by = "anthropic" - + # Get available Claude models claude_models = get_available_models() - + # Convert to OpenAI format model_objects = [] base_timestamp = int(datetime(2024, 1, 1).timestamp()) - + for idx, model_info in enumerate(claude_models): model_obj = ModelObject( id=model_info.id, object="model", created=base_timestamp + idx, # Stagger timestamps - owned_by=owned_by + owned_by=owned_by, ) model_objects.append(model_obj) - + # Only Claude models - no OpenAI aliases all_models = model_objects - + logger.info( - "Listed models", - count=len(all_models), - claude_models=len(model_objects) - ) - - return ModelListResponse( - object="list", - data=all_models + "Listed models", count=len(all_models), claude_models=len(model_objects) ) + return ModelListResponse(object="list", data=all_models) + @router.get("/models/{model_id}") async def get_model(model_id: str, req: Request) -> ModelObject: """Get specific model information.""" - + # Get Claude Code version claude_manager = req.app.state.claude_manager try: claude_version = await claude_manager.get_version() owned_by = f"anthropic-claude-{claude_version}" - except: + except Exception: owned_by = "anthropic" - + # Check if it's a Claude model claude_models = get_available_models() for model_info in claude_models: @@ -75,31 +70,32 @@ async def get_model(model_id: str, req: Request) -> ModelObject: id=model_info.id, object="model", created=int(datetime(2024, 1, 1).timestamp()), - owned_by=owned_by + owned_by=owned_by, ) - + # No OpenAI aliases supported - + # Model not found from fastapi import HTTPException, status + raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error": { "message": f"Model {model_id} not found", "type": "not_found", - "code": "model_not_found" + "code": "model_not_found", } - } + }, ) @router.get("/models/capabilities") async def get_model_capabilities(): """Get detailed model capabilities (extension endpoint).""" - + claude_models = get_available_models() - + capabilities = [] for model_info in claude_models: capability = { @@ -112,29 +108,27 @@ async def get_model_capabilities(): "pricing": { "input_cost_per_1k_tokens": model_info.input_cost_per_1k, "output_cost_per_1k_tokens": model_info.output_cost_per_1k, - "currency": "USD" + "currency": "USD", }, "features": [ "text_generation", "conversation", "code_generation", "analysis", - "reasoning" - ] + "reasoning", + ], } - + if model_info.supports_tools: - capability["features"].extend([ - "file_operations", - "bash_execution", - "project_management" - ]) - + capability["features"].extend( + ["file_operations", "bash_execution", "project_management"] + ) + capabilities.append(capability) - + return { "models": capabilities, "total": len(capabilities), "provider": "anthropic", - "adapter": "claude-code-api" + "adapter": "claude-code-api", } diff --git a/claude_code_api/api/projects.py b/claude_code_api/api/projects.py index 0b0795a..abf4063 100644 --- a/claude_code_api/api/projects.py +++ b/claude_code_api/api/projects.py @@ -1,23 +1,28 @@ """Projects API endpoint - Extension to OpenAI API.""" -import uuid +import math import os -from datetime import datetime -from typing import List, Optional -from fastapi import APIRouter, Request, HTTPException, status, Depends -from fastapi.responses import JSONResponse +import uuid + import structlog +from fastapi import APIRouter, HTTPException, Request, status +from fastapi.responses import JSONResponse +from sqlalchemy.exc import SQLAlchemyError +from claude_code_api.core.claude_manager import ( + cleanup_project_directory, + create_project_directory, +) +from claude_code_api.core.config import settings +from claude_code_api.core.database import db_manager +from claude_code_api.core.security import validate_path from claude_code_api.models.openai import ( - ProjectInfo, CreateProjectRequest, PaginatedResponse, - PaginationInfo + PaginationInfo, + ProjectInfo, ) -from claude_code_api.core.database import db_manager, Project -from claude_code_api.core.claude_manager import create_project_directory, cleanup_project_directory -from claude_code_api.core.security import validate_path -from claude_code_api.core.config import settings +from claude_code_api.utils.time import utc_now logger = structlog.get_logger() router = APIRouter() @@ -25,50 +30,48 @@ @router.get("/projects", response_model=PaginatedResponse) async def list_projects( - page: int = 1, - per_page: int = 20, - req: Request = None + page: int = 1, per_page: int = 20, req: Request = None ) -> PaginatedResponse: """List all projects.""" - - # TODO: Implement proper pagination with database - # For now, return mock data - projects = [ + page = max(1, page) + per_page = max(1, per_page) + total_items = await db_manager.count_projects() + total_pages = math.ceil(total_items / per_page) if total_items else 0 + projects = await db_manager.list_projects(page, per_page) + + project_infos = [ ProjectInfo( - id="project-1", - name="Sample Project", - description="A sample project for testing", - path="/tmp/claude_projects/project-1", - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - is_active=True + id=project.id, + name=project.name, + description=project.description, + path=project.path, + created_at=project.created_at, + updated_at=project.updated_at, + is_active=project.is_active, ) + for project in projects ] - + pagination = PaginationInfo( page=page, per_page=per_page, - total_items=len(projects), - total_pages=1, - has_next=False, - has_prev=False - ) - - return PaginatedResponse( - data=projects, - pagination=pagination + total_items=total_items, + total_pages=total_pages, + has_next=page < total_pages, + has_prev=page > 1, ) + return PaginatedResponse(data=project_infos, pagination=pagination) + @router.post("/projects", response_model=ProjectInfo) async def create_project( - project_request: CreateProjectRequest, - req: Request + project_request: CreateProjectRequest, req: Request ) -> ProjectInfo: """Create a new project.""" - + project_id = str(uuid.uuid4()) - + # Create project directory if project_request.path: # Validate path @@ -77,50 +80,50 @@ async def create_project( os.makedirs(project_path, exist_ok=True) else: project_path = create_project_directory(project_id) - + # Create project in database project_data = { "id": project_id, "name": project_request.name, "description": project_request.description, "path": project_path, - "created_at": datetime.utcnow(), - "updated_at": datetime.utcnow(), - "is_active": True + "created_at": utc_now(), + "updated_at": utc_now(), + "is_active": True, } - + try: await db_manager.create_project(project_data) - + project_info = ProjectInfo(**project_data) - + logger.info( "Project created", project_id=project_id, name=project_request.name, - path=project_path + path=project_path, ) - + return project_info - - except Exception as e: - logger.error("Failed to create project", error=str(e)) + + except SQLAlchemyError as exc: + logger.exception("Failed to create project") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={ "error": { - "message": f"Failed to create project: {str(e)}", + "message": "Failed to create project.", "type": "internal_error", - "code": "project_creation_failed" + "code": "project_creation_failed", } - } - ) + }, + ) from exc @router.get("/projects/{project_id}", response_model=ProjectInfo) async def get_project(project_id: str, req: Request) -> ProjectInfo: """Get project by ID.""" - + project = await db_manager.get_project(project_id) if not project: raise HTTPException( @@ -129,11 +132,11 @@ async def get_project(project_id: str, req: Request) -> ProjectInfo: "error": { "message": f"Project {project_id} not found", "type": "not_found", - "code": "project_not_found" + "code": "project_not_found", } - } + }, ) - + return ProjectInfo( id=project.id, name=project.name, @@ -141,14 +144,14 @@ async def get_project(project_id: str, req: Request) -> ProjectInfo: path=project.path, created_at=project.created_at, updated_at=project.updated_at, - is_active=project.is_active + is_active=project.is_active, ) @router.delete("/projects/{project_id}") async def delete_project(project_id: str, req: Request) -> JSONResponse: """Delete project by ID.""" - + project = await db_manager.get_project(project_id) if not project: raise HTTPException( @@ -157,19 +160,15 @@ async def delete_project(project_id: str, req: Request) -> JSONResponse: "error": { "message": f"Project {project_id} not found", "type": "not_found", - "code": "project_not_found" + "code": "project_not_found", } - } + }, ) - - # TODO: Implement project deletion in database - # cleanup_project_directory(project.path) - + + deleted = await db_manager.delete_project(project_id) + if deleted: + cleanup_project_directory(project.path) + logger.info("Project deleted", project_id=project_id) - - return JSONResponse( - content={ - "project_id": project_id, - "status": "deleted" - } - ) + + return JSONResponse(content={"project_id": project_id, "status": "deleted"}) diff --git a/claude_code_api/api/sessions.py b/claude_code_api/api/sessions.py index d4520ad..b174682 100644 --- a/claude_code_api/api/sessions.py +++ b/claude_code_api/api/sessions.py @@ -1,17 +1,18 @@ """Sessions API endpoint - Extension to OpenAI API.""" -from typing import List, Dict, Any -from fastapi import APIRouter, Request, HTTPException, status -from fastapi.responses import JSONResponse +from typing import Any, Dict + import structlog +from fastapi import APIRouter, HTTPException, Request, status +from fastapi.responses import JSONResponse +from claude_code_api.core.session_manager import SessionManager from claude_code_api.models.openai import ( - SessionInfo, CreateSessionRequest, PaginatedResponse, - PaginationInfo + PaginationInfo, + SessionInfo, ) -from claude_code_api.core.session_manager import SessionManager logger = structlog.get_logger() router = APIRouter() @@ -19,15 +20,12 @@ @router.get("/sessions", response_model=PaginatedResponse) async def list_sessions( - page: int = 1, - per_page: int = 20, - project_id: str = None, - req: Request = None + page: int = 1, per_page: int = 20, project_id: str = None, req: Request = None ) -> PaginatedResponse: """List all sessions.""" - + session_manager: SessionManager = req.app.state.session_manager - + # Get active sessions active_sessions = [] for session_id, session_info in session_manager.active_sessions.items(): @@ -43,48 +41,44 @@ async def list_sessions( is_active=session_info.is_active, total_tokens=session_info.total_tokens, total_cost=session_info.total_cost, - message_count=session_info.message_count + message_count=session_info.message_count, ) active_sessions.append(session_data) - + # Simple pagination start_idx = (page - 1) * per_page end_idx = start_idx + per_page paginated_sessions = active_sessions[start_idx:end_idx] - + pagination = PaginationInfo( page=page, per_page=per_page, total_items=len(active_sessions), total_pages=(len(active_sessions) + per_page - 1) // per_page, has_next=end_idx < len(active_sessions), - has_prev=page > 1 - ) - - return PaginatedResponse( - data=paginated_sessions, - pagination=pagination + has_prev=page > 1, ) + return PaginatedResponse(data=paginated_sessions, pagination=pagination) + @router.post("/sessions", response_model=SessionInfo) async def create_session( - session_request: CreateSessionRequest, - req: Request + session_request: CreateSessionRequest, req: Request ) -> SessionInfo: """Create a new session.""" - + session_manager: SessionManager = req.app.state.session_manager - + try: session_id = await session_manager.create_session( project_id=session_request.project_id, model=session_request.model, - system_prompt=session_request.system_prompt + system_prompt=session_request.system_prompt, ) - + session_info = await session_manager.get_session(session_id) - + response = SessionInfo( id=session_info.session_id, project_id=session_info.project_id, @@ -96,17 +90,17 @@ async def create_session( is_active=session_info.is_active, total_tokens=session_info.total_tokens, total_cost=session_info.total_cost, - message_count=session_info.message_count + message_count=session_info.message_count, ) - + logger.info( "Session created", session_id=session_id, - project_id=session_request.project_id + project_id=session_request.project_id, ) - + return response - + except Exception as e: logger.error("Failed to create session", error=str(e)) raise HTTPException( @@ -115,9 +109,9 @@ async def create_session( "error": { "message": f"Failed to create session: {str(e)}", "type": "internal_error", - "code": "session_creation_failed" + "code": "session_creation_failed", } - } + }, ) @@ -134,16 +128,16 @@ async def get_session_stats(req: Request) -> Dict[str, Any]: return { "session_stats": session_stats, "active_claude_sessions": len(active_claude_sessions), - "claude_sessions": active_claude_sessions + "claude_sessions": active_claude_sessions, } @router.get("/sessions/{session_id}", response_model=SessionInfo) async def get_session(session_id: str, req: Request) -> SessionInfo: """Get session by ID.""" - + session_manager: SessionManager = req.app.state.session_manager - + session_info = await session_manager.get_session(session_id) if not session_info: raise HTTPException( @@ -152,11 +146,11 @@ async def get_session(session_id: str, req: Request) -> SessionInfo: "error": { "message": f"Session {session_id} not found", "type": "not_found", - "code": "session_not_found" + "code": "session_not_found", } - } + }, ) - + return SessionInfo( id=session_info.session_id, project_id=session_info.project_id, @@ -168,28 +162,23 @@ async def get_session(session_id: str, req: Request) -> SessionInfo: is_active=session_info.is_active, total_tokens=session_info.total_tokens, total_cost=session_info.total_cost, - message_count=session_info.message_count + message_count=session_info.message_count, ) @router.delete("/sessions/{session_id}") async def delete_session(session_id: str, req: Request) -> JSONResponse: """Delete session by ID.""" - + session_manager: SessionManager = req.app.state.session_manager claude_manager = req.app.state.claude_manager - + # Stop Claude process if running await claude_manager.stop_session(session_id) - + # End session - await session_manager.end_session(session_id) - + session_manager.end_session(session_id) + logger.info("Session deleted", session_id=session_id) - - return JSONResponse( - content={ - "session_id": session_id, - "status": "deleted" - } - ) + + return JSONResponse(content={"session_id": session_id, "status": "deleted"}) diff --git a/claude_code_api/config/models.json b/claude_code_api/config/models.json index 07f0cec..8a9087c 100644 --- a/claude_code_api/config/models.json +++ b/claude_code_api/config/models.json @@ -30,16 +30,6 @@ "output_cost_per_1k": 1.25, "supports_streaming": true, "supports_tools": true - }, - { - "id": "claude-3-5-haiku-20241022", - "name": "Claude Haiku 3.5", - "description": "Fast and cost-effective model for quick tasks", - "max_tokens": 200000, - "input_cost_per_1k": 0.25, - "output_cost_per_1k": 1.25, - "supports_streaming": true, - "supports_tools": true } ] } diff --git a/claude_code_api/core/__init__.py b/claude_code_api/core/__init__.py index 6dd7805..76f85a4 100644 --- a/claude_code_api/core/__init__.py +++ b/claude_code_api/core/__init__.py @@ -1 +1 @@ -"""Core package.""" \ No newline at end of file +"""Core package.""" diff --git a/claude_code_api/core/auth.py b/claude_code_api/core/auth.py index 4f76fa8..39a24c8 100644 --- a/claude_code_api/core/auth.py +++ b/claude_code_api/core/auth.py @@ -1,10 +1,11 @@ """Authentication middleware and utilities.""" import time -from typing import Optional, List -from fastapi import Request, HTTPException, status -from fastapi.responses import JSONResponse +from typing import Optional + import structlog +from fastapi import Request, status +from fastapi.responses import JSONResponse from .config import settings @@ -16,50 +17,49 @@ class RateLimiter: """Simple in-memory rate limiter.""" - + def __init__(self, requests_per_minute: int = 60, burst: int = 10): self.requests_per_minute = requests_per_minute self.burst = burst self.store = {} - + def is_allowed(self, key: str) -> bool: """Check if request is allowed for the given key.""" now = time.time() - + if key not in self.store: - self.store[key] = {'requests': [], 'burst_used': 0} - + self.store[key] = {"requests": [], "burst_used": 0} + user_data = self.store[key] - + # Remove old requests (older than 1 minute) - user_data['requests'] = [ - req_time for req_time in user_data['requests'] - if now - req_time < 60 + user_data["requests"] = [ + req_time for req_time in user_data["requests"] if now - req_time < 60 ] - + # Check burst limit - if user_data['burst_used'] >= self.burst: + if user_data["burst_used"] >= self.burst: # Reset burst if enough time has passed - if len(user_data['requests']) == 0: - user_data['burst_used'] = 0 + if len(user_data["requests"]) == 0: + user_data["burst_used"] = 0 else: return False - + # Check rate limit - if len(user_data['requests']) >= self.requests_per_minute: + if len(user_data["requests"]) >= self.requests_per_minute: return False - + # Allow request - user_data['requests'].append(now) - user_data['burst_used'] += 1 - + user_data["requests"].append(now) + user_data["burst_used"] += 1 + return True # Global rate limiter instance rate_limiter = RateLimiter( requests_per_minute=settings.rate_limit_requests_per_minute, - burst=settings.rate_limit_burst + burst=settings.rate_limit_burst, ) @@ -69,17 +69,17 @@ def extract_api_key(request: Request) -> Optional[str]: auth_header = request.headers.get("Authorization") if auth_header and auth_header.startswith("Bearer "): return auth_header[7:] # Remove "Bearer " prefix - + # Check x-api-key header api_key = request.headers.get("x-api-key") if api_key: return api_key - + # Check query parameter (less secure, but sometimes needed) api_key = request.query_params.get("api_key") if api_key: return api_key - + return None @@ -87,11 +87,11 @@ def validate_api_key(api_key: str) -> bool: """Validate API key against configured keys.""" if not settings.require_auth: return True - + if not settings.api_keys: logger.warning("No API keys configured but authentication is required") return False - + return api_key in settings.api_keys @@ -101,41 +101,44 @@ async def auth_middleware(request: Request, call_next): public_paths = ["/", "/health", "/docs", "/redoc", "/openapi.json"] if request.url.path in public_paths: return await call_next(request) - + # Skip all auth and rate limiting when authentication is disabled (test mode) if not settings.require_auth: # Still set client_id for logging request.state.api_key = None request.state.client_id = "testclient" return await call_next(request) - + # Extract API key api_key = extract_api_key(request) - + # Validate API key if required if not api_key: logger.warning( "Missing API key", path=request.url.path, - client_ip=request.client.host if request.client else "unknown" + client_ip=request.client.host if request.client else "unknown", ) return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={ "error": { - "message": "Missing API key. Provide it via Authorization header (Bearer token) or x-api-key header.", + "message": ( + "Missing API key. Provide it via Authorization header " + "(Bearer token) or x-api-key header." + ), "type": "authentication_error", - "code": "missing_api_key" + "code": "missing_api_key", } - } + }, ) - + if not validate_api_key(api_key): logger.warning( "Invalid API key", path=request.url.path, client_ip=request.client.host if request.client else "unknown", - api_key_prefix=api_key[:8] if api_key else "none" + api_key_prefix=api_key[:8] if api_key else "none", ) return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, @@ -143,18 +146,16 @@ async def auth_middleware(request: Request, call_next): "error": { "message": "Invalid API key", "type": "authentication_error", - "code": "invalid_api_key" + "code": "invalid_api_key", } - } + }, ) - + # Rate limiting client_id = api_key or request.client.host if request.client else "anonymous" if not rate_limiter.is_allowed(client_id): logger.warning( - "Rate limit exceeded", - client_id=client_id, - path=request.url.path + "Rate limit exceeded", client_id=client_id, path=request.url.path ) return JSONResponse( status_code=status.HTTP_429_TOO_MANY_REQUESTS, @@ -162,13 +163,13 @@ async def auth_middleware(request: Request, call_next): "error": { "message": "Rate limit exceeded", "type": "rate_limit_error", - "code": "rate_limit_exceeded" + "code": "rate_limit_exceeded", } - } + }, ) - + # Add API key to request state for downstream use request.state.api_key = api_key request.state.client_id = client_id - + return await call_next(request) diff --git a/claude_code_api/core/claude_manager.py b/claude_code_api/core/claude_manager.py index c3dd299..e51013e 100644 --- a/claude_code_api/core/claude_manager.py +++ b/claude_code_api/core/claude_manager.py @@ -4,21 +4,20 @@ import json import os import subprocess -import tempfile -import uuid -from pathlib import Path -from typing import Optional, Dict, List, AsyncGenerator, Any +from typing import Any, AsyncGenerator, Dict, List, Optional + import structlog -from .config import settings from claude_code_api.models.claude import get_default_model +from .config import settings + logger = structlog.get_logger() class ClaudeProcess: """Manages a single Claude Code process.""" - + def __init__(self, session_id: str, project_path: str): self.session_id = session_id self.project_path = project_path @@ -26,70 +25,84 @@ def __init__(self, session_id: str, project_path: str): self.is_running = False self.output_queue = asyncio.Queue() self.error_queue = asyncio.Queue() - + self._output_task: Optional[asyncio.Task] = None + self._error_task: Optional[asyncio.Task] = None + async def start( - self, - prompt: str, - model: str = None, - system_prompt: str = None, - resume_session: str = None + self, prompt: str, model: str = None, system_prompt: str = None ) -> bool: """Start Claude Code process and wait for completion.""" try: # Prepare real command - using exact format from working Claudia example cmd = [settings.claude_binary_path] cmd.extend(["-p", prompt]) - + if system_prompt: cmd.extend(["--system-prompt", system_prompt]) - + if model: cmd.extend(["--model", model]) - + # Always use stream-json output format (exact order from working example) - cmd.extend([ - "--output-format", "stream-json", - "--verbose", - "--dangerously-skip-permissions" - ]) - + cmd.extend( + [ + "--output-format", + "stream-json", + "--verbose", + "--dangerously-skip-permissions", + ] + ) + logger.info( "Starting Claude process", session_id=self.session_id, project_path=self.project_path, - model=model or get_default_model() + model=model or get_default_model(), ) - + # Start process from src directory (where Claude works without API key) - src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) + src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) logger.info(f"Starting Claude from directory: {src_dir}") logger.info(f"Command: {' '.join(cmd)}") - + # Start process asynchronously self.process = await asyncio.create_subprocess_exec( *cmd, cwd=src_dir, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - stdin=asyncio.subprocess.PIPE + stdin=asyncio.subprocess.PIPE, ) - + self.is_running = True - + # Start background tasks to read output - asyncio.create_task(self._read_output()) - asyncio.create_task(self._read_error()) - + self._output_task = asyncio.create_task(self._read_output()) + self._error_task = asyncio.create_task(self._read_error()) + return True - + except Exception as e: logger.error( "Failed to start Claude process", session_id=self.session_id, - error=str(e) + error=str(e), ) return False + def _decode_output_line(self, line: bytes) -> Optional[Dict[str, Any]]: + line_text = line.decode().strip() + if not line_text: + return None + + payload = line_text + if payload.startswith("data: "): + payload = payload[6:].strip() + try: + return json.loads(payload) + except json.JSONDecodeError: + return {"type": "text", "content": line_text} + async def _read_output(self): """Read stdout from process line by line.""" claude_session_id = None @@ -100,43 +113,28 @@ async def _read_output(self): if not line: break - line_text = line.decode().strip() - if not line_text: + data = self._decode_output_line(line) + if not data: continue - try: - payload = line_text - if payload.startswith("data: "): - payload = payload[6:].strip() - data = json.loads(payload) - # Extract Claude's session ID from the first message - if not claude_session_id and data.get("session_id"): - claude_session_id = data["session_id"] - logger.info(f"Extracted Claude session ID: {claude_session_id}") - # Update our session_id to match Claude's - self.session_id = claude_session_id - await self.output_queue.put(data) - except json.JSONDecodeError: - # Handle non-JSON output - await self.output_queue.put({"type": "text", "content": line_text}) + # Extract Claude's session ID from the first message + if not claude_session_id and data.get("session_id"): + claude_session_id = data["session_id"] + logger.info( + "Extracted Claude session ID", session_id=claude_session_id + ) + # Update our session_id to match Claude's + self.session_id = claude_session_id + + await self.output_queue.put(data) except Exception as e: logger.error("Error reading output", error=str(e)) finally: await self.output_queue.put(None) self.is_running = False - # Wait for process to exit - if self.process: - try: - # Don't wait forever, just check if it's done or wait a bit - # But actually we should let it run until it's done or stopped - pass - except Exception: - pass - logger.info( - "Claude process output stream ended", - session_id=self.session_id + "Claude process output stream ended", session_id=self.session_id ) async def _read_error(self): @@ -152,37 +150,30 @@ async def _read_error(self): logger.warning("Claude stderr", message=error_text) except Exception as e: logger.error("Error reading stderr", error=str(e)) - - + async def get_output(self) -> AsyncGenerator[Dict[str, Any], None]: """Get output from Claude process.""" while True: try: # Wait for output with timeout output = await asyncio.wait_for( - self.output_queue.get(), - timeout=settings.streaming_timeout_seconds + self.output_queue.get(), timeout=settings.streaming_timeout_seconds ) - + if output is None: # End signal break - + yield output - + except asyncio.TimeoutError: - logger.warning( - "Output timeout", - session_id=self.session_id - ) + logger.warning("Output timeout", session_id=self.session_id) break except Exception as e: logger.error( - "Error getting output", - session_id=self.session_id, - error=str(e) + "Error getting output", session_id=self.session_id, error=str(e) ) break - + async def send_input(self, text: str): """Send input to Claude process.""" if self.process and self.process.stdin and self.is_running: @@ -191,15 +182,17 @@ async def send_input(self, text: str): await self.process.stdin.drain() except Exception as e: logger.error( - "Error sending input", - session_id=self.session_id, - error=str(e) + "Error sending input", session_id=self.session_id, error=str(e) ) - + async def stop(self): """Stop Claude process.""" self.is_running = False - + + for task in (self._output_task, self._error_task): + if task and not task.done(): + task.cancel() + if self.process: try: self.process.terminate() @@ -209,26 +202,43 @@ async def stop(self): await self.process.wait() except Exception as e: logger.error( - "Error stopping process", - session_id=self.session_id, - error=str(e) + "Error stopping process", session_id=self.session_id, error=str(e) ) finally: self.process = None - - logger.info( - "Claude process stopped", - session_id=self.session_id - ) + self._output_task = None + self._error_task = None + + logger.info("Claude process stopped", session_id=self.session_id) + + +class ClaudeManagerError(RuntimeError): + """Base error for Claude manager operations.""" + + +class ClaudeBinaryNotFoundError(ClaudeManagerError): + """Raised when the Claude binary cannot be located.""" + + +class ClaudeVersionError(ClaudeManagerError): + """Raised when the Claude version cannot be determined.""" + + +class ClaudeConcurrencyError(ClaudeManagerError): + """Raised when the concurrent session limit is exceeded.""" + + +class ClaudeProcessStartError(ClaudeManagerError): + """Raised when a Claude process fails to start.""" class ClaudeManager: """Manages multiple Claude Code processes.""" - + def __init__(self): self.processes: Dict[str, ClaudeProcess] = {} self.max_concurrent = settings.max_concurrent_sessions - + async def get_version(self) -> str: """Get Claude Code version.""" try: @@ -236,23 +246,26 @@ async def get_version(self) -> str: settings.claude_binary_path, "--version", stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) - + stdout, stderr = await result.communicate() - + if result.returncode == 0: version = stdout.decode().strip() return version - else: - error = stderr.decode().strip() - raise Exception(f"Claude version check failed: {error}") - - except FileNotFoundError: - raise Exception(f"Claude binary not found at: {settings.claude_binary_path}") - except Exception as e: - raise Exception(f"Failed to get Claude version: {str(e)}") - + error = stderr.decode().strip() + raise ClaudeVersionError(f"Claude version check failed: {error}") + + except FileNotFoundError as exc: + raise ClaudeBinaryNotFoundError( + f"Claude binary not found at: {settings.claude_binary_path}" + ) from exc + except OSError as exc: + raise ClaudeVersionError( + f"Failed to get Claude version: {str(exc)}" + ) from exc + async def create_session( self, session_id: str, @@ -260,79 +273,75 @@ async def create_session( prompt: str, model: str = None, system_prompt: str = None, - resume_session: str = None ) -> ClaudeProcess: """Create new Claude session.""" # Check concurrent session limit if len(self.processes) >= self.max_concurrent: - raise Exception(f"Maximum concurrent sessions ({self.max_concurrent}) reached") - + raise ClaudeConcurrencyError( + f"Maximum concurrent sessions ({self.max_concurrent}) reached" + ) + # Ensure project directory exists os.makedirs(project_path, exist_ok=True) - + # Create process process = ClaudeProcess(session_id, project_path) - + # Start process success = await process.start( prompt=prompt, model=model or get_default_model(), system_prompt=system_prompt, - resume_session=resume_session ) - + if not success: - raise Exception("Failed to start Claude process") - + raise ClaudeProcessStartError("Failed to start Claude process") + # Don't store processes since Claude CLI completes immediately # This prevents the "max concurrent sessions" error - + logger.info( "Claude session created", session_id=process.session_id, # Use Claude's actual session ID - active_sessions=len(self.processes) + active_sessions=len(self.processes), ) - + return process - - async def get_session(self, session_id: str) -> Optional[ClaudeProcess]: + + def get_session(self, session_id: str) -> Optional[ClaudeProcess]: """Get existing Claude session.""" return self.processes.get(session_id) - + async def stop_session(self, session_id: str): """Stop Claude session.""" if session_id in self.processes: process = self.processes[session_id] await process.stop() del self.processes[session_id] - + logger.info( "Claude session stopped", session_id=session_id, - active_sessions=len(self.processes) + active_sessions=len(self.processes), ) - + async def cleanup_all(self): """Stop all Claude sessions.""" - for session_id in list(self.processes.keys()): + for session_id in tuple(self.processes): await self.stop_session(session_id) - + logger.info("All Claude sessions cleaned up") - + def get_active_sessions(self) -> List[str]: """Get list of active session IDs.""" return list(self.processes.keys()) - - async def continue_conversation( - self, - session_id: str, - prompt: str - ) -> bool: + + async def continue_conversation(self, session_id: str, prompt: str) -> bool: """Continue existing conversation.""" process = self.processes.get(session_id) if not process: return False - + await process.send_input(prompt) return True @@ -349,11 +358,14 @@ def cleanup_project_directory(project_path: str): """Clean up project directory.""" try: import shutil + if os.path.exists(project_path): shutil.rmtree(project_path) logger.info("Project directory cleaned up", path=project_path) except Exception as e: - logger.error("Failed to cleanup project directory", path=project_path, error=str(e)) + logger.error( + "Failed to cleanup project directory", path=project_path, error=str(e) + ) def validate_claude_binary() -> bool: @@ -363,7 +375,7 @@ def validate_claude_binary() -> bool: [settings.claude_binary_path, "--version"], capture_output=True, text=True, - timeout=10 + timeout=10, ) return result.returncode == 0 except Exception: diff --git a/claude_code_api/core/config.py b/claude_code_api/core/config.py index 5ee5a5f..9497f7d 100644 --- a/claude_code_api/core/config.py +++ b/claude_code_api/core/config.py @@ -2,7 +2,8 @@ import os import shutil -from typing import List, Union +from typing import List + from pydantic import Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -10,47 +11,70 @@ def find_claude_binary() -> str: """Find Claude binary path automatically.""" # First check environment variable - if 'CLAUDE_BINARY_PATH' in os.environ: - claude_path = os.environ['CLAUDE_BINARY_PATH'] + if "CLAUDE_BINARY_PATH" in os.environ: + claude_path = os.environ["CLAUDE_BINARY_PATH"] if os.path.exists(claude_path): return claude_path - + # Try to find claude in PATH - this should work for npm global installs claude_path = shutil.which("claude") if claude_path: return claude_path - + # Import npm environment if needed try: import subprocess + # Try to get npm global bin path - result = subprocess.run(['npm', 'bin', '-g'], capture_output=True, text=True) + result = subprocess.run(["npm", "bin", "-g"], capture_output=True, text=True) if result.returncode == 0: npm_bin_path = result.stdout.strip() - claude_npm_path = os.path.join(npm_bin_path, 'claude') + claude_npm_path = os.path.join(npm_bin_path, "claude") if os.path.exists(claude_npm_path): return claude_npm_path except Exception: pass - + # Fallback to common npm/nvm locations import glob + common_patterns = [ "/usr/local/bin/claude", "/usr/local/share/nvm/versions/node/*/bin/claude", "~/.nvm/versions/node/*/bin/claude", ] - + for pattern in common_patterns: expanded_pattern = os.path.expanduser(pattern) matches = glob.glob(expanded_pattern) if matches: # Return the most recent version return sorted(matches)[-1] - + return "claude" # Final fallback +def default_project_root() -> str: + """Default project root under the current working directory.""" + return os.path.join(os.getcwd(), "claude_projects") + + +def _is_shell_script_line(line: str) -> bool: + if not line: + return False + if line.startswith("#!") or line.startswith("set "): + return True + if "BASH_SOURCE" in line or "[[" in line: + return True + return line.startswith(("if ", "fi", "for ", "done", "source ")) + + +def _strip_export_prefix(line: str) -> str: + if line.startswith("export "): + return line[len("export ") :].lstrip() + return line + + def _looks_like_dotenv(path: str) -> bool: """Return True when a file appears to be a simple KEY=VALUE dotenv file.""" try: @@ -59,17 +83,10 @@ def _looks_like_dotenv(path: str) -> bool: stripped = line.strip() if not stripped or stripped.startswith("#"): continue - if stripped.startswith("#!") or stripped.startswith("set "): - return False - if "BASH_SOURCE" in stripped or "[[" in stripped: - return False - if stripped.startswith(("if ", "fi", "for ", "done", "source ")): + if _is_shell_script_line(stripped): return False - if stripped.startswith("export "): - stripped = stripped[len("export "):].lstrip() + stripped = _strip_export_prefix(stripped) return "=" in stripped - except FileNotFoundError: - return False except OSError: return False return True @@ -95,65 +112,71 @@ class Settings(BaseSettings): case_sensitive=False, extra="ignore", ) - + # API Configuration api_title: str = "Claude Code API Gateway" api_version: str = "1.0.0" api_description: str = "OpenAI-compatible API for Claude Code" - + # Server Configuration host: str = "0.0.0.0" port: int = 8000 debug: bool = False - + # Authentication api_keys: List[str] = Field(default_factory=list) require_auth: bool = False - - @field_validator('api_keys', mode='before') + + @field_validator("api_keys", mode="before") def parse_api_keys(cls, v): if isinstance(v, str): - return [x.strip() for x in v.split(',') if x.strip()] + return [x.strip() for x in v.split(",") if x.strip()] return v or [] - - # Claude Configuration + + # Claude Configuration claude_binary_path: str = find_claude_binary() claude_api_key: str = "" default_model: str = "claude-sonnet-4-5-20250929" max_concurrent_sessions: int = 10 session_timeout_minutes: int = 30 - + # Project Configuration - project_root: str = "/tmp/claude_projects" + project_root: str = default_project_root() max_project_size_mb: int = 1000 cleanup_interval_minutes: int = 60 - + # Database Configuration database_url: str = "sqlite:///./claude_api.db" - + # Logging Configuration log_level: str = "INFO" log_format: str = "json" - + # CORS Configuration - allowed_origins: List[str] = Field(default=["http://localhost:8000", "http://127.0.0.1:8000"]) - allowed_methods: List[str] = Field(default=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"]) + allowed_origins: List[str] = Field( + default=["http://localhost:8000", "http://127.0.0.1:8000"] + ) + allowed_methods: List[str] = Field( + default=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"] + ) allowed_headers: List[str] = Field(default=["*"]) - - @field_validator('allowed_origins', 'allowed_methods', 'allowed_headers', mode='before') + + @field_validator( + "allowed_origins", "allowed_methods", "allowed_headers", mode="before" + ) def parse_cors_lists(cls, v): if isinstance(v, str): - return [x.strip() for x in v.split(',') if x.strip()] + return [x.strip() for x in v.split(",") if x.strip()] return v - + # Rate Limiting rate_limit_requests_per_minute: int = 100 rate_limit_burst: int = 10 - + # Streaming Configuration streaming_chunk_size: int = 1024 streaming_timeout_seconds: int = 300 - + # Create global settings instance settings = Settings() diff --git a/claude_code_api/core/database.py b/claude_code_api/core/database.py index 791e1ac..e8ce5e6 100644 --- a/claude_code_api/core/database.py +++ b/claude_code_api/core/database.py @@ -1,18 +1,28 @@ """Database models and connection management.""" -from datetime import datetime -from typing import Optional, List +from typing import AsyncGenerator, List, Optional + +import structlog from sqlalchemy import ( - Column, Integer, String, Text, DateTime, Boolean, Float, - ForeignKey, create_engine, MetaData + Boolean, + Column, + DateTime, + Float, + ForeignKey, + Integer, + String, + Text, + func, + select, ) +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, relationship -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker -import structlog +from sqlalchemy.orm import relationship -from .config import settings from claude_code_api.models.claude import get_default_model +from claude_code_api.utils.time import utc_now + +from .config import settings logger = structlog.get_logger() @@ -33,81 +43,89 @@ class Project(Base): """Project model.""" + __tablename__ = "projects" - + id = Column(String, primary_key=True) name = Column(String, nullable=False) description = Column(Text) path = Column(String, nullable=False, unique=True) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + created_at = Column(DateTime, default=utc_now) + updated_at = Column(DateTime, default=utc_now, onupdate=utc_now) is_active = Column(Boolean, default=True) - + # Relationships - sessions = relationship("Session", back_populates="project", cascade="all, delete-orphan") + sessions = relationship( + "Session", back_populates="project", cascade="all, delete-orphan" + ) class Session(Base): """Session model.""" + __tablename__ = "sessions" - + id = Column(String, primary_key=True) project_id = Column(String, ForeignKey("projects.id"), nullable=False) title = Column(String) model = Column(String, default=get_default_model) system_prompt = Column(Text) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + created_at = Column(DateTime, default=utc_now) + updated_at = Column(DateTime, default=utc_now, onupdate=utc_now) is_active = Column(Boolean, default=True) - + # Session metrics total_tokens = Column(Integer, default=0) total_cost = Column(Float, default=0.0) message_count = Column(Integer, default=0) - + # Relationships project = relationship("Project", back_populates="sessions") - messages = relationship("Message", back_populates="session", cascade="all, delete-orphan") + messages = relationship( + "Message", back_populates="session", cascade="all, delete-orphan" + ) class Message(Base): """Message model.""" + __tablename__ = "messages" - + id = Column(Integer, primary_key=True, autoincrement=True) session_id = Column(String, ForeignKey("sessions.id"), nullable=False) role = Column(String, nullable=False) # user, assistant, system content = Column(Text, nullable=False) message_metadata = Column(Text) # JSON metadata - created_at = Column(DateTime, default=datetime.utcnow) - + created_at = Column(DateTime, default=utc_now) + # Token usage input_tokens = Column(Integer, default=0) output_tokens = Column(Integer, default=0) cost = Column(Float, default=0.0) - + # Relationships session = relationship("Session", back_populates="messages") class APIKey(Base): """API Key model for tracking usage.""" + __tablename__ = "api_keys" - + id = Column(Integer, primary_key=True, autoincrement=True) key_hash = Column(String, nullable=False, unique=True) name = Column(String) is_active = Column(Boolean, default=True) - created_at = Column(DateTime, default=datetime.utcnow) + created_at = Column(DateTime, default=utc_now) last_used_at = Column(DateTime) - + # Usage tracking total_requests = Column(Integer, default=0) total_tokens = Column(Integer, default=0) total_cost = Column(Float, default=0.0) -async def get_db() -> AsyncSession: +async def get_db() -> AsyncGenerator[AsyncSession, None]: """Get database session.""" async with AsyncSessionLocal() as session: try: @@ -132,14 +150,35 @@ async def close_database(): # Database utilities class DatabaseManager: """Database operations manager.""" - + @staticmethod async def get_project(project_id: str) -> Optional[Project]: """Get project by ID.""" async with AsyncSessionLocal() as session: result = await session.get(Project, project_id) return result - + + @staticmethod + async def list_projects(page: int, per_page: int) -> List[Project]: + """List projects with pagination.""" + offset = max(0, (page - 1) * per_page) + async with AsyncSessionLocal() as session: + stmt = ( + select(Project) + .order_by(Project.created_at) + .offset(offset) + .limit(per_page) + ) + result = await session.execute(stmt) + return list(result.scalars().all()) + + @staticmethod + async def count_projects() -> int: + """Count total projects.""" + async with AsyncSessionLocal() as session: + result = await session.execute(select(func.count(Project.id))) + return int(result.scalar_one() or 0) + @staticmethod async def create_project(project_data: dict) -> Project: """Create new project.""" @@ -149,14 +188,25 @@ async def create_project(project_data: dict) -> Project: await session.commit() await session.refresh(project) return project - + + @staticmethod + async def delete_project(project_id: str) -> bool: + """Delete project by ID.""" + async with AsyncSessionLocal() as session: + project = await session.get(Project, project_id) + if not project: + return False + await session.delete(project) + await session.commit() + return True + @staticmethod async def get_session(session_id: str) -> Optional[Session]: """Get session by ID.""" async with AsyncSessionLocal() as session: result = await session.get(Session, session_id) return result - + @staticmethod async def create_session(session_data: dict) -> Session: """Create new session.""" @@ -166,7 +216,7 @@ async def create_session(session_data: dict) -> Session: await session.commit() await session.refresh(session_obj) return session_obj - + @staticmethod async def add_message(message_data: dict) -> Message: """Add message to session.""" @@ -176,13 +226,9 @@ async def add_message(message_data: dict) -> Message: await session.commit() await session.refresh(message) return message - + @staticmethod - async def update_session_metrics( - session_id: str, - tokens_used: int, - cost: float - ): + async def update_session_metrics(session_id: str, tokens_used: int, cost: float): """Update session usage metrics.""" async with AsyncSessionLocal() as session: session_obj = await session.get(Session, session_id) @@ -190,7 +236,7 @@ async def update_session_metrics( session_obj.total_tokens += tokens_used session_obj.total_cost += cost session_obj.message_count += 1 - session_obj.updated_at = datetime.utcnow() + session_obj.updated_at = utc_now() await session.commit() diff --git a/claude_code_api/core/security.py b/claude_code_api/core/security.py index 7e8a266..d836488 100644 --- a/claude_code_api/core/security.py +++ b/claude_code_api/core/security.py @@ -1,11 +1,13 @@ """Security utilities.""" import os + import structlog from fastapi import HTTPException, status logger = structlog.get_logger() + def validate_path(path: str, base_path: str) -> str: """ Validate that a path is safe and within the base path. @@ -39,11 +41,11 @@ def validate_path(path: str, base_path: str) -> str: "Path traversal attempt detected", path=path, resolved_path=abs_path, - base_path=abs_base_path + base_path=abs_base_path, ) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid path: Path traversal detected" + detail="Invalid path: Path traversal detected", ) return abs_path @@ -53,6 +55,5 @@ def validate_path(path: str, base_path: str) -> str: except Exception as e: logger.error("Path validation error", error=str(e), path=path) raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid path: {str(e)}" + status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid path: {str(e)}" ) diff --git a/claude_code_api/core/session_manager.py b/claude_code_api/core/session_manager.py index 3747ec4..c7ee52b 100644 --- a/claude_code_api/core/session_manager.py +++ b/claude_code_api/core/session_manager.py @@ -2,34 +2,31 @@ import asyncio import uuid -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Any +from datetime import timedelta +from typing import Any, Dict, List, Optional + import structlog from claude_code_api.core.config import settings +from claude_code_api.core.database import db_manager from claude_code_api.models.claude import get_default_model -from claude_code_api.core.database import db_manager, Session, Message -from claude_code_api.core.claude_manager import ClaudeProcess +from claude_code_api.utils.time import utc_now logger = structlog.get_logger() class SessionInfo: """Session information and metadata.""" - + def __init__( - self, - session_id: str, - project_id: str, - model: str, - system_prompt: str = None + self, session_id: str, project_id: str, model: str, system_prompt: str = None ): self.session_id = session_id self.project_id = project_id self.model = model self.system_prompt = system_prompt - self.created_at = datetime.utcnow() - self.updated_at = datetime.utcnow() + self.created_at = utc_now() + self.updated_at = utc_now() self.message_count = 0 self.total_tokens = 0 self.total_cost = 0.0 @@ -38,50 +35,56 @@ def __init__( class SessionManager: """Manages active sessions and their lifecycle.""" - + def __init__(self): self.active_sessions: Dict[str, SessionInfo] = {} self.cleanup_task: Optional[asyncio.Task] = None + self._shutdown_event = asyncio.Event() self._start_cleanup_task() - + def _start_cleanup_task(self): """Start periodic cleanup task.""" if self.cleanup_task is None or self.cleanup_task.done(): self.cleanup_task = asyncio.create_task(self._periodic_cleanup()) - + async def _periodic_cleanup(self): """Periodic cleanup of expired sessions.""" while True: try: - await asyncio.sleep(settings.cleanup_interval_minutes * 60) - await self.cleanup_expired_sessions() - except asyncio.CancelledError: + await asyncio.wait_for( + self._shutdown_event.wait(), + timeout=settings.cleanup_interval_minutes * 60, + ) break + except asyncio.TimeoutError: + self.cleanup_expired_sessions() + except asyncio.CancelledError: + raise except Exception as e: logger.error("Error in periodic cleanup", error=str(e)) - + async def create_session( self, project_id: str, model: str = None, system_prompt: str = None, - session_id: str = None + session_id: str = None, ) -> str: """Create new session.""" if session_id is None: session_id = str(uuid.uuid4()) - + # Create session info session_info = SessionInfo( session_id=session_id, project_id=project_id, model=model or get_default_model(), - system_prompt=system_prompt + system_prompt=system_prompt, ) - + # Store in active sessions self.active_sessions[session_id] = session_info - + # Create database record session_data = { "id": session_id, @@ -90,26 +93,26 @@ async def create_session( "system_prompt": system_prompt, "title": f"Session {session_id[:8]}", "created_at": session_info.created_at, - "updated_at": session_info.updated_at + "updated_at": session_info.updated_at, } - + await db_manager.create_session(session_data) - + logger.info( "Session created", session_id=session_id, project_id=project_id, - model=session_info.model + model=session_info.model, ) - + return session_id - + async def get_session(self, session_id: str) -> Optional[SessionInfo]: """Get session information.""" # Check active sessions first if session_id in self.active_sessions: return self.active_sessions[session_id] - + # Load from database if not in memory db_session = await db_manager.get_session(session_id) if db_session and db_session.is_active: @@ -118,40 +121,40 @@ async def get_session(self, session_id: str) -> Optional[SessionInfo]: session_id=db_session.id, project_id=db_session.project_id, model=db_session.model, - system_prompt=db_session.system_prompt + system_prompt=db_session.system_prompt, ) session_info.created_at = db_session.created_at session_info.updated_at = db_session.updated_at session_info.message_count = db_session.message_count session_info.total_tokens = db_session.total_tokens session_info.total_cost = db_session.total_cost - + self.active_sessions[session_id] = session_info return session_info - + return None - + async def update_session( self, session_id: str, tokens_used: int = 0, cost: float = 0.0, message_content: str = None, - role: str = "user" + role: str = "user", ): """Update session with new message and metrics.""" session_info = await self.get_session(session_id) if not session_info: return - + # Update session info - session_info.updated_at = datetime.utcnow() + session_info.updated_at = utc_now() session_info.total_tokens += tokens_used session_info.total_cost += cost - + if message_content: session_info.message_count += 1 - + # Add message to database message_data = { "session_id": session_id, @@ -160,153 +163,138 @@ async def update_session( "input_tokens": tokens_used if role == "user" else 0, "output_tokens": tokens_used if role == "assistant" else 0, "cost": cost, - "created_at": datetime.utcnow() + "created_at": utc_now(), } - + await db_manager.add_message(message_data) - + # Update database metrics await db_manager.update_session_metrics(session_id, tokens_used, cost) - + logger.debug( "Session updated", session_id=session_id, tokens_used=tokens_used, cost=cost, - total_tokens=session_info.total_tokens + total_tokens=session_info.total_tokens, ) - - async def end_session(self, session_id: str): + + def end_session(self, session_id: str): """End session and cleanup.""" if session_id in self.active_sessions: session_info = self.active_sessions[session_id] session_info.is_active = False del self.active_sessions[session_id] - + logger.info( "Session ended", session_id=session_id, - duration_minutes=(datetime.utcnow() - session_info.created_at).total_seconds() / 60, + duration_minutes=(utc_now() - session_info.created_at).total_seconds() + / 60, total_tokens=session_info.total_tokens, - total_cost=session_info.total_cost + total_cost=session_info.total_cost, ) - - async def cleanup_expired_sessions(self): + + def cleanup_expired_sessions(self): """Clean up expired sessions.""" - current_time = datetime.utcnow() + current_time = utc_now() timeout_delta = timedelta(minutes=settings.session_timeout_minutes) expired_sessions = [] - + for session_id, session_info in self.active_sessions.items(): if current_time - session_info.updated_at > timeout_delta: expired_sessions.append(session_id) - + for session_id in expired_sessions: - await self.end_session(session_id) + self.end_session(session_id) logger.info("Session expired and cleaned up", session_id=session_id) - + async def cleanup_all(self): """Clean up all sessions.""" session_ids = list(self.active_sessions.keys()) for session_id in session_ids: - await self.end_session(session_id) - + self.end_session(session_id) + if self.cleanup_task and not self.cleanup_task.done(): - self.cleanup_task.cancel() - try: - await self.cleanup_task - except asyncio.CancelledError: - pass - + self._shutdown_event.set() + await self.cleanup_task + logger.info("All sessions cleaned up") - + def get_active_session_count(self) -> int: """Get number of active sessions.""" return len(self.active_sessions) - + def get_session_stats(self) -> Dict[str, Any]: """Get session statistics.""" total_tokens = sum(s.total_tokens for s in self.active_sessions.values()) total_cost = sum(s.total_cost for s in self.active_sessions.values()) total_messages = sum(s.message_count for s in self.active_sessions.values()) - + return { "active_sessions": len(self.active_sessions), "total_tokens": total_tokens, "total_cost": total_cost, "total_messages": total_messages, - "models_in_use": list(set(s.model for s in self.active_sessions.values())) + "models_in_use": list({s.model for s in self.active_sessions.values()}), } class ConversationManager: """Manages conversation flow and context.""" - + def __init__(self, session_manager: SessionManager): self.session_manager = session_manager self.conversation_history: Dict[str, List[Dict[str, Any]]] = {} - + async def add_message( - self, - session_id: str, - role: str, - content: str, - metadata: Dict[str, Any] = None + self, session_id: str, role: str, content: str, metadata: Dict[str, Any] = None ): """Add message to conversation history.""" if session_id not in self.conversation_history: self.conversation_history[session_id] = [] - + message = { "role": role, "content": content, - "timestamp": datetime.utcnow().isoformat(), - "metadata": metadata or {} + "timestamp": utc_now().isoformat(), + "metadata": metadata or {}, } - + self.conversation_history[session_id].append(message) - + # Update session await self.session_manager.update_session( - session_id=session_id, - message_content=content, - role=role + session_id=session_id, message_content=content, role=role ) - + def get_conversation_history( - self, - session_id: str, - limit: int = None + self, session_id: str, limit: int = None ) -> List[Dict[str, Any]]: """Get conversation history for session.""" history = self.conversation_history.get(session_id, []) if limit: return history[-limit:] return history - + def format_messages_for_claude( - self, - session_id: str, - include_system: bool = True + self, session_id: str, include_system: bool = True ) -> List[Dict[str, str]]: """Format messages for Claude Code input.""" history = self.get_conversation_history(session_id) formatted = [] - + for msg in history: if msg["role"] == "system" and not include_system: continue - - formatted.append({ - "role": msg["role"], - "content": msg["content"] - }) - + + formatted.append({"role": msg["role"], "content": msg["content"]}) + return formatted - - async def clear_conversation(self, session_id: str): + + def clear_conversation(self, session_id: str): """Clear conversation history.""" if session_id in self.conversation_history: del self.conversation_history[session_id] - - await self.session_manager.end_session(session_id) + + self.session_manager.end_session(session_id) diff --git a/claude_code_api/main.py b/claude_code_api/main.py index d4282d0..85dd8f4 100644 --- a/claude_code_api/main.py +++ b/claude_code_api/main.py @@ -5,29 +5,26 @@ while leveraging Claude Code's powerful workflow capabilities. """ -import os -import logging from contextlib import asynccontextmanager from typing import AsyncGenerator +import structlog from fastapi import FastAPI, HTTPException, status from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse -import structlog -from claude_code_api.core.config import settings -from claude_code_api.core.database import create_tables, close_database -from claude_code_api.core.session_manager import SessionManager -from claude_code_api.core.claude_manager import ClaudeManager from claude_code_api.api.chat import router as chat_router from claude_code_api.api.models import router as models_router from claude_code_api.api.projects import router as projects_router from claude_code_api.api.sessions import router as sessions_router from claude_code_api.core.auth import auth_middleware +from claude_code_api.core.claude_manager import ClaudeManager +from claude_code_api.core.config import settings +from claude_code_api.core.database import close_database, create_tables +from claude_code_api.core.session_manager import SessionManager from claude_code_api.models.openai import ChatCompletionChunk - # Configure structured logging structlog.configure( processors=[ @@ -39,7 +36,7 @@ structlog.processors.StackInfoRenderer(), structlog.processors.format_exc_info, structlog.processors.UnicodeDecoder(), - structlog.processors.JSONRenderer() + structlog.processors.JSONRenderer(), ], context_class=dict, logger_factory=structlog.stdlib.LoggerFactory(), @@ -54,16 +51,16 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """Application lifespan manager.""" logger.info("Starting Claude Code API Gateway", version="1.0.0") - + # Initialize database await create_tables() logger.info("Database initialized") - + # Initialize managers app.state.session_manager = SessionManager() app.state.claude_manager = ClaudeManager() logger.info("Managers initialized") - + # Verify Claude Code availability try: claude_version = await app.state.claude_manager.get_version() @@ -72,11 +69,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: logger.error("Claude Code not available", error=str(e)) raise HTTPException( status_code=503, - detail="Claude Code CLI not available. Please ensure Claude Code is installed and accessible." + detail="Claude Code CLI not available. Please ensure Claude Code is installed and accessible.", ) - + yield - + # Cleanup logger.info("Shutting down Claude Code API Gateway") await app.state.session_manager.cleanup_all() @@ -90,7 +87,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: version="1.0.0", docs_url="/docs", redoc_url="/redoc", - lifespan=lifespan + lifespan=lifespan, ) @@ -104,7 +101,7 @@ def custom_openapi(): title=app.title, version=app.version, description=app.description, - routes=app.routes + routes=app.routes, ) components = schema.setdefault("components", {}).setdefault("schemas", {}) @@ -144,14 +141,8 @@ def custom_openapi(): async def http_exception_handler(request, exc): """Custom handler for HTTP exceptions to support OpenAI error format.""" if isinstance(exc.detail, dict) and "error" in exc.detail: - return JSONResponse( - status_code=exc.status_code, - content=exc.detail - ) - return JSONResponse( - status_code=exc.status_code, - content={"detail": exc.detail} - ) + return JSONResponse(status_code=exc.status_code, content=exc.detail) + return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) @app.exception_handler(RequestValidationError) @@ -168,9 +159,9 @@ async def validation_exception_handler(request, exc): "error": { "message": "Validation error", "type": "invalid_request_error", - "details": exc.errors() + "details": exc.errors(), } - } + }, ) @@ -182,7 +173,7 @@ async def global_exception_handler(request, exc): path=request.url.path, method=request.method, error=str(exc), - exc_info=True + exc_info=True, ) return JSONResponse( status_code=500, @@ -190,9 +181,9 @@ async def global_exception_handler(request, exc): "error": { "message": "Internal server error", "type": "internal_error", - "code": "internal_error" + "code": "internal_error", } - } + }, ) @@ -202,21 +193,17 @@ async def health_check(): try: # Check Claude Code availability claude_version = await app.state.claude_manager.get_version() - + return { "status": "healthy", "version": "1.0.0", "claude_version": claude_version, - "active_sessions": len(app.state.session_manager.active_sessions) + "active_sessions": len(app.state.session_manager.active_sessions), } except Exception as e: logger.error("Health check failed", error=str(e)) return JSONResponse( - status_code=503, - content={ - "status": "unhealthy", - "error": str(e) - } + status_code=503, content={"status": "unhealthy", "error": str(e)} ) @@ -231,10 +218,10 @@ async def root(): "chat": "/v1/chat/completions", "models": "/v1/models", "projects": "/v1/projects", - "sessions": "/v1/sessions" + "sessions": "/v1/sessions", }, "docs": "/docs", - "health": "/health" + "health": "/health", } @@ -247,10 +234,11 @@ async def root(): if __name__ == "__main__": import uvicorn + uvicorn.run( "claude_code_api.main:app", host="0.0.0.0", port=8000, reload=True, - log_level="info" + log_level="info", ) diff --git a/claude_code_api/models/__init__.py b/claude_code_api/models/__init__.py index 961620a..53f3a7c 100644 --- a/claude_code_api/models/__init__.py +++ b/claude_code_api/models/__init__.py @@ -1 +1 @@ -"""Models package.""" \ No newline at end of file +"""Models package.""" diff --git a/claude_code_api/models/claude.py b/claude_code_api/models/claude.py index 487cd7d..b2d5fd2 100644 --- a/claude_code_api/models/claude.py +++ b/claude_code_api/models/claude.py @@ -1,17 +1,24 @@ """Claude Code specific models and utilities.""" -from datetime import datetime -from functools import lru_cache -from pathlib import Path -from typing import List, Optional, Dict, Any, Union, Literal import json import os +from datetime import datetime from enum import Enum +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + from pydantic import BaseModel, Field +from claude_code_api.utils.time import utc_now + +SESSION_ID_DESC = "Session ID" +PROJECT_PATH_DESC = "Project path" + class ClaudeMessageType(str, Enum): """Claude message types from JSONL output.""" + SYSTEM = "system" USER = "user" ASSISTANT = "assistant" @@ -23,6 +30,7 @@ class ClaudeMessageType(str, Enum): class ClaudeToolType(str, Enum): """Claude Code built-in tools.""" + BASH = "bash" EDIT = "edit" READ = "read" @@ -36,10 +44,11 @@ class ClaudeToolType(str, Enum): class ClaudeMessage(BaseModel): """Claude message from JSONL output.""" + type: str = Field(..., description="Message type") subtype: Optional[str] = Field(None, description="Message subtype") message: Optional[Dict[str, Any]] = Field(None, description="Message content") - session_id: Optional[str] = Field(None, description="Session ID") + session_id: Optional[str] = Field(None, description=SESSION_ID_DESC) model: Optional[str] = Field(None, description="Model used") cwd: Optional[str] = Field(None, description="Current working directory") tools: Optional[List[str]] = Field(None, description="Available tools") @@ -54,6 +63,7 @@ class ClaudeMessage(BaseModel): class ClaudeToolUse(BaseModel): """Claude tool use information.""" + id: str = Field(..., description="Tool use ID") name: str = Field(..., description="Tool name") input: Dict[str, Any] = Field(..., description="Tool input parameters") @@ -61,9 +71,12 @@ class ClaudeToolUse(BaseModel): class ClaudeToolResult(BaseModel): """Claude tool result information.""" + tool_use_id: str = Field(..., description="Tool use ID") content: Union[str, Dict[str, Any]] = Field(..., description="Tool result content") - is_error: Optional[bool] = Field(False, description="Whether this is an error result") + is_error: Optional[bool] = Field( + False, description="Whether this is an error result" + ) def _default_model_factory() -> str: @@ -72,8 +85,9 @@ def _default_model_factory() -> str: class ClaudeSessionInfo(BaseModel): """Claude session information.""" - session_id: str = Field(..., description="Session ID") - project_path: str = Field(..., description="Project path") + + session_id: str = Field(..., description=SESSION_ID_DESC) + project_path: str = Field(..., description=PROJECT_PATH_DESC) model: str = Field(..., description="Model being used") started_at: datetime = Field(..., description="Session start time") is_running: bool = Field(..., description="Whether session is running") @@ -84,6 +98,7 @@ class ClaudeSessionInfo(BaseModel): class ClaudeProcessStatus(str, Enum): """Claude process status.""" + STARTING = "starting" RUNNING = "running" COMPLETED = "completed" @@ -94,8 +109,9 @@ class ClaudeProcessStatus(str, Enum): class ClaudeExecutionRequest(BaseModel): """Claude execution request.""" + prompt: str = Field(..., description="User prompt") - project_path: str = Field(..., description="Project path") + project_path: str = Field(..., description=PROJECT_PATH_DESC) model: Optional[str] = Field(None, description="Model to use") system_prompt: Optional[str] = Field(None, description="System prompt") resume_session: Optional[str] = Field(None, description="Session ID to resume") @@ -104,7 +120,8 @@ class ClaudeExecutionRequest(BaseModel): class ClaudeExecutionResponse(BaseModel): """Claude execution response.""" - session_id: str = Field(..., description="Session ID") + + session_id: str = Field(..., description=SESSION_ID_DESC) status: ClaudeProcessStatus = Field(..., description="Execution status") messages: List[ClaudeMessage] = Field(..., description="Messages from execution") total_tokens: int = Field(0, description="Total tokens used") @@ -114,7 +131,8 @@ class ClaudeExecutionResponse(BaseModel): class ClaudeStreamingChunk(BaseModel): """Claude streaming chunk.""" - session_id: str = Field(..., description="Session ID") + + session_id: str = Field(..., description=SESSION_ID_DESC) chunk_type: str = Field(..., description="Type of chunk") data: ClaudeMessage = Field(..., description="Chunk data") is_final: bool = Field(False, description="Whether this is the final chunk") @@ -122,20 +140,28 @@ class ClaudeStreamingChunk(BaseModel): class ClaudeProjectConfig(BaseModel): """Claude project configuration.""" + project_id: str = Field(..., description="Project ID") name: str = Field(..., description="Project name") - path: str = Field(..., description="Project path") - default_model: str = Field(default_factory=_default_model_factory, description="Default model") + path: str = Field(..., description=PROJECT_PATH_DESC) + default_model: str = Field( + default_factory=_default_model_factory, description="Default model" + ) system_prompt: Optional[str] = Field(None, description="Default system prompt") - tools_enabled: List[ClaudeToolType] = Field(default_factory=list, description="Enabled tools") + tools_enabled: List[ClaudeToolType] = Field( + default_factory=list, description="Enabled tools" + ) max_tokens: Optional[int] = Field(None, description="Maximum tokens per request") temperature: Optional[float] = Field(None, description="Temperature setting") - created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time") - updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time") + created_at: datetime = Field(default_factory=utc_now, description="Creation time") + updated_at: datetime = Field( + default_factory=utc_now, description="Last update time" + ) class ClaudeFileInfo(BaseModel): """Claude file information.""" + path: str = Field(..., description="File path") name: str = Field(..., description="File name") size: int = Field(..., description="File size in bytes") @@ -146,6 +172,7 @@ class ClaudeFileInfo(BaseModel): class ClaudeWorkspaceInfo(BaseModel): """Claude workspace information.""" + path: str = Field(..., description="Workspace path") files: List[ClaudeFileInfo] = Field(..., description="Files in workspace") total_files: int = Field(..., description="Total number of files") @@ -155,6 +182,7 @@ class ClaudeWorkspaceInfo(BaseModel): class ClaudeVersionInfo(BaseModel): """Claude version information.""" + version: str = Field(..., description="Claude Code version") build: Optional[str] = Field(None, description="Build information") is_available: bool = Field(..., description="Whether Claude is available") @@ -163,15 +191,19 @@ class ClaudeVersionInfo(BaseModel): class ClaudeErrorInfo(BaseModel): """Claude error information.""" + error_type: str = Field(..., description="Type of error") message: str = Field(..., description="Error message") - session_id: Optional[str] = Field(None, description="Session ID where error occurred") - timestamp: datetime = Field(default_factory=datetime.utcnow, description="Error timestamp") + session_id: Optional[str] = Field( + None, description="Session ID where error occurred" + ) + timestamp: datetime = Field(default_factory=utc_now, description="Error timestamp") traceback: Optional[str] = Field(None, description="Error traceback") class ClaudeMetrics(BaseModel): """Claude usage metrics.""" + total_sessions: int = Field(..., description="Total number of sessions") active_sessions: int = Field(..., description="Currently active sessions") total_tokens: int = Field(..., description="Total tokens processed") @@ -184,13 +216,16 @@ class ClaudeMetrics(BaseModel): class ClaudeModelInfo(BaseModel): """Claude model information.""" + id: str = Field(..., description="Model ID") name: str = Field(..., description="Model display name") description: str = Field(..., description="Model description") max_tokens: int = Field(..., description="Maximum tokens supported") input_cost_per_1k: float = Field(..., description="Input cost per 1K tokens") output_cost_per_1k: float = Field(..., description="Output cost per 1K tokens") - supports_streaming: bool = Field(True, description="Whether model supports streaming") + supports_streaming: bool = Field( + True, description="Whether model supports streaming" + ) supports_tools: bool = Field(True, description="Whether model supports tool use") diff --git a/claude_code_api/models/openai.py b/claude_code_api/models/openai.py index 02904f2..5cb04fc 100644 --- a/claude_code_api/models/openai.py +++ b/claude_code_api/models/openai.py @@ -1,42 +1,53 @@ """OpenAI-compatible Pydantic models.""" from datetime import datetime -from typing import List, Optional, Dict, Any, Union, Literal +from typing import Any, Dict, List, Literal, Optional, Union + from pydantic import BaseModel, Field +OBJECT_TYPE_DESC = "Object type" + class ToolFunction(BaseModel): """Tool function definition (OpenAI compatible).""" + name: str = Field(..., description="The name of the function to call") description: Optional[str] = Field(None, description="The function description") - parameters: Dict[str, Any] = Field(..., description="The JSON schema for the function parameters") + parameters: Dict[str, Any] = Field( + ..., description="The JSON schema for the function parameters" + ) class ToolDefinition(BaseModel): """Tool definition for chat completion requests.""" + type: Literal["function"] = Field("function", description="Tool type") function: ToolFunction = Field(..., description="Function tool definition") class ToolChoiceFunction(BaseModel): """Tool choice function selector.""" + name: str = Field(..., description="Name of the function to call") class ToolChoice(BaseModel): """Tool choice definition.""" + type: Literal["function"] = Field("function", description="Tool choice type") function: ToolChoiceFunction = Field(..., description="Tool choice function") class ToolCallFunction(BaseModel): """Tool call function payload.""" + name: str = Field(..., description="Function name") arguments: str = Field(..., description="JSON-encoded arguments string") class ToolCall(BaseModel): """Tool call object for responses.""" + id: str = Field(..., description="Tool call ID") type: Literal["function"] = Field("function", description="Tool call type") function: ToolCallFunction = Field(..., description="Function call details") @@ -44,26 +55,39 @@ class ToolCall(BaseModel): class ToolCallFunctionDelta(BaseModel): """Streaming delta for tool call function.""" + name: Optional[str] = Field(None, description="Function name") arguments: Optional[str] = Field(None, description="Partial arguments payload") class ToolCallDelta(BaseModel): """Streaming delta for tool calls.""" + index: int = Field(..., description="Tool call index") id: Optional[str] = Field(None, description="Tool call ID") type: Optional[Literal["function"]] = Field(None, description="Tool call type") - function: Optional[ToolCallFunctionDelta] = Field(None, description="Function delta") + function: Optional[ToolCallFunctionDelta] = Field( + None, description="Function delta" + ) class ChatMessage(BaseModel): """Chat message model - accepts any content format.""" - role: Literal["system", "user", "assistant", "tool"] = Field(..., description="The role of the message author") + + role: Literal["system", "user", "assistant", "tool"] = Field( + ..., description="The role of the message author" + ) content: Optional[Any] = Field(None, description="The content of the message") - name: Optional[str] = Field(None, description="Optional name for the message author") - tool_calls: Optional[List[ToolCall]] = Field(None, description="Tool calls generated by the assistant") - tool_call_id: Optional[str] = Field(None, description="Tool call ID this tool message is responding to") - + name: Optional[str] = Field( + None, description="Optional name for the message author" + ) + tool_calls: Optional[List[ToolCall]] = Field( + None, description="Tool calls generated by the assistant" + ) + tool_call_id: Optional[str] = Field( + None, description="Tool call ID this tool message is responding to" + ) + def get_text_content(self) -> str: """Extract text content from any format.""" if self.content is None: @@ -88,101 +112,166 @@ def get_text_content(self) -> str: class ChatCompletionRequest(BaseModel): """Chat completion request model.""" + model: str = Field(..., description="ID of the model to use") - messages: List[ChatMessage] = Field(..., description="List of messages comprising the conversation") - temperature: Optional[float] = Field(1.0, ge=0.0, le=2.0, description="Sampling temperature") - top_p: Optional[float] = Field(1.0, ge=0.0, le=1.0, description="Nucleus sampling parameter") - max_tokens: Optional[int] = Field(None, ge=1, description="Maximum number of tokens to generate") - stream: Optional[bool] = Field(False, description="Whether to stream partial message deltas") - stop: Optional[Union[str, List[str]]] = Field(None, description="Up to 4 sequences where the API will stop generating") - frequency_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="Frequency penalty") - presence_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="Presence penalty") - user: Optional[str] = Field(None, description="Unique identifier representing your end-user") - tools: Optional[List[ToolDefinition]] = Field(None, description="Tools available for the model to call") + messages: List[ChatMessage] = Field( + ..., description="List of messages comprising the conversation" + ) + temperature: Optional[float] = Field( + 1.0, ge=0.0, le=2.0, description="Sampling temperature" + ) + top_p: Optional[float] = Field( + 1.0, ge=0.0, le=1.0, description="Nucleus sampling parameter" + ) + max_tokens: Optional[int] = Field( + None, ge=1, description="Maximum number of tokens to generate" + ) + stream: Optional[bool] = Field( + False, description="Whether to stream partial message deltas" + ) + stop: Optional[Union[str, List[str]]] = Field( + None, description="Up to 4 sequences where the API will stop generating" + ) + frequency_penalty: Optional[float] = Field( + 0.0, ge=-2.0, le=2.0, description="Frequency penalty" + ) + presence_penalty: Optional[float] = Field( + 0.0, ge=-2.0, le=2.0, description="Presence penalty" + ) + user: Optional[str] = Field( + None, description="Unique identifier representing your end-user" + ) + tools: Optional[List[ToolDefinition]] = Field( + None, description="Tools available for the model to call" + ) tool_choice: Optional[Union[str, ToolChoice]] = Field( - None, description="Tool choice preference (e.g. 'auto', 'none', or a specific tool)" + None, + description="Tool choice preference (e.g. 'auto', 'none', or a specific tool)", ) - + # Extension fields for Claude Code - project_id: Optional[str] = Field(None, description="Project ID for Claude Code context") - session_id: Optional[str] = Field(None, description="Session ID to continue conversation") + project_id: Optional[str] = Field( + None, description="Project ID for Claude Code context" + ) + session_id: Optional[str] = Field( + None, description="Session ID to continue conversation" + ) system_prompt: Optional[str] = Field(None, description="System prompt override") class ChatCompletionChoice(BaseModel): """Chat completion choice model.""" + index: int = Field(..., description="The index of the choice") message: ChatMessage = Field(..., description="The message generated by the model") - finish_reason: Optional[Literal["stop", "length", "content_filter", "tool_calls"]] = Field( - None, description="The reason the model stopped generating tokens" - ) + finish_reason: Optional[ + Literal["stop", "length", "content_filter", "tool_calls"] + ] = Field(None, description="The reason the model stopped generating tokens") class ChatCompletionUsage(BaseModel): """Token usage information.""" + prompt_tokens: int = Field(..., description="Number of tokens in the prompt") - completion_tokens: int = Field(..., description="Number of tokens in the completion") + completion_tokens: int = Field( + ..., description="Number of tokens in the completion" + ) total_tokens: int = Field(..., description="Total number of tokens used") class ChatCompletionResponse(BaseModel): """Chat completion response model.""" + id: str = Field(..., description="Unique identifier for the chat completion") - object: Literal["chat.completion"] = Field("chat.completion", description="Object type") - created: int = Field(..., description="Unix timestamp of when the completion was created") + object: Literal["chat.completion"] = Field( + "chat.completion", description=OBJECT_TYPE_DESC + ) + created: int = Field( + ..., description="Unix timestamp of when the completion was created" + ) model: str = Field(..., description="Model used for the completion") - choices: List[ChatCompletionChoice] = Field(..., description="List of completion choices") - usage: ChatCompletionUsage = Field(..., description="Usage statistics for the completion") - + choices: List[ChatCompletionChoice] = Field( + ..., description="List of completion choices" + ) + usage: ChatCompletionUsage = Field( + ..., description="Usage statistics for the completion" + ) + # Extension fields - session_id: Optional[str] = Field(None, description="Session ID for this completion") - project_id: Optional[str] = Field(None, description="Project ID for this completion") + session_id: Optional[str] = Field( + None, description="Session ID for this completion" + ) + project_id: Optional[str] = Field( + None, description="Project ID for this completion" + ) # Streaming Models class ChatCompletionChunkDelta(BaseModel): """Delta object for streaming responses.""" - role: Optional[str] = Field(None, description="The role of the author of this message") - content: Optional[str] = Field(None, description="The contents of the chunk message") - tool_calls: Optional[List[ToolCallDelta]] = Field(None, description="Tool call deltas") + + role: Optional[str] = Field( + None, description="The role of the author of this message" + ) + content: Optional[str] = Field( + None, description="The contents of the chunk message" + ) + tool_calls: Optional[List[ToolCallDelta]] = Field( + None, description="Tool call deltas" + ) class ChatCompletionChunkChoice(BaseModel): """Choice object for streaming responses.""" + index: int = Field(..., description="The index of the choice") - delta: ChatCompletionChunkDelta = Field(..., description="Delta containing message changes") - finish_reason: Optional[Literal["stop", "length", "content_filter", "tool_calls"]] = Field( - None, description="The reason the model stopped generating tokens" + delta: ChatCompletionChunkDelta = Field( + ..., description="Delta containing message changes" ) + finish_reason: Optional[ + Literal["stop", "length", "content_filter", "tool_calls"] + ] = Field(None, description="The reason the model stopped generating tokens") class ChatCompletionChunk(BaseModel): """Streaming chat completion chunk.""" + id: str = Field(..., description="Unique identifier for the chat completion") - object: Literal["chat.completion.chunk"] = Field("chat.completion.chunk", description="Object type") - created: int = Field(..., description="Unix timestamp of when the completion was created") + object: Literal["chat.completion.chunk"] = Field( + "chat.completion.chunk", description=OBJECT_TYPE_DESC + ) + created: int = Field( + ..., description="Unix timestamp of when the completion was created" + ) model: str = Field(..., description="Model used for the completion") - choices: List[ChatCompletionChunkChoice] = Field(..., description="List of completion choices") + choices: List[ChatCompletionChunkChoice] = Field( + ..., description="List of completion choices" + ) # Models endpoint class ModelObject(BaseModel): """Model object.""" + id: str = Field(..., description="Model identifier") - object: Literal["model"] = Field("model", description="Object type") - created: int = Field(..., description="Unix timestamp of when the model was created") + object: Literal["model"] = Field("model", description=OBJECT_TYPE_DESC) + created: int = Field( + ..., description="Unix timestamp of when the model was created" + ) owned_by: str = Field(..., description="Organization that owns the model") class ModelListResponse(BaseModel): """Model list response.""" - object: Literal["list"] = Field("list", description="Object type") + + object: Literal["list"] = Field("list", description=OBJECT_TYPE_DESC) data: List[ModelObject] = Field(..., description="List of model objects") # Error Models class ErrorDetail(BaseModel): """Error detail object.""" + message: str = Field(..., description="Human-readable error message") type: str = Field(..., description="Error type") code: Optional[str] = Field(None, description="Error code") @@ -190,12 +279,14 @@ class ErrorDetail(BaseModel): class ErrorResponse(BaseModel): """Error response model.""" + error: ErrorDetail = Field(..., description="Error details") # Extension Models for Claude Code specific features class ProjectInfo(BaseModel): """Project information model.""" + id: str = Field(..., description="Project identifier") name: str = Field(..., description="Project name") description: Optional[str] = Field(None, description="Project description") @@ -207,6 +298,7 @@ class ProjectInfo(BaseModel): class CreateProjectRequest(BaseModel): """Create project request model.""" + name: str = Field(..., description="Project name") description: Optional[str] = Field(None, description="Project description") path: Optional[str] = Field(None, description="Custom project path") @@ -214,11 +306,14 @@ class CreateProjectRequest(BaseModel): class SessionInfo(BaseModel): """Session information model.""" + id: str = Field(..., description="Session identifier") project_id: str = Field(..., description="Associated project ID") title: Optional[str] = Field(None, description="Session title") model: str = Field(..., description="Model used in this session") - system_prompt: Optional[str] = Field(None, description="System prompt for this session") + system_prompt: Optional[str] = Field( + None, description="System prompt for this session" + ) created_at: datetime = Field(..., description="Creation timestamp") updated_at: datetime = Field(..., description="Last update timestamp") is_active: bool = Field(True, description="Whether the session is active") @@ -229,6 +324,7 @@ class SessionInfo(BaseModel): class CreateSessionRequest(BaseModel): """Create session request model.""" + project_id: str = Field(..., description="Project ID for the session") title: Optional[str] = Field(None, description="Session title") model: Optional[str] = Field(None, description="Model to use") @@ -238,6 +334,7 @@ class CreateSessionRequest(BaseModel): # Tool execution models class ToolInfo(BaseModel): """Tool information model.""" + name: str = Field(..., description="Tool name") description: str = Field(..., description="Tool description") parameters: Dict[str, Any] = Field(..., description="Tool parameters schema") @@ -245,6 +342,7 @@ class ToolInfo(BaseModel): class ToolExecutionRequest(BaseModel): """Tool execution request model.""" + name: str = Field(..., description="Tool name to execute") parameters: Dict[str, Any] = Field(..., description="Tool parameters") project_id: str = Field(..., description="Project context for tool execution") @@ -252,6 +350,7 @@ class ToolExecutionRequest(BaseModel): class ToolExecutionResponse(BaseModel): """Tool execution response model.""" + success: bool = Field(..., description="Whether the tool execution was successful") result: Any = Field(..., description="Tool execution result") error: Optional[str] = Field(None, description="Error message if execution failed") @@ -261,7 +360,10 @@ class ToolExecutionResponse(BaseModel): # Health check model class HealthCheckResponse(BaseModel): """Health check response model.""" - status: Literal["healthy", "unhealthy"] = Field(..., description="Service health status") + + status: Literal["healthy", "unhealthy"] = Field( + ..., description="Service health status" + ) version: str = Field(..., description="API version") claude_version: Optional[str] = Field(None, description="Claude Code version") active_sessions: int = Field(..., description="Number of active sessions") @@ -271,17 +373,21 @@ class HealthCheckResponse(BaseModel): # Usage statistics models class UsageStats(BaseModel): """Usage statistics model.""" + total_requests: int = Field(..., description="Total number of requests") total_tokens: int = Field(..., description="Total tokens processed") total_cost: float = Field(..., description="Total cost incurred") active_sessions: int = Field(..., description="Currently active sessions") models_used: List[str] = Field(..., description="Models that have been used") - avg_response_time_ms: float = Field(..., description="Average response time in milliseconds") + avg_response_time_ms: float = Field( + ..., description="Average response time in milliseconds" + ) # Webhook models (for future extension) class WebhookEvent(BaseModel): """Webhook event model.""" + event_type: str = Field(..., description="Type of event") session_id: str = Field(..., description="Session ID") project_id: str = Field(..., description="Project ID") @@ -292,6 +398,7 @@ class WebhookEvent(BaseModel): # Pagination models class PaginationInfo(BaseModel): """Pagination information.""" + page: int = Field(1, ge=1, description="Current page number") per_page: int = Field(20, ge=1, le=100, description="Items per page") total_items: int = Field(..., description="Total number of items") @@ -302,6 +409,7 @@ class PaginationInfo(BaseModel): class PaginatedResponse(BaseModel): """Generic paginated response.""" + data: List[Any] = Field(..., description="List of items") pagination: PaginationInfo = Field(..., description="Pagination information") @@ -309,6 +417,7 @@ class PaginatedResponse(BaseModel): # File upload models (for project files) class FileUploadResponse(BaseModel): """File upload response model.""" + filename: str = Field(..., description="Uploaded filename") size: int = Field(..., description="File size in bytes") path: str = Field(..., description="File path in project") @@ -318,6 +427,7 @@ class FileUploadResponse(BaseModel): # Configuration models class APIConfiguration(BaseModel): """API configuration model.""" + max_concurrent_sessions: int = Field(..., description="Maximum concurrent sessions") session_timeout_minutes: int = Field(..., description="Session timeout in minutes") supported_models: List[str] = Field(..., description="List of supported models") diff --git a/claude_code_api/tests/test_gpt_turbo.py b/claude_code_api/tests/test_gpt_turbo.py index 3645e1c..1b87ea8 100644 --- a/claude_code_api/tests/test_gpt_turbo.py +++ b/claude_code_api/tests/test_gpt_turbo.py @@ -1,46 +1,52 @@ """Tests for GPT-3.5 Turbo integration.""" import os + import pytest -import openai from openai import OpenAI + from claude_code_api.models.openai import ChatCompletionRequest, ChatMessage # Note: This test requires setting the OPENAI_API_KEY environment variable # You can set this by running: export OPENAI_API_KEY='your-api-key-here' + @pytest.mark.skipif( - not os.environ.get('OPENAI_API_KEY'), - reason="OpenAI API key is not set. Set OPENAI_API_KEY environment variable to run this test." + not os.environ.get("OPENAI_API_KEY"), + reason="OpenAI API key is not set. Set OPENAI_API_KEY environment variable to run this test.", ) def test_gpt_turbo_prompt(): """Test basic prompt with GPT-3.5 Turbo.""" # You'll need to replace this with your actual OpenAI API key client = OpenAI(api_key="your-openai-api-key") - + # Create a test request that matches the OpenAI API structure messages = [ {"role": "system", "content": "You are a helpful coding assistant."}, - {"role": "user", "content": "Write a Python function to calculate the factorial of a number."} + { + "role": "user", + "content": "Write a Python function to calculate the factorial of a number.", + }, ] - + # Make the API call response = client.chat.completions.create( model="gpt-3.5-turbo", messages=messages, temperature=0, # For deterministic output - max_tokens=250 # Limit response length + max_tokens=250, # Limit response length ) - + # Validate the response assert response.choices[0].message.role == "assistant" assert response.choices[0].message.content is not None - + # Optional: Check the function content function_code = response.choices[0].message.content assert "def factorial" in function_code or "def fact" in function_code assert "return" in function_code + def test_claude_gpt_turbo_compatibility(): """Ensure our models can parse GPT-3.5 Turbo request.""" # Create a request matching GPT-3.5 Turbo structure @@ -48,20 +54,20 @@ def test_claude_gpt_turbo_compatibility(): "model": "gpt-3.5-turbo", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, world!"} + {"role": "user", "content": "Hello, world!"}, ], "temperature": 0.7, "max_tokens": 100, - "stream": False + "stream": False, } - + # Try parsing with our models request = ChatCompletionRequest(**request_data) - + # Validate parsed request assert request.model == "gpt-3.5-turbo" assert len(request.messages) == 2 assert all(isinstance(msg, ChatMessage) for msg in request.messages) assert request.temperature == 0.7 assert request.max_tokens == 100 - assert not request.stream \ No newline at end of file + assert not request.stream diff --git a/claude_code_api/utils/__init__.py b/claude_code_api/utils/__init__.py index 1d226f6..f9a8320 100644 --- a/claude_code_api/utils/__init__.py +++ b/claude_code_api/utils/__init__.py @@ -1 +1 @@ -"""Utils package.""" \ No newline at end of file +"""Utils package.""" diff --git a/claude_code_api/utils/parser.py b/claude_code_api/utils/parser.py index bd30b53..d8a935e 100644 --- a/claude_code_api/utils/parser.py +++ b/claude_code_api/utils/parser.py @@ -2,30 +2,60 @@ import json import uuid -from typing import Dict, Any, Optional, List, Generator from datetime import datetime +from typing import Any, Dict, Generator, List, Optional + import structlog -from claude_code_api.models.claude import ClaudeMessage, ClaudeToolUse, ClaudeToolResult +from claude_code_api.models.claude import ClaudeMessage, ClaudeToolResult, ClaudeToolUse +from claude_code_api.utils.time import utc_now logger = structlog.get_logger() +def _normalize_text_value(value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, dict) and "text" in value: + return str(value.get("text", "")) + return str(value) + + +def _text_from_mapping(payload: Dict[str, Any]) -> Optional[str]: + if "text" in payload: + return _normalize_text_value(payload.get("text")) + if "content" in payload: + return _normalize_text_value(payload.get("content")) + return None + + +def _text_from_part(part: Any) -> Optional[str]: + if isinstance(part, dict): + if part.get("type") == "text" and "text" in part: + return _normalize_text_value(part.get("text")) + return _text_from_mapping(part) + if isinstance(part, str): + return part + return None + + class ClaudeOutputParser: """Parser for Claude Code JSONL output.""" - + def __init__(self): self.session_id: Optional[str] = None self.model: Optional[str] = None self.total_tokens = 0 self.total_cost = 0.0 self.message_count = 0 - + def parse_line(self, line: str) -> Optional[ClaudeMessage]: """Parse a single JSONL line.""" if not line.strip(): return None - + try: data = json.loads(line.strip()) message = ClaudeMessage(**data) @@ -62,55 +92,44 @@ def parse_message(self, message: ClaudeMessage) -> Optional[ClaudeMessage]: self.message_count += 1 return message - + def parse_stream(self, lines: List[str]) -> Generator[ClaudeMessage, None, None]: """Parse multiple JSONL lines.""" for line in lines: message = self.parse_line(line) if message: yield message - + def extract_text_content(self, message: ClaudeMessage) -> str: """Extract text content from a message.""" if not message.message: return "" - - content = message.message.get("content", []) + + content = message.message.get("content") + if content is None: + return "" if isinstance(content, str): return content if isinstance(content, dict): - if "text" in content: - return str(content.get("text", "")) - if "content" in content: - return str(content.get("content", "")) - + return _text_from_mapping(content) or "" if isinstance(content, list): text_parts = [] for part in content: - if isinstance(part, dict): - if part.get("type") == "text" or "text" in part: - text = part.get("text", "") - if isinstance(text, str): - text_parts.append(text) - elif isinstance(text, dict) and "text" in text: - text_parts.append(text["text"]) - elif "content" in part: - text_parts.append(str(part.get("content", ""))) - elif isinstance(part, str): - text_parts.append(part) + text = _text_from_part(part) + if text: + text_parts.append(text) return "\n".join(text_parts) - - return "" - + return str(content) + def extract_tool_uses(self, message: ClaudeMessage) -> List[ClaudeToolUse]: """Extract tool uses from a message.""" if not message.message: return [] - + content = message.message.get("content", []) if not isinstance(content, list): return [] - + tool_uses = [] for part in content: if isinstance(part, dict) and part.get("type") == "tool_use": @@ -118,23 +137,23 @@ def extract_tool_uses(self, message: ClaudeMessage) -> List[ClaudeToolUse]: tool_use = ClaudeToolUse( id=part.get("id", ""), name=part.get("name", ""), - input=part.get("input", {}) + input=part.get("input", {}), ) tool_uses.append(tool_use) except Exception as e: logger.warning("Failed to parse tool use", part=part, error=str(e)) - + return tool_uses - + def extract_tool_results(self, message: ClaudeMessage) -> List[ClaudeToolResult]: """Extract tool results from a message.""" if not message.message: return [] - + content = message.message.get("content", []) if not isinstance(content, list): return [] - + tool_results = [] for part in content: if isinstance(part, dict) and part.get("type") == "tool_result": @@ -142,32 +161,36 @@ def extract_tool_results(self, message: ClaudeMessage) -> List[ClaudeToolResult] tool_result = ClaudeToolResult( tool_use_id=part.get("tool_use_id", ""), content=part.get("content", ""), - is_error=part.get("is_error", False) + is_error=part.get("is_error", False), ) tool_results.append(tool_result) except Exception as e: - logger.warning("Failed to parse tool result", part=part, error=str(e)) - + logger.warning( + "Failed to parse tool result", part=part, error=str(e) + ) + return tool_results - + def is_system_message(self, message: ClaudeMessage) -> bool: """Check if message is a system message.""" return message.type == "system" - + def is_user_message(self, message: ClaudeMessage) -> bool: """Check if message is from user.""" - return (message.type == "user" or - (message.message and message.message.get("role") == "user")) - + return message.type == "user" or ( + message.message and message.message.get("role") == "user" + ) + def is_assistant_message(self, message: ClaudeMessage) -> bool: """Check if message is from assistant.""" - return (message.type == "assistant" or - (message.message and message.message.get("role") == "assistant")) - + return message.type == "assistant" or ( + message.message and message.message.get("role") == "assistant" + ) + def is_final_message(self, message: ClaudeMessage) -> bool: """Check if this is a final result message.""" return message.type == "result" - + def get_session_summary(self) -> Dict[str, Any]: """Get summary of parsed session.""" return { @@ -175,9 +198,9 @@ def get_session_summary(self) -> Dict[str, Any]: "model": self.model, "total_tokens": self.total_tokens, "total_cost": self.total_cost, - "message_count": self.message_count + "message_count": self.message_count, } - + def reset(self): """Reset parser state.""" self.session_id = None @@ -189,104 +212,84 @@ def reset(self): class OpenAIConverter: """Converts Claude messages to OpenAI format.""" - + @staticmethod def claude_message_to_openai(message: ClaudeMessage) -> Optional[Dict[str, Any]]: """Convert Claude message to OpenAI chat format.""" if message.is_system_message(): - return { - "role": "system", - "content": message.extract_text_content() - } - + return {"role": "system", "content": message.extract_text_content()} + if message.is_user_message(): - return { - "role": "user", - "content": message.extract_text_content() - } - + return {"role": "user", "content": message.extract_text_content()} + if message.is_assistant_message(): content = message.extract_text_content() if content: - return { - "role": "assistant", - "content": content - } - + return {"role": "assistant", "content": content} + return None - + @staticmethod def claude_stream_to_openai_chunk( - message: ClaudeMessage, - chunk_id: str, - model: str, - created: int + message: ClaudeMessage, chunk_id: str, model: str, created: int ) -> Optional[Dict[str, Any]]: """Convert Claude stream message to OpenAI chunk format.""" if not message.is_assistant_message(): return None - + content = message.extract_text_content() if not content: return None - + return { "id": chunk_id, "object": "chat.completion.chunk", "created": created, "model": model, - "choices": [{ - "index": 0, - "delta": { - "role": "assistant", - "content": content - }, - "finish_reason": None - }] + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": content}, + "finish_reason": None, + } + ], } - + @staticmethod def create_final_chunk( - chunk_id: str, - model: str, - created: int, - finish_reason: str = "stop" + chunk_id: str, model: str, created: int, finish_reason: str = "stop" ) -> Dict[str, Any]: """Create final chunk to end streaming.""" return { "id": chunk_id, - "object": "chat.completion.chunk", + "object": "chat.completion.chunk", "created": created, "model": model, - "choices": [{ - "index": 0, - "delta": {}, - "finish_reason": finish_reason - }] + "choices": [{"index": 0, "delta": {}, "finish_reason": finish_reason}], } - + @staticmethod def calculate_usage(parser: ClaudeOutputParser) -> Dict[str, int]: """Calculate token usage from parser.""" # Estimate prompt tokens (this is approximate) prompt_tokens = max(0, parser.total_tokens - parser.message_count * 100) completion_tokens = parser.total_tokens - prompt_tokens - + return { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, - "total_tokens": parser.total_tokens + "total_tokens": parser.total_tokens, } class MessageAggregator: """Aggregates streaming messages into complete responses.""" - + def __init__(self): self.messages: List[ClaudeMessage] = [] self.current_assistant_content = "" self.parser = ClaudeOutputParser() - + def add_message(self, message: Any): """Add message to aggregator.""" normalized = normalize_claude_message(message) @@ -294,25 +297,25 @@ def add_message(self, message: Any): return self.messages.append(normalized) self.parser.parse_message(normalized) - + # Aggregate assistant content for complete response if self.parser.is_assistant_message(normalized): content = self.parser.extract_text_content(normalized) if content: self.current_assistant_content += content - + def get_complete_response(self) -> str: """Get complete aggregated response.""" return self.current_assistant_content - + def get_messages(self) -> List[ClaudeMessage]: """Get all messages.""" return self.messages - + def get_usage_summary(self) -> Dict[str, Any]: """Get usage summary.""" return self.parser.get_session_summary() - + def clear(self): """Clear aggregator state.""" self.messages.clear() @@ -324,20 +327,20 @@ def sanitize_content(content: str) -> str: """Sanitize content for safe transmission.""" if not content: return "" - + # Remove null bytes - content = content.replace('\x00', '') - + content = content.replace("\x00", "") + # Normalize line endings - content = content.replace('\r\n', '\n').replace('\r', '\n') - + content = content.replace("\r\n", "\n").replace("\r", "\n") + # Ensure valid UTF-8 try: - content.encode('utf-8') + content.encode("utf-8") except UnicodeEncodeError: # Replace invalid characters - content = content.encode('utf-8', errors='replace').decode('utf-8') - + content = content.encode("utf-8", errors="replace").decode("utf-8") + return content @@ -345,16 +348,16 @@ def extract_error_from_message(message: ClaudeMessage) -> Optional[str]: """Extract error information from Claude message.""" if message.error: return message.error - + if message.type == "result" and not message.result: return "Execution completed without result" - + # Check for error in tool results tool_results = ClaudeOutputParser().extract_tool_results(message) for result in tool_results: if result.is_error: return str(result.content) - + return None @@ -371,7 +374,7 @@ def normalize_claude_message(raw: Any) -> Optional[ClaudeMessage]: if isinstance(raw, dict): try: return ClaudeMessage(**raw) - except Exception as e: + except (TypeError, ValueError) as e: logger.warning("Failed to normalize Claude message", error=str(e)) return None return None @@ -381,27 +384,28 @@ def tool_use_to_openai_call(tool_use: ClaudeToolUse) -> Dict[str, Any]: """Convert a Claude tool use to an OpenAI tool call object.""" tool_id = tool_use.id or f"call_{uuid.uuid4().hex}" try: - arguments = json.dumps(tool_use.input or {}, separators=(",", ":"), ensure_ascii=False) + arguments = json.dumps( + tool_use.input or {}, separators=(",", ":"), ensure_ascii=False + ) except TypeError: - arguments = json.dumps({"input": str(tool_use.input)}, separators=(",", ":"), ensure_ascii=False) + arguments = json.dumps( + {"input": str(tool_use.input)}, separators=(",", ":"), ensure_ascii=False + ) return { "id": tool_id, "type": "function", - "function": { - "name": tool_use.name, - "arguments": arguments - } + "function": {"name": tool_use.name, "arguments": arguments}, } def format_timestamp(timestamp: Optional[str]) -> str: """Format timestamp for display.""" if not timestamp: - return datetime.utcnow().isoformat() - + return utc_now().isoformat() + try: # Try parsing ISO format - dt = datetime.fromisoformat(timestamp.replace('Z', '+00:00')) + dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) return dt.isoformat() - except: + except (ValueError, TypeError): return timestamp diff --git a/claude_code_api/utils/streaming.py b/claude_code_api/utils/streaming.py index c3aa7b6..1689c9e 100644 --- a/claude_code_api/utils/streaming.py +++ b/claude_code_api/utils/streaming.py @@ -1,26 +1,29 @@ """Server-Sent Events streaming utilities for OpenAI compatibility.""" -import json import asyncio +import json import uuid -from datetime import datetime -from typing import AsyncGenerator, Dict, Any, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple + import structlog +from claude_code_api.core.claude_manager import ClaudeProcess from claude_code_api.utils.parser import ( ClaudeOutputParser, OpenAIConverter, normalize_claude_message, - tool_use_to_openai_call + tool_use_to_openai_call, ) -from claude_code_api.core.claude_manager import ClaudeProcess +from claude_code_api.utils.time import utc_timestamp logger = structlog.get_logger() +CHUNK_OBJECT_TYPE = "chat.completion.chunk" + class SSEFormatter: """Formats data for Server-Sent Events.""" - + @staticmethod def format_event(data: Dict[str, Any]) -> str: """ @@ -29,26 +32,22 @@ def format_event(data: Dict[str, Any]) -> str: We deliberately omit the `event:` line so the default event-type **message** is used. """ - json_data = json.dumps(data, separators=(',', ':')) + json_data = json.dumps(data, separators=(",", ":")) return f"data: {json_data}\n\n" - + @staticmethod def format_completion(data: str) -> str: """Format completion signal.""" return "data: [DONE]\n\n" - + @staticmethod def format_error(error: str, error_type: str = "error") -> str: """Format error message.""" error_data = { - "error": { - "message": error, - "type": error_type, - "code": "stream_error" - } + "error": {"message": error, "type": error_type, "code": "stream_error"} } return SSEFormatter.format_event(error_data) - + @staticmethod def format_heartbeat() -> str: """Format heartbeat ping.""" @@ -57,115 +56,103 @@ def format_heartbeat() -> str: class OpenAIStreamConverter: """Converts Claude Code output to OpenAI-compatible streaming format.""" - + def __init__(self, model: str, session_id: str): self.model = model self.session_id = session_id self.completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" - self.created = int(datetime.utcnow().timestamp()) + self.created = utc_timestamp() self.chunk_index = 0 self.parser = ClaudeOutputParser() self.tool_call_index = 0 - + + def _build_chunk( + self, delta: Dict[str, Any], finish_reason: Optional[str] = None + ) -> Dict[str, Any]: + return { + "id": self.completion_id, + "object": CHUNK_OBJECT_TYPE, + "created": self.created, + "model": self.model, + "choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}], + } + + def _build_tool_calls(self, tool_uses: List[Any]) -> List[Dict[str, Any]]: + tool_calls = [] + for tool_use in tool_uses: + call = tool_use_to_openai_call(tool_use) + call["index"] = self.tool_call_index + self.tool_call_index += 1 + tool_calls.append(call) + return tool_calls + + def _assistant_chunks(self, message: Any) -> Tuple[List[str], bool, bool]: + chunks: List[str] = [] + saw_text = False + saw_tool_calls = False + + text_content = self.parser.extract_text_content(message).strip() + if text_content: + chunks.append( + SSEFormatter.format_event(self._build_chunk({"content": text_content})) + ) + saw_text = True + + tool_uses = self.parser.extract_tool_uses(message) + if tool_uses: + tool_calls = self._build_tool_calls(tool_uses) + chunks.append( + SSEFormatter.format_event(self._build_chunk({"tool_calls": tool_calls})) + ) + saw_tool_calls = True + + return chunks, saw_text, saw_tool_calls + async def convert_stream( - self, - claude_process: ClaudeProcess + self, claude_process: ClaudeProcess ) -> AsyncGenerator[str, None]: """Convert Claude Code output stream to OpenAI format.""" try: # Send initial chunk to establish streaming - initial_chunk = { - "id": self.completion_id, - "object": "chat.completion.chunk", - "created": self.created, - "model": self.model, - "choices": [{ - "index": 0, - "delta": {"role": "assistant", "content": ""}, - "finish_reason": None - }] - } - yield SSEFormatter.format_event(initial_chunk) - + yield SSEFormatter.format_event( + self._build_chunk({"role": "assistant", "content": ""}) + ) + saw_assistant_text = False saw_tool_calls = False - + # Process Claude output async for claude_message in claude_process.get_output(): - try: - message = normalize_claude_message(claude_message) - if not message: - continue - self.parser.parse_message(message) - - if self.parser.is_assistant_message(message): - text_content = self.parser.extract_text_content(message).strip() - if text_content: - chunk = { - "id": self.completion_id, - "object": "chat.completion.chunk", - "created": self.created, - "model": self.model, - "choices": [{ - "index": 0, - "delta": {"content": text_content}, - "finish_reason": None - }] - } - yield SSEFormatter.format_event(chunk) - saw_assistant_text = True - - tool_uses = self.parser.extract_tool_uses(message) - if tool_uses: - tool_calls = [] - for tool_use in tool_uses: - call = tool_use_to_openai_call(tool_use) - call["index"] = self.tool_call_index - self.tool_call_index += 1 - tool_calls.append(call) - tool_chunk = { - "id": self.completion_id, - "object": "chat.completion.chunk", - "created": self.created, - "model": self.model, - "choices": [{ - "index": 0, - "delta": {"tool_calls": tool_calls}, - "finish_reason": None - }] - } - yield SSEFormatter.format_event(tool_chunk) - saw_tool_calls = True - - if self.parser.is_final_message(message): - break - - except Exception as e: - logger.error("Error processing Claude message", error=str(e)) + message = normalize_claude_message(claude_message) + if not message: continue - + self.parser.parse_message(message) + + if self.parser.is_assistant_message(message): + chunks, saw_text, saw_tools = self._assistant_chunks(message) + for chunk in chunks: + yield chunk + saw_assistant_text = saw_assistant_text or saw_text + saw_tool_calls = saw_tool_calls or saw_tools + + if self.parser.is_final_message(message): + break + # Send final chunk - finish_reason = "tool_calls" if (saw_tool_calls and not saw_assistant_text) else "stop" - final_chunk = { - "id": self.completion_id, - "object": "chat.completion.chunk", - "created": self.created, - "model": self.model, - "choices": [{ - "index": 0, - "delta": {}, - "finish_reason": finish_reason - }] - } - yield SSEFormatter.format_event(final_chunk) - + finish_reason = ( + "tool_calls" if (saw_tool_calls and not saw_assistant_text) else "stop" + ) + yield SSEFormatter.format_event( + self._build_chunk({}, finish_reason=finish_reason) + ) + # Send completion signal yield SSEFormatter.format_completion("") - + except Exception as e: logger.error("Error in stream conversion", error=str(e)) yield SSEFormatter.format_error(f"Stream error: {str(e)}") - + def get_final_response(self) -> Dict[str, Any]: """Get complete response in OpenAI format.""" return { @@ -173,53 +160,43 @@ def get_final_response(self) -> Dict[str, Any]: "object": "chat.completion", "created": self.created, "model": self.model, - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "Response completed" - }, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": 10, - "completion_tokens": 5, - "total_tokens": 15 - }, - "session_id": self.session_id + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Response completed"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + "session_id": self.session_id, } class StreamingManager: """Manages multiple streaming connections.""" - + def __init__(self): self.active_streams: Dict[str, OpenAIStreamConverter] = {} self.heartbeat_interval = 30 # seconds - + async def create_stream( - self, - session_id: str, - model: str, - claude_process: ClaudeProcess + self, session_id: str, model: str, claude_process: ClaudeProcess ) -> AsyncGenerator[str, None]: """Create new streaming connection.""" converter = OpenAIStreamConverter(model, session_id) self.active_streams[session_id] = converter - + try: # Start heartbeat task - heartbeat_task = asyncio.create_task( - self._send_heartbeats(session_id) - ) - + heartbeat_task = asyncio.create_task(self._send_heartbeats(session_id)) + # Stream conversion async for chunk in converter.convert_stream(claude_process): yield chunk - + # Cancel heartbeat heartbeat_task.cancel() - + except Exception as e: logger.error("Streaming error", session_id=session_id, error=str(e)) yield SSEFormatter.format_error(f"Streaming failed: {str(e)}") @@ -227,45 +204,42 @@ async def create_stream( # Cleanup if session_id in self.active_streams: del self.active_streams[session_id] - + async def _send_heartbeats(self, session_id: str): """Send periodic heartbeats to keep connection alive.""" - try: - while session_id in self.active_streams: - await asyncio.sleep(self.heartbeat_interval) - # Heartbeats are handled by the SSE client - except asyncio.CancelledError: - pass - + while session_id in self.active_streams: + await asyncio.sleep(self.heartbeat_interval) + # Heartbeats are handled by the SSE client + def get_active_stream_count(self) -> int: """Get number of active streams.""" return len(self.active_streams) - - async def cleanup_stream(self, session_id: str): + + def cleanup_stream(self, session_id: str): """Cleanup specific stream.""" if session_id in self.active_streams: del self.active_streams[session_id] - - async def cleanup_all_streams(self): + + def cleanup_all_streams(self): """Cleanup all streams.""" self.active_streams.clear() class ChunkBuffer: """Buffers chunks for smooth streaming.""" - + def __init__(self, max_size: int = 1000): self.buffer = [] self.max_size = max_size self.lock = asyncio.Lock() - + async def add_chunk(self, chunk: str): """Add chunk to buffer.""" async with self.lock: self.buffer.append(chunk) if len(self.buffer) > self.max_size: self.buffer.pop(0) # Remove oldest chunk - + async def get_chunks(self) -> AsyncGenerator[str, None]: """Get chunks from buffer.""" while True: @@ -279,45 +253,45 @@ async def get_chunks(self) -> AsyncGenerator[str, None]: class AdaptiveStreaming: """Adaptive streaming with backpressure handling.""" - + def __init__(self): self.chunk_size = 1024 self.min_chunk_size = 256 self.max_chunk_size = 4096 self.adjustment_factor = 1.1 - + async def stream_with_backpressure( self, data_source: AsyncGenerator[str, None], - client_ready_callback: Optional[callable] = None + client_ready_callback: Optional[callable] = None, ) -> AsyncGenerator[str, None]: """Stream with adaptive chunk sizing based on client readiness.""" buffer = "" - + async for data in data_source: buffer += data - + # Check if we have enough data to send while len(buffer) >= self.chunk_size: - chunk = buffer[:self.chunk_size] - buffer = buffer[self.chunk_size:] - + chunk = buffer[: self.chunk_size] + buffer = buffer[self.chunk_size :] + # Adjust chunk size based on client readiness if client_ready_callback and not client_ready_callback(): # Client is slow, reduce chunk size self.chunk_size = max( self.min_chunk_size, - int(self.chunk_size / self.adjustment_factor) + int(self.chunk_size / self.adjustment_factor), ) else: # Client is ready, can increase chunk size self.chunk_size = min( self.max_chunk_size, - int(self.chunk_size * self.adjustment_factor) + int(self.chunk_size * self.adjustment_factor), ) - + yield chunk - + # Send remaining buffer if buffer: yield buffer @@ -328,77 +302,90 @@ async def stream_with_backpressure( async def create_sse_response( - session_id: str, - model: str, - claude_process: ClaudeProcess + session_id: str, model: str, claude_process: ClaudeProcess ) -> AsyncGenerator[str, None]: """Create SSE response for Claude Code output.""" - async for chunk in streaming_manager.create_stream(session_id, model, claude_process): + async for chunk in streaming_manager.create_stream( + session_id, model, claude_process + ): yield chunk -def create_non_streaming_response( - messages: list, - session_id: str, - model: str, - usage: Optional[Dict[str, Any]] = None -) -> Dict[str, Any]: - """Create non-streaming response.""" - completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" - created = int(datetime.utcnow().timestamp()) - - logger.info( - "Creating non-streaming response", - session_id=session_id, - model=model, - messages_count=len(messages), - completion_id=completion_id - ) - - parser = ClaudeOutputParser() - tool_calls = [] - # Extract assistant content from Claude messages - content_parts = [] +def _extract_assistant_payload( + messages: list, parser: ClaudeOutputParser +) -> Tuple[List[str], List[Dict[str, Any]]]: + tool_calls: List[Dict[str, Any]] = [] + content_parts: List[str] = [] + for i, msg in enumerate(messages): normalized = normalize_claude_message(msg) if not normalized: continue parser.parse_message(normalized) logger.info( - f"Processing message {i}", + "Processing message", + message_index=i, msg_type=normalized.type, msg_keys=list(normalized.model_dump().keys()), - is_assistant=parser.is_assistant_message(normalized) + is_assistant=parser.is_assistant_message(normalized), + ) + + if not parser.is_assistant_message(normalized): + continue + + text_content = parser.extract_text_content(normalized).strip() + logger.info( + "Found assistant message", + message_index=i, + content_length=len(text_content), + content_preview=text_content[:100] if text_content else "empty", ) - - if parser.is_assistant_message(normalized): - text_content = parser.extract_text_content(normalized).strip() + if text_content: + content_parts.append(text_content) logger.info( - f"Found assistant message {i}", - content_length=len(text_content), - content_preview=text_content[:100] if text_content else "empty" + "Extracted assistant text", + message_index=i, + content_preview=text_content[:50], ) - if text_content: - content_parts.append(text_content) - logger.info(f"Extracted assistant text: {text_content[:50]}...") - - tool_uses = parser.extract_tool_uses(normalized) - for tool_use in tool_uses: - tool_calls.append(tool_use_to_openai_call(tool_use)) - + + tool_uses = parser.extract_tool_uses(normalized) + for tool_use in tool_uses: + tool_calls.append(tool_use_to_openai_call(tool_use)) + + return content_parts, tool_calls + + +def create_non_streaming_response( + messages: list, session_id: str, model: str, usage: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: + """Create non-streaming response.""" + completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" + created = utc_timestamp() + + logger.info( + "Creating non-streaming response", + session_id=session_id, + model=model, + messages_count=len(messages), + completion_id=completion_id, + ) + + parser = ClaudeOutputParser() + content_parts, tool_calls = _extract_assistant_payload(messages, parser) + # Use the actual content or fallback if content_parts: complete_content = "\n".join(content_parts).strip() else: complete_content = "" - + logger.info( "Final response content", content_parts_count=len(content_parts), final_content_length=len(complete_content), - final_content_preview=complete_content[:100] if complete_content else "empty" + final_content_preview=complete_content[:100] if complete_content else "empty", ) - + # Return simple OpenAI-compatible response with basic usage stats if usage is None: usage = OpenAIConverter.calculate_usage(parser) @@ -407,7 +394,7 @@ def create_non_streaming_response( message_payload: Dict[str, Any] = { "role": "assistant", - "content": complete_content or None + "content": complete_content or None, } if tool_calls: message_payload["tool_calls"] = tool_calls @@ -417,24 +404,22 @@ def create_non_streaming_response( "object": "chat.completion", "created": created, "model": model, - "choices": [{ - "index": 0, - "message": message_payload, - "finish_reason": finish_reason - }], + "choices": [ + {"index": 0, "message": message_payload, "finish_reason": finish_reason} + ], "usage": { "prompt_tokens": usage.get("prompt_tokens", 0), "completion_tokens": usage.get("completion_tokens", 0), - "total_tokens": usage.get("total_tokens", 0) + "total_tokens": usage.get("total_tokens", 0), }, - "session_id": session_id + "session_id": session_id, } - + logger.info( "Response created successfully", response_id=response["id"], choices_count=len(response["choices"]), - message_content_length=len(response["choices"][0]["message"]["content"] or "") + message_content_length=len(response["choices"][0]["message"]["content"] or ""), ) - + return response diff --git a/claude_code_api/utils/time.py b/claude_code_api/utils/time.py new file mode 100644 index 0000000..d8e737f --- /dev/null +++ b/claude_code_api/utils/time.py @@ -0,0 +1,13 @@ +"""Time helpers for consistent UTC timestamps.""" + +from datetime import datetime, timezone + + +def utc_now() -> datetime: + """Return a naive UTC datetime (timezone stripped for storage compatibility).""" + return datetime.now(timezone.utc).replace(tzinfo=None) + + +def utc_timestamp() -> int: + """Return a UTC unix timestamp in seconds.""" + return int(utc_now().timestamp()) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..9ed812e --- /dev/null +++ b/setup.cfg @@ -0,0 +1,6 @@ +[flake8] +max-line-length = 120 +extend-ignore = E203, W503 +per-file-ignores = + tests/conftest.py:E402 + tests/test_end_to_end.py:E402 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..38bb211 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test package.""" diff --git a/tests/conftest.py b/tests/conftest.py index 3cfab3a..15822b6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,13 @@ """Pytest configuration and fixtures.""" -import pytest +import json import os +import shutil import sys import tempfile -import shutil -import json from pathlib import Path + +import pytest from fastapi.testclient import TestClient from httpx import AsyncClient @@ -14,27 +15,30 @@ PROJECT_ROOT = Path(__file__).parent.parent sys.path.insert(0, str(PROJECT_ROOT)) +from claude_code_api.core.config import settings + # Now import the app and configuration from claude_code_api.main import app -from claude_code_api.models.claude import get_default_model -from claude_code_api.core.config import settings +from tests.model_utils import get_test_model_id @pytest.fixture(scope="session", autouse=True) def setup_test_environment(): """Setup test environment before all tests.""" # Create temporary directory for testing - temp_dir = tempfile.mkdtemp(prefix="claude_api_test_") - + test_root = PROJECT_ROOT / "dist" / "tests" + test_root.mkdir(parents=True, exist_ok=True) + temp_dir = tempfile.mkdtemp(prefix="claude_api_test_", dir=str(test_root)) + # Store original settings original_settings = { "project_root": getattr(settings, "project_root", None), "require_auth": getattr(settings, "require_auth", False), "claude_binary_path": getattr(settings, "claude_binary_path", "claude"), "database_url": getattr(settings, "database_url", "sqlite:///./test.db"), - "debug": getattr(settings, "debug", False) + "debug": getattr(settings, "debug", False), } - + # Set test settings settings.project_root = os.path.join(temp_dir, "projects") settings.require_auth = False @@ -56,17 +60,21 @@ def setup_test_environment(): # Create a mock binary that replays recorded JSONL fixtures mock_path = os.path.join(temp_dir, "claude") with open(mock_path, "w") as f: - f.write('#!/usr/bin/env bash\n') - f.write('if [ "$1" == "--version" ]; then echo "Claude Code 1.0.0"; exit 0; fi\n') + f.write("#!/usr/bin/env bash\n") + f.write( + 'if [ "$1" == "--version" ]; then echo "Claude Code 1.0.0"; exit 0; fi\n' + ) f.write('prompt=""\n') f.write('args=("$@")\n') - f.write('for ((i=0; i<${#args[@]}; i++)); do\n') + f.write("for ((i=0; i<${#args[@]}; i++)); do\n") f.write(' if [ "${args[$i]}" == "-p" ]; then\n') f.write(' prompt="${args[$((i+1))]}"\n') - f.write(' break\n') - f.write(' fi\n') - f.write('done\n') - f.write('prompt_lower="$(printf "%s" "$prompt" | tr "[:upper:]" "[:lower:]")"\n') + f.write(" break\n") + f.write(" fi\n") + f.write("done\n") + f.write( + 'prompt_lower="$(printf "%s" "$prompt" | tr "[:upper:]" "[:lower:]")"\n' + ) f.write(f'fixture_default="{default_fixture}"\n') f.write('fixture_match="$fixture_default"\n') for rule in fixture_rules: @@ -77,31 +85,38 @@ def setup_test_environment(): fixture_path = fixtures_dir / fixture_file for match in matches: match_escaped = str(match).replace('"', '\\"') - f.write(f'if echo "$prompt_lower" | grep -q "{match_escaped}"; then fixture_match="{fixture_path}"; fi\n') + line = ( + f'if echo "$prompt_lower" | grep -q "{match_escaped}"; ' + f'then fixture_match="{fixture_path}"; ' + "fi\n" + ) + f.write(line) f.write('cat "$fixture_match"\n') os.chmod(mock_path, 0o755) settings.claude_binary_path = mock_path else: # Ensure the real binary is available when requested - if not shutil.which(settings.claude_binary_path) and not os.path.exists(settings.claude_binary_path): + if not shutil.which(settings.claude_binary_path) and not os.path.exists( + settings.claude_binary_path + ): raise RuntimeError( f"CLAUDE_CODE_API_USE_REAL_CLAUDE=1 but binary not found at {settings.claude_binary_path}" ) settings.database_url = f"sqlite:///{temp_dir}/test.db" settings.debug = True - + # Create directories os.makedirs(settings.project_root, exist_ok=True) - + yield temp_dir - + # Cleanup try: shutil.rmtree(temp_dir) except Exception as e: print(f"Cleanup warning: {e}") - + # Restore original settings (if they existed) for key, value in original_settings.items(): if value is not None: @@ -126,11 +141,9 @@ async def async_test_client(): def sample_chat_request(): """Sample chat completion request.""" return { - "model": get_default_model(), - "messages": [ - {"role": "user", "content": "Hi"} - ], - "stream": False + "model": get_test_model_id(), + "messages": [{"role": "user", "content": "Hi"}], + "stream": False, } @@ -138,21 +151,16 @@ def sample_chat_request(): def sample_streaming_request(): """Sample streaming chat completion request.""" return { - "model": get_default_model(), - "messages": [ - {"role": "user", "content": "Tell me a joke"} - ], - "stream": True + "model": get_test_model_id(), + "messages": [{"role": "user", "content": "Tell me a joke"}], + "stream": True, } @pytest.fixture def sample_project_request(): """Sample project creation request.""" - return { - "name": "Test Project", - "description": "A test project" - } + return {"name": "Test Project", "description": "A test project"} @pytest.fixture @@ -161,7 +169,7 @@ def sample_session_request(): return { "project_id": "test-project", "title": "Test Session", - "model": get_default_model() + "model": get_test_model_id(), } @@ -171,15 +179,9 @@ def pytest_configure(config): config.addinivalue_line( "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')" ) - config.addinivalue_line( - "markers", "integration: marks tests as integration tests" - ) - config.addinivalue_line( - "markers", "unit: marks tests as unit tests" - ) - config.addinivalue_line( - "markers", "e2e: marks tests as end-to-end tests" - ) + config.addinivalue_line("markers", "integration: marks tests as integration tests") + config.addinivalue_line("markers", "unit: marks tests as unit tests") + config.addinivalue_line("markers", "e2e: marks tests as end-to-end tests") def pytest_collection_modifyitems(config, items): @@ -190,7 +192,10 @@ def pytest_collection_modifyitems(config, items): item.add_marker(pytest.mark.integration) elif "unit" in item.nodeid: item.add_marker(pytest.mark.unit) - + # Mark slow tests - if any(keyword in item.name.lower() for keyword in ["concurrent", "performance", "large"]): + if any( + keyword in item.name.lower() + for keyword in ["concurrent", "performance", "large"] + ): item.add_marker(pytest.mark.slow) diff --git a/tests/fixtures/claude_stream_simple.jsonl b/tests/fixtures/claude_stream_simple.jsonl index 5d42d1a..a33d3ca 100644 --- a/tests/fixtures/claude_stream_simple.jsonl +++ b/tests/fixtures/claude_stream_simple.jsonl @@ -1,3 +1,3 @@ -{"type":"system","message":{"role":"system","content":[{"type":"text","text":"You are Claude Code."}]},"session_id":"sess_simple_1","model":"claude-3-5-haiku-20241022","cwd":".","tools":["bash","read"],"timestamp":"2026-02-04T00:00:00Z"} -{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"Hello! How can I help today?"}]},"session_id":"sess_simple_1","model":"claude-3-5-haiku-20241022"} -{"type":"result","result":"ok","session_id":"sess_simple_1","model":"claude-3-5-haiku-20241022","usage":{"input_tokens":12,"output_tokens":8},"cost_usd":0.00002,"duration_ms":1200,"num_turns":1} +{"type":"system","message":{"role":"system","content":[{"type":"text","text":"You are Claude Code."}]},"session_id":"sess_simple_1","model":"claude-haiku-4-5-20250929","cwd":".","tools":["bash","read"],"timestamp":"2026-02-04T00:00:00Z"} +{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"Hello! How can I help today?"}]},"session_id":"sess_simple_1","model":"claude-haiku-4-5-20250929"} +{"type":"result","result":"ok","session_id":"sess_simple_1","model":"claude-haiku-4-5-20250929","usage":{"input_tokens":12,"output_tokens":8},"cost_usd":0.00002,"duration_ms":1200,"num_turns":1} diff --git a/tests/fixtures/claude_stream_tool_calls.jsonl b/tests/fixtures/claude_stream_tool_calls.jsonl index d893cff..3084de5 100644 --- a/tests/fixtures/claude_stream_tool_calls.jsonl +++ b/tests/fixtures/claude_stream_tool_calls.jsonl @@ -1,5 +1,5 @@ -{"type":"system","message":{"role":"system","content":[{"type":"text","text":"You are Claude Code."}]},"session_id":"sess_tool_1","model":"claude-3-5-haiku-20241022","cwd":".","tools":["bash","read"],"timestamp":"2026-02-04T00:00:00Z"} -{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"I'll list the files."},{"type":"tool_use","id":"toolu_123","name":"bash","input":{"command":"ls -1"}}]},"session_id":"sess_tool_1","model":"claude-3-5-haiku-20241022"} -{"type":"tool_result","message":{"role":"tool","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"README.md\nclaude_code_api\n","is_error":false}]},"session_id":"sess_tool_1","model":"claude-3-5-haiku-20241022"} -{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"I found README.md and claude_code_api."}]},"session_id":"sess_tool_1","model":"claude-3-5-haiku-20241022"} -{"type":"result","result":"ok","session_id":"sess_tool_1","model":"claude-3-5-haiku-20241022","usage":{"input_tokens":20,"output_tokens":15},"cost_usd":0.00005,"duration_ms":1500,"num_turns":1} +{"type":"system","message":{"role":"system","content":[{"type":"text","text":"You are Claude Code."}]},"session_id":"sess_tool_1","model":"claude-haiku-4-5-20250929","cwd":".","tools":["bash","read"],"timestamp":"2026-02-04T00:00:00Z"} +{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"I'll list the files."},{"type":"tool_use","id":"toolu_123","name":"bash","input":{"command":"ls -1"}}]},"session_id":"sess_tool_1","model":"claude-haiku-4-5-20250929"} +{"type":"tool_result","message":{"role":"tool","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"README.md\nclaude_code_api\n","is_error":false}]},"session_id":"sess_tool_1","model":"claude-haiku-4-5-20250929"} +{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"I found README.md and claude_code_api."}]},"session_id":"sess_tool_1","model":"claude-haiku-4-5-20250929"} +{"type":"result","result":"ok","session_id":"sess_tool_1","model":"claude-haiku-4-5-20250929","usage":{"input_tokens":20,"output_tokens":15},"cost_usd":0.00005,"duration_ms":1500,"num_turns":1} diff --git a/tests/model_utils.py b/tests/model_utils.py new file mode 100644 index 0000000..bf4fab1 --- /dev/null +++ b/tests/model_utils.py @@ -0,0 +1,14 @@ +"""Shared test model selection helpers.""" + +import os + +from claude_code_api.models.claude import get_available_models, get_default_model + +TEST_MODEL_ID = os.getenv("CLAUDE_CODE_API_TEST_MODEL", "claude-haiku-4-5-20250929") + + +def get_test_model_id() -> str: + available = {model.id for model in get_available_models()} + if TEST_MODEL_ID in available: + return TEST_MODEL_ID + return get_default_model() diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..eaa9eba --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,196 @@ +"""Tests for auth utilities and middleware.""" + +import json + +import pytest +from fastapi.responses import JSONResponse +from starlette.requests import Request + +from claude_code_api.core import auth as auth_module +from claude_code_api.core.config import settings + + +def _build_request( + path: str = "/v1/models", headers=None, query_string: bytes = b"" +) -> Request: + headers = headers or [] + scope = { + "type": "http", + "method": "GET", + "path": path, + "headers": headers, + "query_string": query_string, + "client": ("127.0.0.1", 12345), + "server": ("testserver", 80), + "scheme": "http", + } + return Request(scope) + + +def test_extract_api_key_sources(): + request = _build_request(headers=[(b"authorization", b"Bearer secret")]) + assert auth_module.extract_api_key(request) == "secret" + + request = _build_request(headers=[(b"x-api-key", b"apikey")]) + assert auth_module.extract_api_key(request) == "apikey" + + request = _build_request(query_string=b"api_key=querykey") + assert auth_module.extract_api_key(request) == "querykey" + + request = _build_request() + assert auth_module.extract_api_key(request) is None + + +def test_rate_limiter_basic(): + limiter = auth_module.RateLimiter(requests_per_minute=2, burst=10) + assert limiter.is_allowed("client") is True + assert limiter.is_allowed("client") is True + assert limiter.is_allowed("client") is False + + +def test_rate_limiter_burst_reset(monkeypatch): + limiter = auth_module.RateLimiter(requests_per_minute=100, burst=1) + now = [100.0] + monkeypatch.setattr(auth_module.time, "time", lambda: now[0]) + assert limiter.is_allowed("client") is True + # Move time forward so requests are cleared, but burst_used is still set. + now[0] += 61.0 + assert limiter.is_allowed("client") is True + + +def test_validate_api_key_toggle(): + original_require_auth = settings.require_auth + original_keys = list(settings.api_keys) + try: + settings.require_auth = False + settings.api_keys = ["secret"] + assert auth_module.validate_api_key("secret") is True + + settings.require_auth = True + settings.api_keys = [] + assert auth_module.validate_api_key("secret") is False + + settings.api_keys = ["secret"] + assert auth_module.validate_api_key("secret") is True + assert auth_module.validate_api_key("bad") is False + finally: + settings.require_auth = original_require_auth + settings.api_keys = original_keys + + +@pytest.mark.asyncio +async def test_auth_middleware_allows_when_disabled(): + original_require_auth = settings.require_auth + settings.require_auth = False + captured = {} + + async def call_next(req: Request): + captured["api_key"] = req.state.api_key + captured["client_id"] = req.state.client_id + return JSONResponse({"ok": True}) + + request = _build_request() + response = await auth_module.auth_middleware(request, call_next) + assert response.status_code == 200 + assert captured["api_key"] is None + assert captured["client_id"] == "testclient" + + settings.require_auth = original_require_auth + + +@pytest.mark.asyncio +async def test_auth_middleware_missing_key(): + original_require_auth = settings.require_auth + original_keys = list(settings.api_keys) + settings.require_auth = True + settings.api_keys = ["secret"] + + request = _build_request() + + async def call_next(req: Request): + return JSONResponse({"ok": True}) + + response = await auth_module.auth_middleware(request, call_next) + assert response.status_code == 401 + payload = response.json() if hasattr(response, "json") else None + if payload is None: + payload = json.loads(response.body.decode()) + assert payload["error"]["code"] == "missing_api_key" + + settings.require_auth = original_require_auth + settings.api_keys = original_keys + + +@pytest.mark.asyncio +async def test_auth_middleware_invalid_key(): + original_require_auth = settings.require_auth + original_keys = list(settings.api_keys) + settings.require_auth = True + settings.api_keys = ["secret"] + + request = _build_request(headers=[(b"authorization", b"Bearer bad")]) + + async def call_next(req: Request): + return JSONResponse({"ok": True}) + + response = await auth_module.auth_middleware(request, call_next) + assert response.status_code == 401 + payload = response.json() if hasattr(response, "json") else None + if payload is None: + payload = json.loads(response.body.decode()) + assert payload["error"]["code"] == "invalid_api_key" + + settings.require_auth = original_require_auth + settings.api_keys = original_keys + + +@pytest.mark.asyncio +async def test_auth_middleware_rate_limited(monkeypatch): + original_require_auth = settings.require_auth + original_keys = list(settings.api_keys) + settings.require_auth = True + settings.api_keys = ["secret"] + + monkeypatch.setattr(auth_module.rate_limiter, "is_allowed", lambda _key: False) + + request = _build_request(headers=[(b"authorization", b"Bearer secret")]) + + async def call_next(req: Request): + return JSONResponse({"ok": True}) + + response = await auth_module.auth_middleware(request, call_next) + assert response.status_code == 429 + payload = response.json() if hasattr(response, "json") else None + if payload is None: + payload = json.loads(response.body.decode()) + assert payload["error"]["code"] == "rate_limit_exceeded" + + settings.require_auth = original_require_auth + settings.api_keys = original_keys + + +@pytest.mark.asyncio +async def test_auth_middleware_valid_key(): + original_require_auth = settings.require_auth + original_keys = list(settings.api_keys) + settings.require_auth = True + settings.api_keys = ["secret"] + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(auth_module.rate_limiter, "is_allowed", lambda _key: True) + + captured = {} + + async def call_next(req: Request): + captured["api_key"] = req.state.api_key + captured["client_id"] = req.state.client_id + return JSONResponse({"ok": True}) + + request = _build_request(headers=[(b"authorization", b"Bearer secret")]) + response = await auth_module.auth_middleware(request, call_next) + assert response.status_code == 200 + assert captured["api_key"] == "secret" + assert captured["client_id"] == "secret" + + monkeypatch.undo() + settings.require_auth = original_require_auth + settings.api_keys = original_keys diff --git a/tests/test_claude_manager_unit.py b/tests/test_claude_manager_unit.py new file mode 100644 index 0000000..b4a517f --- /dev/null +++ b/tests/test_claude_manager_unit.py @@ -0,0 +1,48 @@ +"""Unit tests for Claude manager helpers.""" + +import os + +from claude_code_api.core import claude_manager as cm +from claude_code_api.core.config import settings + + +def test_create_and_cleanup_project_directory(tmp_path): + original_root = settings.project_root + try: + settings.project_root = str(tmp_path) + project_path = cm.create_project_directory("proj1") + assert os.path.isdir(project_path) + cm.cleanup_project_directory(project_path) + assert not os.path.exists(project_path) + finally: + settings.project_root = original_root + + +def test_validate_claude_binary(monkeypatch): + class Result: + def __init__(self, returncode): + self.returncode = returncode + + def fake_run(*_args, **_kwargs): + return Result(0) + + monkeypatch.setattr(cm.subprocess, "run", fake_run) + assert cm.validate_claude_binary() is True + + def fake_run_fail(*_args, **_kwargs): + raise OSError("nope") + + monkeypatch.setattr(cm.subprocess, "run", fake_run_fail) + assert cm.validate_claude_binary() is False + + +def test_decode_output_line(): + process = cm.ClaudeProcess(session_id="sess", project_path="/tmp") + data = process._decode_output_line(b'{"type":"assistant"}\n') + assert data["type"] == "assistant" + + data = process._decode_output_line(b'data: {"type":"assistant"}\n') + assert data["type"] == "assistant" + + data = process._decode_output_line(b"not-json\n") + assert data["type"] == "text" diff --git a/tests/test_claude_working.py b/tests/test_claude_working.py index 7eba735..e5ffdf5 100644 --- a/tests/test_claude_working.py +++ b/tests/test_claude_working.py @@ -4,7 +4,7 @@ from pathlib import Path from claude_code_api.utils.streaming import create_non_streaming_response -from claude_code_api.models.claude import get_default_model +from tests.model_utils import get_test_model_id FIXTURES_DIR = Path(__file__).parent / "fixtures" @@ -18,9 +18,7 @@ def test_fixture_simple_non_streaming_response(): """Ensure basic fixture output produces a valid response.""" messages = load_fixture("claude_stream_simple.jsonl") response = create_non_streaming_response( - messages=messages, - session_id="sess_simple_1", - model=get_default_model() + messages=messages, session_id="sess_simple_1", model=get_test_model_id() ) choice = response["choices"][0] @@ -33,12 +31,46 @@ def test_fixture_tool_calls_response(): """Ensure tool calls are surfaced from fixture output.""" messages = load_fixture("claude_stream_tool_calls.jsonl") response = create_non_streaming_response( - messages=messages, - session_id="sess_tool_1", - model=get_default_model() + messages=messages, session_id="sess_tool_1", model=get_test_model_id() ) message = response["choices"][0]["message"] assert "tool_calls" in message assert len(message["tool_calls"]) > 0 assert message["tool_calls"][0]["function"]["name"] == "bash" + + +def test_tool_calls_without_text_finish_reason(): + """Tool-only responses should report tool_calls finish reason.""" + messages = [ + { + "type": "assistant", + "message": { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_1", + "name": "bash", + "input": {"command": "ls"}, + } + ], + }, + "session_id": "sess_tool_only", + "model": get_test_model_id(), + }, + { + "type": "result", + "result": "ok", + "session_id": "sess_tool_only", + "model": get_test_model_id(), + "usage": {"input_tokens": 5, "output_tokens": 5}, + }, + ] + response = create_non_streaming_response( + messages=messages, session_id="sess_tool_only", model=get_test_model_id() + ) + + choice = response["choices"][0] + assert choice["finish_reason"] == "tool_calls" + assert choice["message"]["content"] in ("", None) diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..42995be --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,106 @@ +"""Tests for configuration helpers.""" + +import glob +import os +import subprocess + +from claude_code_api.core import config as config_module + + +def test_find_claude_binary_env(monkeypatch, tmp_path): + fake = tmp_path / "claude" + fake.write_text("bin") + monkeypatch.setenv("CLAUDE_BINARY_PATH", str(fake)) + assert config_module.find_claude_binary() == str(fake) + + +def test_find_claude_binary_shutil(monkeypatch): + monkeypatch.delenv("CLAUDE_BINARY_PATH", raising=False) + monkeypatch.setattr( + config_module.shutil, "which", lambda _name: "/usr/local/bin/claude" + ) + assert config_module.find_claude_binary() == "/usr/local/bin/claude" + + +def test_find_claude_binary_npm(monkeypatch, tmp_path): + monkeypatch.delenv("CLAUDE_BINARY_PATH", raising=False) + monkeypatch.setattr(config_module.shutil, "which", lambda _name: None) + + npm_bin = tmp_path / "bin" + npm_bin.mkdir() + (npm_bin / "claude").write_text("bin") + + class Result: + returncode = 0 + stdout = str(npm_bin) + + monkeypatch.setattr(subprocess, "run", lambda *_a, **_k: Result()) + assert config_module.find_claude_binary() == str(npm_bin / "claude") + + +def test_find_claude_binary_glob(monkeypatch): + monkeypatch.delenv("CLAUDE_BINARY_PATH", raising=False) + monkeypatch.setattr(config_module.shutil, "which", lambda _name: None) + monkeypatch.setattr( + subprocess, "run", lambda *_a, **_k: (_ for _ in ()).throw(OSError("no")) + ) + monkeypatch.setattr(glob, "glob", lambda _pattern: ["/a/claude", "/b/claude"]) + assert config_module.find_claude_binary() == "/b/claude" + + +def test_find_claude_binary_fallback(monkeypatch): + monkeypatch.delenv("CLAUDE_BINARY_PATH", raising=False) + monkeypatch.setattr(config_module.shutil, "which", lambda _name: None) + monkeypatch.setattr( + subprocess, "run", lambda *_a, **_k: (_ for _ in ()).throw(OSError("no")) + ) + monkeypatch.setattr(glob, "glob", lambda _pattern: []) + assert config_module.find_claude_binary() == "claude" + + +def test_looks_like_dotenv(tmp_path): + good = tmp_path / "good.env" + good.write_text("KEY=VALUE\n") + assert config_module._looks_like_dotenv(str(good)) is True + + export = tmp_path / "export.env" + export.write_text("export KEY=VALUE\n") + assert config_module._looks_like_dotenv(str(export)) is True + + bad = tmp_path / "bad.env" + bad.write_text("#!/bin/bash\necho nope\n") + assert config_module._looks_like_dotenv(str(bad)) is False + + bad2 = tmp_path / "bad2.env" + bad2.write_text("if [ 1 = 1 ]; then\n") + assert config_module._looks_like_dotenv(str(bad2)) is False + + +def test_looks_like_dotenv_missing_file(): + assert config_module._looks_like_dotenv("/tmp/does-not-exist.env") is False + + +def test_shell_script_line_detection(): + assert config_module._is_shell_script_line("if something") is True + assert config_module._is_shell_script_line("BASH_SOURCE") is True + + +def test_resolve_env_file(monkeypatch, tmp_path): + monkeypatch.delenv("CLAUDE_CODE_API_ENV_FILE", raising=False) + cwd = os.getcwd() + try: + os.chdir(tmp_path) + env_path = tmp_path / ".env" + env_path.write_text("KEY=VALUE\n") + assert config_module._resolve_env_file() == ".env" + finally: + os.chdir(cwd) + + monkeypatch.setenv("CLAUDE_CODE_API_ENV_FILE", "/tmp/explicit.env") + assert config_module._resolve_env_file() == "/tmp/explicit.env" + + +def test_settings_parsers(): + settings = config_module.Settings() + assert settings.parse_api_keys("a, b ,") == ["a", "b"] + assert settings.parse_cors_lists("x,y") == ["x", "y"] diff --git a/tests/test_e2e_live_api.py b/tests/test_e2e_live_api.py index 09beaa8..90af89d 100644 --- a/tests/test_e2e_live_api.py +++ b/tests/test_e2e_live_api.py @@ -1,13 +1,15 @@ """End-to-end tests against a running API server.""" -import os import json -import pytest +import os + import httpx -from claude_code_api.models.claude import get_default_model +import pytest +from tests.model_utils import get_test_model_id BASE_URL = os.getenv("CLAUDE_CODE_API_BASE_URL", "http://localhost:8000") +MODEL_ID = get_test_model_id() def _should_run_e2e() -> bool: @@ -62,9 +64,9 @@ def test_live_models(live_client): @pytest.mark.e2e def test_live_chat_completion(live_client): payload = { - "model": get_default_model(), + "model": MODEL_ID, "messages": [{"role": "user", "content": "Say only 'hi'."}], - "stream": False + "stream": False, } response = live_client.post("/v1/chat/completions", json=payload) assert response.status_code == 200 @@ -76,12 +78,32 @@ def test_live_chat_completion(live_client): @pytest.mark.e2e def test_live_chat_streaming(live_client): payload = { - "model": get_default_model(), + "model": MODEL_ID, "messages": [{"role": "user", "content": "Say only 'hi'."}], - "stream": True + "stream": True, } with live_client.stream("POST", "/v1/chat/completions", json=payload) as response: assert response.status_code == 200 lines = [line for line in response.iter_lines() if line] events = _parse_sse_lines(lines) assert any(event.get("object") == "chat.completion.chunk" for event in events) + + +@pytest.mark.e2e +def test_live_tool_calls(live_client): + payload = { + "model": MODEL_ID, + "messages": [ + { + "role": "user", + "content": "Use the bash tool to run 'ls -1' and return the output.", + } + ], + "stream": False, + } + response = live_client.post("/v1/chat/completions", json=payload) + assert response.status_code == 200 + data = response.json() + message = data["choices"][0]["message"] + assert message.get("tool_calls") + assert message["tool_calls"][0]["function"]["name"] == "bash" diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 56c08d7..691a06e 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -4,21 +4,19 @@ This test suite tests the complete API functionality including: - OpenAI-compatible chat completions - Model endpoints -- Project and session management +- Project and session management - Streaming and non-streaming responses """ -import pytest -import asyncio import json -import uuid -from datetime import datetime -from typing import Dict, Any, List -from httpx import AsyncClient -from fastapi.testclient import TestClient import os -import tempfile import shutil +import tempfile +from typing import Any, Dict, List + +import pytest +from fastapi.testclient import TestClient +from httpx import AsyncClient def parse_sse_events(body_text: str) -> List[Dict[str, Any]]: @@ -33,6 +31,7 @@ def parse_sse_events(body_text: str) -> List[Dict[str, Any]]: events.append(json.loads(payload)) return events + # Import the FastAPI app import sys from pathlib import Path @@ -41,39 +40,41 @@ def parse_sse_events(body_text: str) -> List[Dict[str, Any]]: PROJECT_ROOT = Path(__file__).parent.parent sys.path.insert(0, str(PROJECT_ROOT)) -from claude_code_api.main import app from claude_code_api.core.config import settings -from claude_code_api.models.claude import get_available_models, get_default_model - +from claude_code_api.main import app +from claude_code_api.models.claude import get_available_models +from tests.model_utils import get_test_model_id AVAILABLE_MODELS = get_available_models() -DEFAULT_MODEL = get_default_model() +DEFAULT_MODEL = get_test_model_id() class TestConfig: """Test configuration.""" - + @classmethod def setup_test_environment(cls): """Setup test environment with mock Claude binary.""" # Create temporary directories for testing - cls.temp_dir = tempfile.mkdtemp() + test_root = PROJECT_ROOT / "dist" / "tests" + test_root.mkdir(parents=True, exist_ok=True) + cls.temp_dir = tempfile.mkdtemp(dir=str(test_root)) cls.project_root = os.path.join(cls.temp_dir, "projects") os.makedirs(cls.project_root, exist_ok=True) - + # Override settings for testing settings.project_root = cls.project_root settings.require_auth = False # Disable auth for testing # Keep real Claude binary - DO NOT mock it! # settings.claude_binary_path should remain as found by find_claude_binary() settings.database_url = f"sqlite:///{cls.temp_dir}/test.db" - + return cls.temp_dir - + @classmethod def cleanup_test_environment(cls): """Cleanup test environment.""" - if hasattr(cls, 'temp_dir') and os.path.exists(cls.temp_dir): + if hasattr(cls, "temp_dir") and os.path.exists(cls.temp_dir): shutil.rmtree(cls.temp_dir) @@ -101,22 +102,22 @@ async def async_client(test_environment): class TestHealthAndBasics: """Test basic API functionality.""" - + def test_health_check(self, client): """Test health check endpoint.""" response = client.get("/health") assert response.status_code == 200 - + data = response.json() assert data["status"] == "healthy" assert "version" in data assert "active_sessions" in data - + def test_root_endpoint(self, client): """Test root endpoint.""" response = client.get("/") assert response.status_code == 200 - + data = response.json() assert data["name"] == "Claude Code API Gateway" assert "endpoints" in data @@ -125,17 +126,17 @@ def test_root_endpoint(self, client): class TestModelsAPI: """Test models API endpoints.""" - + def test_list_models(self, client): """Test listing available models.""" response = client.get("/v1/models") assert response.status_code == 200 - + data = response.json() assert data["object"] == "list" assert "data" in data assert len(data["data"]) > 0 - + # Check model structure model = data["data"][0] assert "id" in model @@ -143,32 +144,32 @@ def test_list_models(self, client): assert model["object"] == "model" assert "created" in model assert "owned_by" in model - + def test_get_specific_model(self, client): """Test getting specific model.""" # Test Claude model model_id = AVAILABLE_MODELS[0].id response = client.get(f"/v1/models/{model_id}") assert response.status_code == 200 - + data = response.json() assert data["id"] == model_id assert data["object"] == "model" - + def test_get_openai_alias_model(self, client): """Test getting non-existent OpenAI model (not supported).""" response = client.get("/v1/models/gpt-4") assert response.status_code == 404 - + def test_get_nonexistent_model(self, client): """Test getting non-existent model.""" response = client.get("/v1/models/nonexistent-model") assert response.status_code == 404 - + data = response.json() assert "error" in data assert data["error"]["code"] == "model_not_found" - + def test_model_capabilities(self, client): """Test model capabilities endpoint.""" # Skip capabilities test for now - extension endpoint @@ -177,20 +178,18 @@ def test_model_capabilities(self, client): class TestChatCompletions: """Test chat completions API.""" - + def test_simple_chat_completion_non_streaming(self, client): """Test simple non-streaming chat completion.""" request_data = { "model": DEFAULT_MODEL, - "messages": [ - {"role": "user", "content": "Hi"} - ], - "stream": False + "messages": [{"role": "user", "content": "Hi"}], + "stream": False, } - + response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 - + data = response.json() assert "id" in data assert data["object"] == "chat.completion" @@ -198,60 +197,56 @@ def test_simple_chat_completion_non_streaming(self, client): assert data["model"] == DEFAULT_MODEL assert "choices" in data assert len(data["choices"]) > 0 - + choice = data["choices"][0] assert choice["index"] == 0 assert "message" in choice assert choice["message"]["role"] == "assistant" assert "content" in choice["message"] assert "usage" in data - + def test_chat_completion_with_system_prompt(self, client): """Test chat completion with system prompt.""" request_data = { "model": DEFAULT_MODEL, "messages": [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"} + {"role": "user", "content": "Hello, how are you?"}, ], - "stream": False + "stream": False, } - + response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 - + data = response.json() assert data["model"] == DEFAULT_MODEL assert len(data["choices"]) > 0 - + def test_chat_completion_with_invalid_model_fallback(self, client): """Test chat completion with invalid model (should fallback to default).""" request_data = { "model": "invalid-model", - "messages": [ - {"role": "user", "content": "What's 2+2?"} - ], - "stream": False + "messages": [{"role": "user", "content": "What's 2+2?"}], + "stream": False, } - + response = client.post("/v1/chat/completions", json=request_data) # Should work with fallback to default model assert response.status_code in [200, 503] # 503 if Claude not available - + def test_chat_completion_streaming(self, client): """Test streaming chat completion.""" request_data = { "model": DEFAULT_MODEL, - "messages": [ - {"role": "user", "content": "Tell me a short joke"} - ], - "stream": True + "messages": [{"role": "user", "content": "Tell me a short joke"}], + "stream": True, } - + response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 assert "text/event-stream" in response.headers["content-type"] - + # Check that we get streaming data content = response.text assert "data: " in content @@ -266,12 +261,12 @@ def test_chat_completion_with_tool_calls(self, client): "messages": [ {"role": "user", "content": "Please use a tool to list files"} ], - "stream": False + "stream": False, } - + response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 - + data = response.json() choice = data["choices"][0] message = choice["message"] @@ -286,9 +281,9 @@ def test_chat_completion_streaming_tool_calls(self, client): "messages": [ {"role": "user", "content": "Please use a tool to list files"} ], - "stream": True + "stream": True, } - + response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 assert "text/event-stream" in response.headers["content-type"] @@ -299,7 +294,7 @@ def test_chat_completion_streaming_tool_calls(self, client): for event in events for choice in event.get("choices", []) ) - + def test_chat_completion_with_project_context(self, client): """Test chat completion with project context.""" request_data = { @@ -308,58 +303,50 @@ def test_chat_completion_with_project_context(self, client): {"role": "user", "content": "Hi, I'm working on a Python project"} ], "project_id": "test-project-123", - "stream": False + "stream": False, } - + response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 - + data = response.json() assert "project_id" in data assert data["project_id"] == "test-project-123" - + def test_chat_completion_missing_messages(self, client): """Test chat completion with missing messages.""" - request_data = { - "model": DEFAULT_MODEL, - "messages": [], - "stream": False - } - + request_data = {"model": DEFAULT_MODEL, "messages": [], "stream": False} + response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 400 - + data = response.json() assert "error" in data assert data["error"]["code"] == "missing_messages" - + def test_chat_completion_no_user_message(self, client): """Test chat completion with no user message.""" request_data = { "model": DEFAULT_MODEL, - "messages": [ - {"role": "system", "content": "You are a helpful assistant."} - ], - "stream": False + "messages": [{"role": "system", "content": "You are a helpful assistant."}], + "stream": False, } - + response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 400 - + data = response.json() assert "error" in data assert data["error"]["code"] == "missing_user_message" - + def test_chat_completion_invalid_model(self, client): """Test chat completion with invalid model.""" request_data = { "model": "invalid-model", - "messages": [ - {"role": "user", "content": "Hi"} - ], - "stream": False + "messages": [{"role": "user", "content": "Hi"}], + "stream": False, } - + response = client.post("/v1/chat/completions", json=request_data) # Should still work as model gets converted to default assert response.status_code in [200, 503] # 503 if Claude not available @@ -367,54 +354,59 @@ def test_chat_completion_invalid_model(self, client): class TestConversationFlow: """Test conversation flow and session management.""" - + def test_conversation_continuity(self, client): """Test conversation continuity across messages.""" # First message request_data_1 = { - "model": DEFAULT_MODEL, - "messages": [ - {"role": "user", "content": "My name is Alice"} - ], - "stream": False + "model": DEFAULT_MODEL, + "messages": [{"role": "user", "content": "My name is Alice"}], + "stream": False, } - + response_1 = client.post("/v1/chat/completions", json=request_data_1) assert response_1.status_code == 200 - + data_1 = response_1.json() session_id = data_1.get("session_id") - + if session_id: # Follow-up message in same session request_data_2 = { "model": DEFAULT_MODEL, "messages": [ {"role": "user", "content": "My name is Alice"}, - {"role": "assistant", "content": data_1["choices"][0]["message"]["content"]}, - {"role": "user", "content": "What's my name?"} + { + "role": "assistant", + "content": data_1["choices"][0]["message"]["content"], + }, + {"role": "user", "content": "What's my name?"}, ], "session_id": session_id, - "stream": False + "stream": False, } - + response_2 = client.post("/v1/chat/completions", json=request_data_2) - assert response_2.status_code in [200, 404, 503] # May fail if session management incomplete - + assert response_2.status_code in [ + 200, + 404, + 503, + ] # May fail if session management incomplete + def test_multiple_user_messages(self, client): """Test handling multiple user messages.""" request_data = { "model": DEFAULT_MODEL, "messages": [ {"role": "user", "content": "Hi"}, - {"role": "user", "content": "How are you doing today?"} + {"role": "user", "content": "How are you doing today?"}, ], - "stream": False + "stream": False, } - + response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 - + # Should use the last user message data = response.json() assert len(data["choices"]) > 0 @@ -422,59 +414,59 @@ def test_multiple_user_messages(self, client): class TestProjectsAPI: """Test projects API endpoints.""" - + def test_list_projects(self, client): """Test listing projects.""" response = client.get("/v1/projects") assert response.status_code == 200 - + data = response.json() assert "data" in data assert "pagination" in data - + def test_create_project(self, client): """Test creating a project.""" project_data = { "name": "Test Project", - "description": "A test project for API testing" + "description": "A test project for API testing", } - + response = client.post("/v1/projects", json=project_data) assert response.status_code == 200 - + data = response.json() assert data["name"] == "Test Project" assert data["description"] == "A test project for API testing" assert "id" in data assert "path" in data assert "created_at" in data - + def test_get_project(self, client): """Test getting a specific project.""" # First create a project project_data = { "name": "Test Project for Get", - "description": "Test description" + "description": "Test description", } - + create_response = client.post("/v1/projects", json=project_data) assert create_response.status_code == 200 - + project_id = create_response.json()["id"] - + # Now get the project response = client.get(f"/v1/projects/{project_id}") assert response.status_code == 200 - + data = response.json() assert data["id"] == project_id assert data["name"] == "Test Project for Get" - + def test_get_nonexistent_project(self, client): """Test getting non-existent project.""" response = client.get("/v1/projects/nonexistent-id") assert response.status_code == 404 - + data = response.json() assert "error" in data assert data["error"]["code"] == "project_not_found" @@ -482,38 +474,38 @@ def test_get_nonexistent_project(self, client): class TestSessionsAPI: """Test sessions API endpoints.""" - + def test_list_sessions(self, client): """Test listing sessions.""" response = client.get("/v1/sessions") assert response.status_code == 200 - + data = response.json() assert "data" in data assert "pagination" in data - + def test_create_session(self, client): """Test creating a session.""" session_data = { "project_id": "test-project", "title": "Test Session", - "model": DEFAULT_MODEL + "model": DEFAULT_MODEL, } - + response = client.post("/v1/sessions", json=session_data) assert response.status_code == 200 - + data = response.json() assert data["project_id"] == "test-project" assert data["model"] == DEFAULT_MODEL assert "id" in data assert "created_at" in data - + def test_get_session_stats(self, client): """Test getting session statistics.""" response = client.get("/v1/sessions/stats") assert response.status_code == 200 - + data = response.json() assert "session_stats" in data assert "active_claude_sessions" in data @@ -521,82 +513,78 @@ def test_get_session_stats(self, client): class TestErrorHandling: """Test error handling scenarios.""" - + def test_invalid_json(self, client): """Test handling of invalid JSON.""" response = client.post( "/v1/chat/completions", data="invalid json", - headers={"content-type": "application/json"} + headers={"content-type": "application/json"}, ) # API returns 400 for JSON decode errors (handled manually) assert response.status_code == 400 - + def test_missing_required_fields(self, client): """Test handling of missing required fields.""" request_data = { - "messages": [ - {"role": "user", "content": "Hi"} - ] + "messages": [{"role": "user", "content": "Hi"}] # Missing required "model" field } - + response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 422 # Validation error - + def test_invalid_message_role(self, client): """Test handling of invalid message role.""" request_data = { "model": DEFAULT_MODEL, - "messages": [ - {"role": "invalid_role", "content": "Hi"} - ] + "messages": [{"role": "invalid_role", "content": "Hi"}], } - + response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 422 # Validation error class TestRealWorldScenarios: """Test real-world usage scenarios.""" - + def test_simple_greeting(self, client): """Test simple greeting - most common use case.""" request_data = { "model": DEFAULT_MODEL, - "messages": [ - {"role": "user", "content": "Hi"} - ] + "messages": [{"role": "user", "content": "Hi"}], } - + response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 - + data = response.json() assert "choices" in data assert len(data["choices"]) > 0 assert "message" in data["choices"][0] # Content might be empty with mock setup, just check structure assert "content" in data["choices"][0]["message"] - def test_code_generation_request(self, client): """Test code generation request.""" request_data = { "model": DEFAULT_MODEL, "messages": [ - {"role": "user", "content": "Write a Python function to calculate fibonacci numbers"} - ] + { + "role": "user", + "content": "Write a Python function to calculate fibonacci numbers", + } + ], } - + response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 - + data = response.json() assert "choices" in data assert len(data["choices"]) > 0 # Could check for code-like content but Echo won't generate real code - + def test_multi_turn_conversation(self, client): """Test multi-turn conversation simulation.""" # Simulate a multi-turn conversation in a single request @@ -604,30 +592,30 @@ def test_multi_turn_conversation(self, client): "model": DEFAULT_MODEL, "messages": [ {"role": "user", "content": "Hi, I'm learning Python"}, - {"role": "assistant", "content": "Hello! That's great that you're learning Python. It's an excellent programming language for beginners and professionals alike. What specifically would you like to know about Python?"}, - {"role": "user", "content": "How do I create a list?"} - ] + { + "role": "assistant", + "content": ( + "Hello! That's great that you're learning Python. It's an " + "excellent programming language for beginners and professionals " + "alike. What specifically would you like to know about Python?" + ), + }, + {"role": "user", "content": "How do I create a list?"}, + ], } - + response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 - + data = response.json() assert "choices" in data assert len(data["choices"]) > 0 - - # Test configuration and markers pytestmark = pytest.mark.asyncio if __name__ == "__main__": # Run tests with coverage - pytest.main([ - __file__, - "-v", - "--tb=short", - "--disable-warnings" - ]) + pytest.main([__file__, "-v", "--tb=short", "--disable-warnings"]) diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 0000000..5513b4a --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,170 @@ +"""Tests for parser utilities.""" + +import json +from types import SimpleNamespace + +from claude_code_api.models.claude import ClaudeMessage, ClaudeToolUse +from claude_code_api.utils.parser import ( + ClaudeOutputParser, + MessageAggregator, + estimate_tokens, + extract_error_from_message, + format_timestamp, + normalize_claude_message, + sanitize_content, + tool_use_to_openai_call, +) + + +def _message_with_content(content): + return ClaudeMessage( + type="assistant", + message={"role": "assistant", "content": content}, + session_id="sess", + model="claude", + ) + + +def test_extract_text_content_variants(): + parser = ClaudeOutputParser() + assert parser.extract_text_content(_message_with_content("plain")) == "plain" + assert ( + parser.extract_text_content(_message_with_content({"text": "nested"})) + == "nested" + ) + assert ( + parser.extract_text_content(_message_with_content({"content": "inner"})) + == "inner" + ) + + content = [{"type": "text", "text": "hello"}, "world", {"content": "x"}] + assert ( + parser.extract_text_content(_message_with_content(content)) == "hello\nworld\nx" + ) + + assert parser.extract_text_content(_message_with_content(123)) == "123" + + +def test_extract_tool_uses_and_results(): + parser = ClaudeOutputParser() + message = ClaudeMessage( + type="assistant", + message={ + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "tool1", + "name": "bash", + "input": {"command": "ls"}, + } + ], + }, + ) + tool_uses = parser.extract_tool_uses(message) + assert tool_uses + assert tool_uses[0].name == "bash" + + result_message = ClaudeMessage( + type="tool_result", + message={ + "role": "tool", + "content": [ + { + "type": "tool_result", + "tool_use_id": "tool1", + "content": "ok", + "is_error": False, + } + ], + }, + ) + results = parser.extract_tool_results(result_message) + assert results + assert results[0].tool_use_id == "tool1" + + +def test_parse_message_updates_metrics(): + parser = ClaudeOutputParser() + message = ClaudeMessage( + type="assistant", + message={"role": "assistant", "content": "hi"}, + session_id="sess", + model="claude", + usage={"input_tokens": 3, "output_tokens": 5}, + cost_usd=0.01, + ) + parser.parse_message(message) + assert parser.session_id == "sess" + assert parser.model == "claude" + assert parser.total_tokens == 8 + assert parser.total_cost == 0.01 + assert parser.message_count == 1 + + +def test_error_extraction_helpers(): + message = ClaudeMessage(type="error", error="boom") + assert extract_error_from_message(message) == "boom" + + message = ClaudeMessage(type="result", result=None) + assert extract_error_from_message(message) == "Execution completed without result" + + message = ClaudeMessage( + type="tool_result", + message={ + "role": "tool", + "content": [ + { + "type": "tool_result", + "tool_use_id": "tool1", + "content": "bad", + "is_error": True, + } + ], + }, + ) + assert extract_error_from_message(message) == "bad" + + +def test_misc_utilities(): + assert estimate_tokens("1234") >= 1 + assert sanitize_content("a\x00b") == "ab" + assert "?" in sanitize_content("bad\udcff") + + assert normalize_claude_message({"type": "assistant"}).type == "assistant" + assert normalize_claude_message(["not", "a", "dict"]) is None + + tool_use = ClaudeToolUse(id="", name="bash", input={"command": "ls"}) + call = tool_use_to_openai_call(tool_use) + assert call["function"]["name"] == "bash" + + tool_use = SimpleNamespace(id="", name="bash", input=set([1, 2])) + call = tool_use_to_openai_call(tool_use) + assert "input" in json.loads(call["function"]["arguments"]) + + assert "T" in format_timestamp(None) + assert format_timestamp("2026-02-04T00:00:00Z").startswith("2026-02-04T") + + +def test_message_aggregator(): + aggregator = MessageAggregator() + aggregator.add_message( + { + "type": "assistant", + "message": {"role": "assistant", "content": "hello"}, + "session_id": "sess", + } + ) + aggregator.add_message( + { + "type": "assistant", + "message": {"role": "assistant", "content": " world"}, + "session_id": "sess", + } + ) + assert aggregator.get_complete_response() == "hello world" + + +def test_parse_line_invalid_json(): + parser = ClaudeOutputParser() + assert parser.parse_line("{not-json}") is None diff --git a/tests/test_real_api.py b/tests/test_real_api.py index 7a866e6..9e107e7 100755 --- a/tests/test_real_api.py +++ b/tests/test_real_api.py @@ -4,21 +4,18 @@ Unlike the fake tests that import the app directly. """ -import requests -import json -import time import sys -import subprocess -import signal -import os -from typing import Optional -from claude_code_api.models.claude import get_default_model + +import requests + +from tests.model_utils import get_test_model_id + class RealAPITester: def __init__(self, base_url: str = "http://localhost:8000"): self.base_url = base_url self.session = requests.Session() - + def test_health(self) -> bool: """Test health endpoint.""" try: @@ -35,7 +32,7 @@ def test_health(self) -> bool: except Exception as e: print(f"Health check failed: {e}") return False - + def test_models(self) -> bool: """Test models endpoint.""" try: @@ -43,7 +40,7 @@ def test_models(self) -> bool: print(f"Models API: {response.status_code}") if response.status_code == 200: data = response.json() - models = data.get('data', []) + models = data.get("data", []) print(f" Found {len(models)} models:") for model in models[:2]: # Show first 2 print(f" - {model.get('id')}") @@ -54,21 +51,23 @@ def test_models(self) -> bool: except Exception as e: print(f"Models test failed: {e}") return False - + def test_auth_bypass(self) -> bool: """Test that API works without auth (should work with current config).""" try: # Test without any auth headers response = self.session.get(f"{self.base_url}/v1/models", timeout=5) print(f"Auth Bypass Test: {response.status_code}") - + if response.status_code == 200: print(" API works without authentication") return True elif response.status_code == 401: print(" API requires authentication") error = response.json() - print(f" Error: {error.get('error', {}).get('message', 'Unknown auth error')}") + print( + f" Error: {error.get('error', {}).get('message', 'Unknown auth error')}" + ) return False else: print(f" Unexpected status: {response.text}") @@ -76,56 +75,57 @@ def test_auth_bypass(self) -> bool: except Exception as e: print(f"Auth test failed: {e}") return False - + def test_chat_completion(self) -> bool: """Test chat completion endpoint (may be slow).""" try: payload = { - "model": get_default_model(), + "model": get_test_model_id(), "messages": [ - {"role": "user", "content": "Say 'test successful' and nothing else"} + { + "role": "user", + "content": "Say 'test successful' and nothing else", + } ], - "stream": False + "stream": False, } - + print("Chat Completion (this may take a while)...") response = self.session.post( - f"{self.base_url}/v1/chat/completions", - json=payload, - timeout=30 + f"{self.base_url}/v1/chat/completions", json=payload, timeout=30 ) - + print(f" Status: {response.status_code}") if response.status_code == 200: data = response.json() - if 'choices' in data and len(data['choices']) > 0: - content = data['choices'][0].get('message', {}).get('content', '') + if "choices" in data and len(data["choices"]) > 0: + content = data["choices"][0].get("message", {}).get("content", "") print(f" Response: {content[:100]}...") return True else: print(f" Error: {response.text[:200]}...") - + return response.status_code == 200 - + except requests.exceptions.Timeout: print(" Chat completion timed out (expected with mock setup)") return True # Timeout is expected with echo mock except Exception as e: print(f"Chat completion failed: {e}") return False - + def run_all_tests(self) -> bool: """Run all tests and return overall success.""" print("REAL End-to-End API Tests") print("=" * 40) - + tests = [ ("Health Check", self.test_health), - ("Models API", self.test_models), + ("Models API", self.test_models), ("Auth Bypass", self.test_auth_bypass), ("Chat Completion", self.test_chat_completion), ] - + results = [] for test_name, test_func in tests: print(f"\n{test_name}:") @@ -137,12 +137,12 @@ def run_all_tests(self) -> bool: except Exception as e: print(f" FAIL: {e}") results.append(False) - + print("\n" + "=" * 40) passed = sum(results) total = len(results) print(f"Results: {passed}/{total} tests passed") - + if passed == total: print("ALL TESTS PASSED!") return True @@ -156,24 +156,24 @@ def check_server_running(url: str = "http://localhost:8000") -> bool: try: response = requests.get(f"{url}/health", timeout=2) return response.status_code == 200 - except: + except Exception: return False def main(): print("Checking if API server is running...") - + if not check_server_running(): print("API server not running on http://localhost:8000") print("Start the server with: make start") sys.exit(1) - + print("Server is running!") print() - + tester = RealAPITester() success = tester.run_all_tests() - + sys.exit(0 if success else 1) diff --git a/tests/test_security.py b/tests/test_security.py index c2ee73d..42c5916 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -1,12 +1,15 @@ import pytest from fastapi import HTTPException + from claude_code_api.core.security import validate_path + def test_validate_path_valid(): base = "/tmp/projects" path = "project1" assert validate_path(path, base) == "/tmp/projects/project1" + def test_validate_path_traversal(): base = "/tmp/projects" path = "../etc/passwd" @@ -15,6 +18,7 @@ def test_validate_path_traversal(): assert exc.value.status_code == 400 assert "Path traversal detected" in exc.value.detail + def test_validate_path_absolute_traversal(): base = "/tmp/projects" path = "/etc/passwd" @@ -23,6 +27,7 @@ def test_validate_path_absolute_traversal(): assert exc.value.status_code == 400 assert "Path traversal detected" in exc.value.detail + def test_validate_path_absolute_valid(): base = "/tmp/projects" path = "/tmp/projects/project1" diff --git a/tests/test_session_manager_unit.py b/tests/test_session_manager_unit.py new file mode 100644 index 0000000..e14e5a3 --- /dev/null +++ b/tests/test_session_manager_unit.py @@ -0,0 +1,77 @@ +"""Unit tests for session manager behaviors.""" + +import types +from datetime import timedelta + +import pytest + +from claude_code_api.core import session_manager as sm_module +from claude_code_api.core.session_manager import SessionInfo, SessionManager +from claude_code_api.utils.time import utc_now + + +@pytest.mark.asyncio +async def test_get_session_from_db(monkeypatch): + manager = SessionManager() + + fake_db_session = types.SimpleNamespace( + id="sess1", + project_id="proj", + model="claude", + system_prompt=None, + created_at=utc_now(), + updated_at=utc_now(), + message_count=2, + total_tokens=10, + total_cost=0.01, + is_active=True, + ) + + async def fake_get_session(_session_id): + return fake_db_session + + monkeypatch.setattr(sm_module.db_manager, "get_session", fake_get_session) + + session = await manager.get_session("sess1") + assert session is not None + assert session.session_id == "sess1" + assert session.total_tokens == 10 + + await manager.cleanup_all() + + +@pytest.mark.asyncio +async def test_cleanup_expired_sessions(): + manager = SessionManager() + session = SessionInfo(session_id="sess", project_id="proj", model="claude") + session.updated_at = utc_now() - timedelta(minutes=60) + manager.active_sessions["sess"] = session + + manager.cleanup_expired_sessions() + assert "sess" not in manager.active_sessions + await manager.cleanup_all() + + +@pytest.mark.asyncio +async def test_session_stats(): + manager = SessionManager() + s1 = SessionInfo(session_id="s1", project_id="p1", model="m1") + s1.total_tokens = 5 + s1.total_cost = 1.5 + s1.message_count = 2 + + s2 = SessionInfo(session_id="s2", project_id="p1", model="m2") + s2.total_tokens = 3 + s2.total_cost = 0.5 + s2.message_count = 1 + + manager.active_sessions["s1"] = s1 + manager.active_sessions["s2"] = s2 + + stats = manager.get_session_stats() + assert stats["active_sessions"] == 2 + assert stats["total_tokens"] == 8 + assert stats["total_cost"] == 2.0 + assert stats["total_messages"] == 3 + assert set(stats["models_in_use"]) == {"m1", "m2"} + await manager.cleanup_all() diff --git a/tests/test_utils_time.py b/tests/test_utils_time.py new file mode 100644 index 0000000..da72053 --- /dev/null +++ b/tests/test_utils_time.py @@ -0,0 +1,18 @@ +"""Tests for time utilities and defaults.""" + +import os + +from claude_code_api.core.config import default_project_root +from claude_code_api.utils.time import utc_now, utc_timestamp + + +def test_utc_time_helpers_monotonic(): + first = utc_now() + second = utc_now() + assert second >= first + assert isinstance(utc_timestamp(), int) + + +def test_default_project_root_is_under_cwd(): + expected = os.path.join(os.getcwd(), "claude_projects") + assert default_project_root() == expected From 5f1646a9440d507bd1957a9975031a78bcff9c66 Mon Sep 17 00:00:00 2001 From: Mehdi Date: Wed, 4 Feb 2026 18:42:38 +0000 Subject: [PATCH 3/9] Fix session mapping --- .gitignore | 1 + claude_code_api/api/chat.py | 17 ++-- claude_code_api/api/sessions.py | 2 +- claude_code_api/core/claude_manager.py | 75 +++++++++++--- claude_code_api/core/config.py | 6 ++ claude_code_api/core/database.py | 10 ++ claude_code_api/core/session_manager.py | 98 ++++++++++++++++--- tests/conftest.py | 2 + .../fixtures/claude_stream_session_map.jsonl | 3 + tests/fixtures/index.json | 7 ++ tests/test_end_to_end.py | 30 ++++++ tests/test_session_manager_unit.py | 9 +- 12 files changed, 221 insertions(+), 39 deletions(-) create mode 100644 tests/fixtures/claude_stream_session_map.jsonl diff --git a/.gitignore b/.gitignore index 8d917ca..0789aef 100644 --- a/.gitignore +++ b/.gitignore @@ -324,6 +324,7 @@ $RECYCLE.BIN/ .claude/ sessions/ projects/ +claude_sessions/ # API keys and tokens api_keys.txt diff --git a/claude_code_api/api/chat.py b/claude_code_api/api/chat.py index 3429830..d573bce 100644 --- a/claude_code_api/api/chat.py +++ b/claude_code_api/api/chat.py @@ -295,12 +295,17 @@ async def create_chat_completion(request: ChatCompletionRequest, req: Request) - # Start Claude Code process try: + + def _register_cli_session(cli_session_id: str): + session_manager.register_cli_session(session_id, cli_session_id) + claude_process = await claude_manager.create_session( session_id=session_id, project_path=project_path, prompt=user_prompt, model=claude_model, system_prompt=system_prompt, + on_cli_session_id=_register_cli_session, ) except Exception as e: logger.error( @@ -314,11 +319,11 @@ async def create_chat_completion(request: ChatCompletionRequest, req: Request) - ) # Use Claude's actual session ID - claude_session_id = claude_process.session_id + api_session_id = session_id # Update session with user message await session_manager.update_session( - session_id=claude_session_id, + session_id=api_session_id, message_content=user_prompt, role="user", tokens_used=estimate_tokens(user_prompt), @@ -327,13 +332,13 @@ async def create_chat_completion(request: ChatCompletionRequest, req: Request) - # Handle streaming vs non-streaming if request.stream: return StreamingResponse( - create_sse_response(claude_session_id, claude_model, claude_process), + create_sse_response(api_session_id, claude_model, claude_process), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", - "X-Session-ID": claude_session_id, + "X-Session-ID": api_session_id, "X-Project-ID": project_id, }, ) @@ -341,7 +346,7 @@ async def create_chat_completion(request: ChatCompletionRequest, req: Request) - return await _collect_non_streaming_response( claude_process=claude_process, session_manager=session_manager, - session_id=claude_session_id, + session_id=api_session_id, model=claude_model, project_id=project_id, ) @@ -461,7 +466,7 @@ async def stop_completion(session_id: str, req: Request) -> Dict[str, str]: await claude_manager.stop_session(session_id) # End session - session_manager.end_session(session_id) + await session_manager.end_session(session_id) logger.info("Chat completion stopped", session_id=session_id) diff --git a/claude_code_api/api/sessions.py b/claude_code_api/api/sessions.py index b174682..50667f3 100644 --- a/claude_code_api/api/sessions.py +++ b/claude_code_api/api/sessions.py @@ -177,7 +177,7 @@ async def delete_session(session_id: str, req: Request) -> JSONResponse: await claude_manager.stop_session(session_id) # End session - session_manager.end_session(session_id) + await session_manager.end_session(session_id) logger.info("Session deleted", session_id=session_id) diff --git a/claude_code_api/core/claude_manager.py b/claude_code_api/core/claude_manager.py index e51013e..71bb82d 100644 --- a/claude_code_api/core/claude_manager.py +++ b/claude_code_api/core/claude_manager.py @@ -4,7 +4,7 @@ import json import os import subprocess -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Any, AsyncGenerator, Callable, Dict, List, Optional import structlog @@ -18,8 +18,15 @@ class ClaudeProcess: """Manages a single Claude Code process.""" - def __init__(self, session_id: str, project_path: str): + def __init__( + self, + session_id: str, + project_path: str, + on_cli_session_id: Optional[Callable[[str], None]] = None, + on_end: Optional[Callable[["ClaudeProcess"], None]] = None, + ): self.session_id = session_id + self.cli_session_id: Optional[str] = None self.project_path = project_path self.process: Optional[asyncio.subprocess.Process] = None self.is_running = False @@ -27,6 +34,8 @@ def __init__(self, session_id: str, project_path: str): self.error_queue = asyncio.Queue() self._output_task: Optional[asyncio.Task] = None self._error_task: Optional[asyncio.Task] = None + self._on_cli_session_id = on_cli_session_id + self._on_end = on_end async def start( self, prompt: str, model: str = None, system_prompt: str = None @@ -123,8 +132,9 @@ async def _read_output(self): logger.info( "Extracted Claude session ID", session_id=claude_session_id ) - # Update our session_id to match Claude's - self.session_id = claude_session_id + self.cli_session_id = claude_session_id + if self._on_cli_session_id: + self._on_cli_session_id(claude_session_id) await self.output_queue.put(data) except Exception as e: @@ -136,6 +146,8 @@ async def _read_output(self): logger.info( "Claude process output stream ended", session_id=self.session_id ) + if self._on_end: + self._on_end(self) async def _read_error(self): """Read stderr from process.""" @@ -237,6 +249,7 @@ class ClaudeManager: def __init__(self): self.processes: Dict[str, ClaudeProcess] = {} + self.cli_session_index: Dict[str, str] = {} self.max_concurrent = settings.max_concurrent_sessions async def get_version(self) -> str: @@ -273,6 +286,7 @@ async def create_session( prompt: str, model: str = None, system_prompt: str = None, + on_cli_session_id: Optional[Callable[[str], None]] = None, ) -> ClaudeProcess: """Create new Claude session.""" # Check concurrent session limit @@ -285,7 +299,17 @@ async def create_session( os.makedirs(project_path, exist_ok=True) # Create process - process = ClaudeProcess(session_id, project_path) + def _handle_cli_session_id(cli_session_id: str): + self._register_cli_session(session_id, cli_session_id) + if on_cli_session_id: + on_cli_session_id(cli_session_id) + + process = ClaudeProcess( + session_id=session_id, + project_path=project_path, + on_cli_session_id=_handle_cli_session_id, + on_end=self._cleanup_process, + ) # Start process success = await process.start( @@ -297,12 +321,11 @@ async def create_session( if not success: raise ClaudeProcessStartError("Failed to start Claude process") - # Don't store processes since Claude CLI completes immediately - # This prevents the "max concurrent sessions" error + self.processes[session_id] = process logger.info( "Claude session created", - session_id=process.session_id, # Use Claude's actual session ID + session_id=process.session_id, active_sessions=len(self.processes), ) @@ -310,18 +333,22 @@ async def create_session( def get_session(self, session_id: str) -> Optional[ClaudeProcess]: """Get existing Claude session.""" - return self.processes.get(session_id) + resolved_id = self._resolve_session_id(session_id) + if not resolved_id: + return None + return self.processes.get(resolved_id) async def stop_session(self, session_id: str): """Stop Claude session.""" - if session_id in self.processes: - process = self.processes[session_id] + resolved_id = self._resolve_session_id(session_id) + if resolved_id and resolved_id in self.processes: + process = self.processes[resolved_id] await process.stop() - del self.processes[session_id] + self._cleanup_process(process) logger.info( "Claude session stopped", - session_id=session_id, + session_id=resolved_id, active_sessions=len(self.processes), ) @@ -338,13 +365,33 @@ def get_active_sessions(self) -> List[str]: async def continue_conversation(self, session_id: str, prompt: str) -> bool: """Continue existing conversation.""" - process = self.processes.get(session_id) + resolved_id = self._resolve_session_id(session_id) + if not resolved_id: + return False + process = self.processes.get(resolved_id) if not process: return False await process.send_input(prompt) return True + def _register_cli_session(self, api_session_id: str, cli_session_id: str): + if not cli_session_id: + return + self.cli_session_index[cli_session_id] = api_session_id + + def _resolve_session_id(self, session_id: str) -> Optional[str]: + if session_id in self.processes: + return session_id + return self.cli_session_index.get(session_id) + + def _cleanup_process(self, process: ClaudeProcess): + api_session_id = process.session_id + if api_session_id in self.processes: + del self.processes[api_session_id] + if process.cli_session_id: + self.cli_session_index.pop(process.cli_session_id, None) + # Utility functions for project management def create_project_directory(project_id: str) -> str: diff --git a/claude_code_api/core/config.py b/claude_code_api/core/config.py index 9497f7d..27772c0 100644 --- a/claude_code_api/core/config.py +++ b/claude_code_api/core/config.py @@ -59,6 +59,11 @@ def default_project_root() -> str: return os.path.join(os.getcwd(), "claude_projects") +def default_session_map_path() -> str: + """Default path for CLI-to-API session mapping.""" + return os.path.join(os.getcwd(), "claude_sessions", "session_map.json") + + def _is_shell_script_line(line: str) -> bool: if not line: return False @@ -144,6 +149,7 @@ def parse_api_keys(cls, v): project_root: str = default_project_root() max_project_size_mb: int = 1000 cleanup_interval_minutes: int = 60 + session_map_path: str = default_session_map_path() # Database Configuration database_url: str = "sqlite:///./claude_api.db" diff --git a/claude_code_api/core/database.py b/claude_code_api/core/database.py index e8ce5e6..82ccfb2 100644 --- a/claude_code_api/core/database.py +++ b/claude_code_api/core/database.py @@ -239,6 +239,16 @@ async def update_session_metrics(session_id: str, tokens_used: int, cost: float) session_obj.updated_at = utc_now() await session.commit() + @staticmethod + async def deactivate_session(session_id: str): + """Mark session as inactive.""" + async with AsyncSessionLocal() as session: + session_obj = await session.get(Session, session_id) + if session_obj: + session_obj.is_active = False + session_obj.updated_at = utc_now() + await session.commit() + # Create global database manager instance db_manager = DatabaseManager() diff --git a/claude_code_api/core/session_manager.py b/claude_code_api/core/session_manager.py index c7ee52b..7c302aa 100644 --- a/claude_code_api/core/session_manager.py +++ b/claude_code_api/core/session_manager.py @@ -1,6 +1,8 @@ """Session management for Claude Code API Gateway.""" import asyncio +import json +import os import uuid from datetime import timedelta from typing import Any, Dict, List, Optional @@ -22,6 +24,7 @@ def __init__( self, session_id: str, project_id: str, model: str, system_prompt: str = None ): self.session_id = session_id + self.cli_session_id: Optional[str] = None self.project_id = project_id self.model = model self.system_prompt = system_prompt @@ -38,9 +41,12 @@ class SessionManager: def __init__(self): self.active_sessions: Dict[str, SessionInfo] = {} + self.cli_session_index: Dict[str, str] = {} + self.session_map_path = settings.session_map_path self.cleanup_task: Optional[asyncio.Task] = None self._shutdown_event = asyncio.Event() self._start_cleanup_task() + self._load_cli_session_map() def _start_cleanup_task(self): """Start periodic cleanup task.""" @@ -57,12 +63,47 @@ async def _periodic_cleanup(self): ) break except asyncio.TimeoutError: - self.cleanup_expired_sessions() + await self.cleanup_expired_sessions() except asyncio.CancelledError: raise except Exception as e: logger.error("Error in periodic cleanup", error=str(e)) + def _load_cli_session_map(self): + if not self.session_map_path: + return + try: + with open(self.session_map_path, "r", encoding="utf-8") as handle: + data = json.load(handle) + except FileNotFoundError: + return + except Exception as exc: + logger.warning("Failed to load session map", error=str(exc)) + return + + mapping = data.get("cli_to_api", data) if isinstance(data, dict) else {} + if isinstance(mapping, dict): + self.cli_session_index = { + str(cli_id): str(api_id) + for cli_id, api_id in mapping.items() + if cli_id and api_id + } + + def _persist_cli_session_map(self): + if not self.session_map_path: + return + try: + directory = os.path.dirname(self.session_map_path) + if directory: + os.makedirs(directory, exist_ok=True) + tmp_path = f"{self.session_map_path}.tmp" + payload = {"cli_to_api": self.cli_session_index} + with open(tmp_path, "w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + os.replace(tmp_path, self.session_map_path) + except Exception as exc: + logger.warning("Failed to persist session map", error=str(exc)) + async def create_session( self, project_id: str, @@ -109,12 +150,16 @@ async def create_session( async def get_session(self, session_id: str) -> Optional[SessionInfo]: """Get session information.""" + resolved_id = self._resolve_session_id(session_id) + if resolved_id is None: + resolved_id = session_id + # Check active sessions first - if session_id in self.active_sessions: - return self.active_sessions[session_id] + if resolved_id in self.active_sessions: + return self.active_sessions[resolved_id] # Load from database if not in memory - db_session = await db_manager.get_session(session_id) + db_session = await db_manager.get_session(resolved_id) if db_session and db_session.is_active: # Restore to active sessions session_info = SessionInfo( @@ -129,7 +174,7 @@ async def get_session(self, session_id: str) -> Optional[SessionInfo]: session_info.total_tokens = db_session.total_tokens session_info.total_cost = db_session.total_cost - self.active_sessions[session_id] = session_info + self.active_sessions[resolved_id] = session_info return session_info return None @@ -157,7 +202,7 @@ async def update_session( # Add message to database message_data = { - "session_id": session_id, + "session_id": session_info.session_id, "role": role, "content": message_content, "input_tokens": tokens_used if role == "user" else 0, @@ -169,7 +214,9 @@ async def update_session( await db_manager.add_message(message_data) # Update database metrics - await db_manager.update_session_metrics(session_id, tokens_used, cost) + await db_manager.update_session_metrics( + session_info.session_id, tokens_used, cost + ) logger.debug( "Session updated", @@ -179,12 +226,17 @@ async def update_session( total_tokens=session_info.total_tokens, ) - def end_session(self, session_id: str): + async def end_session(self, session_id: str): """End session and cleanup.""" - if session_id in self.active_sessions: - session_info = self.active_sessions[session_id] + resolved_id = self._resolve_session_id(session_id) or session_id + if resolved_id in self.active_sessions: + session_info = self.active_sessions[resolved_id] session_info.is_active = False - del self.active_sessions[session_id] + await db_manager.deactivate_session(resolved_id) + if session_info.cli_session_id: + self.cli_session_index.pop(session_info.cli_session_id, None) + self._persist_cli_session_map() + del self.active_sessions[resolved_id] logger.info( "Session ended", @@ -195,7 +247,7 @@ def end_session(self, session_id: str): total_cost=session_info.total_cost, ) - def cleanup_expired_sessions(self): + async def cleanup_expired_sessions(self): """Clean up expired sessions.""" current_time = utc_now() timeout_delta = timedelta(minutes=settings.session_timeout_minutes) @@ -206,14 +258,14 @@ def cleanup_expired_sessions(self): expired_sessions.append(session_id) for session_id in expired_sessions: - self.end_session(session_id) + await self.end_session(session_id) logger.info("Session expired and cleaned up", session_id=session_id) async def cleanup_all(self): """Clean up all sessions.""" session_ids = list(self.active_sessions.keys()) for session_id in session_ids: - self.end_session(session_id) + await self.end_session(session_id) if self.cleanup_task and not self.cleanup_task.done(): self._shutdown_event.set() @@ -221,6 +273,20 @@ async def cleanup_all(self): logger.info("All sessions cleaned up") + def register_cli_session(self, api_session_id: str, cli_session_id: str): + if not cli_session_id: + return + session_info = self.active_sessions.get(api_session_id) + if session_info: + session_info.cli_session_id = cli_session_id + self.cli_session_index[cli_session_id] = api_session_id + self._persist_cli_session_map() + + def _resolve_session_id(self, session_id: str) -> Optional[str]: + if session_id in self.active_sessions: + return session_id + return self.cli_session_index.get(session_id) + def get_active_session_count(self) -> int: """Get number of active sessions.""" return len(self.active_sessions) @@ -292,9 +358,9 @@ def format_messages_for_claude( return formatted - def clear_conversation(self, session_id: str): + async def clear_conversation(self, session_id: str): """Clear conversation history.""" if session_id in self.conversation_history: del self.conversation_history[session_id] - self.session_manager.end_session(session_id) + await self.session_manager.end_session(session_id) diff --git a/tests/conftest.py b/tests/conftest.py index 15822b6..67de9a5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,6 +37,7 @@ def setup_test_environment(): "claude_binary_path": getattr(settings, "claude_binary_path", "claude"), "database_url": getattr(settings, "database_url", "sqlite:///./test.db"), "debug": getattr(settings, "debug", False), + "session_map_path": getattr(settings, "session_map_path", None), } # Set test settings @@ -105,6 +106,7 @@ def setup_test_environment(): settings.database_url = f"sqlite:///{temp_dir}/test.db" settings.debug = True + settings.session_map_path = os.path.join(temp_dir, "session_map.json") # Create directories os.makedirs(settings.project_root, exist_ok=True) diff --git a/tests/fixtures/claude_stream_session_map.jsonl b/tests/fixtures/claude_stream_session_map.jsonl new file mode 100644 index 0000000..c027f50 --- /dev/null +++ b/tests/fixtures/claude_stream_session_map.jsonl @@ -0,0 +1,3 @@ +{"type":"system","message":{"role":"system","content":[{"type":"text","text":"You are Claude Code."}]},"session_id":"sess_map_1","model":"claude-haiku-4-5-20250929","cwd":".","tools":["bash","read"],"timestamp":"2026-02-04T00:00:00Z"} +{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"Mapping acknowledged."}]},"session_id":"sess_map_1","model":"claude-haiku-4-5-20250929"} +{"type":"result","result":"ok","session_id":"sess_map_1","model":"claude-haiku-4-5-20250929","usage":{"input_tokens":10,"output_tokens":6},"cost_usd":0.00001,"duration_ms":900,"num_turns":1} diff --git a/tests/fixtures/index.json b/tests/fixtures/index.json index d8cee50..98f7b1e 100644 --- a/tests/fixtures/index.json +++ b/tests/fixtures/index.json @@ -1,4 +1,11 @@ [ + { + "match": [ + "mapping test", + "session map" + ], + "file": "claude_stream_session_map.jsonl" + }, { "match": [ "list files", diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 691a06e..3885688 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -8,6 +8,7 @@ - Streaming and non-streaming responses """ +import asyncio import json import os import shutil @@ -41,6 +42,7 @@ def parse_sse_events(body_text: str) -> List[Dict[str, Any]]: sys.path.insert(0, str(PROJECT_ROOT)) from claude_code_api.core.config import settings +from claude_code_api.core.session_manager import SessionManager from claude_code_api.main import app from claude_code_api.models.claude import get_available_models from tests.model_utils import get_test_model_id @@ -219,6 +221,34 @@ def test_chat_completion_with_system_prompt(self, client): response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 + def test_cli_session_mapping_persisted(self, client): + """Persist CLI-to-API session mapping to disk and reload it.""" + request_data = { + "model": DEFAULT_MODEL, + "messages": [{"role": "user", "content": "mapping test"}], + "stream": False, + } + + response = client.post("/v1/chat/completions", json=request_data) + assert response.status_code == 200 + data = response.json() + api_session_id = data["session_id"] + + with open(settings.session_map_path, "r", encoding="utf-8") as handle: + saved = json.load(handle) + mapping = saved.get("cli_to_api", saved) + assert mapping["sess_map_1"] == api_session_id + + async def _load_session(): + new_manager = SessionManager() + session = await new_manager.get_session("sess_map_1") + await new_manager.cleanup_all() + return session + + session = asyncio.run(_load_session()) + assert session is not None + assert session.session_id == api_session_id + data = response.json() assert data["model"] == DEFAULT_MODEL assert len(data["choices"]) > 0 diff --git a/tests/test_session_manager_unit.py b/tests/test_session_manager_unit.py index e14e5a3..08111c1 100644 --- a/tests/test_session_manager_unit.py +++ b/tests/test_session_manager_unit.py @@ -41,13 +41,18 @@ async def fake_get_session(_session_id): @pytest.mark.asyncio -async def test_cleanup_expired_sessions(): +async def test_cleanup_expired_sessions(monkeypatch): manager = SessionManager() session = SessionInfo(session_id="sess", project_id="proj", model="claude") session.updated_at = utc_now() - timedelta(minutes=60) manager.active_sessions["sess"] = session - manager.cleanup_expired_sessions() + async def fake_deactivate(_session_id): + return None + + monkeypatch.setattr(sm_module.db_manager, "deactivate_session", fake_deactivate) + + await manager.cleanup_expired_sessions() assert "sess" not in manager.active_sessions await manager.cleanup_all() From 7c337796bd99f57bdc846064771cdd1fcd41707d Mon Sep 17 00:00:00 2001 From: Mehdi Date: Wed, 4 Feb 2026 23:10:58 +0000 Subject: [PATCH 4/9] Security fix --- claude_code_api/api/projects.py | 10 ++-- claude_code_api/core/claude_manager.py | 5 +- claude_code_api/core/security.py | 78 +++++++++++++++++++------- claude_code_api/utils/streaming.py | 24 +++++--- 4 files changed, 79 insertions(+), 38 deletions(-) diff --git a/claude_code_api/api/projects.py b/claude_code_api/api/projects.py index abf4063..e2f9e3e 100644 --- a/claude_code_api/api/projects.py +++ b/claude_code_api/api/projects.py @@ -1,7 +1,6 @@ """Projects API endpoint - Extension to OpenAI API.""" import math -import os import uuid import structlog @@ -15,7 +14,7 @@ ) from claude_code_api.core.config import settings from claude_code_api.core.database import db_manager -from claude_code_api.core.security import validate_path +from claude_code_api.core.security import ensure_directory_within_base from claude_code_api.models.openai import ( CreateProjectRequest, PaginatedResponse, @@ -74,10 +73,9 @@ async def create_project( # Create project directory if project_request.path: - # Validate path - project_path = validate_path(project_request.path, settings.project_root) - - os.makedirs(project_path, exist_ok=True) + project_path = ensure_directory_within_base( + project_request.path, settings.project_root + ) else: project_path = create_project_directory(project_id) diff --git a/claude_code_api/core/claude_manager.py b/claude_code_api/core/claude_manager.py index 71bb82d..43eb6bc 100644 --- a/claude_code_api/core/claude_manager.py +++ b/claude_code_api/core/claude_manager.py @@ -11,6 +11,7 @@ from claude_code_api.models.claude import get_default_model from .config import settings +from .security import ensure_directory_within_base logger = structlog.get_logger() @@ -396,9 +397,7 @@ def _cleanup_process(self, process: ClaudeProcess): # Utility functions for project management def create_project_directory(project_id: str) -> str: """Create project directory.""" - project_path = os.path.join(settings.project_root, project_id) - os.makedirs(project_path, exist_ok=True) - return project_path + return ensure_directory_within_base(project_id, settings.project_root) def cleanup_project_directory(project_path: str): diff --git a/claude_code_api/core/security.py b/claude_code_api/core/security.py index d836488..277369d 100644 --- a/claude_code_api/core/security.py +++ b/claude_code_api/core/security.py @@ -1,6 +1,7 @@ """Security utilities.""" import os +from pathlib import Path import structlog from fastapi import HTTPException, status @@ -8,13 +9,13 @@ logger = structlog.get_logger() -def validate_path(path: str, base_path: str) -> str: +def resolve_path_within_base(path: str, base_path: str) -> str: """ - Validate that a path is safe and within the base path. - Prevents directory traversal attacks. + Resolve a user-provided path within a base directory. + Prevents directory traversal and symlink escapes. Args: - path: The path to validate (can be absolute or relative) + path: The path to resolve (can be absolute or relative) base_path: The allowed base directory Returns: @@ -24,36 +25,71 @@ def validate_path(path: str, base_path: str) -> str: HTTPException: If path is invalid or outside base_path """ try: - # Normalize base path to absolute path - abs_base_path = os.path.abspath(base_path) - - # Handle relative paths by joining with base_path - if not os.path.isabs(path): - abs_path = os.path.abspath(os.path.join(abs_base_path, path)) - else: - abs_path = os.path.abspath(path) - - # Check if path is within base_path - # os.path.commonpath returns the longest common sub-path - # If valid, commonpath should be equal to base_path - if os.path.commonpath([abs_base_path, abs_path]) != abs_base_path: + if path is None or not str(path).strip(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid path: Path is required", + ) + if "\x00" in str(path): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid path: Null byte detected", + ) + + abs_base_path = Path(base_path).resolve() + candidate_path = Path(path) + if not candidate_path.is_absolute(): + candidate_path = abs_base_path / candidate_path + + resolved_path = candidate_path.resolve(strict=False) + + if not resolved_path.is_relative_to(abs_base_path): logger.warning( "Path traversal attempt detected", path=path, - resolved_path=abs_path, - base_path=abs_base_path, + resolved_path=str(resolved_path), + base_path=str(abs_base_path), ) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid path: Path traversal detected", ) - return abs_path + return str(resolved_path) except HTTPException: raise except Exception as e: logger.error("Path validation error", error=str(e), path=path) raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid path: {str(e)}" + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid path: Path validation failed", ) + + +def ensure_directory_within_base( + path: str, base_path: str, *, allow_subpaths: bool = True +) -> str: + """Validate a path within base_path and create the directory.""" + path_value = os.fspath(path) + if not allow_subpaths: + if os.path.isabs(path_value): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid path: Absolute paths are not allowed", + ) + for sep in (os.path.sep, os.path.altsep): + if sep and sep in path_value: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid path: Path separators are not allowed", + ) + + resolved_path = resolve_path_within_base(path_value, base_path) + os.makedirs(resolved_path, exist_ok=True) + return resolved_path + + +def validate_path(path: str, base_path: str) -> str: + """Backward-compatible wrapper for path resolution.""" + return resolve_path_within_base(path, base_path) diff --git a/claude_code_api/utils/streaming.py b/claude_code_api/utils/streaming.py index 1689c9e..04ecebc 100644 --- a/claude_code_api/utils/streaming.py +++ b/claude_code_api/utils/streaming.py @@ -150,8 +150,8 @@ async def convert_stream( yield SSEFormatter.format_completion("") except Exception as e: - logger.error("Error in stream conversion", error=str(e)) - yield SSEFormatter.format_error(f"Stream error: {str(e)}") + logger.error("Error in stream conversion", error=str(e), exc_info=True) + yield SSEFormatter.format_error("Stream error") def get_final_response(self) -> Dict[str, Any]: """Get complete response in OpenAI format.""" @@ -198,8 +198,10 @@ async def create_stream( heartbeat_task.cancel() except Exception as e: - logger.error("Streaming error", session_id=session_id, error=str(e)) - yield SSEFormatter.format_error(f"Streaming failed: {str(e)}") + logger.error( + "Streaming error", session_id=session_id, error=str(e), exc_info=True + ) + yield SSEFormatter.format_error("Streaming failed") finally: # Cleanup if session_id in self.active_streams: @@ -305,10 +307,16 @@ async def create_sse_response( session_id: str, model: str, claude_process: ClaudeProcess ) -> AsyncGenerator[str, None]: """Create SSE response for Claude Code output.""" - async for chunk in streaming_manager.create_stream( - session_id, model, claude_process - ): - yield chunk + try: + async for chunk in streaming_manager.create_stream( + session_id, model, claude_process + ): + yield chunk + except Exception as e: + logger.error( + "SSE response error", session_id=session_id, error=str(e), exc_info=True + ) + yield SSEFormatter.format_error("Stream error") def _extract_assistant_payload( From 803d6b3f4303ef0d40fd0028a94b2edcacdf0dfc Mon Sep 17 00:00:00 2001 From: Mehdi Date: Thu, 5 Feb 2026 00:28:19 +0000 Subject: [PATCH 5/9] Update sonar script --- Makefile | 38 +------------------------ scripts/run-sonar.sh | 68 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 37 deletions(-) create mode 100755 scripts/run-sonar.sh diff --git a/Makefile b/Makefile index f5163df..35cbd72 100644 --- a/Makefile +++ b/Makefile @@ -87,43 +87,7 @@ kill: .PHONY: sonar sonar-cloud coverage-sonar sbom sbom-upload gitleaks fmt lint vet sonar: ## Run sonar-scanner for SonarQube analysis - @mkdir -p $(SONAR_DIR) $(COVERAGE_DIR) - @echo "Generating coverage report for SonarQube..." - @python -m pytest --cov=claude_code_api --cov-report=xml:$(COVERAGE_DIR)/coverage.xml --cov-report=term-missing --junitxml=$(SONAR_DIR)/xunit-report.xml -v tests/ - @if command -v sonar-scanner >/dev/null 2>&1; then \ - if [ -f ".env.vault" ]; then \ - . ./.env.vault; \ - fi; \ - if [ -f ".env" ]; then \ - set -a; . ./.env; set +a; \ - fi; \ - if [ -n "$${VAULT_SECRET_PATHS:-}" ] || [ -n "$${VAULT_REQUIRED_VARS:-}" ]; then \ - if [ -f "./scripts/vault-helper.sh" ]; then \ - . ./scripts/vault-helper.sh; \ - vault_helper::load_from_definitions "$${VAULT_SECRET_PATHS:-}" "$${VAULT_REQUIRED_VARS:-}" "$${VAULT_TOKEN_FILE:-}"; \ - fi; \ - fi; \ - SONAR_HOST_URL="$${SONAR_HOST_URL:-$${SONAR_URL:-}}"; \ - if [ -z "$$SONAR_HOST_URL" ]; then \ - echo "SONAR_URL or SONAR_HOST_URL is required (e.g., https://sonarcloud.io or https://sonar.local)"; \ - exit 1; \ - fi; \ - case "$$SONAR_HOST_URL" in \ - http://*|https://*) ;; \ - *) echo "SONAR_URL must include http(s) scheme: $$SONAR_HOST_URL"; exit 1 ;; \ - esac; \ - if [ -z "$${SONAR_TOKEN:-}" ]; then \ - echo "SONAR_TOKEN not set - proceeding without authentication"; \ - fi; \ - sonar-scanner \ - -Dsonar.host.url=$$SONAR_HOST_URL \ - -Dsonar.token=$$SONAR_TOKEN \ - -Dsonar.projectVersion=$${VERSION:-1.0.0} \ - -Dsonar.working.directory=$(SONAR_DIR)/scannerwork; \ - else \ - echo "sonar-scanner not found. Install with: brew install sonar-scanner or download from https://docs.sonarqube.org/latest/analysis/scan/sonarscanner/"; \ - exit 1; \ - fi + @SONAR_DIR=$(SONAR_DIR) COVERAGE_DIR=$(COVERAGE_DIR) VERSION=$(VERSION) ./scripts/run-sonar.sh sonar-cloud: ## Run sonar-scanner for SonarCloud (uses different token/env) @echo "Running SonarCloud scanner..." diff --git a/scripts/run-sonar.sh b/scripts/run-sonar.sh new file mode 100755 index 0000000..922fe44 --- /dev/null +++ b/scripts/run-sonar.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +SONAR_DIR="${SONAR_DIR:-dist/quality/sonar}" +COVERAGE_DIR="${COVERAGE_DIR:-dist/quality/coverage}" +VERSION="${VERSION:-1.0.0}" + +mkdir -p "$SONAR_DIR" "$COVERAGE_DIR" + +echo "Generating coverage report for SonarQube..." +python -m pytest \ + --cov=claude_code_api \ + --cov-report=xml:"$COVERAGE_DIR/coverage.xml" \ + --cov-report=term-missing \ + --junitxml="$SONAR_DIR/xunit-report.xml" \ + -v tests/ + +if ! command -v sonar-scanner >/dev/null 2>&1; then + echo "sonar-scanner not found. Install with: brew install sonar-scanner or download from https://docs.sonarqube.org/latest/analysis/scan/sonarscanner/" + exit 1 +fi + +if [ -f ".env.vault" ]; then + . ./.env.vault +fi + +if [ -f ".env" ]; then + set -a + . ./.env + set +a +fi + +if [ -n "${VAULT_SECRET_PATHS:-}" ] || [ -n "${VAULT_REQUIRED_VARS:-}" ]; then + if [ -f "./scripts/vault-helper.sh" ]; then + . ./scripts/vault-helper.sh + vault_helper::load_from_definitions \ + "${VAULT_SECRET_PATHS:-}" \ + "${VAULT_REQUIRED_VARS:-}" \ + "${VAULT_TOKEN_FILE:-}" + fi +fi + +SONAR_HOST_URL="${SONAR_HOST_URL:-${SONAR_URL:-}}" +if [ -z "$SONAR_HOST_URL" ]; then + echo "SONAR_URL or SONAR_HOST_URL is required (e.g., https://sonarcloud.io or https://sonar.local)" + exit 1 +fi + +case "$SONAR_HOST_URL" in + http://*|https://*) ;; + *) + echo "SONAR_URL must include http(s) scheme: $SONAR_HOST_URL" + exit 1 + ;; +esac + +if [ -z "${SONAR_TOKEN:-}" ]; then + echo "SONAR_TOKEN not set - proceeding without authentication" +fi + +sonar-scanner \ + -Dsonar.host.url="$SONAR_HOST_URL" \ + -Dsonar.token="$SONAR_TOKEN" \ + -Dsonar.projectVersion="$VERSION" \ + -Dsonar.working.directory="$SONAR_DIR/scannerwork" From b0ae39a68f6d4392d4ad205b89bbd5b9ce2a5061 Mon Sep 17 00:00:00 2001 From: Mehdi Date: Thu, 5 Feb 2026 00:45:03 +0000 Subject: [PATCH 6/9] Multiple qa fixes --- .env | 8 ++- .gitignore | 1 + claude_code_api/api/chat.py | 20 +++++- claude_code_api/config/models.json | 22 +++--- claude_code_api/core/claude_manager.py | 21 +++++- claude_code_api/core/database.py | 24 ++++--- claude_code_api/core/security.py | 55 +++++++++++++- claude_code_api/utils/parser.py | 18 ++--- claude_code_api/utils/streaming.py | 99 ++++++++++++++++---------- claude_code_api/utils/time.py | 2 +- scripts/run-sonar.sh | 6 ++ scripts/vault-helper.sh | 19 ++++- tests/model_utils.py | 7 +- 13 files changed, 221 insertions(+), 81 deletions(-) diff --git a/.env b/.env index c1d804b..c52a7b7 100644 --- a/.env +++ b/.env @@ -1,7 +1,4 @@ # This file is committed to git no secrets. - -#!/usr/bin/env bash - # Vault environment loader wrapper # Usage: source .env.vault (never run directly) @@ -45,6 +42,11 @@ SECRET_DEFS="${VAULT_SECRET_PATHS:-$DEFAULT_VAULT_SECRET_DEFS}" REQUIRED_VARS="${VAULT_REQUIRED_VARS:-$DEFAULT_VAULT_REQUIRED_VARS}" vault_helper::load_from_definitions "$SECRET_DEFS" "$REQUIRED_VARS" "$VAULT_TOKEN_FILE" +vault_status=$? +if [[ $vault_status -ne 0 ]]; then + echo "Error: vault_helper::load_from_definitions failed with exit code $vault_status" >&2 + return "$vault_status" +fi # Commented out for CI/automated testing # SONAR_TOKEN="" diff --git a/.gitignore b/.gitignore index 0789aef..e5d25cd 100644 --- a/.gitignore +++ b/.gitignore @@ -126,6 +126,7 @@ celerybeat.pid # Environments .env.vault .env.cloud +!.env .venv env/ venv/ diff --git a/claude_code_api/api/chat.py b/claude_code_api/api/chat.py index d573bce..521c403 100644 --- a/claude_code_api/api/chat.py +++ b/claude_code_api/api/chat.py @@ -1,5 +1,6 @@ """Chat completions API endpoint - OpenAI compatible.""" +import hashlib import json from typing import Any, Dict, Tuple @@ -61,12 +62,27 @@ def _http_error( async def _log_raw_request(req: Request) -> None: raw_body = await req.body() content_type = req.headers.get("content-type", "unknown") + sensitive_headers = { + "authorization", + "proxy-authorization", + "x-api-key", + "api-key", + "x-auth-token", + } + sanitized_headers = {} + for key, value in req.headers.items(): + if key.lower() in sensitive_headers: + sanitized_headers[key] = "" + else: + sanitized_headers[key] = value + body_hash = hashlib.sha256(raw_body).hexdigest() if raw_body else None logger.info( "Raw request received", content_type=content_type, body_size=len(raw_body), - user_agent=req.headers.get("user-agent", "unknown"), - raw_body=raw_body.decode()[:1000] if raw_body else "empty", + user_agent=sanitized_headers.get("user-agent", "unknown"), + headers=sanitized_headers, + body_hash=body_hash or "empty", ) diff --git a/claude_code_api/config/models.json b/claude_code_api/config/models.json index 8a9087c..c6d1f28 100644 --- a/claude_code_api/config/models.json +++ b/claude_code_api/config/models.json @@ -2,12 +2,12 @@ "default_model": "claude-sonnet-4-5-20250929", "models": [ { - "id": "claude-opus-4-5-20250929", + "id": "claude-opus-4-5-20251101", "name": "Claude Opus 4.5", "description": "Most powerful Claude model for complex reasoning", - "max_tokens": 500000, - "input_cost_per_1k": 15.0, - "output_cost_per_1k": 75.0, + "max_tokens": 65536, + "input_cost_per_1k": 0.005, + "output_cost_per_1k": 0.025, "supports_streaming": true, "supports_tools": true }, @@ -15,19 +15,19 @@ "id": "claude-sonnet-4-5-20250929", "name": "Claude Sonnet 4.5", "description": "Latest Sonnet model with enhanced capabilities", - "max_tokens": 500000, - "input_cost_per_1k": 3.0, - "output_cost_per_1k": 15.0, + "max_tokens": 65536, + "input_cost_per_1k": 0.003, + "output_cost_per_1k": 0.015, "supports_streaming": true, "supports_tools": true }, { - "id": "claude-haiku-4-5-20250929", + "id": "claude-haiku-4-5-20251001", "name": "Claude Haiku 4.5", "description": "Fast and cost-effective model for quick tasks", - "max_tokens": 200000, - "input_cost_per_1k": 0.25, - "output_cost_per_1k": 1.25, + "max_tokens": 65536, + "input_cost_per_1k": 0.001, + "output_cost_per_1k": 0.005, "supports_streaming": true, "supports_tools": true } diff --git a/claude_code_api/core/claude_manager.py b/claude_code_api/core/claude_manager.py index 43eb6bc..65bed3a 100644 --- a/claude_code_api/core/claude_manager.py +++ b/claude_code_api/core/claude_manager.py @@ -72,8 +72,20 @@ async def start( # Start process from src directory (where Claude works without API key) src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + safe_cmd: List[str] = [] + redact_next = False + for part in cmd: + if redact_next: + safe_cmd.append("") + redact_next = False + continue + if part in ("-p", "--system-prompt"): + safe_cmd.append(part) + redact_next = True + continue + safe_cmd.append(part) logger.info(f"Starting Claude from directory: {src_dir}") - logger.info(f"Command: {' '.join(cmd)}") + logger.info(f"Command: {' '.join(safe_cmd)}") # Start process asynchronously self.process = await asyncio.create_subprocess_exec( @@ -397,7 +409,12 @@ def _cleanup_process(self, process: ClaudeProcess): # Utility functions for project management def create_project_directory(project_id: str) -> str: """Create project directory.""" - return ensure_directory_within_base(project_id, settings.project_root) + return ensure_directory_within_base( + project_id, + settings.project_root, + allow_subpaths=False, + sanitize_leaf=True, + ) def cleanup_project_directory(project_path: str): diff --git a/claude_code_api/core/database.py b/claude_code_api/core/database.py index 82ccfb2..32844a4 100644 --- a/claude_code_api/core/database.py +++ b/claude_code_api/core/database.py @@ -14,6 +14,7 @@ Text, func, select, + update, ) from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.declarative import declarative_base @@ -161,7 +162,9 @@ async def get_project(project_id: str) -> Optional[Project]: @staticmethod async def list_projects(page: int, per_page: int) -> List[Project]: """List projects with pagination.""" - offset = max(0, (page - 1) * per_page) + page = max(1, page) + per_page = max(1, min(per_page, 100)) + offset = (page - 1) * per_page async with AsyncSessionLocal() as session: stmt = ( select(Project) @@ -231,13 +234,18 @@ async def add_message(message_data: dict) -> Message: async def update_session_metrics(session_id: str, tokens_used: int, cost: float): """Update session usage metrics.""" async with AsyncSessionLocal() as session: - session_obj = await session.get(Session, session_id) - if session_obj: - session_obj.total_tokens += tokens_used - session_obj.total_cost += cost - session_obj.message_count += 1 - session_obj.updated_at = utc_now() - await session.commit() + stmt = ( + update(Session) + .where(Session.id == session_id) + .values( + total_tokens=Session.total_tokens + tokens_used, + total_cost=Session.total_cost + cost, + message_count=Session.message_count + 1, + updated_at=utc_now(), + ) + ) + await session.execute(stmt) + await session.commit() @staticmethod async def deactivate_session(session_id: str): diff --git a/claude_code_api/core/security.py b/claude_code_api/core/security.py index 277369d..7be6eea 100644 --- a/claude_code_api/core/security.py +++ b/claude_code_api/core/security.py @@ -1,6 +1,7 @@ """Security utilities.""" import os +import re from pathlib import Path import structlog @@ -43,7 +44,14 @@ def resolve_path_within_base(path: str, base_path: str) -> str: resolved_path = candidate_path.resolve(strict=False) - if not resolved_path.is_relative_to(abs_base_path): + try: + common_path = os.path.commonpath( + [os.fspath(abs_base_path), os.fspath(resolved_path)] + ) + except ValueError: + common_path = "" + + if common_path != os.fspath(abs_base_path): logger.warning( "Path traversal attempt detected", path=path, @@ -68,7 +76,11 @@ def resolve_path_within_base(path: str, base_path: str) -> str: def ensure_directory_within_base( - path: str, base_path: str, *, allow_subpaths: bool = True + path: str, + base_path: str, + *, + allow_subpaths: bool = True, + sanitize_leaf: bool = False, ) -> str: """Validate a path within base_path and create the directory.""" path_value = os.fspath(path) @@ -85,8 +97,45 @@ def ensure_directory_within_base( detail="Invalid path: Path separators are not allowed", ) + if sanitize_leaf: + if allow_subpaths: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid path: Sanitization only allowed for leaf paths", + ) + sanitized = re.sub(r"[^A-Za-z0-9._-]", "_", path_value) + if not sanitized.strip("._-"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid path: Path is required", + ) + path_value = sanitized + resolved_path = resolve_path_within_base(path_value, base_path) - os.makedirs(resolved_path, exist_ok=True) + abs_base_path = os.path.abspath(base_path) + abs_resolved_path = os.path.abspath(resolved_path) + try: + common_path = os.path.commonpath([abs_base_path, abs_resolved_path]) + except ValueError: + common_path = "" + if common_path != abs_base_path: + logger.warning( + "Path traversal attempt detected (post-validate)", + path=path_value, + resolved_path=resolved_path, + base_path=abs_base_path, + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid path: Path traversal detected", + ) + try: + os.makedirs(resolved_path, exist_ok=True) + except FileExistsError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid path: {resolved_path} exists and is not a directory", + ) from e return resolved_path diff --git a/claude_code_api/utils/parser.py b/claude_code_api/utils/parser.py index d8a935e..de2c050 100644 --- a/claude_code_api/utils/parser.py +++ b/claude_code_api/utils/parser.py @@ -216,14 +216,15 @@ class OpenAIConverter: @staticmethod def claude_message_to_openai(message: ClaudeMessage) -> Optional[Dict[str, Any]]: """Convert Claude message to OpenAI chat format.""" - if message.is_system_message(): - return {"role": "system", "content": message.extract_text_content()} + parser = ClaudeOutputParser() + if parser.is_system_message(message): + return {"role": "system", "content": parser.extract_text_content(message)} - if message.is_user_message(): - return {"role": "user", "content": message.extract_text_content()} + if parser.is_user_message(message): + return {"role": "user", "content": parser.extract_text_content(message)} - if message.is_assistant_message(): - content = message.extract_text_content() + if parser.is_assistant_message(message): + content = parser.extract_text_content(message) if content: return {"role": "assistant", "content": content} @@ -234,10 +235,11 @@ def claude_stream_to_openai_chunk( message: ClaudeMessage, chunk_id: str, model: str, created: int ) -> Optional[Dict[str, Any]]: """Convert Claude stream message to OpenAI chunk format.""" - if not message.is_assistant_message(): + parser = ClaudeOutputParser() + if not parser.is_assistant_message(message): return None - content = message.extract_text_content() + content = parser.extract_text_content(message) if not content: return None diff --git a/claude_code_api/utils/streaming.py b/claude_code_api/utils/streaming.py index 04ecebc..b27d2d4 100644 --- a/claude_code_api/utils/streaming.py +++ b/claude_code_api/utils/streaming.py @@ -1,8 +1,10 @@ """Server-Sent Events streaming utilities for OpenAI compatibility.""" import asyncio +import contextlib import json import uuid +from dataclasses import dataclass from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple import structlog @@ -36,7 +38,7 @@ def format_event(data: Dict[str, Any]) -> str: return f"data: {json_data}\n\n" @staticmethod - def format_completion(data: str) -> str: + def format_completion() -> str: """Format completion signal.""" return "data: [DONE]\n\n" @@ -139,44 +141,30 @@ async def convert_stream( break # Send final chunk - finish_reason = ( - "tool_calls" if (saw_tool_calls and not saw_assistant_text) else "stop" - ) + finish_reason = "tool_calls" if saw_tool_calls else "stop" yield SSEFormatter.format_event( self._build_chunk({}, finish_reason=finish_reason) ) # Send completion signal - yield SSEFormatter.format_completion("") + yield SSEFormatter.format_completion() except Exception as e: logger.error("Error in stream conversion", error=str(e), exc_info=True) yield SSEFormatter.format_error("Stream error") - def get_final_response(self) -> Dict[str, Any]: - """Get complete response in OpenAI format.""" - return { - "id": self.completion_id, - "object": "chat.completion", - "created": self.created, - "model": self.model, - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Response completed"}, - "finish_reason": "stop", - } - ], - "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, - "session_id": self.session_id, - } + +@dataclass +class StreamState: + converter: OpenAIStreamConverter + heartbeat_queue: asyncio.Queue[Optional[str]] class StreamingManager: """Manages multiple streaming connections.""" def __init__(self): - self.active_streams: Dict[str, OpenAIStreamConverter] = {} + self.active_streams: Dict[str, StreamState] = {} self.heartbeat_interval = 30 # seconds async def create_stream( @@ -184,34 +172,63 @@ async def create_stream( ) -> AsyncGenerator[str, None]: """Create new streaming connection.""" converter = OpenAIStreamConverter(model, session_id) - self.active_streams[session_id] = converter + heartbeat_queue: asyncio.Queue[Optional[str]] = asyncio.Queue() + self.active_streams[session_id] = StreamState( + converter=converter, heartbeat_queue=heartbeat_queue + ) + + async def _pump_stream(): + try: + async for chunk in converter.convert_stream(claude_process): + await heartbeat_queue.put(chunk) + finally: + await heartbeat_queue.put(None) + heartbeat_task: Optional[asyncio.Task] = None + stream_task: Optional[asyncio.Task] = None try: # Start heartbeat task - heartbeat_task = asyncio.create_task(self._send_heartbeats(session_id)) + heartbeat_task = asyncio.create_task( + self._send_heartbeats(session_id, heartbeat_queue) + ) - # Stream conversion - async for chunk in converter.convert_stream(claude_process): + stream_task = asyncio.create_task(_pump_stream()) + while True: + chunk = await heartbeat_queue.get() + if chunk is None: + break yield chunk - # Cancel heartbeat - heartbeat_task.cancel() - except Exception as e: logger.error( "Streaming error", session_id=session_id, error=str(e), exc_info=True ) yield SSEFormatter.format_error("Streaming failed") finally: + if heartbeat_task: + heartbeat_task.cancel() + if stream_task: + stream_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + if heartbeat_task: + await heartbeat_task + with contextlib.suppress(asyncio.CancelledError): + if stream_task: + await stream_task # Cleanup if session_id in self.active_streams: del self.active_streams[session_id] - async def _send_heartbeats(self, session_id: str): + async def _send_heartbeats( + self, session_id: str, heartbeat_queue: asyncio.Queue[Optional[str]] + ): """Send periodic heartbeats to keep connection alive.""" - while session_id in self.active_streams: - await asyncio.sleep(self.heartbeat_interval) - # Heartbeats are handled by the SSE client + try: + while session_id in self.active_streams: + await asyncio.sleep(self.heartbeat_interval) + await heartbeat_queue.put(SSEFormatter.format_heartbeat()) + except asyncio.CancelledError: + return def get_active_stream_count(self) -> int: """Get number of active streams.""" @@ -346,14 +363,22 @@ def _extract_assistant_payload( "Found assistant message", message_index=i, content_length=len(text_content), - content_preview=text_content[:100] if text_content else "empty", + ) + logger.debug( + "Found assistant message preview", + message_index=i, + content_preview="" if text_content else "empty", ) if text_content: content_parts.append(text_content) logger.info( "Extracted assistant text", message_index=i, - content_preview=text_content[:50], + ) + logger.debug( + "Extracted assistant text preview", + message_index=i, + content_preview="", ) tool_uses = parser.extract_tool_uses(normalized) @@ -398,7 +423,7 @@ def create_non_streaming_response( if usage is None: usage = OpenAIConverter.calculate_usage(parser) - finish_reason = "tool_calls" if (tool_calls and not complete_content) else "stop" + finish_reason = "tool_calls" if tool_calls else "stop" message_payload: Dict[str, Any] = { "role": "assistant", diff --git a/claude_code_api/utils/time.py b/claude_code_api/utils/time.py index d8e737f..a11a105 100644 --- a/claude_code_api/utils/time.py +++ b/claude_code_api/utils/time.py @@ -10,4 +10,4 @@ def utc_now() -> datetime: def utc_timestamp() -> int: """Return a UTC unix timestamp in seconds.""" - return int(utc_now().timestamp()) + return int(datetime.now(timezone.utc).timestamp()) diff --git a/scripts/run-sonar.sh b/scripts/run-sonar.sh index 922fe44..65c6f5f 100755 --- a/scripts/run-sonar.sh +++ b/scripts/run-sonar.sh @@ -40,6 +40,12 @@ if [ -n "${VAULT_SECRET_PATHS:-}" ] || [ -n "${VAULT_REQUIRED_VARS:-}" ]; then "${VAULT_SECRET_PATHS:-}" \ "${VAULT_REQUIRED_VARS:-}" \ "${VAULT_TOKEN_FILE:-}" + else + echo "Error: vault-helper.sh is required when VAULT_SECRET_PATHS or VAULT_REQUIRED_VARS is set." >&2 + echo "Missing helper: ./scripts/vault-helper.sh" >&2 + echo "VAULT_SECRET_PATHS=${VAULT_SECRET_PATHS:-}" >&2 + echo "VAULT_REQUIRED_VARS=${VAULT_REQUIRED_VARS:-}" >&2 + exit 1 fi fi diff --git a/scripts/vault-helper.sh b/scripts/vault-helper.sh index 5b55eb5..54c31ba 100755 --- a/scripts/vault-helper.sh +++ b/scripts/vault-helper.sh @@ -210,13 +210,28 @@ vault_helper::fetch_and_export() { if ! exports=$(printf '%s' "$data_json" | jq -r ' to_entries[]? | - "export \(.key)=\(.value | @sh)" + [.key, (.value | tostring | @base64)] | + @tsv '); then vault_helper::log_error "Unable to parse secrets from ${path}" return 1 fi - eval "$exports" + while IFS=$'\t' read -r key b64_value; do + [[ -z "$key" ]] && continue + if [[ ! "$key" =~ ^[A-Z_][A-Z0-9_]*$ ]]; then + vault_helper::log_error "Invalid secret key name: ${key}" + return 1 + fi + if ! value=$(printf '%s' "$b64_value" | base64 --decode 2>/dev/null); then + if ! value=$(printf '%s' "$b64_value" | base64 -d 2>/dev/null); then + vault_helper::log_error "Failed to decode secret for ${key}" + return 1 + fi + fi + printf -v "$key" '%s' "$value" + export "$key" + done <<< "$exports" vault_helper::apply_mappings "$data_json" "$path" "$mappings" count=$(printf '%s' "$data_json" | jq 'length') diff --git a/tests/model_utils.py b/tests/model_utils.py index bf4fab1..320ac64 100644 --- a/tests/model_utils.py +++ b/tests/model_utils.py @@ -4,11 +4,10 @@ from claude_code_api.models.claude import get_available_models, get_default_model -TEST_MODEL_ID = os.getenv("CLAUDE_CODE_API_TEST_MODEL", "claude-haiku-4-5-20250929") - def get_test_model_id() -> str: + test_model_id = os.getenv("CLAUDE_CODE_API_TEST_MODEL") or "" available = {model.id for model in get_available_models()} - if TEST_MODEL_ID in available: - return TEST_MODEL_ID + if test_model_id and test_model_id in available: + return test_model_id return get_default_model() From 6c14e3d70a5f532658a7695db408b6b34329012b Mon Sep 17 00:00:00 2001 From: Mehdi Date: Thu, 5 Feb 2026 01:06:04 +0000 Subject: [PATCH 7/9] Update Sonarqube --- .github/workflows/sonarcloud.yml | 43 +++++++++++++++ Makefile | 2 +- claude_code_api/core/security.py | 84 +++++++++++++++--------------- claude_code_api/utils/streaming.py | 2 +- scripts/vault-helper.sh | 4 +- sonar-project.properties | 1 - 6 files changed, 89 insertions(+), 47 deletions(-) create mode 100644 .github/workflows/sonarcloud.yml diff --git a/.github/workflows/sonarcloud.yml b/.github/workflows/sonarcloud.yml new file mode 100644 index 0000000..62f215b --- /dev/null +++ b/.github/workflows/sonarcloud.yml @@ -0,0 +1,43 @@ +name: SonarCloud QA Gate + +on: + push: + pull_request: + +jobs: + sonarcloud: + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: write + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install dependencies + run: make install-dev + + - name: Run tests with coverage + run: make coverage-sonar + + - name: SonarCloud scan + uses: SonarSource/sonarcloud-github-action@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} + with: + args: > + -Dsonar.host.url=https://sonarcloud.io + -Dsonar.organization=${{ secrets.SONAR_ORG }} + + - name: SonarCloud quality gate + uses: SonarSource/sonarqube-quality-gate-action@v1.1.0 + env: + SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} + SONAR_HOST_URL: https://sonarcloud.io + timeout-minutes: 5 diff --git a/Makefile b/Makefile index 35cbd72..2edc81e 100644 --- a/Makefile +++ b/Makefile @@ -94,7 +94,7 @@ sonar-cloud: ## Run sonar-scanner for SonarCloud (uses different token/env) @./scripts/run-sonar-cloud.sh coverage-sonar: ## Generate coverage for SonarQube - @mkdir -p $(COVERAGE_DIR) + @mkdir -p $(COVERAGE_DIR) $(SONAR_DIR) @python -m pytest --cov=claude_code_api --cov-report=xml:$(COVERAGE_DIR)/coverage.xml --cov-report=term-missing --junitxml=$(SONAR_DIR)/xunit-report.xml -v tests/ @echo "Coverage XML generated: $(COVERAGE_DIR)/coverage.xml" diff --git a/claude_code_api/core/security.py b/claude_code_api/core/security.py index 7be6eea..5df82e5 100644 --- a/claude_code_api/core/security.py +++ b/claude_code_api/core/security.py @@ -10,6 +10,42 @@ logger = structlog.get_logger() +def _bad_request(detail: str) -> HTTPException: + return HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=detail) + + +def _ensure_leaf_path(path_value: str) -> None: + if os.path.isabs(path_value): + raise _bad_request("Invalid path: Absolute paths are not allowed") + for sep in (os.path.sep, os.path.altsep): + if sep and sep in path_value: + raise _bad_request("Invalid path: Path separators are not allowed") + + +def _sanitize_leaf_value(path_value: str) -> str: + sanitized = re.sub(r"[^A-Za-z0-9._-]", "_", path_value) + if not sanitized.strip("._-"): + raise _bad_request("Invalid path: Path is required") + return sanitized + + +def _ensure_within_base(path_value: str, base_path: str, resolved_path: str) -> None: + abs_base_path = os.path.abspath(base_path) + abs_resolved_path = os.path.abspath(resolved_path) + try: + common_path = os.path.commonpath([abs_base_path, abs_resolved_path]) + except ValueError: + common_path = "" + if common_path != abs_base_path: + logger.warning( + "Path traversal attempt detected (post-validate)", + path=path_value, + resolved_path=resolved_path, + base_path=abs_base_path, + ) + raise _bad_request("Invalid path: Path traversal detected") + + def resolve_path_within_base(path: str, base_path: str) -> str: """ Resolve a user-provided path within a base directory. @@ -85,56 +121,20 @@ def ensure_directory_within_base( """Validate a path within base_path and create the directory.""" path_value = os.fspath(path) if not allow_subpaths: - if os.path.isabs(path_value): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid path: Absolute paths are not allowed", - ) - for sep in (os.path.sep, os.path.altsep): - if sep and sep in path_value: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid path: Path separators are not allowed", - ) + _ensure_leaf_path(path_value) if sanitize_leaf: if allow_subpaths: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid path: Sanitization only allowed for leaf paths", - ) - sanitized = re.sub(r"[^A-Za-z0-9._-]", "_", path_value) - if not sanitized.strip("._-"): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid path: Path is required", - ) - path_value = sanitized + raise _bad_request("Invalid path: Sanitization only allowed for leaf paths") + path_value = _sanitize_leaf_value(path_value) resolved_path = resolve_path_within_base(path_value, base_path) - abs_base_path = os.path.abspath(base_path) - abs_resolved_path = os.path.abspath(resolved_path) - try: - common_path = os.path.commonpath([abs_base_path, abs_resolved_path]) - except ValueError: - common_path = "" - if common_path != abs_base_path: - logger.warning( - "Path traversal attempt detected (post-validate)", - path=path_value, - resolved_path=resolved_path, - base_path=abs_base_path, - ) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid path: Path traversal detected", - ) + _ensure_within_base(path_value, base_path, resolved_path) try: os.makedirs(resolved_path, exist_ok=True) except FileExistsError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid path: {resolved_path} exists and is not a directory", + raise _bad_request( + f"Invalid path: {resolved_path} exists and is not a directory" ) from e return resolved_path diff --git a/claude_code_api/utils/streaming.py b/claude_code_api/utils/streaming.py index b27d2d4..6f55ced 100644 --- a/claude_code_api/utils/streaming.py +++ b/claude_code_api/utils/streaming.py @@ -228,7 +228,7 @@ async def _send_heartbeats( await asyncio.sleep(self.heartbeat_interval) await heartbeat_queue.put(SSEFormatter.format_heartbeat()) except asyncio.CancelledError: - return + raise def get_active_stream_count(self) -> int: """Get number of active streams.""" diff --git a/scripts/vault-helper.sh b/scripts/vault-helper.sh index 54c31ba..9205c51 100755 --- a/scripts/vault-helper.sh +++ b/scripts/vault-helper.sh @@ -220,8 +220,8 @@ vault_helper::fetch_and_export() { while IFS=$'\t' read -r key b64_value; do [[ -z "$key" ]] && continue if [[ ! "$key" =~ ^[A-Z_][A-Z0-9_]*$ ]]; then - vault_helper::log_error "Invalid secret key name: ${key}" - return 1 + vault_helper::log_warn "Skipping non-env secret key: ${key}" + continue fi if ! value=$(printf '%s' "$b64_value" | base64 --decode 2>/dev/null); then if ! value=$(printf '%s' "$b64_value" | base64 -d 2>/dev/null); then diff --git a/sonar-project.properties b/sonar-project.properties index 038ca5f..fcbef21 100644 --- a/sonar-project.properties +++ b/sonar-project.properties @@ -26,4 +26,3 @@ sonar.sourceEncoding=UTF-8 sonar.scm.provider=git sonar.scm.forceReloadAll=true sonar.scm.disabled=false -sonar.host.url= From a0b5875fb446694fef7c8e71d039417d39c869e8 Mon Sep 17 00:00:00 2001 From: Mehdi Date: Thu, 5 Feb 2026 01:19:32 +0000 Subject: [PATCH 8/9] Fix Action --- .github/workflows/sonarcloud.yml | 3 ++- claude_code_api/core/security.py | 35 ++++++++++++++++++-------------- scripts/run-sonar-cloud.sh | 6 +++--- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/.github/workflows/sonarcloud.yml b/.github/workflows/sonarcloud.yml index 62f215b..071fc4e 100644 --- a/.github/workflows/sonarcloud.yml +++ b/.github/workflows/sonarcloud.yml @@ -33,7 +33,8 @@ jobs: with: args: > -Dsonar.host.url=https://sonarcloud.io - -Dsonar.organization=${{ secrets.SONAR_ORG }} + -Dsonar.organization=codingworkflow + -Dsonar.projectKey=codingworkflow_claude-code-a-api - name: SonarCloud quality gate uses: SonarSource/sonarqube-quality-gate-action@v1.1.0 diff --git a/claude_code_api/core/security.py b/claude_code_api/core/security.py index 5df82e5..c58be20 100644 --- a/claude_code_api/core/security.py +++ b/claude_code_api/core/security.py @@ -30,8 +30,8 @@ def _sanitize_leaf_value(path_value: str) -> str: def _ensure_within_base(path_value: str, base_path: str, resolved_path: str) -> None: - abs_base_path = os.path.abspath(base_path) - abs_resolved_path = os.path.abspath(resolved_path) + abs_base_path = os.path.realpath(base_path) + abs_resolved_path = os.path.realpath(resolved_path) try: common_path = os.path.commonpath([abs_base_path, abs_resolved_path]) except ValueError: @@ -73,33 +73,38 @@ def resolve_path_within_base(path: str, base_path: str) -> str: detail="Invalid path: Null byte detected", ) - abs_base_path = Path(base_path).resolve() - candidate_path = Path(path) - if not candidate_path.is_absolute(): - candidate_path = abs_base_path / candidate_path - - resolved_path = candidate_path.resolve(strict=False) + abs_base_path = os.path.realpath(base_path) + path_value = os.fspath(path) + normalized_path = os.path.normpath(path_value) + if not os.path.isabs(normalized_path): + if normalized_path == ".." or normalized_path.startswith(f"..{os.path.sep}"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid path: Path traversal detected", + ) + if os.path.isabs(normalized_path): + resolved_path = os.path.realpath(normalized_path) + else: + resolved_path = os.path.realpath(os.path.join(abs_base_path, normalized_path)) try: - common_path = os.path.commonpath( - [os.fspath(abs_base_path), os.fspath(resolved_path)] - ) + common_path = os.path.commonpath([abs_base_path, resolved_path]) except ValueError: common_path = "" - if common_path != os.fspath(abs_base_path): + if common_path != abs_base_path: logger.warning( "Path traversal attempt detected", path=path, - resolved_path=str(resolved_path), - base_path=str(abs_base_path), + resolved_path=resolved_path, + base_path=abs_base_path, ) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid path: Path traversal detected", ) - return str(resolved_path) + return resolved_path except HTTPException: raise diff --git a/scripts/run-sonar-cloud.sh b/scripts/run-sonar-cloud.sh index bda7b8a..f703651 100755 --- a/scripts/run-sonar-cloud.sh +++ b/scripts/run-sonar-cloud.sh @@ -56,9 +56,9 @@ if [ -z "${SONAR_CLOUD_TOKEN:-}" ]; then fi # Set defaults if not provided -SONAR_HOST_URL="${SONAR_CLOUD_URL:-https://sonarcloud.io}" -SONAR_ORG="${SONAR_CLOUD_ORG:-}" -SONAR_PROJECT_KEY="${SONAR_CLOUD_PROJECT:-claude-code-api}" +SONAR_HOST_URL="${SONAR_HOST_URL:-${SONAR_CLOUD_URL:-https://sonarcloud.io}}" +SONAR_ORG="${SONAR_ORG:-${SONAR_CLOUD_ORG:-}}" +SONAR_PROJECT_KEY="${SONAR_PROJECT_KEY:-${SONAR_CLOUD_PROJECT:-claude-code-api}}" # Generate coverage for SonarCloud echo "Generating coverage report for SonarCloud..." From 05ed92e18f697cd0aee26f9b55697bb89f73f572 Mon Sep 17 00:00:00 2001 From: Mehdi Date: Thu, 5 Feb 2026 01:24:58 +0000 Subject: [PATCH 9/9] Fix Sonar --- .github/workflows/sonarcloud.yml | 6 ++++-- claude_code_api/core/security.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/sonarcloud.yml b/.github/workflows/sonarcloud.yml index 071fc4e..fb7d277 100644 --- a/.github/workflows/sonarcloud.yml +++ b/.github/workflows/sonarcloud.yml @@ -1,8 +1,10 @@ name: SonarCloud QA Gate on: - push: pull_request: + push: + branches: + - main jobs: sonarcloud: @@ -34,7 +36,7 @@ jobs: args: > -Dsonar.host.url=https://sonarcloud.io -Dsonar.organization=codingworkflow - -Dsonar.projectKey=codingworkflow_claude-code-a-api + -Dsonar.projectKey=codingworkflow_claude-code-api - name: SonarCloud quality gate uses: SonarSource/sonarqube-quality-gate-action@v1.1.0 diff --git a/claude_code_api/core/security.py b/claude_code_api/core/security.py index c58be20..41ce790 100644 --- a/claude_code_api/core/security.py +++ b/claude_code_api/core/security.py @@ -136,7 +136,7 @@ def ensure_directory_within_base( resolved_path = resolve_path_within_base(path_value, base_path) _ensure_within_base(path_value, base_path, resolved_path) try: - os.makedirs(resolved_path, exist_ok=True) + os.makedirs(resolved_path, exist_ok=True) # codeql[py/path-injection] except FileExistsError as e: raise _bad_request( f"Invalid path: {resolved_path} exists and is not a directory"