Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions app/features/seeder/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,4 +475,11 @@ class Phase2EnrichmentResponse(BaseModel):
"(product, replenishment_event, exogenous_signal, sales_returns)."
),
)
records_skipped: dict[str, int] = Field(
default_factory=dict,
description=(
"Count of rows skipped per table on an idempotent re-run "
"(populated when prior phase-2 data is already present in scope)."
),
)
duration_ms: float = Field(description="Wall-clock duration in milliseconds.")
294 changes: 201 additions & 93 deletions app/features/seeder/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

from sqlalchemy import func, select, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession

from app.core.config import get_settings
from app.core.exceptions import UnprocessableEntityError
from app.core.exceptions import ConflictError, UnprocessableEntityError
from app.core.logging import get_logger
from app.features.data_platform.models import (
Calendar,
Expand Down Expand Up @@ -897,16 +898,28 @@ async def phase2_enrichment(
4. INSERT ``sales_returns`` rows sampled from the existing
positive-quantity ``sales_daily`` rows.

Idempotent (#312): a second call against an already-enriched scope is a
no-op for the inserting steps — ``exogenous_signal`` uses Postgres
``ON CONFLICT DO NOTHING`` against its two partial unique indexes, while
``replenishment_event`` and ``sales_returns`` (no natural-key unique
constraint) skip the section when rows already exist within the seeded
date range. The lifecycle ``UPDATE`` block is naturally idempotent under
a fixed seed.

Args:
db: Async database session.
params: Caller-supplied seed + probabilities.

Returns:
Phase2EnrichmentResponse with per-table row counts and wall-clock.
Phase2EnrichmentResponse with per-table ``records_created`` and
``records_skipped`` counts plus wall-clock.

Raises:
UnprocessableEntityError: When dimensions or calendar are empty
(caller must seed first); when the seeded calendar spans 0 days.
ConflictError: Defensive net — a residual ``IntegrityError`` (should
not fire after the idempotency logic above) is mapped to 409
RFC 7807 rather than bubbling as a raw 500.
"""
start_time = time.perf_counter()
rng = random.Random(params.seed)
Expand Down Expand Up @@ -934,112 +947,207 @@ async def phase2_enrichment(
seed=params.seed,
)

# ---- 1) Lifecycle: UPDATE per product
skipped: dict[str, int] = {
"product": 0,
"replenishment_event": 0,
"exogenous_signal": 0,
"sales_returns": 0,
}

try:
lifecycle_map = _assign_lifecycle(
rng,
product_ids,
start_date,
end_date,
discontinue_probability=params.discontinue_probability,
# ---- 1) Lifecycle: UPDATE per product (deterministic with seed → idempotent)
try:
lifecycle_map = _assign_lifecycle(
rng,
product_ids,
start_date,
end_date,
discontinue_probability=params.discontinue_probability,
)
except ValueError as exc:
raise UnprocessableEntityError(message=str(exc)) from exc
for pid, (launch, disc, stage) in lifecycle_map.items():
await db.execute(
update(Product)
.where(Product.id == pid)
.values(launch_date=launch, discontinue_date=disc, lifecycle_stage=stage)
)
product_updates = len(lifecycle_map)
await db.commit()

# ---- 2) Replenishment events (no unique constraint — section-level skip)
lt_cfg = LeadTimeConfig(
enable=True,
mean_lead_time_days=7,
lead_time_sigma_days=1.5,
safety_stock_days=3,
order_frequency_days=14,
fill_rate_mean=0.97,
fill_rate_sigma=0.05,
)
except ValueError as exc:
raise UnprocessableEntityError(message=str(exc)) from exc
for pid, (launch, disc, stage) in lifecycle_map.items():
await db.execute(
update(Product)
.where(Product.id == pid)
.values(launch_date=launch, discontinue_date=disc, lifecycle_stage=stage)
rep_gen = ReplenishmentGenerator(rng, lt_cfg)
rep_records = rep_gen.generate(store_ids, product_ids, dates, base_demand=100)
existing_rep = (
await db.scalar(
select(func.count())
.select_from(ReplenishmentEvent)
.where(
ReplenishmentEvent.date >= start_date,
ReplenishmentEvent.date <= end_date,
)
)
or 0
)
product_updates = len(lifecycle_map)
await db.commit()

# ---- 2) Replenishment events
lt_cfg = LeadTimeConfig(
enable=True,
mean_lead_time_days=7,
lead_time_sigma_days=1.5,
safety_stock_days=3,
order_frequency_days=14,
fill_rate_mean=0.97,
fill_rate_sigma=0.05,
)
rep_gen = ReplenishmentGenerator(rng, lt_cfg)
rep_records = rep_gen.generate(store_ids, product_ids, dates, base_demand=100)
for i in range(0, len(rep_records), PHASE2_ENRICHMENT_BATCH_SIZE):
chunk = rep_records[i : i + PHASE2_ENRICHMENT_BATCH_SIZE]
if chunk:
await db.execute(pg_insert(ReplenishmentEvent).values(chunk))
await db.commit()

# ---- 3) Exogenous signals (weather + macro)
ex_cfg = ExogenousSignalConfig(
enable_weather=True,
enable_macro=True,
enable_events=False,
weather_climatology_mean_c=15.0,
weather_amplitude_c=12.0,
weather_noise_sigma_c=2.0,
macro_initial_value=100.0,
macro_step_sigma=0.5,
)
ex_gen = ExogenousSignalGenerator(rng, ex_cfg)
ex_records = ex_gen.generate(dates, store_ids)
for i in range(0, len(ex_records), PHASE2_ENRICHMENT_BATCH_SIZE):
chunk = ex_records[i : i + PHASE2_ENRICHMENT_BATCH_SIZE]
if chunk:
await db.execute(pg_insert(ExogenousSignal).values(chunk))
await db.commit()

# ---- 4) Sales returns (sampled from existing positive-quantity sales)
ret_cfg = ReturnsConfig(
enable=True,
return_probability=params.returns_probability,
return_lag_days_min=1,
return_lag_days_max=14,
return_quantity_fraction=0.5,
)
ret_gen = ReturnsGenerator(rng, ret_cfg)
sales_rows = (
await db.execute(
select(
SalesDaily.date,
SalesDaily.store_id,
SalesDaily.product_id,
SalesDaily.quantity,
).where(SalesDaily.quantity > 0)
if existing_rep:
skipped["replenishment_event"] = len(rep_records)
rep_created = 0
logger.info(
"seeder.phase2_enrichment.skip",
table="replenishment_event",
existing_rows=existing_rep,
skipped=len(rep_records),
)
else:
for i in range(0, len(rep_records), PHASE2_ENRICHMENT_BATCH_SIZE):
chunk = rep_records[i : i + PHASE2_ENRICHMENT_BATCH_SIZE]
if chunk:
await db.execute(pg_insert(ReplenishmentEvent).values(chunk))
rep_created = len(rep_records)
await db.commit()

# ---- 3) Exogenous signals — ON CONFLICT DO NOTHING on the two partial unique
# indexes (uq_exogenous_signal_global, uq_exogenous_signal_per_store).
ex_cfg = ExogenousSignalConfig(
enable_weather=True,
enable_macro=True,
enable_events=False,
weather_climatology_mean_c=15.0,
weather_amplitude_c=12.0,
weather_noise_sigma_c=2.0,
macro_initial_value=100.0,
macro_step_sigma=0.5,
)
).fetchall()
sales_records: list[dict[str, date | int | Decimal]] = [
{
"date": r[0],
"store_id": r[1],
"product_id": r[2],
"quantity": int(r[3]),
}
for r in sales_rows
]
ret_records = ret_gen.generate(sales_records, end_date)
for i in range(0, len(ret_records), PHASE2_ENRICHMENT_BATCH_SIZE):
chunk = ret_records[i : i + PHASE2_ENRICHMENT_BATCH_SIZE]
if chunk:
await db.execute(pg_insert(SalesReturn).values(chunk))
await db.commit()
ex_gen = ExogenousSignalGenerator(rng, ex_cfg)
ex_records = ex_gen.generate(dates, store_ids)
existing_ex = (
await db.scalar(
select(func.count())
.select_from(ExogenousSignal)
.where(
ExogenousSignal.date >= start_date,
ExogenousSignal.date <= end_date,
)
)
or 0
)
for i in range(0, len(ex_records), PHASE2_ENRICHMENT_BATCH_SIZE):
chunk = ex_records[i : i + PHASE2_ENRICHMENT_BATCH_SIZE]
if chunk:
await db.execute(pg_insert(ExogenousSignal).values(chunk).on_conflict_do_nothing())
await db.commit()
new_ex = (
await db.scalar(
select(func.count())
.select_from(ExogenousSignal)
.where(
ExogenousSignal.date >= start_date,
ExogenousSignal.date <= end_date,
)
)
or 0
)
ex_created = max(0, new_ex - existing_ex)
skipped["exogenous_signal"] = max(0, len(ex_records) - ex_created)
if skipped["exogenous_signal"]:
logger.info(
"seeder.phase2_enrichment.skip",
table="exogenous_signal",
existing_rows=existing_ex,
skipped=skipped["exogenous_signal"],
)

# ---- 4) Sales returns (no unique constraint — section-level skip)
ret_cfg = ReturnsConfig(
enable=True,
return_probability=params.returns_probability,
return_lag_days_min=1,
return_lag_days_max=14,
return_quantity_fraction=0.5,
)
ret_gen = ReturnsGenerator(rng, ret_cfg)
sales_rows = (
await db.execute(
select(
SalesDaily.date,
SalesDaily.store_id,
SalesDaily.product_id,
SalesDaily.quantity,
).where(SalesDaily.quantity > 0)
)
).fetchall()
sales_records: list[dict[str, date | int | Decimal]] = [
{
"date": r[0],
"store_id": r[1],
"product_id": r[2],
"quantity": int(r[3]),
}
for r in sales_rows
]
ret_records = ret_gen.generate(sales_records, end_date)
existing_ret = (
await db.scalar(
select(func.count())
.select_from(SalesReturn)
.where(
SalesReturn.date >= start_date,
SalesReturn.date <= end_date,
)
)
or 0
)
if existing_ret:
skipped["sales_returns"] = len(ret_records)
ret_created = 0
logger.info(
"seeder.phase2_enrichment.skip",
table="sales_returns",
existing_rows=existing_ret,
skipped=len(ret_records),
)
else:
for i in range(0, len(ret_records), PHASE2_ENRICHMENT_BATCH_SIZE):
chunk = ret_records[i : i + PHASE2_ENRICHMENT_BATCH_SIZE]
if chunk:
await db.execute(pg_insert(SalesReturn).values(chunk))
ret_created = len(ret_records)
await db.commit()
except IntegrityError as exc:
await db.rollback()
raise ConflictError(
message=(
"Phase 2 enrichment hit a residual database constraint conflict "
"(idempotency guards should have caught this — please report)."
),
details={"error": str(exc.orig) if exc.orig else str(exc)},
) from exc

duration_ms = (time.perf_counter() - start_time) * 1000.0
counts = {
"product": product_updates,
"replenishment_event": len(rep_records),
"exogenous_signal": len(ex_records),
"sales_returns": len(ret_records),
"replenishment_event": rep_created,
"exogenous_signal": ex_created,
"sales_returns": ret_created,
}
logger.info(
"seeder.phase2_enrichment.complete",
duration_ms=duration_ms,
**counts,
records_created=counts,
records_skipped=skipped,
)
return schemas.Phase2EnrichmentResponse(
success=True,
records_created=counts,
records_skipped=skipped,
duration_ms=duration_ms,
)
Loading