Skip to content

Commit fccbcfe

Browse files
fix trainer + add tests
1 parent 4e840ac commit fccbcfe

12 files changed

Lines changed: 444 additions & 356 deletions

File tree

pina/_src/callback/refinement/base_refinement.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def on_train_start(self, trainer, solver):
9999
# Initialize dataset and compute initial population size
100100
self._dataset = trainer.datamodule.train_datasets
101101
self._initial_population_size = {
102-
cond: self.dataset[cond].length
102+
cond: self.dataset[cond].dataset_length
103103
for cond in self._condition_to_update
104104
}
105105

pina/_src/condition/time_series_condition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def evaluate(self, batch, solver):
200200
raise ValueError(
201201
"The provided input tensor must have at least 4 dimensions:"
202202
" [trajectories, windows, time_steps, *features]."
203-
f" Got shape {batch["input"].shape}."
203+
f" Got shape {batch['input'].shape}."
204204
)
205205

206206
# Copy the kwargs to avoid modifying the original settings

pina/_src/core/trainer.py

Lines changed: 191 additions & 263 deletions
Large diffs are not rendered by default.

pina/_src/data/aggregator.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,6 @@ class _Aggregator:
1010
iteration of multiple training conditions within a single training loop.
1111
"""
1212

13-
_AVAIL_BATCHING_MODES = {
14-
"common_batch_size",
15-
"proportional",
16-
"separate_conditions",
17-
}
18-
1913
def __init__(self, dataloaders, batching_mode):
2014
"""
2115
Initialization of the :class:`_Aggregator` class.
@@ -27,17 +21,9 @@ def __init__(self, dataloaders, batching_mode):
2721
uniform batch sizes across conditions, ``"proportional"`` for batch
2822
sizes proportional to dataset sizes, and ``"separate_conditions"``
2923
for iterating through each condition separately.
30-
:raises ValueError: If an invalid batching mode is provided.
3124
:raises NotImplementedError: If the selected batching mode is not yet
3225
implemented.
3326
"""
34-
# Check consistency
35-
if batching_mode not in self._AVAIL_BATCHING_MODES:
36-
raise ValueError(
37-
f"Invalid batching mode '{batching_mode}'. "
38-
f"Available options are: {self._AVAIL_BATCHING_MODES}"
39-
)
40-
4127
# Raise not implemented error for separate_conditions mode
4228
if batching_mode == "separate_conditions":
4329
raise NotImplementedError(

pina/_src/data/creator.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@ class _Creator:
1212
behavior to specific training requirements
1313
"""
1414

15-
# Available batching modes
16-
_AVAIL_BATCHING_MODES = {
17-
"common_batch_size",
18-
"proportional",
19-
"separate_conditions",
20-
}
21-
2215
def __init__(
2316
self,
2417
batching_mode,
@@ -53,15 +46,7 @@ def __init__(
5346
:param dict[str, BaseCondition] conditions: The mapping between
5447
condition names and condition objects responsible for data loader
5548
creation.
56-
:raises ValueError: If an invalid batching mode is provided.
5749
"""
58-
# Check consistency
59-
if batching_mode not in self._AVAIL_BATCHING_MODES:
60-
raise ValueError(
61-
f"Invalid batching mode '{batching_mode}'. "
62-
f"Available options are: {self._AVAIL_BATCHING_MODES}"
63-
)
64-
6550
# Initialize attributes
6651
self.batching_mode = batching_mode
6752
self.batch_size = batch_size

pina/_src/data/data_module.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,13 @@ def __init__(
9999
# Move domain discretisation into conditions subsets
100100
self.problem.move_discretisation_into_conditions()
101101

102-
# Verify which splits are zero
103-
self._has_train = train_size > 0
104-
self._has_val = val_size > 0
105-
self._has_test = test_size > 0
102+
# If no splits are defined, use the default dataloaders
103+
if train_size == 0:
104+
self.train_dataloader = super().train_dataloader
105+
if val_size == 0:
106+
self.val_dataloader = super().val_dataloader
107+
if test_size == 0:
108+
self.test_dataloader = super().test_dataloader
106109

107110
# Otherwise, create the condition splits and initialize the creator
108111
self._create_condition_splits(train_size, test_size)
@@ -244,14 +247,6 @@ def train_dataloader(self):
244247
dataloaders.
245248
:rtype: _Aggregator
246249
"""
247-
# If no training split is defined, return the default dataloader
248-
if not self._has_train:
249-
return super().train_dataloader()
250-
251-
# If the training dataloaders have not been created yet, call setup
252-
if not hasattr(self, "train_datasets"):
253-
self.setup("fit")
254-
255250
return _Aggregator(
256251
self.creator(self.train_datasets),
257252
batching_mode=self.batching_mode,
@@ -265,14 +260,6 @@ def val_dataloader(self):
265260
dataloaders.
266261
:rtype: _Aggregator
267262
"""
268-
# If no validation split is defined, return the default dataloader
269-
if not self._has_val:
270-
return super().val_dataloader()
271-
272-
# If the validation dataloaders have not been created yet, call setup
273-
if not hasattr(self, "val_datasets"):
274-
self.setup("fit")
275-
276263
return _Aggregator(
277264
self.creator(self.val_datasets), batching_mode=self.batching_mode
278265
)
@@ -285,14 +272,6 @@ def test_dataloader(self):
285272
dataloaders.
286273
:rtype: _Aggregator
287274
"""
288-
# If no test split is defined, return the default dataloader
289-
if not self._has_test:
290-
return super().test_dataloader()
291-
292-
# If the test dataloaders have not been created yet, call setup
293-
if not hasattr(self, "test_datasets"):
294-
self.setup("test")
295-
296275
return _Aggregator(
297276
self.creator(self.test_datasets),
298277
batching_mode=self.batching_mode,

pina/_src/problem/base_problem.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -296,14 +296,3 @@ def are_all_domains_discretised(self):
296296
:rtype: bool
297297
"""
298298
return all(d in self.discretised_domains for d in self.domains)
299-
300-
301-
# Back-compatibility with version 0.2, to be removed soon
302-
class AbstractProblem(BaseProblem):
303-
def __init__(self, *args, **kwargs):
304-
warnings.warn(
305-
"AbstractProblem is deprecated, use BaseProblem instead.",
306-
DeprecationWarning,
307-
stacklevel=2,
308-
)
309-
super().__init__(*args, **kwargs)

pina/problem/__init__.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Module for the Problems."""
22

33
__all__ = [
4-
"AbstractProblem", # back-compatibility with version 0.2, to be removed soon
54
"ProblemInterface",
65
"BaseProblem",
76
"SpatialProblem",
@@ -18,4 +17,19 @@
1817
from pina._src.problem.inverse_problem import InverseProblem
1918

2019
# Back-compatibility with version 0.2, to be removed soon
21-
from pina._src.problem.base_problem import AbstractProblem
20+
import warnings
21+
22+
_DEPRECATED_IMPORTS = {"AbstractProblem": "BaseProblem"}
23+
24+
25+
def __getattr__(name):
26+
if name in _DEPRECATED_IMPORTS:
27+
28+
warnings.warn(
29+
f"Importing '{name}' from 'pina.problem' is deprecated; use "
30+
f"pina.problem.{_DEPRECATED_IMPORTS[name]} instead.",
31+
DeprecationWarning,
32+
stacklevel=2,
33+
)
34+
35+
return globals()[_DEPRECATED_IMPORTS[name]]

tests/test_data/test_aggregator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,6 @@ def test_constructor(batching_mode):
6969
# Initialize the aggregator
7070
_Aggregator(dataloaders, batching_mode=batching_mode)
7171

72-
# Should fail if an invalid batching mode is provided
73-
with pytest.raises(ValueError):
74-
_Aggregator(dataloaders, batching_mode="invalid_mode")
75-
7672
# Should raise NotImplementedError for separate_conditions mode
7773
with pytest.raises(NotImplementedError):
7874
_Aggregator(dataloaders, batching_mode="separate_conditions")

tests/test_data/test_creator.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def create_dataloader(
5151
}
5252

5353

54-
@pytest.mark.parametrize("batching_mode", _Creator._AVAIL_BATCHING_MODES)
54+
@pytest.mark.parametrize(
55+
"batching_mode",
56+
["common_batch_size", "separate_conditions", "proportional"],
57+
)
5558
def test_constructor(batching_mode):
5659

5760
_Creator(
@@ -64,18 +67,6 @@ def test_constructor(batching_mode):
6467
conditions=dataloaders,
6568
)
6669

67-
# Should fail if an invalid batching mode is provided
68-
with pytest.raises(ValueError):
69-
_Creator(
70-
batching_mode="invalid_mode",
71-
batch_size=4,
72-
shuffle=False,
73-
automatic_batching=True,
74-
num_workers=0,
75-
pin_memory=False,
76-
conditions=dataloaders,
77-
)
78-
7970

8071
@pytest.mark.parametrize(
8172
"batching_mode, batch_size, expected_batch_sizes, expected_max_len",

0 commit comments

Comments
 (0)