diff --git a/src/knowhere/__init__.py b/src/knowhere/__init__.py index 134a757..4d094b8 100644 --- a/src/knowhere/__init__.py +++ b/src/knowhere/__init__.py @@ -22,6 +22,7 @@ ConflictError, GatewayTimeoutError, InternalServerError, + InvalidStateError, JobFailedError, KnowhereError, NotFoundError, @@ -30,6 +31,7 @@ PollingTimeoutError, RateLimitError, ServiceUnavailableError, + ValidationError, ) from knowhere._types import PollProgressCallback, UploadProgressCallback from knowhere._version import __version__ @@ -58,6 +60,8 @@ "__version__", # Exceptions "KnowhereError", + "ValidationError", + "InvalidStateError", "APIConnectionError", "APITimeoutError", "APIStatusError", diff --git a/src/knowhere/_base_client.py b/src/knowhere/_base_client.py index c84800a..815bf39 100644 --- a/src/knowhere/_base_client.py +++ b/src/knowhere/_base_client.py @@ -25,6 +25,7 @@ from knowhere._exceptions import ( APIConnectionError, APITimeoutError, + ValidationError, makeStatusError, ) from knowhere._logging import getLogger, redactSensitiveHeaders @@ -35,17 +36,23 @@ _logger = getLogger() -# Error codes that are safe to retry -_RETRYABLE_ERROR_CODES: frozenset[str] = frozenset({ - "rate_limit_exceeded", - "service_unavailable", - "gateway_timeout", - "internal_server_error", - "timeout", +# Error codes that are always safe to retry (matches server ALWAYS_RETRYABLE_ERROR_CODES) +_ALWAYS_RETRYABLE_ERROR_CODES: frozenset[str] = frozenset({ + "ABORTED", # 409 - Concurrency conflict + "UNAVAILABLE", # 503 - Service temporarily down + "DEADLINE_EXCEEDED", # 504 - Timeout }) -# Status codes that are safe to retry -_RETRYABLE_STATUS_CODES: frozenset[int] = frozenset({408, 429, 500, 502, 503, 504}) +# RESOURCE_EXHAUSTED (429) is conditionally retryable: +# - Rate limit: details.retry_after present → RETRY +# - Quota exceeded: no retry_after → DO NOT RETRY +_CONDITIONALLY_RETRYABLE_ERROR_CODE: str = "RESOURCE_EXHAUSTED" + +# HTTP status codes that are always safe to retry +_ALWAYS_RETRYABLE_STATUS_CODES: frozenset[int] = frozenset({409, 502, 503, 504}) + +# HTTP status code that is conditionally retryable (only with retry_after) +_CONDITIONALLY_RETRYABLE_STATUS_CODE: int = 429 class BaseClient: @@ -71,7 +78,7 @@ def __init__( # Resolve: arg > env > default resolved_key: Optional[str] = api_key or os.environ.get(ENV_API_KEY) if not resolved_key: - raise ValueError( + raise ValidationError( "An API key must be provided via the 'api_key' argument " f"or the {ENV_API_KEY} environment variable." ) @@ -122,12 +129,68 @@ def _shouldRetry( self, status_code: int, error_code: Optional[str] = None, - details: Optional[Any] = None, + details: Optional[Dict[str, Any]] = None, ) -> bool: - """Decide whether a request should be retried.""" - if error_code and error_code in _RETRYABLE_ERROR_CODES: + """Decide whether a request should be retried. + + Follows server-side retry semantics: + - ABORTED, UNAVAILABLE, DEADLINE_EXCEEDED → always retry + - RESOURCE_EXHAUSTED (429) → retry only if details.retry_after present + - All other errors → never retry + """ + if error_code: + if error_code in _ALWAYS_RETRYABLE_ERROR_CODES: + return True + if error_code == _CONDITIONALLY_RETRYABLE_ERROR_CODE: + return self._hasRetryAfter(details) + return False + + # Fallback to status code when error_code is unavailable + if status_code in _ALWAYS_RETRYABLE_STATUS_CODES: return True - return status_code in _RETRYABLE_STATUS_CODES + if status_code == _CONDITIONALLY_RETRYABLE_STATUS_CODE: + return self._hasRetryAfter(details) + return False + + @staticmethod + def _hasRetryAfter(details: Optional[Dict[str, Any]]) -> bool: + """Check if details contains a retry_after hint.""" + if not isinstance(details, dict): + return False + retry_after: Any = details.get("retry_after") + return retry_after is not None + + @staticmethod + def _extractRetryAfter( + error_body: Optional[Dict[str, Any]], + response: httpx.Response, + ) -> Optional[float]: + """Extract retry_after from the response body or Retry-After header. + + The server puts retry_after in ``error.details.retry_after``. + Falls back to the HTTP ``Retry-After`` header. + """ + # Prefer body: error.details.retry_after + if isinstance(error_body, dict): + err_obj: Any = error_body.get("error", error_body) + if isinstance(err_obj, dict): + details: Any = err_obj.get("details") + if isinstance(details, dict): + raw: Any = details.get("retry_after") + if raw is not None: + try: + return float(raw) + except (ValueError, TypeError): + pass + + # Fallback: HTTP Retry-After header + header_raw: Optional[str] = response.headers.get("retry-after") + if header_raw is not None: + try: + return float(header_raw) + except (ValueError, TypeError): + pass + return None def _calculateRetryDelay( self, @@ -257,24 +320,24 @@ def _request( response ) error_code: Optional[str] = None + error_details: Optional[Dict[str, Any]] = None if isinstance(error_body, dict): err_obj: Any = error_body.get("error", error_body) if isinstance(err_obj, dict): error_code = err_obj.get("code") + raw_details: Any = err_obj.get("details") + if isinstance(raw_details, dict): + error_details = raw_details if ( attempt < self.max_retries - and self._shouldRetry(response.status_code, error_code) + and self._shouldRetry( + response.status_code, error_code, error_details + ) ): - retry_after_raw: Optional[str] = response.headers.get( - "retry-after" + retry_after_val: Optional[float] = self._extractRetryAfter( + error_body, response ) - retry_after_val: Optional[float] = None - if retry_after_raw: - try: - retry_after_val = float(retry_after_raw) - except (ValueError, TypeError): - pass delay = self._calculateRetryDelay(attempt, retry_after_val) _logger.warning( "Retryable error %d on attempt %d/%d, retrying in %.1fs", @@ -404,22 +467,24 @@ async def _request( error_body: Optional[Dict[str, Any]] = self._parseErrorResponse(response) error_code: Optional[str] = None + error_details: Optional[Dict[str, Any]] = None if isinstance(error_body, dict): err_obj: Any = error_body.get("error", error_body) if isinstance(err_obj, dict): error_code = err_obj.get("code") + raw_details: Any = err_obj.get("details") + if isinstance(raw_details, dict): + error_details = raw_details if ( attempt < self.max_retries - and self._shouldRetry(response.status_code, error_code) + and self._shouldRetry( + response.status_code, error_code, error_details + ) ): - retry_after_raw: Optional[str] = response.headers.get("retry-after") - retry_after_val: Optional[float] = None - if retry_after_raw: - try: - retry_after_val = float(retry_after_raw) - except (ValueError, TypeError): - pass + retry_after_val: Optional[float] = self._extractRetryAfter( + error_body, response + ) delay = self._calculateRetryDelay(attempt, retry_after_val) _logger.warning( "Retryable error %d on attempt %d/%d, retrying in %.1fs", diff --git a/src/knowhere/_client.py b/src/knowhere/_client.py index 0856d96..b2cbc3e 100644 --- a/src/knowhere/_client.py +++ b/src/knowhere/_client.py @@ -13,6 +13,7 @@ from knowhere._base_client import AsyncAPIClient, SyncAPIClient from knowhere._constants import DEFAULT_POLL_INTERVAL, DEFAULT_POLL_TIMEOUT +from knowhere._exceptions import ValidationError from knowhere._logging import getLogger from knowhere._types import ( PollProgressCallback, @@ -94,9 +95,9 @@ def parse( Provide exactly one of *url* or *file*. """ if url and file: - raise ValueError("Provide either 'url' or 'file', not both.") + raise ValidationError("Provide either 'url' or 'file', not both.") if not url and file is None: - raise ValueError("Provide either 'url' or 'file'.") + raise ValidationError("Provide either 'url' or 'file'.") # Determine source type and create job if url: @@ -196,9 +197,9 @@ async def parse( ) -> ParseResult: """Parse a document end-to-end (async version).""" if url and file: - raise ValueError("Provide either 'url' or 'file', not both.") + raise ValidationError("Provide either 'url' or 'file', not both.") if not url and file is None: - raise ValueError("Provide either 'url' or 'file'.") + raise ValidationError("Provide either 'url' or 'file'.") if url: job: Job = await self.jobs.create( diff --git a/src/knowhere/_constants.py b/src/knowhere/_constants.py index cb25d23..7b74dd3 100644 --- a/src/knowhere/_constants.py +++ b/src/knowhere/_constants.py @@ -18,6 +18,7 @@ # Retry configuration DEFAULT_MAX_RETRIES: int = 5 +DEFAULT_UPLOAD_MAX_RETRIES: int = 2 # Polling configuration MAX_POLL_INTERVAL: float = 30.0 diff --git a/src/knowhere/_exceptions.py b/src/knowhere/_exceptions.py index c407108..208aa40 100644 --- a/src/knowhere/_exceptions.py +++ b/src/knowhere/_exceptions.py @@ -41,6 +41,19 @@ def __init__(self, message: str = "Request timed out.") -> None: super().__init__(message) +# --------------------------------------------------------------------------- +# Validation / state +# --------------------------------------------------------------------------- + + +class ValidationError(KnowhereError): + """Raised when the caller provides invalid arguments.""" + + +class InvalidStateError(KnowhereError): + """Raised when an object is in an unexpected state for the operation.""" + + # --------------------------------------------------------------------------- # Polling / job errors # --------------------------------------------------------------------------- @@ -161,9 +174,17 @@ class ConflictError(APIStatusError): class RateLimitError(APIStatusError): - """HTTP 429 — includes optional ``retry_after`` hint.""" + """HTTP 429 — includes optional rate limit hints from the server. + + Attributes: + retry_after: Seconds to wait before retrying (``None`` for quota exceeded). + limit: Maximum allowed requests in the rate window. + period: Rate window unit (``"second"``, ``"minute"``, ``"hour"``, ``"day"``). + """ retry_after: Optional[float] + limit: Optional[int] + period: Optional[str] def __init__( self, @@ -176,6 +197,8 @@ def __init__( body: Optional[Any] = None, response: httpx.Response, retry_after: Optional[float] = None, + limit: Optional[int] = None, + period: Optional[str] = None, ) -> None: super().__init__( status_code, @@ -187,6 +210,8 @@ def __init__( response=response, ) self.retry_after = retry_after + self.limit = limit + self.period = period class InternalServerError(APIStatusError): @@ -194,9 +219,17 @@ class InternalServerError(APIStatusError): class ServiceUnavailableError(APIStatusError): - """HTTP 502 / 503 — includes optional ``retry_after`` hint.""" + """HTTP 502 / 503 — includes optional rate limit hints from the server. + + Attributes: + retry_after: Seconds to wait before retrying. + limit: Maximum allowed requests in the rate window (optional). + period: Rate window unit (optional). + """ retry_after: Optional[float] + limit: Optional[int] + period: Optional[str] def __init__( self, @@ -209,6 +242,8 @@ def __init__( body: Optional[Any] = None, response: httpx.Response, retry_after: Optional[float] = None, + limit: Optional[int] = None, + period: Optional[str] = None, ) -> None: super().__init__( status_code, @@ -220,12 +255,22 @@ def __init__( response=response, ) self.retry_after = retry_after + self.limit = limit + self.period = period class GatewayTimeoutError(APIStatusError): - """HTTP 504 — includes optional ``retry_after`` hint.""" + """HTTP 504 — includes optional rate limit hints from the server. + + Attributes: + retry_after: Seconds to wait before retrying. + limit: Maximum allowed requests in the rate window (optional). + period: Rate window unit (optional). + """ retry_after: Optional[float] + limit: Optional[int] + period: Optional[str] def __init__( self, @@ -238,6 +283,8 @@ def __init__( body: Optional[Any] = None, response: httpx.Response, retry_after: Optional[float] = None, + limit: Optional[int] = None, + period: Optional[str] = None, ) -> None: super().__init__( status_code, @@ -249,6 +296,8 @@ def __init__( response=response, ) self.retry_after = retry_after + self.limit = limit + self.period = period # --------------------------------------------------------------------------- @@ -298,14 +347,36 @@ def makeStatusError( status_code, APIStatusError ) - # Extract retry_after for classes that support it + # Extract retry hints for classes that support them + # Prefer body: error.details.retry_after, fallback to HTTP header retry_after: Optional[float] = None - raw_retry: Optional[str] = response.headers.get("retry-after") - if raw_retry is not None: - try: - retry_after = float(raw_retry) - except (ValueError, TypeError): - retry_after = None + limit: Optional[int] = None + period: Optional[str] = None + + if isinstance(details, dict): + raw_body_retry: Any = details.get("retry_after") + if raw_body_retry is not None: + try: + retry_after = float(raw_body_retry) + except (ValueError, TypeError): + pass + raw_limit: Any = details.get("limit") + if raw_limit is not None: + try: + limit = int(raw_limit) + except (ValueError, TypeError): + pass + raw_period: Any = details.get("period") + if isinstance(raw_period, str): + period = raw_period + + if retry_after is None: + raw_header_retry: Optional[str] = response.headers.get("retry-after") + if raw_header_retry is not None: + try: + retry_after = float(raw_header_retry) + except (ValueError, TypeError): + pass common_kwargs: Dict[str, Any] = dict( code=code, @@ -318,7 +389,11 @@ def makeStatusError( if exception_class in (RateLimitError, ServiceUnavailableError, GatewayTimeoutError): return exception_class( - status_code, **common_kwargs, retry_after=retry_after # type: ignore[call-arg] + status_code, + **common_kwargs, + retry_after=retry_after, # type: ignore[call-arg] + limit=limit, + period=period, ) return exception_class(status_code, **common_kwargs) diff --git a/src/knowhere/lib/upload.py b/src/knowhere/lib/upload.py index af43396..5c7b278 100644 --- a/src/knowhere/lib/upload.py +++ b/src/knowhere/lib/upload.py @@ -2,11 +2,15 @@ from __future__ import annotations +import asyncio +import random +import time from pathlib import Path from typing import BinaryIO, Dict, Optional, Union import httpx +from knowhere._constants import DEFAULT_UPLOAD_MAX_RETRIES from knowhere._exceptions import APIConnectionError, APITimeoutError from knowhere._logging import getLogger from knowhere._types import UploadProgressCallback @@ -16,6 +20,26 @@ # Chunk size for streaming uploads (256 KiB) _UPLOAD_CHUNK_SIZE: int = 256 * 1024 +# Storage-provider HTTP status codes that are safe to retry. +# These are transient errors from S3/GCS/Azure Blob, not Knowhere API codes. +_UPLOAD_RETRYABLE_STATUS_CODES: frozenset[int] = frozenset({500, 502, 503, 504}) + + +def _calculateUploadRetryDelay(attempt: int) -> float: + """Exponential backoff with jitter for upload retries.""" + base_delay: float = min(1.0 * (2 ** attempt), 16.0) + jitter: float = random.uniform(0, base_delay * 0.25) + return base_delay + jitter + + +def _isRetryableUploadError(exc: Exception) -> bool: + """Return True if the upload error is transient and worth retrying.""" + if isinstance(exc, (httpx.ConnectError, httpx.TimeoutException)): + return True + if isinstance(exc, httpx.HTTPStatusError): + return exc.response.status_code in _UPLOAD_RETRYABLE_STATUS_CODES + return False + def _prepareFileContent( file: Union[Path, BinaryIO, bytes], @@ -66,41 +90,68 @@ def syncUploadFile( on_progress: Optional[UploadProgressCallback] = None, *, timeout: float = 600.0, + max_retries: int = DEFAULT_UPLOAD_MAX_RETRIES, ) -> None: - """Upload *file* to *upload_url* using a synchronous PUT request.""" + """Upload *file* to *upload_url* using a synchronous PUT request. + + Retries on connection errors, timeouts, and transient storage HTTP errors + (500/502/503/504) up to *max_retries* times. + """ content, total_bytes = _prepareFileContent(file) headers: Dict[str, str] = _buildUploadHeaders(upload_headers, total_bytes) - _logger.debug("Uploading %s bytes to %s", total_bytes, upload_url) - if isinstance(content, bytes): data: bytes = content else: - # BinaryIO — read all for simplicity (already measured size) pos: int = content.tell() data = content.read() content.seek(pos) - if on_progress: - on_progress(0, total_bytes) + last_exc: Optional[Exception] = None - try: - response: httpx.Response = client.put( - upload_url, - content=data, - headers=headers, - timeout=timeout, + for attempt in range(max_retries + 1): + _logger.debug( + "Upload attempt %d/%d — %s bytes to %s", + attempt + 1, max_retries + 1, total_bytes, upload_url, ) - response.raise_for_status() - except httpx.TimeoutException as exc: - raise APITimeoutError(f"Upload timed out: {exc}") from exc - except httpx.HTTPError as exc: - raise APIConnectionError(f"Upload failed: {exc}") from exc - - if on_progress: - on_progress(len(data), total_bytes) - _logger.debug("Upload complete: %d", response.status_code) + if on_progress and attempt == 0: + on_progress(0, total_bytes) + + try: + response: httpx.Response = client.put( + upload_url, + content=data, + headers=headers, + timeout=timeout, + ) + response.raise_for_status() + except (httpx.HTTPError, httpx.TimeoutException) as exc: + last_exc = exc + if attempt < max_retries and _isRetryableUploadError(exc): + delay: float = _calculateUploadRetryDelay(attempt) + _logger.warning( + "Upload attempt %d/%d failed (%s), retrying in %.1fs", + attempt + 1, max_retries + 1, exc, delay, + ) + time.sleep(delay) + continue + # Non-retryable or exhausted retries + if isinstance(exc, httpx.TimeoutException): + raise APITimeoutError(f"Upload timed out: {exc}") from exc + raise APIConnectionError(f"Upload failed: {exc}") from exc + + # Success + if on_progress: + on_progress(len(data), total_bytes) + _logger.debug("Upload complete: %d", response.status_code) + return + + # Should not reach here, but guard against it + if last_exc is not None: + if isinstance(last_exc, httpx.TimeoutException): + raise APITimeoutError(f"Upload timed out: {last_exc}") from last_exc + raise APIConnectionError(f"Upload failed: {last_exc}") from last_exc async def asyncUploadFile( @@ -111,37 +162,62 @@ async def asyncUploadFile( on_progress: Optional[UploadProgressCallback] = None, *, timeout: float = 600.0, + max_retries: int = DEFAULT_UPLOAD_MAX_RETRIES, ) -> None: - """Upload *file* to *upload_url* using an async PUT request.""" + """Upload *file* to *upload_url* using an async PUT request. + + Retries on connection errors, timeouts, and transient storage HTTP errors + (500/502/503/504) up to *max_retries* times. + """ content, total_bytes = _prepareFileContent(file) headers: Dict[str, str] = _buildUploadHeaders(upload_headers, total_bytes) - _logger.debug("Async uploading %s bytes to %s", total_bytes, upload_url) - if isinstance(content, bytes): data: bytes = content else: - pos = content.tell() + pos: int = content.tell() data = content.read() content.seek(pos) - if on_progress: - on_progress(0, total_bytes) + last_exc: Optional[Exception] = None - try: - response: httpx.Response = await client.put( - upload_url, - content=data, - headers=headers, - timeout=timeout, + for attempt in range(max_retries + 1): + _logger.debug( + "Async upload attempt %d/%d — %s bytes to %s", + attempt + 1, max_retries + 1, total_bytes, upload_url, ) - response.raise_for_status() - except httpx.TimeoutException as exc: - raise APITimeoutError(f"Upload timed out: {exc}") from exc - except httpx.HTTPError as exc: - raise APIConnectionError(f"Upload failed: {exc}") from exc - - if on_progress: - on_progress(len(data), total_bytes) - _logger.debug("Async upload complete: %d", response.status_code) + if on_progress and attempt == 0: + on_progress(0, total_bytes) + + try: + response: httpx.Response = await client.put( + upload_url, + content=data, + headers=headers, + timeout=timeout, + ) + response.raise_for_status() + except (httpx.HTTPError, httpx.TimeoutException) as exc: + last_exc = exc + if attempt < max_retries and _isRetryableUploadError(exc): + delay: float = _calculateUploadRetryDelay(attempt) + _logger.warning( + "Async upload attempt %d/%d failed (%s), retrying in %.1fs", + attempt + 1, max_retries + 1, exc, delay, + ) + await asyncio.sleep(delay) + continue + if isinstance(exc, httpx.TimeoutException): + raise APITimeoutError(f"Upload timed out: {exc}") from exc + raise APIConnectionError(f"Upload failed: {exc}") from exc + + if on_progress: + on_progress(len(data), total_bytes) + _logger.debug("Async upload complete: %d", response.status_code) + return + + if last_exc is not None: + if isinstance(last_exc, httpx.TimeoutException): + raise APITimeoutError(f"Upload timed out: {last_exc}") from last_exc + raise APIConnectionError(f"Upload failed: {last_exc}") from last_exc diff --git a/src/knowhere/resources/jobs.py b/src/knowhere/resources/jobs.py index f4da6f6..11fdc21 100644 --- a/src/knowhere/resources/jobs.py +++ b/src/knowhere/resources/jobs.py @@ -8,6 +8,7 @@ import httpx from knowhere._constants import DEFAULT_POLL_INTERVAL, DEFAULT_POLL_TIMEOUT +from knowhere._exceptions import InvalidStateError from knowhere._logging import getLogger from knowhere._types import ( PollProgressCallback, @@ -84,7 +85,7 @@ def upload( """ if isinstance(job, Job): if not job.upload_url: - raise ValueError("Job does not have an upload URL.") + raise InvalidStateError("Job does not have an upload URL.") upload_url: str = job.upload_url upload_headers: Optional[Dict[str, str]] = job.upload_headers else: @@ -134,7 +135,7 @@ def load( """ if isinstance(job_result, JobResult): if not job_result.result_url: - raise ValueError("JobResult does not have a result_url.") + raise InvalidStateError("JobResult does not have a result_url.") result_url: str = job_result.result_url else: result_url = job_result @@ -192,7 +193,7 @@ async def upload( """Upload a file for a job (async).""" if isinstance(job, Job): if not job.upload_url: - raise ValueError("Job does not have an upload URL.") + raise InvalidStateError("Job does not have an upload URL.") upload_url: str = job.upload_url upload_headers: Optional[Dict[str, str]] = job.upload_headers else: @@ -234,7 +235,7 @@ async def load( """Download and parse the result ZIP (async).""" if isinstance(job_result, JobResult): if not job_result.result_url: - raise ValueError("JobResult does not have a result_url.") + raise InvalidStateError("JobResult does not have a result_url.") result_url: str = job_result.result_url else: result_url = job_result diff --git a/src/knowhere/types/result.py b/src/knowhere/types/result.py index d5f3e22..ea02cb7 100644 --- a/src/knowhere/types/result.py +++ b/src/knowhere/types/result.py @@ -9,6 +9,8 @@ from pydantic import BaseModel, Field +from knowhere._exceptions import ValidationError + # --------------------------------------------------------------------------- # Filename sanitisation helper @@ -30,11 +32,11 @@ def _sanitizeFilename(name: str) -> str: def _ensurePathWithinDirectory(base: Path, target: Path) -> Path: - """Raise ``ValueError`` if *target* escapes *base* (Zip Slip prevention).""" + """Raise ``ValidationError`` if *target* escapes *base* (Zip Slip prevention).""" resolved_base: Path = base.resolve() resolved_target: Path = target.resolve() if not str(resolved_target).startswith(str(resolved_base)): - raise ValueError( + raise ValidationError( f"Path '{resolved_target}' escapes output directory '{resolved_base}'." ) return resolved_target diff --git a/tests/test_client.py b/tests/test_client.py index 68476eb..0b074ac 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,6 +8,8 @@ import pytest +from knowhere._exceptions import ValidationError + from knowhere._constants import ( DEFAULT_BASE_URL, DEFAULT_MAX_RETRIES, @@ -62,7 +64,7 @@ def test_missing_api_key_raises_value_error(self) -> None: with patch.dict(os.environ, {}, clear=True): os.environ.pop("KNOWHERE_API_KEY", None) - with pytest.raises(ValueError, match="(?i)api.key"): + with pytest.raises(ValidationError, match="(?i)api.key"): Knowhere() def test_default_timeout(self) -> None: @@ -168,7 +170,7 @@ def test_missing_api_key_raises_value_error(self) -> None: with patch.dict(os.environ, {}, clear=True): os.environ.pop("KNOWHERE_API_KEY", None) - with pytest.raises(ValueError, match="(?i)api.key"): + with pytest.raises(ValidationError, match="(?i)api.key"): AsyncKnowhere() def test_default_values(self) -> None: diff --git a/tests/test_parse.py b/tests/test_parse.py index d91564a..8d545b0 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -8,6 +8,8 @@ import httpx import pytest + +from knowhere._exceptions import ValidationError import respx from tests.conftest import BASE_URL @@ -219,13 +221,13 @@ class TestParseValidation: def test_missing_url_and_file_raises_value_error( self, sync_client: Any ) -> None: - with pytest.raises(ValueError, match="url.*file|file.*url"): + with pytest.raises(ValidationError, match="url.*file|file.*url"): sync_client.parse() def test_both_url_and_file_raises_value_error( self, sync_client: Any ) -> None: - with pytest.raises(ValueError, match="url.*file|file.*url"): + with pytest.raises(ValidationError, match="url.*file|file.*url"): sync_client.parse( url="https://example.com/doc.pdf", file=b"content", diff --git a/tests/test_retry.py b/tests/test_retry.py index 50dd391..f2b9329 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any, Dict, Optional import httpx import pytest @@ -11,6 +11,8 @@ from knowhere._exceptions import ( AuthenticationError, BadRequestError, + ConflictError, + InternalServerError, RateLimitError, ServiceUnavailableError, ) @@ -29,16 +31,20 @@ } -def _error_body(code: str, message: str) -> Dict[str, Any]: +def _error_body( + code: str, + message: str, + details: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: """Build an error response body.""" - return { - "success": False, - "error": { - "code": code, - "message": message, - "request_id": "req_retry", - }, + error: Dict[str, Any] = { + "code": code, + "message": message, + "request_id": "req_retry", } + if details is not None: + error["details"] = details + return {"success": False, "error": error} # --------------------------------------------------------------------------- @@ -96,21 +102,26 @@ def test_504_triggers_retry(self, sync_client: Any) -> None: # --------------------------------------------------------------------------- -# 429 with retry_after triggers retry +# 429 with retry_after in body details triggers retry # --------------------------------------------------------------------------- class TestRetry429WithRetryAfter: - """Verify 429 with retry-after header triggers retry.""" + """Verify 429 with details.retry_after triggers retry.""" @respx.mock - def test_429_with_retry_after_retries(self, sync_client: Any) -> None: + def test_429_with_retry_after_in_body_retries( + self, sync_client: Any + ) -> None: route = respx.get(GET_URL).mock( side_effect=[ httpx.Response( 429, - json=_error_body("RESOURCE_EXHAUSTED", "Rate limited"), - headers={"retry-after": "0"}, + json=_error_body( + "RESOURCE_EXHAUSTED", + "Rate limited", + details={"retry_after": 0}, + ), ), httpx.Response(200, json=DONE_RESPONSE), ] @@ -123,31 +134,30 @@ def test_429_with_retry_after_retries(self, sync_client: Any) -> None: # --------------------------------------------------------------------------- -# 429 without retry_after also retries (429 is in _RETRYABLE_STATUS_CODES) +# 429 without retry_after does NOT retry (quota exceeded) # --------------------------------------------------------------------------- -class TestRetry429WithoutRetryAfter: - """Verify 429 without retry-after still retries (status is retryable).""" +class TestNoRetry429WithoutRetryAfter: + """Verify 429 without retry_after does NOT retry (quota exceeded).""" @respx.mock - def test_429_without_retry_after_still_retries( + def test_429_without_retry_after_does_not_retry( self, sync_client: Any ) -> None: route = respx.get(GET_URL).mock( - side_effect=[ - httpx.Response( - 429, - json=_error_body("RESOURCE_EXHAUSTED", "Quota exceeded"), + return_value=httpx.Response( + 429, + json=_error_body( + "RESOURCE_EXHAUSTED", "Quota exceeded" ), - httpx.Response(200, json=DONE_RESPONSE), - ] + ) ) - result = sync_client.jobs.get(JOB_ID) + with pytest.raises(RateLimitError): + sync_client.jobs.get(JOB_ID) - assert result.status == "done" - assert route.call_count == 2 + assert route.call_count == 1 # --------------------------------------------------------------------------- @@ -248,3 +258,78 @@ def test_connection_error_triggers_retry( assert result.status == "done" assert route.call_count == 2 + + +# --------------------------------------------------------------------------- +# 409 (ABORTED) triggers retry +# --------------------------------------------------------------------------- + + +class TestRetry409: + """Verify 409 ABORTED responses trigger automatic retry.""" + + @respx.mock + def test_409_triggers_retry(self, sync_client: Any) -> None: + route = respx.get(GET_URL).mock( + side_effect=[ + httpx.Response( + 409, + json=_error_body("ABORTED", "Concurrency conflict"), + ), + httpx.Response(200, json=DONE_RESPONSE), + ] + ) + + result = sync_client.jobs.get(JOB_ID) + + assert result.status == "done" + assert route.call_count == 2 + + +# --------------------------------------------------------------------------- +# 500 does NOT retry +# --------------------------------------------------------------------------- + + +class TestNoRetry500: + """Verify 500 INTERNAL_ERROR responses are NOT retried.""" + + @respx.mock + def test_500_does_not_retry(self, sync_client: Any) -> None: + route = respx.get(GET_URL).mock( + return_value=httpx.Response( + 500, + json=_error_body("INTERNAL_ERROR", "Internal server error"), + ) + ) + + with pytest.raises(InternalServerError): + sync_client.jobs.get(JOB_ID) + + assert route.call_count == 1 + + +# --------------------------------------------------------------------------- +# 502 triggers retry (maps to UNAVAILABLE) +# --------------------------------------------------------------------------- + + +class TestRetry502: + """Verify 502 responses trigger automatic retry.""" + + @respx.mock + def test_502_triggers_retry(self, sync_client: Any) -> None: + route = respx.get(GET_URL).mock( + side_effect=[ + httpx.Response( + 502, + json=_error_body("UNAVAILABLE", "Bad gateway"), + ), + httpx.Response(200, json=DONE_RESPONSE), + ] + ) + + result = sync_client.jobs.get(JOB_ID) + + assert result.status == "done" + assert route.call_count == 2 diff --git a/tests/test_upload.py b/tests/test_upload.py index b970cce..0b4c64b 100644 --- a/tests/test_upload.py +++ b/tests/test_upload.py @@ -11,7 +11,7 @@ import pytest import respx -from knowhere._exceptions import KnowhereError +from knowhere._exceptions import APIConnectionError, KnowhereError from knowhere.types.job import Job @@ -191,3 +191,63 @@ def test_upload_url_string_sends_put(self, sync_client: Any) -> None: sync_client.jobs.upload(UPLOAD_URL, b"content via url string") assert route.called + + +# --------------------------------------------------------------------------- +# Upload retry on transient storage errors +# --------------------------------------------------------------------------- + + +class TestUploadRetry503: + """Verify upload retries on 503 from storage provider.""" + + @respx.mock + def test_503_triggers_upload_retry(self, sync_client: Any) -> None: + route = respx.put(UPLOAD_URL).mock( + side_effect=[ + httpx.Response(503), + httpx.Response(200), + ] + ) + + job: Job = _make_job_with_upload_url(UPLOAD_URL) + sync_client.jobs.upload(job, b"retry content") + + assert route.call_count == 2 + + +class TestUploadRetryConnectionError: + """Verify upload retries on connection errors.""" + + @respx.mock + def test_connection_error_triggers_upload_retry( + self, sync_client: Any + ) -> None: + route = respx.put(UPLOAD_URL).mock( + side_effect=[ + httpx.ConnectError("Connection refused"), + httpx.Response(200), + ] + ) + + job: Job = _make_job_with_upload_url(UPLOAD_URL) + sync_client.jobs.upload(job, b"retry content") + + assert route.call_count == 2 + + +class TestUploadNoRetry403: + """Verify upload does NOT retry on 403 (expired pre-signed URL).""" + + @respx.mock + def test_403_does_not_retry(self, sync_client: Any) -> None: + route = respx.put(UPLOAD_URL).mock( + return_value=httpx.Response(403) + ) + + job: Job = _make_job_with_upload_url(UPLOAD_URL) + + with pytest.raises(APIConnectionError): + sync_client.jobs.upload(job, b"should not retry") + + assert route.call_count == 1