From 5bc7359e1324e2d8321768231769f728667c1ba7 Mon Sep 17 00:00:00 2001 From: David Cramer Date: Thu, 5 Mar 2026 16:07:26 -0800 Subject: [PATCH] feat(provider): Add Pi provider support across runtime and action Thread provider selection through config, CLI, SDK, and GitHub Action flows so runs can target Claude or Pi consistently with provider-aware auth and telemetry. Add provider-specific tool policy defaults for the core loop, including Pi shell/read/write/edit tools while preserving skill-level allow/deny overrides. Also record provider metadata in outputs and add regression coverage for provider wiring. Co-Authored-By: GPT-5 Codex --- action.yml | 6 ++ src/action/fix-evaluation/judge.ts | 5 +- src/action/inputs.test.ts | 12 +++- src/action/inputs.ts | 44 ++++++++++----- src/action/triggers/executor.test.ts | 3 +- src/action/triggers/executor.ts | 11 ++-- src/action/workflow/base.ts | 17 ++++++ src/action/workflow/pr-workflow.ts | 11 +++- src/action/workflow/schedule.ts | 11 +++- src/cli/args.test.ts | 5 ++ src/cli/args.ts | 5 ++ src/cli/commands/init.ts | 2 + src/cli/main.ts | 73 ++++++++++++++---------- src/cli/output/jsonl.ts | 6 +- src/cli/output/tasks.ts | 9 +++ src/cli/output/tty.test.ts | 1 + src/config/loader.test.ts | 21 +++++++ src/config/loader.ts | 14 ++++- src/config/schema.ts | 6 ++ src/sdk/analyze.test.ts | 39 ++++++++++++- src/sdk/analyze.ts | 83 +++++++++++++++++++++++----- src/sdk/auth.test.ts | 11 ++++ src/sdk/auth.ts | 5 +- src/sdk/errors.ts | 1 + src/sdk/extract.ts | 10 +++- src/sdk/fix-quality.ts | 9 ++- src/sdk/haiku.ts | 11 +++- src/sdk/types.ts | 3 + src/types/index.ts | 2 + src/utils/index.ts | 11 ++++ 30 files changed, 365 insertions(+), 82 deletions(-) diff --git a/action.yml b/action.yml index 4f117229..eca3f90a 100644 --- a/action.yml +++ b/action.yml @@ -7,6 +7,9 @@ branding: color: 'purple' inputs: + provider: + description: 'Provider to use (claude or pi). Overrides WARDEN_PROVIDER when set.' + required: false anthropic-api-key: description: 'Anthropic API key (sk-ant-...) or OAuth token. Can also be set via ANTHROPIC_API_KEY or CLAUDE_CODE_OAUTH_TOKEN env vars.' required: false @@ -61,6 +64,7 @@ runs: using: 'composite' steps: - name: Install Claude Code CLI + if: ${{ inputs.provider != 'pi' && env.WARDEN_PROVIDER != 'pi' }} shell: bash run: | CLAUDE_CODE_VERSION="2.1.32" @@ -84,6 +88,7 @@ runs: id: warden shell: bash env: + INPUT_PROVIDER: ${{ inputs.provider }} INPUT_ANTHROPIC_API_KEY: ${{ inputs.anthropic-api-key }} INPUT_GITHUB_TOKEN: ${{ inputs.github-token }} INPUT_CONFIG_PATH: ${{ inputs.config-path }} @@ -93,5 +98,6 @@ runs: INPUT_REQUEST_CHANGES: ${{ inputs.request-changes }} INPUT_FAIL_CHECK: ${{ inputs.fail-check }} INPUT_PARALLEL: ${{ inputs.parallel }} + WARDEN_PROVIDER: ${{ inputs.provider || env.WARDEN_PROVIDER }} CLAUDE_CODE_PATH: ${{ env.HOME }}/.local/bin/claude run: node ${{ github.action_path }}/dist/action/index.js diff --git a/src/action/fix-evaluation/judge.ts b/src/action/fix-evaluation/judge.ts index 19165e35..ac3f8094 100644 --- a/src/action/fix-evaluation/judge.ts +++ b/src/action/fix-evaluation/judge.ts @@ -7,6 +7,7 @@ import { emptyUsage } from '../../sdk/usage.js'; import { FixJudgeVerdictSchema } from './types.js'; import type { FixJudgeResult } from './types.js'; import { fetchFileContent, fetchFileLines } from './github.js'; +import type { Provider } from '../../config/schema.js'; export interface FixJudgeInput { comment: ExistingComment; @@ -212,7 +213,8 @@ export async function evaluateFix( input: FixJudgeInput, context: FixJudgeContext, apiKey: string, - maxRetries?: number + maxRetries?: number, + provider: Provider = 'claude' ): Promise { const fallback: FixJudgeResult = { verdict: { status: 'not_attempted', reasoning: 'Evaluation failed' }, @@ -225,6 +227,7 @@ export async function evaluateFix( const result = await callHaikuWithTools({ apiKey, + provider, prompt, schema: FixJudgeVerdictSchema, tools: TOOL_DEFINITIONS, diff --git a/src/action/inputs.test.ts b/src/action/inputs.test.ts index 5dab374e..320d48b5 100644 --- a/src/action/inputs.test.ts +++ b/src/action/inputs.test.ts @@ -38,11 +38,19 @@ describe('parseActionInputs', () => { expect(inputs.anthropicApiKey).toBe(''); }); - it('throws when no auth token is found', () => { + it('allows empty auth tokens (validated later by provider)', () => { delete process.env['ANTHROPIC_API_KEY']; delete process.env['WARDEN_ANTHROPIC_API_KEY']; delete process.env['CLAUDE_CODE_OAUTH_TOKEN']; - expect(() => parseActionInputs()).toThrow('Authentication not found'); + const inputs = parseActionInputs(); + expect(inputs.anthropicApiKey).toBe(''); + expect(inputs.oauthToken).toBe(''); + }); + + it('parses provider from INPUT_PROVIDER', () => { + process.env['INPUT_PROVIDER'] = 'pi'; + const inputs = parseActionInputs(); + expect(inputs.provider).toBe('pi'); }); }); diff --git a/src/action/inputs.ts b/src/action/inputs.ts index 29b0b2ee..bf89bc33 100644 --- a/src/action/inputs.ts +++ b/src/action/inputs.ts @@ -7,14 +7,19 @@ import { SeverityThresholdSchema } from '../types/index.js'; import type { SeverityThreshold } from '../types/index.js'; import { DEFAULT_CONCURRENCY } from '../utils/index.js'; +import type { Provider } from '../config/schema.js'; // ----------------------------------------------------------------------------- // Types // ----------------------------------------------------------------------------- export interface ActionInputs { + /** Optional provider override (defaults to config/env) */ + provider?: Provider; /** API key for Anthropic API (empty if using OAuth) */ anthropicApiKey: string; + /** API key for Pi provider */ + piApiKey?: string; /** OAuth token for Claude Code (empty if using API key) */ oauthToken: string; githubToken: string; @@ -58,31 +63,32 @@ function parseBooleanInput(value: string): boolean | undefined { return undefined; } +function parseProviderInput(value: string): Provider | undefined { + return value === 'claude' || value === 'pi' ? value : undefined; +} + /** * Parse action inputs from the GitHub Actions environment. * Throws if required inputs are missing. */ export function parseActionInputs(): ActionInputs { - // Check for auth token: supports both API keys and OAuth tokens + const providerInput = getInput('provider') || process.env['WARDEN_PROVIDER'] || ''; + const provider = parseProviderInput(providerInput); + + // Claude auth token: supports both API keys and OAuth tokens // Priority: input > WARDEN_ANTHROPIC_API_KEY > ANTHROPIC_API_KEY > CLAUDE_CODE_OAUTH_TOKEN - const authToken = + const claudeAuthToken = getInput('anthropic-api-key') || process.env['WARDEN_ANTHROPIC_API_KEY'] || process.env['ANTHROPIC_API_KEY'] || process.env['CLAUDE_CODE_OAUTH_TOKEN'] || ''; - if (!authToken) { - throw new Error( - 'Authentication not found. Provide an API key via anthropic-api-key input, ' + - 'ANTHROPIC_API_KEY env var, or OAuth token via CLAUDE_CODE_OAUTH_TOKEN env var.' - ); - } - // Detect token type: OAuth tokens start with 'sk-ant-oat', API keys are other 'sk-ant-' prefixes - const isOAuthToken = authToken.startsWith('sk-ant-oat'); - const anthropicApiKey = isOAuthToken ? '' : authToken; - const oauthToken = isOAuthToken ? authToken : ''; + const isOAuthToken = claudeAuthToken.startsWith('sk-ant-oat'); + const anthropicApiKey = isOAuthToken ? '' : claudeAuthToken; + const oauthToken = isOAuthToken ? claudeAuthToken : ''; + const piApiKey = process.env['WARDEN_PI_API_KEY'] || ''; const failOnInput = getInput('fail-on'); const failOn = SeverityThresholdSchema.safeParse(failOnInput).success @@ -101,7 +107,9 @@ export function parseActionInputs(): ActionInputs { const failCheck = parseBooleanInput(getInput('fail-check')); return { + provider, anthropicApiKey, + piApiKey, oauthToken, githubToken: getInput('github-token') || process.env['GITHUB_TOKEN'] || '', configPath: getInput('config-path') || 'warden.toml', @@ -129,10 +137,18 @@ export function validateInputs(inputs: ActionInputs): void { * Sets appropriate env vars based on token type (API key vs OAuth). */ export function setupAuthEnv(inputs: ActionInputs): void { + if (inputs.provider) { + process.env['WARDEN_PROVIDER'] = inputs.provider; + } + if (inputs.piApiKey) { + process.env['WARDEN_PI_API_KEY'] = inputs.piApiKey; + } if (inputs.oauthToken) { process.env['CLAUDE_CODE_OAUTH_TOKEN'] = inputs.oauthToken; } else { - process.env['WARDEN_ANTHROPIC_API_KEY'] = inputs.anthropicApiKey; - process.env['ANTHROPIC_API_KEY'] = inputs.anthropicApiKey; + if (inputs.anthropicApiKey) { + process.env['WARDEN_ANTHROPIC_API_KEY'] = inputs.anthropicApiKey; + process.env['ANTHROPIC_API_KEY'] = inputs.anthropicApiKey; + } } } diff --git a/src/action/triggers/executor.test.ts b/src/action/triggers/executor.test.ts index b262f9f4..916c6693 100644 --- a/src/action/triggers/executor.test.ts +++ b/src/action/triggers/executor.test.ts @@ -74,7 +74,8 @@ describe('executeTrigger', () => { octokit: mockOctokit, context: mockContext, config: mockConfig, - anthropicApiKey: 'test-key', + apiKey: 'test-key', + provider: 'claude', claudePath: '/test/claude', globalMaxFindings: 10, }; diff --git a/src/action/triggers/executor.ts b/src/action/triggers/executor.ts index 42c77e46..5a8f6610 100644 --- a/src/action/triggers/executor.ts +++ b/src/action/triggers/executor.ts @@ -10,6 +10,7 @@ import { Sentry } from '../../sentry.js'; import { ActionFailedError } from '../workflow/base.js'; import type { ResolvedTrigger } from '../../config/loader.js'; import type { WardenConfig } from '../../config/schema.js'; +import type { Provider } from '../../config/schema.js'; import type { EventContext, SkillReport, SeverityThreshold, ConfidenceThreshold } from '../../types/index.js'; import type { RenderResult } from '../../output/types.js'; import type { OutputMode } from '../../cli/output/tty.js'; @@ -43,8 +44,9 @@ export interface TriggerExecutorDeps { octokit: Octokit; context: EventContext; config: WardenConfig; - anthropicApiKey: string; - claudePath: string; + apiKey: string; + provider: Provider; + claudePath?: string; /** Global fail-on from action inputs (trigger-specific takes precedence) */ globalFailOn?: SeverityThreshold; /** Global report-on from action inputs (trigger-specific takes precedence) */ @@ -97,7 +99,7 @@ export async function executeTrigger( { op: 'trigger.execute', name: `execute ${trigger.name}` }, async (span) => { span.setAttribute('skill.name', trigger.skill); - const { octokit, context, config, anthropicApiKey, claudePath } = deps; + const { octokit, context, config, apiKey, claudePath, provider } = deps; logGroup(`Running trigger: ${trigger.name} (skill: ${trigger.skill})`); @@ -134,7 +136,8 @@ export async function executeTrigger( }), context: filterContextByPaths(context, trigger.filters), runnerOptions: { - apiKey: anthropicApiKey, + apiKey, + provider, model: trigger.model, maxTurns: trigger.maxTurns, batchDelayMs: config.defaults?.batchDelayMs, diff --git a/src/action/workflow/base.ts b/src/action/workflow/base.ts index 3df679b6..5cba8de4 100644 --- a/src/action/workflow/base.ts +++ b/src/action/workflow/base.ts @@ -13,6 +13,8 @@ import { execFileNonInteractive } from '../../utils/exec.js'; import type { EventContext, SkillReport } from '../../types/index.js'; import { countSeverity } from '../../triggers/matcher.js'; import type { TriggerResult } from '../triggers/executor.js'; +import type { Provider, WardenConfig } from '../../config/schema.js'; +import type { ActionInputs } from '../inputs.js'; /** * Sentinel error thrown by setFailed() so the top-level catch handler @@ -187,6 +189,21 @@ export function setWorkflowOutputs(outputs: WorkflowOutputs): void { setOutput('summary', outputs.summary); } +export function resolveActionProvider(inputs: ActionInputs, config?: WardenConfig): Provider { + if (inputs.provider) return inputs.provider; + const envProvider = process.env['WARDEN_PROVIDER']; + if (envProvider === 'claude' || envProvider === 'pi') return envProvider; + const cfgProvider = config?.defaults?.provider; + if (cfgProvider === 'claude' || cfgProvider === 'pi') return cfgProvider; + return 'claude'; +} + +export function getActionProviderApiKey(provider: Provider, inputs: ActionInputs): string { + return provider === 'pi' + ? (inputs.piApiKey ?? inputs.anthropicApiKey) + : inputs.anthropicApiKey; +} + // ----------------------------------------------------------------------------- // GitHub API Helpers // ----------------------------------------------------------------------------- diff --git a/src/action/workflow/pr-workflow.ts b/src/action/workflow/pr-workflow.ts index 07a94018..06d0477c 100644 --- a/src/action/workflow/pr-workflow.ts +++ b/src/action/workflow/pr-workflow.ts @@ -48,6 +48,8 @@ import { setWorkflowOutputs, getAuthenticatedBotLogin, writeFindingsOutput, + resolveActionProvider, + getActionProviderApiKey, } from './base.js'; // ----------------------------------------------------------------------------- @@ -148,7 +150,7 @@ async function initializeWorkflow( } // Resolve skills into triggers and match - const resolvedTriggers = resolveSkillConfigs(config); + const resolvedTriggers = resolveSkillConfigs(config, undefined, inputs.provider); const matchedTriggers = resolvedTriggers.filter((t) => matchTrigger(t, context, 'github')); if (matchedTriggers.length > 0) { @@ -250,7 +252,9 @@ async function executeAllTriggers( inputs: ActionInputs ): Promise { const concurrency = config.runner?.concurrency ?? inputs.parallel; - const claudePath = await findClaudeCodeExecutable(); + const provider = resolveActionProvider(inputs, config); + const apiKey = getActionProviderApiKey(provider, inputs); + const claudePath = provider === 'claude' ? await findClaudeCodeExecutable() : undefined; // Global semaphore gates file-level work across all triggers. // All triggers launch immediately; the semaphore limits concurrent file analyses. @@ -264,7 +268,8 @@ async function executeAllTriggers( octokit, context, config, - anthropicApiKey: inputs.anthropicApiKey, + apiKey, + provider, claudePath, globalFailOn: inputs.failOn, globalReportOn: inputs.reportOn, diff --git a/src/action/workflow/schedule.ts b/src/action/workflow/schedule.ts index 36e499e3..1bef3b1f 100644 --- a/src/action/workflow/schedule.ts +++ b/src/action/workflow/schedule.ts @@ -26,6 +26,8 @@ import { handleTriggerErrors, getDefaultBranchFromAPI, writeFindingsOutput, + resolveActionProvider, + getActionProviderApiKey, } from './base.js'; // ----------------------------------------------------------------------------- @@ -67,7 +69,9 @@ export async function runScheduleWorkflow( } // Find schedule triggers - const scheduleTriggers = resolveSkillConfigs(config).filter((t) => t.type === 'schedule'); + const scheduleTriggers = resolveSkillConfigs(config, undefined, inputs.provider).filter((t) => t.type === 'schedule'); + const provider = resolveActionProvider(inputs, config); + const apiKey = getActionProviderApiKey(provider, inputs); if (scheduleTriggers.length === 0) { console.log('No schedule triggers configured'); setOutput('findings-count', 0); @@ -147,9 +151,10 @@ export async function runScheduleWorkflow( const skill = await resolveSkillAsync(resolved.skill, repoPath, { remote: resolved.remote, }); - const claudePath = await findClaudeCodeExecutable(); + const claudePath = provider === 'claude' ? await findClaudeCodeExecutable() : undefined; const report = await runSkill(skill, context, { - apiKey: inputs.anthropicApiKey, + apiKey, + provider, model: resolved.model, maxTurns: resolved.maxTurns, batchDelayMs: config.defaults?.batchDelayMs, diff --git a/src/cli/args.test.ts b/src/cli/args.test.ts index 5184fdbb..60368b92 100644 --- a/src/cli/args.test.ts +++ b/src/cli/args.test.ts @@ -66,6 +66,11 @@ describe('parseCliArgs', () => { expect(result.options.config).toBe('./custom.toml'); }); + it('parses --provider option', () => { + const result = parseCliArgs(['--provider', 'pi']); + expect(result.options.provider).toBe('pi'); + }); + it('parses --json flag', () => { const result = parseCliArgs(['--json']); expect(result.options.json).toBe(true); diff --git a/src/cli/args.ts b/src/cli/args.ts index f836d254..65a37e0d 100644 --- a/src/cli/args.ts +++ b/src/cli/args.ts @@ -22,6 +22,8 @@ export const CLIOptionsSchema = z.object({ parallel: z.number().int().positive().optional(), /** Model to use for analysis (fallback when not set in config) */ model: z.string().optional(), + /** Provider to use for analysis (fallback when not set in config) */ + provider: z.enum(['claude', 'pi']).optional(), // Verbosity options quiet: z.boolean().default(false), verbose: z.number().default(0), @@ -100,6 +102,7 @@ Options: --skill Run only this skill (default: run all built-in skills) --config Path to warden.toml (default: ./warden.toml) -m, --model Model to use (fallback when not set in config) + --provider Provider to use (claude, pi) --json Output results as JSON -o, --output Write full run output to a JSONL file --fail-on Exit with code 1 if findings >= severity @@ -281,6 +284,7 @@ export function parseCliArgs(argv: string[] = process.argv.slice(2)): ParsedArgs skill: { type: 'string' }, config: { type: 'string' }, model: { type: 'string', short: 'm' }, + provider: { type: 'string' }, json: { type: 'boolean', default: false }, output: { type: 'string', short: 'o' }, 'fail-on': { type: 'string' }, @@ -470,6 +474,7 @@ export function parseCliArgs(argv: string[] = process.argv.slice(2)): ParsedArgs skill: values.skill, config: values.config, model: values.model, + provider: values.provider === 'claude' || values.provider === 'pi' ? values.provider : undefined, json: values.json, output: values.output, failOn: values['fail-on'] as SeverityThreshold | undefined, diff --git a/src/cli/commands/init.ts b/src/cli/commands/init.ts index 26ae1428..29675b54 100644 --- a/src/cli/commands/init.ts +++ b/src/cli/commands/init.ts @@ -117,11 +117,13 @@ jobs: runs-on: ubuntu-latest env: WARDEN_MODEL: \${{ secrets.WARDEN_MODEL }} + WARDEN_PROVIDER: \${{ vars.WARDEN_PROVIDER }} WARDEN_SENTRY_DSN: \${{ secrets.WARDEN_SENTRY_DSN }} steps: - uses: actions/checkout@v4 - uses: getsentry/warden@v${majorVersion} with: + provider: \${{ vars.WARDEN_PROVIDER }} anthropic-api-key: \${{ secrets.WARDEN_ANTHROPIC_API_KEY }} `; } diff --git a/src/cli/main.ts b/src/cli/main.ts index 5e078af3..e09c5c79 100644 --- a/src/cli/main.ts +++ b/src/cli/main.ts @@ -8,7 +8,7 @@ import { resolveSkillAsync } from '../skills/loader.js'; import { matchTrigger, filterContextByPaths, shouldFail, countFindingsAtOrAbove } from '../triggers/matcher.js'; import type { SkillReport, ConfidenceThreshold } from '../types/index.js'; import { filterFindings } from '../types/index.js'; -import { DEFAULT_CONCURRENCY, getAnthropicApiKey } from '../utils/index.js'; +import { DEFAULT_CONCURRENCY, getProviderApiKey } from '../utils/index.js'; import { parseCliArgs, showHelp, showVersion, classifyTargets, type CLIOptions } from './args.js'; import { buildLocalEventContext, buildFileEventContext } from './context.js'; import { getRepoRoot, getHeadSha, refExists, hasUncommittedChanges } from './git.js'; @@ -96,6 +96,22 @@ function resolveConfigPath(options: CLIOptions, repoPath: string): string { return options.config ? resolve(cwd, options.config) : resolve(repoPath, 'warden.toml'); } +function resolveProvider(configProvider?: string, cliProvider?: string): 'claude' | 'pi' { + if (configProvider === 'claude' || configProvider === 'pi') return configProvider; + if (cliProvider === 'claude' || cliProvider === 'pi') return cliProvider; + const envProvider = process.env['WARDEN_PROVIDER']; + if (envProvider === 'claude' || envProvider === 'pi') return envProvider; + return 'claude'; +} + +function logMissingProviderApiKey(reporter: Reporter, provider: 'claude' | 'pi'): void { + reporter.debug( + provider === 'claude' + ? 'No API key found. Using Claude Code subscription auth.' + : 'No API key found for Pi provider. Falling back to Claude auth sources.' + ); +} + /** * Write a minimal JSONL log (summary-only, 0 findings) for early-exit paths. * Returns the rendered content and the log file path. The content is always @@ -306,12 +322,6 @@ async function runSkills( const cwd = process.cwd(); const startTime = Date.now(); - // Get API key (optional - SDK can use Claude Code subscription auth) - const apiKey = getAnthropicApiKey(); - if (!apiKey) { - reporter.debug('No API key found. Using Claude Code subscription auth.'); - } - // Try to find repo root for config loading let repoPath: string | undefined; try { @@ -320,9 +330,27 @@ async function runSkills( // Not in a git repo - that's fine for file mode } + // Resolve config path + let configPath: string | null = null; + if (options.config) { + configPath = resolve(cwd, options.config); + } else if (repoPath) { + configPath = resolve(repoPath, 'warden.toml'); + } + + // Load config if available + const config = configPath && existsSync(configPath) + ? loadWardenConfig(dirname(configPath)) + : null; + const provider = resolveProvider(config?.defaults?.provider, options.provider); + const apiKey = getProviderApiKey(provider); + if (!apiKey) { + logMissingProviderApiKey(reporter, provider); + } + // Pre-flight: verify auth will work before starting analysis try { - verifyAuth({ apiKey }); + verifyAuth({ apiKey, provider }); } catch (error: unknown) { reporter.error((error as WardenAuthenticationError).message); const effectiveRepo = repoPath ?? cwd; @@ -335,25 +363,12 @@ async function runSkills( return 1; } - // Resolve config path - let configPath: string | null = null; - if (options.config) { - configPath = resolve(cwd, options.config); - } else if (repoPath) { - configPath = resolve(repoPath, 'warden.toml'); - } - - // Load config if available - const config = configPath && existsSync(configPath) - ? loadWardenConfig(dirname(configPath)) - : null; - // Determine which triggers/skills to run let skillsToRun: SkillToRun[]; if (options.skill) { // Explicit skill specified via CLI — check config for remote/filters if available const match = config - ? resolveSkillConfigs(config, options.model).find((t) => t.skill === options.skill) + ? resolveSkillConfigs(config, options.model, options.provider).find((t) => t.skill === options.skill) : undefined; // Fall back to global defaults when the skill isn't in the config const defaultIgnorePaths = config?.defaults?.ignorePaths; @@ -363,7 +378,7 @@ async function runSkills( skillsToRun = [{ skill: options.skill, remote: match?.remote, filters: match?.filters ?? fallbackFilters }]; } else if (config) { // Get skills from matched triggers, preserving remote property and filters - const resolvedTriggers = resolveSkillConfigs(config, options.model); + const resolvedTriggers = resolveSkillConfigs(config, options.model, options.provider); const matchedTriggers = resolvedTriggers.filter((t) => matchTrigger(t, context, 'local')); // Dedupe by skill name but keep first occurrence (with its remote property and filters) const seen = new Set(); @@ -405,6 +420,7 @@ async function runSkills( const runnerOptions: SkillRunnerOptions = { apiKey, model: sdkModel, + provider, abortController, maxTurns: config?.defaults?.maxTurns, batchDelayMs: config?.defaults?.batchDelayMs, @@ -632,7 +648,7 @@ async function runConfigMode(options: CLIOptions, reporter: Reporter): Promise matchTrigger(t, context, 'local')); // Filter by skill if specified @@ -673,15 +689,15 @@ async function runConfigMode(options: CLIOptions, reporter: Reporter): Promise ({ name: t.name, skill: t.skill })) ); - // Get API key (optional - SDK can use Claude Code subscription auth) - const apiKey = getAnthropicApiKey(); + const provider = resolveProvider(config.defaults?.provider, options.provider); + const apiKey = getProviderApiKey(provider); if (!apiKey) { - reporter.debug('No API key found. Using Claude Code subscription auth.'); + logMissingProviderApiKey(reporter, provider); } // Pre-flight: verify auth will work before starting analysis try { - verifyAuth({ apiKey }); + verifyAuth({ apiKey, provider }); } catch (error: unknown) { reporter.error((error as WardenAuthenticationError).message); if (options.json) { @@ -708,6 +724,7 @@ async function runConfigMode(options: CLIOptions, reporter: Reporter): Promise { const r = allResults[i]; return { diff --git a/src/cli/output/tty.test.ts b/src/cli/output/tty.test.ts index d038ade8..972765d0 100644 --- a/src/cli/output/tty.test.ts +++ b/src/cli/output/tty.test.ts @@ -90,6 +90,7 @@ describe('detectOutputMode', () => { it('respects FORCE_COLOR environment variable', () => { process.env['TERM'] = 'xterm-256color'; + delete process.env['NO_COLOR']; process.env['FORCE_COLOR'] = '1'; const mode = detectOutputMode(); expect(mode.supportsColor).toBe(true); diff --git a/src/config/loader.test.ts b/src/config/loader.test.ts index 203517de..6ef77676 100644 --- a/src/config/loader.test.ts +++ b/src/config/loader.test.ts @@ -447,6 +447,27 @@ describe('resolveSkillConfigs', () => { expect(resolved?.minConfidence).toBeUndefined(); }); }); + + describe('provider resolution', () => { + it('defaults to claude when provider is not configured', () => { + const [resolved] = resolveSkillConfigs(baseConfig); + expect(resolved?.provider).toBe('claude'); + }); + + it('uses defaults.provider when set', () => { + const config: WardenConfig = { + ...baseConfig, + defaults: { provider: 'pi' }, + }; + const [resolved] = resolveSkillConfigs(config); + expect(resolved?.provider).toBe('pi'); + }); + + it('uses cliProvider when defaults.provider is unset', () => { + const [resolved] = resolveSkillConfigs(baseConfig, undefined, 'pi'); + expect(resolved?.provider).toBe('pi'); + }); + }); }); describe('maxTurns config', () => { diff --git a/src/config/loader.ts b/src/config/loader.ts index 22b03213..6a451025 100644 --- a/src/config/loader.ts +++ b/src/config/loader.ts @@ -7,6 +7,7 @@ import { type WardenConfig, type ScheduleConfig, type TriggerType, + type Provider, } from './schema.js'; import type { SeverityThreshold, ConfidenceThreshold } from '../types/index.js'; @@ -94,6 +95,8 @@ export interface ResolvedTrigger { failCheck?: boolean; /** Model (merged: trigger > skill > defaults > cli > env) */ model?: string; + /** Provider (merged: defaults > cli > env) */ + provider?: Provider; /** Max agentic turns (merged: trigger > skill > defaults) */ maxTurns?: number; /** Minimum confidence for findings (merged: trigger > skill > defaults) */ @@ -111,6 +114,10 @@ function emptyToUndefined(value: string | undefined): string | undefined { return value === '' ? undefined : value; } +function parseProvider(value: string | undefined): Provider | undefined { + return value === 'claude' || value === 'pi' ? value : undefined; +} + /** * Resolve all skills in a config into a flat array of ResolvedTriggers. * Each skill x trigger combination produces one entry. @@ -126,10 +133,13 @@ function emptyToUndefined(value: string | undefined): string | undefined { */ export function resolveSkillConfigs( config: WardenConfig, - cliModel?: string + cliModel?: string, + cliProvider?: Provider ): ResolvedTrigger[] { const defaults = config.defaults; const envModel = emptyToUndefined(process.env['WARDEN_MODEL']); + const envProvider = parseProvider(emptyToUndefined(process.env['WARDEN_PROVIDER'])); + const baseProvider = defaults?.provider ?? cliProvider ?? envProvider ?? 'claude'; const result: ResolvedTrigger[] = []; for (const skill of config.skills) { @@ -165,6 +175,7 @@ export function resolveSkillConfigs( requestChanges: skill.requestChanges ?? defaults?.requestChanges, failCheck: skill.failCheck ?? defaults?.failCheck, model: baseModel, + provider: baseProvider, maxTurns: skill.maxTurns ?? defaults?.maxTurns, minConfidence: skill.minConfidence ?? defaults?.minConfidence, }); @@ -185,6 +196,7 @@ export function resolveSkillConfigs( requestChanges: trigger.requestChanges ?? skill.requestChanges ?? defaults?.requestChanges, failCheck: trigger.failCheck ?? skill.failCheck ?? defaults?.failCheck, model: emptyToUndefined(trigger.model) ?? baseModel, + provider: baseProvider, maxTurns: trigger.maxTurns ?? skill.maxTurns ?? defaults?.maxTurns, minConfidence: trigger.minConfidence ?? skill.minConfidence ?? defaults?.minConfidence, schedule: trigger.schedule, diff --git a/src/config/schema.ts b/src/config/schema.ts index c2ba4e22..cf9e5999 100644 --- a/src/config/schema.ts +++ b/src/config/schema.ts @@ -1,6 +1,10 @@ import { z } from 'zod'; import { SeverityThresholdSchema, ConfidenceThresholdSchema } from '../types/index.js'; +// LLM provider selection +export const ProviderSchema = z.enum(['claude', 'pi']); +export type Provider = z.infer; + // Tool names that can be allowed/denied export const ToolNameSchema = z.enum([ 'Read', @@ -152,6 +156,8 @@ export type ChunkingConfig = z.infer; // Default configuration that skills inherit from export const DefaultsSchema = z.object({ + /** Default provider for all skills (claude|pi). Default: claude */ + provider: ProviderSchema.optional(), /** Fail the build when findings meet this severity */ failOn: SeverityThresholdSchema.optional(), /** Only report findings at or above this severity */ diff --git a/src/sdk/analyze.test.ts b/src/sdk/analyze.test.ts index 6479f514..ba13bad3 100644 --- a/src/sdk/analyze.test.ts +++ b/src/sdk/analyze.test.ts @@ -1,6 +1,7 @@ import { describe, it, expect } from 'vitest'; -import { filterOutOfRangeFindings } from './analyze.js'; +import { filterOutOfRangeFindings, resolveToolPolicy } from './analyze.js'; import type { Finding } from '../types/index.js'; +import type { SkillDefinition } from '../config/schema.js'; function makeFinding(startLine: number, id = `f-${startLine}`): Finding { return { @@ -78,3 +79,39 @@ describe('filterOutOfRangeFindings', () => { expect(dropped).toEqual([]); }); }); + +describe('resolveToolPolicy', () => { + const baseSkill: SkillDefinition = { + name: 'test-skill', + description: 'test', + prompt: 'test prompt', + }; + + it('uses read-only defaults for claude provider', () => { + const policy = resolveToolPolicy('claude', baseSkill); + expect(policy.allowedTools).toEqual(['Read', 'Grep', 'Glob']); + expect(policy.disallowedTools).toContain('Write'); + expect(policy.disallowedTools).toContain('Bash'); + }); + + it('enables shell/read/write defaults for pi provider', () => { + const policy = resolveToolPolicy('pi', baseSkill); + expect(policy.allowedTools).toContain('Read'); + expect(policy.allowedTools).toContain('Write'); + expect(policy.allowedTools).toContain('Bash'); + expect(policy.disallowedTools).not.toContain('Write'); + expect(policy.disallowedTools).not.toContain('Bash'); + }); + + it('applies skill allowed-tools override and denied-tools filtering', () => { + const policy = resolveToolPolicy('pi', { + ...baseSkill, + tools: { + allowed: ['Read', 'Write', 'Bash'], + denied: ['Write'], + }, + }); + expect(policy.allowedTools).toEqual(['Read', 'Bash']); + expect(policy.disallowedTools).toContain('Write'); + }); +}); diff --git a/src/sdk/analyze.ts b/src/sdk/analyze.ts index c1cb11c0..777701b0 100644 --- a/src/sdk/analyze.ts +++ b/src/sdk/analyze.ts @@ -49,7 +49,8 @@ async function parseHunkOutput( result: SDKResultMessage, filename: string, apiKey?: string, - auxiliaryMaxRetries?: number + auxiliaryMaxRetries?: number, + provider?: 'claude' | 'pi' ): Promise { if (result.subtype !== 'success') { // SDK error - not an extraction failure, just no findings @@ -64,7 +65,7 @@ async function parseHunkOutput( } // Tier 2: Try LLM fallback for malformed output - const fallback = await extractFindingsWithLLM(result.result, apiKey, auxiliaryMaxRetries); + const fallback = await extractFindingsWithLLM(result.result, apiKey, auxiliaryMaxRetries, provider); if (fallback.success) { return { findings: validateFindings(fallback.findings, filename), extractionFailed: false, extractionMethod: 'llm', extractionUsage: fallback.usage }; @@ -129,6 +130,45 @@ interface QueryExecutionResult { stderr?: string; } +const CLAUDE_DEFAULT_ALLOWED_TOOLS = ['Read', 'Grep', 'Glob'] as const; +const CLAUDE_DEFAULT_DISALLOWED_TOOLS = ['Write', 'Edit', 'Bash', 'WebFetch', 'WebSearch', 'Task', 'TodoWrite'] as const; +const PI_DEFAULT_ALLOWED_TOOLS = ['Read', 'Write', 'Edit', 'Bash', 'Grep', 'Glob'] as const; +const PI_DEFAULT_DISALLOWED_TOOLS = ['WebFetch', 'WebSearch', 'Task', 'TodoWrite'] as const; + +export function resolveToolPolicy( + provider: 'claude' | 'pi', + skill: SkillDefinition +): { allowedTools: string[]; disallowedTools: string[] } { + const baseAllowed = provider === 'pi' + ? [...PI_DEFAULT_ALLOWED_TOOLS] + : [...CLAUDE_DEFAULT_ALLOWED_TOOLS]; + const baseDisallowed = provider === 'pi' + ? [...PI_DEFAULT_DISALLOWED_TOOLS] + : [...CLAUDE_DEFAULT_DISALLOWED_TOOLS]; + + const allowedFromSkill = skill.tools?.allowed; + const deniedFromSkill = skill.tools?.denied; + + let allowedTools = allowedFromSkill && allowedFromSkill.length > 0 + ? [...allowedFromSkill] + : baseAllowed; + + const disallowed = new Set(baseDisallowed); + if (deniedFromSkill) { + for (const denied of deniedFromSkill) { + disallowed.add(denied); + } + } + for (const denied of disallowed) { + allowedTools = allowedTools.filter((t) => t !== denied); + } + + return { + allowedTools, + disallowedTools: [...disallowed], + }; +} + /** * Execute a single SDK query attempt. * Captures stderr for better error diagnostics when Claude Code fails. @@ -138,9 +178,13 @@ async function executeQuery( userPrompt: string, repoPath: string, options: SkillRunnerOptions, - skillName: string + skill: SkillDefinition ): Promise { - const { maxTurns = 50, model, abortController, pathToClaudeCodeExecutable } = options; + const { maxTurns = 50, model, abortController, pathToClaudeCodeExecutable, provider = 'claude' } = options; + const { allowedTools, disallowedTools } = resolveToolPolicy(provider, skill); + const skillName = skill.name; + + const providerName = provider === 'claude' ? 'anthropic' : provider; const modelId = model ?? 'unknown'; return Sentry.startSpan( @@ -149,7 +193,7 @@ async function executeQuery( name: `invoke_agent ${skillName}`, attributes: { 'gen_ai.operation.name': 'invoke_agent', - 'gen_ai.provider.name': 'anthropic', + 'gen_ai.provider.name': providerName, 'gen_ai.agent.name': skillName, 'gen_ai.request.model': modelId, 'warden.request.max_turns': maxTurns, @@ -170,10 +214,9 @@ async function executeQuery( maxTurns, cwd: repoPath, systemPrompt, - // Only allow read-only tools - context is already provided in the prompt - allowedTools: ['Read', 'Grep', 'Glob'], - // Explicitly block modification/side-effect tools as defense-in-depth - disallowedTools: ['Write', 'Edit', 'Bash', 'WebFetch', 'WebSearch', 'Task', 'TodoWrite'], + allowedTools, + disallowedTools, + ...(provider === 'pi' ? { agent: 'pi' } : {}), permissionMode: 'bypassPermissions', // Prevent SDK from writing session .jsonl files and polluting Claude Code's session index persistSession: false, @@ -215,7 +258,7 @@ async function executeQuery( name: `chat ${skillName} turn ${turnCount}`, attributes: { 'gen_ai.operation.name': 'chat', - 'gen_ai.provider.name': 'anthropic', + 'gen_ai.provider.name': providerName, 'gen_ai.agent.name': skillName, 'gen_ai.response.model': turn.model, 'gen_ai.usage.input_tokens': totalInput, @@ -407,7 +450,7 @@ async function analyzeHunk( } try { - const { result: resultMessage, authError } = await executeQuery(systemPrompt, userPrompt, repoPath, options, skill.name); + const { result: resultMessage, authError } = await executeQuery(systemPrompt, userPrompt, repoPath, options, skill); // Check for authentication errors from auth_status messages // auth_status errors are always auth-related - throw immediately @@ -459,7 +502,13 @@ async function analyzeHunk( }; } - const parseResult = await parseHunkOutput(resultMessage, hunkCtx.filename, apiKey, options.auxiliaryMaxRetries); + const parseResult = await parseHunkOutput( + resultMessage, + hunkCtx.filename, + apiKey, + options.auxiliaryMaxRetries, + options.provider + ); // Filter findings outside hunk line range (defense-in-depth) const hunkRange = getHunkLineRange(hunkCtx.hunk); @@ -524,8 +573,8 @@ async function analyzeHunk( if (isSubprocessError(error)) { const errorMessage = error instanceof Error ? error.message : String(error); throw new WardenAuthenticationError( - `Claude Code subprocess failed (${errorMessage}).\n` + - `This usually means the claude CLI cannot run in this environment.`, + `Provider subprocess failed (${errorMessage}).\n` + + `This usually means the configured provider CLI cannot run in this environment.`, { cause: error } ); } @@ -755,6 +804,7 @@ export async function runSkill( usage: emptyUsage(), durationMs: Date.now() - startTime, model: options.model, + provider: options.provider, }; if (skippedFiles.length > 0) { report.skippedFiles = skippedFiles; @@ -897,7 +947,7 @@ export async function runSkill( throw new SkillRunnerError( `All ${totalHunks} chunk${totalHunks === 1 ? '' : 's'} failed to analyze. ` + `This usually indicates an authentication problem. ` + - `Verify WARDEN_ANTHROPIC_API_KEY is set correctly, or run 'claude login' if using Claude Code subscription.` + `Verify provider credentials are set correctly (WARDEN_ANTHROPIC_API_KEY or WARDEN_PI_API_KEY).` ); } @@ -910,6 +960,7 @@ export async function runSkill( apiKey: options.apiKey, repoPath: context.repoPath, maxRetries: options.auxiliaryMaxRetries, + provider: options.provider, }); let mergedFindings = mergeResult.findings; if (mergeResult.usage) { @@ -919,6 +970,7 @@ export async function runSkill( repoPath: context.repoPath, apiKey: options.apiKey, maxRetries: options.auxiliaryMaxRetries, + provider: options.provider, }); mergedFindings = sanitized.findings; if (sanitized.usage) { @@ -953,6 +1005,7 @@ export async function runSkill( usage: totalUsage, durationMs: Date.now() - startTime, model: options.model, + provider: options.provider, files: fileResults.map((fr) => ({ filename: fr.filename, findingCount: fr.result.findings.length, diff --git a/src/sdk/auth.test.ts b/src/sdk/auth.test.ts index 2c0f6684..4ebda478 100644 --- a/src/sdk/auth.test.ts +++ b/src/sdk/auth.test.ts @@ -21,6 +21,17 @@ describe('verifyAuth', () => { expect(mockExec).not.toHaveBeenCalled(); }); + it('accepts PI API key for pi provider', () => { + verifyAuth({ apiKey: 'pi-key', provider: 'pi' }); + expect(mockExec).not.toHaveBeenCalled(); + }); + + it('falls back to CLI auth checks for pi provider without API key', () => { + mockExec.mockReturnValue('1.0.0'); + verifyAuth({ apiKey: undefined, provider: 'pi' }); + expect(mockExec).toHaveBeenCalledWith('claude', ['--version'], { timeout: 5000 }); + }); + it('checks for claude binary when no API key', () => { mockExec.mockReturnValue('1.0.0'); verifyAuth({ apiKey: undefined }); diff --git a/src/sdk/auth.ts b/src/sdk/auth.ts index b50e65e3..85679684 100644 --- a/src/sdk/auth.ts +++ b/src/sdk/auth.ts @@ -1,5 +1,6 @@ import { ExecError, execFileNonInteractive } from '../utils/exec.js'; import { WardenAuthenticationError } from './errors.js'; +import type { Provider } from '../config/schema.js'; /** * Pre-flight auth check: verify that authentication will work before starting analysis. @@ -13,7 +14,9 @@ import { WardenAuthenticationError } from './errors.js'; * Subtler failures (binary exists but sandbox blocks IPC) are caught by the * isSubprocessError() handler in analyzeHunk(). */ -export function verifyAuth({ apiKey }: { apiKey?: string }): void { +export function verifyAuth({ apiKey, provider = 'claude' }: { apiKey?: string; provider?: Provider }): void { + void provider; + // Direct API auth — no subprocess needed if (apiKey) return; diff --git a/src/sdk/errors.ts b/src/sdk/errors.ts index c286e44b..be3d4c0d 100644 --- a/src/sdk/errors.ts +++ b/src/sdk/errors.ts @@ -35,6 +35,7 @@ export function isAuthenticationErrorMessage(message: string): boolean { const AUTH_ERROR_GUIDANCE = ` claude login # Use Claude Code subscription export WARDEN_ANTHROPIC_API_KEY=sk-... # Or use API key + export WARDEN_PI_API_KEY=... # Pi provider API key https://console.anthropic.com/ for API keys`; diff --git a/src/sdk/extract.ts b/src/sdk/extract.ts index a74a289d..cd022253 100644 --- a/src/sdk/extract.ts +++ b/src/sdk/extract.ts @@ -8,6 +8,7 @@ import type { Finding, Location, UsageStats } from '../types/index.js'; import { Sentry } from '../sentry.js'; import { callHaiku, DEFAULT_AUXILIARY_MAX_RETRIES, HAIKU_MODEL, setGenAiResponseAttrs } from './haiku.js'; import { apiUsageToStats } from './pricing.js'; +import type { Provider } from '../config/schema.js'; /** Pattern to match the start of findings JSON (allows whitespace after brace) */ export const FINDINGS_JSON_START = /\{\s*"findings"/; @@ -174,8 +175,10 @@ export function truncateForLLMFallback(rawText: string, maxChars: number): strin export async function extractFindingsWithLLM( rawText: string, apiKey?: string, - maxRetries?: number + maxRetries?: number, + provider: Provider = 'claude' ): Promise { + const providerName = provider === 'claude' ? 'anthropic' : provider; if (!apiKey) { return { success: false, @@ -202,7 +205,7 @@ export async function extractFindingsWithLLM( name: `chat ${HAIKU_MODEL}`, attributes: { 'gen_ai.operation.name': 'chat', - 'gen_ai.provider.name': 'anthropic', + 'gen_ai.provider.name': providerName, 'gen_ai.request.model': HAIKU_MODEL, 'gen_ai.request.max_tokens': LLM_FALLBACK_MAX_TOKENS, }, @@ -476,7 +479,7 @@ function readSnippet(repoPath: string, filePath: string, startLine: number, cont */ export async function mergeCrossLocationFindings( findings: Finding[], - options?: { apiKey?: string; repoPath?: string; maxRetries?: number } + options?: { apiKey?: string; repoPath?: string; maxRetries?: number; provider?: Provider } ): Promise { const apiKey = options?.apiKey; const repoPath = options?.repoPath ?? '.'; @@ -507,6 +510,7 @@ Singletons should not appear. Return [] if no findings describe the same issue.` const result = await callHaiku({ apiKey, + provider: options?.provider, prompt, schema: MergeGroupsSchema, maxTokens: 512, diff --git a/src/sdk/fix-quality.ts b/src/sdk/fix-quality.ts index 192e6d5c..e416ae20 100644 --- a/src/sdk/fix-quality.ts +++ b/src/sdk/fix-quality.ts @@ -6,6 +6,7 @@ import { applyDiffToContent } from '../diff/apply.js'; import type { Finding, UsageStats } from '../types/index.js'; import { callHaiku } from './haiku.js'; import { aggregateUsage } from './usage.js'; +import type { Provider } from '../config/schema.js'; export interface FixQualityStats { checked: number; @@ -24,6 +25,7 @@ interface SanitizeSuggestedFixesOptions { repoPath: string; apiKey?: string; maxRetries?: number; + provider?: Provider; } const SEMANTIC_PROMPT_MAX_CHARS = 4000; @@ -128,7 +130,8 @@ async function runSemanticGate( fileContent: string, patchedContent: string, apiKey?: string, - maxRetries?: number + maxRetries?: number, + provider?: Provider ): Promise<{ verdict: 'pass' | 'fail' | 'unavailable'; usage?: UsageStats }> { if (!apiKey) { return { verdict: 'unavailable' }; @@ -156,6 +159,7 @@ async function runSemanticGate( const result = await callHaiku({ apiKey, + provider, prompt, schema: SemanticFixVerdictSchema, maxTokens: 220, @@ -203,7 +207,8 @@ export async function sanitizeFindingsSuggestedFixes( deterministic.fileContent, deterministic.patchedContent, options.apiKey, - options.maxRetries + options.maxRetries, + options.provider ); if (semantic.usage) { semanticUsage.push(semantic.usage); diff --git a/src/sdk/haiku.ts b/src/sdk/haiku.ts index aa3d90fd..7f6db405 100644 --- a/src/sdk/haiku.ts +++ b/src/sdk/haiku.ts @@ -5,6 +5,7 @@ import type { UsageStats } from '../types/index.js'; import { Sentry } from '../sentry.js'; import { apiUsageToStats } from './pricing.js'; import { aggregateUsage, emptyUsage } from './usage.js'; +import type { Provider } from '../config/schema.js'; export const HAIKU_MODEL = 'claude-haiku-4-5'; export const DEFAULT_AUXILIARY_MAX_RETRIES = 5; @@ -123,6 +124,7 @@ export interface CallHaikuOptions { apiKey: string; prompt: string; schema: z.ZodType; + provider?: Provider; maxTokens?: number; timeout?: number; maxRetries?: number; @@ -145,6 +147,8 @@ function inferPrefill(schema: z.ZodType): string | undefined { */ export async function callHaiku(options: CallHaikuOptions): Promise> { const { apiKey, prompt, schema, maxTokens = DEFAULT_MAX_TOKENS, timeout = DEFAULT_TIMEOUT_MS, maxRetries = DEFAULT_AUXILIARY_MAX_RETRIES } = options; + const provider = options.provider ?? 'claude'; + const providerName = provider === 'claude' ? 'anthropic' : provider; return Sentry.startSpan( { @@ -152,7 +156,7 @@ export async function callHaiku(options: CallHaikuOptions): Promise { apiKey: string; prompt: string; schema: z.ZodType; + provider?: Provider; tools: Anthropic.Tool[]; executeTool: (name: string, input: Record) => Promise; maxTokens?: number; @@ -236,6 +241,7 @@ export async function callHaikuWithTools(options: CallHaikuWithToolsOptions(options: CallHaikuWithToolsOptions(options: CallHaikuWithToolsOptions