diff --git a/spec/unit/bundle_loader_spec.lua b/spec/unit/bundle_loader_spec.lua index c0415f3..fc8ca26 100644 --- a/spec/unit/bundle_loader_spec.lua +++ b/spec/unit/bundle_loader_spec.lua @@ -370,11 +370,13 @@ runner:then_("^the compiled bundle includes runtime override blocks$", function( assert.equals(true, ctx.compiled.global_shadow.enabled) assert.equals("incident-global-shadow", ctx.compiled.global_shadow.reason) assert.equals("2030-01-01T00:00:00Z", ctx.compiled.global_shadow.expires_at) + assert.is_number(ctx.compiled.global_shadow._expires_epoch) assert.is_table(ctx.compiled.kill_switch_override) assert.equals(true, ctx.compiled.kill_switch_override.enabled) assert.equals("incident-ks-override", ctx.compiled.kill_switch_override.reason) assert.equals("2030-01-01T00:00:00Z", ctx.compiled.kill_switch_override.expires_at) + assert.is_number(ctx.compiled.kill_switch_override._expires_epoch) end) runner:then_("^policies_by_id contains exactly one entry for that ID with the last policy spec$", function(ctx) diff --git a/spec/unit/circuit_breaker_spec.lua b/spec/unit/circuit_breaker_spec.lua index 619a13d..e814f90 100644 --- a/spec/unit/circuit_breaker_spec.lua +++ b/spec/unit/circuit_breaker_spec.lua @@ -212,4 +212,32 @@ describe("circuit_breaker targeted direct coverage", function() assert.is_nil(result.spend_rate) assert.is_nil(result.reason) end) + + it("clears rate keys on auto-reset before resuming traffic", function() + local env = mock_ngx.setup_ngx() + local now = env.time.now() + local limit_key = "org-auto-reset" + local config = { + enabled = true, + spend_rate_threshold_per_minute = 100, + action = "reject", + alert = false, + auto_reset_after_minutes = 5, + } + local current_window = _window_start(now) + local previous_window = current_window - 60 + + env.dict:set(circuit_breaker.build_state_key(limit_key), "open:" .. tostring(now - 300)) + env.dict:set(circuit_breaker.build_rate_key(limit_key, current_window), 100) + env.dict:set(circuit_breaker.build_rate_key(limit_key, previous_window), 100) + + local result = circuit_breaker.check(env.dict, config, limit_key, 1, now) + + assert.is_false(result.tripped) + assert.equals("closed", result.state) + assert.equals(1, result.spend_rate) + assert.is_nil(env.dict:get(circuit_breaker.build_state_key(limit_key))) + assert.equals(1, env.dict:get(circuit_breaker.build_rate_key(limit_key, current_window))) + assert.is_nil(env.dict:get(circuit_breaker.build_rate_key(limit_key, previous_window))) + end) end) diff --git a/spec/unit/rule_engine_spec.lua b/spec/unit/rule_engine_spec.lua index 74b5982..8488ae3 100644 --- a/spec/unit/rule_engine_spec.lua +++ b/spec/unit/rule_engine_spec.lua @@ -916,4 +916,38 @@ describe("rule_engine targeted direct coverage", function() assert.is_true(wrapped.would_reject) debug.setupvalue(rule_engine.evaluate, shadow_index, original_shadow_mode) end) + + it("uses cached override expiry without reparsing expires_at", function() + local ctx = {} + _setup_engine(ctx) + + ctx.matching_policy_ids = { "p1" } + ctx.request_context._descriptors["jwt:org_id"] = "org-shadow-cached" + ctx.bundle.global_shadow = { + enabled = true, + reason = "incident-global-shadow-cached", + expires_at = "not-a-date", + _expires_epoch = ctx.time.now() + 300, + } + ctx.bundle.policies_by_id.p1 = _new_policy("p1", "enforce", { + { + name = "reject_rule", + algorithm = "token_bucket", + limit_keys = { "jwt:org_id" }, + algorithm_config = {}, + }, + }) + ctx.rule_results.reject_rule = { + allowed = false, + reason = "rate_limited", + limit = 100, + remaining = 0, + retry_after = 2, + } + + local decision = ctx.engine.evaluate(ctx.request_context, ctx.bundle) + assert.equals("allow", decision.action) + assert.equals("true", decision.headers["X-Fairvisor-Global-Shadow"]) + assert.equals("incident-global-shadow-cached", decision.headers["X-Fairvisor-Global-Shadow-Reason"]) + end) end) diff --git a/spec/unit/saas_client_spec.lua b/spec/unit/saas_client_spec.lua index 08971fa..545ea0f 100644 --- a/spec/unit/saas_client_spec.lua +++ b/spec/unit/saas_client_spec.lua @@ -743,4 +743,44 @@ describe("saas_client targeted direct coverage", function() local flushed = reloaded.flush_events() assert.equals(3, flushed) end) + + it("rebuilds cached auth header after reinit with a different token", function() + local reloaded = _reload_saas_client() + local http = mock_http.new() + http.queue_response("POST", "https://s/api/v1/edge/register", { status = 200 }) + http.queue_response("POST", "https://s/api/v1/edge/register", { status = 200 }) + + local deps = { + bundle_loader = { get_current = function() end, load_from_string = function() end, apply = function() end }, + health = { set = function() end, inc = function() end }, + http_client = http.client, + } + + local ok, err = reloaded.init({ + edge_id = "e", + edge_token = "token-1", + saas_url = "https://s", + }, deps) + assert.is_true(ok) + assert.is_nil(err) + + ok, err = reloaded.init({ + edge_id = "e", + edge_token = "token-2", + saas_url = "https://s", + }, deps) + assert.is_true(ok) + assert.is_nil(err) + + local register_calls = {} + for _, request in ipairs(http.requests) do + if request.method == "POST" and request.url == "https://s/api/v1/edge/register" then + register_calls[#register_calls + 1] = request + end + end + + assert.equals(2, #register_calls) + assert.equals("Bearer token-1", register_calls[1].headers.Authorization) + assert.equals("Bearer token-2", register_calls[2].headers.Authorization) + end) end) diff --git a/src/fairvisor/bundle_loader.lua b/src/fairvisor/bundle_loader.lua index ee2c3d2..4258328 100644 --- a/src/fairvisor/bundle_loader.lua +++ b/src/fairvisor/bundle_loader.lua @@ -406,6 +406,7 @@ local function _validate_top_level(bundle) if expires_err then return nil, "expires_at_invalid" end + bundle._expires_epoch = expires_epoch if ngx and ngx.now and expires_epoch <= ngx.now() then return nil, "bundle_expired" @@ -457,6 +458,7 @@ local function _validate_top_level(bundle) if expires_err then return nil, field_name .. "_invalid: expires_at_invalid" end + block._expires_epoch = expires_epoch if ngx and ngx.now and expires_epoch <= ngx.now() then return nil, field_name .. "_invalid: expired" diff --git a/src/fairvisor/circuit_breaker.lua b/src/fairvisor/circuit_breaker.lua index 1f2a33b..e0625c5 100644 --- a/src/fairvisor/circuit_breaker.lua +++ b/src/fairvisor/circuit_breaker.lua @@ -128,7 +128,7 @@ function _M.check(dict, config, limit_key, cost, now) } end - dict:delete(state_key) + _M.reset(dict, limit_key, now) else return { tripped = true, diff --git a/src/fairvisor/llm_limiter.lua b/src/fairvisor/llm_limiter.lua index dfbb684..214b68a 100644 --- a/src/fairvisor/llm_limiter.lua +++ b/src/fairvisor/llm_limiter.lua @@ -14,6 +14,7 @@ local string_gsub = string.gsub local string_format = string.format local string_lower = string.lower local string_sub = string.sub +local string_byte = string.byte local os_date = os.date local token_bucket = require("fairvisor.token_bucket") @@ -160,21 +161,23 @@ local function _simple_word_estimate(request_context) if body == "" then return 0 end - if #body > MAX_BODY_SCAN_BYTES then - body = string_sub(body, 1, MAX_BODY_SCAN_BYTES) + + local scan_limit = #body + if scan_limit > MAX_BODY_SCAN_BYTES then + scan_limit = MAX_BODY_SCAN_BYTES end local messages_start = string_find(body, "\"messages\"", 1, true) - if messages_start then + if messages_start and messages_start < scan_limit then local array_start = string_find(body, "[", messages_start, true) - if array_start then + if array_start and array_start < scan_limit then local position = array_start local char_count = 0 while true do -- Find "content" key local key_start = string_find(body, "\"content\"", position, true) - if not key_start then + if not key_start or key_start >= scan_limit then break end @@ -182,7 +185,7 @@ local function _simple_word_estimate(request_context) -- pattern: %s*:%s*" local val_marker_start, val_marker_end = string_find(body, "^%s*:%s*\"", key_start + 9) - if not val_marker_start then + if not val_marker_start or val_marker_start > scan_limit then -- False positive (e.g. key was in a string), skip it position = key_start + 9 else @@ -190,12 +193,15 @@ local function _simple_word_estimate(request_context) local content_end = content_start while true do content_end = string_find(body, "\"", content_end, true) - if not content_end then break end + if not content_end or content_end > scan_limit then + content_end = nil + break + end -- Count backslashes before this quote local bs_count = 0 local p = content_end - 1 - while p >= content_start and string_sub(body, p, p) == "\\" do + while p >= content_start and string_byte(body, p) == 92 do bs_count = bs_count + 1 p = p - 1 end @@ -219,7 +225,9 @@ local function _simple_word_estimate(request_context) end end - return ceil(#body / 4) + local total_len = #body + if total_len > MAX_BODY_SCAN_BYTES then total_len = MAX_BODY_SCAN_BYTES end + return ceil(total_len / 4) end local function _extract_max_tokens(body) diff --git a/src/fairvisor/rule_engine.lua b/src/fairvisor/rule_engine.lua index 455c6a4..17f9758 100644 --- a/src/fairvisor/rule_engine.lua +++ b/src/fairvisor/rule_engine.lua @@ -459,17 +459,22 @@ local function _is_override_active(block, now) return false end + local expires_epoch = block._expires_epoch + if expires_epoch then + return expires_epoch > now + end + local expires_at = block.expires_at if type(expires_at) ~= "string" or expires_at == "" then return false end - local expires_epoch, parse_err = utils.parse_iso8601(expires_at) + local parsed_epoch, parse_err = utils.parse_iso8601(expires_at) if parse_err ~= nil then return false end - return expires_epoch > now + return parsed_epoch > now end local function _maybe_log_override_state(flags) diff --git a/src/fairvisor/saas_client.lua b/src/fairvisor/saas_client.lua index 827094a..4351148 100644 --- a/src/fairvisor/saas_client.lua +++ b/src/fairvisor/saas_client.lua @@ -152,15 +152,22 @@ local function _http_client() end local function _auth_header() + if _state.auth_header then + return _state.auth_header + end + local token = _state.config.edge_token if type(token) ~= "string" then - return "Bearer " + _state.auth_header = "Bearer " + return _state.auth_header end -- Defensive: reject tokens that could inject CR/LF into HTTP headers. if token:find("[\r\n]") then - return "Bearer " + _state.auth_header = "Bearer " + return _state.auth_header end - return "Bearer " .. token + _state.auth_header = "Bearer " .. token + return _state.auth_header end local function _is_non_retriable_status(status) @@ -697,6 +704,7 @@ local function _reset_state(config, deps) _state.register_attempt = 0 _state.register_next_retry_at = 0 _state.last_config_poll_at = 0 + _state.auth_header = nil end local function _validate_config(config)