diff --git a/polyid/polyid.py b/polyid/polyid.py index 097c3ec..1b18fdd 100644 --- a/polyid/polyid.py +++ b/polyid/polyid.py @@ -499,7 +499,7 @@ def load_models( , by default None nmodels : Union[int, list], optional Used to import models 0 through nmodels if value passed is an int. If value - passed is a list, then those models will be imported. Should be a list of + passed is a list, then those models will be imported. Should be a list of integers. by default None Returns @@ -535,8 +535,9 @@ def load_models( if nmodels == None: for model_folder in model_folders: - model_path = Path(model_folder) / (model_folder.rsplit("/")[-1] + ".h5") - data_path = Path(model_folder) / (model_folder.rsplit("/")[-1] + "_data.pk") + model_folder_path = Path(model_folder) + model_path = model_folder_path / (model_folder_path.stem + ".h5") + data_path = model_folder_path / (model_folder_path.stem + "_data.pk") mm.models.append( SingleModel.load_model( model_path, data_path, custom_objects=custom_objects_dict @@ -690,7 +691,8 @@ def make_aggregate_predictions( df_prediction = self.make_predictions(df_prediction) # Required columns - groupby_columns = ["pm", "distribution"] + groupby_columns = ["pm", "polymer", "distribution"] + #groupby_columns = ["pm", "distribution"] groupby_columns.extend(additional_groupby) if not groupby_pm or "pm" not in df_prediction.columns: groupby_columns.remove("pm") @@ -756,7 +758,7 @@ def split_data( data = self.df_polymer.copy() - # Assign a data_id column to aid in kfolds; data id acts as a unique identifer for creating stratified splits. It should be a unique integer. + # Assign a data_id column to aid in kfolds; data id acts as a unique identifer for creating stratified splits. It should be a unique integer. if "data_id" not in data: data["data_id"] = 0 #replacing nans for logical statements later @@ -765,12 +767,12 @@ def split_data( if "pm" in row: idxs = data[ - (data["smiles_monomer"] == row.smiles_monomer) & + (data["smiles_monomer"] == row.smiles_monomer) & (data["distribution"] == row.distribution)& (data["pm"]==row.pm) ].index.tolist() - else: + else: idxs = data[(data["smiles_monomer"] == row.smiles_monomer) & ( data["distribution"] == row.distribution )].index.tolist() @@ -1017,7 +1019,7 @@ def load_training_data(cls, fname: Union[Path, str]) -> MultiModel: class RenameUnpickler(pk.Unpickler): """For handling the renaming of modules previously named polyml. Backwards compatibilty for older models. - + example: params = RenameUnpickler.load_pickle(filepath_of_picklefile)