-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimulator.py
More file actions
159 lines (121 loc) · 6.14 KB
/
simulator.py
File metadata and controls
159 lines (121 loc) · 6.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""Define the simulator."""
# %%
from typing import cast
import numpy as np
import torch
from botorch.test_functions import BraninCurrin
from numpy.typing import NDArray
from scipy.stats import gumbel_r
from torch.distributions import Categorical, MixtureSameFamily, MultivariateNormal
from axtreme.simulator.base import Simulator
torch.set_default_dtype(torch.float64)
_branin_currin = BraninCurrin(negate=False).to(dtype=torch.double)
# %%
# These are helpers for our dummy simulator, and would not be available in a real problem
def _true_loc_func_torch(x: torch.Tensor) -> torch.Tensor:
"""Takes input of shape (*b, d=2) and returns (*b,) locs."""
# For this toy example we use a Mixture distribution of a MultivariateNormal distribution
dist1_mean, dist1_cov = torch.tensor([0.8, 0.8]), torch.tensor([[0.03, 0], [0, 0.03]])
dist2_mean, dist2_cov = torch.tensor([0.2, 0.8]), torch.tensor([[0.04, 0.01], [0.01, 0.04]])
dist3_mean, dist3_cov = torch.tensor([0.5, 0.2]), torch.tensor([[0.06, 0], [0, 0.06]])
locs = torch.stack([dist1_mean, dist2_mean, dist3_mean])
covs = torch.stack([dist1_cov, dist2_cov, dist3_cov])
component_dist = MultivariateNormal(loc=locs, covariance_matrix=covs)
mix = Categorical(
torch.ones(
3,
)
)
gmm = MixtureSameFamily(mix, component_dist)
return gmm.log_prob(x).exp()
def _true_loc_func(x: NDArray[np.float64]) -> NDArray[np.float64]:
"""Takes input of shape (*b, d=2) and returns (*b,) locs."""
return _true_loc_func_torch(torch.tensor(x)).numpy()
def _true_scale_func_torch(x: torch.Tensor) -> torch.Tensor:
"""Takes input of shape (*b, d=2) and returns (*b,) locs."""
# For this toy example we use a constant scale for simplicity
return torch.ones(x.shape[:-1]) * 0.1
def _true_scale_func(x: NDArray[np.float64]) -> NDArray[np.float64]:
"""Takes input of shape (*b, d=2) and returns (*b,) locs."""
return _true_scale_func_torch(torch.tensor(x)).numpy()
def dummy_simulator_function(x: NDArray[np.float64]) -> NDArray[np.float64]:
"""Generate a sample from a Gumbel distribution where the location and scale are function of X.
Parameters:
x: (n,2) array of points to simulate
Returns:
*(n,1) array of the simulator results for that point
"""
location = _true_loc_func(x)
scale = _true_scale_func(x)
sample = cast("NDArray[np.float64]", gumbel_r.rvs(loc=location, scale=scale))
return sample.reshape(-1, 1)
sim = dummy_simulator_function
class DummySimulatorSeeded(Simulator):
"""A seeded version of ``dummy_simulator_function`` conforming to the ``Simulator`` protocol.
The each unique point in the x domain has a fixed seed used when generating samples. this can be
useful for reproducibility. Points still appear "semi" random, as points close together use completly different
seeds.
Details:
- Points which differ only after the 10th decimal place get the same random seed.
- Co-ordinates at the same unique point will produce the exact same results. IT
"""
def __call__(
self, x: np.ndarray[tuple[int, int], np.dtype[np.float64]], n_simulations_per_point: int = 1
) -> np.ndarray[tuple[int, int, int], np.dtype[np.float64]]:
"""Evaluate the model at given points.
Args:
x: An array of shape (n_points, n_input_dims) of points at which to evaluate the model.
n_simulations_per_point: The number of simulations to run at each point. Expected to have a default value
Returns:
An array of shape (n_points, n_simulations_per_point, n_output_dims) of the model evaluated at the input
points.
"""
# for each unque x point create a unqiue seed
seeds = [DummySimulatorSeeded._hash_function(*tuple(x_i)) for x_i in x]
location = _true_loc_func(x)
scale = _true_scale_func(x)
samples = []
for loc_i, scale_i, seed_i in zip(location, scale, seeds, strict=True):
sample = cast(
"NDArray[np.float64]",
gumbel_r.rvs(loc=loc_i, scale=scale_i, random_state=seed_i, size=n_simulations_per_point),
)
samples.append(sample)
return np.expand_dims(np.stack(samples), axis=-1)
@staticmethod
def _hash_function(x1: float, x2: float) -> int:
"""Hash 2 float to a number within between 0 and 2**32 - 1."""
return abs(hash((x1, x2)) % (2**32 - 1))
# %%
if __name__ == "__main__":
# Quick and dirty tests:
sim = DummySimulatorSeeded()
# %%
x = np.array([[0.5000, 0.5], [0.3, 0.3]])
# The same value will produce the same result
assert (sim(x, n_simulations_per_point=5) == sim(x, n_simulations_per_point=5)).all()
# %5
# Very similar values produce different results
# we allow a wide margin of error because results should be completely different due to sampling
x1 = np.array([[0.5 + 1e-5, 0.5], [0.3, 0.3]])
assert not np.allclose(sim(x1, n_simulations_per_point=5), sim(x, n_simulations_per_point=5), atol=2)
# %%
# Plut the surface over a small area. If sample is not random the values should change slowly.
x1 = np.linspace(0.5, 0.5 + 1e-8, 10) # 100 points between -5 and 5
x2 = np.linspace(0.5, 0.5 + 1e-8, 10)
# Create a grid of (x, y) points
x1_mesh, x2_mesh = np.meshgrid(x1, x2)
x = np.column_stack([x1_mesh.flatten(), x2_mesh.flatten()])
# %%
import matplotlib.pyplot as plt
samples = sim(x).flatten()
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d")
_ = ax.scatter(x1_mesh, x2_mesh, samples.reshape(len(x1), len(x2)), cmap="viridis")
# %% Testing the shape of the underlying function is as expected
assert _true_loc_func_torch(torch.rand(2)).shape == torch.Size([])
assert _true_loc_func_torch(torch.rand(5, 2)).shape == torch.Size([5])
assert _true_loc_func_torch(torch.rand(7, 5, 2)).shape == torch.Size([7, 5])
assert _true_scale_func_torch(torch.rand(2)).shape == torch.Size([])
assert _true_scale_func_torch(torch.rand(5, 2)).shape == torch.Size([5])
assert _true_scale_func_torch(torch.rand(7, 5, 2)).shape == torch.Size([7, 5])