From 5cbdf05be60ce5475ddb0d9a8f54c39fc5d190cf Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sat, 25 Apr 2026 12:09:57 -0600 Subject: [PATCH 01/12] Refactored current loss group by adding 1) a currently unused but potentially useful optional parameter, epoch, that would help future addition of loss weight scheudling functionality; and 2) adding get_config functions for auto loss config logging --- src/virtual_stain_flow/engine/loss_group.py | 35 +++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/src/virtual_stain_flow/engine/loss_group.py b/src/virtual_stain_flow/engine/loss_group.py index 16b3a74..72e79c3 100644 --- a/src/virtual_stain_flow/engine/loss_group.py +++ b/src/virtual_stain_flow/engine/loss_group.py @@ -22,7 +22,7 @@ """ from dataclasses import dataclass -from typing import Optional, Union, Tuple, Dict, Sequence, List +from typing import Optional, Union, Tuple, Dict, Sequence, List, Any import torch @@ -79,6 +79,7 @@ def __post_init__(self): def __call__( self, train: bool, + epoch: Optional[int] = None, context: Optional[Context] = None, **inputs: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -88,6 +89,7 @@ def __call__( skipped during validation. :param train: Whether the model is in training mode. + :param epoch: Optional epoch number to determine the weight from the schedule. :param context: Optional Context object containing tensors. :param inputs: Keyword arguments containing all necessary inputs for the loss computation. @@ -117,6 +119,21 @@ def __call__( return raw, raw * _scalar_from_ctx(self.weight, inputs) + def get_config(self) -> Dict[str, Any]: + """ + Get the configuration of the LossItem for logging or checkpointing. + """ + return { + 'module': self.module.__class__.__name__, + 'args': self.args, + 'key': self.key, + 'weight': self.weight, + 'enabled': self.enabled, + 'compute_at_val': self.compute_at_val, + 'device': str(self.device) + } + + @dataclass class LossGroup: """ @@ -137,6 +154,7 @@ def item_names(self) -> List[Optional[str]]: def __call__( self, train: bool, + epoch: Optional[int] = None, context: Optional[Context] = None, **inputs: torch.Tensor ) -> Tuple[torch.Tensor, Dict[str, Scalar]]: @@ -144,6 +162,7 @@ def __call__( Compute the total loss and individual loss values. :param train: Whether the model is in training mode. + :param epoch: Optional epoch number to determine the weight from the schedule. :param context: Optional Context object containing tensors. :input inputs: Keyword arguments containing all necessary inputs for the loss computations. @@ -156,8 +175,20 @@ def __call__( logs: Dict[str, float] = {} for item in self.items: - raw, weighted = item(train, context=context, **inputs) + raw, weighted = item( + train, + epoch=epoch, + context=context, + **inputs + ) logs[item.key] = raw.item() # type: ignore total += weighted return total, logs + + def get_config(self) -> List[Dict[str, Any]]: + """ + Get the configuration of the LossGroup for logging or checkpointing. + """ + + return [item.get_config() for item in self.items] From d891e757668a8904f16c5f914d7b4b020b88afe8 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sat, 25 Apr 2026 12:11:57 -0600 Subject: [PATCH 02/12] Refactored current realizations of logging trainers to 1) pass self.epoch as parameter to modified loss group (currently not used for anything) that will be useful for the future; and 2) add loss group logging property for loss auto logging --- .../trainers/logging_gan_trainer.py | 7 +++++++ .../trainers/logging_trainer.py | 16 ++++++++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/virtual_stain_flow/trainers/logging_gan_trainer.py b/src/virtual_stain_flow/trainers/logging_gan_trainer.py index 8e5fe69..d50c437 100644 --- a/src/virtual_stain_flow/trainers/logging_gan_trainer.py +++ b/src/virtual_stain_flow/trainers/logging_gan_trainer.py @@ -187,6 +187,13 @@ def evaluate_step( metric.update(*ctx.as_metric_args(), validation=True) return gen_logs | disc_logs + + @property + def loss_groups(self) -> Dict[str, LossGroup]: + return { + 'generator': self._generator_loss_group, + 'discriminator': self._discriminator_loss_group + } def save_model( self, diff --git a/src/virtual_stain_flow/trainers/logging_trainer.py b/src/virtual_stain_flow/trainers/logging_trainer.py index 6442dea..27e9f09 100644 --- a/src/virtual_stain_flow/trainers/logging_trainer.py +++ b/src/virtual_stain_flow/trainers/logging_trainer.py @@ -105,7 +105,11 @@ def train_step( targets=targets ) - weighted_total, logs = self._loss_group(train=True, context=ctx) + weighted_total, logs = self._loss_group( + train=True, + epoch=self.epoch, + context=ctx + ) weighted_total.backward() self._forward_group.step() @@ -133,12 +137,20 @@ def evaluate_step( targets=targets ) - _, logs = self._loss_group(train=False, context=ctx) + _, logs = self._loss_group( + train=False, + epoch=self.epoch, + context=ctx + ) for _, metric in self.metrics.items(): metric.update(*ctx.as_metric_args(), validation=True) return logs + + @property + def loss_groups(self) -> Dict[str, LossGroup]: + return {'main': self._loss_group} def save_model( self, From 7e0eae75e38c1834fa645c41f6bac21830210e42 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sat, 25 Apr 2026 12:13:41 -0600 Subject: [PATCH 03/12] Add auto loss group logging functionality to MlflowLogger --- .../vsf_logging/MlflowLogger.py | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/src/virtual_stain_flow/vsf_logging/MlflowLogger.py b/src/virtual_stain_flow/vsf_logging/MlflowLogger.py index a46c61d..9bfe22b 100644 --- a/src/virtual_stain_flow/vsf_logging/MlflowLogger.py +++ b/src/virtual_stain_flow/vsf_logging/MlflowLogger.py @@ -222,6 +222,8 @@ def on_train_start(self): except Exception as e: print(f"Fail to log model config as artifact: {e}") + self._log_loss_groups_config_and_tags() + for callback in self.callbacks: # TODO consider if we want hasattr checks @@ -502,6 +504,92 @@ def _save_model_weights( artifact_path=artifact_path ) + def _get_loss_groups(self) -> Dict[str, Any]: + """ + Discover loss groups attached to the bound trainer. + """ + + if self.trainer is None: + return {} + + loss_groups: Dict[str, Any] = {} + + explicit_groups = getattr(self.trainer, 'loss_groups', None) + if isinstance(explicit_groups, dict): + for group_name, group in explicit_groups.items(): + if hasattr(group, 'get_config'): + loss_groups[str(group_name)] = group + + fallback_attrs = { + 'main': '_loss_group', + 'generator': '_generator_loss_group', + 'discriminator': '_discriminator_loss_group' + } + for group_name, attr in fallback_attrs.items(): + if group_name in loss_groups: + continue + group = getattr(self.trainer, attr, None) + if group is not None and hasattr(group, 'get_config'): + loss_groups[group_name] = group + + return loss_groups + + def _log_loss_groups_config_and_tags(self) -> None: + """ + Log loss item names and weights as flat tags and full loss group + configuration as config artifacts. + """ + + loss_groups = self._get_loss_groups() + if not loss_groups: + return None + + for group_name, group in loss_groups.items(): + try: + group_config = group.get_config() + except Exception as e: + print( + f"Could not get loss group config for logging " + f"({group_name}): {e}" + ) + continue + + if not isinstance(group_config, list): + continue + + for idx, item_cfg in enumerate(group_config): + if not isinstance(item_cfg, dict): + continue + + if 'key' in item_cfg and item_cfg['key'] is not None: + mlflow.set_tag( + f"loss.{group_name}.{idx}.name", + str(item_cfg['key']) + ) + + if 'weight' in item_cfg and item_cfg['weight'] is not None: + mlflow.set_tag( + f"loss.{group_name}.{idx}.weight", + str(item_cfg['weight']) + ) + + try: + self.log_config( + tag=f"loss_group_{group_name}", + config={ + 'group_name': group_name, + 'items': group_config + }, + stage=None + ) + except Exception as e: + print( + f"Fail to log loss group config as artifact " + f"({group_name}): {e}" + ) + + return None + def log_config( self, tag: str, From b124f321bd1dc82d58d0c250205837b15433e8d8 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sat, 25 Apr 2026 12:50:08 -0600 Subject: [PATCH 04/12] Refactor test fixture by moving fixtures only under the trainer scope to global fixture, and modified in test_abstract_trainer tests to use fixtures as opposed to rely on helper class import. This would benefit future test addition by exposing more useful fixtures --- tests/conftest.py | 302 +++++++++++++++++++++++- tests/trainers/conftest.py | 208 ---------------- tests/trainers/test_abstract_trainer.py | 7 +- 3 files changed, 305 insertions(+), 212 deletions(-) delete mode 100644 tests/trainers/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py index 5859a44..b635c69 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,16 +2,21 @@ Testing fixtures meant to be shared across the whole package """ +import json +import importlib import pathlib +from types import SimpleNamespace import pytest import torch from torch.utils.data import DataLoader, Dataset +from virtual_stain_flow.trainers.AbstractTrainer import AbstractTrainer +from virtual_stain_flow.trainers.logging_trainer import SingleGeneratorTrainer from virtual_stain_flow.vsf_logging import MlflowLogger -# ----- Mock virtual_stain_flow components ----- # +# ----- Logger test doubles ----- # class DummyLogger(MlflowLogger): """ @@ -75,6 +80,8 @@ def dummy_logger(): return DummyLogger() +# ----- Model/optimizer fixtures ----- # + class MockModelWithSaveWeights(torch.nn.Module): """ Mock model that implements save_weights method for testing. @@ -115,7 +122,7 @@ def mock_optimizer(mock_model_with_save): return torch.optim.Adam(mock_model_with_save.parameters(), lr=0.001) -# ----- Fixtures for simulating minimal training ----- # +# ----- Dataset/dataloader fixtures ----- # class MinimalDataset(Dataset): """Minimal torch.utils.data.Dataset to test training.""" @@ -222,7 +229,6 @@ def empty_dataloader(): return DataLoader(dataset, batch_size=2, shuffle=False) - @pytest.fixture def image_train_loader(image_dataset): """Create a train dataloader with image data.""" @@ -243,6 +249,8 @@ def image_val_loader(image_dataset): return DataLoader(val_dataset, batch_size=4, shuffle=False) +# ----- Generic training fixtures ----- # + @pytest.fixture def simple_loss(): """Create a simple MSE loss function.""" @@ -278,3 +286,291 @@ def mock_metric(): def dataset_for_splitting(): """Create a larger dataset suitable for train/val/test splitting.""" return MinimalDataset(num_samples=100, input_size=4, target_size=2) + + +# ----- Trainer fixtures ----- # + +class MinimalTrainerRealization(AbstractTrainer): + """ + Minimal concrete realization of AbstractTrainer for testing. + Tracks method calls and provides controllable step behavior. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.train_step_calls = [] + self.evaluate_step_calls = [] + self.on_epoch_start_called = False + self.on_epoch_end_called = False + + class DummyProgressBar: + def set_postfix_str(self, *args, **kwargs): + pass + + self._epoch_pbar = DummyProgressBar() # type: ignore + + def train_step(self, inputs: torch.Tensor, targets: torch.Tensor) -> dict: + self.train_step_calls.append({ + 'inputs_shape': inputs.shape, + 'targets_shape': targets.shape, + }) + + return { + 'loss_a': torch.tensor(0.5), + 'loss_b': torch.tensor(0.3), + } + + def evaluate_step(self, inputs: torch.Tensor, targets: torch.Tensor) -> dict: + self.evaluate_step_calls.append({ + 'inputs_shape': inputs.shape, + 'targets_shape': targets.shape, + }) + + return { + 'loss_a': torch.tensor(0.4), + 'loss_b': torch.tensor(0.2), + } + + def save_model(self, save_path, file_name_prefix=None, file_name_suffix=None, + file_ext='.pth', best_model=True): + return None + + +@pytest.fixture +def minimal_trainer_cls(): + """Expose the minimal concrete trainer class for tests needing custom init.""" + return MinimalTrainerRealization + + +@pytest.fixture +def trainer_with_loaders(minimal_model, minimal_optimizer, train_dataloader, val_dataloader): + """ + Create a MinimalTrainerRealization with train and validation loaders. + """ + trainer = MinimalTrainerRealization( + model=minimal_model, + optimizer=minimal_optimizer, + train_loader=train_dataloader, + val_loader=val_dataloader, + batch_size=2, + device=torch.device('cpu') + ) + return trainer + + +@pytest.fixture +def trainer_with_empty_val_loader(minimal_model, minimal_optimizer, train_dataloader, empty_dataloader): + """ + Create a MinimalTrainerRealization with empty validation loader. + """ + trainer = MinimalTrainerRealization( + model=minimal_model, + optimizer=minimal_optimizer, + train_loader=train_dataloader, + val_loader=empty_dataloader, + batch_size=2, + device=torch.device('cpu') + ) + return trainer + + +@pytest.fixture +def single_generator_trainer(minimal_model, minimal_optimizer, simple_loss, train_dataloader, val_dataloader): + """ + Create a SingleGeneratorTrainer with a single loss function. + """ + trainer = SingleGeneratorTrainer( + model=minimal_model, + optimizer=minimal_optimizer, + losses=simple_loss, + device=torch.device('cpu'), + train_loader=train_dataloader, + val_loader=val_dataloader, + batch_size=2 + ) + return trainer + + +@pytest.fixture +def multi_loss_trainer(minimal_model, minimal_optimizer, multiple_losses, train_dataloader, val_dataloader): + """ + Create a SingleGeneratorTrainer with multiple loss functions. + """ + trainer = SingleGeneratorTrainer( + model=minimal_model, + optimizer=minimal_optimizer, + losses=multiple_losses, + device=torch.device('cpu'), + loss_weights=[0.5, 0.5], + train_loader=train_dataloader, + val_loader=val_dataloader, + batch_size=2 + ) + return trainer + + +@pytest.fixture +def conv_trainer(conv_model, conv_optimizer, simple_loss, image_train_loader, image_val_loader): + """ + Create a SingleGeneratorTrainer with conv model for full training tests. + """ + trainer = SingleGeneratorTrainer( + model=conv_model, + optimizer=conv_optimizer, + losses=simple_loss, + device=torch.device('cpu'), + train_loader=image_train_loader, + val_loader=image_val_loader, + batch_size=4, + early_termination_metric='MSELoss' + ) + return trainer + + +@pytest.fixture +def simple_discriminator(): + """ + Simple discriminator model for GAN testing. + Takes concatenated input/target stack (B, 2, H, W) -> outputs score (B, 1) + """ + import torch.nn as nn + + class SimpleDiscriminator(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(in_channels=2, out_channels=16, kernel_size=3, padding=1) + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(16, 1) + + def forward(self, x): + x = torch.relu(self.conv(x)) + x = self.pool(x).flatten(1) + return self.fc(x) + + return SimpleDiscriminator() + + +@pytest.fixture +def discriminator_optimizer(simple_discriminator): + """Create an optimizer for the discriminator.""" + return torch.optim.Adam(simple_discriminator.parameters(), lr=0.0001) + + +@pytest.fixture +def wgan_trainer(conv_model, simple_discriminator, conv_optimizer, discriminator_optimizer, + simple_loss, image_train_loader, image_val_loader): + """ + Create a LoggingWGANTrainer for testing. + """ + from virtual_stain_flow.trainers.logging_gan_trainer import LoggingWGANTrainer + + trainer = LoggingWGANTrainer( + generator=conv_model, + discriminator=simple_discriminator, + generator_optimizer=conv_optimizer, + discriminator_optimizer=discriminator_optimizer, + generator_losses=simple_loss, + device=torch.device('cpu'), + train_loader=image_train_loader, + val_loader=image_val_loader, + batch_size=4, + n_discriminator_steps=3 + ) + return trainer + + +# ----- MLflow patch fixture ----- # + +@pytest.fixture +def patched_mlflow(monkeypatch): + """Patch MLflow module methods used by MlflowLogger and capture calls.""" + + captured = { + 'tags': {}, + 'artifacts': [], + 'active_run_id': None, + } + + mlflow_logger_module = importlib.import_module( + 'virtual_stain_flow.vsf_logging.MlflowLogger' + ) + + def fake_get_experiment_by_name(_name): + return None + + def fake_create_experiment(_name): + return 'exp-1' + + def fake_start_run(*args, **kwargs): + run_id = 'run-123' + captured['active_run_id'] = run_id + return SimpleNamespace(info=SimpleNamespace(run_id=run_id)) + + def fake_active_run(): + run_id = captured['active_run_id'] + if run_id is None: + return None + return SimpleNamespace(info=SimpleNamespace(run_id=run_id)) + + def fake_end_run(): + captured['active_run_id'] = None + + def fake_set_tag(key, value): + captured['tags'][key] = value + + def fake_log_artifact(file_path, artifact_path=None): + file_content = None + try: + with open(file_path, 'r', encoding='utf-8') as f: + file_content = json.load(f) + except Exception: + file_content = None + + captured['artifacts'].append({ + 'file_path': file_path, + 'artifact_path': artifact_path, + 'content': file_content, + }) + + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'get_experiment_by_name', + fake_get_experiment_by_name, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'create_experiment', + fake_create_experiment, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'start_run', + fake_start_run, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'active_run', + fake_active_run, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'end_run', + fake_end_run, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'set_tag', + fake_set_tag, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'log_artifact', + fake_log_artifact, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'log_params', + lambda *_args, **_kwargs: None, + ) + + return captured diff --git a/tests/trainers/conftest.py b/tests/trainers/conftest.py deleted file mode 100644 index cd5e443..0000000 --- a/tests/trainers/conftest.py +++ /dev/null @@ -1,208 +0,0 @@ -""" -Fixtures for trainer tests -""" - -import pytest -import torch - -from virtual_stain_flow.trainers.AbstractTrainer import AbstractTrainer -from virtual_stain_flow.trainers.logging_trainer import SingleGeneratorTrainer - - -class MinimalTrainerRealization(AbstractTrainer): - """ - Minimal concrete realization of AbstractTrainer for testing. - Tracks method calls and provides controllable step behavior. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Track method calls for testing - self.train_step_calls = [] - self.evaluate_step_calls = [] - self.on_epoch_start_called = False - self.on_epoch_end_called = False - - # Create a dummy progress bar property that does nothing beyond - # allowing set_postfix_str calls - class DummyProgressBar: - def set_postfix_str(self, *args, **kwargs): - pass - - self._epoch_pbar = DummyProgressBar() # type: ignore - - def train_step(self, inputs: torch.Tensor, targets: torch.Tensor) -> dict: - """ - Minimal train step that returns a dict of losses. - Stores call information for verification. - """ - - self.train_step_calls.append({ - 'inputs_shape': inputs.shape, - 'targets_shape': targets.shape, - }) - - # Return scalar tensor losses (simulating real losses) - return { - 'loss_a': torch.tensor(0.5), - 'loss_b': torch.tensor(0.3), - } - - def evaluate_step(self, inputs: torch.Tensor, targets: torch.Tensor) -> dict: - """ - Minimal evaluate step that returns a dict of losses. - Stores call information for verification. - """ - - self.evaluate_step_calls.append({ - 'inputs_shape': inputs.shape, - 'targets_shape': targets.shape, - }) - - # Return scalar tensor losses (simulating real losses) - return { - 'loss_a': torch.tensor(0.4), - 'loss_b': torch.tensor(0.2), - } - - def save_model(self, save_path, file_name_prefix=None, file_name_suffix=None, - file_ext='.pth', best_model=True): - """Minimal save_model implementation.""" - return None - - -@pytest.fixture -def trainer_with_loaders(minimal_model, minimal_optimizer, train_dataloader, val_dataloader): - """ - Create a MinimalTrainerRealization with train and validation loaders. - """ - trainer = MinimalTrainerRealization( - model=minimal_model, - optimizer=minimal_optimizer, - train_loader=train_dataloader, - val_loader=val_dataloader, - batch_size=2, - device=torch.device('cpu') - ) - return trainer - - -@pytest.fixture -def trainer_with_empty_val_loader(minimal_model, minimal_optimizer, train_dataloader, empty_dataloader): - """ - Create a MinimalTrainerRealization with empty validation loader. - """ - trainer = MinimalTrainerRealization( - model=minimal_model, - optimizer=minimal_optimizer, - train_loader=train_dataloader, - val_loader=empty_dataloader, - batch_size=2, - device=torch.device('cpu') - ) - return trainer - - -@pytest.fixture -def single_generator_trainer(minimal_model, minimal_optimizer, simple_loss, train_dataloader, val_dataloader): - """ - Create a SingleGeneratorTrainer with a single loss function. - """ - trainer = SingleGeneratorTrainer( - model=minimal_model, - optimizer=minimal_optimizer, - losses=simple_loss, - device=torch.device('cpu'), - train_loader=train_dataloader, - val_loader=val_dataloader, - batch_size=2 - ) - return trainer - - -@pytest.fixture -def multi_loss_trainer(minimal_model, minimal_optimizer, multiple_losses, train_dataloader, val_dataloader): - """ - Create a SingleGeneratorTrainer with multiple loss functions. - """ - trainer = SingleGeneratorTrainer( - model=minimal_model, - optimizer=minimal_optimizer, - losses=multiple_losses, - device=torch.device('cpu'), - loss_weights=[0.5, 0.5], - train_loader=train_dataloader, - val_loader=val_dataloader, - batch_size=2 - ) - return trainer - - -@pytest.fixture -def conv_trainer(conv_model, conv_optimizer, simple_loss, image_train_loader, image_val_loader): - """ - Create a SingleGeneratorTrainer with conv model for full training tests. - """ - trainer = SingleGeneratorTrainer( - model=conv_model, - optimizer=conv_optimizer, - losses=simple_loss, - device=torch.device('cpu'), - train_loader=image_train_loader, - val_loader=image_val_loader, - batch_size=4, - early_termination_metric='MSELoss' - ) - return trainer - - -@pytest.fixture -def simple_discriminator(): - """ - Simple discriminator model for GAN testing. - Takes concatenated input/target stack (B, 2, H, W) -> outputs score (B, 1) - """ - import torch.nn as nn - - class SimpleDiscriminator(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(in_channels=2, out_channels=16, kernel_size=3, padding=1) - self.pool = nn.AdaptiveAvgPool2d(1) - self.fc = nn.Linear(16, 1) - - def forward(self, x): - x = torch.relu(self.conv(x)) - x = self.pool(x).flatten(1) - return self.fc(x) - - return SimpleDiscriminator() - - -@pytest.fixture -def discriminator_optimizer(simple_discriminator): - """Create an optimizer for the discriminator.""" - return torch.optim.Adam(simple_discriminator.parameters(), lr=0.0001) - - -@pytest.fixture -def wgan_trainer(conv_model, simple_discriminator, conv_optimizer, discriminator_optimizer, - simple_loss, image_train_loader, image_val_loader): - """ - Create a LoggingWGANTrainer for testing. - """ - from virtual_stain_flow.trainers.logging_gan_trainer import LoggingWGANTrainer - - trainer = LoggingWGANTrainer( - generator=conv_model, - discriminator=simple_discriminator, - generator_optimizer=conv_optimizer, - discriminator_optimizer=discriminator_optimizer, - generator_losses=simple_loss, - device=torch.device('cpu'), - train_loader=image_train_loader, - val_loader=image_val_loader, - batch_size=4, - n_discriminator_steps=3 - ) - return trainer diff --git a/tests/trainers/test_abstract_trainer.py b/tests/trainers/test_abstract_trainer.py index 5d9555a..62d6e38 100644 --- a/tests/trainers/test_abstract_trainer.py +++ b/tests/trainers/test_abstract_trainer.py @@ -5,7 +5,12 @@ import pytest import torch -from conftest import MinimalTrainerRealization + +@pytest.fixture(autouse=True) +def _bind_minimal_trainer_cls(minimal_trainer_cls): + """Bind concrete trainer class from fixture to avoid direct conftest imports.""" + global MinimalTrainerRealization + MinimalTrainerRealization = minimal_trainer_cls class TestTrainEpochBatchIteration: From 2ae005c0fec92aac45075d3bfb54ab1766f37e76 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sat, 25 Apr 2026 12:50:33 -0600 Subject: [PATCH 05/12] Add tests for automatic loss-group config logging in MlflowLogger --- .../test_mlflow_logger_loss_config.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 tests/vsf_logging/test_mlflow_logger_loss_config.py diff --git a/tests/vsf_logging/test_mlflow_logger_loss_config.py b/tests/vsf_logging/test_mlflow_logger_loss_config.py new file mode 100644 index 0000000..296940f --- /dev/null +++ b/tests/vsf_logging/test_mlflow_logger_loss_config.py @@ -0,0 +1,80 @@ +""" +Tests for automatic loss-group config logging in MlflowLogger. +""" + + +class TestMlflowLoggerLossConfigLogging: + + def test_on_train_start_logs_single_trainer_loss_tags_and_config( + self, + patched_mlflow, + single_generator_trainer, + ): + from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger + + captured = patched_mlflow + + logger = MlflowLogger( + name='logger', + experiment_name='exp', + ) + logger.bind_trainer(single_generator_trainer) + + logger.on_train_start() + + assert captured['tags']['loss.main.0.name'] == 'MSELoss' + assert captured['tags']['loss.main.0.weight'] == '1.0' + + loss_group_artifacts = [ + artifact + for artifact in captured['artifacts'] + if artifact['content'] is not None + and artifact['content'].get('group_name') == 'main' + ] + + assert len(loss_group_artifacts) == 1 + artifact = loss_group_artifacts[0] + assert artifact['artifact_path'] == 'configs' + assert len(artifact['content']['items']) == 1 + assert artifact['content']['items'][0]['key'] == 'MSELoss' + assert artifact['content']['items'][0]['weight'] == 1.0 + + logger.end_run() + + def test_on_train_start_logs_wgan_loss_tags_and_configs( + self, + patched_mlflow, + wgan_trainer, + ): + from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger + + captured = patched_mlflow + + logger = MlflowLogger( + name='logger', + experiment_name='exp', + ) + logger.bind_trainer(wgan_trainer) + + logger.on_train_start() + + assert captured['tags']['loss.generator.0.name'] == 'MSELoss' + assert captured['tags']['loss.generator.0.weight'] == '1.0' + assert captured['tags']['loss.generator.1.name'] == 'AdversarialLoss' + assert captured['tags']['loss.generator.1.weight'] == '1.0' + + assert captured['tags']['loss.discriminator.0.name'] == 'WassersteinLoss' + assert captured['tags']['loss.discriminator.0.weight'] == '1.0' + assert captured['tags']['loss.discriminator.1.name'] == 'GradientPenaltyLoss' + assert captured['tags']['loss.discriminator.1.weight'] == '10.0' + + group_names = { + artifact['content']['group_name'] + for artifact in captured['artifacts'] + if artifact['content'] is not None + and 'group_name' in artifact['content'] + } + + assert group_names == {'generator', 'discriminator'} + + logger.end_run() From 48bed604df46fac1e23aac5532f463499483bee0 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 30 Apr 2026 11:42:15 -0600 Subject: [PATCH 06/12] Add Progress class abstraction for more versatile weight scheduling --- src/virtual_stain_flow/engine/progress.py | 37 +++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 src/virtual_stain_flow/engine/progress.py diff --git a/src/virtual_stain_flow/engine/progress.py b/src/virtual_stain_flow/engine/progress.py new file mode 100644 index 0000000..e6024af --- /dev/null +++ b/src/virtual_stain_flow/engine/progress.py @@ -0,0 +1,37 @@ +""" +progress.py + +Progress tracking for loss weight scheduling. + +Provides a centralized abstraction for training progress state that can be used +by loss schedulers operating at different granularities (epoch, step, etc.). +Designed to be minimal and extensible without overcomplicating the current API. +""" + +from dataclasses import dataclass + + +@dataclass +class Progress: + """ + Tracks training progress for loss weight scheduling. + + Provides centralized access to scheduling state including epoch and step, + with room for future custom scheduling granularities. + """ + epoch: int = 0 + step: int = 0 + + def set_epoch(self, epoch: int) -> None: + """ + Update the current epoch number. + """ + self.epoch = epoch + + def set_step(self, step: int) -> None: + """ + Update the current step number. + Intended to be accumulated across epoch for a global step count + that can be used for step-based scheduling. + """ + self.step = step From adada9299820a6698f369eca64672270d93227ce Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 30 Apr 2026 11:43:27 -0600 Subject: [PATCH 07/12] Refactored loss item and group and trainer to use and update the newly added `Progress` object in place of a epoch number. --- src/virtual_stain_flow/engine/loss_group.py | 13 ++++++++----- src/virtual_stain_flow/trainers/AbstractTrainer.py | 12 ++++++++++++ .../trainers/logging_gan_trainer.py | 6 ++++++ src/virtual_stain_flow/trainers/logging_trainer.py | 4 ++-- 4 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/virtual_stain_flow/engine/loss_group.py b/src/virtual_stain_flow/engine/loss_group.py index 72e79c3..09f3684 100644 --- a/src/virtual_stain_flow/engine/loss_group.py +++ b/src/virtual_stain_flow/engine/loss_group.py @@ -29,6 +29,7 @@ from .loss_utils import BaseLoss, _get_loss_name, _scalar_from_ctx from .context import Context, ContextValue from .names import PREDS, TARGETS +from .progress import Progress Scalar = Union[int, float, bool] @@ -79,7 +80,7 @@ def __post_init__(self): def __call__( self, train: bool, - epoch: Optional[int] = None, + progress: Optional[Progress] = None, context: Optional[Context] = None, **inputs: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -89,7 +90,8 @@ def __call__( skipped during validation. :param train: Whether the model is in training mode. - :param epoch: Optional epoch number to determine the weight from the schedule. + :param progress: Optional Progress object containing scheduling state (epoch, step, etc.) + for dynamic weight scheduling. :param context: Optional Context object containing tensors. :param inputs: Keyword arguments containing all necessary inputs for the loss computation. @@ -154,7 +156,7 @@ def item_names(self) -> List[Optional[str]]: def __call__( self, train: bool, - epoch: Optional[int] = None, + progress: Optional[Progress] = None, context: Optional[Context] = None, **inputs: torch.Tensor ) -> Tuple[torch.Tensor, Dict[str, Scalar]]: @@ -162,7 +164,8 @@ def __call__( Compute the total loss and individual loss values. :param train: Whether the model is in training mode. - :param epoch: Optional epoch number to determine the weight from the schedule. + :param progress: Optional Progress object containing scheduling state (epoch, step, etc.) + for dynamic weight scheduling. :param context: Optional Context object containing tensors. :input inputs: Keyword arguments containing all necessary inputs for the loss computations. @@ -177,7 +180,7 @@ def __call__( for item in self.items: raw, weighted = item( train, - epoch=epoch, + progress=progress, context=context, **inputs ) diff --git a/src/virtual_stain_flow/trainers/AbstractTrainer.py b/src/virtual_stain_flow/trainers/AbstractTrainer.py index 6c64acf..36c25ae 100644 --- a/src/virtual_stain_flow/trainers/AbstractTrainer.py +++ b/src/virtual_stain_flow/trainers/AbstractTrainer.py @@ -14,6 +14,7 @@ from .trainer_protocol import TrainerProtocol from ..metrics.AbstractMetrics import AbstractMetrics +from ..engine.progress import Progress from ..datasets.data_split import default_random_split @@ -113,6 +114,9 @@ def _init_state( # Epoch state self._epoch = 0 + + # Progress tracking for loss weight scheduling + self._progress = Progress(epoch=0, step=0) # Loss and metrics state self._train_losses = defaultdict(list) @@ -232,6 +236,8 @@ def train_epoch(self): phase="Train" ) + self._progress.set_step(self._progress.step + 1) + batch_loss = self.train_step(inputs, targets) for key, value in batch_loss.items(): losses[key].append(value) @@ -512,6 +518,11 @@ def metrics(self): def epoch(self): return self._epoch + @property + def progress(self) -> Progress: + """Returns the Progress object tracking training state (epoch, step, etc.)""" + return self._progress + @property def train_losses(self): return self._train_losses @@ -548,6 +559,7 @@ def early_stop_counter(self, value: int): @epoch.setter def epoch(self, value: int): self._epoch = value + self._progress.set_epoch(value) """ Update loss and metrics diff --git a/src/virtual_stain_flow/trainers/logging_gan_trainer.py b/src/virtual_stain_flow/trainers/logging_gan_trainer.py index d50c437..d959457 100644 --- a/src/virtual_stain_flow/trainers/logging_gan_trainer.py +++ b/src/virtual_stain_flow/trainers/logging_gan_trainer.py @@ -110,6 +110,7 @@ def train_step( ) disc_weighted_total, disc_logs = self._discriminator_loss_group( train=True, + progress=self.progress, context=disc_ctx ) disc_weighted_total.backward() @@ -133,6 +134,7 @@ def train_step( ) gen_weighted_total, gen_logs = self._generator_loss_group( train=True, + progress=self.progress, context=gen_ctx ) gen_weighted_total.backward() @@ -142,12 +144,14 @@ def train_step( gen_logs = {} self._global_step += 1 + self._progress.set_step(self._global_step) # if generator logs are not computed this step (due to skipped update), # compute from discriminator context if not gen_logs: _, gen_logs = self._generator_loss_group( train=True, + progress=self.progress, context=ctx ) @@ -176,10 +180,12 @@ def evaluate_step( ) _, gen_logs = self._generator_loss_group( train=False, + progress=self.progress, context=ctx ) _, disc_logs = self._discriminator_loss_group( train=False, + progress=self.progress, context=ctx ) diff --git a/src/virtual_stain_flow/trainers/logging_trainer.py b/src/virtual_stain_flow/trainers/logging_trainer.py index 27e9f09..ff20c8d 100644 --- a/src/virtual_stain_flow/trainers/logging_trainer.py +++ b/src/virtual_stain_flow/trainers/logging_trainer.py @@ -107,7 +107,7 @@ def train_step( weighted_total, logs = self._loss_group( train=True, - epoch=self.epoch, + progress=self.progress, context=ctx ) weighted_total.backward() @@ -139,7 +139,7 @@ def evaluate_step( _, logs = self._loss_group( train=False, - epoch=self.epoch, + progress=self.progress, context=ctx ) From d8fb72c9493ac5670231f424b4786e227e02a587 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 30 Apr 2026 11:48:20 -0600 Subject: [PATCH 08/12] Update logging method documentation for clarity --- src/virtual_stain_flow/vsf_logging/MlflowLogger.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/virtual_stain_flow/vsf_logging/MlflowLogger.py b/src/virtual_stain_flow/vsf_logging/MlflowLogger.py index 9bfe22b..d2d9917 100644 --- a/src/virtual_stain_flow/vsf_logging/MlflowLogger.py +++ b/src/virtual_stain_flow/vsf_logging/MlflowLogger.py @@ -536,8 +536,9 @@ def _get_loss_groups(self) -> Dict[str, Any]: def _log_loss_groups_config_and_tags(self) -> None: """ - Log loss item names and weights as flat tags and full loss group - configuration as config artifacts. + Log loss item names and weights as flattened string mlflow tags and + full loss group configuration (loss name, loss weight, whether loss + is active during validation etc.) as mlflow config artifacts. """ loss_groups = self._get_loss_groups() From 47e8b8902ee32195fc4b912d88aa003be56b0615 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 30 Apr 2026 14:01:59 -0600 Subject: [PATCH 09/12] Add docstrings to test methods for clarity on logging behavior --- tests/vsf_logging/test_mlflow_logger_loss_config.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/vsf_logging/test_mlflow_logger_loss_config.py b/tests/vsf_logging/test_mlflow_logger_loss_config.py index 296940f..a86b5d0 100644 --- a/tests/vsf_logging/test_mlflow_logger_loss_config.py +++ b/tests/vsf_logging/test_mlflow_logger_loss_config.py @@ -10,6 +10,11 @@ def test_on_train_start_logs_single_trainer_loss_tags_and_config( patched_mlflow, single_generator_trainer, ): + """ + Test for correct on start logging of loss name and weight as + mlflow tags and full loss group config as mlflow artifacts + for a single-loss trainer. + """ from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger captured = patched_mlflow @@ -46,6 +51,12 @@ def test_on_train_start_logs_wgan_loss_tags_and_configs( patched_mlflow, wgan_trainer, ): + """ + Test for correct on start logging of loss name and weight as + mlflow tags and full loss group config as mlflow artifacts + for a WGAN trainer with multiple losses in both generator and + discriminator. + """ from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger captured = patched_mlflow From 4fc7a65ac5e7d57035141d3757670b49c00d30da Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 30 Apr 2026 14:43:23 -0600 Subject: [PATCH 10/12] Remove older tests for MlflowLogger focusing only on loss logging behavior against one one type of model across two trainers. Replace with newly added tests files separate for UNet, ConvNeXtUNet, and WGAN model with UNet testing for both loss logging and model logging. --- .../test_mlflow_logger_loss_config.py | 91 -------------- .../test_mlflow_logger_model_config_unet.py | 95 ++++++++++++++ .../test_mlflow_logger_model_config_unext.py | 89 ++++++++++++++ ...st_mlflow_logger_model_config_wgan_unet.py | 116 ++++++++++++++++++ 4 files changed, 300 insertions(+), 91 deletions(-) delete mode 100644 tests/vsf_logging/test_mlflow_logger_loss_config.py create mode 100644 tests/vsf_logging/test_mlflow_logger_model_config_unet.py create mode 100644 tests/vsf_logging/test_mlflow_logger_model_config_unext.py create mode 100644 tests/vsf_logging/test_mlflow_logger_model_config_wgan_unet.py diff --git a/tests/vsf_logging/test_mlflow_logger_loss_config.py b/tests/vsf_logging/test_mlflow_logger_loss_config.py deleted file mode 100644 index a86b5d0..0000000 --- a/tests/vsf_logging/test_mlflow_logger_loss_config.py +++ /dev/null @@ -1,91 +0,0 @@ -""" -Tests for automatic loss-group config logging in MlflowLogger. -""" - - -class TestMlflowLoggerLossConfigLogging: - - def test_on_train_start_logs_single_trainer_loss_tags_and_config( - self, - patched_mlflow, - single_generator_trainer, - ): - """ - Test for correct on start logging of loss name and weight as - mlflow tags and full loss group config as mlflow artifacts - for a single-loss trainer. - """ - from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger - - captured = patched_mlflow - - logger = MlflowLogger( - name='logger', - experiment_name='exp', - ) - logger.bind_trainer(single_generator_trainer) - - logger.on_train_start() - - assert captured['tags']['loss.main.0.name'] == 'MSELoss' - assert captured['tags']['loss.main.0.weight'] == '1.0' - - loss_group_artifacts = [ - artifact - for artifact in captured['artifacts'] - if artifact['content'] is not None - and artifact['content'].get('group_name') == 'main' - ] - - assert len(loss_group_artifacts) == 1 - artifact = loss_group_artifacts[0] - assert artifact['artifact_path'] == 'configs' - assert len(artifact['content']['items']) == 1 - assert artifact['content']['items'][0]['key'] == 'MSELoss' - assert artifact['content']['items'][0]['weight'] == 1.0 - - logger.end_run() - - def test_on_train_start_logs_wgan_loss_tags_and_configs( - self, - patched_mlflow, - wgan_trainer, - ): - """ - Test for correct on start logging of loss name and weight as - mlflow tags and full loss group config as mlflow artifacts - for a WGAN trainer with multiple losses in both generator and - discriminator. - """ - from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger - - captured = patched_mlflow - - logger = MlflowLogger( - name='logger', - experiment_name='exp', - ) - logger.bind_trainer(wgan_trainer) - - logger.on_train_start() - - assert captured['tags']['loss.generator.0.name'] == 'MSELoss' - assert captured['tags']['loss.generator.0.weight'] == '1.0' - assert captured['tags']['loss.generator.1.name'] == 'AdversarialLoss' - assert captured['tags']['loss.generator.1.weight'] == '1.0' - - assert captured['tags']['loss.discriminator.0.name'] == 'WassersteinLoss' - assert captured['tags']['loss.discriminator.0.weight'] == '1.0' - assert captured['tags']['loss.discriminator.1.name'] == 'GradientPenaltyLoss' - assert captured['tags']['loss.discriminator.1.weight'] == '10.0' - - group_names = { - artifact['content']['group_name'] - for artifact in captured['artifacts'] - if artifact['content'] is not None - and 'group_name' in artifact['content'] - } - - assert group_names == {'generator', 'discriminator'} - - logger.end_run() diff --git a/tests/vsf_logging/test_mlflow_logger_model_config_unet.py b/tests/vsf_logging/test_mlflow_logger_model_config_unet.py new file mode 100644 index 0000000..a61f786 --- /dev/null +++ b/tests/vsf_logging/test_mlflow_logger_model_config_unet.py @@ -0,0 +1,95 @@ +""" +Tests for UNet model config logging in MlflowLogger. +""" + +import itertools + +import pytest +import torch + +from virtual_stain_flow.models.unet import UNet +from virtual_stain_flow.trainers.logging_trainer import SingleGeneratorTrainer +from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger + +UNET_VARIANTS = list( + itertools.product( + ["conv", "maxpool"], + ["convt", "bilinear"], + [4, 5], + [32, 64], + ) +) + + +@pytest.mark.parametrize( + "encoder_down_block,decoder_up_block,depth,base_channels", + UNET_VARIANTS, +) +def test_on_train_start_logs_unet_model_config_and_loss_items( + patched_mlflow, + simple_loss, + train_dataloader, + val_dataloader, + encoder_down_block, + decoder_up_block, + depth, + base_channels, +): + model = UNet( + in_channels=1, + out_channels=1, + base_channels=base_channels, + depth=depth, + encoder_down_block=encoder_down_block, + decoder_up_block=decoder_up_block, + act_type="sigmoid", + _num_units=2, + ) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + trainer = SingleGeneratorTrainer( + model=model, + optimizer=optimizer, + losses=simple_loss, + device=torch.device("cpu"), + train_loader=train_dataloader, + val_loader=val_dataloader, + batch_size=2, + ) + + logger = MlflowLogger( + name="logger", + experiment_name="exp", + ) + logger.bind_trainer(trainer) + + logger.on_train_start() + + captured = patched_mlflow + + model_artifacts = [ + artifact + for artifact in captured["artifacts"] + if artifact["content"] is not None + and "class_path" in artifact["content"] + ] + assert model_artifacts + + model_config = next( + artifact["content"] + for artifact in model_artifacts + if artifact["content"]["class_path"].endswith("UNet") + ) + assert "init" in model_config + + assert captured["tags"]["loss.main.0.name"] == "MSELoss" + assert captured["tags"]["loss.main.0.weight"] == "1.0" + + loss_group_artifacts = [ + artifact + for artifact in captured["artifacts"] + if artifact["content"] is not None + and artifact["content"].get("group_name") == "main" + ] + assert len(loss_group_artifacts) == 1 + + logger.end_run() diff --git a/tests/vsf_logging/test_mlflow_logger_model_config_unext.py b/tests/vsf_logging/test_mlflow_logger_model_config_unext.py new file mode 100644 index 0000000..d23bf8f --- /dev/null +++ b/tests/vsf_logging/test_mlflow_logger_model_config_unext.py @@ -0,0 +1,89 @@ +""" +Tests for ConvNeXtUNet model config logging in MlflowLogger. +""" + +import itertools + +import pytest +import torch + +from virtual_stain_flow.models.unext import ConvNeXtUNet +from virtual_stain_flow.trainers.logging_trainer import SingleGeneratorTrainer +from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger + +UNEXT_VARIANTS = list( + itertools.product( + ["pixelshuffle", "convt"], + ["convnext", "conv2d"], + ) +) + + +@pytest.mark.parametrize( + "decoder_up_block,decoder_compute_block", + UNEXT_VARIANTS, +) +def test_on_train_start_logs_unext_model_config_and_loss_items( + patched_mlflow, + simple_loss, + train_dataloader, + val_dataloader, + decoder_up_block, + decoder_compute_block, +): + model = ConvNeXtUNet( + in_channels=1, + out_channels=1, + decoder_up_block=decoder_up_block, + decoder_compute_block=decoder_compute_block, + act_type="sigmoid", + _num_units=2, + ) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + trainer = SingleGeneratorTrainer( + model=model, + optimizer=optimizer, + losses=simple_loss, + device=torch.device("cpu"), + train_loader=train_dataloader, + val_loader=val_dataloader, + batch_size=2, + ) + + logger = MlflowLogger( + name="logger", + experiment_name="exp", + ) + logger.bind_trainer(trainer) + + logger.on_train_start() + + captured = patched_mlflow + + model_artifacts = [ + artifact + for artifact in captured["artifacts"] + if artifact["content"] is not None + and "class_path" in artifact["content"] + ] + assert model_artifacts + + model_config = next( + artifact["content"] + for artifact in model_artifacts + if artifact["content"]["class_path"].endswith("ConvNeXtUNet") + ) + assert "init" in model_config + + assert captured["tags"]["loss.main.0.name"] == "MSELoss" + assert captured["tags"]["loss.main.0.weight"] == "1.0" + + loss_group_artifacts = [ + artifact + for artifact in captured["artifacts"] + if artifact["content"] is not None + and artifact["content"].get("group_name") == "main" + ] + assert len(loss_group_artifacts) == 1 + + logger.end_run() diff --git a/tests/vsf_logging/test_mlflow_logger_model_config_wgan_unet.py b/tests/vsf_logging/test_mlflow_logger_model_config_wgan_unet.py new file mode 100644 index 0000000..6d657f6 --- /dev/null +++ b/tests/vsf_logging/test_mlflow_logger_model_config_wgan_unet.py @@ -0,0 +1,116 @@ +""" +Tests for WGAN model config logging in MlflowLogger with UNet generator. +""" + +import itertools + +import pytest +import torch + +from virtual_stain_flow.models.discriminator import GlobalDiscriminator +from virtual_stain_flow.models.unet import UNet +from virtual_stain_flow.trainers.logging_gan_trainer import LoggingWGANTrainer +from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger + +UNET_VARIANTS = list( + itertools.product( + ["conv", "maxpool"], + ["convt", "bilinear"], + [4, 5], + [32, 64], + ) +) + + +@pytest.mark.parametrize( + "encoder_down_block,decoder_up_block,depth,base_channels", + UNET_VARIANTS, +) +def test_on_train_start_logs_wgan_unet_model_configs_and_loss_items( + patched_mlflow, + simple_loss, + train_dataloader, + val_dataloader, + encoder_down_block, + decoder_up_block, + depth, + base_channels, +): + generator = UNet( + in_channels=1, + out_channels=1, + base_channels=base_channels, + depth=depth, + encoder_down_block=encoder_down_block, + decoder_up_block=decoder_up_block, + act_type="sigmoid", + _num_units=2, + ) + discriminator = GlobalDiscriminator( + n_in_channels=1, + n_in_filters=1, + out_activation=None, + _conv_depth=4, + _leaky_relu_alpha=0.2, + _batch_norm=False, + _pool_before_fc=False, + ) + generator_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-3) + discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-3) + + trainer = LoggingWGANTrainer( + generator=generator, + discriminator=discriminator, + generator_optimizer=generator_optimizer, + discriminator_optimizer=discriminator_optimizer, + generator_losses=simple_loss, + device=torch.device("cpu"), + train_loader=train_dataloader, + val_loader=val_dataloader, + batch_size=2, + n_discriminator_steps=3, + ) + + logger = MlflowLogger( + name="logger", + experiment_name="exp", + ) + logger.bind_trainer(trainer) + + logger.on_train_start() + + captured = patched_mlflow + + model_artifacts = [ + artifact + for artifact in captured["artifacts"] + if artifact["content"] is not None + and "class_path" in artifact["content"] + ] + assert model_artifacts + + generator_config = next( + artifact["content"] + for artifact in model_artifacts + if artifact["content"]["class_path"].endswith("UNet") + ) + assert "init" in generator_config + + discriminator_config = next( + artifact["content"] + for artifact in model_artifacts + if artifact["content"]["class_path"].endswith("GlobalDiscriminator") + ) + assert "init" in discriminator_config + + assert captured["tags"]["loss.generator.0.name"] == "MSELoss" + assert captured["tags"]["loss.generator.0.weight"] == "1.0" + assert captured["tags"]["loss.generator.1.name"] == "AdversarialLoss" + assert captured["tags"]["loss.generator.1.weight"] == "1.0" + + assert captured["tags"]["loss.discriminator.0.name"] == "WassersteinLoss" + assert captured["tags"]["loss.discriminator.0.weight"] == "1.0" + assert captured["tags"]["loss.discriminator.1.name"] == "GradientPenaltyLoss" + assert captured["tags"]["loss.discriminator.1.weight"] == "10.0" + + logger.end_run() From 58147eedb0579dec67adaf650729999918228fa2 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 30 Apr 2026 14:52:47 -0600 Subject: [PATCH 11/12] Add model class path logging as tag and tests as it is a minimal but potentially useful enhancement fitting the theme of this PR --- src/virtual_stain_flow/vsf_logging/MlflowLogger.py | 8 +++++++- tests/vsf_logging/test_mlflow_logger_model_config_unet.py | 2 ++ .../vsf_logging/test_mlflow_logger_model_config_unext.py | 2 ++ .../test_mlflow_logger_model_config_wgan_unet.py | 3 +++ 4 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/virtual_stain_flow/vsf_logging/MlflowLogger.py b/src/virtual_stain_flow/vsf_logging/MlflowLogger.py index d2d9917..a9164da 100644 --- a/src/virtual_stain_flow/vsf_logging/MlflowLogger.py +++ b/src/virtual_stain_flow/vsf_logging/MlflowLogger.py @@ -204,7 +204,7 @@ def on_train_start(self): else: models = [] - for model in models: + for idx, model in enumerate(models): if isinstance(model, BaseModel) and hasattr(model, 'to_config'): try: @@ -213,6 +213,12 @@ def on_train_start(self): print(f"Could not get model config for logging: {e}") config = None if config: + class_path = config.get("class_path") + if class_path: + mlflow.set_tag( + f"model.{idx}.class_path", + str(class_path) + ) try: self.log_config( tag=model.__class__.__name__, diff --git a/tests/vsf_logging/test_mlflow_logger_model_config_unet.py b/tests/vsf_logging/test_mlflow_logger_model_config_unet.py index a61f786..9b5e8f0 100644 --- a/tests/vsf_logging/test_mlflow_logger_model_config_unet.py +++ b/tests/vsf_logging/test_mlflow_logger_model_config_unet.py @@ -81,6 +81,8 @@ def test_on_train_start_logs_unet_model_config_and_loss_items( ) assert "init" in model_config + assert captured["tags"]["model.0.class_path"].endswith("UNet") + assert captured["tags"]["loss.main.0.name"] == "MSELoss" assert captured["tags"]["loss.main.0.weight"] == "1.0" diff --git a/tests/vsf_logging/test_mlflow_logger_model_config_unext.py b/tests/vsf_logging/test_mlflow_logger_model_config_unext.py index d23bf8f..f9d1cad 100644 --- a/tests/vsf_logging/test_mlflow_logger_model_config_unext.py +++ b/tests/vsf_logging/test_mlflow_logger_model_config_unext.py @@ -75,6 +75,8 @@ def test_on_train_start_logs_unext_model_config_and_loss_items( ) assert "init" in model_config + assert captured["tags"]["model.0.class_path"].endswith("ConvNeXtUNet") + assert captured["tags"]["loss.main.0.name"] == "MSELoss" assert captured["tags"]["loss.main.0.weight"] == "1.0" diff --git a/tests/vsf_logging/test_mlflow_logger_model_config_wgan_unet.py b/tests/vsf_logging/test_mlflow_logger_model_config_wgan_unet.py index 6d657f6..612861c 100644 --- a/tests/vsf_logging/test_mlflow_logger_model_config_wgan_unet.py +++ b/tests/vsf_logging/test_mlflow_logger_model_config_wgan_unet.py @@ -103,6 +103,9 @@ def test_on_train_start_logs_wgan_unet_model_configs_and_loss_items( ) assert "init" in discriminator_config + assert captured["tags"]["model.0.class_path"].endswith("UNet") + assert captured["tags"]["model.1.class_path"].endswith("GlobalDiscriminator") + assert captured["tags"]["loss.generator.0.name"] == "MSELoss" assert captured["tags"]["loss.generator.0.weight"] == "1.0" assert captured["tags"]["loss.generator.1.name"] == "AdversarialLoss" From 13fc3f7e93316d1b6fe4ba08f652be687d300e3b Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 30 Apr 2026 15:00:35 -0600 Subject: [PATCH 12/12] Update version to 0.4.5 in pyproject.toml and document MLflow auto logging enhancements in CHANGELOG.md --- CHANGELOG.md | 12 ++++++++++++ pyproject.toml | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eba8c81..3f4fa68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,18 @@ All notable chagnes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +--- + +## [0.4.5] - 2026-04-30 + +### Added + +#### MLflow auto logging enhancements (`virtual_stain_flow/vsf_logging/`): + +- The logger now records model architecture tags by capturing each model config's `class_path` and setting `model..class_path` at train start. +- The loss-group auto logging routine `_log_loss_groups_config_and_tags` logs loss item names and weights as MLflow tags and persists the full loss group configuration as a JSON config artifact. + --- ## [0.4.4] - 2026-04-23 diff --git a/pyproject.toml b/pyproject.toml index a376133..1dfdd8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" [project] name = "virtual_stain_flow" -version = "0.4.0" +version = "0.4.5" description = "For developing virtual staining models" requires-python = ">=3.9" dependencies = [