diff --git a/grpc/lib/grpc/client/connection.ex b/grpc/lib/grpc/client/connection.ex index f51a6e262..234c98efa 100644 --- a/grpc/lib/grpc/client/connection.ex +++ b/grpc/lib/grpc/client/connection.ex @@ -20,7 +20,8 @@ defmodule GRPC.Client.Connection do * The target string is resolved using a [Resolver](GRPC.Client.Resolver). * Depending on the target and service config, a load-balancing module is chosen (e.g. `PickFirst`, `RoundRobin`). - * The orchestrator periodically refreshes the LB decision to adapt to changes. + * Each call to `pick/2` dispatches to the LB module, which selects a channel + per request. DNS re-resolution reconciles the LB's channel list in place. ## Target syntax @@ -68,9 +69,6 @@ defmodule GRPC.Client.Connection do iex> GRPC.Client.Connection.disconnect(ch) {:ok, %GRPC.Channel{...}} - ## Notes - - * The orchestrator refreshes the LB pick every 15 seconds. """ use GenServer alias GRPC.Channel @@ -79,7 +77,6 @@ defmodule GRPC.Client.Connection do @insecure_scheme "http" @secure_scheme "https" - @refresh_interval 15_000 @default_resolve_interval 30_000 @default_max_resolve_interval 300_000 @default_min_resolve_interval 5_000 @@ -122,29 +119,33 @@ defmodule GRPC.Client.Connection do def init(%__MODULE__{} = state) do Process.flag(:trap_exit, true) - # only now persist the chosen channel (which should already have adapter_payload - # because build_initial_state connected real channels and set virtual_channel) - :persistent_term.put( - {__MODULE__, :lb_state, state.virtual_channel.ref}, - state.virtual_channel - ) + connected = connected_channels(state.real_channels) - Process.send_after(self(), :refresh, @refresh_interval) + case state.lb_mod.init(channels: connected) do + {:ok, lb_state} -> + state = %{state | lb_state: lb_state} + + state = + if function_exported?(state.resolver, :init, 2) do + {:ok, resolver_state} = + state.resolver.init(state.resolver_target, + connection_pid: self(), + connect_opts: state.connect_opts + ) + + %{state | resolver_state: resolver_state} + else + state + end - state = - if function_exported?(state.resolver, :init, 2) do - {:ok, resolver_state} = - state.resolver.init(state.resolver_target, - connection_pid: self(), - connect_opts: state.connect_opts - ) - - %{state | resolver_state: resolver_state} - else - state - end + :persistent_term.put({__MODULE__, state.virtual_channel.ref}, {state.lb_mod, lb_state}) - {:ok, state} + {:ok, state} + + {:error, reason} -> + disconnect_real_channels(state.real_channels, state.adapter) + {:stop, reason} + end end @doc """ @@ -185,16 +186,10 @@ defmodule GRPC.Client.Connection do case DynamicSupervisor.start_child(GRPC.Client.Supervisor, child_spec(initial_state)) do {:ok, _pid} -> - {:ok, ch} + finalize_connection(ch, opts) {:error, {:already_started, _pid}} -> - case pick_channel(ch, opts) do - {:ok, %Channel{} = channel} -> - {:ok, channel} - - _ -> - {:error, :no_connection} - end + finalize_connection(ch, opts) {:error, reason} -> {:error, reason} @@ -244,12 +239,15 @@ defmodule GRPC.Client.Connection do """ @spec pick_channel(Channel.t(), keyword()) :: {:ok, Channel.t()} | {:error, term()} def pick_channel(%Channel{ref: ref} = _channel, _opts \\ []) do - case :persistent_term.get({__MODULE__, :lb_state, ref}, nil) do - nil -> - {:error, :no_connection} + case :persistent_term.get({__MODULE__, ref}, nil) do + {lb_mod, lb_state} when not is_nil(lb_mod) -> + case lb_mod.pick(lb_state) do + {:ok, %Channel{} = channel, _new_state} -> {:ok, channel} + {:error, _} -> {:error, :no_connection} + end - %Channel{} = channel -> - {:ok, channel} + _ -> + {:error, :no_connection} end end @@ -279,58 +277,14 @@ defmodule GRPC.Client.Connection do state.resolver.shutdown(state.resolver_state) end - resp = {:ok, %Channel{channel | adapter_payload: %{conn_pid: nil}}} - :persistent_term.erase({__MODULE__, :lb_state, channel.ref}) - - if Map.has_key?(state, :real_channels) do - Enum.map(state.real_channels, fn - {_key, {:connected, ch}} -> - do_disconnect(adapter, ch) + :persistent_term.erase({__MODULE__, channel.ref}) + disconnect_real_channels(state.real_channels, adapter) - _ -> - :ok - end) - - keys_to_delete = [:real_channels, :virtual_channel] - new_state = Map.drop(state, keys_to_delete) - - {:reply, resp, new_state, {:continue, :stop}} - else - {:reply, resp, state, {:continue, :stop}} - end + resp = {:ok, %Channel{channel | adapter_payload: %{conn_pid: nil}}} + {:reply, resp, state, {:continue, :stop}} end @impl GenServer - def handle_info( - :refresh, - %{lb_mod: lb_mod, lb_state: lb_state, real_channels: channels, virtual_channel: vc} = - state - ) - when not is_nil(lb_mod) do - {:ok, {prefer_host, prefer_port}, new_lb_state} = lb_mod.pick(lb_state) - - channel_key = build_address_key(prefer_host, prefer_port) - - case Map.get(channels, channel_key) do - {:connected, %Channel{} = picked_channel} -> - :persistent_term.put({__MODULE__, :lb_state, vc.ref}, picked_channel) - - Process.send_after(self(), :refresh, @refresh_interval) - {:noreply, %{state | lb_state: new_lb_state, virtual_channel: picked_channel}} - - _nil_or_failed -> - # LB picked a channel that is missing or in {:failed, _} state. - # Don't update persistent_term — keep serving from the current - # virtual_channel until re-resolution provides healthy backends. - Logger.warning("LB picked #{channel_key}, but channel is unavailable") - - Process.send_after(self(), :refresh, @refresh_interval) - {:noreply, %{state | lb_state: new_lb_state}} - end - end - - def handle_info(:refresh, state), do: {:noreply, state} - def handle_info({:resolver_update, result}, state) do state = handle_resolve_result(result, state) {:noreply, state} @@ -356,6 +310,16 @@ defmodule GRPC.Client.Connection do {:noreply, state} end + def handle_info({:EXIT, _pid, :normal}, state), do: {:noreply, state} + + def handle_info({:EXIT, pid, reason}, state) do + Logger.warning( + "#{inspect(__MODULE__)} received :EXIT from #{inspect(pid)} reason: #{inspect(reason)}" + ) + + {:noreply, state} + end + def handle_info({:DOWN, _ref, :process, pid, reason}, state) do Logger.warning( "#{inspect(__MODULE__)} received :DOWN from #{inspect(pid)} with reason: #{inspect(reason)}" @@ -378,13 +342,30 @@ defmodule GRPC.Client.Connection do @impl GenServer def terminate(_reason, %{virtual_channel: %{ref: ref}}) do - :persistent_term.erase({__MODULE__, :lb_state, ref}) + :persistent_term.erase({__MODULE__, ref}) + :ok rescue _ -> :ok end def terminate(_reason, _state), do: :ok + defp finalize_connection(%Channel{} = ch, opts) do + case pick_channel(ch, opts) do + {:ok, %Channel{} = channel} -> {:ok, channel} + _ -> {:error, :no_connection} + end + end + + defp disconnect_real_channels(real_channels, adapter) when is_map(real_channels) do + Enum.each(real_channels, fn + {_key, {:connected, ch}} -> do_disconnect(adapter, ch) + _ -> :ok + end) + end + + defp disconnect_real_channels(_real_channels, _adapter), do: :ok + defp handle_resolve_result({:ok, %{addresses: []}}, state), do: state defp handle_resolve_result({:ok, %{addresses: new_addresses}}, state) do @@ -445,64 +426,32 @@ defmodule GRPC.Client.Connection do end) end - defp rebalance_after_reconcile(new_addresses, real_channels, state) do - if state.lb_mod do - case state.lb_mod.init(addresses: new_addresses) do - {:ok, new_lb_state} -> - {:ok, {host, port}, picked_lb_state} = state.lb_mod.pick(new_lb_state) - key = build_address_key(host, port) - - case Map.get(real_channels, key) do - {:connected, picked_channel} -> - maybe_update_persistent_term(state.virtual_channel, picked_channel) - - %{ - state - | real_channels: real_channels, - lb_state: picked_lb_state, - virtual_channel: picked_channel - } + defp rebalance_after_reconcile(_new_addresses, real_channels, state) do + connected = connected_channels(real_channels) - _ -> - fallback_to_healthy_channel(state, real_channels, picked_lb_state) - end - - {:error, _} -> - fallback_to_healthy_channel(state, real_channels, state.lb_state) + new_lb_state = + if state.lb_mod do + case reconcile_lb(state.lb_mod, state.lb_state, connected) do + {:ok, s} -> s + {:error, _} -> state.lb_state + end + else + state.lb_state end - else - fallback_to_healthy_channel(state, real_channels, state.lb_state) - end - end - - defp fallback_to_healthy_channel(state, real_channels, lb_state) do - ref = state.virtual_channel.ref - case Enum.find_value(real_channels, fn {_k, v} -> match?({:connected, _}, v) && v end) do - {:connected, healthy_channel} -> - maybe_update_persistent_term(state.virtual_channel, healthy_channel) + if connected == [] do + Logger.warning("No healthy channels available after re-resolution") + end - %{ - state - | real_channels: real_channels, - lb_state: lb_state, - virtual_channel: healthy_channel - } + %{state | real_channels: real_channels, lb_state: new_lb_state} + end - nil -> - Logger.warning("No healthy channels available after re-resolution") - :persistent_term.erase({__MODULE__, :lb_state, ref}) - %{state | real_channels: real_channels, lb_state: lb_state} - end + defp reconcile_lb(lb_mod, lb_state, new_channels) do + lb_mod.update(lb_state, new_channels) end - defp maybe_update_persistent_term(current_channel, new_channel) do - if current_channel != new_channel do - :persistent_term.put( - {__MODULE__, :lb_state, new_channel.ref}, - new_channel - ) - end + defp connected_channels(real_channels) do + for {_key, {:connected, ch}} <- real_channels, do: ch end defp channel_alive?({:connected, %{adapter_payload: %{conn_pid: pid}}}) when is_pid(pid) do @@ -625,44 +574,36 @@ defmodule GRPC.Client.Connection do lb_mod = choose_lb(lb_policy) - case lb_mod.init(addresses: addresses) do - {:ok, lb_state} -> - {:ok, {prefer_host, prefer_port}, new_lb_state} = lb_mod.pick(lb_state) - - real_channels = - build_real_channels(addresses, base_state.virtual_channel, norm_opts, adapter) - - key = build_address_key(prefer_host, prefer_port) - - with {:connected, ch} <- Map.get(real_channels, key, {:failed, :no_channel}) do - {:ok, - %__MODULE__{ - base_state - | lb_mod: lb_mod, - lb_state: new_lb_state, - virtual_channel: ch, - real_channels: real_channels - }} - else - {:failed, reason} -> {:error, reason} - end + real_channels = + build_real_channels(addresses, base_state.virtual_channel, norm_opts, adapter) + + case connected_channels(real_channels) do + [] -> + disconnect_real_channels(real_channels, adapter) + {:error, :no_channels} - {:error, :no_addresses} -> - {:error, :no_addresses} + _connected -> + {:ok, + %__MODULE__{ + base_state + | lb_mod: lb_mod, + real_channels: real_channels + }} end end defp build_direct_state(%__MODULE__{} = base_state, norm_target, norm_opts, adapter) do {host, port} = split_host_port(norm_target) vc = base_state.virtual_channel + lb_mod = GRPC.Client.LoadBalancing.PickFirst case connect_real_channel(vc, host, port, norm_opts, adapter) do {:ok, ch} -> {:ok, %__MODULE__{ base_state - | virtual_channel: ch, - real_channels: %{"#{host}:#{port}" => {:connected, ch}} + | real_channels: %{build_address_key(host, port) => {:connected, ch}}, + lb_mod: lb_mod }} {:error, reason} -> diff --git a/grpc/lib/grpc/client/load_balacing.ex b/grpc/lib/grpc/client/load_balacing.ex index fb3a58d1c..de55e9cb3 100644 --- a/grpc/lib/grpc/client/load_balacing.ex +++ b/grpc/lib/grpc/client/load_balacing.ex @@ -1,12 +1,13 @@ defmodule GRPC.Client.LoadBalancing do - @moduledoc """ - Load balancing behaviour for gRPC clients. + @moduledoc "Load balancing behaviour for gRPC clients." + + alias GRPC.Channel - This module defines the behaviour that load balancing strategies must implement. - """ @callback init(opts :: keyword()) :: {:ok, state :: any()} | {:error, reason :: any()} @callback pick(state :: any()) :: - {:ok, {host :: String.t(), port :: non_neg_integer()}, new_state :: any()} - | {:error, reason :: any()} + {:ok, Channel.t(), new_state :: any()} | {:error, reason :: any()} + + @callback update(state :: any(), new_channels :: [Channel.t()]) :: + {:ok, new_state :: any()} | {:error, reason :: any()} end diff --git a/grpc/lib/grpc/client/load_balacing/pick_first.ex b/grpc/lib/grpc/client/load_balacing/pick_first.ex index 14e17ca03..9f37b3603 100644 --- a/grpc/lib/grpc/client/load_balacing/pick_first.ex +++ b/grpc/lib/grpc/client/load_balacing/pick_first.ex @@ -1,16 +1,42 @@ defmodule GRPC.Client.LoadBalancing.PickFirst do + @moduledoc "Pick-first load balancer: always returns the first channel in the list." + @behaviour GRPC.Client.LoadBalancing + @current_key :current + @impl true def init(opts) do - case Keyword.get(opts, :addresses, []) do - [] -> {:error, :no_addresses} - addresses -> {:ok, %{addresses: addresses, current: hd(addresses)}} + case Keyword.get(opts, :channels, []) do + [] -> + {:error, :no_channels} + + [first | _] -> + tid = :ets.new(:grpc_lb_pick_first, [:set, :public, read_concurrency: true]) + :ets.insert(tid, {@current_key, first}) + {:ok, %{tid: tid}} end end @impl true - def pick(%{current: %{address: host, port: port}} = state) do - {:ok, {host, port}, state} + def pick(%{tid: tid} = state) do + case :ets.lookup(tid, @current_key) do + [{@current_key, nil}] -> {:error, :no_channels} + [{@current_key, channel}] -> {:ok, channel, state} + [] -> {:error, :no_channels} + end + rescue + ArgumentError -> {:error, :no_channels} + end + + @impl true + def update(%{tid: tid} = state, [first | _]) do + :ets.insert(tid, {@current_key, first}) + {:ok, state} + end + + def update(%{tid: tid} = state, []) do + :ets.insert(tid, {@current_key, nil}) + {:ok, state} end end diff --git a/grpc/lib/grpc/client/load_balacing/round_robin.ex b/grpc/lib/grpc/client/load_balacing/round_robin.ex index af47bf5bb..4ad29a3f6 100644 --- a/grpc/lib/grpc/client/load_balacing/round_robin.ex +++ b/grpc/lib/grpc/client/load_balacing/round_robin.ex @@ -1,22 +1,45 @@ defmodule GRPC.Client.LoadBalancing.RoundRobin do + @moduledoc "Round-robin load balancer (ETS for the channel tuple, `:atomics` for the cursor)." + @behaviour GRPC.Client.LoadBalancing + @channels_key :channels + @impl true def init(opts) do - addresses = Keyword.get(opts, :addresses, []) + case Keyword.get(opts, :channels, []) do + [] -> + {:error, :no_channels} + + channels -> + tid = :ets.new(:grpc_lb_round_robin, [:set, :public, read_concurrency: true]) + aref = :atomics.new(1, signed: false) + + :ets.insert(tid, {@channels_key, List.to_tuple(channels)}) - if addresses == [] do - {:error, :no_addresses} - else - {:ok, %{addresses: addresses, index: 0, n: length(addresses)}} + {:ok, %{tid: tid, atomics: aref}} end end @impl true - def pick(%{addresses: addresses, index: idx, n: n} = state) do - %{address: host, port: port} = Enum.fetch!(addresses, idx) + def pick(%{tid: tid, atomics: aref} = state) do + case :ets.lookup(tid, @channels_key) do + [{@channels_key, channels}] when tuple_size(channels) > 0 -> + idx = :atomics.add_get(aref, 1, 1) + channel = elem(channels, rem(idx - 1, tuple_size(channels))) + {:ok, channel, state} - new_state = %{state | index: rem(idx + 1, n)} - {:ok, {host, port}, new_state} + _ -> + {:error, :no_channels} + end + rescue + ArgumentError -> {:error, :no_channels} + end + + @impl true + def update(%{tid: tid, atomics: aref} = state, new_channels) do + :ets.insert(tid, {@channels_key, List.to_tuple(new_channels)}) + :atomics.put(aref, 1, 0) + {:ok, state} end end diff --git a/grpc/test/grpc/client/connection_test.exs b/grpc/test/grpc/client/connection_test.exs index 8ab0018e9..00cc036ca 100644 --- a/grpc/test/grpc/client/connection_test.exs +++ b/grpc/test/grpc/client/connection_test.exs @@ -16,13 +16,13 @@ defmodule GRPC.Client.ConnectionTest do end describe "pick_channel/2" do - test "returns {:error, :no_connection} when no persistent_term entry exists", %{ref: ref} do + test "returns {:error, :no_connection} when the ref is not registered", %{ref: ref} do channel = %Channel{ref: ref} assert {:error, :no_connection} = Connection.pick_channel(channel) end - test "returns {:ok, channel} when a channel is stored in persistent_term", %{ + test "returns {:ok, channel} once a connection has published its LB state", %{ ref: ref, target: target, adapter: adapter @@ -41,8 +41,6 @@ defmodule GRPC.Client.ConnectionTest do adapter: adapter } do {:ok, first_channel} = Connection.connect(target, adapter: adapter, name: ref) - - # Connecting again with the same ref triggers the :already_started path {:ok, second_channel} = Connection.connect(target, adapter: adapter, name: ref) assert first_channel.ref == second_channel.ref @@ -51,6 +49,20 @@ defmodule GRPC.Client.ConnectionTest do Connection.disconnect(first_channel) end + + test "returns {:error, :no_connection} when already_started but the persistent_term entry is missing", + %{ref: ref, target: target, adapter: adapter} do + {:ok, channel} = Connection.connect(target, adapter: adapter, name: ref) + + key = {Connection, ref} + entry = :persistent_term.get(key) + :persistent_term.erase(key) + + assert {:error, :no_connection} = Connection.connect(target, adapter: adapter, name: ref) + + :persistent_term.put(key, entry) + Connection.disconnect(channel) + end end describe "disconnect/1" do @@ -69,7 +81,7 @@ defmodule GRPC.Client.ConnectionTest do assert_receive {:DOWN, ^ref_mon, :process, ^pid, _reason}, 500 end - test "pick_channel returns {:error, :no_connection} after disconnect (persistent_term is erased)", + test "pick_channel returns {:error, :no_connection} after disconnect (persistent_term entry is erased)", %{ref: ref, target: target, adapter: adapter} do {:ok, channel} = Connection.connect(target, adapter: adapter, name: ref) @@ -80,7 +92,7 @@ defmodule GRPC.Client.ConnectionTest do end describe "terminate/2 - persistent_term cleanup on process kill" do - test "persistent_term is erased when process is killed without disconnect", %{ + test "persistent_term entry is erased when process is killed without disconnect", %{ ref: ref, target: target, adapter: adapter @@ -96,6 +108,114 @@ defmodule GRPC.Client.ConnectionTest do end end + describe "LB ETS table lifecycle" do + test "disconnect/1 exits the GenServer and the ETS table is freed", %{ + ref: ref, + target: target, + adapter: adapter + } do + {:ok, channel} = + Connection.connect(target, adapter: adapter, name: ref, lb_policy: :round_robin) + + tid = lb_tid(ref) + assert :ets.info(tid) != :undefined + + pid = whereis_name(ref) + ref_mon = Process.monitor(pid) + + {:ok, _} = Connection.disconnect(channel) + + assert_receive {:DOWN, ^ref_mon, :process, ^pid, _reason}, 500 + + assert :ets.info(tid) == :undefined + end + + test "terminate/2 frees the ETS table when the process is killed", %{ + ref: ref, + target: target, + adapter: adapter + } do + {:ok, _channel} = + Connection.connect(target, adapter: adapter, name: ref, lb_policy: :round_robin) + + tid = lb_tid(ref) + pid = whereis_name(ref) + ref_mon = Process.monitor(pid) + + GenServer.stop(pid, :shutdown) + assert_receive {:DOWN, ^ref_mon, :process, ^pid, :shutdown}, 500 + + assert :ets.info(tid) == :undefined + end + end + + describe "pick_channel/2 races with disconnect/1" do + test "many concurrent picks complete without crashing while disconnect runs", %{ + ref: ref, + target: target, + adapter: adapter + } do + {:ok, channel} = + Connection.connect(target, adapter: adapter, name: ref, lb_policy: :round_robin) + + parent = self() + picker_count = 50 + picks_per_proc = 100 + + pickers = + for i <- 1..picker_count do + spawn_link(fn -> + results = + for _ <- 1..picks_per_proc do + Connection.pick_channel(channel) + end + + send(parent, {:done, i, results}) + end) + end + + Process.sleep(2) + {:ok, _} = Connection.disconnect(channel) + + for _ <- 1..picker_count do + assert_receive {:done, _, results}, 2_000 + + for r <- results do + assert match?({:ok, %Channel{}}, r) or r == {:error, :no_connection} + end + end + + for pid <- pickers do + refute Process.alive?(pid) + end + end + end + + describe "resource leaks over repeated connect/disconnect" do + test "500 cycles leave persistent_term clean and no per-LB tables leak", %{ + target: target, + adapter: adapter + } do + before_table_count = length(:ets.all()) + before_pt_count = connection_pt_count() + + for _ <- 1..500 do + ref = make_ref() + {:ok, channel} = Connection.connect(target, adapter: adapter, name: ref) + {:ok, _} = Connection.disconnect(channel) + end + + after_pt_count = connection_pt_count() + after_table_count = length(:ets.all()) + + assert after_pt_count == before_pt_count, + "persistent_term leaked: before=#{before_pt_count} after=#{after_pt_count}" + + assert after_table_count - before_table_count <= 5, + "ETS tables leaked: before=#{before_table_count} after=#{after_table_count}" + end + end + describe "connect/2 - distributed named channels" do test "named channels do not conflict across connected nodes" do {:ok, _, port} = GRPC.Server.start(FeatureServer, 0) @@ -128,6 +248,16 @@ defmodule GRPC.Client.ConnectionTest do end end + defp connection_pt_count do + Enum.count(:persistent_term.get(), &match?({{Connection, _}, _}, &1)) + end + + defp lb_tid(ref) do + pid = whereis_name(ref) + %{lb_state: %{tid: tid}} = :sys.get_state(pid) + tid + end + defp start_peer do {:ok, peer, node} = @peer.start_link(%{ diff --git a/grpc/test/grpc/client/dns_resolver_test.exs b/grpc/test/grpc/client/dns_resolver_test.exs index 8029529a7..972ed5e21 100644 --- a/grpc/test/grpc/client/dns_resolver_test.exs +++ b/grpc/test/grpc/client/dns_resolver_test.exs @@ -27,6 +27,7 @@ defmodule GRPC.Client.ReResolveTest do use GRPC.Client.DataCase, async: false import Mox + alias GRPC.Channel alias GRPC.Client.Connection @resolve_interval 200 @@ -438,6 +439,134 @@ defmodule GRPC.Client.ReResolveTest do end end + describe "pick_first reconciliation" do + test "pick_channel returns the new backend after DNS replaces it", ctx do + {:ok, channel} = + connect_with_resolver( + ctx.ref, + ctx.resolver, + ctx.adapter, + [%{address: "10.0.0.1", port: 50051}], + [] + ) + + {:ok, before} = Connection.pick_channel(channel) + assert before.host == "10.0.0.1" + + stub(ctx.resolver, :resolve, fn _target -> + {:ok, %{addresses: [%{address: "10.0.0.9", port: 50051}], service_config: nil}} + end) + + Process.sleep(@wait) + + {:ok, picked} = Connection.pick_channel(channel) + assert picked.host == "10.0.0.9" + + disconnect_and_wait(channel) + end + end + + describe "pick_channel during shrinking reconcile" do + test "concurrent picks stay correct while backends are removed", ctx do + large = + for i <- 1..8, do: %{address: "10.0.0.#{i}", port: 50051} + + {:ok, channel} = + connect_with_resolver( + ctx.ref, + ctx.resolver, + ctx.adapter, + large, + lb_policy: :round_robin + ) + + small = [ + %{address: "10.0.0.1", port: 50051}, + %{address: "10.0.0.2", port: 50051} + ] + + parent = self() + picker_count = 30 + picks_per_proc = 200 + + pickers = + for i <- 1..picker_count do + spawn_link(fn -> + results = + for _ <- 1..picks_per_proc do + Connection.pick_channel(channel) + end + + send(parent, {:done, i, results}) + end) + end + + stub(ctx.resolver, :resolve, fn _target -> + {:ok, %{addresses: small, service_config: nil}} + end) + + Process.sleep(@wait) + + for _ <- 1..picker_count do + assert_receive {:done, _, results}, 2_000 + + for r <- results do + assert match?({:ok, %Channel{}}, r), + "pick returned #{inspect(r)} — expected {:ok, %Channel{}}" + end + end + + Process.sleep(@wait) + + hosts = + for _ <- 1..20 do + {:ok, picked} = Connection.pick_channel(channel) + picked.host + end + + assert Enum.all?(hosts, &(&1 in ["10.0.0.1", "10.0.0.2"])), + "picked from an already-removed backend: #{inspect(hosts)}" + + for pid <- pickers do + refute Process.alive?(pid) + end + + disconnect_and_wait(channel) + end + end + + describe "pick_channel per-request rotation" do + test "round-robin rotates across all backends on successive picks", ctx do + {:ok, channel} = + connect_with_resolver( + ctx.ref, + ctx.resolver, + ctx.adapter, + [ + %{address: "10.0.0.1", port: 50051}, + %{address: "10.0.0.2", port: 50051}, + %{address: "10.0.0.3", port: 50051} + ], + lb_policy: :round_robin + ) + + hosts = + for _ <- 1..9 do + {:ok, picked} = Connection.pick_channel(channel) + picked.host + end + + assert Enum.sort(Enum.uniq(hosts)) == ["10.0.0.1", "10.0.0.2", "10.0.0.3"] + + counts = Enum.frequencies(hosts) + assert counts["10.0.0.1"] == 3 + assert counts["10.0.0.2"] == 3 + assert counts["10.0.0.3"] == 3 + + disconnect_and_wait(channel) + end + end + describe "pick_channel after full backend replacement" do test "picks a channel from the new backend set", ctx do {:ok, channel} = @@ -846,7 +975,7 @@ defmodule GRPC.Client.ReResolveTest do end end - describe "stale persistent_term prevention" do + describe "unhealthy-pick fallback" do setup ctx do Application.put_env(:grpc, :grpc_test_failing_hosts, ["10.0.0.99"]) on_exit(fn -> Application.delete_env(:grpc, :grpc_test_failing_hosts) end) @@ -932,6 +1061,8 @@ defmodule GRPC.Client.ReResolveTest do Process.sleep(@wait) assert {:error, :no_connection} = Connection.pick_channel(channel) + + disconnect_and_wait(channel) end end @@ -1160,14 +1291,12 @@ defmodule GRPC.Client.ReResolveTest do original_pid = state.resolver_state.worker_pid assert Process.alive?(original_pid) - # Kill the worker — Connection traps exits and should re-init Process.exit(original_pid, :kill) Process.sleep(100) conn_pid = whereis_name(ctx.ref) assert Process.alive?(conn_pid) - # resolver_state should have a NEW worker pid state = get_state(ctx.ref) assert state.resolver_state != nil assert state.resolver_state.worker_pid != original_pid @@ -1177,5 +1306,26 @@ defmodule GRPC.Client.ReResolveTest do disconnect_and_wait(channel) end + + test "unrelated :EXIT signals don't stop the Connection", ctx do + {:ok, channel} = + connect_with_resolver( + ctx.ref, + ctx.resolver, + ctx.adapter, + [%{address: "10.0.0.1", port: 50051}], + lb_policy: :round_robin + ) + + conn_pid = whereis_name(ctx.ref) + stray_pid = spawn(fn -> :ok end) + send(conn_pid, {:EXIT, stray_pid, :boom}) + + Process.sleep(50) + assert Process.alive?(conn_pid) + assert {:ok, _} = Connection.pick_channel(channel) + + disconnect_and_wait(channel) + end end end diff --git a/grpc/test/grpc/client/load_balacing/pick_first_test.exs b/grpc/test/grpc/client/load_balacing/pick_first_test.exs new file mode 100644 index 000000000..52ba1c30b --- /dev/null +++ b/grpc/test/grpc/client/load_balacing/pick_first_test.exs @@ -0,0 +1,71 @@ +defmodule GRPC.Client.LoadBalancing.PickFirstTest do + use ExUnit.Case, async: true + + alias GRPC.Channel + alias GRPC.Client.LoadBalancing.PickFirst + + defp channels(pairs), + do: Enum.map(pairs, fn {h, p} -> %Channel{host: h, port: p, ref: {h, p}} end) + + describe "init/1" do + test "creates an ETS table and returns the tid in state" do + {:ok, state} = PickFirst.init(channels: channels([{"a", 1}, {"b", 2}])) + assert %{tid: tid} = state + assert is_reference(tid) + assert :ets.info(tid) != :undefined + end + + test "seeds the table with the first channel" do + {:ok, state} = PickFirst.init(channels: channels([{"a", 1}, {"b", 2}])) + assert {:ok, %Channel{host: "a", port: 1}, ^state} = PickFirst.pick(state) + end + + test "rejects empty channel lists" do + assert {:error, :no_channels} = PickFirst.init(channels: []) + end + + test "rejects missing :channels option" do + assert {:error, :no_channels} = PickFirst.init([]) + end + end + + describe "pick/1" do + test "always returns the current channel" do + {:ok, state} = PickFirst.init(channels: channels([{"a", 1}, {"b", 2}])) + + for _ <- 1..3 do + assert {:ok, %Channel{host: "a", port: 1}, ^state} = PickFirst.pick(state) + end + end + + test "returns :no_channels when current is nil" do + {:ok, state} = PickFirst.init(channels: channels([{"a", 1}])) + {:ok, _} = PickFirst.update(state, []) + assert {:error, :no_channels} = PickFirst.pick(state) + end + + test "returns :no_channels instead of raising when the table was deleted" do + {:ok, state} = PickFirst.init(channels: channels([{"a", 1}])) + :ets.delete(state.tid) + assert {:error, :no_channels} = PickFirst.pick(state) + end + end + + describe "update/2" do + test "swaps current in place without changing the tid" do + {:ok, state} = PickFirst.init(channels: channels([{"a", 1}])) + original_tid = state.tid + + {:ok, new_state} = PickFirst.update(state, channels([{"x", 9}, {"y", 8}])) + assert new_state.tid == original_tid + + assert {:ok, %Channel{host: "x", port: 9}, _} = PickFirst.pick(new_state) + end + + test "clears current to nil on empty list" do + {:ok, state} = PickFirst.init(channels: channels([{"a", 1}])) + {:ok, state} = PickFirst.update(state, []) + assert {:error, :no_channels} = PickFirst.pick(state) + end + end +end diff --git a/grpc/test/grpc/client/load_balacing/round_robin_test.exs b/grpc/test/grpc/client/load_balacing/round_robin_test.exs new file mode 100644 index 000000000..f8ee17af5 --- /dev/null +++ b/grpc/test/grpc/client/load_balacing/round_robin_test.exs @@ -0,0 +1,128 @@ +defmodule GRPC.Client.LoadBalancing.RoundRobinTest do + use ExUnit.Case, async: true + + alias GRPC.Channel + alias GRPC.Client.LoadBalancing.RoundRobin + + defp channels(pairs), + do: Enum.map(pairs, fn {h, p} -> %Channel{host: h, port: p, ref: {h, p}} end) + + describe "init/1" do + test "creates an ETS table + atomics ref and returns both in state" do + {:ok, state} = RoundRobin.init(channels: channels([{"a", 1}])) + assert %{tid: tid, atomics: aref} = state + assert is_reference(tid) + assert is_reference(aref) + assert :ets.info(tid) != :undefined + assert :atomics.get(aref, 1) == 0 + end + + test "rejects empty channel lists" do + assert {:error, :no_channels} = RoundRobin.init(channels: []) + end + + test "rejects missing :channels option" do + assert {:error, :no_channels} = RoundRobin.init([]) + end + end + + describe "pick/1" do + test "rotates through channels in order" do + {:ok, state} = RoundRobin.init(channels: channels([{"a", 1}, {"b", 2}, {"c", 3}])) + + assert {:ok, %Channel{host: "a", port: 1}, _} = RoundRobin.pick(state) + assert {:ok, %Channel{host: "b", port: 2}, _} = RoundRobin.pick(state) + assert {:ok, %Channel{host: "c", port: 3}, _} = RoundRobin.pick(state) + assert {:ok, %Channel{host: "a", port: 1}, _} = RoundRobin.pick(state) + end + + test "wraps around with a single channel" do + {:ok, state} = RoundRobin.init(channels: channels([{"only", 1}])) + + for _ <- 1..5 do + assert {:ok, %Channel{host: "only"}, _} = RoundRobin.pick(state) + end + end + end + + describe "update/2" do + test "replaces channels in place without changing the tid or atomics ref" do + {:ok, state} = RoundRobin.init(channels: channels([{"a", 1}, {"b", 2}])) + original_tid = state.tid + original_aref = state.atomics + + {:ok, new_state} = RoundRobin.update(state, channels([{"x", 9}, {"y", 8}, {"z", 7}])) + assert new_state.tid == original_tid + assert new_state.atomics == original_aref + + assert {:ok, %Channel{host: "x", port: 9}, _} = RoundRobin.pick(new_state) + assert {:ok, %Channel{host: "y", port: 8}, _} = RoundRobin.pick(new_state) + assert {:ok, %Channel{host: "z", port: 7}, _} = RoundRobin.pick(new_state) + end + + test "accepts empty channel lists; pick then returns :no_channels" do + {:ok, state} = RoundRobin.init(channels: channels([{"a", 1}])) + assert {:ok, ^state} = RoundRobin.update(state, []) + assert {:error, :no_channels} = RoundRobin.pick(state) + end + + test "resets cursor so the first pick after update starts at the first channel" do + {:ok, state} = RoundRobin.init(channels: channels([{"a", 1}, {"b", 2}])) + {:ok, _, _} = RoundRobin.pick(state) + {:ok, _, _} = RoundRobin.pick(state) + + {:ok, state} = RoundRobin.update(state, channels([{"new-first", 1}, {"new-second", 2}])) + assert {:ok, %Channel{host: "new-first"}, _} = RoundRobin.pick(state) + end + end + + describe "pick/1 race with table deletion" do + test "returns :no_channels instead of raising when the table was deleted" do + {:ok, state} = RoundRobin.init(channels: channels([{"a", 1}])) + :ets.delete(state.tid) + + assert {:error, :no_channels} = RoundRobin.pick(state) + end + end + + describe "concurrency" do + test "pick/1 is safe under many concurrent processes" do + chs = channels(for i <- 1..4, do: {"host#{i}", 1000 + i}) + {:ok, state} = RoundRobin.init(channels: chs) + + parent = self() + picks_per_proc = 250 + procs = 16 + + for _ <- 1..procs do + spawn_link(fn -> + picks = + for _ <- 1..picks_per_proc do + {:ok, %Channel{host: host}, _} = RoundRobin.pick(state) + host + end + + send(parent, {:picks, picks}) + end) + end + + all_picks = + for _ <- 1..procs, reduce: [] do + acc -> + receive do + {:picks, picks} -> picks ++ acc + end + end + + assert length(all_picks) == procs * picks_per_proc + + counts = Enum.frequencies(all_picks) + avg = div(length(all_picks), length(chs)) + + for {_, c} <- counts do + assert abs(c - avg) <= 1, + "uneven pick distribution: #{inspect(counts)}" + end + end + end +end