diff --git a/README.md b/README.md
index 26f5941..7bb61d8 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
| **PyPI** | **Versions** | **Docs & License** | **Tests** | **Reference** |
|:---------:|:-------------:|:------------------:|:----------:|:--------------:|
-| [](https://pypi.org/project/opensr-srgan/) | 
 | [](https://srgan.opensr.eu)
 | [](https://github.com/ESAOpenSR/SRGAN/actions/workflows/ci.yml)
[](https://codecov.io/github/ESAOpenSR/SRGAN) | [](https://arxiv.org/abs/2511.10461)
[](https://doi.org/10.5281/zenodo.17590993)
+| [](https://pypi.org/project/opensr-srgan/) | 
 | [](https://srgan.opensr.eu)
 | [](https://github.com/ESAOpenSR/SRGAN/actions/workflows/ci.yml)
[](https://codecov.io/github/ESAOpenSR/SRGAN) | [](https://arxiv.org/abs/2511.10461)
[](https://doi.org/10.5281/zenodo.17590993)
@@ -49,7 +49,7 @@ All key knobs are exposed via YAML in the `opensr_srgan/configs` folder:
* **EMA smoothing:** Enable `Training.EMA.enabled` to keep a shadow copy of the generator. Decay values in the 0.995–0.9999 range balance responsiveness with stability and are swapped in automatically for validation/inference.
* **Spectral normalization:** Optional for the SRGAN discriminator via `Discriminator.use_spectral_norm` to better control its Lipschitz constant and stabilize adversarial updates. [Miyato et al., 2018](https://arxiv.org/abs/1802.05957)
* **Wasserstein critic + R1 penalty:** Switch `Training.Losses.adv_loss_type: wasserstein` to enable a critic objective and pair it with the configurable `Training.Losses.r1_gamma` gradient penalty on real images for smoother discriminator updates. [Arjovsky et al., 2017](https://arxiv.org/abs/1701.07875); [Mescheder et al., 2018](https://arxiv.org/abs/1801.04406)
-* **Relativistic average GAN (BCE):** Set `Training.Losses.relativistic_average_d: true` to train D/G on relative real-vs-fake logits instead of absolute logits. This is supported in both Lightning training paths (PL1 and PL2).
+* **Relativistic average GAN (BCE):** Set `Training.Losses.relativistic_average_d: true` to train D/G on relative real-vs-fake logits instead of absolute logits. This is supported in the Lightning 2+ manual-optimization training path.
The schedule and ramp make training **easier, safer, and more reproducible**.
---
diff --git a/docs/architecture.md b/docs/architecture.md
index 046e8fb..ad9bccc 100644
--- a/docs/architecture.md
+++ b/docs/architecture.md
@@ -2,9 +2,9 @@
This document outlines how ESA OpenSR organises its super-resolution GAN, the major components that make up the model, and how each piece interacts during training and inference.
-## Vackground
+## Background
-OpenSR-SRGAN follows the single-image super-resolution (SISR) formulation in which the generator learns a mapping from a low-resolution observation $x$ to a plausible high-resolution reconstruction $x'$. The generator head widens the receptive field, a configurable trunk of $N$ residual-style blocks extracts features, and an upsampling tail increases spatial resolution. The residual fusion keeps skip connections active so the network focuses on high-frequency corrections rather than relearning the full signal:
+OpenSR-SRGAN follows the single-image super-resolution (SISR) formulation in which the generator learns a mapping from a low-resolution observation \(x\) to a plausible high-resolution reconstruction \(x'\). The generator head widens the receptive field, a configurable trunk of \(N\) residual-style blocks extracts features, and an upsampling tail increases spatial resolution. The residual fusion keeps skip connections active so the network focuses on high-frequency corrections rather than relearning the full signal:
$$
x' = \mathrm{Upsample}\!\left( \mathrm{Conv}_{\text{tail}}\!\left(\mathrm{Body}(x_{\text{head}}) + x_{\text{head}}\right)\! \right).
$$
@@ -24,9 +24,10 @@ Because every generator variant (residual, RCAB, RRDB, large-kernel attention, E
total-variation terms. Adversarial supervision uses `torch.nn.BCEWithLogitsLoss` with optional label smoothing.
* **Optimiser scheduling.** `configure_optimizers()` returns paired Adam optimisers (generator + discriminator) with
`ReduceLROnPlateau` schedulers that monitor a configurable validation metric.
-* **Training orchestration.** `training_step()` alternates discriminator (`optimizer_idx == 0`) and generator (`optimizer_idx ==
- 1`) updates. During the warm-up period configured by `Training.pretrain_g_only`, discriminator weights are frozen via
- `on_train_batch_start()` and a dedicated `pretraining_training_step()` computes purely content-driven updates.
+* **Training orchestration.** `setup_lightning()` binds `training_step_PL2()` and enables manual optimisation
+ (`automatic_optimization = False`). Each step performs explicit discriminator and generator optimiser updates; during the
+ warm-up period configured by `Training.pretrain_g_only`, the generator runs content-driven updates while discriminator metrics
+ are logged without stepping discriminator weights.
* **Validation and logging.** `validation_step()` computes the same content metrics, logs discriminator diagnostics, and pushes
qualitative image panels to Weights & Biases according to `Logging.num_val_images`.
* **Inference pipeline.** `predict_step()` automatically normalises Sentinel-2 style 0–10000 inputs, runs the generator,
@@ -53,7 +54,7 @@ The generator zoo lives under `opensr_srgan/model/generators/` and can be select
* **Stochastic GAN generator (`cgan_generator.py`).** Extends the flexible generator with conditioning inputs and latent noise,
enabling experiments where auxiliary metadata influences the super-resolution output.
* **ESRGAN generator (`esrgan.py`).** Implements the RRDBNet trunk introduced with ESRGAN, exposing `n_blocks`, `growth_channels`,
- and `res_scale` so you can dial in deeper receptive fields and sharper textures. The implementation supports original features like Relativistic Average GAN (RaGAN) and the codebase allows to perform two step training phase (content-oriented pretraining of generator followed by adversarial training with Discriminator) as originally proposed by ESRGAN authors.
+ and `res_scale` so you can dial in deeper receptive fields and sharper textures. The implementation supports original features like Relativistic Average GAN (RaGAN), and the codebase allows a two-step training phase (content-oriented pretraining of the generator followed by adversarial training with the discriminator), as originally proposed by the ESRGAN authors.
* **Advanced variants (`SRGAN_advanced.py`).** Provides additional block implementations and compatibility aliases exposed in
`__init__.py` for backwards compatibility.
diff --git a/docs/configuration.md b/docs/configuration.md
index 71e2485..fc627de 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -75,7 +75,7 @@ If you need to reuse the same function for both directions (for example
| Key | Default | Description |
| --- | --- | --- |
| `in_bands` | 6 | Number of input channels expected by the generator and discriminator. |
-| `continue_training` | `False` | Path to a Lightning checkpoint for resuming training (`ckpt_path` on PL ≥ 2, `resume_from_checkpoint` on PL < 2). |
+| `continue_training` | `False` | Path to a Lightning checkpoint for resuming training (`Trainer.fit(..., ckpt_path=...)`). |
| `load_checkpoint` | `False` | Path to a checkpoint used solely for weight initialisation (no training state restored). |
## Training
@@ -109,12 +109,13 @@ stable validation imagery. The EMA is fully optional and controlled through the
| `adv_loss_beta` | `1e-3` | Target weight applied to the adversarial term after ramp-up. |
| `adv_loss_schedule` | `cosine` | Ramp shape (`linear` or `cosine`). |
| `adv_loss_type` | `bce` | Adversarial objective (`bce` for classic SRGAN logits, `wasserstein` for a non-saturating critic-style loss). |
-| `relativistic_average_d` | `False` | BCE-only switch for relativistic-average GAN training (real/fake logits are compared against each other's batch mean). Supported in both PL1 and PL2 training-step implementations. |
+| `relativistic_average_d` | `False` | BCE-only switch for relativistic-average GAN training (real/fake logits are compared against each other's batch mean). Supported in the Lightning 2+ manual-optimization training-step implementation. |
| `r1_gamma` | `0.0` | Strength of the R1 gradient penalty applied to real images (useful with Wasserstein critics). |
| `l1_weight` | `1.0` | Weight of the pixelwise L1 loss. |
| `sam_weight` | `0.05` | Weight of the spectral angle mapper loss. |
| `perceptual_weight` | `0.1` | Weight of the perceptual feature loss. |
| `perceptual_metric` | `vgg` | Backbone used for perceptual features (`vgg` or `lpips`). |
+| `fixed_idx` | unset | Optional fixed 3-band indices used by perceptual loss when `in_bands > 3` (recommended: `[0, 1, 2]` for RGB+NIR setups). |
| `tv_weight` | `0.0` | Total variation regularisation strength. |
| `max_val` | `1.0` | Peak value assumed by PSNR/SSIM computations. |
| `ssim_win` | `11` | Window size for SSIM metrics. Must be an odd integer. |
@@ -133,6 +134,7 @@ stable validation imagery. The EMA is fully optional and controlled through the
| `growth_channels` | `32` | ESRGAN-only: growth channels inside each RRDB block. |
| `res_scale` | `0.2` | Residual scaling used by stochastic/ESRGAN variants. |
| `out_channels` | `Model.in_bands` | ESRGAN-only: override the number of output bands. |
+| `use_icnr` | `True` | ESRGAN-only: apply ICNR initialization to pre-PixelShuffle convolutions to reduce checkerboard artifacts in low-frequency regions. |
## Discriminator
@@ -200,6 +202,9 @@ The trainer instantiates independent Adam optimisers for the generator and discr
Weight decay exclusions are handled automatically: batch/instance/group-norm layers and bias parameters are filtered into a no-decay group so regularisation only touches convolutional kernels and dense weights. This mirrors best practices for GAN training and keeps normalisation statistics stable.
+For ESRGAN runs on multispectral datasets, a practical stability baseline is:
+`Generator.use_icnr: true`, `Training.Losses.fixed_idx: [0, 1, 2]`, and `optim_d_lr <= 0.5 * optim_g_lr`.
+
## Schedulers
Both optimisers share the same configuration keys because they use `torch.optim.lr_scheduler.ReduceLROnPlateau`.
diff --git a/docs/index.md b/docs/index.md
index 45686aa..0e25b44 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -4,7 +4,7 @@
| **PyPI** | **Versions** | **Docs & License** | **Tests** | **Reference** |
|:---------:|:-------------:|:------------------:|:----------:|:--------------:|
-| [](https://pypi.org/project/opensr-srgan/) | 
 | [](https://srgan.opensr.eu)
 | [](https://github.com/ESAOpenSR/SRGAN/actions/workflows/ci.yml)
[](https://codecov.io/github/ESAOpenSR/SRGAN) | [](https://arxiv.org/abs/2511.10461)
[](https://doi.org/10.5281/zenodo.17590993) |
+| [](https://pypi.org/project/opensr-srgan/) | 
 | [](https://srgan.opensr.eu)
 | [](https://github.com/ESAOpenSR/SRGAN/actions/workflows/ci.yml)
[](https://codecov.io/github/ESAOpenSR/SRGAN) | [](https://arxiv.org/abs/2511.10461)
[](https://doi.org/10.5281/zenodo.17590993) |

@@ -12,7 +12,7 @@ OpenSR-SRGAN is a comprehensive toolkit for training and evaluating super-resolu
that make adversarial optimisation tractable—generator warm-up phases, learning-rate scheduling, adversarial-weight ramping, and more. All options are driven by concise YAML configuration files so you can explore new architectures or datasets without
rewriting pipelines.
-Whether you are reproducing published results, exploring new remote-sensing modalities, or are trying to esablish some benchmarks, OpenSR-SRGAN gives you a clear and extensible foundation for multispectral super-resolution research.
+Whether you are reproducing published results, exploring new remote-sensing modalities, or trying to establish benchmarks, OpenSR-SRGAN gives you a clear and extensible foundation for multispectral super-resolution research.
> This repository and the configs represent the experiences that were made with SR-GAN training for remote sensing imagery. It's neither complete nor claims to perform SOTA SR, but it implements all tweaks and tips that make training SR-GANs easier.
@@ -25,7 +25,7 @@ Whether you are reproducing published results, exploring new remote-sensing moda
* **Battle-tested training loop.** PyTorch Lightning handles mixed precision, gradient accumulation, multi-GPU training, and
restartable checkpoints while the repo layers in GAN-specific tweaks such as adversarial weight ramping and learning-rate
restarts.
-* **Lightning 1.x ↔ 2.x compatibility.** The trainer detects the installed Lightning version at runtime and routes to the appropriate automatic- or manual-optimisation step so your configs run unchanged across releases.
+* **Lightning 2+ training path.** Training uses manual optimization (`automatic_optimization=False`) and a single GAN training-step implementation.
* **Remote-sensing aware defaults.** Normalisation, histogram matching, spectral-band handling, and Sentinel-2 SAFE ingestion are
ready-made for 10 m and 20 m bands and easily extendable to other sensors.
@@ -82,7 +82,7 @@ Whether you are reproducing published results, exploring new remote-sensing moda
* [Results](results.md) showcases results for some generator/discriminator and dataset combinations.
## ESA OpenSR
-OpenSR-SRGAN is part of the ESA [OpenSR](https://www.opensr.eu) ecosystem — an open framework for trustworthy super-resolution of multispectral satellite imagery. Within this initiative, this repository serves as the adversarial benchmark suite: it provides standardized GAN architectures, training procedures, and evaluation utilities that complement the other model types implemented in the project (diffusion, transformers, regression) and interfaces with from companion packages such as opensr-utils.
+OpenSR-SRGAN is part of the ESA [OpenSR](https://www.opensr.eu) ecosystem — an open framework for trustworthy super-resolution of multispectral satellite imagery. Within this initiative, this repository serves as the adversarial benchmark suite: it provides standardized GAN architectures, training procedures, and evaluation utilities that complement the other model types implemented in the project (diffusion, transformers, regression), and it interfaces with companion packages such as opensr-utils.
## Citation
diff --git a/docs/trainer-details.md b/docs/trainer-details.md
index 7e6a5bb..4264eef 100644
--- a/docs/trainer-details.md
+++ b/docs/trainer-details.md
@@ -1,164 +1,70 @@
# Trainer details
-This page walks through the control flow that powers adversarial optimisation in OpenSR-SRGAN. It cross-references the exact helper functions in the codebase so you can trace which checks run on every batch, how pretraining vs. adversarial steps are chosen, and how the PyTorch Lightning integration remains compatible with both 1.x and 2.x releases.
+This page describes the training control flow used by OpenSR-SRGAN on PyTorch Lightning 2+.
-## Version-aware bootstrap
+## Bootstrap sequence
-1. **Detect the installed Lightning release.** `SRGAN_model` stores the parsed semantic version via `self.pl_version = tuple(int(x) for x in pl.__version__.split("."))`. 【F:opensr_srgan/model/SRGAN.py†L66-L67】
-2. **Bind the appropriate training step.** `setup_lightning()` switches between the automatic-optimisation `training_step_PL1` helper (Lightning 1.x) and the manual-optimisation `training_step_PL2` clone (Lightning 2.x) while asserting the required optimisation mode. 【F:opensr_srgan/model/SRGAN.py†L191-L206】
-3. **Assemble Trainer keyword arguments.** `build_lightning_kwargs()` mirrors the version choice when it prepares the `Trainer` arguments: pre-2.0 builds receive `resume_from_checkpoint`, whereas 2.x runs use `ckpt_path`. It also normalises device selection (`Training.device`, `Training.gpus`) and strategy flags so multi-GPU training works consistently. 【F:opensr_srgan/utils/build_trainer_kwargs.py†L10-L122】
-4. **Resume or continue training.** When `Model.continue_training` points to a checkpoint path the trainer will resume in-place, preserving optimiser state, EMA buffers, and step counters. A fresh run keeps the value at `False`. 【F:opensr_srgan/train.py†L36-L63】
+1. **Validate Lightning version.** `SRGAN_model.setup_lightning()` enforces Lightning `>= 2.0`.
+2. **Enable manual optimization.** The model sets `automatic_optimization = False`.
+3. **Bind training-step helper.** `training_step_PL2` is attached as the active training-step implementation.
+4. **Build trainer kwargs.** `build_lightning_kwargs()` normalises accelerator/devices and prepares `fit_kwargs` (including `ckpt_path` when resuming).
-These checks ensure you can retrain a model on Lightning 1.9 or 2.2 with the same configuration file—no manual flag-flipping required.
+## Training-step anatomy
-## Training step anatomy (Lightning 1.x)
-
-The legacy automatic-optimisation path receives `(batch, batch_idx, optimizer_idx)` and splits discriminator and generator logic by the `optimizer_idx` flag:
+`training_step_PL2(self, batch, batch_idx)` always performs manual optimizer control.
```python
-# opensr_srgan/model/training_step_PL.py
+opt_d, opt_g = self.optimizers()
pretrain_phase = self._pretrain_check()
-if optimizer_idx == 1:
- self.log("training/pretrain_phase", float(pretrain_phase), sync_dist=True)
-
-if pretrain_phase:
- if optimizer_idx == 1:
- content_loss, metrics = self.content_loss_criterion.return_loss(sr_imgs, hr_imgs)
- self._log_generator_content_loss(content_loss)
- adv_weight = self._compute_adv_loss_weight()
- self._log_adv_loss_weight(adv_weight)
- return content_loss
- else:
- dummy = torch.zeros((), device=device, dtype=dtype, requires_grad=True)
- self.log("discriminator/adversarial_loss", dummy, sync_dist=True)
- return dummy
+self.log("training/pretrain_phase", float(pretrain_phase), sync_dist=True)
```
-* `_pretrain_check()` compares `self.global_step` against `Training.g_pretrain_steps` to decide whether the generator-only warm-up is active (`g_pretrain_steps: -1` keeps this phase active indefinitely). 【F:opensr_srgan/model/training_step_PL.py†L10-L46】
-* The pretraining branch logs the instantaneous adversarial weight even though it stays unused until GAN training begins. This keeps dashboards continuous when you review historical runs.
-* The discriminator receives a zero-valued tensor with `requires_grad=True` so Lightning's closure executes without mutating weights. Dummy logs (`discriminator/D(y)_prob`, `discriminator/D(G(x))_prob`) remain pinned to zero for clarity.
-
-Once `_pretrain_check()` flips to `False`, the function splits into discriminator and generator updates:
+### Pretraining branch
-* **Discriminator (`optimizer_idx == 0`).** Real and fake logits are compared against smoothed targets, and the resulting BCE components are summed into `discriminator/adversarial_loss`. If `Training.Losses.relativistic_average_d: true` (BCE mode), both terms are computed on relativistic logits (`D(real)-mean(D(fake))`, `D(fake)-mean(D(real))`) and additional relativistic confidence logs are emitted. 【F:opensr_srgan/model/training_step_PL.py†L135-L195】
-* **Generator (`optimizer_idx == 1`).** The generator measures content metrics once, reuses them for logging, queries the adversarial signal, and multiplies it with `_adv_loss_weight()` before combining both parts into `generator/total_loss`. In BCE + relativistic mode, generator adversarial loss is averaged from `BCE(D(fake)-mean(D(real)), 1)` and `BCE(D(real)-mean(D(fake)), 0)`. 【F:opensr_srgan/model/training_step_PL.py†L203-L247】
+When `pretrain_phase` is active:
-With `Training.Losses.adv_loss_type: wasserstein`, the same branches apply but swap the BCE terms for a critic objective: the discriminator minimises `mean(fake) - mean(real)` (plus any configured R1 penalty), and the generator minimises `-mean(D(G(x)))`. Logged probabilities remain sigmoid-squashed critic scores to keep dashboards comparable. Configure `Training.Losses.r1_gamma` to activate the real-image R1 gradient penalty popularised by Mescheder et al. for stabilising Wasserstein critics, and toggle `Discriminator.use_spectral_norm` when you want Miyato et al.'s spectral normalisation to enforce a tighter Lipschitz bound on SRGAN discriminators. 【F:opensr_srgan/model/training_step_PL.py†L129-L247】
+1. Discriminator metrics are logged as zeros (no discriminator optimizer step).
+2. Generator content loss is computed and logged.
+3. The generator optimizer performs `zero_grad -> manual_backward -> step`.
+4. EMA updates after the generator step when enabled and active.
-## Training step anatomy (Lightning 2.x)
-
-Lightning 2.x requires manual optimisation to alternate between optimisers. `training_step_PL2` mirrors the structure of the 1.x helper but drives the two optimisers explicitly:
-
-```python
-# opensr_srgan/model/training_step_PL.py
-opt_d, opt_g = self.optimizers()
-pretrain_phase = self._pretrain_check()
-self.log("training/pretrain_phase", float(pretrain_phase), sync_dist=True)
+### Adversarial branch
-if pretrain_phase:
- zero = torch.tensor(0.0, device=hr_imgs.device, dtype=hr_imgs.dtype)
- self.log("discriminator/adversarial_loss", zero, sync_dist=True)
- content_loss, metrics = self.content_loss_criterion.return_loss(sr_imgs, hr_imgs)
- self._log_generator_content_loss(content_loss)
- self._log_adv_loss_weight(_adv_weight())
- opt_g.zero_grad(); self.manual_backward(content_loss); opt_g.step()
- if self.ema and self.global_step >= self._ema_update_after_step:
- self.ema.update(self.generator)
- return content_loss
-```
+When pretraining is finished:
-The adversarial branch toggles each optimiser in turn, accumulates identical logs to the PL1.x path (including optional relativistic BCE metrics via `Training.Losses.relativistic_average_d`), and performs the EMA update after every generator step. 【F:opensr_srgan/model/training_step_PL.py†L336-L458】
+1. **Discriminator update**
+1. Compute `D(hr)` and `D(sr.detach())`.
+1. Apply BCE or Wasserstein objective (+ optional R1 penalty).
+1. Log discriminator losses/probabilities.
+1. Run `manual_backward` and `opt_d.step()`.
+2. **Generator update**
+1. Compute content loss + metrics.
+1. Compute adversarial generator objective from `D(sr)`.
+1. Apply ramped adversarial weight (`training/adv_loss_weight`).
+1. Log `generator/content_loss`, `generator/adversarial_loss`, `generator/total_loss`.
+1. Run `manual_backward` and `opt_g.step()`.
+3. EMA updates after the generator step when enabled and active.
-## Adversarial weight schedule
+## Resume behavior
-Both training-step variants call `_adv_loss_weight()` (or `_compute_adv_loss_weight()` in older modules) to retrieve the ramped coefficient that blends the adversarial and content terms. The helper logs `training/adv_loss_weight` so you can confirm whether the ramp has reached its configured `Training.Losses.adv_loss_beta`. During pretraining this value stays at zero; afterwards it climbs toward the configured maximum.
+`Model.continue_training` is passed through `build_lightning_kwargs()` and forwarded as:
-## Retraining and checkpoint flow
+```python
+trainer.fit(model, datamodule=pl_datamodule, ckpt_path=resume_ckpt)
+```
-When you relaunch an experiment with `Model.continue_training` set to the saved checkpoint path, Lightning restores optimiser states, EMA buffers, and global step counters before the next batch runs. The same logic works on both Lightning branches because the resume argument is threaded through `build_lightning_kwargs()` according to the detected version. 【F:opensr_srgan/train.py†L36-L90】【F:opensr_srgan/utils/build_trainer_kwargs.py†L16-L122】
+This restores optimizer/scheduler state, EMA state, and global step before continuing.
-## Summary of runtime checks
+## Runtime checks summary
| Check | Source | Purpose |
| --- | --- | --- |
-| PyTorch Lightning version | `SRGAN_model.setup_lightning()` | Select PL1 vs. PL2 training-step implementation and toggle manual optimisation. |
-| Continue training? | `train.py` (`Model.continue_training`) | Resume checkpoints with schedulers/EMA intact. |
-| Pretraining active? | `_pretrain_check()` | Gate between content-only updates and full GAN updates. |
-| Adversarial weight value | `_adv_loss_weight()` / `_compute_adv_loss_weight()` | Log instantaneous GAN weight and blend it into `generator/total_loss`. |
-| EMA ready to update? | `self.global_step >= self._ema_update_after_step` | Delay shadow-weight updates until the warm-up step threshold. |
-
-Keeping these checkpoints visible in the logs and documentation makes it easy to understand what happens when the trainer toggles between warm-up, adversarial learning, and resumed runs.
-
-## Branch map (full text)
-
-The textual workflow in `opensr_srgan/model/training_workflow.txt` mirrors the branches and logging described above. It is reproduced verbatim below so you can scan the entire decision tree without leaving the docs:
-
-```text
-ENTRY: trainer.fit(model, datamodule)
-│
-├─ PRELUDE (opensr_srgan/train.py)
-│ ├─ Load config (OmegaConf) and resolve device list (`Training.gpus`).
-│ ├─ Check `Model.load_checkpoint` / `Model.continue_training` to decide between fresh training vs. retraining from a checkpoint.
-│ ├─ Call `build_lightning_kwargs()` → detects PyTorch Lightning version, normalises accelerator/devices, and routes resume arguments (`resume_from_checkpoint` for PL<2, `ckpt_path` for PL≥2).
-│ └─ Instantiate `Trainer(**trainer_kwargs)` and invoke `trainer.fit(..., **fit_kwargs)`.
-│
-├─ SRGAN_model.setup_lightning()
-│ ├─ Parse `pl.__version__` into `self.pl_version`.
-│ ├─ IF `self.pl_version >= (2,0,0)`
-│ │ ├─ Set `automatic_optimization = False` (manual optimisation required by PL2).
-│ │ └─ Bind `training_step_PL2` as the active `training_step` implementation.
-│ └─ ELSE (PL1.x)
-│ ├─ Ensure `automatic_optimization is True`.
-│ └─ Bind `training_step_PL1` (optimizer_idx-based training).
-│
-└─ ACTIVE TRAINING STEP (batch, batch_idx[, optimizer_idx])
- │
- ├─ 1) Forward + metrics (no grad for logging reuse)
- │ ├─ (lr, hr) = batch
- │ ├─ sr = G(lr)
- │ └─ metrics = content_loss.return_metrics(sr, hr)
- │ └─ LOG: `train_metrics/*` (L1, SAM, perceptual, TV, PSNR, SSIM)
- │
- ├─ 2) Phase checks
- │ ├─ `pretrain = _pretrain_check()` # compare global_step vs. `Training.g_pretrain_steps`
- │ ├─ LOG: `training/pretrain_phase` (on G step for PL1, per-batch for PL2)
- │ └─ `adv_weight = _adv_loss_weight()` or `_compute_adv_loss_weight()` # ramp toward `Training.Losses.adv_loss_beta`
- │ └─ LOG: `training/adv_loss_weight`
- │
- ├─ 3) IF `pretrain` True (Generator warm-up)
- │ ├─ Generator path
- │ │ ├─ Compute `(content_loss, metrics) = content_loss.return_loss(sr, hr)`
- │ │ ├─ LOG: `generator/content_loss`
- │ │ ├─ Reuse metrics for logging (`train_metrics/*`)
- │ │ ├─ LOG: `training/adv_loss_weight` (even though weight is 0 during warm-up)
- │ │ └─ RETURN/STEP on `content_loss` only (PL1 returns scalar; PL2 manual_backward + `opt_g.step()`)
- │ └─ Discriminator path
- │ ├─ LOG zeros for `discriminator/D(y)_prob`, `discriminator/D(G(x))_prob`, `discriminator/adversarial_loss`
- │ └─ Return dummy zero tensor with `requires_grad=True` (PL1) or skip optimisation but keep logs (PL2)
- │
- └─ 4) ELSE `pretrain` False (Full GAN training)
- │
- ├─ 4A) Discriminator update
- │ ├─ hr_logits = D(hr)
- │ ├─ sr_logits = D(sr.detach())
- │ ├─ real_target = adv_target (0.9 with label smoothing else 1.0)
- │ ├─ fake_target = 0.0
- │ ├─ loss_real = BCEWithLogits(hr_logits, real_target)
- │ ├─ loss_fake = BCEWithLogits(sr_logits, fake_target)
- │ ├─ d_loss = loss_real + loss_fake
- │ ├─ LOG: `discriminator/adversarial_loss`
- │ ├─ LOG: `discriminator/D(y)_prob` = sigmoid(hr_logits).mean()
- │ ├─ LOG: `discriminator/D(G(x))_prob` = sigmoid(sr_logits).mean()
- │ └─ Optimise D (return `d_loss` in PL1; manual backward + `opt_d.step()` in PL2)
- │
- └─ 4B) Generator update
- ├─ (content_loss, metrics) = content_loss.return_loss(sr, hr)
- ├─ LOG: `generator/content_loss`
- ├─ sr_logits = D(sr)
- ├─ g_adv = BCEWithLogits(sr_logits, target=1.0)
- ├─ LOG: `generator/adversarial_loss` = g_adv
- ├─ total_loss = content_loss + adv_weight * g_adv
- ├─ LOG: `generator/total_loss`
- ├─ Optimise G (return `total_loss` in PL1; manual backward + `opt_g.step()` in PL2)
- └─ IF EMA enabled AND `global_step >= _ema_update_after_step`: update shadow weights (`EMA/update_after_step`, `EMA/is_active` logs)
-```
+| Lightning `>= 2.0` | `SRGAN_model.setup_lightning()` | Reject unsupported runtime versions. |
+| Manual optimization enabled | `setup_lightning()` | Ensure GAN optimizer alternation is explicit. |
+| Pretraining active? | `_pretrain_check()` | Gate between content-only and adversarial training. |
+| Adversarial weight | `_adv_loss_weight()` | Log and apply the current GAN loss multiplier. |
+| EMA active? | `self.global_step >= self._ema_update_after_step` | Delay EMA updates until configured step. |
+
+## Workflow map
+
+See `opensr_srgan/model/training_workflow.txt` for the full text branch map aligned to the current implementation.
diff --git a/docs/training-guideline.md b/docs/training-guideline.md
index b77aab4..3d59a08 100644
--- a/docs/training-guideline.md
+++ b/docs/training-guideline.md
@@ -4,7 +4,7 @@ This section goes over the most important metrics and settings to achieve a bala
## Best Practices
-It is recommended to use the training warmups and schedulers as explained above. The following images present how these rpactices are reflected in the logs.
+It is recommended to use the training warmups and schedulers as explained above. The following images present how these practices are reflected in the logs.
### Objectives and loss composition
@@ -16,11 +16,11 @@ Each coefficient maps directly to the `Training.Losses` block in the configurati
### Exponential Moving Average (EMA)
-For smoother validation curves and more stable inference, the trainer can maintain an exponential moving average of the generator parameters. After each optimisation step, the EMA weights $\theta_{\text{EMA}}$ are updated toward the current generator state $\theta$:
+For smoother validation curves and more stable inference, the trainer can maintain an exponential moving average of the generator parameters. After each optimisation step, the EMA weights \(\theta_{\text{EMA}}\) are updated toward the current generator state \(\theta\):
$$
\theta_{\text{EMA}}^{(t)} = \beta \, \theta_{\text{EMA}}^{(t-1)} + (1 - \beta)\, \theta^{(t)},
$$
-where the decay $\beta \in [0,1)$ controls how much history is retained. During validation and inference, the EMA snapshot replaces the live weights so that predictions are less sensitive to short-term oscillations. The final super-resolved output therefore comes from the smoothed generator,
+where the decay \(\beta \in [0,1)\) controls how much history is retained. During validation and inference, the EMA snapshot replaces the live weights so that predictions are less sensitive to short-term oscillations. The final super-resolved output therefore comes from the smoothed generator,
$$
\hat{y}_{\text{SR}} = G(x; \theta_{\text{EMA}}),
$$
@@ -28,11 +28,13 @@ which empirically reduces adversarial artefacts and improves perceptual consiste
#### Generator LR Warmup
-When starting to train, the learning rate slowly raises from 0 to the indicated value. This prevents exploding gradients after a random initialization of the weights when training the model from scratch. The length of the LR warmup is defined with the `Schedulers.g_warmup_steps` parameter in the config. Wether the increase is linear or more smooth is defined with the `Schedulers.g_warmup_type` setting, ideally this should be set to `cosine`.
+When starting to train, the learning rate slowly rises from 0 to the indicated value. This prevents exploding gradients after random initialization of the weights when training the model from scratch. The length of the LR warmup is defined with the `Schedulers.g_warmup_steps` parameter in the config. Whether the increase is linear or smoother is defined with the `Schedulers.g_warmup_type` setting; ideally this should be set to `cosine`.

#### Generator Pre-training
-After the loss stabilizes, the generator continues to be trained while the discriminator sits idle. This prevents the discriminator form overpowering the generator in early stages of the training, where the generator output is easily identifyable as synthetic. The binary flag `training/pretrain_phase` is logged to indicate wether the model is still in pretraining or not. Wether the pretraining is enabled or not is defined with the `Training.pretrain_g_only` parameter in the config, the parameter `Training.g_pretrain_steps` defines how many steps this pretraining takes in total. The parameter `Training.g_warmup_steps` decides how many training steps (batches) this smooth LR increase takes, setting it to `0` turns it off.
+After the loss stabilizes, the generator continues to be trained while the discriminator sits idle. This prevents the discriminator from overpowering the generator in early training stages, where the generator output is still easily identifiable as synthetic. The binary flag `training/pretrain_phase` is logged to indicate whether the model is still in pretraining. Whether pretraining is enabled is defined with the `Training.pretrain_g_only` parameter in the config; `Training.g_pretrain_steps` defines how many steps this pretraining takes in total. The parameter `Training.g_warmup_steps` defines how many training steps (batches) the smooth LR increase lasts; setting it to `0` turns it off.
+
+During this generator-only pretraining window, the optimization target is hardwired to plain L1 loss only. Once pretraining ends, the normal configured content-loss mix (L1/SAM/perceptual/TV) is used again.

#### Discriminator
@@ -40,8 +42,8 @@ Once the `training/pretrain_phase` flag is `0`, pretraining of the generator is

#### Continued Training
-As training continues, the generator is trying to fool the discriminator and the discriminator is trying to distinguish between true/synthetic, we monitor the overall loss of the models independantly. When the overall loss metric of one model reaches a plateau, we reduce it's learning rate in order to optimnally train the model.
-. The patience, LR decrease factor inc ase of plateau and the metric to be used for these LR schedulers are all defined individually for $G$ and $D$ in the `Schedulers.` section of the config file.
+As training continues, the generator tries to fool the discriminator, while the discriminator tries to distinguish between real and synthetic samples. We monitor the overall loss of both models independently. When the overall loss metric of one model reaches a plateau, we reduce its learning rate to train the model optimally.
+. The patience, LR decrease factor in case of a plateau, and the metric used for these LR schedulers are all defined individually for \(G\) and \(D\) in the `Schedulers` section of the config file.
The schedulers now expose a `cooldown` period and `min_lr` floor. Cooldown waits a configurable number of epochs before watching for the next plateau, preventing back-to-back reductions, while `min_lr` guarantees that the optimiser never stalls at zero. Use these knobs to keep the momentum of long trainings without overshooting into vanishing updates.
@@ -49,8 +51,15 @@ The schedulers now expose a `cooldown` period and `min_lr` floor. Cooldown waits
Both optimisers use a two-time-scale update rule (TTUR) so the discriminator defaults to a slower learning rate than the generator. The bundled Adam configuration mirrors popular GAN recipes with betas set to `(0.0, 0.99)` and `eps=1e-7`, ensuring the generator reacts quickly to discriminator feedback without building up stale momentum. Weight decay is automatically restricted to convolutional and dense kernels—normalisation layers and biases are excluded—so regularisation never interferes with running statistics. Finally, `gradient_clip_val` applies global norm clipping when set above zero; values between `0.5` and `1.0` work well when discriminator spikes cause unstable updates.
+#### ESRGAN checkerboard mitigation (10m defaults)
+
+If you observe faint checkerboard textures, especially in flat/low-frequency areas, start with:
+- `Generator.use_icnr: True` to initialise PixelShuffle pre-convolutions with ICNR.
+- `Optimizers.optim_d_lr <= 0.5 * optim_g_lr` to keep discriminator pressure in check.
+- `Training.Losses.fixed_idx: [0, 1, 2]` for 4-band inputs so VGG perceptual loss uses RGB consistently.
+
#### Final stages of the Training
-With further progression of the training, it is important not only to monitor the absolute reconstruction quality of the generator, but also to keep an eye on the balance between the generator and discriminator. Ideally, we try to reach the Nash equilibrium, where the discriminator can not distinguish between real and synthetic anymore, meaning the super-resolution is (at least fdor the discriminator) indistinguishable from the real high-resolution image. This equilibrium is achieved when both $D(y)$ and $D(G(x))$ approach `0.5`.
+With further progression of training, it is important not only to monitor the absolute reconstruction quality of the generator, but also to keep an eye on the balance between the generator and discriminator. Ideally, we try to reach the Nash equilibrium, where the discriminator cannot distinguish between real and synthetic anymore, meaning the super-resolution is (at least for the discriminator) indistinguishable from the real high-resolution image. This equilibrium is achieved when both \(D(y)\) and \(D(G(x))\) approach `0.5`.


diff --git a/docs/training.md b/docs/training.md
index 8f7bce6..73e27ff 100644
--- a/docs/training.md
+++ b/docs/training.md
@@ -2,10 +2,10 @@
`opensr_srgan/train.py` is the canonical entry point for ESA OpenSR experiments. It ties together configuration loading, model instantiation, dataset selection, logging, and callbacks. This page explains how the script is organised and how to customise the training loop.
-!!! note "PyTorch Lightning 1.x and 2.x compatible"
- The training stack now adapts automatically to the installed PyTorch Lightning release. `SRGAN_model.setup_lightning()` inspects `pytorch_lightning.__version__`, binds the legacy automatic-optimisation `training_step_PL1()` when running on 1.x, and switches to the manual-optimisation `training_step_PL2()` helper on 2.x where GAN training requires `automatic_optimization = False`. `opensr_srgan.utils.build_trainer_kwargs.build_lightning_kwargs()` mirrors this by emitting the correct `Trainer` arguments—`resume_from_checkpoint` for 1.x, `ckpt_path` for 2.x—so both Lightning branches resume, log, and step optimisers identically. See [Trainer Details](trainer-details.md) for a step-by-step breakdown of the warm-up checks, adversarial updates, and EMA lifecycle.
+!!! note "PyTorch Lightning 2+ only"
+ The training stack uses a single manual-optimisation path. `SRGAN_model.setup_lightning()` enforces Lightning >= 2.0 and binds `training_step_PL2()` where GAN training runs with `automatic_optimization = False`. `opensr_srgan.utils.build_trainer_kwargs.build_lightning_kwargs()` forwards resume checkpoints through `Trainer.fit(..., ckpt_path=...)`. See [Trainer Details](trainer-details.md) for a step-by-step breakdown of warm-up checks, adversarial updates, and EMA lifecycle.
-This section is a more technical overview, [Training Guideline](training-guideline.md) gives a more broad overview how to sirveill the training process.
+This section is a more technical overview; [Training Guideline](training-guideline.md) provides a broader overview of how to monitor the training process.
## Data module construction
In order to train, you need a dataset. `Data.dataset_type` decides which dataset to use and wraps them in a `LightningDataModule`. Should you implement your own, you will need to add it to the dataset_selector.py file with the settings of your choice (see [Data](data.md)). Optionally, the selector instantiates `ExampleDataset` by default—perfect for smoke tests after downloading the sample data, a dataset of 200 RGB-NIR image pairs. The module inherits batch sizes, worker counts, and prefetching parameters from the configuration and prints a summary including dataset size.
@@ -33,7 +33,7 @@ Both entry points accept the same configuration file. The CLI exposes a single o
GPU assignment is handled directly in the configuration. Set `Training.gpus` to a list of device indices (for example `[0, 1, 2, 3]`) to enable multi-GPU training; a single value such as `[0]` keeps the run on one card. When more than one device is listed the trainer automatically activates PyTorch Lightning's Distributed Data Parallel (DDP) backend for significantly faster epochs.
## Initialisation steps - Overview
-The code performs the following, no matter if the script is launched form the CLI or through the import.
+The code performs the following, regardless of whether the script is launched from the CLI or via import.
1. **Import dependencies.** Torch, PyTorch Lightning, OmegaConf, and logging backends are loaded up-front.
2. **Parse arguments.** `argparse` reads the configuration path and ensures the file exists.
3. **Load configuration.** `OmegaConf.load()` parses the YAML file into an object used throughout the run.
@@ -66,13 +66,13 @@ dashboard quickly reveals which subsystem is active at any given step.
| `discriminator/adversarial_loss` | Binary cross-entropy loss of the discriminator on real vs. fake batches. | Drops below ~0.7 as the discriminator learns; continues trending down when D keeps up. |
| `discriminator/D(y)_prob` | Mean discriminator confidence that HR inputs are real. | Rises toward 0.8–1.0 during stable training. |
| `discriminator/D(G(x))_prob` | Mean discriminator confidence that SR predictions are real. | Starts low (~0.0–0.2) and climbs toward 0.5 as the generator improves. |
-| `train_metrics/l1` | Mean absolute error between SR and HR tensors. | Decreases toward 0 as reconstructions sharpen. |
+| `train_metrics/l1` | Mean absolute error between SR and HR tensors. In generator-only pretraining this is the hardwired optimization target. | Decreases toward 0 as reconstructions sharpen. |
| `train_metrics/sam` | Spectral angle mapper (radians) averaged over pixels. | Falls toward 0; values <0.1 indicate strong spectral fidelity. |
| `train_metrics/perceptual` | Perceptual distance (VGG or LPIPS) on selected RGB bands. | Decreases as textures align; exact range depends on the chosen metric. |
| `train_metrics/tv` | Total variation penalty capturing SR smoothness. | Remains small; near-zero means little high-frequency noise. |
| `train_metrics/psnr` | Peak signal-to-noise ratio (dB) on normalised tensors. | Climbs above 20 dB early; mature models reach 25–35 dB depending on data. |
| `train_metrics/ssim` | Structural Similarity Index (0–1). | Increases toward 1.0; >0.8 is typical for converged runs. |
-| `generator/content_loss` | Weighted content portion of the generator objective. | Mirrors the trend of `train_metrics/*` losses and should steadily decline. |
+| `generator/content_loss` | Generator objective used for the current phase: hardwired L1 during generator-only pretraining, then weighted content loss afterwards. | Should steadily decline and remain stable when adversarial training starts. |
| `generator/total_loss` | Sum of content and adversarial terms used to update the generator. | Tracks `generator/content_loss` early, then stabilises once adversarial weight ramps in. |
| `val_metrics/l1` | Validation MAE. | Should roughly match `train_metrics/l1`; lower is better. |
| `val_metrics/sam` | Validation SAM. | Mirrors the training trend; values <0.1 rad indicate good spectra. |
diff --git a/opensr_srgan/__init__.py b/opensr_srgan/__init__.py
index 667ad05..34975d6 100644
--- a/opensr_srgan/__init__.py
+++ b/opensr_srgan/__init__.py
@@ -2,6 +2,7 @@
from __future__ import annotations
+import warnings
from typing import TYPE_CHECKING, Any
try: # pragma: no cover - import shim for Python <3.8
@@ -22,6 +23,19 @@
except PackageNotFoundError: # pragma: no cover - local source tree fallback
__version__ = "0.0.0"
+
+def _silence_known_lightning_deprecations() -> None:
+ """Suppress noisy third-party deprecations we cannot patch locally."""
+ warnings.filterwarnings(
+ "ignore",
+ message=r".*isinstance\(treespec,\s*LeafSpec\)\s*is deprecated.*",
+ category=DeprecationWarning,
+ module=r"pytorch_lightning\.utilities\._pytree",
+ )
+
+
+_silence_known_lightning_deprecations()
+
if TYPE_CHECKING: # pragma: no cover - type checkers only
from .model.SRGAN import SRGAN_model as SRGANModel
from .train import train as _train
diff --git a/opensr_srgan/configs/config_10m.yaml b/opensr_srgan/configs/config_10m.yaml
index 0d32d06..ceecefd 100644
--- a/opensr_srgan/configs/config_10m.yaml
+++ b/opensr_srgan/configs/config_10m.yaml
@@ -17,7 +17,7 @@ Data:
prefetch_factor: 2 # Samples prefetched per worker (2 is stable default)
# Dataset configuration
- dataset_type: 'SISR_WW' # Choose dataset type: ['cv', 'SPOT6', 'S2_6b', 'SISR_WW']
+ dataset_type: 'SEN2NAIP' # Choose dataset type: ['cv', 'SPOT6', 'S2_6b', 'SISR_WW']
normalization: 'normalise_10k' # Normalization strategy for data processing
@@ -36,16 +36,16 @@ Model:
Training:
# --- Hardware Setup
device: "cuda" # Runtime device backend: ['cuda', 'cpu']
- gpus: [2,3] # Number of GPUs to use, individually in list form, e.g. [0] or [0,2]
+ gpus: [0,1] # Number of GPUs to use, individually in list form, e.g. [0] or [0,2]
# --- General Training Setup
max_epochs: 9999 # Maximum number of training epochs
- val_check_interval: 0.25 # Validate at x percent of epoch (float) or every N steps (int)
- limit_val_batches: 250 # Limit number of validation batches
+ val_check_interval: 1. # Validate at x percent of epoch (float) or every N Epochs (int)
+ limit_val_batches: 100 # Limit number of validation batches
# --- Pretraining and adversarial setup ---
pretrain_g_only: True # Train generator only for initial phase
- g_pretrain_steps: 20000 # Number of generator-only warmup steps (-1 keeps pretraining active)
- adv_loss_ramp_steps: 5000 # Gradual adversarial weight ramp steps
+ g_pretrain_steps: 5000 # Number of generator-only warmup steps (-1 keeps pretraining active)
+ adv_loss_ramp_steps: 500 # Gradual adversarial weight ramp steps
label_smoothing: True # Discriminator target smoothing (1.0 → 0.9)
EMA:
@@ -64,9 +64,10 @@ Training:
# --- Content loss components (GeneratorContentLoss) ---
l1_weight: 1.0 # L1 loss over all bands
- sam_weight: 0.05 # Spectral Angle Mapper loss
- perceptual_weight: 0.2 # Perceptual similarity term weight
- perceptual_metric: 'vgg' # ['vgg', 'lpips'] - LPIPS requires pip install lpips
+ sam_weight: 0.00 # Spectral Angle Mapper loss
+ perceptual_weight: 0.05 # Perceptual similarity term weight (keep moderate for 4-band data)
+ perceptual_metric: 'vgg' # ['vgg', 'lpips'] - LPIPS requires !pip install lpips
+ fixed_idx: [0, 1, 2] # Use RGB channels for perceptual loss on multispectral data
tv_weight: 0.0 # Total Variation regularization (optional)
# --- Metric evaluation settings ---
@@ -79,7 +80,7 @@ Training:
# ---------------------------------------------------------------------------- #
# See Docs for archtecture details and suggestions
Generator:
- model_type: 'SRResNet' # Generator family: ['SRResNet', 'stochastic_gan', 'esrgan']
+ model_type: 'esrgan' # Generator family: ['SRResNet', 'stochastic_gan', 'esrgan']
block_type: 'rrdb' # SRResNet block variant: ['standard', 'res', 'rcab', 'rrdb', 'lka']
large_kernel_size: 9 # Kernel for head and tail conv layers (SRResNet/stochastic)
small_kernel_size: 3 # Kernel for intermediate blocks (SRResNet/stochastic)
@@ -88,9 +89,10 @@ Generator:
scaling_factor: 4 # Upscaling factor (e.g., 2×, 4×, 8×)
growth_channels: 32 # ESRGAN-specific RRDB growth channels (ignored otherwise)
res_scale: 0.2 # Residual scaling used by stochastic/ESRGAN variants
+ use_icnr: True # Initialize PixelShuffle convs with ICNR to reduce checkerboard artifacts
Discriminator:
- model_type: 'standard' # Discriminator architecture selector ['standard', 'patchgan', 'esrgan']
+ model_type: 'esrgan' # Discriminator architecture selector ['standard', 'patchgan', 'esrgan']
n_blocks: 8 # Convolutional depth for SRGAN/PatchGAN (ignored by ESRGAN)
use_spectral_norm: False # Apply spectral normalization to SRGAN discriminator layers for stability
base_channels: 64 # ESRGAN discriminator base feature width (ignored otherwise)
@@ -100,9 +102,9 @@ Discriminator:
# 🧮 OPTIMIZATION SETTINGS
# ---------------------------------------------------------------------------- #
Optimizers:
- optim_g_lr: 1e-4 # Learning rate for Generator
- optim_d_lr: 1e-6 # Learning rate for Discriminator
- gradient_clip_val: 1.0 # Gradient clipping value (0 disables clipping)
+ optim_g_lr: 1e-5 # Learning rate for Generator
+ optim_d_lr: 5e-5 # Learning rate for Discriminator (TTUR: slower than generator)
+ gradient_clip_val: 0 # Gradient clipping value (0 disables clipping)
betas: [0.0, 0.99] # optional
eps: 1.0e-7 # optional
weight_decay_g: 0.0 # optional
@@ -112,16 +114,14 @@ Optimizers:
# 📉 SCHEDULERS AND EARLY STOPPING
# ---------------------------------------------------------------------------- #
Schedulers:
- g_warmup_steps: 1000 # Generator warmup LR curve duration in steps (0 disables warmup)
- g_warmup_type: 'cosine' # Generator warmup curve: ['cosine', 'linear']
+ g_warmup_steps: 0 # Generator warmup LR curve duration in steps (0 disables warmup)
+ g_warmup_type: 'linear' # Generator warmup curve: ['cosine', 'linear']
metric_g: 'val_metrics/l1' # Metric monitored for Generator LR scheduler
metric_d: 'discriminator/adversarial_loss' # Metric monitored for Discriminator LR scheduler
patience_g: 10 # Patience (epochs) for Generator LR scheduler
patience_d: 10 # Patience (epochs) for Discriminator LR scheduler
factor_g: 0.5 # LR reduction factor for Generator
factor_d: 0.5 # LR reduction factor for Discriminator
- verbose: True # Enable scheduler logging output
-
# ============================================================================ #
# 🧾 LOGGING SETTINGS
diff --git a/opensr_srgan/configs/config_20m.yaml b/opensr_srgan/configs/config_20m.yaml
index d15b884..897acb2 100644
--- a/opensr_srgan/configs/config_20m.yaml
+++ b/opensr_srgan/configs/config_20m.yaml
@@ -121,8 +121,6 @@ Schedulers:
patience_d: 50 # Patience (epochs) for Discriminator LR scheduler
factor_g: 0.5 # LR reduction factor for Generator
factor_d: 0.5 # LR reduction factor for Discriminator
- verbose: True # Enable scheduler logging output
-
# ============================================================================ #
# 🧾 LOGGING SETTINGS
diff --git a/opensr_srgan/configs/config_playgound.yaml b/opensr_srgan/configs/config_playgound.yaml
index d6d9abd..a6f8df0 100644
--- a/opensr_srgan/configs/config_playgound.yaml
+++ b/opensr_srgan/configs/config_playgound.yaml
@@ -121,8 +121,6 @@ Schedulers:
patience_d: 10 # Patience (epochs) for Discriminator LR scheduler
factor_g: 0.5 # LR reduction factor for Generator
factor_d: 0.5 # LR reduction factor for Discriminator
- verbose: True # Enable scheduler logging output
-
# ============================================================================ #
# 🧾 LOGGING SETTINGS
diff --git a/opensr_srgan/configs/config_training_example.yaml b/opensr_srgan/configs/config_training_example.yaml
index ed41358..9eb704c 100644
--- a/opensr_srgan/configs/config_training_example.yaml
+++ b/opensr_srgan/configs/config_training_example.yaml
@@ -120,8 +120,6 @@ Schedulers:
patience_d: 10 # Patience (epochs) for Discriminator LR scheduler
factor_g: 0.5 # LR reduction factor for Generator
factor_d: 0.5 # LR reduction factor for Discriminator
- verbose: True # Enable scheduler logging output
-
# ============================================================================ #
# 🧾 LOGGING SETTINGS
diff --git a/opensr_srgan/data/dataset_selector.py b/opensr_srgan/data/dataset_selector.py
index 65213b2..45cc346 100644
--- a/opensr_srgan/data/dataset_selector.py
+++ b/opensr_srgan/data/dataset_selector.py
@@ -2,6 +2,7 @@
LRHR_FOLDER_DATASET_ROOT = "data/"
+SEN2NAIP_TACO_FILE = "/data1/datasets/SEN2NAIP/sen2naipv2-crosssensor.taco"
def select_dataset(config):
@@ -41,6 +42,12 @@ def select_dataset(config):
ds_train = ExampleDataset(folder=path, phase="train")
ds_val = ExampleDataset(folder=path, phase="val")
+ elif str(dataset_selection).lower() == "sen2naip":
+ from opensr_srgan.data.sen2naip.sen2naip_dataset import SEN2NAIP
+
+ ds_train = SEN2NAIP(config=config, phase="train", taco_file=SEN2NAIP_TACO_FILE)
+ ds_val = SEN2NAIP(config=config, phase="val", taco_file=SEN2NAIP_TACO_FILE)
+
elif dataset_selection == "LRHRFolderDataset":
from opensr_srgan.data.lrhr_folder.lrhr_folder_dataset import LRHRFolderDataset
diff --git a/opensr_srgan/data/sen2naip/download_sen2naip.sh b/opensr_srgan/data/sen2naip/download_sen2naip.sh
new file mode 100644
index 0000000..445c4ab
--- /dev/null
+++ b/opensr_srgan/data/sen2naip/download_sen2naip.sh
@@ -0,0 +1,37 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# =========================
+# Configuration
+# =========================
+VENV_PATH="/work/envs/srgan"
+DATA_ROOT="/data1/datasets"
+DATASET_NAME="SEN2NAIP"
+FILENAME="sen2naipv2-crosssensor.taco"
+URL="https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2/resolve/main/${FILENAME}"
+
+TARGET_DIR="${DATA_ROOT}/${DATASET_NAME}"
+
+# =========================
+# Activate environment
+# =========================
+source "${VENV_PATH}/bin/activate"
+
+# =========================
+# Prepare directory
+# =========================
+mkdir -p "${TARGET_DIR}"
+echo "Downloading to: ${TARGET_DIR}"
+sleep 2
+
+# =========================
+# Download
+# =========================
+aria2c \
+ --max-connection-per-server=8 \
+ --split=8 \
+ --dir="${TARGET_DIR}" \
+ --out="${FILENAME}" \
+ "${URL}"
+
+echo "Download finished: ${TARGET_DIR}/${FILENAME}"
diff --git a/opensr_srgan/data/sen2naip/sen2naip_dataset.py b/opensr_srgan/data/sen2naip/sen2naip_dataset.py
new file mode 100644
index 0000000..3a90cb5
--- /dev/null
+++ b/opensr_srgan/data/sen2naip/sen2naip_dataset.py
@@ -0,0 +1,106 @@
+from __future__ import annotations
+
+import importlib
+import sys
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+
+from opensr_srgan.data.utils.normalizer import Normalizer
+
+
+class SEN2NAIP(torch.utils.data.Dataset):
+ """SEN2NAIP cross-sensor dataset loader backed by a Taco manifest."""
+
+ def __init__(self, config: Any, phase=None, taco_file=None):
+ if config is None:
+ raise ValueError("SEN2NAIP requires a config object.")
+
+ if isinstance(config, (str, Path)):
+ from omegaconf import OmegaConf
+
+ config = OmegaConf.load(config)
+
+ data_cfg = getattr(config, "Data", None)
+ if data_cfg is None:
+ raise ValueError("SEN2NAIP requires config.Data.")
+
+ if taco_file is None:
+ taco_file = getattr(data_cfg, "sen2naip_taco_file", None)
+ if phase is None:
+ phase = getattr(data_cfg, "sen2naip_phase", "train")
+ val_fraction = getattr(data_cfg, "sen2naip_val_fraction", 0.1)
+
+ if not taco_file:
+ raise ValueError("SEN2NAIP requires config.Data.sen2naip_taco_file.")
+ if phase not in {"train", "val"}:
+ raise ValueError(f"Unknown phase '{phase}'. Expected one of: train, val.")
+ if not (0.0 < val_fraction < 1.0):
+ raise ValueError(
+ f"val_fraction must be in (0, 1), received {val_fraction}."
+ )
+
+ try:
+ import tacoreader
+ except ImportError as exc:
+ raise ImportError(
+ "SEN2NAIP requires 'tacoreader'. Install it via "
+ "'pip install tacoreader==0.6.5'."
+ ) from exc
+
+ self.dataset = tacoreader.load(taco_file)
+ self.normalizer = Normalizer(config)
+
+ total = len(self.dataset)
+ if total < 2:
+ raise ValueError(
+ "SEN2NAIP dataset requires at least 2 samples to create train/val splits."
+ )
+ split_idx = int(round(total * (1.0 - val_fraction)))
+ split_idx = min(max(split_idx, 1), total - 1)
+
+ if phase == "train":
+ self.indices = list(range(0, split_idx))
+ else:
+ self.indices = list(range(split_idx, total))
+
+ def __len__(self):
+ return len(self.indices)
+
+ @staticmethod
+ def _to_tensor(data: np.ndarray) -> torch.Tensor:
+ tensor = torch.from_numpy(data).float()
+ if tensor.ndim == 2:
+ tensor = tensor.unsqueeze(0)
+ return tensor
+
+ def __getitem__(self, idx):
+ rio = sys.modules.get("rasterio")
+ if rio is None:
+ rio = importlib.import_module("rasterio")
+
+ sample = self.dataset.read(self.indices[idx])
+ lr_path = sample.read(0)
+ hr_path = sample.read(1)
+
+ with rio.open(lr_path) as src, rio.open(hr_path) as dst:
+ lr_data = src.read()
+ hr_data = dst.read()
+
+ lr = self._to_tensor(lr_data)
+ hr = self._to_tensor(hr_data)
+
+ lr = self.normalizer.normalize(lr)
+ hr = self.normalizer.normalize(hr)
+
+ return lr, hr
+
+
+if __name__ == "__main__":
+ ds = SEN2NAIP(
+ config="opensr_srgan/configs/config_10m.yaml",
+ phase="train",
+ taco_file="/data1/datasets/SEN2NAIP/sen2naipv2-crosssensor.taco",
+ )
diff --git a/opensr_srgan/model/SRGAN.py b/opensr_srgan/model/SRGAN.py
index 8b89306..ec44ed8 100644
--- a/opensr_srgan/model/SRGAN.py
+++ b/opensr_srgan/model/SRGAN.py
@@ -4,7 +4,6 @@
from contextlib import nullcontext
from pathlib import Path
from types import MethodType
-from typing import Optional
import numpy as np
import pytorch_lightning as pl
@@ -40,7 +39,7 @@ class SRGAN_model(pl.LightningModule):
------------
- **Backbone flexibility:** Select generator/discriminator architectures via config.
- **Training modes:** Generator pretraining, adversarial training, and LR warm-up.
- - **PL compatibility:** Automatic optimization for PL < 2.0; manual optimization for PL ≥ 2.0.
+ - **Lightning 2+ training:** Manual optimization for multi-optimizer GAN updates.
- **EMA support:** Optional EMA tracking with delayed activation and device placement.
- **Metrics & logging:** Content/perceptual metrics, LR logging, and optional W&B image logs.
- **Inference helpers:** Normalization/denormalization for 0–10000 reflectance and histogram matching.
@@ -77,10 +76,10 @@ class SRGAN_model(pl.LightningModule):
Behavior & versioning
---------------------
- - **PL ≥ 2.0**: Manual optimization (`automatic_optimization = False`). The bound
- `training_step_PL2` performs explicit `zero_grad/step` calls and handles EMA updates.
- - **PL < 2.0**: Automatic optimization. The legacy `training_step_PL1` is used, and
- `optimizer_step` coordinates stepping and EMA after generator updates.
+ - **PyTorch Lightning ≥ 2.0** is required.
+ - Manual optimization (`automatic_optimization = False`) is used for GAN updates.
+ - The bound `training_step_PL2` helper performs explicit `zero_grad/step` calls and
+ handles EMA updates.
Created attributes (non-exhaustive)
-----------------------------------
@@ -143,11 +142,10 @@ def __init__(self, config="config.yaml", mode="train"):
# ======================================================================
# SECTION: Set Variables
- # Purpose: Set config and mode variables model-wide, including PL version.
+ # Purpose: Set config and mode variables model-wide.
# ======================================================================
self.config = config
self.mode = mode
- self.pl_version = tuple(int(x) for x in pl.__version__.split("."))
self.normalizer = Normalizer(self.config)
# ======================================================================
@@ -182,7 +180,7 @@ def __init__(self, config="config.yaml", mode="train"):
# ======================================================================
# SECTION: Set up Training Strategy
- # Purpose: Depending on PL version, set up optimizers, schedulers, etc.
+ # Purpose: Configure Lightning hooks and optimization mode.
# ======================================================================
self.setup_lightning() # dynamically builds and attaches generator + discriminator
@@ -311,48 +309,25 @@ def get_models(self, mode):
)
def setup_lightning(self):
- """Configure PyTorch Lightning behavior based on the detected version.
-
- This method ensures compatibility between different versions of
- PyTorch Lightning (PL) by setting appropriate optimization modes
- and binding the correct training step implementation.
-
- - For PL ≥ 2.0: Enables **manual optimization**, required for GAN training.
- - For PL < 2.0: Uses **automatic optimization** and the legacy training step.
-
- The selected training step function (`training_step_PL1` or `training_step_PL2`)
- is dynamically attached to the model as `_training_step_implementation`.
+ """Configure the Lightning 2+ training hooks for manual optimization.
Raises:
- AssertionError: If `automatic_optimization` is incorrectly set for PL < 2.0.
- RuntimeError: If the detected PyTorch Lightning version is unsupported.
-
- Attributes:
- automatic_optimization (bool): Indicates whether Lightning manages
- optimizer steps automatically.
- _training_step_implementation (Callable): Bound training step function
- corresponding to the active PL version.
+ RuntimeError: If the installed PyTorch Lightning major version is < 2.
"""
- # Check for PL version - Define PL Hooks accordingly
- if self.pl_version >= (2, 0, 0):
- self.automatic_optimization = False # manual optimization for PL 2.x
- # Set up Training Step
- from opensr_srgan.model.training_step_PL import training_step_PL2
-
- self._training_step_implementation = MethodType(training_step_PL2, self)
- elif self.pl_version < (2, 0, 0):
- assert (
- self.automatic_optimization is True
- ), "For PL <2.0, automatic_optimization must be True."
- # Set up Training Step
- from opensr_srgan.model.training_step_PL import training_step_PL1
-
- self._training_step_implementation = MethodType(training_step_PL1, self)
- else:
+ major_token = str(pl.__version__).split(".", 1)[0]
+ major_digits = "".join(ch for ch in major_token if ch.isdigit())
+ major_version = int(major_digits) if major_digits else 0
+ if major_version < 2:
raise RuntimeError(
- f"Unsupported PyTorch Lightning version: {pl.__version__}"
+ "OpenSR-SRGAN requires PyTorch Lightning >= 2.0. "
+ f"Found version: {pl.__version__}."
)
+ self.automatic_optimization = False # manual optimization for Lightning 2+
+ from opensr_srgan.model.training_step_PL import training_step_PL2
+
+ self._training_step_implementation = MethodType(training_step_PL2, self)
+
def initialize_ema(self):
"""Initialize the Exponential Moving Average (EMA) mechanism for the generator.
@@ -460,97 +435,9 @@ def predict_step(self, lr_imgs):
sr_imgs = sr_imgs.cpu().detach() # detach from graph for inference output
return sr_imgs
- def training_step(
- self, batch, batch_idx, optimizer_idx: Optional[int] = None, *args
- ):
- """Dispatch the correct training step implementation based on PyTorch Lightning version.
-
- This method acts as a compatibility layer between different PyTorch Lightning
- versions that handle multi-optimizer GAN training differently.
-
- - For PL ≥ 2.0: Manual optimization is used, and the optimizer index is not passed.
- - For PL < 2.0: Automatic optimization is used, and the optimizer index is passed
- to handle generator/discriminator updates separately.
-
- Args:
- batch (Any): A batch of training data (input tensors and targets as defined by the DataModule).
- batch_idx (int): Index of the current batch within the epoch.
- optimizer_idx (int | None, optional): Index of the active optimizer (0 for generator,
- 1 for discriminator) when using PL < 2.0.
- *args: Additional arguments that may be passed by older Lightning versions.
-
- Returns:
- Any: The output of the active training step implementation, loss value.
- """
- # Depending on PL version, and depending on the manual optimization
- if self.pl_version >= (2, 0, 0):
- # In PL2.x, optimizer_idx is not passed, manual optimization is performed
- return self._training_step_implementation(batch, batch_idx) # no optim_idx
- else:
- # In Pl1.x, optimizer_idx arrives twice and is passed on
- return self._training_step_implementation(
- batch, batch_idx, optimizer_idx
- ) # pass optim_idx
-
- def optimizer_step(
- self,
- epoch,
- batch_idx,
- optimizer,
- optimizer_idx=None,
- optimizer_closure=None,
- **kwargs, # absorbs on_tpu/using_lbfgs/etc across PL versions
- ):
- """Custom optimizer step handling for PL 1.x automatic optimization.
-
- This method ensures correct behavior across different PyTorch Lightning
- versions and training modes. It is invoked automatically during training
- in PL < 2.0 when `automatic_optimization=True`. For PL ≥ 2.0, where manual
- optimization is used, this function is effectively bypassed.
-
- - In **PL ≥ 2.0 (manual optimization)**: The optimizer step is explicitly
- called within `training_step_PL2()`, including EMA updates.
- - In **PL < 2.0 (automatic optimization)**: This function manages optimizer
- stepping, gradient zeroing, and optional EMA updates after generator steps.
-
- Args:
- epoch (int): Current training epoch.
- batch_idx (int): Index of the current batch.
- optimizer (torch.optim.Optimizer): The active optimizer instance.
- optimizer_idx (int, optional): Index of the optimizer being stepped
- (e.g., 0 for discriminator, 1 for generator).
- optimizer_closure (Callable, optional): Closure for re-evaluating the
- model and loss before optimizer step (used with some optimizers).
- **kwargs: Additional arguments passed by PL depending on backend
- (e.g., TPU flags, LBFGS options).
-
- Notes:
- - EMA updates are performed only after generator steps (optimizer_idx == 1).
- - The update starts after `self._ema_update_after_step` global steps.
-
- """
- # If we're in manual optimization (PL >=2 path), do nothing special.
- if not self.automatic_optimization:
- # In manual mode we call opt.step()/zero_grad() in training_step_PL2.
- # In manual mode, we update EMA weights manually in training step too.
- return super().optimizer_step(
- epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **kwargs
- )
-
- # ---- PL 1.x auto-optimization path ----
- if optimizer_closure is not None:
- optimizer.step(closure=optimizer_closure)
- else:
- optimizer.step()
- optimizer.zero_grad()
-
- # EMA after the generator step (assumes G is optimizer_idx == 1)
- if (
- self.ema is not None
- and optimizer_idx == 1
- and self.global_step >= self._ema_update_after_step
- ):
- self.ema.update(self.generator)
+ def training_step(self, batch, batch_idx):
+ """Run one Lightning 2+ manual-optimization training step."""
+ return self._training_step_implementation(batch, batch_idx)
@torch.no_grad()
def validation_step(self, batch, batch_idx):
@@ -746,7 +633,7 @@ def on_test_epoch_end(self):
def configure_optimizers(self):
"""
- Robust optimizers & schedulers for GANs (PL1 & PL2 compatible).
+ Robust optimizers & schedulers for GANs (Lightning 2+ manual optimization).
- TTUR by default (D lr <= G lr)
- Adam with GAN-friendly betas/eps
@@ -845,7 +732,6 @@ def _adam(params, lr):
threshold_mode="rel",
cooldown=int(getattr(cfg_sch, "cooldown", 0)),
min_lr=float(getattr(cfg_sch, "min_lr", 1e-7)),
- verbose=bool(getattr(cfg_sch, "verbose", False)),
)
# D can have its own factor/patience; fall back to G’s if not set
sched_kwargs_d = dict(sched_kwargs)
@@ -882,16 +768,21 @@ def _adam(params, lr):
warmup_steps = int(getattr(cfg_sch, "g_warmup_steps", 0))
warmup_type = str(getattr(cfg_sch, "g_warmup_type", "none")).lower()
if warmup_steps > 0 and warmup_type in {"linear", "cosine"}:
+ cfg_g_lr = float(getattr(cfg_opt, "optim_g_lr", 1e-4))
+ min_warmup_lr = 0.05 * cfg_g_lr
def _g_warmup_lambda(step: int) -> float:
if step >= warmup_steps:
return 1.0
t = (step + 1) / max(1, warmup_steps)
- return (
+ raw = (
t
if warmup_type == "linear"
else 0.5 * (1.0 - math.cos(math.pi * t))
)
+ # Clamp using an LR floor derived from configured generator LR.
+ target_lr = max(raw * cfg_g_lr, min_warmup_lr)
+ return target_lr / cfg_g_lr
warmup_g = torch.optim.lr_scheduler.LambdaLR(
optimizer_g, lr_lambda=_g_warmup_lambda
diff --git a/opensr_srgan/model/generators/esrgan.py b/opensr_srgan/model/generators/esrgan.py
index c2b21bb..2ffe241 100644
--- a/opensr_srgan/model/generators/esrgan.py
+++ b/opensr_srgan/model/generators/esrgan.py
@@ -95,6 +95,7 @@ def __init__(
growth_channels: int = 32,
res_scale: float = 0.2,
scale: int = 4,
+ use_icnr: bool = True,
) -> None:
super().__init__()
@@ -117,10 +118,27 @@ def __init__(
self.conv_first = nn.Conv2d(in_channels, n_features, 3, padding=1)
self.body = nn.Sequential(*body_blocks)
self.conv_body = nn.Conv2d(n_features, n_features, 3, padding=1)
- self.upsampler = nn.Identity() if scale == 1 else make_upsampler(n_features, scale)
+ self.upsampler = (
+ nn.Identity()
+ if scale == 1
+ else make_upsampler(n_features, scale, use_icnr=use_icnr)
+ )
self.conv_hr = nn.Conv2d(n_features, n_features, 3, padding=1)
self.activation = nn.LeakyReLU(0.2, inplace=True)
self.conv_last = nn.Conv2d(n_features, out_channels, 3, padding=1)
+ self._init_esrgan_weights()
+
+ def _init_esrgan_weights(self) -> None:
+ """Apply ESRGAN-style small initialization for stable early training."""
+ for module in self.modules():
+ if isinstance(module, nn.Conv2d):
+ # ICNR-upsample convs are already initialized in make_upsampler.
+ if module in self.upsampler.modules():
+ continue
+ nn.init.kaiming_normal_(module.weight, a=0.2, mode="fan_in", nonlinearity="leaky_relu")
+ module.weight.data.mul_(0.1)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
def forward(self, x: Tensor) -> Tensor:
"""
diff --git a/opensr_srgan/model/generators/factory.py b/opensr_srgan/model/generators/factory.py
index 066dca0..4cbcd8c 100644
--- a/opensr_srgan/model/generators/factory.py
+++ b/opensr_srgan/model/generators/factory.py
@@ -291,6 +291,7 @@ def build_generator(config: Any) -> nn.Module:
growth_channels = int(getattr(generator_cfg, "growth_channels", 32))
res_scale = float(getattr(generator_cfg, "res_scale", 0.2))
out_channels = int(getattr(generator_cfg, "out_channels", in_channels))
+ use_icnr = bool(getattr(generator_cfg, "use_icnr", True))
_warn_overridden_options(
"Generator",
@@ -312,6 +313,7 @@ def build_generator(config: Any) -> nn.Module:
growth_channels=growth_channels,
res_scale=res_scale,
scale=scale,
+ use_icnr=use_icnr,
)
raise ValueError(
diff --git a/opensr_srgan/model/model_blocks/__init__.py b/opensr_srgan/model/model_blocks/__init__.py
index a15e474..6acb176 100644
--- a/opensr_srgan/model/model_blocks/__init__.py
+++ b/opensr_srgan/model/model_blocks/__init__.py
@@ -5,6 +5,7 @@
import math
import torch
from torch import nn
+from torch.nn import init
from .EMA import ExponentialMovingAverage
@@ -251,14 +252,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.res_scale * y
-def make_upsampler(n_channels: int, scale: int) -> nn.Sequential:
+def _icnr_(weight: torch.Tensor, scale: int = 2) -> None:
+ """Apply ICNR initialization to a pre-pixel-shuffle convolution kernel."""
+
+ out_channels, in_channels, k1, k2 = weight.shape
+ if out_channels % (scale**2) != 0:
+ raise ValueError("ICNR requires out_channels divisible by scale**2.")
+ subkernel = torch.empty(
+ out_channels // (scale**2),
+ in_channels,
+ k1,
+ k2,
+ device=weight.device,
+ dtype=weight.dtype,
+ )
+ init.kaiming_normal_(subkernel)
+ subkernel = subkernel.repeat_interleave(scale**2, dim=0)
+ with torch.no_grad():
+ weight.copy_(subkernel)
+
+
+def make_upsampler(n_channels: int, scale: int, *, use_icnr: bool = False) -> nn.Sequential:
"""Create a pixel-shuffle upsampler matching the flexible generator implementation."""
stages: list[nn.Module] = []
for _ in range(int(math.log2(scale))):
+ conv = nn.Conv2d(n_channels, n_channels * 4, 3, padding=1)
+ if use_icnr:
+ _icnr_(conv.weight, scale=2)
+ if conv.bias is not None:
+ init.zeros_(conv.bias)
stages.extend(
[
- nn.Conv2d(n_channels, n_channels * 4, 3, padding=1),
+ conv,
nn.PixelShuffle(2),
nn.PReLU(),
]
diff --git a/opensr_srgan/model/training_step_PL.py b/opensr_srgan/model/training_step_PL.py
index 5f425d9..4044ea2 100644
--- a/opensr_srgan/model/training_step_PL.py
+++ b/opensr_srgan/model/training_step_PL.py
@@ -1,340 +1,14 @@
import torch
-def training_step_PL1(self, batch, batch_idx, optimizer_idx):
- """One training step for PL < 2.0 using automatic optimization and multi-optimizers.
-
- Implements GAN training with two optimizers (D first, then G) and a
- pretraining gate. During the **pretraining phase**, only the generator
- (optimizer_idx == 1) is optimized with content loss; the discriminator
- branch returns a dummy loss and logs zeros. During **adversarial training**,
- the discriminator minimizes BCE on real HR vs. fake SR logits, and the
- generator minimizes content loss plus a ramped adversarial loss.
-
- Args:
- batch (Tuple[torch.Tensor, torch.Tensor]): `(lr_imgs, hr_imgs)` with shape `(B, C, H, W)`.
- batch_idx (int): Global batch index for the current epoch.
- optimizer_idx (int): Active optimizer index provided by Lightning:
- - `0`: Discriminator step.
- - `1`: Generator step.
-
- Returns:
- torch.Tensor:
- - **Pretraining**:
- - `optimizer_idx == 1`: content loss tensor for the generator.
- - `optimizer_idx == 0`: dummy scalar tensor with `requires_grad=True`.
- - **Adversarial training**:
- - `optimizer_idx == 0`: discriminator BCE loss (real + fake).
- - `optimizer_idx == 1`: generator total loss = content + λ_adv · BCE(G).
-
- Logged Metrics (selection):
- - `"training/pretrain_phase"`: 1.0 during pretraining (logged on G step).
- - `"train_metrics/*"`: content metrics from the content loss criterion.
- - `"generator/content_loss"`, `"generator/adversarial_loss"`, `"generator/total_loss"`.
- - `"discriminator/adversarial_loss"`, `"discriminator/D(y)_prob"`,
- `"discriminator/D(G(x))_prob"`.
- - `"training/adv_loss_weight"`: current λ_adv from the ramp scheduler.
-
- Notes:
- - Discriminator step uses `sr_imgs.detach()` to prevent G gradients.
- - Adversarial loss weight λ_adv ramps from 0 → `adv_loss_beta` per configured schedule.
- - Assumes optimizers are ordered as `[D, G]` in `configure_optimizers()`.
- """
-
- # -------- CREATE SR DATA --------
- lr_imgs, hr_imgs = batch # unpack LR/HR tensors from dataloader batch
- sr_imgs = self.forward(
- lr_imgs
- ) # forward pass of the generator to produce SR from LR
-
- # Default to standard GAN loss if adv_loss_type is not defined (e.g., lightweight
- # harnesses in tests). The real model sets this attribute during init.
- use_wasserstein = getattr(self, "adv_loss_type", "gan") == "wasserstein"
- use_relativistic = bool(getattr(self, "relativistic_average_d", False))
-
- # ======================================================================
- # SECTION: Pretraining phase gate
- # Purpose: decide if we are in the content-only pretrain stage.
- # ======================================================================
-
- # -------- DETERMINE PRETRAINING --------
- pretrain_phase = (
- self._pretrain_check()
- ) # check schedule: True => content-only pretraining
- if optimizer_idx == 1: # log whether pretraining is active or not
- self.log(
- "training/pretrain_phase",
- float(pretrain_phase),
- prog_bar=False,
- sync_dist=True,
- ) # log once per G step to track phase state
-
- # ======================================================================
- # SECTION: Pretraining branch (delegated)
- # Purpose: during pretrain, only content loss for G and dummy logging for D.
- # ======================================================================
-
- # -------- IF PRETRAIN: delegate --------
- if pretrain_phase:
- # run pretrain step separately and return loss here
- if optimizer_idx == 1:
- content_loss, metrics = self.content_loss_criterion.return_loss(
- sr_imgs, hr_imgs
- ) # compute perceptual/content loss (e.g., VGG or L1)
- self._log_generator_content_loss(
- content_loss
- ) # log content loss for G (consistent args)
- for key, value in metrics.items():
- self.log(
- f"train_metrics/{key}", value, sync_dist=True
- ) # reuse computed metrics for logging
-
- # Ensure adversarial weight is logged even when not used during pretraining
- adv_weight = self._compute_adv_loss_weight()
- self._log_adv_loss_weight(adv_weight)
- return content_loss # return loss for optimizer step (G only)
-
- # ======================================================================
- # SECTION: Discriminator (D) pretraining step
- # Purpose: no real training — just log zeros and return dummy loss to satisfy closure.
- # ======================================================================
- elif optimizer_idx == 0:
- device, dtype = (
- hr_imgs.device,
- hr_imgs.dtype,
- ) # get tensor device and dtype for consistency
- zero = torch.tensor(
- 0.0, device=device, dtype=dtype
- ) # define reusable zero tensor
-
- # --- Log dummy discriminator "opinions" (always zero during pretrain) ---
- self.log(
- "discriminator/D(y)_prob", zero, prog_bar=True, sync_dist=True
- ) # fake real-prob (always 0)
- self.log(
- "discriminator/D(G(x))_prob", zero, prog_bar=True, sync_dist=True
- ) # fake fake-prob (always 0)
-
- # --- Create dummy scalar loss (ensures PL closure runs) ---
- dummy = torch.zeros(
- (), device=device, dtype=dtype, requires_grad=True
- ) # dummy value with grad for optimizer compatibility
- self.log(
- "discriminator/adversarial_loss", dummy, sync_dist=True
- ) # log dummy adversarial loss (always 0)
- return dummy
- # -------- END PRETRAIN --------
-
- # ======================================================================
- # SECTION: Adversarial training — Discriminator step
- # Purpose: update D to distinguish HR (real) vs SR (fake).
- # ======================================================================
-
- # -------- Normal Train: Discriminator Step --------
- if optimizer_idx == 0:
- r1_gamma = getattr(self, "r1_gamma", 0.0) # default to 0 for
- hr_imgs.requires_grad_(r1_gamma > 0) # enable grad for R1 penalty if needed
-
- # run discriminator and get loss between pred labels and true labels
- hr_discriminated = self.discriminator(hr_imgs) # D(real): logits for HR images
- sr_discriminated = self.discriminator(
- sr_imgs.detach()
- ) # detach so G doesn’t get gradients from D’s step
-
- # Check for WS GAN loss
- if use_wasserstein: # Wasserstein GAN loss
- loss_real = -hr_discriminated.mean()
- loss_fake = sr_discriminated.mean()
- else:
- # Standard GAN loss (BCE)
- real_target = torch.full_like(
- hr_discriminated, self.adv_target
- ) # get labels/fuzzy labels
- fake_target = torch.zeros_like(
- sr_discriminated
- ) # zeros, since generative prediction
- if self.relativistic_average_d:
- # Relativistic Average GAN loss
-
- # Calculate real and fake means
- real_mean = hr_discriminated.mean()
- fake_mean = sr_discriminated.mean()
-
- sr_discriminated_rel = sr_discriminated - real_mean
- hr_discriminated_rel = hr_discriminated - fake_mean
-
- loss_real = self.adversarial_loss_criterion(
- hr_discriminated_rel, real_target
- ) # BCEWithLogitsLoss for D(y)
-
- loss_fake = self.adversarial_loss_criterion(
- sr_discriminated_rel, fake_target
- ) # BCEWithLogitsLoss for D(G(x))
- else: # Standard GAN loss without relativistic average
- # Binary Cross-Entropy loss
- loss_real = self.adversarial_loss_criterion(
- hr_discriminated, real_target
- ) * 0.5 # BCEWithLogitsLoss for D(y)
-
- loss_fake = self.adversarial_loss_criterion(
- sr_discriminated, fake_target
- ) * 0.5 # BCEWithLogitsLoss for D(G(x))
-
- # R1 Gradient Penalty (if enabled)
- r1_penalty = torch.zeros((), device=hr_imgs.device, dtype=hr_imgs.dtype)
- if r1_gamma > 0:
- grad_real = torch.autograd.grad(
- outputs=hr_discriminated.sum(),
- inputs=hr_imgs,
- create_graph=True,
- retain_graph=True,
- )[0]
- grad_penalty = grad_real.pow(2).reshape(grad_real.size(0), -1).sum(dim=1)
- r1_penalty = 0.5 * r1_gamma * grad_penalty.mean()
-
- # Sum up losses
- adversarial_loss = (
- loss_real + loss_fake + r1_penalty
- ) # add 0s for R1 if disabled
- self.log(
- "discriminator/adversarial_loss", adversarial_loss, sync_dist=True
- ) # log weighted loss
- self.log(
- "discriminator/r1_penalty", r1_penalty.detach(), sync_dist=True
- ) # log R1 penalty regarless, is 0 when turned off
-
- # [LOG-B] Always log D opinions: real probs in normal training
- with torch.no_grad():
- d_real_prob = torch.sigmoid(
- hr_discriminated
- ).mean() # estimate mean real probability
- d_fake_prob = torch.sigmoid(
- sr_discriminated
- ).mean() # estimate mean fake probability
- self.log(
- "discriminator/D(y)_prob", d_real_prob, prog_bar=True, sync_dist=True
- ) # log D(real) confidence
- self.log(
- "discriminator/D(G(x))_prob", d_fake_prob, prog_bar=True, sync_dist=True
- ) # log D(fake) confidence
-
-
- if self.relativistic_average_d:
- # Previous log of D opinions are not useful in RaGAN, log relativistic ones
- with torch.no_grad():
- real_mean = hr_discriminated.mean()
- fake_mean = sr_discriminated.mean()
-
- sr_discriminated_rel = sr_discriminated - real_mean
- hr_discriminated_rel = hr_discriminated - fake_mean
-
- d_real_prob_rel = torch.sigmoid(
- hr_discriminated_rel
- ).mean() # estimate mean real probability
- d_fake_prob_rel = torch.sigmoid(
- sr_discriminated_rel
- ).mean() # estimate mean fake probability
-
- self.log(
- "train_metrics/discriminator/D(y)_prob_relativistic",
- d_real_prob_rel,
- prog_bar=True,
- sync_dist=True,
- ) # log D(real) confidence
- self.log(
- "train_metrics/discriminator/D(G(x))_prob_relativistic",
- d_fake_prob_rel,
- prog_bar=True,
- sync_dist=True,
- ) # log D(fake) confidence
-
- # return weighted discriminator loss
- return adversarial_loss # PL will use this to step the D optimizer
-
- # ======================================================================
- # SECTION: Adversarial training — Generator step
- # Purpose: update G to minimize content loss + (weighted) adversarial loss.
- # ======================================================================
-
- # -------- Normal Train: Generator Step --------
- if optimizer_idx == 1:
-
- """1. Get VGG space loss"""
- # encode images
- content_loss, metrics = self.content_loss_criterion.return_loss(
- sr_imgs, hr_imgs
- ) # perceptual/content criterion (e.g., VGG)
- self._log_generator_content_loss(
- content_loss
- ) # log content loss for G (consistent args)
- for key, value in metrics.items():
- self.log(
- f"train_metrics/{key}", value, sync_dist=True
- ) # log detailed metrics without extra forward passes
-
- """ 2. Get Discriminator Opinion and loss """
- # run discriminator and get loss between pred labels and true labels
- sr_discriminated = self.discriminator(
- sr_imgs
- ) # D(SR): logits for generator outputs
- if use_wasserstein: # Wasserstein GAN loss
- adversarial_loss = -sr_discriminated.mean()
- else:
- if self.relativistic_average_d:
- # Relativistic Average GAN loss for G
-
- # Calculate real mean
- with torch.no_grad():
- hr_discriminated = self.discriminator(hr_imgs)
- real_mean = hr_discriminated.mean()
-
- fake_mean = sr_discriminated.mean()
- sr_discriminated_rel = sr_discriminated - real_mean
- hr_discriminated_rel = hr_discriminated - fake_mean
-
- loss_fake = self.adversarial_loss_criterion(
- sr_discriminated_rel, torch.ones_like(sr_discriminated)
- ) # now target is 1.0 for G loss
-
- loss_real = self.adversarial_loss_criterion(
- hr_discriminated_rel, torch.zeros_like(hr_discriminated)
- ) # now target is 0.0 for G loss
-
- adversarial_loss = (loss_fake + loss_real) / 2.0
-
- else:
- adversarial_loss = self.adversarial_loss_criterion(
- sr_discriminated, torch.ones_like(sr_discriminated)
- ) # now target is 1.0 for G loss
- self.log(
- "generator/adversarial_loss", adversarial_loss, sync_dist=True
- ) # log unweighted adversarial loss
-
- """ 3. Weight the losses"""
- adv_weight = (
- self._adv_loss_weight()
- ) # get adversarial weight based on current step
- adversarial_loss_weighted = (
- adversarial_loss * adv_weight
- ) # weight adversarial loss
- total_loss = content_loss + adversarial_loss_weighted # total content loss
- self.log(
- "generator/total_loss", total_loss, sync_dist=True
- ) # log combined objective (content + λ_adv * adv)
-
- # return Generator loss
- return total_loss
-
-
def training_step_PL2(self, batch, batch_idx):
- """Manual-optimization training step for PyTorch Lightning ≥ 2.0.
+ """Manual-optimization training step for PyTorch Lightning >= 2.
- Mirrors the PL1.x logic with explicit optimizer control:
+ Performs two explicit optimizer updates per batch:
- **Pretraining phase**: Discriminator logs dummies; Generator is optimized with
- content loss only (no adversarial term), and EMA optionally updates.
+ hardwired L1 loss only (no adversarial term), and EMA optionally updates.
- **Adversarial phase**: Performs a Discriminator step (real vs. fake BCE),
- followed by a Generator step (content + λ_adv · BCE against ones). Uses the
- same log keys and ordering as the PL1.x path.
+ followed by a Generator step (content + λ_adv · BCE against ones).
Assumptions:
- `self.automatic_optimization` is `False` (manual opt).
@@ -356,18 +30,10 @@ def training_step_PL2(self, batch, batch_idx):
- `"generator/content_loss"`, `"generator/adversarial_loss"`, `"generator/total_loss"`
- `"discriminator/adversarial_loss"`, `"discriminator/D(y)_prob"`, `"discriminator/D(G(x))_prob"`
- `"training/adv_loss_weight"` (λ_adv from ramp schedule)
-
- Raises:
- AssertionError: If PL version < 2.0 or `automatic_optimization` is True.
"""
- assert self.pl_version >= (
- 2,
- 0,
- 0,
- ), "training_step_PL2 requires PyTorch Lightning >= 2.x."
assert (
self.automatic_optimization is False
- ), "training_step_PL2 requires manual_optimization."
+ ), "training_step_PL2 requires manual optimization."
# -------- CREATE SR DATA --------
lr_imgs, hr_imgs = batch
@@ -405,26 +71,24 @@ def _maybe_clip_gradients(module, optimizer=None):
# SECTION: Pretraining phase gate
# ======================================================================
pretrain_phase = self._pretrain_check()
- # in PL1.x you logged this only on G-step; here we log once per batch
self.log(
"training/pretrain_phase", float(pretrain_phase), prog_bar=False, sync_dist=True
)
# ======================================================================
- # SECTION: Pretraining branch (content-only on G; D logs dummies)
+ # SECTION: Pretraining branch (L1-only on G; D logs dummies)
# ======================================================================
if pretrain_phase:
- # --- D dummy logs (no step) to mimic your optimizer_idx==0 branch ---
+ # --- D dummy logs (no step during pretraining) ---
with torch.no_grad():
zero = torch.tensor(0.0, device=hr_imgs.device, dtype=hr_imgs.dtype)
self.log("discriminator/D(y)_prob", zero, prog_bar=True, sync_dist=True)
self.log("discriminator/D(G(x))_prob", zero, prog_bar=True, sync_dist=True)
self.log("discriminator/adversarial_loss", zero, sync_dist=True)
- # --- G step: content loss only (identical to your optimizer_idx==1 pretrain) ---
- content_loss, metrics = self.content_loss_criterion.return_loss(
- sr_imgs, hr_imgs
- )
+ # --- G step: hardwired L1-only pretraining loss ---
+ content_loss = torch.nn.functional.l1_loss(sr_imgs, hr_imgs)
+ metrics = {"l1": content_loss.detach()}
self._log_generator_content_loss(content_loss)
for key, value in metrics.items():
self.log(f"train_metrics/{key}", value, sync_dist=True)
@@ -446,7 +110,6 @@ def _maybe_clip_gradients(module, optimizer=None):
if self.ema is not None and self.global_step >= self._ema_update_after_step:
self.ema.update(self.generator)
- # return same scalar you’d have returned in PL1.x (content loss)
return content_loss
# ======================================================================
@@ -460,7 +123,6 @@ def _maybe_clip_gradients(module, optimizer=None):
r1_gamma = getattr(self, "r1_gamma", 0.0)
hr_imgs.requires_grad_(r1_gamma > 0) # enable grad for R1 penalty if needed
-
hr_discriminated = self.discriminator(hr_imgs) # D(y)
sr_discriminated = self.discriminator(sr_imgs.detach()) # D(G(x)) w/o grad to G
@@ -482,7 +144,7 @@ def _maybe_clip_gradients(module, optimizer=None):
loss_real = self.adversarial_loss_criterion(hr_discriminated_rel, real_target)
loss_fake = self.adversarial_loss_criterion(sr_discriminated_rel, fake_target)
else:
- # Keep PL1 and PL2 loss scales consistent for non-relativistic BCE.
+ # Keep loss scales consistent for non-relativistic BCE.
loss_real = self.adversarial_loss_criterion(hr_discriminated, real_target) * 0.5
loss_fake = self.adversarial_loss_criterion(sr_discriminated, fake_target) * 0.5
@@ -599,5 +261,4 @@ def _maybe_clip_gradients(module, optimizer=None):
if self.ema is not None and self.global_step >= self._ema_update_after_step:
self.ema.update(self.generator)
- # return same scalar you return in PL1.x G path
return total_loss
diff --git a/opensr_srgan/model/training_workflow.txt b/opensr_srgan/model/training_workflow.txt
index 8b1175f..040792e 100644
--- a/opensr_srgan/model/training_workflow.txt
+++ b/opensr_srgan/model/training_workflow.txt
@@ -1,67 +1,55 @@
-ENTRY: trainer.fit(model, datamodule)
+ENTRY: trainer.fit(model, datamodule[, ckpt_path])
│
├─ PRELUDE (opensr_srgan/train.py)
-│ ├─ Load config (OmegaConf) and resolve device list (`Training.gpus`).
-│ ├─ Check `Model.load_checkpoint` / `Model.continue_training` to decide between fresh training vs. retraining from a checkpoint.
-│ ├─ Call `build_lightning_kwargs()` → detects PyTorch Lightning version, normalises accelerator/devices, and routes resume arguments (`resume_from_checkpoint` for PL<2, `ckpt_path` for PL≥2).
-│ └─ Instantiate `Trainer(**trainer_kwargs)` and invoke `trainer.fit(..., **fit_kwargs)`.
+│ ├─ Load config (OmegaConf).
+│ ├─ Validate checkpoint intent:
+│ │ ├─ `Model.load_checkpoint` -> weight-only initialization.
+│ │ └─ `Model.continue_training` -> full resume via `ckpt_path`.
+│ ├─ Call `build_lightning_kwargs()` -> normalize accelerator/devices and build `fit_kwargs`.
+│ └─ Instantiate `Trainer(**trainer_kwargs)` and call `trainer.fit(..., **fit_kwargs)`.
│
├─ SRGAN_model.setup_lightning()
-│ ├─ Parse `pl.__version__` into `self.pl_version`.
-│ ├─ IF `self.pl_version >= (2,0,0)`
-│ │ ├─ Set `automatic_optimization = False` (manual optimisation required by PL2).
-│ │ └─ Bind `training_step_PL2` as the active `training_step` implementation.
-│ └─ ELSE (PL1.x)
-│ ├─ Ensure `automatic_optimization is True`.
-│ └─ Bind `training_step_PL1` (optimizer_idx-based training).
+│ ├─ Validate installed PyTorch Lightning major version is >= 2.
+│ ├─ Set `automatic_optimization = False` (manual optimization).
+│ └─ Bind `training_step_PL2` as active `training_step`.
│
-└─ ACTIVE TRAINING STEP (batch, batch_idx[, optimizer_idx])
+└─ ACTIVE TRAINING STEP (batch, batch_idx)
│
- ├─ 1) Forward + metrics (no grad for logging reuse)
+ ├─ 1) Forward
│ ├─ (lr, hr) = batch
- │ ├─ sr = G(lr)
- │ └─ metrics = content_loss.return_metrics(sr, hr)
- │ └─ LOG: `train_metrics/*` (L1, SAM, perceptual, TV, PSNR, SSIM)
+ │ └─ sr = G(lr)
│
├─ 2) Phase checks
- │ ├─ `pretrain = _pretrain_check()` # compare global_step vs. `Training.g_pretrain_steps`
- │ ├─ LOG: `training/pretrain_phase` (on G step for PL1, per-batch for PL2)
- │ └─ `adv_weight = _adv_loss_weight()` or `_compute_adv_loss_weight()` # ramp toward `Training.Losses.adv_loss_beta`
+ │ ├─ `pretrain = _pretrain_check()`
+ │ ├─ LOG: `training/pretrain_phase`
+ │ └─ `adv_weight = _adv_loss_weight()`
│ └─ LOG: `training/adv_loss_weight`
│
- ├─ 3) IF `pretrain` True (Generator warm-up)
- │ ├─ Generator path
- │ │ ├─ Compute `(content_loss, metrics) = content_loss.return_loss(sr, hr)`
- │ │ ├─ LOG: `generator/content_loss`
- │ │ ├─ Reuse metrics for logging (`train_metrics/*`)
- │ │ ├─ LOG: `training/adv_loss_weight` (even though weight is 0 during warm-up)
- │ │ └─ RETURN/STEP on `content_loss` only (PL1 returns scalar; PL2 manual_backward + `opt_g.step()`)
- │ └─ Discriminator path
- │ ├─ LOG zeros for `discriminator/D(y)_prob`, `discriminator/D(G(x))_prob`, `discriminator/adversarial_loss`
- │ └─ Return dummy zero tensor with `requires_grad=True` (PL1) or skip optimisation but keep logs (PL2)
+ ├─ 3) IF `pretrain` True (Generator warm-up)
+ │ ├─ LOG zeros for discriminator diagnostics
+ │ │ (`discriminator/D(y)_prob`, `discriminator/D(G(x))_prob`, `discriminator/adversarial_loss`)
+ │ ├─ Compute `(content_loss, metrics) = content_loss.return_loss(sr, hr)`
+ │ ├─ LOG: `generator/content_loss` and `train_metrics/*`
+ │ ├─ Optimize G manually (`zero_grad` -> `manual_backward` -> `step`)
+ │ ├─ Optional EMA update after G step
+ │ └─ RETURN `content_loss`
│
- └─ 4) ELSE `pretrain` False (Full GAN training)
+ └─ 4) ELSE (Full GAN training)
│
├─ 4A) Discriminator update
│ ├─ hr_logits = D(hr)
│ ├─ sr_logits = D(sr.detach())
- │ ├─ real_target = adv_target (0.9 with label smoothing else 1.0)
- │ ├─ fake_target = 0.0
- │ ├─ loss_real = BCEWithLogits(hr_logits, real_target)
- │ ├─ loss_fake = BCEWithLogits(sr_logits, fake_target)
- │ ├─ d_loss = loss_real + loss_fake
- │ ├─ LOG: `discriminator/adversarial_loss`
- │ ├─ LOG: `discriminator/D(y)_prob` = sigmoid(hr_logits).mean()
- │ ├─ LOG: `discriminator/D(G(x))_prob` = sigmoid(sr_logits).mean()
- │ └─ Optimise D (return `d_loss` in PL1; manual backward + `opt_d.step()` in PL2)
+ │ ├─ Compute D loss (BCE or Wasserstein + optional R1)
+ │ ├─ LOG: `discriminator/adversarial_loss`, `discriminator/r1_penalty`,
+ │ │ `discriminator/D(y)_prob`, `discriminator/D(G(x))_prob`
+ │ └─ Optimize D manually (`zero_grad` -> `manual_backward` -> `step`)
│
└─ 4B) Generator update
- ├─ (content_loss, metrics) = content_loss.return_loss(sr, hr)
- ├─ LOG: `generator/content_loss`
- ├─ sr_logits = D(sr)
- ├─ g_adv = BCEWithLogits(sr_logits, target=1.0)
- ├─ LOG: `generator/adversarial_loss` = g_adv
+ ├─ Compute `(content_loss, metrics) = content_loss.return_loss(sr, hr)`
+ ├─ g_adv = adversarial objective from D(sr)
├─ total_loss = content_loss + adv_weight * g_adv
- ├─ LOG: `generator/total_loss`
- ├─ Optimise G (return `total_loss` in PL1; manual backward + `opt_g.step()` in PL2)
- └─ IF EMA enabled AND `global_step >= _ema_update_after_step`: update shadow weights (`EMA/update_after_step`, `EMA/is_active` logs)
+ ├─ LOG: `generator/content_loss`, `generator/adversarial_loss`, `generator/total_loss`,
+ │ and `train_metrics/*`
+ ├─ Optimize G manually (`zero_grad` -> `manual_backward` -> `step`)
+ ├─ Optional EMA update after G step
+ └─ RETURN `total_loss`
diff --git a/opensr_srgan/train.py b/opensr_srgan/train.py
index 430dce1..7850b76 100644
--- a/opensr_srgan/train.py
+++ b/opensr_srgan/train.py
@@ -9,13 +9,11 @@
import datetime
import os
from pathlib import Path
-from multiprocessing import freeze_support
import torch
import wandb
from omegaconf import OmegaConf
import pytorch_lightning as pl
-from pytorch_lightning import Trainer
def train(config):
@@ -30,7 +28,7 @@ def train(config):
Notes
-----
- - Supports both PL < 2.0 and >= 2.0 via `build_lightning_kwargs`.
+ - Requires PyTorch Lightning >= 2.0.
- If `Model.load_checkpoint` is set, weights are loaded before training.
- If `Model.continue_training` is set, training resumes from that checkpoint.
- Setting both `Model.load_checkpoint` and `Model.continue_training` is invalid.
@@ -49,11 +47,6 @@ def train(config):
)
#############################################################################################################
- # Get devices
- cuda_devices = config.Training.gpus
- cuda_strategy = "ddp" if len(cuda_devices) > 1 else None
-
- #############################################################################################################
" LOAD MODEL "
#############################################################################################################
model_load_checkpoint = getattr(config.Model, "load_checkpoint", False)
@@ -78,9 +71,7 @@ def _checkpoint_is_set(value) -> bool:
if _checkpoint_is_set(model_load_checkpoint):
model.load_weights_from_checkpoint(model_load_checkpoint, strict=False)
- resume_from_checkpoint_variable = (
- resume_checkpoint if _checkpoint_is_set(resume_checkpoint) else None
- )
+ resume_ckpt = resume_checkpoint if _checkpoint_is_set(resume_checkpoint) else None
#############################################################################################################
""" GET DATA """
@@ -152,17 +143,17 @@ def _checkpoint_is_set(value) -> bool:
#############################################################################################################
""" Set Args for Training and Start Training """
- """ make it robust for both PL<2.0 and PL>=2.0 """
+ """ Build trainer kwargs and launch training """
#############################################################################################################
from opensr_srgan.utils.build_trainer_kwargs import build_lightning_kwargs
trainer_kwargs, fit_kwargs = (
- build_lightning_kwargs( # get kwargs depending on PL version
+ build_lightning_kwargs(
config=config,
logger=wandb_logger,
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stop_callback,
- resume_ckpt=resume_from_checkpoint_variable,
+ resume_ckpt=resume_ckpt,
)
)
diff --git a/opensr_srgan/utils/build_trainer_kwargs.py b/opensr_srgan/utils/build_trainer_kwargs.py
index 722404e..8180913 100644
--- a/opensr_srgan/utils/build_trainer_kwargs.py
+++ b/opensr_srgan/utils/build_trainer_kwargs.py
@@ -1,11 +1,8 @@
-# put this near the top of your module (e.g., in opensr_srgan/train.py)
-import os
import inspect
from collections.abc import Sequence
import torch
import pytorch_lightning as pl
-from packaging.version import Version
def build_lightning_kwargs(
@@ -15,19 +12,15 @@ def build_lightning_kwargs(
early_stop_callback,
resume_ckpt: str | None = None,
):
- """Return Trainer/fit keyword arguments compatible with Lightning < 2 and ≥ 2.
+ """Return Trainer/fit keyword arguments for Lightning 2+.
Builds two dictionaries:
- 1) ``trainer_kwargs`` — safe, version-aware arguments for ``pytorch_lightning.Trainer``.
- 2) ``fit_kwargs`` — arguments for ``Trainer.fit`` (e.g., ``ckpt_path`` on PL ≥ 2).
+ 1) ``trainer_kwargs`` — arguments for ``pytorch_lightning.Trainer``.
+ 2) ``fit_kwargs`` — arguments for ``Trainer.fit`` (e.g., ``ckpt_path``).
The helper normalizes device configuration (CPU/GPU, DDP when multiple GPUs),
- removes deprecated/None entries, and maps the legacy resume API:
- - PL < 2: uses ``resume_from_checkpoint`` in ``trainer_kwargs``.
- - PL ≥ 2: uses ``ckpt_path`` in ``fit_kwargs`` (if supported by signature).
-
- It also clears the legacy environment variable
- ``PL_TRAINER_RESUME_FROM_CHECKPOINT`` to avoid non-deterministic resume behavior.
+ removes ``None`` entries, and filters kwargs against the active Lightning
+ signatures to avoid passing unsupported arguments.
Args:
config: OmegaConf-like config with ``Training`` fields:
@@ -44,7 +37,7 @@ def build_lightning_kwargs(
Returns:
Tuple[Dict[str, Any], Dict[str, Any]]:
- trainer_kwargs: Dict for ``pl.Trainer(**trainer_kwargs)``.
- - fit_kwargs: Dict for ``trainer.fit(..., **fit_kwargs)`` (may be empty).
+ - fit_kwargs: Dict for ``trainer.fit(..., **fit_kwargs)``.
Raises:
ValueError: If ``Training.device`` is not one of {"auto","cpu","cuda","gpu"}.
@@ -52,27 +45,11 @@ def build_lightning_kwargs(
Notes:
- CPU runs force ``devices=1`` and no strategy.
- GPU runs honor ``Training.gpus``; DDP is enabled when requesting >1 device.
- - All ``None`` values are pruned; kwargs are filtered to match the current
- ``Trainer.__init__`` and ``Trainer.fit`` signatures to stay future-proof.
+ - Resume checkpoints are forwarded through ``Trainer.fit(ckpt_path=...)``.
"""
# ---------------------------------------------------------------------
- # 1) Version detection and environment cleanup
- # ---------------------------------------------------------------------
- # Determine whether the installed Lightning version is 2.x or newer.
- # The behaviour of ``resume_from_checkpoint`` changed between major
- # versions, so we compute this once and use the flag later when assembling
- # the kwargs.
- is_v2 = Version(pl.__version__) >= Version("2.0.0")
-
- # Lightning < 2 used an environment variable to infer the checkpoint path
- # when resuming. The variable is ignored (and in some cases triggers
- # warnings) on newer versions, so we proactively remove it to provide a
- # deterministic behaviour across environments.
- os.environ.pop("PL_TRAINER_RESUME_FROM_CHECKPOINT", None)
-
- # ---------------------------------------------------------------------
- # 2) Parse device configuration from the OmegaConf config
+ # 1) Parse device configuration from the OmegaConf config
# ---------------------------------------------------------------------
# ``Training.gpus`` may be specified either as an integer (e.g. ``2``) or a
# sequence (e.g. ``[0, 1]``). We keep the raw object so it can be passed to
@@ -122,10 +99,20 @@ def _count_devices(devices):
strategy = None
else:
devices = devices_cfg if ndev else 1
- strategy = "ddp" if ndev > 1 else None
+ if ndev > 1:
+ # GAN manual optimization updates only one optimizer branch at a time,
+ # so DDP must track unused params on each step.
+ find_unused = bool(
+ getattr(config.Training, "find_unused_parameters", True)
+ )
+ strategy = (
+ "ddp_find_unused_parameters_true" if find_unused else "ddp"
+ )
+ else:
+ strategy = None
# ---------------------------------------------------------------------
- # 3) Assemble the base Trainer kwargs shared across Lightning versions
+ # 2) Assemble the base Trainer kwargs
# ---------------------------------------------------------------------
trainer_kwargs = dict(
accelerator=accelerator,
@@ -134,7 +121,7 @@ def _count_devices(devices):
val_check_interval=config.Training.val_check_interval,
limit_val_batches=config.Training.limit_val_batches,
max_epochs=config.Training.max_epochs,
- log_every_n_steps=100,
+ log_every_n_steps=50,
logger=[logger],
callbacks=[checkpoint_callback, early_stop_callback],
gradient_clip_val=config.Optimizers.gradient_clip_val,
@@ -146,12 +133,6 @@ def _count_devices(devices):
# kwargs.
trainer_kwargs = {k: v for k, v in trainer_kwargs.items() if v is not None}
- # ---------------------------------------------------------------------
- # 4) Add compatibility shims for pre-Lightning 2 releases
- # ---------------------------------------------------------------------
- if not is_v2 and resume_ckpt:
- trainer_kwargs["resume_from_checkpoint"] = resume_ckpt
-
# Some Lightning releases occasionally deprecate constructor arguments. To
# ensure we do not pass stale options we filter the dictionary so it only
# contains parameters that are still accepted by ``Trainer.__init__``.
@@ -159,11 +140,10 @@ def _count_devices(devices):
trainer_kwargs = {k: v for k, v in trainer_kwargs.items() if k in init_sig}
# ---------------------------------------------------------------------
- # 5) ``Trainer.fit`` keyword arguments (Lightning >= 2)
+ # 3) ``Trainer.fit`` keyword arguments
# ---------------------------------------------------------------------
fit_kwargs = {}
- if is_v2 and resume_ckpt:
- # ``ckpt_path`` is the new name for ``resume_from_checkpoint``.
+ if resume_ckpt:
fit_sig = inspect.signature(pl.Trainer.fit).parameters
if "ckpt_path" in fit_sig:
fit_kwargs["ckpt_path"] = resume_ckpt
diff --git a/pyproject.toml b/pyproject.toml
index 072451f..b657fe6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -13,10 +13,8 @@ authors = [
license = { file = "LICENSE" }
requires-python = ">=3.10,!=3.11.*,<3.13"
dependencies = [
- "torch==1.13.1; python_version < \"3.11\"",
- "torch>=2.1; python_version >= \"3.12\"",
- "pytorch-lightning==1.9.*; python_version < \"3.11\"",
- "pytorch-lightning>=2.1,<3.0; python_version >= \"3.12\"",
+ "torch>=2.1",
+ "pytorch-lightning>=2.1,<3.0",
"torchvision",
"numpy",
"kornia",
diff --git a/requirements.txt b/requirements.txt
index dac3d95..1417e71 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,14 +1,5 @@
-# NOTE: Torch 1.13.1 wheels are provided for Python 3.10 and earlier. Use a Python 3.10.x
-# interpreter when creating the environment to satisfy these pinned dependencies.
-# If PyPI cannot locate torch/torchvision, install them from the official
-# PyTorch wheel index (CPU or CUDA builds) before re-running
-# `pip install -r requirements.txt`.
-
-torch==1.13.1; python_version < "3.11"
-torch>=2.1; python_version >= "3.12"
-
-pytorch-lightning==1.9.*; python_version < "3.11"
-pytorch-lightning>=2.1,<3.0; python_version >= "3.12"
+torch>=2.1
+pytorch-lightning>=2.1,<3.0
torchvision
numpy
@@ -26,4 +17,4 @@ huggingface_hub # for loading models from the Hugging Face Hub
# Optional extras (enable additional metrics/data loaders,inference)
#lpips # for when training with LPIPS loss
#tacoreader # for SEN2NAIP dataset
-#opensr_utils # for inference utilities
\ No newline at end of file
+#opensr_utils # for inference utilities
diff --git a/tests/test_data/test_dataset_selector.py b/tests/test_data/test_dataset_selector.py
index 9bcbf6d..74ec22a 100644
--- a/tests/test_data/test_dataset_selector.py
+++ b/tests/test_data/test_dataset_selector.py
@@ -141,6 +141,21 @@ def __getitem__(self, idx):
return value, value + 1
+class _StubSEN2NAIPDataset(Dataset):
+ created_args = []
+
+ def __init__(self, config, **kwargs):
+ self.__class__.created_args.append((config, kwargs))
+ self._data = torch.arange(4, dtype=torch.float32)
+
+ def __len__(self):
+ return len(self._data)
+
+ def __getitem__(self, idx):
+ value = self._data[idx]
+ return value, value + 1
+
+
def _install_module(monkeypatch, name, is_package=False, **attrs):
module = types.ModuleType(name)
if is_package:
@@ -208,6 +223,43 @@ def test_select_dataset_lrhr_folder_branch_missing_root_raises(monkeypatch, tmp_
dataset_selector.select_dataset(config)
+def test_select_dataset_sen2naip_uses_hardcoded_taco_and_training_config(monkeypatch):
+ _StubSEN2NAIPDataset.created_args.clear()
+ _install_module(monkeypatch, "opensr_srgan.data.sen2naip", is_package=True)
+ _install_module(
+ monkeypatch,
+ "opensr_srgan.data.sen2naip.sen2naip_dataset",
+ SEN2NAIP=_StubSEN2NAIPDataset,
+ )
+
+ monkeypatch.setattr(dataset_selector, "SEN2NAIP_TACO_FILE", "/tmp/hardcoded.taco")
+
+ config = _make_config(
+ dataset_type="sen2naip",
+ normalization="normalise_10k",
+ sen2naip_val_fraction=0.2,
+ train_batch_size=2,
+ val_batch_size=2,
+ num_workers=0,
+ )
+
+ datamodule = dataset_selector.select_dataset(config)
+ train_loader = datamodule.train_dataloader()
+ train_batch = next(iter(train_loader))
+
+ assert len(_StubSEN2NAIPDataset.created_args) == 2
+ train_cfg, train_kwargs = _StubSEN2NAIPDataset.created_args[0]
+ val_cfg, val_kwargs = _StubSEN2NAIPDataset.created_args[1]
+ assert train_cfg is config
+ assert val_cfg is config
+ assert train_kwargs["phase"] == "train"
+ assert val_kwargs["phase"] == "val"
+ assert train_kwargs["taco_file"] == "/tmp/hardcoded.taco"
+ assert val_kwargs["taco_file"] == "/tmp/hardcoded.taco"
+ assert isinstance(train_batch, (list, tuple))
+ assert len(train_batch) == 2
+
+
def test_dataset_selector_module_main_guard(monkeypatch):
import opensr_srgan.data.example_data.example_dataset as example_module
from omegaconf import OmegaConf
diff --git a/tests/test_data/test_sen2naip_dataset.py b/tests/test_data/test_sen2naip_dataset.py
new file mode 100644
index 0000000..4b40d73
--- /dev/null
+++ b/tests/test_data/test_sen2naip_dataset.py
@@ -0,0 +1,82 @@
+import sys
+import types
+from types import SimpleNamespace
+
+import numpy as np
+import pytest
+import torch
+
+from opensr_srgan.data.sen2naip.sen2naip_dataset import SEN2NAIP
+
+
+class _FakeSample:
+ def read(self, idx):
+ return "lr.tif" if idx == 0 else "hr.tif"
+
+
+class _FakeTacoDataset:
+ def __len__(self):
+ return 4
+
+ def read(self, _idx):
+ return _FakeSample()
+
+
+class _FakeRaster:
+ def __init__(self, path):
+ self.path = path
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc, tb):
+ return False
+
+ def read(self):
+ if self.path == "lr.tif":
+ return np.full((1, 2, 2), 5000.0, dtype=np.float32)
+ return np.full((1, 2, 2), 10000.0, dtype=np.float32)
+
+
+@pytest.fixture(autouse=True)
+def _stub_external_modules(monkeypatch):
+ tacoreader_mod = types.ModuleType("tacoreader")
+ tacoreader_mod.load = lambda _path: _FakeTacoDataset()
+ monkeypatch.setitem(sys.modules, "tacoreader", tacoreader_mod)
+
+ rasterio_mod = types.ModuleType("rasterio")
+ rasterio_mod.open = lambda path: _FakeRaster(path)
+ monkeypatch.setitem(sys.modules, "rasterio", rasterio_mod)
+
+
+def test_sen2naip_uses_config_normalization():
+ cfg = SimpleNamespace(
+ Data=SimpleNamespace(
+ normalization="normalise_10k",
+ sen2naip_taco_file="/tmp/fake.taco",
+ sen2naip_phase="train",
+ sen2naip_val_fraction=0.25,
+ )
+ )
+ ds = SEN2NAIP(cfg)
+
+ lr, hr = ds[0]
+
+ assert ds.normalizer.method == "normalise_10k"
+ assert isinstance(lr, torch.Tensor)
+ assert isinstance(hr, torch.Tensor)
+ assert torch.allclose(lr, torch.full_like(lr, 0.5))
+ assert torch.allclose(hr, torch.full_like(hr, 1.0))
+
+
+def test_sen2naip_requires_taco_file():
+ cfg = SimpleNamespace(
+ Data=SimpleNamespace(
+ normalization="identity",
+ sen2naip_phase="train",
+ sen2naip_val_fraction=0.25,
+ )
+ )
+
+ with pytest.raises(ValueError, match="sen2naip_taco_file"):
+ SEN2NAIP(cfg)
diff --git a/tests/test_models/test_SRGAN_train.py b/tests/test_models/test_SRGAN_train.py
index 8c94f0b..a3fb534 100644
--- a/tests/test_models/test_SRGAN_train.py
+++ b/tests/test_models/test_SRGAN_train.py
@@ -1,5 +1,6 @@
import types
+import pytest
import torch
from opensr_srgan.model import SRGAN
@@ -79,33 +80,23 @@ def _sample_batch():
return lr, hr
-def test_setup_lightning_selects_training_step_branches():
+def test_setup_lightning_configures_manual_step_for_pl2(monkeypatch):
+ monkeypatch.setattr(SRGAN.pl, "__version__", "2.2.0")
model = SRGAN.SRGAN_model.__new__(SRGAN.SRGAN_model)
- model.pl_version = (2, 0, 0)
- model.automatic_optimization = True
model.setup_lightning()
assert model.automatic_optimization is False
assert model._training_step_implementation.__func__ is training_step_PL.training_step_PL2
- model = SRGAN.SRGAN_model.__new__(SRGAN.SRGAN_model)
- model.pl_version = (1, 9, 0)
- model.automatic_optimization = True
- model.setup_lightning()
- assert model.automatic_optimization is True
- assert model._training_step_implementation.__func__ is training_step_PL.training_step_PL1
-
-def test_training_step_pl1_handles_pretraining_branch():
- harness = TrainingHarness(pretrain=True)
- loss = training_step_PL.training_step_PL1(harness, _sample_batch(), batch_idx=0, optimizer_idx=1)
- assert torch.is_tensor(loss)
- assert harness.logged["training/pretrain_phase"] == 1.0
- assert "generator/content_loss" in harness.logged
+def test_setup_lightning_rejects_pre_v2(monkeypatch):
+ monkeypatch.setattr(SRGAN.pl, "__version__", "1.9.5")
+ model = SRGAN.SRGAN_model.__new__(SRGAN.SRGAN_model)
+ with pytest.raises(RuntimeError, match="requires PyTorch Lightning >= 2.0"):
+ model.setup_lightning()
def test_training_step_pl2_runs_manual_optimization():
harness = TrainingHarness(pretrain=False)
- harness.pl_version = (2, 0, 0)
harness.automatic_optimization = False
loss = training_step_PL.training_step_PL2(harness, _sample_batch(), batch_idx=0)
diff --git a/tests/test_models/test_generators.py b/tests/test_models/test_generators.py
index d337421..db780cb 100644
--- a/tests/test_models/test_generators.py
+++ b/tests/test_models/test_generators.py
@@ -101,6 +101,7 @@ def test_conditional_alias_points_to_stochastic_generator():
"n_blocks": 23,
"scaling_factor": 4,
"growth_channels": 32,
+ "use_icnr": True,
},
ESRGANGenerator,
),
diff --git a/tests/test_training/test_training_step.py b/tests/test_training/test_training_step.py
index ad4407c..9135fbc 100644
--- a/tests/test_training/test_training_step.py
+++ b/tests/test_training/test_training_step.py
@@ -1,5 +1,6 @@
import types
+import pytest
import torch
from opensr_srgan.model import SRGAN
@@ -84,10 +85,9 @@ def _sample_batch():
return lr, hr
-def test_setup_lightning_selects_training_step_branches():
+def test_setup_lightning_configures_manual_step_for_pl2(monkeypatch):
+ monkeypatch.setattr(SRGAN.pl, "__version__", "2.2.0")
model = SRGAN.SRGAN_model.__new__(SRGAN.SRGAN_model)
- model.pl_version = (2, 0, 0)
- model.automatic_optimization = True
model.setup_lightning()
assert model.automatic_optimization is False
assert (
@@ -95,30 +95,16 @@ def test_setup_lightning_selects_training_step_branches():
is training_step_PL.training_step_PL2
)
- model = SRGAN.SRGAN_model.__new__(SRGAN.SRGAN_model)
- model.pl_version = (1, 9, 0)
- model.automatic_optimization = True
- model.setup_lightning()
- assert model.automatic_optimization is True
- assert (
- model._training_step_implementation.__func__
- is training_step_PL.training_step_PL1
- )
-
-def test_training_step_pl1_handles_pretraining_branch():
- harness = TrainingHarness(pretrain=True)
- loss = training_step_PL.training_step_PL1(
- harness, _sample_batch(), batch_idx=0, optimizer_idx=1
- )
- assert torch.is_tensor(loss)
- assert harness.logged["training/pretrain_phase"] == 1.0
- assert "generator/content_loss" in harness.logged
+def test_setup_lightning_rejects_pre_v2(monkeypatch):
+ monkeypatch.setattr(SRGAN.pl, "__version__", "1.9.5")
+ model = SRGAN.SRGAN_model.__new__(SRGAN.SRGAN_model)
+ with pytest.raises(RuntimeError, match="requires PyTorch Lightning >= 2.0"):
+ model.setup_lightning()
def test_training_step_pl2_runs_manual_optimization():
harness = TrainingHarness(pretrain=False)
- harness.pl_version = (2, 0, 0)
harness.automatic_optimization = False
loss = training_step_PL.training_step_PL2(harness, _sample_batch(), batch_idx=0)
diff --git a/tests/test_utils/test_trainer_kwargs.py b/tests/test_utils/test_trainer_kwargs.py
index b29b361..6099c39 100644
--- a/tests/test_utils/test_trainer_kwargs.py
+++ b/tests/test_utils/test_trainer_kwargs.py
@@ -3,9 +3,6 @@
import sys
import types
-import importlib.util
-from pathlib import Path
-
import pytest
from omegaconf import OmegaConf
@@ -29,7 +26,6 @@ def __init__(
log_every_n_steps=None,
logger=None,
callbacks=None,
- resume_from_checkpoint=None,
) -> None:
pass
@@ -92,8 +88,8 @@ def test_cpu_device_forces_single_device(monkeypatch):
assert fit_kwargs == {}
-def test_multi_gpu_enables_ddp(monkeypatch):
- """Multiple GPUs trigger the DDP strategy when CUDA is requested."""
+def test_multi_gpu_enables_ddp_with_unused_param_detection(monkeypatch):
+ """Multiple GPUs default to DDP with find-unused-parameters enabled."""
monkeypatch.setattr(
"opensr_srgan.utils.build_trainer_kwargs.torch.cuda.is_available",
@@ -104,7 +100,7 @@ def test_multi_gpu_enables_ddp(monkeypatch):
assert trainer_kwargs["accelerator"] == "gpu"
assert trainer_kwargs["devices"] == [0, 1]
- assert trainer_kwargs["strategy"] == "ddp"
+ assert trainer_kwargs["strategy"] == "ddp_find_unused_parameters_true"
def test_auto_device_respects_cuda_availability(monkeypatch):
@@ -131,57 +127,46 @@ def test_invalid_device_raises():
_call_builder(config)
-def test_integer_gpu_count_and_resume_for_pre_v2(monkeypatch):
- """PL<2 uses ``resume_from_checkpoint`` in Trainer kwargs."""
+def test_integer_gpu_count_enables_ddp_and_resume_ckpt_path():
+ """Integer GPU counts still enable DDP and resume via ``ckpt_path``."""
- monkeypatch.setattr("opensr_srgan.utils.build_trainer_kwargs.pl.__version__", "1.9.5")
- class _TrainerV1:
- def __init__(
- self,
- *,
- accelerator=None,
- strategy=None,
- devices=None,
- val_check_interval=None,
- limit_val_batches=None,
- max_epochs=None,
- log_every_n_steps=None,
- logger=None,
- callbacks=None,
- gradient_clip_val=None,
- resume_from_checkpoint=None,
- ) -> None:
- pass
+ config = _make_config(device="cuda", gpus=2)
+ trainer_kwargs, fit_kwargs = _call_builder(config, resume_ckpt="resume.ckpt")
+
+ assert trainer_kwargs["devices"] == 2
+ assert trainer_kwargs["strategy"] == "ddp_find_unused_parameters_true"
+ assert fit_kwargs == {"ckpt_path": "resume.ckpt"}
- def fit(self, *args, **kwargs):
- return None
+
+def test_multi_gpu_can_disable_unused_param_detection(monkeypatch):
+ """Config can opt out and use plain DDP strategy."""
monkeypatch.setattr(
- "opensr_srgan.utils.build_trainer_kwargs.pl.Trainer", _TrainerV1
+ "opensr_srgan.utils.build_trainer_kwargs.torch.cuda.is_available",
+ lambda: True,
)
- config = _make_config(device="cuda", gpus=2)
- trainer_kwargs, fit_kwargs = _call_builder(config, resume_ckpt="resume.ckpt")
+ config = _make_config(
+ device="cuda",
+ gpus=[0, 1],
+ find_unused_parameters=False,
+ )
+ trainer_kwargs, _ = _call_builder(config)
- assert trainer_kwargs["devices"] == 2
assert trainer_kwargs["strategy"] == "ddp"
- assert trainer_kwargs["resume_from_checkpoint"] == "resume.ckpt"
- assert fit_kwargs == {}
-def test_v2_resume_uses_fit_ckpt_path(monkeypatch):
+def test_v2_resume_uses_fit_ckpt_path():
"""PL>=2 forwards resume checkpoints via ``Trainer.fit(ckpt_path=...)``."""
- monkeypatch.setattr("opensr_srgan.utils.build_trainer_kwargs.pl.__version__", "2.2.0")
config = _make_config(device="cuda", gpus=[0])
_, fit_kwargs = _call_builder(config, resume_ckpt="resume.ckpt")
assert fit_kwargs == {"ckpt_path": "resume.ckpt"}
-def test_non_sequence_gpu_config_falls_back_to_single_device(monkeypatch):
+def test_non_sequence_gpu_config_falls_back_to_single_device():
"""Unexpected ``gpus`` values trigger the safe single-device fallback."""
- monkeypatch.setattr("opensr_srgan.utils.build_trainer_kwargs.pl.__version__", "2.2.0")
config = _make_config(device="cuda", gpus="not-a-device-list")
trainer_kwargs, _ = _call_builder(config)