-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDiffusionModel.py
More file actions
43 lines (35 loc) · 1.84 KB
/
DiffusionModel.py
File metadata and controls
43 lines (35 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
class DiffusionModel:
def __init__(self, start_schedule=0.0001, end_schedule=0.02, timesteps = 400,device = 'cuda'):
self.start_schedule = start_schedule
self.end_schedule = end_schedule
self.timesteps = timesteps
self.betas = torch.linspace(start_schedule, end_schedule, timesteps)
self.alphas = 1 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
self.device = device
def forward(self, x_0, t, device):
noise = torch.randn_like(x_0)
sqrt_alphas_cumprod_t = self.get_index_from_list(self.alphas_cumprod.sqrt(), t, x_0.shape)
sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(torch.sqrt(1. - self.alphas_cumprod), t, x_0.shape)
mean = sqrt_alphas_cumprod_t.to(device) * x_0.to(device)
variance = sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device)
return mean + variance, noise.to(device)
@torch.no_grad()
def backward(self, x, t, model, **kwargs):
betas_t = self.get_index_from_list(self.betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(torch.sqrt(1. - self.alphas_cumprod), t, x.shape)
sqrt_recip_alphas_t = self.get_index_from_list(torch.sqrt(1.0 / self.alphas), t, x.shape)
mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t, **kwargs) / sqrt_one_minus_alphas_cumprod_t)
posterior_variance_t = betas_t
if t == 0:
return mean
else:
noise = torch.randn_like(x)
variance = torch.sqrt(posterior_variance_t) * noise
return mean + variance
@staticmethod
def get_index_from_list(values, t, x_shape):
batch_size = t.shape[0]
result = values.gather(-1,t.cpu())
return result.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)