Skip to content

Commit 4e840ac

Browse files
fix data module + add tests
1 parent 966114f commit 4e840ac

6 files changed

Lines changed: 383 additions & 484 deletions

File tree

docs/source/_rst/data/data_module.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ DataModule
22
======================
33
.. currentmodule:: pina.data.data_module
44

5-
.. autoclass:: pina._src.data.data_module.PinaDataModule
5+
.. autoclass:: pina._src.data.data_module.DataModule
66
:members:
77
:show-inheritance:

pina/__init__.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
"""
2-
PINA: Physics-Informed Neural Analysis.
3-
42
A specialized framework for Scientific Machine Learning (SciML), providing
53
tools for Physics-Informed Neural Networks (PINNs), Neural Operators,
64
and data-driven physical modeling.
@@ -10,12 +8,31 @@
108
"LabelTensor",
119
"Trainer",
1210
"Condition",
13-
"PinaDataModule",
11+
"DataModule",
1412
"Graph",
1513
]
1614

1715
from pina._src.core.label_tensor import LabelTensor
1816
from pina._src.core.graph import Graph
1917
from pina._src.core.trainer import Trainer
2018
from pina._src.condition.condition import Condition
21-
from pina._src.data.data_module import PinaDataModule
19+
from pina._src.data.data_module import DataModule
20+
21+
22+
# Back-compatibility with version 0.2, to be removed soon
23+
import warnings
24+
25+
_DEPRECATED_IMPORTS = {"PinaDataModule": "DataModule"}
26+
27+
28+
def __getattr__(name):
29+
if name in _DEPRECATED_IMPORTS:
30+
31+
warnings.warn(
32+
f"Importing '{name}' from 'pina' is deprecated; use "
33+
f"pina.{_DEPRECATED_IMPORTS[name]} instead.",
34+
DeprecationWarning,
35+
stacklevel=2,
36+
)
37+
38+
return globals()[_DEPRECATED_IMPORTS[name]]

pina/_src/core/trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
warnings.formatwarning = custom_warning_format
1919
warnings.filterwarnings("always", category=UserWarning)
2020

21+
# TODO: add checks on training, val and test sizes
22+
# TODO: rimuovi tutti i check inutili a cascata in tutto il data module
23+
2124

2225
class Trainer(lightning.pytorch.Trainer):
2326
"""
@@ -69,7 +72,7 @@ def __init__(
6972
:param bool automatic_batching: If ``True``, automatic PyTorch batching
7073
is performed, otherwise the items are retrieved from the dataset
7174
all at once. For further details, see the
72-
:class:`~pina.data.data_module.PinaDataModule` class. Default is
75+
:class:`~pina.data.data_module.DataModule` class. Default is
7376
``False``.
7477
:param int num_workers: The number of worker threads for data loading.
7578
Default is ``0`` (serial loading).
@@ -248,7 +251,7 @@ def _create_datamodule(
248251
"are sampled. The Trainer got the following:\n"
249252
f"{error_message}"
250253
)
251-
self.data_module = PinaDataModule(
254+
self.data_module = DataModule(
252255
self.solver.problem,
253256
train_size=train_size,
254257
test_size=test_size,

0 commit comments

Comments
 (0)