Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@ All notable chagnes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


---

## [0.4.5] - 2026-04-30

### Added

#### MLflow auto logging enhancements (`virtual_stain_flow/vsf_logging/`):

- The logger now records model architecture tags by capturing each model config's `class_path` and setting `model.<idx>.class_path` at train start.
- The loss-group auto logging routine `_log_loss_groups_config_and_tags` logs loss item names and weights as MLflow tags and persists the full loss group configuration as a JSON config artifact.

---

## [0.4.4] - 2026-04-23
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "virtual_stain_flow"
version = "0.4.0"
version = "0.4.5"
description = "For developing virtual staining models"
requires-python = ">=3.9"
dependencies = [
Expand Down
38 changes: 36 additions & 2 deletions src/virtual_stain_flow/engine/loss_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
"""

from dataclasses import dataclass
from typing import Optional, Union, Tuple, Dict, Sequence, List
from typing import Optional, Union, Tuple, Dict, Sequence, List, Any

import torch

from .loss_utils import BaseLoss, _get_loss_name, _scalar_from_ctx
from .context import Context, ContextValue
from .names import PREDS, TARGETS
from .progress import Progress

Scalar = Union[int, float, bool]

Expand Down Expand Up @@ -79,6 +80,7 @@ def __post_init__(self):
def __call__(
self,
train: bool,
progress: Optional[Progress] = None,
context: Optional[Context] = None,
**inputs: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -88,6 +90,8 @@ def __call__(
skipped during validation.

:param train: Whether the model is in training mode.
:param progress: Optional Progress object containing scheduling state (epoch, step, etc.)
for dynamic weight scheduling.
:param context: Optional Context object containing tensors.
Comment thread
wli51 marked this conversation as resolved.
:param inputs: Keyword arguments containing all necessary inputs for the
loss computation.
Expand Down Expand Up @@ -117,6 +121,21 @@ def __call__(

return raw, raw * _scalar_from_ctx(self.weight, inputs)

def get_config(self) -> Dict[str, Any]:
Comment thread
wli51 marked this conversation as resolved.
"""
Get the configuration of the LossItem for logging or checkpointing.
"""
return {
'module': self.module.__class__.__name__,
'args': self.args,
'key': self.key,
'weight': self.weight,
'enabled': self.enabled,
'compute_at_val': self.compute_at_val,
'device': str(self.device)
}


@dataclass
class LossGroup:
"""
Expand All @@ -137,13 +156,16 @@ def item_names(self) -> List[Optional[str]]:
def __call__(
self,
train: bool,
progress: Optional[Progress] = None,
context: Optional[Context] = None,
**inputs: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, Scalar]]:
"""
Compute the total loss and individual loss values.

:param train: Whether the model is in training mode.
:param progress: Optional Progress object containing scheduling state (epoch, step, etc.)
for dynamic weight scheduling.
:param context: Optional Context object containing tensors.
:input inputs: Keyword arguments containing all necessary inputs for the
loss computations.
Expand All @@ -156,8 +178,20 @@ def __call__(
logs: Dict[str, float] = {}

for item in self.items:
raw, weighted = item(train, context=context, **inputs)
raw, weighted = item(
train,
progress=progress,
context=context,
**inputs
)
logs[item.key] = raw.item() # type: ignore
total += weighted

return total, logs

def get_config(self) -> List[Dict[str, Any]]:
"""
Get the configuration of the LossGroup for logging or checkpointing.
"""

return [item.get_config() for item in self.items]
Comment thread
wli51 marked this conversation as resolved.
37 changes: 37 additions & 0 deletions src/virtual_stain_flow/engine/progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
progress.py

Progress tracking for loss weight scheduling.

Provides a centralized abstraction for training progress state that can be used
by loss schedulers operating at different granularities (epoch, step, etc.).
Designed to be minimal and extensible without overcomplicating the current API.
"""

from dataclasses import dataclass


@dataclass
class Progress:
"""
Tracks training progress for loss weight scheduling.

Provides centralized access to scheduling state including epoch and step,
with room for future custom scheduling granularities.
"""
epoch: int = 0
step: int = 0

def set_epoch(self, epoch: int) -> None:
"""
Update the current epoch number.
"""
self.epoch = epoch

def set_step(self, step: int) -> None:
"""
Update the current step number.
Intended to be accumulated across epoch for a global step count
that can be used for step-based scheduling.
"""
self.step = step
12 changes: 12 additions & 0 deletions src/virtual_stain_flow/trainers/AbstractTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .trainer_protocol import TrainerProtocol
from ..metrics.AbstractMetrics import AbstractMetrics
from ..engine.progress import Progress
from ..datasets.data_split import default_random_split


Expand Down Expand Up @@ -113,6 +114,9 @@ def _init_state(

# Epoch state
self._epoch = 0

# Progress tracking for loss weight scheduling
self._progress = Progress(epoch=0, step=0)

# Loss and metrics state
self._train_losses = defaultdict(list)
Expand Down Expand Up @@ -232,6 +236,8 @@ def train_epoch(self):
phase="Train"
)

self._progress.set_step(self._progress.step + 1)

batch_loss = self.train_step(inputs, targets)
for key, value in batch_loss.items():
losses[key].append(value)
Expand Down Expand Up @@ -512,6 +518,11 @@ def metrics(self):
def epoch(self):
return self._epoch

@property
def progress(self) -> Progress:
"""Returns the Progress object tracking training state (epoch, step, etc.)"""
return self._progress

@property
def train_losses(self):
return self._train_losses
Expand Down Expand Up @@ -548,6 +559,7 @@ def early_stop_counter(self, value: int):
@epoch.setter
def epoch(self, value: int):
self._epoch = value
self._progress.set_epoch(value)

"""
Update loss and metrics
Expand Down
13 changes: 13 additions & 0 deletions src/virtual_stain_flow/trainers/logging_gan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def train_step(
)
disc_weighted_total, disc_logs = self._discriminator_loss_group(
train=True,
progress=self.progress,
context=disc_ctx
)
disc_weighted_total.backward()
Expand All @@ -133,6 +134,7 @@ def train_step(
)
gen_weighted_total, gen_logs = self._generator_loss_group(
train=True,
progress=self.progress,
context=gen_ctx
)
gen_weighted_total.backward()
Expand All @@ -142,12 +144,14 @@ def train_step(
gen_logs = {}

self._global_step += 1
self._progress.set_step(self._global_step)

# if generator logs are not computed this step (due to skipped update),
# compute from discriminator context
if not gen_logs:
_, gen_logs = self._generator_loss_group(
train=True,
progress=self.progress,
context=ctx
)

Expand Down Expand Up @@ -176,17 +180,26 @@ def evaluate_step(
)
_, gen_logs = self._generator_loss_group(
train=False,
progress=self.progress,
context=ctx
)
_, disc_logs = self._discriminator_loss_group(
train=False,
progress=self.progress,
context=ctx
)

for _, metric in self.metrics.items():
metric.update(*ctx.as_metric_args(), validation=True)

return gen_logs | disc_logs

@property
def loss_groups(self) -> Dict[str, LossGroup]:
return {
'generator': self._generator_loss_group,
'discriminator': self._discriminator_loss_group
}

def save_model(
self,
Expand Down
16 changes: 14 additions & 2 deletions src/virtual_stain_flow/trainers/logging_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ def train_step(
targets=targets
)

weighted_total, logs = self._loss_group(train=True, context=ctx)
weighted_total, logs = self._loss_group(
train=True,
progress=self.progress,
context=ctx
)
weighted_total.backward()
self._forward_group.step()

Expand Down Expand Up @@ -133,12 +137,20 @@ def evaluate_step(
targets=targets
)

_, logs = self._loss_group(train=False, context=ctx)
_, logs = self._loss_group(
train=False,
progress=self.progress,
context=ctx
)

for _, metric in self.metrics.items():
metric.update(*ctx.as_metric_args(), validation=True)

return logs

@property
def loss_groups(self) -> Dict[str, LossGroup]:
return {'main': self._loss_group}

def save_model(
self,
Expand Down
Loading
Loading