Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 15 additions & 20 deletions app/agent/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@
}
"""

# Shared client for making HTTP requests.
_http_client = httpx.Client(timeout=httpx.Timeout(TIMEOUT, read=READ_TIMEOUT))


class GoogleAPIError:
"""Constants for expected Google API error types."""
Expand Down Expand Up @@ -278,12 +281,10 @@ def search_datasets(query: str) -> str:
Strategy: Start with broad terms like "censo", "ibge", "inep", "rais", then get specific if needed.
Next step: Use `get_dataset_details()` with returned dataset IDs.
""" # noqa: E501
with httpx.Client() as client:
response = client.get(
url=SEARCH_URL,
params={"contains": "tables", "q": query, "page_size": PAGE_SIZE},
timeout=httpx.Timeout(TIMEOUT, read=READ_TIMEOUT),
)
response = _http_client.get(
url=SEARCH_URL,
params={"contains": "tables", "q": query, "page_size": PAGE_SIZE},
)

response.raise_for_status()
data: dict = response.json()
Expand Down Expand Up @@ -333,15 +334,13 @@ def get_dataset_details(dataset_id: str) -> str:

Next step: Use `execute_bigquery_sql()` to execute queries.
""" # noqa: E501
with httpx.Client() as client:
response = client.post(
url=GRAPHQL_URL,
json={
"query": DATASET_DETAILS_QUERY,
"variables": {"id": dataset_id},
},
timeout=httpx.Timeout(TIMEOUT, read=READ_TIMEOUT),
)
response = _http_client.post(
url=GRAPHQL_URL,
json={
"query": DATASET_DETAILS_QUERY,
"variables": {"id": dataset_id},
},
)

response.raise_for_status()
data: dict[str, dict[str, dict]] = response.json()
Expand Down Expand Up @@ -436,11 +435,7 @@ def get_dataset_details(dataset_id: str) -> str:
if gcp_dataset_id is not None:
filename = gcp_dataset_id.replace("_", "-")

with httpx.Client() as client:
response = client.get(
url=f"{BASE_USAGE_GUIDE_URL}/{filename}.md",
timeout=httpx.Timeout(TIMEOUT, read=READ_TIMEOUT),
)
response = _http_client.get(f"{BASE_USAGE_GUIDE_URL}/{filename}.md")

if response.status_code == httpx.codes.OK:
usage_guide = response.text.strip()
Expand Down
11 changes: 6 additions & 5 deletions app/api/dependencies/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=False)

_http_client = httpx.AsyncClient()


async def _verify_token(token: str) -> bool:
query = """
Expand All @@ -22,11 +24,10 @@ async def _verify_token(token: str) -> bool:
"""
start = time.perf_counter()
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{settings.BASEDOSDADOS_BASE_URL}/graphql",
json={"query": query, "variables": {"token": token}},
)
response = await _http_client.post(
f"{settings.BASEDOSDADOS_BASE_URL}/graphql",
json={"query": query, "variables": {"token": token}},
)
response.raise_for_status()
except (httpx.HTTPStatusError, httpx.ConnectError):
raise HTTPException(
Expand Down
68 changes: 33 additions & 35 deletions tests/app/api/dependencies/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import uuid
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock

import httpx
import jwt
Expand All @@ -14,12 +13,6 @@
class TestVerifyToken:
"""Tests for _verify_token function."""

def _mock_client(self, mock_response: Any):
"""Create a mock httpx.AsyncClient context manager."""
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
return mock_client

def _mock_graphql_response(self, has_access: bool):
"""Create a mock response for the GraphQL endpoint."""
mock_response = MagicMock()
Expand All @@ -28,58 +21,63 @@ def _mock_graphql_response(self, has_access: bool):
}
return mock_response

async def test_returns_true_when_user_has_access(self):
async def test_returns_true_when_user_has_access(
self, monkeypatch: pytest.MonkeyPatch
):
"""Test returns True when user has chatbot access."""
mock_response = self._mock_graphql_response(has_access=True)
mock_client = self._mock_client(mock_response)

with patch("app.api.dependencies.auth.httpx.AsyncClient") as MockClient:
MockClient.return_value.__aenter__.return_value = mock_client
monkeypatch.setattr(
"app.api.dependencies.auth._http_client",
MagicMock(post=AsyncMock(return_value=mock_response)),
)

result = await _verify_token("valid-token")
result = await _verify_token("valid-token")

assert result is True

async def test_returns_false_when_user_lacks_access(self):
async def test_returns_false_when_user_lacks_access(
self, monkeypatch: pytest.MonkeyPatch
):
"""Test returns False when user lacks chatbot access."""
mock_response = self._mock_graphql_response(has_access=False)
mock_client = self._mock_client(mock_response)

with patch("app.api.dependencies.auth.httpx.AsyncClient") as MockClient:
MockClient.return_value.__aenter__.return_value = mock_client
monkeypatch.setattr(
"app.api.dependencies.auth._http_client",
MagicMock(post=AsyncMock(return_value=mock_response)),
)

result = await _verify_token("valid-token")
result = await _verify_token("valid-token")

assert result is False

async def test_raises_503_on_http_error(self):
async def test_raises_503_on_http_error(self, monkeypatch: pytest.MonkeyPatch):
"""Test raises 503 when GraphQL endpoint returns HTTP error."""
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"Server Error",
request=httpx.Request("POST", "http://test"),
response=mock_response,
)
mock_client = self._mock_client(mock_response)

with patch("app.api.dependencies.auth.httpx.AsyncClient") as MockClient:
MockClient.return_value.__aenter__.return_value = mock_client
monkeypatch.setattr(
"app.api.dependencies.auth._http_client",
MagicMock(post=AsyncMock(return_value=mock_response)),
)

with pytest.raises(HTTPException) as e:
await _verify_token("valid-token")
with pytest.raises(HTTPException) as e:
await _verify_token("valid-token")

assert e.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE

async def test_raises_503_on_connect_error(self):
async def test_raises_503_on_connect_error(self, monkeypatch: pytest.MonkeyPatch):
"""Test raises 503 when GraphQL endpoint is unreachable."""
mock_client = AsyncMock()
mock_client.post.side_effect = httpx.ConnectError("Connection refused")

with patch("app.api.dependencies.auth.httpx.AsyncClient") as MockClient:
MockClient.return_value.__aenter__.return_value = mock_client
monkeypatch.setattr(
"app.api.dependencies.auth._http_client",
MagicMock(
post=AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
),
)

with pytest.raises(HTTPException) as e:
await _verify_token("valid-token")
with pytest.raises(HTTPException) as e:
await _verify_token("valid-token")

assert e.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE

Expand Down