Skip to content

Commit f09dec0

Browse files
authored
Merge pull request #236 from w7-mgfcode/fix/planner-model-exogenous-and-empty-dates
fix(jobs,ui): reach model_exogenous + block empty assumption dates in the planner
2 parents df612a1 + 34104c9 commit f09dec0

6 files changed

Lines changed: 218 additions & 11 deletions

File tree

app/features/jobs/schemas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class JobCreate(BaseModel):
2525
**Job Types and Required Params**:
2626
2727
- **train**: Train a forecasting model
28-
- `model_type`: Required - 'naive', 'seasonal_naive', 'linear_regression', etc.
28+
- `model_type`: Required - 'naive', 'seasonal_naive', 'moving_average', 'regression'.
2929
- `store_id`: Required - Store ID from /dimensions/stores
3030
- `product_id`: Required - Product ID from /dimensions/products
3131
- `start_date`: Required - Training data start (YYYY-MM-DD)

app/features/jobs/service.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ async def _execute_train(
426426
from app.features.forecasting.schemas import (
427427
MovingAverageModelConfig,
428428
NaiveModelConfig,
429+
RegressionModelConfig,
429430
SeasonalNaiveModelConfig,
430431
)
431432
from app.features.forecasting.service import ForecastingService
@@ -457,6 +458,12 @@ async def _execute_train(
457458
elif model_type == "moving_average":
458459
window_size = params.get("window_size", 7)
459460
config = MovingAverageModelConfig(window_size=window_size)
461+
elif model_type == "regression":
462+
config = RegressionModelConfig(
463+
max_iter=params.get("max_iter", 200),
464+
learning_rate=params.get("learning_rate", 0.05),
465+
max_depth=params.get("max_depth", 6),
466+
)
460467
else:
461468
msg = f"Unsupported model_type: {model_type}"
462469
raise ValueError(msg)

app/features/jobs/tests/test_service.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88

99
import math
1010
from datetime import date
11+
from typing import Any, cast
12+
from unittest.mock import AsyncMock, patch
13+
14+
import pytest
15+
from sqlalchemy.ext.asyncio import AsyncSession
1116

1217
from app.features.backtesting.schemas import (
1318
BacktestResponse,
@@ -16,7 +21,9 @@
1621
SplitBoundary,
1722
SplitConfig,
1823
)
19-
from app.features.jobs.service import _finite, _shape_backtest_result
24+
from app.features.forecasting.schemas import RegressionModelConfig, TrainResponse
25+
from app.features.forecasting.service import ForecastingService
26+
from app.features.jobs.service import JobService, _finite, _shape_backtest_result
2027

2128

2229
def _fold(idx: int, mae: float, smape: float, wape: float, bias: float) -> FoldResult:
@@ -158,3 +165,59 @@ def test_finite_coerces_non_finite_values() -> None:
158165
assert _finite(math.nan) == 0.0
159166
assert _finite(math.inf) == 0.0
160167
assert _finite(-math.inf) == 0.0
168+
169+
170+
# =============================================================================
171+
# _execute_train regression-model support (#229)
172+
# =============================================================================
173+
174+
175+
def _fake_train_response(model_type: str) -> TrainResponse:
176+
"""Build a TrainResponse stub for mocking ForecastingService.train_model."""
177+
return TrainResponse(
178+
store_id=1,
179+
product_id=1,
180+
model_type=model_type,
181+
model_path="/data/artifacts/model_abc123def456.joblib",
182+
config_hash="cfg-hash",
183+
n_observations=400,
184+
train_start_date=date(2024, 1, 1),
185+
train_end_date=date(2024, 12, 31),
186+
duration_ms=12.0,
187+
)
188+
189+
190+
_REGRESSION_PARAMS: dict[str, Any] = {
191+
"model_type": "regression",
192+
"store_id": 1,
193+
"product_id": 1,
194+
"start_date": "2024-01-01",
195+
"end_date": "2024-12-31",
196+
}
197+
198+
199+
async def test_execute_train_builds_regression_config() -> None:
200+
"""A train job with model_type='regression' builds a RegressionModelConfig (#229)."""
201+
fake = _fake_train_response("regression")
202+
with patch.object(
203+
ForecastingService, "train_model", new=AsyncMock(return_value=fake)
204+
) as mock_train:
205+
result = await JobService()._execute_train(
206+
db=cast(AsyncSession, AsyncMock()),
207+
params=_REGRESSION_PARAMS,
208+
)
209+
assert mock_train.call_args is not None
210+
config = mock_train.call_args.kwargs["config"]
211+
assert isinstance(config, RegressionModelConfig)
212+
assert result["model_type"] == "regression"
213+
# run_id is parsed from the model_abc123def456.joblib artifact path.
214+
assert result["run_id"] == "abc123def456"
215+
216+
217+
async def test_execute_train_rejects_unsupported_model_type() -> None:
218+
"""_execute_train still rejects a genuinely unsupported model_type (e.g. lightgbm)."""
219+
with pytest.raises(ValueError, match="Unsupported model_type"):
220+
await JobService()._execute_train(
221+
db=cast(AsyncSession, AsyncMock()),
222+
params={**_REGRESSION_PARAMS, "model_type": "lightgbm"},
223+
)

frontend/src/lib/scenario-utils.test.ts

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { describe, expect, it } from 'vitest'
22
import {
3+
assumptionDateErrors,
34
buildMultiSeries,
45
coverageLabel,
56
coverageVariant,
@@ -143,3 +144,61 @@ describe('methodLabel', () => {
143144
expect(methodLabel('model_exogenous')).toBe('Model-driven')
144145
})
145146
})
147+
148+
describe('assumptionDateErrors', () => {
149+
const NONE = {
150+
priceEnabled: false,
151+
priceStart: '',
152+
priceEnd: '',
153+
promoEnabled: false,
154+
promoStart: '',
155+
promoEnd: '',
156+
}
157+
158+
it('reports no errors when nothing is enabled', () => {
159+
expect(assumptionDateErrors(NONE).hasErrors).toBe(false)
160+
})
161+
162+
it('flags both price dates when price is enabled and blank', () => {
163+
const e = assumptionDateErrors({ ...NONE, priceEnabled: true })
164+
expect(e.priceStart).toBe(true)
165+
expect(e.priceEnd).toBe(true)
166+
expect(e.hasErrors).toBe(true)
167+
})
168+
169+
it('clears price errors once both dates are filled', () => {
170+
const e = assumptionDateErrors({
171+
...NONE,
172+
priceEnabled: true,
173+
priceStart: '2026-07-01',
174+
priceEnd: '2026-07-14',
175+
})
176+
expect(e.hasErrors).toBe(false)
177+
})
178+
179+
it('flags only the blank promotion date', () => {
180+
const e = assumptionDateErrors({
181+
...NONE,
182+
promoEnabled: true,
183+
promoStart: '2026-07-01',
184+
promoEnd: '',
185+
})
186+
expect(e.promoStart).toBe(false)
187+
expect(e.promoEnd).toBe(true)
188+
expect(e.hasErrors).toBe(true)
189+
})
190+
191+
it('isolates errors per assumption (price ok, promo blank)', () => {
192+
const e = assumptionDateErrors({
193+
priceEnabled: true,
194+
priceStart: '2026-07-01',
195+
priceEnd: '2026-07-14',
196+
promoEnabled: true,
197+
promoStart: '',
198+
promoEnd: '',
199+
})
200+
expect(e.priceStart).toBe(false)
201+
expect(e.promoStart).toBe(true)
202+
expect(e.hasErrors).toBe(true)
203+
})
204+
})

frontend/src/lib/scenario-utils.ts

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,41 @@ export function buildMultiSeries(comparison: MultiScenarioComparison): MultiSeri
135135
export function methodLabel(method: 'heuristic' | 'model_exogenous'): string {
136136
return method === 'model_exogenous' ? 'Model-driven' : 'Heuristic'
137137
}
138+
139+
/** Form state for the date-bearing assumptions (price, promotion). */
140+
export interface AssumptionDateState {
141+
priceEnabled: boolean
142+
priceStart: string
143+
priceEnd: string
144+
promoEnabled: boolean
145+
promoStart: string
146+
promoEnd: string
147+
}
148+
149+
/** Which enabled assumption date inputs are still blank. */
150+
export interface AssumptionDateErrors {
151+
priceStart: boolean
152+
priceEnd: boolean
153+
promoStart: boolean
154+
promoEnd: boolean
155+
hasErrors: boolean
156+
}
157+
158+
/**
159+
* Flag every enabled Price/Promotion assumption whose From/To date is blank.
160+
* The planner blocks Run/Save while `hasErrors` is true so the backend never
161+
* receives an empty-string date (which fails Pydantic date validation → 422).
162+
*/
163+
export function assumptionDateErrors(state: AssumptionDateState): AssumptionDateErrors {
164+
const priceStart = state.priceEnabled && !state.priceStart
165+
const priceEnd = state.priceEnabled && !state.priceEnd
166+
const promoStart = state.promoEnabled && !state.promoStart
167+
const promoEnd = state.promoEnabled && !state.promoEnd
168+
return {
169+
priceStart,
170+
priceEnd,
171+
promoStart,
172+
promoEnd,
173+
hasErrors: priceStart || priceEnd || promoStart || promoEnd,
174+
}
175+
}

frontend/src/pages/visualize/planner.tsx

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import {
3535
import { downloadCsv, toCsv } from '@/lib/csv-export'
3636
import { formatCurrency, formatNumber, getErrorMessage } from '@/lib/api'
3737
import {
38+
assumptionDateErrors,
3839
buildMultiSeries,
3940
coverageLabel,
4041
coverageVariant,
@@ -73,8 +74,11 @@ export default function WhatIfPlannerPage() {
7374
const [selectedJobId, setSelectedJobId] = useState('')
7475
const [horizon, setHorizon] = useState(14)
7576
const { data: job } = useJob(selectedJobId, !!selectedJobId)
76-
// A predict job's params.run_id is the baseline model artifact key.
77-
const baselineRunId = typeof job?.params?.run_id === 'string' ? job.params.run_id : null
77+
// A completed `train` job stores result.run_id — the model-artifact key
78+
// POST /scenarios/simulate resolves. (This is NOT a registry run id.)
79+
// A `regression` baseline routes the simulate call down the model_exogenous
80+
// re-forecast branch; other model types fall back to the heuristic factor.
81+
const baselineRunId = typeof job?.result?.run_id === 'string' ? job.result.run_id : null
7882

7983
// -- Assumption form state ---------------------------------------------
8084
const [priceEnabled, setPriceEnabled] = useState(false)
@@ -97,6 +101,19 @@ export default function WhatIfPlannerPage() {
97101
const [lifecycleStage, setLifecycleStage] =
98102
useState<(typeof LIFECYCLE_STAGES)[number]>('maturity')
99103

104+
// -- Derived validation ------------------------------------------------
105+
// Enabling Price/Promotion without filling both dates would submit empty
106+
// strings — Pydantic date validation rejects those with an RFC 7807 422.
107+
// Gate Run/Save on this so the form can never produce that request (#228).
108+
const dateErrors = assumptionDateErrors({
109+
priceEnabled,
110+
priceStart,
111+
priceEnd,
112+
promoEnabled,
113+
promoStart,
114+
promoEnd,
115+
})
116+
100117
// -- Results / persistence state ---------------------------------------
101118
const [simulated, setSimulated] = useState<ScenarioComparison | null>(null)
102119
const [planName, setPlanName] = useState('')
@@ -152,7 +169,7 @@ export default function WhatIfPlannerPage() {
152169
}
153170

154171
async function handleRun() {
155-
if (!baselineRunId) return
172+
if (!baselineRunId || dateErrors.hasErrors) return
156173
setRunError(null)
157174
setReloadId('')
158175
try {
@@ -169,7 +186,7 @@ export default function WhatIfPlannerPage() {
169186
}
170187

171188
async function handleSave() {
172-
if (!baselineRunId || !planName.trim()) return
189+
if (!baselineRunId || !planName.trim() || dateErrors.hasErrors) return
173190
setRunError(null)
174191
try {
175192
await createScenario.mutateAsync({
@@ -245,12 +262,15 @@ export default function WhatIfPlannerPage() {
245262
<CardHeader>
246263
<CardTitle>1. Pick a baseline</CardTitle>
247264
<CardDescription>
248-
Choose a completed prediction job — its model is the baseline this scenario adjusts.
265+
Choose a completed training job — its model is the baseline this scenario
266+
adjusts. A regression baseline is genuinely re-forecast through the model
267+
(model-driven); naive, seasonal-naive and moving-average baselines use a
268+
heuristic adjustment factor.
249269
</CardDescription>
250270
</CardHeader>
251271
<CardContent className="space-y-4">
252272
<JobPicker
253-
jobType="predict"
273+
jobType="train"
254274
selectedJobId={selectedJobId}
255275
onSelect={setSelectedJobId}
256276
autoSelectLatest
@@ -274,7 +294,7 @@ export default function WhatIfPlannerPage() {
274294
</div>
275295
{selectedJobId && !baselineRunId && (
276296
<p className="text-sm text-muted-foreground">
277-
The selected job has no model artifact — pick a completed predict job.
297+
The selected job has no model artifact — pick a completed train job.
278298
</p>
279299
)}
280300
</CardContent>
@@ -317,6 +337,9 @@ export default function WhatIfPlannerPage() {
317337
value={priceStart}
318338
onChange={(event) => setPriceStart(event.target.value)}
319339
/>
340+
{dateErrors.priceStart && (
341+
<p className="text-xs text-destructive">Required</p>
342+
)}
320343
</div>
321344
<div className="space-y-1">
322345
<span className="text-xs text-muted-foreground">To</span>
@@ -326,6 +349,9 @@ export default function WhatIfPlannerPage() {
326349
value={priceEnd}
327350
onChange={(event) => setPriceEnd(event.target.value)}
328351
/>
352+
{dateErrors.priceEnd && (
353+
<p className="text-xs text-destructive">Required</p>
354+
)}
329355
</div>
330356
</div>
331357
)}
@@ -368,6 +394,9 @@ export default function WhatIfPlannerPage() {
368394
value={promoStart}
369395
onChange={(event) => setPromoStart(event.target.value)}
370396
/>
397+
{dateErrors.promoStart && (
398+
<p className="text-xs text-destructive">Required</p>
399+
)}
371400
</div>
372401
<div className="space-y-1">
373402
<span className="text-xs text-muted-foreground">To</span>
@@ -377,6 +406,9 @@ export default function WhatIfPlannerPage() {
377406
value={promoEnd}
378407
onChange={(event) => setPromoEnd(event.target.value)}
379408
/>
409+
{dateErrors.promoEnd && (
410+
<p className="text-xs text-destructive">Required</p>
411+
)}
380412
</div>
381413
</div>
382414
)}
@@ -462,7 +494,10 @@ export default function WhatIfPlannerPage() {
462494
</div>
463495

464496
<div className="flex flex-wrap items-center gap-3 border-t pt-4">
465-
<Button onClick={handleRun} disabled={!baselineRunId || simulate.isPending}>
497+
<Button
498+
onClick={handleRun}
499+
disabled={!baselineRunId || simulate.isPending || dateErrors.hasErrors}
500+
>
466501
{simulate.isPending ? (
467502
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
468503
) : (
@@ -593,7 +628,12 @@ export default function WhatIfPlannerPage() {
593628
</div>
594629
<Button
595630
onClick={handleSave}
596-
disabled={!baselineRunId || !planName.trim() || createScenario.isPending}
631+
disabled={
632+
!baselineRunId ||
633+
!planName.trim() ||
634+
createScenario.isPending ||
635+
dateErrors.hasErrors
636+
}
597637
>
598638
{createScenario.isPending ? (
599639
<Loader2 className="mr-2 h-4 w-4 animate-spin" />

0 commit comments

Comments
 (0)