11import random
22from torch .utils .data import Dataset
3+ import torch
34
45from RawHandler .RawHandler import RawHandler
5- from RawHandler .utils import align_images
6+
7+ import re
68
79
810class RawDataset (Dataset ):
@@ -25,8 +27,8 @@ def __getitem__(self, idx):
2527 # Crop and align
2628 H , W = noisy_rh .raw .shape [- 2 :]
2729 half_crop = self .crop_size // 2
28- H_center = random .randint (0 + half_crop , H - half_crop )
29- W_center = random .randint (0 + half_crop , W - half_crop )
30+ H_center = random .randint (0 + half_crop * 2 , H - half_crop * 2 )
31+ W_center = random .randint (0 + half_crop * 2 , W - half_crop * 2 )
3032 crop = (
3133 H_center - half_crop ,
3234 H_center + half_crop ,
@@ -36,16 +38,33 @@ def __getitem__(self, idx):
3638 if self .offsets is None :
3739 offset = (0 , 0 , 0 , 0 )
3840 else :
39- offset = self .offsets [idx ]
40- offset = align_images (noisy_rh , gt_rh , crop , offset = (0 , 0 , 0 , 0 ))
41+ offset = self .offsets [idx ][0 ][0 ]
4142
4243 # Adjust exposure
43- noisy_rggb = noisy_rh .as_rggb (dims = crop )
44- noisy_rgb = noisy_rh .as_rgb (dims = crop )
45- clean_rgb = gt_rh .as_rgb (dims = crop )
44+ gain = (
45+ noisy_rh .adjust_bayer_bw_levels (dims = crop ).mean ()
46+ / gt_rh .adjust_bayer_bw_levels (dims = crop ).mean ()
47+ )
48+ gt_rh .gain = gain
49+
50+ # offset = align_images(noisy_rh, gt_rh, crop, offset=offset, step_sizes=[2])
51+
52+ noisy_rggb = noisy_rh .as_rggb_colorspace (dims = crop , colorspace = "AdobeRGB" )
53+ noisy_rgb = noisy_rh .as_rgb_colorspace (dims = crop , colorspace = "AdobeRGB" )
54+ clean_rgb = gt_rh .as_rgb_colorspace (dims = crop + offset , colorspace = "AdobeRGB" )
55+
56+ iso = re .findall ("_ISO([0-9]+)_" , noisy_file )
57+ if len (iso ) == 1 :
58+ iso = int (iso [0 ])
59+ else :
60+ iso = - 100
61+
62+ iso_conditioning = iso / 65535
4663
4764 if self .transform :
48- noisy_rggb = self .transform (noisy_rggb )
49- noisy_rgb = self .transform (noisy_rgb )
50- clean_rgb = self .transform (clean_rgb )
51- return noisy_rggb , noisy_rgb , clean_rgb , offset
65+ noisy_rggb = self .transform (noisy_rggb .transpose (1 , 2 , 0 ))
66+ noisy_rgb = self .transform (noisy_rgb .transpose (1 , 2 , 0 ))
67+ clean_rgb = self .transform (clean_rgb .transpose (1 , 2 , 0 ))
68+ iso_conditioning = torch .tensor ([iso_conditioning ])
69+
70+ return noisy_rggb , noisy_rgb , clean_rgb , offset , iso_conditioning
0 commit comments