Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ repos:
hooks:
- id: yamllint
- repo: https://github.com/lyz-code/yamlfix
rev: 1.17.0
rev: 1.19.1
hooks:
- id: yamlfix
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: check-added-large-files
args:
Expand Down Expand Up @@ -42,7 +42,7 @@ repos:
- id: python-use-type-annotations
- id: text-unicode-replacement-char
- repo: https://github.com/pycqa/isort
rev: 6.0.1
rev: 7.0.0
hooks:
- id: isort
name: isort
Expand All @@ -54,14 +54,14 @@ repos:
# - id: reorder-python-imports
# args:
# - --py37-plus
- repo: https://github.com/psf/black
rev: 25.1.0
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.12.0
hooks:
- id: black
language_version: python3.11
language_version: python3.12
exclude: tests/utils/fast_upper_envelope_org.py
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.8
rev: v0.14.10
hooks:
- id: ruff
# exclude: |
Expand All @@ -87,7 +87,7 @@ repos:
- id: nbqa-black
- id: nbqa-ruff
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.22
rev: 1.0.0
hooks:
- id: mdformat
additional_dependencies:
Expand All @@ -97,12 +97,12 @@ repos:
- --wrap
- '88'
files: (README\.md)
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
hooks:
- id: codespell
additional_dependencies:
- tomli
# - repo: https://github.com/codespell-project/codespell
# rev: v2.4.1
# hooks:
# - id: codespell
# additional_dependencies:
# - tomli
# - repo: https://github.com/mgedmin/check-manifest
# rev: "0.49"
# hooks:
Expand Down
38 changes: 30 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,40 @@
# Fast Upper Envelope Scan (FUES)
# Upper Envelope Package

[![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE)

[![PyPI version](https://badge.fury.io/py/upper-envelope.svg)](https://badge.fury.io/py/upper-envelope)
[![Downloads](https://pepy.tech/badge/upper-envelope)](https://pepy.tech/project/upper-envelope)

[![Continuous Integration Workflow](https://github.com/OpenSourceEconomics/upper-envelope/actions/workflows/main.yml/badge.svg)](https://github.com/OpenSourceEconomics/upper-envelope/actions/workflows/main.yml)
[![Codecov](https://codecov.io/gh/OpenSourceEconomics/upper-envelope/branch/main/graph/badge.svg)](https://app.codecov.io/gh/OpenSourceEconomics/upper-envelope)
[![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)

Extension of the Fast Upper-Envelope Scan (FUES) for solving discrete-continuous dynamic
programming problems based on Dobrescu & Shanker (2022). Both `jax` and `numba` versions
are available.
This package collects several HPC implementations of upper-envelopes used to correct the
value and policy functions in discrete-continuous dynamic programming problems.

The following implementations are available:

- Extension of the Fast Upper-Envelope Scan (FUES) for solving discrete-continuous
dynamic programming problems based on Dobrescu & Shanker (2022). Both `jax` and
`numba` versions are available. We provide the original version without endogenous
jump detection.

- Line segment interpolation and selection of the upper envelope based on Druedahl &
Jorgensen (2017). Both `jax` and `numba` versions are available.

- Also contained for test reasons is the original upper-envelope implementation from
Iskhakov et al. (2017). It is not optimized and can not yet be imported when
installing the package.

## References

1. Iskhakov, Jorgensen, Rust, & Schjerning (2017).
1. Dobrescu & Shanker (2022).
[Fast Upper-Envelope Scan for Discrete-Continuous Dynamic Programming](https://dx.doi.org/10.2139/ssrn.4181302).

1. Druedahl & Jørgensen (2017).
[A general endogenous grid method for multi-dimensional models with non-convexities and constraints](https://www.sciencedirect.com/science/article/abs/pii/S0165188916301920).
*Journal of Economic Dynamics and Control*

1. Iskhakov, Jørgensen, Rust, & Schjerning (2017).
[The Endogenous Grid Method for Discrete-Continuous Dynamic Choice Models with (or without) Taste Shocks](http://onlinelibrary.wiley.com/doi/10.3982/QE643/full).
*Quantitative Economics*

1. Loretti I. Dobrescu & Akshay Shanker (2022).
[Fast Upper-Envelope Scan for Discrete-Continuous Dynamic Programming](https://dx.doi.org/10.2139/ssrn.4181302).
1 change: 1 addition & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ ignore:
- tests/*
- tests/**/*
- .tox/**/*
- docs/
236 changes: 236 additions & 0 deletions docs/time_period2_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import argparse
import time
from pathlib import Path
from typing import Dict

import jax
import jax.numpy as jnp
import numpy as np
from numba import njit

import upper_envelope as upenv

jax.config.update("jax_enable_x64", True)

ROOT_DIR = Path(__file__).resolve().parents[1]
n_runs = 10


def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--runs", type=int, default=10)
return parser.parse_args()


def utility_crra_jax(
consumption: jnp.ndarray, choice: int, params: Dict[str, float]
) -> jnp.ndarray:
utility_consumption = (consumption ** (1 - params["rho"]) - 1) / (1 - params["rho"])
utility = utility_consumption - (1 - choice) * params["delta"]
return utility


@njit
def value_func_numba(
consumption, choice, beta, rho, delta, continuation_at_zero_savings
):
utility_consumption = (consumption ** (1 - rho) - 1) / (1 - rho)
utility = utility_consumption - (1 - choice) * delta
return utility + beta * continuation_at_zero_savings


test_resources = ROOT_DIR / "tests" / "resources" / "upper_envelope_period_tests"

period = 2
value_egm = np.genfromtxt(
test_resources / f"val{period}.csv", delimiter=",", dtype=float
)
policy_egm = np.genfromtxt(
test_resources / f"pol{period}.csv", delimiter=",", dtype=float
)

params: Dict[str, float] = {"beta": 0.95, "rho": 1.95, "delta": 0.35}
state_choice = {"lagged_choice": 0, "choice": 0}


def value_func_jax(consumption, choice, params):
return (
utility_crra_jax(consumption, choice, params) + params["beta"] * value_egm[1, 0]
)


def fues_jax_partial(endog, pol, val, exp_val_zero):
return upenv.fues_jax(
endog_grid=jnp.asarray(endog),
policy=jnp.asarray(pol),
value=jnp.asarray(val),
expected_value_zero_savings=exp_val_zero,
value_function=value_func_jax,
value_function_args=(state_choice["choice"], params),
)


fues_jax_partial_jit = jax.jit(fues_jax_partial)

# Compile time
start = time.time()
jax.block_until_ready(
fues_jax_partial_jit(
endog=policy_egm[0, 1:],
pol=policy_egm[1, 1:],
val=value_egm[1, 1:],
exp_val_zero=value_egm[1, 0],
)
)
end = time.time()
print(f"JAX FUES compilation time: {end - start:.4f} seconds")

tot_time = 0.0
for _ in range(n_runs):
start = time.time()
jax.block_until_ready(
fues_jax_partial_jit(
endog=policy_egm[0, 1:],
pol=policy_egm[1, 1:],
val=value_egm[1, 1:],
exp_val_zero=value_egm[1, 0],
)
)
end = time.time()
tot_time += end - start

print(f"JAX FUES average time over {n_runs} runs: {tot_time / n_runs:.6f} seconds")


def drued_jorg_jax_partial(endog, pol, val, m_grid, exp_val_zero):
return upenv.drued_jorg_jax(
endog_grid=endog,
policy=pol,
value=val,
m_grid=m_grid,
expected_value_zero_savings=exp_val_zero,
value_function=value_func_jax,
value_function_args=(state_choice["choice"], params),
)


drued_jorg_jax_partial_jit = jax.jit(drued_jorg_jax_partial)
endog_jax = jnp.asarray(policy_egm[0, 1:])
pol_jax = jnp.asarray(policy_egm[1, 1:])
val_jax = jnp.asarray(value_egm[1, 1:])

m_min = float(np.min(policy_egm[0, 1:]))
m_max = float(np.max(policy_egm[0, 1:]))
m_grid = np.linspace(m_min, m_max, 500)
m_grid_jax = jnp.asarray(m_grid)

# Compile time
start = time.time()
jax.block_until_ready(
drued_jorg_jax_partial_jit(
endog=endog_jax,
pol=pol_jax,
val=val_jax,
m_grid=m_grid_jax,
exp_val_zero=value_egm[1, 0],
)
)
end = time.time()
print(f"JAX DRUED-JORG compilation time: {end - start:.4f} seconds")

tot_time = 0.0
for _ in range(n_runs):
start = time.time()
jax.block_until_ready(
drued_jorg_jax_partial_jit(
endog=endog_jax,
pol=pol_jax,
val=val_jax,
m_grid=m_grid_jax,
exp_val_zero=value_egm[1, 0],
)
)
end = time.time()
tot_time += end - start

print(
f"JAX DRUED-JORG average time over {n_runs} runs: {tot_time / n_runs:.6f} seconds"
)

numba_args = (
int(state_choice["choice"]),
float(params["beta"]),
float(params["rho"]),
float(params["delta"]),
float(value_egm[1, 0]),
)

# Numba FUES
start = time.time()
jax.block_until_ready(
upenv.fues_numba(
endog_grid=policy_egm[0, 1:],
policy=policy_egm[1, 1:],
value=value_egm[1, 1:],
expected_value_zero_savings=value_egm[1, 0],
value_function=value_func_numba,
value_function_args=numba_args,
)
)
end = time.time()
print(f"Numba FUES compilation time: {end - start:.4f} seconds")

tot_time = 0.0
for _ in range(n_runs):
start = time.time()
jax.block_until_ready(
upenv.fues_numba(
endog_grid=policy_egm[0, 1:],
policy=policy_egm[1, 1:],
value=value_egm[1, 1:],
expected_value_zero_savings=value_egm[1, 0],
value_function=value_func_numba,
value_function_args=numba_args,
)
)
end = time.time()
tot_time += end - start

print(f"Numba FUES average time over {n_runs} runs: {tot_time / n_runs:.6f} seconds")

# Numba DRUED-JORG
start = time.time()
jax.block_until_ready(
upenv.drued_jorg_numba(
endog_grid=policy_egm[0, 1:],
policy=policy_egm[1, 1:],
value=value_egm[1, 1:],
m_grid=m_grid,
expected_value_zero_savings=value_egm[1, 0],
value_function=value_func_numba,
value_function_args=numba_args,
)
)
end = time.time()
print(f"Numba DRUED-JORG compilation time: {end - start:.4f} seconds")

tot_time = 0.0
for _ in range(n_runs):
start = time.time()
jax.block_until_ready(
upenv.drued_jorg_numba(
endog_grid=policy_egm[0, 1:],
policy=policy_egm[1, 1:],
value=value_egm[1, 1:],
m_grid=m_grid,
expected_value_zero_savings=value_egm[1, 0],
value_function=value_func_numba,
value_function_args=numba_args,
)
)
end = time.time()
tot_time += end - start

print(
f"Numba DRUED-JORG average time over {n_runs} runs: {tot_time / n_runs:.6f} seconds"
)
Loading
Loading