From 3027cc3bdf4d71b779551f7d266907cc680e2b4c Mon Sep 17 00:00:00 2001 From: Abhinav Date: Sat, 4 Oct 2025 20:10:30 +0530 Subject: [PATCH 1/5] better validations, user friendly with custom errors, some ai slop docstring and more --- fast_flights/__init__.py | 36 ++++- fast_flights/exceptions.py | 33 ++++ fast_flights/fetcher.py | 113 +++++++++++--- fast_flights/integrations/base.py | 62 +++++--- fast_flights/integrations/bright_data.py | 113 +++++++++++--- fast_flights/querying.py | 184 +++++++++++++++++++---- fast_flights/validation.py | 132 ++++++++++++++++ test.py | 5 +- 8 files changed, 587 insertions(+), 91 deletions(-) create mode 100644 fast_flights/exceptions.py create mode 100644 fast_flights/validation.py diff --git a/fast_flights/__init__.py b/fast_flights/__init__.py index c496da65..7d295b04 100644 --- a/fast_flights/__init__.py +++ b/fast_flights/__init__.py @@ -1,5 +1,21 @@ from . import integrations - +from .exceptions import ( + FastFlightsError, + ValidationError, + AirportCodeError, + DateFormatError, + PassengerError, + FlightQueryError, + APIConnectionError, + APIError, +) +from .validation import ( + validate_airport_code, + validate_date, + validate_passengers, + validate_flight_query, + validate_currency, +) from .querying import ( FlightQuery, Query, @@ -10,6 +26,7 @@ from .fetcher import get_flights, fetch_flights_html __all__ = [ + # Core functionality "FlightQuery", "Query", "Passengers", @@ -18,4 +35,21 @@ "get_flights", "fetch_flights_html", "integrations", + + # Exceptions + "FastFlightsError", + "ValidationError", + "AirportCodeError", + "DateFormatError", + "PassengerError", + "FlightQueryError", + "APIConnectionError", + "APIError", + + # Validation utilities + "validate_airport_code", + "validate_date", + "validate_passengers", + "validate_flight_query", + "validate_currency", ] diff --git a/fast_flights/exceptions.py b/fast_flights/exceptions.py new file mode 100644 index 00000000..417ec0e5 --- /dev/null +++ b/fast_flights/exceptions.py @@ -0,0 +1,33 @@ +"""Custom exceptions for the fast_flights package.""" + +class FastFlightsError(Exception): + """Base exception for all fast_flights exceptions.""" + pass + +class ValidationError(FastFlightsError, ValueError): + """Raised when input validation fails.""" + pass + +class AirportCodeError(ValidationError): + """Raised when an invalid airport code is provided.""" + pass + +class DateFormatError(ValidationError): + """Raised when a date string is in an invalid format.""" + pass + +class PassengerError(ValidationError): + """Raised when there's an issue with passenger configuration.""" + pass + +class FlightQueryError(ValidationError): + """Raised when there's an issue with flight query parameters.""" + pass + +class APIConnectionError(FastFlightsError): + """Raised when there's an issue connecting to the flight data API.""" + pass + +class APIError(FastFlightsError): + """Raised when the flight data API returns an error.""" + pass diff --git a/fast_flights/fetcher.py b/fast_flights/fetcher.py index 1729fbf3..8d3e0221 100644 --- a/fast_flights/fetcher.py +++ b/fast_flights/fetcher.py @@ -1,11 +1,16 @@ +import logging from typing import Optional, Union, overload from primp import Client +from .exceptions import APIConnectionError, APIError from .querying import Query from .parser import MetaList, parse from .integrations import Integration +# Set up logging +logger = logging.getLogger(__name__) + URL = "https://www.google.com/travel/flights" @@ -54,11 +59,28 @@ def get_flights( """Get flights. Args: - q: The query. - proxy (str, optional): Proxy. + q: The query string or Query object. + proxy: Optional proxy configuration. + integration: Optional integration to use for fetching data. + + Returns: + MetaList: Parsed flight data. + + Raises: + APIConnectionError: If there's an issue connecting to the flight data source. + APIError: If the API returns an error or invalid response. + ValueError: If the input query is invalid. """ - html = fetch_flights_html(q, proxy=proxy, integration=integration) - return parse(html) + try: + logger.debug("Fetching flight data...") + html = fetch_flights_html(q, proxy=proxy, integration=integration) + if not html or not isinstance(html, str): + raise APIError("Received empty or invalid response from the flight data source") + return parse(html) + except Exception as e: + if isinstance(e, (APIConnectionError, APIError, ValueError)): + raise + raise APIConnectionError(f"Failed to fetch flight data: {str(e)}") from e def fetch_flights_html( @@ -68,29 +90,70 @@ def fetch_flights_html( proxy: Optional[str] = None, integration: Optional[Integration] = None, ) -> str: - """Fetch flights and get the **HTML**. + """Fetch flights and get the HTML response. Args: - q: The query. - proxy (str, optional): Proxy. + q: The query string or Query object. + proxy: Optional proxy configuration. + integration: Optional integration to use for fetching data. + + Returns: + str: The HTML content of the flight search results. + + Raises: + APIConnectionError: If there's an issue connecting to the flight data source. + APIError: If the API returns an error or invalid response. + ValueError: If the input query is invalid. """ - if integration is None: - client = Client( - impersonate="chrome_133", - impersonate_os="macos", - referer=True, - proxy=proxy, - cookie_store=True, - ) - - if isinstance(q, Query): - params = q.params() + if not q: + raise ValueError("Query cannot be empty") + + try: + if integration is None: + logger.debug("Using default client for fetching flight data") + client = Client( + impersonate="chrome_133", + impersonate_os="macos", + referer=True, + proxy=proxy, + cookie_store=True, + timeout=30, # 30 seconds timeout + ) + + try: + if isinstance(q, Query): + params = q.params() + else: + if not isinstance(q, str): + raise ValueError("Query must be a string or Query object") + params = {"q": q} + + logger.debug(f"Sending request to {URL} with params: {params}") + res = client.get(URL, params=params) + res.raise_for_status() # Raise HTTPError for bad responses + + if not res.text: + raise APIError("Received empty response from the flight data source") + + return res.text + + except Exception as e: + logger.error(f"Error fetching flight data: {str(e)}") + if hasattr(e, 'response') and hasattr(e.response, 'status_code'): + raise APIError( + f"Flight data API returned status code {e.response.status_code}: {str(e)}" + ) from e + raise APIConnectionError(f"Failed to connect to flight data source: {str(e)}") from e else: - params = {"q": q} - - res = client.get(URL, params=params) - return res.text - - else: - return integration.fetch_html(q) + logger.debug("Using integration for fetching flight data") + try: + return integration.fetch_html(q) + except Exception as e: + logger.error(f"Integration error while fetching flight data: {str(e)}") + raise APIError(f"Integration failed to fetch flight data: {str(e)}") from e + + except Exception as e: + if isinstance(e, (APIConnectionError, APIError, ValueError)): + raise + raise APIConnectionError(f"Unexpected error while fetching flight data: {str(e)}") from e diff --git a/fast_flights/integrations/base.py b/fast_flights/integrations/base.py index c7a832a9..83ff0468 100644 --- a/fast_flights/integrations/base.py +++ b/fast_flights/integrations/base.py @@ -1,38 +1,64 @@ +"""Base integration module for flight data providers.""" +import logging import os +from abc import ABC, abstractmethod +from typing import Union, Optional -from abc import ABC -from typing import Union - +from ..exceptions import APIConnectionError, APIError from ..querying import Query +# Set up logging +logger = logging.getLogger(__name__) + try: import dotenv # pip install python-dotenv - dotenv.load_dotenv() - except ModuleNotFoundError: - pass + logger.debug("python-dotenv not installed, skipping .env file loading") class Integration(ABC): + """Abstract base class for flight data integrations. + + This class defines the interface that all flight data integrations must implement. + Subclasses should implement the fetch_html method to retrieve flight data. + """ + + @abstractmethod def fetch_html(self, q: Union[Query, str], /) -> str: """Fetch the flights page HTML from a query. Args: - q: The query. + q: The query string or Query object. + + Returns: + str: The HTML content of the flight search results. + + Raises: + APIConnectionError: If there's an issue connecting to the data source. + APIError: If the API returns an error or invalid response. + ValueError: If the input query is invalid. """ - raise NotImplementedError - - -def get_env(k: str, /) -> str: - """(utility) Get environment variable. + raise NotImplementedError("Subclasses must implement this method") - If nothing found, raises an error. +def get_env(k: str, /, default: Optional[str] = None) -> str: + """Get environment variable with optional default value. + + Args: + k: The name of the environment variable. + default: Default value to return if the environment variable is not found. + If not provided, raises an OSError when the variable is not found. + Returns: - str: The value. + str: The value of the environment variable, or the default value if provided. + + Raises: + OSError: If the environment variable is not found and no default is provided. """ - try: - return os.environ[k] - except KeyError: - raise OSError(f"could not find environment variable: {k!r}") + value = os.environ.get(k, default) + if value is None: + error_msg = f"Required environment variable not found: {k!r}" + logger.error(error_msg) + raise OSError(error_msg) + return value diff --git a/fast_flights/integrations/bright_data.py b/fast_flights/integrations/bright_data.py index 88dd0d5e..c15da286 100644 --- a/fast_flights/integrations/bright_data.py +++ b/fast_flights/integrations/bright_data.py @@ -1,19 +1,36 @@ -# original by @Manouchehri -# pr: #64 - +"""BrightData integration for fetching flight data.""" +import logging from typing import Optional, Union + from primp import Client from .base import Integration, get_env from ..querying import Query from ..fetcher import URL +from ..exceptions import APIConnectionError, APIError + +# Set up logging +logger = logging.getLogger(__name__) DEFAULT_API_URL = "https://api.brightdata.com/request" DEFAULT_DATA_SERP_ZONE = "serp_api1" class BrightData(Integration): - __slots__ = ("api_url", "zone") + """BrightData integration for fetching flight data using their API. + + This class provides a way to fetch flight data using BrightData's API. + It requires a valid API key and zone configuration. + + Args: + api_key: BrightData API key. If not provided, will try to get from environment. + api_url: Base URL for the BrightData API. + zone: BrightData zone to use for the requests. + + Raises: + ValueError: If required configuration is missing. + """ + __slots__ = ("api_url", "zone", "client") def __init__( self, @@ -22,22 +39,84 @@ def __init__( api_url: str = DEFAULT_API_URL, zone: str = DEFAULT_DATA_SERP_ZONE, ): + """Initialize the BrightData integration.""" self.api_url = api_url or get_env("BRIGHT_DATA_API_URL") - self.zone = zone + if not self.api_url: + raise ValueError("BrightData API URL is required") + + self.zone = zone or get_env("BRIGHT_DATA_ZONE") + if not self.zone: + raise ValueError("BrightData zone is required") + + api_key = api_key or get_env("BRIGHT_DATA_API_KEY") + if not api_key: + raise ValueError("BrightData API key is required") + + logger.debug("Initializing BrightData integration") self.client = Client( - headers={ - "Authorization": "Bearer " + (api_key or get_env("BRIGHT_DATA_API_KEY")) - } + headers={"Authorization": f"Bearer {api_key}"}, + timeout=30, # 30 seconds timeout ) def fetch_html(self, q: Union[Query, str], /) -> str: - if isinstance(q, str): - res = self.client.post( - self.api_url, json={"url": URL + "?q=" + q, "zone": self.zone} - ) - else: - res = self.client.post( - self.api_url, json={"url": q.url(), "zone": self.zone} + """Fetch flight data HTML using BrightData API. + + Args: + q: The query string or Query object. + + Returns: + str: The HTML content of the flight search results. + + Raises: + APIConnectionError: If there's an issue connecting to BrightData API. + APIError: If the API returns an error or invalid response. + ValueError: If the input query is invalid. + """ + if not q: + raise ValueError("Query cannot be empty") + + try: + # Prepare the request payload + if isinstance(q, str): + url = f"{URL}?q={q}" + elif isinstance(q, Query): + url = q.url() + else: + raise ValueError("Query must be a string or Query object") + + payload = { + "url": url, + "zone": self.zone, + } + + logger.debug(f"Sending request to BrightData API: {self.api_url}") + logger.debug(f"Request payload: {payload}") + + # Make the API request + response = self.client.post( + self.api_url, + json=payload, + headers={"Content-Type": "application/json"}, ) - - return res.text + + # Check for HTTP errors + if not response.ok: + error_msg = f"BrightData API returned status code {response.status_code}" + logger.error(error_msg) + raise APIError(error_msg) + + # Check for empty response + if not response.text: + error_msg = "Received empty response from BrightData API" + logger.error(error_msg) + raise APIError(error_msg) + + return response.text + + except Exception as e: + if isinstance(e, (APIConnectionError, APIError, ValueError)): + raise + + error_msg = f"Failed to fetch data from BrightData: {str(e)}" + logger.error(error_msg, exc_info=True) + raise APIConnectionError(error_msg) from e diff --git a/fast_flights/querying.py b/fast_flights/querying.py index defeb2d2..f66a6498 100644 --- a/fast_flights/querying.py +++ b/fast_flights/querying.py @@ -1,8 +1,20 @@ +import re from base64 import b64encode from datetime import datetime as Datetime from dataclasses import dataclass from typing import Literal, Optional, Union +from .exceptions import ( + PassengerError, + FlightQueryError, + ValidationError +) +from .validation import ( + validate_passengers, + validate_flight_query, + validate_currency +) + from .types import Currency, Language, SeatType, TripType from .pb.flights_pb2 import Airport, Info, Passenger, Seat, FlightData, Trip @@ -59,13 +71,61 @@ def __repr__(self) -> str: @dataclass class FlightQuery: + """Represents a flight search query. + + Args: + date: Departure date as a string in YYYY-MM-DD format or a datetime object. + from_airport: IATA code of the departure airport (e.g., 'JFK', 'LAX'). + to_airport: IATA code of the arrival airport (e.g., 'SFO', 'LHR'). + max_stops: Maximum number of stops allowed (None for any number of stops). + airlines: Optional list of airline IATA codes to filter by. + + Raises: + FlightQueryError: If any of the query parameters are invalid. + """ date: Union[str, Datetime] from_airport: str to_airport: str max_stops: Optional[int] = None airlines: Optional[list[str]] = None + def __post_init__(self): + """Validate the flight query parameters after initialization.""" + # Convert to strings in case we get non-string inputs + self.from_airport = str(self.from_airport).strip().upper() + self.to_airport = str(self.to_airport).strip().upper() + + # Validate the flight query + try: + validate_flight_query(self.from_airport, self.to_airport, self.date, self.max_stops) + + # Validate airlines if provided + if self.airlines is not None: + if not isinstance(self.airlines, (list, tuple)) or not all( + isinstance(airline, str) and len(airline) == 2 + for airline in self.airlines + ): + raise FlightQueryError( + "airlines must be a list of 2-letter IATA airline codes" + ) + + # Clean and validate each airline code + self.airlines = [airline.strip().upper() for airline in self.airlines] + for airline in self.airlines: + if not re.match(r'^[A-Z]{2}$', airline): + raise FlightQueryError( + f"Invalid airline code: {airline}. Must be 2 uppercase letters" + ) + + except (ValidationError, ValueError) as e: + raise FlightQueryError(str(e)) from e + def pb(self) -> FlightData: + """Convert this query to a protobuf FlightData message. + + Returns: + FlightData: The protobuf message representing this query. + """ if isinstance(self.date, str): date = self.date else: @@ -75,8 +135,8 @@ def pb(self) -> FlightData: date=date, from_airport=Airport(airport=self.from_airport), to_airport=Airport(airport=self.to_airport), - max_stops=self.max_stops, - airlines=self.airlines, + max_stops=self.max_stops if self.max_stops is not None else None, + airlines=self.airlines if self.airlines else None, ) def _setmaxstops(self, m: Optional[int] = None) -> "FlightQuery": @@ -87,6 +147,17 @@ def _setmaxstops(self, m: Optional[int] = None) -> "FlightQuery": class Passengers: + """Represents a group of passengers for a flight. + + Args: + adults: Number of adults (16+ years). At least one adult is required when traveling with children or infants. + children: Number of children (2-15 years). + infants_in_seat: Number of infants (under 2 years) with their own seat. + infants_on_lap: Number of infants (under 2 years) on an adult's lap. + + Raises: + PassengerError: If the passenger configuration is invalid. + """ def __init__( self, *, @@ -95,13 +166,25 @@ def __init__( infants_in_seat: int = 0, infants_on_lap: int = 0, ): - assert ( - sum((adults, children, infants_in_seat, infants_on_lap)) <= 9 - ), "Too many passengers (> 9)" - assert ( - infants_on_lap <= adults - ), "Must have at least one adult per infant on lap" - + # Convert to integers in case we get strings + try: + adults = int(adults) + children = int(children) + infants_in_seat = int(infants_in_seat) + infants_on_lap = int(infants_on_lap) + except (TypeError, ValueError) as e: + raise PassengerError("Passenger counts must be integers") from e + + # Validate passenger counts are non-negative + if any(count < 0 for count in (adults, children, infants_in_seat, infants_on_lap)): + raise PassengerError("Passenger counts cannot be negative") + + # Validate passenger configuration + try: + validate_passengers(adults, children, infants_in_seat, infants_on_lap) + except PassengerError as e: + raise # Re-raise the validation error with the original message + self.adults = adults self.children = children self.infants_in_seat = infants_in_seat @@ -136,26 +219,71 @@ def create_query( seat: SeatType = "economy", trip: TripType = "one-way", passengers: Passengers = DEFAULT_PASSENGERS, - language: Union[str, Literal[""], Language] = "", - currency: Union[str, Literal[""], Currency] = "", + language: Union[str, Literal[""], Language] = "en-US", + currency: Union[str, Literal[""], Currency] = "USD", max_stops: Optional[int] = None, ) -> Query: - """Create a query. - + """Create a query for flight search. + Args: - flights: The flight queries. - seat: Desired seat type. - trip: Trip type. - passengers: Passengers. - language: Set the language. Use `""` (blank str) to let Google decide. - currency: Set the currency. Use `""` (blank str) to let Google decide. - max_stops (optional): Set the maximum stops for every flight query, if present. + flights: List of FlightQuery objects representing the flight segments. + seat: Seat type (e.g., 'economy', 'business'). Defaults to 'economy'. + trip: Type of trip ('one-way', 'round-trip', 'multi-city'). Defaults to 'one-way'. + passengers: Passengers configuration. Defaults to 1 adult. + language: Language code (e.g., 'en', 'es'). Empty string lets Google decide. + currency: Currency code (e.g., 'USD', 'EUR'). Empty string lets Google decide. + max_stops: Maximum number of stops allowed for all flights. Overrides individual flight settings. + + Returns: + Query: A configured Query object ready for flight search. + + Raises: + ValueError: If any of the input parameters are invalid. + FlightQueryError: If there are issues with the flight queries. """ - return Query( - flight_data=[flight._setmaxstops(max_stops).pb() for flight in flights], - seat=SEAT_LOOKUP[seat], - trip=TRIP_LOOKUP[trip], - passengers=passengers.pb(), - language=language, - currency=currency, - ) + # Validate inputs + if not isinstance(flights, (list, tuple)) or not flights: + raise ValueError("At least one flight segment is required") + + if not all(isinstance(f, FlightQuery) for f in flights): + raise ValueError("All flight segments must be FlightQuery instances") + + if seat not in SEAT_LOOKUP: + valid_seats = ", ".join(f"'{s}'" for s in SEAT_LOOKUP.keys()) + raise ValueError(f"Invalid seat type: '{seat}'. Must be one of: {valid_seats}") + + if trip not in TRIP_LOOKUP: + valid_trips = ", ".join(f"'{t}'" for t in TRIP_LOOKUP.keys()) + raise ValueError(f"Invalid trip type: '{trip}'. Must be one of: {valid_trips}") + + if not isinstance(passengers, Passengers): + raise ValueError("passengers must be an instance of Passengers") + + # Validate language and currency if provided + if language and not isinstance(language, str) and not isinstance(language, Language): + raise ValueError("language must be a string or Language enum value") + + try: + if currency: + validate_currency(str(currency)) + except ValueError as e: + raise ValueError(f"Invalid currency: {e}") from e + + # Apply max_stops to all flights if specified + if max_stops is not None: + if not isinstance(max_stops, int) or max_stops < 0: + raise ValueError("max_stops must be a non-negative integer") + flights = [flight._setmaxstops(max_stops) for flight in flights] + + # Create the query + try: + return Query( + flight_data=[flight.pb() for flight in flights], + seat=SEAT_LOOKUP[seat], + trip=TRIP_LOOKUP[trip], + passengers=passengers.pb(), + language=str(language) if language else "", + currency=str(currency) if currency else "", + ) + except Exception as e: + raise FlightQueryError(f"Failed to create query: {str(e)}") from e diff --git a/fast_flights/validation.py b/fast_flights/validation.py new file mode 100644 index 00000000..7ff04144 --- /dev/null +++ b/fast_flights/validation.py @@ -0,0 +1,132 @@ +"""Input validation utilities for fast_flights package.""" +import re +from datetime import datetime +from typing import Optional, Union + +from .exceptions import ( + AirportCodeError, + DateFormatError, + PassengerError, + FlightQueryError +) + +def validate_airport_code(code: str) -> None: + """Validate an airport IATA code. + + Args: + code: The airport code to validate. + + Raises: + AirportCodeError: If the airport code is invalid. + """ + if not isinstance(code, str): + raise AirportCodeError(f"Airport code must be a string, got {type(code).__name__}") + + if not re.match(r'^[A-Z]{3}$', code.upper()): + raise AirportCodeError( + f"Invalid airport code: {code}. Must be 3 uppercase letters (IATA code)" + ) + +def validate_date(date_str: Union[str, datetime]) -> None: + """Validate a date string or datetime object. + + Args: + date_str: The date to validate (string in YYYY-MM-DD format or datetime object). + + Raises: + DateFormatError: If the date format is invalid. + """ + if isinstance(date_str, datetime): + return + + try: + datetime.strptime(date_str, '%Y-%m-%d') + except (TypeError, ValueError) as e: + raise DateFormatError( + f"Invalid date format: {date_str}. Expected YYYY-MM-DD or datetime object" + ) from e + +def validate_passengers( + adults: int = 0, + children: int = 0, + infants_in_seat: int = 0, + infants_on_lap: int = 0 +) -> None: + """Validate passenger configuration. + + Args: + adults: Number of adults. + children: Number of children. + infants_in_seat: Number of infants with their own seat. + infants_on_lap: Number of infants on lap. + + Raises: + PassengerError: If passenger configuration is invalid. + """ + total = sum((adults, children, infants_in_seat, infants_on_lap)) + + if total > 9: + raise PassengerError(f"Too many passengers ({total} > 9)") + + if adults < 1 and (children > 0 or infants_in_seat > 0 or infants_on_lap > 0): + raise PassengerError("At least one adult is required when traveling with children or infants") + + if infants_on_lap > adults: + raise PassengerError( + f"Number of infants on lap ({infants_on_lap}) exceeds number of adults ({adults})" + ) + +def validate_flight_query( + from_airport: str, + to_airport: str, + date: Union[str, datetime], + max_stops: Optional[int] = None +) -> None: + """Validate flight query parameters. + + Args: + from_airport: Departure airport code. + to_airport: Arrival airport code. + date: Departure date (string in YYYY-MM-DD format or datetime object). + max_stops: Maximum number of stops allowed. + + Raises: + FlightQueryError: If any parameter is invalid. + """ + # Validate airport codes + try: + validate_airport_code(from_airport) + validate_airport_code(to_airport) + except AirportCodeError as e: + raise FlightQueryError(f"Invalid airport code: {e}") + + # Ensure origin and destination are different + if from_airport.upper() == to_airport.upper(): + raise FlightQueryError("Origin and destination airports cannot be the same") + + # Validate date + try: + validate_date(date) + except DateFormatError as e: + raise FlightQueryError(f"Invalid date: {e}") + + # Validate max_stops if provided + if max_stops is not None and (not isinstance(max_stops, int) or max_stops < 0): + raise FlightQueryError("max_stops must be a non-negative integer") + +def validate_currency(currency: str) -> None: + """Validate currency code. + + Args: + currency: The currency code to validate (3-letter ISO code). + + Raises: + ValueError: If the currency code is invalid. + """ + if not isinstance(currency, str): + raise ValueError(f"Currency must be a string, got {type(currency).__name__}") + + if currency and not re.match(r'^[A-Z]{3}$', currency.upper()): + raise ValueError( + f"Invalid currency code: {currency}. Must be a 3-letter ISO code (e.g., 'USD', 'EUR')" + ) diff --git a/test.py b/test.py index dc5f8323..82ad7b99 100644 --- a/test.py +++ b/test.py @@ -1,10 +1,11 @@ from fast_flights import FlightQuery, Passengers, create_query, get_flights from pprint import pprint +import datetime query = create_query( flights=[ FlightQuery( - date="2025-12-22", + date=(datetime.date.today() + datetime.timedelta(days=30)).isoformat(), from_airport="MYJ", to_airport="TPE", ), @@ -12,7 +13,7 @@ seat="economy", trip="one-way", passengers=Passengers(adults=1), - language="zh-TW", + language="en-US", ) res = get_flights(query) pprint(res) From da8c368ec05e9bccf5e73ae0db634afe56770e4a Mon Sep 17 00:00:00 2001 From: Abhinav Date: Sat, 4 Oct 2025 21:33:22 +0530 Subject: [PATCH 2/5] use primp client correctly --- fast_flights/fetcher.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/fast_flights/fetcher.py b/fast_flights/fetcher.py index 8d3e0221..ad5eddc0 100644 --- a/fast_flights/fetcher.py +++ b/fast_flights/fetcher.py @@ -130,20 +130,32 @@ def fetch_flights_html( logger.debug(f"Sending request to {URL} with params: {params}") res = client.get(URL, params=params) - res.raise_for_status() # Raise HTTPError for bad responses + + # Check status code directly since primp's client might not have raise_for_status + if res.status_code >= 400: + error_msg = f"Flight data API returned status code {res.status_code}" + logger.error(error_msg) + raise APIError(error_msg) if not res.text: - raise APIError("Received empty response from the flight data source") + error_msg = "Received empty response from the flight data source" + logger.error(error_msg) + raise APIError(error_msg) return res.text + except APIError: + # Re-raise APIError as is + raise + except ValueError as e: + # Re-raise ValueError as is + logger.error(f"Invalid query: {str(e)}") + raise except Exception as e: - logger.error(f"Error fetching flight data: {str(e)}") - if hasattr(e, 'response') and hasattr(e.response, 'status_code'): - raise APIError( - f"Flight data API returned status code {e.response.status_code}: {str(e)}" - ) from e - raise APIConnectionError(f"Failed to connect to flight data source: {str(e)}") from e + # Handle other exceptions + error_msg = f"Failed to connect to flight data source: {str(e)}" + logger.error(error_msg) + raise APIConnectionError(error_msg) from e else: logger.debug("Using integration for fetching flight data") From 24cef0c91bb25e5a9ac923360ce6ea988db1d511 Mon Sep 17 00:00:00 2001 From: Abhinav Date: Sat, 4 Oct 2025 23:03:02 +0530 Subject: [PATCH 3/5] simplify validation, normalization, better error handling and messages --- fast_flights/querying.py | 138 ++++++++++++++++++------------------- fast_flights/validation.py | 105 ++++++++++++++++++---------- 2 files changed, 137 insertions(+), 106 deletions(-) diff --git a/fast_flights/querying.py b/fast_flights/querying.py index f66a6498..97ca6e20 100644 --- a/fast_flights/querying.py +++ b/fast_flights/querying.py @@ -90,35 +90,42 @@ class FlightQuery: airlines: Optional[list[str]] = None def __post_init__(self): - """Validate the flight query parameters after initialization.""" - # Convert to strings in case we get non-string inputs - self.from_airport = str(self.from_airport).strip().upper() - self.to_airport = str(self.to_airport).strip().upper() + """Validate and normalize flight query parameters after initialization. - # Validate the flight query - try: - validate_flight_query(self.from_airport, self.to_airport, self.date, self.max_stops) + Raises: + AirportCodeError: If airport codes are invalid + DateFormatError: If date format is invalid + FlightQueryError: For other validation errors + """ + # Validate and normalize flight query parameters + self.from_airport, self.to_airport = validate_flight_query( + self.from_airport, + self.to_airport, + self.date, + self.max_stops + ) + + # Validate airlines if provided + if self.airlines is not None: + if not isinstance(self.airlines, (list, tuple)): + raise FlightQueryError("Airlines must be a list or tuple") - # Validate airlines if provided - if self.airlines is not None: - if not isinstance(self.airlines, (list, tuple)) or not all( - isinstance(airline, str) and len(airline) == 2 - for airline in self.airlines - ): + airlines = [] + for i, airline in enumerate(self.airlines): + if not isinstance(airline, str): + raise FlightQueryError( + f"Airline code at index {i} must be a string, got {type(airline).__name__}" + ) + + airline = airline.strip().upper() + if not re.match(r'^[A-Z]{2}$', airline): raise FlightQueryError( - "airlines must be a list of 2-letter IATA airline codes" + f"Invalid airline code: '{airline}'. Must be 2 uppercase letters" ) - # Clean and validate each airline code - self.airlines = [airline.strip().upper() for airline in self.airlines] - for airline in self.airlines: - if not re.match(r'^[A-Z]{2}$', airline): - raise FlightQueryError( - f"Invalid airline code: {airline}. Must be 2 uppercase letters" - ) - - except (ValidationError, ValueError) as e: - raise FlightQueryError(str(e)) from e + airlines.append(airline) + + self.airlines = airlines def pb(self) -> FlightData: """Convert this query to a protobuf FlightData message. @@ -161,34 +168,24 @@ class Passengers: def __init__( self, *, - adults: int = 0, + adults: int = 1, # Changed default to 1 to match DEFAULT_PASSENGERS children: int = 0, infants_in_seat: int = 0, infants_on_lap: int = 0, ): - # Convert to integers in case we get strings - try: - adults = int(adults) - children = int(children) - infants_in_seat = int(infants_in_seat) - infants_on_lap = int(infants_on_lap) - except (TypeError, ValueError) as e: - raise PassengerError("Passenger counts must be integers") from e - - # Validate passenger counts are non-negative - if any(count < 0 for count in (adults, children, infants_in_seat, infants_on_lap)): - raise PassengerError("Passenger counts cannot be negative") + # Convert to integers and validate + self.adults = int(adults) if adults is not None else 1 + self.children = int(children) if children is not None else 0 + self.infants_in_seat = int(infants_in_seat) if infants_in_seat is not None else 0 + self.infants_on_lap = int(infants_on_lap) if infants_on_lap is not None else 0 # Validate passenger configuration - try: - validate_passengers(adults, children, infants_in_seat, infants_on_lap) - except PassengerError as e: - raise # Re-raise the validation error with the original message - - self.adults = adults - self.children = children - self.infants_in_seat = infants_in_seat - self.infants_on_lap = infants_on_lap + validate_passengers( + adults=self.adults, + children=self.children, + infants_in_seat=self.infants_in_seat, + infants_on_lap=self.infants_on_lap + ) def pb(self) -> list[Passenger]: return [ @@ -218,11 +215,14 @@ def create_query( flights: list[FlightQuery], seat: SeatType = "economy", trip: TripType = "one-way", - passengers: Passengers = DEFAULT_PASSENGERS, + passengers: Optional[Passengers] = None, language: Union[str, Literal[""], Language] = "en-US", currency: Union[str, Literal[""], Currency] = "USD", max_stops: Optional[int] = None, ) -> Query: + # Use default passengers if not provided + if passengers is None: + passengers = DEFAULT_PASSENGERS """Create a query for flight search. Args: @@ -240,34 +240,37 @@ def create_query( Raises: ValueError: If any of the input parameters are invalid. FlightQueryError: If there are issues with the flight queries. + AirportCodeError: If any airport code is invalid. + DateFormatError: If any date format is invalid. + PassengerError: If passenger configuration is invalid. """ - # Validate inputs + # Validate required inputs if not isinstance(flights, (list, tuple)) or not flights: raise ValueError("At least one flight segment is required") if not all(isinstance(f, FlightQuery) for f in flights): raise ValueError("All flight segments must be FlightQuery instances") + # Validate seat type if seat not in SEAT_LOOKUP: - valid_seats = ", ".join(f"'{s}'" for s in SEAT_LOOKUP.keys()) + valid_seats = ", ".join(f"'{s}'" for s in SEAT_LOOKUP) raise ValueError(f"Invalid seat type: '{seat}'. Must be one of: {valid_seats}") + # Validate trip type if trip not in TRIP_LOOKUP: - valid_trips = ", ".join(f"'{t}'" for t in TRIP_LOOKUP.keys()) + valid_trips = ", ".join(f"'{t}'" for t in TRIP_LOOKUP) raise ValueError(f"Invalid trip type: '{trip}'. Must be one of: {valid_trips}") + # Validate passengers if not isinstance(passengers, Passengers): raise ValueError("passengers must be an instance of Passengers") - # Validate language and currency if provided - if language and not isinstance(language, str) and not isinstance(language, Language): + # Validate language + if language and not isinstance(language, (str, Language)): raise ValueError("language must be a string or Language enum value") - try: - if currency: - validate_currency(str(currency)) - except ValueError as e: - raise ValueError(f"Invalid currency: {e}") from e + # Validate and normalize currency + currency_code = validate_currency(currency) if currency else "" # Apply max_stops to all flights if specified if max_stops is not None: @@ -275,15 +278,12 @@ def create_query( raise ValueError("max_stops must be a non-negative integer") flights = [flight._setmaxstops(max_stops) for flight in flights] - # Create the query - try: - return Query( - flight_data=[flight.pb() for flight in flights], - seat=SEAT_LOOKUP[seat], - trip=TRIP_LOOKUP[trip], - passengers=passengers.pb(), - language=str(language) if language else "", - currency=str(currency) if currency else "", - ) - except Exception as e: - raise FlightQueryError(f"Failed to create query: {str(e)}") from e + # Create and return the query + return Query( + flight_data=[flight.pb() for flight in flights], + seat=SEAT_LOOKUP[seat], + trip=TRIP_LOOKUP[trip], + passengers=passengers.pb(), + language=str(language) if language else "", + currency=currency_code, + ) diff --git a/fast_flights/validation.py b/fast_flights/validation.py index 7ff04144..704ce1e4 100644 --- a/fast_flights/validation.py +++ b/fast_flights/validation.py @@ -10,22 +10,30 @@ FlightQueryError ) -def validate_airport_code(code: str) -> None: - """Validate an airport IATA code. +def validate_airport_code(code: str) -> str: + """Validate and normalize an airport IATA code. Args: code: The airport code to validate. + Returns: + str: The normalized uppercase airport code. + Raises: - AirportCodeError: If the airport code is invalid. + TypeError: If code is not a string. + AirportCodeError: If the airport code format is invalid. """ if not isinstance(code, str): - raise AirportCodeError(f"Airport code must be a string, got {type(code).__name__}") + raise TypeError(f"Airport code must be a string, got {type(code).__name__}") - if not re.match(r'^[A-Z]{3}$', code.upper()): - raise AirportCodeError( - f"Invalid airport code: {code}. Must be 3 uppercase letters (IATA code)" - ) + code = code.strip().upper() + if not code: + raise AirportCodeError("Airport code cannot be empty") + + if not re.match(r'^[A-Z]{3}$', code): + raise AirportCodeError(f"Invalid IATA code: '{code}'. Must be 3 uppercase letters") + + return code def validate_date(date_str: Union[str, datetime]) -> None: """Validate a date string or datetime object. @@ -47,7 +55,7 @@ def validate_date(date_str: Union[str, datetime]) -> None: ) from e def validate_passengers( - adults: int = 0, + adults: int = 1, children: int = 0, infants_in_seat: int = 0, infants_on_lap: int = 0 @@ -55,21 +63,35 @@ def validate_passengers( """Validate passenger configuration. Args: - adults: Number of adults. - children: Number of children. - infants_in_seat: Number of infants with their own seat. - infants_on_lap: Number of infants on lap. + adults: Number of adults (must be >= 1). + children: Number of children (must be >= 0). + infants_in_seat: Number of infants with their own seat (must be >= 0). + infants_on_lap: Number of infants on lap (must be >= 0). Raises: PassengerError: If passenger configuration is invalid. + TypeError: If any count is not an integer. + ValueError: If any count is negative. """ - total = sum((adults, children, infants_in_seat, infants_on_lap)) + # Validate input types and non-negative values + for count, name in [ + (adults, "adults"), + (children, "children"), + (infants_in_seat, "infants_in_seat"), + (infants_on_lap, "infants_on_lap") + ]: + if not isinstance(count, int): + raise TypeError(f"{name} must be an integer, got {type(count).__name__}") + if count < 0: + raise ValueError(f"{name} cannot be negative") - if total > 9: - raise PassengerError(f"Too many passengers ({total} > 9)") + total = adults + children + infants_in_seat + infants_on_lap + + if adults < 1 and total > 0: + raise PassengerError("At least one adult is required when traveling with passengers") - if adults < 1 and (children > 0 or infants_in_seat > 0 or infants_on_lap > 0): - raise PassengerError("At least one adult is required when traveling with children or infants") + if total > 9: + raise PassengerError(f"Maximum of 9 passengers allowed, got {total}") if infants_on_lap > adults: raise PassengerError( @@ -81,8 +103,8 @@ def validate_flight_query( to_airport: str, date: Union[str, datetime], max_stops: Optional[int] = None -) -> None: - """Validate flight query parameters. +) -> tuple[str, str]: + """Validate and normalize flight query parameters. Args: from_airport: Departure airport code. @@ -90,43 +112,52 @@ def validate_flight_query( date: Departure date (string in YYYY-MM-DD format or datetime object). max_stops: Maximum number of stops allowed. + Returns: + tuple: Normalized (from_airport, to_airport) codes. + Raises: FlightQueryError: If any parameter is invalid. """ - # Validate airport codes - try: - validate_airport_code(from_airport) - validate_airport_code(to_airport) - except AirportCodeError as e: - raise FlightQueryError(f"Invalid airport code: {e}") + # Validate and normalize airport codes + from_code = validate_airport_code(from_airport) + to_code = validate_airport_code(to_airport) # Ensure origin and destination are different - if from_airport.upper() == to_airport.upper(): + if from_code == to_code: raise FlightQueryError("Origin and destination airports cannot be the same") # Validate date - try: - validate_date(date) - except DateFormatError as e: - raise FlightQueryError(f"Invalid date: {e}") + validate_date(date) # Validate max_stops if provided if max_stops is not None and (not isinstance(max_stops, int) or max_stops < 0): raise FlightQueryError("max_stops must be a non-negative integer") + + return from_code, to_code -def validate_currency(currency: str) -> None: - """Validate currency code. +def validate_currency(currency: str) -> str: + """Validate and normalize a currency code. Args: currency: The currency code to validate (3-letter ISO code). + Returns: + str: The normalized uppercase currency code. + Raises: - ValueError: If the currency code is invalid. + TypeError: If currency is not a string. + ValueError: If the currency code format is invalid. """ if not isinstance(currency, str): - raise ValueError(f"Currency must be a string, got {type(currency).__name__}") + raise TypeError(f"Currency must be a string, got {type(currency).__name__}") - if currency and not re.match(r'^[A-Z]{3}$', currency.upper()): + currency = currency.strip().upper() + if not currency: + raise ValueError("Currency code cannot be empty") + + if not re.match(r'^[A-Z]{3}$', currency): raise ValueError( - f"Invalid currency code: {currency}. Must be a 3-letter ISO code (e.g., 'USD', 'EUR')" + f"Invalid currency code: '{currency}'. Must be 3 uppercase letters (e.g., 'USD', 'EUR')" ) + + return currency From 99b9b2bc77da7716c85500fda134745824e07a78 Mon Sep 17 00:00:00 2001 From: Abhinav Date: Sun, 5 Oct 2025 10:56:04 +0530 Subject: [PATCH 4/5] refactor --- fast_flights/__init__.py | 47 +--- fast_flights/querying.py | 277 +++++++++++++++-------- fast_flights/validation.py | 435 ++++++++++++++++++++++++++++++------- pyproject.toml | 1 + tests/test_validations.py | 143 ++++++++++++ 5 files changed, 698 insertions(+), 205 deletions(-) create mode 100644 tests/test_validations.py diff --git a/fast_flights/__init__.py b/fast_flights/__init__.py index 7d295b04..533f759f 100644 --- a/fast_flights/__init__.py +++ b/fast_flights/__init__.py @@ -1,34 +1,14 @@ from . import integrations -from .exceptions import ( - FastFlightsError, - ValidationError, - AirportCodeError, - DateFormatError, - PassengerError, - FlightQueryError, - APIConnectionError, - APIError, -) -from .validation import ( - validate_airport_code, - validate_date, - validate_passengers, - validate_flight_query, - validate_currency, -) -from .querying import ( - FlightQuery, - Query, - Passengers, - create_query, - create_query as create_filter, # alias -) +from .exceptions import FastFlightsError, APIError +from .querying import FlightQuery, Passengers, create_query from .fetcher import get_flights, fetch_flights_html +# Create alias for backward compatibility +create_filter = create_query + __all__ = [ # Core functionality "FlightQuery", - "Query", "Passengers", "create_query", "create_filter", @@ -36,20 +16,7 @@ "fetch_flights_html", "integrations", - # Exceptions + # Public exceptions "FastFlightsError", - "ValidationError", - "AirportCodeError", - "DateFormatError", - "PassengerError", - "FlightQueryError", - "APIConnectionError", - "APIError", - - # Validation utilities - "validate_airport_code", - "validate_date", - "validate_passengers", - "validate_flight_query", - "validate_currency", + "APIError" ] diff --git a/fast_flights/querying.py b/fast_flights/querying.py index 97ca6e20..557db392 100644 --- a/fast_flights/querying.py +++ b/fast_flights/querying.py @@ -2,16 +2,22 @@ from base64 import b64encode from datetime import datetime as Datetime from dataclasses import dataclass -from typing import Literal, Optional, Union +from typing import Literal, Optional, TypeVar, Union, get_args, get_origin from .exceptions import ( PassengerError, FlightQueryError, - ValidationError ) from .validation import ( validate_passengers, validate_flight_query, + validate_seat_type, + validate_trip_type, + validate_flights_list, + validate_airlines, + validate_max_stops, + validate_and_normalize_date, + validate_language, validate_currency ) @@ -29,8 +35,43 @@ class Query: passengers: list[Passenger] language: str currency: str + + def __str__(self) -> str: + """Return a human-readable string representation of the query.""" + flight_info = [] + for i, flight in enumerate(self.flight_data, 1): + flight_info.append(f" Flight {i}:") + flight_info.append(f" From: {flight.from_airport.airport if flight.HasField('from_airport') else 'N/A'}") + flight_info.append(f" To: {flight.to_airport.airport if flight.HasField('to_airport') else 'N/A'}") + flight_info.append(f" Date: {flight.date}") + + if flight.airlines: + flight_info.append(f" Airlines: {', '.join(flight.airlines)}") + + if flight.HasField('max_stops'): + flight_info.append(f" Max Stops: {flight.max_stops}") + + # Count passenger types + passenger_counts = {} + for p in self.passengers: + if p == Passenger.ADULT: + passenger_counts['Adults'] = passenger_counts.get('Adults', 0) + 1 + elif p == Passenger.CHILD: + passenger_counts['Children'] = passenger_counts.get('Children', 0) + 1 + elif p in (Passenger.INFANT_IN_SEAT, Passenger.INFANT_ON_LAP): + passenger_counts['Infants'] = passenger_counts.get('Infants', 0) + 1 + + return ( + f"Query Details:\n" + f"Seat Class: {self.seat}\n" + f"Trip Type: {self.trip}\n" + f"Passengers: {', '.join(f'{count} {type_}' for type_, count in passenger_counts.items() if count > 0)}\n" + f"Language: {self.language or 'Default'}\n" + f"Currency: {self.currency or 'Default'}\n" + f"Flights:\n" + '\n'.join(flight_info) + ) - def pb(self) -> Info: + def to_proto(self) -> Info: """(internal) Protobuf data. (`Info`)""" return Info( data=self.flight_data, @@ -41,7 +82,7 @@ def pb(self) -> Info: def to_bytes(self) -> bytes: """Convert this query to bytes.""" - return self.pb().SerializeToString() + return self.to_proto().SerializeToString() def to_str(self) -> str: """Convert this query to a string.""" @@ -97,49 +138,36 @@ def __post_init__(self): DateFormatError: If date format is invalid FlightQueryError: For other validation errors """ + # Validate and normalize date first + self._normalized_date = validate_and_normalize_date(self.date) + # Validate and normalize flight query parameters self.from_airport, self.to_airport = validate_flight_query( self.from_airport, self.to_airport, - self.date, + self._normalized_date, self.max_stops ) - # Validate airlines if provided + # Validate and normalize airlines if provided if self.airlines is not None: - if not isinstance(self.airlines, (list, tuple)): - raise FlightQueryError("Airlines must be a list or tuple") - - airlines = [] - for i, airline in enumerate(self.airlines): - if not isinstance(airline, str): - raise FlightQueryError( - f"Airline code at index {i} must be a string, got {type(airline).__name__}" - ) - - airline = airline.strip().upper() - if not re.match(r'^[A-Z]{2}$', airline): - raise FlightQueryError( - f"Invalid airline code: '{airline}'. Must be 2 uppercase letters" - ) - - airlines.append(airline) - - self.airlines = airlines + try: + self.airlines = validate_airlines(self.airlines) + except ValueError as e: + raise FlightQueryError(str(e)) from e - def pb(self) -> FlightData: + def to_proto(self) -> FlightData: """Convert this query to a protobuf FlightData message. Returns: FlightData: The protobuf message representing this query. + + Note: + All validations should be done in __post_init__. + This method should only convert already validated data to protobuf. """ - if isinstance(self.date, str): - date = self.date - else: - date = self.date.strftime("%Y-%m-%d") - return FlightData( - date=date, + date=self._normalized_date, from_airport=Airport(airport=self.from_airport), to_airport=Airport(airport=self.to_airport), max_stops=self.max_stops if self.max_stops is not None else None, @@ -157,7 +185,7 @@ class Passengers: """Represents a group of passengers for a flight. Args: - adults: Number of adults (16+ years). At least one adult is required when traveling with children or infants. + adults: Number of adults (16+ years). At least one adult is required. children: Number of children (2-15 years). infants_in_seat: Number of infants (under 2 years) with their own seat. infants_on_lap: Number of infants (under 2 years) on an adult's lap. @@ -165,21 +193,25 @@ class Passengers: Raises: PassengerError: If the passenger configuration is invalid. """ + adults: int + children: int + infants_in_seat: int + infants_on_lap: int + def __init__( self, *, - adults: int = 1, # Changed default to 1 to match DEFAULT_PASSENGERS + adults: int = 1, children: int = 0, infants_in_seat: int = 0, infants_on_lap: int = 0, - ): - # Convert to integers and validate - self.adults = int(adults) if adults is not None else 1 - self.children = int(children) if children is not None else 0 - self.infants_in_seat = int(infants_in_seat) if infants_in_seat is not None else 0 - self.infants_on_lap = int(infants_on_lap) if infants_on_lap is not None else 0 + ) -> None: + """Initialize and validate passenger configuration.""" + self.adults = int(adults) + self.children = int(children) + self.infants_in_seat = int(infants_in_seat) + self.infants_on_lap = int(infants_on_lap) - # Validate passenger configuration validate_passengers( adults=self.adults, children=self.children, @@ -187,7 +219,7 @@ def __init__( infants_on_lap=self.infants_on_lap ) - def pb(self) -> list[Passenger]: + def to_proto(self) -> list[Passenger]: return [ *(Passenger.ADULT for _ in range(self.adults)), *(Passenger.CHILD for _ in range(self.children)), @@ -196,18 +228,80 @@ def pb(self) -> list[Passenger]: ] -DEFAULT_PASSENGERS = Passengers(adults=1) -SEAT_LOOKUP = { - "economy": Seat.ECONOMY, - "premium-economy": Seat.PREMIUM_ECONOMY, - "business": Seat.BUSINESS, - "first": Seat.FIRST, -} -TRIP_LOOKUP = { - "round-trip": Trip.ROUND_TRIP, - "one-way": Trip.ONE_WAY, - "multi-city": Trip.MULTI_CITY, -} +from typing import TypeVar, Type, Any, get_origin, get_args, cast + +T = TypeVar('T') + +def _get_literal_values(tp: Any) -> tuple[str, ...]: + """Extract the literal values from a Literal type. + + Returns: + tuple[str, ...]: The literal values as strings, or empty tuple if not a Literal + """ + origin = get_origin(tp) + if origin is not Literal: + return () + + args = get_args(tp) + if not args: + return () + + # Filter out non-string values and ensure they're strings + return tuple(str(arg) for arg in args if isinstance(arg, (str, int, bool, float))) + +def _create_enum_lookup(enum_type: Union[Type[T], Type[str]], literal_type: Any) -> dict[str, T]: + """Create a lookup dictionary from a Literal type to an enum or string type. + + Args: + enum_type: The target type (e.g., Seat, Trip, or str for direct mappings) + literal_type: The Literal type (e.g., SeatType, TripType, Language, Currency) + + Returns: + dict[str, T]: A mapping from string literals to enum values or strings of type T + """ + # Get the allowed string values from the Literal type + literal_values = _get_literal_values(literal_type) + if not literal_values: + return {} + + # If the target type is str, create a direct mapping (for Language and Currency) + if enum_type is str: + return {value: value for value in literal_values if isinstance(value, str)} + + # For enum types, map string values to enum values + lookup: dict[str, T] = {} + for value in literal_values: + if not isinstance(value, str): + continue + + # Convert string like "premium-economy" to "PREMIUM_ECONOMY" + enum_name = value.upper().replace('-', '_') + try: + enum_value = getattr(enum_type, enum_name) + lookup[value] = enum_value + except AttributeError: + continue # Skip if the enum value doesn't exist + + return lookup + +# Create lookups from Literal types to their corresponding values +SEAT_LOOKUP = _create_enum_lookup(Seat, SeatType) +TRIP_LOOKUP = _create_enum_lookup(Trip, TripType) +LANGUAGE_LOOKUP = _create_enum_lookup(str, Language) +CURRENCY_LOOKUP = _create_enum_lookup(str, Currency) + +# Runtime validation of the lookup tables +def _validate_lookup(lookup: dict, type_name: str, expected_values: set) -> None: + """Validate that all expected values are present in the lookup.""" + missing = expected_values - set(lookup.keys()) + if missing: + raise RuntimeError(f"Missing {type_name} values for: {', '.join(sorted(missing))}") + +# Validate all lookups +_validate_lookup(SEAT_LOOKUP, "Seat", set(get_args(SeatType))) +_validate_lookup(TRIP_LOOKUP, "Trip", set(get_args(TripType))) +_validate_lookup(LANGUAGE_LOOKUP, "Language", set(get_args(Language))) +_validate_lookup(CURRENCY_LOOKUP, "Currency", set(get_args(Currency))) def create_query( @@ -220,70 +314,79 @@ def create_query( currency: Union[str, Literal[""], Currency] = "USD", max_stops: Optional[int] = None, ) -> Query: - # Use default passengers if not provided - if passengers is None: - passengers = DEFAULT_PASSENGERS """Create a query for flight search. Args: flights: List of FlightQuery objects representing the flight segments. + Must contain at least one flight. seat: Seat type (e.g., 'economy', 'business'). Defaults to 'economy'. - trip: Type of trip ('one-way', 'round-trip', 'multi-city'). Defaults to 'one-way'. - passengers: Passengers configuration. Defaults to 1 adult. - language: Language code (e.g., 'en', 'es'). Empty string lets Google decide. + trip: Type of trip ('one-way', 'round-trip', 'multi-city'). + Defaults to 'one-way'. + passengers: Passengers configuration. If None, defaults to 1 adult. + language: Language code (e.g., 'en-US', 'es'). Empty string lets Google decide. currency: Currency code (e.g., 'USD', 'EUR'). Empty string lets Google decide. - max_stops: Maximum number of stops allowed for all flights. Overrides individual flight settings. + max_stops: Maximum number of stops allowed for all flights. + Overrides individual flight settings. Use None for no limit. Returns: Query: A configured Query object ready for flight search. Raises: - ValueError: If any of the input parameters are invalid. FlightQueryError: If there are issues with the flight queries. + ValueError: If any of the input parameters are invalid. AirportCodeError: If any airport code is invalid. DateFormatError: If any date format is invalid. PassengerError: If passenger configuration is invalid. + + Example: + >>> query = create_query( + ... flights=[ + ... FlightQuery( + ... date="2025-12-25", + ... from_airport="JFK", + ... to_airport="LAX" + ... ) + ... ], + ... passengers=Passengers(adults=2, children=1) + ... ) """ - # Validate required inputs - if not isinstance(flights, (list, tuple)) or not flights: - raise ValueError("At least one flight segment is required") - - if not all(isinstance(f, FlightQuery) for f in flights): - raise ValueError("All flight segments must be FlightQuery instances") - - # Validate seat type - if seat not in SEAT_LOOKUP: - valid_seats = ", ".join(f"'{s}'" for s in SEAT_LOOKUP) - raise ValueError(f"Invalid seat type: '{seat}'. Must be one of: {valid_seats}") - - # Validate trip type - if trip not in TRIP_LOOKUP: - valid_trips = ", ".join(f"'{t}'" for t in TRIP_LOOKUP) - raise ValueError(f"Invalid trip type: '{trip}'. Must be one of: {valid_trips}") + # Initialize default passengers if not provided + if passengers is None: + passengers = Passengers() # Validate passengers if not isinstance(passengers, Passengers): raise ValueError("passengers must be an instance of Passengers") # Validate language - if language and not isinstance(language, (str, Language)): - raise ValueError("language must be a string or Language enum value") + language = validate_language(language) if language else "" + language = LANGUAGE_LOOKUP.get(language, "") if language else "" + + # Validate currency + if currency: + currency = validate_currency(currency) + currency_code = CURRENCY_LOOKUP.get(currency, "") if currency else "" + + # Validate flights list + validate_flights_list(flights, FlightQuery) + + # Validate and normalize seat type + seat = validate_seat_type(seat) - # Validate and normalize currency - currency_code = validate_currency(currency) if currency else "" + # Validate and normalize trip type + trip = validate_trip_type(trip) - # Apply max_stops to all flights if specified + # Validate and apply max_stops to all flights if specified if max_stops is not None: - if not isinstance(max_stops, int) or max_stops < 0: - raise ValueError("max_stops must be a non-negative integer") + max_stops = validate_max_stops(max_stops) flights = [flight._setmaxstops(max_stops) for flight in flights] # Create and return the query return Query( - flight_data=[flight.pb() for flight in flights], + flight_data=[flight.to_proto() for flight in flights], seat=SEAT_LOOKUP[seat], trip=TRIP_LOOKUP[trip], - passengers=passengers.pb(), + passengers=passengers.to_proto(), language=str(language) if language else "", currency=currency_code, ) diff --git a/fast_flights/validation.py b/fast_flights/validation.py index 704ce1e4..b0f7fbb5 100644 --- a/fast_flights/validation.py +++ b/fast_flights/validation.py @@ -1,14 +1,61 @@ """Input validation utilities for fast_flights package.""" import re from datetime import datetime -from typing import Optional, Union - +from enum import Enum +from typing import Any, List, Literal, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin from .exceptions import ( AirportCodeError, DateFormatError, PassengerError, FlightQueryError ) +from .types import Language, SeatType, TripType, Currency + +T = TypeVar('T') + +def validate_enum_value(value: Any, enum_type: Type[T], field_name: str) -> T: + """Validate that a value is a valid enum value or literal. + + Args: + value: The value to validate. + enum_type: The enum class or Literal type to validate against. + field_name: The name of the field being validated (for error messages). + + Returns: + The validated value. + + Raises: + FlightQueryError: If the value is not a valid enum or literal value. + """ + # Handle Enum types + if isinstance(enum_type, type) and issubclass(enum_type, Enum): + try: + return enum_type(value) + except ValueError: + valid_values = [e.value for e in enum_type] + raise FlightQueryError( + f"Invalid {field_name}: '{value}'. " + f"Must be one of: {', '.join(repr(v) for v in valid_values)}" + ) from None + + # Handle Literal types + origin = get_origin(enum_type) + if origin is Literal: + valid_values = get_args(enum_type) + if value in valid_values: + return value + raise FlightQueryError( + f"Invalid {field_name}: '{value}'. " + f"Must be one of: {', '.join(repr(v) for v in valid_values)}" + ) from None + + # Handle regular types + if isinstance(value, enum_type): + return value + + raise FlightQueryError( + f"Expected {field_name} to be of type {enum_type.__name__}, got {type(value).__name__}" + ) def validate_airport_code(code: str) -> str: """Validate and normalize an airport IATA code. @@ -20,82 +67,87 @@ def validate_airport_code(code: str) -> str: str: The normalized uppercase airport code. Raises: - TypeError: If code is not a string. - AirportCodeError: If the airport code format is invalid. + AirportCodeError: If the airport code is not a string or has an invalid format. """ if not isinstance(code, str): - raise TypeError(f"Airport code must be a string, got {type(code).__name__}") + raise AirportCodeError(f"Airport code must be a string, got {type(code).__name__}") code = code.strip().upper() if not code: raise AirportCodeError("Airport code cannot be empty") - if not re.match(r'^[A-Z]{3}$', code): raise AirportCodeError(f"Invalid IATA code: '{code}'. Must be 3 uppercase letters") return code -def validate_date(date_str: Union[str, datetime]) -> None: - """Validate a date string or datetime object. - - Args: - date_str: The date to validate (string in YYYY-MM-DD format or datetime object). - - Raises: - DateFormatError: If the date format is invalid. - """ - if isinstance(date_str, datetime): - return - - try: - datetime.strptime(date_str, '%Y-%m-%d') - except (TypeError, ValueError) as e: - raise DateFormatError( - f"Invalid date format: {date_str}. Expected YYYY-MM-DD or datetime object" - ) from e - def validate_passengers( - adults: int = 1, - children: int = 0, - infants_in_seat: int = 0, - infants_on_lap: int = 0 + adults: int, + children: int, + infants_in_seat: int, + infants_on_lap: int, + max_passengers: int = 9, ) -> None: - """Validate passenger configuration. + """Validate passenger configuration for a flight booking. + + Ensures that the passenger configuration follows these rules: + - At least one adult is required + - No more than max_passengers (default: 9) total passengers allowed + - Number of infants on lap cannot exceed number of adults + - All counts must be non-negative integers Args: adults: Number of adults (must be >= 1). children: Number of children (must be >= 0). infants_in_seat: Number of infants with their own seat (must be >= 0). infants_on_lap: Number of infants on lap (must be >= 0). + max_passengers: Maximum allowed total passengers (default: 9). Raises: - PassengerError: If passenger configuration is invalid. - TypeError: If any count is not an integer. - ValueError: If any count is negative. + PassengerError: If passenger configuration is invalid, counts are not integers, + counts are negative, or other validation fails. """ - # Validate input types and non-negative values - for count, name in [ - (adults, "adults"), - (children, "children"), - (infants_in_seat, "infants_in_seat"), - (infants_on_lap, "infants_on_lap") - ]: - if not isinstance(count, int): - raise TypeError(f"{name} must be an integer, got {type(count).__name__}") - if count < 0: - raise ValueError(f"{name} cannot be negative") + # Validate input types + if not all(isinstance(x, int) for x in [adults, children, infants_in_seat, infants_on_lap]): + raise PassengerError("All passenger counts must be integers") + + # Validate at least one adult + if adults < 1: + raise PassengerError("At least one adult is required") + + # Validate non-negative values + if children < 0: + raise PassengerError("Number of children cannot be negative") + if infants_in_seat < 0: + raise PassengerError("Number of infants in seat cannot be negative") + if infants_on_lap < 0: + raise PassengerError("Number of infants on lap cannot be negative") total = adults + children + infants_in_seat + infants_on_lap + + # Validate at least one passenger + if total < 1: + raise PassengerError("At least one passenger is required") + + - if adults < 1 and total > 0: - raise PassengerError("At least one adult is required when traveling with passengers") + # Validate total passengers + if total > max_passengers: + raise PassengerError( + f"Maximum of {max_passengers} passengers allowed, got {total} " + f"({adults} adults, {children} children, {infants_in_seat} infants in seat, " + f"{infants_on_lap} infants on lap)" + ) - if total > 9: - raise PassengerError(f"Maximum of 9 passengers allowed, got {total}") + # Validate infants on lap + if infants_on_lap > 0 and adults < 1: + raise PassengerError( + f"Cannot have {infants_on_lap} infants on lap without at least one adult" + ) - if infants_on_lap > adults: + # Validate infants in seat + if infants_in_seat > 0 and adults < 1 and children < 1: raise PassengerError( - f"Number of infants on lap ({infants_on_lap}) exceeds number of adults ({adults})" + f"Cannot have {infants_in_seat} infants in seat without at least one adult or child" ) def validate_flight_query( @@ -106,58 +158,285 @@ def validate_flight_query( ) -> tuple[str, str]: """Validate and normalize flight query parameters. + Validates that: + - Airport codes are valid IATA codes + - Origin and destination are different + - Date is valid and in the future + - max_stops is a non-negative integer or None + Args: - from_airport: Departure airport code. - to_airport: Arrival airport code. + from_airport: Departure airport code (case-insensitive). + to_airport: Arrival airport code (case-insensitive). date: Departure date (string in YYYY-MM-DD format or datetime object). - max_stops: Maximum number of stops allowed. + max_stops: Maximum number of stops allowed (None for no limit). Returns: - tuple: Normalized (from_airport, to_airport) codes. + tuple: Normalized (from_airport, to_airport) codes in uppercase. Raises: FlightQueryError: If any parameter is invalid. + AirportCodeError: If airport codes are invalid. + DateFormatError: If date format is invalid or in the past. """ # Validate and normalize airport codes - from_code = validate_airport_code(from_airport) - to_code = validate_airport_code(to_airport) + from_airport_norm = validate_airport_code(from_airport) + to_airport_norm = validate_airport_code(to_airport) # Ensure origin and destination are different - if from_code == to_code: - raise FlightQueryError("Origin and destination airports cannot be the same") + if from_airport_norm == to_airport_norm: + raise FlightQueryError( + f"Origin and destination airports cannot be the same: {from_airport_norm}" + ) + + # Validate date is in the future + normalized_date = validate_and_normalize_date(date) + today = datetime.now().date() + if isinstance(normalized_date, str): + normalized_date = datetime.strptime(normalized_date, "%Y-%m-%d").date() + + if normalized_date < today: + raise DateFormatError("Departure date cannot be in the past") + + # Validate max_stops + if max_stops is not None: + if not isinstance(max_stops, int): + raise FlightQueryError("max_stops must be an integer or None") + if max_stops < 0: + raise FlightQueryError("max_stops cannot be negative") + + return from_airport_norm, to_airport_norm + +def _validate_literal(value: Any, literal_type: Type[T], type_name: str) -> T: + """Validate a value against a Literal type or Enum. + + Args: + value: The value to validate. + literal_type: The type to validate against (Literal or Enum). + type_name: Human-readable name of the type for error messages. + + Returns: + The validated value of type T. + + Raises: + FlightQueryError: If the value is not a valid value of the literal type. + """ + # Handle Enum types + if isinstance(literal_type, type) and issubclass(literal_type, Enum): + try: + return literal_type(value) + except ValueError: + valid_values = [e.value for e in literal_type] + raise FlightQueryError( + f"Invalid {type_name}: '{value}'. " + f"Must be one of: {', '.join(repr(v) for v in valid_values)}" + ) from None + + # Handle Literal types + origin = get_origin(literal_type) + if origin is Literal: + valid_values = get_args(literal_type) + if value in valid_values: + return value + raise FlightQueryError( + f"Invalid {type_name}: '{value}'. " + f"Must be one of: {', '.join(repr(v) for v in valid_values)}" + ) + + # If we get here, the type is not supported + raise ValueError(f"Unsupported type for validation: {literal_type}") + + +def validate_seat_type(seat: Union[str, SeatType]) -> SeatType: + """Validate that a seat type is valid. - # Validate date - validate_date(date) + Args: + seat: The seat type to validate. + + Returns: + SeatType: The validated seat type. + + Raises: + FlightQueryError: If the seat type is invalid. + """ + return _validate_literal(seat, SeatType, 'seat type') + + +def validate_trip_type(trip: Union[str, TripType]) -> TripType: + """Validate that a trip type is valid. - # Validate max_stops if provided - if max_stops is not None and (not isinstance(max_stops, int) or max_stops < 0): - raise FlightQueryError("max_stops must be a non-negative integer") + Args: + trip: The trip type to validate. + + Returns: + TripType: The validated trip type. + + Raises: + FlightQueryError: If the trip type is invalid. + """ + return _validate_literal(trip, TripType, 'trip type') + + +def validate_language(language: Union[str, Language, None]) -> str: + """Validate and normalize a language code. - return from_code, to_code + Args: + language: The language code to validate. + + Returns: + str: The normalized language code, or empty string if None. + + Raises: + FlightQueryError: If the language code is invalid. + """ + if language is None: + return "" + language_str = str(language).strip() + if not language_str: + return "" # Empty string is allowed (means use default) + + return _validate_literal(language_str, Language, 'language code') + -def validate_currency(currency: str) -> str: +def validate_currency(currency: Union[str, Currency, None]) -> str: """Validate and normalize a currency code. Args: - currency: The currency code to validate (3-letter ISO code). + currency: The currency code to validate. + + Returns: + str: The normalized currency code, or empty string if None. + + Raises: + FlightQueryError: If the currency code is invalid. + """ + if currency is None: + return "" + + currency_str = str(currency).strip().upper() + if not currency_str: + return "" # Empty string is allowed (means use default) + + return _validate_literal(currency_str, Currency, 'currency code') + + +def validate_flights_list(flights: List[Any], expected_type: Type[T]) -> List[T]: + """Validate that a list contains only instances of the expected type. + flights: The list to validate. + expected_type: The expected type of list elements. + + Returns: + List[T]: The validated list. + + Raises: + FlightQueryError: If the list is empty or contains invalid elements. + """ + if not isinstance(flights, (list, tuple)) or not flights: + raise FlightQueryError("At least one flight segment is required") + + if not all(isinstance(f, expected_type) for f in flights): + type_name = expected_type.__name__ + raise FlightQueryError(f"All flight segments must be {type_name} instances") + + return flights # type: ignore + + +def validate_airline_code(code: str) -> str: + """Validate and normalize an airline IATA code. + + Args: + code: The airline code to validate. Returns: - str: The normalized uppercase currency code. + str: The normalized uppercase airline code. Raises: - TypeError: If currency is not a string. - ValueError: If the currency code format is invalid. + FlightQueryError: If the airline code format is invalid. """ - if not isinstance(currency, str): - raise TypeError(f"Currency must be a string, got {type(currency).__name__}") + if not isinstance(code, str): + raise FlightQueryError(f"Airline code must be a string, got {type(code).__name__}") + + code = code.strip().upper() + if not code: + raise FlightQueryError("Airline code cannot be empty") + + if not re.match(r'^[A-Z]{2}$', code): + raise FlightQueryError( + f"Invalid airline code: '{code}'. Must be 2 uppercase letters" + ) + + return code + + +def validate_and_normalize_date(date: Union[str, datetime]) -> str: + """Validate and normalize a date to YYYY-MM-DD format. - currency = currency.strip().upper() - if not currency: - raise ValueError("Currency code cannot be empty") + Args: + date: The date to validate, either as a string in YYYY-MM-DD format or a datetime object. + + Returns: + str: The normalized date string in YYYY-MM-DD format. - if not re.match(r'^[A-Z]{3}$', currency): - raise ValueError( - f"Invalid currency code: '{currency}'. Must be 3 uppercase letters (e.g., 'USD', 'EUR')" + Raises: + DateFormatError: If the date format is invalid. + """ + if isinstance(date, str): + try: + # Try to parse the string to validate it + datetime.strptime(date, '%Y-%m-%d') + return date + except ValueError as e: + raise DateFormatError( + f"Invalid date format: '{date}'. Expected YYYY-MM-DD" + ) from e + elif isinstance(date, datetime): + return date.strftime('%Y-%m-%d') + else: + raise DateFormatError( + f"Date must be a string or datetime object, got {type(date).__name__}" ) + + +def validate_max_stops(max_stops: Optional[int]) -> Optional[int]: + """Validate the max_stops parameter. + + Args: + max_stops: Maximum number of stops allowed, or None for any number of stops. + + Returns: + Optional[int]: The validated max_stops value. + + Raises: + FlightQueryError: If max_stops is not a non-negative integer. + """ + if max_stops is not None: + if not isinstance(max_stops, int) or max_stops < 0: + raise FlightQueryError("max_stops must be a non-negative integer") + return max_stops + + +def validate_airlines(airlines: Optional[list[str]]) -> Optional[list[str]]: + """Validate a list of airline IATA codes. + + Args: + airlines: List of airline codes to validate. Can be None. + + Returns: + Optional[list[str]]: List of normalized airline codes, or None if input was None. + + Raises: + FlightQueryError: If any airline code is invalid or input is not a list. + """ + if airlines is None: + return None + + if not isinstance(airlines, list): + raise FlightQueryError("Airlines must be a list of 2-letter IATA codes") + + if not airlines: # Empty list is valid (means no airline filter) + return None + + try: + return [validate_airline_code(code) for code in airlines] + except ValueError as e: + raise FlightQueryError(str(e)) from e - return currency diff --git a/pyproject.toml b/pyproject.toml index 003cc656..f7075019 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,4 +47,5 @@ pythonVersion = '3.9' dev = [ "ipykernel>=6.29.5", "pip>=25.0.1", + "pytest>=8.0.0", ] diff --git a/tests/test_validations.py b/tests/test_validations.py new file mode 100644 index 00000000..1719133c --- /dev/null +++ b/tests/test_validations.py @@ -0,0 +1,143 @@ +"""Tests for validation functions.""" +import pytest +from enum import Enum +from typing import Literal +from datetime import datetime, timedelta +from fast_flights.validation import ( + validate_enum_value, + validate_language, + validate_currency, + validate_seat_type, + validate_trip_type, + validate_flight_query, + FlightQueryError, + AirportCodeError, + DateFormatError +) +from fast_flights.types import Language, SeatType, TripType, Currency + + +def test_validate_enum_value_with_enum(): + """Test validation with Python enums.""" + class TestEnum(Enum): + A = "A" + B = "B" + + # Valid enum values + assert validate_enum_value("A", TestEnum, "test_field") == TestEnum.A + assert validate_enum_value(TestEnum.B, TestEnum, "test_field") == TestEnum.B + + # Invalid enum value + with pytest.raises(FlightQueryError) as excinfo: + validate_enum_value("C", TestEnum, "test_field") + assert "Invalid test_field: 'C'" in str(excinfo.value) + + +def test_validate_enum_value_with_literal(): + """Test validation with Literal types.""" + TestLiteral = Literal["X", "Y", "Z"] + + # Valid literal values + assert validate_enum_value("X", TestLiteral, "test_field") == "X" + assert validate_enum_value("Y", TestLiteral, "test_field") == "Y" + + # Invalid literal value + with pytest.raises(FlightQueryError) as excinfo: + validate_enum_value("W", TestLiteral, "test_field") + assert "Invalid test_field: 'W'" in str(excinfo.value) + + +def test_validate_language(): + """Test language code validation.""" + # Valid languages + assert validate_language("en-US") == "en-US" + assert validate_language("fr") == "fr" + assert validate_language("zh-CN") == "zh-CN" + + # Invalid languages + with pytest.raises(FlightQueryError): + validate_language("invalid") + with pytest.raises(FlightQueryError): + validate_language("en-US-extra") + + # None case + assert validate_language(None) == "" + + +def test_validate_currency(): + """Test currency code validation.""" + # Valid currencies + assert validate_currency("USD") == "USD" + assert validate_currency("EUR") == "EUR" + + # Invalid currencies + with pytest.raises(FlightQueryError): + validate_currency("US") + with pytest.raises(FlightQueryError): + validate_currency("USDD") + with pytest.raises(FlightQueryError): + validate_currency("123") + + # None case + + +def test_validate_seat_type(): + """Test seat type validation.""" + # Valid seat types + assert validate_seat_type("economy") == "economy" + assert validate_seat_type("premium-economy") == "premium-economy" + assert validate_seat_type("business") == "business" + assert validate_seat_type("first") == "first" + assert validate_seat_type("economy") == "economy" # Test with string value + + # Invalid seat type + with pytest.raises(FlightQueryError) as excinfo: + validate_seat_type("invalid") + assert "seat type" in str(excinfo.value).lower() + + +def test_validate_trip_type(): + """Test trip type validation.""" + # Valid trip types + assert validate_trip_type("one-way") == "one-way" + assert validate_trip_type("round-trip") == "round-trip" + assert validate_trip_type("multi-city") == "multi-city" + + # Invalid trip types + with pytest.raises(FlightQueryError) as excinfo: + validate_trip_type("invalid") + assert "trip type" in str(excinfo.value).lower() + + with pytest.raises(FlightQueryError): + validate_trip_type(123) + + +def test_validate_flight_query(): + """Test flight query validation.""" + future_date = (datetime.now() + timedelta(days=30)).strftime("%Y-%m-%d") + + # Valid query + from_airport, to_airport = validate_flight_query("JFK", "LAX", future_date, 2) + assert from_airport == "JFK" + assert to_airport == "LAX" + + # Same origin and destination + with pytest.raises(FlightQueryError): + validate_flight_query("JFK", "JFK", future_date) + + # Past date + past_date = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d") + with pytest.raises(DateFormatError): + validate_flight_query("JFK", "LAX", past_date) + + # Invalid airport codes + with pytest.raises(AirportCodeError): + validate_flight_query("JK", "LAX", future_date) + with pytest.raises(AirportCodeError): + validate_flight_query("JFK", "LONG", future_date) + + # Invalid max_stops + with pytest.raises(FlightQueryError): + validate_flight_query("JFK", "LAX", future_date, -1) + with pytest.raises(FlightQueryError): + validate_flight_query("JFK", "LAX", future_date, "two") From 90ebeda5195cc44ebc743592edd3ab8ac8a23635 Mon Sep 17 00:00:00 2001 From: Abhinav Date: Sun, 5 Oct 2025 11:05:35 +0530 Subject: [PATCH 5/5] retain original simple init logic --- fast_flights/querying.py | 7 ++----- test.py | 2 ++ 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/fast_flights/querying.py b/fast_flights/querying.py index 557db392..5539868b 100644 --- a/fast_flights/querying.py +++ b/fast_flights/querying.py @@ -309,7 +309,7 @@ def create_query( flights: list[FlightQuery], seat: SeatType = "economy", trip: TripType = "one-way", - passengers: Optional[Passengers] = None, + passengers: Passengers = Passengers(), language: Union[str, Literal[""], Language] = "en-US", currency: Union[str, Literal[""], Currency] = "USD", max_stops: Optional[int] = None, @@ -350,9 +350,6 @@ def create_query( ... passengers=Passengers(adults=2, children=1) ... ) """ - # Initialize default passengers if not provided - if passengers is None: - passengers = Passengers() # Validate passengers if not isinstance(passengers, Passengers): @@ -387,6 +384,6 @@ def create_query( seat=SEAT_LOOKUP[seat], trip=TRIP_LOOKUP[trip], passengers=passengers.to_proto(), - language=str(language) if language else "", + language=language, currency=currency_code, ) diff --git a/test.py b/test.py index 82ad7b99..a773c4dc 100644 --- a/test.py +++ b/test.py @@ -15,5 +15,7 @@ passengers=Passengers(adults=1), language="en-US", ) + +print(query) res = get_flights(query) pprint(res)