diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..1bfc2a5
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,166 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+/weights*
+/input*
+/results*
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
+.pdm.toml
+.pdm-python
+.pdm-build/
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
diff --git a/README.md b/README.md
index 7265ac6..9a39fcb 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,27 @@
+Adjustment to the original repository, which allows to run it on weaker cards.
+Uses chunking/patching.
+
+Currently, original repo installation steps will result in wrong dependencies.
+So instead, install like this: `pip install --upgrade-strategy only-if-needed -r requirements.txt`
+
+Or you can download the full setup from here https://github.com/IgorAherne/Shadow_R/releases/tag/latest
+This way you won't need to run any pip installs.
+
+If you need neural nets, get them from google drive of the original repo
+If the drive is unaccessible, you can get them from here too [Release](https://github.com/IgorAherne/Shadow_R/releases/tag/original_weights)
+
+Launch via `python ./test.py --chunk_size 512` or `--chunk_size 256` etc
+
+arguments and their default values (see test.py):
+`--test_dir = ./ShadowDataset/test` where the input images are
+`--input_dir = ./input/`
+`--output_dir = ./output/`
+`--chunk_size = 512` size of sliding window, to split the work into smaller pieces, for performance. Careful, might create seams
+`--overlap = 64` overlap among the windows, to hide possible seams
+
+
+Original repo description:
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..f7a60ba
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,43 @@
+# Custom PyTorch index for CUDA 11.3
+--extra-index-url https://download.pytorch.org/whl/cu113
+
+# Core PyTorch ecosystem - pinned versions
+torch==1.11.0+cu113
+torchvision==0.12.0+cu113
+
+# Core dependencies with strict versions
+numpy==1.24.3 # Compatible with PyTorch 1.11
+typing-extensions==4.12.2 # Required by PyTorch
+
+# Deep Learning Framework Dependencies
+pytorch-lightning==1.9.0 # Compatible with torch 1.11
+torchmetrics==0.11.4 # Compatible with torch 1.11
+timm==1.0.11
+einops==0.8.0
+
+# Vision and Image Processing
+kornia==0.7.4
+kornia-rs==0.1.7
+Pillow==11.0.0
+opencv-python==4.10.0.84
+scikit-image==0.24.0
+
+# Scientific Computing
+scipy==1.14.1
+scikit-learn==1.5.2
+matplotlib==3.9.3
+
+# Additional Required Dependencies
+joblib==1.4.2 # For scikit-learn
+threadpoolctl==3.5.0 # For scikit-learn
+networkx==3.4.2 # For scikit-image
+tifffile==2024.9.20 # For scikit-image
+imageio==2.36.1 # For scikit-image
+PyYAML==6.0.2 # For pytorch-lightning
+fsspec==2024.10.0 # For pytorch-lightning
+packaging==24.2
+tqdm==4.67.1
+
+# Ensure proper package building
+setuptools>=63.2.0
+wheel>=0.37.0
diff --git a/test.py b/test.py
index 5371a7c..5dfb514 100644
--- a/test.py
+++ b/test.py
@@ -9,65 +9,164 @@
from torchvision import transforms
from test_dataset import dehaze_test_dataset
from model import final_net
+import torch.nn.functional as F
+from PIL import Image
+import numpy as np
+
+def process_chunk(model, chunk, device):
+ with torch.no_grad():
+ chunk = chunk.to(device)
+ return model(chunk)
+
+def split_image(img_tensor, window_size=512, overlap=64):
+ """Split image into overlapping chunks."""
+ _, _, h, w = img_tensor.shape
+ chunks = []
+ positions = []
+
+ for y in range(0, h, window_size - overlap):
+ for x in range(0, w, window_size - overlap):
+ # Calculate chunk boundaries
+ y1 = y
+ x1 = x
+ y2 = min(y + window_size, h)
+ x2 = min(x + window_size, w)
+
+ # Extract chunk
+ chunk = img_tensor[:, :, y1:y2, x1:x2]
+
+ # Pad if necessary to maintain consistent size
+ if chunk.shape[2] < window_size or chunk.shape[3] < window_size:
+ ph = window_size - chunk.shape[2]
+ pw = window_size - chunk.shape[3]
+ chunk = F.pad(chunk, (0, pw, 0, ph))
+
+ chunks.append(chunk)
+ positions.append((y1, y2, x1, x2))
+
+ return chunks, positions
-
-parser = argparse.ArgumentParser(description='Shadow')
-parser.add_argument('--test_dir', type=str, default='./ShadowDataset/test/')
-parser.add_argument('--output_dir', type=str, default='results/')
-parser.add_argument('-test_batch_size', help='Set the testing batch size', default=1, type=int)
-args = parser.parse_args()
-output_dir =args.output_dir
-if not os.path.exists(output_dir + '/'):
- os.makedirs(output_dir + '/', exist_ok=True)
-test_dir = args.test_dir
-test_batch_size = args.test_batch_size
-
-test_dataset = dehaze_test_dataset(test_dir)
-test_loader = DataLoader(dataset=test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0)
-
-device = 'cuda:0'
-print(device)
-
-model = final_net()
-
-try:
- model.remove_model.load_state_dict(torch.load(os.path.join('weights', 'shadowremoval.pkl'), map_location='cpu'), strict=True)
- print('loading removal_model success')
-except:
- print('loading removal_model error')
-
-
-try:
- model.enhancement_model.load_state_dict(torch.load(os.path.join('weights', 'refinement.pkl'), map_location='cpu'), strict=True)
- print('loading enhancement model success')
-except:
- print('loading enhancement model error')
-
-model = model.to(device)
-
-total_time = 0
-with torch.no_grad():
+def merge_chunks(chunks, positions, original_size, window_size=512, overlap=64):
+ """Merge overlapping chunks with linear blending."""
+ h, w = original_size
+ result = torch.zeros((1, 3, h, w), device=chunks[0].device)
+ weights = torch.zeros((1, 3, h, w), device=chunks[0].device)
+
+ for chunk, (y1, y2, x1, x2) in zip(chunks, positions):
+ # Calculate actual chunk size (might be smaller at edges)
+ chunk_h = y2 - y1
+ chunk_w = x2 - x1
+
+ # Extract the valid portion of the chunk (without padding)
+ valid_chunk = chunk[:, :, :chunk_h, :chunk_w]
+
+ # Create weight mask
+ weight = torch.ones_like(valid_chunk)
+
+ # Apply linear blending in overlap regions
+ if overlap > 0:
+ for i in range(overlap):
+ weight_value = i / overlap
+ # Blend left edge if not at image boundary
+ if x1 > 0 and i < chunk_w:
+ weight[:, :, :, i] *= weight_value
+ # Blend right edge if not at image boundary
+ if x2 < w and chunk_w - i - 1 >= 0:
+ weight[:, :, :, -(i + 1)] *= weight_value
+ # Blend top edge if not at image boundary
+ if y1 > 0 and i < chunk_h:
+ weight[:, :, i, :] *= weight_value
+ # Blend bottom edge if not at image boundary
+ if y2 < h and chunk_h - i - 1 >= 0:
+ weight[:, :, -(i + 1), :] *= weight_value
+
+ result[:, :, y1:y2, x1:x2] += valid_chunk * weight
+ weights[:, :, y1:y2, x1:x2] += weight
+
+ # Normalize by weights to complete blending
+ valid_mask = weights > 0
+ result[valid_mask] = result[valid_mask] / weights[valid_mask]
+
+ return result
+
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Shadow')
+ parser.add_argument('--input_dir', type=str, default='/ShadowDataset/test/')
+ parser.add_argument('--output_dir', type=str, default='output/')
+ parser.add_argument('--chunk_size', type=int, default=512, help='Size of sliding window')
+ parser.add_argument('--overlap', type=int, default=64, help='Overlap between windows')
+ args = parser.parse_args()
+
+ # Ensure paths end with slash
+ args.input_dir = os.path.join(args.input_dir, '')
+ args.output_dir = os.path.join(args.output_dir, '')
+ print('')
+ print(f'input_dir: {args.input_dir}')
+ print(f'output_dir: {args.output_dir}')
+ print(f'chunk size: {args.chunk_size}. If algorithm is stuck, reduce chunk size to fit inside VRAM.')
+ print('')
+
+ if not os.path.exists(args.output_dir):
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ test_dataset = dehaze_test_dataset(args.input_dir)
+ test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=0)
+
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
+ print(f"Using device: {device}")
+
+ model = final_net()
+
+ try:
+ model.remove_model.load_state_dict(torch.load(os.path.join('weights', 'shadowremoval.pkl'), map_location='cpu'), strict=True)
+ print('Loading removal_model success')
+ except:
+ print('Loading removal_model error')
+ return
+
+ try:
+ model.enhancement_model.load_state_dict(torch.load(os.path.join('weights', 'refinement.pkl'), map_location='cpu'), strict=True)
+ print('Loading enhancement model success')
+ except:
+ print('Loading enhancement model error')
+ return
+
+ model = model.to(device)
model.eval()
-
- start = time.time()
- for batch_idx, (input, name) in enumerate(test_loader):
- print(name[0])
- input = input.to(device)
- frame_out = model(input)
- frame_out = frame_out.to(device)
- name = re.findall("\d+",str(name))
- imwrite(frame_out, os.path.join(output_dir, str(name[0])+'.png'), range=(0, 1))
-
-
-
-
-
-
-
-
-
-
-
-
-
+ total_time = 0
+ with torch.no_grad():
+ for batch_idx, (input_img, name) in enumerate(test_loader):
+ print(f"Processing {name[0]}")
+
+ # Get image dimensions
+ _, _, h, w = input_img.shape
+ print(f"Image size: {w}x{h}")
+
+ # Split image into chunks
+ chunks, positions = split_image(input_img, args.chunk_size, args.overlap)
+ print(f"Split into {len(chunks)} chunks")
+
+ # Process each chunk
+ processed_chunks = []
+ for i, chunk in enumerate(chunks):
+ print(f"Processing chunk {i+1}/{len(chunks)}")
+ processed_chunk = process_chunk(model, chunk, device)
+ processed_chunks.append(processed_chunk)
+ torch.cuda.empty_cache() # Clear GPU memory after each chunk
+
+ # Merge chunks
+ result = merge_chunks(processed_chunks, positions, (h, w), args.chunk_size, args.overlap)
+
+ # Save result
+ name = re.findall(r"\d+", str(name))
+ save_path = os.path.join(args.output_dir, f"{name[0]}.png")
+ print(f"Saving result to {save_path}")
+ imwrite(result, save_path, range=(0, 1))
+
+ torch.cuda.empty_cache()
+
+if __name__ == '__main__':
+ main()
diff --git a/test_dataset.py b/test_dataset.py
index a27e618..8c58454 100644
--- a/test_dataset.py
+++ b/test_dataset.py
@@ -8,12 +8,11 @@ def __init__(self, test_dir):
self.transform = transforms.Compose([transforms.ToTensor()])
self.list_test_hazy=[]
- self.root_hazy=os.path.join(test_dir, 'LQ/')
+ self.root_hazy=test_dir
for i in os.listdir(self.root_hazy):
self.list_test_hazy.append(i)
#self.root_hazy = os.path.join(test_dir)
-
self.file_len = len(self.list_test_hazy)