forked from vivinvinod/LFaB-for-QC
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathGPR_model.py
More file actions
56 lines (47 loc) · 1.89 KB
/
GPR_model.py
File metadata and controls
56 lines (47 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import gpytorch
import os
from tqdm import tqdm
import numpy as np
class ExactGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
if isinstance(train_x, np.ndarray):
train_x = torch.from_numpy(train_x).float()
if isinstance(train_y, np.ndarray):
train_y = torch.from_numpy(train_y).float()
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
# self.likelihood=likelihood
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
def train_hypers(model, likelihood, lr=0.05, maxiter=2000,
save_path='params.pth', tol=1e-8):
model.train()
likelihood.train()
optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Includes GaussianLikelihood parameters
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
losses = []
for i in tqdm(range(maxiter),desc='Training hyper-params for GPR using MarLogLike'):
# Zero gradients from previous iteration
optimizer.zero_grad()
# Output from model
output = model(model.train_inputs[0])
# Calc loss and backprop gradients
loss = -mll(output, model.train_targets)
loss.backward()
optimizer.step()
current_loss = loss.item()
losses.append(current_loss)
#if change in loss is less than tol
if i>5:
if np.abs(current_loss-losses[-2])<tol:
break
else:
pass
if type(save_path) != type(None):
torch.save(model.state_dict(),save_path)
losses = np.asarray(losses)
return losses