Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3b6d95f
Simplify serialization logic to be dependent on only the file_state, …
wli51 Nov 20, 2025
6046769
Add crop dataset class for handling image crops, also introduce the s…
wli51 Nov 20, 2025
697119b
Removing redundant method overrides
wli51 Nov 20, 2025
31dd8c9
Remove unused imports from crop_dataset.py
wli51 Nov 22, 2025
37bcfb7
Remove unused import of PurePath from base_dataset.py
wli51 Dec 10, 2025
7b29be9
Change import order in crop_manifest.py for consistency in attempt to…
wli51 Dec 10, 2025
d3698ac
Merge branch 'main' into dev-add-patch-dataset
wli51 Dec 10, 2025
4506dc5
Add parameter description for kwargs in CropManifest constructor.
wli51 Dec 16, 2025
1dbb8b4
Merge branch 'dev-add-patch-dataset' of https://github.com/wli51/virt…
wli51 Dec 16, 2025
056e786
Enhance parameter descriptions in BaseImageDataset and CropImageDatas…
wli51 Dec 16, 2025
10f1e4d
Add pandera as a dependency for enhanced data validation
wli51 Dec 16, 2025
88e416e
Implement input validation for dataset initialization and enhance Dat…
wli51 Dec 16, 2025
31a1b32
Change test organization to better reflect package structure.
wli51 Dec 16, 2025
319873f
Add tests for input validation in make_file_index_schema function
wli51 Dec 16, 2025
5017460
Enhance CropImageDataset and CropFileState configuration handling by …
wli51 Dec 16, 2025
ee80d04
Add input validation for CropImageDataset and CropFileState configura…
wli51 Dec 16, 2025
6717ec8
Add test suite for CropImageDataset with comprehensive initialization…
wli51 Dec 16, 2025
0fbffb2
Add validation for 'file_state' in BaseImageDataset configuration
wli51 Dec 16, 2025
c3a5c04
Add input validation for CropManifest and FileState configurations; E…
wli51 Dec 16, 2025
603fe45
Update CHANGELOG.md for version 0.4.3; Document added Crop dataset fe…
wli51 Dec 17, 2025
f0183d6
Add input validation test for whitespace-only file index cell.
wli51 Dec 17, 2025
87c2350
Improve test coverage for the dataset subpackage.
wli51 Dec 17, 2025
50e69be
Refactor input validation imports and optimize DataFrame checks in ma…
wli51 Dec 17, 2025
c6e5ee8
Add get_image_dimensions helper method to DatasetManifest for retriev…
wli51 Dec 17, 2025
ee5e9c6
Add crop_generator and ds_utils modules for crop coordinate generation
wli51 Dec 17, 2025
366cb65
Add from_base_dataset class method to CropImageDataset auto-cropping …
wli51 Dec 17, 2025
b6b7257
Add unit tests for crop_generator and ds_utils utility functions
wli51 Dec 17, 2025
58197d6
Refacotored visualization module for plotting model predictions and d…
wli51 Dec 17, 2025
857c5c1
Add transforms support to BaseImageDataset and CropImageDataset for i…
wli51 Dec 17, 2025
5ed4ace
Refactor example to use the newly added crop dataset for training
wli51 Dec 17, 2025
921d410
Re-run notebook
wli51 Dec 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

---

## [0.4.3] - 2025-12-16

### Added

#### Crop dataset (`virtual_stain_flow/datasets/`):

Allows the dataset to return user specified crops dynamically obtained from the full images. Supports serialization and reserialization to facilitate reproducibility.

- **`CropImageDataset`** (`crop_dataset.py`): Dataset class for serving image crops based on a `CropManifest`. Extends `BaseImageDataset` with crop-specific state management and lazy loading via `CropFileState`.
- **`CropManifest`** (`ds_engine/crop_manifest.py`): Immutable collection of crop definitions wrapping a `DatasetManifest` for file access. Supports serialization/deserialization and factory construction from coordinate specifications.
- **`Crop`** (`ds_engine/crop_manifest.py`): Dataclass defining a single crop region with manifest index, position (x, y), and dimensions (width, height).
- **`CropIndexState`** (`ds_engine/crop_manifest.py`): Mutable state tracker for the currently active crop region.
- **`CropFileState`** (`ds_engine/crop_manifest.py`): Lazy image loading backend that wraps `FileState` to load full images and dynamically extract crop regions on demand.

### Removed

#### All obselete dataset classes

---

## [0.4.2] - 2025-11-17

### Added
Expand Down
388 changes: 163 additions & 225 deletions examples/2.training_with_logging_example.ipynb

Large diffs are not rendered by default.

264 changes: 77 additions & 187 deletions examples/nbconverted/2.training_with_logging_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,27 @@


import re
import json
import pathlib
from typing import List, Tuple
from typing import List


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader
from PIL import Image
from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
from mlflow.tracking import MlflowClient

from virtual_stain_flow.datasets.base_dataset import BaseImageDataset
from virtual_stain_flow.datasets.crop_dataset import CropImageDataset
from virtual_stain_flow.transforms.normalizations import MaxScaleNormalize
from virtual_stain_flow.trainers.logging_trainer import SingleGeneratorTrainer
from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger
from virtual_stain_flow.vsf_logging.callbacks.PlotCallback import PlotPredictionCallback
from virtual_stain_flow.models.unet import UNet
from virtual_stain_flow.evaluation.visualization import plot_dataset_grid


# ## Pathing and Additional utils
Expand Down Expand Up @@ -74,216 +76,101 @@ def _collect_field_prefixes(
break
return prefixes


def _load_single_channel(
def build_file_index(
plate_dir: pathlib.Path,
field_prefix: str,
channel: int,
normalize: bool = True,
) -> np.ndarray:
"""
Load a single channel image for a given field prefix and channel index.

:param plate_dir: Directory containing TIFF files for one JUMP plate
:param field_prefix: Prefix like 'r01c01f01p01'.
:param channel: Channel index, e.g. 5 for Hoechst, 7 for BF mid-z.
:param normalize: If True, convert to float32 and divide by dtype max
:return: Image array of shape (H, W), float32.
"""
fname = f"{field_prefix}-ch{channel:d}sk1fk1fl1.tiff"
path = plate_dir / fname
if not path.exists():
raise FileNotFoundError(f"Expected file not found: {path}")

arr = np.array(Image.open(path)) # typically uint16

if normalize:
if np.issubdtype(arr.dtype, np.integer):
info = np.iinfo(arr.dtype)
arr = arr.astype("float32") / float(info.max)
else:
arr = arr.astype("float32")
else:
arr = arr.astype("float32")

return arr # (H, W), float32


def load_jump_bf_hoechst(
plate_dir: str | pathlib.Path,
max_fields: int = 32,
bf_channel: int = 7,
dna_channel: int = 5,
normalize: bool = True,
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
max_fields: int = 16,
) -> pd.DataFrame:
"""
Load a small BF->Hoechst subset from a CPJUMP1 plate.

:param plate_dir: Directory containing TIFF files for one JUMP plate
:param max_fields: Maximum number of fields to load
:param bf_channel: Channel index for BF mid-z (default 7)
:param dna_channel: Channel index for Hoechst (default 5)
:param normalize: If True, convert to float32 and divide by dtype max
Helper function to build a file index that specifies
the relationship of images across channels and field/fovs.
The result can directly be supplied to BaseImageDataset to create a
dataset with the correct image pairs.
"""
plate_dir = pathlib.Path(plate_dir)

if not plate_dir.exists() or not plate_dir.is_dir():
raise FileNotFoundError(
f"Plate directory {plate_dir} does not exist or is not a directory."
)

prefixes = _collect_field_prefixes(plate_dir, max_fields=max_fields)
if not prefixes:
raise RuntimeError(f"No valid JUMP image files found in {plate_dir}")

bf_list: list[np.ndarray] = []
dna_list: list[np.ndarray] = []
used_prefixes: list[str] = []

for prefix in prefixes:
try:
bf = _load_single_channel(
plate_dir, prefix, bf_channel, normalize=normalize
)
dna = _load_single_channel(
plate_dir, prefix, dna_channel, normalize=normalize
)
except FileNotFoundError:
# Skip incomplete fields (missing channels)
continue

# Add channel axis: (1, H, W)
bf_list.append(bf[None, ...])
dna_list.append(dna[None, ...])
used_prefixes.append(prefix)
fields = _collect_field_prefixes(
plate_dir,
max_fields=max_fields,
)

if not bf_list:
raise RuntimeError(
f"No complete BF + DNA pairs found in {plate_dir} "
f"for bf_channel={bf_channel}, dna_channel={dna_channel}"
)
file_index_list = []
for field in fields:
sample = {}
for chan in DATA_PATH.glob(f"**/{field}*.tiff"):
match = FIELD_RE.match(chan.name)
if match and match.groups()[1]:
sample[f"ch{match.groups()[1]}"] = str(chan)

X = np.stack(bf_list, axis=0) # (N, 1, H, W)
Y = np.stack(dna_list, axis=0) # (N, 1, H, W)
file_index_list.append(sample)

return X, Y, used_prefixes
file_index = pd.DataFrame(file_index_list)
file_index.dropna(how='all', inplace=True)
if file_index.empty:
raise ValueError(f"No files found in {plate_dir} matching the expected pattern.")

return file_index.loc[:, sorted(file_index.columns)]

# Dataset object for training

# In[3]:


class SimpleDataset(Dataset):
"""
Simple dataset for demo purposes.
Loads images from disk, crops the center, and returns as tensors.
"""
def __init__(self, X: np.ndarray, Y: np.ndarray, crop_size: int = 256):
self.X = X
self.Y = Y
self.crop_size = crop_size

def __len__(self):
return len(self.X)

def __getitem__(self, idx):
x = self.X[idx, 0, :, :]
y = self.Y[idx, 0, :, :]

# Get image dimensions
height, width = x.shape

# Calculate crop coordinates for center
left = (width - self.crop_size) // 2
top = (height - self.crop_size) // 2
right = left + self.crop_size
bottom = top + self.crop_size

# Crop center
x_crop = x[top:bottom,left:right]
y_crop = y[top:bottom,left:right]

# Convert to tensor
x_tensor = torch.from_numpy(x_crop).unsqueeze(0) # Add channel dimension
y_tensor = torch.from_numpy(y_crop).unsqueeze(0) # Add channel dimension

return x_tensor, y_tensor


# ## Load subsetted demo data

# In[ ]:


# Load very small subset of CJUMP1, BF and Hoechst channel as input-target pairs
# for demo purposes
# See https://github.com/jump-cellpainting/2024_Chandrasekaran_NatureMethods_CPJUMP1 for details
X, Y, prefixes = load_jump_bf_hoechst(
plate_dir=DATA_PATH,
# retrieve up to 64 fields (different positions of images)
# this results in a very small sample size good for demo purposes
# for better training results, increase this number/load the full dataset
max_fields=64,
bf_channel=7, # mid-z BF for CPJUMP1
dna_channel=5, # Hoechst
)

# Print and visualize first 3 images from the loaded data
print("X (BF):", X.shape, X.dtype) # (N, 1, H, W)
print("Y (DNA):", Y.shape, Y.dtype) # (N, 1, H, W)
print("First few fields:", prefixes[:5])

panel_width = 3
indices = [1, 2, 3]
fig, ax = plt.subplots(len(indices), 2, figsize=(panel_width * 2, panel_width * len(indices)))

for i, j in enumerate(indices):
input, target = X[j], Y[j]
ax[i][0].imshow(input[0], cmap='gray')
ax[i][0].set_title(f'No.{j} Input')
ax[i][0].axis('off')
ax[i][1].imshow(target[0], cmap='gray')
ax[i][1].set_title(f'No.{j} Target')
ax[i][1].axis('off')
plt.tight_layout()
plt.show()
file_index = build_file_index(DATA_PATH, max_fields=64)
print(file_index.head())


# ## Create dataset that returns tensors needed for training, and visualize several patches

# In[5]:
# In[4]:


# Create dataset instance
dataset = SimpleDataset(X, Y, crop_size=256)
print(f"Dataset created with {len(dataset)} samples")
# Create a dataset with Brightfield as input and Hoechst as target
# See https://github.com/jump-cellpainting/2024_Chandrasekaran_NatureMethods_CPJUMP1
# for which channel codes correspond to which channel
dataset = BaseImageDataset(
file_index=file_index,
check_exists=True,
pil_image_mode="I;16",
input_channel_keys=["ch7"],
target_channel_keys=["ch5"],
)
print(f"Dataset length: {len(dataset)}")
print(
f"Input channels: {dataset.input_channel_keys}, target channels: {dataset._target_channel_keys}"
)
plot_dataset_grid(
dataset=dataset,
indices=[0,1,2,3],
wspace=0.025,
hspace=0.05
)

# Plot the first 5 samples from the dataset
fig, axes = plt.subplots(5, 2, figsize=(8, 16))

for i in range(5):
brightfield, dna = dataset[i]
brightfield = brightfield.numpy().squeeze()
dna = dna.numpy().squeeze()
# ## Generate cropped dataset by taking the center 256 x 256 square using built in utilities.
# Also visualize the first few crops

# Plot brightfield image
axes[i, 0].imshow(brightfield.squeeze(), cmap='gray')
axes[i, 0].set_title(f'Sample {i} - Brightfield')
axes[i, 0].axis('off')
# In[5]:

# Plot DNA image
axes[i, 1].imshow(dna.squeeze(), cmap='gray')
axes[i, 1].set_title(f'Sample {i} - DNA')
axes[i, 1].axis('off')

plt.tight_layout()
plt.show()
cropped_dataset = CropImageDataset.from_base_dataset(
dataset,
crop_size=256,
transforms=MaxScaleNormalize(
normalization_factor='16bit'
)
)
plot_dataset_grid(
dataset=cropped_dataset,
indices=[0,1,2,3],
wspace=0.025,
hspace=0.05
)


# ## Configure and train

# In[ ]:
# In[6]:


## Hyperparameters
Expand All @@ -303,7 +190,7 @@ def __getitem__(self, idx):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Batch with DataLoader
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(cropped_dataset, batch_size=batch_size, shuffle=True)

# Model & Optimizer
fully_conv_unet = UNet(
Expand All @@ -325,11 +212,14 @@ def __getitem__(self, idx):
# plots to the training.
plot_callback = PlotPredictionCallback(
name="plot_callback_with_train_data",
dataset=dataset,
dataset=cropped_dataset,
indices=[0,1,2,3,4], # first 5 samples
plot_metrics=[torch.nn.L1Loss()],
every_n_epochs=5,
show_plot=False
# kwargs passed to plotting backend
show_plot=False, # don't show plot in notebook
wspace=0.025, # small spacing between subplots
hspace=0.05 # small spacing between subplots
)

# MLflow Logger
Expand Down Expand Up @@ -381,7 +271,7 @@ def __getitem__(self, idx):

# ### Display the last logged prediction plot artifact

# In[ ]:
# In[7]:


# Create MLflow client
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"jupyter",
"notebook",
"tifffile",
"pandera[pandas]",
]

[project.optional-dependencies]
Expand Down
Loading