From 5cb453b7c632461deb2329e1874928fd884512e6 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 7 Nov 2025 17:19:27 -0500 Subject: [PATCH 1/4] Baseline NAFnet arch --- src/Restorer/Cond_NAF_SSID.py | 202 ++++++++++++++++++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 src/Restorer/Cond_NAF_SSID.py diff --git a/src/Restorer/Cond_NAF_SSID.py b/src/Restorer/Cond_NAF_SSID.py new file mode 100644 index 0000000..0fbe926 --- /dev/null +++ b/src/Restorer/Cond_NAF_SSID.py @@ -0,0 +1,202 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ + +''' +Simple Baselines for Image Restoration + +@article{chen2022simple, + title={Simple Baselines for Image Restoration}, + author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian}, + journal={arXiv preprint arXiv:2204.04676}, + year={2022} +} +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from basicsr.models.archs.arch_util import LayerNorm2d +from basicsr.models.archs.local_arch import Local_Base + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + +class NAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, + bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + # Simplified Channel Attention + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + + +class NAFNet(nn.Module): + + def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]): + super().__init__() + + self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, + bias=True) + self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1, + bias=True) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + + chan = width + for num in enc_blk_nums: + self.encoders.append( + nn.Sequential( + *[NAFBlock(chan) for _ in range(num)] + ) + ) + self.downs.append( + nn.Conv2d(chan, 2*chan, 2, 2) + ) + chan = chan * 2 + + self.middle_blks = \ + nn.Sequential( + *[NAFBlock(chan) for _ in range(middle_blk_num)] + ) + + for num in dec_blk_nums: + self.ups.append( + nn.Sequential( + nn.Conv2d(chan, chan * 2, 1, bias=False), + nn.PixelShuffle(2) + ) + ) + chan = chan // 2 + self.decoders.append( + nn.Sequential( + *[NAFBlock(chan) for _ in range(num)] + ) + ) + + self.padder_size = 2 ** len(self.encoders) + + def forward(self, inp): + B, C, H, W = inp.shape + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + + for encoder, down in zip(self.encoders, self.downs): + x = encoder(x) + encs.append(x) + x = down(x) + + x = self.middle_blks(x) + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + x = decoder(x) + + x = self.ending(x) + x = x + inp + + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + +class NAFNetLocal(Local_Base, NAFNet): + def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs): + Local_Base.__init__(self) + NAFNet.__init__(self, *args, **kwargs) + + N, C, H, W = train_size + base_size = (int(H * 1.5), int(W * 1.5)) + + self.eval() + with torch.no_grad(): + self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp) + + +if __name__ == '__main__': + img_channel = 3 + width = 32 + + # enc_blks = [2, 2, 4, 8] + # middle_blk_num = 12 + # dec_blks = [2, 2, 2, 2] + + enc_blks = [1, 1, 1, 28] + middle_blk_num = 1 + dec_blks = [1, 1, 1, 1] + + net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num, + enc_blk_nums=enc_blks, dec_blk_nums=dec_blks) + + + inp_shape = (3, 256, 256) + + from ptflops import get_model_complexity_info + + macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False) + + params = float(params[:-3]) + macs = float(macs[:-4]) + + print(macs, params) \ No newline at end of file From f971c22e3d8cb2f9106023947878e889960ff2bc Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 7 Nov 2025 17:20:20 -0500 Subject: [PATCH 2/4] Testing fuser --- src/Restorer/Cond_NAF_SSID.py | 86 +++++++++++++++++++++++++++++++++-- 1 file changed, 81 insertions(+), 5 deletions(-) diff --git a/src/Restorer/Cond_NAF_SSID.py b/src/Restorer/Cond_NAF_SSID.py index 0fbe926..b1f718a 100644 --- a/src/Restorer/Cond_NAF_SSID.py +++ b/src/Restorer/Cond_NAF_SSID.py @@ -16,8 +16,77 @@ import torch import torch.nn as nn import torch.nn.functional as F -from basicsr.models.archs.arch_util import LayerNorm2d -from basicsr.models.archs.local_arch import Local_Base + +class LayerNorm2dAdjusted(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x, target_mu, target_var): + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + + y = (x - mu) / torch.sqrt(var + self.eps) + + y = y * torch.sqrt(target_var + self.eps) + target_mu + + weight_view = self.weight.view(1, self.weight.size(0), 1, 1) + bias_view = self.bias.view(1, self.bias.size(0), 1, 1) + + y = weight_view * y + bias_view + return y + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + + y = (x - mu) / torch.sqrt(var + self.eps) + + weight_view = self.weight.view(1, self.weight.size(0), 1, 1) + bias_view = self.bias.view(1, self.bias.size(0), 1, 1) + + y = weight_view * y + bias_view + return y + +class ChannelAttention(nn.Module): + def __init__(self, dims): + super().__init__() + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dims, out_channels=dims, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + + def forward(self, x): + return self.sca(x) + +class CondFuser(nn.Module): + def __init__(self, chan): + super().__init__() + self.cca = ChannelAttention(chan * 2) + + def forward(self, x1, x2): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x) * x + + x1, x2 = x.chunk(2, dim=1) + return x1 + x2 + +class CondFuserAdd(nn.Module): + def __init__(self, chan): + super().__init__() + + def forward(self, x1, x2): + return x1 + x2 class SimpleGate(nn.Module): def forward(self, x): @@ -82,7 +151,8 @@ def forward(self, inp): class NAFNet(nn.Module): - def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]): + def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], + use_add = False): super().__init__() self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, @@ -95,6 +165,7 @@ def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], d self.middle_blks = nn.ModuleList() self.ups = nn.ModuleList() self.downs = nn.ModuleList() + self.merges = nn.ModuleList() chan = width for num in enc_blk_nums: @@ -127,6 +198,11 @@ def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], d ) ) + if use_add: + self.merges.append(CondFuserAdd(chan)) + else: + self.merges.append(CondFuser(chan)) + self.padder_size = 2 ** len(self.encoders) def forward(self, inp): @@ -144,9 +220,9 @@ def forward(self, inp): x = self.middle_blks(x) - for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + for decoder, up, merge, enc_skip in zip(self.decoders, self.ups, self.merges, encs[::-1]): x = up(x) - x = x + enc_skip + x = merge(x, enc_skip) x = decoder(x) x = self.ending(x) From 7b30eeee3c751db6b6fa4a351c4100a5fcd782e2 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 7 Nov 2025 21:22:41 -0500 Subject: [PATCH 3/4] Added a demosaiced in version --- src/Restorer/Cond_NAF.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/Restorer/Cond_NAF.py b/src/Restorer/Cond_NAF.py index 432e0bf..1125db8 100644 --- a/src/Restorer/Cond_NAF.py +++ b/src/Restorer/Cond_NAF.py @@ -900,3 +900,29 @@ def make_full_model_PS(params, model_name=None): model.load_state_dict(state_dict) return model + + +class ModelWrapperDemoIn(nn.Module): + def __init__(self, **kwargs): + super().__init__() + if 'gamma' in kwargs: + kwargs.pop('gamma') + + self.demosaicer = DemosaicingFromRGGB() + self.model = Restorer( + **kwargs + ) + + def forward(self, rggb, cond, *args): + debayered = self.demosaicer(rggb, cond) + output = self.model(debayered, cond) + output = (debayered + output) + return output + + +def make_full_model_RGGB_DemoIn(params, model_name=None): + model = ModelWrapperDemoIn(**params) + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model From 9b6bd3e0e60ee084edf9278514e3df12a5fbe264 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 7 Nov 2025 21:26:11 -0500 Subject: [PATCH 4/4] Update to collab config --- config_collab.yaml | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/config_collab.yaml b/config_collab.yaml index 49f4931..658c616 100644 --- a/config_collab.yaml +++ b/config_collab.yaml @@ -17,7 +17,7 @@ num_epochs_pretraining: 270 num_epochs_finetuning: 100 val_split: 0.2 random_seed: 42 -cosine_annealing: False +cosine_annealing: True iso_range: [0, 999999] # --- Experiment Settings --- @@ -25,17 +25,19 @@ experiment: NAF_test mlflow_experiment: NAFNet_variations # --- Run Configuration ---: -run_name: NAF_ps +run_name: NAF_debayer_in run_path: NAF_deep_test_align model_params: - chans: [64, 128, 256, 512, 512, 512] + # chans: [64, 128, 256, 512, 512, 512] + chans: [32, 64, 128, 256, 256, 256] + enc_blk_nums: [2, 2, 2, 3, 4] middle_blk_num: 12 dec_blk_nums: [2, 2, 2, 2, 2] cond_input: 1 - in_channels: 4 + in_channels: 3 out_channels: 3 - rggb: True + rggb: False use_CondFuserV2: False use_add: False use_CondFuserV3: False