diff --git a/.ai-instructions b/.ai-instructions index d15d554a..528a0118 160000 --- a/.ai-instructions +++ b/.ai-instructions @@ -1 +1 @@ -Subproject commit d15d554a54785e84cb80165443fa432a22adc45c +Subproject commit 528a0118ff9e02233bfc073da891b60e81b34754 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d0ea4a07..80b272c1 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,7 @@ jobs: - uses: actions/checkout@v6 - uses: prefix-dev/setup-pixi@v0.9.5 with: - pixi-version: v0.66.0 + pixi-version: v0.67.2 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: tests-cpu @@ -59,7 +59,7 @@ jobs: - uses: actions/checkout@v6 - uses: prefix-dev/setup-pixi@v0.9.5 with: - pixi-version: v0.66.0 + pixi-version: v0.67.2 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: type-checking @@ -82,7 +82,7 @@ jobs: - uses: actions/checkout@v6 - uses: prefix-dev/setup-pixi@v0.9.5 with: - pixi-version: v0.66.0 + pixi-version: v0.67.2 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: tests-cuda12 @@ -101,7 +101,7 @@ jobs: - uses: actions/checkout@v6 - uses: prefix-dev/setup-pixi@v0.9.5 with: - pixi-version: v0.66.0 + pixi-version: v0.67.2 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: tests-cuda12 @@ -116,7 +116,7 @@ jobs: # - uses: actions/checkout@v6 # - uses: prefix-dev/setup-pixi@v0.9.5 # with: - # pixi-version: v0.66.0 + # pixi-version: v0.67.2 # cache: true # cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} # environments: tests-cpu diff --git a/.gitignore b/.gitignore index db181e4d..03acd3c8 100644 --- a/.gitignore +++ b/.gitignore @@ -31,9 +31,6 @@ docs/_build/ .pixi/ node_modules/ -# pytask -.pytask.sqlite3 - # Python __pycache__/ *.py[cod] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 270b9b25..f51ffb31 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,11 +50,11 @@ repos: hooks: - id: yamllint - repo: https://github.com/python-jsonschema/check-jsonschema - rev: 0.37.1 + rev: 0.37.2 hooks: - id: check-github-workflows - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.11 + rev: v0.15.12 hooks: - id: ruff-check args: @@ -86,7 +86,7 @@ repos: args: - --wrap - '88' - files: (AGENTS\.md|CLAUDE\.md|README\.md|modules/.*\.md|profiles/.*\.md) + files: (AGENTS\.md|CLAUDE\.md|README\.md) - id: mdformat additional_dependencies: - mdformat-myst diff --git a/AGENTS.md b/AGENTS.md index f9777113..7aea2844 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,5 +1,6 @@ @.ai-instructions/profiles/tier-a.md @.ai-instructions/modules/jax.md @.ai-instructions/modules/pandas.md @.ai-instructions/modules/plotting.md +@.ai-instructions/modules/dags.md # PyLCM diff --git a/benchmarks/bench_aca_baseline.py b/benchmarks/bench_aca_baseline.py index 3ef4c9b2..477d5635 100644 --- a/benchmarks/bench_aca_baseline.py +++ b/benchmarks/bench_aca_baseline.py @@ -55,7 +55,7 @@ def _build() -> tuple[object, object, object]: ) model = create_benchmark_model() - _, model_params = get_benchmark_params() + _, model_params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=_N_SUBJECTS, seed=0 ) diff --git a/pixi.lock b/pixi.lock index 5b3f6f2e..e507ccc8 100644 --- a/pixi.lock +++ b/pixi.lock @@ -270,7 +270,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=adc8a19328608781a5cb2a65ab2d93d580163aae#adc8a19328608781a5cb2a65ab2d93d580163aae + - pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5328,7 +5328,7 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 -- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=adc8a19328608781a5cb2a65ab2d93d580163aae#adc8a19328608781a5cb2a65ab2d93d580163aae +- pypi: git+https://github.com/OpenSourceEconomics/aca-model.git?rev=134286108b7445f3e17e8824bcdd1739a98b6089#134286108b7445f3e17e8824bcdd1739a98b6089 name: aca-model version: 0.0.0 requires_dist: @@ -13962,8 +13962,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev142+g95ab1b648.d20260423 - sha256: ae4d75e092f6528909d9f185ed13eccc4aabeae96f5c6d41987cbb2704afa7e7 + version: 0.0.2.dev130+g3769d2d63.d20260429 + sha256: 365fd6893bcba5ab807032371430b19a9a9056c80ef3a9a201b56d34ced99e0e requires_dist: - cloudpickle>=3.1.2 - dags>=0.5.1 diff --git a/pyproject.toml b/pyproject.toml index ced35ccc..3fe01162 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" } tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" } type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" } [tool.pixi.feature.benchmarks.pypi-dependencies] -aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "adc8a19328608781a5cb2a65ab2d93d580163aae" } +aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "134286108b7445f3e17e8824bcdd1739a98b6089" } [tool.pixi.feature.cuda12] platforms = [ "linux-64" ] system-requirements = { cuda = "12" } diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index c81f9c6a..9a147c44 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -116,12 +116,22 @@ def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array: class LogSpacedGrid(UniformContinuousGrid): """A logarithmically spaced grid of continuous values. + Requires `start > 0`. + Example: -------- Let `start = 1`, `stop = 100`, and `n_points = 3`. The grid is `[1, 10, 100]`. """ + def __post_init__(self) -> None: + _validate_continuous_grid( + start=self.start, + stop=self.stop, + n_points=self.n_points, + requires_positive_start=True, + ) + def to_jax(self) -> Float1D: """Convert the grid to a Jax array.""" return grid_coordinates.logspace( @@ -188,9 +198,23 @@ def pass_points_at_runtime(self) -> bool: return self.points is None def to_jax(self) -> Float1D: - """Convert the grid to a Jax array.""" + """Convert the grid to a Jax array. + + Raises `GridInitializationError` for runtime-supplied grids + (`pass_points_at_runtime=True`). To get the substituted points, + call `internal_regime.state_action_space(regime_params=...)` and + read from `.states[name]` or `.continuous_actions[name]`. + """ if self.points is None: - return jnp.full(self.n_points, jnp.nan) + raise GridInitializationError( + f"IrregSpacedGrid declared with n_points={self.n_points} and " + f"no points; values are supplied at runtime via " + f"params['']['']['points']. To get the " + f"substituted points, call " + f"`internal_regime.state_action_space(regime_params=...)` and " + f"read from `.states[name]` or `.continuous_actions[name]`. " + f"Use `.n_points` if only the shape is needed." + ) return jnp.asarray(self.points) @overload @@ -213,6 +237,7 @@ def _validate_continuous_grid( start: float, stop: float, n_points: int, + requires_positive_start: bool = False, ) -> None: """Validate the continuous grid parameters. @@ -220,6 +245,8 @@ def _validate_continuous_grid( start: The start value of the grid. stop: The stop value of the grid. n_points: The number of points in the grid. + requires_positive_start: If True, also require `start > 0` (used by + log-spaced grids since `log(x)` is undefined for `x <= 0`). Raises: GridInitializationError: If the grid parameters are invalid. @@ -235,6 +262,15 @@ def _validate_continuous_grid( if not valid_stop_type: error_messages.append("stop must be a scalar int or float value") + # Reject NaN/inf early — `start >= stop` returns False for NaN, so an + # un-finite start would otherwise pass silently and produce a broken grid. + if valid_start_type and not jnp.isfinite(start): + error_messages.append(f"start must be finite, got {start}") + valid_start_type = False + if valid_stop_type and not jnp.isfinite(stop): + error_messages.append(f"stop must be finite, got {stop}") + valid_stop_type = False + if not isinstance(n_points, int) or n_points < 1: error_messages.append( f"n_points must be an int greater than 0 but is {n_points}", @@ -243,6 +279,12 @@ def _validate_continuous_grid( if valid_start_type and valid_stop_type and start >= stop: error_messages.append("start must be less than stop") + if valid_start_type and requires_positive_start and start <= 0: + error_messages.append( + f"start must be > 0 for a log-spaced grid (got {start}); " + f"`log(x)` is undefined for `x <= 0`." + ) + if error_messages: msg = format_messages(error_messages) raise GridInitializationError(msg) @@ -275,15 +317,24 @@ def _validate_irreg_spaced_grid(points: Sequence[float] | Float1D) -> None: f"Non-numeric elements found at indices: {non_numeric}" ) else: - # Check that points are in ascending order - for i in range(len(points) - 1): - if points[i] >= points[i + 1]: - error_messages.append( - "Points must be in strictly ascending order. " - f"Found points[{i}]={points[i]} >= " - f"points[{i + 1}]={points[i + 1]}" - ) - break + # Reject NaN/inf — comparisons with NaN are False, so the + # ascending-order check below would silently let them through. + non_finite = [(i, p) for i, p in enumerate(points) if not jnp.isfinite(p)] + if non_finite: + error_messages.append( + f"All elements of points must be finite. " + f"Non-finite elements found at: {non_finite}" + ) + else: + # Check that points are in strictly ascending order + for i in range(len(points) - 1): + if points[i] >= points[i + 1]: + error_messages.append( + "Points must be in strictly ascending order. " + f"Found points[{i}]={points[i]} >= " + f"points[{i + 1}]={points[i + 1]}" + ) + break if error_messages: msg = format_messages(error_messages) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 9ff454b2..59617eb6 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -242,12 +242,12 @@ class InternalRegime: """Flat resolved fixed params for this regime, used by to_dataframe targets.""" def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpace: - """Return the state-action space with runtime state grids filled in. + """Return the state-action space with runtime grids filled in. - For IrregSpacedGrid with runtime-supplied points, the grid points come from - params as `{state_name}__points`. For _ShockGrid with runtime-supplied params, - the grid points are computed from shock params in the params dict or - resolved_fixed_params. + For IrregSpacedGrid (state or continuous action) with runtime-supplied + points, the grid points come from params as `{name}__points`. For + `_ShockGrid` with runtime-supplied params, the grid points are computed + from shock params in the params dict or `resolved_fixed_params`. Args: regime_params: Flat regime parameters supplied at runtime. @@ -257,35 +257,68 @@ def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpac """ all_params = {**self.resolved_fixed_params, **regime_params} - replacements: dict[str, ContinuousState | DiscreteState] = {} - for state_name, spec in self.grids.items(): - if state_name not in self._base_state_action_space.states: + state_replacements: dict[str, ContinuousState | DiscreteState] = {} + action_replacements: dict[str, ContinuousAction] = {} + for name, spec in self.grids.items(): + in_states = name in self._base_state_action_space.states + in_continuous_actions = ( + name in self._base_state_action_space.continuous_actions + ) + if not (in_states or in_continuous_actions): continue if isinstance(spec, IrregSpacedGrid) and spec.pass_points_at_runtime: - points_key = f"{state_name}__points" + points_key = f"{name}__points" if points_key not in all_params: continue - replacements[state_name] = cast( - "ContinuousState", all_params[points_key] - ) - elif isinstance(spec, _ShockGrid) and spec.params_to_pass_at_runtime: + if in_states: + state_replacements[name] = cast( + "ContinuousState", all_params[points_key] + ) + else: + action_replacements[name] = cast( + "ContinuousAction", all_params[points_key] + ) + # `_ShockGrid` is state-only by construction (intrinsic + # transitions, forbidden as actions per AGENTS.md). The + # `in_states` gate makes that invariant explicit — a + # `_ShockGrid` reaching the action branch would be a model + # bug, not something this method should silently substitute. + elif ( + in_states + and isinstance(spec, _ShockGrid) + and spec.params_to_pass_at_runtime + ): all_present = all( - f"{state_name}__{p}" in all_params - for p in spec.params_to_pass_at_runtime + f"{name}__{p}" in all_params for p in spec.params_to_pass_at_runtime ) if not all_present: continue shock_kw: dict[str, float] = dict(spec.params) for p in spec.params_to_pass_at_runtime: - shock_kw[p] = cast("float", all_params[f"{state_name}__{p}"]) - replacements[state_name] = spec.compute_gridpoints(**shock_kw) + shock_kw[p] = cast("float", all_params[f"{name}__{p}"]) + state_replacements[name] = spec.compute_gridpoints(**shock_kw) - if not replacements: + if not state_replacements and not action_replacements: return self._base_state_action_space - new_states = dict(self._base_state_action_space.states) | replacements + new_states = ( + MappingProxyType( + dict(self._base_state_action_space.states) | state_replacements + ) + if state_replacements + else None + ) + new_continuous_actions = ( + MappingProxyType( + dict(self._base_state_action_space.continuous_actions) + | action_replacements + ) + if action_replacements + else None + ) return self._base_state_action_space.replace( - states=MappingProxyType(new_states) + states=new_states, + continuous_actions=new_continuous_actions, ) diff --git a/src/lcm/pandas_utils.py b/src/lcm/pandas_utils.py index 805932e2..2d696a45 100644 --- a/src/lcm/pandas_utils.py +++ b/src/lcm/pandas_utils.py @@ -504,12 +504,15 @@ def _resolve_per_target_template_key( def _is_runtime_grid_param(*, func_name: FunctionName, regime: Regime) -> bool: """Check if a template function key refers to a runtime grid param.""" - if func_name not in regime.states: - return False - grid = regime.states[func_name] - return (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime) or ( - isinstance(grid, _ShockGrid) and bool(grid.params_to_pass_at_runtime) - ) + if func_name in regime.states: + grid = regime.states[func_name] + return (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime) or ( + isinstance(grid, _ShockGrid) and bool(grid.params_to_pass_at_runtime) + ) + if func_name in regime.actions: + grid = regime.actions[func_name] + return isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime + return False def _fail_if_period_level(sr: pd.Series) -> None: diff --git a/src/lcm/params/regime_template.py b/src/lcm/params/regime_template.py index 299b2c27..85a66b47 100644 --- a/src/lcm/params/regime_template.py +++ b/src/lcm/params/regime_template.py @@ -32,7 +32,7 @@ def create_regime_params_template( Grids with runtime-supplied values (IrregSpacedGrid without points, `_ShockGrid` without full shock_params) add entries to the template under - pseudo-function keys matching the state name. + pseudo-function keys matching the state or action name. Args: regime: The regime as provided by the user. @@ -70,27 +70,70 @@ def create_regime_params_template( _validate_no_shadowing(function_params, regime) + _add_runtime_grid_params(function_params, regime) + + return MappingProxyType( + {k: MappingProxyType(v) for k, v in function_params.items()} + ) + + +def _add_runtime_grid_params( + function_params: dict[FunctionName, dict[str, str]], + regime: Regime, +) -> None: + """Add runtime-supplied state/action grid params to the template in place.""" for state_name, grid in regime.states.items(): if isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime: - if state_name in function_params: - raise InvalidNameError( - f"IrregSpacedGrid state '{state_name}' (with runtime-supplied " - f"points) conflicts with a function of the same name in the regime." - ) + _fail_if_runtime_grid_shadows_function( + function_params=function_params, name=state_name, kind="state" + ) function_params[state_name] = {"points": "Float1D"} elif isinstance(grid, _ShockGrid) and grid.params_to_pass_at_runtime: - if state_name in function_params: - raise InvalidNameError( - f"_ShockGrid state '{state_name}' (with runtime-supplied params) " - f"conflicts with a function of the same name in the regime." - ) + _fail_if_runtime_grid_shadows_function( + function_params=function_params, + name=state_name, + kind="_ShockGrid state", + ) function_params[state_name] = dict.fromkeys( grid.params_to_pass_at_runtime, "float" ) - return MappingProxyType( - {k: MappingProxyType(v) for k, v in function_params.items()} - ) + for action_name, grid in regime.actions.items(): + if isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime: + _fail_if_runtime_grid_shadows_function( + function_params=function_params, name=action_name, kind="action" + ) + function_params[action_name] = {"points": "Float1D"} + + +def _fail_if_runtime_grid_shadows_function( + *, + function_params: dict[FunctionName, dict[str, str]], + name: str, + kind: str, +) -> None: + """Raise if a runtime grid name collides with an existing function name. + + Runtime-supplied state and action grids contribute pseudo-function entries + to the params template (keyed by the state or action name). Letting such a + pseudo-function entry shadow a real regime function would silently break + parameter resolution, so we reject it at template-construction time. + + Args: + function_params: Template entries collected so far, keyed by + (pseudo-)function name. + name: State or action name being added. + kind: `"state"` or `"action"`, surfaced in the error message. + + Raises: + InvalidNameError: If `name` already exists in `function_params`. + + """ + if name in function_params: + raise InvalidNameError( + f"IrregSpacedGrid {kind} '{name}' (with runtime-supplied " + f"points/params) conflicts with a function of the same name in the regime." + ) def _collect_all_functions_for_template( diff --git a/src/lcm/regime_building/validation.py b/src/lcm/regime_building/validation.py index 0754c0a7..92656692 100644 --- a/src/lcm/regime_building/validation.py +++ b/src/lcm/regime_building/validation.py @@ -5,6 +5,9 @@ """ +import ast +import inspect +import textwrap from collections.abc import Callable, Mapping from typing import TypeAliasType @@ -125,6 +128,7 @@ def validate_logical_consistency(regime: Regime) -> None: error_messages.extend(_validate_active(regime.active)) error_messages.extend(_validate_state_transitions(regime)) + error_messages.extend(_validate_function_output_grid_indexing(regime)) states_and_actions_overlap = set(regime.states) & set(regime.actions) if states_and_actions_overlap: @@ -138,6 +142,103 @@ def validate_logical_consistency(regime: Regime) -> None: raise RegimeInitializationError(msg) +def _validate_function_output_grid_indexing(regime: Regime) -> list[str]: + """Detect the regime-function-output / discrete-grid-indexed-input name clash. + + The unsafe pattern is: a regime function `f` takes a discrete grid `g` + (state, action, or derived categorical) as an input — so `f`'s output + is a per-cell scalar — and a consumer then indexes `f[g]`. The + consumer is indexing a 0-d array by a scalar integer, which raises + `IndexError` at trace time. The fix is to drop the redundant `[g]` + in the consumer (or refactor `f` not to take `g`). + """ + function_output_names = set(regime.functions) + discrete_grid_names = ( + {name for name, grid in regime.states.items() if isinstance(grid, DiscreteGrid)} + | { + name + for name, grid in regime.actions.items() + if isinstance(grid, DiscreteGrid) + } + | set(regime.derived_categoricals) + ) + if not function_output_names or not discrete_grid_names: + return [] + + # Only treat `func_output[grid]` as unsafe when the producing function + # *also* takes `grid` as an input — that is the case where the output + # is per-cell scalar and the consumer's indexing is wrong. If the + # producing function does not take `grid`, its output shape is + # whatever it computed (typically an array indexable by `grid`) and + # the consumer pattern is correct. + function_inputs: dict[str, set[str]] = {} + for name, func in regime.functions.items(): + try: + function_inputs[name] = set(inspect.signature(func).parameters) + except ValueError, TypeError: + function_inputs[name] = set() + + consumers: list[tuple[str, Callable]] = [] + consumers.extend(regime.functions.items()) + consumers.extend(regime.constraints.items()) + if callable(regime.transition): + consumers.append(("regime_transition", regime.transition)) + + errors: list[str] = [] + for consumer_name, func in consumers: + clashes = _find_function_output_grid_indexing( + func=func, + function_output_names=function_output_names, + discrete_grid_names=discrete_grid_names, + ) + for func_output_name, grid_name in clashes: + if grid_name not in function_inputs.get(func_output_name, set()): + continue + errors.append( + f"Consumer '{consumer_name}' indexes regime function output " + f"'{func_output_name}' by discrete grid '{grid_name}' " + f"(`{func_output_name}[{grid_name}]`), but '{func_output_name}' " + f"already takes '{grid_name}' as input — its output is a " + f"per-cell scalar, so the indexing raises IndexError at trace " + f"time. Drop the redundant `[{grid_name}]` in '{consumer_name}', " + f"or refactor '{func_output_name}' not to take '{grid_name}' " + f"as input." + ) + return errors + + +def _find_function_output_grid_indexing( + *, + func: Callable, + function_output_names: set[str], + discrete_grid_names: set[str], +) -> list[tuple[str, str]]: + """Return `(function_output_name, grid_name)` clashes inside `func`'s body.""" + try: + source = textwrap.dedent(inspect.getsource(func)) + except OSError, TypeError: + return [] + try: + tree = ast.parse(source) + except SyntaxError: + return [] + + clashes: list[tuple[str, str]] = [] + for node in ast.walk(tree): + if not isinstance(node, ast.Subscript): + continue + if not isinstance(node.value, ast.Name): + continue + if node.value.id not in function_output_names: + continue + if not isinstance(node.slice, ast.Name): + continue + if node.slice.id not in discrete_grid_names: + continue + clashes.append((node.value.id, node.slice.id)) + return clashes + + def collect_state_transitions( states: Mapping[StateName, Grid], state_transitions: Mapping[ diff --git a/src/lcm/simulation/initial_conditions.py b/src/lcm/simulation/initial_conditions.py index 1be89309..2cdd3c3b 100644 --- a/src/lcm/simulation/initial_conditions.py +++ b/src/lcm/simulation/initial_conditions.py @@ -7,7 +7,7 @@ from collections.abc import Callable, Mapping, Sequence from types import MappingProxyType -from typing import Never +from typing import Never, cast import jax import numpy as np @@ -25,6 +25,7 @@ from lcm.regime_building.Q_and_F import _get_feasibility from lcm.typing import ( ActionName, + FlatRegimeParams, InternalParams, RegimeName, RegimeNamesToIds, @@ -579,11 +580,21 @@ def _check_regime_feasibility( # noqa: C901 if not action_names: return None - grids = MappingProxyType( - {name: spec.to_jax() for name, spec in internal_regime.grids.items()} + # Build the state-action space with runtime-supplied grid points + # substituted. The base grid's `to_jax()` raises for runtime-supplied + # `IrregSpacedGrid`s declared with `pass_points_at_runtime=True`, so the + # validator must read points from `state_action_space(regime_params=...)`. + state_action_space = internal_regime.state_action_space( + regime_params=cast("FlatRegimeParams", MappingProxyType(dict(regime_params))), + ) + action_grids: dict[str, Array] = { + **state_action_space.discrete_actions, + **state_action_space.continuous_actions, + } + flat_actions = _build_flat_action_grid( + action_names=action_names, + grids=MappingProxyType(action_grids), ) - - flat_actions = _build_flat_action_grid(action_names=action_names, grids=grids) filtered_params = {k: v for k, v in regime_params.items() if k in accepted} state_names = list(internal_regime.variable_info.query("is_state").index) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 9c017db3..d54d2674 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -244,6 +244,7 @@ def _simulate_regime_in_period( state_action_space = create_regime_state_action_space( internal_regime=internal_regime, states=states, + regime_params=internal_params[regime_name], ) # Compute optimal actions # We need to pass the value function array of the next period to the diff --git a/src/lcm/simulation/transitions.py b/src/lcm/simulation/transitions.py index 67e6df67..18bdbb6b 100644 --- a/src/lcm/simulation/transitions.py +++ b/src/lcm/simulation/transitions.py @@ -15,7 +15,7 @@ from lcm.interfaces import InternalRegime, StateActionSpace from lcm.simulation.random import generate_simulation_keys -from lcm.state_action_space import create_state_action_space +from lcm.state_action_space import _validate_all_states_present from lcm.typing import ( ActionName, Bool1D, @@ -31,29 +31,37 @@ def create_regime_state_action_space( *, internal_regime: InternalRegime, states: MappingProxyType[str, Array], + regime_params: FlatRegimeParams, ) -> StateActionSpace: """Create the state-action space containing only the relevant subjects in a regime. + Continuous action grids declared with `pass_points_at_runtime=True` are + completed from `regime_params` (via + `InternalRegime.state_action_space`). + Args: internal_regime: The internal regime instance. states: The current states of all subjects. + regime_params: Flat regime parameters supplied at runtime, used to + substitute runtime-supplied action gridpoints. Returns: The state-action space for the subjects in the regime. """ - relevant_state_names = internal_regime.variable_info.query("is_state").index + base = internal_regime.state_action_space(regime_params=regime_params) + relevant_state_names = internal_regime.variable_info.query("is_state").index states_for_state_action_space = { sn: states[f"{internal_regime.name}__{sn}"] for sn in relevant_state_names } - - return create_state_action_space( - variable_info=internal_regime.variable_info, - grids=internal_regime.grids, - states=states_for_state_action_space, + _validate_all_states_present( + provided_states=states_for_state_action_space, + required_state_names=set(relevant_state_names), ) + return base.replace(states=MappingProxyType(states_for_state_action_space)) + def calculate_next_states( *, diff --git a/src/lcm/state_action_space.py b/src/lcm/state_action_space.py index ad6f6c1f..7c88e165 100644 --- a/src/lcm/state_action_space.py +++ b/src/lcm/state_action_space.py @@ -1,9 +1,10 @@ from types import MappingProxyType +import jax.numpy as jnp import pandas as pd from jax import Array -from lcm.grids import Grid +from lcm.grids import Grid, IrregSpacedGrid from lcm.interfaces import StateActionSpace from lcm.typing import StateName, StateOrActionName @@ -33,7 +34,8 @@ def create_state_action_space( """ if states is None: _states = { - sn: grids[sn].to_jax() for sn in variable_info.query("is_state").index + sn: _grid_to_jax_or_placeholder(grids[sn]) + for sn in variable_info.query("is_state").index } else: _validate_all_states_present( @@ -47,7 +49,7 @@ def create_state_action_space( for name in variable_info.query("is_action & is_discrete").index } continuous_actions = { - name: grids[name].to_jax() + name: _grid_to_jax_or_placeholder(grids[name]) for name in variable_info.query("is_action & is_continuous").index } state_and_discrete_action_names = tuple( @@ -62,6 +64,21 @@ def create_state_action_space( ) +def _grid_to_jax_or_placeholder(grid: Grid) -> Array: + """Return the grid's points, or a NaN placeholder for runtime-supplied grids. + + `IrregSpacedGrid.to_jax()` raises when its points haven't been supplied — that + is the right behaviour everywhere except here: the base state-action space + needs a *shape-correct* array to wire through pytree structures and AOT + tracing before runtime substitution by + `InternalRegime.state_action_space(regime_params=...)`. NaN (rather than + zero) makes any accidental computation against the placeholder fail loudly. + """ + if isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime: + return jnp.full(grid.n_points, jnp.nan) + return grid.to_jax() + + def _validate_all_states_present( *, provided_states: dict[StateName, Array], required_state_names: set[StateName] ) -> None: diff --git a/tests/test_function_output_grid_indexing.py b/tests/test_function_output_grid_indexing.py new file mode 100644 index 00000000..adc37816 --- /dev/null +++ b/tests/test_function_output_grid_indexing.py @@ -0,0 +1,272 @@ +"""Tests for the regime-function-output / discrete-grid-indexed-input name clash. + +The unsafe pattern is: a regime function `f` takes a discrete grid `g` (state, +action, or derived categorical) as an input — so `f`'s output is a per-cell +scalar — and a consumer then indexes `f[g]`. The consumer is indexing a 0-d +array by a scalar integer, which raises `IndexError` at trace time. The +validator catches the pattern at construction so the user gets a clear message +instead of a cryptic JAX trace error during solve. + +The fix is to drop the redundant `[g]` in the consumer (or refactor `f` not +to take `g`). +""" + +import jax.numpy as jnp +import pytest + +from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical +from lcm.exceptions import RegimeInitializationError +from lcm.typing import ContinuousAction, DiscreteAction, DiscreteState, FloatND + + +@categorical(ordered=False) +class PrefType: + type_0: int + type_1: int + + +@categorical(ordered=False) +class RegimeId: + alive: int + dead: int + + +def _per_type_scale_takes_pref_type( + pref_type: DiscreteState, some_param: FloatND +) -> FloatND: + """Takes pref_type — output is per-cell scalar; consumer must not re-index.""" + return jnp.abs(1.0 / (1.0 - some_param[pref_type])) + + +def _utility_redundantly_indexes( + consumption: ContinuousAction, + pref_type: DiscreteState, + per_type_scale: FloatND, +) -> FloatND: + # The clash: per_type_scale's producer takes pref_type, so its output is a + # per-cell scalar. Indexing that scalar by pref_type again raises IndexError + # at trace time. + return per_type_scale[pref_type] * jnp.log(consumption + 1.0) + + +def _next_regime(period: int) -> FloatND: + return jnp.where(period >= 1, RegimeId.dead, RegimeId.alive) + + +def _make_clashing_model() -> Model: + alive = Regime( + functions={ + "utility": _utility_redundantly_indexes, + "per_type_scale": _per_type_scale_takes_pref_type, + }, + states={"pref_type": DiscreteGrid(PrefType)}, + state_transitions={"pref_type": None}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5)}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age >= 2, + ) + return Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) + + +def test_function_output_indexed_by_state_raises(): + """A regime function output redundantly indexed by a discrete state inside + another regime function must raise on construction.""" + with pytest.raises( + RegimeInitializationError, + match=r"per_type_scale.*pref_type", + ): + _make_clashing_model() + + +def _utility_consumes_scalar( + consumption: ContinuousAction, + pref_type: DiscreteState, # noqa: ARG001 + per_type_scale: FloatND, +) -> FloatND: + # Safe variant: per_type_scale is consumed as a scalar (no [pref_type] indexing). + return per_type_scale * jnp.log(consumption + 1.0) + + +def test_safe_pattern_does_not_raise(): + """The safe pattern (function takes the state, returns a scalar; consumer + uses it directly) builds fine.""" + alive = Regime( + functions={ + "utility": _utility_consumes_scalar, + "per_type_scale": _per_type_scale_takes_pref_type, + }, + states={"pref_type": DiscreteGrid(PrefType)}, + state_transitions={"pref_type": None}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5)}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age >= 2, + ) + Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) + + +def _per_type_scale_array_output(some_param: FloatND) -> FloatND: + """Does NOT take pref_type — output is `(n_pref_types,)`-shaped. + + A consumer indexing this output by `pref_type` is correct: the indexing + selects the per-type entry. The validator must NOT flag this case. + """ + return jnp.abs(1.0 / (1.0 - some_param)) + + +def test_array_valued_producer_indexed_by_state_does_not_raise(): + """When the producing function does NOT take the discrete grid as input, + its output stays array-shaped and `func_output[grid]` is correct code.""" + alive = Regime( + functions={ + "utility": _utility_redundantly_indexes, + "per_type_scale": _per_type_scale_array_output, + }, + states={"pref_type": DiscreteGrid(PrefType)}, + state_transitions={"pref_type": None}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5)}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age >= 2, + ) + Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) + + +def test_function_output_indexed_by_derived_categorical_raises(): + """The check applies to derived categoricals (function outputs treated as + categoricals via `derived_categoricals`), not only states.""" + + @categorical(ordered=False) + class IsMarried: + single: int + married: int + + def _is_married(spousal_income: DiscreteState) -> DiscreteState: + return jnp.int32(spousal_income > 0) + + def _per_marital_scale(is_married: DiscreteState, some_param: FloatND) -> FloatND: + return jnp.abs(1.0 / (1.0 - some_param[is_married])) + + def _utility_clash( + consumption: ContinuousAction, + is_married: DiscreteState, + per_marital_scale: FloatND, + ) -> FloatND: + return per_marital_scale[is_married] * jnp.log(consumption + 1.0) + + @categorical(ordered=True) + class SpousalIncome: + single: int + married_no_inc: int + married_has_inc: int + + with pytest.raises( + RegimeInitializationError, + match=r"per_marital_scale.*is_married", + ): + Regime( + functions={ + "utility": _utility_clash, + "per_marital_scale": _per_marital_scale, + "is_married": _is_married, + }, + states={"spousal_income": DiscreteGrid(SpousalIncome)}, + state_transitions={"spousal_income": None}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5)}, + derived_categoricals={"is_married": DiscreteGrid(IsMarried)}, + transition=_next_regime, + active=lambda age: age < 2, + ) + + +def test_function_output_indexed_by_discrete_action_raises(): + """The check applies to discrete actions, not only states/derived categoricals.""" + + @categorical(ordered=False) + class WorkChoice: + no_work: int + work: int + + def _per_choice_scale(labor_supply: DiscreteAction, some_param: FloatND) -> FloatND: + return jnp.abs(1.0 / (1.0 - some_param[labor_supply])) + + def _utility_clash( + consumption: ContinuousAction, + labor_supply: DiscreteAction, + per_choice_scale: FloatND, + ) -> FloatND: + return per_choice_scale[labor_supply] * jnp.log(consumption + 1.0) + + with pytest.raises( + RegimeInitializationError, + match=r"per_choice_scale.*labor_supply", + ): + Regime( + functions={ + "utility": _utility_clash, + "per_choice_scale": _per_choice_scale, + }, + states={"pref_type": DiscreteGrid(PrefType)}, + state_transitions={"pref_type": None}, + actions={ + "consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5), + "labor_supply": DiscreteGrid(WorkChoice), + }, + transition=_next_regime, + active=lambda age: age < 2, + ) + + +def test_constraint_indexing_function_output_by_state_raises(): + """The check applies to regime constraints too, not only `functions`.""" + + def _constraint_indexing_function_output( + consumption: ContinuousAction, + pref_type: DiscreteState, + per_type_scale: FloatND, + ) -> FloatND: + return consumption <= per_type_scale[pref_type] + + with pytest.raises( + RegimeInitializationError, + match=r"per_type_scale.*pref_type", + ): + Regime( + functions={ + "utility": lambda consumption, pref_type, per_type_scale: jnp.log( # noqa: ARG005 + consumption + 1.0 + ), + "per_type_scale": _per_type_scale_takes_pref_type, + }, + states={"pref_type": DiscreteGrid(PrefType)}, + state_transitions={"pref_type": None}, + actions={"consumption": LinSpacedGrid(start=0.1, stop=5.0, n_points=5)}, + constraints={"feasibility": _constraint_indexing_function_output}, + transition=_next_regime, + active=lambda age: age < 2, + ) diff --git a/tests/test_grids.py b/tests/test_grids.py index 045f829b..da2370d9 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -217,6 +217,36 @@ def test_logspace_grid_invalid_start(): LogSpacedGrid(start=1, stop=0, n_points=10) +def test_logspace_grid_rejects_zero_start(): + with pytest.raises(GridInitializationError, match="log-spaced grid"): + LogSpacedGrid(start=0, stop=10, n_points=5) + + +def test_logspace_grid_rejects_negative_start(): + with pytest.raises(GridInitializationError, match="log-spaced grid"): + LogSpacedGrid(start=-1.0, stop=10, n_points=5) + + +def test_validate_continuous_grid_rejects_nan_start(): + with pytest.raises(GridInitializationError, match="start must be finite"): + _validate_continuous_grid(start=float("nan"), stop=10, n_points=5) + + +def test_validate_continuous_grid_rejects_inf_stop(): + with pytest.raises(GridInitializationError, match="stop must be finite"): + _validate_continuous_grid(start=1, stop=float("inf"), n_points=5) + + +def test_irreg_spaced_grid_rejects_nan_points(): + with pytest.raises(GridInitializationError, match="must be finite"): + IrregSpacedGrid(points=(1.0, float("nan"), 3.0)) + + +def test_irreg_spaced_grid_rejects_inf_points(): + with pytest.raises(GridInitializationError, match="must be finite"): + IrregSpacedGrid(points=(1.0, 2.0, float("inf"))) + + def test_replace_mixin(): grid = LinSpacedGrid(start=1, stop=5, n_points=5) new_grid = grid.replace(start=0) diff --git a/tests/test_runtime_params.py b/tests/test_runtime_params.py index 9d588649..61b2414a 100644 --- a/tests/test_runtime_params.py +++ b/tests/test_runtime_params.py @@ -69,7 +69,15 @@ def test_runtime_grid_creation(): grid = IrregSpacedGrid(n_points=5) assert grid.pass_points_at_runtime assert grid.n_points == 5 - assert grid.to_jax().shape == (5,) # placeholder zeros + + +def test_runtime_grid_to_jax_raises_before_substitution(): + """`to_jax()` on a runtime grid is a bug — substitution happens at + solve/simulate time via `InternalRegime.state_action_space(regime_params=...)`. + The error message must point the caller at that API.""" + grid = IrregSpacedGrid(n_points=5) + with pytest.raises(Exception, match="state_action_space"): + grid.to_jax() def test_fixed_grid_not_runtime(): @@ -141,3 +149,167 @@ def test_runtime_grid_matches_fixed(): for period in V_fixed: if "alive" in V_fixed[period] and "alive" in V_runtime[period]: aaae(V_fixed[period]["alive"], V_runtime[period]["alive"]) + + +def _make_action_grid_model(*, consumption_grid: IrregSpacedGrid) -> Model: + """Create a 2-regime model where consumption is the runtime-points action grid.""" + alive = Regime( + functions={"utility": _utility}, + states={"wealth": LinSpacedGrid(start=1, stop=10, n_points=5)}, + state_transitions={"wealth": _next_wealth}, + actions={"consumption": consumption_grid}, + constraints={"borrowing_constraint": _borrowing_constraint}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age >= 2, + ) + return Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) + + +def test_runtime_action_grid_in_params_template(): + """IrregSpacedGrid action with runtime-supplied points adds 'points' to template.""" + model = _make_action_grid_model( + consumption_grid=IrregSpacedGrid(n_points=5), + ) + alive_template = model._params_template["alive"] + assert "consumption" in alive_template + assert "points" in alive_template["consumption"] + + +def test_solve_with_runtime_action_grid(): + """Solve should work when action grid points are provided via params.""" + model = _make_action_grid_model( + consumption_grid=IrregSpacedGrid(n_points=5), + ) + params = { + "discount_factor": 0.95, + "interest_rate": 0.05, + "alive": {"consumption": {"points": jnp.linspace(0.1, 5.0, 5)}}, + } + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + assert len(period_to_regime_to_V_arr) > 0 + + +def test_runtime_action_grid_matches_fixed(): + """Runtime action grid with same points gives same V as a fixed action grid.""" + points = jnp.linspace(0.1, 5.0, 5) + + model_fixed = _make_action_grid_model( + consumption_grid=IrregSpacedGrid(points=list(points.tolist())), + ) + params_fixed = {"discount_factor": 0.95, "interest_rate": 0.05} + V_fixed = model_fixed.solve(params=params_fixed, log_level="off") + + model_runtime = _make_action_grid_model( + consumption_grid=IrregSpacedGrid(n_points=5), + ) + params_runtime = { + "discount_factor": 0.95, + "interest_rate": 0.05, + "alive": {"consumption": {"points": points}}, + } + V_runtime = model_runtime.solve(params=params_runtime, log_level="off") + + for period in V_fixed: + if "alive" in V_fixed[period] and "alive" in V_runtime[period]: + aaae(V_fixed[period]["alive"], V_runtime[period]["alive"]) + + +def test_runtime_action_grid_changes_solution(): + """Different runtime action points should yield different V (sanity check).""" + model = _make_action_grid_model( + consumption_grid=IrregSpacedGrid(n_points=5), + ) + base = {"discount_factor": 0.95, "interest_rate": 0.05} + V_low = model.solve( + params=base | {"alive": {"consumption": {"points": jnp.linspace(0.1, 1.0, 5)}}}, + log_level="off", + ) + V_high = model.solve( + params=base | {"alive": {"consumption": {"points": jnp.linspace(0.1, 5.0, 5)}}}, + log_level="off", + ) + # Period 0 alive value should differ when the action support differs + assert not jnp.allclose(V_low[0]["alive"], V_high[0]["alive"]) + + +def _make_action_grid_model_with_stateful_dead( + *, consumption_grid: IrregSpacedGrid +) -> Model: + """Variant where `dead` has a `wealth` state so its utility depends on it. + + Used to surface NaN propagation when the simulate path forgets to + substitute runtime-supplied action gridpoints. + """ + + def _alive_utility( + consumption: ContinuousAction, wealth: ContinuousState + ) -> FloatND: + return jnp.log(consumption + 1) + 0.01 * wealth + + def _dead_utility(wealth: ContinuousState) -> FloatND: + return jnp.log(wealth + 1) + + alive = Regime( + functions={"utility": _alive_utility}, + states={"wealth": LinSpacedGrid(start=1, stop=10, n_points=5)}, + state_transitions={ + "wealth": { + "alive": _next_wealth, + "dead": _next_wealth, + }, + }, + actions={"consumption": consumption_grid}, + constraints={"borrowing_constraint": _borrowing_constraint}, + transition=_next_regime, + active=lambda age: age < 2, + ) + dead = Regime( + transition=None, + functions={"utility": _dead_utility}, + states={"wealth": LinSpacedGrid(start=1, stop=10, n_points=5)}, + active=lambda _age: True, + ) + return Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RegimeId, + ) + + +def test_simulate_with_runtime_action_grid_no_nan() -> None: + """Simulate must substitute runtime-supplied action gridpoints into the + state-action space; otherwise the action grid is filled with NaN + placeholders, optimal_actions become NaN, next_states propagate NaN to + the dead regime, and `validate_V` raises. + """ + model = _make_action_grid_model_with_stateful_dead( + consumption_grid=IrregSpacedGrid(n_points=5), + ) + params = { + "discount_factor": 0.95, + "interest_rate": 0.05, + "alive": {"consumption": {"points": jnp.linspace(0.1, 5.0, 5)}}, + } + initial_conditions = { + "regime": jnp.array([RegimeId.alive, RegimeId.alive, RegimeId.alive]), + "age": jnp.array([0.0, 0.0, 0.0]), + "wealth": jnp.array([2.0, 5.0, 9.0]), + } + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + log_level="off", + check_initial_conditions=False, + ) + df = result.to_dataframe() + assert not df["value"].isna().any() diff --git a/tests/test_single_feasible_action.py b/tests/test_single_feasible_action.py new file mode 100644 index 00000000..53a18594 --- /dev/null +++ b/tests/test_single_feasible_action.py @@ -0,0 +1,545 @@ +"""Reproduce ways `solve` can raise `InvalidValueFunctionError: all values are NaN`. + +Two hypotheses are exercised here: + +1. **single/all infeasible action**: when constraints leave a single (or no) + consumption gridpoint feasible. `max(Q, where=F, initial=-inf)` is supposed + to mask infeasible Q values, but if Q itself contains NaN at infeasible + cells (e.g. `log(0)` at `consumption=0`), the mask is enough — proven by + the tests in this file. `-inf` from all-infeasible cells does not cascade + to NaN by itself either. + +2. **CRRA bequest with discrete-state indexing under jnp.where**: a bequest + function evaluates *both* branches of `jnp.where(jnp.isclose(gamma, 1), + log_branch, power_branch)`. For a parameter set where gamma is exactly 1 + for one categorical type but not the other, the power_branch divides by + `1 - gamma = 0` for the gamma=1 type. NaN/Inf from the *unselected* + branch leaks through `jnp.where`'s gradient/forward pass under JIT + tracing in some configurations. + +The model uses an `IrregSpacedGrid` consumption action with runtime-supplied +points to also exercise the runtime-action-grids substitution path. +""" + +import jax.numpy as jnp +import pytest + +from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical +from lcm.grids import IrregSpacedGrid +from lcm.grids.coordinates import get_irreg_coordinate +from lcm.regime_building.ndimage import map_coordinates +from lcm.typing import ContinuousAction, ContinuousState, DiscreteState, FloatND + + +@categorical(ordered=False) +class RegimeId: + alive: int + dead: int + + +def _utility(consumption: ContinuousAction) -> FloatND: + """CRRA-like utility. log requires consumption > 0.""" + return jnp.log(consumption) + + +def _next_wealth( + wealth: ContinuousState, + consumption: ContinuousAction, +) -> ContinuousState: + return wealth - consumption + + +def _borrowing_constraint( + consumption: ContinuousAction, wealth: ContinuousState +) -> FloatND: + return consumption <= wealth + + +def _next_regime(age: int, last_alive_age: int) -> FloatND: + """Deterministic regime transition: alive→alive while age= last_alive_age, RegimeId.dead, RegimeId.alive) + + +def _build_model( + *, + wealth_lo: float, + wealth_hi: float, + n_wealth: int, + consumption_lo: float, + consumption_hi: float, + n_consumption: int, + n_periods: int, +) -> tuple[Model, dict]: + """Build a 2-regime (alive, dead) model with runtime consumption points. + + The wealth grid is fixed-LinSpaced (so changing it doesn't perturb the + runtime-action-grid path). The consumption grid is the runtime-supplied + IrregSpacedGrid. + """ + last_alive_age = n_periods - 2 # alive at ages 0..n-2; dead at n-1 + alive = Regime( + functions={"utility": _utility}, + states={ + "wealth": LinSpacedGrid(start=wealth_lo, stop=wealth_hi, n_points=n_wealth) + }, + state_transitions={"wealth": _next_wealth}, + actions={"consumption": IrregSpacedGrid(n_points=n_consumption)}, + constraints={"borrowing_constraint": _borrowing_constraint}, + transition=_next_regime, + active=lambda age: age <= last_alive_age, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age > last_alive_age, + ) + model = Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=n_periods - 1, step="Y"), + regime_id_class=RegimeId, + ) + consumption_points = jnp.linspace(consumption_lo, consumption_hi, n_consumption) + params = { + "discount_factor": 0.95, + "alive": { + "consumption": {"points": consumption_points}, + "next_regime": {"last_alive_age": last_alive_age}, + }, + } + return model, params + + +def test_baseline_no_nan(): + """Healthy regime: at every wealth, every consumption point is feasible. + + consumption_lo == wealth_lo so the smallest action is always feasible at + every wealth gridpoint. `consumption_hi <= wealth_lo` ensures even the + largest consumption is feasible at the lowest wealth state. + """ + model, params = _build_model( + wealth_lo=10.0, + wealth_hi=20.0, + n_wealth=5, + consumption_lo=1.0, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + assert not jnp.any(jnp.isnan(V_arr)) + + +def test_some_states_have_only_one_feasible_action(): + """At low-wealth states, only consumption[0] satisfies `consumption <= wealth`. + + consumption_lo < wealth_lo (so smallest action feasible at every wealth), + but consumption[1:] > wealth_lo (so at the lowest wealth gridpoint only + consumption[0] is feasible). + """ + model, params = _build_model( + wealth_lo=1.0, + wealth_hi=20.0, + n_wealth=5, + consumption_lo=0.5, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + # At wealth = 1.0, only consumption = 0.5 is feasible (1.625, 2.75, 3.875, + # 5.0 all > 1.0). + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + assert not jnp.any(jnp.isnan(V_arr)), ( + "single-feasible-action state should not produce NaN V" + ) + + +def test_some_states_have_no_feasible_action(): + """At sufficiently low wealth, *every* consumption gridpoint is infeasible. + + consumption_lo > wealth_lo means at the lowest wealth state no + consumption point satisfies the borrowing constraint. + """ + model, params = _build_model( + wealth_lo=0.1, + wealth_hi=20.0, + n_wealth=5, + consumption_lo=1.0, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + # At wealth = 0.1, no consumption point is feasible. max returns -inf. + # The next period interpolates V over the resulting -inf cells. + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + assert not jnp.any(jnp.isnan(V_arr)), ( + "all-infeasible state should yield -inf, not NaN" + ) + + +def test_log_zero_consumption_propagates_nan_via_max_when_unconstrained(): + """U(c=0) = -inf is fine, but U evaluated at infeasible negative wealth + via `next_wealth = wealth - c` going through interpolation in the next + period should not pollute V.""" + # consumption_lo = 0 → log(0) = -inf at the smallest action, regardless + # of feasibility. `where=F_arr` should mask this out of the max. + model, params = _build_model( + wealth_lo=1.0, + wealth_hi=10.0, + n_wealth=5, + consumption_lo=0.0, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + # log(0) = -inf is not NaN, but combined with where-mask edge + # cases this could in principle leak; check it does not. + assert not jnp.any(jnp.isnan(V_arr)) + + +@pytest.mark.parametrize( + ("wealth_lo", "consumption_lo", "label"), + [ + (1.0, 0.5, "single-feasible"), + (0.1, 1.0, "all-infeasible"), + ], +) +def test_simulate_with_constrained_action_grid(wealth_lo, consumption_lo, label): + """End-to-end solve+simulate for both regimes.""" + model, params = _build_model( + wealth_lo=wealth_lo, + wealth_hi=20.0, + n_wealth=5, + consumption_lo=consumption_lo, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + initial_conditions = { + "age": jnp.array([0.0, 0.0, 0.0]), + "wealth": jnp.array([wealth_lo, 5.0, 20.0]), + "regime": jnp.array( + [RegimeId.alive, RegimeId.alive, RegimeId.alive], dtype=jnp.int32 + ), + } + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + check_initial_conditions=False, + log_level="off", + ) + df = result.to_dataframe() + assert not df["value"].isna().any(), ( + f"{label}: simulated value column should not contain NaN" + ) + + +@categorical(ordered=False) +class PrefType: + type_0: int + type_1: int + + +@categorical(ordered=False) +class AliveDeadRegimeId: + alive: int + dead: int + + +def _crra_bequest( + assets: ContinuousState, + pref_type: DiscreteState, + bequest_shifter: float, + consumption_weight: FloatND, + coefficient_rra: FloatND, +) -> FloatND: + """Simplified CRRA bequest used to probe NaN propagation under jnp.where. + + `consumption_weight` and `coefficient_rra` are FloatND indexed by + `pref_type`. Both branches of the `jnp.where` are traced. + """ + alpha = consumption_weight[pref_type] + gamma = coefficient_rra[pref_type] + assets_shifted = jnp.maximum(0.0, assets) + bequest_shifter + one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma) + return jnp.where( + jnp.isclose(gamma, 1.0), + jnp.log(assets_shifted), + assets_shifted ** (one_minus_gamma * alpha) / one_minus_gamma, + ) + + +def _alive_utility( + consumption: ContinuousAction, + pref_type: DiscreteState, + consumption_weight: FloatND, +) -> FloatND: + """Make pref_type matter in alive's utility too (otherwise pylcm complains).""" + alpha = consumption_weight[pref_type] + return alpha * jnp.log(consumption) + + +def _next_assets( + assets: ContinuousState, consumption: ContinuousAction +) -> ContinuousState: + return assets - consumption + + +def _alive_borrow(consumption: ContinuousAction, assets: ContinuousState) -> FloatND: + return consumption <= assets + + +def _alive_to_dead(age: int, last_alive_age: int) -> FloatND: + """Deterministic regime transition; returns a scalar regime ID.""" + return jnp.where( + age >= last_alive_age, AliveDeadRegimeId.dead, AliveDeadRegimeId.alive + ) + + +def _build_alive_dead_model( + *, + coefficient_rra: tuple[float, float], + consumption_weight: tuple[float, float], + n_periods: int = 3, +) -> tuple[Model, dict]: + last_alive_age = n_periods - 2 + + alive = Regime( + functions={"utility": _alive_utility}, + states={ + "assets": LinSpacedGrid(start=1.0, stop=20.0, n_points=5), + "pref_type": DiscreteGrid(PrefType, batch_size=1), + }, + state_transitions={"assets": _next_assets, "pref_type": None}, + actions={"consumption": IrregSpacedGrid(n_points=5)}, + constraints={"borrowing_constraint": _alive_borrow}, + transition=_alive_to_dead, + active=lambda age: age <= last_alive_age, + ) + dead = Regime( + transition=None, + functions={"utility": _crra_bequest}, + states={ + "assets": LinSpacedGrid(start=1.0, stop=20.0, n_points=5), + "pref_type": DiscreteGrid(PrefType, batch_size=1), + }, + active=lambda _age: True, + ) + model = Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=n_periods - 1, step="Y"), + regime_id_class=AliveDeadRegimeId, + ) + cw_arr = jnp.asarray(consumption_weight) + params = { + "discount_factor": 0.95, + "alive": { + "utility": {"consumption_weight": cw_arr}, + "consumption": {"points": jnp.linspace(0.5, 5.0, 5)}, + "next_regime": {"last_alive_age": last_alive_age}, + }, + "dead": { + "utility": { + "bequest_shifter": 100.0, + "consumption_weight": cw_arr, + "coefficient_rra": jnp.asarray(coefficient_rra), + }, + }, + } + return model, params + + +def test_bequest_gamma_close_to_one_is_safe(): + """gamma=0.999077 (the benchmark value for type_1) should not produce NaN.""" + model, params = _build_alive_dead_model( + coefficient_rra=(3.84, 0.999077), + consumption_weight=(0.68, 0.88), + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + for regime_to_V in period_to_regime_to_V_arr.values(): + for V_arr in regime_to_V.values(): + assert not jnp.any(jnp.isnan(V_arr)) + + +def test_bequest_gamma_exactly_one_for_one_type_only(): + """gamma=1.0 exactly for one pref type triggers `jnp.where(isclose, log, power)`. + + The unselected `power` branch divides by `1 - gamma = 0` for that type. JAX + evaluates both branches of `jnp.where`; the non-finite from the unselected + branch is masked, but `0/0` produces NaN that may not be masked correctly + when the operand to `jnp.where` is itself NaN under XLA's nan-prop rules. + """ + model, params = _build_alive_dead_model( + coefficient_rra=(3.84, 1.0), # type_1 hits the log branch + consumption_weight=(0.68, 0.88), + ) + period_to_regime_to_V_arr = model.solve(params=params, log_level="off") + for period, regime_to_V in period_to_regime_to_V_arr.items(): + for regime, V_arr in regime_to_V.items(): + assert not jnp.any(jnp.isnan(V_arr)), ( + f"NaN in V[{regime}, period={period}] when one type has gamma=1.0" + ) + + +@pytest.mark.parametrize("bad_coord", [jnp.inf, -jnp.inf, jnp.nan]) +def test_map_coordinates_returns_nan_for_non_finite_coordinate(bad_coord): + """`map_coordinates` cannot recover from a non-finite continuous-state + coordinate: the linear-interp weights become `inf` and `1 - inf = -inf`, + and `inf * V[k] - inf * V[k-1]` reduces to NaN. + + Implication for callers: any path that can feed `inf` or `NaN` into the + coordinate finder (division by zero in a state transition, overflow + when V values are large, or `0/0` in a degenerate IrregSpacedGrid + segment) will produce NaN in V. + """ + V_arr = jnp.array([1.0, 5.0, 12.0]) + out = map_coordinates(V_arr, coordinates=[jnp.array(bad_coord)]) + assert jnp.isnan(out) + + +def test_irreg_coordinate_divides_by_zero_on_duplicate_grid_points(): + """`get_irreg_coordinate` divides by `upper_point - lower_point`. If a + runtime-supplied points array contains duplicates, the divisor is 0. + + Reproduces a class of failures specific to runtime-supplied grids: when + the caller constructs points from a parameter that can collapse (e.g. a + `geomspace(lo, hi, n)` with `lo == hi`, or any param-driven `linspace` + whose endpoints can coincide). + """ + # Duplicate adjacent points where the query value equals the duplicate. + # `searchsorted([0, 1, 1], 1.0, side='right')=3` → clipped to n-1=2, + # idx_lower=1, lower_point=points[1]=1.0, upper_point=points[2]=1.0, + # step_size=0 → decimal_part = 0/0 = nan. + points = jnp.array([0.0, 1.0, 1.0]) + coord = get_irreg_coordinate(value=jnp.array(1.0), points=points) + assert not jnp.isfinite(coord), ( + "Duplicate adjacent grid points cause `step_size = 0` in " + "`get_irreg_coordinate`; current behaviour is to silently return " + "inf/nan, which then poisons V interpolation downstream." + ) + + +def _runtime_state_grid_model() -> tuple[Model, dict, dict]: + """Build a 2-regime model with a runtime-supplied IrregSpacedGrid *state*. + + Reproduces the failure mode where `validate_initial_conditions` reads + `_base_state_action_space` directly — still holding the placeholder + zeros for runtime-supplied `IrregSpacedGrid`s. With a feasibility + constraint that the all-zero placeholder fails, every subject is + reported infeasible even though the real (post-substitution) grid + would pass. The same mechanism affects runtime grids whether they + are states or actions. + """ + + @categorical(ordered=False) + class RuntimeRegimeId: + alive: int + dead: int + + def utility(consumption, wealth): + return jnp.log(consumption) + 0.0 * wealth + + def next_wealth(wealth, consumption): + return wealth - consumption + + def borrow(consumption, wealth): # noqa: ARG001 + # The validator sees `wealth` as a per-subject array with the + # subject-supplied initial values, but `consumption` as the *grid* + # (placeholder zeros for runtime grids). With a feasibility check + # that requires `consumption > 0`, every action gridpoint is + # infeasible until the runtime points replace the placeholder. + return consumption > 0 + + def next_regime(age, last_alive_age): + return jnp.where( + age >= last_alive_age, RuntimeRegimeId.dead, RuntimeRegimeId.alive + ) + + last_alive_age = 1 + alive = Regime( + functions={"utility": utility}, + states={"wealth": IrregSpacedGrid(n_points=4)}, # runtime state grid + state_transitions={"wealth": next_wealth}, + actions={"consumption": LinSpacedGrid(start=0.5, stop=5.0, n_points=5)}, + constraints={"borrow": borrow}, + transition=next_regime, + active=lambda age: age <= last_alive_age, + ) + dead = Regime( + transition=None, + functions={"utility": lambda: 0.0}, + active=lambda age: age > last_alive_age, + ) + model = Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=RuntimeRegimeId, + ) + params = { + "discount_factor": 0.95, + "alive": { + "wealth": {"points": jnp.linspace(1.0, 10.0, 4)}, + "next_regime": {"last_alive_age": last_alive_age}, + }, + } + initial_conditions = { + "age": jnp.array([0.0, 0.0, 0.0]), + "wealth": jnp.array([2.0, 5.0, 9.0]), + "regime": jnp.array( + [RuntimeRegimeId.alive, RuntimeRegimeId.alive, RuntimeRegimeId.alive], + dtype=jnp.int32, + ), + } + return model, params, initial_conditions + + +def test_runtime_action_grid_passes_initial_conditions_validation(): + """`feature/runtime-action-grids` regression: initial-conditions + feasibility check must use the *substituted* action grid, not the + `_base_state_action_space` placeholder zeros.""" + model, params = _build_model( + wealth_lo=10.0, + wealth_hi=20.0, + n_wealth=5, + consumption_lo=1.0, + consumption_hi=5.0, + n_consumption=5, + n_periods=3, + ) + initial_conditions = { + "age": jnp.array([0.0, 0.0, 0.0]), + "wealth": jnp.array([10.0, 15.0, 20.0]), + "regime": jnp.array( + [RegimeId.alive, RegimeId.alive, RegimeId.alive], dtype=jnp.int32 + ), + } + # `check_initial_conditions=True` (the default) must pass — the + # runtime-supplied consumption points are well-formed. + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + log_level="off", + ) + assert result.n_subjects == 3 + + +def test_runtime_state_grid_passes_initial_conditions_validation(): + """Same regression for runtime-supplied *state* grids.""" + model, params, initial_conditions = _runtime_state_grid_model() + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + log_level="off", + ) + assert result.n_subjects == 3 diff --git a/tests/test_validation_scalar_actions.py b/tests/test_validation_scalar_actions.py index 6855e4b9..89289f5b 100644 --- a/tests/test_validation_scalar_actions.py +++ b/tests/test_validation_scalar_actions.py @@ -7,8 +7,8 @@ Validation must do the same. This pattern arises in models that pass multi-dimensional lookup tables as parameters -via MappingLeaf — e.g. tax schedules and pension accrual tables in aca-model, or -tax-transfer schedules in ttsim/gettsim. +via MappingLeaf — e.g. tax schedules, pension accrual tables, or tax-transfer +schedules. """ import jax