diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1bfc2a5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,166 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +/weights* +/input* +/results* + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/README.md b/README.md index 7265ac6..9a39fcb 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,27 @@ +Adjustment to the original repository, which allows to run it on weaker cards.
+Uses chunking/patching.
+ +Currently, original repo installation steps will result in wrong dependencies.
+So instead, install like this: `pip install --upgrade-strategy only-if-needed -r requirements.txt` + +Or you can download the full setup from here https://github.com/IgorAherne/Shadow_R/releases/tag/latest
+This way you won't need to run any pip installs. + +If you need neural nets, get them from google drive of the original repo
+If the drive is unaccessible, you can get them from here too [Release](https://github.com/IgorAherne/Shadow_R/releases/tag/original_weights) + +Launch via `python ./test.py --chunk_size 512` or `--chunk_size 256` etc
+ +arguments and their default values (see test.py):
+`--test_dir = ./ShadowDataset/test` where the input images are
+`--input_dir = ./input/`
+`--output_dir = ./output/`
+`--chunk_size = 512` size of sliding window, to split the work into smaller pieces, for performance. Careful, might create seams
+`--overlap = 64` overlap among the windows, to hide possible seams
+ + +Original repo description: +  
diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f7a60ba --- /dev/null +++ b/requirements.txt @@ -0,0 +1,43 @@ +# Custom PyTorch index for CUDA 11.3 +--extra-index-url https://download.pytorch.org/whl/cu113 + +# Core PyTorch ecosystem - pinned versions +torch==1.11.0+cu113 +torchvision==0.12.0+cu113 + +# Core dependencies with strict versions +numpy==1.24.3 # Compatible with PyTorch 1.11 +typing-extensions==4.12.2 # Required by PyTorch + +# Deep Learning Framework Dependencies +pytorch-lightning==1.9.0 # Compatible with torch 1.11 +torchmetrics==0.11.4 # Compatible with torch 1.11 +timm==1.0.11 +einops==0.8.0 + +# Vision and Image Processing +kornia==0.7.4 +kornia-rs==0.1.7 +Pillow==11.0.0 +opencv-python==4.10.0.84 +scikit-image==0.24.0 + +# Scientific Computing +scipy==1.14.1 +scikit-learn==1.5.2 +matplotlib==3.9.3 + +# Additional Required Dependencies +joblib==1.4.2 # For scikit-learn +threadpoolctl==3.5.0 # For scikit-learn +networkx==3.4.2 # For scikit-image +tifffile==2024.9.20 # For scikit-image +imageio==2.36.1 # For scikit-image +PyYAML==6.0.2 # For pytorch-lightning +fsspec==2024.10.0 # For pytorch-lightning +packaging==24.2 +tqdm==4.67.1 + +# Ensure proper package building +setuptools>=63.2.0 +wheel>=0.37.0 diff --git a/test.py b/test.py index 5371a7c..5dfb514 100644 --- a/test.py +++ b/test.py @@ -9,65 +9,164 @@ from torchvision import transforms from test_dataset import dehaze_test_dataset from model import final_net +import torch.nn.functional as F +from PIL import Image +import numpy as np + +def process_chunk(model, chunk, device): + with torch.no_grad(): + chunk = chunk.to(device) + return model(chunk) + +def split_image(img_tensor, window_size=512, overlap=64): + """Split image into overlapping chunks.""" + _, _, h, w = img_tensor.shape + chunks = [] + positions = [] + + for y in range(0, h, window_size - overlap): + for x in range(0, w, window_size - overlap): + # Calculate chunk boundaries + y1 = y + x1 = x + y2 = min(y + window_size, h) + x2 = min(x + window_size, w) + + # Extract chunk + chunk = img_tensor[:, :, y1:y2, x1:x2] + + # Pad if necessary to maintain consistent size + if chunk.shape[2] < window_size or chunk.shape[3] < window_size: + ph = window_size - chunk.shape[2] + pw = window_size - chunk.shape[3] + chunk = F.pad(chunk, (0, pw, 0, ph)) + + chunks.append(chunk) + positions.append((y1, y2, x1, x2)) + + return chunks, positions - -parser = argparse.ArgumentParser(description='Shadow') -parser.add_argument('--test_dir', type=str, default='./ShadowDataset/test/') -parser.add_argument('--output_dir', type=str, default='results/') -parser.add_argument('-test_batch_size', help='Set the testing batch size', default=1, type=int) -args = parser.parse_args() -output_dir =args.output_dir -if not os.path.exists(output_dir + '/'): - os.makedirs(output_dir + '/', exist_ok=True) -test_dir = args.test_dir -test_batch_size = args.test_batch_size - -test_dataset = dehaze_test_dataset(test_dir) -test_loader = DataLoader(dataset=test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0) - -device = 'cuda:0' -print(device) - -model = final_net() - -try: - model.remove_model.load_state_dict(torch.load(os.path.join('weights', 'shadowremoval.pkl'), map_location='cpu'), strict=True) - print('loading removal_model success') -except: - print('loading removal_model error') - - -try: - model.enhancement_model.load_state_dict(torch.load(os.path.join('weights', 'refinement.pkl'), map_location='cpu'), strict=True) - print('loading enhancement model success') -except: - print('loading enhancement model error') - -model = model.to(device) - -total_time = 0 -with torch.no_grad(): +def merge_chunks(chunks, positions, original_size, window_size=512, overlap=64): + """Merge overlapping chunks with linear blending.""" + h, w = original_size + result = torch.zeros((1, 3, h, w), device=chunks[0].device) + weights = torch.zeros((1, 3, h, w), device=chunks[0].device) + + for chunk, (y1, y2, x1, x2) in zip(chunks, positions): + # Calculate actual chunk size (might be smaller at edges) + chunk_h = y2 - y1 + chunk_w = x2 - x1 + + # Extract the valid portion of the chunk (without padding) + valid_chunk = chunk[:, :, :chunk_h, :chunk_w] + + # Create weight mask + weight = torch.ones_like(valid_chunk) + + # Apply linear blending in overlap regions + if overlap > 0: + for i in range(overlap): + weight_value = i / overlap + # Blend left edge if not at image boundary + if x1 > 0 and i < chunk_w: + weight[:, :, :, i] *= weight_value + # Blend right edge if not at image boundary + if x2 < w and chunk_w - i - 1 >= 0: + weight[:, :, :, -(i + 1)] *= weight_value + # Blend top edge if not at image boundary + if y1 > 0 and i < chunk_h: + weight[:, :, i, :] *= weight_value + # Blend bottom edge if not at image boundary + if y2 < h and chunk_h - i - 1 >= 0: + weight[:, :, -(i + 1), :] *= weight_value + + result[:, :, y1:y2, x1:x2] += valid_chunk * weight + weights[:, :, y1:y2, x1:x2] += weight + + # Normalize by weights to complete blending + valid_mask = weights > 0 + result[valid_mask] = result[valid_mask] / weights[valid_mask] + + return result + + + +def main(): + parser = argparse.ArgumentParser(description='Shadow') + parser.add_argument('--input_dir', type=str, default='/ShadowDataset/test/') + parser.add_argument('--output_dir', type=str, default='output/') + parser.add_argument('--chunk_size', type=int, default=512, help='Size of sliding window') + parser.add_argument('--overlap', type=int, default=64, help='Overlap between windows') + args = parser.parse_args() + + # Ensure paths end with slash + args.input_dir = os.path.join(args.input_dir, '') + args.output_dir = os.path.join(args.output_dir, '') + print('') + print(f'input_dir: {args.input_dir}') + print(f'output_dir: {args.output_dir}') + print(f'chunk size: {args.chunk_size}. If algorithm is stuck, reduce chunk size to fit inside VRAM.') + print('') + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir, exist_ok=True) + + test_dataset = dehaze_test_dataset(args.input_dir) + test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=0) + + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + print(f"Using device: {device}") + + model = final_net() + + try: + model.remove_model.load_state_dict(torch.load(os.path.join('weights', 'shadowremoval.pkl'), map_location='cpu'), strict=True) + print('Loading removal_model success') + except: + print('Loading removal_model error') + return + + try: + model.enhancement_model.load_state_dict(torch.load(os.path.join('weights', 'refinement.pkl'), map_location='cpu'), strict=True) + print('Loading enhancement model success') + except: + print('Loading enhancement model error') + return + + model = model.to(device) model.eval() - - start = time.time() - for batch_idx, (input, name) in enumerate(test_loader): - print(name[0]) - input = input.to(device) - frame_out = model(input) - frame_out = frame_out.to(device) - name = re.findall("\d+",str(name)) - imwrite(frame_out, os.path.join(output_dir, str(name[0])+'.png'), range=(0, 1)) - - - - - - - - - - - - - + total_time = 0 + with torch.no_grad(): + for batch_idx, (input_img, name) in enumerate(test_loader): + print(f"Processing {name[0]}") + + # Get image dimensions + _, _, h, w = input_img.shape + print(f"Image size: {w}x{h}") + + # Split image into chunks + chunks, positions = split_image(input_img, args.chunk_size, args.overlap) + print(f"Split into {len(chunks)} chunks") + + # Process each chunk + processed_chunks = [] + for i, chunk in enumerate(chunks): + print(f"Processing chunk {i+1}/{len(chunks)}") + processed_chunk = process_chunk(model, chunk, device) + processed_chunks.append(processed_chunk) + torch.cuda.empty_cache() # Clear GPU memory after each chunk + + # Merge chunks + result = merge_chunks(processed_chunks, positions, (h, w), args.chunk_size, args.overlap) + + # Save result + name = re.findall(r"\d+", str(name)) + save_path = os.path.join(args.output_dir, f"{name[0]}.png") + print(f"Saving result to {save_path}") + imwrite(result, save_path, range=(0, 1)) + + torch.cuda.empty_cache() + +if __name__ == '__main__': + main() diff --git a/test_dataset.py b/test_dataset.py index a27e618..8c58454 100644 --- a/test_dataset.py +++ b/test_dataset.py @@ -8,12 +8,11 @@ def __init__(self, test_dir): self.transform = transforms.Compose([transforms.ToTensor()]) self.list_test_hazy=[] - self.root_hazy=os.path.join(test_dir, 'LQ/') + self.root_hazy=test_dir for i in os.listdir(self.root_hazy): self.list_test_hazy.append(i) #self.root_hazy = os.path.join(test_dir) - self.file_len = len(self.list_test_hazy)