From 29d0f8adb6be7a37fc4b849720ea415f5c9a5ee8 Mon Sep 17 00:00:00 2001 From: speeddragon Date: Tue, 18 Nov 2025 14:03:17 +0000 Subject: [PATCH] feat: Follow redirect with gun client --- src/hb_http_client.erl | 119 +++++++++++++++++++++++++++++++++++- src/hb_mock_server.erl | 136 +++++++++++++++++++++++++++++++++++++++++ src/hb_util.erl | 20 +++++- 3 files changed, 273 insertions(+), 2 deletions(-) create mode 100644 src/hb_mock_server.erl diff --git a/src/hb_http_client.erl b/src/hb_http_client.erl index cedc2394e..d9f21dcaf 100644 --- a/src/hb_http_client.erl +++ b/src/hb_http_client.erl @@ -3,6 +3,7 @@ -module(hb_http_client). -behaviour(gen_server). -include("include/hb.hrl"). +-include_lib("eunit/include/eunit.hrl"). -export([start_link/1, request/2]). -export([init/1, handle_cast/2, handle_call/3, handle_info/2, terminate/2]). @@ -14,6 +15,7 @@ -define(DEFAULT_RETRIES, 0). -define(DEFAULT_RETRY_TIME, 1000). +-define(MAX_REDIRECTS, 5). %%% ================================================================== %%% Public interface. @@ -151,6 +153,9 @@ gun_req(Args, ReestablishedConnection, Opts) -> true -> {error, client_error}; false -> gun_req(Args, true, Opts) end; + {ok, StatusCode, _Headers, _Body} = Reply + when StatusCode >= 301 andalso StatusCode < 400 -> + follow_redirect(Args, Reply, Opts); Reply -> Reply end; @@ -177,6 +182,52 @@ gun_req(Args, ReestablishedConnection, Opts) -> end, Response. +follow_redirect(Args, {ok, _, ResponseHeaders, _Body} = Reply, Opts) -> + FollowRedirects = maps:get(<<"follow_redirect">>, Opts, false), + CurrentRedirects = maps:get(current_redirects, Opts, 0), + BellowMaxRedirects = CurrentRedirects < ?MAX_REDIRECTS, + case FollowRedirects andalso BellowMaxRedirects of + true -> + #{ peer := Peer, path := Path, method := Method } = Args, + % Only follow the redirect if method is GET. + case Method of + <<"GET">> -> + Location = proplists:get_value(<<"location">>, ResponseHeaders), + #{peer := Peer2, path := Path2} = NewArgs = case Location of + <<"/", _/binary>> = RedirectPath -> + Args#{path => RedirectPath}; + <<"http", _/binary>> -> + URI = uri_string:parse(Location), + NewPeer = uri_string:normalize(maps:remove(path, URI)), + NewPath = maps:get(path, URI), + Args#{peer => NewPeer, path => NewPath}; + undefined -> + ?event(http_client, {error, no_location_header_provided}), + Args + end, + ?event( + http_client, + {follow_redirect, + {from, {peer, Peer}, {path, Path}}, + {to, {peer, Peer2},{path, Path2}} + } + ), + NewOpts = maps:update_with( + current_redirects, + fun (Value) -> Value + 1 end, + 1, + Opts + ), + gun_req(NewArgs, true, NewOpts); + _ -> + ?event(http_client, {error, unsupported_redirect_method}), + Reply + end; + false -> + ?event(http_client, {error, follow_redirect_not_enabled}), + Reply + end. + %% @doc Record the duration of the request in an async process. We write the %% data to prometheus if the application is enabled, as well as invoking the %% `http_monitor' if appropriate. @@ -781,4 +832,70 @@ get_status_class(Data) when is_binary(Data) -> get_status_class(Data) when is_atom(Data) -> atom_to_binary(Data); get_status_class(_) -> - <<"unknown">>. \ No newline at end of file + <<"unknown">>. + +%% Tests + +start_mock_gateway(Responses) -> + DefaultResponse = {200, <<>>}, + Endpoints = [ + {"/redirect1", redirect1, maps:get(redirect1, Responses, DefaultResponse)}, + {"/redirect2", redirect2, maps:get(redirect2, Responses, DefaultResponse)}, + {"/redirect3", redirect3, maps:get(redirect3, Responses, DefaultResponse)}, + {"/redirect4", redirect4, maps:get(redirect4, Responses, DefaultResponse)}, + {"/redirect5", redirect5, maps:get(redirect5, Responses, DefaultResponse)}, + {"/redirect6", redirect6, maps:get(redirect6, Responses, DefaultResponse)}, + {"/ok", ok, maps:get(ok, Responses, DefaultResponse)} + ], + hb_mock_server:start(Endpoints). + +do_not_follow_redirect_by_default_test() -> + application:ensure_all_started(hb), + {ok, MockServer, ServerHandle} = start_mock_gateway(#{ + redirect1 => {301, <<"1">>, #{<<"location">> => <<"/ok">>}} + }), + try + Opts = #{}, + Args = #{peer => MockServer, path => "/redirect1", method => <<"GET">>}, + Response = request(Args, Opts), + ?assertMatch({ok, 301, _, _}, Response), + ok + after + hb_mock_server:stop(ServerHandle) + end. + +follow_redirect_test() -> + application:ensure_all_started(hb), + {ok, MockServer, ServerHandle} = start_mock_gateway(#{ + redirect1 => {301, <<"1">>, #{<<"location">> => <<"/ok">>}} + }), + try + Opts = #{<<"follow_redirect">> => true}, + Args = #{peer => MockServer, path => "/redirect1", method => <<"GET">>}, + Response = request(Args, Opts), + ?assertMatch({ok, 200, _, _}, Response), + ok + after + hb_mock_server:stop(ServerHandle) + end. + +max_redirect_test() -> + application:ensure_all_started(hb), + {ok, MockServer, ServerHandle} = start_mock_gateway(#{ + redirect1 => {301, <<"1">>, #{<<"location">> => <<"/redirect2">>}}, + redirect2 => {301, <<"2">>, #{<<"location">> => <<"/redirect3">>}}, + redirect3 => {301, <<"3">>, #{<<"location">> => <<"/redirect4">>}}, + redirect4 => {301, <<"4">>, #{<<"location">> => <<"/redirect5">>}}, + redirect5 => {301, <<"5">>, #{<<"location">> => <<"/redirect6">>}}, + redirect6 => {301, <<"6">>, #{<<"location">> => <<"/ok">>}} + }), + try + Opts = #{<<"follow_redirect">> => true}, + Args = #{peer => MockServer, path => "/redirect1", method => <<"GET">>}, + Response = request(Args, Opts), + %% Return the last response + ?assertMatch({ok, 301, _, <<"6">>}, Response), + ok + after + hb_mock_server:stop(ServerHandle) + end. \ No newline at end of file diff --git a/src/hb_mock_server.erl b/src/hb_mock_server.erl new file mode 100644 index 000000000..c2a57ae69 --- /dev/null +++ b/src/hb_mock_server.erl @@ -0,0 +1,136 @@ +%%% @doc Mock HTTP server for testing. Collects request bodies and returns +%%% configurable responses. +-module(hb_mock_server). +-export([start/1, stop/1, get_requests/2, get_requests/3, get_requests/4]). +%% Cowboy handler callback +-export([init/2]). +-include("include/hb.hrl"). + +%%%=================================================================== +%%% Public API +%%%=================================================================== + +%% @doc Start a generic mock HTTP server that collects request bodies. +%% Usage: start([{"/endpoint", endpoint_tag, {status, body}}, ...]) +%% start([{"/endpoint", endpoint_tag, fun(Req) -> {Status, Body} end}, ...]) +%% start([{"/endpoint", endpoint_tag}, ...]) for default {200, <<"OK">>} +%% +%% Response formats: +%% {Status, Body} - Static response +%% fun(Req) -> ... - Function called with request map, returns {Status, Body} +%% +%% Paths support Cowboy route patterns: +%% "/price/:amount" - Matches /price/123, /price/abc, etc. +%% "/user/:id/post/:post_id" - Multiple parameters +%% "/files/[...]" - Catch-all (matches /files/anything/here) +%% +%% Automatically generates unique listener ID and dynamic port. +%% Returns: {ok, ServerURL, ServerHandle} +start(Endpoints) -> + %% Ensure cowboy/ranch are started + application:ensure_all_started(cowboy), + CollectorPID = spawn(fun() -> collect_loop(#{}) end), + ListenerID = make_ref(), + NormalizedEndpoints = lists:map( + fun + ({Path, Tag, Response}) when is_function(Response) -> + {Path, Tag, Response}; + ({Path, Tag, {Status, Body, Headers}}) -> + {Path, Tag, {Status, Body, Headers}}; + ({Path, Tag, {Status, Body}}) -> + {Path, Tag, {Status, Body, #{}}}; + ({Path, Tag}) -> + {Path, Tag, {200, <<>>, #{}}} + end, + Endpoints + ), + Routes = [ + {Path, ?MODULE, {Tag, Response, CollectorPID}} + || {Path, Tag, Response} <- NormalizedEndpoints + ], + Dispatch = cowboy_router:compile([{'_', Routes}]), + {ok, _Listener} = cowboy:start_clear( + ListenerID, + [{port, 0}], %% dynamic port allocation + #{env => #{dispatch => Dispatch}} + ), + %% Get the port that was assigned + Port = ranch:get_port(ListenerID), + ServerURL = iolist_to_binary(io_lib:format("http://localhost:~p", [Port])), + {ok, ServerURL, {CollectorPID, ListenerID}}. + +stop({CollectorPID, ListenerID}) -> + cowboy:stop_listener(ListenerID), + CollectorPID ! stop. + +%% @doc Get all requests collected for a given endpoint tag. +%% Returns the accumulated requests without clearing them. +%% Takes the ServerHandle returned from start/1. +get_requests({CollectorPID, _ListenerID}, Tag) -> + CollectorPID ! {get_requests, Tag, self()}, + receive + {requests, Requests} -> Requests + after 1000 -> [] + end. + +get_requests(Type, Count, ServerHandle) -> + get_requests(Type, Count, ServerHandle, 10000). + +get_requests(Type, Count, ServerHandle, Timeout) -> + %% Wait for expected transaction + hb_util:wait_until( + fun() -> + Requests = get_requests(ServerHandle, Type), + length(Requests) >= Count + end, + Timeout + ), + get_requests(ServerHandle, Type). + +%%%=================================================================== +%%% Internal Functions +%%%=================================================================== + +%% @doc Collector process loop for mock server. +collect_loop(State) -> + receive + {request, Tag, Body} -> + ?event({request, Tag, Body}), + Requests = maps:get(Tag, State, []), + collect_loop(State#{Tag => [Body | Requests]}); + {get_requests, Tag, From} -> + Requests = maps:get(Tag, State, []), + From ! {requests, lists:reverse(Requests)}, + %% Keep the requests in state (don't clear them) + collect_loop(State); + stop -> ok + end. + +%% @doc Convert a cowboy request to a message (i.e. just convert the atom +%% keys to binaries and add the body) +request_to_message(Req, Body) -> + maps:fold( + fun(Key, Value, Acc) -> + maps:put(hb_util:bin(Key), Value, Acc) + end, + #{<<"body">> => Body}, + Req + ). + +%%%=================================================================== +%%% Cowboy Handler Callback +%%%=================================================================== + +%% @doc Cowboy handler callback - DO NOT CALL DIRECTLY. +%% This is invoked automatically by Cowboy when requests arrive at the +%% mock server. See start/1 for usage. +init(Req0, {Tag, Response, CollectorPID} = State) -> + {ok, Body, Req} = cowboy_req:read_body(Req0), + Msg = request_to_message(Req, Body), + CollectorPID ! {request, Tag, Msg}, + %% Determine the response - either call the function or use the static value + {StatusCode, ResponseBody, Headers} = case is_function(Response) of + true -> Response(Msg); + false -> Response + end, + {ok, cowboy_req:reply(StatusCode, Headers, ResponseBody, Req), State}. \ No newline at end of file diff --git a/src/hb_util.erl b/src/hb_util.erl index 98c5719f0..5a3502014 100644 --- a/src/hb_util.erl +++ b/src/hb_util.erl @@ -19,7 +19,7 @@ -export([remove_common/2, to_lower/1]). -export([maybe_throw/2]). -export([is_hb_module/1, is_hb_module/2, all_hb_modules/0]). --export([ok/1, ok/2, until/1, until/2, until/3]). +-export([ok/1, ok/2, until/1, until/2, until/3, wait_until/2]). -export([count/2, mean/1, stddev/1, variance/1, weighted_random/1]). -export([unique/1]). -export([split_depth_string_aware/2, split_depth_string_aware_single/2]). @@ -134,6 +134,24 @@ until(Condition, Fun, Count) -> true -> Count end. +%% @doc Wait until a condition function returns true or timeout is reached. +%% The condition function is polled every 100ms by default. +%% Returns true if the condition was met, false if timeout was reached. +wait_until(ConditionFun, TimeoutMs) -> + StartTime = erlang:system_time(millisecond), + until( + fun() -> + case ConditionFun() of + true -> true; + false -> + CurrentTime = erlang:system_time(millisecond), + CurrentTime - StartTime >= TimeoutMs + end + end + ), + %% Check one more time to determine if we succeeded or timed out + ConditionFun(). + %% @doc Return the human-readable form of an ID of a message when given either %% a message explicitly, raw encoded ID, or an Erlang Arweave `tx' record. id(Item) -> id(Item, unsigned).