From 95839c07cd995fa44c6ba91934c410621a26f5f1 Mon Sep 17 00:00:00 2001 From: xiao312 Date: Tue, 31 Mar 2026 17:29:37 +0800 Subject: [PATCH] refactor: defer cantera imports in data modules --- dfode_kit/data/augment.py | 7 ++++++- dfode_kit/data/integration.py | 8 +++++++- dfode_kit/data/label.py | 3 ++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/dfode_kit/data/augment.py b/dfode_kit/data/augment.py index f4f4e85..f0a5d5d 100644 --- a/dfode_kit/data/augment.py +++ b/dfode_kit/data/augment.py @@ -1,10 +1,11 @@ import time -import cantera as ct import numpy as np def single_step(npstate, chem, time_step=1e-6): + import cantera as ct + gas = ct.Solution(chem) T_old, P_old, Y_old = npstate[0], npstate[1], npstate[2:] gas.TPY = T_old, P_old, Y_old @@ -33,6 +34,8 @@ def random_perturb( inert_idx: int = -1, time_step: float = 1e-6, ) -> np.ndarray: + import cantera as ct + array = array[array[:, 0] > frozenTem] gas = ct.Solution(mech_path) @@ -141,6 +144,8 @@ def label( mech_path: str, time_step: float = 1e-06, ) -> np.ndarray: + import cantera as ct + from dfode_kit.data.integration import advance_reactor gas = ct.Solution(mech_path) diff --git a/dfode_kit/data/integration.py b/dfode_kit/data/integration.py index aedf323..29397ae 100644 --- a/dfode_kit/data/integration.py +++ b/dfode_kit/data/integration.py @@ -1,6 +1,5 @@ import h5py import numpy as np -import cantera as ct from dfode_kit.data.contracts import MECHANISM_ATTR, require_h5_attr, read_scalar_field_datasets from dfode_kit.data.io_hdf5 import get_TPY_from_h5, touch_h5 @@ -43,6 +42,7 @@ def load_model(model_path, device, model_class, model_layers): def predict_Y(model, model_path, d_arr, mech, device): import torch + import cantera as ct gas = ct.Solution(mech) n_species = gas.n_species @@ -81,6 +81,8 @@ def predict_Y(model, model_path, d_arr, mech, device): def nn_integrate(orig_arr, model_path, device, model_class, model_layers, time_step, mech, frozen_temperature=510): + import cantera as ct + model = load_model(model_path, device, model_class, model_layers) mask = orig_arr[:, 0] > frozen_temperature @@ -123,6 +125,8 @@ def integrate_h5( nn_integration=False, model_settings=None, ): + import cantera as ct + """Process scalar-field datasets and save CVODE / NN integration outputs.""" with h5py.File(file_path, 'r') as f: mech = require_h5_attr(f, MECHANISM_ATTR) @@ -183,6 +187,8 @@ def calculate_error( save_path2, error='RMSE' ): + import cantera as ct + gas = ct.Solution(mech_path) with h5py.File(save_path1, 'r') as f1, h5py.File(save_path2, 'r') as f2: diff --git a/dfode_kit/data/label.py b/dfode_kit/data/label.py index 0cba89d..dc63349 100644 --- a/dfode_kit/data/label.py +++ b/dfode_kit/data/label.py @@ -1,6 +1,5 @@ import time -import cantera as ct import numpy as np @@ -9,6 +8,8 @@ def label_npy( time_step, source_path, ): + import cantera as ct + from dfode_kit.data.integration import advance_reactor gas = ct.Solution(mech_path)