1- '''based on the diffuser example from huggingface: https://huggingface.co/learn/diffusion-course/unit1/2
2- '''
1+ """based on the diffuser example from huggingface: https://huggingface.co/learn/diffusion-course/unit1/2 ."""
32
43from __future__ import annotations
54
98import time
109
1110from diffusers import DDPMScheduler
11+ from diffusers import DDPMPipeline
1212from diffusers import UNet2DConditionModel
13+ from diffusers .utils import make_image_grid
1314from engibench .utils .all_problems import BUILTIN_PROBLEMS
1415import matplotlib .pyplot as plt
1516import numpy as np
1617import torch as th
1718from torch import nn
19+ import torch .nn .functional as F
1820import torchvision .transforms as transforms
1921import tqdm
2022import tyro
@@ -44,7 +46,7 @@ class Args:
4446 # Algorithm specific
4547 n_epochs : int = 1000
4648 """number of epochs of training"""
47- batch_size : int = 1
49+ batch_size : int = 32
4850 """size of the batches"""
4951 lr : float = 3e-4
5052 """learning rate"""
@@ -61,6 +63,139 @@ class Args:
6163 sample_interval : int = 400
6264 """interval between image samples"""
6365
66+ def beta_schedule (T , start = 1e-4 , end = 0.02 , scale = 1.0 , cosine = False , exp_biasing = False , exp_bias_factor = 1 ):
67+ """Returns a beta schedule (default: linear) for the diffusion model.
68+
69+ Args:
70+ T: Number of timesteps
71+ start: Starting value of beta
72+ end: Ending value of beta
73+ scale: Scaling factor for beta
74+ cosine: Whether to use a cosine beta schedule
75+ exp_biasing: Whether to use exponential biasing
76+ exp_bias_factor: Exponential biasing factor
77+ """
78+ beta = th .linspace (scale * start , scale * end , T )
79+ if cosine :
80+ beta = []
81+ a_func = lambda t_val : math .cos ((t_val + 0.008 ) / 1.008 * np .pi / 2 ) ** 2
82+ for i in range (T ):
83+ t1 = i / T
84+ t2 = (i + 1 ) / T
85+ beta .append (min (1 - a_func (t2 ) / a_func (t1 ), 0.999 ))
86+
87+ beta = th .tensor (beta )
88+
89+ if exp_biasing :
90+ beta = (th .flip (th .exp (- exp_bias_factor * th .linspace (0 , 1 , T )), dims = [0 ]))* beta
91+
92+ return beta
93+
94+ def get_index_from_list (vals , t , x_shape ):
95+ """ Returns a specific index t of a passed list of values vals
96+ while considering the batch dimension.
97+ Credit:
98+ """
99+ batch_size = t .shape [0 ]
100+ out = vals .gather (- 1 , t .cpu ())
101+ return out .reshape (batch_size , * ((1 ,) * (len (x_shape ) - 1 ))).to (t .device )
102+
103+ class DiffusionSampler ():
104+ # Precompute the sqrt alphas and sqrt one minus alphas
105+ def __init__ (self , T , betas ):
106+ self .T = T
107+ self .betas = betas
108+ self .alphas = (1. - self .betas )
109+ self .alphas_cumprod = th .cumprod (self .alphas , axis = 0 )
110+ self .alphas_cumprod_prev = F .pad (self .alphas_cumprod [:- 1 ], (1 , 0 ), value = 1.0 )
111+ self .sqrt_recip_alphas = th .sqrt (1.0 / self .alphas )
112+ self .sqrt_alphas_cumprod = th .sqrt (self .alphas_cumprod )
113+ self .sqrt_one_minus_alphas_cumprod = th .sqrt (1. - self .alphas_cumprod )
114+ self .posterior_variance = self .betas * (1. - self .alphas_cumprod_prev ) / (1. - self .alphas_cumprod )
115+
116+ def forward_diffusion_sample (self , x_0 , t , device = "cpu" ):
117+ """Takes an image and a timestep as input and
118+ returns the noisy version of it
119+ """
120+ noise = th .randn_like (x_0 ).to (device )
121+ sqrt_alphas_cumprod_t = get_index_from_list (self .sqrt_alphas_cumprod , t , x_0 .shape )
122+ sqrt_one_minus_alphas_cumprod_t = get_index_from_list (
123+ self .sqrt_one_minus_alphas_cumprod , t , x_0 .shape
124+ )
125+
126+ # mean + variance
127+ return sqrt_alphas_cumprod_t .to (device ) * x_0 .to (device )\
128+ + sqrt_one_minus_alphas_cumprod_t .to (device ) * noise .to (device ), noise .to (device )
129+
130+ def forward_diffusion_sample_partial (self , x_0 , t_current , t_final , device = "cpu" ):
131+ """Takes an image at a timestep and
132+ adds noise to reach the desired timestep
133+ """
134+ for i in range (t_final [0 ]- t_current [0 ]):
135+ t = t_final - i
136+
137+ noise = th .randn_like (x_0 , ).to (device )
138+ x_0 = th .sqrt (get_index_from_list (self .alphas , t , x_0 .shape )) * x_0 .to (device )\
139+ + th .sqrt (get_index_from_list (1 - self .alphas , t , x_0 .shape )) * noise .to (device )
140+
141+ # mean + variance
142+ return x_0 , noise .to (device )
143+
144+ def diffusion_step_sample (self , noise_pred , x_noisy , t , device = "cpu" ):
145+ """Takes an image, noise and step; returns denoised image."""
146+ betas_t = get_index_from_list (self .betas , t , x_noisy .shape ).to (device )
147+ sqrt_one_minus_alphas_cumprod_t = get_index_from_list (
148+ self .sqrt_one_minus_alphas_cumprod , t , x_noisy .shape
149+ ).to (device )
150+ sqrt_recip_alphas_t = get_index_from_list (self .sqrt_recip_alphas , t , x_noisy .shape ).to (device )
151+ model_mean = sqrt_recip_alphas_t * (
152+ x_noisy - betas_t * noise_pred / sqrt_one_minus_alphas_cumprod_t
153+ )
154+ posterior_variance_t = get_index_from_list (self .posterior_variance , t , x_noisy .shape ).to (device )
155+
156+ # mean + variance
157+ return (model_mean + th .sqrt (posterior_variance_t ) * noise_pred ).to (device )
158+
159+ def lossfn_builder (self ):
160+ """Returns the loss function for the diffusion model."""
161+ def lossfn (noise_pred , noise ):
162+
163+ return F .mse_loss (noise_pred , noise )
164+
165+ return lossfn
166+
167+ def sample_timestep (self , model , x , t , encoder_hidden_states , c = None , t_mask = None ):
168+ """Calls the model to predict the noise in the image and returns the denoised image.
169+ Applies noise to this image, if we are not in the last step yet.
170+ """
171+ model .eval ()
172+ with th .no_grad ():
173+ betas_t = get_index_from_list (self .betas , t , x .shape )
174+ sqrt_one_minus_alphas_cumprod_t = get_index_from_list (
175+ self .sqrt_one_minus_alphas_cumprod , t , x .shape
176+ )
177+ sqrt_recip_alphas_t = get_index_from_list (self .sqrt_recip_alphas , t , x .shape )
178+
179+ # Call model (current image - noise prediction)
180+ if c is not None :
181+ # with th.cuda.amp.autocast(dtype=th.float16):
182+ model_mean = sqrt_recip_alphas_t * (
183+ x - betas_t * model (x , t , c ) / sqrt_one_minus_alphas_cumprod_t
184+ )
185+ else :
186+ # with th.cuda.amp.autocast(dtype=th.float16):
187+ model_mean = sqrt_recip_alphas_t * (
188+ x - betas_t * model (x , t , encoder_hidden_states ).sample / sqrt_one_minus_alphas_cumprod_t
189+ )
190+
191+ posterior_variance_t = get_index_from_list (self .posterior_variance , t , x .shape )
192+ if t_mask is None :
193+ device = x .device
194+
195+ t_mask = ((t != 0 ).float ().view (- 1 , * ([1 ] * (len (x .shape ) - 1 )))).to (device )
196+
197+ return model_mean + th .sqrt (posterior_variance_t ) * th .randn_like (x ) * t_mask
198+
64199if __name__ == "__main__" :
65200 args = tyro .cli (Args )
66201
@@ -82,22 +217,28 @@ class Args:
82217
83218 os .makedirs ("images" , exist_ok = True )
84219
85- # if th.backends.mps.is_available():
86- # device = th.device("mps")
87- # elif th.cuda.is_available():
88- # device = th.device("cuda")
89- # else:
90- device = th .device ("cpu" )
220+ if th .backends .mps .is_available ():
221+ device = th .device ("mps" )
222+ elif th .cuda .is_available ():
223+ device = th .device ("cuda" )
224+ else :
225+ device = th .device ("cpu" )
91226
92227 # Loss function
93228 adversarial_loss = th .nn .MSELoss ()
94229
95- # Initialize generator and discriminator
230+ # Initialize UNet from Huggingface
96231 model = UNet2DConditionModel (
97232 sample_size = (100 , 100 ),
98233 in_channels = 1 ,
99234 out_channels = 1 ,
100235 cross_attention_dim = 64 ,
236+ block_out_channels = (64 , 128 ),
237+ down_block_types = ("CrossAttnDownBlock2D" , "DownBlock2D" ),
238+ up_block_types = ("UpBlock2D" , "CrossAttnUpBlock2D" ),
239+ layers_per_block = 1 ,
240+ transformer_layers_per_block = 0 ,
241+ only_cross_attention = True ,
101242 )
102243
103244 model .to (device )
@@ -108,84 +249,103 @@ class Args:
108249 filtered_ds = th .zeros (len (training_ds ), 100 , 100 , device = device )
109250 for i in range (len (training_ds )):
110251 filtered_ds [i ] = transforms .Resize ((100 , 100 ))(training_ds [i ]['optimal_design' ].reshape (1 , training_ds [i ]['nelx' ], training_ds [i ]['nely' ]))
111- training_ds = th .utils .data .TensorDataset (filtered_ds .flatten (1 ), training_ds ['volfrac' ])
252+ filtered_ds_max = filtered_ds .max ()
253+ filtered_ds_min = filtered_ds .min ()
254+ filtered_ds *= 2
255+ filtered_ds -= 1
256+ filtered_ds_norm = (filtered_ds - filtered_ds_min ) / (filtered_ds_max - filtered_ds_min )
257+ training_ds = th .utils .data .TensorDataset (filtered_ds_norm .flatten (1 ), training_ds ['volfrac' ])
258+ vf_min = training_ds .tensors [1 ].min ()
259+ vf_max = training_ds .tensors [1 ].max ()
112260 dataloader = th .utils .data .DataLoader (
113261 training_ds ,
114262 batch_size = args .batch_size ,
115263 shuffle = True ,
116264 )
117-
265+ num_timesteps = 1000
118266 # Optimizer
119- noise_scheduler = DDPMScheduler (num_train_timesteps = 1000 , beta_schedule = "squaredcos_cap_v2 " )
267+ noise_scheduler = DDPMScheduler (num_train_timesteps = num_timesteps , beta_schedule = "linear " )
120268
121269 # Training loop
122270
123271 optimizer = th .optim .AdamW (model .parameters (), lr = 4e-4 )
124- @th .no_grad ()
125- def sample_designs (n_designs : int ) -> th .Tensor :
126- """Samples n_designs from the generator."""
127- # Sample noise
128- z = th .randn ((n_designs , args .latent_dim ), device = device , dtype = th .float )
129- # THESE BOUNDS ARE PROBLEM DEPENDENT
130-
131- linspaces = [th .linspace (objs [:, i ].min (), objs [:, i ].max (), n_designs , device = device ) for i in range (objs .shape [1 ])]
132-
133- objs_small = th .stack (linspaces , dim = 1 )
134- desired_objs = objs_small .reshape (- 1 ,1 ,1 )
135- desired_objs = desired_objs .expand (- 1 ,1 ,64 )
136- noise = th .randn ((25 ,1 ,100 ,100 )).to (device )
137- timesteps = th .full ((25 ,), (950 ))
138- test_ds = th .utils .data .TensorDataset (noise , timesteps , desired_objs )
139- dataloader = th .utils .data .DataLoader (test_ds , batch_size = 1 )
140- gen_images = []
141- for noise , timesteps , desired_objs in tqdm .tqdm (dataloader ):
142- gen_imgs = model (noise , timesteps , encoder_hidden_states = desired_objs )[0 ]
143- gen_images .append (gen_imgs )
144-
145- return objs_small , th .cat (gen_images , dim = 0 )
146-
147- # ----------
148- # Training
149- # ----------
150- for epoch in tqdm .trange (args .n_epochs ):
151- for i , data in enumerate (dataloader ):
152- # THIS IS PROBLEM DEPENDENT
153- designs = data [0 ].reshape (- 1 ,1 ,100 ,100 )
154-
155- objs = th .stack ((data [1 :]), dim = 1 ).reshape (- 1 ,1 ,1 )
156- objs_ex = objs .expand (- 1 ,1 ,64 )
157272
158- clean_images = designs
273+ ## Schedule Parameters
274+ T = num_timesteps # Number of timesteps
275+ start = 1e-4 # Starting variance
276+ end = 0.02 # Ending variance
277+ # Choose a schedule (if the following are False, then a linear schedule is used)
278+ cosine = False # Use cosine schedule
279+ exp_biasing = False # Use exponential schedule
280+ exp_biasing_factor = 1 # Exponential schedule factor (used if exp_biasing=True)
281+ ##
159282
160- # Sample noise to add to the images
283+ # Choose a variance schedule
161284
162- noise = th .randn (clean_images .shape ).to (clean_images .device )
285+ betas = beta_schedule (T = T , start = start , end = end ,
286+ scale = 1.0 , cosine = cosine ,
287+ exp_biasing = exp_biasing , exp_bias_factor = exp_biasing_factor
288+ )
163289
164- bs = clean_images . shape [ 0 ]
290+ ddm_sampler = DiffusionSampler ( T , betas )
165291
166- # Sample a random timestep for each image
167-
168- timesteps = th .randint (0 , noise_scheduler .num_train_timesteps , (bs ,), device = clean_images .device ).long ()
169-
170- # Add noise to the clean images according to the noise magnitude at each timestep
171-
172- noisy_images = noise_scheduler .add_noise (clean_images , noise , timesteps )
173-
174- # Get the model prediction
175-
176- noise_pred = model (noisy_images , timesteps , return_dict = False , encoder_hidden_states = objs_ex )[0 ]
292+ # Loss function
293+ def ddm_loss_fn (noise_pred , noise ):
294+ return F .l1_loss (noise_pred , noise )
177295
178- # Calculate the loss
296+ @th .no_grad ()
297+ def sample_designs (model , n_designs = 25 ):
298+ """Samples n_designs designs."""
299+ model .eval ()
300+ with th .no_grad ():
179301
180- loss = th .nn .functional .mse_loss (noise_pred , noise )
302+ dims = (n_designs , 1 , 100 , 100 )
303+ image = th .randn (dims , device = device ) # initial image
304+ encoder_hidden_states = th .linspace (vf_min , vf_max , n_designs , device = device )
305+ encoder_hidden_states = encoder_hidden_states .view (n_designs , 1 , 1 ).expand (n_designs , 1 , 32 )
306+ for i in range (num_timesteps )[::- 1 ]:
307+ t = th .full ((n_designs ,), i , device = device , dtype = th .long )
181308
182- loss . backward ( loss )
309+ image = ddm_sampler . sample_timestep ( model , image , t , encoder_hidden_states )
183310
184- # Update the model parameters with the optimizer
311+ return image , encoder_hidden_states
185312
186- optimizer .step ()
313+ # ----------
314+ # Training
315+ # ----------
316+ for epoch in tqdm .trange (args .n_epochs ):
317+ for i , data in enumerate (dataloader ):
187318
319+ # Zero the parameter gradients
188320 optimizer .zero_grad ()
321+ designs = data [0 ].reshape (- 1 ,1 ,100 ,100 )
322+ x = designs .to (device )
323+ objs = th .stack ((data [1 :]), dim = 1 ).reshape (- 1 ,1 ,1 )
324+ objs_ex = objs .expand (- 1 ,1 ,32 )
325+
326+
327+ current_batch_size = x .shape [0 ]
328+ t = th .randint (0 , T , (current_batch_size ,), device = device ).long ()
329+ encoder_hidden_states = objs_ex .to (device )
330+
331+ # Get the noise and the noisy input
332+ x_noisy , noise = ddm_sampler .forward_diffusion_sample (x , t , device )
333+
334+ # Forward pass
335+ # if mp_mode:
336+ # with torch.cuda.amp.autocast(dtype=torch.float16):
337+ # noise_pred = model_diffuser(x_noisy, t, encoder_hidden_states).sample
338+ # loss = ddm_loss_fn(noise_pred, noise)
339+ # scaler.scale(loss).backward()
340+ # scaler.step(optimizer)
341+ # scaler.update()
342+ # else:
343+ noise_pred = model (x_noisy , t , encoder_hidden_states ).sample
344+ loss = ddm_loss_fn (noise_pred , noise )
345+
346+ # Backpropagation
347+ loss .backward ()
348+ optimizer .step ()
189349
190350 # ----------
191351 # Logging
@@ -206,18 +366,19 @@ def sample_designs(n_designs: int) -> th.Tensor:
206366 # This saves a grid image of 25 generated designs every sample_interval
207367 if batches_done % args .sample_interval == 0 :
208368 # Extract 25 designs
209- desired_objs , designs = sample_designs (25 )
369+
370+ designs , hidden_states = sample_designs (model , 25 )
210371 fig , axes = plt .subplots (5 , 5 , figsize = (12 , 12 ))
211372
212373 # Flatten axes for easy indexing
213374 axes = axes .flatten ()
214375
215- # Plot each tensor as a scatter plot
376+ # Plot the iamge created by each output
216377 for j , tensor in enumerate (designs ):
217378 img = tensor .cpu ().numpy ().reshape (100 ,100 ) # Extract x and y coordinates
218- do = desired_objs [ j ].cpu ()
219- axes [j ].imshow (img ) # Scatter plot
220- axes [j ].title .set_text (f"volfrac: { do [ 0 ] :.2f} " )
379+ do = hidden_states [ j , 0 , 0 ].cpu ()
380+ axes [j ].imshow (img . T ) # image plot
381+ axes [j ].title .set_text (f"volfrac: { do :.2f} " ) # Set title
221382 axes [j ].set_xticks ([]) # Hide x ticks
222383 axes [j ].set_yticks ([]) # Hide y ticks
223384
0 commit comments