Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion api/src/feder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@
get_flights, FlightQuery, stream_trajectories, stream_trajectory_arrays
)
from .available import available_days, available_times, available_sources # noqa
from .common.models import DataSource, Point, Trajectory, TrajectoryArray # noqa
from .common.models import ( # noqa
DataSource, Point, Trajectory, TrajectoryArray, TrajectoryArrayBatch
)
from .common.db import BoundingBox, TemporalQueryType, SpatialQueryType # noqa
from .common.version import get_feder_version # noqa

Expand All @@ -119,6 +121,7 @@
'Point',
'Trajectory',
'TrajectoryArray',
'TrajectoryArrayBatch',
'BoundingBox',
'TemporalQueryType',
'SpatialQueryType',
Expand Down
14 changes: 7 additions & 7 deletions api/src/feder/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from feder.common import (
DB, DataSource, BoundingBox, TemporalQueryType, SpatialQueryType,
Trajectory, TrajectoryArray
Trajectory, TrajectoryArrayBatch
)


Expand Down Expand Up @@ -270,14 +270,14 @@ def stream_trajectory_arrays(
*,
native_endian: bool = True,
missing_as_nan: bool = True,
) -> Generator[TrajectoryArray, None, None]:
"""Stream all trajectories for one day with points as numpy arrays.
) -> Generator[TrajectoryArrayBatch, None, None]:
"""Stream decoded trajectory-array batches for one day.

The data directory is read from the `FEDER_DATA_DIR` environment variable.
Trajectories are processed in batches to keep memory use bounded. No
ordering guarantee is part of the public API. Returned point arrays should
be treated as read-only; callers that need to modify them should make a
copy.
Each yielded value corresponds to one SQLite/decode chunk. Ordering should not
be relied upon by callers as part of the public contract. Returned point
arrays should be treated as read-only; callers that need to modify them
should make a copy.
"""
if not isinstance(day, date) or isinstance(day, datetime):
raise TypeError('day must be a datetime.date')
Expand Down
38 changes: 26 additions & 12 deletions libs/common/src/feder_common/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import lz4.frame
import numpy as np

from .models import DataSource, Point, Trajectory, TrajectoryArray
from .models import (
DataSource, Point, Trajectory, TrajectoryArray, TrajectoryArrayBatch
)
from .utils import MISSING_VALUE

_N_WORKERS = min(8, os.cpu_count() or 4)
Expand Down Expand Up @@ -84,16 +86,16 @@ def _process_array_chunk(
rows: list,
native_endian: bool,
missing_as_nan: bool,
) -> list[TrajectoryArray]:
return [
) -> tuple[TrajectoryArray, ...]:
return tuple(
TrajectoryArray(
source=DataSource(row[0]), source_id=row[1], transponder_id=row[2],
orig=row[3], dest=row[4], callsign=row[5], aircraft_type=row[6],
points=_prepare_array(row[7], native_endian, missing_as_nan),
partial=False
)
for row in rows
]
)


class TemporalQueryType(Enum):
Expand Down Expand Up @@ -422,14 +424,15 @@ def stream_trajectory_arrays(
*,
native_endian: bool = True,
missing_as_nan: bool = True,
) -> Generator[TrajectoryArray, None, None]:
"""Yield all trajectories with points as structured numpy arrays.
) -> Generator[TrajectoryArrayBatch, None, None]:
"""Yield decoded trajectory-array batches.

Rows are scanned in SQLite row ID order and processed in batches to
keep memory use bounded. The ordering is deterministic, but callers
should not rely on a particular order as part of the public contract.
Point arrays should be treated as read-only; callers that need to
modify them should make a copy.
Each yielded value corresponds to one SQLite/decode chunk. Point arrays
should be treated as read-only; callers that need to modify them should
make a copy.
"""
if batch_size < 1:
raise ValueError('batch_size must be at least 1')
Expand All @@ -443,10 +446,19 @@ def stream_trajectory_arrays(
)

for rows in batched(traj_cur, batch_size):
yield from self._rows_to_trajectory_arrays(
list(rows), native_endian=native_endian,
row_list = list(rows)
trajectories = self._rows_to_trajectory_arrays(
row_list, native_endian=native_endian,
missing_as_nan=missing_as_nan
)
yield TrajectoryArrayBatch(
day=self.ref_date.date(),
db_path=self.db_file(),
row_count=len(row_list),
trajectory_count=len(trajectories),
point_count=sum(len(traj.points) for traj in trajectories),
trajectories=trajectories,
)

def _retrieve(
self,
Expand Down Expand Up @@ -543,7 +555,7 @@ def _rows_to_trajectory_arrays(
*,
native_endian: bool,
missing_as_nan: bool,
) -> Generator[TrajectoryArray, None, None]:
) -> tuple[TrajectoryArray, ...]:
# Split into at most N_WORKERS chunks so thread-pool overhead is O(workers)
# rather than O(trajectories). Array preparation runs inside the threads too.
n_chunks = min(_N_WORKERS, len(rows))
Expand All @@ -554,8 +566,10 @@ def _rows_to_trajectory_arrays(
_pool.submit(_process_array_chunk, chunk, native_endian, missing_as_nan)
for chunk in chunks
]
trajectories: list[TrajectoryArray] = []
for future in futures:
yield from future.result()
trajectories.extend(future.result())
return tuple(trajectories)

def _build_sql_conditions(self, conditions, ids, source_ids):
sql_conditions = [p[0] for p in conditions]
Expand Down
20 changes: 19 additions & 1 deletion libs/common/src/feder_common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections import Counter
from dataclasses import dataclass
from datetime import datetime, timezone
from datetime import date, datetime, timezone
from enum import Enum
from operator import attrgetter
from typing import Self, cast
Expand Down Expand Up @@ -152,6 +152,24 @@ class TrajectoryArray:
"""Was the trajectory generated from a query using waypoint filtering?"""


@dataclass(slots=True)
class TrajectoryArrayBatch:
"""A decoded batch of trajectory arrays from one daily database."""

day: date
"""The Feder day represented by the streamed database."""
db_path: str
"""Path to the SQLite database streamed for this batch."""
row_count: int
"""Number of SQLite trajectory rows decoded for this batch."""
trajectory_count: int
"""Number of trajectories yielded in this batch."""
point_count: int
"""Total number of decoded point rows in this batch."""
trajectories: tuple[TrajectoryArray, ...]
"""Decoded trajectories in SQLite row ID order within the batch."""


@dataclass(slots=True)
class Trajectory:
"""A single flight trajectory."""
Expand Down
9 changes: 4 additions & 5 deletions scripts/stream_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,15 @@ def main() -> None:
waypoint_count = 0

started = perf_counter()
for trajectory in stream_trajectory_arrays(
for batch in stream_trajectory_arrays(
day,
batch_size=args.batch_size,
native_endian=not args.raw_arrays,
missing_as_nan=not args.raw_arrays,
):
trajectory_count += 1
waypoint_count += len(trajectory.points)
if trajectory_count % args.batch_size == 0:
print(f'processed {trajectory_count} trajectories...', file=sys.stderr)
trajectory_count += batch.trajectory_count
waypoint_count += batch.point_count
print(f'processed {trajectory_count} trajectories...', file=sys.stderr)
elapsed = perf_counter() - started

print(f'database_file: {db_file}')
Expand Down
31 changes: 24 additions & 7 deletions tests/api/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import pytest

from feder import TrajectoryArray, stream_trajectories, stream_trajectory_arrays
from feder import (
TrajectoryArray, TrajectoryArrayBatch, stream_trajectories,
stream_trajectory_arrays
)
from feder.common import DB


Expand Down Expand Up @@ -51,20 +54,34 @@ def test_stream_trajectories_rejects_invalid_batch_size():
list(stream_trajectories(DAY, batch_size=0))


def test_stream_trajectory_arrays_returns_trajectory_arrays():
trajectories = list(stream_trajectory_arrays(DAY))
assert len(trajectories) > 0
def test_stream_trajectory_arrays_returns_trajectory_array_batches():
batches = list(stream_trajectory_arrays(DAY))
trajectories = [traj for batch in batches for traj in batch.trajectories]
assert len(batches) > 0
assert all(isinstance(batch, TrajectoryArrayBatch) for batch in batches)
assert all(batch.day == DAY for batch in batches)
assert all(batch.row_count == batch.trajectory_count for batch in batches)
assert all(
batch.point_count == sum(len(traj.points) for traj in batch.trajectories)
for batch in batches
)
assert all(isinstance(traj, TrajectoryArray) for traj in trajectories)
assert all(not traj.partial for traj in trajectories)


def test_stream_trajectory_arrays_count_matches_stream_trajectories():
assert len(list(stream_trajectory_arrays(DAY))) == len(list(stream_trajectories(DAY)))
batches = list(stream_trajectory_arrays(DAY))
streamed_count = sum(batch.trajectory_count for batch in batches)
assert streamed_count == len(list(stream_trajectories(DAY)))


def test_stream_trajectory_arrays_small_batch_returns_all_rows():
expected = len(list(stream_trajectory_arrays(DAY)))
assert len(list(stream_trajectory_arrays(DAY, batch_size=1))) == expected
expected = sum(
batch.trajectory_count for batch in stream_trajectory_arrays(DAY)
)
batches = list(stream_trajectory_arrays(DAY, batch_size=1))
assert len(batches) == expected
assert all(batch.row_count == 1 for batch in batches)


@pytest.mark.parametrize(
Expand Down
35 changes: 27 additions & 8 deletions tests/common_lib/test_db_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pytest

from feder_common import DB, DataSource, Point
from feder_common import DB, DataSource, Point, TrajectoryArrayBatch
from feder_common.db import _BLOB_VERSION


Expand Down Expand Up @@ -44,12 +44,22 @@ def test_stream_trajectories_rejects_invalid_batch_size(batch_size):
db.close()


def test_stream_trajectory_arrays_returns_all_rows():
def test_stream_trajectory_arrays_returns_batches_with_all_rows():
db = DB(str(DATA_DIR), REF_DATE)
try:
trajectories = list(db.stream_trajectory_arrays())
batches = list(db.stream_trajectory_arrays())
trajectories = [traj for batch in batches for traj in batch.trajectories]
assert sum(batch.row_count for batch in batches) == db.size()
assert sum(batch.trajectory_count for batch in batches) == db.size()
assert len(trajectories) == db.size()
assert len(trajectories) > 0
assert len(batches) > 0
assert all(isinstance(batch, TrajectoryArrayBatch) for batch in batches)
assert all(batch.day == REF_DATE for batch in batches)
assert all(batch.db_path == db.db_file() for batch in batches)
assert all(
batch.point_count == sum(len(traj.points) for traj in batch.trajectories)
for batch in batches
)
assert all(not traj.partial for traj in trajectories)
assert all(isinstance(traj.points, np.ndarray) for traj in trajectories)
finally:
Expand All @@ -59,7 +69,10 @@ def test_stream_trajectory_arrays_returns_all_rows():
def test_stream_trajectory_arrays_small_batch_returns_all_rows():
db = DB(str(DATA_DIR), REF_DATE)
try:
assert len(list(db.stream_trajectory_arrays(batch_size=1))) == db.size()
batches = list(db.stream_trajectory_arrays(batch_size=1))
assert len(batches) == db.size()
assert all(batch.row_count == 1 for batch in batches)
assert sum(batch.trajectory_count for batch in batches) == db.size()
finally:
db.close()

Expand All @@ -77,7 +90,8 @@ def test_stream_trajectory_arrays_rejects_invalid_batch_size(batch_size):
def test_stream_trajectory_arrays_defaults_to_native_endian():
db = DB(str(DATA_DIR), REF_DATE)
try:
traj = next(db.stream_trajectory_arrays())
batch = next(db.stream_trajectory_arrays())
traj = batch.trajectories[0]
assert traj.points.dtype['time'].byteorder in ('=', '|')
assert traj.points.dtype['lon'].byteorder in ('=', '|')
finally:
Expand All @@ -87,9 +101,10 @@ def test_stream_trajectory_arrays_defaults_to_native_endian():
def test_stream_trajectory_arrays_fast_mode_returns_raw_endian():
db = DB(str(DATA_DIR), REF_DATE)
try:
traj = next(db.stream_trajectory_arrays(
batch = next(db.stream_trajectory_arrays(
native_endian=False, missing_as_nan=False
))
traj = batch.trajectories[0]
assert traj.points.dtype['time'].byteorder == '>'
assert traj.points.dtype['lon'].byteorder == '>'
finally:
Expand Down Expand Up @@ -145,7 +160,11 @@ def test_stream_trajectory_arrays_converts_missing_values_to_nan(tmp_path):

db = DB(str(data_dir), day)
try:
traj = next(db.stream_trajectory_arrays())
batch = next(db.stream_trajectory_arrays())
traj = batch.trajectories[0]
assert batch.row_count == 1
assert batch.trajectory_count == 1
assert batch.point_count == 2
assert np.isnan(traj.points['alt'][0])
assert np.isnan(traj.points['alt_gnss'][0])
assert np.isnan(traj.points['heading'][0])
Expand Down
Loading