Skip to content

Commit 18c456f

Browse files
authored
Merge pull request #3 from predomics/feat/metadata-regression
feat: metadata upload, variable selection, regression support
2 parents 3f83eb1 + 98d7256 commit 18c456f

File tree

10 files changed

+405
-7
lines changed

10 files changed

+405
-7
lines changed

backend/app/models/schemas.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ class FitFunction(str, Enum):
3535
mcc = "mcc"
3636
f1_score = "f1_score"
3737
g_mean = "g_mean"
38+
# Regression fit functions
39+
spearman = "spearman"
40+
pearson = "pearson"
41+
rmse = "rmse"
42+
mutual_information = "mutual_information"
43+
44+
45+
REGRESSION_FIT_FUNCTIONS = {"spearman", "pearson", "rmse", "mutual_information"}
3846

3947

4048
class JobStatus(str, Enum):
@@ -375,6 +383,12 @@ class DatasetUpdate(BaseModel):
375383
description: Optional[str] = None
376384

377385

386+
class YFromMetadataRequest(BaseModel):
387+
"""Request to generate a y file from a metadata column."""
388+
column: str
389+
file_role: str = "ytrain"
390+
391+
378392
class ProjectUpdate(BaseModel):
379393
name: Optional[str] = None
380394
description: Optional[str] = None

backend/app/routers/analysis.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,14 +467,18 @@ def _run_job(job_id: str, project_id: str, param_path: str, user_id: str = "") -
467467
raise RuntimeError(f"Worker exited with code {proc.returncode}: {error_msg}")
468468

469469
# Extract best_auc/best_k from results for fast list_jobs
470+
# For regression runs, use the fit value as the primary metric
470471
best_auc_val = None
471472
best_k_val = None
472473
if results_path.exists():
473474
try:
474475
with open(results_path) as rf:
475476
res = json.load(rf)
476477
best = res.get("best_individual", {})
477-
best_auc_val = best.get("auc")
478+
if res.get("regression"):
479+
best_auc_val = best.get("fit")
480+
else:
481+
best_auc_val = best.get("auc")
478482
best_k_val = best.get("k")
479483
except Exception:
480484
pass

backend/app/routers/datasets.py

Lines changed: 231 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..core.deps import get_current_user
1616
from ..core.rate_limit import limiter
1717
from ..models.db_models import User, Dataset, DatasetFile, ProjectDataset, Project, DatasetVersion
18-
from ..models.schemas import DatasetResponse, DatasetUpdate, DatasetFileRef
18+
from ..models.schemas import DatasetResponse, DatasetUpdate, DatasetFileRef, YFromMetadataRequest
1919
from ..services import storage, audit
2020

2121
_log = logging.getLogger(__name__)
@@ -554,6 +554,236 @@ async def preview_file(
554554
}
555555

556556

557+
# ---------------------------------------------------------------------------
558+
# Metadata column inspection & y-from-metadata
559+
# ---------------------------------------------------------------------------
560+
561+
def _find_metadata_file(files) -> Optional[Path]:
562+
"""Find the metadata file among dataset files."""
563+
for f in files:
564+
name = f.filename.lower()
565+
if f.role == "metadata" or "metadata" in name or "meta" in name:
566+
p = Path(f.disk_path)
567+
if p.exists():
568+
return p
569+
return None
570+
571+
572+
def _parse_metadata_columns(meta_path: Path) -> list[dict]:
573+
"""Parse a metadata TSV and return column descriptors with types and stats."""
574+
sample = meta_path.read_text(errors="replace")[:4096]
575+
delimiter = "\t" if "\t" in sample else ","
576+
577+
all_rows = []
578+
with open(meta_path, "r", errors="replace") as f:
579+
reader = csv.reader(f, delimiter=delimiter)
580+
for line in reader:
581+
all_rows.append(line)
582+
583+
if len(all_rows) < 2:
584+
return []
585+
586+
header = all_rows[0]
587+
data_rows = all_rows[1:]
588+
columns = []
589+
590+
for col_idx, col_name in enumerate(header):
591+
if col_idx == 0:
592+
continue # skip sample ID column
593+
values = []
594+
for row in data_rows:
595+
if col_idx < len(row) and row[col_idx].strip():
596+
values.append(row[col_idx].strip())
597+
598+
if not values:
599+
continue
600+
601+
# Try to detect numeric vs categorical
602+
numeric_vals = []
603+
for v in values:
604+
try:
605+
numeric_vals.append(float(v))
606+
except (ValueError, TypeError):
607+
pass
608+
609+
if len(numeric_vals) > len(values) * 0.8:
610+
# Numeric column
611+
columns.append({
612+
"name": col_name,
613+
"type": "numeric",
614+
"min": round(min(numeric_vals), 6),
615+
"max": round(max(numeric_vals), 6),
616+
"n_values": len(numeric_vals),
617+
"n_missing": len(data_rows) - len(numeric_vals),
618+
})
619+
else:
620+
# Categorical column
621+
unique_vals = sorted(set(values))
622+
columns.append({
623+
"name": col_name,
624+
"type": "categorical",
625+
"values": unique_vals[:50], # cap at 50 unique values
626+
"n_unique": len(unique_vals),
627+
"n_values": len(values),
628+
"n_missing": len(data_rows) - len(values),
629+
})
630+
631+
return columns
632+
633+
634+
@router.get("/{dataset_id}/metadata-columns")
635+
async def get_metadata_columns(
636+
dataset_id: str,
637+
user: User = Depends(get_current_user),
638+
db: AsyncSession = Depends(get_db),
639+
):
640+
"""Get metadata column names, types, and summary stats.
641+
642+
Numeric columns can be used as regression targets, categorical as classification targets.
643+
"""
644+
result = await db.execute(
645+
select(Dataset)
646+
.where(Dataset.id == dataset_id, Dataset.user_id == user.id)
647+
.options(selectinload(Dataset.files))
648+
)
649+
dataset = result.scalar_one_or_none()
650+
if not dataset:
651+
raise HTTPException(status_code=404, detail="Dataset not found")
652+
653+
meta_path = _find_metadata_file(dataset.files)
654+
if not meta_path:
655+
raise HTTPException(
656+
status_code=404,
657+
detail="No metadata file found in this dataset. Upload a file with role 'metadata'.",
658+
)
659+
660+
columns = _parse_metadata_columns(meta_path)
661+
return {"columns": columns}
662+
663+
664+
@router.post("/{dataset_id}/y-from-metadata")
665+
async def generate_y_from_metadata(
666+
dataset_id: str,
667+
body: YFromMetadataRequest,
668+
user: User = Depends(get_current_user),
669+
db: AsyncSession = Depends(get_db),
670+
):
671+
"""Generate a y file from a metadata column, matching samples with the X file.
672+
673+
The extracted column is written as a TSV file and registered in the dataset.
674+
"""
675+
result = await db.execute(
676+
select(Dataset)
677+
.where(Dataset.id == dataset_id, Dataset.user_id == user.id)
678+
.options(selectinload(Dataset.files))
679+
)
680+
dataset = result.scalar_one_or_none()
681+
if not dataset:
682+
raise HTTPException(status_code=404, detail="Dataset not found")
683+
684+
# Find metadata file
685+
meta_path = _find_metadata_file(dataset.files)
686+
if not meta_path:
687+
raise HTTPException(status_code=404, detail="No metadata file found in this dataset.")
688+
689+
# Find X file to get sample names
690+
x_role = "xtrain" if body.file_role == "ytrain" else "xtest"
691+
x_file = None
692+
for f in dataset.files:
693+
if f.role == x_role:
694+
x_file = f
695+
break
696+
if not x_file or not Path(x_file.disk_path).exists():
697+
raise HTTPException(
698+
status_code=400,
699+
detail=f"No {x_role} file found. Upload an X file first.",
700+
)
701+
702+
# Read metadata TSV
703+
meta_sample = meta_path.read_text(errors="replace")[:4096]
704+
meta_delim = "\t" if "\t" in meta_sample else ","
705+
meta_rows = []
706+
with open(meta_path, "r", errors="replace") as f:
707+
reader = csv.reader(f, delimiter=meta_delim)
708+
for line in reader:
709+
meta_rows.append(line)
710+
711+
if len(meta_rows) < 2:
712+
raise HTTPException(status_code=400, detail="Metadata file is empty or has no data rows.")
713+
714+
meta_header = meta_rows[0]
715+
if body.column not in meta_header:
716+
raise HTTPException(
717+
status_code=400,
718+
detail=f"Column '{body.column}' not found in metadata. Available: {meta_header[1:]}",
719+
)
720+
col_idx = meta_header.index(body.column)
721+
722+
# Build sample -> value map from metadata (first column = sample ID)
723+
meta_map = {}
724+
for row in meta_rows[1:]:
725+
if len(row) > col_idx and row[0].strip() and row[col_idx].strip():
726+
meta_map[row[0].strip()] = row[col_idx].strip()
727+
728+
# Read X file to get sample names (column headers if features_in_rows, else first column)
729+
x_sample = Path(x_file.disk_path).read_text(errors="replace")[:4096]
730+
x_delim = "\t" if "\t" in x_sample else ","
731+
with open(x_file.disk_path, "r", errors="replace") as f:
732+
x_reader = csv.reader(f, delimiter=x_delim)
733+
x_header = next(x_reader)
734+
735+
# Assume features in rows: sample names are column headers (skip first)
736+
x_sample_names = [s.strip() for s in x_header[1:]]
737+
738+
# Match samples
739+
matched = {}
740+
missing = []
741+
for sample in x_sample_names:
742+
if sample in meta_map:
743+
matched[sample] = meta_map[sample]
744+
else:
745+
missing.append(sample)
746+
747+
if not matched:
748+
raise HTTPException(
749+
status_code=400,
750+
detail="No matching samples between X file and metadata.",
751+
)
752+
753+
# Write y file as TSV: sample_id\tvalue
754+
lines = ["sample_id\t" + body.column]
755+
for sample in x_sample_names:
756+
if sample in matched:
757+
lines.append(f"{sample}\t{matched[sample]}")
758+
y_content = "\n".join(lines) + "\n"
759+
760+
# Register the file in the dataset
761+
filename = f"{body.file_role}_{body.column}.tsv"
762+
ds_file = DatasetFile(
763+
dataset_id=dataset.id,
764+
filename=filename,
765+
role=body.file_role,
766+
disk_path="",
767+
)
768+
db.add(ds_file)
769+
await db.flush()
770+
771+
disk_path = storage.save_user_dataset_file(user.id, ds_file.id, filename, y_content.encode("utf-8"))
772+
ds_file.disk_path = disk_path
773+
774+
await _create_version_snapshot(db, dataset_id, user.id, note=f"Generate {filename} from metadata")
775+
776+
# Auto-scan if xtrain + ytrain now present
777+
await _try_auto_scan(db, dataset)
778+
779+
return {
780+
"file": DatasetFileRef(id=ds_file.id, filename=ds_file.filename, role=ds_file.role).model_dump(),
781+
"matched_samples": len(matched),
782+
"missing_samples": len(missing),
783+
"total_x_samples": len(x_sample_names),
784+
}
785+
786+
557787
# ---------------------------------------------------------------------------
558788
# Project assignment
559789
# ---------------------------------------------------------------------------

backend/app/services/engine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import yaml
1313

1414
from ..core.config import settings as app_settings
15+
from ..models.schemas import REGRESSION_FIT_FUNCTIONS
1516

1617
logger = logging.getLogger(__name__)
1718

@@ -117,6 +118,9 @@ def _merge(section_key, defaults):
117118
"save_exp": str(Path(output_dir) / "experiment.bin"),
118119
})
119120

121+
# When fit is a regression function, disable class-based feature selection
122+
is_regression = general.get("fit") in REGRESSION_FIT_FUNCTIONS
123+
120124
data_cfg = config.get("data", {})
121125

122126
cv = _merge("cv", {
@@ -212,6 +216,7 @@ def _merge(section_key, defaults):
212216
"feature_maximal_adj_pvalue": data_cfg.get("feature_maximal_adj_pvalue", 0.05),
213217
"feature_minimal_feature_value": data_cfg.get("feature_minimal_feature_value", 0.0),
214218
**({"classes": data_cfg["classes"]} if data_cfg.get("classes") else {}),
219+
**({"feature_maximal_adj_pvalue": 1.0} if is_regression else {}),
215220
},
216221
"cv": cv,
217222
"importance": importance,

backend/app/services/worker.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,10 +430,14 @@ def main():
430430
param_path = sys.argv[1]
431431
results_path = sys.argv[2]
432432

433+
REGRESSION_FITS = {"spearman", "pearson", "rmse", "mutual_information"}
434+
433435
# Check if this is a sklearn algorithm
434436
with open(param_path) as _f:
435437
_param_yaml = yaml.safe_load(_f)
436438
algo = _param_yaml.get("general", {}).get("algo", "ga")
439+
fit_function = _param_yaml.get("general", {}).get("fit", "auc")
440+
is_regression = fit_function in REGRESSION_FITS
437441

438442
from .sklearn_runner import is_sklearn_algo
439443
if is_sklearn_algo(algo):
@@ -574,6 +578,13 @@ def main():
574578
if stability_data is not None:
575579
results["stability"] = stability_data
576580

581+
# Store regression metadata when using a regression fit function
582+
if is_regression:
583+
results["regression"] = {
584+
"fit_function": fit_function,
585+
"best_fit": metrics.get("fit"),
586+
}
587+
577588
# Clinical integration (if enabled)
578589
clinical_cfg = _param_yaml.get("clinical", {})
579590
if clinical_cfg.get("enabled") and clinical_cfg.get("path"):

frontend/src/data/parameterDefs.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ export const PARAM_DEFS = [
5050
{ value: 'auc', label: 'AUC' }, { value: 'mcc', label: 'MCC' }, { value: 'f1_score', label: 'F1 Score' },
5151
{ value: 'sensitivity', label: 'Sensitivity' }, { value: 'specificity', label: 'Specificity' },
5252
{ value: 'g_mean', label: 'Geometric Mean' }, { value: 'npv', label: 'NPV' }, { value: 'ppv', label: 'PPV' },
53+
{ value: 'spearman', label: 'Spearman correlation' }, { value: 'pearson', label: 'Pearson correlation' },
54+
{ value: 'rmse', label: 'RMSE' }, { value: 'mutual_information', label: 'Mutual Information' },
5355
],
5456
},
5557
{

frontend/src/i18n/locales/en.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,11 @@
642642
"nNeighbors": "Neighbors",
643643
"minDist": "Min distance",
644644
"top": "Top",
645-
"loadingData": "Loading data..."
645+
"loadingData": "Loading data...",
646+
"metadata": "Metadata (optional)",
647+
"selectYVariable": "Select y variable",
648+
"regressionMode": "Regression mode",
649+
"metadataUpload": "Upload metadata"
646650
},
647651
"dataExplore": {
648652
"summary": "Summary",

frontend/src/i18n/locales/fr.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,11 @@
642642
"nNeighbors": "Voisins",
643643
"minDist": "Distance min.",
644644
"top": "Top",
645-
"loadingData": "Chargement des données..."
645+
"loadingData": "Chargement des données...",
646+
"metadata": "Métadonnées (optionnel)",
647+
"selectYVariable": "Sélectionner la variable y",
648+
"regressionMode": "Mode régression",
649+
"metadataUpload": "Importer les métadonnées"
646650
},
647651
"dataExplore": {
648652
"summary": "Résumé",

0 commit comments

Comments
 (0)