fix: miscalculation of num_steps when using num_epoch and lmdb#5488
fix: miscalculation of num_steps when using num_epoch and lmdb#5488OutisLi wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adjust LMDB batch counting to reflect auto-probability expansion and align distributed/training step calculations, with added regression tests.
Changes:
- Update
total_batch/indexsemantics to track sampler-expanded batch counts. - Use DataLoader length for LR step calculations when training on
LmdbDataset. - Add tests covering
total_batchalignment and distributed sampler length with auto-prob expansion.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| source/tests/pt/test_lmdb_dataloader.py | Adds tests validating expanded batch counts and distributed sampler __len__() behavior. |
| deepmd/pt/utils/lmdb_dataset.py | Changes index/total_batch to be derived from the batch sampler length (including expansion). |
| deepmd/pt/train/training.py | Uses len(training_dataloader) to compute batch counts for LR scheduling with LMDB datasets. |
| deepmd/dpmodel/utils/lmdb_data.py | Updates total_batch calculation and adjusts distributed sampler __len__() to include expansion via SameNlocBatchSampler. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| return [self.total_batch] | ||
|
|
||
| @property | ||
| def total_batch(self) -> int: | ||
| return self._reader.total_batch | ||
| return len(self._batch_sampler) |
| """Number of batches for this rank.""" | ||
| total = 0 | ||
| for nloc, indices in self._reader.nloc_groups.items(): | ||
| bs = self._reader.get_batch_size_for_nloc(nloc) | ||
| total += (len(indices) + bs - 1) // bs | ||
| total = len( | ||
| SameNlocBatchSampler( | ||
| self._reader, | ||
| shuffle=False, | ||
| block_targets=self._block_targets, | ||
| ) | ||
| ) | ||
| return math.ceil(total / self._world_size) |
| def test_total_batch_matches_auto_prob_sampler(self, auto_prob_lmdb): | ||
| ds = LmdbDataset( | ||
| auto_prob_lmdb, | ||
| type_map=["O", "H"], | ||
| batch_size=4, | ||
| auto_prob_style="prob_sys_size;0:1:0.5;1:3:0.5", | ||
| ) | ||
| assert ds.total_batch == len(ds._batch_sampler) |
📝 WalkthroughWalkthroughThe PR refactors batch count computation across the LMDB data pipeline. ChangesBatch Count Estimation Alignment
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/utils/lmdb_data.py`:
- Around line 1311-1320: The current __len__() returns ceil(total /
self._world_size) for every rank which mismatches the strided partitioning in
__iter__(); change __len__() to compute a rank-aware batch count by getting
total = len(SameNlocBatchSampler(self._reader, shuffle=False,
block_targets=self._block_targets)), then compute base = total //
self._world_size and remainder = total % self._world_size and return base + (1
if self._rank < remainder else 0) so that __len__() matches the actual number of
batches produced by the __iter__() strided partitioning (alternatively, adjust
_partition_batches() to pad/repeat batches so every rank emits
ceil(total/world_size) and keep current __len__()).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 8ecc555b-6099-4600-a3c5-6d331cce1d8d
📒 Files selected for processing (4)
deepmd/dpmodel/utils/lmdb_data.pydeepmd/pt/train/training.pydeepmd/pt/utils/lmdb_dataset.pysource/tests/pt/test_lmdb_dataloader.py
| def __len__(self) -> int: | ||
| """Number of batches for this rank.""" | ||
| total = 0 | ||
| for nloc, indices in self._reader.nloc_groups.items(): | ||
| bs = self._reader.get_batch_size_for_nloc(nloc) | ||
| total += (len(indices) + bs - 1) // bs | ||
| total = len( | ||
| SameNlocBatchSampler( | ||
| self._reader, | ||
| shuffle=False, | ||
| block_targets=self._block_targets, | ||
| ) | ||
| ) | ||
| return math.ceil(total / self._world_size) |
There was a problem hiding this comment.
Keep __len__() consistent with the batches this rank actually yields.
__iter__() uses strided partitioning (all_batches[self._rank :: self._world_size]), so ranks after the remainder get fewer batches when the global count is not divisible by world_size. __len__() now returns ceil(total / world_size) for every rank, which overstates shorter ranks and no longer matches the iterator. That matters now that deepmd/pt/train/training.py Lines 652 and 681 derive num_epoch -> num_steps from len(self.training_dataloader). Either pad/repeat in _partition_batches() so each rank really emits ceil(total / world_size) batches, or make __len__() rank-aware and resolve a shared epoch length elsewhere.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/dpmodel/utils/lmdb_data.py` around lines 1311 - 1320, The current
__len__() returns ceil(total / self._world_size) for every rank which mismatches
the strided partitioning in __iter__(); change __len__() to compute a rank-aware
batch count by getting total = len(SameNlocBatchSampler(self._reader,
shuffle=False, block_targets=self._block_targets)), then compute base = total //
self._world_size and remainder = total % self._world_size and return base + (1
if self._rank < remainder else 0) so that __len__() matches the actual number of
batches produced by the __iter__() strided partitioning (alternatively, adjust
_partition_batches() to pad/repeat batches so every rank emits
ceil(total/world_size) and keep current __len__()).
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5488 +/- ##
=======================================
Coverage 81.36% 81.36%
=======================================
Files 868 868
Lines 96567 96568 +1
Branches 4233 4234 +1
=======================================
+ Hits 78570 78573 +3
+ Misses 16697 16695 -2
Partials 1300 1300 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary by CodeRabbit
Bug Fixes
Tests