From 81b35eafc453793d37068ae00c8d4565c640a6af Mon Sep 17 00:00:00 2001 From: Rafael Date: Sun, 30 Nov 2025 23:40:23 +0200 Subject: [PATCH] new: add regression --- frame/evaluate.py | 74 +++++++++++++++------- frame/explain.py | 18 ++++-- frame/source/explain/__init__.py | 14 ++++- frame/source/models/__init__.py | 9 ++- frame/source/train/__init__.py | 11 +++- frame/source/train/epoch.py | 77 +++++++++++++++-------- frame/source/train/metrics.py | 101 +++++++++++++++++++++++++++++++ frame/train.py | 17 ++++-- frame/tune.py | 43 +++++++------ parameters.yaml | 1 + 10 files changed, 285 insertions(+), 80 deletions(-) create mode 100644 frame/source/train/metrics.py diff --git a/frame/evaluate.py b/frame/evaluate.py index 0ba2077..893e9ce 100644 --- a/frame/evaluate.py +++ b/frame/evaluate.py @@ -8,7 +8,7 @@ from tqdm import tqdm from sklearn import metrics -from frame.source import models +from frame.source import models, train from torch_geometric.loader import DataLoader device = "cuda" if torch.cuda.is_available() else "cpu" @@ -31,6 +31,7 @@ def main(): path_checkpoint = config["path_checkpoint"] model_name = config.get("model", "gat").lower() batch_size = config.get("batch_size", 64) + task = config.get("task", "classification").lower() # * Initialize name = config["name"] @@ -68,31 +69,58 @@ def main(): batch=data.batch) # * Read prediction values - detach = torch.sigmoid(model_out).cpu().detach() - pred_lbl = (detach >= 0.5).int() - pred = list(torch.ravel(detach).cpu().detach().numpy()) + if task == "classification": + detach = torch.sigmoid(model_out).cpu().detach() + pred = list(torch.ravel(detach).cpu().detach().numpy()) + pred_lbl = (detach >= 0.5).int() + agg_lbl += pred_lbl.flatten().tolist() + else: + detach = model_out.cpu().detach() + pred = list(torch.ravel(detach).cpu().detach().numpy()) + pred_lbl = None # * Save prediction values agg_pred += pred - agg_lbl += pred_lbl.flatten().tolist() agg_true += data.y.flatten().tolist() # * Get metrics - acc = metrics.accuracy_score(agg_true, agg_lbl) - acc_bal = metrics.balanced_accuracy_score(agg_true, agg_lbl) - f1 = metrics.f1_score(agg_true, agg_lbl, zero_division=0) - prec = metrics.precision_score(agg_true, agg_lbl, zero_division=0) - rec = metrics.recall_score(agg_true, agg_lbl, zero_division=0) - mcc = metrics.matthews_corrcoef(agg_true, agg_lbl) - roc_auc = metrics.roc_auc_score(agg_true, agg_pred) - avg_prec = metrics.average_precision_score(agg_true, agg_pred) - - print(f"\n========= {name}" - f"\n{'Accuracy:':<19}{round(acc, 3)}" - f"\n{'Balanced Accuracy:':<19}{round(acc_bal, 3)}" - f"\n{'F1:':<19}{round(f1, 3)}" - f"\n{'MCC:':<19}{round(mcc, 3)}" - f"\n{'Precision:':<19}{round(prec, 3)}" - f"\n{'Recall:':<19}{round(rec, 3)}" - f"\n{'Avg. Precision:':<19}{round(avg_prec, 3)}" - f"\n{'ROC-AUC:':<19}{round(roc_auc, 3)}\n") + if task == "classification": + acc = metrics.accuracy_score(agg_true, agg_lbl) + acc_bal = metrics.balanced_accuracy_score(agg_true, agg_lbl) + f1 = metrics.f1_score(agg_true, agg_lbl, zero_division=0) + prec = metrics.precision_score(agg_true, agg_lbl, zero_division=0) + rec = metrics.recall_score(agg_true, agg_lbl, zero_division=0) + mcc = metrics.matthews_corrcoef(agg_true, agg_lbl) + roc_auc = metrics.roc_auc_score(agg_true, agg_pred) + avg_prec = metrics.average_precision_score(agg_true, agg_pred) + + print(f"\n========= {name}" + f"\n{'Accuracy:':<19}{round(acc, 3)}" + f"\n{'Balanced Accuracy:':<19}{round(acc_bal, 3)}" + f"\n{'F1:':<19}{round(f1, 3)}" + f"\n{'MCC:':<19}{round(mcc, 3)}" + f"\n{'Precision:':<19}{round(prec, 3)}" + f"\n{'Recall:':<19}{round(rec, 3)}" + f"\n{'Avg. Precision:':<19}{round(avg_prec, 3)}" + f"\n{'ROC-AUC:':<19}{round(roc_auc, 3)}\n") + + else: + r2 = metrics.r2_score(agg_true, agg_pred) + rmse = metrics.root_mean_squared_error(agg_true, agg_pred) + mae = metrics.mean_absolute_error(agg_true, agg_pred) + + rto_r2, _ = train.reg_through_origin(agg_true, agg_pred) + ccc = train.concordance_correlation(agg_true, agg_pred) + roy_c = train.roy_criteria(agg_true, agg_pred, inverse=False) + roy_c_inv = train.roy_criteria(agg_true, agg_pred, inverse=True) + delta = train.golbraikh_tropsha(agg_true, agg_pred) + + print(f"\n========= {name}" + f"\n{'R²:':<19}{round(r2, 3)}" + f"\n{'RMSE:':<19}{round(rmse, 3)}" + f"\n{'MAE:':<19}{round(mae, 3)}" + f"\n{'RTO R²:':<19}{round(rto_r2, 3)}" + f"\n{'CCC:':<19}{round(ccc, 3)}" + f"\n{'Roy Criteria:':<19}{round(roy_c, 3)}" + f"\n{'Roy C. Inverse:':<19}{round(roy_c_inv, 3)}" + f"\n{'Delta:':<19}{round(delta, 3)}\n") diff --git a/frame/explain.py b/frame/explain.py index 914a739..644abfc 100644 --- a/frame/explain.py +++ b/frame/explain.py @@ -32,6 +32,7 @@ def main(): path_checkpoint = config["path_checkpoint"] model_name = config.get("model", "gat").lower() batch_size = config.get("batch_size", 64) + task = config.get("task", "classification").lower() # * Initialize name = config["name"] @@ -62,12 +63,16 @@ def main(): model.load_state_dict(torch.load(path_checkpoint)) model.eval() + if task == "classification": + mode = "multiclass_classification" + else: + mode = "regression" explainer = Explainer(model=model, algorithm=CaptumExplainer("IntegratedGradients"), explanation_type="model", edge_mask_type="object", node_mask_type="attributes", - model_config=dict(mode="multiclass_classification", + model_config=dict(mode=mode, task_level="graph", return_type="raw")) @@ -81,9 +86,14 @@ def main(): batch=data.batch) # * Read prediction values - detach = torch.sigmoid(model_out).cpu().detach() - pred_lbl = (detach >= 0.5).int() - pred = list(torch.ravel(detach).cpu().detach().numpy()) + if task == "classification": + detach = torch.sigmoid(model_out).cpu().detach() + pred = list(torch.ravel(detach).cpu().detach().numpy()) + pred_lbl = (detach >= 0.5).int() + else: + detach = model_out.cpu().detach() + pred = list(torch.ravel(detach).cpu().detach().numpy()) + pred_lbl = [None] * detach.shape[0] # * Explain explanation = explainer(data.x.float(), data.edge_index, diff --git a/frame/source/explain/__init__.py b/frame/source/explain/__init__.py index fd4a56f..b06bebe 100644 --- a/frame/source/explain/__init__.py +++ b/frame/source/explain/__init__.py @@ -55,12 +55,15 @@ def retrieve_info(self, graphs): def _info_atom(self, graphs): batch_num = self.batch.unique() masks = [self.mask[self.batch == b] for b in batch_num] + pred_label = "" for idx in range(len(masks)): data = graphs[idx] real_label = int(data.y.cpu().numpy()[0]) pred = self.pred[idx] - pred_label = self.pred_lbl[idx].numpy()[0] + + if self.pred_lbl[idx] is not None: + pred_label = self.pred_lbl[idx].numpy()[0] text = (f"{data.idx},{data.smiles},{real_label}," f"{pred_label},{pred:.3f}\n") @@ -72,14 +75,17 @@ def _info_atom(self, graphs): def _info_fragment(self, graphs): batch_num = self.batch.unique() masks = [self.mask[self.batch == b] for b in batch_num] + pred_label = "" for idx, node_mask in enumerate(masks): data = graphs[idx] real_label = int(data.y.cpu().numpy()[0]) pred = self.pred[idx] - pred_label = self.pred_lbl[idx].numpy()[0] fragments = np.array(data.frag) + if self.pred_lbl[idx] is not None: + pred_label = self.pred_lbl[idx].numpy()[0] + mask_list = node_mask.cpu().numpy().tolist() mask_list = [[f"{m:.3f}" for m in mask] for mask in mask_list] @@ -97,12 +103,14 @@ def _info_fragment(self, graphs): def plot_explanations(self, graphs): batch_num = self.batch.unique() masks = [self.mask[self.batch == b] for b in batch_num] + pred_label = "" for idx, node_mask in enumerate(masks): data = graphs[idx] name = data.idx pred = self.pred[idx] - pred_label = self.pred_lbl[idx].numpy()[0] + if self.pred_lbl[idx] is not None: + pred_label = self.pred_lbl[idx].numpy()[0] if self.loader == "default": self._explain_atom(data, node_mask, pred, pred_label, name) diff --git a/frame/source/models/__init__.py b/frame/source/models/__init__.py index f17262f..6b37e97 100644 --- a/frame/source/models/__init__.py +++ b/frame/source/models/__init__.py @@ -6,6 +6,7 @@ def model_setup(model_name, config): + task = config["task"] model = select_model(model_name, config) base_optimizer = torch.optim.Adam(model.parameters(), @@ -19,8 +20,12 @@ def model_setup(model_name, config): T_max=100, eta_min=1e-6) - bce_weight = config["bce_weight"] - lossfn = torch.nn.BCEWithLogitsLoss(pos_weight=bce_weight).to(device) + if task == "classification": + bce_weight = config["bce_weight"] + lossfn = torch.nn.BCEWithLogitsLoss(pos_weight=bce_weight).to(device) + + else: + lossfn = torch.nn.MSELoss() return model, optimizer, scheduler, lossfn diff --git a/frame/source/train/__init__.py b/frame/source/train/__init__.py index 925fcd2..43ed1b5 100644 --- a/frame/source/train/__init__.py +++ b/frame/source/train/__init__.py @@ -1,8 +1,17 @@ from frame.source.train.optimizer import Lookahead from frame.source.train.epoch import train_epoch, valid_epoch +from frame.source.train.metrics import (reg_through_origin, + concordance_correlation, + roy_criteria, + golbraikh_tropsha) __all__ = ["train_epoch", "valid_epoch", - "Lookahead"] + "Lookahead", + + "reg_through_origin", + "concordance_correlation", + "roy_criteria", + "golbraikh_tropsha"] diff --git a/frame/source/train/epoch.py b/frame/source/train/epoch.py index 1b78ee6..1be88d0 100644 --- a/frame/source/train/epoch.py +++ b/frame/source/train/epoch.py @@ -5,6 +5,8 @@ from sklearn import metrics import torch.backends.cudnn as cudnn +from frame.source.train import metrics as reg_metrics + random.seed(8) np.random.seed(8) @@ -49,7 +51,7 @@ def train_epoch(model, optim, scheduler, lossfn, loader): @torch.no_grad() -def valid_epoch(model, loader): +def valid_epoch(model, task, loader): model.eval() true = [] @@ -66,34 +68,61 @@ def valid_epoch(model, loader): batch=batch.batch) # * Read prediction values - detach = torch.sigmoid(out).cpu().detach() - discretized = (detach >= 0.5).int() - batch_pred = list(torch.ravel(detach).cpu().detach().numpy()) - batch_label = list(torch.ravel(discretized).cpu().detach().numpy()) - batch_true = list(torch.ravel(batch.y).cpu().detach().numpy()) + if task == "classification": + detach = torch.sigmoid(out).cpu().detach() + discretized = (detach >= 0.5).int() + batch_pred = list(torch.ravel(detach).cpu().detach().numpy()) + batch_label = list(torch.ravel(discretized).cpu().detach().numpy()) + label = label + batch_label + + else: + detach = out.cpu().detach() + batch_pred = list(torch.ravel(detach).cpu().detach().numpy()) + true = true + batch_true pred = pred + batch_pred - label = label + batch_label # * Get metrics - acc = metrics.accuracy_score(true, label) - acc_bal = metrics.balanced_accuracy_score(true, label) - f1 = metrics.f1_score(true, label, zero_division=0) - prec = metrics.precision_score(true, label, zero_division=0) - rec = metrics.recall_score(true, label, zero_division=0) - mcc = metrics.matthews_corrcoef(true, label) - roc_auc = metrics.roc_auc_score(true, pred) - avg_prec = metrics.average_precision_score(true, pred) - - result = {"acc": round(acc, 3), - "acc_bal": round(acc_bal, 3), - "f1": round(f1, 3), - "prec": round(prec, 3), - "rec": round(rec, 3), - "mcc": round(mcc, 3), - "avg_prec": round(avg_prec, 3), - "roc_auc": round(roc_auc, 3)} + if task == "classification": + acc = metrics.accuracy_score(true, label) + acc_bal = metrics.balanced_accuracy_score(true, label) + f1 = metrics.f1_score(true, label, zero_division=0) + prec = metrics.precision_score(true, label, zero_division=0) + rec = metrics.recall_score(true, label, zero_division=0) + mcc = metrics.matthews_corrcoef(true, label) + roc_auc = metrics.roc_auc_score(true, pred) + avg_prec = metrics.average_precision_score(true, pred) + + result = {"optim": round(mcc, 3), + "acc": round(acc, 3), + "acc_bal": round(acc_bal, 3), + "f1": round(f1, 3), + "prec": round(prec, 3), + "rec": round(rec, 3), + "mcc": round(mcc, 3), + "avg_prec": round(avg_prec, 3), + "roc_auc": round(roc_auc, 3)} + else: + r2 = metrics.r2_score(true, pred) + rmse = metrics.root_mean_squared_error(true, pred) + mae = metrics.mean_absolute_error(true, pred) + + rto_r2, _ = reg_metrics.reg_through_origin(true, pred) + ccc = reg_metrics.concordance_correlation(true, pred) + roy_c = reg_metrics.roy_criteria(true, pred, inverse=False) + roy_c_inv = reg_metrics.roy_criteria(true, pred, inverse=True) + delta = reg_metrics.golbraikh_tropsha(true, pred) + + result = {"optim": round(ccc, 3), + "r2": round(r2, 3), + "rmse": round(rmse, 3), + "mae": round(mae, 3), + "rto_r2": round(rto_r2, 3), + "ccc": round(ccc, 3), + "roy_c": round(roy_c, 3), + "roy_c_inv": round(roy_c_inv, 3), + "delta": round(delta, 3)} return result diff --git a/frame/source/train/metrics.py b/frame/source/train/metrics.py new file mode 100644 index 0000000..010b5b6 --- /dev/null +++ b/frame/source/train/metrics.py @@ -0,0 +1,101 @@ +import math +import numpy as np +from sklearn import metrics +from sklearn.linear_model import LinearRegression + + +def reg_through_origin(y_true, y_pred): + """Regression Through Origin (RTO) coefficient of determination + + Args: + y_true (np.array): True labels + y_pred (np.array): Predicted labels + + Returns: + float: RTO coefficient of determination + """ + y_true = np.asarray(y_true) + y_pred = np.asarray(y_pred) + true = y_true.reshape(-1, 1) + pred = y_pred.reshape(-1, 1) + + regression = LinearRegression(fit_intercept=False) + rto = regression.fit(pred, true) + + rto_r2 = rto.score(pred, true) + slope = regression.coef_ + + return rto_r2, float(slope) + + +def concordance_correlation(y_true, y_pred): + """Concordance Correlation Coefficient (CCC) + https://doi.org/10.2307/2532051 + + Args: + y_true (np.array): True labels + y_pred (np.array): Predicted labels + + Returns: + float: Coefficient + """ + y_true = np.asarray(y_true) + y_pred = np.asarray(y_pred) + mean_true = y_true.mean() + mean_pred = y_pred.mean() + + vx, cov_xy, cov_xy, vy = np.cov(y_true, y_pred, bias=True).flat + ccc = 2 * cov_xy / (vx + vy + (mean_true - mean_pred) ** 2) + + return ccc + + +def roy_criteria(y_true, y_pred, inverse=False): + """Proposed criteria by Roy based on regression through origin (RTO) + https://doi.org/10.1016/j.ejps.2014.05.019 + + Args: + y_true (np.array): True labels + y_pred (np.array): Predicted labels + + Returns: + float: Roy criteria + """ + y_true = np.asarray(y_true) + y_pred = np.asarray(y_pred) + + if inverse: + rto, _ = reg_through_origin(y_pred, y_true) + r2 = metrics.r2_score(y_pred, y_true) + + roy = r2 * (1 - math.sqrt(abs(r2 - rto))) + + else: + rto, _ = reg_through_origin(y_true, y_pred) + r2 = metrics.r2_score(y_true, y_pred) + + roy = r2 * (1 - math.sqrt(abs(r2 - rto))) + + return roy + + +def golbraikh_tropsha(y_true, y_pred): + """Proposed criteria by Alexander Golbraikh and Alexander Tropsha + https://doi.org/10.1016/S1093-3263(01)00123-1 + + Args: + y_true (np.array): True labels + y_pred (np.array): Predicted labels + + Returns: + float: Golbraikh and Tropsha criteria + """ + y_true = np.asarray(y_true) + y_pred = np.asarray(y_pred) + + rto_0, _ = reg_through_origin(y_true, y_pred) + rto_1, _ = reg_through_origin(y_pred, y_true) + + delta = abs(rto_0 - rto_1) + + return delta diff --git a/frame/train.py b/frame/train.py index 99da1ef..4a3295a 100644 --- a/frame/train.py +++ b/frame/train.py @@ -23,6 +23,7 @@ def run(params, dataset): size = params["Data"].get("batch_size", 32) patience = params["Data"].get("patience", 5) model_name = params["Data"].get("model", "gat").lower() + task = params["Data"].get("task", "classification").lower() project_dir = params["Data"]["project_dir"] @@ -36,6 +37,7 @@ def run(params, dataset): config["feat_size"] = params["Data"]["feat_size"] config["edge_dim"] = params["Data"]["edge_dim"] config["bce_weight"] = params["Data"]["bce_weight"] + config["task"] = task params["Data"]["trial"] = None # * Prepare dataloader @@ -58,12 +60,12 @@ def run(params, dataset): best_model_state = None for epoch in tqdm(range(epochs), ncols=120, desc="Training"): _ = train.train_epoch(model, optim, schdlr, lossfn, train_loader) - val_metrics = train.valid_epoch(model, valid_loader) + val_metrics = train.valid_epoch(model, task, valid_loader) # Early stopping check - if val_metrics["mcc"] > best_metric: + if val_metrics["optim"] > best_metric: patience_counter = 0 - best_metric = val_metrics["mcc"] + best_metric = val_metrics["optim"] best_model_state = copy.deepcopy(model.state_dict()) else: patience_counter += 1 @@ -76,9 +78,12 @@ def run(params, dataset): model.load_state_dict(best_model_state) os.makedirs(project_dir, exist_ok=True) torch.save(best_model_state, str(project_dir / "best_model.pt")) - results = train.valid_epoch(model, valid_loader) + results = train.valid_epoch(model, task, valid_loader) - print(f"MCC: {results['mcc']}") + if task == "classification": + print(f"MCC: {results['mcc']}") + else: + print(f"CCC: {results['ccc']}") def main(): @@ -91,7 +96,7 @@ def main(): # * Initialize name = params["Data"]["name"] if name.lower() == "none": - name = str(uuid.uuid4()).split["-"][0] + name = str(uuid.uuid4()).split("-")[0] params["Data"]["name"] = name cwd = Path(os.getcwd()) diff --git a/frame/tune.py b/frame/tune.py index 1ec2589..476ae15 100644 --- a/frame/tune.py +++ b/frame/tune.py @@ -30,6 +30,7 @@ def objective(trial, params, dataset): patience = params["Data"].get("patience", 5) model_name = params["Data"].get("model", "gat").lower() max_retries = params["Data"].get("max_retries", 5) + task = params["Data"].get("task", "classification").lower() project_dir = params["Data"]["project_dir"] @@ -38,6 +39,7 @@ def objective(trial, params, dataset): config["feat_size"] = params["Data"]["feat_size"] config["edge_dim"] = params["Data"]["edge_dim"] config["bce_weight"] = params["Data"]["bce_weight"] + config["task"] = task params["Data"]["trial"] = trial # * Prepare dataloader @@ -58,7 +60,7 @@ def objective(trial, params, dataset): retries = 0 while retries < max_retries: try: - best_metric = -1.0 + best_metric = -1000 patience_counter = 0 best_model_state = None @@ -66,12 +68,12 @@ def objective(trial, params, dataset): start = time.time() _ = train.train_epoch(model, optim, schdlr, lossfn, train_loader) - val_metrics = train.valid_epoch(model, valid_loader) + val_metrics = train.valid_epoch(model, task, valid_loader) # Early stopping check - if val_metrics["mcc"] > best_metric: + if val_metrics["optim"] > best_metric: patience_counter = 0 - best_metric = val_metrics["mcc"] + best_metric = val_metrics["optim"] best_model_state = copy.deepcopy(model.state_dict()) else: patience_counter += 1 @@ -87,7 +89,7 @@ def objective(trial, params, dataset): trial_dir = project_dir / f"trial_{trial.number}" os.makedirs(trial_dir, exist_ok=True) torch.save(best_model_state, str(trial_dir / "best_model.pt")) - results = train.valid_epoch(model, valid_loader) + results = train.valid_epoch(model, task, valid_loader) # Get model complexity n_params = filter(lambda p: p.requires_grad, model.parameters()) @@ -98,7 +100,7 @@ def objective(trial, params, dataset): trial.set_user_attr("fit_time", float(round(fit_time, 3))) trial.set_user_attr("metrics", results) - return results["mcc"] + return results["optim"] except torch.cuda.OutOfMemoryError: retries += 1 @@ -111,15 +113,20 @@ def objective(trial, params, dataset): raise optuna.exceptions.TrialPruned() -def get_dataframe(study): +def get_dataframe(study, task): records = [] for trial in study.trials: record = {"trial": trial.number, "optim": trial.value} - dummy = {"acc": np.nan, "acc_bal": np.nan, "f1": np.nan, - "prec": np.nan, "rec": np.nan, "mcc": np.nan, - "avg_prec": np.nan, "roc_auc": np.nan} + if task == "classification": + dummy = {"optim": np.nan, "acc": np.nan, "acc_bal": np.nan, + "f1": np.nan, "prec": np.nan, "rec": np.nan, + "mcc": np.nan, "avg_prec": np.nan, "roc_auc": np.nan} + else: + dummy = {"optim": np.nan, "r2": np.nan, "rmse": np.nan, + "mae": np.nan, "rto_r2": np.nan, "ccc": np.nan, + "roy_c": np.nan, "roy_c_inv": np.nan, "delta": np.nan} # Get user attrs val_metrics = trial.user_attrs.get("metrics", dummy) @@ -146,9 +153,10 @@ def main(): params = yaml.safe_load(stream) # * Initialize + task = name = params["Data"]["task"] name = params["Data"]["name"] if name.lower() == "none": - name = str(uuid.uuid4()).split["-"][0] + name = str(uuid.uuid4()).split("-")[0] params["Data"]["name"] = name cwd = Path(os.getcwd()) @@ -180,15 +188,16 @@ def main(): n_trials=1) if study.study_name == name: - df = get_dataframe(study) + df = get_dataframe(study, task) df.to_csv(project_dir / f"{name}.csv", index=False) # plot parallel plot - header = ["optim", "acc", "acc_bal", "f1", - "prec", "rec", "mcc", "avg_prec", - "roc_auc", "n_params", "fit_time"] + header = ["optim", "acc", "acc_bal", "f1", "prec", "rec", + "mcc", "avg_prec", "roc_auc", "r2", "rmse", + "mae", "rto_r2", "ccc", "roy_c", "roy_c_inv", + "delta", "n_params", "fit_time"] feats = [col for col in list(df.columns) if col not in header] - feats = feats + ["mcc"] + feats = feats + ["optim"] dimensions = [] for col in feats: @@ -197,7 +206,7 @@ def main(): range=[col_values.min(), col_values.max()]) dimensions.append(dim) - fig = go.Figure(data=go.Parcoords(line=dict(color=df["mcc"], + fig = go.Figure(data=go.Parcoords(line=dict(color=df["optim"], colorscale="viridis", showscale=False), dimensions=dimensions)) diff --git a/parameters.yaml b/parameters.yaml index 56d3800..b3dc327 100755 --- a/parameters.yaml +++ b/parameters.yaml @@ -1,5 +1,6 @@ Data: name: None + task: Classification path_csv: "path/to/csv_file.csv" path_joblib: "path/to/joblib_file.joblib" path_checkpoint: "path/to/checkpoint.pt"