Skip to content

Add semi-supervised losses and test #1551

Open
surajyadav-research wants to merge 4 commits intogoogle-deepmind:mainfrom
surajyadav-research:semi-sl
Open

Add semi-supervised losses and test #1551
surajyadav-research wants to merge 4 commits intogoogle-deepmind:mainfrom
surajyadav-research:semi-sl

Conversation

@surajyadav-research
Copy link
Copy Markdown

#1550

fixmatch_loss — tests

  • Random batched matches reference: compares output to a float32 reference implementation (hard/soft supervised labels), across different B/U/C, confidence_threshold, lambda_u, and dtype (incl. bfloat16); also checks output is finite.
  • vmap correctness: jax.vmap(fixmatch_loss) matches lax.map (per-item loop) output; also checks finiteness.
  • Permutation invariance: shuffling labeled batch order and unlabeled batch order doesn’t change the loss (with threshold set so ordering shouldn’t matter); checks finiteness.
  • Numerical stability (extreme logits): loss stays finite for very large-magnitude logits; gradients w.r.t. labeled logits and strong unlabeled logits are finite.
  • lambda_u = 0 supervised-only: returns exactly supervised cross-entropy when unsupervised weight is zero.
  • Confidence threshold edges:
    • too high (e.g., >1) → no pseudo-labels used → supervised-only
    • 0.0 → unsupervised term included (loss ≥ supervised loss)
  • Empty unlabeled batch: if U=0, behaves as supervised-only; finite.
  • Gradient flows through strong logits: gradient w.r.t. strong logits (us) is non-zero and finite.
  • bfloat16 run: smoke test that it runs in bfloat16 and returns finite output.

mixmatch_loss — tests

  • Random batched matches reference: compares output to a float32 reference implementation (hard/soft supervised labels), different B/U/C, lambda_u, and dtype (incl. bfloat16); checks finiteness.
  • vmap correctness: jax.vmap(mixmatch_loss) matches lax.map; checks finiteness.
  • Permutation invariance: shuffling labeled and unlabeled batches doesn’t change the loss; checks finiteness.
  • Numerical stability (extreme logits): loss finite for huge logits; gradients w.r.t. labeled logits and unlabeled logits are finite.
  • lambda_u = 0 supervised-only: returns supervised cross-entropy when unsupervised weight is zero.
  • Stop-gradient on unlabeled targets: gradient w.r.t. unlabeled_targets is zero (targets are treated as constants).
  • Unsupervised term zero when targets match probs: if unlabeled_targets == softmax(unlabeled_logits), unsupervised loss becomes ~0, so total ≈ supervised-only.
  • bfloat16 run: smoke test that it runs in bfloat16 and returns finite output.

@surajyadav-research
Copy link
Copy Markdown
Author

Hi @rdyro,
Could you please review this implementation when you have time? I included few tests to check robustness of losses; let me know if any of them seem unnecessary and I’ll remove them.

Copy link
Copy Markdown
Collaborator

@vroulet vroulet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good, thanks. Just add references and use correct headers

Comment thread optax/losses/_semi_supervised.py Outdated
@@ -0,0 +1,173 @@
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2026

Comment thread optax/losses/_semi_supervised.py Outdated
lambda_u: Weight for unlabeled term.

Returns:
Scalar FixMatch loss.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add reference to paper (be careful about formatting, see e.g. how it is done in the docstring of adam)


Returns:
Scalar MixMatch loss.
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, add reference

Comment thread optax/losses/_semi_supervised_test.py Outdated
lambda_u=lambda_u,
)

self._assert_allclose(got, expected, dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the wrapper with self.subTest(...) (see tests in other files).
This way if one test fails the other one can still be tested (so we get all info at once).

Update all tests with that pattern

Comment thread optax/losses/_semi_supervised_test.py Outdated

class FixMatchLossTest(parameterized.TestCase):
@staticmethod
def _assert_allclose(got, expected, dtype):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make these functions not class functions but private functions at the fiel level since they are used in both tests I believe

@surajyadav-research
Copy link
Copy Markdown
Author

@vroulet Thank you for reviewing the code. I’ll push the updated changes ASAP.

@surajyadav-research
Copy link
Copy Markdown
Author

Hi @vroulet,
I’ve updated all the changes. Whenever you have time, could you please review them?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants