Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions config_collab.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,27 @@ num_epochs_pretraining: 270
num_epochs_finetuning: 100
val_split: 0.2
random_seed: 42
cosine_annealing: False
cosine_annealing: True
iso_range: [0, 999999]

# --- Experiment Settings ---
experiment: NAF_test
mlflow_experiment: NAFNet_variations

# --- Run Configuration ---:
run_name: NAF_ps
run_name: NAF_debayer_in
run_path: NAF_deep_test_align
model_params:
chans: [64, 128, 256, 512, 512, 512]
# chans: [64, 128, 256, 512, 512, 512]
chans: [32, 64, 128, 256, 256, 256]

enc_blk_nums: [2, 2, 2, 3, 4]
middle_blk_num: 12
dec_blk_nums: [2, 2, 2, 2, 2]
cond_input: 1
in_channels: 4
in_channels: 3
out_channels: 3
rggb: True
rggb: False
use_CondFuserV2: False
use_add: False
use_CondFuserV3: False
Expand Down
26 changes: 26 additions & 0 deletions src/Restorer/Cond_NAF.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,3 +900,29 @@ def make_full_model_PS(params, model_name=None):
model.load_state_dict(state_dict)
return model



class ModelWrapperDemoIn(nn.Module):
def __init__(self, **kwargs):
super().__init__()
if 'gamma' in kwargs:
kwargs.pop('gamma')

self.demosaicer = DemosaicingFromRGGB()
self.model = Restorer(
**kwargs
)

def forward(self, rggb, cond, *args):
debayered = self.demosaicer(rggb, cond)
output = self.model(debayered, cond)
output = (debayered + output)
return output


def make_full_model_RGGB_DemoIn(params, model_name=None):
model = ModelWrapperDemoIn(**params)
if not model_name is None:
state_dict = torch.load(model_name, map_location="cpu")
model.load_state_dict(state_dict)
return model
278 changes: 278 additions & 0 deletions src/Restorer/Cond_NAF_SSID.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------

'''
Simple Baselines for Image Restoration

@article{chen2022simple,
title={Simple Baselines for Image Restoration},
author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
journal={arXiv preprint arXiv:2204.04676},
year={2022}
}
'''

import torch
import torch.nn as nn
import torch.nn.functional as F

class LayerNorm2dAdjusted(nn.Module):
def __init__(self, channels, eps=1e-6):
super(LayerNorm2d, self).__init__()
self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
self.eps = eps

def forward(self, x, target_mu, target_var):
mu = x.mean(1, keepdim=True)
var = (x - mu).pow(2).mean(1, keepdim=True)

y = (x - mu) / torch.sqrt(var + self.eps)

y = y * torch.sqrt(target_var + self.eps) + target_mu

weight_view = self.weight.view(1, self.weight.size(0), 1, 1)
bias_view = self.bias.view(1, self.bias.size(0), 1, 1)

y = weight_view * y + bias_view
return y

class LayerNorm2d(nn.Module):
def __init__(self, channels, eps=1e-6):
super(LayerNorm2d, self).__init__()
self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
self.eps = eps

def forward(self, x):
mu = x.mean(1, keepdim=True)
var = (x - mu).pow(2).mean(1, keepdim=True)

y = (x - mu) / torch.sqrt(var + self.eps)

weight_view = self.weight.view(1, self.weight.size(0), 1, 1)
bias_view = self.bias.view(1, self.bias.size(0), 1, 1)

y = weight_view * y + bias_view
return y

class ChannelAttention(nn.Module):
def __init__(self, dims):
super().__init__()
self.sca = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels=dims, out_channels=dims, kernel_size=1, padding=0, stride=1,
groups=1, bias=True),
)

def forward(self, x):
return self.sca(x)

class CondFuser(nn.Module):
def __init__(self, chan):
super().__init__()
self.cca = ChannelAttention(chan * 2)

def forward(self, x1, x2):
x = torch.cat([x1, x2], dim=1)
x = self.cca(x) * x

x1, x2 = x.chunk(2, dim=1)
return x1 + x2

class CondFuserAdd(nn.Module):
def __init__(self, chan):
super().__init__()

def forward(self, x1, x2):
return x1 + x2

class SimpleGate(nn.Module):
def forward(self, x):
x1, x2 = x.chunk(2, dim=1)
return x1 * x2

class NAFBlock(nn.Module):
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
super().__init__()
dw_channel = c * DW_Expand
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
bias=True)
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)

# Simplified Channel Attention
self.sca = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
groups=1, bias=True),
)

# SimpleGate
self.sg = SimpleGate()

ffn_channel = FFN_Expand * c
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)

self.norm1 = LayerNorm2d(c)
self.norm2 = LayerNorm2d(c)

self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()

self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)

def forward(self, inp):
x = inp

x = self.norm1(x)

x = self.conv1(x)
x = self.conv2(x)
x = self.sg(x)
x = x * self.sca(x)
x = self.conv3(x)

x = self.dropout1(x)

y = inp + x * self.beta

x = self.conv4(self.norm2(y))
x = self.sg(x)
x = self.conv5(x)

x = self.dropout2(x)

return y + x * self.gamma


class NAFNet(nn.Module):

def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[],
use_add = False):
super().__init__()

self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
bias=True)
self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
bias=True)

self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
self.middle_blks = nn.ModuleList()
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.merges = nn.ModuleList()

chan = width
for num in enc_blk_nums:
self.encoders.append(
nn.Sequential(
*[NAFBlock(chan) for _ in range(num)]
)
)
self.downs.append(
nn.Conv2d(chan, 2*chan, 2, 2)
)
chan = chan * 2

self.middle_blks = \
nn.Sequential(
*[NAFBlock(chan) for _ in range(middle_blk_num)]
)

for num in dec_blk_nums:
self.ups.append(
nn.Sequential(
nn.Conv2d(chan, chan * 2, 1, bias=False),
nn.PixelShuffle(2)
)
)
chan = chan // 2
self.decoders.append(
nn.Sequential(
*[NAFBlock(chan) for _ in range(num)]
)
)

if use_add:
self.merges.append(CondFuserAdd(chan))
else:
self.merges.append(CondFuser(chan))

self.padder_size = 2 ** len(self.encoders)

def forward(self, inp):
B, C, H, W = inp.shape
inp = self.check_image_size(inp)

x = self.intro(inp)

encs = []

for encoder, down in zip(self.encoders, self.downs):
x = encoder(x)
encs.append(x)
x = down(x)

x = self.middle_blks(x)

for decoder, up, merge, enc_skip in zip(self.decoders, self.ups, self.merges, encs[::-1]):
x = up(x)
x = merge(x, enc_skip)
x = decoder(x)

x = self.ending(x)
x = x + inp

return x[:, :, :H, :W]

def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
return x

class NAFNetLocal(Local_Base, NAFNet):
def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs):
Local_Base.__init__(self)
NAFNet.__init__(self, *args, **kwargs)

N, C, H, W = train_size
base_size = (int(H * 1.5), int(W * 1.5))

self.eval()
with torch.no_grad():
self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)


if __name__ == '__main__':
img_channel = 3
width = 32

# enc_blks = [2, 2, 4, 8]
# middle_blk_num = 12
# dec_blks = [2, 2, 2, 2]

enc_blks = [1, 1, 1, 28]
middle_blk_num = 1
dec_blks = [1, 1, 1, 1]

net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)


inp_shape = (3, 256, 256)

from ptflops import get_model_complexity_info

macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)

params = float(params[:-3])
macs = float(macs[:-4])

print(macs, params)
Loading