Skip to content

Commit e3ac3df

Browse files
committed
Warn when the validation-based features are used with very small validation set sizes
1 parent 30fac54 commit e3ac3df

File tree

4 files changed

+150
-4
lines changed

4 files changed

+150
-4
lines changed

docs/source/usage/torch_datasets.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,13 @@ When providing a single ``dataset`` parameter to ``train()``, the trainer automa
385385
# Trainer handles split automatically
386386
pot.train(dataset=my_dataset, config=config) # Uses testpercent
387387
388+
.. note::
389+
390+
When ``testpercent > 0``, validation-driven features such as
391+
``use_scheduler=True`` and ``save_best=True`` become active. For very small
392+
validation splits, prefer disabling those features or creating an explicit
393+
train/test split with enough validation structures for stable monitoring.
394+
388395
Manual Splitting
389396
~~~~~~~~~~~~~~~~
390397

docs/source/usage/torch_training.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,14 @@ structures held out for validation. The
131131
:class:`~aenet.io.train.TrainOut` object containing training history,
132132
statistics, and plotting helpers.
133133

134+
.. note::
135+
136+
Setting ``testpercent > 0`` does more than hold out structures. It also
137+
enables any validation-driven controls in your configuration, such as
138+
``use_scheduler=True`` and ``save_best=True``. On very small validation
139+
splits, these controls can react to noisy metrics and change the training
140+
behavior qualitatively.
141+
134142

135143
Force Training
136144
--------------
@@ -384,6 +392,10 @@ Checkpointing & Model Saving
384392
Save the model with the best validation loss as ``best_model.pt``.
385393
Requires ``testpercent > 0`` to compute validation loss.
386394

395+
For very small validation sets, the selected checkpoint can be unstable.
396+
In that case prefer ``save_best=False`` or supply a larger or explicit
397+
validation split.
398+
387399
**Resuming Training**
388400

389401
To resume training from a checkpoint, pass the checkpoint path to
@@ -442,6 +454,9 @@ adjusting the learning rate for optimal performance.
442454
.. note::
443455

444456
The scheduler requires ``testpercent > 0`` to monitor validation loss.
457+
With only a few validation structures, the monitored loss can be too noisy
458+
for stable plateau detection. In that case prefer ``use_scheduler=False``
459+
or a larger or explicit validation split.
445460

446461

447462
Force Training Parameters

src/aenet/torch_training/tests/test_trainer_smoke.py

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22
from pathlib import Path
33

44
import numpy as np
5-
import torch
6-
75
import pytest
6+
import torch
87

8+
from aenet.torch_featurize import ChebyshevDescriptor
99
from aenet.torch_training import (
10-
TorchTrainingConfig,
1110
Structure,
1211
TorchANNPotential,
12+
TorchTrainingConfig,
1313
)
14-
from aenet.torch_featurize import ChebyshevDescriptor
1514

1615

1716
def make_simple_structures_H_two():
@@ -120,6 +119,77 @@ def test_energy_only_smoke(tmp_path: Path):
120119
and name.endswith(".pt") for name in files)
121120

122121

122+
@pytest.mark.cpu
123+
def test_warns_for_scheduler_with_tiny_validation_set():
124+
structures = make_simple_structures_H_two()
125+
descriptor = make_descriptor_H(dtype=torch.float64)
126+
arch = make_arch_H(descriptor)
127+
128+
pot = TorchANNPotential(arch=arch, descriptor=descriptor)
129+
130+
cfg = TorchTrainingConfig(
131+
iterations=1,
132+
method=None,
133+
testpercent=50,
134+
force_weight=0.0,
135+
atomic_energies={"H": 0.0},
136+
memory_mode="cpu",
137+
device="cpu",
138+
save_energies=False,
139+
checkpoint_dir=None,
140+
checkpoint_interval=0,
141+
max_checkpoints=None,
142+
save_best=False,
143+
use_scheduler=True,
144+
show_progress=False,
145+
)
146+
147+
with pytest.warns(
148+
UserWarning,
149+
match=r"use_scheduler=True with a validation set of only 1 structure",
150+
):
151+
pot.train(
152+
structures=structures,
153+
config=cfg,
154+
)
155+
156+
157+
@pytest.mark.cpu
158+
def test_warns_for_save_best_with_tiny_validation_set(tmp_path: Path):
159+
structures = make_simple_structures_H_two()
160+
descriptor = make_descriptor_H(dtype=torch.float64)
161+
arch = make_arch_H(descriptor)
162+
163+
pot = TorchANNPotential(arch=arch, descriptor=descriptor)
164+
165+
ckpt_dir = tmp_path / "ckpts"
166+
cfg = TorchTrainingConfig(
167+
iterations=1,
168+
method=None,
169+
testpercent=50,
170+
force_weight=0.0,
171+
atomic_energies={"H": 0.0},
172+
memory_mode="cpu",
173+
device="cpu",
174+
save_energies=False,
175+
checkpoint_dir=str(ckpt_dir),
176+
checkpoint_interval=1,
177+
max_checkpoints=None,
178+
save_best=True,
179+
use_scheduler=False,
180+
show_progress=False,
181+
)
182+
183+
with pytest.warns(
184+
UserWarning,
185+
match=r"save_best=True with a validation set of only 1 structure",
186+
):
187+
pot.train(
188+
structures=structures,
189+
config=cfg,
190+
)
191+
192+
123193
@pytest.mark.cpu
124194
def test_force_training_smoke(tmp_path: Path):
125195
structures = make_simple_structures_H_two()

src/aenet/torch_training/trainer.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,51 @@ def _resolve_device(config: TorchTrainingConfig) -> torch.device:
5757
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
5858

5959

60+
_SMALL_VALIDATION_WARNING_THRESHOLD = 10
61+
62+
63+
def _warn_on_small_validation_set(
64+
*,
65+
n_val: int,
66+
use_scheduler: bool,
67+
save_best: bool,
68+
) -> None:
69+
"""
70+
Warn when validation-driven controls are enabled on a tiny split.
71+
72+
Parameters
73+
----------
74+
n_val : int
75+
Number of validation structures.
76+
use_scheduler : bool
77+
Whether ReduceLROnPlateau monitoring is enabled for this run.
78+
save_best : bool
79+
Whether best-checkpoint selection is enabled for this run.
80+
"""
81+
if n_val <= 0 or n_val >= _SMALL_VALIDATION_WARNING_THRESHOLD:
82+
return
83+
84+
noun = "structure" if n_val == 1 else "structures"
85+
86+
if use_scheduler:
87+
warnings.warn(
88+
"use_scheduler=True with a validation set of only "
89+
f"{n_val} {noun} can make ReduceLROnPlateau react to noisy "
90+
"metrics. Consider use_scheduler=False, a larger validation "
91+
"split, or an explicit train/test split.",
92+
UserWarning,
93+
)
94+
95+
if save_best:
96+
warnings.warn(
97+
"save_best=True with a validation set of only "
98+
f"{n_val} {noun} can select a checkpoint from a noisy "
99+
"validation loss. Consider save_best=False, a larger "
100+
"validation split, or an explicit train/test split.",
101+
UserWarning,
102+
)
103+
104+
60105
def _iter_progress(iterable, enable: bool, desc: str):
61106
"""
62107
Wrap an iterable with tqdm progress bar if enabled and available.
@@ -788,6 +833,15 @@ def train(
788833
else None
789834
)
790835

836+
n_val = int(len(test_ds)) if test_ds is not None else 0
837+
_warn_on_small_validation_set(
838+
n_val=n_val,
839+
use_scheduler=bool(config.use_scheduler) and (test_loader is not None),
840+
save_best=bool(config.save_best)
841+
and (config.checkpoint_dir is not None)
842+
and (test_loader is not None),
843+
)
844+
791845
# Initialize normalization manager
792846
normalize_features = bool(getattr(config, "normalize_features", True))
793847
normalize_energy = bool(getattr(config, "normalize_energy", True))

0 commit comments

Comments
 (0)