Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 118 additions & 1 deletion src/hb_http_client.erl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
-module(hb_http_client).
-behaviour(gen_server).
-include("include/hb.hrl").
-include_lib("eunit/include/eunit.hrl").
-export([start_link/1, request/2]).
-export([init/1, handle_cast/2, handle_call/3, handle_info/2, terminate/2]).

Expand All @@ -14,6 +15,7 @@

-define(DEFAULT_RETRIES, 0).
-define(DEFAULT_RETRY_TIME, 1000).
-define(MAX_REDIRECTS, 5).

%%% ==================================================================
%%% Public interface.
Expand Down Expand Up @@ -151,6 +153,9 @@ gun_req(Args, ReestablishedConnection, Opts) ->
true -> {error, client_error};
false -> gun_req(Args, true, Opts)
end;
{ok, StatusCode, _Headers, _Body} = Reply
when StatusCode >= 301 andalso StatusCode < 400 ->
follow_redirect(Args, Reply, Opts);
Reply ->
Reply
end;
Expand All @@ -177,6 +182,52 @@ gun_req(Args, ReestablishedConnection, Opts) ->
end,
Response.

follow_redirect(Args, {ok, _, ResponseHeaders, _Body} = Reply, Opts) ->
FollowRedirects = maps:get(<<"follow_redirect">>, Opts, false),
CurrentRedirects = maps:get(current_redirects, Opts, 0),
BellowMaxRedirects = CurrentRedirects < ?MAX_REDIRECTS,
case FollowRedirects andalso BellowMaxRedirects of
true ->
#{ peer := Peer, path := Path, method := Method } = Args,
% Only follow the redirect if method is GET.
case Method of
<<"GET">> ->
Location = proplists:get_value(<<"location">>, ResponseHeaders),
#{peer := Peer2, path := Path2} = NewArgs = case Location of
<<"/", _/binary>> = RedirectPath ->
Args#{path => RedirectPath};
<<"http", _/binary>> ->
URI = uri_string:parse(Location),
NewPeer = uri_string:normalize(maps:remove(path, URI)),
NewPath = maps:get(path, URI),
Args#{peer => NewPeer, path => NewPath};
undefined ->
?event(http_client, {error, no_location_header_provided}),
Args
end,
?event(
http_client,
{follow_redirect,
{from, {peer, Peer}, {path, Path}},
{to, {peer, Peer2},{path, Path2}}
}
),
NewOpts = maps:update_with(
current_redirects,
fun (Value) -> Value + 1 end,
1,
Opts
),
Comment on lines +215 to +220
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is the right way to keep track of how many redirects we had, but to do it in another way, I would have to change the code more deeply.

gun_req(NewArgs, true, NewOpts);
_ ->
?event(http_client, {error, unsupported_redirect_method}),
Reply
end;
false ->
?event(http_client, {error, follow_redirect_not_enabled}),
Reply
end.

%% @doc Record the duration of the request in an async process. We write the
%% data to prometheus if the application is enabled, as well as invoking the
%% `http_monitor' if appropriate.
Expand Down Expand Up @@ -781,4 +832,70 @@ get_status_class(Data) when is_binary(Data) ->
get_status_class(Data) when is_atom(Data) ->
atom_to_binary(Data);
get_status_class(_) ->
<<"unknown">>.
<<"unknown">>.

%% Tests

start_mock_gateway(Responses) ->
DefaultResponse = {200, <<>>},
Endpoints = [
{"/redirect1", redirect1, maps:get(redirect1, Responses, DefaultResponse)},
{"/redirect2", redirect2, maps:get(redirect2, Responses, DefaultResponse)},
{"/redirect3", redirect3, maps:get(redirect3, Responses, DefaultResponse)},
{"/redirect4", redirect4, maps:get(redirect4, Responses, DefaultResponse)},
{"/redirect5", redirect5, maps:get(redirect5, Responses, DefaultResponse)},
{"/redirect6", redirect6, maps:get(redirect6, Responses, DefaultResponse)},
{"/ok", ok, maps:get(ok, Responses, DefaultResponse)}
],
hb_mock_server:start(Endpoints).

do_not_follow_redirect_by_default_test() ->
application:ensure_all_started(hb),
{ok, MockServer, ServerHandle} = start_mock_gateway(#{
redirect1 => {301, <<"1">>, #{<<"location">> => <<"/ok">>}}
}),
try
Opts = #{},
Args = #{peer => MockServer, path => "/redirect1", method => <<"GET">>},
Response = request(Args, Opts),
?assertMatch({ok, 301, _, _}, Response),
ok
after
hb_mock_server:stop(ServerHandle)
end.

follow_redirect_test() ->
application:ensure_all_started(hb),
{ok, MockServer, ServerHandle} = start_mock_gateway(#{
redirect1 => {301, <<"1">>, #{<<"location">> => <<"/ok">>}}
}),
try
Opts = #{<<"follow_redirect">> => true},
Args = #{peer => MockServer, path => "/redirect1", method => <<"GET">>},
Response = request(Args, Opts),
?assertMatch({ok, 200, _, _}, Response),
ok
after
hb_mock_server:stop(ServerHandle)
end.

max_redirect_test() ->
application:ensure_all_started(hb),
{ok, MockServer, ServerHandle} = start_mock_gateway(#{
redirect1 => {301, <<"1">>, #{<<"location">> => <<"/redirect2">>}},
redirect2 => {301, <<"2">>, #{<<"location">> => <<"/redirect3">>}},
redirect3 => {301, <<"3">>, #{<<"location">> => <<"/redirect4">>}},
redirect4 => {301, <<"4">>, #{<<"location">> => <<"/redirect5">>}},
redirect5 => {301, <<"5">>, #{<<"location">> => <<"/redirect6">>}},
redirect6 => {301, <<"6">>, #{<<"location">> => <<"/ok">>}}
}),
try
Opts = #{<<"follow_redirect">> => true},
Args = #{peer => MockServer, path => "/redirect1", method => <<"GET">>},
Response = request(Args, Opts),
%% Return the last response
?assertMatch({ok, 301, _, <<"6">>}, Response),
ok
after
hb_mock_server:stop(ServerHandle)
end.
136 changes: 136 additions & 0 deletions src/hb_mock_server.erl
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
%%% @doc Mock HTTP server for testing. Collects request bodies and returns
%%% configurable responses.
-module(hb_mock_server).
-export([start/1, stop/1, get_requests/2, get_requests/3, get_requests/4]).
%% Cowboy handler callback
-export([init/2]).
-include("include/hb.hrl").

%%%===================================================================
%%% Public API
%%%===================================================================

%% @doc Start a generic mock HTTP server that collects request bodies.
%% Usage: start([{"/endpoint", endpoint_tag, {status, body}}, ...])
%% start([{"/endpoint", endpoint_tag, fun(Req) -> {Status, Body} end}, ...])
%% start([{"/endpoint", endpoint_tag}, ...]) for default {200, <<"OK">>}
%%
%% Response formats:
%% {Status, Body} - Static response
%% fun(Req) -> ... - Function called with request map, returns {Status, Body}
%%
%% Paths support Cowboy route patterns:
%% "/price/:amount" - Matches /price/123, /price/abc, etc.
%% "/user/:id/post/:post_id" - Multiple parameters
%% "/files/[...]" - Catch-all (matches /files/anything/here)
%%
%% Automatically generates unique listener ID and dynamic port.
%% Returns: {ok, ServerURL, ServerHandle}
start(Endpoints) ->
%% Ensure cowboy/ranch are started
application:ensure_all_started(cowboy),
CollectorPID = spawn(fun() -> collect_loop(#{}) end),
ListenerID = make_ref(),
NormalizedEndpoints = lists:map(
fun
({Path, Tag, Response}) when is_function(Response) ->
{Path, Tag, Response};
({Path, Tag, {Status, Body, Headers}}) ->
{Path, Tag, {Status, Body, Headers}};
({Path, Tag, {Status, Body}}) ->
{Path, Tag, {Status, Body, #{}}};
({Path, Tag}) ->
{Path, Tag, {200, <<>>, #{}}}
end,
Endpoints
),
Routes = [
{Path, ?MODULE, {Tag, Response, CollectorPID}}
|| {Path, Tag, Response} <- NormalizedEndpoints
],
Dispatch = cowboy_router:compile([{'_', Routes}]),
{ok, _Listener} = cowboy:start_clear(
ListenerID,
[{port, 0}], %% dynamic port allocation
#{env => #{dispatch => Dispatch}}
),
%% Get the port that was assigned
Port = ranch:get_port(ListenerID),
ServerURL = iolist_to_binary(io_lib:format("http://localhost:~p", [Port])),
{ok, ServerURL, {CollectorPID, ListenerID}}.

stop({CollectorPID, ListenerID}) ->
cowboy:stop_listener(ListenerID),
CollectorPID ! stop.

%% @doc Get all requests collected for a given endpoint tag.
%% Returns the accumulated requests without clearing them.
%% Takes the ServerHandle returned from start/1.
get_requests({CollectorPID, _ListenerID}, Tag) ->
CollectorPID ! {get_requests, Tag, self()},
receive
{requests, Requests} -> Requests
after 1000 -> []
end.

get_requests(Type, Count, ServerHandle) ->
get_requests(Type, Count, ServerHandle, 10000).

get_requests(Type, Count, ServerHandle, Timeout) ->
%% Wait for expected transaction
hb_util:wait_until(
fun() ->
Requests = get_requests(ServerHandle, Type),
length(Requests) >= Count
end,
Timeout
),
get_requests(ServerHandle, Type).

%%%===================================================================
%%% Internal Functions
%%%===================================================================

%% @doc Collector process loop for mock server.
collect_loop(State) ->
receive
{request, Tag, Body} ->
?event({request, Tag, Body}),
Requests = maps:get(Tag, State, []),
collect_loop(State#{Tag => [Body | Requests]});
{get_requests, Tag, From} ->
Requests = maps:get(Tag, State, []),
From ! {requests, lists:reverse(Requests)},
%% Keep the requests in state (don't clear them)
collect_loop(State);
stop -> ok
end.

%% @doc Convert a cowboy request to a message (i.e. just convert the atom
%% keys to binaries and add the body)
request_to_message(Req, Body) ->
maps:fold(
fun(Key, Value, Acc) ->
maps:put(hb_util:bin(Key), Value, Acc)
end,
#{<<"body">> => Body},
Req
).

%%%===================================================================
%%% Cowboy Handler Callback
%%%===================================================================

%% @doc Cowboy handler callback - DO NOT CALL DIRECTLY.
%% This is invoked automatically by Cowboy when requests arrive at the
%% mock server. See start/1 for usage.
init(Req0, {Tag, Response, CollectorPID} = State) ->
{ok, Body, Req} = cowboy_req:read_body(Req0),
Msg = request_to_message(Req, Body),
CollectorPID ! {request, Tag, Msg},
%% Determine the response - either call the function or use the static value
{StatusCode, ResponseBody, Headers} = case is_function(Response) of
true -> Response(Msg);
false -> Response
end,
{ok, cowboy_req:reply(StatusCode, Headers, ResponseBody, Req), State}.
20 changes: 19 additions & 1 deletion src/hb_util.erl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
-export([remove_common/2, to_lower/1]).
-export([maybe_throw/2]).
-export([is_hb_module/1, is_hb_module/2, all_hb_modules/0]).
-export([ok/1, ok/2, until/1, until/2, until/3]).
-export([ok/1, ok/2, until/1, until/2, until/3, wait_until/2]).
-export([count/2, mean/1, stddev/1, variance/1, weighted_random/1]).
-export([unique/1]).
-export([split_depth_string_aware/2, split_depth_string_aware_single/2]).
Expand Down Expand Up @@ -134,6 +134,24 @@ until(Condition, Fun, Count) ->
true -> Count
end.

%% @doc Wait until a condition function returns true or timeout is reached.
%% The condition function is polled every 100ms by default.
%% Returns true if the condition was met, false if timeout was reached.
wait_until(ConditionFun, TimeoutMs) ->
StartTime = erlang:system_time(millisecond),
until(
fun() ->
case ConditionFun() of
true -> true;
false ->
CurrentTime = erlang:system_time(millisecond),
CurrentTime - StartTime >= TimeoutMs
end
end
),
%% Check one more time to determine if we succeeded or timed out
ConditionFun().

%% @doc Return the human-readable form of an ID of a message when given either
%% a message explicitly, raw encoded ID, or an Erlang Arweave `tx' record.
id(Item) -> id(Item, unsigned).
Expand Down