Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions DeepLearning/training/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class Config:
focal_gamma: float = 2.0
prior_ema: float = 0.01
use_drw: bool = False
use_logit_adjustment: bool = False
use_logit_adjustment: bool = False # disables alpha_t weighting when True
use_batch_alpha: bool = False
use_penultimate_logits: bool = False

Expand Down Expand Up @@ -847,6 +847,11 @@ def _train_one_epoch(self, epoch: int):
alpha_t = alpha_combined
gamma_t = self.cfg.focal_gamma

if self.cfg.use_logit_adjustment:
# When applying logit adjustment, disable inverse-frequency alpha
# weighting to avoid over-correcting class imbalance.
alpha_t = tf.ones_like(alpha_t)

lambda_t = ramp if self.cfg.use_logit_adjustment else 0.0
prior_t = tf.identity(self.class_prior)

Expand Down Expand Up @@ -899,10 +904,19 @@ def _validate_one_epoch(self, epoch: int):
if self.val_dataset is None:
return
self.validator.reset()

step_f = tf.cast(self.global_step, tf.float32)
if self.cfg.use_drw:
post = tf.nn.relu(step_f - float(self.cfg.warmup_steps))
ramp = tf.minimum(1.0, post / float(self.cfg.drw_warmup_steps))
else:
ramp = 1.0
lambda_t = ramp if self.cfg.use_logit_adjustment else 0.0

prog = Progbar(self.val_steps, stateful_metrics=None, verbose=1, unit_name="val_step")
for step, batch in enumerate(self.val_dataset.take(self.val_steps), start=1):
x, y = batch
self.validator.update(x, y)
self.validator.update(x, y, self.class_prior, lambda_t)
prog.update(step)
metrics = self.validator.result()
gstep = int(self.global_step.numpy())
Expand All @@ -929,6 +943,12 @@ def _validate_one_epoch(self, epoch: int):
for step, batch in enumerate(self.val_dataset.take(self.val_steps), start=1):
x, y = batch
logits = self.model(x, training=False)
logits = tf.cast(logits, tf.float32)
if self.cfg.use_logit_adjustment:
log_adj = lambda_t * tf.math.log(
1.0 / tf.clip_by_value(self.class_prior, 1e-6, 1.0)
)
logits += log_adj
h = tf.shape(logits)[1] // 2
w = tf.shape(logits)[2] // 2
half = self.validator.window_size // 2
Expand Down Expand Up @@ -1049,10 +1069,19 @@ def _evaluate_baseline(self):

tf.print("[INFO] Evaluating baseline (pre‑training) model…")
self.validator.reset()

step_f = tf.cast(self.global_step, tf.float32)
if self.cfg.use_drw:
post = tf.nn.relu(step_f - float(self.cfg.warmup_steps))
ramp = tf.minimum(1.0, post / float(self.cfg.drw_warmup_steps))
else:
ramp = 1.0
lambda_t = ramp if self.cfg.use_logit_adjustment else 0.0

prog = Progbar(self.val_steps, stateful_metrics=None, verbose=1, unit_name="val_step")
for step, batch in enumerate(self.val_dataset.take(self.val_steps), start=1):
x, y = batch
self.validator.update(x, y)
self.validator.update(x, y, self.class_prior, lambda_t)
prog.update(step)

metrics = self.validator.result()
Expand Down
46 changes: 35 additions & 11 deletions DeepLearning/training/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from __future__ import annotations

from typing import Callable

import tensorflow as tf


Expand Down Expand Up @@ -100,8 +98,13 @@ def extract_labels(one_hot):
labels = tf.argmax(counts, axis=-1, output_type=tf.int32)
return labels, had_fg if self.skip_empty else tf.ones_like(labels, tf.bool)

def batch_confusion(y_true_onehot, images):
def batch_confusion(y_true_onehot, images, class_prior, lambda_t):
logits = self._infer(images)
logits = tf.cast(logits, tf.float32)
log_adj = lambda_t * tf.math.log(
1.0 / tf.clip_by_value(class_prior, 1e-6, 1.0)
)
logits += log_adj
h = tf.shape(logits)[1] // 2
w = tf.shape(logits)[2] // 2
half = ws // 2
Expand Down Expand Up @@ -129,25 +132,46 @@ def non_empty_case():
return tf.cond(tf.size(true_labels) > 0, non_empty_case, empty_case)

@tf.function(jit_compile=False)
def update_step(images, one_hot):
batch_cm = batch_confusion(one_hot, images)
def update_step(images, one_hot, class_prior, lambda_t):
batch_cm = batch_confusion(one_hot, images, class_prior, lambda_t)
self.cm_var.assign_add(batch_cm)

self._update_step = update_step

# ------------------------------------------------------------------
def update(self, images, one_hot) -> None:
"""Accumulate the confusion matrix for a batch."""
def update(
self,
images,
one_hot,
class_prior: tf.Tensor | None = None,
lambda_t: float = 0.0,
) -> None:
"""Accumulate the confusion matrix for a batch.

Parameters
----------
images: Tensor of shape `[B, H, W, C]` containing input images.
one_hot: Tensor of shape `[B, H, W, num_classes]` with one‑hot labels.
class_prior: Optional prior probabilities for logit adjustment.
lambda_t: Scaling factor for logit adjustment.
"""

if class_prior is None:
class_prior = tf.fill([self.C], 1.0 / self.C)
class_prior = tf.convert_to_tensor(class_prior, dtype=tf.float32)
lambda_t = tf.convert_to_tensor(lambda_t, dtype=tf.float32)

if isinstance(self.strategy, tf.distribute.Strategy) and not isinstance(
self.strategy, tf.distribute.OneDeviceStrategy
):
def replica_fn(imgs, y):
self._update_step(imgs, y)
def replica_fn(imgs, y, prior, lamb):
self._update_step(imgs, y, prior, lamb)

self.strategy.run(replica_fn, args=(images, one_hot))
self.strategy.run(
replica_fn, args=(images, one_hot, class_prior, lambda_t)
)
else:
self._update_step(images, one_hot)
self._update_step(images, one_hot, class_prior, lambda_t)

# ------------------------------------------------------------------
def result(self) -> dict[str, tf.Tensor]:
Expand Down