Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions polyid/polyid.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def load_models(
, by default None
nmodels : Union[int, list], optional
Used to import models 0 through nmodels if value passed is an int. If value
passed is a list, then those models will be imported. Should be a list of
passed is a list, then those models will be imported. Should be a list of
integers. by default None

Returns
Expand Down Expand Up @@ -535,8 +535,9 @@ def load_models(

if nmodels == None:
for model_folder in model_folders:
model_path = Path(model_folder) / (model_folder.rsplit("/")[-1] + ".h5")
data_path = Path(model_folder) / (model_folder.rsplit("/")[-1] + "_data.pk")
model_folder_path = Path(model_folder)
model_path = model_folder_path / (model_folder_path.stem + ".h5")
data_path = model_folder_path / (model_folder_path.stem + "_data.pk")
mm.models.append(
SingleModel.load_model(
model_path, data_path, custom_objects=custom_objects_dict
Expand Down Expand Up @@ -690,7 +691,8 @@ def make_aggregate_predictions(
df_prediction = self.make_predictions(df_prediction)

# Required columns
groupby_columns = ["pm", "distribution"]
groupby_columns = ["pm", "polymer", "distribution"]
#groupby_columns = ["pm", "distribution"]
groupby_columns.extend(additional_groupby)
if not groupby_pm or "pm" not in df_prediction.columns:
groupby_columns.remove("pm")
Expand Down Expand Up @@ -756,7 +758,7 @@ def split_data(

data = self.df_polymer.copy()

# Assign a data_id column to aid in kfolds; data id acts as a unique identifer for creating stratified splits. It should be a unique integer.
# Assign a data_id column to aid in kfolds; data id acts as a unique identifer for creating stratified splits. It should be a unique integer.
if "data_id" not in data:
data["data_id"] = 0
#replacing nans for logical statements later
Expand All @@ -765,12 +767,12 @@ def split_data(

if "pm" in row:
idxs = data[
(data["smiles_monomer"] == row.smiles_monomer) &
(data["smiles_monomer"] == row.smiles_monomer) &
(data["distribution"] == row.distribution)&
(data["pm"]==row.pm)
].index.tolist()

else:
else:
idxs = data[(data["smiles_monomer"] == row.smiles_monomer) & (
data["distribution"] == row.distribution
)].index.tolist()
Expand Down Expand Up @@ -1017,7 +1019,7 @@ def load_training_data(cls, fname: Union[Path, str]) -> MultiModel:

class RenameUnpickler(pk.Unpickler):
"""For handling the renaming of modules previously named polyml. Backwards compatibilty for older models.

example:

params = RenameUnpickler.load_pickle(filepath_of_picklefile)
Expand Down