From a3d8838f16a09c9abcf1d739d51441b23de9992c Mon Sep 17 00:00:00 2001 From: Abhijeet Vishwasrao Date: Fri, 7 Nov 2025 23:00:16 -0500 Subject: [PATCH 1/2] Added 3DEDMPrecond class with 3D version of songunet for diffusion models, utilized for urban flow project under NVIDIA grant Signed-off-by: Abhijeet Vishwasrao --- physicsnemo/models/diffusion/__init__.py | 2 + .../models/diffusion/preconditioning.py | 105 ++ physicsnemo/models/diffusion/song_unet3d.py | 903 ++++++++++++++++++ test/models/data/ddmpp_unet3d_output.pth | Bin 0 -> 5821 bytes test/models/data/ncsnpp_unet3d_output.pth | Bin 0 -> 5828 bytes test/models/diffusion/test_song_unet3d.py | 389 ++++++++ 6 files changed, 1399 insertions(+) create mode 100644 physicsnemo/models/diffusion/song_unet3d.py create mode 100644 test/models/data/ddmpp_unet3d_output.pth create mode 100644 test/models/data/ncsnpp_unet3d_output.pth create mode 100644 test/models/diffusion/test_song_unet3d.py diff --git a/physicsnemo/models/diffusion/__init__.py b/physicsnemo/models/diffusion/__init__.py index db850c14b6..96ad501520 100644 --- a/physicsnemo/models/diffusion/__init__.py +++ b/physicsnemo/models/diffusion/__init__.py @@ -26,10 +26,12 @@ UNetBlock, ) from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd +from .song_unet3d import SongUNet3D from .dhariwal_unet import DhariwalUNet from .unet import UNet, StormCastUNet from .preconditioning import ( EDMPrecond, + EDMPrecond3D, EDMPrecondSuperResolution, EDMPrecondSR, VEPrecond, diff --git a/physicsnemo/models/diffusion/preconditioning.py b/physicsnemo/models/diffusion/preconditioning.py index c42faff028..09dc2ff82e 100644 --- a/physicsnemo/models/diffusion/preconditioning.py +++ b/physicsnemo/models/diffusion/preconditioning.py @@ -689,6 +689,111 @@ def round_sigma(sigma: Union[float, List, torch.Tensor]): return torch.as_tensor(sigma) +@dataclass +class EDMPrecond3DMetaData(ModelMetaData): + """EDMPrecond meta data""" + + name: str = "EDMPrecond3D" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class EDMPrecond3D(EDMPrecond): + """ + Apply EDM preconditioning to denoise a 3D volumetric input. + + Parameters + ---------- + x : torch.Tensor + Noisy volumetric input of shape (B, C, D, H, W) where B is batch size, + C is channels, and D, H, W are spatial dimensions. + sigma : torch.Tensor + Noise level(s) of shape (B,) or (B, 1). + condition : torch.Tensor, optional + Additional conditioning input to concatenate along channel dimension. + Must have shape (B, C_cond, D, H, W), by default None. + class_labels : torch.Tensor, optional + Class labels for conditional generation of shape (B, label_dim). + If None and label_dim > 0, zero labels are used, by default None. + force_fp32 : bool, optional + Force FP32 precision regardless of `use_fp16` setting, by default False. + **model_kwargs : dict + Additional keyword arguments passed to the underlying model's forward method. + + Returns + ------- + torch.Tensor + Denoised volumetric output of shape (B, C, D, H, W). + + Raises + ------ + ValueError + If the model output dtype doesn't match the expected dtype (when not + using autocast). + """ + + def forward( + self, + x, + sigma, + condition=None, + class_labels=None, + force_fp32=False, + **model_kwargs, + ): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1, 1) + + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + arg = c_in * x + + if condition is not None: + arg = torch.cat([arg, condition], dim=1) + + F_x = self.model( + arg.to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + @dataclass class EDMPrecondSuperResolutionMetaData(ModelMetaData): """EDMPrecondSuperResolution meta data""" diff --git a/physicsnemo/models/diffusion/song_unet3d.py b/physicsnemo/models/diffusion/song_unet3d.py new file mode 100644 index 0000000000..6a71189a84 --- /dev/null +++ b/physicsnemo/models/diffusion/song_unet3d.py @@ -0,0 +1,903 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Union + +import numpy as np +import torch +import torch.nn.functional as F +from einops import rearrange +from torch.nn.functional import silu +from torch.utils.checkpoint import checkpoint + +from physicsnemo.models.diffusion import ( + AttentionOp, + FourierEmbedding, + Linear, + PositionalEmbedding, + weight_init, +) +from physicsnemo.models.meta import ModelMetaData +from physicsnemo.models.module import Module + + +class Conv3d(torch.nn.Module): + """ + 3D convolution layer with optional upsampling and downsampling. + + This layer implements a 3D convolution operation with optional bilinear + resampling (upsampling or downsampling) capabilities. It supports both + fused and non-fused resampling modes for efficiency. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel : int + Kernel size for the convolution (applied uniformly across all spatial dimensions). + bias : bool, optional + Whether to include a learnable bias term, by default True. + up : bool, optional + Whether to apply 2x upsampling before/after convolution, by default False. + down : bool, optional + Whether to apply 2x downsampling before/after convolution, by default False. + resample_filter : List[int], optional + 1D filter coefficients for bilinear resampling, by default [1, 1]. + The 3D filter is constructed as outer product of this 1D filter. + fused_resample : bool, optional + Whether to fuse resampling with convolution for efficiency, by default False. + init_mode : str, optional + Weight initialization mode, by default "kaiming_normal". + init_weight : float, optional + Multiplier for weight initialization, by default 1.0. + init_bias : float, optional + Multiplier for bias initialization, by default 0.0. + + Raises + ------ + ValueError + If both `up` and `down` are set to True simultaneously. + + Note + ---- + When `fused_resample=True`, the resampling operation is combined with + convolution for improved computational efficiency. The resample filter + is constructed as a 3D separable filter from the 1D coefficients. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel: int, + bias: bool = True, + up: bool = False, + down: bool = False, + resample_filter: List[int] = [1, 1], + fused_resample: bool = False, + init_mode: str = "kaiming_normal", + init_weight: float = 1.0, + init_bias: float = 0.0, + ): + if up and down: + raise ValueError("Both 'up' and 'down' cannot be true at the same time.") + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.up = up + self.down = down + self.fused_resample = fused_resample + init_kwargs = dict( + mode=init_mode, + fan_in=in_channels * kernel * kernel * kernel, + fan_out=out_channels * kernel * kernel * kernel, + ) + self.weight = ( + torch.nn.Parameter( + weight_init( + [out_channels, in_channels, kernel, kernel, kernel], **init_kwargs + ) + * init_weight + ) + if kernel + else None + ) + self.bias = ( + torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) + if kernel and bias + else None + ) + f = torch.as_tensor(resample_filter, dtype=torch.float32) + f = (f.ger(f).unsqueeze(2) * f.view(1, 1, -1)).unsqueeze(0).unsqueeze( + 1 + ) / f.sum().pow(3) # for 3D, should be ^3 + self.register_buffer("resample_filter", f.contiguous() if up or down else None) + + def forward(self, x): + w = self.weight.to(x.dtype) if self.weight is not None else None + b = self.bias.to(x.dtype) if self.bias is not None else None + f = ( + self.resample_filter.to(x.dtype) + if self.resample_filter is not None + else None + ) + w_pad = w.shape[-1] // 2 if w is not None else 0 + f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 + + if self.fused_resample and self.up and w is not None: + x = torch.nn.functional.conv_transpose3d( + x, + f.mul(4).tile([self.in_channels, 1, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=max(f_pad - w_pad, 0), + ) + x = torch.nn.functional.conv3d(x, w, padding=max(w_pad - f_pad, 0)) + elif self.fused_resample and self.down and w is not None: + x = torch.nn.functional.conv3d(x, w, padding=w_pad + f_pad) + x = torch.nn.functional.conv3d( + x, + f.tile([self.out_channels, 1, 1, 1, 1]), + groups=self.out_channels, + stride=2, + ) + else: + if self.up: + x = torch.nn.functional.conv_transpose3d( + x, + f.mul(4).tile([self.in_channels, 1, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=f_pad, + ) + if self.down: + x = torch.nn.functional.conv3d( + x, + f.tile([self.in_channels, 1, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=f_pad, + ) + if w is not None: + x = torch.nn.functional.conv3d(x, w, padding=w_pad) + if b is not None: + x = x.add_(b.reshape(1, -1, 1, 1, 1)) + return x + + +class GroupNorm(torch.nn.Module): + """ + A custom Group Normalization layer implementation. + + Group Normalization (GN) divides the channels of the input tensor into groups and + normalizes the features within each group independently. It does not require the + batch size as in Batch Normalization, making itsuitable for batch sizes of any size + or even for batch-free scenarios. + + Parameters + ---------- + num_channels : int + Number of channels in the input tensor. + num_groups : int, optional + Desired number of groups to divide the input channels, by default 32. + This might be adjusted based on the `min_channels_per_group`. + min_channels_per_group : int, optional + Minimum channels required per group. This ensures that no group has fewer + channels than this number. By default 4. + eps : float, optional + A small number added to the variance to prevent division by zero, by default + 1e-5. + + Notes + ----- + If `num_channels` is not divisible by `num_groups`, the actual number of groups + might be adjusted to satisfy the `min_channels_per_group` condition. + """ + + def __init__( + self, + num_channels: int, + num_groups: int = 32, + min_channels_per_group: int = 4, + eps: float = 1e-5, + ): + super().__init__() + self.num_groups = min(num_groups, num_channels // min_channels_per_group) + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(num_channels)) + self.bias = torch.nn.Parameter(torch.zeros(num_channels)) + + def forward(self, x): + if self.training: + # Use default torch implementation of GroupNorm for training + # This does not support channels last memory format + x = torch.nn.functional.group_norm( + x, + num_groups=self.num_groups, + weight=self.weight.to(x.dtype), + bias=self.bias.to(x.dtype), + eps=self.eps, + ) + else: + # Use custom GroupNorm implementation that supports channels last + # memory layout for inference + dtype = x.dtype + x = x.float() + x = rearrange(x, "b (g c) d h w -> b g c d h w", g=self.num_groups) + + mean = x.mean(dim=[2, 3, 4, 5], keepdim=True) # added 5th dim + var = x.var(dim=[2, 3, 4, 5], keepdim=True) + + x = (x - mean) * (var + self.eps).rsqrt() + x = rearrange(x, "b g c d h w -> b (g c) d h w") + + weight = rearrange(self.weight, "c -> 1 c 1 1 1") + bias = rearrange(self.bias, "c -> 1 c 1 1 1") + x = x * weight + bias + + x = x.type(dtype) + return x + + +class UNetBlock3D(torch.nn.Module): + """ + Unified U-Net block with optional up/downsampling and self-attention. Represents + the union of all features employed by the DDPM++, NCSN++, and ADM architectures. + + Parameters: + ----------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + emb_channels : int + Number of embedding channels. + up : bool, optional + If True, applies upsampling in the forward pass. By default False. + down : bool, optional + If True, applies downsampling in the forward pass. By default False. + attention : bool, optional + If True, enables the self-attention mechanism in the block. By default False. + num_heads : int, optional + Number of attention heads. If None, defaults to `out_channels // 64`. + channels_per_head : int, optional + Number of channels per attention head. By default 64. + dropout : float, optional + Dropout probability. By default 0.0. + skip_scale : float, optional + Scale factor applied to skip connections. By default 1.0. + eps : float, optional + Epsilon value used for normalization layers. By default 1e-5. + resample_filter : List[int], optional + Filter for resampling layers. By default [1, 1]. + resample_proj : bool, optional + If True, resampling projection is enabled. By default False. + adaptive_scale : bool, optional + If True, uses adaptive scaling in the forward pass. By default True. + init : dict, optional + Initialization parameters for convolutional and linear layers. + init_zero : dict, optional + Initialization parameters with zero weights for certain layers. By default + {'init_weight': 0}. + init_attn : dict, optional + Initialization parameters specific to attention mechanism layers. + Defaults to 'init' if not provided. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + emb_channels: int, + up: bool = False, + down: bool = False, + attention: bool = False, + num_heads: int = None, + channels_per_head: int = 64, + dropout: float = 0.0, + skip_scale: float = 1.0, + eps: float = 1e-5, + resample_filter: List[int] = [1, 1], + resample_proj: bool = False, + adaptive_scale: bool = True, + init: Dict[str, Any] = dict(), + init_zero: Dict[str, Any] = dict(init_weight=0), + init_attn: Any = None, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.emb_channels = emb_channels + self.num_heads = ( + 0 + if not attention + else num_heads + if num_heads is not None + else out_channels // channels_per_head + ) + self.dropout = dropout + self.skip_scale = skip_scale + self.adaptive_scale = adaptive_scale + + self.norm0 = GroupNorm(num_channels=in_channels, eps=eps) + self.conv0 = Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel=3, + up=up, + down=down, + resample_filter=resample_filter, + **init, + ) + self.affine = Linear( + in_features=emb_channels, + out_features=out_channels * (2 if adaptive_scale else 1), + **init, + ) + self.norm1 = GroupNorm(num_channels=out_channels, eps=eps) + self.conv1 = Conv3d( + in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero + ) + + self.skip = None + if out_channels != in_channels or up or down: + kernel = 1 if resample_proj or out_channels != in_channels else 0 + self.skip = Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel=kernel, + up=up, + down=down, + resample_filter=resample_filter, + **init, + ) + + if self.num_heads: + self.norm2 = GroupNorm(num_channels=out_channels, eps=eps) + self.qkv = Conv3d( + in_channels=out_channels, + out_channels=out_channels * 3, + kernel=1, + **(init_attn if init_attn is not None else init), + ) + self.proj = Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel=1, + **init_zero, + ) + + def forward(self, x, emb): + # torch.cuda.nvtx.range_push("UNetBlock3D") + orig = x + x = self.conv0(silu(self.norm0(x))) + + params = self.affine(emb).unsqueeze(2).unsqueeze(3).unsqueeze(4).to(x.dtype) + if self.adaptive_scale: + scale, shift = params.chunk(chunks=2, dim=1) + x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) + else: + x = silu(self.norm1(x.add_(params))) + + x = self.conv1( + torch.nn.functional.dropout(x, p=self.dropout, training=self.training) + ) + x = x.add_(self.skip(orig) if self.skip is not None else orig) + x = x * self.skip_scale + + if self.num_heads: + q, k, v = ( + self.qkv(self.norm2(x)) + .reshape( + x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1 + ) + .unbind(2) + ) + w = AttentionOp.apply(q, k) + a = torch.einsum("nqk,nck->ncq", w, v) + x = self.proj(a.reshape(*x.shape)).add_(x) + x = x * self.skip_scale + # torch.cuda.nvtx.range_pop() + return x + + +@dataclass +class MetaData(ModelMetaData): + name: str = "SongUNet3D" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = True + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class SongUNet3D(Module): + """ + 3D U-Net diffusion backbone for volumetric data generation. + + This architecture extends the DDPM++ and NCSN++ models to 3D volumetric data, + implementing a U-Net variant with optional self-attention, embeddings, and + encoder-decoder components for generating 3D volumes. + + The model supports both conditional and unconditional generation with flexible + architectural choices for encoder/decoder types, embedding types, and attention + mechanisms. It can be configured for various 3D diffusion tasks including medical + imaging, scientific simulations, and volumetric content generation. + + Architecture Overview + --------------------- + The model processes 3D volumetric inputs through: + + 1. **Embedding Generation**: Maps noise levels, class labels, and augmentation + labels to embeddings that condition the generation process. + + 2. **U-Net Encoder**: A hierarchical encoder with multiple levels, where each level: + - Downsamples spatial resolution by 2x (D, H, W dimensions) + - Applies ``num_blocks`` residual blocks with conditioning + - Optionally applies 3D self-attention at specified resolutions + - Caches features for skip connections + + 3. **U-Net Decoder**: Mirror of the encoder that: + - Upsamples spatial resolution by 2x at each level + - Combines features via skip connections from encoder + - Produces the final denoised 3D volume + + Conditioning Mechanism + ---------------------- + - **Noise labels**: Condition on diffusion timestep/noise level + - **Class labels**: Optional vector-valued class conditioning + - **Augmentation labels**: Optional data augmentation conditioning + - **Image conditioning**: Concatenate conditioning volumes to input channels + + Parameters + ---------- + img_resolution : Union[List[int], int] + Spatial resolution of the volumetric data. Can be a single int for uniform + resolution (D=H=W) or a list [D, H, W] for non-uniform dimensions. + Note: Model can process different resolutions at inference, except when + ``additive_pos_embed=True``. + in_channels : int + Number of input channels. Includes both latent channels and any additional + channels for image-based conditioning. For unconditional models, should + equal ``out_channels``. + out_channels : int + Number of output channels. Should match the number of channels in the + latent state being denoised. + label_dim : int, optional + Dimension of vector-valued class labels for conditional generation. + Set to 0 for unconditional generation, by default 0. + augment_dim : int, optional + Dimension of vector-valued augmentation labels. Set to 0 for no + augmentation conditioning, by default 0. + model_channels : int, optional + Base channel multiplier for the network. Determines the number of + channels at the first level, by default 128. + channel_mult : List[int], optional + Channel multipliers at each U-Net level. Length determines the number + of levels. At level i, channels = ``channel_mult[i] * model_channels``, + by default [1, 2, 2, 2]. + channel_mult_emb : int, optional + Multiplier for embedding vector channels. Embedding dimension is + ``model_channels * channel_mult_emb``, by default 4. + num_blocks : int, optional + Number of residual blocks at each U-Net level, by default 4. + attn_resolutions : List[int], optional + Spatial resolutions at which to apply 3D self-attention. Attention is + applied when the feature map resolution matches these values exactly, + by default [16]. + dropout : float, optional + Dropout probability for intermediate activations in U-Net blocks, + by default 0.10. + label_dropout : float, optional + Dropout probability for class labels, typically used for classifier-free + guidance during training, by default 0.0. + embedding_type : str, optional + Noise level embedding type. Options: 'positional' (DDPM++), 'fourier' + (NCSN++), or 'zero' (no embedding), by default "positional". + channel_mult_noise : int, optional + Channel multiplier for noise level embeddings. Noise embedding dimension + is ``model_channels * channel_mult_noise``, by default 1. + encoder_type : str, optional + Encoder architecture variant. Options: 'standard' (DDPM++), 'residual' + (NCSN++), or 'skip' (skip connections), by default "standard". + decoder_type : str, optional + Decoder architecture variant. Options: 'standard' or 'skip' (skip + connections), by default "standard". + resample_filter : List[int], optional + 1D filter coefficients for resampling operations. Use [1, 1] for DDPM++ + or [1, 3, 3, 1] for NCSN++, by default [1, 1]. + checkpoint_level : int, optional + Number of levels to use gradient checkpointing. Higher values trade + memory for computation. 0 disables checkpointing, by default 0. + additive_pos_embed : bool, optional + If True, adds learnable positional embeddings encoding spatial position + (separate from temporal diffusion embeddings). When enabled, input + resolution must match ``img_resolution``, by default False. + + Raises + ------ + ValueError + If ``embedding_type`` is not one of ['fourier', 'positional', 'zero']. + ValueError + If ``encoder_type`` is not one of ['standard', 'skip', 'residual']. + ValueError + If ``decoder_type`` is not one of ['standard', 'skip']. + + Note + ---- + This is a 3D extension of the SongUNet architecture. The primary differences + from the 2D version are: + - All convolutions and attention operations work on 3D volumes (B, C, D, H, W) + - Resampling filters are constructed as 3D separable filters + - Self-attention operates on flattened 3D spatial dimensions + + See Also + -------- + SongUNet : 2D variant of this architecture for image generation. + EDMPrecond3D : Preconditioning wrapper for 3D diffusion models. + + References + ---------- + .. [1] Nichol, A. Q., & Dhariwal, P. (2021). Improved denoising diffusion + probabilistic models. ICML 2021. + .. [2] Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., + & Poole, B. (2021). Score-based generative modeling through stochastic + differential equations. ICLR 2021. + + Examples + -------- + >>> # Create unconditional 3D diffusion model for 64^3 volumes + >>> model = SongUNet3D( + ... img_resolution=64, + ... in_channels=4, + ... out_channels=4, + ... model_channels=128, + ... channel_mult=[1, 2, 2, 2], + ... num_blocks=4, + ... ) + >>> + >>> # Forward pass with noise conditioning + >>> x = torch.randn(2, 4, 64, 64, 64) # Noisy volumes + >>> noise_labels = torch.randn(2, 128) # Noise level embeddings + >>> denoised = model(x, noise_labels) + >>> denoised.shape + torch.Size([2, 4, 64, 64, 64]) + """ + + def __init__( + self, + img_resolution: Union[List[int], int], + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 128, + channel_mult: List[int] = [1, 2, 2, 2], + channel_mult_emb: int = 4, + num_blocks: int = 4, + attn_resolutions: List[int] = [16], + dropout: float = 0.10, + label_dropout: float = 0.0, + embedding_type: str = "positional", + channel_mult_noise: int = 1, + encoder_type: str = "standard", + decoder_type: str = "standard", + resample_filter: List[int] = [1, 1], + checkpoint_level: int = 0, + additive_pos_embed: bool = False, + ): + valid_embedding_types = ["fourier", "positional", "zero"] + if embedding_type not in valid_embedding_types: + raise ValueError( + f"Invalid embedding_type: {embedding_type}. Must be one of {valid_embedding_types}." + ) + + valid_encoder_types = ["standard", "skip", "residual"] + if encoder_type not in valid_encoder_types: + raise ValueError( + f"Invalid encoder_type: {encoder_type}. Must be one of {valid_encoder_types}." + ) + + valid_decoder_types = ["standard", "skip"] + if decoder_type not in valid_decoder_types: + raise ValueError( + f"Invalid decoder_type: {decoder_type}. Must be one of {valid_decoder_types}." + ) + + super().__init__(meta=MetaData()) + self.label_dropout = label_dropout + self.embedding_type = embedding_type + emb_channels = model_channels * channel_mult_emb + self.emb_channels = emb_channels + noise_channels = model_channels * channel_mult_noise + + init = dict(init_mode="xavier_uniform") + init_zero = dict(init_mode="xavier_uniform", init_weight=1e-5) + init_attn = dict(init_mode="xavier_uniform", init_weight=np.sqrt(0.2)) + + block_kwargs = dict( + emb_channels=emb_channels, + num_heads=1, + dropout=dropout, + skip_scale=np.sqrt(0.5), + eps=1e-6, + resample_filter=resample_filter, + resample_proj=True, + adaptive_scale=False, + init=init, + init_zero=init_zero, + init_attn=init_attn, + ) + + # Handle image resolution (now 3D) + self.img_resolution = img_resolution + if isinstance(img_resolution, int): + self.img_shape_z = self.img_shape_y = self.img_shape_x = img_resolution + elif len(img_resolution) == 2: + self.img_shape_y, self.img_shape_x = img_resolution + self.img_shape_z = img_resolution[0] # Default to same as y + else: + self.img_shape_z, self.img_shape_y, self.img_shape_x = img_resolution[:3] + + # Set checkpoint threshold based on resolution + max_dimension = max(self.img_shape_x, self.img_shape_y, self.img_shape_z) + self.checkpoint_threshold = (max_dimension >> checkpoint_level) + 1 + + # Optional additive learned position embed after the first conv + self.additive_pos_embed = additive_pos_embed + if self.additive_pos_embed: + self.spatial_emb = torch.nn.Parameter( + torch.randn( + 1, + model_channels, + self.img_shape_z, + self.img_shape_y, + self.img_shape_x, + ) + ) + torch.nn.init.trunc_normal_(self.spatial_emb, std=0.02) + + # Mapping + if self.embedding_type != "zero": + self.map_noise = ( + PositionalEmbedding(num_channels=noise_channels, endpoint=True) + if embedding_type == "positional" + else FourierEmbedding(num_channels=noise_channels) + ) + self.map_label = ( + Linear(in_features=label_dim, out_features=noise_channels, **init) + if label_dim + else None + ) + self.map_augment = ( + Linear( + in_features=augment_dim, + out_features=noise_channels, + bias=False, + **init, + ) + if augment_dim + else None + ) + self.map_layer0 = Linear( + in_features=noise_channels, out_features=emb_channels, **init + ) + self.map_layer1 = Linear( + in_features=emb_channels, out_features=emb_channels, **init + ) + + # Encoder + self.enc = torch.nn.ModuleDict() + cout = in_channels + caux = in_channels + for level, mult in enumerate(channel_mult): + res = self.img_shape_y >> level + if level == 0: + cin = cout + cout = model_channels + self.enc[f"{res}x{res}_conv"] = Conv3d( + in_channels=cin, out_channels=cout, kernel=3, **init + ) + else: + self.enc[f"{res}x{res}_down"] = UNetBlock3D( + in_channels=cout, out_channels=cout, down=True, **block_kwargs + ) + if encoder_type == "skip": + self.enc[f"{res}x{res}_aux_down"] = Conv3d( + in_channels=caux, + out_channels=caux, + kernel=0, + down=True, + resample_filter=resample_filter, + ) + self.enc[f"{res}x{res}_aux_skip"] = Conv3d( + in_channels=caux, out_channels=cout, kernel=1, **init + ) + if encoder_type == "residual": + self.enc[f"{res}x{res}_aux_residual"] = Conv3d( + in_channels=caux, + out_channels=cout, + kernel=3, + down=True, + resample_filter=resample_filter, + fused_resample=True, + **init, + ) + caux = cout + for idx in range(num_blocks): + cin = cout + cout = model_channels * mult + attn = res in attn_resolutions + self.enc[f"{res}x{res}_block{idx}"] = UNetBlock3D( + in_channels=cin, out_channels=cout, attention=attn, **block_kwargs + ) + skips = [ + block.out_channels for name, block in self.enc.items() if "aux" not in name + ] + + # Decoder + self.dec = torch.nn.ModuleDict() + for level, mult in reversed(list(enumerate(channel_mult))): + res = self.img_shape_y >> level + if level == len(channel_mult) - 1: + self.dec[f"{res}x{res}_in0"] = UNetBlock3D( + in_channels=cout, out_channels=cout, attention=True, **block_kwargs + ) + self.dec[f"{res}x{res}_in1"] = UNetBlock3D( + in_channels=cout, out_channels=cout, **block_kwargs + ) + else: + self.dec[f"{res}x{res}_up"] = UNetBlock3D( + in_channels=cout, out_channels=cout, up=True, **block_kwargs + ) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = model_channels * mult + attn = idx == num_blocks and res in attn_resolutions + self.dec[f"{res}x{res}_block{idx}"] = UNetBlock3D( + in_channels=cin, out_channels=cout, attention=attn, **block_kwargs + ) + if decoder_type == "skip" or level == 0: + if decoder_type == "skip" and level < len(channel_mult) - 1: + self.dec[f"{res}x{res}_aux_up"] = Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel=0, + up=True, + resample_filter=resample_filter, + ) + self.dec[f"{res}x{res}_aux_norm"] = GroupNorm( + num_channels=cout, eps=1e-6 + ) + self.dec[f"{res}x{res}_aux_conv"] = Conv3d( + in_channels=cout, out_channels=out_channels, kernel=3, **init_zero + ) + + def forward(self, x, noise_labels, class_labels=None, augment_labels=None): + # Mapping + if self.embedding_type != "zero": + emb = self.map_noise(noise_labels) + emb = ( + emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) + ) # swap sin/cos + if self.map_label is not None: + tmp = class_labels + if self.training and self.label_dropout: + tmp = tmp * ( + torch.rand([x.shape[0], 1], device=x.device) + >= self.label_dropout + ).to(tmp.dtype) + emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features)) + if self.map_augment is not None and augment_labels is not None: + emb = emb + self.map_augment(augment_labels) + emb = F.silu(self.map_layer0(emb)) + emb = F.silu(self.map_layer1(emb)) + else: + emb = torch.zeros( + (noise_labels.shape[0], self.emb_channels), device=x.device + ) + + # Encoder + skips = [] + aux = x + for name, block in self.enc.items(): + if "aux_down" in name: + aux = block(aux) + elif "aux_skip" in name: + x = skips[-1] = x + block(aux) + elif "aux_residual" in name: + x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) + elif "_conv" in name: + x = block(x) + if self.additive_pos_embed: + x = x + self.spatial_emb.to(dtype=x.dtype) + skips.append(x) + else: + if isinstance(block, UNetBlock3D): + if x.shape[-1] > self.checkpoint_threshold: + x = checkpoint(block, x, emb, use_reentrant=False) + else: + x = block(x, emb) + else: + x = block(x) + skips.append(x) + + # Decoder + aux = None + tmp = None + for name, block in self.dec.items(): + if "aux_up" in name: + aux = block(aux) + elif "aux_norm" in name: + tmp = block(x) + elif "aux_conv" in name: + tmp = block(F.silu(tmp)) + aux = tmp if aux is None else tmp + aux + else: + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + if (x.shape[-1] > self.checkpoint_threshold and "_block" in name) or ( + x.shape[-1] > (self.checkpoint_threshold / 2) and "_up" in name + ): + x = checkpoint(block, x, emb, use_reentrant=False) + else: + x = block(x, emb) + return aux + + +if __name__ == "__main__": + # Example usage + model = SongUNet3D( + img_resolution=8, + in_channels=3, + out_channels=3, + label_dim=0, + augment_dim=0, + model_channels=16, + channel_mult=[1, 2, 2, 2], + channel_mult_emb=4, + num_blocks=4, + attn_resolutions=[16], + dropout=0.10, + label_dropout=0.0, + embedding_type="positional", + channel_mult_noise=1, + encoder_type="standard", + decoder_type="standard", + resample_filter=[1, 1], + checkpoint_level=0, + additive_pos_embed=False, + ) + print("Model created successfully.") + + x = torch.randn(1, 3, 8, 8, 8) # Example input + noise_labels = torch.randn([1]) # Example noise labels + + out = model(x, noise_labels) + print(out) diff --git a/test/models/data/ddmpp_unet3d_output.pth b/test/models/data/ddmpp_unet3d_output.pth new file mode 100644 index 0000000000000000000000000000000000000000..f5b301ce997a7eef19a08499557b3c5e26f0cb94 GIT binary patch literal 5821 zcmbVQ4Oq-s-=5MYL}RwiXh=;hE$MU4@1zZB+q6a6Wl33mNAxje^0io3La6jH$VeN) zZj{9ToM&g^@zt`nttr+Fg<7eU4Jq{!J*TzLUU+!lecpf9ef^uc&$<8i{hM=5UALdl z06iYhz<~E3*C?JoFFZU(Dh*1EjgUEo2gN1Iq=_=S@DN#ut#oO$r@JmMAYi@j+IXH1 zFH9Df5VqJhC{Y#_y=-t$LPTg{RCIWdEFyMUTtd)t$HDPB*4F+5*+Z7Gl_85F;&oqc z8SEJy7b2Ve^3B2V1N?Q^8vA%X{{ieCCQXcgZzfMq7O(%g56{|1$4A%4p#Mh}KhT=J zKa;2H9C>B4HVOn5Vl7o5TP2a_IL%AJlE= z&)tuA`x(od@OD3cjx;DFI%-jDOhl~A(|z3EwFHuW15;PT-_vCe@5j3hjOIDL)j;Sf zSwvj_z&+jV{_b=CJ+J5}nJhY@-^0eqcSh>7=lA2?9!zrAT~Azv%|)4^q?_3TxSet*9#zcS04nM9FWY_*#&kdHo&JJCc`I|$xzKthMQJ< zQKgG67M!>yDGepCB&Qa}jjV&qBkF+PUJL4OJLuM(fG>+QlHRWf1g@zCqpn(b*ij4H zRkffOX$Q}qjK>d6G?J?e2`E<8f>~!RjAFP^RtqCS?ZAie0S4D32T};wr`5ulzFP3_ zt%cm0THx)n1FN3#cl00&LCD5GKIaM&-fxp;ma}(nuH@;0N1h>Oo(L9_SS7L9VSHe-(s@4Q^%FW zN%tmncv_*x!wTXbpUo9GXK^dtvbcZv%DGL3DpGc}4(g^@p?L)gvU6!R=g+v@X<1z8 zLOG{DL`9DOS_ipnEzy)%1)1WO&0V^i$=SDNa^KZvbGkQ`WJ+@#NE9+?xi?bsds7aK zTf7^BBX>jowH&xmwi$U?@Nw&BYk*g@!=czT8!CrqgC0K{_?j$8ePP2{XP$*orjF>! zENie93*l*s5PnG$!o&B3U}WJ6zD_)R)JK9YHO@rAp(42aKm_9(M6e`G1RF&M(HyP^ zbRLwWtIOKZat9S0Pg23c)hf^qQ$f4=J+!eM@Pbc%M1BL>(32r5px%u4P{Eu_O1QDE z8JS)OY&GXc#Am#-kqU%fDoAx#fnm84dRH`~$&7bCJaI8xS zo}aX!3%>zYMEihc^mfpR;Nv55K3=P+WG|5&12%hSCfn#I_3|uuBNCSRs5C=?bT3@bK>698u#9ACxju1b&A^ z;CWC4683xQ`~o!fau2v{Dn~ynT2b1EDwy)M3NEw!Cv{a&QPP4sdjR|HFGqFyt*GaT z66QFn;1-*c^rRA!x|`5c=6~OTa^%c-mkuQ)*s7qEtw}Cbg8B0%#H{0A9m~R-g2~TJ@aOmI zVR)AsX88lo^#y#KK`;^UW1j{vt{aLQZ#betnKeA$DuhiPLdbd`gegmf5X$n;VflNS z9nsY;A9PA2g6@4Hh&n8SJiZ8Af(j7-QV(P$l_SN@RwTAmLE~~2c&}8!H#`;OooPYV z{{wg#Ek^@&T9Ijw5{|L#ZAb;dET8mo6H>8!io$Z#Jh&Alb}FHPW&he<1-d7caHz8h z4ZQ>SpDyL7^g=5-!)99ei3%QiviWG0P{#6iv-}R7-Y|ycAHwqIuHxe_Wqf>WA|C_G zKYEB8ws}ww`yQ*oH~_Fe%YJnx;0#;9(Y_5(SUVJt+T(1 zgd6Pq*Lm>p>cbAmc(NCoI@T22b4=m-PE*LYFoh8=`_a_o9?0&!fF2aIppJt|NWP~8 zHwLs(2|E`zp?OyUONW-DiIn7KyvGQS z9>d2Hh7GrkaK`3aEZ>s`2% zUIS8E1H#f8*qc-Xg`@9)|Je_*>HhO5sBbzDKPil@k-}+>6dp~If?CCc5gQ93>uxCo zAGrg2uL>~#g#d>!7*z{!y;m66ZfOLlEQJffci>Qw0Nec{z_kqB-wLoWDHOIYYy{n^ zQg9Br15b(t*twEvsM+^3*oJzYYZBRe@k=5(w+po&Y_)Vc2H#N$y6;NK!s&JGv>( zKr>JV>L^Y}tJb8zp6KD&`1>A7xJVxjJC_PR25BH4mImp$sgTgP4R&m^#bvJDl6-|8 zLWfgf)~!?!)~CYx#8mJu+6MU>ZLza?x1{Wl9x`D3qnoMF$JU%(o(djEw?Qf6jn{Qc zX3TyM{dOi5#y&}fi(RQOJ3SRfp5F%hzOltZoe)x)Y|WKSS^}wsOCY0o34D7r3P_(C zgz8CHJ^2w>-#rT>?+?N&+Xmqy>~38#2zOfCfR54CU{K*lM2q~0SxpWiojFLaBL~@6 zG3?kW;Tv72*qXPPmlyNgRvh(4&#VN(T1_a&Sz(#-&JX(UV(i zi1U?>lWiySVb`;KnAevNmreFTR+$w_isc|rFdOb~bwyvcuZ76owXl`94)SKMgP5qX zD04;wik@7_$#6QUSdWQKDJFGCFP$R`=qfQiWw?=+95%s69>*f-xN|6y ze-3RPd=AA5E}*ihVw&|dnd%)f!2?Bc=wrKch&SRKy7BxhxqrZohn!TM(Cq%6d{!1}_acwWlMC0lUfJNe*;6loP$7-qg6(h#tRXL>J#SqQ`zT zq8B@X{zE3EUfYW$H#44+BtbP<&#NYLdetOli<&$r@uC3(GU#_JTiiT7eHNm{BU-#yZh68{)laXpEq)MldH_sr-uCoQos(GtgD zTH>)uL*hrrQ2EbER8p6TbQ!-MX-R@mOXO^gY>S2jjftUtjQ1$aL>*Vm=t@^DDY0Vu z4BzKz$PH**#5-+ zXC0yDx-8V((!NSg_;CgCA8$@qD9TBpmp82&%BPMDkqkC0|2&po`?Zv=e{oo1aN#Lw zoLWu(F`=3`+Ex=)iJE+W%8Pcrmq8~)WT8teziqOX>}B~)qP4{Av4-@p{Anz|-<3=h z(q=|aI%-J{%f@FJdp2vxzELrB4a=uu^D%2NqnBBJ-H}?7%W#45>5MOAd@swNS8YZc zCo^43E$KgJd-FAi|ACA-FsHJI*lRFicjZH$JAq|lm^`^ka!wA zBP+A3Nla=r>D^vUB8AoDR=XEn{XqsD$o%z<7(;Q5mcTDs@_B`plzpisv8ge%{(ce- ztH?mx_+x0<4lQxnrX|)}wWPgELu>K_=;;#5b~t1lum4 zE;rWE`t(eRl~p25ZnvRn_ibonn+^3SvZ43BOrirVld0m&Bn%&PlciABe5ZE=$%hfpa=G_ zzb~6|736uEg2)+y4HYD$YX~iBpHIzE4I8#B)hAK!h;~)J#pNg3N z^6}}UgHw<>R6$<+z<5^$IX1+Q&J+dEUiA;)bKo*@m8;3bJ!-NnPfd1Ns>$zPnb7xh z=g}^=>9~BoC7Dw+gY2xEK@7Ds$f2|u#7}aD+)ewO_BRh(3a6Fw_^ftdcz5$4;^DV8 z4}aA_PXA@~@J9vWG`QbEb>5;c|L^`~Ja6n<9oYU+ek_Yfh++-U>X7~hEGR16(_Q>G z3^^d|=ru0T3_om>h<<{8O1T)_y_#chho0_kd zRij_i6tJSfpUrp~KpNc5Tx|NXaH`t#(zDfH^cJpMJI ztak{#{Q&y;= literal 0 HcmV?d00001 diff --git a/test/models/data/ncsnpp_unet3d_output.pth b/test/models/data/ncsnpp_unet3d_output.pth new file mode 100644 index 0000000000000000000000000000000000000000..996e759f19075f29e87550e62e901a4590a3f4aa GIT binary patch literal 5828 zcmbVQ3s{V4`+udH44anioUkk%rgVDmUtvkwL=HtW< zo6;sz`;^LOCuzgcLUG^u4Y9_QJ>i`@a7(*L}U$``pj-dw$P--_N|y^$QGW zug@?>M$EspZj1pVi%OKm$A>4$VifMt;c-cd_#{QY=txDRbNqCvpHDkxLYm%OITOG{ zDdG~Mra6ZvDPpCG!tjKc$w{%&=x{}hEHN%2JlRbsZ)ay0D&QAMn`SkfODUnATfx$E6cl7wh+w9{6Y#Fx?6PP?p5fj%IaX%mD z|8m`b#Y-BiP)K9iIDGfz4NF6Q{WkkJ2W!Ul!yMYGvY(5ekJJA~AjzeSc7|opZ2NSY zv@)Hh)Th&om+92w+aS$;_m!H9TV9dXMiCAy_QsHEZ-mu*Bd^LE7Ed7+^5EU&eqCxfzq+G{aO#@vhY@pO0P3+l> zMXY|#A=+eLMD4fLQ;Tr}-9Q8R_H3Z1y&tfbr!Qn36o=@`rXT41sCsHBtf!kt>j}H- zY0P(xY*F2ORxQR?yg!&ToJ$i}mG+d$=ou%y4<5Eqk`C@j`cb4?^C|_!h z@g?G|SCTJX`F=kevE>KX=3UMhx7Zq*C=1N5wZH_%5>MM%qTx<1$-4QHkC_aES1rfI zvdx%#cQZP?*o;NDH^bd50Ny=@Lts1>4SV|IX>IPEyQG%W*g&c$zXEy{Lj(Rx&igw0wQYSM7&Oe}`q;_%>@2nkJMj61dzJWAzhjMk1uP1|u_5 zjNO~XC`%EeZml&O*6V1UJBMQ%9I;9sjNA8vVQwr&&4XZcak9pRGu1TZx;M@UYDqED z3M;l-q06sU@Zs&nLp8OY38IVJ(kQTGo`&hXX2*dgi*_(KGIubahV7`|c~6tHB3tt; z#TAP0Oz`-!J_05fpih1g)=4zfLt$XXo?jg1&{z7(Uk zr5Md$S!3aaYWjlTe|)!E5=2?y)=?`+o2_7X-wMCKR@04|Alg%zM!)=)rb$cazJp1< zSH|RCDPxK!lzH3S({!1Zt*MH4#g4fq*jlHL`Qr?5Hp2kX2?mgt<dGI$>4q|VVW<43xsX~NtX=Mc98_Ve6){O@S|eUzHrRR+kiWU$v+ zfagj9inj?cGfRM*>MU9;PocOtIifohAh7rZ*4#XStM^Y}<#&}TE-WZ%gs@kBZsm^ ziko&yteLFDYbif&oD#q79EyVx=7^dgM|rdq-`vT@3R@+__DW1QRl+QEDAs;rj()km z_%Xx;W4;N1|DFJ>Im7=iR|UY$;0772-#{;HjNn(4O0D}>l6mb)I(KR%ZN9ORMiiy7 zF0&5rJQn(bcHCP@SN6@Mb$PRBRpl(2(Qg(B77QnsgYD_Y!eG=M@xp8M5EMTgf(80M z7=C*QOcq=4XXXhBB1fXN_9P{DE+OYDC6r>eosw>q(4=!)X-u4ih6Rm8%IcGpeyo^U z)|Jqv6D9OZQ3=VHZY7_4V)}H@NPs&@)0P+0tgsSNt|_4vDJ67!`c|skASU-7BavWp zl1#c6)7)#tG~Kp@sACEFinda02QgV&nZjT1M-RQ~{_L9VgV`xR3}G$K4`I{(@vBFV z+;Vr#kTTjnW}v1f*skJ^kzK_kUEhkv@V*s$mkg~)EbdgHD{`XnnOU?fTtwStB5Ii} zBL7Vy3jTHu>s{qk;kh-NiUy9r%nsJ5w6jKPKWkXJTO;ex5o!wIsKIjqNor)6S?$2* z+X2^aIKZ#L0VBFK(!iGqw5I(6+VQgtvF9Bytl0r*=ZI&o9bgmCNP3GCX?Np%8h=KH zuD?0p0N?-E#1Z3pGfQeDpEHT%RX?9H%4Hb2*8vCX9N^yQfIsd!VC=9)nvtGJb=l!$ z-?R#Ud^QEQJf~nt;1t~Rp915`{+Mfao(g?E>6aWwBwR5-!F|5g8ADXJH{{RG4vLIB zM-wM+r|$i3(q{P5Hjytq9NuCBPWww4A(M4)S~Ir59lF|OPkCmYN$ z^0Fz$?(0rR?>wfLbF<+U6NXRX!{9w93@g`%!N=DR1DtcoQuCO4&C5ohEDXuB!;rru z3_Z)j(BSU}?a*A>UGSLf=V#+kY8V`O-y(GwM(M&}9`1+RlX7X;?8ju7oQ(;RFl>wp zgD{yl9>2?IKQs=@rMo7bF>zrjj)ZK&wCGLvF>VtgXKsS|JQbSen1LByOoiqW*aS_6 z7jF>}lW{YCGDh_C#B#Tx2u{(``k6wsiVA6#YazWFT}boh6;fn&9(h;_EAoG?rfusv zl=Tteq^Afof<^FHEJE*jX3#snky?c{bhu1~+G`vJHgRY%mLYGT3^XGWd#?yE@{x`X z>{lV9gA8XZWC$G~!&I3JpDc+)*aHFlTXZz>unN&VWjNt1Lv*kV$LGrsk{^j)PXyR? zK}Xk0RS4z3vzK3SxY~=awGbKRWJKbtIsq8#>txYu7fdTkaqK`T_M9%o?>9$n>E z%|7s(_MDBm?SOq1K#GH!Ezz&TMEeYK>@A&q>%29DWuEe z@@RXEuwvExYKl(cV0+pT0S%5wf8mI*!6MlBn88xIk#c|1QPFx8{IWP4T*l!{C5Nex zIWX>#DB2-_a)XZgZ&Tsrb`Je3IK*DyP}@z0E<+-*{wDzzY}Judxe5=DaHy!|koJT_ zwucPP!I4l@3GjH8j=JWlFfE^hT>*z($2dSJ!`y+982`NhF(@b zC~Y}-4(0EIN)Ar-5m*r~z%Ri%nzTR#;_ord#DR<8&|1PF=4u4mB?%A`siRlfDttYb zgV$&d#WOiPImAKyC;}U22w)}F(ZHE1TzBQL$&tf&F^2_e4#TfRAS+yeuU229$#=C_ z8(xZ8QKe`~E=5I7DHbeJBWSS?a%cR(e)7EoWFf$&;Xqadpf?wYUu_Tfp8m)gtfhrg zA)Xf(P~Ex$ay?W)^E(vM#4qy5s!mujxmz^}+H*)Ta6}iz5q+#3VH4{J^DI-WcF!lF ztB#J0R$=axfN0!19dQ3H;nf9#)qR3~C*aBe zdo-UIhOFgRsV<=_;t!moqZ`lB@v3w5t>t-o>A9NnB)OVS7JXs+c`9yh4MyGKU_9jS zA-i7(!?w30F0y0k%n2uKZl}UYgG|hLlz|&BGZ5gAi5-Kd;8|29`5kk@KMYlvXq1UP zyni0=Kj@f=t4>oe*t3%Kes#hfp$e}%XM$rO~{sYFwU zOx$eB!06{0koC^Q(_T~Xim9YgpSWSsqpfH)F2V9w#b5;`aPCurB=HIy%n_imzKlNq z#tzH+2BTxgU@SEYhRh=v5if>9VQUA|(}SsL$pwmi-bAfenkc>fBWfJ+h^$*%X|r@E zF^?ver`9G>n&=xcJswM*F5l3cfazqjS4A6sS5w-mW)JuKHYavGIhPBpvA`ELn7 z?@1B$dby9T=CpY1XJB>`vNxw}^st`$#mk)k7mtlf3gzDnD6H%9e8eyzZpvXGJ8d-bc}{$?P_kpEvRW zbuxWGqyjWLPPu7Dx++Ezrg>`Xpb#-_1bm8AUJl$AN z{HYQ}K7Cqf$A_uS0t4E$5&l~_d|OiM*5-Jp{hvDG`%rJIVZGmhX~zp>AB1{Ok$WHU zZF#HjJBZ6ZiuiAu+WVky>s7YzfOg>{^Ff+_>tyd^ywTcN3(4jk zAd{{iA@x`JzK`^FvK`++%J~@5hs!W9z|c_i7l3Vhr@x`v#Nf8}ucyPG^~5lqpS|hl b=l$uLZ1(2;jW}x1CJ*v+1WeobH)H<^D Date: Sun, 9 Nov 2025 18:15:53 -0500 Subject: [PATCH 2/2] updated changelog --- CHANGELOG.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 08c1f7cf8f..d2dcf35999 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,15 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.3.0a0] - 2025-11-10 + +### Added + +- Added 3D-EDMPrecond wrapper class with 3D version of SongUnet for diffusion models: + - `SongUNet3D`: 3D U-Net diffusion backbone extending DDPM++ and NCSN++ to volumetric data + - New test to cover the SongUnet3D + - `EDMPrecond3D`: 3D preconditioning wrapper for volumetric diffusion models + ## [1.3.0a0] - 2025-XX-YY ### Added