From 22ea79ce1275a7b4c46e1a97afb3d52360c577e6 Mon Sep 17 00:00:00 2001 From: ash1ra Date: Tue, 17 Feb 2026 12:05:54 +0200 Subject: [PATCH 1/5] feat: create `create_logger` function --- utils.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/utils.py b/utils.py index 5034573..3de9c58 100644 --- a/utils.py +++ b/utils.py @@ -1,7 +1,52 @@ +import logging +import sys +from datetime import datetime +from logging.handlers import RotatingFileHandler +from pathlib import Path + from einops import rearrange from torch import Tensor +def create_logger( + log_level: str, + log_file_name: str, + max_log_file_size: int = 5 * 1024 * 1024, + backup_count: int = 10, +) -> logging.Logger: + logger = logging.getLogger(__name__) + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter( + "%(asctime)s | %(levelname)s | %(message)s", + datefmt="%d.%m.%Y %H:%M:%S", + ) + + logger.handlers.clear() + + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + console_handler.setLevel(getattr(logging, log_level.upper())) + logger.addHandler(console_handler) + + Path("logs").mkdir(parents=True, exist_ok=True) + current_date = datetime.now().strftime("%d-%m-%Y-%H-%M-%S") + log_file_name = f"logs/{log_file_name}_{current_date}.log" + + file_handler = RotatingFileHandler( + filename=log_file_name, + maxBytes=max_log_file_size, + backupCount=backup_count, + ) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +logger = create_logger(log_level="INFO", log_file_name="HAT") + + def split_img_into_windows(img_tensor: Tensor, window_size: int) -> Tensor: return rearrange( img_tensor, From b56ad46fdd630314d02dbfa3f2fc0f15808f3b15 Mon Sep 17 00:00:00 2001 From: ash1ra Date: Tue, 17 Feb 2026 12:05:06 +0200 Subject: [PATCH 2/5] feat: create `prepare_data.py` file with MATLAB-like bicubic interpolation function --- config.py | 21 +++++ prepare_data.py | 211 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 232 insertions(+) create mode 100644 config.py create mode 100644 prepare_data.py diff --git a/config.py b/config.py new file mode 100644 index 0000000..d6f1dea --- /dev/null +++ b/config.py @@ -0,0 +1,21 @@ +from pathlib import Path +from typing import Literal, TypeAlias + + +DeviceType: TypeAlias = Literal["cuda", "cpu"] + + +# Training settings +SCALING_FACTOR = 4 + +# Dataset pathes +PRETRAIN_DATASET_PATH = Path("data/ImageNet.txt") +TRAIN_DATASET_PATH = Path("data/DF2K") +VAL_DATASET_PATH = Path("data/DIV2K_val") +TEST_DATASET_PATHS = [ + Path("data/Set5"), + Path("data/Set14"), + Path("data/BSDS100"), + Path("data/Urban100"), + Path("data/Manga109"), +] diff --git a/prepare_data.py b/prepare_data.py new file mode 100644 index 0000000..58898c4 --- /dev/null +++ b/prepare_data.py @@ -0,0 +1,211 @@ +# Ported from BasicSR (matlab_functions.py) to ensure academic reproducibility. +# Implements MATLAB-like bicubic interpolation required for standard SR benchmarks (Set5, Set14). +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/matlab_functions.py +import concurrent.futures +from functools import partial +from pathlib import Path +from typing import Literal, Optional + +import numpy as np +from PIL import Image +from tqdm import tqdm + +import config + + +def cubic(x: np.ndarray) -> np.ndarray: + abs_x = np.abs(x) + abs_x2 = abs_x**2 + abs_x3 = abs_x**3 + + return (1.5 * abs_x3 - 2.5 * abs_x2 + 1) * ((abs_x <= 1).astype(type(abs_x))) + ( + -0.5 * abs_x3 + 2.5 * abs_x2 - 4 * abs_x + 2 + ) * (((abs_x > 1) * (abs_x <= 2)).astype(type(abs_x))) + + +def calculate_weights_indices( + in_length: int, + out_length: int, + scale: float, + kernel_width: int, + antialiasing: bool, +) -> tuple[np.ndarray, np.ndarray]: + if (scale < 1) and antialiasing: + kernel_width: int | float = kernel_width / scale + + x = np.linspace(1, out_length, out_length) + u = x / scale + 0.5 * (1 - 1 / scale) + left = np.floor(u - kernel_width / 2) + p = int(np.ceil(kernel_width)) + 2 + + indices = left.reshape(int(out_length), 1) + np.linspace(0, p - 1, p).reshape(1, int(p)) + + distance_to_center = u.reshape(int(out_length), 1) - indices + + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + weights_sum = np.sum(weights, 1).reshape(int(out_length), 1) + weights /= weights_sum + + weights_zero_idx = np.where(weights_sum == 0) + if len(weights_zero_idx[0]) > 0: + weights[weights_zero_idx, 0] = 1 + + padded_indices = indices.astype(int) + padded_indices -= 1 + + padded_indices = np.abs(padded_indices) + padded_indices = np.where(padded_indices < in_length, padded_indices, 2 * in_length - 1 - padded_indices) + padded_indices = np.clip(padded_indices, 0, in_length - 1) + + return weights, padded_indices + + +def imresize(img: np.ndarray, scale: float, antialiasing: bool = True) -> np.ndarray: + if scale == 1: + return img + + if len(img.shape) == 3: + input_img_height, input_img_width, input_img_num_channels = img.shape + else: + input_img_height, input_img_width = img.shape + input_img_num_channels = 1 + + output_img_height = int(np.ceil(input_img_height * scale)) + output_img_width = int(np.ceil(input_img_width * scale)) + + kernel_width = 4 + + height_weights, height_indices = calculate_weights_indices( + in_length=input_img_height, + out_length=output_img_height, + scale=scale, + kernel_width=kernel_width, + antialiasing=antialiasing, + ) + + width_weights, width_indices = calculate_weights_indices( + in_length=input_img_width, + out_length=output_img_width, + scale=scale, + kernel_width=kernel_width, + antialiasing=antialiasing, + ) + + img_aug = np.zeros((output_img_height, input_img_width, input_img_num_channels), dtype=np.float32) + + for channel in range(input_img_num_channels): + channel_data = img[:, :, channel] if input_img_num_channels > 1 else img + pixels = channel_data[height_indices] + img_aug[:, :, channel] = np.sum(height_weights[:, :, None] * pixels, axis=1) + + output_img = np.zeros((output_img_height, output_img_width, input_img_num_channels), dtype=np.float32) + + for channel in range(input_img_num_channels): + channel_data = img_aug[:, :, channel] + pixels = channel_data[:, width_indices] + output_img[:, :, channel] = np.sum(width_weights[None, :, :] * pixels, axis=2) + + output_img = np.clip(output_img, 0, 255) + + return np.round(output_img).astype(np.uint8) + + +def process_single_img( + img_path: Path, + hr_dir: Path, + lr_dir: Path, + scaling_factor: Literal[2, 4, 8], +) -> None: + img_name = f"{img_path.stem}.png" + hr_img_path = hr_dir / img_name + lr_img_path = lr_dir / img_name + + if hr_img_path.exists() and lr_img_path.exists(): + if hr_img_path.stat().st_size > 0 and lr_img_path.stat().st_size > 0: + return + + try: + with Image.open(img_path) as img: + img = img.convert("RGB") + + img_width, img_height = img.size + + remainder_w = img_width % scaling_factor + remainder_h = img_height % scaling_factor + + if remainder_w != 0 or remainder_h != 0: + img = img.crop((0, 0, img_width - remainder_w, img_height - remainder_h)) + + img.save(hr_img_path, compress_level=1) + + img_np = np.array(img) + + lr_img_np = imresize(img_np, scale=1 / scaling_factor) + + lr_img = Image.fromarray(lr_img_np) + + lr_img.save(lr_img_path, compress_level=1) + except Exception as e: + print(f"[Error] Failed to process '{img_path.name}': {e}") + + +def prepare_data( + input_data_path: Path, + output_data_path: Path, + scaling_factor: Literal[2, 4, 8], + num_workers: Optional[int] = None, +) -> None: + print(f"[Data] Preparing data from '{input_data_path}'") + + output_data_path.mkdir(parents=True, exist_ok=True) + + hr_dir_output_path = output_data_path / "HR" + hr_dir_output_path.mkdir(parents=True, exist_ok=True) + + lr_dir_output_path = output_data_path / f"LR_x{scaling_factor}" + lr_dir_output_path.mkdir(parents=True, exist_ok=True) + + if input_data_path.exists(): + if input_data_path.is_dir(): + img_paths = sorted([p for p in input_data_path.glob("*") if p.suffix.lower() in [".png", ".jpg", ".jpeg"]]) + elif input_data_path.is_file(): + with open(input_data_path, "r") as f: + img_paths = sorted([Path(line.strip()) for line in f if line.strip()]) + else: + raise FileNotFoundError(f"[Error] Input path '{input_data_path}' not found.") + + print(f"[Data] Found {len(img_paths)} images. Processing...") + + worker = partial( + process_single_img, hr_dir=hr_dir_output_path, lr_dir=lr_dir_output_path, scaling_factor=scaling_factor + ) + + with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor: + list(tqdm(executor.map(worker, img_paths), total=len(img_paths), desc="Processing images...")) + + print(f"[Data] Processing completed. Output saved to '{output_data_path}'.\n") + + +if __name__ == "__main__": + prepare_data( + input_data_path=config.TRAIN_DATASET_PATH.with_suffix(".txt"), + output_data_path=config.TRAIN_DATASET_PATH, + scaling_factor=config.SCALING_FACTOR, + ) + + prepare_data( + input_data_path=config.VAL_DATASET_PATH.with_suffix(".txt"), + output_data_path=config.VAL_DATASET_PATH, + scaling_factor=config.SCALING_FACTOR, + ) + + for test_dataset_path in config.TEST_DATASET_PATHS: + prepare_data( + input_data_path=test_dataset_path.with_suffix(".txt"), + output_data_path=test_dataset_path, + scaling_factor=config.SCALING_FACTOR, + ) From 8b556c32826a4f7acb282dbe588c168890f69f18 Mon Sep 17 00:00:00 2001 From: ash1ra Date: Tue, 17 Feb 2026 12:06:15 +0200 Subject: [PATCH 3/5] feat: create `dataset.py` file and `SRDataset` class --- dataset.py | 125 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 dataset.py diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..09ac08e --- /dev/null +++ b/dataset.py @@ -0,0 +1,125 @@ +import random +from pathlib import Path + +import torch +from PIL import Image +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.io import ImageReadMode, decode_image +from torchvision.transforms import v2 as transforms + +from utils import logger + + +class SRDataset(Dataset): + def __init__( + self, + data_path: Path, + scaling_factor: int, + patch_size: int, + test_mode: bool = False, + dev_mode: bool = False, + ) -> None: + self.scaling_factor = scaling_factor + self.patch_size = patch_size + self.test_mode = test_mode + self.dev_mode = dev_mode + + self.hr_dir_path = data_path / "HR" + self.lr_dir_path = data_path / f"LR_x{scaling_factor}" + + if not self.hr_dir_path.exists() or not self.lr_dir_path.exists(): + raise FileNotFoundError(f"[Data] Datasets directories not found in '{data_path}'") + + hr_img_names = {hr_img.name for hr_img in self.hr_dir_path.glob("*.png")} + lr_img_names = {lr_img.name for lr_img in self.lr_dir_path.glob("*.png")} + + self.img_names = sorted(list(hr_img_names & lr_img_names)) + + valid_images = [] + for img_name in self.img_names: + lr_path = self.lr_dir_path / img_name + + with Image.open(lr_path) as img: + img_width, img_height = img.size + if img_width > self.patch_size and img_height > self.patch_size: + valid_images.append(img_name) + else: + logger.warning(f"Dropped {img_name} ({img_width}x{img_height})") + + self.img_names = valid_images + + if len(self.img_names) == 0: + raise FileNotFoundError( + f"[Data] No matching files found between '{self.hr_dir_path}' and '{self.lr_dir_path}'" + ) + + if len(hr_img_names) != len(lr_img_names): + logger.warning( + f"[Data] Count mismatch! HR: {len(hr_img_names)}, LR: {len(lr_img_names)}. " + f"Proceeding with {len(self.img_names)} common files." + ) + + if dev_mode: + self.img_names = self.img_names[: int(len(self.img_names) * 0.1)] + + self.normalize = transforms.ToDtype(torch.float32, scale=True) + + def __len__(self) -> int: + return len(self.img_names) + + def __getitem__(self, index: int) -> dict[str, Tensor]: + img_name = self.img_names[index] + + hr_img_path = str(self.hr_dir_path / img_name) + lr_img_path = str(self.lr_dir_path / img_name) + + try: + hr_img_tensor = decode_image(hr_img_path, mode=ImageReadMode.RGB) + lr_img_tensor = decode_image(lr_img_path, mode=ImageReadMode.RGB) + except Exception as e: + if self.test_mode: + logger.error(f"[Data] Error loading '{img_name}'.") + raise e + else: + logger.warning(f"[Data] Failed to load '{img_name}'. Skipping and resampling.") + return self.__getitem__(random.randint(0, len(self) - 1)) + + hr_img_tensor = self.normalize(hr_img_tensor) + lr_img_tensor = self.normalize(lr_img_tensor) + + if self.test_mode: + return {"hr": hr_img_tensor, "lr": lr_img_tensor} + + _, img_height, img_width = lr_img_tensor.shape + + if img_height <= self.patch_size or img_width <= self.patch_size: + logger.warning( + f"[Data] Image '{img_name}' ({img_height}x{img_width}) is smaller than patch size ({self.patch_size}). Skipped." + ) + return self.__getitem__(random.randint(0, len(self) - 1)) + + lr_y = random.randint(0, img_height - self.patch_size) + lr_x = random.randint(0, img_width - self.patch_size) + + hr_y = lr_y * self.scaling_factor + hr_x = lr_x * self.scaling_factor + hr_patch_size = self.patch_size * self.scaling_factor + + lr_patch = lr_img_tensor[:, lr_y : lr_y + self.patch_size, lr_x : lr_x + self.patch_size] + hr_patch = hr_img_tensor[:, hr_y : hr_y + hr_patch_size, hr_x : hr_x + hr_patch_size] + + if random.random() < 0.5: + lr_patch = transforms.functional.hflip(lr_patch) + hr_patch = transforms.functional.hflip(hr_patch) + + if random.random() < 0.5: + lr_patch = transforms.functional.vflip(lr_patch) + hr_patch = transforms.functional.vflip(hr_patch) + + k = random.randint(0, 3) + if k > 0: + lr_patch = torch.rot90(lr_patch, k, [1, 2]) + hr_patch = torch.rot90(hr_patch, k, [1, 2]) + + return {"hr": hr_patch, "lr": lr_patch} From 244f8e20a3f1ac1e66e36705bcd886c896054e40 Mon Sep 17 00:00:00 2001 From: ash1ra Date: Wed, 18 Feb 2026 11:30:40 +0200 Subject: [PATCH 4/5] feat: create `DynamicPairDataset` class for same-task pretraining technique --- dataset.py => datasets.py | 89 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) rename dataset.py => datasets.py (58%) diff --git a/dataset.py b/datasets.py similarity index 58% rename from dataset.py rename to datasets.py index 09ac08e..36fc3e7 100644 --- a/dataset.py +++ b/datasets.py @@ -2,16 +2,103 @@ from pathlib import Path import torch +from einops import rearrange from PIL import Image from torch import Tensor from torch.utils.data import Dataset from torchvision.io import ImageReadMode, decode_image from torchvision.transforms import v2 as transforms +from prepare_data import imresize from utils import logger -class SRDataset(Dataset): +class DynamicPairDataset(Dataset): + def __init__( + self, + data_path: Path, + scaling_factor: int, + patch_size: int, + test_mode: bool = False, + dev_mode: bool = False, + ) -> None: + self.scaling_factor = scaling_factor + self.patch_size = patch_size + self.test_mode = test_mode + self.dev_mode = dev_mode + + with open(data_path, "r") as f: + self.img_paths = [Path(line.strip()) for line in f if line.strip()] + + if len(self.img_paths) == 0: + raise FileNotFoundError("Images not found at specified path.") + + if dev_mode: + self.img_paths = self.img_paths[: int(len(self.img_paths) * 0.1)] + + self.normalize = transforms.ToDtype(torch.float32, scale=True) + + def __len__(self) -> int: + return len(self.img_paths) + + def __getitem__(self, index: int) -> dict[str, Tensor]: + img_path = self.img_paths[index] + + try: + hr_img_tensor = decode_image(str(img_path), mode=ImageReadMode.RGB) + except Exception: + return self.__getitem__(random.randint(0, len(self) - 1)) + + if self.test_mode: + _, img_height, img_width = hr_img_tensor.shape + + img_height_new = img_height - img_height % self.scaling_factor + img_width_new = img_width - img_width % self.scaling_factor + hr_img_tensor = hr_img_tensor[:, :img_height_new, :img_width_new] + + hr_patch_np = rearrange(hr_img_tensor, "c h w -> h w c").numpy() + lr_patch_np = imresize(hr_patch_np, scale=1 / self.scaling_factor) + lr_patch = rearrange(torch.from_numpy(lr_patch_np), "h w c -> c h w") + + return { + "hr": self.normalize(hr_img_tensor), + "lr": self.normalize(lr_patch), + } + + _, img_height, img_width = hr_img_tensor.shape + hr_patch_size = self.patch_size * self.scaling_factor + + if img_height < hr_patch_size or img_width < hr_patch_size: + logger.warning( + f"[Data] Image '{img_path.name}' ({img_height}x{img_width}) is smaller than patch size ({hr_patch_size}). Skipped." + ) + return self.__getitem__(random.randint(0, len(self) - 1)) + + y = random.randint(0, img_height - hr_patch_size) + x = random.randint(0, img_width - hr_patch_size) + hr_patch = hr_img_tensor[:, y : y + hr_patch_size, x : x + hr_patch_size] + + if random.random() < 0.5: + hr_patch = transforms.functional.hflip(hr_patch) + + if random.random() < 0.5: + hr_patch = transforms.functional.vflip(hr_patch) + + k = random.randint(0, 3) + if k > 0: + hr_patch = torch.rot90(hr_patch, k, [1, 2]) + + hr_patch_np = rearrange(hr_patch, "c h w -> h w c").numpy() + lr_patch_np = imresize(hr_patch_np, scale=1 / self.scaling_factor) + lr_patch = rearrange(torch.from_numpy(lr_patch_np), "h w c -> c h w") + + return { + "hr": self.normalize(hr_patch), + "lr": self.normalize(lr_patch), + } + + +class StaticPairDataset(Dataset): def __init__( self, data_path: Path, From 1cd8612c6fe247e1bc01c9913f90579fe0f8c095 Mon Sep 17 00:00:00 2001 From: ash1ra Date: Tue, 17 Feb 2026 14:34:16 +0200 Subject: [PATCH 5/5] test: create tests for datasets --- tests/test_dataset.py | 298 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 298 insertions(+) create mode 100644 tests/test_dataset.py diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..10c169f --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,298 @@ +from pathlib import Path +from datasets import StaticPairDataset, DynamicPairDataset +import pytest +from PIL import Image + + +DATA_CONFIG = [ + # scaling_factor, patch_size + (2, 16), + (4, 32), + (8, 8), + (3, 16), +] + + +@pytest.fixture +def fake_root_dir(tmp_path: Path) -> Path: + return tmp_path / "data" + + +def create_fake_data_static( + root: Path, + scaling_factor: int, + num_imgs: int = 10, + imgs_size: tuple[int, int] = (200, 200), +) -> None: + hr_dir = root / "HR" + lr_dir = root / f"LR_x{scaling_factor}" + + hr_dir.mkdir(parents=True, exist_ok=True) + lr_dir.mkdir(parents=True, exist_ok=True) + + for i in range(num_imgs): + Image.new("RGB", imgs_size, color="red").save(hr_dir / f"img_{i}.png") + Image.new("RGB", (imgs_size[0] // scaling_factor, imgs_size[1] // scaling_factor), color="blue").save( + lr_dir / f"img_{i}.png" + ) + + +def create_fake_data_dynamic( + root: Path, + num_imgs: int = 10, + imgs_size: tuple[int, int] = (200, 200), +) -> Path: + img_dir = root / "Dynamic" + img_dir.mkdir(parents=True, exist_ok=True) + + list_file_path = root / "train_list.txt" + + paths = [] + for i in range(num_imgs): + path = img_dir / f"img_{i}.png" + Image.new("RGB", imgs_size, color="green").save(path) + paths.append(str(path.absolute())) + + with open(list_file_path, "w") as f: + f.write("\n".join(paths)) + + return list_file_path + + +@pytest.mark.parametrize("scaling_factor, patch_size", DATA_CONFIG) +def test_static_dataset_initialization(fake_root_dir: Path, scaling_factor: int, patch_size: int) -> None: + create_fake_data_static(root=fake_root_dir, scaling_factor=scaling_factor) + + Image.new("RGB", (100, 100)).save(fake_root_dir / "HR" / "hr_orphan.png") + Image.new("RGB", (25, 25)).save(fake_root_dir / f"LR_x{scaling_factor}" / "lr_orphan.png") + + dataset = StaticPairDataset( + data_path=fake_root_dir, + scaling_factor=scaling_factor, + patch_size=patch_size, + ) + + assert len(dataset) == 10 + + +@pytest.mark.parametrize("scaling_factor, patch_size", DATA_CONFIG) +def test_static_dataset_output_shape(fake_root_dir: Path, scaling_factor: int, patch_size: int) -> None: + create_fake_data_static(root=fake_root_dir, scaling_factor=scaling_factor) + + dataset = StaticPairDataset( + data_path=fake_root_dir, + scaling_factor=scaling_factor, + patch_size=patch_size, + ) + + imgs_tensor = dataset[0] + lr_img_tensor = imgs_tensor["lr"] + hr_img_tensor = imgs_tensor["hr"] + + assert lr_img_tensor.shape == (3, patch_size, patch_size) + assert hr_img_tensor.shape == (3, patch_size * scaling_factor, patch_size * scaling_factor) + + +@pytest.mark.parametrize("scaling_factor, patch_size", DATA_CONFIG) +def test_static_dataset_normalization(fake_root_dir: Path, scaling_factor: int, patch_size: int) -> None: + create_fake_data_static(root=fake_root_dir, scaling_factor=scaling_factor) + + dataset = StaticPairDataset( + data_path=fake_root_dir, + scaling_factor=scaling_factor, + patch_size=patch_size, + ) + + imgs_tensor = dataset[0] + + assert imgs_tensor["hr"].max() <= 1.0 + assert imgs_tensor["hr"].min() >= 0.0 + + assert imgs_tensor["lr"].max() <= 1.0 + assert imgs_tensor["lr"].min() >= 0.0 + + +@pytest.mark.parametrize("scaling_factor, patch_size", DATA_CONFIG) +def test_static_dataset_small_image(fake_root_dir: Path, scaling_factor: int, patch_size: int) -> None: + create_fake_data_static(root=fake_root_dir, scaling_factor=scaling_factor) + + Image.new("RGB", (5, 5)).save(fake_root_dir / "HR" / "tiny.png") + Image.new("RGB", (5, 5)).save(fake_root_dir / f"LR_x{scaling_factor}" / "tiny.png") + + dataset = StaticPairDataset( + data_path=fake_root_dir, + scaling_factor=scaling_factor, + patch_size=patch_size, + ) + + for i in range(len(dataset)): + imgs_tensor = dataset[i] + assert imgs_tensor["lr"].shape == (3, patch_size, patch_size) + + +@pytest.mark.parametrize("scaling_factor, patch_size", DATA_CONFIG) +def test_static_dataset_augmentations(fake_root_dir: Path, scaling_factor: int, patch_size: int) -> None: + create_fake_data_static(root=fake_root_dir, scaling_factor=scaling_factor) + + dataset = StaticPairDataset( + data_path=fake_root_dir, + scaling_factor=scaling_factor, + patch_size=patch_size, + ) + + for _ in range(50): + imgs_tensor = dataset[0] + assert imgs_tensor["hr"].shape == (3, patch_size * scaling_factor, patch_size * scaling_factor) + + +@pytest.mark.parametrize("scaling_factor, patch_size", DATA_CONFIG) +def test_static_dataset_test_mode(fake_root_dir: Path, scaling_factor: int, patch_size: int) -> None: + create_fake_data_static(root=fake_root_dir, scaling_factor=scaling_factor) + + img_height = 200 + img_width = 200 + + dataset = StaticPairDataset( + data_path=fake_root_dir, + scaling_factor=scaling_factor, + patch_size=patch_size, + test_mode=True, + ) + + imgs_tensor = dataset[0] + + assert imgs_tensor["hr"].shape == (3, img_height, img_width) + assert imgs_tensor["lr"].shape == (3, img_height // scaling_factor, img_width // scaling_factor) + + +@pytest.mark.parametrize("scaling_factor, patch_size", DATA_CONFIG) +def test_static_dataset_dev_mode(fake_root_dir: Path, scaling_factor: int, patch_size: int) -> None: + create_fake_data_static(root=fake_root_dir, scaling_factor=scaling_factor) + + dataset = StaticPairDataset( + data_path=fake_root_dir, + scaling_factor=scaling_factor, + patch_size=patch_size, + dev_mode=True, + ) + + assert len(dataset) == 1 + + +@pytest.mark.parametrize("scaling_factor, patch_size", DATA_CONFIG) +def test_dynamic_dataset_initialization(fake_root_dir: Path, scaling_factor: int, patch_size: int) -> None: + list_file = create_fake_data_dynamic(root=fake_root_dir) + + dataset = DynamicPairDataset( + data_path=list_file, + scaling_factor=scaling_factor, + patch_size=patch_size, + ) + + assert len(dataset) == 10 + + +@pytest.mark.parametrize("scaling_factor, patch_size", DATA_CONFIG) +def test_dynamic_dataset_output_shape(fake_root_dir: Path, scaling_factor: int, patch_size: int) -> None: + list_file = create_fake_data_dynamic(root=fake_root_dir) + + dataset = DynamicPairDataset( + data_path=list_file, + scaling_factor=scaling_factor, + patch_size=patch_size, + ) + + imgs_tensor = dataset[0] + lr_img_tensor = imgs_tensor["lr"] + hr_img_tensor = imgs_tensor["hr"] + + assert lr_img_tensor.shape == (3, patch_size, patch_size) + assert hr_img_tensor.shape == (3, patch_size * scaling_factor, patch_size * scaling_factor) + + +@pytest.mark.parametrize("scaling_factor, patch_size", DATA_CONFIG) +def test_dynamic_dataset_normalization(fake_root_dir: Path, scaling_factor: int, patch_size: int) -> None: + list_file = create_fake_data_dynamic(root=fake_root_dir) + + dataset = DynamicPairDataset( + data_path=list_file, + scaling_factor=scaling_factor, + patch_size=patch_size, + ) + + imgs_tensor = dataset[0] + + assert imgs_tensor["hr"].max() <= 1.0 + assert imgs_tensor["hr"].min() >= 0.0 + + assert imgs_tensor["lr"].max() <= 1.0 + assert imgs_tensor["lr"].min() >= 0.0 + + +@pytest.mark.parametrize("scaling_factor, patch_size", DATA_CONFIG) +def test_dynamic_dataset_small_image(fake_root_dir: Path, scaling_factor: int, patch_size: int) -> None: + list_file = create_fake_data_dynamic(root=fake_root_dir) + + Image.new("RGB", (5, 5)).save(fake_root_dir / "Dynamic" / "tiny.png") + + dataset = DynamicPairDataset( + data_path=list_file, + scaling_factor=scaling_factor, + patch_size=patch_size, + ) + + for i in range(len(dataset)): + imgs_tensor = dataset[i] + assert imgs_tensor["lr"].shape == (3, patch_size, patch_size) + + +@pytest.mark.parametrize("scaling_factor, patch_size", DATA_CONFIG) +def test_dynamic_dataset_augmentations(fake_root_dir: Path, scaling_factor: int, patch_size: int) -> None: + list_file = create_fake_data_dynamic(root=fake_root_dir) + + dataset = DynamicPairDataset( + data_path=list_file, + scaling_factor=scaling_factor, + patch_size=patch_size, + ) + + for _ in range(50): + imgs_tensor = dataset[0] + assert imgs_tensor["hr"].shape == (3, patch_size * scaling_factor, patch_size * scaling_factor) + + +@pytest.mark.parametrize("scaling_factor, patch_size", DATA_CONFIG) +def test_dynamic_dataset_test_mode(fake_root_dir: Path, scaling_factor: int, patch_size: int) -> None: + list_file = create_fake_data_dynamic(root=fake_root_dir) + + dataset = DynamicPairDataset( + data_path=list_file, + scaling_factor=scaling_factor, + patch_size=patch_size, + test_mode=True, + ) + + img_height = 200 + img_width = 200 + + expected_img_height = img_height - img_height % scaling_factor + expected_img_width = img_width - img_width % scaling_factor + + imgs_tensor = dataset[0] + + assert imgs_tensor["hr"].shape == (3, expected_img_height, expected_img_width) + assert imgs_tensor["lr"].shape == (3, expected_img_height // scaling_factor, expected_img_width // scaling_factor) + + +@pytest.mark.parametrize("scaling_factor, patch_size", DATA_CONFIG) +def test_dynamic_dataset_dev_mode(fake_root_dir: Path, scaling_factor: int, patch_size: int) -> None: + list_file = create_fake_data_dynamic(root=fake_root_dir) + + dataset = DynamicPairDataset( + data_path=list_file, + scaling_factor=scaling_factor, + patch_size=patch_size, + dev_mode=True, + ) + + assert len(dataset) == 1