Consolidate open PRs: MIM inference, Hessian solver, gradient accumulation, and AMP#10
Consolidate open PRs: MIM inference, Hessian solver, gradient accumulation, and AMP#10thinksyncs merged 7 commits intomainfrom
Conversation
…entropy loss and geometric consistency
There was a problem hiding this comment.
Pull request overview
This pull request consolidates four separate PRs (#6, #7, #8, #9) into a single comprehensive update that adds significant new functionality to the YOLOZU RT-DETR pose estimation system. The consolidation strategy prevents merge conflicts by applying changes in a coordinated manner.
Changes:
- MIM Reconstruction Branch (PR #6): Adds geometry-aligned masked image modeling for test-time adaptation with ~100K additional parameters, including
RenderTeacherandDecoderMIMmodules - Hessian Solver (PR #7): Implements Gauss-Newton optimization with Levenberg-Marquardt damping for per-detection refinement of depth, rotation, and offset predictions
- Gradient Accumulation & AMP (PR #8): Adds training enhancements with gradient accumulation support and automatic mixed precision (AMP) for efficient large-batch training
- Lint Documentation (PR #9): Documents and fixes lint errors from other PRs (unused import removed)
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
yolozu/calibration/hessian_solver.py |
New Hessian-based solver for iterative regression refinement using second-order optimization |
yolozu/calibration/__init__.py |
Exports for HessianSolverConfig and refinement functions |
tools/refine_predictions_hessian.py |
CLI tool for batch refinement of predictions with configurable solver parameters |
rtdetr_pose/rtdetr_pose/model.py |
Adds RenderTeacher, DecoderMIM modules and MIM branch to RTDETRPose with backward compatibility |
rtdetr_pose/rtdetr_pose/losses.py |
New loss functions: mim_reconstruction_loss and entropy_loss for geometric consistency |
rtdetr_pose/tools/train_minimal.py |
Gradient accumulation and AMP integration with proper gradient scaling and clipping |
tests/test_hessian_solver.py |
Comprehensive test suite (9 tests) for Hessian solver |
tests/test_mim_reconstruction.py |
Comprehensive test suite (10 tests) for MIM components |
rtdetr_pose/tests/test_train_minimal_integration.py |
Integration tests for gradient accumulation and AMP |
rtdetr_pose/tests/test_train_minimal_grad_accum_amp.py |
Unit tests for new training arguments |
tools/example_mim_inference.py |
Example script demonstrating MIM usage with test-time adaptation |
docs/hessian_solver.md |
Documentation for Hessian solver API and usage |
docs/mim_inference.md |
Documentation for MIM branch usage and test-time training |
train_setting.yaml |
Configuration examples for new gradient accumulation and AMP features |
README.md |
Updated feature highlights with Hessian solver reference |
LINT_FIXES_NEEDED.md |
Documents lint fixes applied during merge |
SECURITY_SUMMARY.md |
Security assessment of all changes |
IMPLEMENTATION_SUMMARY.md |
Detailed implementation summary for MIM feature |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Perform optimizer step only at accumulation boundaries | ||
| # steps is 0-indexed within each epoch, so we use (steps + 1) for the check | ||
| if (steps + 1) % accum_steps == 0: | ||
| if scaler is not None: | ||
| # Unscale gradients before clipping | ||
| if args.clip_grad_norm and float(args.clip_grad_norm) > 0: | ||
| scaler.unscale_(optim) | ||
| torch.nn.utils.clip_grad_norm_(model.parameters(), float(args.clip_grad_norm)) | ||
| scaler.step(optim) | ||
| scaler.update() | ||
| optim.zero_grad(set_to_none=True) | ||
| else: | ||
| if args.clip_grad_norm and float(args.clip_grad_norm) > 0: | ||
| torch.nn.utils.clip_grad_norm_(model.parameters(), float(args.clip_grad_norm)) | ||
| optim.step() | ||
| optim.zero_grad(set_to_none=True) |
There was a problem hiding this comment.
The optimizer step is only performed when (steps + 1) % accum_steps == 0. However, if the epoch ends before reaching an accumulation boundary, gradients will remain accumulated without being applied. This means the last few batches of an epoch might not contribute to parameter updates if they don't reach an accumulation boundary. Consider adding logic to perform a final optimizer step at the end of each epoch if there are accumulated gradients remaining.
| # Initialize GradScaler for AMP if enabled | ||
| scaler = None | ||
| if args.use_amp: | ||
| if device.startswith("cuda"): | ||
| scaler = torch.cuda.amp.GradScaler() | ||
| print("amp_enabled=True device=cuda") | ||
| else: | ||
| print("amp_warning: --use-amp requires CUDA device; AMP disabled") | ||
|
|
There was a problem hiding this comment.
When gradient accumulation is used with checkpoint resumption, there's a potential issue: if training is resumed and the checkpoint had accumulated gradients, those gradients might still be present. The code doesn't explicitly zero gradients at the start of training after checkpoint loading. Consider adding optim.zero_grad(set_to_none=True) after initializing the scaler (around line 1718) to ensure a clean gradient state when starting/resuming training with gradient accumulation.
| # With mask | ||
| loss_masked = mim_reconstruction_loss(recon_feat, teacher_feat, mask=mask) | ||
| self.assertEqual(loss_masked.shape, torch.Size([])) | ||
| self.assertTrue(loss_masked.item() >= 0.0) |
There was a problem hiding this comment.
assertTrue(a >= b) cannot provide an informative message. Using assertGreaterEqual(a, b) instead will give more informative messages.
| self.assertTrue(loss_masked.item() >= 0.0) | ||
|
|
||
| # Without mask | ||
| loss_full = mim_reconstruction_loss(recon_feat, teacher_feat, mask=None) | ||
| self.assertEqual(loss_full.shape, torch.Size([])) | ||
| self.assertTrue(loss_full.item() >= 0.0) |
There was a problem hiding this comment.
assertTrue(a >= b) cannot provide an informative message. Using assertGreaterEqual(a, b) instead will give more informative messages.
| self.assertTrue(loss_masked.item() >= 0.0) | |
| # Without mask | |
| loss_full = mim_reconstruction_loss(recon_feat, teacher_feat, mask=None) | |
| self.assertEqual(loss_full.shape, torch.Size([])) | |
| self.assertTrue(loss_full.item() >= 0.0) | |
| self.assertGreaterEqual(loss_masked.item(), 0.0) | |
| # Without mask | |
| loss_full = mim_reconstruction_loss(recon_feat, teacher_feat, mask=None) | |
| self.assertEqual(loss_full.shape, torch.Size([])) | |
| self.assertGreaterEqual(loss_full.item(), 0.0) |
| loss = entropy_loss(logits) | ||
|
|
||
| self.assertEqual(loss.shape, torch.Size([])) | ||
| self.assertTrue(loss.item() >= 0.0) |
There was a problem hiding this comment.
assertTrue(a >= b) cannot provide an informative message. Using assertGreaterEqual(a, b) instead will give more informative messages.
| self.assertTrue(loss.item() >= 0.0) | |
| self.assertGreaterEqual(loss.item(), 0.0) |
| self.assertIn("loss", result) | ||
| self.assertIn("loss_mim", result) | ||
| self.assertIn("loss_entropy", result) | ||
| self.assertTrue(result["loss"].item() > 0.0) |
There was a problem hiding this comment.
assertTrue(a > b) cannot provide an informative message. Using assertGreater(a, b) instead will give more informative messages.
| self.assertTrue(result["loss"].item() > 0.0) | |
| self.assertGreater(result["loss"].item(), 0.0) |
| r_flat = [float(x) for row in gt_rotation for x in row] | ||
| if len(r_flat) == 9: | ||
| r_gt_tensor = torch.tensor(r_flat, dtype=dtype, device=device).reshape(3, 3) | ||
| except (TypeError, ValueError): |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
| except (TypeError, ValueError): | |
| except (TypeError, ValueError): | |
| # If ground-truth rotation is malformed, silently skip rotation supervision. | |
| # This keeps behavior consistent with treating missing gt_rotation as no constraint. |
| if isinstance(t_gt, (list, tuple)) and len(t_gt) >= 3: | ||
| try: | ||
| gt_depth = float(t_gt[2]) # Z component. | ||
| except (TypeError, ValueError): |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
| if isinstance(r_gt, (list, tuple)) and len(r_gt) == 3: | ||
| try: | ||
| gt_rotation = [[float(x) for x in row] for row in r_gt] | ||
| except (TypeError, ValueError): |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
Merges PRs #6, #7, #8, and #9 to prevent conflicts. Total delta: 18 files, +2,888 lines.
Changes
PR #6: Masked reconstruction branch (1,270 lines)
RenderTeacherandDecoderMIMmodules for geometry-aligned masked image modeling in inferencemim_reconstruction_loss()andentropy_loss()for test-time adaptationRTDETRPose.forward()acceptsgeom_input(mask + normalized depth),feature_mask, andreturn_mimflagPR #7: Hessian solver (1,319 lines)
HessianSolverConfigandrefine_predictions_hessian()APItools/refine_predictions_hessian.pyPR #8: Gradient accumulation + AMP (252 lines)
--gradient-accumulation-steps Nscales loss by 1/N, defers optimizer step to accumulation boundaries--use-ampenables torch.cuda.amp with proper gradient unscaling before clippingjsonimport in test filePR #9: Lint documentation (50 lines)
Implementation notes
Original prompt
💬 We'd love your input! Share your thoughts on Copilot coding agent in our 2 minute survey.