diff --git a/config_collab.yaml b/config_collab.yaml index 49f4931..658c616 100644 --- a/config_collab.yaml +++ b/config_collab.yaml @@ -17,7 +17,7 @@ num_epochs_pretraining: 270 num_epochs_finetuning: 100 val_split: 0.2 random_seed: 42 -cosine_annealing: False +cosine_annealing: True iso_range: [0, 999999] # --- Experiment Settings --- @@ -25,17 +25,19 @@ experiment: NAF_test mlflow_experiment: NAFNet_variations # --- Run Configuration ---: -run_name: NAF_ps +run_name: NAF_debayer_in run_path: NAF_deep_test_align model_params: - chans: [64, 128, 256, 512, 512, 512] + # chans: [64, 128, 256, 512, 512, 512] + chans: [32, 64, 128, 256, 256, 256] + enc_blk_nums: [2, 2, 2, 3, 4] middle_blk_num: 12 dec_blk_nums: [2, 2, 2, 2, 2] cond_input: 1 - in_channels: 4 + in_channels: 3 out_channels: 3 - rggb: True + rggb: False use_CondFuserV2: False use_add: False use_CondFuserV3: False diff --git a/src/Restorer/Cond_NAF.py b/src/Restorer/Cond_NAF.py index 432e0bf..1125db8 100644 --- a/src/Restorer/Cond_NAF.py +++ b/src/Restorer/Cond_NAF.py @@ -900,3 +900,29 @@ def make_full_model_PS(params, model_name=None): model.load_state_dict(state_dict) return model + + +class ModelWrapperDemoIn(nn.Module): + def __init__(self, **kwargs): + super().__init__() + if 'gamma' in kwargs: + kwargs.pop('gamma') + + self.demosaicer = DemosaicingFromRGGB() + self.model = Restorer( + **kwargs + ) + + def forward(self, rggb, cond, *args): + debayered = self.demosaicer(rggb, cond) + output = self.model(debayered, cond) + output = (debayered + output) + return output + + +def make_full_model_RGGB_DemoIn(params, model_name=None): + model = ModelWrapperDemoIn(**params) + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model diff --git a/src/Restorer/Cond_NAF_SSID.py b/src/Restorer/Cond_NAF_SSID.py new file mode 100644 index 0000000..b1f718a --- /dev/null +++ b/src/Restorer/Cond_NAF_SSID.py @@ -0,0 +1,278 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ + +''' +Simple Baselines for Image Restoration + +@article{chen2022simple, + title={Simple Baselines for Image Restoration}, + author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian}, + journal={arXiv preprint arXiv:2204.04676}, + year={2022} +} +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class LayerNorm2dAdjusted(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x, target_mu, target_var): + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + + y = (x - mu) / torch.sqrt(var + self.eps) + + y = y * torch.sqrt(target_var + self.eps) + target_mu + + weight_view = self.weight.view(1, self.weight.size(0), 1, 1) + bias_view = self.bias.view(1, self.bias.size(0), 1, 1) + + y = weight_view * y + bias_view + return y + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + + y = (x - mu) / torch.sqrt(var + self.eps) + + weight_view = self.weight.view(1, self.weight.size(0), 1, 1) + bias_view = self.bias.view(1, self.bias.size(0), 1, 1) + + y = weight_view * y + bias_view + return y + +class ChannelAttention(nn.Module): + def __init__(self, dims): + super().__init__() + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dims, out_channels=dims, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + + def forward(self, x): + return self.sca(x) + +class CondFuser(nn.Module): + def __init__(self, chan): + super().__init__() + self.cca = ChannelAttention(chan * 2) + + def forward(self, x1, x2): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x) * x + + x1, x2 = x.chunk(2, dim=1) + return x1 + x2 + +class CondFuserAdd(nn.Module): + def __init__(self, chan): + super().__init__() + + def forward(self, x1, x2): + return x1 + x2 + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + +class NAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, + bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + # Simplified Channel Attention + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + + +class NAFNet(nn.Module): + + def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], + use_add = False): + super().__init__() + + self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, + bias=True) + self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1, + bias=True) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + self.merges = nn.ModuleList() + + chan = width + for num in enc_blk_nums: + self.encoders.append( + nn.Sequential( + *[NAFBlock(chan) for _ in range(num)] + ) + ) + self.downs.append( + nn.Conv2d(chan, 2*chan, 2, 2) + ) + chan = chan * 2 + + self.middle_blks = \ + nn.Sequential( + *[NAFBlock(chan) for _ in range(middle_blk_num)] + ) + + for num in dec_blk_nums: + self.ups.append( + nn.Sequential( + nn.Conv2d(chan, chan * 2, 1, bias=False), + nn.PixelShuffle(2) + ) + ) + chan = chan // 2 + self.decoders.append( + nn.Sequential( + *[NAFBlock(chan) for _ in range(num)] + ) + ) + + if use_add: + self.merges.append(CondFuserAdd(chan)) + else: + self.merges.append(CondFuser(chan)) + + self.padder_size = 2 ** len(self.encoders) + + def forward(self, inp): + B, C, H, W = inp.shape + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + + for encoder, down in zip(self.encoders, self.downs): + x = encoder(x) + encs.append(x) + x = down(x) + + x = self.middle_blks(x) + + for decoder, up, merge, enc_skip in zip(self.decoders, self.ups, self.merges, encs[::-1]): + x = up(x) + x = merge(x, enc_skip) + x = decoder(x) + + x = self.ending(x) + x = x + inp + + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + +class NAFNetLocal(Local_Base, NAFNet): + def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs): + Local_Base.__init__(self) + NAFNet.__init__(self, *args, **kwargs) + + N, C, H, W = train_size + base_size = (int(H * 1.5), int(W * 1.5)) + + self.eval() + with torch.no_grad(): + self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp) + + +if __name__ == '__main__': + img_channel = 3 + width = 32 + + # enc_blks = [2, 2, 4, 8] + # middle_blk_num = 12 + # dec_blks = [2, 2, 2, 2] + + enc_blks = [1, 1, 1, 28] + middle_blk_num = 1 + dec_blks = [1, 1, 1, 1] + + net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num, + enc_blk_nums=enc_blks, dec_blk_nums=dec_blks) + + + inp_shape = (3, 256, 256) + + from ptflops import get_model_complexity_info + + macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False) + + params = float(params[:-3]) + macs = float(macs[:-4]) + + print(macs, params) \ No newline at end of file