From aeb86048ba35158bfea08eedd8d86e51c414d471 Mon Sep 17 00:00:00 2001 From: codehippie1 Date: Wed, 10 Jun 2026 15:37:46 -0400 Subject: [PATCH 1/9] pricing: refresh model catalog to June 2026 rates Correct OpenAI gpt-5.x prices (were 2-7x understated), add the Gemini 3.x generation and still-billable Anthropic legacy models, fix Gemini 2.5 cache rates, and update context windows for the 1M generation. Adds a table-ordering invariant test plus per-provider coverage. --- src/pricing/rates.rs | 212 +++++++++++++++++++++++++++++++++---- tests/unit/pricing_test.rs | 106 ++++++++++++++++--- tests/unit/waste_test.rs | 8 +- 3 files changed, 288 insertions(+), 38 deletions(-) diff --git a/src/pricing/rates.rs b/src/pricing/rates.rs index d9dde1e..0fe9e59 100644 --- a/src/pricing/rates.rs +++ b/src/pricing/rates.rs @@ -24,7 +24,7 @@ /// Date the embedded rate card was last edited, `YYYY-MM-DD`. Bump /// whenever you change [`KNOWN_MODELS`]. The status command warns the user /// if this date is more than 30 days behind today. -pub const PRICING_LAST_UPDATED: &str = "2026-06-09"; +pub const PRICING_LAST_UPDATED: &str = "2026-06-10"; /// USD per million tokens, broken out by token type. #[derive(Debug, Clone, Copy, PartialEq)] @@ -36,10 +36,11 @@ pub struct ModelPricing { } pub const KNOWN_MODELS: &[(&str, ModelPricing)] = &[ - // ─────────── Anthropic (as of June 2026) ─────────── - // Fable 5 (released 2026-06-09): the tier above Opus. 1M context at these - // flat rates. Cache rates follow the standard Anthropic multipliers - // (write 1.25× input for the 5-minute TTL, read 0.1× input). + // ─────────── Anthropic (verified against the published rate card 2026-06-10) ─────────── + // Cache rates follow the standard Anthropic multipliers (write 1.25× input + // for the 5-minute TTL, read 0.1× input). Legacy models are listed too: + // they stay billable until retirement, and a pinned older model would + // otherwise track as $0 — the most expensive miss is the worst one. ( "claude-fable-5", ModelPricing { @@ -49,6 +50,15 @@ pub const KNOWN_MODELS: &[(&str, ModelPricing)] = &[ output_per_mtok: 50.00, }, ), + ( + "claude-mythos-5", + ModelPricing { + input_per_mtok: 10.00, + cache_write_per_mtok: 12.50, + cache_read_per_mtok: 1.00, + output_per_mtok: 50.00, + }, + ), ( "claude-opus-4-8", ModelPricing { @@ -76,6 +86,47 @@ pub const KNOWN_MODELS: &[(&str, ModelPricing)] = &[ output_per_mtok: 25.00, }, ), + ( + "claude-opus-4-5", + ModelPricing { + input_per_mtok: 5.00, + cache_write_per_mtok: 6.25, + cache_read_per_mtok: 0.50, + output_per_mtok: 25.00, + }, + ), + // Opus 4.1 and Opus 4 are deprecated but still billable — at 3× the + // current Opus rate, so missing them would silently drop the priciest + // traffic. Keyed as the alias (`-4-0`) plus the exact dated ID rather + // than a bare `claude-opus-4` prefix, which would shadow-match every + // future `claude-opus-4-9`-style release at the wrong rate. + ( + "claude-opus-4-1", + ModelPricing { + input_per_mtok: 15.00, + cache_write_per_mtok: 18.75, + cache_read_per_mtok: 1.50, + output_per_mtok: 75.00, + }, + ), + ( + "claude-opus-4-0", + ModelPricing { + input_per_mtok: 15.00, + cache_write_per_mtok: 18.75, + cache_read_per_mtok: 1.50, + output_per_mtok: 75.00, + }, + ), + ( + "claude-opus-4-20250514", + ModelPricing { + input_per_mtok: 15.00, + cache_write_per_mtok: 18.75, + cache_read_per_mtok: 1.50, + output_per_mtok: 75.00, + }, + ), ( "claude-sonnet-4-6", ModelPricing { @@ -85,6 +136,33 @@ pub const KNOWN_MODELS: &[(&str, ModelPricing)] = &[ output_per_mtok: 15.00, }, ), + ( + "claude-sonnet-4-5", + ModelPricing { + input_per_mtok: 3.00, + cache_write_per_mtok: 3.75, + cache_read_per_mtok: 0.30, + output_per_mtok: 15.00, + }, + ), + ( + "claude-sonnet-4-0", + ModelPricing { + input_per_mtok: 3.00, + cache_write_per_mtok: 3.75, + cache_read_per_mtok: 0.30, + output_per_mtok: 15.00, + }, + ), + ( + "claude-sonnet-4-20250514", + ModelPricing { + input_per_mtok: 3.00, + cache_write_per_mtok: 3.75, + cache_read_per_mtok: 0.30, + output_per_mtok: 15.00, + }, + ), ( "claude-haiku-4-5", ModelPricing { @@ -94,55 +172,145 @@ pub const KNOWN_MODELS: &[(&str, ModelPricing)] = &[ output_per_mtok: 5.00, }, ), - // ─────────── OpenAI (as of May 2026) ─────────── - // No cache write cost — caching is automatic. + // ─────────── OpenAI (verified against the published rate card 2026-06-10) ─────────── + // No cache write cost — caching is automatic; cached input bills at 10% of + // input. Tiered long-context pricing exists for the flagship models; this + // flat card uses the standard (short-context) tier. `-pro` models have no + // cached-input rate, so cache_read is 0 there. + // Ordering: `gpt-5.5-pro` before `gpt-5.5`; mini/nano/pro before `gpt-5.4`. + ( + "gpt-5.5-pro", + ModelPricing { + input_per_mtok: 30.00, + cache_write_per_mtok: 0.0, + cache_read_per_mtok: 0.0, + output_per_mtok: 180.00, + }, + ), ( "gpt-5.5", ModelPricing { - input_per_mtok: 2.00, + input_per_mtok: 5.00, cache_write_per_mtok: 0.0, - cache_read_per_mtok: 1.00, - output_per_mtok: 10.00, + cache_read_per_mtok: 0.50, + output_per_mtok: 30.00, }, ), - // `gpt-5.4-mini` MUST precede `gpt-5.4` (see module docs). ( "gpt-5.4-mini", ModelPricing { - input_per_mtok: 0.15, + input_per_mtok: 0.75, cache_write_per_mtok: 0.0, cache_read_per_mtok: 0.075, - output_per_mtok: 0.60, + output_per_mtok: 4.50, + }, + ), + ( + "gpt-5.4-nano", + ModelPricing { + input_per_mtok: 0.20, + cache_write_per_mtok: 0.0, + cache_read_per_mtok: 0.02, + output_per_mtok: 1.25, + }, + ), + ( + "gpt-5.4-pro", + ModelPricing { + input_per_mtok: 30.00, + cache_write_per_mtok: 0.0, + cache_read_per_mtok: 0.0, + output_per_mtok: 180.00, }, ), ( "gpt-5.4", ModelPricing { - input_per_mtok: 1.25, + input_per_mtok: 2.50, cache_write_per_mtok: 0.0, - cache_read_per_mtok: 0.625, - output_per_mtok: 10.00, + cache_read_per_mtok: 0.25, + output_per_mtok: 15.00, + }, + ), + // The Codex CLI's dedicated model — high-volume agentic coding traffic. + ( + "gpt-5.3-codex", + ModelPricing { + input_per_mtok: 1.75, + cache_write_per_mtok: 0.0, + cache_read_per_mtok: 0.175, + output_per_mtok: 14.00, + }, + ), + // ─────────── Google Gemini (verified against the published rate card 2026-06-10) ─────────── + // Implicit caching — no explicit cache-write cost on the response path + // (the per-hour cache-storage fee is not response-derivable and is not + // modeled). Tiered >200k-prompt pricing exists on the pro models; this + // flat card uses the standard ≤200k tier. + // Longest prefixes first: `-flash-lite` before `-flash`, `-pro` / `-flash` + // before any shorter family key, per the module docs. + ( + "gemini-3.5-flash", + ModelPricing { + input_per_mtok: 1.50, + cache_write_per_mtok: 0.0, + cache_read_per_mtok: 0.15, + output_per_mtok: 9.00, + }, + ), + // Catches the `gemini-3.1-pro-preview` ID via the `-` suffix rule. + ( + "gemini-3.1-pro", + ModelPricing { + input_per_mtok: 2.00, + cache_write_per_mtok: 0.0, + cache_read_per_mtok: 0.20, + output_per_mtok: 12.00, + }, + ), + ( + "gemini-3.1-flash-lite", + ModelPricing { + input_per_mtok: 0.25, + cache_write_per_mtok: 0.0, + cache_read_per_mtok: 0.025, + output_per_mtok: 1.50, + }, + ), + // Catches the `gemini-3-flash-preview` ID via the `-` suffix rule. + ( + "gemini-3-flash", + ModelPricing { + input_per_mtok: 0.50, + cache_write_per_mtok: 0.0, + cache_read_per_mtok: 0.05, + output_per_mtok: 3.00, }, ), - // ─────────── Google Gemini (as of May 2026) ─────────── - // Implicit caching — no explicit cache-write cost on the response path. - // Longest prefixes first: `gemini-2.5-pro` / `-flash` before any shorter - // family key, per the module docs. ( "gemini-2.5-pro", ModelPricing { input_per_mtok: 1.25, cache_write_per_mtok: 0.0, - cache_read_per_mtok: 0.3125, + cache_read_per_mtok: 0.125, output_per_mtok: 10.00, }, ), + ( + "gemini-2.5-flash-lite", + ModelPricing { + input_per_mtok: 0.10, + cache_write_per_mtok: 0.0, + cache_read_per_mtok: 0.01, + output_per_mtok: 0.40, + }, + ), ( "gemini-2.5-flash", ModelPricing { input_per_mtok: 0.30, cache_write_per_mtok: 0.0, - cache_read_per_mtok: 0.075, + cache_read_per_mtok: 0.03, output_per_mtok: 2.50, }, ), diff --git a/tests/unit/pricing_test.rs b/tests/unit/pricing_test.rs index 12028e9..247fb07 100644 --- a/tests/unit/pricing_test.rs +++ b/tests/unit/pricing_test.rs @@ -47,9 +47,71 @@ fn lookup_strips_openai_date_suffix() { #[test] fn lookup_disambiguates_gpt_mini_from_gpt_base() { // The critical ordering case: `gpt-5.4-mini-2026-03-01` must hit the mini - // rates (0.15/MTok), NOT the base gpt-5.4 rates (1.25/MTok). + // rates (0.75/MTok), NOT the base gpt-5.4 rates (2.50/MTok). let mini = get_pricing("gpt-5.4-mini-2026-03-01").expect("mini variant"); - assert!((mini.input_per_mtok - 0.15).abs() < EPSILON); + assert!((mini.input_per_mtok - 0.75).abs() < EPSILON); + // Same for nano and pro — every longer variant must shadow the base. + let nano = get_pricing("gpt-5.4-nano").expect("nano variant"); + assert!((nano.input_per_mtok - 0.20).abs() < EPSILON); + let pro = get_pricing("gpt-5.4-pro").expect("pro variant"); + assert!((pro.input_per_mtok - 30.00).abs() < EPSILON); +} + +#[test] +fn codex_model_is_priced() { + // The Codex CLI's dedicated model id must resolve — it has no bare + // `gpt-5.3` base entry to fall back to. + let p = get_pricing("gpt-5.3-codex").expect("codex model"); + assert!((p.input_per_mtok - 1.75).abs() < EPSILON); + assert!((p.output_per_mtok - 14.00).abs() < EPSILON); +} + +#[test] +fn legacy_anthropic_models_are_priced() { + // Deprecated-but-billable models must still track cost: Opus 4.1 / Opus 4 + // bill at 3× the current Opus rate — the worst models to silently miss. + for id in [ + "claude-opus-4-1", + "claude-opus-4-1-20250805", + "claude-opus-4-0", + "claude-opus-4-20250514", + ] { + let p = get_pricing(id).unwrap_or_else(|| panic!("{id} should be priced")); + assert!((p.input_per_mtok - 15.00).abs() < EPSILON, "{id}"); + assert!((p.output_per_mtok - 75.00).abs() < EPSILON, "{id}"); + } + for id in [ + "claude-sonnet-4-5", + "claude-sonnet-4-5-20250929", + "claude-sonnet-4-0", + "claude-sonnet-4-20250514", + "claude-opus-4-5-20251101", + ] { + assert!(get_pricing(id).is_some(), "{id} should be priced"); + } +} + +#[test] +fn known_models_table_orders_longer_prefixes_first() { + // The lookup returns the FIRST dash/bracket-prefix match, so any key that + // is itself a dash-prefix of another key must come after it — e.g. + // `gpt-5.4` after `gpt-5.4-mini`, `gemini-2.5-flash` after + // `gemini-2.5-flash-lite`. This guards the invariant for future edits. + let keys: Vec<&str> = burnwall::pricing::KNOWN_MODELS + .iter() + .map(|(k, _)| *k) + .collect(); + for (i, shorter) in keys.iter().enumerate() { + for longer in keys.iter().skip(i + 1) { + let shadowed = longer + .strip_prefix(shorter) + .is_some_and(|rest| rest.starts_with('-') || rest.starts_with('[')); + assert!( + !shadowed, + "table order bug: '{shorter}' (index {i}) shadows the later key '{longer}'" + ); + } + } } #[test] @@ -131,12 +193,12 @@ fn cost_anthropic_uncached_matches_hand_calculation() { #[test] fn cost_openai_cached_matches_hand_calculation() { - // gpt-5.4 rates (1.25, 0.0, 0.625, 10.00). Fixture splits to + // gpt-5.4 rates (2.50, 0.0, 0.25, 15.00). Fixture splits to // input=512, output=512, cache_read=1536: - // input: 512 / 1M * 1.25 = 0.00064 - // cache_read: 1536 / 1M * 0.625 = 0.00096 - // output: 512 / 1M * 10.00 = 0.00512 - // total 0.00672 + // input: 512 / 1M * 2.50 = 0.00128 + // cache_read: 1536 / 1M * 0.25 = 0.000384 + // output: 512 / 1M * 15.00 = 0.00768 + // total 0.009344 let usage = TokenUsage { input_tokens: 512, output_tokens: 512, @@ -144,7 +206,7 @@ fn cost_openai_cached_matches_hand_calculation() { cache_read_tokens: 1536, }; let pricing = get_pricing("gpt-5.4").expect("pricing"); - approx_eq(cost(&usage, pricing), 0.00672, "gpt-5.4 cached cost"); + approx_eq(cost(&usage, pricing), 0.009344, "gpt-5.4 cached cost"); } #[test] @@ -160,12 +222,12 @@ fn lookup_disambiguates_gemini_pro_from_flash() { #[test] fn cost_gemini_cached_matches_hand_calculation() { - // google_cached.json with gemini-2.5-flash rates (0.30, 0.0, 0.075, 2.50). + // google_cached.json with gemini-2.5-flash rates (0.30, 0.0, 0.03, 2.50). // Split: input=512, output=300, cache_read=1536. // input: 512 / 1M * 0.30 = 0.0001536 - // cache_read: 1536 / 1M * 0.075 = 0.0001152 + // cache_read: 1536 / 1M * 0.03 = 0.00004608 // output: 300 / 1M * 2.50 = 0.00075 - // total 0.0010188 + // total 0.00094968 let usage = TokenUsage { input_tokens: 512, output_tokens: 300, @@ -173,7 +235,27 @@ fn cost_gemini_cached_matches_hand_calculation() { cache_read_tokens: 1536, }; let pricing = get_pricing("gemini-2.5-flash").expect("pricing"); - approx_eq(cost(&usage, pricing), 0.0010188, "gemini flash cached cost"); + approx_eq(cost(&usage, pricing), 0.00094968, "gemini flash cached cost"); +} + +#[test] +fn lookup_disambiguates_gemini_flash_lite_from_flash() { + // `gemini-2.5-flash` is a dash-prefix of `gemini-2.5-flash-lite`, so the + // lite entry must come first in the table or it would bill at flash rates. + let lite = get_pricing("gemini-2.5-flash-lite").expect("flash lite"); + assert!((lite.input_per_mtok - 0.10).abs() < EPSILON); + let lite31 = get_pricing("gemini-3.1-flash-lite").expect("3.1 flash lite"); + assert!((lite31.input_per_mtok - 0.25).abs() < EPSILON); +} + +#[test] +fn gemini_3_generation_is_priced() { + // The preview suffixes on current Gemini IDs resolve via the `-` rule. + let pro = get_pricing("gemini-3.1-pro-preview").expect("3.1 pro preview"); + assert!((pro.input_per_mtok - 2.00).abs() < EPSILON); + let flash = get_pricing("gemini-3-flash-preview").expect("3 flash preview"); + assert!((flash.input_per_mtok - 0.50).abs() < EPSILON); + assert!(get_pricing("gemini-3.5-flash").is_some()); } #[test] diff --git a/tests/unit/waste_test.rs b/tests/unit/waste_test.rs index d914ac8..4102eeb 100644 --- a/tests/unit/waste_test.rs +++ b/tests/unit/waste_test.rs @@ -224,8 +224,8 @@ fn flags_heavy_reasoning_on_routine_requests() { .expect("should flag reasoning overuse"); assert_eq!(finding.rule_id, "reasoning-effort-overuse"); assert_eq!(finding.count, 12); - // gpt-5.5 output $10/MTok: 1200 reasoning × 10 / 1e6 = $0.012 each × 12 = $0.144. - assert!((finding.observed_waste_usd - 0.144).abs() < 1e-6); + // gpt-5.5 output $30/MTok: 1200 reasoning × 30 / 1e6 = $0.036 each × 12 = $0.432. + assert!((finding.observed_waste_usd - 0.432).abs() < 1e-6); } #[test] @@ -290,8 +290,8 @@ fn flags_context_window_saturation() { .expect("should flag saturation"); assert_eq!(f.rule_id, "context-window-saturation"); assert_eq!(f.count, 12); - // gpt-5.5 input $2/MTok: 240000 × 2 / 1e6 = $0.48 each × 12 = $5.76. - assert!((f.observed_waste_usd - 5.76).abs() < 1e-6); + // gpt-5.5 input $5/MTok: 240000 × 5 / 1e6 = $1.20 each × 12 = $14.40. + assert!((f.observed_waste_usd - 14.40).abs() < 1e-6); } #[test] From b3cb98bef3b9685c41e9ab4af6012794f8ed73d2 Mon Sep 17 00:00:00 2001 From: codehippie1 Date: Wed, 10 Jun 2026 15:37:55 -0400 Subject: [PATCH 2/9] budget: day/month reset, monthly enforcement, plan-aware caps; loop detector death-spiral fix BudgetTracker is now day- and month-stamped with lazy rollover (restart- and clock-change-proof) so a long-running daemon no longer accumulates across days and 429s everything. Monthly cap is now enforced. budget.enforce_on_plan (default off) keeps the dollar cap from blocking subscription traffic, which is notional. Loop detector splits into a read-only pre-forward peek plus a tee-side record-on-2xx, so blocked 429s and failed-request retries can't refill the window; hash keyed by method+provider+path, GET/body-less skipped, Retry-After added. Also filters empty deny rules at ruleset construction. --- src/budget/limits.rs | 27 ++++ src/budget/loop_detector.rs | 194 +++++++++++++++++++++++++---- src/budget/mod.rs | 113 +++++++++++++++-- src/config/types.rs | 21 +++- tests/integration/budget_test.rs | 147 +++++++++++++++++----- tests/unit/project_profile_test.rs | 1 + 6 files changed, 430 insertions(+), 73 deletions(-) diff --git a/src/budget/limits.rs b/src/budget/limits.rs index 9f785df..3fd1340 100644 --- a/src/budget/limits.rs +++ b/src/budget/limits.rs @@ -15,6 +15,14 @@ pub struct BudgetConfig { /// `x-burnwall-session` request header. `0.0` = unlimited (off). Lets agents /// in a fan-out that share a session id share one blast-radius ceiling. pub per_session_usd: f64, + /// Enforce the dollar caps (daily/monthly/session) on subscription traffic + /// too. Off by default: a flat-rate plan (Claude Pro/Max via OAuth) is not + /// metered per token, so the calculated API-equivalent dollar figure is + /// notional — blocking on it walls the user off from money they are not + /// spending. With `false`, subscription requests are tracked and *warned* + /// but never blocked on the dollar cap; metered API-key traffic is always + /// enforced. The loop detector / cost spiral still apply to both. See B-H4. + pub enforce_on_plan: bool, } impl Default for BudgetConfig { @@ -24,6 +32,7 @@ impl Default for BudgetConfig { monthly_usd: 0.0, // unlimited per SPEC default warn_percent: 80, per_session_usd: 0.0, // off by default + enforce_on_plan: false, } } } @@ -73,6 +82,24 @@ pub fn check_daily(spent_usd: f64, config: &BudgetConfig) -> BudgetStatus { BudgetStatus::Ok } +/// Pure: classify `spent_usd` (month-to-date) against the monthly limit. +/// +/// Mirrors [`check_daily`] but against `monthly_usd` and with no warn tier — +/// the monthly cap is a hard backstop, and the daily warn already nudges. +/// `0.0` monthly limit = unlimited. +pub fn check_monthly(spent_usd: f64, config: &BudgetConfig) -> BudgetStatus { + if config.monthly_usd <= 0.0 { + return BudgetStatus::Ok; + } + if spent_usd >= config.monthly_usd { + return BudgetStatus::Exceeded { + spent: spent_usd, + limit: config.monthly_usd, + }; + } + BudgetStatus::Ok +} + /// Pure: classify a session's `spent_usd` against the per-session cap. Returns /// `Exceeded` once spend reaches the cap; no warn tier (a swarm ceiling is a /// hard stop). `0.0` cap = unlimited. diff --git a/src/budget/loop_detector.rs b/src/budget/loop_detector.rs index f86807b..f50b6da 100644 --- a/src/budget/loop_detector.rs +++ b/src/budget/loop_detector.rs @@ -3,9 +3,9 @@ //! //! Two independent mechanisms: //! -//! - **Repeated-content loop**: hash a prefix of the request body; if the -//! same hash appears `max_identical_requests` times within -//! `window_seconds`, block with HTTP 429. +//! - **Repeated-content loop**: hash the full request body; if the same +//! hash appears `max_identical_requests` times within `window_seconds`, +//! block with HTTP 429. //! - **Cost spiral**: independently of content, if the rolling per-window //! cost exceeds `max_cost_per_window`, block. //! @@ -38,8 +38,6 @@ pub struct LoopConfig { /// logged by `record_cost`, but not enforced — blocking is opt-in so a /// normal burst of spend does not start 429-ing a working session. pub cost_spiral_enforce: bool, - /// Bytes of request body to hash for the dedup signature. - pub hash_prefix_bytes: usize, } impl Default for LoopConfig { @@ -50,7 +48,6 @@ impl Default for LoopConfig { window_seconds: 300, max_cost_per_window: 2.0, cost_spiral_enforce: false, - hash_prefix_bytes: 200, } } } @@ -63,6 +60,10 @@ pub enum LoopVerdict { count: u32, window_seconds: u32, hash: u64, + /// Seconds until the window drains enough to retry (the oldest + /// in-window arrival's expiry). Steers well-behaved SDKs to back off + /// *past* the window instead of hammering it (B-C2). + retry_after_secs: u64, }, /// Rolling cost in the window exceeds the cap. CostSpiral { @@ -77,6 +78,20 @@ impl LoopVerdict { !matches!(self, LoopVerdict::Ok) } + /// Seconds the client should wait before retrying — the `Retry-After` + /// header value. For a repeated-loop block it's the window-drain time; for + /// a cost spiral it's the full window (the rolling cost needs that long to + /// age out). `None` when not blocking. + pub fn retry_after_secs(&self) -> Option { + match self { + LoopVerdict::Ok => None, + LoopVerdict::Repeated { + retry_after_secs, .. + } => Some(*retry_after_secs), + LoopVerdict::CostSpiral { window_seconds, .. } => Some(*window_seconds as u64), + } + } + /// Human-readable message used as `block_reason` in storage and as the /// 429 body's `message` field. pub fn message(&self) -> String { @@ -127,16 +142,37 @@ impl LoopDetector { &self.config } - /// Compute the dedup signature for a request body. - pub fn hash(&self, body: &[u8]) -> u64 { - let take = self.config.hash_prefix_bytes.min(body.len()); + /// Compute the dedup signature for a request. Hashes `(method, provider, + /// path, FULL body)`: + /// + /// - **Full body**, because agentic clients resend the whole (growing) + /// transcript every turn, so any fixed-size prefix is identical across a + /// session and a prefix hash would flag normal activity as a loop. + /// - **method + provider + path**, so body-less requests (every `GET + /// /v1/models` hashes to the same empty body) don't collide into one + /// global bucket across tools and providers (B-H1). The handler also + /// skips loop detection for GET/body-less requests entirely. + pub fn hash(&self, method: &str, provider: &str, path: &str, body: &[u8]) -> u64 { let mut h = DefaultHasher::new(); - body[..take].hash(&mut h); + method.hash(&mut h); + provider.hash(&mut h); + path.hash(&mut h); + body.hash(&mut h); h.finish() } - /// Record a request arrival under its hash and decide if it forms a - /// loop. Always called pre-forward. + /// Read-only pre-forward check: prune expired arrivals and decide whether + /// the window is already full, **without recording** this request. The + /// arrival is recorded later (by [`record_arrival`](Self::record_arrival)), + /// and only if the request was actually forwarded and succeeded. + /// + /// This split is what breaks the death spiral (B-C2): a request the + /// detector blocks returns 429 but is *not* counted, and an SDK that + /// retries that 429 — or retries after an upstream failure — re-peeks + /// without refilling the window, so the window drains after + /// `window_seconds` and the user recovers. Under the old "record then + /// check" model every retry (including retries of the block itself) topped + /// the window back up, so it never drained. pub fn check_request(&self, hash: u64) -> LoopVerdict { if !self.config.enabled { return LoopVerdict::Ok; @@ -145,29 +181,55 @@ impl LoopDetector { let window = Duration::seconds(self.config.window_seconds as i64); let cutoff = now - window; - let count = { - let mut entry = self.hash_history.entry(hash).or_default(); - while let Some(front) = entry.front() { - if *front < cutoff { - entry.pop_front(); - } else { - break; - } + let mut entry = self.hash_history.entry(hash).or_default(); + while let Some(front) = entry.front() { + if *front < cutoff { + entry.pop_front(); + } else { + break; } - entry.push_back(now); - entry.len() as u32 - }; - + } + let count = entry.len() as u32; if count >= self.config.max_identical_requests { + // Window drains when the oldest arrival ages out. + let retry_after_secs = entry + .front() + .map(|oldest| { + let elapsed = (now - *oldest).num_seconds().max(0); + (self.config.window_seconds as i64 - elapsed).max(1) as u64 + }) + .unwrap_or(self.config.window_seconds as u64); return LoopVerdict::Repeated { count, window_seconds: self.config.window_seconds, hash, + retry_after_secs, }; } LoopVerdict::Ok } + /// Record a forwarded-and-succeeded request arrival under its hash. Called + /// from the response tee **only for 2xx responses** — never for blocked or + /// failed requests — so the window counts genuine repeats, not retries of + /// errors. Prunes expired arrivals as it goes. + pub fn record_arrival(&self, hash: u64) { + if !self.config.enabled { + return; + } + let now = Utc::now(); + let cutoff = now - Duration::seconds(self.config.window_seconds as i64); + let mut entry = self.hash_history.entry(hash).or_default(); + while let Some(front) = entry.front() { + if *front < cutoff { + entry.pop_front(); + } else { + break; + } + } + entry.push_back(now); + } + /// Append a recorded cost to the global window and decide whether the /// rolling spend has tripped the cost-spiral cap. /// @@ -258,10 +320,92 @@ mod tests { window_seconds: 300, max_cost_per_window: cap, cost_spiral_enforce: enforce, - hash_prefix_bytes: 200, } } + fn h(det: &LoopDetector, body: &[u8]) -> u64 { + det.hash("POST", "anthropic", "/v1/messages", body) + } + + #[test] + fn growing_transcript_does_not_loop() { + // Regression: agentic clients (Claude Code) resend the entire + // conversation every turn, so consecutive request bodies share a long + // identical prefix — same model, same opening message — while growing + // at the tail. The old 200-byte prefix hash saw those as identical + // and 429'd any session that made 5 requests within 5 minutes. + let det = LoopDetector::with_defaults(); + let prefix = r#"{"model":"claude-fable-5","messages":[{"role":"user","content":"please investigate why successful proxied requests are not recorded and fix the streaming usage parser so the cost tracking pipeline works again"}"#; + assert!(prefix.len() > 200, "prefix must exceed the old hash window"); + for i in 0..10 { + let body = format!( + "{prefix},{{\"role\":\"assistant\",\"content\":\"turn {i}\"}}]}}" + ); + let hash = h(&det, body.as_bytes()); + let verdict = det.check_request(hash); + assert_eq!(verdict, LoopVerdict::Ok, "turn {i} wrongly flagged as loop"); + det.record_arrival(hash); + } + } + + #[test] + fn byte_identical_bodies_still_trip() { + let det = LoopDetector::with_defaults(); + let hash = h(&det, br#"{"model":"m","messages":[{"role":"user","content":"same"}]}"#); + // Five identical *successful* requests are tolerated; the sixth peek + // sees a full window and blocks. Each Ok request records its arrival + // (as the tee does on a 2xx). + for _ in 0..5 { + assert_eq!(det.check_request(hash), LoopVerdict::Ok); + det.record_arrival(hash); + } + assert!(det.check_request(hash).is_blocking()); + } + + #[test] + fn blocked_requests_do_not_feed_the_window() { + // The death-spiral regression (B-C2): the block path calls only + // check_request (never record_arrival), so an SDK that hammers a 429 — + // or retries after an upstream failure — cannot keep the window full. + // check_request is read-only: calling it 100× without a single + // record_arrival must never produce a block. + let det = LoopDetector::with_defaults(); + let hash = h(&det, b"identical-retry-body"); + for _ in 0..100 { + assert_eq!(det.check_request(hash), LoopVerdict::Ok); + } + } + + #[test] + fn distinct_method_path_dont_share_a_bucket() { + // B-H1: body-less requests (empty body) used to collide into one global + // bucket; including method+provider+path keeps GET /v1/models on one + // tool distinct from another tool's. + let det = LoopDetector::with_defaults(); + let a = det.hash("GET", "anthropic", "/v1/models", b""); + let b = det.hash("GET", "openai", "/v1/models", b""); + let c = det.hash("GET", "anthropic", "/v1/models/claude", b""); + assert_ne!(a, b); + assert_ne!(a, c); + } + + #[test] + fn repeated_verdict_carries_retry_after() { + let det = LoopDetector::with_defaults(); + let hash = h(&det, b"loop-body"); + for _ in 0..5 { + det.record_arrival(hash); + } + let v = det.check_request(hash); + match v { + LoopVerdict::Repeated { + retry_after_secs, .. + } => assert!((1..=300).contains(&retry_after_secs)), + other => panic!("expected Repeated, got {other:?}"), + } + assert!(det.check_request(hash).retry_after_secs().is_some()); + } + #[test] fn cost_spiral_not_enforced_by_default() { let det = LoopDetector::new(cfg(false, 2.0)); diff --git a/src/budget/mod.rs b/src/budget/mod.rs index 7eb53b0..1a817be 100644 --- a/src/budget/mod.rs +++ b/src/budget/mod.rs @@ -16,16 +16,23 @@ //! overshoot is harmless. //! //! ### Date awareness -//! The tracker is date-agnostic: it just accumulates. The caller (the proxy -//! / a scheduled reset task) tells it when to reset by calling -//! [`BudgetTracker::reset`] at midnight, and the caller picks UTC vs local. +//! The tracker is **day- and month-aware**: it stamps the local calendar day +//! and month at construction/hydration, and on every [`record`](BudgetTracker::record) +//! / [`check`](BudgetTracker::check) it lazily rolls the counter to zero when +//! the local day (or month) has changed since the stamp. This is restart-proof +//! (hydration re-derives the stamp) and clock-change-proof (any date change +//! triggers it) — unlike the old design where the documented `reset()` task was +//! never wired up, so a multi-day daemon accumulated forever and eventually +//! 429'd all traffic against the daily cap (B-C1). -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicI64, AtomicU64, Ordering}; + +use chrono::Datelike; pub mod limits; pub mod loop_detector; -pub use limits::{check_daily, check_session, BudgetConfig, BudgetStatus}; +pub use limits::{check_daily, check_monthly, check_session, BudgetConfig, BudgetStatus}; pub use loop_detector::{LoopConfig, LoopDetector, LoopVerdict}; use crate::storage::Storage; @@ -33,8 +40,28 @@ use crate::storage::Storage; /// 1 USD in microcents = 10⁸. const MICROCENTS_PER_USD: f64 = 100_000_000.0; +/// Local calendar day as a monotonic integer (days since CE), for the +/// day-rollover stamp. +fn local_epoch_day() -> i64 { + chrono::Local::now().date_naive().num_days_from_ce() as i64 +} + +/// Local calendar month as a monotonic integer (`year*12 + month0`), for the +/// month-rollover stamp. +fn local_epoch_month() -> i64 { + let d = chrono::Local::now().date_naive(); + (d.year() as i64) * 12 + (d.month0() as i64) +} + pub struct BudgetTracker { today_microcents: AtomicU64, + /// Month-to-date spend (microcents) for the monthly cap (B-H2). + month_microcents: AtomicU64, + /// Local calendar day the `today_microcents` counter belongs to. When the + /// current local day differs, the counter is reset before use. + day_stamp: AtomicI64, + /// Local calendar month the `month_microcents` counter belongs to. + month_stamp: AtomicI64, /// Per-session/swarm spend (microcents), keyed on the opt-in /// `x-burnwall-session` header. Only populated when a session id is present. session_microcents: dashmap::DashMap, @@ -45,6 +72,9 @@ impl BudgetTracker { pub fn new(config: BudgetConfig) -> Self { Self { today_microcents: AtomicU64::new(0), + month_microcents: AtomicU64::new(0), + day_stamp: AtomicI64::new(local_epoch_day()), + month_stamp: AtomicI64::new(local_epoch_month()), session_microcents: dashmap::DashMap::new(), config, } @@ -58,19 +88,57 @@ impl BudgetTracker { &self.config } - /// Current accumulated spend in USD. + /// Current accumulated spend in USD (after a lazy day-rollover). pub fn today_spent(&self) -> f64 { + self.roll_if_new_period(); (self.today_microcents.load(Ordering::Relaxed) as f64) / MICROCENTS_PER_USD } - /// Add a request's cost to the counter. Lock-free. + /// Month-to-date accumulated spend in USD (after a lazy month-rollover). + pub fn month_spent(&self) -> f64 { + self.roll_if_new_period(); + (self.month_microcents.load(Ordering::Relaxed) as f64) / MICROCENTS_PER_USD + } + + /// Reset the daily and/or monthly counters if the local calendar day or + /// month has advanced past the stamp. Lazy and idempotent: the first caller + /// to observe the new period wins the compare-and-swap and zeroes the + /// counter; concurrent callers see the already-swapped stamp and skip. + /// At a true midnight rollover the new period's storage spend is ~0, so a + /// reset-to-zero is correct without re-reading storage. + fn roll_if_new_period(&self) { + let today = local_epoch_day(); + let stamped_day = self.day_stamp.load(Ordering::Relaxed); + if today != stamped_day + && self + .day_stamp + .compare_exchange(stamped_day, today, Ordering::SeqCst, Ordering::Relaxed) + .is_ok() + { + self.today_microcents.store(0, Ordering::Relaxed); + } + let month = local_epoch_month(); + let stamped_month = self.month_stamp.load(Ordering::Relaxed); + if month != stamped_month + && self + .month_stamp + .compare_exchange(stamped_month, month, Ordering::SeqCst, Ordering::Relaxed) + .is_ok() + { + self.month_microcents.store(0, Ordering::Relaxed); + } + } + + /// Add a request's cost to the day + month counters. Lock-free. /// Negative inputs are clamped to zero — costs are always non-negative. pub fn record(&self, cost_usd: f64) { if !cost_usd.is_finite() || cost_usd <= 0.0 { return; } + self.roll_if_new_period(); let units = (cost_usd * MICROCENTS_PER_USD).round() as u64; self.today_microcents.fetch_add(units, Ordering::Relaxed); + self.month_microcents.fetch_add(units, Ordering::Relaxed); } /// Classify the current state against the configured daily limit. @@ -78,6 +146,11 @@ impl BudgetTracker { check_daily(self.today_spent(), &self.config) } + /// Classify month-to-date spend against the configured monthly limit. + pub fn check_monthly(&self) -> BudgetStatus { + check_monthly(self.month_spent(), &self.config) + } + /// Add a request's cost to a session/swarm counter (keyed on the opt-in /// `x-burnwall-session` header). No-op when per-session capping is off. pub fn record_session(&self, session: &str, cost_usd: f64) { @@ -102,20 +175,38 @@ impl BudgetTracker { check_session(self.session_spent(session), &self.config) } - /// Zero the counter — call at midnight (caller decides UTC vs local). + /// Zero the daily counter and re-stamp to the current local day. Normally + /// the lazy [`roll_if_new_period`](Self::roll_if_new_period) handles + /// rollover; this is kept for explicit resets and tests. pub fn reset(&self) { self.today_microcents.store(0, Ordering::Relaxed); + self.day_stamp.store(local_epoch_day(), Ordering::Relaxed); } /// Load today's spend from storage into the counter on startup, so - /// restarting Burnwall mid-day doesn't reset the budget to zero. + /// restarting Burnwall mid-day doesn't reset the budget to zero. Stamps the + /// counter with the **current** local day so the lazy rollover fires at the + /// next local-day change (production always hydrates today's date; the + /// counter reflects "now", not the queried date). /// - /// `date` is a `YYYY-MM-DD` string; the caller decides whether that's - /// UTC or local. Replaces (not adds to) the existing counter value. + /// `date` is a `YYYY-MM-DD` string. Replaces (not adds to) the existing + /// counter value. pub fn hydrate_for_date(&self, storage: &Storage, date: &str) -> crate::storage::Result<()> { let spent = storage.total_cost_for_date(date)?; let units = (spent * MICROCENTS_PER_USD).round() as u64; self.today_microcents.store(units, Ordering::Relaxed); + self.day_stamp.store(local_epoch_day(), Ordering::Relaxed); + Ok(()) + } + + /// Load month-to-date spend from storage into the monthly counter on + /// startup. `month` is a `YYYY-MM` string (local). Stamps the current local + /// month so the lazy rollover fires at the next local-month change. + pub fn hydrate_for_month(&self, storage: &Storage, month: &str) -> crate::storage::Result<()> { + let spent = storage.total_cost_for_month(month)?; + let units = (spent * MICROCENTS_PER_USD).round() as u64; + self.month_microcents.store(units, Ordering::Relaxed); + self.month_stamp.store(local_epoch_month(), Ordering::Relaxed); Ok(()) } } diff --git a/src/config/types.rs b/src/config/types.rs index 29c3c0c..5f1c8eb 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -119,6 +119,14 @@ pub struct BudgetConfig { /// same session id share one blast-radius ceiling. #[serde(default)] pub per_session: f64, + /// Enforce the dollar caps on subscription (flat-rate plan) traffic too. + /// Off by default — a Claude Pro/Max session authenticates with an OAuth + /// token and is not metered per token, so the calculated dollar figure is + /// notional. Burnwall still *tracks* and *warns* on plan traffic, but does + /// not 429-block it on the dollar cap unless this is `true` (B-H4). Metered + /// API-key traffic is always enforced. + #[serde(default)] + pub enforce_on_plan: bool, } impl Default for BudgetConfig { @@ -128,6 +136,7 @@ impl Default for BudgetConfig { monthly: 0.0, warn_percent: 80, per_session: 0.0, + enforce_on_plan: false, } } } @@ -398,6 +407,7 @@ impl From<&BudgetConfig> for crate::budget::BudgetConfig { monthly_usd: c.monthly, warn_percent: c.warn_percent, per_session_usd: c.per_session, + enforce_on_plan: c.enforce_on_plan, } } } @@ -408,12 +418,14 @@ impl From<&SecurityConfig> for crate::security::Ruleset { fn from(c: &SecurityConfig) -> Self { Self { enabled: c.enabled, - deny_paths: c.deny_paths.clone(), + // Filter blank rules: a hand-edited config with an empty entry + // would otherwise match every leaf and block all traffic (S-H8). + deny_paths: crate::security::rules::non_empty_rules(c.deny_paths.clone()), // `allow_paths` is project-profile-only — the global config has // no allow list. A discovered `.burnwall.yaml` merges into this // afterwards (see `cli::start`). allow_paths: Vec::new(), - deny_commands: c.deny_commands.clone(), + deny_commands: crate::security::rules::non_empty_rules(c.deny_commands.clone()), block_network_mounts: c.block_network_mounts, detect_secrets: c.detect_secrets, detect_egress: c.dlp, @@ -447,18 +459,15 @@ impl ResilienceConfig { } /// Convert the persistent loop_detection block into the runtime -/// [`crate::budget::LoopConfig`]. `hash_prefix_bytes` keeps its built-in -/// default (200) — we don't expose it as a TOML knob in v0.2. +/// [`crate::budget::LoopConfig`]. impl From<&LoopDetectionConfig> for crate::budget::LoopConfig { fn from(c: &LoopDetectionConfig) -> Self { - let defaults = crate::budget::LoopConfig::default(); Self { enabled: c.enabled, max_identical_requests: c.max_identical_requests, window_seconds: c.window_seconds, max_cost_per_window: c.max_cost_per_window, cost_spiral_enforce: c.cost_spiral_enforce, - hash_prefix_bytes: defaults.hash_prefix_bytes, } } } diff --git a/tests/integration/budget_test.rs b/tests/integration/budget_test.rs index 8a30644..7c66438 100644 --- a/tests/integration/budget_test.rs +++ b/tests/integration/budget_test.rs @@ -21,6 +21,7 @@ fn cfg(daily: f64, warn: u8) -> BudgetConfig { monthly_usd: 0.0, warn_percent: warn, per_session_usd: 0.0, + enforce_on_plan: false, } } @@ -30,11 +31,69 @@ fn cfg_session(per_session: f64) -> BudgetConfig { monthly_usd: 0.0, warn_percent: 80, per_session_usd: per_session, + enforce_on_plan: false, } } const EPS: f64 = 1e-9; +// ───────────────────── Monthly cap (B-H2) ───────────────────── + +fn cfg_monthly(monthly: f64) -> BudgetConfig { + BudgetConfig { + daily_usd: 0.0, // unlimited daily, isolate the monthly check + monthly_usd: monthly, + warn_percent: 80, + per_session_usd: 0.0, + enforce_on_plan: false, + } +} + +#[test] +fn monthly_cap_unlimited_when_zero() { + let t = BudgetTracker::new(cfg_monthly(0.0)); + t.record(1_000.0); + assert!(matches!(t.check_monthly(), BudgetStatus::Ok)); +} + +#[test] +fn monthly_cap_blocks_when_exceeded() { + let t = BudgetTracker::new(cfg_monthly(100.0)); + t.record(99.0); + assert!(matches!(t.check_monthly(), BudgetStatus::Ok)); + t.record(2.0); // 101 > 100 + assert!( + matches!(t.check_monthly(), BudgetStatus::Exceeded { .. }), + "monthly cap should block once exceeded" + ); + // The daily check is independent and unlimited here. + assert!(matches!(t.check(), BudgetStatus::Ok)); +} + +#[test] +fn record_accumulates_into_both_day_and_month() { + let t = BudgetTracker::new(BudgetConfig { + daily_usd: 0.0, + monthly_usd: 0.0, + warn_percent: 80, + per_session_usd: 0.0, + enforce_on_plan: false, + }); + t.record(3.0); + t.record(4.0); + assert!((t.today_spent() - 7.0).abs() < EPS); + assert!((t.month_spent() - 7.0).abs() < EPS); +} + +#[test] +fn reset_zeroes_day_but_not_month() { + let t = BudgetTracker::new(cfg_monthly(0.0)); + t.record(5.0); + t.reset(); + assert!((t.today_spent()).abs() < EPS, "daily reset to zero"); + assert!((t.month_spent() - 5.0).abs() < EPS, "month untouched by daily reset"); +} + // ───────────────────────────── Pure check ───────────────────────────── #[test] @@ -262,10 +321,14 @@ fn loop_cfg(max_identical: u32, window: u32, max_cost: f64) -> LoopConfig { window_seconds: window, max_cost_per_window: max_cost, cost_spiral_enforce: false, - hash_prefix_bytes: 200, } } +/// Hash a body with the standard method/provider/path context. +fn lh(det: &LoopDetector, body: &[u8]) -> u64 { + det.hash("POST", "anthropic", "/v1/messages", body) +} + #[test] fn loop_detector_passes_unique_requests() { let det = LoopDetector::new(loop_cfg(3, 60, 1000.0)); @@ -275,43 +338,62 @@ fn loop_detector_passes_unique_requests() { b"third body".as_slice(), ]; for body in &bodies { - let h = det.hash(body); + let h = lh(&det, body); assert_eq!(det.check_request(h), LoopVerdict::Ok); + det.record_arrival(h); } } #[test] -fn loop_detector_blocks_on_nth_identical_request() { - // max_identical_requests = 3 -> the 3rd identical request triggers the block. +fn loop_detector_blocks_after_max_identical_successes() { + // peek/record model: N identical *successful* requests are tolerated (each + // recorded by the tee on a 2xx); the next identical peek blocks. let det = LoopDetector::new(loop_cfg(3, 60, 0.0)); - let body = b"identical body"; - let h = det.hash(body); + let h = lh(&det, b"identical body"); - assert_eq!(det.check_request(h), LoopVerdict::Ok, "1st should pass"); - assert_eq!(det.check_request(h), LoopVerdict::Ok, "2nd should pass"); + for _ in 0..3 { + assert_eq!(det.check_request(h), LoopVerdict::Ok); + det.record_arrival(h); + } let v = det.check_request(h); assert!( matches!(v, LoopVerdict::Repeated { count: 3, .. }), - "3rd should block, got {:?}", + "should block once 3 successes are recorded, got {:?}", v ); } #[test] -fn loop_detector_hashes_only_prefix_bytes() { - // Same prefix (200 bytes by default), different suffix -> same hash. +fn loop_detector_check_is_read_only() { + // The death-spiral regression (B-C2): check_request never records, so a + // client hammering a 429 can't keep its own window full. + let det = LoopDetector::new(loop_cfg(3, 60, 0.0)); + let h = lh(&det, b"retry body"); + for _ in 0..50 { + assert_eq!(det.check_request(h), LoopVerdict::Ok); + } +} + +#[test] +fn loop_detector_hashes_full_body() { + // Same long prefix, different suffix -> DIFFERENT hash. Agentic clients + // resend the whole (growing) transcript every turn, so a shared prefix + // is normal session traffic, not a loop — only byte-identical bodies + // may collide. let mut a = vec![b'A'; 200]; let mut b = a.clone(); a.extend_from_slice(b"-different-suffix-A"); b.extend_from_slice(b"-different-suffix-B"); let det = LoopDetector::with_defaults(); - assert_eq!(det.hash(&a), det.hash(&b)); + assert_ne!(lh(&det, &a), lh(&det, &b)); + + // Identical bodies -> identical hash. + assert_eq!(lh(&det, &a), lh(&det, &a.clone())); - // Different first 200 bytes -> different hash. - let mut c = vec![b'A'; 200]; + // Different content -> different hash. + let c = vec![b'A'; 200]; let d = vec![b'B'; 200]; - c[0] = b'X'; - assert_ne!(det.hash(&c), det.hash(&d)); + assert_ne!(lh(&det, &c), lh(&det, &d)); } #[test] @@ -320,20 +402,24 @@ fn loop_detector_disabled_returns_ok() { enabled: false, ..loop_cfg(1, 60, 1.0) // would block immediately if enabled }); - let h = det.hash(b"any"); - assert_eq!(det.check_request(h), LoopVerdict::Ok); + let h = lh(&det, b"any"); + det.record_arrival(h); + det.record_arrival(h); assert_eq!(det.check_request(h), LoopVerdict::Ok); } #[test] fn loop_detector_independent_hashes_dont_cross_count() { let det = LoopDetector::new(loop_cfg(2, 60, 0.0)); - let h1 = det.hash(b"body one"); - let h2 = det.hash(b"body two"); + let h1 = lh(&det, b"body one"); + let h2 = lh(&det, b"body two"); + // Record one arrival under each — neither reaches the cap of 2. + det.record_arrival(h1); + det.record_arrival(h2); assert_eq!(det.check_request(h1), LoopVerdict::Ok); - assert_eq!(det.check_request(h2), LoopVerdict::Ok); - // Each hash now has count=1, neither should block. + // A second success under h2 brings it to the cap; the next peek blocks. + det.record_arrival(h2); let v = det.check_request(h2); assert!(matches!(v, LoopVerdict::Repeated { count: 2, .. })); } @@ -375,10 +461,10 @@ fn current_window_cost_excludes_expired_entries() { #[test] fn loop_detector_safe_under_concurrent_writers() { - // 8 threads pounding the same hash. Set max_identical=1 so every call - // returns Repeated{count}, letting us verify no increments are lost. + // 8 threads recording arrivals under the same hash; verify no increments + // are lost under contention (the arrival path is what mutates state now). let det = Arc::new(LoopDetector::new(loop_cfg(1, 60, 0.0))); - let h = det.hash(b"shared body"); + let h = lh(&det, b"shared body"); let threads = 8; let per_thread = 1000; let mut handles = Vec::with_capacity(threads); @@ -386,19 +472,18 @@ fn loop_detector_safe_under_concurrent_writers() { let d = det.clone(); handles.push(std::thread::spawn(move || { for _ in 0..per_thread { - let _ = d.check_request(h); + d.record_arrival(h); } })); } - for h in handles { - h.join().unwrap(); + for handle in handles { + handle.join().unwrap(); } - let final_verdict = det.check_request(h); - let final_count = match final_verdict { + let final_count = match det.check_request(h) { LoopVerdict::Repeated { count, .. } => count, v => panic!("expected Repeated, got {:?}", v), }; - let expected = (threads * per_thread + 1) as u32; + let expected = (threads * per_thread) as u32; assert_eq!(final_count, expected, "lost increments under contention"); } diff --git a/tests/unit/project_profile_test.rs b/tests/unit/project_profile_test.rs index b50d46c..79ebb54 100644 --- a/tests/unit/project_profile_test.rs +++ b/tests/unit/project_profile_test.rs @@ -185,6 +185,7 @@ fn budget(daily: f64) -> BudgetConfig { monthly_usd: 0.0, warn_percent: 80, per_session_usd: 0.0, + enforce_on_plan: false, } } From f039508f566dd3ace0f20a6300aa78575d0a134e Mon Sep 17 00:00:00 2001 From: codehippie1 Date: Wed, 10 Jun 2026 15:38:06 -0400 Subject: [PATCH 3/9] security: scope-by-role scanning + false-positive fixes Editor/content tool args and tool_results now get data checks only (a Write or note that merely mentions ~/.ssh no longer 403s); shell tools keep full command checks, path-shaped content args still path-checked. UNC match requires a real share root (escaped Windows paths pass) with WSL/device whitelist; rm literals dropped in favor of the shape detector (scoped deletes pass), tokenizer splits JSON-glued tokens; AWS example keys exempted; match location surfaced. Responses-API input[] round scoping. /Volumes/ dropped. Adds sk-proj-/ASIA/gh[pousr]_/glpat- patterns. Fail-open scan now logs. --- src/security/destructive.rs | 14 +- src/security/mod.rs | 61 +++++++- src/security/rules.rs | 118 +++++++++++++-- src/security/scanner.rs | 225 +++++++++++++++++++++++++++-- src/security/secrets.rs | 43 +++++- tests/integration/security_test.rs | 52 +++++-- 6 files changed, 462 insertions(+), 51 deletions(-) diff --git a/src/security/destructive.rs b/src/security/destructive.rs index 2c51ef2..cc68cf1 100644 --- a/src/security/destructive.rs +++ b/src/security/destructive.rs @@ -87,10 +87,20 @@ fn contains_flag(lower: &str, flag: char) -> bool { false } -/// Split a command line into tokens on whitespace and shell separators. +/// Split a command line into tokens on whitespace, shell separators, and JSON +/// punctuation. The JSON delimiters (`"' {}:,`) matter because tool-call +/// arguments often arrive as a JSON-encoded string, so the command appears as +/// `{"command":"rm -rf /"}` — without splitting on the quote/brace the `rm` +/// token would be `{"command":"rm` and the recursive-delete check would miss +/// it (the gap exposed when the literal `rm -rf /` deny rule was dropped, S-C2). +/// We deliberately do NOT split on `/` so path targets stay intact (`./build` +/// must remain one token so a scoped delete isn't flagged). fn tokens(lower: &str) -> impl Iterator { lower - .split(|c: char| c.is_whitespace() || c == ';' || c == '|' || c == '&') + .split(|c: char| { + c.is_whitespace() + || matches!(c, ';' | '|' | '&' | '"' | '\'' | '{' | '}' | ':' | ',' | '(' | ')') + }) .filter(|t| !t.is_empty()) } diff --git a/src/security/mod.rs b/src/security/mod.rs index 2b92107..6808a9f 100644 --- a/src/security/mod.rs +++ b/src/security/mod.rs @@ -61,6 +61,32 @@ impl ViolationKind { } } +/// Where in the request body the matching leaf sat. Decisive for the +/// false-positive judgment (S-C3): a hit "in the current tool call" is an +/// action the model is taking now; a hit "in earlier conversation history" is +/// almost always the model quoting/discussing something. The block message +/// surfaces this so the user can tell the two apart. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MatchLocation { + /// In the current in-flight tool call's arguments. + ToolCall, + /// In earlier conversation history (a prior turn the client resent). + History, + /// Elsewhere in the request body (system prompt, chat text, tool defs, + /// or non-shell tool content like a file being written). + Body, +} + +impl MatchLocation { + pub fn describe(&self) -> &'static str { + match self { + MatchLocation::ToolCall => "in the current tool call", + MatchLocation::History => "in earlier conversation history", + MatchLocation::Body => "in the request body", + } + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct Violation { pub kind: ViolationKind, @@ -68,6 +94,8 @@ pub struct Violation { /// secret pattern name) — NOT the matched value, which can contain the /// secret itself. pub matched: String, + /// Where the matching leaf sat in the payload. + pub location: MatchLocation, } impl Violation { @@ -141,6 +169,14 @@ impl SecurityEngine { scanner::scan_request(&json, &self.rules) } + /// Scan an MCP JSON-RPC body. Like [`scan_request`] but for the JSON-RPC + /// envelope: only `tools/call` `params.arguments` get command-shaped checks; + /// the rest is prose (data checks only). See [`scanner::scan_mcp`]. + pub fn scan_mcp(&self, body: &[u8]) -> Option { + let json = self.parse_for_scan(body)?; + scanner::scan_mcp(&json, &self.rules) + } + fn parse_for_scan(&self, body: &[u8]) -> Option { // Master switch — `security.enabled = false` forwards without scanning. if !self.rules.enabled { @@ -151,6 +187,29 @@ impl SecurityEngine { // the fail-open path. Real clients never emit a BOM; this is // defense-in-depth. let body = body.strip_prefix(b"\xef\xbb\xbf").unwrap_or(body); - serde_json::from_slice(body).ok() + match serde_json::from_slice(body) { + Ok(v) => Some(v), + Err(_) => { + // Fail-open, but NOT silently (S-M9): a body the scanner can't + // parse is a body it can't inspect. An empty body is a normal + // GET; a non-empty unparseable one (e.g. an encoding we don't + // handle) is the kind of blind spot that hid the cost-tracking + // outage. Count it and warn periodically rather than never. + if !body.is_empty() { + let n = UNSCANNED_BODIES.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1; + if n == 1 || n.is_multiple_of(100) { + tracing::warn!( + "security scan skipped: request body #{n} is not parseable JSON ({} bytes) — forwarded unscanned", + body.len() + ); + } + } + None + } + } } } + +/// Count of request bodies the scanner could not parse (and therefore could not +/// inspect). Process-local; surfaced in the periodic warn above. +pub static UNSCANNED_BODIES: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); diff --git a/src/security/rules.rs b/src/security/rules.rs index 8580c28..d14f323 100644 --- a/src/security/rules.rs +++ b/src/security/rules.rs @@ -77,14 +77,20 @@ pub const DEFAULT_DENY_PATHS: &[&str] = &[ "/etc/shadow", ]; -pub const DEFAULT_DENY_COMMANDS: &[&str] = &["rm -rf /", "rm -rf ~", "chmod 777", ":(){ :|:& };:"]; +// `rm -rf /` and `rm -rf ~` are deliberately NOT listed here: substring +// matching made `rm -rf /tmp/build-cache` and `rm -rf ~/.cache/pip` — everyday +// cleanup — read as the catastrophic literal (S-C2). The shape-aware +// `super::destructive` detector (always on for tool args) owns recursive-force +// deletes and only fires on broad/expandable targets, so scoped deletes pass. +pub const DEFAULT_DENY_COMMANDS: &[&str] = &["chmod 777", ":(){ :|:& };:"]; -pub const NETWORK_MOUNT_NEEDLES: &[&str] = &[ - "/Volumes/", - r"\\", // Windows UNC prefix (two backslashes) - "smb://", - "nfs://", -]; +// Substring needles for genuine network-mount URI schemes. The Windows UNC +// prefix (`\\`) is matched separately by [`is_unc_mount`] (a bare-substring +// `\\` fired on every JSON-escaped Windows path — S-C1). `/Volumes/` was +// dropped (S-H7): it is where macOS mounts local USB drives, DMGs, and Time +// Machine, not specifically network shares, so a repo on an external SSD had +// every tool call blocked. +pub const NETWORK_MOUNT_NEEDLES: &[&str] = &["smb://", "nfs://", "cifs://", "afp://"]; /// Does `value` reference a denied path? /// @@ -124,12 +130,67 @@ fn collapse_ws(s: &str) -> String { } pub fn mount_matches(value: &str) -> bool { - // Case-fold only — do NOT unify separators here, or the UNC `\\` needle - // would collide with `//` in ordinary URLs (e.g. `https://...`). let hay = value.to_ascii_lowercase(); NETWORK_MOUNT_NEEDLES .iter() - .any(|needle| hay.contains(&needle.to_ascii_lowercase())) + .any(|needle| hay.contains(needle)) + || is_unc_mount(value) +} + +/// True when `value` contains a Windows **UNC network-share root** — `\\` at a +/// token boundary followed by a hostname-ish character. This deliberately does +/// NOT match a bare `\\` substring: JSON-escaped Windows paths decode to a leaf +/// like `C:\\Users\\me` (and OpenAI/Codex tool arguments are a JSON-encoded +/// string, so `{"path":"C:\\\\Users"}` decodes to `C:\\Users`), which contains +/// `\\` mid-token — not a network mount (S-C1). Local device namespaces +/// (`\\?\`, `\\.\`) and WSL (`\\wsl$`, `\\wsl.localhost`) are whitelisted: they +/// are local, not network. +pub fn is_unc_mount(value: &str) -> bool { + let bytes = value.as_bytes(); + let mut i = 0; + while i + 1 < bytes.len() { + if bytes[i] == b'\\' && bytes[i + 1] == b'\\' { + let at_boundary = i == 0 + || matches!( + bytes[i - 1], + b' ' | b'\t' | b'\n' | b'\r' | b'"' | b'\'' | b'=' | b'(' | b',' | b':' + ); + // `:` allows `path:\\server\share`-style prefixes but the doubled + // backslash in a drive path (`C:\\Users`) has the `\\` preceded by + // `:`? No — there it's `C` `:` `\` `\`, so the byte before `\\` is + // `:`. Guard that: a single drive letter + colon before `\\` is a + // local drive path, not UNC. + let drive_path = i >= 2 && bytes[i - 1] == b':' && (bytes[i - 2] as char).is_ascii_alphabetic(); + if at_boundary && !drive_path { + let rest = &value[i + 2..]; + let rest_lower = rest.to_ascii_lowercase(); + let local = rest.starts_with('?') + || rest.starts_with('.') + || rest_lower.starts_with("wsl$") + || rest_lower.starts_with("wsl.localhost"); + let hostnameish = rest + .chars() + .next() + .map(|c| c.is_ascii_alphanumeric()) + .unwrap_or(false); + if !local && hostnameish { + return true; + } + } + } + i += 1; + } + false +} + +/// Drop empty / whitespace-only rules. A blank deny rule makes `contains("")` +/// true for every leaf, blocking 100% of traffic (S-H8); filter it at ruleset +/// construction so a hand-edited config or installed pack can't brick the proxy. +pub fn non_empty_rules>(rules: I) -> Vec { + rules + .into_iter() + .filter(|r| !r.trim().is_empty()) + .collect() } /// Lowercase and unify path separators (`\` → `/`) for case- and @@ -182,11 +243,42 @@ mod tests { } #[test] - fn mount_matches_case_insensitive_without_url_false_positive() { - assert!(mount_matches("/VOLUMES/backup/secrets")); - assert!(mount_matches("\\\\server\\share")); + fn mount_matches_real_network_schemes_and_unc_only() { + assert!(mount_matches("\\\\server\\share")); // genuine UNC root assert!(mount_matches("SMB://host/share")); + assert!(mount_matches("nfs://host/export")); // A plain https URL must not be flagged as a UNC mount. assert!(!mount_matches("https://api.anthropic.com/v1/messages")); + // /Volumes/ is local on macOS (USB/DMG/Time Machine) — no longer flagged. + assert!(!mount_matches("/Volumes/T7/code/project")); + } + + #[test] + fn unc_match_ignores_escaped_windows_paths() { + // S-C1: the regression that blocked every Codex tool call and every + // file write containing a Windows path. + // A drive path with a doubled (JSON-escaped) backslash is NOT a mount. + assert!(!is_unc_mount(r"C:\\Users\\me\\project")); + assert!(!is_unc_mount(r#"{"path":"C:\\Users\\me"}"#)); + // Local device namespaces and WSL are local, not network. + assert!(!is_unc_mount(r"\\?\C:\very\long\path")); + assert!(!is_unc_mount(r"\\.\PhysicalDrive0")); + assert!(!is_unc_mount(r"\\wsl$\Ubuntu\home\me")); + assert!(!is_unc_mount(r"\\wsl.localhost\Ubuntu\home")); + // A genuine UNC share root IS a mount. + assert!(is_unc_mount(r"\\fileserver\share\secret")); + assert!(is_unc_mount(r#"{"path":"\\fileserver\share"}"#)); + } + + #[test] + fn non_empty_rules_drops_blanks() { + // S-H8: a blank deny rule would match every leaf. + let filtered = non_empty_rules(vec![ + "rm -rf /".to_string(), + "".to_string(), + " ".to_string(), + "chmod 777".to_string(), + ]); + assert_eq!(filtered, vec!["rm -rf /".to_string(), "chmod 777".to_string()]); } } diff --git a/src/security/scanner.rs b/src/security/scanner.rs index 819c0b9..8765bf4 100644 --- a/src/security/scanner.rs +++ b/src/security/scanner.rs @@ -29,16 +29,26 @@ use serde_json::Value; use super::rules::{self, Ruleset}; use super::secrets; -use super::{Violation, ViolationKind}; +use super::{MatchLocation, Violation, ViolationKind}; /// Which checks apply to a string leaf, by where it sits in the payload. #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum Scope { - /// Inside a tool-call argument subtree → full check set. + /// Inside a **shell-ish** tool-call argument subtree (bash/exec/run/…) → + /// full check set. The tool is one that runs a command, so its arguments + /// are commands. ToolArgs, + /// Inside an **editor/content** tool-call argument subtree (Write, Edit, + /// apply_patch, …) → data checks only (secrets, DLP). The argument is file + /// *content* the model is writing, not a command to run — a README that + /// mentions `~/.ssh` or a runbook that mentions `chmod 777` must not 403 + /// (S-H4: the class that blocked this very review session). A secret or + /// card number in that content is still worth catching, so data checks + /// stay on. + ContentArgs, /// Anywhere else (system prompt, chat text, tool definitions, tool /// results) → data checks only (secrets, DLP). Tool-call shapes found - /// here promote their subtree to [`Scope::ToolArgs`]. + /// here promote their subtree to [`Scope::ToolArgs`] / [`Scope::ContentArgs`]. Prose, /// An already-adjudicated conversation turn → data checks only, and /// tool-call shapes do NOT promote. See [`walk_turn_array`]. @@ -55,6 +65,42 @@ pub fn scan_request(value: &Value, rules: &Ruleset) -> Option { walk(value, rules, Scope::Prose) } +/// Context-aware scan for an MCP JSON-RPC body (M-C1). The envelope +/// (`jsonrpc`/`method`/`id` and most of `params`) is **prose** — a memory note +/// or issue title that merely mentions `rm -rf` or `~/.ssh` must not 403. Only +/// the `params.arguments` of a `tools/call` are real tool-call arguments and +/// get the full command set (or content-only checks for an editor-ish tool, +/// keyed on `params.name`). Data checks (secrets, DLP) still run across the +/// whole envelope. Mirrors the prose-safe scoping the LLM proxy already uses — +/// the MCP path was still running the full-strict `scan`. +pub fn scan_mcp(value: &Value, rules: &Ruleset) -> Option { + if value.get("method").and_then(Value::as_str) == Some("tools/call") { + if let Some(params) = value.get("params") { + if let Some(args) = params.get("arguments") { + // MCP tools are overwhelmingly app integrations (memory, search, + // GitHub, …) whose arguments are free text, not commands — so + // the default is data-checks-only (catch credential exfil, the + // real MCP risk). Command-shaped checks apply ONLY when the tool + // name is identifiably a shell/exec tool. This is the inverse of + // the LLM default (where Bash/Read are common and dangerous), and + // is what keeps a memory note that mentions `rm -rf` from 403ing. + let name = params.get("name").and_then(Value::as_str); + let scope = if name.map(is_shell_tool).unwrap_or(false) { + Scope::ToolArgs + } else { + Scope::ContentArgs + }; + if let Some(v) = walk(args, rules, scope) { + return Some(v); + } + } + } + } + // Data checks across the whole envelope; command-shaped checks stay scoped + // to the arguments handled above (prose here, so they don't fire). + walk(value, rules, Scope::Prose) +} + fn walk(value: &Value, rules: &Ruleset, scope: Scope) -> Option { match value { Value::Object(map) => { @@ -62,10 +108,14 @@ fn walk(value: &Value, rules: &Ruleset, scope: Scope) -> Option { // Conversation turn arrays get latest-turn scoping; see // walk_turn_array. Only from Prose — under ToolArgs (full // scan) everything stays strict, and under History nothing - // re-promotes. - if scope == Scope::Prose && (k == "messages" || k == "contents") { + // re-promotes. `input` covers the OpenAI Responses API, whose + // items carry `type` instead of `role` (S-H6). + if scope == Scope::Prose && (k == "messages" || k == "contents" || k == "input") { if let Value::Array(turns) = v { - if turns.iter().any(|t| t.get("role").is_some()) { + if turns + .iter() + .any(|t| t.get("role").is_some() || t.get("type").is_some()) + { if let Some(violation) = walk_turn_array(turns, rules) { return Some(violation); } @@ -75,8 +125,9 @@ fn walk(value: &Value, rules: &Ruleset, scope: Scope) -> Option { } let child_scope = match scope { Scope::ToolArgs => Scope::ToolArgs, - Scope::Prose if holds_tool_args(k, map) => Scope::ToolArgs, - other => other, + Scope::ContentArgs => Scope::ContentArgs, + Scope::Prose => tool_arg_scope(k, map).unwrap_or(Scope::Prose), + Scope::History => Scope::History, }; if let Some(violation) = walk(v, rules, child_scope) { return Some(violation); @@ -129,11 +180,20 @@ fn walk_turn_array(turns: &[Value], rules: &Ruleset) -> Option { None } -/// A turn authored by the model: Anthropic/OpenAI `assistant`, Gemini `model`. +/// A turn authored by the model: Anthropic/OpenAI `assistant`, Gemini `model`, +/// or an OpenAI Responses API `function_call` item (which has no `role`). fn is_actor_turn(turn: &Value) -> bool { - matches!( + if matches!( turn.get("role").and_then(Value::as_str), Some("assistant") | Some("model") + ) { + return true; + } + // Responses API: the model's tool call is a top-level `input` item with + // `type: "function_call"` (or a `*_call` variant) and no role. + matches!( + turn.get("type").and_then(Value::as_str), + Some(t) if t.ends_with("_call") ) } @@ -143,6 +203,14 @@ fn is_actor_turn(turn: &Value) -> bool { /// (Anthropic/Gemini clients may attach extra text alongside the results — /// reminders, environment notes — so one result block is enough to qualify.) fn is_tool_result_turn(turn: &Value) -> bool { + // Responses API: tool output is an `input` item with + // `type: "function_call_output"` and no role. + if matches!( + turn.get("type").and_then(Value::as_str), + Some(t) if t.ends_with("_call_output") + ) { + return true; + } match turn.get("role").and_then(Value::as_str) { Some("tool") => true, Some("user") | Some("function") => { @@ -189,11 +257,128 @@ fn holds_tool_args(key: &str, obj: &serde_json::Map) -> bool { } } +/// If `key` (an entry of `obj`) holds tool-call arguments, return the scope its +/// subtree should get — [`Scope::ToolArgs`] for a shell-ish tool (its args are +/// commands) or [`Scope::ContentArgs`] for an editor/content tool (its args are +/// file content, S-H4). Unknown tool names default to strict `ToolArgs` so an +/// unrecognized tool keeps full coverage. Returns `None` if `key` isn't a +/// tool-args slot. +fn tool_arg_scope(key: &str, obj: &serde_json::Map) -> Option { + if !holds_tool_args(key, obj) { + return None; + } + let name = tool_name(obj); + Some(if name.map(is_editor_tool).unwrap_or(false) { + Scope::ContentArgs + } else { + Scope::ToolArgs + }) +} + +/// Best-effort tool name from a tool-call object: the sibling `name` +/// (Anthropic `tool_use`, OpenAI Responses `function_call`, legacy +/// `function_call`) or the nested `function.name` (OpenAI Chat `tool_calls`). +fn tool_name(obj: &serde_json::Map) -> Option<&str> { + obj.get("name") + .and_then(Value::as_str) + .or_else(|| { + obj.get("function") + .and_then(|f| f.get("name")) + .and_then(Value::as_str) + }) +} + +/// Does this tool name denote a shell/exec tool — one whose arguments are a +/// command line? Used for MCP scoping, where the default is data-only and only +/// a recognized shell tool gets full command-shaped checks. +fn is_shell_tool(name: &str) -> bool { + let n = name.to_ascii_lowercase(); + const SHELL_MARKERS: &[&str] = &[ + "bash", + "shell", + "exec", + "terminal", + "powershell", + "run_command", + "run_shell", + "command_exec", + "system_exec", + "shell_command", + ]; + n == "sh" || n == "cmd" || n == "run" || SHELL_MARKERS.iter().any(|m| n.contains(m)) +} + +/// Does this tool name denote an editor/content tool — one whose arguments are +/// file *content* being written, not a command to execute? Conservative: a name +/// we don't recognize stays strict (full command checks). +fn is_editor_tool(name: &str) -> bool { + let n = name.to_ascii_lowercase(); + const EDITOR_MARKERS: &[&str] = &[ + "write", + "edit", // also matches multiedit / str_replace_editor + "str_replace", + "create_file", + "apply_patch", + "notebook", + "new_file", + "save_file", + "update_file", + "insert_edit", + ]; + EDITOR_MARKERS.iter().any(|m| n.contains(m)) +} + fn check_string(s: &str, rules: &Ruleset, scope: Scope) -> Option { + // Where this leaf sits — surfaced in the block message so a user can tell + // a real action from the model quoting something (S-C3). + let location = match scope { + Scope::ToolArgs | Scope::ContentArgs => MatchLocation::ToolCall, + Scope::Prose => MatchLocation::Body, + Scope::History => MatchLocation::History, + }; + // Which checks run where: + // - Command/destructive/exfil checks: ONLY shell-ish tool args — a command + // is only dangerous where it will be executed. + // - Path/mount checks: shell tool args, plus *path-shaped* leaves of + // content/editor tools — `read_file {"path": "~/.ssh/id_rsa"}` must block + // even though read_file is not a shell. A path-shaped leaf is short and + // single-line; a file body or note being written is neither, so a README + // that mentions `~/.ssh` in its prose passes (S-H4) while a path argument + // pointing AT `~/.ssh` blocks. + // - Prose and history: data checks only. + let command_set = scope == Scope::ToolArgs; + let path_set = command_set || (scope == Scope::ContentArgs && path_shaped(s)); + + if path_set && !command_set { + // Path/mount checks for a content-tool's path-shaped argument. + let path_allowed = rules + .allow_paths + .iter() + .any(|allow| rules::path_matches(s, allow)); + if !path_allowed { + for rule in &rules.deny_paths { + if rules::path_matches(s, rule) { + return Some(Violation { + kind: ViolationKind::Path, + matched: rule.clone(), + location, + }); + } + } + } + if rules.block_network_mounts && rules::mount_matches(s) { + return Some(Violation { + kind: ViolationKind::Mount, + matched: extract_mount_prefix(s).to_string(), + location, + }); + } + } + // Order: paths → commands → mounts → secrets. Paths are the highest- // signal category; secrets last so a path-blocked SSH key dump doesn't // also accidentally trip the private-key regex. - if scope == Scope::ToolArgs { + if command_set { // A leaf matching a project `allow_paths` exception skips the path-deny // checks entirely — but command, mount, and secret checks below still // run, so `allow_paths` can never green-light a dangerous command. @@ -207,6 +392,7 @@ fn check_string(s: &str, rules: &Ruleset, scope: Scope) -> Option { return Some(Violation { kind: ViolationKind::Path, matched: rule.clone(), + location, }); } } @@ -216,6 +402,7 @@ fn check_string(s: &str, rules: &Ruleset, scope: Scope) -> Option { return Some(Violation { kind: ViolationKind::Command, matched: rule.clone(), + location, }); } } @@ -227,12 +414,14 @@ fn check_string(s: &str, rules: &Ruleset, scope: Scope) -> Option { return Some(Violation { kind: ViolationKind::Destructive, matched: label.to_string(), + location, }); } if rules.block_network_mounts && rules::mount_matches(s) { return Some(Violation { kind: ViolationKind::Mount, matched: extract_mount_prefix(s).to_string(), + location, }); } } @@ -243,6 +432,7 @@ fn check_string(s: &str, rules: &Ruleset, scope: Scope) -> Option { return Some(Violation { kind: ViolationKind::Secret, matched: name.to_string(), + location, }); } // Pack-contributed patterns are additive (extra detection). Cap the @@ -255,6 +445,7 @@ fn check_string(s: &str, rules: &Ruleset, scope: Scope) -> Option { return Some(Violation { kind: ViolationKind::Secret, matched: name.to_string(), + location, }); } } @@ -266,11 +457,12 @@ fn check_string(s: &str, rules: &Ruleset, scope: Scope) -> Option { // Technique-shaped exfil (DNS exfil, secret→network) first — highest // signal and names the technique, not the data. Command-shaped, so // tool-args only. - if scope == Scope::ToolArgs { + if command_set { if let Some(name) = super::exfil::first_match(hay) { return Some(Violation { kind: ViolationKind::Exfil, matched: name.to_string(), + location, }); } } @@ -279,6 +471,7 @@ fn check_string(s: &str, rules: &Ruleset, scope: Scope) -> Option { return Some(Violation { kind: ViolationKind::Dlp, matched: name.to_string(), + location, }); } } @@ -290,6 +483,14 @@ fn check_string(s: &str, rules: &Ruleset, scope: Scope) -> Option { /// additive pack scan (invariant I5). const MAX_PACK_SCAN_INPUT: usize = 1024 * 1024; +/// Is this leaf plausibly a *path argument* (as opposed to file content / a +/// note body)? Path arguments are short and single-line; content is long or +/// multi-line. Used to apply path checks to content-tool args without flagging +/// prose that merely mentions a protected path. +fn path_shaped(s: &str) -> bool { + s.len() <= 512 && !s.contains('\n') +} + /// Largest prefix of `s` no longer than `max` bytes that ends on a UTF-8 char /// boundary. Returns `s` unchanged when it already fits. fn capped(s: &str, max: usize) -> &str { diff --git a/src/security/secrets.rs b/src/security/secrets.rs index 97222c6..a2f2204 100644 --- a/src/security/secrets.rs +++ b/src/security/secrets.rs @@ -61,10 +61,18 @@ impl SecretPattern { pub static PATTERNS: LazyLock> = LazyLock::new(|| { vec![ SecretPattern::builtin("AWS access key ID", r"\bAKIA[0-9A-Z]{16}\b"), + // STS temporary access keys (S-M12). + SecretPattern::builtin("AWS temporary access key ID", r"\bASIA[0-9A-Z]{16}\b"), SecretPattern::builtin("private key header", r"-----BEGIN [A-Z ]+PRIVATE KEY-----"), - SecretPattern::builtin("GitHub personal access token", r"\bghp_[A-Za-z0-9]{36}\b"), + // ghp_ (classic), gho_/ghu_/ghs_/ghr_ (OAuth/user/server/refresh) — one + // pattern covers all variants (S-M12). + SecretPattern::builtin("GitHub token", r"\bgh[pousr]_[A-Za-z0-9]{36}\b"), + // Modern OpenAI project keys use `sk-proj-…` with hyphens/underscores, + // which the 48-alnum-run pattern misses (S-M12). + SecretPattern::builtin("OpenAI project key", r"\bsk-proj-[A-Za-z0-9_-]{20,}\b"), SecretPattern::builtin("OpenAI API key", r"\bsk-[A-Za-z0-9]{48}\b"), SecretPattern::builtin("Anthropic API key", r"\bsk-ant-[A-Za-z0-9_-]{36,}\b"), + SecretPattern::builtin("GitLab personal access token", r"\bglpat-[A-Za-z0-9_-]{20,}\b"), SecretPattern::builtin("Slack token", r"\bxox[abprs]-[A-Za-z0-9-]{10,}\b"), // Added v0.6. All keep a distinctive prefix + length so the false- // positive rate stays low; deliberately NO generic-entropy or JWT @@ -87,15 +95,34 @@ pub static PATTERNS: LazyLock> = LazyLock::new(|| { ] }); -/// Name of the first **built-in** pattern that matches `value`, or `None`. +/// Well-known documentation / example credentials that vendors publish for +/// tutorials and that constantly appear in READMEs, fixtures, and SDK docs. +/// Flagging them was a top false-positive: an agent reading a file containing +/// AWS's canonical `AKIAIOSFODNN7EXAMPLE` would 403 every later request in the +/// session (S-C3). A match whose text is exactly one of these is not a secret. +const EXAMPLE_SECRETS: &[&str] = &[ + "AKIAIOSFODNN7EXAMPLE", // AWS docs access key id + "ASIAIOSFODNN7EXAMPLE", // AWS docs STS key id +]; + +fn is_example_secret(matched: &str) -> bool { + EXAMPLE_SECRETS.iter().any(|e| e.eq_ignore_ascii_case(matched)) +} + +/// Name of the first **built-in** pattern that matches `value` with a match +/// that is not a known documentation/example credential, or `None`. pub fn first_match(value: &str) -> Option<&'static str> { - PATTERNS.iter().find(|p| p.regex.is_match(value)).map(|p| { - // Built-ins are always borrowed; this is the &'static name. - match &p.name { - Cow::Borrowed(s) => *s, - Cow::Owned(_) => unreachable!("built-in patterns carry borrowed names"), + for p in PATTERNS.iter() { + // Any non-example match counts; scan all matches so a real key elsewhere + // in the leaf isn't masked by a leading example. + if p.regex.find_iter(value).any(|m| !is_example_secret(m.as_str())) { + return match &p.name { + Cow::Borrowed(s) => Some(*s), + Cow::Owned(_) => unreachable!("built-in patterns carry borrowed names"), + }; } - }) + } + None } /// Name of the first pattern in `patterns` that matches `value`, or `None`. diff --git a/tests/integration/security_test.rs b/tests/integration/security_test.rs index 4d96011..5226035 100644 --- a/tests/integration/security_test.rs +++ b/tests/integration/security_test.rs @@ -97,10 +97,19 @@ fn does_not_match_unrelated_directory_with_ssh_in_name() { #[test] fn matches_rm_rf_root() { + // S-C2: `rm -rf /` is now caught by the shape-aware destructive detector, + // not the literal deny list (which dropped the `rm` literals so scoped + // deletes like `rm -rf /tmp/x` aren't false-flagged). let body = br#"{"x": "rm -rf / --no-preserve-root"}"#; let v = engine().scan(body).expect("violation"); - assert_eq!(v.kind, ViolationKind::Command); - assert_eq!(v.matched, "rm -rf /"); + assert_eq!(v.kind, ViolationKind::Destructive); +} + +#[test] +fn scoped_rm_is_not_blocked() { + // The everyday-cleanup case that the substring rule used to false-block. + let body = br#"{"x": "rm -rf /tmp/build-cache"}"#; + assert!(engine().scan(body).is_none()); } #[test] @@ -124,11 +133,12 @@ fn safe_commands_pass() { // ──────────────────────────── Mount rules ──────────────────────────── #[test] -fn blocks_macos_volumes() { +fn volumes_is_local_not_blocked() { + // S-H7: /Volumes/ is where macOS mounts local USB drives, DMGs, and Time + // Machine — not specifically network shares. A repo on an external SSD + // must not have every tool call blocked. let body = br#"{"x": "cp file /Volumes/external/backup"}"#; - let v = engine().scan(body).expect("violation"); - assert_eq!(v.kind, ViolationKind::Mount); - assert_eq!(v.matched, "/Volumes/"); + assert!(engine().scan(body).is_none()); } #[test] @@ -154,7 +164,7 @@ fn mount_blocking_can_be_disabled() { ..Ruleset::default() }; let engine = SecurityEngine::new(rules); - let body = br#"{"x": "ls /Volumes/disk"}"#; + let body = br#"{"x": "mount smb://fileserver/share"}"#; assert!(engine.scan(body).is_none()); } @@ -162,13 +172,22 @@ fn mount_blocking_can_be_disabled() { #[test] fn detects_aws_access_key_id() { - // Fake but pattern-matching key. - let body = br#"{"x": "export AWS_KEY=AKIAIOSFODNN7EXAMPLE"}"#; + // Fake but pattern-matching key (NOT the canonical docs `…EXAMPLE`, which + // is now exempted under S-C3). + let body = br#"{"x": "export AWS_KEY=AKIAIOSFODNN7REALKEY"}"#; let v = engine().scan(body).expect("violation"); assert_eq!(v.kind, ViolationKind::Secret); assert_eq!(v.matched, "AWS access key ID"); } +#[test] +fn aws_example_key_is_exempt() { + // S-C3: the canonical AWS docs key must not 403 a session that merely read + // a file containing it. + let body = br#"{"x": "export AWS_KEY=AKIAIOSFODNN7EXAMPLE"}"#; + assert!(engine().scan(body).is_none()); +} + #[test] fn detects_private_key_header() { let body = br#"{"x": "config: -----BEGIN OPENSSH PRIVATE KEY-----\nMIIEpAIB..."}"#; @@ -183,7 +202,7 @@ fn detects_github_pat() { let body = br#"{"x": "GITHUB_TOKEN=ghp_AbCdEfGhIjKlMnOpQrStUvWxYz0123456789"}"#; let v = engine().scan(body).expect("violation"); assert_eq!(v.kind, ViolationKind::Secret); - assert_eq!(v.matched, "GitHub personal access token"); + assert_eq!(v.matched, "GitHub token"); } #[test] @@ -282,7 +301,9 @@ fn allow_path_exempts_path_but_not_command() { let engine = SecurityEngine::new(rules); let body = br#"{"x": "cat ~/.aws/creds && rm -rf /"}"#; let v = engine.scan(body).expect("violation"); - assert_eq!(v.kind, ViolationKind::Command); + // The path is exempt, but `rm -rf /` is still caught (now by the + // destructive shape detector — S-C2). + assert_eq!(v.kind, ViolationKind::Destructive); } #[test] @@ -294,7 +315,7 @@ fn allow_path_exempts_path_but_not_secret() { ..Ruleset::default() }; let engine = SecurityEngine::new(rules); - let body = br#"{"x": "dump ~/.aws/creds AKIAIOSFODNN7EXAMPLE"}"#; + let body = br#"{"x": "dump ~/.aws/creds AKIAIOSFODNN7REALKEY"}"#; let v = engine.scan(body).expect("violation"); assert_eq!(v.kind, ViolationKind::Secret); } @@ -561,7 +582,8 @@ fn request_scan_blocks_legacy_function_call_arguments() { let body = br#"{"messages":[{"role":"assistant","function_call":{ "name":"bash","arguments":"{\"command\":\"rm -rf / --no-preserve-root\"}"}}]}"#; let v = engine().scan_request(body).expect("violation"); - assert_eq!(v.kind, ViolationKind::Command); + // `rm -rf /` is now a destructive-shape match (S-C2). + assert_eq!(v.kind, ViolationKind::Destructive); } #[test] @@ -585,7 +607,7 @@ fn request_scan_still_detects_secrets_in_prose() { // Data checks stay global: a credential in chat text is exfiltration- // relevant no matter where it sits. let body = br#"{"messages":[{"role":"user", - "content":"my key is AKIAIOSFODNN7EXAMPLE, is that safe to commit?"}]}"#; + "content":"my key is AKIAIOSFODNN7REALKEY, is that safe to commit?"}]}"#; let v = engine().scan_request(body).expect("violation"); assert_eq!(v.kind, ViolationKind::Secret); } @@ -744,7 +766,7 @@ fn request_scan_secrets_still_caught_in_history() { {"role":"assistant","content":[ {"type":"tool_use","id":"t1","name":"bash","input":{"command":"cat notes.txt"}}]}, {"role":"user","content":[ - {"type":"tool_result","tool_use_id":"t1","content":"key=AKIAIOSFODNN7EXAMPLE"}]}, + {"type":"tool_result","tool_use_id":"t1","content":"key=AKIAIOSFODNN7REALKEY"}]}, {"role":"user","content":"summarize that"}, {"role":"assistant","content":[{"type":"text","text":"It contains a key."}]} ]}"#; From 905f1368959d9b74eedae2ccd06f42e54b6bfdca Mon Sep 17 00:00:00 2001 From: codehippie1 Date: Wed, 10 Jun 2026 15:38:17 -0400 Subject: [PATCH 4/9] proxy: upstream timeouts, disconnect-cancel, self-identifying block messages, $0-recording guards Shared HTTP client with connect/keepalive/read timeouts (no more hangs on VPN flips or stalled streams). Client disconnect now drops the upstream stream instead of draining the full billed generation, recording a 499 partial. Every block self-identifies as Burnwall with what/where, escalating escape hatches, Retry-After, and provider-correct error JSON; new 'burnwall report-bug' writes a sanitized local report. OpenAI Responses API now parses (Codex no longer records $0), all-zero usage treated as a parse failure, and unknown models warn once instead of silently costing $0. Cache-projection write moved off the pre-forward hot path into the tee. --- src/cli/mod.rs | 4 + src/cli/report_bug.rs | 96 ++++++++ src/cli/sidecar.rs | 1 + src/cli/start.rs | 84 +++++-- src/providers/openai.rs | 77 ++++-- src/proxy/forwarding.rs | 154 ++++++++++-- src/proxy/handler.rs | 360 ++++++++++++++++++++++++----- src/proxy/mod.rs | 28 ++- src/proxy/streaming.rs | 30 ++- tests/integration/pipeline_test.rs | 161 +++++++++++-- tests/unit/parser_test.rs | 105 +++++++++ 11 files changed, 968 insertions(+), 132 deletions(-) create mode 100644 src/cli/report_bug.rs diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 728574a..4f7b3f9 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -27,6 +27,7 @@ pub mod metrics; pub mod pricing; #[cfg(feature = "observe")] pub mod report; +pub mod report_bug; pub mod routing; pub mod rules; pub mod savings; @@ -69,6 +70,8 @@ pub enum Command { Init(init::InitArgs), /// Inspect security events (blocked attempts). Security(security::SecurityArgs), + /// Write a sanitized, local bug report of recent blocks (nothing is sent). + ReportBug(report_bug::ReportBugArgs), /// Print a shell-completion script to stdout. Completions(completions::CompletionsArgs), /// Pass-through MCP HTTP proxy that logs tools/call invocations. @@ -140,6 +143,7 @@ impl Cli { Command::Config(args) => config_cmd::run_cmd(args), Command::Init(args) => init::run_cmd(args), Command::Security(args) => security::run_cmd(args), + Command::ReportBug(args) => report_bug::run_cmd(args), Command::Completions(args) => completions::run_cmd(args), #[cfg(feature = "mcp")] Command::McpWatch(args) => mcp_watch::run_cmd(args).await, diff --git a/src/cli/report_bug.rs b/src/cli/report_bug.rs new file mode 100644 index 0000000..a69c78e --- /dev/null +++ b/src/cli/report_bug.rs @@ -0,0 +1,96 @@ +//! `burnwall report-bug` — write a **sanitized, local** report of recent blocks +//! so a user who hit a false positive can file a useful issue. Zero-telemetry: +//! nothing is sent anywhere. The report carries only metadata already in the +//! DB — rule labels (`~/.ssh`, `recursive force delete`), pattern *names* +//! (`AWS access key ID`, never the value), event types, timestamps, and +//! provider/model — plus OS/version. The user reviews the file and attaches it +//! to a GitHub issue themselves. + +use std::sync::Arc; + +use anyhow::Context; +use clap::Args; + +use crate::storage::Storage; + +#[derive(Args, Debug)] +pub struct ReportBugArgs { + /// How many days of recent blocks to include (default 1 = today). + #[arg(long, default_value_t = 1)] + pub days: i64, + /// Print the report to stdout instead of writing a file. + #[arg(long)] + pub stdout: bool, +} + +pub fn run_cmd(args: ReportBugArgs) -> anyhow::Result<()> { + let storage = Arc::new(Storage::open_default().context("opening storage")?); + let events = storage.security_events_since_days(args.days.max(1))?; + + let report = build_report(&events, args.days.max(1)); + + if args.stdout { + print!("{report}"); + return Ok(()); + } + + let dir = crate::storage::data_dir().context("locating data dir")?; + let stamp = chrono::Local::now().format("%Y%m%d-%H%M%S"); + let path = dir.join(format!("bug-report-{stamp}.md")); + std::fs::write(&path, &report).with_context(|| format!("writing {}", path.display()))?; + + let issues = format!("{}/issues/new", env!("CARGO_PKG_REPOSITORY")); + println!("📋 Wrote a sanitized bug report (no payload content, nothing sent):"); + println!(" {}", path.display()); + println!(); + println!(" Review it, then open an issue and attach it:"); + println!(" {issues}"); + println!(); + println!(" If a block was a false positive, mention what you were doing when it fired."); + Ok(()) +} + +fn build_report(events: &[crate::storage::SecurityEvent], days: i64) -> String { + let mut s = String::new(); + s.push_str("# Burnwall bug report\n\n"); + s.push_str(&format!("- Version: {}\n", env!("CARGO_PKG_VERSION"))); + s.push_str(&format!("- OS: {} {}\n", std::env::consts::OS, std::env::consts::ARCH)); + s.push_str(&format!( + "- Generated: {}\n", + chrono::Local::now().format("%Y-%m-%d %H:%M:%S %z") + )); + s.push_str(&format!("- Window: last {days} day(s)\n\n")); + s.push_str( + "> This report contains only metadata (rule labels, pattern names, timestamps).\n\ + > No request/response payloads, secrets, or file contents are included.\n\n", + ); + + s.push_str("## Recent blocks\n\n"); + if events.is_empty() { + s.push_str("(none in this window)\n\n"); + } else { + s.push_str("| Time (local) | Type | Rule / pattern | Provider/Model |\n"); + s.push_str("|---|---|---|---|\n"); + for e in events { + let pm = match (&e.provider, &e.model) { + (Some(p), Some(m)) => format!("{p}/{m}"), + (Some(p), None) => p.clone(), + _ => "-".to_string(), + }; + s.push_str(&format!( + "| {} | {} | {} | {} |\n", + e.timestamp + .with_timezone(&chrono::Local) + .format("%Y-%m-%d %H:%M:%S"), + e.event_type, + e.details.replace('|', "\\|"), + pm, + )); + } + s.push('\n'); + } + + s.push_str("## What I was doing\n\n"); + s.push_str("\n"); + s +} diff --git a/src/cli/sidecar.rs b/src/cli/sidecar.rs index 3f1cdd2..bb5116a 100644 --- a/src/cli/sidecar.rs +++ b/src/cli/sidecar.rs @@ -60,6 +60,7 @@ pub async fn run_cmd(args: SidecarArgs) -> anyhow::Result<()> { upstream_google: "https://generativelanguage.googleapis.com".to_string(), rewrite_anthropic_cache: false, no_routing: true, + pause_routing_on_exit: false, }) .await } diff --git a/src/cli/start.rs b/src/cli/start.rs index 582bd63..9ec7a91 100644 --- a/src/cli/start.rs +++ b/src/cli/start.rs @@ -44,6 +44,11 @@ pub struct StartArgs { /// up, and don't pause it when the proxy exits. #[arg(long)] pub no_routing: bool, + /// (internal) Pause routing when this process exits even under + /// `--no-routing`. Injected by the daemon launcher so a gracefully-exiting + /// background child doesn't strand Active env files at a dead port. + #[arg(long, hide = true)] + pub pause_routing_on_exit: bool, } pub async fn run_cmd(args: StartArgs) -> anyhow::Result<()> { @@ -51,7 +56,20 @@ pub async fn run_cmd(args: StartArgs) -> anyhow::Result<()> { return daemon::spawn_background(&args).await; } - init_tracing(); + let cfg_path = config::default_path()?; + let user_config = config::load_or_default(&cfg_path) + .with_context(|| format!("loading config from {}", cfg_path.display()))?; + + // The daemon child (marked by --pause-routing-on-exit) runs with stdio + // detached, so stdout logging goes nowhere — a crashed daemon used to be + // undiagnosable, and `logging.file` was a dead config key (L-H2). Route + // its tracing to the configured log file; foreground keeps stdout. + let log_file = if args.pause_routing_on_exit { + resolved_log_path(&user_config.logging) + } else { + None + }; + init_tracing(log_file, &user_config.logging.level); // Refuse to start a second proxy on top of a running one — `bind` below // is the real backstop, but this gives a clearer message in the common @@ -62,10 +80,6 @@ pub async fn run_cmd(args: StartArgs) -> anyhow::Result<()> { ); } - let cfg_path = config::default_path()?; - let user_config = config::load_or_default(&cfg_path) - .with_context(|| format!("loading config from {}", cfg_path.display()))?; - let storage = Arc::new(Storage::open_default().context("opening default storage")?); let mut ruleset: crate::security::Ruleset = (&user_config.security).into(); @@ -117,6 +131,10 @@ pub async fn run_cmd(args: StartArgs) -> anyhow::Result<()> { budget .hydrate_for_date(&storage, &today) .context("hydrating today's spend")?; + let this_month = chrono::Local::now().format("%Y-%m").to_string(); + budget + .hydrate_for_month(&storage, &this_month) + .context("hydrating this month's spend")?; let port = args.port.unwrap_or(user_config.proxy.port); let host_str = args @@ -172,7 +190,7 @@ pub async fn run_cmd(args: StartArgs) -> anyhow::Result<()> { upstream_anthropic: args.upstream_anthropic.clone(), upstream_openai: args.upstream_openai.clone(), upstream_google: args.upstream_google.clone(), - http_client: reqwest::Client::new(), + http_client: crate::proxy::build_http_client(), security, budget, loop_detector, @@ -207,7 +225,7 @@ pub async fn run_cmd(args: StartArgs) -> anyhow::Result<()> { let result = serve_with_shutdown(listener, Arc::new(state), daemon::shutdown_signal()).await; daemon::remove_pid_file().ok(); - if !args.no_routing { + if !args.no_routing || args.pause_routing_on_exit { super::stop::pause_and_report(); } result.context("proxy serve")?; @@ -268,14 +286,52 @@ pub(crate) fn resume_and_report(proxy_url: &str) { } } -fn init_tracing() { +/// Resolve `logging.file` (with `~/` expansion) to a concrete path. Empty +/// string disables file logging. +pub(crate) fn resolved_log_path(logging: &crate::config::types::LoggingConfig) -> Option { + let raw = logging.file.trim(); + if raw.is_empty() { + return None; + } + if let Some(rest) = raw.strip_prefix("~/").or_else(|| raw.strip_prefix("~\\")) { + return dirs::home_dir().map(|h| h.join(rest)); + } + Some(std::path::PathBuf::from(raw)) +} + +fn init_tracing(log_file: Option, level: &str) { use tracing_subscriber::EnvFilter; - let _ = tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::try_from_default_env() - .unwrap_or_else(|_| EnvFilter::new("info,hyper=warn,h2=warn")), - ) - .try_init(); + let filter = || { + EnvFilter::try_from_default_env().unwrap_or_else(|_| { + let lvl = if level.trim().is_empty() { "info" } else { level.trim() }; + EnvFilter::new(format!("{lvl},hyper=warn,h2=warn")) + }) + }; + if let Some(path) = log_file { + if let Some(parent) = path.parent() { + let _ = std::fs::create_dir_all(parent); + } + // Size cap without a rotation dep: shove an oversized log aside once + // at startup so the file can't grow unbounded across months of uptime. + const MAX_LOG_BYTES: u64 = 10 * 1024 * 1024; + if std::fs::metadata(&path).map(|m| m.len() > MAX_LOG_BYTES).unwrap_or(false) { + let _ = std::fs::rename(&path, path.with_extension("log.old")); + } + match std::fs::OpenOptions::new().create(true).append(true).open(&path) { + Ok(file) => { + let _ = tracing_subscriber::fmt() + .with_env_filter(filter()) + .with_ansi(false) + .with_writer(std::sync::Arc::new(file)) + .try_init(); + return; + } + Err(e) => { + eprintln!("burnwall: could not open log file {}: {e} — logging to stdout", path.display()); + } + } + } + let _ = tracing_subscriber::fmt().with_env_filter(filter()).try_init(); } /// Apply approved third-party rule packs from `/rules/*.toml`. Each diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 026ec8c..b880874 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -1,15 +1,26 @@ -//! OpenAI Chat Completions API response parser. +//! OpenAI Chat Completions + Responses API response parser. //! -//! Two response shapes: -//! - **Non-streaming**: single JSON with `model` + `usage` block. [`parse`]. -//! - **SSE streaming** (when `stream_options.include_usage` is set): a stream -//! of `data: {...}` chunks where one — typically the second-to-last — has -//! a populated `usage` field. [`parse_sse`]. +//! Two APIs, each with a streaming and non-streaming shape: +//! - **Chat Completions** (`/v1/chat/completions`): `usage` carries +//! `prompt_tokens` / `completion_tokens` / `prompt_tokens_details.cached_tokens`. +//! - **Responses API** (`/v1/responses`, what Codex CLI defaults to): `usage` +//! carries `input_tokens` / `output_tokens` / `input_tokens_details.cached_tokens`. //! -//! [`parse_any`] tries non-streaming first, then SSE. +//! Non-streaming bodies for both are a single JSON with top-level `model` + +//! `usage` — [`parse`] handles both via serde field aliases. SSE streams +//! differ: Chat Completions puts `model`/`usage` at the top of a chunk (when +//! `stream_options.include_usage` is set, typically the second-to-last chunk); +//! the Responses API nests them under `response` in typed events, with usage +//! arriving on the `response.completed` event — [`parse_sse`] handles both. //! -//! Normalization: `prompt_tokens` is the TOTAL prompt size (cached + not). -//! We subtract `prompt_tokens_details.cached_tokens` to produce the +//! [`parse_any`] tries non-streaming first, then SSE — and treats an all-zero +//! usage as a parse failure: every `Usage` field is `#[serde(default)]`, so an +//! unrecognized usage shape would otherwise "succeed" with zero tokens and be +//! recorded as a $0 row. A real response always bills at least one input +//! token; all-zero is the signature of a shape we didn't understand. +//! +//! Normalization: the prompt/input count is the TOTAL prompt size (cached + +//! not) in both APIs. We subtract the cached portion to produce the //! `input_tokens` (non-cached) field of [`TokenUsage`]. OpenAI never has //! cache writes — caching is automatic, no opt-in. @@ -23,13 +34,17 @@ struct Response { usage: Usage, } +/// Usage block for both OpenAI APIs. The aliases map the Responses API +/// field names (`input_tokens` / `output_tokens` / `input_tokens_details`) +/// onto the Chat Completions ones — the semantics are identical (totals +/// including the cached portion), only the names differ. #[derive(Deserialize, Default, Clone)] struct Usage { - #[serde(default)] + #[serde(default, alias = "input_tokens")] prompt_tokens: u64, - #[serde(default)] + #[serde(default, alias = "output_tokens")] completion_tokens: u64, - #[serde(default)] + #[serde(default, alias = "input_tokens_details")] prompt_tokens_details: PromptDetails, } @@ -53,14 +68,18 @@ fn to_parsed(model: String, usage: Usage) -> ParsedResponse { } } -/// Parse a non-streaming Chat Completions response body. +/// Parse a non-streaming response body — Chat Completions or Responses API +/// (both have top-level `model` + `usage`; the field aliases on [`Usage`] +/// absorb the naming difference). pub fn parse(body: &[u8]) -> Result { let r: Response = serde_json::from_slice(body)?; Ok(to_parsed(r.model, r.usage)) } -/// Parse an SSE stream body. Looks for the chunk with a non-empty `usage` -/// field; reports the first `model` seen. +/// Parse an SSE stream body — Chat Completions chunks or Responses API +/// events. Looks for a non-empty `usage` block (top-level for Chat +/// Completions, under `response` for Responses API events — usage rides on +/// `response.completed`); reports the first `model` seen. pub fn parse_sse(body: &[u8]) -> Option { let text = std::str::from_utf8(body).ok()?; let mut model: Option = None; @@ -76,10 +95,18 @@ pub fn parse_sse(body: &[u8]) -> Option { let Ok(val) = serde_json::from_str::(json_str) else { continue; }; + // Responses API events (`response.created`, `response.completed`, …) + // nest the payload under `response`; Chat Completions chunks carry + // `model`/`usage` at the top level. Events without a `response` + // object (e.g. `response.output_text.delta`) fall through harmlessly. + let payload = val.get("response").unwrap_or(&val); if model.is_none() { - model = val.get("model").and_then(|m| m.as_str()).map(String::from); + model = payload + .get("model") + .and_then(|m| m.as_str()) + .map(String::from); } - if let Some(usage_val) = val.get("usage") { + if let Some(usage_val) = payload.get("usage") { if !usage_val.is_null() { if let Ok(u) = serde_json::from_value::(usage_val.clone()) { // Keep the most recent non-empty usage block. @@ -95,9 +122,21 @@ pub fn parse_sse(body: &[u8]) -> Option { } /// Try [`parse`] (non-streaming JSON), then [`parse_sse`]. +/// +/// All-zero guard: every [`Usage`] field is `#[serde(default)]`, so a body +/// whose usage shape we don't recognize deserializes "successfully" with +/// zero in every bucket. Recording that would silently book a $0 row for a +/// request that cost real money — worse than not recording, because it looks +/// covered. A billable response always has `input_tokens > 0` (a prompt was +/// processed), so all-zero is treated as a parse failure and the caller's +/// not-recorded warning fires instead. pub fn parse_any(body: &[u8]) -> Option { if let Ok(p) = parse(body) { - return Some(p); + if p.usage.total() > 0 { + return Some(p); + } + // Structurally valid JSON but no recognized usage fields — fall + // through to the SSE parser, then report failure. } - parse_sse(body) + parse_sse(body).filter(|p| p.usage.total() > 0) } diff --git a/src/proxy/forwarding.rs b/src/proxy/forwarding.rs index 75d265b..ed1081a 100644 --- a/src/proxy/forwarding.rs +++ b/src/proxy/forwarding.rs @@ -14,7 +14,8 @@ //! resilience disabled the behavior is unchanged: a single upstream, and a //! 5xx is forwarded to the client verbatim. -use std::sync::Arc; +use std::collections::HashSet; +use std::sync::{Arc, LazyLock, Mutex}; use std::time::Instant; use bytes::Bytes; @@ -23,7 +24,7 @@ use hyper::Response; use tracing::{debug, error, warn}; use crate::pricing; -use crate::providers::{anthropic, google, openai, ParsedResponse}; +use crate::providers::{anthropic, google, openai, ParsedResponse, TokenUsage}; use crate::storage::RequestRecord; use super::{streaming, AppState, BoxError, ProxyBody}; @@ -48,6 +49,26 @@ fn is_hop_by_hop(name: &str) -> bool { HOP_BY_HOP.iter().any(|h| name.eq_ignore_ascii_case(h)) } +/// Headers forwarded upstream on the tracked path: hop-by-hop stripped, plus +/// `Accept-Encoding`. The response tee parses the body for usage/cost, and +/// the proxy's HTTP client is built without decompression support — so when +/// the client's `Accept-Encoding` (Claude Code sends `gzip, br, zstd`) is +/// forwarded, the upstream compresses the body and the tee sees opaque bytes: +/// cost tracking silently records nothing. Dropping the header makes the +/// upstream respond in identity encoding; the response still passes through +/// byte-for-byte unchanged. The bypass relay ([`passthrough`]) keeps the +/// client's header — it never parses anything. +fn tracked_outbound_headers(req_headers: &HeaderMap) -> HeaderMap { + let mut out = HeaderMap::new(); + for (name, value) in req_headers.iter() { + if is_hop_by_hop(name.as_str()) || name.as_str().eq_ignore_ascii_case("accept-encoding") { + continue; + } + out.append(name.clone(), value.clone()); + } + out +} + #[allow(clippy::too_many_arguments)] pub async fn forward( method: Method, @@ -58,16 +79,20 @@ pub async fn forward( state: &Arc, provider: &'static str, request_hash_hex: String, + // Loop-detector hash to record an arrival under, but ONLY when the upstream + // returns 2xx — `None` for GET/body-less requests that aren't loop-tracked. + // Recording on the response path (not pre-forward) is what stops blocked + // 429s and failed-request retries from feeding the window (B-C2). + loop_hash: Option, + // Cache-savings projection (USD) to persist off the hot path in the tee, + // instead of a synchronous pre-forward write (D-M5). `None` when cache + // injection is on or the request isn't an eligible Messages-API call. + cache_projection: Option, ) -> Result, BoxError> { // Opt-in session/swarm id for per-session attribution + budget recording. let session_id = super::handler::session_from_headers(&req_headers); - let mut outbound_headers = HeaderMap::new(); - for (name, value) in req_headers.iter() { - if !is_hop_by_hop(name.as_str()) { - outbound_headers.append(name.clone(), value.clone()); - } - } + let outbound_headers = tracked_outbound_headers(&req_headers); let candidates = state.resilience.candidates(provider, primary_base); let use_breaker = state.resilience.enabled; @@ -138,6 +163,14 @@ pub async fn forward( let status_code = status.as_u16() as i64; let resp_headers = upstream_resp.headers().clone(); + // Captured for the tee's parse-failure diagnostics: a non-identity + // encoding here means the body bytes are compressed and unparseable. + let content_encoding = resp_headers + .get("content-encoding") + .and_then(|v| v.to_str().ok()) + .unwrap_or("identity") + .to_string(); + // Subscription-plan limit headroom rides on the upstream response (e.g. // Anthropic's `unified-*` headers); `None` for API keys / unprobed providers. // Parsed here (cheap, in-memory); persisted off the response path in the tee @@ -158,7 +191,28 @@ pub async fn forward( let hash_hex = request_hash_hex; let session_for_tee = session_id.clone(); - let teed = streaming::tee_stream(upstream_resp.bytes_stream(), move |chunks| { + let teed = streaming::tee_stream(upstream_resp.bytes_stream(), move |chunks, aborted| { + // Record a loop-detector arrival only for a forwarded 2xx (B-C2): a + // genuine repeat is an identical body that keeps *succeeding*. Retries + // of a block or of an upstream error never reach here with a 2xx, so + // they can't refill the window. A client-aborted request isn't a + // completed success, so it doesn't count toward a loop either. + if let Some(hash) = loop_hash { + if !aborted && (200..300).contains(&status_code) { + loop_detector.record_arrival(hash); + } + } + + // Deferred cache-savings projection write (D-M5): off the response path, + // so the synchronous SQLite UPSERT/fsync never sits in front of the + // request the way a pre-forward write did. + if let Some(savings) = cache_projection { + let today = chrono::Local::now().format("%Y-%m-%d").to_string(); + if let Err(e) = storage.record_cache_projection(&today, savings) { + debug!("cache projection record failed: {}", e); + } + } + // Persist the subscription-limit snapshot if this was a unified response. // Off the response path — the client already has its bytes. if let Some(snap) = &plan_snapshot { @@ -172,7 +226,7 @@ pub async fn forward( match parse_for_provider(&provider_str, &total) { Some(p) => { - let cost = pricing::calculate_cost(&p.model, &p.usage).unwrap_or(0.0); + let cost = cost_or_zero(&p.model, &p.usage); let mut record = RequestRecord::successful( &provider_str, &p.model, @@ -182,7 +236,10 @@ pub async fn forward( ); record.request_hash = Some(hash_hex.clone()); record.latency_ms = Some(latency_ms); - record.http_status = Some(status_code); + // 499 (client closed request) marks a partial response the user + // cancelled mid-stream, so its cost is attributable but + // distinguishable from a clean completion. + record.http_status = Some(if aborted { 499 } else { status_code }); if let Err(e) = storage.insert_request(&record) { error!("requests insert failed: {}", e); } @@ -228,11 +285,27 @@ pub async fn forward( status_code, ); } - None => { + None if aborted => { + // A client-cancelled stream is usually a partial body that + // can't parse — expected, not a systemic problem. Don't + // warn-spam on every Esc. debug!( - "could not parse {} response body for usage tracking ({} bytes)", + "{} response not recorded — client aborted mid-stream ({} bytes)", provider_str, - total.len() + total.len(), + ); + } + None => { + // warn, not debug: an unparseable body means this request is + // invisible to cost tracking and coverage. A long stretch of + // these in the log is the signal that something systemic + // (e.g. an encoding we don't handle) is hiding traffic. + warn!( + "could not parse {} response for usage tracking ({} bytes, content-encoding: {}, status {}) — request not recorded", + provider_str, + total.len(), + content_encoding, + status_code, ); } } @@ -268,6 +341,33 @@ fn parse_for_provider(provider: &str, body: &[u8]) -> Option { } } +/// Cost for a parsed response, or `0.0` when the model has no pricing entry. +/// +/// Fail-open: the row is still recorded (the token counts are real and the +/// request must stay visible to history/budget), but pricing an unknown model +/// at $0 silently would understate spend with no trace — so the first time +/// each model name misses, warn and point at the override file. Once per +/// model per process, not per request: an agent can replay the same unknown +/// model thousands of times an hour. +fn cost_or_zero(model: &str, usage: &TokenUsage) -> f64 { + match pricing::calculate_cost(model, usage) { + Some(c) => c, + None => { + static WARNED: LazyLock>> = + LazyLock::new(|| Mutex::new(HashSet::new())); + let mut warned = WARNED.lock().unwrap_or_else(|p| p.into_inner()); + if warned.insert(model.to_string()) { + warn!( + "unknown model '{}' — no pricing entry, cost recorded as $0. \ + Add a [[model]] override in ~/.burnwall/pricing.toml to price it.", + model, + ); + } + 0.0 + } + } +} + /// Pure pass-through: forward `method/headers/body` to `upstream_base + path_and_query`, /// stream the response back. No security scan, no parsing, no storage write, /// no failover, no breaker. Used by the BURNWALL_BYPASS kill-switch (L2). @@ -315,3 +415,29 @@ pub async fn passthrough( } Ok(response.body(body).expect("passthrough: response build failed")) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn tracked_outbound_headers_strips_accept_encoding_and_hop_by_hop() { + let mut h = HeaderMap::new(); + h.insert("accept-encoding", HeaderValue::from_static("gzip, br, zstd")); + h.insert("connection", HeaderValue::from_static("keep-alive")); + h.insert("content-length", HeaderValue::from_static("42")); + h.insert("x-api-key", HeaderValue::from_static("k")); + h.insert( + "anthropic-version", + HeaderValue::from_static("2023-06-01"), + ); + let out = tracked_outbound_headers(&h); + // Forwarding the client's accept-encoding lets the upstream compress + // the body, which the tee can't parse — cost tracking goes dark. + assert!(out.get("accept-encoding").is_none()); + assert!(out.get("connection").is_none()); + assert!(out.get("content-length").is_none()); + assert_eq!(out.get("x-api-key").unwrap(), "k"); + assert_eq!(out.get("anthropic-version").unwrap(), "2023-06-01"); + } +} diff --git a/src/proxy/handler.rs b/src/proxy/handler.rs index 30e1585..fefa4f2 100644 --- a/src/proxy/handler.rs +++ b/src/proxy/handler.rs @@ -137,79 +137,152 @@ pub async fn handle( tracing::error!("blocked-request insert failed: {}", e); } - let msg = format!("Burnwall blocked: {}", violation.message()); - return Ok(error_response( - StatusCode::FORBIDDEN, + let what = format!("{} ({}).", violation.message(), violation.location.describe()); + return Ok(block::build( + provider, "security_blocked", - &msg, + StatusCode::FORBIDDEN, + &what, + block::SECURITY_REMEDIES, + None, )); } // ─── budget check ─── - match state.budget.check() { - BudgetStatus::Exceeded { spent, limit } => { - warn!("💰 BUDGET EXCEEDED: ${:.2}/${:.2}", spent, limit); - let record = RequestRecord::blocked(provider, &model, "budget_exceeded", None); - if let Err(e) = state.storage.insert_request(&record) { - tracing::error!("blocked-request insert failed: {}", e); + // Plan-aware (B-H4): a subscription request (OAuth bearer, no API key) is + // not metered per token, so the dollar cap is notional — we track and warn + // but do not 429-block it unless `budget.enforce_on_plan` is set. Metered + // API-key traffic is always enforced. + let metered = auth_kind(&parts.headers, provider) == AuthKind::Metered; + let enforce_dollar_cap = metered || state.budget.config().enforce_on_plan; + + // Monthly cap first (the hard backstop), then daily. + for (status, label) in [ + (state.budget.check_monthly(), "monthly"), + (state.budget.check(), "daily"), + ] { + match status { + BudgetStatus::Exceeded { spent, limit } => { + if enforce_dollar_cap { + warn!("💰 {} BUDGET EXCEEDED: ${:.2}/${:.2}", label, spent, limit); + let kind = if label == "monthly" { + "monthly_budget_exceeded" + } else { + "budget_exceeded" + }; + let record = RequestRecord::blocked(provider, &model, kind, None); + if let Err(e) = state.storage.insert_request(&record) { + tracing::error!("blocked-request insert failed: {}", e); + } + let reset = if label == "monthly" { + "the 1st of next month" + } else { + "local midnight" + }; + let what = format!( + "Your {label} budget of ${:.2} is used up (${:.2} spent). It resets at {reset}.", + limit, spent + ); + return Ok(block::build( + provider, + kind, + StatusCode::TOO_MANY_REQUESTS, + &what, + block::BUDGET_REMEDIES, + Some(block::seconds_until_local_midnight()), + )); + } else { + // Subscription traffic: notional dollars, plan is the real + // limit. Warn once-ish, never block. + warn!( + "💰 {} notional spend ${:.2} over ${:.2} cap — plan traffic, not blocking (set budget.enforce_on_plan=true to enforce)", + label, spent, limit + ); + } } - let msg = format!( - "Daily budget of ${:.2} exceeded (${:.2} spent)", - limit, spent - ); - return Ok(error_response( - StatusCode::TOO_MANY_REQUESTS, - "budget_exceeded", - &msg, - )); - } - BudgetStatus::Warn { - spent, - limit, - percent, - } => { - warn!("⚠️ Budget {}% used (${:.2}/${:.2})", percent, spent, limit); + BudgetStatus::Warn { + spent, + limit, + percent, + } => { + warn!( + "⚠️ {} budget {}% used (${:.2}/${:.2})", + label, percent, spent, limit + ); + } + BudgetStatus::Ok => {} } - BudgetStatus::Ok => {} } // ─── per-session / swarm budget ceiling (opt-in via x-burnwall-session) ─── + // Same plan-aware gate as the daily/monthly caps: an explicit per-session + // cap is still enforced on metered traffic, but a notional cap on plan + // traffic only warns unless the user opted in. if let Some(sid) = &session_id { if let BudgetStatus::Exceeded { spent, limit } = state.budget.check_session(sid) { - warn!("💰 SESSION BUDGET EXCEEDED: ${:.2}/${:.2}", spent, limit); - let record = - RequestRecord::blocked(provider, &model, "session_budget_exceeded", Some(sid.clone())); - if let Err(e) = state.storage.insert_request(&record) { - tracing::error!("blocked-request insert failed: {}", e); + if enforce_dollar_cap { + warn!("💰 SESSION BUDGET EXCEEDED: ${:.2}/${:.2}", spent, limit); + let record = + RequestRecord::blocked(provider, &model, "session_budget_exceeded", Some(sid.clone())); + if let Err(e) = state.storage.insert_request(&record) { + tracing::error!("blocked-request insert failed: {}", e); + } + let what = format!( + "This session/swarm hit its ${:.2} cap (${:.2} spent).", + limit, spent + ); + return Ok(block::build( + provider, + "session_budget_exceeded", + StatusCode::TOO_MANY_REQUESTS, + &what, + block::SESSION_REMEDIES, + None, + )); + } else { + warn!( + "💰 session notional spend ${:.2} over ${:.2} cap — plan traffic, not blocking", + spent, limit + ); } - let msg = format!( - "Session budget of ${:.2} exceeded (${:.2} spent) — swarm/session cap hit", - limit, spent - ); - return Ok(error_response( - StatusCode::TOO_MANY_REQUESTS, - "session_budget_exceeded", - &msg, - )); } } // ─── loop detection ─── - let request_hash = state.loop_detector.hash(&body_bytes); + // Skip body-less / GET requests entirely (B-H1): a `GET /v1/models` cannot + // be a runaway agent loop worth blocking, and all empty bodies would + // otherwise collide into one bucket. `should_track` gates both the + // pre-forward peek and the on-2xx arrival recording. + let should_track_loop = parts.method != hyper::Method::GET && !body_bytes.is_empty(); + let request_hash = state + .loop_detector + .hash(parts.method.as_str(), provider, &rest, &body_bytes); let request_hash_hex = format!("{:016x}", request_hash); - let verdict = state.loop_detector.check_request(request_hash); - if verdict.is_blocking() { - warn!("🔄 LOOP BLOCKED {}: {}", provider, verdict.message()); - let mut record = RequestRecord::blocked(provider, &model, &verdict.message(), None); - record.request_hash = Some(request_hash_hex.clone()); - if let Err(e) = state.storage.insert_request(&record) { - tracing::error!("blocked-request insert failed: {}", e); + if should_track_loop { + // Read-only peek — the arrival is recorded later by the tee, and only + // on a 2xx, so a blocked 429 (or a retry after an upstream failure) + // never feeds the window. This is the death-spiral fix (B-C2). + let verdict = state.loop_detector.check_request(request_hash); + if verdict.is_blocking() { + warn!("🔄 LOOP BLOCKED {}: {}", provider, verdict.message()); + let mut record = RequestRecord::blocked(provider, &model, &verdict.message(), None); + record.request_hash = Some(request_hash_hex.clone()); + if let Err(e) = state.storage.insert_request(&record) { + tracing::error!("blocked-request insert failed: {}", e); + } + let what = format!( + "{}. This usually means your tool retried an identical request; it clears automatically.", + verdict.message() + ); + return Ok(block::build( + provider, + "loop_detected", + StatusCode::TOO_MANY_REQUESTS, + &what, + block::LOOP_REMEDIES, + verdict.retry_after_secs(), + )); } - return Ok(error_response( - StatusCode::TOO_MANY_REQUESTS, - "loop_detected", - &verdict.message(), - )); } // ─── cost-spiral enforcement (opt-in) ─── @@ -224,10 +297,14 @@ pub async fn handle( if let Err(e) = state.storage.insert_request(&record) { tracing::error!("blocked-request insert failed: {}", e); } - return Ok(error_response( - StatusCode::TOO_MANY_REQUESTS, + let what = format!("{}.", spiral.message()); + return Ok(block::build( + provider, "cost_spiral", - &spiral.message(), + StatusCode::TOO_MANY_REQUESTS, + &what, + block::COST_SPIRAL_REMEDIES, + spiral.retry_after_secs(), )); } @@ -240,6 +317,11 @@ pub async fn handle( // gate to provider=anthropic + path=/v1/messages (the only Anthropic // endpoint that accepts these markers). let messages_api = provider == "anthropic" && cache_injection::is_messages_path(&rest); + // Cache-savings projection (cache injection OFF): the estimate is an + // in-memory parse here, but the DB write is deferred to the tee callback + // (off the response path) instead of a synchronous pre-forward fsync that + // could stall the request behind a contended write — D-M5. + let mut cache_projection = None; let forward_body = if state.cache_injection && messages_api { let outcome = cache_injection::inject_if_eligible(&body_bytes); if outcome.modified { @@ -250,16 +332,16 @@ pub async fn handle( if !state.cache_injection && messages_api { let projected = cache_injection::estimate_savings_usd(&body_bytes); if projected > 0.0 { - let today = chrono::Local::now().format("%Y-%m-%d").to_string(); - if let Err(e) = state.storage.record_cache_projection(&today, projected) { - tracing::warn!("cache projection record failed: {}", e); - } + cache_projection = Some(projected); } } body_bytes }; // ─── forward (with optional failover) + tee-parse ─── + // Pass the loop hash so the tee can record the arrival on a 2xx (and only + // then). `None` when this request isn't loop-tracked (GET/body-less). + let loop_hash = should_track_loop.then_some(request_hash); match forwarding::forward( parts.method, &upstream_base, @@ -269,6 +351,8 @@ pub async fn handle( &state, provider, request_hash_hex, + loop_hash, + cache_projection, ) .await { @@ -307,6 +391,160 @@ fn escape_json(s: &str) -> String { s.replace('\\', "\\\\").replace('"', "\\\"") } +/// Which credential kind a request carries — drives plan-aware budget +/// enforcement (B-H4). We classify the *kind* only and never read or log the +/// credential value. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum AuthKind { + /// Metered API key (`x-api-key`, or any bearer we can't identify as a + /// subscription) — real per-token dollars, so the dollar cap applies. + Metered, + /// Flat-rate subscription (Claude Pro/Max via an OAuth bearer) — not + /// metered per token, so the dollar figure is notional. + Subscription, +} + +/// Classify the request's credential kind. Defaults to [`AuthKind::Metered`] so +/// enforcement is only ever *relaxed* for a positively-identified subscription, +/// never weakened for an unknown auth shape. +fn auth_kind(headers: &hyper::HeaderMap, provider: &str) -> AuthKind { + // An API key is unambiguously metered. + if headers + .get("x-api-key") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { + return AuthKind::Metered; + } + // Anthropic OAuth tokens (Claude Code on a Pro/Max plan) start with + // `sk-ant-oat`. The API authenticates with `x-api-key`, so a bearer of this + // shape is a subscription. We inspect only the prefix; the token is never + // logged. OpenAI/Google bearers are API-metered, so they fall through to + // Metered. + if provider == "anthropic" { + if let Some(auth) = headers + .get(hyper::header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + { + let token = auth + .strip_prefix("Bearer ") + .or_else(|| auth.strip_prefix("bearer ")) + .unwrap_or(""); + if token.starts_with("sk-ant-oat") { + return AuthKind::Subscription; + } + } + } + AuthKind::Metered +} + +/// Self-identifying, actionable block responses (W1-7). Every block Burnwall +/// imposes tells the user: (1) that *Burnwall* did it, before the request left +/// the machine; (2) what matched and where; (3) how to proceed if it's a false +/// positive, escalating inspect → narrow → bypass → pause; and (4) how to +/// report it. Limit blocks also carry a `Retry-After`. The JSON envelope +/// matches the upstream provider's error shape (P-M2) so the AI tool renders a +/// clean error instead of a raw blob. +pub(crate) mod block { + use bytes::Bytes; + use hyper::{Response, StatusCode}; + use serde_json::json; + + use crate::proxy::{streaming, ProxyBody}; + + pub const SECURITY_REMEDIES: &[&str] = &[ + "See exactly what was caught: burnwall security", + "If it's wrong, adjust the rule in ~/.burnwall/config.toml (security.deny_paths / deny_commands), or disable a pack: burnwall rules disable ", + "Bypass Burnwall for this session — UNPROTECTED: set BURNWALL_BYPASS=1 and restart your AI tool", + "Turn Burnwall off entirely — UNPROTECTED: burnwall stop", + ]; + pub const BUDGET_REMEDIES: &[&str] = &[ + "See today's spend: burnwall status", + "Raise or remove the cap: burnwall config set budget.daily (0 = unlimited)", + "On a flat-rate plan? The dollar cap is notional — plan traffic isn't blocked by default (budget.enforce_on_plan).", + "Bypass for this session — UNPROTECTED: set BURNWALL_BYPASS=1 and restart your AI tool", + ]; + pub const SESSION_REMEDIES: &[&str] = &[ + "Raise or turn off the per-session cap: burnwall config set budget.per_session (0 = off)", + "Bypass for this session — UNPROTECTED: set BURNWALL_BYPASS=1 and restart your AI tool", + ]; + pub const LOOP_REMEDIES: &[&str] = &[ + "This clears on its own once the retry window drains — usually a client resending an identical request.", + "Tune the threshold: burnwall config set loop_detection.max_identical_requests ", + "Disable loop detection: burnwall config set loop_detection.enabled false", + "Bypass for this session — UNPROTECTED: set BURNWALL_BYPASS=1 and restart your AI tool", + ]; + pub const COST_SPIRAL_REMEDIES: &[&str] = &[ + "Raise the window cap: burnwall config set loop_detection.max_cost_per_window ", + "Disable spiral blocking: burnwall config set loop_detection.cost_spiral_enforce false", + "Bypass for this session — UNPROTECTED: set BURNWALL_BYPASS=1 and restart your AI tool", + ]; + + /// Seconds until the next local midnight — the daily budget reset time. + pub fn seconds_until_local_midnight() -> u64 { + use chrono::Timelike; + let secs_today = chrono::Local::now().num_seconds_from_midnight() as u64; + 86_400u64.saturating_sub(secs_today).max(1) + } + + /// Assemble the human-readable block message: self-identify, what/where, + /// escape hatches, report path. + fn message(what: &str, remedies: &[&str]) -> String { + let mut m = String::new(); + m.push_str("🛡️ Burnwall blocked this request before it left your machine.\n"); + m.push_str(what); + if !remedies.is_empty() { + m.push_str("\n\nIf this is a false positive, you can:"); + for r in remedies { + m.push_str("\n • "); + m.push_str(r); + } + } + m.push_str("\n\nReport a false positive (nothing leaves your machine): burnwall report-bug"); + m + } + + /// Build the provider-correct JSON error response with the block message + /// and an optional `Retry-After` header. + pub fn build( + provider: &str, + kind: &str, + status: StatusCode, + what: &str, + remedies: &[&str], + retry_after_secs: Option, + ) -> Response { + let msg = message(what, remedies); + // Match each provider's native error envelope so the client SDK renders + // it as an error rather than failing to parse an unexpected shape. + let value = match provider { + "anthropic" => json!({"type": "error", "error": {"type": kind, "message": msg}}), + "google" => { + let gstatus = match status { + StatusCode::TOO_MANY_REQUESTS => "RESOURCE_EXHAUSTED", + StatusCode::FORBIDDEN => "PERMISSION_DENIED", + _ => "FAILED_PRECONDITION", + }; + json!({"error": {"code": status.as_u16(), "message": msg, "status": gstatus}}) + } + _ => json!({"error": {"message": msg, "type": kind, "code": kind}}), + }; + let body = serde_json::to_string(&value) + .unwrap_or_else(|_| r#"{"error":{"message":"Burnwall blocked this request."}}"#.to_string()); + + let mut builder = Response::builder() + .status(status) + .header("content-type", "application/json") + .header("x-burnwall-blocked", kind); + if let Some(secs) = retry_after_secs { + builder = builder.header("retry-after", secs.to_string()); + } + builder + .body(streaming::full(Bytes::from(body))) + .expect("block::build: response builder failed") + } +} + /// Best-effort extraction of the `model` field from a request body. Used /// to populate `RequestRecord.model` even when the request was blocked. fn extract_model(body: &[u8]) -> Option { diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 9127c91..ad6c587 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -33,6 +33,32 @@ pub mod streaming; pub use resilience::Resilience; pub use streaming::{BoxError, ProxyBody}; +/// Build the upstream HTTP client with deadlines and TCP keepalive (P-C1). A +/// bare `reqwest::Client::new()` has no connect timeout, no read timeout, and +/// no keepalive, so a VPN flip / captive portal blackholes a request for the OS +/// connect timeout (tens of seconds, freezing the user's tool), and a stalled +/// stream after laptop sleep/wake blocks the tee task forever — the request is +/// never recorded and the task plus its buffered body leak until restart. +/// +/// - `connect_timeout`: fail fast to a clean 502 instead of a long hang. +/// - `tcp_keepalive`: detect a silently-dead socket (no FIN/RST) so a stalled +/// stream eventually errors instead of blocking forever. +/// - `read_timeout` (per-read, NOT total `timeout`): reclaims a socket that has +/// gone quiet, while still allowing arbitrarily long SSE streams — Anthropic +/// sends periodic pings, so a live stream keeps resetting the per-read clock. +/// A total `.timeout()` would wrongly kill long legitimate generations. +pub fn build_http_client() -> reqwest::Client { + reqwest::Client::builder() + .connect_timeout(std::time::Duration::from_secs(10)) + .tcp_keepalive(std::time::Duration::from_secs(60)) + .read_timeout(std::time::Duration::from_secs(600)) + .build() + .unwrap_or_else(|e| { + tracing::warn!("falling back to default HTTP client: {e}"); + reqwest::Client::new() + }) +} + /// Shared, immutable-from-the-handler-side state. Each component is `Arc`'d /// so the tee callback (which runs in a spawned task) can clone the parts /// it needs without copying the whole struct. @@ -71,7 +97,7 @@ impl AppState { upstream_anthropic, upstream_openai, upstream_google: "https://generativelanguage.googleapis.com".to_string(), - http_client: reqwest::Client::new(), + http_client: build_http_client(), security: Arc::new(SecurityEngine::with_defaults()), budget: Arc::new(BudgetTracker::with_defaults()), loop_detector: Arc::new(LoopDetector::with_defaults()), diff --git a/src/proxy/streaming.rs b/src/proxy/streaming.rs index e3b9b8a..d800c64 100644 --- a/src/proxy/streaming.rs +++ b/src/proxy/streaming.rs @@ -10,9 +10,13 @@ //! `on_complete` with the accumulated chunks so the caller can parse //! usage data and write storage rows. //! -//! If the client disconnects mid-stream, the channel send fails; we keep -//! reading from upstream and still fire `on_complete` so cost tracking -//! reflects what the upstream actually delivered. +//! If the client disconnects mid-stream (e.g. the user presses Esc in their +//! AI tool), the channel send fails. We then **stop** reading and drop the +//! upstream stream so the provider stops generating — otherwise we'd bill the +//! full response for output nobody will read, and a stalled tail could leak the +//! task forever (P-C2). `on_complete` still fires with the bytes collected so +//! far and an `aborted` flag, so a partial response is recorded rather than +//! silently lost. use std::convert::Infallible; use std::pin::Pin; @@ -62,26 +66,32 @@ where pub fn tee_stream(stream: S, on_complete: F) -> ChannelStream where S: Stream> + Send + 'static, - F: FnOnce(Vec) + Send + 'static, + F: FnOnce(Vec, bool) + Send + 'static, { let (tx, rx) = unbounded_channel(); tokio::spawn(async move { let mut collected: Vec = Vec::new(); let mut stream = Box::pin(stream); - let mut client_alive = true; + let mut aborted = false; while let Some(item) = stream.next().await { if let Ok(ref b) = item { collected.push(b.clone()); } - if client_alive && tx.send(item).is_err() { - // Client closed — stop forwarding, but keep draining so we - // still call on_complete with the full accumulated body. - client_alive = false; + if tx.send(item).is_err() { + // Client hung up. Stop reading and drop the upstream stream so + // the connection aborts and the provider stops generating — + // billing for output nobody reads, and leaking a task on a + // stalled tail, are both worse than a partial cost record. + aborted = true; + break; } } + // Drop the upstream stream promptly (before the blocking parse) so the + // socket closes on a client abort. + drop(stream); // Run the usage parse + storage writes on the blocking pool so the // synchronous SQLite I/O never stalls an async worker thread. - let _ = tokio::task::spawn_blocking(move || on_complete(collected)).await; + let _ = tokio::task::spawn_blocking(move || on_complete(collected, aborted)).await; }); ChannelStream(rx) } diff --git a/tests/integration/pipeline_test.rs b/tests/integration/pipeline_test.rs index d3e2de1..1a67679 100644 --- a/tests/integration/pipeline_test.rs +++ b/tests/integration/pipeline_test.rs @@ -164,8 +164,8 @@ async fn safe_openai_request_records_cost_with_cache() { assert_eq!(rows[0].input_tokens, 512); assert_eq!(rows[0].cache_read_tokens, 1536); assert_eq!(rows[0].output_tokens, 512); - // Cost: 512*1.25 + 1536*0.625 + 512*10.00, all / 1M = $0.00672 - assert!((rows[0].cost_usd - 0.00672).abs() < 1e-6); + // Cost: 512*2.50 + 1536*0.25 + 512*15.00, all / 1M = $0.009344 + assert!((rows[0].cost_usd - 0.009344).abs() < 1e-6); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -250,6 +250,7 @@ async fn budget_exceeded_returns_429_without_forwarding() { monthly_usd: 0.0, warn_percent: 80, per_session_usd: 0.0, + enforce_on_plan: false, })); budget.record(2.50); // already past the $1 cap @@ -279,10 +280,10 @@ async fn budget_exceeded_returns_429_without_forwarding() { assert_eq!(resp.status(), 429); let body: serde_json::Value = resp.json().await.unwrap(); assert_eq!(body["error"]["type"], "budget_exceeded"); - assert!(body["error"]["message"] - .as_str() - .unwrap() - .contains("Daily budget")); + // W1-7: the block message self-identifies as Burnwall and names the cap. + let msg = body["error"]["message"].as_str().unwrap(); + assert!(msg.contains("Burnwall"), "should self-identify: {msg}"); + assert!(msg.contains("budget"), "should name the budget: {msg}"); settle().await; let rows = storage.requests_for_date(&today()).unwrap(); @@ -291,6 +292,73 @@ async fn budget_exceeded_returns_429_without_forwarding() { assert_eq!(rows[0].block_reason.as_deref(), Some("budget_exceeded")); } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn subscription_traffic_not_blocked_by_dollar_cap() { + // B-H4: a subscription request (Anthropic OAuth bearer, no API key) carries + // notional dollars — the daily cap must NOT 429 it (it's tracked + warned + // instead). The same over-budget tracker blocks a metered API-key request. + let mock = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/messages")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "id": "msg", + "model": "claude-fable-5", + "usage": {"input_tokens": 10, "output_tokens": 5} + }))) + .mount(&mock) + .await; + + let budget = Arc::new(BudgetTracker::new(BudgetConfig { + daily_usd: 1.0, + monthly_usd: 0.0, + warn_percent: 80, + per_session_usd: 0.0, + enforce_on_plan: false, // default: plan traffic isn't dollar-capped + })); + budget.record(5.00); // well past the $1 cap + + let storage = Arc::new(Storage::open_in_memory().unwrap()); + let state = AppState { + upstream_anthropic: mock.uri(), + upstream_openai: "http://127.0.0.1:1".to_string(), + http_client: reqwest::Client::new(), + security: Arc::new(SecurityEngine::with_defaults()), + budget, + loop_detector: Arc::new(LoopDetector::with_defaults()), + storage: storage.clone(), + cache_injection: false, + upstream_google: "http://127.0.0.1:1".to_string(), + resilience: Default::default(), + otel: None, + }; + let addr = spawn_proxy(state).await; + + // Subscription bearer → forwarded despite being over the dollar cap. + let resp = client() + .post(format!("http://{}/anthropic/v1/messages", addr)) + .header("authorization", "Bearer sk-ant-oat01-fake-oauth-token") + .json(&json!({"model": "claude-fable-5"})) + .send() + .await + .unwrap(); + assert_eq!( + resp.status(), + 200, + "subscription traffic must not be dollar-capped by default" + ); + let _ = resp.bytes().await; + + // Metered API key → blocked by the same over-budget tracker. + let resp = client() + .post(format!("http://{}/anthropic/v1/messages", addr)) + .header("x-api-key", "sk-ant-api03-fake-metered-key") + .json(&json!({"model": "claude-fable-5"})) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 429, "metered traffic is dollar-capped"); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn sse_streaming_response_records_cost_from_message_start() { // Realistic Anthropic SSE payload with input_tokens in message_start and @@ -385,6 +453,7 @@ async fn budget_warning_does_not_block() { monthly_usd: 0.0, warn_percent: 80, per_session_usd: 0.0, + enforce_on_plan: false, })); budget.record(9.50); @@ -426,15 +495,17 @@ async fn loop_detection_blocks_after_threshold_identical_requests() { .mount(&mock) .await; - // Detector tuned to block on the 3rd identical request within 60s. + // Detector tuned so that once 2 identical requests have *succeeded* (been + // recorded by the tee on a 2xx), the next identical request is blocked. + // Arrivals are recorded on the response path now (B-C2), so the test + // settles between requests to let each recording land before the next peek. let detector = Arc::new(burnwall::budget::LoopDetector::new( burnwall::budget::LoopConfig { enabled: true, - max_identical_requests: 3, + max_identical_requests: 2, window_seconds: 60, max_cost_per_window: 0.0, // disable cost-spiral for this test cost_spiral_enforce: false, - hash_prefix_bytes: 200, }, )); @@ -457,7 +528,8 @@ async fn loop_detection_blocks_after_threshold_identical_requests() { let body = json!({"model": "claude-haiku-4-5", "messages": [{"role": "user", "content": "hi"}]}); - // First two: forwarded + // First two: forwarded. Settle after each so the tee records the arrival + // (on the 2xx) before the next request's pre-forward peek. for i in 1..=2 { let resp = client() .post(format!("http://{}/anthropic/v1/messages", addr)) @@ -467,6 +539,7 @@ async fn loop_detection_blocks_after_threshold_identical_requests() { .unwrap(); assert_eq!(resp.status(), 200, "request {} should pass", i); let _ = resp.bytes().await; // drain + settle().await; } // Third identical: blocked @@ -507,6 +580,69 @@ async fn loop_detection_blocks_after_threshold_identical_requests() { assert_eq!(successful[0].request_hash, successful[1].request_hash); } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn accept_encoding_is_not_forwarded_upstream() { + // Regression: when the client's `accept-encoding` (Claude Code sends + // `gzip, br, zstd`) reached the upstream, the response came back + // compressed and the tee couldn't parse usage from it — every successful + // request was silently invisible to cost tracking and coverage. + let mock = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/messages")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "id": "msg", + "model": "claude-haiku-4-5", + "usage": {"input_tokens": 10, "output_tokens": 5} + }))) + .mount(&mock) + .await; + + let storage = Arc::new(Storage::open_in_memory().unwrap()); + let state = AppState { + upstream_anthropic: mock.uri(), + upstream_openai: "http://127.0.0.1:1".to_string(), + http_client: reqwest::Client::new(), + security: Arc::new(SecurityEngine::with_defaults()), + budget: Arc::new(BudgetTracker::with_defaults()), + loop_detector: Arc::new(LoopDetector::with_defaults()), + storage: storage.clone(), + cache_injection: false, + upstream_google: "http://127.0.0.1:1".to_string(), + resilience: Default::default(), + otel: None, + }; + let addr = spawn_proxy(state).await; + + let resp = client() + .post(format!("http://{}/anthropic/v1/messages", addr)) + .header("accept-encoding", "gzip, br, zstd") + .json(&json!({ + "model": "claude-haiku-4-5", + "messages": [{"role": "user", "content": "hi"}] + })) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + let _ = resp.bytes().await; + + settle().await; + + let received = mock.received_requests().await.unwrap(); + assert_eq!(received.len(), 1); + assert!( + received[0].headers.get("accept-encoding").is_none(), + "accept-encoding must be stripped so the upstream replies in identity encoding" + ); + + // With a parseable (identity) body, the tee records the request. + let rows = storage.requests_for_date(&today()).unwrap(); + assert_eq!(rows.len(), 1, "the forwarded request must be recorded"); + assert!(!rows[0].blocked); + assert_eq!(rows[0].input_tokens, 10); + assert_eq!(rows[0].output_tokens, 5); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn security_log_redact_details_strips_rule_from_storage() { use burnwall::security::{Ruleset, SecurityEngine}; @@ -598,7 +734,6 @@ async fn distinct_requests_dont_trip_loop_detector() { window_seconds: 60, max_cost_per_window: 0.0, cost_spiral_enforce: false, - hash_prefix_bytes: 200, }, )), storage: storage.clone(), @@ -877,9 +1012,9 @@ async fn gemini_request_records_cost_and_latency() { assert_eq!(rows[0].output_tokens, 300); // 200 + 100 thoughts assert_eq!(rows[0].http_status, Some(200)); assert!(rows[0].latency_ms.is_some(), "latency recorded"); - // gemini-2.5-flash: 512*0.30 + 1536*0.075 + 300*2.50, /1M = 0.0010188 + // gemini-2.5-flash: 512*0.30 + 1536*0.03 + 300*2.50, /1M = 0.00094968 assert!( - (rows[0].cost_usd - 0.0010188).abs() < 1e-7, + (rows[0].cost_usd - 0.00094968).abs() < 1e-7, "got {}", rows[0].cost_usd ); diff --git a/tests/unit/parser_test.rs b/tests/unit/parser_test.rs index c4c1ca2..2bd16d4 100644 --- a/tests/unit/parser_test.rs +++ b/tests/unit/parser_test.rs @@ -121,6 +121,111 @@ fn openai_invalid_json_returns_error() { assert!(openai::parse(b"").is_err()); } +// ──────────────────────── OpenAI Responses API ────────────────────────── + +#[test] +fn openai_responses_api_body_parses_input_output_and_cached() { + // /v1/responses (Codex CLI default) names the usage fields + // input_tokens/output_tokens/input_tokens_details — same semantics as + // Chat Completions (input includes the cached portion), different names. + let body = br#"{ + "id": "resp_abc123", + "object": "response", + "status": "completed", + "model": "gpt-5.4-codex", + "output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "ok"}]}], + "usage": { + "input_tokens": 2048, + "input_tokens_details": {"cached_tokens": 1536}, + "output_tokens": 256, + "output_tokens_details": {"reasoning_tokens": 64}, + "total_tokens": 2304 + } + }"#; + let parsed = openai::parse(body).expect("parse Responses API body"); + + // input_tokens=2048, cached=1536 → non-cached input=512, cache_read=1536 + assert_eq!(parsed.model, "gpt-5.4-codex"); + assert_eq!( + parsed.usage, + TokenUsage { + input_tokens: 512, + output_tokens: 256, + cache_creation_tokens: 0, + cache_read_tokens: 1536, + } + ); + + // The proxy tee goes through parse_any — same result. + assert_eq!(openai::parse_any(body), Some(parsed)); +} + +#[test] +fn openai_responses_api_sse_reads_usage_from_completed_event() { + // Responses API streaming nests model/usage under `response` in typed + // events; usage arrives on the final `response.completed` event. + let sse = "event: response.created\n\ +data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.4-codex\",\"status\":\"in_progress\",\"usage\":null}}\n\ +\n\ +event: response.output_text.delta\n\ +data: {\"type\":\"response.output_text.delta\",\"delta\":\"Hello\"}\n\ +\n\ +event: response.completed\n\ +data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.4-codex\",\"status\":\"completed\",\"usage\":{\"input_tokens\":1000,\"input_tokens_details\":{\"cached_tokens\":400},\"output_tokens\":50,\"total_tokens\":1050}}}\n\n"; + + let parsed = openai::parse_sse(sse.as_bytes()).expect("sse parse"); + assert_eq!(parsed.model, "gpt-5.4-codex"); + assert_eq!( + parsed.usage, + TokenUsage { + input_tokens: 600, + output_tokens: 50, + cache_creation_tokens: 0, + cache_read_tokens: 400, + } + ); +} + +#[test] +fn openai_chat_completions_still_parses_via_parse_any() { + // The Responses API support must not disturb the Chat Completions path + // the tee already relies on. + let parsed = openai::parse_any(&fixture("openai_cached.json")).expect("parse_any"); + assert_eq!(parsed.model, "gpt-5.4-2026-01-15"); + assert_eq!( + parsed.usage, + TokenUsage { + input_tokens: 512, + output_tokens: 512, + cache_creation_tokens: 0, + cache_read_tokens: 1536, + } + ); +} + +#[test] +fn openai_all_zero_usage_returns_none_from_parse_any() { + // Every Usage field is #[serde(default)], so an unrecognized usage shape + // deserializes "successfully" with zero tokens. parse_any must treat that + // as a parse failure (None → tee warns) instead of recording a $0 row. + let empty_usage = br#"{"model":"gpt-5.4","usage":{}}"#; + assert_eq!(openai::parse_any(empty_usage), None); + + let unknown_shape = br#"{"model":"gpt-5.4","usage":{"weird_tokens":123}}"#; + assert_eq!(openai::parse_any(unknown_shape), None); +} + +#[test] +fn openai_zero_output_with_nonzero_input_still_parses() { + // The all-zero guard must not reject legitimate edge cases: a response + // that billed input but produced no output tokens is still a real, + // billable response. + let body = br#"{"model":"gpt-5.4","usage":{"prompt_tokens":300,"completion_tokens":0}}"#; + let parsed = openai::parse_any(body).expect("nonzero input must parse"); + assert_eq!(parsed.usage.input_tokens, 300); + assert_eq!(parsed.usage.output_tokens, 0); +} + // ──────────────────────────────── Google ──────────────────────────────── #[test] From 0f79165b49a0aa27ee67e4cd90fdfd0fe24aaa03 Mon Sep 17 00:00:00 2001 From: codehippie1 Date: Wed, 10 Jun 2026 15:38:25 -0400 Subject: [PATCH 5/9] storage/logscrape: schema versioning, race-safe migrations, off-by-one + corpus-rescan fixes busy_timeout set before the WAL switch and duplicate-column tolerated (first-open races); PRAGMA user_version stamped and a newer-than-supported DB refused. daily_totals off-by-one corrected to match sibling window queries (history --days 7 now shows 7), history clamps days>=1. Logscrape prunes by file mtime and streams lines instead of slurping whole multi-MB session files, so status/waste no longer re-parse the entire corpus each call. Adds total_cost_for_month for the budget cap. --- src/cli/history.rs | 11 ++-- src/logscrape/aider.rs | 20 ++++++- src/logscrape/claude_code.rs | 22 ++++--- src/logscrape/codex.rs | 53 ++++++++++------ src/logscrape/mod.rs | 113 +++++++++++++++++++++++++++++++---- src/logscrape/opencode.rs | 10 +++- src/storage/mod.rs | 47 +++++++++++++-- src/storage/repository.rs | 22 ++++++- tests/unit/logscrape_test.rs | 108 +++++++++++++++++++++++++++++++++ tests/unit/storage_test.rs | 34 +++++++++++ 10 files changed, 388 insertions(+), 52 deletions(-) diff --git a/src/cli/history.rs b/src/cli/history.rs index 462edd5..aeda9e4 100644 --- a/src/cli/history.rs +++ b/src/cli/history.rs @@ -68,8 +68,11 @@ fn days_in_month(year: i32, month: u32) -> u32 { } pub fn run_cmd(args: HistoryArgs) -> anyhow::Result<()> { + // A non-positive --days would produce an invalid SQLite date modifier + // and a silently empty table — clamp to at least one day (today). + let days = args.days.max(1); let storage = Arc::new(Storage::open_default().context("opening storage")?); - let totals = storage.daily_totals(args.days)?; + let totals = storage.daily_totals(days)?; let cfg_path = config::default_path()?; let cfg = config::load_or_default(&cfg_path).context("loading config")?; @@ -78,7 +81,7 @@ pub fn run_cmd(args: HistoryArgs) -> anyhow::Result<()> { let mut out = std::io::stdout().lock(); if args.json { let value = serde_json::json!({ - "days": args.days, + "days": days, "rows": totals.iter().map(|t| serde_json::json!({ "date": t.date, "total_cost_usd": t.total_cost, @@ -102,8 +105,8 @@ pub fn run_cmd(args: HistoryArgs) -> anyhow::Result<()> { writeln!( out, "📅 Last {} day{}", - args.days, - if args.days == 1 { "" } else { "s" } + days, + if days == 1 { "" } else { "s" } )?; if totals.is_empty() { writeln!(out, " (no data)")?; diff --git a/src/logscrape/aider.rs b/src/logscrape/aider.rs index f96664e..d3f1423 100644 --- a/src/logscrape/aider.rs +++ b/src/logscrape/aider.rs @@ -31,6 +31,7 @@ //! the parser yields nothing. use std::path::PathBuf; +use std::time::SystemTime; use chrono::DateTime; use serde_json::Value; @@ -48,13 +49,26 @@ pub fn parse_str(contents: &str) -> Vec { /// Read and parse the Aider analytics log. Fail-open: returns empty if the /// file is absent or unreadable (analytics off, or never run). pub fn collect() -> Vec { + collect_since(None) +} + +/// [`collect`] with an optional mtime cutoff: an analytics log untouched +/// since before the window start (minus the safety margin) is skipped +/// unread; otherwise it is streamed line by line, never slurped whole. +pub fn collect_since(cutoff: Option) -> Vec { let Some(path) = analytics_path() else { return Vec::new(); }; - let Ok(contents) = std::fs::read_to_string(&path) else { + if super::path_is_stale(&path, cutoff) { return Vec::new(); - }; - parse_str(&contents) + } + let mut out = Vec::new(); + super::for_each_line(&path, |line| { + if let Some(entry) = parse_line(line) { + out.push(entry); + } + }); + out } /// Path to Aider's analytics log. `BURNWALL_AIDER_ANALYTICS` overrides it diff --git a/src/logscrape/claude_code.rs b/src/logscrape/claude_code.rs index 7f4fc4e..79131e9 100644 --- a/src/logscrape/claude_code.rs +++ b/src/logscrape/claude_code.rs @@ -32,6 +32,7 @@ use std::collections::HashSet; use std::path::PathBuf; +use std::time::SystemTime; use chrono::{DateTime, Utc}; use serde_json::Value; @@ -62,25 +63,32 @@ pub fn parse_str(contents: &str) -> Vec { /// de-duplicated across files. Fail-open: returns empty if the log /// directory is absent or unreadable. pub fn collect() -> Vec { + collect_since(None) +} + +/// [`collect`] with an optional mtime cutoff: session files untouched since +/// before the window start (minus the safety margin) are skipped unread — +/// these files can run to 100MB+, so the lines are streamed, never slurped. +pub fn collect_since(cutoff: Option) -> Vec { let Some(root) = log_root() else { return Vec::new(); }; let mut seen: HashSet = HashSet::new(); let mut out = Vec::new(); - for path in super::find_jsonl_files(&root) { - let Ok(contents) = std::fs::read_to_string(&path) else { - continue; - }; - for turn in parse_str(&contents) { + for path in super::find_jsonl_files(&root, cutoff) { + super::for_each_line(&path, |line| { + let Some(turn) = parse_line(line) else { + return; + }; // Repeated (message.id, requestId) across files = the same API // call re-logged by a resumed/forked session — drop the repeat. if let Some(key) = turn.dedup_key { if !seen.insert(key) { - continue; + return; } } out.push(turn.entry); - } + }); } out } diff --git a/src/logscrape/codex.rs b/src/logscrape/codex.rs index 3ead11d..3b73884 100644 --- a/src/logscrape/codex.rs +++ b/src/logscrape/codex.rs @@ -24,6 +24,7 @@ //! events with no known model are skipped, never fatal. use std::path::{Path, PathBuf}; +use std::time::SystemTime; use chrono::{DateTime, Local, NaiveDate, Utc}; use serde_json::Value; @@ -39,22 +40,32 @@ const TOOL: &str = "codex"; pub fn parse_str(contents: &str, fallback_date: Option) -> Vec { let mut state = SessionState::default(); let mut out = Vec::new(); - for line in contents.lines() { - let Ok(value) = serde_json::from_str::(line) else { - continue; - }; - match value.get("type").and_then(Value::as_str) { - Some("turn_context") | Some("session_meta") => state.update_from(&value), - Some("event_msg") => { - if let Some(entry) = parse_token_count(&value, &state, fallback_date) { - out.push(entry); - } + parse_line_into(line, &mut state, fallback_date, &mut out); + } + out +} + +/// Absorb one rollout line: context lines update `state`, `token_count` +/// events append to `out`, everything else is skipped (fail-open). +fn parse_line_into( + line: &str, + state: &mut SessionState, + fallback_date: Option, + out: &mut Vec, +) { + let Ok(value) = serde_json::from_str::(line) else { + return; + }; + match value.get("type").and_then(Value::as_str) { + Some("turn_context") | Some("session_meta") => state.update_from(&value), + Some("event_msg") => { + if let Some(entry) = parse_token_count(&value, state, fallback_date) { + out.push(entry); } - _ => {} } + _ => {} } - out } /// The most recent session context — model, working directory, and session @@ -89,15 +100,23 @@ impl SessionState { /// Discover and parse every Codex rollout log under the log root. /// Fail-open: returns empty if the log directory is absent or unreadable. pub fn collect() -> Vec { + collect_since(None) +} + +/// [`collect`] with an optional mtime cutoff: rollout files untouched since +/// before the window start (minus the safety margin) are skipped unread; +/// the rest are streamed line by line, never slurped whole. +pub fn collect_since(cutoff: Option) -> Vec { let Some(root) = log_root() else { return Vec::new(); }; let mut out = Vec::new(); - for path in super::find_jsonl_files(&root) { - let Ok(contents) = std::fs::read_to_string(&path) else { - continue; - }; - out.extend(parse_str(&contents, date_from_path(&path))); + for path in super::find_jsonl_files(&root, cutoff) { + let fallback_date = date_from_path(&path); + let mut state = SessionState::default(); + super::for_each_line(&path, |line| { + parse_line_into(line, &mut state, fallback_date, &mut out); + }); } out } diff --git a/src/logscrape/mod.rs b/src/logscrape/mod.rs index 64fc668..7d821a7 100644 --- a/src/logscrape/mod.rs +++ b/src/logscrape/mod.rs @@ -37,9 +37,11 @@ pub mod codex; pub mod opencode; use std::collections::BTreeMap; +use std::io::BufRead; use std::path::{Path, PathBuf}; +use std::time::{Duration as StdDuration, SystemTime}; -use chrono::{DateTime, Local, Utc}; +use chrono::{DateTime, Local, NaiveDate, Utc}; use crate::pricing; use crate::providers::TokenUsage; @@ -130,29 +132,101 @@ pub fn collect_all() -> Vec { collect_selected(Tools::all()) } +/// [`collect_all`] with an mtime cutoff — see [`collect_selected_since`]. +pub fn collect_all_since(cutoff: Option) -> Vec { + collect_selected_since(Tools::all(), cutoff) +} + /// Collect entries only from the selected tools — honors the per-tool /// `[tools]` config switches so a disabled tool is never read. pub fn collect_selected(tools: Tools) -> Vec { + collect_selected_since(tools, None) +} + +/// [`collect_selected`] with an optional mtime cutoff: log files whose +/// mtime predates `cutoff` by more than [`MTIME_SAFETY_MARGIN`] are skipped +/// without being read — a file untouched since before the window started +/// cannot contribute rows inside it. `None` reads everything (the previous +/// behavior). +pub fn collect_selected_since(tools: Tools, cutoff: Option) -> Vec { let mut entries = Vec::new(); if tools.claude_code { - entries.extend(claude_code::collect()); + entries.extend(claude_code::collect_since(cutoff)); } if tools.codex { - entries.extend(codex::collect()); + entries.extend(codex::collect_since(cutoff)); } if tools.opencode { - entries.extend(opencode::collect()); + entries.extend(opencode::collect_since(cutoff)); } if tools.aider { - entries.extend(aider::collect()); + entries.extend(aider::collect_since(cutoff)); } entries } /// Scrape every supported tool's logs and aggregate the entries that fall -/// on `date` (a *local* `YYYY-MM-DD` string) by tool + model. +/// on `date` (a *local* `YYYY-MM-DD` string) by tool + model. Files whose +/// mtime predates `date` (minus the safety margin) are never read. pub fn scrape_for_date(date: &str) -> Vec { - aggregate(collect_all(), date) + aggregate(collect_all_since(cutoff_for_local_date(date)), date) +} + +/// How far past a window-start cutoff a file's mtime may lag before the +/// file is skipped unread. One day absorbs clock skew, coarse filesystem +/// timestamps, and tools that buffer writes — a file untouched for longer +/// than this before the window start cannot hold entries inside the window. +pub const MTIME_SAFETY_MARGIN: StdDuration = StdDuration::from_secs(24 * 60 * 60); + +/// Pure cutoff predicate: true when a file last modified at `mtime` cannot +/// contain entries at or after `cutoff` — i.e. the mtime predates the +/// window start by more than [`MTIME_SAFETY_MARGIN`]. +pub fn mtime_is_stale(mtime: SystemTime, cutoff: SystemTime) -> bool { + match cutoff.duration_since(mtime) { + Ok(gap) => gap > MTIME_SAFETY_MARGIN, + // mtime is at/after the cutoff — definitely fresh. + Err(_) => false, + } +} + +/// The window-start instant for a local `YYYY-MM-DD` date string — local +/// midnight of that date. `None` when the string doesn't parse (fail-open: +/// no pruning rather than wrong pruning). +pub fn cutoff_for_local_date(date: &str) -> Option { + let day = NaiveDate::parse_from_str(date, "%Y-%m-%d").ok()?; + let midnight = day.and_hms_opt(0, 0, 0)?.and_local_timezone(Local).earliest()?; + Some(SystemTime::from(midnight)) +} + +/// True when `path`'s mtime says the file cannot contribute entries at or +/// after `cutoff`. An unreadable mtime keeps the file (fail-open — never +/// drop data over a metadata hiccup); `cutoff == None` keeps everything. +pub(crate) fn path_is_stale(path: &Path, cutoff: Option) -> bool { + let Some(cutoff) = cutoff else { + return false; + }; + match std::fs::metadata(path).and_then(|m| m.modified()) { + Ok(mtime) => mtime_is_stale(mtime, cutoff), + Err(_) => false, + } +} + +/// Stream `path` line by line through `f`, without slurping the whole file +/// into memory (Claude Code session files can run to 100MB+). Fail-open per +/// line: a non-UTF-8 line is skipped, matching the "skip unparseable lines" +/// policy; any other I/O error stops reading the file, keeping the lines +/// already seen. +pub(crate) fn for_each_line(path: &Path, mut f: impl FnMut(&str)) { + let Ok(file) = std::fs::File::open(path) else { + return; + }; + for line in std::io::BufReader::new(file).lines() { + match line { + Ok(line) => f(&line), + Err(e) if e.kind() == std::io::ErrorKind::InvalidData => continue, + Err(_) => break, + } + } } /// Pure aggregation step, split out so tests can feed synthetic entries. @@ -213,14 +287,20 @@ pub fn subtotal(rows: &[ScrapeBreakdown]) -> f64 { /// Recursively collect `*.jsonl` files under `root`. See /// [`find_files_with_ext`]. -pub(crate) fn find_jsonl_files(root: &Path) -> Vec { - find_files_with_ext(root, "jsonl") +pub(crate) fn find_jsonl_files(root: &Path, cutoff: Option) -> Vec { + find_files_with_ext(root, "jsonl", cutoff) } -/// Recursively collect files with extension `ext` under `root`. Returns an -/// empty vec if `root` does not exist or cannot be read; unreadable -/// sub-entries are skipped (fail-open). -pub(crate) fn find_files_with_ext(root: &Path, ext: &str) -> Vec { +/// Recursively collect files with extension `ext` under `root`, pruning +/// files whose mtime predates `cutoff` by more than the safety margin (see +/// [`mtime_is_stale`]; `None` keeps everything). Returns an empty vec if +/// `root` does not exist or cannot be read; unreadable sub-entries are +/// skipped, and a file whose mtime can't be read is kept (fail-open). +pub(crate) fn find_files_with_ext( + root: &Path, + ext: &str, + cutoff: Option, +) -> Vec { let mut out = Vec::new(); let mut stack = vec![root.to_path_buf()]; while let Some(dir) = stack.pop() { @@ -236,6 +316,13 @@ pub(crate) fn find_files_with_ext(root: &Path, ext: &str) -> Vec { stack.push(path); } else if file_type.is_file() && path.extension().and_then(|e| e.to_str()) == Some(ext) { + if let Some(cutoff) = cutoff { + if let Ok(mtime) = entry.metadata().and_then(|m| m.modified()) { + if mtime_is_stale(mtime, cutoff) { + continue; + } + } + } out.push(path); } } diff --git a/src/logscrape/opencode.rs b/src/logscrape/opencode.rs index eb15288..b8762a3 100644 --- a/src/logscrape/opencode.rs +++ b/src/logscrape/opencode.rs @@ -32,6 +32,7 @@ //! block, or reports zero tokens, contributes nothing. use std::path::PathBuf; +use std::time::SystemTime; use chrono::{DateTime, Utc}; use serde_json::Value; @@ -44,11 +45,18 @@ const TOOL: &str = "opencode"; /// Discover and parse every OpenCode message file under the message root. /// Fail-open: returns empty if the directory is absent or unreadable. pub fn collect() -> Vec { + collect_since(None) +} + +/// [`collect`] with an optional mtime cutoff: message files untouched since +/// before the window start (minus the safety margin) are skipped unread. +/// Each file is one small JSON object (not JSONL), so whole-file reads stay. +pub fn collect_since(cutoff: Option) -> Vec { let Some(root) = message_root() else { return Vec::new(); }; let mut out = Vec::new(); - for path in super::find_files_with_ext(&root, "json") { + for path in super::find_files_with_ext(&root, "json", cutoff) { let Ok(contents) = std::fs::read_to_string(&path) else { continue; }; diff --git a/src/storage/mod.rs b/src/storage/mod.rs index cc2e698..3a60301 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -165,6 +165,11 @@ pub enum StorageError { Io(#[from] std::io::Error), #[error("home directory not found")] NoHomeDir, + #[error( + "database schema v{found} is newer than this binary supports (v{supported}) — \ + it was written by a newer Burnwall. Upgrade, or point BURNWALL_DATA_DIR elsewhere." + )] + SchemaTooNew { found: i64, supported: i64 }, } pub type Result = std::result::Result; @@ -228,11 +233,35 @@ impl Storage { /// write wait-and-retry instead of failing immediately with `SQLITE_BUSY`. /// Both are harmless on an in-memory database (journal mode stays `memory`). fn configure(conn: &Connection) -> Result<()> { - conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA busy_timeout=5000;")?; + // Set `busy_timeout` FIRST, as its own statement, *before* the WAL switch + // (D-M6). The one-time DELETE→WAL conversion on the first launch after a + // WAL-introducing upgrade needs brief exclusivity; with no busy handler + // armed, a concurrent statusline/daemon open races it into an instant + // `SQLITE_BUSY` that aborts `burnwall start`. Arming the timeout first + // makes the loser wait-and-retry instead. + conn.execute_batch("PRAGMA busy_timeout=5000;")?; + conn.execute_batch("PRAGMA journal_mode=WAL;")?; Ok(()) } +/// Schema version this binary writes/understands. Bump on every migration so +/// an older binary can refuse a DB it would mis-read (D-M7). +const SCHEMA_VERSION: i64 = 1; + fn migrate(conn: &Connection) -> Result<()> { + // Refuse to open a DB stamped newer than we understand: an old binary + // running against a newer schema (after a rolled-back upgrade) silently + // mis-reading rows is the worst post-update failure. Additive migrations + // are still downgrade-safe today (version 0/1), so only a *strictly + // greater* stamp is fatal. + let on_disk: i64 = conn.query_row("PRAGMA user_version", [], |r| r.get(0))?; + if on_disk > SCHEMA_VERSION { + return Err(StorageError::SchemaTooNew { + found: on_disk, + supported: SCHEMA_VERSION, + }); + } + conn.execute_batch(SCHEMA)?; // Forward-add columns introduced after a table first shipped. Idempotent: // skipped when the column already exists (a DB created from the current @@ -246,6 +275,10 @@ fn migrate(conn: &Connection) -> Result<()> { // v0.7 observability: per-request upstream latency + HTTP status. ensure_column(conn, "requests", "latency_ms", "INTEGER")?; ensure_column(conn, "requests", "http_status", "INTEGER")?; + + if on_disk < SCHEMA_VERSION { + conn.execute_batch(&format!("PRAGMA user_version={SCHEMA_VERSION};"))?; + } Ok(()) } @@ -260,10 +293,14 @@ fn ensure_column(conn: &Connection, table: &str, column: &str, decl: &str) -> Re .any(|name| name == column); drop(stmt); if !present { - conn.execute( - &format!("ALTER TABLE {table} ADD COLUMN {column} {decl}"), - [], - )?; + match conn.execute(&format!("ALTER TABLE {table} ADD COLUMN {column} {decl}"), []) { + Ok(_) => {} + // Tolerate the check-then-ALTER race (D-M6): two processes opening + // at once can both see the column missing; the loser's ALTER fails + // with "duplicate column name", which is success for our purposes. + Err(e) if e.to_string().contains("duplicate column name") => {} + Err(e) => return Err(e.into()), + } } Ok(()) } diff --git a/src/storage/repository.rs b/src/storage/repository.rs index ab2484a..1e45b8d 100644 --- a/src/storage/repository.rs +++ b/src/storage/repository.rs @@ -284,6 +284,21 @@ impl Storage { }) } + /// Total spend for a local calendar month. `month` is a `YYYY-MM` string; + /// rows are bucketed by their local-time month so the boundary matches the + /// daily query and the user's clock. Powers the monthly budget cap (B-H2). + pub fn total_cost_for_month(&self, month: &str) -> Result { + self.with_conn(|conn| { + let cost: f64 = conn.query_row( + "SELECT COALESCE(SUM(cost_usd), 0.0) FROM requests + WHERE strftime('%Y-%m', timestamp, 'localtime') = ?1", + params![month], + |row| row.get(0), + )?; + Ok(cost) + }) + } + /// The most recent successful (non-blocked) request, if any. Powers the /// DB-sourced status ribbon (`burnwall watch` / editor bar): the last /// real turn's model, token counts, and cost. @@ -370,8 +385,11 @@ impl Storage { pub fn daily_totals(&self, days: i64) -> Result> { self.with_conn(|conn| { // `DATE('now', 'localtime', '-N days')` gives the local date N - // days ago. Bind `-N days` as a parameter, not concatenated. - let offset = format!("-{} days", days); + // days ago. A window of `days` days *includes* today, so the + // earliest included date is `days - 1` back — matching the other + // `*_since_days` queries. Bind `-N days` as a parameter, not + // concatenated. + let offset = format!("-{} days", days - 1); let mut stmt = conn.prepare( "SELECT DATE(timestamp, 'localtime') AS date, diff --git a/tests/unit/logscrape_test.rs b/tests/unit/logscrape_test.rs index 5c80523..9386bc2 100644 --- a/tests/unit/logscrape_test.rs +++ b/tests/unit/logscrape_test.rs @@ -4,6 +4,7 @@ use std::fs; use std::path::Path; use std::sync::Mutex; +use std::time::{Duration as StdDuration, SystemTime}; use chrono::{DateTime, Duration, Local, NaiveDate, Utc}; @@ -405,6 +406,113 @@ fn aggregate_empty_input_is_empty() { assert!(logscrape::aggregate(Vec::new(), &local_date(0)).is_empty()); } +// ──────────────────────── mtime cutoff pruning ──────────────────────── + +/// Rewind a file's mtime by `days` days from now. +fn age_file(path: &Path, days: u64) { + let mtime = SystemTime::now() - StdDuration::from_secs(days * 24 * 60 * 60); + let file = fs::OpenOptions::new().write(true).open(path).unwrap(); + file.set_modified(mtime).unwrap(); +} + +/// A window-start cutoff `days` days before now. +fn cutoff_days_ago(days: u64) -> SystemTime { + SystemTime::now() - StdDuration::from_secs(days * 24 * 60 * 60) +} + +#[test] +fn mtime_staleness_allows_a_one_day_margin_past_the_cutoff() { + let cutoff = SystemTime::now(); + let hour = StdDuration::from_secs(3600); + // At or after the cutoff → fresh. + assert!(!logscrape::mtime_is_stale(cutoff, cutoff)); + assert!(!logscrape::mtime_is_stale(cutoff + hour, cutoff)); + // Before the cutoff but within the 1-day safety margin → still fresh + // (clock skew / buffered writes must not drop in-window data). + assert!(!logscrape::mtime_is_stale(cutoff - 23 * hour, cutoff)); + // More than the margin before the cutoff → stale, skipped unread. + assert!(logscrape::mtime_is_stale(cutoff - 25 * hour, cutoff)); +} + +#[test] +fn cutoff_for_local_date_parses_dates_fail_open() { + // A valid local date maps to its local midnight: today's cutoff is in + // the past, and yesterday's is strictly earlier. + let today = logscrape::cutoff_for_local_date(&local_date(0)).expect("valid date"); + let yesterday = logscrape::cutoff_for_local_date(&local_date(-1)).expect("valid date"); + assert!(today <= SystemTime::now()); + assert!(yesterday < today); + // Garbage yields no cutoff — scrape everything rather than prune wrongly. + assert!(logscrape::cutoff_for_local_date("not-a-date").is_none()); + assert!(logscrape::cutoff_for_local_date("").is_none()); +} + +#[test] +fn claude_code_collect_since_prunes_files_older_than_the_window() { + let dir = tempfile::tempdir().unwrap(); + let sub = dir.path().join("project-a"); + fs::create_dir_all(&sub).unwrap(); + + // An old session file (mtime 10 days back) and a fresh one written now, + // with distinct dedup keys so pruning — not dedup — decides the count. + let old = sub.join("old.jsonl"); + fs::write(&old, fixture("claude_code_session.jsonl")).unwrap(); + age_file(&old, 10); + let fresh = sub.join("fresh.jsonl"); + fs::write( + &fresh, + r#"{"type":"assistant","timestamp":"2026-06-10T09:00:05.000Z","requestId":"req_fresh","sessionId":"sess_f","cwd":"/w","message":{"id":"msg_fresh","model":"claude-opus-4-7","usage":{"input_tokens":10,"output_tokens":5}}}"#, + ) + .unwrap(); + + let _guard = set_log_dir("BURNWALL_CLAUDE_LOG_DIR", dir.path()); + + // Window starts 2 days ago: the 10-day-old file cannot contribute rows + // inside it (even with the 1-day margin) and is skipped unread; the + // file modified today is parsed. + let entries = claude_code::collect_since(Some(cutoff_days_ago(2))); + assert_eq!(entries.len(), 1, "got {entries:?}"); + assert_eq!(entries[0].model, "claude-opus-4-7"); + assert_eq!(entries[0].session_id.as_deref(), Some("sess_f")); + + // No cutoff preserves the old read-everything behavior: + // 3 deduped turns from the old file + 1 fresh. + assert_eq!(claude_code::collect_since(None).len(), 4); +} + +#[test] +fn aider_collect_since_skips_a_stale_analytics_file() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("analytics.jsonl"); + fs::write(&path, fixture("aider_analytics.jsonl")).unwrap(); + age_file(&path, 10); + + let _guard = set_log_dir("BURNWALL_AIDER_ANALYTICS", &path); + + // The analytics log was last touched well before the window → skipped. + assert!(aider::collect_since(Some(cutoff_days_ago(2))).is_empty()); + // No cutoff still reads it (previous behavior preserved). + assert_eq!(aider::collect_since(None).len(), 2); + // A file touched today survives the same cutoff. + age_file(&path, 0); + assert_eq!(aider::collect_since(Some(cutoff_days_ago(2))).len(), 2); +} + +#[test] +fn codex_collect_since_prunes_stale_rollouts() { + let dir = tempfile::tempdir().unwrap(); + let day = dir.path().join("2026").join("05").join("14"); + fs::create_dir_all(&day).unwrap(); + let rollout = day.join("rollout-abc.jsonl"); + fs::write(&rollout, fixture("codex_session.jsonl")).unwrap(); + age_file(&rollout, 10); + + let _guard = set_log_dir("BURNWALL_CODEX_LOG_DIR", dir.path()); + assert!(codex::collect_since(Some(cutoff_days_ago(2))).is_empty()); + // Streaming without a cutoff parses the same 3 events as before. + assert_eq!(codex::collect_since(None).len(), 3); +} + #[test] fn subtotal_sums_row_costs() { let rows = logscrape::aggregate( diff --git a/tests/unit/storage_test.rs b/tests/unit/storage_test.rs index e6b89cb..96ba946 100644 --- a/tests/unit/storage_test.rs +++ b/tests/unit/storage_test.rs @@ -242,6 +242,40 @@ fn daily_totals_groups_by_date_and_aggregates() { assert!((totals[1].total_cost - 0.20).abs() < 1e-9); } +#[test] +fn daily_totals_one_day_window_returns_only_today() { + let storage = Storage::open_in_memory().unwrap(); + + // One row today, one yesterday. A `days = 1` window means *today only* — + // the same inclusive-of-today convention as the `*_since_days` queries + // (regression test for the off-by-one that made `history --days 7` + // print 8 days). + let mut today_row = + RequestRecord::successful("anthropic", "claude-sonnet-4-6", &sample_usage(), 0.05, None); + today_row.timestamp = local_noon(0); + storage.insert_request(&today_row).unwrap(); + + let mut yesterday_row = + RequestRecord::successful("openai", "gpt-5.4", &sample_usage(), 0.20, None); + yesterday_row.timestamp = local_noon(-1); + storage.insert_request(&yesterday_row).unwrap(); + + let totals = storage.daily_totals(1).unwrap(); + assert_eq!( + totals.len(), + 1, + "1-day window must hold today only, got {totals:?}" + ); + assert_eq!(totals[0].date, local_date(0)); + assert_eq!(totals[0].total_requests, 1); + assert!((totals[0].total_cost - 0.05).abs() < 1e-9); + + // A 2-day window picks yesterday back up. + let totals = storage.daily_totals(2).unwrap(); + assert_eq!(totals.len(), 2); + assert_eq!(totals[1].date, local_date(-1)); +} + // ─────────────────────────── Security events ─────────────────────────── #[test] From b9fbb03f26a581dbd96aa33d91c1b7fe122697c6 Mon Sep 17 00:00:00 2001 From: codehippie1 Date: Wed, 10 Jun 2026 15:38:37 -0400 Subject: [PATCH 6/9] mcp/audit: prose-safe scanning, approval-reset UX, key-loss guard, hardening MCP watcher uses prose-safe scoping (only shell-tool args get command checks) and strips accept-encoding so gzip can't blind the firewall. Rug-pull re-pend now keys on the schema fingerprint (description-only changes warn, don't revoke); approval/deny 403s are proper JSON-RPC errors naming the remediation command; tools/list reads are timeout-bounded. Audit: a lost/changed key is detected and refuses to seal (with a rekey command) instead of forever reporting TAMPERED; seal is transactional; SARIF results carry locations; query strings stripped before persist; pack ids validated; rules fetch compares against the TOFU pin. --- src/audit/mod.rs | 247 ++++++++++++++++++++++-- src/audit/sarif.rs | 12 ++ src/cli/audit.rs | 28 +++ src/cli/mcp_watch.rs | 12 +- src/cli/rules.rs | 27 ++- src/mcp/firewall.rs | 31 ++- src/mcp/mod.rs | 183 +++++++++++++++--- tests/integration/mcp_watch_test.rs | 282 +++++++++++++++++++++++++++- tests/unit/audit_test.rs | 169 +++++++++++++++++ tests/unit/rulepack_test.rs | 18 ++ 10 files changed, 951 insertions(+), 58 deletions(-) create mode 100644 tests/unit/audit_test.rs diff --git a/src/audit/mod.rs b/src/audit/mod.rs index a8d3a5b..db0b04d 100644 --- a/src/audit/mod.rs +++ b/src/audit/mod.rs @@ -22,7 +22,7 @@ pub mod aibom; pub mod sarif; -use std::path::Path; +use std::path::{Path, PathBuf}; use ed25519_dalek::{Signature, Signer, SigningKey}; use sha2::{Digest as _, Sha256}; @@ -43,6 +43,11 @@ pub enum AuditError { Storage(#[from] crate::storage::StorageError), #[error("audit signing key is malformed (expected 32 bytes, found {0})")] BadKey(usize), + #[error( + "audit key changed or lost — existing chain was signed by {old_key}…; new receipts \ + would fork it. Run `burnwall audit rekey` to start a new chain segment." + )] + KeyChanged { old_key: String }, } pub type Result = std::result::Result; @@ -50,6 +55,14 @@ pub type Result = std::result::Result; /// Holds the local Ed25519 signing key and seals/verifies receipts. pub struct AuditChain { key: SigningKey, + /// True when `open()` had to generate a fresh keypair because the key file + /// was missing. Combined with the chain-pubkey sidecar this lets `seal` + /// refuse to silently fork a chain whose original key was lost (M-H1). + regenerated: bool, + /// Sidecar recording the hex public key the existing chain was signed + /// with (`.pub`, next to the key). Written on first seal; + /// compared on every later seal. + chain_pub_path: PathBuf, } impl AuditChain { @@ -61,6 +74,7 @@ impl AuditChain { /// Load (or, if absent, generate) the signing key at `path`. pub fn open(path: &Path) -> Result { + let mut regenerated = false; let key = if path.exists() { let bytes = std::fs::read(path)?; let seed: [u8; 32] = bytes @@ -75,9 +89,130 @@ impl AuditChain { } std::fs::write(path, key.to_bytes())?; set_key_perms(path)?; + regenerated = true; key }; - Ok(Self { key }) + Ok(Self { + key, + regenerated, + chain_pub_path: path.with_extension("pub"), + }) + } + + /// The chain public key recorded by an earlier seal, if any. + fn stored_chain_pubkey(&self) -> Option { + std::fs::read_to_string(&self.chain_pub_path) + .ok() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + } + + /// Record the current key as the chain's public key. + fn record_chain_pubkey(&self) -> Result<()> { + std::fs::write(&self.chain_pub_path, self.public_key_hex())?; + Ok(()) + } + + /// M-H1 guard: refuse to extend an existing chain with a key that is not + /// the one the chain was signed with. Without this, a lost key file would + /// silently regenerate and every receipt sealed from then on would make + /// `verify` report the whole chain TAMPERED. + fn guard_key_continuity(&self, storage: &Storage) -> Result<()> { + let current = self.public_key_hex(); + let stored = self.stored_chain_pubkey(); + if storage.last_receipt_hash()?.is_some() { + match &stored { + Some(stored) if *stored != current => { + return Err(AuditError::KeyChanged { + old_key: stored.chars().take(8).collect(), + }); + } + Some(_) => {} + None => { + // Legacy chain sealed before the sidecar existed. If the + // key file went missing (regenerated) the fresh key cannot + // have signed the existing tail — check the tail signature + // rather than trusting our luck. + if self.regenerated && !self.tail_signature_matches(storage)? { + return Err(AuditError::KeyChanged { + old_key: "an unknown key".to_string(), + }); + } + } + } + } + // Continuity holds (or the chain is empty): pin the key the next + // receipts will be signed with, so a future key loss is detectable. + if stored.as_deref() != Some(current.as_str()) { + self.record_chain_pubkey()?; + } + Ok(()) + } + + /// Does the chain tail's Ed25519 signature verify under the current key? + /// `true` for an empty chain. + fn tail_signature_matches(&self, storage: &Storage) -> Result { + let tail: Option<(String, String)> = storage.with_conn(|conn| { + use rusqlite::OptionalExtension as _; + Ok(conn + .query_row( + "SELECT hash, signature FROM audit_receipts ORDER BY seq DESC LIMIT 1", + [], + |row| Ok((row.get(0)?, row.get(1)?)), + ) + .optional()?) + })?; + let Some((hash, signature)) = tail else { + return Ok(true); + }; + Ok(decode_hex(&signature) + .and_then(|b| Signature::from_slice(&b).ok()) + .map(|sig| { + self.key + .verifying_key() + .verify_strict(hash.as_bytes(), &sig) + .is_ok() + }) + .unwrap_or(false)) + } + + /// Deliberately start a new chain segment under the current key after the + /// previous key was lost or replaced (`burnwall audit rekey`). Archives the + /// closing segment (old public key, chain head, receipt count) next to the + /// sidecar, then records the current key so `seal` can resume. + pub fn rekey(&self, storage: &Storage) -> Result { + let old_key = self.stored_chain_pubkey(); + let chain_head = storage.last_receipt_hash()?; + let receipts: u64 = storage.with_conn(|conn| { + Ok(conn.query_row("SELECT COUNT(*) FROM audit_receipts", [], |row| row.get(0))?) + })?; + + // Append-only archive of closed segments — the external record of + // where each key's coverage ends, so an auditor can still verify the + // old segment against the old public key. + let archive = self.chain_pub_path.with_file_name("audit_chain_segments.log"); + let line = format!( + "{} closed-segment pubkey={} head={} receipts={}\n", + chrono::Utc::now().to_rfc3339(), + old_key.as_deref().unwrap_or("unknown"), + chain_head.as_deref().unwrap_or(GENESIS_HASH), + receipts, + ); + use std::io::Write as _; + std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(&archive)? + .write_all(line.as_bytes())?; + + self.record_chain_pubkey()?; + Ok(RekeyReport { + old_key, + new_key: self.public_key_hex(), + chain_head, + receipts, + archive, + }) } /// The verifying (public) key, hex-encoded. Safe to publish — it lets a @@ -96,10 +231,11 @@ impl AuditChain { /// Seal every not-yet-sealed request + security event into the chain, in /// chronological order. Idempotent: rows already sealed are skipped (the /// `audit_receipts.UNIQUE(source, source_id)` constraint backs this). + /// + /// Refuses outright when the local key is not the one the existing chain + /// was signed with (M-H1) — see [`AuditError::KeyChanged`]. pub fn seal(&self, storage: &Storage) -> Result { - let mut prev = storage - .last_receipt_hash()? - .unwrap_or_else(|| GENESIS_HASH.to_string()); + self.guard_key_continuity(storage)?; let mut pending: Vec = Vec::new(); for r in storage.unsealed_requests()? { @@ -117,28 +253,86 @@ impl AuditChain { .then_with(|| a.source_id().cmp(&b.source_id())) }); + // M-M3: read-the-tail + append must be one atomic unit. Two concurrent + // `seal` runs (e.g. a cron'd seal racing `audit pack`) could otherwise + // both read the same tail hash and append receipts with the same + // `prev_hash` — a fork that `verify` would flag forever. An IMMEDIATE + // transaction takes the SQLite write lock up front; the loser waits + // (busy_timeout) and then re-reads the new tail, skipping any rows the + // winner already sealed. + let sealed = storage.with_conn(|conn| { + conn.execute_batch("BEGIN IMMEDIATE")?; + match self.seal_in_txn(conn, &pending) { + Ok(sealed) => { + conn.execute_batch("COMMIT")?; + Ok(sealed) + } + Err(e) => { + let _ = conn.execute_batch("ROLLBACK"); + Err(e) + } + } + })?; + Ok(SealReport { sealed }) + } + + /// The seal loop body, run while holding the SQLite write lock. Uses the + /// raw connection (not the `Storage` helpers, which would re-lock). + fn seal_in_txn( + &self, + conn: &rusqlite::Connection, + pending: &[Pending], + ) -> crate::storage::Result { + use rusqlite::OptionalExtension as _; + let mut prev: String = conn + .query_row( + "SELECT hash FROM audit_receipts ORDER BY seq DESC LIMIT 1", + [], + |row| row.get(0), + ) + .optional()? + .unwrap_or_else(|| GENESIS_HASH.to_string()); + let mut sealed = 0u64; - for p in &pending { + for p in pending { + // A concurrent sealer may have sealed this row between our pending + // scan and taking the write lock — skip it instead of forking. + let already: Option = conn + .query_row( + "SELECT 1 FROM audit_receipts WHERE source = ?1 AND source_id = ?2", + rusqlite::params![p.source(), p.source_id()], + |row| row.get(0), + ) + .optional()?; + if already.is_some() { + continue; + } let content_hash = sha256_hex(p.canonical().as_bytes()); let hash = link_hash(&prev, &content_hash); let signature = hex(&self.key.sign(hash.as_bytes()).to_bytes()); - storage.insert_receipt( - p.source(), - p.source_id(), - &p.timestamp().to_rfc3339(), - p.action(), - p.provider(), - p.model(), - p.detail().as_deref(), - &content_hash, - &prev, - &hash, - &signature, + conn.execute( + "INSERT INTO audit_receipts + (source, source_id, timestamp, action, provider, model, detail, + content_hash, prev_hash, hash, signature) + VALUES (?1,?2,?3,?4,?5,?6,?7,?8,?9,?10,?11)", + rusqlite::params![ + p.source(), + p.source_id(), + p.timestamp().to_rfc3339(), + p.action(), + p.provider(), + p.model(), + p.detail(), + content_hash, + prev, + hash, + signature + ], )?; prev = hash; sealed += 1; } - Ok(SealReport { sealed }) + Ok(sealed) } /// Re-walk the chain: check each hash link, re-derive each `content_hash` @@ -225,6 +419,21 @@ pub struct SealReport { pub sealed: u64, } +/// Outcome of an `audit rekey` run. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RekeyReport { + /// Public key the closed segment was recorded under, if known. + pub old_key: Option, + /// Public key new receipts will be signed with. + pub new_key: String, + /// Hash of the last receipt in the closed segment (the segment boundary). + pub chain_head: Option, + /// Receipts in the closed segment. + pub receipts: u64, + /// Where the closed segment was archived. + pub archive: PathBuf, +} + /// Outcome of a `verify` run. #[derive(Debug, Clone, PartialEq)] pub enum VerifyReport { diff --git a/src/audit/sarif.rs b/src/audit/sarif.rs index b44fb77..7b970b5 100644 --- a/src/audit/sarif.rs +++ b/src/audit/sarif.rs @@ -34,6 +34,18 @@ pub fn build(events: &[SecurityEvent]) -> Value { "ruleId": e.event_type, "level": "error", "message": {"text": format!("Burnwall blocked a {} attempt: {}", e.event_type, e.details)}, + // GitHub code scanning rejects results without a location + // (M-M4). Security events have no source file, so emit a + // synthetic per-event URI; `region` is required alongside it + // by the upload validator. + "locations": [{ + "physicalLocation": { + "artifactLocation": { + "uri": format!("burnwall://security-events/{}", e.id.unwrap_or(0)), + }, + "region": {"startLine": 1}, + } + }], "properties": { "provider": e.provider, "model": e.model, diff --git a/src/cli/audit.rs b/src/cli/audit.rs index 92bfd58..5521616 100644 --- a/src/cli/audit.rs +++ b/src/cli/audit.rs @@ -30,6 +30,10 @@ pub enum AuditCommand { Seal, /// Verify the receipt chain — hashes, signatures, and live source rows. Verify, + /// Deliberately start a new chain segment under the current key after the + /// previous audit key was lost or replaced. Archives the old segment's + /// public key and chain head, then lets `seal` resume. + Rekey, /// Export the audit receipts. Export(ExportArgs), /// Export a CycloneDX AI Bill of Materials for the window. @@ -104,6 +108,30 @@ pub fn run_cmd(args: AuditArgs) -> anyhow::Result<()> { } } } + AuditCommand::Rekey => { + let chain = AuditChain::open_default().context("opening audit key")?; + let report = chain.rekey(&storage)?; + writeln!(out, "🔑 Started a new audit chain segment.")?; + writeln!( + out, + " Closed segment: {} receipt{} signed by {} (head {})", + report.receipts, + plural(report.receipts), + report.old_key.as_deref().unwrap_or("an unknown key"), + report + .chain_head + .as_deref() + .map(|h| &h[..h.len().min(8)]) + .unwrap_or("genesis"), + )?; + writeln!(out, " Segment record: {}", report.archive.display())?; + writeln!(out, " New public key: {}", report.new_key)?; + writeln!( + out, + " Receipts sealed before the rekey verify only against the archived key; \ + `burnwall audit seal` can now resume." + )?; + } AuditCommand::Export(a) => { let receipts = storage.all_receipts()?; let public_key = AuditChain::open_default().ok().map(|c| c.public_key_hex()); diff --git a/src/cli/mcp_watch.rs b/src/cli/mcp_watch.rs index 2ca4f59..c380ad0 100644 --- a/src/cli/mcp_watch.rs +++ b/src/cli/mcp_watch.rs @@ -129,11 +129,21 @@ pub async fn run_cmd(args: McpWatchArgs) -> anyhow::Result<()> { ); } + // M-C3: bounded timeouts so a hung upstream can never freeze the watcher + // (a fully-buffered `tools/list` against an un-timed client froze session + // init). No total-request timeout: tool calls can legitimately stream for + // a long time; `read_timeout` only fires when the connection goes silent. + let http_client = reqwest::Client::builder() + .connect_timeout(std::time::Duration::from_secs(10)) + .read_timeout(std::time::Duration::from_secs(60)) + .build() + .context("building upstream HTTP client")?; + let state = WatchState { upstream: args.upstream.clone().unwrap_or_default(), servers, require_approval, - http_client: reqwest::Client::new(), + http_client, storage, security, auto_approve: user_config.mcp.auto_approve.clone(), diff --git a/src/cli/rules.rs b/src/cli/rules.rs index d9dfba6..8c9d557 100644 --- a/src/cli/rules.rs +++ b/src/cli/rules.rs @@ -334,11 +334,28 @@ fn test(pack_ref: &str, file: &Path) -> anyhow::Result<()> { // ── add / revoke (third-party, TOFU) ─────────────────────────────────────── +/// M-M6: a pack id becomes a file name under the rules dir, so an id like +/// `..\..\x` would escape it. Reject anything but the registry id alphabet +/// before the id is ever joined to a path. +pub fn validate_pack_id(id: &str) -> anyhow::Result<()> { + let ok = !id.is_empty() + && id + .chars() + .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_' || c == '-'); + if !ok { + anyhow::bail!( + "invalid pack id '{id}' — ids may only contain lowercase letters, digits, '-' and '_'" + ); + } + Ok(()) +} + fn add(src: &Path, yes: bool) -> anyhow::Result<()> { let content = std::fs::read_to_string(src).with_context(|| format!("reading {}", src.display()))?; let pack = packs::RulePack::parse(&content).context("file did not parse as a valid rule pack")?; + validate_pack_id(&pack.id)?; let hash = packs::content_hash(content.as_bytes()); let store = Storage::open_default().context("opening storage")?; @@ -366,6 +383,7 @@ fn add(src: &Path, yes: bool) -> anyhow::Result<()> { } fn revoke(name: &str) -> anyhow::Result<()> { + validate_pack_id(name)?; let store = Storage::open_default().context("opening storage")?; let pin_removed = store.revoke_rule_pack(name)?; let dest = storage::data_dir() @@ -634,13 +652,19 @@ fn fetch(url: &str, sig_url: Option<&str>, extra: &[String], yes: bool) -> anyho let content = String::from_utf8(pack_bytes).context("pack is not valid UTF-8")?; let pack = packs::RulePack::parse(&content) .context("fetched file did not parse as a valid rule pack")?; + validate_pack_id(&pack.id)?; let hash = packs::content_hash(content.as_bytes()); + // M-M7: compare against the prior TOFU pin so a re-fetch that CHANGED the + // pack is flagged in the summary instead of looking like a fresh install. + let store = Storage::open_default().context("opening storage")?; + let prior = store.rule_pack_approved_hash(&pack.id)?; + println!( "📥 Fetched '{}' v{} — signature verified (publisher '{}').", pack.id, pack.version, signer ); - print_add_summary(&pack, None, &hash); + print_add_summary(&pack, prior.as_deref(), &hash); if !yes && !prompt_yes()? { println!("Aborted — '{}' not installed.", pack.id); @@ -653,7 +677,6 @@ fn fetch(url: &str, sig_url: Option<&str>, extra: &[String], yes: bool) -> anyho std::fs::create_dir_all(&dir).context("creating rules dir")?; let dest = dir.join(format!("{}.toml", pack.id)); std::fs::write(&dest, content.as_bytes()).context("installing pack file")?; - let store = Storage::open_default().context("opening storage")?; store.approve_rule_pack(&pack.id, &dest.to_string_lossy(), &hash)?; println!( "✅ Installed '{}' (publisher '{}'). It applies on the next `burnwall start`.", diff --git a/src/mcp/firewall.rs b/src/mcp/firewall.rs index b0fc918..c6b85bc 100644 --- a/src/mcp/firewall.rs +++ b/src/mcp/firewall.rs @@ -32,11 +32,18 @@ pub struct AdvertisedTool { pub name: String, pub description: String, /// Stable content fingerprint over name + description + input schema. - /// Used to detect silent post-approval changes ("rug pulls"). This is + /// A change *tripwire* over the tool's full advertised identity. This is /// FNV-1a: deterministic across runs and platforms (so persisted /// fingerprints stay comparable across binary upgrades), but it is a /// change *tripwire*, not a collision-resistant cryptographic hash. pub fingerprint: String, + /// Fingerprint over name + input schema ONLY (M-C2). This is the value + /// persisted by the watcher and the one whose change resets an approved + /// tool back to `pending`: a description-only edit (typo fix, version + /// bump in prose) must WARN but never silently revoke approval, while a + /// schema change alters what the tool can actually be asked to do and + /// therefore must force re-approval. + pub schema_fingerprint: String, /// The raw tool object, kept so the caller can re-scan it with the /// existing `SecurityEngine` (secret / path / command patterns). pub raw: Value, @@ -69,10 +76,12 @@ pub fn parse_tools_list(body: &[u8]) -> Vec { .to_string(); let schema = tool.get("inputSchema").cloned().unwrap_or(Value::Null); let fingerprint = fingerprint_tool(&name, &description, &schema); + let schema_fingerprint = fingerprint_schema(&name, &schema); Some(AdvertisedTool { name, description, fingerprint, + schema_fingerprint, raw: tool.clone(), }) }) @@ -145,15 +154,27 @@ fn is_hidden_char(c: char) -> bool { /// same. Hex-encoded for storage. fn fingerprint_tool(name: &str, description: &str, schema: &Value) -> String { let schema = serde_json::to_string(schema).unwrap_or_default(); - let mut hash: u64 = 0xcbf2_9ce4_8422_2325; - for part in [ + fnv1a_hex(&[ name.as_bytes(), b"\0", description.as_bytes(), b"\0", schema.as_bytes(), - ] { - for &byte in part { + ]) +} + +/// FNV-1a (64-bit) over name + canonicalised schema only — the persisted +/// fingerprint that drives enforce-mode re-pending (M-C2). Description text is +/// deliberately excluded; see [`AdvertisedTool::schema_fingerprint`]. +fn fingerprint_schema(name: &str, schema: &Value) -> String { + let schema = serde_json::to_string(schema).unwrap_or_default(); + fnv1a_hex(&[name.as_bytes(), b"\0", schema.as_bytes()]) +} + +fn fnv1a_hex(parts: &[&[u8]]) -> String { + let mut hash: u64 = 0xcbf2_9ce4_8422_2325; + for part in parts { + for &byte in *part { hash ^= byte as u64; hash = hash.wrapping_mul(0x0000_0100_0000_01b3); } diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs index 1246cbf..3d9c1e3 100644 --- a/src/mcp/mod.rs +++ b/src/mcp/mod.rs @@ -12,7 +12,8 @@ pub mod firewall; use std::convert::Infallible; use std::net::SocketAddr; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; +use std::time::Duration; use bytes::Bytes; use http_body_util::BodyExt as _; @@ -316,6 +317,15 @@ async fn handle( route.forward_path, query ); + // M-H2: never persist or log the query string — an upstream URI like + // `...?api_key=...` must not reach the database (mcp_events.upstream_uri + // is exported by `burnwall mcp export`). The full URI is still used for + // the forward itself; only the recorded copy is stripped. + let logged_uri = upstream_uri + .split('?') + .next() + .unwrap_or(&upstream_uri) + .to_string(); let (parts, body) = req.into_parts(); let body_bytes = match body.collect().await { @@ -338,13 +348,14 @@ async fn handle( let is_tools_list = method == Method::POST && parse_rpc_method(&body_bytes).as_deref() == Some("tools/list"); - // Security scan: the same engine the LLM proxy uses, applied to the - // raw JSON-RPC body. Walks every string leaf — that means `tools/call` - // arguments get the path / command / mount / secret denylist for free. - // A violation returns 403 and never forwards (mirrors the LLM proxy's - // 403 path); the `security_events` row gets `provider="mcp"` and the - // tool name when we have one, so `burnwall security` shows the source. - if let Some(violation) = state.security.scan(&body_bytes) { + // Security scan: the same engine the LLM proxy uses, but with MCP-aware + // scoping (M-C1). Command-shaped checks apply only to a `tools/call`'s + // `params.arguments`; the rest of the JSON-RPC envelope (and other methods) + // is treated as prose, so a memory note or issue title that merely mentions + // `rm -rf` / `~/.ssh` is not blocked. Data checks (secrets, DLP) still run + // everywhere. A violation returns 403 and never forwards; the + // `security_events` row gets `provider="mcp"` and the tool name. + if let Some(violation) = state.security.scan_mcp(&body_bytes) { warn!("🛡️ MCP BLOCKED: {}", violation.message()); let redact = state.security.rules().log_redact_details; let stored_details = if redact { @@ -376,7 +387,18 @@ async fn handle( if let Err(e) = state.storage.insert_security_event(&event) { error!("mcp security_event insert failed: {}", e); } - return Ok(error_response(StatusCode::FORBIDDEN, "auto_denied")); + // M-C2: a JSON-RPC error (not a bare body) so MCP clients render + // the reason instead of a generic transport failure. + return Ok(jsonrpc_error_response( + StatusCode::FORBIDDEN, + "auto_denied", + raw_rpc_id(&body_bytes), + format!( + "Burnwall: tool '{}' on '{}' is blocked by [mcp].auto_deny policy. \ + Remove the matching glob from [mcp].auto_deny in config.toml to allow it.", + call.name, route.server + ), + )); } } @@ -405,16 +427,35 @@ async fn handle( if let Err(e) = state.storage.insert_security_event(&event) { error!("mcp security_event insert failed: {}", e); } - return Ok(error_response(StatusCode::FORBIDDEN, "approval_required")); + // M-C2: a proper JSON-RPC error naming the exact remediation + // command, so the client surfaces it instead of a generic + // transport failure. + return Ok(jsonrpc_error_response( + StatusCode::FORBIDDEN, + "approval_required", + raw_rpc_id(&body_bytes), + format!( + "Burnwall: tool '{}' on '{}' awaits approval. Run: burnwall mcp approve {}", + call.name, route.server, route.server + ), + )); } } } + // Strip hop-by-hop headers AND `accept-encoding` (M-C4). The watcher + // inspects `tools/list` bodies for poisoning/rug-pull, and its HTTP client + // is built without decompression — so a forwarded `accept-encoding` lets + // the upstream gzip the body, blinding the firewall (and, in enforce mode, + // bricking it: nothing registers, so every call 403s with nothing to + // approve). Dropping it makes the upstream reply in identity encoding; the + // response still passes through byte-for-byte. Mirrors the LLM proxy fix. let mut outbound_headers = HeaderMap::new(); for (name, value) in parts.headers.iter() { - if !is_hop_by_hop(name.as_str()) { - outbound_headers.append(name.clone(), value.clone()); + if is_hop_by_hop(name.as_str()) || name.as_str().eq_ignore_ascii_case("accept-encoding") { + continue; } + outbound_headers.append(name.clone(), value.clone()); } let mut builder = state @@ -428,12 +469,12 @@ async fn handle( let upstream_resp = match builder.send().await { Ok(r) => r, Err(e) => { - warn!("mcp-watch upstream error for {}: {}", upstream_uri, e); + warn!("mcp-watch upstream error for {}: {}", logged_uri, e); // We still record the tool_call attempt with status 0 so // operators can spot upstream connectivity issues in the log. if let Some(call) = tool_call { let event = McpEvent::new(&call.name, call.id.as_deref(), 0) - .with_upstream_uri(&upstream_uri); + .with_upstream_uri(&logged_uri); if let Err(e) = state.storage.insert_mcp_event(&event) { error!("mcp_event insert failed: {}", e); } @@ -444,11 +485,11 @@ async fn handle( let status = upstream_resp.status(); let resp_headers = upstream_resp.headers().clone(); - debug!("mcp-watch ← {} {}", status.as_u16(), upstream_uri); + debug!("mcp-watch ← {} {}", status.as_u16(), logged_uri); if let Some(call) = tool_call { let event = McpEvent::new(&call.name, call.id.as_deref(), status.as_u16() as i64) - .with_upstream_uri(&upstream_uri); + .with_upstream_uri(&logged_uri); if let Err(e) = state.storage.insert_mcp_event(&event) { error!("mcp_event insert failed: {}", e); } @@ -457,16 +498,28 @@ async fn handle( // For `tools/list` we buffer the (small JSON) reply, run the firewall // inspection, then forward the exact same bytes — read-only, the response // is never altered. Every other shape streams straight through unbuffered. + // M-C3: the buffering is bounded by a hard 20s timeout so a stalled + // upstream (e.g. an SSE stream that never completes) cannot freeze the + // client's session init forever. The bytes are partially consumed by then, + // so pass-through is no longer possible — answer 504 instead of hanging. let body = if is_tools_list { - match upstream_resp.bytes().await { - Ok(bytes) => { - inspect_tools_list(&bytes, &state, &route.server); + match tokio::time::timeout(Duration::from_secs(20), upstream_resp.bytes()).await { + Ok(Ok(bytes)) => { + inspect_tools_list(&bytes, &state, &route.server, &route.upstream); streaming::full(bytes) } - Err(e) => { - warn!("mcp-watch upstream body error for {}: {}", upstream_uri, e); + Ok(Err(e)) => { + warn!("mcp-watch upstream body error for {}: {}", logged_uri, e); return Ok(error_response(StatusCode::BAD_GATEWAY, "upstream_error")); } + Err(_) => { + warn!( + "mcp-watch: tools/list body from {} did not complete within 20s — \ + answering 504 (body was partially consumed; pass-through impossible)", + logged_uri + ); + return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "upstream_timeout")); + } } } else { streaming::from_stream(upstream_resp.bytes_stream()) @@ -489,12 +542,22 @@ async fn handle( Ok(response.body(body).expect("response: build failed")) } +/// Last description seen per advertised tool, keyed by +/// `|/` (the upstream URL disambiguates watchers that +/// share a server name, e.g. several single-upstream instances in one +/// process). Process-local on purpose: the *persisted* state is the schema +/// fingerprint in `mcp_tools`, which drives enforce-mode re-pending; this map +/// only powers the advisory description-drift warning (M-C2), so losing it on +/// restart costs one missed warning, never an enforcement change. +static SEEN_DESCRIPTIONS: LazyLock> = + LazyLock::new(dashmap::DashMap::new); + /// Inspect a buffered `tools/list` reply for poisoned or silently-changed /// tool definitions. Read-only: findings are recorded as `security_events` /// (so `burnwall security` surfaces them) and the caller forwards the /// response bytes unchanged. Fail-open — a non-`tools/list` body yields no /// tools and no findings. -fn inspect_tools_list(body: &[u8], state: &WatchState, server: &str) { +fn inspect_tools_list(body: &[u8], state: &WatchState, server: &str, upstream: &str) { for tool in firewall::parse_tools_list(body) { // 1. Prompt-injection tells in the advertised name + description. let surface = format!("{} {}", tool.name, tool.description); @@ -516,20 +579,44 @@ fn inspect_tools_list(body: &[u8], state: &WatchState, server: &str) { } } - // 3. Rug pull — definition changed since we last fingerprinted it. - match state + // 3. Rug pull — the persisted fingerprint (name + inputSchema, M-C2) + // changed since we last saw this tool. Only a schema change resets + // an approved tool to 'pending' (via the storage layer): the schema + // is what the tool can actually be asked to do. + let schema_changed = match state .storage - .observe_mcp_tool(server, &tool.name, &tool.fingerprint) + .observe_mcp_tool(server, &tool.name, &tool.schema_fingerprint) { Ok(McpToolObservation::Changed) => { warn!( - "🛡️ MCP tool '{}' definition changed since last seen (possible rug pull)", - tool.name + "🛡️ MCP tool '{}' on server '{}' changed its input schema since last seen \ + (possible rug pull) — approval reset to pending", + tool.name, server + ); + record_mcp_security(state, "mcp_tool_changed", &tool.name, &tool.name); + true + } + Ok(_) => false, + Err(e) => { + error!("mcp_tools observe failed: {}", e); + false + } + }; + + // 4. Description drift (M-C2): a description-only change is recorded + // and warned about — descriptions are prompt-visible, so a swap is + // worth an operator's eyes — but it does NOT revoke approval. A + // routine version bump in prose must not re-pend every tool. + let desc_key = format!("{upstream}|{server}/{}", tool.name); + if let Some(prev) = SEEN_DESCRIPTIONS.insert(desc_key, tool.description.clone()) { + if prev != tool.description && !schema_changed { + warn!( + "MCP tool '{}' on server '{}' changed its description \ + (schema unchanged — approval kept)", + tool.name, server ); record_mcp_security(state, "mcp_tool_changed", &tool.name, &tool.name); } - Ok(_) => {} - Err(e) => error!("mcp_tools observe failed: {}", e), } } } @@ -557,6 +644,44 @@ fn error_response(status: StatusCode, kind: &str) -> Response { .expect("error_response: response builder failed") } +/// The raw JSON-RPC `id` of a request body (string, number, or null), +/// preserved as-is so an error response can echo it. `Null` when the body is +/// not parseable JSON or carries no id (a notification). +fn raw_rpc_id(body: &[u8]) -> Value { + let body = body.strip_prefix(b"\xef\xbb\xbf").unwrap_or(body); + serde_json::from_slice::(body) + .ok() + .and_then(|v| v.get("id").cloned()) + .unwrap_or(Value::Null) +} + +/// A blocked `tools/call` answered as a *proper JSON-RPC error* (M-C2), so MCP +/// clients show the message — which names the exact remediation command — +/// instead of a generic transport failure. The legacy `"type"` discriminator is +/// kept inside the error object for existing consumers of the 403 body. +fn jsonrpc_error_response( + status: StatusCode, + kind: &str, + id: Value, + message: String, +) -> Response { + let body = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "error": { + "code": -32000, + "message": message, + "type": kind, + }, + }); + let bytes = serde_json::to_vec(&body).unwrap_or_default(); + Response::builder() + .status(status) + .header("content-type", "application/json") + .body(streaming::full(Bytes::from(bytes))) + .expect("jsonrpc_error_response: response builder failed") +} + #[cfg(test)] mod policy_tests { use super::{glob_match, policy_matches}; diff --git a/tests/integration/mcp_watch_test.rs b/tests/integration/mcp_watch_test.rs index c82a827..eaad15e 100644 --- a/tests/integration/mcp_watch_test.rs +++ b/tests/integration/mcp_watch_test.rs @@ -449,7 +449,10 @@ async fn denied_command_in_tool_arguments_is_blocked() { let sec = storage.security_events_for_date(&today()).unwrap(); assert_eq!(sec.len(), 1); - assert_eq!(sec[0].event_type, "command_blocked"); + // `rm -rf /` is now caught by the shape-aware destructive detector rather + // than the literal deny list (S-C2 dropped the `rm` literals so scoped + // deletes like `rm -rf /tmp/x` aren't false-flagged). + assert_eq!(sec[0].event_type, "destructive_blocked"); assert_eq!(sec[0].provider.as_deref(), Some("mcp")); } @@ -477,8 +480,11 @@ async fn secret_pattern_in_tool_arguments_is_blocked() { "jsonrpc": "2.0", "method": "tools/call", "params": { + // A realistic (non-example) AWS key id — the canonical + // `AKIAIOSFODNN7EXAMPLE` is now exempted as a documentation key + // (S-C3), so use one that isn't. "name": "upload", - "arguments": {"body": "AKIAIOSFODNN7EXAMPLE"}, + "arguments": {"body": "AKIAIOSFODNN7REALKEY"}, }, "id": 13, })) @@ -492,6 +498,48 @@ async fn secret_pattern_in_tool_arguments_is_blocked() { assert_eq!(sec[0].event_type, "secret_detected"); } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn prose_mentioning_denied_command_is_not_blocked() { + // M-C1: the MCP path must be prose-safe. A non-tools/call method, or + // free-text arguments that merely *mention* a denied command, must forward + // — not 403. Here a memory-note tool stores text containing "rm -rf /". + let upstream = MockServer::start().await; + Mock::given(method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_string("{}")) + .mount(&upstream) + .await; + + let storage = Arc::new(Storage::open_in_memory().unwrap()); + let state = WatchState::single_upstream( + upstream.uri(), + reqwest::Client::new(), + storage.clone(), + Arc::new(SecurityEngine::with_defaults()), + ); + let addr = spawn_watcher(state).await; + + // A prose note that mentions a dangerous command — the tool is a note + // store, the text is data, so this must pass through. + let resp = client() + .post(format!("http://{}/mcp/rpc", addr)) + .json(&json!({ + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": "create_memory", + "arguments": {"text": "Reminder: never run `rm -rf /` on the prod server."}, + }, + "id": 21, + })) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200, "prose mention must not be blocked"); + + let sec = storage.security_events_for_date(&today()).unwrap(); + assert!(sec.is_empty(), "no security event for a prose mention"); +} + // ─────────────────── Approval workflow / enforce mode (v0.6.5) ─────────────────── /// An enforce-mode watcher in front of `upstream` (single default route). @@ -590,6 +638,236 @@ async fn enforce_mode_forwards_an_approved_tool() { assert!(sec.iter().all(|e| e.event_type != "mcp_tool_unapproved")); } +// ─────────────────── M-C2: JSON-RPC error shape on 403 ─────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn enforce_mode_block_is_a_jsonrpc_error_naming_the_remedy() { + let upstream = MockServer::start().await; + Mock::given(method("POST")) + .respond_with(ResponseTemplate::new(200)) + .expect(0) + .mount(&upstream) + .await; + + let storage = Arc::new(Storage::open_in_memory().unwrap()); + let state = enforce_state(upstream.uri(), storage.clone()); + let addr = spawn_watcher(state).await; + + let resp = client() + .post(format!("http://{}/mcp", addr)) + .json(&json!({ + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": "read_file", "arguments": {"path": "ok.txt"}}, + "id": 42, + })) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 403); + + // The body must be a proper JSON-RPC error object — id echoed, code set, + // message naming the exact remediation command — so MCP clients render it + // instead of a generic transport failure. + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["jsonrpc"], "2.0"); + assert_eq!(body["id"], 42, "request id must be echoed as-is"); + assert_eq!(body["error"]["code"], -32000); + let msg = body["error"]["message"].as_str().unwrap(); + assert!( + msg.contains("tool 'read_file' on 'default' awaits approval"), + "got: {msg}" + ); + assert!( + msg.contains("burnwall mcp approve default"), + "message must name the remediation command, got: {msg}" + ); + // Legacy discriminator preserved for existing consumers of the 403 body. + assert_eq!(body["error"]["type"], "approval_required"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn auto_denied_block_is_a_jsonrpc_error_with_string_id_echo() { + let upstream = MockServer::start().await; + Mock::given(method("POST")) + .respond_with(ResponseTemplate::new(200)) + .expect(0) + .mount(&upstream) + .await; + + let storage = Arc::new(Storage::open_in_memory().unwrap()); + let state = WatchState { + upstream: upstream.uri(), + servers: Vec::new(), + require_approval: false, + http_client: reqwest::Client::new(), + storage: storage.clone(), + security: Arc::new(SecurityEngine::with_defaults()), + auto_approve: Vec::new(), + auto_deny: vec!["default/evil_*".to_string()], + }; + let addr = spawn_watcher(state).await; + + let resp = client() + .post(format!("http://{}/mcp", addr)) + .json(&json!({ + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": "evil_exec", "arguments": {}}, + "id": "abc-1", + })) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 403); + + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["jsonrpc"], "2.0"); + assert_eq!(body["id"], "abc-1", "string ids must echo as strings"); + assert_eq!(body["error"]["code"], -32000); + assert_eq!(body["error"]["type"], "auto_denied"); + assert!(body["error"]["message"] + .as_str() + .unwrap() + .contains("auto_deny")); +} + +// ─────────────────── M-C2: description-only change keeps approval ─────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn description_only_change_warns_but_keeps_approval() { + fn reply(description: &str, schema: serde_json::Value) -> serde_json::Value { + json!({ + "jsonrpc": "2.0", + "id": 1, + "result": {"tools": [ + {"name": "drift_probe", "description": description, "inputSchema": schema} + ]} + }) + } + let schema_v1 = json!({"type": "object"}); + let schema_v2 = json!({"type": "object", "properties": {"force": {"type": "boolean"}}}); + + let upstream = MockServer::start().await; + // Three calls in order: original, description-only change, schema change. + for (i, body) in [ + reply("Reads files. v1.0.0", schema_v1.clone()), + reply("Reads files. v1.0.1 — typo fixes", schema_v1.clone()), + reply("Reads files. v1.0.1 — typo fixes", schema_v2.clone()), + ] + .into_iter() + .enumerate() + { + Mock::given(method("POST")) + .respond_with(ResponseTemplate::new(200).set_body_json(body)) + .up_to_n_times(1) + .with_priority((i + 1) as u8) + .mount(&upstream) + .await; + } + + let storage = Arc::new(Storage::open_in_memory().unwrap()); + let state = WatchState::single_upstream( + upstream.uri(), + reqwest::Client::new(), + storage.clone(), + Arc::new(SecurityEngine::with_defaults()), + ); + let addr = spawn_watcher(state).await; + let list = || async { + let r = client() + .post(format!("http://{}/mcp", addr)) + .json(&tools_list_request()) + .send() + .await + .unwrap(); + assert_eq!(r.status(), 200); + let _ = r.bytes().await; + }; + + // First sighting, then the user approves the tool. + list().await; + assert!(storage.approve_mcp_tool("default", "drift_probe").unwrap()); + + // A description-only change (routine version bump) is recorded as a + // change event but must NOT revoke approval. + list().await; + assert_eq!( + storage + .mcp_tool_trust_state("default", "drift_probe") + .unwrap() + .as_deref(), + Some("approved"), + "description-only change must not re-pend an approved tool" + ); + let after_desc = storage.security_events_for_date(&today()).unwrap(); + assert_eq!( + after_desc + .iter() + .filter(|e| e.event_type == "mcp_tool_changed") + .count(), + 1, + "description drift should still be recorded; got {after_desc:?}" + ); + + // A schema change is the real rug-pull signal: approval resets to pending. + list().await; + assert_eq!( + storage + .mcp_tool_trust_state("default", "drift_probe") + .unwrap() + .as_deref(), + Some("pending"), + "a schema change must force re-approval" + ); +} + +// ─────────────────── M-H2: query string never persisted ─────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn upstream_query_string_is_forwarded_but_never_persisted() { + let upstream = MockServer::start().await; + Mock::given(method("POST")) + .and(wiremock::matchers::query_param("api_key", "sekret123")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({"ok": true}))) + .expect(1) + .mount(&upstream) + .await; + + let storage = Arc::new(Storage::open_in_memory().unwrap()); + let state = WatchState::single_upstream( + upstream.uri(), + reqwest::Client::new(), + storage.clone(), + Arc::new(SecurityEngine::with_defaults()), + ); + let addr = spawn_watcher(state).await; + + let resp = client() + .post(format!("http://{}/rpc?api_key=sekret123", addr)) + .json(&json!({ + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": "ping", "arguments": {}}, + "id": 1, + })) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + + // The query reached the upstream (mock matched), but the persisted event + // must hold the stripped URI — credentials never hit disk. + let events = storage.mcp_events_for_date(&today()).unwrap(); + assert_eq!(events.len(), 1); + let stored = events[0].upstream_uri.as_deref().unwrap(); + assert!( + !stored.contains('?') && !stored.contains("sekret123"), + "query string must be stripped from the persisted URI, got {stored}" + ); + assert!(stored.ends_with("/rpc"), "got {stored}"); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn observe_mode_forwards_unapproved_tools_call() { // Default (require_approval = false): an unapproved call still forwards. diff --git a/tests/unit/audit_test.rs b/tests/unit/audit_test.rs new file mode 100644 index 0000000..66947b3 --- /dev/null +++ b/tests/unit/audit_test.rs @@ -0,0 +1,169 @@ +//! Audit-chain hardening tests (M-H1 / M-M3 / M-M4): +//! +//! - M-H1: a lost/regenerated audit key must REFUSE to seal (instead of +//! silently forking the chain into a forever-TAMPERED state), and +//! `audit rekey` must archive the old segment and let sealing resume. +//! - M-M3: two concurrent `seal` runs must not fork the chain. +//! - M-M4: SARIF results must carry a `locations` array (GitHub code scanning +//! rejects results without one). + +use burnwall::audit::{sarif, AuditChain, VerifyReport}; +use burnwall::providers::TokenUsage; +use burnwall::storage::{RequestRecord, SecurityEvent, Storage}; + +fn usage(input: u64, output: u64) -> TokenUsage { + TokenUsage { + input_tokens: input, + output_tokens: output, + cache_creation_tokens: 0, + cache_read_tokens: 0, + } +} + +fn seed_request(storage: &Storage) { + storage + .insert_request(&RequestRecord::successful( + "anthropic", + "claude", + &usage(100, 50), + 0.5, + None, + )) + .unwrap(); +} + +// ── M-H1: key loss → refuse to seal; rekey → resume ───────────────────────── + +#[test] +fn lost_key_refuses_to_seal_and_rekey_resumes() { + let dir = tempfile::tempdir().unwrap(); + let key_path = dir.path().join("audit_ed25519.key"); + let storage = Storage::open_in_memory().unwrap(); + + seed_request(&storage); + let original = AuditChain::open(&key_path).unwrap(); + assert_eq!(original.seal(&storage).unwrap().sealed, 1); + drop(original); + + // Simulate key loss: the key file vanishes, receipts + sidecar remain. + std::fs::remove_file(&key_path).unwrap(); + let regenerated = AuditChain::open(&key_path).unwrap(); + + seed_request(&storage); + let err = regenerated.seal(&storage).expect_err("seal must refuse"); + let msg = err.to_string(); + assert!( + msg.contains("audit key changed or lost"), + "unexpected error: {msg}" + ); + assert!( + msg.contains("burnwall audit rekey"), + "error must name the remediation command: {msg}" + ); + + // Deliberate rekey: archives the closed segment, records the new pubkey, + // and sealing resumes. + let report = regenerated.rekey(&storage).unwrap(); + assert!(report.old_key.is_some(), "old segment key should be known"); + assert_eq!(report.receipts, 1); + assert!(report.chain_head.is_some()); + assert!(report.archive.exists(), "segment archive must be written"); + let archived = std::fs::read_to_string(&report.archive).unwrap(); + assert!(archived.contains(report.old_key.as_deref().unwrap())); + + assert_eq!(regenerated.seal(&storage).unwrap().sealed, 1); +} + +#[test] +fn legacy_chain_without_sidecar_still_refuses_a_regenerated_key() { + let dir = tempfile::tempdir().unwrap(); + let key_path = dir.path().join("audit_ed25519.key"); + let storage = Storage::open_in_memory().unwrap(); + + seed_request(&storage); + let original = AuditChain::open(&key_path).unwrap(); + assert_eq!(original.seal(&storage).unwrap().sealed, 1); + drop(original); + + // Pre-sidecar chain: both the key AND the recorded pubkey are gone. The + // tail-signature check must still detect that the fresh key never signed + // the existing chain. + std::fs::remove_file(&key_path).unwrap(); + std::fs::remove_file(key_path.with_extension("pub")).unwrap(); + let regenerated = AuditChain::open(&key_path).unwrap(); + + seed_request(&storage); + let err = regenerated.seal(&storage).expect_err("seal must refuse"); + assert!(err.to_string().contains("burnwall audit rekey")); +} + +#[test] +fn reopening_the_same_key_seals_without_refusal() { + let dir = tempfile::tempdir().unwrap(); + let key_path = dir.path().join("audit_ed25519.key"); + let storage = Storage::open_in_memory().unwrap(); + + seed_request(&storage); + AuditChain::open(&key_path).unwrap().seal(&storage).unwrap(); + + // Same key file, fresh open — the normal restart path must be untouched. + seed_request(&storage); + let reopened = AuditChain::open(&key_path).unwrap(); + assert_eq!(reopened.seal(&storage).unwrap().sealed, 1); + assert_eq!( + reopened.verify(&storage).unwrap(), + VerifyReport::Intact { count: 2 } + ); +} + +// ── M-M3: concurrent seals must not fork the chain ────────────────────────── + +#[test] +fn concurrent_seals_do_not_fork_the_chain() { + let dir = tempfile::tempdir().unwrap(); + let db = dir.path().join("burnwall.db"); + let key = dir.path().join("k.key"); + + let s1 = Storage::open(&db).unwrap(); + for _ in 0..6 { + seed_request(&s1); + } + let s2 = Storage::open(&db).unwrap(); + let c1 = AuditChain::open(&key).unwrap(); + let c2 = AuditChain::open(&key).unwrap(); + + use std::sync::atomic::{AtomicU64, Ordering}; + let total = AtomicU64::new(0); + std::thread::scope(|scope| { + scope.spawn(|| total.fetch_add(c1.seal(&s1).unwrap().sealed, Ordering::SeqCst)); + scope.spawn(|| total.fetch_add(c2.seal(&s2).unwrap().sealed, Ordering::SeqCst)); + }); + + // Every row sealed exactly once between the two runs, and the resulting + // chain is a single intact line — no duplicate prev_hash fork. + assert_eq!(total.load(Ordering::SeqCst), 6); + assert_eq!( + c1.verify(&s1).unwrap(), + VerifyReport::Intact { count: 6 } + ); +} + +// ── M-M4: SARIF results carry synthetic locations ──────────────────────────── + +#[test] +fn sarif_results_carry_synthetic_locations() { + let mut event = SecurityEvent::new("path_blocked", "~/.ssh/id_rsa"); + event.id = Some(7); + let log = sarif::build(&[event]); + + let result = &log["runs"][0]["results"][0]; + let location = &result["locations"][0]["physicalLocation"]; + assert_eq!( + location["artifactLocation"]["uri"], + "burnwall://security-events/7" + ); + assert!( + location["region"]["startLine"].is_number(), + "GitHub's SARIF validator wants a region next to the artifactLocation" + ); +} diff --git a/tests/unit/rulepack_test.rs b/tests/unit/rulepack_test.rs index ad9edde..1e97954 100644 --- a/tests/unit/rulepack_test.rs +++ b/tests/unit/rulepack_test.rs @@ -284,6 +284,24 @@ fn lint_flags_empty_pack_and_missing_id() { .any(|x| x.code == "missing-id")); } +// ── M-M6 — pack id is used as a filename; reject traversal attempts ───────── + +#[test] +fn pack_id_validation_blocks_path_traversal() { + use burnwall::cli::rules::validate_pack_id; + // Registry alphabet passes. + assert!(validate_pack_id("django").is_ok()); + assert!(validate_pack_id("data-science_2").is_ok()); + // Anything that could escape the rules dir (or surprise the FS) fails. + assert!(validate_pack_id("..\\..\\x").is_err()); + assert!(validate_pack_id("../escape").is_err()); + assert!(validate_pack_id("a/b").is_err()); + assert!(validate_pack_id("a.b").is_err()); + assert!(validate_pack_id("UPPER").is_err()); + assert!(validate_pack_id("").is_err()); + assert!(validate_pack_id("nul:").is_err()); +} + #[test] fn lint_clean_pack_passes_with_only_warnings() { use burnwall::security::packs; From 178d7542795ba415f448b4595bfb56bda9bb5227 Mon Sep 17 00:00:00 2001 From: codehippie1 Date: Wed, 10 Jun 2026 15:38:48 -0400 Subject: [PATCH 7/9] lifecycle/surfaces: dead-proxy safety, PowerShell routing, PID identity, honest status Env files are liveness-gated so a crashed/rebooted proxy degrades a shell to DIRECT instead of breaking every tool; the daemon child pauses routing on graceful exit. Every surface (ribbon, status, VS Code bar) shows a loud 'proxy down' when routed at a dead port. PowerShell now gets a persistent CurrentUserAllHosts profile hook (no longer a silent dead end), bash chains into login profiles. PID files carry an image-name identity check so stop can't kill an innocent process and autostart can't bail against a reused PID. config doctor prints a per-shell routing matrix; daemon writes a size-capped log file; upgrade/install use one canonical dir and a PATH-resolved restart. Statusline msg is turn-aware, watch annotates idle data, plan window suppressed past reset, warning-grade plan status no longer reads as throttled, combined-today no longer double-counts proxied traffic. --- dist-workspace.toml | 8 + editor/vscode/src/format.ts | 30 +++- editor/vscode/test/format.test.ts | 43 ++++- src/cli/config_cmd.rs | 50 ++++++ src/cli/daemon.rs | 98 ++++++++++- src/cli/enable_routing.rs | 62 ++++--- src/cli/init.rs | 228 +++++++++++++++++++++++-- src/cli/routing.rs | 270 +++++++++++++++++++++++++----- src/cli/status.rs | 142 ++++++++++++---- src/cli/statusline.rs | 77 +++++++-- src/cli/upgrade.rs | 28 +++- src/cli/watch.rs | 104 +++++++++--- src/plan.rs | 48 +++++- src/ribbon.rs | 45 ++++- 14 files changed, 1056 insertions(+), 177 deletions(-) diff --git a/dist-workspace.toml b/dist-workspace.toml index 6b7dacf..3f21452 100644 --- a/dist-workspace.toml +++ b/dist-workspace.toml @@ -20,6 +20,14 @@ ci = "github" # npm -> an npm package using the esbuild optionalDependencies layout # msi -> a native Windows installer installers = ["shell", "powershell", "homebrew", "npm", "msi"] +# Install to the SAME directory the hand-written README installer (install.ps1 +# / install.sh) uses and persists on PATH (L-C3). Without this, cargo-dist +# defaults to $CARGO_HOME/bin, so `burnwall upgrade` (which runs the dist +# installer) wrote the new binary to a *different* dir than the running one — +# leaving the restart pointed at the old path, a second PATH entry, and an +# autostart Run-key aimed at a now-stale exe. One canonical dir removes the +# whole class. +install-path = "~/.burnwall/bin" # Target platforms to build apps for (Rust target-triple syntax) targets = ["aarch64-apple-darwin", "aarch64-unknown-linux-gnu", "x86_64-apple-darwin", "x86_64-unknown-linux-gnu", "x86_64-pc-windows-msvc"] # Where the Homebrew formula is published (the existing tap repo). diff --git a/editor/vscode/src/format.ts b/editor/vscode/src/format.ts index 3c58282..57971c4 100644 --- a/editor/vscode/src/format.ts +++ b/editor/vscode/src/format.ts @@ -6,6 +6,8 @@ export interface StatusJson { total_cost_usd?: number; combined_total_usd?: number; + proxy_running?: boolean; + env_routing?: string; blocked_requests?: number; security_events?: number; budget?: { daily_limit_usd?: number; spent_today_usd?: number }; @@ -62,6 +64,9 @@ export interface StatusSummary { plan: PlanSummary | null; /** Per-tool coverage; empty when no supported tools are installed. */ coverage: CoverageItem[]; + /** True when the env routes to the proxy but the proxy process is not + * running — every request from that environment will fail (U-C1). */ + proxyDown: boolean; } /** "time until" label for a reset countdown: `45m`, `2h28m`, `2d7h`, `now`. */ @@ -98,7 +103,9 @@ function planSummary(s: StatusJson): PlanSummary | null { primaryResetInSecs: primary.reset_in_secs, secondaryLabel: secondary ? secondary.label : null, secondaryPct: secondary ? secondary.utilization * 100 : null, - throttled: prov.status !== "allowed", + // Only positively-throttling statuses — Anthropic emits warning-grade + // intermediates (`allowed_warning`) while requests still succeed (U-H4). + throttled: ["throttled", "rejected", "blocked", "rate_limited"].includes(prov.status), }; if (!best || cand.primaryPct > best.primaryPct) { best = cand; @@ -108,7 +115,10 @@ function planSummary(s: StatusJson): PlanSummary | null { } export function summarize(s: StatusJson): StatusSummary { - const costToday = s.combined_total_usd ?? s.total_cost_usd ?? 0; + // Headline figure: the proxied total. `combined_total_usd` is now deduped + // server-side (X4), but proxied spend is the number Burnwall can vouch for; + // the combined figure is detail for the panel, not the bar. + const costToday = s.total_cost_usd ?? s.combined_total_usd ?? 0; let cacheRead = 0; let promptTotal = 0; @@ -140,12 +150,18 @@ export function summarize(s: StatusJson): StatusSummary { budgetPercent, plan: planSummary(s), coverage, + proxyDown: s.env_routing === "proxied" && s.proxy_running === false, }; } /** One-line status-bar label (VS Code `$(icon)` codicons allowed). On a * subscription, dollars are notional, so the binding limit window leads instead. */ export function statusBarText(s: StatusSummary): string { + // Routed at a dead proxy beats every other message: the user's tools are + // actively failing with connection-refused right now (U-C1). + if (s.proxyDown) { + return "$(error) Burnwall proxy DOWN — run `burnwall start`"; + } const bypassed = s.coverage.filter((c) => c.state === "bypasses"); const bypassPart = bypassed.length > 0 @@ -205,14 +221,22 @@ export function tooltip(s: StatusSummary): string { s.cacheHitRate !== null ? `Cache hit rate: ${Math.round(s.cacheHitRate * 100)}%` : `Cache hit rate: n/a`; + // On a flat-rate plan the dollar figure is notional (API-equivalent), not a + // bill — label it so a subscriber doesn't read it as money owed. + const costLine = s.plan + ? `Cost: $${s.costToday.toFixed(2)} (notional — flat-rate plan)` + : `Cost: $${s.costToday.toFixed(2)}`; const lines = [ "Burnwall — today", - `Cost: $${s.costToday.toFixed(2)}`, + costLine, budgetLine, cacheLine, `Blocked requests: ${s.blocked}`, `Security events: ${s.securityEvents}`, ]; + if (s.proxyDown) { + lines.splice(1, 0, "⛔ PROXY DOWN — tools routed here will fail to connect. Run `burnwall start`."); + } if (s.plan) { const p = s.plan; lines.push( diff --git a/editor/vscode/test/format.test.ts b/editor/vscode/test/format.test.ts index 8610805..61a71a9 100644 --- a/editor/vscode/test/format.test.ts +++ b/editor/vscode/test/format.test.ts @@ -18,9 +18,12 @@ test("summarize computes cost, blocked, cache hit rate, and budget %", () => { assert.equal(Math.round(s.budgetPercent ?? 0), 35); }); -test("combined_total_usd is preferred over total_cost_usd", () => { +test("the bar headlines the proxied total, not the combined figure (X4/U-H3)", () => { + // The proxied number is what Burnwall can vouch for; combined (proxied + + // unproxied logs) is panel detail, and previously double-counted proxied + // Claude Code into the headline. const s = summarize({ total_cost_usd: 1, combined_total_usd: 5 }); - assert.equal(s.costToday, 5); + assert.equal(s.costToday, 1); }); test("no tokens -> null cache hit rate; no limit -> null budget %", () => { @@ -99,6 +102,42 @@ test("subscription plan: throttled flag surfaces", () => { assert.ok(statusBarText(s).includes("throttled")); }); +test("warning-grade plan status is NOT throttled (U-H4)", () => { + const s = summarize({ + plan: { + providers: [ + { + provider: "anthropic", + status: "allowed_warning", + windows: [{ label: "5h", utilization: 0.85, reset_in_secs: 600 }], + }, + ], + }, + }); + assert.equal(s.plan?.throttled, false); + assert.ok(!statusBarText(s).includes("throttled")); +}); + +test("routed at a dead proxy beats all other status (U-C1)", () => { + const s = summarize({ + total_cost_usd: 2, + env_routing: "proxied", + proxy_running: false, + }); + assert.equal(s.proxyDown, true); + assert.ok(statusBarText(s).includes("DOWN")); + assert.ok(tooltip(s).includes("PROXY DOWN")); +}); + +test("proxy running while routed is not flagged down", () => { + const s = summarize({ + total_cost_usd: 2, + env_routing: "proxied", + proxy_running: true, + }); + assert.equal(s.proxyDown, false); +}); + test("coverage: a bypassing tool warns in the status bar and tooltip", () => { const s = summarize({ total_cost_usd: 2, diff --git a/src/cli/config_cmd.rs b/src/cli/config_cmd.rs index ea7a97e..39afbc1 100644 --- a/src/cli/config_cmd.rs +++ b/src/cli/config_cmd.rs @@ -175,6 +175,56 @@ fn doctor(path: &Path) -> anyhow::Result<()> { )?; } + // Per-shell routing matrix (L-H4): env-file state × rc-hook presence × + // proxy liveness — the exact table a stranded "connection refused" user + // needs, which no single surface printed before. Names the precise + // missing link per shell rather than a generic "run enable-routing". + writeln!(out)?; + writeln!(out, "Routing matrix (per shell):")?; + let proxy_up = crate::cli::routing::proxy_port_alive( + cfg.proxy.port, + std::time::Duration::from_millis(120), + ); + writeln!( + out, + " proxy: {} (port {})", + if proxy_up { "🟢 listening" } else { "⚪ not running" }, + cfg.proxy.port + )?; + for shell in crate::cli::init::Shell::ALL { + use crate::cli::routing::{env_file_state, rc_hook_present, EnvFileState}; + let env = match env_file_state(shell) { + Some(EnvFileState::Active) => "active", + Some(EnvFileState::Paused) => "paused", + Some(EnvFileState::Disabled) => "disabled", + None => "absent", + }; + let hook = rc_hook_present(shell); + let verdict = match (env, hook, proxy_up) { + ("active", true, true) => "🟢 routed".to_string(), + ("active", true, false) => { + "🟡 will route once the proxy starts (liveness-gated)".to_string() + } + // Diagnostic only — machine state, not config state, so it never + // flips the doctor's error/warning summary. + ("active", false, _) | ("paused", false, _) => format!( + "⚠️ env file present but no shell hook — add it with `burnwall enable-routing` (run from {})", + shell.label() + ), + ("paused", true, _) => "⏸ paused — `burnwall start` re-enables".to_string(), + ("disabled", _, _) => "⏹ explicitly disabled".to_string(), + _ => "— not configured".to_string(), + }; + writeln!( + out, + " {:<11} env:{:<9} hook:{:<3} {}", + shell.label(), + env, + if hook { "yes" } else { "no" }, + verdict + )?; + } + writeln!(out)?; if errors == 0 && warnings == 0 { writeln!(out, "✅ No problems found.")?; diff --git a/src/cli/daemon.rs b/src/cli/daemon.rs index 50461b3..52768dd 100644 --- a/src/cli/daemon.rs +++ b/src/cli/daemon.rs @@ -142,6 +142,11 @@ pub async fn spawn_background(args: &StartArgs) -> anyhow::Result<()> { resolved_port(args) )); } + // Name the log file so a later crash is diagnosable (L-H2) — + // before this, a dead daemon left nothing to look at. + if let Some(log) = resolved_child_log_path() { + println!(" Logs: {}", log.display()); + } println!(" Check it with `burnwall status`; stop it with `burnwall stop`."); return Ok(()); } @@ -163,10 +168,18 @@ pub async fn spawn_background(args: &StartArgs) -> anyhow::Result<()> { } /// Rebuild the `start` argument list for the child, dropping `--daemon`. -/// The child always gets `--no-routing`: the launcher handles routing (and -/// its messaging) after readiness, and `burnwall stop` handles the pause. +/// The child gets `--no-routing` (the launcher handles the resume and its +/// messaging after readiness) plus `--pause-routing-on-exit` so a *gracefully* +/// exiting daemon still pauses routing itself — `burnwall stop` covers the +/// normal path, but a child that shuts down without `stop` (SIGTERM from the +/// OS, session logout) must not strand Active env files (L-C1). Hard kills get +/// no cleanup anywhere — the liveness-gated env files cover that case. fn child_args(args: &StartArgs) -> Vec { - let mut out = vec!["start".to_string(), "--no-routing".to_string()]; + let mut out = vec![ + "start".to_string(), + "--no-routing".to_string(), + "--pause-routing-on-exit".to_string(), + ]; if let Some(port) = args.port { out.push("--port".to_string()); out.push(port.to_string()); @@ -187,6 +200,15 @@ fn child_args(args: &StartArgs) -> Vec { out } +/// The log file the daemon child will write — same config resolution the +/// child itself performs. +fn resolved_child_log_path() -> Option { + let cfg = crate::config::default_path() + .ok() + .and_then(|p| crate::config::load_or_default(&p).ok())?; + super::start::resolved_log_path(&cfg.logging) +} + /// The port the child will serve on: the explicit flag, else the configured /// port, else the built-in default — same resolution `start` itself uses. fn resolved_port(args: &StartArgs) -> u16 { @@ -357,14 +379,55 @@ fn append_arg_quoted(cmd: &mut Vec, arg: &std::ffi::OsStr) { } } -/// Is a process with this PID currently alive? +/// Is a process with this PID currently alive **and actually burnwall**? +/// +/// PID files have an inherent reuse hazard (L-H1): after a reboot or crash the +/// stale file's PID is frequently reassigned to an unrelated process. Without +/// an identity check, autostart would bail "already running" against a random +/// process (so the proxy never starts while env files claim routing), and +/// `burnwall stop` could hard-kill an innocent process — the user's browser or +/// IDE. A PID that exists but isn't burnwall is treated as *stale*. #[cfg(unix)] pub fn process_is_alive(pid: u32) -> bool { // kill(pid, 0) sends no signal — it just reports whether the process // exists and is signalable. EPERM means it exists but is owned by - // someone else, which still counts as "alive". + // someone else (and so is certainly not our daemon). let ret = unsafe { libc::kill(pid as libc::pid_t, 0) }; - ret == 0 || std::io::Error::last_os_error().raw_os_error() == Some(libc::EPERM) + if ret != 0 { + return false; + } + process_is_burnwall(pid) +} + +/// Identity check via the process image name. Fail-open: if the platform +/// lookup fails (permissions, exotic kernel), assume it IS burnwall — wrongly +/// treating a live daemon as stale would double-start, which is worse than the +/// rare false "already running". +#[cfg(unix)] +fn process_is_burnwall(pid: u32) -> bool { + // Linux: /proc//exe symlink. macOS: no /proc — fall back to `ps`. + #[cfg(target_os = "linux")] + { + match std::fs::read_link(format!("/proc/{pid}/exe")) { + Ok(p) => p + .file_name() + .map(|n| n.to_string_lossy().contains("burnwall")) + .unwrap_or(true), + Err(_) => true, + } + } + #[cfg(not(target_os = "linux"))] + { + match std::process::Command::new("ps") + .args(["-p", &pid.to_string(), "-o", "comm="]) + .output() + { + Ok(out) if out.status.success() => { + String::from_utf8_lossy(&out.stdout).contains("burnwall") + } + _ => true, + } + } } /// Ask the process to terminate. Unix sends SIGTERM, which the proxy @@ -380,12 +443,14 @@ pub fn terminate_process(pid: u32) -> anyhow::Result<()> { } } -/// Is a process with this PID currently alive? +/// Is a process with this PID currently alive **and actually burnwall**? +/// See the Unix variant for why the identity check matters (PID reuse, L-H1). #[cfg(windows)] pub fn process_is_alive(pid: u32) -> bool { use windows_sys::Win32::Foundation::CloseHandle; use windows_sys::Win32::System::Threading::{ - GetExitCodeProcess, OpenProcess, PROCESS_QUERY_LIMITED_INFORMATION, + GetExitCodeProcess, OpenProcess, QueryFullProcessImageNameW, + PROCESS_QUERY_LIMITED_INFORMATION, }; // A process that has fully exited reports an exit code other than // STILL_ACTIVE (259). A process that genuinely exits *with* 259 would be @@ -398,8 +463,23 @@ pub fn process_is_alive(pid: u32) -> bool { } let mut exit_code: u32 = 0; let queried = GetExitCodeProcess(handle, &mut exit_code); + if queried == 0 || exit_code != STILL_ACTIVE { + CloseHandle(handle); + return false; + } + // Identity check (L-H1): the PID is live, but is it burnwall? A reused + // PID belonging to another program must read as stale — otherwise + // autostart bails against a random process and `stop` could kill it. + // Fail-open on lookup failure (assume burnwall) — see the Unix variant. + let mut buf = [0u16; 1024]; + let mut len = buf.len() as u32; + let ok = QueryFullProcessImageNameW(handle, 0, buf.as_mut_ptr(), &mut len); CloseHandle(handle); - queried != 0 && exit_code == STILL_ACTIVE + if ok == 0 { + return true; + } + let image = String::from_utf16_lossy(&buf[..len as usize]).to_ascii_lowercase(); + image.contains("burnwall") } } diff --git a/src/cli/enable_routing.rs b/src/cli/enable_routing.rs index a0d90e4..76387d4 100644 --- a/src/cli/enable_routing.rs +++ b/src/cli/enable_routing.rs @@ -84,27 +84,25 @@ pub async fn run_cmd(args: EnableRoutingArgs) -> Result<()> { let mut writes: Vec = Vec::new(); for shell in targets { let env_path = routing::write_env_file(shell, &args.proxy_url)?; - let hook = if shell.rc_path().is_some() { - match routing::install_rc_hook(shell, &env_path) { - Ok(b) => Some(b), - Err(e) => { - // A real I/O failure on a shell that *does* have an rc file. - if !eval_mode { - let est = Styler::stderr(); - eprintln!( - "{}", - est.yellow(&format!( - "burnwall: could not install rc hook for {} ({e}). \ - The env file is written but won't auto-load.", - shell.label() - )) - ); - } - Some(false) + // Every shell gets a persistent hook now — including PowerShell, whose + // CurrentUserAllHosts profile(s) install_rc_hook manages (L-C2: the + // default Windows shell used to be a silent dead end here). + let hook = match routing::install_rc_hook(shell, &env_path) { + Ok(b) => Some(b), + Err(e) => { + if !eval_mode { + let est = Styler::stderr(); + eprintln!( + "{}", + est.yellow(&format!( + "burnwall: could not install rc hook for {} ({e}). \ + The env file is written but won't auto-load.", + shell.label() + )) + ); } + Some(false) } - } else { - None // PowerShell: we don't auto-edit the profile (by design). }; writes.push(ShellWrite { shell, @@ -135,21 +133,33 @@ pub async fn run_cmd(args: EnableRoutingArgs) -> Result<()> { sty.bold(&tag), sty.blue(&w.env_path.display().to_string()) )?; - match (w.hook, w.shell.rc_path()) { - (Some(true), Some(rc)) => writeln!( + let hook_label = if w.shell == crate::cli::init::Shell::Powershell { + routing::powershell_profile_paths() + .iter() + .map(|p| p.display().to_string()) + .collect::>() + .join(", ") + } else { + w.shell + .rc_path() + .map(|p| p.display().to_string()) + .unwrap_or_else(|| w.shell.label().to_string()) + }; + match w.hook { + Some(true) => writeln!( out, " rc hook: {} (sourced on new shells)", - sty.blue(&rc.display().to_string()) + sty.blue(&hook_label) )?, - (Some(false), Some(rc)) => writeln!( + Some(false) => writeln!( out, " rc hook: {} (already present — left unchanged)", - sty.blue(&rc.display().to_string()) + sty.blue(&hook_label) )?, - _ => writeln!( + None => writeln!( out, " rc hook: {}", - sty.yellow("PowerShell profile not auto-edited — use the eval line below") + sty.yellow("not installed — use the eval line below for this session") )?, } } diff --git a/src/cli/init.rs b/src/cli/init.rs index 59a81c9..a3c5095 100644 --- a/src/cli/init.rs +++ b/src/cli/init.rs @@ -206,6 +206,77 @@ pub fn binary_in_path_var(name: &str, path_var: &std::ffi::OsStr) -> bool { false } +/// Locate a Git-for-Windows `bash.exe` by finding `git.exe` on the given +/// PATH-formatted value and probing the Git install tree around it. +/// +/// Keyed off `git.exe` rather than `bash.exe` deliberately: WSL also ships a +/// `bash.exe` (in System32), but WSL has its own home and filesystem, so a +/// hook written to the Windows `~/.bashrc` would never reach it. Git Bash +/// keeps `HOME` at `%USERPROFILE%` — exactly where our rc hook lands. +pub fn git_bash_from_path_var(path_var: &std::ffi::OsStr) -> Option { + for dir in env::split_paths(path_var) { + if !dir.join("git.exe").is_file() { + continue; + } + // git.exe lives in `\cmd`, `\bin`, or + // `\mingw64\bin`; bash.exe in `\bin` or + // `\usr\bin`. Probing two ancestors covers all three. + let ancestors = [dir.parent(), dir.parent().and_then(Path::parent)]; + for root in ancestors.into_iter().flatten() { + for cand in [ + root.join("bin").join("bash.exe"), + root.join("usr").join("bin").join("bash.exe"), + ] { + if cand.is_file() { + return Some(cand); + } + } + } + } + None +} + +/// Where this shell's source hook lands, for human-readable output. +/// PowerShell hooks live in the managed `CurrentUserAllHosts` profile(s) +/// rather than a classic rc file (L-C2). +fn hook_target_label(shell: Shell) -> String { + if shell == Shell::Powershell { + let paths = super::routing::powershell_profile_paths(); + if paths.is_empty() { + return "the PowerShell profile".to_string(); + } + return paths + .iter() + .map(|p| p.display().to_string()) + .collect::>() + .join(" and "); + } + shell + .rc_path() + .map(|p| p.display().to_string()) + .unwrap_or_else(|| format!("the {} profile", shell.label())) +} + +/// Find Git Bash on this machine: PATH first, then the standard installer +/// locations (Git for Windows can be installed without PATH integration). +pub fn git_bash_path() -> Option { + if let Some(p) = git_bash_from_path_var(&env::var_os("PATH").unwrap_or_default()) { + return Some(p); + } + let roots = [ + env::var_os("ProgramFiles").map(|p| PathBuf::from(p).join("Git")), + env::var_os("ProgramFiles(x86)").map(|p| PathBuf::from(p).join("Git")), + env::var_os("LOCALAPPDATA").map(|p| PathBuf::from(p).join("Programs").join("Git")), + ]; + for root in roots.into_iter().flatten() { + let cand = root.join("bin").join("bash.exe"); + if cand.is_file() { + return Some(cand); + } + } + None +} + const MARKER: &str = "# Added by burnwall init"; /// Append `lines` to `rc_path`, separated from existing content with a @@ -300,13 +371,25 @@ pub fn run_cmd(args: InitArgs) -> anyhow::Result<()> { for line in super::routing::export_lines(s, &args.proxy_url) { writeln!(out, " {}", line)?; } - if let Some(rc) = s.rc_path() { - writeln!(out, " append source line to {}", rc.display())?; - } else { - writeln!(out, " (no rc file for {} — manual step needed)", s.label())?; - } + writeln!(out, " append source line to {}", hook_target_label(s))?; if args.apply { - let env_path = super::routing::write_env_file(s, &args.proxy_url)?; + // Preflight (M1): writing an Active env file with no proxy serving + // means every new terminal exports a dead-port URL — the user's + // first contact with Burnwall becomes "it broke my AI tool". When + // the proxy isn't up yet, write the *paused* stub instead; `start` + // flips it Active automatically once the port is actually bound. + let proxy_up = super::routing::proxy_alive_for_url(&args.proxy_url).unwrap_or(false); + let env_path = if proxy_up { + super::routing::write_env_file(s, &args.proxy_url)? + } else { + let path = super::routing::env_file_path(s) + .ok_or_else(|| anyhow::anyhow!("locating env file path"))?; + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + std::fs::write(&path, super::routing::env_file_paused(s))?; + path + }; let hook_added = match super::routing::install_rc_hook(s, &env_path) { Ok(b) => b, Err(e) => { @@ -314,18 +397,65 @@ pub fn run_cmd(args: InitArgs) -> anyhow::Result<()> { false } }; - writeln!(out, " ✓ env file written: {}", env_path.display())?; + if proxy_up { + writeln!(out, " ✓ env file written: {}", env_path.display())?; + } else { + writeln!( + out, + " ✓ env file written (paused): {} — routing activates when you run `burnwall start`", + env_path.display() + )?; + } if hook_added { - if let Some(rc) = s.rc_path() { - writeln!(out, " ✓ rc hook added to {}", rc.display())?; - } - } else if let Some(rc) = s.rc_path() { - writeln!(out, " • rc hook already present in {}", rc.display())?; + writeln!(out, " ✓ rc hook added to {}", hook_target_label(s))?; + } else { + writeln!(out, " • rc hook already present in {}", hook_target_label(s))?; } } } else { writeln!(out, " (shell not detected — set ANTHROPIC_BASE_URL / OPENAI_BASE_URL manually)")?; } + + // Git Bash on Windows: init run from a PowerShell terminal detects + // PowerShell, but Git Bash commonly coexists and shares the same Windows + // home — and an unhooked bash session silently goes direct to the + // provider. Detect it and offer to wire it up in the same pass. + if cfg!(windows) + && shell == Some(Shell::Powershell) + && !super::routing::rc_hook_present(Shell::Bash) + && git_bash_path().is_some() + { + let rc_label = Shell::Bash + .rc_path() + .map(|p| p.display().to_string()) + .unwrap_or_else(|| "~/.bashrc".to_string()); + writeln!(out)?; + writeln!(out, " Git Bash detected — bash sessions are not routed yet.")?; + if !args.apply { + let env_file = super::routing::env_file_path(Shell::Bash) + .map(|p| p.display().to_string()) + .unwrap_or_else(|| "".to_string()); + writeln!(out, " {action_label}: write env file ({env_file})")?; + writeln!(out, " append source line to {rc_label}")?; + } else { + let hook_bash = + args.yes || prompt_yes_no(&mut out, " Also enable routing for Git Bash?")?; + if hook_bash { + let env_path = super::routing::write_env_file(Shell::Bash, &args.proxy_url)?; + writeln!(out, " ✓ env file written: {}", env_path.display())?; + match super::routing::install_rc_hook(Shell::Bash, &env_path) { + Ok(true) => writeln!(out, " ✓ rc hook added to {rc_label}")?, + Ok(false) => writeln!(out, " • rc hook already present in {rc_label}")?, + Err(e) => writeln!(out, " ⚠ rc hook skipped: {}", e)?, + } + } else { + writeln!( + out, + " • skipped (run `burnwall enable-routing` from Git Bash to add it later)" + )?; + } + } + } writeln!(out)?; // 2. Login service (always opt-in: --install-service flag or interactive @@ -393,10 +523,17 @@ pub fn run_cmd(args: InitArgs) -> anyhow::Result<()> { writeln!(out)?; } - // 3. Next steps. + // 3. Next steps. Starting the proxy comes FIRST: routing only activates + // once the port is bound, so it is the step everything else hangs on. writeln!(out, "▶ Next steps")?; if args.apply { - writeln!(out, " • New shells will source the env file automatically.")?; + if !want_service { + writeln!(out, " • Start the proxy: burnwall start --daemon")?; + } + writeln!( + out, + " • New shells then source the env file automatically (routing engages only while the proxy is up)." + )?; writeln!(out, " • Apply to *this* shell now without restarting:")?; match shell { Some(Shell::Powershell) => { @@ -406,9 +543,6 @@ pub fn run_cmd(args: InitArgs) -> anyhow::Result<()> { writeln!(out, " eval \"$(burnwall enable-routing)\"")?; } } - if !want_service { - writeln!(out, " • Start the proxy: burnwall start --daemon")?; - } writeln!(out, " • Kill switch (instant bypass): export BURNWALL_BYPASS=1")?; } else { writeln!(out, " • Re-run with --apply to execute.")?; @@ -437,3 +571,63 @@ fn prompt_yes_no(out: &mut W, question: &str) -> anyhow::Result let answer = line.trim().to_ascii_lowercase(); Ok(answer.is_empty() || answer == "y" || answer == "yes") } + +#[cfg(test)] +mod tests { + use super::*; + + fn touch(path: &Path) { + std::fs::create_dir_all(path.parent().unwrap()).unwrap(); + std::fs::write(path, "").unwrap(); + } + + #[test] + fn git_bash_found_next_to_git_exe() { + // Standard Git-for-Windows layout: git.exe in cmd\, bash.exe in bin\. + let tmp = tempfile::tempdir().unwrap(); + let root = tmp.path().join("Git"); + touch(&root.join("cmd").join("git.exe")); + touch(&root.join("bin").join("bash.exe")); + let path_var = env::join_paths([root.join("cmd")]).unwrap(); + assert_eq!( + git_bash_from_path_var(&path_var), + Some(root.join("bin").join("bash.exe")) + ); + } + + #[test] + fn git_bash_found_from_mingw64_bin() { + // PATH carries mingw64\bin; bash.exe is two levels up under usr\bin. + let tmp = tempfile::tempdir().unwrap(); + let root = tmp.path().join("Git"); + touch(&root.join("mingw64").join("bin").join("git.exe")); + touch(&root.join("usr").join("bin").join("bash.exe")); + let path_var = env::join_paths([root.join("mingw64").join("bin")]).unwrap(); + assert_eq!( + git_bash_from_path_var(&path_var), + Some(root.join("usr").join("bin").join("bash.exe")) + ); + } + + #[test] + fn wsl_style_bash_without_git_is_not_git_bash() { + // WSL ships System32\bash.exe with no git.exe beside it. WSL has its + // own home, so hooking the Windows ~/.bashrc would do nothing — the + // detector must not count it. + let tmp = tempfile::tempdir().unwrap(); + let sys32 = tmp.path().join("System32"); + touch(&sys32.join("bash.exe")); + let path_var = env::join_paths([sys32]).unwrap(); + assert_eq!(git_bash_from_path_var(&path_var), None); + } + + #[test] + fn git_without_bash_is_not_git_bash() { + // MinGit / scm-only installs have git.exe but no bash. + let tmp = tempfile::tempdir().unwrap(); + let root = tmp.path().join("Git"); + touch(&root.join("cmd").join("git.exe")); + let path_var = env::join_paths([root.join("cmd")]).unwrap(); + assert_eq!(git_bash_from_path_var(&path_var), None); + } +} diff --git a/src/cli/routing.rs b/src/cli/routing.rs index fb84228..46f8767 100644 --- a/src/cli/routing.rs +++ b/src/cli/routing.rs @@ -235,26 +235,71 @@ pub fn manual_unset_hint(shell: Shell) -> &'static str { } } -/// Lines that set the proxy env vars for the given shell. +/// Lines that set the proxy env vars for the given shell — **liveness-gated** +/// (L-C1): the exports only happen if the proxy port actually answers at the +/// moment the shell starts. This is the structural fix for the dead-proxy +/// trap: a crash, `kill`, or reboot can never run any cleanup, so without the +/// gate every new shell would export a base URL pointing at a dead port and +/// every AI tool would fail with connection-refused until the user figured out +/// `burnwall start`. With the gate, a shell opened against a dead proxy +/// silently goes DIRECT (unprotected, but *working*) and the next `start` +/// covers new shells again. +/// +/// Probe cost: a loopback TCP connect is sub-millisecond when the proxy is +/// listening and fails immediately (RST) when nothing is bound — there's no +/// human-perceptible shell-startup cost. pub fn export_lines(shell: Shell, proxy_url: &str) -> Vec { let anthropic = format!("{}/anthropic", proxy_url); let openai = format!("{}/openai", proxy_url); + let port = proxy_url_port(proxy_url); match shell { - Shell::Zsh | Shell::Bash => vec![ - format!("export ANTHROPIC_BASE_URL=\"{}\"", anthropic), - format!("export OPENAI_BASE_URL=\"{}\"", openai), - ], + Shell::Zsh | Shell::Bash => vec![format!( + "if (exec 3<>/dev/tcp/127.0.0.1/{port}) 2>/dev/null; then exec 3>&-; export ANTHROPIC_BASE_URL=\"{anthropic}\"; export OPENAI_BASE_URL=\"{openai}\"; fi" + )], Shell::Fish => vec![ - format!("set -gx ANTHROPIC_BASE_URL \"{}\"", anthropic), - format!("set -gx OPENAI_BASE_URL \"{}\"", openai), - ], - Shell::Powershell => vec![ - format!("$env:ANTHROPIC_BASE_URL = \"{}\"", anthropic), - format!("$env:OPENAI_BASE_URL = \"{}\"", openai), + // fish has no /dev/tcp; probe via bash when available (it is on any + // dev box that also has fish), otherwise export ungated. + format!( + "if not command -q bash; or bash -c 'exec 3<>/dev/tcp/127.0.0.1/{port}' 2>/dev/null; set -gx ANTHROPIC_BASE_URL \"{anthropic}\"; set -gx OPENAI_BASE_URL \"{openai}\"; end" + ), ], + Shell::Powershell => vec![format!( + "try {{ $__bw = [Net.Sockets.TcpClient]::new('127.0.0.1', {port}); $__bw.Dispose(); $env:ANTHROPIC_BASE_URL = \"{anthropic}\"; $env:OPENAI_BASE_URL = \"{openai}\" }} catch {{}}" + )], } } +/// Extract the port from a proxy URL (`http://localhost:4100` → 4100), falling +/// back to the default proxy port. +fn proxy_url_port(proxy_url: &str) -> u16 { + let after_scheme = proxy_url.split("://").nth(1).unwrap_or(proxy_url); + let authority = after_scheme.split(['/', '?', '#']).next().unwrap_or(""); + authority + .rsplit(':') + .next() + .and_then(|p| p.parse().ok()) + .unwrap_or(4100) +} + +/// Quick TCP liveness probe of the local proxy port (used by status surfaces +/// to distinguish "routed and protected" from "routed at a dead port"). +pub fn proxy_port_alive(port: u16, timeout: std::time::Duration) -> bool { + let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port)); + std::net::TcpStream::connect_timeout(&addr, timeout).is_ok() +} + +/// Liveness-probe the proxy that `base_url` points at. `None` if the URL isn't +/// loopback (nothing local to probe). +pub fn proxy_alive_for_url(base_url: &str) -> Option { + if !url_is_loopback(base_url) { + return None; + } + Some(proxy_port_alive( + proxy_url_port(base_url), + std::time::Duration::from_millis(80), + )) +} + /// Lines that unset the proxy env vars for the given shell. Used by /// `disable-routing` in eval-output mode so the current shell drops them /// without a restart. @@ -420,11 +465,60 @@ pub fn env_file_present(shell: Shell) -> bool { env_file_path(shell).map(|p| p.exists()).unwrap_or(false) } +/// The PowerShell `CurrentUserAllHosts` profile paths burnwall manages. Both +/// editions are covered on Windows — Windows PowerShell 5.1 reads +/// `Documents\WindowsPowerShell\profile.ps1` and PowerShell 7+ reads +/// `Documents\PowerShell\profile.ps1` — because either can be the user's daily +/// shell. `dirs::document_dir()` resolves known-folder redirection (OneDrive). +/// PowerShell *was* the one shell never auto-edited, which made persistent +/// routing on the default Windows shell a silent dead end (L-C2). +pub fn powershell_profile_paths() -> Vec { + #[cfg(windows)] + { + let Some(docs) = dirs::document_dir() else { + return Vec::new(); + }; + vec![ + docs.join("WindowsPowerShell").join("profile.ps1"), + docs.join("PowerShell").join("profile.ps1"), + ] + } + #[cfg(not(windows))] + { + let Some(home) = dirs::home_dir() else { + return Vec::new(); + }; + vec![home.join(".config").join("powershell").join("profile.ps1")] + } +} + +/// Bash *login-shell* profile files, in bash's own lookup order. Git Bash +/// terminals and macOS Terminal run login shells, which read the first of +/// these that exists and only read `.bashrc` if that file chains to it — so a +/// hook placed solely in `.bashrc` can silently never execute (L-H3). +fn bash_profile_paths() -> Vec { + let Some(home) = dirs::home_dir() else { + return Vec::new(); + }; + vec![ + home.join(".bash_profile"), + home.join(".bash_login"), + home.join(".profile"), + ] +} + /// True if this shell's rc file carries our source-hook marker — i.e. the user /// previously wired this shell up. The strongest signal that a shell is /// "configured", and the one that disambiguates bash vs zsh (which share a -/// single `env.sh`). +/// single `env.sh`). PowerShell checks its managed profile paths. pub fn rc_hook_present(shell: Shell) -> bool { + if shell == Shell::Powershell { + return powershell_profile_paths().iter().any(|p| { + std::fs::read_to_string(p) + .map(|c| c.contains(RC_MARKER)) + .unwrap_or(false) + }); + } shell .rc_path() .and_then(|rc| std::fs::read_to_string(rc).ok()) @@ -438,17 +532,14 @@ pub fn routing_active(shell: Shell) -> bool { env_file_state(shell) == Some(EnvFileState::Active) } -/// Append the rc-source line to the user's shell rc, if not already there. -/// Returns `true` if the file was modified. -pub fn install_rc_hook(shell: Shell, env_path: &Path) -> Result { - let rc = shell - .rc_path() - .ok_or_else(|| anyhow::anyhow!("no rc file for shell {}", shell.label()))?; - let existing = std::fs::read_to_string(&rc).unwrap_or_default(); +/// Append the marker-carrying `line` to `path` if it isn't already there, +/// creating parent dirs. Returns `true` if the file was modified. +fn append_hook_line(path: &Path, line: &str) -> Result { + let existing = std::fs::read_to_string(path).unwrap_or_default(); if existing.contains(RC_MARKER) { return Ok(false); } - if let Some(parent) = rc.parent() { + if let Some(parent) = path.parent() { std::fs::create_dir_all(parent) .with_context(|| format!("creating {}", parent.display()))?; } @@ -456,20 +547,68 @@ pub fn install_rc_hook(shell: Shell, env_path: &Path) -> Result { if !content.is_empty() && !content.ends_with('\n') { content.push('\n'); } - content.push_str(&rc_source_line(shell, env_path)); + content.push_str(line); content.push('\n'); - std::fs::write(&rc, content).with_context(|| format!("writing {}", rc.display()))?; + std::fs::write(path, content).with_context(|| format!("writing {}", path.display()))?; Ok(true) } -/// Remove the rc-source line (the one carrying [`RC_MARKER`]) from the user's -/// shell rc. Used by `uninstall`. Returns `true` if a line was removed. Missing -/// rc file or no marker line → `false` (nothing to do). -pub fn remove_rc_hook(shell: Shell) -> Result { - let Some(rc) = shell.rc_path() else { - return Ok(false); - }; - let existing = match std::fs::read_to_string(&rc) { +/// Append the rc-source line to the user's shell rc, if not already there. +/// Returns `true` if any file was modified. +/// +/// PowerShell: writes the managed `CurrentUserAllHosts` profile(s) — every +/// edition whose profile dir already exists, or the first (Windows PowerShell) +/// one when none does (L-C2). The dot-source line is `Test-Path`-guarded, so a +/// machine with script-execution disabled merely no-ops. +/// +/// Bash: also chains into the first existing login-profile file +/// (`.bash_profile` / `.bash_login` / `.profile`) when that file doesn't read +/// `.bashrc` — Git Bash and macOS terminals run *login* shells, which never +/// see a hook that lives only in `.bashrc` (L-H3). +pub fn install_rc_hook(shell: Shell, env_path: &Path) -> Result { + if shell == Shell::Powershell { + let line = rc_source_line(shell, env_path); + let paths = powershell_profile_paths(); + if paths.is_empty() { + anyhow::bail!("could not locate a PowerShell profile directory"); + } + let mut targets: Vec<&PathBuf> = paths + .iter() + .filter(|p| p.parent().map(|d| d.exists()).unwrap_or(false)) + .collect(); + if targets.is_empty() { + targets.push(&paths[0]); + } + let mut changed = false; + for p in targets { + changed |= append_hook_line(p, &line)?; + } + return Ok(changed); + } + + let rc = shell + .rc_path() + .ok_or_else(|| anyhow::anyhow!("no rc file for shell {}", shell.label()))?; + let mut changed = append_hook_line(&rc, &rc_source_line(shell, env_path))?; + + if shell == Shell::Bash { + // Login-shell chaining (L-H3): if a profile file exists and neither + // sources .bashrc nor carries our hook, login shells would never run + // the hook above — add it to the first such file in bash's own order. + if let Some(profile) = bash_profile_paths().iter().find(|p| p.exists()) { + let contents = std::fs::read_to_string(profile).unwrap_or_default(); + if !contents.contains(".bashrc") && !contents.contains(RC_MARKER) { + changed |= append_hook_line(profile, &rc_source_line(shell, env_path))?; + } + } + } + Ok(changed) +} + +/// Strip marker-carrying lines from one file. `false` when the file is missing +/// or carries no marker. +fn remove_hook_lines(path: &Path) -> Result { + let existing = match std::fs::read_to_string(path) { Ok(s) => s, Err(_) => return Ok(false), }; @@ -484,35 +623,82 @@ pub fn remove_rc_hook(shell: Shell) -> Result { if !out.is_empty() { out.push('\n'); } - std::fs::write(&rc, out).with_context(|| format!("writing {}", rc.display()))?; + std::fs::write(path, out).with_context(|| format!("writing {}", path.display()))?; Ok(true) } +/// Remove the rc-source line (the one carrying [`RC_MARKER`]) from the user's +/// shell rc. Used by `uninstall`. Returns `true` if a line was removed. Missing +/// rc file or no marker line → `false` (nothing to do). Cleans every file +/// [`install_rc_hook`] can write: the PowerShell profiles, and for bash the +/// login-profile files alongside `.bashrc`. +pub fn remove_rc_hook(shell: Shell) -> Result { + if shell == Shell::Powershell { + let mut removed = false; + for p in powershell_profile_paths() { + removed |= remove_hook_lines(&p)?; + } + return Ok(removed); + } + let Some(rc) = shell.rc_path() else { + return Ok(false); + }; + let mut removed = remove_hook_lines(&rc)?; + if shell == Shell::Bash { + for p in bash_profile_paths() { + removed |= remove_hook_lines(&p)?; + } + } + Ok(removed) +} + #[cfg(test)] mod tests { use super::*; #[test] - fn export_lines_posix() { + fn export_lines_posix_are_liveness_gated() { let lines = export_lines(Shell::Zsh, "http://localhost:4100"); - assert_eq!(lines.len(), 2); - assert!(lines[0].starts_with("export ANTHROPIC_BASE_URL=")); - assert!(lines[0].contains("http://localhost:4100/anthropic")); - assert!(lines[1].starts_with("export OPENAI_BASE_URL=")); - assert!(lines[1].contains("http://localhost:4100/openai")); + let joined = lines.join("\n"); + // L-C1: exports must be gated on a live proxy port so a shell opened + // after a crash/reboot goes DIRECT instead of pointing at a dead port. + assert!(joined.contains("/dev/tcp/127.0.0.1/4100"), "{joined}"); + assert!(joined.contains("export ANTHROPIC_BASE_URL=\"http://localhost:4100/anthropic\"")); + assert!(joined.contains("export OPENAI_BASE_URL=\"http://localhost:4100/openai\"")); } #[test] - fn export_lines_powershell() { + fn export_lines_powershell_are_liveness_gated() { let lines = export_lines(Shell::Powershell, "http://localhost:4100"); - assert!(lines[0].starts_with("$env:ANTHROPIC_BASE_URL =")); - assert!(lines[1].starts_with("$env:OPENAI_BASE_URL =")); + let joined = lines.join("\n"); + assert!(joined.contains("TcpClient"), "{joined}"); + assert!(joined.contains("$env:ANTHROPIC_BASE_URL =")); + assert!(joined.contains("$env:OPENAI_BASE_URL =")); + assert!(joined.contains("catch"), "probe failure must be swallowed: {joined}"); } #[test] - fn export_lines_fish() { + fn export_lines_fish_are_liveness_gated() { let lines = export_lines(Shell::Fish, "http://localhost:4100"); - assert!(lines[0].starts_with("set -gx ANTHROPIC_BASE_URL")); + let joined = lines.join("\n"); + assert!(joined.contains("set -gx ANTHROPIC_BASE_URL")); + assert!(joined.contains("/dev/tcp/127.0.0.1/4100"), "{joined}"); + } + + #[test] + fn proxy_url_port_parses_common_shapes() { + assert_eq!(proxy_url_port("http://localhost:4100"), 4100); + assert_eq!(proxy_url_port("http://127.0.0.1:5000/x"), 5000); + assert_eq!(proxy_url_port("localhost"), 4100); // fallback + } + + #[test] + fn dead_port_probe_reports_not_alive() { + // Port 1 on loopback is essentially never bound; the probe must come + // back fast and false rather than hanging. + let started = std::time::Instant::now(); + assert!(!proxy_port_alive(1, std::time::Duration::from_millis(200))); + assert!(started.elapsed() < std::time::Duration::from_secs(2)); } #[test] diff --git a/src/cli/status.rs b/src/cli/status.rs index 492ac85..2319519 100644 --- a/src/cli/status.rs +++ b/src/cli/status.rs @@ -180,17 +180,69 @@ fn write_coverage( Ok(()) } +/// Cross-tool "today" without double counting (X4). A tool routed through the +/// proxy is recorded twice — once in the proxy DB and once in its own session +/// log — so summing the two buckets read ~2× reality for the recommended +/// setup. We exclude a log row when its tool's provider demonstrably had +/// proxied traffic today (claude-code → anthropic, codex → openai); tools with +/// ambiguous providers (aider, opencode) stay included, which can only +/// over-count, never hide spend. True per-turn dedup needs message-id matching +/// and is tracked separately. +#[cfg(feature = "logscrape")] +fn combined_today( + today_cost: f64, + log_rows: &[crate::logscrape::ScrapeBreakdown], + breakdown: &[ModelBreakdown], +) -> f64 { + let proxied_provider = |p: &str| breakdown.iter().any(|b| b.provider == p && b.cost > 0.0); + let unproxied_logs: f64 = log_rows + .iter() + .filter(|r| match r.tool { + "claude-code" => !proxied_provider("anthropic"), + "codex" => !proxied_provider("openai"), + _ => true, + }) + .map(|r| r.cost) + .sum(); + today_cost + unproxied_logs +} + /// Routing readout for the shell `burnwall status` runs in: is the AI tool you'd /// launch here actually pointed at the proxy? Catches the "proxy up but traffic /// goes direct" gap that leaves a user unprotected without any error. fn write_routing(w: &mut impl Write, sty: &Styler) -> std::io::Result<()> { use crate::cli::routing::{current_routing, EnvRouting}; match current_routing("anthropic") { - EnvRouting::Proxied => writeln!( - w, - " {} this shell points Anthropic traffic at the proxy.", - sty.green("🟢 Routed —") - ), + EnvRouting::Proxied => { + // Routed per the env — but cross-check the proxy is actually + // answering (U-C1): "routed at a dead port" means every AI tool in + // this shell fails with connection-refused, and a green line here + // would half-reassure the user into blaming the provider. + let alive = std::env::var("ANTHROPIC_BASE_URL") + .ok() + .and_then(|u| crate::cli::routing::proxy_alive_for_url(&u)); + if alive == Some(false) { + writeln!( + w, + " {} this shell routes to the proxy, but nothing answers on that port.", + sty.red("⛔ Routed to a DEAD proxy —") + )?; + writeln!( + w, + " Every AI tool launched from this shell will fail to connect." + )?; + return writeln!( + w, + " Fix: {} (or `burnwall stop` to pause routing and go direct)", + sty.bold("burnwall start") + ); + } + writeln!( + w, + " {} this shell points Anthropic traffic at the proxy.", + sty.green("🟢 Routed —") + ) + } EnvRouting::Direct => { writeln!( w, @@ -284,7 +336,7 @@ fn write_table( #[cfg(feature = "logscrape")] if let Some(rows) = log_scrape { - writeln!(w, " Tracked via log files (not proxied)")?; + writeln!(w, " Tracked via local session logs")?; if rows.is_empty() { writeln!(w, " (no Claude Code or Codex activity today)")?; } else { @@ -309,11 +361,21 @@ fn write_table( writeln!(w, " {}", "─".repeat(63))?; writeln!(w, " Log-file subtotal: ${:.2}", log_subtotal)?; writeln!(w)?; - writeln!( - w, - " Combined today (proxied + log files): ${:.2}", - today_cost + log_subtotal - )?; + // X4: a proxied tool's traffic shows up in BOTH buckets (a proxy DB + // row and a session-log row), so a naive proxied+logs sum read ~2× + // reality for exactly the recommended setup. Exclude the log rows + // of tools whose provider demonstrably flowed through the proxy + // today; the remainder is the genuinely unproxied add-on. + let combined = combined_today(today_cost, rows, breakdown); + if (combined - (today_cost + log_subtotal)).abs() > 0.005 { + writeln!( + w, + " Combined today: ${:.2} (proxied + unproxied logs; overlapping tool logs excluded)", + combined + )?; + } else { + writeln!(w, " Combined today (proxied + log files): ${:.2}", combined)?; + } } writeln!(w)?; } @@ -434,31 +496,27 @@ fn write_json( use serde_json::json; let bcfg = budget.config(); - // `log_scrape` JSON + subtotal — `null` / 0.0 when the feature is off or - // scraping is disabled; otherwise the per-tool/model rows plus subtotal. + // `log_scrape` JSON — `null` when the feature is off or scraping is + // disabled; otherwise the per-tool/model rows plus subtotal. #[cfg(feature = "logscrape")] - let (log_scrape_json, log_subtotal) = { - let subtotal = log_scrape.map(logscrape::subtotal).unwrap_or(0.0); - let rows_json = log_scrape.map(|rows| { - json!({ - "rows": rows.iter().map(|r| json!({ - "tool": r.tool, - "model": r.model, - "cost_usd": r.cost, - "turns": r.turns, - "input_tokens": r.usage.input_tokens, - "cache_creation_tokens": r.usage.cache_creation_tokens, - "cache_read_tokens": r.usage.cache_read_tokens, - "output_tokens": r.usage.output_tokens, - "cache_hit_rate": r.cache_hit_rate(), - })).collect::>(), - "subtotal_usd": logscrape::subtotal(rows), - }) - }); - (rows_json, subtotal) - }; + let log_scrape_json = log_scrape.map(|rows| { + json!({ + "rows": rows.iter().map(|r| json!({ + "tool": r.tool, + "model": r.model, + "cost_usd": r.cost, + "turns": r.turns, + "input_tokens": r.usage.input_tokens, + "cache_creation_tokens": r.usage.cache_creation_tokens, + "cache_read_tokens": r.usage.cache_read_tokens, + "output_tokens": r.usage.output_tokens, + "cache_hit_rate": r.cache_hit_rate(), + })).collect::>(), + "subtotal_usd": logscrape::subtotal(rows), + }) + }); #[cfg(not(feature = "logscrape"))] - let (log_scrape_json, log_subtotal) = (Option::::None, 0.0_f64); + let log_scrape_json = Option::::None; // Subscription-plan limit headroom, per provider, for the status bar / IDE // extension. `null` when no fresh snapshot exists (API user, or the proxy @@ -496,10 +554,24 @@ fn write_json( crate::cli::routing::EnvRouting::Direct => "direct", crate::cli::routing::EnvRouting::Bypassed => "bypassed", }; + // Liveness, not just a PID file: lets the extension flag "routed but the + // proxy is dead" (U-C1) instead of showing green over connection-refused. + let proxy_running = super::daemon::running_pid().ok().flatten().is_some(); + + // De-duplicated cross-tool total (X4): excludes log rows of tools whose + // provider flowed through the proxy today, so proxied Claude Code isn't + // counted twice in the headline figure. + #[cfg(feature = "logscrape")] + let combined_total = log_scrape + .map(|rows| combined_today(today_cost, rows, breakdown)) + .unwrap_or(today_cost); + #[cfg(not(feature = "logscrape"))] + let combined_total = today_cost; let value = json!({ "date": date, "env_routing": env_routing, + "proxy_running": proxy_running, "total_cost_usd": today_cost, "total_requests": total_requests, "blocked_requests": blocked, @@ -530,7 +602,7 @@ fn write_json( // `null` when log scraping is disabled or compiled out; otherwise the // per-tool/model rows plus their subtotal. Read-only — not the proxy DB. "log_scrape": log_scrape_json, - "combined_total_usd": today_cost + log_subtotal, + "combined_total_usd": combined_total, // Per-provider subscription limit headroom; `null` for API-only usage. "plan": plan_json, // Per-tool coverage: which installed tools route through the proxy, diff --git a/src/cli/statusline.rs b/src/cli/statusline.rs index 90f18b4..a21653f 100644 --- a/src/cli/statusline.rs +++ b/src/cli/statusline.rs @@ -141,10 +141,25 @@ fn build_ribbon(cc: &CcInput) -> Ribbon { /// Claude Code and inherits its environment, so the tool's `*_BASE_URL` tells us /// whether traffic is actually reaching the proxy. We key off the model's /// provider (Claude Code is Anthropic, but be correct if that ever changes). +/// +/// When the env says Proxied we additionally **liveness-probe the proxy port** +/// (U-C1): an already-open session keeps its env vars after a crash or +/// `burnwall stop`, and a green ribbon over a dead port — every request failing +/// with connection-refused — was the worst "Burnwall broke my setup" signal. +/// The probe is a sub-millisecond loopback connect, paid once per render. fn routing_state(model_id: &str) -> ribbon::Routing { let provider = provider_of(model_id); match crate::cli::routing::current_routing(provider) { - crate::cli::routing::EnvRouting::Proxied => ribbon::Routing::Proxied, + crate::cli::routing::EnvRouting::Proxied => { + let var = crate::cli::routing::base_url_var_for_provider(provider); + match std::env::var(var) + .ok() + .and_then(|u| crate::cli::routing::proxy_alive_for_url(&u)) + { + Some(false) => ribbon::Routing::ProxyDown, + _ => ribbon::Routing::Proxied, + } + } crate::cli::routing::EnvRouting::Direct => ribbon::Routing::Direct, crate::cli::routing::EnvRouting::Bypassed => ribbon::Routing::Bypassed, } @@ -154,7 +169,12 @@ fn routing_state(model_id: &str) -> ribbon::Routing { /// surfaces). Defaults to `anthropic` — the Claude Code case. fn provider_of(model_id: &str) -> &'static str { let m = model_id.to_ascii_lowercase(); - if m.contains("gpt") || m.starts_with("o1") || m.starts_with("o3") || m.contains("openai") { + if m.contains("gpt") + || m.starts_with("o1") + || m.starts_with("o3") + || m.starts_with("o4") + || m.contains("openai") + { "openai" } else if m.contains("gemini") || m.contains("google") { "google" @@ -174,20 +194,55 @@ fn plan_limits() -> Option { crate::plan::freshest(now, 12 * 3600).and_then(|s| s.to_ribbon_limits(now)) } -/// Claude Code reports *cumulative* session cost; cache the previous total per -/// session and return this turn's delta. `None` when we have no prior reading -/// (first turn of a session) so the ribbon shows session-only cost. Best-effort -/// — any I/O error just yields `None`. +/// Claude Code reports *cumulative* session cost, and re-renders the status +/// line many times per turn (~300ms cadence while streaming). A naive +/// "delta since last render" therefore showed only the last streaming +/// increment — $0.05 of a $0.40 turn, or $0.00 after any idle re-render — the +/// most-watched number, systematically wrong-low (U-H1). +/// +/// Turn-aware delta instead: track `(baseline, last_seen, last_msg)` per +/// session. While the total is moving (a turn is streaming), `msg` is the live +/// delta from the baseline — the turn's cost so far. When the total stops +/// moving (turn over), the final delta is locked in as `last_msg` and the +/// baseline advances, so the ribbon keeps showing the *completed* turn's cost +/// until the next turn starts. Best-effort — any I/O error yields `None`. fn session_msg_delta(session: Option<&str>, total: f64) -> Option { let session = session?; let dir = crate::storage::data_dir().ok()?.join("statusline"); let _ = std::fs::create_dir_all(&dir); let path = dir.join(format!("{}.last", sanitize(session))); - let prev = std::fs::read_to_string(&path) - .ok() - .and_then(|s| s.trim().parse::().ok()); - let _ = std::fs::write(&path, total.to_string()); - prev.map(|p| (total - p).max(0.0)) + + let state = std::fs::read_to_string(&path).ok().and_then(|s| { + let mut it = s.split_whitespace().filter_map(|t| t.parse::().ok()); + Some((it.next()?, it.next(), it.next())) + }); + + let (msg, baseline, last_msg) = match state { + // Legacy single-value file (just a total) or fresh triple. + Some((baseline, last_seen, last_msg)) => { + let last_seen = last_seen.unwrap_or(baseline); + let last_msg = last_msg.unwrap_or(0.0); + if total > last_seen + 1e-9 { + // Turn in progress: live cost-so-far from the baseline. + let live = (total - baseline).max(0.0); + (Some(live), baseline, live) + } else { + // Total stopped moving: the turn is over. Lock in its final + // cost and advance the baseline for the next turn. + let final_msg = if total > baseline + 1e-9 { + (total - baseline).max(0.0) + } else { + last_msg + }; + (Some(final_msg), total, final_msg) + } + } + // First render of a session — no baseline yet. + None => (None, total, 0.0), + }; + + let _ = std::fs::write(&path, format!("{baseline} {total} {last_msg}")); + msg } /// Keep a session id safe as a filename component (it's normally a UUID, but be diff --git a/src/cli/upgrade.rs b/src/cli/upgrade.rs index 927e72b..fb261b0 100644 --- a/src/cli/upgrade.rs +++ b/src/cli/upgrade.rs @@ -69,7 +69,12 @@ pub fn run_cmd(args: UpgradeArgs) -> Result<()> { // failed or --no-restart — pause routing so shells aren't left pointed // at a dead port. if was_running && !args.no_restart { - match std::process::Command::new(&exe) + // Resolve the binary fresh rather than reusing the captured `exe`: on + // Windows that path was renamed to `.old`, and the freshly-installed + // binary lives at the canonical install dir / on PATH (L-C3). Prefer + // the canonical dir, then PATH, then the original path. + let restart = restart_binary(&exe); + match std::process::Command::new(&restart) .args(["start", "--daemon"]) .status() { @@ -99,6 +104,27 @@ pub fn sweep_stale_artifact() { } } +/// Pick the binary to invoke for the post-upgrade restart. The freshly +/// installed binary lives at the canonical install dir (`~/.burnwall/bin`, +/// matching `install-path`); on Windows the previously-running `exe` was just +/// renamed to `.old`, so reusing it would fail. Order: canonical dir → bare +/// `burnwall` (PATH-resolved) → the original path as a last resort. +fn restart_binary(original_exe: &std::path::Path) -> std::path::PathBuf { + let bin_name = if cfg!(windows) { "burnwall.exe" } else { "burnwall" }; + if let Some(home) = dirs::home_dir() { + let canonical = home.join(".burnwall").join("bin").join(bin_name); + if canonical.exists() { + return canonical; + } + } + // If the original path still has a real binary (non-Windows, or install dir + // differs), prefer it; otherwise fall back to PATH resolution. + if original_exe.exists() { + return original_exe.to_path_buf(); + } + std::path::PathBuf::from("burnwall") +} + fn installer_url() -> String { // `releases/latest/download/…` always resolves to the newest release asset. let filename = if cfg!(windows) { diff --git a/src/cli/watch.rs b/src/cli/watch.rs index 83fecc0..1a36895 100644 --- a/src/cli/watch.rs +++ b/src/cli/watch.rs @@ -96,10 +96,23 @@ fn title_frame(db: &Storage) -> String { format!("\x1b]0;{}\x07", ribbon_from_db(db).render(false)) } -/// Render the current frame to a string (pure given the DB snapshot) — the -/// one-line ribbon or the multi-line dashboard. +/// Render the current frame to a string — the one-line ribbon or the +/// multi-line dashboard. fn render_frame(db: &Storage, args: &WatchArgs) -> String { - let ribbon = ribbon_from_db(db); + render_frame_with_plan(db, args, live_plan()) +} + +/// [`render_frame`] with the subscription-plan segment supplied by the +/// caller — pure given the DB snapshot and the plan. Split out so tests stay +/// hermetic: the live lookup reads the real data dir, and a fresh +/// `plan_limits.json` on the host (any subscriber's machine) swaps the +/// ribbon's dollar segment for plan headroom and changes the output. +fn render_frame_with_plan( + db: &Storage, + args: &WatchArgs, + plan: Option, +) -> String { + let ribbon = ribbon_with_plan(db, plan); let color = !args.no_color; if args.oneline { format!("{}\n", ribbon.render(color)) @@ -108,10 +121,24 @@ fn render_frame(db: &Storage, args: &WatchArgs) -> String { } } +/// Subscription headroom from the freshest proxy-captured snapshot — the +/// universal surface for CLIs without their own status bar (run `watch` in a +/// side pane). +fn live_plan() -> Option { + let now = chrono::Utc::now().timestamp(); + crate::plan::freshest(now, 12 * 3600).and_then(|s| s.to_ribbon_limits(now)) +} + /// Build the cross-tool ribbon from the proxy database. The originating tool /// isn't recoverable from proxied HTTP (every tool hits the same provider /// route), so `tool` and `sess` are left unset; `today` is the cross-tool total. fn ribbon_from_db(db: &Storage) -> Ribbon { + ribbon_with_plan(db, live_plan()) +} + +/// [`ribbon_from_db`] with the plan segment injected (see +/// [`render_frame_with_plan`] for why). +fn ribbon_with_plan(db: &Storage, plan: Option) -> Ribbon { let today = chrono::Local::now().format("%Y-%m-%d").to_string(); let today_usd = db.total_cost_for_date(&today).unwrap_or(0.0); let blocks = db @@ -121,16 +148,30 @@ fn ribbon_from_db(db: &Storage) -> Ribbon { let last = db.most_recent_request().ok().flatten(); let (model, up, down, msg_usd, ctx) = match last { + // A last-request row older than an hour is history, not "live": render + // the model with an idle annotation and drop the per-message cost and + // ctx gauge, so Monday's pane doesn't present Friday's dead session as + // a current turn (U-M4). Some(r) => { - let prompt = r.input_tokens + r.cache_creation_tokens + r.cache_read_tokens; - let ctx = ribbon::ctx_estimate(&r.model, prompt); - ( - ribbon::short_model(&r.model), - prompt, - r.output_tokens, - Some(r.cost_usd), - ctx, - ) + let age_secs = (chrono::Utc::now() - r.timestamp).num_seconds().max(0); + if age_secs > 3600 { + let label = format!( + "{} (idle {})", + ribbon::short_model(&r.model), + human_age(age_secs) + ); + (label, 0, 0, None, Ctx::Hidden) + } else { + let prompt = r.input_tokens + r.cache_creation_tokens + r.cache_read_tokens; + let ctx = ribbon::ctx_estimate(&r.model, prompt); + ( + ribbon::short_model(&r.model), + prompt, + r.output_tokens, + Some(r.cost_usd), + ctx, + ) + } } None => ("—".to_string(), 0, 0, None, Ctx::Hidden), }; @@ -144,12 +185,7 @@ fn ribbon_from_db(db: &Storage) -> Ribbon { sess_usd: None, // the aggregate view has no session concept today_usd: Some(today_usd), blocks_today: blocks, - // Subscription headroom (freshest provider) — the universal surface for - // CLIs without their own status bar (run `watch` in a side pane). - plan: { - let now = chrono::Utc::now().timestamp(); - crate::plan::freshest(now, 12 * 3600).and_then(|s| s.to_ribbon_limits(now)) - }, + plan, // The aggregate DB view spans every tool; there's no single tool // environment to judge routing from, so stay silent here. Per-tool // coverage is shown in the dashboard's `coverage:` block instead. @@ -202,6 +238,25 @@ fn mtime(path: &std::path::PathBuf) -> Option { std::fs::metadata(path).and_then(|m| m.modified()).ok() } +/// Compact human age for the idle annotation: "5h", "2d4h", "3w". +fn human_age(secs: i64) -> String { + let (m, h, d) = (secs / 60, secs / 3600, secs / 86_400); + if d >= 14 { + format!("{}w", d / 7) + } else if d >= 1 { + let rem_h = h - d * 24; + if rem_h > 0 { + format!("{d}d{rem_h}h") + } else { + format!("{d}d") + } + } else if h >= 1 { + format!("{h}h") + } else { + format!("{}m", m.max(1)) + } +} + #[cfg(test)] mod tests { use super::*; @@ -224,15 +279,16 @@ mod tests { #[test] fn ribbon_from_db_uses_last_request_and_estimates_ctx() { let db = db_with_request(); - let r = ribbon_from_db(&db); + let r = ribbon_with_plan(&db, None); assert_eq!(r.model, "sonnet-4.6"); assert_eq!(r.up, 13_000); // input + cache_creation + cache_read assert_eq!(r.down, 615); assert_eq!(r.msg_usd, Some(0.05)); assert_eq!(r.sess_usd, None); // no session concept in the aggregate view - // 13k / 200k ≈ 6.5% → an Estimate (marked ~ at render time). + // 13k / 1M ≈ 1.3% (Sonnet 4.6 runs a 1M window) → an Estimate + // (marked ~ at render time). match r.ctx { - Ctx::Estimate(p) => assert!(p > 6.0 && p < 7.0), + Ctx::Estimate(p) => assert!(p > 1.0 && p < 2.0), other => panic!("expected Estimate, got {other:?}"), } } @@ -240,7 +296,7 @@ mod tests { #[test] fn ribbon_from_empty_db_is_safe() { let db = Storage::open_in_memory().unwrap(); - let r = ribbon_from_db(&db); + let r = ribbon_with_plan(&db, None); assert_eq!(r.model, "—"); assert_eq!(r.msg_usd, None); assert_eq!(r.ctx, Ctx::Hidden); @@ -258,7 +314,7 @@ mod tests { no_color: true, title: false, }; - let frame = render_frame(&db, &args); + let frame = render_frame_with_plan(&db, &args, None); assert!(frame.contains("🔥 burnwall · sonnet-4.6")); assert!(frame.contains("$0.05 msg")); } @@ -273,7 +329,7 @@ mod tests { no_color: true, title: false, }; - let frame = render_frame(&db, &args); + let frame = render_frame_with_plan(&db, &args, None); assert!(frame.contains("burnwall · live")); assert!(frame.contains("today by model:")); assert!(frame.contains("anthropic/sonnet-4.6")); diff --git a/src/plan.rs b/src/plan.rs index ab92fd2..53a5e71 100644 --- a/src/plan.rs +++ b/src/plan.rs @@ -67,9 +67,15 @@ impl PlanSnapshot { } /// Map to the renderer's [`crate::ribbon::PlanLimits`] (binding window as - /// primary, next as secondary). `None` if there are no windows. + /// primary, next as secondary). `None` if there are no windows, or when the + /// binding window's own reset time has passed — the snapshot says it has + /// expired, so showing yesterday's 92% as live headroom is worse than + /// showing nothing (U-M7). pub fn to_ribbon_limits(&self, now: i64) -> Option { let primary = self.windows.first()?; + if primary.reset <= now { + return None; + } Some(crate::ribbon::PlanLimits { primary_label: primary.label.clone(), primary_pct: (primary.utilization * 100.0).clamp(0.0, 100.0), @@ -78,7 +84,14 @@ impl PlanSnapshot { .windows .get(1) .map(|w| (w.label.clone(), (w.utilization * 100.0).clamp(0.0, 100.0))), - throttled: self.status != "allowed", + // Only a positively-throttling status renders the ⛔ chip. Anthropic + // emits warning-grade intermediates (e.g. `allowed_warning`) near + // the limit while requests still succeed — "anything ≠ allowed" + // showed a false THROTTLED at ~80% utilization (U-H4). + throttled: matches!( + self.status.as_str(), + "throttled" | "rejected" | "blocked" | "rate_limited" + ), }) } } @@ -254,6 +267,37 @@ mod tests { assert!(!rl.throttled); } + #[test] + fn warning_grade_status_is_not_throttled() { + // U-H4: Anthropic emits intermediates like `allowed_warning` near the + // limit while requests still succeed — must not render ⛔ throttled. + let mut h = unified(); + h.insert( + "anthropic-ratelimit-unified-status", + hyper::header::HeaderValue::from_static("allowed_warning"), + ); + let snap = parse_limits("anthropic", &h, 1780951905).unwrap(); + let rl = snap.to_ribbon_limits(1780951905).unwrap(); + assert!(!rl.throttled, "warning-grade status must not show throttled"); + + let mut h = unified(); + h.insert( + "anthropic-ratelimit-unified-status", + hyper::header::HeaderValue::from_static("rejected"), + ); + let snap = parse_limits("anthropic", &h, 1780951905).unwrap(); + assert!(snap.to_ribbon_limits(1780951905).unwrap().throttled); + } + + #[test] + fn expired_window_yields_no_ribbon_limits() { + // U-M7: once the binding window's reset has passed, the reading is + // self-describedly expired — show nothing, not yesterday's 92%. + let snap = parse_limits("anthropic", &unified(), 1780951905).unwrap(); + let after_reset = 1780960800 + 60; + assert!(snap.to_ribbon_limits(after_reset).is_none()); + } + #[test] fn snapshot_json_round_trips() { let snap = parse_limits("anthropic", &unified(), 1780951905).unwrap(); diff --git a/src/ribbon.rs b/src/ribbon.rs index 5ff1705..7175a97 100644 --- a/src/ribbon.rs +++ b/src/ribbon.rs @@ -49,6 +49,10 @@ pub enum Routing { /// Routed, but the `BURNWALL_BYPASS` kill switch makes the proxy a pure /// relay (checks off). Rendered as a softer caution. Bypassed, + /// Routed at the proxy, but the proxy port doesn't answer — every request + /// from this environment will fail with connection-refused. The loudest + /// warning of all: the user's tool is actively broken (U-C1). + ProxyDown, /// The surface has no environment context to judge routing. Renders nothing. Unknown, } @@ -120,6 +124,13 @@ impl Ribbon { Routing::Bypassed => { let _ = write!(s, " · {}", warn_segment("⚠ bypass", color, Hue::Yellow)); } + Routing::ProxyDown => { + let _ = write!( + s, + " · {}", + warn_segment("⛔ PROXY DOWN — run `burnwall start`", color, Hue::Red) + ); + } Routing::Proxied | Routing::Unknown => {} } let _ = write!(s, " · ↑{} ↓{}", human_k(self.up), human_k(self.down)); @@ -251,11 +262,21 @@ pub fn short_model(id: &str) -> String { /// Known model context-window sizes (tokens), matched by name prefix. Used only /// to *estimate* the gauge for tools that don't report it; an unknown model /// yields no estimate (the caller renders [`Ctx::Unknown`]). +/// +/// First prefix match wins, so generation-specific entries (Opus 4.6+ and +/// Sonnet 4.6 moved to 1M windows) must precede their shorter family keys. const CONTEXT_WINDOWS: &[(&str, u64)] = &[ - ("claude-opus-4", 200_000), - ("claude-sonnet-4", 200_000), + ("claude-fable-5", 1_000_000), + ("claude-mythos-5", 1_000_000), + ("claude-opus-4-8", 1_000_000), + ("claude-opus-4-7", 1_000_000), + ("claude-opus-4-6", 1_000_000), + ("claude-opus-4", 200_000), // 4.5 and earlier + ("claude-sonnet-4-6", 1_000_000), + ("claude-sonnet-4", 200_000), // 4.5 and earlier ("claude-haiku-4", 200_000), ("gpt-5", 400_000), + ("gemini-3", 1_000_000), ("gemini-2.5", 1_000_000), ("gemini-2.0", 1_000_000), ]; @@ -521,17 +542,31 @@ mod tests { #[test] fn ctx_estimate_trusts_known_window_and_flags_overflow() { - // Within a known window → Estimate. - match ctx_estimate("claude-sonnet-4-6", 44_000) { + // Within a known window → Estimate (haiku-4.5's window is 200k). + match ctx_estimate("claude-haiku-4-5", 44_000) { Ctx::Estimate(p) => assert!((p - 22.0).abs() < 0.5), other => panic!("expected Estimate, got {other:?}"), } // Prompt exceeds the assumed window (extended mode) → Unknown, not a wrong %. - assert_eq!(ctx_estimate("claude-sonnet-4-6", 512_000), Ctx::Unknown); + assert_eq!(ctx_estimate("claude-haiku-4-5", 512_000), Ctx::Unknown); // Unknown model → Unknown. assert_eq!(ctx_estimate("who-knows-1", 1000), Ctx::Unknown); } + #[test] + fn ctx_windows_track_the_1m_generation() { + // Opus 4.6+ / Sonnet 4.6 / Fable 5 run 1M windows; the pre-4.6 + // generation stays at 200k. The generation-specific prefix must win + // over the shorter family key. + assert_eq!(context_window_for("claude-fable-5"), Some(1_000_000)); + assert_eq!(context_window_for("claude-fable-5[1m]"), Some(1_000_000)); + assert_eq!(context_window_for("claude-opus-4-8"), Some(1_000_000)); + assert_eq!(context_window_for("claude-sonnet-4-6"), Some(1_000_000)); + assert_eq!(context_window_for("claude-sonnet-4-5-20250929"), Some(200_000)); + assert_eq!(context_window_for("claude-opus-4-5-20251101"), Some(200_000)); + assert_eq!(context_window_for("gemini-3.1-pro-preview"), Some(1_000_000)); + } + #[test] fn color_output_contains_ansi() { let s = base().render(true); From acf1d6948e1f8624e90bad0e6ae6dd8b5a305bfd Mon Sep 17 00:00:00 2001 From: codehippie1 Date: Wed, 10 Jun 2026 15:38:58 -0400 Subject: [PATCH 8/9] tests: torture-proxy suite for streaming/timeout/disconnect Raw-TCP fake upstream exercising the paths idealized tests missed: SSE delivered one byte per flush round-trips intact and records usage; a stalled upstream is bounded by read_timeout instead of hanging; a client disconnect mid-stream leaves the proxy responsive. Registers the torture_test and audit_test targets. --- Cargo.toml | 8 + tests/integration/torture_test.rs | 250 ++++++++++++++++++++++++++++++ 2 files changed, 258 insertions(+) create mode 100644 tests/integration/torture_test.rs diff --git a/Cargo.toml b/Cargo.toml index 7f22262..90c037a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -204,6 +204,14 @@ path = "tests/unit/observe_test.rs" name = "audit_cli_test" path = "tests/integration/audit_cli_test.rs" +[[test]] +name = "audit_test" +path = "tests/unit/audit_test.rs" + +[[test]] +name = "torture_test" +path = "tests/integration/torture_test.rs" + [profile.release] opt-level = "z" # Optimize for size lto = true # Link-time optimization diff --git a/tests/integration/torture_test.rs b/tests/integration/torture_test.rs new file mode 100644 index 0000000..c844a1b --- /dev/null +++ b/tests/integration/torture_test.rs @@ -0,0 +1,250 @@ +//! Torture-proxy suite (P0): adversarial upstream behaviour the wiremock +//! happy-path tests can't express — SSE split across tiny TCP frames, an +//! upstream that accepts then stalls forever, and a client that disconnects +//! mid-stream. These exercise the streaming tee and the new timeout/keepalive +//! and disconnect-cancel paths (P-C1/P-C2) that earlier idealized tests missed. +//! +//! The fake upstream is a raw `tokio::net::TcpListener` (not wiremock) so we +//! control flush boundaries and can stall a live socket. Every case is wrapped +//! in `tokio::time::timeout` so a regression *hangs the test deadline* rather +//! than the whole suite. + +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use burnwall::budget::{BudgetTracker, LoopDetector}; +use burnwall::proxy::{serve, AppState}; +use burnwall::security::SecurityEngine; +use burnwall::storage::Storage; +use serde_json::json; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; + +/// A realistic Anthropic SSE response: input/cache tokens in `message_start`, +/// output tokens in `message_delta`. +const SSE: &str = "event: message_start\n\ +data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_x\",\"model\":\"claude-haiku-4-5\",\"usage\":{\"input_tokens\":2000,\"cache_creation_input_tokens\":0,\"cache_read_input_tokens\":500,\"output_tokens\":0}}}\n\ +\n\ +event: message_delta\n\ +data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":300}}\n\ +\n\ +event: message_stop\n\ +data: {\"type\":\"message_stop\"}\n\n"; + +fn today() -> String { + chrono::Local::now().format("%Y-%m-%d").to_string() +} + +/// Build an `AppState` pointed at `upstream`, with a caller-supplied HTTP +/// client (so a test can set a short read_timeout to exercise stall recovery). +fn state_for(upstream: String, storage: Arc, client: reqwest::Client) -> AppState { + AppState { + upstream_anthropic: upstream, + upstream_openai: "http://127.0.0.1:1".to_string(), + upstream_google: "http://127.0.0.1:1".to_string(), + http_client: client, + security: Arc::new(SecurityEngine::with_defaults()), + budget: Arc::new(BudgetTracker::with_defaults()), + loop_detector: Arc::new(LoopDetector::with_defaults()), + storage, + cache_injection: false, + resilience: Default::default(), + #[cfg(feature = "observe")] + otel: None, + } +} + +async fn spawn_proxy(state: AppState) -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + let _ = serve(listener, Arc::new(state)).await; + }); + addr +} + +/// Read past the end of an HTTP request's headers on `sock` (we don't care +/// about the body for these tests — the proxy has already sent it). +async fn drain_request_headers(sock: &mut tokio::net::TcpStream) { + let mut buf = [0u8; 4096]; + // One read is enough to get the headers for our small POSTs; we just need + // the upstream to have accepted and consumed enough to reply. + let _ = sock.read(&mut buf).await; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn sse_split_across_tiny_frames_round_trips_and_records() { + // The tee must reassemble a stream delivered one byte at a time: the client + // gets the exact bytes, and usage is parsed from the reassembled body. + let upstream = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let up_addr = upstream.local_addr().unwrap(); + tokio::spawn(async move { + let (mut sock, _) = upstream.accept().await.unwrap(); + drain_request_headers(&mut sock).await; + let header = format!( + "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\n\r\n", + SSE.len() + ); + sock.write_all(header.as_bytes()).await.unwrap(); + // One byte per write, each flushed — maximally adversarial framing. + for b in SSE.as_bytes() { + sock.write_all(&[*b]).await.unwrap(); + sock.flush().await.unwrap(); + } + sock.shutdown().await.ok(); + }); + + let storage = Arc::new(Storage::open_in_memory().unwrap()); + let state = state_for( + format!("http://{up_addr}"), + storage.clone(), + reqwest::Client::new(), + ); + let addr = spawn_proxy(state).await; + + let body = tokio::time::timeout(Duration::from_secs(10), async { + let resp = reqwest::Client::new() + .post(format!("http://{addr}/anthropic/v1/messages")) + .json(&json!({"model": "claude-haiku-4-5", "stream": true})) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + resp.bytes().await.unwrap() + }) + .await + .expect("byte-at-a-time stream must not hang"); + + assert_eq!(body.as_ref(), SSE.as_bytes(), "stream must round-trip intact"); + + tokio::time::sleep(Duration::from_millis(250)).await; + let rows = storage.requests_for_date(&today()).unwrap(); + assert_eq!(rows.len(), 1, "the reassembled stream should record one row"); + assert!(rows[0].cost_usd > 0.0, "usage parsed from reassembled body"); + assert_eq!(rows[0].input_tokens, 2000); + assert_eq!(rows[0].output_tokens, 300); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn upstream_that_stalls_forever_is_bounded_by_read_timeout() { + // P-C1: an upstream that sends headers + a partial body then goes silent + // must NOT hang the proxy/client forever. With a short read_timeout the + // socket is reclaimed; without the fix this test's deadline would trip. + let upstream = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let up_addr = upstream.local_addr().unwrap(); + let _server = tokio::spawn(async move { + let (mut sock, _) = upstream.accept().await.unwrap(); + drain_request_headers(&mut sock).await; + // Claim a long body, send a sliver, then stall indefinitely. + sock.write_all( + b"HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: 100000\r\n\r\nevent: ping\n", + ) + .await + .unwrap(); + sock.flush().await.unwrap(); + // Never write the rest. Hold the socket open. + tokio::time::sleep(Duration::from_secs(120)).await; + }); + + let storage = Arc::new(Storage::open_in_memory().unwrap()); + // Short read_timeout stands in for the production 600s backstop so the test + // resolves quickly — the point is that a stalled read is reclaimed at all. + let stall_client = reqwest::Client::builder() + .read_timeout(Duration::from_millis(800)) + .build() + .unwrap(); + let state = state_for(format!("http://{up_addr}"), storage.clone(), stall_client); + let addr = spawn_proxy(state).await; + + // The whole exchange must finish well inside the deadline: the client gets + // headers (200) then the body stream errors out when the upstream read + // times out. Either way it must not hang. + let outcome = tokio::time::timeout(Duration::from_secs(8), async { + let resp = reqwest::Client::builder() + .build() + .unwrap() + .post(format!("http://{addr}/anthropic/v1/messages")) + .json(&json!({"model": "claude-haiku-4-5", "stream": true})) + .send() + .await; + // Read the (truncated) body to completion or error. + if let Ok(r) = resp { + let _ = r.bytes().await; + } + }) + .await; + + assert!( + outcome.is_ok(), + "a stalled upstream must be bounded by read_timeout, not hang" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn client_disconnect_midstream_does_not_hang_the_proxy() { + // P-C2: when the client drops mid-stream, the tee stops draining and the + // proxy stays responsive. We assert the proxy serves a *subsequent* request + // fine after a client abandoned a prior streaming response. + let upstream = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let up_addr = upstream.local_addr().unwrap(); + tokio::spawn(async move { + loop { + let Ok((mut sock, _)) = upstream.accept().await else { + break; + }; + tokio::spawn(async move { + drain_request_headers(&mut sock).await; + let header = format!( + "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\n\r\n", + SSE.len() + ); + let _ = sock.write_all(header.as_bytes()).await; + // Trickle the body slowly so the client can disconnect mid-way. + for chunk in SSE.as_bytes().chunks(8) { + if sock.write_all(chunk).await.is_err() { + break; + } + let _ = sock.flush().await; + tokio::time::sleep(Duration::from_millis(20)).await; + } + let _ = sock.shutdown().await; + }); + } + }); + + let storage = Arc::new(Storage::open_in_memory().unwrap()); + let state = state_for( + format!("http://{up_addr}"), + storage.clone(), + reqwest::Client::new(), + ); + let addr = spawn_proxy(state).await; + + // First request: start streaming, then drop the response without reading it + // all (simulates the user pressing Esc). + { + let resp = reqwest::Client::new() + .post(format!("http://{addr}/anthropic/v1/messages")) + .json(&json!({"model": "claude-haiku-4-5", "stream": true})) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + drop(resp); // abandon mid-stream + } + + // Second request must still be served promptly — the proxy isn't wedged. + let ok = tokio::time::timeout(Duration::from_secs(8), async { + let resp = reqwest::Client::new() + .post(format!("http://{addr}/anthropic/v1/messages")) + .json(&json!({"model": "claude-haiku-4-5", "stream": true})) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + let _ = resp.bytes().await; + }) + .await; + assert!(ok.is_ok(), "proxy must stay responsive after a client disconnect"); +} From 2d78240e38d7e40ff03b5ea426f9e3b83746ace2 Mon Sep 17 00:00:00 2001 From: codehippie1 Date: Wed, 10 Jun 2026 16:10:12 -0400 Subject: [PATCH 9/9] =?UTF-8?q?v0.9.14:=20dogfooding=20robustness=20pass?= =?UTF-8?q?=20=E2=80=94=20budget=20reset,=20loop/security=20FP=20fixes,=20?= =?UTF-8?q?dead-proxy=20safety?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 46 ++++++++++++++++++++++++++++++++++++++ Cargo.lock | 2 +- Cargo.toml | 2 +- editor/vscode/package.json | 2 +- packaging/mcp/server.json | 2 +- 5 files changed, 50 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 12644dc..349e423 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,52 @@ All notable changes to Burnwall. +## [0.9.14] — 2026-06-10 + +A real-world robustness pass driven by dogfooding: a multi-agent review of +every feature, focused on the failure modes that make a tool freeze, falsely +block, or mislead — the kind that trigger an uninstall. + +### Fixed + +- **The daily budget now resets at midnight.** A long-running proxy used to + accumulate spend across days and eventually return "budget exceeded" on every + request even though the day's real spend was small. The counter is now + day- and month-aware (restart- and clock-change-proof), and the monthly cap + is actually enforced. +- **Loop detection no longer gets stuck on retries.** A blocked request (and a + client's automatic retry of it, or a retry after a provider outage) no longer + feeds the loop-detection window, so a transient blip can't wedge a session + into a permanent 429 loop. Blocks now carry a `Retry-After`, and the window is + keyed per method/provider/path so unrelated requests don't collide. +- **Fewer false security blocks.** Writing or discussing a file that merely + mentions a sensitive path (e.g. `~/.ssh` in a README) no longer 403s — only + shell-tool arguments get command checks. Windows paths in tool arguments are + no longer mistaken for network mounts, scoped deletes like `rm -rf /tmp/x` + pass, and well-known documentation/example keys are exempt. Blocks now explain + what was caught and how to proceed, and `burnwall report-bug` writes a + sanitized local report for false positives. +- **The proxy no longer hangs on a stalled or unreachable upstream**, and + cancelling a request (Esc) stops the upstream instead of billing the full + response. +- **Accurate cost capture for more tools.** OpenAI's Responses API (used by + Codex) is now parsed instead of silently recording $0, unknown models warn + instead of recording $0, and the cross-tool "today" total no longer + double-counts traffic that went through the proxy. + +### Changed + +- **A crashed or stopped proxy no longer breaks your terminals.** Shell routing + is liveness-gated: if the proxy isn't running, a new shell talks directly to + the provider (unprotected but working) instead of failing to connect. Every + status surface shows a clear "proxy down" warning when routing points at a + dead port. PowerShell now gets persistent routing like the other shells. +- Plan-aware budgeting: on a flat-rate subscription, the dollar cap is treated + as advisory (tracked and warned, not blocked) unless you opt in. +- Hardening across MCP (prose-safe scanning, clearer approval errors), the audit + chain (lost-key detection), storage (schema versioning), and the daemon + (a real log file, PID identity checks). + ## [0.9.13] — 2026-06-09 ### Fixed diff --git a/Cargo.lock b/Cargo.lock index b11fd0b..357a38d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -171,7 +171,7 @@ checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" [[package]] name = "burnwall" -version = "0.9.13" +version = "0.9.14" dependencies = [ "anyhow", "assert_cmd", diff --git a/Cargo.toml b/Cargo.toml index 90c037a..e30b115 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "burnwall" -version = "0.9.13" +version = "0.9.14" edition = "2024" rust-version = "1.87" description = "Local proxy for AI coding tools (Claude Code, Codex CLI, Aider): cache-aware cost tracking, path/command security checks, daily budget enforcement. Zero telemetry." diff --git a/editor/vscode/package.json b/editor/vscode/package.json index 078e65e..5b1b5f5 100644 --- a/editor/vscode/package.json +++ b/editor/vscode/package.json @@ -2,7 +2,7 @@ "name": "burnwall", "displayName": "Burnwall", "description": "Cost + security for your AI coding agents, at a glance — reads your local Burnwall CLI.", - "version": "0.9.13", + "version": "0.9.14", "publisher": "intbot", "license": "FSL-1.1-MIT", "repository": { "type": "git", "url": "https://github.com/intbot/burnwall" }, diff --git a/packaging/mcp/server.json b/packaging/mcp/server.json index 7b6c057..44c5fc4 100644 --- a/packaging/mcp/server.json +++ b/packaging/mcp/server.json @@ -6,7 +6,7 @@ "url": "https://github.com/intbot/burnwall", "source": "github" }, - "version": "0.9.13", + "version": "0.9.14", "packages": [ { "registryType": "oci",