diff --git a/ml/train_model.py b/ml/train_model.py index 65c27324..bbe7c13a 100644 --- a/ml/train_model.py +++ b/ml/train_model.py @@ -3,43 +3,28 @@ ## This script trains machine learning models (GP, NN, or ensemble_NN) ## using simulation and experimental data from MongoDB and saves trained models back to the database import time - -import_start_time = time.time() - import tempfile import argparse import torch -from botorch.models.transforms.input import AffineInputTransform -from botorch.models import MultiTaskGP, SingleTaskGP, ModelListGP -from botorch.fit import fit_gpytorch_mll -from gpytorch.kernels import ScaleKernel, MaternKernel import pymongo import os import re import yaml -from lume_model.models import TorchModel -from lume_model.models.ensemble import NNEnsemble -from lume_model.models.gp_model import GPModel -from lume_model.variables import ScalarVariable -from lume_model.variables import DistributionVariable -from sklearn.model_selection import train_test_split import sys import pandas as pd -from gpytorch.mlls import ExactMarginalLogLikelihood -sys.path.append(".") -from Neural_Net_Classes import CombinedNN as CombinedNN +start_time = time.time() -# measure the time it took to import everything -import_end_time = time.time() -elapsed_time = import_end_time - import_start_time -print(f"Imports took {elapsed_time:.1f} seconds.") +# Device is set when first needed (after model-specific imports) +device = None -# Automatically select device for training of GP -device = "cuda" if torch.cuda.is_available() else "cpu" -print("Device selected: ", device) -start_time = time.time() +def get_device(): + global device + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + print("Device selected: ", device) + return device def parse_arguments(): @@ -136,6 +121,8 @@ def split_data(df_exp, df_sim, variables, model_type): else: return df_sim[variables] else: + from sklearn.model_selection import train_test_split + # Split exp and sim data into training and validation data with 80:20 ratio, selected randomly sim_train_df, sim_val_df = train_test_split( df_sim, test_size=0.2, random_state=None, shuffle=True @@ -153,6 +140,8 @@ def split_data(df_exp, df_sim, variables, model_type): def build_transforms(n_inputs, X_train, n_outputs, y_train): + from botorch.models.transforms.input import AffineInputTransform + input_transform = AffineInputTransform( n_inputs, coefficient=X_train.std(axis=0), offset=X_train.mean(axis=0) ) @@ -177,6 +166,9 @@ def train_nn_ensemble( exp_y_val, device, ): + sys.path.append(".") + from Neural_Net_Classes import CombinedNN as CombinedNN + if model_type == "NN": num_models = 1 elif model_type == "ensemble_NN": @@ -214,6 +206,11 @@ def build_torch_model_from_nn( output_transform, output_names, ): + from botorch.models.transforms.input import AffineInputTransform + from lume_model.models import TorchModel + from lume_model.models.ensemble import NNEnsemble + from lume_model.variables import DistributionVariable, ScalarVariable + torch_models = [] for model_nn in ensemble: @@ -265,6 +262,13 @@ def build_torch_model_from_nn( def train_gp( norm_df_train, input_names, output_names, input_transform, output_transform, device ): + from botorch.fit import fit_gpytorch_mll + from botorch.models import ModelListGP, MultiTaskGP, SingleTaskGP + from gpytorch.kernels import MaternKernel, ScaleKernel + from gpytorch.mlls import ExactMarginalLogLikelihood + from lume_model.models.gp_model import GPModel + from lume_model.variables import DistributionVariable, ScalarVariable + # Create separate GP models for each output to handle NaN values in the training data gp_models = [] @@ -477,6 +481,7 @@ def write_model(model, model_type, experiment, db): # Neural Net and Ensemble Creation and training ###################################################### if model_type != "GP": + device = get_device() ( norm_df_val, norm_expt_inputs_val, @@ -526,6 +531,7 @@ def write_model(model, model_type, experiment, db): # Gaussian Process Creation and training ############################################################### else: + device = get_device() model = train_gp( norm_df_train, input_names,