Skip to content

Commit 4fbe79e

Browse files
authored
Merge pull request #13 from deepflame-ai/cleanup-canonical-imports-and-h5
Cleanup canonical imports and split data integration from h5_kit
2 parents 7639b21 + 55821bc commit 4fbe79e

File tree

12 files changed

+295
-254
lines changed

12 files changed

+295
-254
lines changed

dfode_kit/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@
3434
"df_to_h5": ("dfode_kit.cases.sampling", "df_to_h5"),
3535
"touch_h5": ("dfode_kit.data.io_hdf5", "touch_h5"),
3636
"get_TPY_from_h5": ("dfode_kit.data.io_hdf5", "get_TPY_from_h5"),
37-
"advance_reactor": ("dfode_kit.data_operations.h5_kit", "advance_reactor"),
38-
"load_model": ("dfode_kit.data_operations.h5_kit", "load_model"),
39-
"predict_Y": ("dfode_kit.data_operations.h5_kit", "predict_Y"),
40-
"nn_integrate": ("dfode_kit.data_operations.h5_kit", "nn_integrate"),
41-
"integrate_h5": ("dfode_kit.data_operations.h5_kit", "integrate_h5"),
37+
"advance_reactor": ("dfode_kit.data.integration", "advance_reactor"),
38+
"load_model": ("dfode_kit.data.integration", "load_model"),
39+
"predict_Y": ("dfode_kit.data.integration", "predict_Y"),
40+
"nn_integrate": ("dfode_kit.data.integration", "nn_integrate"),
41+
"integrate_h5": ("dfode_kit.data.integration", "integrate_h5"),
4242
}
4343

4444

dfode_kit/cli/commands/init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def _handle_one_d_flame(args):
109109
json_result = {'case_type': 'oneD-flame'} if args.json else None
110110

111111
if args.write_config:
112-
from dfode_kit.df_interface.case_init import dump_plan_json
112+
from dfode_kit.cases.init import dump_plan_json
113113

114114
config_path = dump_plan_json(plan, args.write_config)
115115
if args.json:

dfode_kit/cli/commands/init_helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pathlib import Path
77
from typing import Any
88

9-
from dfode_kit.df_interface.case_init import (
9+
from dfode_kit.cases.init import (
1010
DEFAULT_ONE_D_FLAME_TEMPLATE,
1111
OneDFlameInitInputs,
1212
dump_plan_json,
@@ -117,7 +117,7 @@ def apply_one_d_flame_plan(
117117
overrides = one_d_flame_overrides_from_plan(plan)
118118
cfg = _build_one_d_flame_config(inputs, overrides, quiet=quiet)
119119

120-
from dfode_kit.df_interface.oneDflame_setup import setup_one_d_flame_case
120+
from dfode_kit.cases.deepflame import setup_one_d_flame_case
121121

122122
if quiet:
123123
with redirect_stdout(io.StringIO()):
@@ -139,7 +139,7 @@ def _build_one_d_flame_config(
139139
overrides: dict[str, Any],
140140
quiet: bool = False,
141141
):
142-
from dfode_kit.df_interface.flame_configurations import OneDFreelyPropagatingFlameConfig
142+
from dfode_kit.cases.presets import OneDFreelyPropagatingFlameConfig
143143

144144
cfg = OneDFreelyPropagatingFlameConfig(
145145
mechanism=inputs.mechanism,

dfode_kit/cli/commands/sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def add_command_parser(subparsers):
3131

3232
def handle_command(args):
3333
from dfode_kit.data.io_hdf5 import touch_h5
34-
from dfode_kit.df_interface.sample_case import df_to_h5
34+
from dfode_kit.cases.sampling import df_to_h5
3535

3636
print('Handling sample command')
3737
df_to_h5(args.case, args.mech, args.save, include_mesh=args.include_mesh)

dfode_kit/data/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
"require_h5_group",
1111
"touch_h5",
1212
"get_TPY_from_h5",
13+
"advance_reactor",
14+
"load_model",
15+
"predict_Y",
16+
"nn_integrate",
17+
"integrate_h5",
18+
"calculate_error",
1319
]
1420

1521
_ATTRIBUTE_MODULES = {
@@ -21,6 +27,12 @@
2127
"require_h5_group": ("dfode_kit.data.contracts", "require_h5_group"),
2228
"touch_h5": ("dfode_kit.data.io_hdf5", "touch_h5"),
2329
"get_TPY_from_h5": ("dfode_kit.data.io_hdf5", "get_TPY_from_h5"),
30+
"advance_reactor": ("dfode_kit.data.integration", "advance_reactor"),
31+
"load_model": ("dfode_kit.data.integration", "load_model"),
32+
"predict_Y": ("dfode_kit.data.integration", "predict_Y"),
33+
"nn_integrate": ("dfode_kit.data.integration", "nn_integrate"),
34+
"integrate_h5": ("dfode_kit.data.integration", "integrate_h5"),
35+
"calculate_error": ("dfode_kit.data.integration", "calculate_error"),
2436
}
2537

2638

dfode_kit/data/integration.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import h5py
2+
import torch
3+
import numpy as np
4+
import cantera as ct
5+
6+
from dfode_kit.data.contracts import MECHANISM_ATTR, require_h5_attr, read_scalar_field_datasets
7+
from dfode_kit.data.io_hdf5 import get_TPY_from_h5, touch_h5
8+
from dfode_kit.utils import BCT, inverse_BCT
9+
10+
11+
def advance_reactor(gas, state, reactor, reactor_net, time_step):
12+
"""Advance the reactor simulation for a given state."""
13+
state = state.flatten()
14+
15+
expected_shape = (2 + gas.n_species,)
16+
if state.shape != expected_shape:
17+
raise ValueError(
18+
f"Expected state shape {expected_shape}, got {state.shape}"
19+
)
20+
21+
gas.TPY = state[0], state[1], state[2:]
22+
23+
reactor.syncState()
24+
reactor_net.reinitialize()
25+
reactor_net.advance(time_step)
26+
reactor_net.set_initial_time(0.0)
27+
28+
return gas
29+
30+
31+
@torch.no_grad()
32+
def load_model(model_path, device, model_class, model_layers):
33+
state_dict = torch.load(model_path, map_location='cpu')
34+
35+
model = model_class(model_layers)
36+
model.load_state_dict(state_dict['net'])
37+
38+
model.eval()
39+
model.to(device=device)
40+
41+
return model
42+
43+
44+
@torch.no_grad()
45+
def predict_Y(model, model_path, d_arr, mech, device):
46+
gas = ct.Solution(mech)
47+
n_species = gas.n_species
48+
expected_dims = 2 + n_species
49+
if d_arr.shape[1] != expected_dims:
50+
raise ValueError(
51+
f"Expected input with {expected_dims} columns, got {d_arr.shape[1]}"
52+
)
53+
54+
state_dict = torch.load(model_path, map_location='cpu')
55+
56+
Xmu0 = state_dict['data_in_mean']
57+
Xstd0 = state_dict['data_in_std']
58+
Ymu0 = state_dict['data_target_mean']
59+
Ystd0 = state_dict['data_target_std']
60+
61+
d_arr = np.clip(d_arr, 0, None)
62+
d_arr[:, 1] *= 0
63+
d_arr[:, 1] += 101325
64+
65+
orig_Y = d_arr[:, 2:].copy()
66+
in_bct = d_arr.copy()
67+
in_bct[:, 2:] = BCT(in_bct[:, 2:])
68+
in_bct_norm = (in_bct - Xmu0) / Xstd0
69+
70+
input = torch.from_numpy(in_bct_norm).float().to(device=device)
71+
72+
output = model(input)
73+
74+
out_bct = output.cpu().numpy() * Ystd0 + Ymu0 + in_bct[:, 2:-1]
75+
next_Y = orig_Y.copy()
76+
next_Y[:, :-1] = inverse_BCT(out_bct)
77+
next_Y[:, :-1] = next_Y[:, :-1] / np.sum(next_Y[:, :-1], axis=1, keepdims=True) * (1 - next_Y[:, -1:])
78+
79+
return next_Y
80+
81+
82+
@torch.no_grad()
83+
def nn_integrate(orig_arr, model_path, device, model_class, model_layers, time_step, mech, frozen_temperature=510):
84+
model = load_model(model_path, device, model_class, model_layers)
85+
86+
mask = orig_arr[:, 0] > frozen_temperature
87+
infer_arr = orig_arr[mask, :]
88+
89+
next_Y = predict_Y(model, model_path, infer_arr, mech, device)
90+
91+
new_states = np.hstack((np.zeros((orig_arr.shape[0], 1)), orig_arr))
92+
new_states[:, 0] += time_step
93+
new_states[:, 2] = orig_arr[:, 1]
94+
new_states[mask, 3:] = next_Y
95+
96+
setter_gas = ct.Solution(mech)
97+
getter_gas = ct.Solution(mech)
98+
new_T = np.zeros_like(next_Y[:, 0])
99+
100+
for idx, (state, next_y) in enumerate(zip(infer_arr, next_Y)):
101+
try:
102+
setter_gas.TPY = state[0], state[1], state[2:]
103+
h = setter_gas.enthalpy_mass
104+
105+
getter_gas.Y = next_y
106+
getter_gas.HP = h, state[1]
107+
108+
new_T[idx] = getter_gas.T
109+
110+
except ct.CanteraError:
111+
continue
112+
new_states[mask, 1] = new_T
113+
114+
return new_states
115+
116+
117+
def integrate_h5(
118+
file_path,
119+
save_path1,
120+
save_path2,
121+
time_step,
122+
cvode_integration=True,
123+
nn_integration=False,
124+
model_settings=None,
125+
):
126+
"""Process scalar-field datasets and save CVODE / NN integration outputs."""
127+
with h5py.File(file_path, 'r') as f:
128+
mech = require_h5_attr(f, MECHANISM_ATTR)
129+
130+
data_dict = read_scalar_field_datasets(file_path)
131+
132+
if cvode_integration:
133+
gas = ct.Solution(mech)
134+
reactor = ct.Reactor(gas, name='Reactor1', energy='off')
135+
reactor_net = ct.ReactorNet([reactor])
136+
reactor_net.rtol, reactor_net.atol = 1e-6, 1e-10
137+
138+
processed_data_dict = {}
139+
140+
for name, data in data_dict.items():
141+
processed_data = np.empty((data.shape[0], data.shape[1] + 1))
142+
for i, state in enumerate(data):
143+
gas = advance_reactor(gas, state, reactor, reactor_net, time_step)
144+
145+
new_state = np.array([time_step, gas.T, gas.P] + list(gas.Y))
146+
147+
processed_data[i, :] = new_state
148+
149+
processed_data_dict[name] = processed_data
150+
151+
with h5py.File(save_path1, 'a') as f:
152+
cvode_group = f.create_group('cvode_integration')
153+
154+
for dataset_name, processed_data in processed_data_dict.items():
155+
cvode_group.create_dataset(dataset_name, data=processed_data)
156+
print(f'Saved processed dataset: {dataset_name} in cvode_integration group')
157+
158+
if nn_integration:
159+
processed_data_dict = {}
160+
if model_settings is None:
161+
raise ValueError("model_settings must be provided for neural network integration.")
162+
163+
for name, data in data_dict.items():
164+
try:
165+
processed_data = nn_integrate(data, **model_settings)
166+
processed_data_dict[name] = processed_data
167+
except Exception as e:
168+
print(f"Error processing dataset '{name}': {e}")
169+
170+
with h5py.File(save_path2, 'a') as f:
171+
if 'nn_integration' in f:
172+
del f['nn_integration']
173+
nn_group = f.create_group('nn_integration')
174+
175+
for dataset_name, processed_data in processed_data_dict.items():
176+
nn_group.create_dataset(dataset_name, data=processed_data)
177+
print(f'Saved processed dataset: {dataset_name} in nn_integration group')
178+
179+
180+
def calculate_error(
181+
mech_path,
182+
save_path1,
183+
save_path2,
184+
error='RMSE'
185+
):
186+
gas = ct.Solution(mech_path)
187+
188+
with h5py.File(save_path1, 'r') as f1, h5py.File(save_path2, 'r') as f2:
189+
cvode_group = f1['cvode_integration']
190+
nn_group = f2['nn_integration']
191+
192+
common_datasets = set(cvode_group.keys()) & set(nn_group.keys())
193+
194+
sorted_datasets = sorted(common_datasets, key=lambda x: float(x))
195+
results = {}
196+
197+
for ds_name in sorted_datasets:
198+
cvode_data = cvode_group[ds_name][:, 3:]
199+
nn_data = nn_group[ds_name][:, 3:]
200+
201+
if error == "RMSE":
202+
rmse_per_dim = np.sqrt(np.mean((cvode_data - nn_data) ** 2, axis=0))
203+
results[ds_name] = rmse_per_dim
204+
205+
print(f"RMSE of ataset: {ds_name}")
206+
for dim_idx, rmse_val in enumerate(rmse_per_dim, start=1):
207+
id = gas.species_names[dim_idx - 3]
208+
print(f" Species {id}: {rmse_val:.6e}")
209+
print()
210+
211+
return results

dfode_kit/data_operations/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
_ATTRIBUTE_MODULES = {
2323
"touch_h5": ("dfode_kit.data.io_hdf5", "touch_h5"),
2424
"get_TPY_from_h5": ("dfode_kit.data.io_hdf5", "get_TPY_from_h5"),
25-
"integrate_h5": ("dfode_kit.data_operations.h5_kit", "integrate_h5"),
26-
"load_model": ("dfode_kit.data_operations.h5_kit", "load_model"),
27-
"nn_integrate": ("dfode_kit.data_operations.h5_kit", "nn_integrate"),
28-
"predict_Y": ("dfode_kit.data_operations.h5_kit", "predict_Y"),
29-
"calculate_error": ("dfode_kit.data_operations.h5_kit", "calculate_error"),
25+
"integrate_h5": ("dfode_kit.data.integration", "integrate_h5"),
26+
"load_model": ("dfode_kit.data.integration", "load_model"),
27+
"nn_integrate": ("dfode_kit.data.integration", "nn_integrate"),
28+
"predict_Y": ("dfode_kit.data.integration", "predict_Y"),
29+
"calculate_error": ("dfode_kit.data.integration", "calculate_error"),
3030
"random_perturb": ("dfode_kit.data_operations.augment_data", "random_perturb"),
3131
"label_npy": ("dfode_kit.data_operations.label_data", "label_npy"),
3232
"SCALAR_FIELDS_GROUP": ("dfode_kit.data.contracts", "SCALAR_FIELDS_GROUP"),

dfode_kit/data_operations/augment_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import cantera as ct
33
import time
4-
from dfode_kit.data_operations.h5_kit import advance_reactor
4+
from dfode_kit.data.integration import advance_reactor
55
from dfode_kit.training.formation import formation_calculate
66

77
def single_step(npstate, chem, time_step=1e-6):

0 commit comments

Comments
 (0)