Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lightcurvedb/models/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ class InstrumentNotFoundException(StorageException):

class CutoutNotFoundException(StorageException):
pass


class FluxMeasurementNotFoundException(StorageException):
pass
4 changes: 4 additions & 0 deletions lightcurvedb/models/feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ class FeedResult(BaseModel):
frequency: int

total_number_of_sources: int

@property
def band_name(self) -> str:
return f"f{self.frequency}"
11 changes: 8 additions & 3 deletions lightcurvedb/storage/postgres/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def get_source_statistics_for_frequency_and_module(
"source_id = %(source_id)s",
"frequency = %(frequency)s",
]
params: dict[str, int | str | datetime] = {
params: dict[str, int | str | datetime | UUID] = {
"source_id": source_id,
"frequency": frequency,
}
Expand All @@ -58,10 +58,11 @@ async def get_source_statistics_for_frequency_and_module(
where_clauses.append("module = %(module)s")
params["module"] = module

module_col = "%(module)s" if module != "all" else "'all'"
query = f"""
SELECT
%(source_id)s as source_id,
{"%(module)s" if module != "all" else "'all'"} as module,
{module_col} as module,
%(frequency)s as frequency,
COUNT(*) as measurement_count,
MIN(flux) as min_flux,
Expand All @@ -80,11 +81,15 @@ async def get_source_statistics_for_frequency_and_module(
WHERE {" AND ".join(where_clauses)}
"""

async with self.flux_storage.conn.cursor(
async with self.flux_storage.cursor(
row_factory=class_row(SourceStatistics)
) as cur:
await cur.execute(query, params)
row = await cur.fetchone()
if row is None:
raise ValueError(
f"No statistics found for source {source_id}"
)
return row

async def get_source_statistics_for_frequency(
Expand Down
13 changes: 10 additions & 3 deletions lightcurvedb/storage/postgres/cutout.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def setup(self) -> None:
await cur.execute(CUTOUT_SCHEMA)
await cur.execute(CUTOUT_INDEXES)

async def create(self, cutout: Cutout):
async def create(self, cutout: Cutout) -> UUID:
"""
Store a cutout for a given source and band.
"""
Expand Down Expand Up @@ -53,9 +53,11 @@ async def create(self, cutout: Cutout):
async with self.cursor() as cur:
await cur.execute(query, params)

if cutout.measurement_id is None:
raise ValueError("Cutout measurement_id must not be None after creation")
return cutout.measurement_id

async def create_batch(self, cutouts: list[Cutout]):
async def create_batch(self, cutouts: list[Cutout]) -> list[UUID]:
"""
Store a cutout for a given source and band.
"""
Expand Down Expand Up @@ -88,7 +90,12 @@ async def create_batch(self, cutouts: list[Cutout]):
async with self.cursor() as cur:
await cur.executemany(query, params_list)

return [c.measurement_id for c in cutouts]
measurement_ids: list[UUID] = []
for c in cutouts:
if c.measurement_id is None:
raise ValueError("Cutout measurement_id must not be None after creation")
measurement_ids.append(c.measurement_id)
return measurement_ids

async def retrieve_cutout(self, source_id: UUID, measurement_id: UUID) -> Cutout:
"""
Expand Down
15 changes: 11 additions & 4 deletions lightcurvedb/storage/postgres/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ async def create(self, measurement: FluxMeasurementCreate) -> UUID:
async with self.cursor() as cur:
await cur.execute(query, params)
row = await cur.fetchone()
if row:
return row[0]
else:
raise ValueError("Unexpected error retrieving generated UUID")
if row is None:
raise ValueError("INSERT RETURNING measurement_id returned no row")
return row[0]

async def create_batch(
self, measurements: list[FluxMeasurementCreate]
Expand Down Expand Up @@ -110,6 +109,14 @@ async def get(self, measurement_id: UUID) -> FluxMeasurement:
async with self.cursor(row_factory=class_row(FluxMeasurement)) as cur:
await cur.execute(query, {"measurement_id": measurement_id})
row = await cur.fetchone()
if row is None:
from lightcurvedb.models.exceptions import (
FluxMeasurementNotFoundException,
)

raise FluxMeasurementNotFoundException(
f"FluxMeasurement {measurement_id} not found"
)
return row

async def delete(self, measurement_id: UUID) -> None:
Expand Down
6 changes: 4 additions & 2 deletions lightcurvedb/storage/postgres/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ async def create(self, instrument: Instrument) -> str:
async with self.cursor() as cur:
await cur.execute(query, params)
row = await cur.fetchone()
if row is None:
raise ValueError("INSERT RETURNING instrument returned no row")
return row[0]

async def create_batch(self, instruments: list[Instrument]) -> int:
async def create_batch(self, instruments: list[Instrument]) -> list[str]:
"""
Bulk insert instruments.
"""
Expand Down Expand Up @@ -71,7 +73,7 @@ async def create_batch(self, instruments: list[Instrument]) -> int:
async with self.cursor() as cur:
await cur.execute(query, data)

return len(instruments)
return [instrument.instrument for instrument in instruments]

async def get(self, frequency: int, module: str) -> Instrument:
"""Get instrument by frequency and module."""
Expand Down
84 changes: 73 additions & 11 deletions lightcurvedb/storage/postgres/lightcurves.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import asyncio
import datetime
from typing import Literal
from typing import Literal, overload
from uuid import UUID

from psycopg.rows import class_row
Expand Down Expand Up @@ -60,7 +60,7 @@ async def get_instrument_lightcurve(
)
"""

async with self.flux_storage.conn.cursor(
async with self.flux_storage.cursor(
row_factory=class_row(InstrumentLightcurve)
) as cur:
await cur.execute(
Expand All @@ -72,7 +72,13 @@ async def get_instrument_lightcurve(
"limit": limit,
},
)
return await cur.fetchone()
row = await cur.fetchone()
if row is None:
raise ValueError(
f"No instrument lightcurve found for source {source_id}, "
f"module {module}, frequency {frequency}"
)
return row

async def get_frequency_lightcurve(
self, source_id: UUID, frequency: int, limit: int = 1000000
Expand Down Expand Up @@ -102,13 +108,19 @@ async def get_frequency_lightcurve(
)
"""

async with self.flux_storage.conn.cursor(
async with self.flux_storage.cursor(
row_factory=class_row(FrequencyLightcurve)
) as cur:
await cur.execute(
query, {"source_id": source_id, "frequency": frequency, "limit": limit}
)
return await cur.fetchone()
row = await cur.fetchone()
if row is None:
raise ValueError(
f"No frequency lightcurve found for source {source_id}, "
f"frequency {frequency}"
)
return row

async def get_binned_instrument_lightcurve(
self,
Expand Down Expand Up @@ -159,7 +171,7 @@ async def get_binned_instrument_lightcurve(
) AS binned
"""

async with self.flux_storage.conn.cursor(
async with self.flux_storage.cursor(
row_factory=class_row(BinnedInstrumentLightcurve)
) as cur:
await cur.execute(
Expand All @@ -174,7 +186,13 @@ async def get_binned_instrument_lightcurve(
"limit": limit,
},
)
return await cur.fetchone()
row = await cur.fetchone()
if row is None:
raise ValueError(
f"No binned instrument lightcurve found for source {source_id}, "
f"module {module}, frequency {frequency}"
)
return row

async def get_binned_frequency_lightcurve(
self,
Expand Down Expand Up @@ -225,7 +243,7 @@ async def get_binned_frequency_lightcurve(
) AS binned
"""

async with self.flux_storage.conn.cursor(
async with self.flux_storage.cursor(
row_factory=class_row(BinnedFrequencyLightcurve)
) as cur:
await cur.execute(
Expand All @@ -239,7 +257,13 @@ async def get_binned_frequency_lightcurve(
"limit": limit,
},
)
return await cur.fetchone()
row = await cur.fetchone()
if row is None:
raise ValueError(
f"No binned frequency lightcurve found for source {source_id}, "
f"frequency {frequency}"
)
return row

async def get_frequencies_for_source(self, source_id: UUID) -> list[int]:
"""
Expand All @@ -251,7 +275,7 @@ async def get_frequencies_for_source(self, source_id: UUID) -> list[int]:
WHERE source_id = %(source_id)s
"""

async with self.flux_storage.conn.cursor() as cur:
async with self.flux_storage.cursor() as cur:
await cur.execute(query, {"source_id": source_id})
rows = await cur.fetchall()
return [row[0] for row in rows]
Expand All @@ -268,11 +292,27 @@ async def get_module_frequency_pairs_for_source(
WHERE source_id = %(source_id)s
"""

async with self.flux_storage.conn.cursor() as cur:
async with self.flux_storage.cursor() as cur:
await cur.execute(query, {"source_id": source_id})
rows = await cur.fetchall()
return [(row[1], row[0]) for row in rows]

@overload
async def get_source_lightcurve(
self,
source_id: UUID,
selection_strategy: Literal["frequency"],
limit: int = 1000000,
) -> SourceLightcurveFrequency: ...

@overload
async def get_source_lightcurve(
self,
source_id: UUID,
selection_strategy: Literal["instrument"],
limit: int = 1000000,
) -> SourceLightcurveInstrument: ...

async def get_source_lightcurve(
self,
source_id: UUID,
Expand Down Expand Up @@ -318,6 +358,28 @@ async def get_source_lightcurve(
else:
raise ValueError(f"Invalid strategy: {selection_strategy}")

@overload
async def get_binned_source_lightcurve(
self,
source_id: UUID,
selection_strategy: Literal["frequency"],
binning_strategy: Literal["1 day", "7 days", "30 days"],
start_time: datetime.datetime,
end_time: datetime.datetime,
limit: int = 1000000,
) -> SourceLightcurveBinnedFrequency: ...

@overload
async def get_binned_source_lightcurve(
self,
source_id: UUID,
selection_strategy: Literal["instrument"],
binning_strategy: Literal["1 day", "7 days", "30 days"],
start_time: datetime.datetime,
end_time: datetime.datetime,
limit: int = 1000000,
) -> SourceLightcurveBinnedInstrument: ...

async def get_binned_source_lightcurve(
self,
source_id: UUID,
Expand Down
29 changes: 22 additions & 7 deletions lightcurvedb/storage/postgres/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,36 @@
"""

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import Any, overload

from psycopg import AsyncClientCursor
from psycopg.cursor_async import AsyncRowFactory, Row
from psycopg.rows import BaseRowFactory, Row
from psycopg_pool import AsyncConnectionPool


class PostgresPoolUser:
def __init__(self, pool: AsyncConnectionPool):
self.pool = pool

@asynccontextmanager
@overload
def cursor(
self,
*,
row_factory: BaseRowFactory[Row],
) -> AbstractAsyncContextManager[AsyncClientCursor[Row]]: ...

@overload
def cursor(self) -> AbstractAsyncContextManager[AsyncClientCursor[Any]]: ...

@asynccontextmanager # type: ignore[misc]
async def cursor(
self, *, row_factory: AsyncRowFactory[Row] | None = None, **kwargs
) -> AsyncIterator[AsyncClientCursor[Row]]:
self, *, row_factory: BaseRowFactory[Row] | None = None
) -> AsyncIterator[AsyncClientCursor[Any]]:
async with self.pool.connection() as conn:
async with conn.cursor(row_factory=row_factory, **kwargs) as cur:
yield cur
if row_factory is not None:
async with conn.cursor(row_factory=row_factory) as cur:
yield cur
else:
async with conn.cursor() as cur:
yield cur
2 changes: 2 additions & 0 deletions lightcurvedb/storage/postgres/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ async def create(self, source: Source) -> UUID:
await cur.execute(query, params)
row = await cur.fetchone()

if row is None:
raise ValueError("INSERT RETURNING source_id returned no row")
return row[0]

async def create_batch(self, sources: list[Source]) -> list[UUID]:
Expand Down
4 changes: 2 additions & 2 deletions lightcurvedb/storage/prototype/cutout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ async def setup(self) -> None:
"""
...

async def create(self, cutout: Cutout) -> int:
async def create(self, cutout: Cutout) -> UUID:
"""
Store a cutout for a given source and band.
"""
...

async def create_batch(self, cutouts: list[Cutout]) -> list[int]:
async def create_batch(self, cutouts: list[Cutout]) -> list[UUID]:
"""
Store a cutout for a given source and band.
"""
Expand Down
Loading
Loading