Skip to content

Commit 97987c1

Browse files
Davide-Miottindem0
authored andcommitted
Fix DatabaseScaler plugin
1 parent 5b5f9e3 commit 97987c1

2 files changed

Lines changed: 195 additions & 60 deletions

File tree

ezyrb/plugin/scaler.py

Lines changed: 93 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" Module for Scaler plugin """
1+
"""Module for Scaler plugin"""
22

33
from .plugin import Plugin
44

@@ -19,9 +19,9 @@ class DatabaseScaler(Plugin):
1919
applied at the full order ('full') or at the reduced one ('reduced').
2020
:param {'parameters', 'snapshots'} params: define if the rescaling has to
2121
be applied to the parameters or to the snapshots.
22-
22+
2323
:Example:
24-
24+
2525
>>> from ezyrb import ReducedOrderModel as ROM
2626
>>> from ezyrb import POD, RBF, Database
2727
>>> from ezyrb.plugin import DatabaseScaler
@@ -33,10 +33,11 @@ class DatabaseScaler(Plugin):
3333
>>> rom = ROM(db, pod, rbf, plugins=[scaler])
3434
>>> rom.fit()
3535
"""
36+
3637
def __init__(self, scaler, mode, target) -> None:
3738
"""
3839
Initialize the DatabaseScaler plugin.
39-
40+
4041
:param scaler: Scaler object with fit, transform, and inverse_transform methods.
4142
:param str mode: 'full' or 'reduced' - where to apply the scaling.
4243
:param str target: 'parameters' or 'snapshots' - what to scale.
@@ -46,7 +47,7 @@ def __init__(self, scaler, mode, target) -> None:
4647
self.scaler = scaler
4748
self.mode = mode
4849
self.target = target
49-
50+
5051
@property
5152
def target(self):
5253
"""
@@ -58,7 +59,7 @@ def target(self):
5859

5960
@target.setter
6061
def target(self, new_target):
61-
if new_target not in ['snapshots', 'parameters']:
62+
if new_target not in ["snapshots", "parameters"]:
6263
raise ValueError
6364

6465
self._target = new_target
@@ -74,102 +75,152 @@ def mode(self):
7475

7576
@mode.setter
7677
def mode(self, new_mode):
77-
if new_mode not in ['full', 'reduced']:
78+
if new_mode not in ["full", "reduced"]:
7879
raise ValueError
7980

8081
self._mode = new_mode
8182

8283
def _select_matrix(self, db):
8384
"""
8485
Helper function to select the proper matrix to rescale.
85-
86+
8687
:param Database db: The database object.
8788
:return: The selected matrix (parameters or snapshots).
8889
"""
89-
return getattr(db, f'{self.target}_matrix')
90+
return getattr(db, f"{self.target}_matrix")
9091

91-
def rom_preprocessing(self, rom):
92+
# =========================================================================
93+
# MODE = 'FULL' - Scaling applied at full order (before reduction or after prediction)
94+
# =========================================================================
95+
96+
def fit_before_reduction(self, rom):
9297
"""
93-
Apply scaling to the reduced database before ROM processing.
94-
98+
Apply scaling before POD reduction when mode='full'.
99+
Scales the full-order database before reduction.
100+
95101
:param ReducedOrderModel rom: The ROM instance.
96102
"""
97-
if self.mode != 'reduced':
103+
if self.mode != "full":
98104
return
99105

100-
db = rom._reduced_database
106+
db = rom.train_full_database
101107

102108
self.scaler.fit(self._select_matrix(db))
103109

104-
if self.target == 'parameters':
110+
if self.target == "parameters":
105111
new_db = type(db)(
106112
self.scaler.transform(self._select_matrix(db)),
107-
db.snapshots_matrix
113+
db.snapshots_matrix,
108114
)
109115
else:
110116
new_db = type(db)(
111117
db.parameters_matrix,
112118
self.scaler.transform(self._select_matrix(db)),
113119
)
114120

115-
rom._reduced_database = new_db
121+
rom.train_full_database = new_db
116122

117-
def fom_preprocessing(self, rom):
118-
if self.mode != 'full':
119-
return
123+
def predict_postprocessing(self, rom):
124+
"""
125+
Inverse transform scaled data after prediction when mode='full'.
126+
Restores original scale to the full-order predicted database.
120127
121-
db = rom._full_database
128+
:param ReducedOrderModel rom: The ROM instance.
129+
"""
130+
if self.mode != "full":
131+
return
122132

123-
self.scaler.fit(self._select_matrix(db))
133+
db = rom.predicted_full_database
124134

125-
if self.target == 'parameters':
135+
if self.target == "parameters":
126136
new_db = type(db)(
127-
self.scaler.transform(self._select_matrix(db)),
128-
db.snapshots_matrix
137+
self.scaler.inverse_transform(self._select_matrix(db)),
138+
db.snapshots_matrix,
129139
)
130140
else:
131141
new_db = type(db)(
132142
db.parameters_matrix,
133-
self.scaler.transform(self._select_matrix(db)),
143+
self.scaler.inverse_transform(self._select_matrix(db)),
134144
)
135145

136-
rom._full_database = new_db
146+
rom.predicted_full_database = new_db
147+
148+
# =========================================================================
149+
# MODE = 'REDUCED' - Scaling applied at reduced order (before/after approximation)
150+
# =========================================================================
137151

138-
def fom_postprocessing(self, rom):
152+
def fit_before_approximation(self, rom):
153+
"""
154+
Apply scaling before approximation training when mode='reduced'.
155+
Scales the reduced database before approximation training.
139156
140-
if self.mode != 'full':
157+
:param ReducedOrderModel rom: The ROM instance.
158+
"""
159+
if self.mode != "reduced":
141160
return
142161

143-
db = rom._full_database
162+
db = rom.train_reduced_database
144163

145-
if self.target == 'parameters':
164+
self.scaler.fit(self._select_matrix(db))
165+
166+
if self.target == "parameters":
146167
new_db = type(db)(
147-
self.scaler.inverse_transform(self._select_matrix(db)),
148-
db.snapshots_matrix
168+
self.scaler.transform(self._select_matrix(db)),
169+
db.snapshots_matrix,
149170
)
150171
else:
151172
new_db = type(db)(
152173
db.parameters_matrix,
153-
self.scaler.inverse_transform(self._select_matrix(db)),
174+
self.scaler.transform(self._select_matrix(db)),
154175
)
155176

156-
rom._full_database = new_db
177+
rom.train_reduced_database = new_db
157178

158-
def rom_postprocessing(self, rom):
159-
if self.mode != 'reduced':
179+
def predict_after_approximation(self, rom):
180+
"""
181+
Inverse transform scaled data after approximation when mode='reduced'.
182+
Restores original scale to the reduced predicted database.
183+
184+
:param ReducedOrderModel rom: The ROM instance.
185+
"""
186+
if self.mode != "reduced":
160187
return
161188

162-
db = rom._reduced_database
189+
db = rom.predict_reduced_database
163190

164-
if self.target == 'parameters':
191+
if self.target == "parameters":
165192
new_db = type(db)(
166193
self.scaler.inverse_transform(self._select_matrix(db)),
167-
db.snapshots_matrix
194+
db.snapshots_matrix,
168195
)
169196
else:
170197
new_db = type(db)(
171198
db.parameters_matrix,
172199
self.scaler.inverse_transform(self._select_matrix(db)),
173200
)
174-
175-
rom._reduced_database = new_db
201+
202+
rom.predict_reduced_database = new_db
203+
204+
# =========================================================================
205+
# PREDICT - Scaling input parameters before approximation (both modes)
206+
# =========================================================================
207+
208+
def predict_before_approximation(self, rom):
209+
"""
210+
Transform (scale) input parameters before approximation if target='parameters'.
211+
This ensures parameters are scaled to match the training data.
212+
Applied during prediction for both 'full' and 'reduced' modes.
213+
214+
:param ReducedOrderModel rom: The ROM instance.
215+
"""
216+
if self.target != "parameters":
217+
return
218+
219+
db = rom.predict_reduced_database
220+
transformed_params = self.scaler.transform(self._select_matrix(db))
221+
222+
# During prediction, snapshots are None (not yet predicted)
223+
# Database constructor handles None snapshots: creates [None] * len(parameters)
224+
new_db = type(db)(transformed_params, None)
225+
226+
rom.predict_reduced_database = new_db

tests/test_scaler.py

Lines changed: 102 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,40 +8,124 @@
88

99
from sklearn.preprocessing import StandardScaler, MinMaxScaler
1010

11-
snapshots = np.load('tests/test_datasets/p_snapshots.npy').T
12-
pred_sol_tst = np.load('tests/test_datasets/p_predsol.npy').T
13-
pred_sol_gpr = np.load('tests/test_datasets/p_predsol_gpr.npy').T
14-
param = np.array([[-.5, -.5], [.5, -.5], [.5, .5], [-.5, .5]])
11+
snapshots = np.load("tests/test_datasets/p_snapshots.npy").T
12+
pred_sol_tst = np.load("tests/test_datasets/p_predsol.npy").T
13+
pred_sol_gpr = np.load("tests/test_datasets/p_predsol_gpr.npy").T
14+
param = np.array([[-0.5, -0.5], [0.5, -0.5], [0.5, 0.5], [-0.5, 0.5]])
1515

1616

1717
def test_constructor():
1818
pod = POD()
1919
import torch
20+
2021
rbf = RBF()
21-
#rbf = ANN([10, 10], function=torch.nn.Softplus(), stop_training=[1000])
22+
# rbf = ANN([10, 10], function=torch.nn.Softplus(), stop_training=[1000])
2223
db = Database(param, snapshots.T)
2324
# rom = ROM(db, pod, rbf, plugins=[DatabaseScaler(StandardScaler(), 'full', 'snapshots')])
24-
rom = ROM(db, pod, rbf, plugins=[
25-
DatabaseScaler(StandardScaler(), 'reduced', 'parameters'),
26-
DatabaseScaler(StandardScaler(), 'reduced', 'snapshots')
27-
])
25+
rom = ROM(
26+
db,
27+
pod,
28+
rbf,
29+
plugins=[
30+
DatabaseScaler(StandardScaler(), "reduced", "parameters"),
31+
DatabaseScaler(StandardScaler(), "reduced", "snapshots"),
32+
],
33+
)
34+
rom.fit()
35+
assert rom is not None
36+
37+
38+
def test_scaler_reduced_snapshots():
39+
"""Test that StandardScaler on reduced snapshots produces mean=0 and std=1"""
40+
pod = POD()
41+
rbf = RBF()
42+
db = Database(param, snapshots.T)
43+
rom = ROM(
44+
db,
45+
pod,
46+
rbf,
47+
plugins=[DatabaseScaler(StandardScaler(), "reduced", "snapshots")],
48+
)
49+
rom.fit()
50+
51+
# Check that the scaled reduced snapshots have mean ≈ 0 and std ≈ 1
52+
scaled_snapshots = rom.train_reduced_database.snapshots_matrix
53+
np.testing.assert_allclose(np.mean(scaled_snapshots, axis=0), 0, atol=1e-7)
54+
np.testing.assert_allclose(np.std(scaled_snapshots, axis=0), 1, atol=1e-7)
55+
56+
57+
def test_scaler_reduced_parameters():
58+
"""Test that StandardScaler on reduced parameters produces mean=0 and std=1"""
59+
pod = POD()
60+
rbf = RBF()
61+
db = Database(param, snapshots.T)
62+
rom = ROM(
63+
db,
64+
pod,
65+
rbf,
66+
plugins=[DatabaseScaler(StandardScaler(), "reduced", "parameters")],
67+
)
68+
rom.fit()
69+
70+
# Check that the scaled reduced parameters have mean ≈ 0 and std ≈ 1
71+
scaled_params = rom.train_reduced_database.parameters_matrix
72+
np.testing.assert_allclose(np.mean(scaled_params, axis=0), 0, atol=1e-7)
73+
np.testing.assert_allclose(np.std(scaled_params, axis=0), 1, atol=1e-7)
74+
75+
76+
def test_scaler_full_snapshots():
77+
"""Test that StandardScaler on full snapshots produces mean=0 and std=1"""
78+
pod = POD()
79+
rbf = RBF()
80+
db = Database(param, snapshots.T)
81+
rom = ROM(
82+
db,
83+
pod,
84+
rbf,
85+
plugins=[DatabaseScaler(StandardScaler(), "full", "snapshots")],
86+
)
2887
rom.fit()
29-
30-
88+
89+
# Check that the scaled full snapshots have mean ≈ 0 and std ≈ 1
90+
scaled_snapshots = rom.train_full_database.snapshots_matrix
91+
np.testing.assert_allclose(np.mean(scaled_snapshots, axis=0), 0, atol=2e-6)
92+
np.testing.assert_allclose(np.std(scaled_snapshots, axis=0), 1, atol=2e-6)
93+
94+
95+
def test_scaler_full_parameters():
96+
"""Test that StandardScaler on full parameters produces mean=0 and std=1"""
97+
pod = POD()
98+
rbf = RBF()
99+
db = Database(param, snapshots.T)
100+
rom = ROM(
101+
db,
102+
pod,
103+
rbf,
104+
plugins=[DatabaseScaler(StandardScaler(), "full", "parameters")],
105+
)
106+
rom.fit()
107+
108+
# Check that the scaled full parameters have mean ≈ 0 and std ≈ 1
109+
scaled_params = rom.train_full_database.parameters_matrix
110+
np.testing.assert_allclose(np.mean(scaled_params, axis=0), 0, atol=2e-6)
111+
np.testing.assert_allclose(np.std(scaled_params, axis=0), 1, atol=2e-6)
31112

32113

33114
def test_values():
34115
pod = POD()
35116
rbf = RBF()
36117
db = Database(param, snapshots.T)
37-
rom = ROM(db, pod, rbf, plugins=[
38-
DatabaseScaler(StandardScaler(), 'reduced', 'snapshots'),
39-
DatabaseScaler(StandardScaler(), 'full', 'parameters')
40-
])
118+
rom = ROM(
119+
db,
120+
pod,
121+
rbf,
122+
plugins=[
123+
DatabaseScaler(StandardScaler(), "reduced", "snapshots"),
124+
DatabaseScaler(StandardScaler(), "full", "parameters"),
125+
],
126+
)
41127
rom.fit()
42128
test_param = param[2]
43129
truth_sol = db.snapshots_matrix[2]
44130
predicted_sol = rom.predict(test_param)[0]
45-
np.testing.assert_allclose(predicted_sol, truth_sol,
46-
rtol=1e-5, atol=1e-5)
47-
131+
np.testing.assert_allclose(predicted_sol, truth_sol, rtol=1e-5, atol=1e-5)

0 commit comments

Comments
 (0)