Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ repos:
hooks:
- id: yamllint
- repo: https://github.com/lyz-code/yamlfix
rev: 1.17.0
rev: 1.19.1
hooks:
- id: yamlfix
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: check-added-large-files
args:
Expand Down Expand Up @@ -42,7 +42,7 @@ repos:
- id: python-use-type-annotations
- id: text-unicode-replacement-char
- repo: https://github.com/pycqa/isort
rev: 6.0.1
rev: 7.0.0
hooks:
- id: isort
name: isort
Expand All @@ -54,14 +54,14 @@ repos:
# - id: reorder-python-imports
# args:
# - --py37-plus
- repo: https://github.com/psf/black
rev: 25.1.0
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.12.0
hooks:
- id: black
language_version: python3.11
exclude: tests/utils/fast_upper_envelope_org.py
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.8
rev: v0.14.10
hooks:
- id: ruff
# exclude: |
Expand All @@ -87,7 +87,7 @@ repos:
- id: nbqa-black
- id: nbqa-ruff
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.22
rev: 1.0.0
hooks:
- id: mdformat
additional_dependencies:
Expand Down
7 changes: 3 additions & 4 deletions src/upper_envelope/fues_jax/fues_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import jax
import jax.numpy as jnp
import numpy as np
from jax import vmap

from upper_envelope.fues_jax.check_and_scan_funcs import (
Expand Down Expand Up @@ -124,14 +123,14 @@ def fues_jax(

# Because of jax, we always need to perform the same set of computations. Hence,
# if there is no wealth grid point below the first, we just add nans thereafter.
min_id = np.argmin(endog_grid)
min_id = jnp.argmin(endog_grid)
min_wealth_grid = endog_grid[min_id]

# This is the condition, which we do not use at the moment.
# closed_form_cond = min_wealth_grid < endog_grid[0]
grid_points_to_add = jnp.linspace(
min_wealth_grid, endog_grid[0], n_constrained_points_to_add + 1
)[:-1]
min_wealth_grid, endog_grid[0], n_constrained_points_to_add
)
# Compute closed form values
values_to_add = vmap(_compute_value, in_axes=(0, None, None, None))(
grid_points_to_add, value_function, value_function_args, value_function_kwargs
Expand Down
6 changes: 3 additions & 3 deletions tests/test_fues_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ def test_fast_upper_envelope_against_org_fues(setup_model):
policy_expected = policy_org[~np.isnan(policy_org)]
value_expected = value_org[~np.isnan(value_org)]

assert np.all(np.in1d(endog_grid_expected, endog_grid_refined))
assert np.all(np.in1d(policy_expected, policy_refined))
assert np.all(np.in1d(value_expected, value_refined))
assert np.all(np.isin(endog_grid_expected, endog_grid_refined))
assert np.all(np.isin(policy_expected, policy_refined))
assert np.all(np.isin(value_expected, value_refined))


@pytest.mark.parametrize("period", [2, 4, 10, 9, 18])
Expand Down
Loading