From 48ecc683a317d45483142cc4260d7aac8b75d60c Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 25 May 2026 09:29:10 +0800 Subject: [PATCH 1/2] fix: miscalculation of num_steps when using num_epoch and lmdb --- deepmd/dpmodel/utils/lmdb_data.py | 21 +++++++++++++------ deepmd/pt/train/training.py | 4 ++-- deepmd/pt/utils/lmdb_dataset.py | 4 ++-- source/tests/pt/test_lmdb_dataloader.py | 28 +++++++++++++++++++++++++ 4 files changed, 47 insertions(+), 10 deletions(-) diff --git a/deepmd/dpmodel/utils/lmdb_data.py b/deepmd/dpmodel/utils/lmdb_data.py index 2f7d8f9836..97899544c9 100644 --- a/deepmd/dpmodel/utils/lmdb_data.py +++ b/deepmd/dpmodel/utils/lmdb_data.py @@ -751,11 +751,17 @@ def set_noise(self, noise_settings: dict[str, Any]) -> None: @property def index(self) -> list[int]: """Number of batches per system (single system).""" - return [max(1, self.nframes // self.batch_size)] + return [self.total_batch] @property def total_batch(self) -> int: - return self.index[0] + if self.mixed_batch: + return math.ceil(self.nframes / self.batch_size) if self.nframes else 0 + total = 0 + for nloc, indices in self._nloc_groups.items(): + bs = self.get_batch_size_for_nloc(nloc) + total += (len(indices) + bs - 1) // bs + return total @property def batch_sizes(self) -> list[int]: @@ -1304,10 +1310,13 @@ def _partition_batches(self, all_batches: list[list[int]]) -> list[list[int]]: def __len__(self) -> int: """Number of batches for this rank.""" - total = 0 - for nloc, indices in self._reader.nloc_groups.items(): - bs = self._reader.get_batch_size_for_nloc(nloc) - total += (len(indices) + bs - 1) // bs + total = len( + SameNlocBatchSampler( + self._reader, + shuffle=False, + block_targets=self._block_targets, + ) + ) return math.ceil(total / self._world_size) @property diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index f166971dfe..49b87138a7 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -649,7 +649,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: if self.num_epoch <= 0: raise ValueError("training.num_epoch must be positive.") if isinstance(training_data, LmdbDataset): - total_numb_batch = training_data.total_batch + total_numb_batch = len(self.training_dataloader) else: sampler_weights = to_numpy_array( self.training_dataloader.sampler.weights @@ -678,7 +678,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: ) for model_key in self.model_keys: if isinstance(training_data[model_key], LmdbDataset): - per_task_total.append(training_data[model_key].total_batch) + per_task_total.append(len(self.training_dataloader[model_key])) else: sampler_weights = to_numpy_array( self.training_dataloader[model_key].sampler.weights diff --git a/deepmd/pt/utils/lmdb_dataset.py b/deepmd/pt/utils/lmdb_dataset.py index 067b420da9..07b41cba98 100644 --- a/deepmd/pt/utils/lmdb_dataset.py +++ b/deepmd/pt/utils/lmdb_dataset.py @@ -311,11 +311,11 @@ def set_noise(self, noise_settings: dict[str, Any]) -> None: @property def index(self) -> list[int]: - return self._reader.index + return [self.total_batch] @property def total_batch(self) -> int: - return self._reader.total_batch + return len(self._batch_sampler) @property def batch_sizes(self) -> list[int]: diff --git a/source/tests/pt/test_lmdb_dataloader.py b/source/tests/pt/test_lmdb_dataloader.py index ebb505706d..3d0bd5a7e4 100644 --- a/source/tests/pt/test_lmdb_dataloader.py +++ b/source/tests/pt/test_lmdb_dataloader.py @@ -624,6 +624,34 @@ def test_dataset_auto_prob_iteration(self, auto_prob_lmdb): count = sum(len(batch) for batch in ds._batch_sampler) assert count > 300 # expanded + def test_total_batch_matches_auto_prob_sampler(self, auto_prob_lmdb): + ds = LmdbDataset( + auto_prob_lmdb, + type_map=["O", "H"], + batch_size=4, + auto_prob_style="prob_sys_size;0:1:0.5;1:3:0.5", + ) + assert ds.total_batch == len(ds._batch_sampler) + + def test_distributed_len_includes_auto_prob_expansion(self, auto_prob_lmdb): + import math + + ds = LmdbDataset( + auto_prob_lmdb, + type_map=["O", "H"], + batch_size=4, + auto_prob_style="prob_sys_size;0:1:0.5;1:3:0.5", + ) + global_batches = len(ds._batch_sampler) + dist_sampler = DistributedSameNlocBatchSampler( + ds._reader, + rank=0, + world_size=2, + shuffle=False, + block_targets=ds._block_targets, + ) + assert len(dist_sampler) == math.ceil(global_batches / 2) + class TestMergeLmdbSystemIds: """Test merge_lmdb propagates frame_system_ids.""" From e64bfae597a33576f77d54e16c024fff7d78dcee Mon Sep 17 00:00:00 2001 From: OutisLi Date: Wed, 3 Jun 2026 17:52:08 +0800 Subject: [PATCH 2/2] fix --- deepmd/dpmodel/utils/lmdb_data.py | 16 +++++----- deepmd/pt/utils/lmdb_dataset.py | 3 ++ source/tests/pt/test_lmdb_dataloader.py | 42 +++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/deepmd/dpmodel/utils/lmdb_data.py b/deepmd/dpmodel/utils/lmdb_data.py index 97899544c9..9642024548 100644 --- a/deepmd/dpmodel/utils/lmdb_data.py +++ b/deepmd/dpmodel/utils/lmdb_data.py @@ -1275,6 +1275,13 @@ def __init__( self._seed = seed if seed is not None else 0 self._epoch = 0 self._block_targets = block_targets + self._total_batches = len( + SameNlocBatchSampler( + self._reader, + shuffle=False, + block_targets=self._block_targets, + ) + ) def set_epoch(self, epoch: int) -> None: """Set epoch for deterministic cross-rank shuffling. @@ -1310,14 +1317,7 @@ def _partition_batches(self, all_batches: list[list[int]]) -> list[list[int]]: def __len__(self) -> int: """Number of batches for this rank.""" - total = len( - SameNlocBatchSampler( - self._reader, - shuffle=False, - block_targets=self._block_targets, - ) - ) - return math.ceil(total / self._world_size) + return math.ceil(self._total_batches / self._world_size) @property def rank(self) -> int: diff --git a/deepmd/pt/utils/lmdb_dataset.py b/deepmd/pt/utils/lmdb_dataset.py index 07b41cba98..b7f0e17735 100644 --- a/deepmd/pt/utils/lmdb_dataset.py +++ b/deepmd/pt/utils/lmdb_dataset.py @@ -311,6 +311,9 @@ def set_noise(self, noise_settings: dict[str, Any]) -> None: @property def index(self) -> list[int]: + """Number of batches per logical LMDB dataset.""" + if not self._block_targets: + return self._reader.index return [self.total_batch] @property diff --git a/source/tests/pt/test_lmdb_dataloader.py b/source/tests/pt/test_lmdb_dataloader.py index 3d0bd5a7e4..d568d1ab0a 100644 --- a/source/tests/pt/test_lmdb_dataloader.py +++ b/source/tests/pt/test_lmdb_dataloader.py @@ -12,6 +12,9 @@ import pytest import torch +from deepmd.dpmodel.utils import ( + lmdb_data, +) from deepmd.dpmodel.utils.lmdb_data import ( DistributedSameNlocBatchSampler, LmdbDataReader, @@ -632,6 +635,8 @@ def test_total_batch_matches_auto_prob_sampler(self, auto_prob_lmdb): auto_prob_style="prob_sys_size;0:1:0.5;1:3:0.5", ) assert ds.total_batch == len(ds._batch_sampler) + assert ds.index == [ds.total_batch] + assert ds.index != ds._reader.index def test_distributed_len_includes_auto_prob_expansion(self, auto_prob_lmdb): import math @@ -652,6 +657,43 @@ def test_distributed_len_includes_auto_prob_expansion(self, auto_prob_lmdb): ) assert len(dist_sampler) == math.ceil(global_batches / 2) + def test_distributed_len_reuses_cached_total(self, auto_prob_lmdb, monkeypatch): + import math + + calls = 0 + real_sampler = lmdb_data.SameNlocBatchSampler + + class CountingSameNlocBatchSampler(real_sampler): + def __init__(self, *args, **kwargs): + nonlocal calls + calls += 1 + super().__init__(*args, **kwargs) + + monkeypatch.setattr( + lmdb_data, + "SameNlocBatchSampler", + CountingSameNlocBatchSampler, + ) + ds = LmdbDataset( + auto_prob_lmdb, + type_map=["O", "H"], + batch_size=4, + auto_prob_style="prob_sys_size;0:1:0.5;1:3:0.5", + ) + dist_sampler = DistributedSameNlocBatchSampler( + ds._reader, + rank=0, + world_size=2, + shuffle=False, + block_targets=ds._block_targets, + ) + + assert calls == 1 + expected_len = math.ceil(len(ds._batch_sampler) / 2) + assert len(dist_sampler) == expected_len + assert len(dist_sampler) == expected_len + assert calls == 1 + class TestMergeLmdbSystemIds: """Test merge_lmdb propagates frame_system_ids."""