diff --git a/CHANGELOG.md b/CHANGELOG.md index eba8c81..3f4fa68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,18 @@ All notable chagnes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +--- + +## [0.4.5] - 2026-04-30 + +### Added + +#### MLflow auto logging enhancements (`virtual_stain_flow/vsf_logging/`): + +- The logger now records model architecture tags by capturing each model config's `class_path` and setting `model..class_path` at train start. +- The loss-group auto logging routine `_log_loss_groups_config_and_tags` logs loss item names and weights as MLflow tags and persists the full loss group configuration as a JSON config artifact. + --- ## [0.4.4] - 2026-04-23 diff --git a/pyproject.toml b/pyproject.toml index a376133..1dfdd8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" [project] name = "virtual_stain_flow" -version = "0.4.0" +version = "0.4.5" description = "For developing virtual staining models" requires-python = ">=3.9" dependencies = [ diff --git a/src/virtual_stain_flow/engine/loss_group.py b/src/virtual_stain_flow/engine/loss_group.py index 16b3a74..09f3684 100644 --- a/src/virtual_stain_flow/engine/loss_group.py +++ b/src/virtual_stain_flow/engine/loss_group.py @@ -22,13 +22,14 @@ """ from dataclasses import dataclass -from typing import Optional, Union, Tuple, Dict, Sequence, List +from typing import Optional, Union, Tuple, Dict, Sequence, List, Any import torch from .loss_utils import BaseLoss, _get_loss_name, _scalar_from_ctx from .context import Context, ContextValue from .names import PREDS, TARGETS +from .progress import Progress Scalar = Union[int, float, bool] @@ -79,6 +80,7 @@ def __post_init__(self): def __call__( self, train: bool, + progress: Optional[Progress] = None, context: Optional[Context] = None, **inputs: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -88,6 +90,8 @@ def __call__( skipped during validation. :param train: Whether the model is in training mode. + :param progress: Optional Progress object containing scheduling state (epoch, step, etc.) + for dynamic weight scheduling. :param context: Optional Context object containing tensors. :param inputs: Keyword arguments containing all necessary inputs for the loss computation. @@ -117,6 +121,21 @@ def __call__( return raw, raw * _scalar_from_ctx(self.weight, inputs) + def get_config(self) -> Dict[str, Any]: + """ + Get the configuration of the LossItem for logging or checkpointing. + """ + return { + 'module': self.module.__class__.__name__, + 'args': self.args, + 'key': self.key, + 'weight': self.weight, + 'enabled': self.enabled, + 'compute_at_val': self.compute_at_val, + 'device': str(self.device) + } + + @dataclass class LossGroup: """ @@ -137,6 +156,7 @@ def item_names(self) -> List[Optional[str]]: def __call__( self, train: bool, + progress: Optional[Progress] = None, context: Optional[Context] = None, **inputs: torch.Tensor ) -> Tuple[torch.Tensor, Dict[str, Scalar]]: @@ -144,6 +164,8 @@ def __call__( Compute the total loss and individual loss values. :param train: Whether the model is in training mode. + :param progress: Optional Progress object containing scheduling state (epoch, step, etc.) + for dynamic weight scheduling. :param context: Optional Context object containing tensors. :input inputs: Keyword arguments containing all necessary inputs for the loss computations. @@ -156,8 +178,20 @@ def __call__( logs: Dict[str, float] = {} for item in self.items: - raw, weighted = item(train, context=context, **inputs) + raw, weighted = item( + train, + progress=progress, + context=context, + **inputs + ) logs[item.key] = raw.item() # type: ignore total += weighted return total, logs + + def get_config(self) -> List[Dict[str, Any]]: + """ + Get the configuration of the LossGroup for logging or checkpointing. + """ + + return [item.get_config() for item in self.items] diff --git a/src/virtual_stain_flow/engine/progress.py b/src/virtual_stain_flow/engine/progress.py new file mode 100644 index 0000000..e6024af --- /dev/null +++ b/src/virtual_stain_flow/engine/progress.py @@ -0,0 +1,37 @@ +""" +progress.py + +Progress tracking for loss weight scheduling. + +Provides a centralized abstraction for training progress state that can be used +by loss schedulers operating at different granularities (epoch, step, etc.). +Designed to be minimal and extensible without overcomplicating the current API. +""" + +from dataclasses import dataclass + + +@dataclass +class Progress: + """ + Tracks training progress for loss weight scheduling. + + Provides centralized access to scheduling state including epoch and step, + with room for future custom scheduling granularities. + """ + epoch: int = 0 + step: int = 0 + + def set_epoch(self, epoch: int) -> None: + """ + Update the current epoch number. + """ + self.epoch = epoch + + def set_step(self, step: int) -> None: + """ + Update the current step number. + Intended to be accumulated across epoch for a global step count + that can be used for step-based scheduling. + """ + self.step = step diff --git a/src/virtual_stain_flow/trainers/AbstractTrainer.py b/src/virtual_stain_flow/trainers/AbstractTrainer.py index 6c64acf..36c25ae 100644 --- a/src/virtual_stain_flow/trainers/AbstractTrainer.py +++ b/src/virtual_stain_flow/trainers/AbstractTrainer.py @@ -14,6 +14,7 @@ from .trainer_protocol import TrainerProtocol from ..metrics.AbstractMetrics import AbstractMetrics +from ..engine.progress import Progress from ..datasets.data_split import default_random_split @@ -113,6 +114,9 @@ def _init_state( # Epoch state self._epoch = 0 + + # Progress tracking for loss weight scheduling + self._progress = Progress(epoch=0, step=0) # Loss and metrics state self._train_losses = defaultdict(list) @@ -232,6 +236,8 @@ def train_epoch(self): phase="Train" ) + self._progress.set_step(self._progress.step + 1) + batch_loss = self.train_step(inputs, targets) for key, value in batch_loss.items(): losses[key].append(value) @@ -512,6 +518,11 @@ def metrics(self): def epoch(self): return self._epoch + @property + def progress(self) -> Progress: + """Returns the Progress object tracking training state (epoch, step, etc.)""" + return self._progress + @property def train_losses(self): return self._train_losses @@ -548,6 +559,7 @@ def early_stop_counter(self, value: int): @epoch.setter def epoch(self, value: int): self._epoch = value + self._progress.set_epoch(value) """ Update loss and metrics diff --git a/src/virtual_stain_flow/trainers/logging_gan_trainer.py b/src/virtual_stain_flow/trainers/logging_gan_trainer.py index 8e5fe69..d959457 100644 --- a/src/virtual_stain_flow/trainers/logging_gan_trainer.py +++ b/src/virtual_stain_flow/trainers/logging_gan_trainer.py @@ -110,6 +110,7 @@ def train_step( ) disc_weighted_total, disc_logs = self._discriminator_loss_group( train=True, + progress=self.progress, context=disc_ctx ) disc_weighted_total.backward() @@ -133,6 +134,7 @@ def train_step( ) gen_weighted_total, gen_logs = self._generator_loss_group( train=True, + progress=self.progress, context=gen_ctx ) gen_weighted_total.backward() @@ -142,12 +144,14 @@ def train_step( gen_logs = {} self._global_step += 1 + self._progress.set_step(self._global_step) # if generator logs are not computed this step (due to skipped update), # compute from discriminator context if not gen_logs: _, gen_logs = self._generator_loss_group( train=True, + progress=self.progress, context=ctx ) @@ -176,10 +180,12 @@ def evaluate_step( ) _, gen_logs = self._generator_loss_group( train=False, + progress=self.progress, context=ctx ) _, disc_logs = self._discriminator_loss_group( train=False, + progress=self.progress, context=ctx ) @@ -187,6 +193,13 @@ def evaluate_step( metric.update(*ctx.as_metric_args(), validation=True) return gen_logs | disc_logs + + @property + def loss_groups(self) -> Dict[str, LossGroup]: + return { + 'generator': self._generator_loss_group, + 'discriminator': self._discriminator_loss_group + } def save_model( self, diff --git a/src/virtual_stain_flow/trainers/logging_trainer.py b/src/virtual_stain_flow/trainers/logging_trainer.py index 6442dea..ff20c8d 100644 --- a/src/virtual_stain_flow/trainers/logging_trainer.py +++ b/src/virtual_stain_flow/trainers/logging_trainer.py @@ -105,7 +105,11 @@ def train_step( targets=targets ) - weighted_total, logs = self._loss_group(train=True, context=ctx) + weighted_total, logs = self._loss_group( + train=True, + progress=self.progress, + context=ctx + ) weighted_total.backward() self._forward_group.step() @@ -133,12 +137,20 @@ def evaluate_step( targets=targets ) - _, logs = self._loss_group(train=False, context=ctx) + _, logs = self._loss_group( + train=False, + progress=self.progress, + context=ctx + ) for _, metric in self.metrics.items(): metric.update(*ctx.as_metric_args(), validation=True) return logs + + @property + def loss_groups(self) -> Dict[str, LossGroup]: + return {'main': self._loss_group} def save_model( self, diff --git a/src/virtual_stain_flow/vsf_logging/MlflowLogger.py b/src/virtual_stain_flow/vsf_logging/MlflowLogger.py index a46c61d..a9164da 100644 --- a/src/virtual_stain_flow/vsf_logging/MlflowLogger.py +++ b/src/virtual_stain_flow/vsf_logging/MlflowLogger.py @@ -204,7 +204,7 @@ def on_train_start(self): else: models = [] - for model in models: + for idx, model in enumerate(models): if isinstance(model, BaseModel) and hasattr(model, 'to_config'): try: @@ -213,6 +213,12 @@ def on_train_start(self): print(f"Could not get model config for logging: {e}") config = None if config: + class_path = config.get("class_path") + if class_path: + mlflow.set_tag( + f"model.{idx}.class_path", + str(class_path) + ) try: self.log_config( tag=model.__class__.__name__, @@ -222,6 +228,8 @@ def on_train_start(self): except Exception as e: print(f"Fail to log model config as artifact: {e}") + self._log_loss_groups_config_and_tags() + for callback in self.callbacks: # TODO consider if we want hasattr checks @@ -502,6 +510,93 @@ def _save_model_weights( artifact_path=artifact_path ) + def _get_loss_groups(self) -> Dict[str, Any]: + """ + Discover loss groups attached to the bound trainer. + """ + + if self.trainer is None: + return {} + + loss_groups: Dict[str, Any] = {} + + explicit_groups = getattr(self.trainer, 'loss_groups', None) + if isinstance(explicit_groups, dict): + for group_name, group in explicit_groups.items(): + if hasattr(group, 'get_config'): + loss_groups[str(group_name)] = group + + fallback_attrs = { + 'main': '_loss_group', + 'generator': '_generator_loss_group', + 'discriminator': '_discriminator_loss_group' + } + for group_name, attr in fallback_attrs.items(): + if group_name in loss_groups: + continue + group = getattr(self.trainer, attr, None) + if group is not None and hasattr(group, 'get_config'): + loss_groups[group_name] = group + + return loss_groups + + def _log_loss_groups_config_and_tags(self) -> None: + """ + Log loss item names and weights as flattened string mlflow tags and + full loss group configuration (loss name, loss weight, whether loss + is active during validation etc.) as mlflow config artifacts. + """ + + loss_groups = self._get_loss_groups() + if not loss_groups: + return None + + for group_name, group in loss_groups.items(): + try: + group_config = group.get_config() + except Exception as e: + print( + f"Could not get loss group config for logging " + f"({group_name}): {e}" + ) + continue + + if not isinstance(group_config, list): + continue + + for idx, item_cfg in enumerate(group_config): + if not isinstance(item_cfg, dict): + continue + + if 'key' in item_cfg and item_cfg['key'] is not None: + mlflow.set_tag( + f"loss.{group_name}.{idx}.name", + str(item_cfg['key']) + ) + + if 'weight' in item_cfg and item_cfg['weight'] is not None: + mlflow.set_tag( + f"loss.{group_name}.{idx}.weight", + str(item_cfg['weight']) + ) + + try: + self.log_config( + tag=f"loss_group_{group_name}", + config={ + 'group_name': group_name, + 'items': group_config + }, + stage=None + ) + except Exception as e: + print( + f"Fail to log loss group config as artifact " + f"({group_name}): {e}" + ) + + return None + def log_config( self, tag: str, diff --git a/tests/conftest.py b/tests/conftest.py index 5859a44..b635c69 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,16 +2,21 @@ Testing fixtures meant to be shared across the whole package """ +import json +import importlib import pathlib +from types import SimpleNamespace import pytest import torch from torch.utils.data import DataLoader, Dataset +from virtual_stain_flow.trainers.AbstractTrainer import AbstractTrainer +from virtual_stain_flow.trainers.logging_trainer import SingleGeneratorTrainer from virtual_stain_flow.vsf_logging import MlflowLogger -# ----- Mock virtual_stain_flow components ----- # +# ----- Logger test doubles ----- # class DummyLogger(MlflowLogger): """ @@ -75,6 +80,8 @@ def dummy_logger(): return DummyLogger() +# ----- Model/optimizer fixtures ----- # + class MockModelWithSaveWeights(torch.nn.Module): """ Mock model that implements save_weights method for testing. @@ -115,7 +122,7 @@ def mock_optimizer(mock_model_with_save): return torch.optim.Adam(mock_model_with_save.parameters(), lr=0.001) -# ----- Fixtures for simulating minimal training ----- # +# ----- Dataset/dataloader fixtures ----- # class MinimalDataset(Dataset): """Minimal torch.utils.data.Dataset to test training.""" @@ -222,7 +229,6 @@ def empty_dataloader(): return DataLoader(dataset, batch_size=2, shuffle=False) - @pytest.fixture def image_train_loader(image_dataset): """Create a train dataloader with image data.""" @@ -243,6 +249,8 @@ def image_val_loader(image_dataset): return DataLoader(val_dataset, batch_size=4, shuffle=False) +# ----- Generic training fixtures ----- # + @pytest.fixture def simple_loss(): """Create a simple MSE loss function.""" @@ -278,3 +286,291 @@ def mock_metric(): def dataset_for_splitting(): """Create a larger dataset suitable for train/val/test splitting.""" return MinimalDataset(num_samples=100, input_size=4, target_size=2) + + +# ----- Trainer fixtures ----- # + +class MinimalTrainerRealization(AbstractTrainer): + """ + Minimal concrete realization of AbstractTrainer for testing. + Tracks method calls and provides controllable step behavior. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.train_step_calls = [] + self.evaluate_step_calls = [] + self.on_epoch_start_called = False + self.on_epoch_end_called = False + + class DummyProgressBar: + def set_postfix_str(self, *args, **kwargs): + pass + + self._epoch_pbar = DummyProgressBar() # type: ignore + + def train_step(self, inputs: torch.Tensor, targets: torch.Tensor) -> dict: + self.train_step_calls.append({ + 'inputs_shape': inputs.shape, + 'targets_shape': targets.shape, + }) + + return { + 'loss_a': torch.tensor(0.5), + 'loss_b': torch.tensor(0.3), + } + + def evaluate_step(self, inputs: torch.Tensor, targets: torch.Tensor) -> dict: + self.evaluate_step_calls.append({ + 'inputs_shape': inputs.shape, + 'targets_shape': targets.shape, + }) + + return { + 'loss_a': torch.tensor(0.4), + 'loss_b': torch.tensor(0.2), + } + + def save_model(self, save_path, file_name_prefix=None, file_name_suffix=None, + file_ext='.pth', best_model=True): + return None + + +@pytest.fixture +def minimal_trainer_cls(): + """Expose the minimal concrete trainer class for tests needing custom init.""" + return MinimalTrainerRealization + + +@pytest.fixture +def trainer_with_loaders(minimal_model, minimal_optimizer, train_dataloader, val_dataloader): + """ + Create a MinimalTrainerRealization with train and validation loaders. + """ + trainer = MinimalTrainerRealization( + model=minimal_model, + optimizer=minimal_optimizer, + train_loader=train_dataloader, + val_loader=val_dataloader, + batch_size=2, + device=torch.device('cpu') + ) + return trainer + + +@pytest.fixture +def trainer_with_empty_val_loader(minimal_model, minimal_optimizer, train_dataloader, empty_dataloader): + """ + Create a MinimalTrainerRealization with empty validation loader. + """ + trainer = MinimalTrainerRealization( + model=minimal_model, + optimizer=minimal_optimizer, + train_loader=train_dataloader, + val_loader=empty_dataloader, + batch_size=2, + device=torch.device('cpu') + ) + return trainer + + +@pytest.fixture +def single_generator_trainer(minimal_model, minimal_optimizer, simple_loss, train_dataloader, val_dataloader): + """ + Create a SingleGeneratorTrainer with a single loss function. + """ + trainer = SingleGeneratorTrainer( + model=minimal_model, + optimizer=minimal_optimizer, + losses=simple_loss, + device=torch.device('cpu'), + train_loader=train_dataloader, + val_loader=val_dataloader, + batch_size=2 + ) + return trainer + + +@pytest.fixture +def multi_loss_trainer(minimal_model, minimal_optimizer, multiple_losses, train_dataloader, val_dataloader): + """ + Create a SingleGeneratorTrainer with multiple loss functions. + """ + trainer = SingleGeneratorTrainer( + model=minimal_model, + optimizer=minimal_optimizer, + losses=multiple_losses, + device=torch.device('cpu'), + loss_weights=[0.5, 0.5], + train_loader=train_dataloader, + val_loader=val_dataloader, + batch_size=2 + ) + return trainer + + +@pytest.fixture +def conv_trainer(conv_model, conv_optimizer, simple_loss, image_train_loader, image_val_loader): + """ + Create a SingleGeneratorTrainer with conv model for full training tests. + """ + trainer = SingleGeneratorTrainer( + model=conv_model, + optimizer=conv_optimizer, + losses=simple_loss, + device=torch.device('cpu'), + train_loader=image_train_loader, + val_loader=image_val_loader, + batch_size=4, + early_termination_metric='MSELoss' + ) + return trainer + + +@pytest.fixture +def simple_discriminator(): + """ + Simple discriminator model for GAN testing. + Takes concatenated input/target stack (B, 2, H, W) -> outputs score (B, 1) + """ + import torch.nn as nn + + class SimpleDiscriminator(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(in_channels=2, out_channels=16, kernel_size=3, padding=1) + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(16, 1) + + def forward(self, x): + x = torch.relu(self.conv(x)) + x = self.pool(x).flatten(1) + return self.fc(x) + + return SimpleDiscriminator() + + +@pytest.fixture +def discriminator_optimizer(simple_discriminator): + """Create an optimizer for the discriminator.""" + return torch.optim.Adam(simple_discriminator.parameters(), lr=0.0001) + + +@pytest.fixture +def wgan_trainer(conv_model, simple_discriminator, conv_optimizer, discriminator_optimizer, + simple_loss, image_train_loader, image_val_loader): + """ + Create a LoggingWGANTrainer for testing. + """ + from virtual_stain_flow.trainers.logging_gan_trainer import LoggingWGANTrainer + + trainer = LoggingWGANTrainer( + generator=conv_model, + discriminator=simple_discriminator, + generator_optimizer=conv_optimizer, + discriminator_optimizer=discriminator_optimizer, + generator_losses=simple_loss, + device=torch.device('cpu'), + train_loader=image_train_loader, + val_loader=image_val_loader, + batch_size=4, + n_discriminator_steps=3 + ) + return trainer + + +# ----- MLflow patch fixture ----- # + +@pytest.fixture +def patched_mlflow(monkeypatch): + """Patch MLflow module methods used by MlflowLogger and capture calls.""" + + captured = { + 'tags': {}, + 'artifacts': [], + 'active_run_id': None, + } + + mlflow_logger_module = importlib.import_module( + 'virtual_stain_flow.vsf_logging.MlflowLogger' + ) + + def fake_get_experiment_by_name(_name): + return None + + def fake_create_experiment(_name): + return 'exp-1' + + def fake_start_run(*args, **kwargs): + run_id = 'run-123' + captured['active_run_id'] = run_id + return SimpleNamespace(info=SimpleNamespace(run_id=run_id)) + + def fake_active_run(): + run_id = captured['active_run_id'] + if run_id is None: + return None + return SimpleNamespace(info=SimpleNamespace(run_id=run_id)) + + def fake_end_run(): + captured['active_run_id'] = None + + def fake_set_tag(key, value): + captured['tags'][key] = value + + def fake_log_artifact(file_path, artifact_path=None): + file_content = None + try: + with open(file_path, 'r', encoding='utf-8') as f: + file_content = json.load(f) + except Exception: + file_content = None + + captured['artifacts'].append({ + 'file_path': file_path, + 'artifact_path': artifact_path, + 'content': file_content, + }) + + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'get_experiment_by_name', + fake_get_experiment_by_name, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'create_experiment', + fake_create_experiment, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'start_run', + fake_start_run, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'active_run', + fake_active_run, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'end_run', + fake_end_run, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'set_tag', + fake_set_tag, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'log_artifact', + fake_log_artifact, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'log_params', + lambda *_args, **_kwargs: None, + ) + + return captured diff --git a/tests/trainers/conftest.py b/tests/trainers/conftest.py deleted file mode 100644 index cd5e443..0000000 --- a/tests/trainers/conftest.py +++ /dev/null @@ -1,208 +0,0 @@ -""" -Fixtures for trainer tests -""" - -import pytest -import torch - -from virtual_stain_flow.trainers.AbstractTrainer import AbstractTrainer -from virtual_stain_flow.trainers.logging_trainer import SingleGeneratorTrainer - - -class MinimalTrainerRealization(AbstractTrainer): - """ - Minimal concrete realization of AbstractTrainer for testing. - Tracks method calls and provides controllable step behavior. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Track method calls for testing - self.train_step_calls = [] - self.evaluate_step_calls = [] - self.on_epoch_start_called = False - self.on_epoch_end_called = False - - # Create a dummy progress bar property that does nothing beyond - # allowing set_postfix_str calls - class DummyProgressBar: - def set_postfix_str(self, *args, **kwargs): - pass - - self._epoch_pbar = DummyProgressBar() # type: ignore - - def train_step(self, inputs: torch.Tensor, targets: torch.Tensor) -> dict: - """ - Minimal train step that returns a dict of losses. - Stores call information for verification. - """ - - self.train_step_calls.append({ - 'inputs_shape': inputs.shape, - 'targets_shape': targets.shape, - }) - - # Return scalar tensor losses (simulating real losses) - return { - 'loss_a': torch.tensor(0.5), - 'loss_b': torch.tensor(0.3), - } - - def evaluate_step(self, inputs: torch.Tensor, targets: torch.Tensor) -> dict: - """ - Minimal evaluate step that returns a dict of losses. - Stores call information for verification. - """ - - self.evaluate_step_calls.append({ - 'inputs_shape': inputs.shape, - 'targets_shape': targets.shape, - }) - - # Return scalar tensor losses (simulating real losses) - return { - 'loss_a': torch.tensor(0.4), - 'loss_b': torch.tensor(0.2), - } - - def save_model(self, save_path, file_name_prefix=None, file_name_suffix=None, - file_ext='.pth', best_model=True): - """Minimal save_model implementation.""" - return None - - -@pytest.fixture -def trainer_with_loaders(minimal_model, minimal_optimizer, train_dataloader, val_dataloader): - """ - Create a MinimalTrainerRealization with train and validation loaders. - """ - trainer = MinimalTrainerRealization( - model=minimal_model, - optimizer=minimal_optimizer, - train_loader=train_dataloader, - val_loader=val_dataloader, - batch_size=2, - device=torch.device('cpu') - ) - return trainer - - -@pytest.fixture -def trainer_with_empty_val_loader(minimal_model, minimal_optimizer, train_dataloader, empty_dataloader): - """ - Create a MinimalTrainerRealization with empty validation loader. - """ - trainer = MinimalTrainerRealization( - model=minimal_model, - optimizer=minimal_optimizer, - train_loader=train_dataloader, - val_loader=empty_dataloader, - batch_size=2, - device=torch.device('cpu') - ) - return trainer - - -@pytest.fixture -def single_generator_trainer(minimal_model, minimal_optimizer, simple_loss, train_dataloader, val_dataloader): - """ - Create a SingleGeneratorTrainer with a single loss function. - """ - trainer = SingleGeneratorTrainer( - model=minimal_model, - optimizer=minimal_optimizer, - losses=simple_loss, - device=torch.device('cpu'), - train_loader=train_dataloader, - val_loader=val_dataloader, - batch_size=2 - ) - return trainer - - -@pytest.fixture -def multi_loss_trainer(minimal_model, minimal_optimizer, multiple_losses, train_dataloader, val_dataloader): - """ - Create a SingleGeneratorTrainer with multiple loss functions. - """ - trainer = SingleGeneratorTrainer( - model=minimal_model, - optimizer=minimal_optimizer, - losses=multiple_losses, - device=torch.device('cpu'), - loss_weights=[0.5, 0.5], - train_loader=train_dataloader, - val_loader=val_dataloader, - batch_size=2 - ) - return trainer - - -@pytest.fixture -def conv_trainer(conv_model, conv_optimizer, simple_loss, image_train_loader, image_val_loader): - """ - Create a SingleGeneratorTrainer with conv model for full training tests. - """ - trainer = SingleGeneratorTrainer( - model=conv_model, - optimizer=conv_optimizer, - losses=simple_loss, - device=torch.device('cpu'), - train_loader=image_train_loader, - val_loader=image_val_loader, - batch_size=4, - early_termination_metric='MSELoss' - ) - return trainer - - -@pytest.fixture -def simple_discriminator(): - """ - Simple discriminator model for GAN testing. - Takes concatenated input/target stack (B, 2, H, W) -> outputs score (B, 1) - """ - import torch.nn as nn - - class SimpleDiscriminator(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(in_channels=2, out_channels=16, kernel_size=3, padding=1) - self.pool = nn.AdaptiveAvgPool2d(1) - self.fc = nn.Linear(16, 1) - - def forward(self, x): - x = torch.relu(self.conv(x)) - x = self.pool(x).flatten(1) - return self.fc(x) - - return SimpleDiscriminator() - - -@pytest.fixture -def discriminator_optimizer(simple_discriminator): - """Create an optimizer for the discriminator.""" - return torch.optim.Adam(simple_discriminator.parameters(), lr=0.0001) - - -@pytest.fixture -def wgan_trainer(conv_model, simple_discriminator, conv_optimizer, discriminator_optimizer, - simple_loss, image_train_loader, image_val_loader): - """ - Create a LoggingWGANTrainer for testing. - """ - from virtual_stain_flow.trainers.logging_gan_trainer import LoggingWGANTrainer - - trainer = LoggingWGANTrainer( - generator=conv_model, - discriminator=simple_discriminator, - generator_optimizer=conv_optimizer, - discriminator_optimizer=discriminator_optimizer, - generator_losses=simple_loss, - device=torch.device('cpu'), - train_loader=image_train_loader, - val_loader=image_val_loader, - batch_size=4, - n_discriminator_steps=3 - ) - return trainer diff --git a/tests/trainers/test_abstract_trainer.py b/tests/trainers/test_abstract_trainer.py index 5d9555a..62d6e38 100644 --- a/tests/trainers/test_abstract_trainer.py +++ b/tests/trainers/test_abstract_trainer.py @@ -5,7 +5,12 @@ import pytest import torch -from conftest import MinimalTrainerRealization + +@pytest.fixture(autouse=True) +def _bind_minimal_trainer_cls(minimal_trainer_cls): + """Bind concrete trainer class from fixture to avoid direct conftest imports.""" + global MinimalTrainerRealization + MinimalTrainerRealization = minimal_trainer_cls class TestTrainEpochBatchIteration: diff --git a/tests/vsf_logging/test_mlflow_logger_model_config_unet.py b/tests/vsf_logging/test_mlflow_logger_model_config_unet.py new file mode 100644 index 0000000..9b5e8f0 --- /dev/null +++ b/tests/vsf_logging/test_mlflow_logger_model_config_unet.py @@ -0,0 +1,97 @@ +""" +Tests for UNet model config logging in MlflowLogger. +""" + +import itertools + +import pytest +import torch + +from virtual_stain_flow.models.unet import UNet +from virtual_stain_flow.trainers.logging_trainer import SingleGeneratorTrainer +from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger + +UNET_VARIANTS = list( + itertools.product( + ["conv", "maxpool"], + ["convt", "bilinear"], + [4, 5], + [32, 64], + ) +) + + +@pytest.mark.parametrize( + "encoder_down_block,decoder_up_block,depth,base_channels", + UNET_VARIANTS, +) +def test_on_train_start_logs_unet_model_config_and_loss_items( + patched_mlflow, + simple_loss, + train_dataloader, + val_dataloader, + encoder_down_block, + decoder_up_block, + depth, + base_channels, +): + model = UNet( + in_channels=1, + out_channels=1, + base_channels=base_channels, + depth=depth, + encoder_down_block=encoder_down_block, + decoder_up_block=decoder_up_block, + act_type="sigmoid", + _num_units=2, + ) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + trainer = SingleGeneratorTrainer( + model=model, + optimizer=optimizer, + losses=simple_loss, + device=torch.device("cpu"), + train_loader=train_dataloader, + val_loader=val_dataloader, + batch_size=2, + ) + + logger = MlflowLogger( + name="logger", + experiment_name="exp", + ) + logger.bind_trainer(trainer) + + logger.on_train_start() + + captured = patched_mlflow + + model_artifacts = [ + artifact + for artifact in captured["artifacts"] + if artifact["content"] is not None + and "class_path" in artifact["content"] + ] + assert model_artifacts + + model_config = next( + artifact["content"] + for artifact in model_artifacts + if artifact["content"]["class_path"].endswith("UNet") + ) + assert "init" in model_config + + assert captured["tags"]["model.0.class_path"].endswith("UNet") + + assert captured["tags"]["loss.main.0.name"] == "MSELoss" + assert captured["tags"]["loss.main.0.weight"] == "1.0" + + loss_group_artifacts = [ + artifact + for artifact in captured["artifacts"] + if artifact["content"] is not None + and artifact["content"].get("group_name") == "main" + ] + assert len(loss_group_artifacts) == 1 + + logger.end_run() diff --git a/tests/vsf_logging/test_mlflow_logger_model_config_unext.py b/tests/vsf_logging/test_mlflow_logger_model_config_unext.py new file mode 100644 index 0000000..f9d1cad --- /dev/null +++ b/tests/vsf_logging/test_mlflow_logger_model_config_unext.py @@ -0,0 +1,91 @@ +""" +Tests for ConvNeXtUNet model config logging in MlflowLogger. +""" + +import itertools + +import pytest +import torch + +from virtual_stain_flow.models.unext import ConvNeXtUNet +from virtual_stain_flow.trainers.logging_trainer import SingleGeneratorTrainer +from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger + +UNEXT_VARIANTS = list( + itertools.product( + ["pixelshuffle", "convt"], + ["convnext", "conv2d"], + ) +) + + +@pytest.mark.parametrize( + "decoder_up_block,decoder_compute_block", + UNEXT_VARIANTS, +) +def test_on_train_start_logs_unext_model_config_and_loss_items( + patched_mlflow, + simple_loss, + train_dataloader, + val_dataloader, + decoder_up_block, + decoder_compute_block, +): + model = ConvNeXtUNet( + in_channels=1, + out_channels=1, + decoder_up_block=decoder_up_block, + decoder_compute_block=decoder_compute_block, + act_type="sigmoid", + _num_units=2, + ) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + trainer = SingleGeneratorTrainer( + model=model, + optimizer=optimizer, + losses=simple_loss, + device=torch.device("cpu"), + train_loader=train_dataloader, + val_loader=val_dataloader, + batch_size=2, + ) + + logger = MlflowLogger( + name="logger", + experiment_name="exp", + ) + logger.bind_trainer(trainer) + + logger.on_train_start() + + captured = patched_mlflow + + model_artifacts = [ + artifact + for artifact in captured["artifacts"] + if artifact["content"] is not None + and "class_path" in artifact["content"] + ] + assert model_artifacts + + model_config = next( + artifact["content"] + for artifact in model_artifacts + if artifact["content"]["class_path"].endswith("ConvNeXtUNet") + ) + assert "init" in model_config + + assert captured["tags"]["model.0.class_path"].endswith("ConvNeXtUNet") + + assert captured["tags"]["loss.main.0.name"] == "MSELoss" + assert captured["tags"]["loss.main.0.weight"] == "1.0" + + loss_group_artifacts = [ + artifact + for artifact in captured["artifacts"] + if artifact["content"] is not None + and artifact["content"].get("group_name") == "main" + ] + assert len(loss_group_artifacts) == 1 + + logger.end_run() diff --git a/tests/vsf_logging/test_mlflow_logger_model_config_wgan_unet.py b/tests/vsf_logging/test_mlflow_logger_model_config_wgan_unet.py new file mode 100644 index 0000000..612861c --- /dev/null +++ b/tests/vsf_logging/test_mlflow_logger_model_config_wgan_unet.py @@ -0,0 +1,119 @@ +""" +Tests for WGAN model config logging in MlflowLogger with UNet generator. +""" + +import itertools + +import pytest +import torch + +from virtual_stain_flow.models.discriminator import GlobalDiscriminator +from virtual_stain_flow.models.unet import UNet +from virtual_stain_flow.trainers.logging_gan_trainer import LoggingWGANTrainer +from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger + +UNET_VARIANTS = list( + itertools.product( + ["conv", "maxpool"], + ["convt", "bilinear"], + [4, 5], + [32, 64], + ) +) + + +@pytest.mark.parametrize( + "encoder_down_block,decoder_up_block,depth,base_channels", + UNET_VARIANTS, +) +def test_on_train_start_logs_wgan_unet_model_configs_and_loss_items( + patched_mlflow, + simple_loss, + train_dataloader, + val_dataloader, + encoder_down_block, + decoder_up_block, + depth, + base_channels, +): + generator = UNet( + in_channels=1, + out_channels=1, + base_channels=base_channels, + depth=depth, + encoder_down_block=encoder_down_block, + decoder_up_block=decoder_up_block, + act_type="sigmoid", + _num_units=2, + ) + discriminator = GlobalDiscriminator( + n_in_channels=1, + n_in_filters=1, + out_activation=None, + _conv_depth=4, + _leaky_relu_alpha=0.2, + _batch_norm=False, + _pool_before_fc=False, + ) + generator_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-3) + discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-3) + + trainer = LoggingWGANTrainer( + generator=generator, + discriminator=discriminator, + generator_optimizer=generator_optimizer, + discriminator_optimizer=discriminator_optimizer, + generator_losses=simple_loss, + device=torch.device("cpu"), + train_loader=train_dataloader, + val_loader=val_dataloader, + batch_size=2, + n_discriminator_steps=3, + ) + + logger = MlflowLogger( + name="logger", + experiment_name="exp", + ) + logger.bind_trainer(trainer) + + logger.on_train_start() + + captured = patched_mlflow + + model_artifacts = [ + artifact + for artifact in captured["artifacts"] + if artifact["content"] is not None + and "class_path" in artifact["content"] + ] + assert model_artifacts + + generator_config = next( + artifact["content"] + for artifact in model_artifacts + if artifact["content"]["class_path"].endswith("UNet") + ) + assert "init" in generator_config + + discriminator_config = next( + artifact["content"] + for artifact in model_artifacts + if artifact["content"]["class_path"].endswith("GlobalDiscriminator") + ) + assert "init" in discriminator_config + + assert captured["tags"]["model.0.class_path"].endswith("UNet") + assert captured["tags"]["model.1.class_path"].endswith("GlobalDiscriminator") + + assert captured["tags"]["loss.generator.0.name"] == "MSELoss" + assert captured["tags"]["loss.generator.0.weight"] == "1.0" + assert captured["tags"]["loss.generator.1.name"] == "AdversarialLoss" + assert captured["tags"]["loss.generator.1.weight"] == "1.0" + + assert captured["tags"]["loss.discriminator.0.name"] == "WassersteinLoss" + assert captured["tags"]["loss.discriminator.0.weight"] == "1.0" + assert captured["tags"]["loss.discriminator.1.name"] == "GradientPenaltyLoss" + assert captured["tags"]["loss.discriminator.1.weight"] == "10.0" + + logger.end_run()