Skip to content

Commit 9ed6b51

Browse files
authored
#29: add task for nature import to layer 2 (#397)
* #29: add layer2 task for nature import * add orphans pgcs test * last update tests * style checks * add dry run * move log level to main task * add pgc index * add better logging * use execute many for batched insertions * combine migrations
1 parent cadf7dd commit 9ed6b51

15 files changed

Lines changed: 427 additions & 20 deletions

app/data/model/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
CIResultObjectCollision,
1616
CIResultObjectExisting,
1717
CIResultObjectNew,
18+
NatureRecord,
1819
Record,
1920
RecordCrossmatch,
2021
RecordWithPGC,
@@ -45,8 +46,9 @@
4546
"CIResultObjectCollision",
4647
"CIResultObjectExisting",
4748
"CIResultObjectNew",
48-
"RecordWithPGC",
49+
"NatureRecord",
4950
"Record",
51+
"RecordWithPGC",
5052
"Layer2CatalogObject",
5153
"Layer2Object",
5254
"RawCatalog",

app/data/model/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class RawCatalog(enum.Enum):
3030
aggregated data on layer 2.
3131
"""
3232

33+
ALL = "all"
3334
ICRS = "icrs"
3435
DESIGNATION = "designation"
3536
REDSHIFT = "redshift"

app/data/model/records.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,10 @@ class RecordCrossmatch:
4646
class RecordWithPGC:
4747
pgc: int
4848
record: Record
49+
50+
51+
@dataclass
52+
class NatureRecord:
53+
pgc: int
54+
record_id: str
55+
type_name: str

app/data/repositories/layer1.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ def save_structured_data(self, table: str, columns: list[str], ids: list[str], d
3333
)
3434
rows = [[rid] + vals for rid, vals in zip(ids, data, strict=True)]
3535
with self.with_tx():
36-
cursor = self._storage.get_connection().cursor()
37-
cursor.executemany(query, rows)
36+
self._storage.execute_batch(query, rows)
3837

3938
def save_data(self, records: list[model.Record]) -> None:
4039
all_catalog_objects = []
@@ -126,6 +125,22 @@ def get_new_observations(
126125

127126
return records
128127

128+
def get_new_nature_records(self, dt: datetime.datetime, limit: int, offset: int) -> list[model.NatureRecord]:
129+
query = """SELECT o.pgc, l1.record_id, l1.type_name
130+
FROM nature.data AS l1
131+
JOIN layer0.records AS o ON l1.record_id = o.id
132+
WHERE o.pgc IN (
133+
SELECT DISTINCT o.pgc
134+
FROM nature.data AS l1
135+
JOIN layer0.records AS o ON l1.record_id = o.id
136+
WHERE o.modification_time > %s AND o.pgc > %s
137+
ORDER BY o.pgc
138+
LIMIT %s
139+
)
140+
ORDER BY o.pgc ASC"""
141+
rows = self._storage.query(query, params=[dt, offset, limit])
142+
return [model.NatureRecord(pgc=int(r["pgc"]), record_id=r["record_id"], type_name=r["type_name"]) for r in rows]
143+
129144
def query_records(
130145
self,
131146
catalogs: list[model.RawCatalog],

app/data/repositories/layer2/repository.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@ def __init__(self, storage: postgres.PgStorage, logger: structlog.stdlib.BoundLo
2424
self._logger = logger
2525
self._storage = storage
2626

27-
def get_last_update_time(self) -> datetime.datetime:
28-
return self._storage.query_one("SELECT dt FROM layer2.last_update WHERE catalog = %s", params=["all"])["dt"]
27+
def get_last_update_time(self, catalog: model.RawCatalog) -> datetime.datetime:
28+
return self._storage.query_one("SELECT dt FROM layer2.last_update WHERE catalog = %s", params=[catalog.value])[
29+
"dt"
30+
]
2931

30-
def update_last_update_time(self, dt: datetime.datetime):
32+
def update_last_update_time(self, dt: datetime.datetime, catalog: model.RawCatalog) -> None:
3133
self._storage.exec(
3234
"UPDATE layer2.last_update SET dt = %s WHERE catalog = %s",
33-
params=[dt, "all"],
35+
params=[dt, catalog.value],
3436
)
3537

3638
def get_orphaned_pgcs(self, catalogs: list[model.RawCatalog]) -> dict[str, list[int]]:
@@ -96,6 +98,20 @@ def save_data(self, objects: list[model.Layer2CatalogObject]):
9698

9799
self._storage.exec(query, params=params)
98100

101+
def save(self, table: str, columns: list[str], pgcs: list[int], data: list[list[Any]]) -> None:
102+
if not pgcs:
103+
return
104+
all_columns = ["pgc"] + columns
105+
placeholders = ",".join(["%s"] * len(all_columns))
106+
on_conflict = ", ".join([f"{c} = EXCLUDED.{c}" for c in all_columns])
107+
query = (
108+
f"INSERT INTO {table} ({', '.join(all_columns)}) VALUES ({placeholders}) "
109+
f"ON CONFLICT (pgc) DO UPDATE SET {on_conflict}"
110+
)
111+
rows = [[pgc, *row] for pgc, row in zip(pgcs, data, strict=True)]
112+
with self.with_tx():
113+
self._storage.execute_batch(query, rows)
114+
99115
def _construct_batch_query(
100116
self,
101117
catalogs: list[model.RawCatalog],

app/lib/storage/postgres/postgres_storage.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Sequence
12
from typing import Any
23

34
import numpy as np
@@ -101,6 +102,15 @@ def exec(self, query: str | sql.SQL | sql.Composed, *, params: list[Any] | None
101102
cursor = self._connection.cursor()
102103
cursor.execute(query, params)
103104

105+
def execute_batch(self, query: str, rows: Sequence[Sequence[Any]]) -> None:
106+
if self._connection is None:
107+
raise RuntimeError("Unable to execute query: connection to Postgres was not established")
108+
109+
log.debug("SQL execute batch", query=query.replace("\n", " "), num_rows=len(rows))
110+
111+
cursor = self._connection.cursor()
112+
cursor.executemany(query, rows)
113+
104114
def query(self, query: str | sql.SQL | sql.Composed, *, params: list[Any] | None = None) -> list[rows.DictRow]:
105115
if params is None:
106116
params = []

app/tasks/layer2_import.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def prepare(self, config: interface.Config):
3737
self.layer2_repository = repositories.Layer2Repository(self.pg_storage, self.log)
3838

3939
def run(self):
40-
last_update_dt = self.layer2_repository.get_last_update_time()
40+
last_update_dt = self.layer2_repository.get_last_update_time(model.RawCatalog.ALL)
4141

4242
self.log.info("Starting Layer 2 import", last_update=last_update_dt.ctime())
4343

@@ -77,7 +77,7 @@ def run(self):
7777

7878
self.log.info("Updated catalog", catalog=catalog.value)
7979

80-
self.layer2_repository.update_last_update_time(datetime.datetime.now(tz=datetime.UTC))
80+
self.layer2_repository.update_last_update_time(datetime.datetime.now(tz=datetime.UTC), model.RawCatalog.ALL)
8181
self.log.info("Layer 2 import completed", last_update=last_update_dt.ctime())
8282

8383
def cleanup(self):

app/tasks/layer2_import_nature.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import datetime
2+
from typing import final
3+
4+
import structlog
5+
6+
from app.data import model, repositories
7+
from app.lib import containers
8+
from app.lib.storage import postgres
9+
from app.tasks import interface
10+
11+
12+
@final
13+
class Layer2ImportNatureTask(interface.Task):
14+
def __init__(
15+
self,
16+
logger: structlog.stdlib.BoundLogger,
17+
batch_size: int = 100000,
18+
dry_run: bool = False,
19+
) -> None:
20+
self.log = logger
21+
self.batch_size = batch_size
22+
self.dry_run = dry_run
23+
24+
@classmethod
25+
def name(cls) -> str:
26+
return "layer2-import-nature"
27+
28+
def prepare(self, config: interface.Config) -> None:
29+
self.pg_storage = postgres.PgStorage(config.storage, self.log)
30+
self.pg_storage.connect()
31+
self.layer1_repository = repositories.Layer1Repository(self.pg_storage, self.log)
32+
self.layer2_repository = repositories.Layer2Repository(self.pg_storage, self.log)
33+
34+
def run(self) -> None:
35+
last_update_dt = self.layer2_repository.get_last_update_time(model.RawCatalog.NATURE)
36+
self.log.info(
37+
"Starting Layer 2 nature import",
38+
last_update=last_update_dt.ctime(),
39+
dry_run=self.dry_run,
40+
)
41+
42+
objects_to_save = 0
43+
type_distribution: dict[str, int] = {}
44+
for offset, records in containers.read_batches(
45+
self.layer1_repository.get_new_nature_records,
46+
lambda data: len(data) == 0,
47+
0,
48+
lambda d, _: d[-1].pgc,
49+
last_update_dt,
50+
batch_size=self.batch_size,
51+
):
52+
records_by_pgc = containers.group_by(records, key_func=lambda r: r.pgc)
53+
pgcs: list[int] = []
54+
data: list[list[str]] = []
55+
for pgc, pgc_records in records_by_pgc.items():
56+
type_counts: dict[str, int] = {}
57+
for rec in pgc_records:
58+
type_counts[rec.type_name] = type_counts.get(rec.type_name, 0) + 1
59+
max_type = max(type_counts, key=lambda k: type_counts[k])
60+
type_distribution[max_type] = type_distribution.get(max_type, 0) + 1
61+
pgcs.append(pgc)
62+
data.append([max_type])
63+
if pgcs:
64+
objects_to_save += len(pgcs)
65+
if not self.dry_run:
66+
self.layer2_repository.save("layer2.nature", ["type_name"], pgcs, data)
67+
self.log.info(
68+
"Processed batch",
69+
last_pgc=offset,
70+
batch_size=len(records),
71+
total_processed=objects_to_save,
72+
)
73+
74+
orphaned = self.layer2_repository.get_orphaned_pgcs([model.RawCatalog.NATURE])
75+
pgcs_to_remove = [pgc for pgcs in orphaned.values() for pgc in pgcs]
76+
orphans_to_delete = len(pgcs_to_remove)
77+
if pgcs_to_remove and not self.dry_run:
78+
self.layer2_repository.remove_pgcs([model.RawCatalog.NATURE], pgcs_to_remove)
79+
80+
if not self.dry_run:
81+
self.layer2_repository.update_last_update_time(
82+
datetime.datetime.now(tz=datetime.UTC), model.RawCatalog.NATURE
83+
)
84+
self.log.info("Layer 2 nature import completed", last_update=last_update_dt.ctime())
85+
86+
if self.dry_run:
87+
self._print_summary(objects_to_save, orphans_to_delete, type_distribution)
88+
89+
def _print_summary(
90+
self,
91+
objects_to_save: int,
92+
orphans_to_delete: int,
93+
type_distribution: dict[str, int],
94+
) -> None:
95+
col_desc = "Description"
96+
col_count = "Count"
97+
type_rows = [(t, c) for t, c in sorted(type_distribution.items())]
98+
width_desc = max(
99+
len(col_desc),
100+
30,
101+
len("Distribution by type"),
102+
*(len(f" {t}") for t, _ in type_rows) if type_rows else [0],
103+
)
104+
width_count = max(
105+
len(col_count),
106+
len(str(objects_to_save)),
107+
len(str(orphans_to_delete)),
108+
*(len(str(c)) for _, c in type_rows) if type_rows else [0],
109+
)
110+
sep = f"+{'-' * (width_desc + 2)}+{'-' * (width_count + 2)}+"
111+
lines = [
112+
sep,
113+
f"| {col_desc:<{width_desc}} | {col_count:>{width_count}} |",
114+
sep,
115+
f"| {'Objects to be saved':<{width_desc}} | {objects_to_save:>{width_count}} |",
116+
f"| {'Orphans to be deleted':<{width_desc}} | {orphans_to_delete:>{width_count}} |",
117+
sep,
118+
]
119+
if type_rows:
120+
lines.append(f"| {'Distribution by type':<{width_desc}} | {'':>{width_count}} |")
121+
lines.extend([f"| {f' {t}':<{width_desc}} | {c:>{width_count}} |" for t, c in type_rows])
122+
lines.append(sep)
123+
for line in lines:
124+
print(line)
125+
126+
def cleanup(self) -> None:
127+
self.pg_storage.disconnect()

app/tasks/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
interface,
88
layer0_marking,
99
layer2_import,
10+
layer2_import_nature,
1011
layer2_orphan_cleanup,
1112
submit_crossmatch,
1213
)
@@ -16,6 +17,7 @@
1617
layer0_marking.Layer0MarkingTask,
1718
submit_crossmatch.SubmitCrossmatchTask,
1819
layer2_import.Layer2ImportTask,
20+
layer2_import_nature.Layer2ImportNatureTask,
1921
layer2_orphan_cleanup.Layer2OrphanCleanupTask,
2022
]
2123

main.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,16 @@
1010

1111

1212
@click.group()
13-
def cli():
14-
pass
13+
@click.option(
14+
"--log-level",
15+
type=click.Choice(["debug", "info", "warning", "error", "critical"], case_sensitive=False),
16+
default="info",
17+
help="Set the logging level (for runtask and other commands that use it)",
18+
)
19+
@click.pass_context
20+
def cli(ctx: click.Context, log_level: str) -> None:
21+
ctx.ensure_object(dict)
22+
ctx.obj["log_level"] = log_level
1523

1624

1725
@cli.command(short_help=AdminAPICommand.help())
@@ -50,14 +58,15 @@ def dataapi(config: str):
5058
type=str,
5159
help="Path to input data file",
5260
)
53-
@click.option(
54-
"--log-level",
55-
type=click.Choice(["debug", "info", "warning", "error", "critical"], case_sensitive=False),
56-
default="info",
57-
help="Set the logging level",
58-
)
61+
@click.pass_context
5962
@click.argument("task_args", nargs=-1, type=click.UNPROCESSED)
60-
def runtask(task_name: str, input_data: str | None, log_level: str, task_args: tuple[str, ...]):
63+
def runtask(
64+
ctx: click.Context,
65+
task_name: str,
66+
input_data: str | None,
67+
task_args: tuple[str, ...],
68+
) -> None:
69+
log_level = (ctx.obj or {}).get("log_level", "info")
6170
commands.run(RunTaskCommand(task_name, input_data, None, task_args, log_level))
6271

6372

0 commit comments

Comments
 (0)