Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions deepmd/dpmodel/utils/lmdb_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,11 +751,17 @@ def set_noise(self, noise_settings: dict[str, Any]) -> None:
@property
def index(self) -> list[int]:
"""Number of batches per system (single system)."""
return [max(1, self.nframes // self.batch_size)]
return [self.total_batch]

@property
def total_batch(self) -> int:
return self.index[0]
if self.mixed_batch:
return math.ceil(self.nframes / self.batch_size) if self.nframes else 0
total = 0
for nloc, indices in self._nloc_groups.items():
bs = self.get_batch_size_for_nloc(nloc)
total += (len(indices) + bs - 1) // bs
return total

@property
def batch_sizes(self) -> list[int]:
Expand Down Expand Up @@ -1269,6 +1275,13 @@ def __init__(
self._seed = seed if seed is not None else 0
self._epoch = 0
self._block_targets = block_targets
self._total_batches = len(
SameNlocBatchSampler(
self._reader,
shuffle=False,
block_targets=self._block_targets,
)
)

def set_epoch(self, epoch: int) -> None:
"""Set epoch for deterministic cross-rank shuffling.
Expand Down Expand Up @@ -1304,11 +1317,7 @@ def _partition_batches(self, all_batches: list[list[int]]) -> list[list[int]]:

def __len__(self) -> int:
"""Number of batches for this rank."""
total = 0
for nloc, indices in self._reader.nloc_groups.items():
bs = self._reader.get_batch_size_for_nloc(nloc)
total += (len(indices) + bs - 1) // bs
return math.ceil(total / self._world_size)
return math.ceil(self._total_batches / self._world_size)

@property
def rank(self) -> int:
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
if self.num_epoch <= 0:
raise ValueError("training.num_epoch must be positive.")
if isinstance(training_data, LmdbDataset):
total_numb_batch = training_data.total_batch
total_numb_batch = len(self.training_dataloader)
else:
sampler_weights = to_numpy_array(
self.training_dataloader.sampler.weights
Expand Down Expand Up @@ -678,7 +678,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
)
for model_key in self.model_keys:
if isinstance(training_data[model_key], LmdbDataset):
per_task_total.append(training_data[model_key].total_batch)
per_task_total.append(len(self.training_dataloader[model_key]))
else:
sampler_weights = to_numpy_array(
self.training_dataloader[model_key].sampler.weights
Expand Down
7 changes: 5 additions & 2 deletions deepmd/pt/utils/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,14 @@ def set_noise(self, noise_settings: dict[str, Any]) -> None:

@property
def index(self) -> list[int]:
return self._reader.index
"""Number of batches per logical LMDB dataset."""
if not self._block_targets:
return self._reader.index
return [self.total_batch]

@property
def total_batch(self) -> int:
return self._reader.total_batch
return len(self._batch_sampler)
Comment on lines +317 to +321

@property
def batch_sizes(self) -> list[int]:
Expand Down
70 changes: 70 additions & 0 deletions source/tests/pt/test_lmdb_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import pytest
import torch

from deepmd.dpmodel.utils import (
lmdb_data,
)
from deepmd.dpmodel.utils.lmdb_data import (
DistributedSameNlocBatchSampler,
LmdbDataReader,
Expand Down Expand Up @@ -624,6 +627,73 @@
count = sum(len(batch) for batch in ds._batch_sampler)
assert count > 300 # expanded

def test_total_batch_matches_auto_prob_sampler(self, auto_prob_lmdb):
ds = LmdbDataset(
auto_prob_lmdb,
type_map=["O", "H"],
batch_size=4,
auto_prob_style="prob_sys_size;0:1:0.5;1:3:0.5",
)
assert ds.total_batch == len(ds._batch_sampler)
Comment on lines +630 to +637
assert ds.index == [ds.total_batch]
assert ds.index != ds._reader.index

def test_distributed_len_includes_auto_prob_expansion(self, auto_prob_lmdb):
import math

ds = LmdbDataset(
auto_prob_lmdb,
type_map=["O", "H"],
batch_size=4,
auto_prob_style="prob_sys_size;0:1:0.5;1:3:0.5",
)
global_batches = len(ds._batch_sampler)
dist_sampler = DistributedSameNlocBatchSampler(
ds._reader,
rank=0,
world_size=2,
shuffle=False,
block_targets=ds._block_targets,
)
assert len(dist_sampler) == math.ceil(global_batches / 2)

def test_distributed_len_reuses_cached_total(self, auto_prob_lmdb, monkeypatch):
import math

calls = 0
real_sampler = lmdb_data.SameNlocBatchSampler

class CountingSameNlocBatchSampler(real_sampler):
def __init__(self, *args, **kwargs):
nonlocal calls
calls += 1
super().__init__(*args, **kwargs)

monkeypatch.setattr(
lmdb_data,
"SameNlocBatchSampler",
CountingSameNlocBatchSampler,
)
ds = LmdbDataset(
auto_prob_lmdb,
type_map=["O", "H"],
batch_size=4,
auto_prob_style="prob_sys_size;0:1:0.5;1:3:0.5",
)
dist_sampler = DistributedSameNlocBatchSampler(
ds._reader,
rank=0,
world_size=2,
shuffle=False,
block_targets=ds._block_targets,
)

assert calls == 1
expected_len = math.ceil(len(ds._batch_sampler) / 2)
assert len(dist_sampler) == expected_len
assert len(dist_sampler) == expected_len
assert calls == 1

Check warning

Code scanning / CodeQL

Redundant comparison Warning test

Test is always true, because of
this condition
.


class TestMergeLmdbSystemIds:
"""Test merge_lmdb propagates frame_system_ids."""
Expand Down
Loading