diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 80e3965..548f73b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,11 +10,11 @@ repos: hooks: - id: yamllint - repo: https://github.com/lyz-code/yamlfix - rev: 1.17.0 + rev: 1.19.1 hooks: - id: yamlfix - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: check-added-large-files args: @@ -42,7 +42,7 @@ repos: - id: python-use-type-annotations - id: text-unicode-replacement-char - repo: https://github.com/pycqa/isort - rev: 6.0.1 + rev: 7.0.0 hooks: - id: isort name: isort @@ -54,14 +54,14 @@ repos: # - id: reorder-python-imports # args: # - --py37-plus - - repo: https://github.com/psf/black - rev: 25.1.0 + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 25.12.0 hooks: - id: black language_version: python3.11 exclude: tests/utils/fast_upper_envelope_org.py - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.8 + rev: v0.14.10 hooks: - id: ruff # exclude: | @@ -87,7 +87,7 @@ repos: - id: nbqa-black - id: nbqa-ruff - repo: https://github.com/executablebooks/mdformat - rev: 0.7.22 + rev: 1.0.0 hooks: - id: mdformat additional_dependencies: diff --git a/src/upper_envelope/fues_jax/fues_jax.py b/src/upper_envelope/fues_jax/fues_jax.py index 44abb7c..acb430a 100644 --- a/src/upper_envelope/fues_jax/fues_jax.py +++ b/src/upper_envelope/fues_jax/fues_jax.py @@ -11,7 +11,6 @@ import jax import jax.numpy as jnp -import numpy as np from jax import vmap from upper_envelope.fues_jax.check_and_scan_funcs import ( @@ -124,14 +123,14 @@ def fues_jax( # Because of jax, we always need to perform the same set of computations. Hence, # if there is no wealth grid point below the first, we just add nans thereafter. - min_id = np.argmin(endog_grid) + min_id = jnp.argmin(endog_grid) min_wealth_grid = endog_grid[min_id] # This is the condition, which we do not use at the moment. # closed_form_cond = min_wealth_grid < endog_grid[0] grid_points_to_add = jnp.linspace( - min_wealth_grid, endog_grid[0], n_constrained_points_to_add + 1 - )[:-1] + min_wealth_grid, endog_grid[0], n_constrained_points_to_add + ) # Compute closed form values values_to_add = vmap(_compute_value, in_axes=(0, None, None, None))( grid_points_to_add, value_function, value_function_args, value_function_kwargs diff --git a/tests/test_fues_numba.py b/tests/test_fues_numba.py index 98189cf..b4ab325 100644 --- a/tests/test_fues_numba.py +++ b/tests/test_fues_numba.py @@ -161,9 +161,9 @@ def test_fast_upper_envelope_against_org_fues(setup_model): policy_expected = policy_org[~np.isnan(policy_org)] value_expected = value_org[~np.isnan(value_org)] - assert np.all(np.in1d(endog_grid_expected, endog_grid_refined)) - assert np.all(np.in1d(policy_expected, policy_refined)) - assert np.all(np.in1d(value_expected, value_refined)) + assert np.all(np.isin(endog_grid_expected, endog_grid_refined)) + assert np.all(np.isin(policy_expected, policy_refined)) + assert np.all(np.isin(value_expected, value_refined)) @pytest.mark.parametrize("period", [2, 4, 10, 9, 18])