diff --git a/lib/aikido/zen.rb b/lib/aikido/zen.rb index 6ac3e649..491c4b25 100644 --- a/lib/aikido/zen.rb +++ b/lib/aikido/zen.rb @@ -10,6 +10,7 @@ require_relative "zen/worker" require_relative "zen/agent" require_relative "zen/api_client" +require_relative "zen/api_stream" require_relative "zen/context" require_relative "zen/current_context" require_relative "zen/detached_agent" diff --git a/lib/aikido/zen/agent.rb b/lib/aikido/zen/agent.rb index 484d9ff1..ff145e58 100644 --- a/lib/aikido/zen/agent.rb +++ b/lib/aikido/zen/agent.rb @@ -21,15 +21,17 @@ def initialize( collector: Aikido::Zen.collector, detached_agent: Aikido::Zen.detached_agent, worker: Aikido::Zen::Worker.new(config: config), - api_client: Aikido::Zen::APIClient.new(config: config) + api_client: Aikido::Zen::APIClient.new(config: config), + api_stream: Aikido::Zen::APIStream.new(config: config) ) - @started_at = nil - @config = config - @worker = worker - @api_client = api_client @collector = collector @detached_agent = detached_agent + @worker = worker + @api_client = api_client + @api_stream = api_stream + + @started_at = nil end def started? @@ -59,7 +61,7 @@ def start! at_exit { stop! if started? } report(Events::Started.new(time: @started_at)) do |response| - if Aikido::Zen.runtime_settings.update_from_runtime_config_json(response) + if update_settings_from_runtime_config!(response) updated_settings! @config.logger.info("Updated runtime settings") end @@ -68,7 +70,7 @@ def start! end begin - Aikido::Zen.runtime_settings.update_from_runtime_firewall_lists_json(@api_client.fetch_runtime_firewall_lists) + update_settings_from_runtime_firewall_lists!(@api_client.fetch_runtime_firewall_lists) @config.logger.info("Updated runtime firewall list") rescue => err @config.logger.error(err.message) @@ -82,6 +84,15 @@ def start! @config.logger.info("Executed initial heartbeat after #{heartbeat_delay} seconds") end end + + if @config.realtime_updates_enabled? + if @api_stream.can_connect? + @api_stream.handle("config-updated") { |event| settings_updated(event) } + @api_stream.start! + else + @config.logger.warn("Can't reach #{Aikido::Zen.config.realtime_endpoint}, make sure it's in your outbound firewall allowlist. Realtime config updates won't be available, switched to polling.") + end + end end # Clean up any ongoing threads, and reset the state. Called automatically @@ -92,6 +103,8 @@ def stop! @config.logger.info("Stopping Aikido agent") @started_at = nil @worker.shutdown + + @api_stream.stop! end # Respond to the runtime settings changing after being fetched from the @@ -157,11 +170,11 @@ def send_heartbeat(at: Time.now.utc) heartbeat = @collector.flush report(heartbeat) do |response| - if Aikido::Zen.runtime_settings.update_from_runtime_config_json(response) + if update_settings_from_runtime_config!(response) updated_settings! @config.logger.info("Updated runtime settings after heartbeat") - Aikido::Zen.runtime_settings.update_from_runtime_firewall_lists_json(@api_client.fetch_runtime_firewall_lists) + update_settings_from_runtime_firewall_lists!(@api_client.fetch_runtime_firewall_lists) @config.logger.info("Updated runtime firewall list after heartbeat") end end @@ -177,23 +190,88 @@ def send_heartbeat(at: Time.now.utc) def poll_for_setting_updates @worker.every(@config.polling_interval) do if @api_client.should_fetch_settings? - if Aikido::Zen.runtime_settings.update_from_runtime_config_json(@api_client.fetch_runtime_config) + if update_settings_from_runtime_config!(@api_client.fetch_runtime_config) updated_settings! @config.logger.info("Updated runtime settings after polling") end - Aikido::Zen.runtime_settings.update_from_runtime_firewall_lists_json(@api_client.fetch_runtime_firewall_lists) + update_settings_from_runtime_firewall_lists!(@api_client.fetch_runtime_firewall_lists) @config.logger.info("Updated runtime firewall list after polling") end end end - private def heartbeats + private + + def settings_updated(event) + updated_at = Time.at(event[:data]["configUpdatedAt"].to_i) + + if should_fetch_settings?(updated_at) + if update_settings_from_runtime_config!(@api_client.fetch_runtime_config) + updated_settings! + @config.logger.info("Updated runtime settings after server-side event") + + update_settings_from_runtime_firewall_lists!(@api_client.fetch_runtime_firewall_lists) + @config.logger.info("Updated runtime firewall list after server-side event") + end + end + end + + def should_fetch_settings?(updated_at, last_updated_at = Aikido::Zen.runtime_settings.updated_at) + return false unless @api_client.can_make_requests? + return true if last_updated_at.nil? + + updated_at > last_updated_at + end + + def heartbeats @heartbeats ||= Aikido::Zen::Agent::HeartbeatsManager.new( config: @config, worker: @worker ) end + + module ExclusiveUpdater + # Define a method `method_name` that returns early if the method is running. + # + # @param method_name [Symbol, String] the name of the method to define + # @yield the block to execute + # @yieldparam args [Array] the positional arguments passed to the method + # @yieldparam blk [Proc] the block passed to the method + # @yieldparam kwargs [Hash] the keyword arguments passed to the method + # @yieldreturn [Object] the return value of the method + # @return [void] + def exclusive_updater(method_name, &block) + raise ArgumentError, "block required" unless block + + instance_variable = :"@__updater_#{block.object_id}" + + define_method(method_name) do |*args, **kwargs| + updating = instance_variable_get(instance_variable) || + instance_variable_set(instance_variable, Concurrent::AtomicBoolean.new) + + return unless updating.make_true + begin + instance_exec(*args, **kwargs, &block) + ensure + updating.make_false + end + end + end + end + extend ExclusiveUpdater + + # @param data [Hash] + # @return [Boolean, nil] + exclusive_updater :update_settings_from_runtime_config! do |data| + Aikido::Zen.runtime_settings.update_from_runtime_config_json(data) + end + + # @param data [Hash] + # @return [Boolean, nil] + exclusive_updater :update_settings_from_runtime_firewall_lists! do |data| + Aikido::Zen.runtime_settings.update_from_runtime_firewall_lists_json(data) + end end end diff --git a/lib/aikido/zen/api_client.rb b/lib/aikido/zen/api_client.rb index a2418ea3..b6f258a3 100644 --- a/lib/aikido/zen/api_client.rb +++ b/lib/aikido/zen/api_client.rb @@ -41,7 +41,7 @@ def should_fetch_settings?(last_updated_at = Aikido::Zen.runtime_settings.update base_url: @config.realtime_endpoint ) - new_updated_at = Time.at(response["configUpdatedAt"].to_i / 1000) + new_updated_at = Time.at(response["configUpdatedAt"].to_i) new_updated_at > last_updated_at end diff --git a/lib/aikido/zen/api_stream.rb b/lib/aikido/zen/api_stream.rb new file mode 100644 index 00000000..b6bcbd8a --- /dev/null +++ b/lib/aikido/zen/api_stream.rb @@ -0,0 +1,197 @@ +# frozen_string_literal: true + +require "net/http" +require "uri" +require "json" + +module Aikido::Zen + class APIStream + def initialize( + config: Aikido::Zen.config, + min_backoff: 5, + max_backoff: 60, + backoff_reset: 30, + open_timeout: 5, + write_timeout: open_timeout, + read_timeout: 70 + ) + @config = config + @min_backoff = min_backoff + @max_backoff = max_backoff + @backoff_reset = backoff_reset + @open_timeout = open_timeout + @write_timeout = write_timeout + @read_timeout = read_timeout + + @running = Concurrent::AtomicBoolean.new + @executor = nil + + @host = @config.realtime_endpoint.host + @port = @config.realtime_endpoint.port + @use_ssl = @config.realtime_endpoint.scheme == "https" + @token = @config.api_token + + @handlers = Concurrent::Array.new + end + + # @return [Boolean] whether we could connect to the realtime endpoint + def can_connect? + http = Net::HTTP.new(@host, @port) + http.use_ssl = @use_ssl + http.open_timeout = 5 + http.write_timeout = 5 + http.read_timeout = 5 + http.max_retries = 0 + + request = Net::HTTP::Get.new("/config") + request["Authorization"] = @token + + begin + http.request(request) + + return true + rescue Timeout::Error, SocketError, IOError, SystemCallError, OpenSSL::OpenSSLError => err + @config.logger.debug("Error probing realtime endpoint: #{err.class}: #{err.message}") + rescue => err + @config.logger.error("Error probing realtime endpoint: #{err.class}: #{err.message}") + end + + false + end + + def running? + @running.true? + end + alias_method :started?, :running? + + def start! + return false unless @running.make_true + + @executor = Concurrent::SingleThreadExecutor.new + + @executor.post do + backoff = @min_backoff + + while running? + time_before = Process.clock_gettime(Process::CLOCK_MONOTONIC, :second) + + begin + work + rescue Timeout::Error, SocketError, IOError, SystemCallError, OpenSSL::OpenSSLError => err + @config.logger.debug("Error in API stream: #{err.class}: #{err.message}") + rescue => err + @config.logger.error("Error in API stream: #{err.class}: #{err.message}") + end + + break unless running? + + time_after = Process.clock_gettime(Process::CLOCK_MONOTONIC, :second) + + backoff = if time_after - time_before > @backoff_reset + @min_backoff + else + [backoff * 2, @max_backoff].min + end + + jitter = rand * backoff / 2 + + @config.logger.debug("API stream reconnecting in %d seconds" % (backoff + jitter).ceil) + + sleep(backoff + jitter) + end + end + + true + end + + def stop! + return false unless @running.make_false + + @executor.shutdown + @executor.wait_for_termination(@read_timeout) + + true + end + + def handle(type, &block) + raise ArgumentError, "block required" unless block + + @handlers << proc do |event| + block.call(event) if type === event[:type] + end + end + + private def work + http = Net::HTTP.new(@host, @port) + http.use_ssl = @use_ssl + http.open_timeout = @open_timeout + http.write_timeout = @write_timeout + http.read_timeout = @read_timeout + http.max_retries = 0 + + request = Net::HTTP::Get.new("/api/runtime/stream") + request["Authorization"] = @token + request["Accept"] = "text/event-stream" + request["Cache-Control"] = "no-cache" + + @config.logger.debug("API stream connecting") + http.start + @config.logger.debug("API stream connected") + + begin + http.request(request) do |response| + case response.code.to_i + when 200 + # empty + when 401, 403 + @running.make_false + return nil + else + return nil + end + + buffer = +"" + + response.read_body do |chunk| + return nil unless running? + + @config.logger.debug("API stream received chunk of #{chunk.bytesize} bytes") + + buffer << chunk + + while (index = buffer.index("\n\n")) + event_str = buffer.slice!(0..index + 1) + buffer = buffer.lstrip + + event = {} + + begin + event_str.each_line do |line| + case line + when /^event:\s*(.+)/ + event[:type] = $1.strip + when /^data:\s*(.+)/ + event[:data] = JSON.parse($1.strip) + end + end + rescue => err + @config.logger.error("Error in API stream: #{err.class}: #{err.message}") + next + end + + @handlers.each do |handler| + handler.call(event) + rescue => err + @config.logger.error("Error in API stream: #{err.class}: #{err.message}") + end + end + end + end + ensure + @config.logger.debug("API stream disconnecting") + http.finish + @config.logger.debug("API stream disconnected") + end + end + end +end diff --git a/lib/aikido/zen/config.rb b/lib/aikido/zen/config.rb index eb43773a..1e8a7788 100644 --- a/lib/aikido/zen/config.rb +++ b/lib/aikido/zen/config.rb @@ -216,6 +216,11 @@ class Config # Defaults to 1000 entries. attr_accessor :idor_max_cache_entries + # @return [Boolean] whether the realtime updates feature is enabled. + # Defaults to false. + attr_accessor :realtime_updates_enabled + alias_method :realtime_updates_enabled?, :realtime_updates_enabled + def initialize self.insert_middleware_after = ::ActionDispatch::RemoteIp self.disabled = read_boolean_from_env(ENV.fetch("AIKIDO_DISABLE", false)) || read_boolean_from_env(ENV.fetch("AIKIDO_DISABLED", false)) @@ -261,6 +266,7 @@ def initialize self.idor_tenant_column_name = nil self.idor_excluded_table_names = [] self.idor_max_cache_entries = 1000 + self.realtime_updates_enabled = false end # Set the base URL for API requests. diff --git a/lib/aikido/zen/runtime_settings.rb b/lib/aikido/zen/runtime_settings.rb index 12d9d093..8c1a480d 100644 --- a/lib/aikido/zen/runtime_settings.rb +++ b/lib/aikido/zen/runtime_settings.rb @@ -80,11 +80,11 @@ def initialize(*) # # @param data [Hash] the decoded JSON payload from the /api/runtime/config # API endpoint. - # @return [bool] + # @return [Boolean] def update_from_runtime_config_json(data) last_updated_at = updated_at - self.updated_at = Time.at(data["configUpdatedAt"].to_i / 1000) + self.updated_at = Time.at(data["configUpdatedAt"].to_i) self.heartbeat_interval = data["heartbeatIntervalInMS"].to_i / 1000 self.endpoints = RuntimeSettings::Endpoints.from_json(data["endpoints"]) self.blocked_user_ids = data["blockedUserIds"] @@ -105,7 +105,7 @@ def update_from_runtime_config_json(data) # # @param data [Hash] the decoded JSON payload from the /api/runtime/firewall/lists # API endpoint. - # @return [void] + # @return [Boolean] def update_from_runtime_firewall_lists_json(data) self.blocked_user_agent_regexp = pattern(data["blockedUserAgents"]) @@ -142,6 +142,8 @@ def update_from_runtime_firewall_lists_json(data) data["monitoredIPAddresses"]&.each do |ip_list| monitored_ip_lists << RuntimeSettings::IPList.from_json(ip_list) end + + true end # Construct a regular expression from the non-nil and non-empty string, diff --git a/test/aikido/zen/agent_test.rb b/test/aikido/zen/agent_test.rb index 7580c626..f4dd918f 100644 --- a/test/aikido/zen/agent_test.rb +++ b/test/aikido/zen/agent_test.rb @@ -23,18 +23,32 @@ def report(event) end end + class MockAPIStream < Aikido::Zen::APIStream + def work + nil + end + end + + def stub_probe_realtime_endpoint + stub_request(:get, "#{@config.realtime_endpoint}/config") + end + setup do @config = Aikido::Zen.config @config.api_token = "TOKEN" + @config.realtime_updates_enabled = true - @api_client = Minitest::Mock.new(MockAPIClient.new) @collector = Aikido::Zen.collector @worker = MockWorker.new + @api_client = Minitest::Mock.new(MockAPIClient.new) + @api_stream = Minitest::Mock.new(MockAPIStream.new) @agent = Aikido::Zen::Agent.new( - api_client: @api_client, + config: @config, collector: @collector, - worker: @worker + worker: @worker, + api_client: @api_client, + api_stream: @api_stream ) @test_sink = Aikido::Zen::Sink.new("test", scanners: [NOOP]) @@ -45,6 +59,8 @@ def report(event) end test "knows if it has started" do + stub_probe_realtime_endpoint + refute @agent.started? @agent.start! @@ -55,6 +71,8 @@ def report(event) end test "#start! fails if attempted to start multiple times" do + stub_probe_realtime_endpoint + @agent.start! err = assert_raises Aikido::ZenError do @@ -65,12 +83,16 @@ def report(event) end test "#start! sets the start time for our stats funnel" do + stub_probe_realtime_endpoint + assert_changes "@collector.stats.started_at", from: nil do @agent.start! end end test "#start! warns if blocking mode is disabled" do + stub_probe_realtime_endpoint + @config.blocking_mode = false @agent.start! @@ -79,6 +101,8 @@ def report(event) end test "#start! notifies if blocking mode is enabled" do + stub_probe_realtime_endpoint + @config.blocking_mode = true @agent.start! @@ -87,14 +111,18 @@ def report(event) end test "#start! notifies if an API token has been set" do + stub_probe_realtime_endpoint + @config.api_token = "TOKEN" @agent.start! - assert_logged :debug, /api token set! reporting has been enabled/i + assert_logged :info, /api token set! reporting has been enabled/i refute_logged :warn, /no api token set! reporting has been disabled/i end test "#start! warns if there's no API token set" do + stub_probe_realtime_endpoint + @config.api_token = nil @agent.start! @@ -102,7 +130,79 @@ def report(event) refute_logged :debug, /api token set! reporting has been enabled/i end + test "#start! probes the realtime endpoint" do + request = stub_probe_realtime_endpoint + .to_return(status: 200, body: "") + + @config.api_token = "TOKEN" + @agent.start! + + assert_requested request + + refute_logged :debug, /error probing realtime endpoint/i + refute_logged :error, /error probing realtime endpoint/i + refute_logged :warn, /can't reach #{Aikido::Zen.config.realtime_endpoint}/i + end + + test "#start! probes the realtime endpont and logs warning after open timeout" do + request = stub_probe_realtime_endpoint + .to_raise(Net::OpenTimeout) + + @config.api_token = "TOKEN" + @agent.start! + + assert_requested request + + assert_logged :debug, /error probing realtime endpoint/i + refute_logged :error, /error probing realtime endpoint/i + assert_logged :warn, /can't reach #{Aikido::Zen.config.realtime_endpoint}/i + end + + test "#start! probes the realtime endpont and logs warning after write timeout" do + request = stub_probe_realtime_endpoint + .to_raise(Net::WriteTimeout) + + @config.api_token = "TOKEN" + @agent.start! + + assert_requested request + + assert_logged :debug, /error probing realtime endpoint/i + refute_logged :error, /error probing realtime endpoint/i + assert_logged :warn, /can't reach #{Aikido::Zen.config.realtime_endpoint}/i + end + + test "#start! probes the realtime endpont and logs warning after read timeout" do + request = stub_probe_realtime_endpoint + .to_raise(Net::ReadTimeout) + + @config.api_token = "TOKEN" + @agent.start! + + assert_requested request + + assert_logged :debug, /error probing realtime endpoint/i + refute_logged :error, /error probing realtime endpoint/i + assert_logged :warn, /can't reach #{Aikido::Zen.config.realtime_endpoint}/i + end + + test "#start! probes the realtime endpont and logs error after unexpected error" do + request = stub_probe_realtime_endpoint + .to_raise(RuntimeError) + + @config.api_token = "TOKEN" + @agent.start! + + assert_requested request + + refute_logged :debug, /error probing realtime endpoint/i + assert_logged :error, /error probing realtime endpoint/i + assert_logged :warn, /can't reach #{Aikido::Zen.config.realtime_endpoint}/i + end + test "#start! reports a STARTED event" do + stub_probe_realtime_endpoint + @api_client.expect :report, {}, [Aikido::Zen::Events::Started] @agent.start! @@ -111,8 +211,10 @@ def report(event) end test "#start! takes the response of the STARTED event as runtime settings" do + stub_probe_realtime_endpoint + @api_client.expect :report, - {"configUpdatedAt" => 1234567890000}, + {"configUpdatedAt" => 1234567890}, [Aikido::Zen::Events::Started] assert_changes -> { Aikido::Zen.runtime_settings.updated_at }, to: Time.at(1234567890) do @@ -136,6 +238,8 @@ def @api_client.report(event) end test "#start! starts polling for setting updates every minute" do + stub_probe_realtime_endpoint + @api_client.expect :should_fetch_settings?, false assert_difference "@worker.jobs.size", +1 do @@ -151,8 +255,10 @@ def @api_client.report(event) end test "#start! updates the runtime settings after polling if needed" do + stub_probe_realtime_endpoint + @api_client.expect :should_fetch_settings?, true - @api_client.expect :fetch_runtime_config, {"configUpdatedAt" => 1234567890000} + @api_client.expect :fetch_runtime_config, {"configUpdatedAt" => 1234567890} assert_changes -> { Aikido::Zen.runtime_settings.updated_at }, to: Time.at(1234567890) do @agent.start! @@ -331,6 +437,8 @@ def @api_client.report(event) end test "#start! queues a one-off tasks for each initial heartbeat delay" do + stub_probe_realtime_endpoint + size = @config.initial_heartbeat_delays.size assert_difference "@worker.delayed.size", size do @@ -346,6 +454,8 @@ def @api_client.report(event) end test "#start! successfully sends the initial heartbeats" do + stub_probe_realtime_endpoint + # Make sure there are _some_ stats @collector.track_request @@ -432,13 +542,4 @@ def exception(*) Aikido::Zen::UnderAttackError.new(self) end end - - def stub_context(path = "/", env = {}) - env = Rack::MockRequest.env_for(path, {"REQUEST_METHOD" => "GET"}.merge(env)) - Aikido::Zen.current_context = Aikido::Zen::Context.from_rack_env(env) - end - - def stub_request(path = "/", env = {}) - stub_context(path, env).request - end end diff --git a/test/aikido/zen/api_client_test.rb b/test/aikido/zen/api_client_test.rb index 202a7184..fd64f69e 100644 --- a/test/aikido/zen/api_client_test.rb +++ b/test/aikido/zen/api_client_test.rb @@ -83,7 +83,7 @@ class CheckIfStaleConfigTest < ActiveSupport::TestCase test "returns false if the updated_at from the server is the same or older than the one we have" do stub_request(:get, "https://runtime.aikido.dev/config") - .to_return(status: 200, body: JSON.dump(configUpdatedAt: 1234567890000)) + .to_return(status: 200, body: JSON.dump(configUpdatedAt: 1234567890)) Aikido::Zen.runtime_settings.updated_at = Time.at(1234567890) assert_not @client.should_fetch_settings? @@ -94,7 +94,7 @@ class CheckIfStaleConfigTest < ActiveSupport::TestCase test "returns true if the updated_at from the server is newer than the one we have" do stub_request(:get, "https://runtime.aikido.dev/config") - .to_return(status: 200, body: JSON.dump(configUpdatedAt: 1234567890000)) + .to_return(status: 200, body: JSON.dump(configUpdatedAt: 1234567890)) Aikido::Zen.runtime_settings.updated_at = Time.at(1234567890 - 1) assert @client.should_fetch_settings? @@ -102,7 +102,7 @@ class CheckIfStaleConfigTest < ActiveSupport::TestCase test "sets the User-Agent on the request" do stub_request(:get, "https://runtime.aikido.dev/config") - .to_return(status: 200, body: JSON.dump(configUpdatedAt: 1234567890000)) + .to_return(status: 200, body: JSON.dump(configUpdatedAt: 1234567890)) @client.should_fetch_settings? diff --git a/test/aikido/zen/api_stream_test.rb b/test/aikido/zen/api_stream_test.rb new file mode 100644 index 00000000..9831d8f9 --- /dev/null +++ b/test/aikido/zen/api_stream_test.rb @@ -0,0 +1,238 @@ +# frozen_string_literal: true + +require "test_helper" +require "securerandom" + +class Aikido::Zen::StreamTest < ActiveSupport::TestCase + setup do + config = Aikido::Zen.config + config.api_token = "TOKEN" + + @endpoint = "#{config.realtime_endpoint}/api/runtime/stream" + + @api_stream = Aikido::Zen::APIStream.new( + min_backoff: 0.02, + max_backoff: 0.08, + backoff_reset: 0.04, + open_timeout: 1, + read_timeout: 1 + ) + end + + teardown do + @api_stream.stop! + end + + DEFAULT_SSE_BODY = <<~SSE + event: config-updated + data: {"serviceId":1,"configUpdatedAt":1779292466} + + event: config-updated + data: {"serviceId":1,"configUpdatedAt":1779292467} + + : ping + + SSE + + test "#start! returns false if already running" do + stub_request(:get, @endpoint) + .to_return(status: 200, body: DEFAULT_SSE_BODY) + + assert @api_stream.start! + assert_equal false, @api_stream.start! + end + + test "#handle raises ArgumentError without a block" do + assert_raises(ArgumentError) { @api_stream.handle("config-updated") } + end + + test "it starts and connects" do + connection = stub_request(:get, @endpoint) + .with( + headers: { + "Authorization" => "TOKEN", + "Accept" => "text/event-stream", + "Cache-Control" => "no-cache" + } + ) + .to_return(status: 200, body: "", headers: {}) + + assert_connects(connection, times: 1) + end + + test "it handles valid events" do + connection = stub_request(:get, @endpoint) + .to_return(status: 200, body: DEFAULT_SSE_BODY) + + events = Concurrent::Array.new + @api_stream.handle("config-updated") { |event| events << event } + + assert_connects(connection, times: 1) + + assert_equal 2, events.size + + assert_equal "config-updated", events[0][:type] + assert_equal 1, events[0][:data]["serviceId"] + assert_equal 1779292466, events[0][:data]["configUpdatedAt"] + + assert_equal "config-updated", events[1][:type] + assert_equal 1, events[1][:data]["serviceId"] + assert_equal 1779292467, events[1][:data]["configUpdatedAt"] + end + + test "it skips invalid events and continues processing" do + body = <<~SSE + event: config-updated + data: not valid json + + event: config-updated + data: {"serviceId":1,"configUpdatedAt":1779292466} + + SSE + + connection = stub_request(:get, @endpoint) + .to_return(status: 200, body: body) + + events = Concurrent::Array.new + @api_stream.handle("config-updated") { |event| events << event } + + assert_connects(connection, times: 1) + + assert_equal 1, events.size + + assert_equal "config-updated", events[0][:type] + assert_equal 1, events[0][:data]["serviceId"] + assert_equal 1779292466, events[0][:data]["configUpdatedAt"] + end + + test "it skips handler errors and continues processing" do + connection = stub_request(:get, @endpoint) + .to_return(status: 200, body: DEFAULT_SSE_BODY) + + events = Concurrent::Array.new + @api_stream.handle("config-updated") { |_event| raise "handler error" } + @api_stream.handle("config-updated") { |event| events << event } + + assert_connects(connection, times: 1) + + assert_equal 2, events.size + end + + test "it reconnects after the stream ends naturally" do + connection = stub_request(:get, @endpoint) + .to_return(status: 200, body: DEFAULT_SSE_BODY).then + .to_return(status: 200, body: DEFAULT_SSE_BODY) + + assert_connects(connection, times: 2) + end + + test "it reconnects after connection reset" do + connection = stub_request(:get, @endpoint) + .to_raise(Errno::ECONNRESET).then + .to_return(status: 200, body: DEFAULT_SSE_BODY) + + assert_connects(connection, times: 2) + end + + test "it reconnects after connection refused" do + connection = stub_request(:get, @endpoint) + .to_raise(Errno::ECONNREFUSED).then + .to_return(status: 200, body: DEFAULT_SSE_BODY) + + assert_connects(connection, times: 2) + end + + test "it reconnects after open timeout" do + connection = stub_request(:get, @endpoint) + .to_raise(Net::OpenTimeout).then + .to_return(status: 200, body: DEFAULT_SSE_BODY) + + assert_connects(connection, times: 2) + end + + test "it reconnects after write timeout" do + connection = stub_request(:get, @endpoint) + .to_raise(Net::WriteTimeout).then + .to_return(status: 200, body: DEFAULT_SSE_BODY) + + assert_connects(connection, times: 2) + end + + test "it reconnects after read timeout" do + connection = stub_request(:get, @endpoint) + .to_raise(Net::ReadTimeout).then + .to_return(status: 200, body: DEFAULT_SSE_BODY) + + assert_connects(connection, times: 2) + end + + test "it reconnects after unexpected error" do + connection = stub_request(:get, @endpoint) + .to_raise(RuntimeError).then + .to_return(status: 200, body: DEFAULT_SSE_BODY) + + assert_connects(connection, times: 2) + end + + test "it reconnects after unexpected HTTP status code" do + connection = stub_request(:get, @endpoint) + .to_return(status: 418).then + .to_return(status: 200, body: DEFAULT_SSE_BODY) + + @api_stream.start! + + assert @api_stream.running? + + wait_until(timeout: 2) { connected?(connection, times: 2) } + + assert @api_stream.running? + + assert_requested connection, times: 2 + end + + test "it does not reconnect after 401 Unauthorized" do + connection = stub_request(:get, @endpoint) + .to_return(status: 401) + + @api_stream.start! + + assert @api_stream.running? + + wait_until(timeout: 2) { connected?(connection, times: 1) } + + refute @api_stream.running? + + assert_requested connection, times: 1 + end + + test "it does not reconnect after 403 Forbidden" do + connection = stub_request(:get, @endpoint) + .to_return(status: 403) + + @api_stream.start! + + assert @api_stream.running? + + wait_until(timeout: 2) { connected?(connection, times: 1) } + + refute @api_stream.running? + + assert_requested connection, times: 1 + end + + private + + def connected?(connection, times: 1) + WebMock::RequestRegistry.instance.times_executed(connection.request_pattern) == times + end + + def assert_connects(connection, times:, timeout: 2) + @api_stream.start! + + wait_until(timeout: timeout) { connected?(connection, times: times) } + + @api_stream.stop! + + assert_requested connection, times: times + end +end diff --git a/test/aikido/zen/config_test.rb b/test/aikido/zen/config_test.rb index 4bb88126..14504a84 100644 --- a/test/aikido/zen/config_test.rb +++ b/test/aikido/zen/config_test.rb @@ -54,6 +54,7 @@ class Aikido::Zen::ConfigTest < ActiveSupport::TestCase assert_nil @config.idor_tenant_column_name assert_equal [], @config.idor_excluded_table_names assert_equal 1000, @config.idor_max_cache_entries + assert_equal false, @config.realtime_updates_enabled? end test "can set AIKIDO_DISABLE to configure if the agent should be turned off" do diff --git a/test/aikido/zen/runtime_settings_test.rb b/test/aikido/zen/runtime_settings_test.rb index 8b39ec1a..b127487d 100644 --- a/test/aikido/zen/runtime_settings_test.rb +++ b/test/aikido/zen/runtime_settings_test.rb @@ -11,7 +11,7 @@ class Aikido::Zen::RuntimeSettingsTest < ActiveSupport::TestCase assert @settings.update_from_runtime_config_json({ "success" => true, "serviceId" => 1234, - "configUpdatedAt" => 1717171717000, + "configUpdatedAt" => 1717171717, "heartbeatIntervalInMS" => 60000, "endpoints" => [], "blockedUserIds" => [], @@ -60,7 +60,7 @@ class Aikido::Zen::RuntimeSettingsTest < ActiveSupport::TestCase assert @settings.update_from_runtime_config_json({ "success" => true, "serviceId" => 1234, - "configUpdatedAt" => 1717171717000, + "configUpdatedAt" => 1717171717, "heartbeatIntervalInMS" => 60000, "endpoints" => [], "blockedUserIds" => [], @@ -81,7 +81,7 @@ class Aikido::Zen::RuntimeSettingsTest < ActiveSupport::TestCase payload = { "success" => true, "serviceId" => 1234, - "configUpdatedAt" => 1717171717000, + "configUpdatedAt" => 1717171717, "heartbeatIntervalInMS" => 60000, "endpoints" => [], "blockedUserIds" => [], @@ -197,7 +197,7 @@ class Aikido::Zen::RuntimeSettingsTest < ActiveSupport::TestCase assert @settings.update_from_runtime_config_json({ "success" => true, "serviceId" => 1234, - "configUpdatedAt" => 1717171717000, + "configUpdatedAt" => 1717171717, "heartbeatIntervalInMS" => 60000, "endpoints" => [ { @@ -265,7 +265,7 @@ class Aikido::Zen::RuntimeSettingsTest < ActiveSupport::TestCase assert @settings.update_from_runtime_config_json({ "success" => true, "serviceId" => 1234, - "configUpdatedAt" => 1717171717000, + "configUpdatedAt" => 1717171717, "heartbeatIntervalInMS" => 60000, "endpoints" => [], "blockedUserIds" => [], diff --git a/test/aikido/zen/sinks/mysql2_test.rb b/test/aikido/zen/sinks/mysql2_test.rb index 9776325d..b4ff437a 100644 --- a/test/aikido/zen/sinks/mysql2_test.rb +++ b/test/aikido/zen/sinks/mysql2_test.rb @@ -9,6 +9,7 @@ class Aikido::Zen::Sinks::Mysql2Test < ActiveSupport::TestCase setup do @db = Mysql2::Client.new( host: ENV.fetch("MYSQL_HOST", "127.0.0.1"), + port: ENV.fetch("MYSQL_PORT", "3306"), username: ENV.fetch("MYSQL_USERNAME", "root"), password: ENV.fetch("MYSQL_PASSWORD", "") ) @@ -75,6 +76,7 @@ def with_mocked_protector(params = nil) setup do @db = Mysql2::Client.new( host: ENV.fetch("MYSQL_HOST", "127.0.0.1"), + port: ENV.fetch("MYSQL_PORT", "3306"), username: ENV.fetch("MYSQL_USERNAME", "root"), password: ENV.fetch("MYSQL_PASSWORD", "") ) diff --git a/test/aikido/zen/sinks/pg_test.rb b/test/aikido/zen/sinks/pg_test.rb index aa2d65a1..dc8e2188 100644 --- a/test/aikido/zen/sinks/pg_test.rb +++ b/test/aikido/zen/sinks/pg_test.rb @@ -9,6 +9,7 @@ class Aikido::Zen::Sinks::PGTest < ActiveSupport::TestCase setup do @db = PG.connect( host: ENV.fetch("POSTGRES_HOST", "127.0.0.1"), + port: ENV.fetch("POSTGRES_PORT", "5432"), user: ENV.fetch("POSTGRES_USERNAME", "postgres"), password: ENV.fetch("POSTGRES_PASSWORD", "password"), dbname: ENV.fetch("POSTGRES_DATABASE", "postgres") @@ -237,6 +238,7 @@ def with_mocked_protector(params = nil) setup do @db = PG.connect( host: ENV.fetch("POSTGRES_HOST", "127.0.0.1"), + port: ENV.fetch("POSTGRES_PORT", "5432"), user: ENV.fetch("POSTGRES_USERNAME", "postgres"), password: ENV.fetch("POSTGRES_PASSWORD", "password"), dbname: ENV.fetch("POSTGRES_DATABASE", "postgres") diff --git a/test/aikido/zen/sinks/trilogy_test.rb b/test/aikido/zen/sinks/trilogy_test.rb index e8cb32f3..99d29205 100644 --- a/test/aikido/zen/sinks/trilogy_test.rb +++ b/test/aikido/zen/sinks/trilogy_test.rb @@ -9,6 +9,7 @@ class Aikido::Zen::Sinks::TrilogyTest < ActiveSupport::TestCase setup do @db = Trilogy.new( host: ENV.fetch("MYSQL_HOST", "127.0.0.1"), + port: ENV.fetch("MYSQL_PORT", "3306"), username: ENV.fetch("MYSQL_USERNAME", "root"), password: ENV.fetch("MYSQL_PASSWORD", "") ) @@ -75,6 +76,7 @@ def with_mocked_protector(params = nil) setup do @db = Trilogy.new( host: ENV.fetch("MYSQL_HOST", "127.0.0.1"), + port: ENV.fetch("MYSQL_PORT", "3306"), username: ENV.fetch("MYSQL_USERNAME", "root"), password: ENV.fetch("MYSQL_PASSWORD", "") ) diff --git a/test/support/wait_helpers.rb b/test/support/wait_helpers.rb new file mode 100644 index 00000000..2d011341 --- /dev/null +++ b/test/support/wait_helpers.rb @@ -0,0 +1,10 @@ +# frozen_string_literal: true + +module WaitHelpers + def wait_until(timeout:) + start_time = Time.now + until yield || (Time.now - start_time) > timeout + sleep 0.01 + end + end +end diff --git a/test/test_helper.rb b/test/test_helper.rb index 909b5f73..f2937ddc 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -58,11 +58,13 @@ def handle_fork require_relative "support/rate_limiting_assertions" require_relative "support/sink_attack_helpers" require_relative "support/worker_helpers" +require_relative "support/wait_helpers" # Utility proc that does nothing. NOOP = ->(*args, **opts) {} class ActiveSupport::TestCase + include WaitHelpers self.file_fixture_path = "test/fixtures" # Reset any global state before each test @@ -142,7 +144,7 @@ def assert_logged(level = nil, pattern) "matches #{pattern.inspect}. ".squeeze("\s") + "Log messages:\n#{lines.map { |line| "\t* #{line}" }.join("\n")}" - assert lines.any? { |line| pattern === line && (match_level === line or true) }, reason + assert lines.any? { |line| pattern === line && (match_level.nil? || line.include?(match_level)) }, reason end def refute_logged(level = nil, pattern) @@ -155,7 +157,7 @@ def refute_logged(level = nil, pattern) "to match #{pattern.inspect}".squeeze("\s") + "Log messages:\n#{lines.map { |line| "\t* #{line}" }.join("\n")}" - refute lines.any? { |line| pattern === line && (match_level === line or true) }, reason + refute lines.any? { |line| pattern === line && (match_level.nil? || line.include?(match_level)) }, reason end # rubocop:enable Style/OptionalArguments