Skip to content

Commit b20eabe

Browse files
authored
2D diffusion model now works
1 parent 477cd99 commit b20eabe

1 file changed

Lines changed: 232 additions & 71 deletions

File tree

engiopt/diffusion_2d_cond.py

Lines changed: 232 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
'''based on the diffuser example from huggingface: https://huggingface.co/learn/diffusion-course/unit1/2
2-
'''
1+
"""based on the diffuser example from huggingface: https://huggingface.co/learn/diffusion-course/unit1/2 ."""
32

43
from __future__ import annotations
54

@@ -9,12 +8,15 @@
98
import time
109

1110
from diffusers import DDPMScheduler
11+
from diffusers import DDPMPipeline
1212
from diffusers import UNet2DConditionModel
13+
from diffusers.utils import make_image_grid
1314
from engibench.utils.all_problems import BUILTIN_PROBLEMS
1415
import matplotlib.pyplot as plt
1516
import numpy as np
1617
import torch as th
1718
from torch import nn
19+
import torch.nn.functional as F
1820
import torchvision.transforms as transforms
1921
import tqdm
2022
import tyro
@@ -44,7 +46,7 @@ class Args:
4446
# Algorithm specific
4547
n_epochs: int = 1000
4648
"""number of epochs of training"""
47-
batch_size: int = 1
49+
batch_size: int = 32
4850
"""size of the batches"""
4951
lr: float = 3e-4
5052
"""learning rate"""
@@ -61,6 +63,139 @@ class Args:
6163
sample_interval: int = 400
6264
"""interval between image samples"""
6365

66+
def beta_schedule(T, start=1e-4, end=0.02, scale= 1.0, cosine=False, exp_biasing=False, exp_bias_factor=1):
67+
"""Returns a beta schedule (default: linear) for the diffusion model.
68+
69+
Args:
70+
T: Number of timesteps
71+
start: Starting value of beta
72+
end: Ending value of beta
73+
scale: Scaling factor for beta
74+
cosine: Whether to use a cosine beta schedule
75+
exp_biasing: Whether to use exponential biasing
76+
exp_bias_factor: Exponential biasing factor
77+
"""
78+
beta = th.linspace(scale*start, scale*end, T)
79+
if cosine:
80+
beta = []
81+
a_func = lambda t_val: math.cos((t_val + 0.008) / 1.008 * np.pi / 2) ** 2
82+
for i in range(T):
83+
t1 = i / T
84+
t2 = (i + 1) / T
85+
beta.append(min(1 - a_func(t2) / a_func(t1), 0.999))
86+
87+
beta = th.tensor(beta)
88+
89+
if exp_biasing:
90+
beta = (th.flip(th.exp(-exp_bias_factor*th.linspace(0, 1, T)), dims=[0]))*beta
91+
92+
return beta
93+
94+
def get_index_from_list(vals, t, x_shape):
95+
""" Returns a specific index t of a passed list of values vals
96+
while considering the batch dimension.
97+
Credit:
98+
"""
99+
batch_size = t.shape[0]
100+
out = vals.gather(-1, t.cpu())
101+
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
102+
103+
class DiffusionSampler():
104+
# Precompute the sqrt alphas and sqrt one minus alphas
105+
def __init__(self, T, betas):
106+
self.T = T
107+
self.betas = betas
108+
self.alphas = (1. - self.betas)
109+
self.alphas_cumprod = th.cumprod(self.alphas, axis=0)
110+
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
111+
self.sqrt_recip_alphas = th.sqrt(1.0 / self.alphas)
112+
self.sqrt_alphas_cumprod = th.sqrt(self.alphas_cumprod)
113+
self.sqrt_one_minus_alphas_cumprod = th.sqrt(1. - self.alphas_cumprod)
114+
self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
115+
116+
def forward_diffusion_sample(self, x_0, t, device="cpu"):
117+
"""Takes an image and a timestep as input and
118+
returns the noisy version of it
119+
"""
120+
noise = th.randn_like(x_0).to(device)
121+
sqrt_alphas_cumprod_t = get_index_from_list(self.sqrt_alphas_cumprod, t, x_0.shape)
122+
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
123+
self.sqrt_one_minus_alphas_cumprod, t, x_0.shape
124+
)
125+
126+
# mean + variance
127+
return sqrt_alphas_cumprod_t.to(device) * x_0.to(device)\
128+
+ sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)
129+
130+
def forward_diffusion_sample_partial(self, x_0, t_current, t_final, device="cpu"):
131+
"""Takes an image at a timestep and
132+
adds noise to reach the desired timestep
133+
"""
134+
for i in range(t_final[0]-t_current[0]):
135+
t = t_final - i
136+
137+
noise = th.randn_like(x_0, ).to(device)
138+
x_0 = th.sqrt(get_index_from_list(self.alphas, t, x_0.shape)) * x_0.to(device)\
139+
+ th.sqrt(get_index_from_list(1-self.alphas, t, x_0.shape)) * noise.to(device)
140+
141+
# mean + variance
142+
return x_0, noise.to(device)
143+
144+
def diffusion_step_sample(self, noise_pred, x_noisy, t, device="cpu"):
145+
"""Takes an image, noise and step; returns denoised image."""
146+
betas_t = get_index_from_list(self.betas, t, x_noisy.shape).to(device)
147+
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
148+
self.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape
149+
).to(device)
150+
sqrt_recip_alphas_t = get_index_from_list(self.sqrt_recip_alphas, t, x_noisy.shape).to(device)
151+
model_mean = sqrt_recip_alphas_t * (
152+
x_noisy - betas_t * noise_pred / sqrt_one_minus_alphas_cumprod_t
153+
)
154+
posterior_variance_t = get_index_from_list(self.posterior_variance, t, x_noisy.shape).to(device)
155+
156+
# mean + variance
157+
return (model_mean + th.sqrt(posterior_variance_t) * noise_pred).to(device)
158+
159+
def lossfn_builder(self):
160+
"""Returns the loss function for the diffusion model."""
161+
def lossfn(noise_pred, noise):
162+
163+
return F.mse_loss(noise_pred, noise)
164+
165+
return lossfn
166+
167+
def sample_timestep(self, model, x, t, encoder_hidden_states, c=None, t_mask=None):
168+
"""Calls the model to predict the noise in the image and returns the denoised image.
169+
Applies noise to this image, if we are not in the last step yet.
170+
"""
171+
model.eval()
172+
with th.no_grad():
173+
betas_t = get_index_from_list(self.betas, t, x.shape)
174+
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
175+
self.sqrt_one_minus_alphas_cumprod, t, x.shape
176+
)
177+
sqrt_recip_alphas_t = get_index_from_list(self.sqrt_recip_alphas, t, x.shape)
178+
179+
# Call model (current image - noise prediction)
180+
if c is not None:
181+
# with th.cuda.amp.autocast(dtype=th.float16):
182+
model_mean = sqrt_recip_alphas_t * (
183+
x - betas_t * model(x, t, c) / sqrt_one_minus_alphas_cumprod_t
184+
)
185+
else:
186+
# with th.cuda.amp.autocast(dtype=th.float16):
187+
model_mean = sqrt_recip_alphas_t * (
188+
x - betas_t * model(x, t, encoder_hidden_states).sample / sqrt_one_minus_alphas_cumprod_t
189+
)
190+
191+
posterior_variance_t = get_index_from_list(self.posterior_variance, t, x.shape)
192+
if t_mask is None:
193+
device = x.device
194+
195+
t_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))).to(device)
196+
197+
return model_mean + th.sqrt(posterior_variance_t) * th.randn_like(x) * t_mask
198+
64199
if __name__ == "__main__":
65200
args = tyro.cli(Args)
66201

@@ -82,22 +217,28 @@ class Args:
82217

83218
os.makedirs("images", exist_ok=True)
84219

85-
# if th.backends.mps.is_available():
86-
# device = th.device("mps")
87-
# elif th.cuda.is_available():
88-
# device = th.device("cuda")
89-
# else:
90-
device = th.device("cpu")
220+
if th.backends.mps.is_available():
221+
device = th.device("mps")
222+
elif th.cuda.is_available():
223+
device = th.device("cuda")
224+
else:
225+
device = th.device("cpu")
91226

92227
# Loss function
93228
adversarial_loss = th.nn.MSELoss()
94229

95-
# Initialize generator and discriminator
230+
# Initialize UNet from Huggingface
96231
model = UNet2DConditionModel(
97232
sample_size=(100, 100),
98233
in_channels=1,
99234
out_channels=1,
100235
cross_attention_dim=64,
236+
block_out_channels=(64, 128),
237+
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
238+
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D"),
239+
layers_per_block=1,
240+
transformer_layers_per_block=0,
241+
only_cross_attention=True,
101242
)
102243

103244
model.to(device)
@@ -108,84 +249,103 @@ class Args:
108249
filtered_ds = th.zeros(len(training_ds), 100, 100, device=device)
109250
for i in range(len(training_ds)):
110251
filtered_ds[i] = transforms.Resize((100, 100))(training_ds[i]['optimal_design'].reshape(1, training_ds[i]['nelx'], training_ds[i]['nely']))
111-
training_ds = th.utils.data.TensorDataset(filtered_ds.flatten(1), training_ds['volfrac'])
252+
filtered_ds_max = filtered_ds.max()
253+
filtered_ds_min = filtered_ds.min()
254+
filtered_ds *= 2
255+
filtered_ds -= 1
256+
filtered_ds_norm = (filtered_ds - filtered_ds_min) / (filtered_ds_max - filtered_ds_min)
257+
training_ds = th.utils.data.TensorDataset(filtered_ds_norm.flatten(1), training_ds['volfrac'])
258+
vf_min = training_ds.tensors[1].min()
259+
vf_max = training_ds.tensors[1].max()
112260
dataloader = th.utils.data.DataLoader(
113261
training_ds,
114262
batch_size=args.batch_size,
115263
shuffle=True,
116264
)
117-
265+
num_timesteps = 1000
118266
# Optimizer
119-
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")
267+
noise_scheduler = DDPMScheduler(num_train_timesteps=num_timesteps, beta_schedule="linear")
120268

121269
# Training loop
122270

123271
optimizer = th.optim.AdamW(model.parameters(), lr=4e-4)
124-
@th.no_grad()
125-
def sample_designs(n_designs: int) -> th.Tensor:
126-
"""Samples n_designs from the generator."""
127-
# Sample noise
128-
z = th.randn((n_designs, args.latent_dim), device=device, dtype=th.float)
129-
# THESE BOUNDS ARE PROBLEM DEPENDENT
130-
131-
linspaces = [th.linspace(objs[:, i].min(), objs[:, i].max(), n_designs, device=device) for i in range(objs.shape[1])]
132-
133-
objs_small = th.stack(linspaces, dim=1)
134-
desired_objs = objs_small.reshape(-1,1,1)
135-
desired_objs = desired_objs.expand(-1,1,64)
136-
noise = th.randn((25,1,100,100)).to(device)
137-
timesteps = th.full((25,), (950))
138-
test_ds = th.utils.data.TensorDataset(noise, timesteps, desired_objs)
139-
dataloader = th.utils.data.DataLoader(test_ds, batch_size=1)
140-
gen_images = []
141-
for noise, timesteps, desired_objs in tqdm.tqdm(dataloader):
142-
gen_imgs = model(noise, timesteps, encoder_hidden_states=desired_objs)[0]
143-
gen_images.append(gen_imgs)
144-
145-
return objs_small, th.cat(gen_images, dim=0)
146-
147-
# ----------
148-
# Training
149-
# ----------
150-
for epoch in tqdm.trange(args.n_epochs):
151-
for i, data in enumerate(dataloader):
152-
# THIS IS PROBLEM DEPENDENT
153-
designs = data[0].reshape(-1,1,100,100)
154-
155-
objs = th.stack((data[1:]), dim=1).reshape(-1,1,1)
156-
objs_ex = objs.expand(-1,1,64)
157272

158-
clean_images = designs
273+
## Schedule Parameters
274+
T = num_timesteps # Number of timesteps
275+
start = 1e-4 # Starting variance
276+
end = 0.02 # Ending variance
277+
# Choose a schedule (if the following are False, then a linear schedule is used)
278+
cosine = False # Use cosine schedule
279+
exp_biasing = False # Use exponential schedule
280+
exp_biasing_factor = 1 # Exponential schedule factor (used if exp_biasing=True)
281+
##
159282

160-
# Sample noise to add to the images
283+
# Choose a variance schedule
161284

162-
noise = th.randn(clean_images.shape).to(clean_images.device)
285+
betas = beta_schedule(T=T, start=start, end=end,
286+
scale= 1.0, cosine=cosine,
287+
exp_biasing=exp_biasing, exp_bias_factor=exp_biasing_factor
288+
)
163289

164-
bs = clean_images.shape[0]
290+
ddm_sampler = DiffusionSampler(T, betas)
165291

166-
# Sample a random timestep for each image
167-
168-
timesteps = th.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()
169-
170-
# Add noise to the clean images according to the noise magnitude at each timestep
171-
172-
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
173-
174-
# Get the model prediction
175-
176-
noise_pred = model(noisy_images, timesteps, return_dict=False, encoder_hidden_states=objs_ex)[0]
292+
# Loss function
293+
def ddm_loss_fn(noise_pred, noise):
294+
return F.l1_loss(noise_pred, noise)
177295

178-
# Calculate the loss
296+
@th.no_grad()
297+
def sample_designs(model, n_designs=25):
298+
"""Samples n_designs designs."""
299+
model.eval()
300+
with th.no_grad():
179301

180-
loss = th.nn.functional.mse_loss(noise_pred, noise)
302+
dims = (n_designs, 1, 100, 100)
303+
image = th.randn(dims, device=device) # initial image
304+
encoder_hidden_states = th.linspace(vf_min, vf_max, n_designs, device=device)
305+
encoder_hidden_states = encoder_hidden_states.view(n_designs, 1, 1).expand(n_designs, 1, 32)
306+
for i in range(num_timesteps)[::-1]:
307+
t = th.full((n_designs,), i, device=device, dtype=th.long)
181308

182-
loss.backward(loss)
309+
image = ddm_sampler.sample_timestep(model, image, t, encoder_hidden_states)
183310

184-
# Update the model parameters with the optimizer
311+
return image, encoder_hidden_states
185312

186-
optimizer.step()
313+
# ----------
314+
# Training
315+
# ----------
316+
for epoch in tqdm.trange(args.n_epochs):
317+
for i, data in enumerate(dataloader):
187318

319+
# Zero the parameter gradients
188320
optimizer.zero_grad()
321+
designs = data[0].reshape(-1,1,100,100)
322+
x = designs.to(device)
323+
objs = th.stack((data[1:]), dim=1).reshape(-1,1,1)
324+
objs_ex = objs.expand(-1,1,32)
325+
326+
327+
current_batch_size = x.shape[0]
328+
t = th.randint(0, T, (current_batch_size,), device=device).long()
329+
encoder_hidden_states = objs_ex.to(device)
330+
331+
# Get the noise and the noisy input
332+
x_noisy, noise = ddm_sampler.forward_diffusion_sample(x, t, device)
333+
334+
# Forward pass
335+
# if mp_mode:
336+
# with torch.cuda.amp.autocast(dtype=torch.float16):
337+
# noise_pred = model_diffuser(x_noisy, t, encoder_hidden_states).sample
338+
# loss = ddm_loss_fn(noise_pred, noise)
339+
# scaler.scale(loss).backward()
340+
# scaler.step(optimizer)
341+
# scaler.update()
342+
# else:
343+
noise_pred = model(x_noisy, t, encoder_hidden_states).sample
344+
loss = ddm_loss_fn(noise_pred, noise)
345+
346+
# Backpropagation
347+
loss.backward()
348+
optimizer.step()
189349

190350
# ----------
191351
# Logging
@@ -206,18 +366,19 @@ def sample_designs(n_designs: int) -> th.Tensor:
206366
# This saves a grid image of 25 generated designs every sample_interval
207367
if batches_done % args.sample_interval == 0:
208368
# Extract 25 designs
209-
desired_objs, designs = sample_designs(25)
369+
370+
designs, hidden_states = sample_designs(model, 25)
210371
fig, axes = plt.subplots(5, 5, figsize=(12, 12))
211372

212373
# Flatten axes for easy indexing
213374
axes = axes.flatten()
214375

215-
# Plot each tensor as a scatter plot
376+
# Plot the iamge created by each output
216377
for j, tensor in enumerate(designs):
217378
img = tensor.cpu().numpy().reshape(100,100) # Extract x and y coordinates
218-
do = desired_objs[j].cpu()
219-
axes[j].imshow(img) # Scatter plot
220-
axes[j].title.set_text(f"volfrac: {do[0]:.2f}")
379+
do = hidden_states[j,0,0].cpu()
380+
axes[j].imshow(img.T) # image plot
381+
axes[j].title.set_text(f"volfrac: {do:.2f}") # Set title
221382
axes[j].set_xticks([]) # Hide x ticks
222383
axes[j].set_yticks([]) # Hide y ticks
223384

0 commit comments

Comments
 (0)