From 48c21421bbc76c62f50c546d9a898861aa410ed5 Mon Sep 17 00:00:00 2001 From: rlaplaza Date: Fri, 15 May 2026 20:45:54 +0200 Subject: [PATCH 1/8] feat(uq): top-k meta-model uncertainty decomposition Enable opt-in ensemble over GENERATE Raw_data candidates with decomposed within/between uncertainty and RobertModel return modes. Co-authored-by: Cursor --- docs/API/robert.api.rst | 55 ++++++++- robert/api.py | 54 ++++++++- robert/argument_parser.py | 4 + robert/predict.py | 8 +- robert/predict_utils.py | 39 +++++++ robert/utils.py | 232 ++++++++++++++++++++++++++++++++++++++ tests/test_uq_meta.py | 92 +++++++++++++++ 7 files changed, 476 insertions(+), 8 deletions(-) create mode 100644 tests/test_uq_meta.py diff --git a/docs/API/robert.api.rst b/docs/API/robert.api.rst index d9476d9..6adcedc 100644 --- a/docs/API/robert.api.rst +++ b/docs/API/robert.api.rst @@ -7,10 +7,12 @@ VERIFY, PREDICT) and exposes ``fit`` / ``predict`` / ``score`` on columns aligned with the pipeline: - ``{y}_pred``: point prediction from the **selected estimator refit on all training - data** (deployment-style mean). + data** (deployment-style mean), or a **weighted average** / **weighted vote** when + meta-model uncertainty is enabled (see below). - ``{y}_pred_sd``: per-row **standard deviation across repeated cross-validation predictions** (disagreement between refits on overlapping training folds; related - to epistemic instability, not a calibrated predictive distribution). + to epistemic instability, not a calibrated predictive distribution). With meta UQ + enabled, this column is set to ``{y}_pred_uq_total``. - ``{y}_pred_conformal_hw`` (**regression only**): a single **symmetric interval half-width** from split-style conformal calibration (absolute residuals on a held-out calibration slice of the training set when large enough, otherwise @@ -20,16 +22,61 @@ columns aligned with the pipeline: ``conformal_calib_frac``, and ``conformal_coverage`` in :class:`~robert.api.RobertModel` kwargs. For **classification**, this column is present but filled with NaN; ``{y}_pred_sd`` reflects **vote spread** across CV refits, not class probabilities. +- ``{y}_pred_uq_model`` (**meta UQ, opt-in**): within-model component (mean CV spread + across the top-k candidates; regression uses variance decomposition). +- ``{y}_pred_uq_meta`` (**meta UQ, opt-in**): between-model component (spread of + top-k point predictions). +- ``{y}_pred_uq_total`` (**meta UQ, opt-in**): combined uncertainty (regression: + :math:`\sqrt{\mathrm{E}[\sigma^2] + \mathrm{Var}(\hat y)}`; classification uses a + heuristic combining vote spread and between-model disagreement). ``predict`` returns ``{y}_pred`` values aligned to input rows. Uncertainty: - ``return_std=True`` is equivalent to ``return_uncertainty="cv_sd"`` and returns ``(y, sd_cv)``. -- ``return_uncertainty="conformal"`` (regression only) returns ``(y, half_width)``. -- ``return_uncertainty="both"`` returns ``(y, sd_cv, half_width)``. +- ``return_uncertainty="conformal"`` (**regression only**) returns ``(y, half_width)``. +- ``return_uncertainty="both"`` (**regression only**) returns ``(y, sd_cv, half_width)``. +- ``return_uncertainty="meta"`` returns ``(y, uq_meta)`` (requires ``uq_enable_meta=True``). +- ``return_uncertainty="total"`` returns ``(y, uq_total)`` (requires ``uq_enable_meta=True``). +- ``return_uncertainty="decomposed"`` returns ``(y, uq_model, uq_meta, uq_total)`` + (requires ``uq_enable_meta=True``). - If both ``return_std`` and ``return_uncertainty`` are set, ``return_uncertainty`` wins and a warning is issued. +Configuration (meta-model kwargs) +--------------------------------- + +- ``uq_enable_meta`` (``False``), ``uq_top_k_models`` (``3``), + ``uq_model_weighting`` (``"score_weighted"`` or ``"uniform"``). + +Meta-model uncertainty +---------------------- + +Enable with ``uq_enable_meta=True`` on :class:`~robert.api.RobertModel``. PREDICT +re-runs up to ``uq_top_k_models`` estimators ranked by GENERATE +``combined_{error_type}`` scores in ``GENERATE/Raw_data``, then combines predictions +with ``uq_model_weighting``. **Regression** uses a weighted mean and law-of-total-variance +decomposition. **Classification** uses a weighted vote and heuristic uncertainty +components (not class probabilities). If Raw_data has no candidates, CV spread is +used with zero meta component and a warning is issued. + +Example (meta UQ): + +.. code-block:: python + + model_meta = RobertModel( + problem_type="reg", + workdir="./robert_run_meta", + model=["RF", "GB"], + uq_enable_meta=True, + uq_top_k_models=2, + ) + model_meta.fit(X.iloc[:25], y.iloc[:25]) + y_meta, uq_between = model_meta.predict(X.iloc[25:], return_uncertainty="meta") + y_dec, uq_m, uq_b, uq_t = model_meta.predict( + X.iloc[25:], return_uncertainty="decomposed" + ) + Pipeline semantics ------------------ diff --git a/robert/api.py b/robert/api.py index 5f054ee..bf72e3c 100644 --- a/robert/api.py +++ b/robert/api.py @@ -176,6 +176,9 @@ class RobertModel(BaseEstimator): :param kwargs: Additional ROBERT options (keys in ``robert.argument_parser.var_dict``), e.g. ``model``, ``n_iter``. Regression uncertainty tuning includes ``conformal_enable``, ``conformal_calib_frac``, and ``conformal_coverage``. + Top-k meta-model uncertainty (opt-in) uses ``uq_enable_meta``, + ``uq_top_k_models``, and ``uq_model_weighting`` (``"score_weighted"`` or + ``"uniform"``). """ def __init__( @@ -491,11 +494,20 @@ def predict( self, X: Union[pd.DataFrame, np.ndarray], return_std: bool = False, - return_uncertainty: Literal[False, "cv_sd", "conformal", "both"] = False, + return_uncertainty: Literal[ + False, + "cv_sd", + "conformal", + "both", + "meta", + "total", + "decomposed", + ] = False, ) -> Union[ np.ndarray, Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray], + Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ]: if not self.is_fitted_: raise RuntimeError("Call fit before predict.") @@ -503,7 +515,15 @@ def predict( assert workdir is not None if return_uncertainty is not False: - umode: Literal[False, "cv_sd", "conformal", "both"] = return_uncertainty + umode: Literal[ + False, + "cv_sd", + "conformal", + "both", + "meta", + "total", + "decomposed", + ] = return_uncertainty if return_std: warnings.warn( "return_uncertainty is set; return_std is ignored.", @@ -572,6 +592,9 @@ def predict( pred_col = f"{y_target}_pred" sd_col = f"{y_target}_pred_sd" hw_col = f"{y_target}_pred_conformal_hw" + uq_model_col = f"{y_target}_pred_uq_model" + uq_meta_col = f"{y_target}_pred_uq_meta" + uq_total_col = f"{y_target}_pred_uq_total" if pred_col not in result_df.columns: raise RuntimeError(f"Column {pred_col!r} missing in {csv_path}") @@ -621,12 +644,37 @@ def predict( f"Column {hw_col!r} has no finite values; disable conformal " "or use a larger training set." ) + if umode in ("meta", "total", "decomposed"): + if not bool(self._rob_kwargs.get("uq_enable_meta", False)): + raise ValueError( + "return_uncertainty='meta', 'total', or 'decomposed' requires " + "uq_enable_meta=True when constructing RobertModel." + ) + for col in (uq_model_col, uq_meta_col, uq_total_col): + if col not in result_df.columns: + raise RuntimeError( + f"Column {col!r} missing in {csv_path}; refit with " + "uq_enable_meta=True or run predict after enabling meta UQ." + ) + y_uq_model = ordered[uq_model_col].to_numpy(dtype=float) + y_uq_meta = ordered[uq_meta_col].to_numpy(dtype=float) + y_uq_total = ordered[uq_total_col].to_numpy(dtype=float) + if not (np.isfinite(y_uq_total).all() and (y_uq_total >= 0).all()): + raise RuntimeError( + f"Column {uq_total_col!r} has invalid uncertainty values." + ) if umode == "cv_sd": return y_pred, y_sd if umode == "conformal": return y_pred, y_hw - return y_pred, y_sd, y_hw + if umode == "both": + return y_pred, y_sd, y_hw + if umode == "meta": + return y_pred, y_uq_meta + if umode == "total": + return y_pred, y_uq_total + return y_pred, y_uq_model, y_uq_meta, y_uq_total def score( self, diff --git a/robert/argument_parser.py b/robert/argument_parser.py index a1dea3c..5318131 100644 --- a/robert/argument_parser.py +++ b/robert/argument_parser.py @@ -65,6 +65,10 @@ "conformal_enable": True, "conformal_calib_frac": 0.15, "conformal_coverage": 0.9, + # Top-k meta-model uncertainty (Python API / optional PREDICT path). + "uq_enable_meta": False, + "uq_top_k_models": 3, + "uq_model_weighting": "score_weighted", } diff --git a/robert/predict.py b/robert/predict.py index 8874de4..e26b774 100644 --- a/robert/predict.py +++ b/robert/predict.py @@ -42,9 +42,11 @@ print_predict, pearson_map_predict ) -from robert.utils import (load_variables, +from robert.utils import ( + load_variables, load_db_n_params, load_n_predict, + apply_meta_uq_ensemble, finish_print, print_pfi, PFI_plot, @@ -89,6 +91,10 @@ def __init__(self, **kwargs): # get results from training, test and external test (if any) Xy_data = load_n_predict(self, model_data, Xy_data, BO_opt=False) + if getattr(self.args, "uq_enable_meta", False): + Xy_data = apply_meta_uq_ensemble( + self, Xy_data, model_data, params_dir + ) # save predictions for all sets path_n_suffix, name_points, Xy_data = save_predictions(self,Xy_data,model_data,suffix_title) diff --git a/robert/predict_utils.py b/robert/predict_utils.py index 10e70ad..dacdfb6 100644 --- a/robert/predict_utils.py +++ b/robert/predict_utils.py @@ -44,6 +44,28 @@ def test_csv(self,Xy_test_df,descs_model,params_df): return X_test_df, y_test_df +def _uq_columns_for_split(Xy_data, y_col, split): + """Optional meta-UQ and auto-UQ columns for a train/test/external split.""" + prefix = f"y_pred_{split}" + out = {} + for suffix in ("uq_model", "uq_meta", "uq_total"): + key = f"{prefix}_{suffix}" + if key in Xy_data: + out[f"{y_col}_pred_{suffix}"] = Xy_data[key] + auto_key = f"{prefix}_uq_auto" + if auto_key in Xy_data: + out[f"{y_col}_pred_uq_auto"] = Xy_data[auto_key] + return out + + +def _uq_auto_source_column(Xy_data): + """Constant source label for the selected auto uncertainty candidate.""" + selected = Xy_data.get("uq_auto_selected") + if not selected: + return None + return str(selected) + + def plot_predictions(self, params_dict, Xy_data, path_n_suffix): ''' Plot graphs of predicted vs actual values for train, validation and test sets @@ -109,10 +131,15 @@ def save_predictions(self,Xy_data,model_data,suffix_title): Xy_train[y_col] = y_train_values Xy_train[f"{y_col}_pred"] = y_pred_train_values Xy_train[f"{y_col}_pred_sd"] = Xy_data['y_pred_train_sd'] + for col_name, col_vals in _uq_columns_for_split(Xy_data, y_col, "train").items(): + Xy_train[col_name] = col_vals hw_scalar = float(Xy_data.get("conformal_half_width", float("nan"))) if model_data["type"].lower() != "reg": hw_scalar = float("nan") Xy_train[f"{y_col}_pred_conformal_hw"] = [hw_scalar] * len(Xy_train) + auto_src = _uq_auto_source_column(Xy_data) + if auto_src is not None: + Xy_train[f"{y_col}_pred_uq_auto_source"] = [auto_src] * len(Xy_train) # For test set y_test_values = Xy_data['y_test'].tolist() @@ -124,7 +151,11 @@ def save_predictions(self,Xy_data,model_data,suffix_title): Xy_test[y_col] = y_test_values Xy_test[f"{y_col}_pred"] = y_pred_test_values Xy_test[f"{y_col}_pred_sd"] = Xy_data['y_pred_test_sd'] + for col_name, col_vals in _uq_columns_for_split(Xy_data, y_col, "test").items(): + Xy_test[col_name] = col_vals Xy_test[f"{y_col}_pred_conformal_hw"] = [hw_scalar] * len(Xy_test) + if auto_src is not None: + Xy_test[f"{y_col}_pred_uq_auto_source"] = [auto_src] * len(Xy_test) df_results = pd.concat([Xy_train, Xy_test], axis=0) @@ -165,7 +196,15 @@ def save_predictions(self,Xy_data,model_data,suffix_title): Xy_external[f"{model_data['y']}_pred"] = y_pred_external_values Xy_external[f"{model_data['y']}_pred_sd"] = Xy_data['y_pred_external_sd'] + for col_name, col_vals in _uq_columns_for_split( + Xy_data, model_data["y"], "external" + ).items(): + Xy_external[col_name] = col_vals Xy_external[f"{model_data['y']}_pred_conformal_hw"] = [hw_scalar] * len(Xy_external) + if auto_src is not None: + Xy_external[f"{model_data['y']}_pred_uq_auto_source"] = [ + auto_src + ] * len(Xy_external) path_external = Path(os.getcwd()).joinpath('PREDICT/csv_test/') Path(path_external).mkdir(exist_ok=True, parents=True) diff --git a/robert/utils.py b/robert/utils.py index 56016ae..20ffea3 100644 --- a/robert/utils.py +++ b/robert/utils.py @@ -1796,6 +1796,238 @@ def correct_hidden_layers(params): return params +def _raw_data_dir_from_best_params(params_dir): + """Map ``GENERATE/Best_model/{PFI|No_PFI}`` to ``GENERATE/Raw_data/{PFI|No_PFI}``.""" + p = Path(str(params_dir)) + parts = list(p.parts) + if "Best_model" in parts: + idx = parts.index("Best_model") + parts[idx] = "Raw_data" + return Path(os.getcwd()).joinpath(*parts) + if "Raw_data" in parts: + return Path(os.getcwd()).joinpath(*parts) + return Path(os.getcwd()).joinpath("GENERATE", "Raw_data", p.name) + + +def discover_top_k_model_candidates(raw_data_dir, top_k, weighting="score_weighted"): + """ + List up to ``top_k`` model parameter CSV paths ranked by GENERATE combined score. + + Returns a list of dicts with keys ``path``, ``model``, ``score``, ``error_type``, + ``weight`` (normalized, higher is better for classification metrics). + """ + raw_path = Path(raw_data_dir) + if not raw_path.is_dir(): + return [] + + entries = [] + for csv_file in sorted(glob.glob(str(raw_path.joinpath("*.csv")))): + if csv_file.endswith("_db.csv"): + continue + try: + results = pd.read_csv(csv_file, encoding="utf-8") + except (OSError, ValueError, KeyError): + continue + if results.empty or "error_type" not in results.columns: + continue + err_type = str(results["error_type"].iloc[0]) + comb_col = f"combined_{err_type}" + if comb_col not in results.columns: + continue + score = float(results[comb_col].iloc[0]) + model_name = Path(csv_file).stem.split("_")[0] + entries.append( + { + "path": csv_file, + "model": model_name, + "score": score, + "error_type": err_type, + } + ) + + if not entries: + return [] + + lower_is_better = entries[0]["error_type"].lower() in ("mae", "rmse") + entries.sort(key=lambda e: e["score"], reverse=not lower_is_better) + entries = entries[: max(1, int(top_k))] + + if weighting == "uniform" or len(entries) == 1: + w = 1.0 / len(entries) + for e in entries: + e["weight"] = w + return entries + + scores = np.array([e["score"] for e in entries], dtype=float) + if lower_is_better: + # Convert errors to weights (smaller error -> larger weight). + shifted = scores.max() - scores + 1e-12 + raw_w = shifted + else: + raw_w = np.maximum(scores, 0.0) + 1e-12 + raw_w = raw_w / raw_w.sum() + for e, w in zip(entries, raw_w): + e["weight"] = float(w) + return entries + + +def aggregate_meta_uq_decomposition(preds_stack, sd_stack, weights, problem_type): + """ + Combine per-model predictions and within-model CV SD into decomposition scalars. + + Regression uses the law of total variance (within + between). Classification + uses a heuristic on vote-spread (within) and between-model label disagreement. + + Returns + ------- + y_point, uq_model, uq_meta, uq_total + Point prediction and nonnegative uncertainty components per row. + """ + preds = np.asarray(preds_stack, dtype=float) + sds = np.asarray(sd_stack, dtype=float) + w = np.asarray(weights, dtype=float).ravel() + if preds.ndim != 2: + preds = preds.reshape(1, -1) + if sds.ndim != 2: + sds = sds.reshape(1, -1) + if w.size != preds.shape[0]: + w = np.ones(preds.shape[0], dtype=float) / preds.shape[0] + w = w / w.sum() + + if problem_type.lower() == "reg": + y_point = np.average(preds, axis=0, weights=w) + within_var = np.average(sds ** 2, axis=0, weights=w) + mean_pred = np.average(preds, axis=0, weights=w) + between_var = np.average((preds - mean_pred) ** 2, axis=0, weights=w) + uq_model = np.sqrt(np.maximum(within_var, 0.0)) + uq_meta = np.sqrt(np.maximum(between_var, 0.0)) + uq_total = np.sqrt(np.maximum(within_var + between_var, 0.0)) + return y_point, uq_model, uq_meta, uq_total + + # Classification: preds are class labels; use float spread heuristics. + preds_f = preds.astype(float) + y_point = np.array( + [int(round(np.average(preds_f[:, i], weights=w))) for i in range(preds_f.shape[1])], + dtype=int, + ) + uq_model = np.average(sds, axis=0, weights=w) + mean_pred = np.average(preds_f, axis=0, weights=w) + uq_meta = np.sqrt( + np.maximum(np.average((preds_f - mean_pred) ** 2, axis=0, weights=w), 0.0) + ) + uq_total = np.sqrt(np.maximum(uq_model ** 2 + uq_meta ** 2, 0.0)) + return y_point, uq_model, uq_meta, uq_total + + +def _snapshot_prediction_fields(Xy_data, prefix): + """Copy prediction-related lists for one split prefix (train/test/external).""" + keys = [f"y_pred_{prefix}", f"y_pred_{prefix}_sd"] + return {k: list(Xy_data[k]) for k in keys if k in Xy_data} + + +def _restore_prediction_fields(Xy_data, prefix, snap): + for k, v in snap.items(): + Xy_data[k] = v + + +def apply_meta_uq_ensemble(self, Xy_data, model_data, params_dir): + """ + Run top-k models from GENERATE Raw_data and attach meta UQ fields on ``Xy_data``. + + No-op unless ``self.args.uq_enable_meta`` is true. Re-runs ``load_n_predict`` for + each candidate while restoring base point predictions between models. Writes + ``y_pred_{split}_uq_{model,meta,total}`` and sets ``y_pred_{split}_sd`` to total. + Emits :class:`UserWarning` when Raw_data has no or only one candidate. + """ + if not bool(getattr(self.args, "uq_enable_meta", False)): + return Xy_data + + top_k = int(getattr(self.args, "uq_top_k_models", 3)) + weighting = str(getattr(self.args, "uq_model_weighting", "score_weighted")) + raw_dir = _raw_data_dir_from_best_params(params_dir) + candidates = discover_top_k_model_candidates(raw_dir, top_k, weighting) + + if not candidates: + warnings.warn( + f"Meta-model uncertainty requested but no candidates found in {raw_dir!s}; " + "using single-model CV spread only.", + UserWarning, + stacklevel=2, + ) + _attach_single_model_uq_fallback(Xy_data, model_data) + return Xy_data + + if len(candidates) < 2: + warnings.warn( + "Meta-model uncertainty requested but fewer than two candidate models " + f"in {raw_dir!s}; meta component set to zero.", + UserWarning, + stacklevel=2, + ) + + problem_type = model_data["type"] + splits = ["train", "test"] + if "X_external" in Xy_data: + splits.append("external") + + base_snap_full = { + split: _snapshot_prediction_fields(Xy_data, split) for split in splits + } + preds_by_split = {split: [] for split in splits} + sd_by_split = {split: [] for split in splits} + + for cand in candidates: + cand_data = load_params(self, cand["path"]) + cand_data["repeat_kfolds"] = model_data.get( + "repeat_kfolds", cand_data.get("repeat_kfolds") + ) + cand_data["kfold"] = model_data.get("kfold", cand_data.get("kfold")) + xy_run = load_n_predict(self, cand_data, Xy_data, BO_opt=False) + for split in splits: + pred_key = f"y_pred_{split}" + sd_key = f"y_pred_{split}_sd" + if pred_key in xy_run: + preds_by_split[split].append(np.asarray(xy_run[pred_key], dtype=float)) + sd_by_split[split].append(np.asarray(xy_run[sd_key], dtype=float)) + for split in splits: + _restore_prediction_fields(Xy_data, split, base_snap_full[split]) + + weights = [c["weight"] for c in candidates] + for split in splits: + if not preds_by_split[split]: + continue + pred_key = f"y_pred_{split}" + sd_key = f"y_pred_{split}_sd" + stack_p = np.vstack(preds_by_split[split]) + stack_s = np.vstack(sd_by_split[split]) + y_pt, uq_m, uq_meta, uq_tot = aggregate_meta_uq_decomposition( + stack_p, stack_s, weights, problem_type + ) + if problem_type.lower() == "clas": + y_pt = np.asarray([int(v) for v in y_pt], dtype=int) + else: + y_pt = np.asarray(y_pt, dtype=float) + Xy_data[pred_key] = y_pt.tolist() + Xy_data[f"y_pred_{split}_uq_model"] = np.asarray(uq_m, dtype=float).tolist() + Xy_data[f"y_pred_{split}_uq_meta"] = np.asarray(uq_meta, dtype=float).tolist() + Xy_data[f"y_pred_{split}_uq_total"] = np.asarray(uq_tot, dtype=float).tolist() + Xy_data[sd_key] = Xy_data[f"y_pred_{split}_uq_total"] + + return Xy_data + + +def _attach_single_model_uq_fallback(Xy_data, model_data): + """When meta ensemble cannot run, map CV SD to uq_model and zero meta.""" + for split in ("train", "test", "external"): + sd_key = f"y_pred_{split}_sd" + if sd_key not in Xy_data: + continue + sd = np.asarray(Xy_data[sd_key], dtype=float) + Xy_data[f"y_pred_{split}_uq_model"] = sd.tolist() + Xy_data[f"y_pred_{split}_uq_meta"] = [0.0] * len(sd) + Xy_data[f"y_pred_{split}_uq_total"] = sd.tolist() + + def _conformal_abs_residual_quantile(abs_residuals, coverage): """ Finite-sample quantile for split-style conformal symmetric intervals diff --git a/tests/test_uq_meta.py b/tests/test_uq_meta.py new file mode 100644 index 0000000..8efcc0e --- /dev/null +++ b/tests/test_uq_meta.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python + +"""Tests for top-k meta-model uncertainty helpers and API.""" + +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from robert import RobertModel +from robert.utils import ( + aggregate_meta_uq_decomposition, + discover_top_k_model_candidates, +) + +_REG_CSV = Path(__file__).resolve().parent / "Robert_example.csv" + +_FAST = { + "model": ["RF", "GB"], + "n_iter": 2, + "init_points": 2, + "repeat_kfolds": 2, + "kfold": 3, + "pfi_epochs": 1, + "seed": 42, + "uq_enable_meta": True, + "uq_top_k_models": 2, + "uq_model_weighting": "uniform", +} + + +def test_aggregate_meta_uq_regression_decomposition(): + """Total variance equals within + between for regression.""" + preds = np.array([[1.0, 2.0], [3.0, 4.0], [2.0, 3.0]]) + sds = np.array([[0.1, 0.2], [0.3, 0.4], [0.2, 0.3]]) + weights = np.array([1 / 3, 1 / 3, 1 / 3]) + y, uq_m, uq_meta, uq_tot = aggregate_meta_uq_decomposition( + preds, sds, weights, "reg" + ) + assert y.shape == (2,) + assert np.all(uq_tot >= uq_m - 1e-9) + assert np.all(uq_tot >= uq_meta - 1e-9) + assert np.all(uq_m >= 0) and np.all(uq_meta >= 0) + + +def test_discover_top_k_empty_dir(tmp_path): + assert discover_top_k_model_candidates(tmp_path, 3) == [] + + +def test_fit_predict_meta_uncertainty_modes(tmp_path): + df = pd.read_csv(_REG_CSV, encoding="utf-8") + X = df.drop(columns=["Target_values"]) + y = df["Target_values"] + n_fit = 25 + model = RobertModel( + problem_type="reg", + filter_mode="no_pfi", + workdir=tmp_path, + names="Name", + **_FAST, + ) + model.fit(X.iloc[:n_fit], y.iloc[:n_fit]) + X_hold = X.iloc[n_fit:].drop_duplicates(subset=["Name"], keep="first") + y_hat, uq_meta = model.predict(X_hold, return_uncertainty="meta") + assert y_hat.shape == uq_meta.shape + assert np.isfinite(uq_meta).all() and (uq_meta >= 0).all() + _, uq_total = model.predict(X_hold, return_uncertainty="total") + y_d, uq_m, uq_meta2, uq_tot = model.predict( + X_hold, return_uncertainty="decomposed" + ) + assert y_d.shape == uq_m.shape == uq_meta2.shape == uq_tot.shape + assert np.all(uq_tot >= uq_m - 1e-9) + assert np.allclose(uq_tot, uq_total, rtol=1e-5, atol=1e-5) + + +def test_meta_uncertainty_requires_enable_flag(tmp_path): + df = pd.read_csv(_REG_CSV, encoding="utf-8") + X = df.drop(columns=["Target_values"]) + y = df["Target_values"] + fast = {k: v for k, v in _FAST.items() if k != "uq_enable_meta"} + model = RobertModel( + problem_type="reg", + filter_mode="no_pfi", + workdir=tmp_path, + names="Name", + **fast, + ) + model.fit(X.iloc[:20], y.iloc[:20]) + X_hold = X.iloc[20:].drop_duplicates(subset=["Name"], keep="first") + with pytest.raises(ValueError, match="uq_enable_meta"): + model.predict(X_hold, return_uncertainty="total") From 64b0a9d858b3e764e08f56eb5402b3f5a104e010 Mon Sep 17 00:00:00 2001 From: rlaplaza Date: Fri, 15 May 2026 20:46:16 +0200 Subject: [PATCH 2/8] feat(uq): auto uncertainty selection for regression Score cv_sd, conformal, and meta_total on OOF residuals, fit optional global scaler, and expose auto / auto_decomposed predict modes. Co-authored-by: Cursor --- docs/API/API_Reference.rst | 1 + docs/API/robert.api.rst | 55 ++++- docs/API/robert.uq_auto.rst | 5 + robert/api.py | 45 +++- robert/argument_parser.py | 8 + robert/predict.py | 5 + robert/uq_auto.py | 450 ++++++++++++++++++++++++++++++++++++ tests/test_uq_auto.py | 127 ++++++++++ 8 files changed, 691 insertions(+), 5 deletions(-) create mode 100644 docs/API/robert.uq_auto.rst create mode 100644 robert/uq_auto.py create mode 100644 tests/test_uq_auto.py diff --git a/docs/API/API_Reference.rst b/docs/API/API_Reference.rst index 7771eaa..c7e7b3d 100644 --- a/docs/API/API_Reference.rst +++ b/docs/API/API_Reference.rst @@ -25,4 +25,5 @@ Other modules robert.api robert.generate_utils robert.predict_utils + robert.uq_auto robert.utils diff --git a/docs/API/robert.api.rst b/docs/API/robert.api.rst index 6adcedc..551abf7 100644 --- a/docs/API/robert.api.rst +++ b/docs/API/robert.api.rst @@ -29,6 +29,10 @@ columns aligned with the pipeline: - ``{y}_pred_uq_total`` (**meta UQ, opt-in**): combined uncertainty (regression: :math:`\sqrt{\mathrm{E}[\sigma^2] + \mathrm{Var}(\hat y)}`; classification uses a heuristic combining vote spread and between-model disagreement). +- ``{y}_pred_uq_auto`` (**auto UQ, regression, opt-in**): calibrated sigma-like scale + from automatic candidate selection (see **Auto uncertainty** below). +- ``{y}_pred_uq_auto_source`` (**auto UQ, opt-in**): name of the selected candidate + (``cv_sd``, ``conformal``, or ``meta_total``). ``predict`` returns ``{y}_pred`` values aligned to input rows. Uncertainty: @@ -40,14 +44,29 @@ columns aligned with the pipeline: - ``return_uncertainty="total"`` returns ``(y, uq_total)`` (requires ``uq_enable_meta=True``). - ``return_uncertainty="decomposed"`` returns ``(y, uq_model, uq_meta, uq_total)`` (requires ``uq_enable_meta=True``). +- ``return_uncertainty="auto"`` (**regression only**) enables auto uncertainty for + that predict call and returns ``(y, uq_auto)`` from ``{y}_pred_uq_auto``. +- ``return_uncertainty="auto_decomposed"`` (**regression only**) returns + ``(y, uq_auto, metadata)`` where ``metadata`` is loaded from + ``PREDICT/uq_auto_metadata.json`` when present. - If both ``return_std`` and ``return_uncertainty`` are set, ``return_uncertainty`` wins and a warning is issued. -Configuration (meta-model kwargs) ---------------------------------- +Configuration (uncertainty kwargs) +------------------------------------ -- ``uq_enable_meta`` (``False``), ``uq_top_k_models`` (``3``), +Defaults are defined in ``robert.argument_parser.var_dict``: + +- **Conformal:** ``conformal_enable`` (``True``), ``conformal_calib_frac`` (``0.15``), + ``conformal_coverage`` (``0.9``). +- **Meta-model:** ``uq_enable_meta`` (``False``), ``uq_top_k_models`` (``3``), ``uq_model_weighting`` (``"score_weighted"`` or ``"uniform"``). +- **Auto (regression):** ``uq_auto_enable`` (``False``), + ``uq_auto_candidates`` (``["cv_sd", "conformal", "meta_total"]``), + ``uq_auto_scaler`` (``"global_multiplicative"``; also ``"none"`` or ``"isotonic"``), + ``uq_auto_metric_weights`` (coverage / sharpness / NLL weights), + ``uq_auto_min_samples`` (``12``), ``uq_auto_random_state`` (``0``), + ``uq_auto_clas_mode`` (``"error"`` — raises if auto is requested for classification). Meta-model uncertainty ---------------------- @@ -77,6 +96,34 @@ Example (meta UQ): X.iloc[25:], return_uncertainty="decomposed" ) +Auto uncertainty (Bayesian optimization) +---------------------------------------- + +For uncertainty-aware acquisition (e.g. expected improvement with a surrogate +variance), use ``return_uncertainty="auto"`` on **regression** tasks. Auto mode +scores candidates ``cv_sd``, ``conformal``, and ``meta_total`` (when available) on +training out-of-fold absolute residuals, fits an optional scaler (``uq_auto_scaler``), +and writes ``{y}_pred_uq_auto``. Enable with ``uq_auto_enable=True``, or rely on +``return_uncertainty="auto"`` / ``"auto_decomposed"`` to enable it per predict call. +Legacy columns ``{y}_pred_sd`` and conformal half-width are unchanged when auto mode +runs. Lower-level helpers live in :mod:`robert.uq_auto`. + +Example (auto UQ): + +.. code-block:: python + + model_auto = RobertModel( + problem_type="reg", + workdir="./robert_run_auto", + model=["RF"], + uq_auto_enable=True, + ) + model_auto.fit(X.iloc[:25], y.iloc[:25]) + y_bo, sigma = model_auto.predict(X.iloc[25:], return_uncertainty="auto") + y_bo2, sigma2, meta = model_auto.predict( + X.iloc[25:], return_uncertainty="auto_decomposed" + ) + Pipeline semantics ------------------ @@ -125,5 +172,5 @@ Example r2 = model.score(X.iloc[25:], y.iloc[25:]) .. autoclass:: robert.api.RobertModel - :members: + :members: fit, predict, score, get_params, set_params :no-inherited-members: diff --git a/docs/API/robert.uq_auto.rst b/docs/API/robert.uq_auto.rst new file mode 100644 index 0000000..29b037e --- /dev/null +++ b/docs/API/robert.uq_auto.rst @@ -0,0 +1,5 @@ +uq_auto +======= + +.. automodule:: robert.uq_auto + :members: diff --git a/robert/api.py b/robert/api.py index bf72e3c..f1191db 100644 --- a/robert/api.py +++ b/robert/api.py @@ -9,6 +9,7 @@ from __future__ import annotations import glob +import json import os import tempfile import uuid @@ -178,7 +179,10 @@ class RobertModel(BaseEstimator): ``conformal_enable``, ``conformal_calib_frac``, and ``conformal_coverage``. Top-k meta-model uncertainty (opt-in) uses ``uq_enable_meta``, ``uq_top_k_models``, and ``uq_model_weighting`` (``"score_weighted"`` or - ``"uniform"``). + ``"uniform"``). Auto uncertainty (regression, opt-in) uses ``uq_auto_enable``, + ``uq_auto_candidates``, ``uq_auto_scaler``, ``uq_auto_metric_weights``, + ``uq_auto_min_samples``, and ``uq_auto_random_state``; ``return_uncertainty`` + ``"auto"`` or ``"auto_decomposed"`` enables auto mode for that predict call. """ def __init__( @@ -502,12 +506,15 @@ def predict( "meta", "total", "decomposed", + "auto", + "auto_decomposed", ] = False, ) -> Union[ np.ndarray, Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], + Tuple[np.ndarray, np.ndarray, dict], ]: if not self.is_fitted_: raise RuntimeError("Call fit before predict.") @@ -523,6 +530,8 @@ def predict( "meta", "total", "decomposed", + "auto", + "auto_decomposed", ] = return_uncertainty if return_std: warnings.warn( @@ -579,6 +588,13 @@ def predict( base["csv_test"] = pred_name base["params_dir"] = "GENERATE/Best_model" base["names"] = self.names_col_ + if umode in ("auto", "auto_decomposed"): + if self.problem_type != "reg": + raise ValueError( + "return_uncertainty='auto' or 'auto_decomposed' is only " + "supported for problem_type='reg'." + ) + base["uq_auto_enable"] = True with _noninteractive_mpl(), _chdir(workdir): predict_module(**base) @@ -595,6 +611,8 @@ def predict( uq_model_col = f"{y_target}_pred_uq_model" uq_meta_col = f"{y_target}_pred_uq_meta" uq_total_col = f"{y_target}_pred_uq_total" + uq_auto_col = f"{y_target}_pred_uq_auto" + uq_auto_src_col = f"{y_target}_pred_uq_auto_source" if pred_col not in result_df.columns: raise RuntimeError(f"Column {pred_col!r} missing in {csv_path}") @@ -644,6 +662,27 @@ def predict( f"Column {hw_col!r} has no finite values; disable conformal " "or use a larger training set." ) + if umode in ("auto", "auto_decomposed"): + if uq_auto_col not in result_df.columns: + raise RuntimeError( + f"Column {uq_auto_col!r} missing in {csv_path}; ensure " + "uq_auto_enable=True or use return_uncertainty='auto'." + ) + y_uq_auto = ordered[uq_auto_col].to_numpy(dtype=float) + if not (np.isfinite(y_uq_auto).all() and (y_uq_auto >= 0).all()): + raise RuntimeError( + f"Column {uq_auto_col!r} has invalid uncertainty values." + ) + auto_metadata: dict[str, Any] = {} + meta_path = workdir / "PREDICT" / "uq_auto_metadata.json" + if meta_path.is_file(): + with meta_path.open(encoding="utf-8") as fh: + auto_metadata = json.load(fh) + elif uq_auto_src_col in result_df.columns: + src_vals = ordered[uq_auto_src_col].dropna().unique() + if len(src_vals): + auto_metadata["selected"] = str(src_vals[0]) + if umode in ("meta", "total", "decomposed"): if not bool(self._rob_kwargs.get("uq_enable_meta", False)): raise ValueError( @@ -670,6 +709,10 @@ def predict( return y_pred, y_hw if umode == "both": return y_pred, y_sd, y_hw + if umode == "auto": + return y_pred, y_uq_auto + if umode == "auto_decomposed": + return y_pred, y_uq_auto, auto_metadata if umode == "meta": return y_pred, y_uq_meta if umode == "total": diff --git a/robert/argument_parser.py b/robert/argument_parser.py index 5318131..5e54f8e 100644 --- a/robert/argument_parser.py +++ b/robert/argument_parser.py @@ -69,6 +69,14 @@ "uq_enable_meta": False, "uq_top_k_models": 3, "uq_model_weighting": "score_weighted", + # Auto uncertainty selection (regression; opt-in). + "uq_auto_enable": False, + "uq_auto_candidates": ["cv_sd", "conformal", "meta_total"], + "uq_auto_scaler": "global_multiplicative", + "uq_auto_metric_weights": {"coverage": 1.0, "sharpness": 0.25, "nll": 0.5}, + "uq_auto_min_samples": 12, + "uq_auto_random_state": 0, + "uq_auto_clas_mode": "error", } diff --git a/robert/predict.py b/robert/predict.py index e26b774..4d6af92 100644 --- a/robert/predict.py +++ b/robert/predict.py @@ -42,6 +42,7 @@ print_predict, pearson_map_predict ) +from robert.uq_auto import apply_auto_uq from robert.utils import ( load_variables, load_db_n_params, @@ -95,6 +96,10 @@ def __init__(self, **kwargs): Xy_data = apply_meta_uq_ensemble( self, Xy_data, model_data, params_dir ) + if getattr(self.args, "uq_auto_enable", False): + Xy_data = apply_auto_uq( + self, Xy_data, model_data, params_dir + ) # save predictions for all sets path_n_suffix, name_points, Xy_data = save_predictions(self,Xy_data,model_data,suffix_title) diff --git a/robert/uq_auto.py b/robert/uq_auto.py new file mode 100644 index 0000000..6bc6349 --- /dev/null +++ b/robert/uq_auto.py @@ -0,0 +1,450 @@ +""" +Automatic uncertainty selection and calibration for regression. + +Evaluates multiple uncertainty candidates on out-of-fold training residuals, +fits optional global multiplicative scalers, and selects the best-calibrated +candidate for deployment on all prediction splits. User-facing API details are +in ``docs/API/robert.api.rst``. +""" + +from __future__ import annotations + +import json +import warnings +from pathlib import Path +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple + +import numpy as np +from scipy.stats import norm +from sklearn.model_selection import train_test_split + +# Candidate identifiers (stable strings stored in CSV metadata). +CANDIDATE_CV_SD = "cv_sd" +CANDIDATE_CONFORMAL = "conformal" +CANDIDATE_META_TOTAL = "meta_total" + +DEFAULT_CANDIDATES = (CANDIDATE_CV_SD, CANDIDATE_CONFORMAL, CANDIDATE_META_TOTAL) +DEFAULT_METRIC_WEIGHTS = {"coverage": 1.0, "sharpness": 0.25, "nll": 0.5} +TIE_BREAK_ORDER = (CANDIDATE_CV_SD, CANDIDATE_CONFORMAL, CANDIDATE_META_TOTAL) + +VALID_SCALERS = ("none", "global_multiplicative", "isotonic") + + +def _as_float_array(x: Sequence[float]) -> np.ndarray: + return np.asarray(x, dtype=float).ravel() + + +def _normalize_metric_weights(weights: Optional[Mapping[str, float]]) -> Dict[str, float]: + base = dict(DEFAULT_METRIC_WEIGHTS) + if weights is not None: + for key in base: + if key in weights: + base[key] = float(weights[key]) + total = sum(base.values()) + if total <= 0: + return dict(DEFAULT_METRIC_WEIGHTS) + return {k: v / total for k, v in base.items()} + + +def fit_uncertainty_scaler( + method: str, + u_raw: np.ndarray, + abs_residuals: np.ndarray, +) -> Dict[str, Any]: + """ + Fit a post-hoc scaler mapping raw uncertainty to calibrated sigma-like values. + + Parameters + ---------- + method + ``none``, ``global_multiplicative``, or ``isotonic``. + u_raw, abs_residuals + Aligned nonnegative uncertainty and absolute residual arrays. + """ + method = str(method).lower() + if method not in VALID_SCALERS: + raise ValueError( + f"Unknown uq_auto_scaler {method!r}; expected one of {VALID_SCALERS}." + ) + + u = np.maximum(_as_float_array(u_raw), 0.0) + r = np.maximum(_as_float_array(abs_residuals), 0.0) + mask = np.isfinite(u) & np.isfinite(r) + u, r = u[mask], r[mask] + + if method == "none": + return {"method": "none", "alpha": 1.0} + + if u.size < 2: + return {"method": method, "alpha": 1.0} + + if method == "global_multiplicative": + denom = float(np.dot(u, u)) + if denom <= 1e-12: + alpha = 1.0 + else: + alpha = float(np.dot(u, r) / denom) + alpha = max(alpha, 1e-6) + return {"method": "global_multiplicative", "alpha": alpha} + + # isotonic: piecewise-linear map on sorted unique u values + order = np.argsort(u) + u_sorted = u[order] + r_sorted = r[order] + # PAV-style simple isotonic: cumulative max of running medians in bins + uniq_u, inv = np.unique(u_sorted, return_inverse=True) + r_mean = np.zeros_like(uniq_u) + counts = np.zeros_like(uniq_u) + for idx, val in enumerate(r_sorted): + b = inv[idx] + r_mean[b] += val + counts[b] += 1 + r_mean = r_mean / np.maximum(counts, 1) + # enforce monotonicity in u + for i in range(1, len(r_mean)): + r_mean[i] = max(r_mean[i], r_mean[i - 1]) + return { + "method": "isotonic", + "knots_u": uniq_u.tolist(), + "knots_r": r_mean.tolist(), + } + + +def apply_uncertainty_scaler( + method: str, + u_raw: np.ndarray, + params: Mapping[str, Any], +) -> np.ndarray: + """Apply fitted scaler parameters to raw uncertainty.""" + u = np.maximum(_as_float_array(u_raw), 0.0) + method = str(params.get("method", method)).lower() + + if method == "none": + return u + + if method == "global_multiplicative": + alpha = float(params.get("alpha", 1.0)) + return np.maximum(u * alpha, 0.0) + + if method == "isotonic": + knots_u = np.asarray(params.get("knots_u", []), dtype=float) + knots_r = np.asarray(params.get("knots_r", []), dtype=float) + if knots_u.size == 0: + return u + return np.maximum(np.interp(u, knots_u, knots_r, left=knots_r[0], right=knots_r[-1]), 0.0) + + raise ValueError(f"Unknown scaler method in params: {method!r}") + + +def _coverage_error(abs_resid: np.ndarray, sigma: np.ndarray, coverage: float) -> float: + sigma = np.maximum(sigma, 1e-12) + # Gaussian half-width factor for symmetric interval + alpha = 1.0 - coverage + z = float(norm.ppf(1.0 - alpha / 2.0)) + inside = abs_resid <= z * sigma + emp_cov = float(np.mean(inside)) if inside.size else 0.0 + return abs(emp_cov - coverage) + + +def _gaussian_nll(abs_resid: np.ndarray, sigma: np.ndarray) -> float: + sigma = np.maximum(sigma, 1e-12) + # NLL for Laplace-like on abs residual under Gaussian proxy + var = sigma ** 2 + return float(np.mean(0.5 * np.log(2.0 * np.pi * var) + 0.5 * (abs_resid ** 2) / var)) + + +def _sharpness(sigma: np.ndarray) -> float: + return float(np.mean(np.maximum(sigma, 0.0))) + + +def score_uncertainty_candidate( + u_scaled: np.ndarray, + abs_resid: np.ndarray, + coverage: float, + metric_weights: Optional[Mapping[str, float]] = None, +) -> float: + """ + Lower is better. Composite of coverage error, mean width, and Gaussian NLL. + """ + w = _normalize_metric_weights(metric_weights) + u_s = np.maximum(_as_float_array(u_scaled), 1e-12) + r = _as_float_array(abs_resid) + cov_err = _coverage_error(r, u_s, coverage) + sharp = _sharpness(u_s) + nll = _gaussian_nll(r, u_s) + return w["coverage"] * cov_err + w["sharpness"] * sharp + w["nll"] * nll + + +def _oof_mean_train(Xy_data: Mapping[str, Any]) -> np.ndarray: + preds_all = Xy_data.get("y_pred_train_all", []) + return np.array([float(np.mean(p)) if len(p) else np.nan for p in preds_all], dtype=float) + + +def _train_abs_residuals(Xy_data: Mapping[str, Any]) -> np.ndarray: + y_true = _as_float_array(Xy_data["y_train"]) + y_oof = _oof_mean_train(Xy_data) + return np.abs(y_true - y_oof) + + +def _raw_candidate_train( + candidate: str, + Xy_data: Mapping[str, Any], +) -> Optional[np.ndarray]: + n = len(Xy_data["y_train"]) + if candidate == CANDIDATE_CV_SD: + key = "y_pred_train_uq_model" + if key in Xy_data: + return _as_float_array(Xy_data[key]) + if "y_pred_train_sd" in Xy_data: + return _as_float_array(Xy_data["y_pred_train_sd"]) + return None + if candidate == CANDIDATE_CONFORMAL: + hw = float(Xy_data.get("conformal_half_width", float("nan"))) + if not np.isfinite(hw) or hw < 0: + return None + return np.full(n, hw, dtype=float) + if candidate == CANDIDATE_META_TOTAL: + key = "y_pred_train_uq_total" + if key not in Xy_data: + return None + return _as_float_array(Xy_data[key]) + return None + + +def _raw_candidate_split( + candidate: str, + split: str, + Xy_data: Mapping[str, Any], +) -> Optional[np.ndarray]: + if split == "train": + n = len(Xy_data["y_train"]) + elif split == "test": + n = len(Xy_data["y_test"]) + elif split == "external": + if "X_external" not in Xy_data: + return None + n = len(Xy_data["X_external"]) + else: + return None + + if candidate == CANDIDATE_CV_SD: + uq_key = f"y_pred_{split}_uq_model" + sd_key = f"y_pred_{split}_sd" + if uq_key in Xy_data: + return _as_float_array(Xy_data[uq_key]) + if sd_key in Xy_data: + return _as_float_array(Xy_data[sd_key]) + return None + if candidate == CANDIDATE_CONFORMAL: + hw = float(Xy_data.get("conformal_half_width", float("nan"))) + if not np.isfinite(hw) or hw < 0: + return None + return np.full(n, hw, dtype=float) + if candidate == CANDIDATE_META_TOTAL: + key = f"y_pred_{split}_uq_total" + if key not in Xy_data: + return None + return _as_float_array(Xy_data[key]) + return None + + +def _available_candidates( + candidates: Sequence[str], + Xy_data: Mapping[str, Any], + problem_type: str, +) -> List[str]: + reg = problem_type.lower() == "reg" + out: List[str] = [] + for cand in candidates: + if cand == CANDIDATE_CONFORMAL and not reg: + continue + raw = _raw_candidate_train(cand, Xy_data) + if raw is None or not np.isfinite(raw).any(): + continue + out.append(cand) + return out + + +def evaluate_uq_candidates( + Xy_data: Mapping[str, Any], + args: Any, + problem_type: str, +) -> Dict[str, Any]: + """ + Score candidates on training OOF residuals with an inner hold-out split. + + Returns dict with keys: selected, scaler_params, candidate_scores, coverage, n_eval. + """ + candidates_cfg = getattr(args, "uq_auto_candidates", None) or list(DEFAULT_CANDIDATES) + if isinstance(candidates_cfg, str): + candidates_cfg = [c.strip() for c in candidates_cfg.split(",") if c.strip()] + + scaler_method = str(getattr(args, "uq_auto_scaler", "global_multiplicative")) + coverage = float(getattr(args, "conformal_coverage", 0.9)) + min_samples = int(getattr(args, "uq_auto_min_samples", 12)) + seed = int(getattr(args, "uq_auto_random_state", getattr(args, "seed", 0))) + metric_weights = getattr(args, "uq_auto_metric_weights", None) + + abs_resid = _train_abs_residuals(Xy_data) + available = _available_candidates(candidates_cfg, Xy_data, problem_type) + + if not available: + raise RuntimeError( + "Auto uncertainty mode found no valid candidates on training data." + ) + + n = abs_resid.size + if n < min_samples: + warnings.warn( + f"Training size {n} < uq_auto_min_samples={min_samples}; " + "using unscaled best-effort selection.", + UserWarning, + stacklevel=2, + ) + + candidate_scores: Dict[str, float] = {} + scaler_by_candidate: Dict[str, Dict[str, Any]] = {} + + for cand in available: + u_raw = _raw_candidate_train(cand, Xy_data) + if u_raw is None: + continue + if n >= max(min_samples, 8): + idx = np.arange(n) + fit_ix, eval_ix = train_test_split( + idx, + test_size=0.25, + random_state=seed, + shuffle=True, + ) + params = fit_uncertainty_scaler( + scaler_method, u_raw[fit_ix], abs_resid[fit_ix] + ) + u_scaled_eval = apply_uncertainty_scaler(scaler_method, u_raw[eval_ix], params) + score = score_uncertainty_candidate( + u_scaled_eval, abs_resid[eval_ix], coverage, metric_weights + ) + else: + params = fit_uncertainty_scaler(scaler_method, u_raw, abs_resid) + u_scaled_eval = apply_uncertainty_scaler(scaler_method, u_raw, params) + score = score_uncertainty_candidate( + u_scaled_eval, abs_resid, coverage, metric_weights + ) + + candidate_scores[cand] = score + scaler_by_candidate[cand] = params + + if not candidate_scores: + raise RuntimeError("Auto uncertainty scoring produced no valid candidates.") + + best_score = min(candidate_scores.values()) + tied = [c for c, s in candidate_scores.items() if abs(s - best_score) < 1e-12] + selected = next(c for c in TIE_BREAK_ORDER if c in tied) + + # Refit scaler on full training OOF for deployment + u_full = _raw_candidate_train(selected, Xy_data) + assert u_full is not None + final_params = fit_uncertainty_scaler(scaler_method, u_full, abs_resid) + + return { + "selected": selected, + "scaler_method": scaler_method, + "scaler_params": final_params, + "candidate_scores": candidate_scores, + "coverage_target": coverage, + "n_eval": int(n), + "available_candidates": available, + } + + +def apply_auto_uq( + self: Any, + Xy_data: Dict[str, Any], + model_data: Mapping[str, Any], + params_dir: str, +) -> Dict[str, Any]: + """ + Select and attach auto-calibrated uncertainty columns to ``Xy_data``. + + Writes ``y_pred_{split}_uq_auto`` and run-level metadata under PREDICT/. + """ + _ = params_dir + if not bool(getattr(self.args, "uq_auto_enable", False)): + return Xy_data + + if model_data["type"].lower() != "reg": + clas_mode = str(getattr(self.args, "uq_auto_clas_mode", "error")).lower() + if clas_mode == "error": + raise ValueError( + "uq_auto_enable is only supported for regression (problem_type='reg') " + "in this release." + ) + warnings.warn( + "Auto uncertainty for classification is not calibrated; skipping.", + UserWarning, + stacklevel=2, + ) + return Xy_data + + selection = evaluate_uq_candidates(Xy_data, self.args, model_data["type"]) + selected = selection["selected"] + scaler_method = selection["scaler_method"] + params = selection["scaler_params"] + + splits = ["train", "test"] + if "X_external" in Xy_data: + splits.append("external") + + for split in splits: + u_raw = _raw_candidate_split(selected, split, Xy_data) + if u_raw is None: + continue + u_auto = apply_uncertainty_scaler(scaler_method, u_raw, params) + Xy_data[f"y_pred_{split}_uq_auto"] = np.asarray(u_auto, dtype=float).tolist() + + Xy_data["uq_auto_selected"] = selected + Xy_data["uq_auto_scaler_params"] = params + Xy_data["uq_auto_metadata"] = selection + + meta_path = Path("PREDICT") / "uq_auto_metadata.json" + meta_path.parent.mkdir(parents=True, exist_ok=True) + serializable = { + k: (v if not isinstance(v, dict) else dict(v)) + for k, v in selection.items() + } + with meta_path.open("w", encoding="utf-8") as fh: + json.dump(serializable, fh, indent=2) + + return Xy_data + + +def get_bo_ready_prediction_bundle( + Xy_data: Mapping[str, Any], + split: str, + y_col: str, +) -> Dict[str, Any]: + """ + Return mean, sigma, and provenance for Bayesian-optimization consumers. + + Uses auto uncertainty when present, otherwise CV SD. + """ + pred_key = f"y_pred_{split}" + auto_key = f"y_pred_{split}_uq_auto" + sd_key = f"y_pred_{split}_sd" + + mean = np.asarray(Xy_data[pred_key], dtype=float) + if auto_key in Xy_data: + sigma = np.asarray(Xy_data[auto_key], dtype=float) + source = str(Xy_data.get("uq_auto_selected", "auto")) + elif sd_key in Xy_data: + sigma = np.asarray(Xy_data[sd_key], dtype=float) + source = "cv_sd_fallback" + else: + raise KeyError(f"No uncertainty available for split {split!r}.") + + return { + "y_col": y_col, + "mean": mean, + "sigma": np.maximum(sigma, 0.0), + "provenance": source, + } diff --git a/tests/test_uq_auto.py b/tests/test_uq_auto.py new file mode 100644 index 0000000..4ad8299 --- /dev/null +++ b/tests/test_uq_auto.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python + +"""Tests for automatic uncertainty selection and calibration.""" + +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from robert import RobertModel +from robert.uq_auto import ( + CANDIDATE_CV_SD, + apply_uncertainty_scaler, + evaluate_uq_candidates, + fit_uncertainty_scaler, + score_uncertainty_candidate, +) + +_REG_CSV = Path(__file__).resolve().parent / "Robert_example.csv" + +_FAST = { + "model": ["RF"], + "n_iter": 2, + "init_points": 2, + "repeat_kfolds": 2, + "kfold": 3, + "pfi_epochs": 1, + "seed": 42, + "uq_auto_enable": True, +} + + +class _Args: + uq_auto_candidates = ["cv_sd", "conformal"] + uq_auto_scaler = "global_multiplicative" + uq_auto_metric_weights = None + uq_auto_min_samples = 5 + uq_auto_random_state = 0 + conformal_coverage = 0.9 + seed = 0 + + +def test_global_multiplicative_scaler_monotone(): + u = np.array([0.5, 1.0, 2.0]) + r = np.array([0.6, 1.2, 2.4]) + params = fit_uncertainty_scaler("global_multiplicative", u, r) + scaled = apply_uncertainty_scaler("global_multiplicative", u, params) + assert np.all(scaled >= 0) + assert scaled[0] < scaled[1] < scaled[2] + + +def test_none_scaler_identity(): + u = np.array([1.0, 2.0, 3.0]) + params = fit_uncertainty_scaler("none", u, u) + out = apply_uncertainty_scaler("none", u, params) + assert np.allclose(out, u) + + +def test_score_prefers_calibrated_scale(): + abs_res = np.array([1.0, 1.0, 1.0, 1.0]) + bad = np.full(4, 10.0) + good = np.full(4, 1.0) + assert score_uncertainty_candidate(good, abs_res, 0.9) < score_uncertainty_candidate( + bad, abs_res, 0.9 + ) + + +def test_evaluate_uq_candidates_deterministic(): + rng = np.random.default_rng(0) + n = 30 + y = rng.normal(size=n) + oof = y + rng.normal(scale=0.2, size=n) + Xy = { + "y_train": y.tolist(), + "y_pred_train_all": [[v] for v in oof], + "y_pred_train_sd": (0.3 + 0.1 * rng.random(n)).tolist(), + "conformal_half_width": 0.5, + } + sel1 = evaluate_uq_candidates(Xy, _Args(), "reg") + sel2 = evaluate_uq_candidates(Xy, _Args(), "reg") + assert sel1["selected"] == sel2["selected"] + assert sel1["selected"] in (CANDIDATE_CV_SD, "conformal") + + +def test_fit_predict_auto_uncertainty(tmp_path): + df = pd.read_csv(_REG_CSV, encoding="utf-8") + X = df.drop(columns=["Target_values"]) + y = df["Target_values"] + n_fit = 25 + model = RobertModel( + problem_type="reg", + filter_mode="no_pfi", + workdir=tmp_path, + names="Name", + **_FAST, + ) + model.fit(X.iloc[:n_fit], y.iloc[:n_fit]) + X_hold = X.iloc[n_fit:].drop_duplicates(subset=["Name"], keep="first") + y_hat, u_auto = model.predict(X_hold, return_uncertainty="auto") + assert y_hat.shape == u_auto.shape + assert np.isfinite(u_auto).all() and (u_auto >= 0).all() + y_d, u2, meta = model.predict(X_hold, return_uncertainty="auto_decomposed") + assert y_d.shape == u2.shape + assert isinstance(meta, dict) + assert meta.get("selected") in ("cv_sd", "conformal", "meta_total") + meta_path = tmp_path / "PREDICT" / "uq_auto_metadata.json" + assert meta_path.is_file() + + +def test_auto_classification_raises(tmp_path): + clas_csv = Path(__file__).resolve().parent / "Robert_example_clas.csv" + df = pd.read_csv(clas_csv, encoding="utf-8") + X = df.drop(columns=["Target_values"]) + y = df["Target_values"] + fast = {k: v for k, v in _FAST.items() if k != "uq_auto_enable"} + model = RobertModel( + problem_type="clas", + filter_mode="no_pfi", + workdir=tmp_path, + names="Name", + **fast, + ) + model.fit(X.iloc[:20], y.iloc[:20]) + X_hold = X.iloc[20:].drop_duplicates(subset=["Name"], keep="first") + with pytest.raises(ValueError, match="auto"): + model.predict(X_hold, return_uncertainty="auto") From ae261d66e00664e041bc2ecd4598b3821e6ad45c Mon Sep 17 00:00:00 2001 From: rlaplaza Date: Sat, 16 May 2026 17:54:18 +0200 Subject: [PATCH 3/8] feat(models): add XGBoost and voting ensemble support Add xgboost dependency and XGB/VR model paths in CURATE/GENERATE, with CURATE tests for XGB and voting regressor outputs. Co-authored-by: Cursor --- environment/env.yaml | 1 + robert/curate.py | 15 +- robert/generate.py | 44 +++-- robert/utils.py | 395 ++++++++++++++++++++++++++++++++---------- setup.py | 1 + tests/test_1curate.py | 26 +++ 6 files changed, 375 insertions(+), 107 deletions(-) diff --git a/environment/env.yaml b/environment/env.yaml index a37dde5..be4c2e7 100644 --- a/environment/env.yaml +++ b/environment/env.yaml @@ -15,6 +15,7 @@ dependencies: - pip: - aqme==2.0.0 - robert==2.1.0 + - xgboost==2.1.4 - PySide6==6.9.2 - PySide6-Addons==6.9.2 - PySide6-Essentials==6.9.2 diff --git a/robert/curate.py b/robert/curate.py index 179352d..c064668 100644 --- a/robert/curate.py +++ b/robert/curate.py @@ -62,8 +62,16 @@ import time import os import pandas as pd -from robert.utils import (load_variables, finish_print, load_database, pearson_map, - check_clas_problem, categorical_transform, correlation_filter) +from robert.utils import ( + load_variables, + finish_print, + load_database, + pearson_map, + check_clas_problem, + categorical_transform, + correlation_filter, + should_plot_curate_pearson, +) class curate: @@ -115,7 +123,8 @@ def __init__(self, **kwargs): _ = self.save_curate(csv_df) # create Pearson heatmap (use the general filtered dataframe) - _ = pearson_map(self,csv_df,'curate') + if should_plot_curate_pearson(self.args): + _ = pearson_map(self, csv_df, "curate") # finish the printing of the CURATE info file _ = finish_print(self,start_time,'CURATE') diff --git a/robert/generate.py b/robert/generate.py index 192832b..640a4da 100644 --- a/robert/generate.py +++ b/robert/generate.py @@ -27,6 +27,8 @@ 4. 'NN' (MLP neural network) 5. 'GP' (Gaussian Process) 6. 'AdaB' (AdaBoost) + 7. 'VR' (Voting Regressor/Classifier with RF, GB and NN base estimators) + 8. 'XGB' (XGBoost regressor/classifier; opt-in, not in default model list) custom_params : str, default=None Define new parameters for the ML models used in the hyperoptimization workflow. The path to the folder containing all the yaml files should be specified (i.e. custom_params='YAML_FOLDER') @@ -50,7 +52,7 @@ Number of initial points for Bayesian optimization (exploration) n_iter : int, default=10 Number of iterations for Bayesian optimization (exploitation) - expect_improv : int, default=0.05 + expect_improv : float, default=0.05 Expected improvement for Bayesian optimization pfi_filter : bool, default=True Activate the PFI filter of descriptors. @@ -90,11 +92,12 @@ import os import time from robert.utils import ( - load_variables, + load_variables, finish_print, load_database, check_clas_problem, - prepare_sets + prepare_sets, + should_plot_generate_heatmap, ) from robert.generate_utils import ( BO_workflow, @@ -181,13 +184,32 @@ def __init__(self, **kwargs): # apply the PFI descriptor filter if it's activated if self.args.pfi_filter: - # load database, discard user-defined descriptors and perform data checks - csv_df, csv_X, csv_y = load_database(self,csv_to_load,"generate",print_info=False) + # Reuse processed dataframe: BO adds a Set column in-memory; drop it and + # rebuild X/y splits without re-reading the curated CSV from disk. + csv_df_pfi = csv_df.drop(columns=["Set"], errors="ignore") + cols_to_ignore = [ + col for col in self.args.ignore if col in csv_df_pfi.columns + ] + csv_df_ignore = csv_df_pfi.drop(cols_to_ignore, axis=1) + csv_X = csv_df_ignore.drop([self.args.y], axis=1) + csv_y = csv_df_ignore[self.args.y] + csv_X = csv_X[sorted([col for col in csv_X.columns])] # standardizes and separates an external test set - Xy_data = prepare_sets(self,csv_df,csv_X,csv_y,None,self.args.names,None,None,None,BO_opt=True) - - _ = PFI_workflow(self, csv_df, ML_model, Xy_data) + Xy_data = prepare_sets( + self, + csv_df_pfi, + csv_X, + csv_y, + None, + self.args.names, + None, + None, + None, + BO_opt=True, + ) + + _ = PFI_workflow(self, csv_df_pfi, ML_model, Xy_data) # Restore the original csv_name self.args.csv_name = original_csv_name @@ -199,13 +221,15 @@ def __init__(self, **kwargs): _ = detect_best(f'{dir_csv}/No_PFI') # create heatmap plot(s) - _ = heatmap_workflow(self,"No_PFI") + if should_plot_generate_heatmap(self.args): + _ = heatmap_workflow(self, "No_PFI") # detect best and create heatmap for PFI models if self.args.pfi_filter: try: # if no models were found _ = detect_best(f'{dir_csv}/PFI') - _ = heatmap_workflow(self,"PFI") + if should_plot_generate_heatmap(self.args): + _ = heatmap_workflow(self, "PFI") except UnboundLocalError: pass diff --git a/robert/utils.py b/robert/utils.py index 33fc3f6..88c4c78 100644 --- a/robert/utils.py +++ b/robert/utils.py @@ -11,6 +11,8 @@ import yaml import ast import shutil +import importlib +from contextlib import contextmanager from pathlib import Path import pandas as pd import numpy as np @@ -19,13 +21,10 @@ # This prevents numerical differences between Windows/Ubuntu in parallel operations os.environ["LOKY_MAX_CPU_COUNT"] = "1" from matplotlib import pyplot as plt -import importlib import matplotlib.patches as mpatches import matplotlib.colors as mcolor from matplotlib.legend_handler import HandlerPatch from matplotlib.ticker import FormatStrFormatter -import shap -import seaborn as sb from scipy import stats from importlib.resources import files # sklearnex was deactivated in ROBERT v2.1 because it only accelerated RF @@ -46,7 +45,10 @@ RandomForestClassifier, GradientBoostingClassifier, AdaBoostClassifier, + VotingRegressor, + VotingClassifier, ) +from xgboost import XGBClassifier, XGBRegressor from sklearn.gaussian_process import GaussianProcessRegressor, GaussianProcessClassifier from sklearn.neural_network import MLPRegressor, MLPClassifier from sklearn.linear_model import LinearRegression @@ -57,12 +59,48 @@ from sklearn.inspection import permutation_importance from sklearn.exceptions import ConvergenceWarning from robert.argument_parser import set_options, var_dict -from bayes_opt import BayesianOptimization -from bayes_opt import acquisition import warnings # this avoids warnings from sklearn warnings.filterwarnings("ignore") +@contextmanager +def _mpl_plot_context(): + """Reload pyplot once per plotting batch (threading workaround).""" + importlib.reload(plt) + yield + + +def plot_verbosity_level(args) -> int: + """Return clipped plot verbosity in {0, 1, 2}; default matches legacy (all plots).""" + try: + v = int(getattr(args, "plot_verbosity", var_dict["plot_verbosity"])) + except (TypeError, ValueError): + v = int(var_dict["plot_verbosity"]) + return max(0, min(2, v)) + + +def should_plot_curate_pearson(args) -> bool: + return plot_verbosity_level(args) >= 1 + + +def should_plot_generate_heatmap(args) -> bool: + return plot_verbosity_level(args) >= 1 + + +def should_plot_verify_metrics(args) -> bool: + return plot_verbosity_level(args) >= 1 + + +def should_plot_predict_results(args) -> bool: + """Main PREDICT figures (e.g. Results_*); tied to predict_diagnostics for backward compatibility.""" + return bool(getattr(args, "predict_diagnostics", True)) and plot_verbosity_level(args) >= 1 + + +def should_plot_predict_deep_diagnostics(args) -> bool: + """SHAP, PFI, Pearson heatmap, outliers, distribution plots.""" + return bool(getattr(args, "predict_diagnostics", True)) and plot_verbosity_level(args) >= 2 + + robert_version = "2.1.0" time_run = time.strftime("%Y/%m/%d %H:%M:%S", time.localtime()) robert_ref = "Dalmau, D.; Alegre Requena, J. V. WIREs Comput Mol Sci. 2024, 14, e1733." @@ -77,18 +115,38 @@ def load_from_yaml(self): txt_yaml = f"\no Importing ROBERT parameters from {self.varfile}" error_yaml = False + param_list = {} # Variables will be updated from YAML file try: - if os.path.exists(self.varfile): - if os.path.basename(Path(self.varfile)).split('.')[1] in ["yaml", "yml", "txt"]: - with open(self.varfile, "r") as file: - try: - param_list = yaml.load(file, Loader=yaml.SafeLoader) - except (yaml.scanner.ScannerError,yaml.parser.ParserError): - txt_yaml = f'\nx Error while reading {self.varfile}. Edit the yaml file and try again (i.e. use ":" instead of "=" to specify variables)' - error_yaml = True - if not error_yaml: + if not os.path.exists(self.varfile): + return ( + self, + "\nx The specified yaml file containing parameters was not found! Make sure that the valid params file is in the folder where you are running the code.", + ) + base_name = os.path.basename(Path(self.varfile)) + ext = base_name.rsplit(".", 1)[-1] if "." in base_name else "" + if ext in ["yaml", "yml", "txt"]: + with open(self.varfile, "r") as file: + try: + loaded = yaml.load(file, Loader=yaml.SafeLoader) + param_list = loaded if isinstance(loaded, dict) else {} + except (yaml.scanner.ScannerError,yaml.parser.ParserError): + txt_yaml = f'\nx Error while reading {self.varfile}. Edit the yaml file and try again (i.e. use ":" instead of "=" to specify variables)' + error_yaml = True + else: + txt_yaml = ( + f"\nx Unsupported parameter file extension for {self.varfile!r}. " + "Use .yaml, .yml, or .txt." + ) + error_yaml = True + if not error_yaml and param_list: for param in param_list: + if param not in var_dict: + print( + f"Warning! YAML key [{param}] is not a recognized ROBERT option; " + "ignored. See online documentation for valid option names." + ) + continue if hasattr(self, param): if getattr(self, param) != param_list[param]: setattr(self, param, param_list[param]) @@ -173,6 +231,10 @@ def command_line_args(exe_type,sys_args): "seed", "init_points", "n_iter", + "plot_verbosity", + "uq_top_k_models", + "uq_auto_min_samples", + "uq_auto_random_state", ] float_args = [ 'pfi_threshold', @@ -182,7 +244,9 @@ def command_line_args(exe_type,sys_args): 'test_set', 'desc_thres', 'alpha', - 'expect_improv' + 'expect_improv', + 'conformal_calib_frac', + 'conformal_coverage', ] for arg in var_dict: @@ -236,7 +300,7 @@ def command_line_args(exe_type,sys_args): --discard "[COL1,COL2,etc]" (default=[]) : CSV columns that will be removed * Affecting data curation in CURATE: - --kfold INT (default='auto') : number of folds for k-fold cross-validation of the RFECV feature selector. If 'auto', the program does a LOOCV for databases with less than 50 points, and 5-fold CV for larger databases + --kfold INT (default=5) : number of folds for RFECV feature selection during curation (RepeatedKFold with repeat_kfolds; see docs) --categorical "onehot" or "numbers" (default="onehot") : type of conversion for categorical variables --corr_filter_x BOOL (default=True) : activate/disable the correlation filter of descriptors X @@ -246,16 +310,27 @@ def command_line_args(exe_type,sys_args): --pfi_max INT (default=0) : number of features to keep in the PFI models * Affecting tests, VERIFY: - --kfold INT (default='auto') : number of folds for k-fold cross-validation. If 'auto', the program does a LOOCV for databases with less than 50 points, and 5-fold CV for larger databases + --repeat_kfolds INT (default=10) : repetitions for repeated k-fold CV during verification * Affecting predictions, PREDICT: --t_value INT (default=2) : t-value threshold to identify outliers --shap_show INT (default=10) : maximum number of descriptors shown in the SHAP plot + --plot_verbosity INT (default=2) : 0=no figures, 1=workflow summaries, 2=full diagnostics (see docs) + --predict_diagnostics BOOL (default=True) : when False, skip SHAP/PFI/heatmap/outlier/distribution plots + --conformal_enable BOOL (default=True) : regression split-conformal intervals (half-width column in outputs) + --conformal_calib_frac FLOAT (default=0.15) : fraction of data for conformal calibration + --conformal_coverage FLOAT (default=0.9) : target coverage for conformal intervals + --uq_enable_meta BOOL (default=False) : enable meta-model uncertainty (Python API and advanced PREDICT) + --uq_top_k_models INT (default=3) : top-k GENERATE models for meta UQ + --uq_auto_enable BOOL (default=False) : automatic uncertainty selection (regression; see docs/API) + + Note: structured options (e.g. uq_auto_candidates, uq_auto_metric_weights) are easiest to set + via a YAML varfile or the Python API (RobertModel); see https://robert.readthedocs.io * Affecting SMILES workflows, AQME: - --qdescp_keywords STR (default="") : extra keywords in QDESCP (i.e. "--qdescp_atoms [Ir] --alpb h2o") - --csearch_keywords STR (default="--sample 50") : extra keywords in CSEARCH + --qdescp_keywords STR (default="") : extra tokens appended to the internal AQME QDESCP command (e.g. xTB/atom selections) --descp_lvl (default="interpret") "interpret", "denovo" or "full" : type of descriptor calculation + Advanced conformer/CSEARCH settings follow AQME; pass compatible AQME flags inside --qdescp_keywords only where they apply to the same AQME invocation, or run AQME standalone if you need a custom CSEARCH workflow. o How to cite ROBERT: @@ -437,7 +512,7 @@ def load_variables(kwargs, robert_module): if robert_module.upper() in ['CURATE','GENERATE']: if self.type.lower() == 'clas': - if ('MVL' or 'mvl') in self.model: + if any(m.upper() == "MVL" for m in self.model): self.model = [x if x.upper() != 'MVL' else 'AdaB' for x in self.model] models_gen = [] # use capital letters in all the models @@ -753,8 +828,8 @@ def correlation_filter(self, csv_df): scoring = get_scoring_key(self.args.type,self.args.error_type) # Use different strategies for models without feature_importances_ - if model.upper() in ['NN', 'GP']: - # For NN and GP, use a simpler approach: select top features by correlation with y + if model.upper() in ['NN', 'GP', 'VR']: + # For NN, GP and VR, use a simpler approach: select top features by correlation with y # after initial fit, then use permutation importance to rank them # Train the model once on all features @@ -804,7 +879,7 @@ def correlation_filter(self, csv_df): # For MVL, use absolute coefficients as importance feature_importances = np.abs(selector.estimator_.coef_) else: - # RF, GB, ADAB have feature_importances_ + # RF, GB, ADAB, XGB have feature_importances_ feature_importances = selector.estimator_.feature_importances_ # Round importances to reduce floating point variance @@ -891,7 +966,22 @@ def load_minimal_model(model): 'GP': { 'n_restarts_optimizer': 30, }, + 'XGB': { + 'n_estimators': 30, + 'learning_rate': 0.1, + 'max_depth': 10, + 'min_child_weight': 1, + 'subsample': 1.0, + 'colsample_bytree': 1.0, + 'reg_alpha': 0.0, + 'reg_lambda': 1.0, + }, 'MVL': { + }, + 'VR': { + 'w_rf': 1.0, + 'w_gb': 1.0, + 'w_nn': 1.0, } } @@ -1006,8 +1096,8 @@ def sanity_checks(self, type_checks, module, columns_csv): self.split = 'rnd' for model_type in self.model: - if model_type.upper() not in ['RF','MVL','GB','GP','ADAB','NN'] or len(self.model) == 0: - self.log.write(f"\nx The model option used is not valid! Options: 'RF', 'MVL', 'GB', 'ADAB', 'NN'") + if model_type.upper() not in ['RF','MVL','GB','GP','ADAB','NN','XGB','VR'] or len(self.model) == 0: + self.log.write(f"\nx The model option used is not valid! Options: 'RF', 'MVL', 'GB', 'GP', 'ADAB', 'NN', 'XGB', 'VR'") curate_valid = False if model_type.upper() == 'MVL' and self.type.lower() == 'clas': self.log.write(f"\nx Multivariate linear models (MVL in the model_type option) are not compatible with classificaton!") @@ -1181,10 +1271,17 @@ def load_database(self,csv_load,module,print_info=True,external_test=False): external_test = True txt_load = '' - # this part fixes CSV files that use ";" as separator + # Semicolon-separated "CSV" from Excel: peek at the first rows before reading the whole file. + _scan_limit = 64 + head_lines = [] with open(csv_load, 'r', encoding='utf-8') as file: - lines = file.readlines() - if lines[1].count(';') > 1: + for _, line in zip(range(_scan_limit), file): + head_lines.append(line) + semicolon_issue = len(head_lines) >= 2 and head_lines[1].count(';') > 1 + if semicolon_issue: + with open(csv_load, 'r', encoding='utf-8') as file: + lines = file.readlines() + if semicolon_issue: new_csv_name = os.path.basename(csv_load).split('.csv')[0].split('.CSV')[0]+'_original.csv' shutil.move(csv_load, Path(os.path.dirname(csv_load)).joinpath(new_csv_name)) new_csv_file = open(csv_load, "w") @@ -1563,6 +1660,8 @@ def generate_lhs_points(pbounds, n_points, random_state=None): def BO_optimizer(self,bo_data,Xy_data): + from bayes_opt import BayesianOptimization, acquisition + # Define an acquisition function for Bayesian optimization _ = acquisition.ExpectedImprovement(xi=self.args.expect_improv) @@ -1650,7 +1749,22 @@ def BO_hyperparams(model_name): }, 'GP': { 'n_restarts_optimizer': (0, 100), - } + }, + 'XGB': { + 'n_estimators': (10, 100), + 'learning_rate': (0.01, 0.3), + 'max_depth': (3, 20), + 'min_child_weight': (1, 10), + 'subsample': (0.7, 1.0), + 'colsample_bytree': (0.25, 1.0), + 'reg_alpha': (0, 1.0), + 'reg_lambda': (0, 1.0), + }, + 'VR': { + 'w_rf': (0.1, 5.0), + 'w_gb': (0.1, 5.0), + 'w_nn': (0.1, 5.0), + }, } return model_BO_params[model_name] @@ -1675,7 +1789,7 @@ def model_adjust_params(self,model_name,params): ''' - if model_name != 'MVL': + if model_name not in ['MVL', 'VR']: params['random_state'] = self.args.seed if model_name in ['RF','GB']: @@ -1684,6 +1798,11 @@ def model_adjust_params(self,model_name,params): params['min_samples_split'] = round(params['min_samples_split']) params['min_samples_leaf'] = round(params['min_samples_leaf']) + elif model_name == 'XGB': + params['n_estimators'] = round(params['n_estimators']) + params['max_depth'] = round(params['max_depth']) + params['min_child_weight'] = round(params['min_child_weight']) + elif model_name == 'NN': # add solver first params['solver'] = 'lbfgs' @@ -1697,6 +1816,17 @@ def model_adjust_params(self,model_name,params): elif model_name == 'GP': params['n_restarts_optimizer'] = round(params['n_restarts_optimizer']) + elif model_name == 'VR': + # VR only optimizes ensemble weights; base estimators receive deterministic seeds. + if all(weight_key in params for weight_key in ['w_rf', 'w_gb', 'w_nn']): + params['weights'] = [ + float(params.pop('w_rf')), + float(params.pop('w_gb')), + float(params.pop('w_nn')), + ] + elif 'weights' in params: + params['weights'] = [float(weight) for weight in params['weights']] + return params @@ -1721,6 +1851,14 @@ def load_model(self, model_name, **params): else: loaded_model = GradientBoostingClassifier(**params) + elif model_name == 'XGB': + if 'n_jobs' not in params: + params['n_jobs'] = 1 + if self.args.type.lower() == 'reg': + loaded_model = XGBRegressor(**params) + else: + loaded_model = XGBClassifier(**params) + elif model_name == 'NN': # create the hidden layers architecture first params = setup_hidden_layers(params) @@ -1744,6 +1882,26 @@ def load_model(self, model_name, **params): elif model_name == 'MVL': loaded_model = LinearRegression(**params) + + elif model_name == 'VR': + weights = params.pop('weights', [1.0, 1.0, 1.0]) + weights = [float(weight) for weight in weights] + seed = self.args.seed + + if self.args.type.lower() == 'reg': + voting_estimators = [ + ('rf', RandomForestRegressor(n_estimators=100, random_state=seed, n_jobs=1)), + ('gb', GradientBoostingRegressor(random_state=seed)), + ('nn', MLPRegressor(hidden_layer_sizes=(50,), max_iter=500, solver='lbfgs', random_state=seed)), + ] + loaded_model = VotingRegressor(estimators=voting_estimators, weights=weights) + else: + voting_estimators = [ + ('rf', RandomForestClassifier(n_estimators=100, random_state=seed, n_jobs=1)), + ('gb', GradientBoostingClassifier(random_state=seed)), + ('nn', MLPClassifier(hidden_layer_sizes=(50,), max_iter=500, solver='lbfgs', random_state=seed)), + ] + loaded_model = VotingClassifier(estimators=voting_estimators, weights=weights) return loaded_model @@ -1808,6 +1966,27 @@ def _raw_data_dir_from_best_params(params_dir): return Path(os.getcwd()).joinpath("GENERATE", "Raw_data", p.name) +PARAMS_DIR_BEST_MODEL_MARK = "GENERATE/Best_model" + + +def path_generate_best_model(workdir=None, *subparts: str) -> Path: + """ + Path to ``GENERATE/Best_model`` (optionally ``/No_PFI``, ``/PFI``, ...). + + Parameters + ---------- + workdir : pathlib.Path or str, optional + Root directory (default: current working directory). + *subparts : str + Additional path segments under Best_model (e.g. ``"No_PFI"``). + """ + root = Path(workdir) if workdir is not None else Path.cwd() + p = root / "GENERATE" / "Best_model" + if subparts: + p = p.joinpath(*subparts) + return p + + def discover_top_k_model_candidates(raw_data_dir, top_k, weighting="score_weighted"): """ List up to ``top_k`` model parameter CSV paths ranked by GENERATE combined score. @@ -2122,7 +2301,7 @@ def _apply_full_refit_split_conformal(self, model_data, Xy_data, loaded_model, y hw = _conformal_abs_residual_quantile(abs_res, cov) Xy_data["conformal_half_width"] = hw - return Xy_data + return Xy_data, m_full def load_n_predict(self, model_data, Xy_data, BO_opt=False, verify_job=False): @@ -2138,9 +2317,10 @@ def load_n_predict(self, model_data, Xy_data, BO_opt=False, verify_job=False): y_cv_mean_train = np.asarray(Xy_data["y_pred_train"], dtype=float) if not BO_opt: - Xy_data = _apply_full_refit_split_conformal( + Xy_data, fitted_model = _apply_full_refit_split_conformal( self, model_data, Xy_data, loaded_model, y_cv_mean_train ) + Xy_data["_fitted_model"] = fitted_model # combine all the predictions from the repeated CV (metrics of the train set) y_all_list,y_pred_all_list = [],[] @@ -2488,37 +2668,49 @@ def create_heatmap(self,csv_df,suffix,path_raw): """ Graph the heatmap """ + import seaborn as sb + + with _mpl_plot_context(): + csv_df = csv_df.sort_index(ascending=False) + sb.set(font_scale=1.2, style="ticks") + _, ax = plt.subplots(figsize=(7.45, 6)) + cmap_blues_75_percent_512 = [ + mcolor.rgb2hex(c) for c in plt.cm.Blues(np.linspace(0, 0.8, 512)) + ] + csv_df.replace([np.inf, -np.inf], np.nan, inplace=True) + ax = sb.heatmap( + csv_df, + annot=True, + linewidth=1, + cmap=cmap_blues_75_percent_512, + cbar_kws={"label": f"Combined {self.args.error_type.upper()}"}, + mask=csv_df.isnull(), + ) + fontsize = 14 + ax.set_xlabel("ML Model", fontsize=fontsize) + ax.set_ylabel("", fontsize=fontsize) + ax.tick_params(axis="x", which="major", labelsize=fontsize) + ax.tick_params( + axis="y", which="both", left=False, right=False, labelleft=False + ) + title_fig = f"Heatmap ML models {suffix}" + plt.title(title_fig, y=1.04, fontsize=fontsize, fontweight="bold") + sb.despine(top=False, right=False) + name_fig = "_".join(title_fig.split()) + plt.savefig( + f"{path_raw.joinpath(name_fig)}.png", dpi=300, bbox_inches="tight" + ) - importlib.reload(plt) # needed to avoid threading issues - csv_df = csv_df.sort_index(ascending=False) - sb.set(font_scale=1.2, style='ticks') - _, ax = plt.subplots(figsize=(7.45,6)) - cmap_blues_75_percent_512 = [mcolor.rgb2hex(c) for c in plt.cm.Blues(np.linspace(0, 0.8, 512))] - # Replace inf values with NaN for proper heatmap visualization - csv_df.replace([np.inf, -np.inf], np.nan, inplace=True) - ax = sb.heatmap(csv_df, annot=True, linewidth=1, cmap=cmap_blues_75_percent_512, cbar_kws={'label': f'Combined {self.args.error_type.upper()}'}, mask=csv_df.isnull()) - fontsize = 14 - ax.set_xlabel("ML Model",fontsize=fontsize) - ax.set_ylabel("",fontsize=fontsize) - ax.tick_params(axis='x', which='major', labelsize=fontsize) - ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False) - title_fig = f'Heatmap ML models {suffix}' - plt.title(title_fig, y=1.04, fontsize = fontsize, fontweight="bold") - sb.despine(top=False, right=False) - name_fig = '_'.join(title_fig.split()) - plt.savefig(f'{path_raw.joinpath(name_fig)}.png', dpi=300, bbox_inches='tight') - - path_reduced = '/'.join(f'{path_raw}'.replace('\\','/').split('/')[-2:]) - self.args.log.write(f'\no {name_fig} succesfully created in {path_reduced}') + path_reduced = "/".join(f"{path_raw}".replace("\\", "/").split("/")[-2:]) + self.args.log.write(f"\no {name_fig} succesfully created in {path_reduced}") def graph_reg(self,Xy_data,params_dict,set_types,path_n_suffix,graph_style,csv_test=False,print_fun=True,sd_graph=False): ''' Plot regression graphs of predicted vs actual values for train, validation and test sets ''' + import seaborn as sb - # Create graph - importlib.reload(plt) # needed to avoid threading issues sb.set(style="ticks") _, ax = plt.subplots(figsize=(7.45,6)) @@ -2692,8 +2884,6 @@ def graph_clas(self,Xy_data,params_dict,set_type,path_n_suffix,csv_test=False,pr Plot a confusion matrix with the prediction vs actual values ''' - importlib.reload(plt) # needed to avoid threading issues - # Check if we need to use original class labels for display display_labels = None if 'class_0_label' in params_dict and 'class_1_label' in params_dict: @@ -2750,19 +2940,21 @@ def graph_clas(self,Xy_data,params_dict,set_type,path_n_suffix,csv_test=False,pr self.args.log.write(f" - Graph in: {path_reduced}") -def shap_analysis(self,Xy_data,model_data,path_n_suffix): +def shap_analysis(self, Xy_data, model_data, path_n_suffix, fitted_model=None): ''' Plots and prints the results of the SHAP analysis ''' + import shap - importlib.reload(plt) # needed to avoid threading issues - _, _ = plt.subplots(figsize=(7.45,6)) + _, _ = plt.subplots(figsize=(7.45, 6)) shap_plot_file = f'{os.path.dirname(path_n_suffix)}/SHAP_{os.path.basename(path_n_suffix)}.png' - # load and fit the ML model - loaded_model = load_model(self, model_data['model'], **model_data['params']) - loaded_model.fit(Xy_data['X_train_scaled'], Xy_data['y_train']) + if fitted_model is None: + loaded_model = load_model(self, model_data["model"], **model_data["params"]) + loaded_model.fit(Xy_data["X_train_scaled"], Xy_data["y_train"]) + else: + loaded_model = fitted_model # run the SHAP analysis and save the plot explainer = shap.Explainer(loaded_model.predict, Xy_data['X_train_scaled'], seed=model_data['seed']) @@ -2813,17 +3005,17 @@ def shap_analysis(self,Xy_data,model_data,path_n_suffix): plt.savefig(f'{shap_plot_file}', dpi=300, bbox_inches='tight') -def PFI_plot(self,Xy_data,model_data,path_n_suffix): +def PFI_plot(self, Xy_data, model_data, path_n_suffix, fitted_model=None): ''' Plots and prints the results of the PFI analysis ''' - - importlib.reload(plt) # needed to avoid threading issues pfi_plot_file = f'{os.path.dirname(path_n_suffix)}/PFI_{os.path.basename(path_n_suffix)}.png' - # load and fit the ML model - loaded_model = load_model(self, model_data['model'], **model_data['params']) - loaded_model.fit(Xy_data['X_train_scaled'], Xy_data['y_train']) + if fitted_model is None: + loaded_model = load_model(self, model_data["model"], **model_data["params"]) + loaded_model.fit(Xy_data["X_train_scaled"], Xy_data["y_train"]) + else: + loaded_model = fitted_model # select scoring function for PFI analysis based on the error type scoring, _, error_type = scoring_n_score(self,model_data,Xy_data,loaded_model) @@ -2868,8 +3060,8 @@ def outlier_plot(self,Xy_data,path_n_suffix,name_points,graph_style): ''' Plots and prints the results of the outlier analysis ''' + import seaborn as sb - importlib.reload(plt) # needed to avoid threading issues # detect outliers outliers_data, print_outliers = outlier_filter(self, Xy_data, name_points) @@ -3009,9 +3201,8 @@ def distribution_plot(self,Xy_data,path_n_suffix,params_dict): ''' Plots histogram (reg) or bin plot (clas). ''' + import seaborn as sb - # make graph - importlib.reload(plt) # needed to avoid threading issues sb.set(style="ticks") _, ax = plt.subplots(figsize=(7.45,6)) @@ -3183,12 +3374,12 @@ def get_prediction_results(model_data,y,y_pred_all): def get_error_labels(model_type): """ Returns the three error metric labels for the given model type. - + Parameters ---------- model_type : str The type of model: 'reg' for regression or 'clas' for classification - + Returns ------- tuple of str @@ -3200,12 +3391,37 @@ def get_error_labels(model_type): 'reg': ('r2', 'mae', 'rmse'), 'clas': ('acc', 'f1', 'mcc') } - + model_type_lower = model_type.lower() - + return error_labels[model_type_lower] +def _select_descriptors(self, df, descriptors, module): + """Subset *df* to model descriptors, applying categorical_transform if needed.""" + try: + return df[descriptors] + except KeyError: + try: + self.args.log.write( + "\n x There are missing descriptors in the test set! " + "Looking for categorical variables converted from CURATE" + ) + df = categorical_transform(self, df, module) + out = df[descriptors] + self.args.log.write( + " o The missing descriptors were successfully created" + ) + return out + except KeyError: + self.args.log.write( + " x There are still missing descriptors in the test set! " + f"The following descriptors are needed: {descriptors}" + ) + self.args.log.finalize() + sys.exit() + + def load_db_n_params(self,params_dir,suffix,suffix_title,module,print_load): ''' Loads the parameters and Xy databases from a folder, add scaled X data and print information @@ -3220,25 +3436,15 @@ def load_db_n_params(self,params_dir,suffix,suffix_title,module,print_load): csv_X = csv_X.drop(columns=['Set']) # keep only the descriptors used in the model - csv_X = csv_X[model_data['X_descriptors']] + csv_X = _select_descriptors(self, csv_X, model_data["X_descriptors"], module) # load and adjust external set (if any) csv_external_df, csv_X_external,csv_y_external = None,None,None if self.args.csv_test != '': csv_external_df,csv_X_external,csv_y_external = load_database(self,self.args.csv_test,'predict',external_test=True) - try: - csv_X_external = csv_X_external[model_data['X_descriptors']] - except KeyError: - # this might fail if the initial categorical variables have not been transformed - try: - self.args.log.write(f"\n x There are missing descriptors in the test set! Looking for categorical variables converted from CURATE") - csv_X_external = categorical_transform(self,csv_X_external,'predict') - csv_X_external = csv_X_external[model_data['X_descriptors']] - self.args.log.write(f" o The missing descriptors were successfully created") - except KeyError: - self.args.log.write(f" x There are still missing descriptors in the test set! The following descriptors are needed: {model_data['X_descriptors']}") - self.args.log.finalize() - sys.exit() + csv_X_external = _select_descriptors( + self, csv_X_external, model_data["X_descriptors"], "predict" + ) # split tests Xy_data = prepare_sets(self,csv_df,csv_X,csv_y,test_points,model_data['names'],csv_external_df,csv_X_external,csv_y_external,BO_opt=False) @@ -3380,9 +3586,9 @@ def pearson_map(self,csv_df_pearson,module,params_dir=None): ''' Creates Pearson heatmap ''' + import seaborn as sb - importlib.reload(plt) # needed to avoid threading issues - if module.lower() == 'curate': # only represent the final descriptors in CURATE + if module.lower() == "curate": # only represent the final descriptors in CURATE csv_df_pearson = csv_df_pearson.drop([self.args.y] + self.args.ignore, axis=1) corr_matrix = csv_df_pearson.corr() @@ -3457,11 +3663,12 @@ def plot_metrics(model_data,suffix_title,verify_metrics,verify_results): ''' Creates a plot with the results of the flawed models in VERIFY ''' + import seaborn as sb - importlib.reload(plt) # needed to avoid threading issues + importlib.reload(plt) sb.reset_defaults() sb.set(style="ticks") - _, ax = plt.subplots(figsize=(7.45,6)) + _, ax = plt.subplots(figsize=(7.45, 6)) # define names csv_name = os.path.basename(model_data['model']).split('_db.csv')[0] diff --git a/setup.py b/setup.py index 7381922..ecfaff1 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,7 @@ "seaborn==0.13.2", "scipy==1.15.0", "scikit-learn==1.7.2", + "xgboost==2.1.4", "bayesian-optimization==3.1.0", "numba==0.62.1", "shap==0.49.1", diff --git a/tests/test_1curate.py b/tests/test_1curate.py index 374a9e0..6b7b56c 100644 --- a/tests/test_1curate.py +++ b/tests/test_1curate.py @@ -36,6 +36,8 @@ "missing_input" ), # test that if the --names, --y or --csv_name options are empty, a prompt pops up and asks for them ("rfecv"), # test for the RFECV feature, default + ("xgb"), # test XGB model-specific CURATE output + ("vr"), # test VR model-specific CURATE output ("standard"), # standard test ("standard_cmd"), # standard test through command line ], @@ -442,3 +444,27 @@ def _check_standard(): accepted_vars = ["V_Bur", "dist", "rando1", "rando2", "rando3", "rando4"] for var in accepted_vars: assert var in db_final.columns + + elif test_job == "xgb": + curate_kwargs["model"] = ["XGB"] + _ = curate(**curate_kwargs) + + db_xgb = pd.read_csv(f"{path_curate}/Robert_example_CURATE_XGB.csv") + assert len(db_xgb["Name"]) == 37 + assert "Target_values" in db_xgb.columns + assert "Name" in db_xgb.columns + assert "xtest" not in db_xgb.columns + n_descps = len(db_xgb.columns) - 2 + assert n_descps < (len(db_xgb) / 3) + + elif test_job == "vr": + curate_kwargs["model"] = ["VR"] + _ = curate(**curate_kwargs) + + db_vr = pd.read_csv(f"{path_curate}/Robert_example_CURATE_VR.csv") + assert len(db_vr["Name"]) == 37 + assert "Target_values" in db_vr.columns + assert "Name" in db_vr.columns + assert "xtest" not in db_vr.columns + n_descps = len(db_vr.columns) - 2 + assert n_descps < (len(db_vr) / 3) From b635c3c5179a445a6c5d19118b098d8987221ab3 Mon Sep 17 00:00:00 2001 From: rlaplaza Date: Sat, 16 May 2026 17:54:20 +0200 Subject: [PATCH 4/8] feat(plots): add plot_verbosity and predict_diagnostics Gate matplotlib output by plot_verbosity (0-2) and predict_diagnostics, share Best_model iteration via iter_best_model_dirs, and reuse fitted models for SHAP/PFI diagnostics in PREDICT. Co-authored-by: Cursor --- robert/api.py | 35 ++++--- robert/aqme.py | 53 ++++++----- robert/argument_parser.py | 5 + robert/evaluate.py | 4 +- robert/predict.py | 86 +++++++++++------- robert/predict_utils.py | 186 ++++++++++++++++++++++---------------- robert/report.py | 2 +- robert/report_utils.py | 5 +- robert/uq_auto.py | 2 +- robert/verify.py | 43 +++------ 10 files changed, 242 insertions(+), 179 deletions(-) diff --git a/robert/api.py b/robert/api.py index f1191db..7b149cf 100644 --- a/robert/api.py +++ b/robert/api.py @@ -26,12 +26,6 @@ from sklearn.metrics import accuracy_score, r2_score from robert.argument_parser import options_add, var_dict -from robert.curate import curate -from robert.generate import generate -from robert.predict import predict as predict_module -from robert.report import report as report_module -from robert.utils import load_params -from robert.verify import verify _NAME_COL = "__robert_name__" _DEFAULT_Y = "__robert_y__" @@ -175,14 +169,19 @@ class RobertModel(BaseEstimator): :param y_column: If ``fit(X)`` is called with ``y is None``, name of the target column in ``X`` (DataFrame only). :param kwargs: Additional ROBERT options (keys in ``robert.argument_parser.var_dict``), - e.g. ``model``, ``n_iter``. Regression uncertainty tuning includes + e.g. ``model``, ``n_iter``, ``plot_verbosity`` (``0``–``2``): ``0`` skips + matplotlib artifacts; ``1`` keeps CURATE/GENERATE/VERIFY summary plots and main + PREDICT result plots when ``predict_diagnostics`` is True; ``2`` additionally + enables SHAP/PFI/Pearson/outlier/distribution diagnostics when + ``predict_diagnostics`` is True. Regression uncertainty tuning includes ``conformal_enable``, ``conformal_calib_frac``, and ``conformal_coverage``. Top-k meta-model uncertainty (opt-in) uses ``uq_enable_meta``, ``uq_top_k_models``, and ``uq_model_weighting`` (``"score_weighted"`` or ``"uniform"``). Auto uncertainty (regression, opt-in) uses ``uq_auto_enable``, ``uq_auto_candidates``, ``uq_auto_scaler``, ``uq_auto_metric_weights``, - ``uq_auto_min_samples``, and ``uq_auto_random_state``; ``return_uncertainty`` - ``"auto"`` or ``"auto_decomposed"`` enables auto mode for that predict call. + ``uq_auto_min_samples``, ``uq_auto_random_state``, and ``uq_auto_clas_mode`` + (``"error"`` by default); ``return_uncertainty`` ``"auto"`` or + ``"auto_decomposed"`` enables auto mode for that predict call. """ def __init__( @@ -395,7 +394,9 @@ def _build_train_frame(self, X_df: pd.DataFrame, y_series: pd.Series) -> pd.Data def _read_model_snapshot(self, workdir: Path) -> dict[str, Any]: sub = "PFI" if self.filter_mode == "pfi" else "No_PFI" - folder = workdir / "GENERATE" / "Best_model" / sub + from robert.utils import load_params, path_generate_best_model + + folder = path_generate_best_model(workdir, sub) params_path = _find_params_csv(folder) seed = int(self._rob_kwargs.get("seed", var_dict["seed"])) adapter = _ParamsAdapter(seed, self.problem_type) @@ -423,6 +424,13 @@ def fit( base["command_line"] = False base["csv_test"] = "" + from robert.curate import curate + from robert.generate import generate + from robert.predict import predict as predict_module + from robert.report import report as report_module + from robert.verify import verify + from robert.utils import path_generate_best_model + with _noninteractive_mpl(), _chdir(workdir): curate( csv_name=train_rel, @@ -441,8 +449,8 @@ def fit( if self.run_report: report_module(**base) - best_sub = workdir / "GENERATE" / "Best_model" / ( - "PFI" if self.filter_mode == "pfi" else "No_PFI" + best_sub = path_generate_best_model( + workdir, "PFI" if self.filter_mode == "pfi" else "No_PFI" ) if self.filter_mode == "pfi" and not best_sub.is_dir(): raise RuntimeError( @@ -588,6 +596,7 @@ def predict( base["csv_test"] = pred_name base["params_dir"] = "GENERATE/Best_model" base["names"] = self.names_col_ + base["predict_diagnostics"] = False if umode in ("auto", "auto_decomposed"): if self.problem_type != "reg": raise ValueError( @@ -596,6 +605,8 @@ def predict( ) base["uq_auto_enable"] = True + from robert.predict import predict as predict_module + with _noninteractive_mpl(), _chdir(workdir): predict_module(**base) diff --git a/robert/aqme.py b/robert/aqme.py index 0dd67ba..e385636 100644 --- a/robert/aqme.py +++ b/robert/aqme.py @@ -149,7 +149,6 @@ def run_csearch_qdescp(self,csv_target,aqme_test=False): order = csv_temp['code_name'].tolist() # Sort the rows in 'AQME-ROBERT_{aqme_indv_name}.csv' based on the order - df_temp = pd.read_csv(f'AQME-ROBERT_{self.args.descp_lvl}_{aqme_indv_name}.csv', encoding='utf-8') df_temp = df_temp.sort_values(by='code_name', key=lambda x: x.map({v: i for i, v in enumerate(order)})) # Fill missing values with corresponding SMILES row @@ -188,13 +187,12 @@ def run_csearch_qdescp(self,csv_target,aqme_test=False): self.args.log.write(f"\nx The initial AQME descriptor protocol did not create any CSV output!") sys.exit() - # remove atomic properties if no SMARTS patterns were selected in qdescp + # remove atomic properties if no SMARTS patterns were selected in qdescp, + # and drop AQME argument columns from CSV inputs (single read/write) if 'qdescp_atoms' not in self.args.qdescp_keywords: - _ = filter_atom_prop(aqme_db,csv_df) - - # remove arguments from CSV inputs in AQME - _ = filter_aqme_args(aqme_db) - + _ = filter_atom_prop_and_aqme_args(aqme_db, csv_df, strip_atom_lists=True) + else: + _ = filter_atom_prop_and_aqme_args(aqme_db, csv_df, strip_atom_lists=False) # delete AQME_indiv*.csv files for file in glob.glob('*QME_indiv*.csv'): os.remove(file) @@ -226,34 +224,41 @@ def init_aqme(self): sys.exit() +def filter_atom_prop_and_aqme_args(aqme_db, csv_df, *, strip_atom_lists): + """ + Drop atomic list descriptors when no --qdescp_atoms was used, and remove + columns that duplicate AQME CSV inputs (single pass over the dataframe). + """ + aqme_df = pd.read_csv(aqme_db, encoding='utf-8') + if strip_atom_lists: + for column in list(aqme_df.columns): + if column == 'DBSTEP_Vbur': + aqme_df = aqme_df.drop(column, axis=1) + # remove lists of atomic properties (skip columns from AQME arguments) + elif aqme_df[column].dtype == object and column.lower() not in aqme_args: + first_cell = aqme_df[column].iloc[0] if len(aqme_df) else None + if first_cell is not None and '[' in str(first_cell) and column not in csv_df.columns: + aqme_df = aqme_df.drop(column, axis=1) + for column in list(aqme_df.columns): + if column.lower() in aqme_args: + aqme_df = aqme_df.drop(column, axis=1) + os.remove(aqme_db) + aqme_df.to_csv(f'{aqme_db}', index=None, header=True) + + def filter_atom_prop(aqme_db, csv_df): ''' Function that filters off atomic properties if no atom was selected in the --qdescp_atoms option ''' - aqme_df = pd.read_csv(aqme_db, encoding='utf-8') - for column in aqme_df.columns: - if column == 'DBSTEP_Vbur': - aqme_df = aqme_df.drop(column, axis=1) - # remove lists of atomic properties (skip columns from AQME arguments) - elif aqme_df[column].dtype == object and column.lower() not in aqme_args: - if '[' in aqme_df[column][0] and column not in csv_df.columns: - aqme_df = aqme_df.drop(column, axis=1) - os.remove(aqme_db) - _ = aqme_df.to_csv(f'{aqme_db}', index=None, header=True) + filter_atom_prop_and_aqme_args(aqme_db, csv_df, strip_atom_lists=True) def filter_aqme_args(aqme_db): ''' Function that filters off AQME arguments in CSV inputs ''' - - aqme_df = pd.read_csv(aqme_db, encoding='utf-8') - for column in aqme_df.columns: - if column.lower() in aqme_args: - aqme_df = aqme_df.drop(column, axis=1) - os.remove(aqme_db) - _ = aqme_df.to_csv(f'{aqme_db}', index = None, header=True) + filter_atom_prop_and_aqme_args(aqme_db, pd.DataFrame(), strip_atom_lists=False) def move_aqme(): diff --git a/robert/argument_parser.py b/robert/argument_parser.py index 5e54f8e..ef25c38 100644 --- a/robert/argument_parser.py +++ b/robert/argument_parser.py @@ -77,6 +77,11 @@ "uq_auto_min_samples": 12, "uq_auto_random_state": 0, "uq_auto_clas_mode": "error", + # When False, skip SHAP/PFI/heatmap/outlier/distribution plots (PREDICT only). + "predict_diagnostics": True, + # Plot verbosity: 0 = no figures, 1 = workflow summaries (CURATE/GENERATE/VERIFY + + # main PREDICT result plots when predict_diagnostics is True), 2 = full diagnostics. + "plot_verbosity": 2, } diff --git a/robert/evaluate.py b/robert/evaluate.py index 478dfa1..c298c52 100644 --- a/robert/evaluate.py +++ b/robert/evaluate.py @@ -43,7 +43,7 @@ import time import pandas as pd from pathlib import Path -from robert.utils import load_variables, finish_print, load_database, prepare_sets +from robert.utils import load_variables, finish_print, load_database, prepare_sets, path_generate_best_model from robert.generate_utils import set_sets @@ -95,7 +95,7 @@ def save_generate(self,csv_df,Xy_data): ''' # copy database with Set column - generate_folder = Path('GENERATE/Best_model/No_PFI') + generate_folder = path_generate_best_model(None, "No_PFI") if os.path.exists(generate_folder): shutil.rmtree(generate_folder) Path(generate_folder).mkdir(exist_ok=True, parents=True) diff --git a/robert/predict.py b/robert/predict.py index 4d6af92..f2a17d0 100644 --- a/robert/predict.py +++ b/robert/predict.py @@ -37,11 +37,13 @@ import os import time -from robert.predict_utils import (plot_predictions, +from robert.predict_utils import ( + iter_best_model_dirs, + plot_predictions, save_predictions, print_predict, - pearson_map_predict - ) + pearson_map_predict, +) from robert.uq_auto import apply_auto_uq from robert.utils import ( load_variables, @@ -54,6 +56,9 @@ shap_analysis, outlier_plot, distribution_plot, + _mpl_plot_context, + should_plot_predict_results, + should_plot_predict_deep_diagnostics, ) class predict: @@ -73,16 +78,12 @@ def __init__(self, **kwargs): # load default and user-specified variables self.args = load_variables(kwargs, "predict") - # if params_dir = '', the program performs the tests for the No_PFI and PFI folders - if 'GENERATE/Best_model' in self.args.params_dir: - params_dirs = [f'{self.args.params_dir}/No_PFI',f'{self.args.params_dir}/PFI'] - suffixes = ['(with no PFI filter)','(with PFI filter)'] - suffix_titles = ['No_PFI','PFI'] - else: - params_dirs = [self.args.params_dir] - suffix = ['custom'] + run_prediction_plots = should_plot_predict_results(self.args) + run_deep_diagnostics = should_plot_predict_deep_diagnostics(self.args) - for (params_dir,suffix,suffix_title) in zip(params_dirs,suffixes,suffix_titles): + for params_dir, suffix, suffix_title in iter_best_model_dirs( + self.args.params_dir + ): if os.path.exists(params_dir): _ = print_pfi(self,params_dir) @@ -102,28 +103,47 @@ def __init__(self, **kwargs): ) # save predictions for all sets - path_n_suffix, name_points, Xy_data = save_predictions(self,Xy_data,model_data,suffix_title) - - # represent y vs predicted y - colors = plot_predictions(self,model_data,Xy_data,path_n_suffix) + path_n_suffix, name_points, Xy_data = save_predictions( + self, Xy_data, model_data, suffix_title + ) # print results - _ = print_predict(self,Xy_data,model_data,suffix_title) - - # SHAP analysis - _ = shap_analysis(self,Xy_data,model_data,path_n_suffix) - - # PFI analysis - _ = PFI_plot(self,Xy_data,model_data,path_n_suffix) - - # create Pearson heatmap - _ = pearson_map_predict(self,Xy_data,params_dir) - - # Outlier analysis - if model_data['type'].lower() == 'reg': - _ = outlier_plot(self,Xy_data,path_n_suffix,name_points,colors) - - # y distribution - _ = distribution_plot(self,Xy_data,path_n_suffix,model_data) + _ = print_predict(self, Xy_data, model_data, suffix_title) + + if run_prediction_plots or run_deep_diagnostics: + fitted = Xy_data.get("_fitted_model") + with _mpl_plot_context(): + colors = None + if run_prediction_plots: + colors = plot_predictions( + self, model_data, Xy_data, path_n_suffix + ) + if run_deep_diagnostics: + _ = shap_analysis( + self, + Xy_data, + model_data, + path_n_suffix, + fitted_model=fitted, + ) + _ = PFI_plot( + self, + Xy_data, + model_data, + path_n_suffix, + fitted_model=fitted, + ) + _ = pearson_map_predict(self, Xy_data, params_dir) + if model_data["type"].lower() == "reg": + _ = outlier_plot( + self, + Xy_data, + path_n_suffix, + name_points, + colors, + ) + _ = distribution_plot( + self, Xy_data, path_n_suffix, model_data + ) _ = finish_print(self,start_time,'PREDICT') diff --git a/robert/predict_utils.py b/robert/predict_utils.py index 80b4636..83e66b3 100644 --- a/robert/predict_utils.py +++ b/robert/predict_utils.py @@ -3,46 +3,36 @@ #####################################################. import os -import sys from pathlib import Path import pandas as pd import numpy as np from robert.utils import ( - categorical_transform, get_graph_style, pearson_map, graph_reg, graph_clas, get_error_labels, - ) + PARAMS_DIR_BEST_MODEL_MARK, +) -def test_csv(self,Xy_test_df,descs_model,params_df): +def iter_best_model_dirs(params_dir): """ - Separates the test databases into X and y. This allows to merge test external databases that - contain different columns with internal test databases coming from GENERATE + Yield (params_dir, suffix, suffix_title) for default Best_model layout or a + single custom folder. """ - - y_test_df = pd.DataFrame() - - try: - X_test_df = Xy_test_df[descs_model] - except KeyError: - # this might fail if the initial categorical variables have not been transformed - try: - self.args.log.write(f"\n x There are missing descriptors in the test set! Looking for categorical variables converted from CURATE") - Xy_test_df = categorical_transform(self,Xy_test_df,'predict') - X_test_df = Xy_test_df[descs_model] - self.args.log.write(f" o The missing descriptors were successfully created") - except KeyError: - self.args.log.write(f" x There are still missing descriptors in the test set! The following descriptors are needed: {descs_model}") - self.args.log.finalize() - sys.exit() - - if params_df['y'][0] in Xy_test_df: - y_test_df = Xy_test_df[params_df['y'][0]] - - return X_test_df, y_test_df + if PARAMS_DIR_BEST_MODEL_MARK in str(params_dir).replace("\\", "/"): + params_dirs = [ + f"{params_dir}/No_PFI", + f"{params_dir}/PFI", + ] + suffixes = ["(with no PFI filter)", "(with PFI filter)"] + suffix_titles = ["No_PFI", "PFI"] + else: + params_dirs = [params_dir] + suffixes = ["(custom)"] + suffix_titles = ["custom"] + return zip(params_dirs, suffixes, suffix_titles) def _uq_columns_for_split(Xy_data, y_col, split): @@ -67,6 +57,47 @@ def _uq_auto_source_column(Xy_data): return str(selected) +def _reconvert_values(values, reconvert_labels, class_mapping_reverse): + if not reconvert_labels: + return values + return [class_mapping_reverse[int(y)] for y in values] + + +def _append_split_columns( + df, + Xy_data, + model_data, + split, + *, + reconvert_labels, + class_mapping_reverse, + hw_scalar, + auto_src, +): + """Add y, predictions, SD, UQ, and conformal columns for one split.""" + y_col = model_data["y"] + y_key = f"y_{split}" + pred_key = f"y_pred_{split}" + sd_key = f"y_pred_{split}_sd" + + y_values = _reconvert_values( + Xy_data[y_key].tolist(), reconvert_labels, class_mapping_reverse + ) + pred_values = _reconvert_values( + Xy_data[pred_key], reconvert_labels, class_mapping_reverse + ) + + df[y_col] = y_values + df[f"{y_col}_pred"] = pred_values + df[f"{y_col}_pred_sd"] = Xy_data[sd_key] + for col_name, col_vals in _uq_columns_for_split(Xy_data, y_col, split).items(): + df[col_name] = col_vals + df[f"{y_col}_pred_conformal_hw"] = [hw_scalar] * len(df) + if auto_src is not None: + df[f"{y_col}_pred_uq_auto_source"] = [auto_src] * len(df) + return df + + def plot_predictions(self, params_dict, Xy_data, path_n_suffix): ''' Plot graphs of predicted vs actual values for train, validation and test sets @@ -122,41 +153,31 @@ def save_predictions(self,Xy_data,model_data,suffix_title): # Store y values and predictions, reconverting if needed y_col = model_data['y'] - # For training set - y_train_values = Xy_data['y_train'].tolist() - y_pred_train_values = Xy_data['y_pred_train'] - if reconvert_labels: - y_train_values = [class_mapping_reverse[int(y)] for y in y_train_values] - y_pred_train_values = [class_mapping_reverse[int(y)] for y in y_pred_train_values] - - Xy_train[y_col] = y_train_values - Xy_train[f"{y_col}_pred"] = y_pred_train_values - Xy_train[f"{y_col}_pred_sd"] = Xy_data['y_pred_train_sd'] - for col_name, col_vals in _uq_columns_for_split(Xy_data, y_col, "train").items(): - Xy_train[col_name] = col_vals hw_scalar = float(Xy_data.get("conformal_half_width", float("nan"))) if model_data["type"].lower() != "reg": hw_scalar = float("nan") - Xy_train[f"{y_col}_pred_conformal_hw"] = [hw_scalar] * len(Xy_train) auto_src = _uq_auto_source_column(Xy_data) - if auto_src is not None: - Xy_train[f"{y_col}_pred_uq_auto_source"] = [auto_src] * len(Xy_train) - - # For test set - y_test_values = Xy_data['y_test'].tolist() - y_pred_test_values = Xy_data['y_pred_test'] - if reconvert_labels: - y_test_values = [class_mapping_reverse[int(y)] for y in y_test_values] - y_pred_test_values = [class_mapping_reverse[int(y)] for y in y_pred_test_values] - - Xy_test[y_col] = y_test_values - Xy_test[f"{y_col}_pred"] = y_pred_test_values - Xy_test[f"{y_col}_pred_sd"] = Xy_data['y_pred_test_sd'] - for col_name, col_vals in _uq_columns_for_split(Xy_data, y_col, "test").items(): - Xy_test[col_name] = col_vals - Xy_test[f"{y_col}_pred_conformal_hw"] = [hw_scalar] * len(Xy_test) - if auto_src is not None: - Xy_test[f"{y_col}_pred_uq_auto_source"] = [auto_src] * len(Xy_test) + + Xy_train = _append_split_columns( + Xy_train, + Xy_data, + model_data, + "train", + reconvert_labels=reconvert_labels, + class_mapping_reverse=class_mapping_reverse, + hw_scalar=hw_scalar, + auto_src=auto_src, + ) + Xy_test = _append_split_columns( + Xy_test, + Xy_data, + model_data, + "test", + reconvert_labels=reconvert_labels, + class_mapping_reverse=class_mapping_reverse, + hw_scalar=hw_scalar, + auto_src=auto_src, + ) df_results = pd.concat([Xy_train, Xy_test], axis=0) @@ -196,28 +217,29 @@ def save_predictions(self,Xy_data,model_data,suffix_title): # saves prediction for external test in --csv_test Xy_external = pd.DataFrame(Xy_data['names_external']) - for col in Xy_data['X_external']: - Xy_external[col] = Xy_data['X_external'][col].tolist() - Xy_external[col] = Xy_data['X_external'][col].tolist() - - # Reconvert external set labels if needed - if 'y_external' in Xy_data: - y_external_values = Xy_data['y_external'].tolist() - if reconvert_labels: - y_external_values = [class_mapping_reverse[int(y)] for y in y_external_values] - Xy_external[model_data['y']] = y_external_values - - y_pred_external_values = Xy_data['y_pred_external'] - if reconvert_labels: - y_pred_external_values = [class_mapping_reverse[int(y)] for y in y_pred_external_values] - - Xy_external[f"{model_data['y']}_pred"] = y_pred_external_values - Xy_external[f"{model_data['y']}_pred_sd"] = Xy_data['y_pred_external_sd'] + for col in Xy_data["X_external"]: + Xy_external[col] = Xy_data["X_external"][col].tolist() + + if "y_external" in Xy_data: + y_external_values = _reconvert_values( + Xy_data["y_external"].tolist(), + reconvert_labels, + class_mapping_reverse, + ) + Xy_external[model_data["y"]] = y_external_values + + pred_values = _reconvert_values( + Xy_data["y_pred_external"], reconvert_labels, class_mapping_reverse + ) + Xy_external[f"{model_data['y']}_pred"] = pred_values + Xy_external[f"{model_data['y']}_pred_sd"] = Xy_data["y_pred_external_sd"] for col_name, col_vals in _uq_columns_for_split( Xy_data, model_data["y"], "external" ).items(): Xy_external[col_name] = col_vals - Xy_external[f"{model_data['y']}_pred_conformal_hw"] = [hw_scalar] * len(Xy_external) + Xy_external[f"{model_data['y']}_pred_conformal_hw"] = [hw_scalar] * len( + Xy_external + ) if auto_src is not None: Xy_external[f"{model_data['y']}_pred_uq_auto_source"] = [ auto_src @@ -247,10 +269,22 @@ def save_predictions(self,Xy_data,model_data,suffix_title): return path_n_suffix, name_points, Xy_data +def _ensure_pred_range_stats(Xy_data): + """Set y-range summary keys when diagnostic plots were skipped.""" + if "pred_min" in Xy_data: + return + pred_min = min(min(Xy_data["y_train"]), min(Xy_data["y_test"])) + pred_max = max(max(Xy_data["y_train"]), max(Xy_data["y_test"])) + Xy_data["pred_min"] = pred_min + Xy_data["pred_max"] = pred_max + Xy_data["pred_range"] = float(np.abs(pred_max - pred_min)) + + def print_predict(self,Xy_data,model_data,suffix_title): ''' Prints results of the predictions for all the sets ''' + _ensure_pred_range_stats(Xy_data) print_results = ( "\n o Summary of results " diff --git a/robert/report.py b/robert/report.py index 8a167c0..a7431f0 100644 --- a/robert/report.py +++ b/robert/report.py @@ -6,7 +6,7 @@ Directory to create the output file(s). varfile : str, default=None Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml). - report_modules : list of str, default=['CURATE','GENERATE','VERIFY','PREDICT'] + report_modules : list of str, default=['AQME','CURATE','GENERATE','VERIFY','PREDICT'] List of the modules to include in the report. debug_report : bool, default=False Debug mode using during the pytests of report.py diff --git a/robert/report_utils.py b/robert/report_utils.py index 11585be..5899127 100644 --- a/robert/report_utils.py +++ b/robert/report_utils.py @@ -698,7 +698,8 @@ def get_col_text(type_thres): 'CV: cross-validation', 'F1 score: balanced F-score', 'GB: gradient boosting', - 'GP: gaussian process' + 'GP: gaussian process', + 'XGB: extreme gradient boosting' ] elif type_thres == 'abbrev_2': @@ -755,9 +756,11 @@ def get_col_transpa(params_dict,suffix,section,spacing): models_dict = {'RF': f'RandomForest{model_type}', 'MVL': 'LinearRegression', 'GB': f'GradientBoosting{model_type}', + 'XGB': f'XGB{model_type}', 'NN': f'MLP{model_type}', 'GP': f'GaussianProcess{model_type}', 'ADAB': f'AdaBoost{model_type}', + 'VR': f'Voting{model_type}', } col_info,sklearn_model = '','' diff --git a/robert/uq_auto.py b/robert/uq_auto.py index 6bc6349..0a0ea99 100644 --- a/robert/uq_auto.py +++ b/robert/uq_auto.py @@ -418,7 +418,7 @@ def apply_auto_uq( return Xy_data -def get_bo_ready_prediction_bundle( +def _get_bo_ready_prediction_bundle( Xy_data: Mapping[str, Any], split: str, y_col: str, diff --git a/robert/verify.py b/robert/verify.py index ac07ada..9ee9440 100644 --- a/robert/verify.py +++ b/robert/verify.py @@ -25,13 +25,16 @@ import time import numpy as np from statistics import mode -from robert.utils import (load_variables, +from robert.predict_utils import iter_best_model_dirs +from robert.utils import ( + load_variables, load_db_n_params, load_n_predict, finish_print, get_prediction_results, print_pfi, - plot_metrics + plot_metrics, + should_plot_verify_metrics, ) @@ -56,16 +59,9 @@ def __init__(self, **kwargs): # load default and user-specified variables self.args = load_variables(kwargs, "verify") - # if params_dir = '', the program performs the tests for the No_PFI and PFI folders - if 'GENERATE/Best_model' in self.args.params_dir: - params_dirs = [f'{self.args.params_dir}/No_PFI',f'{self.args.params_dir}/PFI'] - suffixes = ['(with no PFI filter)','(with PFI filter)'] - suffix_titles = ['No_PFI','PFI'] - else: - params_dirs = [self.args.params_dir] - suffix = ['custom'] - - for (params_dir,suffix,suffix_title) in zip(params_dirs,suffixes,suffix_titles): + for params_dir, suffix, suffix_title in iter_best_model_dirs( + self.args.params_dir + ): if os.path.exists(params_dir): _ = print_pfi(self,params_dir) @@ -89,21 +85,15 @@ def __init__(self, **kwargs): verify_results[f'f1_train_sorted_CV'] = [float(f"{val:.2f}") for val in Xy_data[f'f1_train_sorted_CV']] verify_results[f'mcc_train_sorted_CV'] = [float(f"{val:.2f}") for val in Xy_data[f'mcc_train_sorted_CV']] - # load the Xy databse and model parameters + # Reload once for flawed-model tests (fresh splits consistent with CSV on disk). Xy_data, model_data, suffix_title = load_db_n_params(self,params_dir,suffix,suffix_title,"verify",False) # calculate scores for the y-mean test verify_results = self.ymean_test(verify_results,Xy_data,model_data) - # load the Xy databse and model parameters - Xy_data, model_data, suffix_title = load_db_n_params(self,params_dir,suffix,suffix_title,"verify",False) - # calculate scores for the y-shuffle test verify_results = self.yshuffle_test(verify_results,Xy_data,model_data) - # load the Xy databse and model parameters - Xy_data, model_data, suffix_title = load_db_n_params(self,params_dir,suffix,suffix_title,"verify",False) - # one-hot test (check that if a value isnt 0, the value assigned is 1) verify_results = self.onehot_test(verify_results,Xy_data,model_data) @@ -111,7 +101,10 @@ def __init__(self, **kwargs): results_print,verify_results,verify_metrics = self.analyze_tests(verify_results) # plot a bar graph with the results - print_ver = plot_metrics(model_data,suffix_title,verify_metrics,verify_results) + if should_plot_verify_metrics(self.args): + print_ver = plot_metrics(model_data,suffix_title,verify_metrics,verify_results) + else: + print_ver = "\n o VERIFY plot skipped (plot_verbosity)" # print and save results _ = self.print_verify(results_print,verify_results,print_ver,model_data) @@ -162,15 +155,7 @@ def onehot_test(self,verify_results,Xy_data,model_data): ''' Xy_onehot = Xy_data.copy() - for desc in Xy_onehot['X_train']: - new_vals = [] - for val in Xy_onehot['X_train'][desc]: - if val == 0: - new_vals.append(0) - else: - new_vals.append(1) - Xy_onehot['X_train_scaled'][desc] = new_vals - + Xy_onehot['X_train_scaled'] = Xy_onehot['X_train_scaled'].copy() for desc in Xy_onehot['X_train']: new_vals = [] for val in Xy_onehot['X_train'][desc]: From d094c51f6d8bf1b4c81291c893bc60440cf7fbe1 Mon Sep 17 00:00:00 2001 From: rlaplaza Date: Sat, 16 May 2026 17:54:25 +0200 Subject: [PATCH 5/8] test: consolidate UQ tests and extend API coverage Merge UQ auto/meta coverage into test_8uq.py, add fast_robert_kwargs fixture, custom predict fixtures, and API/GENERATE test updates. Co-authored-by: Cursor --- tests/conftest.py | 16 ++ tests/fixtures/custom_predict_model/RF.csv | 2 + tests/fixtures/custom_predict_model/RF_db.csv | 38 ++++ tests/test_2generate.py | 50 +++-- tests/test_7api.py | 178 +++++++++++++++--- tests/{test_uq_auto.py => test_8uq.py} | 108 +++++++++-- tests/test_uq_meta.py | 92 --------- 7 files changed, 336 insertions(+), 148 deletions(-) create mode 100644 tests/fixtures/custom_predict_model/RF.csv create mode 100644 tests/fixtures/custom_predict_model/RF_db.csv rename tests/{test_uq_auto.py => test_8uq.py} (50%) delete mode 100644 tests/test_uq_meta.py diff --git a/tests/conftest.py b/tests/conftest.py index 78f36f1..34aeb51 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,8 @@ import os import sys +import pytest + def pytest_configure(config): """ @@ -24,3 +26,17 @@ def pytest_configure(config): os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") os.environ.setdefault("MPLBACKEND", "Agg") + + +@pytest.fixture +def fast_robert_kwargs(): + """Reduced CV/BO settings for faster integration tests.""" + return { + "model": ["RF"], + "n_iter": 2, + "init_points": 2, + "repeat_kfolds": 2, + "kfold": 3, + "pfi_epochs": 1, + "seed": 42, + } diff --git a/tests/fixtures/custom_predict_model/RF.csv b/tests/fixtures/custom_predict_model/RF.csv new file mode 100644 index 0000000..8da016b --- /dev/null +++ b/tests/fixtures/custom_predict_model/RF.csv @@ -0,0 +1,2 @@ +model,type,kfold,repeat_kfolds,seed,error_type,y,names,X_descriptors,params,combined_rmse,split +RF,reg,5,10,0,rmse,Target_values,Name,"[""x10"", ""x2"", ""x5"", ""x7"", ""x9""]","{""n_estimators"": 59, ""max_depth"": 16, ""min_samples_split"": 7, ""min_samples_leaf"": 4, ""min_weight_fraction_leaf"": 0.021182739966945238, ""max_features"": 0.7344205847999921, ""ccp_alpha"": 0.004375872112626925, ""max_samples"": 0.9188297505865598, ""random_state"": 0}",0.5518457913196202,even diff --git a/tests/fixtures/custom_predict_model/RF_db.csv b/tests/fixtures/custom_predict_model/RF_db.csv new file mode 100644 index 0000000..24637eb --- /dev/null +++ b/tests/fixtures/custom_predict_model/RF_db.csv @@ -0,0 +1,38 @@ +x10,x2,x5,x7,x9,Name,Target_values,Set +2,113.0647278,89.46292114,0,2,6,0.308776518,Training +3,110.7593079,59.81459808,0,1,36,0.321084552,Training +3,115.2292938,70.45233154,0,1,37,0.329517076,Training +3,111.0383148,55.76971817,0,1,33,0.39762592,Test +3,111.6276932,59.93067169,0,1,35,0.410599034,Training +2,112.8844147,62.99021912,0,1,12,0.44305174,Training +2,110.4608917,58.43679047,0,1,10,0.515551116,Training +2,113.9221497,69.55827332,0,1,11,0.515551116,Training +2,111.9355087,78.45265198,0,2,28,0.515838419,Test +2,111.7983551,76.89291382,1,2,30,0.837770564,Training +1,112.0635071,88.85021973,0,1,3,0.860667495,Training +1,111.7115936,70.3069458,0,0,29,0.916777853,Training +1,111.0540771,86.16723633,1,1,31,0.942297022,Training +1,111.8824539,88.58656311,1,1,5,0.952469022,Test +1,110.8601837,86.97593689,1,1,32,0.965415203,Training +1,110.7689209,79.076828,1,1,7,0.982261986,Training +1,110.4682465,66.3848114,1,0,24,0.998633019,Training +1,112.2061539,67.16915131,1,0,26,1.006818535,Training +1,112.101181,69.10329437,0,0,4,1.026447198,Test +1,113.392807,69.00031281,1,0,25,1.031375085,Training +1,110.5679703,78.03665161,1,1,9,1.047746118,Training +1,111.5555725,78.75331879,1,1,8,1.129601283,Training +0,110.8476944,76.10786438,1,0,19,1.16234335,Training +1,111.0308151,84.12934875,2,1,34,1.178714383,Test +0,110.9802704,77.10182953,1,0,27,1.694401925,Training +0,110.9270401,89.87553406,1,0,1,1.854766065,Training +0,110.719635,89.83821106,2,0,15,1.883806797,Training +0,110.7536316,89.6808548,2,0,17,1.902644865,Training +0,111.2339783,82.79759216,2,0,20,1.989080521,Test +0,110.4220352,78.39561462,2,0,23,2.011592734,Training +0,110.6553116,78.65235138,1,0,2,2.034511341,Training +0,111.3189087,89.78244019,2,0,16,2.091025544,Training +0,110.3923798,78.41550446,2,0,14,2.138192104,Training +0,110.5127563,77.87947083,2,0,13,2.221390241,Test +0,110.959137,78.64178467,2,0,18,2.312908191,Training +0,110.9408264,78.66916656,2,0,22,2.391560251,Training +0,110.3707199,78.09907532,2,0,21,2.801329141,Training diff --git a/tests/test_2generate.py b/tests/test_2generate.py index f3e45ef..4c91a9f 100644 --- a/tests/test_2generate.py +++ b/tests/test_2generate.py @@ -4,6 +4,8 @@ # Testing GENERATE with pytest # ######################################################. +import ast +import json import os import sys import glob @@ -33,6 +35,15 @@ def _read_best_model_pair(best_dir): ) +def _parse_x_descriptors(raw): + """Parse X_descriptors from GENERATE CSV (JSON or Python literal list).""" + text = raw if isinstance(raw, str) else str(raw) + try: + return json.loads(text) + except json.JSONDecodeError: + return ast.literal_eval(text) + + def _log_line_metric_close(line, prefix, expected, *, rel_tol=0.05, abs_tol=0.02): """ True if line contains prefix and the numeric token immediately after prefix @@ -64,6 +75,11 @@ def _log_line_metric_close(line, prefix, expected, *, rel_tol=0.05, abs_tol=0.02 ( "reduced_adab" ), # test for other GP model (important since PFI filter tries to discard all the descriptors) + ( + "reduced_xgb" + ), # test for XGB model (tree booster with feature importances) + ("reduced_vr"), # test for Voting Regressor model + ("reduced_vr_clas"), # test Voting Classifier workflow ("reduced_clas"), # test for clasification models ("standard"), # standard test ], @@ -83,7 +99,7 @@ def test_GENERATE(test_job): shutil.rmtree(os.path.join(path_main, "GENERATE_clas")) # runs the program with the different tests - if test_job == "reduced_clas": + if test_job in ["reduced_clas", "reduced_vr_clas"]: csv_name = os.path.join("tests", "Robert_example_clas.csv") else: csv_name = os.path.join("CURATE", "Robert_example_CURATE.csv") @@ -110,12 +126,16 @@ def test_GENERATE(test_job): generate_kwargs = {"generate": True, "csv_name": csv_name, "y": "Target_values"} if test_job != "standard": # add model - if test_job not in ["reduced_gp", "reduced_adab"]: + if test_job not in ["reduced_gp", "reduced_adab", "reduced_xgb", "reduced_vr", "reduced_vr_clas"]: generate_kwargs["model"] = ["RF"] elif test_job == "reduced_gp": generate_kwargs["model"] = ["GP"] elif test_job == "reduced_adab": generate_kwargs["model"] = ["Adab"] + elif test_job == "reduced_xgb": + generate_kwargs["model"] = ["XGB"] + elif test_job in ["reduced_vr", "reduced_vr_clas"]: + generate_kwargs["model"] = ["VR"] # adjust cmd for tests if test_job == "reduced_noPFI": @@ -125,7 +145,7 @@ def test_GENERATE(test_job): elif test_job == "reduced_kfold": generate_kwargs["kfold"] = 10 generate_kwargs["repeat_kfolds"] = 5 - elif test_job in ["reduced_clas"]: + elif test_job in ["reduced_clas", "reduced_vr_clas"]: generate_kwargs["type"] = "clas" generate_kwargs["init_points"] = 1 @@ -150,7 +170,7 @@ def test_GENERATE(test_job): else: indeces = [7, 8, 9, 10] assert "- 37 datapoints" in outlines[indeces[0]] - if test_job == "reduced_clas": + if test_job in ["reduced_clas", "reduced_vr_clas"]: assert "- 9 accepted descriptors" in outlines[indeces[1]] else: assert "- 11 accepted descriptors" in outlines[indeces[1]] @@ -256,9 +276,12 @@ def test_GENERATE(test_job): if test_job == "reduced_noPFI": assert finding_line == 3.5 assert reproducibility == 1 - elif test_job in ["reduced", "reduced_cmd"]: + elif test_job in ["reduced", "reduced_cmd", "reduced_vr"]: assert finding_line == 4 - assert reproducibility == 2 + if test_job in ["reduced", "reduced_cmd"]: + assert reproducibility == 2 + else: + assert reproducibility == 0 if test_job == "standard": assert finding_line == 11 assert reproducibility == 8 @@ -291,6 +314,7 @@ def test_GENERATE(test_job): "reduced_PFImax", "reduced_gp", "reduced_adab", + "reduced_xgb", "reduced_clas", "standard", ]: @@ -311,7 +335,7 @@ def test_GENERATE(test_job): "x9", "ynoise", ] - elif test_job == "reduced_adab": + elif test_job in ["reduced_xgb", "reduced_adab"]: desc_list = [ "Csub-Csub", "Csub-H", @@ -358,6 +382,8 @@ def test_GENERATE(test_job): desc_list = ["x10"] elif test_job == "reduced_gp": desc_list = ["x5", "Csub-Csub", "x7", "x10", "x8", "Csub-H"] + elif test_job == "reduced_xgb": + desc_list = ["x10", "x7"] elif test_job == "reduced_adab": desc_list = ["x10", "x9"] elif test_job == "reduced_clas": @@ -391,9 +417,10 @@ def test_GENERATE(test_job): for i, expected_col in enumerate(expected_cols): assert db_best.columns[i] == expected_col + stored_descs = _parse_x_descriptors(params_best["X_descriptors"][0]) for var in desc_list: - assert var in params_best["X_descriptors"][0] - assert len(desc_list) == len(params_best["X_descriptors"][0].split(",")) + assert var in stored_descs + assert len(desc_list) == len(stored_descs) if test_job == "reduced_clas": metric_bo = "mcc" @@ -435,9 +462,10 @@ def test_GENERATE(test_job): ) # Check that the default metric for classification models is MCC - if test_job == "reduced_clas": + if test_job in ["reduced_clas", "reduced_vr_clas"]: + model_name = "RF" if test_job == "reduced_clas" else "VR" csv_clas = glob.glob( - os.path.join(path_generate, "Best_model", "PFI", "RF_PFI.csv") + os.path.join(path_generate, "Best_model", "PFI", f"{model_name}_PFI.csv") ) df = pd.read_csv(csv_clas[0]) if "error_type" in df.columns: diff --git a/tests/test_7api.py b/tests/test_7api.py index f9c4863..e481533 100644 --- a/tests/test_7api.py +++ b/tests/test_7api.py @@ -4,12 +4,16 @@ # Testing API with pytest # ######################################################. -"""Tests for :class:`robert.api.RobertModel` and related helpers.""" +"""Tests for RobertModel, YAML config, plot verbosity, and custom PREDICT paths.""" +import os +import shutil import subprocess import sys +import tempfile import warnings from pathlib import Path +from types import SimpleNamespace import numpy as np import pandas as pd @@ -17,20 +21,32 @@ from robert import RobertModel from robert.api import _resolve_prediction_id_column +from robert.argument_parser import set_options +from robert.predict import predict +from robert.utils import ( + load_from_yaml, + plot_verbosity_level, + should_plot_curate_pearson, + should_plot_generate_heatmap, + should_plot_predict_deep_diagnostics, + should_plot_predict_results, + should_plot_verify_metrics, +) _REPO = Path(__file__).resolve().parent.parent _REG_CSV = _REPO / "tests" / "Robert_example.csv" _CLAS_CSV = _REPO / "tests" / "Robert_example_clas.csv" +_FIXTURE_MODEL = _REPO / "tests" / "fixtures" / "custom_predict_model" -_FAST = { - "model": ["RF"], - "n_iter": 2, - "init_points": 2, - "repeat_kfolds": 2, - "kfold": 3, - "pfi_epochs": 1, - "seed": 42, -} + +@pytest.fixture +def custom_model_dir(tmp_path): + """Minimal GENERATE-style folder (params CSV + _db.csv).""" + dest = tmp_path / "custom_model" + dest.mkdir() + shutil.copy(_FIXTURE_MODEL / "RF.csv", dest / "RF.csv") + shutil.copy(_FIXTURE_MODEL / "RF_db.csv", dest / "RF_db.csv") + return dest def _holdout_for_predict(X: pd.DataFrame, n_fit: int) -> pd.DataFrame: @@ -39,7 +55,119 @@ def _holdout_for_predict(X: pd.DataFrame, n_fit: int) -> pd.DataFrame: return tail.drop_duplicates(subset=["Name"], keep="first") -def test_fit_predict_regression(tmp_path): +# --- YAML --- + + +def test_yaml_unknown_key_warns_and_known_key_applies(capsys): + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False, encoding="utf-8") as f: + f.write("not_a_robert_option: 1\nseed: 99\n") + path = f.name + try: + opts = set_options({}) + opts.varfile = path + load_from_yaml(opts) + captured = capsys.readouterr() + text = captured.out + captured.err + assert "not_a_robert_option" in text + assert opts.seed == 99 + finally: + os.unlink(path) + + +def test_yaml_missing_file_message(): + opts = set_options({}) + opts.varfile = os.path.join(tempfile.gettempdir(), "robert_nonexistent_params_xyz.yaml") + _, msg = load_from_yaml(opts) + assert "not found" in msg.lower() + + +# --- PREDICT / plot verbosity --- + + +def test_predict_custom_params_dir( + custom_model_dir, fast_robert_kwargs, tmp_path, monkeypatch +): + """Custom params_dir must not raise NameError on suffixes/suffix_titles.""" + monkeypatch.chdir(tmp_path) + predict( + params_dir=str(custom_model_dir), + predict_diagnostics=False, + command_line=False, + **fast_robert_kwargs, + ) + assert (tmp_path / "PREDICT" / "RF_custom.csv").is_file() + + +def test_plot_verbosity_level_defaults_and_bounds(): + assert plot_verbosity_level(SimpleNamespace()) == 2 + assert plot_verbosity_level(SimpleNamespace(plot_verbosity="bogus")) == 2 + assert plot_verbosity_level(SimpleNamespace(plot_verbosity=-5)) == 0 + assert plot_verbosity_level(SimpleNamespace(plot_verbosity=99)) == 2 + + +def test_should_plot_predict_tiers(): + off = SimpleNamespace(predict_diagnostics=False, plot_verbosity=2) + assert not should_plot_predict_results(off) + assert not should_plot_predict_deep_diagnostics(off) + + mid = SimpleNamespace(predict_diagnostics=True, plot_verbosity=1) + assert should_plot_predict_results(mid) + assert not should_plot_predict_deep_diagnostics(mid) + + full = SimpleNamespace(predict_diagnostics=True, plot_verbosity=2) + assert should_plot_predict_results(full) + assert should_plot_predict_deep_diagnostics(full) + + +def test_stage_flags_match_levels(): + low = SimpleNamespace(plot_verbosity=0) + mid = SimpleNamespace(plot_verbosity=1) + assert not should_plot_curate_pearson(low) + assert should_plot_curate_pearson(mid) + assert not should_plot_generate_heatmap(low) + assert should_plot_generate_heatmap(mid) + assert not should_plot_verify_metrics(low) + assert should_plot_verify_metrics(mid) + + +def test_predict_plot_verbosity_zero_skips_pngs( + custom_model_dir, tmp_path, monkeypatch, fast_robert_kwargs +): + monkeypatch.chdir(tmp_path) + predict( + params_dir=str(custom_model_dir), + predict_diagnostics=True, + plot_verbosity=0, + command_line=False, + **fast_robert_kwargs, + ) + predict_root = tmp_path / "PREDICT" + pngs = list(predict_root.rglob("*.png")) + assert pngs == [] + assert (predict_root / "RF_custom.csv").is_file() + + +def test_predict_plot_verbosity_one_skips_shap( + custom_model_dir, tmp_path, monkeypatch, fast_robert_kwargs +): + monkeypatch.chdir(tmp_path) + predict( + params_dir=str(custom_model_dir), + predict_diagnostics=True, + plot_verbosity=1, + command_line=False, + **fast_robert_kwargs, + ) + predict_root = tmp_path / "PREDICT" + shap_pngs = list(predict_root.rglob("SHAP*.png")) + assert not shap_pngs + assert list(predict_root.rglob("*.png")) + + +# --- RobertModel API --- + + +def test_fit_predict_regression(tmp_path, fast_robert_kwargs): df = pd.read_csv(_REG_CSV, encoding="utf-8") X = df.drop(columns=["Target_values"]) y = df["Target_values"] @@ -49,7 +177,7 @@ def test_fit_predict_regression(tmp_path): filter_mode="no_pfi", workdir=tmp_path, names="Name", - **_FAST, + **fast_robert_kwargs, ) model.fit(X.iloc[:n_fit], y.iloc[:n_fit]) assert model.is_fitted_ @@ -69,7 +197,7 @@ def test_fit_predict_regression(tmp_path): assert sd_b.shape == hw_b.shape -def test_fit_predict_classification(tmp_path): +def test_fit_predict_classification(tmp_path, fast_robert_kwargs): df = pd.read_csv(_CLAS_CSV, encoding="utf-8") X = df.drop(columns=["Target_values"]) y = df["Target_values"] @@ -79,7 +207,7 @@ def test_fit_predict_classification(tmp_path): filter_mode="no_pfi", workdir=tmp_path, names="Name", - **_FAST, + **fast_robert_kwargs, ) model.fit(X.iloc[:n_fit], y.iloc[:n_fit]) X_hold = _holdout_for_predict(X, n_fit) @@ -91,24 +219,24 @@ def test_fit_predict_classification(tmp_path): model.predict(X_hold, return_uncertainty="conformal") -def test_deprecated_type_filter_kwargs(tmp_path): +def test_deprecated_type_filter_kwargs(tmp_path, fast_robert_kwargs): with pytest.warns(DeprecationWarning, match="problem_type"): RobertModel( type="reg", workdir=tmp_path, filter_mode="no_pfi", - **_FAST, + **fast_robert_kwargs, ) with pytest.warns(DeprecationWarning, match="filter_mode"): RobertModel( problem_type="reg", workdir=tmp_path, filter="no_pfi", - **_FAST, + **fast_robert_kwargs, ) -def test_names_col_matches_model_data_after_fit(tmp_path): +def test_names_col_matches_model_data_after_fit(tmp_path, fast_robert_kwargs): """``names_col_`` must match GENERATE params so PREDICT and API agree on the id column.""" df = pd.read_csv(_REG_CSV, encoding="utf-8") X = df.drop(columns=["Target_values"]) @@ -118,13 +246,13 @@ def test_names_col_matches_model_data_after_fit(tmp_path): filter_mode="no_pfi", workdir=tmp_path, names="Name", - **_FAST, + **fast_robert_kwargs, ) model.fit(X.iloc[:20], y.iloc[:20]) assert model.names_col_ == str(model.model_data_["names"]) -def test_names_col_default_matches_model_after_fit(tmp_path): +def test_names_col_default_matches_model_after_fit(tmp_path, fast_robert_kwargs): """Auto-inserted ``__robert_name__`` must match what CURATE stores in params.""" df = pd.read_csv(_REG_CSV, encoding="utf-8") X = df.drop(columns=["Target_values"]) @@ -134,7 +262,7 @@ def test_names_col_default_matches_model_after_fit(tmp_path): filter_mode="no_pfi", workdir=tmp_path, names=None, - **_FAST, + **fast_robert_kwargs, ) model.fit(X.iloc[:15], y.iloc[:15]) assert model.names_col_ == str(model.model_data_["names"]) @@ -153,7 +281,7 @@ def test_resolve_prediction_id_column_prefers_exact_then_model_then_casefold(): _resolve_prediction_id_column(df3, "Name", "missing") -def test_predict_row_order_matches_input_order(tmp_path): +def test_predict_row_order_matches_input_order(tmp_path, fast_robert_kwargs): """``predict`` matches ``X`` row order vs PREDICT CSV row order.""" df = pd.read_csv(_REG_CSV, encoding="utf-8") X = df.drop(columns=["Target_values"]) @@ -164,7 +292,7 @@ def test_predict_row_order_matches_input_order(tmp_path): filter_mode="no_pfi", workdir=tmp_path, names="Name", - **_FAST, + **fast_robert_kwargs, ) model.fit(X.iloc[:n_fit], y.iloc[:n_fit]) X_hold = _holdout_for_predict(X, n_fit) @@ -179,7 +307,7 @@ def test_predict_row_order_matches_input_order(tmp_path): assert np.allclose(pred_natural, pred_realigned) -def test_fit_accepts_unused_fit_params(tmp_path): +def test_fit_accepts_unused_fit_params(tmp_path, fast_robert_kwargs): """Sklearn Pipeline may pass extra fit kwargs; they should not raise.""" df = pd.read_csv(_REG_CSV, encoding="utf-8") X = df.drop(columns=["Target_values"]) @@ -189,7 +317,7 @@ def test_fit_accepts_unused_fit_params(tmp_path): filter_mode="no_pfi", workdir=tmp_path, names="Name", - **_FAST, + **fast_robert_kwargs, ) with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) diff --git a/tests/test_uq_auto.py b/tests/test_8uq.py similarity index 50% rename from tests/test_uq_auto.py rename to tests/test_8uq.py index 4ad8299..d9ff00c 100644 --- a/tests/test_uq_auto.py +++ b/tests/test_8uq.py @@ -1,6 +1,10 @@ #!/usr/bin/env python -"""Tests for automatic uncertainty selection and calibration.""" +######################################################. +# Testing UQ with pytest # +######################################################. + +"""Tests for meta-model and automatic uncertainty quantification.""" from pathlib import Path @@ -16,19 +20,14 @@ fit_uncertainty_scaler, score_uncertainty_candidate, ) +from robert.utils import ( + aggregate_meta_uq_decomposition, + discover_top_k_model_candidates, +) -_REG_CSV = Path(__file__).resolve().parent / "Robert_example.csv" - -_FAST = { - "model": ["RF"], - "n_iter": 2, - "init_points": 2, - "repeat_kfolds": 2, - "kfold": 3, - "pfi_epochs": 1, - "seed": 42, - "uq_auto_enable": True, -} +_REPO = Path(__file__).resolve().parent.parent +_REG_CSV = _REPO / "tests" / "Robert_example.csv" +_CLAS_CSV = _REPO / "tests" / "Robert_example_clas.csv" class _Args: @@ -41,6 +40,76 @@ class _Args: seed = 0 +# --- meta-model UQ --- + + +def test_aggregate_meta_uq_regression_decomposition(): + """Total variance equals within + between for regression.""" + preds = np.array([[1.0, 2.0], [3.0, 4.0], [2.0, 3.0]]) + sds = np.array([[0.1, 0.2], [0.3, 0.4], [0.2, 0.3]]) + weights = np.array([1 / 3, 1 / 3, 1 / 3]) + y, uq_m, uq_meta, uq_tot = aggregate_meta_uq_decomposition( + preds, sds, weights, "reg" + ) + assert y.shape == (2,) + assert np.all(uq_tot >= uq_m - 1e-9) + assert np.all(uq_tot >= uq_meta - 1e-9) + assert np.all(uq_m >= 0) and np.all(uq_meta >= 0) + + +def test_discover_top_k_empty_dir(tmp_path): + assert discover_top_k_model_candidates(tmp_path, 3) == [] + + +def test_fit_predict_meta_uncertainty_modes(tmp_path, fast_robert_kwargs): + df = pd.read_csv(_REG_CSV, encoding="utf-8") + X = df.drop(columns=["Target_values"]) + y = df["Target_values"] + n_fit = 25 + model = RobertModel( + problem_type="reg", + filter_mode="no_pfi", + workdir=tmp_path, + names="Name", + uq_enable_meta=True, + uq_top_k_models=2, + uq_model_weighting="uniform", + **{**fast_robert_kwargs, "model": ["RF", "GB"]}, + ) + model.fit(X.iloc[:n_fit], y.iloc[:n_fit]) + X_hold = X.iloc[n_fit:].drop_duplicates(subset=["Name"], keep="first") + y_hat, uq_meta = model.predict(X_hold, return_uncertainty="meta") + assert y_hat.shape == uq_meta.shape + assert np.isfinite(uq_meta).all() and (uq_meta >= 0).all() + _, uq_total = model.predict(X_hold, return_uncertainty="total") + y_d, uq_m, uq_meta2, uq_tot = model.predict( + X_hold, return_uncertainty="decomposed" + ) + assert y_d.shape == uq_m.shape == uq_meta2.shape == uq_tot.shape + assert np.all(uq_tot >= uq_m - 1e-9) + assert np.allclose(uq_tot, uq_total, rtol=1e-5, atol=1e-5) + + +def test_meta_uncertainty_requires_enable_flag(tmp_path, fast_robert_kwargs): + df = pd.read_csv(_REG_CSV, encoding="utf-8") + X = df.drop(columns=["Target_values"]) + y = df["Target_values"] + model = RobertModel( + problem_type="reg", + filter_mode="no_pfi", + workdir=tmp_path, + names="Name", + **fast_robert_kwargs, + ) + model.fit(X.iloc[:20], y.iloc[:20]) + X_hold = X.iloc[20:].drop_duplicates(subset=["Name"], keep="first") + with pytest.raises(ValueError, match="uq_enable_meta"): + model.predict(X_hold, return_uncertainty="total") + + +# --- automatic UQ --- + + def test_global_multiplicative_scaler_monotone(): u = np.array([0.5, 1.0, 2.0]) r = np.array([0.6, 1.2, 2.4]) @@ -83,7 +152,7 @@ def test_evaluate_uq_candidates_deterministic(): assert sel1["selected"] in (CANDIDATE_CV_SD, "conformal") -def test_fit_predict_auto_uncertainty(tmp_path): +def test_fit_predict_auto_uncertainty(tmp_path, fast_robert_kwargs): df = pd.read_csv(_REG_CSV, encoding="utf-8") X = df.drop(columns=["Target_values"]) y = df["Target_values"] @@ -93,7 +162,8 @@ def test_fit_predict_auto_uncertainty(tmp_path): filter_mode="no_pfi", workdir=tmp_path, names="Name", - **_FAST, + **fast_robert_kwargs, + uq_auto_enable=True, ) model.fit(X.iloc[:n_fit], y.iloc[:n_fit]) X_hold = X.iloc[n_fit:].drop_duplicates(subset=["Name"], keep="first") @@ -108,18 +178,16 @@ def test_fit_predict_auto_uncertainty(tmp_path): assert meta_path.is_file() -def test_auto_classification_raises(tmp_path): - clas_csv = Path(__file__).resolve().parent / "Robert_example_clas.csv" - df = pd.read_csv(clas_csv, encoding="utf-8") +def test_auto_classification_raises(tmp_path, fast_robert_kwargs): + df = pd.read_csv(_CLAS_CSV, encoding="utf-8") X = df.drop(columns=["Target_values"]) y = df["Target_values"] - fast = {k: v for k, v in _FAST.items() if k != "uq_auto_enable"} model = RobertModel( problem_type="clas", filter_mode="no_pfi", workdir=tmp_path, names="Name", - **fast, + **fast_robert_kwargs, ) model.fit(X.iloc[:20], y.iloc[:20]) X_hold = X.iloc[20:].drop_duplicates(subset=["Name"], keep="first") diff --git a/tests/test_uq_meta.py b/tests/test_uq_meta.py deleted file mode 100644 index 8efcc0e..0000000 --- a/tests/test_uq_meta.py +++ /dev/null @@ -1,92 +0,0 @@ -#!/usr/bin/env python - -"""Tests for top-k meta-model uncertainty helpers and API.""" - -from pathlib import Path - -import numpy as np -import pandas as pd -import pytest - -from robert import RobertModel -from robert.utils import ( - aggregate_meta_uq_decomposition, - discover_top_k_model_candidates, -) - -_REG_CSV = Path(__file__).resolve().parent / "Robert_example.csv" - -_FAST = { - "model": ["RF", "GB"], - "n_iter": 2, - "init_points": 2, - "repeat_kfolds": 2, - "kfold": 3, - "pfi_epochs": 1, - "seed": 42, - "uq_enable_meta": True, - "uq_top_k_models": 2, - "uq_model_weighting": "uniform", -} - - -def test_aggregate_meta_uq_regression_decomposition(): - """Total variance equals within + between for regression.""" - preds = np.array([[1.0, 2.0], [3.0, 4.0], [2.0, 3.0]]) - sds = np.array([[0.1, 0.2], [0.3, 0.4], [0.2, 0.3]]) - weights = np.array([1 / 3, 1 / 3, 1 / 3]) - y, uq_m, uq_meta, uq_tot = aggregate_meta_uq_decomposition( - preds, sds, weights, "reg" - ) - assert y.shape == (2,) - assert np.all(uq_tot >= uq_m - 1e-9) - assert np.all(uq_tot >= uq_meta - 1e-9) - assert np.all(uq_m >= 0) and np.all(uq_meta >= 0) - - -def test_discover_top_k_empty_dir(tmp_path): - assert discover_top_k_model_candidates(tmp_path, 3) == [] - - -def test_fit_predict_meta_uncertainty_modes(tmp_path): - df = pd.read_csv(_REG_CSV, encoding="utf-8") - X = df.drop(columns=["Target_values"]) - y = df["Target_values"] - n_fit = 25 - model = RobertModel( - problem_type="reg", - filter_mode="no_pfi", - workdir=tmp_path, - names="Name", - **_FAST, - ) - model.fit(X.iloc[:n_fit], y.iloc[:n_fit]) - X_hold = X.iloc[n_fit:].drop_duplicates(subset=["Name"], keep="first") - y_hat, uq_meta = model.predict(X_hold, return_uncertainty="meta") - assert y_hat.shape == uq_meta.shape - assert np.isfinite(uq_meta).all() and (uq_meta >= 0).all() - _, uq_total = model.predict(X_hold, return_uncertainty="total") - y_d, uq_m, uq_meta2, uq_tot = model.predict( - X_hold, return_uncertainty="decomposed" - ) - assert y_d.shape == uq_m.shape == uq_meta2.shape == uq_tot.shape - assert np.all(uq_tot >= uq_m - 1e-9) - assert np.allclose(uq_tot, uq_total, rtol=1e-5, atol=1e-5) - - -def test_meta_uncertainty_requires_enable_flag(tmp_path): - df = pd.read_csv(_REG_CSV, encoding="utf-8") - X = df.drop(columns=["Target_values"]) - y = df["Target_values"] - fast = {k: v for k, v in _FAST.items() if k != "uq_enable_meta"} - model = RobertModel( - problem_type="reg", - filter_mode="no_pfi", - workdir=tmp_path, - names="Name", - **fast, - ) - model.fit(X.iloc[:20], y.iloc[:20]) - X_hold = X.iloc[20:].drop_duplicates(subset=["Name"], keep="first") - with pytest.raises(ValueError, match="uq_enable_meta"): - model.predict(X_hold, return_uncertainty="total") From 153c8d577c1e97a7966b91488f4605c0ce6f69dd Mon Sep 17 00:00:00 2001 From: rlaplaza Date: Sat, 16 May 2026 17:54:29 +0200 Subject: [PATCH 6/8] docs: document plot verbosity, UQ defaults, and evaluate module Update API/module docs for plot_verbosity and UQ options, add evaluate RST pages, and refresh defaults and workflow examples. Co-authored-by: Cursor --- docs/.readthedocs.yaml | 2 +- docs/API/API_Reference.rst | 3 +- docs/API/robert.api.rst | 59 +++++++++++++++++-- docs/API/robert.evaluate.rst | 8 +++ docs/API/robert.uq_auto.rst | 21 +++++++ docs/Examples/full_workflow/smiles_vaskas.rst | 9 ++- .../full_workflow/smiles_workflow.rst | 9 ++- docs/Examples/modules/prediction.rst | 4 ++ docs/Examples/modules/screening.rst | 4 ++ docs/Misc/abbreviations.rst | 3 +- docs/Modules/curate.rst | 4 +- docs/Modules/evaluate.rst | 49 +++++++++++++++ docs/Modules/generate.rst | 2 +- docs/Modules/predict.rst | 11 +++- docs/README.rst | 20 ++++--- docs/Report/score.rst | 10 ++-- docs/Technical/defaults.rst | 46 +++++++++++++++ docs/Tutorials/videos.rst | 2 +- docs/conf.py | 38 +++++++++--- docs/index.rst | 1 + docs/requirements.txt | 1 - 21 files changed, 272 insertions(+), 34 deletions(-) create mode 100644 docs/API/robert.evaluate.rst create mode 100644 docs/Modules/evaluate.rst diff --git a/docs/.readthedocs.yaml b/docs/.readthedocs.yaml index 6cfa356..6df5658 100644 --- a/docs/.readthedocs.yaml +++ b/docs/.readthedocs.yaml @@ -9,7 +9,7 @@ version: 2 sphinx: builder: html configuration: docs/conf.py - fail_on_warning: false + fail_on_warning: true build: os: ubuntu-20.04 diff --git a/docs/API/API_Reference.rst b/docs/API/API_Reference.rst index c7e7b3d..2bf134e 100644 --- a/docs/API/API_Reference.rst +++ b/docs/API/API_Reference.rst @@ -9,8 +9,10 @@ Main modules .. toctree:: :maxdepth: 1 + robert.api robert.aqme robert.curate + robert.evaluate robert.generate robert.predict robert.verify @@ -22,7 +24,6 @@ Other modules .. toctree:: :maxdepth: 1 - robert.api robert.generate_utils robert.predict_utils robert.uq_auto diff --git a/docs/API/robert.api.rst b/docs/API/robert.api.rst index 551abf7..d5752b7 100644 --- a/docs/API/robert.api.rst +++ b/docs/API/robert.api.rst @@ -52,26 +52,76 @@ columns aligned with the pipeline: - If both ``return_std`` and ``return_uncertainty`` are set, ``return_uncertainty`` wins and a warning is issued. +Supported models +---------------- + +Pass ``model`` as a list of algorithm codes (same as the CLI ``--model`` option). +Defaults are ``["RF", "GB", "NN", "MVL"]`` for regression and ``["RF", "GB", "NN", "AdaB"]`` +for classification (when ``auto_type`` switches the problem type). + +.. list-table:: + :header-rows: 1 + :widths: 20 80 + + * - Code + - Backend + * - RF, GB, NN, MVL + - scikit-learn (default screening set for regression includes MVL instead of AdaB) + * - GP, AdaB, VR + - scikit-learn (opt-in; AdaB replaces MVL in the default classification set) + * - **XGB** + - XGBoost (:class:`~xgboost.XGBRegressor` / :class:`~xgboost.XGBClassifier`), opt-in; + hyperoptimized with in-code Bayesian bounds (no packaged ``model_params/XGB_params.yaml``) + +Example with XGB: + +.. code-block:: python + + model_xgb = RobertModel( + problem_type="reg", + workdir="./robert_run_xgb", + model=["RF", "XGB"], + n_iter=2, + init_points=2, + ) + model_xgb.fit(X_train, y_train) + Configuration (uncertainty kwargs) ------------------------------------ -Defaults are defined in ``robert.argument_parser.var_dict``: +Defaults are defined in ``robert.argument_parser.var_dict``. Every key in that +dictionary can be passed as a :class:`~robert.api.RobertModel` keyword argument or +set in a YAML varfile (``varfile=FILE.yaml``). For XGBoost, install-time dependency +availability (``xgboost``) is distinct from runtime model selection: include +``"XGB"`` in ``model`` to screen XGBoost. + +**CLI vs API / YAML.** ``python -m robert --help`` documents ``conformal_enable``, +``conformal_calib_frac``, ``conformal_coverage``, ``uq_enable_meta``, +``uq_top_k_models``, and ``uq_auto_enable``. Set ``uq_model_weighting`` and the +remaining ``uq_auto_*`` keys (candidates, scaler, metric weights, min samples, +random state, clas mode) via a YAML varfile or :class:`~robert.api.RobertModel` +keyword arguments (see list below). - **Conformal:** ``conformal_enable`` (``True``), ``conformal_calib_frac`` (``0.15``), ``conformal_coverage`` (``0.9``). - **Meta-model:** ``uq_enable_meta`` (``False``), ``uq_top_k_models`` (``3``), ``uq_model_weighting`` (``"score_weighted"`` or ``"uniform"``). +- **PREDICT diagnostics:** ``predict_diagnostics`` (``True``). When ``False``, PREDICT + skips SHAP, PFI, Pearson heatmap, outlier, and distribution plots (and y-vs-pred + graphs). :class:`~robert.api.RobertModel.predict` sets this to ``False`` automatically. +- **Plot verbosity:** ``plot_verbosity`` (``2``). Higher values emit more diagnostic + figures during PREDICT when ``predict_diagnostics`` is ``True`` (see CLI help / ``var_dict``). - **Auto (regression):** ``uq_auto_enable`` (``False``), ``uq_auto_candidates`` (``["cv_sd", "conformal", "meta_total"]``), ``uq_auto_scaler`` (``"global_multiplicative"``; also ``"none"`` or ``"isotonic"``), - ``uq_auto_metric_weights`` (coverage / sharpness / NLL weights), + ``uq_auto_metric_weights`` (default ``{"coverage": 1.0, "sharpness": 0.25, "nll": 0.5}``), ``uq_auto_min_samples`` (``12``), ``uq_auto_random_state`` (``0``), ``uq_auto_clas_mode`` (``"error"`` — raises if auto is requested for classification). Meta-model uncertainty ---------------------- -Enable with ``uq_enable_meta=True`` on :class:`~robert.api.RobertModel``. PREDICT +Enable with ``uq_enable_meta=True`` on :class:`~robert.api.RobertModel`. PREDICT re-runs up to ``uq_top_k_models`` estimators ranked by GENERATE ``combined_{error_type}`` scores in ``GENERATE/Raw_data``, then combines predictions with ``uq_model_weighting``. **Regression** uses a weighted mean and law-of-total-variance @@ -160,9 +210,10 @@ Example model = RobertModel( problem_type="reg", workdir="./robert_run", - model=["RF"], + model=["RF", "XGB"], n_iter=2, init_points=2, + conformal_enable=True, ) model.fit(X.iloc[:25], y.iloc[:25]) preds = model.predict(X.iloc[25:]) diff --git a/docs/API/robert.evaluate.rst b/docs/API/robert.evaluate.rst new file mode 100644 index 0000000..3b26b18 --- /dev/null +++ b/docs/API/robert.evaluate.rst @@ -0,0 +1,8 @@ +EVALUATE +======== + +.. autoclass:: robert.evaluate.evaluate + :members: + +.. automodule:: robert.evaluate + :noindex: diff --git a/docs/API/robert.uq_auto.rst b/docs/API/robert.uq_auto.rst index 29b037e..403e4db 100644 --- a/docs/API/robert.uq_auto.rst +++ b/docs/API/robert.uq_auto.rst @@ -1,5 +1,26 @@ uq_auto ======= +Automatic uncertainty selection for **regression**. :mod:`robert.uq_auto` scores +multiple uncertainty candidates on training out-of-fold absolute residuals, fits an +optional post-hoc scaler, and writes calibrated ``{y}_pred_uq_auto`` values plus +metadata under ``PREDICT/uq_auto_metadata.json``. + +**Candidates** (configured via ``uq_auto_candidates``): + +* ``cv_sd`` — spread across repeated CV refits. +* ``conformal`` — constant split-conformal half-width (requires ``conformal_enable``). +* ``meta_total`` — combined meta-model uncertainty (requires ``uq_enable_meta`` during fit). + +**Scalers** (``uq_auto_scaler``): ``none``, ``global_multiplicative`` (default), or +``isotonic``. + +**Selection** uses a weighted composite of coverage, sharpness, and Gaussian NLL on +an inner hold-out split (see ``uq_auto_metric_weights``). When scores tie, preference +order is ``cv_sd`` → ``conformal`` → ``meta_total``. + +User-facing configuration, output columns, and ``predict(..., return_uncertainty="auto")`` +are documented in :doc:`robert.api`. + .. automodule:: robert.uq_auto :members: diff --git a/docs/Examples/full_workflow/smiles_vaskas.rst b/docs/Examples/full_workflow/smiles_vaskas.rst index 6b18e92..1b284f9 100644 --- a/docs/Examples/full_workflow/smiles_vaskas.rst +++ b/docs/Examples/full_workflow/smiles_vaskas.rst @@ -136,6 +136,13 @@ By default, the workflow sets: * :code:`--names code_name` (name of the column containing the names of the datapoints) +.. note:: + + Default model screening in this workflow uses RF/GB/NN/MVL. XGBoost is available as opt-in + by setting :code:`--model "[...,XGB]"`. For Python API usage and uncertainty controls + (``conformal_*``, ``uq_enable_meta``/``uq_top_k_models``/``uq_model_weighting``, ``uq_auto_*``), + see :doc:`../../API/robert.api`. + Execution time and versions +++++++++++++++++++++++++++ @@ -160,7 +167,7 @@ Results :target: ../../_static/AQME-ROBERT_vaska_short.csv :width: 30 -* The workflow starts with a CSEARCH-RDKit conformer sampling (using RDKit by default, although CREST is also available if :code:`--csearch_keywords "--program crest"` is added). +* The workflow starts with a CSEARCH-RDKit conformer sampling (using RDKit by default). For CREST or other CSEARCH engines/settings, use the `AQME `__ CLI or API directly, then continue ROBERT with the produced descriptor CSV—ROBERT exposes ``--qdescp_keywords`` for the internal QDESCP step, not a separate CSEARCH flag. * Then, QDESCP is used to generate more than 200 RDKit and xTB Boltzmann-averaged molecular descriptors (using xTB geometry optimizations and different single-point calculations). diff --git a/docs/Examples/full_workflow/smiles_workflow.rst b/docs/Examples/full_workflow/smiles_workflow.rst index 5461338..76a0599 100644 --- a/docs/Examples/full_workflow/smiles_workflow.rst +++ b/docs/Examples/full_workflow/smiles_workflow.rst @@ -100,6 +100,13 @@ By default, the workflow sets: * :code:`--names code_name` (name of the column containing the names of the datapoints) +.. note:: + + Default model screening in this workflow uses RF/GB/NN/MVL. XGBoost is available as opt-in + by setting :code:`--model "[...,XGB]"`. For Python API usage and uncertainty controls + (``conformal_*``, ``uq_enable_meta``/``uq_top_k_models``/``uq_model_weighting``, ``uq_auto_*``), + see :doc:`../../API/robert.api`. + Execution time and versions +++++++++++++++++++++++++++ @@ -124,7 +131,7 @@ Results :target: ../../_static/AQME-ROBERT_solubility_short.csv :width: 30 -* The workflow starts with a CSEARCH-RDKit conformer sampling (using RDKit by default, although CREST is also available if :code:`--csearch_keywords "--program crest"` is added). +* The workflow starts with a CSEARCH-RDKit conformer sampling (using RDKit by default). For CREST or other CSEARCH engines/settings, use the `AQME `__ CLI or API directly, then continue ROBERT with the produced descriptor CSV—ROBERT exposes ``--qdescp_keywords`` for the internal QDESCP step, not a separate CSEARCH flag. * Then, QDESCP is used to generate more than 200 RDKit and xTB Boltzmann-averaged molecular descriptors (using xTB geometry optimizations and different single-point calculations). diff --git a/docs/Examples/modules/prediction.rst b/docs/Examples/modules/prediction.rst index 631599f..93a6c3e 100644 --- a/docs/Examples/modules/prediction.rst +++ b/docs/Examples/modules/prediction.rst @@ -43,6 +43,10 @@ Executing the job * :code:`--predict`: Use only the PREDICT module. +Prediction CSVs can include uncertainty columns (``{y}_pred_sd``, conformal +half-width, meta-model, and auto UQ). Configure conformal, meta, and auto knobs +via CLI/YAML or :class:`~robert.api.RobertModel`; see :doc:`../../API/robert.api`. + Execution time ++++++++++++++ diff --git a/docs/Examples/modules/screening.rst b/docs/Examples/modules/screening.rst index 327ff82..611af52 100644 --- a/docs/Examples/modules/screening.rst +++ b/docs/Examples/modules/screening.rst @@ -36,6 +36,10 @@ Executing the job * :code:`--generate`: Use only the GENERATE module. +To include XGBoost in screening (opt-in), add for example +:code:`--model "[RF,XGB]"`. Default screening uses RF, GB, NN, and MVL. See +:doc:`../../Modules/generate` and the :doc:`Python API <../../API/robert.api>`. + Execution time ++++++++++++++ diff --git a/docs/Misc/abbreviations.rst b/docs/Misc/abbreviations.rst index d92dc09..9a465c6 100644 --- a/docs/Misc/abbreviations.rst +++ b/docs/Misc/abbreviations.rst @@ -14,8 +14,9 @@ A to M N to Z **GB:** gradient boosting **RND:** random **GP:** gaussian process **SHAP:** Shapley additive explanations **KN:** k-nearest neighbors **VR:** voting regressor -**MAE:** root-mean-square error +**MAE:** mean absolute error **MCC:** Matthew's correlation coefficient **ML:** machine learning **MVL:** multivariate lineal models +**XGB:** extreme gradient boosting =========================================== ======================================= \ No newline at end of file diff --git a/docs/Modules/curate.rst b/docs/Modules/curate.rst index 6cc93ac..8040ab9 100644 --- a/docs/Modules/curate.rst +++ b/docs/Modules/curate.rst @@ -23,7 +23,7 @@ Automated protocols * Filters off variables with very low correlation to the target values (noise, with R\ :sup:`2` lower than 0.001). (Disabled by default) * Filters off duplicates. * Converts categorical descriptors into one-hot or numerical descriptors. -* At the end of the curation process, the program uses scikit-learn’s RFECV with repeated K-fold cross-validation to select the most relevant descriptors. The selection can average importances across multiple models (e.g., RF, GB, ADAB) and ensures that the number of descriptors is reduced to one third of the datapoints. +* At the end of the curation process, the program uses scikit-learn’s RFECV with repeated K-fold cross-validation to select the most relevant descriptors. The selection can average importances across multiple models (e.g., RF, GB, ADAB, XGB) and ensures that the number of descriptors is reduced to one third of the datapoints. When ``model`` includes XGB, CURATE also writes model-specific outputs (for example ``*_CURATE_XGB.csv``) using the same RFECV and ``feature_importances_`` path as RF, GB, and AdaB. Technical information +++++++++++++++++++++ @@ -35,6 +35,8 @@ numerical or one-hot encoding values. For example, consider a variable that repr * “Numbers”: It assigns numerical values (e.g., 1, 2, 3, 4) to describe the different C atom types. * “Onehot”: It creates a separate descriptor for each C atom type using 0s and 1s to indicate their presence. +When the ``model`` option lists multiple algorithms (including opt-in models such as XGB), CURATE can emit one curated CSV per model type in addition to the main curated file. Use the same ``model`` syntax as in GENERATE (CLI ``--model "[RF,XGB]"`` or :doc:`RobertModel <../API/robert.api>` kwargs). + Example +++++++ diff --git a/docs/Modules/evaluate.rst b/docs/Modules/evaluate.rst new file mode 100644 index 0000000..43f496d --- /dev/null +++ b/docs/Modules/evaluate.rst @@ -0,0 +1,49 @@ +.. evaluate-modules-start + +Evaluate a pre-specified model +------------------------------ + +Overview of the EVALUATE module ++++++++++++++++++++++++++++++++ + +The EVALUATE module skips GENERATE model screening. It prepares the +``GENERATE/Best_model/No_PFI`` folder with a user-chosen sklearn model so later +modules (VERIFY, PREDICT, REPORT) can run as in a standard workflow. + +Input required +++++++++++++++ + +A curated CSV (typically from CURATE) with descriptors and a target column ``y``. + +Automated protocols ++++++++++++++++++++ + +* Loads and standardizes the database (same path as GENERATE). +* Writes ``GENERATE/Best_model/No_PFI/{eval_model}.csv`` (model metadata) and + ``{eval_model}_db.csv`` (database with train/test ``Set`` column). +* Creates an ``EVALUATE/`` log folder; does **not** produce GENERATE heatmaps or + Raw_data screening outputs. + +Technical information ++++++++++++++++++++++ + +* **Supported models:** ``eval_model='MVL'`` (multivariate linear regression via + sklearn ``LinearRegression``) is the only option today. +* **Problem type:** regression (``type='reg'``). Classification support is planned. +* **CLI:** ``python -m robert --csv_name FILE.csv --evaluate`` plus ``--y``, + ``--names``, and optional ``--eval_model``, ``--kfold``, ``--repeat_kfolds``. +* **Typical workflow:** CURATE → EVALUATE → VERIFY → PREDICT (or full workflow + with ``--evaluate`` instead of GENERATE). +* **UQ / XGB:** EVALUATE does not screen XGB or other GENERATE models. Uncertainty + columns in PREDICT still follow :doc:`../API/robert.api` when enabled downstream. + +Example ++++++++ + +A minimal regression CSV is in the repository at ``tests/Evaluate_test.csv``: + +.. code:: shell + + python -m robert --csv_name tests/Evaluate_test.csv --evaluate --y Target_values --names Name + +.. evaluate-modules-end diff --git a/docs/Modules/generate.rst b/docs/Modules/generate.rst index 8c427e6..f835de6 100644 --- a/docs/Modules/generate.rst +++ b/docs/Modules/generate.rst @@ -40,7 +40,7 @@ The GENERATE module performs an exploration of various ML algorithms. It uses bu The software automatically generates a heatmap displaying RMSE values obtained from hyperoptimized algorithms. Furthermore, it performs permutation feature importance (PFI) analysis to identify the most influential descriptors and generate new models with only those descriptors. Users have the flexibility to fine-tune the PFI filter threshold using the "PFI_threshold" parameter. By default, this threshold removes features that contribute less than 4% to the model's R2. While this filter is activated by default, users can deactivate it by setting the "pfi_filter" option to False. -Users can choose between different modes for data splitting using the "split" option: (KN, RND, STRATIFIED, EVEN, EXTRA_Q1, EXTRA_Q5). The selection of ML algorithms during screening is tuned through the "model" parameter, offering a range of popular options such as Random Forests (RF), Multivariate Linear Models (MVL), Gradient Boosting (GB), Gaussian Process (GP), AdaBoost Regressor (AdaB), MLP Regressor Neural Network (NN), and Voting Regressor (VR). +Users can choose between different modes for data splitting using the "split" option: (KN, RND, STRATIFIED, EVEN, EXTRA_Q1, EXTRA_Q5). The selection of ML algorithms during screening is tuned through the ``model`` parameter. By default, ROBERT screens RF, GB, NN, and MVL (regression) or RF, GB, NN, and AdaB (classification). Additional options include Gaussian Process (GP), Voting Regressor/Classifier (VR), and **XGB** (XGBoost regressor/classifier). XGB is opt-in (for example ``--model "[XGB]"`` or ``--model "[RF,XGB]"``); it is hyperoptimized with in-code Bayesian bounds rather than a packaged YAML file. The same ``model`` list is accepted by the :doc:`Python API <../API/robert.api>` via :class:`~robert.api.RobertModel`, where uncertainty kwargs are documented in detail. 20% of the data is used as a test set before hyperoptimization. This algorithm ensures an even distribution of data points across the range of y values, facilitating a balanced evaluation of predictions across the low, mid, and high y-value ranges. diff --git a/docs/Modules/predict.rst b/docs/Modules/predict.rst index d0f7143..3aac34e 100644 --- a/docs/Modules/predict.rst +++ b/docs/Modules/predict.rst @@ -29,7 +29,16 @@ Technical information +++++++++++++++++++++ The PREDICT module uses models obtained in the GENERATE module to compute various metrics, including R2, MAE, and RMSE (regression), and accuracy, F1 score, and MCC (classification). This module also enables predictions for an external test dataset, incorporating predictor metrics when measured y-values are available. In cases where measured y-values are absent, the module shows predicted y-values in the resulting PDF report and within the csv_test folder created inside the PREDICT main folder. -Furthermore, this module conducts feature importance analysis through PFI and SHAP methods, which analyze how descriptors impact model performance. The PREDICT module also identifies outliers by measuring the absolute errors between predicted and measured y values. The detection of outliers is based on the “t_value” option, defaulted to two and measured in SD units. This default t-value identifies outliers in predictions exhibiting errors surpassing two SDs (approx. 5% of a normal population). +Furthermore, it conducts feature importance analysis through PFI and SHAP methods when ``predict_diagnostics`` is True and ``plot_verbosity`` is high enough to emit diagnostic figures (see CLI help / ``var_dict``). It also identifies outliers using the ``t_value`` option (default 2, in approximate SD units). + +Prediction CSVs and the :doc:`Python API <../API/robert.api>` can include uncertainty columns alongside ``{y}_pred``: + +* ``{y}_pred_sd`` — spread across repeated CV refits (overwritten by ``{y}_pred_uq_total`` when meta UQ is enabled). +* ``{y}_pred_conformal_hw`` — split-conformal half-width (regression; NaN for classification). +* ``{y}_pred_uq_model``, ``{y}_pred_uq_meta``, ``{y}_pred_uq_total`` — meta-model decomposition (opt-in). +* ``{y}_pred_uq_auto``, ``{y}_pred_uq_auto_source`` — auto-selected uncertainty (regression, opt-in). + +Full semantics, configuration kwargs, and ``predict(..., return_uncertainty=...)`` modes are documented in :doc:`../API/robert.api`. This includes conformal knobs (``conformal_*``), meta-UQ knobs (``uq_enable_meta``, ``uq_top_k_models``, ``uq_model_weighting``), and auto-UQ knobs (``uq_auto_*``). Example +++++++ diff --git a/docs/README.rst b/docs/README.rst index 43624f5..502d814 100644 --- a/docs/README.rst +++ b/docs/README.rst @@ -73,11 +73,18 @@ standards for cheminformatics studies, including: Requires the `AQME program `__. * **Data curation**, including filters for correlated descriptors, noise, and duplicates, as well as conversion of categorical descriptors. + * **EVALUATE**, to run VERIFY/PREDICT on a pre-specified linear model (``MVL``) without + GENERATE screening (see :doc:`Modules/evaluate`). * **Model selection**, including the comparison of multiple hyperoptimized models using multiple cross-validation techniques. This approach mitigates overfitting in low-data regimes. + The default screening set is RF/GB/NN/MVL for regression (or RF/GB/NN/AdaB for classification); + add ``XGB`` explicitly with ``--model`` (CLI) or ``model=[..., "XGB"]`` (Python API). * **Prediction** of external test sets, as well as SHAP and PFI feature analysis. * **VERIFY tests** to assess the predictive ability of the models, including y-shuffle, y-mean, and one-hot encoding tests. + * **Python API** (:class:`~robert.api.RobertModel`): sklearn-style ``fit``, ``predict``, and + ``score`` on DataFrames, with optional split-conformal, meta-model, and auto uncertainty + quantification. See :doc:`API/robert.api`. The code has been designed for: @@ -107,7 +114,7 @@ In a nutshell, ROBERT and all its dependencies can be installed automatically us :width: 140 :align: middle -**1.** Download the environment file `env.yaml `__ by clicking this button on GitHub |download|. +**1.** Download the environment file `env.yaml `__ by clicking this button on GitHub |download|. **2.** Open an Anaconda Prompt (Windows) or a terminal (macOS/Linux) and navigate to the folder where you saved ``env.yaml``: @@ -135,7 +142,7 @@ In a nutshell, ROBERT and all its dependencies can be installed automatically us No additional manual installation is required. **Alternative installation** -=========================== +============================ In a nutshell, ROBERT and its dependencies are installed as follows: @@ -202,10 +209,6 @@ You need a terminal with Python to install and run ROBERT. These are some sugges If you prefer a faster and easier installation, you can use the preconfigured **YAML environment file**. This method automatically installs Python, ROBERT, and all required dependencies. -.. |download| image:: /Modules/images/download.png - :width: 140 - :align: middle - **1.** Install `Anaconda with Python 3 `__ for your operating system (Windows, macOS or Linux). Alternatively, if you're familiar with conda installers, you can install `Miniconda with Python 3 `__ @@ -241,7 +244,7 @@ you can install `Miniconda with Python 3 `_. + For video tutorials on how to use easyROB, check out our `easyROB video tutorials `_. @@ -382,6 +385,7 @@ Python and Python libraries * seaborn * scipy * scikit-learn +* xgboost (dependency for optional ``XGB`` model screening; not part of the default model list) * hyperopt * numba * shap diff --git a/docs/Report/score.rst b/docs/Report/score.rst index b503ebc..2dad7d0 100644 --- a/docs/Report/score.rst +++ b/docs/Report/score.rst @@ -211,11 +211,11 @@ The following examples might help clarify these points: Differences in the RMSE/MCC obtained across the five folds of a sorted 5-fold CV (where target values, y, are sorted from minimum to maximum and not shuffled during CV). First, the minimum RMSE/mMCC among the five folds is identified. Then, the differences between each fold’s RMSE/MCC and this minimum RMSE/MCC are evaluated -+------------+--------------------------------------------------+ -| Points | Condition | -+============+==================================================+ -| • 1 | Every two folds with RMSE/MCC ≤ 1.25*min RMSE/MCC | -+------------+--------------------------------------------------+ ++------------+---------------------------------------------------+ +| Points | Condition | ++============+===================================================+ +| 1 | Every two folds with RMSE/MCC ≤ 1.25*min RMSE/MCC | ++------------+---------------------------------------------------+ Score ranges ++++++++++++ diff --git a/docs/Technical/defaults.rst b/docs/Technical/defaults.rst index 793a1eb..70a3ea4 100644 --- a/docs/Technical/defaults.rst +++ b/docs/Technical/defaults.rst @@ -42,3 +42,49 @@ REPORT .. automodule:: robert.report :noindex: + +Uncertainty and model selection +------------------------------- + +Defaults below are defined in ``robert.argument_parser.var_dict``. Full semantics, +output columns, and ``predict(..., return_uncertainty=...)`` modes are in +:doc:`../API/robert.api`. + +.. list-table:: + :header-rows: 1 + :widths: 28 72 + + * - Parameter + - Default / notes + * - ``model`` + - ``["RF", "GB", "NN", "MVL"]`` (reg) or ``["RF", "GB", "NN", "AdaB"]`` (clas); add ``"XGB"`` explicitly + * - ``conformal_enable`` + - ``True`` (regression split-conformal half-width column) + * - ``conformal_calib_frac`` + - ``0.15`` + * - ``conformal_coverage`` + - ``0.9`` + * - ``uq_enable_meta`` + - ``False`` + * - ``uq_top_k_models`` + - ``3`` + * - ``uq_model_weighting`` + - ``"score_weighted"`` (or ``"uniform"``) + * - ``uq_auto_enable`` + - ``False`` (regression auto selection) + * - ``uq_auto_candidates`` + - ``["cv_sd", "conformal", "meta_total"]`` + * - ``uq_auto_scaler`` + - ``"global_multiplicative"`` (also ``"none"``, ``"isotonic"``) + * - ``uq_auto_metric_weights`` + - ``coverage: 1.0``, ``sharpness: 0.25``, ``nll: 0.5`` + * - ``uq_auto_min_samples`` + - ``12`` + * - ``uq_auto_random_state`` + - ``0`` + * - ``uq_auto_clas_mode`` + - ``"error"`` (raises if auto UQ requested for classification) + * - ``predict_diagnostics`` + - ``True`` (SHAP, PFI, plots; ``RobertModel.predict`` forces ``False``) + * - ``plot_verbosity`` + - ``2`` (higher → more diagnostic figures when diagnostics enabled) diff --git a/docs/Tutorials/videos.rst b/docs/Tutorials/videos.rst index b659a94..7512970 100644 --- a/docs/Tutorials/videos.rst +++ b/docs/Tutorials/videos.rst @@ -18,7 +18,7 @@ Full workflow from CSV EasyROB (Graphical User Interface) -============================= +================================== A series of short tutorials showing how to install and use **easyROB** — no coding required. diff --git a/docs/conf.py b/docs/conf.py index 4e8f56f..50fe54b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,11 +11,17 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. # import os +import re import sys -# Ensure that modules can be imported without installing ROBERT -sys.path.insert(0, os.path.abspath('..')) +from pathlib import Path +# Ensure that modules can be imported without installing ROBERT +sys.path.insert(0, os.path.abspath('..')) +_root = Path(__file__).resolve().parents[1] +_setup = (_root / "setup.py").read_text(encoding="utf-8") +_version_match = re.search(r'version\s*=\s*["\']([^"\']+)["\']', _setup) +version = _version_match.group(1) if _version_match else "unknown" # -- Project information ----------------------------------------------------- @@ -24,7 +30,7 @@ author = '2023, Juan V. Alegre Requena, David Dalmau Ginesta' # The full version, including alpha/beta/rc tags -release = 'v1.0' +release = version # -- General configuration --------------------------------------------------- @@ -32,10 +38,28 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['sphinx.ext.autodoc', - 'sphinx_design', - ] -# Add any paths that contain templates here, relative to this directory. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx_design', + 'sphinx.ext.intersphinx', + 'sphinx.ext.napoleon', +] + +intersphinx_mapping = { + 'python': ('https://docs.python.org/3', None), + 'numpy': ('https://numpy.org/doc/stable/', None), + 'pandas': ('https://pandas.pydata.org/docs/', None), + 'sklearn': ('https://scikit-learn.org/stable/', None), + 'xgboost': ('https://xgboost.readthedocs.io/en/stable/', None), +} + +# Shared image substitutions (used in README partial includes) +rst_prolog = """ +.. |download| image:: /Modules/images/download.png + :width: 140 + :align: middle +""" + html_theme_options = { 'collapse_navigation': False, } diff --git a/docs/index.rst b/docs/index.rst index f0f20f5..9a10961 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -57,6 +57,7 @@ Special acknowledgments Modules/database Modules/curate + Modules/evaluate Modules/generate Modules/predict Modules/verify diff --git a/docs/requirements.txt b/docs/requirements.txt index 36b0864..cbf1e36 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,2 @@ sphinx sphinx-rtd-theme -sphinx-autoapi \ No newline at end of file From 36eff21c72c0a1fdfbbea771e516abfa0764156e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rub=C3=A9n=20Laplaza?= <30357710+rlaplaza@users.noreply.github.com> Date: Mon, 18 May 2026 21:29:20 +0200 Subject: [PATCH 7/8] Updates, fixes, linting, tests in windows and mac and linux. (#77) * ci: add Ruff lint and format checks Pin Ruff defaults in pyproject.toml, reformat Python sources, fix default lint violations, and gate CircleCI on ruff check and ruff format --check. * feat: API scores, VR tuning, and multi-platform CI Add RobertModel.robert_scores(), VR hyperparameter BO support, and v2.2.0 test extras. Extend tests (plot metrics, VR/BO, API) and docs. Refactor CircleCI to run the shared conda suite on Linux, Windows, and macOS. --- .circleci/config.yml | 395 +- .../API_workflow/robert_api_full_workflow.py | 94 + build_easyrob/buildScript.py | 257 +- build_easyrob/config_files/post_install.py | 12 +- build_easyrob/logger_config.py | 2 +- docs/API/robert.api.rst | 14 +- docs/Misc/versions.rst | 9 + docs/Modules/generate.rst | 2 +- docs/Tutorials/md_to_rst.py | 35 +- docs/conf.py | 44 +- pyproject.toml | 19 + pytest.ini | 1 + requirements-test.txt | 7 + robert/__main__.py | 8 +- robert/api.py | 92 +- robert/aqme.py | 266 +- robert/argument_parser.py | 84 +- robert/curate.py | 142 +- robert/evaluate.py | 86 +- robert/generate.py | 100 +- robert/generate_utils.py | 225 +- robert/gui_easyrob/easyrob.py | 3 +- robert/gui_easyrob/easyrob_launcher.py | 11 +- robert/gui_easyrob/main/window.py | 674 ++- robert/gui_easyrob/tabs/advanced_options.py | 47 +- robert/gui_easyrob/tabs/aqme.py | 435 +- robert/gui_easyrob/tabs/images.py | 8 +- robert/gui_easyrob/tabs/molssi.py | 34 +- robert/gui_easyrob/tabs/predictions.py | 23 +- robert/gui_easyrob/tabs/results.py | 56 +- robert/gui_easyrob/utils/aqme_utils.py | 8 +- robert/gui_easyrob/utils/molssi_utils.py | 122 +- robert/gui_easyrob/utils/predictions_utils.py | 228 +- robert/gui_easyrob/utils/utils_gui.py | 91 +- robert/gui_easyrob/version.py | 7 +- robert/predict.py | 25 +- robert/predict_utils.py | 238 +- robert/report.py | 1174 +++--- robert/report_utils.py | 1180 +++--- robert/robert.py | 61 +- robert/uq_auto.py | 29 +- robert/utils.py | 3754 ++++++++++------- robert/verify.py | 279 +- setup.py | 17 +- tests/__init__.py | 1 + tests/conftest.py | 113 + tests/test_2generate.py | 16 +- tests/test_3verify.py | 94 +- tests/test_5aqme_n_full.py | 12 + tests/test_7api.py | 57 +- tests/test_8uq.py | 10 +- tests/test_easyrob.py | 569 ++- tests/test_plot_metrics.py | 36 + tests/test_vr_bo.py | 81 + 54 files changed, 7023 insertions(+), 4364 deletions(-) create mode 100644 Examples/API_workflow/robert_api_full_workflow.py create mode 100644 pyproject.toml create mode 100644 requirements-test.txt create mode 100644 tests/__init__.py create mode 100644 tests/test_plot_metrics.py create mode 100644 tests/test_vr_bo.py diff --git a/.circleci/config.yml b/.circleci/config.yml index c5ec364..74aa66b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -2,41 +2,189 @@ # See: https://circleci.com/docs/2.0/configuration-reference version: 2.1 -# Orbs are reusable packages of CircleCI configuration that you may share across projects, enabling you to create encapsulated, parameterized commands, jobs, and executors that can be used across multiple projects. -# See: https://circleci.com/docs/2.0/orb-intro/ orbs: - # The python orb contains a set of prepackaged CircleCI configuration you can use repeatedly in your configuration files - # Orb commands and jobs help you with common scripting around a language/tool - # so you dont have to copy and paste it everywhere. - # See the orb documentation here: https://circleci.com/developer/orbs/orb/circleci/python python: circleci/python@2.1.1 codecov: codecov/codecov@3.2.4 -# Define a job to be invoked later in a workflow. -# See: https://circleci.com/docs/2.0/configuration-reference/#jobs +commands: + install-miniconda: + description: Install Miniconda on machine executors (Windows/macOS) + steps: + - run: + name: Install Miniconda + command: | + set -euo pipefail + MINICONDA_DIR="${HOME}/miniconda3" + if [ -x "${MINICONDA_DIR}/Scripts/conda.exe" ] || [ -x "${MINICONDA_DIR}/bin/conda" ]; then + echo "Miniconda already present at ${MINICONDA_DIR}" + else + case "$(uname -s)" in + MINGW*|MSYS*|CYGWIN*) + INSTALLER_URL="https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe" + curl -fsSL -o miniconda-installer.exe "${INSTALLER_URL}" + ./miniconda-installer.exe /InstallationType=JustMe /RegisterPython=0 /S /D="${MINICONDA_DIR}" + rm -f miniconda-installer.exe + ;; + Darwin) + ARCH="$(uname -m)" + if [ "${ARCH}" = "arm64" ]; then + INSTALLER_URL="https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh" + else + INSTALLER_URL="https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh" + fi + curl -fsSL -o miniconda-installer.sh "${INSTALLER_URL}" + bash miniconda-installer.sh -b -p "${MINICONDA_DIR}" + rm -f miniconda-installer.sh + ;; + *) + echo "install-miniconda: unsupported OS $(uname -s)" + exit 1 + ;; + esac + fi + { + echo "export MINICONDA_DIR=${MINICONDA_DIR}" + case "$(uname -s)" in + MINGW*|MSYS*|CYGWIN*) + echo "export PATH=\"${MINICONDA_DIR}:${MINICONDA_DIR}/Library/bin:${MINICONDA_DIR}/Scripts:\${PATH}\"" + ;; + Darwin) + echo "export PATH=\"${MINICONDA_DIR}/bin:\${PATH}\"" + ;; + esac + } >> "${BASH_ENV}" + # shellcheck source=/dev/null + source "${BASH_ENV}" + conda --version + + run-conda-test-suite: + description: Conda env, deps, and pytest suite (mirrors Linux build-and-test) + parameters: + upload_coverage: + type: boolean + default: false + steps: + - run: + name: Setup conda environment, install deps, and run tests + command: | + set -euo pipefail + # shellcheck source=/dev/null + [ -f "${BASH_ENV}" ] && source "${BASH_ENV}" + + eval "$(conda shell.bash hook)" + conda update -y conda + conda create -n cheminf python=3.12 -y + conda activate cheminf + + python -m pip install --upgrade pip + + # Core stack shared across Linux/Windows/macOS (matches Linux job). + CONDA_PKGS=( + openbabel=3.1.1 + xtb=6.7.1 + glib + gtk3 + pango + ) + case "$(uname -s)" in + Linux) + CONDA_PKGS+=( + libgfortran=14.2.0 + libxkbcommon + mscorefonts + fonts-conda-forge + ) + ;; + Darwin) + CONDA_PKGS+=(libgfortran=14.2.0) + ;; + esac + conda install -y -c conda-forge "${CONDA_PKGS[@]}" + + case "$(uname -s)" in + Linux) + export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:${LD_LIBRARY_PATH:-}" + ;; + Darwin) + export DYLD_FALLBACK_LIBRARY_PATH="$CONDA_PREFIX/lib:${DYLD_FALLBACK_LIBRARY_PATH:-}" + ;; + MINGW*|MSYS*|CYGWIN*) + export PATH="$CONDA_PREFIX/Library/bin:$CONDA_PREFIX/Scripts:$PATH" + ;; + esac + + pip install aqme==2.0.0 + pip install . + pip install pytest pytest-cov pytest-qt + + set +e + + python -m pytest \ + -v \ + --ignore=tests/test_easyrob.py \ + --cov=robert \ + --cov-report= + + CORE_EXIT_CODE=$? + if [ "$CORE_EXIT_CODE" -ne 0 ]; then + echo "[CI] Core tests failed with exit code $CORE_EXIT_CODE" + set -e + exit "$CORE_EXIT_CODE" + fi + + export QT_QPA_PLATFORM=offscreen + export QT_OPENGL=software + export QT_RHI=software + export QTWEBENGINE_DISABLE_SANDBOX=1 + export QTWEBENGINE_CHROMIUM_FLAGS="--no-sandbox --disable-gpu" + + if [ "$(uname -s)" = "Linux" ] && command -v xvfb-run >/dev/null 2>&1; then + xvfb-run -s "-screen 0 1920x1080x24" \ + python -m pytest \ + -vv \ + --cov=robert \ + --cov-append \ + --cov-report= \ + tests/test_easyrob.py + else + python -m pytest \ + -vv \ + --cov=robert \ + --cov-append \ + --cov-report= \ + tests/test_easyrob.py + fi + + GUI_EXIT_CODE=$? + set -e + + if [ "$GUI_EXIT_CODE" -eq 139 ] || [ "$GUI_EXIT_CODE" -eq 134 ]; then + echo "[CI WARNING] GUI tests completed, but Qt/PySide6 triggered a crash on interpreter shutdown (exit code $GUI_EXIT_CODE)." + echo "[CI WARNING] Treating exit code $GUI_EXIT_CODE as success because all tests passed." + GUI_EXIT_CODE=0 + fi + + if [ "$GUI_EXIT_CODE" -ne 0 ]; then + echo "[CI] GUI tests failed with exit code $GUI_EXIT_CODE" + exit "$GUI_EXIT_CODE" + fi + + if [ "<< parameters.upload_coverage >>" = "true" ]; then + coverage xml + rm -rf /tmp/* + cp coverage.xml /tmp/ + cp -r robert /tmp/ + fi + jobs: - build-and-test: # This is the name of the job, feel free to change it to better match what you're trying to do! - # These next lines defines a Docker executors: https://circleci.com/docs/2.0/executor-types/ - # You can specify an image from Dockerhub or use one of the convenience images from CircleCI's Developer Hub - # A list of available CircleCI Docker convenience images are available here: https://circleci.com/developer/images/image/cimg/python - # The executor is the environment in which the steps below will be executed - below will use a python 3.10.2 container - # Change the version below to your required version of python + build-and-test: docker: - image: continuumio/miniconda3 - # working_directory: /root/project - # Checkout the code as the first step. This is a dedicated CircleCI step. - # The python orb's install-packages step will install the dependencies from a Pipfile via Pipenv by default. - # Here we're making sure we use just use the system-wide pip. By default it uses the project root's requirements.txt. - # Then run your tests! - # CircleCI will report the results back to your VCS provider. steps: - checkout - python/install-packages: pkg-manager: pip - # app-dir: ~/project/package-directory/ # If your requirements.txt isn't in the root directory. - # pip-dependency-file: test-requirements.txt # if you have a different name for your requirements file, maybe one that combines your runtime and test requirements. - # Install system-level dependencies required by Qt/PySide6 and WeasyPrint - run: name: Install system dependencies for Qt/PySide6 and WeasyPrint command: | @@ -108,148 +256,19 @@ jobs: fonts-dejavu-extra \ ca-certificates - # ICU version can vary between base images; try the most recent, then fall back. apt-get install -y libicu72 || \ apt-get install -y libicu70 || \ apt-get install -y libicu67 || true - # Create conda environment, install ROBERT/easyROB + test stack, run tests & generate coverage - - run: - name: Run core and GUI tests with combined coverage - command: | - # ------------------------------------------------------------------ - # Proper conda initialization (non-interactive shell) - # ------------------------------------------------------------------ - eval "$(conda shell.bash hook)" - - # Optional: updating conda can be slow; keep it if we want latest solver. - conda update -y conda - - # Create and activate project-specific environment - conda create -n cheminf python=3.12 -y - conda activate cheminf - - # ------------------------------------------------------------------ - # Install project dependencies (ROBERT, AQME, GUI stack) - # ------------------------------------------------------------------ - python -m pip install --upgrade pip - - # Core scientific and native dependencies from conda-forge: - # - openbabel / xtb / libgfortran: chemistry-related tooling - # - glib / gtk3 / pango / mscorefonts: required by WeasyPrint for PDF generation - # - libxkbcommon: required by Qt for keyboard handling - conda install -y -c conda-forge \ - openbabel=3.1.1 \ - xtb=6.7.1 \ - libgfortran=14.2.0 \ - glib \ - gtk3 \ - pango \ - libxkbcommon \ - mscorefonts \ - fonts-conda-forge + - run-conda-test-suite: + upload_coverage: true - # Ensure the dynamic linker can find native libraries installed inside the conda environment. - # WeasyPrint loads GLib/Pango/Cairo through dlopen(), and without this path the loader may miss - # libraries such as libgobject-2.0.so.0 that live in $CONDA_PREFIX/lib. - export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:${LD_LIBRARY_PATH}" - - # AQME (external dependency used in some workflows) - pip install aqme==2.0.0 - - # Install ROBERT/easyROB from the current repository as a package. - # This also installs WeasyPrint==63.1 from PyPI via setup.py (install_requires). - pip install . - - # Do NOT uninstall robert here - GUI tests import it as an installed package. - - # ------------------------------------------------------------------ - # Test tooling - # ------------------------------------------------------------------ - pip install pytest pytest-cov pytest-qt - - # IMPORTANT: - # CircleCI executes 'run' steps with 'set -e' (exit on first non-zero code). - # We temporarily disable it so we can inspect pytest/xvfb-run exit codes - # and treat the known Qt segfault (139) as a special case. - set +e - - # ------------------------------------------------------------------ - # Run core (non-GUI) tests first - # ------------------------------------------------------------------ - python -m pytest \ - -v \ - --ignore=tests/test_easyrob.py \ - --cov=robert \ - --cov-report= - - # CORE_EXIT_CODE=$? - # if [ "$CORE_EXIT_CODE" -ne 0 ]; then - # echo "[CI] Core tests failed with exit code $CORE_EXIT_CODE" - # exit "$CORE_EXIT_CODE" - # fi - - # ------------------------------------------------------------------ - # Configure Qt/QtWebEngine for headless execution under Xvfb - # ------------------------------------------------------------------ - export QT_QPA_PLATFORM=offscreen - export QT_OPENGL=software - export QT_RHI=software - export QTWEBENGINE_DISABLE_SANDBOX=1 - export QTWEBENGINE_CHROMIUM_FLAGS="--no-sandbox --disable-gpu" - - # ------------------------------------------------------------------ - # Run GUI tests under Xvfb, capturing the exit code - # ------------------------------------------------------------------ - xvfb-run -s "-screen 0 1920x1080x24" \ - python -m pytest \ - -vv \ - --cov=robert \ - --cov-append \ - --cov-report= \ - tests/test_easyrob.py - - - GUI_EXIT_CODE=$? - - # Re-enable 'exit on error' for any subsequent commands. - set -e - - # PySide6/Qt sometimes causes a segmentation fault (exit code 139) on interpreter shutdown - # after all tests have passed. In that case we treat the run as successful but emit a warning. - - if [ "$GUI_EXIT_CODE" -eq 139 ] || [ "$GUI_EXIT_CODE" -eq 134 ]; then - echo "[CI WARNING] GUI tests completed, but Qt/PySide6 triggered a crash on interpreter shutdown (exit code $GUI_EXIT_CODE)." - echo "[CI WARNING] This is usually caused by threads or Qt resources not being fully cleaned up (e.g., QThread/QThreadPool/QProcess still running)." - echo "[CI WARNING] Treating exit code $GUI_EXIT_CODE as success because all tests passed and coverage.xml was generated." - GUI_EXIT_CODE=0 - fi - - # Any other non-zero exit code is treated as a real failure. - if [ "$GUI_EXIT_CODE" -ne 0 ]; then - echo "[CI] GUI tests failed with exit code $GUI_EXIT_CODE" - exit "$GUI_EXIT_CODE" - fi - - # Generate a single combined XML report after both pytest runs - coverage xml - - # ------------------------------------------------------------------ - # Save coverage and workspace artifacts (only if tests were successful) - # ------------------------------------------------------------------ - rm -r /tmp/* - cp coverage.xml /tmp/ - cp -r robert /tmp/ - - # Persist coverage report and source tree for the Codecov job - persist_to_workspace: root: /tmp paths: - coverage.xml - robert - # Same dependency stack and pytest suite as build-and-test, on Windows (win-64). - # No workspace: codecov stays on the Linux job only. build-and-test-windows: machine: image: windows-server-2022-gui:current @@ -257,51 +276,59 @@ jobs: shell: bash.exe steps: - checkout - - run: - name: Conda env, install deps, pytest (Windows) - command: | - set -euo pipefail - conda update -y conda - eval "$(conda shell.bash hook)" - conda create -y -n cheminf python=3.12 - conda activate cheminf - python -m pip install --upgrade pip - conda install -y -c conda-forge openbabel=3.1.1 - conda install -y -c conda-forge xtb=6.7.1 - conda install -y -c conda-forge glib - pip install weasyprint - conda install -y -c conda-forge gtk3 - conda install -y -c conda-forge pango - pip install aqme==2.0.0 - pip install . - pip uninstall -y robert - pip install pytest - pip install pytest-cov - python -m pytest -v --cov=robert --cov-report=xml --cov-report=term - - # the codecov orb doesn't work with the miniconda docker, so the coverage report - # needs to be stored and loaded in a new job using a python docker + - install-miniconda + - run-conda-test-suite: + upload_coverage: false + + build-and-test-macos: + macos: + xcode: "16.2.0" + resource_class: macos.m1.medium.gen1 + steps: + - checkout + - install-miniconda + - run-conda-test-suite: + upload_coverage: false codecov-coverage: docker: - image: cimg/python:3.12 steps: - # Attach artifacts from the previous job workspace - attach_workspace: at: /tmp - - # Checkout is required so Codecov can correctly map coverage to the repo - checkout - - # Upload the coverage report generated in the build-and-test job - codecov/upload: file: /tmp/coverage.xml token: CODECOV_TOKEN + lint: + docker: + - image: cimg/python:3.12 + steps: + - checkout + - run: + name: Install Ruff + command: python -m pip install 'ruff==0.14.5' + - run: + name: Ruff check + command: ruff check . + - run: + name: Ruff format check + command: ruff format --check . + workflows: sample: jobs: - - build-and-test + - lint + - build-and-test: + requires: + - lint + - build-and-test-windows: + requires: + - lint + - build-and-test-macos: + requires: + - lint - codecov-coverage: requires: - build-and-test diff --git a/Examples/API_workflow/robert_api_full_workflow.py b/Examples/API_workflow/robert_api_full_workflow.py new file mode 100644 index 0000000..40b0e0e --- /dev/null +++ b/Examples/API_workflow/robert_api_full_workflow.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +""" +Full ROBERT workflow via the Python API. + +Runs CURATE, GENERATE (Bayesian hyperparameter optimization), VERIFY, +PREDICT, and REPORT; prints sklearn and ROBERT scores and artifact paths. + +Usage (from repository root):: + + python Examples/API_workflow/robert_api_full_workflow.py + +Requires WeasyPrint system libraries for PDF generation (see ROBERT docs). +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import pandas as pd + +_REPO = Path(__file__).resolve().parents[2] +if str(_REPO) not in sys.path: + sys.path.insert(0, str(_REPO)) + +from robert import RobertModel # noqa: E402 + +_CSV = _REPO / "tests" / "Robert_example.csv" +_WORKDIR = Path(__file__).resolve().parent / "robert_api_run" + + +def main() -> None: + df = pd.read_csv(_CSV, encoding="utf-8") + X = df.drop(columns=["Target_values"]) + y = df["Target_values"] + n_fit = 25 + + model = RobertModel( + problem_type="reg", + filter_mode="no_pfi", + workdir=_WORKDIR, + names="Name", + report=True, + plot_verbosity=2, + model=["RF", "GB"], + n_iter=2, + init_points=2, + repeat_kfolds=2, + kfold=3, + seed=42, + ) + + print("Fitting (CURATE → GENERATE → VERIFY → PREDICT → REPORT)...") + model.fit(X.iloc[:n_fit], y.iloc[:n_fit]) + + info = model.best_model_info() + print(f"Best model: {info['model']} descriptors: {len(info['descriptors'])}") + + X_test = X.iloc[n_fit:].drop_duplicates(subset=["Name"], keep="first") + y_test = y.loc[X_test.index] + preds = model.predict(X_test) + r2 = model.score(X_test, y_test) + print(f"Holdout R² (sklearn): {r2:.3f}") + print(f"Holdout predictions shape: {preds.shape}") + + scores = model.robert_scores() + print(f"\nROBERT score ({scores['suffix']}): {scores['robert_score']}") + print("Sub-scores:") + for name, value in scores["components"].items(): + print(f" - {name}: {value}") + + workdir = Path(model.workdir_) + print("\nKey outputs:") + for pattern in ( + "CURATE/*.png", + "GENERATE/Heatmap*.png", + "VERIFY/VERIFY_tests_*.png", + "PREDICT/*.png", + "ROBERT_report.pdf", + ): + matches = sorted(workdir.glob(pattern)) + for path in matches[:3]: + print(f" {path.relative_to(workdir)}") + if len(matches) > 3: + print(f" ... ({len(matches)} files matching {pattern})") + + if scores["pdf_path"]: + print(f"\nReport PDF: {scores['pdf_path']}") + else: + print("\nx ROBERT_report.pdf was not created (WeasyPrint may be missing).") + + +if __name__ == "__main__": + main() diff --git a/build_easyrob/buildScript.py b/build_easyrob/buildScript.py index be0ad5a..ffb124c 100644 --- a/build_easyrob/buildScript.py +++ b/build_easyrob/buildScript.py @@ -4,7 +4,6 @@ import logging import shutil from pathlib import Path -from typing import Tuple from logger_config import setup_logger setup_logger() @@ -14,15 +13,15 @@ def define_env(env_name: str, yaml_path: Path, force_recreate: bool = False) -> bool: """Define or verify a conda environment. - + Args: env_name: Name of the conda environment yaml_path: Path to the YAML file with environment definition force_recreate: If True, recreate the environment even if it exists - + Returns: bool: True if the environment is ready to use, False if there was an error - + Raises: FileNotFoundError: If conda or YAML file is not found subprocess.CalledProcessError: If a conda command fails @@ -30,13 +29,13 @@ def define_env(env_name: str, yaml_path: Path, force_recreate: bool = False) -> if not yaml_path.exists(): log.error(f"Configuration file {yaml_path} does not exist") return False - + try: log.info("Checking environment...") result = subprocess.run( ["conda", "env", "list"], capture_output=True, text=True, check=True ) - + env_exists = any( line.startswith(env_name) or f"{env_name} " in line for line in result.stdout.splitlines() @@ -48,19 +47,34 @@ def define_env(env_name: str, yaml_path: Path, force_recreate: bool = False) -> else: log.warning(f"Environment '{env_name}' already exists.") while True: - response = input("Do you want to continue using it? [Y/N]: ").strip().lower() + response = ( + input("Do you want to continue using it? [Y/N]: ") + .strip() + .lower() + ) log.info(f"User response: {response}") match response: case "y": return True case "n": while True: - response = input("Do you want delete it and re-create? [Y/N]: ").strip().lower() + response = ( + input( + "Do you want delete it and re-create? [Y/N]: " + ) + .strip() + .lower() + ) log.info(f"User response: {response}") match response: case "y": - log.info(f"Removing existing environment '{env_name}'...") - subprocess.run(["conda", "env", "remove", "-n", env_name], check=True) + log.info( + f"Removing existing environment '{env_name}'..." + ) + subprocess.run( + ["conda", "env", "remove", "-n", env_name], + check=True, + ) break case "n": log.info("Process aborted by user.") @@ -75,11 +89,11 @@ def define_env(env_name: str, yaml_path: Path, force_recreate: bool = False) -> ["conda", "env", "create", "-n", env_name, "-f", str(yaml_path)], check=True, capture_output=True, - text=True + text=True, ) log.info(f"Environment '{env_name}' created successfully.") return True - + except FileNotFoundError as fe: log.error(f"Conda is not installed or not in PATH.\n{fe}") return False @@ -89,17 +103,17 @@ def define_env(env_name: str, yaml_path: Path, force_recreate: bool = False) -> def pyinstaller_build( - env_name: str, - spec_path: Path, - dist_dir: Path, + env_name: str, + spec_path: Path, + dist_dir: Path, work_dir: Path, clean: bool = True, debug: bool = False, one_file: bool = False, - windowed: bool = False + windowed: bool = False, ) -> bool: """Build the executable using PyInstaller. - + Args: env_name: Name of the conda environment to use spec_path: Path to the .spec file @@ -109,14 +123,14 @@ def pyinstaller_build( debug: If True, include debugging information one_file: If True, generate a single executable file windowed: If True, hide console on Windows - + Returns: bool: True if build was successful, False if there was an error """ if not spec_path.exists(): log.error(f"Spec file {spec_path} does not exist") return False - + log.info(f"Building executable with PyInstaller in environment '{env_name}'...") env = os.environ.copy() @@ -132,7 +146,7 @@ def pyinstaller_build( "-m", "PyInstaller", ] - + if clean: command.append("--clean") if debug: @@ -141,13 +155,15 @@ def pyinstaller_build( command.append("--onefile") if windowed: command.append("--windowed") - - command.extend([ - "--noconfirm", - f"--distpath={dist_dir}", - f"--workpath={work_dir}", - str(spec_path), - ]) + + command.extend( + [ + "--noconfirm", + f"--distpath={dist_dir}", + f"--workpath={work_dir}", + str(spec_path), + ] + ) try: process = subprocess.Popen( @@ -170,7 +186,7 @@ def pyinstaller_build( log.info("Build completed successfully.") log.info("Please check 'pyinstaller.log' for details and warnings") return True - + except subprocess.CalledProcessError as e: log.error("Build failed. Check 'pyinstaller.log' for more details.") log.error(f"Command: {e.cmd}") @@ -180,7 +196,7 @@ def pyinstaller_build( def conda_pack(env_name: str, output_dir: Path): log.info(f"Packing Conda environment '{env_name}' with conda-pack...") - + archive_name = f"{env_name}.tar.gz" output_path = output_dir / archive_name unpacked_dir = output_dir / "robert_env_unpacked" @@ -204,13 +220,12 @@ def conda_pack(env_name: str, output_dir: Path): str(output_path), ] - process = subprocess.run(command, check=True, capture_output=True, text=True) - + subprocess.run(command, check=True, capture_output=True, text=True) + # Descomprimir el archivo log.info("Descomprimimiendo el entorno empaquetado...") subprocess.run( - ["tar", "-xzf", str(output_path), "-C", str(unpacked_dir)], - check=True + ["tar", "-xzf", str(output_path), "-C", str(unpacked_dir)], check=True ) # Ejecutar conda-unpack en el entorno descomprimido @@ -219,14 +234,14 @@ def conda_pack(env_name: str, output_dir: Path): unpack_script = unpacked_dir / "Scripts" / "conda-unpack.exe" else: unpack_script = unpacked_dir / "bin" / "conda-unpack" - + subprocess.run([str(unpack_script)], check=True, cwd=str(unpacked_dir)) # Eliminar el archivo tar.gz original ya que no lo necesitaremos output_path.unlink() - + log.info("Environment unpacked successfully.") - + except subprocess.CalledProcessError as e: log.error("Failed during environment packing or unpacking.") log.error(f"Command: {e.cmd}") @@ -237,7 +252,7 @@ def conda_pack(env_name: str, output_dir: Path): # Default paths used across the build process def get_build_paths(root: Path) -> tuple[Path, Path, Path, Path]: """Get standard paths for building. - + Returns: tuple containing: - dist_dir: Distribution directory @@ -248,10 +263,14 @@ def get_build_paths(root: Path) -> tuple[Path, Path, Path, Path]: dist_dir = root / "distribution" tmp_dir = root / "tmp" spec_file = root / "config_files" / "pyinstallerBuild.spec" - + # Only Windows needs an installer config file - installer_config = root / "config_files" / "win32_installer.iss" if sys.platform == "win32" else None - + installer_config = ( + root / "config_files" / "win32_installer.iss" + if sys.platform == "win32" + else None + ) + return dist_dir, tmp_dir, spec_file, installer_config @@ -287,74 +306,97 @@ def build_innosetup_installer(iss_path: Path): def parse_args(): """Parse command line arguments.""" import argparse - - parser = argparse.ArgumentParser(description='Build tool for EasyRob') - + + parser = argparse.ArgumentParser(description="Build tool for EasyRob") + # Frontend (EasyRob GUI) environment arguments - frontend_group = parser.add_argument_group('Frontend Environment') - frontend_group.add_argument('--frontend-env', default='easyrob_env', - help='Name of the frontend conda environment (default: easyrob_env)') - frontend_group.add_argument('--frontend-yaml', type=Path, - help='Path to the frontend environment YAML file') - + frontend_group = parser.add_argument_group("Frontend Environment") + frontend_group.add_argument( + "--frontend-env", + default="easyrob_env", + help="Name of the frontend conda environment (default: easyrob_env)", + ) + frontend_group.add_argument( + "--frontend-yaml", type=Path, help="Path to the frontend environment YAML file" + ) + # Backend (Robert) environment arguments - backend_group = parser.add_argument_group('Backend Environment') - backend_group.add_argument('--backend-env', default='robert_env', - help='Name of the backend conda environment (default: robert_env)') - backend_group.add_argument('--backend-yaml', type=Path, - help='Path to the backend environment YAML file') - + backend_group = parser.add_argument_group("Backend Environment") + backend_group.add_argument( + "--backend-env", + default="robert_env", + help="Name of the backend conda environment (default: robert_env)", + ) + backend_group.add_argument( + "--backend-yaml", type=Path, help="Path to the backend environment YAML file" + ) + # Environment management arguments - env_group = parser.add_argument_group('Environment Management') - env_group.add_argument('--force-recreate', action='store_true', - help='Force recreation of environments even if they exist') - env_group.add_argument('--skip-env', action='store_true', - help='Skip environment creation/verification') - + env_group = parser.add_argument_group("Environment Management") + env_group.add_argument( + "--force-recreate", + action="store_true", + help="Force recreation of environments even if they exist", + ) + env_group.add_argument( + "--skip-env", action="store_true", help="Skip environment creation/verification" + ) + # Build arguments - build_group = parser.add_argument_group('Build Options') - build_group.add_argument('--debug', action='store_true', - help='Enable debug mode in PyInstaller') - build_group.add_argument('--one-file', action='store_true', - help='Generate a single executable file') - build_group.add_argument('--windowed', action='store_true', - help='Hide console on Windows') - build_group.add_argument('--skip-installer', action='store_true', - help='Skip installer creation') - + build_group = parser.add_argument_group("Build Options") + build_group.add_argument( + "--debug", action="store_true", help="Enable debug mode in PyInstaller" + ) + build_group.add_argument( + "--one-file", action="store_true", help="Generate a single executable file" + ) + build_group.add_argument( + "--windowed", action="store_true", help="Hide console on Windows" + ) + build_group.add_argument( + "--skip-installer", action="store_true", help="Skip installer creation" + ) + # macOS specific arguments if sys.platform == "darwin": parser.add_argument( - "--version", - default="0.5.0", - help="Version number for macOS package" + "--version", default="0.5.0", help="Version number for macOS package" ) - + return parser.parse_args() + def main(): """Main build script function.""" try: args = parse_args() root = Path(__file__).parent - + # Get standard build paths dist_dir, tmp_dir, spec_file, installer_config = get_build_paths(root) - + # Create required directories dist_dir.mkdir(parents=True, exist_ok=True) tmp_dir.mkdir(parents=True, exist_ok=True) - + # Define environments - if not define_env(args.frontend_env, root/"config_files"/"frontend_env.yaml", args.force_recreate): + if not define_env( + args.frontend_env, + root / "config_files" / "frontend_env.yaml", + args.force_recreate, + ): return 1 - - if not define_env(args.backend_env, root/"config_files"/"backend_env.yaml", args.force_recreate): + + if not define_env( + args.backend_env, + root / "config_files" / "backend_env.yaml", + args.force_recreate, + ): return 1 - + # Pack conda environment conda_pack(args.backend_env, dist_dir) - + # Build with PyInstaller if not pyinstaller_build( args.frontend_env, @@ -362,61 +404,70 @@ def main(): dist_dir, tmp_dir, clean=True, - windowed=args.windowed + windowed=args.windowed, ): return 1 - + # Create platform-specific installer if not args.skip_installer: - if sys.platform == 'win32': + if sys.platform == "win32": if not build_innosetup_installer(installer_config): log.error("Error creating Windows installer") return 1 - elif sys.platform == 'darwin': + elif sys.platform == "darwin": if not build_macos_package(dist_dir, "easyROB", args.version): log.error("Error creating macOS package") return 1 - + return 0 - + except Exception as e: log.exception(f"Build failed: {e}") return 1 + def build_macos_package(dist_dir: Path, app_name: str, version: str = "0.5.0"): """Build a macOS .pkg installer - + Args: dist_dir: Directory containing the .app bundle app_name: Name of the application version: Version string for the package - + Returns: bool: True if build was successful, False if there was an error """ log.info("Building macOS package...") - + try: app_path = dist_dir / "build" / f"{app_name}.app" pkg_dir = dist_dir / "installer" pkg_dir.mkdir(parents=True, exist_ok=True) - + # Build the package - subprocess.run([ - "pkgbuild", - "--root", str(app_path), - "--install-location", f"/Applications/{app_name}.app", - "--identifier", f"com.robert.{app_name.lower()}", - "--version", version, - str(pkg_dir / f"{app_name.lower()}_installer.pkg") - ], check=True) - + subprocess.run( + [ + "pkgbuild", + "--root", + str(app_path), + "--install-location", + f"/Applications/{app_name}.app", + "--identifier", + f"com.robert.{app_name.lower()}", + "--version", + version, + str(pkg_dir / f"{app_name.lower()}_installer.pkg"), + ], + check=True, + ) + log.info("macOS package built successfully") return True - + except subprocess.CalledProcessError as e: log.error(f"Error building macOS package: {e}") return False + if __name__ == "__main__": sys.exit(main()) diff --git a/build_easyrob/config_files/post_install.py b/build_easyrob/config_files/post_install.py index 98275b7..4a20f3f 100644 --- a/build_easyrob/config_files/post_install.py +++ b/build_easyrob/config_files/post_install.py @@ -3,6 +3,7 @@ import subprocess from pathlib import Path + def run_postinstall(): """ Execute platform-specific post-installation steps. @@ -11,12 +12,12 @@ def run_postinstall(): - Setting execution permissions - Running the script to unpack the conda environment """ - if sys.platform == 'darwin': - if getattr(sys, 'frozen', False): + if sys.platform == "darwin": + if getattr(sys, "frozen", False): # Get the app bundle path when running as a frozen application app_path = Path(sys._MEIPASS) - install_script = app_path / 'Contents' / 'Resources' / 'postinstall.sh' - + install_script = app_path / "Contents" / "Resources" / "postinstall.sh" + if install_script.exists(): try: # Set execute permissions @@ -26,5 +27,6 @@ def run_postinstall(): except subprocess.CalledProcessError as e: print(f"Post-installation error: {e}") -if __name__ == '__main__': + +if __name__ == "__main__": run_postinstall() diff --git a/build_easyrob/logger_config.py b/build_easyrob/logger_config.py index ad657b2..2e00cb6 100644 --- a/build_easyrob/logger_config.py +++ b/build_easyrob/logger_config.py @@ -25,7 +25,7 @@ def setup_logger(): # PyInstaller logger pyinstaller_logger = logging.getLogger("pyinstaller_logger") pyinstaller_logger.setLevel(logging.INFO) - pyinstaller_logger.propagate = False + pyinstaller_logger.propagate = False pyinstaller_log_path = Path("pyinstaller.log") pyinstaller_file_handler = logging.FileHandler(pyinstaller_log_path, mode="w") diff --git a/docs/API/robert.api.rst b/docs/API/robert.api.rst index d5752b7..19beb00 100644 --- a/docs/API/robert.api.rst +++ b/docs/API/robert.api.rst @@ -68,7 +68,7 @@ for classification (when ``auto_type`` switches the problem type). * - RF, GB, NN, MVL - scikit-learn (default screening set for regression includes MVL instead of AdaB) * - GP, AdaB, VR - - scikit-learn (opt-in; AdaB replaces MVL in the default classification set) + - scikit-learn (opt-in; AdaB replaces MVL in the default classification set; VR optimizes ensemble weights and RF/GB/NN member hyperparameters) * - **XGB** - XGBoost (:class:`~xgboost.XGBRegressor` / :class:`~xgboost.XGBClassifier`), opt-in; hyperoptimized with in-code Bayesian bounds (no packaged ``model_params/XGB_params.yaml``) @@ -222,6 +222,16 @@ Example preds3, sd_cv2, hw2 = model.predict(X.iloc[25:], return_uncertainty="both") r2 = model.score(X.iloc[25:], y.iloc[25:]) +Full workflow (CSV → CURATE → GENERATE → VERIFY → PREDICT → PDF report) +------------------------------------------------------------------------- + +A runnable script that loads a CSV, runs the full pipeline with ``report=True``, +and prints the ROBERT score with sub-scores is in +``Examples/API_workflow/robert_api_full_workflow.py``. + +After ``fit`` with ``report=True``, use :meth:`~robert.api.RobertModel.robert_scores` +to read the same score components as ``ROBERT_report.pdf`` without parsing the PDF. + .. autoclass:: robert.api.RobertModel - :members: fit, predict, score, get_params, set_params + :members: fit, predict, score, robert_scores, get_params, set_params :no-inherited-members: diff --git a/docs/Misc/versions.rst b/docs/Misc/versions.rst index 2bfd39b..2a6755c 100644 --- a/docs/Misc/versions.rst +++ b/docs/Misc/versions.rst @@ -4,6 +4,15 @@ Versions ======== +Version 2.2.0 [`url `__] + - Python API: :class:`~robert.api.RobertModel` with ``fit``, ``predict``, ``score``, and :meth:`~robert.api.RobertModel.robert_scores` + - Opt-in XGBoost screening (``XGB``) with in-code Bayesian optimization bounds + - Voting Regressor/Classifier (``VR``): Bayesian optimization of ensemble weights and member RF/GB/NN hyperparameters + - Regression uncertainty: conformal intervals, meta-model UQ, and auto-UQ (``uq_auto``) + - ``plot_verbosity`` and ``predict_diagnostics`` control diagnostic figure generation + - Matplotlib threading workaround consolidated in ``_mpl_plot_context``; figures closed after saving + - VERIFY metrics plot handles degenerate axis limits when all test scores are equal + Version 2.1.1 [`url `__] - Adding RMSE values for each fold to calculate t- and Wilconxon tests - Add code for BO function diff --git a/docs/Modules/generate.rst b/docs/Modules/generate.rst index f835de6..fb18d08 100644 --- a/docs/Modules/generate.rst +++ b/docs/Modules/generate.rst @@ -40,7 +40,7 @@ The GENERATE module performs an exploration of various ML algorithms. It uses bu The software automatically generates a heatmap displaying RMSE values obtained from hyperoptimized algorithms. Furthermore, it performs permutation feature importance (PFI) analysis to identify the most influential descriptors and generate new models with only those descriptors. Users have the flexibility to fine-tune the PFI filter threshold using the "PFI_threshold" parameter. By default, this threshold removes features that contribute less than 4% to the model's R2. While this filter is activated by default, users can deactivate it by setting the "pfi_filter" option to False. -Users can choose between different modes for data splitting using the "split" option: (KN, RND, STRATIFIED, EVEN, EXTRA_Q1, EXTRA_Q5). The selection of ML algorithms during screening is tuned through the ``model`` parameter. By default, ROBERT screens RF, GB, NN, and MVL (regression) or RF, GB, NN, and AdaB (classification). Additional options include Gaussian Process (GP), Voting Regressor/Classifier (VR), and **XGB** (XGBoost regressor/classifier). XGB is opt-in (for example ``--model "[XGB]"`` or ``--model "[RF,XGB]"``); it is hyperoptimized with in-code Bayesian bounds rather than a packaged YAML file. The same ``model`` list is accepted by the :doc:`Python API <../API/robert.api>` via :class:`~robert.api.RobertModel`, where uncertainty kwargs are documented in detail. +Users can choose between different modes for data splitting using the "split" option: (KN, RND, STRATIFIED, EVEN, EXTRA_Q1, EXTRA_Q5). The selection of ML algorithms during screening is tuned through the ``model`` parameter. By default, ROBERT screens RF, GB, NN, and MVL (regression) or RF, GB, NN, and AdaB (classification). Additional options include Gaussian Process (GP), Voting Regressor/Classifier (VR), and **XGB** (XGBoost regressor/classifier). XGB is opt-in (for example ``--model "[XGB]"`` or ``--model "[RF,XGB]"``); it is hyperoptimized with in-code Bayesian bounds rather than a packaged YAML file. **VR** combines RF, GB, and NN base estimators; Bayesian optimization tunes ensemble weights and member-model hyperparameters (prefixed ``rf_``, ``gb_``, and ``nn_`` bounds). The same ``model`` list is accepted by the :doc:`Python API <../API/robert.api>` via :class:`~robert.api.RobertModel`, where uncertainty kwargs are documented in detail. 20% of the data is used as a test set before hyperoptimization. This algorithm ensures an even distribution of data points across the range of y values, facilitating a balanced evaluation of predictions across the low, mid, and high y-value ranges. diff --git a/docs/Tutorials/md_to_rst.py b/docs/Tutorials/md_to_rst.py index 20ffe75..a0457f1 100644 --- a/docs/Tutorials/md_to_rst.py +++ b/docs/Tutorials/md_to_rst.py @@ -1,48 +1,45 @@ import os import re + def clean_text(text): text = text.replace("

", "\n\n") text = text.replace("
", "\n") # bold text = re.sub( - r"<\s*b\s*>(.*?)<\s*/\s*b\s*>", - r"**\1**", - text, - flags=re.DOTALL | re.IGNORECASE + r"<\s*b\s*>(.*?)<\s*/\s*b\s*>", r"**\1**", text, flags=re.DOTALL | re.IGNORECASE ) # italic text = re.sub( - r"<\s*i\s*>(.*?)<\s*/\s*i\s*>", - r"*\1*", - text, - flags=re.DOTALL | re.IGNORECASE + r"<\s*i\s*>(.*?)<\s*/\s*i\s*>", r"*\1*", text, flags=re.DOTALL | re.IGNORECASE ) return text.strip() + def sort_key(filename): name = filename.replace(".png", "") - numbers = re.findall(r'\d+', name) + numbers = re.findall(r"\d+", name) return tuple(int(n) for n in numbers) + def indent_block(text, spaces=3): prefix = " " * spaces return "\n".join(prefix + line if line.strip() else "" for line in text.split("\n")) + def convert(md_path, img_folder, out_path, prefix): print(f"\n===== {prefix.upper()} =====") with open(md_path, "r", encoding="utf-8") as f: content = f.read() - steps = [clean_text(s) for s in content.split('---') if s.strip()] + steps = [clean_text(s) for s in content.split("---") if s.strip()] images = sorted( - [f for f in os.listdir(img_folder) if f.endswith(".png")], - key=sort_key + [f for f in os.listdir(img_folder) if f.endswith(".png")], key=sort_key ) print(f"Steps: {len(steps)}") @@ -51,7 +48,7 @@ def convert(md_path, img_folder, out_path, prefix): rst = "" # ONE container per tutorial - rst += f".. container:: step\n\n" + rst += ".. container:: step\n\n" for i, step in enumerate(steps): step_id = i + 1 @@ -75,10 +72,10 @@ def convert(md_path, img_folder, out_path, prefix): buttons = '
\n' if step_id > 1: - buttons += f' \n' + buttons += f" \n" if step_id < len(steps): - buttons += f' \n' + buttons += f" \n" buttons += "
\n" @@ -100,6 +97,10 @@ def convert(md_path, img_folder, out_path, prefix): # RUN ALL convert("chemdraw.md", "tutorial_images/chemdraw", "chemdraw.rst", "chemdraw") convert("csv.md", "tutorial_images/csv", "csv.rst", "csv") -convert("descriptors.md", "tutorial_images/descriptors", "descriptors.rst", "descriptors") +convert( + "descriptors.md", "tutorial_images/descriptors", "descriptors.rst", "descriptors" +) convert("overview.md", "tutorial_images/overview", "overview.rst", "overview") -convert("predictions.md", "tutorial_images/predictions", "predictions.rst", "predictions") \ No newline at end of file +convert( + "predictions.md", "tutorial_images/predictions", "predictions.rst", "predictions" +) diff --git a/docs/conf.py b/docs/conf.py index 50fe54b..d5585dd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,7 +16,7 @@ from pathlib import Path # Ensure that modules can be imported without installing ROBERT -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) _root = Path(__file__).resolve().parents[1] _setup = (_root / "setup.py").read_text(encoding="utf-8") @@ -25,9 +25,9 @@ # -- Project information ----------------------------------------------------- -project = 'robert' -copyright = '2023, Juan V. Alegre Requena, David Dalmau Ginesta' -author = '2023, Juan V. Alegre Requena, David Dalmau Ginesta' +project = "robert" +copyright = "2023, Juan V. Alegre Requena, David Dalmau Ginesta" +author = "2023, Juan V. Alegre Requena, David Dalmau Ginesta" # The full version, including alpha/beta/rc tags release = version @@ -39,18 +39,18 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx_design', - 'sphinx.ext.intersphinx', - 'sphinx.ext.napoleon', + "sphinx.ext.autodoc", + "sphinx_design", + "sphinx.ext.intersphinx", + "sphinx.ext.napoleon", ] intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'numpy': ('https://numpy.org/doc/stable/', None), - 'pandas': ('https://pandas.pydata.org/docs/', None), - 'sklearn': ('https://scikit-learn.org/stable/', None), - 'xgboost': ('https://xgboost.readthedocs.io/en/stable/', None), + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "pandas": ("https://pandas.pydata.org/docs/", None), + "sklearn": ("https://scikit-learn.org/stable/", None), + "xgboost": ("https://xgboost.readthedocs.io/en/stable/", None), } # Shared image substitutions (used in README partial includes) @@ -61,25 +61,25 @@ """ html_theme_options = { - 'collapse_navigation': False, + "collapse_navigation": False, } -# Avoid paths in class names i.e. +# Avoid paths in class names i.e. # class robert.robert.curate.curate -> class curate add_module_names = False # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # Build htlm steps for tutorials -html_static_path = ['_static'] -html_js_files = ['steps.js'] -html_css_files = ['custom.css'] +html_static_path = ["_static"] +html_js_files = ["steps.js"] +html_css_files = ["custom.css"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # Disable smartquotes which might transform '--' into a different character smartquotes = False @@ -88,9 +88,9 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..49dfc1e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,19 @@ +[project] +name = "robert" + +[tool.ruff] +target-version = "py311" +line-length = 88 +extend-exclude = [ + ".git", + "build", + "dist", + "*.egg-info", + "Examples/**/*.ipynb", +] + +[tool.ruff.lint] +per-file-ignores = { "robert/gui_easyrob/easyrob_launcher.py" = ["E402"], "tests/test_easyrob.py" = ["E402"] } + +[tool.ruff.format] +quote-style = "double" diff --git a/pytest.ini b/pytest.ini index 27eec68..77f7860 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,4 @@ [pytest] testpaths = tests python_files = test_*.py +qt_api = pyside6 diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 0000000..17f32d3 --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,7 @@ +# Test dependencies (install after robert: pip install -e . && pip install -r requirements-test.txt) +pytest>=7.0 +pytest-cov>=4.0 +pytest-qt>=4.0 + +# Optional: AQME integration tests (also matches CI) +# aqme==2.0.0 diff --git a/robert/__main__.py b/robert/__main__.py index aff3b55..dba5902 100644 --- a/robert/__main__.py +++ b/robert/__main__.py @@ -12,9 +12,11 @@ # If we are running from a wheel, add the wheel to sys.path # This allows the usage python pip-*.whl/pip install pip-*.whl -if __package__ != 'robert': - print('ROBERT is not installed! Use: pip install robert or conda install -y -c conda-forge robert.') +if __package__ != "robert": + print( + "ROBERT is not installed! Use: pip install robert or conda install -y -c conda-forge robert." + ) -if __name__ == '__main__': +if __name__ == "__main__": robert.main() sys.exit() diff --git a/robert/api.py b/robert/api.py index 7b149cf..ce95194 100644 --- a/robert/api.py +++ b/robert/api.py @@ -83,7 +83,9 @@ def _suffix_title(filter_mode: str) -> str: def _find_params_csv(best_subdir: Path) -> Path: if not best_subdir.is_dir(): raise FileNotFoundError(str(best_subdir)) - csvs = sorted(p for p in best_subdir.glob("*.csv") if not p.name.endswith("_db.csv")) + csvs = sorted( + p for p in best_subdir.glob("*.csv") if not p.name.endswith("_db.csv") + ) if len(csvs) != 1: raise RuntimeError( "Expected exactly one parameter CSV in " @@ -124,7 +126,9 @@ def _resolve_prediction_id_column( ) -def _resolve_predict_csv(workdir: Path, pred_stem: str, model_code: str, suffix: str) -> str: +def _resolve_predict_csv( + workdir: Path, pred_stem: str, model_code: str, suffix: str +) -> str: """Path to PREDICT output CSV for the external set (exact file, else sorted glob).""" csv_dir = workdir / "PREDICT" / "csv_test" exact = csv_dir / f"{pred_stem}_{model_code}_{suffix}.csv" @@ -204,7 +208,9 @@ def __init__( stacklevel=2, ) if problem_type != "reg": - raise ValueError("Pass only one of 'problem_type' or deprecated 'type'.") + raise ValueError( + "Pass only one of 'problem_type' or deprecated 'type'." + ) problem_type = kwargs.pop("type") # type: ignore[assignment] if "filter" in kwargs: warnings.warn( @@ -213,7 +219,9 @@ def __init__( stacklevel=2, ) if filter_mode != "pfi": - raise ValueError("Pass only one of 'filter_mode' or deprecated 'filter'.") + raise ValueError( + "Pass only one of 'filter_mode' or deprecated 'filter'." + ) filter_mode = kwargs.pop("filter") # type: ignore[assignment] self.problem_type = problem_type @@ -387,7 +395,9 @@ def _coerce_xy( y_series.name = y_col return X_df, y_series, names_col - def _build_train_frame(self, X_df: pd.DataFrame, y_series: pd.Series) -> pd.DataFrame: + def _build_train_frame( + self, X_df: pd.DataFrame, y_series: pd.Series + ) -> pd.DataFrame: out = X_df.copy() out[y_series.name] = y_series.values return out @@ -472,7 +482,11 @@ def fit( self.feature_names_in_ = None model_names = str(model_data.get("names") or "") if model_names != names_col: - if model_names and names_col and model_names.casefold() == names_col.casefold(): + if ( + model_names + and names_col + and model_names.casefold() == names_col.casefold() + ): warnings.warn( f"Names column casing normalized to saved model column {model_names!r} " f"(was {names_col!r}).", @@ -570,9 +584,7 @@ def predict( missing = [c for c in descriptors if c not in X_df.columns] if missing: tail = "..." if len(missing) > 10 else "" - raise ValueError( - f"Missing descriptor columns: {missing[:10]!r}{tail}" - ) + raise ValueError(f"Missing descriptor columns: {missing[:10]!r}{tail}") pred_id = uuid.uuid4().hex[:12] pred_name = f"_robert_predict_{pred_id}.csv" @@ -730,6 +742,68 @@ def predict( return y_pred, y_uq_total return y_pred, y_uq_model, y_uq_meta, y_uq_total + def robert_scores( + self, + suffix: Optional[Literal["No PFI", "PFI"]] = None, + ) -> dict[str, Any]: + """ + Return the ROBERT report score and sub-scores from VERIFY/PREDICT outputs. + + Requires a prior :meth:`fit` that ran VERIFY and PREDICT (and REPORT if a + PDF is expected). Reads ``*_data.dat`` files in :attr:`workdir_`. + """ + if not self.is_fitted_: + raise RuntimeError("Call fit before robert_scores.") + workdir = self.workdir_ + if workdir is None: + raise RuntimeError("workdir is not set.") + + if suffix is None: + suffix = "PFI" if self.filter_mode == "pfi" else "No PFI" + + from robert.report_utils import calc_score, repro_info + + modules = ["CURATE", "GENERATE", "VERIFY", "PREDICT"] + with _chdir(workdir): + _, _, _, _, _, dat_files = repro_info(modules) + if "PREDICT" not in dat_files or "VERIFY" not in dat_files: + raise RuntimeError( + "PREDICT/VERIFY outputs missing in workdir; " + "run fit() with the full pipeline first." + ) + data_score: dict[str, Any] = {} + data_score = calc_score(dat_files, suffix, self.problem_type, data_score) + + score_key = f"robert_score_{suffix}" + if self.problem_type == "reg": + component_keys = [ + "cv_score_combined", + "test_score_combined", + "cv_sd_score", + "diff_scaled_rmse_score", + "flawed_mod_score", + "sorted_cv_score", + ] + else: + component_keys = [ + "cv_score_combined", + "test_score_combined", + "flawed_mod_score", + "sorted_cv_score", + "diff_mcc_score", + "descp_score", + ] + components = { + key: data_score.get(f"{key}_{suffix}", 0) for key in component_keys + } + pdf_path = workdir / "ROBERT_report.pdf" + return { + "suffix": suffix, + "robert_score": int(data_score.get(score_key, 0)), + "components": components, + "pdf_path": str(pdf_path) if pdf_path.is_file() else None, + } + def score( self, X: Union[pd.DataFrame, np.ndarray], diff --git a/robert/aqme.py b/robert/aqme.py index e385636..ed4338e 100644 --- a/robert/aqme.py +++ b/robert/aqme.py @@ -3,14 +3,14 @@ ---------- csv_name : str, default='' - Name of the CSV file containing the database with SMILES and code_name columns. A path can be provided (i.e. 'C:/Users/FOLDER/FILE.csv'). + Name of the CSV file containing the database with SMILES and code_name columns. A path can be provided (i.e. 'C:/Users/FOLDER/FILE.csv'). destination : str, default=None, Directory to create the output file(s). varfile : str, default=None - Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml). + Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml). y : str, default='' - Name of the column containing the response variable in the input CSV file (i.e. 'solubility'). - qdescp_keywords : str, default='' + Name of the column containing the response variable in the input CSV file (i.e. 'solubility'). + qdescp_keywords : str, default='' Add extra keywords to the AQME-QDESCP run (i.e. qdescp_keywords="--qdescp_atoms ['Ir']") descp_lvl : str, default='interpret' Type of descriptor to be used in the AQME-ROBERT workflow. Options are 'interpret', 'denovo' or 'full'. @@ -29,13 +29,21 @@ import sys from pathlib import Path import pandas as pd -from robert.utils import (load_variables, - finish_print, - load_database - ) +from robert.utils import load_variables, finish_print, load_database # list of potential arguments from CSV inputs in AQME -aqme_args = ['charge','mult','complex_type','geom','constraints_atoms','constraints_dist','constraints_angle','constraints_dihedral','sample'] +aqme_args = [ + "charge", + "mult", + "complex_type", + "geom", + "constraints_atoms", + "constraints_dist", + "constraints_angle", + "constraints_dihedral", + "sample", +] + class aqme: """ @@ -48,7 +56,6 @@ class aqme: """ def __init__(self, **kwargs): - start_time = time.time() # load default and user-specified variables @@ -61,27 +68,26 @@ def __init__(self, **kwargs): self = self.run_csearch_qdescp(self.args.csv_name) # run an AQME workflow for the test set (if any) - if self.args.csv_test != '': - _ = self.run_csearch_qdescp(self.args.csv_test,aqme_test=True) + if self.args.csv_test != "": + _ = self.run_csearch_qdescp(self.args.csv_test, aqme_test=True) # move AQME output files (remove previous runs as well) _ = move_aqme() # finish the printing of the AQME info file - _ = finish_print(self,start_time,'AQME') - + _ = finish_print(self, start_time, "AQME") - def run_csearch_qdescp(self,csv_target,aqme_test=False): - ''' + def run_csearch_qdescp(self, csv_target, aqme_test=False): + """ Runs CSEARCH and QDESCP jobs in AQME - ''' - + """ + # load database just to perform data checks (i.e. no need to run AQME if the specified y is not # in the database, since the program would crush in the subsequent CURATE job) - job_type = 'aqme' - path_sdf = Path(f'{os.getcwd()}/CSEARCH/sdf_temp') + job_type = "aqme" + path_sdf = Path(f"{os.getcwd()}/CSEARCH/sdf_temp") if aqme_test: - job_type = 'aqme_test' + job_type = "aqme_test" if not path_sdf.exists(): path_sdf.mkdir(exist_ok=True, parents=True) @@ -90,137 +96,199 @@ def run_csearch_qdescp(self,csv_target,aqme_test=False): if path_sdf.exists(): shutil.rmtree(path_sdf) path_sdf.mkdir(exist_ok=True, parents=True) - for sdf_file in glob.glob(f'{os.getcwd()}/CSEARCH/*.sdf'): + for sdf_file in glob.glob(f"{os.getcwd()}/CSEARCH/*.sdf"): new_sdf = path_sdf.joinpath(os.path.basename(sdf_file)) if os.path.exists(new_sdf): os.remove(new_sdf) shutil.move(sdf_file, new_sdf) - - # Load database - csv_df,_,_ = load_database(self,csv_target,job_type,print_info=False) - # avoid running calcs with special signs (i.e. *) - for name_csv_indiv in csv_df['code_name']: - if '*' in f'{name_csv_indiv}': - self.args.log.write(f"\nx WARNING! The names provided in the CSV contain * (i.e. {name_csv_indiv}). Please, remove all the * characters.") + # Load database + csv_df, _, _ = load_database(self, csv_target, job_type, print_info=False) + + # avoid running calcs with special signs (i.e. *) + for name_csv_indiv in csv_df["code_name"]: + if "*" in f"{name_csv_indiv}": + self.args.log.write( + f"\nx WARNING! The names provided in the CSV contain * (i.e. {name_csv_indiv}). Please, remove all the * characters." + ) self.args.log.finalize() sys.exit() # find if there is more than one SMILES column in the CSV file for column in csv_df.columns: if "SMILES" == column.upper() or "SMILES_" in column.upper(): - self.args.ignore.append(column) # create individual csv file for each SMILES column - csv_temp = csv_df[['code_name', column] + [col for col in csv_df.columns if col.lower() in aqme_args]] - csv_temp.columns = ['code_name', 'SMILES'] + [col for col in csv_temp.columns if col.lower() in aqme_args] - + csv_temp = csv_df[ + ["code_name", column] + + [col for col in csv_df.columns if col.lower() in aqme_args] + ] + csv_temp.columns = ["code_name", "SMILES"] + [ + col for col in csv_temp.columns if col.lower() in aqme_args + ] + if column.upper() == "SMILES": smi_suffix = None - csv_temp.to_csv('AQME_indiv.csv', index=False) - aqme_indv_name = 'AQME_indiv' + csv_temp.to_csv("AQME_indiv.csv", index=False) + aqme_indv_name = "AQME_indiv" else: smi_suffix = column.split("_")[1] - csv_temp['code_name'] = csv_temp['code_name'].astype(str) + '_' + smi_suffix - csv_temp.to_csv(f'AQME_indiv_{smi_suffix}.csv', index=False) - aqme_indv_name = f'AQME_indiv_{smi_suffix}' + csv_temp["code_name"] = ( + csv_temp["code_name"].astype(str) + "_" + smi_suffix + ) + csv_temp.to_csv(f"AQME_indiv_{smi_suffix}.csv", index=False) + aqme_indv_name = f"AQME_indiv_{smi_suffix}" # run AQME-QDESCP to generate descriptors - cmd_qdescp = ['python','-u', '-m', 'aqme', '--qdescp', '--input', f'{aqme_indv_name}.csv', '--program', 'xtb', '--csv_name', f'{aqme_indv_name}.csv', '--nprocs', f'{self.args.nprocs}', '--robert'] + cmd_qdescp = [ + "python", + "-u", + "-m", + "aqme", + "--qdescp", + "--input", + f"{aqme_indv_name}.csv", + "--program", + "xtb", + "--csv_name", + f"{aqme_indv_name}.csv", + "--nprocs", + f"{self.args.nprocs}", + "--robert", + ] _ = self.run_aqme(cmd_qdescp, self.args.qdescp_keywords) if smi_suffix is not None: # Change column names by adding suffix try: - df_temp = pd.read_csv(f'AQME-ROBERT_{self.args.descp_lvl}_{aqme_indv_name}.csv', encoding='utf-8') + df_temp = pd.read_csv( + f"AQME-ROBERT_{self.args.descp_lvl}_{aqme_indv_name}.csv", + encoding="utf-8", + ) except FileNotFoundError: - self.args.log.write("x WARNING! ROBERT stopped due to a problem with the AQME job. Please, check the previous AQME warnings.") + self.args.log.write( + "x WARNING! ROBERT stopped due to a problem with the AQME job. Please, check the previous AQME warnings." + ) sys.exit() - df_temp.columns = [f'{col}_{smi_suffix}' if col not in ['code_name','SMILES'] and col not in aqme_args else col for col in df_temp.columns] - df_temp.to_csv(f'AQME-ROBERT_{self.args.descp_lvl}_{aqme_indv_name}.csv', index=False) + df_temp.columns = [ + f"{col}_{smi_suffix}" + if col not in ["code_name", "SMILES"] and col not in aqme_args + else col + for col in df_temp.columns + ] + df_temp.to_csv( + f"AQME-ROBERT_{self.args.descp_lvl}_{aqme_indv_name}.csv", + index=False, + ) # Check if there are missing rows in the AQME-ROBERT_{aqme_indv_name}.csv if len(df_temp) < len(csv_temp): - missing_rows = csv_temp.loc[~csv_temp['code_name'].isin(df_temp['code_name'])] - missing_rows[['code_name', 'SMILES']].to_csv(f'AQME-ROBERT_{self.args.descp_lvl}_{aqme_indv_name}.csv', mode='a', header=False, index=False) + missing_rows = csv_temp.loc[ + ~csv_temp["code_name"].isin(df_temp["code_name"]) + ] + missing_rows[["code_name", "SMILES"]].to_csv( + f"AQME-ROBERT_{self.args.descp_lvl}_{aqme_indv_name}.csv", + mode="a", + header=False, + index=False, + ) # Get the order of code_name in aqme_indv_name - order = csv_temp['code_name'].tolist() + order = csv_temp["code_name"].tolist() # Sort the rows in 'AQME-ROBERT_{aqme_indv_name}.csv' based on the order - df_temp = df_temp.sort_values(by='code_name', key=lambda x: x.map({v: i for i, v in enumerate(order)})) + df_temp = df_temp.sort_values( + by="code_name", + key=lambda x: x.map({v: i for i, v in enumerate(order)}), + ) # Fill missing values with corresponding SMILES row - df_temp = df_temp.fillna(df_temp.groupby('SMILES').transform('first')) + df_temp = df_temp.fillna( + df_temp.groupby("SMILES").transform("first") + ) - df_temp.to_csv(f'AQME-ROBERT_{self.args.descp_lvl}_{aqme_indv_name}.csv', index=False) + df_temp.to_csv( + f"AQME-ROBERT_{self.args.descp_lvl}_{aqme_indv_name}.csv", + index=False, + ) # return SDF files after csv_test if aqme_test: - for sdf_file in glob.glob(f'{path_sdf}/*.sdf'): - new_sdf = Path(f'{os.getcwd()}/CSEARCH').joinpath(os.path.basename(sdf_file)) + for sdf_file in glob.glob(f"{path_sdf}/*.sdf"): + new_sdf = Path(f"{os.getcwd()}/CSEARCH").joinpath( + os.path.basename(sdf_file) + ) shutil.move(sdf_file, new_sdf) shutil.rmtree(path_sdf) # if AQME-ROBERT_AQME_indiv_n.csv >0 in folder: - if len(glob.glob(f'AQME-ROBERT_{self.args.descp_lvl}_AQME_indiv*.csv')) > 0: - + if len(glob.glob(f"AQME-ROBERT_{self.args.descp_lvl}_AQME_indiv*.csv")) > 0: df_concat = pd.DataFrame() - # Read and concatenate CSV files - for file in sorted(glob.glob(f'AQME-ROBERT_{self.args.descp_lvl}_AQME_indiv*.csv'), key=os.path.getmtime,reverse=True): - columns_to_drop = ['code_name', 'SMILES'] + aqme_args - df_temp = pd.read_csv(file, encoding='utf-8') - columns_to_drop = [col for col in columns_to_drop if col in df_temp.columns] + # Read and concatenate CSV files + for file in sorted( + glob.glob(f"AQME-ROBERT_{self.args.descp_lvl}_AQME_indiv*.csv"), + key=os.path.getmtime, + reverse=True, + ): + columns_to_drop = ["code_name", "SMILES"] + aqme_args + df_temp = pd.read_csv(file, encoding="utf-8") + columns_to_drop = [ + col for col in columns_to_drop if col in df_temp.columns + ] df_temp = df_temp.drop(columns=columns_to_drop) df_concat = pd.concat([df_temp, df_concat], axis=1) df_concat = pd.concat([csv_df, df_concat], axis=1) - df_concat.to_csv(f'AQME-ROBERT_{self.args.descp_lvl}_{csv_target}', index=False) - + df_concat.to_csv( + f"AQME-ROBERT_{self.args.descp_lvl}_{csv_target}", index=False + ) # if no qdesc_atom is set, only keep molecular properties and discard atomic properties - aqme_db = f'AQME-ROBERT_{self.args.descp_lvl}_{csv_target}' + aqme_db = f"AQME-ROBERT_{self.args.descp_lvl}_{csv_target}" # ensure that the AQME database was successfully created if not os.path.exists(aqme_db): - self.args.log.write(f"\nx The initial AQME descriptor protocol did not create any CSV output!") + self.args.log.write( + "\nx The initial AQME descriptor protocol did not create any CSV output!" + ) sys.exit() - + # remove atomic properties if no SMARTS patterns were selected in qdescp, # and drop AQME argument columns from CSV inputs (single read/write) - if 'qdescp_atoms' not in self.args.qdescp_keywords: + if "qdescp_atoms" not in self.args.qdescp_keywords: _ = filter_atom_prop_and_aqme_args(aqme_db, csv_df, strip_atom_lists=True) else: _ = filter_atom_prop_and_aqme_args(aqme_db, csv_df, strip_atom_lists=False) # delete AQME_indiv*.csv files - for file in glob.glob('*QME_indiv*.csv'): + for file in glob.glob("*QME_indiv*.csv"): os.remove(file) - + # this returns stores options just in case csv_test is included return self - def run_aqme(self,command,extra_keywords): - ''' + def run_aqme(self, command, extra_keywords): + """ Function that runs the AQME jobs - ''' + """ - if extra_keywords != '': + if extra_keywords != "": for keyword in extra_keywords.split(): command.append(keyword) subprocess.run(command) - def init_aqme(self): - ''' + """ Checks whether AQME is installed - ''' - - try: - from aqme.qprep import qprep - except ModuleNotFoundError: - self.args.log.write("x AQME is not installed (required for the --aqme option)! The program is typically installed within 2-5 minutes (https://aqme.readthedocs.io, see the Installation section)") + """ + + import importlib.util + + if importlib.util.find_spec("aqme.qprep") is None: + self.args.log.write( + "x AQME is not installed (required for the --aqme option)! The program is typically installed within 2-5 minutes (https://aqme.readthedocs.io, see the Installation section)" + ) sys.exit() @@ -229,48 +297,52 @@ def filter_atom_prop_and_aqme_args(aqme_db, csv_df, *, strip_atom_lists): Drop atomic list descriptors when no --qdescp_atoms was used, and remove columns that duplicate AQME CSV inputs (single pass over the dataframe). """ - aqme_df = pd.read_csv(aqme_db, encoding='utf-8') + aqme_df = pd.read_csv(aqme_db, encoding="utf-8") if strip_atom_lists: for column in list(aqme_df.columns): - if column == 'DBSTEP_Vbur': + if column == "DBSTEP_Vbur": aqme_df = aqme_df.drop(column, axis=1) # remove lists of atomic properties (skip columns from AQME arguments) elif aqme_df[column].dtype == object and column.lower() not in aqme_args: first_cell = aqme_df[column].iloc[0] if len(aqme_df) else None - if first_cell is not None and '[' in str(first_cell) and column not in csv_df.columns: + if ( + first_cell is not None + and "[" in str(first_cell) + and column not in csv_df.columns + ): aqme_df = aqme_df.drop(column, axis=1) for column in list(aqme_df.columns): if column.lower() in aqme_args: aqme_df = aqme_df.drop(column, axis=1) os.remove(aqme_db) - aqme_df.to_csv(f'{aqme_db}', index=None, header=True) + aqme_df.to_csv(f"{aqme_db}", index=None, header=True) def filter_atom_prop(aqme_db, csv_df): - ''' + """ Function that filters off atomic properties if no atom was selected in the --qdescp_atoms option - ''' - + """ + filter_atom_prop_and_aqme_args(aqme_db, csv_df, strip_atom_lists=True) def filter_aqme_args(aqme_db): - ''' + """ Function that filters off AQME arguments in CSV inputs - ''' + """ filter_atom_prop_and_aqme_args(aqme_db, pd.DataFrame(), strip_atom_lists=False) def move_aqme(): - ''' + """ Move raw data from AQME-CSEARCH and -QDESCP runs into the AQME folder - ''' - - for file in glob.glob(f'*'): - if 'CSEARCH' in file or 'QDESCP' in file: - if os.path.exists(f'AQME/{file}'): - if len(os.path.basename(Path(file)).split('.')) == 1: - shutil.rmtree(f'AQME/{file}') + """ + + for file in glob.glob("*"): + if "CSEARCH" in file or "QDESCP" in file: + if os.path.exists(f"AQME/{file}"): + if len(os.path.basename(Path(file)).split(".")) == 1: + shutil.rmtree(f"AQME/{file}") else: - os.remove(f'AQME/{file}') - shutil.move(file, f'AQME/{file}') \ No newline at end of file + os.remove(f"AQME/{file}") + shutil.move(file, f"AQME/{file}") diff --git a/robert/argument_parser.py b/robert/argument_parser.py index ef25c38..d695d5c 100644 --- a/robert/argument_parser.py +++ b/robert/argument_parser.py @@ -8,7 +8,7 @@ var_dict = { "varfile": None, "command_line": False, - "extra_cmd": '', + "extra_cmd": "", "curate": False, "generate": False, "predict": False, @@ -19,47 +19,47 @@ "evaluate": False, "seed": 0, "destination": None, - "csv_name" : '', - "csv_test": '', - "y" : '', - "discard" : [], - "ignore" : [], - "categorical" : "onehot", - "corr_filter_x" : True, - "corr_filter_y" : False, - "std" : True, - "desc_thres" : 25, - "thres_y" : 0.001, - "thres_x" : 0.7, - "test_set" : 0.2, - "auto_test" : True, + "csv_name": "", + "csv_test": "", + "y": "", + "discard": [], + "ignore": [], + "categorical": "onehot", + "corr_filter_x": True, + "corr_filter_y": False, + "std": True, + "desc_thres": 25, + "thres_y": 0.001, + "thres_x": 0.7, + "test_set": 0.2, + "auto_test": True, "auto_type": True, "auto_fill": True, - "model" : ['RF','GB','NN','MVL'], - "eval_model" : 'MVL', - "custom_params" : None, - "type" : "reg", - "split" : "auto", + "model": ["RF", "GB", "NN", "MVL"], + "eval_model": "MVL", + "custom_params": None, + "type": "reg", + "split": "auto", "nprocs": 8, - "error_type" : "rmse", - "pfi_epochs" : 5, - "pfi_threshold" : 0.2, - "pfi_filter" : True, - "pfi_max" : 0, - "init_points" : 10, - "n_iter" : 10, - "expect_improv" : 0.05, - "kfold" : 5, - "repeat_kfolds" : 10, - "alpha" : 0.05, - "params_dir" : '', - "t_value" : 2, - "shap_show" : 10, - "pfi_show" : 10, - "names" : '', - "qdescp_keywords" : '', + "error_type": "rmse", + "pfi_epochs": 5, + "pfi_threshold": 0.2, + "pfi_filter": True, + "pfi_max": 0, + "init_points": 10, + "n_iter": 10, + "expect_improv": 0.05, + "kfold": 5, + "repeat_kfolds": 10, + "alpha": 0.05, + "params_dir": "", + "t_value": 2, + "shap_show": 10, + "pfi_show": 10, + "names": "", + "qdescp_keywords": "", "descp_lvl": "interpret", - "report_modules" : ['AQME','CURATE','GENERATE','VERIFY','PREDICT'], + "report_modules": ["AQME", "CURATE", "GENERATE", "VERIFY", "PREDICT"], "debug_report": False, # Split conformal (regression): symmetric interval half-width. "conformal_enable": True, @@ -103,7 +103,13 @@ def set_options(kwargs): elif key.lower() in var_dict: vars(options)[key.lower()] = kwargs[key.lower()] else: - print("Warning! Option: [", key,":",kwargs[key],"] provided but no option exists, try the online documentation to see available options for each module.",) + print( + "Warning! Option: [", + key, + ":", + kwargs[key], + "] provided but no option exists, try the online documentation to see available options for each module.", + ) sys.exit() return options diff --git a/robert/curate.py b/robert/curate.py index c064668..0d7686e 100644 --- a/robert/curate.py +++ b/robert/curate.py @@ -3,9 +3,9 @@ ---------- csv_name : str, default='' - Name of the CSV file containing the database. A path can be provided (i.e. 'C:/Users/FOLDER/FILE.csv'). + Name of the CSV file containing the database. A path can be provided (i.e. 'C:/Users/FOLDER/FILE.csv'). y : str, default='' - Name of the column containing the response variable in the input CSV file (i.e. 'solubility'). + Name of the column containing the response variable in the input CSV file (i.e. 'solubility'). discard : list, default=[] List containing the columns of the input CSV file that will not be included as descriptors in the curated CSV file (i.e. "['name','SMILES']"). @@ -18,11 +18,11 @@ destination : str, default=None, Directory to create the output file(s). varfile : str, default=None - Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml). + Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml). categorical : str, default='onehot' Mode to convert data from columns with categorical variables. As an example, a variable containing 4 types of C atoms (i.e. primary, secondary, tertiary, quaternary) will be converted into categorical - variables. Options: + variables. Options: 1. 'onehot' (for one-hot encoding, ROBERT will create a descriptor for each type of C atom using 0s and 1s to indicate whether the C type is present) 2. 'numbers' (to describe the C atoms with numbers: 1, 2, 3, 4). @@ -31,13 +31,13 @@ of the descriptors with other descriptors (x filter). corr_filter_y : bool, default=False Activate the correlation filters of descriptors, based on the correlation - of the descriptors with the y values (y filter, for noise). This filter is only + of the descriptors with the y values (y filter, for noise). This filter is only suggested for MVL. desc_thres : float, default=25 Threshold for the descriptor-to-datapoints ratio to loose the correlation filter. By default, the correlation filter is loosen if there are 25 times more datapoints than descriptors. thres_x : float, default=0.7 - Thresolhold to discard descriptors based on high R**2 correlation with other descriptors (i.e. + Thresolhold to discard descriptors based on high R**2 correlation with other descriptors (i.e. if thres_x=0.7, variables that show R**2 > 0.7 will be discarded). thres_y : float, default=0.001 Thresolhold to discard descriptors with poor correlation with the y values based on R**2 (i.e. @@ -45,7 +45,7 @@ seed : int, default=0 Random seed used in RFECV feature selector and other protocols. kfold : int, default=5 - Number of random data splits for the cross-validation of the RFECV feature selector. + Number of random data splits for the cross-validation of the RFECV feature selector. repeat_kfolds : int, default=10 Number of repetitions for the k-fold cross-validation of the RFECV feature selector. auto_type : bool, default=True @@ -85,28 +85,27 @@ class curate: """ def __init__(self, **kwargs): - start_time = time.time() # load default and user-specified variables self.args = load_variables(kwargs, "curate") # load database, discard user-defined descriptors and perform data checks - csv_df,_,_ = load_database(self,self.args.csv_name,"curate") + csv_df, _, _ = load_database(self, self.args.csv_name, "curate") # adjust options of classification problems and detects whether the right type of problem was used - self = check_clas_problem(self,csv_df) + self = check_clas_problem(self, csv_df) if not self.args.evaluate: # transform categorical descriptors - csv_df = categorical_transform(self,csv_df,'curate') + csv_df = categorical_transform(self, csv_df, "curate") # apply duplicate filters (i.e., duplication of datapoints or descriptors) csv_df = self.dup_filter(csv_df) # apply the correlation filters and returns the database without correlated descriptors if self.args.corr_filter_x or self.args.corr_filter_y: - csv_df_result = correlation_filter(self,csv_df) + csv_df_result = correlation_filter(self, csv_df) # Check if result is a tuple (model-specific CSVs) or single dataframe if isinstance(csv_df_result, tuple): csv_df, csv_df_per_model = csv_df_result @@ -115,7 +114,7 @@ def __init__(self, **kwargs): csv_df_per_model = None else: csv_df_per_model = None - + # save the curated CSVs (one per model if RFECV is applied) if csv_df_per_model is not None: _ = self.save_curate((csv_df, csv_df_per_model)) @@ -127,27 +126,26 @@ def __init__(self, **kwargs): _ = pearson_map(self, csv_df, "curate") # finish the printing of the CURATE info file - _ = finish_print(self,start_time,'CURATE') - + _ = finish_print(self, start_time, "CURATE") - def dup_filter(self,csv_df_dup): - ''' + def dup_filter(self, csv_df_dup): + """ Removes duplicated datapoints and descriptors - ''' + """ - txt_dup = f'\no Duplication filters activated' - txt_dup += f'\n Excluded datapoints:' + txt_dup = "\no Duplication filters activated" + txt_dup += "\n Excluded datapoints:" # remove duplicated entries datapoint_drop = [] - for i,datapoint in enumerate(csv_df_dup.duplicated()): + for i, datapoint in enumerate(csv_df_dup.duplicated()): if datapoint: datapoint_drop.append(i) for datapoint in datapoint_drop: - txt_dup += f'\n - Datapoint number {datapoint}' + txt_dup += f"\n - Datapoint number {datapoint}" if len(datapoint_drop) == 0: - txt_dup += f'\n - No datapoints were removed' + txt_dup += "\n - No datapoints were removed" csv_df_dup = csv_df_dup.drop(datapoint_drop, axis=0) @@ -156,64 +154,88 @@ def dup_filter(self,csv_df_dup): return csv_df_dup - - def save_curate(self,csv_df): - ''' + def save_curate(self, csv_df): + """ Saves the curated database and options used in CURATE - ''' + """ + + csv_basename = os.path.basename(f"{self.args.csv_name}").split(".")[0] - csv_basename = os.path.basename(f'{self.args.csv_name}').split('.')[0] - # Check if csv_df is a tuple (csv_df_filtered, csv_df_per_model) from correlation_filter if isinstance(csv_df, tuple): csv_df_filtered, csv_df_per_model = csv_df - + # Save model-specific curated databases (sorted for reproducibility) for model in csv_df_per_model: - csv_curate_name = f'{csv_basename}_CURATE_{model}.csv' + csv_curate_name = f"{csv_basename}_CURATE_{model}.csv" csv_curate_name = self.args.destination.joinpath(csv_curate_name) # Sort rows by y value for reproducibility - csv_df_to_save = csv_df_per_model[model].reset_index(drop=True).sort_values(by=self.args.y).reset_index(drop=True) - _ = csv_df_to_save.to_csv(f'{csv_curate_name}', index=None, header=True) + csv_df_to_save = ( + csv_df_per_model[model] + .reset_index(drop=True) + .sort_values(by=self.args.y) + .reset_index(drop=True) + ) + _ = csv_df_to_save.to_csv(f"{csv_curate_name}", index=None, header=True) # y values to predict, considering that ROBERT will work with multiple values of y in future versions - if not isinstance(self.args.y,list): + if not isinstance(self.args.y, list): count_y = 1 else: count_y = len(self.args.y) - - txt_csv = f'\n o Model {model}: {len(csv_df_per_model[model].columns)-len(self.args.ignore)-count_y} descriptors remaining:\n' - txt_csv += ' ' + ', '.join(f'{var}' for var in csv_df_per_model[model].columns if var not in self.args.ignore and var != self.args.y) + + txt_csv = f"\n o Model {model}: {len(csv_df_per_model[model].columns) - len(self.args.ignore) - count_y} descriptors remaining:\n" + txt_csv += " " + ", ".join( + f"{var}" + for var in csv_df_per_model[model].columns + if var not in self.args.ignore and var != self.args.y + ) self.args.log.write(txt_csv) - - self.args.log.write(f'\no Model-specific curated databases were stored in {self.args.destination}') - + + self.args.log.write( + f"\no Model-specific curated databases were stored in {self.args.destination}" + ) + # Save general curated database (for reference/Pearson map, sorted for reproducibility) - csv_curate_name_general = f'{csv_basename}_CURATE.csv' - csv_curate_name_general = self.args.destination.joinpath(csv_curate_name_general) - csv_df_to_save_general = csv_df_filtered.reset_index(drop=True).sort_values(by=self.args.y).reset_index(drop=True) - _ = csv_df_to_save_general.to_csv(f'{csv_curate_name_general}', index=None, header=True) + csv_curate_name_general = f"{csv_basename}_CURATE.csv" + csv_curate_name_general = self.args.destination.joinpath( + csv_curate_name_general + ) + csv_df_to_save_general = ( + csv_df_filtered.reset_index(drop=True) + .sort_values(by=self.args.y) + .reset_index(drop=True) + ) + _ = csv_df_to_save_general.to_csv( + f"{csv_curate_name_general}", index=None, header=True + ) else: # Original behavior: save single curated database - csv_curate_name_general = f'{csv_basename}_CURATE.csv' - csv_curate_name_general = self.args.destination.joinpath(csv_curate_name_general) - _ = csv_df.to_csv(f'{csv_curate_name_general}', index=None, header=True) - path_reduced = '/'.join(f'{csv_curate_name_general}'.replace('\\','/').split('/')[-2:]) - self.args.log.write(f'\no The curated database was stored in {path_reduced}.') + csv_curate_name_general = f"{csv_basename}_CURATE.csv" + csv_curate_name_general = self.args.destination.joinpath( + csv_curate_name_general + ) + _ = csv_df.to_csv(f"{csv_curate_name_general}", index=None, header=True) + path_reduced = "/".join( + f"{csv_curate_name_general}".replace("\\", "/").split("/")[-2:] + ) + self.args.log.write( + f"\no The curated database was stored in {path_reduced}." + ) # Save important options used in CURATE - options_name = f'CURATE_options.csv' + options_name = "CURATE_options.csv" options_name = self.args.destination.joinpath(options_name) options_df = pd.DataFrame() - options_df['y'] = [self.args.y] - options_df['ignore'] = [self.args.ignore] - options_df['names'] = [self.args.names] - options_df['csv_name'] = [csv_curate_name_general] - + options_df["y"] = [self.args.y] + options_df["ignore"] = [self.args.ignore] + options_df["names"] = [self.args.names] + options_df["csv_name"] = [csv_curate_name_general] + # Save class label mapping if it exists (for classification with string labels) - if hasattr(self.args, 'class_0_label'): - options_df['class_0_label'] = [self.args.class_0_label] - options_df['class_1_label'] = [self.args.class_1_label] - - _ = options_df.to_csv(f'{options_name}', index=None, header=True) + if hasattr(self.args, "class_0_label"): + options_df["class_0_label"] = [self.args.class_0_label] + options_df["class_1_label"] = [self.args.class_1_label] + + _ = options_df.to_csv(f"{options_name}", index=None, header=True) diff --git a/robert/evaluate.py b/robert/evaluate.py index c298c52..5ac54a0 100644 --- a/robert/evaluate.py +++ b/robert/evaluate.py @@ -7,16 +7,16 @@ csv_name : str, default='' Name of the CSV file containing all the points used in the model (combining train + valid + test). - A path can be provided (i.e. 'C:/Users/FOLDER/FILE.csv'). + A path can be provided (i.e. 'C:/Users/FOLDER/FILE.csv'). y : str, default='' - Name of the column containing the response variable in the input CSV file (i.e. 'solubility'). + Name of the column containing the response variable in the input CSV file (i.e. 'solubility'). names : str, default='' Column of the names for each datapoint. Names are used to print outliers. eval_model : str, default='MVL' - ML models that can be evaluated (for now, only models from sklearn are accepted, more options will be added): + ML models that can be evaluated (for now, only models from sklearn are accepted, more options will be added): 1. 'MVL' (Multivariate lineal models, LinearRegression() in sklearn) type : str, default='reg' - Type of the pedictions. Options: + Type of the pedictions. Options: 1. 'reg' (Regressor) 2. 'clas' (Classifier) seed : int, default=0 @@ -28,7 +28,7 @@ +++++++++++++++++++++++++ kfold : int, default=5 - Number of random data splits for the cross-validation of the models. + Number of random data splits for the cross-validation of the models. repeat_kfolds : int, default=10 Number of repetitions for the k-fold cross-validation of the models. @@ -43,7 +43,13 @@ import time import pandas as pd from pathlib import Path -from robert.utils import load_variables, finish_print, load_database, prepare_sets, path_generate_best_model +from robert.utils import ( + load_variables, + finish_print, + load_database, + prepare_sets, + path_generate_best_model, +) from robert.generate_utils import set_sets @@ -58,7 +64,6 @@ class evaluate: """ def __init__(self, **kwargs): - start_time = time.time() # load default and user-specified variables @@ -68,31 +73,44 @@ def __init__(self, **kwargs): _ = self.clean_eval() # load database, discard user-defined descriptors and perform data checks - csv_df, csv_X, csv_y = load_database(self,self.args.csv_name,"generate",print_info=False) + csv_df, csv_X, csv_y = load_database( + self, self.args.csv_name, "generate", print_info=False + ) # standardizes and separates an external test set - Xy_data = prepare_sets(self,csv_df,csv_X,csv_y,None,self.args.names,None,None,None,BO_opt=True) + Xy_data = prepare_sets( + self, + csv_df, + csv_X, + csv_y, + None, + self.args.names, + None, + None, + None, + BO_opt=True, + ) # saves database and model params in the /GENERATE/Best_model/No_PFI folder - _ = self.save_generate(csv_df,Xy_data) + _ = self.save_generate(csv_df, Xy_data) # finish the printing of the EVALUATE info file - _ = finish_print(self,start_time,'EVALUATE') + _ = finish_print(self, start_time, "EVALUATE") def clean_eval(self): - ''' + """ Cleans folders from previous runs - ''' + """ - for folder in ['CURATE','GENERATE','VERIFY','PREDICT']: - eval_folder = f'{Path(os.getcwd()).joinpath(folder)}' + for folder in ["CURATE", "GENERATE", "VERIFY", "PREDICT"]: + eval_folder = f"{Path(os.getcwd()).joinpath(folder)}" if os.path.exists(eval_folder): shutil.rmtree(eval_folder) - def save_generate(self,csv_df,Xy_data): - ''' + def save_generate(self, csv_df, Xy_data): + """ Saves database and model params in the /GENERATE/Best_model/No_PFI folder - ''' + """ # copy database with Set column generate_folder = path_generate_best_model(None, "No_PFI") @@ -101,21 +119,25 @@ def save_generate(self,csv_df,Xy_data): Path(generate_folder).mkdir(exist_ok=True, parents=True) # include the Set column to differentiate between train and test sets (and external test, if any) - csv_df = set_sets(csv_df,Xy_data) + csv_df = set_sets(csv_df, Xy_data) - _ = csv_df.to_csv(f'{generate_folder}/{self.args.eval_model}_db.csv', index = None, header=True) + _ = csv_df.to_csv( + f"{generate_folder}/{self.args.eval_model}_db.csv", index=None, header=True + ) # save all the parameters of the model df_params = pd.DataFrame() - df_params['kfold'] = [self.args.kfold] - df_params['repeat_kfolds'] = [self.args.repeat_kfolds] - df_params['model'] = [self.args.eval_model] - df_params['type'] = [self.args.type] - df_params['seed'] = [self.args.seed] - df_params['y'] = [self.args.y] - df_params['names'] = [self.args.names] - df_params['error_type'] = [self.args.error_type] - df_params['params'] = '{}' - df_params['X_descriptors'] = [list(Xy_data['X_descriptors'])] - - _ = df_params.to_csv(f'{generate_folder}/{self.args.eval_model}.csv', index = None, header=True) + df_params["kfold"] = [self.args.kfold] + df_params["repeat_kfolds"] = [self.args.repeat_kfolds] + df_params["model"] = [self.args.eval_model] + df_params["type"] = [self.args.type] + df_params["seed"] = [self.args.seed] + df_params["y"] = [self.args.y] + df_params["names"] = [self.args.names] + df_params["error_type"] = [self.args.error_type] + df_params["params"] = "{}" + df_params["X_descriptors"] = [list(Xy_data["X_descriptors"])] + + _ = df_params.to_csv( + f"{generate_folder}/{self.args.eval_model}.csv", index=None, header=True + ) diff --git a/robert/generate.py b/robert/generate.py index 640a4da..fba519a 100644 --- a/robert/generate.py +++ b/robert/generate.py @@ -3,9 +3,9 @@ ---------- csv_name : str, default='' - Name of the CSV file containing the database. A path can be provided (i.e. 'C:/Users/FOLDER/FILE.csv'). + Name of the CSV file containing the database. A path can be provided (i.e. 'C:/Users/FOLDER/FILE.csv'). y : str, default='' - Name of the column containing the response variable in the input CSV file (i.e. 'solubility'). + Name of the column containing the response variable in the input CSV file (i.e. 'solubility'). discard : list, default=[] List containing the columns of the input CSV file that will not be included as descriptors in the curated CSV file (i.e. ['name','SMILES']). @@ -16,11 +16,11 @@ destination : str, default=None Directory to create the output file(s). varfile : str, default=None - Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml). + Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml). auto_type : bool, default=True If there are only two y values, the program automatically changes the type of problem to classification. - model : list, default=['RF','GB','NN','MVL'] (regression) and default=['RF','GB','NN','AdaB'] (classification) - ML models available: + model : list, default=['RF','GB','NN','MVL'] (regression) and default=['RF','GB','NN','AdaB'] (classification) + ML models available: 1. 'RF' (Random forest) 2. 'MVL' (Multivariate lineal models) 3. 'GB' (Gradient boosting) @@ -33,7 +33,7 @@ Define new parameters for the ML models used in the hyperoptimization workflow. The path to the folder containing all the yaml files should be specified (i.e. custom_params='YAML_FOLDER') type : str, default='reg' - Type of the pedictions. Options: + Type of the pedictions. Options: 1. 'reg' (Regressor) 2. 'clas' (Classifier) seed : int, default=0 @@ -71,7 +71,7 @@ hyperoptimization, and PREDICT will use the points as test set during ROBERT workflows. Select --test_set 0 to use only training and validation. kfold : int, default=5 - Number of random data splits for the cross-validation of the models. + Number of random data splits for the cross-validation of the models. repeat_kfolds : int, default=10 Number of repetitions for the k-fold cross-validation of the models. split : str, default= 'even' (regression) or 'rnd' (classification) @@ -82,7 +82,7 @@ 4. 'KN': uses a k-means approach to select representative samples for training (good for intrapolation, bad for extrapolation). 5. 'extra_q1': selects the 20% lowest values. 6. 'extra_q5': selects the 20% highest values. - + """ #####################################################. # This file stores the GENERATE class # @@ -103,7 +103,7 @@ BO_workflow, PFI_workflow, heatmap_workflow, - detect_best + detect_best, ) @@ -118,66 +118,88 @@ class generate: """ def __init__(self, **kwargs): - start_time = time.time() # load default and user-specified variables self.args = load_variables(kwargs, "generate") # load database, discard user-defined descriptors and perform data checks - csv_df, _, _ = load_database(self,self.args.csv_name,"generate") + csv_df, _, _ = load_database(self, self.args.csv_name, "generate") # changes type to classification if there are only two different y values - if self.args.type.lower() == 'reg' and self.args.auto_type: - self = check_clas_problem(self,csv_df) - + if self.args.type.lower() == "reg" and self.args.auto_type: + self = check_clas_problem(self, csv_df) + # scan different ML models txt_heatmap = f"\no Starting heatmap scan with {len(self.args.model)} ML models ({self.args.model})." # scan different training partition sizes cycle = 1 - txt_heatmap += f'\n Heatmap generation:' + txt_heatmap += "\n Heatmap generation:" self.args.log.write(txt_heatmap) # scan different ML models - self.args.log.write(f''' o Starting BO-based hyperoptimization using the combined target: + self.args.log.write(f""" o Starting BO-based hyperoptimization using the combined target: \n 1. 50% = {self.args.error_type.upper()} from a {self.args.repeat_kfolds}x repeated {self.args.kfold}-fold CV (interpoplation) \n 2. 50% = {self.args.error_type.upper()} from the bottom or top (worst performing) fold in a sorted {self.args.kfold}-fold CV (extrapolation) - \n''') + \n""") for ML_model in self.args.model: - - self.args.log.write(f' - {cycle}/{len(self.args.model)} - ML model: {ML_model} ') + self.args.log.write( + f" - {cycle}/{len(self.args.model)} - ML model: {ML_model} " + ) # Try to load model-specific curated CSV first, fall back to general CSV # Get the base name from the original csv_name (remove path if any) - if 'CURATE' in str(self.args.csv_name): + if "CURATE" in str(self.args.csv_name): # If csv_name is already a CURATE file, extract the original base name - csv_basename = os.path.basename(f'{self.args.csv_name}').replace('_CURATE.csv', '').replace('.csv', '') + csv_basename = ( + os.path.basename(f"{self.args.csv_name}") + .replace("_CURATE.csv", "") + .replace(".csv", "") + ) else: - csv_basename = os.path.basename(f'{self.args.csv_name}').split('.')[0] - - curate_folder = self.args.initial_dir.joinpath('CURATE') - csv_model_specific = curate_folder.joinpath(f'{csv_basename}_CURATE_{ML_model}.csv') - + csv_basename = os.path.basename(f"{self.args.csv_name}").split(".")[0] + + curate_folder = self.args.initial_dir.joinpath("CURATE") + csv_model_specific = curate_folder.joinpath( + f"{csv_basename}_CURATE_{ML_model}.csv" + ) + # Store the original csv_name temporarily original_csv_name = self.args.csv_name - + if os.path.exists(csv_model_specific): csv_to_load = csv_model_specific # Temporarily update csv_name to the model-specific CSV self.args.csv_name = str(csv_model_specific) - self.args.log.write(f' o Using model-specific curated database: {os.path.basename(csv_model_specific)}') + self.args.log.write( + f" o Using model-specific curated database: {os.path.basename(csv_model_specific)}" + ) else: csv_to_load = self.args.csv_name - self.args.log.write(f' x Using general database (model-specific not found): {os.path.basename(self.args.csv_name)}') - + self.args.log.write( + f" x Using general database (model-specific not found): {os.path.basename(self.args.csv_name)}" + ) + # load database, discard user-defined descriptors and perform data checks - csv_df, csv_X, csv_y = load_database(self,csv_to_load,"generate",print_info=False) - + csv_df, csv_X, csv_y = load_database( + self, csv_to_load, "generate", print_info=False + ) # standardizes and separates an external test set - Xy_data = prepare_sets(self,csv_df,csv_X,csv_y,None,self.args.names,None,None,None,BO_opt=True) + Xy_data = prepare_sets( + self, + csv_df, + csv_X, + csv_y, + None, + self.args.names, + None, + None, + None, + BO_opt=True, + ) # hyperopt process for ML models _ = BO_workflow(self, Xy_data, csv_df, ML_model) @@ -210,15 +232,15 @@ def __init__(self, **kwargs): ) _ = PFI_workflow(self, csv_df_pfi, ML_model, Xy_data) - + # Restore the original csv_name self.args.csv_name = original_csv_name cycle += 1 # detects best combinations - dir_csv = self.args.destination.joinpath(f"Raw_data") - _ = detect_best(f'{dir_csv}/No_PFI') + dir_csv = self.args.destination.joinpath("Raw_data") + _ = detect_best(f"{dir_csv}/No_PFI") # create heatmap plot(s) if should_plot_generate_heatmap(self.args): @@ -226,11 +248,11 @@ def __init__(self, **kwargs): # detect best and create heatmap for PFI models if self.args.pfi_filter: - try: # if no models were found - _ = detect_best(f'{dir_csv}/PFI') + try: # if no models were found + _ = detect_best(f"{dir_csv}/PFI") if should_plot_generate_heatmap(self.args): _ = heatmap_workflow(self, "PFI") except UnboundLocalError: pass - _ = finish_print(self,start_time,'GENERATE') + _ = finish_print(self, start_time, "GENERATE") diff --git a/robert/generate_utils.py b/robert/generate_utils.py index 5e3f7a9..005289b 100644 --- a/robert/generate_utils.py +++ b/robert/generate_utils.py @@ -14,85 +14,93 @@ create_heatmap, BO_optimizer, BO_metrics, - model_adjust_params - ) + model_adjust_params, +) # hyperopt workflow def BO_workflow(self, Xy_data, csv_df, ML_model): - ''' + """ Load hyperparameter space and perform a Bayesian optimization - ''' - - bo_data = {'model': ML_model.upper(), - 'type': self.args.type.lower(), - 'kfold': self.args.kfold, - 'repeat_kfolds': self.args.repeat_kfolds, - 'seed': self.args.seed, - 'error_type': self.args.error_type.lower(), - 'y': self.args.y, - 'names': self.args.names, - 'X_descriptors': Xy_data['X_descriptors']} - - if ML_model.upper() != 'MVL': - bo_data['params'], bo_data[f"combined_{bo_data['error_type']}"] = BO_optimizer(self,bo_data,Xy_data) - bo_data['params'] = model_adjust_params(self, bo_data['model'], bo_data['params']) + """ + + bo_data = { + "model": ML_model.upper(), + "type": self.args.type.lower(), + "kfold": self.args.kfold, + "repeat_kfolds": self.args.repeat_kfolds, + "seed": self.args.seed, + "error_type": self.args.error_type.lower(), + "y": self.args.y, + "names": self.args.names, + "X_descriptors": Xy_data["X_descriptors"], + } + + if ML_model.upper() != "MVL": + bo_data["params"], bo_data[f"combined_{bo_data['error_type']}"] = BO_optimizer( + self, bo_data, Xy_data + ) + bo_data["params"] = model_adjust_params( + self, bo_data["model"], bo_data["params"] + ) else: - bo_data['params'] = {} # no need to format params + bo_data["params"] = {} # no need to format params bo_data = BO_metrics(self, bo_data, Xy_data) metric_combined = bo_data[f"combined_{bo_data['error_type']}"] - self.args.log.write(f" o Combined {bo_data['error_type'].upper()} for {bo_data['model']} (no BO needed) (no PFI filter): {metric_combined:.2}") + self.args.log.write( + f" o Combined {bo_data['error_type'].upper()} for {bo_data['model']} (no BO needed) (no PFI filter): {metric_combined:.2}" + ) # include the Set column to differentiate between train and test sets (and external test, if any) - csv_df = set_sets(csv_df,Xy_data) + csv_df = set_sets(csv_df, Xy_data) # save csv files with model params and with Xy datapoints db_name = self.args.destination.joinpath(f"Raw_data/No_PFI/{ML_model}_db") params_name = self.args.destination.joinpath(f"Raw_data/No_PFI/{ML_model.upper()}") - _ = csv_df.to_csv(f'{db_name}.csv', index = None, header=True) - + _ = csv_df.to_csv(f"{db_name}.csv", index=None, header=True) + # Convert params dict to string to avoid serialization issues bo_data_to_save = bo_data.copy() - if 'params' in bo_data_to_save: - bo_data_to_save['params'] = json.dumps(bo_data_to_save['params']) - if 'X_descriptors' in bo_data_to_save: - bo_data_to_save['X_descriptors'] = json.dumps(bo_data_to_save['X_descriptors']) - + if "params" in bo_data_to_save: + bo_data_to_save["params"] = json.dumps(bo_data_to_save["params"]) + if "X_descriptors" in bo_data_to_save: + bo_data_to_save["X_descriptors"] = json.dumps(bo_data_to_save["X_descriptors"]) + # Save class label mapping if it exists (for classification with string labels) - if hasattr(self.args, 'class_0_label'): - bo_data_to_save['class_0_label'] = self.args.class_0_label - bo_data_to_save['class_1_label'] = self.args.class_1_label + if hasattr(self.args, "class_0_label"): + bo_data_to_save["class_0_label"] = self.args.class_0_label + bo_data_to_save["class_1_label"] = self.args.class_1_label # Save split type - bo_data_to_save['split'] = self.args.split - + bo_data_to_save["split"] = self.args.split + bo_data_df = pd.DataFrame([bo_data_to_save]) - _ = bo_data_df.to_csv(f'{params_name}.csv', index = None, header=True) + _ = bo_data_df.to_csv(f"{params_name}.csv", index=None, header=True) return bo_data def PFI_workflow(self, csv_df, ML_model, Xy_data): - ''' + """ Filter off parameters with low PFI (not relevant in the model) - ''' + """ # convert df to dict, then adjust params to a valid format name_csv_hyperopt = f"Raw_data/No_PFI/{ML_model}" - path_csv = self.args.destination.joinpath(f'{name_csv_hyperopt}.csv') - PFI_dict = load_params(self,path_csv) + path_csv = self.args.destination.joinpath(f"{name_csv_hyperopt}.csv") + PFI_dict = load_params(self, path_csv) - PFI_discard_cols,descp_cols_pfi = PFI_filter(self, Xy_data, PFI_dict) + PFI_discard_cols, descp_cols_pfi = PFI_filter(self, Xy_data, PFI_dict) - desc_keep = calc_desc_keep(self,Xy_data,PFI_discard_cols) + desc_keep = calc_desc_keep(self, Xy_data, PFI_discard_cols) - discard_idx, descriptors_PFI = [],[] + discard_idx, descriptors_PFI = [], [] # if no descriptors pass the filter, just choose them based on importance until having the number of descps from desc_keep if len(PFI_discard_cols) == len(descp_cols_pfi): PFI_discard_cols = [] - for _,column in enumerate(descp_cols_pfi): + for _, column in enumerate(descp_cols_pfi): if column not in PFI_discard_cols and len(descriptors_PFI) < desc_keep: descriptors_PFI.append(column) else: @@ -100,137 +108,149 @@ def PFI_workflow(self, csv_df, ML_model, Xy_data): # only use the descriptors that passed the PFI filter Xy_data_PFI = Xy_data.copy() - Xy_data_PFI['X_train_scaled'] = Xy_data['X_train_scaled'].drop(discard_idx, axis=1) + Xy_data_PFI["X_train_scaled"] = Xy_data["X_train_scaled"].drop(discard_idx, axis=1) - PFI_dict['X_descriptors'] = descriptors_PFI + PFI_dict["X_descriptors"] = descriptors_PFI - # updates the model's error and descriptors used from the corresponding No_PFI CSV file + # updates the model's error and descriptors used from the corresponding No_PFI CSV file # (the other parameters remain the same) PFI_dict = BO_metrics(self, PFI_dict, Xy_data_PFI) metric_combined = PFI_dict[f"combined_{PFI_dict['error_type']}"] - self.args.log.write(f" o Combined {PFI_dict['error_type'].upper()} for {PFI_dict['model']} (with PFI filter): {metric_combined:.2}") + self.args.log.write( + f" o Combined {PFI_dict['error_type'].upper()} for {PFI_dict['model']} (with PFI filter): {metric_combined:.2}" + ) # save CSV file - _ = save_pfi_csv(self,csv_df,name_csv_hyperopt,PFI_dict,Xy_data_PFI,ML_model) + _ = save_pfi_csv(self, csv_df, name_csv_hyperopt, PFI_dict, Xy_data_PFI, ML_model) -def calc_desc_keep(self,Xy_data,PFI_discard_cols): - ''' +def calc_desc_keep(self, Xy_data, PFI_discard_cols): + """ Calculate number of descriptors to keep in the PFI model - ''' - + """ + # generate new X datasets and store the descriptors used for the PFI-filtered model - desc_keep = len(Xy_data['X_train_scaled'].columns) - - # if the filter does not remove any descriptors based on the PFI threshold, or the + desc_keep = len(Xy_data["X_train_scaled"].columns) + + # if the filter does not remove any descriptors based on the PFI threshold, or the # proportion of descriptors:total datapoints is higher than 1:3, then the filter takes # the minimum value of 1 and 2: # 1. 25% less descriptors than the No PFI original model # 2. Proportion of 1:3 of descriptors:total datapoints (training + validation) - total_points = len(Xy_data['y_train']) - n_descp_PFI = desc_keep-len(PFI_discard_cols) + total_points = len(Xy_data["y_train"]) + n_descp_PFI = desc_keep - len(PFI_discard_cols) # determine how many points will be kept if desc_keep > 1: if self.args.pfi_max > 0: desc_keep = self.args.pfi_max - elif n_descp_PFI > 0.2*total_points or n_descp_PFI >= (0.75*desc_keep) or n_descp_PFI == 0: - option_one = round(0.75*len(Xy_data['X_train_scaled'].columns)) - option_two = round(0.2*total_points) - option_three = round(len(Xy_data['X_train_scaled'].columns)-1) # for databases with two or three descriptors - desc_keep = min(option_one,option_two,option_three) + elif ( + n_descp_PFI > 0.2 * total_points + or n_descp_PFI >= (0.75 * desc_keep) + or n_descp_PFI == 0 + ): + option_one = round(0.75 * len(Xy_data["X_train_scaled"].columns)) + option_two = round(0.2 * total_points) + option_three = round( + len(Xy_data["X_train_scaled"].columns) - 1 + ) # for databases with two or three descriptors + desc_keep = min(option_one, option_two, option_three) return desc_keep -def save_pfi_csv(self,csv_df,name_csv_hyperopt,PFI_dict,Xy_data_PFI,ML_model): - ''' - Saves CSV files with PFI models and information - ''' +def save_pfi_csv(self, csv_df, name_csv_hyperopt, PFI_dict, Xy_data_PFI, ML_model): + """ + Saves CSV files with PFI models and information + """ + + name_csv_hyperopt_PFI = name_csv_hyperopt.replace("No_PFI", "PFI") + path_csv_PFI = self.args.destination.joinpath(f"{name_csv_hyperopt_PFI}_PFI") - name_csv_hyperopt_PFI = name_csv_hyperopt.replace('No_PFI','PFI') - path_csv_PFI = self.args.destination.joinpath(f'{name_csv_hyperopt_PFI}_PFI') - # Save class label mapping if it exists (for classification with string labels) - if hasattr(self.args, 'class_0_label'): - PFI_dict['class_0_label'] = self.args.class_0_label - PFI_dict['class_1_label'] = self.args.class_1_label - + if hasattr(self.args, "class_0_label"): + PFI_dict["class_0_label"] = self.args.class_0_label + PFI_dict["class_1_label"] = self.args.class_1_label + csv_PFI_df = pd.DataFrame([PFI_dict]) - _ = csv_PFI_df.to_csv(f'{path_csv_PFI}.csv', index = None, header=True) + _ = csv_PFI_df.to_csv(f"{path_csv_PFI}.csv", index=None, header=True) # include the Set column to differentiate between train and test sets (and external test, if any) - csv_df = set_sets(csv_df,Xy_data_PFI) + csv_df = set_sets(csv_df, Xy_data_PFI) # save the csv file - if os.path.exists(self.args.destination.joinpath(f"Raw_data/PFI/{ML_model}_PFI.csv")): + if os.path.exists( + self.args.destination.joinpath(f"Raw_data/PFI/{ML_model}_PFI.csv") + ): db_name = self.args.destination.joinpath(f"Raw_data/PFI/{ML_model}_PFI_db") - _ = csv_df.to_csv(f'{db_name}.csv', index = None, header=True) + _ = csv_df.to_csv(f"{db_name}.csv", index=None, header=True) -def set_sets(csv_df,Xy_data): +def set_sets(csv_df, Xy_data): """ Set a new column for the sets, including test set (if any) """ set_column = [] n_points = len(csv_df[csv_df.columns[0]]) - for i in range(0,n_points): - if i in Xy_data['test_points']: - set_column.append('Test') + for i in range(0, n_points): + if i in Xy_data["test_points"]: + set_column.append("Test") else: - set_column.append('Training') + set_column.append("Training") - csv_df['Set'] = set_column + csv_df["Set"] = set_column return csv_df def detect_best(folder): - ''' + """ Check which combination led to the best results - ''' + """ # detect files - file_list = glob.glob(f'{folder}/*.csv') + file_list = glob.glob(f"{folder}/*.csv") errors = [] for file in file_list: - if '_db' not in file: - results_model = pd.read_csv(f'{file}', encoding='utf-8') - training_error = results_model[f"combined_{results_model['error_type'][0]}"][0] + if "_db" not in file: + results_model = pd.read_csv(f"{file}", encoding="utf-8") + training_error = results_model[ + f"combined_{results_model['error_type'][0]}" + ][0] errors.append(training_error) else: errors.append(np.nan) # detect best result and copy files to the Best_model folder - if results_model['error_type'][0].lower() in ['mae','rmse']: + if results_model["error_type"][0].lower() in ["mae", "rmse"]: min_idx = errors.index(np.nanmin(errors)) else: min_idx = errors.index(np.nanmax(errors)) best_name = file_list[min_idx] - best_db = f'{os.path.dirname(file_list[min_idx])}/{os.path.basename(file_list[min_idx]).split(".csv")[0]}_db.csv' + best_db = f"{os.path.dirname(file_list[min_idx])}/{os.path.basename(file_list[min_idx]).split('.csv')[0]}_db.csv" - shutil.copyfile(f'{best_name}', f'{best_name}'.replace('Raw_data','Best_model')) - shutil.copyfile(f'{best_db}', f'{best_db}'.replace('Raw_data','Best_model')) + shutil.copyfile(f"{best_name}", f"{best_name}".replace("Raw_data", "Best_model")) + shutil.copyfile(f"{best_db}", f"{best_db}".replace("Raw_data", "Best_model")) -def heatmap_workflow(self,folder_hm): +def heatmap_workflow(self, folder_hm): """ Create matrix of ML models, training sizes and errors/precision """ - path_raw = self.args.destination.joinpath(f"Raw_data") - csv_data, model_list = {},[] + path_raw = self.args.destination.joinpath("Raw_data") + csv_data, model_list = {}, [] for csv_file in glob.glob(path_raw.joinpath(f"{folder_hm}/*.csv").as_posix()): - if '_db' not in csv_file: + if "_db" not in csv_file: basename = os.path.basename(csv_file) - csv_model = basename.replace('.','_').split('_')[0] + csv_model = basename.replace(".", "_").split("_")[0] if csv_model not in model_list: - csv_value = pd.read_csv(csv_file, encoding='utf-8') + csv_value = pd.read_csv(csv_file, encoding="utf-8") csv_data[csv_model] = csv_value[f"combined_{self.args.error_type}"][0] # pass dictionary into a dataframe, and sort the models alphabetically csv_df = pd.DataFrame([csv_data]) - + # sort columns in the same order as the optimization df_cols = [] for model in self.args.model: @@ -239,8 +259,7 @@ def heatmap_workflow(self,folder_hm): # plot heatmap if folder_hm == "No_PFI": - suffix = 'No PFI' + suffix = "No PFI" elif folder_hm == "PFI": - suffix = 'PFI' - _ = create_heatmap(self,csv_df,suffix,path_raw) - + suffix = "PFI" + _ = create_heatmap(self, csv_df, suffix, path_raw) diff --git a/robert/gui_easyrob/easyrob.py b/robert/gui_easyrob/easyrob.py index d8a3340..b1d1ff2 100644 --- a/robert/gui_easyrob/easyrob.py +++ b/robert/gui_easyrob/easyrob.py @@ -22,6 +22,7 @@ """ + def get_main_window_class(): """Factory function to retrieve the main application window class.""" try: @@ -29,4 +30,4 @@ def get_main_window_class(): except ImportError: from robert.gui_easyrob.main.window import EasyROB - return EasyROB \ No newline at end of file + return EasyROB diff --git a/robert/gui_easyrob/easyrob_launcher.py b/robert/gui_easyrob/easyrob_launcher.py index a28c1bd..6f66648 100644 --- a/robert/gui_easyrob/easyrob_launcher.py +++ b/robert/gui_easyrob/easyrob_launcher.py @@ -17,6 +17,7 @@ import os import sys + def configure_environment(): """Configure Qt environment for stable rendering.""" os.environ.setdefault("QT_QUICK_BACKEND", "software") @@ -35,9 +36,10 @@ def configure_environment(): configure_environment() -from PySide6.QtCore import Qt, QCoreApplication -from PySide6.QtWidgets import QApplication -from PySide6.QtGui import QPalette +from PySide6.QtCore import Qt, QCoreApplication # noqa: E402 +from PySide6.QtGui import QPalette # noqa: E402 +from PySide6.QtWidgets import QApplication # noqa: E402 + def main(): """Main entry point for the EasyROB application.""" @@ -64,5 +66,6 @@ def main(): sys.exit(app.exec()) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/robert/gui_easyrob/main/window.py b/robert/gui_easyrob/main/window.py index 5b3eb76..b92111e 100644 --- a/robert/gui_easyrob/main/window.py +++ b/robert/gui_easyrob/main/window.py @@ -36,16 +36,21 @@ import webbrowser try: - from version import SOFTWARE_VERSIONS from utils import utils_gui, molssi_utils from tabs import predictions, aqme, advanced_options, molssi, results, images -except ImportError as e: - +except ImportError: from robert.gui_easyrob.version import SOFTWARE_VERSIONS from robert.gui_easyrob.utils import utils_gui, molssi_utils - from robert.gui_easyrob.tabs import predictions, aqme, advanced_options, molssi, results, images + from robert.gui_easyrob.tabs import ( + predictions, + aqme, + advanced_options, + molssi, + results, + images, + ) # ------------------------------------------------------------ @@ -119,27 +124,31 @@ # ------------------------------------------------------------ BASE_DIR = Path(__file__).resolve().parent.parent + class EasyROB(QMainWindow): """Main window for the easyROB application.""" + def __init__(self): super().__init__() self.file_path = "" self.csv_test_path = "" - self.process = None + self.process = None self.available_list = None self.ignore_list = None self.manual_stop = False self.worker = None self._last_loaded_file_path = None - self._molssi_workers = set() # Keep track of MolSSI workers + self._molssi_workers = set() # Keep track of MolSSI workers self.molssi_is_closing = False self.initUI() - self.clear_test_button.setVisible(False) # Hide the button initially - self.molssi_tab.load_test_requested.connect(self.set_csv_test_path) # Connect signal with molssi tab donwload test requested + self.clear_test_button.setVisible(False) # Hide the button initially + self.molssi_tab.load_test_requested.connect( + self.set_csv_test_path + ) # Connect signal with molssi tab donwload test requested def closeEvent(self, event): """Handle the window close event, ensuring proper shutdown of workers.""" - worker = getattr(self, 'worker', None) + worker = getattr(self, "worker", None) if worker is not None and worker.isRunning(): reply = QMessageBox.question( @@ -147,7 +156,7 @@ def closeEvent(self, event): "Exit Confirmation", "ROBERT is still running. Do you want to stop the process and exit?", QMessageBox.Yes | QMessageBox.No, - QMessageBox.No + QMessageBox.No, ) if reply == QMessageBox.No: event.ignore() @@ -211,8 +220,8 @@ def move_to_available(self): self.available_list.addItem(item.text()) # Add back to left list row = self.ignore_list.row(item) # Get correct row index self.ignore_list.takeItem(row) # Remove from right list - - def open_external_url(self,url: str): + + def open_external_url(self, url: str): """Open URL using the system default browser.""" try: webbrowser.open(url, new=2) # new=2 → new tab if possible @@ -241,7 +250,7 @@ def initUI(self): } """ self.setWindowTitle("easyROB") - + # Create main tab widget self.tab_widget = QTabWidget() self.setCentralWidget(self.tab_widget) @@ -275,7 +284,9 @@ def initUI(self): tutorial_btn = QToolButton() tutorial_btn.setText("Tutorial") tutorial_btn.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) - tutorial_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_ComputerIcon)) + tutorial_btn.setIcon( + self.style().standardIcon(QStyle.StandardPixmap.SP_ComputerIcon) + ) tutorial_btn.setIconSize(QSize(14, 14)) tutorial_btn.setCursor(Qt.PointingHandCursor) tutorial_btn.setStyleSheet(tool_style) @@ -291,7 +302,9 @@ def initUI(self): youtube_btn.setCursor(Qt.PointingHandCursor) youtube_btn.setStyleSheet(tool_style) youtube_btn.clicked.connect( - lambda: self.open_external_url("https://www.youtube.com/@thealegregroup4964/videos") + lambda: self.open_external_url( + "https://www.youtube.com/@thealegregroup4964/videos" + ) ) # Documentation @@ -311,7 +324,9 @@ def initUI(self): contact_btn = QToolButton() contact_btn.setText("Contact") contact_btn.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) - contact_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_MessageBoxInformation)) + contact_btn.setIcon( + self.style().standardIcon(QStyle.StandardPixmap.SP_MessageBoxInformation) + ) contact_btn.setIconSize(QSize(14, 14)) contact_btn.setCursor(Qt.PointingHandCursor) contact_btn.setStyleSheet(tool_style) @@ -321,7 +336,9 @@ def initUI(self): version_btn = QToolButton() version_btn.setText("Version") version_btn.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) - version_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogInfoView)) + version_btn.setIcon( + self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogInfoView) + ) version_btn.setIconSize(QSize(14, 14)) version_btn.setCursor(Qt.PointingHandCursor) version_btn.setStyleSheet(tool_style) @@ -353,7 +370,9 @@ def initUI(self): # --- Add logo with frame --- with AssetLibrary.Robert_logo_transparent.get_path() as path_logo: pixmap = QPixmap(str(path_logo)) - scaled_pixmap = pixmap.scaled(300, 110, Qt.KeepAspectRatio, Qt.SmoothTransformation) + scaled_pixmap = pixmap.scaled( + 300, 110, Qt.KeepAspectRatio, Qt.SmoothTransformation + ) logo_label = QLabel(self) logo_label.setPixmap(scaled_pixmap) @@ -379,9 +398,9 @@ def initUI(self): "Drag & Drop a CSV file here", self, file_filter="CSV Files (*.csv)", - extensions=(".csv",) + extensions=(".csv",), ) - self.file_label.set_callback(self.set_file_path) + self.file_label.set_callback(self.set_file_path) input_layout.addWidget(self.file_title) input_layout.addWidget(self.file_label) @@ -396,7 +415,7 @@ def initUI(self): "Drag & Drop a external CSV test file here (optional)", self, file_filter="CSV Files (*.csv)", - extensions=(".csv",) + extensions=(".csv",), ) self.csv_test_label.set_callback(self.set_csv_test_path) @@ -417,7 +436,6 @@ def initUI(self): test_layout.addWidget(self.csv_test_title) test_layout.addWidget(test_label_container) - # --- CSV Section with Button in the Middle --- csv_layout = QHBoxLayout() csv_layout.addLayout(input_layout) @@ -425,15 +443,15 @@ def initUI(self): # --- Add All to Main Layout --- main_layout.addLayout(csv_layout) - + # --- Select column for --y --- self.y_label = QLabel("Select Target Column (y)") self.y_label.setStyleSheet("font-size:13px;") main_layout.addWidget(self.y_label) - self.y_dropdown = NoScrollComboBox() + self.y_dropdown = NoScrollComboBox() main_layout.addWidget(self.y_dropdown) self.y_dropdown.setStyleSheet(box_features) - + # --- Select prediction type --- self.type_label = QLabel("Prediction Type") self.type_label.setStyleSheet("font-size:13px;") @@ -442,15 +460,15 @@ def initUI(self): self.type_dropdown.addItems(["Regression", "Classification"]) main_layout.addWidget(self.type_dropdown) self.type_dropdown.setStyleSheet(box_features) - + # --- Select column for --names --- self.names_label = QLabel("Select name column") self.names_label.setStyleSheet("font-size:13px;") main_layout.addWidget(self.names_label) self.names_dropdown = NoScrollComboBox() - main_layout.addWidget(self.names_dropdown) + main_layout.addWidget(self.names_dropdown) self.names_dropdown.setStyleSheet(box_features) - + # Main horizontal layout for column selection column_layout = QHBoxLayout() @@ -470,10 +488,14 @@ def initUI(self): self.add_button = QPushButton(">>") self.add_button.setFixedSize(30, 24) - self.add_button.clicked.connect(self.move_to_selected) # Moves selected items to "Ignored Columns" + self.add_button.clicked.connect( + self.move_to_selected + ) # Moves selected items to "Ignored Columns" self.remove_button = QPushButton("<<") self.remove_button.setFixedSize(30, 24) - self.remove_button.clicked.connect(self.move_to_available) # Moves selected items back to "Available Columns" + self.remove_button.clicked.connect( + self.move_to_available + ) # Moves selected items back to "Available Columns" button_style = """ QPushButton { border: 1px solid palette(mid); @@ -486,13 +508,13 @@ def initUI(self): """ self.add_button.setStyleSheet(button_style) - self.remove_button.setStyleSheet(button_style) + self.remove_button.setStyleSheet(button_style) # Add buttons to the button layout - button_layout.addStretch() + button_layout.addStretch() button_layout.addWidget(self.add_button, alignment=Qt.AlignCenter) button_layout.addWidget(self.remove_button, alignment=Qt.AlignCenter) - button_layout.addStretch() + button_layout.addStretch() # Right side (Ignored Columns) right_layout = QVBoxLayout() @@ -512,32 +534,27 @@ def initUI(self): # Create a container for the column layout and resize it column_container = QWidget() column_container.setLayout(column_layout) - column_container.setFixedHeight(120) + column_container.setFixedHeight(120) # Insert the column container into the main layout main_layout.addWidget(column_container) main_layout.addSpacing(10) # AQME Workflow Checkbox - self.aqme_workflow = QCheckBox("Enable AQME Workflow") + self.aqme_workflow = QCheckBox("Enable AQME Workflow") self.aqme_workflow.setStyleSheet("font-weight: bold; font-size: 14px;") self.aqme_workflow.stateChanged.connect(self.check_aqme_workflow) main_layout.addWidget(self.aqme_workflow) - main_layout.addSpacing(10) + main_layout.addSpacing(10) # Workflow selection dropdown self.workflow_selector = NoScrollComboBox() self.workflow_selector.setStyleSheet("font-weight: bold; font-size: 14px;") # Add options - self.workflow_selector.addItems([ - "Full Workflow", - "CURATE", - "GENERATE", - "PREDICT", - "VERIFY", - "REPORT" - ]) + self.workflow_selector.addItems( + ["Full Workflow", "CURATE", "GENERATE", "PREDICT", "VERIFY", "REPORT"] + ) # Set default selection self.workflow_selector.setCurrentText("Full Workflow") @@ -614,7 +631,6 @@ def initUI(self): } """) - self.run_aqme_button.clicked.connect(self.run_aqme) # --- Stop Button --- @@ -693,9 +709,9 @@ def initUI(self): } """) # Set minimum height for the console output and make it expandable - self.console_output.setMinimumHeight(250) + self.console_output.setMinimumHeight(250) self.console_output.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) - + # Create ANSI converter to display colors in the console and special characters self.ansi_converter = Ansi2HTMLConverter(dark_bg=True) # Preserves colors main_layout.addWidget(QLabel("Console Output")) @@ -747,7 +763,9 @@ def initUI(self): # Wrap options tab in a scroll area to handle large content options_scroll = QScrollArea() options_scroll.setWidgetResizable(True) - options_scroll.setMinimumSize(0, 0) # Prevent the scroll area from imposing a big minimum on the window + options_scroll.setMinimumSize( + 0, 0 + ) # Prevent the scroll area from imposing a big minimum on the window options_scroll.setWidget(self.options_tab) # Images tab @@ -765,10 +783,12 @@ def initUI(self): # =============================== self.tab_widget.addTab(self.tab_widget_aqme, "AQME") - self.tab_widget.setTabEnabled(self.tab_widget.indexOf(self.tab_widget_aqme), False) + self.tab_widget.setTabEnabled( + self.tab_widget.indexOf(self.tab_widget_aqme), False + ) self.tab_widget.addTab(options_scroll, "Advanced Options") - + self.tab_widget.addTab(self.molssi_tab, "MolSSI Databases") self.tab_widget.addTab(self.results_tab, "Reports") @@ -781,15 +801,13 @@ def initUI(self): # Start disabled self.tab_widget.setTabEnabled( - self.tab_widget.indexOf(self.predictions_tab), - False + self.tab_widget.indexOf(self.predictions_tab), False ) # React to availability decided by the tab itself self.predictions_tab.availabilityChanged.connect( lambda ok: self.tab_widget.setTabEnabled( - self.tab_widget.indexOf(self.predictions_tab), - ok + self.tab_widget.indexOf(self.predictions_tab), ok ) ) @@ -805,8 +823,7 @@ def show_contact_dialog(self): label = QLabel() label.setTextFormat(Qt.RichText) label.setTextInteractionFlags( - Qt.TextSelectableByMouse | - Qt.LinksAccessibleByMouse + Qt.TextSelectableByMouse | Qt.LinksAccessibleByMouse ) label.setOpenExternalLinks(True) @@ -843,7 +860,7 @@ def load_tutorial(self, name): block.strip().replace("\n", " ") for block in text.split("---") if block.strip() - ] + ] def show_tutorial_dialog(self): """Display workflow tutorial dialog.""" @@ -888,10 +905,7 @@ def _build_tutorial_tabs(self): for folder, title in tutorials: texts = self.load_tutorial(folder) - tabs.addTab( - self.create_tutorial_tab(folder, texts), - title - ) + tabs.addTab(self.create_tutorial_tab(folder, texts), title) self.tutorial_layout.addWidget(tabs) @@ -906,7 +920,7 @@ def create_tutorial_tab(self, folder_name, texts): base = BASE_DIR / "tutorials" / "tutorial_images" / folder_name images = sorted( base.glob(f"{folder_name}_*.png"), - key=lambda p: [int(s) if s.isdigit() else s for s in p.stem.split("_")] + key=lambda p: [int(s) if s.isdigit() else s for s in p.stem.split("_")], ) tab = QWidget() @@ -915,7 +929,6 @@ def create_tutorial_tab(self, folder_name, texts): stacked = QStackedWidget() for i, image_path in enumerate(images): - page = QWidget() page_layout = QHBoxLayout(page) page_layout.setContentsMargins(0, 0, 0, 0) @@ -935,10 +948,7 @@ def create_tutorial_tab(self, folder_name, texts): if not pixmap.isNull(): scaled = pixmap.scaled( - 520, - 520, - Qt.KeepAspectRatio, - Qt.SmoothTransformation + 520, 520, Qt.KeepAspectRatio, Qt.SmoothTransformation ) image_label.setPixmap(scaled) else: @@ -991,7 +1001,7 @@ def create_tutorial_tab(self, folder_name, texts): layout.addLayout(nav_layout) def update_step(): - step_label.setText(f"Step {stacked.currentIndex()+1} / {stacked.count()}") + step_label.setText(f"Step {stacked.currentIndex() + 1} / {stacked.count()}") def next_step(): i = (stacked.currentIndex() + 1) % stacked.count() @@ -1009,7 +1019,7 @@ def prev_step(): update_step() return tab - + def show_version_dialog(self): """Display styled version dialog.""" @@ -1078,10 +1088,14 @@ def check_aqme_workflow(self): # Enable the AQME tab if not already enabled if not self.tab_widget.isTabEnabled(tab_index): self.tab_widget.setTabEnabled(tab_index, True) - QMessageBox.information(self, "AQME Tab Enabled", "AQME tab unlocked to specify AQME parameters.") + QMessageBox.information( + self, + "AQME Tab Enabled", + "AQME tab unlocked to specify AQME parameters.", + ) # Always refresh AQME tab content if file path is available - if hasattr(self, 'file_path') and self.file_path: + if hasattr(self, "file_path") and self.file_path: self.tab_widget_aqme.selected_atoms = [] self.tab_widget_aqme.file_path = self.file_path self.tab_widget_aqme.detect_patterns_and_display() @@ -1136,7 +1150,7 @@ def refresh_tabs(self, file_path): self._refresh_scheduled = True - # Schedule refresh after a short delay for avoid freeze popups + # Schedule refresh after a short delay for avoid freeze popups QTimer.singleShot(50, self._execute_refresh_tabs) def _execute_refresh_tabs(self): @@ -1160,13 +1174,17 @@ def _execute_refresh_tabs(self): def select_file(self): """Opens file dialog to select a CSV file.""" - file_path, _ = QFileDialog.getOpenFileName(self, "Select CSV File", "", "CSV Files (*.csv)") + file_path, _ = QFileDialog.getOpenFileName( + self, "Select CSV File", "", "CSV Files (*.csv)" + ) if file_path: self.set_file_path(file_path) def select_csv_test_file(self): """Opens file dialog to select a test CSV file.""" - file_path, _ = QFileDialog.getOpenFileName(self, "Select Test CSV File", "", "CSV Files (*.csv)") + file_path, _ = QFileDialog.getOpenFileName( + self, "Select Test CSV File", "", "CSV Files (*.csv)" + ) if file_path: self.set_csv_test_path(file_path) @@ -1176,8 +1194,8 @@ def set_file_path(self, file_path: str, force: bool = False): Reloads if the file path changed OR the file was modified (mtime) OR force=True. """ p = Path(file_path) - current_path = getattr(self, 'file_path', None) - current_mtime = getattr(self, '_file_mtime', None) + current_path = getattr(self, "file_path", None) + current_mtime = getattr(self, "_file_mtime", None) # Compute new file's modification time (None if missing) try: @@ -1185,8 +1203,8 @@ def set_file_path(self, file_path: str, force: bool = False): except OSError: new_mtime = None - same_path = (current_path == file_path) - same_mtime = (current_mtime == new_mtime) + same_path = current_path == file_path + same_mtime = current_mtime == new_mtime # If nothing changed and not forced, bail early if same_path and same_mtime and not force: @@ -1209,7 +1227,7 @@ def set_file_path(self, file_path: str, force: bool = False): self.load_csv_columns() self.refresh_tabs(file_path) - # Check for MolSSI descriptors + # Check for MolSSI descriptors if not self._is_molssi_csv(file_path): self.check_molssi_descriptors() @@ -1220,7 +1238,6 @@ def set_file_path(self, file_path: str, force: bool = False): self.check_aqme_workflow() def set_csv_test_path(self, file_path): - """Sets the path for the test CSV file and updates the label.""" self.csv_test_path = file_path file_name = Path(file_path).name @@ -1300,20 +1317,13 @@ def _finished(path): popup.close() popup.deleteLater() - self.molssi_tab.handle_external_download( - path, - context="molssi_test" - ) + self.molssi_tab.handle_external_download(path, context="molssi_test") def _error(msg): popup.close() popup.deleteLater() - QMessageBox.warning( - self, - "MolSSI download failed", - msg - ) + QMessageBox.warning(self, "MolSSI download failed", msg) self._molssi_download_worker.finished.connect(_finished) self._molssi_download_worker.error.connect(_error) @@ -1324,14 +1334,14 @@ def _is_molssi_csv(self, file_path: str) -> bool: Return True if the file path corresponds to a MolSSI-generated CSV. """ return "_molssi_" in Path(file_path).stem - + def check_molssi_descriptors(self): """Check for MolSSI descriptors in the current dataset.""" worker = MolSSIWorker( self.df, self.file_path, should_abort=lambda: self.molssi_is_closing, - debug=True + debug=True, ) self._molssi_workers.add(worker) @@ -1341,7 +1351,7 @@ def check_molssi_descriptors(self): worker.finished.connect(worker.deleteLater) worker.start() - + def on_molssi_finished(self, result): """Called when MolSSI worker finishes.""" @@ -1418,19 +1428,15 @@ def on_molssi_finished(self, result): msg.setText( "Your molecules are fully covered by the MolSSI " f"{result['library']} ({result['data_type']}) database.\n\n" - "In addition to generating descriptors for your current dataset, " "you can also load the complete MolSSI database as an external test set.\n\n" - "This external dataset contains all molecules available in the MolSSI " "database with the same type of descriptors and can be used to:\n\n" "• Evaluate model performance\n" "• Validate predictions\n" "• Explore new candidate molecules\n\n" - "The test dataset will be loaded separately and will NOT modify your " "current dataset.\n\n" - "Do you want to load the full MolSSI dataset as a test set?" ) @@ -1444,9 +1450,7 @@ def on_molssi_finished(self, result): self.current_download_context = "molssi_test" # Start automatic MolSSI download - self.start_molssi_test_download( - library_slug=result["library"] - ) + self.start_molssi_test_download(library_slug=result["library"]) def _update_unified_smiles_context(self): """ @@ -1463,14 +1467,13 @@ def _update_unified_smiles_context(self): if self.csv_test_path: unified_smiles = self.tab_widget_aqme.build_unified_smiles_context( - self.file_path, - self.csv_test_path + self.file_path, self.csv_test_path ) else: unified_smiles = self.tab_widget_aqme.build_unified_smiles_context( self.file_path ) - + self.tab_widget_aqme.unified_smiles = unified_smiles def set_main_chemdraw_path(self, file_path): @@ -1576,7 +1579,7 @@ def check_atomic_descriptors(self, context: str) -> bool: "No atoms are currently selected. This is not an error.\n" "Atom selection may not always be available depending on the structures.\n\n" "To access the AQME tab and enable atom selection, check " - "\"Enable AQME Workflow\".\n\n" + '"Enable AQME Workflow".\n\n' "You can safely continue with the current setup, or go back and select " "atoms if that option is available and relevant.\n\n" "Do you want to continue?" @@ -1597,15 +1600,11 @@ def check_atomic_descriptors(self, context: str) -> bool: ) reply = QMessageBox.question( - self, - title, - message, - QMessageBox.Yes | QMessageBox.No, - QMessageBox.Yes + self, title, message, QMessageBox.Yes | QMessageBox.No, QMessageBox.Yes ) return reply == QMessageBox.Yes - + def handle_predict_with_test(self, run_dir): """Preflight checks for PREDICT with test CSV.""" @@ -1640,17 +1639,12 @@ def _ask_descriptor_strategy(self): "the required descriptors generated without AQME." ) aqme_btn = msg.addButton( - "Generate descriptors with AQME", - QMessageBox.AcceptRole + "Generate descriptors with AQME", QMessageBox.AcceptRole ) existing_btn = msg.addButton( - "Descriptors already present", - QMessageBox.AcceptRole - ) - cancel_btn = msg.addButton( - "Cancel", - QMessageBox.RejectRole + "Descriptors already present", QMessageBox.AcceptRole ) + msg.addButton("Cancel", QMessageBox.RejectRole) msg.exec() @@ -1663,7 +1657,7 @@ def _ask_descriptor_strategy(self): return "existing" return "cancel" - + def _run_aqme_for_test_descriptors(self, run_dir): """ Automatically detects how AQME descriptors were generated @@ -1677,10 +1671,8 @@ def _run_aqme_for_test_descriptors(self, run_dir): # Scenario 1: AQME-ROBERT (integrated workflow) # ================================================== if os.path.exists(aqme_dat): - robert_cmd = self._read_command_from_dat( - aqme_dat, - expected_prefix="Command line used in ROBERT" + aqme_dat, expected_prefix="Command line used in ROBERT" ) # Extract --qdescp_keywords "..." if present @@ -1689,9 +1681,7 @@ def _run_aqme_for_test_descriptors(self, run_dir): if match: qdescp_keywords = match.group(1) - aqme_cmd = self._build_test_aqme_command( - qdescp_keywords=qdescp_keywords - ) + aqme_cmd = self._build_test_aqme_command(qdescp_keywords=qdescp_keywords) return aqme_cmd @@ -1699,15 +1689,11 @@ def _run_aqme_for_test_descriptors(self, run_dir): # Scenario 2: AQME -> ROBERT (separate runs) # ================================================== if os.path.exists(qdescp_dat): - aqme_cmd_original = self._read_command_from_dat( - qdescp_dat, - expected_prefix="Command line used in AQME" + qdescp_dat, expected_prefix="Command line used in AQME" ) - aqme_cmd = self._build_test_aqme_command( - original_command=aqme_cmd_original - ) + aqme_cmd = self._build_test_aqme_command(original_command=aqme_cmd_original) return aqme_cmd @@ -1718,11 +1704,11 @@ def _run_aqme_for_test_descriptors(self, run_dir): self, "AQME information not found", "Could not automatically determine how descriptors were generated.\n\n" - "No AQME metadata files were found for this model." + "No AQME metadata files were found for this model.", ) return None - + def _read_command_from_dat(self, dat_path, expected_prefix): """ Reads a .dat file and extracts the command line @@ -1735,9 +1721,7 @@ def _read_command_from_dat(self, dat_path, expected_prefix): if line.startswith(expected_prefix): return line.split(":", 1)[1].strip() - raise RuntimeError( - f"Expected command line not found in {dat_path}" - ) + raise RuntimeError(f"Expected command line not found in {dat_path}") def _build_test_aqme_command(self, original_command=None, qdescp_keywords=None): """Builds an AQME command for generating descriptors for the test CSV.""" @@ -1760,22 +1744,16 @@ def _build_test_aqme_command(self, original_command=None, qdescp_keywords=None): if original_command: # 1. Remove any leading python executable # Keep everything from "-m aqme" onwards - match = re.search(r'(-m\s+aqme.*)', original_command) + match = re.search(r"(-m\s+aqme.*)", original_command) if not match: raise RuntimeError("Could not locate '-m aqme' in AQME command") aqme_args = match.group(1) # 2. Replace CSV references + aqme_args = re.sub(r'--input\s+"[^"]+"', f'--input "{test_csv}"', aqme_args) aqme_args = re.sub( - r'--input\s+"[^"]+"', - f'--input "{test_csv}"', - aqme_args - ) - aqme_args = re.sub( - r'--csv_name\s+"[^"]+"', - f'--csv_name "{test_csv}"', - aqme_args + r'--csv_name\s+"[^"]+"', f'--csv_name "{test_csv}"', aqme_args ) # 3. Rebuild command with correct python @@ -1788,26 +1766,26 @@ def _build_test_aqme_command(self, original_command=None, qdescp_keywords=None): cmd = ( f'"{python_pointer}" -u -m aqme --qdescp ' f'--input "{test_csv}" ' - f'--program xtb ' + f"--program xtb " f'--csv_name "{test_csv}" ' - f'--robert' + f"--robert" ) if qdescp_keywords: - cmd += f' {qdescp_keywords}' + cmd += f" {qdescp_keywords}" return cmd - + def run_test_aqme(self, aqme_command, run_dir): """Launches an AQME worker for generating descriptors for the test CSV.""" - + self.console_output.append( "Running AQME...
" ) self.progress.setRange(0, 0) self.current_process = "AQME" - self.aqme_role = "test" + self.aqme_role = "test" self.worker = RobertWorker(aqme_command, run_dir) self.worker.output_received.connect(self.console_output.append) @@ -1826,9 +1804,7 @@ def _detect_aqme_output_csv(self): run_dir = os.path.dirname(self.csv_test_path) # Original CSV name without extension) - original_name = os.path.splitext( - os.path.basename(self.csv_test_path) - )[0] + original_name = os.path.splitext(os.path.basename(self.csv_test_path))[0] # Expected output CSV expected_csv = f"AQME-ROBERT_full_{original_name}.csv" @@ -1838,13 +1814,13 @@ def _detect_aqme_output_csv(self): return expected_path return None - + def _write_atom_mapping_dat( self, smarts: str, selected_atoms: list, run_dir: str, - filename: str = "AtomMapping_data.dat" + filename: str = "AtomMapping_data.dat", ): """ Write an atomic mapping contract to disk. @@ -1915,9 +1891,10 @@ def _write_atom_mapping_dat( self.console_output.append( f"WARNING: Failed to write atom mapping dat: {e}" ) + def _check_generate_folder(self, run_dir): """Checks if a GENERATE folder exists in the run directory.""" - + generate_dir = os.path.join(run_dir, "GENERATE") if not os.path.exists(generate_dir): @@ -1926,12 +1903,12 @@ def _check_generate_folder(self, run_dir): "Trained model not found", "No trained model was found in this folder.\n\n" "Prediction with a test CSV requires an existing model " - "generated in a previous run (GENERATE step).\n" + "generated in a previous run (GENERATE step).\n", ) return False return True - + def _validate_robert_workflow(self): """ Validates workflow state and resolves mismatches. @@ -1943,14 +1920,18 @@ def _validate_robert_workflow(self): # --------------------------------------------------- # Detect mismatch: test CSV loaded but not PREDICT # --------------------------------------------------- - if self.csv_test_path and not self.file_path and workflow not in ["PREDICT", "REPORT"]: + if ( + self.csv_test_path + and not self.file_path + and workflow not in ["PREDICT", "REPORT"] + ): reply = QMessageBox.question( self, "Possible workflow mismatch", "You loaded a test CSV but selected 'Full Workflow'.\n\n" "Did you mean to generate predictions instead?", QMessageBox.Yes | QMessageBox.No, - QMessageBox.Yes + QMessageBox.Yes, ) if reply == QMessageBox.Yes: @@ -1961,7 +1942,7 @@ def _validate_robert_workflow(self): self, "Execution stopped", "To run 'Full Workflow' in ROBERT, please load the training CSV " - "and select the appropriate target and name columns." + "and select the appropriate target and name columns.", ) return False @@ -1973,7 +1954,7 @@ def _validate_robert_workflow(self): QMessageBox.warning( self, "WARNING!", - "Please load a training CSV file before running the workflow." + "Please load a training CSV file before running the workflow.", ) return False @@ -1981,12 +1962,9 @@ def _validate_robert_workflow(self): # PREDICT validation # --------------------------------------------------- if workflow == "PREDICT": - if not self.csv_test_path: QMessageBox.warning( - self, - "WARNING!", - "Please select a test CSV file for prediction." + self, "WARNING!", "Please select a test CSV file for prediction." ) return False @@ -1999,12 +1977,11 @@ def _validate_robert_workflow(self): # REPORT validation # --------------------------------------------------- if workflow == "REPORT": - if not self.file_path and not self.csv_test_path: QMessageBox.warning( self, "WARNING!", - "Please load a CSV file to determine the report directory." + "Please load a CSV file to determine the report directory.", ) return False @@ -2018,7 +1995,7 @@ def run_robert(self): # -------------------------------------------------- if not self._validate_robert_workflow(): return - + # -------------------------------------------------- # Init process # -------------------------------------------------- @@ -2033,7 +2010,7 @@ def run_robert(self): "
"
         )
 
-        # Path to run directory 
+        # Path to run directory
         if self.file_path:
             run_dir = os.path.dirname(self.file_path)
         elif self.csv_test_path:
@@ -2048,8 +2025,7 @@ def run_robert(self):
             folders_to_check.extend(["CSEARCH", "QDESCP"])
 
         existing_folders = [
-            f for f in folders_to_check
-            if os.path.exists(os.path.join(run_dir, f))
+            f for f in folders_to_check if os.path.exists(os.path.join(run_dir, f))
         ]
 
         if existing_folders and self.workflow_selector.currentText() == "Full Workflow":
@@ -2061,7 +2037,7 @@ def run_robert(self):
                 "or will be overwritten if the previous run completed successfully.\n\n"
                 "Are you sure you want to continue and delete them?",
                 QMessageBox.Yes | QMessageBox.No,
-                QMessageBox.No
+                QMessageBox.No,
             )
 
             if confirmation == QMessageBox.No:
@@ -2076,8 +2052,8 @@ def run_robert(self):
                         f"[ERROR] Could not delete folder '{folder}': {e}"
                     )
                     self._reset_ui_after_process()
-                    return   
-      
+                    return
+
         # --------------------------------------------------
         # Collect GUI values
         # --------------------------------------------------
@@ -2089,7 +2065,7 @@ def run_robert(self):
                 "WARNING! Invalid parameters. Please fix them before running."
             )
             return
-        
+
         # Rename pdf if full workflow or report selected
         wf_predict = self.workflow_selector.currentText()
         if wf_predict == "Full Workflow" or wf_predict == "REPORT":
@@ -2113,10 +2089,10 @@ def run_robert(self):
 
             train_source_csv = self._get_unmapped_csv(self.file_path)
 
-            self.mapped_train_csv = self.tab_widget_aqme.generate_mapped_csv_from_smiles(
-                train_source_csv,
-                smarts,
-                selected_atoms_for_robert
+            self.mapped_train_csv = (
+                self.tab_widget_aqme.generate_mapped_csv_from_smiles(
+                    train_source_csv, smarts, selected_atoms_for_robert
+                )
             )
 
             is_robert_mapped = True
@@ -2124,18 +2100,16 @@ def run_robert(self):
             if getattr(self, "csv_test_path", None):
                 test_source_csv = self._get_unmapped_csv(self.csv_test_path)
 
-                self.mapped_test_csv = self.tab_widget_aqme.generate_mapped_csv_from_smiles(
-                    test_source_csv,
-                    smarts,
-                    selected_atoms_for_robert
+                self.mapped_test_csv = (
+                    self.tab_widget_aqme.generate_mapped_csv_from_smiles(
+                        test_source_csv, smarts, selected_atoms_for_robert
+                    )
                 )
 
             #  Save atomic mapping contract in .dat
             run_dir = os.path.dirname(self.file_path)
             self._write_atom_mapping_dat(
-                smarts=smarts,
-                selected_atoms=selected_atoms_for_robert,
-                run_dir=run_dir
+                smarts=smarts, selected_atoms=selected_atoms_for_robert, run_dir=run_dir
             )
 
         # --------------------------------------------------
@@ -2151,7 +2125,7 @@ def run_robert(self):
             self.mapped_test_csv
             if is_robert_mapped and self.mapped_test_csv
             else getattr(self, "csv_test_path", None)
-        ) 
+        )
 
         # --------------------------------------------------------------------------
         # AQME-origin CSV check, disable AQME workflow if detected previously runned
@@ -2169,14 +2143,11 @@ def run_robert(self):
             if not self.check_atomic_descriptors("ROBERT"):
                 self._reset_ui_after_process()
                 return
-            
+
         # --------------------------------------------------
         # PREDICT + csv_test preflight
         # --------------------------------------------------
-        if (
-            self.workflow_selector.currentText() == "PREDICT"
-            and self.csv_test_path
-        ):
+        if self.workflow_selector.currentText() == "PREDICT" and self.csv_test_path:
             run_dir = os.path.dirname(self.csv_test_path)
 
             # -----------------------------------------------
@@ -2185,7 +2156,6 @@ def run_robert(self):
             dat_path = os.path.join(run_dir, "AtomMapping_data.dat")
 
             if os.path.isfile(dat_path):
-
                 self.console_output.append(
                     f"[INFO] Atomic mapping contract detected: {dat_path}"
                 )
@@ -2196,15 +2166,14 @@ def run_robert(self):
 
                     # Validate + Apply in one step
                     new_mapped_csv = self._apply_mapping_smarts(
-                        self.csv_test_path,
-                        contract
+                        self.csv_test_path, contract
                     )
 
                     self.console_output.append(
                         f"[INFO] Generated mapped test CSV: {new_mapped_csv}"
                     )
 
-                    # Save original test CSV only 
+                    # Save original test CSV only
                     if not getattr(self, "_original_test_csv_path", None):
                         self._original_test_csv_path = self.csv_test_path
 
@@ -2217,7 +2186,7 @@ def run_robert(self):
                         "Atomic mapping mismatch",
                         f"{e}\n\n"
                         "Prediction cannot continue because atomic descriptors "
-                        "would not be consistent with the trained model."
+                        "would not be consistent with the trained model.",
                     )
                     self._reset_ui_after_process()
                     return
@@ -2253,13 +2222,13 @@ def run_robert(self):
             run_dir = os.path.dirname(selected_file_path)
         elif self.csv_test_path:
             run_dir = os.path.dirname(self.csv_test_path)
-     
+
         self.worker = RobertWorker(command, run_dir)
         self.worker.output_received.connect(self.console_output.append)
         self.worker.error_received.connect(self.console_output.append)
         self.worker.process_finished.connect(self.on_process_finished)
         self.worker.start()
-    
+
     def _read_atom_mapping_dat(self, dat_path):
         """
         Read atomic mapping contract from .dat file.
@@ -2284,13 +2253,17 @@ def _read_atom_mapping_dat(self, dat_path):
         mapping = []
 
         for line in lines:
-
             # SMARTS line
             if line.startswith("SMARTS pattern"):
                 continue  # skip header line
 
-            if smarts is None and not line.startswith("-") and "Pattern atoms" not in line and "Pattern atom" not in line:
-                # First non-header SMARTS candidate 
+            if (
+                smarts is None
+                and not line.startswith("-")
+                and "Pattern atoms" not in line
+                and "Pattern atom" not in line
+            ):
+                # First non-header SMARTS candidate
                 if "[" in line or "#" in line:
                     smarts = line
 
@@ -2303,28 +2276,26 @@ def _read_atom_mapping_dat(self, dat_path):
                 # Extract using regex
                 match = re.search(
                     r"Pattern atom\s+(\d+)\s+→\s+atomMap\s+(\d+)\s+\(Element:\s+(\w+)\)",
-                    line
+                    line,
                 )
                 if match:
                     pattern_idx = int(match.group(1))
                     map_num = int(match.group(2))
                     element = match.group(3)
 
-                    mapping.append({
-                        "pattern_idx": pattern_idx,
-                        "map_num": map_num,
-                        "element": element
-                    })
+                    mapping.append(
+                        {
+                            "pattern_idx": pattern_idx,
+                            "map_num": map_num,
+                            "element": element,
+                        }
+                    )
 
         if smarts is None or pattern_atoms is None or not mapping:
             raise ValueError("Invalid atom_mapping.dat format")
-        
-        return {
-            "smarts": smarts,
-            "pattern_atoms": pattern_atoms,
-            "mapping": mapping
-        }
-    
+
+        return {"smarts": smarts, "pattern_atoms": pattern_atoms, "mapping": mapping}
+
     def _apply_mapping_smarts(self, csv_path, contract):
         """
         Validate and apply atomic mapping contract to CSV.
@@ -2335,10 +2306,7 @@ def _apply_mapping_smarts(self, csv_path, contract):
 
         df = smart_read_csv(csv_path)
 
-        smiles_col = next(
-            (c for c in df.columns if c.lower() == "smiles"),
-            None
-        )
+        smiles_col = next((c for c in df.columns if c.lower() == "smiles"), None)
 
         if smiles_col is None:
             raise ValueError("CSV has no SMILES column")
@@ -2354,7 +2322,7 @@ def _apply_mapping_smarts(self, csv_path, contract):
 
         if pattern_mol.GetNumAtoms() != expected_pattern_atoms:
             raise ValueError("SMARTS atom count mismatch with contract")
-        
+
         # --------------------------------------------------
         # Step 1: Detect if already mapped correctly
         # --------------------------------------------------
@@ -2367,7 +2335,6 @@ def _apply_mapping_smarts(self, csv_path, contract):
         mapped_smiles = []
 
         for smiles in df[smiles_col].dropna().astype(str):
-
             mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
             if mol is None:
                 raise ValueError(f"Invalid SMILES: {smiles}")
@@ -2445,7 +2412,7 @@ def build_robert_command(self, selected_file_path):
                     command += f' --csv_test "{csv_test}"'
 
             return command
-        
+
         # ==================================================
         # NORMAL WORKFLOW (CURATE / GENERATE / VERIFY)
         # ==================================================
@@ -2466,8 +2433,7 @@ def build_robert_command(self, selected_file_path):
 
         # ---------- IGNORE COLUMNS ----------
         selected_columns = [
-            self.ignore_list.item(i).text()
-            for i in range(self.ignore_list.count())
+            self.ignore_list.item(i).text() for i in range(self.ignore_list.count())
         ]
         if selected_columns:
             formatted_columns = [f"'{col}'" for col in selected_columns]
@@ -2487,21 +2453,21 @@ def build_robert_command(self, selected_file_path):
             command += " --auto_type False"
 
         if self.seed_value:
-            command += f' --seed {self.seed_value}'
+            command += f" --seed {self.seed_value}"
 
         if self.kfold_value:
-            command += f' --kfold {self.kfold_value}'
+            command += f" --kfold {self.kfold_value}"
 
         if self.repeat_kfolds_value:
-            command += f' --repeat_kfolds {self.repeat_kfolds_value}'
+            command += f" --repeat_kfolds {self.repeat_kfolds_value}"
 
         if self.split_value != "even":
-            command += f' --split {self.split_value.lower()}'
+            command += f" --split {self.split_value.lower()}"
 
         # ---------- AQME ----------
         if self.aqme_workflow.isChecked():
-            command += ' --aqme'
-            command += f' --descp_lvl {self.descriptor_level_selected}'
+            command += " --aqme"
+            command += f" --descp_lvl {self.descriptor_level_selected}"
 
             atoms_entries = []
 
@@ -2527,66 +2493,66 @@ def build_robert_command(self, selected_file_path):
 
         # ---------- CURATE ----------
         if self.categorical_value != "onehot":
-            command += f' --categorical {self.categorical_value}'
+            command += f" --categorical {self.categorical_value}"
 
         if not self.corr_filter_x_value:
-            command += ' --corr_filter_x False'
+            command += " --corr_filter_x False"
 
         if self.corr_filter_y_value:
-            command += ' --corr_filter_y True'
+            command += " --corr_filter_y True"
 
         if self.desc_thres_value:
-            command += f' --desc_thres {self.desc_thres_value}'
+            command += f" --desc_thres {self.desc_thres_value}"
 
         if self.thres_x_value:
-            command += f' --thres_x {self.thres_x_value}'
+            command += f" --thres_x {self.thres_x_value}"
 
         if self.thres_y_value:
-            command += f' --thres_y {self.thres_y_value}'
+            command += f" --thres_y {self.thres_y_value}"
 
         # ---------- GENERATE ----------
         if self.selected_models != self.default_models:
-            model_list = "[" + ",".join(
-                f"'{m}'" for m in sorted(self.selected_models)
-            ) + "]"
+            model_list = (
+                "[" + ",".join(f"'{m}'" for m in sorted(self.selected_models)) + "]"
+            )
             command += f' --model "{model_list}"'
 
         if self.error_type_value != self.default_error_type:
-            command += f' --error_type {self.error_type_value}'
+            command += f" --error_type {self.error_type_value}"
 
         if self.init_points_value:
-            command += f' --init_points {self.init_points_value}'
+            command += f" --init_points {self.init_points_value}"
 
         if self.n_iter_value:
-            command += f' --n_iter {self.n_iter_value}'
+            command += f" --n_iter {self.n_iter_value}"
 
         if not self.pfi_filter_value:
             command += " --pfi_filter False"
 
         if self.pfi_epochs_value:
-            command += f' --pfi_epochs {self.pfi_epochs_value}'
+            command += f" --pfi_epochs {self.pfi_epochs_value}"
 
         if self.pfi_threshold_value:
-            command += f' --pfi_threshold {self.pfi_threshold_value}'
+            command += f" --pfi_threshold {self.pfi_threshold_value}"
 
         if self.pfi_max_value:
-            command += f' --pfi_max {self.pfi_max_value}'
+            command += f" --pfi_max {self.pfi_max_value}"
 
         if not self.auto_test_value:
             command += " --auto_test False"
 
         if self.test_set_value:
-            command += f' --test_set {self.test_set_value}'
+            command += f" --test_set {self.test_set_value}"
 
         # ---------- PREDICT OPTIONS (shared flags) ----------
         if self.t_value:
-            command += f' --t_value {self.t_value}'
+            command += f" --t_value {self.t_value}"
 
         if self.shap_show:
-            command += f' --shap_show {self.shap_show}'
+            command += f" --shap_show {self.shap_show}"
 
         if self.pfi_show:
-            command += f' --pfi_show {self.pfi_show}'
+            command += f" --pfi_show {self.pfi_show}"
 
         return command
 
@@ -2601,7 +2567,9 @@ def _collect_robert_gui_values(self):
         self.split_value = self.options_tab.split.currentText().strip()
 
         # ---------- AQME ----------
-        self.descriptor_level_selected = self.tab_widget_aqme.descriptor_level.currentText()
+        self.descriptor_level_selected = (
+            self.tab_widget_aqme.descriptor_level.currentText()
+        )
         self.atoms_selected = self.tab_widget_aqme.atoms.text().strip()
         self.solvent_selected = self.tab_widget_aqme.solvent.currentText()
 
@@ -2617,15 +2585,15 @@ def _collect_robert_gui_values(self):
         type_mode = self.type_dropdown.currentText()
 
         self.default_models = (
-            {"RF", "GB", "NN", "MVL"} if type_mode == "Regression"
+            {"RF", "GB", "NN", "MVL"}
+            if type_mode == "Regression"
             else {"RF", "GB", "NN", "AdaB"}
         )
 
-        self.default_error_type = (
-            "rmse" if type_mode == "Regression" else "mcc"
-        )
+        self.default_error_type = "rmse" if type_mode == "Regression" else "mcc"
         self.selected_models = {
-            model for model, checkbox in self.options_tab.modellist.items()
+            model
+            for model, checkbox in self.options_tab.modellist.items()
             if checkbox.isChecked()
         }
 
@@ -2662,9 +2630,7 @@ def _ensure_test_name_column(self):
             df_test = pd.read_csv(self.csv_test_path)
         except Exception as e:
             QMessageBox.warning(
-                self,
-                "Test dataset error",
-                f"Could not read test CSV file:\n\n{e}"
+                self, "Test dataset error", f"Could not read test CSV file:\n\n{e}"
             )
             return False
 
@@ -2686,7 +2652,7 @@ def _ensure_test_name_column(self):
                 QMessageBox.warning(
                     self,
                     "Test dataset error",
-                    f"Could not update test CSV file:\n\n{e}"
+                    f"Could not update test CSV file:\n\n{e}",
                 )
                 return False
 
@@ -2703,7 +2669,7 @@ def _ensure_test_name_column(self):
             "Incompatible test dataset",
             f"The selected name column '{name_col}' is not present in the test dataset.\n\n"
             "The test CSV does not contain this column, nor a fallback 'code_name' column.\n\n"
-            "Please select a compatible test dataset or change the name column."
+            "Please select a compatible test dataset or change the name column.",
         )
 
         return False
@@ -2722,9 +2688,7 @@ def check_variables_robert(self):
         # ------------------------
         if is_predict and not self.file_path and not self.csv_test_path:
             QMessageBox.warning(
-                self,
-                "Invalid Selection",
-                "Predict requires at least one CSV file."
+                self, "Invalid Selection", "Predict requires at least one CSV file."
             )
             return False
 
@@ -2736,7 +2700,7 @@ def check_variables_robert(self):
                 QMessageBox.warning(
                     self,
                     "Invalid Selection",
-                    "The name column and the target value column cannot be the same. Please select different columns."
+                    "The name column and the target value column cannot be the same. Please select different columns.",
                 )
                 return False
 
@@ -2763,10 +2727,14 @@ def check_variables_robert(self):
         # AQME (skip in PREDICT)
         # ------------------------
         if self.aqme_workflow.isChecked() and not is_predict:
-
             total_columns = []
-            total_columns += [self.available_list.item(i).text() for i in range(self.available_list.count())]
-            total_columns += [self.ignore_list.item(i).text() for i in range(self.ignore_list.count())]
+            total_columns += [
+                self.available_list.item(i).text()
+                for i in range(self.available_list.count())
+            ]
+            total_columns += [
+                self.ignore_list.item(i).text() for i in range(self.ignore_list.count())
+            ]
             lowercase_columns = [col.lower() for col in total_columns]
 
             if not any(col.startswith("smiles") for col in lowercase_columns):
@@ -2853,7 +2821,7 @@ def check_variables_robert(self):
             return False
 
         return True
-    
+
     def build_aqme_command(self, selected_file_path, selected_atoms_override=None):
         """Builds the AQME command based on the GUI selections."""
 
@@ -2873,9 +2841,9 @@ def build_aqme_command(self, selected_file_path, selected_atoms_override=None):
         command = (
             f'"{python_pointer}" -u -m aqme --qdescp '
             f'--input "{csv_name}" '
-            f'--program xtb '
+            f"--program xtb "
             f'--csv_name "{csv_name}" '
-            f'--robert'
+            f"--robert"
         )
 
         # ------------------------
@@ -2946,8 +2914,7 @@ def run_aqme(self):
         # ---------------------------
         folders_to_check = ["AQME", "CSEARCH", "QDESCP", "AQME_RUNS"]
         existing_folders = [
-            f for f in folders_to_check
-            if os.path.exists(os.path.join(run_dir, f))
+            f for f in folders_to_check if os.path.exists(os.path.join(run_dir, f))
         ]
 
         if existing_folders:
@@ -2959,7 +2926,7 @@ def run_aqme(self):
                 "They may be reused or overwritten.\n\n"
                 "Do you want to continue?",
                 QMessageBox.Yes | QMessageBox.No,
-                QMessageBox.No
+                QMessageBox.No,
             )
 
             if confirmation == QMessageBox.No:
@@ -2979,36 +2946,32 @@ def run_aqme(self):
 
             smarts = self.tab_widget_aqme.smarts_targets[0]
 
-            self.mapped_train_csv = self.tab_widget_aqme.generate_mapped_csv_from_smiles(
-                self.file_path,
-                smarts,
-                selected_atoms_for_aqme
+            self.mapped_train_csv = (
+                self.tab_widget_aqme.generate_mapped_csv_from_smiles(
+                    self.file_path, smarts, selected_atoms_for_aqme
+                )
             )
 
             self.is_aqme_mapped = True
 
             if self.csv_test_path:
-                self.mapped_test_csv = self.tab_widget_aqme.generate_mapped_csv_from_smiles(
-                    self.csv_test_path,
-                    smarts,
-                    selected_atoms_for_aqme
+                self.mapped_test_csv = (
+                    self.tab_widget_aqme.generate_mapped_csv_from_smiles(
+                        self.csv_test_path, smarts, selected_atoms_for_aqme
+                    )
                 )
 
             #  Save atomic mapping contract in .dat
             run_dir = os.path.dirname(self.file_path)
             self._write_atom_mapping_dat(
-                smarts=smarts,
-                selected_atoms=selected_atoms_for_aqme,
-                run_dir=run_dir
+                smarts=smarts, selected_atoms=selected_atoms_for_aqme, run_dir=run_dir
             )
 
         # ------------------------------------------------
         # Decide REAL input CSVs for AQME
         # ------------------------------------------------
         train_input_csv = (
-            self.mapped_train_csv
-            if self.is_aqme_mapped
-            else self.file_path
+            self.mapped_train_csv if self.is_aqme_mapped else self.file_path
         )
 
         test_input_csv = (
@@ -3021,26 +2984,20 @@ def run_aqme(self):
         # Build AQME command queue
         # -------------------------
         main_cmd, self.aqme_run_dir = self.build_aqme_command(
-            train_input_csv,
-            selected_atoms_override=selected_atoms_for_aqme
+            train_input_csv, selected_atoms_override=selected_atoms_for_aqme
         )
 
-        self.aqme_command_queue.append({
-            "command": main_cmd,
-            "csv": train_input_csv,
-            "role": "train"
-        })
+        self.aqme_command_queue.append(
+            {"command": main_cmd, "csv": train_input_csv, "role": "train"}
+        )
 
         if test_input_csv:
             test_cmd, _ = self.build_aqme_command(
-                test_input_csv,
-                selected_atoms_override=selected_atoms_for_aqme
+                test_input_csv, selected_atoms_override=selected_atoms_for_aqme
+            )
+            self.aqme_command_queue.append(
+                {"command": test_cmd, "csv": test_input_csv, "role": "test"}
             )
-            self.aqme_command_queue.append({
-                "command": test_cmd,
-                "csv": test_input_csv,
-                "role": "test"
-            })
 
         # ------------
         # Launch AQME
@@ -3052,7 +3009,7 @@ def run_aqme(self):
         self._run_next_aqme()
 
     def _run_next_aqme(self):
-        """ Runs the next AQME command in the queue."""
+        """Runs the next AQME command in the queue."""
         if not self.aqme_command_queue:
             return
 
@@ -3071,23 +3028,25 @@ def stop_process(self):
         """Stops the ROBERT and AQME process safely after user confirmation, non-blocking."""
 
         confirmation = QMessageBox.question(
-            self, 
-            "WARNING!", 
+            self,
+            "WARNING!",
             "Are you sure you want to stop the process?",
-            QMessageBox.Yes | QMessageBox.No, 
-            QMessageBox.No
+            QMessageBox.Yes | QMessageBox.No,
+            QMessageBox.No,
         )
 
         if confirmation == QMessageBox.No:
-            return  
+            return
 
         self.manual_stop = True
 
         if self.worker and self.worker.isRunning():
-            self.console_output.append("
Stopping ROBERT...") + self.console_output.append( + "
Stopping ROBERT..." + ) self.progress.setRange(0, 100) self.stop_button.setDisabled(True) - QTimer.singleShot(0, self.worker.stop) + QTimer.singleShot(0, self.worker.stop) def _on_aqme_step_finished(self, exit_code): """Handles the completion of an AQME step and manages the queue.""" @@ -3163,7 +3122,7 @@ def on_process_finished(self, exit_code): QMessageBox.information( self, "WARNING!", - f"{self.current_process} has been successfully stopped." + f"{self.current_process} has been successfully stopped.", ) self.manual_stop = False self._reset_ui_after_process() @@ -3175,17 +3134,14 @@ def on_process_finished(self, exit_code): # AQME COMPLETION LOGIC # ================================================== if self.current_process == "AQME": - # ================================================== # AQME SUCCESS # ================================================== if exit_code == 0 and "Time QDESCP:" in output_text: - # ============================================= # AQME TEST -> chain directly to ROBERT PREDICT # ============================================= if getattr(self, "aqme_role", None) == "test": - # Reset AQME state to avoid conflicts with future runs self.aqme_role = None @@ -3197,7 +3153,7 @@ def on_process_finished(self, exit_code): "AQME error", "AQME finished successfully, but the expected output CSV was not found.\n\n" "Descriptor generation for the test set completed, but the generated " - "CSV file could not be detected, so prediction cannot continue." + "CSV file could not be detected, so prediction cannot continue.", ) self.manual_stop = False self._reset_ui_after_process() @@ -3212,8 +3168,8 @@ def on_process_finished(self, exit_code): # -------------------------------------------------- # Launch ROBERT prediction (force AQME output as test CSV) # -------------------------------------------------- - - # Save original test CSV only + + # Save original test CSV only if not getattr(self, "_original_test_csv_path", None): self._original_test_csv_path = self.csv_test_path @@ -3230,22 +3186,17 @@ def on_process_finished(self, exit_code): self.worker.error_received.connect(self.console_output.append) self.worker.process_finished.connect(self.on_process_finished) self.worker.start() - return + return # ============================================= # AQME TRAIN -> original popup logic (unchanged) # ============================================= train_run = next( - (r for r in self.aqme_runs if r["role"] == "train"), - None + (r for r in self.aqme_runs if r["role"] == "train"), None ) if not train_run: - QMessageBox.warning( - self, - "WARNING!", - "No AQME train output found." - ) + QMessageBox.warning(self, "WARNING!", "No AQME train output found.") self.manual_stop = False self._reset_ui_after_process() return @@ -3255,8 +3206,12 @@ def on_process_finished(self, exit_code): base_name = os.path.splitext(os.path.basename(aqme_base))[0] aqme_csvs = { - "denovo": os.path.join(base_dir, f"AQME-ROBERT_denovo_{base_name}.csv"), - "interpret": os.path.join(base_dir, f"AQME-ROBERT_interpret_{base_name}.csv"), + "denovo": os.path.join( + base_dir, f"AQME-ROBERT_denovo_{base_name}.csv" + ), + "interpret": os.path.join( + base_dir, f"AQME-ROBERT_interpret_{base_name}.csv" + ), "full": os.path.join(base_dir, f"AQME-ROBERT_full_{base_name}.csv"), } @@ -3316,17 +3271,10 @@ def on_process_finished(self, exit_code): ) btn_interpret = msg.addButton( - "Interpret descriptors (recommended)", - QMessageBox.ActionRole - ) - btn_denovo = msg.addButton( - "DeNovo descriptors", - QMessageBox.ActionRole - ) - btn_full = msg.addButton( - "Full descriptors", - QMessageBox.ActionRole + "Interpret descriptors (recommended)", QMessageBox.ActionRole ) + btn_denovo = msg.addButton("DeNovo descriptors", QMessageBox.ActionRole) + btn_full = msg.addButton("Full descriptors", QMessageBox.ActionRole) msg.addButton("Cancel", QMessageBox.RejectRole) msg.exec() @@ -3356,8 +3304,7 @@ def on_process_finished(self, exit_code): base_name = os.path.splitext(os.path.basename(run["csv"]))[0] output_csv = os.path.join( - base_dir, - f"AQME-ROBERT_{selected_level}_{base_name}.csv" + base_dir, f"AQME-ROBERT_{selected_level}_{base_name}.csv" ) if not os.path.isfile(output_csv): @@ -3386,7 +3333,7 @@ def on_process_finished(self, exit_code): # -------------------------------------------------- for level in ["denovo", "interpret", "full"]: if not user_cancelled and level == selected_level: - continue # keep active CSV where it is + continue # keep active CSV where it is csv_name = f"AQME-ROBERT_{level}_{base_name}.csv" csv_path = Path(self.aqme_run_dir) / csv_name @@ -3399,7 +3346,7 @@ def on_process_finished(self, exit_code): target_path.unlink() shutil.move(str(csv_path), str(target_path)) - + # -------------------------------------------------- # 2) Move mapped CSVs (PER RUN, TRAIN + TEST) # -------------------------------------------------- @@ -3430,7 +3377,7 @@ def on_process_finished(self, exit_code): QMessageBox.warning( self, "WARNING!", - "AQME encountered an issue while finishing. Please check the logs." + "AQME encountered an issue while finishing. Please check the logs.", ) # End of AQME workflow self.manual_stop = False @@ -3454,8 +3401,13 @@ def on_process_finished(self, exit_code): # ------------------------ # Full workflow / REPORT # ------------------------ - if not self.manual_stop and (workflow == "Full Workflow" or workflow == "REPORT"): - if exit_code == 0 and "ROBERT_report.pdf was created successfully" in output_text: + if not self.manual_stop and ( + workflow == "Full Workflow" or workflow == "REPORT" + ): + if ( + exit_code == 0 + and "ROBERT_report.pdf was created successfully" in output_text + ): msg_box = QMessageBox(self) msg_box.setIcon(QMessageBox.Information) msg_box.setWindowTitle("Success!") @@ -3476,7 +3428,7 @@ def on_process_finished(self, exit_code): QMessageBox.warning( self, "WARNING!", - "ROBERT encountered an issue while finishing. Please check the logs." + "ROBERT encountered an issue while finishing. Please check the logs.", ) # ------------------------ @@ -3485,41 +3437,57 @@ def on_process_finished(self, exit_code): elif workflow == "CURATE": if exit_code == 0 and "Time CURATE:" in output_text: QMessageBox.information( - self, "Success", "ROBERT has successfully completed the CURATE step." + self, + "Success", + "ROBERT has successfully completed the CURATE step.", ) else: QMessageBox.warning( - self, "WARNING!", "ROBERT encountered an issue while finishing. Please check the logs." + self, + "WARNING!", + "ROBERT encountered an issue while finishing. Please check the logs.", ) elif workflow == "GENERATE": if exit_code == 0 and "Time GENERATE:" in output_text: QMessageBox.information( - self, "Success", "ROBERT has successfully completed the GENERATE step." + self, + "Success", + "ROBERT has successfully completed the GENERATE step.", ) else: QMessageBox.warning( - self, "WARNING!", "ROBERT encountered an issue while finishing. Please check the logs." + self, + "WARNING!", + "ROBERT encountered an issue while finishing. Please check the logs.", ) elif workflow == "PREDICT": if exit_code == 0 and "Time PREDICT:" in output_text: QMessageBox.information( - self, "Success", "ROBERT has successfully completed the PREDICT step." + self, + "Success", + "ROBERT has successfully completed the PREDICT step.", ) else: QMessageBox.warning( - self, "WARNING!", "ROBERT encountered an issue while finishing. Please check the logs." + self, + "WARNING!", + "ROBERT encountered an issue while finishing. Please check the logs.", ) elif workflow == "VERIFY": if exit_code == 0 and "Time VERIFY:" in output_text: QMessageBox.information( - self, "Success", "ROBERT has successfully completed the VERIFY step." + self, + "Success", + "ROBERT has successfully completed the VERIFY step.", ) else: QMessageBox.warning( - self, "WARNING!", "ROBERT encountered an issue while finishing. Please check the logs." + self, + "WARNING!", + "ROBERT encountered an issue while finishing. Please check the logs.", ) # Restore previous test CSV if overridden for test workflow aqme generation diff --git a/robert/gui_easyrob/tabs/advanced_options.py b/robert/gui_easyrob/tabs/advanced_options.py index c95b60b..5e8042f 100644 --- a/robert/gui_easyrob/tabs/advanced_options.py +++ b/robert/gui_easyrob/tabs/advanced_options.py @@ -28,7 +28,6 @@ # Attempt local imports first (portable mode). If they fail, # fall back to installed package imports. try: - from utils.utils_gui import ( AssetLibrary, QCheckBox, @@ -47,8 +46,7 @@ Qt, ) -except ImportError as e: - +except ImportError: from robert.gui_easyrob.utils.utils_gui import ( AssetLibrary, QCheckBox, @@ -67,14 +65,16 @@ Qt, ) + class AdvancedOptionsTab(QWidget): """Tab for advanced options in the easyROB application.""" + def __init__(self, type_dropdown, tab_widget): super().__init__() self.type = type_dropdown self.tab_widget = tab_widget # Reference to the main QTabWidget main_layout = QVBoxLayout(self) - grid_layout = QGridLayout() + grid_layout = QGridLayout() self.box_features = "QGroupBox { font-weight: bold; }" # Create section boxes @@ -93,14 +93,13 @@ def __init__(self, type_dropdown, tab_widget): # PREDICT (Bottom Row, Full Width) grid_layout.addWidget(predict_box, 2, 0, 1, 2) - # Add the grid layout to the main layout main_layout.addLayout(grid_layout) self.setLayout(main_layout) def go_to_help_section(self, anchor): """Open a documentation section in the browser.""" - + base_url = "https://robert.readthedocs.io/en/latest/Technical/defaults.html" if anchor.upper() == "GENERAL": @@ -139,7 +138,7 @@ def create_general_section(self): self.seed = QLineEdit() self.seed.setPlaceholderText("0") layout.addRow(QLabel("seed:"), self.seed) - + self.kfold = QLineEdit() self.kfold.setPlaceholderText("5") layout.addRow(QLabel("kfold:"), self.kfold) @@ -149,7 +148,7 @@ def create_general_section(self): layout.addRow(QLabel("repeat_kfolds:"), self.repeat_kfolds) self.split = QComboBox() - self.split.addItems([ "even", "RND", "stratified", "KN", "extra_q1", "extra_q5" ]) + self.split.addItems(["even", "RND", "stratified", "KN", "extra_q1", "extra_q5"]) layout.addRow(QLabel("split:"), self.split) # --- Help button at the bottom --- @@ -165,7 +164,7 @@ def create_general_section(self): def create_curate_section(self): """Creates the CURATE section with a box and input fields.""" box = QGroupBox("CURATE") - box.setStyleSheet(self.box_features) + box.setStyleSheet(self.box_features) layout = QFormLayout() # Add new input fields for additional options @@ -206,7 +205,7 @@ def create_curate_section(self): def create_generate_section(self): """Creates the GENERATE section with a box and input fields.""" box = QGroupBox("GENERATE") - box.setStyleSheet(self.box_features) + box.setStyleSheet(self.box_features) layout = QFormLayout() self.model_group = QGroupBox("Models") @@ -220,9 +219,19 @@ def update_model_options(): # Determine which models should be checked by default if self.type.currentText() == "Regression": - default_checked_models = ["RF", "GB", "NN", "MVL"] # Regression defaults + default_checked_models = [ + "RF", + "GB", + "NN", + "MVL", + ] # Regression defaults else: - default_checked_models = ["RF", "GB", "NN", "AdaB"] # Classification defaults + default_checked_models = [ + "RF", + "GB", + "NN", + "AdaB", + ] # Classification defaults # Update check states instead of recreating widgets for model, checkbox in self.modellist.items(): @@ -248,14 +257,14 @@ def update_model_options(): # Error type selection that changes dynamically but is also user-selectable self.error_type = QComboBox() layout.addRow(QLabel("error_type:"), self.error_type) - + def update_error_type(): self.error_type.clear() if self.type.currentText() == "Regression": self.error_type.addItems(["rmse", "mae", "r2"]) else: self.error_type.addItems(["mcc", "f1", "acc"]) - + self.type.currentIndexChanged.connect(update_error_type) update_error_type() # Initialize with the correct default values @@ -308,17 +317,17 @@ def update_error_type(): def create_predict_section(self): """Creates the PREDICT section with a box and input fields.""" box = QGroupBox("PREDICT") - box.setStyleSheet(self.box_features) + box.setStyleSheet(self.box_features) layout = QFormLayout() - + self.t_value = QLineEdit() self.t_value.setPlaceholderText("2") layout.addRow(QLabel("t_value:"), self.t_value) - + self.shap_show = QLineEdit() self.shap_show.setPlaceholderText("10") layout.addRow(QLabel("shap_show:"), self.shap_show) - + self.pfi_show = QLineEdit() self.pfi_show.setPlaceholderText("10") layout.addRow(QLabel("pfi_show:"), self.pfi_show) @@ -331,4 +340,4 @@ def create_predict_section(self): layout.setAlignment(help_button, Qt.AlignRight) box.setLayout(layout) - return box \ No newline at end of file + return box diff --git a/robert/gui_easyrob/tabs/aqme.py b/robert/gui_easyrob/tabs/aqme.py index 14b4d52..526fc28 100644 --- a/robert/gui_easyrob/tabs/aqme.py +++ b/robert/gui_easyrob/tabs/aqme.py @@ -68,7 +68,7 @@ from utils.aqme_utils import ChemDrawFileDialog, MCSProcessWorker -except ImportError as e: +except ImportError: from robert.gui_easyrob.utils.utils_gui import ( AssetLibrary, BytesIO, @@ -114,21 +114,24 @@ import csv from functools import partial + class AQMETab(QWidget): """Tab responsible for AQME-oriented chemistry preparation workflows.""" - def __init__(self, tab_parent=None, main_window=None): + def __init__(self, tab_parent=None, main_window=None): super().__init__(tab_parent) # tab_parent = QTabWidget - self.main_tab_widget = tab_parent # Reference to the main QTabWidget - self.main_window = main_window # Reference to the main window, accessible to csv_df, csv_path, etc... + self.main_tab_widget = tab_parent # Reference to the main QTabWidget + self.main_window = main_window # Reference to the main window, accessible to csv_df, csv_path, etc... self.selected_atoms = [] self.box_features = "QGroupBox { font-weight: bold; }" # === Main vertical layout === main_layout = QVBoxLayout(self) - # --- ChemDraw Button (modern purple style + top spacing) --- - self.chemdraw_button = QPushButton("Generate CSV from ChemDraw Files or SDF file") + # --- ChemDraw Button (modern purple style + top spacing) --- + self.chemdraw_button = QPushButton( + "Generate CSV from ChemDraw Files or SDF file" + ) self.chemdraw_button.setCursor(Qt.PointingHandCursor) self.chemdraw_button.setFixedSize(400, 42) @@ -164,7 +167,6 @@ def __init__(self, tab_parent=None, main_window=None): main_layout.addLayout(button_container) - # === Viewer container with label + viewer stacked === self.mol_viewer_container = QWidget() self.mol_viewer_container.setFixedSize(400, 400) @@ -182,9 +184,11 @@ def __init__(self, tab_parent=None, main_window=None): # Allow text selection self.mol_viewer.setTextInteractionFlags( - Qt.TextSelectableByMouse | Qt.TextSelectableByKeyboard + Qt.TextSelectableByMouse | Qt.TextSelectableByKeyboard + ) + self.set_mol_viewer_message( + "📄 Select a CSV with a SMILES column to display a common SMARTS pattern." ) - self.set_mol_viewer_message("📄 Select a CSV with a SMILES column to display a common SMARTS pattern.") self.mol_viewer.setFixedSize(400, 400) # === mol_info_label === @@ -200,15 +204,22 @@ def __init__(self, tab_parent=None, main_window=None): border: 1px solid #aaa; """) - self.mol_info_label.setWordWrap(True) + self.mol_info_label.setWordWrap(True) self.mol_info_label.setTextInteractionFlags(Qt.TextSelectableByMouse) self.mol_info_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Maximum) - self.mol_info_label.setMaximumWidth(600) - self.mol_info_label.setAlignment(Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignTop) + self.mol_info_label.setMaximumWidth(600) + self.mol_info_label.setAlignment( + Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignTop + ) # === Set up the molecule viewer === mol_layout.addWidget(self.mol_viewer, 0, 0) - mol_layout.addWidget(self.mol_info_label, 0, 0, alignment=Qt.AlignmentFlag.AlignTop | Qt.AlignmentFlag.AlignLeft) + mol_layout.addWidget( + self.mol_info_label, + 0, + 0, + alignment=Qt.AlignmentFlag.AlignTop | Qt.AlignmentFlag.AlignLeft, + ) mol_wrapper_layout = QHBoxLayout() mol_wrapper_layout.setAlignment(Qt.AlignmentFlag.AlignCenter) mol_wrapper_layout.addWidget(self.mol_viewer_container) @@ -216,7 +227,7 @@ def __init__(self, tab_parent=None, main_window=None): # === AQME Box at the bottom === aqme_box = QGroupBox("AQME") - aqme_box.setMaximumHeight(200) + aqme_box.setMaximumHeight(200) aqme_box.setStyleSheet(self.box_features) aqme_layout = QFormLayout() @@ -224,33 +235,35 @@ def __init__(self, tab_parent=None, main_window=None): self.descriptor_level = QComboBox() self.descriptor_level.addItems(["interpret", "denovo", "full"]) self.solvent = QComboBox() - self.solvent.addItems([ - "None", - # "Acetone", - # "Acetonitrile", - # "Aniline", - # "Benzaldehyde", - # "Benzene", - # "CH2Cl2", - # "CHCl3", - # "CS2", - # "Dioxane", - # "DMF", - # "DMSO", - # "Ether", - # "Ethylacetate", - # "Furane", - # "Hexadecane", - # "Hexane", - # "Methanol", - # "Nitromethane", - # "Octanol", - # "Octanol (wet)", - # "Phenol", - # "Toluene", - # "THF", - # "Water" - ]) + self.solvent.addItems( + [ + "None", + # "Acetone", + # "Acetonitrile", + # "Aniline", + # "Benzaldehyde", + # "Benzene", + # "CH2Cl2", + # "CHCl3", + # "CS2", + # "Dioxane", + # "DMF", + # "DMSO", + # "Ether", + # "Ethylacetate", + # "Furane", + # "Hexadecane", + # "Hexane", + # "Methanol", + # "Nitromethane", + # "Octanol", + # "Octanol (wet)", + # "Phenol", + # "Toluene", + # "THF", + # "Water" + ] + ) aqme_layout.addRow(QLabel("QDESCP Atoms:"), self.atoms) aqme_layout.addRow(QLabel("Descriptor Level:"), self.descriptor_level) @@ -302,15 +315,19 @@ def detect_patterns_and_display(self): """Detects patterns in the loaded CSV and displays the first molecule.""" try: - self.csv_df = smart_read_csv(self.file_path) # Store the DataFrame for later use - self.smiles_column = next((col for col in self.csv_df.columns if col.lower() == "smiles"), None) + self.csv_df = smart_read_csv( + self.file_path + ) # Store the DataFrame for later use + self.smiles_column = next( + (col for col in self.csv_df.columns if col.lower() == "smiles"), None + ) self.set_mol_viewer_message("🔬 Detecting common SMARTS pattern...") # === Auto SMARTS detection === self.auto_pattern() - except Exception as e: + except Exception: self.set_mol_viewer_message("❌ Failed to load or process the CSV.") self.mol_info_label.setText("🔬 Info here") @@ -322,17 +339,14 @@ def _on_mcs_success(self, smarts): def _on_mcs_error(self, message): """Handle MCS detection error.""" - self.set_mol_viewer_message( - message, - tooltip="SMARTS pattern detection failed." - ) + self.set_mol_viewer_message(message, tooltip="SMARTS pattern detection failed.") self.mol_info_label.setText("🔬 Info here") def _on_mcs_timeout(self): """Handle MCS detection timeout.""" self.set_mol_viewer_message( "⏱️ Timeout: MCS (Maximum Common Substructure) took too long and was aborted.", - tooltip="SMARTS pattern detection failed." + tooltip="SMARTS pattern detection failed.", ) self.mol_info_label.setText("🔬 Info here") @@ -342,52 +356,33 @@ def build_unified_smiles_context(self, train_csv_path, test_csv_path=None): (FMCS, ambiguity checks, metal detection). """ train_df = smart_read_csv(train_csv_path) - smiles_col = next( - (c for c in train_df.columns if c.lower() == "smiles"), - None - ) + smiles_col = next((c for c in train_df.columns if c.lower() == "smiles"), None) if smiles_col is None: raise ValueError("TRAIN CSV has no SMILES column") - unified_smiles = ( - train_df[smiles_col] - .dropna() - .astype(str) - .tolist() - ) + unified_smiles = train_df[smiles_col].dropna().astype(str).tolist() if test_csv_path: test_df = smart_read_csv(test_csv_path) test_smiles_col = next( - (c for c in test_df.columns if c.lower() == "smiles"), - None + (c for c in test_df.columns if c.lower() == "smiles"), None ) if test_smiles_col is None: raise ValueError("TEST CSV has no SMILES column") unified_smiles.extend( - test_df[test_smiles_col] - .dropna() - .astype(str) - .tolist() + test_df[test_smiles_col].dropna().astype(str).tolist() ) return unified_smiles - + def generate_mapped_csv_from_smiles( - self, - csv_path, - smarts, - selected_atoms, - suffix="_mapped" + self, csv_path, smarts, selected_atoms, suffix="_mapped" ): """Generate a new CSV file with mapped SMILES based on the provided SMARTS pattern""" df = smart_read_csv(csv_path) - smiles_col = next( - (c for c in df.columns if c.lower() == "smiles"), - None - ) + smiles_col = next((c for c in df.columns if c.lower() == "smiles"), None) if smiles_col is None: raise ValueError("CSV has no SMILES column") @@ -447,12 +442,7 @@ def auto_pattern(self): smiles_list = unified_smiles else: # TRAIN only → exploratory - smiles_list = ( - self.csv_df[self.smiles_column] - .dropna() - .astype(str) - .tolist() - ) + smiles_list = self.csv_df[self.smiles_column].dropna().astype(str).tolist() if not smiles_list: self.set_mol_viewer_message( @@ -463,10 +453,7 @@ def auto_pattern(self): # ------------------------------- # Launch MCS worker # ------------------------------- - self.mcs_worker = MCSProcessWorker( - smiles_list, - timeout_ms=60000 - ) + self.mcs_worker = MCSProcessWorker(smiles_list, timeout_ms=60000) self.mcs_worker.finished.connect(self._on_mcs_success) self.mcs_worker.error.connect(self._on_mcs_error) @@ -476,16 +463,57 @@ def auto_pattern(self): def display_molecule(self): """Display a SMARTS molecule and highlight atoms based on user selection.""" - rdkit.rdBase.DisableLog('rdApp.*') + rdkit.rdBase.DisableLog("rdApp.*") rdDepictor.SetPreferCoordGen(True) self.metal_atomic_numbers = { - 3, 11, 19, 37, 55, 87, - 4, 12, 20, 38, 56, 88, - 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, - 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, - 72, 73, 74, 75, 76, 77, 78, 79, 80, - 13, 49, 50, 81, 82, 83 + 3, + 11, + 19, + 37, + 55, + 87, + 4, + 12, + 20, + 38, + 56, + 88, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 13, + 49, + 50, + 81, + 82, + 83, } try: @@ -556,18 +584,18 @@ def display_molecule(self): highlight_colors = ( {idx: (0.698, 0.4, 1.0) for idx in highlight_atoms} - if highlight_atoms else {} + if highlight_atoms + else {} ) drawer = rdMolDraw2D.MolDraw2DCairo( - self.molecule_image_width, - self.molecule_image_height + self.molecule_image_width, self.molecule_image_height ) drawer.drawOptions().bondLineWidth = 1.5 drawer.DrawMolecule( self.mol, highlightAtoms=list(highlight_atoms), - highlightAtomColors=highlight_colors + highlightAtomColors=highlight_colors, ) drawer.FinishDrawing() @@ -575,8 +603,7 @@ def display_molecule(self): pixmap = QPixmap() pixmap.loadFromData(png_bytes) self.atom_coords = [ - drawer.GetDrawCoords(i) - for i in range(self.mol.GetNumAtoms()) + drawer.GetDrawCoords(i) for i in range(self.mol.GetNumAtoms()) ] if self.mol_viewer: @@ -588,16 +615,16 @@ def display_molecule(self): if self.metal_found and self.multiple_matches_detected: self.mol_info_label.setText( - '🧪 SMARTS pattern loaded. Metal atom(s) automatically selected.
' + "🧪 SMARTS pattern loaded. Metal atom(s) automatically selected.
" '⚠️ Multiple matches were found. ' - 'Atomic descriptors will be generated for the detected metal atom(s). ' - 'Manual atom selection has been disabled to avoid ambiguity.' + "Atomic descriptors will be generated for the detected metal atom(s). " + "Manual atom selection has been disabled to avoid ambiguity." ) elif self.metal_found and not self.selected_atoms: self.mol_info_label.setText( - '🧪 SMARTS pattern loaded. Click to select atoms.
' + "🧪 SMARTS pattern loaded. Click to select atoms.
" '⚠️ No atoms selected. ' - 'Descriptors will only be generated for the detected metal.' + "Descriptors will only be generated for the detected metal." ) else: if highlight_atoms: @@ -606,25 +633,24 @@ def display_molecule(self): ) else: self.mol_info_label.setText( - '🧪 SMARTS pattern loaded. Click to select atoms.
' + "🧪 SMARTS pattern loaded. Click to select atoms.
" '⚠️ WARNING! No atoms selected. ' - 'Atomic descriptors will not be generated.' + "Atomic descriptors will not be generated." ) except Exception as e: - self.set_mol_viewer_message( - "❌ Error displaying molecule.", - tooltip=str(e) - ) + self.set_mol_viewer_message("❌ Error displaying molecule.", tooltip=str(e)) self.mol_info_label.setText("🔬 Info here") def handle_atom_selection(self, atom_idx): """Handle the selection of an atom in the pattern.""" - if not hasattr(self, 'selected_atoms'): + if not hasattr(self, "selected_atoms"): self.selected_atoms = [] - - if getattr(self, 'metal_found', False) and getattr(self, 'multiple_matches_detected', False): + + if getattr(self, "metal_found", False) and getattr( + self, "multiple_matches_detected", False + ): # Prevent manual selection when metal match has been auto-selected due to ambiguity return @@ -641,11 +667,12 @@ def handle_atom_selection(self, atom_idx): self.generate_mapped_smiles( self.smarts_targets[0], self.selected_atoms, - self.csv_df[self.smiles_column].dropna() + self.csv_df[self.smiles_column].dropna(), ) - - def generate_mapped_smiles(self, smarts_pattern, selected_pattern_indices, smiles_list): + def generate_mapped_smiles( + self, smarts_pattern, selected_pattern_indices, smiles_list + ): """ Generate mapped SMILES using a SMARTS pattern and selected atom indices. Updates self.df_mapped_smiles with a copy of the original CSV where 'SMILES' is replaced. @@ -693,42 +720,51 @@ def generate_mapped_smiles(self, smarts_pattern, selected_pattern_indices, smile df_mapped[self.smiles_column] = mapped_smiles self.df_mapped_smiles = df_mapped - def mousePressEvent(self, event: QMouseEvent): """Handle mouse press events to select atoms and crate pattern. The logic is to check if the mouse press event is within the molecule_viewer area.""" if event.button() == Qt.MouseButton.LeftButton: pos = event.position() - if self.mol_viewer_container and self.mol_viewer_container.geometry().contains(pos.toPoint()): + if ( + self.mol_viewer_container + and self.mol_viewer_container.geometry().contains(pos.toPoint()) + ): relative_pos = self.mol_viewer_container.mapFrom(self, pos.toPoint()) x = relative_pos.x() y = relative_pos.y() selected_atom = self.get_atom_at_position(x, y) if selected_atom is not None: self.handle_atom_selection(selected_atom) - self.display_molecule() + self.display_molecule() def get_atom_at_position(self, x, y): - """Get the atom index at the given position by - checking the distance from the atom coordinates. + """Get the atom index at the given position by + checking the distance from the atom coordinates. The atom coordinates are found using RDKit. The logic is to check if the distance between the mouse click and the atom coordinates is less than a threshold.""" - if not hasattr(self, 'atom_coords'): + if not hasattr(self, "atom_coords"): return None elif self.atom_coords is not None: for idx, coord in enumerate(self.atom_coords): - if len(self.smarts_targets[0]) <= 30: # small molecule = bigger click area - if (coord.x - x) ** 2 + (coord.y - y) ** 2 < 300: - return idx - if len(self.smarts_targets[0]) <= 50 and len(self.smarts_targets[0]) > 30: # medium molecule = medium click area - if (coord.x - x) ** 2 + (coord.y - y) ** 2 < 200: - return idx - elif len(self.smarts_targets[0]) > 50 : # big molecule = smaller click area - if (coord.x - x) ** 2 + (coord.y - y) ** 2 < 100: - return idx + if ( + len(self.smarts_targets[0]) <= 30 + ): # small molecule = bigger click area + if (coord.x - x) ** 2 + (coord.y - y) ** 2 < 300: + return idx + if ( + len(self.smarts_targets[0]) <= 50 + and len(self.smarts_targets[0]) > 30 + ): # medium molecule = medium click area + if (coord.x - x) ** 2 + (coord.y - y) ** 2 < 200: + return idx + elif ( + len(self.smarts_targets[0]) > 50 + ): # big molecule = smaller click area + if (coord.x - x) ** 2 + (coord.y - y) ** 2 < 100: + return idx return None def open_chemdraw_popup(self): @@ -744,7 +780,7 @@ def open_chemdraw_popup(self): "• Incorrect or broken bonds
" "• Unconnected fragments or misdrawn connections

" "When everything looks correct, click OK to select your file." - ) + ), ) dialog = ChemDrawFileDialog(self) @@ -754,9 +790,10 @@ def open_chemdraw_popup(self): def load_chemdraw_file(self, main_path): """Opens a ChemDraw file and displays the molecules in a table.""" + def load_mols_from_path(path): """Load molecules from a ChemDraw or SDF file.""" - if path.endswith('.cdxml'): + if path.endswith(".cdxml"): try: mols = MolsFromCDXMLFile(path, sanitize=False, removeHs=False) total_count = len(mols) @@ -764,13 +801,19 @@ def load_mols_from_path(path): for mol in mols: if mol is not None: - fragments = GetMolFrags(mol, asMols=True, sanitizeFrags=False) + fragments = GetMolFrags( + mol, asMols=True, sanitizeFrags=False + ) valid_mols.extend(fragments) valid_count = len(valid_mols) if valid_count == 0: - QMessageBox.warning(self, "CDXML Warning", f"No valid molecules found in the file:\n{path}") + QMessageBox.warning( + self, + "CDXML Warning", + f"No valid molecules found in the file:\n{path}", + ) return [] elif valid_count < total_count: @@ -778,18 +821,20 @@ def load_mols_from_path(path): QMessageBox.warning( self, "CDXML Partial Load", - f"File loaded with partial success.\n{failed_count} out of {total_count} molecules failed sanitization and were skipped." + f"File loaded with partial success.\n{failed_count} out of {total_count} molecules failed sanitization and were skipped.", ) return valid_mols except Exception as e: - QMessageBox.critical(self, "CDXML Read Error", f"Failed to read {path}:\n{str(e)}") + QMessageBox.critical( + self, "CDXML Read Error", f"Failed to read {path}:\n{str(e)}" + ) return [] - elif path.endswith('.sdf'): + elif path.endswith(".sdf"): return [mol for mol in Chem.SDMolSupplier(path) if mol is not None] - + elif path.endswith(".cdx"): QMessageBox.warning( self, @@ -807,7 +852,7 @@ def load_mols_from_path(path): "4. Paste it into a new ChemDraw document.
" "5. Save it as CDXML.

" "This ensures proper structure recognition and full compatibility with easyROB." - ) + ), ) return None @@ -820,7 +865,7 @@ def load_mols_from_path(path): # If the function returned None, it means we already handled a special case (like .cdx) if mols_main is None: return - + # If the function returned an empty list, it means there were no valid molecules if not mols_main: QMessageBox.warning(self, "Error", "No valid molecules found in the file.") @@ -846,7 +891,13 @@ def show_molecule_table_dialog(self, mols): # --- Table Columns --- base_headers = ["Image", "SMILES", "code_name", "target"] extra_columns = ["charge", "mult", "complex_type", "sample", "geom"] - complex_type_options = ["", "squareplanar", "squarepyramidal", "linear", "trigonalplanar"] + complex_type_options = [ + "", + "squareplanar", + "squarepyramidal", + "linear", + "trigonalplanar", + ] # Table widget setup table = QTableWidget(len(mols), len(base_headers)) @@ -867,8 +918,10 @@ def on_header_double_clicked(index): return # Only allow renaming for the 'target' column current_text = table.horizontalHeaderItem(index).text() new_text, ok = QInputDialog.getText( - dialog, "Edit Column Name", - f"Rename column '{current_text}':", text=current_text + dialog, + "Edit Column Name", + f"Rename column '{current_text}':", + text=current_text, ) if ok and new_text.strip(): table.setHorizontalHeaderItem(index, QTableWidgetItem(new_text.strip())) @@ -884,7 +937,11 @@ def on_header_double_clicked(index): img.save(buffer, format="PNG") qimg = QImage.fromData(buffer.getvalue()) label = QLabel() - label.setPixmap(QPixmap.fromImage(qimg).scaled(100, 100, Qt.KeepAspectRatio, Qt.SmoothTransformation)) + label.setPixmap( + QPixmap.fromImage(qimg).scaled( + 100, 100, Qt.KeepAspectRatio, Qt.SmoothTransformation + ) + ) widget = QWidget() hbox = QHBoxLayout() @@ -916,11 +973,14 @@ def toggle_column(col_name, state): Add or remove an extra column based on the corresponding checkbox. Handles special widget for 'complex_type' column. """ + def set_all_column_widths(width): for col in range(table.columnCount()): table.setColumnWidth(col, width) - current_headers = [table.horizontalHeaderItem(i).text() for i in range(table.columnCount())] + current_headers = [ + table.horizontalHeaderItem(i).text() for i in range(table.columnCount()) + ] if state: # Checkbox checked: add column if not present if col_name not in current_headers: idx = table.columnCount() @@ -961,7 +1021,9 @@ def save_to_csv(): Collect all table data and save to a CSV file. Includes validation for required fields, uniqueness, types, and empty checks. """ - headers = [table.horizontalHeaderItem(i).text() for i in range(table.columnCount())] + headers = [ + table.horizontalHeaderItem(i).text() for i in range(table.columnCount()) + ] # --- Mandatory column presence check --- try: @@ -977,13 +1039,21 @@ def save_to_csv(): # Check 'SMILES' not empty item = table.item(row, smiles_idx) if not item or not item.text().strip(): - QMessageBox.warning(dialog, "WARNING!", f"Please fill in all 'SMILES' fields before saving.") + QMessageBox.warning( + dialog, + "WARNING!", + "Please fill in all 'SMILES' fields before saving.", + ) return # Check 'code_name' not empty item = table.item(row, code_name_idx) if not item or not item.text().strip(): - QMessageBox.warning(dialog, "WARNING!", f"Please fill in all 'code_name' fields before saving.") + QMessageBox.warning( + dialog, + "WARNING!", + "Please fill in all 'code_name' fields before saving.", + ) return code_names.append(table.item(row, code_name_idx).text().strip()) @@ -994,10 +1064,14 @@ def save_to_csv(): item = table.item(row, charge_idx) val = item.text().strip() if item else "" if val == "": - QMessageBox.warning(dialog, "WARNING!", f"Column 'charge' cannot be empty.") + QMessageBox.warning( + dialog, "WARNING!", "Column 'charge' cannot be empty." + ) return - if not (val.lstrip('-').isdigit() and '.' not in val): - QMessageBox.warning(dialog, "WARNING!", f"Column 'charge' must be an integer.") + if not (val.lstrip("-").isdigit() and "." not in val): + QMessageBox.warning( + dialog, "WARNING!", "Column 'charge' must be an integer." + ) return # Validate 'mult' column if present (must be int, not empty) @@ -1006,10 +1080,14 @@ def save_to_csv(): item = table.item(row, mult_idx) val = item.text().strip() if item else "" if val == "": - QMessageBox.warning(dialog, "WARNING!", f"Column 'mult' cannot be empty.") + QMessageBox.warning( + dialog, "WARNING!", "Column 'mult' cannot be empty." + ) return - if not (val.lstrip('-').isdigit() and '.' not in val): - QMessageBox.warning(dialog, "WARNING!", f"Column 'mult' must be an integer.") + if not (val.lstrip("-").isdigit() and "." not in val): + QMessageBox.warning( + dialog, "WARNING!", "Column 'mult' must be an integer." + ) return # Validate 'complex_type' if present (must be selected) @@ -1018,11 +1096,12 @@ def save_to_csv(): combo = table.cellWidget(row, complex_type_idx) if combo is not None and combo.currentText().strip() == "": QMessageBox.warning( - dialog, "WARNING!", - f"Column 'complex_type' cannot be empty. Please select a value." + dialog, + "WARNING!", + "Column 'complex_type' cannot be empty. Please select a value.", ) return - + # Validate 'sample' column if present (must be int, not empty) if "sample" in headers: sample_idx = headers.index("sample") @@ -1030,10 +1109,16 @@ def save_to_csv(): item = table.item(row, sample_idx) val = item.text().strip() if item else "" if val == "": - QMessageBox.warning(dialog, "WARNING!", f"Column 'sample' cannot be empty.") + QMessageBox.warning( + dialog, "WARNING!", "Column 'sample' cannot be empty." + ) return - if not (val.lstrip('-').isdigit() and '.' not in val): - QMessageBox.warning(dialog, "WARNING!", f"Column 'sample' must be an integer.") + if not (val.lstrip("-").isdigit() and "." not in val): + QMessageBox.warning( + dialog, + "WARNING!", + "Column 'sample' must be an integer.", + ) return # Validate 'GEOM' column if present (must not be empty) @@ -1043,16 +1128,20 @@ def save_to_csv(): item = table.item(row, geom_idx) val = item.text().strip() if item else "" if val == "": - QMessageBox.warning(dialog, "WARNING!", f"Column 'geom' cannot be empty.") + QMessageBox.warning( + dialog, "WARNING!", "Column 'geom' cannot be empty." + ) return - # --- Uniqueness check for 'code_name' --- - duplicates = [name for name in set(code_names) if code_names.count(name) > 1] + duplicates = [ + name for name in set(code_names) if code_names.count(name) > 1 + ] if duplicates: QMessageBox.warning( - dialog, "WARNING!", - f"The following 'code_name' values are duplicated:\n\n{', '.join(duplicates)}\n\nPlease make them unique before saving." + dialog, + "WARNING!", + f"The following 'code_name' values are duplicated:\n\n{', '.join(duplicates)}\n\nPlease make them unique before saving.", ) return @@ -1061,16 +1150,20 @@ def save_to_csv(): item = table.item(row, self.target_col_index) val = item.text().strip() if item else "" if not val: - QMessageBox.warning(dialog, "WARNING!", f"Target column is empty.") + QMessageBox.warning(dialog, "WARNING!", "Target column is empty.") return try: float(val) except ValueError: - QMessageBox.warning(dialog, "WARNING!", f"Target column must be numeric.") + QMessageBox.warning( + dialog, "WARNING!", "Target column must be numeric." + ) return # --- File dialog to select save path --- - path, _ = QFileDialog.getSaveFileName(dialog, "Save CSV", "", "CSV Files (*.csv)") + path, _ = QFileDialog.getSaveFileName( + dialog, "Save CSV", "", "CSV Files (*.csv)" + ) if not path: return @@ -1097,9 +1190,11 @@ def save_to_csv(): if hasattr(self, "main_window") and self.main_window: self.main_window.set_file_path(path) dialog.accept() - QMessageBox.information(dialog, "Success", "CSV file saved and loaded successfully!") + QMessageBox.information( + dialog, "Success", "CSV file saved and loaded successfully!" + ) save_button.clicked.connect(save_to_csv) dialog.setLayout(layout) - dialog.exec() \ No newline at end of file + dialog.exec() diff --git a/robert/gui_easyrob/tabs/images.py b/robert/gui_easyrob/tabs/images.py index c1ae2c8..bc2bac3 100644 --- a/robert/gui_easyrob/tabs/images.py +++ b/robert/gui_easyrob/tabs/images.py @@ -42,7 +42,7 @@ Qt, ) -except ImportError as e: +except ImportError: from robert.gui_easyrob.utils.utils_gui import ( QApplication, QDesktopServices, @@ -66,6 +66,7 @@ import os import glob + class ImagesTab(QWidget): """Images tab for displaying images from multiple folders as workflow results.""" @@ -98,8 +99,7 @@ def __init__(self, main_tab_widget, image_folders, file_path): help_button.setFixedSize(18, 18) help_button.setStyleSheet("font-size: 11px;") help_button.setToolTip( - "Double-click: Open image\n" - "Right-click: Copy, Save, or Open folder" + "Double-click: Open image\nRight-click: Copy, Save, or Open folder" ) help_button.clicked.connect(self.show_help_dialog) @@ -268,4 +268,4 @@ def show_context_menu(self, position): "Images (*.png *.jpg *.jpeg)", ) if target_path: - QPixmap(self.image_path).save(target_path) \ No newline at end of file + QPixmap(self.image_path).save(target_path) diff --git a/robert/gui_easyrob/tabs/molssi.py b/robert/gui_easyrob/tabs/molssi.py index e8b9fe5..729e977 100644 --- a/robert/gui_easyrob/tabs/molssi.py +++ b/robert/gui_easyrob/tabs/molssi.py @@ -23,6 +23,7 @@ - Designed to keep the main window decoupled from download logic. """ + try: from utils.utils_gui import ( Path, @@ -42,7 +43,7 @@ ) from utils.molssi_utils import ExcelToCSVWorker -except ImportError as e: +except ImportError: from robert.gui_easyrob.utils.utils_gui import ( Path, QFileDialog, @@ -64,6 +65,7 @@ # ---- Standard library ---- import os + class MolSSIDatabasesTab(QWidget): """ Tab widget embedding the MolSSI descriptor databases web interface. @@ -74,9 +76,10 @@ class MolSSIDatabasesTab(QWidget): - Allows saving descriptor files locally - Optionally converts downloaded Excel files to CSV """ - # Signal emitted when a test file download is requested + + # Signal emitted when a test file download is requested load_test_requested = Signal(str) - + def __init__(self, parent=None): super().__init__(parent) @@ -100,6 +103,7 @@ class SingleWindowWebView(QWebEngineView): Custom QWebEngineView that prevents opening external windows. Any request to open a new window is redirected to the same view. """ + def createWindow(self, webWindowType): tmp = QWebEngineView(self) tmp.setAttribute(Qt.WA_DeleteOnClose, True) @@ -109,9 +113,7 @@ def createWindow(self, webWindowType): return tmp # Base URL for MolSSI databases - self.databases_home_url = QUrl( - "https://descriptor-libraries.molssi.org/" - ) + self.databases_home_url = QUrl("https://descriptor-libraries.molssi.org/") # -------------------------------------------------- # Home bar @@ -173,16 +175,10 @@ def _handle_download(self, req: QWebEngineDownloadRequest): def open_dialog(): suggested = ( - req.downloadFileName() - or QUrl(req.url()).fileName() - or "download" + req.downloadFileName() or QUrl(req.url()).fileName() or "download" ) - path, _ = QFileDialog.getSaveFileName( - self, - "Save File", - suggested - ) + path, _ = QFileDialog.getSaveFileName(self, "Save File", suggested) if not path: req.cancel() @@ -239,7 +235,7 @@ def _on_download_completed(self, path): "(for example, an experimental property or value you want to predict).\n\n" "Do you want to convert this Excel file to CSV now?", QMessageBox.Yes | QMessageBox.No, - QMessageBox.Yes + QMessageBox.Yes, ) if reply != QMessageBox.Yes: @@ -274,9 +270,7 @@ def finished(csv_path): return QMessageBox.information( - self, - "Conversion completed", - "Excel converted to CSV successfully." + self, "Conversion completed", "Excel converted to CSV successfully." ) def error(msg): @@ -307,7 +301,7 @@ def _show_download_popup(self): popup.setModal(False) popup.show() return popup - + def load_test_molssi(self, csv_path, source=None): """ Finalizes a MolSSI test dataset. @@ -335,4 +329,4 @@ def load_test_molssi(self, csv_path, source=None): try: Path(source).unlink() except Exception: - pass \ No newline at end of file + pass diff --git a/robert/gui_easyrob/tabs/predictions.py b/robert/gui_easyrob/tabs/predictions.py index 8eebfea..3afe41a 100644 --- a/robert/gui_easyrob/tabs/predictions.py +++ b/robert/gui_easyrob/tabs/predictions.py @@ -23,6 +23,7 @@ - Designed to keep heavy logic delegated to utils.predictions_utils """ + # ------------------------------------------------------------ # Import resolution (local vs installed package) # ------------------------------------------------------------ @@ -60,7 +61,7 @@ get_robert_report_path, ) -except ImportError as e: +except ImportError: from robert.gui_easyrob.utils.utils_gui import ( QFrame, QHBoxLayout, @@ -102,8 +103,10 @@ import pandas as pd import matplotlib.pyplot as plt + class PredictionsTab(QWidget): """Tab for displaying prediction results from ROBERT runs.""" + availabilityChanged = Signal(bool) def __init__(self, parent=None): @@ -147,7 +150,7 @@ def _extract_names_column_from_predict(self): return match.group(1) return None - + def _filter_prediction_dataframe(self, df: pd.DataFrame) -> pd.DataFrame: """ Keeps and orders columns as: @@ -193,7 +196,7 @@ def _filter_prediction_dataframe(self, df: pd.DataFrame) -> pd.DataFrame: df.insert(0, "Image", df[smiles_cols[0]]) return df[ordered_columns] - + def refresh_with_new_path(self, selected_file_path: str): """Refreshes the predictions tab with new data from the selected file path.""" # This is the ONLY base path used by PredictionsTab @@ -231,7 +234,7 @@ def _add_loaded_df(self, key: str, df: pd.DataFrame): info = evaluate_predictions_for_model( self._base_path, df, - key # "PFI" or "No_PFI" + key, # "PFI" or "No_PFI" ) # Extract fragment image @@ -271,7 +274,7 @@ def _show_histogram_menu_header(self, pos, df: pd.DataFrame, header): def _create_table_with_stats(self, df, info, pdf_image): """Creates the main table view with the predictions and the side dashboard with stats and diagnostics.""" - # ---- Container ---- + # ---- Container ---- container = QWidget() container.setStyleSheet("background: palette(window);") @@ -308,7 +311,7 @@ def _create_table_with_stats(self, df, info, pdf_image): lambda pos, d=df, h=header: self._show_header_menu(pos, d, h) ) - # ---- Side Dashboard ---- + # ---- Side Dashboard ---- pdf_path = info["pdf_path"] model_key = info["model"] # "PFI" or "No_PFI" @@ -325,7 +328,7 @@ def _create_table_with_stats(self, df, info, pdf_image): pdf_image=pdf_image, extrapolation_score=extrap_scores.get(model_key), extrapolation_image=extrap_pixmap, - external_plot=external_pixmap + external_plot=external_pixmap, ) # ---- Separator ---- @@ -353,11 +356,11 @@ def _show_header_menu(self, pos, df: pd.DataFrame, header): menu = QMenu(header) - # Sorting actions + # Sorting actions action_sort_asc = menu.addAction("Sort ascending") action_sort_desc = menu.addAction("Sort descending") - menu.addSeparator() + menu.addSeparator() # Histogram action (only for numeric columns) action_hist = None @@ -408,4 +411,4 @@ def _show_histogram(self, series: pd.Series, col_name: str): plt.xlabel(col_name) plt.ylabel("Frequency") plt.grid(False) - plt.show(block=False) \ No newline at end of file + plt.show(block=False) diff --git a/robert/gui_easyrob/tabs/results.py b/robert/gui_easyrob/tabs/results.py index 3d50b6e..063a8e6 100644 --- a/robert/gui_easyrob/tabs/results.py +++ b/robert/gui_easyrob/tabs/results.py @@ -54,7 +54,7 @@ fitz, ) -except ImportError as e: +except ImportError: from robert.gui_easyrob.utils.utils_gui import ( QImage, QLabel, @@ -81,15 +81,17 @@ import os import glob + class ResultsTab(QWidget): """PDF viewer for ROBERT reports.""" + def __init__(self, main_tab_widget, file_path): super().__init__() self.main_tab_widget = main_tab_widget self.base_path = os.path.dirname(file_path) - self.pdf_tabs = {} # {pdf_path: PDFViewer|None} None => placeholder not materialized - self.title_to_path = {} # {basename: full path} + self.pdf_tabs = {} # {pdf_path: PDFViewer|None} None => placeholder not materialized + self.title_to_path = {} # {basename: full path} # Shared thread pool for all PDF viewers self.shared_pool = QThreadPool() @@ -207,20 +209,25 @@ def _index_of_title(self, title: str) -> int: # ------------------------- Worker signals ------------------------- + class RenderSignals(QObject): """Signals for page rendering.""" + finished = Signal(int, float, int, QPixmap) # page_num, zoom, generation, pixmap class MetaSignals(QObject): """Signals for PDF metadata loading.""" + done = Signal(int, list) # page_count, page_sizes # ------------------------- Worker tasks ------------------------- + class RenderTask(QRunnable): """Background render task for a single PDF page (open by path; no big upfront I/O).""" + def __init__(self, pdf_path: str, page_num: int, zoom: float, generation: int): super().__init__() self.pdf_path = pdf_path @@ -238,15 +245,20 @@ def run(self): page = doc.load_page(self.page_num) mat = fitz.Matrix(self.zoom, self.zoom) pix = page.get_pixmap(matrix=mat, alpha=False) - qimg = QImage(pix.samples, pix.width, pix.height, pix.stride, QImage.Format_RGB888).copy() + qimg = QImage( + pix.samples, pix.width, pix.height, pix.stride, QImage.Format_RGB888 + ).copy() qp = QPixmap.fromImage(qimg) self.signals.finished.emit(self.page_num, self.zoom, self.generation, qp) except Exception: - self.signals.finished.emit(self.page_num, self.zoom, self.generation, QPixmap()) + self.signals.finished.emit( + self.page_num, self.zoom, self.generation, QPixmap() + ) class MetaTask(QRunnable): """Load page count and page sizes off the UI thread.""" + def __init__(self, pdf_path: str): super().__init__() self.pdf_path = pdf_path @@ -258,7 +270,9 @@ def run(self): try: with fitz.open(self.pdf_path) as doc: count = len(doc) - sizes = [tuple(doc.load_page(i).rect.br) for i in range(count)] # (w_pts, h_pts) + sizes = [ + tuple(doc.load_page(i).rect.br) for i in range(count) + ] # (w_pts, h_pts) except Exception: count, sizes = 1, [(595, 842)] # Fallback to A4 portrait in points self.signals.done.emit(count, sizes) @@ -266,19 +280,21 @@ def run(self): # ------------------------- PDFViewer (async metadata + visible-only render) ------------------------- + class PDFViewer(QWidget): """Widget to display a PDF inside a scrollable area with zoom control and threading.""" + def __init__(self, pdf_path: str, thread_pool: QThreadPool): super().__init__() self.pdf_path = pdf_path self.current_zoom = 1.2 self.thread_pool = thread_pool # shared - self.image_cache = {} # {(page_num, zoom): QPixmap} - self.labels = [] # one QLabel per page - self.page_sizes = None # [(width_pts, height_pts)] + self.image_cache = {} # {(page_num, zoom): QPixmap} + self.labels = [] # one QLabel per page + self.page_sizes = None # [(width_pts, height_pts)] self.page_count = None - self._renderGeneration = 0 # cancel stale renders + self._renderGeneration = 0 # cancel stale renders self._zoomPending = False self._scrollPending = False @@ -332,7 +348,9 @@ def _apply_metadata(self, page_count: int, page_sizes: list): self._build_placeholders_for_zoom(self.current_zoom) # Now that we know page geometry, hook scroll coalescing - self.scroll_area.verticalScrollBar().valueChanged.connect(self._schedule_visible_render) + self.scroll_area.verticalScrollBar().valueChanged.connect( + self._schedule_visible_render + ) # Initial render: only what's visible + tiny warm self._kick_off_visible_render(force=True, warm=1) @@ -355,7 +373,9 @@ def _apply_zoom_now(self): # Bump generation to discard in-flight renders self._renderGeneration += 1 # Keep only current-zoom cache - self.image_cache = {k: v for k, v in self.image_cache.items() if k[1] == self.current_zoom} + self.image_cache = { + k: v for k, v in self.image_cache.items() if k[1] == self.current_zoom + } # Recompute placeholder heights and clear labels self._build_placeholders_for_zoom(self.current_zoom) # Kick minimal warm-up @@ -465,13 +485,19 @@ def _kick_off_visible_render(self, force: bool = False, warm: int = 0): # ---------- Render completion ---------- @Slot(int, float, int, QPixmap) - def on_page_rendered(self, page_num: int, zoom: float, generation: int, pixmap: QPixmap): + def on_page_rendered( + self, page_num: int, zoom: float, generation: int, pixmap: QPixmap + ): """Handle rendered page: update cache and label if still relevant.""" # Discard outdated renders (other zoom or older generation) or failed pixmaps - if generation != self._renderGeneration or zoom != self.current_zoom or pixmap.isNull(): + if ( + generation != self._renderGeneration + or zoom != self.current_zoom + or pixmap.isNull() + ): return key = (page_num, zoom) self.image_cache[key] = pixmap lbl = self.labels[page_num] lbl.setPixmap(pixmap) - lbl.setText("") \ No newline at end of file + lbl.setText("") diff --git a/robert/gui_easyrob/utils/aqme_utils.py b/robert/gui_easyrob/utils/aqme_utils.py index e17696e..15e190b 100644 --- a/robert/gui_easyrob/utils/aqme_utils.py +++ b/robert/gui_easyrob/utils/aqme_utils.py @@ -46,6 +46,7 @@ # ------------------------------------------------------------ from .utils_gui import DropLabel + class ChemDrawFileDialog(QDialog): """Dialog that collects the main ChemDraw/SDF input file.""" @@ -88,10 +89,13 @@ def set_main_file(self, path): def continue_clicked(self): """Check if a main ChemDraw file has been selected and accept the dialog.""" if not self.main_chemdraw_path: - QMessageBox.warning(self, "Missing File", "Please select a main ChemDraw file.") + QMessageBox.warning( + self, "Missing File", "Please select a main ChemDraw file." + ) return self.accept() + def mcs_process(smiles_list, result_queue): """Find the maximum common substructure for a list of SMILES.""" try: @@ -160,4 +164,4 @@ def _on_timeout(self): if self.process and self.process.is_alive(): self.process.terminate() self.process.join() - self.timeout.emit() \ No newline at end of file + self.timeout.emit() diff --git a/robert/gui_easyrob/utils/molssi_utils.py b/robert/gui_easyrob/utils/molssi_utils.py index e39bdc6..4814ae7 100644 --- a/robert/gui_easyrob/utils/molssi_utils.py +++ b/robert/gui_easyrob/utils/molssi_utils.py @@ -39,6 +39,7 @@ from PySide6.QtCore import QThread, Signal + class MolSSIWorker(QThread): """Background worker responsible for resolving MolSSI descriptors.""" @@ -132,7 +133,7 @@ def canonicalize(smiles): def chunked(lst, size): """Split a list into chunks of a specified size.""" for i in range(0, len(lst), size): - yield lst[i:i + size] + yield lst[i : i + size] def safe_query_batched(smiles, library, data_type, batch_size=200): """Query MolSSI API in batches and handle partial failures.""" @@ -197,11 +198,15 @@ def full_coverage(smiles, df_api): for lib in dft_libraries: df_api = safe_query_batched(smiles_list, lib, "DFT") if full_coverage(smiles_list, df_api): - return _prepare_export(df_work, df_api, smiles_col, lib, "DFT", original_input_columns) + return _prepare_export( + df_work, df_api, smiles_col, lib, "DFT", original_input_columns + ) df_api = safe_query_batched(smiles_list, "kraken", "ML") if full_coverage(smiles_list, df_api): - return _prepare_export(df_work, df_api, smiles_col, "kraken", "ML", original_input_columns) + return _prepare_export( + df_work, df_api, smiles_col, "kraken", "ML", original_input_columns + ) return { "available": False, @@ -225,7 +230,9 @@ def _molssi_test_dataset_available(library_slug): return False -def _prepare_export(df_work, df_api, smiles_col, library, data_type, original_input_columns): +def _prepare_export( + df_work, df_api, smiles_col, library, data_type, original_input_columns +): """Prepare the merged MolSSI export DataFrame for use in easyROB.""" try: df_api = df_api.copy() @@ -237,7 +244,9 @@ def _prepare_export(df_work, df_api, smiles_col, library, data_type, original_in df_merged = df_work.merge(df_api, on="_smiles_canonical", how="left") export_df = df_merged.drop( - columns=[c for c in ["_smiles_canonical", "smiles"] if c in df_merged.columns] + columns=[ + c for c in ["_smiles_canonical", "smiles"] if c in df_merged.columns + ] ) export_df[smiles_col] = export_df["_smiles_original"] export_df = export_df.drop(columns=["_smiles_original"]) @@ -251,7 +260,10 @@ def _prepare_export(df_work, df_api, smiles_col, library, data_type, original_in ] if "molecule_id" in export_df.columns: - only_smiles_input = len(original_input_columns) == 1 and original_input_columns[0].lower() == "smiles" + only_smiles_input = ( + len(original_input_columns) == 1 + and original_input_columns[0].lower() == "smiles" + ) if not only_smiles_input: export_df = export_df.drop(columns=["molecule_id"]) @@ -278,27 +290,68 @@ def _prepare_export(df_work, df_api, smiles_col, library, data_type, original_in def fix_greek_caps_columns(col: str) -> str: """Normalize Greek characters and canonical spelling in MolSSI headers.""" greek_map = { - "α": "alpha", "β": "beta", "γ": "gamma", "δ": "delta", - "ε": "epsilon", "ζ": "zeta", "η": "eta", "θ": "theta", - "ι": "iota", "κ": "kappa", "λ": "lambda", "μ": "mu", - "ν": "nu", "ξ": "xi", "ο": "omicron", "π": "pi", - "ρ": "rho", "σ": "sigma", "τ": "tau", "υ": "upsilon", - "φ": "phi", "χ": "chi", "ψ": "psi", "ω": "omega", - "Α": "alpha", "Β": "beta", "Γ": "gamma", "Δ": "delta", - "Ε": "epsilon", "Ζ": "zeta", "Η": "eta", "Θ": "theta", - "Ι": "iota", "Κ": "kappa", "Λ": "lambda", "Μ": "mu", - "Ν": "nu", "Ξ": "xi", "Ο": "omicron", "Π": "pi", - "Ρ": "rho", "Σ": "sigma", "Τ": "tau", "Υ": "upsilon", - "Φ": "phi", "Χ": "chi", "Ψ": "psi", "Ω": "omega", + "α": "alpha", + "β": "beta", + "γ": "gamma", + "δ": "delta", + "ε": "epsilon", + "ζ": "zeta", + "η": "eta", + "θ": "theta", + "ι": "iota", + "κ": "kappa", + "λ": "lambda", + "μ": "mu", + "ν": "nu", + "ξ": "xi", + "ο": "omicron", + "π": "pi", + "ρ": "rho", + "σ": "sigma", + "τ": "tau", + "υ": "upsilon", + "φ": "phi", + "χ": "chi", + "ψ": "psi", + "ω": "omega", + "Α": "alpha", + "Β": "beta", + "Γ": "gamma", + "Δ": "delta", + "Ε": "epsilon", + "Ζ": "zeta", + "Η": "eta", + "Θ": "theta", + "Ι": "iota", + "Κ": "kappa", + "Λ": "lambda", + "Μ": "mu", + "Ν": "nu", + "Ξ": "xi", + "Ο": "omicron", + "Π": "pi", + "Ρ": "rho", + "Σ": "sigma", + "Τ": "tau", + "Υ": "upsilon", + "Φ": "phi", + "Χ": "chi", + "Ψ": "psi", + "Ω": "omega", } for greek_char, latin in greek_map.items(): col = col.replace(greek_char, latin) - col = re.sub(r"(? Path: """Given the path to a selected file, return the corresponding PREDICT/csv_test directory.""" return Path(selected_file_path).parent / "PREDICT" / "csv_test" + def find_prediction_csvs(selected_file_path: str) -> dict[str, Path]: """Search for prediction CSV files in the PREDICT/csv_test directory related to the selected file.""" predict_dir = get_predict_dir(selected_file_path) @@ -91,10 +93,12 @@ def find_prediction_csvs(selected_file_path: str) -> dict[str, Path]: results["PFI"] = path return results + def get_robert_report_path(selected_file_path: str | Path) -> Path: """Given the path to a selected file, return the corresponding ROBERT_report.pdf file.""" return Path(selected_file_path).parent / "ROBERT_report.pdf" + def find_external_test_pixmaps(base_path: str | Path) -> dict[str, QPixmap]: """Search for external test images in the PREDICT/csv_test directory related to the selected file.""" base_path = Path(base_path) @@ -119,6 +123,7 @@ def find_external_test_pixmaps(base_path: str | Path) -> dict[str, QPixmap]: return results + def extract_scores_from_robert_report(pdf_path: Path) -> dict: """Extract scores from the ROBERT report PDF file.""" result = {"pdf_found": False, "PFI": None, "No_PFI": None} @@ -133,6 +138,7 @@ def extract_scores_from_robert_report(pdf_path: Path) -> dict: return result + def extract_extrapolation_fragment(pdf_path: Path, model_key: str) -> QPixmap | None: """Render the extrapolation block from parsed ROBERT report data.""" details = _extract_extrapolation_details(pdf_path, model_key) @@ -140,6 +146,7 @@ def extract_extrapolation_fragment(pdf_path: Path, model_key: str) -> QPixmap | return None return _render_extrapolation_pixmap(details) + def extract_robert_fragment_image(pdf_path: Path, model_key: str) -> QPixmap | None: """Render the ROBERT score block from parsed report data.""" details = _extract_robert_score_details(pdf_path, model_key) @@ -147,6 +154,7 @@ def extract_robert_fragment_image(pdf_path: Path, model_key: str) -> QPixmap | N return None return _render_robert_score_pixmap(details) + def _get_extrapolation_bbox(page, model_key: str): """Return the PDF area containing the extrapolation block for the requested model.""" if model_key == "No_PFI": @@ -158,7 +166,9 @@ def _get_extrapolation_bbox(page, model_key: str): def _normalize_extrapolation_lines(text: str) -> list[str]: """Collapse noisy PDF whitespace while preserving the content of each line.""" - return [re.sub(r"\s+", " ", line).strip() for line in text.splitlines() if line.strip()] + return [ + re.sub(r"\s+", " ", line).strip() for line in text.splitlines() if line.strip() + ] def _parse_extrapolation_block(text: str) -> dict | None: @@ -168,7 +178,9 @@ def _parse_extrapolation_block(text: str) -> dict | None: lines = _normalize_extrapolation_lines(text) title_line = next((line for line in lines if "Extrapolation" in line), None) - rmse_line = next((line for line in lines if "[" in line and "]" in line and "%" in line), None) + rmse_line = next( + (line for line in lines if "[" in line and "]" in line and "%" in line), None + ) scoring_line = next((line for line in lines if "Scoring from" in line), None) rule_line = next((line for line in lines if "Every two folds" in line), None) @@ -183,7 +195,11 @@ def _parse_extrapolation_block(text: str) -> dict | None: if rmse_line: values_match = re.search(r"\[(.*?)\]", rmse_line) if values_match: - values = [value.strip() for value in values_match.group(1).split(",") if value.strip()] + values = [ + value.strip() + for value in values_match.group(1).split(",") + if value.strip() + ] clean_title = re.sub(r"\(\s*\d+\s*/\s*\d+\s*\)", "", title_line).strip() @@ -217,7 +233,9 @@ def _extract_extrapolation_details(pdf_path: Path, model_key: str) -> dict | Non return None -def _score_fill_rgb(obtained: int | None, maximum: int | None) -> tuple[float, float, float]: +def _score_fill_rgb( + obtained: int | None, maximum: int | None +) -> tuple[float, float, float]: """Return a fill color for the extrapolation score indicator.""" if obtained is None or maximum in (None, 0): return (0.78, 0.78, 0.78) @@ -295,8 +313,15 @@ def _render_extrapolation_pixmap(details: dict) -> QPixmap | None: fontname="hebo", color=(0.12, 0.20, 0.30), ) - rmse_text = f"[{', '.join(details['rmse_values'])}]" if details["rmse_values"] else "[]" - values_rect = fitz.Rect(margin - 2, title_y + line_gap + 12, width - margin, title_y + line_gap * 2 + 20) + rmse_text = ( + f"[{', '.join(details['rmse_values'])}]" if details["rmse_values"] else "[]" + ) + values_rect = fitz.Rect( + margin - 2, + title_y + line_gap + 12, + width - margin, + title_y + line_gap * 2 + 20, + ) page.draw_rect(values_rect, color=(0.86, 0.86, 0.86), fill=(1, 1, 1), width=0.8) page.insert_text( fitz.Point(margin + 8, title_y + line_gap * 2 + 8), @@ -334,17 +359,24 @@ def _parse_robert_score_block(text: str) -> dict | None: if not text: return None - lines = [re.sub(r"\s+", " ", line).strip() for line in text.splitlines() if line.strip()] + lines = [ + re.sub(r"\s+", " ", line).strip() for line in text.splitlines() if line.strip() + ] if not lines: return None - title_line = next((line for line in lines if re.search(r"\bScore\s+\d+\b", line, re.IGNORECASE)), None) + title_line = next( + (line for line in lines if re.search(r"\bScore\s+\d+\b", line, re.IGNORECASE)), + None, + ) if not title_line: joined_text = " ".join(lines) score_match = re.search(r"\bScore\s+(\d+)\b", joined_text, re.IGNORECASE) if not score_match: return None - title_match = re.search(r"(.+?)\s*[.\-·]?\s*Score\s+\d+\b", joined_text, re.IGNORECASE) + title_match = re.search( + r"(.+?)\s*[.\-·]?\s*Score\s+\d+\b", joined_text, re.IGNORECASE + ) title = title_match.group(1).strip() if title_match else "ROBERT Score" return { "title": title, @@ -357,7 +389,9 @@ def _parse_robert_score_block(text: str) -> dict | None: if not score_match: return None - title = re.sub(r"\s*[·\.-]?\s*Score\s+\d+\b.*$", "", title_line, flags=re.IGNORECASE).strip() + title = re.sub( + r"\s*[·\.-]?\s*Score\s+\d+\b.*$", "", title_line, flags=re.IGNORECASE + ).strip() return { "title": title, "score": int(score_match.group(1)), @@ -412,16 +446,46 @@ def _extract_robert_score_details(pdf_path: Path, model_key: str) -> dict | None def _robert_score_style(score: int | None) -> dict: """Return the visual style associated with a ROBERT score.""" if score is None: - return {"label": "UNKNOWN", "segments": 0, "fill": (0.92, 0.90, 0.96), "text": (0.35, 0.35, 0.35)} + return { + "label": "UNKNOWN", + "segments": 0, + "fill": (0.92, 0.90, 0.96), + "text": (0.35, 0.35, 0.35), + } if score <= 0: - return {"label": "VERY WEAK", "segments": 0, "fill": (1.0, 0.42, 0.42), "text": (1.0, 0.42, 0.42)} + return { + "label": "VERY WEAK", + "segments": 0, + "fill": (1.0, 0.42, 0.42), + "text": (1.0, 0.42, 0.42), + } if score <= 3: - return {"label": "VERY WEAK", "segments": score, "fill": (1.0, 0.42, 0.42), "text": (1.0, 0.42, 0.42)} + return { + "label": "VERY WEAK", + "segments": score, + "fill": (1.0, 0.42, 0.42), + "text": (1.0, 0.42, 0.42), + } if score <= 6: - return {"label": "WEAK", "segments": score, "fill": (1.0, 0.79, 0.38), "text": (1.0, 0.79, 0.38)} + return { + "label": "WEAK", + "segments": score, + "fill": (1.0, 0.79, 0.38), + "text": (1.0, 0.79, 0.38), + } if score <= 8: - return {"label": "MODERATE", "segments": score, "fill": (0.60, 0.78, 0.95), "text": (0.60, 0.78, 0.95)} - return {"label": "STRONG", "segments": min(score, 10), "fill": (0.38, 0.60, 0.80), "text": (0.38, 0.60, 0.80)} + return { + "label": "MODERATE", + "segments": score, + "fill": (0.60, 0.78, 0.95), + "text": (0.60, 0.78, 0.95), + } + return { + "label": "STRONG", + "segments": min(score, 10), + "fill": (0.38, 0.60, 0.80), + "text": (0.38, 0.60, 0.80), + } def _render_robert_score_pixmap(details: dict) -> QPixmap | None: @@ -521,7 +585,11 @@ def extract_extrapolation_scores(pdf_path: Path) -> dict: for model_key in ("No_PFI", "PFI"): details = _extract_extrapolation_details(pdf_path, model_key) - if details and details.get("obtained") is not None and details.get("maximum") is not None: + if ( + details + and details.get("obtained") is not None + and details.get("maximum") is not None + ): result[model_key] = { "obtained": details["obtained"], "maximum": details["maximum"], @@ -529,6 +597,7 @@ def extract_extrapolation_scores(pdf_path: Path) -> dict: return result + def extract_prediction_info(df: pd.DataFrame) -> dict: """Extract prediction information from a DataFrame.""" pred_cols = [col for col in df.columns if col.endswith("_pred")] @@ -548,18 +617,27 @@ def extract_prediction_info(df: pd.DataFrame) -> dict: result["has_pred_column"] = True result["pred_column"] = col result["n_unique"] = series.nunique() - result["predictions_identical"] = None if result["n_unique"] == 0 else result["n_unique"] == 1 + result["predictions_identical"] = ( + None if result["n_unique"] == 0 else result["n_unique"] == 1 + ) return result -def evaluate_model_scenario(score: int | None, predictions_identical: bool | None) -> dict: + +def evaluate_model_scenario( + score: int | None, predictions_identical: bool | None +) -> dict: """Evaluate the model scenario based on ROBERT score and predictions.""" almos_link = "https://github.com/MiguelMartzFdez/almos" almos_html = f'ALMOS' result = {"state": "UNKNOWN", "messages": [], "recommendations": []} if score is None: - result["messages"].append("No valid ROBERT score was detected. Model reliability cannot be evaluated.") - result["recommendations"].append("You may verify that ROBERT_report.pdf was generated correctly.") + result["messages"].append( + "No valid ROBERT score was detected. Model reliability cannot be evaluated." + ) + result["recommendations"].append( + "You may verify that ROBERT_report.pdf was generated correctly." + ) return result if predictions_identical is True: @@ -576,7 +654,9 @@ def evaluate_model_scenario(score: int | None, predictions_identical: bool | Non if 0 <= score <= 3: result["state"] = "FAILED" - result["messages"].append(f"ROBERT score is {score}. Model performance is critically low.") + result["messages"].append( + f"ROBERT score is {score}. Model performance is critically low." + ) result["recommendations"].append( "You may avoid using these predictions and rebuild the dataset " f"using Clustering module in {almos_html}." @@ -585,7 +665,9 @@ def evaluate_model_scenario(score: int | None, predictions_identical: bool | Non if 4 <= score <= 6: result["state"] = "WEAK" - result["messages"].append(f"ROBERT score is {score}. The model works, but reliability is limited.") + result["messages"].append( + f"ROBERT score is {score}. The model works, but reliability is limited." + ) result["recommendations"].append( "You may use predictions cautiously and improve robustness " f"through Active Learning module with {almos_html}." @@ -594,7 +676,9 @@ def evaluate_model_scenario(score: int | None, predictions_identical: bool | Non if 7 <= score <= 8: result["state"] = "DECENT" - result["messages"].append(f"ROBERT score is {score}. The model is solid but can still improve.") + result["messages"].append( + f"ROBERT score is {score}. The model is solid but can still improve." + ) result["recommendations"].append( "You may use these predictions while considering further optimization " f"through Active Learning module with {almos_html}." @@ -603,7 +687,9 @@ def evaluate_model_scenario(score: int | None, predictions_identical: bool | Non if score > 8: result["state"] = "STRONG" - result["messages"].append(f"ROBERT score is {score}. The model shows strong predictive performance.") + result["messages"].append( + f"ROBERT score is {score}. The model shows strong predictive performance." + ) result["recommendations"].append( "You may confidently use these predictions for candidate prioritization." ) @@ -611,7 +697,10 @@ def evaluate_model_scenario(score: int | None, predictions_identical: bool | Non return result -def evaluate_predictions_for_model(selected_file_path: str | Path, df: pd.DataFrame, model_key: str) -> dict: + +def evaluate_predictions_for_model( + selected_file_path: str | Path, df: pd.DataFrame, model_key: str +) -> dict: """Evaluate predictions for a specific model.""" pdf_path = get_robert_report_path(selected_file_path) scores = extract_scores_from_robert_report(pdf_path) @@ -628,6 +717,7 @@ def evaluate_predictions_for_model(selected_file_path: str | Path, df: pd.DataFr "scenario": scenario, } + def collect_model_info(selected_file_path: str | Path, df: pd.DataFrame) -> dict: """Collect information for all models.""" pdf_path = get_robert_report_path(selected_file_path) @@ -644,9 +734,19 @@ def collect_model_info(selected_file_path: str | Path, df: pd.DataFrame) -> dict "scenario": scenario, } + class PredictionDashboardPanel(QWidget): """A collapsible dashboard panel to display ROBERT prediction evaluation results and diagnostics.""" - def __init__(self, scenario: dict, pdf_image=None, extrapolation_score=None, extrapolation_image=None, external_plot=None, parent=None): + + def __init__( + self, + scenario: dict, + pdf_image=None, + extrapolation_score=None, + extrapolation_image=None, + external_plot=None, + parent=None, + ): super().__init__(parent) self._pdf_image = pdf_image self._extrapolation_score = extrapolation_score @@ -715,7 +815,9 @@ def _build_ui(self, scenario): self._build_status_block(content_layout, scenario) self._build_pdf_snapshot_block(content_layout) - self._build_extrapolation_block(content_layout, self._extrapolation_score, self._extrapolation_image) + self._build_extrapolation_block( + content_layout, self._extrapolation_score, self._extrapolation_image + ) self._build_external_validation_block(content_layout, self._external_plot) content_layout.addStretch() @@ -747,7 +849,9 @@ def _build_pdf_snapshot_block(self, layout): layout.addSpacing(15) container = QWidget() container.setObjectName("dashboardBlock") - container.setStyleSheet("QWidget#dashboardBlock { border: 1px solid palette(mid); border-radius: 8px; }") + container.setStyleSheet( + "QWidget#dashboardBlock { border: 1px solid palette(mid); border-radius: 8px; }" + ) container_layout = QVBoxLayout(container) container_layout.setContentsMargins(14, 14, 14, 14) container_layout.setSpacing(10) @@ -758,13 +862,19 @@ def _build_pdf_snapshot_block(self, layout): image_frame = QWidget() image_frame.setObjectName("imageFrame") - image_frame.setStyleSheet("QWidget#imageFrame { border: 1px solid palette(mid); border-radius: 6px; }") + image_frame.setStyleSheet( + "QWidget#imageFrame { border: 1px solid palette(mid); border-radius: 6px; }" + ) image_layout = QVBoxLayout(image_frame) image_layout.setContentsMargins(6, 6, 6, 6) image_label = QLabel() image_label.setAlignment(Qt.AlignCenter) - image_label.setPixmap(self._pdf_image.scaledToWidth(self.expanded_width - 120, Qt.SmoothTransformation)) + image_label.setPixmap( + self._pdf_image.scaledToWidth( + self.expanded_width - 120, Qt.SmoothTransformation + ) + ) image_layout.addWidget(image_label) container_layout.addWidget(image_frame) layout.addWidget(container) @@ -777,7 +887,9 @@ def _build_extrapolation_block(self, layout, score, pixmap): layout.addSpacing(15) container = QWidget() container.setObjectName("dashboardBlock") - container.setStyleSheet("QWidget#dashboardBlock { border: 1px solid palette(mid); border-radius: 8px; }") + container.setStyleSheet( + "QWidget#dashboardBlock { border: 1px solid palette(mid); border-radius: 8px; }" + ) container_layout = QVBoxLayout(container) container_layout.setContentsMargins(14, 14, 14, 14) container_layout.setSpacing(10) @@ -786,7 +898,9 @@ def _build_extrapolation_block(self, layout, score, pixmap): title.setStyleSheet("font-weight: bold; font-size: 13px;") container_layout.addWidget(title) - subtitle = QLabel("Assessment of the model's ability to predict beyond the range of the training data.") + subtitle = QLabel( + "Assessment of the model's ability to predict beyond the range of the training data." + ) subtitle.setWordWrap(True) subtitle.setStyleSheet("font-size: 11px;") container_layout.addWidget(subtitle) @@ -794,12 +908,16 @@ def _build_extrapolation_block(self, layout, score, pixmap): if pixmap: image_frame = QWidget() image_frame.setObjectName("imageFrame") - image_frame.setStyleSheet("QWidget#imageFrame { border: 1px solid palette(mid); border-radius: 6px; }") + image_frame.setStyleSheet( + "QWidget#imageFrame { border: 1px solid palette(mid); border-radius: 6px; }" + ) image_layout = QVBoxLayout(image_frame) image_layout.setContentsMargins(6, 6, 6, 6) image_label = QLabel() image_label.setAlignment(Qt.AlignCenter) - image_label.setPixmap(pixmap.scaledToWidth(self.expanded_width - 120, Qt.SmoothTransformation)) + image_label.setPixmap( + pixmap.scaledToWidth(self.expanded_width - 120, Qt.SmoothTransformation) + ) image_layout.addWidget(image_label) container_layout.addWidget(image_frame) @@ -848,7 +966,9 @@ def _build_external_validation_block(self, layout, pixmap): layout.addSpacing(15) container = QWidget() container.setObjectName("dashboardBlock") - container.setStyleSheet("QWidget#dashboardBlock { border: 1px solid palette(mid); border-radius: 8px; }") + container.setStyleSheet( + "QWidget#dashboardBlock { border: 1px solid palette(mid); border-radius: 8px; }" + ) container_layout = QVBoxLayout(container) container_layout.setContentsMargins(14, 14, 14, 14) container_layout.setSpacing(10) @@ -857,19 +977,25 @@ def _build_external_validation_block(self, layout, pixmap): title.setStyleSheet("font-weight: bold; font-size: 13px;") container_layout.addWidget(title) - subtitle = QLabel("Predicted vs experimental values for molecules with known target data.") + subtitle = QLabel( + "Predicted vs experimental values for molecules with known target data." + ) subtitle.setWordWrap(True) subtitle.setStyleSheet("font-size: 11px;") container_layout.addWidget(subtitle) image_frame = QWidget() image_frame.setObjectName("imageFrame") - image_frame.setStyleSheet("QWidget#imageFrame { border: 1px solid palette(mid); border-radius: 6px; }") + image_frame.setStyleSheet( + "QWidget#imageFrame { border: 1px solid palette(mid); border-radius: 6px; }" + ) image_layout = QVBoxLayout(image_frame) image_layout.setContentsMargins(6, 6, 6, 6) image_label = QLabel() image_label.setAlignment(Qt.AlignCenter) - image_label.setPixmap(pixmap.scaledToWidth(self.expanded_width - 120, Qt.SmoothTransformation)) + image_label.setPixmap( + pixmap.scaledToWidth(self.expanded_width - 120, Qt.SmoothTransformation) + ) image_layout.addWidget(image_label) container_layout.addWidget(image_frame) @@ -891,6 +1017,7 @@ def toggle(self): class PandasTableModel(QAbstractTableModel): """A Qt table model that wraps a pandas DataFrame, with special handling for SMILES rendering and sorting optimization.""" + def __init__(self, df: pd.DataFrame): super().__init__() self._df = df @@ -950,8 +1077,10 @@ def sort_key(self, column: int) -> np.ndarray: self._sort_cache[column] = col.astype(str).to_numpy() return self._sort_cache[column] + class StatsHeader(QHeaderView): """A custom header view that displays column names and basic statistics for numeric columns.""" + def __init__(self, df: pd.DataFrame, orientation, parent=None): super().__init__(orientation, parent) self._df = df @@ -998,11 +1127,15 @@ def paintSection(self, painter, rect, logical_index): bold_font.setBold(True) painter.setFont(bold_font) fm = QFontMetrics(bold_font) - name_height = fm.boundingRect(0, 0, r.width(), 1000, Qt.AlignHCenter | Qt.TextWordWrap, col_name).height() + name_height = fm.boundingRect( + 0, 0, r.width(), 1000, Qt.AlignHCenter | Qt.TextWordWrap, col_name + ).height() name_rect = rect.adjusted(margin, margin, -margin, -margin) name_rect.setHeight(name_height) - painter.drawText(name_rect, Qt.AlignHCenter | Qt.AlignTop | Qt.TextWordWrap, col_name) + painter.drawText( + name_rect, Qt.AlignHCenter | Qt.AlignTop | Qt.TextWordWrap, col_name + ) if stats is not None: normal_font = painter.font() @@ -1016,12 +1149,18 @@ def paintSection(self, painter, rect, logical_index): ) metrics_rect = r metrics_rect.setTop(name_rect.bottom() + 6) - painter.drawText(metrics_rect, Qt.AlignHCenter | Qt.AlignTop | Qt.TextWordWrap, metrics_text) + painter.drawText( + metrics_rect, + Qt.AlignHCenter | Qt.AlignTop | Qt.TextWordWrap, + metrics_text, + ) painter.restore() + class ColumnStatsWidget(QWidget): """A widget that displays basic statistics for numeric columns in a DataFrame.""" + def __init__(self, df: pd.DataFrame, parent=None): super().__init__(parent) layout = QHBoxLayout(self) @@ -1049,10 +1188,13 @@ def __init__(self, df: pd.DataFrame, parent=None): class LoadCsvSignals(QObject): """Signals for the LoadCsvTask.""" + done = Signal(str, pd.DataFrame) + class LoadCsvTask(QRunnable): """A task for loading a CSV file.""" + def __init__(self, key: str, path: Path): super().__init__() self.key = key @@ -1066,8 +1208,10 @@ def run(self): except Exception as exc: print(f"Failed to load CSV {self.path}: {exc}") + class NumericSortProxy(QSortFilterProxyModel): """A proxy model that optimizes sorting for numeric columns by caching sort keys.""" + def lessThan(self, left, right): model = self.sourceModel() keys = model.sort_key(left.column()) diff --git a/robert/gui_easyrob/utils/utils_gui.py b/robert/gui_easyrob/utils/utils_gui.py index 5542454..155f8b2 100644 --- a/robert/gui_easyrob/utils/utils_gui.py +++ b/robert/gui_easyrob/utils/utils_gui.py @@ -29,36 +29,21 @@ # ------------------------------------------------------------ # Standard library # ------------------------------------------------------------ -import csv -import glob import os import platform -import re import shlex -import shutil import subprocess import sys import threading -from functools import partial -from io import BytesIO from pathlib import Path -from importlib.metadata import PackageNotFoundError, version from importlib.resources import as_file, files # ------------------------------------------------------------ # Third-party libraries # ------------------------------------------------------------ import pandas as pd -import matplotlib.pyplot as plt import psutil -import fitz -import rdkit -from rdkit import Chem -from rdkit.Chem import Draw, rdDepictor, rdFMCS -from rdkit.Chem.Draw import rdMolDraw2D -from rdkit.Chem.rdmolfiles import MolsFromCDXMLFile -from rdkit.Chem.rdmolops import GetMolFrags from ansi2html import Ansi2HTMLConverter @@ -66,80 +51,32 @@ # Qt (PySide6) # ------------------------------------------------------------ from PySide6.QtCore import ( - QByteArray, - QEventLoop, - QAbstractTableModel, - QModelIndex, - QObject, - QRunnable, - QRect, - QSize, - QSortFilterProxyModel, QThread, - QThreadPool, - QTimer, Qt, Signal, - Slot, - QUrl, ) from PySide6.QtGui import ( - QDesktopServices, - QFontMetrics, - QIcon, - QImage, - QMouseEvent, - QPalette, - QPixmap, QWheelEvent, ) -from PySide6.QtWebEngineCore import QWebEngineDownloadRequest -from PySide6.QtWebEngineWidgets import QWebEngineView from PySide6.QtWidgets import ( - QApplication, - QCheckBox, QComboBox, - QDialog, QFileDialog, - QFormLayout, QFrame, - QGridLayout, - QGroupBox, - QHBoxLayout, - QHeaderView, - QInputDialog, QLabel, - QLineEdit, - QListWidget, - QMainWindow, - QMenu, - QMessageBox, - QProgressBar, QPushButton, - QScrollArea, - QSizePolicy, - QSlider, - QStackedWidget, - QStatusBar, - QStyle, - QStyleOptionHeader, - QTabWidget, - QTableView, - QTableWidget, - QTableWidgetItem, - QTextEdit, - QToolButton, QVBoxLayout, - QWidget, ) + class DropLabel(QFrame): """Frame-based drop target with an optional file dialog button.""" - def __init__(self, text, parent=None, file_filter="CSV Files (*.csv)", extensions=(".csv",)): + def __init__( + self, text, parent=None, file_filter="CSV Files (*.csv)", extensions=(".csv",) + ): super().__init__(parent) self.file_filter = file_filter self.valid_extensions = extensions @@ -186,7 +123,9 @@ def set_file_type(self, file_filter, extensions): def open_file_dialog(self): """Open a file dialog to select a file.""" - file_path, _ = QFileDialog.getOpenFileName(self, "Select File", "", self.file_filter) + file_path, _ = QFileDialog.getOpenFileName( + self, "Select File", "", self.file_filter + ) if file_path and self.callback: self.set_file_path(file_path) @@ -220,6 +159,7 @@ def setText(self, text): """Set the text of the label.""" self.label.setText(text) + class RobertWorker(QThread): """QThread that runs a subprocess asynchronously and streams real-time output.""" @@ -250,7 +190,8 @@ def run(self): text=True, bufsize=1, universal_newlines=True, - creationflags=subprocess.CREATE_NEW_PROCESS_GROUP | subprocess.CREATE_NO_WINDOW, + creationflags=subprocess.CREATE_NEW_PROCESS_GROUP + | subprocess.CREATE_NO_WINDOW, ) else: self.process = subprocess.Popen( @@ -270,7 +211,9 @@ def read_stdout(): for line in self.process.stdout: if self._stop_requested: break - formatted_line = self.ansi_converter.convert(line.strip(), full=False) + formatted_line = self.ansi_converter.convert( + line.strip(), full=False + ) self.output_received.emit(formatted_line) except Exception as exc: self.error_received.emit(f"Error reading stdout: {exc}") @@ -281,7 +224,9 @@ def read_stderr(): for line in self.process.stderr: if self._stop_requested: break - formatted_line = f'{line.strip()}' + formatted_line = ( + f'{line.strip()}' + ) self.error_received.emit(formatted_line) reset_line = self.ansi_converter.convert("\033[0m", full=False) @@ -339,6 +284,7 @@ def _handle_stop(self): except Exception as exc: self.error_received.emit(f"Error stopping process: {exc}") + def smart_read_csv(filepath): """Read a CSV file with automatic delimiter detection.""" try: @@ -350,6 +296,7 @@ def smart_read_csv(filepath): except (FileNotFoundError, OSError): return None + class NoScrollComboBox(QComboBox): """Combo box that ignores wheel events while the popup is closed.""" @@ -359,6 +306,7 @@ def wheelEvent(self, event: QWheelEvent): else: event.ignore() + class AssetPath: """Resolve asset paths both in development and in frozen distributions.""" @@ -379,6 +327,7 @@ def get_path(self): ) return as_file(files("robert") / "icons" / self._filename) + class AssetLibrary: """Central registry of asset files used by the GUI.""" diff --git a/robert/gui_easyrob/version.py b/robert/gui_easyrob/version.py index 2bf2c8b..1d28c50 100644 --- a/robert/gui_easyrob/version.py +++ b/robert/gui_easyrob/version.py @@ -24,12 +24,14 @@ EASYROB_VERSION = "2.0.0" + def get_python_package_version(pkg): try: return version(pkg) except PackageNotFoundError: return "Not found" + def get_cli_version(cmd): try: result = subprocess.run([cmd, "--version"], capture_output=True, text=True) @@ -37,6 +39,7 @@ def get_cli_version(cmd): except Exception: return "Not found" + def get_xtb_version(): try: result = subprocess.run(["xtb", "--version"], capture_output=True, text=True) @@ -48,6 +51,7 @@ def get_xtb_version(): except Exception: return "Not found" + def get_software_versions(): return { "easyROB": EASYROB_VERSION, @@ -60,4 +64,5 @@ def get_software_versions(): }, } -SOFTWARE_VERSIONS = get_software_versions() \ No newline at end of file + +SOFTWARE_VERSIONS = get_software_versions() diff --git a/robert/predict.py b/robert/predict.py index f2a17d0..55084f6 100644 --- a/robert/predict.py +++ b/robert/predict.py @@ -5,15 +5,15 @@ destination : str, default=None, Directory to create the output file(s). varfile : str, default=None - Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml). + Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml). params_dir : str, default='' Folder containing the database and parameters of the ML model. csv_test : str, default='' - Name of the CSV file containing the test set (if any). A path can be provided (i.e. - 'C:/Users/FOLDER/FILE.csv'). + Name of the CSV file containing the test set (if any). A path can be provided (i.e. + 'C:/Users/FOLDER/FILE.csv'). t_value : float, default=2 t-value that will be the threshold to identify outliers (check tables for t-values elsewhere). - The higher the t-value the more restrictive the analysis will be (i.e. there will be more + The higher the t-value the more restrictive the analysis will be (i.e. there will be more outliers with t-value=1 than with t-value = 4). alpha : float, default=0.05 Significance level, or probability of making a wrong decision. This parameter is related to @@ -61,6 +61,7 @@ should_plot_predict_deep_diagnostics, ) + class predict: """ Class containing all the functions from the PREDICT module. @@ -72,7 +73,6 @@ class predict: """ def __init__(self, **kwargs): - start_time = time.time() # load default and user-specified variables @@ -85,12 +85,13 @@ def __init__(self, **kwargs): self.args.params_dir ): if os.path.exists(params_dir): - - _ = print_pfi(self,params_dir) + _ = print_pfi(self, params_dir) # load the Xy databse and model parameters - Xy_data, model_data, suffix_title = load_db_n_params(self,params_dir,suffix,suffix_title,"verify",True) # module 'verify' since PREDICT follows similar protocols - + Xy_data, model_data, suffix_title = load_db_n_params( + self, params_dir, suffix, suffix_title, "verify", True + ) # module 'verify' since PREDICT follows similar protocols + # get results from training, test and external test (if any) Xy_data = load_n_predict(self, model_data, Xy_data, BO_opt=False) if getattr(self.args, "uq_enable_meta", False): @@ -98,9 +99,7 @@ def __init__(self, **kwargs): self, Xy_data, model_data, params_dir ) if getattr(self.args, "uq_auto_enable", False): - Xy_data = apply_auto_uq( - self, Xy_data, model_data, params_dir - ) + Xy_data = apply_auto_uq(self, Xy_data, model_data, params_dir) # save predictions for all sets path_n_suffix, name_points, Xy_data = save_predictions( @@ -146,4 +145,4 @@ def __init__(self, **kwargs): self, Xy_data, path_n_suffix, model_data ) - _ = finish_print(self,start_time,'PREDICT') + _ = finish_print(self, start_time, "PREDICT") diff --git a/robert/predict_utils.py b/robert/predict_utils.py index 83e66b3..762cd41 100644 --- a/robert/predict_utils.py +++ b/robert/predict_utils.py @@ -99,60 +99,91 @@ def _append_split_columns( def plot_predictions(self, params_dict, Xy_data, path_n_suffix): - ''' + """ Plot graphs of predicted vs actual values for train, validation and test sets - ''' + """ + + set_types = [ + f"{params_dict['repeat_kfolds']}x {params_dict['kfold']}-fold CV", + "test", + ] - set_types = [f"{params_dict['repeat_kfolds']}x {params_dict['kfold']}-fold CV",'test'] - graph_style = get_graph_style() - - self.args.log.write(f"\n o Saving graphs in:") - if params_dict['type'].lower() == 'reg': + self.args.log.write("\n o Saving graphs in:") + + if params_dict["type"].lower() == "reg": # Plot graph with all sets - _ = graph_reg(self,Xy_data,params_dict,set_types,path_n_suffix,graph_style) + _ = graph_reg(self, Xy_data, params_dict, set_types, path_n_suffix, graph_style) # Plot CV average ± SD graph of validation or test set - _ = graph_reg(self,Xy_data,params_dict,set_types,path_n_suffix,graph_style,sd_graph=True) - if 'y_external' in Xy_data and not Xy_data['y_external'].isnull().values.any() and len(Xy_data['y_external']) > 0: + _ = graph_reg( + self, + Xy_data, + params_dict, + set_types, + path_n_suffix, + graph_style, + sd_graph=True, + ) + if ( + "y_external" in Xy_data + and not Xy_data["y_external"].isnull().values.any() + and len(Xy_data["y_external"]) > 0 + ): # Plot CV average ± SD graph of external set - set_type = 'external' - _ = graph_reg(self,Xy_data,params_dict,set_type,path_n_suffix,graph_style,csv_test=True,sd_graph=True) + set_type = "external" + _ = graph_reg( + self, + Xy_data, + params_dict, + set_type, + path_n_suffix, + graph_style, + csv_test=True, + sd_graph=True, + ) - elif params_dict['type'].lower() == 'clas': + elif params_dict["type"].lower() == "clas": for set_type in set_types: - _ = graph_clas(self,Xy_data,params_dict,set_type,path_n_suffix) - if 'y_external' in Xy_data and not Xy_data['y_external'].isnull().values.any() and len(Xy_data['y_external']) > 0: - set_type = 'external' - _ = graph_clas(self,Xy_data,params_dict,set_type,path_n_suffix,csv_test=True) + _ = graph_clas(self, Xy_data, params_dict, set_type, path_n_suffix) + if ( + "y_external" in Xy_data + and not Xy_data["y_external"].isnull().values.any() + and len(Xy_data["y_external"]) > 0 + ): + set_type = "external" + _ = graph_clas( + self, Xy_data, params_dict, set_type, path_n_suffix, csv_test=True + ) return graph_style -def save_predictions(self,Xy_data,model_data,suffix_title): - ''' +def save_predictions(self, Xy_data, model_data, suffix_title): + """ Saves CSV files with the different sets and their predicted results - ''' + """ # Check if we need to reconvert class labels (for classification with string labels) reconvert_labels = False class_mapping_reverse = None - if 'class_0_label' in model_data and 'class_1_label' in model_data: + if "class_0_label" in model_data and "class_1_label" in model_data: reconvert_labels = True class_mapping_reverse = { - 0: model_data['class_0_label'], - 1: model_data['class_1_label'] + 0: model_data["class_0_label"], + 1: model_data["class_1_label"], } # save CV and test results as a single df - Xy_train, Xy_test = pd.DataFrame(Xy_data['names_train']), pd.DataFrame(Xy_data['names_test']) - for col in Xy_data['X_train']: - Xy_train[col] = Xy_data['X_train'][col].tolist() - Xy_test[col] = Xy_data['X_test'][col].tolist() - + Xy_train, Xy_test = ( + pd.DataFrame(Xy_data["names_train"]), + pd.DataFrame(Xy_data["names_test"]), + ) + for col in Xy_data["X_train"]: + Xy_train[col] = Xy_data["X_train"][col].tolist() + Xy_test[col] = Xy_data["X_test"][col].tolist() + # Store y values and predictions, reconverting if needed - y_col = model_data['y'] - hw_scalar = float(Xy_data.get("conformal_half_width", float("nan"))) if model_data["type"].lower() != "reg": hw_scalar = float("nan") @@ -182,19 +213,19 @@ def save_predictions(self,Xy_data,model_data,suffix_title): df_results = pd.concat([Xy_train, Xy_test], axis=0) # add column with sets - train_list = ['CV' for _ in Xy_data['y_train']] - test_list = ['Test' for _ in Xy_data['y_test']] + train_list = ["CV" for _ in Xy_data["y_train"]] + test_list = ["Test" for _ in Xy_data["y_test"]] col_set = train_list + test_list - df_results['Set'] = col_set + df_results["Set"] = col_set # save results as CSV base_csv_name = f"PREDICT/{model_data['model']}_{suffix_title}" base_csv_path = f"{Path(os.getcwd()).joinpath(base_csv_name)}" - path_n_suffix = f'{base_csv_path}' - _ = df_results.to_csv(f'{base_csv_path}.csv', index = None, header=True) - + path_n_suffix = f"{base_csv_path}" + _ = df_results.to_csv(f"{base_csv_path}.csv", index=None, header=True) + # also save results for performance of individual folds (useful for t-tests and Wilcoxon tests between the folds) - error1, error2, error3 = get_error_labels(model_data['type']) + error1, error2, error3 = get_error_labels(model_data["type"]) # df_folds = pd.DataFrame() # df_folds['Fold'] = [f'{i+1}' for i in range(len(Xy_data['idx_valid']))] @@ -210,12 +241,14 @@ def save_predictions(self,Xy_data,model_data,suffix_title): # _ = df_folds.to_csv(f'{path_folds}.csv', index = None, header=True) # prints - print_preds = f' o Saving CSV databases with predictions and their SD in:' - print_preds += f'\n - Predicted results of starting dataset: {base_csv_name}.csv' + print_preds = " o Saving CSV databases with predictions and their SD in:" + print_preds += ( + f"\n - Predicted results of starting dataset: {base_csv_name}.csv" + ) - if self.args.csv_test != '': + if self.args.csv_test != "": # saves prediction for external test in --csv_test - Xy_external = pd.DataFrame(Xy_data['names_external']) + Xy_external = pd.DataFrame(Xy_data["names_external"]) for col in Xy_data["X_external"]: Xy_external[col] = Xy_data["X_external"][col].tolist() @@ -241,30 +274,36 @@ def save_predictions(self,Xy_data,model_data,suffix_title): Xy_external ) if auto_src is not None: - Xy_external[f"{model_data['y']}_pred_uq_auto_source"] = [ - auto_src - ] * len(Xy_external) + Xy_external[f"{model_data['y']}_pred_uq_auto_source"] = [auto_src] * len( + Xy_external + ) - path_external = Path(os.getcwd()).joinpath('PREDICT/csv_test/') + path_external = Path(os.getcwd()).joinpath("PREDICT/csv_test/") Path(path_external).mkdir(exist_ok=True, parents=True) - csv_name_external = f'{os.path.basename(self.args.csv_test).split(".csv")[0]}_{model_data["model"]}_{suffix_title}.csv' + csv_name_external = f"{os.path.basename(self.args.csv_test).split('.csv')[0]}_{model_data['model']}_{suffix_title}.csv" name_external = f"{path_external}/{csv_name_external}" - _ = Xy_external.to_csv(name_external, index = None, header=True) - print_preds += f'\n - External set with predicted results: PREDICT/csv_test/{csv_name_external}' + _ = Xy_external.to_csv(name_external, index=None, header=True) + print_preds += f"\n - External set with predicted results: PREDICT/csv_test/{csv_name_external}" self.args.log.write(print_preds) # store the names of the datapoints name_points = {} - if model_data['names'] != '': - if model_data['names'].lower() in Xy_train: # accounts for upper/lowercase mismatches - model_data['names'] = model_data['names'].lower() - if model_data['names'].upper() in Xy_train: - model_data['names'] = model_data['names'].upper() - if model_data['names'] in Xy_train: - name_points['train'] = df_results[model_data['names']][df_results.Set == 'CV'] - name_points['test'] = df_results[model_data['names']][df_results.Set == 'Test'] + if model_data["names"] != "": + if ( + model_data["names"].lower() in Xy_train + ): # accounts for upper/lowercase mismatches + model_data["names"] = model_data["names"].lower() + if model_data["names"].upper() in Xy_train: + model_data["names"] = model_data["names"].upper() + if model_data["names"] in Xy_train: + name_points["train"] = df_results[model_data["names"]][ + df_results.Set == "CV" + ] + name_points["test"] = df_results[model_data["names"]][ + df_results.Set == "Test" + ] return path_n_suffix, name_points, Xy_data @@ -280,20 +319,17 @@ def _ensure_pred_range_stats(Xy_data): Xy_data["pred_range"] = float(np.abs(pred_max - pred_min)) -def print_predict(self,Xy_data,model_data,suffix_title): - ''' +def print_predict(self, Xy_data, model_data, suffix_title): + """ Prints results of the predictions for all the sets - ''' + """ _ensure_pred_range_stats(Xy_data) - print_results = ( - "\n o Summary of results " - f"{model_data['model']}_{suffix_title}:" - ) + print_results = f"\n o Summary of results {model_data['model']}_{suffix_title}:" # get number of points and proportions - n_train = len(Xy_data['y_train']) - n_test = len(Xy_data['y_test']) + n_train = len(Xy_data["y_train"]) + n_test = len(Xy_data["y_test"]) print_results += ( "\n - Point counts: CV (train+valid.) = " f"{n_train}, held-out test = {n_test}" @@ -303,11 +339,10 @@ def print_predict(self,Xy_data,model_data,suffix_title): prop_train = round(n_train * 100 / total_points) prop_test = round(n_test * 100 / total_points) print_results += ( - f"\n - Proportion CV (train+valid.):test = " - f"{prop_train}:{prop_test}" + f"\n - Proportion CV (train+valid.):test = {prop_train}:{prop_test}" ) - n_descps = len(Xy_data['X_train'].keys()) + n_descps = len(Xy_data["X_train"].keys()) print_results += f"\n - Number of descriptors = {n_descps}" print_results += ( "\n - Proportion (train+valid.) points:descriptors = " @@ -316,55 +351,66 @@ def print_predict(self,Xy_data,model_data,suffix_title): # print results and save dat file CV_type = f"{model_data['repeat_kfolds']}x {model_data['kfold']}-fold CV" - if model_data['type'].lower() == 'reg': + if model_data["type"].lower() == "reg": print_results += f"\n - {CV_type} : R2 = {Xy_data['r2_train']:.2}, MAE = {Xy_data['mae_train']:.2}, RMSE = {Xy_data['rmse_train']:.2}" print_results += f"\n - Test : R2 = {Xy_data['r2_test']:.2}, MAE = {Xy_data['mae_test']:.2}, RMSE = {Xy_data['rmse_test']:.2}" print_results += f"\n - Average SD in test set = {np.mean(Xy_data['y_pred_test_sd']):.2}" print_results += f"\n - y range of dataset (train+valid.) = {float(Xy_data['pred_min']):.2} to {float(Xy_data['pred_max']):.2}, total {float(Xy_data['pred_range']):.2}" - if 'y_external' in Xy_data and not Xy_data['y_external'].isnull().values.any() and len(Xy_data['y_external']) > 0: + if ( + "y_external" in Xy_data + and not Xy_data["y_external"].isnull().values.any() + and len(Xy_data["y_external"]) > 0 + ): print_results += f"\n - External test : R2 = {Xy_data['r2_external']:.2}, MAE = {Xy_data['mae_external']:.2}, RMSE = {Xy_data['rmse_external']:.2}" - elif model_data['type'].lower() == 'clas': + elif model_data["type"].lower() == "clas": print_results += f"\n - {CV_type} : Accur. = {Xy_data['acc_train']:.2}, F1 score = {Xy_data['f1_train']:.2}, MCC = {Xy_data['mcc_train']:.2}" - if 'y_pred_test' in Xy_data and not Xy_data['y_test'].isnull().values.any() and len(Xy_data['y_test']) > 0: + if ( + "y_pred_test" in Xy_data + and not Xy_data["y_test"].isnull().values.any() + and len(Xy_data["y_test"]) > 0 + ): print_results += f"\n - Test : Accur. = {Xy_data['acc_test']:.2}, F1 score = {Xy_data['f1_test']:.2}, MCC = {Xy_data['mcc_test']:.2}" - if 'y_external' in Xy_data and not Xy_data['y_external'].isnull().values.any() and len(Xy_data['y_external']) > 0: + if ( + "y_external" in Xy_data + and not Xy_data["y_external"].isnull().values.any() + and len(Xy_data["y_external"]) > 0 + ): print_results += f"\n - External test : Accur. = {Xy_data['acc_external']:.2}, F1 score = {Xy_data['f1_external']:.2}, MCC = {Xy_data['mcc_external']:.2}" self.args.log.write(print_results) -def pearson_map_predict(self,Xy_data,params_dir): - ''' +def pearson_map_predict(self, Xy_data, params_dir): + """ Plots the Pearson map and analyzes correlation of descriptors. - ''' + """ - X_combined = pd.concat([Xy_data['X_train'], Xy_data['X_test']], axis=0, ignore_index=True) - corr_matrix = pearson_map(self,X_combined,'predict',params_dir=params_dir) + X_combined = pd.concat( + [Xy_data["X_train"], Xy_data["X_test"]], axis=0, ignore_index=True + ) + corr_matrix = pearson_map(self, X_combined, "predict", params_dir=params_dir) - corr_dict = {'descp_1': [], - 'descp_2': [], - 'r': [] - } - for i,descp in enumerate(corr_matrix.columns): - for j,val in enumerate(corr_matrix[descp]): + corr_dict = {"descp_1": [], "descp_2": [], "r": []} + for i, descp in enumerate(corr_matrix.columns): + for j, val in enumerate(corr_matrix[descp]): if i < j and np.abs(val) > 0.8: - corr_dict['descp_1'].append(corr_matrix.columns[i]) - corr_dict['descp_2'].append(corr_matrix.columns[j]) - corr_dict['r'].append(val) + corr_dict["descp_1"].append(corr_matrix.columns[i]) + corr_dict["descp_2"].append(corr_matrix.columns[j]) + corr_dict["r"].append(val) - print_corr = f' Ideally, variables should show low correlations.' # no initial \n, it's a new log.write - if len(corr_dict['descp_1']) == 0: - print_corr += f"\n o Correlations between variables are acceptable" + print_corr = " Ideally, variables should show low correlations." # no initial \n, it's a new log.write + if len(corr_dict["descp_1"]) == 0: + print_corr += "\n o Correlations between variables are acceptable" else: - abs_r_list = list(np.abs(corr_dict['r'])) + abs_r_list = list(np.abs(corr_dict["r"])) abs_max_r = max(abs_r_list) - max_r = corr_dict['r'][abs_r_list.index(abs_max_r)] - max_descp_1 = corr_dict['descp_1'][abs_r_list.index(abs_max_r)] - max_descp_2 = corr_dict['descp_2'][abs_r_list.index(abs_max_r)] + max_r = corr_dict["r"][abs_r_list.index(abs_max_r)] + max_descp_1 = corr_dict["descp_1"][abs_r_list.index(abs_max_r)] + max_descp_2 = corr_dict["descp_2"][abs_r_list.index(abs_max_r)] if abs_max_r > 0.84: - print_corr += f"\n x WARNING! High correlations observed (up to r = {max_r:.2} or R2 = {max_r*max_r:.2}, for {max_descp_1} and {max_descp_2})" + print_corr += f"\n x WARNING! High correlations observed (up to r = {max_r:.2} or R2 = {max_r * max_r:.2}, for {max_descp_1} and {max_descp_2})" elif abs_max_r > 0.71: - print_corr += f"\n x WARNING! Noticeable correlations observed (up to r = {max_r:.2} or R2 = {max_r*max_r:.2}, for {max_descp_1} and {max_descp_2})" + print_corr += f"\n x WARNING! Noticeable correlations observed (up to r = {max_r:.2} or R2 = {max_r * max_r:.2}, for {max_descp_1} and {max_descp_2})" self.args.log.write(print_corr) diff --git a/robert/report.py b/robert/report.py index a7431f0..e995653 100644 --- a/robert/report.py +++ b/robert/report.py @@ -5,7 +5,7 @@ destination : str, default=None, Directory to create the output file(s). varfile : str, default=None - Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml). + Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml). report_modules : list of str, default=['AQME','CURATE','GENERATE','VERIFY','PREDICT'] List of the modules to include in the report. debug_report : bool, default=False @@ -23,9 +23,9 @@ import json import platform import pandas as pd -import traceback from pathlib import Path -from robert.utils import (load_variables, +from robert.utils import ( + load_variables, pd_to_dict, ) from robert.report_utils import ( @@ -52,7 +52,7 @@ get_outliers, detect_predictions, get_csv_metrics, - get_csv_pred + get_csv_pred, ) @@ -69,23 +69,26 @@ class report: def __init__(self, **kwargs): # check if there is a problem with weasyprint (required for this module) # Suppress fontconfig warnings during import on Windows - if platform.system() == 'Windows': + if platform.system() == "Windows": import tempfile - temp_stderr = tempfile.TemporaryFile(mode='w+') + + temp_stderr = tempfile.TemporaryFile(mode="w+") old_stderr = os.dup(2) os.dup2(temp_stderr.fileno(), 2) - + try: from weasyprint import HTML except (OSError, ModuleNotFoundError): - if platform.system() == 'Windows': + if platform.system() == "Windows": os.dup2(old_stderr, 2) os.close(old_stderr) temp_stderr.close() - print(f"\nx The REPORT module requires some libraries that are missing, the PDF with the summary of the results has not been created. Try installing the libraries with 'conda install -y -c conda-forge glib gtk3 pango mscorefonts'") + print( + "\nx The REPORT module requires some libraries that are missing, the PDF with the summary of the results has not been created. Try installing the libraries with 'conda install -y -c conda-forge glib gtk3 pango mscorefonts'" + ) sys.exit() finally: - if platform.system() == 'Windows': + if platform.system() == "Windows": os.dup2(old_stderr, 2) os.close(old_stderr) temp_stderr.close() @@ -95,46 +98,56 @@ def __init__(self, **kwargs): eval_only = False # if EVALUATE is activated, no PFI models are generated - path_eval = Path(f'{os.getcwd()}/EVALUATE/EVALUATE_data.dat') + path_eval = Path(f"{os.getcwd()}/EVALUATE/EVALUATE_data.dat") if os.path.exists(path_eval): eval_only = True # get spacing between No PFI and PFI columns - spacing_PFI = f'{(" ")*4}' + spacing_PFI = f"{(' ') * 4}" # Reproducibility section (these functions only gather information, the sections # will be print later in the report) - citation_dat, repro_dat, dat_files, csv_name, robert_version = self.get_repro(eval_only) + citation_dat, repro_dat, dat_files, csv_name, robert_version = self.get_repro( + eval_only + ) - # Transparency section - transpa_dat,params_df = self.get_transparency(spacing_PFI) - pred_type = params_df['type'][0].lower() + # Transparency section + transpa_dat, params_df = self.get_transparency(spacing_PFI) + pred_type = params_df["type"][0].lower() # print header report_html = self.print_header(citation_dat) # print ROBERT score section - score_dat,data_score = self.print_score(dat_files,pred_type,eval_only,spacing_PFI) + score_dat, data_score = self.print_score( + dat_files, pred_type, eval_only, spacing_PFI + ) report_html += score_dat # print warnings in ROBERT score section - warnings_dat,warnings_dict = self.print_warnings(pred_type,eval_only,data_score) + warnings_dat, warnings_dict = self.print_warnings( + pred_type, eval_only, data_score + ) report_html += warnings_dat # print advanced score analysis - report_html += self.print_adv_anal(pred_type,eval_only,spacing_PFI,data_score) + report_html += self.print_adv_anal( + pred_type, eval_only, spacing_PFI, data_score + ) # print y distribution - report_html += self.print_y_distrib(pred_type,eval_only,spacing_PFI,warnings_dict) + report_html += self.print_y_distrib( + pred_type, eval_only, spacing_PFI, warnings_dict + ) # print feature importances - report_html += self.print_features(warnings_dict,eval_only,spacing_PFI) + report_html += self.print_features(warnings_dict, eval_only, spacing_PFI) # print outlier analysis - report_html += self.print_outliers(pred_type,eval_only,spacing_PFI) + report_html += self.print_outliers(pred_type, eval_only, spacing_PFI) # print model screening - report_html += self.print_generate(pred_type,eval_only) + report_html += self.print_generate(pred_type, eval_only) # print reproducibility section report_html += repro_dat @@ -146,7 +159,7 @@ def __init__(self, **kwargs): report_html += self.get_abbrev() # print new predictions - report_html += self.print_predictions(pred_type,eval_only,spacing_PFI) + report_html += self.print_predictions(pred_type, eval_only, spacing_PFI) # print miscellaneous section report_html += self.print_misc() @@ -157,34 +170,35 @@ def __init__(self, **kwargs): # create css with open("report.css", "w", encoding="utf-8") as cssfile: - cssfile.write(css_content(csv_name,robert_version)) + cssfile.write(css_content(csv_name, robert_version)) # Suppress fontconfig warnings from WeasyPrint on Windows # These warnings come from the C library level, so we need to redirect at OS level - if platform.system() == 'Windows': + if platform.system() == "Windows": import tempfile - + # Create a temporary file to redirect stderr - temp_stderr = tempfile.TemporaryFile(mode='w+') + temp_stderr = tempfile.TemporaryFile(mode="w+") old_stderr = os.dup(2) # Duplicate stderr file descriptor os.dup2(temp_stderr.fileno(), 2) # Redirect stderr to temp file - + try: - _ = make_report(report_html,HTML) + _ = make_report(report_html, HTML) finally: os.dup2(old_stderr, 2) # Restore stderr os.close(old_stderr) temp_stderr.close() else: - _ = make_report(report_html,HTML) + _ = make_report(report_html, HTML) # Remove report.css file os.remove("report.css") - - print('\no ROBERT_report.pdf was created successfully in the working directory!') + print( + "\no ROBERT_report.pdf was created successfully in the working directory!" + ) - def print_header(self,citation_dat): + def print_header(self, citation_dat): """ Retrieves the header for the HTML string """ @@ -200,154 +214,172 @@ def print_header(self,citation_dat): return header_lines - - def print_score(self,dat_files,pred_type,eval_only,spacing_PFI): + def print_score(self, dat_files, pred_type, eval_only, spacing_PFI): """ Generates the ROBERT score section """ - + # starts with the icon of ROBERT score - score_dat = '' - score_dat = self.module_lines('score',score_dat) + score_dat = "" + score_dat = self.module_lines("score", score_dat) # calculates the ROBERT scores (R2 is analogous for accuracy in classification) data_score = {} - columns_score,columns_summary = [],[] + columns_score, columns_summary = [], [] # get two columns to combine and print - for suffix in ['No PFI','PFI']: - spacing = get_spacing_col(suffix,spacing_PFI) + for suffix in ["No PFI", "PFI"]: + spacing = get_spacing_col(suffix, spacing_PFI) - if eval_only and suffix == 'PFI': - columns_score.append('') + if eval_only and suffix == "PFI": + columns_score.append("") else: # calculate score - data_score = calc_score(dat_files,suffix,pred_type,data_score) + data_score = calc_score(dat_files, suffix, pred_type, data_score) # initial two-column ROBERT score summary - score_info = f"""{spacing}

""" - columns_score.append(get_col_score(score_info,data_score,suffix,spacing,eval_only)) + score_info = f"""{spacing}

""" + columns_score.append( + get_col_score(score_info, data_score, suffix, spacing, eval_only) + ) # Combine both columns score_dat += combine_cols(columns_score) - + # add corresponding images - diff_height = 25 # account for different graph sizes in reg and clas + diff_height = 25 # account for different graph sizes in reg and clas height = 221 - if pred_type == 'clas': + if pred_type == "clas": height += diff_height - score_dat += self.print_img('Results',-5,height,'PREDICT',pred_type,eval_only,diff_names=True) + score_dat += self.print_img( + "Results", -5, height, "PREDICT", pred_type, eval_only, diff_names=True + ) - for suffix in ['No PFI','PFI']: - spacing = get_spacing_col(suffix,spacing_PFI) + for suffix in ["No PFI", "PFI"]: + spacing = get_spacing_col(suffix, spacing_PFI) - if eval_only and suffix == 'PFI': - columns_summary.append('') + if eval_only and suffix == "PFI": + columns_summary.append("") else: # metrics of the models - module_file = f'{os.getcwd()}/PREDICT/PREDICT_data.dat' - columns_summary.append(get_metrics(module_file,suffix,spacing)) + module_file = f"{os.getcwd()}/PREDICT/PREDICT_data.dat" + columns_summary.append(get_metrics(module_file, suffix, spacing)) # Combine both columns score_dat += combine_cols(columns_summary) - return score_dat,data_score - + return score_dat, data_score - def print_warnings(self,pred_type,eval_only,data_score): + def print_warnings(self, pred_type, eval_only, data_score): """ Generates the warning boxes in the ROBERT score section """ # load spacing, colors, and line and table formats - space,color_dict,style_lines,warnings_dat = self.get_warning_params() + space, color_dict, style_lines, warnings_dat = self.get_warning_params() # gather the lines from PREDICT where the potential warnings are print warnings_dict = self.get_warning_lines(pred_type) columns_warnings = [] # get two columns to combine and print - warnings_dict['severe_warnings_No PFI'], warnings_dict['severe_warnings_PFI'] = [],[] - warnings_dict['moderate_warnings_No PFI'], warnings_dict['moderate_warnings_PFI'] = [],[] - for suffix in ['No PFI','PFI']: - - if eval_only and suffix == 'PFI': - columns_warnings.append('') + ( + warnings_dict["severe_warnings_No PFI"], + warnings_dict["severe_warnings_PFI"], + ) = [], [] + ( + warnings_dict["moderate_warnings_No PFI"], + warnings_dict["moderate_warnings_PFI"], + ) = [], [] + for suffix in ["No PFI", "PFI"]: + if eval_only and suffix == "PFI": + columns_warnings.append("") else: - if suffix == 'No PFI': + if suffix == "No PFI": margin_left = 0 else: margin_left = 29 # analyze and append warnings - warnings_dict = self.analyze_warnings(data_score,suffix,warnings_dict,pred_type) + warnings_dict = self.analyze_warnings( + data_score, suffix, warnings_dict, pred_type + ) # add table in the corresponding column - warning_print = f''' + warning_print = f""" - -
''' + """ # add severe warnings - warning_print += f''' -

{space}Severe warnings

''' - if len(warnings_dict[f'severe_warnings_{suffix}']) == 0: + warning_print += f""" +

{space}Severe warnings

""" + if len(warnings_dict[f"severe_warnings_{suffix}"]) == 0: warning_print += self.print_line_warning( - 'No severe warnings detected', - style_lines,color_dict['blue'],space) + "No severe warnings detected", + style_lines, + color_dict["blue"], + space, + ) else: - for sev_warning in warnings_dict[f'severe_warnings_{suffix}']: + for sev_warning in warnings_dict[f"severe_warnings_{suffix}"]: warning_print += self.print_line_warning( - sev_warning, - style_lines,color_dict['red'],space) + sev_warning, style_lines, color_dict["red"], space + ) # add moderate warnings - warning_print += f''' -

{space}Moderate warnings

''' - if len(warnings_dict[f'moderate_warnings_{suffix}']) == 0: + warning_print += f""" +

{space}Moderate warnings

""" + if len(warnings_dict[f"moderate_warnings_{suffix}"]) == 0: warning_print += self.print_line_warning( - 'No moderate warnings detected', - style_lines,color_dict['blue'],space) + "No moderate warnings detected", + style_lines, + color_dict["blue"], + space, + ) else: - for mode_warning in warnings_dict[f'moderate_warnings_{suffix}']: + for mode_warning in warnings_dict[f"moderate_warnings_{suffix}"]: warning_print += self.print_line_warning( - mode_warning, - style_lines,color_dict['yellow'],space) + mode_warning, style_lines, color_dict["yellow"], space + ) # add overall assessment - warning_print += self.print_assessment(space,suffix,data_score,style_lines,warnings_dict,color_dict,pred_type) - + warning_print += self.print_assessment( + space, + suffix, + data_score, + style_lines, + warnings_dict, + color_dict, + pred_type, + ) + # end table - warning_print += f'''
''' - + warning_print += """ + """ + columns_warnings.append(warning_print) # Combine both columns warnings_dat += combine_cols(columns_warnings) # page break - warnings_dat += f"""

""" - - return warnings_dat,warnings_dict + warnings_dat += """

""" + return warnings_dat, warnings_dict def get_warning_params(self): - ''' + """ Load spacing, colors, and line and table formats - ''' - - space = ' ' - color_dict = { - 'red': '#c56666', - 'yellow': '#c5c57d', - 'blue': '#9ba5e3' - } + """ + + space = " " + color_dict = {"red": "#c56666", "yellow": "#c5c57d", "blue": "#9ba5e3"} style_lines = '

' # table style - warnings_dat = ''' - ''' + """ - return space,color_dict,style_lines,warnings_dat + return space, color_dict, style_lines, warnings_dat - def analyze_warnings(self,data_score,suffix,warnings_dict,pred_type): - ''' + def analyze_warnings(self, data_score, suffix, warnings_dict, pred_type): + """ Analyze and append warnings - ''' - + """ + # tests from flawed models - if data_score[f'flawed_mod_score_{suffix}'] < 0: - if data_score[f'failed_tests_{suffix}'] > 0: - warnings_dict[f'severe_warnings_{suffix}'].append('Failing required tests (Section B.1)') + if data_score[f"flawed_mod_score_{suffix}"] < 0: + if data_score[f"failed_tests_{suffix}"] > 0: + warnings_dict[f"severe_warnings_{suffix}"].append( + "Failing required tests (Section B.1)" + ) else: - warnings_dict[f'moderate_warnings_{suffix}'].append('Some tests are unclear (Section B.1)') + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Some tests are unclear (Section B.1)" + ) # variation in CV - if pred_type == 'reg': - if data_score[f'cv_sd_score_{suffix}'] == 0: - warnings_dict[f'moderate_warnings_{suffix}'].append('Imprecise predictions (Section B.3b)') - elif pred_type == 'clas': - if data_score[f'diff_mcc_score_{suffix}'] == 0: - warnings_dict[f'moderate_warnings_{suffix}'].append('Imprecise predictions (Section B.3b)') + if pred_type == "reg": + if data_score[f"cv_sd_score_{suffix}"] == 0: + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Imprecise predictions (Section B.3b)" + ) + elif pred_type == "clas": + if data_score[f"diff_mcc_score_{suffix}"] == 0: + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Imprecise predictions (Section B.3b)" + ) # y distribution - if 'WARNING! Your data is not uniform' in warnings_dict[f'y_dist_info_{suffix}']: - if pred_type == 'reg': - warnings_dict[f'moderate_warnings_{suffix}'].append('Uneven y distribution (Section C)') - elif pred_type == 'clas': # it's severe in clasification - warnings_dict[f'severe_warnings_{suffix}'].append('Very uneven class distribution (Section C)') - elif 'WARNING! Your data is slightly not uniform' in warnings_dict[f'y_dist_info_{suffix}']: - if pred_type == 'reg': - warnings_dict[f'moderate_warnings_{suffix}'].append('Slightly uneven y distribution (Section C)') - elif pred_type == 'clas': - warnings_dict[f'moderate_warnings_{suffix}'].append('Uneven class distribution (Section C)') + if ( + "WARNING! Your data is not uniform" + in warnings_dict[f"y_dist_info_{suffix}"] + ): + if pred_type == "reg": + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Uneven y distribution (Section C)" + ) + elif pred_type == "clas": # it's severe in clasification + warnings_dict[f"severe_warnings_{suffix}"].append( + "Very uneven class distribution (Section C)" + ) + elif ( + "WARNING! Your data is slightly not uniform" + in warnings_dict[f"y_dist_info_{suffix}"] + ): + if pred_type == "reg": + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Slightly uneven y distribution (Section C)" + ) + elif pred_type == "clas": + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Uneven class distribution (Section C)" + ) # feature correlation - if 'WARNING! High correlations' in warnings_dict[f'pearson_info_{suffix}']: - warnings_dict[f'moderate_warnings_{suffix}'].append('Highly correlated features (Section D)') - elif 'WARNING! Noticeable correlations' in warnings_dict[f'pearson_info_{suffix}']: - warnings_dict[f'moderate_warnings_{suffix}'].append('Moderately correlated features (Section D)') + if "WARNING! High correlations" in warnings_dict[f"pearson_info_{suffix}"]: + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Highly correlated features (Section D)" + ) + elif ( + "WARNING! Noticeable correlations" + in warnings_dict[f"pearson_info_{suffix}"] + ): + warnings_dict[f"moderate_warnings_{suffix}"].append( + "Moderately correlated features (Section D)" + ) # outliers (threshold is set above 6.5 SD, around 99.9 CI) - if pred_type == 'reg': - if warnings_dict[f'max_sd_{suffix}'] > 6.5: - warnings_dict[f'moderate_warnings_{suffix}'].append('Potential "faulty" outliers (Section E)') + if pred_type == "reg": + if warnings_dict[f"max_sd_{suffix}"] > 6.5: + warnings_dict[f"moderate_warnings_{suffix}"].append( + 'Potential "faulty" outliers (Section E)' + ) return warnings_dict - - def get_warning_lines(self,pred_type): - ''' + def get_warning_lines(self, pred_type): + """ Gather the lines from PREDICT where the potential warnings are print - ''' - + """ + warnings_dict = {} # get lines with warnings from PREDICT - file_pred = f'{os.getcwd()}/PREDICT/PREDICT_data.dat' - with open(file_pred, 'r', encoding='utf-8') as datfile: + file_pred = f"{os.getcwd()}/PREDICT/PREDICT_data.dat" + with open(file_pred, "r", encoding="utf-8") as datfile: lines = datfile.readlines() - pfi_section_pearson = False # to get both No PFI and PFI information + pfi_section_pearson = False # to get both No PFI and PFI information pfi_section_y_dist = False pfi_section_outlier = False - for i,line in enumerate(lines): - if 'Ideally, variables should show low' in line and not pfi_section_pearson: - warnings_dict['pearson_info_No PFI'] = lines[i+1][6:] - pfi_section_pearson = True # the next line found will correspond to the PFI section - elif 'Ideally, variables should show low' in line and pfi_section_pearson: - warnings_dict['pearson_info_PFI'] = lines[i+1][6:] - if 'Ideally, the number of datapoints in' in line and not pfi_section_y_dist: - warnings_dict['y_dist_info_No PFI'] = lines[i+2][6:] + for i, line in enumerate(lines): + if ( + "Ideally, variables should show low" in line + and not pfi_section_pearson + ): + warnings_dict["pearson_info_No PFI"] = lines[i + 1][6:] + pfi_section_pearson = ( + True # the next line found will correspond to the PFI section + ) + elif ( + "Ideally, variables should show low" in line and pfi_section_pearson + ): + warnings_dict["pearson_info_PFI"] = lines[i + 1][6:] + if ( + "Ideally, the number of datapoints in" in line + and not pfi_section_y_dist + ): + warnings_dict["y_dist_info_No PFI"] = lines[i + 2][6:] pfi_section_y_dist = True - elif 'Ideally, the number of datapoints in' in line and pfi_section_y_dist: - warnings_dict['y_dist_info_PFI'] = lines[i+2][6:] - if pred_type == 'reg': - if 'Outliers plot saved' in line and not pfi_section_outlier: + elif ( + "Ideally, the number of datapoints in" in line + and pfi_section_y_dist + ): + warnings_dict["y_dist_info_PFI"] = lines[i + 2][6:] + if pred_type == "reg": + if "Outliers plot saved" in line and not pfi_section_outlier: max_SD = 0 - for j in range(i,len(lines)): - if '-------' in lines[j]: + for j in range(i, len(lines)): + if "-------" in lines[j]: break - elif 'SDs' in lines[j]: + elif "SDs" in lines[j]: sd_line = float(lines[j].split()[2][1:]) if sd_line > max_SD: max_SD = sd_line - warnings_dict['max_sd_No PFI'] = max_SD + warnings_dict["max_sd_No PFI"] = max_SD pfi_section_outlier = True - elif 'Outliers plot saved' in line and pfi_section_outlier: + elif "Outliers plot saved" in line and pfi_section_outlier: max_SD = 0 - for j in range(i,len(lines)): - if '-------' in lines[j]: + for j in range(i, len(lines)): + if "-------" in lines[j]: break - elif 'SDs' in lines[j]: + elif "SDs" in lines[j]: sd_line = float(lines[j].split()[2][1:]) if sd_line > max_SD: max_SD = sd_line - warnings_dict['max_sd_PFI'] = max_SD + warnings_dict["max_sd_PFI"] = max_SD return warnings_dict - - def print_line_warning(self,message,style_lines,color,space): - ''' + def print_line_warning(self, message, style_lines, color, space): + """ Add line with warning - ''' - - return f''' - {style_lines}{space}◉ - {space}{message}

''' - + """ - def print_assessment(self,space,suffix,data_score,style_lines,warnings_dict,color_dict,pred_type): - ''' + return f""" + {style_lines}{space}◉ + {space}{message}

""" + + def print_assessment( + self, + space, + suffix, + data_score, + style_lines, + warnings_dict, + color_dict, + pred_type, + ): + """ Add overall assessment to the ROBERT score section - ''' - - assessment_print = f''' -

{space}Overall assessment

''' + """ - if len(warnings_dict[f'severe_warnings_{suffix}']) > 0 or data_score[f'robert_score_{suffix}'] < 5: - assessment_print += self.print_line_warning( - 'The model is unreliable', - style_lines,color_dict['red'],space) + assessment_print = f""" +

{space}Overall assessment

""" - elif data_score[f'robert_score_{suffix}'] in [9,10]: - if pred_type == 'reg' and len(warnings_dict[f'moderate_warnings_{suffix}']) >= 3: + if ( + len(warnings_dict[f"severe_warnings_{suffix}"]) > 0 + or data_score[f"robert_score_{suffix}"] < 5 + ): + assessment_print += self.print_line_warning( + "The model is unreliable", style_lines, color_dict["red"], space + ) + + elif data_score[f"robert_score_{suffix}"] in [9, 10]: + if ( + pred_type == "reg" + and len(warnings_dict[f"moderate_warnings_{suffix}"]) >= 3 + ): assessment_print += self.print_line_warning( - 'Reliable model, but examine warnings', - style_lines,color_dict['yellow'],space) - elif pred_type == 'clas' and len(warnings_dict[f'moderate_warnings_{suffix}']) >= 2: + "Reliable model, but examine warnings", + style_lines, + color_dict["yellow"], + space, + ) + elif ( + pred_type == "clas" + and len(warnings_dict[f"moderate_warnings_{suffix}"]) >= 2 + ): assessment_print += self.print_line_warning( - 'Reliable model, but examine warnings', - style_lines,color_dict['yellow'],space) + "Reliable model, but examine warnings", + style_lines, + color_dict["yellow"], + space, + ) else: assessment_print += self.print_line_warning( - f'The model seems reliable', - style_lines,color_dict['blue'],space) + "The model seems reliable", style_lines, color_dict["blue"], space + ) - elif data_score[f'robert_score_{suffix}'] in [7,8]: + elif data_score[f"robert_score_{suffix}"] in [7, 8]: assessment_print += self.print_line_warning( - 'Decent model, but it has limitations', - style_lines,color_dict['yellow'],space) + "Decent model, but it has limitations", + style_lines, + color_dict["yellow"], + space, + ) - elif data_score[f'robert_score_{suffix}'] in [5,6]: + elif data_score[f"robert_score_{suffix}"] in [5, 6]: assessment_print += self.print_line_warning( - 'Moderate model, with important limitations', - style_lines,color_dict['yellow'],space) - - return assessment_print + "Moderate model, with important limitations", + style_lines, + color_dict["yellow"], + space, + ) + return assessment_print - def print_adv_anal(self,pred_type,eval_only,spacing_PFI,data_score): + def print_adv_anal(self, pred_type, eval_only, spacing_PFI, data_score): """ Generates the advanced score analysis section """ - adv_score_dat = '' + adv_score_dat = "" - adv_score_dat += self.module_lines('adv_anal',adv_score_dat) + adv_score_dat += self.module_lines("adv_anal", adv_score_dat) # parts of the robert score section - score_sections = ['adv_flawed'] - score_sections.append('adv_flawed_extra') - score_sections.append('adv_predict') - score_sections.append('adv_test') - score_sections.append('adv_diff_test') - score_sections.append('adv_cv_sd') - score_sections.append('adv_cv_diff') - score_sections.append('adv_sorted_cv') + score_sections = ["adv_flawed"] + score_sections.append("adv_flawed_extra") + score_sections.append("adv_predict") + score_sections.append("adv_test") + score_sections.append("adv_diff_test") + score_sections.append("adv_cv_sd") + score_sections.append("adv_cv_diff") + score_sections.append("adv_sorted_cv") for section in score_sections: columns_score = [] # get two columns to combine and print - for suffix in ['No PFI','PFI']: - + for suffix in ["No PFI", "PFI"]: # add spacing of PFI column - if suffix == 'No PFI': - spacing = '' - elif suffix == 'PFI': + if suffix == "No PFI": + spacing = "" + elif suffix == "PFI": spacing = spacing_PFI - if eval_only and suffix == 'PFI': - columns_score.append('') + if eval_only and suffix == "PFI": + columns_score.append("") else: - - if section == 'adv_flawed': + if section == "adv_flawed": # advanced score analysis 1, flawed models - columns_score.append(adv_flawed(self,suffix,data_score,spacing*2)) + columns_score.append( + adv_flawed(self, suffix, data_score, spacing * 2) + ) - elif section == 'adv_predict': + elif section == "adv_predict": # advanced score analysis 2, predictive ability - columns_score.append(adv_predict(self,suffix,data_score,spacing*2,pred_type)) + columns_score.append( + adv_predict( + self, suffix, data_score, spacing * 2, pred_type + ) + ) - elif section == 'adv_test': + elif section == "adv_test": # advanced score analysis 3 and 3a, predictive ability of CV - columns_score.append(adv_test(self,suffix,data_score,spacing*2,pred_type)) + columns_score.append( + adv_test(self, suffix, data_score, spacing * 2, pred_type) + ) - elif section == 'adv_cv_sd' and pred_type == 'reg': + elif section == "adv_cv_sd" and pred_type == "reg": # advanced score analysis 3b, SD of CV - columns_score.append(adv_cv_sd(self,suffix,data_score,spacing*2)) + columns_score.append( + adv_cv_sd(self, suffix, data_score, spacing * 2) + ) - elif section == 'adv_diff_test': + elif section == "adv_diff_test": # advanced score analysis 3c, difference bwteen RMSE in test vs CV - columns_score.append(adv_diff_test(self,suffix,data_score,spacing*2,pred_type)) + columns_score.append( + adv_diff_test( + self, suffix, data_score, spacing * 2, pred_type + ) + ) - elif section == 'adv_sorted_cv': + elif section == "adv_sorted_cv": # advanced score analysis 3d, descriptor proportion - columns_score.append(adv_sorted_cv(self,suffix,data_score,spacing*2,pred_type)) + columns_score.append( + adv_sorted_cv( + self, suffix, data_score, spacing * 2, pred_type + ) + ) - elif section == 'adv_cv_diff' and pred_type == 'clas': + elif section == "adv_cv_diff" and pred_type == "clas": # advanced score analysis 3b, difference of MCC in model and CV - columns_score.append(adv_cv_diff(self,suffix,data_score,spacing*2,pred_type)) + columns_score.append( + adv_cv_diff( + self, suffix, data_score, spacing * 2, pred_type + ) + ) # Combine both columns adv_score_dat += combine_cols(columns_score) # add corresponding images - section_separator = f'
' + section_separator = '
' - if section == 'adv_flawed': + if section == "adv_flawed": height = 223 - if pred_type == 'clas': + if pred_type == "clas": height -= 15 - adv_score_dat += self.print_img('VERIFY_tests',13,height,'VERIFY',pred_type,eval_only) + adv_score_dat += self.print_img( + "VERIFY_tests", 13, height, "VERIFY", pred_type, eval_only + ) # page break to second page adv_score_dat += '
' - elif section == 'adv_predict': + elif section == "adv_predict": adv_score_dat += section_separator - elif section == 'adv_cv_sd' and pred_type == 'reg': - adv_score_dat += self.print_img('CV_variability',10,221,'PREDICT',pred_type,eval_only) + elif section == "adv_cv_sd" and pred_type == "reg": + adv_score_dat += self.print_img( + "CV_variability", 10, 221, "PREDICT", pred_type, eval_only + ) - elif section == 'adv_cv_diff' and pred_type == 'clas': + elif section == "adv_cv_diff" and pred_type == "clas": adv_score_dat += section_separator - elif section == 'adv_sorted_cv': + elif section == "adv_sorted_cv": adv_score_dat += '

' return adv_score_dat - def print_misc(self): """ Generates the miscellaneous section """ - misc_dat = '' - misc_dat += self.module_lines('misc',misc_dat) + misc_dat = "" + misc_dat += self.module_lines("misc", misc_dat) # get some tips - style_line = '

' # reduces line separation separation - misc_dat += f"""

Some general tips to improve the score

""" - misc_dat += f'

1. Adding meaningful datapoints might help to improve the model. Also, using a uniform population of datapoints across the whole range of y values usually helps to obtain reliable predictions across the whole range. More information about the range of y values used is available in Section C.

' - misc_dat += f'{style_line}2. Adding meaningful descriptors or replacing/deleting the least useful descriptors used might help. Feature importances are gathered in Section D.

' - + style_line = '

' # reduces line separation separation + misc_dat += """

Some general tips to improve the score

""" + misc_dat += '

1. Adding meaningful datapoints might help to improve the model. Also, using a uniform population of datapoints across the whole range of y values usually helps to obtain reliable predictions across the whole range. More information about the range of y values used is available in Section C.

' + misc_dat += f"{style_line}2. Adding meaningful descriptors or replacing/deleting the least useful descriptors used might help. Feature importances are gathered in Section D.

" # how to predict new values misc_dat += f""" @@ -624,70 +748,77 @@ def print_misc(self): misc_dat += '
' return misc_dat - - def print_outliers(self,pred_type,eval_only,spacing_PFI): + def print_outliers(self, pred_type, eval_only, spacing_PFI): """ Generates the outliers section """ - + # starts with the icon of outliers - outlier_dat = '' - outlier_dat = self.module_lines('outliers',outlier_dat,pred_type=pred_type) + outlier_dat = "" + outlier_dat = self.module_lines("outliers", outlier_dat, pred_type=pred_type) - if pred_type == 'reg': + if pred_type == "reg": columns_outlier = [] # get two columns to combine and print - for suffix in ['No PFI','PFI']: - spacing = get_spacing_col(suffix,spacing_PFI) + for suffix in ["No PFI", "PFI"]: + spacing = get_spacing_col(suffix, spacing_PFI) - if eval_only and suffix == 'PFI': - columns_outlier.append('') + if eval_only and suffix == "PFI": + columns_outlier.append("") else: # get information about outliers - module_file = f'{os.getcwd()}/PREDICT/PREDICT_data.dat' - columns_outlier.append(get_outliers(module_file,suffix,spacing)) + module_file = f"{os.getcwd()}/PREDICT/PREDICT_data.dat" + columns_outlier.append(get_outliers(module_file, suffix, spacing)) # Combine both columns outlier_dat += combine_cols(columns_outlier) - + # add corresponding images height = 217 - outlier_dat += self.print_img('Outliers',-5,height,'PREDICT',pred_type,eval_only) + outlier_dat += self.print_img( + "Outliers", -5, height, "PREDICT", pred_type, eval_only + ) # add separator line and page break outlier_dat += '
' - outlier_dat += f"""

""" + outlier_dat += """

""" return outlier_dat - - def print_y_distrib(self,pred_type,eval_only,spacing_PFI,warnings_dict): + def print_y_distrib(self, pred_type, eval_only, spacing_PFI, warnings_dict): """ Generates the y distribution section """ - + # starts with the icon of outliers - distrib_dat = '' - distrib_dat = self.module_lines('y_distrib',distrib_dat) - + distrib_dat = "" + distrib_dat = self.module_lines("y_distrib", distrib_dat) + # add corresponding images height = 220 - distrib_dat += self.print_img('y_distribution',-5,height,'PREDICT',pred_type,eval_only) + distrib_dat += self.print_img( + "y_distribution", -5, height, "PREDICT", pred_type, eval_only + ) columns_y_distrib = [] # get two columns to combine and print - for suffix in ['No PFI','PFI']: - spacing = get_spacing_col(suffix,spacing_PFI) + for suffix in ["No PFI", "PFI"]: + spacing = get_spacing_col(suffix, spacing_PFI) - if eval_only and suffix == 'PFI': - columns_y_distrib.append('') + if eval_only and suffix == "PFI": + columns_y_distrib.append("") else: # split the sentence into 1 column size and add spacing line by line - y_distrib_sentence = format_lines(warnings_dict[f'y_dist_info_{suffix}'],max_width=55,one_column=True,spacing=spacing) + y_distrib_sentence = format_lines( + warnings_dict[f"y_dist_info_{suffix}"], + max_width=55, + one_column=True, + spacing=spacing, + ) column = f""" -

{spacing*3}y distribution analysis

+

{spacing * 3}y distribution analysis

{y_distrib_sentence}

""" @@ -700,25 +831,30 @@ def print_y_distrib(self,pred_type,eval_only,spacing_PFI,warnings_dict): # add separator line and page break distrib_dat += '
' - distrib_dat += f"""

""" + distrib_dat += """

""" return distrib_dat - - def print_features(self,warnings_dict,eval_only,spacing_PFI): + def print_features(self, warnings_dict, eval_only, spacing_PFI): """ Generates the feature analysis section """ - + # starts with the icon of feature importances - feature_dat = '' - feature_dat = self.module_lines('features',feature_dat) - + feature_dat = "" + feature_dat = self.module_lines("features", feature_dat) + # Add linear model equations - spacing = get_spacing_col('PFI',spacing_PFI) - with open(f'{os.getcwd()}/PREDICT/PREDICT_data.dat', 'r', encoding='utf-8') as file: + spacing = get_spacing_col("PFI", spacing_PFI) + with open( + f"{os.getcwd()}/PREDICT/PREDICT_data.dat", "r", encoding="utf-8" + ) as file: lines = file.readlines() - linear_model_eqs = [lines[i + 1].strip().lstrip('- ') for i, line in enumerate(lines) if 'o Linear model equation' in line] + linear_model_eqs = [ + lines[i + 1].strip().lstrip("- ") + for i, line in enumerate(lines) + if "o Linear model equation" in line + ] if linear_model_eqs: feature_dat += "

Linear model equation_No_PFI

" @@ -728,17 +864,21 @@ def print_features(self,warnings_dict,eval_only,spacing_PFI): columns_eq = [] for i, eq in enumerate(linear_model_eqs): if i == 0: - columns_eq.append(f"

{eq}

") + columns_eq.append( + f"

{eq}

" + ) else: - columns_eq.append(f"

{spacing*3}Linear model equation_PFI

{spacing*3}{eq}

") + columns_eq.append( + f"

{spacing * 3}Linear model equation_PFI

{spacing * 3}{eq}

" + ) feature_dat += combine_cols(columns_eq) # add corresponding images - module_path = Path(f'{os.getcwd()}/PREDICT') - - shap_images = glob.glob(f'{module_path}/SHAP_*.png') - pfi_images = glob.glob(f'{module_path}/PFI_*.png') - pearson_images = glob.glob(f'{module_path}/Pearson_*.png') + module_path = Path(f"{os.getcwd()}/PREDICT") + + shap_images = glob.glob(f"{module_path}/SHAP_*.png") + pfi_images = glob.glob(f"{module_path}/PFI_*.png") + pearson_images = glob.glob(f"{module_path}/Pearson_*.png") shap_images = revert_list(shap_images) pfi_images = revert_list(pfi_images) @@ -746,16 +886,18 @@ def print_features(self,warnings_dict,eval_only,spacing_PFI): image_pair_list = [shap_images, pfi_images, pearson_images] - margin_top, margin_bottom = -10,30 - for _,image_pair in enumerate(image_pair_list): - if len(image_pair) < 2 and not eval_only: # Pearson graphs aren't created when >30 descriptors + margin_top, margin_bottom = -10, 30 + for _, image_pair in enumerate(image_pair_list): + if ( + len(image_pair) < 2 and not eval_only + ): # Pearson graphs aren't created when >30 descriptors pair_list = f'

Pearson maps not created if >30 descriptors.' - pair_list += f'{(" ")*15}' + pair_list += f"{(' ') * 15}" if len(image_pair) == 1: pair_list += f'

' elif len(image_pair) == 0: - pair_list += f'{(" ")*15}' - pair_list += f'Pearson maps not created if >30 descriptors.

' + pair_list += f"{(' ') * 15}" + pair_list += "Pearson maps not created if >30 descriptors.

" elif eval_only: if len(image_pair) == 1: pair_list = f'

' @@ -763,23 +905,28 @@ def print_features(self,warnings_dict,eval_only,spacing_PFI): pair_list = f'

Pearson maps not created if >30 descriptors.

' else: pair_list = f'

' - pair_list += f'{(" ")*22}' + pair_list += f"{(' ') * 22}" pair_list += f'

' feature_dat += pair_list columns_pearson = [] # get two columns to combine and print - for suffix in ['No PFI','PFI']: - spacing = get_spacing_col(suffix,spacing_PFI) + for suffix in ["No PFI", "PFI"]: + spacing = get_spacing_col(suffix, spacing_PFI) - if eval_only and suffix == 'PFI': - columns_pearson.append('') + if eval_only and suffix == "PFI": + columns_pearson.append("") else: # split the sentence into 1 column size and add spacing line by line - pearson_sentence = format_lines(warnings_dict[f'pearson_info_{suffix}'],max_width=55,one_column=True,spacing=spacing) + pearson_sentence = format_lines( + warnings_dict[f"pearson_info_{suffix}"], + max_width=55, + one_column=True, + spacing=spacing, + ) column = f""" -

{spacing*3}Correlation analysis

+

{spacing * 3}Correlation analysis

{pearson_sentence} """ columns_pearson.append(column) @@ -789,116 +936,144 @@ def print_features(self,warnings_dict,eval_only,spacing_PFI): # add separator line and page break feature_dat += '
' - feature_dat += f"""

""" + feature_dat += """

""" return feature_dat - - def print_generate(self,pred_type,eval_only): + def print_generate(self, pred_type, eval_only): """ Generates the GENERATE hyperoptimization section """ - + # starts with the icon of feature importances - generate_dat = '' - generate_dat = self.module_lines('generate',generate_dat,eval_only=eval_only) - + generate_dat = "" + generate_dat = self.module_lines("generate", generate_dat, eval_only=eval_only) + # add corresponding images if not eval_only: height = 236 - generate_dat += self.print_img('Heatmap',-5,height,'GENERATE',pred_type,eval_only) + generate_dat += self.print_img( + "Heatmap", -5, height, "GENERATE", pred_type, eval_only + ) generate_dat += '

' return generate_dat - - def get_repro(self,eval_only): + def get_repro(self, eval_only): """ Generates the reproducibility section """ - - version_n_date, citation, command_line, python_version, total_time, dat_files = repro_info(self.args.report_modules) + + ( + version_n_date, + citation, + command_line, + python_version, + total_time, + dat_files, + ) = repro_info(self.args.report_modules) robert_version = version_n_date.split()[2] - if self.args.csv_name == '' or self.args.csv_test == '': - self = get_csv_names(self,command_line) + if self.args.csv_name == "" or self.args.csv_test == "": + self = get_csv_names(self, command_line) + + repro_dat, citation_dat = "", "" - repro_dat,citation_dat = '','' - # version, date and citation citation_dat += f"""


{version_n_date}

How to cite: {citation}

""" - aqme_workflow,aqme_updated = False,True + aqme_workflow, aqme_updated = False, True crest_workflow = False - if '--aqme' in command_line: + if "--aqme" in command_line: original_command = command_line aqme_workflow = True - command_line = command_line.replace('AQME-ROBERT_','') - self.args.csv_name = f'{self.args.csv_name}'.replace('AQME-ROBERT_','') - if self.args.csv_test != '': - self.args.csv_test = f'{self.args.csv_test}'.replace('AQME-ROBERT_','') + command_line = command_line.replace("AQME-ROBERT_", "") + self.args.csv_name = f"{self.args.csv_name}".replace("AQME-ROBERT_", "") + if self.args.csv_test != "": + self.args.csv_test = f"{self.args.csv_test}".replace("AQME-ROBERT_", "") - if '--program crest' in command_line.lower(): + if "--program crest" in command_line.lower(): crest_workflow = True # make the text more compact if --aqme is used (more lines are included) if aqme_workflow: - first_line = f'

' # reduces line separation separation + first_line = '

' # reduces line separation separation else: - first_line = f'

' # reduces line separation separation - reduced_line = f'

' # reduces line separation separation - space = (' ')*4 + first_line = '

' # reduces line separation separation + reduced_line = '

' # reduces line separation separation + space = (" ") * 4 # just in case the command lines are so long - command_line = format_lines(command_line,cmd_line=True) + command_line = format_lines(command_line, cmd_line=True) - # reproducibility section, starts with the icon of reproducibility + # reproducibility section, starts with the icon of reproducibility repro_dat += f"""{first_line}
1. Download these files (the authors should have uploaded the files as supporting information!):

""" - repro_dat += f"""{reduced_line}{space}- CSV database ({self.args.csv_name})

""" - if self.args.csv_test != '': + repro_dat += ( + f"""{reduced_line}{space}- CSV database ({self.args.csv_name})

""" + ) + if self.args.csv_test != "": repro_dat += f"""{reduced_line}{space}- External test set ({self.args.csv_test})

""" if aqme_workflow: try: - path_aqme = Path(f'{os.getcwd()}/AQME/CSEARCH_data.dat') - datfile = open(path_aqme, 'r', errors="replace") + path_aqme = Path(f"{os.getcwd()}/AQME/CSEARCH_data.dat") + datfile = open(path_aqme, "r", errors="replace") outlines = datfile.readlines() aqme_version = outlines[0].split()[2] datfile.close() find_aqme = True - except: + except Exception: find_aqme = False - aqme_version = '0.0' # dummy number - if int(aqme_version.split('.')[0]) in [0,1] and int(aqme_version.split('.')[1]) < 6: + aqme_version = "0.0" # dummy number + if ( + int(aqme_version.split(".")[0]) in [0, 1] + and int(aqme_version.split(".")[1]) < 6 + ): aqme_updated = False repro_dat += f"""{reduced_line}{space}Warning! This workflow might not be exactly reproducible, update to AQME v1.6.0+ (pip install aqme --upgrade)

""" repro_dat += f"""{reduced_line}{space}To obtain the same results, download the descriptor database (AQME-ROBERT_{self.args.csv_name}) and run:

""" repro_line = [] - original_command = original_command.replace(self.args.csv_name,f'AQME-ROBERT_{self.args.csv_name}') - for i,keyword in enumerate(original_command.split('"')): + original_command = original_command.replace( + self.args.csv_name, f"AQME-ROBERT_{self.args.csv_name}" + ) + for i, keyword in enumerate(original_command.split('"')): if i == 0: - if '--aqme' not in keyword and '--qdescp_keywords' not in keyword and '--csearch_keywords' not in keyword: - repro_line.append(keyword) + if ( + "--aqme" not in keyword + and "--qdescp_keywords" not in keyword + and "--csearch_keywords" not in keyword + ): + repro_line.append(keyword) else: - repro_line.append('python -m robert ') + repro_line.append("python -m robert ") if i > 0: - if '--qdescp_keywords' not in original_command.split('"')[i-1] and '--csearch_keywords' not in original_command.split('"')[i-1]: - if '--aqme' not in keyword and '--qdescp_keywords' not in keyword and '--csearch_keywords' not in keyword and keyword != '\n': - repro_line.append(keyword) + if ( + "--qdescp_keywords" + not in original_command.split('"')[i - 1] + and "--csearch_keywords" + not in original_command.split('"')[i - 1] + ): + if ( + "--aqme" not in keyword + and "--qdescp_keywords" not in keyword + and "--csearch_keywords" not in keyword + and keyword != "\n" + ): + repro_line.append(keyword) repro_line = '"'.join(repro_line) repro_line += '"' - if '--names ' not in repro_line: + if "--names " not in repro_line: repro_line += ' --names "code_name"' - repro_line = f'{reduced_line}{space}- Run: {repro_line}' - repro_line = format_lines(repro_line,cmd_line=True) + repro_line = f"{reduced_line}{space}- Run: {repro_line}" + repro_line = format_lines(repro_line, cmd_line=True) repro_dat += f"""{reduced_line}{repro_line}

""" if aqme_workflow and not aqme_updated: # I use a very reduced line in this title because the formatted command_line comes with an extra blank line # (if AQME is not updated the PDF contains a reproducibility warning) - repro_dat += f"""


2. Install and adjust the versions of the following Python modules:

""" + repro_dat += """


2. Install and adjust the versions of the following Python modules:

""" else: repro_dat += f"""{first_line}
2. Install and adjust the versions of the following Python modules:

""" repro_dat += f"""{reduced_line}{space}- Install ROBERT and its dependencies: conda install -y -c conda-forge robert

""" @@ -912,14 +1087,14 @@ def get_repro(self,eval_only): repro_dat += f"""{reduced_line}{space}- Install or adjust AQME version: pip install aqme=={aqme_version}

""" try: - path_xtb = Path(f'{os.getcwd()}/AQME/QDESCP') - xtb_json = glob.glob(f'{path_xtb}/*.json')[0] + path_xtb = Path(f"{os.getcwd()}/AQME/QDESCP") + xtb_json = glob.glob(f"{path_xtb}/*.json")[0] f = open(xtb_json, "r") # Opening JSON file data = json.loads(f.read()) # read file f.close() - xtb_version = data['xtb version'].split()[0] + xtb_version = data["xtb version"].split()[0] find_xtb = True - except: + except Exception: find_xtb = False if not find_xtb: repro_dat += f"""{reduced_line}{space}- xTB is required, but no version was found:

""" @@ -929,7 +1104,11 @@ def get_repro(self,eval_only): if crest_workflow: try: - from importlib.metadata import PackageNotFoundError, version as importlib_version + from importlib.metadata import ( + PackageNotFoundError, + version as importlib_version, + ) + crest_version = importlib_version("crest") find_crest = True except PackageNotFoundError: @@ -940,82 +1119,80 @@ def get_repro(self,eval_only): if find_crest: repro_dat += f"""{reduced_line}{space}- Adjust CREST version: conda install -c conda-forge crest={crest_version})

""" - character_line = '' - if self.args.csv_test != '': - character_line += 's' + character_line = "" + if self.args.csv_test != "": + character_line += "s" repro_dat += f"""{first_line}
3. Run ROBERT using this command line in the folder with the CSV database{character_line}:

{reduced_line}{command_line}

""" # I use a very reduced line in this title because the formatted command_line comes with an extra blank line if aqme_workflow: - repro_dat += f"""


4. Execution time, Python version and OS:

""" + repro_dat += """


4. Execution time, Python version and OS:

""" else: - repro_dat += f"""


4. Execution time, Python version and OS:

""" - + repro_dat += """


4. Execution time, Python version and OS:

""" + # add total execution time repro_dat += f"""{reduced_line}Originally run in Python {python_version} using {platform.system()} {platform.version()}

""" repro_dat += f"""{reduced_line}Total execution time: {total_time} seconds (the number of processors should be specified by the user)

""" # add separator line and page break repro_dat += '
' - repro_dat += f"""

""" + repro_dat += """

""" - repro_dat = self.module_lines('repro',repro_dat) + repro_dat = self.module_lines("repro", repro_dat) return citation_dat, repro_dat, dat_files, self.args.csv_name, robert_version - - def get_transparency(self,spacing_PFI): + def get_transparency(self, spacing_PFI): """ Generates the transparency section """ - transpa_dat = '' - titles_line = f'

' # reduces line separation separation + transpa_dat = "" + titles_line = '

' # reduces line separation separation # add params of the models transpa_dat += f"""{titles_line}
1. Parameters of the scikit-learn models (same keywords as used in scikit-learn):

""" - - model_dat, params_df = self.transpa_model_misc('model_section',spacing_PFI) + + model_dat, params_df = self.transpa_model_misc("model_section", spacing_PFI) transpa_dat += model_dat # add misc params - transpa_dat += f"""


2. ROBERT options, including prediction type (REG or CLAS), folds and repeats used for CV, etc:

""" - - section_dat, params_df = self.transpa_model_misc('misc_section',spacing_PFI) - transpa_dat += section_dat - - transpa_dat = self.module_lines('transpa',transpa_dat) + transpa_dat += """


2. ROBERT options, including prediction type (REG or CLAS), folds and repeats used for CV, etc:

""" + section_dat, params_df = self.transpa_model_misc("misc_section", spacing_PFI) + transpa_dat += section_dat - return transpa_dat,params_df + transpa_dat = self.module_lines("transpa", transpa_dat) + return transpa_dat, params_df - def transpa_model_misc(self,section,spacing_PFI): + def transpa_model_misc(self, section, spacing_PFI): """ Collects the data for model parameters and misc options in the Reproducibility section """ columns_repro = [] - for suffix in ['No PFI','PFI']: - spacing = get_spacing_col(suffix,spacing_PFI) + for suffix in ["No PFI", "PFI"]: + spacing = get_spacing_col(suffix, spacing_PFI) # set the parameters for each ML model - params_dir = f'{self.args.params_dir}/{"_".join(suffix.split())}' - files_param = glob.glob(f'{params_dir}/*.csv') + params_dir = f"{self.args.params_dir}/{'_'.join(suffix.split())}" + files_param = glob.glob(f"{params_dir}/*.csv") for file_param in files_param: - if '_db' not in file_param: - params_df = pd.read_csv(file_param, encoding='utf-8') - params_dict = pd_to_dict(params_df) # (using a dict to keep the same format of load_model) + if "_db" not in file_param: + params_df = pd.read_csv(file_param, encoding="utf-8") + params_dict = pd_to_dict( + params_df + ) # (using a dict to keep the same format of load_model) - columns_repro.append(get_col_transpa(params_dict,suffix,section,spacing)) + columns_repro.append(get_col_transpa(params_dict, suffix, section, spacing)) section_dat = combine_cols(columns_repro) section_dat += '

' - return section_dat,params_df - + return section_dat, params_df def get_abbrev(self): """ @@ -1023,132 +1200,136 @@ def get_abbrev(self): """ # starts with the icon of abbreviation - abbrev_dat = '' - abbrev_dat = self.module_lines('abbrev',abbrev_dat) + abbrev_dat = "" + abbrev_dat = self.module_lines("abbrev", abbrev_dat) columns_abbrev = [] - columns_abbrev.append(get_col_text('abbrev_1')) - columns_abbrev.append(get_col_text('abbrev_2')) - columns_abbrev.append(get_col_text('abbrev_3')) + columns_abbrev.append(get_col_text("abbrev_1")) + columns_abbrev.append(get_col_text("abbrev_2")) + columns_abbrev.append(get_col_text("abbrev_3")) abbrev_dat += combine_cols(columns_abbrev) - abbrev_dat +=f'


' + abbrev_dat += '
' - abbrev_dat += f"""

""" + abbrev_dat += """

""" return abbrev_dat - - def print_predictions(self,pred_type,eval_only,spacing_PFI): + def print_predictions(self, pred_type, eval_only, spacing_PFI): """ Generates the outliers section """ - + # detects whether there are predictions from an external set - module_file = f'{os.getcwd()}/PREDICT/PREDICT_data.dat' + module_file = f"{os.getcwd()}/PREDICT/PREDICT_data.dat" csv_test_exists, y_value, names, path_csv_test = detect_predictions(module_file) if csv_test_exists: - pred_dat = '' - pred_dat = self.module_lines('pred',pred_dat,pred_type=pred_type) + pred_dat = "" + pred_dat = self.module_lines("pred", pred_dat, pred_type=pred_type) columns_metrics = [] # add metrics - for suffix in ['No PFI','PFI']: - spacing = get_spacing_col(suffix,spacing_PFI) + for suffix in ["No PFI", "PFI"]: + spacing = get_spacing_col(suffix, spacing_PFI) - if eval_only and suffix == 'PFI': - columns_metrics.append('') + if eval_only and suffix == "PFI": + columns_metrics.append("") else: - columns_metrics.append(get_csv_metrics(module_file,suffix,spacing)) + columns_metrics.append( + get_csv_metrics(module_file, suffix, spacing) + ) # Combine both columns pred_dat += combine_cols(columns_metrics) columns_pred = [] # add predictions table - for suffix in ['No PFI','PFI']: - spacing = get_spacing_col(suffix,spacing_PFI) + for suffix in ["No PFI", "PFI"]: + spacing = get_spacing_col(suffix, spacing_PFI) - if eval_only and suffix == 'PFI': - columns_pred.append('') + if eval_only and suffix == "PFI": + columns_pred.append("") else: # add metrics - module_file = f'{os.getcwd()}/PREDICT/PREDICT_data.dat' - columns_pred.append(get_csv_pred(suffix,path_csv_test,y_value,names,spacing)) + module_file = f"{os.getcwd()}/PREDICT/PREDICT_data.dat" + columns_pred.append( + get_csv_pred(suffix, path_csv_test, y_value, names, spacing) + ) # Combine both columns pred_dat += combine_cols(columns_pred) # add corresponding images height = 217 - if pred_type == 'reg': - prefix_img = 'CV_variability' - elif pred_type == 'clas': - prefix_img = 'Results' + if pred_type == "reg": + prefix_img = "CV_variability" + elif pred_type == "clas": + prefix_img = "Results" height += 17 - if len(glob.glob(f'{os.getcwd()}/PREDICT/csv_test/{prefix_img}*.png')) > 0: - pred_dat += self.print_img(prefix_img,-5,height,'PREDICT/csv_test',pred_type,eval_only) + if len(glob.glob(f"{os.getcwd()}/PREDICT/csv_test/{prefix_img}*.png")) > 0: + pred_dat += self.print_img( + prefix_img, -5, height, "PREDICT/csv_test", pred_type, eval_only + ) # add separator line and page break pred_dat += '
' - pred_dat += f"""

""" + pred_dat += """

""" return pred_dat else: - return '' - + return "" - def module_lines(self,module,module_data,pred_type='reg',eval_only=False): + def module_lines(self, module, module_data, pred_type="reg", eval_only=False): """ Returns the line with icon and module for section titles """ - - if module == 'score': - module_name = 'Section A. ROBERT Score' - section_explain = f'

This score is designed to evaluate the models using different metrics.' - elif module == 'adv_anal': - module_name = 'Section B. Advanced Score Analysis' - section_explain = f'

This section explains each component that comprises the ROBERT score. More details here.' - elif module == 'y_distrib': - module_name = 'Section C. Distribution of y Values' - section_explain = f'

This section shows the distribution of y values within the training and validation sets.' - elif module == 'features': - module_name = 'Section D. Feature Importances' - section_explain = f'

This section presents feature importances measured using the validation set.' - elif module == 'outliers': - module_name = 'Section E. Outlier Analysis' - if pred_type == 'clas': - section_explain = f'

This feature is disabled in classification problems.' + + if module == "score": + module_name = "Section A. ROBERT Score" + section_explain = '

This score is designed to evaluate the models using different metrics.' + elif module == "adv_anal": + module_name = "Section B. Advanced Score Analysis" + section_explain = '

This section explains each component that comprises the ROBERT score. More details here.' + elif module == "y_distrib": + module_name = "Section C. Distribution of y Values" + section_explain = '

This section shows the distribution of y values within the training and validation sets.' + elif module == "features": + module_name = "Section D. Feature Importances" + section_explain = '

This section presents feature importances measured using the validation set.' + elif module == "outliers": + module_name = "Section E. Outlier Analysis" + if pred_type == "clas": + section_explain = '

This feature is disabled in classification problems.' else: - section_explain = f'

This section detects outliers using the standard deviation (SD) of errors from the training set.' - elif module == 'generate': - module_name = 'Section F. Model Screening' + section_explain = '

This section detects outliers using the standard deviation (SD) of errors from the training set.' + elif module == "generate": + module_name = "Section F. Model Screening" if eval_only: - section_explain = f'

The screening of models is disabled when using the EVALUATE module.' + section_explain = '

The screening of models is disabled when using the EVALUATE module.' else: - section_explain = f'

This section compares different combinations of hyperoptimized algorithms and partition sizes. The combined error is calculated as the product of the training error, validation error, and cross-validation error.' - elif module == 'repro': - module_name = 'Section G. Reproducibility' - section_explain = f'

This section provides all the instructions to reproduce the results presented.' - elif module == 'transpa': - module_name = 'Section H. Transparency' - section_explain = f'

This section contains important parameters used in scikit-learn models and ROBERT.' - elif module == 'abbrev': - module_name = 'Section I. Abbreviations' - section_explain = f'

Reference section for the abbreviations used.' - elif module == 'pred': - module_name = 'Section J. New Predictions' - section_explain = f'

Predictions of the external test set added with the csv_test option.' - elif module == 'misc': - module_name = 'Miscellaneous' - section_explain = f'

General tips to improve the models and instructions to predict new values.' - - if module not in ['repro','transpa','misc']: + section_explain = '

This section compares different combinations of hyperoptimized algorithms and partition sizes. The combined error is calculated as the product of the training error, validation error, and cross-validation error.' + elif module == "repro": + module_name = "Section G. Reproducibility" + section_explain = '

This section provides all the instructions to reproduce the results presented.' + elif module == "transpa": + module_name = "Section H. Transparency" + section_explain = '

This section contains important parameters used in scikit-learn models and ROBERT.' + elif module == "abbrev": + module_name = "Section I. Abbreviations" + section_explain = '

Reference section for the abbreviations used.' + elif module == "pred": + module_name = "Section J. New Predictions" + section_explain = '

Predictions of the external test set added with the csv_test option.' + elif module == "misc": + module_name = "Miscellaneous" + section_explain = '

General tips to improve the models and instructions to predict new values.' + + if module not in ["repro", "transpa", "misc"]: module_data = format_lines(module_data) - module_data = '

' + module_data + '
' - + module_data = '
' + module_data + "
" + separator_section = '

' title_line = f""" @@ -1160,43 +1341,60 @@ def module_lines(self,module,module_data,pred_type='reg',eval_only=False): {module_data}

""" - - return title_line + return title_line - def print_img(self,file_name,margin_top,height,module,pred_type,eval_only,test_set=False,diff_names=False): + def print_img( + self, + file_name, + margin_top, + height, + module, + pred_type, + eval_only, + test_set=False, + diff_names=False, + ): """ Generates the string that includes couples of images to print """ - - module_path = Path(f'{os.getcwd()}/{module}') + + module_path = Path(f"{os.getcwd()}/{module}") # detect test - set_types = ['train','valid'] + set_types = ["train", "valid"] if test_set: - set_types.append('test') + set_types.append("test") # different names for reg and clas problems, only for results images from PREDICT if diff_names: - if pred_type.lower() == 'reg': - results_images = [str(file_path) for file_path in module_path.rglob(f'{file_name}_*.png')] - elif pred_type.lower() == 'clas': - results_images = [str(file_path) for file_path in module_path.rglob(f'{file_name}_*.png')] + if pred_type.lower() == "reg": + results_images = [ + str(file_path) + for file_path in module_path.rglob(f"{file_name}_*.png") + ] + elif pred_type.lower() == "clas": + results_images = [ + str(file_path) + for file_path in module_path.rglob(f"{file_name}_*.png") + ] # images with no suffixes in the names else: - results_images = [str(file_path) for file_path in module_path.rglob(f'{file_name}_*.png')] + results_images = [ + str(file_path) for file_path in module_path.rglob(f"{file_name}_*.png") + ] # keep the ordering (No_PFI in the left, PFI in the right of the PDF) - results_images = revert_list(results_images) - + results_images = revert_list(results_images) + # add the graphs width = 100 pair_list = f'

' if not eval_only: - pair_list += f'{(" ")*22}' + pair_list += f"{(' ') * 22}" pair_list += f'

' - html_png = f'{pair_list}' + html_png = f"{pair_list}" - return html_png + return html_png diff --git a/robert/report_utils.py b/robert/report_utils.py index 5899127..0cc08c1 100644 --- a/robert/report_utils.py +++ b/robert/report_utils.py @@ -12,114 +12,115 @@ import ast -title_no_pfi = 'No PFI (standard descriptor filter):' -title_pfi = 'PFI (only important descriptors):' +title_no_pfi = "No PFI (standard descriptor filter):" +title_pfi = "PFI (only important descriptors):" -def get_csv_names(self,command_line): + +def get_csv_names(self, command_line): """ Detects the options from a command line or add them from manual inputs """ - - csv_name = '' - if '--csv_name' in command_line: - csv_name = command_line.split('--csv_name')[1].split()[0] + + csv_name = "" + if "--csv_name" in command_line: + csv_name = command_line.split("--csv_name")[1].split()[0] csv_name = remove_quot(csv_name) - - csv_test = '' - if '--csv_test' in command_line: - csv_test = command_line.split('--csv_test')[1].split()[0] + + csv_test = "" + if "--csv_test" in command_line: + csv_test = command_line.split("--csv_test")[1].split()[0] csv_test = remove_quot(csv_test) - if self.args.csv_name == '': + if self.args.csv_name == "": self.args.csv_name = csv_name - if self.args.csv_test == '': + if self.args.csv_test == "": self.args.csv_test = csv_test return self def remove_quot(name): - ''' + """ Remove initial and final quotations from names - ''' - - if name[0] in ['"',"'"]: + """ + + if name[0] in ['"', "'"]: name = name[1:] - if name[-1] in ['"',"'"]: + if name[-1] in ['"', "'"]: name = name[:-1] - + return name -def get_outliers(file,suffix,spacing): +def get_outliers(file, suffix, spacing): """ Retrieve the summary of results from the PREDICT and VERIFY dat files """ - - with open(file, 'r', encoding='utf-8') as datfile: + + with open(file, "r", encoding="utf-8") as datfile: lines = datfile.readlines() - train_outliers,test_outliers = [],[] - for i,line in enumerate(lines): - if suffix == 'No PFI': - if 'o Outliers plot saved' in line and 'No_PFI.png' in line: - train_outliers,test_outliers = locate_outliers(i,lines) - if suffix == 'PFI': - if 'o Outliers plot saved' in line and 'No_PFI.png' not in line: - train_outliers,test_outliers = locate_outliers(i,lines) + train_outliers, test_outliers = [], [] + for i, line in enumerate(lines): + if suffix == "No PFI": + if "o Outliers plot saved" in line and "No_PFI.png" in line: + train_outliers, test_outliers = locate_outliers(i, lines) + if suffix == "PFI": + if "o Outliers plot saved" in line and "No_PFI.png" not in line: + train_outliers, test_outliers = locate_outliers(i, lines) summary = [] # add the outlier part - summary.append(f'\n{spacing*2}Outliers (max. 10 shown)\n') + summary.append(f"\n{spacing * 2}Outliers (max. 10 shown)\n") summary = summary + train_outliers + test_outliers - summary = f'{spacing*2}'.join(summary) + summary = f"{spacing * 2}".join(summary) # add columns - if suffix == 'No PFI': + if suffix == "No PFI": title_col = title_no_pfi - elif suffix == 'PFI': + elif suffix == "PFI": title_col = title_pfi column = f""" -

{spacing*2}{title_col}

+

{spacing * 2}{title_col}

{summary}
""" return column -def get_metrics(file,suffix,spacing): +def get_metrics(file, suffix, spacing): """ Retrieve the summary of results from the PREDICT dat files """ - - with open(file, 'r', encoding='utf-8') as datfile: + + with open(file, "r", encoding="utf-8") as datfile: lines = datfile.readlines() - start_results,stop_results = 0,0 - for i,line in enumerate(lines): - if suffix == 'No PFI': - if 'o Summary of results' in line and 'No_PFI:' in line: - start_results = i+1 - stop_results = i+6 - if suffix == 'PFI': - if 'o Summary of results' in line and 'No_PFI:' not in line: - start_results = i+1 - stop_results = i+6 + start_results, stop_results = 0, 0 + for i, line in enumerate(lines): + if suffix == "No PFI": + if "o Summary of results" in line and "No_PFI:" in line: + start_results = i + 1 + stop_results = i + 6 + if suffix == "PFI": + if "o Summary of results" in line and "No_PFI:" not in line: + start_results = i + 1 + stop_results = i + 6 # add the summary of results of PREDICT - start_results += 4 # skip informaton that aren't metrics + start_results += 4 # skip informaton that aren't metrics summary = [] - for line in lines[start_results:stop_results+1]: - if 'R2' in line: - line = line.replace('R2','R2') + for line in lines[start_results : stop_results + 1]: + if "R2" in line: + line = line.replace("R2", "R2") - if suffix == 'No PFI': + if suffix == "No PFI": summary.append(line[8:]) - elif suffix == 'PFI': - summary.append(f'{spacing}{spacing}{line[8:]}') + elif suffix == "PFI": + summary.append(f"{spacing}{spacing}{line[8:]}") - summary = ''.join(summary) + summary = "".join(summary) column = f"""
{summary}
@@ -128,70 +129,70 @@ def get_metrics(file,suffix,spacing): return column -def get_csv_metrics(file,suffix,spacing): +def get_csv_metrics(file, suffix, spacing): """ Retrieve the csv_test results from the PREDICT dat file """ - - results_line = '' - with open(file, 'r', encoding='utf-8') as datfile: + + results_line = "" + with open(file, "r", encoding="utf-8") as datfile: lines = datfile.readlines() - for i,line in enumerate(lines): - if suffix == 'No PFI': - if 'o Summary of results' in line and 'No_PFI:' in line: - for j in range(i,i+15): - if 'o SHAP' in lines[j]: + for i, line in enumerate(lines): + if suffix == "No PFI": + if "o Summary of results" in line and "No_PFI:" in line: + for j in range(i, i + 15): + if "o SHAP" in lines[j]: break - elif '- External test : ' in lines[j]: + elif "- External test : " in lines[j]: results_line = lines[j][25:] - if suffix == 'PFI': - if 'o Summary of results' in line and 'No_PFI:' not in line: - for j in range(i,i+15): - if 'o SHAP' in lines[j]: + if suffix == "PFI": + if "o Summary of results" in line and "No_PFI:" not in line: + for j in range(i, i + 15): + if "o SHAP" in lines[j]: break - elif '- External test : ' in lines[j]: + elif "- External test : " in lines[j]: results_line = lines[j][25:] # start the csv_test section - metrics_dat = f'

{spacing*2}External test metrics

' + metrics_dat = f'

{spacing * 2}External test metrics

' # add line with model metrics (if any) - if results_line != '': - metrics_dat += f'

{spacing*2}{results_line}

' - + if results_line != "": + metrics_dat += f'

{spacing * 2}{results_line}

' + return metrics_dat - + else: - return '' + return "" -def get_csv_pred(suffix,path_csv_test,y_value,names,spacing): +def get_csv_pred(suffix, path_csv_test, y_value, names, spacing): """ Retrieve the csv_test results from the PREDICT dat file """ - - pred_line = '' - csv_test_folder = f'{os.getcwd()}/{os.path.dirname(path_csv_test)}' - csv_test_list = glob.glob(f'{csv_test_folder}/*.csv') + + pred_line = "" + csv_test_folder = f"{os.getcwd()}/{os.path.dirname(path_csv_test)}" + csv_test_list = glob.glob(f"{csv_test_folder}/*.csv") for file in csv_test_list: - if suffix == 'No PFI': - if '_No_PFI.csv' in file: + if suffix == "No PFI": + if "_No_PFI.csv" in file: csv_test_file = file - if suffix == 'PFI': - if '_No_PFI.csv' not in file and '_PFI.csv' in file: + if suffix == "PFI": + if "_No_PFI.csv" not in file and "_PFI.csv" in file: csv_test_file = file - csv_test_df = pd.read_csv(csv_test_file, encoding='utf-8') + csv_test_df = pd.read_csv(csv_test_file, encoding="utf-8") # start the csv_test section - pred_line = f'

{spacing*2}External test predictions (sorted, max. 20 shown)

' + pred_line = f'

{spacing * 2}External test predictions (sorted, max. 20 shown)

' - if suffix == 'No PFI': - pred_line += f'

{spacing*2}From /PREDICT/csv_test/...No_PFI.csv

' - elif suffix == 'PFI': - pred_line += f'

{spacing*2}From /PREDICT/csv_test/..._PFI.csv

' + if suffix == "No PFI": + pred_line += f'

{spacing * 2}From /PREDICT/csv_test/...No_PFI.csv

' + elif suffix == "PFI": + pred_line += f'

{spacing * 2}From /PREDICT/csv_test/..._PFI.csv

' - pred_line += ''' - ''' + """ y_val_exist = False - if f'{y_value}' in csv_test_df.columns: + if f"{y_value}" in csv_test_df.columns: y_val_exist = True # adjust format of headers names_head = names if len(str(names_head)) > 12: - names_head = f'{str(names_head[:9])}...' + names_head = f"{str(names_head[:9])}..." y_value_head = y_value if len(str(y_value_head)) > 12: - y_value_head = f'{str(y_value_head[:9])}...' + y_value_head = f"{str(y_value_head[:9])}..." - if pred_line != '': - if suffix == 'No PFI': + if pred_line != "": + if suffix == "No PFI": margin_left = 0 else: margin_left = 27 - pred_line += f''' + pred_line += f""" - ''' + """ if y_val_exist: - pred_line += f''' - ''' - if f'{y_value}_pred_sd' in csv_test_df: - pred_line += f''' + pred_line += f""" + """ + if f"{y_value}_pred_sd" in csv_test_df: + pred_line += f""" - ''' + """ else: - pred_line += f''' + pred_line += f""" - ''' - + """ + # retrieve and sort the values if not y_val_exist: - csv_test_df[y_value] = csv_test_df[f'{y_value}_pred'] + csv_test_df[y_value] = csv_test_df[f"{y_value}_pred"] # in clas problems, there are no SD in the predictions (we use a list of 0s) - if f'{y_value}_pred_sd' in csv_test_df: - sd_list = csv_test_df[f'{y_value}_pred_sd'] + if f"{y_value}_pred_sd" in csv_test_df: + sd_list = csv_test_df[f"{y_value}_pred_sd"] else: - sd_list = [0] * len(csv_test_df[f'{y_value}_pred']) - - y_pred_sorted, y_sorted, names_sorted, sd_sorted = (list(t) for t in zip(*sorted(zip(csv_test_df[f'{y_value}_pred'], csv_test_df[y_value], csv_test_df[names], sd_list), reverse=True))) + sd_list = [0] * len(csv_test_df[f"{y_value}_pred"]) + + y_pred_sorted, y_sorted, names_sorted, sd_sorted = ( + list(t) + for t in zip( + *sorted( + zip( + csv_test_df[f"{y_value}_pred"], + csv_test_df[y_value], + csv_test_df[names], + sd_list, + ), + reverse=True, + ) + ) + ) max_table = False if len(y_pred_sorted) > 20: max_table = True count_entries = 0 - for y_val_pred, y_val, name, sd in zip(y_pred_sorted, y_sorted, names_sorted, sd_sorted): + for y_val_pred, y_val, name, sd in zip( + y_pred_sorted, y_sorted, names_sorted, sd_sorted + ): # adjust format of entries if len(str(name)) > 12: - name = f'{str(name[:9])}...' + name = f"{str(name[:9])}..." y_val_pred = round(y_val_pred, 2) y_val = round(y_val, 2) sd = round(sd, 2) - if f'{y_value}_pred_sd' in csv_test_df: - y_val_pred_formatted = f'{y_val_pred} ± {sd}' + if f"{y_value}_pred_sd" in csv_test_df: + y_val_pred_formatted = f"{y_val_pred} ± {sd}" else: - y_val_pred_formatted = f'{y_val_pred}' + y_val_pred_formatted = f"{y_val_pred}" add_entry = True # if there are more than 20 predictions, only 20 values will be shown if max_table and count_entries >= 10: add_entry = False if count_entries == 10: - pred_line += f''' + pred_line += """ - ''' + """ if y_val_exist: - pred_line += f''' - ''' - pred_line += f''' + pred_line += """ + """ + pred_line += """ - ''' + """ elif count_entries >= (len(y_pred_sorted) - 10): add_entry = True if add_entry: - pred_line += f''' + pred_line += f""" - ''' + """ if y_val_exist: - pred_line += f''' - ''' - pred_line += f''' + pred_line += f""" + """ + pred_line += f""" - ''' + """ count_entries += 1 - pred_line += f''' + pred_line += """
{names_head}{names_head}{y_value_head}{y_value_head}{y_value_head}_pred ± sd
{y_value_head}_pred
...............
{name}{name}{y_val}{y_val}{y_val_pred_formatted}
-

''' +

""" return pred_line @@ -302,50 +318,58 @@ def detect_predictions(module_file): """ Check whether there are predictions from an external test set """ - + csv_test_exists = False # summary of the external CSV test set (if any) - y_value, names, path_csv_test = '','','' - with open(module_file, 'r', encoding= 'utf-8') as datfile: + y_value, names, path_csv_test = "", "", "" + with open(module_file, "r", encoding="utf-8") as datfile: lines = datfile.readlines() - for _,line in enumerate(lines): - if '- Target value:' in line: - y_value = ' '.join(line.split(':')[1:]).strip() - elif '- Names:' in line: + for _, line in enumerate(lines): + if "- Target value:" in line: + y_value = " ".join(line.split(":")[1:]).strip() + elif "- Names:" in line: names = line.split()[-1] - elif 'External set with predicted results:' in line: + elif "External set with predicted results:" in line: path_csv_test = line.split()[-1] csv_test_exists = True return csv_test_exists, y_value, names, path_csv_test -def locate_outliers(i,lines): +def locate_outliers(i, lines): """ Returns the start and end of the PREDICT summary in the dat file """ - - train_outliers,test_outliers = [],[] + + train_outliers, test_outliers = [], [] len_line = 54 - for j in range(i+1,len(lines)): - if 'Train:' in lines[j]: - for k in range(j,len(lines)): - if 'Test:' in lines[k]: + for j in range(i + 1, len(lines)): + if "Train:" in lines[j]: + for k in range(j, len(lines)): + if "Test:" in lines[k]: break - elif len(train_outliers) <= 10: # 10 outliers and the line with the % of outliers + elif ( + len(train_outliers) <= 10 + ): # 10 outliers and the line with the % of outliers if len(lines[k][6:]) > len_line: - outlier_line = f'{lines[k][6:len_line+6]}\n{lines[k][len_line+6:]}' + outlier_line = ( + f"{lines[k][6 : len_line + 6]}\n{lines[k][len_line + 6 :]}" + ) else: outlier_line = lines[k][6:] train_outliers.append(outlier_line) - elif 'Test:' in lines[j]: - for k in range(j,len(lines)): + elif "Test:" in lines[j]: + for k in range(j, len(lines)): if len(lines[k].split()) == 0: break - elif len(test_outliers) <= 10: # 10 outliers and the line with the % of outliers + elif ( + len(test_outliers) <= 10 + ): # 10 outliers and the line with the % of outliers if len(lines[k][6:]) > len_line: - outlier_line = f'{lines[k][6:len_line+6]}\n{lines[k][len_line+6:]}' + outlier_line = ( + f"{lines[k][6 : len_line + 6]}\n{lines[k][len_line + 6 :]}" + ) else: outlier_line = lines[k][6:] test_outliers.append(outlier_line) @@ -353,15 +377,15 @@ def locate_outliers(i,lines): if len(lines[j].split()) == 0: break - return train_outliers,test_outliers + return train_outliers, test_outliers + - def combine_cols(columns): """ Makes a string with multi-column lines """ - - column_data = '' + + column_data = "" for column in columns: column_data += f'
{column}
' @@ -379,8 +403,8 @@ def revert_list(list_tuple): Reverts the order of a list of two components """ - if len(list_tuple) == 2 and 'No_PFI' in list_tuple[1]: - new_sort = [] # for some reason reverse() gives a weird issue when reverting lists + if len(list_tuple) == 2 and "No_PFI" in list_tuple[1]: + new_sort = [] # for some reason reverse() gives a weird issue when reverting lists new_sort.append(list_tuple[1]) new_sort.append(list_tuple[0]) list_tuple = new_sort @@ -388,31 +412,33 @@ def revert_list(list_tuple): return list_tuple -def get_col_score(score_info,data_score,suffix,spacing,eval_only): +def get_col_score(score_info, data_score, suffix, spacing, eval_only): """ Gather the information regarding the score of the No PFI and PFI models """ - + ML_line_format = f'

{spacing}' part_line_format = f'

{spacing}' - score_title = f'''  ·  Score {data_score[f'robert_score_{suffix}']}''' - if suffix == 'No PFI': - caption = f'{spacing}{title_no_pfi.replace(":",score_title)}' + score_title = ( + f"""  ·  Score {data_score[f"robert_score_{suffix}"]}""" + ) + if suffix == "No PFI": + caption = f"{spacing}{title_no_pfi.replace(':', score_title)}" - elif suffix == 'PFI': - caption = f'{spacing}{title_pfi.replace(":",score_title)}' + elif suffix == "PFI": + caption = f"{spacing}{title_pfi.replace(':', score_title)}" - partitions_ratio = data_score['proportion_ratio_print'].split('- Proportion ')[1] + partitions_ratio = data_score["proportion_ratio_print"].split("- Proportion ")[1] if not eval_only: - title_line = f'{caption}' + title_line = f"{caption}" else: - title_line = 'Summary and score of your model (No PFI)' + title_line = "Summary and score of your model (No PFI)" column = f"""

{title_line}

- {ML_line_format}Model = {data_score['ML_model']}  ·  {partitions_ratio}

- {part_line_format}Points(train+validation):descriptors = {data_score[f'points_descp_ratio_{suffix}']}

+ {ML_line_format}Model = {data_score["ML_model"]}  ·  {partitions_ratio}

+ {part_line_format}Points(train+validation):descriptors = {data_score[f"points_descp_ratio_{suffix}"]}

{score_info}

""" @@ -420,17 +446,17 @@ def get_col_score(score_info,data_score,suffix,spacing,eval_only): return column -def adv_flawed(self,suffix,data_score,spacing): +def adv_flawed(self, suffix, data_score, spacing): """ Gather the advanced analysis of flawed models """ - score_flawed = data_score[f'flawed_mod_score_{suffix}'] + score_flawed = data_score[f"flawed_mod_score_{suffix}"] if score_flawed == 0: - flaw_result = f'The model predicts right for the right reasons.' + flaw_result = "The model predicts right for the right reasons." else: - flaw_result = f'Warning! The model probably has important flaws.' + flaw_result = "Warning! The model probably has important flaws." # adds a bit more space if there is no test set score_adv_flawed = f'

{spacing}' @@ -443,7 +469,7 @@ def adv_flawed(self,suffix,data_score,spacing): return column -def adv_predict(self,suffix,data_score,spacing,pred_type): +def adv_predict(self, suffix, data_score, spacing, pred_type): """ Gather the advanced analysis of predictive ability @@ -454,19 +480,19 @@ def adv_predict(self,suffix,data_score,spacing,pred_type): if 0.30 < MCC <= 0.50 => 1, else => 0 """ - score_predict = data_score.get(f'cv_score_combined_{suffix}', 0) + score_predict = data_score.get(f"cv_score_combined_{suffix}", 0) cv_type = data_score.get(f"cv_type_{suffix}", "10x 5-fold CV") - if pred_type == 'reg': - predict_image = f'{self.args.path_icons}/score_w_2_{score_predict}.jpg' - metric_type = ['Scaled RMSE','R2'] - scaled_rmse_cv = data_score.get(f'scaled_rmse_cv_{suffix}', 0) - r2_cv = data_score.get(f'r2_cv_{suffix}', 0) + if pred_type == "reg": + predict_image = f"{self.args.path_icons}/score_w_2_{score_predict}.jpg" + metric_type = ["Scaled RMSE", "R2"] + scaled_rmse_cv = data_score.get(f"scaled_rmse_cv_{suffix}", 0) + r2_cv = data_score.get(f"r2_cv_{suffix}", 0) - predict_result = f'{metric_type[0]} ({cv_type}) = {scaled_rmse_cv}%.' - predict_result += f'
{spacing}{metric_type[1]} ({cv_type}) = {r2_cv}.' - thres_line = 'Scaled RMSE ≤ 10%: +2, Scaled RMSE ≤ 20%: +1.' - thres_line += f'
{spacing}R2 < 0.5: -2, R2 < 0.7: -1' + predict_result = f"{metric_type[0]} ({cv_type}) = {scaled_rmse_cv}%." + predict_result += f"
{spacing}{metric_type[1]} ({cv_type}) = {r2_cv}." + thres_line = "Scaled RMSE ≤ 10%: +2, Scaled RMSE ≤ 20%: +1." + thres_line += f"
{spacing}R2 < 0.5: -2, R2 < 0.7: -1" init_sep = f'

{spacing}' score_adv_pred = f'

{spacing}' column = f"""{init_sep}2. CV predictions of the model  ({score_predict} / 2  score)

@@ -475,7 +501,7 @@ def adv_predict(self,suffix,data_score,spacing,pred_type): return column else: # Classification: award up to 3 points - mcc_cv = data_score.get(f'r2_cv_{suffix}', 0) + mcc_cv = data_score.get(f"r2_cv_{suffix}", 0) if mcc_cv > 0.75: display_score = 3 @@ -486,9 +512,9 @@ def adv_predict(self,suffix,data_score,spacing,pred_type): else: display_score = 0 - predict_image = f'{self.args.path_icons}/score_w_3_{display_score}.jpg' - metric_type = ['MCC'] - predict_result = f'{metric_type[0]} ({cv_type}) = {mcc_cv}.' + predict_image = f"{self.args.path_icons}/score_w_3_{display_score}.jpg" + metric_type = ["MCC"] + predict_result = f"{metric_type[0]} ({cv_type}) = {mcc_cv}." thres_line = "MCC >0.75: +3; 0.50-0.75: +2; 0.30-0.50: +1" init_sep = f'

{spacing}' @@ -499,20 +525,20 @@ def adv_predict(self,suffix,data_score,spacing,pred_type): return column -def adv_test(self,suffix,data_score,spacing,pred_type): +def adv_test(self, suffix, data_score, spacing, pred_type): """ Gather the advanced analysis of predictive ability with the test set """ - score_test = data_score.get(f'test_score_combined_{suffix}', 0) + score_test = data_score.get(f"test_score_combined_{suffix}", 0) - if pred_type == 'reg': - test_image = f'{self.args.path_icons}/score_w_2_{score_test}.jpg' - metric_type = ['Scaled RMSE','R2'] - predict_result = f'{metric_type[0]} (test set) = {data_score.get(f"scaled_rmse_test_{suffix}", 0)}%.' - predict_result += f'
{spacing}{metric_type[1]} (test set) = {data_score.get(f"r2_test_{suffix}", 0)}.' - thres_line = 'Scaled RMSE ≤ 10%: +2, Scaled RMSE ≤ 20%: +1.' - thres_line += f'
{spacing}R2 < 0.5: -2, R2 < 0.7: -1' + if pred_type == "reg": + test_image = f"{self.args.path_icons}/score_w_2_{score_test}.jpg" + metric_type = ["Scaled RMSE", "R2"] + predict_result = f"{metric_type[0]} (test set) = {data_score.get(f'scaled_rmse_test_{suffix}', 0)}%." + predict_result += f"
{spacing}{metric_type[1]} (test set) = {data_score.get(f'r2_test_{suffix}', 0)}." + thres_line = "Scaled RMSE ≤ 10%: +2, Scaled RMSE ≤ 20%: +1." + thres_line += f"
{spacing}R2 < 0.5: -2, R2 < 0.7: -1" score_adv_cv = f'

{spacing}' column = f"""{score_adv_cv}
{spacing}3. Predictive ability & overfitting

{spacing}3a. Predictions test set  ({score_test} / 2  score)

@@ -532,10 +558,10 @@ def adv_test(self,suffix,data_score,spacing,pred_type): else: display_score = 0 - test_image = f'{self.args.path_icons}/score_w_3_{display_score}.jpg' - metric_type = ['MCC'] - predict_result = f'{metric_type[0]} (test set) = {test_mcc}.' - thres_line = ('MCC >0.75: +3; 0.50-0.75: +2; 0.30-0.50: +1') + test_image = f"{self.args.path_icons}/score_w_3_{display_score}.jpg" + metric_type = ["MCC"] + predict_result = f"{metric_type[0]} (test set) = {test_mcc}." + thres_line = "MCC >0.75: +3; 0.50-0.75: +2; 0.30-0.50: +1" score_adv_cv = f'

{spacing}' column = f"""{score_adv_cv}
{spacing}3. Predictive ability & overfitting

@@ -545,34 +571,34 @@ def adv_test(self,suffix,data_score,spacing,pred_type): return column -def adv_diff_test(self,suffix,data_score,spacing,pred_type): +def adv_diff_test(self, suffix, data_score, spacing, pred_type): """ Gather the advanced analysis of difference in model performance between CV and test set. For regression, we compare scaled RMSE. For classification, we compare Δ MCC. """ - - if pred_type == 'reg': + + if pred_type == "reg": # Regression: use diff_scaled_rmse_score - score_diff_test = data_score[f'diff_scaled_rmse_score_{suffix}'] - diff_test_image = f'{self.args.path_icons}/score_w_2_{score_diff_test}.jpg' - - diff_result = f'RMSE in test is {round(data_score[f"factor_scaled_rmse_{suffix}"],2)}*scaled RMSE (CV).' - - thres_line = 'Scaled RMSE (test) ≤ 1.25*scaled RMSE (CV): +2.' - thres_line += f'
{spacing}Scaled RMSE (test) ≤ 1.50*scaled RMSE (CV): +1.' + score_diff_test = data_score[f"diff_scaled_rmse_score_{suffix}"] + diff_test_image = f"{self.args.path_icons}/score_w_2_{score_diff_test}.jpg" + + diff_result = f"RMSE in test is {round(data_score[f'factor_scaled_rmse_{suffix}'], 2)}*scaled RMSE (CV)." + + thres_line = "Scaled RMSE (test) ≤ 1.25*scaled RMSE (CV): +2." + thres_line += f"
{spacing}Scaled RMSE (test) ≤ 1.50*scaled RMSE (CV): +1." else: # Classification: use diff_mcc_score instead - score_diff_test = data_score.get(f'diff_mcc_score_{suffix}', 0) - diff_test_image = f'{self.args.path_icons}/score_w_2_{score_diff_test}.jpg' - + score_diff_test = data_score.get(f"diff_mcc_score_{suffix}", 0) + diff_test_image = f"{self.args.path_icons}/score_w_2_{score_diff_test}.jpg" + # Calculate the absolute difference between CV MCC and test MCC - mcc_cv = data_score.get(f'r2_cv_{suffix}', 0) - mcc_test = data_score.get(f'r2_test_{suffix}', 0) + mcc_cv = data_score.get(f"r2_cv_{suffix}", 0) + mcc_test = data_score.get(f"r2_test_{suffix}", 0) diff_mcc = round(abs(mcc_test - mcc_cv), 2) - - diff_result = f'The ΔMCC between CV and test is {diff_mcc}.' - - thres_line = 'ΔMCC ≤ 0.15: +2, ΔMCC ≤ 0.30: +1' + + diff_result = f"The ΔMCC between CV and test is {diff_mcc}." + + thres_line = "ΔMCC ≤ 0.15: +2, ΔMCC ≤ 0.30: +1" score_adv_diff = f'

{spacing}' column = f"""

{spacing}3b. Prediction accuracy test vs CV  ({score_diff_test} / 2  ROBERT Score)

@@ -583,22 +609,24 @@ def adv_diff_test(self,suffix,data_score,spacing,pred_type): return column -def adv_cv_sd(self,suffix,data_score,spacing): +def adv_cv_sd(self, suffix, data_score, spacing): """ Gather the advanced analysis of test predictions regarding variation """ - score_cv_sd = data_score[f'cv_sd_score_{suffix}'] - cv_r2_image = f'{self.args.path_icons}/score_w_2_{score_cv_sd}.jpg' - y_range_covered = round(data_score[f"cv_range_cov_{suffix}"]*100) - cv_4sd = round(data_score[f"cv_4sd_{suffix}"],1) + score_cv_sd = data_score[f"cv_sd_score_{suffix}"] + cv_r2_image = f"{self.args.path_icons}/score_w_2_{score_cv_sd}.jpg" + y_range_covered = round(data_score[f"cv_range_cov_{suffix}"] * 100) + cv_4sd = round(data_score[f"cv_4sd_{suffix}"], 1) if score_cv_sd == 0: - cv_sd_result = f'High variation, 4*SD = {cv_4sd} ({y_range_covered}% y-range).' + cv_sd_result = f"High variation, 4*SD = {cv_4sd} ({y_range_covered}% y-range)." elif score_cv_sd == 1: - cv_sd_result = f'Moderate variation, 4*SD = {cv_4sd} ({y_range_covered}% y-range).' + cv_sd_result = ( + f"Moderate variation, 4*SD = {cv_4sd} ({y_range_covered}% y-range)." + ) elif score_cv_sd == 2: - cv_sd_result = f'Low variation, 4*SD = {cv_4sd} ({y_range_covered}% y-range).' + cv_sd_result = f"Low variation, 4*SD = {cv_4sd} ({y_range_covered}% y-range)." score_adv_pred = f'

{spacing}' column = f"""

{spacing}3c. Avg. standard deviation (SD)  ({score_cv_sd} / 2  ROBERT Score)

@@ -608,40 +636,40 @@ def adv_cv_sd(self,suffix,data_score,spacing): return column -def adv_cv_diff(self,suffix,data_score,spacing,pred_type,test_set=False): +def adv_cv_diff(self, suffix, data_score, spacing, pred_type, test_set=False): """ Gather the advanced analysis of cross-validation regarding variation """ - if pred_type == 'clas': + if pred_type == "clas": # Skip entirely for classification return "" - score_cv_diff = data_score.get(f'r2_diff_score_{suffix}', 0) - cv_diff_image = f'{self.args.path_icons}/score_w_2_{score_cv_diff}.jpg' - cv_diff = round(data_score.get(f'r2_diff_{suffix}', 0), 2) + score_cv_diff = data_score.get(f"r2_diff_score_{suffix}", 0) + cv_diff_image = f"{self.args.path_icons}/score_w_2_{score_cv_diff}.jpg" + cv_diff = round(data_score.get(f"r2_diff_{suffix}", 0), 2) if test_set: - sd_set = 'test' + sd_set = "test" else: - sd_set = 'valid.' + sd_set = "valid." # Build R2 difference text if score_cv_diff == 0: - cv_diff_result = f'High variation ({sd_set} and CV), ΔR² = {cv_diff}.' + cv_diff_result = f"High variation ({sd_set} and CV), ΔR² = {cv_diff}." elif score_cv_diff == 1: - cv_diff_result = f'Moderate variation ({sd_set} and CV), ΔR² = {cv_diff}.' + cv_diff_result = f"Moderate variation ({sd_set} and CV), ΔR² = {cv_diff}." elif score_cv_diff == 2: - cv_diff_result = f'Low variation ({sd_set} and CV), ΔR² = {cv_diff}.' - + cv_diff_result = f"Low variation ({sd_set} and CV), ΔR² = {cv_diff}." + metric_label = "ΔR²" - threshold_text = f'{metric_label} 0.15-0.30: +1, {metric_label} < 0.15: +2.' - title = f"3b. R² difference (model vs CV)" + threshold_text = f"{metric_label} 0.15-0.30: +1, {metric_label} < 0.15: +2." + title = "3b. R² difference (model vs CV)" score_adv_pred = f'

{spacing}' column = f"""

{spacing}{title}  ({score_cv_diff} / 2  ROBERT Score)

{score_adv_pred}{cv_diff_result}
{spacing}· Scoring from 0 to 2 ·
{spacing}{threshold_text}
- """ + """ return column @@ -653,24 +681,27 @@ def adv_sorted_cv(self, suffix, data_score, spacing, pred_type): based on how many folds stay near the maximum, indicating consistent performance. """ - score_sorted = data_score.get(f'sorted_cv_score_{suffix}', 0) - sorted_cv_image = f'{self.args.path_icons}/score_w_2_{score_sorted}.jpg' + score_sorted = data_score.get(f"sorted_cv_score_{suffix}", 0) + sorted_cv_image = f"{self.args.path_icons}/score_w_2_{score_sorted}.jpg" - error_keyword = "rmse" if pred_type.lower() == 'reg' else "mcc" + error_keyword = "rmse" if pred_type.lower() == "reg" else "mcc" - # "3d. Extrapolation/Consistency (sorted CV)" - if f'scaled_{error_keyword}_sorted_{suffix}' not in data_score: - data_score[f'scaled_{error_keyword}_sorted_{suffix}'] = [] + if f"scaled_{error_keyword}_sorted_{suffix}" not in data_score: + data_score[f"scaled_{error_keyword}_sorted_{suffix}"] = [] - if pred_type == 'reg': - title_cap = '3d. Extrapolation' - sorted_rmse = [f'{val}%' for val in data_score[f'scaled_{error_keyword}_sorted_{suffix}']] + if pred_type == "reg": + title_cap = "3d. Extrapolation" + sorted_rmse = [ + f"{val}%" for val in data_score[f"scaled_{error_keyword}_sorted_{suffix}"] + ] else: - title_cap = '3c. Consistency' - sorted_rmse = [f'{val}' for val in data_score[f'scaled_{error_keyword}_sorted_{suffix}']] + title_cap = "3c. Consistency" + sorted_rmse = [ + f"{val}" for val in data_score[f"scaled_{error_keyword}_sorted_{suffix}"] + ] - sorted_rmse_str = str(sorted_rmse).replace("'", '') + sorted_rmse_str = str(sorted_rmse).replace("'", "") column = f"""

{spacing}{title_cap} (sorted CV)  ({score_sorted} / 2  ROBERT Score)

@@ -687,42 +718,45 @@ def get_col_text(type_thres): Gather the information regarding the thresholds used in the score and abbreviation sections """ - reduced_line = '

' # reduces line separation separation + reduced_line = '

' # reduces line separation separation first_line = '

' - if type_thres == 'abbrev_1': - abbrev_list = ['ACC: accuracy', - 'ADAB: AdaBoost', - 'CSV: comma separated values', - 'CLAS: classification', - 'CV: cross-validation', - 'F1 score: balanced F-score', - 'GB: gradient boosting', - 'GP: gaussian process', - 'XGB: extreme gradient boosting' + if type_thres == "abbrev_1": + abbrev_list = [ + "ACC: accuracy", + "ADAB: AdaBoost", + "CSV: comma separated values", + "CLAS: classification", + "CV: cross-validation", + "F1 score: balanced F-score", + "GB: gradient boosting", + "GP: gaussian process", + "XGB: extreme gradient boosting", ] - elif type_thres == 'abbrev_2': - abbrev_list = ['KN: k-nearest neighbors', - 'MAE: root-mean-square error', - "MCC: Matthew's correl. coefficient", - 'ML: machine learning', - 'MVL: multivariate lineal models', - 'NN: neural network', - 'PFI: permutation feature importance', - 'R2: coefficient of determination' + elif type_thres == "abbrev_2": + abbrev_list = [ + "KN: k-nearest neighbors", + "MAE: root-mean-square error", + "MCC: Matthew's correl. coefficient", + "ML: machine learning", + "MVL: multivariate lineal models", + "NN: neural network", + "PFI: permutation feature importance", + "R2: coefficient of determination", ] - elif type_thres == 'abbrev_3': - abbrev_list = ['REG: Regression', - 'RF: random forest', - 'RMSE: root mean square error', - 'RND: random', - 'SHAP: Shapley additive explanations', - 'VR: voting regressor', + elif type_thres == "abbrev_3": + abbrev_list = [ + "REG: Regression", + "RF: random forest", + "RMSE: root mean square error", + "RND: random", + "SHAP: Shapley additive explanations", + "VR: voting regressor", ] - column = '' - for i,ele in enumerate(abbrev_list): + column = "" + for i, ele in enumerate(abbrev_list): if i == 0: column += f"""{first_line}{ele}

""" @@ -733,112 +767,128 @@ def get_col_text(type_thres): return column -def get_col_transpa(params_dict,suffix,section,spacing): +def get_col_transpa(params_dict, suffix, section, spacing): """ Gather the information regarding the model parameters represented in the Reproducibility section """ - first_line = f'

{spacing*2}' # reduces line separation separation - reduced_line = f'

{spacing*2}' # reduces line separation separation - - if suffix == 'No PFI': - caption = f'{title_no_pfi}' - - elif suffix == 'PFI': - caption = f'{title_pfi}' - - excluded_params = [f"combined_{params_dict['error_type']}", 'train', 'X_descriptors', 'y', 'error_train', 'cv_error', 'names'] - misc_params = ['type','error_type','split','kfold','repeat_kfolds','seed'] - if params_dict['type'] == 'reg': - model_type = 'Regressor' - elif params_dict['type'] == 'clas': - model_type = 'Classifier' - models_dict = {'RF': f'RandomForest{model_type}', - 'MVL': 'LinearRegression', - 'GB': f'GradientBoosting{model_type}', - 'XGB': f'XGB{model_type}', - 'NN': f'MLP{model_type}', - 'GP': f'GaussianProcess{model_type}', - 'ADAB': f'AdaBoost{model_type}', - 'VR': f'Voting{model_type}', - } - - col_info,sklearn_model = '','' - for _,ele in enumerate(params_dict.keys()): + first_line = f'

{spacing * 2}' # reduces line separation separation + reduced_line = f'

{spacing * 2}' # reduces line separation separation + + if suffix == "No PFI": + caption = f"{title_no_pfi}" + + elif suffix == "PFI": + caption = f"{title_pfi}" + + excluded_params = [ + f"combined_{params_dict['error_type']}", + "train", + "X_descriptors", + "y", + "error_train", + "cv_error", + "names", + ] + misc_params = ["type", "error_type", "split", "kfold", "repeat_kfolds", "seed"] + if params_dict["type"] == "reg": + model_type = "Regressor" + elif params_dict["type"] == "clas": + model_type = "Classifier" + models_dict = { + "RF": f"RandomForest{model_type}", + "MVL": "LinearRegression", + "GB": f"GradientBoosting{model_type}", + "XGB": f"XGB{model_type}", + "NN": f"MLP{model_type}", + "GP": f"GaussianProcess{model_type}", + "ADAB": f"AdaBoost{model_type}", + "VR": f"Voting{model_type}", + } + + col_info, sklearn_model = "", "" + for _, ele in enumerate(params_dict.keys()): if ele not in excluded_params: - if ele == 'model' and section == 'model_section': + if ele == "model" and section == "model_section": sklearn_model = models_dict[params_dict[ele].upper()] sklearn_model = f"""{first_line}sklearn model: {sklearn_model}

""" - elif section == 'model_section' and ele.lower() not in misc_params: - if ele == 'params': - model_params = ast.literal_eval(params_dict['params']) + elif section == "model_section" and ele.lower() not in misc_params: + if ele == "params": + model_params = ast.literal_eval(params_dict["params"]) for param in model_params: - col_info += f"""{reduced_line}{param}: {model_params[param]}

""" - elif section == 'misc_section' and ele.lower() in misc_params: - if col_info == '': + col_info += ( + f"""{reduced_line}{param}: {model_params[param]}

""" + ) + elif section == "misc_section" and ele.lower() in misc_params: + if col_info == "": col_info += f"""{first_line}{ele}: {params_dict[ele]}

""" else: col_info += f"""{reduced_line}{ele}: {params_dict[ele]}

""" - - column = f"""

{spacing*2}{caption}

+ + column = f"""

{spacing * 2}{caption}

{sklearn_model}{col_info} """ return column -def calc_score(dat_files,suffix,pred_type,data_score): - ''' +def calc_score(dat_files, suffix, pred_type, data_score): + """ Calculates ROBERT score - ''' + """ - data_score = get_predict_scores(dat_files['PREDICT'],suffix,pred_type,data_score) + data_score = get_predict_scores(dat_files["PREDICT"], suffix, pred_type, data_score) - data_score = get_verify_scores(dat_files['VERIFY'],suffix,pred_type,data_score) + data_score = get_verify_scores(dat_files["VERIFY"], suffix, pred_type, data_score) - if pred_type == 'reg': - robert_score = data_score.get(f'cv_score_combined_{suffix}', 0) + data_score.get(f'test_score_combined_{suffix}', 0) \ - + data_score.get(f'cv_sd_score_{suffix}', 0) + data_score.get(f'diff_scaled_rmse_score_{suffix}', 0) \ - + data_score.get(f'flawed_mod_score_{suffix}', 0) + data_score.get(f'sorted_cv_score_{suffix}', 0) + if pred_type == "reg": + robert_score = ( + data_score.get(f"cv_score_combined_{suffix}", 0) + + data_score.get(f"test_score_combined_{suffix}", 0) + + data_score.get(f"cv_sd_score_{suffix}", 0) + + data_score.get(f"diff_scaled_rmse_score_{suffix}", 0) + + data_score.get(f"flawed_mod_score_{suffix}", 0) + + data_score.get(f"sorted_cv_score_{suffix}", 0) + ) # Adjustment to avoid negative values if robert_score < 0: robert_score = 0 # Assign the final value - data_score[f'robert_score_{suffix}'] = robert_score + data_score[f"robert_score_{suffix}"] = robert_score - elif pred_type == 'clas': + elif pred_type == "clas": # Calculate the difference between CV MCC and test MCC - mcc_cv = data_score.get(f'r2_cv_{suffix}', 0) - mcc_test = data_score.get(f'r2_test_{suffix}', 0) + mcc_cv = data_score.get(f"r2_cv_{suffix}", 0) + mcc_test = data_score.get(f"r2_test_{suffix}", 0) diff_mcc = round(np.abs(mcc_test - mcc_cv), 2) # Assign a score based on the MCC gap (e.g., ±2, ±1, 0) - data_score[f'diff_mcc_score_{suffix}'] = 0 + data_score[f"diff_mcc_score_{suffix}"] = 0 if diff_mcc < 0.15: - data_score[f'diff_mcc_score_{suffix}'] += 2 + data_score[f"diff_mcc_score_{suffix}"] += 2 elif diff_mcc <= 0.30: - data_score[f'diff_mcc_score_{suffix}'] += 1 + data_score[f"diff_mcc_score_{suffix}"] += 1 # Sum scores similarly to regression: robert_score = ( - data_score.get(f'cv_score_combined_{suffix}', 0) - + data_score.get(f'test_score_combined_{suffix}', 0) - + data_score.get(f'flawed_mod_score_{suffix}', 0) - + data_score.get(f'sorted_cv_score_{suffix}', 0) - + data_score.get(f'diff_mcc_score_{suffix}', 0) - + data_score.get(f'descp_score_{suffix}', 0) + data_score.get(f"cv_score_combined_{suffix}", 0) + + data_score.get(f"test_score_combined_{suffix}", 0) + + data_score.get(f"flawed_mod_score_{suffix}", 0) + + data_score.get(f"sorted_cv_score_{suffix}", 0) + + data_score.get(f"diff_mcc_score_{suffix}", 0) + + data_score.get(f"descp_score_{suffix}", 0) ) # Adjustment to avoid negative values if robert_score < 0: robert_score = 0 # Assign the final value - data_score[f'robert_score_{suffix}'] = robert_score + data_score[f"robert_score_{suffix}"] = robert_score return data_score - -def get_verify_scores(dat_verify,suffix,pred_type,data_score): + +def get_verify_scores(dat_verify, suffix, pred_type, data_score): """ Calculates scores that come from the VERIFY module (VERIFY tests) """ @@ -847,175 +897,271 @@ def get_verify_scores(dat_verify,suffix,pred_type,data_score): flawed_score = 0 failed_tests = 0 sorted_cv_score = 0 - for i,line in enumerate(dat_verify): + for i, line in enumerate(dat_verify): # set starting points for No PFI and PFI models - if suffix == 'No PFI': - if '------- ' in line and '(No PFI)' in line: + if suffix == "No PFI": + if "------- " in line and "(No PFI)" in line: start_data = True - elif '------- ' in line and 'with PFI' in line: + elif "------- " in line and "with PFI" in line: start_data = False - if suffix == 'PFI': - if '------- ' in line and 'with PFI' in line: + if suffix == "PFI": + if "------- " in line and "with PFI" in line: start_data = True - + if start_data: - error_keyword = "rmse" if pred_type.lower() == 'reg' else "mcc" + error_keyword = "rmse" if pred_type.lower() == "reg" else "mcc" if f"Original {error_keyword.upper()} (" in line: - for j in range(i+1,i+4): # y-mean, y-shuffle and onehot tests - if 'UNCLEAR' in dat_verify[j]: + for j in range(i + 1, i + 4): # y-mean, y-shuffle and onehot tests + if "UNCLEAR" in dat_verify[j]: flawed_score -= 1 - elif 'FAILED' in dat_verify[j]: + elif "FAILED" in dat_verify[j]: flawed_score -= 2 failed_tests += 1 - if '- Sorted ' in dat_verify[i+4]: - sorted_cv_results = dat_verify[i+4].split(f'{error_keyword.upper()} = ')[-1] + if "- Sorted " in dat_verify[i + 4]: + sorted_cv_results = dat_verify[i + 4].split( + f"{error_keyword.upper()} = " + )[-1] sorted_cv_results = ast.literal_eval(sorted_cv_results) - if pred_type.lower() == 'reg': - data_score[f'scaled_{error_keyword}_sorted_{suffix}'] = [round((val/data_score[f'y_range_{suffix}'])*100,2) for val in sorted_cv_results] + if pred_type.lower() == "reg": + data_score[f"scaled_{error_keyword}_sorted_{suffix}"] = [ + round((val / data_score[f"y_range_{suffix}"]) * 100, 2) + for val in sorted_cv_results + ] else: - data_score[f'scaled_{error_keyword}_sorted_{suffix}'] = sorted_cv_results # no scaling for MCC + data_score[f"scaled_{error_keyword}_sorted_{suffix}"] = ( + sorted_cv_results # no scaling for MCC + ) # define min and max values - data_score[f'min_scaled_{error_keyword}_{suffix}'] = min(data_score[f'scaled_{error_keyword}_sorted_{suffix}']) - idx_min_scaled_rmse = data_score[f'scaled_{error_keyword}_sorted_{suffix}'].index(data_score[f'min_scaled_{error_keyword}_{suffix}']) - data_score[f'max_scaled_{error_keyword}_{suffix}'] = max(data_score[f'scaled_{error_keyword}_sorted_{suffix}']) - idx_max_scaled_rmse = data_score[f'scaled_{error_keyword}_sorted_{suffix}'].index(data_score[f'max_scaled_{error_keyword}_{suffix}']) - - data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'] = [] - for idx,err in enumerate(data_score[f'scaled_{error_keyword}_sorted_{suffix}']): - if pred_type.lower() == 'reg': + data_score[f"min_scaled_{error_keyword}_{suffix}"] = min( + data_score[f"scaled_{error_keyword}_sorted_{suffix}"] + ) + idx_min_scaled_rmse = data_score[ + f"scaled_{error_keyword}_sorted_{suffix}" + ].index(data_score[f"min_scaled_{error_keyword}_{suffix}"]) + data_score[f"max_scaled_{error_keyword}_{suffix}"] = max( + data_score[f"scaled_{error_keyword}_sorted_{suffix}"] + ) + idx_max_scaled_rmse = data_score[ + f"scaled_{error_keyword}_sorted_{suffix}" + ].index(data_score[f"max_scaled_{error_keyword}_{suffix}"]) + + data_score[f"scaled_{error_keyword}_results_sorted_{suffix}"] = [] + for idx, err in enumerate( + data_score[f"scaled_{error_keyword}_sorted_{suffix}"] + ): + if pred_type.lower() == "reg": if idx == idx_min_scaled_rmse: - data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'].append('min') - elif err <= (data_score[f'min_scaled_{error_keyword}_{suffix}']*1.25): - data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'].append('pass') + data_score[ + f"scaled_{error_keyword}_results_sorted_{suffix}" + ].append("min") + elif err <= ( + data_score[f"min_scaled_{error_keyword}_{suffix}"] + * 1.25 + ): + data_score[ + f"scaled_{error_keyword}_results_sorted_{suffix}" + ].append("pass") else: - data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'].append('fail') + data_score[ + f"scaled_{error_keyword}_results_sorted_{suffix}" + ].append("fail") else: if idx == idx_max_scaled_rmse: - data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'].append('max') - elif err >= (data_score[f'max_scaled_{error_keyword}_{suffix}']*0.75): - data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'].append('pass') + data_score[ + f"scaled_{error_keyword}_results_sorted_{suffix}" + ].append("max") + elif err >= ( + data_score[f"max_scaled_{error_keyword}_{suffix}"] + * 0.75 + ): + data_score[ + f"scaled_{error_keyword}_results_sorted_{suffix}" + ].append("pass") else: - data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'].append('fail') + data_score[ + f"scaled_{error_keyword}_results_sorted_{suffix}" + ].append("fail") - sorted_cv_score = int(data_score[f'scaled_{error_keyword}_results_sorted_{suffix}'].count('pass')/2) + sorted_cv_score = int( + data_score[ + f"scaled_{error_keyword}_results_sorted_{suffix}" + ].count("pass") + / 2 + ) # adjust max 1 point for flawed tests if flawed_score > 1: flawed_score = 1 - + # stores data - data_score[f'flawed_mod_score_{suffix}'] = flawed_score - data_score[f'failed_tests_{suffix}'] = failed_tests - data_score[f'sorted_cv_score_{suffix}'] = sorted_cv_score + data_score[f"flawed_mod_score_{suffix}"] = flawed_score + data_score[f"failed_tests_{suffix}"] = failed_tests + data_score[f"sorted_cv_score_{suffix}"] = sorted_cv_score return data_score -def get_predict_scores(dat_predict,suffix,pred_type,data_score): +def get_predict_scores(dat_predict, suffix, pred_type, data_score): """ Calculates scores that come from the PREDICT module (R2 or accuracy, datapoints:descriptors ratio, outlier proportion) """ start_data = False - data_score[f'rmse_score_{suffix}'] = 0 - data_score[f'cv_type_{suffix}'] = "10x 5-fold CV" - - for i,line in enumerate(dat_predict): + data_score[f"rmse_score_{suffix}"] = 0 + data_score[f"cv_type_{suffix}"] = "10x 5-fold CV" + for i, line in enumerate(dat_predict): # set starting points for No PFI and PFI models - if suffix == 'No PFI': - if '------- ' in line and '(No PFI)' in line: + if suffix == "No PFI": + if "------- " in line and "(No PFI)" in line: start_data = True - elif '------- ' in line and 'with PFI' in line: + elif "------- " in line and "with PFI" in line: start_data = False - if suffix == 'PFI': - if '------- ' in line and 'with PFI' in line: + if suffix == "PFI": + if "------- " in line and "with PFI" in line: start_data = True - + if start_data: # model type - if line.startswith(' - Model:'): - data_score['ML_model'] = line.split()[-1] + if line.startswith(" - Model:"): + data_score["ML_model"] = line.split()[-1] # R2 and proportion - if 'o Summary of results' in line: - data_score['proportion_ratio_print'] = dat_predict[i+2] - data_score[f'points_descp_ratio_{suffix}'] = dat_predict[i+4].split()[-1] + if "o Summary of results" in line: + data_score["proportion_ratio_print"] = dat_predict[i + 2] + data_score[f"points_descp_ratio_{suffix}"] = dat_predict[i + 4].split()[ + -1 + ] # scaled RMSE/MCC from test (if any) or validation - if pred_type == 'reg': - if '-fold CV : R2 =' in dat_predict[i+5]: - data_score[f'rmse_cv_{suffix}'] = float(dat_predict[i+5].split()[-1]) - data_score[f"cv_type_{suffix}"] = ' '.join([ele for ele in dat_predict[i+5].split()[1:4]]) - data_score[f'r2_cv_{suffix}'] = float(dat_predict[i+5].split(',')[0].split()[-1]) - if 'Test : R2 =' in dat_predict[i+6]: - data_score[f'rmse_test_{suffix}'] = float(dat_predict[i+6].split()[-1]) - data_score[f'r2_test_{suffix}'] = float(dat_predict[i+6].split(',')[0].split()[-1]) - if '- y range of dataset' in dat_predict[i+8]: - data_score[f'y_range_{suffix}'] = float(dat_predict[i+8].split()[-1]) - - data_score[f'scaled_rmse_cv_{suffix}'] = round((data_score[f'rmse_cv_{suffix}']/data_score[f'y_range_{suffix}'])*100,2) - data_score[f'scaled_rmse_test_{suffix}'] = round((data_score[f'rmse_test_{suffix}']/data_score[f'y_range_{suffix}'])*100,2) - - data_score[f'cv_score_rmse_{suffix}'] = score_rmse_mcc(pred_type,data_score[f'scaled_rmse_cv_{suffix}']) - data_score[f'test_score_rmse_{suffix}'] = score_rmse_mcc(pred_type,data_score[f'scaled_rmse_test_{suffix}']) + if pred_type == "reg": + if "-fold CV : R2 =" in dat_predict[i + 5]: + data_score[f"rmse_cv_{suffix}"] = float( + dat_predict[i + 5].split()[-1] + ) + data_score[f"cv_type_{suffix}"] = " ".join( + [ele for ele in dat_predict[i + 5].split()[1:4]] + ) + data_score[f"r2_cv_{suffix}"] = float( + dat_predict[i + 5].split(",")[0].split()[-1] + ) + if "Test : R2 =" in dat_predict[i + 6]: + data_score[f"rmse_test_{suffix}"] = float( + dat_predict[i + 6].split()[-1] + ) + data_score[f"r2_test_{suffix}"] = float( + dat_predict[i + 6].split(",")[0].split()[-1] + ) + if "- y range of dataset" in dat_predict[i + 8]: + data_score[f"y_range_{suffix}"] = float( + dat_predict[i + 8].split()[-1] + ) + + data_score[f"scaled_rmse_cv_{suffix}"] = round( + ( + data_score[f"rmse_cv_{suffix}"] + / data_score[f"y_range_{suffix}"] + ) + * 100, + 2, + ) + data_score[f"scaled_rmse_test_{suffix}"] = round( + ( + data_score[f"rmse_test_{suffix}"] + / data_score[f"y_range_{suffix}"] + ) + * 100, + 2, + ) + + data_score[f"cv_score_rmse_{suffix}"] = score_rmse_mcc( + pred_type, data_score[f"scaled_rmse_cv_{suffix}"] + ) + data_score[f"test_score_rmse_{suffix}"] = score_rmse_mcc( + pred_type, data_score[f"scaled_rmse_test_{suffix}"] + ) # get penalties for R2 - data_score[f'cv_penalty_r2_{suffix}'] = calc_penalty_r2(data_score[f'r2_cv_{suffix}']) - data_score[f'test_penalty_r2_{suffix}'] = calc_penalty_r2(data_score[f'r2_test_{suffix}']) + data_score[f"cv_penalty_r2_{suffix}"] = calc_penalty_r2( + data_score[f"r2_cv_{suffix}"] + ) + data_score[f"test_penalty_r2_{suffix}"] = calc_penalty_r2( + data_score[f"r2_test_{suffix}"] + ) # combined scores RMSE/R2 (min 0) - data_score[f'cv_score_combined_{suffix}'] = data_score[f'cv_score_rmse_{suffix}'] + data_score[f'cv_penalty_r2_{suffix}'] - if data_score[f'cv_score_combined_{suffix}'] < 0: - data_score[f'cv_score_combined_{suffix}'] = 0 - data_score[f'test_score_combined_{suffix}'] = data_score[f'test_score_rmse_{suffix}'] + data_score[f'test_penalty_r2_{suffix}'] - if data_score[f'test_score_combined_{suffix}'] < 0: - data_score[f'test_score_combined_{suffix}'] = 0 + data_score[f"cv_score_combined_{suffix}"] = ( + data_score[f"cv_score_rmse_{suffix}"] + + data_score[f"cv_penalty_r2_{suffix}"] + ) + if data_score[f"cv_score_combined_{suffix}"] < 0: + data_score[f"cv_score_combined_{suffix}"] = 0 + data_score[f"test_score_combined_{suffix}"] = ( + data_score[f"test_score_rmse_{suffix}"] + + data_score[f"test_penalty_r2_{suffix}"] + ) + if data_score[f"test_score_combined_{suffix}"] < 0: + data_score[f"test_score_combined_{suffix}"] = 0 diff_score = 0 # relative difference between RMSE from test and CV - data_score[f'factor_scaled_rmse_{suffix}'] = data_score[f'scaled_rmse_test_{suffix}'] / data_score[f'scaled_rmse_cv_{suffix}'] - if data_score[f'factor_scaled_rmse_{suffix}'] <= 1.25: + data_score[f"factor_scaled_rmse_{suffix}"] = ( + data_score[f"scaled_rmse_test_{suffix}"] + / data_score[f"scaled_rmse_cv_{suffix}"] + ) + if data_score[f"factor_scaled_rmse_{suffix}"] <= 1.25: diff_score += 2 - elif data_score[f'factor_scaled_rmse_{suffix}'] <= 1.5: + elif data_score[f"factor_scaled_rmse_{suffix}"] <= 1.5: diff_score += 1 - data_score[f'diff_scaled_rmse_score_{suffix}'] = diff_score - - elif pred_type == 'clas': # Process classification: using MCC extracted from CV and Test results + data_score[f"diff_scaled_rmse_score_{suffix}"] = diff_score + + elif ( + pred_type == "clas" + ): # Process classification: using MCC extracted from CV and Test results # Extract MCC from the 10x 5-fold CV line - if '5-fold' in dat_predict[i+5]: - parts = dat_predict[i+5].split(',') + if "5-fold" in dat_predict[i + 5]: + parts = dat_predict[i + 5].split(",") mcc_cv = None for part in parts: - if 'MCC' in part: - mcc_cv = float(part.split('=')[-1]) + if "MCC" in part: + mcc_cv = float(part.split("=")[-1]) break if mcc_cv is not None: - data_score[f'r2_cv_{suffix}'] = mcc_cv # storing MCC in a key keyed as r2_cv for consistency + data_score[f"r2_cv_{suffix}"] = ( + mcc_cv # storing MCC in a key keyed as r2_cv for consistency + ) # Extract MCC from the Test line - if '- Test :' in dat_predict[i+6]: - parts = dat_predict[i+6].split(',') + if "- Test :" in dat_predict[i + 6]: + parts = dat_predict[i + 6].split(",") mcc_test = None for part in parts: - if 'MCC' in part: - mcc_test = float(part.split('=')[-1]) + if "MCC" in part: + mcc_test = float(part.split("=")[-1]) break if mcc_test is not None: - data_score[f'r2_test_{suffix}'] = mcc_test + data_score[f"r2_test_{suffix}"] = mcc_test # Compute CV and Test scores using the classification thresholds in score_rmse_mcc - data_score[f'cv_score_rmse_{suffix}'] = score_rmse_mcc(pred_type, data_score.get(f'r2_cv_{suffix}', 0)) - data_score[f'test_score_rmse_{suffix}'] = score_rmse_mcc(pred_type, data_score.get(f'r2_test_{suffix}', 0)) - + data_score[f"cv_score_rmse_{suffix}"] = score_rmse_mcc( + pred_type, data_score.get(f"r2_cv_{suffix}", 0) + ) + data_score[f"test_score_rmse_{suffix}"] = score_rmse_mcc( + pred_type, data_score.get(f"r2_test_{suffix}", 0) + ) + # For classification, the combined score is simply the score from MCC (no additional penalty) - data_score[f'cv_score_combined_{suffix}'] = data_score[f'cv_score_rmse_{suffix}'] - data_score[f'test_score_combined_{suffix}'] = data_score[f'test_score_rmse_{suffix}'] + data_score[f"cv_score_combined_{suffix}"] = data_score[ + f"cv_score_rmse_{suffix}" + ] + data_score[f"test_score_combined_{suffix}"] = data_score[ + f"test_score_rmse_{suffix}" + ] # SD from CV - if pred_type == 'reg': - if '- Average SD in test set' in line: + if pred_type == "reg": + if "- Average SD in test set" in line: cv_sd = float(line.split()[-1]) - cv_4sd = 4*cv_sd - y_range_covered = cv_4sd/data_score[f'y_range_{suffix}'] + cv_4sd = 4 * cv_sd + y_range_covered = cv_4sd / data_score[f"y_range_{suffix}"] cv_sd_score = 0 if y_range_covered <= 0.25: @@ -1025,42 +1171,42 @@ def get_predict_scores(dat_predict,suffix,pred_type,data_score): data_score[f"cv_4sd_{suffix}"] = cv_4sd data_score[f"cv_range_cov_{suffix}"] = y_range_covered - data_score[f'cv_sd_score_{suffix}'] = cv_sd_score + data_score[f"cv_sd_score_{suffix}"] = cv_sd_score return data_score -def score_rmse_mcc(pred_type,scaledrmse_mcc_val): - ''' +def score_rmse_mcc(pred_type, scaledrmse_mcc_val): + """ Calculate scores for R2 and MCC using predetermined thresholds - + For regression (scaled RMSE): 0-2 points For classification (MCC): 0-3 points - ''' + """ r2_mcc_score = 0 - if pred_type == 'reg': # scaled RMSE + if pred_type == "reg": # scaled RMSE if scaledrmse_mcc_val <= 10: r2_mcc_score += 2 elif scaledrmse_mcc_val <= 20: r2_mcc_score += 1 - else: # MCC + else: # MCC if scaledrmse_mcc_val > 0.75: r2_mcc_score += 3 elif scaledrmse_mcc_val > 0.5: r2_mcc_score += 2 elif scaledrmse_mcc_val > 0.3: r2_mcc_score += 1 - + return r2_mcc_score def calc_penalty_r2(r2_val): - ''' + """ Calculate scores for R2 and MCC using predetermined thresholds - ''' + """ penalty_r2 = 0 @@ -1068,7 +1214,7 @@ def calc_penalty_r2(r2_val): penalty_r2 -= 2 elif r2_val < 0.7: penalty_r2 -= 1 - + return penalty_r2 @@ -1077,35 +1223,38 @@ def repro_info(modules): Retrieves variables used in the Reproducibility section """ - version_n_date, citation, command_line = '','','' - python_version, total_time = '',0 + version_n_date, citation, command_line = "", "", "" + python_version, total_time = "", 0 dat_files = {} for module in modules: - path_file = Path(f'{os.getcwd()}/{module}/{module}_data.dat') + path_file = Path(f"{os.getcwd()}/{module}/{module}_data.dat") if os.path.exists(path_file): - datfile = open(path_file, 'r', encoding= 'utf-8', errors="replace") + datfile = open(path_file, "r", encoding="utf-8", errors="replace") txt_file = [] for line in datfile: txt_file.append(line) - if 'Time' in line and 'seconds' in line: + if "Time" in line and "seconds" in line: total_time += float(line.split()[2]) - if 'How to cite: ' in line: - citation = line.split('How to cite: ')[1] - if 'ROBERT v' == line[:8]: + if "How to cite: " in line: + citation = line.split("How to cite: ")[1] + if "ROBERT v" == line[:8]: version_n_date = line - if 'Command line used in ROBERT: ' in line: - if '--csv_name' not in command_line: # ensures that the value for --csv_name is stored - command_line = line.split('Command line used in ROBERT: ')[1] - total_time = round(total_time,2) + if "Command line used in ROBERT: " in line: + if ( + "--csv_name" not in command_line + ): # ensures that the value for --csv_name is stored + command_line = line.split("Command line used in ROBERT: ")[1] + total_time = round(total_time, 2) dat_files[module] = txt_file datfile.close() - + try: import platform + python_version = platform.python_version() - except: - python_version = '(version could not be determined)' - + except Exception: + python_version = "(version could not be determined)" + return version_n_date, citation, command_line, python_version, total_time, dat_files @@ -1120,7 +1269,9 @@ def make_report(report_html, HTML): try: os.remove(outfile) except PermissionError: - print('\nx ROBERT_report.pdf is open! Please, close the PDF file and run ROBERT again with --report (i.e., "python -m robert --report").') + print( + '\nx ROBERT_report.pdf is open! Please, close the PDF file and run ROBERT again with --report (i.e., "python -m robert --report").' + ) sys.exit() pdf = make_pdf(report_html, HTML, css_files) _ = Path(outfile).write_bytes(pdf) @@ -1136,7 +1287,7 @@ def make_pdf(html, HTML, css_files): return htmldoc -def css_content(csv_name,robert_version): +def css_content(csv_name, robert_version): """ Obtain ROBERT version and CSV name to use it on top of the PDF report """ @@ -1266,50 +1417,63 @@ def css_content(csv_name,robert_version): return css_content -def format_lines(module_data, max_width=122, cmd_line=False, one_column=False, spacing=''): +def format_lines( + module_data, max_width=122, cmd_line=False, one_column=False, spacing="" +): """ Reads a file and returns a formatted string between two markers """ formatted_lines = [] - lines = module_data.split('\n') - for i,line in enumerate(lines): - if 'R2' in line: - line = line.replace('R2','R2') + lines = module_data.split("\n") + for i, line in enumerate(lines): + if "R2" in line: + line = line.replace("R2", "R2") if cmd_line: - formatted_line = textwrap.fill(line, width=max_width-5, subsequent_indent='') + formatted_line = textwrap.fill( + line, width=max_width - 5, subsequent_indent="" + ) else: - formatted_line = textwrap.fill(line, width=max_width, subsequent_indent='') + formatted_line = textwrap.fill(line, width=max_width, subsequent_indent="") if i > 0: - formatted_lines.append(f'
\n{formatted_line}
') + formatted_lines.append( + f'
\n{formatted_line}
' + ) else: - formatted_lines.append(f'
{formatted_line}
\n') + formatted_lines.append( + f'
{formatted_line}
\n' + ) # for two columns if not one_column: - return ''.join(formatted_lines) - + return "".join(formatted_lines) + # for one column - one_col_lines = '' - for line in ''.join(formatted_lines).split('\n'): - if line.startswith('
') and line != '
':
-            one_col_lines += line.replace('
',f'
{spacing*3}')
-        elif not line.startswith('<'):
-            one_col_lines += f'\n{spacing*3}{line}'
+    one_col_lines = ""
+    for line in "".join(formatted_lines).split("\n"):
+        if (
+            line.startswith('
')
+            and line != '
'
+        ):
+            one_col_lines += line.replace(
+                '
',
+                f'
{spacing * 3}',
+            )
+        elif not line.startswith("<"):
+            one_col_lines += f"\n{spacing * 3}{line}"
         else:
-            one_col_lines += f'\n{line}'
+            one_col_lines += f"\n{line}"
     return one_col_lines
 
 
-
-def get_spacing_col(suffix,spacing_PFI):
-    '''
+def get_spacing_col(suffix, spacing_PFI):
+    """
     Assign spacing of column
-    '''
-    
-    if suffix == 'No PFI':
-        spacing = ''
-    elif suffix == 'PFI':
+    """
+
+    if suffix == "No PFI":
+        spacing = ""
+    elif suffix == "PFI":
         spacing = spacing_PFI
-    
-    return spacing
\ No newline at end of file
+
+    return spacing
diff --git a/robert/robert.py b/robert/robert.py
index 7e370d5..5602b4d 100644
--- a/robert/robert.py
+++ b/robert/robert.py
@@ -35,16 +35,16 @@
 from robert.report import report
 from robert.aqme import aqme
 from robert.evaluate import evaluate
-from robert.utils import (command_line_args,missing_inputs)
+from robert.utils import command_line_args, missing_inputs
 
 
-def main(exe_type='command',sys_args=None):
+def main(exe_type="command", sys_args=None):
     """
     Main function of ROBERT, acts as the starting point when the program is run through a terminal
     """
 
     # load user-defined arguments from command line
-    args = command_line_args(exe_type,sys_args)
+    args = command_line_args(exe_type, sys_args)
     args.command_line = True
 
     if not args.evaluate:
@@ -53,7 +53,7 @@ def main(exe_type='command',sys_args=None):
         if not args.curate and not args.generate and not args.predict:
             if not args.cheers and not args.verify and not args.report:
                 full_workflow = True
-        
+
         if args.aqme:
             full_workflow = True
 
@@ -63,19 +63,23 @@ def main(exe_type='command',sys_args=None):
         # save the csv_name, y and names values from full workflows
         if full_workflow:
             # remove the EVALUATE folder to avoid issues when generating the report PDF
-            eval_folder = Path(f'{os.getcwd()}/EVALUATE')
+            eval_folder = Path(f"{os.getcwd()}/EVALUATE")
             if os.path.exists(eval_folder):
                 shutil.rmtree(eval_folder)
 
-            args = missing_inputs(args,'full_workflow',print_err=True)
+            args = missing_inputs(args, "full_workflow", print_err=True)
 
         # AQME
         if args.aqme:
             aqme(**vars(args))
             # set the path to the database created by AQME to continue in the full_workflow
-            args.csv_name = Path(os.path.dirname(args.csv_name)).joinpath(f'AQME-ROBERT_{args.descp_lvl}_{os.path.basename(args.csv_name)}')
-            if args.csv_test != '':
-                args.csv_test = Path(os.path.dirname(args.csv_test)).joinpath(f'AQME-ROBERT_{args.descp_lvl}_{os.path.basename(args.csv_test)}')
+            args.csv_name = Path(os.path.dirname(args.csv_name)).joinpath(
+                f"AQME-ROBERT_{args.descp_lvl}_{os.path.basename(args.csv_name)}"
+            )
+            if args.csv_test != "":
+                args.csv_test = Path(os.path.dirname(args.csv_test)).joinpath(
+                    f"AQME-ROBERT_{args.descp_lvl}_{os.path.basename(args.csv_test)}"
+                )
 
         # CURATE
         if args.curate or full_workflow:
@@ -83,9 +87,9 @@ def main(exe_type='command',sys_args=None):
 
         if full_workflow:
             # this ensures GENERATE communicates with CURATE (see the load_variables() function in utils.py)
-            args.y = ''
-            args.discard = [] # avoids an error since the variable(s) are removed in CURATE
-            args.csv_name = '' # force GENERATE to use the curated database
+            args.y = ""
+            args.discard = []  # avoids an error since the variable(s) are removed in CURATE
+            args.csv_name = ""  # force GENERATE to use the curated database
 
         # GENERATE
         if args.generate or full_workflow:
@@ -102,10 +106,12 @@ def main(exe_type='command',sys_args=None):
         # REPORT
         if args.report or full_workflow:
             report(**vars(args))
-        
+
         # CHEERS
         if args.cheers:
-            print('o  This module was designed to thank ROBERT Paton, who was a mentor to me throughout my years at Colorado State University, and who introduced me to the field of cheminformatics. Cheers mate!\n')
+            print(
+                "o  This module was designed to thank ROBERT Paton, who was a mentor to me throughout my years at Colorado State University, and who introduced me to the field of cheminformatics. Cheers mate!\n"
+            )
 
     # EVALUATE, only evaluates models
     else:
@@ -116,7 +122,7 @@ def main(exe_type='command',sys_args=None):
         curate(**vars(args))
 
         # Ignore the Set column created in EVALUATE inside the CSV of the GENERATE folder
-        args.ignore.append('Set')
+        args.ignore.append("Set")
 
         # VERIFY
         verify(**vars(args))
@@ -134,13 +140,26 @@ def set_aqme_args(args):
     """
 
     if os.path.exists(args.csv_name):
-        aqme_df = pd.read_csv(args.csv_name, encoding='utf-8')
+        aqme_df = pd.read_csv(args.csv_name, encoding="utf-8")
     else:
-        print(f'\nx  The path of your CSV file doesn\'t exist! You specified: {args.csv_name}')
+        print(
+            f"\nx  The path of your CSV file doesn't exist! You specified: {args.csv_name}"
+        )
         sys.exit()
 
     # list of potential arguments from CSV inputs in AQME
-    aqme_args = ['smiles','charge','mult','complex_type','geom','constraints_atoms','constraints_dist','constraints_angle','constraints_dihedral','sample']
+    aqme_args = [
+        "smiles",
+        "charge",
+        "mult",
+        "complex_type",
+        "geom",
+        "constraints_atoms",
+        "constraints_dist",
+        "constraints_angle",
+        "constraints_dihedral",
+        "sample",
+    ]
 
     # ignore the names and SMILES of the molecules
     remove = []
@@ -149,10 +168,10 @@ def set_aqme_args(args):
             remove.append(column)
     for column in remove:
         args.ignore.remove(column)
-    if 'code_name' in args.ignore:
-        args.ignore.remove('code_name')
+    if "code_name" in args.ignore:
+        args.ignore.remove("code_name")
     for column in aqme_df.columns:
-        if column.lower() == 'code_name' and args.names == '':
+        if column.lower() == "code_name" and args.names == "":
             args.names = column
 
     return args
diff --git a/robert/uq_auto.py b/robert/uq_auto.py
index 0a0ea99..c9e5588 100644
--- a/robert/uq_auto.py
+++ b/robert/uq_auto.py
@@ -12,7 +12,7 @@
 import json
 import warnings
 from pathlib import Path
-from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
+from typing import Any, Dict, List, Mapping, Optional, Sequence
 
 import numpy as np
 from scipy.stats import norm
@@ -34,7 +34,9 @@ def _as_float_array(x: Sequence[float]) -> np.ndarray:
     return np.asarray(x, dtype=float).ravel()
 
 
-def _normalize_metric_weights(weights: Optional[Mapping[str, float]]) -> Dict[str, float]:
+def _normalize_metric_weights(
+    weights: Optional[Mapping[str, float]],
+) -> Dict[str, float]:
     base = dict(DEFAULT_METRIC_WEIGHTS)
     if weights is not None:
         for key in base:
@@ -131,7 +133,9 @@ def apply_uncertainty_scaler(
         knots_r = np.asarray(params.get("knots_r", []), dtype=float)
         if knots_u.size == 0:
             return u
-        return np.maximum(np.interp(u, knots_u, knots_r, left=knots_r[0], right=knots_r[-1]), 0.0)
+        return np.maximum(
+            np.interp(u, knots_u, knots_r, left=knots_r[0], right=knots_r[-1]), 0.0
+        )
 
     raise ValueError(f"Unknown scaler method in params: {method!r}")
 
@@ -149,8 +153,8 @@ def _coverage_error(abs_resid: np.ndarray, sigma: np.ndarray, coverage: float) -
 def _gaussian_nll(abs_resid: np.ndarray, sigma: np.ndarray) -> float:
     sigma = np.maximum(sigma, 1e-12)
     # NLL for Laplace-like on abs residual under Gaussian proxy
-    var = sigma ** 2
-    return float(np.mean(0.5 * np.log(2.0 * np.pi * var) + 0.5 * (abs_resid ** 2) / var))
+    var = sigma**2
+    return float(np.mean(0.5 * np.log(2.0 * np.pi * var) + 0.5 * (abs_resid**2) / var))
 
 
 def _sharpness(sigma: np.ndarray) -> float:
@@ -177,7 +181,9 @@ def score_uncertainty_candidate(
 
 def _oof_mean_train(Xy_data: Mapping[str, Any]) -> np.ndarray:
     preds_all = Xy_data.get("y_pred_train_all", [])
-    return np.array([float(np.mean(p)) if len(p) else np.nan for p in preds_all], dtype=float)
+    return np.array(
+        [float(np.mean(p)) if len(p) else np.nan for p in preds_all], dtype=float
+    )
 
 
 def _train_abs_residuals(Xy_data: Mapping[str, Any]) -> np.ndarray:
@@ -275,7 +281,9 @@ def evaluate_uq_candidates(
 
     Returns dict with keys: selected, scaler_params, candidate_scores, coverage, n_eval.
     """
-    candidates_cfg = getattr(args, "uq_auto_candidates", None) or list(DEFAULT_CANDIDATES)
+    candidates_cfg = getattr(args, "uq_auto_candidates", None) or list(
+        DEFAULT_CANDIDATES
+    )
     if isinstance(candidates_cfg, str):
         candidates_cfg = [c.strip() for c in candidates_cfg.split(",") if c.strip()]
 
@@ -320,7 +328,9 @@ def evaluate_uq_candidates(
             params = fit_uncertainty_scaler(
                 scaler_method, u_raw[fit_ix], abs_resid[fit_ix]
             )
-            u_scaled_eval = apply_uncertainty_scaler(scaler_method, u_raw[eval_ix], params)
+            u_scaled_eval = apply_uncertainty_scaler(
+                scaler_method, u_raw[eval_ix], params
+            )
             score = score_uncertainty_candidate(
                 u_scaled_eval, abs_resid[eval_ix], coverage, metric_weights
             )
@@ -409,8 +419,7 @@ def apply_auto_uq(
     meta_path = Path("PREDICT") / "uq_auto_metadata.json"
     meta_path.parent.mkdir(parents=True, exist_ok=True)
     serializable = {
-        k: (v if not isinstance(v, dict) else dict(v))
-        for k, v in selection.items()
+        k: (v if not isinstance(v, dict) else dict(v)) for k, v in selection.items()
     }
     with meta_path.open("w", encoding="utf-8") as fh:
         json.dump(serializable, fh, indent=2)
diff --git a/robert/utils.py b/robert/utils.py
index 88c4c78..716a388 100644
--- a/robert/utils.py
+++ b/robert/utils.py
@@ -21,21 +21,30 @@
 # This prevents numerical differences between Windows/Ubuntu in parallel operations
 os.environ["LOKY_MAX_CPU_COUNT"] = "1"
 from matplotlib import pyplot as plt
+import seaborn as sb
+from bayes_opt import BayesianOptimization
 import matplotlib.patches as mpatches
 import matplotlib.colors as mcolor
 from matplotlib.legend_handler import HandlerPatch
 from matplotlib.ticker import FormatStrFormatter
 from scipy import stats
 from importlib.resources import files
+
 # sklearnex was deactivated in ROBERT v2.1 because it only accelerated RF
 # try:
 #     from sklearnex import patch_sklearn
 #     patch_sklearn(verbose=True)
 # except (ModuleNotFoundError,ImportError):
 #     pass
-from sklearn.metrics import (mean_absolute_error, mean_squared_error,
-                             matthews_corrcoef, accuracy_score, f1_score, make_scorer,
-                             ConfusionMatrixDisplay)
+from sklearn.metrics import (
+    mean_absolute_error,
+    mean_squared_error,
+    matthews_corrcoef,
+    accuracy_score,
+    f1_score,
+    make_scorer,
+    ConfusionMatrixDisplay,
+)
 from sklearn.feature_selection import RFECV
 from sklearn.preprocessing import StandardScaler
 from sklearn.ensemble import (
@@ -47,19 +56,27 @@
     AdaBoostClassifier,
     VotingRegressor,
     VotingClassifier,
-    )
+)
 from xgboost import XGBClassifier, XGBRegressor
 from sklearn.gaussian_process import GaussianProcessRegressor, GaussianProcessClassifier
 from sklearn.neural_network import MLPRegressor, MLPClassifier
 from sklearn.linear_model import LinearRegression
 from sklearn.impute import KNNImputer
 from sklearn.base import clone
-from sklearn.model_selection import train_test_split, cross_val_score, StratifiedShuffleSplit, RepeatedKFold, KFold, StratifiedKFold
+from sklearn.model_selection import (
+    train_test_split,
+    cross_val_score,
+    StratifiedShuffleSplit,
+    RepeatedKFold,
+    KFold,
+    StratifiedKFold,
+)
 from sklearn.cluster import KMeans
 from sklearn.inspection import permutation_importance
 from sklearn.exceptions import ConvergenceWarning
 from robert.argument_parser import set_options, var_dict
-import warnings # this avoids warnings from sklearn
+import warnings  # this avoids warnings from sklearn
+
 warnings.filterwarnings("ignore")
 
 
@@ -93,12 +110,18 @@ def should_plot_verify_metrics(args) -> bool:
 
 def should_plot_predict_results(args) -> bool:
     """Main PREDICT figures (e.g. Results_*); tied to predict_diagnostics for backward compatibility."""
-    return bool(getattr(args, "predict_diagnostics", True)) and plot_verbosity_level(args) >= 1
+    return (
+        bool(getattr(args, "predict_diagnostics", True))
+        and plot_verbosity_level(args) >= 1
+    )
 
 
 def should_plot_predict_deep_diagnostics(args) -> bool:
     """SHAP, PFI, Pearson heatmap, outliers, distribution plots."""
-    return bool(getattr(args, "predict_diagnostics", True)) and plot_verbosity_level(args) >= 2
+    return (
+        bool(getattr(args, "predict_diagnostics", True))
+        and plot_verbosity_level(args) >= 2
+    )
 
 
 robert_version = "2.1.0"
@@ -130,7 +153,7 @@ def load_from_yaml(self):
                 try:
                     loaded = yaml.load(file, Loader=yaml.SafeLoader)
                     param_list = loaded if isinstance(loaded, dict) else {}
-                except (yaml.scanner.ScannerError,yaml.parser.ParserError):
+                except (yaml.scanner.ScannerError, yaml.parser.ParserError):
                     txt_yaml = f'\nx  Error while reading {self.varfile}. Edit the yaml file and try again (i.e. use ":" instead of "=" to specify variables)'
                     error_yaml = True
         else:
@@ -186,12 +209,12 @@ def finalize(self):
         self.log.close()
 
 
-def command_line_args(exe_type,sys_args):
+def command_line_args(exe_type, sys_args):
     """
     Load default and user-defined arguments specified through command lines. Arrguments are loaded as a dictionary
     """
 
-    if exe_type == 'exe':
+    if exe_type == "exe":
         # Simulate sys.argv for use in an executable environment
         sys.argv = ["launcher.exe"]
         for k, v in sys_args.items():
@@ -220,14 +243,14 @@ def command_line_args(exe_type,sys_args):
         "report_modules",
     ]
     int_args = [
-        'pfi_epochs',
-        'epochs',
-        'nprocs',
-        'pfi_max',
-        'kfold',
-        'repeat_kfolds',
-        'shap_show',
-        'pfi_show',
+        "pfi_epochs",
+        "epochs",
+        "nprocs",
+        "pfi_max",
+        "kfold",
+        "repeat_kfolds",
+        "shap_show",
+        "pfi_show",
         "seed",
         "init_points",
         "n_iter",
@@ -237,16 +260,16 @@ def command_line_args(exe_type,sys_args):
         "uq_auto_random_state",
     ]
     float_args = [
-        'pfi_threshold',
-        't_value',
-        'thres_x',
-        'thres_y',
-        'test_set',
-        'desc_thres',
-        'alpha',
-        'expect_improv',
-        'conformal_calib_frac',
-        'conformal_coverage',
+        "pfi_threshold",
+        "t_value",
+        "thres_x",
+        "thres_y",
+        "test_set",
+        "desc_thres",
+        "alpha",
+        "expect_improv",
+        "conformal_calib_frac",
+        "conformal_coverage",
     ]
 
     for arg in var_dict:
@@ -346,12 +369,12 @@ def command_line_args(exe_type,sys_args):
             sys.exit()
         else:
             # this "if" allows to use * to select multiple files in multiple OS
-            if arg_name.lower() == 'files' and value.find('*') > -1:
+            if arg_name.lower() == "files" and value.find("*") > -1:
                 kwargs[arg_name] = glob.glob(value)
             else:
                 # converts the string parameters from command line to the right format
                 if arg_name in bool_args:
-                    value = True                    
+                    value = True
                 elif arg_name.lower() in list_args:
                     value = format_lists(value)
                 elif arg_name.lower() in int_args:
@@ -374,18 +397,20 @@ def command_line_args(exe_type,sys_args):
 
 
 def format_lists(value):
-    '''
+    """
     Transforms strings into a list
-    '''
+    """
 
     if not isinstance(value, list):
         try:
             value = ast.literal_eval(value)
         except (SyntaxError, ValueError):
             # this line fixes issues when using "[X]" or ["X"] instead of "['X']" when using lists
-            value = value.replace('[',']').replace(',',']').replace("'",']').split(']')
-            while('' in value):
-                value.remove('')
+            value = (
+                value.replace("[", "]").replace(",", "]").replace("'", "]").split("]")
+            )
+            while "" in value:
+                value.remove("")
 
     # remove extra spaces that sometimes are included by mistake
     value = [ele.strip() if isinstance(ele, str) else ele for ele in value]
@@ -407,34 +432,44 @@ def load_variables(kwargs, robert_module):
         self, txt_yaml = load_from_yaml(self)
 
     # check if user used .csv in csv_name
-    if not os.path.exists(f"{self.csv_name}") and os.path.exists(f'{self.csv_name}.csv'):
-        self.csv_name = f'{self.csv_name}.csv'
+    if not os.path.exists(f"{self.csv_name}") and os.path.exists(
+        f"{self.csv_name}.csv"
+    ):
+        self.csv_name = f"{self.csv_name}.csv"
 
     # check if user used .csv in csv_test
-    if self.csv_test and not os.path.exists(f"{self.csv_test}") and os.path.exists(f'{self.csv_test}.csv'):
-        self.csv_test = f'{self.csv_test}.csv'
+    if (
+        self.csv_test
+        and not os.path.exists(f"{self.csv_test}")
+        and os.path.exists(f"{self.csv_test}.csv")
+    ):
+        self.csv_test = f"{self.csv_test}.csv"
 
     # check for spaces in csv file names
     if " " in str(self.csv_name):
-        print("\nx  ERROR: The input CSV file name contains spaces. Please remove spaces from the file name and try again. Spaces in file names can cause problems. Example: use 'my_data.csv' instead of 'my data.csv'.")
+        print(
+            "\nx  ERROR: The input CSV file name contains spaces. Please remove spaces from the file name and try again. Spaces in file names can cause problems. Example: use 'my_data.csv' instead of 'my data.csv'."
+        )
         sys.exit()
     if self.csv_test and " " in str(self.csv_test):
-        print("\nx  ERROR: The test CSV file name contains spaces. Please remove spaces from the file name and try again. Spaces in file names can cause problems. Example: use 'test_data.csv' instead of 'test data.csv'.")
+        print(
+            "\nx  ERROR: The test CSV file name contains spaces. Please remove spaces from the file name and try again. Spaces in file names can cause problems. Example: use 'test_data.csv' instead of 'test data.csv'."
+        )
         sys.exit()
 
     if robert_module != "command":
         self.initial_dir = Path(os.getcwd())
 
         # adds --names to --ignore
-        if self.names not in self.ignore and self.names != '':
+        if self.names not in self.ignore and self.names != "":
             self.ignore.append(self.names)
 
         # creates destination folder
-        if robert_module.upper() != 'REPORT':
-            self = destination_folder(self,robert_module)
+        if robert_module.upper() != "REPORT":
+            self = destination_folder(self, robert_module)
 
             # start a log file
-            logger_1 = 'ROBERT'
+            logger_1 = "ROBERT"
             logger_1, logger_2 = robert_module.upper(), "data"
 
             if txt_yaml not in [
@@ -448,134 +483,168 @@ def load_variables(kwargs, robert_module):
                 sys.exit()
 
             self.log = Logger(self.destination / logger_1, logger_2)
-            self.log.write(f"ROBERT v {robert_version} {time_run} \nHow to cite: {robert_ref}\n")
+            self.log.write(
+                f"ROBERT v {robert_version} {time_run} \nHow to cite: {robert_ref}\n"
+            )
 
             if self.command_line:
-                cmd_print = ''
+                cmd_print = ""
                 cmd_args = sys.argv[1:]
-                if self.extra_cmd != '':
+                if self.extra_cmd != "":
                     for arg in self.extra_cmd.split():
                         cmd_args.append(arg)
-                for i,elem in enumerate(cmd_args):
-                    if elem[0] in ['"',"'"]:
+                for i, elem in enumerate(cmd_args):
+                    if elem[0] in ['"', "'"]:
                         elem = elem[1:]
-                    if elem[-1] in ['"',"'"]:
+                    if elem[-1] in ['"', "'"]:
                         elem = elem[:-1]
-                    if elem != '-h' and elem.split('--')[-1] not in var_dict:
+                    if elem != "-h" and elem.split("--")[-1] not in var_dict:
                         # parse single elements of the list as strings (otherwise the commands cannot be reproduced)
-                        if '--qdescp_atoms' in elem:
+                        if "--qdescp_atoms" in elem:
                             new_arg = []
-                            list_qdescp = elem.replace(', ',',').replace(' ,',',').split()
-                            for j,qdescp_elem in enumerate(list_qdescp):
-                                if list_qdescp[j-1] == '--qdescp_atoms':
+                            list_qdescp = (
+                                elem.replace(", ", ",").replace(" ,", ",").split()
+                            )
+                            for j, qdescp_elem in enumerate(list_qdescp):
+                                if list_qdescp[j - 1] == "--qdescp_atoms":
                                     qdescp_elem = qdescp_elem[1:-1]
                                     new_elem = []
-                                    for smarts_strings in qdescp_elem.split(','):
-                                        new_elem.append(f'{smarts_strings}'.replace("'",''))
-                                    new_arg.append(f'{new_elem}'.replace(" ",""))
+                                    for smarts_strings in qdescp_elem.split(","):
+                                        new_elem.append(
+                                            f"{smarts_strings}".replace("'", "")
+                                        )
+                                    new_arg.append(f"{new_elem}".replace(" ", ""))
                                 else:
                                     new_arg.append(qdescp_elem)
-                            new_arg = ' '.join(new_arg)
+                            new_arg = " ".join(new_arg)
                             elem = new_arg
-                        if cmd_args[i-1].split('--')[-1] in var_dict: # check if the previous word is an arg
+                        if (
+                            cmd_args[i - 1].split("--")[-1] in var_dict
+                        ):  # check if the previous word is an arg
                             cmd_print += f'"{elem}'
-                        if i == len(cmd_args)-1 or cmd_args[i+1].split('--')[-1] in var_dict: # check if the next word is an arg, or last word in command
-                            cmd_print += f'"'
+                        if (
+                            i == len(cmd_args) - 1
+                            or cmd_args[i + 1].split("--")[-1] in var_dict
+                        ):  # check if the next word is an arg, or last word in command
+                            cmd_print += '"'
                     else:
-                        cmd_print += f'{elem}'
-                    if i != len(cmd_args)-1:
-                        cmd_print += ' '
+                        cmd_print += f"{elem}"
+                    if i != len(cmd_args) - 1:
+                        cmd_print += " "
 
-                self.log.write(f"Command line used in ROBERT: python -m robert {cmd_print}\n")
+                self.log.write(
+                    f"Command line used in ROBERT: python -m robert {cmd_print}\n"
+                )
 
-        elif robert_module.upper() == 'REPORT':
+        elif robert_module.upper() == "REPORT":
             self.path_icons = files("robert").joinpath("report")
 
         # sklearnex was deactivated in ROBERT v2.1 because it only accelerated RF
         # using or not the intelex accelerator might affect the results
         # if robert_module.upper() in ['GENERATE','VERIFY','PREDICT']:
-            # try:
-            #     import sklearnex
-            #     pass
-            # except (ModuleNotFoundError,ImportError):
-            #     self.log.write(f"\nx  WARNING! The scikit-learn-intelex accelerator is not installed, the results might vary if it is installed and the execution times might become much longer (if available, use 'pip install scikit-learn-intelex')")
+        # try:
+        #     import sklearnex
+        #     pass
+        # except (ModuleNotFoundError,ImportError):
+        #     self.log.write(f"\nx  WARNING! The scikit-learn-intelex accelerator is not installed, the results might vary if it is installed and the execution times might become much longer (if available, use 'pip install scikit-learn-intelex')")
 
-        if robert_module.upper() in ['GENERATE', 'VERIFY']:
+        if robert_module.upper() in ["GENERATE", "VERIFY"]:
             # adjust the default value of error_type for classification
-            if self.type.lower() == 'clas':
-                if self.error_type not in ['acc', 'mcc', 'f1']:
-                    self.error_type = 'mcc'
+            if self.type.lower() == "clas":
+                if self.error_type not in ["acc", "mcc", "f1"]:
+                    self.error_type = "mcc"
 
-        if robert_module.upper() in ['PREDICT','VERIFY','REPORT']:
-            if self.params_dir == '':
-                self.params_dir = 'GENERATE/Best_model'
+        if robert_module.upper() in ["PREDICT", "VERIFY", "REPORT"]:
+            if self.params_dir == "":
+                self.params_dir = "GENERATE/Best_model"
 
-        if robert_module.upper() in ['CURATE','GENERATE']:
-            if self.type.lower() == 'clas':
+        if robert_module.upper() in ["CURATE", "GENERATE"]:
+            if self.type.lower() == "clas":
                 if any(m.upper() == "MVL" for m in self.model):
-                    self.model = [x if x.upper() != 'MVL' else 'AdaB' for x in self.model]
-            
-            models_gen = [] # use capital letters in all the models
+                    self.model = [
+                        x if x.upper() != "MVL" else "AdaB" for x in self.model
+                    ]
+
+            models_gen = []  # use capital letters in all the models
             for model_type in self.model:
                 models_gen.append(model_type.upper())
             self.model = models_gen
 
-        if robert_module.upper() == 'CURATE':
-            self.log.write(f"\no  Starting data curation with the CURATE module")
+        if robert_module.upper() == "CURATE":
+            self.log.write("\no  Starting data curation with the CURATE module")
+
+        elif robert_module.upper() == "GENERATE":
+            self.log.write(
+                "\no  Starting generation of ML models with the GENERATE module"
+            )
 
-        elif robert_module.upper() == 'GENERATE':
-            self.log.write(f"\no  Starting generation of ML models with the GENERATE module")
-            
             # Check if the folders exist and if they do, delete and replace them
-            folder_names = [self.initial_dir.joinpath('GENERATE/Best_model/No_PFI'), self.initial_dir.joinpath('GENERATE/Raw_data/No_PFI')]
+            folder_names = [
+                self.initial_dir.joinpath("GENERATE/Best_model/No_PFI"),
+                self.initial_dir.joinpath("GENERATE/Raw_data/No_PFI"),
+            ]
             if self.pfi_filter:
-                folder_names.append(self.initial_dir.joinpath('GENERATE/Best_model/PFI'))
-                folder_names.append(self.initial_dir.joinpath('GENERATE/Raw_data/PFI'))
+                folder_names.append(
+                    self.initial_dir.joinpath("GENERATE/Best_model/PFI")
+                )
+                folder_names.append(self.initial_dir.joinpath("GENERATE/Raw_data/PFI"))
             _ = create_folders(folder_names)
 
             # if there are missing options, look for them from a previous CURATE job (if any)
             options_dict = {
-                'y': self.y,
-                'names': self.names,
-                'ignore': self.ignore,
-                'csv_name': self.csv_name
+                "y": self.y,
+                "names": self.names,
+                "ignore": self.ignore,
+                "csv_name": self.csv_name,
             }
-            curate_folder = Path(self.initial_dir).joinpath('CURATE')
-            curate_csv = f'{curate_folder}/CURATE_options.csv'
+            curate_folder = Path(self.initial_dir).joinpath("CURATE")
+            curate_csv = f"{curate_folder}/CURATE_options.csv"
             if os.path.exists(curate_csv):
-                curate_df = pd.read_csv(curate_csv, encoding='utf-8')
+                curate_df = pd.read_csv(curate_csv, encoding="utf-8")
 
                 for option in options_dict:
-                    if options_dict[option] == '':
-                        if option == 'y':
-                            self.y = curate_df['y'][0]
-                        elif option == 'names':
-                            self.names = curate_df['names'][0]
-                        elif option == 'ignore':
-                            self.ignore = curate_df['ignore'][0]
-                            self.ignore  = format_lists(self.ignore)
-                        elif option == 'csv_name':
-                            self.csv_name = curate_df['csv_name'][0]
-                
-                # Load class labels if they exist (for classification with string labels)
-                if 'class_0_label' in curate_df.columns and 'class_1_label' in curate_df.columns:
-                    self.class_0_label = curate_df['class_0_label'][0]
-                    self.class_1_label = curate_df['class_1_label'][0]
+                    if options_dict[option] == "":
+                        if option == "y":
+                            self.y = curate_df["y"][0]
+                        elif option == "names":
+                            self.names = curate_df["names"][0]
+                        elif option == "ignore":
+                            self.ignore = curate_df["ignore"][0]
+                            self.ignore = format_lists(self.ignore)
+                        elif option == "csv_name":
+                            self.csv_name = curate_df["csv_name"][0]
 
-        elif robert_module.upper() in ['PREDICT','VERIFY']:
-            if robert_module.upper() == 'PREDICT':
-                self.log.write(f"\no  Representation of predictions and analysis of ML models with the PREDICT module")
-            elif robert_module.upper() == 'VERIFY':
-                self.log.write(f"\no  Starting tests to verify the prediction ability of the ML models with the VERIFY module")
+                # Load class labels if they exist (for classification with string labels)
+                if (
+                    "class_0_label" in curate_df.columns
+                    and "class_1_label" in curate_df.columns
+                ):
+                    self.class_0_label = curate_df["class_0_label"][0]
+                    self.class_1_label = curate_df["class_1_label"][0]
+
+        elif robert_module.upper() in ["PREDICT", "VERIFY"]:
+            if robert_module.upper() == "PREDICT":
+                self.log.write(
+                    "\no  Representation of predictions and analysis of ML models with the PREDICT module"
+                )
+            elif robert_module.upper() == "VERIFY":
+                self.log.write(
+                    "\no  Starting tests to verify the prediction ability of the ML models with the VERIFY module"
+                )
 
-            if '' in [self.names,self.y,self.csv_name]:
+            if "" in [self.names, self.y, self.csv_name]:
                 # tries to get names from GENERATE
-                if 'GENERATE/Best_model' in self.params_dir:
-                    params_dirs = [f'{self.params_dir}/No_PFI',f'{self.params_dir}/PFI']
+                if "GENERATE/Best_model" in self.params_dir:
+                    params_dirs = [
+                        f"{self.params_dir}/No_PFI",
+                        f"{self.params_dir}/PFI",
+                    ]
                 else:
                     params_dirs = [self.params_dir]
                 self.args = self
-                _,_,_,model_data,csv_name = load_dfs(self,params_dirs[0],'predict',sanity_check=True)
+                _, _, _, model_data, csv_name = load_dfs(
+                    self, params_dirs[0], "predict", sanity_check=True
+                )
 
                 self.names = model_data["names"]
                 self.y = model_data["y"]
@@ -587,48 +656,58 @@ def load_variables(kwargs, robert_module):
                 if "type" in model_data:
                     self.type = model_data["type"]
 
-        elif robert_module.upper() in ['AQME', 'AQME_TEST']: 
+        elif robert_module.upper() in ["AQME", "AQME_TEST"]:
             # Check if the csv has 2 columns named smiles or smiles_Suffix. The file is read as text because pandas assigns automatically
             # .1 to duplicate columns. (i.e. SMILES and SMILES.1 if there are two columns named SMILES)
-            unique_columns=[]
-            with open(self.csv_name, 'r') as datfile:
+            unique_columns = []
+            with open(self.csv_name, "r") as datfile:
                 lines = datfile.readlines()
-                for column in lines[0].split(','):
+                for column in lines[0].split(","):
                     if column in unique_columns:
-                        print(f"\nWARNING! The CSV file contains duplicate columns ({column}). Please, rename or remove these columns. If you want to use more than one SMILES column, use _Suffix (i.e. SMILES_1, SMILES_2, ...)")
+                        print(
+                            f"\nWARNING! The CSV file contains duplicate columns ({column}). Please, rename or remove these columns. If you want to use more than one SMILES column, use _Suffix (i.e. SMILES_1, SMILES_2, ...)"
+                        )
                         sys.exit()
                     else:
                         unique_columns.append(column)
-            
+
             # Check if there is a column with the name "smiles" or "smiles_" followed by any characters
             if not any(col.lower().startswith("smiles") for col in unique_columns):
-                print("\nWARNING! The CSV file does not contain a column with the name 'smiles' or a column starting with 'smiles_'. Please make sure the column exists.")
+                print(
+                    "\nWARNING! The CSV file does not contain a column with the name 'smiles' or a column starting with 'smiles_'. Please make sure the column exists."
+                )
                 sys.exit()
 
             # Check if there are duplicate names in code_names in the csv file.
-            df = pd.read_csv(self.csv_name, encoding='utf-8')
-            unique_entries=[]
-            for entry in df['code_name']:
+            df = pd.read_csv(self.csv_name, encoding="utf-8")
+            unique_entries = []
+            for entry in df["code_name"]:
                 if entry in unique_entries:
-                    print(f"\nWARNING! The code_name column in the CSV file contains duplicate entries ({entry}). Please, rename or remove these entries.")
+                    print(
+                        f"\nWARNING! The code_name column in the CSV file contains duplicate entries ({entry}). Please, rename or remove these entries."
+                    )
                     sys.exit()
                 else:
                     unique_entries.append(entry)
 
-            self.log.write(f"\no  Starting the generation of AQME descriptors with the AQME module")
+            self.log.write(
+                "\no  Starting the generation of AQME descriptors with the AQME module"
+            )
 
         # initial sanity checks
-        if robert_module.upper() != 'REPORT':
-            _ = sanity_checks(self, 'initial', robert_module, None)
+        if robert_module.upper() != "REPORT":
+            _ = sanity_checks(self, "initial", robert_module, None)
 
     return self
 
 
-def destination_folder(self,dest_module):
+def destination_folder(self, dest_module):
     if self.destination is None:
         self.destination = Path(self.initial_dir).joinpath(dest_module.upper())
     else:
-        self.log.write(f"\nx  The destination option has not been implemented yet! Please, remove it from your input and stay tuned.")
+        self.log.write(
+            "\nx  The destination option has not been implemented yet! Please, remove it from your input and stay tuned."
+        )
         sys.exit()
         # this part does not work for know
         # if Path(f"{self.destination}").exists():
@@ -643,37 +722,45 @@ def destination_folder(self,dest_module):
     return self
 
 
-def missing_inputs(self,module,print_err=False):
+def missing_inputs(self, module, print_err=False):
     """
     Gives the option to input missing variables in the terminal
     """
 
-    if module.lower() not in ['predict','verify','report','aqme_test']:
-        if self.csv_name == '':
-            self = check_csv_option(self,'csv_name',print_err)
+    if module.lower() not in ["predict", "verify", "report", "aqme_test"]:
+        if self.csv_name == "":
+            self = check_csv_option(self, "csv_name", print_err)
 
-    if module.lower() not in ['predict','verify','report','aqme_test']:
-        if self.y == '':
+    if module.lower() not in ["predict", "verify", "report", "aqme_test"]:
+        if self.y == "":
             if print_err:
-                print(f'\nx  Specify a y value (column name) with the y option! (i.e. y="solubility")')
+                print(
+                    '\nx  Specify a y value (column name) with the y option! (i.e. y="solubility")'
+                )
             else:
-                self.log.write(f'\nx  Specify a y value (column name) with the y option! (i.e. y="solubility")')
-            self.y = input('Enter the column with y values: ')
-            self.extra_cmd += f' --y {self.y}'
+                self.log.write(
+                    '\nx  Specify a y value (column name) with the y option! (i.e. y="solubility")'
+                )
+            self.y = input("Enter the column with y values: ")
+            self.extra_cmd += f" --y {self.y}"
             if not print_err:
                 self.log.write(f"   -  y option set to {self.y} by the user")
 
-    if module.lower() in ['full_workflow','predict','curate','generate','evaluate']:
-        if self.names == '':
+    if module.lower() in ["full_workflow", "predict", "curate", "generate", "evaluate"]:
+        if self.names == "":
             if print_err:
-                print(f'\nx  Specify the column with the entry names! (i.e. names="code_name")')
+                print(
+                    '\nx  Specify the column with the entry names! (i.e. names="code_name")'
+                )
             else:
-                self.log.write(f'\nx  Specify the column with the entry names! (i.e. names="code_name")')
-            self.names = input('Enter the column with the entry names: ')
-            self.extra_cmd += f' --names {self.names}'
+                self.log.write(
+                    '\nx  Specify the column with the entry names! (i.e. names="code_name")'
+                )
+            self.names = input("Enter the column with the entry names: ")
+            self.extra_cmd += f" --names {self.names}"
             if not print_err:
                 self.log.write(f"   -  names option set to {self.names} by the user")
-        if self.names != '' and self.names not in self.ignore:
+        if self.names != "" and self.names not in self.ignore:
             self.ignore.append(self.names)
 
     return self
@@ -683,7 +770,7 @@ def correlation_filter(self, csv_df):
     """
     Discards a) correlated variables and b) variables that do not correlate with the y values, based
     on R**2 values c) reduces the number of descriptors to one third of the datapoints using RFECV.
-    
+
     REPRODUCIBILITY GUARANTEES:
     - Columns are sorted alphabetically before any operation
     - Rows are sorted by y value to ensure consistent ordering
@@ -691,81 +778,102 @@ def correlation_filter(self, csv_df):
     - RFECV descriptor selection uses sorted feature importances with alphabetical tie-breaking
     """
 
-    txt_corr = ''
-    
+    txt_corr = ""
+
     # Sort columns alphabetically and rows by y value for reproducibility
-    descriptor_cols = [col for col in csv_df.columns if col not in self.args.ignore and col != self.args.y]
+    descriptor_cols = [
+        col
+        for col in csv_df.columns
+        if col not in self.args.ignore and col != self.args.y
+    ]
     descriptor_cols_sorted = sorted(descriptor_cols)
-    other_cols = [col for col in csv_df.columns if col in self.args.ignore or col == self.args.y]
+    other_cols = [
+        col for col in csv_df.columns if col in self.args.ignore or col == self.args.y
+    ]
     csv_df = csv_df[descriptor_cols_sorted + other_cols].copy()
-    csv_df = csv_df.reset_index(drop=True).sort_values(by=self.args.y, kind='stable').reset_index(drop=True)
+    csv_df = (
+        csv_df.reset_index(drop=True)
+        .sort_values(by=self.args.y, kind="stable")
+        .reset_index(drop=True)
+    )
 
     # loosen correlation filters if there are too few descriptors
-    n_descps = len(csv_df.columns)-len(self.args.ignore)-1 # all columns - ignored - y
-    txt_corr += f'\no  Correlation filter activated with these thresholds: thres_x = {self.args.thres_x}'
+    n_descps = (
+        len(csv_df.columns) - len(self.args.ignore) - 1
+    )  # all columns - ignored - y
+    txt_corr += f"\no  Correlation filter activated with these thresholds: thres_x = {self.args.thres_x}"
     if self.args.corr_filter_y:
-        txt_corr += f', thres_y = {self.args.thres_y}'
+        txt_corr += f", thres_y = {self.args.thres_y}"
 
     descriptors_drop = []
-    txt_corr += f'\n   Excluded descriptors:'
-    
+    txt_corr += "\n   Excluded descriptors:"
+
     # First pass: remove constant descriptors and those with low correlation to y
-    for _,column in enumerate(csv_df.columns):
-        if column not in descriptors_drop and column not in self.args.ignore and column != self.args.y:
+    for _, column in enumerate(csv_df.columns):
+        if (
+            column not in descriptors_drop
+            and column not in self.args.ignore
+            and column != self.args.y
+        ):
             # Remove descriptors where all values are the same
             if len(set(csv_df[column])) == 1:
                 descriptors_drop.append(column)
-                txt_corr += f'\n   - {column}: all the values are the same'
+                txt_corr += f"\n   - {column}: all the values are the same"
 
             # Remove descriptors with low correlation to the response values
             if self.args.corr_filter_y:
                 # Calculate correlation with y for remaining descriptors
                 if column not in descriptors_drop:
-                    res_y = stats.linregress(csv_df[column],csv_df[self.args.y])
+                    res_y = stats.linregress(csv_df[column], csv_df[self.args.y])
                     rsquared_y = res_y.rvalue**2
                 if rsquared_y < self.args.thres_y:
                     descriptors_drop.append(column)
-                    txt_corr += f'\n   - {column}: R**2 = {rsquared_y:.2} with the {self.args.y} values'
+                    txt_corr += f"\n   - {column}: R**2 = {rsquared_y:.2} with the {self.args.y} values"
 
     self.args.log.write(txt_corr)
 
     # Second pass: remove highly correlated descriptors (always removing the most correlated first)
-    txt_corr = ''
+    txt_corr = ""
     csv_df_filtered = csv_df.drop(descriptors_drop, axis=1)
     csv_df_X_filtered = csv_df_filtered.drop([self.args.y] + self.args.ignore, axis=1)
-    if self.args.corr_filter_x and len(csv_df_X_filtered.columns) > 1:        
+    if self.args.corr_filter_x and len(csv_df_X_filtered.columns) > 1:
         # Calculate R2 correlation matrix between descriptors
         corr_matrix = csv_df_X_filtered.corr().abs()
-        corr_matrix_r2 = corr_matrix ** 2
-        upper = corr_matrix_r2.where(np.triu(np.ones(corr_matrix_r2.shape), k=1).astype(bool))
-        
+        corr_matrix_r2 = corr_matrix**2
+        upper = corr_matrix_r2.where(
+            np.triu(np.ones(corr_matrix_r2.shape), k=1).astype(bool)
+        )
+
         # Calculate R2 of each descriptor with y (for deciding which one to drop)
         r2_with_y = {}
         for col in csv_df_X_filtered.columns:
             res_y = stats.linregress(csv_df_filtered[col], csv_df_filtered[self.args.y])
             r2_with_y[col] = res_y.rvalue**2
-        
+
         # Iteratively remove the most correlated descriptors
         while True:
             # Find the maximum R2 correlation
             max_r2 = upper.max().max()
-            if max_r2 <= self.args.thres_x or str(max_r2).lower() == 'nan':
+            if max_r2 <= self.args.thres_x or str(max_r2).lower() == "nan":
                 break
-            
+
             # Get ALL pairs with maximum correlation, round to avoid floating point issues
             upper_rounded = upper.round(10)
             max_r2_rounded = upper_rounded.max().max()
             row_idx, col_idx = np.where(upper_rounded == max_r2_rounded)
-            
+
             # Sort pairs alphabetically for deterministic selection
-            pairs = [(upper.index[row_idx[i]], upper.columns[col_idx[i]]) for i in range(len(row_idx))]
+            pairs = [
+                (upper.index[row_idx[i]], upper.columns[col_idx[i]])
+                for i in range(len(row_idx))
+            ]
             pairs.sort()
             col_name_1, col_name_2 = pairs[0]
-            
+
             # Drop the descriptor with lower R2 to y, round for comparison
             r2_1 = round(r2_with_y[col_name_1], 10)
             r2_2 = round(r2_with_y[col_name_2], 10)
-            
+
             if r2_1 == r2_2:
                 # Tied on R2 with y, drop alphabetically later one
                 drop_col = col_name_1 if col_name_1 > col_name_2 else col_name_2
@@ -776,10 +884,10 @@ def correlation_filter(self, csv_df):
             else:
                 drop_col = col_name_2
                 keep_col = col_name_1
-            
+
             descriptors_drop.append(drop_col)
-            txt_corr += f'\n   - {drop_col} removed (R2 = {max_r2:.2f} with {keep_col}), kept more predictive descriptor'
-            
+            txt_corr += f"\n   - {drop_col} removed (R2 = {max_r2:.2f} with {keep_col}), kept more predictive descriptor"
+
             upper = upper.drop(index=drop_col, columns=drop_col)
             del r2_with_y[drop_col]
 
@@ -787,31 +895,31 @@ def correlation_filter(self, csv_df):
         to_drop = [col for col in csv_df_X_filtered.columns if col not in upper.columns]
 
         if to_drop:
-            txt_corr += f'\n   - Total: {len(to_drop)} descriptors removed due to high correlation with other descriptors'
+            txt_corr += f"\n   - Total: {len(to_drop)} descriptors removed due to high correlation with other descriptors"
 
     # drop descriptors that did not pass the filters
     csv_df_filtered = csv_df.drop(descriptors_drop, axis=1)
 
     if len(descriptors_drop) == 0:
-        txt_corr += f'\n   -  No descriptors were removed'
+        txt_corr += "\n   -  No descriptors were removed"
 
     self.args.log.write(txt_corr)
 
     # Check if descriptors are more than one third of datapoints
-    txt_corr = ''
+    txt_corr = ""
     descriptors_used = {}
     csv_df_per_model = {}
 
     num_descriptors = round(len(csv_df[self.args.y]) / 3)
     if n_descps > num_descriptors:
-        cv_type = f'{self.args.repeat_kfolds}x {self.args.kfold}_fold_cv'
-        txt_corr += f'\no  There are more descriptors than one-third of the data points. A Recursive Feature Elimination with Cross-Validation (RFECV) or permutation feature importance (PFI) using {cv_type} will be performed to select the most relevant descriptors for each model'
+        cv_type = f"{self.args.repeat_kfolds}x {self.args.kfold}_fold_cv"
+        txt_corr += f"\no  There are more descriptors than one-third of the data points. A Recursive Feature Elimination with Cross-Validation (RFECV) or permutation feature importance (PFI) using {cv_type} will be performed to select the most relevant descriptors for each model"
         self.args.log.write(txt_corr)
-        txt_corr = ''
+        txt_corr = ""
 
         # Perform RFECV for each model specified by the user
         X_df = csv_df_filtered.drop([self.args.y] + self.args.ignore, axis=1)
-        X_scaled_df,_ = scale_df(X_df,None)
+        X_scaled_df, _ = scale_df(X_df, None)
         y_df = csv_df_filtered[self.args.y]
 
         for model in sorted(self.args.model):
@@ -822,98 +930,143 @@ def correlation_filter(self, csv_df):
             estimator = load_model(self, model, **rfecv_params)
 
             # Repeated kfold-CV type
-            cv_model = RepeatedKFold(n_splits=self.args.kfold, n_repeats=self.args.repeat_kfolds, random_state=self.args.seed)
+            cv_model = RepeatedKFold(
+                n_splits=self.args.kfold,
+                n_repeats=self.args.repeat_kfolds,
+                random_state=self.args.seed,
+            )
 
             # Select scoring function for RFECV analysis based on the error type
-            scoring = get_scoring_key(self.args.type,self.args.error_type)
+            scoring = get_scoring_key(self.args.type, self.args.error_type)
 
             # Use different strategies for models without feature_importances_
-            if model.upper() in ['NN', 'GP', 'VR']:
+            if model.upper() in ["NN", "GP", "VR"]:
                 # For NN, GP and VR, use a simpler approach: select top features by correlation with y
                 # after initial fit, then use permutation importance to rank them
-                
+
                 # Train the model once on all features
                 estimator.fit(X_scaled_df, y_df)
-                
+
                 # Use permutation importance to rank features
-                perm_result = permutation_importance(estimator, X_scaled_df, y_df, 
-                                      n_repeats=self.args.pfi_epochs, 
-                                      random_state=self.args.seed,
-                                      scoring=scoring,
-                                      n_jobs=1)  # Force single thread for reproducibility
-                
+                perm_result = permutation_importance(
+                    estimator,
+                    X_scaled_df,
+                    y_df,
+                    n_repeats=self.args.pfi_epochs,
+                    random_state=self.args.seed,
+                    scoring=scoring,
+                    n_jobs=1,
+                )  # Force single thread for reproducibility
+
                 # Round to reduce floating point variance
-                feature_importances = np.round(perm_result.importances_mean, decimals=10)
-                
+                feature_importances = np.round(
+                    perm_result.importances_mean, decimals=10
+                )
+
                 # Create list of (importance, name) tuples for ALL features
-                importance_with_names = [(feature_importances[i], X_scaled_df.columns[i]) 
-                                        for i in range(len(feature_importances))]
-                
+                importance_with_names = [
+                    (feature_importances[i], X_scaled_df.columns[i])
+                    for i in range(len(feature_importances))
+                ]
+
                 # Sort by importance (descending) and break ties alphabetically for determinism
-                importance_with_names.sort(key=lambda x: (-x[0], x[1]))  # Sort by importance DESC, then name ASC
-                
+                importance_with_names.sort(
+                    key=lambda x: (-x[0], x[1])
+                )  # Sort by importance DESC, then name ASC
+
                 # Select top num_descriptors features (or all with positive importance if fewer)
-                positive_features = [(imp, name) for imp, name in importance_with_names if imp > 0]
+                positive_features = [
+                    (imp, name) for imp, name in importance_with_names if imp > 0
+                ]
                 if len(positive_features) > num_descriptors:
-                    descriptors_used[model] = [name for _, name in positive_features[:num_descriptors]]
+                    descriptors_used[model] = [
+                        name for _, name in positive_features[:num_descriptors]
+                    ]
                 else:
                     descriptors_used[model] = [name for _, name in positive_features]
-                
+
                 # Sort final list alphabetically for consistent ordering in output
                 descriptors_used[model] = sorted(descriptors_used[model])
-                
-                txt_corr += f'\n   - {model}: {len(descriptors_used[model])} descriptors selected (using PFI)'
-            
+
+                txt_corr += f"\n   - {model}: {len(descriptors_used[model])} descriptors selected (using PFI)"
+
             else:
                 # MVL, RF, GB, ADAB: use RFECV with feature importances
                 # Set step=1 for most stable/deterministic feature elimination
-                selector = RFECV(estimator, scoring=scoring, min_features_to_select=2, cv=cv_model, step=1, n_jobs=1)
+                selector = RFECV(
+                    estimator,
+                    scoring=scoring,
+                    min_features_to_select=2,
+                    cv=cv_model,
+                    step=1,
+                    n_jobs=1,
+                )
                 selector.fit(X_scaled_df, y_df)
-                
+
                 # Get selected features
                 selected_mask = selector.support_
                 selected_features_list = list(X_scaled_df.columns[selected_mask])
-                
+
                 # Get feature importances for selected features only
-                if model.upper() == 'MVL':
+                if model.upper() == "MVL":
                     # For MVL, use absolute coefficients as importance
                     feature_importances = np.abs(selector.estimator_.coef_)
-                else: 
+                else:
                     # RF, GB, ADAB, XGB have feature_importances_
                     feature_importances = selector.estimator_.feature_importances_
-                
+
                 # Round importances to reduce floating point variance
                 feature_importances = np.round(feature_importances, decimals=10)
-                
+
                 # Create (importance, name) pairs for selected features
-                importance_with_names = list(zip(feature_importances, selected_features_list))
-                
+                importance_with_names = list(
+                    zip(feature_importances, selected_features_list)
+                )
+
                 # Sort by importance (descending) with alphabetical tie-breaking for determinism
-                importance_with_names.sort(key=lambda x: (-x[0], x[1]))  # Sort by importance DESC, then name ASC
-                
+                importance_with_names.sort(
+                    key=lambda x: (-x[0], x[1])
+                )  # Sort by importance DESC, then name ASC
+
                 # Select top num_descriptors (or all if fewer selected)
                 n_to_select = min(num_descriptors, len(importance_with_names))
-                descriptors_used[model] = [name for _, name in importance_with_names[:n_to_select]]
-                
+                descriptors_used[model] = [
+                    name for _, name in importance_with_names[:n_to_select]
+                ]
+
                 # Sort final list alphabetically for consistent ordering in output
                 descriptors_used[model] = sorted(descriptors_used[model])
-                
-                txt_corr += f'\n   - {model}: {len(descriptors_used[model])} descriptors selected (using RFECV)'
-            
+
+                txt_corr += f"\n   - {model}: {len(descriptors_used[model])} descriptors selected (using RFECV)"
+
             # Create model-specific dataframe with sorted columns for reproducibility
             keep_cols = descriptors_used[model] + [self.args.y] + self.args.ignore
-            keep_cols = list(dict.fromkeys(keep_cols))  # Remove duplicates preserving order
+            keep_cols = list(
+                dict.fromkeys(keep_cols)
+            )  # Remove duplicates preserving order
             keep_cols = [col for col in keep_cols if col in csv_df_filtered.columns]
-            
+
             # Sort descriptor columns alphabetically, keep y and ignore at the end
-            descriptor_cols = [col for col in keep_cols if col not in self.args.ignore and col != self.args.y]
-            other_cols = [col for col in keep_cols if col in self.args.ignore or col == self.args.y]
-            sorted_cols = sorted(descriptor_cols) + sorted([col for col in other_cols if col in self.args.ignore]) + [self.args.y]
-            
+            descriptor_cols = [
+                col
+                for col in keep_cols
+                if col not in self.args.ignore and col != self.args.y
+            ]
+            other_cols = [
+                col
+                for col in keep_cols
+                if col in self.args.ignore or col == self.args.y
+            ]
+            sorted_cols = (
+                sorted(descriptor_cols)
+                + sorted([col for col in other_cols if col in self.args.ignore])
+                + [self.args.y]
+            )
+
             csv_df_per_model[model] = csv_df_filtered[sorted_cols].copy()
 
     else:
-        txt_corr += f'\n   x The RFECV filter was not applied, there are less descriptors than one-third of the data points ({len(csv_df_filtered.columns)-len(self.args.ignore)-1} <= {num_descriptors})'
+        txt_corr += f"\n   x The RFECV filter was not applied, there are less descriptors than one-third of the data points ({len(csv_df_filtered.columns) - len(self.args.ignore) - 1} <= {num_descriptors})"
         # If RFECV is not applied, all models use the same filtered dataframe
         for model in self.args.model:
             csv_df_per_model[model] = csv_df_filtered
@@ -925,126 +1078,150 @@ def correlation_filter(self, csv_df):
 
 
 def load_minimal_model(model):
-    '''
+    """
     Load the parameters of the minimalist models used for REFCV
-    '''
+    """
 
     minimal_params = {
-        'RF' : {
-        'n_estimators': 30,
-        'max_depth': 10,
-        'min_samples_split': 2,
-        'min_samples_leaf': 1,
-        'min_weight_fraction_leaf': 0,
-        'max_features': 1,
-        'ccp_alpha': 0.0,
-        'max_samples': None
-        },
-        'GB': {
-        'n_estimators': 30,
-        'learning_rate': 0.1,
-        'max_depth': 10,
-        'min_samples_split': 2,
-        'min_samples_leaf': 1,
-        'subsample': 1.0,
-        'max_features': None,
-        'validation_fraction': 0.2,
-        'min_weight_fraction_leaf': 0.0,
-        'ccp_alpha': 0.0
+        "RF": {
+            "n_estimators": 30,
+            "max_depth": 10,
+            "min_samples_split": 2,
+            "min_samples_leaf": 1,
+            "min_weight_fraction_leaf": 0,
+            "max_features": 1,
+            "ccp_alpha": 0.0,
+            "max_samples": None,
         },
-        'NN': {
-        'hidden_layer_1': 4,
-        'hidden_layer_2': 4,
-        'max_iter': 200,
-        'alpha': 0.01,
-        'tol': 0.0001
+        "GB": {
+            "n_estimators": 30,
+            "learning_rate": 0.1,
+            "max_depth": 10,
+            "min_samples_split": 2,
+            "min_samples_leaf": 1,
+            "subsample": 1.0,
+            "max_features": None,
+            "validation_fraction": 0.2,
+            "min_weight_fraction_leaf": 0.0,
+            "ccp_alpha": 0.0,
         },
-        'ADAB': {
-        'learning_rate': 1.0,
-        'n_estimators': 30
+        "NN": {
+            "hidden_layer_1": 4,
+            "hidden_layer_2": 4,
+            "max_iter": 200,
+            "alpha": 0.01,
+            "tol": 0.0001,
         },
-        'GP': {
-        'n_restarts_optimizer': 30,
+        "ADAB": {"learning_rate": 1.0, "n_estimators": 30},
+        "GP": {
+            "n_restarts_optimizer": 30,
         },
-        'XGB': {
-        'n_estimators': 30,
-        'learning_rate': 0.1,
-        'max_depth': 10,
-        'min_child_weight': 1,
-        'subsample': 1.0,
-        'colsample_bytree': 1.0,
-        'reg_alpha': 0.0,
-        'reg_lambda': 1.0,
+        "XGB": {
+            "n_estimators": 30,
+            "learning_rate": 0.1,
+            "max_depth": 10,
+            "min_child_weight": 1,
+            "subsample": 1.0,
+            "colsample_bytree": 1.0,
+            "reg_alpha": 0.0,
+            "reg_lambda": 1.0,
         },
-        'MVL': {
-        },
-        'VR': {
-        'w_rf': 1.0,
-        'w_gb': 1.0,
-        'w_nn': 1.0,
-        }
+        "MVL": {},
+    }
+    minimal_params["VR"] = {
+        "w_rf": 1.0,
+        "w_gb": 1.0,
+        "w_nn": 1.0,
+        **{f"rf_{key}": value for key, value in minimal_params["RF"].items()},
+        **{f"gb_{key}": value for key, value in minimal_params["GB"].items()},
+        **{f"nn_{key}": value for key, value in minimal_params["NN"].items()},
     }
 
     return minimal_params[model]
 
-def mcc_scorer_clf(y_true,y_pred):
+
+def _round_vr_member_params(params):
+    """Round integer hyperparameters for VR member models (rf_*, gb_*, nn_*)."""
+    rf_int = {"n_estimators", "max_depth", "min_samples_split", "min_samples_leaf"}
+    gb_int = {"n_estimators", "max_depth", "min_samples_split", "min_samples_leaf"}
+    nn_int = {"max_iter", "hidden_layer_1", "hidden_layer_2"}
+    for key in list(params.keys()):
+        if key.startswith("rf_"):
+            if key[3:] in rf_int:
+                params[key] = round(params[key])
+        elif key.startswith("gb_"):
+            if key[3:] in gb_int:
+                params[key] = round(params[key])
+        elif key.startswith("nn_"):
+            if key[3:] in nn_int:
+                params[key] = round(params[key])
+
+
+def _pop_vr_member_params(params, prefix, defaults):
+    """Extract ``prefix_*`` keys into a member-model parameter dict."""
+    member = dict(defaults)
+    for key in list(params.keys()):
+        if key.startswith(f"{prefix}_"):
+            member[key[len(prefix) + 1 :]] = params.pop(key)
+    return member
+
+
+def mcc_scorer_clf(y_true, y_pred):
     """Forces classification predictions to integer for MCC."""
     # Even if .predict() returns floats, coerce them to integer:
     y_pred = np.round(y_pred).astype(int)
-    
+
     return matthews_corrcoef(y_true, y_pred)
 
-def get_scoring_key(problem_type,error_type):
-    '''
+
+def get_scoring_key(problem_type, error_type):
+    """
     Load scoring function for evaluating models
-    '''
+    """
 
-    if problem_type.lower() == 'reg':
+    if problem_type.lower() == "reg":
         scoring = {
-            'rmse': 'neg_root_mean_squared_error',
-            'mae': 'neg_median_absolute_error',
-            'r2': 'r2'
+            "rmse": "neg_root_mean_squared_error",
+            "mae": "neg_median_absolute_error",
+            "r2": "r2",
         }.get(error_type)
     else:
         # For classification
-        if error_type == 'mcc':
+        if error_type == "mcc":
             # Use the custom MCC scorer that ensures integer predictions
             scoring = make_scorer(mcc_scorer_clf)
         else:
-            scoring = {
-                'f1': 'f1',
-                'acc': 'accuracy'
-            }.get(error_type)
-   
+            scoring = {"f1": "f1", "acc": "accuracy"}.get(error_type)
+
     return scoring
 
 
-def check_csv_option(self,csv_option,print_err):
-    '''
+def check_csv_option(self, csv_option, print_err):
+    """
     Checks missing values in input CSV options
-    '''
-    
-    if csv_option == 'csv_name':
-        line_print = f'\nx  Specify the CSV name for the {csv_option} option!'
-    elif csv_option == 'csv_train':
-        line_print = f'\nx  Specify the CSV name containing the TRAINING set!'
-    elif csv_option == 'csv_valid':
-        line_print = f'\nx  Specify the CSV name containing the VALIDATION set!'
+    """
+
+    if csv_option == "csv_name":
+        line_print = f"\nx  Specify the CSV name for the {csv_option} option!"
+    elif csv_option == "csv_train":
+        line_print = "\nx  Specify the CSV name containing the TRAINING set!"
+    elif csv_option == "csv_valid":
+        line_print = "\nx  Specify the CSV name containing the VALIDATION set!"
 
     if print_err:
         print(line_print)
     else:
         self.log.write(line_print)
-    val_option = input('Enter the name of your CSV file: ')
-    self.extra_cmd += f' --{csv_option} {val_option}'
+    val_option = input("Enter the name of your CSV file: ")
+    self.extra_cmd += f" --{csv_option} {val_option}"
     if not print_err:
         self.log.write(f"   -  {csv_option} option set to {val_option} by the user")
 
-    if csv_option == 'csv_name':
-        self.csv_name = val_option    
-    elif csv_option == 'csv_train':
+    if csv_option == "csv_name":
+        self.csv_name = val_option
+    elif csv_option == "csv_train":
         self.csv_train = val_option
-    elif csv_option == 'csv_valid':
+    elif csv_option == "csv_valid":
         self.csv_valid = val_option
 
     return self
@@ -1057,116 +1234,183 @@ def sanity_checks(self, type_checks, module, columns_csv):
 
     curate_valid = True
     # adds manual inputs missing from the command line
-    self = missing_inputs(self,module)
+    self = missing_inputs(self, module)
 
-    if module.lower() == 'evaluate':
-        curate_valid = locate_csv(self,self.csv_name,curate_valid)
+    if module.lower() == "evaluate":
+        curate_valid = locate_csv(self, self.csv_name, curate_valid)
 
-        if self.eval_model.lower() not in ['mvl']:
-            self.log.write(f"\nx  The eval_model option used is not valid! Options: 'MVL' (more options will be added soon)")
+        if self.eval_model.lower() not in ["mvl"]:
+            self.log.write(
+                "\nx  The eval_model option used is not valid! Options: 'MVL' (more options will be added soon)"
+            )
             curate_valid = False
 
-        if self.type.lower() not in ['reg']:
-            self.log.write(f"\nx  The type option used is not valid in EVALUATE! Options: 'reg' (the 'clas' option will be added soon)")
+        if self.type.lower() not in ["reg"]:
+            self.log.write(
+                "\nx  The type option used is not valid in EVALUATE! Options: 'reg' (the 'clas' option will be added soon)"
+            )
             curate_valid = False
 
-    elif type_checks == 'initial' and module.lower() not in ['verify','predict']:
-
-        curate_valid = locate_csv(self,self.csv_name,curate_valid)
+    elif type_checks == "initial" and module.lower() not in ["verify", "predict"]:
+        curate_valid = locate_csv(self, self.csv_name, curate_valid)
 
-        if module.lower() == 'curate':
-            if self.categorical.lower() not in ['onehot','numbers']:
-                self.log.write(f"\nx  The categorical option used is not valid! Options: 'onehot', 'numbers'")
+        if module.lower() == "curate":
+            if self.categorical.lower() not in ["onehot", "numbers"]:
+                self.log.write(
+                    "\nx  The categorical option used is not valid! Options: 'onehot', 'numbers'"
+                )
                 curate_valid = False
 
-            for thres,thres_name in zip([self.thres_x,self.thres_y],['thres_x','thres_y']):
+            for thres, thres_name in zip(
+                [self.thres_x, self.thres_y], ["thres_x", "thres_y"]
+            ):
                 if float(thres) > 1 or float(thres) < 0:
-                    self.log.write(f"\nx  The {thres_name} option should be between 0 and 1!")
+                    self.log.write(
+                        f"\nx  The {thres_name} option should be between 0 and 1!"
+                    )
                     curate_valid = False
-        
-        elif module.lower() == 'generate':
-            if self.split.lower() not in ['kn','rnd','stratified','even','extra_q1','extra_q5','auto']:
-                self.log.write(f"\nx  The split option used is not valid! Options: 'KN', 'RND'")
+
+        elif module.lower() == "generate":
+            if self.split.lower() not in [
+                "kn",
+                "rnd",
+                "stratified",
+                "even",
+                "extra_q1",
+                "extra_q5",
+                "auto",
+            ]:
+                self.log.write(
+                    "\nx  The split option used is not valid! Options: 'KN', 'RND'"
+                )
                 curate_valid = False
 
-            if self.split == 'auto':
-                if self.type.lower() == 'reg':
-                    self.split = 'even'
-                elif self.type.lower() == 'clas':
-                    self.split = 'rnd'
+            if self.split == "auto":
+                if self.type.lower() == "reg":
+                    self.split = "even"
+                elif self.type.lower() == "clas":
+                    self.split = "rnd"
 
             for model_type in self.model:
-                if model_type.upper() not in ['RF','MVL','GB','GP','ADAB','NN','XGB','VR'] or len(self.model) == 0:
-                    self.log.write(f"\nx  The model option used is not valid! Options: 'RF', 'MVL', 'GB', 'GP', 'ADAB', 'NN', 'XGB', 'VR'")
+                if (
+                    model_type.upper()
+                    not in ["RF", "MVL", "GB", "GP", "ADAB", "NN", "XGB", "VR"]
+                    or len(self.model) == 0
+                ):
+                    self.log.write(
+                        "\nx  The model option used is not valid! Options: 'RF', 'MVL', 'GB', 'GP', 'ADAB', 'NN', 'XGB', 'VR'"
+                    )
                     curate_valid = False
-                if model_type.upper() == 'MVL' and self.type.lower() == 'clas':
-                    self.log.write(f"\nx  Multivariate linear models (MVL in the model_type option) are not compatible with classificaton!")                 
+                if model_type.upper() == "MVL" and self.type.lower() == "clas":
+                    self.log.write(
+                        "\nx  Multivariate linear models (MVL in the model_type option) are not compatible with classificaton!"
+                    )
                     curate_valid = False
 
-            if self.type.lower() not in ['reg','clas']:
-                self.log.write(f"\nx  The type option used is not valid! Options: 'reg', 'clas'")
+            if self.type.lower() not in ["reg", "clas"]:
+                self.log.write(
+                    "\nx  The type option used is not valid! Options: 'reg', 'clas'"
+                )
                 curate_valid = False
 
-    if type_checks == 'initial' and module.lower() in ['generate','verify','predict','report']:
-
-        if type_checks == 'initial' and module.lower() in ['generate','verify']:
-            if self.type.lower() == 'reg' and self.error_type.lower() not in ['rmse','mae','r2']:
-                self.log.write(f"\nx  The error_type option is not valid! Options for regression: 'rmse', 'mae', 'r2'")
+    if type_checks == "initial" and module.lower() in [
+        "generate",
+        "verify",
+        "predict",
+        "report",
+    ]:
+        if type_checks == "initial" and module.lower() in ["generate", "verify"]:
+            if self.type.lower() == "reg" and self.error_type.lower() not in [
+                "rmse",
+                "mae",
+                "r2",
+            ]:
+                self.log.write(
+                    "\nx  The error_type option is not valid! Options for regression: 'rmse', 'mae', 'r2'"
+                )
                 curate_valid = False
 
-            if self.type.lower() == 'clas' and self.error_type.lower() not in ['mcc','f1','acc']:
-                self.log.write(f"\nx  The error_type option is not valid! Options for classification: 'mcc', 'f1', 'acc'")
+            if self.type.lower() == "clas" and self.error_type.lower() not in [
+                "mcc",
+                "f1",
+                "acc",
+            ]:
+                self.log.write(
+                    "\nx  The error_type option is not valid! Options for classification: 'mcc', 'f1', 'acc'"
+                )
                 curate_valid = False
 
-        if module.lower() in ['verify','predict']:
+        if module.lower() in ["verify", "predict"]:
             if os.getcwd() in f"{self.params_dir}":
                 path_db = self.params_dir
             else:
                 path_db = f"{Path(os.getcwd()).joinpath(self.params_dir)}"
 
             if not os.path.exists(path_db):
-                self.log.write(f'\nx  The path of your CSV files doesn\'t exist! Set the folder containing the two CSV files with 1) the parameters of the model and 2) the Xy database with the params_dir option')
+                self.log.write(
+                    "\nx  The path of your CSV files doesn't exist! Set the folder containing the two CSV files with 1) the parameters of the model and 2) the Xy database with the params_dir option"
+                )
                 curate_valid = False
 
-        if module.lower() == 'predict':
+        if module.lower() == "predict":
             if self.t_value < 0:
                 self.log.write(f"\nx  t_value ({self.t_value}) should be higher 0!")
                 curate_valid = False
 
-            if self.csv_test != '':
+            if self.csv_test != "":
                 if os.getcwd() in f"{self.csv_test}":
                     path_test = self.csv_test
                 else:
                     path_test = f"{Path(os.getcwd()).joinpath(self.csv_test)}"
                 if not os.path.exists(path_test):
-                    self.log.write(f'\nx  The path of your CSV file with the test set doesn\'t exist! You specified: {self.csv_test}')
+                    self.log.write(
+                        f"\nx  The path of your CSV file with the test set doesn't exist! You specified: {self.csv_test}"
+                    )
                     curate_valid = False
 
-        if module.lower() == 'report':
+        if module.lower() == "report":
             if len(self.report_modules) == 0:
-                self.log.write(f'\nx  No modules were provided in the report_modules option! Options: "CURATE", "GENERATE", "VERIFY", "PREDICT"')
+                self.log.write(
+                    '\nx  No modules were provided in the report_modules option! Options: "CURATE", "GENERATE", "VERIFY", "PREDICT"'
+                )
                 curate_valid = False
 
             for module in self.report_modules:
-                if module.upper() not in ['CURATE','GENERATE','VERIFY','PREDICT','AQME']:
-                    self.log.write(f'\nx  Module {module} specified in the report_modules option is not a valid module! Options: "CURATE", "GENERATE", "VERIFY", "PREDICT", "AQME"')
+                if module.upper() not in [
+                    "CURATE",
+                    "GENERATE",
+                    "VERIFY",
+                    "PREDICT",
+                    "AQME",
+                ]:
+                    self.log.write(
+                        f'\nx  Module {module} specified in the report_modules option is not a valid module! Options: "CURATE", "GENERATE", "VERIFY", "PREDICT", "AQME"'
+                    )
                     curate_valid = False
-  
-    elif type_checks == 'csv_db':
-        if module.lower() != 'predict':
+
+    elif type_checks == "csv_db":
+        if module.lower() != "predict":
             if self.y not in columns_csv:
-                if self.y.lower() in columns_csv: # accounts for upper/lowercase mismatches
+                if (
+                    self.y.lower() in columns_csv
+                ):  # accounts for upper/lowercase mismatches
                     self.y = self.y.lower()
                 elif self.y.upper() in columns_csv:
                     self.y = self.y.upper()
                 else:
-                    self.log.write(f"\nx  The y option specified ({self.y}) is not a column in the csv selected ({self.csv_name})! If you are using command lines, make sure you add quotation marks like --y \"VALUE\"")
+                    self.log.write(
+                        f'\nx  The y option specified ({self.y}) is not a column in the csv selected ({self.csv_name})! If you are using command lines, make sure you add quotation marks like --y "VALUE"'
+                    )
                     curate_valid = False
 
-            for option,option_name in zip([self.discard,self.ignore],['discard','ignore']):
+            for option, option_name in zip(
+                [self.discard, self.ignore], ["discard", "ignore"]
+            ):
                 for val in option:
                     if val not in columns_csv:
-                        self.log.write(f"\nx  Descriptor {val} specified in the {option_name} option is not a column in the csv selected ({self.csv_name})!")
+                        self.log.write(
+                            f"\nx  Descriptor {val} specified in the {option_name} option is not a column in the csv selected ({self.csv_name})!"
+                        )
                         curate_valid = False
 
     if not curate_valid:
@@ -1174,50 +1418,58 @@ def sanity_checks(self, type_checks, module, columns_csv):
         sys.exit()
 
 
-def locate_csv(self,csv_input,curate_valid):
-    '''
+def locate_csv(self, csv_input, curate_valid):
+    """
     Assesses whether the input CSV databases can be located
-    '''
+    """
 
-    path_csv = ''
+    path_csv = ""
     if os.path.exists(f"{csv_input}"):
         path_csv = csv_input
     elif os.path.exists(f"{Path(os.getcwd()).joinpath(csv_input)}"):
         path_csv = f"{Path(os.getcwd()).joinpath(csv_input)}"
-    if not os.path.exists(path_csv) or csv_input == '':
-        self.log.write(f'\nx  The path of your CSV file doesn\'t exist! You specified: --csv_name {csv_input}')
+    if not os.path.exists(path_csv) or csv_input == "":
+        self.log.write(
+            f"\nx  The path of your CSV file doesn't exist! You specified: --csv_name {csv_input}"
+        )
         curate_valid = False
-    
+
     return curate_valid
 
 
-def check_clas_problem(self,csv_df):
-    '''
+def check_clas_problem(self, csv_df):
+    """
     Changes type to classification if there are only two different y values.
     Automatically converts any pair of values (strings or numbers) to 0 and 1.
     Stores the original labels for later reconversion in outputs.
-    '''
+    """
 
     # changes type to classification if there are only two different y values
-    if self.args.type.lower() == 'reg' and self.args.auto_type:
+    if self.args.type.lower() == "reg" and self.args.auto_type:
         num_unique = len(set(csv_df[self.args.y]))
         if num_unique == 2:
-            self.args.type = 'clas'
-            if self.args.error_type not in ['acc', 'mcc', 'f1']:
-                self.args.error_type = 'mcc'
-            if ('MVL' or 'mvl') in self.args.model:
-                self.args.model = [x if x.upper() != 'MVL' else 'ADAB' for x in self.args.model]
+            self.args.type = "clas"
+            if self.args.error_type not in ["acc", "mcc", "f1"]:
+                self.args.error_type = "mcc"
+            if ("MVL" or "mvl") in self.args.model:
+                self.args.model = [
+                    x if x.upper() != "MVL" else "ADAB" for x in self.args.model
+                ]
 
             unique_vals = list(set(csv_df[self.args.y]))
-            y_val_detect = f'{unique_vals[0]} and {unique_vals[1]}'
-            self.args.log.write(f'\no  Only two different y values were detected ({y_val_detect})! The program will consider classification models (same effect as using "--type clas"). This option can be disabled with "--auto_type False"')
+            y_val_detect = f"{unique_vals[0]} and {unique_vals[1]}"
+            self.args.log.write(
+                f'\no  Only two different y values were detected ({y_val_detect})! The program will consider classification models (same effect as using "--type clas"). This option can be disabled with "--auto_type False"'
+            )
 
-    if self.args.type.lower() == 'clas':
+    if self.args.type.lower() == "clas":
         if len(set(csv_df[self.args.y])) == 2:
-            unique_values = sorted(list(set(csv_df[self.args.y])))  # Sort alphabetically for consistency
-            
+            unique_values = sorted(
+                list(set(csv_df[self.args.y]))
+            )  # Sort alphabetically for consistency
+
             # Check if values are already 0 and 1
-            if set([str(v) for v in unique_values]) == {'0', '1'}:
+            if set([str(v) for v in unique_values]) == {"0", "1"}:
                 # Already in correct format, just ensure they're integers
                 csv_df[self.args.y] = csv_df[self.args.y].astype(int)
             else:
@@ -1225,169 +1477,224 @@ def check_clas_problem(self,csv_df):
                 # Store original labels for reconversion in outputs
                 self.args.class_0_label = str(unique_values[0])
                 self.args.class_1_label = str(unique_values[1])
-                
+
                 # Create mapping dictionaries
                 self.args.class_mapping = {unique_values[0]: 0, unique_values[1]: 1}
-                self.args.class_mapping_reverse = {0: unique_values[0], 1: unique_values[1]}
-                
+                self.args.class_mapping_reverse = {
+                    0: unique_values[0],
+                    1: unique_values[1],
+                }
+
                 # Convert values in dataframe
                 csv_df[self.args.y] = csv_df[self.args.y].map(self.args.class_mapping)
-                
-                self.args.log.write(f'\no  Classification labels converted: {self.args.class_0_label} → 0, {self.args.class_1_label} → 1')
-                self.args.log.write(f'   Original labels will be restored in output files')
-        
+
+                self.args.log.write(
+                    f"\no  Classification labels converted: {self.args.class_0_label} → 0, {self.args.class_1_label} → 1"
+                )
+                self.args.log.write(
+                    "   Original labels will be restored in output files"
+                )
+
         # Check that each class has at least 5 points
         class_counts = csv_df[self.args.y].value_counts()
         min_class_count = class_counts.min()
         min_class_label = class_counts.idxmin()
-        
+
         if min_class_count < 5:
             # Get original label if available
-            if hasattr(self.args, 'class_mapping_reverse') and min_class_label in self.args.class_mapping_reverse:
+            if (
+                hasattr(self.args, "class_mapping_reverse")
+                and min_class_label in self.args.class_mapping_reverse
+            ):
                 original_label = self.args.class_mapping_reverse[min_class_label]
             else:
                 original_label = min_class_label
-            
+
             # Convert class_counts to dict with regular Python ints
             class_dist = {int(k): int(v) for k, v in class_counts.items()}
-            
-            self.args.log.write(f'\nx  Insufficient data for classification! One of the classes has only {min_class_count} datapoints (class "{original_label}")')
-            self.args.log.write(f'   Each class must have at least 5 datapoints to ensure robust train/validation/test splits')
-            self.args.log.write(f'   Current distribution: {class_dist}')
-            self.args.log.write(f'   Please add more datapoints for the minority class or consider a different approach')
+
+            self.args.log.write(
+                f'\nx  Insufficient data for classification! One of the classes has only {min_class_count} datapoints (class "{original_label}")'
+            )
+            self.args.log.write(
+                "   Each class must have at least 5 datapoints to ensure robust train/validation/test splits"
+            )
+            self.args.log.write(f"   Current distribution: {class_dist}")
+            self.args.log.write(
+                "   Please add more datapoints for the minority class or consider a different approach"
+            )
             self.args.log.finalize()
             sys.exit()
 
     return self
-    
 
-def load_database(self,csv_load,module,print_info=True,external_test=False):
-    '''
+
+def load_database(self, csv_load, module, print_info=True, external_test=False):
+    """
     Loads either a Xy (params=False) or a parameter (params=True) database from a CSV file
-    '''
-    
+    """
+
     # adjust external set in AQME workflows
-    if module.lower() == 'aqme_test':
+    if module.lower() == "aqme_test":
         external_test = True
 
-    txt_load = ''
+    txt_load = ""
     # Semicolon-separated "CSV" from Excel: peek at the first rows before reading the whole file.
     _scan_limit = 64
     head_lines = []
-    with open(csv_load, 'r', encoding='utf-8') as file:
+    with open(csv_load, "r", encoding="utf-8") as file:
         for _, line in zip(range(_scan_limit), file):
             head_lines.append(line)
-    semicolon_issue = len(head_lines) >= 2 and head_lines[1].count(';') > 1
+    semicolon_issue = len(head_lines) >= 2 and head_lines[1].count(";") > 1
     if semicolon_issue:
-        with open(csv_load, 'r', encoding='utf-8') as file:
+        with open(csv_load, "r", encoding="utf-8") as file:
             lines = file.readlines()
     if semicolon_issue:
-        new_csv_name = os.path.basename(csv_load).split('.csv')[0].split('.CSV')[0]+'_original.csv'
+        new_csv_name = (
+            os.path.basename(csv_load).split(".csv")[0].split(".CSV")[0]
+            + "_original.csv"
+        )
         shutil.move(csv_load, Path(os.path.dirname(csv_load)).joinpath(new_csv_name))
         new_csv_file = open(csv_load, "w")
         for line in lines:
-            line = line.replace(',','.')
-            line = line.replace(';',',')
+            line = line.replace(",", ".")
+            line = line.replace(";", ",")
             # line = line.replace(':',',')
             new_csv_file.write(line)
         new_csv_file.close()
-        txt_load += f'\nx  WARNING! The original database was not a valid CSV (i.e., formatting issues from Microsoft Excel?). A new database using commas as separators was created and used instead, and the original database was stored as {new_csv_name}. To prevent this issue from happening again, you should use commas as separators: https://support.edapp.com/change-csv-separator.\n\n'
+        txt_load += f"\nx  WARNING! The original database was not a valid CSV (i.e., formatting issues from Microsoft Excel?). A new database using commas as separators was created and used instead, and the original database was stored as {new_csv_name}. To prevent this issue from happening again, you should use commas as separators: https://support.edapp.com/change-csv-separator.\n\n"
 
-    csv_df = pd.read_csv(csv_load, encoding='utf-8')
+    csv_df = pd.read_csv(csv_load, encoding="utf-8")
 
     # Missing data handling: robust strategy for columns and rows (optional KNN imputer)
     target_col = self.args.y
-    descriptor_cols = [col for col in csv_df.columns if col not in self.args.ignore+self.args.discard and col != self.args.y]
+    descriptor_cols = [
+        col
+        for col in csv_df.columns
+        if col not in self.args.ignore + self.args.discard and col != self.args.y
+    ]
     min_count = int(0.9 * len(csv_df))
 
     # Remove columns with <90% data
-    cols_to_drop = [col for col in descriptor_cols if csv_df[col].notna().sum() < min_count]
+    cols_to_drop = [
+        col for col in descriptor_cols if csv_df[col].notna().sum() < min_count
+    ]
     if cols_to_drop:
         csv_df = csv_df.drop(columns=cols_to_drop)
-        if module.lower() == 'curate':
+        if module.lower() == "curate":
             txt_load += f"\n   - Removed {len(cols_to_drop)} column(s) with <90% data\n"
-        descriptor_cols = [col for col in csv_df.columns if col not in self.args.ignore+self.args.discard and col != self.args.y]
-    
+        descriptor_cols = [
+            col
+            for col in csv_df.columns
+            if col not in self.args.ignore + self.args.discard and col != self.args.y
+        ]
+
     # Remove rows with <50% data
-    rows_too_missing = csv_df[descriptor_cols].isna().sum(axis=1) > (0.5 * len(descriptor_cols))
+    rows_too_missing = csv_df[descriptor_cols].isna().sum(axis=1) > (
+        0.5 * len(descriptor_cols)
+    )
     if rows_too_missing.any():
         n_removed_rows = rows_too_missing.sum()
         csv_df = csv_df[~rows_too_missing].reset_index(drop=True)
-        descriptor_cols = [col for col in csv_df.columns if col not in self.args.ignore+self.args.discard and col != self.args.y]
-        if module.lower() == 'curate':
+        descriptor_cols = [
+            col
+            for col in csv_df.columns
+            if col not in self.args.ignore + self.args.discard and col != self.args.y
+        ]
+        if module.lower() == "curate":
             txt_load += f"\n   - Removed {n_removed_rows} row(s) with >50% missing descriptors\n"
-    
+
     # Apply KNN imputer only to numeric columns with missing values (when auto_fill is activated)
     if self.args.auto_fill:
-        numeric_columns = csv_df.select_dtypes(include=['float']).columns.drop(target_col, errors='ignore')
+        numeric_columns = csv_df.select_dtypes(include=["float"]).columns.drop(
+            target_col, errors="ignore"
+        )
         if csv_df[numeric_columns].isna().any().any():
             imputer = KNNImputer(n_neighbors=5)
-            csv_df[numeric_columns] = pd.DataFrame(imputer.fit_transform(csv_df[numeric_columns]), columns=numeric_columns, index=csv_df.index)
-            if module.lower() == 'curate':
-                txt_load += f"\n   - Applied KNN imputer to columns with missing values\n"
+            csv_df[numeric_columns] = pd.DataFrame(
+                imputer.fit_transform(csv_df[numeric_columns]),
+                columns=numeric_columns,
+                index=csv_df.index,
+            )
+            if module.lower() == "curate":
+                txt_load += (
+                    "\n   - Applied KNN imputer to columns with missing values\n"
+                )
     else:
         # Remove columns with ANY missing value
-        cols_with_missing = [col for col in descriptor_cols if csv_df[col].isna().any() and col not in self.args.ignore+self.args.discard]
+        cols_with_missing = [
+            col
+            for col in descriptor_cols
+            if csv_df[col].isna().any()
+            and col not in self.args.ignore + self.args.discard
+        ]
         if cols_with_missing:
             csv_df = csv_df.drop(columns=cols_with_missing)
-            if module.lower() == 'curate':
+            if module.lower() == "curate":
                 txt_load += f"\n   - Removed {len(cols_with_missing)} column(s) with missing values\n"
 
     if print_info:
-        sanity_checks(self.args,'csv_db',module,csv_df.columns)
+        sanity_checks(self.args, "csv_db", module, csv_df.columns)
         csv_df = csv_df.drop(self.args.discard, axis=1)
         total_amount = len(csv_df.columns)
         ignored_descs = len(self.args.ignore)
-        accepted_descs = total_amount - ignored_descs - 1 # the y column is substracted
-        if 'Set' in csv_df.columns: # removes the column that tracks sets
+        accepted_descs = total_amount - ignored_descs - 1  # the y column is substracted
+        if "Set" in csv_df.columns:  # removes the column that tracks sets
             accepted_descs -= 1
             ignored_descs += 1
-        if module.lower() not in ['aqme','aqme_test']:
+        if module.lower() not in ["aqme", "aqme_test"]:
             csv_name = os.path.basename(csv_load)
-            if module.lower() not in ['predict']:
-                txt_load += f'\no  Database {csv_name} loaded successfully, including:'
-                txt_load += f'\n   - {len(csv_df[self.args.y])} datapoints'
-                txt_load += f'\n   - {accepted_descs} accepted descriptors'
-                txt_load += f'\n   - {ignored_descs} ignored descriptors'
-                txt_load += (
-                    f"\n   - {len(self.args.discard)} discarded descriptors"
-                )
+            if module.lower() not in ["predict"]:
+                txt_load += f"\no  Database {csv_name} loaded successfully, including:"
+                txt_load += f"\n   - {len(csv_df[self.args.y])} datapoints"
+                txt_load += f"\n   - {accepted_descs} accepted descriptors"
+                txt_load += f"\n   - {ignored_descs} ignored descriptors"
+                txt_load += f"\n   - {len(self.args.discard)} discarded descriptors"
             else:
                 txt_load += (
-                    f"\no  External set {csv_name} loaded successfully, "
-                    "including:"
+                    f"\no  External set {csv_name} loaded successfully, including:"
                 )
                 txt_load += f"\n   - {len(csv_df)} datapoints (rows)"
             self.args.log.write(txt_load)
             if accepted_descs == 0:
-                self.args.log.write(f"\nx  The aren't any valid descriptors! Check the messages above to see whether the filters have discarded descriptors")
+                self.args.log.write(
+                    "\nx  The aren't any valid descriptors! Check the messages above to see whether the filters have discarded descriptors"
+                )
                 sys.exit()
 
     # Sort columns alphabetically for reproducibility across ALL modules
-    if module.lower() not in ['aqme', 'aqme_test']:
+    if module.lower() not in ["aqme", "aqme_test"]:
         # Get descriptor columns (excluding y and ignore)
-        descriptor_cols = [col for col in csv_df.columns if col not in self.args.ignore and col != self.args.y]
+        descriptor_cols = [
+            col
+            for col in csv_df.columns
+            if col not in self.args.ignore and col != self.args.y
+        ]
         descriptor_cols_sorted = sorted(descriptor_cols)
         # Get other columns (y and ignore)
-        other_cols = [col for col in csv_df.columns if col in self.args.ignore or col == self.args.y]
+        other_cols = [
+            col
+            for col in csv_df.columns
+            if col in self.args.ignore or col == self.args.y
+        ]
         # Reorder dataframe with sorted descriptors + other columns
         sorted_all_cols = descriptor_cols_sorted + other_cols
         csv_df = csv_df[sorted_all_cols]
-    
+
     # ignore user-defined descriptors and assign X and y values (but keeps the original database)
-    if module.lower() == 'generate':
+    if module.lower() == "generate":
         # Only drop columns that actually exist in the dataframe (for model-specific CSVs from CURATE)
         cols_to_ignore = [col for col in self.args.ignore if col in csv_df.columns]
         # Also drop 'Set' column if it exists (for model-specific CSVs)
-        if 'Set' in csv_df.columns and 'Set' not in cols_to_ignore:
-            cols_to_ignore.append('Set')
+        if "Set" in csv_df.columns and "Set" not in cols_to_ignore:
+            cols_to_ignore.append("Set")
         csv_df_ignore = csv_df.drop(cols_to_ignore, axis=1)
         csv_X = csv_df_ignore.drop([self.args.y], axis=1)
         csv_y = csv_df_ignore[self.args.y]
-        
+
         # Columns are already sorted from above, just extract them
         csv_X = csv_X[sorted([col for col in csv_X.columns])]
-        
+
     else:
         if external_test and self.args.y not in csv_df.columns:
             csv_X = csv_df
@@ -1396,53 +1703,53 @@ def load_database(self,csv_load,module,print_info=True,external_test=False):
             csv_X = csv_df.drop([self.args.y], axis=1)
             csv_y = csv_df[self.args.y]
 
-    return csv_df,csv_X,csv_y
+    return csv_df, csv_X, csv_y
 
 
-def categorical_transform(self,csv_df,module):
-    ''' converts all columns with strings into categorical values (one hot encoding
+def categorical_transform(self, csv_df, module):
+    """converts all columns with strings into categorical values (one hot encoding
     by default, can be set to numerical 1,2,3... with categorical = True).
     Troubleshooting! For one-hot encoding, don't use variable names that are
     also column headers! i.e. DESCRIPTOR "C_atom" contain C2 as a value,
     but C2 is already a header of a different column in the database. Same applies
     for multiple columns containing the same variable names.
-    '''
+    """
 
-    if module.lower() == 'curate':
-        txt_categor = f'\no  Analyzing categorical variables'
+    if module.lower() == "curate":
+        txt_categor = "\no  Analyzing categorical variables"
 
-    descriptors_to_drop, categorical_vars, new_categor_desc = [],[],[]
+    descriptors_to_drop, categorical_vars, new_categor_desc = [], [], []
     for column in csv_df.columns:
         if column not in self.args.ignore and column != self.args.y:
-            if(csv_df[column].dtype == 'object'):
+            if csv_df[column].dtype == "object":
                 descriptors_to_drop.append(column)
                 categorical_vars.append(column)
-                if self.args.categorical.lower() == 'numbers':
-                    csv_df[column] = csv_df[column].astype('category')
+                if self.args.categorical.lower() == "numbers":
+                    csv_df[column] = csv_df[column].astype("category")
                     csv_df[column] = csv_df[column].cat.codes
                 else:
-                    _ = csv_df[column].unique() # is this necessary?
+                    _ = csv_df[column].unique()  # is this necessary?
                     categor_descs = pd.get_dummies(csv_df[column])
                     csv_df = csv_df.drop(column, axis=1)
                     csv_df = pd.concat([csv_df, categor_descs], axis=1)
                     for desc in categor_descs:
                         new_categor_desc.append(desc)
 
-    if module.lower() == 'curate':
+    if module.lower() == "curate":
         if len(categorical_vars) == 0:
-            txt_categor += f'\n   - No categorical variables were found'
+            txt_categor += "\n   - No categorical variables were found"
         else:
-            if self.args.categorical.lower() == 'numbers':
-                txt_categor += f'\n   A total of {len(categorical_vars)} categorical variables were converted using the {self.args.categorical} mode in the categorical option:\n'
-                txt_categor += '\n'.join(f'   - {var}' for var in categorical_vars)
+            if self.args.categorical.lower() == "numbers":
+                txt_categor += f"\n   A total of {len(categorical_vars)} categorical variables were converted using the {self.args.categorical} mode in the categorical option:\n"
+                txt_categor += "\n".join(f"   - {var}" for var in categorical_vars)
             else:
-                txt_categor += f'\n   A total of {len(categorical_vars)} categorical variables were converted using the {self.args.categorical} mode in the categorical option'
-                txt_categor += f'\n   Initial descriptors:\n'
-                txt_categor += '\n'.join(f'   - {var}' for var in categorical_vars)
-                txt_categor += f'\n   Generated descriptors:\n'
-                txt_categor += '\n'.join(f'   - {var}' for var in new_categor_desc)
+                txt_categor += f"\n   A total of {len(categorical_vars)} categorical variables were converted using the {self.args.categorical} mode in the categorical option"
+                txt_categor += "\n   Initial descriptors:\n"
+                txt_categor += "\n".join(f"   - {var}" for var in categorical_vars)
+                txt_categor += "\n   Generated descriptors:\n"
+                txt_categor += "\n".join(f"   - {var}" for var in new_categor_desc)
 
-        self.args.log.write(f'{txt_categor}')
+        self.args.log.write(f"{txt_categor}")
 
     return csv_df
 
@@ -1454,142 +1761,182 @@ def create_folders(folder_names):
         folder.mkdir(exist_ok=True, parents=True)
 
 
-def finish_print(self,start_time,module):
+def finish_print(self, start_time, module):
     elapsed_time = round(time.time() - start_time, 2)
     self.args.log.write(f"\nTime {module.upper()}: {elapsed_time} seconds\n")
     self.args.log.finalize()
 
 
-def scale_df(csv_X,csv_X_external):
-    '''
+def scale_df(csv_X, csv_X_external):
+    """
     Scale the X matrix for the training set and the external test set (if any)
-    '''
-    
+    """
+
     scaler = StandardScaler()
     _ = scaler.fit(csv_X)
     X_scaled = scaler.transform(csv_X)
-    X_scaled_df = pd.DataFrame(X_scaled, columns = csv_X.columns)
+    X_scaled_df = pd.DataFrame(X_scaled, columns=csv_X.columns)
 
     X_scaled_external_df = None
     if csv_X_external is not None:
         X_scaled_external = scaler.transform(csv_X_external)
-        X_scaled_external_df = pd.DataFrame(X_scaled_external, columns = csv_X_external.columns)
-
-    return X_scaled_df,X_scaled_external_df
-
+        X_scaled_external_df = pd.DataFrame(
+            X_scaled_external, columns=csv_X_external.columns
+        )
 
-def Xy_split(csv_df,csv_X,X_scaled_df,csv_y,csv_external_df,csv_X_external,X_scaled_external_df,csv_y_external,test_points,column_names):
-    '''
+    return X_scaled_df, X_scaled_external_df
+
+
+def Xy_split(
+    csv_df,
+    csv_X,
+    X_scaled_df,
+    csv_y,
+    csv_external_df,
+    csv_X_external,
+    X_scaled_external_df,
+    csv_y_external,
+    test_points,
+    column_names,
+):
+    """
     Returns a dictionary with the database divided into train and validation
-    '''
+    """
 
-    Xy_data =  {}
+    Xy_data = {}
 
     if len(test_points) == 0:
-        Xy_data['X_train'] = csv_X
-        Xy_data['X_train_scaled'] = X_scaled_df
-        Xy_data['y_train'] = csv_y
-        Xy_data['names_train'] = csv_df[column_names]
+        Xy_data["X_train"] = csv_X
+        Xy_data["X_train_scaled"] = X_scaled_df
+        Xy_data["y_train"] = csv_y
+        Xy_data["names_train"] = csv_df[column_names]
 
     else:
-        Xy_data['X_train'] = csv_X.drop(test_points)
-        Xy_data['X_train_scaled'] = X_scaled_df.drop(test_points)
-        Xy_data['y_train'] = csv_y.drop(test_points)
-        Xy_data['X_test'] = csv_X.iloc[test_points]
-        Xy_data['X_test_scaled'] = X_scaled_df.iloc[test_points]
-        Xy_data['y_test'] = csv_y.iloc[test_points]
-        Xy_data['names_train'] = csv_df.drop(test_points)[column_names]
-        Xy_data['names_test'] = csv_df.iloc[test_points][column_names]
+        Xy_data["X_train"] = csv_X.drop(test_points)
+        Xy_data["X_train_scaled"] = X_scaled_df.drop(test_points)
+        Xy_data["y_train"] = csv_y.drop(test_points)
+        Xy_data["X_test"] = csv_X.iloc[test_points]
+        Xy_data["X_test_scaled"] = X_scaled_df.iloc[test_points]
+        Xy_data["y_test"] = csv_y.iloc[test_points]
+        Xy_data["names_train"] = csv_df.drop(test_points)[column_names]
+        Xy_data["names_test"] = csv_df.iloc[test_points][column_names]
 
-    Xy_data['test_points'] = test_points
+    Xy_data["test_points"] = test_points
 
     if X_scaled_external_df is not None:
-        Xy_data['X_external'] = csv_X_external
-        Xy_data['X_external_scaled'] = X_scaled_external_df
+        Xy_data["X_external"] = csv_X_external
+        Xy_data["X_external_scaled"] = X_scaled_external_df
         if csv_y_external is not None:
-            Xy_data['y_external'] = csv_y_external 
-        Xy_data['names_external'] = csv_external_df[column_names]
+            Xy_data["y_external"] = csv_y_external
+        Xy_data["names_external"] = csv_external_df[column_names]
 
     return Xy_data
 
 
-def test_select(self,X_scaled,csv_y):
-    '''
+def test_select(self, X_scaled, csv_y):
+    """
     Selection of test set (if any)
-    '''
+    """
 
     # adjusts size of the test_set to include at least 4 points regardless of the number of datapoints
     test_input_size = round(self.args.test_set * len(csv_y))
     min_test_size = 4
-    selected_size = max(test_input_size,min_test_size)
+    selected_size = max(test_input_size, min_test_size)
 
     # in the future, we'll adapt other data splitting techniques for classificaiton problems with 3+ target values
-    if self.args.type == 'clas':
+    if self.args.type == "clas":
         if len(set(csv_y)) != 2:
-            self.args.split = 'RND' 
+            self.args.split = "RND"
 
-    if self.args.split.upper() == 'KN':
+    if self.args.split.upper() == "KN":
         # k-neighbours data split
 
         # selects representative training points for each target value in classification problems
-        if self.args.type == 'clas':
+        if self.args.type == "clas":
             class_0_idx = list(csv_y[csv_y == 0].index)
             class_1_idx = list(csv_y[csv_y == 1].index)
-            class_0_test_size = round((len(class_0_idx)/len(csv_y))*selected_size)
-            class_1_test_size = selected_size-class_0_test_size
+            class_0_test_size = round((len(class_0_idx) / len(csv_y)) * selected_size)
+            class_1_test_size = selected_size - class_0_test_size
             class_0_train_size = len(class_0_idx) - class_0_test_size
             class_1_train_size = len(class_1_idx) - class_1_test_size
 
-            # the k-means function internally selects the training points to be as diverse as possible, 
+            # the k-means function internally selects the training points to be as diverse as possible,
             # but it returns the test points
-            test_class_0 = k_means(self,X_scaled.iloc[class_0_idx],csv_y,class_0_train_size,self.args.seed,class_0_idx)
-            test_class_1 = k_means(self,X_scaled.iloc[class_1_idx],csv_y,class_1_train_size,self.args.seed,class_1_idx)
-            test_points = test_class_0+test_class_1
+            test_class_0 = k_means(
+                self,
+                X_scaled.iloc[class_0_idx],
+                csv_y,
+                class_0_train_size,
+                self.args.seed,
+                class_0_idx,
+            )
+            test_class_1 = k_means(
+                self,
+                X_scaled.iloc[class_1_idx],
+                csv_y,
+                class_1_train_size,
+                self.args.seed,
+                class_1_idx,
+            )
+            test_points = test_class_0 + test_class_1
 
         else:
             idx_list = csv_y.index
-            training_size = len(csv_y)-selected_size
-            test_points = k_means(self,X_scaled,csv_y,training_size,self.args.seed,idx_list)
+            training_size = len(csv_y) - selected_size
+            test_points = k_means(
+                self, X_scaled, csv_y, training_size, self.args.seed, idx_list
+            )
 
-    elif self.args.split.upper() == 'RND':
+    elif self.args.split.upper() == "RND":
         size = round(selected_size * 100 / (len(csv_y)))
-        _, X_test, _, _ = train_test_split(X_scaled, csv_y, test_size=size/100, random_state=self.args.seed)
+        _, X_test, _, _ = train_test_split(
+            X_scaled, csv_y, test_size=size / 100, random_state=self.args.seed
+        )
         test_points = X_test.index.tolist()
 
-    elif self.args.split.upper() == 'STRATIFIED':
-
+    elif self.args.split.upper() == "STRATIFIED":
         size = np.ceil(selected_size * 100 / (len(csv_y)))
         # Remove the max and min values so they don't end up in the training set
         # Calculate the number of bins based on the number of points
         csv_y_capped = csv_y.drop([csv_y.idxmin(), csv_y.idxmax()])
-        y_binned = pd.qcut(csv_y_capped, q=selected_size, labels=False, duplicates='drop')
-        
+        y_binned = pd.qcut(
+            csv_y_capped, q=selected_size, labels=False, duplicates="drop"
+        )
+
         # Adjust the number of bins until each class has at least 2 members
         while y_binned.value_counts().min() < 2 and selected_size > 2:
             selected_size -= 1
-            y_binned = pd.qcut(csv_y_capped, q=selected_size, labels=False, duplicates='drop')
-        splitter = StratifiedShuffleSplit(n_splits=1, test_size=(100 - size) / 100, random_state=self.args.seed)
+            y_binned = pd.qcut(
+                csv_y_capped, q=selected_size, labels=False, duplicates="drop"
+            )
+        splitter = StratifiedShuffleSplit(
+            n_splits=1, test_size=(100 - size) / 100, random_state=self.args.seed
+        )
         for test_idx, _ in splitter.split(X_scaled, y_binned):
             test_points = test_idx.tolist()
 
-    elif self.args.split.upper() == 'EVEN':
+    elif self.args.split.upper() == "EVEN":
         # Remove the max and min values so they don't end up in the training set
         csv_y_capped = csv_y.drop([csv_y.idxmin(), csv_y.idxmax()])
         # Calculate the number of bins based on the number of points
-        y_binned = pd.qcut(csv_y_capped, q=selected_size, labels=False, duplicates='drop')
+        y_binned = pd.qcut(
+            csv_y_capped, q=selected_size, labels=False, duplicates="drop"
+        )
 
         # Adjust bin count if any bin has fewer than two elements (happens in imbalanced data, see comment below)
         temp_size = selected_size
         while y_binned.value_counts().min() < 2 and temp_size > 2:
             temp_size -= 1
-            y_binned = pd.qcut(csv_y_capped, q=temp_size, labels=False, duplicates='drop')
+            y_binned = pd.qcut(
+                csv_y_capped, q=temp_size, labels=False, duplicates="drop"
+            )
 
         # Determine central validation points for each bin
         test_points = []
         for bin_label in y_binned.unique():
             bin_indices = y_binned[y_binned == bin_label].index
             sorted_indices = sorted(bin_indices, key=lambda idx: csv_y[idx])
-            test_points.append(sorted_indices[round(len(sorted_indices)/2)])
+            test_points.append(sorted_indices[round(len(sorted_indices) / 2)])
 
         # in umbalanced databases, the points cannot be selected entirely even (i.e., if a database
         # contains 10 points in th 0-10 range, and 1000 points in the 10-90 range, choosing 100
@@ -1603,12 +1950,12 @@ def test_select(self,X_scaled,csv_y):
                 test_points.append(new_test_point)
             random_seed += 1
 
-    elif self.args.split.upper() == 'EXTRA_Q1':
+    elif self.args.split.upper() == "EXTRA_Q1":
         # 20% lowest points
         portion = max(1, round(0.2 * len(csv_y)))
         test_points = csv_y.nsmallest(portion).index.tolist()
-    
-    elif self.args.split.upper() == 'EXTRA_Q5':
+
+    elif self.args.split.upper() == "EXTRA_Q5":
         # 20%% highest points
         portion = max(1, round(0.2 * len(csv_y)))
         test_points = csv_y.nlargest(portion).index.tolist()
@@ -1622,20 +1969,20 @@ def generate_lhs_points(pbounds, n_points, random_state=None):
     """
     Generate initial points using Latin Hypercube Sampling for better space coverage.
     LHS ensures uniform distribution across all dimensions of the hyperparameter space.
-    
+
     Args:
         pbounds: Dictionary with parameter bounds from BO_hyperparams
         n_points: Number of initial points to generate
         random_state: Random seed for reproducibility
-    
+
     Returns:
         List of dictionaries with parameter values
     """
     np.random.seed(random_state)
-    
+
     param_names = list(pbounds.keys())
     n_params = len(param_names)
-    
+
     # Generate LHS samples in [0, 1]^n_params
     # Each dimension is divided into n_points intervals, and one point is sampled from each interval
     samples = np.zeros((n_points, n_params))
@@ -1645,7 +1992,7 @@ def generate_lhs_points(pbounds, n_points, random_state=None):
         samples[:, i] = np.random.uniform(intervals[:-1], intervals[1:])
         # Shuffle to break correlation between dimensions
         np.random.shuffle(samples[:, i])
-    
+
     # Scale samples to actual parameter bounds
     initial_points = []
     for sample in samples:
@@ -1655,30 +2002,25 @@ def generate_lhs_points(pbounds, n_points, random_state=None):
             # Scale from [0, 1] to [lower, upper]
             point[param_name] = lower + sample[i] * (upper - lower)
         initial_points.append(point)
-    
-    return initial_points
-
 
-def BO_optimizer(self,bo_data,Xy_data):
-    from bayes_opt import BayesianOptimization, acquisition
+    return initial_points
 
-    # Define an acquisition function for Bayesian optimization
-    _ = acquisition.ExpectedImprovement(xi=self.args.expect_improv)
 
+def BO_optimizer(self, bo_data, Xy_data):
     # Initialize Bayesian optimization
     optimizer = BayesianOptimization(
         f=lambda **p: BO_iteration(self, bo_data, Xy_data, **p),
-        pbounds=BO_hyperparams(bo_data['model']),
+        pbounds=BO_hyperparams(bo_data["model"]),
         verbose=2,
-        random_state=self.args.seed
+        random_state=self.args.seed,
     )
 
     # Generate initial points using Latin Hypercube Sampling for better space coverage
     if self.args.init_points > 0:
         initial_points = generate_lhs_points(
-            pbounds=BO_hyperparams(bo_data['model']),
+            pbounds=BO_hyperparams(bo_data["model"]),
             n_points=self.args.init_points,
-            random_state=self.args.seed
+            random_state=self.args.seed,
         )
         # Probe the initial points
         for params in initial_points:
@@ -1687,145 +2029,148 @@ def BO_optimizer(self,bo_data,Xy_data):
     # Run the optimization (with warnings suppressed for Convergence issues)
     with warnings.catch_warnings():
         warnings.filterwarnings("ignore", category=ConvergenceWarning)
-        optimizer.maximize(init_points=0, n_iter=self.args.n_iter)  # init_points=0 since we already probed LHS points
+        optimizer.maximize(
+            init_points=0, n_iter=self.args.n_iter
+        )  # init_points=0 since we already probed LHS points
 
-    if bo_data['error_type'].upper() in ['RMSE','MAE']:
-        BO_target = -optimizer.max['target']
+    if bo_data["error_type"].upper() in ["RMSE", "MAE"]:
+        BO_target = -optimizer.max["target"]
     else:
-        BO_target = optimizer.max['target']
-    self.args.log.write(f"   o Best combined {bo_data['error_type'].upper()} (target) found in BO for {bo_data['model']} (no PFI filter): {BO_target:.2}")
+        BO_target = optimizer.max["target"]
+    self.args.log.write(
+        f"   o Best combined {bo_data['error_type'].upper()} (target) found in BO for {bo_data['model']} (no PFI filter): {BO_target:.2}"
+    )
 
     # Retrieve best parameters and best result
-    return optimizer.max['params'], BO_target
+    return optimizer.max["params"], BO_target
 
 
 def BO_iteration(self, bo_data, Xy_data, **params):
-    '''
+    """
     Evaluate a model with given parameters using cross-validation.
     Returns the mean negative root mean squared error (higher is better).
-    '''
+    """
 
-    bo_data['params'] = model_adjust_params(self, bo_data['model'], params)
+    bo_data["params"] = model_adjust_params(self, bo_data["model"], params)
     BO_iter_score = load_n_predict(self, bo_data, Xy_data, BO_opt=True)
 
     return BO_iter_score
 
 
 def BO_hyperparams(model_name):
-
     model_BO_params = {
-        'RF' : {
-        'n_estimators': (10, 100),
-        'max_depth': (5, 20),
-        'min_samples_split': (2, 10),
-        'min_samples_leaf': (2, 5),
-        'min_weight_fraction_leaf': (0, 0.05),
-        'max_features': (0.25, 1.0),
-        'ccp_alpha': (0, 0.01),
-        'max_samples': (0.25, 1.0)
+        "RF": {
+            "n_estimators": (10, 100),
+            "max_depth": (5, 20),
+            "min_samples_split": (2, 10),
+            "min_samples_leaf": (2, 5),
+            "min_weight_fraction_leaf": (0, 0.05),
+            "max_features": (0.25, 1.0),
+            "ccp_alpha": (0, 0.01),
+            "max_samples": (0.25, 1.0),
         },
-        'GB': {
-        'n_estimators': (10, 100),
-        'learning_rate': (0.01, 0.3),
-        'max_depth': (5, 20),
-        'min_samples_split': (2, 10),
-        'min_samples_leaf': (2, 5),
-        'subsample': (0.7, 1.0),
-        'max_features': (0.25, 1.0),
-        'validation_fraction': (0.1, 0.3),
-        'min_weight_fraction_leaf': (0, 0.05),
-        'ccp_alpha': (0, 0.01)
+        "GB": {
+            "n_estimators": (10, 100),
+            "learning_rate": (0.01, 0.3),
+            "max_depth": (5, 20),
+            "min_samples_split": (2, 10),
+            "min_samples_leaf": (2, 5),
+            "subsample": (0.7, 1.0),
+            "max_features": (0.25, 1.0),
+            "validation_fraction": (0.1, 0.3),
+            "min_weight_fraction_leaf": (0, 0.05),
+            "ccp_alpha": (0, 0.01),
         },
-        'NN': {
-        'hidden_layer_1': (1, 10),
-        'hidden_layer_2': (0, 10),
-        'max_iter': (200, 500),
-        'alpha': (0.01, 0.1),
-        'tol': (0.00001, 0.0001)
+        "NN": {
+            "hidden_layer_1": (1, 10),
+            "hidden_layer_2": (0, 10),
+            "max_iter": (200, 500),
+            "alpha": (0.01, 0.1),
+            "tol": (0.00001, 0.0001),
         },
-        'ADAB': {
-        'learning_rate': (0.1, 5),
-        'n_estimators': (10, 100)
+        "ADAB": {"learning_rate": (0.1, 5), "n_estimators": (10, 100)},
+        "GP": {
+            "n_restarts_optimizer": (0, 100),
         },
-        'GP': {
-        'n_restarts_optimizer': (0, 100),
-        },
-        'XGB': {
-        'n_estimators': (10, 100),
-        'learning_rate': (0.01, 0.3),
-        'max_depth': (3, 20),
-        'min_child_weight': (1, 10),
-        'subsample': (0.7, 1.0),
-        'colsample_bytree': (0.25, 1.0),
-        'reg_alpha': (0, 1.0),
-        'reg_lambda': (0, 1.0),
-        },
-        'VR': {
-        'w_rf': (0.1, 5.0),
-        'w_gb': (0.1, 5.0),
-        'w_nn': (0.1, 5.0),
+        "XGB": {
+            "n_estimators": (10, 100),
+            "learning_rate": (0.01, 0.3),
+            "max_depth": (3, 20),
+            "min_child_weight": (1, 10),
+            "subsample": (0.7, 1.0),
+            "colsample_bytree": (0.25, 1.0),
+            "reg_alpha": (0, 1.0),
+            "reg_lambda": (0, 1.0),
         },
     }
+    model_BO_params["VR"] = {
+        "w_rf": (0.1, 5.0),
+        "w_gb": (0.1, 5.0),
+        "w_nn": (0.1, 5.0),
+        **{f"rf_{key}": value for key, value in model_BO_params["RF"].items()},
+        **{f"gb_{key}": value for key, value in model_BO_params["GB"].items()},
+        **{f"nn_{key}": value for key, value in model_BO_params["NN"].items()},
+    }
 
     return model_BO_params[model_name]
 
 
 def BO_metrics(self, bo_data, Xy_data):
-    '''
+    """
     Get combined score for repeated k-fold and top-bottom sorted CVs (used in BO)
-    '''
+    """
 
     metric_combined = load_n_predict(self, bo_data, Xy_data, BO_opt=True)
-    if bo_data['error_type'].upper() in ['RMSE','MAE']:
-        metric_combined  = -metric_combined
+    if bo_data["error_type"].upper() in ["RMSE", "MAE"]:
+        metric_combined = -metric_combined
     bo_data[f"combined_{bo_data['error_type']}"] = metric_combined
 
     return bo_data
 
 
-def model_adjust_params(self,model_name,params):
-    '''
+def model_adjust_params(self, model_name, params):
+    """
     Add seed and convert parameters to integers, since they come as floats with decimals in the iterations
 
-    '''
+    """
 
-    if model_name not in ['MVL', 'VR']:
-        params['random_state'] = self.args.seed
+    if model_name not in ["MVL", "VR"]:
+        params["random_state"] = self.args.seed
 
-        if model_name in ['RF','GB']:
-            params['n_estimators'] = round(params['n_estimators'])
-            params['max_depth'] = round(params['max_depth'])
-            params['min_samples_split'] = round(params['min_samples_split'])
-            params['min_samples_leaf'] = round(params['min_samples_leaf'])
+        if model_name in ["RF", "GB"]:
+            params["n_estimators"] = round(params["n_estimators"])
+            params["max_depth"] = round(params["max_depth"])
+            params["min_samples_split"] = round(params["min_samples_split"])
+            params["min_samples_leaf"] = round(params["min_samples_leaf"])
 
-        elif model_name == 'XGB':
-            params['n_estimators'] = round(params['n_estimators'])
-            params['max_depth'] = round(params['max_depth'])
-            params['min_child_weight'] = round(params['min_child_weight'])
+        elif model_name == "XGB":
+            params["n_estimators"] = round(params["n_estimators"])
+            params["max_depth"] = round(params["max_depth"])
+            params["min_child_weight"] = round(params["min_child_weight"])
 
-        elif model_name == 'NN':
+        elif model_name == "NN":
             # add solver first
-            params['solver'] = 'lbfgs'
-            params['max_iter'] = round(params['max_iter'])
-            params['hidden_layer_1'] = round(params['hidden_layer_1'])
-            params['hidden_layer_2'] = round(params['hidden_layer_2'])
-
-        elif model_name == 'ADAB':
-            params['n_estimators'] = round(params['n_estimators'])
-
-        elif model_name == 'GP':
-            params['n_restarts_optimizer'] = round(params['n_restarts_optimizer'])
-
-    elif model_name == 'VR':
-        # VR only optimizes ensemble weights; base estimators receive deterministic seeds.
-        if all(weight_key in params for weight_key in ['w_rf', 'w_gb', 'w_nn']):
-            params['weights'] = [
-                float(params.pop('w_rf')),
-                float(params.pop('w_gb')),
-                float(params.pop('w_nn')),
+            params["solver"] = "lbfgs"
+            params["max_iter"] = round(params["max_iter"])
+            params["hidden_layer_1"] = round(params["hidden_layer_1"])
+            params["hidden_layer_2"] = round(params["hidden_layer_2"])
+
+        elif model_name == "ADAB":
+            params["n_estimators"] = round(params["n_estimators"])
+
+        elif model_name == "GP":
+            params["n_restarts_optimizer"] = round(params["n_restarts_optimizer"])
+
+    elif model_name == "VR":
+        if all(weight_key in params for weight_key in ["w_rf", "w_gb", "w_nn"]):
+            params["weights"] = [
+                float(params.pop("w_rf")),
+                float(params.pop("w_gb")),
+                float(params.pop("w_nn")),
             ]
-        elif 'weights' in params:
-            params['weights'] = [float(weight) for weight in params['weights']]
+        elif "weights" in params:
+            params["weights"] = [float(weight) for weight in params["weights"]]
+        _round_vr_member_params(params)
 
     return params
 
@@ -1835,120 +2180,162 @@ def load_model(self, model_name, **params):
     Load models with their corresponding parameters.
     """
 
-    if model_name == 'RF':
+    if model_name == "RF":
         # Ensure n_jobs=1 for reproducibility if not already in params
-        if 'n_jobs' not in params:
-            params['n_jobs'] = 1
-        if self.args.type.lower() == 'reg':
+        if "n_jobs" not in params:
+            params["n_jobs"] = 1
+        if self.args.type.lower() == "reg":
             loaded_model = RandomForestRegressor(**params)
         else:
             loaded_model = RandomForestClassifier(**params)
 
-    elif model_name == 'GB':
+    elif model_name == "GB":
         # GradientBoosting doesn't have n_jobs parameter, it's already deterministic
-        if self.args.type.lower() == 'reg':
+        if self.args.type.lower() == "reg":
             loaded_model = GradientBoostingRegressor(**params)
         else:
             loaded_model = GradientBoostingClassifier(**params)
 
-    elif model_name == 'XGB':
-        if 'n_jobs' not in params:
-            params['n_jobs'] = 1
-        if self.args.type.lower() == 'reg':
+    elif model_name == "XGB":
+        if "n_jobs" not in params:
+            params["n_jobs"] = 1
+        if self.args.type.lower() == "reg":
             loaded_model = XGBRegressor(**params)
         else:
             loaded_model = XGBClassifier(**params)
 
-    elif model_name == 'NN':
+    elif model_name == "NN":
         # create the hidden layers architecture first
         params = setup_hidden_layers(params)
 
-        if self.args.type.lower() == 'reg':
+        if self.args.type.lower() == "reg":
             loaded_model = MLPRegressor(**params)
         else:
             loaded_model = MLPClassifier(**params)
 
-    elif model_name == 'ADAB':
-        if self.args.type.lower() == 'reg':
+    elif model_name == "ADAB":
+        if self.args.type.lower() == "reg":
             loaded_model = AdaBoostRegressor(**params)
         else:
             loaded_model = AdaBoostClassifier(**params)
 
-    elif model_name == 'GP':        
-        if self.args.type.lower() == 'reg':
+    elif model_name == "GP":
+        if self.args.type.lower() == "reg":
             loaded_model = GaussianProcessRegressor(**params)
         else:
             loaded_model = GaussianProcessClassifier(**params)
 
-    elif model_name == 'MVL':
+    elif model_name == "MVL":
         loaded_model = LinearRegression(**params)
 
-    elif model_name == 'VR':
-        weights = params.pop('weights', [1.0, 1.0, 1.0])
+    elif model_name == "VR":
+        weights = params.pop("weights", [1.0, 1.0, 1.0])
         weights = [float(weight) for weight in weights]
         seed = self.args.seed
+        rf_defaults = {
+            "n_estimators": 100,
+            "max_depth": 10,
+            "min_samples_split": 2,
+            "min_samples_leaf": 1,
+            "min_weight_fraction_leaf": 0,
+            "max_features": 1.0,
+            "ccp_alpha": 0.0,
+            "max_samples": None,
+            "random_state": seed,
+            "n_jobs": 1,
+        }
+        gb_defaults = {
+            "n_estimators": 30,
+            "learning_rate": 0.1,
+            "max_depth": 10,
+            "min_samples_split": 2,
+            "min_samples_leaf": 1,
+            "subsample": 1.0,
+            "max_features": None,
+            "validation_fraction": 0.2,
+            "min_weight_fraction_leaf": 0.0,
+            "ccp_alpha": 0.0,
+            "random_state": seed,
+        }
+        nn_defaults = {
+            "hidden_layer_1": 50,
+            "hidden_layer_2": 0,
+            "max_iter": 500,
+            "alpha": 0.01,
+            "tol": 0.0001,
+            "solver": "lbfgs",
+            "random_state": seed,
+        }
+        rf_params = _pop_vr_member_params(params, "rf", rf_defaults)
+        gb_params = _pop_vr_member_params(params, "gb", gb_defaults)
+        nn_params = _pop_vr_member_params(params, "nn", nn_defaults)
+        nn_params = setup_hidden_layers(nn_params)
 
-        if self.args.type.lower() == 'reg':
+        if self.args.type.lower() == "reg":
             voting_estimators = [
-                ('rf', RandomForestRegressor(n_estimators=100, random_state=seed, n_jobs=1)),
-                ('gb', GradientBoostingRegressor(random_state=seed)),
-                ('nn', MLPRegressor(hidden_layer_sizes=(50,), max_iter=500, solver='lbfgs', random_state=seed)),
+                ("rf", RandomForestRegressor(**rf_params)),
+                ("gb", GradientBoostingRegressor(**gb_params)),
+                ("nn", MLPRegressor(**nn_params)),
             ]
-            loaded_model = VotingRegressor(estimators=voting_estimators, weights=weights)
+            loaded_model = VotingRegressor(
+                estimators=voting_estimators, weights=weights
+            )
         else:
             voting_estimators = [
-                ('rf', RandomForestClassifier(n_estimators=100, random_state=seed, n_jobs=1)),
-                ('gb', GradientBoostingClassifier(random_state=seed)),
-                ('nn', MLPClassifier(hidden_layer_sizes=(50,), max_iter=500, solver='lbfgs', random_state=seed)),
+                ("rf", RandomForestClassifier(**rf_params)),
+                ("gb", GradientBoostingClassifier(**gb_params)),
+                ("nn", MLPClassifier(**nn_params)),
             ]
-            loaded_model = VotingClassifier(estimators=voting_estimators, weights=weights)
-    
+            loaded_model = VotingClassifier(
+                estimators=voting_estimators, weights=weights
+            )
+
     return loaded_model
 
 
 def setup_hidden_layers(params):
-    '''
+    """
     Build hidden layer structure from provided parameters
-    '''
+    """
 
     hidden_layer_sizes = []
-    hidden_layer_1 = params.pop('hidden_layer_1')
-    hidden_layer_2 = params.pop('hidden_layer_2')
+    hidden_layer_1 = params.pop("hidden_layer_1")
+    hidden_layer_2 = params.pop("hidden_layer_2")
     if hidden_layer_1 > 0:
         hidden_layer_sizes.append(hidden_layer_1)
     if hidden_layer_2 > 0:
         hidden_layer_sizes.append(hidden_layer_2)
     hidden_layer_sizes = tuple(hidden_layer_sizes) if hidden_layer_sizes else (1,)
 
-    params['hidden_layer_sizes'] = hidden_layer_sizes
+    params["hidden_layer_sizes"] = hidden_layer_sizes
 
     return params
 
 
 def correct_hidden_layers(params):
-    '''
+    """
     Correct for a problem with the 'hidden_layer_sizes' parameter when loading arrays from JSON
-    '''
-    
+    """
+
     layer_arrays = []
 
-    if not isinstance(params['hidden_layer_sizes'],int):
-        if params['hidden_layer_sizes'][0] == '[':
-            params['hidden_layer_sizes'] = params['hidden_layer_sizes'][1:]
-        if params['hidden_layer_sizes'][-1] == ']':
-            params['hidden_layer_sizes'] = params['hidden_layer_sizes'][:-1]
-        if not isinstance(params['hidden_layer_sizes'],list):
-            for _,ele in enumerate(params['hidden_layer_sizes'].split(',')):
-                if ele != '':
+    if not isinstance(params["hidden_layer_sizes"], int):
+        if params["hidden_layer_sizes"][0] == "[":
+            params["hidden_layer_sizes"] = params["hidden_layer_sizes"][1:]
+        if params["hidden_layer_sizes"][-1] == "]":
+            params["hidden_layer_sizes"] = params["hidden_layer_sizes"][:-1]
+        if not isinstance(params["hidden_layer_sizes"], list):
+            for _, ele in enumerate(params["hidden_layer_sizes"].split(",")):
+                if ele != "":
                     layer_arrays.append(int(ele))
         else:
-            for _,ele in enumerate(params['hidden_layer_sizes']):
-                if ele != '':
+            for _, ele in enumerate(params["hidden_layer_sizes"]):
+                if ele != "":
                     layer_arrays.append(int(ele))
     else:
         layer_arrays = ele
 
-    params['hidden_layer_sizes'] = (layer_arrays)
+    params["hidden_layer_sizes"] = layer_arrays
 
     return params
 
@@ -2074,7 +2461,7 @@ def aggregate_meta_uq_decomposition(preds_stack, sd_stack, weights, problem_type
 
     if problem_type.lower() == "reg":
         y_point = np.average(preds, axis=0, weights=w)
-        within_var = np.average(sds ** 2, axis=0, weights=w)
+        within_var = np.average(sds**2, axis=0, weights=w)
         mean_pred = np.average(preds, axis=0, weights=w)
         between_var = np.average((preds - mean_pred) ** 2, axis=0, weights=w)
         uq_model = np.sqrt(np.maximum(within_var, 0.0))
@@ -2085,7 +2472,10 @@ def aggregate_meta_uq_decomposition(preds_stack, sd_stack, weights, problem_type
     # Classification: preds are class labels; use float spread heuristics.
     preds_f = preds.astype(float)
     y_point = np.array(
-        [int(round(np.average(preds_f[:, i], weights=w))) for i in range(preds_f.shape[1])],
+        [
+            int(round(np.average(preds_f[:, i], weights=w)))
+            for i in range(preds_f.shape[1])
+        ],
         dtype=int,
     )
     uq_model = np.average(sds, axis=0, weights=w)
@@ -2093,7 +2483,7 @@ def aggregate_meta_uq_decomposition(preds_stack, sd_stack, weights, problem_type
     uq_meta = np.sqrt(
         np.maximum(np.average((preds_f - mean_pred) ** 2, axis=0, weights=w), 0.0)
     )
-    uq_total = np.sqrt(np.maximum(uq_model ** 2 + uq_meta ** 2, 0.0))
+    uq_total = np.sqrt(np.maximum(uq_model**2 + uq_meta**2, 0.0))
     return y_point, uq_model, uq_meta, uq_total
 
 
@@ -2222,7 +2612,9 @@ def _conformal_abs_residual_quantile(abs_residuals, coverage):
     return float(np.quantile(abs_r, level, method="higher"))
 
 
-def _apply_full_refit_split_conformal(self, model_data, Xy_data, loaded_model, y_cv_mean_train):
+def _apply_full_refit_split_conformal(
+    self, model_data, Xy_data, loaded_model, y_cv_mean_train
+):
     """
     Point predictions from a single estimator refit on all training data.
     For regression, ``conformal_half_width`` uses a held-out calibration split when
@@ -2305,15 +2697,15 @@ def _apply_full_refit_split_conformal(self, model_data, Xy_data, loaded_model, y
 
 
 def load_n_predict(self, model_data, Xy_data, BO_opt=False, verify_job=False):
-    '''
+    """
     Load model and calculate errors/precision and predicted values of the ML models
-    '''
+    """
 
     # set the parameters for the ML model and load it
-    loaded_model = load_model(self, model_data['model'], **model_data['params'])
+    loaded_model = load_model(self, model_data["model"], **model_data["params"])
 
     # calculate predicted y values using repeated k-fold CV
-    Xy_data = repeated_kfold_cv(model_data,loaded_model,Xy_data,BO_opt)
+    Xy_data = repeated_kfold_cv(model_data, loaded_model, Xy_data, BO_opt)
 
     y_cv_mean_train = np.asarray(Xy_data["y_pred_train"], dtype=float)
     if not BO_opt:
@@ -2323,38 +2715,59 @@ def load_n_predict(self, model_data, Xy_data, BO_opt=False, verify_job=False):
         Xy_data["_fitted_model"] = fitted_model
 
     # combine all the predictions from the repeated CV (metrics of the train set)
-    y_all_list,y_pred_all_list = [],[]
-    for y_val,y_pred_vals in zip(Xy_data['y_train'],Xy_data['y_pred_train_all']):
+    y_all_list, y_pred_all_list = [], []
+    for y_val, y_pred_vals in zip(Xy_data["y_train"], Xy_data["y_pred_train_all"]):
         for y_pred_val in y_pred_vals:
             y_all_list.append(y_val)
             y_pred_all_list.append(y_pred_val)
-        
+
     # get metrics for the different sets
-    error_labels = {'reg': ['r2','mae','rmse'],
-                    'clas': ['acc','f1','mcc']
-                    }
-
-    error1 = error_labels[model_data['type']][0]
-    error2 = error_labels[model_data['type']][1]
-    error3 = error_labels[model_data['type']][2]
-    Xy_data[f'{error1}_train'], Xy_data[f'{error2}_train'], Xy_data[f'{error3}_train'] = get_prediction_results(model_data,y_all_list,y_pred_all_list)
+    error_labels = {"reg": ["r2", "mae", "rmse"], "clas": ["acc", "f1", "mcc"]}
+
+    error1 = error_labels[model_data["type"]][0]
+    error2 = error_labels[model_data["type"]][1]
+    error3 = error_labels[model_data["type"]][2]
+    (
+        Xy_data[f"{error1}_train"],
+        Xy_data[f"{error2}_train"],
+        Xy_data[f"{error3}_train"],
+    ) = get_prediction_results(model_data, y_all_list, y_pred_all_list)
     if not BO_opt:
-        Xy_data[f'{error1}_test'], Xy_data[f'{error2}_test'], Xy_data[f'{error3}_test'] = get_prediction_results(model_data,Xy_data['y_test'],Xy_data['y_pred_test'])
-        if 'y_external' in Xy_data and not Xy_data['y_external'].isnull().values.any() and len(Xy_data['y_external']) > 0:
-            Xy_data[f'{error1}_external'], Xy_data[f'{error2}_external'], Xy_data[f'{error3}_external'] = get_prediction_results(model_data,Xy_data['y_external'],Xy_data['y_pred_external'])
+        (
+            Xy_data[f"{error1}_test"],
+            Xy_data[f"{error2}_test"],
+            Xy_data[f"{error3}_test"],
+        ) = get_prediction_results(
+            model_data, Xy_data["y_test"], Xy_data["y_pred_test"]
+        )
+        if (
+            "y_external" in Xy_data
+            and not Xy_data["y_external"].isnull().values.any()
+            and len(Xy_data["y_external"]) > 0
+        ):
+            (
+                Xy_data[f"{error1}_external"],
+                Xy_data[f"{error2}_external"],
+                Xy_data[f"{error3}_external"],
+            ) = get_prediction_results(
+                model_data, Xy_data["y_external"], Xy_data["y_pred_external"]
+            )
     if BO_opt:
         # calculate sorted CV and its metrics
         # print the target that is above the BO
         # print the final result of the BO just after finishing all the iterations
         Xy_data = sorted_kfold_cv(loaded_model, model_data, Xy_data, error_labels)
-        combined_score = (Xy_data[f'{model_data["error_type"]}_train'] + Xy_data[f'{model_data["error_type"]}_up_bottom']) / 2
+        combined_score = (
+            Xy_data[f"{model_data['error_type']}_train"]
+            + Xy_data[f"{model_data['error_type']}_up_bottom"]
+        ) / 2
 
         # Return if this is part of a verify job
         if verify_job:
             return Xy_data
 
         # Return negative score for MAE/RMSE in BO
-        if model_data["error_type"].lower() in ['mae', 'rmse']:
+        if model_data["error_type"].lower() in ["mae", "rmse"]:
             return -combined_score
         else:
             return combined_score
@@ -2362,110 +2775,139 @@ def load_n_predict(self, model_data, Xy_data, BO_opt=False, verify_job=False):
         return Xy_data
 
 
-def repeated_kfold_cv(model_data,loaded_model,Xy_data,BO_opt):
-    '''
+def repeated_kfold_cv(model_data, loaded_model, Xy_data, BO_opt):
+    """
     Performs a repeated k-fold cross-validation on the Xy dataset
-    '''
+    """
 
     # create a list of lists with the same number of entries as y
-    y_global,y_pred_global = [],[]
-    for _ in range(len(Xy_data['y_train'])):
+    y_global, y_pred_global = [], []
+    for _ in range(len(Xy_data["y_train"])):
         y_pred_global.append([])
         y_global.append([])
 
     y_pred_global_test = []
-    for _ in range(len(Xy_data['y_test'])):
+    for _ in range(len(Xy_data["y_test"])):
         y_pred_global_test.append([])
 
-    y_global_external,y_pred_global_external = [],[]
-    if 'X_external' in Xy_data: # if there is an external test set
-        for _ in range(len(Xy_data['X_external'])):
+    y_global_external, y_pred_global_external = [], []
+    if "X_external" in Xy_data:  # if there is an external test set
+        for _ in range(len(Xy_data["X_external"])):
             y_pred_global_external.append([])
             y_global_external.append([])
 
     # start the repeated CV
-    for CV_repeat in range(int(model_data['repeat_kfolds'])):
-        _,y_pred_global,y_pred_global_test,y_pred_global_external, = kfold_cv(y_global,y_pred_global,
-                                        y_pred_global_test,
-                                        y_pred_global_external,
-                                        model_data,loaded_model,
-                                        Xy_data,CV_repeat,BO_opt=BO_opt)
-
-    y_train_pred, y_train_std = [],[]
+    for CV_repeat in range(int(model_data["repeat_kfolds"])):
+        (
+            _,
+            y_pred_global,
+            y_pred_global_test,
+            y_pred_global_external,
+        ) = kfold_cv(
+            y_global,
+            y_pred_global,
+            y_pred_global_test,
+            y_pred_global_external,
+            model_data,
+            loaded_model,
+            Xy_data,
+            CV_repeat,
+            BO_opt=BO_opt,
+        )
+
+    y_train_pred, y_train_std = [], []
     for y_val in y_pred_global:
-        if model_data['type'].lower() == 'reg':
+        if model_data["type"].lower() == "reg":
             y_train_pred.append(np.mean(y_val))
-        elif model_data['type'].lower() == 'clas':
+        elif model_data["type"].lower() == "clas":
             y_train_pred.append(int(round(np.mean(y_val))))
         y_train_std.append(float(np.std(y_val)))
 
-    Xy_data['y_pred_train_all'] = y_pred_global
-    Xy_data['y_pred_train'] = y_train_pred
-    Xy_data['y_pred_train_sd'] = y_train_std
+    Xy_data["y_pred_train_all"] = y_pred_global
+    Xy_data["y_pred_train"] = y_train_pred
+    Xy_data["y_pred_train_sd"] = y_train_std
 
     if not BO_opt:
-        y_test_pred, y_test_std = [],[]
+        y_test_pred, y_test_std = [], []
         for y_val_test in y_pred_global_test:
-            if model_data['type'].lower() == 'reg':
+            if model_data["type"].lower() == "reg":
                 y_test_pred.append(np.mean(y_val_test))
-            elif model_data['type'].lower() == 'clas':
+            elif model_data["type"].lower() == "clas":
                 y_test_pred.append(int(round(np.mean(y_val_test))))
             y_test_std.append(float(np.std(y_val_test)))
 
-        Xy_data['y_pred_test_all'] = y_pred_global_test
-        Xy_data['y_pred_test'] = y_test_pred
-        Xy_data['y_pred_test_sd'] = y_test_std
+        Xy_data["y_pred_test_all"] = y_pred_global_test
+        Xy_data["y_pred_test"] = y_test_pred
+        Xy_data["y_pred_test_sd"] = y_test_std
 
-        if 'X_external' in Xy_data: # if there is an external test set
-            y_external_pred, y_external_std = [],[]
+        if "X_external" in Xy_data:  # if there is an external test set
+            y_external_pred, y_external_std = [], []
             for y_val_external in y_pred_global_external:
-                if model_data['type'].lower() == 'reg':
+                if model_data["type"].lower() == "reg":
                     y_external_pred.append(np.mean(y_val_external))
-                elif model_data['type'].lower() == 'clas':
+                elif model_data["type"].lower() == "clas":
                     y_external_pred.append(int(round(np.mean(y_val_external))))
                 y_external_std.append(float(np.std(y_val_external)))
 
-            Xy_data['y_pred_external_all'] = y_pred_global_external
-            Xy_data['y_pred_external'] = y_external_pred
-            Xy_data['y_pred_external_sd'] = y_external_std
+            Xy_data["y_pred_external_all"] = y_pred_global_external
+            Xy_data["y_pred_external"] = y_external_pred
+            Xy_data["y_pred_external_sd"] = y_external_std
 
     return Xy_data
 
 
-def kfold_cv(y_global,y_pred_global,
-             y_pred_global_test,
-             y_pred_global_external,
-             model_data,loaded_model,Xy_data,random_state,
-             BO_opt=False,shuffle=True,kfold_cv_type='repeated'):
-    '''
+def kfold_cv(
+    y_global,
+    y_pred_global,
+    y_pred_global_test,
+    y_pred_global_external,
+    model_data,
+    loaded_model,
+    Xy_data,
+    random_state,
+    BO_opt=False,
+    shuffle=True,
+    kfold_cv_type="repeated",
+):
+    """
     Perform a k-fold CV
     Uses StratifiedKFold for classification problems to maintain class distribution
-    '''
+    """
 
     # load CV scheme
-    if model_data['type'].lower() == 'clas':
+    if model_data["type"].lower() == "clas":
         # Use StratifiedKFold for classification to maintain class distribution
-        cv = StratifiedKFold(n_splits=int(model_data['kfold']), shuffle=shuffle, random_state=random_state)
+        cv = StratifiedKFold(
+            n_splits=int(model_data["kfold"]),
+            shuffle=shuffle,
+            random_state=random_state,
+        )
     else:
-        cv = KFold(n_splits=int(model_data['kfold']), shuffle=shuffle, random_state=random_state)
+        cv = KFold(
+            n_splits=int(model_data["kfold"]),
+            shuffle=shuffle,
+            random_state=random_state,
+        )
 
     # # load Xy values and sort using y_train as the sorting reference
-    if kfold_cv_type == 'sorted':
-        X_init,y_init = sort_n_load(Xy_data) # do not use, currently it doesn't sort indices for X_train as well
+    if kfold_cv_type == "sorted":
+        X_init, y_init = sort_n_load(
+            Xy_data
+        )  # do not use, currently it doesn't sort indices for X_train as well
 
     else:
         # convert Xy values of training and validation for CV
-        X_init = np.array(Xy_data['X_train_scaled'])
-        y_init = np.array(Xy_data['y_train'])
+        X_init = np.array(Xy_data["X_train_scaled"])
+        y_init = np.array(Xy_data["y_train"])
 
         # convert Xy values for the test set and external test set (if any)
-        X_test = np.array(Xy_data['X_test_scaled'])
-        if 'X_external_scaled' in Xy_data:
-            X_external = np.array(Xy_data['X_external_scaled'])
+        X_test = np.array(Xy_data["X_test_scaled"])
+        if "X_external_scaled" in Xy_data:
+            X_external = np.array(Xy_data["X_external_scaled"])
 
     ix_training, ix_valid = [], []
     # Loop through each fold and append the training & test indices to the empty lists above
-    if model_data['type'].lower() == 'clas':
+    if model_data["type"].lower() == "clas":
         # For classification, we need to pass y values to ensure stratification
         for fold in cv.split(X_init, y_init):
             ix_training.append(fold[0]), ix_valid.append(fold[1])
@@ -2473,8 +2915,8 @@ def kfold_cv(y_global,y_pred_global,
         for fold in cv.split(X_init):
             ix_training.append(fold[0]), ix_valid.append(fold[1])
 
-    # Loop through each outer fold, and extract predicted vs actual values and SHAP feature analysis 
-    for (train_outer_ix, test_outer_ix) in zip(ix_training, ix_valid): 
+    # Loop through each outer fold, and extract predicted vs actual values and SHAP feature analysis
+    for train_outer_ix, test_outer_ix in zip(ix_training, ix_valid):
         X_train, X_valid = X_init[train_outer_ix, :], X_init[test_outer_ix, :]
         y_train, y_valid = y_init[train_outer_ix], y_init[test_outer_ix]
 
@@ -2482,94 +2924,125 @@ def kfold_cv(y_global,y_pred_global,
         y_pred_valid = fit.predict(X_valid)
         if not BO_opt:
             y_pred_test = fit.predict(X_test)
-            if 'X_external_scaled' in Xy_data:
+            if "X_external_scaled" in Xy_data:
                 y_pred_external = fit.predict(X_external)
-        
-        if kfold_cv_type == 'repeated':
-            for y_val,y_pred_val,idx in zip(y_valid,y_pred_valid,test_outer_ix):
+
+        if kfold_cv_type == "repeated":
+            for y_val, y_pred_val, idx in zip(y_valid, y_pred_valid, test_outer_ix):
                 y_global[idx].append(y_val)
                 y_pred_global[idx].append(y_pred_val)
             if not BO_opt:
-                for idx,y_pred_val_test in enumerate(y_pred_test):
+                for idx, y_pred_val_test in enumerate(y_pred_test):
                     y_pred_global_test[idx].append(y_pred_val_test)
-                if 'X_external_scaled' in Xy_data:
-                    for idx,y_pred_val_external in enumerate(y_pred_external):
+                if "X_external_scaled" in Xy_data:
+                    for idx, y_pred_val_external in enumerate(y_pred_external):
                         y_pred_global_external[idx].append(y_pred_val_external)
 
-        elif kfold_cv_type == 'sorted':
+        elif kfold_cv_type == "sorted":
             y_global.append(y_valid)
-            y_pred_global.append(y_pred_valid) 
+            y_pred_global.append(y_pred_valid)
 
-    return y_global,y_pred_global,y_pred_global_test,y_pred_global_external
+    return y_global, y_pred_global, y_pred_global_test, y_pred_global_external
 
 
 def sort_n_load(Xy_data):
-    '''
+    """
     Sort Xy data values to enhance reproducibility in cases where same databases are loaded
     with different row order, ensuring stable sorting across OS with kind='stable'.
-    '''
-    
-    X_train_scaled = np.array(Xy_data['X_train_scaled'])
-    y_train = np.array(Xy_data['y_train'])
+    """
 
-    sorted_indices = np.argsort(y_train, kind='stable')
+    X_train_scaled = np.array(Xy_data["X_train_scaled"])
+    y_train = np.array(Xy_data["y_train"])
+
+    sorted_indices = np.argsort(y_train, kind="stable")
     sorted_X_train_scaled = X_train_scaled[sorted_indices]
     sorted_y_train = y_train[sorted_indices]
 
     return sorted_X_train_scaled, sorted_y_train
 
 
-def sorted_kfold_cv(loaded_model,model_data,Xy_data,error_labels):
-    '''
+def sorted_kfold_cv(loaded_model, model_data, Xy_data, error_labels):
+    """
     Performs a sorted k-fold cross-validation on the Xy dataset. Returns the average of the two results
-    '''
+    """
 
     # perform sorted 5-fold CV
-    Xy_data['y_sorted_cv'],Xy_data['y_pred_sorted_cv'] = [],[]
-    Xy_data['y_sorted_cv'],Xy_data['y_pred_sorted_cv'],_,_ = kfold_cv(Xy_data['y_sorted_cv'],Xy_data['y_pred_sorted_cv'],
-                                                None,
-                                                None,
-                                                model_data,loaded_model,Xy_data,None,BO_opt=True,shuffle=False,kfold_cv_type='sorted')
-    error1 = error_labels[model_data['type']][0]
-    error2 = error_labels[model_data['type']][1]
-    error3 = error_labels[model_data['type']][2]
-    if model_data['type'].lower() == 'reg':
-        Xy_data[f'{error1}_train_sorted_CV'], Xy_data[f'{error2}_train_sorted_CV'], Xy_data[f'{error3}_train_sorted_CV'] = [],[],[]
-        for y_cv,y_pred_cd in zip(Xy_data['y_sorted_cv'],Xy_data['y_pred_sorted_cv']):
-            r2_train_sorted_CV, mae_train_sorted_CV, rmse_train_sorted_CV = get_prediction_results(model_data,y_cv,y_pred_cd)
-            Xy_data[f'{error1}_train_sorted_CV'].append(r2_train_sorted_CV)
-            Xy_data[f'{error2}_train_sorted_CV'].append(mae_train_sorted_CV)
-            Xy_data[f'{error3}_train_sorted_CV'].append(rmse_train_sorted_CV)
+    Xy_data["y_sorted_cv"], Xy_data["y_pred_sorted_cv"] = [], []
+    Xy_data["y_sorted_cv"], Xy_data["y_pred_sorted_cv"], _, _ = kfold_cv(
+        Xy_data["y_sorted_cv"],
+        Xy_data["y_pred_sorted_cv"],
+        None,
+        None,
+        model_data,
+        loaded_model,
+        Xy_data,
+        None,
+        BO_opt=True,
+        shuffle=False,
+        kfold_cv_type="sorted",
+    )
+    error1 = error_labels[model_data["type"]][0]
+    error2 = error_labels[model_data["type"]][1]
+    error3 = error_labels[model_data["type"]][2]
+    if model_data["type"].lower() == "reg":
+        (
+            Xy_data[f"{error1}_train_sorted_CV"],
+            Xy_data[f"{error2}_train_sorted_CV"],
+            Xy_data[f"{error3}_train_sorted_CV"],
+        ) = [], [], []
+        for y_cv, y_pred_cd in zip(Xy_data["y_sorted_cv"], Xy_data["y_pred_sorted_cv"]):
+            r2_train_sorted_CV, mae_train_sorted_CV, rmse_train_sorted_CV = (
+                get_prediction_results(model_data, y_cv, y_pred_cd)
+            )
+            Xy_data[f"{error1}_train_sorted_CV"].append(r2_train_sorted_CV)
+            Xy_data[f"{error2}_train_sorted_CV"].append(mae_train_sorted_CV)
+            Xy_data[f"{error3}_train_sorted_CV"].append(rmse_train_sorted_CV)
 
         # take the worst performing predictions from the top and bottom folds
-        if model_data["error_type"].lower() in ['mae','rmse']:
-            Xy_data[f'{model_data["error_type"]}_up_bottom'] = max(Xy_data[f'{model_data["error_type"]}_train_sorted_CV'][0], Xy_data[f'{model_data["error_type"]}_train_sorted_CV'][-1])
-            Xy_data['r2_up_bottom'] = min(Xy_data['r2_train_sorted_CV'][0], Xy_data['r2_train_sorted_CV'][-1])
+        if model_data["error_type"].lower() in ["mae", "rmse"]:
+            Xy_data[f"{model_data['error_type']}_up_bottom"] = max(
+                Xy_data[f"{model_data['error_type']}_train_sorted_CV"][0],
+                Xy_data[f"{model_data['error_type']}_train_sorted_CV"][-1],
+            )
+            Xy_data["r2_up_bottom"] = min(
+                Xy_data["r2_train_sorted_CV"][0], Xy_data["r2_train_sorted_CV"][-1]
+            )
         else:  # r2
-            Xy_data[f'{model_data["error_type"]}_up_bottom'] = min(Xy_data[f'{model_data["error_type"]}_train_sorted_CV'][0], Xy_data[f'{model_data["error_type"]}_train_sorted_CV'][-1])
+            Xy_data[f"{model_data['error_type']}_up_bottom"] = min(
+                Xy_data[f"{model_data['error_type']}_train_sorted_CV"][0],
+                Xy_data[f"{model_data['error_type']}_train_sorted_CV"][-1],
+            )
 
     else:  # classification
-        Xy_data[f'{error1}_train_sorted_CV'], Xy_data[f'{error2}_train_sorted_CV'], Xy_data[f'{error3}_train_sorted_CV'] = [],[],[]
-        for y_cv, y_pred_cd in zip(Xy_data['y_sorted_cv'], Xy_data['y_pred_sorted_cv']):
-            acc_fold, f1_fold, mcc_fold = get_prediction_results(model_data, y_cv, y_pred_cd)
-            Xy_data[f'{error1}_train_sorted_CV'].append(acc_fold)
-            Xy_data[f'{error2}_train_sorted_CV'].append(f1_fold)
-            Xy_data[f'{error3}_train_sorted_CV'].append(mcc_fold)
+        (
+            Xy_data[f"{error1}_train_sorted_CV"],
+            Xy_data[f"{error2}_train_sorted_CV"],
+            Xy_data[f"{error3}_train_sorted_CV"],
+        ) = [], [], []
+        for y_cv, y_pred_cd in zip(Xy_data["y_sorted_cv"], Xy_data["y_pred_sorted_cv"]):
+            acc_fold, f1_fold, mcc_fold = get_prediction_results(
+                model_data, y_cv, y_pred_cd
+            )
+            Xy_data[f"{error1}_train_sorted_CV"].append(acc_fold)
+            Xy_data[f"{error2}_train_sorted_CV"].append(f1_fold)
+            Xy_data[f"{error3}_train_sorted_CV"].append(mcc_fold)
 
         # Measure fold stability by difference between best and worst fold
-        Xy_data[f'{model_data["error_type"]}_up_bottom'] = np.mean(np.abs(Xy_data[f'{model_data["error_type"]}_train_sorted_CV']))
+        Xy_data[f"{model_data['error_type']}_up_bottom"] = np.mean(
+            np.abs(Xy_data[f"{model_data['error_type']}_train_sorted_CV"])
+        )
 
     return Xy_data
 
 
-def k_means(self,X_scaled,csv_y,size,seed,idx_list):
-    '''
-    
-    Uses k-means clustering to select the test points to be as diverse as possible, 
+def k_means(self, X_scaled, csv_y, size, seed, idx_list):
+    """
+
+    Uses k-means clustering to select the test points to be as diverse as possible,
     but it returns the test pointsReturns the data points that will be used as training set based on the k-means clustering
-    
-    '''
-    
+
+    """
+
     # number of clusters in the training set from the k-means clustering (based on the
     # training set size specified above)
     X_scaled_array = np.asarray(X_scaled)
@@ -2577,20 +3050,22 @@ def k_means(self,X_scaled,csv_y,size,seed,idx_list):
 
     # to avoid points from the validation set outside the training set, the 2 first training
     # points are automatically set as the 2 points with minimum/maximum response value
-    if self.args.type.lower() == 'reg':
+    if self.args.type.lower() == "reg":
         test_points = []
-        training_idx = [csv_y.idxmin(),csv_y.idxmax()]
+        training_idx = [csv_y.idxmin(), csv_y.idxmax()]
         number_of_clusters -= 2
     else:
         test_points = []
         training_idx = []
-    
+
     # runs the k-means algorithm and keeps the closest point to the center of each cluster
-    kmeans = KMeans(n_clusters=number_of_clusters,random_state=seed)
+    kmeans = KMeans(n_clusters=number_of_clusters, random_state=seed)
     try:
         kmeans.fit(X_scaled_array)
     except ValueError:
-        self.args.log.write("\nx  The K-means clustering process failed! This might be due to having NaN or strings as descriptors (curate the data first with CURATE) or having too few datapoints!")
+        self.args.log.write(
+            "\nx  The K-means clustering process failed! This might be due to having NaN or strings as descriptors (curate the data first with CURATE) or having too few datapoints!"
+        )
         sys.exit()
     centers = kmeans.cluster_centers_
     for i in range(number_of_clusters):
@@ -2599,14 +3074,18 @@ def k_means(self,X_scaled,csv_y,size,seed,idx_list):
             if k not in training_idx:
                 # calculate the Euclidean distance in n-dimensions
                 points_sum = 0
-                for l in range(len(X_scaled_array[0])):
-                    points_sum += (X_scaled_array[:, l][k]-centers[:, l][i])**2
+                for idx_l in range(len(X_scaled_array[0])):
+                    points_sum += (
+                        X_scaled_array[:, idx_l][k] - centers[:, idx_l][i]
+                    ) ** 2
                 if np.sqrt(points_sum) < results_cluster:
                     results_cluster = np.sqrt(points_sum)
                     training_point = k
         training_idx.append(training_point)
 
-    test_idx = [idx for idx in range(len(X_scaled_array[:, 0])) if idx not in training_idx]
+    test_idx = [
+        idx for idx in range(len(X_scaled_array[:, 0])) if idx not in training_idx
+    ]
     test_points = [idx_list[i] for i in test_idx]
     test_points.sort()
 
@@ -2614,62 +3093,86 @@ def k_means(self,X_scaled,csv_y,size,seed,idx_list):
 
 
 def PFI_filter(self, Xy_data, model_data):
-    '''
+    """
     Performs the PFI calculation and returns a list of the descriptors that are not important
-    '''
+    """
 
     # load and fit model
-    loaded_model = load_model(self,model_data['model'],**model_data['params'])
-    loaded_model.fit(Xy_data['X_train_scaled'], Xy_data['y_train'])
+    loaded_model = load_model(self, model_data["model"], **model_data["params"])
+    loaded_model.fit(Xy_data["X_train_scaled"], Xy_data["y_train"])
 
     # select scoring function for PFI analysis based on the error type
-    scoring, score_model, _ = scoring_n_score(self,model_data,Xy_data,loaded_model)
-    
-    perm_importance = permutation_importance(loaded_model, Xy_data['X_train_scaled'], Xy_data['y_train'], scoring=scoring, n_repeats=self.args.pfi_epochs, random_state=self.args.seed, n_jobs=1)
+    scoring, score_model, _ = scoring_n_score(self, model_data, Xy_data, loaded_model)
+
+    perm_importance = permutation_importance(
+        loaded_model,
+        Xy_data["X_train_scaled"],
+        Xy_data["y_train"],
+        scoring=scoring,
+        n_repeats=self.args.pfi_epochs,
+        random_state=self.args.seed,
+        n_jobs=1,
+    )
 
     # transforms the values into a list and sort the PFI values with the descriptor names
-    descp_cols_pfi, PFI_values, PFI_sd = [],[],[]
-    for i,desc in enumerate(Xy_data['X_train_scaled'].columns):
-        descp_cols_pfi.append(desc) # includes lists of descriptors not column names!
+    descp_cols_pfi, PFI_values, PFI_sd = [], [], []
+    for i, desc in enumerate(Xy_data["X_train_scaled"].columns):
+        descp_cols_pfi.append(desc)  # includes lists of descriptors not column names!
         PFI_values.append(perm_importance.importances_mean[i])
         PFI_sd.append(perm_importance.importances_std[i])
-  
-    PFI_values, PFI_sd, descp_cols_pfi = (list(t) for t in zip(*sorted(zip(PFI_values, PFI_sd, descp_cols_pfi), reverse=True)))
+
+    PFI_values, PFI_sd, descp_cols_pfi = (
+        list(t)
+        for t in zip(*sorted(zip(PFI_values, PFI_sd, descp_cols_pfi), reverse=True))
+    )
 
     # PFI filter
     PFI_discard_cols = []
     # the threshold is based either on the RMSE of the model or the importance of the most important descriptor
-    PFI_thres = max([abs(self.args.pfi_threshold*score_model),abs(self.args.pfi_threshold*PFI_values[0])])
+    PFI_thres = max(
+        [
+            abs(self.args.pfi_threshold * score_model),
+            abs(self.args.pfi_threshold * PFI_values[0]),
+        ]
+    )
     for i in range(len(PFI_values)):
         if PFI_values[i] < PFI_thres:
             PFI_discard_cols.append(descp_cols_pfi[i])
 
-    return PFI_discard_cols,descp_cols_pfi
+    return PFI_discard_cols, descp_cols_pfi
 
 
-def scoring_n_score(self,model_data,Xy_data,loaded_model):
-    '''
+def scoring_n_score(self, model_data, Xy_data, loaded_model):
+    """
     Get scoring system and score of the original model with CV
-    '''
+    """
 
-    error_type = model_data['error_type'].lower()
-    scoring = get_scoring_key(model_data['type'],error_type)
-    cv_model = RepeatedKFold(n_splits=self.args.kfold, n_repeats=self.args.repeat_kfolds, random_state=self.args.seed)
-    score_model = cross_val_score(estimator = loaded_model, X=Xy_data['X_train_scaled'], y=Xy_data['y_train'],scoring=scoring, cv =cv_model)
+    error_type = model_data["error_type"].lower()
+    scoring = get_scoring_key(model_data["type"], error_type)
+    cv_model = RepeatedKFold(
+        n_splits=self.args.kfold,
+        n_repeats=self.args.repeat_kfolds,
+        random_state=self.args.seed,
+    )
+    score_model = cross_val_score(
+        estimator=loaded_model,
+        X=Xy_data["X_train_scaled"],
+        y=Xy_data["y_train"],
+        scoring=scoring,
+        cv=cv_model,
+    )
     score_model = score_model.mean()
 
-    if model_data['error_type'].lower() in ['rmse','mae']:
+    if model_data["error_type"].lower() in ["rmse", "mae"]:
         score_model = -score_model
 
     return scoring, score_model, error_type
 
 
-def create_heatmap(self,csv_df,suffix,path_raw):
+def create_heatmap(self, csv_df, suffix, path_raw):
     """
     Graph the heatmap
     """
-    import seaborn as sb
-
     with _mpl_plot_context():
         csv_df = csv_df.sort_index(ascending=False)
         sb.set(font_scale=1.2, style="ticks")
@@ -2690,265 +3193,377 @@ def create_heatmap(self,csv_df,suffix,path_raw):
         ax.set_xlabel("ML Model", fontsize=fontsize)
         ax.set_ylabel("", fontsize=fontsize)
         ax.tick_params(axis="x", which="major", labelsize=fontsize)
-        ax.tick_params(
-            axis="y", which="both", left=False, right=False, labelleft=False
-        )
+        ax.tick_params(axis="y", which="both", left=False, right=False, labelleft=False)
         title_fig = f"Heatmap ML models {suffix}"
         plt.title(title_fig, y=1.04, fontsize=fontsize, fontweight="bold")
         sb.despine(top=False, right=False)
         name_fig = "_".join(title_fig.split())
-        plt.savefig(
-            f"{path_raw.joinpath(name_fig)}.png", dpi=300, bbox_inches="tight"
-        )
+        plt.savefig(f"{path_raw.joinpath(name_fig)}.png", dpi=300, bbox_inches="tight")
+        plt.close()
 
     path_reduced = "/".join(f"{path_raw}".replace("\\", "/").split("/")[-2:])
     self.args.log.write(f"\no  {name_fig} succesfully created in {path_reduced}")
 
 
-def graph_reg(self,Xy_data,params_dict,set_types,path_n_suffix,graph_style,csv_test=False,print_fun=True,sd_graph=False):
-    '''
+def graph_reg(
+    self,
+    Xy_data,
+    params_dict,
+    set_types,
+    path_n_suffix,
+    graph_style,
+    csv_test=False,
+    print_fun=True,
+    sd_graph=False,
+):
+    """
     Plot regression graphs of predicted vs actual values for train, validation and test sets
-    '''
-    import seaborn as sb
-
+    """
     sb.set(style="ticks")
 
-    _, ax = plt.subplots(figsize=(7.45,6))
+    fig, ax = plt.subplots(figsize=(7.45, 6))
 
     # Set tick sizes
     plt.xticks(fontsize=14)
     plt.yticks(fontsize=14)
-    
+
     error_bars = "test"
 
-    title_graph = graph_title(self,csv_test,sd_graph,error_bars)
+    title_graph = graph_title(self, csv_test, sd_graph, error_bars)
 
     if print_fun:
-        plt.text(0.5, 1.08, f'{title_graph} of {os.path.basename(path_n_suffix)}', horizontalalignment='center',
-            fontsize=14, fontweight='bold', transform = ax.transAxes)
+        plt.text(
+            0.5,
+            1.08,
+            f"{title_graph} of {os.path.basename(path_n_suffix)}",
+            horizontalalignment="center",
+            fontsize=14,
+            fontweight="bold",
+            transform=ax.transAxes,
+        )
 
     # Plot the data
     if not sd_graph:
-        _ = ax.scatter(Xy_data["y_train"], Xy_data["y_pred_train"],
-                    c = graph_style['color_train'], s = graph_style['dot_size'], edgecolor = 'k', linewidths = 0.8, alpha = graph_style['alpha'], zorder=2)   
+        _ = ax.scatter(
+            Xy_data["y_train"],
+            Xy_data["y_pred_train"],
+            c=graph_style["color_train"],
+            s=graph_style["dot_size"],
+            edgecolor="k",
+            linewidths=0.8,
+            alpha=graph_style["alpha"],
+            zorder=2,
+        )
 
     if not csv_test:
-        _ = ax.scatter(Xy_data["y_test"], Xy_data["y_pred_test"],
-                    c = graph_style['color_test'], s = graph_style['dot_size'], edgecolor = 'k', linewidths = 0.8, alpha = graph_style['alpha'], zorder=3)
+        _ = ax.scatter(
+            Xy_data["y_test"],
+            Xy_data["y_pred_test"],
+            c=graph_style["color_test"],
+            s=graph_style["dot_size"],
+            edgecolor="k",
+            linewidths=0.8,
+            alpha=graph_style["alpha"],
+            zorder=3,
+        )
 
     else:
         error_bars = "external"
-        _ = ax.scatter(Xy_data["y_external"], Xy_data["y_pred_external"],
-                        c = graph_style['color_test'], s = graph_style['dot_size'], edgecolor = 'k', linewidths = 0.8, alpha = graph_style['alpha'], zorder=2)
+        _ = ax.scatter(
+            Xy_data["y_external"],
+            Xy_data["y_pred_external"],
+            c=graph_style["color_test"],
+            s=graph_style["dot_size"],
+            edgecolor="k",
+            linewidths=0.8,
+            alpha=graph_style["alpha"],
+            zorder=2,
+        )
 
-    # average CV ± SD graphs 
+    # average CV ± SD graphs
     if sd_graph:
-        if not csv_test:   
+        if not csv_test:
             # Plot the data with the error bars
-            _ = ax.errorbar(Xy_data[f"y_{error_bars}"], Xy_data[f"y_pred_{error_bars}"], yerr=Xy_data[f"y_pred_{error_bars}_sd"], fmt='none', ecolor="gray", capsize=3, zorder=1)
+            _ = ax.errorbar(
+                Xy_data[f"y_{error_bars}"],
+                Xy_data[f"y_pred_{error_bars}"],
+                yerr=Xy_data[f"y_pred_{error_bars}_sd"],
+                fmt="none",
+                ecolor="gray",
+                capsize=3,
+                zorder=1,
+            )
             # Adjust labels from legend
-            set_types=[error_bars,f'± SD']
+            set_types = [error_bars, "± SD"]
 
         else:
-            _ = ax.errorbar(Xy_data[f"y_{error_bars}"], Xy_data[f"y_pred_{error_bars}"], yerr=Xy_data[f"y_pred_{error_bars}_sd"], fmt='none', ecolor="gray", capsize=3, zorder=1)
-            set_types=['External test',f'± SD']
+            _ = ax.errorbar(
+                Xy_data[f"y_{error_bars}"],
+                Xy_data[f"y_pred_{error_bars}"],
+                yerr=Xy_data[f"y_pred_{error_bars}_sd"],
+                fmt="none",
+                ecolor="gray",
+                capsize=3,
+                zorder=1,
+            )
+            set_types = ["External test", "± SD"]
 
     # legend and regression line with 95% CI considering all possible lines (not CI of the points)
-    if 'CV' in set_types[0]: # CV in VERIFY
+    if "CV" in set_types[0]:  # CV in VERIFY
         legend_coords = (0.70, 0.15)
-    elif len(set_types) == 2: # external test or sets with ± SD
-        if 'External test' in set_types:
+    elif len(set_types) == 2:  # external test or sets with ± SD
+        if "External test" in set_types:
             legend_coords = (0.66, 0.15)
         else:
             legend_coords = (0.735, 0.15)
-    ax.legend(loc='upper center', bbox_to_anchor=legend_coords, 
-            handletextpad=0,
-            fancybox=True, shadow=True, ncol=5, labels=set_types, fontsize=14)
+    ax.legend(
+        loc="upper center",
+        bbox_to_anchor=legend_coords,
+        handletextpad=0,
+        fancybox=True,
+        shadow=True,
+        ncol=5,
+        labels=set_types,
+        fontsize=14,
+    )
 
     Xy_data_df = pd.DataFrame()
     if not sd_graph:
-        line_suff = 'train'
+        line_suff = "train"
     elif not csv_test:
-        line_suff = 'test'
+        line_suff = "test"
     else:
-        line_suff = 'external'
+        line_suff = "external"
 
     Xy_data_df[f"y_{line_suff}"] = Xy_data[f"y_{line_suff}"]
     Xy_data_df[f"y_pred_{line_suff}"] = Xy_data[f"y_pred_{line_suff}"]
     if len(Xy_data_df[f"y_pred_{line_suff}"]) >= 10:
-        _ = sb.regplot(x=f"y_{line_suff}", y=f"y_pred_{line_suff}", data=Xy_data_df, scatter=False, color=".1", 
-                        truncate = True, ax=ax, seed=params_dict['seed'])
+        _ = sb.regplot(
+            x=f"y_{line_suff}",
+            y=f"y_pred_{line_suff}",
+            data=Xy_data_df,
+            scatter=False,
+            color=".1",
+            truncate=True,
+            ax=ax,
+            seed=params_dict["seed"],
+        )
 
     # Title and labels of the axis
-    plt.ylabel(f'Predicted {params_dict["y"]}', fontsize=14)
-    plt.xlabel(f'{params_dict["y"]}', fontsize=14)
+    plt.ylabel(f"Predicted {params_dict['y']}", fontsize=14)
+    plt.xlabel(f"{params_dict['y']}", fontsize=14)
 
     # set axis limits and graph PATH
-    min_value_graph,max_value_graph,reg_plot_file,path_reduced = graph_vars(Xy_data,set_types,csv_test,path_n_suffix,sd_graph)
+    min_value_graph, max_value_graph, reg_plot_file, path_reduced = graph_vars(
+        Xy_data, set_types, csv_test, path_n_suffix, sd_graph
+    )
 
     # track the range of predictions (used in ROBERT score)
-    pred_min = min(min(Xy_data["y_train"]),min(Xy_data["y_test"]))
-    pred_max = max(max(Xy_data["y_train"]),max(Xy_data["y_test"]))
-    pred_range = np.abs(pred_max-pred_min)
-    Xy_data['pred_min'] = pred_min
-    Xy_data['pred_max'] = pred_max
-    Xy_data['pred_range'] = pred_range
+    pred_min = min(min(Xy_data["y_train"]), min(Xy_data["y_test"]))
+    pred_max = max(max(Xy_data["y_train"]), max(Xy_data["y_test"]))
+    pred_range = np.abs(pred_max - pred_min)
+    Xy_data["pred_min"] = pred_min
+    Xy_data["pred_max"] = pred_max
+    Xy_data["pred_range"] = pred_range
 
     # Add gridlines
-    ax.grid(linestyle='--', linewidth=1)
+    ax.grid(linestyle="--", linewidth=1)
 
     # set axis limits
     plt.xlim(min_value_graph, max_value_graph)
     plt.ylim(min_value_graph, max_value_graph)
 
     # save graph
-    plt.savefig(f'{reg_plot_file}', dpi=300, bbox_inches='tight')
+    plt.savefig(f"{reg_plot_file}", dpi=300, bbox_inches="tight")
+    plt.close(fig)
     if print_fun:
         self.args.log.write(f"      -  Graph in: {path_reduced}")
 
 
-def graph_title(self,csv_test,sd_graph,error_bars):
-    '''
+def graph_title(self, csv_test, sd_graph, error_bars):
+    """
     Retrieves the corresponding graph title.
-    '''
+    """
 
     # set title for regular graphs
     if not sd_graph:
         if not csv_test:
             # regular graphs
-            title_graph = f'Predictions CV and test set'
+            title_graph = "Predictions CV and test set"
         else:
-            title_graph = f'{os.path.basename(self.args.csv_test)}'
+            title_graph = f"{os.path.basename(self.args.csv_test)}"
             if len(title_graph) > 30:
-                title_graph = f'{title_graph[:27]}...'
+                title_graph = f"{title_graph[:27]}..."
 
     # set title for averaged CV ± SD graphs
     else:
         if not csv_test:
             sets_title = error_bars
         else:
-            sets_title = 'external test'
+            sets_title = "external test"
 
-        title_graph = f'{sets_title} set ± SD (CV)'
+        title_graph = f"{sets_title} set ± SD (CV)"
 
     return title_graph
 
 
-def graph_vars(Xy_data,set_types,csv_test,path_n_suffix,sd_graph):
-    '''
+def graph_vars(Xy_data, set_types, csv_test, path_n_suffix, sd_graph):
+    """
     Set axis limits for regression plots and PATH to save the graphs
-    '''
+    """
 
     # x and y axis limits for graphs with multiple sets
     if not csv_test:
-        size_space = 0.1*abs(min(Xy_data["y_train"])-max(Xy_data["y_train"]))
-        min_value_graph = min(min(Xy_data["y_train"]),min(Xy_data["y_pred_train"]),min(Xy_data["y_test"]),min(Xy_data["y_pred_test"]))
-        if 'test' in set_types:
-            min_value_graph = min(min_value_graph,min(Xy_data["y_test"]),min(Xy_data["y_pred_test"]))
-        min_value_graph = min_value_graph-size_space
-            
-        max_value_graph = max(max(Xy_data["y_train"]),max(Xy_data["y_pred_train"]),max(Xy_data["y_test"]),max(Xy_data["y_pred_test"]))
-        if 'test' in set_types:
-            max_value_graph = max(max_value_graph,max(Xy_data["y_test"]),max(Xy_data["y_pred_test"]))
-        max_value_graph = max_value_graph+size_space
-
-    else: # limits for graphs with only one set
-        set_type = 'external'
-        size_space = 0.1*abs(min(Xy_data[f'y_{set_type}'])-max(Xy_data[f'y_{set_type}']))
-        min_value_graph = min(min(Xy_data[f'y_{set_type}']),min(Xy_data[f'y_pred_{set_type}']))
-        min_value_graph = min_value_graph-size_space
-        max_value_graph = max(max(Xy_data[f'y_{set_type}']),max(Xy_data[f'y_pred_{set_type}']))
-        max_value_graph = max_value_graph+size_space
+        size_space = 0.1 * abs(min(Xy_data["y_train"]) - max(Xy_data["y_train"]))
+        min_value_graph = min(
+            min(Xy_data["y_train"]),
+            min(Xy_data["y_pred_train"]),
+            min(Xy_data["y_test"]),
+            min(Xy_data["y_pred_test"]),
+        )
+        if "test" in set_types:
+            min_value_graph = min(
+                min_value_graph, min(Xy_data["y_test"]), min(Xy_data["y_pred_test"])
+            )
+        min_value_graph = min_value_graph - size_space
+
+        max_value_graph = max(
+            max(Xy_data["y_train"]),
+            max(Xy_data["y_pred_train"]),
+            max(Xy_data["y_test"]),
+            max(Xy_data["y_pred_test"]),
+        )
+        if "test" in set_types:
+            max_value_graph = max(
+                max_value_graph, max(Xy_data["y_test"]), max(Xy_data["y_pred_test"])
+            )
+        max_value_graph = max_value_graph + size_space
+
+    else:  # limits for graphs with only one set
+        set_type = "external"
+        size_space = 0.1 * abs(
+            min(Xy_data[f"y_{set_type}"]) - max(Xy_data[f"y_{set_type}"])
+        )
+        min_value_graph = min(
+            min(Xy_data[f"y_{set_type}"]), min(Xy_data[f"y_pred_{set_type}"])
+        )
+        min_value_graph = min_value_graph - size_space
+        max_value_graph = max(
+            max(Xy_data[f"y_{set_type}"]), max(Xy_data[f"y_pred_{set_type}"])
+        )
+        max_value_graph = max_value_graph + size_space
 
     # PATH of the graph
     if not csv_test:
         if not sd_graph:
-            reg_plot_file = f'{os.path.dirname(path_n_suffix)}/Results_{os.path.basename(path_n_suffix)}.png'
+            reg_plot_file = f"{os.path.dirname(path_n_suffix)}/Results_{os.path.basename(path_n_suffix)}.png"
         else:
-            reg_plot_file = f'{os.path.dirname(path_n_suffix)}/CV_variability_{os.path.basename(path_n_suffix)}.png'
-        path_reduced = '/'.join(f'{reg_plot_file}'.replace('\\','/').split('/')[-2:])
+            reg_plot_file = f"{os.path.dirname(path_n_suffix)}/CV_variability_{os.path.basename(path_n_suffix)}.png"
+        path_reduced = "/".join(f"{reg_plot_file}".replace("\\", "/").split("/")[-2:])
 
     else:
-        folder_graph = f'{os.path.dirname(path_n_suffix)}/csv_test'
+        folder_graph = f"{os.path.dirname(path_n_suffix)}/csv_test"
         if not sd_graph:
-            reg_plot_file = f'{folder_graph}/Results_{os.path.basename(path_n_suffix)}_{set_type}.png'
+            reg_plot_file = f"{folder_graph}/Results_{os.path.basename(path_n_suffix)}_{set_type}.png"
         else:
-            reg_plot_file = f'{folder_graph}/CV_variability_{os.path.basename(path_n_suffix)}_{set_type}.png'
-        path_reduced = '/'.join(f'{reg_plot_file}'.replace('\\','/').split('/')[-3:])
-    
-    return min_value_graph,max_value_graph,reg_plot_file,path_reduced
+            reg_plot_file = f"{folder_graph}/CV_variability_{os.path.basename(path_n_suffix)}_{set_type}.png"
+        path_reduced = "/".join(f"{reg_plot_file}".replace("\\", "/").split("/")[-3:])
+
+    return min_value_graph, max_value_graph, reg_plot_file, path_reduced
 
 
-def graph_clas(self,Xy_data,params_dict,set_type,path_n_suffix,csv_test=False,print_fun=True):
-    '''
+def graph_clas(
+    self, Xy_data, params_dict, set_type, path_n_suffix, csv_test=False, print_fun=True
+):
+    """
     Plot a confusion matrix with the prediction vs actual values
-    '''
+    """
 
     # Check if we need to use original class labels for display
     display_labels = None
-    if 'class_0_label' in params_dict and 'class_1_label' in params_dict:
-        display_labels = [params_dict['class_0_label'], params_dict['class_1_label']]
+    if "class_0_label" in params_dict and "class_1_label" in params_dict:
+        display_labels = [params_dict["class_0_label"], params_dict["class_1_label"]]
 
     # get confusion matrix
-    if 'CV' in set_type: # CV graphs
-        y_train_binary = np.round(Xy_data[f'y_train']).astype(int)
-        y_pred_train_binary = np.round(Xy_data[f'y_pred_train']).astype(int)
-        matrix = ConfusionMatrixDisplay.from_predictions(y_train_binary, y_pred_train_binary, 
-                                                          normalize=None, cmap='Blues', 
-                                                          display_labels=display_labels)
-    else: # other graphs
-        y_binary = np.round(Xy_data[f'y_{set_type}']).astype(int)
-        y_pred_binary = np.round(Xy_data[f'y_pred_{set_type}']).astype(int)
-        matrix = ConfusionMatrixDisplay.from_predictions(y_binary, y_pred_binary, 
-                                                          normalize=None, cmap='Blues',
-                                                          display_labels=display_labels) 
+    if "CV" in set_type:  # CV graphs
+        y_train_binary = np.round(Xy_data["y_train"]).astype(int)
+        y_pred_train_binary = np.round(Xy_data["y_pred_train"]).astype(int)
+        matrix = ConfusionMatrixDisplay.from_predictions(
+            y_train_binary,
+            y_pred_train_binary,
+            normalize=None,
+            cmap="Blues",
+            display_labels=display_labels,
+        )
+    else:  # other graphs
+        y_binary = np.round(Xy_data[f"y_{set_type}"]).astype(int)
+        y_pred_binary = np.round(Xy_data[f"y_pred_{set_type}"]).astype(int)
+        matrix = ConfusionMatrixDisplay.from_predictions(
+            y_binary,
+            y_pred_binary,
+            normalize=None,
+            cmap="Blues",
+            display_labels=display_labels,
+        )
 
     # transfer it to the same format and size used in reg graphs
-    _, ax = plt.subplots(figsize=(7.45,6))
-    matrix.plot(ax=ax, cmap='Blues')
+    _, ax = plt.subplots(figsize=(7.45, 6))
+    matrix.plot(ax=ax, cmap="Blues")
 
     if print_fun:
-        if 'CV' not in set_type:
-            title_set = f'{set_type} set'
+        if "CV" not in set_type:
+            title_set = f"{set_type} set"
         else:
             title_set = set_type
-        plt.text(0.5, 1.08, f'{title_set} of {os.path.basename(path_n_suffix)}', horizontalalignment='center',
-            fontsize=14, fontweight='bold', transform = ax.transAxes)
+        plt.text(
+            0.5,
+            1.08,
+            f"{title_set} of {os.path.basename(path_n_suffix)}",
+            horizontalalignment="center",
+            fontsize=14,
+            fontweight="bold",
+            transform=ax.transAxes,
+        )
 
-    plt.xlabel(f'Predicted {params_dict["y"]}', fontsize=14)
-    plt.ylabel(f'{params_dict["y"]}', fontsize=14)
+    plt.xlabel(f"Predicted {params_dict['y']}", fontsize=14)
+    plt.ylabel(f"{params_dict['y']}", fontsize=14)
     plt.xticks(fontsize=14)
     plt.yticks(fontsize=14)
 
     # save fig
-    if 'CV' in set_type: # CV graphs
-        clas_plot_file = f'{os.path.dirname(path_n_suffix)}/CV_train_valid_predict_{os.path.basename(path_n_suffix)}.png'
-        path_reduced = '/'.join(f'{clas_plot_file}'.replace('\\','/').split('/')[-2:])
+    if "CV" in set_type:  # CV graphs
+        clas_plot_file = f"{os.path.dirname(path_n_suffix)}/CV_train_valid_predict_{os.path.basename(path_n_suffix)}.png"
+        path_reduced = "/".join(f"{clas_plot_file}".replace("\\", "/").split("/")[-2:])
 
     elif not csv_test:
-        clas_plot_file = f'{os.path.dirname(path_n_suffix)}/Results_{os.path.basename(path_n_suffix)}_{set_type}.png'
-        path_reduced = '/'.join(f'{clas_plot_file}'.replace('\\','/').split('/')[-2:])
+        clas_plot_file = f"{os.path.dirname(path_n_suffix)}/Results_{os.path.basename(path_n_suffix)}_{set_type}.png"
+        path_reduced = "/".join(f"{clas_plot_file}".replace("\\", "/").split("/")[-2:])
 
     else:
-        folder_graph = f'{os.path.dirname(path_n_suffix)}/csv_test'
-        clas_plot_file = f'{folder_graph}/Results_{os.path.basename(path_n_suffix)}_{set_type}.png'
-        path_reduced = '/'.join(f'{clas_plot_file}'.replace('\\','/').split('/')[-3:])
+        folder_graph = f"{os.path.dirname(path_n_suffix)}/csv_test"
+        clas_plot_file = (
+            f"{folder_graph}/Results_{os.path.basename(path_n_suffix)}_{set_type}.png"
+        )
+        path_reduced = "/".join(f"{clas_plot_file}".replace("\\", "/").split("/")[-3:])
 
-    plt.savefig(f'{clas_plot_file}', dpi=300, bbox_inches='tight')
+    plt.savefig(f"{clas_plot_file}", dpi=300, bbox_inches="tight")
+    plt.close()
 
     if print_fun:
         self.args.log.write(f"      -  Graph in: {path_reduced}")
 
 
 def shap_analysis(self, Xy_data, model_data, path_n_suffix, fitted_model=None):
-    '''
+    """
     Plots and prints the results of the SHAP analysis
-    '''
+    """
     import shap
 
     _, _ = plt.subplots(figsize=(7.45, 6))
 
-    shap_plot_file = f'{os.path.dirname(path_n_suffix)}/SHAP_{os.path.basename(path_n_suffix)}.png'
+    shap_plot_file = (
+        f"{os.path.dirname(path_n_suffix)}/SHAP_{os.path.basename(path_n_suffix)}.png"
+    )
 
     if fitted_model is None:
         loaded_model = load_model(self, model_data["model"], **model_data["params"])
@@ -2957,59 +3572,85 @@ def shap_analysis(self, Xy_data, model_data, path_n_suffix, fitted_model=None):
         loaded_model = fitted_model
 
     # run the SHAP analysis and save the plot
-    explainer = shap.Explainer(loaded_model.predict, Xy_data['X_train_scaled'], seed=model_data['seed'])
+    explainer = shap.Explainer(
+        loaded_model.predict, Xy_data["X_train_scaled"], seed=model_data["seed"]
+    )
     try:
-        shap_values = explainer(Xy_data['X_train_scaled'])
+        shap_values = explainer(Xy_data["X_train_scaled"])
     except ValueError:
-        shap_values = explainer(Xy_data['X_train_scaled'],max_evals=(2*len(Xy_data['X_train_scaled'].columns))+1)
+        shap_values = explainer(
+            Xy_data["X_train_scaled"],
+            max_evals=(2 * len(Xy_data["X_train_scaled"].columns)) + 1,
+        )
 
-    shap_show = [self.args.shap_show,len(Xy_data['X_train_scaled'].columns)]
-    aspect_shap = 25+((min(shap_show)-2)*5)
-    height_shap = 1.2+min(shap_show)/4
+    shap_show = [self.args.shap_show, len(Xy_data["X_train_scaled"].columns)]
+    aspect_shap = 25 + ((min(shap_show) - 2) * 5)
+    height_shap = 1.2 + min(shap_show) / 4
 
     # explainer = shap.TreeExplainer(loaded_model) # in case the standard version doesn't work
-    _ = shap.summary_plot(shap_values, Xy_data['X_train_scaled'], max_display=self.args.shap_show,show=False, plot_size=[7.45,height_shap])
+    _ = shap.summary_plot(
+        shap_values,
+        Xy_data["X_train_scaled"],
+        max_display=self.args.shap_show,
+        show=False,
+        plot_size=[7.45, height_shap],
+    )
 
     # set title
-    plt.title(f'SHAP analysis of {os.path.basename(path_n_suffix)}', fontsize = 14, fontweight="bold")
+    plt.title(
+        f"SHAP analysis of {os.path.basename(path_n_suffix)}",
+        fontsize=14,
+        fontweight="bold",
+    )
 
-    path_reduced = '/'.join(f'{shap_plot_file}'.replace('\\','/').split('/')[-2:])
+    path_reduced = "/".join(f"{shap_plot_file}".replace("\\", "/").split("/")[-2:])
     print_shap = f"\n   o  SHAP plot saved in {path_reduced}"
 
     # collect SHAP values and print
-    desc_list, min_list, max_list = [],[],[]
-    for i,desc in enumerate(Xy_data['X_train_scaled']):
+    desc_list, min_list, max_list = [], [], []
+    for i, desc in enumerate(Xy_data["X_train_scaled"]):
         desc_list.append(desc)
-        val_list_indiv= []
-        for _,val in enumerate(shap_values.values):
+        val_list_indiv = []
+        for _, val in enumerate(shap_values.values):
             val_list_indiv.append(val[i])
         min_indiv = min(val_list_indiv)
         max_indiv = max(val_list_indiv)
         min_list.append(min_indiv)
         max_list.append(max_indiv)
-    
+
     if max(max_list, key=abs) > max(min_list, key=abs):
-        max_list, min_list, desc_list = (list(t) for t in zip(*sorted(zip(max_list, min_list, desc_list), reverse=True)))
+        max_list, min_list, desc_list = (
+            list(t)
+            for t in zip(*sorted(zip(max_list, min_list, desc_list), reverse=True))
+        )
     else:
-        min_list, max_list, desc_list = (list(t) for t in zip(*sorted(zip(min_list, max_list, desc_list), reverse=False)))
+        min_list, max_list, desc_list = (
+            list(t)
+            for t in zip(*sorted(zip(min_list, max_list, desc_list), reverse=False))
+        )
 
-    for i,desc in enumerate(desc_list):
-        print_shap += f"\n      -  {desc} = min: {min_list[i]:.2}, max: {max_list[i]:.2}"
+    for i, desc in enumerate(desc_list):
+        print_shap += (
+            f"\n      -  {desc} = min: {min_list[i]:.2}, max: {max_list[i]:.2}"
+        )
 
     self.args.log.write(print_shap)
 
     # adjust width of the colorbar
     plt.gcf().axes[-1].set_aspect(aspect_shap)
     plt.gcf().axes[-1].set_box_aspect(aspect_shap)
-    
-    plt.savefig(f'{shap_plot_file}', dpi=300, bbox_inches='tight')
+
+    plt.savefig(f"{shap_plot_file}", dpi=300, bbox_inches="tight")
+    plt.close()
 
 
 def PFI_plot(self, Xy_data, model_data, path_n_suffix, fitted_model=None):
-    '''
+    """
     Plots and prints the results of the PFI analysis
-    '''
-    pfi_plot_file = f'{os.path.dirname(path_n_suffix)}/PFI_{os.path.basename(path_n_suffix)}.png'
+    """
+    pfi_plot_file = (
+        f"{os.path.dirname(path_n_suffix)}/PFI_{os.path.basename(path_n_suffix)}.png"
+    )
 
     if fitted_model is None:
         loaded_model = load_model(self, model_data["model"], **model_data["params"])
@@ -3018,144 +3659,204 @@ def PFI_plot(self, Xy_data, model_data, path_n_suffix, fitted_model=None):
         loaded_model = fitted_model
 
     # select scoring function for PFI analysis based on the error type
-    scoring, _, error_type = scoring_n_score(self,model_data,Xy_data,loaded_model)
-
-    perm_importance = permutation_importance(loaded_model, Xy_data['X_train_scaled'], Xy_data['y_train'], scoring=scoring, n_repeats=self.args.pfi_epochs, random_state=model_data['seed'], n_jobs=1)
+    scoring, _, error_type = scoring_n_score(self, model_data, Xy_data, loaded_model)
+
+    perm_importance = permutation_importance(
+        loaded_model,
+        Xy_data["X_train_scaled"],
+        Xy_data["y_train"],
+        scoring=scoring,
+        n_repeats=self.args.pfi_epochs,
+        random_state=model_data["seed"],
+        n_jobs=1,
+    )
 
     # sort descriptors and results from PFI
-    desc_list, PFI_values, PFI_sd = [],[],[]
-    for i,desc in enumerate(Xy_data['X_train_scaled']):
+    desc_list, PFI_values, PFI_sd = [], [], []
+    for i, desc in enumerate(Xy_data["X_train_scaled"]):
         desc_list.append(desc)
         PFI_values.append(perm_importance.importances_mean[i])
         PFI_sd.append(perm_importance.importances_std[i])
 
     # sort from higher to lower values and keep only the top self.args.pfi_show descriptors
-    PFI_values, PFI_sd, desc_list = (list(t) for t in zip(*sorted(zip(PFI_values, PFI_sd, desc_list), reverse=True)))
-    PFI_values_plot = PFI_values[:self.args.pfi_show][::-1]
-    desc_list_plot = desc_list[:self.args.pfi_show][::-1]
+    PFI_values, PFI_sd, desc_list = (
+        list(t) for t in zip(*sorted(zip(PFI_values, PFI_sd, desc_list), reverse=True))
+    )
+    PFI_values_plot = PFI_values[: self.args.pfi_show][::-1]
+    desc_list_plot = desc_list[: self.args.pfi_show][::-1]
 
     # plot and print results
-    _, ax = plt.subplots(figsize=(7.45,6))
+    fig, ax = plt.subplots(figsize=(7.45, 6))
     y_ticks = np.arange(0, len(desc_list_plot))
     ax.barh(desc_list_plot, PFI_values_plot)
-    ax.set_yticks(y_ticks,labels=desc_list_plot,fontsize=14)
-    plt.text(0.5, 1.08, f'Permutation feature importances (PFIs) of {os.path.basename(path_n_suffix)}', horizontalalignment='center',
-        fontsize=14, fontweight='bold', transform = ax.transAxes)
-    ax.set(ylabel=None, xlabel='PFI')
+    ax.set_yticks(y_ticks, labels=desc_list_plot, fontsize=14)
+    plt.text(
+        0.5,
+        1.08,
+        f"Permutation feature importances (PFIs) of {os.path.basename(path_n_suffix)}",
+        horizontalalignment="center",
+        fontsize=14,
+        fontweight="bold",
+        transform=ax.transAxes,
+    )
+    ax.set(ylabel=None, xlabel="PFI")
 
-    plt.savefig(f'{pfi_plot_file}', dpi=300, bbox_inches='tight')
+    plt.savefig(f"{pfi_plot_file}", dpi=300, bbox_inches="tight")
+    plt.close(fig)
 
-    path_reduced = '/'.join(f'{pfi_plot_file}'.replace('\\','/').split('/')[-2:])
+    path_reduced = "/".join(f"{pfi_plot_file}".replace("\\", "/").split("/")[-2:])
     print_PFI = f"\n   o  PFI plot saved in {path_reduced}"
 
-    print_PFI += f'\n      Influence on {error_type.upper()}'
+    print_PFI += f"\n      Influence on {error_type.upper()}"
 
-    for i,desc in enumerate(desc_list):
+    for i, desc in enumerate(desc_list):
         print_PFI += f"\n      -  {desc} = {PFI_values[i]:.2} +- {PFI_sd[i]:.2}"
-    
+
     self.args.log.write(print_PFI)
 
 
-def outlier_plot(self,Xy_data,path_n_suffix,name_points,graph_style):
-    '''
+def outlier_plot(self, Xy_data, path_n_suffix, name_points, graph_style):
+    """
     Plots and prints the results of the outlier analysis
-    '''
-    import seaborn as sb
-
+    """
     # detect outliers
     outliers_data, print_outliers = outlier_filter(self, Xy_data, name_points)
 
     # plot data in SD units
     sb.set(style="ticks")
 
-    _, ax = plt.subplots(figsize=(7.45,6))
-    plt.text(0.5, 1.08, f'Outlier analysis of {os.path.basename(path_n_suffix)}', horizontalalignment='center',
-    fontsize=14, fontweight='bold', transform = ax.transAxes)
+    fig, ax = plt.subplots(figsize=(7.45, 6))
+    plt.text(
+        0.5,
+        1.08,
+        f"Outlier analysis of {os.path.basename(path_n_suffix)}",
+        horizontalalignment="center",
+        fontsize=14,
+        fontweight="bold",
+        transform=ax.transAxes,
+    )
+
+    plt.grid(linestyle="--", linewidth=1)
+    _ = ax.scatter(
+        outliers_data["train_scaled"],
+        outliers_data["train_scaled"],
+        c=graph_style["color_train"],
+        s=graph_style["dot_size"],
+        edgecolor="k",
+        linewidths=0.8,
+        alpha=graph_style["alpha"],
+        zorder=2,
+    )
+    _ = ax.scatter(
+        outliers_data["test_scaled"],
+        outliers_data["test_scaled"],
+        c=graph_style["color_test"],
+        s=graph_style["dot_size"],
+        edgecolor="k",
+        linewidths=0.8,
+        alpha=graph_style["alpha"],
+        zorder=2,
+    )
 
-    plt.grid(linestyle='--', linewidth=1)
-    _ = ax.scatter(outliers_data['train_scaled'], outliers_data['train_scaled'],
-            c = graph_style['color_train'], s = graph_style['dot_size'], edgecolor = 'k', linewidths = 0.8, alpha = graph_style['alpha'], zorder=2)
-    _ = ax.scatter(outliers_data['test_scaled'], outliers_data['test_scaled'],
-        c = graph_style['color_test'], s = graph_style['dot_size'], edgecolor = 'k', linewidths = 0.8, alpha = graph_style['alpha'], zorder=2)
-    
     # Set styling preferences and graph limits
-    plt.xlabel('SD of the errors',fontsize=14)
+    plt.xlabel("SD of the errors", fontsize=14)
     plt.xticks(fontsize=14)
-    plt.ylabel('SD of the errors',fontsize=14)
+    plt.ylabel("SD of the errors", fontsize=14)
     plt.yticks(fontsize=14)
-    
-    axis_limit = max(outliers_data['train_scaled'], key=abs)
-    if 'test_scaled' in outliers_data:
-        if max(outliers_data['test_scaled'], key=abs) > axis_limit:
-            axis_limit = max(outliers_data['test_scaled'], key=abs)
-    axis_limit = axis_limit+0.5
-    if axis_limit < 2.5: # this fixes a problem when representing rectangles in graphs with low SDs
+
+    axis_limit = max(outliers_data["train_scaled"], key=abs)
+    if "test_scaled" in outliers_data:
+        if max(outliers_data["test_scaled"], key=abs) > axis_limit:
+            axis_limit = max(outliers_data["test_scaled"], key=abs)
+    axis_limit = axis_limit + 0.5
+    if (
+        axis_limit < 2.5
+    ):  # this fixes a problem when representing rectangles in graphs with low SDs
         axis_limit = 2.5
     plt.ylim(-axis_limit, axis_limit)
     plt.xlim(-axis_limit, axis_limit)
 
     # plot rectangles in corners
     diff_tvalue = axis_limit - self.args.t_value
-    Rectangle_top = mpatches.Rectangle(xy=(axis_limit, axis_limit), width=-diff_tvalue, height=-diff_tvalue, facecolor='grey', alpha=0.3)
-    Rectangle_bottom = mpatches.Rectangle(xy=(-(axis_limit), -(axis_limit)), width=diff_tvalue, height=diff_tvalue, facecolor='grey', alpha=0.3)
+    Rectangle_top = mpatches.Rectangle(
+        xy=(axis_limit, axis_limit),
+        width=-diff_tvalue,
+        height=-diff_tvalue,
+        facecolor="grey",
+        alpha=0.3,
+    )
+    Rectangle_bottom = mpatches.Rectangle(
+        xy=(-(axis_limit), -(axis_limit)),
+        width=diff_tvalue,
+        height=diff_tvalue,
+        facecolor="grey",
+        alpha=0.3,
+    )
     ax.add_patch(Rectangle_top)
     ax.add_patch(Rectangle_bottom)
 
     # save plot and print results
-    outliers_plot_file = f'{os.path.dirname(path_n_suffix)}/Outliers_{os.path.basename(path_n_suffix)}.png'
-    plt.savefig(f'{outliers_plot_file}', dpi=300, bbox_inches='tight')
-    
-    path_reduced = '/'.join(f'{outliers_plot_file}'.replace('\\','/').split('/')[-2:])
+    outliers_plot_file = f"{os.path.dirname(path_n_suffix)}/Outliers_{os.path.basename(path_n_suffix)}.png"
+    plt.savefig(f"{outliers_plot_file}", dpi=300, bbox_inches="tight")
+    plt.close(fig)
+
+    path_reduced = "/".join(f"{outliers_plot_file}".replace("\\", "/").split("/")[-2:])
     print_outliers += f"\n   o  Outliers plot saved in {path_reduced}"
 
-    if 'train' not in name_points:
-        print_outliers += f'\n      x  No names option (or var missing in CSV file)! Outlier names will not be shown'
+    if "train" not in name_points:
+        print_outliers += "\n      x  No names option (or var missing in CSV file)! Outlier names will not be shown"
     else:
-        if 'test_scaled' in outliers_data and 'test' not in name_points:
-            print_outliers += f'\n      x  No names option (or var missing in CSV file in the test file)! Outlier names will not be shown'
+        if "test_scaled" in outliers_data and "test" not in name_points:
+            print_outliers += "\n      x  No names option (or var missing in CSV file in the test file)! Outlier names will not be shown"
+
+    print_outliers = outlier_analysis(print_outliers, outliers_data, "train")
+    print_outliers = outlier_analysis(print_outliers, outliers_data, "test")
 
-    print_outliers = outlier_analysis(print_outliers,outliers_data,'train')
-    print_outliers = outlier_analysis(print_outliers,outliers_data,'test')
-    
     self.args.log.write(print_outliers)
 
 
-def outlier_analysis(print_outliers,outliers_data,outliers_set):
-    '''
+def outlier_analysis(print_outliers, outliers_data, outliers_set):
+    """
     Analyzes the outlier results
-    '''
-    
-    if outliers_set == 'train':
-        label_set = 'Train'
-        outliers_label = 'outliers_train'
-        n_points_label = 'train_scaled'
-        outliers_name = 'names_train'
-    elif outliers_set == 'valid':
-        label_set = 'Validation'
-        outliers_label = 'outliers_valid'
-        n_points_label = 'valid_scaled'
-        outliers_name = 'names_valid'
-    elif outliers_set == 'test':
-        label_set = 'Test'
-        outliers_label = 'outliers_test'
-        n_points_label = 'test_scaled'
-        outliers_name = 'names_test'
-
-    per_cent = (len(outliers_data[outliers_label])/len(outliers_data[n_points_label]))*100
+    """
+
+    if outliers_set == "train":
+        label_set = "Train"
+        outliers_label = "outliers_train"
+        n_points_label = "train_scaled"
+        outliers_name = "names_train"
+    elif outliers_set == "valid":
+        label_set = "Validation"
+        outliers_label = "outliers_valid"
+        n_points_label = "valid_scaled"
+        outliers_name = "names_valid"
+    elif outliers_set == "test":
+        label_set = "Test"
+        outliers_label = "outliers_test"
+        n_points_label = "test_scaled"
+        outliers_name = "names_test"
+
+    per_cent = (
+        len(outliers_data[outliers_label]) / len(outliers_data[n_points_label])
+    ) * 100
     print_outliers += f"\n      {label_set}: {len(outliers_data[outliers_label])} outliers out of {len(outliers_data[n_points_label])} datapoints ({per_cent:.1f}%)"
-    for val,name in zip(outliers_data[outliers_label], outliers_data[outliers_name]):
+    for val, name in zip(outliers_data[outliers_label], outliers_data[outliers_name]):
         print_outliers += f"\n      -  {name} ({val:.2} SDs)"
     return print_outliers
 
 
 def outlier_filter(self, Xy_data, name_points):
-    '''
+    """
     Calculates and stores absolute errors in SD units for all the sets
-    '''
-    
+    """
+
     # calculate absolute errors between predicted y and actual values
-    outliers_train = [abs(x-y) for x,y in zip(Xy_data['y_train'],Xy_data['y_pred_train'])]
-    outliers_test = [abs(x-y) for x,y in zip(Xy_data['y_test'],Xy_data['y_pred_test'])]
+    outliers_train = [
+        abs(x - y) for x, y in zip(Xy_data["y_train"], Xy_data["y_pred_train"])
+    ]
+    outliers_test = [
+        abs(x - y) for x, y in zip(Xy_data["y_test"], Xy_data["y_pred_test"])
+    ]
 
     # the errors are scaled using standard deviation units. When the absolute
     # error is larger than the t-value, the point is considered an outlier. All the sets
@@ -3164,31 +3865,35 @@ def outlier_filter(self, Xy_data, name_points):
     outliers_sd = np.std(outliers_train)
 
     outliers_data = {}
-    outliers_data['train_scaled'] = (outliers_train-outliers_mean)/outliers_sd
-    outliers_data['test_scaled'] = (outliers_test-outliers_mean)/outliers_sd
+    outliers_data["train_scaled"] = (outliers_train - outliers_mean) / outliers_sd
+    outliers_data["test_scaled"] = (outliers_test - outliers_mean) / outliers_sd
 
-    print_outliers, naming, naming_test = '', False, False
-    if 'train' in name_points:
+    print_outliers, naming, naming_test = "", False, False
+    if "train" in name_points:
         naming = True
-        if 'test' in name_points:
+        if "test" in name_points:
             naming_test = True
 
-    outliers_data['outliers_train'], outliers_data['names_train'] = detect_outliers(self, outliers_data['train_scaled'], name_points, naming, 'train')
-    outliers_data['outliers_test'], outliers_data['names_test'] = detect_outliers(self, outliers_data['test_scaled'], name_points, naming_test, 'test')
-    
+    outliers_data["outliers_train"], outliers_data["names_train"] = detect_outliers(
+        self, outliers_data["train_scaled"], name_points, naming, "train"
+    )
+    outliers_data["outliers_test"], outliers_data["names_test"] = detect_outliers(
+        self, outliers_data["test_scaled"], name_points, naming_test, "test"
+    )
+
     return outliers_data, print_outliers
 
 
 def detect_outliers(self, outliers_scaled, name_points, naming_detect, set_type):
-    '''
+    """
     Detects and store outliers with their corresponding datapoint names
-    '''
+    """
 
     val_outliers = []
     name_outliers = []
     if naming_detect:
         name_points_list = name_points[set_type].to_list()
-    for i,val in enumerate(outliers_scaled):
+    for i, val in enumerate(outliers_scaled):
         if val > self.args.t_value or val < -self.args.t_value:
             val_outliers.append(val)
             if naming_detect:
@@ -3197,177 +3902,217 @@ def detect_outliers(self, outliers_scaled, name_points, naming_detect, set_type)
     return val_outliers, name_outliers
 
 
-def distribution_plot(self,Xy_data,path_n_suffix,params_dict):
-    '''
+def distribution_plot(self, Xy_data, path_n_suffix, params_dict):
+    """
     Plots histogram (reg) or bin plot (clas).
-    '''
-    import seaborn as sb
-
+    """
     sb.set(style="ticks")
 
-    _, ax = plt.subplots(figsize=(7.45,6))
-    plt.text(0.5, 1.08, f'y-values distribution (CV + test) of {os.path.basename(path_n_suffix)}', horizontalalignment='center',
-    fontsize=14, fontweight='bold', transform = ax.transAxes)
+    fig, ax = plt.subplots(figsize=(7.45, 6))
+    plt.text(
+        0.5,
+        1.08,
+        f"y-values distribution (CV + test) of {os.path.basename(path_n_suffix)}",
+        horizontalalignment="center",
+        fontsize=14,
+        fontweight="bold",
+        transform=ax.transAxes,
+    )
 
-    plt.grid(linestyle='--', linewidth=1)
+    plt.grid(linestyle="--", linewidth=1)
 
     # combine train and validation sets
-    y_combined = pd.concat([Xy_data['y_train'],Xy_data['y_test']], axis=0).reset_index(drop=True)
+    y_combined = pd.concat([Xy_data["y_train"], Xy_data["y_test"]], axis=0).reset_index(
+        drop=True
+    )
 
     # plot histogram, quartile lines and the points in each quartile
-    if params_dict['type'].lower() == 'reg':
-        y_dist_dict,ax = plot_quartiles(y_combined,ax)
-    
+    if params_dict["type"].lower() == "reg":
+        y_dist_dict, ax = plot_quartiles(y_combined, ax)
+
     # plot a bar plot with the count of each y type
-    elif params_dict['type'].lower() == 'clas':
-        y_dist_dict,ax = plot_y_count(y_combined,ax)
+    elif params_dict["type"].lower() == "clas":
+        y_dist_dict, ax = plot_y_count(y_combined, ax)
 
     # set styling preferences and graph limits
-    plt.xlabel(f'{params_dict["y"]} values',fontsize=14)
+    plt.xlabel(f"{params_dict['y']} values", fontsize=14)
     plt.xticks(fontsize=14)
-    plt.ylabel('Frequency',fontsize=14)
+    plt.ylabel("Frequency", fontsize=14)
     plt.yticks(fontsize=14)
 
     # set limits
-    if params_dict['type'].lower() == 'reg':
-        border_y_range = 0.1*np.abs(max(y_combined)-min(y_combined))
-        plt.xlim(min(y_combined)-border_y_range, max(y_combined)+border_y_range)
+    if params_dict["type"].lower() == "reg":
+        border_y_range = 0.1 * np.abs(max(y_combined) - min(y_combined))
+        plt.xlim(min(y_combined) - border_y_range, max(y_combined) + border_y_range)
 
     # save plot and print results
-    orig_distrib_file = f'y_distribution_{os.path.basename(path_n_suffix)}.png'
-    plt.savefig(f'{orig_distrib_file}', dpi=300, bbox_inches='tight')
+    orig_distrib_file = f"y_distribution_{os.path.basename(path_n_suffix)}.png"
+    plt.savefig(f"{orig_distrib_file}", dpi=300, bbox_inches="tight")
+    plt.close(fig)
     # for a VERY weird reason, I need to save the figure in the working directory and then move it into PREDICT
-    final_distrib_file = f'{os.path.dirname(path_n_suffix)}/y_distribution_{os.path.basename(path_n_suffix)}.png'
+    final_distrib_file = f"{os.path.dirname(path_n_suffix)}/y_distribution_{os.path.basename(path_n_suffix)}.png"
     shutil.move(orig_distrib_file, final_distrib_file)
 
-    path_reduced = '/'.join(f'{final_distrib_file}'.replace('\\','/').split('/')[-2:])
+    path_reduced = "/".join(f"{final_distrib_file}".replace("\\", "/").split("/")[-2:])
     print_distrib = f"\n   o  y-values distribution plot saved in {path_reduced}"
 
     # print the quartile results
-    if params_dict['type'].lower() == 'reg':
-        print_distrib += f"\n      Ideally, the number of datapoints in the four quartiles of the y-range should be uniform (25% population in each quartile) to have similar confidence intervals in the predictions across the y-range"
-        quartile_pops = [len(y_dist_dict['q1_points']),len(y_dist_dict['q2_points']),len(y_dist_dict['q3_points']),len(y_dist_dict['q4_points'])]
+    if params_dict["type"].lower() == "reg":
+        print_distrib += "\n      Ideally, the number of datapoints in the four quartiles of the y-range should be uniform (25% population in each quartile) to have similar confidence intervals in the predictions across the y-range"
+        quartile_pops = [
+            len(y_dist_dict["q1_points"]),
+            len(y_dist_dict["q2_points"]),
+            len(y_dist_dict["q3_points"]),
+            len(y_dist_dict["q4_points"]),
+        ]
         print_distrib += f"\n      -  The number of points in each quartile is Q1: {quartile_pops[0]}, Q2: {quartile_pops[1]}, Q3: {quartile_pops[2]}, Q4: {quartile_pops[3]}"
         quartile_min_idx = quartile_pops.index(min(quartile_pops))
         quartile_max_idx = quartile_pops.index(max(quartile_pops))
-        if 4*min(quartile_pops) < max(quartile_pops):
-            print_distrib += f"\n      x  WARNING! Your data is not uniform (Q{quartile_min_idx+1} has {min(quartile_pops)} points while Q{quartile_max_idx+1} has {max(quartile_pops)})"
-        elif 2*min(quartile_pops) < max(quartile_pops):
-            print_distrib += f"\n      x  WARNING! Your data is slightly not uniform (Q{quartile_min_idx+1} has {min(quartile_pops)} points while Q{quartile_max_idx+1} has {max(quartile_pops)})"
+        if 4 * min(quartile_pops) < max(quartile_pops):
+            print_distrib += f"\n      x  WARNING! Your data is not uniform (Q{quartile_min_idx + 1} has {min(quartile_pops)} points while Q{quartile_max_idx + 1} has {max(quartile_pops)})"
+        elif 2 * min(quartile_pops) < max(quartile_pops):
+            print_distrib += f"\n      x  WARNING! Your data is slightly not uniform (Q{quartile_min_idx + 1} has {min(quartile_pops)} points while Q{quartile_max_idx + 1} has {max(quartile_pops)})"
         else:
-            print_distrib += f"\n      o  Your data seems quite uniform"
+            print_distrib += "\n      o  Your data seems quite uniform"
 
-    elif params_dict['type'].lower() == 'clas':
-        print_distrib += f"\n      Ideally, the number of datapoints in each prediction class should be uniform (50% population per class) to have similar reliability in the predictions across classes"
-        distrib_counts = [y_dist_dict['count_labels'][0],y_dist_dict['count_labels'][1]]
+    elif params_dict["type"].lower() == "clas":
+        print_distrib += "\n      Ideally, the number of datapoints in each prediction class should be uniform (50% population per class) to have similar reliability in the predictions across classes"
+        distrib_counts = [
+            y_dist_dict["count_labels"][0],
+            y_dist_dict["count_labels"][1],
+        ]
         print_distrib += f"\n      - The number of points in each class is {y_dist_dict['type_labels'][0]}: {y_dist_dict['count_labels'][0]}, {y_dist_dict['type_labels'][1]}: {y_dist_dict['count_labels'][1]}"
         class_min_idx = distrib_counts.index(min(distrib_counts))
         class_max_idx = distrib_counts.index(max(distrib_counts))
-        if 3*min(distrib_counts) < max(distrib_counts):
+        if 3 * min(distrib_counts) < max(distrib_counts):
             print_distrib += f"\n      x  WARNING! Your data is not uniform (class {y_dist_dict['type_labels'][class_min_idx]} has {min(distrib_counts)} points while class {y_dist_dict['type_labels'][class_max_idx]} has {max(distrib_counts)})"
-        elif 1.5*min(distrib_counts) < max(distrib_counts):
+        elif 1.5 * min(distrib_counts) < max(distrib_counts):
             print_distrib += f"\n      x  WARNING! Your data is slightly not uniform (class {y_dist_dict['type_labels'][class_min_idx]} has {min(distrib_counts)} points while class {y_dist_dict['type_labels'][class_max_idx]} has {max(distrib_counts)})"
         else:
-            print_distrib += f"\n      o  Your data seems quite uniform"
+            print_distrib += "\n      o  Your data seems quite uniform"
 
     self.args.log.write(print_distrib)
 
 
-def plot_quartiles(y_combined,ax):
-    '''
+def plot_quartiles(y_combined, ax):
+    """
     Plot histogram, quartile lines and the points in each quartile.
-    '''
+    """
 
-    bins = max([round(len(y_combined)/5),5]) # at least 5 bins until 25 points
+    bins = max([round(len(y_combined) / 5), 5])  # at least 5 bins until 25 points
     # histogram
-    y_hist, _, _ = ax.hist(y_combined, bins=bins,
-                color='#1f77b4', edgecolor='k', linewidth=1, alpha=1)
+    y_hist, _, _ = ax.hist(
+        y_combined, bins=bins, color="#1f77b4", edgecolor="k", linewidth=1, alpha=1
+    )
 
     # uniformity lines to plot
-    separation_range = np.abs(max(y_combined)-min(y_combined))/4
-    quart_dict = {'line_1': min(y_combined),
-                    'line_2': min(y_combined) + separation_range,
-                    'line_3': min(y_combined) + (2*separation_range),
-                    'line_4': min(y_combined) + (3*separation_range),
-                    'line_5': max(y_combined)}
+    separation_range = np.abs(max(y_combined) - min(y_combined)) / 4
+    quart_dict = {
+        "line_1": min(y_combined),
+        "line_2": min(y_combined) + separation_range,
+        "line_3": min(y_combined) + (2 * separation_range),
+        "line_4": min(y_combined) + (3 * separation_range),
+        "line_5": max(y_combined),
+    }
 
     lines_plot = [quart_dict[line] for line in quart_dict]
-    ax.vlines([lines_plot], ymin=max(y_hist)*1.05, ymax=max(y_hist)*1.3, colors='crimson', linestyles='--')
+    ax.vlines(
+        [lines_plot],
+        ymin=max(y_hist) * 1.05,
+        ymax=max(y_hist) * 1.3,
+        colors="crimson",
+        linestyles="--",
+    )
 
     # points in each quartile
-    quart_dict['q1_points'] = []
-    quart_dict['q2_points'] = []
-    quart_dict['q3_points'] = []
-    quart_dict['q4_points'] = []
+    quart_dict["q1_points"] = []
+    quart_dict["q2_points"] = []
+    quart_dict["q3_points"] = []
+    quart_dict["q4_points"] = []
 
     for val in y_combined:
-        if val < quart_dict['line_2']:
-            quart_dict['q1_points'].append(val)
-        elif quart_dict['line_2'] < val < quart_dict['line_3']:
-            quart_dict['q2_points'].append(val)
-        elif quart_dict['line_3'] < val < quart_dict['line_4']:
-            quart_dict['q3_points'].append(val)
-        elif val >= quart_dict['line_4']:
-            quart_dict['q4_points'].append(val)
+        if val < quart_dict["line_2"]:
+            quart_dict["q1_points"].append(val)
+        elif quart_dict["line_2"] < val < quart_dict["line_3"]:
+            quart_dict["q2_points"].append(val)
+        elif quart_dict["line_3"] < val < quart_dict["line_4"]:
+            quart_dict["q3_points"].append(val)
+        elif val >= quart_dict["line_4"]:
+            quart_dict["q4_points"].append(val)
 
     x_quart = 0.185
     for quart in quart_dict:
-        if 'points' in quart:
-            plt.text(x_quart, 0.845, f'Q{quart[1]}\n{len(quart_dict[quart])} points', horizontalalignment='center',
-                    fontsize=12, transform = ax.transAxes, backgroundcolor='w')
+        if "points" in quart:
+            plt.text(
+                x_quart,
+                0.845,
+                f"Q{quart[1]}\n{len(quart_dict[quart])} points",
+                horizontalalignment="center",
+                fontsize=12,
+                transform=ax.transAxes,
+                backgroundcolor="w",
+            )
             x_quart += 0.209
 
-    return quart_dict,ax
+    return quart_dict, ax
 
 
-def plot_y_count(y_combined,ax):
-    '''
+def plot_y_count(y_combined, ax):
+    """
     Plot a bar plot with the count of each y type.
-    '''
+    """
 
     # get the number of times that each y type is included
     labels_used = set(y_combined)
-    type_labels,count_labels = [],[]
+    type_labels, count_labels = [], []
     for label in labels_used:
         type_labels.append(label)
         count_labels.append(len(y_combined[y_combined == label]))
 
-    _ = ax.bar(type_labels, count_labels, tick_label=type_labels,
-                color='#1f77b4', edgecolor='k', linewidth=1, alpha=1,
-                width=0.4)
+    _ = ax.bar(
+        type_labels,
+        count_labels,
+        tick_label=type_labels,
+        color="#1f77b4",
+        edgecolor="k",
+        linewidth=1,
+        alpha=1,
+        width=0.4,
+    )
 
-    y_dist_dict = {'type_labels': type_labels,
-                   'count_labels': count_labels}
+    y_dist_dict = {"type_labels": type_labels, "count_labels": count_labels}
 
-    return y_dist_dict,ax
+    return y_dist_dict, ax
 
 
-def get_prediction_results(model_data,y,y_pred_all):
-    '''
+def get_prediction_results(model_data, y, y_pred_all):
+    """
     Calculate metrics based on y and y_pred
-    '''
+    """
 
-    if model_data['type'].lower() == 'reg':
-        mae = mean_absolute_error(y,y_pred_all)
-        rmse = np.sqrt(mean_squared_error(y,y_pred_all))
+    if model_data["type"].lower() == "reg":
+        mae = mean_absolute_error(y, y_pred_all)
+        rmse = np.sqrt(mean_squared_error(y, y_pred_all))
         if len(np.unique(y)) > 1 and len(np.unique(y_pred_all)) > 1:
-            res = stats.linregress(y,y_pred_all)
+            res = stats.linregress(y, y_pred_all)
             r2 = res.rvalue**2
         else:
             r2 = 0.0
 
         return r2, mae, rmse
 
-    elif model_data['type'].lower() == 'clas':
+    elif model_data["type"].lower() == "clas":
         # ensure true and predicted labels are integers
-        acc = accuracy_score(y,np.round(y_pred_all).astype(int))
+        acc = accuracy_score(y, np.round(y_pred_all).astype(int))
         # F1 by default uses average='binnary', to deal with predictions with more than 2 ouput values we use average='micro'
         # if len(set(y))==2:
         try:
-            f1_score_val = f1_score(y,np.round(y_pred_all).astype(int))
+            f1_score_val = f1_score(y, np.round(y_pred_all).astype(int))
         except ValueError:
-            f1_score_val = f1_score(y,np.round(y_pred_all).astype(int),average='micro')
-        mcc = matthews_corrcoef(y,np.round(y_pred_all).astype(int))
+            f1_score_val = f1_score(
+                y, np.round(y_pred_all).astype(int), average="micro"
+            )
+        mcc = matthews_corrcoef(y, np.round(y_pred_all).astype(int))
         return acc, f1_score_val, mcc
 
 
@@ -3387,10 +4132,7 @@ def get_error_labels(model_type):
         - Regression: ('r2', 'mae', 'rmse')
         - Classification: ('acc', 'f1', 'mcc')
     """
-    error_labels = {
-        'reg': ('r2', 'mae', 'rmse'),
-        'clas': ('acc', 'f1', 'mcc')
-    }
+    error_labels = {"reg": ("r2", "mae", "rmse"), "clas": ("acc", "f1", "mcc")}
 
     model_type_lower = model_type.lower()
 
@@ -3422,134 +4164,193 @@ def _select_descriptors(self, df, descriptors, module):
             sys.exit()
 
 
-def load_db_n_params(self,params_dir,suffix,suffix_title,module,print_load):
-    '''
+def load_db_n_params(self, params_dir, suffix, suffix_title, module, print_load):
+    """
     Loads the parameters and Xy databases from a folder, add scaled X data and print information
     about the databases
-    '''
+    """
 
     # load databases from CSV
-    csv_df,csv_X,csv_y,model_data,_ = load_dfs(self,params_dir,module,print_info=print_load)
+    csv_df, csv_X, csv_y, model_data, _ = load_dfs(
+        self, params_dir, module, print_info=print_load
+    )
 
     # detect points in the test set
-    test_points = csv_X[csv_X['Set'] == 'Test'].index.tolist()
-    csv_X = csv_X.drop(columns=['Set'])
+    test_points = csv_X[csv_X["Set"] == "Test"].index.tolist()
+    csv_X = csv_X.drop(columns=["Set"])
 
     # keep only the descriptors used in the model
     csv_X = _select_descriptors(self, csv_X, model_data["X_descriptors"], module)
 
     # load and adjust external set (if any)
-    csv_external_df, csv_X_external,csv_y_external = None,None,None
-    if self.args.csv_test != '':
-        csv_external_df,csv_X_external,csv_y_external = load_database(self,self.args.csv_test,'predict',external_test=True)
+    csv_external_df, csv_X_external, csv_y_external = None, None, None
+    if self.args.csv_test != "":
+        csv_external_df, csv_X_external, csv_y_external = load_database(
+            self, self.args.csv_test, "predict", external_test=True
+        )
         csv_X_external = _select_descriptors(
             self, csv_X_external, model_data["X_descriptors"], "predict"
         )
 
     # split tests
-    Xy_data = prepare_sets(self,csv_df,csv_X,csv_y,test_points,model_data['names'],csv_external_df,csv_X_external,csv_y_external,BO_opt=False)
+    Xy_data = prepare_sets(
+        self,
+        csv_df,
+        csv_X,
+        csv_y,
+        test_points,
+        model_data["names"],
+        csv_external_df,
+        csv_X_external,
+        csv_y_external,
+        BO_opt=False,
+    )
 
     # print information of loaded database
     params_name = os.path.basename(params_dir)
     if print_load:
-        _ = load_print(self,params_name,suffix,model_data,Xy_data)
+        _ = load_print(self, params_name, suffix, model_data, Xy_data)
 
     return Xy_data, model_data, suffix_title
 
 
-def prepare_sets(self,csv_df,csv_X,csv_y,test_points,column_names,csv_external_df,csv_X_external,csv_y_external,BO_opt=False):
-    '''
+def prepare_sets(
+    self,
+    csv_df,
+    csv_X,
+    csv_y,
+    test_points,
+    column_names,
+    csv_external_df,
+    csv_X_external,
+    csv_y_external,
+    BO_opt=False,
+):
+    """
     Standardizes and separate test set
-    '''
+    """
 
-    X_scaled_df,X_scaled_external_df = scale_df(csv_X,csv_X_external)
+    X_scaled_df, X_scaled_external_df = scale_df(csv_X, csv_X_external)
 
     # separate test set and save it in the Xy data
     if BO_opt:
-        if self.args.csv_test != '':
+        if self.args.csv_test != "":
             self.args.test_set = 0
-        
+
         if self.args.auto_test:
             if self.args.test_set < 0.2:
                 self.args.test_set = 0.2
-                self.args.log.write(f'\nx  WARNING! The test_set option was set to {self.args.test_set}, this value will be raised to 0.2 to include a meaningful amount of points in the test set. You can bypass this option and include less test points with "--auto_test False".')
+                self.args.log.write(
+                    f'\nx  WARNING! The test_set option was set to {self.args.test_set}, this value will be raised to 0.2 to include a meaningful amount of points in the test set. You can bypass this option and include less test points with "--auto_test False".'
+                )
 
         if self.args.test_set > 0:
-            self.args.log.write(f'\no  Before hyperoptimization, {int(self.args.test_set*100)}% of the data (or 4 points at least) was separated as test set, using an even distribution of data points across the range of y values.')
+            self.args.log.write(
+                f"\no  Before hyperoptimization, {int(self.args.test_set * 100)}% of the data (or 4 points at least) was separated as test set, using an even distribution of data points across the range of y values."
+            )
             try:
-                test_points = test_select(self,X_scaled_df,csv_y)
+                test_points = test_select(self, X_scaled_df, csv_y)
             except TypeError:
-                self.args.log.write(f'   x The data split process failed! This is probably due to using strings/words as values (use --curate to curate the data first)')
+                self.args.log.write(
+                    "   x The data split process failed! This is probably due to using strings/words as values (use --curate to curate the data first)"
+                )
                 sys.exit()
 
     # load predefined sets and save the info in Xy data
-    Xy_data = Xy_split(csv_df,csv_X,X_scaled_df,csv_y,csv_external_df,csv_X_external,X_scaled_external_df,csv_y_external,test_points,column_names)
+    Xy_data = Xy_split(
+        csv_df,
+        csv_X,
+        X_scaled_df,
+        csv_y,
+        csv_external_df,
+        csv_X_external,
+        X_scaled_external_df,
+        csv_y_external,
+        test_points,
+        column_names,
+    )
 
     # also store the descriptors used (the labels disappear after test_select() )
-    Xy_data['X_descriptors'] = csv_X.columns.tolist()
+    Xy_data["X_descriptors"] = csv_X.columns.tolist()
 
     return Xy_data
 
 
-def load_dfs(self,folder_model,module,sanity_check=False,print_info=True):
-    '''
+def load_dfs(self, folder_model, module, sanity_check=False, print_info=True):
+    """
     Loads the parameters and Xy databases from the GENERATE folder as dataframes
-    '''
-    
+    """
+
+    csv_df = pd.DataFrame()
+    csv_X = pd.DataFrame()
+    csv_y = pd.DataFrame()
+    model_data = {}
+    csv_name = ""
+
     if os.getcwd() in f"{folder_model}":
         path_db = folder_model
     else:
         path_db = f"{Path(os.getcwd()).joinpath(folder_model)}"
     if os.path.exists(path_db):
-        csv_files = glob.glob(f'{Path(path_db).joinpath("*.csv")}')
-        csv_files.sort(key=lambda f: f.endswith('_db.csv')) # sort the database file to be the last one, depending on the OS was taking first the dabatase and then the parameters
+        csv_files = glob.glob(f"{Path(path_db).joinpath('*.csv')}")
+        csv_files.sort(
+            key=lambda f: f.endswith("_db.csv")
+        )  # sort the database file to be the last one, depending on the OS was taking first the dabatase and then the parameters
         for csv_file in csv_files:
-            if csv_file.endswith('_db.csv'):
+            if csv_file.endswith("_db.csv"):
                 if not sanity_check:
-                    csv_df,csv_X,csv_y = load_database(self,csv_file,module,print_info=print_info)
+                    csv_df, csv_X, csv_y = load_database(
+                        self, csv_file, module, print_info=print_info
+                    )
                 csv_name = csv_file
             else:
-                csv_df,csv_X,csv_y = pd.DataFrame(),pd.DataFrame(),pd.DataFrame()
+                csv_df, csv_X, csv_y = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
                 # convert df to dict, then adjust params to a valid format
-                model_data = load_params(self,csv_file)
+                model_data = load_params(self, csv_file)
     else:
-        self.args.log.write(f"\nx  The folder with the model and database ({path_db}) does not exist! Did you use the destination=PATH option in the other modules?")
+        self.args.log.write(
+            f"\nx  The folder with the model and database ({path_db}) does not exist! Did you use the destination=PATH option in the other modules?"
+        )
         sys.exit()
 
-    return csv_df,csv_X,csv_y,model_data,csv_name
+    return csv_df, csv_X, csv_y, model_data, csv_name
 
 
-def load_params(self,path_csv):
-    '''
+def load_params(self, path_csv):
+    """
     Load parameters from a CSV and adjust the format
-    '''
-    
-    PFI_df = pd.read_csv(path_csv, encoding='utf-8')
+    """
+
+    PFI_df = pd.read_csv(path_csv, encoding="utf-8")
     PFI_dict = pd_to_dict(PFI_df)
     PFI_dict = dict_formating(PFI_dict)
-    PFI_dict['params'] = model_adjust_params(self, PFI_dict['model'], PFI_dict['params'])
+    PFI_dict["params"] = model_adjust_params(
+        self, PFI_dict["model"], PFI_dict["params"]
+    )
 
     return PFI_dict
 
 
-def load_print(self,params_name,suffix,model_data,Xy_data):
-    '''
+def load_print(self, params_name, suffix, model_data, Xy_data):
+    """
     Print information of the database loaded and type of model used
-    '''
-
-    if '.csv' in params_name:
-        params_name = params_name.split('.csv')[0]
-    txt_load = f'\no  ML model {params_name} {suffix} and Xy database were loaded, including:'
-    txt_load += f'\n   - Target value: {model_data["y"]}'
-    txt_load += f'\n   - Names: {model_data["names"]}'
-    txt_load += f'\n   - Model: {model_data["model"]}'
-    txt_load += f'\n   - k-fold CV: {model_data["kfold"]}'
-    txt_load += f'\n   - Repetitions CV: {model_data["repeat_kfolds"]}'
-    txt_load += f'\n   - Descriptors: {model_data["X_descriptors"]}'
-    txt_load += f'\n   - Training points: {len(Xy_data["y_train"])}'
-    txt_load += f'\n   - Test points: {len(Xy_data["y_test"])}'
-    if 'X_external' in Xy_data:
-        txt_load += f'\n   - External test points: {len(Xy_data["X_external"])}'
+    """
+
+    if ".csv" in params_name:
+        params_name = params_name.split(".csv")[0]
+    txt_load = (
+        f"\no  ML model {params_name} {suffix} and Xy database were loaded, including:"
+    )
+    txt_load += f"\n   - Target value: {model_data['y']}"
+    txt_load += f"\n   - Names: {model_data['names']}"
+    txt_load += f"\n   - Model: {model_data['model']}"
+    txt_load += f"\n   - k-fold CV: {model_data['kfold']}"
+    txt_load += f"\n   - Repetitions CV: {model_data['repeat_kfolds']}"
+    txt_load += f"\n   - Descriptors: {model_data['X_descriptors']}"
+    txt_load += f"\n   - Training points: {len(Xy_data['y_train'])}"
+    txt_load += f"\n   - Test points: {len(Xy_data['y_test'])}"
+    if "X_external" in Xy_data:
+        txt_load += f"\n   - External test points: {len(Xy_data['X_external'])}"
     self.args.log.write(txt_load)
 
 
@@ -3560,239 +4361,316 @@ def pd_to_dict(PFI_df):
     return PFI_df_dict
 
 
-def print_pfi(self,params_dir):
-    if 'No_PFI' in params_dir:
-        self.args.log.write('\n\n------- Starting model with all variables (No PFI) -------')
+def print_pfi(self, params_dir):
+    if "No_PFI" in params_dir:
+        self.args.log.write(
+            "\n\n------- Starting model with all variables (No PFI) -------"
+        )
     else:
-        self.args.log.write('\n\n------- Starting model with PFI filter (only important descriptors used) -------')
+        self.args.log.write(
+            "\n\n------- Starting model with PFI filter (only important descriptors used) -------"
+        )
 
 
 def get_graph_style():
     """
     Retrieves the graph style for regression plots
     """
-    
-    graph_style = {'color_train' : 'b',
-        'color_valid' : 'orange',
-        'color_test' : 'r',
-        'dot_size' : 50,
-        'alpha' : 1 # from 0 (transparent) to 1 (opaque)
-        }
+
+    graph_style = {
+        "color_train": "b",
+        "color_valid": "orange",
+        "color_test": "r",
+        "dot_size": 50,
+        "alpha": 1,  # from 0 (transparent) to 1 (opaque)
+    }
 
     return graph_style
 
 
-def pearson_map(self,csv_df_pearson,module,params_dir=None):
-    '''
+def pearson_map(self, csv_df_pearson, module, params_dir=None):
+    """
     Creates Pearson heatmap
-    '''
-    import seaborn as sb
-
+    """
     if module.lower() == "curate":  # only represent the final descriptors in CURATE
         csv_df_pearson = csv_df_pearson.drop([self.args.y] + self.args.ignore, axis=1)
 
     corr_matrix = csv_df_pearson.corr()
     mask = np.zeros_like(corr_matrix, dtype=bool)
-    mask[np.triu_indices_from(mask)]= True
-    
+    mask[np.triu_indices_from(mask)] = True
+
     # no representatoins when there are more than 30 descriptors
     if len(csv_df_pearson.columns) > 30:
         disable_plot = True
     else:
         disable_plot = False
-        _, ax = plt.subplots(figsize=(7.45,6))
+        fig, ax = plt.subplots(figsize=(7.45, 6))
         size_title = 14
-        size_font = 14-2*((len(csv_df_pearson.columns)/5))
+        size_font = 14 - 2 * (len(csv_df_pearson.columns) / 5)
 
     if disable_plot:
-        if module.lower() == 'curate':
-            self.args.log.write(f'\nx  The Pearson heatmap was not generated because the number of features and the y value ({len(csv_df_pearson.columns)}) is higher than 30.')
-        if module.lower() == 'predict':
-            self.args.log.write(f'\n   x  The Pearson heatmap was not generated because the number of features and the y value ({len(csv_df_pearson.columns)}) is higher than 30.')
+        if module.lower() == "curate":
+            self.args.log.write(
+                f"\nx  The Pearson heatmap was not generated because the number of features and the y value ({len(csv_df_pearson.columns)}) is higher than 30."
+            )
+        if module.lower() == "predict":
+            self.args.log.write(
+                f"\n   x  The Pearson heatmap was not generated because the number of features and the y value ({len(csv_df_pearson.columns)}) is higher than 30."
+            )
 
     else:
-        sb.set(font_scale=1.2, style='ticks')
-
-        _ = sb.heatmap(corr_matrix,
-                        mask = mask,
-                        square = True,
-                        linewidths = .5,
-                        cmap = 'coolwarm',
-                        cbar = False,
-                        cbar_kws = {'shrink': .4,
-                                    'ticks' : [-1, -.5, 0, 0.5, 1]},
-                        vmin = -1,
-                        vmax = 1,
-                        annot = True,
-                        annot_kws = {'size': size_font})
+        sb.set(font_scale=1.2, style="ticks")
+
+        _ = sb.heatmap(
+            corr_matrix,
+            mask=mask,
+            square=True,
+            linewidths=0.5,
+            cmap="coolwarm",
+            cbar=False,
+            cbar_kws={"shrink": 0.4, "ticks": [-1, -0.5, 0, 0.5, 1]},
+            vmin=-1,
+            vmax=1,
+            annot=True,
+            annot_kws={"size": size_font},
+        )
 
         plt.tick_params(labelsize=size_font)
-        #add the column names as labels
-        ax.set_yticklabels(corr_matrix.columns, rotation = 0)
+        # add the column names as labels
+        ax.set_yticklabels(corr_matrix.columns, rotation=0)
         ax.set_xticklabels(corr_matrix.columns)
 
-        title_fig = 'Pearson\'s r heatmap'
-        if module.lower() == 'predict':
-            if os.path.basename(Path(params_dir)) == 'No_PFI':
-                suffix_title = 'No_PFI'
-            elif os.path.basename(Path(params_dir)) == 'PFI':
-                suffix_title = 'PFI'
-            title_fig += f'_{suffix_title}'
+        title_fig = "Pearson's r heatmap"
+        if module.lower() == "predict":
+            if os.path.basename(Path(params_dir)) == "No_PFI":
+                suffix_title = "No_PFI"
+            elif os.path.basename(Path(params_dir)) == "PFI":
+                suffix_title = "PFI"
+            title_fig += f"_{suffix_title}"
 
-        plt.title(title_fig, y=1.04, fontsize = size_title, fontweight="bold")
-        sb.set_style({'xtick.bottom': True}, {'ytick.left': True})
+        plt.title(title_fig, y=1.04, fontsize=size_title, fontweight="bold")
+        sb.set_style({"xtick.bottom": True}, {"ytick.left": True})
 
-        if module.lower() == 'curate':
-            heatmap_name = 'Pearson_heatmap.png'
-        elif module.lower() == 'predict':
-            heatmap_name = f'Pearson_heatmap_{suffix_title}.png'
+        if module.lower() == "curate":
+            heatmap_name = "Pearson_heatmap.png"
+        elif module.lower() == "predict":
+            heatmap_name = f"Pearson_heatmap_{suffix_title}.png"
 
         heatmap_path = self.args.destination.joinpath(heatmap_name)
-        plt.savefig(f'{heatmap_path}', dpi=300, bbox_inches='tight')
+        plt.savefig(f"{heatmap_path}", dpi=300, bbox_inches="tight")
+        plt.close(fig)
 
-        path_reduced = '/'.join(f'{heatmap_path}'.replace('\\','/').split('/')[-2:])
-        if module.lower() == 'curate':
-            self.args.log.write(f'\no  The Pearson heatmap was stored in {path_reduced}.')
-        elif module.lower() == 'predict':
-            self.args.log.write(f'\n   o  The Pearson heatmap was stored in {path_reduced}.')
+        path_reduced = "/".join(f"{heatmap_path}".replace("\\", "/").split("/")[-2:])
+        if module.lower() == "curate":
+            self.args.log.write(
+                f"\no  The Pearson heatmap was stored in {path_reduced}."
+            )
+        elif module.lower() == "predict":
+            self.args.log.write(
+                f"\n   o  The Pearson heatmap was stored in {path_reduced}."
+            )
 
     return corr_matrix
 
 
-def plot_metrics(model_data,suffix_title,verify_metrics,verify_results):
-    '''
+def plot_metrics(model_data, suffix_title, verify_metrics, verify_results):
+    """
     Creates a plot with the results of the flawed models in VERIFY
-    '''
-    import seaborn as sb
-
-    importlib.reload(plt)
-    sb.reset_defaults()
-    sb.set(style="ticks")
-    _, ax = plt.subplots(figsize=(7.45, 6))
-
-    # define names
-    csv_name = os.path.basename(model_data['model']).split('_db.csv')[0]
-    base_csv_name = f'VERIFY/{csv_name}'
-    base_csv_path = f"{Path(os.getcwd()).joinpath(base_csv_name)}"
-    path_n_suffix = f'{base_csv_path}_{suffix_title}'
-
-    # axis limits
-    max_val = max(verify_metrics['metrics'])
-    min_val = min(verify_metrics['metrics'])
-    range_vals = np.abs(max_val - min_val)
-    if verify_results['error_type'].lower() in ['mae','rmse']:
-        max_lim = 1.2*max_val
-        min_lim = 0
-    else:
-        max_lim = max_val + (0.2*range_vals)
-        min_lim = min_val - (0.1*range_vals)
-    plt.ylim(min_lim, max_lim)
-    plt.ylim(min_lim, max_lim)
-
-    # adjust number of significative numbers shown
-    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
-
-    width_bar = 0.55
-    label_count = 0
-    for test_metric,test_name,test_color in zip(verify_metrics['metrics'],verify_metrics['test_names'],verify_metrics['colors']):
-        rects = ax.bar(test_name, test_metric, label=test_name, 
-                width=width_bar, linewidth=1, edgecolor='k', 
-                color=test_color, zorder=2)
-        # plot whether the tests pass or fail
-        if test_name != 'Model':
-            if test_metric >= 0:
-                offset_txt = test_metric+(0.05*range_vals)
+    """
+    with _mpl_plot_context():
+        sb.reset_defaults()
+        sb.set(style="ticks")
+        fig, ax = plt.subplots(figsize=(7.45, 6))
+
+        # define names
+        csv_name = os.path.basename(model_data["model"]).split("_db.csv")[0]
+        base_csv_name = f"VERIFY/{csv_name}"
+        base_csv_path = f"{Path(os.getcwd()).joinpath(base_csv_name)}"
+        path_n_suffix = f"{base_csv_path}_{suffix_title}"
+
+        # axis limits
+        max_val = max(verify_metrics["metrics"])
+        min_val = min(verify_metrics["metrics"])
+        range_vals = np.abs(max_val - min_val)
+        if verify_results["error_type"].lower() in ["mae", "rmse"]:
+            min_lim = 0
+            max_lim = 1.2 * max_val if max_val != 0 else 0.1
+        else:
+            if range_vals == 0:
+                pad = max(abs(max_val) * 0.1, 0.05)
+                min_lim = max_val - pad
+                max_lim = max_val + pad
             else:
-                offset_txt = test_metric-(0.05*range_vals)
-            if test_color == '#1f77b4':
-                txt_bar = 'pass'
-            elif test_color == '#cd5c5c':
-                txt_bar = 'fail'
-            elif test_color == '#c5c57d':
-                txt_bar = 'unclear'
-            ax.text(label_count, offset_txt, txt_bar, color=test_color, 
-                    fontstyle='italic', horizontalalignment='center')
-        label_count += 1
-
-    # Set tick sizes
-    plt.xticks(fontsize=14)
-    plt.yticks(fontsize=14)
-
-    # title and labels of the axis
-    plt.ylabel(f'{verify_results["error_type"].upper()}', fontsize=14)
-
-    plt.text(0.5, 1.08, f'VERIFY tests of {os.path.basename(path_n_suffix)}', horizontalalignment='center',
-        fontsize=14, fontweight='bold', transform = ax.transAxes)
-
-    # add threshold line and arrow indicating passed test direction
-    arrow_length = np.abs(max_lim-min_lim)/11
-    
-    if verify_results['error_type'].lower() in ['mae','rmse']:
-        thres_line = verify_metrics['higher_thres']
-        unclear_thres_line = verify_metrics['unclear_higher_thres']  
-    else:
-        thres_line = verify_metrics['lower_thres']
-        unclear_thres_line = verify_metrics['unclear_lower_thres']
-        arrow_length = -arrow_length
-
-    width = 2
-    xmin = 0.237
-    thres = ax.axhline(thres_line,xmin=xmin, color='black',ls='--', label='thres', zorder=0)
-    thres = ax.axhline(unclear_thres_line,xmin=xmin, color='black',ls='--', label='thres', zorder=0)
-
-    x_arrow = 0.5
-    style = mpatches.ArrowStyle('simple', head_length=4.5*width, head_width=3.5*width, tail_width=width)
-    arrow = mpatches.FancyArrowPatch((x_arrow, thres_line), (x_arrow, thres_line+arrow_length), 
-                            arrowstyle=style, color='k')  # (x1,y1), (x2,y2) vector direction                   
-    ax.add_patch(arrow)
+                max_lim = max_val + (0.2 * range_vals)
+                min_lim = min_val - (0.1 * range_vals)
+        if range_vals == 0 and verify_results["error_type"].lower() in ["mae", "rmse"]:
+            range_vals = max_lim - min_lim
+        ax.set_ylim(min_lim, max_lim)
+
+        # adjust number of significative numbers shown
+        ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
+
+        width_bar = 0.55
+        label_count = 0
+        for test_metric, test_name, test_color in zip(
+            verify_metrics["metrics"],
+            verify_metrics["test_names"],
+            verify_metrics["colors"],
+        ):
+            ax.bar(
+                test_name,
+                test_metric,
+                label=test_name,
+                width=width_bar,
+                linewidth=1,
+                edgecolor="k",
+                color=test_color,
+                zorder=2,
+            )
+            # plot whether the tests pass or fail
+            if test_name != "Model":
+                if test_metric >= 0:
+                    offset_txt = test_metric + (0.05 * range_vals)
+                else:
+                    offset_txt = test_metric - (0.05 * range_vals)
+                if test_color == "#1f77b4":
+                    txt_bar = "pass"
+                elif test_color == "#cd5c5c":
+                    txt_bar = "fail"
+                elif test_color == "#c5c57d":
+                    txt_bar = "unclear"
+                ax.text(
+                    label_count,
+                    offset_txt,
+                    txt_bar,
+                    color=test_color,
+                    fontstyle="italic",
+                    horizontalalignment="center",
+                )
+            label_count += 1
+
+        # Set tick sizes
+        ax.tick_params(axis="x", labelsize=14)
+        ax.tick_params(axis="y", labelsize=14)
+
+        # title and labels of the axis
+        ax.set_ylabel(f"{verify_results['error_type'].upper()}", fontsize=14)
+
+        ax.text(
+            0.5,
+            1.08,
+            f"VERIFY tests of {os.path.basename(path_n_suffix)}",
+            horizontalalignment="center",
+            fontsize=14,
+            fontweight="bold",
+            transform=ax.transAxes,
+        )
 
-    # invisible "dummy" arrows to make the graph wider so the real arrows fit in the right place
-    ax.arrow(x_arrow, thres_line, 0, 0, width=0, fc='k', ec='k') # x,y,dx,dy format
+        # add threshold line and arrow indicating passed test direction
+        arrow_length = np.abs(max_lim - min_lim) / 11
 
-    # legend and regression line with 95% CI considering all possible lines (not CI of the points)
-    def make_legend_arrow(legend, orig_handle,
-                        xdescent, ydescent,
-                        width, height, fontsize):
-        p = mpatches.FancyArrow(0, 0.5*height, width, 0, width=1.5, length_includes_head=True, head_width=0.58*height )
-        return p
+        if verify_results["error_type"].lower() in ["mae", "rmse"]:
+            thres_line = verify_metrics["higher_thres"]
+            unclear_thres_line = verify_metrics["unclear_higher_thres"]
+        else:
+            thres_line = verify_metrics["lower_thres"]
+            unclear_thres_line = verify_metrics["unclear_lower_thres"]
+            arrow_length = -arrow_length
+
+        width = 2
+        xmin = 0.237
+        thres = ax.axhline(
+            thres_line, xmin=xmin, color="black", ls="--", label="thres", zorder=0
+        )
+        thres = ax.axhline(
+            unclear_thres_line,
+            xmin=xmin,
+            color="black",
+            ls="--",
+            label="thres",
+            zorder=0,
+        )
 
-    arrow = plt.arrow(0, 0, 0, 0, label='arrow', width=0, fc='k', ec='k') # arrow for the legend
-    plt.figlegend([thres,arrow], [f'Limits: {thres_line:.2} (pass), {unclear_thres_line:.2} (unclear)','Pass test'], handler_map={mpatches.FancyArrow : HandlerPatch(patch_func=make_legend_arrow),},
-                    loc="lower center", ncol=2, bbox_to_anchor=(0.5, -0.05),
-                    fancybox=True, shadow=True, fontsize=14)
+        x_arrow = 0.5
+        style = mpatches.ArrowStyle(
+            "simple", head_length=4.5 * width, head_width=3.5 * width, tail_width=width
+        )
+        arrow = mpatches.FancyArrowPatch(
+            (x_arrow, thres_line),
+            (x_arrow, thres_line + arrow_length),
+            arrowstyle=style,
+            color="k",
+        )
+        ax.add_patch(arrow)
+
+        # invisible "dummy" arrows to make the graph wider so the real arrows fit in the right place
+        ax.arrow(x_arrow, thres_line, 0, 0, width=0, fc="k", ec="k")
+
+        def make_legend_arrow(
+            legend, orig_handle, xdescent, ydescent, width, height, fontsize
+        ):
+            p = mpatches.FancyArrow(
+                0,
+                0.5 * height,
+                width,
+                0,
+                width=1.5,
+                length_includes_head=True,
+                head_width=0.58 * height,
+            )
+            return p
+
+        arrow_legend = plt.arrow(0, 0, 0, 0, label="arrow", width=0, fc="k", ec="k")
+        fig.legend(
+            [thres, arrow_legend],
+            [
+                f"Limits: {thres_line:.2} (pass), {unclear_thres_line:.2} (unclear)",
+                "Pass test",
+            ],
+            handler_map={
+                mpatches.FancyArrow: HandlerPatch(patch_func=make_legend_arrow),
+            },
+            loc="lower center",
+            ncol=2,
+            bbox_to_anchor=(0.5, -0.05),
+            fancybox=True,
+            shadow=True,
+            fontsize=14,
+        )
 
-    # Add gridlines
-    ax.grid(linestyle='--', linewidth=1)
+        # Add gridlines
+        ax.grid(linestyle="--", linewidth=1)
 
-    # save plot
-    verify_plot_file = f'{os.path.dirname(path_n_suffix)}/VERIFY_tests_{os.path.basename(path_n_suffix)}.png'
-    plt.savefig(verify_plot_file, dpi=300, bbox_inches='tight')
+        # save plot
+        verify_plot_file = f"{os.path.dirname(path_n_suffix)}/VERIFY_tests_{os.path.basename(path_n_suffix)}.png"
+        plt.savefig(verify_plot_file, dpi=300, bbox_inches="tight")
+        plt.close(fig)
 
-    path_reduced = '/'.join(f'{verify_plot_file}'.replace('\\','/').split('/')[-2:])
+    path_reduced = "/".join(f"{verify_plot_file}".replace("\\", "/").split("/")[-2:])
     print_ver = f"\n   o  VERIFY plot saved in {path_reduced}"
 
     return print_ver
 
 
 def dict_formating(dict_csv):
-    '''
+    """
     Adapt format of dictionaries that come from dataframes loaded from CSV
-    '''
-    
+    """
+
     import json
 
-    if 'X_descriptors' in dict_csv:
+    if "X_descriptors" in dict_csv:
         # Try JSON first (new format), fall back to ast.literal_eval (old format)
         try:
-            dict_csv['X_descriptors'] = json.loads(dict_csv['X_descriptors'])
+            dict_csv["X_descriptors"] = json.loads(dict_csv["X_descriptors"])
         except (json.JSONDecodeError, TypeError):
-            dict_csv['X_descriptors'] = ast.literal_eval(dict_csv['X_descriptors'])
-            
-    if 'params' in dict_csv:
+            dict_csv["X_descriptors"] = ast.literal_eval(dict_csv["X_descriptors"])
+
+    if "params" in dict_csv:
         # Try JSON first (new format), fall back to ast.literal_eval (old format)
         try:
-            dict_csv['params'] = json.loads(dict_csv['params'])
+            dict_csv["params"] = json.loads(dict_csv["params"])
         except (json.JSONDecodeError, TypeError):
-            dict_csv['params'] = ast.literal_eval(dict_csv['params'])
+            dict_csv["params"] = ast.literal_eval(dict_csv["params"])
 
-    return dict_csv
\ No newline at end of file
+    return dict_csv
diff --git a/robert/verify.py b/robert/verify.py
index 9ee9440..4c33e07 100644
--- a/robert/verify.py
+++ b/robert/verify.py
@@ -5,13 +5,13 @@
     destination : str, default=None,
         Directory to create the output file(s).
     varfile : str, default=None
-        Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml).  
+        Option to parse the variables using a yaml file (specify the filename, i.e. varfile=FILE.yaml).
     params_dir : str, default=''
         Folder containing the database and parameters of the ML model to analyze.
     seed : int, default=0
         Random seed used in the ML predictor models and other protocols.
     kfold : int, default=5
-        Number of random data splits for the cross-validation of the models. 
+        Number of random data splits for the cross-validation of the models.
     repeat_kfolds : int, default=10
         Number of repetitions for the k-fold cross-validation of the models.
 
@@ -42,6 +42,7 @@
 thres_test_pass = 0.3
 thres_test_unclear = 0.15
 
+
 class verify:
     """
     Class containing all the functions from the VERIFY module.
@@ -53,7 +54,6 @@ class verify:
     """
 
     def __init__(self, **kwargs):
-
         start_time = time.time()
 
         # load default and user-specified variables
@@ -63,194 +63,247 @@ def __init__(self, **kwargs):
             self.args.params_dir
         ):
             if os.path.exists(params_dir):
-
-                _ = print_pfi(self,params_dir)
+                _ = print_pfi(self, params_dir)
 
                 # load the Xy databse and model parameters
-                Xy_data, model_data, suffix_title = load_db_n_params(self,params_dir,suffix,suffix_title,"verify",True)
-                
+                Xy_data, model_data, suffix_title = load_db_n_params(
+                    self, params_dir, suffix, suffix_title, "verify", True
+                )
+
                 # this dictionary will keep the results of the tests
-                verify_results = {'error_type': model_data['error_type']}
+                verify_results = {"error_type": model_data["error_type"]}
 
                 # get data about repeated and sorted CVs
-                Xy_data = load_n_predict(self, model_data, Xy_data, BO_opt=True, verify_job=True)
-                verify_results['CV_score'] = Xy_data[f'{verify_results["error_type"]}_train']
-                verify_results['sorted_CV_score'] = Xy_data[f'{model_data["error_type"]}_train_sorted_CV']
-                if model_data['type'].lower() == 'reg':
-                    verify_results[f'r2_train_sorted_CV'] = [float(f"{val:.2f}") for val in Xy_data[f'r2_train_sorted_CV']]
-                    verify_results[f'mae_train_sorted_CV'] = [float(f"{val:.2f}") for val in Xy_data[f'mae_train_sorted_CV']]
-                    verify_results[f'rmse_train_sorted_CV'] = [float(f"{val:.2f}") for val in Xy_data[f'rmse_train_sorted_CV']]
-                elif model_data['type'].lower() == 'clas':
-                    verify_results[f'acc_train_sorted_CV'] = [float(f"{val:.2f}") for val in Xy_data[f'acc_train_sorted_CV']]
-                    verify_results[f'f1_train_sorted_CV'] = [float(f"{val:.2f}") for val in Xy_data[f'f1_train_sorted_CV']]
-                    verify_results[f'mcc_train_sorted_CV'] = [float(f"{val:.2f}") for val in Xy_data[f'mcc_train_sorted_CV']]
+                Xy_data = load_n_predict(
+                    self, model_data, Xy_data, BO_opt=True, verify_job=True
+                )
+                verify_results["CV_score"] = Xy_data[
+                    f"{verify_results['error_type']}_train"
+                ]
+                verify_results["sorted_CV_score"] = Xy_data[
+                    f"{model_data['error_type']}_train_sorted_CV"
+                ]
+                if model_data["type"].lower() == "reg":
+                    verify_results["r2_train_sorted_CV"] = [
+                        float(f"{val:.2f}") for val in Xy_data["r2_train_sorted_CV"]
+                    ]
+                    verify_results["mae_train_sorted_CV"] = [
+                        float(f"{val:.2f}") for val in Xy_data["mae_train_sorted_CV"]
+                    ]
+                    verify_results["rmse_train_sorted_CV"] = [
+                        float(f"{val:.2f}") for val in Xy_data["rmse_train_sorted_CV"]
+                    ]
+                elif model_data["type"].lower() == "clas":
+                    verify_results["acc_train_sorted_CV"] = [
+                        float(f"{val:.2f}") for val in Xy_data["acc_train_sorted_CV"]
+                    ]
+                    verify_results["f1_train_sorted_CV"] = [
+                        float(f"{val:.2f}") for val in Xy_data["f1_train_sorted_CV"]
+                    ]
+                    verify_results["mcc_train_sorted_CV"] = [
+                        float(f"{val:.2f}") for val in Xy_data["mcc_train_sorted_CV"]
+                    ]
 
                 # Reload once for flawed-model tests (fresh splits consistent with CSV on disk).
-                Xy_data, model_data, suffix_title = load_db_n_params(self,params_dir,suffix,suffix_title,"verify",False)
+                Xy_data, model_data, suffix_title = load_db_n_params(
+                    self, params_dir, suffix, suffix_title, "verify", False
+                )
 
                 # calculate scores for the y-mean test
-                verify_results = self.ymean_test(verify_results,Xy_data,model_data)
+                verify_results = self.ymean_test(verify_results, Xy_data, model_data)
 
                 # calculate scores for the y-shuffle test
-                verify_results = self.yshuffle_test(verify_results,Xy_data,model_data)
+                verify_results = self.yshuffle_test(verify_results, Xy_data, model_data)
 
                 # one-hot test (check that if a value isnt 0, the value assigned is 1)
-                verify_results = self.onehot_test(verify_results,Xy_data,model_data)
+                verify_results = self.onehot_test(verify_results, Xy_data, model_data)
 
                 # analysis of results
-                results_print,verify_results,verify_metrics = self.analyze_tests(verify_results)
+                results_print, verify_results, verify_metrics = self.analyze_tests(
+                    verify_results
+                )
 
                 # plot a bar graph with the results
                 if should_plot_verify_metrics(self.args):
-                    print_ver = plot_metrics(model_data,suffix_title,verify_metrics,verify_results)
+                    print_ver = plot_metrics(
+                        model_data, suffix_title, verify_metrics, verify_results
+                    )
                 else:
                     print_ver = "\n   o  VERIFY plot skipped (plot_verbosity)"
 
                 # print and save results
-                _ = self.print_verify(results_print,verify_results,print_ver,model_data)
-
-        _ = finish_print(self,start_time,'VERIFY')
+                _ = self.print_verify(
+                    results_print, verify_results, print_ver, model_data
+                )
 
+        _ = finish_print(self, start_time, "VERIFY")
 
-    def ymean_test(self,verify_results,Xy_data,model_data):
-        '''
-        Calculate the accuracy of the model when using a flat line of predicted y values. For 
+    def ymean_test(self, verify_results, Xy_data, model_data):
+        """
+        Calculate the accuracy of the model when using a flat line of predicted y values. For
         regression, the mean of the y values is used. For classification, the value that is
         predicted more often is used.
-        '''
+        """
 
-        Xy_ymean = Xy_data.copy()   
-        if model_data['type'].lower() == 'reg':
-            y_mean_array = np.ones(len(Xy_ymean['y_train']))*(Xy_ymean['y_train'].mean())
-            Xy_ymean['r2_train'], Xy_ymean['mae_train'], Xy_ymean['rmse_train'] = get_prediction_results(model_data,Xy_ymean['y_train'],y_mean_array)
-        
-        elif model_data['type'].lower() == 'clas':
-            y_mean_array = np.ones(len(Xy_ymean['y_train']))*mode(Xy_ymean['y_train'])
-            Xy_ymean['acc_train'], Xy_ymean['f1_train'], Xy_ymean['mcc_train'] = get_prediction_results(model_data,Xy_ymean['y_train'],y_mean_array)
+        Xy_ymean = Xy_data.copy()
+        if model_data["type"].lower() == "reg":
+            y_mean_array = np.ones(len(Xy_ymean["y_train"])) * (
+                Xy_ymean["y_train"].mean()
+            )
+            Xy_ymean["r2_train"], Xy_ymean["mae_train"], Xy_ymean["rmse_train"] = (
+                get_prediction_results(model_data, Xy_ymean["y_train"], y_mean_array)
+            )
 
-        verify_results['y_mean'] = Xy_ymean[f'{verify_results["error_type"]}_train']
+        elif model_data["type"].lower() == "clas":
+            y_mean_array = np.ones(len(Xy_ymean["y_train"])) * mode(Xy_ymean["y_train"])
+            Xy_ymean["acc_train"], Xy_ymean["f1_train"], Xy_ymean["mcc_train"] = (
+                get_prediction_results(model_data, Xy_ymean["y_train"], y_mean_array)
+            )
 
-        return verify_results
+        verify_results["y_mean"] = Xy_ymean[f"{verify_results['error_type']}_train"]
 
+        return verify_results
 
-    def yshuffle_test(self,verify_results,Xy_data,model_data):
-        '''
+    def yshuffle_test(self, verify_results, Xy_data, model_data):
+        """
         Calculate the accuracy of the model when the y values are randomly shuffled in the validation set
         For example, a y array of 1.3, 2.1, 4.0, 5.2 might become 2.1, 1.3, 5.2, 4.0.
-        '''
+        """
 
         Xy_yshuffle = Xy_data.copy()
-        Xy_yshuffle['y_train'] = Xy_yshuffle['y_train'].sample(frac=1,random_state=model_data['seed'],axis=0)
+        Xy_yshuffle["y_train"] = Xy_yshuffle["y_train"].sample(
+            frac=1, random_state=model_data["seed"], axis=0
+        )
         Xy_yshuffle = load_n_predict(self, model_data, Xy_yshuffle, BO_opt=False)
 
-        verify_results['y_shuffle'] = Xy_yshuffle[f'{verify_results["error_type"]}_train']
+        verify_results["y_shuffle"] = Xy_yshuffle[
+            f"{verify_results['error_type']}_train"
+        ]
 
         return verify_results
 
-
-    def onehot_test(self,verify_results,Xy_data,model_data):
-        '''
+    def onehot_test(self, verify_results, Xy_data, model_data):
+        """
         Calculate the accuracy of the model when using one-hot models. All X values that are
         not 0 are considered to be 1 (NaN from missing values are converted to 0).
-        '''
+        """
 
         Xy_onehot = Xy_data.copy()
-        Xy_onehot['X_train_scaled'] = Xy_onehot['X_train_scaled'].copy()
-        for desc in Xy_onehot['X_train']:
+        Xy_onehot["X_train_scaled"] = Xy_onehot["X_train_scaled"].copy()
+        for desc in Xy_onehot["X_train"]:
             new_vals = []
-            for val in Xy_onehot['X_train'][desc]:
+            for val in Xy_onehot["X_train"][desc]:
                 if val == 0:
                     new_vals.append(0)
                 else:
                     new_vals.append(1)
-            Xy_onehot['X_train_scaled'][desc] = new_vals
+            Xy_onehot["X_train_scaled"][desc] = new_vals
 
         Xy_onehot = load_n_predict(self, model_data, Xy_onehot, BO_opt=False)
-        verify_results['onehot'] = Xy_onehot[f'{verify_results["error_type"]}_train']
+        verify_results["onehot"] = Xy_onehot[f"{verify_results['error_type']}_train"]
         return verify_results
 
-
-    def analyze_tests(self,verify_results):
-        '''
+    def analyze_tests(self, verify_results):
+        """
         Function to check whether the tests pass and retrieve the corresponding colors:
         1. Blue for passing tests
         2. Red for failing tests
-        '''
+        """
 
-        blue_color = '#1f77b4'
-        red_color = '#cd5c5c'
-        yellow_color = '#c5c57d'
-        colors = [None,None,None]
-        results_print = [None,None,None]
-        metrics = [None,None,None]
+        blue_color = "#1f77b4"
+        red_color = "#cd5c5c"
+        yellow_color = "#c5c57d"
+        colors = [None, None, None]
+        results_print = [None, None, None]
+        metrics = [None, None, None]
 
         # the threshold uses validation results to compare in the tests
-        verify_results['higher_thres'] = (1+thres_test_pass)*verify_results['CV_score']
-        verify_results['unclear_higher_thres'] = (1+thres_test_unclear)*verify_results['CV_score']
-        verify_results['lower_thres'] = (1-thres_test_pass)*verify_results['CV_score']
-        verify_results['unclear_lower_thres'] = (1-thres_test_unclear)*verify_results['CV_score']
+        verify_results["higher_thres"] = (1 + thres_test_pass) * verify_results[
+            "CV_score"
+        ]
+        verify_results["unclear_higher_thres"] = (
+            1 + thres_test_unclear
+        ) * verify_results["CV_score"]
+        verify_results["lower_thres"] = (1 - thres_test_pass) * verify_results[
+            "CV_score"
+        ]
+        verify_results["unclear_lower_thres"] = (
+            1 - thres_test_unclear
+        ) * verify_results["CV_score"]
 
         # determine whether the tests pass
-        test_names = ['y_mean','y_shuffle','onehot']
-        for i,test_ver in enumerate(test_names):
+        test_names = ["y_mean", "y_shuffle", "onehot"]
+        for i, test_ver in enumerate(test_names):
             metrics[i] = verify_results[test_ver]
-            if verify_results['error_type'].lower() in ['mae','rmse']:
-                if verify_results[test_ver] <= verify_results['unclear_higher_thres']:
-                        colors[i] = red_color
-                        results_print[i] = f'\n         x {test_ver}: FAILED, {verify_results["error_type"].upper()} = {verify_results[test_ver]:.2}, lower than threshold'
-                elif verify_results[test_ver] <= verify_results['higher_thres']:
-                        colors[i] = yellow_color
-                        results_print[i] = f'\n         - {test_ver}: UNCLEAR, {verify_results["error_type"].upper()} = {verify_results[test_ver]:.2}, higher than original, but close to fail'
+            if verify_results["error_type"].lower() in ["mae", "rmse"]:
+                if verify_results[test_ver] <= verify_results["unclear_higher_thres"]:
+                    colors[i] = red_color
+                    results_print[i] = (
+                        f"\n         x {test_ver}: FAILED, {verify_results['error_type'].upper()} = {verify_results[test_ver]:.2}, lower than threshold"
+                    )
+                elif verify_results[test_ver] <= verify_results["higher_thres"]:
+                    colors[i] = yellow_color
+                    results_print[i] = (
+                        f"\n         - {test_ver}: UNCLEAR, {verify_results['error_type'].upper()} = {verify_results[test_ver]:.2}, higher than original, but close to fail"
+                    )
                 else:
-                        colors[i] = blue_color
-                        results_print[i] = f'\n         o {test_ver}: PASSED, {verify_results["error_type"].upper()} = {verify_results[test_ver]:.2}, higher than thresholds'
+                    colors[i] = blue_color
+                    results_print[i] = (
+                        f"\n         o {test_ver}: PASSED, {verify_results['error_type'].upper()} = {verify_results[test_ver]:.2}, higher than thresholds"
+                    )
 
             else:
-                if verify_results[test_ver] >= verify_results['unclear_lower_thres']:
-                        colors[i] = red_color
-                        results_print[i] = f'\n         x {test_ver}: FAILED, {verify_results["error_type"].upper()} = {verify_results[test_ver]:.2}, higher than thresholds'
-                elif verify_results[test_ver] >= verify_results['lower_thres']:
-                        colors[i] = yellow_color
-                        results_print[i] = f'\n         - {test_ver}: UNCLEAR, {verify_results["error_type"].upper()} = {verify_results[test_ver]:.2}, lower than original, but close to fail'
+                if verify_results[test_ver] >= verify_results["unclear_lower_thres"]:
+                    colors[i] = red_color
+                    results_print[i] = (
+                        f"\n         x {test_ver}: FAILED, {verify_results['error_type'].upper()} = {verify_results[test_ver]:.2}, higher than thresholds"
+                    )
+                elif verify_results[test_ver] >= verify_results["lower_thres"]:
+                    colors[i] = yellow_color
+                    results_print[i] = (
+                        f"\n         - {test_ver}: UNCLEAR, {verify_results['error_type'].upper()} = {verify_results[test_ver]:.2}, lower than original, but close to fail"
+                    )
                 else:
-                        colors[i] = blue_color
-                        results_print[i] = f'\n         o {test_ver}: PASSED, {verify_results["error_type"].upper()} = {verify_results[test_ver]:.2}, lower than thresholds'
+                    colors[i] = blue_color
+                    results_print[i] = (
+                        f"\n         o {test_ver}: PASSED, {verify_results['error_type'].upper()} = {verify_results[test_ver]:.2}, lower than thresholds"
+                    )
 
-        # store metrics and colors to represent in comparison graph, adding the metrics of the 
+        # store metrics and colors to represent in comparison graph, adding the metrics of the
         # original model first
-        test_names = ['Model'] + test_names
+        test_names = ["Model"] + test_names
         colors = [blue_color] + colors
-        metrics = [verify_results['CV_score']] + metrics
-        verify_metrics = {'test_names': test_names,
-                          'colors': colors,
-                          'metrics': metrics,
-                          'higher_thres': verify_results['higher_thres'],
-                          'lower_thres': verify_results['lower_thres'],
-                          'unclear_higher_thres': verify_results['unclear_higher_thres'],
-                          'unclear_lower_thres': verify_results['unclear_lower_thres'],
-                          }        
-        
-        return results_print,verify_results,verify_metrics
-
-
-    def print_verify(self,results_print,verify_results,print_ver,model_data):
-        '''
+        metrics = [verify_results["CV_score"]] + metrics
+        verify_metrics = {
+            "test_names": test_names,
+            "colors": colors,
+            "metrics": metrics,
+            "higher_thres": verify_results["higher_thres"],
+            "lower_thres": verify_results["lower_thres"],
+            "unclear_higher_thres": verify_results["unclear_higher_thres"],
+            "unclear_lower_thres": verify_results["unclear_lower_thres"],
+        }
+
+        return results_print, verify_results, verify_metrics
+
+    def print_verify(self, results_print, verify_results, print_ver, model_data):
+        """
         Print and store the results of VERIFY
-        '''
+        """
 
-        print_ver += f'\n      Results of flawed models and sorted cross-validation:'
+        print_ver += "\n      Results of flawed models and sorted cross-validation:"
         CV_type = f"{model_data['repeat_kfolds']}x {model_data['kfold']}-fold CV"
         # the printing order should be y-mean, y-shuffle and one-hot
-        if verify_results['error_type'].lower() in ['mae','rmse']:
-            print_ver += f'\n      Original {verify_results["error_type"].upper()} ({CV_type}) {verify_results["CV_score"]:.2} + {int(thres_test_unclear*100)}% & {int(thres_test_pass*100)}% threshold = {verify_results["unclear_higher_thres"]:.2} & {verify_results["higher_thres"]:.2}'
+        if verify_results["error_type"].lower() in ["mae", "rmse"]:
+            print_ver += f"\n      Original {verify_results['error_type'].upper()} ({CV_type}) {verify_results['CV_score']:.2} + {int(thres_test_unclear * 100)}% & {int(thres_test_pass * 100)}% threshold = {verify_results['unclear_higher_thres']:.2} & {verify_results['higher_thres']:.2}"
         else:
-            print_ver += f'\n      Original {verify_results["error_type"].upper()} ({CV_type}) {verify_results["CV_score"]:.2} - {int(thres_test_unclear*100)}% & {int(thres_test_pass*100)}% threshold = {verify_results["unclear_lower_thres"]:.2} & {verify_results["lower_thres"]:.2}'
+            print_ver += f"\n      Original {verify_results['error_type'].upper()} ({CV_type}) {verify_results['CV_score']:.2} - {int(thres_test_unclear * 100)}% & {int(thres_test_pass * 100)}% threshold = {verify_results['unclear_lower_thres']:.2} & {verify_results['lower_thres']:.2}"
         print_ver += results_print[0]
         print_ver += results_print[1]
         print_ver += results_print[2]
-        if model_data['type'].lower() == 'reg':
+        if model_data["type"].lower() == "reg":
             print_ver += f"\n         - Sorted {model_data['kfold']}-fold CV : R2 = {verify_results['r2_train_sorted_CV']}, MAE = {verify_results['mae_train_sorted_CV']}, RMSE = {verify_results['rmse_train_sorted_CV']}"
-        elif model_data['type'].lower() == 'clas':
+        elif model_data["type"].lower() == "clas":
             print_ver += f"\n         - Sorted CV : Accuracy = {verify_results['acc_train_sorted_CV']}, F1 score = {verify_results['f1_train_sorted_CV']}, MCC = {verify_results['mcc_train_sorted_CV']}"
 
         self.args.log.write(print_ver)
diff --git a/setup.py b/setup.py
index ecfaff1..4a8f3fd 100644
--- a/setup.py
+++ b/setup.py
@@ -1,6 +1,6 @@
 from setuptools import setup, find_packages
 
-version = "2.1.0"
+version = "2.2.0"
 
 setup(
     name="robert",
@@ -34,11 +34,11 @@
     url="https://github.com/jvalegre/robert",
     download_url=f"https://github.com/jvalegre/robert/archive/refs/tags/{version}.tar.gz",
     classifiers=[
-        "Development Status :: 5 - Production/Stable", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
-        "Intended Audience :: Developers", # Define that your audience are developers
+        "Development Status :: 5 - Production/Stable",  # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
+        "Intended Audience :: Developers",  # Define that your audience are developers
         "Topic :: Software Development :: Build Tools",
         "License :: OSI Approved :: MIT License",
-        "Programming Language :: Python :: 3.11", # Specify which python versions you want to support
+        "Programming Language :: Python :: 3.11",  # Specify which python versions you want to support
         "Programming Language :: Python :: 3.12",
         "Programming Language :: Python :: 3.13",
         "Programming Language :: Python :: 3.14",
@@ -68,9 +68,16 @@
     ],
     python_requires=">=3.11",
     include_package_data=True,
+    extras_require={
+        "test": [
+            "pytest>=7.0",
+            "pytest-cov>=4.0",
+            "pytest-qt>=4.0",
+        ],
+    },
     entry_points={
         "console_scripts": [
             "easyrob=robert.gui_easyrob.easyrob_launcher:main",
         ],
     },
-)
\ No newline at end of file
+)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..7c61582
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1 @@
+"""ROBERT test package (enables shared imports between test modules)."""
diff --git a/tests/conftest.py b/tests/conftest.py
index 34aeb51..1414922 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,10 +1,28 @@
 """Pytest configuration for the ROBERT test suite."""
 
+from __future__ import annotations
+
 import os
 import sys
+from contextlib import contextmanager
+from pathlib import Path
 
 import pytest
 
+REPO_ROOT = Path(__file__).resolve().parent.parent
+
+# Module output folders legacy integration tests swap under the repo root.
+ROBERT_MODULE_DIR_NAMES = (
+    "CURATE",
+    "GENERATE",
+    "GENERATE_reg",
+    "GENERATE_clas",
+    "PREDICT",
+    "VERIFY",
+    "AQME",
+    "EVALUATE",
+)
+
 
 def pytest_configure(config):
     """
@@ -28,6 +46,101 @@ def pytest_configure(config):
     os.environ.setdefault("MPLBACKEND", "Agg")
 
 
+def robert_module_dirs(root: Path) -> set[str]:
+    """Names of ROBERT module directories present under ``root``."""
+    return {name for name in ROBERT_MODULE_DIR_NAMES if (root / name).is_dir()}
+
+
+def aqme_installed() -> bool:
+    """Return True when the optional AQME package is importable."""
+    try:
+        import aqme  # noqa: F401
+
+        return True
+    except ImportError:
+        return False
+
+
+def restore_regression_generate_layout(root: Path) -> None:
+    """
+    Undo a half-finished clas/reg GENERATE rename left by a failed test.
+
+    After a successful clas test: ``GENERATE`` (reg) + ``GENERATE_clas``.
+    Mid-clas failure may leave: ``GENERATE_reg`` + ``GENERATE`` (clas).
+    """
+    reg_backup = root / "GENERATE_reg"
+    generate = root / "GENERATE"
+    clas = root / "GENERATE_clas"
+    if reg_backup.is_dir() and generate.is_dir():
+        generate.rename(clas)
+        reg_backup.rename(generate)
+    elif reg_backup.is_dir() and not generate.is_dir():
+        reg_backup.rename(generate)
+
+
+@contextmanager
+def clas_generate_layout(root: Path):
+    """
+    Temporarily point ``GENERATE`` at the classification screening outputs.
+
+    Requires ``GENERATE`` (regression) and ``GENERATE_clas`` from
+    ``test_2generate`` (e.g. ``reduced_clas``).
+    """
+    restore_regression_generate_layout(root)
+    generate = root / "GENERATE"
+    clas = root / "GENERATE_clas"
+    if not clas.is_dir():
+        pytest.skip(
+            "GENERATE_clas missing under repo root; run "
+            "tests/test_2generate.py::test_GENERATE[reduced_clas] first."
+        )
+    if not generate.is_dir():
+        pytest.skip(
+            "GENERATE missing under repo root; run GENERATE integration tests first."
+        )
+    generate.rename(root / "GENERATE_reg")
+    clas.rename(generate)
+    try:
+        yield
+    finally:
+        generate.rename(clas)
+        (root / "GENERATE_reg").rename(generate)
+
+
+@pytest.fixture
+def repo_root() -> Path:
+    """Repository root (stable even if the process cwd changes during a test)."""
+    return REPO_ROOT
+
+
+@pytest.fixture(autouse=True)
+def restore_process_cwd():
+    """Restore the process working directory after each test."""
+    cwd_before = os.getcwd()
+    yield
+    try:
+        os.chdir(cwd_before)
+    except OSError:
+        pass
+
+
+@pytest.fixture(autouse=True)
+def drain_qt_thread_pool_after_test():
+    """Avoid 'QThread destroyed while still running' abort on pytest exit."""
+    yield
+    try:
+        from PySide6.QtCore import QCoreApplication, QThreadPool
+    except ImportError:
+        return
+    app = QCoreApplication.instance()
+    if app is None:
+        return
+    pool = QThreadPool.globalInstance()
+    if pool is not None:
+        pool.waitForDone(10_000)
+    app.processEvents()
+
+
 @pytest.fixture
 def fast_robert_kwargs():
     """Reduced CV/BO settings for faster integration tests."""
diff --git a/tests/test_2generate.py b/tests/test_2generate.py
index 4c91a9f..deb847f 100644
--- a/tests/test_2generate.py
+++ b/tests/test_2generate.py
@@ -75,9 +75,7 @@ def _log_line_metric_close(line, prefix, expected, *, rel_tol=0.05, abs_tol=0.02
         (
             "reduced_adab"
         ),  # test for other GP model (important since PFI filter tries to discard all the descriptors)
-        (
-            "reduced_xgb"
-        ),  # test for XGB model (tree booster with feature importances)
+        ("reduced_xgb"),  # test for XGB model (tree booster with feature importances)
         ("reduced_vr"),  # test for Voting Regressor model
         ("reduced_vr_clas"),  # test Voting Classifier workflow
         ("reduced_clas"),  # test for clasification models
@@ -126,7 +124,13 @@ def test_GENERATE(test_job):
         generate_kwargs = {"generate": True, "csv_name": csv_name, "y": "Target_values"}
         if test_job != "standard":
             # add model
-            if test_job not in ["reduced_gp", "reduced_adab", "reduced_xgb", "reduced_vr", "reduced_vr_clas"]:
+            if test_job not in [
+                "reduced_gp",
+                "reduced_adab",
+                "reduced_xgb",
+                "reduced_vr",
+                "reduced_vr_clas",
+            ]:
                 generate_kwargs["model"] = ["RF"]
             elif test_job == "reduced_gp":
                 generate_kwargs["model"] = ["GP"]
@@ -465,7 +469,9 @@ def test_GENERATE(test_job):
         if test_job in ["reduced_clas", "reduced_vr_clas"]:
             model_name = "RF" if test_job == "reduced_clas" else "VR"
             csv_clas = glob.glob(
-                os.path.join(path_generate, "Best_model", "PFI", f"{model_name}_PFI.csv")
+                os.path.join(
+                    path_generate, "Best_model", "PFI", f"{model_name}_PFI.csv"
+                )
             )
             df = pd.read_csv(csv_clas[0])
             if "error_type" in df.columns:
diff --git a/tests/test_3verify.py b/tests/test_3verify.py
index 3640273..96828dd 100644
--- a/tests/test_3verify.py
+++ b/tests/test_3verify.py
@@ -4,18 +4,18 @@
 # 	        Testing VERIFY with pytest 	             #
 ######################################################.
 
-import os
+import subprocess
 import sys
-import glob
+
 import pytest
 import shutil
-import subprocess
-from pathlib import Path
+
 from robert.verify import verify
 
-# saves the working directory
-path_main = os.getcwd()
-path_verify = os.path.join(path_main, "VERIFY")
+from tests.conftest import (
+    clas_generate_layout,
+    restore_regression_generate_layout,
+)
 
 
 # VERIFY tests
@@ -27,60 +27,39 @@
         ("standard_cmd"),  # standard test with command line
     ],
 )
-def test_VERIFY(test_job):
-    # leave the folders as they were initially to run a different batch of tests
-    if os.path.exists(path_verify):
+def test_VERIFY(test_job, repo_root, monkeypatch):
+    monkeypatch.chdir(repo_root)
+    path_verify = repo_root / "VERIFY"
+
+    if path_verify.is_dir():
         shutil.rmtree(path_verify)
-        # remove DAT and CSV files generated by VERIFY
-        dat_files = glob.glob("*.dat")
-        for dat_file in dat_files:
-            if "VERIFY" in dat_file:
-                os.remove(dat_file)
+    for dat_file in repo_root.glob("*.dat"):
+        if "VERIFY" in dat_file.name:
+            dat_file.unlink()
 
-    if test_job == "clas":  # rename folders to use in classification
-        # rename the regression GENERATE folder
-        filepath_reg = Path(path_main) / "GENERATE"
-        filepath_reg.rename(Path(path_main) / "GENERATE_reg")
-        # rename the classification GENERATE folder
-        filepath = Path(path_main) / "GENERATE_clas"
-        filepath.rename(Path(path_main) / "GENERATE")
+    if test_job == "clas":
+        with clas_generate_layout(repo_root):
+            _run_verify(test_job, repo_root, path_verify)
+    else:
+        restore_regression_generate_layout(repo_root)
+        _run_verify(test_job, repo_root, path_verify)
 
-    else:  # in case the clas test fails and the ending rename doesn't happen
-        if os.path.exists(Path(path_main) / "GENERATE_reg"):
-            # rename the classification GENERATE folder
-            filepath = Path(path_main) / "GENERATE"
-            filepath.rename(Path(path_main) / "GENERATE_clas")
-            # rename the regression GENERATE folder
-            filepath_reg = Path(path_main) / "GENERATE_reg"
-            filepath_reg.rename(Path(path_main) / "GENERATE")
 
-    # runs the program with the different tests
+def _run_verify(test_job, repo_root, path_verify):
     if test_job == "standard_cmd":
-        cmd_robert = [
-            sys.executable,
-            "-m",
-            "robert",
-            "--verify",
-        ]
-
-        subprocess.run(cmd_robert)
-
+        cmd_robert = [sys.executable, "-m", "robert", "--verify"]
+        subprocess.run(cmd_robert, cwd=repo_root, check=False)
     else:
-        verify_kwargs = {}
+        verify()
 
-        verify(**verify_kwargs)
-
-    # check that the DAT file is created
-    assert not os.path.exists(os.path.join(path_main, "VERIFY_data.dat"))
-    outfile = open(os.path.join(path_verify, "VERIFY_data.dat"), "r")
-    outlines = outfile.readlines()
-    outfile.close()
+    assert not (repo_root / "VERIFY_data.dat").is_file()
+    verify_dat = path_verify / "VERIFY_data.dat"
+    with verify_dat.open(encoding="utf-8") as outfile:
+        outlines = outfile.readlines()
     assert "ROBERT v" in outlines[0]
     results_line, start_reading = False, False
     for i, line in enumerate(outlines):
-        if (
-            "------- Starting model with PFI filter " in line
-        ):  # focus on the PFI since there is an unclear test
+        if "------- Starting model with PFI filter " in line:
             start_reading = True
         if start_reading:
             if "Results of flawed models and sorted cross-validation:" in line:
@@ -121,14 +100,5 @@ def test_VERIFY(test_job):
                 break
     assert results_line
 
-    # check that the verify plots and DAT file are created
-    assert len(glob.glob(os.path.join(path_verify, "*.png"))) == 2
-    assert len(glob.glob(os.path.join(path_verify, "*.dat"))) == 1
-
-    if test_job == "clas":  # rename folders back to their original names
-        # rename the classification GENERATE folder
-        filepath = Path(path_main) / "GENERATE"
-        filepath.rename(Path(path_main) / "GENERATE_clas")
-        # rename the regression GENERATE folder
-        filepath_reg = Path(path_main) / "GENERATE_reg"
-        filepath_reg.rename(Path(path_main) / "GENERATE")
+    assert len(list(path_verify.glob("*.png"))) == 2
+    assert len(list(path_verify.glob("*.dat"))) == 1
diff --git a/tests/test_5aqme_n_full.py b/tests/test_5aqme_n_full.py
index 4d05448..1c1905b 100644
--- a/tests/test_5aqme_n_full.py
+++ b/tests/test_5aqme_n_full.py
@@ -17,6 +17,15 @@
 path_aqme = os.path.join(path_main, "AQME")
 
 
+def _aqme_installed() -> bool:
+    try:
+        import aqme  # noqa: F401
+
+        return True
+    except ImportError:
+        return False
+
+
 # AQME and full workflow tests
 @pytest.mark.parametrize(
     "test_job",
@@ -30,6 +39,9 @@
     ],
 )
 def test_AQME(test_job):
+    if test_job in ("aqme", "2smiles_columns") and not _aqme_installed():
+        pytest.skip("AQME is not installed (pip install aqme==2.0.0)")
+
     # reset the folders (to avoid interferences with previous failed tests)
     folders = [
         "CURATE",
diff --git a/tests/test_7api.py b/tests/test_7api.py
index e481533..648731e 100644
--- a/tests/test_7api.py
+++ b/tests/test_7api.py
@@ -12,7 +12,6 @@
 import sys
 import tempfile
 import warnings
-from pathlib import Path
 from types import SimpleNamespace
 
 import numpy as np
@@ -33,12 +32,35 @@
     should_plot_verify_metrics,
 )
 
-_REPO = Path(__file__).resolve().parent.parent
+from tests.conftest import REPO_ROOT, robert_module_dirs
+
+_REPO = REPO_ROOT
 _REG_CSV = _REPO / "tests" / "Robert_example.csv"
 _CLAS_CSV = _REPO / "tests" / "Robert_example_clas.csv"
 _FIXTURE_MODEL = _REPO / "tests" / "fixtures" / "custom_predict_model"
 
 
+@pytest.fixture(autouse=True)
+def api_test_repo_isolation(repo_root):
+    """
+    RobertModel.fit uses os.chdir(workdir); ensure cwd and repo-root artifacts
+    do not leak into legacy integration tests (e.g. test_3verify).
+    """
+    dirs_before = robert_module_dirs(repo_root)
+    yield
+    try:
+        os.chdir(repo_root)
+    except OSError:
+        pass
+    for name in robert_module_dirs(repo_root) - dirs_before:
+        shutil.rmtree(repo_root / name, ignore_errors=True)
+    for dat_file in repo_root.glob("*.dat"):
+        if any(
+            tag in dat_file.name for tag in ("CURATE", "GENERATE", "PREDICT", "VERIFY")
+        ):
+            dat_file.unlink(missing_ok=True)
+
+
 @pytest.fixture
 def custom_model_dir(tmp_path):
     """Minimal GENERATE-style folder (params CSV + _db.csv)."""
@@ -59,7 +81,9 @@ def _holdout_for_predict(X: pd.DataFrame, n_fit: int) -> pd.DataFrame:
 
 
 def test_yaml_unknown_key_warns_and_known_key_applies(capsys):
-    with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False, encoding="utf-8") as f:
+    with tempfile.NamedTemporaryFile(
+        mode="w", suffix=".yaml", delete=False, encoding="utf-8"
+    ) as f:
         f.write("not_a_robert_option: 1\nseed: 99\n")
         path = f.name
     try:
@@ -76,7 +100,9 @@ def test_yaml_unknown_key_warns_and_known_key_applies(capsys):
 
 def test_yaml_missing_file_message():
     opts = set_options({})
-    opts.varfile = os.path.join(tempfile.gettempdir(), "robert_nonexistent_params_xyz.yaml")
+    opts.varfile = os.path.join(
+        tempfile.gettempdir(), "robert_nonexistent_params_xyz.yaml"
+    )
     _, msg = load_from_yaml(opts)
     assert "not found" in msg.lower()
 
@@ -307,6 +333,29 @@ def test_predict_row_order_matches_input_order(tmp_path, fast_robert_kwargs):
     assert np.allclose(pred_natural, pred_realigned)
 
 
+def test_fit_with_report_and_robert_scores(tmp_path, fast_robert_kwargs):
+    """Full API path with REPORT and robert_scores()."""
+    pytest.importorskip("weasyprint")
+    df = pd.read_csv(_REG_CSV, encoding="utf-8")
+    X = df.drop(columns=["Target_values"])
+    y = df["Target_values"]
+    model = RobertModel(
+        problem_type="reg",
+        filter_mode="no_pfi",
+        workdir=tmp_path,
+        names="Name",
+        report=True,
+        **fast_robert_kwargs,
+    )
+    model.fit(X.iloc[:20], y.iloc[:20])
+    scores = model.robert_scores()
+    assert 0 <= scores["robert_score"] <= 10
+    assert "cv_score_combined" in scores["components"]
+    pdf = tmp_path / "ROBERT_report.pdf"
+    assert pdf.is_file()
+    assert scores["pdf_path"] == str(pdf)
+
+
 def test_fit_accepts_unused_fit_params(tmp_path, fast_robert_kwargs):
     """Sklearn Pipeline may pass extra fit kwargs; they should not raise."""
     df = pd.read_csv(_REG_CSV, encoding="utf-8")
diff --git a/tests/test_8uq.py b/tests/test_8uq.py
index d9ff00c..cfaf80f 100644
--- a/tests/test_8uq.py
+++ b/tests/test_8uq.py
@@ -82,9 +82,7 @@ def test_fit_predict_meta_uncertainty_modes(tmp_path, fast_robert_kwargs):
     assert y_hat.shape == uq_meta.shape
     assert np.isfinite(uq_meta).all() and (uq_meta >= 0).all()
     _, uq_total = model.predict(X_hold, return_uncertainty="total")
-    y_d, uq_m, uq_meta2, uq_tot = model.predict(
-        X_hold, return_uncertainty="decomposed"
-    )
+    y_d, uq_m, uq_meta2, uq_tot = model.predict(X_hold, return_uncertainty="decomposed")
     assert y_d.shape == uq_m.shape == uq_meta2.shape == uq_tot.shape
     assert np.all(uq_tot >= uq_m - 1e-9)
     assert np.allclose(uq_tot, uq_total, rtol=1e-5, atol=1e-5)
@@ -130,9 +128,9 @@ def test_score_prefers_calibrated_scale():
     abs_res = np.array([1.0, 1.0, 1.0, 1.0])
     bad = np.full(4, 10.0)
     good = np.full(4, 1.0)
-    assert score_uncertainty_candidate(good, abs_res, 0.9) < score_uncertainty_candidate(
-        bad, abs_res, 0.9
-    )
+    assert score_uncertainty_candidate(
+        good, abs_res, 0.9
+    ) < score_uncertainty_candidate(bad, abs_res, 0.9)
 
 
 def test_evaluate_uq_candidates_deterministic():
diff --git a/tests/test_easyrob.py b/tests/test_easyrob.py
index bbaec0b..795d168 100644
--- a/tests/test_easyrob.py
+++ b/tests/test_easyrob.py
@@ -29,10 +29,10 @@
 sys.path.insert(0, str(PROJECT_ROOT))
 
 # Third-party imports
-import pandas as pd
-import pytest
-from PySide6.QtCore import Qt, QCoreApplication
-from PySide6.QtWidgets import (
+import pandas as pd  # noqa: E402
+import pytest  # noqa: E402
+from PySide6.QtCore import Qt, QCoreApplication  # noqa: E402
+from PySide6.QtWidgets import (  # noqa: E402
     QListWidgetItem,
     QMessageBox,
     QDialog,
@@ -40,16 +40,17 @@
     QPushButton,
     QTableWidget,
     QTableWidgetItem,
-)
-from rdkit import Chem
+)  # noqa: E402
 
 # Local project imports
-from robert.gui_easyrob.main.window import EasyROB
-import robert.gui_easyrob.easyrob as easyrob_module
-import robert.gui_easyrob.main.window as window_module
-import robert.gui_easyrob.tabs.aqme as aqme_module
-import robert.gui_easyrob.tabs.predictions as predictions_module
-import robert.gui_easyrob.tabs.results as results_module
+from robert.gui_easyrob.main.window import EasyROB  # noqa: E402
+import robert.gui_easyrob.easyrob as easyrob_module  # noqa: E402
+import robert.gui_easyrob.main.window as window_module  # noqa: E402
+import robert.gui_easyrob.tabs.aqme as aqme_module  # noqa: E402
+import robert.gui_easyrob.tabs.predictions as predictions_module  # noqa: E402
+import robert.gui_easyrob.tabs.results as results_module  # noqa: E402
+
+from tests.conftest import aqme_installed  # noqa: E402
 
 # ----------------------------------------------------------------------
 # Constants
@@ -154,7 +155,9 @@ def dump_console_output(title, text):
     print("----- END CONSOLE OUTPUT -----")
 
 
-def process_events_until(predicate, timeout_s, poll_interval_s=WORKFLOW_POLL_INTERVAL_S):
+def process_events_until(
+    predicate, timeout_s, poll_interval_s=WORKFLOW_POLL_INTERVAL_S
+):
     """Process Qt events until a condition becomes true or the timeout expires."""
     elapsed = 0.0
     while elapsed < timeout_s:
@@ -212,10 +215,16 @@ def wait_for_workflow_completion(
         all_dirs_exist = all((output_dir / name).is_dir() for name in expected_dirs)
         pdf_exists = report_pdf.is_file()
         if workflow_started and all_dirs_exist and pdf_exists:
-            print("[OK] Workflow completed (all output folders AND report PDF detected)")
+            print(
+                "[OK] Workflow completed (all output folders AND report PDF detected)"
+            )
             return True, workflow_started, last_console, last_process
 
-        if process is not None and process.poll() is not None and not (all_dirs_exist and pdf_exists):
+        if (
+            process is not None
+            and process.poll() is not None
+            and not (all_dirs_exist and pdf_exists)
+        ):
             print(
                 f"[WARN] Process exited with code {process.returncode} "
                 "but not all outputs (folders + PDF) are present yet."
@@ -242,21 +251,28 @@ def run_full_workflow_and_wait(window, qtbot, output_dir, expected_dirs, report_
         initial_console_text=baseline_text,
     )
     if not started:
-        dump_console_output("[DEBUG] Console at timeout (no start detected):", last_console)
+        dump_console_output(
+            "[DEBUG] Console at timeout (no start detected):", last_console
+        )
         pytest.fail("Workflow did not start within timeout")
     if not completed:
         print("\n[DEBUG] Console at timeout (no completion detected):")
-        print("Existing entries in output_dir:", [path.name for path in output_dir.iterdir()])
+        print(
+            "Existing entries in output_dir:",
+            [path.name for path in output_dir.iterdir()],
+        )
         print("Report PDF exists:", report_pdf.is_file())
         if last_process is not None:
             print("Process return code:", last_process.returncode)
         dump_console_output("[DEBUG] Timed out console snapshot:", last_console)
         pytest.fail("Workflow did not complete within timeout")
 
+
 # ----------------------------------------------------------------------
 # Fixtures
 # ----------------------------------------------------------------------
 
+
 @pytest.fixture(scope="session")
 def test_output_dir():
     """
@@ -310,6 +326,31 @@ def easyrob_window(qtbot, monkeypatch):
     return window
 
 
+@pytest.fixture
+def predictions_tab(qtbot):
+    """PredictionsTab with an active QApplication (required for QWidget construction)."""
+    tab = predictions_module.PredictionsTab()
+    qtbot.addWidget(tab)
+    return tab
+
+
+@pytest.fixture
+def results_tab_mocks(monkeypatch):
+    """Common ResultsTab mocks for unit tests that avoid PDF/WebEngine widgets."""
+    monkeypatch.setattr(results_module.QTimer, "singleShot", lambda ms, fn: None)
+    monkeypatch.setattr(
+        results_module,
+        "PDFViewer",
+        lambda pdf_path, thread_pool: results_module.QWidget(),
+    )
+
+
+def _results_tab(qtbot, input_csv_path: str):
+    tab = results_module.ResultsTab(None, str(input_csv_path))
+    qtbot.addWidget(tab)
+    return tab
+
+
 # =====================================================
 # Basic Initialization Tests
 # =====================================================
@@ -332,6 +373,7 @@ def test_easyrob_factory_returns_main_window_class():
     """The lightweight factory module returns the real main window class."""
     assert easyrob_module.get_main_window_class() is EasyROB
 
+
 def test_all_tabs_created(easyrob_window):
     """All expected top-level tabs are present."""
     window = easyrob_window
@@ -352,7 +394,9 @@ def test_dropdowns_populated(easyrob_window):
     window = easyrob_window
 
     assert window.type_dropdown.count() == 2
-    items = [window.type_dropdown.itemText(i) for i in range(window.type_dropdown.count())]
+    items = [
+        window.type_dropdown.itemText(i) for i in range(window.type_dropdown.count())
+    ]
     assert "Regression" in items
     assert "Classification" in items
 
@@ -389,8 +433,7 @@ def test_move_to_selected_and_back(easyrob_window):
         for i in range(window.available_list.count())
     ]
     ignore_items = [
-        window.ignore_list.item(i).text()
-        for i in range(window.ignore_list.count())
+        window.ignore_list.item(i).text() for i in range(window.ignore_list.count())
     ]
 
     assert "col3" in available_items
@@ -445,7 +488,9 @@ def test_load_csv_columns(easyrob_window, tmp_path):
     assert set(available_items) == {"ID", "Name", "Target", "Feature1"}
 
 
-def test_load_csv_columns_auto_ignores_smiles_and_prefers_code_name(easyrob_window, tmp_path):
+def test_load_csv_columns_auto_ignores_smiles_and_prefers_code_name(
+    easyrob_window, tmp_path
+):
     """SMILES is auto-ignored and code_name is auto-selected when present."""
     window = easyrob_window
 
@@ -467,8 +512,7 @@ def test_load_csv_columns_auto_ignores_smiles_and_prefers_code_name(easyrob_wind
         for i in range(window.available_list.count())
     }
     ignored_items = {
-        window.ignore_list.item(i).text()
-        for i in range(window.ignore_list.count())
+        window.ignore_list.item(i).text() for i in range(window.ignore_list.count())
     }
 
     assert "SMILES" not in available_items
@@ -476,7 +520,9 @@ def test_load_csv_columns_auto_ignores_smiles_and_prefers_code_name(easyrob_wind
     assert window.names_dropdown.currentText() == "code_name"
 
 
-def test_set_file_path_updates_ui_and_skips_redundant_reload(easyrob_window, tmp_path, monkeypatch):
+def test_set_file_path_updates_ui_and_skips_redundant_reload(
+    easyrob_window, tmp_path, monkeypatch
+):
     """set_file_path updates labels and avoids reloading unchanged files unless forced."""
     window = easyrob_window
     csv_path = tmp_path / "input.csv"
@@ -490,12 +536,36 @@ def test_set_file_path_updates_ui_and_skips_redundant_reload(easyrob_window, tmp
         "check_aqme_workflow": 0,
     }
 
-    monkeypatch.setattr(window, "load_csv_columns", lambda: calls.__setitem__("load_csv_columns", calls["load_csv_columns"] + 1))
-    monkeypatch.setattr(window, "refresh_tabs", lambda file_path: calls.__setitem__("refresh_tabs", calls["refresh_tabs"] + 1))
+    monkeypatch.setattr(
+        window,
+        "load_csv_columns",
+        lambda: calls.__setitem__("load_csv_columns", calls["load_csv_columns"] + 1),
+    )
+    monkeypatch.setattr(
+        window,
+        "refresh_tabs",
+        lambda file_path: calls.__setitem__("refresh_tabs", calls["refresh_tabs"] + 1),
+    )
     monkeypatch.setattr(window, "_is_molssi_csv", lambda file_path: False)
-    monkeypatch.setattr(window, "check_molssi_descriptors", lambda: calls.__setitem__("check_molssi_descriptors", calls["check_molssi_descriptors"] + 1))
-    monkeypatch.setattr(window, "_update_unified_smiles_context", lambda: calls.__setitem__("update_smiles", calls["update_smiles"] + 1))
-    monkeypatch.setattr(window, "check_aqme_workflow", lambda: calls.__setitem__("check_aqme_workflow", calls["check_aqme_workflow"] + 1))
+    monkeypatch.setattr(
+        window,
+        "check_molssi_descriptors",
+        lambda: calls.__setitem__(
+            "check_molssi_descriptors", calls["check_molssi_descriptors"] + 1
+        ),
+    )
+    monkeypatch.setattr(
+        window,
+        "_update_unified_smiles_context",
+        lambda: calls.__setitem__("update_smiles", calls["update_smiles"] + 1),
+    )
+    monkeypatch.setattr(
+        window,
+        "check_aqme_workflow",
+        lambda: calls.__setitem__(
+            "check_aqme_workflow", calls["check_aqme_workflow"] + 1
+        ),
+    )
     window.tab_widget_aqme.df_mapped_smiles = object()
 
     window.set_file_path(str(csv_path))
@@ -526,9 +596,23 @@ def test_set_and_clear_csv_test_path_updates_ui(easyrob_window, tmp_path, monkey
     pd.DataFrame({"SMILES": ["C"], "target": [1.0]}).to_csv(csv_path, index=False)
 
     calls = {"update_smiles": 0, "check_aqme_workflow": 0, "refresh_tabs": 0}
-    monkeypatch.setattr(window, "_update_unified_smiles_context", lambda: calls.__setitem__("update_smiles", calls["update_smiles"] + 1))
-    monkeypatch.setattr(window, "check_aqme_workflow", lambda: calls.__setitem__("check_aqme_workflow", calls["check_aqme_workflow"] + 1))
-    monkeypatch.setattr(window, "refresh_tabs", lambda file_path: calls.__setitem__("refresh_tabs", calls["refresh_tabs"] + 1))
+    monkeypatch.setattr(
+        window,
+        "_update_unified_smiles_context",
+        lambda: calls.__setitem__("update_smiles", calls["update_smiles"] + 1),
+    )
+    monkeypatch.setattr(
+        window,
+        "check_aqme_workflow",
+        lambda: calls.__setitem__(
+            "check_aqme_workflow", calls["check_aqme_workflow"] + 1
+        ),
+    )
+    monkeypatch.setattr(
+        window,
+        "refresh_tabs",
+        lambda file_path: calls.__setitem__("refresh_tabs", calls["refresh_tabs"] + 1),
+    )
 
     window.set_csv_test_path(str(csv_path))
 
@@ -541,7 +625,10 @@ def test_set_and_clear_csv_test_path_updates_ui(easyrob_window, tmp_path, monkey
     window.clear_test_file()
 
     assert window.csv_test_path is None
-    assert window.csv_test_label.label.text() == "Drag & Drop a CSV external test file here (optional)"
+    assert (
+        window.csv_test_label.label.text()
+        == "Drag & Drop a CSV external test file here (optional)"
+    )
     assert window.csv_test_label.toolTip() == ""
     assert window.clear_test_button.isHidden()
     assert calls["update_smiles"] == 2
@@ -569,14 +656,20 @@ def test_reset_ui_after_process_restores_buttons(easyrob_window):
 def test_open_external_url_uses_browser(easyrob_window, monkeypatch):
     """open_external_url delegates to the browser helper."""
     opened = {}
-    monkeypatch.setattr(window_module.webbrowser, "open", lambda url, new=0: opened.update({"url": url, "new": new}))
+    monkeypatch.setattr(
+        window_module.webbrowser,
+        "open",
+        lambda url, new=0: opened.update({"url": url, "new": new}),
+    )
 
     easyrob_window.open_external_url("https://example.com")
 
     assert opened == {"url": "https://example.com", "new": 2}
 
 
-def test_close_event_ignores_when_running_worker_is_not_stopped(easyrob_window, monkeypatch):
+def test_close_event_ignores_when_running_worker_is_not_stopped(
+    easyrob_window, monkeypatch
+):
     """Closing is aborted if ROBERT is still running and the user declines stopping it."""
     window = easyrob_window
 
@@ -591,7 +684,9 @@ class DummyWorker:
         def isRunning(self):
             return True
 
-    monkeypatch.setattr(window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.No)
+    monkeypatch.setattr(
+        window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.No
+    )
 
     event = DummyEvent()
     window.worker = DummyWorker()
@@ -646,10 +741,20 @@ def quit(self):
     shutdown_calls = {"n": 0}
     timer_calls = {"n": 0}
 
-    monkeypatch.setattr(window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.Yes)
+    monkeypatch.setattr(
+        window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.Yes
+    )
     monkeypatch.setattr(window_module, "QEventLoop", lambda: loop)
-    monkeypatch.setattr(window_module.QTimer, "singleShot", lambda ms, fn: timer_calls.__setitem__("n", timer_calls["n"] + 1))
-    monkeypatch.setattr(window, "_shutdown_molssi_async", lambda: shutdown_calls.__setitem__("n", shutdown_calls["n"] + 1))
+    monkeypatch.setattr(
+        window_module.QTimer,
+        "singleShot",
+        lambda ms, fn: timer_calls.__setitem__("n", timer_calls["n"] + 1),
+    )
+    monkeypatch.setattr(
+        window,
+        "_shutdown_molssi_async",
+        lambda: shutdown_calls.__setitem__("n", shutdown_calls["n"] + 1),
+    )
 
     event = DummyEvent()
     window.worker = worker
@@ -664,7 +769,9 @@ def quit(self):
     assert event.ignored == 1
 
 
-def test_close_event_without_running_worker_starts_async_shutdown(easyrob_window, monkeypatch):
+def test_close_event_without_running_worker_starts_async_shutdown(
+    easyrob_window, monkeypatch
+):
     """Closing without an active ROBERT worker still routes through async cleanup."""
     window = easyrob_window
 
@@ -676,7 +783,11 @@ def ignore(self):
             self.ignored += 1
 
     shutdown_calls = {"n": 0}
-    monkeypatch.setattr(window, "_shutdown_molssi_async", lambda: shutdown_calls.__setitem__("n", shutdown_calls["n"] + 1))
+    monkeypatch.setattr(
+        window,
+        "_shutdown_molssi_async",
+        lambda: shutdown_calls.__setitem__("n", shutdown_calls["n"] + 1),
+    )
 
     event = DummyEvent()
     window.worker = None
@@ -690,7 +801,11 @@ def ignore(self):
 def test_show_contact_dialog_executes_modal(easyrob_window, monkeypatch):
     """Contact dialog is created and executed."""
     exec_calls = {"n": 0}
-    monkeypatch.setattr(window_module.QDialog, "exec", lambda self: exec_calls.__setitem__("n", exec_calls["n"] + 1))
+    monkeypatch.setattr(
+        window_module.QDialog,
+        "exec",
+        lambda self: exec_calls.__setitem__("n", exec_calls["n"] + 1),
+    )
 
     easyrob_window.show_contact_dialog()
 
@@ -700,7 +815,11 @@ def test_show_contact_dialog_executes_modal(easyrob_window, monkeypatch):
 def test_show_version_dialog_executes_modal(easyrob_window, monkeypatch):
     """Version dialog is created and executed."""
     exec_calls = {"n": 0}
-    monkeypatch.setattr(window_module.QDialog, "exec", lambda self: exec_calls.__setitem__("n", exec_calls["n"] + 1))
+    monkeypatch.setattr(
+        window_module.QDialog,
+        "exec",
+        lambda self: exec_calls.__setitem__("n", exec_calls["n"] + 1),
+    )
 
     easyrob_window.show_version_dialog()
 
@@ -709,6 +828,7 @@ def test_show_version_dialog_executes_modal(easyrob_window, monkeypatch):
 
 def test_show_tutorial_dialog_reuses_visible_dialog(easyrob_window):
     """Reopening the tutorial dialog reuses the visible instance."""
+
     class DummyDialog:
         def __init__(self):
             self.raise_calls = 0
@@ -760,7 +880,11 @@ def test_check_for_pdfs_and_images_runs_with_existing_outputs(easyrob_window, tm
 def test_refresh_tabs_updates_children_and_schedules_once(easyrob_window, monkeypatch):
     """refresh_tabs stores the latest path and coalesces duplicate scheduling."""
     timer_calls = {"n": 0}
-    monkeypatch.setattr(window_module.QTimer, "singleShot", lambda ms, fn: timer_calls.__setitem__("n", timer_calls["n"] + 1))
+    monkeypatch.setattr(
+        window_module.QTimer,
+        "singleShot",
+        lambda ms, fn: timer_calls.__setitem__("n", timer_calls["n"] + 1),
+    )
 
     easyrob_window._refresh_scheduled = False
     easyrob_window.refresh_tabs("one.csv")
@@ -774,11 +898,31 @@ def test_execute_refresh_tabs_calls_all_child_refreshes(easyrob_window, monkeypa
     """_execute_refresh_tabs fans out the refresh to child tabs and file-based checks."""
     calls = {"results": 0, "images": 0, "predictions": 0, "pdfs": 0, "imgs": 0}
 
-    monkeypatch.setattr(easyrob_window.results_tab, "refresh_with_new_path", lambda p: calls.__setitem__("results", calls["results"] + 1))
-    monkeypatch.setattr(easyrob_window.images_tab, "refresh_with_new_path", lambda p: calls.__setitem__("images", calls["images"] + 1))
-    monkeypatch.setattr(easyrob_window.predictions_tab, "refresh_with_new_path", lambda p: calls.__setitem__("predictions", calls["predictions"] + 1))
-    monkeypatch.setattr(easyrob_window, "check_for_pdfs", lambda p: calls.__setitem__("pdfs", calls["pdfs"] + 1))
-    monkeypatch.setattr(easyrob_window, "check_for_images", lambda p: calls.__setitem__("imgs", calls["imgs"] + 1))
+    monkeypatch.setattr(
+        easyrob_window.results_tab,
+        "refresh_with_new_path",
+        lambda p: calls.__setitem__("results", calls["results"] + 1),
+    )
+    monkeypatch.setattr(
+        easyrob_window.images_tab,
+        "refresh_with_new_path",
+        lambda p: calls.__setitem__("images", calls["images"] + 1),
+    )
+    monkeypatch.setattr(
+        easyrob_window.predictions_tab,
+        "refresh_with_new_path",
+        lambda p: calls.__setitem__("predictions", calls["predictions"] + 1),
+    )
+    monkeypatch.setattr(
+        easyrob_window,
+        "check_for_pdfs",
+        lambda p: calls.__setitem__("pdfs", calls["pdfs"] + 1),
+    )
+    monkeypatch.setattr(
+        easyrob_window,
+        "check_for_images",
+        lambda p: calls.__setitem__("imgs", calls["imgs"] + 1),
+    )
 
     easyrob_window._pending_refresh_path = "demo.csv"
     easyrob_window._refresh_scheduled = True
@@ -790,6 +934,7 @@ def test_execute_refresh_tabs_calls_all_child_refreshes(easyrob_window, monkeypa
 
 def test_stop_process_confirms_and_stops_worker(easyrob_window, monkeypatch):
     """stop_process marks manual stop and schedules worker stop after confirmation."""
+
     class DummySignal:
         def connect(self, fn):
             self.fn = fn
@@ -807,8 +952,14 @@ def stop(self):
 
     worker = DummyWorker()
     timer_calls = {"n": 0}
-    monkeypatch.setattr(window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.Yes)
-    monkeypatch.setattr(window_module.QTimer, "singleShot", lambda ms, fn: (timer_calls.__setitem__("n", timer_calls["n"] + 1), fn())[1])
+    monkeypatch.setattr(
+        window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.Yes
+    )
+    monkeypatch.setattr(
+        window_module.QTimer,
+        "singleShot",
+        lambda ms, fn: (timer_calls.__setitem__("n", timer_calls["n"] + 1), fn())[1],
+    )
 
     easyrob_window.worker = worker
     easyrob_window.stop_button.setDisabled(False)
@@ -823,7 +974,9 @@ def stop(self):
 
 def test_stop_process_returns_when_user_declines(easyrob_window, monkeypatch):
     """stop_process does nothing if the user rejects the confirmation dialog."""
-    monkeypatch.setattr(window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.No)
+    monkeypatch.setattr(
+        window_module.QMessageBox, "question", lambda *args, **kwargs: QMessageBox.No
+    )
 
     easyrob_window.manual_stop = False
     easyrob_window.stop_process()
@@ -1011,14 +1164,24 @@ def test_predictions_filter_dataframe_orders_core_columns(easyrob_window, monkey
             "target_pred_sd": [0.3],
         }
     )
-    monkeypatch.setattr(easyrob_window.predictions_tab, "_extract_names_column_from_predict", lambda: "sample_id")
+    monkeypatch.setattr(
+        easyrob_window.predictions_tab,
+        "_extract_names_column_from_predict",
+        lambda: "sample_id",
+    )
 
     filtered = easyrob_window.predictions_tab._filter_prediction_dataframe(df)
 
-    assert list(filtered.columns) == ["Image", "sample_id", "SMILES", "target_pred", "target_pred_sd"]
+    assert list(filtered.columns) == [
+        "Image",
+        "sample_id",
+        "SMILES",
+        "target_pred",
+        "target_pred_sd",
+    ]
 
 
-def test_predictions_extract_names_column_from_predict(tmp_path):
+def test_predictions_extract_names_column_from_predict(tmp_path, predictions_tab):
     """The names field is extracted from the stored PREDICT command line."""
     predict_dir = tmp_path / "PREDICT"
     predict_dir.mkdir()
@@ -1026,23 +1189,19 @@ def test_predictions_extract_names_column_from_predict(tmp_path):
     dat_path.write_text('--names "code_name"\n', encoding="utf-8")
     (tmp_path / "input.csv").write_text("a,b\n1,2\n", encoding="utf-8")
 
-    tab = predictions_module.PredictionsTab()
-    tab._base_path = str(tmp_path / "input.csv")
+    predictions_tab._base_path = str(tmp_path / "input.csv")
 
-    assert tab._extract_names_column_from_predict() == "code_name"
+    assert predictions_tab._extract_names_column_from_predict() == "code_name"
 
 
-def test_results_tab_detects_and_refreshes_pdf_tabs(tmp_path, monkeypatch):
+def test_results_tab_detects_and_refreshes_pdf_tabs(tmp_path, qtbot, results_tab_mocks):
     """Results tab discovers PDFs and refreshes when a new path is provided."""
-    monkeypatch.setattr(results_module.QTimer, "singleShot", lambda ms, fn: None)
-    monkeypatch.setattr(results_module, "PDFViewer", lambda pdf_path, thread_pool: results_module.QWidget())
-
     run_dir = tmp_path / "run"
     run_dir.mkdir()
     first_pdf = run_dir / "ROBERT_report.pdf"
     first_pdf.write_text("pdf", encoding="utf-8")
 
-    tab = results_module.ResultsTab(None, str(run_dir / "input.csv"))
+    tab = _results_tab(qtbot, run_dir / "input.csv")
 
     assert first_pdf.name in tab.title_to_path
     assert tab.pdf_tab_widget.count() == 1
@@ -1058,9 +1217,9 @@ def test_results_tab_detects_and_refreshes_pdf_tabs(tmp_path, monkeypatch):
     assert second_pdf.name in tab.title_to_path
 
 
-def test_predictions_show_header_menu_sorts_and_histogram(monkeypatch):
+def test_predictions_show_header_menu_sorts_and_histogram(predictions_tab, monkeypatch):
     """Header context menu routes to sort and histogram actions."""
-    tab = predictions_module.PredictionsTab()
+    tab = predictions_tab
     df = pd.DataFrame({"num": [2, 1], "txt": ["b", "a"]})
 
     class DummyHeader:
@@ -1102,7 +1261,11 @@ def exec(self, pos):
 
     monkeypatch.setattr(predictions_module, "QMenu", DummyMenu)
     histogram_calls = {"n": 0}
-    monkeypatch.setattr(tab, "_show_histogram", lambda series, name: histogram_calls.__setitem__("n", histogram_calls["n"] + 1))
+    monkeypatch.setattr(
+        tab,
+        "_show_histogram",
+        lambda series, name: histogram_calls.__setitem__("n", histogram_calls["n"] + 1),
+    )
 
     header = DummyHeader()
     table = DummyTable()
@@ -1117,9 +1280,11 @@ def exec(self, pos):
     assert histogram_calls["n"] == 1
 
 
-def test_predictions_show_histogram_menu_non_numeric_shows_message(monkeypatch):
+def test_predictions_show_histogram_menu_non_numeric_shows_message(
+    predictions_tab, monkeypatch
+):
     """Non-numeric columns show an informational popup instead of plotting."""
-    tab = predictions_module.PredictionsTab()
+    tab = predictions_tab
     df = pd.DataFrame({"txt": ["a", "b"]})
     info_calls = {"n": 0}
 
@@ -1127,48 +1292,121 @@ class DummyHeader:
         def logicalIndexAt(self, pos):
             return 0
 
-    monkeypatch.setattr(predictions_module.QMessageBox, "information", lambda *args, **kwargs: info_calls.__setitem__("n", info_calls["n"] + 1))
+    monkeypatch.setattr(
+        predictions_module.QMessageBox,
+        "information",
+        lambda *args, **kwargs: info_calls.__setitem__("n", info_calls["n"] + 1),
+    )
 
     tab._show_histogram_menu_header(None, df, DummyHeader())
 
     assert info_calls["n"] == 1
 
 
-def test_predictions_show_histogram_uses_matplotlib(monkeypatch):
+def test_predictions_show_histogram_uses_matplotlib(predictions_tab, monkeypatch):
     """Histogram plotting delegates to matplotlib without blocking."""
-    tab = predictions_module.PredictionsTab()
+    tab = predictions_tab
     series = pd.Series([1, 2, 3])
-    calls = {"figure": 0, "title": 0, "xlabel": 0, "ylabel": 0, "grid": 0, "show": 0, "hist": 0}
+    calls = {
+        "figure": 0,
+        "title": 0,
+        "xlabel": 0,
+        "ylabel": 0,
+        "grid": 0,
+        "show": 0,
+        "hist": 0,
+    }
 
-    monkeypatch.setattr(predictions_module.plt, "figure", lambda: calls.__setitem__("figure", calls["figure"] + 1))
-    monkeypatch.setattr(predictions_module.plt, "title", lambda name: calls.__setitem__("title", calls["title"] + 1))
-    monkeypatch.setattr(predictions_module.plt, "xlabel", lambda name: calls.__setitem__("xlabel", calls["xlabel"] + 1))
-    monkeypatch.setattr(predictions_module.plt, "ylabel", lambda name: calls.__setitem__("ylabel", calls["ylabel"] + 1))
-    monkeypatch.setattr(predictions_module.plt, "grid", lambda enabled: calls.__setitem__("grid", calls["grid"] + 1))
-    monkeypatch.setattr(predictions_module.plt, "show", lambda block=False: calls.__setitem__("show", calls["show"] + 1))
-    monkeypatch.setattr(pd.Series, "hist", lambda self, bins=30: calls.__setitem__("hist", calls["hist"] + 1))
+    monkeypatch.setattr(
+        predictions_module.plt,
+        "figure",
+        lambda: calls.__setitem__("figure", calls["figure"] + 1),
+    )
+    monkeypatch.setattr(
+        predictions_module.plt,
+        "title",
+        lambda name: calls.__setitem__("title", calls["title"] + 1),
+    )
+    monkeypatch.setattr(
+        predictions_module.plt,
+        "xlabel",
+        lambda name: calls.__setitem__("xlabel", calls["xlabel"] + 1),
+    )
+    monkeypatch.setattr(
+        predictions_module.plt,
+        "ylabel",
+        lambda name: calls.__setitem__("ylabel", calls["ylabel"] + 1),
+    )
+    monkeypatch.setattr(
+        predictions_module.plt,
+        "grid",
+        lambda enabled: calls.__setitem__("grid", calls["grid"] + 1),
+    )
+    monkeypatch.setattr(
+        predictions_module.plt,
+        "show",
+        lambda block=False: calls.__setitem__("show", calls["show"] + 1),
+    )
+    monkeypatch.setattr(
+        pd.Series,
+        "hist",
+        lambda self, bins=30: calls.__setitem__("hist", calls["hist"] + 1),
+    )
 
     tab._show_histogram(series, "value")
 
-    assert calls == {"figure": 1, "title": 1, "xlabel": 1, "ylabel": 1, "grid": 1, "show": 1, "hist": 1}
+    assert calls == {
+        "figure": 1,
+        "title": 1,
+        "xlabel": 1,
+        "ylabel": 1,
+        "grid": 1,
+        "show": 1,
+        "hist": 1,
+    }
 
 
-def test_predictions_add_loaded_df_replaces_loading_tab(monkeypatch):
+def test_predictions_add_loaded_df_replaces_loading_tab(
+    predictions_tab, qtbot, monkeypatch
+):
     """Loaded prediction data replaces the placeholder tab widget."""
-    tab = predictions_module.PredictionsTab()
+    tab = predictions_tab
     tab._base_path = "demo.csv"
     tab.subtabs.addTab(predictions_module.QLabel("Loading"), "No PFI")
 
     df = pd.DataFrame({"SMILES": ["C"], "target_pred": [1.0]})
     monkeypatch.setattr(tab, "_filter_prediction_dataframe", lambda frame: frame)
-    monkeypatch.setattr(predictions_module, "evaluate_predictions_for_model", lambda base, frame, key: {"pdf_path": "report.pdf", "model": key, "scenario": "demo"})
-    monkeypatch.setattr(predictions_module, "get_robert_report_path", lambda base: "report.pdf")
-    monkeypatch.setattr(predictions_module, "extract_robert_fragment_image", lambda path, key: None)
-    monkeypatch.setattr(predictions_module, "extract_extrapolation_scores", lambda path: {"No_PFI": None})
-    monkeypatch.setattr(predictions_module, "extract_extrapolation_fragment", lambda path, key: None)
-    monkeypatch.setattr(predictions_module, "find_external_test_pixmaps", lambda base: {})
+    monkeypatch.setattr(
+        predictions_module,
+        "evaluate_predictions_for_model",
+        lambda base, frame, key: {
+            "pdf_path": "report.pdf",
+            "model": key,
+            "scenario": "demo",
+        },
+    )
+    monkeypatch.setattr(
+        predictions_module, "get_robert_report_path", lambda base: "report.pdf"
+    )
+    monkeypatch.setattr(
+        predictions_module, "extract_robert_fragment_image", lambda path, key: None
+    )
+    monkeypatch.setattr(
+        predictions_module,
+        "extract_extrapolation_scores",
+        lambda path: {"No_PFI": None},
+    )
+    monkeypatch.setattr(
+        predictions_module, "extract_extrapolation_fragment", lambda path, key: None
+    )
+    monkeypatch.setattr(
+        predictions_module, "find_external_test_pixmaps", lambda base: {}
+    )
     widget = predictions_module.QWidget()
-    monkeypatch.setattr(tab, "_create_table_with_stats", lambda frame, info, pdf_image: widget)
+    qtbot.addWidget(widget)
+    monkeypatch.setattr(
+        tab, "_create_table_with_stats", lambda frame, info, pdf_image: widget
+    )
 
     tab._add_loaded_df("No_PFI", df)
 
@@ -1176,17 +1414,21 @@ def test_predictions_add_loaded_df_replaces_loading_tab(monkeypatch):
     assert tab.subtabs.tabText(0) == "No PFI"
 
 
-def test_predictions_refresh_with_new_path_loads_csvs_synchronously(tmp_path, monkeypatch):
+def test_predictions_refresh_with_new_path_loads_csvs_synchronously(
+    tmp_path, predictions_tab, monkeypatch
+):
     """refresh_with_new_path discovers CSVs and materializes tabs when tasks run synchronously."""
     csv_test_dir = tmp_path / "PREDICT" / "csv_test"
     csv_test_dir.mkdir(parents=True)
 
     no_pfi_path = csv_test_dir / "demo_No_PFI.csv"
     pfi_path = csv_test_dir / "demo_PFI.csv"
-    pd.DataFrame({"SMILES": ["C"], "target_pred": [1.0]}).to_csv(no_pfi_path, index=False)
+    pd.DataFrame({"SMILES": ["C"], "target_pred": [1.0]}).to_csv(
+        no_pfi_path, index=False
+    )
     pd.DataFrame({"SMILES": ["CC"], "target_pred": [2.0]}).to_csv(pfi_path, index=False)
 
-    tab = predictions_module.PredictionsTab()
+    tab = predictions_tab
     created = []
 
     class FakePlaceholder:
@@ -1273,17 +1515,32 @@ def run(self):
     monkeypatch.setattr(
         predictions_module,
         "evaluate_predictions_for_model",
-        lambda base, frame, key: {"pdf_path": "report.pdf", "model": key, "scenario": "demo"},
+        lambda base, frame, key: {
+            "pdf_path": "report.pdf",
+            "model": key,
+            "scenario": "demo",
+        },
+    )
+    monkeypatch.setattr(
+        predictions_module, "get_robert_report_path", lambda base: "report.pdf"
+    )
+    monkeypatch.setattr(
+        predictions_module, "extract_robert_fragment_image", lambda path, key: None
+    )
+    monkeypatch.setattr(
+        predictions_module, "extract_extrapolation_scores", lambda path: {}
+    )
+    monkeypatch.setattr(
+        predictions_module, "extract_extrapolation_fragment", lambda path, key: None
+    )
+    monkeypatch.setattr(
+        predictions_module, "find_external_test_pixmaps", lambda base: {}
     )
-    monkeypatch.setattr(predictions_module, "get_robert_report_path", lambda base: "report.pdf")
-    monkeypatch.setattr(predictions_module, "extract_robert_fragment_image", lambda path, key: None)
-    monkeypatch.setattr(predictions_module, "extract_extrapolation_scores", lambda path: {})
-    monkeypatch.setattr(predictions_module, "extract_extrapolation_fragment", lambda path, key: None)
-    monkeypatch.setattr(predictions_module, "find_external_test_pixmaps", lambda base: {})
     monkeypatch.setattr(
         tab,
         "_create_table_with_stats",
-        lambda frame, info, pdf_image: created.append((info["model"], frame.copy())) or object(),
+        lambda frame, info, pdf_image: created.append((info["model"], frame.copy()))
+        or object(),
     )
 
     tab.refresh_with_new_path(str(tmp_path / "input.csv"))
@@ -1291,21 +1548,23 @@ def run(self):
     assert tab.placeholder.isHidden()
     assert not tab.subtabs.isHidden()
     assert tab.subtabs.count() == 2
-    assert {tab.subtabs.tabText(i) for i in range(tab.subtabs.count())} == {"No PFI", "PFI"}
+    assert {tab.subtabs.tabText(i) for i in range(tab.subtabs.count())} == {
+        "No PFI",
+        "PFI",
+    }
     assert {model for model, _ in created} == {"No_PFI", "PFI"}
 
 
-def test_results_clear_pdf_tabs_removes_placeholders(tmp_path, monkeypatch):
+def test_results_clear_pdf_tabs_removes_placeholders(
+    tmp_path, qtbot, results_tab_mocks
+):
     """clear_pdf_tabs removes tracked tabs and resets internal maps."""
-    monkeypatch.setattr(results_module.QTimer, "singleShot", lambda ms, fn: None)
-    monkeypatch.setattr(results_module, "PDFViewer", lambda pdf_path, thread_pool: results_module.QWidget())
-
     run_dir = tmp_path / "run"
     run_dir.mkdir()
     pdf_path = run_dir / "ROBERT_report.pdf"
     pdf_path.write_text("pdf", encoding="utf-8")
 
-    tab = results_module.ResultsTab(None, str(run_dir / "input.csv"))
+    tab = _results_tab(qtbot, run_dir / "input.csv")
     assert tab.pdf_tab_widget.count() == 1
 
     tab.clear_pdf_tabs()
@@ -1315,15 +1574,14 @@ def test_results_clear_pdf_tabs_removes_placeholders(tmp_path, monkeypatch):
     assert tab.title_to_path == {}
 
 
-def test_results_maybe_materialize_tab_builds_viewer(monkeypatch, tmp_path):
+def test_results_maybe_materialize_tab_builds_viewer(
+    qtbot, results_tab_mocks, monkeypatch, tmp_path
+):
     """Selecting a placeholder PDF tab materializes a real viewer."""
-    monkeypatch.setattr(results_module.QTimer, "singleShot", lambda ms, fn: None)
-    monkeypatch.setattr(results_module, "PDFViewer", lambda pdf_path, thread_pool: results_module.QWidget())
-
     run_dir = tmp_path / "run"
     run_dir.mkdir()
     pdf_path = run_dir / "ROBERT_report.pdf"
-    tab = results_module.ResultsTab(None, str(run_dir / "input.csv"))
+    tab = _results_tab(qtbot, run_dir / "input.csv")
     pdf_path.write_text("pdf", encoding="utf-8")
     tab.clear_pdf_tabs()
     tab.pdf_tabs[str(pdf_path)] = None
@@ -1333,24 +1591,27 @@ def test_results_maybe_materialize_tab_builds_viewer(monkeypatch, tmp_path):
     tab.pdf_tab_widget.blockSignals(False)
 
     viewer = results_module.QWidget()
-    monkeypatch.setattr(tab, "_materialize_pdf_viewer", lambda index, path: tab.pdf_tabs.__setitem__(path, viewer))
+    monkeypatch.setattr(
+        tab,
+        "_materialize_pdf_viewer",
+        lambda index, path: tab.pdf_tabs.__setitem__(path, viewer),
+    )
 
     tab._maybe_materialize_tab(0)
 
     assert tab.pdf_tabs[str(pdf_path)] is viewer
 
 
-def test_results_index_of_title_returns_expected_index(tmp_path, monkeypatch):
+def test_results_index_of_title_returns_expected_index(
+    tmp_path, qtbot, results_tab_mocks
+):
     """Tab titles can be resolved back to their index."""
-    monkeypatch.setattr(results_module.QTimer, "singleShot", lambda ms, fn: None)
-    monkeypatch.setattr(results_module, "PDFViewer", lambda pdf_path, thread_pool: results_module.QWidget())
-
     run_dir = tmp_path / "run"
     run_dir.mkdir()
     pdf_path = run_dir / "ROBERT_report.pdf"
     pdf_path.write_text("pdf", encoding="utf-8")
 
-    tab = results_module.ResultsTab(None, str(run_dir / "input.csv"))
+    tab = _results_tab(qtbot, run_dir / "input.csv")
 
     assert tab._index_of_title(pdf_path.name) == 0
     assert tab._index_of_title("missing.pdf") == -1
@@ -1365,7 +1626,14 @@ def test_workflow_selector_options(easyrob_window):
     """Workflow selector has all required entries and default."""
     window = easyrob_window
 
-    expected_workflows = ["Full Workflow", "CURATE", "GENERATE", "PREDICT", "VERIFY", "REPORT"]
+    expected_workflows = [
+        "Full Workflow",
+        "CURATE",
+        "GENERATE",
+        "PREDICT",
+        "VERIFY",
+        "REPORT",
+    ]
 
     workflow_items = [
         window.workflow_selector.itemText(i)
@@ -1402,6 +1670,7 @@ def test_progress_bar_exists(easyrob_window):
 # REAL End-to-End User Workflow Test (GUI)
 # =====================================================
 
+
 @pytest.mark.parametrize(
     "test_scenario",
     [
@@ -1440,8 +1709,11 @@ def test_full_user_workflow_end_to_end(
         * AQME workflow checkbox enabled.
         * Existing ROBERT folders (if any) are removed before starting.
         * AQME generates a mapped CSV and ROBERT runs with it.
-        * Check predictions tab 
+        * Check predictions tab
     """
+    if test_scenario == "aqme_regression" and not aqme_installed():
+        pytest.skip("AQME is not installed (pip install aqme==2.0.0)")
+
     window = easyrob_window
     config = SCENARIO_CONFIG[test_scenario]
 
@@ -1523,7 +1795,9 @@ def test_full_user_workflow_end_to_end(
     window.move_to_selected()
 
     # Collect ignored items from GUI
-    ignore_items = [window.ignore_list.item(i).text() for i in range(window.ignore_list.count())]
+    ignore_items = [
+        window.ignore_list.item(i).text() for i in range(window.ignore_list.count())
+    ]
     print(f"  • Ignored columns: {ignore_items}")
 
     actual = set(ignore_items)
@@ -1546,8 +1820,7 @@ def test_full_user_workflow_end_to_end(
     extra = actual - expected
 
     assert not missing and not extra, (
-        f"\nMissing items: {missing}"
-        f"\nExtra items: {extra}"
+        f"\nMissing items: {missing}\nExtra items: {extra}"
     )
 
     available_items = [
@@ -1661,9 +1934,16 @@ def test_full_user_workflow_end_to_end(
     # ------------------------------------------------------------------
     if test_scenario == "existing_dirs_stop":
         # Bootstrap the expected output state if it is not already present.
-        if not all((output_dir / d).is_dir() for d in expected_dirs) or not report_pdf.is_file():
-            print("[SETUP] Existing output folders missing; running baseline workflow first...")
-            run_full_workflow_and_wait(window, qtbot, output_dir, expected_dirs, report_pdf)
+        if (
+            not all((output_dir / d).is_dir() for d in expected_dirs)
+            or not report_pdf.is_file()
+        ):
+            print(
+                "[SETUP] Existing output folders missing; running baseline workflow first..."
+            )
+            run_full_workflow_and_wait(
+                window, qtbot, output_dir, expected_dirs, report_pdf
+            )
             QCoreApplication.processEvents()
             assert all((output_dir / d).is_dir() for d in expected_dirs)
             assert report_pdf.is_file()
@@ -1676,12 +1956,16 @@ def test_full_user_workflow_end_to_end(
 
         started = wait_for_workflow_start(window, baseline_text)
         if not started:
-            pytest.fail("Re-run did not start within timeout after existing-folders popup")
+            pytest.fail(
+                "Re-run did not start within timeout after existing-folders popup"
+            )
 
         print("\n[STEP 8] Clicking Stop ROBERT button...")
         qtbot.mouseClick(window.stop_button, Qt.LeftButton)
 
-        print("[STEP 9] Waiting for workflow to stop and GUI to return to idle state...")
+        print(
+            "[STEP 9] Waiting for workflow to stop and GUI to return to idle state..."
+        )
         stopped = process_events_until(
             lambda: (
                 getattr(window, "worker", None) is None
@@ -1706,7 +1990,9 @@ def test_full_user_workflow_end_to_end(
         ), "Expected a stop-process QMessageBox.question"
 
         final_console_text = window.console_output.toPlainText()
-        dump_console_output("[STEP 10] Final console output (existing_dirs_stop):", final_console_text)
+        dump_console_output(
+            "[STEP 10] Final console output (existing_dirs_stop):", final_console_text
+        )
 
         assert window.file_path == str(csv_path)
 
@@ -1756,10 +2042,14 @@ def test_full_user_workflow_end_to_end(
         print(f"[OK] AQME mapped CSV exists: {mapped_csv_path}")
 
         predict_csv_test_dir = output_dir / "PREDICT" / "csv_test"
-        assert predict_csv_test_dir.is_dir(), "PREDICT/csv_test directory was not created"
+        assert predict_csv_test_dir.is_dir(), (
+            "PREDICT/csv_test directory was not created"
+        )
 
         prediction_csvs = predictions_module.find_prediction_csvs(str(csv_path))
-        assert prediction_csvs, "Predictions CSVs were not generated for the external test set"
+        assert prediction_csvs, (
+            "Predictions CSVs were not generated for the external test set"
+        )
         print(f"[OK] Predictions CSVs detected: {sorted(prediction_csvs)}")
 
     print("\n" + "=" * 80)
@@ -1777,6 +2067,9 @@ def test_run_aqme_only_end_to_end(easyrob_window, test_output_dir, qtbot, monkey
     - wait for the subprocess to finish
     - verify AQME outputs were generated
     """
+    if not aqme_installed():
+        pytest.skip("AQME is not installed (pip install aqme==2.0.0)")
+
     window = easyrob_window
     finished = {"exit_code": None}
 
@@ -1840,6 +2133,7 @@ def _on_process_finished_stub(exit_code):
 # ChemDraw → popup → table → CSV → main window test
 # =====================================================
 
+
 def test_open_chemdraw_popup_end_to_end_cdxml(
     easyrob_window, qtbot, monkeypatch, test_output_dir
 ):
@@ -1881,16 +2175,13 @@ def exec(self):
             return QDialog.Accepted
 
     # Patch the symbol that open_chemdraw_popup uses
-    monkeypatch.setattr(
-        aqme_module, "ChemDrawFileDialog", FakeChemDrawFileDialog
-    )
+    monkeypatch.setattr(aqme_module, "ChemDrawFileDialog", FakeChemDrawFileDialog)
 
     # --------------------------------------------------------------
     # 4. Stub QFileDialog.getSaveFileName so CSV is written to tmp_path
     # --------------------------------------------------------------
     csv_path = test_output_dir / "chemdraw_table_output.csv"
 
-
     def _fake_get_save_file_name(*args, **kwargs):
         return (str(csv_path), "CSV Files (*.csv)")
 
@@ -1914,7 +2205,9 @@ def _fake_dialog_exec(self: QDialog):
         table = self.findChild(QTableWidget)
         assert table is not None, "ChemDraw table dialog should contain a QTableWidget."
 
-        headers = [table.horizontalHeaderItem(i).text() for i in range(table.columnCount())]
+        headers = [
+            table.horizontalHeaderItem(i).text() for i in range(table.columnCount())
+        ]
         assert "SMILES" in headers
         assert "code_name" in headers
         assert "target" in headers
@@ -1949,7 +2242,9 @@ def _fake_dialog_exec(self: QDialog):
             if "Save as CSV" in btn.text():
                 save_button = btn
                 break
-        assert save_button is not None, "Could not find 'Save as CSV' button in ChemDraw dialog."
+        assert save_button is not None, (
+            "Could not find 'Save as CSV' button in ChemDraw dialog."
+        )
 
         # Click it → this will call save_to_csv() and then dialog.accept()
         save_button.click()
@@ -1984,7 +2279,6 @@ def _fake_dialog_exec(self: QDialog):
 
     # Columns should be loaded into dropdowns
     y_items = [window.y_dropdown.itemText(i) for i in range(window.y_dropdown.count())]
-    names_items = [window.names_dropdown.itemText(i) for i in range(window.names_dropdown.count())]
 
     assert "SMILES" in y_items
     assert "code_name" in y_items
@@ -1995,8 +2289,7 @@ def _fake_dialog_exec(self: QDialog):
         for i in range(window.available_list.count())
     ]
     ignored_items = [
-        window.ignore_list.item(i).text()
-        for i in range(window.ignore_list.count())
+        window.ignore_list.item(i).text() for i in range(window.ignore_list.count())
     ]
     assert "SMILES" in ignored_items
     assert "code_name" in available_items
diff --git a/tests/test_plot_metrics.py b/tests/test_plot_metrics.py
new file mode 100644
index 0000000..36d176a
--- /dev/null
+++ b/tests/test_plot_metrics.py
@@ -0,0 +1,36 @@
+#!/usr/bin/env python
+
+"""Tests for VERIFY metrics plotting."""
+
+import pytest
+
+from robert.utils import plot_metrics
+
+
+@pytest.fixture
+def verify_plot_env(tmp_path, monkeypatch):
+    monkeypatch.chdir(tmp_path)
+    verify_dir = tmp_path / "VERIFY"
+    verify_dir.mkdir()
+    (verify_dir / "RF_No_PFI").touch()
+    return tmp_path
+
+
+def test_plot_metrics_equal_values(verify_plot_env):
+    """Degenerate axis limits (all metrics equal) must not break plotting."""
+    model_data = {"model": "RF_db.csv"}
+    verify_metrics = {
+        "metrics": [0.5, 0.5, 0.5, 0.5],
+        "test_names": ["Model", "y-mean", "y-shuffle", "one-hot"],
+        "colors": ["#808080", "#1f77b4", "#1f77b4", "#1f77b4"],
+        "higher_thres": 0.3,
+        "unclear_higher_thres": 0.2,
+        "lower_thres": 0.3,
+        "unclear_lower_thres": 0.2,
+    }
+    verify_results = {"error_type": "r2"}
+
+    msg = plot_metrics(model_data, "No_PFI", verify_metrics, verify_results)
+    png = verify_plot_env / "VERIFY" / "VERIFY_tests_RF_No_PFI.png"
+    assert png.is_file()
+    assert "VERIFY plot saved" in msg
diff --git a/tests/test_vr_bo.py b/tests/test_vr_bo.py
new file mode 100644
index 0000000..89bbfda
--- /dev/null
+++ b/tests/test_vr_bo.py
@@ -0,0 +1,81 @@
+#!/usr/bin/env python
+
+"""Tests for Voting Regressor/Classifier Bayesian optimization."""
+
+from types import SimpleNamespace
+
+import numpy as np
+
+from robert.argument_parser import options_add
+from robert.utils import load_minimal_model, load_model, model_adjust_params
+
+
+def _vr_adapter(problem_type="reg"):
+    args = options_add()
+    args.seed = 42
+    args.type = problem_type
+    return SimpleNamespace(args=args)
+
+
+def test_vr_member_hyperparameters_affect_predictions():
+    adapter = _vr_adapter("reg")
+    rng = np.random.RandomState(0)
+    X = rng.rand(24, 6)
+    y = X.sum(axis=1) + rng.randn(24) * 0.05
+
+    params_low = model_adjust_params(adapter, "VR", dict(load_minimal_model("VR")))
+    params_high = model_adjust_params(
+        adapter,
+        "VR",
+        {**load_minimal_model("VR"), "rf_n_estimators": 90, "gb_n_estimators": 90},
+    )
+
+    model_low = load_model(adapter, "VR", **params_low)
+    model_high = load_model(adapter, "VR", **params_high)
+    model_low.fit(X, y)
+    model_high.fit(X, y)
+
+    assert not np.allclose(model_low.predict(X), model_high.predict(X))
+
+
+def test_vr_ensemble_weights_affect_predictions():
+    adapter = _vr_adapter("reg")
+    rng = np.random.RandomState(1)
+    X = rng.rand(24, 6)
+    y = X.sum(axis=1) + rng.randn(24) * 0.05
+
+    base = load_minimal_model("VR")
+    params_a = model_adjust_params(
+        adapter, "VR", {**base, "w_rf": 5.0, "w_gb": 0.2, "w_nn": 0.2}
+    )
+    params_b = model_adjust_params(
+        adapter, "VR", {**base, "w_rf": 0.2, "w_gb": 5.0, "w_nn": 0.2}
+    )
+
+    model_a = load_model(adapter, "VR", **params_a)
+    model_b = load_model(adapter, "VR", **params_b)
+    model_a.fit(X, y)
+    model_b.fit(X, y)
+
+    assert not np.allclose(model_a.predict(X), model_b.predict(X))
+
+
+def test_vr_bo_bounds_include_member_models():
+    from robert.utils import BO_hyperparams
+
+    bounds = BO_hyperparams("VR")
+    assert "w_rf" in bounds
+    assert "rf_n_estimators" in bounds
+    assert "gb_learning_rate" in bounds
+    assert "nn_hidden_layer_1" in bounds
+
+
+def test_vr_classification_loads():
+    adapter = _vr_adapter("clas")
+    y = np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1])
+    X = np.random.RandomState(2).rand(len(y), 4)
+    params = model_adjust_params(adapter, "VR", dict(load_minimal_model("VR")))
+    model = load_model(adapter, "VR", **params)
+    model.fit(X, y)
+    preds = model.predict(X)
+    assert preds.shape == y.shape

From df98f82d80d442cffdec4e60bd18c3fef3163845 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Rub=C3=A9n=20Laplaza?=
 <30357710+rlaplaza@users.noreply.github.com>
Date: Mon, 18 May 2026 21:49:32 +0200
Subject: [PATCH 8/8] Fixes for circleci tests in windows (#78)

* ci: add Ruff lint and format checks

Pin Ruff defaults in pyproject.toml, reformat Python sources,
fix default lint violations, and gate CircleCI on ruff check
and ruff format --check.

* feat: API scores, VR tuning, and multi-platform CI

Add RobertModel.robert_scores(), VR hyperparameter BO support, and
v2.2.0 test extras. Extend tests (plot metrics, VR/BO, API) and docs.
Refactor CircleCI to run the shared conda suite on Linux, Windows, and
macOS.

* fix: remove incomplete project table from pyproject.toml

The partial [project] section lacked a required version field and broke
pip install during CI; package metadata remains in setup.py.

* fix(ci): use preinstalled Miniconda on Windows executor

CircleCI windows-server-2022-gui images already include Miniconda. The
install-miniconda step only checked $HOME/miniconda3, missed the system
install, and hung trying to reinstall via the silent .exe installer.
---
 .circleci/config.yml | 14 ++++++++------
 pyproject.toml       |  3 ---
 2 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/.circleci/config.yml b/.circleci/config.yml
index 74aa66b..522b1f0 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -8,22 +8,25 @@ orbs:
 
 commands:
   install-miniconda:
-    description: Install Miniconda on machine executors (Windows/macOS)
+    description: Install Miniconda on machine executors (macOS only; Windows images ship with Miniconda)
     steps:
       - run:
           name: Install Miniconda
           command: |
             set -euo pipefail
+            if command -v conda >/dev/null 2>&1; then
+              echo "Conda already available at $(command -v conda)"
+              conda --version
+              exit 0
+            fi
             MINICONDA_DIR="${HOME}/miniconda3"
             if [ -x "${MINICONDA_DIR}/Scripts/conda.exe" ] || [ -x "${MINICONDA_DIR}/bin/conda" ]; then
               echo "Miniconda already present at ${MINICONDA_DIR}"
             else
               case "$(uname -s)" in
                 MINGW*|MSYS*|CYGWIN*)
-                  INSTALLER_URL="https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe"
-                  curl -fsSL -o miniconda-installer.exe "${INSTALLER_URL}"
-                  ./miniconda-installer.exe /InstallationType=JustMe /RegisterPython=0 /S /D="${MINICONDA_DIR}"
-                  rm -f miniconda-installer.exe
+                  echo "install-miniconda: unexpected Windows path without preinstalled conda"
+                  exit 1
                   ;;
                 Darwin)
                   ARCH="$(uname -m)"
@@ -276,7 +279,6 @@ jobs:
     shell: bash.exe
     steps:
       - checkout
-      - install-miniconda
       - run-conda-test-suite:
           upload_coverage: false
 
diff --git a/pyproject.toml b/pyproject.toml
index 49dfc1e..a330a0a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,3 @@
-[project]
-name = "robert"
-
 [tool.ruff]
 target-version = "py311"
 line-length = 88