From 48c21421bbc76c62f50c546d9a898861aa410ed5 Mon Sep 17 00:00:00 2001
From: rlaplaza
", "\n\n")
text = text.replace("
", "\n")
# bold
text = re.sub(
- r"<\s*b\s*>(.*?)<\s*/\s*b\s*>",
- r"**\1**",
- text,
- flags=re.DOTALL | re.IGNORECASE
+ r"<\s*b\s*>(.*?)<\s*/\s*b\s*>", r"**\1**", text, flags=re.DOTALL | re.IGNORECASE
)
# italic
text = re.sub(
- r"<\s*i\s*>(.*?)<\s*/\s*i\s*>",
- r"*\1*",
- text,
- flags=re.DOTALL | re.IGNORECASE
+ r"<\s*i\s*>(.*?)<\s*/\s*i\s*>", r"*\1*", text, flags=re.DOTALL | re.IGNORECASE
)
return text.strip()
+
def sort_key(filename):
name = filename.replace(".png", "")
- numbers = re.findall(r'\d+', name)
+ numbers = re.findall(r"\d+", name)
return tuple(int(n) for n in numbers)
+
def indent_block(text, spaces=3):
prefix = " " * spaces
return "\n".join(prefix + line if line.strip() else "" for line in text.split("\n"))
+
def convert(md_path, img_folder, out_path, prefix):
print(f"\n===== {prefix.upper()} =====")
with open(md_path, "r", encoding="utf-8") as f:
content = f.read()
- steps = [clean_text(s) for s in content.split('---') if s.strip()]
+ steps = [clean_text(s) for s in content.split("---") if s.strip()]
images = sorted(
- [f for f in os.listdir(img_folder) if f.endswith(".png")],
- key=sort_key
+ [f for f in os.listdir(img_folder) if f.endswith(".png")], key=sort_key
)
print(f"Steps: {len(steps)}")
@@ -51,7 +48,7 @@ def convert(md_path, img_folder, out_path, prefix):
rst = ""
# ONE container per tutorial
- rst += f".. container:: step\n\n"
+ rst += ".. container:: step\n\n"
for i, step in enumerate(steps):
step_id = i + 1
@@ -75,10 +72,10 @@ def convert(md_path, img_folder, out_path, prefix):
buttons = '
"
)
self.progress.setRange(0, 0)
self.current_process = "AQME"
- self.aqme_role = "test"
+ self.aqme_role = "test"
self.worker = RobertWorker(aqme_command, run_dir)
self.worker.output_received.connect(self.console_output.append)
@@ -1826,9 +1804,7 @@ def _detect_aqme_output_csv(self):
run_dir = os.path.dirname(self.csv_test_path)
# Original CSV name without extension)
- original_name = os.path.splitext(
- os.path.basename(self.csv_test_path)
- )[0]
+ original_name = os.path.splitext(os.path.basename(self.csv_test_path))[0]
# Expected output CSV
expected_csv = f"AQME-ROBERT_full_{original_name}.csv"
@@ -1838,13 +1814,13 @@ def _detect_aqme_output_csv(self):
return expected_path
return None
-
+
def _write_atom_mapping_dat(
self,
smarts: str,
selected_atoms: list,
run_dir: str,
- filename: str = "AtomMapping_data.dat"
+ filename: str = "AtomMapping_data.dat",
):
"""
Write an atomic mapping contract to disk.
@@ -1915,9 +1891,10 @@ def _write_atom_mapping_dat(
self.console_output.append(
f"WARNING: Failed to write atom mapping dat: {e}"
)
+
def _check_generate_folder(self, run_dir):
"""Checks if a GENERATE folder exists in the run directory."""
-
+
generate_dir = os.path.join(run_dir, "GENERATE")
if not os.path.exists(generate_dir):
@@ -1926,12 +1903,12 @@ def _check_generate_folder(self, run_dir):
"Trained model not found",
"No trained model was found in this folder.\n\n"
"Prediction with a test CSV requires an existing model "
- "generated in a previous run (GENERATE step).\n"
+ "generated in a previous run (GENERATE step).\n",
)
return False
return True
-
+
def _validate_robert_workflow(self):
"""
Validates workflow state and resolves mismatches.
@@ -1943,14 +1920,18 @@ def _validate_robert_workflow(self):
# ---------------------------------------------------
# Detect mismatch: test CSV loaded but not PREDICT
# ---------------------------------------------------
- if self.csv_test_path and not self.file_path and workflow not in ["PREDICT", "REPORT"]:
+ if (
+ self.csv_test_path
+ and not self.file_path
+ and workflow not in ["PREDICT", "REPORT"]
+ ):
reply = QMessageBox.question(
self,
"Possible workflow mismatch",
"You loaded a test CSV but selected 'Full Workflow'.\n\n"
"Did you mean to generate predictions instead?",
QMessageBox.Yes | QMessageBox.No,
- QMessageBox.Yes
+ QMessageBox.Yes,
)
if reply == QMessageBox.Yes:
@@ -1961,7 +1942,7 @@ def _validate_robert_workflow(self):
self,
"Execution stopped",
"To run 'Full Workflow' in ROBERT, please load the training CSV "
- "and select the appropriate target and name columns."
+ "and select the appropriate target and name columns.",
)
return False
@@ -1973,7 +1954,7 @@ def _validate_robert_workflow(self):
QMessageBox.warning(
self,
"WARNING!",
- "Please load a training CSV file before running the workflow."
+ "Please load a training CSV file before running the workflow.",
)
return False
@@ -1981,12 +1962,9 @@ def _validate_robert_workflow(self):
# PREDICT validation
# ---------------------------------------------------
if workflow == "PREDICT":
-
if not self.csv_test_path:
QMessageBox.warning(
- self,
- "WARNING!",
- "Please select a test CSV file for prediction."
+ self, "WARNING!", "Please select a test CSV file for prediction."
)
return False
@@ -1999,12 +1977,11 @@ def _validate_robert_workflow(self):
# REPORT validation
# ---------------------------------------------------
if workflow == "REPORT":
-
if not self.file_path and not self.csv_test_path:
QMessageBox.warning(
self,
"WARNING!",
- "Please load a CSV file to determine the report directory."
+ "Please load a CSV file to determine the report directory.",
)
return False
@@ -2018,7 +1995,7 @@ def run_robert(self):
# --------------------------------------------------
if not self._validate_robert_workflow():
return
-
+
# --------------------------------------------------
# Init process
# --------------------------------------------------
@@ -2033,7 +2010,7 @@ def run_robert(self):
""
)
- # Path to run directory
+ # Path to run directory
if self.file_path:
run_dir = os.path.dirname(self.file_path)
elif self.csv_test_path:
@@ -2048,8 +2025,7 @@ def run_robert(self):
folders_to_check.extend(["CSEARCH", "QDESCP"])
existing_folders = [
- f for f in folders_to_check
- if os.path.exists(os.path.join(run_dir, f))
+ f for f in folders_to_check if os.path.exists(os.path.join(run_dir, f))
]
if existing_folders and self.workflow_selector.currentText() == "Full Workflow":
@@ -2061,7 +2037,7 @@ def run_robert(self):
"or will be overwritten if the previous run completed successfully.\n\n"
"Are you sure you want to continue and delete them?",
QMessageBox.Yes | QMessageBox.No,
- QMessageBox.No
+ QMessageBox.No,
)
if confirmation == QMessageBox.No:
@@ -2076,8 +2052,8 @@ def run_robert(self):
f"[ERROR] Could not delete folder '{folder}': {e}"
)
self._reset_ui_after_process()
- return
-
+ return
+
# --------------------------------------------------
# Collect GUI values
# --------------------------------------------------
@@ -2089,7 +2065,7 @@ def run_robert(self):
"WARNING! Invalid parameters. Please fix them before running."
)
return
-
+
# Rename pdf if full workflow or report selected
wf_predict = self.workflow_selector.currentText()
if wf_predict == "Full Workflow" or wf_predict == "REPORT":
@@ -2113,10 +2089,10 @@ def run_robert(self):
train_source_csv = self._get_unmapped_csv(self.file_path)
- self.mapped_train_csv = self.tab_widget_aqme.generate_mapped_csv_from_smiles(
- train_source_csv,
- smarts,
- selected_atoms_for_robert
+ self.mapped_train_csv = (
+ self.tab_widget_aqme.generate_mapped_csv_from_smiles(
+ train_source_csv, smarts, selected_atoms_for_robert
+ )
)
is_robert_mapped = True
@@ -2124,18 +2100,16 @@ def run_robert(self):
if getattr(self, "csv_test_path", None):
test_source_csv = self._get_unmapped_csv(self.csv_test_path)
- self.mapped_test_csv = self.tab_widget_aqme.generate_mapped_csv_from_smiles(
- test_source_csv,
- smarts,
- selected_atoms_for_robert
+ self.mapped_test_csv = (
+ self.tab_widget_aqme.generate_mapped_csv_from_smiles(
+ test_source_csv, smarts, selected_atoms_for_robert
+ )
)
# Save atomic mapping contract in .dat
run_dir = os.path.dirname(self.file_path)
self._write_atom_mapping_dat(
- smarts=smarts,
- selected_atoms=selected_atoms_for_robert,
- run_dir=run_dir
+ smarts=smarts, selected_atoms=selected_atoms_for_robert, run_dir=run_dir
)
# --------------------------------------------------
@@ -2151,7 +2125,7 @@ def run_robert(self):
self.mapped_test_csv
if is_robert_mapped and self.mapped_test_csv
else getattr(self, "csv_test_path", None)
- )
+ )
# --------------------------------------------------------------------------
# AQME-origin CSV check, disable AQME workflow if detected previously runned
@@ -2169,14 +2143,11 @@ def run_robert(self):
if not self.check_atomic_descriptors("ROBERT"):
self._reset_ui_after_process()
return
-
+
# --------------------------------------------------
# PREDICT + csv_test preflight
# --------------------------------------------------
- if (
- self.workflow_selector.currentText() == "PREDICT"
- and self.csv_test_path
- ):
+ if self.workflow_selector.currentText() == "PREDICT" and self.csv_test_path:
run_dir = os.path.dirname(self.csv_test_path)
# -----------------------------------------------
@@ -2185,7 +2156,6 @@ def run_robert(self):
dat_path = os.path.join(run_dir, "AtomMapping_data.dat")
if os.path.isfile(dat_path):
-
self.console_output.append(
f"[INFO] Atomic mapping contract detected: {dat_path}"
)
@@ -2196,15 +2166,14 @@ def run_robert(self):
# Validate + Apply in one step
new_mapped_csv = self._apply_mapping_smarts(
- self.csv_test_path,
- contract
+ self.csv_test_path, contract
)
self.console_output.append(
f"[INFO] Generated mapped test CSV: {new_mapped_csv}"
)
- # Save original test CSV only
+ # Save original test CSV only
if not getattr(self, "_original_test_csv_path", None):
self._original_test_csv_path = self.csv_test_path
@@ -2217,7 +2186,7 @@ def run_robert(self):
"Atomic mapping mismatch",
f"{e}\n\n"
"Prediction cannot continue because atomic descriptors "
- "would not be consistent with the trained model."
+ "would not be consistent with the trained model.",
)
self._reset_ui_after_process()
return
@@ -2253,13 +2222,13 @@ def run_robert(self):
run_dir = os.path.dirname(selected_file_path)
elif self.csv_test_path:
run_dir = os.path.dirname(self.csv_test_path)
-
+
self.worker = RobertWorker(command, run_dir)
self.worker.output_received.connect(self.console_output.append)
self.worker.error_received.connect(self.console_output.append)
self.worker.process_finished.connect(self.on_process_finished)
self.worker.start()
-
+
def _read_atom_mapping_dat(self, dat_path):
"""
Read atomic mapping contract from .dat file.
@@ -2284,13 +2253,17 @@ def _read_atom_mapping_dat(self, dat_path):
mapping = []
for line in lines:
-
# SMARTS line
if line.startswith("SMARTS pattern"):
continue # skip header line
- if smarts is None and not line.startswith("-") and "Pattern atoms" not in line and "Pattern atom" not in line:
- # First non-header SMARTS candidate
+ if (
+ smarts is None
+ and not line.startswith("-")
+ and "Pattern atoms" not in line
+ and "Pattern atom" not in line
+ ):
+ # First non-header SMARTS candidate
if "[" in line or "#" in line:
smarts = line
@@ -2303,28 +2276,26 @@ def _read_atom_mapping_dat(self, dat_path):
# Extract using regex
match = re.search(
r"Pattern atom\s+(\d+)\s+→\s+atomMap\s+(\d+)\s+\(Element:\s+(\w+)\)",
- line
+ line,
)
if match:
pattern_idx = int(match.group(1))
map_num = int(match.group(2))
element = match.group(3)
- mapping.append({
- "pattern_idx": pattern_idx,
- "map_num": map_num,
- "element": element
- })
+ mapping.append(
+ {
+ "pattern_idx": pattern_idx,
+ "map_num": map_num,
+ "element": element,
+ }
+ )
if smarts is None or pattern_atoms is None or not mapping:
raise ValueError("Invalid atom_mapping.dat format")
-
- return {
- "smarts": smarts,
- "pattern_atoms": pattern_atoms,
- "mapping": mapping
- }
-
+
+ return {"smarts": smarts, "pattern_atoms": pattern_atoms, "mapping": mapping}
+
def _apply_mapping_smarts(self, csv_path, contract):
"""
Validate and apply atomic mapping contract to CSV.
@@ -2335,10 +2306,7 @@ def _apply_mapping_smarts(self, csv_path, contract):
df = smart_read_csv(csv_path)
- smiles_col = next(
- (c for c in df.columns if c.lower() == "smiles"),
- None
- )
+ smiles_col = next((c for c in df.columns if c.lower() == "smiles"), None)
if smiles_col is None:
raise ValueError("CSV has no SMILES column")
@@ -2354,7 +2322,7 @@ def _apply_mapping_smarts(self, csv_path, contract):
if pattern_mol.GetNumAtoms() != expected_pattern_atoms:
raise ValueError("SMARTS atom count mismatch with contract")
-
+
# --------------------------------------------------
# Step 1: Detect if already mapped correctly
# --------------------------------------------------
@@ -2367,7 +2335,6 @@ def _apply_mapping_smarts(self, csv_path, contract):
mapped_smiles = []
for smiles in df[smiles_col].dropna().astype(str):
-
mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
if mol is None:
raise ValueError(f"Invalid SMILES: {smiles}")
@@ -2445,7 +2412,7 @@ def build_robert_command(self, selected_file_path):
command += f' --csv_test "{csv_test}"'
return command
-
+
# ==================================================
# NORMAL WORKFLOW (CURATE / GENERATE / VERIFY)
# ==================================================
@@ -2466,8 +2433,7 @@ def build_robert_command(self, selected_file_path):
# ---------- IGNORE COLUMNS ----------
selected_columns = [
- self.ignore_list.item(i).text()
- for i in range(self.ignore_list.count())
+ self.ignore_list.item(i).text() for i in range(self.ignore_list.count())
]
if selected_columns:
formatted_columns = [f"'{col}'" for col in selected_columns]
@@ -2487,21 +2453,21 @@ def build_robert_command(self, selected_file_path):
command += " --auto_type False"
if self.seed_value:
- command += f' --seed {self.seed_value}'
+ command += f" --seed {self.seed_value}"
if self.kfold_value:
- command += f' --kfold {self.kfold_value}'
+ command += f" --kfold {self.kfold_value}"
if self.repeat_kfolds_value:
- command += f' --repeat_kfolds {self.repeat_kfolds_value}'
+ command += f" --repeat_kfolds {self.repeat_kfolds_value}"
if self.split_value != "even":
- command += f' --split {self.split_value.lower()}'
+ command += f" --split {self.split_value.lower()}"
# ---------- AQME ----------
if self.aqme_workflow.isChecked():
- command += ' --aqme'
- command += f' --descp_lvl {self.descriptor_level_selected}'
+ command += " --aqme"
+ command += f" --descp_lvl {self.descriptor_level_selected}"
atoms_entries = []
@@ -2527,66 +2493,66 @@ def build_robert_command(self, selected_file_path):
# ---------- CURATE ----------
if self.categorical_value != "onehot":
- command += f' --categorical {self.categorical_value}'
+ command += f" --categorical {self.categorical_value}"
if not self.corr_filter_x_value:
- command += ' --corr_filter_x False'
+ command += " --corr_filter_x False"
if self.corr_filter_y_value:
- command += ' --corr_filter_y True'
+ command += " --corr_filter_y True"
if self.desc_thres_value:
- command += f' --desc_thres {self.desc_thres_value}'
+ command += f" --desc_thres {self.desc_thres_value}"
if self.thres_x_value:
- command += f' --thres_x {self.thres_x_value}'
+ command += f" --thres_x {self.thres_x_value}"
if self.thres_y_value:
- command += f' --thres_y {self.thres_y_value}'
+ command += f" --thres_y {self.thres_y_value}"
# ---------- GENERATE ----------
if self.selected_models != self.default_models:
- model_list = "[" + ",".join(
- f"'{m}'" for m in sorted(self.selected_models)
- ) + "]"
+ model_list = (
+ "[" + ",".join(f"'{m}'" for m in sorted(self.selected_models)) + "]"
+ )
command += f' --model "{model_list}"'
if self.error_type_value != self.default_error_type:
- command += f' --error_type {self.error_type_value}'
+ command += f" --error_type {self.error_type_value}"
if self.init_points_value:
- command += f' --init_points {self.init_points_value}'
+ command += f" --init_points {self.init_points_value}"
if self.n_iter_value:
- command += f' --n_iter {self.n_iter_value}'
+ command += f" --n_iter {self.n_iter_value}"
if not self.pfi_filter_value:
command += " --pfi_filter False"
if self.pfi_epochs_value:
- command += f' --pfi_epochs {self.pfi_epochs_value}'
+ command += f" --pfi_epochs {self.pfi_epochs_value}"
if self.pfi_threshold_value:
- command += f' --pfi_threshold {self.pfi_threshold_value}'
+ command += f" --pfi_threshold {self.pfi_threshold_value}"
if self.pfi_max_value:
- command += f' --pfi_max {self.pfi_max_value}'
+ command += f" --pfi_max {self.pfi_max_value}"
if not self.auto_test_value:
command += " --auto_test False"
if self.test_set_value:
- command += f' --test_set {self.test_set_value}'
+ command += f" --test_set {self.test_set_value}"
# ---------- PREDICT OPTIONS (shared flags) ----------
if self.t_value:
- command += f' --t_value {self.t_value}'
+ command += f" --t_value {self.t_value}"
if self.shap_show:
- command += f' --shap_show {self.shap_show}'
+ command += f" --shap_show {self.shap_show}"
if self.pfi_show:
- command += f' --pfi_show {self.pfi_show}'
+ command += f" --pfi_show {self.pfi_show}"
return command
@@ -2601,7 +2567,9 @@ def _collect_robert_gui_values(self):
self.split_value = self.options_tab.split.currentText().strip()
# ---------- AQME ----------
- self.descriptor_level_selected = self.tab_widget_aqme.descriptor_level.currentText()
+ self.descriptor_level_selected = (
+ self.tab_widget_aqme.descriptor_level.currentText()
+ )
self.atoms_selected = self.tab_widget_aqme.atoms.text().strip()
self.solvent_selected = self.tab_widget_aqme.solvent.currentText()
@@ -2617,15 +2585,15 @@ def _collect_robert_gui_values(self):
type_mode = self.type_dropdown.currentText()
self.default_models = (
- {"RF", "GB", "NN", "MVL"} if type_mode == "Regression"
+ {"RF", "GB", "NN", "MVL"}
+ if type_mode == "Regression"
else {"RF", "GB", "NN", "AdaB"}
)
- self.default_error_type = (
- "rmse" if type_mode == "Regression" else "mcc"
- )
+ self.default_error_type = "rmse" if type_mode == "Regression" else "mcc"
self.selected_models = {
- model for model, checkbox in self.options_tab.modellist.items()
+ model
+ for model, checkbox in self.options_tab.modellist.items()
if checkbox.isChecked()
}
@@ -2662,9 +2630,7 @@ def _ensure_test_name_column(self):
df_test = pd.read_csv(self.csv_test_path)
except Exception as e:
QMessageBox.warning(
- self,
- "Test dataset error",
- f"Could not read test CSV file:\n\n{e}"
+ self, "Test dataset error", f"Could not read test CSV file:\n\n{e}"
)
return False
@@ -2686,7 +2652,7 @@ def _ensure_test_name_column(self):
QMessageBox.warning(
self,
"Test dataset error",
- f"Could not update test CSV file:\n\n{e}"
+ f"Could not update test CSV file:\n\n{e}",
)
return False
@@ -2703,7 +2669,7 @@ def _ensure_test_name_column(self):
"Incompatible test dataset",
f"The selected name column '{name_col}' is not present in the test dataset.\n\n"
"The test CSV does not contain this column, nor a fallback 'code_name' column.\n\n"
- "Please select a compatible test dataset or change the name column."
+ "Please select a compatible test dataset or change the name column.",
)
return False
@@ -2722,9 +2688,7 @@ def check_variables_robert(self):
# ------------------------
if is_predict and not self.file_path and not self.csv_test_path:
QMessageBox.warning(
- self,
- "Invalid Selection",
- "Predict requires at least one CSV file."
+ self, "Invalid Selection", "Predict requires at least one CSV file."
)
return False
@@ -2736,7 +2700,7 @@ def check_variables_robert(self):
QMessageBox.warning(
self,
"Invalid Selection",
- "The name column and the target value column cannot be the same. Please select different columns."
+ "The name column and the target value column cannot be the same. Please select different columns.",
)
return False
@@ -2763,10 +2727,14 @@ def check_variables_robert(self):
# AQME (skip in PREDICT)
# ------------------------
if self.aqme_workflow.isChecked() and not is_predict:
-
total_columns = []
- total_columns += [self.available_list.item(i).text() for i in range(self.available_list.count())]
- total_columns += [self.ignore_list.item(i).text() for i in range(self.ignore_list.count())]
+ total_columns += [
+ self.available_list.item(i).text()
+ for i in range(self.available_list.count())
+ ]
+ total_columns += [
+ self.ignore_list.item(i).text() for i in range(self.ignore_list.count())
+ ]
lowercase_columns = [col.lower() for col in total_columns]
if not any(col.startswith("smiles") for col in lowercase_columns):
@@ -2853,7 +2821,7 @@ def check_variables_robert(self):
return False
return True
-
+
def build_aqme_command(self, selected_file_path, selected_atoms_override=None):
"""Builds the AQME command based on the GUI selections."""
@@ -2873,9 +2841,9 @@ def build_aqme_command(self, selected_file_path, selected_atoms_override=None):
command = (
f'"{python_pointer}" -u -m aqme --qdescp '
f'--input "{csv_name}" '
- f'--program xtb '
+ f"--program xtb "
f'--csv_name "{csv_name}" '
- f'--robert'
+ f"--robert"
)
# ------------------------
@@ -2946,8 +2914,7 @@ def run_aqme(self):
# ---------------------------
folders_to_check = ["AQME", "CSEARCH", "QDESCP", "AQME_RUNS"]
existing_folders = [
- f for f in folders_to_check
- if os.path.exists(os.path.join(run_dir, f))
+ f for f in folders_to_check if os.path.exists(os.path.join(run_dir, f))
]
if existing_folders:
@@ -2959,7 +2926,7 @@ def run_aqme(self):
"They may be reused or overwritten.\n\n"
"Do you want to continue?",
QMessageBox.Yes | QMessageBox.No,
- QMessageBox.No
+ QMessageBox.No,
)
if confirmation == QMessageBox.No:
@@ -2979,36 +2946,32 @@ def run_aqme(self):
smarts = self.tab_widget_aqme.smarts_targets[0]
- self.mapped_train_csv = self.tab_widget_aqme.generate_mapped_csv_from_smiles(
- self.file_path,
- smarts,
- selected_atoms_for_aqme
+ self.mapped_train_csv = (
+ self.tab_widget_aqme.generate_mapped_csv_from_smiles(
+ self.file_path, smarts, selected_atoms_for_aqme
+ )
)
self.is_aqme_mapped = True
if self.csv_test_path:
- self.mapped_test_csv = self.tab_widget_aqme.generate_mapped_csv_from_smiles(
- self.csv_test_path,
- smarts,
- selected_atoms_for_aqme
+ self.mapped_test_csv = (
+ self.tab_widget_aqme.generate_mapped_csv_from_smiles(
+ self.csv_test_path, smarts, selected_atoms_for_aqme
+ )
)
# Save atomic mapping contract in .dat
run_dir = os.path.dirname(self.file_path)
self._write_atom_mapping_dat(
- smarts=smarts,
- selected_atoms=selected_atoms_for_aqme,
- run_dir=run_dir
+ smarts=smarts, selected_atoms=selected_atoms_for_aqme, run_dir=run_dir
)
# ------------------------------------------------
# Decide REAL input CSVs for AQME
# ------------------------------------------------
train_input_csv = (
- self.mapped_train_csv
- if self.is_aqme_mapped
- else self.file_path
+ self.mapped_train_csv if self.is_aqme_mapped else self.file_path
)
test_input_csv = (
@@ -3021,26 +2984,20 @@ def run_aqme(self):
# Build AQME command queue
# -------------------------
main_cmd, self.aqme_run_dir = self.build_aqme_command(
- train_input_csv,
- selected_atoms_override=selected_atoms_for_aqme
+ train_input_csv, selected_atoms_override=selected_atoms_for_aqme
)
- self.aqme_command_queue.append({
- "command": main_cmd,
- "csv": train_input_csv,
- "role": "train"
- })
+ self.aqme_command_queue.append(
+ {"command": main_cmd, "csv": train_input_csv, "role": "train"}
+ )
if test_input_csv:
test_cmd, _ = self.build_aqme_command(
- test_input_csv,
- selected_atoms_override=selected_atoms_for_aqme
+ test_input_csv, selected_atoms_override=selected_atoms_for_aqme
+ )
+ self.aqme_command_queue.append(
+ {"command": test_cmd, "csv": test_input_csv, "role": "test"}
)
- self.aqme_command_queue.append({
- "command": test_cmd,
- "csv": test_input_csv,
- "role": "test"
- })
# ------------
# Launch AQME
@@ -3052,7 +3009,7 @@ def run_aqme(self):
self._run_next_aqme()
def _run_next_aqme(self):
- """ Runs the next AQME command in the queue."""
+ """Runs the next AQME command in the queue."""
if not self.aqme_command_queue:
return
@@ -3071,23 +3028,25 @@ def stop_process(self):
"""Stops the ROBERT and AQME process safely after user confirmation, non-blocking."""
confirmation = QMessageBox.question(
- self,
- "WARNING!",
+ self,
+ "WARNING!",
"Are you sure you want to stop the process?",
- QMessageBox.Yes | QMessageBox.No,
- QMessageBox.No
+ QMessageBox.Yes | QMessageBox.No,
+ QMessageBox.No,
)
if confirmation == QMessageBox.No:
- return
+ return
self.manual_stop = True
if self.worker and self.worker.isRunning():
- self.console_output.append("
Stopping ROBERT...")
+ self.console_output.append(
+ "
Stopping ROBERT..."
+ )
self.progress.setRange(0, 100)
self.stop_button.setDisabled(True)
- QTimer.singleShot(0, self.worker.stop)
+ QTimer.singleShot(0, self.worker.stop)
def _on_aqme_step_finished(self, exit_code):
"""Handles the completion of an AQME step and manages the queue."""
@@ -3163,7 +3122,7 @@ def on_process_finished(self, exit_code):
QMessageBox.information(
self,
"WARNING!",
- f"{self.current_process} has been successfully stopped."
+ f"{self.current_process} has been successfully stopped.",
)
self.manual_stop = False
self._reset_ui_after_process()
@@ -3175,17 +3134,14 @@ def on_process_finished(self, exit_code):
# AQME COMPLETION LOGIC
# ==================================================
if self.current_process == "AQME":
-
# ==================================================
# AQME SUCCESS
# ==================================================
if exit_code == 0 and "Time QDESCP:" in output_text:
-
# =============================================
# AQME TEST -> chain directly to ROBERT PREDICT
# =============================================
if getattr(self, "aqme_role", None) == "test":
-
# Reset AQME state to avoid conflicts with future runs
self.aqme_role = None
@@ -3197,7 +3153,7 @@ def on_process_finished(self, exit_code):
"AQME error",
"AQME finished successfully, but the expected output CSV was not found.\n\n"
"Descriptor generation for the test set completed, but the generated "
- "CSV file could not be detected, so prediction cannot continue."
+ "CSV file could not be detected, so prediction cannot continue.",
)
self.manual_stop = False
self._reset_ui_after_process()
@@ -3212,8 +3168,8 @@ def on_process_finished(self, exit_code):
# --------------------------------------------------
# Launch ROBERT prediction (force AQME output as test CSV)
# --------------------------------------------------
-
- # Save original test CSV only
+
+ # Save original test CSV only
if not getattr(self, "_original_test_csv_path", None):
self._original_test_csv_path = self.csv_test_path
@@ -3230,22 +3186,17 @@ def on_process_finished(self, exit_code):
self.worker.error_received.connect(self.console_output.append)
self.worker.process_finished.connect(self.on_process_finished)
self.worker.start()
- return
+ return
# =============================================
# AQME TRAIN -> original popup logic (unchanged)
# =============================================
train_run = next(
- (r for r in self.aqme_runs if r["role"] == "train"),
- None
+ (r for r in self.aqme_runs if r["role"] == "train"), None
)
if not train_run:
- QMessageBox.warning(
- self,
- "WARNING!",
- "No AQME train output found."
- )
+ QMessageBox.warning(self, "WARNING!", "No AQME train output found.")
self.manual_stop = False
self._reset_ui_after_process()
return
@@ -3255,8 +3206,12 @@ def on_process_finished(self, exit_code):
base_name = os.path.splitext(os.path.basename(aqme_base))[0]
aqme_csvs = {
- "denovo": os.path.join(base_dir, f"AQME-ROBERT_denovo_{base_name}.csv"),
- "interpret": os.path.join(base_dir, f"AQME-ROBERT_interpret_{base_name}.csv"),
+ "denovo": os.path.join(
+ base_dir, f"AQME-ROBERT_denovo_{base_name}.csv"
+ ),
+ "interpret": os.path.join(
+ base_dir, f"AQME-ROBERT_interpret_{base_name}.csv"
+ ),
"full": os.path.join(base_dir, f"AQME-ROBERT_full_{base_name}.csv"),
}
@@ -3316,17 +3271,10 @@ def on_process_finished(self, exit_code):
)
btn_interpret = msg.addButton(
- "Interpret descriptors (recommended)",
- QMessageBox.ActionRole
- )
- btn_denovo = msg.addButton(
- "DeNovo descriptors",
- QMessageBox.ActionRole
- )
- btn_full = msg.addButton(
- "Full descriptors",
- QMessageBox.ActionRole
+ "Interpret descriptors (recommended)", QMessageBox.ActionRole
)
+ btn_denovo = msg.addButton("DeNovo descriptors", QMessageBox.ActionRole)
+ btn_full = msg.addButton("Full descriptors", QMessageBox.ActionRole)
msg.addButton("Cancel", QMessageBox.RejectRole)
msg.exec()
@@ -3356,8 +3304,7 @@ def on_process_finished(self, exit_code):
base_name = os.path.splitext(os.path.basename(run["csv"]))[0]
output_csv = os.path.join(
- base_dir,
- f"AQME-ROBERT_{selected_level}_{base_name}.csv"
+ base_dir, f"AQME-ROBERT_{selected_level}_{base_name}.csv"
)
if not os.path.isfile(output_csv):
@@ -3386,7 +3333,7 @@ def on_process_finished(self, exit_code):
# --------------------------------------------------
for level in ["denovo", "interpret", "full"]:
if not user_cancelled and level == selected_level:
- continue # keep active CSV where it is
+ continue # keep active CSV where it is
csv_name = f"AQME-ROBERT_{level}_{base_name}.csv"
csv_path = Path(self.aqme_run_dir) / csv_name
@@ -3399,7 +3346,7 @@ def on_process_finished(self, exit_code):
target_path.unlink()
shutil.move(str(csv_path), str(target_path))
-
+
# --------------------------------------------------
# 2) Move mapped CSVs (PER RUN, TRAIN + TEST)
# --------------------------------------------------
@@ -3430,7 +3377,7 @@ def on_process_finished(self, exit_code):
QMessageBox.warning(
self,
"WARNING!",
- "AQME encountered an issue while finishing. Please check the logs."
+ "AQME encountered an issue while finishing. Please check the logs.",
)
# End of AQME workflow
self.manual_stop = False
@@ -3454,8 +3401,13 @@ def on_process_finished(self, exit_code):
# ------------------------
# Full workflow / REPORT
# ------------------------
- if not self.manual_stop and (workflow == "Full Workflow" or workflow == "REPORT"):
- if exit_code == 0 and "ROBERT_report.pdf was created successfully" in output_text:
+ if not self.manual_stop and (
+ workflow == "Full Workflow" or workflow == "REPORT"
+ ):
+ if (
+ exit_code == 0
+ and "ROBERT_report.pdf was created successfully" in output_text
+ ):
msg_box = QMessageBox(self)
msg_box.setIcon(QMessageBox.Information)
msg_box.setWindowTitle("Success!")
@@ -3476,7 +3428,7 @@ def on_process_finished(self, exit_code):
QMessageBox.warning(
self,
"WARNING!",
- "ROBERT encountered an issue while finishing. Please check the logs."
+ "ROBERT encountered an issue while finishing. Please check the logs.",
)
# ------------------------
@@ -3485,41 +3437,57 @@ def on_process_finished(self, exit_code):
elif workflow == "CURATE":
if exit_code == 0 and "Time CURATE:" in output_text:
QMessageBox.information(
- self, "Success", "ROBERT has successfully completed the CURATE step."
+ self,
+ "Success",
+ "ROBERT has successfully completed the CURATE step.",
)
else:
QMessageBox.warning(
- self, "WARNING!", "ROBERT encountered an issue while finishing. Please check the logs."
+ self,
+ "WARNING!",
+ "ROBERT encountered an issue while finishing. Please check the logs.",
)
elif workflow == "GENERATE":
if exit_code == 0 and "Time GENERATE:" in output_text:
QMessageBox.information(
- self, "Success", "ROBERT has successfully completed the GENERATE step."
+ self,
+ "Success",
+ "ROBERT has successfully completed the GENERATE step.",
)
else:
QMessageBox.warning(
- self, "WARNING!", "ROBERT encountered an issue while finishing. Please check the logs."
+ self,
+ "WARNING!",
+ "ROBERT encountered an issue while finishing. Please check the logs.",
)
elif workflow == "PREDICT":
if exit_code == 0 and "Time PREDICT:" in output_text:
QMessageBox.information(
- self, "Success", "ROBERT has successfully completed the PREDICT step."
+ self,
+ "Success",
+ "ROBERT has successfully completed the PREDICT step.",
)
else:
QMessageBox.warning(
- self, "WARNING!", "ROBERT encountered an issue while finishing. Please check the logs."
+ self,
+ "WARNING!",
+ "ROBERT encountered an issue while finishing. Please check the logs.",
)
elif workflow == "VERIFY":
if exit_code == 0 and "Time VERIFY:" in output_text:
QMessageBox.information(
- self, "Success", "ROBERT has successfully completed the VERIFY step."
+ self,
+ "Success",
+ "ROBERT has successfully completed the VERIFY step.",
)
else:
QMessageBox.warning(
- self, "WARNING!", "ROBERT encountered an issue while finishing. Please check the logs."
+ self,
+ "WARNING!",
+ "ROBERT encountered an issue while finishing. Please check the logs.",
)
# Restore previous test CSV if overridden for test workflow aqme generation
diff --git a/robert/gui_easyrob/tabs/advanced_options.py b/robert/gui_easyrob/tabs/advanced_options.py
index c95b60b..5e8042f 100644
--- a/robert/gui_easyrob/tabs/advanced_options.py
+++ b/robert/gui_easyrob/tabs/advanced_options.py
@@ -28,7 +28,6 @@
# Attempt local imports first (portable mode). If they fail,
# fall back to installed package imports.
try:
-
from utils.utils_gui import (
AssetLibrary,
QCheckBox,
@@ -47,8 +46,7 @@
Qt,
)
-except ImportError as e:
-
+except ImportError:
from robert.gui_easyrob.utils.utils_gui import (
AssetLibrary,
QCheckBox,
@@ -67,14 +65,16 @@
Qt,
)
+
class AdvancedOptionsTab(QWidget):
"""Tab for advanced options in the easyROB application."""
+
def __init__(self, type_dropdown, tab_widget):
super().__init__()
self.type = type_dropdown
self.tab_widget = tab_widget # Reference to the main QTabWidget
main_layout = QVBoxLayout(self)
- grid_layout = QGridLayout()
+ grid_layout = QGridLayout()
self.box_features = "QGroupBox { font-weight: bold; }"
# Create section boxes
@@ -93,14 +93,13 @@ def __init__(self, type_dropdown, tab_widget):
# PREDICT (Bottom Row, Full Width)
grid_layout.addWidget(predict_box, 2, 0, 1, 2)
-
# Add the grid layout to the main layout
main_layout.addLayout(grid_layout)
self.setLayout(main_layout)
def go_to_help_section(self, anchor):
"""Open a documentation section in the browser."""
-
+
base_url = "https://robert.readthedocs.io/en/latest/Technical/defaults.html"
if anchor.upper() == "GENERAL":
@@ -139,7 +138,7 @@ def create_general_section(self):
self.seed = QLineEdit()
self.seed.setPlaceholderText("0")
layout.addRow(QLabel("seed:"), self.seed)
-
+
self.kfold = QLineEdit()
self.kfold.setPlaceholderText("5")
layout.addRow(QLabel("kfold:"), self.kfold)
@@ -149,7 +148,7 @@ def create_general_section(self):
layout.addRow(QLabel("repeat_kfolds:"), self.repeat_kfolds)
self.split = QComboBox()
- self.split.addItems([ "even", "RND", "stratified", "KN", "extra_q1", "extra_q5" ])
+ self.split.addItems(["even", "RND", "stratified", "KN", "extra_q1", "extra_q5"])
layout.addRow(QLabel("split:"), self.split)
# --- Help button at the bottom ---
@@ -165,7 +164,7 @@ def create_general_section(self):
def create_curate_section(self):
"""Creates the CURATE section with a box and input fields."""
box = QGroupBox("CURATE")
- box.setStyleSheet(self.box_features)
+ box.setStyleSheet(self.box_features)
layout = QFormLayout()
# Add new input fields for additional options
@@ -206,7 +205,7 @@ def create_curate_section(self):
def create_generate_section(self):
"""Creates the GENERATE section with a box and input fields."""
box = QGroupBox("GENERATE")
- box.setStyleSheet(self.box_features)
+ box.setStyleSheet(self.box_features)
layout = QFormLayout()
self.model_group = QGroupBox("Models")
@@ -220,9 +219,19 @@ def update_model_options():
# Determine which models should be checked by default
if self.type.currentText() == "Regression":
- default_checked_models = ["RF", "GB", "NN", "MVL"] # Regression defaults
+ default_checked_models = [
+ "RF",
+ "GB",
+ "NN",
+ "MVL",
+ ] # Regression defaults
else:
- default_checked_models = ["RF", "GB", "NN", "AdaB"] # Classification defaults
+ default_checked_models = [
+ "RF",
+ "GB",
+ "NN",
+ "AdaB",
+ ] # Classification defaults
# Update check states instead of recreating widgets
for model, checkbox in self.modellist.items():
@@ -248,14 +257,14 @@ def update_model_options():
# Error type selection that changes dynamically but is also user-selectable
self.error_type = QComboBox()
layout.addRow(QLabel("error_type:"), self.error_type)
-
+
def update_error_type():
self.error_type.clear()
if self.type.currentText() == "Regression":
self.error_type.addItems(["rmse", "mae", "r2"])
else:
self.error_type.addItems(["mcc", "f1", "acc"])
-
+
self.type.currentIndexChanged.connect(update_error_type)
update_error_type() # Initialize with the correct default values
@@ -308,17 +317,17 @@ def update_error_type():
def create_predict_section(self):
"""Creates the PREDICT section with a box and input fields."""
box = QGroupBox("PREDICT")
- box.setStyleSheet(self.box_features)
+ box.setStyleSheet(self.box_features)
layout = QFormLayout()
-
+
self.t_value = QLineEdit()
self.t_value.setPlaceholderText("2")
layout.addRow(QLabel("t_value:"), self.t_value)
-
+
self.shap_show = QLineEdit()
self.shap_show.setPlaceholderText("10")
layout.addRow(QLabel("shap_show:"), self.shap_show)
-
+
self.pfi_show = QLineEdit()
self.pfi_show.setPlaceholderText("10")
layout.addRow(QLabel("pfi_show:"), self.pfi_show)
@@ -331,4 +340,4 @@ def create_predict_section(self):
layout.setAlignment(help_button, Qt.AlignRight)
box.setLayout(layout)
- return box
\ No newline at end of file
+ return box
diff --git a/robert/gui_easyrob/tabs/aqme.py b/robert/gui_easyrob/tabs/aqme.py
index 14b4d52..526fc28 100644
--- a/robert/gui_easyrob/tabs/aqme.py
+++ b/robert/gui_easyrob/tabs/aqme.py
@@ -68,7 +68,7 @@
from utils.aqme_utils import ChemDrawFileDialog, MCSProcessWorker
-except ImportError as e:
+except ImportError:
from robert.gui_easyrob.utils.utils_gui import (
AssetLibrary,
BytesIO,
@@ -114,21 +114,24 @@
import csv
from functools import partial
+
class AQMETab(QWidget):
"""Tab responsible for AQME-oriented chemistry preparation workflows."""
- def __init__(self, tab_parent=None, main_window=None):
+ def __init__(self, tab_parent=None, main_window=None):
super().__init__(tab_parent) # tab_parent = QTabWidget
- self.main_tab_widget = tab_parent # Reference to the main QTabWidget
- self.main_window = main_window # Reference to the main window, accessible to csv_df, csv_path, etc...
+ self.main_tab_widget = tab_parent # Reference to the main QTabWidget
+ self.main_window = main_window # Reference to the main window, accessible to csv_df, csv_path, etc...
self.selected_atoms = []
self.box_features = "QGroupBox { font-weight: bold; }"
# === Main vertical layout ===
main_layout = QVBoxLayout(self)
- # --- ChemDraw Button (modern purple style + top spacing) ---
- self.chemdraw_button = QPushButton("Generate CSV from ChemDraw Files or SDF file")
+ # --- ChemDraw Button (modern purple style + top spacing) ---
+ self.chemdraw_button = QPushButton(
+ "Generate CSV from ChemDraw Files or SDF file"
+ )
self.chemdraw_button.setCursor(Qt.PointingHandCursor)
self.chemdraw_button.setFixedSize(400, 42)
@@ -164,7 +167,6 @@ def __init__(self, tab_parent=None, main_window=None):
main_layout.addLayout(button_container)
-
# === Viewer container with label + viewer stacked ===
self.mol_viewer_container = QWidget()
self.mol_viewer_container.setFixedSize(400, 400)
@@ -182,9 +184,11 @@ def __init__(self, tab_parent=None, main_window=None):
# Allow text selection
self.mol_viewer.setTextInteractionFlags(
- Qt.TextSelectableByMouse | Qt.TextSelectableByKeyboard
+ Qt.TextSelectableByMouse | Qt.TextSelectableByKeyboard
+ )
+ self.set_mol_viewer_message(
+ "📄 Select a CSV with a SMILES column to display a common SMARTS pattern."
)
- self.set_mol_viewer_message("📄 Select a CSV with a SMILES column to display a common SMARTS pattern.")
self.mol_viewer.setFixedSize(400, 400)
# === mol_info_label ===
@@ -200,15 +204,22 @@ def __init__(self, tab_parent=None, main_window=None):
border: 1px solid #aaa;
""")
- self.mol_info_label.setWordWrap(True)
+ self.mol_info_label.setWordWrap(True)
self.mol_info_label.setTextInteractionFlags(Qt.TextSelectableByMouse)
self.mol_info_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Maximum)
- self.mol_info_label.setMaximumWidth(600)
- self.mol_info_label.setAlignment(Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignTop)
+ self.mol_info_label.setMaximumWidth(600)
+ self.mol_info_label.setAlignment(
+ Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignTop
+ )
# === Set up the molecule viewer ===
mol_layout.addWidget(self.mol_viewer, 0, 0)
- mol_layout.addWidget(self.mol_info_label, 0, 0, alignment=Qt.AlignmentFlag.AlignTop | Qt.AlignmentFlag.AlignLeft)
+ mol_layout.addWidget(
+ self.mol_info_label,
+ 0,
+ 0,
+ alignment=Qt.AlignmentFlag.AlignTop | Qt.AlignmentFlag.AlignLeft,
+ )
mol_wrapper_layout = QHBoxLayout()
mol_wrapper_layout.setAlignment(Qt.AlignmentFlag.AlignCenter)
mol_wrapper_layout.addWidget(self.mol_viewer_container)
@@ -216,7 +227,7 @@ def __init__(self, tab_parent=None, main_window=None):
# === AQME Box at the bottom ===
aqme_box = QGroupBox("AQME")
- aqme_box.setMaximumHeight(200)
+ aqme_box.setMaximumHeight(200)
aqme_box.setStyleSheet(self.box_features)
aqme_layout = QFormLayout()
@@ -224,33 +235,35 @@ def __init__(self, tab_parent=None, main_window=None):
self.descriptor_level = QComboBox()
self.descriptor_level.addItems(["interpret", "denovo", "full"])
self.solvent = QComboBox()
- self.solvent.addItems([
- "None",
- # "Acetone",
- # "Acetonitrile",
- # "Aniline",
- # "Benzaldehyde",
- # "Benzene",
- # "CH2Cl2",
- # "CHCl3",
- # "CS2",
- # "Dioxane",
- # "DMF",
- # "DMSO",
- # "Ether",
- # "Ethylacetate",
- # "Furane",
- # "Hexadecane",
- # "Hexane",
- # "Methanol",
- # "Nitromethane",
- # "Octanol",
- # "Octanol (wet)",
- # "Phenol",
- # "Toluene",
- # "THF",
- # "Water"
- ])
+ self.solvent.addItems(
+ [
+ "None",
+ # "Acetone",
+ # "Acetonitrile",
+ # "Aniline",
+ # "Benzaldehyde",
+ # "Benzene",
+ # "CH2Cl2",
+ # "CHCl3",
+ # "CS2",
+ # "Dioxane",
+ # "DMF",
+ # "DMSO",
+ # "Ether",
+ # "Ethylacetate",
+ # "Furane",
+ # "Hexadecane",
+ # "Hexane",
+ # "Methanol",
+ # "Nitromethane",
+ # "Octanol",
+ # "Octanol (wet)",
+ # "Phenol",
+ # "Toluene",
+ # "THF",
+ # "Water"
+ ]
+ )
aqme_layout.addRow(QLabel("QDESCP Atoms:"), self.atoms)
aqme_layout.addRow(QLabel("Descriptor Level:"), self.descriptor_level)
@@ -302,15 +315,19 @@ def detect_patterns_and_display(self):
"""Detects patterns in the loaded CSV and displays the first molecule."""
try:
- self.csv_df = smart_read_csv(self.file_path) # Store the DataFrame for later use
- self.smiles_column = next((col for col in self.csv_df.columns if col.lower() == "smiles"), None)
+ self.csv_df = smart_read_csv(
+ self.file_path
+ ) # Store the DataFrame for later use
+ self.smiles_column = next(
+ (col for col in self.csv_df.columns if col.lower() == "smiles"), None
+ )
self.set_mol_viewer_message("🔬 Detecting common SMARTS pattern...")
# === Auto SMARTS detection ===
self.auto_pattern()
- except Exception as e:
+ except Exception:
self.set_mol_viewer_message("❌ Failed to load or process the CSV.")
self.mol_info_label.setText("🔬 Info here")
@@ -322,17 +339,14 @@ def _on_mcs_success(self, smarts):
def _on_mcs_error(self, message):
"""Handle MCS detection error."""
- self.set_mol_viewer_message(
- message,
- tooltip="SMARTS pattern detection failed."
- )
+ self.set_mol_viewer_message(message, tooltip="SMARTS pattern detection failed.")
self.mol_info_label.setText("🔬 Info here")
def _on_mcs_timeout(self):
"""Handle MCS detection timeout."""
self.set_mol_viewer_message(
"⏱️ Timeout: MCS (Maximum Common Substructure) took too long and was aborted.",
- tooltip="SMARTS pattern detection failed."
+ tooltip="SMARTS pattern detection failed.",
)
self.mol_info_label.setText("🔬 Info here")
@@ -342,52 +356,33 @@ def build_unified_smiles_context(self, train_csv_path, test_csv_path=None):
(FMCS, ambiguity checks, metal detection).
"""
train_df = smart_read_csv(train_csv_path)
- smiles_col = next(
- (c for c in train_df.columns if c.lower() == "smiles"),
- None
- )
+ smiles_col = next((c for c in train_df.columns if c.lower() == "smiles"), None)
if smiles_col is None:
raise ValueError("TRAIN CSV has no SMILES column")
- unified_smiles = (
- train_df[smiles_col]
- .dropna()
- .astype(str)
- .tolist()
- )
+ unified_smiles = train_df[smiles_col].dropna().astype(str).tolist()
if test_csv_path:
test_df = smart_read_csv(test_csv_path)
test_smiles_col = next(
- (c for c in test_df.columns if c.lower() == "smiles"),
- None
+ (c for c in test_df.columns if c.lower() == "smiles"), None
)
if test_smiles_col is None:
raise ValueError("TEST CSV has no SMILES column")
unified_smiles.extend(
- test_df[test_smiles_col]
- .dropna()
- .astype(str)
- .tolist()
+ test_df[test_smiles_col].dropna().astype(str).tolist()
)
return unified_smiles
-
+
def generate_mapped_csv_from_smiles(
- self,
- csv_path,
- smarts,
- selected_atoms,
- suffix="_mapped"
+ self, csv_path, smarts, selected_atoms, suffix="_mapped"
):
"""Generate a new CSV file with mapped SMILES based on the provided SMARTS pattern"""
df = smart_read_csv(csv_path)
- smiles_col = next(
- (c for c in df.columns if c.lower() == "smiles"),
- None
- )
+ smiles_col = next((c for c in df.columns if c.lower() == "smiles"), None)
if smiles_col is None:
raise ValueError("CSV has no SMILES column")
@@ -447,12 +442,7 @@ def auto_pattern(self):
smiles_list = unified_smiles
else:
# TRAIN only → exploratory
- smiles_list = (
- self.csv_df[self.smiles_column]
- .dropna()
- .astype(str)
- .tolist()
- )
+ smiles_list = self.csv_df[self.smiles_column].dropna().astype(str).tolist()
if not smiles_list:
self.set_mol_viewer_message(
@@ -463,10 +453,7 @@ def auto_pattern(self):
# -------------------------------
# Launch MCS worker
# -------------------------------
- self.mcs_worker = MCSProcessWorker(
- smiles_list,
- timeout_ms=60000
- )
+ self.mcs_worker = MCSProcessWorker(smiles_list, timeout_ms=60000)
self.mcs_worker.finished.connect(self._on_mcs_success)
self.mcs_worker.error.connect(self._on_mcs_error)
@@ -476,16 +463,57 @@ def auto_pattern(self):
def display_molecule(self):
"""Display a SMARTS molecule and highlight atoms based on user selection."""
- rdkit.rdBase.DisableLog('rdApp.*')
+ rdkit.rdBase.DisableLog("rdApp.*")
rdDepictor.SetPreferCoordGen(True)
self.metal_atomic_numbers = {
- 3, 11, 19, 37, 55, 87,
- 4, 12, 20, 38, 56, 88,
- 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
- 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
- 72, 73, 74, 75, 76, 77, 78, 79, 80,
- 13, 49, 50, 81, 82, 83
+ 3,
+ 11,
+ 19,
+ 37,
+ 55,
+ 87,
+ 4,
+ 12,
+ 20,
+ 38,
+ 56,
+ 88,
+ 21,
+ 22,
+ 23,
+ 24,
+ 25,
+ 26,
+ 27,
+ 28,
+ 29,
+ 30,
+ 39,
+ 40,
+ 41,
+ 42,
+ 43,
+ 44,
+ 45,
+ 46,
+ 47,
+ 48,
+ 72,
+ 73,
+ 74,
+ 75,
+ 76,
+ 77,
+ 78,
+ 79,
+ 80,
+ 13,
+ 49,
+ 50,
+ 81,
+ 82,
+ 83,
}
try:
@@ -556,18 +584,18 @@ def display_molecule(self):
highlight_colors = (
{idx: (0.698, 0.4, 1.0) for idx in highlight_atoms}
- if highlight_atoms else {}
+ if highlight_atoms
+ else {}
)
drawer = rdMolDraw2D.MolDraw2DCairo(
- self.molecule_image_width,
- self.molecule_image_height
+ self.molecule_image_width, self.molecule_image_height
)
drawer.drawOptions().bondLineWidth = 1.5
drawer.DrawMolecule(
self.mol,
highlightAtoms=list(highlight_atoms),
- highlightAtomColors=highlight_colors
+ highlightAtomColors=highlight_colors,
)
drawer.FinishDrawing()
@@ -575,8 +603,7 @@ def display_molecule(self):
pixmap = QPixmap()
pixmap.loadFromData(png_bytes)
self.atom_coords = [
- drawer.GetDrawCoords(i)
- for i in range(self.mol.GetNumAtoms())
+ drawer.GetDrawCoords(i) for i in range(self.mol.GetNumAtoms())
]
if self.mol_viewer:
@@ -588,16 +615,16 @@ def display_molecule(self):
if self.metal_found and self.multiple_matches_detected:
self.mol_info_label.setText(
- '🧪 SMARTS pattern loaded. Metal atom(s) automatically selected.
'
+ "🧪 SMARTS pattern loaded. Metal atom(s) automatically selected.
"
'⚠️ Multiple matches were found. '
- 'Atomic descriptors will be generated for the detected metal atom(s). '
- 'Manual atom selection has been disabled to avoid ambiguity.'
+ "Atomic descriptors will be generated for the detected metal atom(s). "
+ "Manual atom selection has been disabled to avoid ambiguity."
)
elif self.metal_found and not self.selected_atoms:
self.mol_info_label.setText(
- '🧪 SMARTS pattern loaded. Click to select atoms.
'
+ "🧪 SMARTS pattern loaded. Click to select atoms.
"
'⚠️ No atoms selected. '
- 'Descriptors will only be generated for the detected metal.'
+ "Descriptors will only be generated for the detected metal."
)
else:
if highlight_atoms:
@@ -606,25 +633,24 @@ def display_molecule(self):
)
else:
self.mol_info_label.setText(
- '🧪 SMARTS pattern loaded. Click to select atoms.
'
+ "🧪 SMARTS pattern loaded. Click to select atoms.
"
'⚠️ WARNING! No atoms selected. '
- 'Atomic descriptors will not be generated.'
+ "Atomic descriptors will not be generated."
)
except Exception as e:
- self.set_mol_viewer_message(
- "❌ Error displaying molecule.",
- tooltip=str(e)
- )
+ self.set_mol_viewer_message("❌ Error displaying molecule.", tooltip=str(e))
self.mol_info_label.setText("🔬 Info here")
def handle_atom_selection(self, atom_idx):
"""Handle the selection of an atom in the pattern."""
- if not hasattr(self, 'selected_atoms'):
+ if not hasattr(self, "selected_atoms"):
self.selected_atoms = []
-
- if getattr(self, 'metal_found', False) and getattr(self, 'multiple_matches_detected', False):
+
+ if getattr(self, "metal_found", False) and getattr(
+ self, "multiple_matches_detected", False
+ ):
# Prevent manual selection when metal match has been auto-selected due to ambiguity
return
@@ -641,11 +667,12 @@ def handle_atom_selection(self, atom_idx):
self.generate_mapped_smiles(
self.smarts_targets[0],
self.selected_atoms,
- self.csv_df[self.smiles_column].dropna()
+ self.csv_df[self.smiles_column].dropna(),
)
-
- def generate_mapped_smiles(self, smarts_pattern, selected_pattern_indices, smiles_list):
+ def generate_mapped_smiles(
+ self, smarts_pattern, selected_pattern_indices, smiles_list
+ ):
"""
Generate mapped SMILES using a SMARTS pattern and selected atom indices.
Updates self.df_mapped_smiles with a copy of the original CSV where 'SMILES' is replaced.
@@ -693,42 +720,51 @@ def generate_mapped_smiles(self, smarts_pattern, selected_pattern_indices, smile
df_mapped[self.smiles_column] = mapped_smiles
self.df_mapped_smiles = df_mapped
-
def mousePressEvent(self, event: QMouseEvent):
"""Handle mouse press events to select atoms and crate pattern.
The logic is to check if the mouse press event is within the molecule_viewer area."""
if event.button() == Qt.MouseButton.LeftButton:
pos = event.position()
- if self.mol_viewer_container and self.mol_viewer_container.geometry().contains(pos.toPoint()):
+ if (
+ self.mol_viewer_container
+ and self.mol_viewer_container.geometry().contains(pos.toPoint())
+ ):
relative_pos = self.mol_viewer_container.mapFrom(self, pos.toPoint())
x = relative_pos.x()
y = relative_pos.y()
selected_atom = self.get_atom_at_position(x, y)
if selected_atom is not None:
self.handle_atom_selection(selected_atom)
- self.display_molecule()
+ self.display_molecule()
def get_atom_at_position(self, x, y):
- """Get the atom index at the given position by
- checking the distance from the atom coordinates.
+ """Get the atom index at the given position by
+ checking the distance from the atom coordinates.
The atom coordinates are found using RDKit.
The logic is to check if the distance between the mouse click
and the atom coordinates is less than a threshold."""
- if not hasattr(self, 'atom_coords'):
+ if not hasattr(self, "atom_coords"):
return None
elif self.atom_coords is not None:
for idx, coord in enumerate(self.atom_coords):
- if len(self.smarts_targets[0]) <= 30: # small molecule = bigger click area
- if (coord.x - x) ** 2 + (coord.y - y) ** 2 < 300:
- return idx
- if len(self.smarts_targets[0]) <= 50 and len(self.smarts_targets[0]) > 30: # medium molecule = medium click area
- if (coord.x - x) ** 2 + (coord.y - y) ** 2 < 200:
- return idx
- elif len(self.smarts_targets[0]) > 50 : # big molecule = smaller click area
- if (coord.x - x) ** 2 + (coord.y - y) ** 2 < 100:
- return idx
+ if (
+ len(self.smarts_targets[0]) <= 30
+ ): # small molecule = bigger click area
+ if (coord.x - x) ** 2 + (coord.y - y) ** 2 < 300:
+ return idx
+ if (
+ len(self.smarts_targets[0]) <= 50
+ and len(self.smarts_targets[0]) > 30
+ ): # medium molecule = medium click area
+ if (coord.x - x) ** 2 + (coord.y - y) ** 2 < 200:
+ return idx
+ elif (
+ len(self.smarts_targets[0]) > 50
+ ): # big molecule = smaller click area
+ if (coord.x - x) ** 2 + (coord.y - y) ** 2 < 100:
+ return idx
return None
def open_chemdraw_popup(self):
@@ -744,7 +780,7 @@ def open_chemdraw_popup(self):
"• Incorrect or broken bonds
"
"• Unconnected fragments or misdrawn connections
"
"When everything looks correct, click OK to select your file."
- )
+ ),
)
dialog = ChemDrawFileDialog(self)
@@ -754,9 +790,10 @@ def open_chemdraw_popup(self):
def load_chemdraw_file(self, main_path):
"""Opens a ChemDraw file and displays the molecules in a table."""
+
def load_mols_from_path(path):
"""Load molecules from a ChemDraw or SDF file."""
- if path.endswith('.cdxml'):
+ if path.endswith(".cdxml"):
try:
mols = MolsFromCDXMLFile(path, sanitize=False, removeHs=False)
total_count = len(mols)
@@ -764,13 +801,19 @@ def load_mols_from_path(path):
for mol in mols:
if mol is not None:
- fragments = GetMolFrags(mol, asMols=True, sanitizeFrags=False)
+ fragments = GetMolFrags(
+ mol, asMols=True, sanitizeFrags=False
+ )
valid_mols.extend(fragments)
valid_count = len(valid_mols)
if valid_count == 0:
- QMessageBox.warning(self, "CDXML Warning", f"No valid molecules found in the file:\n{path}")
+ QMessageBox.warning(
+ self,
+ "CDXML Warning",
+ f"No valid molecules found in the file:\n{path}",
+ )
return []
elif valid_count < total_count:
@@ -778,18 +821,20 @@ def load_mols_from_path(path):
QMessageBox.warning(
self,
"CDXML Partial Load",
- f"File loaded with partial success.\n{failed_count} out of {total_count} molecules failed sanitization and were skipped."
+ f"File loaded with partial success.\n{failed_count} out of {total_count} molecules failed sanitization and were skipped.",
)
return valid_mols
except Exception as e:
- QMessageBox.critical(self, "CDXML Read Error", f"Failed to read {path}:\n{str(e)}")
+ QMessageBox.critical(
+ self, "CDXML Read Error", f"Failed to read {path}:\n{str(e)}"
+ )
return []
- elif path.endswith('.sdf'):
+ elif path.endswith(".sdf"):
return [mol for mol in Chem.SDMolSupplier(path) if mol is not None]
-
+
elif path.endswith(".cdx"):
QMessageBox.warning(
self,
@@ -807,7 +852,7 @@ def load_mols_from_path(path):
"4. Paste it into a new ChemDraw document.
"
"5. Save it as CDXML.
"
"This ensures proper structure recognition and full compatibility with easyROB."
- )
+ ),
)
return None
@@ -820,7 +865,7 @@ def load_mols_from_path(path):
# If the function returned None, it means we already handled a special case (like .cdx)
if mols_main is None:
return
-
+
# If the function returned an empty list, it means there were no valid molecules
if not mols_main:
QMessageBox.warning(self, "Error", "No valid molecules found in the file.")
@@ -846,7 +891,13 @@ def show_molecule_table_dialog(self, mols):
# --- Table Columns ---
base_headers = ["Image", "SMILES", "code_name", "target"]
extra_columns = ["charge", "mult", "complex_type", "sample", "geom"]
- complex_type_options = ["", "squareplanar", "squarepyramidal", "linear", "trigonalplanar"]
+ complex_type_options = [
+ "",
+ "squareplanar",
+ "squarepyramidal",
+ "linear",
+ "trigonalplanar",
+ ]
# Table widget setup
table = QTableWidget(len(mols), len(base_headers))
@@ -867,8 +918,10 @@ def on_header_double_clicked(index):
return # Only allow renaming for the 'target' column
current_text = table.horizontalHeaderItem(index).text()
new_text, ok = QInputDialog.getText(
- dialog, "Edit Column Name",
- f"Rename column '{current_text}':", text=current_text
+ dialog,
+ "Edit Column Name",
+ f"Rename column '{current_text}':",
+ text=current_text,
)
if ok and new_text.strip():
table.setHorizontalHeaderItem(index, QTableWidgetItem(new_text.strip()))
@@ -884,7 +937,11 @@ def on_header_double_clicked(index):
img.save(buffer, format="PNG")
qimg = QImage.fromData(buffer.getvalue())
label = QLabel()
- label.setPixmap(QPixmap.fromImage(qimg).scaled(100, 100, Qt.KeepAspectRatio, Qt.SmoothTransformation))
+ label.setPixmap(
+ QPixmap.fromImage(qimg).scaled(
+ 100, 100, Qt.KeepAspectRatio, Qt.SmoothTransformation
+ )
+ )
widget = QWidget()
hbox = QHBoxLayout()
@@ -916,11 +973,14 @@ def toggle_column(col_name, state):
Add or remove an extra column based on the corresponding checkbox.
Handles special widget for 'complex_type' column.
"""
+
def set_all_column_widths(width):
for col in range(table.columnCount()):
table.setColumnWidth(col, width)
- current_headers = [table.horizontalHeaderItem(i).text() for i in range(table.columnCount())]
+ current_headers = [
+ table.horizontalHeaderItem(i).text() for i in range(table.columnCount())
+ ]
if state: # Checkbox checked: add column if not present
if col_name not in current_headers:
idx = table.columnCount()
@@ -961,7 +1021,9 @@ def save_to_csv():
Collect all table data and save to a CSV file.
Includes validation for required fields, uniqueness, types, and empty checks.
"""
- headers = [table.horizontalHeaderItem(i).text() for i in range(table.columnCount())]
+ headers = [
+ table.horizontalHeaderItem(i).text() for i in range(table.columnCount())
+ ]
# --- Mandatory column presence check ---
try:
@@ -977,13 +1039,21 @@ def save_to_csv():
# Check 'SMILES' not empty
item = table.item(row, smiles_idx)
if not item or not item.text().strip():
- QMessageBox.warning(dialog, "WARNING!", f"Please fill in all 'SMILES' fields before saving.")
+ QMessageBox.warning(
+ dialog,
+ "WARNING!",
+ "Please fill in all 'SMILES' fields before saving.",
+ )
return
# Check 'code_name' not empty
item = table.item(row, code_name_idx)
if not item or not item.text().strip():
- QMessageBox.warning(dialog, "WARNING!", f"Please fill in all 'code_name' fields before saving.")
+ QMessageBox.warning(
+ dialog,
+ "WARNING!",
+ "Please fill in all 'code_name' fields before saving.",
+ )
return
code_names.append(table.item(row, code_name_idx).text().strip())
@@ -994,10 +1064,14 @@ def save_to_csv():
item = table.item(row, charge_idx)
val = item.text().strip() if item else ""
if val == "":
- QMessageBox.warning(dialog, "WARNING!", f"Column 'charge' cannot be empty.")
+ QMessageBox.warning(
+ dialog, "WARNING!", "Column 'charge' cannot be empty."
+ )
return
- if not (val.lstrip('-').isdigit() and '.' not in val):
- QMessageBox.warning(dialog, "WARNING!", f"Column 'charge' must be an integer.")
+ if not (val.lstrip("-").isdigit() and "." not in val):
+ QMessageBox.warning(
+ dialog, "WARNING!", "Column 'charge' must be an integer."
+ )
return
# Validate 'mult' column if present (must be int, not empty)
@@ -1006,10 +1080,14 @@ def save_to_csv():
item = table.item(row, mult_idx)
val = item.text().strip() if item else ""
if val == "":
- QMessageBox.warning(dialog, "WARNING!", f"Column 'mult' cannot be empty.")
+ QMessageBox.warning(
+ dialog, "WARNING!", "Column 'mult' cannot be empty."
+ )
return
- if not (val.lstrip('-').isdigit() and '.' not in val):
- QMessageBox.warning(dialog, "WARNING!", f"Column 'mult' must be an integer.")
+ if not (val.lstrip("-").isdigit() and "." not in val):
+ QMessageBox.warning(
+ dialog, "WARNING!", "Column 'mult' must be an integer."
+ )
return
# Validate 'complex_type' if present (must be selected)
@@ -1018,11 +1096,12 @@ def save_to_csv():
combo = table.cellWidget(row, complex_type_idx)
if combo is not None and combo.currentText().strip() == "":
QMessageBox.warning(
- dialog, "WARNING!",
- f"Column 'complex_type' cannot be empty. Please select a value."
+ dialog,
+ "WARNING!",
+ "Column 'complex_type' cannot be empty. Please select a value.",
)
return
-
+
# Validate 'sample' column if present (must be int, not empty)
if "sample" in headers:
sample_idx = headers.index("sample")
@@ -1030,10 +1109,16 @@ def save_to_csv():
item = table.item(row, sample_idx)
val = item.text().strip() if item else ""
if val == "":
- QMessageBox.warning(dialog, "WARNING!", f"Column 'sample' cannot be empty.")
+ QMessageBox.warning(
+ dialog, "WARNING!", "Column 'sample' cannot be empty."
+ )
return
- if not (val.lstrip('-').isdigit() and '.' not in val):
- QMessageBox.warning(dialog, "WARNING!", f"Column 'sample' must be an integer.")
+ if not (val.lstrip("-").isdigit() and "." not in val):
+ QMessageBox.warning(
+ dialog,
+ "WARNING!",
+ "Column 'sample' must be an integer.",
+ )
return
# Validate 'GEOM' column if present (must not be empty)
@@ -1043,16 +1128,20 @@ def save_to_csv():
item = table.item(row, geom_idx)
val = item.text().strip() if item else ""
if val == "":
- QMessageBox.warning(dialog, "WARNING!", f"Column 'geom' cannot be empty.")
+ QMessageBox.warning(
+ dialog, "WARNING!", "Column 'geom' cannot be empty."
+ )
return
-
# --- Uniqueness check for 'code_name' ---
- duplicates = [name for name in set(code_names) if code_names.count(name) > 1]
+ duplicates = [
+ name for name in set(code_names) if code_names.count(name) > 1
+ ]
if duplicates:
QMessageBox.warning(
- dialog, "WARNING!",
- f"The following 'code_name' values are duplicated:\n\n{', '.join(duplicates)}\n\nPlease make them unique before saving."
+ dialog,
+ "WARNING!",
+ f"The following 'code_name' values are duplicated:\n\n{', '.join(duplicates)}\n\nPlease make them unique before saving.",
)
return
@@ -1061,16 +1150,20 @@ def save_to_csv():
item = table.item(row, self.target_col_index)
val = item.text().strip() if item else ""
if not val:
- QMessageBox.warning(dialog, "WARNING!", f"Target column is empty.")
+ QMessageBox.warning(dialog, "WARNING!", "Target column is empty.")
return
try:
float(val)
except ValueError:
- QMessageBox.warning(dialog, "WARNING!", f"Target column must be numeric.")
+ QMessageBox.warning(
+ dialog, "WARNING!", "Target column must be numeric."
+ )
return
# --- File dialog to select save path ---
- path, _ = QFileDialog.getSaveFileName(dialog, "Save CSV", "", "CSV Files (*.csv)")
+ path, _ = QFileDialog.getSaveFileName(
+ dialog, "Save CSV", "", "CSV Files (*.csv)"
+ )
if not path:
return
@@ -1097,9 +1190,11 @@ def save_to_csv():
if hasattr(self, "main_window") and self.main_window:
self.main_window.set_file_path(path)
dialog.accept()
- QMessageBox.information(dialog, "Success", "CSV file saved and loaded successfully!")
+ QMessageBox.information(
+ dialog, "Success", "CSV file saved and loaded successfully!"
+ )
save_button.clicked.connect(save_to_csv)
dialog.setLayout(layout)
- dialog.exec()
\ No newline at end of file
+ dialog.exec()
diff --git a/robert/gui_easyrob/tabs/images.py b/robert/gui_easyrob/tabs/images.py
index c1ae2c8..bc2bac3 100644
--- a/robert/gui_easyrob/tabs/images.py
+++ b/robert/gui_easyrob/tabs/images.py
@@ -42,7 +42,7 @@
Qt,
)
-except ImportError as e:
+except ImportError:
from robert.gui_easyrob.utils.utils_gui import (
QApplication,
QDesktopServices,
@@ -66,6 +66,7 @@
import os
import glob
+
class ImagesTab(QWidget):
"""Images tab for displaying images from multiple folders as workflow results."""
@@ -98,8 +99,7 @@ def __init__(self, main_tab_widget, image_folders, file_path):
help_button.setFixedSize(18, 18)
help_button.setStyleSheet("font-size: 11px;")
help_button.setToolTip(
- "Double-click: Open image\n"
- "Right-click: Copy, Save, or Open folder"
+ "Double-click: Open image\nRight-click: Copy, Save, or Open folder"
)
help_button.clicked.connect(self.show_help_dialog)
@@ -268,4 +268,4 @@ def show_context_menu(self, position):
"Images (*.png *.jpg *.jpeg)",
)
if target_path:
- QPixmap(self.image_path).save(target_path)
\ No newline at end of file
+ QPixmap(self.image_path).save(target_path)
diff --git a/robert/gui_easyrob/tabs/molssi.py b/robert/gui_easyrob/tabs/molssi.py
index e8b9fe5..729e977 100644
--- a/robert/gui_easyrob/tabs/molssi.py
+++ b/robert/gui_easyrob/tabs/molssi.py
@@ -23,6 +23,7 @@
- Designed to keep the main window decoupled from download logic.
"""
+
try:
from utils.utils_gui import (
Path,
@@ -42,7 +43,7 @@
)
from utils.molssi_utils import ExcelToCSVWorker
-except ImportError as e:
+except ImportError:
from robert.gui_easyrob.utils.utils_gui import (
Path,
QFileDialog,
@@ -64,6 +65,7 @@
# ---- Standard library ----
import os
+
class MolSSIDatabasesTab(QWidget):
"""
Tab widget embedding the MolSSI descriptor databases web interface.
@@ -74,9 +76,10 @@ class MolSSIDatabasesTab(QWidget):
- Allows saving descriptor files locally
- Optionally converts downloaded Excel files to CSV
"""
- # Signal emitted when a test file download is requested
+
+ # Signal emitted when a test file download is requested
load_test_requested = Signal(str)
-
+
def __init__(self, parent=None):
super().__init__(parent)
@@ -100,6 +103,7 @@ class SingleWindowWebView(QWebEngineView):
Custom QWebEngineView that prevents opening external windows.
Any request to open a new window is redirected to the same view.
"""
+
def createWindow(self, webWindowType):
tmp = QWebEngineView(self)
tmp.setAttribute(Qt.WA_DeleteOnClose, True)
@@ -109,9 +113,7 @@ def createWindow(self, webWindowType):
return tmp
# Base URL for MolSSI databases
- self.databases_home_url = QUrl(
- "https://descriptor-libraries.molssi.org/"
- )
+ self.databases_home_url = QUrl("https://descriptor-libraries.molssi.org/")
# --------------------------------------------------
# Home bar
@@ -173,16 +175,10 @@ def _handle_download(self, req: QWebEngineDownloadRequest):
def open_dialog():
suggested = (
- req.downloadFileName()
- or QUrl(req.url()).fileName()
- or "download"
+ req.downloadFileName() or QUrl(req.url()).fileName() or "download"
)
- path, _ = QFileDialog.getSaveFileName(
- self,
- "Save File",
- suggested
- )
+ path, _ = QFileDialog.getSaveFileName(self, "Save File", suggested)
if not path:
req.cancel()
@@ -239,7 +235,7 @@ def _on_download_completed(self, path):
"(for example, an experimental property or value you want to predict).\n\n"
"Do you want to convert this Excel file to CSV now?",
QMessageBox.Yes | QMessageBox.No,
- QMessageBox.Yes
+ QMessageBox.Yes,
)
if reply != QMessageBox.Yes:
@@ -274,9 +270,7 @@ def finished(csv_path):
return
QMessageBox.information(
- self,
- "Conversion completed",
- "Excel converted to CSV successfully."
+ self, "Conversion completed", "Excel converted to CSV successfully."
)
def error(msg):
@@ -307,7 +301,7 @@ def _show_download_popup(self):
popup.setModal(False)
popup.show()
return popup
-
+
def load_test_molssi(self, csv_path, source=None):
"""
Finalizes a MolSSI test dataset.
@@ -335,4 +329,4 @@ def load_test_molssi(self, csv_path, source=None):
try:
Path(source).unlink()
except Exception:
- pass
\ No newline at end of file
+ pass
diff --git a/robert/gui_easyrob/tabs/predictions.py b/robert/gui_easyrob/tabs/predictions.py
index 8eebfea..3afe41a 100644
--- a/robert/gui_easyrob/tabs/predictions.py
+++ b/robert/gui_easyrob/tabs/predictions.py
@@ -23,6 +23,7 @@
- Designed to keep heavy logic delegated to utils.predictions_utils
"""
+
# ------------------------------------------------------------
# Import resolution (local vs installed package)
# ------------------------------------------------------------
@@ -60,7 +61,7 @@
get_robert_report_path,
)
-except ImportError as e:
+except ImportError:
from robert.gui_easyrob.utils.utils_gui import (
QFrame,
QHBoxLayout,
@@ -102,8 +103,10 @@
import pandas as pd
import matplotlib.pyplot as plt
+
class PredictionsTab(QWidget):
"""Tab for displaying prediction results from ROBERT runs."""
+
availabilityChanged = Signal(bool)
def __init__(self, parent=None):
@@ -147,7 +150,7 @@ def _extract_names_column_from_predict(self):
return match.group(1)
return None
-
+
def _filter_prediction_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Keeps and orders columns as:
@@ -193,7 +196,7 @@ def _filter_prediction_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
df.insert(0, "Image", df[smiles_cols[0]])
return df[ordered_columns]
-
+
def refresh_with_new_path(self, selected_file_path: str):
"""Refreshes the predictions tab with new data from the selected file path."""
# This is the ONLY base path used by PredictionsTab
@@ -231,7 +234,7 @@ def _add_loaded_df(self, key: str, df: pd.DataFrame):
info = evaluate_predictions_for_model(
self._base_path,
df,
- key # "PFI" or "No_PFI"
+ key, # "PFI" or "No_PFI"
)
# Extract fragment image
@@ -271,7 +274,7 @@ def _show_histogram_menu_header(self, pos, df: pd.DataFrame, header):
def _create_table_with_stats(self, df, info, pdf_image):
"""Creates the main table view with the predictions and the side dashboard with stats and diagnostics."""
- # ---- Container ----
+ # ---- Container ----
container = QWidget()
container.setStyleSheet("background: palette(window);")
@@ -308,7 +311,7 @@ def _create_table_with_stats(self, df, info, pdf_image):
lambda pos, d=df, h=header: self._show_header_menu(pos, d, h)
)
- # ---- Side Dashboard ----
+ # ---- Side Dashboard ----
pdf_path = info["pdf_path"]
model_key = info["model"] # "PFI" or "No_PFI"
@@ -325,7 +328,7 @@ def _create_table_with_stats(self, df, info, pdf_image):
pdf_image=pdf_image,
extrapolation_score=extrap_scores.get(model_key),
extrapolation_image=extrap_pixmap,
- external_plot=external_pixmap
+ external_plot=external_pixmap,
)
# ---- Separator ----
@@ -353,11 +356,11 @@ def _show_header_menu(self, pos, df: pd.DataFrame, header):
menu = QMenu(header)
- # Sorting actions
+ # Sorting actions
action_sort_asc = menu.addAction("Sort ascending")
action_sort_desc = menu.addAction("Sort descending")
- menu.addSeparator()
+ menu.addSeparator()
# Histogram action (only for numeric columns)
action_hist = None
@@ -408,4 +411,4 @@ def _show_histogram(self, series: pd.Series, col_name: str):
plt.xlabel(col_name)
plt.ylabel("Frequency")
plt.grid(False)
- plt.show(block=False)
\ No newline at end of file
+ plt.show(block=False)
diff --git a/robert/gui_easyrob/tabs/results.py b/robert/gui_easyrob/tabs/results.py
index 3d50b6e..063a8e6 100644
--- a/robert/gui_easyrob/tabs/results.py
+++ b/robert/gui_easyrob/tabs/results.py
@@ -54,7 +54,7 @@
fitz,
)
-except ImportError as e:
+except ImportError:
from robert.gui_easyrob.utils.utils_gui import (
QImage,
QLabel,
@@ -81,15 +81,17 @@
import os
import glob
+
class ResultsTab(QWidget):
"""PDF viewer for ROBERT reports."""
+
def __init__(self, main_tab_widget, file_path):
super().__init__()
self.main_tab_widget = main_tab_widget
self.base_path = os.path.dirname(file_path)
- self.pdf_tabs = {} # {pdf_path: PDFViewer|None} None => placeholder not materialized
- self.title_to_path = {} # {basename: full path}
+ self.pdf_tabs = {} # {pdf_path: PDFViewer|None} None => placeholder not materialized
+ self.title_to_path = {} # {basename: full path}
# Shared thread pool for all PDF viewers
self.shared_pool = QThreadPool()
@@ -207,20 +209,25 @@ def _index_of_title(self, title: str) -> int:
# ------------------------- Worker signals -------------------------
+
class RenderSignals(QObject):
"""Signals for page rendering."""
+
finished = Signal(int, float, int, QPixmap) # page_num, zoom, generation, pixmap
class MetaSignals(QObject):
"""Signals for PDF metadata loading."""
+
done = Signal(int, list) # page_count, page_sizes
# ------------------------- Worker tasks -------------------------
+
class RenderTask(QRunnable):
"""Background render task for a single PDF page (open by path; no big upfront I/O)."""
+
def __init__(self, pdf_path: str, page_num: int, zoom: float, generation: int):
super().__init__()
self.pdf_path = pdf_path
@@ -238,15 +245,20 @@ def run(self):
page = doc.load_page(self.page_num)
mat = fitz.Matrix(self.zoom, self.zoom)
pix = page.get_pixmap(matrix=mat, alpha=False)
- qimg = QImage(pix.samples, pix.width, pix.height, pix.stride, QImage.Format_RGB888).copy()
+ qimg = QImage(
+ pix.samples, pix.width, pix.height, pix.stride, QImage.Format_RGB888
+ ).copy()
qp = QPixmap.fromImage(qimg)
self.signals.finished.emit(self.page_num, self.zoom, self.generation, qp)
except Exception:
- self.signals.finished.emit(self.page_num, self.zoom, self.generation, QPixmap())
+ self.signals.finished.emit(
+ self.page_num, self.zoom, self.generation, QPixmap()
+ )
class MetaTask(QRunnable):
"""Load page count and page sizes off the UI thread."""
+
def __init__(self, pdf_path: str):
super().__init__()
self.pdf_path = pdf_path
@@ -258,7 +270,9 @@ def run(self):
try:
with fitz.open(self.pdf_path) as doc:
count = len(doc)
- sizes = [tuple(doc.load_page(i).rect.br) for i in range(count)] # (w_pts, h_pts)
+ sizes = [
+ tuple(doc.load_page(i).rect.br) for i in range(count)
+ ] # (w_pts, h_pts)
except Exception:
count, sizes = 1, [(595, 842)] # Fallback to A4 portrait in points
self.signals.done.emit(count, sizes)
@@ -266,19 +280,21 @@ def run(self):
# ------------------------- PDFViewer (async metadata + visible-only render) -------------------------
+
class PDFViewer(QWidget):
"""Widget to display a PDF inside a scrollable area with zoom control and threading."""
+
def __init__(self, pdf_path: str, thread_pool: QThreadPool):
super().__init__()
self.pdf_path = pdf_path
self.current_zoom = 1.2
self.thread_pool = thread_pool # shared
- self.image_cache = {} # {(page_num, zoom): QPixmap}
- self.labels = [] # one QLabel per page
- self.page_sizes = None # [(width_pts, height_pts)]
+ self.image_cache = {} # {(page_num, zoom): QPixmap}
+ self.labels = [] # one QLabel per page
+ self.page_sizes = None # [(width_pts, height_pts)]
self.page_count = None
- self._renderGeneration = 0 # cancel stale renders
+ self._renderGeneration = 0 # cancel stale renders
self._zoomPending = False
self._scrollPending = False
@@ -332,7 +348,9 @@ def _apply_metadata(self, page_count: int, page_sizes: list):
self._build_placeholders_for_zoom(self.current_zoom)
# Now that we know page geometry, hook scroll coalescing
- self.scroll_area.verticalScrollBar().valueChanged.connect(self._schedule_visible_render)
+ self.scroll_area.verticalScrollBar().valueChanged.connect(
+ self._schedule_visible_render
+ )
# Initial render: only what's visible + tiny warm
self._kick_off_visible_render(force=True, warm=1)
@@ -355,7 +373,9 @@ def _apply_zoom_now(self):
# Bump generation to discard in-flight renders
self._renderGeneration += 1
# Keep only current-zoom cache
- self.image_cache = {k: v for k, v in self.image_cache.items() if k[1] == self.current_zoom}
+ self.image_cache = {
+ k: v for k, v in self.image_cache.items() if k[1] == self.current_zoom
+ }
# Recompute placeholder heights and clear labels
self._build_placeholders_for_zoom(self.current_zoom)
# Kick minimal warm-up
@@ -465,13 +485,19 @@ def _kick_off_visible_render(self, force: bool = False, warm: int = 0):
# ---------- Render completion ----------
@Slot(int, float, int, QPixmap)
- def on_page_rendered(self, page_num: int, zoom: float, generation: int, pixmap: QPixmap):
+ def on_page_rendered(
+ self, page_num: int, zoom: float, generation: int, pixmap: QPixmap
+ ):
"""Handle rendered page: update cache and label if still relevant."""
# Discard outdated renders (other zoom or older generation) or failed pixmaps
- if generation != self._renderGeneration or zoom != self.current_zoom or pixmap.isNull():
+ if (
+ generation != self._renderGeneration
+ or zoom != self.current_zoom
+ or pixmap.isNull()
+ ):
return
key = (page_num, zoom)
self.image_cache[key] = pixmap
lbl = self.labels[page_num]
lbl.setPixmap(pixmap)
- lbl.setText("")
\ No newline at end of file
+ lbl.setText("")
diff --git a/robert/gui_easyrob/utils/aqme_utils.py b/robert/gui_easyrob/utils/aqme_utils.py
index e17696e..15e190b 100644
--- a/robert/gui_easyrob/utils/aqme_utils.py
+++ b/robert/gui_easyrob/utils/aqme_utils.py
@@ -46,6 +46,7 @@
# ------------------------------------------------------------
from .utils_gui import DropLabel
+
class ChemDrawFileDialog(QDialog):
"""Dialog that collects the main ChemDraw/SDF input file."""
@@ -88,10 +89,13 @@ def set_main_file(self, path):
def continue_clicked(self):
"""Check if a main ChemDraw file has been selected and accept the dialog."""
if not self.main_chemdraw_path:
- QMessageBox.warning(self, "Missing File", "Please select a main ChemDraw file.")
+ QMessageBox.warning(
+ self, "Missing File", "Please select a main ChemDraw file."
+ )
return
self.accept()
+
def mcs_process(smiles_list, result_queue):
"""Find the maximum common substructure for a list of SMILES."""
try:
@@ -160,4 +164,4 @@ def _on_timeout(self):
if self.process and self.process.is_alive():
self.process.terminate()
self.process.join()
- self.timeout.emit()
\ No newline at end of file
+ self.timeout.emit()
diff --git a/robert/gui_easyrob/utils/molssi_utils.py b/robert/gui_easyrob/utils/molssi_utils.py
index e39bdc6..4814ae7 100644
--- a/robert/gui_easyrob/utils/molssi_utils.py
+++ b/robert/gui_easyrob/utils/molssi_utils.py
@@ -39,6 +39,7 @@
from PySide6.QtCore import QThread, Signal
+
class MolSSIWorker(QThread):
"""Background worker responsible for resolving MolSSI descriptors."""
@@ -132,7 +133,7 @@ def canonicalize(smiles):
def chunked(lst, size):
"""Split a list into chunks of a specified size."""
for i in range(0, len(lst), size):
- yield lst[i:i + size]
+ yield lst[i : i + size]
def safe_query_batched(smiles, library, data_type, batch_size=200):
"""Query MolSSI API in batches and handle partial failures."""
@@ -197,11 +198,15 @@ def full_coverage(smiles, df_api):
for lib in dft_libraries:
df_api = safe_query_batched(smiles_list, lib, "DFT")
if full_coverage(smiles_list, df_api):
- return _prepare_export(df_work, df_api, smiles_col, lib, "DFT", original_input_columns)
+ return _prepare_export(
+ df_work, df_api, smiles_col, lib, "DFT", original_input_columns
+ )
df_api = safe_query_batched(smiles_list, "kraken", "ML")
if full_coverage(smiles_list, df_api):
- return _prepare_export(df_work, df_api, smiles_col, "kraken", "ML", original_input_columns)
+ return _prepare_export(
+ df_work, df_api, smiles_col, "kraken", "ML", original_input_columns
+ )
return {
"available": False,
@@ -225,7 +230,9 @@ def _molssi_test_dataset_available(library_slug):
return False
-def _prepare_export(df_work, df_api, smiles_col, library, data_type, original_input_columns):
+def _prepare_export(
+ df_work, df_api, smiles_col, library, data_type, original_input_columns
+):
"""Prepare the merged MolSSI export DataFrame for use in easyROB."""
try:
df_api = df_api.copy()
@@ -237,7 +244,9 @@ def _prepare_export(df_work, df_api, smiles_col, library, data_type, original_in
df_merged = df_work.merge(df_api, on="_smiles_canonical", how="left")
export_df = df_merged.drop(
- columns=[c for c in ["_smiles_canonical", "smiles"] if c in df_merged.columns]
+ columns=[
+ c for c in ["_smiles_canonical", "smiles"] if c in df_merged.columns
+ ]
)
export_df[smiles_col] = export_df["_smiles_original"]
export_df = export_df.drop(columns=["_smiles_original"])
@@ -251,7 +260,10 @@ def _prepare_export(df_work, df_api, smiles_col, library, data_type, original_in
]
if "molecule_id" in export_df.columns:
- only_smiles_input = len(original_input_columns) == 1 and original_input_columns[0].lower() == "smiles"
+ only_smiles_input = (
+ len(original_input_columns) == 1
+ and original_input_columns[0].lower() == "smiles"
+ )
if not only_smiles_input:
export_df = export_df.drop(columns=["molecule_id"])
@@ -278,27 +290,68 @@ def _prepare_export(df_work, df_api, smiles_col, library, data_type, original_in
def fix_greek_caps_columns(col: str) -> str:
"""Normalize Greek characters and canonical spelling in MolSSI headers."""
greek_map = {
- "α": "alpha", "β": "beta", "γ": "gamma", "δ": "delta",
- "ε": "epsilon", "ζ": "zeta", "η": "eta", "θ": "theta",
- "ι": "iota", "κ": "kappa", "λ": "lambda", "μ": "mu",
- "ν": "nu", "ξ": "xi", "ο": "omicron", "π": "pi",
- "ρ": "rho", "σ": "sigma", "τ": "tau", "υ": "upsilon",
- "φ": "phi", "χ": "chi", "ψ": "psi", "ω": "omega",
- "Α": "alpha", "Β": "beta", "Γ": "gamma", "Δ": "delta",
- "Ε": "epsilon", "Ζ": "zeta", "Η": "eta", "Θ": "theta",
- "Ι": "iota", "Κ": "kappa", "Λ": "lambda", "Μ": "mu",
- "Ν": "nu", "Ξ": "xi", "Ο": "omicron", "Π": "pi",
- "Ρ": "rho", "Σ": "sigma", "Τ": "tau", "Υ": "upsilon",
- "Φ": "phi", "Χ": "chi", "Ψ": "psi", "Ω": "omega",
+ "α": "alpha",
+ "β": "beta",
+ "γ": "gamma",
+ "δ": "delta",
+ "ε": "epsilon",
+ "ζ": "zeta",
+ "η": "eta",
+ "θ": "theta",
+ "ι": "iota",
+ "κ": "kappa",
+ "λ": "lambda",
+ "μ": "mu",
+ "ν": "nu",
+ "ξ": "xi",
+ "ο": "omicron",
+ "π": "pi",
+ "ρ": "rho",
+ "σ": "sigma",
+ "τ": "tau",
+ "υ": "upsilon",
+ "φ": "phi",
+ "χ": "chi",
+ "ψ": "psi",
+ "ω": "omega",
+ "Α": "alpha",
+ "Β": "beta",
+ "Γ": "gamma",
+ "Δ": "delta",
+ "Ε": "epsilon",
+ "Ζ": "zeta",
+ "Η": "eta",
+ "Θ": "theta",
+ "Ι": "iota",
+ "Κ": "kappa",
+ "Λ": "lambda",
+ "Μ": "mu",
+ "Ν": "nu",
+ "Ξ": "xi",
+ "Ο": "omicron",
+ "Π": "pi",
+ "Ρ": "rho",
+ "Σ": "sigma",
+ "Τ": "tau",
+ "Υ": "upsilon",
+ "Φ": "phi",
+ "Χ": "chi",
+ "Ψ": "psi",
+ "Ω": "omega",
}
for greek_char, latin in greek_map.items():
col = col.replace(greek_char, latin)
- col = re.sub(r"(? Path:
"""Given the path to a selected file, return the corresponding PREDICT/csv_test directory."""
return Path(selected_file_path).parent / "PREDICT" / "csv_test"
+
def find_prediction_csvs(selected_file_path: str) -> dict[str, Path]:
"""Search for prediction CSV files in the PREDICT/csv_test directory related to the selected file."""
predict_dir = get_predict_dir(selected_file_path)
@@ -91,10 +93,12 @@ def find_prediction_csvs(selected_file_path: str) -> dict[str, Path]:
results["PFI"] = path
return results
+
def get_robert_report_path(selected_file_path: str | Path) -> Path:
"""Given the path to a selected file, return the corresponding ROBERT_report.pdf file."""
return Path(selected_file_path).parent / "ROBERT_report.pdf"
+
def find_external_test_pixmaps(base_path: str | Path) -> dict[str, QPixmap]:
"""Search for external test images in the PREDICT/csv_test directory related to the selected file."""
base_path = Path(base_path)
@@ -119,6 +123,7 @@ def find_external_test_pixmaps(base_path: str | Path) -> dict[str, QPixmap]:
return results
+
def extract_scores_from_robert_report(pdf_path: Path) -> dict:
"""Extract scores from the ROBERT report PDF file."""
result = {"pdf_found": False, "PFI": None, "No_PFI": None}
@@ -133,6 +138,7 @@ def extract_scores_from_robert_report(pdf_path: Path) -> dict:
return result
+
def extract_extrapolation_fragment(pdf_path: Path, model_key: str) -> QPixmap | None:
"""Render the extrapolation block from parsed ROBERT report data."""
details = _extract_extrapolation_details(pdf_path, model_key)
@@ -140,6 +146,7 @@ def extract_extrapolation_fragment(pdf_path: Path, model_key: str) -> QPixmap |
return None
return _render_extrapolation_pixmap(details)
+
def extract_robert_fragment_image(pdf_path: Path, model_key: str) -> QPixmap | None:
"""Render the ROBERT score block from parsed report data."""
details = _extract_robert_score_details(pdf_path, model_key)
@@ -147,6 +154,7 @@ def extract_robert_fragment_image(pdf_path: Path, model_key: str) -> QPixmap | N
return None
return _render_robert_score_pixmap(details)
+
def _get_extrapolation_bbox(page, model_key: str):
"""Return the PDF area containing the extrapolation block for the requested model."""
if model_key == "No_PFI":
@@ -158,7 +166,9 @@ def _get_extrapolation_bbox(page, model_key: str):
def _normalize_extrapolation_lines(text: str) -> list[str]:
"""Collapse noisy PDF whitespace while preserving the content of each line."""
- return [re.sub(r"\s+", " ", line).strip() for line in text.splitlines() if line.strip()]
+ return [
+ re.sub(r"\s+", " ", line).strip() for line in text.splitlines() if line.strip()
+ ]
def _parse_extrapolation_block(text: str) -> dict | None:
@@ -168,7 +178,9 @@ def _parse_extrapolation_block(text: str) -> dict | None:
lines = _normalize_extrapolation_lines(text)
title_line = next((line for line in lines if "Extrapolation" in line), None)
- rmse_line = next((line for line in lines if "[" in line and "]" in line and "%" in line), None)
+ rmse_line = next(
+ (line for line in lines if "[" in line and "]" in line and "%" in line), None
+ )
scoring_line = next((line for line in lines if "Scoring from" in line), None)
rule_line = next((line for line in lines if "Every two folds" in line), None)
@@ -183,7 +195,11 @@ def _parse_extrapolation_block(text: str) -> dict | None:
if rmse_line:
values_match = re.search(r"\[(.*?)\]", rmse_line)
if values_match:
- values = [value.strip() for value in values_match.group(1).split(",") if value.strip()]
+ values = [
+ value.strip()
+ for value in values_match.group(1).split(",")
+ if value.strip()
+ ]
clean_title = re.sub(r"\(\s*\d+\s*/\s*\d+\s*\)", "", title_line).strip()
@@ -217,7 +233,9 @@ def _extract_extrapolation_details(pdf_path: Path, model_key: str) -> dict | Non
return None
-def _score_fill_rgb(obtained: int | None, maximum: int | None) -> tuple[float, float, float]:
+def _score_fill_rgb(
+ obtained: int | None, maximum: int | None
+) -> tuple[float, float, float]:
"""Return a fill color for the extrapolation score indicator."""
if obtained is None or maximum in (None, 0):
return (0.78, 0.78, 0.78)
@@ -295,8 +313,15 @@ def _render_extrapolation_pixmap(details: dict) -> QPixmap | None:
fontname="hebo",
color=(0.12, 0.20, 0.30),
)
- rmse_text = f"[{', '.join(details['rmse_values'])}]" if details["rmse_values"] else "[]"
- values_rect = fitz.Rect(margin - 2, title_y + line_gap + 12, width - margin, title_y + line_gap * 2 + 20)
+ rmse_text = (
+ f"[{', '.join(details['rmse_values'])}]" if details["rmse_values"] else "[]"
+ )
+ values_rect = fitz.Rect(
+ margin - 2,
+ title_y + line_gap + 12,
+ width - margin,
+ title_y + line_gap * 2 + 20,
+ )
page.draw_rect(values_rect, color=(0.86, 0.86, 0.86), fill=(1, 1, 1), width=0.8)
page.insert_text(
fitz.Point(margin + 8, title_y + line_gap * 2 + 8),
@@ -334,17 +359,24 @@ def _parse_robert_score_block(text: str) -> dict | None:
if not text:
return None
- lines = [re.sub(r"\s+", " ", line).strip() for line in text.splitlines() if line.strip()]
+ lines = [
+ re.sub(r"\s+", " ", line).strip() for line in text.splitlines() if line.strip()
+ ]
if not lines:
return None
- title_line = next((line for line in lines if re.search(r"\bScore\s+\d+\b", line, re.IGNORECASE)), None)
+ title_line = next(
+ (line for line in lines if re.search(r"\bScore\s+\d+\b", line, re.IGNORECASE)),
+ None,
+ )
if not title_line:
joined_text = " ".join(lines)
score_match = re.search(r"\bScore\s+(\d+)\b", joined_text, re.IGNORECASE)
if not score_match:
return None
- title_match = re.search(r"(.+?)\s*[.\-·]?\s*Score\s+\d+\b", joined_text, re.IGNORECASE)
+ title_match = re.search(
+ r"(.+?)\s*[.\-·]?\s*Score\s+\d+\b", joined_text, re.IGNORECASE
+ )
title = title_match.group(1).strip() if title_match else "ROBERT Score"
return {
"title": title,
@@ -357,7 +389,9 @@ def _parse_robert_score_block(text: str) -> dict | None:
if not score_match:
return None
- title = re.sub(r"\s*[·\.-]?\s*Score\s+\d+\b.*$", "", title_line, flags=re.IGNORECASE).strip()
+ title = re.sub(
+ r"\s*[·\.-]?\s*Score\s+\d+\b.*$", "", title_line, flags=re.IGNORECASE
+ ).strip()
return {
"title": title,
"score": int(score_match.group(1)),
@@ -412,16 +446,46 @@ def _extract_robert_score_details(pdf_path: Path, model_key: str) -> dict | None
def _robert_score_style(score: int | None) -> dict:
"""Return the visual style associated with a ROBERT score."""
if score is None:
- return {"label": "UNKNOWN", "segments": 0, "fill": (0.92, 0.90, 0.96), "text": (0.35, 0.35, 0.35)}
+ return {
+ "label": "UNKNOWN",
+ "segments": 0,
+ "fill": (0.92, 0.90, 0.96),
+ "text": (0.35, 0.35, 0.35),
+ }
if score <= 0:
- return {"label": "VERY WEAK", "segments": 0, "fill": (1.0, 0.42, 0.42), "text": (1.0, 0.42, 0.42)}
+ return {
+ "label": "VERY WEAK",
+ "segments": 0,
+ "fill": (1.0, 0.42, 0.42),
+ "text": (1.0, 0.42, 0.42),
+ }
if score <= 3:
- return {"label": "VERY WEAK", "segments": score, "fill": (1.0, 0.42, 0.42), "text": (1.0, 0.42, 0.42)}
+ return {
+ "label": "VERY WEAK",
+ "segments": score,
+ "fill": (1.0, 0.42, 0.42),
+ "text": (1.0, 0.42, 0.42),
+ }
if score <= 6:
- return {"label": "WEAK", "segments": score, "fill": (1.0, 0.79, 0.38), "text": (1.0, 0.79, 0.38)}
+ return {
+ "label": "WEAK",
+ "segments": score,
+ "fill": (1.0, 0.79, 0.38),
+ "text": (1.0, 0.79, 0.38),
+ }
if score <= 8:
- return {"label": "MODERATE", "segments": score, "fill": (0.60, 0.78, 0.95), "text": (0.60, 0.78, 0.95)}
- return {"label": "STRONG", "segments": min(score, 10), "fill": (0.38, 0.60, 0.80), "text": (0.38, 0.60, 0.80)}
+ return {
+ "label": "MODERATE",
+ "segments": score,
+ "fill": (0.60, 0.78, 0.95),
+ "text": (0.60, 0.78, 0.95),
+ }
+ return {
+ "label": "STRONG",
+ "segments": min(score, 10),
+ "fill": (0.38, 0.60, 0.80),
+ "text": (0.38, 0.60, 0.80),
+ }
def _render_robert_score_pixmap(details: dict) -> QPixmap | None:
@@ -521,7 +585,11 @@ def extract_extrapolation_scores(pdf_path: Path) -> dict:
for model_key in ("No_PFI", "PFI"):
details = _extract_extrapolation_details(pdf_path, model_key)
- if details and details.get("obtained") is not None and details.get("maximum") is not None:
+ if (
+ details
+ and details.get("obtained") is not None
+ and details.get("maximum") is not None
+ ):
result[model_key] = {
"obtained": details["obtained"],
"maximum": details["maximum"],
@@ -529,6 +597,7 @@ def extract_extrapolation_scores(pdf_path: Path) -> dict:
return result
+
def extract_prediction_info(df: pd.DataFrame) -> dict:
"""Extract prediction information from a DataFrame."""
pred_cols = [col for col in df.columns if col.endswith("_pred")]
@@ -548,18 +617,27 @@ def extract_prediction_info(df: pd.DataFrame) -> dict:
result["has_pred_column"] = True
result["pred_column"] = col
result["n_unique"] = series.nunique()
- result["predictions_identical"] = None if result["n_unique"] == 0 else result["n_unique"] == 1
+ result["predictions_identical"] = (
+ None if result["n_unique"] == 0 else result["n_unique"] == 1
+ )
return result
-def evaluate_model_scenario(score: int | None, predictions_identical: bool | None) -> dict:
+
+def evaluate_model_scenario(
+ score: int | None, predictions_identical: bool | None
+) -> dict:
"""Evaluate the model scenario based on ROBERT score and predictions."""
almos_link = "https://github.com/MiguelMartzFdez/almos"
almos_html = f'ALMOS'
result = {"state": "UNKNOWN", "messages": [], "recommendations": []}
if score is None:
- result["messages"].append("No valid ROBERT score was detected. Model reliability cannot be evaluated.")
- result["recommendations"].append("You may verify that ROBERT_report.pdf was generated correctly.")
+ result["messages"].append(
+ "No valid ROBERT score was detected. Model reliability cannot be evaluated."
+ )
+ result["recommendations"].append(
+ "You may verify that ROBERT_report.pdf was generated correctly."
+ )
return result
if predictions_identical is True:
@@ -576,7 +654,9 @@ def evaluate_model_scenario(score: int | None, predictions_identical: bool | Non
if 0 <= score <= 3:
result["state"] = "FAILED"
- result["messages"].append(f"ROBERT score is {score}. Model performance is critically low.")
+ result["messages"].append(
+ f"ROBERT score is {score}. Model performance is critically low."
+ )
result["recommendations"].append(
"You may avoid using these predictions and rebuild the dataset "
f"using Clustering module in {almos_html}."
@@ -585,7 +665,9 @@ def evaluate_model_scenario(score: int | None, predictions_identical: bool | Non
if 4 <= score <= 6:
result["state"] = "WEAK"
- result["messages"].append(f"ROBERT score is {score}. The model works, but reliability is limited.")
+ result["messages"].append(
+ f"ROBERT score is {score}. The model works, but reliability is limited."
+ )
result["recommendations"].append(
"You may use predictions cautiously and improve robustness "
f"through Active Learning module with {almos_html}."
@@ -594,7 +676,9 @@ def evaluate_model_scenario(score: int | None, predictions_identical: bool | Non
if 7 <= score <= 8:
result["state"] = "DECENT"
- result["messages"].append(f"ROBERT score is {score}. The model is solid but can still improve.")
+ result["messages"].append(
+ f"ROBERT score is {score}. The model is solid but can still improve."
+ )
result["recommendations"].append(
"You may use these predictions while considering further optimization "
f"through Active Learning module with {almos_html}."
@@ -603,7 +687,9 @@ def evaluate_model_scenario(score: int | None, predictions_identical: bool | Non
if score > 8:
result["state"] = "STRONG"
- result["messages"].append(f"ROBERT score is {score}. The model shows strong predictive performance.")
+ result["messages"].append(
+ f"ROBERT score is {score}. The model shows strong predictive performance."
+ )
result["recommendations"].append(
"You may confidently use these predictions for candidate prioritization."
)
@@ -611,7 +697,10 @@ def evaluate_model_scenario(score: int | None, predictions_identical: bool | Non
return result
-def evaluate_predictions_for_model(selected_file_path: str | Path, df: pd.DataFrame, model_key: str) -> dict:
+
+def evaluate_predictions_for_model(
+ selected_file_path: str | Path, df: pd.DataFrame, model_key: str
+) -> dict:
"""Evaluate predictions for a specific model."""
pdf_path = get_robert_report_path(selected_file_path)
scores = extract_scores_from_robert_report(pdf_path)
@@ -628,6 +717,7 @@ def evaluate_predictions_for_model(selected_file_path: str | Path, df: pd.DataFr
"scenario": scenario,
}
+
def collect_model_info(selected_file_path: str | Path, df: pd.DataFrame) -> dict:
"""Collect information for all models."""
pdf_path = get_robert_report_path(selected_file_path)
@@ -644,9 +734,19 @@ def collect_model_info(selected_file_path: str | Path, df: pd.DataFrame) -> dict
"scenario": scenario,
}
+
class PredictionDashboardPanel(QWidget):
"""A collapsible dashboard panel to display ROBERT prediction evaluation results and diagnostics."""
- def __init__(self, scenario: dict, pdf_image=None, extrapolation_score=None, extrapolation_image=None, external_plot=None, parent=None):
+
+ def __init__(
+ self,
+ scenario: dict,
+ pdf_image=None,
+ extrapolation_score=None,
+ extrapolation_image=None,
+ external_plot=None,
+ parent=None,
+ ):
super().__init__(parent)
self._pdf_image = pdf_image
self._extrapolation_score = extrapolation_score
@@ -715,7 +815,9 @@ def _build_ui(self, scenario):
self._build_status_block(content_layout, scenario)
self._build_pdf_snapshot_block(content_layout)
- self._build_extrapolation_block(content_layout, self._extrapolation_score, self._extrapolation_image)
+ self._build_extrapolation_block(
+ content_layout, self._extrapolation_score, self._extrapolation_image
+ )
self._build_external_validation_block(content_layout, self._external_plot)
content_layout.addStretch()
@@ -747,7 +849,9 @@ def _build_pdf_snapshot_block(self, layout):
layout.addSpacing(15)
container = QWidget()
container.setObjectName("dashboardBlock")
- container.setStyleSheet("QWidget#dashboardBlock { border: 1px solid palette(mid); border-radius: 8px; }")
+ container.setStyleSheet(
+ "QWidget#dashboardBlock { border: 1px solid palette(mid); border-radius: 8px; }"
+ )
container_layout = QVBoxLayout(container)
container_layout.setContentsMargins(14, 14, 14, 14)
container_layout.setSpacing(10)
@@ -758,13 +862,19 @@ def _build_pdf_snapshot_block(self, layout):
image_frame = QWidget()
image_frame.setObjectName("imageFrame")
- image_frame.setStyleSheet("QWidget#imageFrame { border: 1px solid palette(mid); border-radius: 6px; }")
+ image_frame.setStyleSheet(
+ "QWidget#imageFrame { border: 1px solid palette(mid); border-radius: 6px; }"
+ )
image_layout = QVBoxLayout(image_frame)
image_layout.setContentsMargins(6, 6, 6, 6)
image_label = QLabel()
image_label.setAlignment(Qt.AlignCenter)
- image_label.setPixmap(self._pdf_image.scaledToWidth(self.expanded_width - 120, Qt.SmoothTransformation))
+ image_label.setPixmap(
+ self._pdf_image.scaledToWidth(
+ self.expanded_width - 120, Qt.SmoothTransformation
+ )
+ )
image_layout.addWidget(image_label)
container_layout.addWidget(image_frame)
layout.addWidget(container)
@@ -777,7 +887,9 @@ def _build_extrapolation_block(self, layout, score, pixmap):
layout.addSpacing(15)
container = QWidget()
container.setObjectName("dashboardBlock")
- container.setStyleSheet("QWidget#dashboardBlock { border: 1px solid palette(mid); border-radius: 8px; }")
+ container.setStyleSheet(
+ "QWidget#dashboardBlock { border: 1px solid palette(mid); border-radius: 8px; }"
+ )
container_layout = QVBoxLayout(container)
container_layout.setContentsMargins(14, 14, 14, 14)
container_layout.setSpacing(10)
@@ -786,7 +898,9 @@ def _build_extrapolation_block(self, layout, score, pixmap):
title.setStyleSheet("font-weight: bold; font-size: 13px;")
container_layout.addWidget(title)
- subtitle = QLabel("Assessment of the model's ability to predict beyond the range of the training data.")
+ subtitle = QLabel(
+ "Assessment of the model's ability to predict beyond the range of the training data."
+ )
subtitle.setWordWrap(True)
subtitle.setStyleSheet("font-size: 11px;")
container_layout.addWidget(subtitle)
@@ -794,12 +908,16 @@ def _build_extrapolation_block(self, layout, score, pixmap):
if pixmap:
image_frame = QWidget()
image_frame.setObjectName("imageFrame")
- image_frame.setStyleSheet("QWidget#imageFrame { border: 1px solid palette(mid); border-radius: 6px; }")
+ image_frame.setStyleSheet(
+ "QWidget#imageFrame { border: 1px solid palette(mid); border-radius: 6px; }"
+ )
image_layout = QVBoxLayout(image_frame)
image_layout.setContentsMargins(6, 6, 6, 6)
image_label = QLabel()
image_label.setAlignment(Qt.AlignCenter)
- image_label.setPixmap(pixmap.scaledToWidth(self.expanded_width - 120, Qt.SmoothTransformation))
+ image_label.setPixmap(
+ pixmap.scaledToWidth(self.expanded_width - 120, Qt.SmoothTransformation)
+ )
image_layout.addWidget(image_label)
container_layout.addWidget(image_frame)
@@ -848,7 +966,9 @@ def _build_external_validation_block(self, layout, pixmap):
layout.addSpacing(15)
container = QWidget()
container.setObjectName("dashboardBlock")
- container.setStyleSheet("QWidget#dashboardBlock { border: 1px solid palette(mid); border-radius: 8px; }")
+ container.setStyleSheet(
+ "QWidget#dashboardBlock { border: 1px solid palette(mid); border-radius: 8px; }"
+ )
container_layout = QVBoxLayout(container)
container_layout.setContentsMargins(14, 14, 14, 14)
container_layout.setSpacing(10)
@@ -857,19 +977,25 @@ def _build_external_validation_block(self, layout, pixmap):
title.setStyleSheet("font-weight: bold; font-size: 13px;")
container_layout.addWidget(title)
- subtitle = QLabel("Predicted vs experimental values for molecules with known target data.")
+ subtitle = QLabel(
+ "Predicted vs experimental values for molecules with known target data."
+ )
subtitle.setWordWrap(True)
subtitle.setStyleSheet("font-size: 11px;")
container_layout.addWidget(subtitle)
image_frame = QWidget()
image_frame.setObjectName("imageFrame")
- image_frame.setStyleSheet("QWidget#imageFrame { border: 1px solid palette(mid); border-radius: 6px; }")
+ image_frame.setStyleSheet(
+ "QWidget#imageFrame { border: 1px solid palette(mid); border-radius: 6px; }"
+ )
image_layout = QVBoxLayout(image_frame)
image_layout.setContentsMargins(6, 6, 6, 6)
image_label = QLabel()
image_label.setAlignment(Qt.AlignCenter)
- image_label.setPixmap(pixmap.scaledToWidth(self.expanded_width - 120, Qt.SmoothTransformation))
+ image_label.setPixmap(
+ pixmap.scaledToWidth(self.expanded_width - 120, Qt.SmoothTransformation)
+ )
image_layout.addWidget(image_label)
container_layout.addWidget(image_frame)
@@ -891,6 +1017,7 @@ def toggle(self):
class PandasTableModel(QAbstractTableModel):
"""A Qt table model that wraps a pandas DataFrame, with special handling for SMILES rendering and sorting optimization."""
+
def __init__(self, df: pd.DataFrame):
super().__init__()
self._df = df
@@ -950,8 +1077,10 @@ def sort_key(self, column: int) -> np.ndarray:
self._sort_cache[column] = col.astype(str).to_numpy()
return self._sort_cache[column]
+
class StatsHeader(QHeaderView):
"""A custom header view that displays column names and basic statistics for numeric columns."""
+
def __init__(self, df: pd.DataFrame, orientation, parent=None):
super().__init__(orientation, parent)
self._df = df
@@ -998,11 +1127,15 @@ def paintSection(self, painter, rect, logical_index):
bold_font.setBold(True)
painter.setFont(bold_font)
fm = QFontMetrics(bold_font)
- name_height = fm.boundingRect(0, 0, r.width(), 1000, Qt.AlignHCenter | Qt.TextWordWrap, col_name).height()
+ name_height = fm.boundingRect(
+ 0, 0, r.width(), 1000, Qt.AlignHCenter | Qt.TextWordWrap, col_name
+ ).height()
name_rect = rect.adjusted(margin, margin, -margin, -margin)
name_rect.setHeight(name_height)
- painter.drawText(name_rect, Qt.AlignHCenter | Qt.AlignTop | Qt.TextWordWrap, col_name)
+ painter.drawText(
+ name_rect, Qt.AlignHCenter | Qt.AlignTop | Qt.TextWordWrap, col_name
+ )
if stats is not None:
normal_font = painter.font()
@@ -1016,12 +1149,18 @@ def paintSection(self, painter, rect, logical_index):
)
metrics_rect = r
metrics_rect.setTop(name_rect.bottom() + 6)
- painter.drawText(metrics_rect, Qt.AlignHCenter | Qt.AlignTop | Qt.TextWordWrap, metrics_text)
+ painter.drawText(
+ metrics_rect,
+ Qt.AlignHCenter | Qt.AlignTop | Qt.TextWordWrap,
+ metrics_text,
+ )
painter.restore()
+
class ColumnStatsWidget(QWidget):
"""A widget that displays basic statistics for numeric columns in a DataFrame."""
+
def __init__(self, df: pd.DataFrame, parent=None):
super().__init__(parent)
layout = QHBoxLayout(self)
@@ -1049,10 +1188,13 @@ def __init__(self, df: pd.DataFrame, parent=None):
class LoadCsvSignals(QObject):
"""Signals for the LoadCsvTask."""
+
done = Signal(str, pd.DataFrame)
+
class LoadCsvTask(QRunnable):
"""A task for loading a CSV file."""
+
def __init__(self, key: str, path: Path):
super().__init__()
self.key = key
@@ -1066,8 +1208,10 @@ def run(self):
except Exception as exc:
print(f"Failed to load CSV {self.path}: {exc}")
+
class NumericSortProxy(QSortFilterProxyModel):
"""A proxy model that optimizes sorting for numeric columns by caching sort keys."""
+
def lessThan(self, left, right):
model = self.sourceModel()
keys = model.sort_key(left.column())
diff --git a/robert/gui_easyrob/utils/utils_gui.py b/robert/gui_easyrob/utils/utils_gui.py
index 5542454..155f8b2 100644
--- a/robert/gui_easyrob/utils/utils_gui.py
+++ b/robert/gui_easyrob/utils/utils_gui.py
@@ -29,36 +29,21 @@
# ------------------------------------------------------------
# Standard library
# ------------------------------------------------------------
-import csv
-import glob
import os
import platform
-import re
import shlex
-import shutil
import subprocess
import sys
import threading
-from functools import partial
-from io import BytesIO
from pathlib import Path
-from importlib.metadata import PackageNotFoundError, version
from importlib.resources import as_file, files
# ------------------------------------------------------------
# Third-party libraries
# ------------------------------------------------------------
import pandas as pd
-import matplotlib.pyplot as plt
import psutil
-import fitz
-import rdkit
-from rdkit import Chem
-from rdkit.Chem import Draw, rdDepictor, rdFMCS
-from rdkit.Chem.Draw import rdMolDraw2D
-from rdkit.Chem.rdmolfiles import MolsFromCDXMLFile
-from rdkit.Chem.rdmolops import GetMolFrags
from ansi2html import Ansi2HTMLConverter
@@ -66,80 +51,32 @@
# Qt (PySide6)
# ------------------------------------------------------------
from PySide6.QtCore import (
- QByteArray,
- QEventLoop,
- QAbstractTableModel,
- QModelIndex,
- QObject,
- QRunnable,
- QRect,
- QSize,
- QSortFilterProxyModel,
QThread,
- QThreadPool,
- QTimer,
Qt,
Signal,
- Slot,
- QUrl,
)
from PySide6.QtGui import (
- QDesktopServices,
- QFontMetrics,
- QIcon,
- QImage,
- QMouseEvent,
- QPalette,
- QPixmap,
QWheelEvent,
)
-from PySide6.QtWebEngineCore import QWebEngineDownloadRequest
-from PySide6.QtWebEngineWidgets import QWebEngineView
from PySide6.QtWidgets import (
- QApplication,
- QCheckBox,
QComboBox,
- QDialog,
QFileDialog,
- QFormLayout,
QFrame,
- QGridLayout,
- QGroupBox,
- QHBoxLayout,
- QHeaderView,
- QInputDialog,
QLabel,
- QLineEdit,
- QListWidget,
- QMainWindow,
- QMenu,
- QMessageBox,
- QProgressBar,
QPushButton,
- QScrollArea,
- QSizePolicy,
- QSlider,
- QStackedWidget,
- QStatusBar,
- QStyle,
- QStyleOptionHeader,
- QTabWidget,
- QTableView,
- QTableWidget,
- QTableWidgetItem,
- QTextEdit,
- QToolButton,
QVBoxLayout,
- QWidget,
)
+
class DropLabel(QFrame):
"""Frame-based drop target with an optional file dialog button."""
- def __init__(self, text, parent=None, file_filter="CSV Files (*.csv)", extensions=(".csv",)):
+ def __init__(
+ self, text, parent=None, file_filter="CSV Files (*.csv)", extensions=(".csv",)
+ ):
super().__init__(parent)
self.file_filter = file_filter
self.valid_extensions = extensions
@@ -186,7 +123,9 @@ def set_file_type(self, file_filter, extensions):
def open_file_dialog(self):
"""Open a file dialog to select a file."""
- file_path, _ = QFileDialog.getOpenFileName(self, "Select File", "", self.file_filter)
+ file_path, _ = QFileDialog.getOpenFileName(
+ self, "Select File", "", self.file_filter
+ )
if file_path and self.callback:
self.set_file_path(file_path)
@@ -220,6 +159,7 @@ def setText(self, text):
"""Set the text of the label."""
self.label.setText(text)
+
class RobertWorker(QThread):
"""QThread that runs a subprocess asynchronously and streams real-time output."""
@@ -250,7 +190,8 @@ def run(self):
text=True,
bufsize=1,
universal_newlines=True,
- creationflags=subprocess.CREATE_NEW_PROCESS_GROUP | subprocess.CREATE_NO_WINDOW,
+ creationflags=subprocess.CREATE_NEW_PROCESS_GROUP
+ | subprocess.CREATE_NO_WINDOW,
)
else:
self.process = subprocess.Popen(
@@ -270,7 +211,9 @@ def read_stdout():
for line in self.process.stdout:
if self._stop_requested:
break
- formatted_line = self.ansi_converter.convert(line.strip(), full=False)
+ formatted_line = self.ansi_converter.convert(
+ line.strip(), full=False
+ )
self.output_received.emit(formatted_line)
except Exception as exc:
self.error_received.emit(f"Error reading stdout: {exc}")
@@ -281,7 +224,9 @@ def read_stderr():
for line in self.process.stderr:
if self._stop_requested:
break
- formatted_line = f'{line.strip()}'
+ formatted_line = (
+ f'{line.strip()}'
+ )
self.error_received.emit(formatted_line)
reset_line = self.ansi_converter.convert("\033[0m", full=False)
@@ -339,6 +284,7 @@ def _handle_stop(self):
except Exception as exc:
self.error_received.emit(f"Error stopping process: {exc}")
+
def smart_read_csv(filepath):
"""Read a CSV file with automatic delimiter detection."""
try:
@@ -350,6 +296,7 @@ def smart_read_csv(filepath):
except (FileNotFoundError, OSError):
return None
+
class NoScrollComboBox(QComboBox):
"""Combo box that ignores wheel events while the popup is closed."""
@@ -359,6 +306,7 @@ def wheelEvent(self, event: QWheelEvent):
else:
event.ignore()
+
class AssetPath:
"""Resolve asset paths both in development and in frozen distributions."""
@@ -379,6 +327,7 @@ def get_path(self):
)
return as_file(files("robert") / "icons" / self._filename)
+
class AssetLibrary:
"""Central registry of asset files used by the GUI."""
diff --git a/robert/gui_easyrob/version.py b/robert/gui_easyrob/version.py
index 2bf2c8b..1d28c50 100644
--- a/robert/gui_easyrob/version.py
+++ b/robert/gui_easyrob/version.py
@@ -24,12 +24,14 @@
EASYROB_VERSION = "2.0.0"
+
def get_python_package_version(pkg):
try:
return version(pkg)
except PackageNotFoundError:
return "Not found"
+
def get_cli_version(cmd):
try:
result = subprocess.run([cmd, "--version"], capture_output=True, text=True)
@@ -37,6 +39,7 @@ def get_cli_version(cmd):
except Exception:
return "Not found"
+
def get_xtb_version():
try:
result = subprocess.run(["xtb", "--version"], capture_output=True, text=True)
@@ -48,6 +51,7 @@ def get_xtb_version():
except Exception:
return "Not found"
+
def get_software_versions():
return {
"easyROB": EASYROB_VERSION,
@@ -60,4 +64,5 @@ def get_software_versions():
},
}
-SOFTWARE_VERSIONS = get_software_versions()
\ No newline at end of file
+
+SOFTWARE_VERSIONS = get_software_versions()
diff --git a/robert/predict.py b/robert/predict.py
index f2a17d0..55084f6 100644
--- a/robert/predict.py
+++ b/robert/predict.py
@@ -5,15 +5,15 @@
destination : str, default=None,
Directory to create the output file(s).
varfile : str, default=None
- Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml).
+ Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml).
params_dir : str, default=''
Folder containing the database and parameters of the ML model.
csv_test : str, default=''
- Name of the CSV file containing the test set (if any). A path can be provided (i.e.
- 'C:/Users/FOLDER/FILE.csv').
+ Name of the CSV file containing the test set (if any). A path can be provided (i.e.
+ 'C:/Users/FOLDER/FILE.csv').
t_value : float, default=2
t-value that will be the threshold to identify outliers (check tables for t-values elsewhere).
- The higher the t-value the more restrictive the analysis will be (i.e. there will be more
+ The higher the t-value the more restrictive the analysis will be (i.e. there will be more
outliers with t-value=1 than with t-value = 4).
alpha : float, default=0.05
Significance level, or probability of making a wrong decision. This parameter is related to
@@ -61,6 +61,7 @@
should_plot_predict_deep_diagnostics,
)
+
class predict:
"""
Class containing all the functions from the PREDICT module.
@@ -72,7 +73,6 @@ class predict:
"""
def __init__(self, **kwargs):
-
start_time = time.time()
# load default and user-specified variables
@@ -85,12 +85,13 @@ def __init__(self, **kwargs):
self.args.params_dir
):
if os.path.exists(params_dir):
-
- _ = print_pfi(self,params_dir)
+ _ = print_pfi(self, params_dir)
# load the Xy databse and model parameters
- Xy_data, model_data, suffix_title = load_db_n_params(self,params_dir,suffix,suffix_title,"verify",True) # module 'verify' since PREDICT follows similar protocols
-
+ Xy_data, model_data, suffix_title = load_db_n_params(
+ self, params_dir, suffix, suffix_title, "verify", True
+ ) # module 'verify' since PREDICT follows similar protocols
+
# get results from training, test and external test (if any)
Xy_data = load_n_predict(self, model_data, Xy_data, BO_opt=False)
if getattr(self.args, "uq_enable_meta", False):
@@ -98,9 +99,7 @@ def __init__(self, **kwargs):
self, Xy_data, model_data, params_dir
)
if getattr(self.args, "uq_auto_enable", False):
- Xy_data = apply_auto_uq(
- self, Xy_data, model_data, params_dir
- )
+ Xy_data = apply_auto_uq(self, Xy_data, model_data, params_dir)
# save predictions for all sets
path_n_suffix, name_points, Xy_data = save_predictions(
@@ -146,4 +145,4 @@ def __init__(self, **kwargs):
self, Xy_data, path_n_suffix, model_data
)
- _ = finish_print(self,start_time,'PREDICT')
+ _ = finish_print(self, start_time, "PREDICT")
diff --git a/robert/predict_utils.py b/robert/predict_utils.py
index 83e66b3..762cd41 100644
--- a/robert/predict_utils.py
+++ b/robert/predict_utils.py
@@ -99,60 +99,91 @@ def _append_split_columns(
def plot_predictions(self, params_dict, Xy_data, path_n_suffix):
- '''
+ """
Plot graphs of predicted vs actual values for train, validation and test sets
- '''
+ """
+
+ set_types = [
+ f"{params_dict['repeat_kfolds']}x {params_dict['kfold']}-fold CV",
+ "test",
+ ]
- set_types = [f"{params_dict['repeat_kfolds']}x {params_dict['kfold']}-fold CV",'test']
-
graph_style = get_graph_style()
-
- self.args.log.write(f"\n o Saving graphs in:")
- if params_dict['type'].lower() == 'reg':
+ self.args.log.write("\n o Saving graphs in:")
+
+ if params_dict["type"].lower() == "reg":
# Plot graph with all sets
- _ = graph_reg(self,Xy_data,params_dict,set_types,path_n_suffix,graph_style)
+ _ = graph_reg(self, Xy_data, params_dict, set_types, path_n_suffix, graph_style)
# Plot CV average ± SD graph of validation or test set
- _ = graph_reg(self,Xy_data,params_dict,set_types,path_n_suffix,graph_style,sd_graph=True)
- if 'y_external' in Xy_data and not Xy_data['y_external'].isnull().values.any() and len(Xy_data['y_external']) > 0:
+ _ = graph_reg(
+ self,
+ Xy_data,
+ params_dict,
+ set_types,
+ path_n_suffix,
+ graph_style,
+ sd_graph=True,
+ )
+ if (
+ "y_external" in Xy_data
+ and not Xy_data["y_external"].isnull().values.any()
+ and len(Xy_data["y_external"]) > 0
+ ):
# Plot CV average ± SD graph of external set
- set_type = 'external'
- _ = graph_reg(self,Xy_data,params_dict,set_type,path_n_suffix,graph_style,csv_test=True,sd_graph=True)
+ set_type = "external"
+ _ = graph_reg(
+ self,
+ Xy_data,
+ params_dict,
+ set_type,
+ path_n_suffix,
+ graph_style,
+ csv_test=True,
+ sd_graph=True,
+ )
- elif params_dict['type'].lower() == 'clas':
+ elif params_dict["type"].lower() == "clas":
for set_type in set_types:
- _ = graph_clas(self,Xy_data,params_dict,set_type,path_n_suffix)
- if 'y_external' in Xy_data and not Xy_data['y_external'].isnull().values.any() and len(Xy_data['y_external']) > 0:
- set_type = 'external'
- _ = graph_clas(self,Xy_data,params_dict,set_type,path_n_suffix,csv_test=True)
+ _ = graph_clas(self, Xy_data, params_dict, set_type, path_n_suffix)
+ if (
+ "y_external" in Xy_data
+ and not Xy_data["y_external"].isnull().values.any()
+ and len(Xy_data["y_external"]) > 0
+ ):
+ set_type = "external"
+ _ = graph_clas(
+ self, Xy_data, params_dict, set_type, path_n_suffix, csv_test=True
+ )
return graph_style
-def save_predictions(self,Xy_data,model_data,suffix_title):
- '''
+def save_predictions(self, Xy_data, model_data, suffix_title):
+ """
Saves CSV files with the different sets and their predicted results
- '''
+ """
# Check if we need to reconvert class labels (for classification with string labels)
reconvert_labels = False
class_mapping_reverse = None
- if 'class_0_label' in model_data and 'class_1_label' in model_data:
+ if "class_0_label" in model_data and "class_1_label" in model_data:
reconvert_labels = True
class_mapping_reverse = {
- 0: model_data['class_0_label'],
- 1: model_data['class_1_label']
+ 0: model_data["class_0_label"],
+ 1: model_data["class_1_label"],
}
# save CV and test results as a single df
- Xy_train, Xy_test = pd.DataFrame(Xy_data['names_train']), pd.DataFrame(Xy_data['names_test'])
- for col in Xy_data['X_train']:
- Xy_train[col] = Xy_data['X_train'][col].tolist()
- Xy_test[col] = Xy_data['X_test'][col].tolist()
-
+ Xy_train, Xy_test = (
+ pd.DataFrame(Xy_data["names_train"]),
+ pd.DataFrame(Xy_data["names_test"]),
+ )
+ for col in Xy_data["X_train"]:
+ Xy_train[col] = Xy_data["X_train"][col].tolist()
+ Xy_test[col] = Xy_data["X_test"][col].tolist()
+
# Store y values and predictions, reconverting if needed
- y_col = model_data['y']
-
hw_scalar = float(Xy_data.get("conformal_half_width", float("nan")))
if model_data["type"].lower() != "reg":
hw_scalar = float("nan")
@@ -182,19 +213,19 @@ def save_predictions(self,Xy_data,model_data,suffix_title):
df_results = pd.concat([Xy_train, Xy_test], axis=0)
# add column with sets
- train_list = ['CV' for _ in Xy_data['y_train']]
- test_list = ['Test' for _ in Xy_data['y_test']]
+ train_list = ["CV" for _ in Xy_data["y_train"]]
+ test_list = ["Test" for _ in Xy_data["y_test"]]
col_set = train_list + test_list
- df_results['Set'] = col_set
+ df_results["Set"] = col_set
# save results as CSV
base_csv_name = f"PREDICT/{model_data['model']}_{suffix_title}"
base_csv_path = f"{Path(os.getcwd()).joinpath(base_csv_name)}"
- path_n_suffix = f'{base_csv_path}'
- _ = df_results.to_csv(f'{base_csv_path}.csv', index = None, header=True)
-
+ path_n_suffix = f"{base_csv_path}"
+ _ = df_results.to_csv(f"{base_csv_path}.csv", index=None, header=True)
+
# also save results for performance of individual folds (useful for t-tests and Wilcoxon tests between the folds)
- error1, error2, error3 = get_error_labels(model_data['type'])
+ error1, error2, error3 = get_error_labels(model_data["type"])
# df_folds = pd.DataFrame()
# df_folds['Fold'] = [f'{i+1}' for i in range(len(Xy_data['idx_valid']))]
@@ -210,12 +241,14 @@ def save_predictions(self,Xy_data,model_data,suffix_title):
# _ = df_folds.to_csv(f'{path_folds}.csv', index = None, header=True)
# prints
- print_preds = f' o Saving CSV databases with predictions and their SD in:'
- print_preds += f'\n - Predicted results of starting dataset: {base_csv_name}.csv'
+ print_preds = " o Saving CSV databases with predictions and their SD in:"
+ print_preds += (
+ f"\n - Predicted results of starting dataset: {base_csv_name}.csv"
+ )
- if self.args.csv_test != '':
+ if self.args.csv_test != "":
# saves prediction for external test in --csv_test
- Xy_external = pd.DataFrame(Xy_data['names_external'])
+ Xy_external = pd.DataFrame(Xy_data["names_external"])
for col in Xy_data["X_external"]:
Xy_external[col] = Xy_data["X_external"][col].tolist()
@@ -241,30 +274,36 @@ def save_predictions(self,Xy_data,model_data,suffix_title):
Xy_external
)
if auto_src is not None:
- Xy_external[f"{model_data['y']}_pred_uq_auto_source"] = [
- auto_src
- ] * len(Xy_external)
+ Xy_external[f"{model_data['y']}_pred_uq_auto_source"] = [auto_src] * len(
+ Xy_external
+ )
- path_external = Path(os.getcwd()).joinpath('PREDICT/csv_test/')
+ path_external = Path(os.getcwd()).joinpath("PREDICT/csv_test/")
Path(path_external).mkdir(exist_ok=True, parents=True)
- csv_name_external = f'{os.path.basename(self.args.csv_test).split(".csv")[0]}_{model_data["model"]}_{suffix_title}.csv'
+ csv_name_external = f"{os.path.basename(self.args.csv_test).split('.csv')[0]}_{model_data['model']}_{suffix_title}.csv"
name_external = f"{path_external}/{csv_name_external}"
- _ = Xy_external.to_csv(name_external, index = None, header=True)
- print_preds += f'\n - External set with predicted results: PREDICT/csv_test/{csv_name_external}'
+ _ = Xy_external.to_csv(name_external, index=None, header=True)
+ print_preds += f"\n - External set with predicted results: PREDICT/csv_test/{csv_name_external}"
self.args.log.write(print_preds)
# store the names of the datapoints
name_points = {}
- if model_data['names'] != '':
- if model_data['names'].lower() in Xy_train: # accounts for upper/lowercase mismatches
- model_data['names'] = model_data['names'].lower()
- if model_data['names'].upper() in Xy_train:
- model_data['names'] = model_data['names'].upper()
- if model_data['names'] in Xy_train:
- name_points['train'] = df_results[model_data['names']][df_results.Set == 'CV']
- name_points['test'] = df_results[model_data['names']][df_results.Set == 'Test']
+ if model_data["names"] != "":
+ if (
+ model_data["names"].lower() in Xy_train
+ ): # accounts for upper/lowercase mismatches
+ model_data["names"] = model_data["names"].lower()
+ if model_data["names"].upper() in Xy_train:
+ model_data["names"] = model_data["names"].upper()
+ if model_data["names"] in Xy_train:
+ name_points["train"] = df_results[model_data["names"]][
+ df_results.Set == "CV"
+ ]
+ name_points["test"] = df_results[model_data["names"]][
+ df_results.Set == "Test"
+ ]
return path_n_suffix, name_points, Xy_data
@@ -280,20 +319,17 @@ def _ensure_pred_range_stats(Xy_data):
Xy_data["pred_range"] = float(np.abs(pred_max - pred_min))
-def print_predict(self,Xy_data,model_data,suffix_title):
- '''
+def print_predict(self, Xy_data, model_data, suffix_title):
+ """
Prints results of the predictions for all the sets
- '''
+ """
_ensure_pred_range_stats(Xy_data)
- print_results = (
- "\n o Summary of results "
- f"{model_data['model']}_{suffix_title}:"
- )
+ print_results = f"\n o Summary of results {model_data['model']}_{suffix_title}:"
# get number of points and proportions
- n_train = len(Xy_data['y_train'])
- n_test = len(Xy_data['y_test'])
+ n_train = len(Xy_data["y_train"])
+ n_test = len(Xy_data["y_test"])
print_results += (
"\n - Point counts: CV (train+valid.) = "
f"{n_train}, held-out test = {n_test}"
@@ -303,11 +339,10 @@ def print_predict(self,Xy_data,model_data,suffix_title):
prop_train = round(n_train * 100 / total_points)
prop_test = round(n_test * 100 / total_points)
print_results += (
- f"\n - Proportion CV (train+valid.):test = "
- f"{prop_train}:{prop_test}"
+ f"\n - Proportion CV (train+valid.):test = {prop_train}:{prop_test}"
)
- n_descps = len(Xy_data['X_train'].keys())
+ n_descps = len(Xy_data["X_train"].keys())
print_results += f"\n - Number of descriptors = {n_descps}"
print_results += (
"\n - Proportion (train+valid.) points:descriptors = "
@@ -316,55 +351,66 @@ def print_predict(self,Xy_data,model_data,suffix_title):
# print results and save dat file
CV_type = f"{model_data['repeat_kfolds']}x {model_data['kfold']}-fold CV"
- if model_data['type'].lower() == 'reg':
+ if model_data["type"].lower() == "reg":
print_results += f"\n - {CV_type} : R2 = {Xy_data['r2_train']:.2}, MAE = {Xy_data['mae_train']:.2}, RMSE = {Xy_data['rmse_train']:.2}"
print_results += f"\n - Test : R2 = {Xy_data['r2_test']:.2}, MAE = {Xy_data['mae_test']:.2}, RMSE = {Xy_data['rmse_test']:.2}"
print_results += f"\n - Average SD in test set = {np.mean(Xy_data['y_pred_test_sd']):.2}"
print_results += f"\n - y range of dataset (train+valid.) = {float(Xy_data['pred_min']):.2} to {float(Xy_data['pred_max']):.2}, total {float(Xy_data['pred_range']):.2}"
- if 'y_external' in Xy_data and not Xy_data['y_external'].isnull().values.any() and len(Xy_data['y_external']) > 0:
+ if (
+ "y_external" in Xy_data
+ and not Xy_data["y_external"].isnull().values.any()
+ and len(Xy_data["y_external"]) > 0
+ ):
print_results += f"\n - External test : R2 = {Xy_data['r2_external']:.2}, MAE = {Xy_data['mae_external']:.2}, RMSE = {Xy_data['rmse_external']:.2}"
- elif model_data['type'].lower() == 'clas':
+ elif model_data["type"].lower() == "clas":
print_results += f"\n - {CV_type} : Accur. = {Xy_data['acc_train']:.2}, F1 score = {Xy_data['f1_train']:.2}, MCC = {Xy_data['mcc_train']:.2}"
- if 'y_pred_test' in Xy_data and not Xy_data['y_test'].isnull().values.any() and len(Xy_data['y_test']) > 0:
+ if (
+ "y_pred_test" in Xy_data
+ and not Xy_data["y_test"].isnull().values.any()
+ and len(Xy_data["y_test"]) > 0
+ ):
print_results += f"\n - Test : Accur. = {Xy_data['acc_test']:.2}, F1 score = {Xy_data['f1_test']:.2}, MCC = {Xy_data['mcc_test']:.2}"
- if 'y_external' in Xy_data and not Xy_data['y_external'].isnull().values.any() and len(Xy_data['y_external']) > 0:
+ if (
+ "y_external" in Xy_data
+ and not Xy_data["y_external"].isnull().values.any()
+ and len(Xy_data["y_external"]) > 0
+ ):
print_results += f"\n - External test : Accur. = {Xy_data['acc_external']:.2}, F1 score = {Xy_data['f1_external']:.2}, MCC = {Xy_data['mcc_external']:.2}"
self.args.log.write(print_results)
-def pearson_map_predict(self,Xy_data,params_dir):
- '''
+def pearson_map_predict(self, Xy_data, params_dir):
+ """
Plots the Pearson map and analyzes correlation of descriptors.
- '''
+ """
- X_combined = pd.concat([Xy_data['X_train'], Xy_data['X_test']], axis=0, ignore_index=True)
- corr_matrix = pearson_map(self,X_combined,'predict',params_dir=params_dir)
+ X_combined = pd.concat(
+ [Xy_data["X_train"], Xy_data["X_test"]], axis=0, ignore_index=True
+ )
+ corr_matrix = pearson_map(self, X_combined, "predict", params_dir=params_dir)
- corr_dict = {'descp_1': [],
- 'descp_2': [],
- 'r': []
- }
- for i,descp in enumerate(corr_matrix.columns):
- for j,val in enumerate(corr_matrix[descp]):
+ corr_dict = {"descp_1": [], "descp_2": [], "r": []}
+ for i, descp in enumerate(corr_matrix.columns):
+ for j, val in enumerate(corr_matrix[descp]):
if i < j and np.abs(val) > 0.8:
- corr_dict['descp_1'].append(corr_matrix.columns[i])
- corr_dict['descp_2'].append(corr_matrix.columns[j])
- corr_dict['r'].append(val)
+ corr_dict["descp_1"].append(corr_matrix.columns[i])
+ corr_dict["descp_2"].append(corr_matrix.columns[j])
+ corr_dict["r"].append(val)
- print_corr = f' Ideally, variables should show low correlations.' # no initial \n, it's a new log.write
- if len(corr_dict['descp_1']) == 0:
- print_corr += f"\n o Correlations between variables are acceptable"
+ print_corr = " Ideally, variables should show low correlations." # no initial \n, it's a new log.write
+ if len(corr_dict["descp_1"]) == 0:
+ print_corr += "\n o Correlations between variables are acceptable"
else:
- abs_r_list = list(np.abs(corr_dict['r']))
+ abs_r_list = list(np.abs(corr_dict["r"]))
abs_max_r = max(abs_r_list)
- max_r = corr_dict['r'][abs_r_list.index(abs_max_r)]
- max_descp_1 = corr_dict['descp_1'][abs_r_list.index(abs_max_r)]
- max_descp_2 = corr_dict['descp_2'][abs_r_list.index(abs_max_r)]
+ max_r = corr_dict["r"][abs_r_list.index(abs_max_r)]
+ max_descp_1 = corr_dict["descp_1"][abs_r_list.index(abs_max_r)]
+ max_descp_2 = corr_dict["descp_2"][abs_r_list.index(abs_max_r)]
if abs_max_r > 0.84:
- print_corr += f"\n x WARNING! High correlations observed (up to r = {max_r:.2} or R2 = {max_r*max_r:.2}, for {max_descp_1} and {max_descp_2})"
+ print_corr += f"\n x WARNING! High correlations observed (up to r = {max_r:.2} or R2 = {max_r * max_r:.2}, for {max_descp_1} and {max_descp_2})"
elif abs_max_r > 0.71:
- print_corr += f"\n x WARNING! Noticeable correlations observed (up to r = {max_r:.2} or R2 = {max_r*max_r:.2}, for {max_descp_1} and {max_descp_2})"
+ print_corr += f"\n x WARNING! Noticeable correlations observed (up to r = {max_r:.2} or R2 = {max_r * max_r:.2}, for {max_descp_1} and {max_descp_2})"
self.args.log.write(print_corr)
diff --git a/robert/report.py b/robert/report.py
index a7431f0..e995653 100644
--- a/robert/report.py
+++ b/robert/report.py
@@ -5,7 +5,7 @@
destination : str, default=None,
Directory to create the output file(s).
varfile : str, default=None
- Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml).
+ Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml).
report_modules : list of str, default=['AQME','CURATE','GENERATE','VERIFY','PREDICT']
List of the modules to include in the report.
debug_report : bool, default=False
@@ -23,9 +23,9 @@
import json
import platform
import pandas as pd
-import traceback
from pathlib import Path
-from robert.utils import (load_variables,
+from robert.utils import (
+ load_variables,
pd_to_dict,
)
from robert.report_utils import (
@@ -52,7 +52,7 @@
get_outliers,
detect_predictions,
get_csv_metrics,
- get_csv_pred
+ get_csv_pred,
)
@@ -69,23 +69,26 @@ class report:
def __init__(self, **kwargs):
# check if there is a problem with weasyprint (required for this module)
# Suppress fontconfig warnings during import on Windows
- if platform.system() == 'Windows':
+ if platform.system() == "Windows":
import tempfile
- temp_stderr = tempfile.TemporaryFile(mode='w+')
+
+ temp_stderr = tempfile.TemporaryFile(mode="w+")
old_stderr = os.dup(2)
os.dup2(temp_stderr.fileno(), 2)
-
+
try:
from weasyprint import HTML
except (OSError, ModuleNotFoundError):
- if platform.system() == 'Windows':
+ if platform.system() == "Windows":
os.dup2(old_stderr, 2)
os.close(old_stderr)
temp_stderr.close()
- print(f"\nx The REPORT module requires some libraries that are missing, the PDF with the summary of the results has not been created. Try installing the libraries with 'conda install -y -c conda-forge glib gtk3 pango mscorefonts'")
+ print(
+ "\nx The REPORT module requires some libraries that are missing, the PDF with the summary of the results has not been created. Try installing the libraries with 'conda install -y -c conda-forge glib gtk3 pango mscorefonts'"
+ )
sys.exit()
finally:
- if platform.system() == 'Windows':
+ if platform.system() == "Windows":
os.dup2(old_stderr, 2)
os.close(old_stderr)
temp_stderr.close()
@@ -95,46 +98,56 @@ def __init__(self, **kwargs):
eval_only = False
# if EVALUATE is activated, no PFI models are generated
- path_eval = Path(f'{os.getcwd()}/EVALUATE/EVALUATE_data.dat')
+ path_eval = Path(f"{os.getcwd()}/EVALUATE/EVALUATE_data.dat")
if os.path.exists(path_eval):
eval_only = True
# get spacing between No PFI and PFI columns
- spacing_PFI = f'{(" ")*4}'
+ spacing_PFI = f"{(' ') * 4}"
# Reproducibility section (these functions only gather information, the sections
# will be print later in the report)
- citation_dat, repro_dat, dat_files, csv_name, robert_version = self.get_repro(eval_only)
+ citation_dat, repro_dat, dat_files, csv_name, robert_version = self.get_repro(
+ eval_only
+ )
- # Transparency section
- transpa_dat,params_df = self.get_transparency(spacing_PFI)
- pred_type = params_df['type'][0].lower()
+ # Transparency section
+ transpa_dat, params_df = self.get_transparency(spacing_PFI)
+ pred_type = params_df["type"][0].lower()
# print header
report_html = self.print_header(citation_dat)
# print ROBERT score section
- score_dat,data_score = self.print_score(dat_files,pred_type,eval_only,spacing_PFI)
+ score_dat, data_score = self.print_score(
+ dat_files, pred_type, eval_only, spacing_PFI
+ )
report_html += score_dat
# print warnings in ROBERT score section
- warnings_dat,warnings_dict = self.print_warnings(pred_type,eval_only,data_score)
+ warnings_dat, warnings_dict = self.print_warnings(
+ pred_type, eval_only, data_score
+ )
report_html += warnings_dat
# print advanced score analysis
- report_html += self.print_adv_anal(pred_type,eval_only,spacing_PFI,data_score)
+ report_html += self.print_adv_anal(
+ pred_type, eval_only, spacing_PFI, data_score
+ )
# print y distribution
- report_html += self.print_y_distrib(pred_type,eval_only,spacing_PFI,warnings_dict)
+ report_html += self.print_y_distrib(
+ pred_type, eval_only, spacing_PFI, warnings_dict
+ )
# print feature importances
- report_html += self.print_features(warnings_dict,eval_only,spacing_PFI)
+ report_html += self.print_features(warnings_dict, eval_only, spacing_PFI)
# print outlier analysis
- report_html += self.print_outliers(pred_type,eval_only,spacing_PFI)
+ report_html += self.print_outliers(pred_type, eval_only, spacing_PFI)
# print model screening
- report_html += self.print_generate(pred_type,eval_only)
+ report_html += self.print_generate(pred_type, eval_only)
# print reproducibility section
report_html += repro_dat
@@ -146,7 +159,7 @@ def __init__(self, **kwargs):
report_html += self.get_abbrev()
# print new predictions
- report_html += self.print_predictions(pred_type,eval_only,spacing_PFI)
+ report_html += self.print_predictions(pred_type, eval_only, spacing_PFI)
# print miscellaneous section
report_html += self.print_misc()
@@ -157,34 +170,35 @@ def __init__(self, **kwargs):
# create css
with open("report.css", "w", encoding="utf-8") as cssfile:
- cssfile.write(css_content(csv_name,robert_version))
+ cssfile.write(css_content(csv_name, robert_version))
# Suppress fontconfig warnings from WeasyPrint on Windows
# These warnings come from the C library level, so we need to redirect at OS level
- if platform.system() == 'Windows':
+ if platform.system() == "Windows":
import tempfile
-
+
# Create a temporary file to redirect stderr
- temp_stderr = tempfile.TemporaryFile(mode='w+')
+ temp_stderr = tempfile.TemporaryFile(mode="w+")
old_stderr = os.dup(2) # Duplicate stderr file descriptor
os.dup2(temp_stderr.fileno(), 2) # Redirect stderr to temp file
-
+
try:
- _ = make_report(report_html,HTML)
+ _ = make_report(report_html, HTML)
finally:
os.dup2(old_stderr, 2) # Restore stderr
os.close(old_stderr)
temp_stderr.close()
else:
- _ = make_report(report_html,HTML)
+ _ = make_report(report_html, HTML)
# Remove report.css file
os.remove("report.css")
-
- print('\no ROBERT_report.pdf was created successfully in the working directory!')
+ print(
+ "\no ROBERT_report.pdf was created successfully in the working directory!"
+ )
- def print_header(self,citation_dat):
+ def print_header(self, citation_dat):
"""
Retrieves the header for the HTML string
"""
@@ -200,154 +214,172 @@ def print_header(self,citation_dat):
return header_lines
-
- def print_score(self,dat_files,pred_type,eval_only,spacing_PFI):
+ def print_score(self, dat_files, pred_type, eval_only, spacing_PFI):
"""
Generates the ROBERT score section
"""
-
+
# starts with the icon of ROBERT score
- score_dat = ''
- score_dat = self.module_lines('score',score_dat)
+ score_dat = ""
+ score_dat = self.module_lines("score", score_dat)
# calculates the ROBERT scores (R2 is analogous for accuracy in classification)
data_score = {}
- columns_score,columns_summary = [],[]
+ columns_score, columns_summary = [], []
# get two columns to combine and print
- for suffix in ['No PFI','PFI']:
- spacing = get_spacing_col(suffix,spacing_PFI)
+ for suffix in ["No PFI", "PFI"]:
+ spacing = get_spacing_col(suffix, spacing_PFI)
- if eval_only and suffix == 'PFI':
- columns_score.append('')
+ if eval_only and suffix == "PFI":
+ columns_score.append("")
else:
# calculate score
- data_score = calc_score(dat_files,suffix,pred_type,data_score)
+ data_score = calc_score(dat_files, suffix, pred_type, data_score)
# initial two-column ROBERT score summary
- score_info = f"""{spacing}![]()
| ''' + | """
# add severe warnings
- warning_print += f'''
- {space}Severe warnings ''' - if len(warnings_dict[f'severe_warnings_{suffix}']) == 0: + warning_print += f""" +{space}Severe warnings """ + if len(warnings_dict[f"severe_warnings_{suffix}"]) == 0: warning_print += self.print_line_warning( - 'No severe warnings detected', - style_lines,color_dict['blue'],space) + "No severe warnings detected", + style_lines, + color_dict["blue"], + space, + ) else: - for sev_warning in warnings_dict[f'severe_warnings_{suffix}']: + for sev_warning in warnings_dict[f"severe_warnings_{suffix}"]: warning_print += self.print_line_warning( - sev_warning, - style_lines,color_dict['red'],space) + sev_warning, style_lines, color_dict["red"], space + ) # add moderate warnings - warning_print += f''' -{space}Moderate warnings ''' - if len(warnings_dict[f'moderate_warnings_{suffix}']) == 0: + warning_print += f""" +{space}Moderate warnings """ + if len(warnings_dict[f"moderate_warnings_{suffix}"]) == 0: warning_print += self.print_line_warning( - 'No moderate warnings detected', - style_lines,color_dict['blue'],space) + "No moderate warnings detected", + style_lines, + color_dict["blue"], + space, + ) else: - for mode_warning in warnings_dict[f'moderate_warnings_{suffix}']: + for mode_warning in warnings_dict[f"moderate_warnings_{suffix}"]: warning_print += self.print_line_warning( - mode_warning, - style_lines,color_dict['yellow'],space) + mode_warning, style_lines, color_dict["yellow"], space + ) # add overall assessment - warning_print += self.print_assessment(space,suffix,data_score,style_lines,warnings_dict,color_dict,pred_type) - + warning_print += self.print_assessment( + space, + suffix, + data_score, + style_lines, + warnings_dict, + color_dict, + pred_type, + ) + # end table - warning_print += f''' |
-
' # table style - warnings_dat = ''' - ''' + """ - return space,color_dict,style_lines,warnings_dat + return space, color_dict, style_lines, warnings_dat - def analyze_warnings(self,data_score,suffix,warnings_dict,pred_type): - ''' + def analyze_warnings(self, data_score, suffix, warnings_dict, pred_type): + """ Analyze and append warnings - ''' - + """ + # tests from flawed models - if data_score[f'flawed_mod_score_{suffix}'] < 0: - if data_score[f'failed_tests_{suffix}'] > 0: - warnings_dict[f'severe_warnings_{suffix}'].append('Failing required tests (Section B.1)') + if data_score[f"flawed_mod_score_{suffix}"] < 0: + if data_score[f"failed_tests_{suffix}"] > 0: + warnings_dict[f"severe_warnings_{suffix}"].append( + "Failing required tests (Section B.1)" + ) else: - warnings_dict[f'moderate_warnings_{suffix}'].append('Some tests are unclear (Section B.1)') + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Some tests are unclear (Section B.1)" + ) # variation in CV - if pred_type == 'reg': - if data_score[f'cv_sd_score_{suffix}'] == 0: - warnings_dict[f'moderate_warnings_{suffix}'].append('Imprecise predictions (Section B.3b)') - elif pred_type == 'clas': - if data_score[f'diff_mcc_score_{suffix}'] == 0: - warnings_dict[f'moderate_warnings_{suffix}'].append('Imprecise predictions (Section B.3b)') + if pred_type == "reg": + if data_score[f"cv_sd_score_{suffix}"] == 0: + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Imprecise predictions (Section B.3b)" + ) + elif pred_type == "clas": + if data_score[f"diff_mcc_score_{suffix}"] == 0: + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Imprecise predictions (Section B.3b)" + ) # y distribution - if 'WARNING! Your data is not uniform' in warnings_dict[f'y_dist_info_{suffix}']: - if pred_type == 'reg': - warnings_dict[f'moderate_warnings_{suffix}'].append('Uneven y distribution (Section C)') - elif pred_type == 'clas': # it's severe in clasification - warnings_dict[f'severe_warnings_{suffix}'].append('Very uneven class distribution (Section C)') - elif 'WARNING! Your data is slightly not uniform' in warnings_dict[f'y_dist_info_{suffix}']: - if pred_type == 'reg': - warnings_dict[f'moderate_warnings_{suffix}'].append('Slightly uneven y distribution (Section C)') - elif pred_type == 'clas': - warnings_dict[f'moderate_warnings_{suffix}'].append('Uneven class distribution (Section C)') + if ( + "WARNING! Your data is not uniform" + in warnings_dict[f"y_dist_info_{suffix}"] + ): + if pred_type == "reg": + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Uneven y distribution (Section C)" + ) + elif pred_type == "clas": # it's severe in clasification + warnings_dict[f"severe_warnings_{suffix}"].append( + "Very uneven class distribution (Section C)" + ) + elif ( + "WARNING! Your data is slightly not uniform" + in warnings_dict[f"y_dist_info_{suffix}"] + ): + if pred_type == "reg": + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Slightly uneven y distribution (Section C)" + ) + elif pred_type == "clas": + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Uneven class distribution (Section C)" + ) # feature correlation - if 'WARNING! High correlations' in warnings_dict[f'pearson_info_{suffix}']: - warnings_dict[f'moderate_warnings_{suffix}'].append('Highly correlated features (Section D)') - elif 'WARNING! Noticeable correlations' in warnings_dict[f'pearson_info_{suffix}']: - warnings_dict[f'moderate_warnings_{suffix}'].append('Moderately correlated features (Section D)') + if "WARNING! High correlations" in warnings_dict[f"pearson_info_{suffix}"]: + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Highly correlated features (Section D)" + ) + elif ( + "WARNING! Noticeable correlations" + in warnings_dict[f"pearson_info_{suffix}"] + ): + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Moderately correlated features (Section D)" + ) # outliers (threshold is set above 6.5 SD, around 99.9 CI) - if pred_type == 'reg': - if warnings_dict[f'max_sd_{suffix}'] > 6.5: - warnings_dict[f'moderate_warnings_{suffix}'].append('Potential "faulty" outliers (Section E)') + if pred_type == "reg": + if warnings_dict[f"max_sd_{suffix}"] > 6.5: + warnings_dict[f"moderate_warnings_{suffix}"].append( + 'Potential "faulty" outliers (Section E)' + ) return warnings_dict - - def get_warning_lines(self,pred_type): - ''' + def get_warning_lines(self, pred_type): + """ Gather the lines from PREDICT where the potential warnings are print - ''' - + """ + warnings_dict = {} # get lines with warnings from PREDICT - file_pred = f'{os.getcwd()}/PREDICT/PREDICT_data.dat' - with open(file_pred, 'r', encoding='utf-8') as datfile: + file_pred = f"{os.getcwd()}/PREDICT/PREDICT_data.dat" + with open(file_pred, "r", encoding="utf-8") as datfile: lines = datfile.readlines() - pfi_section_pearson = False # to get both No PFI and PFI information + pfi_section_pearson = False # to get both No PFI and PFI information pfi_section_y_dist = False pfi_section_outlier = False - for i,line in enumerate(lines): - if 'Ideally, variables should show low' in line and not pfi_section_pearson: - warnings_dict['pearson_info_No PFI'] = lines[i+1][6:] - pfi_section_pearson = True # the next line found will correspond to the PFI section - elif 'Ideally, variables should show low' in line and pfi_section_pearson: - warnings_dict['pearson_info_PFI'] = lines[i+1][6:] - if 'Ideally, the number of datapoints in' in line and not pfi_section_y_dist: - warnings_dict['y_dist_info_No PFI'] = lines[i+2][6:] + for i, line in enumerate(lines): + if ( + "Ideally, variables should show low" in line + and not pfi_section_pearson + ): + warnings_dict["pearson_info_No PFI"] = lines[i + 1][6:] + pfi_section_pearson = ( + True # the next line found will correspond to the PFI section + ) + elif ( + "Ideally, variables should show low" in line and pfi_section_pearson + ): + warnings_dict["pearson_info_PFI"] = lines[i + 1][6:] + if ( + "Ideally, the number of datapoints in" in line + and not pfi_section_y_dist + ): + warnings_dict["y_dist_info_No PFI"] = lines[i + 2][6:] pfi_section_y_dist = True - elif 'Ideally, the number of datapoints in' in line and pfi_section_y_dist: - warnings_dict['y_dist_info_PFI'] = lines[i+2][6:] - if pred_type == 'reg': - if 'Outliers plot saved' in line and not pfi_section_outlier: + elif ( + "Ideally, the number of datapoints in" in line + and pfi_section_y_dist + ): + warnings_dict["y_dist_info_PFI"] = lines[i + 2][6:] + if pred_type == "reg": + if "Outliers plot saved" in line and not pfi_section_outlier: max_SD = 0 - for j in range(i,len(lines)): - if '-------' in lines[j]: + for j in range(i, len(lines)): + if "-------" in lines[j]: break - elif 'SDs' in lines[j]: + elif "SDs" in lines[j]: sd_line = float(lines[j].split()[2][1:]) if sd_line > max_SD: max_SD = sd_line - warnings_dict['max_sd_No PFI'] = max_SD + warnings_dict["max_sd_No PFI"] = max_SD pfi_section_outlier = True - elif 'Outliers plot saved' in line and pfi_section_outlier: + elif "Outliers plot saved" in line and pfi_section_outlier: max_SD = 0 - for j in range(i,len(lines)): - if '-------' in lines[j]: + for j in range(i, len(lines)): + if "-------" in lines[j]: break - elif 'SDs' in lines[j]: + elif "SDs" in lines[j]: sd_line = float(lines[j].split()[2][1:]) if sd_line > max_SD: max_SD = sd_line - warnings_dict['max_sd_PFI'] = max_SD + warnings_dict["max_sd_PFI"] = max_SD return warnings_dict - - def print_line_warning(self,message,style_lines,color,space): - ''' + def print_line_warning(self, message, style_lines, color, space): + """ Add line with warning - ''' - - return f''' - {style_lines}{space}◉ - {space}{message}
''' - + """ - def print_assessment(self,space,suffix,data_score,style_lines,warnings_dict,color_dict,pred_type): - ''' + return f""" + {style_lines}{space}◉ + {space}{message}""" + + def print_assessment( + self, + space, + suffix, + data_score, + style_lines, + warnings_dict, + color_dict, + pred_type, + ): + """ Add overall assessment to the ROBERT score section - ''' - - assessment_print = f''' -{space}Overall assessment
''' + """ - if len(warnings_dict[f'severe_warnings_{suffix}']) > 0 or data_score[f'robert_score_{suffix}'] < 5: - assessment_print += self.print_line_warning( - 'The model is unreliable', - style_lines,color_dict['red'],space) + assessment_print = f""" +{space}Overall assessment
""" - elif data_score[f'robert_score_{suffix}'] in [9,10]: - if pred_type == 'reg' and len(warnings_dict[f'moderate_warnings_{suffix}']) >= 3: + if ( + len(warnings_dict[f"severe_warnings_{suffix}"]) > 0 + or data_score[f"robert_score_{suffix}"] < 5 + ): + assessment_print += self.print_line_warning( + "The model is unreliable", style_lines, color_dict["red"], space + ) + + elif data_score[f"robert_score_{suffix}"] in [9, 10]: + if ( + pred_type == "reg" + and len(warnings_dict[f"moderate_warnings_{suffix}"]) >= 3 + ): assessment_print += self.print_line_warning( - 'Reliable model, but examine warnings', - style_lines,color_dict['yellow'],space) - elif pred_type == 'clas' and len(warnings_dict[f'moderate_warnings_{suffix}']) >= 2: + "Reliable model, but examine warnings", + style_lines, + color_dict["yellow"], + space, + ) + elif ( + pred_type == "clas" + and len(warnings_dict[f"moderate_warnings_{suffix}"]) >= 2 + ): assessment_print += self.print_line_warning( - 'Reliable model, but examine warnings', - style_lines,color_dict['yellow'],space) + "Reliable model, but examine warnings", + style_lines, + color_dict["yellow"], + space, + ) else: assessment_print += self.print_line_warning( - f'The model seems reliable', - style_lines,color_dict['blue'],space) + "The model seems reliable", style_lines, color_dict["blue"], space + ) - elif data_score[f'robert_score_{suffix}'] in [7,8]: + elif data_score[f"robert_score_{suffix}"] in [7, 8]: assessment_print += self.print_line_warning( - 'Decent model, but it has limitations', - style_lines,color_dict['yellow'],space) + "Decent model, but it has limitations", + style_lines, + color_dict["yellow"], + space, + ) - elif data_score[f'robert_score_{suffix}'] in [5,6]: + elif data_score[f"robert_score_{suffix}"] in [5, 6]: assessment_print += self.print_line_warning( - 'Moderate model, with important limitations', - style_lines,color_dict['yellow'],space) - - return assessment_print + "Moderate model, with important limitations", + style_lines, + color_dict["yellow"], + space, + ) + return assessment_print - def print_adv_anal(self,pred_type,eval_only,spacing_PFI,data_score): + def print_adv_anal(self, pred_type, eval_only, spacing_PFI, data_score): """ Generates the advanced score analysis section """ - adv_score_dat = '' + adv_score_dat = "" - adv_score_dat += self.module_lines('adv_anal',adv_score_dat) + adv_score_dat += self.module_lines("adv_anal", adv_score_dat) # parts of the robert score section - score_sections = ['adv_flawed'] - score_sections.append('adv_flawed_extra') - score_sections.append('adv_predict') - score_sections.append('adv_test') - score_sections.append('adv_diff_test') - score_sections.append('adv_cv_sd') - score_sections.append('adv_cv_diff') - score_sections.append('adv_sorted_cv') + score_sections = ["adv_flawed"] + score_sections.append("adv_flawed_extra") + score_sections.append("adv_predict") + score_sections.append("adv_test") + score_sections.append("adv_diff_test") + score_sections.append("adv_cv_sd") + score_sections.append("adv_cv_diff") + score_sections.append("adv_sorted_cv") for section in score_sections: columns_score = [] # get two columns to combine and print - for suffix in ['No PFI','PFI']: - + for suffix in ["No PFI", "PFI"]: # add spacing of PFI column - if suffix == 'No PFI': - spacing = '' - elif suffix == 'PFI': + if suffix == "No PFI": + spacing = "" + elif suffix == "PFI": spacing = spacing_PFI - if eval_only and suffix == 'PFI': - columns_score.append('') + if eval_only and suffix == "PFI": + columns_score.append("") else: - - if section == 'adv_flawed': + if section == "adv_flawed": # advanced score analysis 1, flawed models - columns_score.append(adv_flawed(self,suffix,data_score,spacing*2)) + columns_score.append( + adv_flawed(self, suffix, data_score, spacing * 2) + ) - elif section == 'adv_predict': + elif section == "adv_predict": # advanced score analysis 2, predictive ability - columns_score.append(adv_predict(self,suffix,data_score,spacing*2,pred_type)) + columns_score.append( + adv_predict( + self, suffix, data_score, spacing * 2, pred_type + ) + ) - elif section == 'adv_test': + elif section == "adv_test": # advanced score analysis 3 and 3a, predictive ability of CV - columns_score.append(adv_test(self,suffix,data_score,spacing*2,pred_type)) + columns_score.append( + adv_test(self, suffix, data_score, spacing * 2, pred_type) + ) - elif section == 'adv_cv_sd' and pred_type == 'reg': + elif section == "adv_cv_sd" and pred_type == "reg": # advanced score analysis 3b, SD of CV - columns_score.append(adv_cv_sd(self,suffix,data_score,spacing*2)) + columns_score.append( + adv_cv_sd(self, suffix, data_score, spacing * 2) + ) - elif section == 'adv_diff_test': + elif section == "adv_diff_test": # advanced score analysis 3c, difference bwteen RMSE in test vs CV - columns_score.append(adv_diff_test(self,suffix,data_score,spacing*2,pred_type)) + columns_score.append( + adv_diff_test( + self, suffix, data_score, spacing * 2, pred_type + ) + ) - elif section == 'adv_sorted_cv': + elif section == "adv_sorted_cv": # advanced score analysis 3d, descriptor proportion - columns_score.append(adv_sorted_cv(self,suffix,data_score,spacing*2,pred_type)) + columns_score.append( + adv_sorted_cv( + self, suffix, data_score, spacing * 2, pred_type + ) + ) - elif section == 'adv_cv_diff' and pred_type == 'clas': + elif section == "adv_cv_diff" and pred_type == "clas": # advanced score analysis 3b, difference of MCC in model and CV - columns_score.append(adv_cv_diff(self,suffix,data_score,spacing*2,pred_type)) + columns_score.append( + adv_cv_diff( + self, suffix, data_score, spacing * 2, pred_type + ) + ) # Combine both columns adv_score_dat += combine_cols(columns_score) # add corresponding images - section_separator = f'' # reduces line separation separation - misc_dat += f"""
Some general tips to improve the score
""" - misc_dat += f'1. Adding meaningful datapoints might help to improve the model. Also, using a uniform population of datapoints across the whole range of y values usually helps to obtain reliable predictions across the whole range. More information about the range of y values used is available in Section C.
' - misc_dat += f'{style_line}2. Adding meaningful descriptors or replacing/deleting the least useful descriptors used might help. Feature importances are gathered in Section D.' - + style_line = '' # reduces line separation separation + misc_dat += """
Some general tips to improve the score
""" + misc_dat += '1. Adding meaningful datapoints might help to improve the model. Also, using a uniform population of datapoints across the whole range of y values usually helps to obtain reliable predictions across the whole range. More information about the range of y values used is available in Section C.
' + misc_dat += f"{style_line}2. Adding meaningful descriptors or replacing/deleting the least useful descriptors used might help. Feature importances are gathered in Section D." # how to predict new values misc_dat += f""" @@ -624,70 +748,77 @@ def print_misc(self): misc_dat += '{spacing*3}y distribution analysis
+{spacing * 3}y distribution analysis
{y_distrib_sentence} """ @@ -700,25 +831,30 @@ def print_y_distrib(self,pred_type,eval_only,spacing_PFI,warnings_dict): # add separator line and page break distrib_dat += 'Linear model equation_No_PFI
" @@ -728,17 +864,21 @@ def print_features(self,warnings_dict,eval_only,spacing_PFI): columns_eq = [] for i, eq in enumerate(linear_model_eqs): if i == 0: - columns_eq.append(f"{eq}
") + columns_eq.append( + f"{eq}
" + ) else: - columns_eq.append(f"{spacing*3}Linear model equation_PFI
{spacing*3}{eq}
") + columns_eq.append( + f"{spacing * 3}Linear model equation_PFI
{spacing * 3}{eq}
" + ) feature_dat += combine_cols(columns_eq) # add corresponding images - module_path = Path(f'{os.getcwd()}/PREDICT') - - shap_images = glob.glob(f'{module_path}/SHAP_*.png') - pfi_images = glob.glob(f'{module_path}/PFI_*.png') - pearson_images = glob.glob(f'{module_path}/Pearson_*.png') + module_path = Path(f"{os.getcwd()}/PREDICT") + + shap_images = glob.glob(f"{module_path}/SHAP_*.png") + pfi_images = glob.glob(f"{module_path}/PFI_*.png") + pearson_images = glob.glob(f"{module_path}/Pearson_*.png") shap_images = revert_list(shap_images) pfi_images = revert_list(pfi_images) @@ -746,16 +886,18 @@ def print_features(self,warnings_dict,eval_only,spacing_PFI): image_pair_list = [shap_images, pfi_images, pearson_images] - margin_top, margin_bottom = -10,30 - for _,image_pair in enumerate(image_pair_list): - if len(image_pair) < 2 and not eval_only: # Pearson graphs aren't created when >30 descriptors + margin_top, margin_bottom = -10, 30 + for _, image_pair in enumerate(image_pair_list): + if ( + len(image_pair) < 2 and not eval_only + ): # Pearson graphs aren't created when >30 descriptors pair_list = f'Pearson maps not created if >30 descriptors.'
- pair_list += f'{(" ")*15}'
+ pair_list += f"{(' ') * 15}"
if len(image_pair) == 1:
pair_list += f'
Pearson maps not created if >30 descriptors.
' else: pair_list = f''
- pair_list += f'{(" ")*22}'
+ pair_list += f"{(' ') * 22}"
pair_list += f'
{spacing*3}Correlation analysis
+{spacing * 3}Correlation analysis
{pearson_sentence} """ columns_pearson.append(column) @@ -789,116 +936,144 @@ def print_features(self,warnings_dict,eval_only,spacing_PFI): # add separator line and page break feature_dat += '
{version_n_date}
How to cite: {citation}
""" - aqme_workflow,aqme_updated = False,True + aqme_workflow, aqme_updated = False, True crest_workflow = False - if '--aqme' in command_line: + if "--aqme" in command_line: original_command = command_line aqme_workflow = True - command_line = command_line.replace('AQME-ROBERT_','') - self.args.csv_name = f'{self.args.csv_name}'.replace('AQME-ROBERT_','') - if self.args.csv_test != '': - self.args.csv_test = f'{self.args.csv_test}'.replace('AQME-ROBERT_','') + command_line = command_line.replace("AQME-ROBERT_", "") + self.args.csv_name = f"{self.args.csv_name}".replace("AQME-ROBERT_", "") + if self.args.csv_test != "": + self.args.csv_test = f"{self.args.csv_test}".replace("AQME-ROBERT_", "") - if '--program crest' in command_line.lower(): + if "--program crest" in command_line.lower(): crest_workflow = True # make the text more compact if --aqme is used (more lines are included) if aqme_workflow: - first_line = f'' # reduces line separation separation + first_line = '
' # reduces line separation separation else: - first_line = f'
' # reduces line separation separation - reduced_line = f'
' # reduces line separation separation - space = (' ')*4 + first_line = '
' # reduces line separation separation + reduced_line = '
' # reduces line separation separation
+ space = (" ") * 4
# just in case the command lines are so long
- command_line = format_lines(command_line,cmd_line=True)
+ command_line = format_lines(command_line, cmd_line=True)
- # reproducibility section, starts with the icon of reproducibility
+ # reproducibility section, starts with the icon of reproducibility
repro_dat += f"""{first_line}
1. Download these files (the authors should have uploaded the files as supporting information!):
2. Install and adjust the versions of the following Python modules:
2. Install and adjust the versions of the following Python modules:
4. Execution time, Python version and OS:
4. Execution time, Python version and OS:
4. Execution time, Python version and OS:
4. Execution time, Python version and OS:
' # reduces line separation separation + transpa_dat = "" + titles_line = '
' # reduces line separation separation
# add params of the models
transpa_dat += f"""{titles_line}
1. Parameters of the scikit-learn models (same keywords as used in scikit-learn):
2. ROBERT options, including prediction type (REG or CLAS), folds and repeats used for CV, etc:
2. ROBERT options, including prediction type (REG or CLAS), folds and repeats used for CV, etc:
' - return section_dat,params_df - + return section_dat, params_df def get_abbrev(self): """ @@ -1023,132 +1200,136 @@ def get_abbrev(self): """ # starts with the icon of abbreviation - abbrev_dat = '' - abbrev_dat = self.module_lines('abbrev',abbrev_dat) + abbrev_dat = "" + abbrev_dat = self.module_lines("abbrev", abbrev_dat) columns_abbrev = [] - columns_abbrev.append(get_col_text('abbrev_1')) - columns_abbrev.append(get_col_text('abbrev_2')) - columns_abbrev.append(get_col_text('abbrev_3')) + columns_abbrev.append(get_col_text("abbrev_1")) + columns_abbrev.append(get_col_text("abbrev_2")) + columns_abbrev.append(get_col_text("abbrev_3")) abbrev_dat += combine_cols(columns_abbrev) - abbrev_dat +=f'
This score is designed to evaluate the models using different metrics.' - elif module == 'adv_anal': - module_name = 'Section B. Advanced Score Analysis' - section_explain = f'
This section explains each component that comprises the ROBERT score. More details here.' - elif module == 'y_distrib': - module_name = 'Section C. Distribution of y Values' - section_explain = f'
This section shows the distribution of y values within the training and validation sets.' - elif module == 'features': - module_name = 'Section D. Feature Importances' - section_explain = f'
This section presents feature importances measured using the validation set.' - elif module == 'outliers': - module_name = 'Section E. Outlier Analysis' - if pred_type == 'clas': - section_explain = f'
This feature is disabled in classification problems.' + + if module == "score": + module_name = "Section A. ROBERT Score" + section_explain = '
This score is designed to evaluate the models using different metrics.' + elif module == "adv_anal": + module_name = "Section B. Advanced Score Analysis" + section_explain = '
This section explains each component that comprises the ROBERT score. More details here.' + elif module == "y_distrib": + module_name = "Section C. Distribution of y Values" + section_explain = '
This section shows the distribution of y values within the training and validation sets.' + elif module == "features": + module_name = "Section D. Feature Importances" + section_explain = '
This section presents feature importances measured using the validation set.' + elif module == "outliers": + module_name = "Section E. Outlier Analysis" + if pred_type == "clas": + section_explain = '
This feature is disabled in classification problems.' else: - section_explain = f'
This section detects outliers using the standard deviation (SD) of errors from the training set.' - elif module == 'generate': - module_name = 'Section F. Model Screening' + section_explain = '
This section detects outliers using the standard deviation (SD) of errors from the training set.' + elif module == "generate": + module_name = "Section F. Model Screening" if eval_only: - section_explain = f'
The screening of models is disabled when using the EVALUATE module.' + section_explain = '
The screening of models is disabled when using the EVALUATE module.' else: - section_explain = f'
This section compares different combinations of hyperoptimized algorithms and partition sizes. The combined error is calculated as the product of the training error, validation error, and cross-validation error.' - elif module == 'repro': - module_name = 'Section G. Reproducibility' - section_explain = f'
This section provides all the instructions to reproduce the results presented.' - elif module == 'transpa': - module_name = 'Section H. Transparency' - section_explain = f'
This section contains important parameters used in scikit-learn models and ROBERT.' - elif module == 'abbrev': - module_name = 'Section I. Abbreviations' - section_explain = f'
Reference section for the abbreviations used.' - elif module == 'pred': - module_name = 'Section J. New Predictions' - section_explain = f'
Predictions of the external test set added with the csv_test option.' - elif module == 'misc': - module_name = 'Miscellaneous' - section_explain = f'
General tips to improve the models and instructions to predict new values.' - - if module not in ['repro','transpa','misc']: + section_explain = '
This section compares different combinations of hyperoptimized algorithms and partition sizes. The combined error is calculated as the product of the training error, validation error, and cross-validation error.' + elif module == "repro": + module_name = "Section G. Reproducibility" + section_explain = '
This section provides all the instructions to reproduce the results presented.' + elif module == "transpa": + module_name = "Section H. Transparency" + section_explain = '
This section contains important parameters used in scikit-learn models and ROBERT.' + elif module == "abbrev": + module_name = "Section I. Abbreviations" + section_explain = '
Reference section for the abbreviations used.' + elif module == "pred": + module_name = "Section J. New Predictions" + section_explain = '
Predictions of the external test set added with the csv_test option.' + elif module == "misc": + module_name = "Miscellaneous" + section_explain = '
General tips to improve the models and instructions to predict new values.' + + if module not in ["repro", "transpa", "misc"]: module_data = format_lines(module_data) - module_data = '
' + module_data + '
' + module_data + "
'
if not eval_only:
- pair_list += f'{(" ")*22}'
+ pair_list += f"{(' ') * 22}"
pair_list += f'
{spacing*2}{title_col}
+{spacing * 2}{title_col}
{summary}
"""
return column
-def get_metrics(file,suffix,spacing):
+def get_metrics(file, suffix, spacing):
"""
Retrieve the summary of results from the PREDICT dat files
"""
-
- with open(file, 'r', encoding='utf-8') as datfile:
+
+ with open(file, "r", encoding="utf-8") as datfile:
lines = datfile.readlines()
- start_results,stop_results = 0,0
- for i,line in enumerate(lines):
- if suffix == 'No PFI':
- if 'o Summary of results' in line and 'No_PFI:' in line:
- start_results = i+1
- stop_results = i+6
- if suffix == 'PFI':
- if 'o Summary of results' in line and 'No_PFI:' not in line:
- start_results = i+1
- stop_results = i+6
+ start_results, stop_results = 0, 0
+ for i, line in enumerate(lines):
+ if suffix == "No PFI":
+ if "o Summary of results" in line and "No_PFI:" in line:
+ start_results = i + 1
+ stop_results = i + 6
+ if suffix == "PFI":
+ if "o Summary of results" in line and "No_PFI:" not in line:
+ start_results = i + 1
+ stop_results = i + 6
# add the summary of results of PREDICT
- start_results += 4 # skip informaton that aren't metrics
+ start_results += 4 # skip informaton that aren't metrics
summary = []
- for line in lines[start_results:stop_results+1]:
- if 'R2' in line:
- line = line.replace('R2','R2')
+ for line in lines[start_results : stop_results + 1]:
+ if "R2" in line:
+ line = line.replace("R2", "R2")
- if suffix == 'No PFI':
+ if suffix == "No PFI":
summary.append(line[8:])
- elif suffix == 'PFI':
- summary.append(f'{spacing}{spacing}{line[8:]}')
+ elif suffix == "PFI":
+ summary.append(f"{spacing}{spacing}{line[8:]}")
- summary = ''.join(summary)
+ summary = "".join(summary)
column = f"""
{summary}
@@ -128,70 +129,70 @@ def get_metrics(file,suffix,spacing):
return column
-def get_csv_metrics(file,suffix,spacing):
+def get_csv_metrics(file, suffix, spacing):
"""
Retrieve the csv_test results from the PREDICT dat file
"""
-
- results_line = ''
- with open(file, 'r', encoding='utf-8') as datfile:
+
+ results_line = ""
+ with open(file, "r", encoding="utf-8") as datfile:
lines = datfile.readlines()
- for i,line in enumerate(lines):
- if suffix == 'No PFI':
- if 'o Summary of results' in line and 'No_PFI:' in line:
- for j in range(i,i+15):
- if 'o SHAP' in lines[j]:
+ for i, line in enumerate(lines):
+ if suffix == "No PFI":
+ if "o Summary of results" in line and "No_PFI:" in line:
+ for j in range(i, i + 15):
+ if "o SHAP" in lines[j]:
break
- elif '- External test : ' in lines[j]:
+ elif "- External test : " in lines[j]:
results_line = lines[j][25:]
- if suffix == 'PFI':
- if 'o Summary of results' in line and 'No_PFI:' not in line:
- for j in range(i,i+15):
- if 'o SHAP' in lines[j]:
+ if suffix == "PFI":
+ if "o Summary of results" in line and "No_PFI:" not in line:
+ for j in range(i, i + 15):
+ if "o SHAP" in lines[j]:
break
- elif '- External test : ' in lines[j]:
+ elif "- External test : " in lines[j]:
results_line = lines[j][25:]
# start the csv_test section
- metrics_dat = f'{spacing*2}External test metrics
' + metrics_dat = f'{spacing * 2}External test metrics
' # add line with model metrics (if any) - if results_line != '': - metrics_dat += f'{spacing*2}{results_line}
' - + if results_line != "": + metrics_dat += f'{spacing * 2}{results_line}
' + return metrics_dat - + else: - return '' + return "" -def get_csv_pred(suffix,path_csv_test,y_value,names,spacing): +def get_csv_pred(suffix, path_csv_test, y_value, names, spacing): """ Retrieve the csv_test results from the PREDICT dat file """ - - pred_line = '' - csv_test_folder = f'{os.getcwd()}/{os.path.dirname(path_csv_test)}' - csv_test_list = glob.glob(f'{csv_test_folder}/*.csv') + + pred_line = "" + csv_test_folder = f"{os.getcwd()}/{os.path.dirname(path_csv_test)}" + csv_test_list = glob.glob(f"{csv_test_folder}/*.csv") for file in csv_test_list: - if suffix == 'No PFI': - if '_No_PFI.csv' in file: + if suffix == "No PFI": + if "_No_PFI.csv" in file: csv_test_file = file - if suffix == 'PFI': - if '_No_PFI.csv' not in file and '_PFI.csv' in file: + if suffix == "PFI": + if "_No_PFI.csv" not in file and "_PFI.csv" in file: csv_test_file = file - csv_test_df = pd.read_csv(csv_test_file, encoding='utf-8') + csv_test_df = pd.read_csv(csv_test_file, encoding="utf-8") # start the csv_test section - pred_line = f'{spacing*2}External test predictions (sorted, max. 20 shown)
' + pred_line = f'{spacing * 2}External test predictions (sorted, max. 20 shown)
' - if suffix == 'No PFI': - pred_line += f'{spacing*2}From /PREDICT/csv_test/...No_PFI.csv
' - elif suffix == 'PFI': - pred_line += f'{spacing*2}From /PREDICT/csv_test/..._PFI.csv
' + if suffix == "No PFI": + pred_line += f'{spacing * 2}From /PREDICT/csv_test/...No_PFI.csv
' + elif suffix == "PFI": + pred_line += f'{spacing * 2}From /PREDICT/csv_test/..._PFI.csv
' - pred_line += ''' - ''' + """ y_val_exist = False - if f'{y_value}' in csv_test_df.columns: + if f"{y_value}" in csv_test_df.columns: y_val_exist = True # adjust format of headers names_head = names if len(str(names_head)) > 12: - names_head = f'{str(names_head[:9])}...' + names_head = f"{str(names_head[:9])}..." y_value_head = y_value if len(str(y_value_head)) > 12: - y_value_head = f'{str(y_value_head[:9])}...' + y_value_head = f"{str(y_value_head[:9])}..." - if pred_line != '': - if suffix == 'No PFI': + if pred_line != "": + if suffix == "No PFI": margin_left = 0 else: margin_left = 27 - pred_line += f''' + pred_line += f"""| {names_head} | ''' +{names_head} | """ if y_val_exist: - pred_line += f''' -{y_value_head} | ''' - if f'{y_value}_pred_sd' in csv_test_df: - pred_line += f''' + pred_line += f""" +{y_value_head} | """ + if f"{y_value}_pred_sd" in csv_test_df: + pred_line += f"""{y_value_head}_pred ± sd | -{y_value_head}_pred | - ''' - + """ + # retrieve and sort the values if not y_val_exist: - csv_test_df[y_value] = csv_test_df[f'{y_value}_pred'] + csv_test_df[y_value] = csv_test_df[f"{y_value}_pred"] # in clas problems, there are no SD in the predictions (we use a list of 0s) - if f'{y_value}_pred_sd' in csv_test_df: - sd_list = csv_test_df[f'{y_value}_pred_sd'] + if f"{y_value}_pred_sd" in csv_test_df: + sd_list = csv_test_df[f"{y_value}_pred_sd"] else: - sd_list = [0] * len(csv_test_df[f'{y_value}_pred']) - - y_pred_sorted, y_sorted, names_sorted, sd_sorted = (list(t) for t in zip(*sorted(zip(csv_test_df[f'{y_value}_pred'], csv_test_df[y_value], csv_test_df[names], sd_list), reverse=True))) + sd_list = [0] * len(csv_test_df[f"{y_value}_pred"]) + + y_pred_sorted, y_sorted, names_sorted, sd_sorted = ( + list(t) + for t in zip( + *sorted( + zip( + csv_test_df[f"{y_value}_pred"], + csv_test_df[y_value], + csv_test_df[names], + sd_list, + ), + reverse=True, + ) + ) + ) max_table = False if len(y_pred_sorted) > 20: max_table = True count_entries = 0 - for y_val_pred, y_val, name, sd in zip(y_pred_sorted, y_sorted, names_sorted, sd_sorted): + for y_val_pred, y_val, name, sd in zip( + y_pred_sorted, y_sorted, names_sorted, sd_sorted + ): # adjust format of entries if len(str(name)) > 12: - name = f'{str(name[:9])}...' + name = f"{str(name[:9])}..." y_val_pred = round(y_val_pred, 2) y_val = round(y_val, 2) sd = round(sd, 2) - if f'{y_value}_pred_sd' in csv_test_df: - y_val_pred_formatted = f'{y_val_pred} ± {sd}' + if f"{y_value}_pred_sd" in csv_test_df: + y_val_pred_formatted = f"{y_val_pred} ± {sd}" else: - y_val_pred_formatted = f'{y_val_pred}' + y_val_pred_formatted = f"{y_val_pred}" add_entry = True # if there are more than 20 predictions, only 20 values will be shown if max_table and count_entries >= 10: add_entry = False if count_entries == 10: - pred_line += f''' + pred_line += """
| ... | ''' +... | """ if y_val_exist: - pred_line += f''' -... | ''' - pred_line += f''' + pred_line += """ +... | """ + pred_line += """... | -
| {name} | ''' +{name} | """ if y_val_exist: - pred_line += f''' -{y_val} | ''' - pred_line += f''' + pred_line += f""" +{y_val} | """ + pred_line += f"""{y_val_pred_formatted} | -
{spacing}' part_line_format = f'
{spacing}' - score_title = f''' · Score {data_score[f'robert_score_{suffix}']}''' - if suffix == 'No PFI': - caption = f'{spacing}{title_no_pfi.replace(":",score_title)}' + score_title = ( + f""" · Score {data_score[f"robert_score_{suffix}"]}""" + ) + if suffix == "No PFI": + caption = f"{spacing}{title_no_pfi.replace(':', score_title)}" - elif suffix == 'PFI': - caption = f'{spacing}{title_pfi.replace(":",score_title)}' + elif suffix == "PFI": + caption = f"{spacing}{title_pfi.replace(':', score_title)}" - partitions_ratio = data_score['proportion_ratio_print'].split('- Proportion ')[1] + partitions_ratio = data_score["proportion_ratio_print"].split("- Proportion ")[1] if not eval_only: - title_line = f'{caption}' + title_line = f"{caption}" else: - title_line = 'Summary and score of your model (No PFI)' + title_line = "Summary and score of your model (No PFI)" column = f"""
{title_line}
- {ML_line_format}Model = {data_score['ML_model']} · {partitions_ratio} - {part_line_format}Points(train+validation):descriptors = {data_score[f'points_descp_ratio_{suffix}']} + {ML_line_format}Model = {data_score["ML_model"]} · {partitions_ratio} + {part_line_format}Points(train+validation):descriptors = {data_score[f"points_descp_ratio_{suffix}"]}{score_info}
""" @@ -420,17 +446,17 @@ def get_col_score(score_info,data_score,suffix,spacing,eval_only): return column -def adv_flawed(self,suffix,data_score,spacing): +def adv_flawed(self, suffix, data_score, spacing): """ Gather the advanced analysis of flawed models """ - score_flawed = data_score[f'flawed_mod_score_{suffix}'] + score_flawed = data_score[f"flawed_mod_score_{suffix}"] if score_flawed == 0: - flaw_result = f'The model predicts right for the right reasons.' + flaw_result = "The model predicts right for the right reasons." else: - flaw_result = f'Warning! The model probably has important flaws.' + flaw_result = "Warning! The model probably has important flaws." # adds a bit more space if there is no test set score_adv_flawed = f'{spacing}'
@@ -443,7 +469,7 @@ def adv_flawed(self,suffix,data_score,spacing):
return column
-def adv_predict(self,suffix,data_score,spacing,pred_type):
+def adv_predict(self, suffix, data_score, spacing, pred_type):
"""
Gather the advanced analysis of predictive ability
@@ -454,19 +480,19 @@ def adv_predict(self,suffix,data_score,spacing,pred_type):
if 0.30 < MCC <= 0.50 => 1, else => 0
"""
- score_predict = data_score.get(f'cv_score_combined_{suffix}', 0)
+ score_predict = data_score.get(f"cv_score_combined_{suffix}", 0)
cv_type = data_score.get(f"cv_type_{suffix}", "10x 5-fold CV")
- if pred_type == 'reg':
- predict_image = f'{self.args.path_icons}/score_w_2_{score_predict}.jpg'
- metric_type = ['Scaled RMSE','R2']
- scaled_rmse_cv = data_score.get(f'scaled_rmse_cv_{suffix}', 0)
- r2_cv = data_score.get(f'r2_cv_{suffix}', 0)
+ if pred_type == "reg":
+ predict_image = f"{self.args.path_icons}/score_w_2_{score_predict}.jpg"
+ metric_type = ["Scaled RMSE", "R2"]
+ scaled_rmse_cv = data_score.get(f"scaled_rmse_cv_{suffix}", 0)
+ r2_cv = data_score.get(f"r2_cv_{suffix}", 0)
- predict_result = f'{metric_type[0]} ({cv_type}) = {scaled_rmse_cv}%.'
- predict_result += f'
{spacing}{metric_type[1]} ({cv_type}) = {r2_cv}.'
- thres_line = 'Scaled RMSE ≤ 10%: +2, Scaled RMSE ≤ 20%: +1.'
- thres_line += f'
{spacing}R2 < 0.5: -2, R2 < 0.7: -1'
+ predict_result = f"{metric_type[0]} ({cv_type}) = {scaled_rmse_cv}%."
+ predict_result += f"
{spacing}{metric_type[1]} ({cv_type}) = {r2_cv}."
+ thres_line = "Scaled RMSE ≤ 10%: +2, Scaled RMSE ≤ 20%: +1."
+ thres_line += f"
{spacing}R2 < 0.5: -2, R2 < 0.7: -1"
init_sep = f'
{spacing}' score_adv_pred = f'
{spacing}'
column = f"""{init_sep}2. CV predictions of the model ({score_predict} / 2 )
{spacing}'
@@ -499,20 +525,20 @@ def adv_predict(self,suffix,data_score,spacing,pred_type):
return column
-def adv_test(self,suffix,data_score,spacing,pred_type):
+def adv_test(self, suffix, data_score, spacing, pred_type):
"""
Gather the advanced analysis of predictive ability with the test set
"""
- score_test = data_score.get(f'test_score_combined_{suffix}', 0)
+ score_test = data_score.get(f"test_score_combined_{suffix}", 0)
- if pred_type == 'reg':
- test_image = f'{self.args.path_icons}/score_w_2_{score_test}.jpg'
- metric_type = ['Scaled RMSE','R2']
- predict_result = f'{metric_type[0]} (test set) = {data_score.get(f"scaled_rmse_test_{suffix}", 0)}%.'
- predict_result += f'
{spacing}{metric_type[1]} (test set) = {data_score.get(f"r2_test_{suffix}", 0)}.'
- thres_line = 'Scaled RMSE ≤ 10%: +2, Scaled RMSE ≤ 20%: +1.'
- thres_line += f'
{spacing}R2 < 0.5: -2, R2 < 0.7: -1'
+ if pred_type == "reg":
+ test_image = f"{self.args.path_icons}/score_w_2_{score_test}.jpg"
+ metric_type = ["Scaled RMSE", "R2"]
+ predict_result = f"{metric_type[0]} (test set) = {data_score.get(f'scaled_rmse_test_{suffix}', 0)}%."
+ predict_result += f"
{spacing}{metric_type[1]} (test set) = {data_score.get(f'r2_test_{suffix}', 0)}."
+ thres_line = "Scaled RMSE ≤ 10%: +2, Scaled RMSE ≤ 20%: +1."
+ thres_line += f"
{spacing}R2 < 0.5: -2, R2 < 0.7: -1"
score_adv_cv = f'
{spacing}'
column = f"""{score_adv_cv}
{spacing}3. Predictive ability & overfitting
{spacing}3a. Predictions test set ({score_test} / 2 )
{spacing}'
column = f"""{score_adv_cv}
{spacing}3. Predictive ability & overfitting
{spacing}' column = f"""
{spacing}3b. Prediction accuracy test vs CV ({score_diff_test} / 2 )
{spacing}' column = f"""
{spacing}3c. Avg. standard deviation (SD) ({score_cv_sd} / 2 )
{spacing}' column = f"""
{spacing}{title} ({score_cv_diff} / 2 )
{spacing}{title_cap} (sorted CV) ({score_sorted} / 2 )
' # reduces line separation separation + reduced_line = '
' # reduces line separation separation first_line = '
' - if type_thres == 'abbrev_1': - abbrev_list = ['ACC: accuracy', - 'ADAB: AdaBoost', - 'CSV: comma separated values', - 'CLAS: classification', - 'CV: cross-validation', - 'F1 score: balanced F-score', - 'GB: gradient boosting', - 'GP: gaussian process', - 'XGB: extreme gradient boosting' + if type_thres == "abbrev_1": + abbrev_list = [ + "ACC: accuracy", + "ADAB: AdaBoost", + "CSV: comma separated values", + "CLAS: classification", + "CV: cross-validation", + "F1 score: balanced F-score", + "GB: gradient boosting", + "GP: gaussian process", + "XGB: extreme gradient boosting", ] - elif type_thres == 'abbrev_2': - abbrev_list = ['KN: k-nearest neighbors', - 'MAE: root-mean-square error', - "MCC: Matthew's correl. coefficient", - 'ML: machine learning', - 'MVL: multivariate lineal models', - 'NN: neural network', - 'PFI: permutation feature importance', - 'R2: coefficient of determination' + elif type_thres == "abbrev_2": + abbrev_list = [ + "KN: k-nearest neighbors", + "MAE: root-mean-square error", + "MCC: Matthew's correl. coefficient", + "ML: machine learning", + "MVL: multivariate lineal models", + "NN: neural network", + "PFI: permutation feature importance", + "R2: coefficient of determination", ] - elif type_thres == 'abbrev_3': - abbrev_list = ['REG: Regression', - 'RF: random forest', - 'RMSE: root mean square error', - 'RND: random', - 'SHAP: Shapley additive explanations', - 'VR: voting regressor', + elif type_thres == "abbrev_3": + abbrev_list = [ + "REG: Regression", + "RF: random forest", + "RMSE: root mean square error", + "RND: random", + "SHAP: Shapley additive explanations", + "VR: voting regressor", ] - column = '' - for i,ele in enumerate(abbrev_list): + column = "" + for i, ele in enumerate(abbrev_list): if i == 0: column += f"""{first_line}{ele}
""" @@ -733,112 +767,128 @@ def get_col_text(type_thres): return column -def get_col_transpa(params_dict,suffix,section,spacing): +def get_col_transpa(params_dict, suffix, section, spacing): """ Gather the information regarding the model parameters represented in the Reproducibility section """ - first_line = f'{spacing*2}' # reduces line separation separation - reduced_line = f'
{spacing*2}' # reduces line separation separation - - if suffix == 'No PFI': - caption = f'{title_no_pfi}' - - elif suffix == 'PFI': - caption = f'{title_pfi}' - - excluded_params = [f"combined_{params_dict['error_type']}", 'train', 'X_descriptors', 'y', 'error_train', 'cv_error', 'names'] - misc_params = ['type','error_type','split','kfold','repeat_kfolds','seed'] - if params_dict['type'] == 'reg': - model_type = 'Regressor' - elif params_dict['type'] == 'clas': - model_type = 'Classifier' - models_dict = {'RF': f'RandomForest{model_type}', - 'MVL': 'LinearRegression', - 'GB': f'GradientBoosting{model_type}', - 'XGB': f'XGB{model_type}', - 'NN': f'MLP{model_type}', - 'GP': f'GaussianProcess{model_type}', - 'ADAB': f'AdaBoost{model_type}', - 'VR': f'Voting{model_type}', - } - - col_info,sklearn_model = '','' - for _,ele in enumerate(params_dict.keys()): + first_line = f'
{spacing * 2}' # reduces line separation separation + reduced_line = f'
{spacing * 2}' # reduces line separation separation + + if suffix == "No PFI": + caption = f"{title_no_pfi}" + + elif suffix == "PFI": + caption = f"{title_pfi}" + + excluded_params = [ + f"combined_{params_dict['error_type']}", + "train", + "X_descriptors", + "y", + "error_train", + "cv_error", + "names", + ] + misc_params = ["type", "error_type", "split", "kfold", "repeat_kfolds", "seed"] + if params_dict["type"] == "reg": + model_type = "Regressor" + elif params_dict["type"] == "clas": + model_type = "Classifier" + models_dict = { + "RF": f"RandomForest{model_type}", + "MVL": "LinearRegression", + "GB": f"GradientBoosting{model_type}", + "XGB": f"XGB{model_type}", + "NN": f"MLP{model_type}", + "GP": f"GaussianProcess{model_type}", + "ADAB": f"AdaBoost{model_type}", + "VR": f"Voting{model_type}", + } + + col_info, sklearn_model = "", "" + for _, ele in enumerate(params_dict.keys()): if ele not in excluded_params: - if ele == 'model' and section == 'model_section': + if ele == "model" and section == "model_section": sklearn_model = models_dict[params_dict[ele].upper()] sklearn_model = f"""{first_line}sklearn model: {sklearn_model}
""" - elif section == 'model_section' and ele.lower() not in misc_params: - if ele == 'params': - model_params = ast.literal_eval(params_dict['params']) + elif section == "model_section" and ele.lower() not in misc_params: + if ele == "params": + model_params = ast.literal_eval(params_dict["params"]) for param in model_params: - col_info += f"""{reduced_line}{param}: {model_params[param]}""" - elif section == 'misc_section' and ele.lower() in misc_params: - if col_info == '': + col_info += ( + f"""{reduced_line}{param}: {model_params[param]}""" + ) + elif section == "misc_section" and ele.lower() in misc_params: + if col_info == "": col_info += f"""{first_line}{ele}: {params_dict[ele]}""" else: col_info += f"""{reduced_line}{ele}: {params_dict[ele]}""" - - column = f"""{spacing*2}{caption}
+ + column = f"""{spacing * 2}{caption}
{sklearn_model}{col_info} """ return column -def calc_score(dat_files,suffix,pred_type,data_score): - ''' +def calc_score(dat_files, suffix, pred_type, data_score): + """ Calculates ROBERT score - ''' + """ - data_score = get_predict_scores(dat_files['PREDICT'],suffix,pred_type,data_score) + data_score = get_predict_scores(dat_files["PREDICT"], suffix, pred_type, data_score) - data_score = get_verify_scores(dat_files['VERIFY'],suffix,pred_type,data_score) + data_score = get_verify_scores(dat_files["VERIFY"], suffix, pred_type, data_score) - if pred_type == 'reg': - robert_score = data_score.get(f'cv_score_combined_{suffix}', 0) + data_score.get(f'test_score_combined_{suffix}', 0) \ - + data_score.get(f'cv_sd_score_{suffix}', 0) + data_score.get(f'diff_scaled_rmse_score_{suffix}', 0) \ - + data_score.get(f'flawed_mod_score_{suffix}', 0) + data_score.get(f'sorted_cv_score_{suffix}', 0) + if pred_type == "reg": + robert_score = ( + data_score.get(f"cv_score_combined_{suffix}", 0) + + data_score.get(f"test_score_combined_{suffix}", 0) + + data_score.get(f"cv_sd_score_{suffix}", 0) + + data_score.get(f"diff_scaled_rmse_score_{suffix}", 0) + + data_score.get(f"flawed_mod_score_{suffix}", 0) + + data_score.get(f"sorted_cv_score_{suffix}", 0) + ) # Adjustment to avoid negative values if robert_score < 0: robert_score = 0 # Assign the final value - data_score[f'robert_score_{suffix}'] = robert_score + data_score[f"robert_score_{suffix}"] = robert_score - elif pred_type == 'clas': + elif pred_type == "clas": # Calculate the difference between CV MCC and test MCC - mcc_cv = data_score.get(f'r2_cv_{suffix}', 0) - mcc_test = data_score.get(f'r2_test_{suffix}', 0) + mcc_cv = data_score.get(f"r2_cv_{suffix}", 0) + mcc_test = data_score.get(f"r2_test_{suffix}", 0) diff_mcc = round(np.abs(mcc_test - mcc_cv), 2) # Assign a score based on the MCC gap (e.g., ±2, ±1, 0) - data_score[f'diff_mcc_score_{suffix}'] = 0 + data_score[f"diff_mcc_score_{suffix}"] = 0 if diff_mcc < 0.15: - data_score[f'diff_mcc_score_{suffix}'] += 2 + data_score[f"diff_mcc_score_{suffix}"] += 2 elif diff_mcc <= 0.30: - data_score[f'diff_mcc_score_{suffix}'] += 1 + data_score[f"diff_mcc_score_{suffix}"] += 1 # Sum scores similarly to regression: robert_score = ( - data_score.get(f'cv_score_combined_{suffix}', 0) - + data_score.get(f'test_score_combined_{suffix}', 0) - + data_score.get(f'flawed_mod_score_{suffix}', 0) - + data_score.get(f'sorted_cv_score_{suffix}', 0) - + data_score.get(f'diff_mcc_score_{suffix}', 0) - + data_score.get(f'descp_score_{suffix}', 0) + data_score.get(f"cv_score_combined_{suffix}", 0) + + data_score.get(f"test_score_combined_{suffix}", 0) + + data_score.get(f"flawed_mod_score_{suffix}", 0) + + data_score.get(f"sorted_cv_score_{suffix}", 0) + + data_score.get(f"diff_mcc_score_{suffix}", 0) + + data_score.get(f"descp_score_{suffix}", 0) ) # Adjustment to avoid negative values if robert_score < 0: robert_score = 0 # Assign the final value - data_score[f'robert_score_{suffix}'] = robert_score + data_score[f"robert_score_{suffix}"] = robert_score return data_score - -def get_verify_scores(dat_verify,suffix,pred_type,data_score): + +def get_verify_scores(dat_verify, suffix, pred_type, data_score): """ Calculates scores that come from the VERIFY module (VERIFY tests) """ @@ -847,175 +897,271 @@ def get_verify_scores(dat_verify,suffix,pred_type,data_score): flawed_score = 0 failed_tests = 0 sorted_cv_score = 0 - for i,line in enumerate(dat_verify): + for i, line in enumerate(dat_verify): # set starting points for No PFI and PFI models - if suffix == 'No PFI': - if '------- ' in line and '(No PFI)' in line: + if suffix == "No PFI": + if "------- " in line and "(No PFI)" in line: start_data = True - elif '------- ' in line and 'with PFI' in line: + elif "------- " in line and "with PFI" in line: start_data = False - if suffix == 'PFI': - if '------- ' in line and 'with PFI' in line: + if suffix == "PFI": + if "------- " in line and "with PFI" in line: start_data = True - + if start_data: - error_keyword = "rmse" if pred_type.lower() == 'reg' else "mcc" + error_keyword = "rmse" if pred_type.lower() == "reg" else "mcc" if f"Original {error_keyword.upper()} (" in line: - for j in range(i+1,i+4): # y-mean, y-shuffle and onehot tests - if 'UNCLEAR' in dat_verify[j]: + for j in range(i + 1, i + 4): # y-mean, y-shuffle and onehot tests + if "UNCLEAR" in dat_verify[j]: flawed_score -= 1 - elif 'FAILED' in dat_verify[j]: + elif "FAILED" in dat_verify[j]: flawed_score -= 2 failed_tests += 1 - if '- Sorted ' in dat_verify[i+4]: - sorted_cv_results = dat_verify[i+4].split(f'{error_keyword.upper()} = ')[-1] + if "- Sorted " in dat_verify[i + 4]: + sorted_cv_results = dat_verify[i + 4].split( + f"{error_keyword.upper()} = " + )[-1] sorted_cv_results = ast.literal_eval(sorted_cv_results) - if pred_type.lower() == 'reg': - data_score[f'scaled_{error_keyword}_sorted_{suffix}'] = [round((val/data_score[f'y_range_{suffix}'])*100,2) for val in sorted_cv_results] + if pred_type.lower() == "reg": + data_score[f"scaled_{error_keyword}_sorted_{suffix}"] = [ + round((val / data_score[f"y_range_{suffix}"]) * 100, 2) + for val in sorted_cv_results + ] else: - data_score[f'scaled_{error_keyword}_sorted_{suffix}'] = sorted_cv_results # no scaling for MCC + data_score[f"scaled_{error_keyword}_sorted_{suffix}"] = ( + sorted_cv_results # no scaling for MCC + ) # define min and max values - data_score[f'min_scaled_{error_keyword}_{suffix}'] = min(data_score[f'scaled_{error_keyword}_sorted_{suffix}']) - idx_min_scaled_rmse = data_score[f'scaled_{error_keyword}_sorted_{suffix}'].index(data_score[f'min_scaled_{error_keyword}_{suffix}']) - data_score[f'max_scaled_{error_keyword}_{suffix}'] = max(data_score[f'scaled_{error_keyword}_sorted_{suffix}']) - idx_max_scaled_rmse = data_score[f'scaled_{error_keyword}_sorted_{suffix}'].index(data_score[f'max_scaled_{error_keyword}_{suffix}']) - - data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'] = [] - for idx,err in enumerate(data_score[f'scaled_{error_keyword}_sorted_{suffix}']): - if pred_type.lower() == 'reg': + data_score[f"min_scaled_{error_keyword}_{suffix}"] = min( + data_score[f"scaled_{error_keyword}_sorted_{suffix}"] + ) + idx_min_scaled_rmse = data_score[ + f"scaled_{error_keyword}_sorted_{suffix}" + ].index(data_score[f"min_scaled_{error_keyword}_{suffix}"]) + data_score[f"max_scaled_{error_keyword}_{suffix}"] = max( + data_score[f"scaled_{error_keyword}_sorted_{suffix}"] + ) + idx_max_scaled_rmse = data_score[ + f"scaled_{error_keyword}_sorted_{suffix}" + ].index(data_score[f"max_scaled_{error_keyword}_{suffix}"]) + + data_score[f"scaled_{error_keyword}_results_sorted_{suffix}"] = [] + for idx, err in enumerate( + data_score[f"scaled_{error_keyword}_sorted_{suffix}"] + ): + if pred_type.lower() == "reg": if idx == idx_min_scaled_rmse: - data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'].append('min') - elif err <= (data_score[f'min_scaled_{error_keyword}_{suffix}']*1.25): - data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'].append('pass') + data_score[ + f"scaled_{error_keyword}_results_sorted_{suffix}" + ].append("min") + elif err <= ( + data_score[f"min_scaled_{error_keyword}_{suffix}"] + * 1.25 + ): + data_score[ + f"scaled_{error_keyword}_results_sorted_{suffix}" + ].append("pass") else: - data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'].append('fail') + data_score[ + f"scaled_{error_keyword}_results_sorted_{suffix}" + ].append("fail") else: if idx == idx_max_scaled_rmse: - data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'].append('max') - elif err >= (data_score[f'max_scaled_{error_keyword}_{suffix}']*0.75): - data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'].append('pass') + data_score[ + f"scaled_{error_keyword}_results_sorted_{suffix}" + ].append("max") + elif err >= ( + data_score[f"max_scaled_{error_keyword}_{suffix}"] + * 0.75 + ): + data_score[ + f"scaled_{error_keyword}_results_sorted_{suffix}" + ].append("pass") else: - data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'].append('fail') + data_score[ + f"scaled_{error_keyword}_results_sorted_{suffix}" + ].append("fail") - sorted_cv_score = int(data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'].count('pass')/2) + sorted_cv_score = int( + data_score[ + f"scaled_{error_keyword}_results_sorted_{suffix}" + ].count("pass") + / 2 + ) # adjust max 1 point for flawed tests if flawed_score > 1: flawed_score = 1 - + # stores data - data_score[f'flawed_mod_score_{suffix}'] = flawed_score - data_score[f'failed_tests_{suffix}'] = failed_tests - data_score[f'sorted_cv_score_{suffix}'] = sorted_cv_score + data_score[f"flawed_mod_score_{suffix}"] = flawed_score + data_score[f"failed_tests_{suffix}"] = failed_tests + data_score[f"sorted_cv_score_{suffix}"] = sorted_cv_score return data_score -def get_predict_scores(dat_predict,suffix,pred_type,data_score): +def get_predict_scores(dat_predict, suffix, pred_type, data_score): """ Calculates scores that come from the PREDICT module (R2 or accuracy, datapoints:descriptors ratio, outlier proportion) """ start_data = False - data_score[f'rmse_score_{suffix}'] = 0 - data_score[f'cv_type_{suffix}'] = "10x 5-fold CV" - - for i,line in enumerate(dat_predict): + data_score[f"rmse_score_{suffix}"] = 0 + data_score[f"cv_type_{suffix}"] = "10x 5-fold CV" + for i, line in enumerate(dat_predict): # set starting points for No PFI and PFI models - if suffix == 'No PFI': - if '------- ' in line and '(No PFI)' in line: + if suffix == "No PFI": + if "------- " in line and "(No PFI)" in line: start_data = True - elif '------- ' in line and 'with PFI' in line: + elif "------- " in line and "with PFI" in line: start_data = False - if suffix == 'PFI': - if '------- ' in line and 'with PFI' in line: + if suffix == "PFI": + if "------- " in line and "with PFI" in line: start_data = True - + if start_data: # model type - if line.startswith(' - Model:'): - data_score['ML_model'] = line.split()[-1] + if line.startswith(" - Model:"): + data_score["ML_model"] = line.split()[-1] # R2 and proportion - if 'o Summary of results' in line: - data_score['proportion_ratio_print'] = dat_predict[i+2] - data_score[f'points_descp_ratio_{suffix}'] = dat_predict[i+4].split()[-1] + if "o Summary of results" in line: + data_score["proportion_ratio_print"] = dat_predict[i + 2] + data_score[f"points_descp_ratio_{suffix}"] = dat_predict[i + 4].split()[ + -1 + ] # scaled RMSE/MCC from test (if any) or validation - if pred_type == 'reg': - if '-fold CV : R2 =' in dat_predict[i+5]: - data_score[f'rmse_cv_{suffix}'] = float(dat_predict[i+5].split()[-1]) - data_score[f"cv_type_{suffix}"] = ' '.join([ele for ele in dat_predict[i+5].split()[1:4]]) - data_score[f'r2_cv_{suffix}'] = float(dat_predict[i+5].split(',')[0].split()[-1]) - if 'Test : R2 =' in dat_predict[i+6]: - data_score[f'rmse_test_{suffix}'] = float(dat_predict[i+6].split()[-1]) - data_score[f'r2_test_{suffix}'] = float(dat_predict[i+6].split(',')[0].split()[-1]) - if '- y range of dataset' in dat_predict[i+8]: - data_score[f'y_range_{suffix}'] = float(dat_predict[i+8].split()[-1]) - - data_score[f'scaled_rmse_cv_{suffix}'] = round((data_score[f'rmse_cv_{suffix}']/data_score[f'y_range_{suffix}'])*100,2) - data_score[f'scaled_rmse_test_{suffix}'] = round((data_score[f'rmse_test_{suffix}']/data_score[f'y_range_{suffix}'])*100,2) - - data_score[f'cv_score_rmse_{suffix}'] = score_rmse_mcc(pred_type,data_score[f'scaled_rmse_cv_{suffix}']) - data_score[f'test_score_rmse_{suffix}'] = score_rmse_mcc(pred_type,data_score[f'scaled_rmse_test_{suffix}']) + if pred_type == "reg": + if "-fold CV : R2 =" in dat_predict[i + 5]: + data_score[f"rmse_cv_{suffix}"] = float( + dat_predict[i + 5].split()[-1] + ) + data_score[f"cv_type_{suffix}"] = " ".join( + [ele for ele in dat_predict[i + 5].split()[1:4]] + ) + data_score[f"r2_cv_{suffix}"] = float( + dat_predict[i + 5].split(",")[0].split()[-1] + ) + if "Test : R2 =" in dat_predict[i + 6]: + data_score[f"rmse_test_{suffix}"] = float( + dat_predict[i + 6].split()[-1] + ) + data_score[f"r2_test_{suffix}"] = float( + dat_predict[i + 6].split(",")[0].split()[-1] + ) + if "- y range of dataset" in dat_predict[i + 8]: + data_score[f"y_range_{suffix}"] = float( + dat_predict[i + 8].split()[-1] + ) + + data_score[f"scaled_rmse_cv_{suffix}"] = round( + ( + data_score[f"rmse_cv_{suffix}"] + / data_score[f"y_range_{suffix}"] + ) + * 100, + 2, + ) + data_score[f"scaled_rmse_test_{suffix}"] = round( + ( + data_score[f"rmse_test_{suffix}"] + / data_score[f"y_range_{suffix}"] + ) + * 100, + 2, + ) + + data_score[f"cv_score_rmse_{suffix}"] = score_rmse_mcc( + pred_type, data_score[f"scaled_rmse_cv_{suffix}"] + ) + data_score[f"test_score_rmse_{suffix}"] = score_rmse_mcc( + pred_type, data_score[f"scaled_rmse_test_{suffix}"] + ) # get penalties for R2 - data_score[f'cv_penalty_r2_{suffix}'] = calc_penalty_r2(data_score[f'r2_cv_{suffix}']) - data_score[f'test_penalty_r2_{suffix}'] = calc_penalty_r2(data_score[f'r2_test_{suffix}']) + data_score[f"cv_penalty_r2_{suffix}"] = calc_penalty_r2( + data_score[f"r2_cv_{suffix}"] + ) + data_score[f"test_penalty_r2_{suffix}"] = calc_penalty_r2( + data_score[f"r2_test_{suffix}"] + ) # combined scores RMSE/R2 (min 0) - data_score[f'cv_score_combined_{suffix}'] = data_score[f'cv_score_rmse_{suffix}'] + data_score[f'cv_penalty_r2_{suffix}'] - if data_score[f'cv_score_combined_{suffix}'] < 0: - data_score[f'cv_score_combined_{suffix}'] = 0 - data_score[f'test_score_combined_{suffix}'] = data_score[f'test_score_rmse_{suffix}'] + data_score[f'test_penalty_r2_{suffix}'] - if data_score[f'test_score_combined_{suffix}'] < 0: - data_score[f'test_score_combined_{suffix}'] = 0 + data_score[f"cv_score_combined_{suffix}"] = ( + data_score[f"cv_score_rmse_{suffix}"] + + data_score[f"cv_penalty_r2_{suffix}"] + ) + if data_score[f"cv_score_combined_{suffix}"] < 0: + data_score[f"cv_score_combined_{suffix}"] = 0 + data_score[f"test_score_combined_{suffix}"] = ( + data_score[f"test_score_rmse_{suffix}"] + + data_score[f"test_penalty_r2_{suffix}"] + ) + if data_score[f"test_score_combined_{suffix}"] < 0: + data_score[f"test_score_combined_{suffix}"] = 0 diff_score = 0 # relative difference between RMSE from test and CV - data_score[f'factor_scaled_rmse_{suffix}'] = data_score[f'scaled_rmse_test_{suffix}'] / data_score[f'scaled_rmse_cv_{suffix}'] - if data_score[f'factor_scaled_rmse_{suffix}'] <= 1.25: + data_score[f"factor_scaled_rmse_{suffix}"] = ( + data_score[f"scaled_rmse_test_{suffix}"] + / data_score[f"scaled_rmse_cv_{suffix}"] + ) + if data_score[f"factor_scaled_rmse_{suffix}"] <= 1.25: diff_score += 2 - elif data_score[f'factor_scaled_rmse_{suffix}'] <= 1.5: + elif data_score[f"factor_scaled_rmse_{suffix}"] <= 1.5: diff_score += 1 - data_score[f'diff_scaled_rmse_score_{suffix}'] = diff_score - - elif pred_type == 'clas': # Process classification: using MCC extracted from CV and Test results + data_score[f"diff_scaled_rmse_score_{suffix}"] = diff_score + + elif ( + pred_type == "clas" + ): # Process classification: using MCC extracted from CV and Test results # Extract MCC from the 10x 5-fold CV line - if '5-fold' in dat_predict[i+5]: - parts = dat_predict[i+5].split(',') + if "5-fold" in dat_predict[i + 5]: + parts = dat_predict[i + 5].split(",") mcc_cv = None for part in parts: - if 'MCC' in part: - mcc_cv = float(part.split('=')[-1]) + if "MCC" in part: + mcc_cv = float(part.split("=")[-1]) break if mcc_cv is not None: - data_score[f'r2_cv_{suffix}'] = mcc_cv # storing MCC in a key keyed as r2_cv for consistency + data_score[f"r2_cv_{suffix}"] = ( + mcc_cv # storing MCC in a key keyed as r2_cv for consistency + ) # Extract MCC from the Test line - if '- Test :' in dat_predict[i+6]: - parts = dat_predict[i+6].split(',') + if "- Test :" in dat_predict[i + 6]: + parts = dat_predict[i + 6].split(",") mcc_test = None for part in parts: - if 'MCC' in part: - mcc_test = float(part.split('=')[-1]) + if "MCC" in part: + mcc_test = float(part.split("=")[-1]) break if mcc_test is not None: - data_score[f'r2_test_{suffix}'] = mcc_test + data_score[f"r2_test_{suffix}"] = mcc_test # Compute CV and Test scores using the classification thresholds in score_rmse_mcc - data_score[f'cv_score_rmse_{suffix}'] = score_rmse_mcc(pred_type, data_score.get(f'r2_cv_{suffix}', 0)) - data_score[f'test_score_rmse_{suffix}'] = score_rmse_mcc(pred_type, data_score.get(f'r2_test_{suffix}', 0)) - + data_score[f"cv_score_rmse_{suffix}"] = score_rmse_mcc( + pred_type, data_score.get(f"r2_cv_{suffix}", 0) + ) + data_score[f"test_score_rmse_{suffix}"] = score_rmse_mcc( + pred_type, data_score.get(f"r2_test_{suffix}", 0) + ) + # For classification, the combined score is simply the score from MCC (no additional penalty) - data_score[f'cv_score_combined_{suffix}'] = data_score[f'cv_score_rmse_{suffix}'] - data_score[f'test_score_combined_{suffix}'] = data_score[f'test_score_rmse_{suffix}'] + data_score[f"cv_score_combined_{suffix}"] = data_score[ + f"cv_score_rmse_{suffix}" + ] + data_score[f"test_score_combined_{suffix}"] = data_score[ + f"test_score_rmse_{suffix}" + ] # SD from CV - if pred_type == 'reg': - if '- Average SD in test set' in line: + if pred_type == "reg": + if "- Average SD in test set" in line: cv_sd = float(line.split()[-1]) - cv_4sd = 4*cv_sd - y_range_covered = cv_4sd/data_score[f'y_range_{suffix}'] + cv_4sd = 4 * cv_sd + y_range_covered = cv_4sd / data_score[f"y_range_{suffix}"] cv_sd_score = 0 if y_range_covered <= 0.25: @@ -1025,42 +1171,42 @@ def get_predict_scores(dat_predict,suffix,pred_type,data_score): data_score[f"cv_4sd_{suffix}"] = cv_4sd data_score[f"cv_range_cov_{suffix}"] = y_range_covered - data_score[f'cv_sd_score_{suffix}'] = cv_sd_score + data_score[f"cv_sd_score_{suffix}"] = cv_sd_score return data_score -def score_rmse_mcc(pred_type,scaledrmse_mcc_val): - ''' +def score_rmse_mcc(pred_type, scaledrmse_mcc_val): + """ Calculate scores for R2 and MCC using predetermined thresholds - + For regression (scaled RMSE): 0-2 points For classification (MCC): 0-3 points - ''' + """ r2_mcc_score = 0 - if pred_type == 'reg': # scaled RMSE + if pred_type == "reg": # scaled RMSE if scaledrmse_mcc_val <= 10: r2_mcc_score += 2 elif scaledrmse_mcc_val <= 20: r2_mcc_score += 1 - else: # MCC + else: # MCC if scaledrmse_mcc_val > 0.75: r2_mcc_score += 3 elif scaledrmse_mcc_val > 0.5: r2_mcc_score += 2 elif scaledrmse_mcc_val > 0.3: r2_mcc_score += 1 - + return r2_mcc_score def calc_penalty_r2(r2_val): - ''' + """ Calculate scores for R2 and MCC using predetermined thresholds - ''' + """ penalty_r2 = 0 @@ -1068,7 +1214,7 @@ def calc_penalty_r2(r2_val): penalty_r2 -= 2 elif r2_val < 0.7: penalty_r2 -= 1 - + return penalty_r2 @@ -1077,35 +1223,38 @@ def repro_info(modules): Retrieves variables used in the Reproducibility section """ - version_n_date, citation, command_line = '','','' - python_version, total_time = '',0 + version_n_date, citation, command_line = "", "", "" + python_version, total_time = "", 0 dat_files = {} for module in modules: - path_file = Path(f'{os.getcwd()}/{module}/{module}_data.dat') + path_file = Path(f"{os.getcwd()}/{module}/{module}_data.dat") if os.path.exists(path_file): - datfile = open(path_file, 'r', encoding= 'utf-8', errors="replace") + datfile = open(path_file, "r", encoding="utf-8", errors="replace") txt_file = [] for line in datfile: txt_file.append(line) - if 'Time' in line and 'seconds' in line: + if "Time" in line and "seconds" in line: total_time += float(line.split()[2]) - if 'How to cite: ' in line: - citation = line.split('How to cite: ')[1] - if 'ROBERT v' == line[:8]: + if "How to cite: " in line: + citation = line.split("How to cite: ")[1] + if "ROBERT v" == line[:8]: version_n_date = line - if 'Command line used in ROBERT: ' in line: - if '--csv_name' not in command_line: # ensures that the value for --csv_name is stored - command_line = line.split('Command line used in ROBERT: ')[1] - total_time = round(total_time,2) + if "Command line used in ROBERT: " in line: + if ( + "--csv_name" not in command_line + ): # ensures that the value for --csv_name is stored + command_line = line.split("Command line used in ROBERT: ")[1] + total_time = round(total_time, 2) dat_files[module] = txt_file datfile.close() - + try: import platform + python_version = platform.python_version() - except: - python_version = '(version could not be determined)' - + except Exception: + python_version = "(version could not be determined)" + return version_n_date, citation, command_line, python_version, total_time, dat_files @@ -1120,7 +1269,9 @@ def make_report(report_html, HTML): try: os.remove(outfile) except PermissionError: - print('\nx ROBERT_report.pdf is open! Please, close the PDF file and run ROBERT again with --report (i.e., "python -m robert --report").') + print( + '\nx ROBERT_report.pdf is open! Please, close the PDF file and run ROBERT again with --report (i.e., "python -m robert --report").' + ) sys.exit() pdf = make_pdf(report_html, HTML, css_files) _ = Path(outfile).write_bytes(pdf) @@ -1136,7 +1287,7 @@ def make_pdf(html, HTML, css_files): return htmldoc -def css_content(csv_name,robert_version): +def css_content(csv_name, robert_version): """ Obtain ROBERT version and CSV name to use it on top of the PDF report """ @@ -1266,50 +1417,63 @@ def css_content(csv_name,robert_version): return css_content -def format_lines(module_data, max_width=122, cmd_line=False, one_column=False, spacing=''): +def format_lines( + module_data, max_width=122, cmd_line=False, one_column=False, spacing="" +): """ Reads a file and returns a formatted string between two markers """ formatted_lines = [] - lines = module_data.split('\n') - for i,line in enumerate(lines): - if 'R2' in line: - line = line.replace('R2','R2') + lines = module_data.split("\n") + for i, line in enumerate(lines): + if "R2" in line: + line = line.replace("R2", "R2") if cmd_line: - formatted_line = textwrap.fill(line, width=max_width-5, subsequent_indent='') + formatted_line = textwrap.fill( + line, width=max_width - 5, subsequent_indent="" + ) else: - formatted_line = textwrap.fill(line, width=max_width, subsequent_indent='') + formatted_line = textwrap.fill(line, width=max_width, subsequent_indent="") if i > 0: - formatted_lines.append(f'\n{formatted_line}')
+ formatted_lines.append(
+ f'\n{formatted_line}'
+ )
else:
- formatted_lines.append(f'{formatted_line}\n')
+ formatted_lines.append(
+ f'{formatted_line}\n'
+ )
# for two columns
if not one_column:
- return ''.join(formatted_lines)
-
+ return "".join(formatted_lines)
+
# for one column
- one_col_lines = ''
- for line in ''.join(formatted_lines).split('\n'):
- if line.startswith('') and line != '':
- one_col_lines += line.replace('',f'{spacing*3}')
- elif not line.startswith('<'):
- one_col_lines += f'\n{spacing*3}{line}'
+ one_col_lines = ""
+ for line in "".join(formatted_lines).split("\n"):
+ if (
+ line.startswith('')
+ and line != ''
+ ):
+ one_col_lines += line.replace(
+ '',
+ f'{spacing * 3}',
+ )
+ elif not line.startswith("<"):
+ one_col_lines += f"\n{spacing * 3}{line}"
else:
- one_col_lines += f'\n{line}'
+ one_col_lines += f"\n{line}"
return one_col_lines
-
-def get_spacing_col(suffix,spacing_PFI):
- '''
+def get_spacing_col(suffix, spacing_PFI):
+ """
Assign spacing of column
- '''
-
- if suffix == 'No PFI':
- spacing = ''
- elif suffix == 'PFI':
+ """
+
+ if suffix == "No PFI":
+ spacing = ""
+ elif suffix == "PFI":
spacing = spacing_PFI
-
- return spacing
\ No newline at end of file
+
+ return spacing
diff --git a/robert/robert.py b/robert/robert.py
index 7e370d5..5602b4d 100644
--- a/robert/robert.py
+++ b/robert/robert.py
@@ -35,16 +35,16 @@
from robert.report import report
from robert.aqme import aqme
from robert.evaluate import evaluate
-from robert.utils import (command_line_args,missing_inputs)
+from robert.utils import command_line_args, missing_inputs
-def main(exe_type='command',sys_args=None):
+def main(exe_type="command", sys_args=None):
"""
Main function of ROBERT, acts as the starting point when the program is run through a terminal
"""
# load user-defined arguments from command line
- args = command_line_args(exe_type,sys_args)
+ args = command_line_args(exe_type, sys_args)
args.command_line = True
if not args.evaluate:
@@ -53,7 +53,7 @@ def main(exe_type='command',sys_args=None):
if not args.curate and not args.generate and not args.predict:
if not args.cheers and not args.verify and not args.report:
full_workflow = True
-
+
if args.aqme:
full_workflow = True
@@ -63,19 +63,23 @@ def main(exe_type='command',sys_args=None):
# save the csv_name, y and names values from full workflows
if full_workflow:
# remove the EVALUATE folder to avoid issues when generating the report PDF
- eval_folder = Path(f'{os.getcwd()}/EVALUATE')
+ eval_folder = Path(f"{os.getcwd()}/EVALUATE")
if os.path.exists(eval_folder):
shutil.rmtree(eval_folder)
- args = missing_inputs(args,'full_workflow',print_err=True)
+ args = missing_inputs(args, "full_workflow", print_err=True)
# AQME
if args.aqme:
aqme(**vars(args))
# set the path to the database created by AQME to continue in the full_workflow
- args.csv_name = Path(os.path.dirname(args.csv_name)).joinpath(f'AQME-ROBERT_{args.descp_lvl}_{os.path.basename(args.csv_name)}')
- if args.csv_test != '':
- args.csv_test = Path(os.path.dirname(args.csv_test)).joinpath(f'AQME-ROBERT_{args.descp_lvl}_{os.path.basename(args.csv_test)}')
+ args.csv_name = Path(os.path.dirname(args.csv_name)).joinpath(
+ f"AQME-ROBERT_{args.descp_lvl}_{os.path.basename(args.csv_name)}"
+ )
+ if args.csv_test != "":
+ args.csv_test = Path(os.path.dirname(args.csv_test)).joinpath(
+ f"AQME-ROBERT_{args.descp_lvl}_{os.path.basename(args.csv_test)}"
+ )
# CURATE
if args.curate or full_workflow:
@@ -83,9 +87,9 @@ def main(exe_type='command',sys_args=None):
if full_workflow:
# this ensures GENERATE communicates with CURATE (see the load_variables() function in utils.py)
- args.y = ''
- args.discard = [] # avoids an error since the variable(s) are removed in CURATE
- args.csv_name = '' # force GENERATE to use the curated database
+ args.y = ""
+ args.discard = [] # avoids an error since the variable(s) are removed in CURATE
+ args.csv_name = "" # force GENERATE to use the curated database
# GENERATE
if args.generate or full_workflow:
@@ -102,10 +106,12 @@ def main(exe_type='command',sys_args=None):
# REPORT
if args.report or full_workflow:
report(**vars(args))
-
+
# CHEERS
if args.cheers:
- print('o This module was designed to thank ROBERT Paton, who was a mentor to me throughout my years at Colorado State University, and who introduced me to the field of cheminformatics. Cheers mate!\n')
+ print(
+ "o This module was designed to thank ROBERT Paton, who was a mentor to me throughout my years at Colorado State University, and who introduced me to the field of cheminformatics. Cheers mate!\n"
+ )
# EVALUATE, only evaluates models
else:
@@ -116,7 +122,7 @@ def main(exe_type='command',sys_args=None):
curate(**vars(args))
# Ignore the Set column created in EVALUATE inside the CSV of the GENERATE folder
- args.ignore.append('Set')
+ args.ignore.append("Set")
# VERIFY
verify(**vars(args))
@@ -134,13 +140,26 @@ def set_aqme_args(args):
"""
if os.path.exists(args.csv_name):
- aqme_df = pd.read_csv(args.csv_name, encoding='utf-8')
+ aqme_df = pd.read_csv(args.csv_name, encoding="utf-8")
else:
- print(f'\nx The path of your CSV file doesn\'t exist! You specified: {args.csv_name}')
+ print(
+ f"\nx The path of your CSV file doesn't exist! You specified: {args.csv_name}"
+ )
sys.exit()
# list of potential arguments from CSV inputs in AQME
- aqme_args = ['smiles','charge','mult','complex_type','geom','constraints_atoms','constraints_dist','constraints_angle','constraints_dihedral','sample']
+ aqme_args = [
+ "smiles",
+ "charge",
+ "mult",
+ "complex_type",
+ "geom",
+ "constraints_atoms",
+ "constraints_dist",
+ "constraints_angle",
+ "constraints_dihedral",
+ "sample",
+ ]
# ignore the names and SMILES of the molecules
remove = []
@@ -149,10 +168,10 @@ def set_aqme_args(args):
remove.append(column)
for column in remove:
args.ignore.remove(column)
- if 'code_name' in args.ignore:
- args.ignore.remove('code_name')
+ if "code_name" in args.ignore:
+ args.ignore.remove("code_name")
for column in aqme_df.columns:
- if column.lower() == 'code_name' and args.names == '':
+ if column.lower() == "code_name" and args.names == "":
args.names = column
return args
diff --git a/robert/uq_auto.py b/robert/uq_auto.py
index 0a0ea99..c9e5588 100644
--- a/robert/uq_auto.py
+++ b/robert/uq_auto.py
@@ -12,7 +12,7 @@
import json
import warnings
from pathlib import Path
-from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
+from typing import Any, Dict, List, Mapping, Optional, Sequence
import numpy as np
from scipy.stats import norm
@@ -34,7 +34,9 @@ def _as_float_array(x: Sequence[float]) -> np.ndarray:
return np.asarray(x, dtype=float).ravel()
-def _normalize_metric_weights(weights: Optional[Mapping[str, float]]) -> Dict[str, float]:
+def _normalize_metric_weights(
+ weights: Optional[Mapping[str, float]],
+) -> Dict[str, float]:
base = dict(DEFAULT_METRIC_WEIGHTS)
if weights is not None:
for key in base:
@@ -131,7 +133,9 @@ def apply_uncertainty_scaler(
knots_r = np.asarray(params.get("knots_r", []), dtype=float)
if knots_u.size == 0:
return u
- return np.maximum(np.interp(u, knots_u, knots_r, left=knots_r[0], right=knots_r[-1]), 0.0)
+ return np.maximum(
+ np.interp(u, knots_u, knots_r, left=knots_r[0], right=knots_r[-1]), 0.0
+ )
raise ValueError(f"Unknown scaler method in params: {method!r}")
@@ -149,8 +153,8 @@ def _coverage_error(abs_resid: np.ndarray, sigma: np.ndarray, coverage: float) -
def _gaussian_nll(abs_resid: np.ndarray, sigma: np.ndarray) -> float:
sigma = np.maximum(sigma, 1e-12)
# NLL for Laplace-like on abs residual under Gaussian proxy
- var = sigma ** 2
- return float(np.mean(0.5 * np.log(2.0 * np.pi * var) + 0.5 * (abs_resid ** 2) / var))
+ var = sigma**2
+ return float(np.mean(0.5 * np.log(2.0 * np.pi * var) + 0.5 * (abs_resid**2) / var))
def _sharpness(sigma: np.ndarray) -> float:
@@ -177,7 +181,9 @@ def score_uncertainty_candidate(
def _oof_mean_train(Xy_data: Mapping[str, Any]) -> np.ndarray:
preds_all = Xy_data.get("y_pred_train_all", [])
- return np.array([float(np.mean(p)) if len(p) else np.nan for p in preds_all], dtype=float)
+ return np.array(
+ [float(np.mean(p)) if len(p) else np.nan for p in preds_all], dtype=float
+ )
def _train_abs_residuals(Xy_data: Mapping[str, Any]) -> np.ndarray:
@@ -275,7 +281,9 @@ def evaluate_uq_candidates(
Returns dict with keys: selected, scaler_params, candidate_scores, coverage, n_eval.
"""
- candidates_cfg = getattr(args, "uq_auto_candidates", None) or list(DEFAULT_CANDIDATES)
+ candidates_cfg = getattr(args, "uq_auto_candidates", None) or list(
+ DEFAULT_CANDIDATES
+ )
if isinstance(candidates_cfg, str):
candidates_cfg = [c.strip() for c in candidates_cfg.split(",") if c.strip()]
@@ -320,7 +328,9 @@ def evaluate_uq_candidates(
params = fit_uncertainty_scaler(
scaler_method, u_raw[fit_ix], abs_resid[fit_ix]
)
- u_scaled_eval = apply_uncertainty_scaler(scaler_method, u_raw[eval_ix], params)
+ u_scaled_eval = apply_uncertainty_scaler(
+ scaler_method, u_raw[eval_ix], params
+ )
score = score_uncertainty_candidate(
u_scaled_eval, abs_resid[eval_ix], coverage, metric_weights
)
@@ -409,8 +419,7 @@ def apply_auto_uq(
meta_path = Path("PREDICT") / "uq_auto_metadata.json"
meta_path.parent.mkdir(parents=True, exist_ok=True)
serializable = {
- k: (v if not isinstance(v, dict) else dict(v))
- for k, v in selection.items()
+ k: (v if not isinstance(v, dict) else dict(v)) for k, v in selection.items()
}
with meta_path.open("w", encoding="utf-8") as fh:
json.dump(serializable, fh, indent=2)
diff --git a/robert/utils.py b/robert/utils.py
index 88c4c78..716a388 100644
--- a/robert/utils.py
+++ b/robert/utils.py
@@ -21,21 +21,30 @@
# This prevents numerical differences between Windows/Ubuntu in parallel operations
os.environ["LOKY_MAX_CPU_COUNT"] = "1"
from matplotlib import pyplot as plt
+import seaborn as sb
+from bayes_opt import BayesianOptimization
import matplotlib.patches as mpatches
import matplotlib.colors as mcolor
from matplotlib.legend_handler import HandlerPatch
from matplotlib.ticker import FormatStrFormatter
from scipy import stats
from importlib.resources import files
+
# sklearnex was deactivated in ROBERT v2.1 because it only accelerated RF
# try:
# from sklearnex import patch_sklearn
# patch_sklearn(verbose=True)
# except (ModuleNotFoundError,ImportError):
# pass
-from sklearn.metrics import (mean_absolute_error, mean_squared_error,
- matthews_corrcoef, accuracy_score, f1_score, make_scorer,
- ConfusionMatrixDisplay)
+from sklearn.metrics import (
+ mean_absolute_error,
+ mean_squared_error,
+ matthews_corrcoef,
+ accuracy_score,
+ f1_score,
+ make_scorer,
+ ConfusionMatrixDisplay,
+)
from sklearn.feature_selection import RFECV
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import (
@@ -47,19 +56,27 @@
AdaBoostClassifier,
VotingRegressor,
VotingClassifier,
- )
+)
from xgboost import XGBClassifier, XGBRegressor
from sklearn.gaussian_process import GaussianProcessRegressor, GaussianProcessClassifier
from sklearn.neural_network import MLPRegressor, MLPClassifier
from sklearn.linear_model import LinearRegression
from sklearn.impute import KNNImputer
from sklearn.base import clone
-from sklearn.model_selection import train_test_split, cross_val_score, StratifiedShuffleSplit, RepeatedKFold, KFold, StratifiedKFold
+from sklearn.model_selection import (
+ train_test_split,
+ cross_val_score,
+ StratifiedShuffleSplit,
+ RepeatedKFold,
+ KFold,
+ StratifiedKFold,
+)
from sklearn.cluster import KMeans
from sklearn.inspection import permutation_importance
from sklearn.exceptions import ConvergenceWarning
from robert.argument_parser import set_options, var_dict
-import warnings # this avoids warnings from sklearn
+import warnings # this avoids warnings from sklearn
+
warnings.filterwarnings("ignore")
@@ -93,12 +110,18 @@ def should_plot_verify_metrics(args) -> bool:
def should_plot_predict_results(args) -> bool:
"""Main PREDICT figures (e.g. Results_*); tied to predict_diagnostics for backward compatibility."""
- return bool(getattr(args, "predict_diagnostics", True)) and plot_verbosity_level(args) >= 1
+ return (
+ bool(getattr(args, "predict_diagnostics", True))
+ and plot_verbosity_level(args) >= 1
+ )
def should_plot_predict_deep_diagnostics(args) -> bool:
"""SHAP, PFI, Pearson heatmap, outliers, distribution plots."""
- return bool(getattr(args, "predict_diagnostics", True)) and plot_verbosity_level(args) >= 2
+ return (
+ bool(getattr(args, "predict_diagnostics", True))
+ and plot_verbosity_level(args) >= 2
+ )
robert_version = "2.1.0"
@@ -130,7 +153,7 @@ def load_from_yaml(self):
try:
loaded = yaml.load(file, Loader=yaml.SafeLoader)
param_list = loaded if isinstance(loaded, dict) else {}
- except (yaml.scanner.ScannerError,yaml.parser.ParserError):
+ except (yaml.scanner.ScannerError, yaml.parser.ParserError):
txt_yaml = f'\nx Error while reading {self.varfile}. Edit the yaml file and try again (i.e. use ":" instead of "=" to specify variables)'
error_yaml = True
else:
@@ -186,12 +209,12 @@ def finalize(self):
self.log.close()
-def command_line_args(exe_type,sys_args):
+def command_line_args(exe_type, sys_args):
"""
Load default and user-defined arguments specified through command lines. Arrguments are loaded as a dictionary
"""
- if exe_type == 'exe':
+ if exe_type == "exe":
# Simulate sys.argv for use in an executable environment
sys.argv = ["launcher.exe"]
for k, v in sys_args.items():
@@ -220,14 +243,14 @@ def command_line_args(exe_type,sys_args):
"report_modules",
]
int_args = [
- 'pfi_epochs',
- 'epochs',
- 'nprocs',
- 'pfi_max',
- 'kfold',
- 'repeat_kfolds',
- 'shap_show',
- 'pfi_show',
+ "pfi_epochs",
+ "epochs",
+ "nprocs",
+ "pfi_max",
+ "kfold",
+ "repeat_kfolds",
+ "shap_show",
+ "pfi_show",
"seed",
"init_points",
"n_iter",
@@ -237,16 +260,16 @@ def command_line_args(exe_type,sys_args):
"uq_auto_random_state",
]
float_args = [
- 'pfi_threshold',
- 't_value',
- 'thres_x',
- 'thres_y',
- 'test_set',
- 'desc_thres',
- 'alpha',
- 'expect_improv',
- 'conformal_calib_frac',
- 'conformal_coverage',
+ "pfi_threshold",
+ "t_value",
+ "thres_x",
+ "thres_y",
+ "test_set",
+ "desc_thres",
+ "alpha",
+ "expect_improv",
+ "conformal_calib_frac",
+ "conformal_coverage",
]
for arg in var_dict:
@@ -346,12 +369,12 @@ def command_line_args(exe_type,sys_args):
sys.exit()
else:
# this "if" allows to use * to select multiple files in multiple OS
- if arg_name.lower() == 'files' and value.find('*') > -1:
+ if arg_name.lower() == "files" and value.find("*") > -1:
kwargs[arg_name] = glob.glob(value)
else:
# converts the string parameters from command line to the right format
if arg_name in bool_args:
- value = True
+ value = True
elif arg_name.lower() in list_args:
value = format_lists(value)
elif arg_name.lower() in int_args:
@@ -374,18 +397,20 @@ def command_line_args(exe_type,sys_args):
def format_lists(value):
- '''
+ """
Transforms strings into a list
- '''
+ """
if not isinstance(value, list):
try:
value = ast.literal_eval(value)
except (SyntaxError, ValueError):
# this line fixes issues when using "[X]" or ["X"] instead of "['X']" when using lists
- value = value.replace('[',']').replace(',',']').replace("'",']').split(']')
- while('' in value):
- value.remove('')
+ value = (
+ value.replace("[", "]").replace(",", "]").replace("'", "]").split("]")
+ )
+ while "" in value:
+ value.remove("")
# remove extra spaces that sometimes are included by mistake
value = [ele.strip() if isinstance(ele, str) else ele for ele in value]
@@ -407,34 +432,44 @@ def load_variables(kwargs, robert_module):
self, txt_yaml = load_from_yaml(self)
# check if user used .csv in csv_name
- if not os.path.exists(f"{self.csv_name}") and os.path.exists(f'{self.csv_name}.csv'):
- self.csv_name = f'{self.csv_name}.csv'
+ if not os.path.exists(f"{self.csv_name}") and os.path.exists(
+ f"{self.csv_name}.csv"
+ ):
+ self.csv_name = f"{self.csv_name}.csv"
# check if user used .csv in csv_test
- if self.csv_test and not os.path.exists(f"{self.csv_test}") and os.path.exists(f'{self.csv_test}.csv'):
- self.csv_test = f'{self.csv_test}.csv'
+ if (
+ self.csv_test
+ and not os.path.exists(f"{self.csv_test}")
+ and os.path.exists(f"{self.csv_test}.csv")
+ ):
+ self.csv_test = f"{self.csv_test}.csv"
# check for spaces in csv file names
if " " in str(self.csv_name):
- print("\nx ERROR: The input CSV file name contains spaces. Please remove spaces from the file name and try again. Spaces in file names can cause problems. Example: use 'my_data.csv' instead of 'my data.csv'.")
+ print(
+ "\nx ERROR: The input CSV file name contains spaces. Please remove spaces from the file name and try again. Spaces in file names can cause problems. Example: use 'my_data.csv' instead of 'my data.csv'."
+ )
sys.exit()
if self.csv_test and " " in str(self.csv_test):
- print("\nx ERROR: The test CSV file name contains spaces. Please remove spaces from the file name and try again. Spaces in file names can cause problems. Example: use 'test_data.csv' instead of 'test data.csv'.")
+ print(
+ "\nx ERROR: The test CSV file name contains spaces. Please remove spaces from the file name and try again. Spaces in file names can cause problems. Example: use 'test_data.csv' instead of 'test data.csv'."
+ )
sys.exit()
if robert_module != "command":
self.initial_dir = Path(os.getcwd())
# adds --names to --ignore
- if self.names not in self.ignore and self.names != '':
+ if self.names not in self.ignore and self.names != "":
self.ignore.append(self.names)
# creates destination folder
- if robert_module.upper() != 'REPORT':
- self = destination_folder(self,robert_module)
+ if robert_module.upper() != "REPORT":
+ self = destination_folder(self, robert_module)
# start a log file
- logger_1 = 'ROBERT'
+ logger_1 = "ROBERT"
logger_1, logger_2 = robert_module.upper(), "data"
if txt_yaml not in [
@@ -448,134 +483,168 @@ def load_variables(kwargs, robert_module):
sys.exit()
self.log = Logger(self.destination / logger_1, logger_2)
- self.log.write(f"ROBERT v {robert_version} {time_run} \nHow to cite: {robert_ref}\n")
+ self.log.write(
+ f"ROBERT v {robert_version} {time_run} \nHow to cite: {robert_ref}\n"
+ )
if self.command_line:
- cmd_print = ''
+ cmd_print = ""
cmd_args = sys.argv[1:]
- if self.extra_cmd != '':
+ if self.extra_cmd != "":
for arg in self.extra_cmd.split():
cmd_args.append(arg)
- for i,elem in enumerate(cmd_args):
- if elem[0] in ['"',"'"]:
+ for i, elem in enumerate(cmd_args):
+ if elem[0] in ['"', "'"]:
elem = elem[1:]
- if elem[-1] in ['"',"'"]:
+ if elem[-1] in ['"', "'"]:
elem = elem[:-1]
- if elem != '-h' and elem.split('--')[-1] not in var_dict:
+ if elem != "-h" and elem.split("--")[-1] not in var_dict:
# parse single elements of the list as strings (otherwise the commands cannot be reproduced)
- if '--qdescp_atoms' in elem:
+ if "--qdescp_atoms" in elem:
new_arg = []
- list_qdescp = elem.replace(', ',',').replace(' ,',',').split()
- for j,qdescp_elem in enumerate(list_qdescp):
- if list_qdescp[j-1] == '--qdescp_atoms':
+ list_qdescp = (
+ elem.replace(", ", ",").replace(" ,", ",").split()
+ )
+ for j, qdescp_elem in enumerate(list_qdescp):
+ if list_qdescp[j - 1] == "--qdescp_atoms":
qdescp_elem = qdescp_elem[1:-1]
new_elem = []
- for smarts_strings in qdescp_elem.split(','):
- new_elem.append(f'{smarts_strings}'.replace("'",''))
- new_arg.append(f'{new_elem}'.replace(" ",""))
+ for smarts_strings in qdescp_elem.split(","):
+ new_elem.append(
+ f"{smarts_strings}".replace("'", "")
+ )
+ new_arg.append(f"{new_elem}".replace(" ", ""))
else:
new_arg.append(qdescp_elem)
- new_arg = ' '.join(new_arg)
+ new_arg = " ".join(new_arg)
elem = new_arg
- if cmd_args[i-1].split('--')[-1] in var_dict: # check if the previous word is an arg
+ if (
+ cmd_args[i - 1].split("--")[-1] in var_dict
+ ): # check if the previous word is an arg
cmd_print += f'"{elem}'
- if i == len(cmd_args)-1 or cmd_args[i+1].split('--')[-1] in var_dict: # check if the next word is an arg, or last word in command
- cmd_print += f'"'
+ if (
+ i == len(cmd_args) - 1
+ or cmd_args[i + 1].split("--")[-1] in var_dict
+ ): # check if the next word is an arg, or last word in command
+ cmd_print += '"'
else:
- cmd_print += f'{elem}'
- if i != len(cmd_args)-1:
- cmd_print += ' '
+ cmd_print += f"{elem}"
+ if i != len(cmd_args) - 1:
+ cmd_print += " "
- self.log.write(f"Command line used in ROBERT: python -m robert {cmd_print}\n")
+ self.log.write(
+ f"Command line used in ROBERT: python -m robert {cmd_print}\n"
+ )
- elif robert_module.upper() == 'REPORT':
+ elif robert_module.upper() == "REPORT":
self.path_icons = files("robert").joinpath("report")
# sklearnex was deactivated in ROBERT v2.1 because it only accelerated RF
# using or not the intelex accelerator might affect the results
# if robert_module.upper() in ['GENERATE','VERIFY','PREDICT']:
- # try:
- # import sklearnex
- # pass
- # except (ModuleNotFoundError,ImportError):
- # self.log.write(f"\nx WARNING! The scikit-learn-intelex accelerator is not installed, the results might vary if it is installed and the execution times might become much longer (if available, use 'pip install scikit-learn-intelex')")
+ # try:
+ # import sklearnex
+ # pass
+ # except (ModuleNotFoundError,ImportError):
+ # self.log.write(f"\nx WARNING! The scikit-learn-intelex accelerator is not installed, the results might vary if it is installed and the execution times might become much longer (if available, use 'pip install scikit-learn-intelex')")
- if robert_module.upper() in ['GENERATE', 'VERIFY']:
+ if robert_module.upper() in ["GENERATE", "VERIFY"]:
# adjust the default value of error_type for classification
- if self.type.lower() == 'clas':
- if self.error_type not in ['acc', 'mcc', 'f1']:
- self.error_type = 'mcc'
+ if self.type.lower() == "clas":
+ if self.error_type not in ["acc", "mcc", "f1"]:
+ self.error_type = "mcc"
- if robert_module.upper() in ['PREDICT','VERIFY','REPORT']:
- if self.params_dir == '':
- self.params_dir = 'GENERATE/Best_model'
+ if robert_module.upper() in ["PREDICT", "VERIFY", "REPORT"]:
+ if self.params_dir == "":
+ self.params_dir = "GENERATE/Best_model"
- if robert_module.upper() in ['CURATE','GENERATE']:
- if self.type.lower() == 'clas':
+ if robert_module.upper() in ["CURATE", "GENERATE"]:
+ if self.type.lower() == "clas":
if any(m.upper() == "MVL" for m in self.model):
- self.model = [x if x.upper() != 'MVL' else 'AdaB' for x in self.model]
-
- models_gen = [] # use capital letters in all the models
+ self.model = [
+ x if x.upper() != "MVL" else "AdaB" for x in self.model
+ ]
+
+ models_gen = [] # use capital letters in all the models
for model_type in self.model:
models_gen.append(model_type.upper())
self.model = models_gen
- if robert_module.upper() == 'CURATE':
- self.log.write(f"\no Starting data curation with the CURATE module")
+ if robert_module.upper() == "CURATE":
+ self.log.write("\no Starting data curation with the CURATE module")
+
+ elif robert_module.upper() == "GENERATE":
+ self.log.write(
+ "\no Starting generation of ML models with the GENERATE module"
+ )
- elif robert_module.upper() == 'GENERATE':
- self.log.write(f"\no Starting generation of ML models with the GENERATE module")
-
# Check if the folders exist and if they do, delete and replace them
- folder_names = [self.initial_dir.joinpath('GENERATE/Best_model/No_PFI'), self.initial_dir.joinpath('GENERATE/Raw_data/No_PFI')]
+ folder_names = [
+ self.initial_dir.joinpath("GENERATE/Best_model/No_PFI"),
+ self.initial_dir.joinpath("GENERATE/Raw_data/No_PFI"),
+ ]
if self.pfi_filter:
- folder_names.append(self.initial_dir.joinpath('GENERATE/Best_model/PFI'))
- folder_names.append(self.initial_dir.joinpath('GENERATE/Raw_data/PFI'))
+ folder_names.append(
+ self.initial_dir.joinpath("GENERATE/Best_model/PFI")
+ )
+ folder_names.append(self.initial_dir.joinpath("GENERATE/Raw_data/PFI"))
_ = create_folders(folder_names)
# if there are missing options, look for them from a previous CURATE job (if any)
options_dict = {
- 'y': self.y,
- 'names': self.names,
- 'ignore': self.ignore,
- 'csv_name': self.csv_name
+ "y": self.y,
+ "names": self.names,
+ "ignore": self.ignore,
+ "csv_name": self.csv_name,
}
- curate_folder = Path(self.initial_dir).joinpath('CURATE')
- curate_csv = f'{curate_folder}/CURATE_options.csv'
+ curate_folder = Path(self.initial_dir).joinpath("CURATE")
+ curate_csv = f"{curate_folder}/CURATE_options.csv"
if os.path.exists(curate_csv):
- curate_df = pd.read_csv(curate_csv, encoding='utf-8')
+ curate_df = pd.read_csv(curate_csv, encoding="utf-8")
for option in options_dict:
- if options_dict[option] == '':
- if option == 'y':
- self.y = curate_df['y'][0]
- elif option == 'names':
- self.names = curate_df['names'][0]
- elif option == 'ignore':
- self.ignore = curate_df['ignore'][0]
- self.ignore = format_lists(self.ignore)
- elif option == 'csv_name':
- self.csv_name = curate_df['csv_name'][0]
-
- # Load class labels if they exist (for classification with string labels)
- if 'class_0_label' in curate_df.columns and 'class_1_label' in curate_df.columns:
- self.class_0_label = curate_df['class_0_label'][0]
- self.class_1_label = curate_df['class_1_label'][0]
+ if options_dict[option] == "":
+ if option == "y":
+ self.y = curate_df["y"][0]
+ elif option == "names":
+ self.names = curate_df["names"][0]
+ elif option == "ignore":
+ self.ignore = curate_df["ignore"][0]
+ self.ignore = format_lists(self.ignore)
+ elif option == "csv_name":
+ self.csv_name = curate_df["csv_name"][0]
- elif robert_module.upper() in ['PREDICT','VERIFY']:
- if robert_module.upper() == 'PREDICT':
- self.log.write(f"\no Representation of predictions and analysis of ML models with the PREDICT module")
- elif robert_module.upper() == 'VERIFY':
- self.log.write(f"\no Starting tests to verify the prediction ability of the ML models with the VERIFY module")
+ # Load class labels if they exist (for classification with string labels)
+ if (
+ "class_0_label" in curate_df.columns
+ and "class_1_label" in curate_df.columns
+ ):
+ self.class_0_label = curate_df["class_0_label"][0]
+ self.class_1_label = curate_df["class_1_label"][0]
+
+ elif robert_module.upper() in ["PREDICT", "VERIFY"]:
+ if robert_module.upper() == "PREDICT":
+ self.log.write(
+ "\no Representation of predictions and analysis of ML models with the PREDICT module"
+ )
+ elif robert_module.upper() == "VERIFY":
+ self.log.write(
+ "\no Starting tests to verify the prediction ability of the ML models with the VERIFY module"
+ )
- if '' in [self.names,self.y,self.csv_name]:
+ if "" in [self.names, self.y, self.csv_name]:
# tries to get names from GENERATE
- if 'GENERATE/Best_model' in self.params_dir:
- params_dirs = [f'{self.params_dir}/No_PFI',f'{self.params_dir}/PFI']
+ if "GENERATE/Best_model" in self.params_dir:
+ params_dirs = [
+ f"{self.params_dir}/No_PFI",
+ f"{self.params_dir}/PFI",
+ ]
else:
params_dirs = [self.params_dir]
self.args = self
- _,_,_,model_data,csv_name = load_dfs(self,params_dirs[0],'predict',sanity_check=True)
+ _, _, _, model_data, csv_name = load_dfs(
+ self, params_dirs[0], "predict", sanity_check=True
+ )
self.names = model_data["names"]
self.y = model_data["y"]
@@ -587,48 +656,58 @@ def load_variables(kwargs, robert_module):
if "type" in model_data:
self.type = model_data["type"]
- elif robert_module.upper() in ['AQME', 'AQME_TEST']:
+ elif robert_module.upper() in ["AQME", "AQME_TEST"]:
# Check if the csv has 2 columns named smiles or smiles_Suffix. The file is read as text because pandas assigns automatically
# .1 to duplicate columns. (i.e. SMILES and SMILES.1 if there are two columns named SMILES)
- unique_columns=[]
- with open(self.csv_name, 'r') as datfile:
+ unique_columns = []
+ with open(self.csv_name, "r") as datfile:
lines = datfile.readlines()
- for column in lines[0].split(','):
+ for column in lines[0].split(","):
if column in unique_columns:
- print(f"\nWARNING! The CSV file contains duplicate columns ({column}). Please, rename or remove these columns. If you want to use more than one SMILES column, use _Suffix (i.e. SMILES_1, SMILES_2, ...)")
+ print(
+ f"\nWARNING! The CSV file contains duplicate columns ({column}). Please, rename or remove these columns. If you want to use more than one SMILES column, use _Suffix (i.e. SMILES_1, SMILES_2, ...)"
+ )
sys.exit()
else:
unique_columns.append(column)
-
+
# Check if there is a column with the name "smiles" or "smiles_" followed by any characters
if not any(col.lower().startswith("smiles") for col in unique_columns):
- print("\nWARNING! The CSV file does not contain a column with the name 'smiles' or a column starting with 'smiles_'. Please make sure the column exists.")
+ print(
+ "\nWARNING! The CSV file does not contain a column with the name 'smiles' or a column starting with 'smiles_'. Please make sure the column exists."
+ )
sys.exit()
# Check if there are duplicate names in code_names in the csv file.
- df = pd.read_csv(self.csv_name, encoding='utf-8')
- unique_entries=[]
- for entry in df['code_name']:
+ df = pd.read_csv(self.csv_name, encoding="utf-8")
+ unique_entries = []
+ for entry in df["code_name"]:
if entry in unique_entries:
- print(f"\nWARNING! The code_name column in the CSV file contains duplicate entries ({entry}). Please, rename or remove these entries.")
+ print(
+ f"\nWARNING! The code_name column in the CSV file contains duplicate entries ({entry}). Please, rename or remove these entries."
+ )
sys.exit()
else:
unique_entries.append(entry)
- self.log.write(f"\no Starting the generation of AQME descriptors with the AQME module")
+ self.log.write(
+ "\no Starting the generation of AQME descriptors with the AQME module"
+ )
# initial sanity checks
- if robert_module.upper() != 'REPORT':
- _ = sanity_checks(self, 'initial', robert_module, None)
+ if robert_module.upper() != "REPORT":
+ _ = sanity_checks(self, "initial", robert_module, None)
return self
-def destination_folder(self,dest_module):
+def destination_folder(self, dest_module):
if self.destination is None:
self.destination = Path(self.initial_dir).joinpath(dest_module.upper())
else:
- self.log.write(f"\nx The destination option has not been implemented yet! Please, remove it from your input and stay tuned.")
+ self.log.write(
+ "\nx The destination option has not been implemented yet! Please, remove it from your input and stay tuned."
+ )
sys.exit()
# this part does not work for know
# if Path(f"{self.destination}").exists():
@@ -643,37 +722,45 @@ def destination_folder(self,dest_module):
return self
-def missing_inputs(self,module,print_err=False):
+def missing_inputs(self, module, print_err=False):
"""
Gives the option to input missing variables in the terminal
"""
- if module.lower() not in ['predict','verify','report','aqme_test']:
- if self.csv_name == '':
- self = check_csv_option(self,'csv_name',print_err)
+ if module.lower() not in ["predict", "verify", "report", "aqme_test"]:
+ if self.csv_name == "":
+ self = check_csv_option(self, "csv_name", print_err)
- if module.lower() not in ['predict','verify','report','aqme_test']:
- if self.y == '':
+ if module.lower() not in ["predict", "verify", "report", "aqme_test"]:
+ if self.y == "":
if print_err:
- print(f'\nx Specify a y value (column name) with the y option! (i.e. y="solubility")')
+ print(
+ '\nx Specify a y value (column name) with the y option! (i.e. y="solubility")'
+ )
else:
- self.log.write(f'\nx Specify a y value (column name) with the y option! (i.e. y="solubility")')
- self.y = input('Enter the column with y values: ')
- self.extra_cmd += f' --y {self.y}'
+ self.log.write(
+ '\nx Specify a y value (column name) with the y option! (i.e. y="solubility")'
+ )
+ self.y = input("Enter the column with y values: ")
+ self.extra_cmd += f" --y {self.y}"
if not print_err:
self.log.write(f" - y option set to {self.y} by the user")
- if module.lower() in ['full_workflow','predict','curate','generate','evaluate']:
- if self.names == '':
+ if module.lower() in ["full_workflow", "predict", "curate", "generate", "evaluate"]:
+ if self.names == "":
if print_err:
- print(f'\nx Specify the column with the entry names! (i.e. names="code_name")')
+ print(
+ '\nx Specify the column with the entry names! (i.e. names="code_name")'
+ )
else:
- self.log.write(f'\nx Specify the column with the entry names! (i.e. names="code_name")')
- self.names = input('Enter the column with the entry names: ')
- self.extra_cmd += f' --names {self.names}'
+ self.log.write(
+ '\nx Specify the column with the entry names! (i.e. names="code_name")'
+ )
+ self.names = input("Enter the column with the entry names: ")
+ self.extra_cmd += f" --names {self.names}"
if not print_err:
self.log.write(f" - names option set to {self.names} by the user")
- if self.names != '' and self.names not in self.ignore:
+ if self.names != "" and self.names not in self.ignore:
self.ignore.append(self.names)
return self
@@ -683,7 +770,7 @@ def correlation_filter(self, csv_df):
"""
Discards a) correlated variables and b) variables that do not correlate with the y values, based
on R**2 values c) reduces the number of descriptors to one third of the datapoints using RFECV.
-
+
REPRODUCIBILITY GUARANTEES:
- Columns are sorted alphabetically before any operation
- Rows are sorted by y value to ensure consistent ordering
@@ -691,81 +778,102 @@ def correlation_filter(self, csv_df):
- RFECV descriptor selection uses sorted feature importances with alphabetical tie-breaking
"""
- txt_corr = ''
-
+ txt_corr = ""
+
# Sort columns alphabetically and rows by y value for reproducibility
- descriptor_cols = [col for col in csv_df.columns if col not in self.args.ignore and col != self.args.y]
+ descriptor_cols = [
+ col
+ for col in csv_df.columns
+ if col not in self.args.ignore and col != self.args.y
+ ]
descriptor_cols_sorted = sorted(descriptor_cols)
- other_cols = [col for col in csv_df.columns if col in self.args.ignore or col == self.args.y]
+ other_cols = [
+ col for col in csv_df.columns if col in self.args.ignore or col == self.args.y
+ ]
csv_df = csv_df[descriptor_cols_sorted + other_cols].copy()
- csv_df = csv_df.reset_index(drop=True).sort_values(by=self.args.y, kind='stable').reset_index(drop=True)
+ csv_df = (
+ csv_df.reset_index(drop=True)
+ .sort_values(by=self.args.y, kind="stable")
+ .reset_index(drop=True)
+ )
# loosen correlation filters if there are too few descriptors
- n_descps = len(csv_df.columns)-len(self.args.ignore)-1 # all columns - ignored - y
- txt_corr += f'\no Correlation filter activated with these thresholds: thres_x = {self.args.thres_x}'
+ n_descps = (
+ len(csv_df.columns) - len(self.args.ignore) - 1
+ ) # all columns - ignored - y
+ txt_corr += f"\no Correlation filter activated with these thresholds: thres_x = {self.args.thres_x}"
if self.args.corr_filter_y:
- txt_corr += f', thres_y = {self.args.thres_y}'
+ txt_corr += f", thres_y = {self.args.thres_y}"
descriptors_drop = []
- txt_corr += f'\n Excluded descriptors:'
-
+ txt_corr += "\n Excluded descriptors:"
+
# First pass: remove constant descriptors and those with low correlation to y
- for _,column in enumerate(csv_df.columns):
- if column not in descriptors_drop and column not in self.args.ignore and column != self.args.y:
+ for _, column in enumerate(csv_df.columns):
+ if (
+ column not in descriptors_drop
+ and column not in self.args.ignore
+ and column != self.args.y
+ ):
# Remove descriptors where all values are the same
if len(set(csv_df[column])) == 1:
descriptors_drop.append(column)
- txt_corr += f'\n - {column}: all the values are the same'
+ txt_corr += f"\n - {column}: all the values are the same"
# Remove descriptors with low correlation to the response values
if self.args.corr_filter_y:
# Calculate correlation with y for remaining descriptors
if column not in descriptors_drop:
- res_y = stats.linregress(csv_df[column],csv_df[self.args.y])
+ res_y = stats.linregress(csv_df[column], csv_df[self.args.y])
rsquared_y = res_y.rvalue**2
if rsquared_y < self.args.thres_y:
descriptors_drop.append(column)
- txt_corr += f'\n - {column}: R**2 = {rsquared_y:.2} with the {self.args.y} values'
+ txt_corr += f"\n - {column}: R**2 = {rsquared_y:.2} with the {self.args.y} values"
self.args.log.write(txt_corr)
# Second pass: remove highly correlated descriptors (always removing the most correlated first)
- txt_corr = ''
+ txt_corr = ""
csv_df_filtered = csv_df.drop(descriptors_drop, axis=1)
csv_df_X_filtered = csv_df_filtered.drop([self.args.y] + self.args.ignore, axis=1)
- if self.args.corr_filter_x and len(csv_df_X_filtered.columns) > 1:
+ if self.args.corr_filter_x and len(csv_df_X_filtered.columns) > 1:
# Calculate R2 correlation matrix between descriptors
corr_matrix = csv_df_X_filtered.corr().abs()
- corr_matrix_r2 = corr_matrix ** 2
- upper = corr_matrix_r2.where(np.triu(np.ones(corr_matrix_r2.shape), k=1).astype(bool))
-
+ corr_matrix_r2 = corr_matrix**2
+ upper = corr_matrix_r2.where(
+ np.triu(np.ones(corr_matrix_r2.shape), k=1).astype(bool)
+ )
+
# Calculate R2 of each descriptor with y (for deciding which one to drop)
r2_with_y = {}
for col in csv_df_X_filtered.columns:
res_y = stats.linregress(csv_df_filtered[col], csv_df_filtered[self.args.y])
r2_with_y[col] = res_y.rvalue**2
-
+
# Iteratively remove the most correlated descriptors
while True:
# Find the maximum R2 correlation
max_r2 = upper.max().max()
- if max_r2 <= self.args.thres_x or str(max_r2).lower() == 'nan':
+ if max_r2 <= self.args.thres_x or str(max_r2).lower() == "nan":
break
-
+
# Get ALL pairs with maximum correlation, round to avoid floating point issues
upper_rounded = upper.round(10)
max_r2_rounded = upper_rounded.max().max()
row_idx, col_idx = np.where(upper_rounded == max_r2_rounded)
-
+
# Sort pairs alphabetically for deterministic selection
- pairs = [(upper.index[row_idx[i]], upper.columns[col_idx[i]]) for i in range(len(row_idx))]
+ pairs = [
+ (upper.index[row_idx[i]], upper.columns[col_idx[i]])
+ for i in range(len(row_idx))
+ ]
pairs.sort()
col_name_1, col_name_2 = pairs[0]
-
+
# Drop the descriptor with lower R2 to y, round for comparison
r2_1 = round(r2_with_y[col_name_1], 10)
r2_2 = round(r2_with_y[col_name_2], 10)
-
+
if r2_1 == r2_2:
# Tied on R2 with y, drop alphabetically later one
drop_col = col_name_1 if col_name_1 > col_name_2 else col_name_2
@@ -776,10 +884,10 @@ def correlation_filter(self, csv_df):
else:
drop_col = col_name_2
keep_col = col_name_1
-
+
descriptors_drop.append(drop_col)
- txt_corr += f'\n - {drop_col} removed (R2 = {max_r2:.2f} with {keep_col}), kept more predictive descriptor'
-
+ txt_corr += f"\n - {drop_col} removed (R2 = {max_r2:.2f} with {keep_col}), kept more predictive descriptor"
+
upper = upper.drop(index=drop_col, columns=drop_col)
del r2_with_y[drop_col]
@@ -787,31 +895,31 @@ def correlation_filter(self, csv_df):
to_drop = [col for col in csv_df_X_filtered.columns if col not in upper.columns]
if to_drop:
- txt_corr += f'\n - Total: {len(to_drop)} descriptors removed due to high correlation with other descriptors'
+ txt_corr += f"\n - Total: {len(to_drop)} descriptors removed due to high correlation with other descriptors"
# drop descriptors that did not pass the filters
csv_df_filtered = csv_df.drop(descriptors_drop, axis=1)
if len(descriptors_drop) == 0:
- txt_corr += f'\n - No descriptors were removed'
+ txt_corr += "\n - No descriptors were removed"
self.args.log.write(txt_corr)
# Check if descriptors are more than one third of datapoints
- txt_corr = ''
+ txt_corr = ""
descriptors_used = {}
csv_df_per_model = {}
num_descriptors = round(len(csv_df[self.args.y]) / 3)
if n_descps > num_descriptors:
- cv_type = f'{self.args.repeat_kfolds}x {self.args.kfold}_fold_cv'
- txt_corr += f'\no There are more descriptors than one-third of the data points. A Recursive Feature Elimination with Cross-Validation (RFECV) or permutation feature importance (PFI) using {cv_type} will be performed to select the most relevant descriptors for each model'
+ cv_type = f"{self.args.repeat_kfolds}x {self.args.kfold}_fold_cv"
+ txt_corr += f"\no There are more descriptors than one-third of the data points. A Recursive Feature Elimination with Cross-Validation (RFECV) or permutation feature importance (PFI) using {cv_type} will be performed to select the most relevant descriptors for each model"
self.args.log.write(txt_corr)
- txt_corr = ''
+ txt_corr = ""
# Perform RFECV for each model specified by the user
X_df = csv_df_filtered.drop([self.args.y] + self.args.ignore, axis=1)
- X_scaled_df,_ = scale_df(X_df,None)
+ X_scaled_df, _ = scale_df(X_df, None)
y_df = csv_df_filtered[self.args.y]
for model in sorted(self.args.model):
@@ -822,98 +930,143 @@ def correlation_filter(self, csv_df):
estimator = load_model(self, model, **rfecv_params)
# Repeated kfold-CV type
- cv_model = RepeatedKFold(n_splits=self.args.kfold, n_repeats=self.args.repeat_kfolds, random_state=self.args.seed)
+ cv_model = RepeatedKFold(
+ n_splits=self.args.kfold,
+ n_repeats=self.args.repeat_kfolds,
+ random_state=self.args.seed,
+ )
# Select scoring function for RFECV analysis based on the error type
- scoring = get_scoring_key(self.args.type,self.args.error_type)
+ scoring = get_scoring_key(self.args.type, self.args.error_type)
# Use different strategies for models without feature_importances_
- if model.upper() in ['NN', 'GP', 'VR']:
+ if model.upper() in ["NN", "GP", "VR"]:
# For NN, GP and VR, use a simpler approach: select top features by correlation with y
# after initial fit, then use permutation importance to rank them
-
+
# Train the model once on all features
estimator.fit(X_scaled_df, y_df)
-
+
# Use permutation importance to rank features
- perm_result = permutation_importance(estimator, X_scaled_df, y_df,
- n_repeats=self.args.pfi_epochs,
- random_state=self.args.seed,
- scoring=scoring,
- n_jobs=1) # Force single thread for reproducibility
-
+ perm_result = permutation_importance(
+ estimator,
+ X_scaled_df,
+ y_df,
+ n_repeats=self.args.pfi_epochs,
+ random_state=self.args.seed,
+ scoring=scoring,
+ n_jobs=1,
+ ) # Force single thread for reproducibility
+
# Round to reduce floating point variance
- feature_importances = np.round(perm_result.importances_mean, decimals=10)
-
+ feature_importances = np.round(
+ perm_result.importances_mean, decimals=10
+ )
+
# Create list of (importance, name) tuples for ALL features
- importance_with_names = [(feature_importances[i], X_scaled_df.columns[i])
- for i in range(len(feature_importances))]
-
+ importance_with_names = [
+ (feature_importances[i], X_scaled_df.columns[i])
+ for i in range(len(feature_importances))
+ ]
+
# Sort by importance (descending) and break ties alphabetically for determinism
- importance_with_names.sort(key=lambda x: (-x[0], x[1])) # Sort by importance DESC, then name ASC
-
+ importance_with_names.sort(
+ key=lambda x: (-x[0], x[1])
+ ) # Sort by importance DESC, then name ASC
+
# Select top num_descriptors features (or all with positive importance if fewer)
- positive_features = [(imp, name) for imp, name in importance_with_names if imp > 0]
+ positive_features = [
+ (imp, name) for imp, name in importance_with_names if imp > 0
+ ]
if len(positive_features) > num_descriptors:
- descriptors_used[model] = [name for _, name in positive_features[:num_descriptors]]
+ descriptors_used[model] = [
+ name for _, name in positive_features[:num_descriptors]
+ ]
else:
descriptors_used[model] = [name for _, name in positive_features]
-
+
# Sort final list alphabetically for consistent ordering in output
descriptors_used[model] = sorted(descriptors_used[model])
-
- txt_corr += f'\n - {model}: {len(descriptors_used[model])} descriptors selected (using PFI)'
-
+
+ txt_corr += f"\n - {model}: {len(descriptors_used[model])} descriptors selected (using PFI)"
+
else:
# MVL, RF, GB, ADAB: use RFECV with feature importances
# Set step=1 for most stable/deterministic feature elimination
- selector = RFECV(estimator, scoring=scoring, min_features_to_select=2, cv=cv_model, step=1, n_jobs=1)
+ selector = RFECV(
+ estimator,
+ scoring=scoring,
+ min_features_to_select=2,
+ cv=cv_model,
+ step=1,
+ n_jobs=1,
+ )
selector.fit(X_scaled_df, y_df)
-
+
# Get selected features
selected_mask = selector.support_
selected_features_list = list(X_scaled_df.columns[selected_mask])
-
+
# Get feature importances for selected features only
- if model.upper() == 'MVL':
+ if model.upper() == "MVL":
# For MVL, use absolute coefficients as importance
feature_importances = np.abs(selector.estimator_.coef_)
- else:
+ else:
# RF, GB, ADAB, XGB have feature_importances_
feature_importances = selector.estimator_.feature_importances_
-
+
# Round importances to reduce floating point variance
feature_importances = np.round(feature_importances, decimals=10)
-
+
# Create (importance, name) pairs for selected features
- importance_with_names = list(zip(feature_importances, selected_features_list))
-
+ importance_with_names = list(
+ zip(feature_importances, selected_features_list)
+ )
+
# Sort by importance (descending) with alphabetical tie-breaking for determinism
- importance_with_names.sort(key=lambda x: (-x[0], x[1])) # Sort by importance DESC, then name ASC
-
+ importance_with_names.sort(
+ key=lambda x: (-x[0], x[1])
+ ) # Sort by importance DESC, then name ASC
+
# Select top num_descriptors (or all if fewer selected)
n_to_select = min(num_descriptors, len(importance_with_names))
- descriptors_used[model] = [name for _, name in importance_with_names[:n_to_select]]
-
+ descriptors_used[model] = [
+ name for _, name in importance_with_names[:n_to_select]
+ ]
+
# Sort final list alphabetically for consistent ordering in output
descriptors_used[model] = sorted(descriptors_used[model])
-
- txt_corr += f'\n - {model}: {len(descriptors_used[model])} descriptors selected (using RFECV)'
-
+
+ txt_corr += f"\n - {model}: {len(descriptors_used[model])} descriptors selected (using RFECV)"
+
# Create model-specific dataframe with sorted columns for reproducibility
keep_cols = descriptors_used[model] + [self.args.y] + self.args.ignore
- keep_cols = list(dict.fromkeys(keep_cols)) # Remove duplicates preserving order
+ keep_cols = list(
+ dict.fromkeys(keep_cols)
+ ) # Remove duplicates preserving order
keep_cols = [col for col in keep_cols if col in csv_df_filtered.columns]
-
+
# Sort descriptor columns alphabetically, keep y and ignore at the end
- descriptor_cols = [col for col in keep_cols if col not in self.args.ignore and col != self.args.y]
- other_cols = [col for col in keep_cols if col in self.args.ignore or col == self.args.y]
- sorted_cols = sorted(descriptor_cols) + sorted([col for col in other_cols if col in self.args.ignore]) + [self.args.y]
-
+ descriptor_cols = [
+ col
+ for col in keep_cols
+ if col not in self.args.ignore and col != self.args.y
+ ]
+ other_cols = [
+ col
+ for col in keep_cols
+ if col in self.args.ignore or col == self.args.y
+ ]
+ sorted_cols = (
+ sorted(descriptor_cols)
+ + sorted([col for col in other_cols if col in self.args.ignore])
+ + [self.args.y]
+ )
+
csv_df_per_model[model] = csv_df_filtered[sorted_cols].copy()
else:
- txt_corr += f'\n x The RFECV filter was not applied, there are less descriptors than one-third of the data points ({len(csv_df_filtered.columns)-len(self.args.ignore)-1} <= {num_descriptors})'
+ txt_corr += f"\n x The RFECV filter was not applied, there are less descriptors than one-third of the data points ({len(csv_df_filtered.columns) - len(self.args.ignore) - 1} <= {num_descriptors})"
# If RFECV is not applied, all models use the same filtered dataframe
for model in self.args.model:
csv_df_per_model[model] = csv_df_filtered
@@ -925,126 +1078,150 @@ def correlation_filter(self, csv_df):
def load_minimal_model(model):
- '''
+ """
Load the parameters of the minimalist models used for REFCV
- '''
+ """
minimal_params = {
- 'RF' : {
- 'n_estimators': 30,
- 'max_depth': 10,
- 'min_samples_split': 2,
- 'min_samples_leaf': 1,
- 'min_weight_fraction_leaf': 0,
- 'max_features': 1,
- 'ccp_alpha': 0.0,
- 'max_samples': None
- },
- 'GB': {
- 'n_estimators': 30,
- 'learning_rate': 0.1,
- 'max_depth': 10,
- 'min_samples_split': 2,
- 'min_samples_leaf': 1,
- 'subsample': 1.0,
- 'max_features': None,
- 'validation_fraction': 0.2,
- 'min_weight_fraction_leaf': 0.0,
- 'ccp_alpha': 0.0
+ "RF": {
+ "n_estimators": 30,
+ "max_depth": 10,
+ "min_samples_split": 2,
+ "min_samples_leaf": 1,
+ "min_weight_fraction_leaf": 0,
+ "max_features": 1,
+ "ccp_alpha": 0.0,
+ "max_samples": None,
},
- 'NN': {
- 'hidden_layer_1': 4,
- 'hidden_layer_2': 4,
- 'max_iter': 200,
- 'alpha': 0.01,
- 'tol': 0.0001
+ "GB": {
+ "n_estimators": 30,
+ "learning_rate": 0.1,
+ "max_depth": 10,
+ "min_samples_split": 2,
+ "min_samples_leaf": 1,
+ "subsample": 1.0,
+ "max_features": None,
+ "validation_fraction": 0.2,
+ "min_weight_fraction_leaf": 0.0,
+ "ccp_alpha": 0.0,
},
- 'ADAB': {
- 'learning_rate': 1.0,
- 'n_estimators': 30
+ "NN": {
+ "hidden_layer_1": 4,
+ "hidden_layer_2": 4,
+ "max_iter": 200,
+ "alpha": 0.01,
+ "tol": 0.0001,
},
- 'GP': {
- 'n_restarts_optimizer': 30,
+ "ADAB": {"learning_rate": 1.0, "n_estimators": 30},
+ "GP": {
+ "n_restarts_optimizer": 30,
},
- 'XGB': {
- 'n_estimators': 30,
- 'learning_rate': 0.1,
- 'max_depth': 10,
- 'min_child_weight': 1,
- 'subsample': 1.0,
- 'colsample_bytree': 1.0,
- 'reg_alpha': 0.0,
- 'reg_lambda': 1.0,
+ "XGB": {
+ "n_estimators": 30,
+ "learning_rate": 0.1,
+ "max_depth": 10,
+ "min_child_weight": 1,
+ "subsample": 1.0,
+ "colsample_bytree": 1.0,
+ "reg_alpha": 0.0,
+ "reg_lambda": 1.0,
},
- 'MVL': {
- },
- 'VR': {
- 'w_rf': 1.0,
- 'w_gb': 1.0,
- 'w_nn': 1.0,
- }
+ "MVL": {},
+ }
+ minimal_params["VR"] = {
+ "w_rf": 1.0,
+ "w_gb": 1.0,
+ "w_nn": 1.0,
+ **{f"rf_{key}": value for key, value in minimal_params["RF"].items()},
+ **{f"gb_{key}": value for key, value in minimal_params["GB"].items()},
+ **{f"nn_{key}": value for key, value in minimal_params["NN"].items()},
}
return minimal_params[model]
-def mcc_scorer_clf(y_true,y_pred):
+
+def _round_vr_member_params(params):
+ """Round integer hyperparameters for VR member models (rf_*, gb_*, nn_*)."""
+ rf_int = {"n_estimators", "max_depth", "min_samples_split", "min_samples_leaf"}
+ gb_int = {"n_estimators", "max_depth", "min_samples_split", "min_samples_leaf"}
+ nn_int = {"max_iter", "hidden_layer_1", "hidden_layer_2"}
+ for key in list(params.keys()):
+ if key.startswith("rf_"):
+ if key[3:] in rf_int:
+ params[key] = round(params[key])
+ elif key.startswith("gb_"):
+ if key[3:] in gb_int:
+ params[key] = round(params[key])
+ elif key.startswith("nn_"):
+ if key[3:] in nn_int:
+ params[key] = round(params[key])
+
+
+def _pop_vr_member_params(params, prefix, defaults):
+ """Extract ``prefix_*`` keys into a member-model parameter dict."""
+ member = dict(defaults)
+ for key in list(params.keys()):
+ if key.startswith(f"{prefix}_"):
+ member[key[len(prefix) + 1 :]] = params.pop(key)
+ return member
+
+
+def mcc_scorer_clf(y_true, y_pred):
"""Forces classification predictions to integer for MCC."""
# Even if .predict() returns floats, coerce them to integer:
y_pred = np.round(y_pred).astype(int)
-
+
return matthews_corrcoef(y_true, y_pred)
-def get_scoring_key(problem_type,error_type):
- '''
+
+def get_scoring_key(problem_type, error_type):
+ """
Load scoring function for evaluating models
- '''
+ """
- if problem_type.lower() == 'reg':
+ if problem_type.lower() == "reg":
scoring = {
- 'rmse': 'neg_root_mean_squared_error',
- 'mae': 'neg_median_absolute_error',
- 'r2': 'r2'
+ "rmse": "neg_root_mean_squared_error",
+ "mae": "neg_median_absolute_error",
+ "r2": "r2",
}.get(error_type)
else:
# For classification
- if error_type == 'mcc':
+ if error_type == "mcc":
# Use the custom MCC scorer that ensures integer predictions
scoring = make_scorer(mcc_scorer_clf)
else:
- scoring = {
- 'f1': 'f1',
- 'acc': 'accuracy'
- }.get(error_type)
-
+ scoring = {"f1": "f1", "acc": "accuracy"}.get(error_type)
+
return scoring
-def check_csv_option(self,csv_option,print_err):
- '''
+def check_csv_option(self, csv_option, print_err):
+ """
Checks missing values in input CSV options
- '''
-
- if csv_option == 'csv_name':
- line_print = f'\nx Specify the CSV name for the {csv_option} option!'
- elif csv_option == 'csv_train':
- line_print = f'\nx Specify the CSV name containing the TRAINING set!'
- elif csv_option == 'csv_valid':
- line_print = f'\nx Specify the CSV name containing the VALIDATION set!'
+ """
+
+ if csv_option == "csv_name":
+ line_print = f"\nx Specify the CSV name for the {csv_option} option!"
+ elif csv_option == "csv_train":
+ line_print = "\nx Specify the CSV name containing the TRAINING set!"
+ elif csv_option == "csv_valid":
+ line_print = "\nx Specify the CSV name containing the VALIDATION set!"
if print_err:
print(line_print)
else:
self.log.write(line_print)
- val_option = input('Enter the name of your CSV file: ')
- self.extra_cmd += f' --{csv_option} {val_option}'
+ val_option = input("Enter the name of your CSV file: ")
+ self.extra_cmd += f" --{csv_option} {val_option}"
if not print_err:
self.log.write(f" - {csv_option} option set to {val_option} by the user")
- if csv_option == 'csv_name':
- self.csv_name = val_option
- elif csv_option == 'csv_train':
+ if csv_option == "csv_name":
+ self.csv_name = val_option
+ elif csv_option == "csv_train":
self.csv_train = val_option
- elif csv_option == 'csv_valid':
+ elif csv_option == "csv_valid":
self.csv_valid = val_option
return self
@@ -1057,116 +1234,183 @@ def sanity_checks(self, type_checks, module, columns_csv):
curate_valid = True
# adds manual inputs missing from the command line
- self = missing_inputs(self,module)
+ self = missing_inputs(self, module)
- if module.lower() == 'evaluate':
- curate_valid = locate_csv(self,self.csv_name,curate_valid)
+ if module.lower() == "evaluate":
+ curate_valid = locate_csv(self, self.csv_name, curate_valid)
- if self.eval_model.lower() not in ['mvl']:
- self.log.write(f"\nx The eval_model option used is not valid! Options: 'MVL' (more options will be added soon)")
+ if self.eval_model.lower() not in ["mvl"]:
+ self.log.write(
+ "\nx The eval_model option used is not valid! Options: 'MVL' (more options will be added soon)"
+ )
curate_valid = False
- if self.type.lower() not in ['reg']:
- self.log.write(f"\nx The type option used is not valid in EVALUATE! Options: 'reg' (the 'clas' option will be added soon)")
+ if self.type.lower() not in ["reg"]:
+ self.log.write(
+ "\nx The type option used is not valid in EVALUATE! Options: 'reg' (the 'clas' option will be added soon)"
+ )
curate_valid = False
- elif type_checks == 'initial' and module.lower() not in ['verify','predict']:
-
- curate_valid = locate_csv(self,self.csv_name,curate_valid)
+ elif type_checks == "initial" and module.lower() not in ["verify", "predict"]:
+ curate_valid = locate_csv(self, self.csv_name, curate_valid)
- if module.lower() == 'curate':
- if self.categorical.lower() not in ['onehot','numbers']:
- self.log.write(f"\nx The categorical option used is not valid! Options: 'onehot', 'numbers'")
+ if module.lower() == "curate":
+ if self.categorical.lower() not in ["onehot", "numbers"]:
+ self.log.write(
+ "\nx The categorical option used is not valid! Options: 'onehot', 'numbers'"
+ )
curate_valid = False
- for thres,thres_name in zip([self.thres_x,self.thres_y],['thres_x','thres_y']):
+ for thres, thres_name in zip(
+ [self.thres_x, self.thres_y], ["thres_x", "thres_y"]
+ ):
if float(thres) > 1 or float(thres) < 0:
- self.log.write(f"\nx The {thres_name} option should be between 0 and 1!")
+ self.log.write(
+ f"\nx The {thres_name} option should be between 0 and 1!"
+ )
curate_valid = False
-
- elif module.lower() == 'generate':
- if self.split.lower() not in ['kn','rnd','stratified','even','extra_q1','extra_q5','auto']:
- self.log.write(f"\nx The split option used is not valid! Options: 'KN', 'RND'")
+
+ elif module.lower() == "generate":
+ if self.split.lower() not in [
+ "kn",
+ "rnd",
+ "stratified",
+ "even",
+ "extra_q1",
+ "extra_q5",
+ "auto",
+ ]:
+ self.log.write(
+ "\nx The split option used is not valid! Options: 'KN', 'RND'"
+ )
curate_valid = False
- if self.split == 'auto':
- if self.type.lower() == 'reg':
- self.split = 'even'
- elif self.type.lower() == 'clas':
- self.split = 'rnd'
+ if self.split == "auto":
+ if self.type.lower() == "reg":
+ self.split = "even"
+ elif self.type.lower() == "clas":
+ self.split = "rnd"
for model_type in self.model:
- if model_type.upper() not in ['RF','MVL','GB','GP','ADAB','NN','XGB','VR'] or len(self.model) == 0:
- self.log.write(f"\nx The model option used is not valid! Options: 'RF', 'MVL', 'GB', 'GP', 'ADAB', 'NN', 'XGB', 'VR'")
+ if (
+ model_type.upper()
+ not in ["RF", "MVL", "GB", "GP", "ADAB", "NN", "XGB", "VR"]
+ or len(self.model) == 0
+ ):
+ self.log.write(
+ "\nx The model option used is not valid! Options: 'RF', 'MVL', 'GB', 'GP', 'ADAB', 'NN', 'XGB', 'VR'"
+ )
curate_valid = False
- if model_type.upper() == 'MVL' and self.type.lower() == 'clas':
- self.log.write(f"\nx Multivariate linear models (MVL in the model_type option) are not compatible with classificaton!")
+ if model_type.upper() == "MVL" and self.type.lower() == "clas":
+ self.log.write(
+ "\nx Multivariate linear models (MVL in the model_type option) are not compatible with classificaton!"
+ )
curate_valid = False
- if self.type.lower() not in ['reg','clas']:
- self.log.write(f"\nx The type option used is not valid! Options: 'reg', 'clas'")
+ if self.type.lower() not in ["reg", "clas"]:
+ self.log.write(
+ "\nx The type option used is not valid! Options: 'reg', 'clas'"
+ )
curate_valid = False
- if type_checks == 'initial' and module.lower() in ['generate','verify','predict','report']:
-
- if type_checks == 'initial' and module.lower() in ['generate','verify']:
- if self.type.lower() == 'reg' and self.error_type.lower() not in ['rmse','mae','r2']:
- self.log.write(f"\nx The error_type option is not valid! Options for regression: 'rmse', 'mae', 'r2'")
+ if type_checks == "initial" and module.lower() in [
+ "generate",
+ "verify",
+ "predict",
+ "report",
+ ]:
+ if type_checks == "initial" and module.lower() in ["generate", "verify"]:
+ if self.type.lower() == "reg" and self.error_type.lower() not in [
+ "rmse",
+ "mae",
+ "r2",
+ ]:
+ self.log.write(
+ "\nx The error_type option is not valid! Options for regression: 'rmse', 'mae', 'r2'"
+ )
curate_valid = False
- if self.type.lower() == 'clas' and self.error_type.lower() not in ['mcc','f1','acc']:
- self.log.write(f"\nx The error_type option is not valid! Options for classification: 'mcc', 'f1', 'acc'")
+ if self.type.lower() == "clas" and self.error_type.lower() not in [
+ "mcc",
+ "f1",
+ "acc",
+ ]:
+ self.log.write(
+ "\nx The error_type option is not valid! Options for classification: 'mcc', 'f1', 'acc'"
+ )
curate_valid = False
- if module.lower() in ['verify','predict']:
+ if module.lower() in ["verify", "predict"]:
if os.getcwd() in f"{self.params_dir}":
path_db = self.params_dir
else:
path_db = f"{Path(os.getcwd()).joinpath(self.params_dir)}"
if not os.path.exists(path_db):
- self.log.write(f'\nx The path of your CSV files doesn\'t exist! Set the folder containing the two CSV files with 1) the parameters of the model and 2) the Xy database with the params_dir option')
+ self.log.write(
+ "\nx The path of your CSV files doesn't exist! Set the folder containing the two CSV files with 1) the parameters of the model and 2) the Xy database with the params_dir option"
+ )
curate_valid = False
- if module.lower() == 'predict':
+ if module.lower() == "predict":
if self.t_value < 0:
self.log.write(f"\nx t_value ({self.t_value}) should be higher 0!")
curate_valid = False
- if self.csv_test != '':
+ if self.csv_test != "":
if os.getcwd() in f"{self.csv_test}":
path_test = self.csv_test
else:
path_test = f"{Path(os.getcwd()).joinpath(self.csv_test)}"
if not os.path.exists(path_test):
- self.log.write(f'\nx The path of your CSV file with the test set doesn\'t exist! You specified: {self.csv_test}')
+ self.log.write(
+ f"\nx The path of your CSV file with the test set doesn't exist! You specified: {self.csv_test}"
+ )
curate_valid = False
- if module.lower() == 'report':
+ if module.lower() == "report":
if len(self.report_modules) == 0:
- self.log.write(f'\nx No modules were provided in the report_modules option! Options: "CURATE", "GENERATE", "VERIFY", "PREDICT"')
+ self.log.write(
+ '\nx No modules were provided in the report_modules option! Options: "CURATE", "GENERATE", "VERIFY", "PREDICT"'
+ )
curate_valid = False
for module in self.report_modules:
- if module.upper() not in ['CURATE','GENERATE','VERIFY','PREDICT','AQME']:
- self.log.write(f'\nx Module {module} specified in the report_modules option is not a valid module! Options: "CURATE", "GENERATE", "VERIFY", "PREDICT", "AQME"')
+ if module.upper() not in [
+ "CURATE",
+ "GENERATE",
+ "VERIFY",
+ "PREDICT",
+ "AQME",
+ ]:
+ self.log.write(
+ f'\nx Module {module} specified in the report_modules option is not a valid module! Options: "CURATE", "GENERATE", "VERIFY", "PREDICT", "AQME"'
+ )
curate_valid = False
-
- elif type_checks == 'csv_db':
- if module.lower() != 'predict':
+
+ elif type_checks == "csv_db":
+ if module.lower() != "predict":
if self.y not in columns_csv:
- if self.y.lower() in columns_csv: # accounts for upper/lowercase mismatches
+ if (
+ self.y.lower() in columns_csv
+ ): # accounts for upper/lowercase mismatches
self.y = self.y.lower()
elif self.y.upper() in columns_csv:
self.y = self.y.upper()
else:
- self.log.write(f"\nx The y option specified ({self.y}) is not a column in the csv selected ({self.csv_name})! If you are using command lines, make sure you add quotation marks like --y \"VALUE\"")
+ self.log.write(
+ f'\nx The y option specified ({self.y}) is not a column in the csv selected ({self.csv_name})! If you are using command lines, make sure you add quotation marks like --y "VALUE"'
+ )
curate_valid = False
- for option,option_name in zip([self.discard,self.ignore],['discard','ignore']):
+ for option, option_name in zip(
+ [self.discard, self.ignore], ["discard", "ignore"]
+ ):
for val in option:
if val not in columns_csv:
- self.log.write(f"\nx Descriptor {val} specified in the {option_name} option is not a column in the csv selected ({self.csv_name})!")
+ self.log.write(
+ f"\nx Descriptor {val} specified in the {option_name} option is not a column in the csv selected ({self.csv_name})!"
+ )
curate_valid = False
if not curate_valid:
@@ -1174,50 +1418,58 @@ def sanity_checks(self, type_checks, module, columns_csv):
sys.exit()
-def locate_csv(self,csv_input,curate_valid):
- '''
+def locate_csv(self, csv_input, curate_valid):
+ """
Assesses whether the input CSV databases can be located
- '''
+ """
- path_csv = ''
+ path_csv = ""
if os.path.exists(f"{csv_input}"):
path_csv = csv_input
elif os.path.exists(f"{Path(os.getcwd()).joinpath(csv_input)}"):
path_csv = f"{Path(os.getcwd()).joinpath(csv_input)}"
- if not os.path.exists(path_csv) or csv_input == '':
- self.log.write(f'\nx The path of your CSV file doesn\'t exist! You specified: --csv_name {csv_input}')
+ if not os.path.exists(path_csv) or csv_input == "":
+ self.log.write(
+ f"\nx The path of your CSV file doesn't exist! You specified: --csv_name {csv_input}"
+ )
curate_valid = False
-
+
return curate_valid
-def check_clas_problem(self,csv_df):
- '''
+def check_clas_problem(self, csv_df):
+ """
Changes type to classification if there are only two different y values.
Automatically converts any pair of values (strings or numbers) to 0 and 1.
Stores the original labels for later reconversion in outputs.
- '''
+ """
# changes type to classification if there are only two different y values
- if self.args.type.lower() == 'reg' and self.args.auto_type:
+ if self.args.type.lower() == "reg" and self.args.auto_type:
num_unique = len(set(csv_df[self.args.y]))
if num_unique == 2:
- self.args.type = 'clas'
- if self.args.error_type not in ['acc', 'mcc', 'f1']:
- self.args.error_type = 'mcc'
- if ('MVL' or 'mvl') in self.args.model:
- self.args.model = [x if x.upper() != 'MVL' else 'ADAB' for x in self.args.model]
+ self.args.type = "clas"
+ if self.args.error_type not in ["acc", "mcc", "f1"]:
+ self.args.error_type = "mcc"
+ if ("MVL" or "mvl") in self.args.model:
+ self.args.model = [
+ x if x.upper() != "MVL" else "ADAB" for x in self.args.model
+ ]
unique_vals = list(set(csv_df[self.args.y]))
- y_val_detect = f'{unique_vals[0]} and {unique_vals[1]}'
- self.args.log.write(f'\no Only two different y values were detected ({y_val_detect})! The program will consider classification models (same effect as using "--type clas"). This option can be disabled with "--auto_type False"')
+ y_val_detect = f"{unique_vals[0]} and {unique_vals[1]}"
+ self.args.log.write(
+ f'\no Only two different y values were detected ({y_val_detect})! The program will consider classification models (same effect as using "--type clas"). This option can be disabled with "--auto_type False"'
+ )
- if self.args.type.lower() == 'clas':
+ if self.args.type.lower() == "clas":
if len(set(csv_df[self.args.y])) == 2:
- unique_values = sorted(list(set(csv_df[self.args.y]))) # Sort alphabetically for consistency
-
+ unique_values = sorted(
+ list(set(csv_df[self.args.y]))
+ ) # Sort alphabetically for consistency
+
# Check if values are already 0 and 1
- if set([str(v) for v in unique_values]) == {'0', '1'}:
+ if set([str(v) for v in unique_values]) == {"0", "1"}:
# Already in correct format, just ensure they're integers
csv_df[self.args.y] = csv_df[self.args.y].astype(int)
else:
@@ -1225,169 +1477,224 @@ def check_clas_problem(self,csv_df):
# Store original labels for reconversion in outputs
self.args.class_0_label = str(unique_values[0])
self.args.class_1_label = str(unique_values[1])
-
+
# Create mapping dictionaries
self.args.class_mapping = {unique_values[0]: 0, unique_values[1]: 1}
- self.args.class_mapping_reverse = {0: unique_values[0], 1: unique_values[1]}
-
+ self.args.class_mapping_reverse = {
+ 0: unique_values[0],
+ 1: unique_values[1],
+ }
+
# Convert values in dataframe
csv_df[self.args.y] = csv_df[self.args.y].map(self.args.class_mapping)
-
- self.args.log.write(f'\no Classification labels converted: {self.args.class_0_label} → 0, {self.args.class_1_label} → 1')
- self.args.log.write(f' Original labels will be restored in output files')
-
+
+ self.args.log.write(
+ f"\no Classification labels converted: {self.args.class_0_label} → 0, {self.args.class_1_label} → 1"
+ )
+ self.args.log.write(
+ " Original labels will be restored in output files"
+ )
+
# Check that each class has at least 5 points
class_counts = csv_df[self.args.y].value_counts()
min_class_count = class_counts.min()
min_class_label = class_counts.idxmin()
-
+
if min_class_count < 5:
# Get original label if available
- if hasattr(self.args, 'class_mapping_reverse') and min_class_label in self.args.class_mapping_reverse:
+ if (
+ hasattr(self.args, "class_mapping_reverse")
+ and min_class_label in self.args.class_mapping_reverse
+ ):
original_label = self.args.class_mapping_reverse[min_class_label]
else:
original_label = min_class_label
-
+
# Convert class_counts to dict with regular Python ints
class_dist = {int(k): int(v) for k, v in class_counts.items()}
-
- self.args.log.write(f'\nx Insufficient data for classification! One of the classes has only {min_class_count} datapoints (class "{original_label}")')
- self.args.log.write(f' Each class must have at least 5 datapoints to ensure robust train/validation/test splits')
- self.args.log.write(f' Current distribution: {class_dist}')
- self.args.log.write(f' Please add more datapoints for the minority class or consider a different approach')
+
+ self.args.log.write(
+ f'\nx Insufficient data for classification! One of the classes has only {min_class_count} datapoints (class "{original_label}")'
+ )
+ self.args.log.write(
+ " Each class must have at least 5 datapoints to ensure robust train/validation/test splits"
+ )
+ self.args.log.write(f" Current distribution: {class_dist}")
+ self.args.log.write(
+ " Please add more datapoints for the minority class or consider a different approach"
+ )
self.args.log.finalize()
sys.exit()
return self
-
-def load_database(self,csv_load,module,print_info=True,external_test=False):
- '''
+
+def load_database(self, csv_load, module, print_info=True, external_test=False):
+ """
Loads either a Xy (params=False) or a parameter (params=True) database from a CSV file
- '''
-
+ """
+
# adjust external set in AQME workflows
- if module.lower() == 'aqme_test':
+ if module.lower() == "aqme_test":
external_test = True
- txt_load = ''
+ txt_load = ""
# Semicolon-separated "CSV" from Excel: peek at the first rows before reading the whole file.
_scan_limit = 64
head_lines = []
- with open(csv_load, 'r', encoding='utf-8') as file:
+ with open(csv_load, "r", encoding="utf-8") as file:
for _, line in zip(range(_scan_limit), file):
head_lines.append(line)
- semicolon_issue = len(head_lines) >= 2 and head_lines[1].count(';') > 1
+ semicolon_issue = len(head_lines) >= 2 and head_lines[1].count(";") > 1
if semicolon_issue:
- with open(csv_load, 'r', encoding='utf-8') as file:
+ with open(csv_load, "r", encoding="utf-8") as file:
lines = file.readlines()
if semicolon_issue:
- new_csv_name = os.path.basename(csv_load).split('.csv')[0].split('.CSV')[0]+'_original.csv'
+ new_csv_name = (
+ os.path.basename(csv_load).split(".csv")[0].split(".CSV")[0]
+ + "_original.csv"
+ )
shutil.move(csv_load, Path(os.path.dirname(csv_load)).joinpath(new_csv_name))
new_csv_file = open(csv_load, "w")
for line in lines:
- line = line.replace(',','.')
- line = line.replace(';',',')
+ line = line.replace(",", ".")
+ line = line.replace(";", ",")
# line = line.replace(':',',')
new_csv_file.write(line)
new_csv_file.close()
- txt_load += f'\nx WARNING! The original database was not a valid CSV (i.e., formatting issues from Microsoft Excel?). A new database using commas as separators was created and used instead, and the original database was stored as {new_csv_name}. To prevent this issue from happening again, you should use commas as separators: https://support.edapp.com/change-csv-separator.\n\n'
+ txt_load += f"\nx WARNING! The original database was not a valid CSV (i.e., formatting issues from Microsoft Excel?). A new database using commas as separators was created and used instead, and the original database was stored as {new_csv_name}. To prevent this issue from happening again, you should use commas as separators: https://support.edapp.com/change-csv-separator.\n\n"
- csv_df = pd.read_csv(csv_load, encoding='utf-8')
+ csv_df = pd.read_csv(csv_load, encoding="utf-8")
# Missing data handling: robust strategy for columns and rows (optional KNN imputer)
target_col = self.args.y
- descriptor_cols = [col for col in csv_df.columns if col not in self.args.ignore+self.args.discard and col != self.args.y]
+ descriptor_cols = [
+ col
+ for col in csv_df.columns
+ if col not in self.args.ignore + self.args.discard and col != self.args.y
+ ]
min_count = int(0.9 * len(csv_df))
# Remove columns with <90% data
- cols_to_drop = [col for col in descriptor_cols if csv_df[col].notna().sum() < min_count]
+ cols_to_drop = [
+ col for col in descriptor_cols if csv_df[col].notna().sum() < min_count
+ ]
if cols_to_drop:
csv_df = csv_df.drop(columns=cols_to_drop)
- if module.lower() == 'curate':
+ if module.lower() == "curate":
txt_load += f"\n - Removed {len(cols_to_drop)} column(s) with <90% data\n"
- descriptor_cols = [col for col in csv_df.columns if col not in self.args.ignore+self.args.discard and col != self.args.y]
-
+ descriptor_cols = [
+ col
+ for col in csv_df.columns
+ if col not in self.args.ignore + self.args.discard and col != self.args.y
+ ]
+
# Remove rows with <50% data
- rows_too_missing = csv_df[descriptor_cols].isna().sum(axis=1) > (0.5 * len(descriptor_cols))
+ rows_too_missing = csv_df[descriptor_cols].isna().sum(axis=1) > (
+ 0.5 * len(descriptor_cols)
+ )
if rows_too_missing.any():
n_removed_rows = rows_too_missing.sum()
csv_df = csv_df[~rows_too_missing].reset_index(drop=True)
- descriptor_cols = [col for col in csv_df.columns if col not in self.args.ignore+self.args.discard and col != self.args.y]
- if module.lower() == 'curate':
+ descriptor_cols = [
+ col
+ for col in csv_df.columns
+ if col not in self.args.ignore + self.args.discard and col != self.args.y
+ ]
+ if module.lower() == "curate":
txt_load += f"\n - Removed {n_removed_rows} row(s) with >50% missing descriptors\n"
-
+
# Apply KNN imputer only to numeric columns with missing values (when auto_fill is activated)
if self.args.auto_fill:
- numeric_columns = csv_df.select_dtypes(include=['float']).columns.drop(target_col, errors='ignore')
+ numeric_columns = csv_df.select_dtypes(include=["float"]).columns.drop(
+ target_col, errors="ignore"
+ )
if csv_df[numeric_columns].isna().any().any():
imputer = KNNImputer(n_neighbors=5)
- csv_df[numeric_columns] = pd.DataFrame(imputer.fit_transform(csv_df[numeric_columns]), columns=numeric_columns, index=csv_df.index)
- if module.lower() == 'curate':
- txt_load += f"\n - Applied KNN imputer to columns with missing values\n"
+ csv_df[numeric_columns] = pd.DataFrame(
+ imputer.fit_transform(csv_df[numeric_columns]),
+ columns=numeric_columns,
+ index=csv_df.index,
+ )
+ if module.lower() == "curate":
+ txt_load += (
+ "\n - Applied KNN imputer to columns with missing values\n"
+ )
else:
# Remove columns with ANY missing value
- cols_with_missing = [col for col in descriptor_cols if csv_df[col].isna().any() and col not in self.args.ignore+self.args.discard]
+ cols_with_missing = [
+ col
+ for col in descriptor_cols
+ if csv_df[col].isna().any()
+ and col not in self.args.ignore + self.args.discard
+ ]
if cols_with_missing:
csv_df = csv_df.drop(columns=cols_with_missing)
- if module.lower() == 'curate':
+ if module.lower() == "curate":
txt_load += f"\n - Removed {len(cols_with_missing)} column(s) with missing values\n"
if print_info:
- sanity_checks(self.args,'csv_db',module,csv_df.columns)
+ sanity_checks(self.args, "csv_db", module, csv_df.columns)
csv_df = csv_df.drop(self.args.discard, axis=1)
total_amount = len(csv_df.columns)
ignored_descs = len(self.args.ignore)
- accepted_descs = total_amount - ignored_descs - 1 # the y column is substracted
- if 'Set' in csv_df.columns: # removes the column that tracks sets
+ accepted_descs = total_amount - ignored_descs - 1 # the y column is substracted
+ if "Set" in csv_df.columns: # removes the column that tracks sets
accepted_descs -= 1
ignored_descs += 1
- if module.lower() not in ['aqme','aqme_test']:
+ if module.lower() not in ["aqme", "aqme_test"]:
csv_name = os.path.basename(csv_load)
- if module.lower() not in ['predict']:
- txt_load += f'\no Database {csv_name} loaded successfully, including:'
- txt_load += f'\n - {len(csv_df[self.args.y])} datapoints'
- txt_load += f'\n - {accepted_descs} accepted descriptors'
- txt_load += f'\n - {ignored_descs} ignored descriptors'
- txt_load += (
- f"\n - {len(self.args.discard)} discarded descriptors"
- )
+ if module.lower() not in ["predict"]:
+ txt_load += f"\no Database {csv_name} loaded successfully, including:"
+ txt_load += f"\n - {len(csv_df[self.args.y])} datapoints"
+ txt_load += f"\n - {accepted_descs} accepted descriptors"
+ txt_load += f"\n - {ignored_descs} ignored descriptors"
+ txt_load += f"\n - {len(self.args.discard)} discarded descriptors"
else:
txt_load += (
- f"\no External set {csv_name} loaded successfully, "
- "including:"
+ f"\no External set {csv_name} loaded successfully, including:"
)
txt_load += f"\n - {len(csv_df)} datapoints (rows)"
self.args.log.write(txt_load)
if accepted_descs == 0:
- self.args.log.write(f"\nx The aren't any valid descriptors! Check the messages above to see whether the filters have discarded descriptors")
+ self.args.log.write(
+ "\nx The aren't any valid descriptors! Check the messages above to see whether the filters have discarded descriptors"
+ )
sys.exit()
# Sort columns alphabetically for reproducibility across ALL modules
- if module.lower() not in ['aqme', 'aqme_test']:
+ if module.lower() not in ["aqme", "aqme_test"]:
# Get descriptor columns (excluding y and ignore)
- descriptor_cols = [col for col in csv_df.columns if col not in self.args.ignore and col != self.args.y]
+ descriptor_cols = [
+ col
+ for col in csv_df.columns
+ if col not in self.args.ignore and col != self.args.y
+ ]
descriptor_cols_sorted = sorted(descriptor_cols)
# Get other columns (y and ignore)
- other_cols = [col for col in csv_df.columns if col in self.args.ignore or col == self.args.y]
+ other_cols = [
+ col
+ for col in csv_df.columns
+ if col in self.args.ignore or col == self.args.y
+ ]
# Reorder dataframe with sorted descriptors + other columns
sorted_all_cols = descriptor_cols_sorted + other_cols
csv_df = csv_df[sorted_all_cols]
-
+
# ignore user-defined descriptors and assign X and y values (but keeps the original database)
- if module.lower() == 'generate':
+ if module.lower() == "generate":
# Only drop columns that actually exist in the dataframe (for model-specific CSVs from CURATE)
cols_to_ignore = [col for col in self.args.ignore if col in csv_df.columns]
# Also drop 'Set' column if it exists (for model-specific CSVs)
- if 'Set' in csv_df.columns and 'Set' not in cols_to_ignore:
- cols_to_ignore.append('Set')
+ if "Set" in csv_df.columns and "Set" not in cols_to_ignore:
+ cols_to_ignore.append("Set")
csv_df_ignore = csv_df.drop(cols_to_ignore, axis=1)
csv_X = csv_df_ignore.drop([self.args.y], axis=1)
csv_y = csv_df_ignore[self.args.y]
-
+
# Columns are already sorted from above, just extract them
csv_X = csv_X[sorted([col for col in csv_X.columns])]
-
+
else:
if external_test and self.args.y not in csv_df.columns:
csv_X = csv_df
@@ -1396,53 +1703,53 @@ def load_database(self,csv_load,module,print_info=True,external_test=False):
csv_X = csv_df.drop([self.args.y], axis=1)
csv_y = csv_df[self.args.y]
- return csv_df,csv_X,csv_y
+ return csv_df, csv_X, csv_y
-def categorical_transform(self,csv_df,module):
- ''' converts all columns with strings into categorical values (one hot encoding
+def categorical_transform(self, csv_df, module):
+ """converts all columns with strings into categorical values (one hot encoding
by default, can be set to numerical 1,2,3... with categorical = True).
Troubleshooting! For one-hot encoding, don't use variable names that are
also column headers! i.e. DESCRIPTOR "C_atom" contain C2 as a value,
but C2 is already a header of a different column in the database. Same applies
for multiple columns containing the same variable names.
- '''
+ """
- if module.lower() == 'curate':
- txt_categor = f'\no Analyzing categorical variables'
+ if module.lower() == "curate":
+ txt_categor = "\no Analyzing categorical variables"
- descriptors_to_drop, categorical_vars, new_categor_desc = [],[],[]
+ descriptors_to_drop, categorical_vars, new_categor_desc = [], [], []
for column in csv_df.columns:
if column not in self.args.ignore and column != self.args.y:
- if(csv_df[column].dtype == 'object'):
+ if csv_df[column].dtype == "object":
descriptors_to_drop.append(column)
categorical_vars.append(column)
- if self.args.categorical.lower() == 'numbers':
- csv_df[column] = csv_df[column].astype('category')
+ if self.args.categorical.lower() == "numbers":
+ csv_df[column] = csv_df[column].astype("category")
csv_df[column] = csv_df[column].cat.codes
else:
- _ = csv_df[column].unique() # is this necessary?
+ _ = csv_df[column].unique() # is this necessary?
categor_descs = pd.get_dummies(csv_df[column])
csv_df = csv_df.drop(column, axis=1)
csv_df = pd.concat([csv_df, categor_descs], axis=1)
for desc in categor_descs:
new_categor_desc.append(desc)
- if module.lower() == 'curate':
+ if module.lower() == "curate":
if len(categorical_vars) == 0:
- txt_categor += f'\n - No categorical variables were found'
+ txt_categor += "\n - No categorical variables were found"
else:
- if self.args.categorical.lower() == 'numbers':
- txt_categor += f'\n A total of {len(categorical_vars)} categorical variables were converted using the {self.args.categorical} mode in the categorical option:\n'
- txt_categor += '\n'.join(f' - {var}' for var in categorical_vars)
+ if self.args.categorical.lower() == "numbers":
+ txt_categor += f"\n A total of {len(categorical_vars)} categorical variables were converted using the {self.args.categorical} mode in the categorical option:\n"
+ txt_categor += "\n".join(f" - {var}" for var in categorical_vars)
else:
- txt_categor += f'\n A total of {len(categorical_vars)} categorical variables were converted using the {self.args.categorical} mode in the categorical option'
- txt_categor += f'\n Initial descriptors:\n'
- txt_categor += '\n'.join(f' - {var}' for var in categorical_vars)
- txt_categor += f'\n Generated descriptors:\n'
- txt_categor += '\n'.join(f' - {var}' for var in new_categor_desc)
+ txt_categor += f"\n A total of {len(categorical_vars)} categorical variables were converted using the {self.args.categorical} mode in the categorical option"
+ txt_categor += "\n Initial descriptors:\n"
+ txt_categor += "\n".join(f" - {var}" for var in categorical_vars)
+ txt_categor += "\n Generated descriptors:\n"
+ txt_categor += "\n".join(f" - {var}" for var in new_categor_desc)
- self.args.log.write(f'{txt_categor}')
+ self.args.log.write(f"{txt_categor}")
return csv_df
@@ -1454,142 +1761,182 @@ def create_folders(folder_names):
folder.mkdir(exist_ok=True, parents=True)
-def finish_print(self,start_time,module):
+def finish_print(self, start_time, module):
elapsed_time = round(time.time() - start_time, 2)
self.args.log.write(f"\nTime {module.upper()}: {elapsed_time} seconds\n")
self.args.log.finalize()
-def scale_df(csv_X,csv_X_external):
- '''
+def scale_df(csv_X, csv_X_external):
+ """
Scale the X matrix for the training set and the external test set (if any)
- '''
-
+ """
+
scaler = StandardScaler()
_ = scaler.fit(csv_X)
X_scaled = scaler.transform(csv_X)
- X_scaled_df = pd.DataFrame(X_scaled, columns = csv_X.columns)
+ X_scaled_df = pd.DataFrame(X_scaled, columns=csv_X.columns)
X_scaled_external_df = None
if csv_X_external is not None:
X_scaled_external = scaler.transform(csv_X_external)
- X_scaled_external_df = pd.DataFrame(X_scaled_external, columns = csv_X_external.columns)
-
- return X_scaled_df,X_scaled_external_df
-
+ X_scaled_external_df = pd.DataFrame(
+ X_scaled_external, columns=csv_X_external.columns
+ )
-def Xy_split(csv_df,csv_X,X_scaled_df,csv_y,csv_external_df,csv_X_external,X_scaled_external_df,csv_y_external,test_points,column_names):
- '''
+ return X_scaled_df, X_scaled_external_df
+
+
+def Xy_split(
+ csv_df,
+ csv_X,
+ X_scaled_df,
+ csv_y,
+ csv_external_df,
+ csv_X_external,
+ X_scaled_external_df,
+ csv_y_external,
+ test_points,
+ column_names,
+):
+ """
Returns a dictionary with the database divided into train and validation
- '''
+ """
- Xy_data = {}
+ Xy_data = {}
if len(test_points) == 0:
- Xy_data['X_train'] = csv_X
- Xy_data['X_train_scaled'] = X_scaled_df
- Xy_data['y_train'] = csv_y
- Xy_data['names_train'] = csv_df[column_names]
+ Xy_data["X_train"] = csv_X
+ Xy_data["X_train_scaled"] = X_scaled_df
+ Xy_data["y_train"] = csv_y
+ Xy_data["names_train"] = csv_df[column_names]
else:
- Xy_data['X_train'] = csv_X.drop(test_points)
- Xy_data['X_train_scaled'] = X_scaled_df.drop(test_points)
- Xy_data['y_train'] = csv_y.drop(test_points)
- Xy_data['X_test'] = csv_X.iloc[test_points]
- Xy_data['X_test_scaled'] = X_scaled_df.iloc[test_points]
- Xy_data['y_test'] = csv_y.iloc[test_points]
- Xy_data['names_train'] = csv_df.drop(test_points)[column_names]
- Xy_data['names_test'] = csv_df.iloc[test_points][column_names]
+ Xy_data["X_train"] = csv_X.drop(test_points)
+ Xy_data["X_train_scaled"] = X_scaled_df.drop(test_points)
+ Xy_data["y_train"] = csv_y.drop(test_points)
+ Xy_data["X_test"] = csv_X.iloc[test_points]
+ Xy_data["X_test_scaled"] = X_scaled_df.iloc[test_points]
+ Xy_data["y_test"] = csv_y.iloc[test_points]
+ Xy_data["names_train"] = csv_df.drop(test_points)[column_names]
+ Xy_data["names_test"] = csv_df.iloc[test_points][column_names]
- Xy_data['test_points'] = test_points
+ Xy_data["test_points"] = test_points
if X_scaled_external_df is not None:
- Xy_data['X_external'] = csv_X_external
- Xy_data['X_external_scaled'] = X_scaled_external_df
+ Xy_data["X_external"] = csv_X_external
+ Xy_data["X_external_scaled"] = X_scaled_external_df
if csv_y_external is not None:
- Xy_data['y_external'] = csv_y_external
- Xy_data['names_external'] = csv_external_df[column_names]
+ Xy_data["y_external"] = csv_y_external
+ Xy_data["names_external"] = csv_external_df[column_names]
return Xy_data
-def test_select(self,X_scaled,csv_y):
- '''
+def test_select(self, X_scaled, csv_y):
+ """
Selection of test set (if any)
- '''
+ """
# adjusts size of the test_set to include at least 4 points regardless of the number of datapoints
test_input_size = round(self.args.test_set * len(csv_y))
min_test_size = 4
- selected_size = max(test_input_size,min_test_size)
+ selected_size = max(test_input_size, min_test_size)
# in the future, we'll adapt other data splitting techniques for classificaiton problems with 3+ target values
- if self.args.type == 'clas':
+ if self.args.type == "clas":
if len(set(csv_y)) != 2:
- self.args.split = 'RND'
+ self.args.split = "RND"
- if self.args.split.upper() == 'KN':
+ if self.args.split.upper() == "KN":
# k-neighbours data split
# selects representative training points for each target value in classification problems
- if self.args.type == 'clas':
+ if self.args.type == "clas":
class_0_idx = list(csv_y[csv_y == 0].index)
class_1_idx = list(csv_y[csv_y == 1].index)
- class_0_test_size = round((len(class_0_idx)/len(csv_y))*selected_size)
- class_1_test_size = selected_size-class_0_test_size
+ class_0_test_size = round((len(class_0_idx) / len(csv_y)) * selected_size)
+ class_1_test_size = selected_size - class_0_test_size
class_0_train_size = len(class_0_idx) - class_0_test_size
class_1_train_size = len(class_1_idx) - class_1_test_size
- # the k-means function internally selects the training points to be as diverse as possible,
+ # the k-means function internally selects the training points to be as diverse as possible,
# but it returns the test points
- test_class_0 = k_means(self,X_scaled.iloc[class_0_idx],csv_y,class_0_train_size,self.args.seed,class_0_idx)
- test_class_1 = k_means(self,X_scaled.iloc[class_1_idx],csv_y,class_1_train_size,self.args.seed,class_1_idx)
- test_points = test_class_0+test_class_1
+ test_class_0 = k_means(
+ self,
+ X_scaled.iloc[class_0_idx],
+ csv_y,
+ class_0_train_size,
+ self.args.seed,
+ class_0_idx,
+ )
+ test_class_1 = k_means(
+ self,
+ X_scaled.iloc[class_1_idx],
+ csv_y,
+ class_1_train_size,
+ self.args.seed,
+ class_1_idx,
+ )
+ test_points = test_class_0 + test_class_1
else:
idx_list = csv_y.index
- training_size = len(csv_y)-selected_size
- test_points = k_means(self,X_scaled,csv_y,training_size,self.args.seed,idx_list)
+ training_size = len(csv_y) - selected_size
+ test_points = k_means(
+ self, X_scaled, csv_y, training_size, self.args.seed, idx_list
+ )
- elif self.args.split.upper() == 'RND':
+ elif self.args.split.upper() == "RND":
size = round(selected_size * 100 / (len(csv_y)))
- _, X_test, _, _ = train_test_split(X_scaled, csv_y, test_size=size/100, random_state=self.args.seed)
+ _, X_test, _, _ = train_test_split(
+ X_scaled, csv_y, test_size=size / 100, random_state=self.args.seed
+ )
test_points = X_test.index.tolist()
- elif self.args.split.upper() == 'STRATIFIED':
-
+ elif self.args.split.upper() == "STRATIFIED":
size = np.ceil(selected_size * 100 / (len(csv_y)))
# Remove the max and min values so they don't end up in the training set
# Calculate the number of bins based on the number of points
csv_y_capped = csv_y.drop([csv_y.idxmin(), csv_y.idxmax()])
- y_binned = pd.qcut(csv_y_capped, q=selected_size, labels=False, duplicates='drop')
-
+ y_binned = pd.qcut(
+ csv_y_capped, q=selected_size, labels=False, duplicates="drop"
+ )
+
# Adjust the number of bins until each class has at least 2 members
while y_binned.value_counts().min() < 2 and selected_size > 2:
selected_size -= 1
- y_binned = pd.qcut(csv_y_capped, q=selected_size, labels=False, duplicates='drop')
- splitter = StratifiedShuffleSplit(n_splits=1, test_size=(100 - size) / 100, random_state=self.args.seed)
+ y_binned = pd.qcut(
+ csv_y_capped, q=selected_size, labels=False, duplicates="drop"
+ )
+ splitter = StratifiedShuffleSplit(
+ n_splits=1, test_size=(100 - size) / 100, random_state=self.args.seed
+ )
for test_idx, _ in splitter.split(X_scaled, y_binned):
test_points = test_idx.tolist()
- elif self.args.split.upper() == 'EVEN':
+ elif self.args.split.upper() == "EVEN":
# Remove the max and min values so they don't end up in the training set
csv_y_capped = csv_y.drop([csv_y.idxmin(), csv_y.idxmax()])
# Calculate the number of bins based on the number of points
- y_binned = pd.qcut(csv_y_capped, q=selected_size, labels=False, duplicates='drop')
+ y_binned = pd.qcut(
+ csv_y_capped, q=selected_size, labels=False, duplicates="drop"
+ )
# Adjust bin count if any bin has fewer than two elements (happens in imbalanced data, see comment below)
temp_size = selected_size
while y_binned.value_counts().min() < 2 and temp_size > 2:
temp_size -= 1
- y_binned = pd.qcut(csv_y_capped, q=temp_size, labels=False, duplicates='drop')
+ y_binned = pd.qcut(
+ csv_y_capped, q=temp_size, labels=False, duplicates="drop"
+ )
# Determine central validation points for each bin
test_points = []
for bin_label in y_binned.unique():
bin_indices = y_binned[y_binned == bin_label].index
sorted_indices = sorted(bin_indices, key=lambda idx: csv_y[idx])
- test_points.append(sorted_indices[round(len(sorted_indices)/2)])
+ test_points.append(sorted_indices[round(len(sorted_indices) / 2)])
# in umbalanced databases, the points cannot be selected entirely even (i.e., if a database
# contains 10 points in th 0-10 range, and 1000 points in the 10-90 range, choosing 100
@@ -1603,12 +1950,12 @@ def test_select(self,X_scaled,csv_y):
test_points.append(new_test_point)
random_seed += 1
- elif self.args.split.upper() == 'EXTRA_Q1':
+ elif self.args.split.upper() == "EXTRA_Q1":
# 20% lowest points
portion = max(1, round(0.2 * len(csv_y)))
test_points = csv_y.nsmallest(portion).index.tolist()
-
- elif self.args.split.upper() == 'EXTRA_Q5':
+
+ elif self.args.split.upper() == "EXTRA_Q5":
# 20%% highest points
portion = max(1, round(0.2 * len(csv_y)))
test_points = csv_y.nlargest(portion).index.tolist()
@@ -1622,20 +1969,20 @@ def generate_lhs_points(pbounds, n_points, random_state=None):
"""
Generate initial points using Latin Hypercube Sampling for better space coverage.
LHS ensures uniform distribution across all dimensions of the hyperparameter space.
-
+
Args:
pbounds: Dictionary with parameter bounds from BO_hyperparams
n_points: Number of initial points to generate
random_state: Random seed for reproducibility
-
+
Returns:
List of dictionaries with parameter values
"""
np.random.seed(random_state)
-
+
param_names = list(pbounds.keys())
n_params = len(param_names)
-
+
# Generate LHS samples in [0, 1]^n_params
# Each dimension is divided into n_points intervals, and one point is sampled from each interval
samples = np.zeros((n_points, n_params))
@@ -1645,7 +1992,7 @@ def generate_lhs_points(pbounds, n_points, random_state=None):
samples[:, i] = np.random.uniform(intervals[:-1], intervals[1:])
# Shuffle to break correlation between dimensions
np.random.shuffle(samples[:, i])
-
+
# Scale samples to actual parameter bounds
initial_points = []
for sample in samples:
@@ -1655,30 +2002,25 @@ def generate_lhs_points(pbounds, n_points, random_state=None):
# Scale from [0, 1] to [lower, upper]
point[param_name] = lower + sample[i] * (upper - lower)
initial_points.append(point)
-
- return initial_points
-
-def BO_optimizer(self,bo_data,Xy_data):
- from bayes_opt import BayesianOptimization, acquisition
+ return initial_points
- # Define an acquisition function for Bayesian optimization
- _ = acquisition.ExpectedImprovement(xi=self.args.expect_improv)
+def BO_optimizer(self, bo_data, Xy_data):
# Initialize Bayesian optimization
optimizer = BayesianOptimization(
f=lambda **p: BO_iteration(self, bo_data, Xy_data, **p),
- pbounds=BO_hyperparams(bo_data['model']),
+ pbounds=BO_hyperparams(bo_data["model"]),
verbose=2,
- random_state=self.args.seed
+ random_state=self.args.seed,
)
# Generate initial points using Latin Hypercube Sampling for better space coverage
if self.args.init_points > 0:
initial_points = generate_lhs_points(
- pbounds=BO_hyperparams(bo_data['model']),
+ pbounds=BO_hyperparams(bo_data["model"]),
n_points=self.args.init_points,
- random_state=self.args.seed
+ random_state=self.args.seed,
)
# Probe the initial points
for params in initial_points:
@@ -1687,145 +2029,148 @@ def BO_optimizer(self,bo_data,Xy_data):
# Run the optimization (with warnings suppressed for Convergence issues)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ConvergenceWarning)
- optimizer.maximize(init_points=0, n_iter=self.args.n_iter) # init_points=0 since we already probed LHS points
+ optimizer.maximize(
+ init_points=0, n_iter=self.args.n_iter
+ ) # init_points=0 since we already probed LHS points
- if bo_data['error_type'].upper() in ['RMSE','MAE']:
- BO_target = -optimizer.max['target']
+ if bo_data["error_type"].upper() in ["RMSE", "MAE"]:
+ BO_target = -optimizer.max["target"]
else:
- BO_target = optimizer.max['target']
- self.args.log.write(f" o Best combined {bo_data['error_type'].upper()} (target) found in BO for {bo_data['model']} (no PFI filter): {BO_target:.2}")
+ BO_target = optimizer.max["target"]
+ self.args.log.write(
+ f" o Best combined {bo_data['error_type'].upper()} (target) found in BO for {bo_data['model']} (no PFI filter): {BO_target:.2}"
+ )
# Retrieve best parameters and best result
- return optimizer.max['params'], BO_target
+ return optimizer.max["params"], BO_target
def BO_iteration(self, bo_data, Xy_data, **params):
- '''
+ """
Evaluate a model with given parameters using cross-validation.
Returns the mean negative root mean squared error (higher is better).
- '''
+ """
- bo_data['params'] = model_adjust_params(self, bo_data['model'], params)
+ bo_data["params"] = model_adjust_params(self, bo_data["model"], params)
BO_iter_score = load_n_predict(self, bo_data, Xy_data, BO_opt=True)
return BO_iter_score
def BO_hyperparams(model_name):
-
model_BO_params = {
- 'RF' : {
- 'n_estimators': (10, 100),
- 'max_depth': (5, 20),
- 'min_samples_split': (2, 10),
- 'min_samples_leaf': (2, 5),
- 'min_weight_fraction_leaf': (0, 0.05),
- 'max_features': (0.25, 1.0),
- 'ccp_alpha': (0, 0.01),
- 'max_samples': (0.25, 1.0)
+ "RF": {
+ "n_estimators": (10, 100),
+ "max_depth": (5, 20),
+ "min_samples_split": (2, 10),
+ "min_samples_leaf": (2, 5),
+ "min_weight_fraction_leaf": (0, 0.05),
+ "max_features": (0.25, 1.0),
+ "ccp_alpha": (0, 0.01),
+ "max_samples": (0.25, 1.0),
},
- 'GB': {
- 'n_estimators': (10, 100),
- 'learning_rate': (0.01, 0.3),
- 'max_depth': (5, 20),
- 'min_samples_split': (2, 10),
- 'min_samples_leaf': (2, 5),
- 'subsample': (0.7, 1.0),
- 'max_features': (0.25, 1.0),
- 'validation_fraction': (0.1, 0.3),
- 'min_weight_fraction_leaf': (0, 0.05),
- 'ccp_alpha': (0, 0.01)
+ "GB": {
+ "n_estimators": (10, 100),
+ "learning_rate": (0.01, 0.3),
+ "max_depth": (5, 20),
+ "min_samples_split": (2, 10),
+ "min_samples_leaf": (2, 5),
+ "subsample": (0.7, 1.0),
+ "max_features": (0.25, 1.0),
+ "validation_fraction": (0.1, 0.3),
+ "min_weight_fraction_leaf": (0, 0.05),
+ "ccp_alpha": (0, 0.01),
},
- 'NN': {
- 'hidden_layer_1': (1, 10),
- 'hidden_layer_2': (0, 10),
- 'max_iter': (200, 500),
- 'alpha': (0.01, 0.1),
- 'tol': (0.00001, 0.0001)
+ "NN": {
+ "hidden_layer_1": (1, 10),
+ "hidden_layer_2": (0, 10),
+ "max_iter": (200, 500),
+ "alpha": (0.01, 0.1),
+ "tol": (0.00001, 0.0001),
},
- 'ADAB': {
- 'learning_rate': (0.1, 5),
- 'n_estimators': (10, 100)
+ "ADAB": {"learning_rate": (0.1, 5), "n_estimators": (10, 100)},
+ "GP": {
+ "n_restarts_optimizer": (0, 100),
},
- 'GP': {
- 'n_restarts_optimizer': (0, 100),
- },
- 'XGB': {
- 'n_estimators': (10, 100),
- 'learning_rate': (0.01, 0.3),
- 'max_depth': (3, 20),
- 'min_child_weight': (1, 10),
- 'subsample': (0.7, 1.0),
- 'colsample_bytree': (0.25, 1.0),
- 'reg_alpha': (0, 1.0),
- 'reg_lambda': (0, 1.0),
- },
- 'VR': {
- 'w_rf': (0.1, 5.0),
- 'w_gb': (0.1, 5.0),
- 'w_nn': (0.1, 5.0),
+ "XGB": {
+ "n_estimators": (10, 100),
+ "learning_rate": (0.01, 0.3),
+ "max_depth": (3, 20),
+ "min_child_weight": (1, 10),
+ "subsample": (0.7, 1.0),
+ "colsample_bytree": (0.25, 1.0),
+ "reg_alpha": (0, 1.0),
+ "reg_lambda": (0, 1.0),
},
}
+ model_BO_params["VR"] = {
+ "w_rf": (0.1, 5.0),
+ "w_gb": (0.1, 5.0),
+ "w_nn": (0.1, 5.0),
+ **{f"rf_{key}": value for key, value in model_BO_params["RF"].items()},
+ **{f"gb_{key}": value for key, value in model_BO_params["GB"].items()},
+ **{f"nn_{key}": value for key, value in model_BO_params["NN"].items()},
+ }
return model_BO_params[model_name]
def BO_metrics(self, bo_data, Xy_data):
- '''
+ """
Get combined score for repeated k-fold and top-bottom sorted CVs (used in BO)
- '''
+ """
metric_combined = load_n_predict(self, bo_data, Xy_data, BO_opt=True)
- if bo_data['error_type'].upper() in ['RMSE','MAE']:
- metric_combined = -metric_combined
+ if bo_data["error_type"].upper() in ["RMSE", "MAE"]:
+ metric_combined = -metric_combined
bo_data[f"combined_{bo_data['error_type']}"] = metric_combined
return bo_data
-def model_adjust_params(self,model_name,params):
- '''
+def model_adjust_params(self, model_name, params):
+ """
Add seed and convert parameters to integers, since they come as floats with decimals in the iterations
- '''
+ """
- if model_name not in ['MVL', 'VR']:
- params['random_state'] = self.args.seed
+ if model_name not in ["MVL", "VR"]:
+ params["random_state"] = self.args.seed
- if model_name in ['RF','GB']:
- params['n_estimators'] = round(params['n_estimators'])
- params['max_depth'] = round(params['max_depth'])
- params['min_samples_split'] = round(params['min_samples_split'])
- params['min_samples_leaf'] = round(params['min_samples_leaf'])
+ if model_name in ["RF", "GB"]:
+ params["n_estimators"] = round(params["n_estimators"])
+ params["max_depth"] = round(params["max_depth"])
+ params["min_samples_split"] = round(params["min_samples_split"])
+ params["min_samples_leaf"] = round(params["min_samples_leaf"])
- elif model_name == 'XGB':
- params['n_estimators'] = round(params['n_estimators'])
- params['max_depth'] = round(params['max_depth'])
- params['min_child_weight'] = round(params['min_child_weight'])
+ elif model_name == "XGB":
+ params["n_estimators"] = round(params["n_estimators"])
+ params["max_depth"] = round(params["max_depth"])
+ params["min_child_weight"] = round(params["min_child_weight"])
- elif model_name == 'NN':
+ elif model_name == "NN":
# add solver first
- params['solver'] = 'lbfgs'
- params['max_iter'] = round(params['max_iter'])
- params['hidden_layer_1'] = round(params['hidden_layer_1'])
- params['hidden_layer_2'] = round(params['hidden_layer_2'])
-
- elif model_name == 'ADAB':
- params['n_estimators'] = round(params['n_estimators'])
-
- elif model_name == 'GP':
- params['n_restarts_optimizer'] = round(params['n_restarts_optimizer'])
-
- elif model_name == 'VR':
- # VR only optimizes ensemble weights; base estimators receive deterministic seeds.
- if all(weight_key in params for weight_key in ['w_rf', 'w_gb', 'w_nn']):
- params['weights'] = [
- float(params.pop('w_rf')),
- float(params.pop('w_gb')),
- float(params.pop('w_nn')),
+ params["solver"] = "lbfgs"
+ params["max_iter"] = round(params["max_iter"])
+ params["hidden_layer_1"] = round(params["hidden_layer_1"])
+ params["hidden_layer_2"] = round(params["hidden_layer_2"])
+
+ elif model_name == "ADAB":
+ params["n_estimators"] = round(params["n_estimators"])
+
+ elif model_name == "GP":
+ params["n_restarts_optimizer"] = round(params["n_restarts_optimizer"])
+
+ elif model_name == "VR":
+ if all(weight_key in params for weight_key in ["w_rf", "w_gb", "w_nn"]):
+ params["weights"] = [
+ float(params.pop("w_rf")),
+ float(params.pop("w_gb")),
+ float(params.pop("w_nn")),
]
- elif 'weights' in params:
- params['weights'] = [float(weight) for weight in params['weights']]
+ elif "weights" in params:
+ params["weights"] = [float(weight) for weight in params["weights"]]
+ _round_vr_member_params(params)
return params
@@ -1835,120 +2180,162 @@ def load_model(self, model_name, **params):
Load models with their corresponding parameters.
"""
- if model_name == 'RF':
+ if model_name == "RF":
# Ensure n_jobs=1 for reproducibility if not already in params
- if 'n_jobs' not in params:
- params['n_jobs'] = 1
- if self.args.type.lower() == 'reg':
+ if "n_jobs" not in params:
+ params["n_jobs"] = 1
+ if self.args.type.lower() == "reg":
loaded_model = RandomForestRegressor(**params)
else:
loaded_model = RandomForestClassifier(**params)
- elif model_name == 'GB':
+ elif model_name == "GB":
# GradientBoosting doesn't have n_jobs parameter, it's already deterministic
- if self.args.type.lower() == 'reg':
+ if self.args.type.lower() == "reg":
loaded_model = GradientBoostingRegressor(**params)
else:
loaded_model = GradientBoostingClassifier(**params)
- elif model_name == 'XGB':
- if 'n_jobs' not in params:
- params['n_jobs'] = 1
- if self.args.type.lower() == 'reg':
+ elif model_name == "XGB":
+ if "n_jobs" not in params:
+ params["n_jobs"] = 1
+ if self.args.type.lower() == "reg":
loaded_model = XGBRegressor(**params)
else:
loaded_model = XGBClassifier(**params)
- elif model_name == 'NN':
+ elif model_name == "NN":
# create the hidden layers architecture first
params = setup_hidden_layers(params)
- if self.args.type.lower() == 'reg':
+ if self.args.type.lower() == "reg":
loaded_model = MLPRegressor(**params)
else:
loaded_model = MLPClassifier(**params)
- elif model_name == 'ADAB':
- if self.args.type.lower() == 'reg':
+ elif model_name == "ADAB":
+ if self.args.type.lower() == "reg":
loaded_model = AdaBoostRegressor(**params)
else:
loaded_model = AdaBoostClassifier(**params)
- elif model_name == 'GP':
- if self.args.type.lower() == 'reg':
+ elif model_name == "GP":
+ if self.args.type.lower() == "reg":
loaded_model = GaussianProcessRegressor(**params)
else:
loaded_model = GaussianProcessClassifier(**params)
- elif model_name == 'MVL':
+ elif model_name == "MVL":
loaded_model = LinearRegression(**params)
- elif model_name == 'VR':
- weights = params.pop('weights', [1.0, 1.0, 1.0])
+ elif model_name == "VR":
+ weights = params.pop("weights", [1.0, 1.0, 1.0])
weights = [float(weight) for weight in weights]
seed = self.args.seed
+ rf_defaults = {
+ "n_estimators": 100,
+ "max_depth": 10,
+ "min_samples_split": 2,
+ "min_samples_leaf": 1,
+ "min_weight_fraction_leaf": 0,
+ "max_features": 1.0,
+ "ccp_alpha": 0.0,
+ "max_samples": None,
+ "random_state": seed,
+ "n_jobs": 1,
+ }
+ gb_defaults = {
+ "n_estimators": 30,
+ "learning_rate": 0.1,
+ "max_depth": 10,
+ "min_samples_split": 2,
+ "min_samples_leaf": 1,
+ "subsample": 1.0,
+ "max_features": None,
+ "validation_fraction": 0.2,
+ "min_weight_fraction_leaf": 0.0,
+ "ccp_alpha": 0.0,
+ "random_state": seed,
+ }
+ nn_defaults = {
+ "hidden_layer_1": 50,
+ "hidden_layer_2": 0,
+ "max_iter": 500,
+ "alpha": 0.01,
+ "tol": 0.0001,
+ "solver": "lbfgs",
+ "random_state": seed,
+ }
+ rf_params = _pop_vr_member_params(params, "rf", rf_defaults)
+ gb_params = _pop_vr_member_params(params, "gb", gb_defaults)
+ nn_params = _pop_vr_member_params(params, "nn", nn_defaults)
+ nn_params = setup_hidden_layers(nn_params)
- if self.args.type.lower() == 'reg':
+ if self.args.type.lower() == "reg":
voting_estimators = [
- ('rf', RandomForestRegressor(n_estimators=100, random_state=seed, n_jobs=1)),
- ('gb', GradientBoostingRegressor(random_state=seed)),
- ('nn', MLPRegressor(hidden_layer_sizes=(50,), max_iter=500, solver='lbfgs', random_state=seed)),
+ ("rf", RandomForestRegressor(**rf_params)),
+ ("gb", GradientBoostingRegressor(**gb_params)),
+ ("nn", MLPRegressor(**nn_params)),
]
- loaded_model = VotingRegressor(estimators=voting_estimators, weights=weights)
+ loaded_model = VotingRegressor(
+ estimators=voting_estimators, weights=weights
+ )
else:
voting_estimators = [
- ('rf', RandomForestClassifier(n_estimators=100, random_state=seed, n_jobs=1)),
- ('gb', GradientBoostingClassifier(random_state=seed)),
- ('nn', MLPClassifier(hidden_layer_sizes=(50,), max_iter=500, solver='lbfgs', random_state=seed)),
+ ("rf", RandomForestClassifier(**rf_params)),
+ ("gb", GradientBoostingClassifier(**gb_params)),
+ ("nn", MLPClassifier(**nn_params)),
]
- loaded_model = VotingClassifier(estimators=voting_estimators, weights=weights)
-
+ loaded_model = VotingClassifier(
+ estimators=voting_estimators, weights=weights
+ )
+
return loaded_model
def setup_hidden_layers(params):
- '''
+ """
Build hidden layer structure from provided parameters
- '''
+ """
hidden_layer_sizes = []
- hidden_layer_1 = params.pop('hidden_layer_1')
- hidden_layer_2 = params.pop('hidden_layer_2')
+ hidden_layer_1 = params.pop("hidden_layer_1")
+ hidden_layer_2 = params.pop("hidden_layer_2")
if hidden_layer_1 > 0:
hidden_layer_sizes.append(hidden_layer_1)
if hidden_layer_2 > 0:
hidden_layer_sizes.append(hidden_layer_2)
hidden_layer_sizes = tuple(hidden_layer_sizes) if hidden_layer_sizes else (1,)
- params['hidden_layer_sizes'] = hidden_layer_sizes
+ params["hidden_layer_sizes"] = hidden_layer_sizes
return params
def correct_hidden_layers(params):
- '''
+ """
Correct for a problem with the 'hidden_layer_sizes' parameter when loading arrays from JSON
- '''
-
+ """
+
layer_arrays = []
- if not isinstance(params['hidden_layer_sizes'],int):
- if params['hidden_layer_sizes'][0] == '[':
- params['hidden_layer_sizes'] = params['hidden_layer_sizes'][1:]
- if params['hidden_layer_sizes'][-1] == ']':
- params['hidden_layer_sizes'] = params['hidden_layer_sizes'][:-1]
- if not isinstance(params['hidden_layer_sizes'],list):
- for _,ele in enumerate(params['hidden_layer_sizes'].split(',')):
- if ele != '':
+ if not isinstance(params["hidden_layer_sizes"], int):
+ if params["hidden_layer_sizes"][0] == "[":
+ params["hidden_layer_sizes"] = params["hidden_layer_sizes"][1:]
+ if params["hidden_layer_sizes"][-1] == "]":
+ params["hidden_layer_sizes"] = params["hidden_layer_sizes"][:-1]
+ if not isinstance(params["hidden_layer_sizes"], list):
+ for _, ele in enumerate(params["hidden_layer_sizes"].split(",")):
+ if ele != "":
layer_arrays.append(int(ele))
else:
- for _,ele in enumerate(params['hidden_layer_sizes']):
- if ele != '':
+ for _, ele in enumerate(params["hidden_layer_sizes"]):
+ if ele != "":
layer_arrays.append(int(ele))
else:
layer_arrays = ele
- params['hidden_layer_sizes'] = (layer_arrays)
+ params["hidden_layer_sizes"] = layer_arrays
return params
@@ -2074,7 +2461,7 @@ def aggregate_meta_uq_decomposition(preds_stack, sd_stack, weights, problem_type
if problem_type.lower() == "reg":
y_point = np.average(preds, axis=0, weights=w)
- within_var = np.average(sds ** 2, axis=0, weights=w)
+ within_var = np.average(sds**2, axis=0, weights=w)
mean_pred = np.average(preds, axis=0, weights=w)
between_var = np.average((preds - mean_pred) ** 2, axis=0, weights=w)
uq_model = np.sqrt(np.maximum(within_var, 0.0))
@@ -2085,7 +2472,10 @@ def aggregate_meta_uq_decomposition(preds_stack, sd_stack, weights, problem_type
# Classification: preds are class labels; use float spread heuristics.
preds_f = preds.astype(float)
y_point = np.array(
- [int(round(np.average(preds_f[:, i], weights=w))) for i in range(preds_f.shape[1])],
+ [
+ int(round(np.average(preds_f[:, i], weights=w)))
+ for i in range(preds_f.shape[1])
+ ],
dtype=int,
)
uq_model = np.average(sds, axis=0, weights=w)
@@ -2093,7 +2483,7 @@ def aggregate_meta_uq_decomposition(preds_stack, sd_stack, weights, problem_type
uq_meta = np.sqrt(
np.maximum(np.average((preds_f - mean_pred) ** 2, axis=0, weights=w), 0.0)
)
- uq_total = np.sqrt(np.maximum(uq_model ** 2 + uq_meta ** 2, 0.0))
+ uq_total = np.sqrt(np.maximum(uq_model**2 + uq_meta**2, 0.0))
return y_point, uq_model, uq_meta, uq_total
@@ -2222,7 +2612,9 @@ def _conformal_abs_residual_quantile(abs_residuals, coverage):
return float(np.quantile(abs_r, level, method="higher"))
-def _apply_full_refit_split_conformal(self, model_data, Xy_data, loaded_model, y_cv_mean_train):
+def _apply_full_refit_split_conformal(
+ self, model_data, Xy_data, loaded_model, y_cv_mean_train
+):
"""
Point predictions from a single estimator refit on all training data.
For regression, ``conformal_half_width`` uses a held-out calibration split when
@@ -2305,15 +2697,15 @@ def _apply_full_refit_split_conformal(self, model_data, Xy_data, loaded_model, y
def load_n_predict(self, model_data, Xy_data, BO_opt=False, verify_job=False):
- '''
+ """
Load model and calculate errors/precision and predicted values of the ML models
- '''
+ """
# set the parameters for the ML model and load it
- loaded_model = load_model(self, model_data['model'], **model_data['params'])
+ loaded_model = load_model(self, model_data["model"], **model_data["params"])
# calculate predicted y values using repeated k-fold CV
- Xy_data = repeated_kfold_cv(model_data,loaded_model,Xy_data,BO_opt)
+ Xy_data = repeated_kfold_cv(model_data, loaded_model, Xy_data, BO_opt)
y_cv_mean_train = np.asarray(Xy_data["y_pred_train"], dtype=float)
if not BO_opt:
@@ -2323,38 +2715,59 @@ def load_n_predict(self, model_data, Xy_data, BO_opt=False, verify_job=False):
Xy_data["_fitted_model"] = fitted_model
# combine all the predictions from the repeated CV (metrics of the train set)
- y_all_list,y_pred_all_list = [],[]
- for y_val,y_pred_vals in zip(Xy_data['y_train'],Xy_data['y_pred_train_all']):
+ y_all_list, y_pred_all_list = [], []
+ for y_val, y_pred_vals in zip(Xy_data["y_train"], Xy_data["y_pred_train_all"]):
for y_pred_val in y_pred_vals:
y_all_list.append(y_val)
y_pred_all_list.append(y_pred_val)
-
+
# get metrics for the different sets
- error_labels = {'reg': ['r2','mae','rmse'],
- 'clas': ['acc','f1','mcc']
- }
-
- error1 = error_labels[model_data['type']][0]
- error2 = error_labels[model_data['type']][1]
- error3 = error_labels[model_data['type']][2]
- Xy_data[f'{error1}_train'], Xy_data[f'{error2}_train'], Xy_data[f'{error3}_train'] = get_prediction_results(model_data,y_all_list,y_pred_all_list)
+ error_labels = {"reg": ["r2", "mae", "rmse"], "clas": ["acc", "f1", "mcc"]}
+
+ error1 = error_labels[model_data["type"]][0]
+ error2 = error_labels[model_data["type"]][1]
+ error3 = error_labels[model_data["type"]][2]
+ (
+ Xy_data[f"{error1}_train"],
+ Xy_data[f"{error2}_train"],
+ Xy_data[f"{error3}_train"],
+ ) = get_prediction_results(model_data, y_all_list, y_pred_all_list)
if not BO_opt:
- Xy_data[f'{error1}_test'], Xy_data[f'{error2}_test'], Xy_data[f'{error3}_test'] = get_prediction_results(model_data,Xy_data['y_test'],Xy_data['y_pred_test'])
- if 'y_external' in Xy_data and not Xy_data['y_external'].isnull().values.any() and len(Xy_data['y_external']) > 0:
- Xy_data[f'{error1}_external'], Xy_data[f'{error2}_external'], Xy_data[f'{error3}_external'] = get_prediction_results(model_data,Xy_data['y_external'],Xy_data['y_pred_external'])
+ (
+ Xy_data[f"{error1}_test"],
+ Xy_data[f"{error2}_test"],
+ Xy_data[f"{error3}_test"],
+ ) = get_prediction_results(
+ model_data, Xy_data["y_test"], Xy_data["y_pred_test"]
+ )
+ if (
+ "y_external" in Xy_data
+ and not Xy_data["y_external"].isnull().values.any()
+ and len(Xy_data["y_external"]) > 0
+ ):
+ (
+ Xy_data[f"{error1}_external"],
+ Xy_data[f"{error2}_external"],
+ Xy_data[f"{error3}_external"],
+ ) = get_prediction_results(
+ model_data, Xy_data["y_external"], Xy_data["y_pred_external"]
+ )
if BO_opt:
# calculate sorted CV and its metrics
# print the target that is above the BO
# print the final result of the BO just after finishing all the iterations
Xy_data = sorted_kfold_cv(loaded_model, model_data, Xy_data, error_labels)
- combined_score = (Xy_data[f'{model_data["error_type"]}_train'] + Xy_data[f'{model_data["error_type"]}_up_bottom']) / 2
+ combined_score = (
+ Xy_data[f"{model_data['error_type']}_train"]
+ + Xy_data[f"{model_data['error_type']}_up_bottom"]
+ ) / 2
# Return if this is part of a verify job
if verify_job:
return Xy_data
# Return negative score for MAE/RMSE in BO
- if model_data["error_type"].lower() in ['mae', 'rmse']:
+ if model_data["error_type"].lower() in ["mae", "rmse"]:
return -combined_score
else:
return combined_score
@@ -2362,110 +2775,139 @@ def load_n_predict(self, model_data, Xy_data, BO_opt=False, verify_job=False):
return Xy_data
-def repeated_kfold_cv(model_data,loaded_model,Xy_data,BO_opt):
- '''
+def repeated_kfold_cv(model_data, loaded_model, Xy_data, BO_opt):
+ """
Performs a repeated k-fold cross-validation on the Xy dataset
- '''
+ """
# create a list of lists with the same number of entries as y
- y_global,y_pred_global = [],[]
- for _ in range(len(Xy_data['y_train'])):
+ y_global, y_pred_global = [], []
+ for _ in range(len(Xy_data["y_train"])):
y_pred_global.append([])
y_global.append([])
y_pred_global_test = []
- for _ in range(len(Xy_data['y_test'])):
+ for _ in range(len(Xy_data["y_test"])):
y_pred_global_test.append([])
- y_global_external,y_pred_global_external = [],[]
- if 'X_external' in Xy_data: # if there is an external test set
- for _ in range(len(Xy_data['X_external'])):
+ y_global_external, y_pred_global_external = [], []
+ if "X_external" in Xy_data: # if there is an external test set
+ for _ in range(len(Xy_data["X_external"])):
y_pred_global_external.append([])
y_global_external.append([])
# start the repeated CV
- for CV_repeat in range(int(model_data['repeat_kfolds'])):
- _,y_pred_global,y_pred_global_test,y_pred_global_external, = kfold_cv(y_global,y_pred_global,
- y_pred_global_test,
- y_pred_global_external,
- model_data,loaded_model,
- Xy_data,CV_repeat,BO_opt=BO_opt)
-
- y_train_pred, y_train_std = [],[]
+ for CV_repeat in range(int(model_data["repeat_kfolds"])):
+ (
+ _,
+ y_pred_global,
+ y_pred_global_test,
+ y_pred_global_external,
+ ) = kfold_cv(
+ y_global,
+ y_pred_global,
+ y_pred_global_test,
+ y_pred_global_external,
+ model_data,
+ loaded_model,
+ Xy_data,
+ CV_repeat,
+ BO_opt=BO_opt,
+ )
+
+ y_train_pred, y_train_std = [], []
for y_val in y_pred_global:
- if model_data['type'].lower() == 'reg':
+ if model_data["type"].lower() == "reg":
y_train_pred.append(np.mean(y_val))
- elif model_data['type'].lower() == 'clas':
+ elif model_data["type"].lower() == "clas":
y_train_pred.append(int(round(np.mean(y_val))))
y_train_std.append(float(np.std(y_val)))
- Xy_data['y_pred_train_all'] = y_pred_global
- Xy_data['y_pred_train'] = y_train_pred
- Xy_data['y_pred_train_sd'] = y_train_std
+ Xy_data["y_pred_train_all"] = y_pred_global
+ Xy_data["y_pred_train"] = y_train_pred
+ Xy_data["y_pred_train_sd"] = y_train_std
if not BO_opt:
- y_test_pred, y_test_std = [],[]
+ y_test_pred, y_test_std = [], []
for y_val_test in y_pred_global_test:
- if model_data['type'].lower() == 'reg':
+ if model_data["type"].lower() == "reg":
y_test_pred.append(np.mean(y_val_test))
- elif model_data['type'].lower() == 'clas':
+ elif model_data["type"].lower() == "clas":
y_test_pred.append(int(round(np.mean(y_val_test))))
y_test_std.append(float(np.std(y_val_test)))
- Xy_data['y_pred_test_all'] = y_pred_global_test
- Xy_data['y_pred_test'] = y_test_pred
- Xy_data['y_pred_test_sd'] = y_test_std
+ Xy_data["y_pred_test_all"] = y_pred_global_test
+ Xy_data["y_pred_test"] = y_test_pred
+ Xy_data["y_pred_test_sd"] = y_test_std
- if 'X_external' in Xy_data: # if there is an external test set
- y_external_pred, y_external_std = [],[]
+ if "X_external" in Xy_data: # if there is an external test set
+ y_external_pred, y_external_std = [], []
for y_val_external in y_pred_global_external:
- if model_data['type'].lower() == 'reg':
+ if model_data["type"].lower() == "reg":
y_external_pred.append(np.mean(y_val_external))
- elif model_data['type'].lower() == 'clas':
+ elif model_data["type"].lower() == "clas":
y_external_pred.append(int(round(np.mean(y_val_external))))
y_external_std.append(float(np.std(y_val_external)))
- Xy_data['y_pred_external_all'] = y_pred_global_external
- Xy_data['y_pred_external'] = y_external_pred
- Xy_data['y_pred_external_sd'] = y_external_std
+ Xy_data["y_pred_external_all"] = y_pred_global_external
+ Xy_data["y_pred_external"] = y_external_pred
+ Xy_data["y_pred_external_sd"] = y_external_std
return Xy_data
-def kfold_cv(y_global,y_pred_global,
- y_pred_global_test,
- y_pred_global_external,
- model_data,loaded_model,Xy_data,random_state,
- BO_opt=False,shuffle=True,kfold_cv_type='repeated'):
- '''
+def kfold_cv(
+ y_global,
+ y_pred_global,
+ y_pred_global_test,
+ y_pred_global_external,
+ model_data,
+ loaded_model,
+ Xy_data,
+ random_state,
+ BO_opt=False,
+ shuffle=True,
+ kfold_cv_type="repeated",
+):
+ """
Perform a k-fold CV
Uses StratifiedKFold for classification problems to maintain class distribution
- '''
+ """
# load CV scheme
- if model_data['type'].lower() == 'clas':
+ if model_data["type"].lower() == "clas":
# Use StratifiedKFold for classification to maintain class distribution
- cv = StratifiedKFold(n_splits=int(model_data['kfold']), shuffle=shuffle, random_state=random_state)
+ cv = StratifiedKFold(
+ n_splits=int(model_data["kfold"]),
+ shuffle=shuffle,
+ random_state=random_state,
+ )
else:
- cv = KFold(n_splits=int(model_data['kfold']), shuffle=shuffle, random_state=random_state)
+ cv = KFold(
+ n_splits=int(model_data["kfold"]),
+ shuffle=shuffle,
+ random_state=random_state,
+ )
# # load Xy values and sort using y_train as the sorting reference
- if kfold_cv_type == 'sorted':
- X_init,y_init = sort_n_load(Xy_data) # do not use, currently it doesn't sort indices for X_train as well
+ if kfold_cv_type == "sorted":
+ X_init, y_init = sort_n_load(
+ Xy_data
+ ) # do not use, currently it doesn't sort indices for X_train as well
else:
# convert Xy values of training and validation for CV
- X_init = np.array(Xy_data['X_train_scaled'])
- y_init = np.array(Xy_data['y_train'])
+ X_init = np.array(Xy_data["X_train_scaled"])
+ y_init = np.array(Xy_data["y_train"])
# convert Xy values for the test set and external test set (if any)
- X_test = np.array(Xy_data['X_test_scaled'])
- if 'X_external_scaled' in Xy_data:
- X_external = np.array(Xy_data['X_external_scaled'])
+ X_test = np.array(Xy_data["X_test_scaled"])
+ if "X_external_scaled" in Xy_data:
+ X_external = np.array(Xy_data["X_external_scaled"])
ix_training, ix_valid = [], []
# Loop through each fold and append the training & test indices to the empty lists above
- if model_data['type'].lower() == 'clas':
+ if model_data["type"].lower() == "clas":
# For classification, we need to pass y values to ensure stratification
for fold in cv.split(X_init, y_init):
ix_training.append(fold[0]), ix_valid.append(fold[1])
@@ -2473,8 +2915,8 @@ def kfold_cv(y_global,y_pred_global,
for fold in cv.split(X_init):
ix_training.append(fold[0]), ix_valid.append(fold[1])
- # Loop through each outer fold, and extract predicted vs actual values and SHAP feature analysis
- for (train_outer_ix, test_outer_ix) in zip(ix_training, ix_valid):
+ # Loop through each outer fold, and extract predicted vs actual values and SHAP feature analysis
+ for train_outer_ix, test_outer_ix in zip(ix_training, ix_valid):
X_train, X_valid = X_init[train_outer_ix, :], X_init[test_outer_ix, :]
y_train, y_valid = y_init[train_outer_ix], y_init[test_outer_ix]
@@ -2482,94 +2924,125 @@ def kfold_cv(y_global,y_pred_global,
y_pred_valid = fit.predict(X_valid)
if not BO_opt:
y_pred_test = fit.predict(X_test)
- if 'X_external_scaled' in Xy_data:
+ if "X_external_scaled" in Xy_data:
y_pred_external = fit.predict(X_external)
-
- if kfold_cv_type == 'repeated':
- for y_val,y_pred_val,idx in zip(y_valid,y_pred_valid,test_outer_ix):
+
+ if kfold_cv_type == "repeated":
+ for y_val, y_pred_val, idx in zip(y_valid, y_pred_valid, test_outer_ix):
y_global[idx].append(y_val)
y_pred_global[idx].append(y_pred_val)
if not BO_opt:
- for idx,y_pred_val_test in enumerate(y_pred_test):
+ for idx, y_pred_val_test in enumerate(y_pred_test):
y_pred_global_test[idx].append(y_pred_val_test)
- if 'X_external_scaled' in Xy_data:
- for idx,y_pred_val_external in enumerate(y_pred_external):
+ if "X_external_scaled" in Xy_data:
+ for idx, y_pred_val_external in enumerate(y_pred_external):
y_pred_global_external[idx].append(y_pred_val_external)
- elif kfold_cv_type == 'sorted':
+ elif kfold_cv_type == "sorted":
y_global.append(y_valid)
- y_pred_global.append(y_pred_valid)
+ y_pred_global.append(y_pred_valid)
- return y_global,y_pred_global,y_pred_global_test,y_pred_global_external
+ return y_global, y_pred_global, y_pred_global_test, y_pred_global_external
def sort_n_load(Xy_data):
- '''
+ """
Sort Xy data values to enhance reproducibility in cases where same databases are loaded
with different row order, ensuring stable sorting across OS with kind='stable'.
- '''
-
- X_train_scaled = np.array(Xy_data['X_train_scaled'])
- y_train = np.array(Xy_data['y_train'])
+ """
- sorted_indices = np.argsort(y_train, kind='stable')
+ X_train_scaled = np.array(Xy_data["X_train_scaled"])
+ y_train = np.array(Xy_data["y_train"])
+
+ sorted_indices = np.argsort(y_train, kind="stable")
sorted_X_train_scaled = X_train_scaled[sorted_indices]
sorted_y_train = y_train[sorted_indices]
return sorted_X_train_scaled, sorted_y_train
-def sorted_kfold_cv(loaded_model,model_data,Xy_data,error_labels):
- '''
+def sorted_kfold_cv(loaded_model, model_data, Xy_data, error_labels):
+ """
Performs a sorted k-fold cross-validation on the Xy dataset. Returns the average of the two results
- '''
+ """
# perform sorted 5-fold CV
- Xy_data['y_sorted_cv'],Xy_data['y_pred_sorted_cv'] = [],[]
- Xy_data['y_sorted_cv'],Xy_data['y_pred_sorted_cv'],_,_ = kfold_cv(Xy_data['y_sorted_cv'],Xy_data['y_pred_sorted_cv'],
- None,
- None,
- model_data,loaded_model,Xy_data,None,BO_opt=True,shuffle=False,kfold_cv_type='sorted')
- error1 = error_labels[model_data['type']][0]
- error2 = error_labels[model_data['type']][1]
- error3 = error_labels[model_data['type']][2]
- if model_data['type'].lower() == 'reg':
- Xy_data[f'{error1}_train_sorted_CV'], Xy_data[f'{error2}_train_sorted_CV'], Xy_data[f'{error3}_train_sorted_CV'] = [],[],[]
- for y_cv,y_pred_cd in zip(Xy_data['y_sorted_cv'],Xy_data['y_pred_sorted_cv']):
- r2_train_sorted_CV, mae_train_sorted_CV, rmse_train_sorted_CV = get_prediction_results(model_data,y_cv,y_pred_cd)
- Xy_data[f'{error1}_train_sorted_CV'].append(r2_train_sorted_CV)
- Xy_data[f'{error2}_train_sorted_CV'].append(mae_train_sorted_CV)
- Xy_data[f'{error3}_train_sorted_CV'].append(rmse_train_sorted_CV)
+ Xy_data["y_sorted_cv"], Xy_data["y_pred_sorted_cv"] = [], []
+ Xy_data["y_sorted_cv"], Xy_data["y_pred_sorted_cv"], _, _ = kfold_cv(
+ Xy_data["y_sorted_cv"],
+ Xy_data["y_pred_sorted_cv"],
+ None,
+ None,
+ model_data,
+ loaded_model,
+ Xy_data,
+ None,
+ BO_opt=True,
+ shuffle=False,
+ kfold_cv_type="sorted",
+ )
+ error1 = error_labels[model_data["type"]][0]
+ error2 = error_labels[model_data["type"]][1]
+ error3 = error_labels[model_data["type"]][2]
+ if model_data["type"].lower() == "reg":
+ (
+ Xy_data[f"{error1}_train_sorted_CV"],
+ Xy_data[f"{error2}_train_sorted_CV"],
+ Xy_data[f"{error3}_train_sorted_CV"],
+ ) = [], [], []
+ for y_cv, y_pred_cd in zip(Xy_data["y_sorted_cv"], Xy_data["y_pred_sorted_cv"]):
+ r2_train_sorted_CV, mae_train_sorted_CV, rmse_train_sorted_CV = (
+ get_prediction_results(model_data, y_cv, y_pred_cd)
+ )
+ Xy_data[f"{error1}_train_sorted_CV"].append(r2_train_sorted_CV)
+ Xy_data[f"{error2}_train_sorted_CV"].append(mae_train_sorted_CV)
+ Xy_data[f"{error3}_train_sorted_CV"].append(rmse_train_sorted_CV)
# take the worst performing predictions from the top and bottom folds
- if model_data["error_type"].lower() in ['mae','rmse']:
- Xy_data[f'{model_data["error_type"]}_up_bottom'] = max(Xy_data[f'{model_data["error_type"]}_train_sorted_CV'][0], Xy_data[f'{model_data["error_type"]}_train_sorted_CV'][-1])
- Xy_data['r2_up_bottom'] = min(Xy_data['r2_train_sorted_CV'][0], Xy_data['r2_train_sorted_CV'][-1])
+ if model_data["error_type"].lower() in ["mae", "rmse"]:
+ Xy_data[f"{model_data['error_type']}_up_bottom"] = max(
+ Xy_data[f"{model_data['error_type']}_train_sorted_CV"][0],
+ Xy_data[f"{model_data['error_type']}_train_sorted_CV"][-1],
+ )
+ Xy_data["r2_up_bottom"] = min(
+ Xy_data["r2_train_sorted_CV"][0], Xy_data["r2_train_sorted_CV"][-1]
+ )
else: # r2
- Xy_data[f'{model_data["error_type"]}_up_bottom'] = min(Xy_data[f'{model_data["error_type"]}_train_sorted_CV'][0], Xy_data[f'{model_data["error_type"]}_train_sorted_CV'][-1])
+ Xy_data[f"{model_data['error_type']}_up_bottom"] = min(
+ Xy_data[f"{model_data['error_type']}_train_sorted_CV"][0],
+ Xy_data[f"{model_data['error_type']}_train_sorted_CV"][-1],
+ )
else: # classification
- Xy_data[f'{error1}_train_sorted_CV'], Xy_data[f'{error2}_train_sorted_CV'], Xy_data[f'{error3}_train_sorted_CV'] = [],[],[]
- for y_cv, y_pred_cd in zip(Xy_data['y_sorted_cv'], Xy_data['y_pred_sorted_cv']):
- acc_fold, f1_fold, mcc_fold = get_prediction_results(model_data, y_cv, y_pred_cd)
- Xy_data[f'{error1}_train_sorted_CV'].append(acc_fold)
- Xy_data[f'{error2}_train_sorted_CV'].append(f1_fold)
- Xy_data[f'{error3}_train_sorted_CV'].append(mcc_fold)
+ (
+ Xy_data[f"{error1}_train_sorted_CV"],
+ Xy_data[f"{error2}_train_sorted_CV"],
+ Xy_data[f"{error3}_train_sorted_CV"],
+ ) = [], [], []
+ for y_cv, y_pred_cd in zip(Xy_data["y_sorted_cv"], Xy_data["y_pred_sorted_cv"]):
+ acc_fold, f1_fold, mcc_fold = get_prediction_results(
+ model_data, y_cv, y_pred_cd
+ )
+ Xy_data[f"{error1}_train_sorted_CV"].append(acc_fold)
+ Xy_data[f"{error2}_train_sorted_CV"].append(f1_fold)
+ Xy_data[f"{error3}_train_sorted_CV"].append(mcc_fold)
# Measure fold stability by difference between best and worst fold
- Xy_data[f'{model_data["error_type"]}_up_bottom'] = np.mean(np.abs(Xy_data[f'{model_data["error_type"]}_train_sorted_CV']))
+ Xy_data[f"{model_data['error_type']}_up_bottom"] = np.mean(
+ np.abs(Xy_data[f"{model_data['error_type']}_train_sorted_CV"])
+ )
return Xy_data
-def k_means(self,X_scaled,csv_y,size,seed,idx_list):
- '''
-
- Uses k-means clustering to select the test points to be as diverse as possible,
+def k_means(self, X_scaled, csv_y, size, seed, idx_list):
+ """
+
+ Uses k-means clustering to select the test points to be as diverse as possible,
but it returns the test pointsReturns the data points that will be used as training set based on the k-means clustering
-
- '''
-
+
+ """
+
# number of clusters in the training set from the k-means clustering (based on the
# training set size specified above)
X_scaled_array = np.asarray(X_scaled)
@@ -2577,20 +3050,22 @@ def k_means(self,X_scaled,csv_y,size,seed,idx_list):
# to avoid points from the validation set outside the training set, the 2 first training
# points are automatically set as the 2 points with minimum/maximum response value
- if self.args.type.lower() == 'reg':
+ if self.args.type.lower() == "reg":
test_points = []
- training_idx = [csv_y.idxmin(),csv_y.idxmax()]
+ training_idx = [csv_y.idxmin(), csv_y.idxmax()]
number_of_clusters -= 2
else:
test_points = []
training_idx = []
-
+
# runs the k-means algorithm and keeps the closest point to the center of each cluster
- kmeans = KMeans(n_clusters=number_of_clusters,random_state=seed)
+ kmeans = KMeans(n_clusters=number_of_clusters, random_state=seed)
try:
kmeans.fit(X_scaled_array)
except ValueError:
- self.args.log.write("\nx The K-means clustering process failed! This might be due to having NaN or strings as descriptors (curate the data first with CURATE) or having too few datapoints!")
+ self.args.log.write(
+ "\nx The K-means clustering process failed! This might be due to having NaN or strings as descriptors (curate the data first with CURATE) or having too few datapoints!"
+ )
sys.exit()
centers = kmeans.cluster_centers_
for i in range(number_of_clusters):
@@ -2599,14 +3074,18 @@ def k_means(self,X_scaled,csv_y,size,seed,idx_list):
if k not in training_idx:
# calculate the Euclidean distance in n-dimensions
points_sum = 0
- for l in range(len(X_scaled_array[0])):
- points_sum += (X_scaled_array[:, l][k]-centers[:, l][i])**2
+ for idx_l in range(len(X_scaled_array[0])):
+ points_sum += (
+ X_scaled_array[:, idx_l][k] - centers[:, idx_l][i]
+ ) ** 2
if np.sqrt(points_sum) < results_cluster:
results_cluster = np.sqrt(points_sum)
training_point = k
training_idx.append(training_point)
- test_idx = [idx for idx in range(len(X_scaled_array[:, 0])) if idx not in training_idx]
+ test_idx = [
+ idx for idx in range(len(X_scaled_array[:, 0])) if idx not in training_idx
+ ]
test_points = [idx_list[i] for i in test_idx]
test_points.sort()
@@ -2614,62 +3093,86 @@ def k_means(self,X_scaled,csv_y,size,seed,idx_list):
def PFI_filter(self, Xy_data, model_data):
- '''
+ """
Performs the PFI calculation and returns a list of the descriptors that are not important
- '''
+ """
# load and fit model
- loaded_model = load_model(self,model_data['model'],**model_data['params'])
- loaded_model.fit(Xy_data['X_train_scaled'], Xy_data['y_train'])
+ loaded_model = load_model(self, model_data["model"], **model_data["params"])
+ loaded_model.fit(Xy_data["X_train_scaled"], Xy_data["y_train"])
# select scoring function for PFI analysis based on the error type
- scoring, score_model, _ = scoring_n_score(self,model_data,Xy_data,loaded_model)
-
- perm_importance = permutation_importance(loaded_model, Xy_data['X_train_scaled'], Xy_data['y_train'], scoring=scoring, n_repeats=self.args.pfi_epochs, random_state=self.args.seed, n_jobs=1)
+ scoring, score_model, _ = scoring_n_score(self, model_data, Xy_data, loaded_model)
+
+ perm_importance = permutation_importance(
+ loaded_model,
+ Xy_data["X_train_scaled"],
+ Xy_data["y_train"],
+ scoring=scoring,
+ n_repeats=self.args.pfi_epochs,
+ random_state=self.args.seed,
+ n_jobs=1,
+ )
# transforms the values into a list and sort the PFI values with the descriptor names
- descp_cols_pfi, PFI_values, PFI_sd = [],[],[]
- for i,desc in enumerate(Xy_data['X_train_scaled'].columns):
- descp_cols_pfi.append(desc) # includes lists of descriptors not column names!
+ descp_cols_pfi, PFI_values, PFI_sd = [], [], []
+ for i, desc in enumerate(Xy_data["X_train_scaled"].columns):
+ descp_cols_pfi.append(desc) # includes lists of descriptors not column names!
PFI_values.append(perm_importance.importances_mean[i])
PFI_sd.append(perm_importance.importances_std[i])
-
- PFI_values, PFI_sd, descp_cols_pfi = (list(t) for t in zip(*sorted(zip(PFI_values, PFI_sd, descp_cols_pfi), reverse=True)))
+
+ PFI_values, PFI_sd, descp_cols_pfi = (
+ list(t)
+ for t in zip(*sorted(zip(PFI_values, PFI_sd, descp_cols_pfi), reverse=True))
+ )
# PFI filter
PFI_discard_cols = []
# the threshold is based either on the RMSE of the model or the importance of the most important descriptor
- PFI_thres = max([abs(self.args.pfi_threshold*score_model),abs(self.args.pfi_threshold*PFI_values[0])])
+ PFI_thres = max(
+ [
+ abs(self.args.pfi_threshold * score_model),
+ abs(self.args.pfi_threshold * PFI_values[0]),
+ ]
+ )
for i in range(len(PFI_values)):
if PFI_values[i] < PFI_thres:
PFI_discard_cols.append(descp_cols_pfi[i])
- return PFI_discard_cols,descp_cols_pfi
+ return PFI_discard_cols, descp_cols_pfi
-def scoring_n_score(self,model_data,Xy_data,loaded_model):
- '''
+def scoring_n_score(self, model_data, Xy_data, loaded_model):
+ """
Get scoring system and score of the original model with CV
- '''
+ """
- error_type = model_data['error_type'].lower()
- scoring = get_scoring_key(model_data['type'],error_type)
- cv_model = RepeatedKFold(n_splits=self.args.kfold, n_repeats=self.args.repeat_kfolds, random_state=self.args.seed)
- score_model = cross_val_score(estimator = loaded_model, X=Xy_data['X_train_scaled'], y=Xy_data['y_train'],scoring=scoring, cv =cv_model)
+ error_type = model_data["error_type"].lower()
+ scoring = get_scoring_key(model_data["type"], error_type)
+ cv_model = RepeatedKFold(
+ n_splits=self.args.kfold,
+ n_repeats=self.args.repeat_kfolds,
+ random_state=self.args.seed,
+ )
+ score_model = cross_val_score(
+ estimator=loaded_model,
+ X=Xy_data["X_train_scaled"],
+ y=Xy_data["y_train"],
+ scoring=scoring,
+ cv=cv_model,
+ )
score_model = score_model.mean()
- if model_data['error_type'].lower() in ['rmse','mae']:
+ if model_data["error_type"].lower() in ["rmse", "mae"]:
score_model = -score_model
return scoring, score_model, error_type
-def create_heatmap(self,csv_df,suffix,path_raw):
+def create_heatmap(self, csv_df, suffix, path_raw):
"""
Graph the heatmap
"""
- import seaborn as sb
-
with _mpl_plot_context():
csv_df = csv_df.sort_index(ascending=False)
sb.set(font_scale=1.2, style="ticks")
@@ -2690,265 +3193,377 @@ def create_heatmap(self,csv_df,suffix,path_raw):
ax.set_xlabel("ML Model", fontsize=fontsize)
ax.set_ylabel("", fontsize=fontsize)
ax.tick_params(axis="x", which="major", labelsize=fontsize)
- ax.tick_params(
- axis="y", which="both", left=False, right=False, labelleft=False
- )
+ ax.tick_params(axis="y", which="both", left=False, right=False, labelleft=False)
title_fig = f"Heatmap ML models {suffix}"
plt.title(title_fig, y=1.04, fontsize=fontsize, fontweight="bold")
sb.despine(top=False, right=False)
name_fig = "_".join(title_fig.split())
- plt.savefig(
- f"{path_raw.joinpath(name_fig)}.png", dpi=300, bbox_inches="tight"
- )
+ plt.savefig(f"{path_raw.joinpath(name_fig)}.png", dpi=300, bbox_inches="tight")
+ plt.close()
path_reduced = "/".join(f"{path_raw}".replace("\\", "/").split("/")[-2:])
self.args.log.write(f"\no {name_fig} succesfully created in {path_reduced}")
-def graph_reg(self,Xy_data,params_dict,set_types,path_n_suffix,graph_style,csv_test=False,print_fun=True,sd_graph=False):
- '''
+def graph_reg(
+ self,
+ Xy_data,
+ params_dict,
+ set_types,
+ path_n_suffix,
+ graph_style,
+ csv_test=False,
+ print_fun=True,
+ sd_graph=False,
+):
+ """
Plot regression graphs of predicted vs actual values for train, validation and test sets
- '''
- import seaborn as sb
-
+ """
sb.set(style="ticks")
- _, ax = plt.subplots(figsize=(7.45,6))
+ fig, ax = plt.subplots(figsize=(7.45, 6))
# Set tick sizes
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
-
+
error_bars = "test"
- title_graph = graph_title(self,csv_test,sd_graph,error_bars)
+ title_graph = graph_title(self, csv_test, sd_graph, error_bars)
if print_fun:
- plt.text(0.5, 1.08, f'{title_graph} of {os.path.basename(path_n_suffix)}', horizontalalignment='center',
- fontsize=14, fontweight='bold', transform = ax.transAxes)
+ plt.text(
+ 0.5,
+ 1.08,
+ f"{title_graph} of {os.path.basename(path_n_suffix)}",
+ horizontalalignment="center",
+ fontsize=14,
+ fontweight="bold",
+ transform=ax.transAxes,
+ )
# Plot the data
if not sd_graph:
- _ = ax.scatter(Xy_data["y_train"], Xy_data["y_pred_train"],
- c = graph_style['color_train'], s = graph_style['dot_size'], edgecolor = 'k', linewidths = 0.8, alpha = graph_style['alpha'], zorder=2)
+ _ = ax.scatter(
+ Xy_data["y_train"],
+ Xy_data["y_pred_train"],
+ c=graph_style["color_train"],
+ s=graph_style["dot_size"],
+ edgecolor="k",
+ linewidths=0.8,
+ alpha=graph_style["alpha"],
+ zorder=2,
+ )
if not csv_test:
- _ = ax.scatter(Xy_data["y_test"], Xy_data["y_pred_test"],
- c = graph_style['color_test'], s = graph_style['dot_size'], edgecolor = 'k', linewidths = 0.8, alpha = graph_style['alpha'], zorder=3)
+ _ = ax.scatter(
+ Xy_data["y_test"],
+ Xy_data["y_pred_test"],
+ c=graph_style["color_test"],
+ s=graph_style["dot_size"],
+ edgecolor="k",
+ linewidths=0.8,
+ alpha=graph_style["alpha"],
+ zorder=3,
+ )
else:
error_bars = "external"
- _ = ax.scatter(Xy_data["y_external"], Xy_data["y_pred_external"],
- c = graph_style['color_test'], s = graph_style['dot_size'], edgecolor = 'k', linewidths = 0.8, alpha = graph_style['alpha'], zorder=2)
+ _ = ax.scatter(
+ Xy_data["y_external"],
+ Xy_data["y_pred_external"],
+ c=graph_style["color_test"],
+ s=graph_style["dot_size"],
+ edgecolor="k",
+ linewidths=0.8,
+ alpha=graph_style["alpha"],
+ zorder=2,
+ )
- # average CV ± SD graphs
+ # average CV ± SD graphs
if sd_graph:
- if not csv_test:
+ if not csv_test:
# Plot the data with the error bars
- _ = ax.errorbar(Xy_data[f"y_{error_bars}"], Xy_data[f"y_pred_{error_bars}"], yerr=Xy_data[f"y_pred_{error_bars}_sd"], fmt='none', ecolor="gray", capsize=3, zorder=1)
+ _ = ax.errorbar(
+ Xy_data[f"y_{error_bars}"],
+ Xy_data[f"y_pred_{error_bars}"],
+ yerr=Xy_data[f"y_pred_{error_bars}_sd"],
+ fmt="none",
+ ecolor="gray",
+ capsize=3,
+ zorder=1,
+ )
# Adjust labels from legend
- set_types=[error_bars,f'± SD']
+ set_types = [error_bars, "± SD"]
else:
- _ = ax.errorbar(Xy_data[f"y_{error_bars}"], Xy_data[f"y_pred_{error_bars}"], yerr=Xy_data[f"y_pred_{error_bars}_sd"], fmt='none', ecolor="gray", capsize=3, zorder=1)
- set_types=['External test',f'± SD']
+ _ = ax.errorbar(
+ Xy_data[f"y_{error_bars}"],
+ Xy_data[f"y_pred_{error_bars}"],
+ yerr=Xy_data[f"y_pred_{error_bars}_sd"],
+ fmt="none",
+ ecolor="gray",
+ capsize=3,
+ zorder=1,
+ )
+ set_types = ["External test", "± SD"]
# legend and regression line with 95% CI considering all possible lines (not CI of the points)
- if 'CV' in set_types[0]: # CV in VERIFY
+ if "CV" in set_types[0]: # CV in VERIFY
legend_coords = (0.70, 0.15)
- elif len(set_types) == 2: # external test or sets with ± SD
- if 'External test' in set_types:
+ elif len(set_types) == 2: # external test or sets with ± SD
+ if "External test" in set_types:
legend_coords = (0.66, 0.15)
else:
legend_coords = (0.735, 0.15)
- ax.legend(loc='upper center', bbox_to_anchor=legend_coords,
- handletextpad=0,
- fancybox=True, shadow=True, ncol=5, labels=set_types, fontsize=14)
+ ax.legend(
+ loc="upper center",
+ bbox_to_anchor=legend_coords,
+ handletextpad=0,
+ fancybox=True,
+ shadow=True,
+ ncol=5,
+ labels=set_types,
+ fontsize=14,
+ )
Xy_data_df = pd.DataFrame()
if not sd_graph:
- line_suff = 'train'
+ line_suff = "train"
elif not csv_test:
- line_suff = 'test'
+ line_suff = "test"
else:
- line_suff = 'external'
+ line_suff = "external"
Xy_data_df[f"y_{line_suff}"] = Xy_data[f"y_{line_suff}"]
Xy_data_df[f"y_pred_{line_suff}"] = Xy_data[f"y_pred_{line_suff}"]
if len(Xy_data_df[f"y_pred_{line_suff}"]) >= 10:
- _ = sb.regplot(x=f"y_{line_suff}", y=f"y_pred_{line_suff}", data=Xy_data_df, scatter=False, color=".1",
- truncate = True, ax=ax, seed=params_dict['seed'])
+ _ = sb.regplot(
+ x=f"y_{line_suff}",
+ y=f"y_pred_{line_suff}",
+ data=Xy_data_df,
+ scatter=False,
+ color=".1",
+ truncate=True,
+ ax=ax,
+ seed=params_dict["seed"],
+ )
# Title and labels of the axis
- plt.ylabel(f'Predicted {params_dict["y"]}', fontsize=14)
- plt.xlabel(f'{params_dict["y"]}', fontsize=14)
+ plt.ylabel(f"Predicted {params_dict['y']}", fontsize=14)
+ plt.xlabel(f"{params_dict['y']}", fontsize=14)
# set axis limits and graph PATH
- min_value_graph,max_value_graph,reg_plot_file,path_reduced = graph_vars(Xy_data,set_types,csv_test,path_n_suffix,sd_graph)
+ min_value_graph, max_value_graph, reg_plot_file, path_reduced = graph_vars(
+ Xy_data, set_types, csv_test, path_n_suffix, sd_graph
+ )
# track the range of predictions (used in ROBERT score)
- pred_min = min(min(Xy_data["y_train"]),min(Xy_data["y_test"]))
- pred_max = max(max(Xy_data["y_train"]),max(Xy_data["y_test"]))
- pred_range = np.abs(pred_max-pred_min)
- Xy_data['pred_min'] = pred_min
- Xy_data['pred_max'] = pred_max
- Xy_data['pred_range'] = pred_range
+ pred_min = min(min(Xy_data["y_train"]), min(Xy_data["y_test"]))
+ pred_max = max(max(Xy_data["y_train"]), max(Xy_data["y_test"]))
+ pred_range = np.abs(pred_max - pred_min)
+ Xy_data["pred_min"] = pred_min
+ Xy_data["pred_max"] = pred_max
+ Xy_data["pred_range"] = pred_range
# Add gridlines
- ax.grid(linestyle='--', linewidth=1)
+ ax.grid(linestyle="--", linewidth=1)
# set axis limits
plt.xlim(min_value_graph, max_value_graph)
plt.ylim(min_value_graph, max_value_graph)
# save graph
- plt.savefig(f'{reg_plot_file}', dpi=300, bbox_inches='tight')
+ plt.savefig(f"{reg_plot_file}", dpi=300, bbox_inches="tight")
+ plt.close(fig)
if print_fun:
self.args.log.write(f" - Graph in: {path_reduced}")
-def graph_title(self,csv_test,sd_graph,error_bars):
- '''
+def graph_title(self, csv_test, sd_graph, error_bars):
+ """
Retrieves the corresponding graph title.
- '''
+ """
# set title for regular graphs
if not sd_graph:
if not csv_test:
# regular graphs
- title_graph = f'Predictions CV and test set'
+ title_graph = "Predictions CV and test set"
else:
- title_graph = f'{os.path.basename(self.args.csv_test)}'
+ title_graph = f"{os.path.basename(self.args.csv_test)}"
if len(title_graph) > 30:
- title_graph = f'{title_graph[:27]}...'
+ title_graph = f"{title_graph[:27]}..."
# set title for averaged CV ± SD graphs
else:
if not csv_test:
sets_title = error_bars
else:
- sets_title = 'external test'
+ sets_title = "external test"
- title_graph = f'{sets_title} set ± SD (CV)'
+ title_graph = f"{sets_title} set ± SD (CV)"
return title_graph
-def graph_vars(Xy_data,set_types,csv_test,path_n_suffix,sd_graph):
- '''
+def graph_vars(Xy_data, set_types, csv_test, path_n_suffix, sd_graph):
+ """
Set axis limits for regression plots and PATH to save the graphs
- '''
+ """
# x and y axis limits for graphs with multiple sets
if not csv_test:
- size_space = 0.1*abs(min(Xy_data["y_train"])-max(Xy_data["y_train"]))
- min_value_graph = min(min(Xy_data["y_train"]),min(Xy_data["y_pred_train"]),min(Xy_data["y_test"]),min(Xy_data["y_pred_test"]))
- if 'test' in set_types:
- min_value_graph = min(min_value_graph,min(Xy_data["y_test"]),min(Xy_data["y_pred_test"]))
- min_value_graph = min_value_graph-size_space
-
- max_value_graph = max(max(Xy_data["y_train"]),max(Xy_data["y_pred_train"]),max(Xy_data["y_test"]),max(Xy_data["y_pred_test"]))
- if 'test' in set_types:
- max_value_graph = max(max_value_graph,max(Xy_data["y_test"]),max(Xy_data["y_pred_test"]))
- max_value_graph = max_value_graph+size_space
-
- else: # limits for graphs with only one set
- set_type = 'external'
- size_space = 0.1*abs(min(Xy_data[f'y_{set_type}'])-max(Xy_data[f'y_{set_type}']))
- min_value_graph = min(min(Xy_data[f'y_{set_type}']),min(Xy_data[f'y_pred_{set_type}']))
- min_value_graph = min_value_graph-size_space
- max_value_graph = max(max(Xy_data[f'y_{set_type}']),max(Xy_data[f'y_pred_{set_type}']))
- max_value_graph = max_value_graph+size_space
+ size_space = 0.1 * abs(min(Xy_data["y_train"]) - max(Xy_data["y_train"]))
+ min_value_graph = min(
+ min(Xy_data["y_train"]),
+ min(Xy_data["y_pred_train"]),
+ min(Xy_data["y_test"]),
+ min(Xy_data["y_pred_test"]),
+ )
+ if "test" in set_types:
+ min_value_graph = min(
+ min_value_graph, min(Xy_data["y_test"]), min(Xy_data["y_pred_test"])
+ )
+ min_value_graph = min_value_graph - size_space
+
+ max_value_graph = max(
+ max(Xy_data["y_train"]),
+ max(Xy_data["y_pred_train"]),
+ max(Xy_data["y_test"]),
+ max(Xy_data["y_pred_test"]),
+ )
+ if "test" in set_types:
+ max_value_graph = max(
+ max_value_graph, max(Xy_data["y_test"]), max(Xy_data["y_pred_test"])
+ )
+ max_value_graph = max_value_graph + size_space
+
+ else: # limits for graphs with only one set
+ set_type = "external"
+ size_space = 0.1 * abs(
+ min(Xy_data[f"y_{set_type}"]) - max(Xy_data[f"y_{set_type}"])
+ )
+ min_value_graph = min(
+ min(Xy_data[f"y_{set_type}"]), min(Xy_data[f"y_pred_{set_type}"])
+ )
+ min_value_graph = min_value_graph - size_space
+ max_value_graph = max(
+ max(Xy_data[f"y_{set_type}"]), max(Xy_data[f"y_pred_{set_type}"])
+ )
+ max_value_graph = max_value_graph + size_space
# PATH of the graph
if not csv_test:
if not sd_graph:
- reg_plot_file = f'{os.path.dirname(path_n_suffix)}/Results_{os.path.basename(path_n_suffix)}.png'
+ reg_plot_file = f"{os.path.dirname(path_n_suffix)}/Results_{os.path.basename(path_n_suffix)}.png"
else:
- reg_plot_file = f'{os.path.dirname(path_n_suffix)}/CV_variability_{os.path.basename(path_n_suffix)}.png'
- path_reduced = '/'.join(f'{reg_plot_file}'.replace('\\','/').split('/')[-2:])
+ reg_plot_file = f"{os.path.dirname(path_n_suffix)}/CV_variability_{os.path.basename(path_n_suffix)}.png"
+ path_reduced = "/".join(f"{reg_plot_file}".replace("\\", "/").split("/")[-2:])
else:
- folder_graph = f'{os.path.dirname(path_n_suffix)}/csv_test'
+ folder_graph = f"{os.path.dirname(path_n_suffix)}/csv_test"
if not sd_graph:
- reg_plot_file = f'{folder_graph}/Results_{os.path.basename(path_n_suffix)}_{set_type}.png'
+ reg_plot_file = f"{folder_graph}/Results_{os.path.basename(path_n_suffix)}_{set_type}.png"
else:
- reg_plot_file = f'{folder_graph}/CV_variability_{os.path.basename(path_n_suffix)}_{set_type}.png'
- path_reduced = '/'.join(f'{reg_plot_file}'.replace('\\','/').split('/')[-3:])
-
- return min_value_graph,max_value_graph,reg_plot_file,path_reduced
+ reg_plot_file = f"{folder_graph}/CV_variability_{os.path.basename(path_n_suffix)}_{set_type}.png"
+ path_reduced = "/".join(f"{reg_plot_file}".replace("\\", "/").split("/")[-3:])
+
+ return min_value_graph, max_value_graph, reg_plot_file, path_reduced
-def graph_clas(self,Xy_data,params_dict,set_type,path_n_suffix,csv_test=False,print_fun=True):
- '''
+def graph_clas(
+ self, Xy_data, params_dict, set_type, path_n_suffix, csv_test=False, print_fun=True
+):
+ """
Plot a confusion matrix with the prediction vs actual values
- '''
+ """
# Check if we need to use original class labels for display
display_labels = None
- if 'class_0_label' in params_dict and 'class_1_label' in params_dict:
- display_labels = [params_dict['class_0_label'], params_dict['class_1_label']]
+ if "class_0_label" in params_dict and "class_1_label" in params_dict:
+ display_labels = [params_dict["class_0_label"], params_dict["class_1_label"]]
# get confusion matrix
- if 'CV' in set_type: # CV graphs
- y_train_binary = np.round(Xy_data[f'y_train']).astype(int)
- y_pred_train_binary = np.round(Xy_data[f'y_pred_train']).astype(int)
- matrix = ConfusionMatrixDisplay.from_predictions(y_train_binary, y_pred_train_binary,
- normalize=None, cmap='Blues',
- display_labels=display_labels)
- else: # other graphs
- y_binary = np.round(Xy_data[f'y_{set_type}']).astype(int)
- y_pred_binary = np.round(Xy_data[f'y_pred_{set_type}']).astype(int)
- matrix = ConfusionMatrixDisplay.from_predictions(y_binary, y_pred_binary,
- normalize=None, cmap='Blues',
- display_labels=display_labels)
+ if "CV" in set_type: # CV graphs
+ y_train_binary = np.round(Xy_data["y_train"]).astype(int)
+ y_pred_train_binary = np.round(Xy_data["y_pred_train"]).astype(int)
+ matrix = ConfusionMatrixDisplay.from_predictions(
+ y_train_binary,
+ y_pred_train_binary,
+ normalize=None,
+ cmap="Blues",
+ display_labels=display_labels,
+ )
+ else: # other graphs
+ y_binary = np.round(Xy_data[f"y_{set_type}"]).astype(int)
+ y_pred_binary = np.round(Xy_data[f"y_pred_{set_type}"]).astype(int)
+ matrix = ConfusionMatrixDisplay.from_predictions(
+ y_binary,
+ y_pred_binary,
+ normalize=None,
+ cmap="Blues",
+ display_labels=display_labels,
+ )
# transfer it to the same format and size used in reg graphs
- _, ax = plt.subplots(figsize=(7.45,6))
- matrix.plot(ax=ax, cmap='Blues')
+ _, ax = plt.subplots(figsize=(7.45, 6))
+ matrix.plot(ax=ax, cmap="Blues")
if print_fun:
- if 'CV' not in set_type:
- title_set = f'{set_type} set'
+ if "CV" not in set_type:
+ title_set = f"{set_type} set"
else:
title_set = set_type
- plt.text(0.5, 1.08, f'{title_set} of {os.path.basename(path_n_suffix)}', horizontalalignment='center',
- fontsize=14, fontweight='bold', transform = ax.transAxes)
+ plt.text(
+ 0.5,
+ 1.08,
+ f"{title_set} of {os.path.basename(path_n_suffix)}",
+ horizontalalignment="center",
+ fontsize=14,
+ fontweight="bold",
+ transform=ax.transAxes,
+ )
- plt.xlabel(f'Predicted {params_dict["y"]}', fontsize=14)
- plt.ylabel(f'{params_dict["y"]}', fontsize=14)
+ plt.xlabel(f"Predicted {params_dict['y']}", fontsize=14)
+ plt.ylabel(f"{params_dict['y']}", fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
# save fig
- if 'CV' in set_type: # CV graphs
- clas_plot_file = f'{os.path.dirname(path_n_suffix)}/CV_train_valid_predict_{os.path.basename(path_n_suffix)}.png'
- path_reduced = '/'.join(f'{clas_plot_file}'.replace('\\','/').split('/')[-2:])
+ if "CV" in set_type: # CV graphs
+ clas_plot_file = f"{os.path.dirname(path_n_suffix)}/CV_train_valid_predict_{os.path.basename(path_n_suffix)}.png"
+ path_reduced = "/".join(f"{clas_plot_file}".replace("\\", "/").split("/")[-2:])
elif not csv_test:
- clas_plot_file = f'{os.path.dirname(path_n_suffix)}/Results_{os.path.basename(path_n_suffix)}_{set_type}.png'
- path_reduced = '/'.join(f'{clas_plot_file}'.replace('\\','/').split('/')[-2:])
+ clas_plot_file = f"{os.path.dirname(path_n_suffix)}/Results_{os.path.basename(path_n_suffix)}_{set_type}.png"
+ path_reduced = "/".join(f"{clas_plot_file}".replace("\\", "/").split("/")[-2:])
else:
- folder_graph = f'{os.path.dirname(path_n_suffix)}/csv_test'
- clas_plot_file = f'{folder_graph}/Results_{os.path.basename(path_n_suffix)}_{set_type}.png'
- path_reduced = '/'.join(f'{clas_plot_file}'.replace('\\','/').split('/')[-3:])
+ folder_graph = f"{os.path.dirname(path_n_suffix)}/csv_test"
+ clas_plot_file = (
+ f"{folder_graph}/Results_{os.path.basename(path_n_suffix)}_{set_type}.png"
+ )
+ path_reduced = "/".join(f"{clas_plot_file}".replace("\\", "/").split("/")[-3:])
- plt.savefig(f'{clas_plot_file}', dpi=300, bbox_inches='tight')
+ plt.savefig(f"{clas_plot_file}", dpi=300, bbox_inches="tight")
+ plt.close()
if print_fun:
self.args.log.write(f" - Graph in: {path_reduced}")
def shap_analysis(self, Xy_data, model_data, path_n_suffix, fitted_model=None):
- '''
+ """
Plots and prints the results of the SHAP analysis
- '''
+ """
import shap
_, _ = plt.subplots(figsize=(7.45, 6))
- shap_plot_file = f'{os.path.dirname(path_n_suffix)}/SHAP_{os.path.basename(path_n_suffix)}.png'
+ shap_plot_file = (
+ f"{os.path.dirname(path_n_suffix)}/SHAP_{os.path.basename(path_n_suffix)}.png"
+ )
if fitted_model is None:
loaded_model = load_model(self, model_data["model"], **model_data["params"])
@@ -2957,59 +3572,85 @@ def shap_analysis(self, Xy_data, model_data, path_n_suffix, fitted_model=None):
loaded_model = fitted_model
# run the SHAP analysis and save the plot
- explainer = shap.Explainer(loaded_model.predict, Xy_data['X_train_scaled'], seed=model_data['seed'])
+ explainer = shap.Explainer(
+ loaded_model.predict, Xy_data["X_train_scaled"], seed=model_data["seed"]
+ )
try:
- shap_values = explainer(Xy_data['X_train_scaled'])
+ shap_values = explainer(Xy_data["X_train_scaled"])
except ValueError:
- shap_values = explainer(Xy_data['X_train_scaled'],max_evals=(2*len(Xy_data['X_train_scaled'].columns))+1)
+ shap_values = explainer(
+ Xy_data["X_train_scaled"],
+ max_evals=(2 * len(Xy_data["X_train_scaled"].columns)) + 1,
+ )
- shap_show = [self.args.shap_show,len(Xy_data['X_train_scaled'].columns)]
- aspect_shap = 25+((min(shap_show)-2)*5)
- height_shap = 1.2+min(shap_show)/4
+ shap_show = [self.args.shap_show, len(Xy_data["X_train_scaled"].columns)]
+ aspect_shap = 25 + ((min(shap_show) - 2) * 5)
+ height_shap = 1.2 + min(shap_show) / 4
# explainer = shap.TreeExplainer(loaded_model) # in case the standard version doesn't work
- _ = shap.summary_plot(shap_values, Xy_data['X_train_scaled'], max_display=self.args.shap_show,show=False, plot_size=[7.45,height_shap])
+ _ = shap.summary_plot(
+ shap_values,
+ Xy_data["X_train_scaled"],
+ max_display=self.args.shap_show,
+ show=False,
+ plot_size=[7.45, height_shap],
+ )
# set title
- plt.title(f'SHAP analysis of {os.path.basename(path_n_suffix)}', fontsize = 14, fontweight="bold")
+ plt.title(
+ f"SHAP analysis of {os.path.basename(path_n_suffix)}",
+ fontsize=14,
+ fontweight="bold",
+ )
- path_reduced = '/'.join(f'{shap_plot_file}'.replace('\\','/').split('/')[-2:])
+ path_reduced = "/".join(f"{shap_plot_file}".replace("\\", "/").split("/")[-2:])
print_shap = f"\n o SHAP plot saved in {path_reduced}"
# collect SHAP values and print
- desc_list, min_list, max_list = [],[],[]
- for i,desc in enumerate(Xy_data['X_train_scaled']):
+ desc_list, min_list, max_list = [], [], []
+ for i, desc in enumerate(Xy_data["X_train_scaled"]):
desc_list.append(desc)
- val_list_indiv= []
- for _,val in enumerate(shap_values.values):
+ val_list_indiv = []
+ for _, val in enumerate(shap_values.values):
val_list_indiv.append(val[i])
min_indiv = min(val_list_indiv)
max_indiv = max(val_list_indiv)
min_list.append(min_indiv)
max_list.append(max_indiv)
-
+
if max(max_list, key=abs) > max(min_list, key=abs):
- max_list, min_list, desc_list = (list(t) for t in zip(*sorted(zip(max_list, min_list, desc_list), reverse=True)))
+ max_list, min_list, desc_list = (
+ list(t)
+ for t in zip(*sorted(zip(max_list, min_list, desc_list), reverse=True))
+ )
else:
- min_list, max_list, desc_list = (list(t) for t in zip(*sorted(zip(min_list, max_list, desc_list), reverse=False)))
+ min_list, max_list, desc_list = (
+ list(t)
+ for t in zip(*sorted(zip(min_list, max_list, desc_list), reverse=False))
+ )
- for i,desc in enumerate(desc_list):
- print_shap += f"\n - {desc} = min: {min_list[i]:.2}, max: {max_list[i]:.2}"
+ for i, desc in enumerate(desc_list):
+ print_shap += (
+ f"\n - {desc} = min: {min_list[i]:.2}, max: {max_list[i]:.2}"
+ )
self.args.log.write(print_shap)
# adjust width of the colorbar
plt.gcf().axes[-1].set_aspect(aspect_shap)
plt.gcf().axes[-1].set_box_aspect(aspect_shap)
-
- plt.savefig(f'{shap_plot_file}', dpi=300, bbox_inches='tight')
+
+ plt.savefig(f"{shap_plot_file}", dpi=300, bbox_inches="tight")
+ plt.close()
def PFI_plot(self, Xy_data, model_data, path_n_suffix, fitted_model=None):
- '''
+ """
Plots and prints the results of the PFI analysis
- '''
- pfi_plot_file = f'{os.path.dirname(path_n_suffix)}/PFI_{os.path.basename(path_n_suffix)}.png'
+ """
+ pfi_plot_file = (
+ f"{os.path.dirname(path_n_suffix)}/PFI_{os.path.basename(path_n_suffix)}.png"
+ )
if fitted_model is None:
loaded_model = load_model(self, model_data["model"], **model_data["params"])
@@ -3018,144 +3659,204 @@ def PFI_plot(self, Xy_data, model_data, path_n_suffix, fitted_model=None):
loaded_model = fitted_model
# select scoring function for PFI analysis based on the error type
- scoring, _, error_type = scoring_n_score(self,model_data,Xy_data,loaded_model)
-
- perm_importance = permutation_importance(loaded_model, Xy_data['X_train_scaled'], Xy_data['y_train'], scoring=scoring, n_repeats=self.args.pfi_epochs, random_state=model_data['seed'], n_jobs=1)
+ scoring, _, error_type = scoring_n_score(self, model_data, Xy_data, loaded_model)
+
+ perm_importance = permutation_importance(
+ loaded_model,
+ Xy_data["X_train_scaled"],
+ Xy_data["y_train"],
+ scoring=scoring,
+ n_repeats=self.args.pfi_epochs,
+ random_state=model_data["seed"],
+ n_jobs=1,
+ )
# sort descriptors and results from PFI
- desc_list, PFI_values, PFI_sd = [],[],[]
- for i,desc in enumerate(Xy_data['X_train_scaled']):
+ desc_list, PFI_values, PFI_sd = [], [], []
+ for i, desc in enumerate(Xy_data["X_train_scaled"]):
desc_list.append(desc)
PFI_values.append(perm_importance.importances_mean[i])
PFI_sd.append(perm_importance.importances_std[i])
# sort from higher to lower values and keep only the top self.args.pfi_show descriptors
- PFI_values, PFI_sd, desc_list = (list(t) for t in zip(*sorted(zip(PFI_values, PFI_sd, desc_list), reverse=True)))
- PFI_values_plot = PFI_values[:self.args.pfi_show][::-1]
- desc_list_plot = desc_list[:self.args.pfi_show][::-1]
+ PFI_values, PFI_sd, desc_list = (
+ list(t) for t in zip(*sorted(zip(PFI_values, PFI_sd, desc_list), reverse=True))
+ )
+ PFI_values_plot = PFI_values[: self.args.pfi_show][::-1]
+ desc_list_plot = desc_list[: self.args.pfi_show][::-1]
# plot and print results
- _, ax = plt.subplots(figsize=(7.45,6))
+ fig, ax = plt.subplots(figsize=(7.45, 6))
y_ticks = np.arange(0, len(desc_list_plot))
ax.barh(desc_list_plot, PFI_values_plot)
- ax.set_yticks(y_ticks,labels=desc_list_plot,fontsize=14)
- plt.text(0.5, 1.08, f'Permutation feature importances (PFIs) of {os.path.basename(path_n_suffix)}', horizontalalignment='center',
- fontsize=14, fontweight='bold', transform = ax.transAxes)
- ax.set(ylabel=None, xlabel='PFI')
+ ax.set_yticks(y_ticks, labels=desc_list_plot, fontsize=14)
+ plt.text(
+ 0.5,
+ 1.08,
+ f"Permutation feature importances (PFIs) of {os.path.basename(path_n_suffix)}",
+ horizontalalignment="center",
+ fontsize=14,
+ fontweight="bold",
+ transform=ax.transAxes,
+ )
+ ax.set(ylabel=None, xlabel="PFI")
- plt.savefig(f'{pfi_plot_file}', dpi=300, bbox_inches='tight')
+ plt.savefig(f"{pfi_plot_file}", dpi=300, bbox_inches="tight")
+ plt.close(fig)
- path_reduced = '/'.join(f'{pfi_plot_file}'.replace('\\','/').split('/')[-2:])
+ path_reduced = "/".join(f"{pfi_plot_file}".replace("\\", "/").split("/")[-2:])
print_PFI = f"\n o PFI plot saved in {path_reduced}"
- print_PFI += f'\n Influence on {error_type.upper()}'
+ print_PFI += f"\n Influence on {error_type.upper()}"
- for i,desc in enumerate(desc_list):
+ for i, desc in enumerate(desc_list):
print_PFI += f"\n - {desc} = {PFI_values[i]:.2} +- {PFI_sd[i]:.2}"
-
+
self.args.log.write(print_PFI)
-def outlier_plot(self,Xy_data,path_n_suffix,name_points,graph_style):
- '''
+def outlier_plot(self, Xy_data, path_n_suffix, name_points, graph_style):
+ """
Plots and prints the results of the outlier analysis
- '''
- import seaborn as sb
-
+ """
# detect outliers
outliers_data, print_outliers = outlier_filter(self, Xy_data, name_points)
# plot data in SD units
sb.set(style="ticks")
- _, ax = plt.subplots(figsize=(7.45,6))
- plt.text(0.5, 1.08, f'Outlier analysis of {os.path.basename(path_n_suffix)}', horizontalalignment='center',
- fontsize=14, fontweight='bold', transform = ax.transAxes)
+ fig, ax = plt.subplots(figsize=(7.45, 6))
+ plt.text(
+ 0.5,
+ 1.08,
+ f"Outlier analysis of {os.path.basename(path_n_suffix)}",
+ horizontalalignment="center",
+ fontsize=14,
+ fontweight="bold",
+ transform=ax.transAxes,
+ )
+
+ plt.grid(linestyle="--", linewidth=1)
+ _ = ax.scatter(
+ outliers_data["train_scaled"],
+ outliers_data["train_scaled"],
+ c=graph_style["color_train"],
+ s=graph_style["dot_size"],
+ edgecolor="k",
+ linewidths=0.8,
+ alpha=graph_style["alpha"],
+ zorder=2,
+ )
+ _ = ax.scatter(
+ outliers_data["test_scaled"],
+ outliers_data["test_scaled"],
+ c=graph_style["color_test"],
+ s=graph_style["dot_size"],
+ edgecolor="k",
+ linewidths=0.8,
+ alpha=graph_style["alpha"],
+ zorder=2,
+ )
- plt.grid(linestyle='--', linewidth=1)
- _ = ax.scatter(outliers_data['train_scaled'], outliers_data['train_scaled'],
- c = graph_style['color_train'], s = graph_style['dot_size'], edgecolor = 'k', linewidths = 0.8, alpha = graph_style['alpha'], zorder=2)
- _ = ax.scatter(outliers_data['test_scaled'], outliers_data['test_scaled'],
- c = graph_style['color_test'], s = graph_style['dot_size'], edgecolor = 'k', linewidths = 0.8, alpha = graph_style['alpha'], zorder=2)
-
# Set styling preferences and graph limits
- plt.xlabel('SD of the errors',fontsize=14)
+ plt.xlabel("SD of the errors", fontsize=14)
plt.xticks(fontsize=14)
- plt.ylabel('SD of the errors',fontsize=14)
+ plt.ylabel("SD of the errors", fontsize=14)
plt.yticks(fontsize=14)
-
- axis_limit = max(outliers_data['train_scaled'], key=abs)
- if 'test_scaled' in outliers_data:
- if max(outliers_data['test_scaled'], key=abs) > axis_limit:
- axis_limit = max(outliers_data['test_scaled'], key=abs)
- axis_limit = axis_limit+0.5
- if axis_limit < 2.5: # this fixes a problem when representing rectangles in graphs with low SDs
+
+ axis_limit = max(outliers_data["train_scaled"], key=abs)
+ if "test_scaled" in outliers_data:
+ if max(outliers_data["test_scaled"], key=abs) > axis_limit:
+ axis_limit = max(outliers_data["test_scaled"], key=abs)
+ axis_limit = axis_limit + 0.5
+ if (
+ axis_limit < 2.5
+ ): # this fixes a problem when representing rectangles in graphs with low SDs
axis_limit = 2.5
plt.ylim(-axis_limit, axis_limit)
plt.xlim(-axis_limit, axis_limit)
# plot rectangles in corners
diff_tvalue = axis_limit - self.args.t_value
- Rectangle_top = mpatches.Rectangle(xy=(axis_limit, axis_limit), width=-diff_tvalue, height=-diff_tvalue, facecolor='grey', alpha=0.3)
- Rectangle_bottom = mpatches.Rectangle(xy=(-(axis_limit), -(axis_limit)), width=diff_tvalue, height=diff_tvalue, facecolor='grey', alpha=0.3)
+ Rectangle_top = mpatches.Rectangle(
+ xy=(axis_limit, axis_limit),
+ width=-diff_tvalue,
+ height=-diff_tvalue,
+ facecolor="grey",
+ alpha=0.3,
+ )
+ Rectangle_bottom = mpatches.Rectangle(
+ xy=(-(axis_limit), -(axis_limit)),
+ width=diff_tvalue,
+ height=diff_tvalue,
+ facecolor="grey",
+ alpha=0.3,
+ )
ax.add_patch(Rectangle_top)
ax.add_patch(Rectangle_bottom)
# save plot and print results
- outliers_plot_file = f'{os.path.dirname(path_n_suffix)}/Outliers_{os.path.basename(path_n_suffix)}.png'
- plt.savefig(f'{outliers_plot_file}', dpi=300, bbox_inches='tight')
-
- path_reduced = '/'.join(f'{outliers_plot_file}'.replace('\\','/').split('/')[-2:])
+ outliers_plot_file = f"{os.path.dirname(path_n_suffix)}/Outliers_{os.path.basename(path_n_suffix)}.png"
+ plt.savefig(f"{outliers_plot_file}", dpi=300, bbox_inches="tight")
+ plt.close(fig)
+
+ path_reduced = "/".join(f"{outliers_plot_file}".replace("\\", "/").split("/")[-2:])
print_outliers += f"\n o Outliers plot saved in {path_reduced}"
- if 'train' not in name_points:
- print_outliers += f'\n x No names option (or var missing in CSV file)! Outlier names will not be shown'
+ if "train" not in name_points:
+ print_outliers += "\n x No names option (or var missing in CSV file)! Outlier names will not be shown"
else:
- if 'test_scaled' in outliers_data and 'test' not in name_points:
- print_outliers += f'\n x No names option (or var missing in CSV file in the test file)! Outlier names will not be shown'
+ if "test_scaled" in outliers_data and "test" not in name_points:
+ print_outliers += "\n x No names option (or var missing in CSV file in the test file)! Outlier names will not be shown"
+
+ print_outliers = outlier_analysis(print_outliers, outliers_data, "train")
+ print_outliers = outlier_analysis(print_outliers, outliers_data, "test")
- print_outliers = outlier_analysis(print_outliers,outliers_data,'train')
- print_outliers = outlier_analysis(print_outliers,outliers_data,'test')
-
self.args.log.write(print_outliers)
-def outlier_analysis(print_outliers,outliers_data,outliers_set):
- '''
+def outlier_analysis(print_outliers, outliers_data, outliers_set):
+ """
Analyzes the outlier results
- '''
-
- if outliers_set == 'train':
- label_set = 'Train'
- outliers_label = 'outliers_train'
- n_points_label = 'train_scaled'
- outliers_name = 'names_train'
- elif outliers_set == 'valid':
- label_set = 'Validation'
- outliers_label = 'outliers_valid'
- n_points_label = 'valid_scaled'
- outliers_name = 'names_valid'
- elif outliers_set == 'test':
- label_set = 'Test'
- outliers_label = 'outliers_test'
- n_points_label = 'test_scaled'
- outliers_name = 'names_test'
-
- per_cent = (len(outliers_data[outliers_label])/len(outliers_data[n_points_label]))*100
+ """
+
+ if outliers_set == "train":
+ label_set = "Train"
+ outliers_label = "outliers_train"
+ n_points_label = "train_scaled"
+ outliers_name = "names_train"
+ elif outliers_set == "valid":
+ label_set = "Validation"
+ outliers_label = "outliers_valid"
+ n_points_label = "valid_scaled"
+ outliers_name = "names_valid"
+ elif outliers_set == "test":
+ label_set = "Test"
+ outliers_label = "outliers_test"
+ n_points_label = "test_scaled"
+ outliers_name = "names_test"
+
+ per_cent = (
+ len(outliers_data[outliers_label]) / len(outliers_data[n_points_label])
+ ) * 100
print_outliers += f"\n {label_set}: {len(outliers_data[outliers_label])} outliers out of {len(outliers_data[n_points_label])} datapoints ({per_cent:.1f}%)"
- for val,name in zip(outliers_data[outliers_label], outliers_data[outliers_name]):
+ for val, name in zip(outliers_data[outliers_label], outliers_data[outliers_name]):
print_outliers += f"\n - {name} ({val:.2} SDs)"
return print_outliers
def outlier_filter(self, Xy_data, name_points):
- '''
+ """
Calculates and stores absolute errors in SD units for all the sets
- '''
-
+ """
+
# calculate absolute errors between predicted y and actual values
- outliers_train = [abs(x-y) for x,y in zip(Xy_data['y_train'],Xy_data['y_pred_train'])]
- outliers_test = [abs(x-y) for x,y in zip(Xy_data['y_test'],Xy_data['y_pred_test'])]
+ outliers_train = [
+ abs(x - y) for x, y in zip(Xy_data["y_train"], Xy_data["y_pred_train"])
+ ]
+ outliers_test = [
+ abs(x - y) for x, y in zip(Xy_data["y_test"], Xy_data["y_pred_test"])
+ ]
# the errors are scaled using standard deviation units. When the absolute
# error is larger than the t-value, the point is considered an outlier. All the sets
@@ -3164,31 +3865,35 @@ def outlier_filter(self, Xy_data, name_points):
outliers_sd = np.std(outliers_train)
outliers_data = {}
- outliers_data['train_scaled'] = (outliers_train-outliers_mean)/outliers_sd
- outliers_data['test_scaled'] = (outliers_test-outliers_mean)/outliers_sd
+ outliers_data["train_scaled"] = (outliers_train - outliers_mean) / outliers_sd
+ outliers_data["test_scaled"] = (outliers_test - outliers_mean) / outliers_sd
- print_outliers, naming, naming_test = '', False, False
- if 'train' in name_points:
+ print_outliers, naming, naming_test = "", False, False
+ if "train" in name_points:
naming = True
- if 'test' in name_points:
+ if "test" in name_points:
naming_test = True
- outliers_data['outliers_train'], outliers_data['names_train'] = detect_outliers(self, outliers_data['train_scaled'], name_points, naming, 'train')
- outliers_data['outliers_test'], outliers_data['names_test'] = detect_outliers(self, outliers_data['test_scaled'], name_points, naming_test, 'test')
-
+ outliers_data["outliers_train"], outliers_data["names_train"] = detect_outliers(
+ self, outliers_data["train_scaled"], name_points, naming, "train"
+ )
+ outliers_data["outliers_test"], outliers_data["names_test"] = detect_outliers(
+ self, outliers_data["test_scaled"], name_points, naming_test, "test"
+ )
+
return outliers_data, print_outliers
def detect_outliers(self, outliers_scaled, name_points, naming_detect, set_type):
- '''
+ """
Detects and store outliers with their corresponding datapoint names
- '''
+ """
val_outliers = []
name_outliers = []
if naming_detect:
name_points_list = name_points[set_type].to_list()
- for i,val in enumerate(outliers_scaled):
+ for i, val in enumerate(outliers_scaled):
if val > self.args.t_value or val < -self.args.t_value:
val_outliers.append(val)
if naming_detect:
@@ -3197,177 +3902,217 @@ def detect_outliers(self, outliers_scaled, name_points, naming_detect, set_type)
return val_outliers, name_outliers
-def distribution_plot(self,Xy_data,path_n_suffix,params_dict):
- '''
+def distribution_plot(self, Xy_data, path_n_suffix, params_dict):
+ """
Plots histogram (reg) or bin plot (clas).
- '''
- import seaborn as sb
-
+ """
sb.set(style="ticks")
- _, ax = plt.subplots(figsize=(7.45,6))
- plt.text(0.5, 1.08, f'y-values distribution (CV + test) of {os.path.basename(path_n_suffix)}', horizontalalignment='center',
- fontsize=14, fontweight='bold', transform = ax.transAxes)
+ fig, ax = plt.subplots(figsize=(7.45, 6))
+ plt.text(
+ 0.5,
+ 1.08,
+ f"y-values distribution (CV + test) of {os.path.basename(path_n_suffix)}",
+ horizontalalignment="center",
+ fontsize=14,
+ fontweight="bold",
+ transform=ax.transAxes,
+ )
- plt.grid(linestyle='--', linewidth=1)
+ plt.grid(linestyle="--", linewidth=1)
# combine train and validation sets
- y_combined = pd.concat([Xy_data['y_train'],Xy_data['y_test']], axis=0).reset_index(drop=True)
+ y_combined = pd.concat([Xy_data["y_train"], Xy_data["y_test"]], axis=0).reset_index(
+ drop=True
+ )
# plot histogram, quartile lines and the points in each quartile
- if params_dict['type'].lower() == 'reg':
- y_dist_dict,ax = plot_quartiles(y_combined,ax)
-
+ if params_dict["type"].lower() == "reg":
+ y_dist_dict, ax = plot_quartiles(y_combined, ax)
+
# plot a bar plot with the count of each y type
- elif params_dict['type'].lower() == 'clas':
- y_dist_dict,ax = plot_y_count(y_combined,ax)
+ elif params_dict["type"].lower() == "clas":
+ y_dist_dict, ax = plot_y_count(y_combined, ax)
# set styling preferences and graph limits
- plt.xlabel(f'{params_dict["y"]} values',fontsize=14)
+ plt.xlabel(f"{params_dict['y']} values", fontsize=14)
plt.xticks(fontsize=14)
- plt.ylabel('Frequency',fontsize=14)
+ plt.ylabel("Frequency", fontsize=14)
plt.yticks(fontsize=14)
# set limits
- if params_dict['type'].lower() == 'reg':
- border_y_range = 0.1*np.abs(max(y_combined)-min(y_combined))
- plt.xlim(min(y_combined)-border_y_range, max(y_combined)+border_y_range)
+ if params_dict["type"].lower() == "reg":
+ border_y_range = 0.1 * np.abs(max(y_combined) - min(y_combined))
+ plt.xlim(min(y_combined) - border_y_range, max(y_combined) + border_y_range)
# save plot and print results
- orig_distrib_file = f'y_distribution_{os.path.basename(path_n_suffix)}.png'
- plt.savefig(f'{orig_distrib_file}', dpi=300, bbox_inches='tight')
+ orig_distrib_file = f"y_distribution_{os.path.basename(path_n_suffix)}.png"
+ plt.savefig(f"{orig_distrib_file}", dpi=300, bbox_inches="tight")
+ plt.close(fig)
# for a VERY weird reason, I need to save the figure in the working directory and then move it into PREDICT
- final_distrib_file = f'{os.path.dirname(path_n_suffix)}/y_distribution_{os.path.basename(path_n_suffix)}.png'
+ final_distrib_file = f"{os.path.dirname(path_n_suffix)}/y_distribution_{os.path.basename(path_n_suffix)}.png"
shutil.move(orig_distrib_file, final_distrib_file)
- path_reduced = '/'.join(f'{final_distrib_file}'.replace('\\','/').split('/')[-2:])
+ path_reduced = "/".join(f"{final_distrib_file}".replace("\\", "/").split("/")[-2:])
print_distrib = f"\n o y-values distribution plot saved in {path_reduced}"
# print the quartile results
- if params_dict['type'].lower() == 'reg':
- print_distrib += f"\n Ideally, the number of datapoints in the four quartiles of the y-range should be uniform (25% population in each quartile) to have similar confidence intervals in the predictions across the y-range"
- quartile_pops = [len(y_dist_dict['q1_points']),len(y_dist_dict['q2_points']),len(y_dist_dict['q3_points']),len(y_dist_dict['q4_points'])]
+ if params_dict["type"].lower() == "reg":
+ print_distrib += "\n Ideally, the number of datapoints in the four quartiles of the y-range should be uniform (25% population in each quartile) to have similar confidence intervals in the predictions across the y-range"
+ quartile_pops = [
+ len(y_dist_dict["q1_points"]),
+ len(y_dist_dict["q2_points"]),
+ len(y_dist_dict["q3_points"]),
+ len(y_dist_dict["q4_points"]),
+ ]
print_distrib += f"\n - The number of points in each quartile is Q1: {quartile_pops[0]}, Q2: {quartile_pops[1]}, Q3: {quartile_pops[2]}, Q4: {quartile_pops[3]}"
quartile_min_idx = quartile_pops.index(min(quartile_pops))
quartile_max_idx = quartile_pops.index(max(quartile_pops))
- if 4*min(quartile_pops) < max(quartile_pops):
- print_distrib += f"\n x WARNING! Your data is not uniform (Q{quartile_min_idx+1} has {min(quartile_pops)} points while Q{quartile_max_idx+1} has {max(quartile_pops)})"
- elif 2*min(quartile_pops) < max(quartile_pops):
- print_distrib += f"\n x WARNING! Your data is slightly not uniform (Q{quartile_min_idx+1} has {min(quartile_pops)} points while Q{quartile_max_idx+1} has {max(quartile_pops)})"
+ if 4 * min(quartile_pops) < max(quartile_pops):
+ print_distrib += f"\n x WARNING! Your data is not uniform (Q{quartile_min_idx + 1} has {min(quartile_pops)} points while Q{quartile_max_idx + 1} has {max(quartile_pops)})"
+ elif 2 * min(quartile_pops) < max(quartile_pops):
+ print_distrib += f"\n x WARNING! Your data is slightly not uniform (Q{quartile_min_idx + 1} has {min(quartile_pops)} points while Q{quartile_max_idx + 1} has {max(quartile_pops)})"
else:
- print_distrib += f"\n o Your data seems quite uniform"
+ print_distrib += "\n o Your data seems quite uniform"
- elif params_dict['type'].lower() == 'clas':
- print_distrib += f"\n Ideally, the number of datapoints in each prediction class should be uniform (50% population per class) to have similar reliability in the predictions across classes"
- distrib_counts = [y_dist_dict['count_labels'][0],y_dist_dict['count_labels'][1]]
+ elif params_dict["type"].lower() == "clas":
+ print_distrib += "\n Ideally, the number of datapoints in each prediction class should be uniform (50% population per class) to have similar reliability in the predictions across classes"
+ distrib_counts = [
+ y_dist_dict["count_labels"][0],
+ y_dist_dict["count_labels"][1],
+ ]
print_distrib += f"\n - The number of points in each class is {y_dist_dict['type_labels'][0]}: {y_dist_dict['count_labels'][0]}, {y_dist_dict['type_labels'][1]}: {y_dist_dict['count_labels'][1]}"
class_min_idx = distrib_counts.index(min(distrib_counts))
class_max_idx = distrib_counts.index(max(distrib_counts))
- if 3*min(distrib_counts) < max(distrib_counts):
+ if 3 * min(distrib_counts) < max(distrib_counts):
print_distrib += f"\n x WARNING! Your data is not uniform (class {y_dist_dict['type_labels'][class_min_idx]} has {min(distrib_counts)} points while class {y_dist_dict['type_labels'][class_max_idx]} has {max(distrib_counts)})"
- elif 1.5*min(distrib_counts) < max(distrib_counts):
+ elif 1.5 * min(distrib_counts) < max(distrib_counts):
print_distrib += f"\n x WARNING! Your data is slightly not uniform (class {y_dist_dict['type_labels'][class_min_idx]} has {min(distrib_counts)} points while class {y_dist_dict['type_labels'][class_max_idx]} has {max(distrib_counts)})"
else:
- print_distrib += f"\n o Your data seems quite uniform"
+ print_distrib += "\n o Your data seems quite uniform"
self.args.log.write(print_distrib)
-def plot_quartiles(y_combined,ax):
- '''
+def plot_quartiles(y_combined, ax):
+ """
Plot histogram, quartile lines and the points in each quartile.
- '''
+ """
- bins = max([round(len(y_combined)/5),5]) # at least 5 bins until 25 points
+ bins = max([round(len(y_combined) / 5), 5]) # at least 5 bins until 25 points
# histogram
- y_hist, _, _ = ax.hist(y_combined, bins=bins,
- color='#1f77b4', edgecolor='k', linewidth=1, alpha=1)
+ y_hist, _, _ = ax.hist(
+ y_combined, bins=bins, color="#1f77b4", edgecolor="k", linewidth=1, alpha=1
+ )
# uniformity lines to plot
- separation_range = np.abs(max(y_combined)-min(y_combined))/4
- quart_dict = {'line_1': min(y_combined),
- 'line_2': min(y_combined) + separation_range,
- 'line_3': min(y_combined) + (2*separation_range),
- 'line_4': min(y_combined) + (3*separation_range),
- 'line_5': max(y_combined)}
+ separation_range = np.abs(max(y_combined) - min(y_combined)) / 4
+ quart_dict = {
+ "line_1": min(y_combined),
+ "line_2": min(y_combined) + separation_range,
+ "line_3": min(y_combined) + (2 * separation_range),
+ "line_4": min(y_combined) + (3 * separation_range),
+ "line_5": max(y_combined),
+ }
lines_plot = [quart_dict[line] for line in quart_dict]
- ax.vlines([lines_plot], ymin=max(y_hist)*1.05, ymax=max(y_hist)*1.3, colors='crimson', linestyles='--')
+ ax.vlines(
+ [lines_plot],
+ ymin=max(y_hist) * 1.05,
+ ymax=max(y_hist) * 1.3,
+ colors="crimson",
+ linestyles="--",
+ )
# points in each quartile
- quart_dict['q1_points'] = []
- quart_dict['q2_points'] = []
- quart_dict['q3_points'] = []
- quart_dict['q4_points'] = []
+ quart_dict["q1_points"] = []
+ quart_dict["q2_points"] = []
+ quart_dict["q3_points"] = []
+ quart_dict["q4_points"] = []
for val in y_combined:
- if val < quart_dict['line_2']:
- quart_dict['q1_points'].append(val)
- elif quart_dict['line_2'] < val < quart_dict['line_3']:
- quart_dict['q2_points'].append(val)
- elif quart_dict['line_3'] < val < quart_dict['line_4']:
- quart_dict['q3_points'].append(val)
- elif val >= quart_dict['line_4']:
- quart_dict['q4_points'].append(val)
+ if val < quart_dict["line_2"]:
+ quart_dict["q1_points"].append(val)
+ elif quart_dict["line_2"] < val < quart_dict["line_3"]:
+ quart_dict["q2_points"].append(val)
+ elif quart_dict["line_3"] < val < quart_dict["line_4"]:
+ quart_dict["q3_points"].append(val)
+ elif val >= quart_dict["line_4"]:
+ quart_dict["q4_points"].append(val)
x_quart = 0.185
for quart in quart_dict:
- if 'points' in quart:
- plt.text(x_quart, 0.845, f'Q{quart[1]}\n{len(quart_dict[quart])} points', horizontalalignment='center',
- fontsize=12, transform = ax.transAxes, backgroundcolor='w')
+ if "points" in quart:
+ plt.text(
+ x_quart,
+ 0.845,
+ f"Q{quart[1]}\n{len(quart_dict[quart])} points",
+ horizontalalignment="center",
+ fontsize=12,
+ transform=ax.transAxes,
+ backgroundcolor="w",
+ )
x_quart += 0.209
- return quart_dict,ax
+ return quart_dict, ax
-def plot_y_count(y_combined,ax):
- '''
+def plot_y_count(y_combined, ax):
+ """
Plot a bar plot with the count of each y type.
- '''
+ """
# get the number of times that each y type is included
labels_used = set(y_combined)
- type_labels,count_labels = [],[]
+ type_labels, count_labels = [], []
for label in labels_used:
type_labels.append(label)
count_labels.append(len(y_combined[y_combined == label]))
- _ = ax.bar(type_labels, count_labels, tick_label=type_labels,
- color='#1f77b4', edgecolor='k', linewidth=1, alpha=1,
- width=0.4)
+ _ = ax.bar(
+ type_labels,
+ count_labels,
+ tick_label=type_labels,
+ color="#1f77b4",
+ edgecolor="k",
+ linewidth=1,
+ alpha=1,
+ width=0.4,
+ )
- y_dist_dict = {'type_labels': type_labels,
- 'count_labels': count_labels}
+ y_dist_dict = {"type_labels": type_labels, "count_labels": count_labels}
- return y_dist_dict,ax
+ return y_dist_dict, ax
-def get_prediction_results(model_data,y,y_pred_all):
- '''
+def get_prediction_results(model_data, y, y_pred_all):
+ """
Calculate metrics based on y and y_pred
- '''
+ """
- if model_data['type'].lower() == 'reg':
- mae = mean_absolute_error(y,y_pred_all)
- rmse = np.sqrt(mean_squared_error(y,y_pred_all))
+ if model_data["type"].lower() == "reg":
+ mae = mean_absolute_error(y, y_pred_all)
+ rmse = np.sqrt(mean_squared_error(y, y_pred_all))
if len(np.unique(y)) > 1 and len(np.unique(y_pred_all)) > 1:
- res = stats.linregress(y,y_pred_all)
+ res = stats.linregress(y, y_pred_all)
r2 = res.rvalue**2
else:
r2 = 0.0
return r2, mae, rmse
- elif model_data['type'].lower() == 'clas':
+ elif model_data["type"].lower() == "clas":
# ensure true and predicted labels are integers
- acc = accuracy_score(y,np.round(y_pred_all).astype(int))
+ acc = accuracy_score(y, np.round(y_pred_all).astype(int))
# F1 by default uses average='binnary', to deal with predictions with more than 2 ouput values we use average='micro'
# if len(set(y))==2:
try:
- f1_score_val = f1_score(y,np.round(y_pred_all).astype(int))
+ f1_score_val = f1_score(y, np.round(y_pred_all).astype(int))
except ValueError:
- f1_score_val = f1_score(y,np.round(y_pred_all).astype(int),average='micro')
- mcc = matthews_corrcoef(y,np.round(y_pred_all).astype(int))
+ f1_score_val = f1_score(
+ y, np.round(y_pred_all).astype(int), average="micro"
+ )
+ mcc = matthews_corrcoef(y, np.round(y_pred_all).astype(int))
return acc, f1_score_val, mcc
@@ -3387,10 +4132,7 @@ def get_error_labels(model_type):
- Regression: ('r2', 'mae', 'rmse')
- Classification: ('acc', 'f1', 'mcc')
"""
- error_labels = {
- 'reg': ('r2', 'mae', 'rmse'),
- 'clas': ('acc', 'f1', 'mcc')
- }
+ error_labels = {"reg": ("r2", "mae", "rmse"), "clas": ("acc", "f1", "mcc")}
model_type_lower = model_type.lower()
@@ -3422,134 +4164,193 @@ def _select_descriptors(self, df, descriptors, module):
sys.exit()
-def load_db_n_params(self,params_dir,suffix,suffix_title,module,print_load):
- '''
+def load_db_n_params(self, params_dir, suffix, suffix_title, module, print_load):
+ """
Loads the parameters and Xy databases from a folder, add scaled X data and print information
about the databases
- '''
+ """
# load databases from CSV
- csv_df,csv_X,csv_y,model_data,_ = load_dfs(self,params_dir,module,print_info=print_load)
+ csv_df, csv_X, csv_y, model_data, _ = load_dfs(
+ self, params_dir, module, print_info=print_load
+ )
# detect points in the test set
- test_points = csv_X[csv_X['Set'] == 'Test'].index.tolist()
- csv_X = csv_X.drop(columns=['Set'])
+ test_points = csv_X[csv_X["Set"] == "Test"].index.tolist()
+ csv_X = csv_X.drop(columns=["Set"])
# keep only the descriptors used in the model
csv_X = _select_descriptors(self, csv_X, model_data["X_descriptors"], module)
# load and adjust external set (if any)
- csv_external_df, csv_X_external,csv_y_external = None,None,None
- if self.args.csv_test != '':
- csv_external_df,csv_X_external,csv_y_external = load_database(self,self.args.csv_test,'predict',external_test=True)
+ csv_external_df, csv_X_external, csv_y_external = None, None, None
+ if self.args.csv_test != "":
+ csv_external_df, csv_X_external, csv_y_external = load_database(
+ self, self.args.csv_test, "predict", external_test=True
+ )
csv_X_external = _select_descriptors(
self, csv_X_external, model_data["X_descriptors"], "predict"
)
# split tests
- Xy_data = prepare_sets(self,csv_df,csv_X,csv_y,test_points,model_data['names'],csv_external_df,csv_X_external,csv_y_external,BO_opt=False)
+ Xy_data = prepare_sets(
+ self,
+ csv_df,
+ csv_X,
+ csv_y,
+ test_points,
+ model_data["names"],
+ csv_external_df,
+ csv_X_external,
+ csv_y_external,
+ BO_opt=False,
+ )
# print information of loaded database
params_name = os.path.basename(params_dir)
if print_load:
- _ = load_print(self,params_name,suffix,model_data,Xy_data)
+ _ = load_print(self, params_name, suffix, model_data, Xy_data)
return Xy_data, model_data, suffix_title
-def prepare_sets(self,csv_df,csv_X,csv_y,test_points,column_names,csv_external_df,csv_X_external,csv_y_external,BO_opt=False):
- '''
+def prepare_sets(
+ self,
+ csv_df,
+ csv_X,
+ csv_y,
+ test_points,
+ column_names,
+ csv_external_df,
+ csv_X_external,
+ csv_y_external,
+ BO_opt=False,
+):
+ """
Standardizes and separate test set
- '''
+ """
- X_scaled_df,X_scaled_external_df = scale_df(csv_X,csv_X_external)
+ X_scaled_df, X_scaled_external_df = scale_df(csv_X, csv_X_external)
# separate test set and save it in the Xy data
if BO_opt:
- if self.args.csv_test != '':
+ if self.args.csv_test != "":
self.args.test_set = 0
-
+
if self.args.auto_test:
if self.args.test_set < 0.2:
self.args.test_set = 0.2
- self.args.log.write(f'\nx WARNING! The test_set option was set to {self.args.test_set}, this value will be raised to 0.2 to include a meaningful amount of points in the test set. You can bypass this option and include less test points with "--auto_test False".')
+ self.args.log.write(
+ f'\nx WARNING! The test_set option was set to {self.args.test_set}, this value will be raised to 0.2 to include a meaningful amount of points in the test set. You can bypass this option and include less test points with "--auto_test False".'
+ )
if self.args.test_set > 0:
- self.args.log.write(f'\no Before hyperoptimization, {int(self.args.test_set*100)}% of the data (or 4 points at least) was separated as test set, using an even distribution of data points across the range of y values.')
+ self.args.log.write(
+ f"\no Before hyperoptimization, {int(self.args.test_set * 100)}% of the data (or 4 points at least) was separated as test set, using an even distribution of data points across the range of y values."
+ )
try:
- test_points = test_select(self,X_scaled_df,csv_y)
+ test_points = test_select(self, X_scaled_df, csv_y)
except TypeError:
- self.args.log.write(f' x The data split process failed! This is probably due to using strings/words as values (use --curate to curate the data first)')
+ self.args.log.write(
+ " x The data split process failed! This is probably due to using strings/words as values (use --curate to curate the data first)"
+ )
sys.exit()
# load predefined sets and save the info in Xy data
- Xy_data = Xy_split(csv_df,csv_X,X_scaled_df,csv_y,csv_external_df,csv_X_external,X_scaled_external_df,csv_y_external,test_points,column_names)
+ Xy_data = Xy_split(
+ csv_df,
+ csv_X,
+ X_scaled_df,
+ csv_y,
+ csv_external_df,
+ csv_X_external,
+ X_scaled_external_df,
+ csv_y_external,
+ test_points,
+ column_names,
+ )
# also store the descriptors used (the labels disappear after test_select() )
- Xy_data['X_descriptors'] = csv_X.columns.tolist()
+ Xy_data["X_descriptors"] = csv_X.columns.tolist()
return Xy_data
-def load_dfs(self,folder_model,module,sanity_check=False,print_info=True):
- '''
+def load_dfs(self, folder_model, module, sanity_check=False, print_info=True):
+ """
Loads the parameters and Xy databases from the GENERATE folder as dataframes
- '''
-
+ """
+
+ csv_df = pd.DataFrame()
+ csv_X = pd.DataFrame()
+ csv_y = pd.DataFrame()
+ model_data = {}
+ csv_name = ""
+
if os.getcwd() in f"{folder_model}":
path_db = folder_model
else:
path_db = f"{Path(os.getcwd()).joinpath(folder_model)}"
if os.path.exists(path_db):
- csv_files = glob.glob(f'{Path(path_db).joinpath("*.csv")}')
- csv_files.sort(key=lambda f: f.endswith('_db.csv')) # sort the database file to be the last one, depending on the OS was taking first the dabatase and then the parameters
+ csv_files = glob.glob(f"{Path(path_db).joinpath('*.csv')}")
+ csv_files.sort(
+ key=lambda f: f.endswith("_db.csv")
+ ) # sort the database file to be the last one, depending on the OS was taking first the dabatase and then the parameters
for csv_file in csv_files:
- if csv_file.endswith('_db.csv'):
+ if csv_file.endswith("_db.csv"):
if not sanity_check:
- csv_df,csv_X,csv_y = load_database(self,csv_file,module,print_info=print_info)
+ csv_df, csv_X, csv_y = load_database(
+ self, csv_file, module, print_info=print_info
+ )
csv_name = csv_file
else:
- csv_df,csv_X,csv_y = pd.DataFrame(),pd.DataFrame(),pd.DataFrame()
+ csv_df, csv_X, csv_y = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
# convert df to dict, then adjust params to a valid format
- model_data = load_params(self,csv_file)
+ model_data = load_params(self, csv_file)
else:
- self.args.log.write(f"\nx The folder with the model and database ({path_db}) does not exist! Did you use the destination=PATH option in the other modules?")
+ self.args.log.write(
+ f"\nx The folder with the model and database ({path_db}) does not exist! Did you use the destination=PATH option in the other modules?"
+ )
sys.exit()
- return csv_df,csv_X,csv_y,model_data,csv_name
+ return csv_df, csv_X, csv_y, model_data, csv_name
-def load_params(self,path_csv):
- '''
+def load_params(self, path_csv):
+ """
Load parameters from a CSV and adjust the format
- '''
-
- PFI_df = pd.read_csv(path_csv, encoding='utf-8')
+ """
+
+ PFI_df = pd.read_csv(path_csv, encoding="utf-8")
PFI_dict = pd_to_dict(PFI_df)
PFI_dict = dict_formating(PFI_dict)
- PFI_dict['params'] = model_adjust_params(self, PFI_dict['model'], PFI_dict['params'])
+ PFI_dict["params"] = model_adjust_params(
+ self, PFI_dict["model"], PFI_dict["params"]
+ )
return PFI_dict
-def load_print(self,params_name,suffix,model_data,Xy_data):
- '''
+def load_print(self, params_name, suffix, model_data, Xy_data):
+ """
Print information of the database loaded and type of model used
- '''
-
- if '.csv' in params_name:
- params_name = params_name.split('.csv')[0]
- txt_load = f'\no ML model {params_name} {suffix} and Xy database were loaded, including:'
- txt_load += f'\n - Target value: {model_data["y"]}'
- txt_load += f'\n - Names: {model_data["names"]}'
- txt_load += f'\n - Model: {model_data["model"]}'
- txt_load += f'\n - k-fold CV: {model_data["kfold"]}'
- txt_load += f'\n - Repetitions CV: {model_data["repeat_kfolds"]}'
- txt_load += f'\n - Descriptors: {model_data["X_descriptors"]}'
- txt_load += f'\n - Training points: {len(Xy_data["y_train"])}'
- txt_load += f'\n - Test points: {len(Xy_data["y_test"])}'
- if 'X_external' in Xy_data:
- txt_load += f'\n - External test points: {len(Xy_data["X_external"])}'
+ """
+
+ if ".csv" in params_name:
+ params_name = params_name.split(".csv")[0]
+ txt_load = (
+ f"\no ML model {params_name} {suffix} and Xy database were loaded, including:"
+ )
+ txt_load += f"\n - Target value: {model_data['y']}"
+ txt_load += f"\n - Names: {model_data['names']}"
+ txt_load += f"\n - Model: {model_data['model']}"
+ txt_load += f"\n - k-fold CV: {model_data['kfold']}"
+ txt_load += f"\n - Repetitions CV: {model_data['repeat_kfolds']}"
+ txt_load += f"\n - Descriptors: {model_data['X_descriptors']}"
+ txt_load += f"\n - Training points: {len(Xy_data['y_train'])}"
+ txt_load += f"\n - Test points: {len(Xy_data['y_test'])}"
+ if "X_external" in Xy_data:
+ txt_load += f"\n - External test points: {len(Xy_data['X_external'])}"
self.args.log.write(txt_load)
@@ -3560,239 +4361,316 @@ def pd_to_dict(PFI_df):
return PFI_df_dict
-def print_pfi(self,params_dir):
- if 'No_PFI' in params_dir:
- self.args.log.write('\n\n------- Starting model with all variables (No PFI) -------')
+def print_pfi(self, params_dir):
+ if "No_PFI" in params_dir:
+ self.args.log.write(
+ "\n\n------- Starting model with all variables (No PFI) -------"
+ )
else:
- self.args.log.write('\n\n------- Starting model with PFI filter (only important descriptors used) -------')
+ self.args.log.write(
+ "\n\n------- Starting model with PFI filter (only important descriptors used) -------"
+ )
def get_graph_style():
"""
Retrieves the graph style for regression plots
"""
-
- graph_style = {'color_train' : 'b',
- 'color_valid' : 'orange',
- 'color_test' : 'r',
- 'dot_size' : 50,
- 'alpha' : 1 # from 0 (transparent) to 1 (opaque)
- }
+
+ graph_style = {
+ "color_train": "b",
+ "color_valid": "orange",
+ "color_test": "r",
+ "dot_size": 50,
+ "alpha": 1, # from 0 (transparent) to 1 (opaque)
+ }
return graph_style
-def pearson_map(self,csv_df_pearson,module,params_dir=None):
- '''
+def pearson_map(self, csv_df_pearson, module, params_dir=None):
+ """
Creates Pearson heatmap
- '''
- import seaborn as sb
-
+ """
if module.lower() == "curate": # only represent the final descriptors in CURATE
csv_df_pearson = csv_df_pearson.drop([self.args.y] + self.args.ignore, axis=1)
corr_matrix = csv_df_pearson.corr()
mask = np.zeros_like(corr_matrix, dtype=bool)
- mask[np.triu_indices_from(mask)]= True
-
+ mask[np.triu_indices_from(mask)] = True
+
# no representatoins when there are more than 30 descriptors
if len(csv_df_pearson.columns) > 30:
disable_plot = True
else:
disable_plot = False
- _, ax = plt.subplots(figsize=(7.45,6))
+ fig, ax = plt.subplots(figsize=(7.45, 6))
size_title = 14
- size_font = 14-2*((len(csv_df_pearson.columns)/5))
+ size_font = 14 - 2 * (len(csv_df_pearson.columns) / 5)
if disable_plot:
- if module.lower() == 'curate':
- self.args.log.write(f'\nx The Pearson heatmap was not generated because the number of features and the y value ({len(csv_df_pearson.columns)}) is higher than 30.')
- if module.lower() == 'predict':
- self.args.log.write(f'\n x The Pearson heatmap was not generated because the number of features and the y value ({len(csv_df_pearson.columns)}) is higher than 30.')
+ if module.lower() == "curate":
+ self.args.log.write(
+ f"\nx The Pearson heatmap was not generated because the number of features and the y value ({len(csv_df_pearson.columns)}) is higher than 30."
+ )
+ if module.lower() == "predict":
+ self.args.log.write(
+ f"\n x The Pearson heatmap was not generated because the number of features and the y value ({len(csv_df_pearson.columns)}) is higher than 30."
+ )
else:
- sb.set(font_scale=1.2, style='ticks')
-
- _ = sb.heatmap(corr_matrix,
- mask = mask,
- square = True,
- linewidths = .5,
- cmap = 'coolwarm',
- cbar = False,
- cbar_kws = {'shrink': .4,
- 'ticks' : [-1, -.5, 0, 0.5, 1]},
- vmin = -1,
- vmax = 1,
- annot = True,
- annot_kws = {'size': size_font})
+ sb.set(font_scale=1.2, style="ticks")
+
+ _ = sb.heatmap(
+ corr_matrix,
+ mask=mask,
+ square=True,
+ linewidths=0.5,
+ cmap="coolwarm",
+ cbar=False,
+ cbar_kws={"shrink": 0.4, "ticks": [-1, -0.5, 0, 0.5, 1]},
+ vmin=-1,
+ vmax=1,
+ annot=True,
+ annot_kws={"size": size_font},
+ )
plt.tick_params(labelsize=size_font)
- #add the column names as labels
- ax.set_yticklabels(corr_matrix.columns, rotation = 0)
+ # add the column names as labels
+ ax.set_yticklabels(corr_matrix.columns, rotation=0)
ax.set_xticklabels(corr_matrix.columns)
- title_fig = 'Pearson\'s r heatmap'
- if module.lower() == 'predict':
- if os.path.basename(Path(params_dir)) == 'No_PFI':
- suffix_title = 'No_PFI'
- elif os.path.basename(Path(params_dir)) == 'PFI':
- suffix_title = 'PFI'
- title_fig += f'_{suffix_title}'
+ title_fig = "Pearson's r heatmap"
+ if module.lower() == "predict":
+ if os.path.basename(Path(params_dir)) == "No_PFI":
+ suffix_title = "No_PFI"
+ elif os.path.basename(Path(params_dir)) == "PFI":
+ suffix_title = "PFI"
+ title_fig += f"_{suffix_title}"
- plt.title(title_fig, y=1.04, fontsize = size_title, fontweight="bold")
- sb.set_style({'xtick.bottom': True}, {'ytick.left': True})
+ plt.title(title_fig, y=1.04, fontsize=size_title, fontweight="bold")
+ sb.set_style({"xtick.bottom": True}, {"ytick.left": True})
- if module.lower() == 'curate':
- heatmap_name = 'Pearson_heatmap.png'
- elif module.lower() == 'predict':
- heatmap_name = f'Pearson_heatmap_{suffix_title}.png'
+ if module.lower() == "curate":
+ heatmap_name = "Pearson_heatmap.png"
+ elif module.lower() == "predict":
+ heatmap_name = f"Pearson_heatmap_{suffix_title}.png"
heatmap_path = self.args.destination.joinpath(heatmap_name)
- plt.savefig(f'{heatmap_path}', dpi=300, bbox_inches='tight')
+ plt.savefig(f"{heatmap_path}", dpi=300, bbox_inches="tight")
+ plt.close(fig)
- path_reduced = '/'.join(f'{heatmap_path}'.replace('\\','/').split('/')[-2:])
- if module.lower() == 'curate':
- self.args.log.write(f'\no The Pearson heatmap was stored in {path_reduced}.')
- elif module.lower() == 'predict':
- self.args.log.write(f'\n o The Pearson heatmap was stored in {path_reduced}.')
+ path_reduced = "/".join(f"{heatmap_path}".replace("\\", "/").split("/")[-2:])
+ if module.lower() == "curate":
+ self.args.log.write(
+ f"\no The Pearson heatmap was stored in {path_reduced}."
+ )
+ elif module.lower() == "predict":
+ self.args.log.write(
+ f"\n o The Pearson heatmap was stored in {path_reduced}."
+ )
return corr_matrix
-def plot_metrics(model_data,suffix_title,verify_metrics,verify_results):
- '''
+def plot_metrics(model_data, suffix_title, verify_metrics, verify_results):
+ """
Creates a plot with the results of the flawed models in VERIFY
- '''
- import seaborn as sb
-
- importlib.reload(plt)
- sb.reset_defaults()
- sb.set(style="ticks")
- _, ax = plt.subplots(figsize=(7.45, 6))
-
- # define names
- csv_name = os.path.basename(model_data['model']).split('_db.csv')[0]
- base_csv_name = f'VERIFY/{csv_name}'
- base_csv_path = f"{Path(os.getcwd()).joinpath(base_csv_name)}"
- path_n_suffix = f'{base_csv_path}_{suffix_title}'
-
- # axis limits
- max_val = max(verify_metrics['metrics'])
- min_val = min(verify_metrics['metrics'])
- range_vals = np.abs(max_val - min_val)
- if verify_results['error_type'].lower() in ['mae','rmse']:
- max_lim = 1.2*max_val
- min_lim = 0
- else:
- max_lim = max_val + (0.2*range_vals)
- min_lim = min_val - (0.1*range_vals)
- plt.ylim(min_lim, max_lim)
- plt.ylim(min_lim, max_lim)
-
- # adjust number of significative numbers shown
- ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
-
- width_bar = 0.55
- label_count = 0
- for test_metric,test_name,test_color in zip(verify_metrics['metrics'],verify_metrics['test_names'],verify_metrics['colors']):
- rects = ax.bar(test_name, test_metric, label=test_name,
- width=width_bar, linewidth=1, edgecolor='k',
- color=test_color, zorder=2)
- # plot whether the tests pass or fail
- if test_name != 'Model':
- if test_metric >= 0:
- offset_txt = test_metric+(0.05*range_vals)
+ """
+ with _mpl_plot_context():
+ sb.reset_defaults()
+ sb.set(style="ticks")
+ fig, ax = plt.subplots(figsize=(7.45, 6))
+
+ # define names
+ csv_name = os.path.basename(model_data["model"]).split("_db.csv")[0]
+ base_csv_name = f"VERIFY/{csv_name}"
+ base_csv_path = f"{Path(os.getcwd()).joinpath(base_csv_name)}"
+ path_n_suffix = f"{base_csv_path}_{suffix_title}"
+
+ # axis limits
+ max_val = max(verify_metrics["metrics"])
+ min_val = min(verify_metrics["metrics"])
+ range_vals = np.abs(max_val - min_val)
+ if verify_results["error_type"].lower() in ["mae", "rmse"]:
+ min_lim = 0
+ max_lim = 1.2 * max_val if max_val != 0 else 0.1
+ else:
+ if range_vals == 0:
+ pad = max(abs(max_val) * 0.1, 0.05)
+ min_lim = max_val - pad
+ max_lim = max_val + pad
else:
- offset_txt = test_metric-(0.05*range_vals)
- if test_color == '#1f77b4':
- txt_bar = 'pass'
- elif test_color == '#cd5c5c':
- txt_bar = 'fail'
- elif test_color == '#c5c57d':
- txt_bar = 'unclear'
- ax.text(label_count, offset_txt, txt_bar, color=test_color,
- fontstyle='italic', horizontalalignment='center')
- label_count += 1
-
- # Set tick sizes
- plt.xticks(fontsize=14)
- plt.yticks(fontsize=14)
-
- # title and labels of the axis
- plt.ylabel(f'{verify_results["error_type"].upper()}', fontsize=14)
-
- plt.text(0.5, 1.08, f'VERIFY tests of {os.path.basename(path_n_suffix)}', horizontalalignment='center',
- fontsize=14, fontweight='bold', transform = ax.transAxes)
-
- # add threshold line and arrow indicating passed test direction
- arrow_length = np.abs(max_lim-min_lim)/11
-
- if verify_results['error_type'].lower() in ['mae','rmse']:
- thres_line = verify_metrics['higher_thres']
- unclear_thres_line = verify_metrics['unclear_higher_thres']
- else:
- thres_line = verify_metrics['lower_thres']
- unclear_thres_line = verify_metrics['unclear_lower_thres']
- arrow_length = -arrow_length
-
- width = 2
- xmin = 0.237
- thres = ax.axhline(thres_line,xmin=xmin, color='black',ls='--', label='thres', zorder=0)
- thres = ax.axhline(unclear_thres_line,xmin=xmin, color='black',ls='--', label='thres', zorder=0)
-
- x_arrow = 0.5
- style = mpatches.ArrowStyle('simple', head_length=4.5*width, head_width=3.5*width, tail_width=width)
- arrow = mpatches.FancyArrowPatch((x_arrow, thres_line), (x_arrow, thres_line+arrow_length),
- arrowstyle=style, color='k') # (x1,y1), (x2,y2) vector direction
- ax.add_patch(arrow)
+ max_lim = max_val + (0.2 * range_vals)
+ min_lim = min_val - (0.1 * range_vals)
+ if range_vals == 0 and verify_results["error_type"].lower() in ["mae", "rmse"]:
+ range_vals = max_lim - min_lim
+ ax.set_ylim(min_lim, max_lim)
+
+ # adjust number of significative numbers shown
+ ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
+
+ width_bar = 0.55
+ label_count = 0
+ for test_metric, test_name, test_color in zip(
+ verify_metrics["metrics"],
+ verify_metrics["test_names"],
+ verify_metrics["colors"],
+ ):
+ ax.bar(
+ test_name,
+ test_metric,
+ label=test_name,
+ width=width_bar,
+ linewidth=1,
+ edgecolor="k",
+ color=test_color,
+ zorder=2,
+ )
+ # plot whether the tests pass or fail
+ if test_name != "Model":
+ if test_metric >= 0:
+ offset_txt = test_metric + (0.05 * range_vals)
+ else:
+ offset_txt = test_metric - (0.05 * range_vals)
+ if test_color == "#1f77b4":
+ txt_bar = "pass"
+ elif test_color == "#cd5c5c":
+ txt_bar = "fail"
+ elif test_color == "#c5c57d":
+ txt_bar = "unclear"
+ ax.text(
+ label_count,
+ offset_txt,
+ txt_bar,
+ color=test_color,
+ fontstyle="italic",
+ horizontalalignment="center",
+ )
+ label_count += 1
+
+ # Set tick sizes
+ ax.tick_params(axis="x", labelsize=14)
+ ax.tick_params(axis="y", labelsize=14)
+
+ # title and labels of the axis
+ ax.set_ylabel(f"{verify_results['error_type'].upper()}", fontsize=14)
+
+ ax.text(
+ 0.5,
+ 1.08,
+ f"VERIFY tests of {os.path.basename(path_n_suffix)}",
+ horizontalalignment="center",
+ fontsize=14,
+ fontweight="bold",
+ transform=ax.transAxes,
+ )
- # invisible "dummy" arrows to make the graph wider so the real arrows fit in the right place
- ax.arrow(x_arrow, thres_line, 0, 0, width=0, fc='k', ec='k') # x,y,dx,dy format
+ # add threshold line and arrow indicating passed test direction
+ arrow_length = np.abs(max_lim - min_lim) / 11
- # legend and regression line with 95% CI considering all possible lines (not CI of the points)
- def make_legend_arrow(legend, orig_handle,
- xdescent, ydescent,
- width, height, fontsize):
- p = mpatches.FancyArrow(0, 0.5*height, width, 0, width=1.5, length_includes_head=True, head_width=0.58*height )
- return p
+ if verify_results["error_type"].lower() in ["mae", "rmse"]:
+ thres_line = verify_metrics["higher_thres"]
+ unclear_thres_line = verify_metrics["unclear_higher_thres"]
+ else:
+ thres_line = verify_metrics["lower_thres"]
+ unclear_thres_line = verify_metrics["unclear_lower_thres"]
+ arrow_length = -arrow_length
+
+ width = 2
+ xmin = 0.237
+ thres = ax.axhline(
+ thres_line, xmin=xmin, color="black", ls="--", label="thres", zorder=0
+ )
+ thres = ax.axhline(
+ unclear_thres_line,
+ xmin=xmin,
+ color="black",
+ ls="--",
+ label="thres",
+ zorder=0,
+ )
- arrow = plt.arrow(0, 0, 0, 0, label='arrow', width=0, fc='k', ec='k') # arrow for the legend
- plt.figlegend([thres,arrow], [f'Limits: {thres_line:.2} (pass), {unclear_thres_line:.2} (unclear)','Pass test'], handler_map={mpatches.FancyArrow : HandlerPatch(patch_func=make_legend_arrow),},
- loc="lower center", ncol=2, bbox_to_anchor=(0.5, -0.05),
- fancybox=True, shadow=True, fontsize=14)
+ x_arrow = 0.5
+ style = mpatches.ArrowStyle(
+ "simple", head_length=4.5 * width, head_width=3.5 * width, tail_width=width
+ )
+ arrow = mpatches.FancyArrowPatch(
+ (x_arrow, thres_line),
+ (x_arrow, thres_line + arrow_length),
+ arrowstyle=style,
+ color="k",
+ )
+ ax.add_patch(arrow)
+
+ # invisible "dummy" arrows to make the graph wider so the real arrows fit in the right place
+ ax.arrow(x_arrow, thres_line, 0, 0, width=0, fc="k", ec="k")
+
+ def make_legend_arrow(
+ legend, orig_handle, xdescent, ydescent, width, height, fontsize
+ ):
+ p = mpatches.FancyArrow(
+ 0,
+ 0.5 * height,
+ width,
+ 0,
+ width=1.5,
+ length_includes_head=True,
+ head_width=0.58 * height,
+ )
+ return p
+
+ arrow_legend = plt.arrow(0, 0, 0, 0, label="arrow", width=0, fc="k", ec="k")
+ fig.legend(
+ [thres, arrow_legend],
+ [
+ f"Limits: {thres_line:.2} (pass), {unclear_thres_line:.2} (unclear)",
+ "Pass test",
+ ],
+ handler_map={
+ mpatches.FancyArrow: HandlerPatch(patch_func=make_legend_arrow),
+ },
+ loc="lower center",
+ ncol=2,
+ bbox_to_anchor=(0.5, -0.05),
+ fancybox=True,
+ shadow=True,
+ fontsize=14,
+ )
- # Add gridlines
- ax.grid(linestyle='--', linewidth=1)
+ # Add gridlines
+ ax.grid(linestyle="--", linewidth=1)
- # save plot
- verify_plot_file = f'{os.path.dirname(path_n_suffix)}/VERIFY_tests_{os.path.basename(path_n_suffix)}.png'
- plt.savefig(verify_plot_file, dpi=300, bbox_inches='tight')
+ # save plot
+ verify_plot_file = f"{os.path.dirname(path_n_suffix)}/VERIFY_tests_{os.path.basename(path_n_suffix)}.png"
+ plt.savefig(verify_plot_file, dpi=300, bbox_inches="tight")
+ plt.close(fig)
- path_reduced = '/'.join(f'{verify_plot_file}'.replace('\\','/').split('/')[-2:])
+ path_reduced = "/".join(f"{verify_plot_file}".replace("\\", "/").split("/")[-2:])
print_ver = f"\n o VERIFY plot saved in {path_reduced}"
return print_ver
def dict_formating(dict_csv):
- '''
+ """
Adapt format of dictionaries that come from dataframes loaded from CSV
- '''
-
+ """
+
import json
- if 'X_descriptors' in dict_csv:
+ if "X_descriptors" in dict_csv:
# Try JSON first (new format), fall back to ast.literal_eval (old format)
try:
- dict_csv['X_descriptors'] = json.loads(dict_csv['X_descriptors'])
+ dict_csv["X_descriptors"] = json.loads(dict_csv["X_descriptors"])
except (json.JSONDecodeError, TypeError):
- dict_csv['X_descriptors'] = ast.literal_eval(dict_csv['X_descriptors'])
-
- if 'params' in dict_csv:
+ dict_csv["X_descriptors"] = ast.literal_eval(dict_csv["X_descriptors"])
+
+ if "params" in dict_csv:
# Try JSON first (new format), fall back to ast.literal_eval (old format)
try:
- dict_csv['params'] = json.loads(dict_csv['params'])
+ dict_csv["params"] = json.loads(dict_csv["params"])
except (json.JSONDecodeError, TypeError):
- dict_csv['params'] = ast.literal_eval(dict_csv['params'])
+ dict_csv["params"] = ast.literal_eval(dict_csv["params"])
- return dict_csv
\ No newline at end of file
+ return dict_csv
diff --git a/robert/verify.py b/robert/verify.py
index 9ee9440..4c33e07 100644
--- a/robert/verify.py
+++ b/robert/verify.py
@@ -5,13 +5,13 @@
destination : str, default=None,
Directory to create the output file(s).
varfile : str, default=None
- Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml).
+ Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml).
params_dir : str, default=''
Folder containing the database and parameters of the ML model to analyze.
seed : int, default=0
Random seed used in the ML predictor models and other protocols.
kfold : int, default=5
- Number of random data splits for the cross-validation of the models.
+ Number of random data splits for the cross-validation of the models.
repeat_kfolds : int, default=10
Number of repetitions for the k-fold cross-validation of the models.
@@ -42,6 +42,7 @@
thres_test_pass = 0.3
thres_test_unclear = 0.15
+
class verify:
"""
Class containing all the functions from the VERIFY module.
@@ -53,7 +54,6 @@ class verify:
"""
def __init__(self, **kwargs):
-
start_time = time.time()
# load default and user-specified variables
@@ -63,194 +63,247 @@ def __init__(self, **kwargs):
self.args.params_dir
):
if os.path.exists(params_dir):
-
- _ = print_pfi(self,params_dir)
+ _ = print_pfi(self, params_dir)
# load the Xy databse and model parameters
- Xy_data, model_data, suffix_title = load_db_n_params(self,params_dir,suffix,suffix_title,"verify",True)
-
+ Xy_data, model_data, suffix_title = load_db_n_params(
+ self, params_dir, suffix, suffix_title, "verify", True
+ )
+
# this dictionary will keep the results of the tests
- verify_results = {'error_type': model_data['error_type']}
+ verify_results = {"error_type": model_data["error_type"]}
# get data about repeated and sorted CVs
- Xy_data = load_n_predict(self, model_data, Xy_data, BO_opt=True, verify_job=True)
- verify_results['CV_score'] = Xy_data[f'{verify_results["error_type"]}_train']
- verify_results['sorted_CV_score'] = Xy_data[f'{model_data["error_type"]}_train_sorted_CV']
- if model_data['type'].lower() == 'reg':
- verify_results[f'r2_train_sorted_CV'] = [float(f"{val:.2f}") for val in Xy_data[f'r2_train_sorted_CV']]
- verify_results[f'mae_train_sorted_CV'] = [float(f"{val:.2f}") for val in Xy_data[f'mae_train_sorted_CV']]
- verify_results[f'rmse_train_sorted_CV'] = [float(f"{val:.2f}") for val in Xy_data[f'rmse_train_sorted_CV']]
- elif model_data['type'].lower() == 'clas':
- verify_results[f'acc_train_sorted_CV'] = [float(f"{val:.2f}") for val in Xy_data[f'acc_train_sorted_CV']]
- verify_results[f'f1_train_sorted_CV'] = [float(f"{val:.2f}") for val in Xy_data[f'f1_train_sorted_CV']]
- verify_results[f'mcc_train_sorted_CV'] = [float(f"{val:.2f}") for val in Xy_data[f'mcc_train_sorted_CV']]
+ Xy_data = load_n_predict(
+ self, model_data, Xy_data, BO_opt=True, verify_job=True
+ )
+ verify_results["CV_score"] = Xy_data[
+ f"{verify_results['error_type']}_train"
+ ]
+ verify_results["sorted_CV_score"] = Xy_data[
+ f"{model_data['error_type']}_train_sorted_CV"
+ ]
+ if model_data["type"].lower() == "reg":
+ verify_results["r2_train_sorted_CV"] = [
+ float(f"{val:.2f}") for val in Xy_data["r2_train_sorted_CV"]
+ ]
+ verify_results["mae_train_sorted_CV"] = [
+ float(f"{val:.2f}") for val in Xy_data["mae_train_sorted_CV"]
+ ]
+ verify_results["rmse_train_sorted_CV"] = [
+ float(f"{val:.2f}") for val in Xy_data["rmse_train_sorted_CV"]
+ ]
+ elif model_data["type"].lower() == "clas":
+ verify_results["acc_train_sorted_CV"] = [
+ float(f"{val:.2f}") for val in Xy_data["acc_train_sorted_CV"]
+ ]
+ verify_results["f1_train_sorted_CV"] = [
+ float(f"{val:.2f}") for val in Xy_data["f1_train_sorted_CV"]
+ ]
+ verify_results["mcc_train_sorted_CV"] = [
+ float(f"{val:.2f}") for val in Xy_data["mcc_train_sorted_CV"]
+ ]
# Reload once for flawed-model tests (fresh splits consistent with CSV on disk).
- Xy_data, model_data, suffix_title = load_db_n_params(self,params_dir,suffix,suffix_title,"verify",False)
+ Xy_data, model_data, suffix_title = load_db_n_params(
+ self, params_dir, suffix, suffix_title, "verify", False
+ )
# calculate scores for the y-mean test
- verify_results = self.ymean_test(verify_results,Xy_data,model_data)
+ verify_results = self.ymean_test(verify_results, Xy_data, model_data)
# calculate scores for the y-shuffle test
- verify_results = self.yshuffle_test(verify_results,Xy_data,model_data)
+ verify_results = self.yshuffle_test(verify_results, Xy_data, model_data)
# one-hot test (check that if a value isnt 0, the value assigned is 1)
- verify_results = self.onehot_test(verify_results,Xy_data,model_data)
+ verify_results = self.onehot_test(verify_results, Xy_data, model_data)
# analysis of results
- results_print,verify_results,verify_metrics = self.analyze_tests(verify_results)
+ results_print, verify_results, verify_metrics = self.analyze_tests(
+ verify_results
+ )
# plot a bar graph with the results
if should_plot_verify_metrics(self.args):
- print_ver = plot_metrics(model_data,suffix_title,verify_metrics,verify_results)
+ print_ver = plot_metrics(
+ model_data, suffix_title, verify_metrics, verify_results
+ )
else:
print_ver = "\n o VERIFY plot skipped (plot_verbosity)"
# print and save results
- _ = self.print_verify(results_print,verify_results,print_ver,model_data)
-
- _ = finish_print(self,start_time,'VERIFY')
+ _ = self.print_verify(
+ results_print, verify_results, print_ver, model_data
+ )
+ _ = finish_print(self, start_time, "VERIFY")
- def ymean_test(self,verify_results,Xy_data,model_data):
- '''
- Calculate the accuracy of the model when using a flat line of predicted y values. For
+ def ymean_test(self, verify_results, Xy_data, model_data):
+ """
+ Calculate the accuracy of the model when using a flat line of predicted y values. For
regression, the mean of the y values is used. For classification, the value that is
predicted more often is used.
- '''
+ """
- Xy_ymean = Xy_data.copy()
- if model_data['type'].lower() == 'reg':
- y_mean_array = np.ones(len(Xy_ymean['y_train']))*(Xy_ymean['y_train'].mean())
- Xy_ymean['r2_train'], Xy_ymean['mae_train'], Xy_ymean['rmse_train'] = get_prediction_results(model_data,Xy_ymean['y_train'],y_mean_array)
-
- elif model_data['type'].lower() == 'clas':
- y_mean_array = np.ones(len(Xy_ymean['y_train']))*mode(Xy_ymean['y_train'])
- Xy_ymean['acc_train'], Xy_ymean['f1_train'], Xy_ymean['mcc_train'] = get_prediction_results(model_data,Xy_ymean['y_train'],y_mean_array)
+ Xy_ymean = Xy_data.copy()
+ if model_data["type"].lower() == "reg":
+ y_mean_array = np.ones(len(Xy_ymean["y_train"])) * (
+ Xy_ymean["y_train"].mean()
+ )
+ Xy_ymean["r2_train"], Xy_ymean["mae_train"], Xy_ymean["rmse_train"] = (
+ get_prediction_results(model_data, Xy_ymean["y_train"], y_mean_array)
+ )
- verify_results['y_mean'] = Xy_ymean[f'{verify_results["error_type"]}_train']
+ elif model_data["type"].lower() == "clas":
+ y_mean_array = np.ones(len(Xy_ymean["y_train"])) * mode(Xy_ymean["y_train"])
+ Xy_ymean["acc_train"], Xy_ymean["f1_train"], Xy_ymean["mcc_train"] = (
+ get_prediction_results(model_data, Xy_ymean["y_train"], y_mean_array)
+ )
- return verify_results
+ verify_results["y_mean"] = Xy_ymean[f"{verify_results['error_type']}_train"]
+ return verify_results
- def yshuffle_test(self,verify_results,Xy_data,model_data):
- '''
+ def yshuffle_test(self, verify_results, Xy_data, model_data):
+ """
Calculate the accuracy of the model when the y values are randomly shuffled in the validation set
For example, a y array of 1.3, 2.1, 4.0, 5.2 might become 2.1, 1.3, 5.2, 4.0.
- '''
+ """
Xy_yshuffle = Xy_data.copy()
- Xy_yshuffle['y_train'] = Xy_yshuffle['y_train'].sample(frac=1,random_state=model_data['seed'],axis=0)
+ Xy_yshuffle["y_train"] = Xy_yshuffle["y_train"].sample(
+ frac=1, random_state=model_data["seed"], axis=0
+ )
Xy_yshuffle = load_n_predict(self, model_data, Xy_yshuffle, BO_opt=False)
- verify_results['y_shuffle'] = Xy_yshuffle[f'{verify_results["error_type"]}_train']
+ verify_results["y_shuffle"] = Xy_yshuffle[
+ f"{verify_results['error_type']}_train"
+ ]
return verify_results
-
- def onehot_test(self,verify_results,Xy_data,model_data):
- '''
+ def onehot_test(self, verify_results, Xy_data, model_data):
+ """
Calculate the accuracy of the model when using one-hot models. All X values that are
not 0 are considered to be 1 (NaN from missing values are converted to 0).
- '''
+ """
Xy_onehot = Xy_data.copy()
- Xy_onehot['X_train_scaled'] = Xy_onehot['X_train_scaled'].copy()
- for desc in Xy_onehot['X_train']:
+ Xy_onehot["X_train_scaled"] = Xy_onehot["X_train_scaled"].copy()
+ for desc in Xy_onehot["X_train"]:
new_vals = []
- for val in Xy_onehot['X_train'][desc]:
+ for val in Xy_onehot["X_train"][desc]:
if val == 0:
new_vals.append(0)
else:
new_vals.append(1)
- Xy_onehot['X_train_scaled'][desc] = new_vals
+ Xy_onehot["X_train_scaled"][desc] = new_vals
Xy_onehot = load_n_predict(self, model_data, Xy_onehot, BO_opt=False)
- verify_results['onehot'] = Xy_onehot[f'{verify_results["error_type"]}_train']
+ verify_results["onehot"] = Xy_onehot[f"{verify_results['error_type']}_train"]
return verify_results
-
- def analyze_tests(self,verify_results):
- '''
+ def analyze_tests(self, verify_results):
+ """
Function to check whether the tests pass and retrieve the corresponding colors:
1. Blue for passing tests
2. Red for failing tests
- '''
+ """
- blue_color = '#1f77b4'
- red_color = '#cd5c5c'
- yellow_color = '#c5c57d'
- colors = [None,None,None]
- results_print = [None,None,None]
- metrics = [None,None,None]
+ blue_color = "#1f77b4"
+ red_color = "#cd5c5c"
+ yellow_color = "#c5c57d"
+ colors = [None, None, None]
+ results_print = [None, None, None]
+ metrics = [None, None, None]
# the threshold uses validation results to compare in the tests
- verify_results['higher_thres'] = (1+thres_test_pass)*verify_results['CV_score']
- verify_results['unclear_higher_thres'] = (1+thres_test_unclear)*verify_results['CV_score']
- verify_results['lower_thres'] = (1-thres_test_pass)*verify_results['CV_score']
- verify_results['unclear_lower_thres'] = (1-thres_test_unclear)*verify_results['CV_score']
+ verify_results["higher_thres"] = (1 + thres_test_pass) * verify_results[
+ "CV_score"
+ ]
+ verify_results["unclear_higher_thres"] = (
+ 1 + thres_test_unclear
+ ) * verify_results["CV_score"]
+ verify_results["lower_thres"] = (1 - thres_test_pass) * verify_results[
+ "CV_score"
+ ]
+ verify_results["unclear_lower_thres"] = (
+ 1 - thres_test_unclear
+ ) * verify_results["CV_score"]
# determine whether the tests pass
- test_names = ['y_mean','y_shuffle','onehot']
- for i,test_ver in enumerate(test_names):
+ test_names = ["y_mean", "y_shuffle", "onehot"]
+ for i, test_ver in enumerate(test_names):
metrics[i] = verify_results[test_ver]
- if verify_results['error_type'].lower() in ['mae','rmse']:
- if verify_results[test_ver] <= verify_results['unclear_higher_thres']:
- colors[i] = red_color
- results_print[i] = f'\n x {test_ver}: FAILED, {verify_results["error_type"].upper()} = {verify_results[test_ver]:.2}, lower than threshold'
- elif verify_results[test_ver] <= verify_results['higher_thres']:
- colors[i] = yellow_color
- results_print[i] = f'\n - {test_ver}: UNCLEAR, {verify_results["error_type"].upper()} = {verify_results[test_ver]:.2}, higher than original, but close to fail'
+ if verify_results["error_type"].lower() in ["mae", "rmse"]:
+ if verify_results[test_ver] <= verify_results["unclear_higher_thres"]:
+ colors[i] = red_color
+ results_print[i] = (
+ f"\n x {test_ver}: FAILED, {verify_results['error_type'].upper()} = {verify_results[test_ver]:.2}, lower than threshold"
+ )
+ elif verify_results[test_ver] <= verify_results["higher_thres"]:
+ colors[i] = yellow_color
+ results_print[i] = (
+ f"\n - {test_ver}: UNCLEAR, {verify_results['error_type'].upper()} = {verify_results[test_ver]:.2}, higher than original, but close to fail"
+ )
else:
- colors[i] = blue_color
- results_print[i] = f'\n o {test_ver}: PASSED, {verify_results["error_type"].upper()} = {verify_results[test_ver]:.2}, higher than thresholds'
+ colors[i] = blue_color
+ results_print[i] = (
+ f"\n o {test_ver}: PASSED, {verify_results['error_type'].upper()} = {verify_results[test_ver]:.2}, higher than thresholds"
+ )
else:
- if verify_results[test_ver] >= verify_results['unclear_lower_thres']:
- colors[i] = red_color
- results_print[i] = f'\n x {test_ver}: FAILED, {verify_results["error_type"].upper()} = {verify_results[test_ver]:.2}, higher than thresholds'
- elif verify_results[test_ver] >= verify_results['lower_thres']:
- colors[i] = yellow_color
- results_print[i] = f'\n - {test_ver}: UNCLEAR, {verify_results["error_type"].upper()} = {verify_results[test_ver]:.2}, lower than original, but close to fail'
+ if verify_results[test_ver] >= verify_results["unclear_lower_thres"]:
+ colors[i] = red_color
+ results_print[i] = (
+ f"\n x {test_ver}: FAILED, {verify_results['error_type'].upper()} = {verify_results[test_ver]:.2}, higher than thresholds"
+ )
+ elif verify_results[test_ver] >= verify_results["lower_thres"]:
+ colors[i] = yellow_color
+ results_print[i] = (
+ f"\n - {test_ver}: UNCLEAR, {verify_results['error_type'].upper()} = {verify_results[test_ver]:.2}, lower than original, but close to fail"
+ )
else:
- colors[i] = blue_color
- results_print[i] = f'\n o {test_ver}: PASSED, {verify_results["error_type"].upper()} = {verify_results[test_ver]:.2}, lower than thresholds'
+ colors[i] = blue_color
+ results_print[i] = (
+ f"\n o {test_ver}: PASSED, {verify_results['error_type'].upper()} = {verify_results[test_ver]:.2}, lower than thresholds"
+ )
- # store metrics and colors to represent in comparison graph, adding the metrics of the
+ # store metrics and colors to represent in comparison graph, adding the metrics of the
# original model first
- test_names = ['Model'] + test_names
+ test_names = ["Model"] + test_names
colors = [blue_color] + colors
- metrics = [verify_results['CV_score']] + metrics
- verify_metrics = {'test_names': test_names,
- 'colors': colors,
- 'metrics': metrics,
- 'higher_thres': verify_results['higher_thres'],
- 'lower_thres': verify_results['lower_thres'],
- 'unclear_higher_thres': verify_results['unclear_higher_thres'],
- 'unclear_lower_thres': verify_results['unclear_lower_thres'],
- }
-
- return results_print,verify_results,verify_metrics
-
-
- def print_verify(self,results_print,verify_results,print_ver,model_data):
- '''
+ metrics = [verify_results["CV_score"]] + metrics
+ verify_metrics = {
+ "test_names": test_names,
+ "colors": colors,
+ "metrics": metrics,
+ "higher_thres": verify_results["higher_thres"],
+ "lower_thres": verify_results["lower_thres"],
+ "unclear_higher_thres": verify_results["unclear_higher_thres"],
+ "unclear_lower_thres": verify_results["unclear_lower_thres"],
+ }
+
+ return results_print, verify_results, verify_metrics
+
+ def print_verify(self, results_print, verify_results, print_ver, model_data):
+ """
Print and store the results of VERIFY
- '''
+ """
- print_ver += f'\n Results of flawed models and sorted cross-validation:'
+ print_ver += "\n Results of flawed models and sorted cross-validation:"
CV_type = f"{model_data['repeat_kfolds']}x {model_data['kfold']}-fold CV"
# the printing order should be y-mean, y-shuffle and one-hot
- if verify_results['error_type'].lower() in ['mae','rmse']:
- print_ver += f'\n Original {verify_results["error_type"].upper()} ({CV_type}) {verify_results["CV_score"]:.2} + {int(thres_test_unclear*100)}% & {int(thres_test_pass*100)}% threshold = {verify_results["unclear_higher_thres"]:.2} & {verify_results["higher_thres"]:.2}'
+ if verify_results["error_type"].lower() in ["mae", "rmse"]:
+ print_ver += f"\n Original {verify_results['error_type'].upper()} ({CV_type}) {verify_results['CV_score']:.2} + {int(thres_test_unclear * 100)}% & {int(thres_test_pass * 100)}% threshold = {verify_results['unclear_higher_thres']:.2} & {verify_results['higher_thres']:.2}"
else:
- print_ver += f'\n Original {verify_results["error_type"].upper()} ({CV_type}) {verify_results["CV_score"]:.2} - {int(thres_test_unclear*100)}% & {int(thres_test_pass*100)}% threshold = {verify_results["unclear_lower_thres"]:.2} & {verify_results["lower_thres"]:.2}'
+ print_ver += f"\n Original {verify_results['error_type'].upper()} ({CV_type}) {verify_results['CV_score']:.2} - {int(thres_test_unclear * 100)}% & {int(thres_test_pass * 100)}% threshold = {verify_results['unclear_lower_thres']:.2} & {verify_results['lower_thres']:.2}"
print_ver += results_print[0]
print_ver += results_print[1]
print_ver += results_print[2]
- if model_data['type'].lower() == 'reg':
+ if model_data["type"].lower() == "reg":
print_ver += f"\n - Sorted {model_data['kfold']}-fold CV : R2 = {verify_results['r2_train_sorted_CV']}, MAE = {verify_results['mae_train_sorted_CV']}, RMSE = {verify_results['rmse_train_sorted_CV']}"
- elif model_data['type'].lower() == 'clas':
+ elif model_data["type"].lower() == "clas":
print_ver += f"\n - Sorted CV : Accuracy = {verify_results['acc_train_sorted_CV']}, F1 score = {verify_results['f1_train_sorted_CV']}, MCC = {verify_results['mcc_train_sorted_CV']}"
self.args.log.write(print_ver)
diff --git a/setup.py b/setup.py
index ecfaff1..4a8f3fd 100644
--- a/setup.py
+++ b/setup.py
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages
-version = "2.1.0"
+version = "2.2.0"
setup(
name="robert",
@@ -34,11 +34,11 @@
url="https://github.com/jvalegre/robert",
download_url=f"https://github.com/jvalegre/robert/archive/refs/tags/{version}.tar.gz",
classifiers=[
- "Development Status :: 5 - Production/Stable", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
- "Intended Audience :: Developers", # Define that your audience are developers
+ "Development Status :: 5 - Production/Stable", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
+ "Intended Audience :: Developers", # Define that your audience are developers
"Topic :: Software Development :: Build Tools",
"License :: OSI Approved :: MIT License",
- "Programming Language :: Python :: 3.11", # Specify which python versions you want to support
+ "Programming Language :: Python :: 3.11", # Specify which python versions you want to support
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
@@ -68,9 +68,16 @@
],
python_requires=">=3.11",
include_package_data=True,
+ extras_require={
+ "test": [
+ "pytest>=7.0",
+ "pytest-cov>=4.0",
+ "pytest-qt>=4.0",
+ ],
+ },
entry_points={
"console_scripts": [
"easyrob=robert.gui_easyrob.easyrob_launcher:main",
],
},
-)
\ No newline at end of file
+)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..7c61582
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1 @@
+"""ROBERT test package (enables shared imports between test modules)."""
diff --git a/tests/conftest.py b/tests/conftest.py
index 34aeb51..1414922 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,10 +1,28 @@
"""Pytest configuration for the ROBERT test suite."""
+from __future__ import annotations
+
import os
import sys
+from contextlib import contextmanager
+from pathlib import Path
import pytest
+REPO_ROOT = Path(__file__).resolve().parent.parent
+
+# Module output folders legacy integration tests swap under the repo root.
+ROBERT_MODULE_DIR_NAMES = (
+ "CURATE",
+ "GENERATE",
+ "GENERATE_reg",
+ "GENERATE_clas",
+ "PREDICT",
+ "VERIFY",
+ "AQME",
+ "EVALUATE",
+)
+
def pytest_configure(config):
"""
@@ -28,6 +46,101 @@ def pytest_configure(config):
os.environ.setdefault("MPLBACKEND", "Agg")
+def robert_module_dirs(root: Path) -> set[str]:
+ """Names of ROBERT module directories present under ``root``."""
+ return {name for name in ROBERT_MODULE_DIR_NAMES if (root / name).is_dir()}
+
+
+def aqme_installed() -> bool:
+ """Return True when the optional AQME package is importable."""
+ try:
+ import aqme # noqa: F401
+
+ return True
+ except ImportError:
+ return False
+
+
+def restore_regression_generate_layout(root: Path) -> None:
+ """
+ Undo a half-finished clas/reg GENERATE rename left by a failed test.
+
+ After a successful clas test: ``GENERATE`` (reg) + ``GENERATE_clas``.
+ Mid-clas failure may leave: ``GENERATE_reg`` + ``GENERATE`` (clas).
+ """
+ reg_backup = root / "GENERATE_reg"
+ generate = root / "GENERATE"
+ clas = root / "GENERATE_clas"
+ if reg_backup.is_dir() and generate.is_dir():
+ generate.rename(clas)
+ reg_backup.rename(generate)
+ elif reg_backup.is_dir() and not generate.is_dir():
+ reg_backup.rename(generate)
+
+
+@contextmanager
+def clas_generate_layout(root: Path):
+ """
+ Temporarily point ``GENERATE`` at the classification screening outputs.
+
+ Requires ``GENERATE`` (regression) and ``GENERATE_clas`` from
+ ``test_2generate`` (e.g. ``reduced_clas``).
+ """
+ restore_regression_generate_layout(root)
+ generate = root / "GENERATE"
+ clas = root / "GENERATE_clas"
+ if not clas.is_dir():
+ pytest.skip(
+ "GENERATE_clas missing under repo root; run "
+ "tests/test_2generate.py::test_GENERATE[reduced_clas] first."
+ )
+ if not generate.is_dir():
+ pytest.skip(
+ "GENERATE missing under repo root; run GENERATE integration tests first."
+ )
+ generate.rename(root / "GENERATE_reg")
+ clas.rename(generate)
+ try:
+ yield
+ finally:
+ generate.rename(clas)
+ (root / "GENERATE_reg").rename(generate)
+
+
+@pytest.fixture
+def repo_root() -> Path:
+ """Repository root (stable even if the process cwd changes during a test)."""
+ return REPO_ROOT
+
+
+@pytest.fixture(autouse=True)
+def restore_process_cwd():
+ """Restore the process working directory after each test."""
+ cwd_before = os.getcwd()
+ yield
+ try:
+ os.chdir(cwd_before)
+ except OSError:
+ pass
+
+
+@pytest.fixture(autouse=True)
+def drain_qt_thread_pool_after_test():
+ """Avoid 'QThread destroyed while still running' abort on pytest exit."""
+ yield
+ try:
+ from PySide6.QtCore import QCoreApplication, QThreadPool
+ except ImportError:
+ return
+ app = QCoreApplication.instance()
+ if app is None:
+ return
+ pool = QThreadPool.globalInstance()
+ if pool is not None:
+ pool.waitForDone(10_000)
+ app.processEvents()
+
+
@pytest.fixture
def fast_robert_kwargs():
"""Reduced CV/BO settings for faster integration tests."""
diff --git a/tests/test_2generate.py b/tests/test_2generate.py
index 4c91a9f..deb847f 100644
--- a/tests/test_2generate.py
+++ b/tests/test_2generate.py
@@ -75,9 +75,7 @@ def _log_line_metric_close(line, prefix, expected, *, rel_tol=0.05, abs_tol=0.02
(
"reduced_adab"
), # test for other GP model (important since PFI filter tries to discard all the descriptors)
- (
- "reduced_xgb"
- ), # test for XGB model (tree booster with feature importances)
+ ("reduced_xgb"), # test for XGB model (tree booster with feature importances)
("reduced_vr"), # test for Voting Regressor model
("reduced_vr_clas"), # test Voting Classifier workflow
("reduced_clas"), # test for clasification models
@@ -126,7 +124,13 @@ def test_GENERATE(test_job):
generate_kwargs = {"generate": True, "csv_name": csv_name, "y": "Target_values"}
if test_job != "standard":
# add model
- if test_job not in ["reduced_gp", "reduced_adab", "reduced_xgb", "reduced_vr", "reduced_vr_clas"]:
+ if test_job not in [
+ "reduced_gp",
+ "reduced_adab",
+ "reduced_xgb",
+ "reduced_vr",
+ "reduced_vr_clas",
+ ]:
generate_kwargs["model"] = ["RF"]
elif test_job == "reduced_gp":
generate_kwargs["model"] = ["GP"]
@@ -465,7 +469,9 @@ def test_GENERATE(test_job):
if test_job in ["reduced_clas", "reduced_vr_clas"]:
model_name = "RF" if test_job == "reduced_clas" else "VR"
csv_clas = glob.glob(
- os.path.join(path_generate, "Best_model", "PFI", f"{model_name}_PFI.csv")
+ os.path.join(
+ path_generate, "Best_model", "PFI", f"{model_name}_PFI.csv"
+ )
)
df = pd.read_csv(csv_clas[0])
if "error_type" in df.columns:
diff --git a/tests/test_3verify.py b/tests/test_3verify.py
index 3640273..96828dd 100644
--- a/tests/test_3verify.py
+++ b/tests/test_3verify.py
@@ -4,18 +4,18 @@
# Testing VERIFY with pytest #
######################################################.
-import os
+import subprocess
import sys
-import glob
+
import pytest
import shutil
-import subprocess
-from pathlib import Path
+
from robert.verify import verify
-# saves the working directory
-path_main = os.getcwd()
-path_verify = os.path.join(path_main, "VERIFY")
+from tests.conftest import (
+ clas_generate_layout,
+ restore_regression_generate_layout,
+)
# VERIFY tests
@@ -27,60 +27,39 @@
("standard_cmd"), # standard test with command line
],
)
-def test_VERIFY(test_job):
- # leave the folders as they were initially to run a different batch of tests
- if os.path.exists(path_verify):
+def test_VERIFY(test_job, repo_root, monkeypatch):
+ monkeypatch.chdir(repo_root)
+ path_verify = repo_root / "VERIFY"
+
+ if path_verify.is_dir():
shutil.rmtree(path_verify)
- # remove DAT and CSV files generated by VERIFY
- dat_files = glob.glob("*.dat")
- for dat_file in dat_files:
- if "VERIFY" in dat_file:
- os.remove(dat_file)
+ for dat_file in repo_root.glob("*.dat"):
+ if "VERIFY" in dat_file.name:
+ dat_file.unlink()
- if test_job == "clas": # rename folders to use in classification
- # rename the regression GENERATE folder
- filepath_reg = Path(path_main) / "GENERATE"
- filepath_reg.rename(Path(path_main) / "GENERATE_reg")
- # rename the classification GENERATE folder
- filepath = Path(path_main) / "GENERATE_clas"
- filepath.rename(Path(path_main) / "GENERATE")
+ if test_job == "clas":
+ with clas_generate_layout(repo_root):
+ _run_verify(test_job, repo_root, path_verify)
+ else:
+ restore_regression_generate_layout(repo_root)
+ _run_verify(test_job, repo_root, path_verify)
- else: # in case the clas test fails and the ending rename doesn't happen
- if os.path.exists(Path(path_main) / "GENERATE_reg"):
- # rename the classification GENERATE folder
- filepath = Path(path_main) / "GENERATE"
- filepath.rename(Path(path_main) / "GENERATE_clas")
- # rename the regression GENERATE folder
- filepath_reg = Path(path_main) / "GENERATE_reg"
- filepath_reg.rename(Path(path_main) / "GENERATE")
- # runs the program with the different tests
+def _run_verify(test_job, repo_root, path_verify):
if test_job == "standard_cmd":
- cmd_robert = [
- sys.executable,
- "-m",
- "robert",
- "--verify",
- ]
-
- subprocess.run(cmd_robert)
-
+ cmd_robert = [sys.executable, "-m", "robert", "--verify"]
+ subprocess.run(cmd_robert, cwd=repo_root, check=False)
else:
- verify_kwargs = {}
+ verify()
- verify(**verify_kwargs)
-
- # check that the DAT file is created
- assert not os.path.exists(os.path.join(path_main, "VERIFY_data.dat"))
- outfile = open(os.path.join(path_verify, "VERIFY_data.dat"), "r")
- outlines = outfile.readlines()
- outfile.close()
+ assert not (repo_root / "VERIFY_data.dat").is_file()
+ verify_dat = path_verify / "VERIFY_data.dat"
+ with verify_dat.open(encoding="utf-8") as outfile:
+ outlines = outfile.readlines()
assert "ROBERT v" in outlines[0]
results_line, start_reading = False, False
for i, line in enumerate(outlines):
- if (
- "------- Starting model with PFI filter " in line
- ): # focus on the PFI since there is an unclear test
+ if "------- Starting model with PFI filter " in line:
start_reading = True
if start_reading:
if "Results of flawed models and sorted cross-validation:" in line:
@@ -121,14 +100,5 @@ def test_VERIFY(test_job):
break
assert results_line
- # check that the verify plots and DAT file are created
- assert len(glob.glob(os.path.join(path_verify, "*.png"))) == 2
- assert len(glob.glob(os.path.join(path_verify, "*.dat"))) == 1
-
- if test_job == "clas": # rename folders back to their original names
- # rename the classification GENERATE folder
- filepath = Path(path_main) / "GENERATE"
- filepath.rename(Path(path_main) / "GENERATE_clas")
- # rename the regression GENERATE folder
- filepath_reg = Path(path_main) / "GENERATE_reg"
- filepath_reg.rename(Path(path_main) / "GENERATE")
+ assert len(list(path_verify.glob("*.png"))) == 2
+ assert len(list(path_verify.glob("*.dat"))) == 1
diff --git a/tests/test_5aqme_n_full.py b/tests/test_5aqme_n_full.py
index 4d05448..1c1905b 100644
--- a/tests/test_5aqme_n_full.py
+++ b/tests/test_5aqme_n_full.py
@@ -17,6 +17,15 @@
path_aqme = os.path.join(path_main, "AQME")
+def _aqme_installed() -> bool:
+ try:
+ import aqme # noqa: F401
+
+ return True
+ except ImportError:
+ return False
+
+
# AQME and full workflow tests
@pytest.mark.parametrize(
"test_job",
@@ -30,6 +39,9 @@
],
)
def test_AQME(test_job):
+ if test_job in ("aqme", "2smiles_columns") and not _aqme_installed():
+ pytest.skip("AQME is not installed (pip install aqme==2.0.0)")
+
# reset the folders (to avoid interferences with previous failed tests)
folders = [
"CURATE",
diff --git a/tests/test_7api.py b/tests/test_7api.py
index e481533..648731e 100644
--- a/tests/test_7api.py
+++ b/tests/test_7api.py
@@ -12,7 +12,6 @@
import sys
import tempfile
import warnings
-from pathlib import Path
from types import SimpleNamespace
import numpy as np
@@ -33,12 +32,35 @@
should_plot_verify_metrics,
)
-_REPO = Path(__file__).resolve().parent.parent
+from tests.conftest import REPO_ROOT, robert_module_dirs
+
+_REPO = REPO_ROOT
_REG_CSV = _REPO / "tests" / "Robert_example.csv"
_CLAS_CSV = _REPO / "tests" / "Robert_example_clas.csv"
_FIXTURE_MODEL = _REPO / "tests" / "fixtures" / "custom_predict_model"
+@pytest.fixture(autouse=True)
+def api_test_repo_isolation(repo_root):
+ """
+ RobertModel.fit uses os.chdir(workdir); ensure cwd and repo-root artifacts
+ do not leak into legacy integration tests (e.g. test_3verify).
+ """
+ dirs_before = robert_module_dirs(repo_root)
+ yield
+ try:
+ os.chdir(repo_root)
+ except OSError:
+ pass
+ for name in robert_module_dirs(repo_root) - dirs_before:
+ shutil.rmtree(repo_root / name, ignore_errors=True)
+ for dat_file in repo_root.glob("*.dat"):
+ if any(
+ tag in dat_file.name for tag in ("CURATE", "GENERATE", "PREDICT", "VERIFY")
+ ):
+ dat_file.unlink(missing_ok=True)
+
+
@pytest.fixture
def custom_model_dir(tmp_path):
"""Minimal GENERATE-style folder (params CSV + _db.csv)."""
@@ -59,7 +81,9 @@ def _holdout_for_predict(X: pd.DataFrame, n_fit: int) -> pd.DataFrame:
def test_yaml_unknown_key_warns_and_known_key_applies(capsys):
- with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False, encoding="utf-8") as f:
+ with tempfile.NamedTemporaryFile(
+ mode="w", suffix=".yaml", delete=False, encoding="utf-8"
+ ) as f:
f.write("not_a_robert_option: 1\nseed: 99\n")
path = f.name
try:
@@ -76,7 +100,9 @@ def test_yaml_unknown_key_warns_and_known_key_applies(capsys):
def test_yaml_missing_file_message():
opts = set_options({})
- opts.varfile = os.path.join(tempfile.gettempdir(), "robert_nonexistent_params_xyz.yaml")
+ opts.varfile = os.path.join(
+ tempfile.gettempdir(), "robert_nonexistent_params_xyz.yaml"
+ )
_, msg = load_from_yaml(opts)
assert "not found" in msg.lower()
@@ -307,6 +333,29 @@ def test_predict_row_order_matches_input_order(tmp_path, fast_robert_kwargs):
assert np.allclose(pred_natural, pred_realigned)
+def test_fit_with_report_and_robert_scores(tmp_path, fast_robert_kwargs):
+ """Full API path with REPORT and robert_scores()."""
+ pytest.importorskip("weasyprint")
+ df = pd.read_csv(_REG_CSV, encoding="utf-8")
+ X = df.drop(columns=["Target_values"])
+ y = df["Target_values"]
+ model = RobertModel(
+ problem_type="reg",
+ filter_mode="no_pfi",
+ workdir=tmp_path,
+ names="Name",
+ report=True,
+ **fast_robert_kwargs,
+ )
+ model.fit(X.iloc[:20], y.iloc[:20])
+ scores = model.robert_scores()
+ assert 0 <= scores["robert_score"] <= 10
+ assert "cv_score_combined" in scores["components"]
+ pdf = tmp_path / "ROBERT_report.pdf"
+ assert pdf.is_file()
+ assert scores["pdf_path"] == str(pdf)
+
+
def test_fit_accepts_unused_fit_params(tmp_path, fast_robert_kwargs):
"""Sklearn Pipeline may pass extra fit kwargs; they should not raise."""
df = pd.read_csv(_REG_CSV, encoding="utf-8")
diff --git a/tests/test_8uq.py b/tests/test_8uq.py
index d9ff00c..cfaf80f 100644
--- a/tests/test_8uq.py
+++ b/tests/test_8uq.py
@@ -82,9 +82,7 @@ def test_fit_predict_meta_uncertainty_modes(tmp_path, fast_robert_kwargs):
assert y_hat.shape == uq_meta.shape
assert np.isfinite(uq_meta).all() and (uq_meta >= 0).all()
_, uq_total = model.predict(X_hold, return_uncertainty="total")
- y_d, uq_m, uq_meta2, uq_tot = model.predict(
- X_hold, return_uncertainty="decomposed"
- )
+ y_d, uq_m, uq_meta2, uq_tot = model.predict(X_hold, return_uncertainty="decomposed")
assert y_d.shape == uq_m.shape == uq_meta2.shape == uq_tot.shape
assert np.all(uq_tot >= uq_m - 1e-9)
assert np.allclose(uq_tot, uq_total, rtol=1e-5, atol=1e-5)
@@ -130,9 +128,9 @@ def test_score_prefers_calibrated_scale():
abs_res = np.array([1.0, 1.0, 1.0, 1.0])
bad = np.full(4, 10.0)
good = np.full(4, 1.0)
- assert score_uncertainty_candidate(good, abs_res, 0.9) < score_uncertainty_candidate(
- bad, abs_res, 0.9
- )
+ assert score_uncertainty_candidate(
+ good, abs_res, 0.9
+ ) < score_uncertainty_candidate(bad, abs_res, 0.9)
def test_evaluate_uq_candidates_deterministic():
diff --git a/tests/test_easyrob.py b/tests/test_easyrob.py
index bbaec0b..795d168 100644
--- a/tests/test_easyrob.py
+++ b/tests/test_easyrob.py
@@ -29,10 +29,10 @@
sys.path.insert(0, str(PROJECT_ROOT))
# Third-party imports
-import pandas as pd
-import pytest
-from PySide6.QtCore import Qt, QCoreApplication
-from PySide6.QtWidgets import (
+import pandas as pd # noqa: E402
+import pytest # noqa: E402
+from PySide6.QtCore import Qt, QCoreApplication # noqa: E402
+from PySide6.QtWidgets import ( # noqa: E402
QListWidgetItem,
QMessageBox,
QDialog,
@@ -40,16 +40,17 @@
QPushButton,
QTableWidget,
QTableWidgetItem,
-)
-from rdkit import Chem
+) # noqa: E402
# Local project imports
-from robert.gui_easyrob.main.window import EasyROB
-import robert.gui_easyrob.easyrob as easyrob_module
-import robert.gui_easyrob.main.window as window_module
-import robert.gui_easyrob.tabs.aqme as aqme_module
-import robert.gui_easyrob.tabs.predictions as predictions_module
-import robert.gui_easyrob.tabs.results as results_module
+from robert.gui_easyrob.main.window import EasyROB # noqa: E402
+import robert.gui_easyrob.easyrob as easyrob_module # noqa: E402
+import robert.gui_easyrob.main.window as window_module # noqa: E402
+import robert.gui_easyrob.tabs.aqme as aqme_module # noqa: E402
+import robert.gui_easyrob.tabs.predictions as predictions_module # noqa: E402
+import robert.gui_easyrob.tabs.results as results_module # noqa: E402
+
+from tests.conftest import aqme_installed # noqa: E402
# ----------------------------------------------------------------------
# Constants
@@ -154,7 +155,9 @@ def dump_console_output(title, text):
print("----- END CONSOLE OUTPUT -----")
-def process_events_until(predicate, timeout_s, poll_interval_s=WORKFLOW_POLL_INTERVAL_S):
+def process_events_until(
+ predicate, timeout_s, poll_interval_s=WORKFLOW_POLL_INTERVAL_S
+):
"""Process Qt events until a condition becomes true or the timeout expires."""
elapsed = 0.0
while elapsed < timeout_s:
@@ -212,10 +215,16 @@ def wait_for_workflow_completion(
all_dirs_exist = all((output_dir / name).is_dir() for name in expected_dirs)
pdf_exists = report_pdf.is_file()
if workflow_started and all_dirs_exist and pdf_exists:
- print("[OK] Workflow completed (all output folders AND report PDF detected)")
+ print(
+ "[OK] Workflow completed (all output folders AND report PDF detected)"
+ )
return True, workflow_started, last_console, last_process
- if process is not None and process.poll() is not None and not (all_dirs_exist and pdf_exists):
+ if (
+ process is not None
+ and process.poll() is not None
+ and not (all_dirs_exist and pdf_exists)
+ ):
print(
f"[WARN] Process exited with code {process.returncode} "
"but not all outputs (folders + PDF) are present yet."
@@ -242,21 +251,28 @@ def run_full_workflow_and_wait(window, qtbot, output_dir, expected_dirs, report_
initial_console_text=baseline_text,
)
if not started:
- dump_console_output("[DEBUG] Console at timeout (no start detected):", last_console)
+ dump_console_output(
+ "[DEBUG] Console at timeout (no start detected):", last_console
+ )
pytest.fail("Workflow did not start within timeout")
if not completed:
print("\n[DEBUG] Console at timeout (no completion detected):")
- print("Existing entries in output_dir:", [path.name for path in output_dir.iterdir()])
+ print(
+ "Existing entries in output_dir:",
+ [path.name for path in output_dir.iterdir()],
+ )
print("Report PDF exists:", report_pdf.is_file())
if last_process is not None:
print("Process return code:", last_process.returncode)
dump_console_output("[DEBUG] Timed out console snapshot:", last_console)
pytest.fail("Workflow did not complete within timeout")
+
# ----------------------------------------------------------------------
# Fixtures
# ----------------------------------------------------------------------
+
@pytest.fixture(scope="session")
def test_output_dir():
"""
@@ -310,6 +326,31 @@ def easyrob_window(qtbot, monkeypatch):
return window
+@pytest.fixture
+def predictions_tab(qtbot):
+ """PredictionsTab with an active QApplication (required for QWidget construction)."""
+ tab = predictions_module.PredictionsTab()
+ qtbot.addWidget(tab)
+ return tab
+
+
+@pytest.fixture
+def results_tab_mocks(monkeypatch):
+ """Common ResultsTab mocks for unit tests that avoid PDF/WebEngine widgets."""
+ monkeypatch.setattr(results_module.QTimer, "singleShot", lambda ms, fn: None)
+ monkeypatch.setattr(
+ results_module,
+ "PDFViewer",
+ lambda pdf_path, thread_pool: results_module.QWidget(),
+ )
+
+
+def _results_tab(qtbot, input_csv_path: str):
+ tab = results_module.ResultsTab(None, str(input_csv_path))
+ qtbot.addWidget(tab)
+ return tab
+
+
# =====================================================
# Basic Initialization Tests
# =====================================================
@@ -332,6 +373,7 @@ def test_easyrob_factory_returns_main_window_class():
"""The lightweight factory module returns the real main window class."""
assert easyrob_module.get_main_window_class() is EasyROB
+
def test_all_tabs_created(easyrob_window):
"""All expected top-level tabs are present."""
window = easyrob_window
@@ -352,7 +394,9 @@ def test_dropdowns_populated(easyrob_window):
window = easyrob_window
assert window.type_dropdown.count() == 2
- items = [window.type_dropdown.itemText(i) for i in range(window.type_dropdown.count())]
+ items = [
+ window.type_dropdown.itemText(i) for i in range(window.type_dropdown.count())
+ ]
assert "Regression" in items
assert "Classification" in items
@@ -389,8 +433,7 @@ def test_move_to_selected_and_back(easyrob_window):
for i in range(window.available_list.count())
]
ignore_items = [
- window.ignore_list.item(i).text()
- for i in range(window.ignore_list.count())
+ window.ignore_list.item(i).text() for i in range(window.ignore_list.count())
]
assert "col3" in available_items
@@ -445,7 +488,9 @@ def test_load_csv_columns(easyrob_window, tmp_path):
assert set(available_items) == {"ID", "Name", "Target", "Feature1"}
-def test_load_csv_columns_auto_ignores_smiles_and_prefers_code_name(easyrob_window, tmp_path):
+def test_load_csv_columns_auto_ignores_smiles_and_prefers_code_name(
+ easyrob_window, tmp_path
+):
"""SMILES is auto-ignored and code_name is auto-selected when present."""
window = easyrob_window
@@ -467,8 +512,7 @@ def test_load_csv_columns_auto_ignores_smiles_and_prefers_code_name(easyrob_wind
for i in range(window.available_list.count())
}
ignored_items = {
- window.ignore_list.item(i).text()
- for i in range(window.ignore_list.count())
+ window.ignore_list.item(i).text() for i in range(window.ignore_list.count())
}
assert "SMILES" not in available_items
@@ -476,7 +520,9 @@ def test_load_csv_columns_auto_ignores_smiles_and_prefers_code_name(easyrob_wind
assert window.names_dropdown.currentText() == "code_name"
-def test_set_file_path_updates_ui_and_skips_redundant_reload(easyrob_window, tmp_path, monkeypatch):
+def test_set_file_path_updates_ui_and_skips_redundant_reload(
+ easyrob_window, tmp_path, monkeypatch
+):
"""set_file_path updates labels and avoids reloading unchanged files unless forced."""
window = easyrob_window
csv_path = tmp_path / "input.csv"
@@ -490,12 +536,36 @@ def test_set_file_path_updates_ui_and_skips_redundant_reload(easyrob_window, tmp
"check_aqme_workflow": 0,
}
- monkeypatch.setattr(window, "load_csv_columns", lambda: calls.__setitem__("load_csv_columns", calls["load_csv_columns"] + 1))
- monkeypatch.setattr(window, "refresh_tabs", lambda file_path: calls.__setitem__("refresh_tabs", calls["refresh_tabs"] + 1))
+ monkeypatch.setattr(
+ window,
+ "load_csv_columns",
+ lambda: calls.__setitem__("load_csv_columns", calls["load_csv_columns"] + 1),
+ )
+ monkeypatch.setattr(
+ window,
+ "refresh_tabs",
+ lambda file_path: calls.__setitem__("refresh_tabs", calls["refresh_tabs"] + 1),
+ )
monkeypatch.setattr(window, "_is_molssi_csv", lambda file_path: False)
- monkeypatch.setattr(window, "check_molssi_descriptors", lambda: calls.__setitem__("check_molssi_descriptors", calls["check_molssi_descriptors"] + 1))
- monkeypatch.setattr(window, "_update_unified_smiles_context", lambda: calls.__setitem__("update_smiles", calls["update_smiles"] + 1))
- monkeypatch.setattr(window, "check_aqme_workflow", lambda: calls.__setitem__("check_aqme_workflow", calls["check_aqme_workflow"] + 1))
+ monkeypatch.setattr(
+ window,
+ "check_molssi_descriptors",
+ lambda: calls.__setitem__(
+ "check_molssi_descriptors", calls["check_molssi_descriptors"] + 1
+ ),
+ )
+ monkeypatch.setattr(
+ window,
+ "_update_unified_smiles_context",
+ lambda: calls.__setitem__("update_smiles", calls["update_smiles"] + 1),
+ )
+ monkeypatch.setattr(
+ window,
+ "check_aqme_workflow",
+ lambda: calls.__setitem__(
+ "check_aqme_workflow", calls["check_aqme_workflow"] + 1
+ ),
+ )
window.tab_widget_aqme.df_mapped_smiles = object()
window.set_file_path(str(csv_path))
@@ -526,9 +596,23 @@ def test_set_and_clear_csv_test_path_updates_ui(easyrob_window, tmp_path, monkey
pd.DataFrame({"SMILES": ["C"], "target": [1.0]}).to_csv(csv_path, index=False)
calls = {"update_smiles": 0, "check_aqme_workflow": 0, "refresh_tabs": 0}
- monkeypatch.setattr(window, "_update_unified_smiles_context", lambda: calls.__setitem__("update_smiles", calls["update_smiles"] + 1))
- monkeypatch.setattr(window, "check_aqme_workflow", lambda: calls.__setitem__("check_aqme_workflow", calls["check_aqme_workflow"] + 1))
- monkeypatch.setattr(window, "refresh_tabs", lambda file_path: calls.__setitem__("refresh_tabs", calls["refresh_tabs"] + 1))
+ monkeypatch.setattr(
+ window,
+ "_update_unified_smiles_context",
+ lambda: calls.__setitem__("update_smiles", calls["update_smiles"] + 1),
+ )
+ monkeypatch.setattr(
+ window,
+ "check_aqme_workflow",
+ lambda: calls.__setitem__(
+ "check_aqme_workflow", calls["check_aqme_workflow"] + 1
+ ),
+ )
+ monkeypatch.setattr(
+ window,
+ "refresh_tabs",
+ lambda file_path: calls.__setitem__("refresh_tabs", calls["refresh_tabs"] + 1),
+ )
window.set_csv_test_path(str(csv_path))
@@ -541,7 +625,10 @@ def test_set_and_clear_csv_test_path_updates_ui(easyrob_window, tmp_path, monkey
window.clear_test_file()
assert window.csv_test_path is None
- assert window.csv_test_label.label.text() == "Drag & Drop a CSV external test file here (optional)"
+ assert (
+ window.csv_test_label.label.text()
+ == "Drag & Drop a CSV external test file here (optional)"
+ )
assert window.csv_test_label.toolTip() == ""
assert window.clear_test_button.isHidden()
assert calls["update_smiles"] == 2
@@ -569,14 +656,20 @@ def test_reset_ui_after_process_restores_buttons(easyrob_window):
def test_open_external_url_uses_browser(easyrob_window, monkeypatch):
"""open_external_url delegates to the browser helper."""
opened = {}
- monkeypatch.setattr(window_module.webbrowser, "open", lambda url, new=0: opened.update({"url": url, "new": new}))
+ monkeypatch.setattr(
+ window_module.webbrowser,
+ "open",
+ lambda url, new=0: opened.update({"url": url, "new": new}),
+ )
easyrob_window.open_external_url("https://example.com")
assert opened == {"url": "https://example.com", "new": 2}
-def test_close_event_ignores_when_running_worker_is_not_stopped(easyrob_window, monkeypatch):
+def test_close_event_ignores_when_running_worker_is_not_stopped(
+ easyrob_window, monkeypatch
+):
"""Closing is aborted if ROBERT is still running and the user declines stopping it."""
window = easyrob_window
@@ -591,7 +684,9 @@ class DummyWorker:
def isRunning(self):
return True
- monkeypatch.setattr(window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.No)
+ monkeypatch.setattr(
+ window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.No
+ )
event = DummyEvent()
window.worker = DummyWorker()
@@ -646,10 +741,20 @@ def quit(self):
shutdown_calls = {"n": 0}
timer_calls = {"n": 0}
- monkeypatch.setattr(window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.Yes)
+ monkeypatch.setattr(
+ window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.Yes
+ )
monkeypatch.setattr(window_module, "QEventLoop", lambda: loop)
- monkeypatch.setattr(window_module.QTimer, "singleShot", lambda ms, fn: timer_calls.__setitem__("n", timer_calls["n"] + 1))
- monkeypatch.setattr(window, "_shutdown_molssi_async", lambda: shutdown_calls.__setitem__("n", shutdown_calls["n"] + 1))
+ monkeypatch.setattr(
+ window_module.QTimer,
+ "singleShot",
+ lambda ms, fn: timer_calls.__setitem__("n", timer_calls["n"] + 1),
+ )
+ monkeypatch.setattr(
+ window,
+ "_shutdown_molssi_async",
+ lambda: shutdown_calls.__setitem__("n", shutdown_calls["n"] + 1),
+ )
event = DummyEvent()
window.worker = worker
@@ -664,7 +769,9 @@ def quit(self):
assert event.ignored == 1
-def test_close_event_without_running_worker_starts_async_shutdown(easyrob_window, monkeypatch):
+def test_close_event_without_running_worker_starts_async_shutdown(
+ easyrob_window, monkeypatch
+):
"""Closing without an active ROBERT worker still routes through async cleanup."""
window = easyrob_window
@@ -676,7 +783,11 @@ def ignore(self):
self.ignored += 1
shutdown_calls = {"n": 0}
- monkeypatch.setattr(window, "_shutdown_molssi_async", lambda: shutdown_calls.__setitem__("n", shutdown_calls["n"] + 1))
+ monkeypatch.setattr(
+ window,
+ "_shutdown_molssi_async",
+ lambda: shutdown_calls.__setitem__("n", shutdown_calls["n"] + 1),
+ )
event = DummyEvent()
window.worker = None
@@ -690,7 +801,11 @@ def ignore(self):
def test_show_contact_dialog_executes_modal(easyrob_window, monkeypatch):
"""Contact dialog is created and executed."""
exec_calls = {"n": 0}
- monkeypatch.setattr(window_module.QDialog, "exec", lambda self: exec_calls.__setitem__("n", exec_calls["n"] + 1))
+ monkeypatch.setattr(
+ window_module.QDialog,
+ "exec",
+ lambda self: exec_calls.__setitem__("n", exec_calls["n"] + 1),
+ )
easyrob_window.show_contact_dialog()
@@ -700,7 +815,11 @@ def test_show_contact_dialog_executes_modal(easyrob_window, monkeypatch):
def test_show_version_dialog_executes_modal(easyrob_window, monkeypatch):
"""Version dialog is created and executed."""
exec_calls = {"n": 0}
- monkeypatch.setattr(window_module.QDialog, "exec", lambda self: exec_calls.__setitem__("n", exec_calls["n"] + 1))
+ monkeypatch.setattr(
+ window_module.QDialog,
+ "exec",
+ lambda self: exec_calls.__setitem__("n", exec_calls["n"] + 1),
+ )
easyrob_window.show_version_dialog()
@@ -709,6 +828,7 @@ def test_show_version_dialog_executes_modal(easyrob_window, monkeypatch):
def test_show_tutorial_dialog_reuses_visible_dialog(easyrob_window):
"""Reopening the tutorial dialog reuses the visible instance."""
+
class DummyDialog:
def __init__(self):
self.raise_calls = 0
@@ -760,7 +880,11 @@ def test_check_for_pdfs_and_images_runs_with_existing_outputs(easyrob_window, tm
def test_refresh_tabs_updates_children_and_schedules_once(easyrob_window, monkeypatch):
"""refresh_tabs stores the latest path and coalesces duplicate scheduling."""
timer_calls = {"n": 0}
- monkeypatch.setattr(window_module.QTimer, "singleShot", lambda ms, fn: timer_calls.__setitem__("n", timer_calls["n"] + 1))
+ monkeypatch.setattr(
+ window_module.QTimer,
+ "singleShot",
+ lambda ms, fn: timer_calls.__setitem__("n", timer_calls["n"] + 1),
+ )
easyrob_window._refresh_scheduled = False
easyrob_window.refresh_tabs("one.csv")
@@ -774,11 +898,31 @@ def test_execute_refresh_tabs_calls_all_child_refreshes(easyrob_window, monkeypa
"""_execute_refresh_tabs fans out the refresh to child tabs and file-based checks."""
calls = {"results": 0, "images": 0, "predictions": 0, "pdfs": 0, "imgs": 0}
- monkeypatch.setattr(easyrob_window.results_tab, "refresh_with_new_path", lambda p: calls.__setitem__("results", calls["results"] + 1))
- monkeypatch.setattr(easyrob_window.images_tab, "refresh_with_new_path", lambda p: calls.__setitem__("images", calls["images"] + 1))
- monkeypatch.setattr(easyrob_window.predictions_tab, "refresh_with_new_path", lambda p: calls.__setitem__("predictions", calls["predictions"] + 1))
- monkeypatch.setattr(easyrob_window, "check_for_pdfs", lambda p: calls.__setitem__("pdfs", calls["pdfs"] + 1))
- monkeypatch.setattr(easyrob_window, "check_for_images", lambda p: calls.__setitem__("imgs", calls["imgs"] + 1))
+ monkeypatch.setattr(
+ easyrob_window.results_tab,
+ "refresh_with_new_path",
+ lambda p: calls.__setitem__("results", calls["results"] + 1),
+ )
+ monkeypatch.setattr(
+ easyrob_window.images_tab,
+ "refresh_with_new_path",
+ lambda p: calls.__setitem__("images", calls["images"] + 1),
+ )
+ monkeypatch.setattr(
+ easyrob_window.predictions_tab,
+ "refresh_with_new_path",
+ lambda p: calls.__setitem__("predictions", calls["predictions"] + 1),
+ )
+ monkeypatch.setattr(
+ easyrob_window,
+ "check_for_pdfs",
+ lambda p: calls.__setitem__("pdfs", calls["pdfs"] + 1),
+ )
+ monkeypatch.setattr(
+ easyrob_window,
+ "check_for_images",
+ lambda p: calls.__setitem__("imgs", calls["imgs"] + 1),
+ )
easyrob_window._pending_refresh_path = "demo.csv"
easyrob_window._refresh_scheduled = True
@@ -790,6 +934,7 @@ def test_execute_refresh_tabs_calls_all_child_refreshes(easyrob_window, monkeypa
def test_stop_process_confirms_and_stops_worker(easyrob_window, monkeypatch):
"""stop_process marks manual stop and schedules worker stop after confirmation."""
+
class DummySignal:
def connect(self, fn):
self.fn = fn
@@ -807,8 +952,14 @@ def stop(self):
worker = DummyWorker()
timer_calls = {"n": 0}
- monkeypatch.setattr(window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.Yes)
- monkeypatch.setattr(window_module.QTimer, "singleShot", lambda ms, fn: (timer_calls.__setitem__("n", timer_calls["n"] + 1), fn())[1])
+ monkeypatch.setattr(
+ window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.Yes
+ )
+ monkeypatch.setattr(
+ window_module.QTimer,
+ "singleShot",
+ lambda ms, fn: (timer_calls.__setitem__("n", timer_calls["n"] + 1), fn())[1],
+ )
easyrob_window.worker = worker
easyrob_window.stop_button.setDisabled(False)
@@ -823,7 +974,9 @@ def stop(self):
def test_stop_process_returns_when_user_declines(easyrob_window, monkeypatch):
"""stop_process does nothing if the user rejects the confirmation dialog."""
- monkeypatch.setattr(window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.No)
+ monkeypatch.setattr(
+ window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.No
+ )
easyrob_window.manual_stop = False
easyrob_window.stop_process()
@@ -1011,14 +1164,24 @@ def test_predictions_filter_dataframe_orders_core_columns(easyrob_window, monkey
"target_pred_sd": [0.3],
}
)
- monkeypatch.setattr(easyrob_window.predictions_tab, "_extract_names_column_from_predict", lambda: "sample_id")
+ monkeypatch.setattr(
+ easyrob_window.predictions_tab,
+ "_extract_names_column_from_predict",
+ lambda: "sample_id",
+ )
filtered = easyrob_window.predictions_tab._filter_prediction_dataframe(df)
- assert list(filtered.columns) == ["Image", "sample_id", "SMILES", "target_pred", "target_pred_sd"]
+ assert list(filtered.columns) == [
+ "Image",
+ "sample_id",
+ "SMILES",
+ "target_pred",
+ "target_pred_sd",
+ ]
-def test_predictions_extract_names_column_from_predict(tmp_path):
+def test_predictions_extract_names_column_from_predict(tmp_path, predictions_tab):
"""The names field is extracted from the stored PREDICT command line."""
predict_dir = tmp_path / "PREDICT"
predict_dir.mkdir()
@@ -1026,23 +1189,19 @@ def test_predictions_extract_names_column_from_predict(tmp_path):
dat_path.write_text('--names "code_name"\n', encoding="utf-8")
(tmp_path / "input.csv").write_text("a,b\n1,2\n", encoding="utf-8")
- tab = predictions_module.PredictionsTab()
- tab._base_path = str(tmp_path / "input.csv")
+ predictions_tab._base_path = str(tmp_path / "input.csv")
- assert tab._extract_names_column_from_predict() == "code_name"
+ assert predictions_tab._extract_names_column_from_predict() == "code_name"
-def test_results_tab_detects_and_refreshes_pdf_tabs(tmp_path, monkeypatch):
+def test_results_tab_detects_and_refreshes_pdf_tabs(tmp_path, qtbot, results_tab_mocks):
"""Results tab discovers PDFs and refreshes when a new path is provided."""
- monkeypatch.setattr(results_module.QTimer, "singleShot", lambda ms, fn: None)
- monkeypatch.setattr(results_module, "PDFViewer", lambda pdf_path, thread_pool: results_module.QWidget())
-
run_dir = tmp_path / "run"
run_dir.mkdir()
first_pdf = run_dir / "ROBERT_report.pdf"
first_pdf.write_text("pdf", encoding="utf-8")
- tab = results_module.ResultsTab(None, str(run_dir / "input.csv"))
+ tab = _results_tab(qtbot, run_dir / "input.csv")
assert first_pdf.name in tab.title_to_path
assert tab.pdf_tab_widget.count() == 1
@@ -1058,9 +1217,9 @@ def test_results_tab_detects_and_refreshes_pdf_tabs(tmp_path, monkeypatch):
assert second_pdf.name in tab.title_to_path
-def test_predictions_show_header_menu_sorts_and_histogram(monkeypatch):
+def test_predictions_show_header_menu_sorts_and_histogram(predictions_tab, monkeypatch):
"""Header context menu routes to sort and histogram actions."""
- tab = predictions_module.PredictionsTab()
+ tab = predictions_tab
df = pd.DataFrame({"num": [2, 1], "txt": ["b", "a"]})
class DummyHeader:
@@ -1102,7 +1261,11 @@ def exec(self, pos):
monkeypatch.setattr(predictions_module, "QMenu", DummyMenu)
histogram_calls = {"n": 0}
- monkeypatch.setattr(tab, "_show_histogram", lambda series, name: histogram_calls.__setitem__("n", histogram_calls["n"] + 1))
+ monkeypatch.setattr(
+ tab,
+ "_show_histogram",
+ lambda series, name: histogram_calls.__setitem__("n", histogram_calls["n"] + 1),
+ )
header = DummyHeader()
table = DummyTable()
@@ -1117,9 +1280,11 @@ def exec(self, pos):
assert histogram_calls["n"] == 1
-def test_predictions_show_histogram_menu_non_numeric_shows_message(monkeypatch):
+def test_predictions_show_histogram_menu_non_numeric_shows_message(
+ predictions_tab, monkeypatch
+):
"""Non-numeric columns show an informational popup instead of plotting."""
- tab = predictions_module.PredictionsTab()
+ tab = predictions_tab
df = pd.DataFrame({"txt": ["a", "b"]})
info_calls = {"n": 0}
@@ -1127,48 +1292,121 @@ class DummyHeader:
def logicalIndexAt(self, pos):
return 0
- monkeypatch.setattr(predictions_module.QMessageBox, "information", lambda *args, **kwargs: info_calls.__setitem__("n", info_calls["n"] + 1))
+ monkeypatch.setattr(
+ predictions_module.QMessageBox,
+ "information",
+ lambda *args, **kwargs: info_calls.__setitem__("n", info_calls["n"] + 1),
+ )
tab._show_histogram_menu_header(None, df, DummyHeader())
assert info_calls["n"] == 1
-def test_predictions_show_histogram_uses_matplotlib(monkeypatch):
+def test_predictions_show_histogram_uses_matplotlib(predictions_tab, monkeypatch):
"""Histogram plotting delegates to matplotlib without blocking."""
- tab = predictions_module.PredictionsTab()
+ tab = predictions_tab
series = pd.Series([1, 2, 3])
- calls = {"figure": 0, "title": 0, "xlabel": 0, "ylabel": 0, "grid": 0, "show": 0, "hist": 0}
+ calls = {
+ "figure": 0,
+ "title": 0,
+ "xlabel": 0,
+ "ylabel": 0,
+ "grid": 0,
+ "show": 0,
+ "hist": 0,
+ }
- monkeypatch.setattr(predictions_module.plt, "figure", lambda: calls.__setitem__("figure", calls["figure"] + 1))
- monkeypatch.setattr(predictions_module.plt, "title", lambda name: calls.__setitem__("title", calls["title"] + 1))
- monkeypatch.setattr(predictions_module.plt, "xlabel", lambda name: calls.__setitem__("xlabel", calls["xlabel"] + 1))
- monkeypatch.setattr(predictions_module.plt, "ylabel", lambda name: calls.__setitem__("ylabel", calls["ylabel"] + 1))
- monkeypatch.setattr(predictions_module.plt, "grid", lambda enabled: calls.__setitem__("grid", calls["grid"] + 1))
- monkeypatch.setattr(predictions_module.plt, "show", lambda block=False: calls.__setitem__("show", calls["show"] + 1))
- monkeypatch.setattr(pd.Series, "hist", lambda self, bins=30: calls.__setitem__("hist", calls["hist"] + 1))
+ monkeypatch.setattr(
+ predictions_module.plt,
+ "figure",
+ lambda: calls.__setitem__("figure", calls["figure"] + 1),
+ )
+ monkeypatch.setattr(
+ predictions_module.plt,
+ "title",
+ lambda name: calls.__setitem__("title", calls["title"] + 1),
+ )
+ monkeypatch.setattr(
+ predictions_module.plt,
+ "xlabel",
+ lambda name: calls.__setitem__("xlabel", calls["xlabel"] + 1),
+ )
+ monkeypatch.setattr(
+ predictions_module.plt,
+ "ylabel",
+ lambda name: calls.__setitem__("ylabel", calls["ylabel"] + 1),
+ )
+ monkeypatch.setattr(
+ predictions_module.plt,
+ "grid",
+ lambda enabled: calls.__setitem__("grid", calls["grid"] + 1),
+ )
+ monkeypatch.setattr(
+ predictions_module.plt,
+ "show",
+ lambda block=False: calls.__setitem__("show", calls["show"] + 1),
+ )
+ monkeypatch.setattr(
+ pd.Series,
+ "hist",
+ lambda self, bins=30: calls.__setitem__("hist", calls["hist"] + 1),
+ )
tab._show_histogram(series, "value")
- assert calls == {"figure": 1, "title": 1, "xlabel": 1, "ylabel": 1, "grid": 1, "show": 1, "hist": 1}
+ assert calls == {
+ "figure": 1,
+ "title": 1,
+ "xlabel": 1,
+ "ylabel": 1,
+ "grid": 1,
+ "show": 1,
+ "hist": 1,
+ }
-def test_predictions_add_loaded_df_replaces_loading_tab(monkeypatch):
+def test_predictions_add_loaded_df_replaces_loading_tab(
+ predictions_tab, qtbot, monkeypatch
+):
"""Loaded prediction data replaces the placeholder tab widget."""
- tab = predictions_module.PredictionsTab()
+ tab = predictions_tab
tab._base_path = "demo.csv"
tab.subtabs.addTab(predictions_module.QLabel("Loading"), "No PFI")
df = pd.DataFrame({"SMILES": ["C"], "target_pred": [1.0]})
monkeypatch.setattr(tab, "_filter_prediction_dataframe", lambda frame: frame)
- monkeypatch.setattr(predictions_module, "evaluate_predictions_for_model", lambda base, frame, key: {"pdf_path": "report.pdf", "model": key, "scenario": "demo"})
- monkeypatch.setattr(predictions_module, "get_robert_report_path", lambda base: "report.pdf")
- monkeypatch.setattr(predictions_module, "extract_robert_fragment_image", lambda path, key: None)
- monkeypatch.setattr(predictions_module, "extract_extrapolation_scores", lambda path: {"No_PFI": None})
- monkeypatch.setattr(predictions_module, "extract_extrapolation_fragment", lambda path, key: None)
- monkeypatch.setattr(predictions_module, "find_external_test_pixmaps", lambda base: {})
+ monkeypatch.setattr(
+ predictions_module,
+ "evaluate_predictions_for_model",
+ lambda base, frame, key: {
+ "pdf_path": "report.pdf",
+ "model": key,
+ "scenario": "demo",
+ },
+ )
+ monkeypatch.setattr(
+ predictions_module, "get_robert_report_path", lambda base: "report.pdf"
+ )
+ monkeypatch.setattr(
+ predictions_module, "extract_robert_fragment_image", lambda path, key: None
+ )
+ monkeypatch.setattr(
+ predictions_module,
+ "extract_extrapolation_scores",
+ lambda path: {"No_PFI": None},
+ )
+ monkeypatch.setattr(
+ predictions_module, "extract_extrapolation_fragment", lambda path, key: None
+ )
+ monkeypatch.setattr(
+ predictions_module, "find_external_test_pixmaps", lambda base: {}
+ )
widget = predictions_module.QWidget()
- monkeypatch.setattr(tab, "_create_table_with_stats", lambda frame, info, pdf_image: widget)
+ qtbot.addWidget(widget)
+ monkeypatch.setattr(
+ tab, "_create_table_with_stats", lambda frame, info, pdf_image: widget
+ )
tab._add_loaded_df("No_PFI", df)
@@ -1176,17 +1414,21 @@ def test_predictions_add_loaded_df_replaces_loading_tab(monkeypatch):
assert tab.subtabs.tabText(0) == "No PFI"
-def test_predictions_refresh_with_new_path_loads_csvs_synchronously(tmp_path, monkeypatch):
+def test_predictions_refresh_with_new_path_loads_csvs_synchronously(
+ tmp_path, predictions_tab, monkeypatch
+):
"""refresh_with_new_path discovers CSVs and materializes tabs when tasks run synchronously."""
csv_test_dir = tmp_path / "PREDICT" / "csv_test"
csv_test_dir.mkdir(parents=True)
no_pfi_path = csv_test_dir / "demo_No_PFI.csv"
pfi_path = csv_test_dir / "demo_PFI.csv"
- pd.DataFrame({"SMILES": ["C"], "target_pred": [1.0]}).to_csv(no_pfi_path, index=False)
+ pd.DataFrame({"SMILES": ["C"], "target_pred": [1.0]}).to_csv(
+ no_pfi_path, index=False
+ )
pd.DataFrame({"SMILES": ["CC"], "target_pred": [2.0]}).to_csv(pfi_path, index=False)
- tab = predictions_module.PredictionsTab()
+ tab = predictions_tab
created = []
class FakePlaceholder:
@@ -1273,17 +1515,32 @@ def run(self):
monkeypatch.setattr(
predictions_module,
"evaluate_predictions_for_model",
- lambda base, frame, key: {"pdf_path": "report.pdf", "model": key, "scenario": "demo"},
+ lambda base, frame, key: {
+ "pdf_path": "report.pdf",
+ "model": key,
+ "scenario": "demo",
+ },
+ )
+ monkeypatch.setattr(
+ predictions_module, "get_robert_report_path", lambda base: "report.pdf"
+ )
+ monkeypatch.setattr(
+ predictions_module, "extract_robert_fragment_image", lambda path, key: None
+ )
+ monkeypatch.setattr(
+ predictions_module, "extract_extrapolation_scores", lambda path: {}
+ )
+ monkeypatch.setattr(
+ predictions_module, "extract_extrapolation_fragment", lambda path, key: None
+ )
+ monkeypatch.setattr(
+ predictions_module, "find_external_test_pixmaps", lambda base: {}
)
- monkeypatch.setattr(predictions_module, "get_robert_report_path", lambda base: "report.pdf")
- monkeypatch.setattr(predictions_module, "extract_robert_fragment_image", lambda path, key: None)
- monkeypatch.setattr(predictions_module, "extract_extrapolation_scores", lambda path: {})
- monkeypatch.setattr(predictions_module, "extract_extrapolation_fragment", lambda path, key: None)
- monkeypatch.setattr(predictions_module, "find_external_test_pixmaps", lambda base: {})
monkeypatch.setattr(
tab,
"_create_table_with_stats",
- lambda frame, info, pdf_image: created.append((info["model"], frame.copy())) or object(),
+ lambda frame, info, pdf_image: created.append((info["model"], frame.copy()))
+ or object(),
)
tab.refresh_with_new_path(str(tmp_path / "input.csv"))
@@ -1291,21 +1548,23 @@ def run(self):
assert tab.placeholder.isHidden()
assert not tab.subtabs.isHidden()
assert tab.subtabs.count() == 2
- assert {tab.subtabs.tabText(i) for i in range(tab.subtabs.count())} == {"No PFI", "PFI"}
+ assert {tab.subtabs.tabText(i) for i in range(tab.subtabs.count())} == {
+ "No PFI",
+ "PFI",
+ }
assert {model for model, _ in created} == {"No_PFI", "PFI"}
-def test_results_clear_pdf_tabs_removes_placeholders(tmp_path, monkeypatch):
+def test_results_clear_pdf_tabs_removes_placeholders(
+ tmp_path, qtbot, results_tab_mocks
+):
"""clear_pdf_tabs removes tracked tabs and resets internal maps."""
- monkeypatch.setattr(results_module.QTimer, "singleShot", lambda ms, fn: None)
- monkeypatch.setattr(results_module, "PDFViewer", lambda pdf_path, thread_pool: results_module.QWidget())
-
run_dir = tmp_path / "run"
run_dir.mkdir()
pdf_path = run_dir / "ROBERT_report.pdf"
pdf_path.write_text("pdf", encoding="utf-8")
- tab = results_module.ResultsTab(None, str(run_dir / "input.csv"))
+ tab = _results_tab(qtbot, run_dir / "input.csv")
assert tab.pdf_tab_widget.count() == 1
tab.clear_pdf_tabs()
@@ -1315,15 +1574,14 @@ def test_results_clear_pdf_tabs_removes_placeholders(tmp_path, monkeypatch):
assert tab.title_to_path == {}
-def test_results_maybe_materialize_tab_builds_viewer(monkeypatch, tmp_path):
+def test_results_maybe_materialize_tab_builds_viewer(
+ qtbot, results_tab_mocks, monkeypatch, tmp_path
+):
"""Selecting a placeholder PDF tab materializes a real viewer."""
- monkeypatch.setattr(results_module.QTimer, "singleShot", lambda ms, fn: None)
- monkeypatch.setattr(results_module, "PDFViewer", lambda pdf_path, thread_pool: results_module.QWidget())
-
run_dir = tmp_path / "run"
run_dir.mkdir()
pdf_path = run_dir / "ROBERT_report.pdf"
- tab = results_module.ResultsTab(None, str(run_dir / "input.csv"))
+ tab = _results_tab(qtbot, run_dir / "input.csv")
pdf_path.write_text("pdf", encoding="utf-8")
tab.clear_pdf_tabs()
tab.pdf_tabs[str(pdf_path)] = None
@@ -1333,24 +1591,27 @@ def test_results_maybe_materialize_tab_builds_viewer(monkeypatch, tmp_path):
tab.pdf_tab_widget.blockSignals(False)
viewer = results_module.QWidget()
- monkeypatch.setattr(tab, "_materialize_pdf_viewer", lambda index, path: tab.pdf_tabs.__setitem__(path, viewer))
+ monkeypatch.setattr(
+ tab,
+ "_materialize_pdf_viewer",
+ lambda index, path: tab.pdf_tabs.__setitem__(path, viewer),
+ )
tab._maybe_materialize_tab(0)
assert tab.pdf_tabs[str(pdf_path)] is viewer
-def test_results_index_of_title_returns_expected_index(tmp_path, monkeypatch):
+def test_results_index_of_title_returns_expected_index(
+ tmp_path, qtbot, results_tab_mocks
+):
"""Tab titles can be resolved back to their index."""
- monkeypatch.setattr(results_module.QTimer, "singleShot", lambda ms, fn: None)
- monkeypatch.setattr(results_module, "PDFViewer", lambda pdf_path, thread_pool: results_module.QWidget())
-
run_dir = tmp_path / "run"
run_dir.mkdir()
pdf_path = run_dir / "ROBERT_report.pdf"
pdf_path.write_text("pdf", encoding="utf-8")
- tab = results_module.ResultsTab(None, str(run_dir / "input.csv"))
+ tab = _results_tab(qtbot, run_dir / "input.csv")
assert tab._index_of_title(pdf_path.name) == 0
assert tab._index_of_title("missing.pdf") == -1
@@ -1365,7 +1626,14 @@ def test_workflow_selector_options(easyrob_window):
"""Workflow selector has all required entries and default."""
window = easyrob_window
- expected_workflows = ["Full Workflow", "CURATE", "GENERATE", "PREDICT", "VERIFY", "REPORT"]
+ expected_workflows = [
+ "Full Workflow",
+ "CURATE",
+ "GENERATE",
+ "PREDICT",
+ "VERIFY",
+ "REPORT",
+ ]
workflow_items = [
window.workflow_selector.itemText(i)
@@ -1402,6 +1670,7 @@ def test_progress_bar_exists(easyrob_window):
# REAL End-to-End User Workflow Test (GUI)
# =====================================================
+
@pytest.mark.parametrize(
"test_scenario",
[
@@ -1440,8 +1709,11 @@ def test_full_user_workflow_end_to_end(
* AQME workflow checkbox enabled.
* Existing ROBERT folders (if any) are removed before starting.
* AQME generates a mapped CSV and ROBERT runs with it.
- * Check predictions tab
+ * Check predictions tab
"""
+ if test_scenario == "aqme_regression" and not aqme_installed():
+ pytest.skip("AQME is not installed (pip install aqme==2.0.0)")
+
window = easyrob_window
config = SCENARIO_CONFIG[test_scenario]
@@ -1523,7 +1795,9 @@ def test_full_user_workflow_end_to_end(
window.move_to_selected()
# Collect ignored items from GUI
- ignore_items = [window.ignore_list.item(i).text() for i in range(window.ignore_list.count())]
+ ignore_items = [
+ window.ignore_list.item(i).text() for i in range(window.ignore_list.count())
+ ]
print(f" • Ignored columns: {ignore_items}")
actual = set(ignore_items)
@@ -1546,8 +1820,7 @@ def test_full_user_workflow_end_to_end(
extra = actual - expected
assert not missing and not extra, (
- f"\nMissing items: {missing}"
- f"\nExtra items: {extra}"
+ f"\nMissing items: {missing}\nExtra items: {extra}"
)
available_items = [
@@ -1661,9 +1934,16 @@ def test_full_user_workflow_end_to_end(
# ------------------------------------------------------------------
if test_scenario == "existing_dirs_stop":
# Bootstrap the expected output state if it is not already present.
- if not all((output_dir / d).is_dir() for d in expected_dirs) or not report_pdf.is_file():
- print("[SETUP] Existing output folders missing; running baseline workflow first...")
- run_full_workflow_and_wait(window, qtbot, output_dir, expected_dirs, report_pdf)
+ if (
+ not all((output_dir / d).is_dir() for d in expected_dirs)
+ or not report_pdf.is_file()
+ ):
+ print(
+ "[SETUP] Existing output folders missing; running baseline workflow first..."
+ )
+ run_full_workflow_and_wait(
+ window, qtbot, output_dir, expected_dirs, report_pdf
+ )
QCoreApplication.processEvents()
assert all((output_dir / d).is_dir() for d in expected_dirs)
assert report_pdf.is_file()
@@ -1676,12 +1956,16 @@ def test_full_user_workflow_end_to_end(
started = wait_for_workflow_start(window, baseline_text)
if not started:
- pytest.fail("Re-run did not start within timeout after existing-folders popup")
+ pytest.fail(
+ "Re-run did not start within timeout after existing-folders popup"
+ )
print("\n[STEP 8] Clicking Stop ROBERT button...")
qtbot.mouseClick(window.stop_button, Qt.LeftButton)
- print("[STEP 9] Waiting for workflow to stop and GUI to return to idle state...")
+ print(
+ "[STEP 9] Waiting for workflow to stop and GUI to return to idle state..."
+ )
stopped = process_events_until(
lambda: (
getattr(window, "worker", None) is None
@@ -1706,7 +1990,9 @@ def test_full_user_workflow_end_to_end(
), "Expected a stop-process QMessageBox.question"
final_console_text = window.console_output.toPlainText()
- dump_console_output("[STEP 10] Final console output (existing_dirs_stop):", final_console_text)
+ dump_console_output(
+ "[STEP 10] Final console output (existing_dirs_stop):", final_console_text
+ )
assert window.file_path == str(csv_path)
@@ -1756,10 +2042,14 @@ def test_full_user_workflow_end_to_end(
print(f"[OK] AQME mapped CSV exists: {mapped_csv_path}")
predict_csv_test_dir = output_dir / "PREDICT" / "csv_test"
- assert predict_csv_test_dir.is_dir(), "PREDICT/csv_test directory was not created"
+ assert predict_csv_test_dir.is_dir(), (
+ "PREDICT/csv_test directory was not created"
+ )
prediction_csvs = predictions_module.find_prediction_csvs(str(csv_path))
- assert prediction_csvs, "Predictions CSVs were not generated for the external test set"
+ assert prediction_csvs, (
+ "Predictions CSVs were not generated for the external test set"
+ )
print(f"[OK] Predictions CSVs detected: {sorted(prediction_csvs)}")
print("\n" + "=" * 80)
@@ -1777,6 +2067,9 @@ def test_run_aqme_only_end_to_end(easyrob_window, test_output_dir, qtbot, monkey
- wait for the subprocess to finish
- verify AQME outputs were generated
"""
+ if not aqme_installed():
+ pytest.skip("AQME is not installed (pip install aqme==2.0.0)")
+
window = easyrob_window
finished = {"exit_code": None}
@@ -1840,6 +2133,7 @@ def _on_process_finished_stub(exit_code):
# ChemDraw → popup → table → CSV → main window test
# =====================================================
+
def test_open_chemdraw_popup_end_to_end_cdxml(
easyrob_window, qtbot, monkeypatch, test_output_dir
):
@@ -1881,16 +2175,13 @@ def exec(self):
return QDialog.Accepted
# Patch the symbol that open_chemdraw_popup uses
- monkeypatch.setattr(
- aqme_module, "ChemDrawFileDialog", FakeChemDrawFileDialog
- )
+ monkeypatch.setattr(aqme_module, "ChemDrawFileDialog", FakeChemDrawFileDialog)
# --------------------------------------------------------------
# 4. Stub QFileDialog.getSaveFileName so CSV is written to tmp_path
# --------------------------------------------------------------
csv_path = test_output_dir / "chemdraw_table_output.csv"
-
def _fake_get_save_file_name(*args, **kwargs):
return (str(csv_path), "CSV Files (*.csv)")
@@ -1914,7 +2205,9 @@ def _fake_dialog_exec(self: QDialog):
table = self.findChild(QTableWidget)
assert table is not None, "ChemDraw table dialog should contain a QTableWidget."
- headers = [table.horizontalHeaderItem(i).text() for i in range(table.columnCount())]
+ headers = [
+ table.horizontalHeaderItem(i).text() for i in range(table.columnCount())
+ ]
assert "SMILES" in headers
assert "code_name" in headers
assert "target" in headers
@@ -1949,7 +2242,9 @@ def _fake_dialog_exec(self: QDialog):
if "Save as CSV" in btn.text():
save_button = btn
break
- assert save_button is not None, "Could not find 'Save as CSV' button in ChemDraw dialog."
+ assert save_button is not None, (
+ "Could not find 'Save as CSV' button in ChemDraw dialog."
+ )
# Click it → this will call save_to_csv() and then dialog.accept()
save_button.click()
@@ -1984,7 +2279,6 @@ def _fake_dialog_exec(self: QDialog):
# Columns should be loaded into dropdowns
y_items = [window.y_dropdown.itemText(i) for i in range(window.y_dropdown.count())]
- names_items = [window.names_dropdown.itemText(i) for i in range(window.names_dropdown.count())]
assert "SMILES" in y_items
assert "code_name" in y_items
@@ -1995,8 +2289,7 @@ def _fake_dialog_exec(self: QDialog):
for i in range(window.available_list.count())
]
ignored_items = [
- window.ignore_list.item(i).text()
- for i in range(window.ignore_list.count())
+ window.ignore_list.item(i).text() for i in range(window.ignore_list.count())
]
assert "SMILES" in ignored_items
assert "code_name" in available_items
diff --git a/tests/test_plot_metrics.py b/tests/test_plot_metrics.py
new file mode 100644
index 0000000..36d176a
--- /dev/null
+++ b/tests/test_plot_metrics.py
@@ -0,0 +1,36 @@
+#!/usr/bin/env python
+
+"""Tests for VERIFY metrics plotting."""
+
+import pytest
+
+from robert.utils import plot_metrics
+
+
+@pytest.fixture
+def verify_plot_env(tmp_path, monkeypatch):
+ monkeypatch.chdir(tmp_path)
+ verify_dir = tmp_path / "VERIFY"
+ verify_dir.mkdir()
+ (verify_dir / "RF_No_PFI").touch()
+ return tmp_path
+
+
+def test_plot_metrics_equal_values(verify_plot_env):
+ """Degenerate axis limits (all metrics equal) must not break plotting."""
+ model_data = {"model": "RF_db.csv"}
+ verify_metrics = {
+ "metrics": [0.5, 0.5, 0.5, 0.5],
+ "test_names": ["Model", "y-mean", "y-shuffle", "one-hot"],
+ "colors": ["#808080", "#1f77b4", "#1f77b4", "#1f77b4"],
+ "higher_thres": 0.3,
+ "unclear_higher_thres": 0.2,
+ "lower_thres": 0.3,
+ "unclear_lower_thres": 0.2,
+ }
+ verify_results = {"error_type": "r2"}
+
+ msg = plot_metrics(model_data, "No_PFI", verify_metrics, verify_results)
+ png = verify_plot_env / "VERIFY" / "VERIFY_tests_RF_No_PFI.png"
+ assert png.is_file()
+ assert "VERIFY plot saved" in msg
diff --git a/tests/test_vr_bo.py b/tests/test_vr_bo.py
new file mode 100644
index 0000000..89bbfda
--- /dev/null
+++ b/tests/test_vr_bo.py
@@ -0,0 +1,81 @@
+#!/usr/bin/env python
+
+"""Tests for Voting Regressor/Classifier Bayesian optimization."""
+
+from types import SimpleNamespace
+
+import numpy as np
+
+from robert.argument_parser import options_add
+from robert.utils import load_minimal_model, load_model, model_adjust_params
+
+
+def _vr_adapter(problem_type="reg"):
+ args = options_add()
+ args.seed = 42
+ args.type = problem_type
+ return SimpleNamespace(args=args)
+
+
+def test_vr_member_hyperparameters_affect_predictions():
+ adapter = _vr_adapter("reg")
+ rng = np.random.RandomState(0)
+ X = rng.rand(24, 6)
+ y = X.sum(axis=1) + rng.randn(24) * 0.05
+
+ params_low = model_adjust_params(adapter, "VR", dict(load_minimal_model("VR")))
+ params_high = model_adjust_params(
+ adapter,
+ "VR",
+ {**load_minimal_model("VR"), "rf_n_estimators": 90, "gb_n_estimators": 90},
+ )
+
+ model_low = load_model(adapter, "VR", **params_low)
+ model_high = load_model(adapter, "VR", **params_high)
+ model_low.fit(X, y)
+ model_high.fit(X, y)
+
+ assert not np.allclose(model_low.predict(X), model_high.predict(X))
+
+
+def test_vr_ensemble_weights_affect_predictions():
+ adapter = _vr_adapter("reg")
+ rng = np.random.RandomState(1)
+ X = rng.rand(24, 6)
+ y = X.sum(axis=1) + rng.randn(24) * 0.05
+
+ base = load_minimal_model("VR")
+ params_a = model_adjust_params(
+ adapter, "VR", {**base, "w_rf": 5.0, "w_gb": 0.2, "w_nn": 0.2}
+ )
+ params_b = model_adjust_params(
+ adapter, "VR", {**base, "w_rf": 0.2, "w_gb": 5.0, "w_nn": 0.2}
+ )
+
+ model_a = load_model(adapter, "VR", **params_a)
+ model_b = load_model(adapter, "VR", **params_b)
+ model_a.fit(X, y)
+ model_b.fit(X, y)
+
+ assert not np.allclose(model_a.predict(X), model_b.predict(X))
+
+
+def test_vr_bo_bounds_include_member_models():
+ from robert.utils import BO_hyperparams
+
+ bounds = BO_hyperparams("VR")
+ assert "w_rf" in bounds
+ assert "rf_n_estimators" in bounds
+ assert "gb_learning_rate" in bounds
+ assert "nn_hidden_layer_1" in bounds
+
+
+def test_vr_classification_loads():
+ adapter = _vr_adapter("clas")
+ y = np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1])
+ X = np.random.RandomState(2).rand(len(y), 4)
+ params = model_adjust_params(adapter, "VR", dict(load_minimal_model("VR")))
+ model = load_model(adapter, "VR", **params)
+ model.fit(X, y)
+ preds = model.predict(X)
+ assert preds.shape == y.shape
From df98f82d80d442cffdec4e60bd18c3fef3163845 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Rub=C3=A9n=20Laplaza?=
<30357710+rlaplaza@users.noreply.github.com>
Date: Mon, 18 May 2026 21:49:32 +0200
Subject: [PATCH 8/8] Fixes for circleci tests in windows (#78)
* ci: add Ruff lint and format checks
Pin Ruff defaults in pyproject.toml, reformat Python sources,
fix default lint violations, and gate CircleCI on ruff check
and ruff format --check.
* feat: API scores, VR tuning, and multi-platform CI
Add RobertModel.robert_scores(), VR hyperparameter BO support, and
v2.2.0 test extras. Extend tests (plot metrics, VR/BO, API) and docs.
Refactor CircleCI to run the shared conda suite on Linux, Windows, and
macOS.
* fix: remove incomplete project table from pyproject.toml
The partial [project] section lacked a required version field and broke
pip install during CI; package metadata remains in setup.py.
* fix(ci): use preinstalled Miniconda on Windows executor
CircleCI windows-server-2022-gui images already include Miniconda. The
install-miniconda step only checked $HOME/miniconda3, missed the system
install, and hung trying to reinstall via the silent .exe installer.
---
.circleci/config.yml | 14 ++++++++------
pyproject.toml | 3 ---
2 files changed, 8 insertions(+), 9 deletions(-)
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 74aa66b..522b1f0 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -8,22 +8,25 @@ orbs:
commands:
install-miniconda:
- description: Install Miniconda on machine executors (Windows/macOS)
+ description: Install Miniconda on machine executors (macOS only; Windows images ship with Miniconda)
steps:
- run:
name: Install Miniconda
command: |
set -euo pipefail
+ if command -v conda >/dev/null 2>&1; then
+ echo "Conda already available at $(command -v conda)"
+ conda --version
+ exit 0
+ fi
MINICONDA_DIR="${HOME}/miniconda3"
if [ -x "${MINICONDA_DIR}/Scripts/conda.exe" ] || [ -x "${MINICONDA_DIR}/bin/conda" ]; then
echo "Miniconda already present at ${MINICONDA_DIR}"
else
case "$(uname -s)" in
MINGW*|MSYS*|CYGWIN*)
- INSTALLER_URL="https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe"
- curl -fsSL -o miniconda-installer.exe "${INSTALLER_URL}"
- ./miniconda-installer.exe /InstallationType=JustMe /RegisterPython=0 /S /D="${MINICONDA_DIR}"
- rm -f miniconda-installer.exe
+ echo "install-miniconda: unexpected Windows path without preinstalled conda"
+ exit 1
;;
Darwin)
ARCH="$(uname -m)"
@@ -276,7 +279,6 @@ jobs:
shell: bash.exe
steps:
- checkout
- - install-miniconda
- run-conda-test-suite:
upload_coverage: false
diff --git a/pyproject.toml b/pyproject.toml
index 49dfc1e..a330a0a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,3 @@
-[project]
-name = "robert"
-
[tool.ruff]
target-version = "py311"
line-length = 88