diff --git a/fast_flights/__init__.py b/fast_flights/__init__.py index c496da65..533f759f 100644 --- a/fast_flights/__init__.py +++ b/fast_flights/__init__.py @@ -1,21 +1,22 @@ from . import integrations - -from .querying import ( - FlightQuery, - Query, - Passengers, - create_query, - create_query as create_filter, # alias -) +from .exceptions import FastFlightsError, APIError +from .querying import FlightQuery, Passengers, create_query from .fetcher import get_flights, fetch_flights_html +# Create alias for backward compatibility +create_filter = create_query + __all__ = [ + # Core functionality "FlightQuery", - "Query", "Passengers", "create_query", "create_filter", "get_flights", "fetch_flights_html", "integrations", + + # Public exceptions + "FastFlightsError", + "APIError" ] diff --git a/fast_flights/exceptions.py b/fast_flights/exceptions.py new file mode 100644 index 00000000..417ec0e5 --- /dev/null +++ b/fast_flights/exceptions.py @@ -0,0 +1,33 @@ +"""Custom exceptions for the fast_flights package.""" + +class FastFlightsError(Exception): + """Base exception for all fast_flights exceptions.""" + pass + +class ValidationError(FastFlightsError, ValueError): + """Raised when input validation fails.""" + pass + +class AirportCodeError(ValidationError): + """Raised when an invalid airport code is provided.""" + pass + +class DateFormatError(ValidationError): + """Raised when a date string is in an invalid format.""" + pass + +class PassengerError(ValidationError): + """Raised when there's an issue with passenger configuration.""" + pass + +class FlightQueryError(ValidationError): + """Raised when there's an issue with flight query parameters.""" + pass + +class APIConnectionError(FastFlightsError): + """Raised when there's an issue connecting to the flight data API.""" + pass + +class APIError(FastFlightsError): + """Raised when the flight data API returns an error.""" + pass diff --git a/fast_flights/fetcher.py b/fast_flights/fetcher.py index 1729fbf3..ad5eddc0 100644 --- a/fast_flights/fetcher.py +++ b/fast_flights/fetcher.py @@ -1,11 +1,16 @@ +import logging from typing import Optional, Union, overload from primp import Client +from .exceptions import APIConnectionError, APIError from .querying import Query from .parser import MetaList, parse from .integrations import Integration +# Set up logging +logger = logging.getLogger(__name__) + URL = "https://www.google.com/travel/flights" @@ -54,11 +59,28 @@ def get_flights( """Get flights. Args: - q: The query. - proxy (str, optional): Proxy. + q: The query string or Query object. + proxy: Optional proxy configuration. + integration: Optional integration to use for fetching data. + + Returns: + MetaList: Parsed flight data. + + Raises: + APIConnectionError: If there's an issue connecting to the flight data source. + APIError: If the API returns an error or invalid response. + ValueError: If the input query is invalid. """ - html = fetch_flights_html(q, proxy=proxy, integration=integration) - return parse(html) + try: + logger.debug("Fetching flight data...") + html = fetch_flights_html(q, proxy=proxy, integration=integration) + if not html or not isinstance(html, str): + raise APIError("Received empty or invalid response from the flight data source") + return parse(html) + except Exception as e: + if isinstance(e, (APIConnectionError, APIError, ValueError)): + raise + raise APIConnectionError(f"Failed to fetch flight data: {str(e)}") from e def fetch_flights_html( @@ -68,29 +90,82 @@ def fetch_flights_html( proxy: Optional[str] = None, integration: Optional[Integration] = None, ) -> str: - """Fetch flights and get the **HTML**. + """Fetch flights and get the HTML response. Args: - q: The query. - proxy (str, optional): Proxy. + q: The query string or Query object. + proxy: Optional proxy configuration. + integration: Optional integration to use for fetching data. + + Returns: + str: The HTML content of the flight search results. + + Raises: + APIConnectionError: If there's an issue connecting to the flight data source. + APIError: If the API returns an error or invalid response. + ValueError: If the input query is invalid. """ - if integration is None: - client = Client( - impersonate="chrome_133", - impersonate_os="macos", - referer=True, - proxy=proxy, - cookie_store=True, - ) - - if isinstance(q, Query): - params = q.params() + if not q: + raise ValueError("Query cannot be empty") + + try: + if integration is None: + logger.debug("Using default client for fetching flight data") + client = Client( + impersonate="chrome_133", + impersonate_os="macos", + referer=True, + proxy=proxy, + cookie_store=True, + timeout=30, # 30 seconds timeout + ) + + try: + if isinstance(q, Query): + params = q.params() + else: + if not isinstance(q, str): + raise ValueError("Query must be a string or Query object") + params = {"q": q} + + logger.debug(f"Sending request to {URL} with params: {params}") + res = client.get(URL, params=params) + + # Check status code directly since primp's client might not have raise_for_status + if res.status_code >= 400: + error_msg = f"Flight data API returned status code {res.status_code}" + logger.error(error_msg) + raise APIError(error_msg) + + if not res.text: + error_msg = "Received empty response from the flight data source" + logger.error(error_msg) + raise APIError(error_msg) + + return res.text + + except APIError: + # Re-raise APIError as is + raise + except ValueError as e: + # Re-raise ValueError as is + logger.error(f"Invalid query: {str(e)}") + raise + except Exception as e: + # Handle other exceptions + error_msg = f"Failed to connect to flight data source: {str(e)}" + logger.error(error_msg) + raise APIConnectionError(error_msg) from e else: - params = {"q": q} - - res = client.get(URL, params=params) - return res.text - - else: - return integration.fetch_html(q) + logger.debug("Using integration for fetching flight data") + try: + return integration.fetch_html(q) + except Exception as e: + logger.error(f"Integration error while fetching flight data: {str(e)}") + raise APIError(f"Integration failed to fetch flight data: {str(e)}") from e + + except Exception as e: + if isinstance(e, (APIConnectionError, APIError, ValueError)): + raise + raise APIConnectionError(f"Unexpected error while fetching flight data: {str(e)}") from e diff --git a/fast_flights/integrations/base.py b/fast_flights/integrations/base.py index c7a832a9..83ff0468 100644 --- a/fast_flights/integrations/base.py +++ b/fast_flights/integrations/base.py @@ -1,38 +1,64 @@ +"""Base integration module for flight data providers.""" +import logging import os +from abc import ABC, abstractmethod +from typing import Union, Optional -from abc import ABC -from typing import Union - +from ..exceptions import APIConnectionError, APIError from ..querying import Query +# Set up logging +logger = logging.getLogger(__name__) + try: import dotenv # pip install python-dotenv - dotenv.load_dotenv() - except ModuleNotFoundError: - pass + logger.debug("python-dotenv not installed, skipping .env file loading") class Integration(ABC): + """Abstract base class for flight data integrations. + + This class defines the interface that all flight data integrations must implement. + Subclasses should implement the fetch_html method to retrieve flight data. + """ + + @abstractmethod def fetch_html(self, q: Union[Query, str], /) -> str: """Fetch the flights page HTML from a query. Args: - q: The query. + q: The query string or Query object. + + Returns: + str: The HTML content of the flight search results. + + Raises: + APIConnectionError: If there's an issue connecting to the data source. + APIError: If the API returns an error or invalid response. + ValueError: If the input query is invalid. """ - raise NotImplementedError - - -def get_env(k: str, /) -> str: - """(utility) Get environment variable. + raise NotImplementedError("Subclasses must implement this method") - If nothing found, raises an error. +def get_env(k: str, /, default: Optional[str] = None) -> str: + """Get environment variable with optional default value. + + Args: + k: The name of the environment variable. + default: Default value to return if the environment variable is not found. + If not provided, raises an OSError when the variable is not found. + Returns: - str: The value. + str: The value of the environment variable, or the default value if provided. + + Raises: + OSError: If the environment variable is not found and no default is provided. """ - try: - return os.environ[k] - except KeyError: - raise OSError(f"could not find environment variable: {k!r}") + value = os.environ.get(k, default) + if value is None: + error_msg = f"Required environment variable not found: {k!r}" + logger.error(error_msg) + raise OSError(error_msg) + return value diff --git a/fast_flights/integrations/bright_data.py b/fast_flights/integrations/bright_data.py index 88dd0d5e..c15da286 100644 --- a/fast_flights/integrations/bright_data.py +++ b/fast_flights/integrations/bright_data.py @@ -1,19 +1,36 @@ -# original by @Manouchehri -# pr: #64 - +"""BrightData integration for fetching flight data.""" +import logging from typing import Optional, Union + from primp import Client from .base import Integration, get_env from ..querying import Query from ..fetcher import URL +from ..exceptions import APIConnectionError, APIError + +# Set up logging +logger = logging.getLogger(__name__) DEFAULT_API_URL = "https://api.brightdata.com/request" DEFAULT_DATA_SERP_ZONE = "serp_api1" class BrightData(Integration): - __slots__ = ("api_url", "zone") + """BrightData integration for fetching flight data using their API. + + This class provides a way to fetch flight data using BrightData's API. + It requires a valid API key and zone configuration. + + Args: + api_key: BrightData API key. If not provided, will try to get from environment. + api_url: Base URL for the BrightData API. + zone: BrightData zone to use for the requests. + + Raises: + ValueError: If required configuration is missing. + """ + __slots__ = ("api_url", "zone", "client") def __init__( self, @@ -22,22 +39,84 @@ def __init__( api_url: str = DEFAULT_API_URL, zone: str = DEFAULT_DATA_SERP_ZONE, ): + """Initialize the BrightData integration.""" self.api_url = api_url or get_env("BRIGHT_DATA_API_URL") - self.zone = zone + if not self.api_url: + raise ValueError("BrightData API URL is required") + + self.zone = zone or get_env("BRIGHT_DATA_ZONE") + if not self.zone: + raise ValueError("BrightData zone is required") + + api_key = api_key or get_env("BRIGHT_DATA_API_KEY") + if not api_key: + raise ValueError("BrightData API key is required") + + logger.debug("Initializing BrightData integration") self.client = Client( - headers={ - "Authorization": "Bearer " + (api_key or get_env("BRIGHT_DATA_API_KEY")) - } + headers={"Authorization": f"Bearer {api_key}"}, + timeout=30, # 30 seconds timeout ) def fetch_html(self, q: Union[Query, str], /) -> str: - if isinstance(q, str): - res = self.client.post( - self.api_url, json={"url": URL + "?q=" + q, "zone": self.zone} - ) - else: - res = self.client.post( - self.api_url, json={"url": q.url(), "zone": self.zone} + """Fetch flight data HTML using BrightData API. + + Args: + q: The query string or Query object. + + Returns: + str: The HTML content of the flight search results. + + Raises: + APIConnectionError: If there's an issue connecting to BrightData API. + APIError: If the API returns an error or invalid response. + ValueError: If the input query is invalid. + """ + if not q: + raise ValueError("Query cannot be empty") + + try: + # Prepare the request payload + if isinstance(q, str): + url = f"{URL}?q={q}" + elif isinstance(q, Query): + url = q.url() + else: + raise ValueError("Query must be a string or Query object") + + payload = { + "url": url, + "zone": self.zone, + } + + logger.debug(f"Sending request to BrightData API: {self.api_url}") + logger.debug(f"Request payload: {payload}") + + # Make the API request + response = self.client.post( + self.api_url, + json=payload, + headers={"Content-Type": "application/json"}, ) - - return res.text + + # Check for HTTP errors + if not response.ok: + error_msg = f"BrightData API returned status code {response.status_code}" + logger.error(error_msg) + raise APIError(error_msg) + + # Check for empty response + if not response.text: + error_msg = "Received empty response from BrightData API" + logger.error(error_msg) + raise APIError(error_msg) + + return response.text + + except Exception as e: + if isinstance(e, (APIConnectionError, APIError, ValueError)): + raise + + error_msg = f"Failed to fetch data from BrightData: {str(e)}" + logger.error(error_msg, exc_info=True) + raise APIConnectionError(error_msg) from e diff --git a/fast_flights/querying.py b/fast_flights/querying.py index defeb2d2..5539868b 100644 --- a/fast_flights/querying.py +++ b/fast_flights/querying.py @@ -1,7 +1,25 @@ +import re from base64 import b64encode from datetime import datetime as Datetime from dataclasses import dataclass -from typing import Literal, Optional, Union +from typing import Literal, Optional, TypeVar, Union, get_args, get_origin + +from .exceptions import ( + PassengerError, + FlightQueryError, +) +from .validation import ( + validate_passengers, + validate_flight_query, + validate_seat_type, + validate_trip_type, + validate_flights_list, + validate_airlines, + validate_max_stops, + validate_and_normalize_date, + validate_language, + validate_currency +) from .types import Currency, Language, SeatType, TripType from .pb.flights_pb2 import Airport, Info, Passenger, Seat, FlightData, Trip @@ -17,8 +35,43 @@ class Query: passengers: list[Passenger] language: str currency: str + + def __str__(self) -> str: + """Return a human-readable string representation of the query.""" + flight_info = [] + for i, flight in enumerate(self.flight_data, 1): + flight_info.append(f" Flight {i}:") + flight_info.append(f" From: {flight.from_airport.airport if flight.HasField('from_airport') else 'N/A'}") + flight_info.append(f" To: {flight.to_airport.airport if flight.HasField('to_airport') else 'N/A'}") + flight_info.append(f" Date: {flight.date}") + + if flight.airlines: + flight_info.append(f" Airlines: {', '.join(flight.airlines)}") + + if flight.HasField('max_stops'): + flight_info.append(f" Max Stops: {flight.max_stops}") + + # Count passenger types + passenger_counts = {} + for p in self.passengers: + if p == Passenger.ADULT: + passenger_counts['Adults'] = passenger_counts.get('Adults', 0) + 1 + elif p == Passenger.CHILD: + passenger_counts['Children'] = passenger_counts.get('Children', 0) + 1 + elif p in (Passenger.INFANT_IN_SEAT, Passenger.INFANT_ON_LAP): + passenger_counts['Infants'] = passenger_counts.get('Infants', 0) + 1 + + return ( + f"Query Details:\n" + f"Seat Class: {self.seat}\n" + f"Trip Type: {self.trip}\n" + f"Passengers: {', '.join(f'{count} {type_}' for type_, count in passenger_counts.items() if count > 0)}\n" + f"Language: {self.language or 'Default'}\n" + f"Currency: {self.currency or 'Default'}\n" + f"Flights:\n" + '\n'.join(flight_info) + ) - def pb(self) -> Info: + def to_proto(self) -> Info: """(internal) Protobuf data. (`Info`)""" return Info( data=self.flight_data, @@ -29,7 +82,7 @@ def pb(self) -> Info: def to_bytes(self) -> bytes: """Convert this query to bytes.""" - return self.pb().SerializeToString() + return self.to_proto().SerializeToString() def to_str(self) -> str: """Convert this query to a string.""" @@ -59,24 +112,66 @@ def __repr__(self) -> str: @dataclass class FlightQuery: + """Represents a flight search query. + + Args: + date: Departure date as a string in YYYY-MM-DD format or a datetime object. + from_airport: IATA code of the departure airport (e.g., 'JFK', 'LAX'). + to_airport: IATA code of the arrival airport (e.g., 'SFO', 'LHR'). + max_stops: Maximum number of stops allowed (None for any number of stops). + airlines: Optional list of airline IATA codes to filter by. + + Raises: + FlightQueryError: If any of the query parameters are invalid. + """ date: Union[str, Datetime] from_airport: str to_airport: str max_stops: Optional[int] = None airlines: Optional[list[str]] = None - def pb(self) -> FlightData: - if isinstance(self.date, str): - date = self.date - else: - date = self.date.strftime("%Y-%m-%d") + def __post_init__(self): + """Validate and normalize flight query parameters after initialization. + + Raises: + AirportCodeError: If airport codes are invalid + DateFormatError: If date format is invalid + FlightQueryError: For other validation errors + """ + # Validate and normalize date first + self._normalized_date = validate_and_normalize_date(self.date) + + # Validate and normalize flight query parameters + self.from_airport, self.to_airport = validate_flight_query( + self.from_airport, + self.to_airport, + self._normalized_date, + self.max_stops + ) + + # Validate and normalize airlines if provided + if self.airlines is not None: + try: + self.airlines = validate_airlines(self.airlines) + except ValueError as e: + raise FlightQueryError(str(e)) from e + def to_proto(self) -> FlightData: + """Convert this query to a protobuf FlightData message. + + Returns: + FlightData: The protobuf message representing this query. + + Note: + All validations should be done in __post_init__. + This method should only convert already validated data to protobuf. + """ return FlightData( - date=date, + date=self._normalized_date, from_airport=Airport(airport=self.from_airport), to_airport=Airport(airport=self.to_airport), - max_stops=self.max_stops, - airlines=self.airlines, + max_stops=self.max_stops if self.max_stops is not None else None, + airlines=self.airlines if self.airlines else None, ) def _setmaxstops(self, m: Optional[int] = None) -> "FlightQuery": @@ -87,27 +182,44 @@ def _setmaxstops(self, m: Optional[int] = None) -> "FlightQuery": class Passengers: + """Represents a group of passengers for a flight. + + Args: + adults: Number of adults (16+ years). At least one adult is required. + children: Number of children (2-15 years). + infants_in_seat: Number of infants (under 2 years) with their own seat. + infants_on_lap: Number of infants (under 2 years) on an adult's lap. + + Raises: + PassengerError: If the passenger configuration is invalid. + """ + adults: int + children: int + infants_in_seat: int + infants_on_lap: int + def __init__( self, *, - adults: int = 0, + adults: int = 1, children: int = 0, infants_in_seat: int = 0, infants_on_lap: int = 0, - ): - assert ( - sum((adults, children, infants_in_seat, infants_on_lap)) <= 9 - ), "Too many passengers (> 9)" - assert ( - infants_on_lap <= adults - ), "Must have at least one adult per infant on lap" - - self.adults = adults - self.children = children - self.infants_in_seat = infants_in_seat - self.infants_on_lap = infants_on_lap - - def pb(self) -> list[Passenger]: + ) -> None: + """Initialize and validate passenger configuration.""" + self.adults = int(adults) + self.children = int(children) + self.infants_in_seat = int(infants_in_seat) + self.infants_on_lap = int(infants_on_lap) + + validate_passengers( + adults=self.adults, + children=self.children, + infants_in_seat=self.infants_in_seat, + infants_on_lap=self.infants_on_lap + ) + + def to_proto(self) -> list[Passenger]: return [ *(Passenger.ADULT for _ in range(self.adults)), *(Passenger.CHILD for _ in range(self.children)), @@ -116,18 +228,80 @@ def pb(self) -> list[Passenger]: ] -DEFAULT_PASSENGERS = Passengers(adults=1) -SEAT_LOOKUP = { - "economy": Seat.ECONOMY, - "premium-economy": Seat.PREMIUM_ECONOMY, - "business": Seat.BUSINESS, - "first": Seat.FIRST, -} -TRIP_LOOKUP = { - "round-trip": Trip.ROUND_TRIP, - "one-way": Trip.ONE_WAY, - "multi-city": Trip.MULTI_CITY, -} +from typing import TypeVar, Type, Any, get_origin, get_args, cast + +T = TypeVar('T') + +def _get_literal_values(tp: Any) -> tuple[str, ...]: + """Extract the literal values from a Literal type. + + Returns: + tuple[str, ...]: The literal values as strings, or empty tuple if not a Literal + """ + origin = get_origin(tp) + if origin is not Literal: + return () + + args = get_args(tp) + if not args: + return () + + # Filter out non-string values and ensure they're strings + return tuple(str(arg) for arg in args if isinstance(arg, (str, int, bool, float))) + +def _create_enum_lookup(enum_type: Union[Type[T], Type[str]], literal_type: Any) -> dict[str, T]: + """Create a lookup dictionary from a Literal type to an enum or string type. + + Args: + enum_type: The target type (e.g., Seat, Trip, or str for direct mappings) + literal_type: The Literal type (e.g., SeatType, TripType, Language, Currency) + + Returns: + dict[str, T]: A mapping from string literals to enum values or strings of type T + """ + # Get the allowed string values from the Literal type + literal_values = _get_literal_values(literal_type) + if not literal_values: + return {} + + # If the target type is str, create a direct mapping (for Language and Currency) + if enum_type is str: + return {value: value for value in literal_values if isinstance(value, str)} + + # For enum types, map string values to enum values + lookup: dict[str, T] = {} + for value in literal_values: + if not isinstance(value, str): + continue + + # Convert string like "premium-economy" to "PREMIUM_ECONOMY" + enum_name = value.upper().replace('-', '_') + try: + enum_value = getattr(enum_type, enum_name) + lookup[value] = enum_value + except AttributeError: + continue # Skip if the enum value doesn't exist + + return lookup + +# Create lookups from Literal types to their corresponding values +SEAT_LOOKUP = _create_enum_lookup(Seat, SeatType) +TRIP_LOOKUP = _create_enum_lookup(Trip, TripType) +LANGUAGE_LOOKUP = _create_enum_lookup(str, Language) +CURRENCY_LOOKUP = _create_enum_lookup(str, Currency) + +# Runtime validation of the lookup tables +def _validate_lookup(lookup: dict, type_name: str, expected_values: set) -> None: + """Validate that all expected values are present in the lookup.""" + missing = expected_values - set(lookup.keys()) + if missing: + raise RuntimeError(f"Missing {type_name} values for: {', '.join(sorted(missing))}") + +# Validate all lookups +_validate_lookup(SEAT_LOOKUP, "Seat", set(get_args(SeatType))) +_validate_lookup(TRIP_LOOKUP, "Trip", set(get_args(TripType))) +_validate_lookup(LANGUAGE_LOOKUP, "Language", set(get_args(Language))) +_validate_lookup(CURRENCY_LOOKUP, "Currency", set(get_args(Currency))) def create_query( @@ -135,27 +309,81 @@ def create_query( flights: list[FlightQuery], seat: SeatType = "economy", trip: TripType = "one-way", - passengers: Passengers = DEFAULT_PASSENGERS, - language: Union[str, Literal[""], Language] = "", - currency: Union[str, Literal[""], Currency] = "", + passengers: Passengers = Passengers(), + language: Union[str, Literal[""], Language] = "en-US", + currency: Union[str, Literal[""], Currency] = "USD", max_stops: Optional[int] = None, ) -> Query: - """Create a query. - + """Create a query for flight search. + Args: - flights: The flight queries. - seat: Desired seat type. - trip: Trip type. - passengers: Passengers. - language: Set the language. Use `""` (blank str) to let Google decide. - currency: Set the currency. Use `""` (blank str) to let Google decide. - max_stops (optional): Set the maximum stops for every flight query, if present. + flights: List of FlightQuery objects representing the flight segments. + Must contain at least one flight. + seat: Seat type (e.g., 'economy', 'business'). Defaults to 'economy'. + trip: Type of trip ('one-way', 'round-trip', 'multi-city'). + Defaults to 'one-way'. + passengers: Passengers configuration. If None, defaults to 1 adult. + language: Language code (e.g., 'en-US', 'es'). Empty string lets Google decide. + currency: Currency code (e.g., 'USD', 'EUR'). Empty string lets Google decide. + max_stops: Maximum number of stops allowed for all flights. + Overrides individual flight settings. Use None for no limit. + + Returns: + Query: A configured Query object ready for flight search. + + Raises: + FlightQueryError: If there are issues with the flight queries. + ValueError: If any of the input parameters are invalid. + AirportCodeError: If any airport code is invalid. + DateFormatError: If any date format is invalid. + PassengerError: If passenger configuration is invalid. + + Example: + >>> query = create_query( + ... flights=[ + ... FlightQuery( + ... date="2025-12-25", + ... from_airport="JFK", + ... to_airport="LAX" + ... ) + ... ], + ... passengers=Passengers(adults=2, children=1) + ... ) """ + + # Validate passengers + if not isinstance(passengers, Passengers): + raise ValueError("passengers must be an instance of Passengers") + + # Validate language + language = validate_language(language) if language else "" + language = LANGUAGE_LOOKUP.get(language, "") if language else "" + + # Validate currency + if currency: + currency = validate_currency(currency) + currency_code = CURRENCY_LOOKUP.get(currency, "") if currency else "" + + # Validate flights list + validate_flights_list(flights, FlightQuery) + + # Validate and normalize seat type + seat = validate_seat_type(seat) + + # Validate and normalize trip type + trip = validate_trip_type(trip) + + # Validate and apply max_stops to all flights if specified + if max_stops is not None: + max_stops = validate_max_stops(max_stops) + flights = [flight._setmaxstops(max_stops) for flight in flights] + + # Create and return the query return Query( - flight_data=[flight._setmaxstops(max_stops).pb() for flight in flights], + flight_data=[flight.to_proto() for flight in flights], seat=SEAT_LOOKUP[seat], trip=TRIP_LOOKUP[trip], - passengers=passengers.pb(), + passengers=passengers.to_proto(), language=language, - currency=currency, + currency=currency_code, ) diff --git a/fast_flights/validation.py b/fast_flights/validation.py new file mode 100644 index 00000000..b0f7fbb5 --- /dev/null +++ b/fast_flights/validation.py @@ -0,0 +1,442 @@ +"""Input validation utilities for fast_flights package.""" +import re +from datetime import datetime +from enum import Enum +from typing import Any, List, Literal, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin +from .exceptions import ( + AirportCodeError, + DateFormatError, + PassengerError, + FlightQueryError +) +from .types import Language, SeatType, TripType, Currency + +T = TypeVar('T') + +def validate_enum_value(value: Any, enum_type: Type[T], field_name: str) -> T: + """Validate that a value is a valid enum value or literal. + + Args: + value: The value to validate. + enum_type: The enum class or Literal type to validate against. + field_name: The name of the field being validated (for error messages). + + Returns: + The validated value. + + Raises: + FlightQueryError: If the value is not a valid enum or literal value. + """ + # Handle Enum types + if isinstance(enum_type, type) and issubclass(enum_type, Enum): + try: + return enum_type(value) + except ValueError: + valid_values = [e.value for e in enum_type] + raise FlightQueryError( + f"Invalid {field_name}: '{value}'. " + f"Must be one of: {', '.join(repr(v) for v in valid_values)}" + ) from None + + # Handle Literal types + origin = get_origin(enum_type) + if origin is Literal: + valid_values = get_args(enum_type) + if value in valid_values: + return value + raise FlightQueryError( + f"Invalid {field_name}: '{value}'. " + f"Must be one of: {', '.join(repr(v) for v in valid_values)}" + ) from None + + # Handle regular types + if isinstance(value, enum_type): + return value + + raise FlightQueryError( + f"Expected {field_name} to be of type {enum_type.__name__}, got {type(value).__name__}" + ) + +def validate_airport_code(code: str) -> str: + """Validate and normalize an airport IATA code. + + Args: + code: The airport code to validate. + + Returns: + str: The normalized uppercase airport code. + + Raises: + AirportCodeError: If the airport code is not a string or has an invalid format. + """ + if not isinstance(code, str): + raise AirportCodeError(f"Airport code must be a string, got {type(code).__name__}") + + code = code.strip().upper() + if not code: + raise AirportCodeError("Airport code cannot be empty") + if not re.match(r'^[A-Z]{3}$', code): + raise AirportCodeError(f"Invalid IATA code: '{code}'. Must be 3 uppercase letters") + + return code + +def validate_passengers( + adults: int, + children: int, + infants_in_seat: int, + infants_on_lap: int, + max_passengers: int = 9, +) -> None: + """Validate passenger configuration for a flight booking. + + Ensures that the passenger configuration follows these rules: + - At least one adult is required + - No more than max_passengers (default: 9) total passengers allowed + - Number of infants on lap cannot exceed number of adults + - All counts must be non-negative integers + + Args: + adults: Number of adults (must be >= 1). + children: Number of children (must be >= 0). + infants_in_seat: Number of infants with their own seat (must be >= 0). + infants_on_lap: Number of infants on lap (must be >= 0). + max_passengers: Maximum allowed total passengers (default: 9). + + Raises: + PassengerError: If passenger configuration is invalid, counts are not integers, + counts are negative, or other validation fails. + """ + # Validate input types + if not all(isinstance(x, int) for x in [adults, children, infants_in_seat, infants_on_lap]): + raise PassengerError("All passenger counts must be integers") + + # Validate at least one adult + if adults < 1: + raise PassengerError("At least one adult is required") + + # Validate non-negative values + if children < 0: + raise PassengerError("Number of children cannot be negative") + if infants_in_seat < 0: + raise PassengerError("Number of infants in seat cannot be negative") + if infants_on_lap < 0: + raise PassengerError("Number of infants on lap cannot be negative") + + total = adults + children + infants_in_seat + infants_on_lap + + # Validate at least one passenger + if total < 1: + raise PassengerError("At least one passenger is required") + + + + # Validate total passengers + if total > max_passengers: + raise PassengerError( + f"Maximum of {max_passengers} passengers allowed, got {total} " + f"({adults} adults, {children} children, {infants_in_seat} infants in seat, " + f"{infants_on_lap} infants on lap)" + ) + + # Validate infants on lap + if infants_on_lap > 0 and adults < 1: + raise PassengerError( + f"Cannot have {infants_on_lap} infants on lap without at least one adult" + ) + + # Validate infants in seat + if infants_in_seat > 0 and adults < 1 and children < 1: + raise PassengerError( + f"Cannot have {infants_in_seat} infants in seat without at least one adult or child" + ) + +def validate_flight_query( + from_airport: str, + to_airport: str, + date: Union[str, datetime], + max_stops: Optional[int] = None +) -> tuple[str, str]: + """Validate and normalize flight query parameters. + + Validates that: + - Airport codes are valid IATA codes + - Origin and destination are different + - Date is valid and in the future + - max_stops is a non-negative integer or None + + Args: + from_airport: Departure airport code (case-insensitive). + to_airport: Arrival airport code (case-insensitive). + date: Departure date (string in YYYY-MM-DD format or datetime object). + max_stops: Maximum number of stops allowed (None for no limit). + + Returns: + tuple: Normalized (from_airport, to_airport) codes in uppercase. + + Raises: + FlightQueryError: If any parameter is invalid. + AirportCodeError: If airport codes are invalid. + DateFormatError: If date format is invalid or in the past. + """ + # Validate and normalize airport codes + from_airport_norm = validate_airport_code(from_airport) + to_airport_norm = validate_airport_code(to_airport) + + # Ensure origin and destination are different + if from_airport_norm == to_airport_norm: + raise FlightQueryError( + f"Origin and destination airports cannot be the same: {from_airport_norm}" + ) + + # Validate date is in the future + normalized_date = validate_and_normalize_date(date) + today = datetime.now().date() + if isinstance(normalized_date, str): + normalized_date = datetime.strptime(normalized_date, "%Y-%m-%d").date() + + if normalized_date < today: + raise DateFormatError("Departure date cannot be in the past") + + # Validate max_stops + if max_stops is not None: + if not isinstance(max_stops, int): + raise FlightQueryError("max_stops must be an integer or None") + if max_stops < 0: + raise FlightQueryError("max_stops cannot be negative") + + return from_airport_norm, to_airport_norm + +def _validate_literal(value: Any, literal_type: Type[T], type_name: str) -> T: + """Validate a value against a Literal type or Enum. + + Args: + value: The value to validate. + literal_type: The type to validate against (Literal or Enum). + type_name: Human-readable name of the type for error messages. + + Returns: + The validated value of type T. + + Raises: + FlightQueryError: If the value is not a valid value of the literal type. + """ + # Handle Enum types + if isinstance(literal_type, type) and issubclass(literal_type, Enum): + try: + return literal_type(value) + except ValueError: + valid_values = [e.value for e in literal_type] + raise FlightQueryError( + f"Invalid {type_name}: '{value}'. " + f"Must be one of: {', '.join(repr(v) for v in valid_values)}" + ) from None + + # Handle Literal types + origin = get_origin(literal_type) + if origin is Literal: + valid_values = get_args(literal_type) + if value in valid_values: + return value + raise FlightQueryError( + f"Invalid {type_name}: '{value}'. " + f"Must be one of: {', '.join(repr(v) for v in valid_values)}" + ) + + # If we get here, the type is not supported + raise ValueError(f"Unsupported type for validation: {literal_type}") + + +def validate_seat_type(seat: Union[str, SeatType]) -> SeatType: + """Validate that a seat type is valid. + + Args: + seat: The seat type to validate. + + Returns: + SeatType: The validated seat type. + + Raises: + FlightQueryError: If the seat type is invalid. + """ + return _validate_literal(seat, SeatType, 'seat type') + + +def validate_trip_type(trip: Union[str, TripType]) -> TripType: + """Validate that a trip type is valid. + + Args: + trip: The trip type to validate. + + Returns: + TripType: The validated trip type. + + Raises: + FlightQueryError: If the trip type is invalid. + """ + return _validate_literal(trip, TripType, 'trip type') + + +def validate_language(language: Union[str, Language, None]) -> str: + """Validate and normalize a language code. + + Args: + language: The language code to validate. + + Returns: + str: The normalized language code, or empty string if None. + + Raises: + FlightQueryError: If the language code is invalid. + """ + if language is None: + return "" + language_str = str(language).strip() + if not language_str: + return "" # Empty string is allowed (means use default) + + return _validate_literal(language_str, Language, 'language code') + + +def validate_currency(currency: Union[str, Currency, None]) -> str: + """Validate and normalize a currency code. + + Args: + currency: The currency code to validate. + + Returns: + str: The normalized currency code, or empty string if None. + + Raises: + FlightQueryError: If the currency code is invalid. + """ + if currency is None: + return "" + + currency_str = str(currency).strip().upper() + if not currency_str: + return "" # Empty string is allowed (means use default) + + return _validate_literal(currency_str, Currency, 'currency code') + + +def validate_flights_list(flights: List[Any], expected_type: Type[T]) -> List[T]: + """Validate that a list contains only instances of the expected type. + flights: The list to validate. + expected_type: The expected type of list elements. + + Returns: + List[T]: The validated list. + + Raises: + FlightQueryError: If the list is empty or contains invalid elements. + """ + if not isinstance(flights, (list, tuple)) or not flights: + raise FlightQueryError("At least one flight segment is required") + + if not all(isinstance(f, expected_type) for f in flights): + type_name = expected_type.__name__ + raise FlightQueryError(f"All flight segments must be {type_name} instances") + + return flights # type: ignore + + +def validate_airline_code(code: str) -> str: + """Validate and normalize an airline IATA code. + + Args: + code: The airline code to validate. + + Returns: + str: The normalized uppercase airline code. + + Raises: + FlightQueryError: If the airline code format is invalid. + """ + if not isinstance(code, str): + raise FlightQueryError(f"Airline code must be a string, got {type(code).__name__}") + + code = code.strip().upper() + if not code: + raise FlightQueryError("Airline code cannot be empty") + + if not re.match(r'^[A-Z]{2}$', code): + raise FlightQueryError( + f"Invalid airline code: '{code}'. Must be 2 uppercase letters" + ) + + return code + + +def validate_and_normalize_date(date: Union[str, datetime]) -> str: + """Validate and normalize a date to YYYY-MM-DD format. + + Args: + date: The date to validate, either as a string in YYYY-MM-DD format or a datetime object. + + Returns: + str: The normalized date string in YYYY-MM-DD format. + + Raises: + DateFormatError: If the date format is invalid. + """ + if isinstance(date, str): + try: + # Try to parse the string to validate it + datetime.strptime(date, '%Y-%m-%d') + return date + except ValueError as e: + raise DateFormatError( + f"Invalid date format: '{date}'. Expected YYYY-MM-DD" + ) from e + elif isinstance(date, datetime): + return date.strftime('%Y-%m-%d') + else: + raise DateFormatError( + f"Date must be a string or datetime object, got {type(date).__name__}" + ) + + +def validate_max_stops(max_stops: Optional[int]) -> Optional[int]: + """Validate the max_stops parameter. + + Args: + max_stops: Maximum number of stops allowed, or None for any number of stops. + + Returns: + Optional[int]: The validated max_stops value. + + Raises: + FlightQueryError: If max_stops is not a non-negative integer. + """ + if max_stops is not None: + if not isinstance(max_stops, int) or max_stops < 0: + raise FlightQueryError("max_stops must be a non-negative integer") + return max_stops + + +def validate_airlines(airlines: Optional[list[str]]) -> Optional[list[str]]: + """Validate a list of airline IATA codes. + + Args: + airlines: List of airline codes to validate. Can be None. + + Returns: + Optional[list[str]]: List of normalized airline codes, or None if input was None. + + Raises: + FlightQueryError: If any airline code is invalid or input is not a list. + """ + if airlines is None: + return None + + if not isinstance(airlines, list): + raise FlightQueryError("Airlines must be a list of 2-letter IATA codes") + + if not airlines: # Empty list is valid (means no airline filter) + return None + + try: + return [validate_airline_code(code) for code in airlines] + except ValueError as e: + raise FlightQueryError(str(e)) from e + diff --git a/pyproject.toml b/pyproject.toml index 003cc656..f7075019 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,4 +47,5 @@ pythonVersion = '3.9' dev = [ "ipykernel>=6.29.5", "pip>=25.0.1", + "pytest>=8.0.0", ] diff --git a/test.py b/test.py index dc5f8323..a773c4dc 100644 --- a/test.py +++ b/test.py @@ -1,10 +1,11 @@ from fast_flights import FlightQuery, Passengers, create_query, get_flights from pprint import pprint +import datetime query = create_query( flights=[ FlightQuery( - date="2025-12-22", + date=(datetime.date.today() + datetime.timedelta(days=30)).isoformat(), from_airport="MYJ", to_airport="TPE", ), @@ -12,7 +13,9 @@ seat="economy", trip="one-way", passengers=Passengers(adults=1), - language="zh-TW", + language="en-US", ) + +print(query) res = get_flights(query) pprint(res) diff --git a/tests/test_validations.py b/tests/test_validations.py new file mode 100644 index 00000000..1719133c --- /dev/null +++ b/tests/test_validations.py @@ -0,0 +1,143 @@ +"""Tests for validation functions.""" +import pytest +from enum import Enum +from typing import Literal +from datetime import datetime, timedelta +from fast_flights.validation import ( + validate_enum_value, + validate_language, + validate_currency, + validate_seat_type, + validate_trip_type, + validate_flight_query, + FlightQueryError, + AirportCodeError, + DateFormatError +) +from fast_flights.types import Language, SeatType, TripType, Currency + + +def test_validate_enum_value_with_enum(): + """Test validation with Python enums.""" + class TestEnum(Enum): + A = "A" + B = "B" + + # Valid enum values + assert validate_enum_value("A", TestEnum, "test_field") == TestEnum.A + assert validate_enum_value(TestEnum.B, TestEnum, "test_field") == TestEnum.B + + # Invalid enum value + with pytest.raises(FlightQueryError) as excinfo: + validate_enum_value("C", TestEnum, "test_field") + assert "Invalid test_field: 'C'" in str(excinfo.value) + + +def test_validate_enum_value_with_literal(): + """Test validation with Literal types.""" + TestLiteral = Literal["X", "Y", "Z"] + + # Valid literal values + assert validate_enum_value("X", TestLiteral, "test_field") == "X" + assert validate_enum_value("Y", TestLiteral, "test_field") == "Y" + + # Invalid literal value + with pytest.raises(FlightQueryError) as excinfo: + validate_enum_value("W", TestLiteral, "test_field") + assert "Invalid test_field: 'W'" in str(excinfo.value) + + +def test_validate_language(): + """Test language code validation.""" + # Valid languages + assert validate_language("en-US") == "en-US" + assert validate_language("fr") == "fr" + assert validate_language("zh-CN") == "zh-CN" + + # Invalid languages + with pytest.raises(FlightQueryError): + validate_language("invalid") + with pytest.raises(FlightQueryError): + validate_language("en-US-extra") + + # None case + assert validate_language(None) == "" + + +def test_validate_currency(): + """Test currency code validation.""" + # Valid currencies + assert validate_currency("USD") == "USD" + assert validate_currency("EUR") == "EUR" + + # Invalid currencies + with pytest.raises(FlightQueryError): + validate_currency("US") + with pytest.raises(FlightQueryError): + validate_currency("USDD") + with pytest.raises(FlightQueryError): + validate_currency("123") + + # None case + + +def test_validate_seat_type(): + """Test seat type validation.""" + # Valid seat types + assert validate_seat_type("economy") == "economy" + assert validate_seat_type("premium-economy") == "premium-economy" + assert validate_seat_type("business") == "business" + assert validate_seat_type("first") == "first" + assert validate_seat_type("economy") == "economy" # Test with string value + + # Invalid seat type + with pytest.raises(FlightQueryError) as excinfo: + validate_seat_type("invalid") + assert "seat type" in str(excinfo.value).lower() + + +def test_validate_trip_type(): + """Test trip type validation.""" + # Valid trip types + assert validate_trip_type("one-way") == "one-way" + assert validate_trip_type("round-trip") == "round-trip" + assert validate_trip_type("multi-city") == "multi-city" + + # Invalid trip types + with pytest.raises(FlightQueryError) as excinfo: + validate_trip_type("invalid") + assert "trip type" in str(excinfo.value).lower() + + with pytest.raises(FlightQueryError): + validate_trip_type(123) + + +def test_validate_flight_query(): + """Test flight query validation.""" + future_date = (datetime.now() + timedelta(days=30)).strftime("%Y-%m-%d") + + # Valid query + from_airport, to_airport = validate_flight_query("JFK", "LAX", future_date, 2) + assert from_airport == "JFK" + assert to_airport == "LAX" + + # Same origin and destination + with pytest.raises(FlightQueryError): + validate_flight_query("JFK", "JFK", future_date) + + # Past date + past_date = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d") + with pytest.raises(DateFormatError): + validate_flight_query("JFK", "LAX", past_date) + + # Invalid airport codes + with pytest.raises(AirportCodeError): + validate_flight_query("JK", "LAX", future_date) + with pytest.raises(AirportCodeError): + validate_flight_query("JFK", "LONG", future_date) + + # Invalid max_stops + with pytest.raises(FlightQueryError): + validate_flight_query("JFK", "LAX", future_date, -1) + with pytest.raises(FlightQueryError): + validate_flight_query("JFK", "LAX", future_date, "two")