|
15 | 15 | from ..core.deps import get_current_user |
16 | 16 | from ..core.rate_limit import limiter |
17 | 17 | from ..models.db_models import User, Dataset, DatasetFile, ProjectDataset, Project, DatasetVersion |
18 | | -from ..models.schemas import DatasetResponse, DatasetUpdate, DatasetFileRef |
| 18 | +from ..models.schemas import DatasetResponse, DatasetUpdate, DatasetFileRef, YFromMetadataRequest |
19 | 19 | from ..services import storage, audit |
20 | 20 |
|
21 | 21 | _log = logging.getLogger(__name__) |
@@ -554,6 +554,236 @@ async def preview_file( |
554 | 554 | } |
555 | 555 |
|
556 | 556 |
|
| 557 | +# --------------------------------------------------------------------------- |
| 558 | +# Metadata column inspection & y-from-metadata |
| 559 | +# --------------------------------------------------------------------------- |
| 560 | + |
| 561 | +def _find_metadata_file(files) -> Optional[Path]: |
| 562 | + """Find the metadata file among dataset files.""" |
| 563 | + for f in files: |
| 564 | + name = f.filename.lower() |
| 565 | + if f.role == "metadata" or "metadata" in name or "meta" in name: |
| 566 | + p = Path(f.disk_path) |
| 567 | + if p.exists(): |
| 568 | + return p |
| 569 | + return None |
| 570 | + |
| 571 | + |
| 572 | +def _parse_metadata_columns(meta_path: Path) -> list[dict]: |
| 573 | + """Parse a metadata TSV and return column descriptors with types and stats.""" |
| 574 | + sample = meta_path.read_text(errors="replace")[:4096] |
| 575 | + delimiter = "\t" if "\t" in sample else "," |
| 576 | + |
| 577 | + all_rows = [] |
| 578 | + with open(meta_path, "r", errors="replace") as f: |
| 579 | + reader = csv.reader(f, delimiter=delimiter) |
| 580 | + for line in reader: |
| 581 | + all_rows.append(line) |
| 582 | + |
| 583 | + if len(all_rows) < 2: |
| 584 | + return [] |
| 585 | + |
| 586 | + header = all_rows[0] |
| 587 | + data_rows = all_rows[1:] |
| 588 | + columns = [] |
| 589 | + |
| 590 | + for col_idx, col_name in enumerate(header): |
| 591 | + if col_idx == 0: |
| 592 | + continue # skip sample ID column |
| 593 | + values = [] |
| 594 | + for row in data_rows: |
| 595 | + if col_idx < len(row) and row[col_idx].strip(): |
| 596 | + values.append(row[col_idx].strip()) |
| 597 | + |
| 598 | + if not values: |
| 599 | + continue |
| 600 | + |
| 601 | + # Try to detect numeric vs categorical |
| 602 | + numeric_vals = [] |
| 603 | + for v in values: |
| 604 | + try: |
| 605 | + numeric_vals.append(float(v)) |
| 606 | + except (ValueError, TypeError): |
| 607 | + pass |
| 608 | + |
| 609 | + if len(numeric_vals) > len(values) * 0.8: |
| 610 | + # Numeric column |
| 611 | + columns.append({ |
| 612 | + "name": col_name, |
| 613 | + "type": "numeric", |
| 614 | + "min": round(min(numeric_vals), 6), |
| 615 | + "max": round(max(numeric_vals), 6), |
| 616 | + "n_values": len(numeric_vals), |
| 617 | + "n_missing": len(data_rows) - len(numeric_vals), |
| 618 | + }) |
| 619 | + else: |
| 620 | + # Categorical column |
| 621 | + unique_vals = sorted(set(values)) |
| 622 | + columns.append({ |
| 623 | + "name": col_name, |
| 624 | + "type": "categorical", |
| 625 | + "values": unique_vals[:50], # cap at 50 unique values |
| 626 | + "n_unique": len(unique_vals), |
| 627 | + "n_values": len(values), |
| 628 | + "n_missing": len(data_rows) - len(values), |
| 629 | + }) |
| 630 | + |
| 631 | + return columns |
| 632 | + |
| 633 | + |
| 634 | +@router.get("/{dataset_id}/metadata-columns") |
| 635 | +async def get_metadata_columns( |
| 636 | + dataset_id: str, |
| 637 | + user: User = Depends(get_current_user), |
| 638 | + db: AsyncSession = Depends(get_db), |
| 639 | +): |
| 640 | + """Get metadata column names, types, and summary stats. |
| 641 | +
|
| 642 | + Numeric columns can be used as regression targets, categorical as classification targets. |
| 643 | + """ |
| 644 | + result = await db.execute( |
| 645 | + select(Dataset) |
| 646 | + .where(Dataset.id == dataset_id, Dataset.user_id == user.id) |
| 647 | + .options(selectinload(Dataset.files)) |
| 648 | + ) |
| 649 | + dataset = result.scalar_one_or_none() |
| 650 | + if not dataset: |
| 651 | + raise HTTPException(status_code=404, detail="Dataset not found") |
| 652 | + |
| 653 | + meta_path = _find_metadata_file(dataset.files) |
| 654 | + if not meta_path: |
| 655 | + raise HTTPException( |
| 656 | + status_code=404, |
| 657 | + detail="No metadata file found in this dataset. Upload a file with role 'metadata'.", |
| 658 | + ) |
| 659 | + |
| 660 | + columns = _parse_metadata_columns(meta_path) |
| 661 | + return {"columns": columns} |
| 662 | + |
| 663 | + |
| 664 | +@router.post("/{dataset_id}/y-from-metadata") |
| 665 | +async def generate_y_from_metadata( |
| 666 | + dataset_id: str, |
| 667 | + body: YFromMetadataRequest, |
| 668 | + user: User = Depends(get_current_user), |
| 669 | + db: AsyncSession = Depends(get_db), |
| 670 | +): |
| 671 | + """Generate a y file from a metadata column, matching samples with the X file. |
| 672 | +
|
| 673 | + The extracted column is written as a TSV file and registered in the dataset. |
| 674 | + """ |
| 675 | + result = await db.execute( |
| 676 | + select(Dataset) |
| 677 | + .where(Dataset.id == dataset_id, Dataset.user_id == user.id) |
| 678 | + .options(selectinload(Dataset.files)) |
| 679 | + ) |
| 680 | + dataset = result.scalar_one_or_none() |
| 681 | + if not dataset: |
| 682 | + raise HTTPException(status_code=404, detail="Dataset not found") |
| 683 | + |
| 684 | + # Find metadata file |
| 685 | + meta_path = _find_metadata_file(dataset.files) |
| 686 | + if not meta_path: |
| 687 | + raise HTTPException(status_code=404, detail="No metadata file found in this dataset.") |
| 688 | + |
| 689 | + # Find X file to get sample names |
| 690 | + x_role = "xtrain" if body.file_role == "ytrain" else "xtest" |
| 691 | + x_file = None |
| 692 | + for f in dataset.files: |
| 693 | + if f.role == x_role: |
| 694 | + x_file = f |
| 695 | + break |
| 696 | + if not x_file or not Path(x_file.disk_path).exists(): |
| 697 | + raise HTTPException( |
| 698 | + status_code=400, |
| 699 | + detail=f"No {x_role} file found. Upload an X file first.", |
| 700 | + ) |
| 701 | + |
| 702 | + # Read metadata TSV |
| 703 | + meta_sample = meta_path.read_text(errors="replace")[:4096] |
| 704 | + meta_delim = "\t" if "\t" in meta_sample else "," |
| 705 | + meta_rows = [] |
| 706 | + with open(meta_path, "r", errors="replace") as f: |
| 707 | + reader = csv.reader(f, delimiter=meta_delim) |
| 708 | + for line in reader: |
| 709 | + meta_rows.append(line) |
| 710 | + |
| 711 | + if len(meta_rows) < 2: |
| 712 | + raise HTTPException(status_code=400, detail="Metadata file is empty or has no data rows.") |
| 713 | + |
| 714 | + meta_header = meta_rows[0] |
| 715 | + if body.column not in meta_header: |
| 716 | + raise HTTPException( |
| 717 | + status_code=400, |
| 718 | + detail=f"Column '{body.column}' not found in metadata. Available: {meta_header[1:]}", |
| 719 | + ) |
| 720 | + col_idx = meta_header.index(body.column) |
| 721 | + |
| 722 | + # Build sample -> value map from metadata (first column = sample ID) |
| 723 | + meta_map = {} |
| 724 | + for row in meta_rows[1:]: |
| 725 | + if len(row) > col_idx and row[0].strip() and row[col_idx].strip(): |
| 726 | + meta_map[row[0].strip()] = row[col_idx].strip() |
| 727 | + |
| 728 | + # Read X file to get sample names (column headers if features_in_rows, else first column) |
| 729 | + x_sample = Path(x_file.disk_path).read_text(errors="replace")[:4096] |
| 730 | + x_delim = "\t" if "\t" in x_sample else "," |
| 731 | + with open(x_file.disk_path, "r", errors="replace") as f: |
| 732 | + x_reader = csv.reader(f, delimiter=x_delim) |
| 733 | + x_header = next(x_reader) |
| 734 | + |
| 735 | + # Assume features in rows: sample names are column headers (skip first) |
| 736 | + x_sample_names = [s.strip() for s in x_header[1:]] |
| 737 | + |
| 738 | + # Match samples |
| 739 | + matched = {} |
| 740 | + missing = [] |
| 741 | + for sample in x_sample_names: |
| 742 | + if sample in meta_map: |
| 743 | + matched[sample] = meta_map[sample] |
| 744 | + else: |
| 745 | + missing.append(sample) |
| 746 | + |
| 747 | + if not matched: |
| 748 | + raise HTTPException( |
| 749 | + status_code=400, |
| 750 | + detail="No matching samples between X file and metadata.", |
| 751 | + ) |
| 752 | + |
| 753 | + # Write y file as TSV: sample_id\tvalue |
| 754 | + lines = ["sample_id\t" + body.column] |
| 755 | + for sample in x_sample_names: |
| 756 | + if sample in matched: |
| 757 | + lines.append(f"{sample}\t{matched[sample]}") |
| 758 | + y_content = "\n".join(lines) + "\n" |
| 759 | + |
| 760 | + # Register the file in the dataset |
| 761 | + filename = f"{body.file_role}_{body.column}.tsv" |
| 762 | + ds_file = DatasetFile( |
| 763 | + dataset_id=dataset.id, |
| 764 | + filename=filename, |
| 765 | + role=body.file_role, |
| 766 | + disk_path="", |
| 767 | + ) |
| 768 | + db.add(ds_file) |
| 769 | + await db.flush() |
| 770 | + |
| 771 | + disk_path = storage.save_user_dataset_file(user.id, ds_file.id, filename, y_content.encode("utf-8")) |
| 772 | + ds_file.disk_path = disk_path |
| 773 | + |
| 774 | + await _create_version_snapshot(db, dataset_id, user.id, note=f"Generate {filename} from metadata") |
| 775 | + |
| 776 | + # Auto-scan if xtrain + ytrain now present |
| 777 | + await _try_auto_scan(db, dataset) |
| 778 | + |
| 779 | + return { |
| 780 | + "file": DatasetFileRef(id=ds_file.id, filename=ds_file.filename, role=ds_file.role).model_dump(), |
| 781 | + "matched_samples": len(matched), |
| 782 | + "missing_samples": len(missing), |
| 783 | + "total_x_samples": len(x_sample_names), |
| 784 | + } |
| 785 | + |
| 786 | + |
557 | 787 | # --------------------------------------------------------------------------- |
558 | 788 | # Project assignment |
559 | 789 | # --------------------------------------------------------------------------- |
|
0 commit comments