diff --git a/src/Restorer/Cond_NAF.py b/src/Restorer/Cond_NAF.py index f710cf6..cfdeaa6 100644 --- a/src/Restorer/Cond_NAF.py +++ b/src/Restorer/Cond_NAF.py @@ -844,22 +844,18 @@ def make_full_model_RGGB(params, model_name=None): class ModelWrapperNoRes(nn.Module): def __init__(self, **kwargs): super().__init__() - - self.gamma = 1 if 'gamma' in kwargs: - self.gamma = kwargs.pop('gamma') + kwargs.pop('gamma') self.demosaicer = DemosaicingFromRGGB() self.model = Restorer( **kwargs ) - def forward(self, rggb, cond): - rggb = rggb.clip(0, 1) ** (1. / self.gamma) + def forward(self, rggb, cond, *args): debayered = self.demosaicer(rggb, cond) - debayered = debayered.clip(0, 1) ** (1. / self.gamma) output = self.model(rggb, cond) - output = (debayered + output).clip(0, 1) ** (self.gamma) + output = (debayered + output) return output diff --git a/src/training/RawDatasetDNG_load_into_memory.py b/src/training/RawDatasetDNG_load_into_memory.py index bafd62f..a1a5f1e 100644 --- a/src/training/RawDatasetDNG_load_into_memory.py +++ b/src/training/RawDatasetDNG_load_into_memory.py @@ -16,6 +16,8 @@ from pathlib import Path from RawHandler.RawHandler import RawHandler +from .align_images import apply_alignment + class RawDatasetDNG(Dataset): def __init__(self, path, csv, colorspace, crop_size=180, buffer=10, validation=False, run_align=False,