Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion horizon/facts/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ async def build_forward_request(
)

full_path = urljoin(f"/v2/facts/{project_id}/{environment_id}/", path.removeprefix("/"))
_query_params = {**request.query_params, **(query_params or {})}
_query_params = list(request.query_params.multi_items())
if query_params:
override_keys = set(query_params)
_query_params = [(key, value) for key, value in _query_params if key not in override_keys]
_query_params.extend(query_params.items())
return self.client.build_request(
method=request.method,
url=full_path,
Expand Down
44 changes: 42 additions & 2 deletions horizon/tests/test_facts_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from starlette.requests import Request as FastApiRequest


def _make_request(headers: dict[str, str] | None = None) -> FastApiRequest:
def _make_request(headers: dict[str, str] | None = None, query_string: bytes = b"") -> FastApiRequest:
scope = {
"type": "http",
"method": "POST",
"path": "/facts/users",
"raw_path": b"/facts/users",
"query_string": b"",
"query_string": query_string,
"headers": [(k.lower().encode(), v.encode()) for k, v in (headers or {}).items()],
}

Expand Down Expand Up @@ -59,6 +59,46 @@ async def test_build_forward_request_omits_header_by_default():
assert forward_request.headers.get(CONSISTENT_UPDATE_HEADER) is None


@pytest.mark.asyncio
async def test_build_forward_request_preserves_repeated_query_params():
client = FactsClient()

mock_remote_config = MagicMock()
mock_remote_config.context = {"project_id": "proj1", "env_id": "env1"}

with (
patch("horizon.facts.client.get_remote_config", return_value=mock_remote_config),
patch("horizon.facts.client.get_env_api_key", return_value="test_api_key"),
):
request = _make_request(query_string=b"tenant=tenant_id&user=user_1&user=user_2")
forward_request = await client.build_forward_request(request, "/role_assignments")

assert forward_request.url.params.get_list("user") == ["user_1", "user_2"]
assert forward_request.url.params["tenant"] == "tenant_id"


@pytest.mark.asyncio
async def test_build_forward_request_query_param_overrides_replace_existing_values():
client = FactsClient()

mock_remote_config = MagicMock()
mock_remote_config.context = {"project_id": "proj1", "env_id": "env1"}

with (
patch("horizon.facts.client.get_remote_config", return_value=mock_remote_config),
patch("horizon.facts.client.get_env_api_key", return_value="test_api_key"),
):
request = _make_request(query_string=b"return_deleted=false&user=user_1&user=user_2")
forward_request = await client.build_forward_request(
request,
"/role_assignments",
query_params={"return_deleted": True},
)

assert forward_request.url.params.get_list("user") == ["user_1", "user_2"]
assert forward_request.url.params["return_deleted"] == "true"


@pytest.mark.asyncio
async def test_send_forward_request_propagates_consistent_update_kwarg():
"""send_forward_request must plumb is_consistent_update into the built request's headers."""
Expand Down