Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion axtreme_knowledge_base
9 changes: 6 additions & 3 deletions examples/basic_example_usecase/problem/brute_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _brute_force_calc(
N_ENV_SAMPLES_PER_PERIOD = N_YEARS_IN_PERIOD * N_SECONDS_IN_YEAR // N_SECONDS_IN_TIME_STEP
N_ENV_SAMPLES_PER_PERIOD = 1000

samples, x_max = collect_or_calculate_results(N_ENV_SAMPLES_PER_PERIOD, 300_000)
samples, x_max = collect_or_calculate_results(N_ENV_SAMPLES_PER_PERIOD, 10_000)

_ = plt.hist(samples, bins=100, density=True)
_ = plt.title(
Expand All @@ -199,8 +199,11 @@ def _brute_force_calc(
f"results/brute_force/erd_n_sample_per_period_{N_ENV_SAMPLES_PER_PERIOD}.png",
)
plt.show()

_ = plt.scatter(x_max[:, 0], x_max[:, 1])
_ = plt.scatter(x_max[:, 0], x_max[:, 1], alpha=0.5, s=5)
_ = plt.title("Locations of max response in environment space (point is sample erd)")
plt.savefig(
f"results/brute_force/erd_n_sample_per_period_x{N_ENV_SAMPLES_PER_PERIOD}.png",
)
plt.show()

# %%
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 25 additions & 4 deletions examples/basic_example_usecase/problem/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

# %%
# These are helpers for our dummy simulator, and would not be available in a real problem
def _true_loc_func(x: NDArray[np.float64]) -> NDArray[np.float64]:
def _true_loc_func_torch(x: torch.Tensor) -> torch.Tensor:
"""Takes input of shape (*b, d=2) and returns (*b,) locs."""
# For this toy example we use a Mixture distribution of a MultivariateNormal distribution
dist1_mean, dist1_cov = torch.tensor([0.8, 0.8]), torch.tensor([[0.03, 0], [0, 0.03]])
dist2_mean, dist2_cov = torch.tensor([0.2, 0.8]), torch.tensor([[0.04, 0.01], [0.01, 0.04]])
Expand All @@ -35,12 +36,23 @@ def _true_loc_func(x: NDArray[np.float64]) -> NDArray[np.float64]:
)
)
gmm = MixtureSameFamily(mix, component_dist)
return np.exp(gmm.log_prob(torch.tensor(x)).numpy())
return gmm.log_prob(x).exp()


def _true_scale_func(x: NDArray[np.float64]) -> NDArray[np.float64]:
def _true_loc_func(x: NDArray[np.float64]) -> NDArray[np.float64]:
"""Takes input of shape (*b, d=2) and returns (*b,) locs."""
return _true_loc_func_torch(torch.tensor(x)).numpy()


def _true_scale_func_torch(x: torch.Tensor) -> torch.Tensor:
"""Takes input of shape (*b, d=2) and returns (*b,) locs."""
# For this toy example we use a constant scale for simplicity
return np.ones(x.shape[0]) * 0.1
return torch.ones(x.shape[:-1]) * 0.1


def _true_scale_func(x: NDArray[np.float64]) -> NDArray[np.float64]:
"""Takes input of shape (*b, d=2) and returns (*b,) locs."""
return _true_scale_func_torch(torch.tensor(x)).numpy()


def dummy_simulator_function(x: NDArray[np.float64]) -> NDArray[np.float64]:
Expand Down Expand Up @@ -136,3 +148,12 @@ def _hash_function(x1: float, x2: float) -> int:
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d")
_ = ax.scatter(x1_mesh, x2_mesh, samples.reshape(len(x1), len(x2)), cmap="viridis")

# %% Testing the shape of the underlying function is as expected
assert _true_loc_func_torch(torch.rand(2)).shape == torch.Size([])
assert _true_loc_func_torch(torch.rand(5, 2)).shape == torch.Size([5])
assert _true_loc_func_torch(torch.rand(7, 5, 2)).shape == torch.Size([7, 5])

assert _true_scale_func_torch(torch.rand(2)).shape == torch.Size([])
assert _true_scale_func_torch(torch.rand(5, 2)).shape == torch.Size([5])
assert _true_scale_func_torch(torch.rand(7, 5, 2)).shape == torch.Size([7, 5])
30 changes: 26 additions & 4 deletions src/axtreme/plotting/gp_fit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Plotting module for visualizing how well the GP fits the data."""

from collections.abc import Callable
from collections.abc import Callable, Sequence
from typing import TypeAlias

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -43,16 +43,38 @@ def plot_surface_over_2d_search_space(
# Extract the parameter names and ranges from the search space
assert len(search_space.parameters) == 2, "Only 2D search spaces are supported for now." # noqa: PLR2004

(x1_name, x1_param), (x2_name, x2_param) = list(search_space.parameters.items())
(_, x1_param), (_, x2_param) = list(search_space.parameters.items())

if not (isinstance(x1_param, RangeParameter) and isinstance(x2_param, RangeParameter)):
msg = f"""Expect search_space.parameters to all be of type RangeParameter.
Instead got {type(x1_param) = }, and {type(x2_param) = }."""
raise NotImplementedError(msg)

bounds = [(x1_param.lower, x1_param.upper), (x2_param.lower, x2_param.upper)]
return plot_surface_over_2d_space(bounds, funcs, colors, num_points)


def plot_surface_over_2d_space(
bounds: Sequence[tuple[float, float]],
funcs: list[Callable[[Numpy2dArray], Numpy1dArray]],
colors: list[str] | None = None,
num_points: int = 101,
) -> Figure:
"""Creates a figure with the functions `funcs` plotted over the bounds.

Note:
Currently only support search spaces with 2 parameters.

Args:
bounds: For evaluation and plotting `[(x1_low, x1_high), (x2_low, x2_high)]`
funcs: A list of callables that take in a numpy array with shape (num_values, num_parameters=2 )
and return a numpy array with (num_values) elements.
colors: A list of colors to use for each function. If None, will use default Plotly colors.
num_points: The number of points in each dimension to evaluate the functions at.
"""
# Generate parameter ranges using NumPy
x1_values = np.linspace(x1_param.lower, x1_param.upper, num_points)
x2_values = np.linspace(x2_param.lower, x2_param.upper, num_points)
x1_values = np.linspace(bounds[0][0], bounds[0][1], num_points)
x2_values = np.linspace(bounds[1][0], bounds[1][1], num_points)

# Create a meshgrid for the parameter values

Expand Down
2 changes: 1 addition & 1 deletion src/axtreme/plotting/histogram3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def histogram3d( # noqa: PLR0913
**{
"intensity": [0, 0, 0, 0, z_value, z_value, z_value, z_value],
"flatshading": flatshading,
"coloaxtremeis": "coloaxtremeis",
"coloraxis": "coloraxis",
**mesh3d_kwargs,
},
),
Expand Down
2 changes: 1 addition & 1 deletion tests/qoi/test_gp_brute_force_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
This script is designed to be run interactively as well as though pytest.
"""

# ruff: noqa: T201
# ruff: noqa: T201 PT028
# pyright: reportUnnecessaryTypeIgnoreComment=false

# %%
Expand Down
2 changes: 1 addition & 1 deletion tests/qoi/test_marginal_cdf_extrapolation_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
This script is designed to be run interactively as well as though pytest.
"""

# ruff: noqa: T201
# ruff: noqa: T201 PT028
# pyright: reportUnnecessaryTypeIgnoreComment=false
# %%
import json
Expand Down
4 changes: 2 additions & 2 deletions tutorials/ax_botorch/optimisation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@
"- runtime: 3 seconds\n",
"- function eval: 100\n",
"\n",
"#### **3: Limit the maxiter or max functio neval**\n",
"- By default perfroms 200 maxiter and ~400 feval are used. Can reducing this give a suitably accurate result?\n",
"#### **3: Limit the `maxiter` or max function eval (`maxfev`)**\n",
"- By default performs 200 `maxiter` and ~400 `maxfev` are used. Can reducing this give a suitably accurate result?\n",
"\n",
"\n",
"```python\n",
Expand Down
Loading