-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathacquisition.py
More file actions
425 lines (308 loc) · 13.5 KB
/
acquisition.py
File metadata and controls
425 lines (308 loc) · 13.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
import torch
from torch import Tensor
import pygmo as pg
import numpy as np
from botorch.models.model import Model
from botorch.acquisition import AnalyticAcquisitionFunction
from botorch.optim import optimize_acqf
from botorch.sampling.pathwise.posterior_samplers import draw_matheron_paths
dtype = torch.double
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
class PosteriorSample(AnalyticAcquisitionFunction):
def __init__(
self,
model: Model,
) -> None:
super(AnalyticAcquisitionFunction, self).__init__(model)
self.path = draw_matheron_paths(self.model, torch.Size([1]))
def forward(self, X: Tensor):
# X: N, ..., d
y = self.path(X)
return y.flatten() # N
class PenalizedUpperConfidenceBound(AnalyticAcquisitionFunction):
def __init__(
self,
model: Model,
beta: Tensor,
bounds: Tensor,
busy: Tensor | None = None,
y_max = None,
local = False
) -> None:
super(AnalyticAcquisitionFunction, self).__init__(model)
self.register_buffer("beta", torch.as_tensor(beta, dtype=dtype, device = device))
self.register_buffer("bounds", torch.as_tensor(bounds, dtype=dtype, device = device))
if y_max is not None:
self.register_buffer("y_max", torch.as_tensor(y_max, dtype=dtype, device = device))
if busy is not None:
self.register_buffer("busy", torch.as_tensor(busy, dtype=dtype, device = device))
grad_norm = AnalyticPostMeanGradientNorm(self.model)
d = self.bounds.shape[-1]
if local:
ls = self.model.covar_module.lengthscale # 1xd
bounds_l = torch.clamp_min(self.busy-ls, self.bounds[0]) # bxd
bounds_u = torch.clamp_max(self.busy+ls, self.bounds[1]) # bxd
bounds_ = torch.stack((bounds_l, bounds_u), dim = 1) # bx2xd
L = []
norm_maxer = []
for i in range(len(self.busy)):
maxer, l = optimize_acqf( acq_function=grad_norm,
bounds=bounds_[i],
q=1,
num_restarts=10,
raw_samples=d*1000,
options={"batch_limit": 50, "maxiter": 200},
)
L.append(l)
norm_maxer.append(maxer)
L = torch.tensor(L, dtype=dtype, device = device).reshape(1,len(self.busy)) # 1xb
norm_maxer = torch.cat(norm_maxer).reshape(-1, d) # bxd
else:
norm_maxer, L = optimize_acqf( acq_function=grad_norm,
bounds=self.bounds,
q=1,
num_restarts=10,
raw_samples=d*1000,
options={"batch_limit": 50, "maxiter": 200},
)
L = L.to(dtype=dtype, device = device).reshape(1,1)
self.register_buffer("L", torch.as_tensor(L, dtype=dtype, device=device))
self.register_buffer("norm_maxer", torch.as_tensor(norm_maxer, dtype=dtype))
else:
self.busy = busy
def forward(self, X: Tensor) -> Tensor:
"""
Args:
X (tensor): Nxq=1xd
Returns:
acqf value (tensor): N
"""
p = -5
posterior = self.model.posterior(X)
mean = posterior.mean
std = posterior.variance.sqrt()
ucb = (mean + self.beta.sqrt() * std).flatten() # N
if self.busy is None:
return ucb
post_b = self.model.posterior(self.busy)
mean_b = post_b.mean # bx1
eps = 1e-8
std_b = post_b.variance.sqrt() #bx1
if len(X.shape) == 3: X = X.squeeze(1) # remove q-batch dim
norm = torch.cdist(X, self.busy) # + 1e-8 # Nxb
s = ((torch.abs(mean_b - self.y_max) + 1 * std_b)).reshape(1,-1) / (self.L) # 1xb
weights = norm / s # Nxb
diff_weights = (weights**p + 1)**(1/p) # Nxb
pen = torch.sum(torch.log(diff_weights), dim=1) # N
pen_ucb = torch.exp(torch.log(ucb.clamp_min(eps)) + pen)
return pen_ucb # N
class AnalyticPostMeanGradientNorm(AnalyticAcquisitionFunction):
def __init__(self,
model: Model,
) -> None:
super().__init__(model)
self.k = model.covar_module
self.Theta_inv = torch.atleast_2d(torch.diag(1/self.k.lengthscale.flatten()**2))
self.train_X = model.input_transform.untransform(model.train_inputs[0])
self.train_Y = model.outcome_transform.untransform(model.train_targets)[0].reshape(-1,1)
K_X_X = self.k(self.train_X).evaluate()
sig_squ = model.likelihood.noise
K_noise = K_X_X + (sig_squ + 1e-8) * torch.eye(K_X_X.size(0), dtype=dtype, device=device)
L = torch.linalg.cholesky(K_noise + 1e-8 * torch.eye(K_X_X.size(0), dtype=dtype, device=device))
K_noise_inv = torch.cholesky_inverse(L)
self.K_noise_inv_Y = torch.matmul(K_noise_inv, self.train_Y)
def forward(self, X: Tensor) -> Tensor:
# X: N,d
if len(X.shape) == 3: X = X.squeeze(1)
K_st_X = self.k(X, self.train_X).evaluate().unsqueeze(-1)
D = (self.train_X.unsqueeze(0)-X.unsqueeze(1))
grad_K_st_X = K_st_X * D @ self.Theta_inv
dmu_dx = torch.linalg.matmul(grad_K_st_X.transpose(1,2),
self.K_noise_inv_Y).squeeze(-1)
grad_norm = torch.linalg.vector_norm(dmu_dx, dim=-1)
return grad_norm.clamp_min(1e-8) # N
# The following function is taken directly from the authors of AEGiS
# https://github.com/georgedeath/aegis/blob/main/aegis/batch/nsga2_pygo.py
def NSGA2_pygmo(model, fevals, lb, ub, cf=None):
"""Finds the estimated Pareto front of a gpytorch model using NSGA2 [1]_.
Parameters
----------
model: gpytorch.models.ExactGP
gpytorch regression model on which to find the Pareto front
of its mean prediction and standard deviation.
fevals : int
Maximum number of times to evaluate a location using the model.
lb : (D, ) torch.tensor
Lower bound box constraint on D
ub : (D, ) torch.tensor
Upper bound box constraint on D
cf : callable, optional
Constraint function that returns True if it is called with a
valid decision vector, else False.
Returns
-------
X_front : (F, D) numpy.ndarray
The F D-dimensional locations on the estimated Pareto front.
musigma_front : (F, 2) numpy.ndarray
The corresponding mean response and standard deviation of the locations
on the front such that a point X_front[i, :] has a mean prediction
musigma_front[i, 0] and standard deviation musigma_front[i, 1].
Notes
-----
NSGA2 [1]_ discards locations on the pareto front if the size of the front
is greater than that of the population size. We counteract this by storing
every location and its corresponding mean and standard deviation and
calculate the Pareto front from this - thereby making the most of every
GP model evaluation.
References
----------
.. [1] Kalyanmoy Deb, Amrit Pratap, Sameer Agarwal, and T. Meyarivan.
A fast and elitist multiobjective genetic algorithm: NSGA-II.
IEEE Transactions on Evolutionary Computation 6, 2 (2001), 182–197.
"""
# internal class for the pygmo optimiser
class GPYTORCH_WRAPPER(object):
def __init__(self, model, lb, ub, cf, evals):
# model = gpytorch model
# lb = torch.tensor of lower bounds on X
# ub = torch.tensor of upper bounds on X
# cf = callable constraint function
# evals = total evaluations to be carried out
self.model = model
self.lb = lb.numpy()
self.ub = ub.numpy()
self.nd = lb.numel()
self.got_cf = cf is not None
self.cf = cf
self.i = 0 # evaluation pointer
self.dtype = model.train_targets.dtype
def get_bounds(self):
return (self.lb, self.ub)
def get_nobj(self):
return 2
def fitness(self, X):
X = np.atleast_2d(X)
X = torch.as_tensor(X, dtype=self.dtype)
f = model_fitness(
X,
self.model,
self.cf,
self.got_cf,
self.i,
self.i + X.shape[0],
)
self.i += X.shape[0]
return f.ravel()
def has_batch_fitness(self):
return True
def batch_fitness(self, X):
X = X.reshape(-1, self.nd)
return self.fitness(X)
# fitness function for the optimiser
def model_fitness(X, model, cf, got_cf, start_slice, end_slice):
n = X.shape[0]
f = np.zeros((n, 2))
valid_mask = np.ones(n, dtype="bool")
# if we select a location that violates the constraint,
# ensure it cannot dominate anything by having its fitness values
# maximally bad (i.e. set to infinity)
if got_cf:
for i in range(n):
if not cf(X[i]):
f[i] = [np.inf, np.inf]
valid_mask[i] = False
if np.any(valid_mask):
output = model(X[valid_mask])
output = model.likelihood(
output,
noise=torch.full_like(
output.mean, model.likelihood.noise.mean()
),
)
# note the negative stdev here as NSGA2 is minimising
# so we want to minimise the negative stdev
f[valid_mask, 0] = output.mean.numpy()
f[valid_mask, 1] = -np.sqrt(output.variance.numpy())
# store every location ever evaluated
model_fitness.X[start_slice:end_slice, :] = X
model_fitness.Y[start_slice:end_slice, :] = f
return f
# get the problem dimensionality
D = lb.numel()
# NSGA-II settings
POPSIZE = D * 100
# -1 here because the pop is evaluated first before iterating N_GENS times
N_GENS = int(np.ceil(fevals / POPSIZE)) - 1
TOTAL_EVALUATIONS = POPSIZE * (N_GENS + 1)
_nsga2 = pg.nsga2(
gen=1, # number of generations to evaluate per evolve() call
cr=0.8, # cross-over probability.
eta_c=20.0, # distribution index (cr)
m=1 / D, # mutation rate
eta_m=20.0, # distribution index (m)
)
# batch fitness evaluator -- this is the strange way we
# tell pygmo that we have a batch_fitness method
bfe = pg.bfe()
# tell nsgaII about it
_nsga2.set_bfe(bfe)
nsga2 = pg.algorithm(_nsga2)
# preallocate the storage of every location and fitness to be evaluated
model_fitness.X = np.zeros((TOTAL_EVALUATIONS, D))
model_fitness.Y = np.zeros((TOTAL_EVALUATIONS, 2))
# problem instance
gpytorch_problem = GPYTORCH_WRAPPER(model, lb, ub, cf, TOTAL_EVALUATIONS)
problem = pg.problem(gpytorch_problem)
# skip all gradient calculations as we don't need them
with torch.no_grad():
# initialise the population -- in batch (using bfe)
population = pg.population(problem, size=POPSIZE, b=bfe)
# evolve the population
for i in range(N_GENS):
population = nsga2.evolve(population)
# indices non-dominated points across the entire NSGA-II run
front_inds = pg.non_dominated_front_2d(model_fitness.Y)
X_front = model_fitness.X[front_inds, :]
musigma_front = model_fitness.Y[front_inds, :]
# convert the standard deviations back to positive values; nsga2 minimises
# the negative standard deviation (i.e. maximises the standard deviation)
musigma_front[:, 1] *= -1
# convert it to torch
X_front = torch.as_tensor(X_front, dtype=model.train_targets.dtype)
musigma_front = torch.as_tensor(
musigma_front, dtype=model.train_targets.dtype
)
return X_front, musigma_front
class AEGIS(AnalyticAcquisitionFunction):
def __init__(
self,
model: Model,
bounds: Tensor,
) -> None:
super(AnalyticAcquisitionFunction, self).__init__(model)
d = torch.tensor(len(model.covar_module.lengthscale[0]))
eps = torch.min(1.0 / torch.sqrt(d), torch.tensor(0.5))
r = torch.rand(1)
if r < 1 - (eps + eps):
# exploit
self.mode = "exploit"
elif r < 1 - eps:
# Thompson
self.mode = "Thompson"
self.path = draw_matheron_paths(self.model, torch.Size([1]))
else:
# approx Pareto selection
self.mode = "Pareto"
self.pareto_front, _ = NSGA2_pygmo(
model=model, fevals=1, lb=bounds[0], ub=bounds[1], cf=None
)
self.model = model
def forward(self, X):
if self.mode == "exploit":
post = self.model.posterior(X)
y = post.mean.squeeze(-1)
return y.flatten()
elif self.mode == "Thompson":
y = self.path(X)
return y.flatten()