Add semi-supervised losses and test #1551
Add semi-supervised losses and test #1551surajyadav-research wants to merge 4 commits intogoogle-deepmind:mainfrom
Conversation
|
Hi @rdyro, |
vroulet
left a comment
There was a problem hiding this comment.
Looks pretty good, thanks. Just add references and use correct headers
| @@ -0,0 +1,173 @@ | |||
| # Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. | |||
| lambda_u: Weight for unlabeled term. | ||
|
|
||
| Returns: | ||
| Scalar FixMatch loss. |
There was a problem hiding this comment.
Add reference to paper (be careful about formatting, see e.g. how it is done in the docstring of adam)
|
|
||
| Returns: | ||
| Scalar MixMatch loss. | ||
| """ |
| lambda_u=lambda_u, | ||
| ) | ||
|
|
||
| self._assert_allclose(got, expected, dtype) |
There was a problem hiding this comment.
Use the wrapper with self.subTest(...) (see tests in other files).
This way if one test fails the other one can still be tested (so we get all info at once).
Update all tests with that pattern
|
|
||
| class FixMatchLossTest(parameterized.TestCase): | ||
| @staticmethod | ||
| def _assert_allclose(got, expected, dtype): |
There was a problem hiding this comment.
Make these functions not class functions but private functions at the fiel level since they are used in both tests I believe
|
@vroulet Thank you for reviewing the code. I’ll push the updated changes ASAP. |
|
Hi @vroulet, |
#1550
fixmatch_loss— testsB/U/C,confidence_threshold,lambda_u, anddtype(incl.bfloat16); also checks output is finite.vmapcorrectness:jax.vmap(fixmatch_loss)matcheslax.map(per-item loop) output; also checks finiteness.lambda_u = 0supervised-only: returns exactly supervised cross-entropy when unsupervised weight is zero.U=0, behaves as supervised-only; finite.us) is non-zero and finite.bfloat16run: smoke test that it runs inbfloat16and returns finite output.mixmatch_loss— testsB/U/C,lambda_u, anddtype(incl.bfloat16); checks finiteness.vmapcorrectness:jax.vmap(mixmatch_loss)matcheslax.map; checks finiteness.lambda_u = 0supervised-only: returns supervised cross-entropy when unsupervised weight is zero.unlabeled_targetsis zero (targets are treated as constants).unlabeled_targets == softmax(unlabeled_logits), unsupervised loss becomes ~0, so total ≈ supervised-only.bfloat16run: smoke test that it runs inbfloat16and returns finite output.