Skip to content

Commit d717cd8

Browse files
Add RoundTripResult model and search_round_trip to all providers
- models.py: Add RoundTripResult(outbound, inbound, total_price, currency) - duffel.py: Add _convert_offer_slice() and DuffelProvider.search_round_trip() using two-slice offer_requests API for native round-trip fares - google.py: Add GoogleProvider.search_round_trip() using TripType.ROUND_TRIP with paired outbound/return FlightSegments - search.py: Add SearchEngine.search_round_trip() with caching, max_price filtering, and sorted-by-total_price output Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 31d9ae9 commit d717cd8

4 files changed

Lines changed: 199 additions & 3 deletions

File tree

src/opensky/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ class FlightResult(BaseModel):
7777
booking_url: str = ""
7878

7979

80+
class RoundTripResult(BaseModel):
81+
outbound: FlightResult
82+
inbound: FlightResult
83+
total_price: float
84+
currency: str
85+
86+
8087
class ScoredFlight(BaseModel):
8188
flight: FlightResult
8289
origin: str

src/opensky/providers/duffel.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import httpx
66
from ratelimit import limits, sleep_and_retry
77

8-
from opensky.models import FlightLeg, FlightResult
8+
from opensky.models import FlightLeg, FlightResult, RoundTripResult
99
from opensky.providers import parse_iso_duration
1010

1111
log = logging.getLogger(__name__)
@@ -20,6 +20,47 @@
2020
}
2121

2222

23+
def _convert_offer_slice(offer: dict, slice_index: int, currency: str) -> FlightResult:
24+
"""Extract a single slice from a Duffel offer as a FlightResult (price=0, use offer total)."""
25+
offer_currency = offer.get("total_currency", currency)
26+
slices = offer.get("slices", [])
27+
if slice_index >= len(slices):
28+
return FlightResult(price=0, currency=offer_currency, duration_minutes=0, stops=0, legs=[], provider="duffel")
29+
30+
slc = slices[slice_index]
31+
legs: list[FlightLeg] = []
32+
total_duration = 0
33+
34+
for seg in slc.get("segments", []):
35+
dep = seg.get("departing_at", "")
36+
arr = seg.get("arriving_at", "")
37+
dur = parse_iso_duration(seg.get("duration", ""))
38+
total_duration += dur
39+
40+
carrier = seg.get("operating_carrier", {})
41+
airline = carrier.get("iata_code", seg.get("marketing_carrier", {}).get("iata_code", ""))
42+
flight_num = seg.get("operating_carrier_flight_number") or seg.get("marketing_carrier_flight_number") or ""
43+
44+
legs.append(FlightLeg(
45+
airline=airline,
46+
flight_number=flight_num,
47+
departure_airport=seg.get("origin", {}).get("iata_code", ""),
48+
arrival_airport=seg.get("destination", {}).get("iata_code", ""),
49+
departure_time=dep,
50+
arrival_time=arr,
51+
duration_minutes=dur,
52+
))
53+
54+
return FlightResult(
55+
price=0,
56+
currency=offer_currency,
57+
duration_minutes=total_duration,
58+
stops=max(len(legs) - 1, 0),
59+
legs=legs,
60+
provider="duffel",
61+
)
62+
63+
2364
def _convert_offer(offer: dict, currency: str) -> FlightResult:
2465
"""Convert a Duffel offer dict to a FlightResult."""
2566
price = float(offer.get("total_amount", 0))
@@ -118,5 +159,55 @@ def search(
118159

119160
return results
120161

162+
@sleep_and_retry
163+
@limits(calls=10, period=1)
164+
def search_round_trip(
165+
self,
166+
origin: str,
167+
dest: str,
168+
outbound_date: str,
169+
return_date: str,
170+
cabin: str,
171+
currency: str,
172+
max_stops: int | None,
173+
) -> list[RoundTripResult]:
174+
cabin_class = CABIN_MAP.get(cabin, "economy")
175+
176+
payload = {
177+
"data": {
178+
"slices": [
179+
{"origin": origin, "destination": dest, "departure_date": outbound_date},
180+
{"origin": dest, "destination": origin, "departure_date": return_date},
181+
],
182+
"passengers": [{"type": "adult"}],
183+
"cabin_class": cabin_class,
184+
"currency": currency,
185+
"return_offers": True,
186+
"max_connections": max_stops if max_stops is not None else 2,
187+
}
188+
}
189+
190+
resp = self._client.post("/air/offer_requests", json=payload)
191+
resp.raise_for_status()
192+
data = resp.json().get("data", {})
193+
offers = data.get("offers", [])
194+
195+
results: list[RoundTripResult] = []
196+
for offer in offers:
197+
total_price = float(offer.get("total_amount", 0))
198+
offer_currency = offer.get("total_currency", currency)
199+
outbound = _convert_offer_slice(offer, 0, offer_currency)
200+
inbound = _convert_offer_slice(offer, 1, offer_currency)
201+
if max_stops is not None and (outbound.stops > max_stops or inbound.stops > max_stops):
202+
continue
203+
results.append(RoundTripResult(
204+
outbound=outbound,
205+
inbound=inbound,
206+
total_price=total_price,
207+
currency=offer_currency,
208+
))
209+
210+
return results
211+
121212
def close(self) -> None:
122213
self._client.close()

src/opensky/providers/google.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
SeatType,
1111
SearchFlights,
1212
SortBy,
13+
TripType,
1314
)
14-
from opensky.models import FlightLeg, FlightResult
15+
from opensky.models import FlightLeg, FlightResult, RoundTripResult
1516

1617
SEAT_MAP: dict[str, SeatType] = {
1718
"economy": SeatType.ECONOMY,
@@ -104,6 +105,55 @@ def search(
104105
return []
105106
return [_convert_result(r, currency) for r in raw_results]
106107

108+
def search_round_trip(
109+
self,
110+
origin: str,
111+
dest: str,
112+
outbound_date: str,
113+
return_date: str,
114+
cabin: str,
115+
currency: str,
116+
max_stops: int | None,
117+
) -> list[RoundTripResult]:
118+
seat = SEAT_MAP.get(cabin, SeatType.ECONOMY)
119+
stops = STOPS_MAP.get(max_stops, MaxStops.ANY)
120+
filters = FlightSearchFilters(
121+
trip_type=TripType.ROUND_TRIP,
122+
flight_segments=[
123+
FlightSegment(
124+
departure_airport=[[origin, 0]],
125+
arrival_airport=[[dest, 0]],
126+
travel_date=outbound_date,
127+
),
128+
FlightSegment(
129+
departure_airport=[[dest, 0]],
130+
arrival_airport=[[origin, 0]],
131+
travel_date=return_date,
132+
),
133+
],
134+
passenger_info=PassengerInfo(adults=1),
135+
seat_type=seat,
136+
stops=stops,
137+
sort_by=SortBy.CHEAPEST,
138+
)
139+
api = self._get_api()
140+
raw: list[tuple] = api.search(filters) or [] # type: ignore[assignment]
141+
results: list[RoundTripResult] = []
142+
for pair in raw:
143+
if not isinstance(pair, tuple) or len(pair) != 2:
144+
continue
145+
out_raw, in_raw = pair
146+
outbound = _convert_result(out_raw, currency)
147+
inbound = _convert_result(in_raw, currency)
148+
total_price = outbound.price + inbound.price
149+
results.append(RoundTripResult(
150+
outbound=outbound,
151+
inbound=inbound,
152+
total_price=total_price,
153+
currency=currency,
154+
))
155+
return results
156+
107157
def close(self) -> None:
108158
if self._api:
109159
self._api.close()

src/opensky/search.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from opensky import cache
1010
from opensky.config import ScanConfig, VALID_CABINS, VALID_STOPS
11-
from opensky.models import FlightLeg, FlightResult, RiskLevel, ScoredFlight
11+
from opensky.models import FlightLeg, FlightResult, RiskLevel, RoundTripResult, ScoredFlight
1212
from opensky.providers import FlightProvider, configured_providers
1313
from opensky.safety import check_route
1414

@@ -284,6 +284,54 @@ def search_scored_report(
284284
failed_providers=search_report.failed_providers,
285285
)
286286

287+
def search_round_trip(
288+
self,
289+
origin: str,
290+
dest: str,
291+
outbound_date: str,
292+
return_date: str,
293+
max_price: float = 0,
294+
) -> list[RoundTripResult]:
295+
all_results: list[RoundTripResult] = []
296+
297+
for provider in self._providers:
298+
if not hasattr(provider, "search_round_trip"):
299+
continue
300+
ck = cache.cache_key(
301+
f"rt:{origin}", dest, f"{outbound_date}:{return_date}",
302+
seat=self.seat, currency=self.currency,
303+
stops=self.max_stops, provider=provider.name,
304+
)
305+
if self.use_cache:
306+
cached = cache.get(ck)
307+
if cached is not None:
308+
all_results.extend(cached)
309+
continue
310+
311+
try:
312+
results = provider.search_round_trip(
313+
origin, dest, outbound_date, return_date,
314+
cabin=self.seat,
315+
currency=self.currency,
316+
max_stops=self.max_stops,
317+
)
318+
except Exception as exc:
319+
log.warning(
320+
"Provider %s round-trip failed for %s->%s %s/%s: %s",
321+
provider.name, origin, dest, outbound_date, return_date, exc,
322+
)
323+
continue
324+
325+
if self.use_cache:
326+
cache.put(ck, results)
327+
all_results.extend(results)
328+
329+
if max_price > 0:
330+
all_results = [r for r in all_results if r.total_price <= max_price]
331+
332+
all_results.sort(key=lambda r: r.total_price)
333+
return all_results
334+
287335
def scan(
288336
self,
289337
config: ScanConfig,

0 commit comments

Comments
 (0)