diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml new file mode 100644 index 0000000..a809403 --- /dev/null +++ b/.github/workflows/unit-tests.yml @@ -0,0 +1,43 @@ +# This is a basic workflow to help you get started with Actions + +name: unit tests + +# Controls when the workflow will run +on: + # Triggers the workflow on push or pull request events but only for the "main" branch + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +# A workflow run is made up of one or more jobs that can run sequentially or in parallel +jobs: + # This workflow contains a single job called "build" + build: + # The type of runner that the job will run on + runs-on: ubuntu-latest + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v4 + + - name: Set up conda environment + uses: conda-incubator/setup-miniconda@v3 + with: + miniforge-variant: Miniforge3 # uses mamba automatically + activate-environment: contextscore + environment-file: environment.yml + auto-activate-base: false + use-mamba: true + cache-environment: true # ← caches the env + cache-downloads: true # ← caches downloaded packages + + - name: Run tests + shell: bash --login {0} + run: | + mkdir tests/output + python -m pytest diff --git a/.gitignore b/.gitignore index 15201ac..81d2c36 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,22 @@ cython_debug/ # PyPI configuration file .pypirc + +# Ignore the output/ folder +output/ +scripts/ + +# VS Code settings +.vscode/launch.json + +# Testing scripts +linktoscripts +truvari_results_Simulated_*/ +conda/contextscore-models/ +tests/fixtures/output.vcf.avinput +tests/fixtures/output.vcf.bed +tests/fixtures/annotations/features.tsv +tests/fixtures/annotations/regions.hg38_multianno.txt + +# Database files +data/ diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..9b38853 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..196b0b8 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +include README.md +include LICENSE +recursive-include data * diff --git a/README.md b/README.md index 8a32729..0743819 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,39 @@ +[![unit tests](https://github.com/WGLab/ContextScore/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/WGLab/ContextScore/actions/workflows/unit-tests.yml) + # ContextScore -Assign confidence scores to SV datasets based on coverage, genomic context, and other important alignment features +

+ContextSV +Filtering step for the ContextSV long-read structural variant (SV) caller, utilizing a Random Forest model trained on SV validation features. Assign confidence scores to SV datasets based on coverage, genomic context, and other important alignment features, then filter low-confidence SVs to increase the precision of the final callset. Genomic context is determined from annotations using ANNOVAR and UCSC databases. +

+
+ +## Installation +```bash +conda install -c wglab -c bioconda -c conda-forge contextscore + +# Or using mamba (faster dependency resolution): +mamba install -c wglab contextscore +``` + +## ANNOVAR setup +[ANNOVAR](https://annovar.openbioinformatics.org/en/latest/user-guide/download/) is required for annotations and must be installed separately. + +These are the required ANNOVAR components for ContextScore: +- `--annovar`: directory containing `annotate_variation.pl` and `table_annovar.pl` +- `--annovar-db`: ANNOVAR database directory + +## User Workflow +```bash +contextscore --input input.vcf --output scored.vcf --sample-coverage 30 --buildver {hg38,hg19} --threshold 0.2 \ + --annovar /path/to/annovar --annovar-db /path/to/humandb +``` + +## Sources for additional annotations (under `data/` directory): +| File | Source | Description | Link | +| --- | --- | --- | --- | +| `cytobands_hg{19,38}.txt` | UCSC Genome Browser | Cytoband annotations for human genome builds hg19 and hg38 | [UCSC hg19](https://hgdownload.soe.ucsc.edu/goldenPath/hg19/database/cytoBand.txt.gz) / [UCSC hg38](https://hgdownload.soe.ucsc.edu/goldenPath/hg38/database/cytoBand.txt.gz) | +| `hg{19,38}_segmental_duplications.bed` | UCSC Genome Browser | Segmental duplication annotations for human genome builds hg19 and hg38 | [UCSC hg19](https://hgdownload.soe.ucsc.edu/goldenPath/hg19/database/segmentalDuplications.txt.gz) / [UCSC hg38](https://hgdownload.soe.ucsc.edu/goldenPath/hg38/database/segmentalDuplications.txt.gz) | +| `phastcons100way_hg{19,38}.bed` | UCSC Genome Browser | PhastCons conservation scores for human genome builds hg19 and hg38 | [UCSC hg19](https://hgdownload.soe.ucsc.edu/goldenPath/hg19/database/phastCons100way.txt.gz) / [UCSC hg38](https://hgdownload.soe.ucsc.edu/goldenPath/hg38/database/phastCons100way.txt.gz) | +| `simple_repeats_hg{19,38}.bed` | UCSC Genome Browser | Simple repeat annotations for human genome builds hg19 and hg38 | [UCSC hg19](https://hgdownload.soe.ucsc.edu/goldenPath/hg19/database/simpleRepeat.txt.gz) / [UCSC hg38](https://hgdownload.soe.ucsc.edu/goldenPath/hg38/database/simpleRepeat.txt.gz) | +| `fragile_sites_hg38.bed` / `fragile_sites_hg19_liftover.bed` | [HumCFS](https://webs.iiitd.edu.in/raghava/humcfs/download.html) | Fragile site annotations for human genome builds hg38 and hg19 (liftover) | [HumCFS](https://webs.iiitd.edu.in/raghava/humcfs/fragile_site_bed.zip) | + diff --git a/conda/meta.yaml b/conda/meta.yaml new file mode 100644 index 0000000..88d049d --- /dev/null +++ b/conda/meta.yaml @@ -0,0 +1,42 @@ +{% set name = "contextscore" %} +{% set version = "0.1.0" %} + +package: + name: {{ name|lower }} + version: {{ version }} + +source: + path: .. + +build: + number: 0 + skip: true # [win] + script: "{{ PYTHON }} -m pip install . --no-deps -vv" + +requirements: + host: + - python >=3.10,<3.11 + - pip + - setuptools + run: + - python >=3.10,<3.11 + - numpy + - pandas + - scikit-learn =1.6.1 # For consistency with model training environment + - joblib + - bedtools + - contextscore-models + +about: + home: https://github.com/WGLab/ContextScore + summary: Assign confidence scores to structural variant datasets. + description: | + ContextScore prediction package. Model weights are distributed separately + (for example via contextscore-models) and can be provided via --model or + CONTEXTSCORE_MODEL_PATH. + license: MIT + license_file: LICENSE + +extra: + recipe-maintainers: + - WGLab diff --git a/contextscore/TrainingAnnotationsSummary.tsv b/contextscore/TrainingAnnotationsSummary.tsv new file mode 100644 index 0000000..2c78f0d --- /dev/null +++ b/contextscore/TrainingAnnotationsSummary.tsv @@ -0,0 +1,2 @@ +True Positives +Total Fragile Sites Telomeres Centromeres Segmental Duplications Conserved Regions diff --git a/contextscore/__init__.py b/contextscore/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/contextscore/__main__.py b/contextscore/__main__.py new file mode 100644 index 0000000..a4a5245 --- /dev/null +++ b/contextscore/__main__.py @@ -0,0 +1,5 @@ +from .predict import main + + +if __name__ == '__main__': + main() diff --git a/contextscore/download_tables.py b/contextscore/download_tables.py new file mode 100644 index 0000000..c0e0e18 --- /dev/null +++ b/contextscore/download_tables.py @@ -0,0 +1,61 @@ +import pandas as pd +import pymysql +from pathlib import Path + +def download_ucsc(table_name: str, + genome_version: str = "hg38", + output_file: str = "ucsc_table.bed") -> None: + """ + Downloads the UCSC Simple Repeats table and saves it as a BED file for use with BEDTools. + Note: This function requires access to the UCSC MySQL database. + """ + print("Downloading UCSC " + table_name + " table for " + genome_version + "...") + + # Connect to UCSC MySQL database + conn = pymysql.connect(host="genome-mysql.soe.ucsc.edu", + user="genome", + password="", + database="hg38") # Change to the desired genome version (e.g., hg19, mm10) + + query = f""" + SELECT + chrom AS chr, + chromStart AS start, + chromEnd AS end, + name + FROM + {table_name} + WHERE + chrom IS NOT NULL AND + chromStart IS NOT NULL AND + chromEnd IS NOT NULL + AND + chromStart >= 0 AND + chromEnd > chromStart + AND + chromStart < chromEnd; + """ + df = pd.read_sql(query, conn) + + # Close connection + conn.close() + + # Save as BED file for BEDTools + df.to_csv(output_file, sep="\t", index=False, header=False) + print("Downloaded UCSC " + table_name + " table for " + genome_version + " and saved as " + output_file) + +if __name__ == "__main__": + data_dir = Path(__file__).resolve().parents[1] / "data" + data_dir.mkdir(parents=True, exist_ok=True) + + # Download the UCSC Simple Repeats table for hg38 + simple_repeat_file = str(data_dir / "simple_repeats_hg38.bed") + download_ucsc(table_name="simpleRepeat", + genome_version="hg38", + output_file=simple_repeat_file) + + # Download the UCSC phastCons100way table for hg38 + phastcons_file = str(data_dir / "phastcons100way_hg38.bed") + download_ucsc(table_name="phastCons100way", + genome_version="hg38", + output_file=phastcons_file) diff --git a/contextscore/extract_features.py b/contextscore/extract_features.py new file mode 100644 index 0000000..1e090ee --- /dev/null +++ b/contextscore/extract_features.py @@ -0,0 +1,936 @@ +""" +extract_features.py: Extract features from the input VCF file. + +Usage: + extract_features.py + +Arguments: + Path to the input VCF file. + +Output: + A dataframe with a column for each feature. +""" + +import os +import sys +import logging +import heapq +from pathlib import Path +import numpy as np +import pandas as pd +import subprocess +import tempfile +from io import StringIO + + +_BEDTOOLS_CHECKED = False + + +def get_contextscore_data_file(file_name): + """Resolve a data file path from ContextScore's bundled data directory.""" + env_data_dir = os.environ.get('CONTEXTSCORE_DATA_DIR') + candidate_dirs = [] + + if env_data_dir: + candidate_dirs.append(Path(env_data_dir)) + + # Source tree / editable install layout. + candidate_dirs.append(Path(__file__).resolve().parents[1] / 'data') + + # data_files install layout from setup.py. + candidate_dirs.append(Path(sys.prefix) / 'contextscore' / 'data') + + # Future-proof fallback if data ever moves under the package directory. + candidate_dirs.append(Path(__file__).resolve().parent / 'data') + + for directory in candidate_dirs: + candidate = directory / file_name + if candidate.exists(): + return str(candidate) + + searched_dirs = ', '.join(str(directory) for directory in candidate_dirs) + raise FileNotFoundError( + f'Could not locate required data file {file_name}. Searched: {searched_dirs}. ' + 'Set CONTEXTSCORE_DATA_DIR to override the data directory.' + ) + + +def get_annotation_paths(buildversion): + """Return annotation file paths for the selected genome build.""" + buildversion = str(buildversion).lower() + file_map = { + 'hg38': { + 'fragile_sites': 'fragile_sites_hg38.bed', + 'phastcons': 'phastcons100way_hg38.bed', + 'simple_repeats': 'simple_repeats_hg38.bed', + 'cytobands': 'cytobands_hg38.txt', + }, + 'hg19': { + 'fragile_sites': 'fragile_sites_hg19_liftover.bed', + 'phastcons': 'phastcons100way_hg19.bed', + 'simple_repeats': 'simple_repeats_hg19.bed', + 'cytobands': 'cytobands_hg19.txt', + }, + } + + if buildversion not in file_map: + raise ValueError(f'Unsupported build version: {buildversion}. Please use hg38 or hg19.') + + return { + key: get_contextscore_data_file(file_name) + for key, file_name in file_map[buildversion].items() + } + + +def read_cytoband_file(cytoband_file): + """Get the centromere and telomere regions for each chromosome.""" + cytobands = pd.read_csv(cytoband_file, sep='\t', header=0, names=["chrom", "start", "end", "name", "gieStain"], dtype={"chrom": str, "start": int, "end": int, "name": str, "gieStain": str}) + chrom_dict = {} + for chrom in cytobands['chrom'].unique(): + + # Skip chrM, and other non-standard chromosomes. + if chrom == 'chrM': + continue + + chrom_df = cytobands[cytobands['chrom'] == chrom].sort_values('start') + # Store chromosome boundaries and terminal bands. + chrom_dict[chrom] = { + 'chrom_start': int(chrom_df['start'].min()), + 'chrom_end': int(chrom_df['end'].max()), + 'telomerep_start': int(chrom_df.iloc[0]['start']), + 'telomerep_end': int(chrom_df.iloc[0]['end']), + 'telomereq_start': int(chrom_df.iloc[-1]['start']), + 'telomereq_end': int(chrom_df.iloc[-1]['end']) + } + + # Identify centromeres from cytobands with gieStain == "acen". + acen_df = chrom_df[chrom_df['gieStain'] == 'acen'] + centromere_p = acen_df[acen_df['name'].str.startswith('p', na=False)] + centromere_q = acen_df[acen_df['name'].str.startswith('q', na=False)] + if not centromere_p.empty: + chrom_dict[chrom]['centromerep_start'] = int(centromere_p.iloc[0]['start']) + chrom_dict[chrom]['centromerep_end'] = int(centromere_p.iloc[0]['end']) + if not centromere_q.empty: + chrom_dict[chrom]['centromereq_start'] = int(centromere_q.iloc[0]['start']) + chrom_dict[chrom]['centromereq_end'] = int(centromere_q.iloc[0]['end']) + + # Combined centromere span (union of acen blocks) for distance calculation. + if not acen_df.empty: + chrom_dict[chrom]['centromere_start'] = int(acen_df['start'].min()) + chrom_dict[chrom]['centromere_end'] = int(acen_df['end'].max()) + + return chrom_dict + + +def normalize_chrom_label(chrom): + """Normalize chromosome labels for robust joins/lookups (e.g., 1 vs chr1).""" + if pd.isna(chrom): + return None + chrom_str = str(chrom).strip() + if not chrom_str: + return None + chrom_str = chrom_str[3:] if chrom_str.lower().startswith('chr') else chrom_str + return chrom_str.upper() + +def extract_features(input_bed, annovar_path, db_path, outdiranno, buildversion='hg38', sample_coverage=None): + """Extract the features from the BED file, columns are in the first row: + chrom, start, end, sv_type, sv_length, genotype, read_depth, hmm_llh, aln_type, cluster_size + + Args: + sample_coverage (float): Required. Mean read depth coverage for the sample, used to normalize read_depth. + """ + logging.info('Extracting features from the BED file %s', input_bed) + + if sample_coverage is None or sample_coverage <= 0: + logging.error('sample_coverage is required and must be > 0') + raise ValueError('sample_coverage is required and must be > 0') + + # Get the number of columns in the BED file. + with open(input_bed, 'r') as f: + first_line = f.readline().strip() + num_columns = len(first_line.split('\t')) + logging.info('Number of columns in the BED file: %d', num_columns) + + training_format = False + if num_columns == 12: # Standard training format. + training_format = True + logging.info('Training format detected.') + elif num_columns == 13: # Contains additional 'id' column. + logging.info('Prediction format detected.') + + # Read in the BED file. + if training_format: + bed_df = pd.read_csv(input_bed, sep='\t', header=None, usecols=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + names=['chrom', 'start', 'end', 'sv_type', 'sv_length', 'genotype', 'read_depth', 'hmm_llh', 'aln_type', 'cluster_size', 'cn_state', 'aln_offset'], + dtype={'chrom': str, 'start': np.int32, 'end': np.int32, 'sv_type': str, 'sv_length': np.int32, 'genotype': str, 'read_depth': np.int32, 'hmm_llh': np.float32, 'aln_type': str, 'cluster_size': np.int32, 'cn_state': np.int32, 'aln_offset': np.int32}) + else: + bed_df = pd.read_csv(input_bed, sep='\t', header=None, usecols=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + names=['chrom', 'start', 'end', 'sv_type', 'sv_length', 'genotype', 'read_depth', 'hmm_llh', 'aln_type', 'cluster_size', 'cn_state', 'aln_offset', 'id'], + dtype={'chrom': str, 'start': np.int32, 'end': np.int32, 'sv_type': str, 'sv_length': np.int32, 'genotype': str, 'read_depth': np.int32, 'hmm_llh': np.float32, 'aln_type': str, 'cluster_size': np.int32, 'cn_state': np.int32, 'aln_offset': np.int32, 'id': np.int32}) + + # Normalize SV length to a positive magnitude. + bed_df['sv_length'] = bed_df['sv_length'].abs() + + # Drop the genotype column and cn_state columns (due to redundancy). + bed_df.drop(columns=['genotype', 'cn_state'], inplace=True) + + # Create alignment type feature, 0 for CIGAR alignment types (contains + # CIGAR), 1 for CIGARCLIP (contains CIGARCLIP), 2 for SPLIT alignment (all + # others) + bed_df['call_type'] = bed_df['aln_type'].apply(lambda x: 1 if 'CIGARCLIP' in x else (0 if 'CIGAR' in x else 2)) + bed_df['call_type'] = bed_df['call_type'].astype('category') + + # Drop the original aln_type column. + bed_df.drop(columns=['aln_type'], inplace=True) + + # Create normalized cluster_size feature (cluster_size per 1000 bp of SV length) + # This prevents large sparse SVs from being unfairly penalized + bed_df['cluster_size_per_kb'] = np.where( + bed_df['sv_length'] > 0, + bed_df['cluster_size'] / (bed_df['sv_length'] / 1000.0), + 0 + ) + + # Read depth normalized by sample coverage + bed_df['read_depth_normalized'] = np.where( + sample_coverage > 0, + bed_df['read_depth'] / sample_coverage, + bed_df['read_depth'] + ) + + # Map of SV types to integers + sv_type_map = { + 'DEL': 0, + 'DUP': 1, + 'INV': 2, + 'INS': 3, + 'BND': 4, + 'UNKNOWN': 5 + } + bed_df['sv_type_str'] = bed_df['sv_type'].astype(str) + bed_df['sv_type'] = bed_df['sv_type'].map(sv_type_map).astype('category') + + # Check for missing features + if bed_df.isnull().values.any(): + logging.error('Features are missing.') + + # Get the rows with missing features. + missing_features = bed_df[bed_df.isnull().any(axis=1)] + + # Print the rows with missing features. + logging.error(missing_features) + sys.exit(1) + + # Add annotations to the features. + bed_df = add_annotations(bed_df, input_bed, annovar_path, db_path, outdiranno, buildversion, training_format) + + # Print the number of NaN values + logging.info('Number of NaN values: %d', bed_df.isnull().sum().sum()) + + # ------------------------------------------------------------------- + # Fix the chromosome names to all start with 'chr' if they don't already. + bed_df['chrom'] = bed_df['chrom'].apply(lambda x: 'chr' + x if not x.startswith('chr') else x) + + # Drop telomere and centromere columns (they don't affect predictions). + bed_df.drop(columns=['telomere', 'centromere'], inplace=True) + + # Drop the genotype column from the data. + bed_df = bed_df.drop(columns=['genotype'], errors='ignore') + + # Drop the cn_state column from the data. + bed_df = bed_df.drop(columns=['cn_state'], errors='ignore') + + # Add distance to nearest other SV call, clustered false positives often appear near real SVs. + logging.info('Computing distance to nearest other SV call (same chromosome)...') + bed_df['dist_to_nearest_sv'] = np.nan + for chrom, idx in bed_df.groupby('chrom', sort=False).groups.items(): + chrom_df = bed_df.loc[idx, ['start', 'end']].sort_values(['start', 'end']) + n = chrom_df.shape[0] + + if n <= 1: + continue + + starts = chrom_df['start'].to_numpy(dtype=np.int64) + ends = chrom_df['end'].to_numpy(dtype=np.int64) + + # Previous interval summary. + prev_max_end = np.maximum.accumulate(ends) + prev_max_end_excl = np.empty(n, dtype=np.int64) + prev_max_end_excl[0] = np.iinfo(np.int64).min + prev_max_end_excl[1:] = prev_max_end[:-1] + + # Next interval summary. + next_start_excl = np.empty(n, dtype=np.int64) + next_start_excl[:-1] = starts[1:] + next_start_excl[-1] = np.iinfo(np.int64).max + + # Overlap checks with prior/next intervals. + overlap_prev = prev_max_end_excl > starts + overlap_next = ends > next_start_excl + overlap_any = overlap_prev | overlap_next + + # Gap to closest left/right neighbor (touching intervals yield 0). + left_gap = starts - prev_max_end_excl + right_gap = next_start_excl - ends + + # No-left/no-right sentinels. + left_gap[0] = np.iinfo(np.int64).max + right_gap[-1] = np.iinfo(np.int64).max + + nearest = np.minimum(left_gap, right_gap).astype(np.float64) + nearest[overlap_any] = 0.0 + + # Any remaining sentinel values are undefined (should only happen in degenerate cases). + sentinel = float(np.iinfo(np.int64).max) + nearest[nearest >= sentinel] = np.nan + + bed_df.loc[chrom_df.index, 'dist_to_nearest_sv'] = nearest + + logging.info('Distance to nearest SV calculated. Coverage: %.1f%%', (bed_df['dist_to_nearest_sv'].notna().sum() / len(bed_df) * 100)) + + # Print statistics about the distance to nearest SV feature. + logging.info('Distance to nearest SV - mean: %.2f, median: %.2f, std: %.2f', bed_df['dist_to_nearest_sv'].mean(), bed_df['dist_to_nearest_sv'].median(), bed_df['dist_to_nearest_sv'].std()) + + # Normalize by SV size + bed_df['dist_nearest_sv_per_kb'] = np.where( + bed_df['sv_length'] > 0, + bed_df['dist_to_nearest_sv'] / (bed_df['sv_length'] / 1000.0), + bed_df['dist_to_nearest_sv'] + ) + + # Return the features dataframe. + return bed_df + + +def run_bedtools_intersect(input_bed, table_bed, training_format=False): + """Run bedtools intersect to annotate the BED file.""" + def bed_uses_chr_prefix(path, sample_size=1000): + try: + sample_df = pd.read_csv( + path, + sep='\t', + header=None, + usecols=[0], + nrows=sample_size, + comment='#', + dtype=str, + ) + except Exception: + return None + + if sample_df.empty: + return None + + chroms = sample_df.iloc[:, 0].dropna().astype(str).str.strip() + chroms = chroms[chroms != ''] + if chroms.empty: + return None + + return bool((chroms.str.lower().str.startswith('chr')).mean() >= 0.5) + + def normalize_chrom_name(chrom, target_has_chr): + if pd.isna(chrom): + return chrom + chrom_str = str(chrom).strip() + if not chrom_str: + return chrom_str + + has_chr = chrom_str.lower().startswith('chr') + if target_has_chr and not has_chr: + return f'chr{chrom_str}' + if not target_has_chr and has_chr: + return chrom_str[3:] + return chrom_str + + def write_normalized_bed(path, target_has_chr): + bed_df = pd.read_csv(path, sep='\t', header=None, comment='#', dtype=str) + bed_df.iloc[:, 0] = bed_df.iloc[:, 0].apply(lambda c: normalize_chrom_name(c, target_has_chr)) + + with tempfile.NamedTemporaryFile( + mode='w', + suffix='.bed', + prefix='contextscore_normalized_', + delete=False, + encoding='utf-8', + ) as tmp_file: + temp_path = tmp_file.name + + bed_df.to_csv(temp_path, sep='\t', header=False, index=False) + return temp_path + + # Check if bedtools is installed (once per process). + global _BEDTOOLS_CHECKED + if not _BEDTOOLS_CHECKED: + try: + subprocess.run(["bedtools", "--version"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + _BEDTOOLS_CHECKED = True + except (subprocess.CalledProcessError, FileNotFoundError): + logging.error('bedtools is not installed. Please install bedtools.') + sys.exit(1) + + # Check if the input BED file exists. + if not os.path.exists(input_bed): + logging.error('Input BED file does not exist: %s', input_bed) + sys.exit(1) + + # Check if the table BED file exists. + if not os.path.exists(table_bed): + logging.error('Table BED file does not exist: %s', table_bed) + sys.exit(1) + + intersect_input_bed = input_bed + normalized_temp_bed = None + try: + input_has_chr = bed_uses_chr_prefix(input_bed) + table_has_chr = bed_uses_chr_prefix(table_bed) + if input_has_chr is not None and table_has_chr is not None and input_has_chr != table_has_chr: + normalized_temp_bed = write_normalized_bed(input_bed, table_has_chr) + intersect_input_bed = normalized_temp_bed + logging.info( + 'Normalized chromosome naming for bedtools intersect: %s -> %s prefix.', + 'chr' if input_has_chr else 'no-chr', + 'chr' if table_has_chr else 'no-chr', + ) + except Exception as exc: + logging.warning('Could not normalize chromosome naming before bedtools intersect: %s', exc) + intersect_input_bed = input_bed + + # Run bedtools intersect to annotate the BED file. + cmd = [ + "bedtools", "intersect", + "-a", intersect_input_bed, + "-b", table_bed, + "-wa", "-wb" + ] + logging.info('Running the command to annotate the BED file: %s', " ".join(cmd)) + try: + result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + if result.stderr: + stderr_text = result.stderr.strip() + if 'inconsistent naming convention' in stderr_text: + logging.info('bedtools reported chromosome naming convention warnings; continuing.') + logging.debug('bedtools stderr: %s', stderr_text) + else: + logging.warning('bedtools stderr: %s', stderr_text) + + # Parse the output of bedtools intersect into a pandas DataFrame. + logging.info('Parsing the output of bedtools intersect.') + if training_format: + annotated_bed = pd.read_csv( + StringIO(result.stdout), + sep='\t', + header=None, + names=["chrom", "start", "end", "chr_anno", "start_anno", "end_anno", "name"], + usecols=[0, 1, 2, 12, 13, 14, 15], + dtype={'chrom': str, 'start': np.int32, 'end': np.int32, 'chr_anno': str, 'start_anno': np.int32, 'end_anno': np.int32, 'name': str} + ) + else: + annotated_bed = pd.read_csv( + StringIO(result.stdout), + sep='\t', + header=None, + names=["chrom", "start", "end", "chr_anno", "start_anno", "end_anno", "name"], + usecols=[0, 1, 2, 13, 14, 15, 16],#12, 13, 14, 15], #10, 11, 12, 13], + dtype={'chrom': str, 'start': np.int32, 'end': np.int32, 'chr_anno': str, 'start_anno': np.int32, 'end_anno': np.int32, 'name': str} + ) + + return annotated_bed + + except subprocess.CalledProcessError as e: + logging.error('Error annotating the BED file: %s', e) + if e.stderr: + logging.error('bedtools stderr: %s', e.stderr.strip()) + logging.error('Please check the input and table BED files.') + sys.exit(1) + finally: + if normalized_temp_bed and os.path.exists(normalized_temp_bed): + try: + os.remove(normalized_temp_bed) + except OSError: + logging.warning('Could not remove temporary normalized BED file: %s', normalized_temp_bed) + + +def bed_to_annovar_input(bed_file): + """Convert the BED file to ANNOVAR input format.""" + output_file = bed_file.replace('.bed', '.avinput') + logging.info('Converting the BED file to ANNOVAR input format.') + + # Read the BED file using pandas. ContextScore BED files are headerless. + df = pd.read_csv(bed_file, sep='\t', header=None, usecols=[0, 1, 2], + names=["CHROM", "POS", "END"], + dtype={'CHROM': str, 'POS': np.int32, 'END': np.int32}) + + # Check if the BED file is empty. + logging.info('Number of rows in the BED file: %d', df.shape[0]) + + # The ANNOVAR input format requires the following columns: + # 1. Chromosome + # 2. Start position + # 3. End position + # 4. Reference allele + # 5. Alternate allele + # We will use the first three columns from the BED file and placeholder + # columns for the reference and alternate alleles (0, and -) since gnomAD does not + # provide sequence information + + # Create a new dataframe with the required columns. + annovar_df = pd.DataFrame() + annovar_df['chrom'] = df['CHROM'] + annovar_df['start'] = df['POS'] + annovar_df['end'] = df['END'] + annovar_df['ref'] = '0' + annovar_df['alt'] = '-' + + # Save the tab-delimited dataframe to a file. + logging.info('Saving the ANNOVAR input file to %s', output_file) + annovar_df.to_csv(output_file, sep='\t', index=False, header=False) + logging.info('Number of rows in the ANNOVAR input file: %d', annovar_df.shape[0]) + logging.info('Saved the ANNOVAR input file to %s', output_file) + + return output_file + + +def download_annovar_db(annovar_path, db_path, db_name, buildversion='hg38'): + """Download the ANNOVAR database if it does not exist. + + Returns True if successful or database already exists, False if download failed. + """ + logging.info('Downloading the database: %s for build version: %s', db_name, buildversion) + + # Check if database files already exist + expected_files = [ + os.path.join(db_path, f"{buildversion}_{db_name}.txt"), + os.path.join(db_path, f"{buildversion}_{db_name}.txt.idx"), + ] + + if all(os.path.exists(f) for f in expected_files): + logging.info('Database %s already exists, skipping download.', db_name) + return True + + # Ensure the database directory exists + os.makedirs(db_path, exist_ok=True) + + cmd = [ + f"{annovar_path}/annotate_variation.pl", + "-buildver", buildversion, + "-downdb", db_name, + "." # Download to current directory (we'll set cwd=db_path) + ] + + # Run the command to download the database from the db_path directory + # This ensures files are downloaded directly to the correct location + logging.info('Running the command to download the database: %s (in directory: %s)', " ".join(cmd), db_path) + try: + result = subprocess.run(cmd, check=True, capture_output=True, text=True, cwd=db_path) + if result.stdout: + logging.debug('Download stdout: %s', result.stdout) + logging.info('Downloaded the database %s successfully.', db_name) + return True + except subprocess.CalledProcessError as e: + logging.warning('Failed to download the database %s: %s', db_name, e) + if e.stderr: + logging.warning('Error output: %s', e.stderr) + logging.warning('Continuing without this database. Some features may be missing.') + return False + except FileNotFoundError as e: + logging.warning('Failed to download the database %s: %s', db_name, e) + logging.warning('Please verify ANNOVAR is installed and annovar_path is correct.') + logging.warning('Continuing without this database. Some features may be missing.') + return False + + +def annotate(annovar_input, annovar_path, db_path, output_dir, buildversion='hg38'): + """Annotate regions.""" + logging.info('Annotating regions using ANNOVAR.') + + annotations_dir = os.path.join(output_dir, 'regions') + logging.info('Creating the output directory: %s', annotations_dir) + cmd = [ + os.path.join(annovar_path, "table_annovar.pl"), + annovar_input, + db_path, + "--buildver", buildversion, + "--out", annotations_dir, + "--remove", + "--protocol", "genomicSuperDups,cytoBand", + "--operation", "r,r", + "--nastring", ".", + "-polish" + ] + + try: + subprocess.run(cmd, check=True) + except subprocess.CalledProcessError as e: + logging.error('Error annotating: %s', e) + logging.error('Please check the ANNOVAR path and database path.') + sys.exit(1) + + logging.info('Completed annotations.') + + +def get_cytoband_is_c_t(chrom_dict, chrom, cytoband): + """Check if the cytoband is a telomere or centromere.""" + if chrom not in chrom_dict: + return False, False # Not in any region. + + is_telomere = False + is_centromere = False + # Check if the cytoband annotation indicates telomere or centromere regions. + try: + # Centromeres contain 'acen' + if 'acen' in cytoband: + is_centromere = True + # Telomeres are at the extreme bands - simplistic check for p/q terminal regions + elif 'p11' in cytoband or 'p12' in cytoband or 'p13' in cytoband: # p-arm terminal + is_telomere = True + elif 'q13' in cytoband or 'q14' in cytoband: # q-arm terminal (varies by chromosome) + is_telomere = True + + except TypeError: + pass + + return is_telomere, is_centromere + + +def add_annotations(data, input_bed, annovar_path, db_path, anno_outdir, buildversion='hg38', training_format=False): + """Add annotations to the features.""" + logging.info('Adding annotations to the features.') + + try: + annotation_paths = get_annotation_paths(buildversion) + except (FileNotFoundError, ValueError) as exc: + logging.error('%s', exc) + sys.exit(1) + + # --------------------------------------------------------------- + # Annotate the fragile sites using a BED file from HumCFS. + # https://webs.iiitd.edu.in/raghava/humcfs/download.html + # ANNOVAR instructions are here: + # https://annovar.openbioinformatics.org/en/latest/user-guide/region/ + fragile_sites_bed = annotation_paths['fragile_sites'] + + logging.info('Annotating the fragile sites using the BED file (%s): %s', buildversion, fragile_sites_bed) + fragile_sites_df = run_bedtools_intersect(input_bed, fragile_sites_bed, training_format) + + # Merge the fragile sites annotations with the true positive data. + data['fragile_site'] = data.merge(fragile_sites_df, on=['chrom', 'start', 'end'], how='left')['chr_anno'].notna() + + logging.info('Number of records with fragile sites: %d', data['fragile_site'].sum()) + logging.info('Total number of records: %d', data.shape[0]) + + # --------------------------------------------------------------- + # Annotate conserved regions using a UCSC Table Browser BED file for + # phastCons100way. + phastCons_bed = annotation_paths['phastcons'] + logging.info('Annotating conserved regions using the BED file (%s): %s', buildversion, phastCons_bed) + phastCons_df = run_bedtools_intersect(input_bed, phastCons_bed, training_format) + + # Merge the phastCons annotations with the true positive data. + data['phastCons'] = data.merge(phastCons_df, on=['chrom', 'start', 'end'], how='left')['chr_anno'].notna() + + logging.info('Number of records with conserved regions: %d', data['phastCons'].sum()) + logging.info('Total number of records: %d', data.shape[0]) + + # --------------------------------------------------------------- + # Annotate simple repeats using a UCSC Table Browser BED file for + # simpleRepeat. + simpleRepeat_bed = annotation_paths['simple_repeats'] + logging.info('Annotating simple repeats using the BED file (%s): %s', buildversion, simpleRepeat_bed) + simpleRepeat_df = run_bedtools_intersect(input_bed, simpleRepeat_bed, training_format) + + # Check if record has any simple repeats (boolean indicator). + data['simpleRepeat'] = data.merge(simpleRepeat_df, on=['chrom', 'start', 'end'], how='left')['chr_anno'].notna() + + logging.info('Number of records with simple repeats: %d', data['simpleRepeat'].sum()) + logging.info('Total number of records: %d', data.shape[0]) + + # --------------------------------------------------------------- + # Annotate using ANNOVAR. + + # Download the segmental duplication database + segdup_success = download_annovar_db(annovar_path, db_path, "genomicSuperDups", buildversion) + + # Download the cytoband database + cytoband_success = download_annovar_db(annovar_path, db_path, "cytoBand", buildversion) + + # Set up a dictionary for each chromosome, mapping cytobands to + # centromere and telomere regions. + cytoband_file = annotation_paths['cytobands'] + cytoband_dict = read_cytoband_file(cytoband_file) + + logging.info('Converting the true positive BED file to ANNOVAR input format.') + annovar_file = bed_to_annovar_input(input_bed) + + logging.info('Annotating the SVs using ANNOVAR.') + if not os.path.exists(anno_outdir): + os.makedirs(anno_outdir) + + annotate(annovar_file, annovar_path, db_path, anno_outdir, buildversion) + + anno_file = os.path.join(anno_outdir, 'regions.{}_multianno.txt'.format(buildversion)) + if not os.path.exists(anno_file): + logging.error('ANNOVAR annotation file does not exist: %s', anno_file) + sys.exit(1) + + # Read the ANNOVAR output file. + logging.info('Reading the ANNOVAR output file: %s', anno_file) + anno_df = pd.read_csv(anno_file, sep='\t', header=0, comment='#') + + # Convert chr, start, end to the same data types as the data. + anno_df['Chr'] = anno_df['Chr'].astype(str) + anno_df['Start'] = anno_df['Start'].astype(np.int32) + anno_df['End'] = anno_df['End'].astype(np.int32) + + # Merge the ANNOVAR annotations with the data. + logging.info('Merging ANNOVAR annotations (%d records) with data (%d records)...', anno_df.shape[0], data.shape[0]) + data = data.merge(anno_df, left_on=['chrom', 'start', 'end'], right_on=['Chr', 'Start', 'End'], how='left') + logging.info('ANNOVAR merge completed.') + + # Extract segmental duplication scores. + logging.info('Extracting segmental duplication scores...') + def extract_scores(score_str): + """Extract and return the segmental duplication scores from a string.""" + if pd.isna(score_str) or score_str == '.': + return 0 + # Extract the Score= value from the string. + try: + score = score_str.split('Score=')[1].split(';')[0] + except IndexError: + logging.warning('Score= not found in the string: %s', score_str) + return 0 + return float(score) if score else 0 + + # Extract the segmental duplication scores. + data['segdup'] = data['genomicSuperDups'].apply(extract_scores) + logging.info('Segmental duplication scores extracted. Mean: %.3f', data['segdup'].mean()) + + # Extract the cytoband annotations. + logging.info('Processing cytoband annotations for telomere/centromere detection...') + def get_cyto_info(row): + """Get telomere and centromere information for a row.""" + if pd.notna(row['cytoBand']): + return get_cytoband_is_c_t(cytoband_dict, row['chrom'], row['cytoBand']) + + return (False, False) + + cyto_flags = data.apply(get_cyto_info, axis=1, result_type='expand') + data[['telomere', 'centromere']] = cyto_flags + logging.info('Telomere/centromere annotation complete. Telomeres: %d, Centromeres: %d', data['telomere'].sum(), data['centromere'].sum()) + + # Add feature dist_to_telomere, dist_to_centromere using vectorized operations. + logging.info('Computing distances to chromosome telomeres and centromeres...') + chrom_bounds = pd.DataFrame([ + { + 'chrom': chrom, + 'chrom_norm': normalize_chrom_label(chrom), + 'chrom_start': values.get('chrom_start', np.nan), + 'chrom_end': values.get('chrom_end', np.nan), + 'centromere_start': values.get('centromere_start', np.nan), + 'centromere_end': values.get('centromere_end', np.nan) + } + for chrom, values in cytoband_dict.items() + ]) + + data_with_bounds = data.copy() + data_with_bounds['chrom_norm'] = data_with_bounds['chrom'].apply(normalize_chrom_label) + data_with_bounds = data_with_bounds.merge( + chrom_bounds[['chrom_norm', 'chrom_start', 'chrom_end', 'centromere_start', 'centromere_end']], + on='chrom_norm', + how='left' + ) + + starts = data_with_bounds['start'].to_numpy(dtype=np.float64) + ends = data_with_bounds['end'].to_numpy(dtype=np.float64) + chrom_starts = data_with_bounds['chrom_start'].to_numpy(dtype=np.float64) + chrom_ends = data_with_bounds['chrom_end'].to_numpy(dtype=np.float64) + centromere_starts = data_with_bounds['centromere_start'].to_numpy(dtype=np.float64) + centromere_ends = data_with_bounds['centromere_end'].to_numpy(dtype=np.float64) + + # Telomere distance: nearest interval-to-point distance to chromosome start/end. + dist_left_tel = np.minimum(np.abs(starts - chrom_starts), np.abs(ends - chrom_starts)) + dist_right_tel = np.minimum(np.abs(starts - chrom_ends), np.abs(ends - chrom_ends)) + dist_to_telomere = np.minimum(dist_left_tel, dist_right_tel) + tel_valid = (~np.isnan(chrom_starts)) & (~np.isnan(chrom_ends)) + dist_to_telomere[~tel_valid] = np.nan + + # Centromere distance: 0 if overlapping centromere span, else gap to nearest boundary. + cen_valid = (~np.isnan(centromere_starts)) & (~np.isnan(centromere_ends)) + dist_to_centromere = np.full(len(data_with_bounds), np.nan, dtype=np.float64) + left_of_centromere = ends < centromere_starts + right_of_centromere = starts > centromere_ends + overlap_centromere = (~left_of_centromere) & (~right_of_centromere) + dist_to_centromere[cen_valid & left_of_centromere] = (centromere_starts - ends)[cen_valid & left_of_centromere] + dist_to_centromere[cen_valid & right_of_centromere] = (starts - centromere_ends)[cen_valid & right_of_centromere] + dist_to_centromere[cen_valid & overlap_centromere] = 0.0 + + data['dist_to_telomere'] = dist_to_telomere + data['dist_to_centromere'] = dist_to_centromere + tel_zero_pct = (data['dist_to_telomere'] == 0).mean() * 100 + cen_zero_pct = (data['dist_to_centromere'] == 0).mean() * 100 + cen_le1_pct = (data['dist_to_centromere'] <= 1).mean() * 100 + cen_desc = data['dist_to_centromere'].describe(percentiles=[0.5, 0.9, 0.99]) + # Diagnostics for coordinate issues that can indicate malformed records. + out_of_bounds_pct = ((data_with_bounds['start'] < data_with_bounds['chrom_start']) | (data_with_bounds['end'] > data_with_bounds['chrom_end'])).mean() * 100 + logging.info( + 'Telomere/centromere distances calculated. Mean dist_to_telomere: %.2f, Mean dist_to_centromere: %.2f, telomere zeros: %.2f%%, centromere zeros: %.2f%%, out-of-bounds coords: %.2f%%', + data['dist_to_telomere'].mean(), + data['dist_to_centromere'].mean(), + tel_zero_pct, + cen_zero_pct, + out_of_bounds_pct + ) + logging.info( + 'Centromere distance distribution: min=%.2f, p50=%.2f, p90=%.2f, p99=%.2f, max=%.2f, <=1bp: %.2f%%', + cen_desc['min'], + cen_desc['50%'], + cen_desc['90%'], + cen_desc['99%'], + cen_desc['max'], + cen_le1_pct + ) + + # Log-transform long-tailed distance features for model stability. + logging.info('Applying log1p transform to dist_to_telomere and dist_to_centromere...') + data['dist_to_telomere'] = np.log1p(data['dist_to_telomere']) + data['dist_to_centromere'] = np.log1p(data['dist_to_centromere']) + logging.info( + 'Distance log-transform complete. Mean log-dist_to_telomere: %.3f, Mean log-dist_to_centromere: %.3f', + data['dist_to_telomere'].mean(), + data['dist_to_centromere'].mean() + ) + + # Helper function to compute repeat density across entire SV span + def compute_repeat_density_span(data_df, repeat_overlap_df): + """Compute repeat span density as the fraction of the SV covered by simple repeats.""" + repeat_copy = repeat_overlap_df.copy() + repeat_copy['overlap_length'] = repeat_copy['end_anno'] - repeat_copy['start_anno'] + + # Group by original SV coordinates and sum total overlapping lengths + density_df = repeat_copy.groupby(['chrom', 'start', 'end'])['overlap_length'].sum().reset_index() + density_df.columns = ['chrom', 'start', 'end', 'total_repeat_length'] + + # Merge with data and calculate density + merged = data_df.merge(density_df, on=['chrom', 'start', 'end'], how='left') + merged['total_repeat_length'] = merged['total_repeat_length'].fillna(0) + span_length = (merged['end'] - merged['start']).astype(float) + zero_span_count = (span_length <= 0).sum() + if zero_span_count > 0: + logging.info('Found %d SV records with non-positive span; setting repeat_span_density to 0 for these records.', zero_span_count) + valid_span = span_length > 0 + density_values = pd.Series(0.0, index=merged.index) + density_values.loc[valid_span] = merged.loc[valid_span, 'total_repeat_length'] / span_length.loc[valid_span] + density_values = density_values.clip(lower=0, upper=1) + + return density_values + + # Add breakpoint features from both breakpoints (vectorized by chromosome). + logging.info('Computing breakpoint features (segdup and simple repeat at left/right breakpoints)...') + + def point_max_overlap_score(points, starts, ends, scores): + """For each query point, return max score among overlapping intervals.""" + if len(starts) == 0: + return np.zeros(len(points), dtype=np.float64) + + order = np.argsort(starts, kind='mergesort') + starts = starts[order] + ends = ends[order] + scores = scores[order] + + point_order = np.argsort(points, kind='mergesort') + result = np.zeros(len(points), dtype=np.float64) + + active = [] # max-heap via negative score: (-score, interval_end) + interval_idx = 0 + n_intervals = len(starts) + + for point_idx in point_order: + point = points[point_idx] + + while interval_idx < n_intervals and starts[interval_idx] <= point: + heapq.heappush(active, (-scores[interval_idx], ends[interval_idx])) + interval_idx += 1 + + while active and active[0][1] < point: + heapq.heappop(active) + + if active: + result[point_idx] = -active[0][0] + + return result + + def point_in_any_interval(points, starts, ends): + """For each query point, return whether it is covered by any interval.""" + if len(starts) == 0: + return np.zeros(len(points), dtype=bool) + + order = np.argsort(starts, kind='mergesort') + starts_sorted = starts[order] + ends_sorted = ends[order] + max_end_prefix = np.maximum.accumulate(ends_sorted) + + idx = np.searchsorted(starts_sorted, points, side='right') - 1 + covered = np.zeros(len(points), dtype=bool) + valid = idx >= 0 + covered[valid] = max_end_prefix[idx[valid]] >= points[valid] + + return covered + + # Precompute interval arrays by chromosome for fast lookup. + anno_segdup = anno_df[['Chr', 'Start', 'End', 'genomicSuperDups']].copy() + anno_segdup['segdup_score'] = anno_segdup['genomicSuperDups'].apply(extract_scores) + segdup_intervals = { + normalize_chrom_label(chrom): ( + grp['Start'].to_numpy(dtype=np.int64), + grp['End'].to_numpy(dtype=np.int64), + grp['segdup_score'].to_numpy(dtype=np.float64) + ) + for chrom, grp in anno_segdup.groupby('Chr', sort=False) + } + + repeat_intervals = { + normalize_chrom_label(chrom): ( + grp['start_anno'].to_numpy(dtype=np.int64), + grp['end_anno'].to_numpy(dtype=np.int64) + ) + for chrom, grp in simpleRepeat_df.groupby('chrom', sort=False) + } + + # Allocate result columns. + data['segdup_left'] = 0.0 + data['segdup_right'] = 0.0 + data['simpleRepeat_left'] = False + data['simpleRepeat_right'] = False + + logging.info('Computing left breakpoint features...') + for chrom, chrom_idx in data.groupby('chrom', sort=False).groups.items(): + idx = list(chrom_idx) + left_points = data.loc[idx, 'start'].to_numpy(dtype=np.int64) + right_points = data.loc[idx, 'end'].to_numpy(dtype=np.int64) + chrom_norm = normalize_chrom_label(chrom) + + seg_starts, seg_ends, seg_scores = segdup_intervals.get( + chrom_norm, (np.array([], dtype=np.int64), np.array([], dtype=np.int64), np.array([], dtype=np.float64)) + ) + rep_starts, rep_ends = repeat_intervals.get( + chrom_norm, (np.array([], dtype=np.int64), np.array([], dtype=np.int64)) + ) + + data.loc[idx, 'segdup_left'] = point_max_overlap_score(left_points, seg_starts, seg_ends, seg_scores) + data.loc[idx, 'simpleRepeat_left'] = point_in_any_interval(left_points, rep_starts, rep_ends) + + data.loc[idx, 'segdup_right'] = point_max_overlap_score(right_points, seg_starts, seg_ends, seg_scores) + data.loc[idx, 'simpleRepeat_right'] = point_in_any_interval(right_points, rep_starts, rep_ends) + + logging.info('Breakpoint features complete. segdup_left mean: %.3f, segdup_right mean: %.3f', data['segdup_left'].mean(), data['segdup_right'].mean()) + + # Calculate repeat span density feature using the simpleRepeat annotations. For each record, calculate the repeat span density as the total overlapping length of all simple repeats divided by the length of the record (end - start). + logging.info('Computing repeat span density (total repeat coverage across SV)...') + data['repeat_span_density'] = compute_repeat_density_span(data, simpleRepeat_df) # across entire SV + logging.info('Repeat span density calculated. Mean: %.3f, Max: %.3f', data['repeat_span_density'].mean(), data['repeat_span_density'].max()) + + # Drop the unnecessary/redundant columns. + data.drop(columns=['Chr', 'Start', 'End', 'cytoBand', 'genomicSuperDups', 'Ref', 'Alt', 'segdup', 'simpleRepeat'], inplace=True) + + logging.info('Number of records after adding annotations: %d', data.shape[0]) + + return data diff --git a/contextscore/predict.py b/contextscore/predict.py new file mode 100644 index 0000000..4b63591 --- /dev/null +++ b/contextscore/predict.py @@ -0,0 +1,599 @@ +""" +scoring_model.py: Score the structural variants using the binary classification +model. + +Usage: + scoring_model.py + +Arguments: + Path to the input VCF file. + Path to the model file. +""" + +import os +import sys +import logging +import argparse +import importlib +import gzip +import re +import tempfile +import numpy as np +import joblib +import pandas as pd + +try: + from .extract_features import extract_features +except ImportError: + from extract_features import extract_features + + +USER_PREFIX = "[ContextScore]" +DEFAULT_MODEL_ENV_VAR = 'CONTEXTSCORE_MODEL_PATH' +DEFAULT_MODEL_INSTALL_PATH = os.path.join( + sys.prefix, + 'share', + 'contextscore', + 'models', + 'contextscore_model.pkl', +) + + +def user_message(message): + """Emit concise, user-facing progress messages.""" + print(f"{USER_PREFIX} {message}") + + +def configure_logging(verbose=False, debug=False): + """Configure logging output level based on user-selected mode.""" + level = logging.DEBUG if debug else (logging.INFO if verbose else logging.WARNING) + logging.basicConfig(level=level, format='%(asctime)s - %(levelname)s - %(message)s') + + +def resolve_annovar_paths(annovar_path, annovar_db_path): + """Resolve ANNOVAR paths from CLI flags or environment variables.""" + resolved_path = annovar_path or os.getenv('ANNOVAR_PATH') + resolved_db = annovar_db_path or os.getenv('ANNOVAR_DB_PATH') + return resolved_path, resolved_db + + +def resolve_model_path(model_path): + """Resolve model path from CLI, env var, or default installed location.""" + if model_path: + return model_path, 'cli' + + env_model_path = os.getenv(DEFAULT_MODEL_ENV_VAR) + if env_model_path: + return env_model_path, 'env' + + return DEFAULT_MODEL_INSTALL_PATH, 'default' + + +def validate_annovar_paths(annovar_path, annovar_db_path): + """Validate ANNOVAR installation paths before running feature extraction.""" + if not annovar_path: + raise ValueError( + 'ANNOVAR path is required. Set --annovar or environment variable ANNOVAR_PATH.' + ) + if not annovar_db_path: + raise ValueError( + 'ANNOVAR database path is required. Set --annovar-db or environment variable ANNOVAR_DB_PATH.' + ) + + annotate_variation = os.path.join(annovar_path, 'annotate_variation.pl') + table_annovar = os.path.join(annovar_path, 'table_annovar.pl') + if not os.path.isfile(annotate_variation) or not os.path.isfile(table_annovar): + raise ValueError( + f'Invalid ANNOVAR path: {annovar_path}. Expected annotate_variation.pl and table_annovar.pl in this directory.' + ) + if not os.path.isdir(annovar_db_path): + raise ValueError(f'ANNOVAR database directory does not exist: {annovar_db_path}') + + +def try_import_plotting_libs(): + """Attempt to import plotting libraries without failing prediction flow.""" + try: + plt = importlib.import_module('matplotlib.pyplot') + sns = importlib.import_module('seaborn') + return plt, sns + except ImportError: + return None, None + + +def open_vcf_text(path): + """Open VCF text input, supporting both plain and gzipped files.""" + if str(path).endswith('.gz'): + return gzip.open(path, 'rt', encoding='utf-8') + return open(path, 'r', encoding='utf-8') + + +def canonicalize_chromosome(chrom_value): + """Map CHROM values to canonical chr-prefixed labels; return None if unparseable.""" + if pd.isna(chrom_value): + return None + + chrom_str = str(chrom_value).strip() + if not chrom_str: + return None + + has_chr_prefix = chrom_str.lower().startswith('chr') + token = chrom_str[3:] if has_chr_prefix else chrom_str + token_upper = token.upper() + + if token_upper in {'M', 'MT'}: + return 'chrM' + if token_upper in {'X', 'Y'}: + return f'chr{token_upper}' + if token_upper.isdigit(): + token_num = int(token_upper) + if 1 <= token_num <= 22: + return f'chr{token_num}' + + # Keep non-canonical contigs as-is (e.g., GL*, KI*, NC_*). + if re.fullmatch(r'[A-Za-z0-9_.-]+', chrom_str): + return chrom_str + return None + +def create_bed(input_vcf, output_bed): + """Create a BED file from the input VCF file. Extract the following fields: + 1. Chromosome (CHROM) + 2. Start position (POS) + 3. End position (END) + 4. SV type (SVTYPE) + 5. SV length (SVLEN) + 6. Genotype (GT) + 7. Read depth (DP) + 8. HMM log likelihood (HMM) + 9. Alignment type (ALN) + 10. Cluster size (CLUSTER) + 11. Copy number state (CN) + 12. Read alignment offset (ALNOFFSET) + Args: + input_vcf (str): Path to the input VCF file. + output_bed (str): Path to the output BED file. + """ + logging.info('Reading VCF file: %s', input_vcf) + vcf_df = pd.read_csv(input_vcf, sep='\t', comment='#', header=None, + names=['CHROM', 'POS', 'INFO', 'FORMAT', 'SAMPLE'], usecols=[0, 1, 7, 8, 9], + dtype={'CHROM': str, 'POS': int, 'INFO': str, 'FORMAT': str, 'SAMPLE': str}) + + # Add a column for the ID field with the VCF row number + vcf_df['id'] = vcf_df.index + + # Normalize CHROM labels for robust annotation intersects. + vcf_df['CHROM_ORIG'] = vcf_df['CHROM'].astype(str) + vcf_df['CHROM'] = vcf_df['CHROM'].apply(canonicalize_chromosome) + invalid_chrom_mask = vcf_df['CHROM'].isna() + skipped_chrom_ids = set(vcf_df.loc[invalid_chrom_mask, 'id'].astype(int).tolist()) + if skipped_chrom_ids: + examples = vcf_df.loc[invalid_chrom_mask, 'CHROM_ORIG'].dropna().astype(str).unique()[:5] + logging.warning( + 'Skipping %d variants with unparseable CHROM labels during annotation/scoring. Examples: %s', + len(skipped_chrom_ids), + ', '.join(examples) if len(examples) > 0 else 'N/A', + ) + vcf_df = vcf_df.loc[~invalid_chrom_mask].copy() + + info_df = pd.DataFrame() + info_df['ALN'] = vcf_df['INFO'].str.extract(r'ALN=([^;]+)') + info_df['END'] = vcf_df['INFO'].str.extract(r'END=(\d+)') + info_df['SVTYPE'] = vcf_df['INFO'].str.extract(r'SVTYPE=([^;]+)') + info_df['SVLEN'] = vcf_df['INFO'].str.extract(r'SVLEN=([^;]+)') + info_df['HMM'] = vcf_df['INFO'].str.extract(r'HMM=([^;]+)') + info_df['CLUSTER'] = vcf_df['INFO'].str.extract(r'CLUSTER=([^;]+)') + info_df['CN'] = vcf_df['INFO'].str.extract(r'CN=([^;]+)') + info_df['ALNOFFSET'] = vcf_df['INFO'].str.extract(r'ALNOFFSET=([^;]+)') + + def _extract_sample_field(row, field_name): + sample_value = row.get('SAMPLE') + if pd.isna(sample_value): + return np.nan + + sample_parts = str(sample_value).split(':') + format_value = row.get('FORMAT') + if pd.notna(format_value): + format_parts = str(format_value).split(':') + try: + field_index = format_parts.index(field_name) + except ValueError: + field_index = None + if field_index is not None and field_index < len(sample_parts): + return sample_parts[field_index] + return np.nan + + # Fallback to the original positional interpretation when FORMAT is unavailable. + if field_name == 'GT': + return sample_parts[0] if len(sample_parts) > 0 else np.nan + if field_name == 'DP': + return sample_parts[1] if len(sample_parts) > 1 else np.nan + return np.nan + + # Extract the genotype (GT) and read depth (DP) from the SAMPLE column + sample_df = pd.DataFrame() + sample_df['GT'] = vcf_df.apply(lambda row: _extract_sample_field(row, 'GT'), axis=1) + sample_df['DP'] = pd.to_numeric( + vcf_df.apply(lambda row: _extract_sample_field(row, 'DP'), axis=1), + errors='coerce', + ).fillna(0).astype(int) + + # Create the BED file + bed_df = pd.DataFrame() + bed_df['CHROM'] = vcf_df['CHROM'] + bed_df['START'] = vcf_df['POS'] + bed_df['END'] = info_df['END'] + bed_df['SVTYPE'] = info_df['SVTYPE'] + bed_df['SVLEN'] = info_df['SVLEN'] + bed_df['GT'] = sample_df['GT'] + bed_df['DP'] = sample_df['DP'] + bed_df['HMM'] = info_df['HMM'] + bed_df['ALN'] = info_df['ALN'] + bed_df['CLUSTER'] = info_df['CLUSTER'] + bed_df['CN'] = info_df['CN'] + bed_df['ALNOFFSET'] = info_df['ALNOFFSET'] + bed_df['id'] = vcf_df['id'] + + # Save the BED file + bed_df.to_csv(output_bed, sep='\t', header=False, index=False) + logging.info('Created BED file: %s', output_bed) + return skipped_chrom_ids + +def score(model, input_vcf, output_vcf, buildver='hg38', threshold=0.05, + threshold_del=None, threshold_dup=None, threshold_ins=None, threshold_inv=None, + sample_coverage=None, large_cutoff=10000, annovar_path=None, annovar_db_path=None, + debug_plot=False): + """Score the structural variants using the binary classification model. + + Args: + model (str): Path to the model file. + input_vcf (str): Path to the input VCF file. + output_vcf (str): Path to the output VCF file. + threshold (float): Default threshold for SV types not specified. + threshold_del (float): Optional. Threshold for DEL variants. If None, uses default threshold. + threshold_dup (float): Optional. Threshold for DUP variants. If None, uses default threshold. + threshold_ins (float): Optional. Threshold for INS variants. If None, uses default threshold. + threshold_inv (float): Optional. Threshold for INV variants. If None, uses default threshold. + sample_coverage (float): Required. Mean read depth coverage for the sample. + large_cutoff (int): SV size cutoff in bp; variants larger than this are always kept (default: 50000). + """ + # Build threshold dictionary with type-specific values + threshold_by_type = { + 'DEL': threshold_del if threshold_del is not None else threshold, + 'DUP': threshold_dup if threshold_dup is not None else threshold, + 'INS': threshold_ins if threshold_ins is not None else threshold, + 'INV': threshold_inv if threshold_inv is not None else threshold, + } + + prob_threshold = threshold + logging.info('Using confidence threshold policy:') + for svtype, thr in sorted(threshold_by_type.items()): + logging.info(' %s: %.3f', svtype, thr) + + output_dir = os.path.dirname(os.path.abspath(output_vcf)) or '.' + os.makedirs(output_dir, exist_ok=True) + + # Create temporary annotation inputs in the output location so read-only input paths work. + with tempfile.TemporaryDirectory(prefix='contextscore_', dir=output_dir) as temp_workdir: + bed_file = os.path.join(temp_workdir, f"{os.path.splitext(os.path.basename(input_vcf))[0]}.bed") + skipped_chrom_ids = create_bed(input_vcf, bed_file) + logging.info('Created BED file: %s', bed_file) + if skipped_chrom_ids: + logging.info('Variants skipped from annotation/scoring due to unparseable CHROM: %d', len(skipped_chrom_ids)) + + # Extract the features from the BED file. + anno_outdir = os.path.join(temp_workdir, 'annotations') + os.makedirs(anno_outdir, exist_ok=True) + feature_df = extract_features(bed_file, annovar_path, annovar_db_path, anno_outdir, buildver, sample_coverage=sample_coverage) + + # Load the model + logging.info('Loading model from: %s', model) + clf = joblib.load(model) + logging.info('Model loaded successfully.') + + # Check if the feature extraction was successful + if feature_df.empty: + logging.error('Feature extraction failed. No features extracted.') + sys.exit(1) + + # Separate the ID column and keep variant metadata for downstream evaluation joins. + id_col = feature_df.pop('id') + + predictions_meta = pd.DataFrame({ + 'id': id_col.values, + 'chrom': feature_df['chrom'].astype(str).values if 'chrom' in feature_df.columns else np.nan, + 'start': pd.to_numeric(feature_df['start'], errors='coerce').astype('Int64').values if 'start' in feature_df.columns else pd.Series([pd.NA] * len(id_col), dtype='Int64').values, + 'end': pd.to_numeric(feature_df['end'], errors='coerce').astype('Int64').values if 'end' in feature_df.columns else pd.Series([pd.NA] * len(id_col), dtype='Int64').values, + 'sv_type_str': feature_df['sv_type_str'].astype(str).values if 'sv_type_str' in feature_df.columns else np.nan, + 'sv_length': pd.to_numeric(feature_df['sv_length'], errors='coerce').astype('Int64').values if 'sv_length' in feature_df.columns else pd.Series([pd.NA] * len(id_col), dtype='Int64').values, + }) + predictions_meta['sv_length_abs'] = predictions_meta['sv_length'].abs() + + # Remove other non-feature columns before prediction. + # Keep normalized *_per_kb features; remove raw versions. + for col in ['chrom', 'start', 'end', 'sv_type_str', 'cluster_size', 'dist_to_nearest_sv', 'read_depth']: + if col in feature_df.columns: + feature_df.pop(col) + + # Handle NaNs by filling with 0 (matching training's imputation fallback) + logging.info('Handling NaN values in features...') + nan_count_before = feature_df.isna().sum().sum() + if nan_count_before > 0: + logging.info('Found %d NaN values in prediction features. Filling with 0.', nan_count_before) + feature_df = feature_df.fillna(0) + + # Convert categorical/object columns to numeric (matching training preprocessing) + logging.info('Converting categorical features to numeric...') + for col in feature_df.columns: + if feature_df[col].dtype == 'category': + feature_df[col] = feature_df[col].cat.codes + elif feature_df[col].dtype == 'object': + feature_df[col] = pd.to_numeric(feature_df[col], errors='coerce') + + # Ensure all columns are float64 + feature_df = feature_df.fillna(0).astype('float64') + + # Run the model on the features + logging.info('Running the model on the features...') + y_pred = clf.predict_proba(feature_df) + + # Save per-variant probabilities for downstream threshold tuning. + predictions_tsv = os.path.join(output_dir, 'predictions.tsv') + predictions_df = predictions_meta.copy() + predictions_df['confidence_score'] = y_pred[:, 1] + predictions_df.to_csv(predictions_tsv, sep='\t', index=False) + logging.info('Saved per-variant predictions to %s', predictions_tsv) + + if debug_plot: + plt, sns = try_import_plotting_libs() + if plt is None or sns is None: + logging.warning('Debug plotting requested but matplotlib/seaborn are not installed. Skipping plot generation.') + else: + _, ax = plt.subplots() + sns.histplot(y_pred[:, 1], bins=20, ax=ax) + ax.set_xlabel('Confidence Score') + ax.set_ylabel('Count') + ax.set_title('Probability Distribution') + plot_path = os.path.join(output_dir, 'prob_dist.svg') + plt.savefig(plot_path) + plt.close() + logging.info('Saved debug probability plot to %s', plot_path) + + # Build a lookup dictionary: variant_id → (confidence_score, sv_type) for type-specific filtering + variant_lookup = {} + for _, row in predictions_df.iterrows(): + variant_lookup[row['id']] = (row['confidence_score'], row['sv_type_str']) + + logging.info('Built variant lookup with %d entries for type-specific filtering', len(variant_lookup)) + + # For backward compatibility, also track variants below the default threshold + filtered_indices = np.where(y_pred[:, 1] < prob_threshold)[0] + logging.info('Number of variants under the default probability threshold %.2f: %d', prob_threshold, len(filtered_indices)) + + # Get the IDs of the filtered variants (for logging/debugging) + filtered_ids = id_col.iloc[filtered_indices].values + filtered_ids_file = os.path.join(output_dir, 'filtered_ids.txt') + np.savetxt(filtered_ids_file, filtered_ids, fmt='%s') + logging.info('Saved the filtered IDs (using default threshold) to %s', filtered_ids_file) + + # Create a VCF file with only the filtered variants + removed_svs_vcf = os.path.join(output_dir, 'removed_svs.vcf') + + # Filter the input VCF file based on type-specific thresholds and SV length + # Keep all SVs >50kb regardless of confidence score; apply type-specific threshold to SVs <=50kb + logging.info('Filtering the input VCF file using type-specific thresholds and SV length...') + logging.info('Policy: Keep all SVs >50kb; apply type-specific thresholds to SVs <=50kb') + + current_record = 0 + pass_count = 0 + filter_count = 0 + total_records = 0 + type_filter_stats = {} # Track filtering statistics by type + + with open_vcf_text(input_vcf) as vcf_in, open(output_vcf, 'w', encoding='utf-8') as vcf_out, open(removed_svs_vcf, 'w', encoding='utf-8') as removed_out: + for line in vcf_in: + if line.startswith('#'): + # Write the header lines as they are + vcf_out.write(line) + removed_out.write(line) + else: + # Extract SVLEN and SVTYPE from the VCF INFO field + info_field = line.split('\t')[7] + svlen_match = None + svtype_match = None + + for field in info_field.split(';'): + if field.startswith('SVLEN='): + try: + svlen_match = int(field.split('=')[1]) + except (ValueError, IndexError): + svlen_match = None + elif field.startswith('SVTYPE='): + try: + svtype_match = field.split('=')[1] + except IndexError: + svtype_match = None + + if current_record in skipped_chrom_ids: + svtype_for_stats = svtype_match if svtype_match else 'UNKNOWN' + if svtype_for_stats not in type_filter_stats: + type_filter_stats[svtype_for_stats] = {'total': 0, 'kept': 0, 'filtered': 0} + type_filter_stats[svtype_for_stats]['total'] += 1 + type_filter_stats[svtype_for_stats]['kept'] += 1 + vcf_out.write(line) + pass_count += 1 + total_records += 1 + current_record += 1 + continue + + # Get confidence score and sv_type from predictions lookup + if current_record in variant_lookup: + confidence_score, predicted_svtype = variant_lookup[current_record] + # Use VCF SVTYPE if available, otherwise use predicted svtype + svtype = svtype_match if svtype_match else predicted_svtype + else: + # Variant not in predictions (shouldn't happen, but handle gracefully) + logging.warning('Variant %d not found in predictions lookup, using default threshold', current_record) + confidence_score = 0.0 + svtype = svtype_match if svtype_match else 'UNKNOWN' + + # Get the appropriate threshold for this SV type + type_threshold = threshold_by_type.get(svtype, prob_threshold) + + # Determine if variant should be kept + is_large_sv = svlen_match is not None and abs(svlen_match) > large_cutoff + passes_threshold = confidence_score >= type_threshold + + # Keep if: (large SV) OR (passes type-specific threshold) + should_keep = is_large_sv or passes_threshold + + # Track statistics by type + if svtype not in type_filter_stats: + type_filter_stats[svtype] = {'total': 0, 'kept': 0, 'filtered': 0} + type_filter_stats[svtype]['total'] += 1 + + if should_keep: + vcf_out.write(line) + pass_count += 1 + type_filter_stats[svtype]['kept'] += 1 + else: + # Write the line to the removed_svs.vcf file if filtered + removed_out.write(line) + filter_count += 1 + type_filter_stats[svtype]['filtered'] += 1 + + total_records += 1 + current_record += 1 + + logging.info('Filtered the input VCF file and saved it to %s', output_vcf) + logging.info('Scoring process completed successfully. Passed %d out of %d records.', pass_count, total_records) + logging.info('Removed %d records (low confidence and <=50kb). See %s for details.', filter_count, removed_svs_vcf) + + # Log filtering statistics by SV type + logging.info('Filtering statistics by SV type:') + for svtype in sorted(type_filter_stats.keys()): + stats = type_filter_stats[svtype] + kept_pct = 100.0 * stats['kept'] / stats['total'] if stats['total'] > 0 else 0 + logging.info(' %s: kept %d/%d (%.1f%%)', svtype, stats['kept'], stats['total'], kept_pct) + + return { + 'total_records': total_records, + 'passed_records': pass_count, + 'filtered_records': filter_count, + 'output_vcf': output_vcf, + 'removed_vcf': removed_svs_vcf, + 'predictions_tsv': predictions_tsv, + } + + +def main(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument('--input', type=str, required=True, + help='Path to the input VCF file.') + parser.add_argument('--output', type=str, required=True, + help='Path to the output VCF file.') + parser.add_argument('--model', type=str, required=False, default=None, + help='Path to the model file. Optional if CONTEXTSCORE_MODEL_PATH is set or default packaged model is installed.') + parser.add_argument('--buildver', type=str, default='hg38', + help='Genome build version (default: hg38).') + parser.add_argument('--threshold', type=float, default=0.2, + help='Default threshold for filtering predictions (default: 0.2). Used for SV types without specific thresholds.') + parser.add_argument('--threshold-del', type=float, default=None, + help='Threshold for DEL variants (default: uses --threshold value).') + parser.add_argument('--threshold-dup', type=float, default=None, + help='Threshold for DUP variants (default: uses --threshold value).') + parser.add_argument('--threshold-ins', type=float, default=None, + help='Threshold for INS variants (default: uses --threshold value).') + parser.add_argument('--threshold-inv', type=float, default=None, + help='Threshold for INV variants (default: uses --threshold value).') + parser.add_argument('--sample-coverage', type=float, required=True, + help='Mean read depth coverage for the sample (required, used to normalize read_depth).') + parser.add_argument('--large-cutoff', type=int, default=10000, + help='SV size cutoff in bp; variants larger than this are always kept (default: 50000).') + parser.add_argument('--annovar', type=str, default=None, + help='Path to ANNOVAR installation directory. Can also be set via ANNOVAR_PATH.') + parser.add_argument('--annovar-db', type=str, default=None, + help='Path to ANNOVAR database directory. Can also be set via ANNOVAR_DB_PATH.') + parser.add_argument('--verbose', action='store_true', + help='Show detailed progress logs.') + parser.add_argument('--debug', action='store_true', + help='Show debug logs including subprocess details.') + parser.add_argument('--debug-plot', action='store_true', + help='Generate probability distribution plot for debugging (optional, requires matplotlib and seaborn).') + + args = parser.parse_args(argv) + input_vcf = args.input + output_vcf = args.output + model, model_source = resolve_model_path(args.model) + + configure_logging(verbose=args.verbose, debug=args.debug) + user_message('Starting prediction run') + + # Check if the input VCF file exists + if not os.path.isfile(input_vcf): + logging.error('Input VCF file does not exist: %s', input_vcf) + sys.exit(1) + + # Check if the model file exists + if not os.path.isfile(model): + logging.error('Model file does not exist: %s', model) + user_message('Model path could not be resolved to an existing file.') + user_message('Provide --model /path/to/model.pkl, or set CONTEXTSCORE_MODEL_PATH, or install the contextscore-models package.') + if model_source == 'default': + user_message(f'Default expected path: {DEFAULT_MODEL_INSTALL_PATH}') + sys.exit(1) + + # Check if the output directory exists, if not create it + output_dir = os.path.dirname(os.path.abspath(output_vcf)) or '.' + if not os.path.exists(output_dir): + os.makedirs(output_dir) + logging.info('Created output directory: %s', output_dir) + + # Check if the input VCF file is a valid VCF file + if not input_vcf.endswith('.vcf') and not input_vcf.endswith('.vcf.gz'): + logging.error('Input file is not a valid VCF file: %s', input_vcf) + sys.exit(1) + if not output_vcf.endswith('.vcf'): + logging.error('Output file must have a .vcf extension: %s', output_vcf) + sys.exit(1) + if not model.endswith('.pkl'): + logging.error('Model file must have a .pkl extension: %s', model) + sys.exit(1) + + logging.info('Using model path from %s: %s', model_source, model) + + # Check the reference genome build version + buildver = args.buildver + if buildver not in ['hg19', 'hg38']: + logging.error('Unsupported genome build version: %s. Supported versions are hg19 and hg38.', buildver) + sys.exit(1) + + annovar_path, annovar_db_path = resolve_annovar_paths(args.annovar, args.annovar_db) + try: + validate_annovar_paths(annovar_path, annovar_db_path) + except ValueError as exc: + logging.error('%s', exc) + user_message('ANNOVAR setup is required before running prediction.') + user_message('Example: contextscore --input sample.vcf --output out.vcf --sample-coverage 30 --annovar /path/to/annovar --annovar-db /path/to/humandb') + user_message('Optional: add --model /path/to/model.pkl to override default model resolution.') + user_message('You can also set ANNOVAR_PATH and ANNOVAR_DB_PATH environment variables.') + sys.exit(2) + + user_message('Running feature extraction and scoring') + + # Run the scoring function + summary = score(model, input_vcf, output_vcf, buildver=buildver, + threshold=args.threshold, sample_coverage=args.sample_coverage, + threshold_del=args.threshold_del, threshold_dup=args.threshold_dup, + threshold_ins=args.threshold_ins, threshold_inv=args.threshold_inv, + large_cutoff=args.large_cutoff, annovar_path=annovar_path, + annovar_db_path=annovar_db_path, + debug_plot=args.debug_plot) + + user_message( + f"Completed. Kept {summary['passed_records']}/{summary['total_records']} variants; filtered {summary['filtered_records']}." + ) + user_message(f"Output VCF: {summary['output_vcf']}") + logging.info('Scoring process completed.') + + +if __name__ == '__main__': + main() diff --git a/contextscore/train_full_model.py b/contextscore/train_full_model.py new file mode 100644 index 0000000..d1f2d64 --- /dev/null +++ b/contextscore/train_full_model.py @@ -0,0 +1,807 @@ +""" +train_model.py - Train the binary classification model and evaluate using per-chromosome cross-validation and 80/20 train/test split. +""" + +import os +import logging +import joblib +import numpy as np +import pandas as pd + +from sklearn.model_selection import train_test_split +from sklearn.linear_model import LogisticRegression +from sklearn.ensemble import RandomForestClassifier +from sklearn.pipeline import Pipeline +from sklearn.model_selection import GridSearchCV, StratifiedKFold +from sklearn.svm import SVC + +from sklearn.metrics import roc_curve, auc + +try: + from .extract_features import extract_features +except ImportError: + from extract_features import extract_features + +# Set up the logger. +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +# Manuscript-friendly display labels for model features. +FEATURE_DISPLAY_NAMES = { + 'dist_nearest_sv_per_kb': 'Nearest SV distance / kb', + 'cluster_size_per_kb': 'Cluster size / kb', + 'sv_length': 'SV length (bp)', + 'read_depth_normalized': 'Normalized depth', + 'segdup_left': 'SegDup overlap (left)', + 'segdup_right': 'SegDup overlap (right)', + 'dist_to_telomere': 'Telomere distance', + 'dist_to_centromere': 'Centromere distance', + 'call_type': 'Call evidence type', + 'simpleRepeat_left': 'Simple repeat (left)', + 'simpleRepeat_right': 'Simple repeat (right)', + 'sv_type': 'SV type', + 'repeat_span_density': 'Repeat span density', + 'fragile_site': 'Fragile-site overlap', + 'phastCons': 'phastCons score', + 'hmm_llh': 'HMM log-likelihood', + 'aln_offset': 'Alignment offset' +} + +ENABLE_SHAP = False + +if ENABLE_SHAP: + import shap + +def get_display_feature_name(feature_name): + """Map internal feature keys to human-readable labels for plots/tables.""" + return FEATURE_DISPLAY_NAMES.get(feature_name, feature_name.replace('_', ' ')) + + +def get_display_feature_names(feature_names): + """Return human-readable labels in the same order as input feature names.""" + return [get_display_feature_name(name) for name in feature_names] + + +def preprocess_feature_matrix(feature_df): + """Convert mixed-type feature columns to numeric values for model fitting/inference.""" + processed_df = feature_df.copy() + for col in processed_df.columns: + if processed_df[col].dtype == 'category': + processed_df[col] = processed_df[col].cat.codes + elif processed_df[col].dtype == 'object': + processed_df[col] = pd.to_numeric(processed_df[col], errors='coerce') + + return processed_df.fillna(0).astype('float64') + + +def get_cv_splits(y, max_splits=5): + """Choose a valid number of stratified CV folds for the provided labels.""" + class_counts = y.value_counts() + if class_counts.empty or len(class_counts) < 2: + return None + + n_splits = min(max_splits, int(class_counts.min())) + if n_splits < 2: + return None + + return StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) + +def balance_tp_fp_datasets(tp_data, fp_data): + """Balance the true positive and false positive datasets by undersampling the lower-count class.""" + tp_count = tp_data.shape[0] + fp_count = fp_data.shape[0] + + if tp_count > fp_count: + logging.info('Balancing the dataset by undersampling the true positives (count = %d) to match the false positives (count = %d)', tp_count, fp_count) + tp_data = tp_data.sample(fp_count, random_state=42) + elif fp_count > tp_count: + logging.info('Balancing the dataset by undersampling the false positives (count = %d) to match the true positives (count = %d)', fp_count, tp_count) + fp_data = fp_data.sample(tp_count, random_state=42) + else: + logging.info('The dataset is already balanced. True positives: %d, False positives: %d', tp_count, fp_count) + + return tp_data, fp_data + + +def impute_missing_values(tp_data, fp_data): + """Impute missing values using TP-referenced statistics to avoid excessive row drops.""" + logging.info('Imputing NaN values using TP-referenced statistics...') + + # Report NaNs by column before imputation. + tp_nan = tp_data.isna().sum() + fp_nan = fp_data.isna().sum() + tp_nan = tp_nan[tp_nan > 0].sort_values(ascending=False) + fp_nan = fp_nan[fp_nan > 0].sort_values(ascending=False) + if not tp_nan.empty: + logging.info('TP NaN counts by column before imputation: %s', tp_nan.to_dict()) + if not fp_nan.empty: + logging.info('FP NaN counts by column before imputation: %s', fp_nan.to_dict()) + + bool_like_cols = { + 'fragile_site', 'phastCons', 'telomere', 'centromere', + 'simpleRepeat_left', 'simpleRepeat_right' + } + + shared_cols = [col for col in tp_data.columns if col in fp_data.columns and col != 'label'] + for col in shared_cols: + if not (tp_data[col].isna().any() or fp_data[col].isna().any()): + continue + + if col in bool_like_cols: + tp_data[col] = tp_data[col].fillna(False) + fp_data[col] = fp_data[col].fillna(False) + continue + + if pd.api.types.is_numeric_dtype(tp_data[col]): + fill_value = tp_data[col].median(skipna=True) + if pd.isna(fill_value): + fill_value = 0.0 + tp_data[col] = tp_data[col].fillna(fill_value) + fp_data[col] = fp_data[col].fillna(fill_value) + continue + + # Categorical/object fallback: use TP mode, else placeholder. + mode_values = tp_data[col].mode(dropna=True) + fill_value = mode_values.iloc[0] if not mode_values.empty else 'UNKNOWN' + tp_data[col] = tp_data[col].fillna(fill_value) + fp_data[col] = fp_data[col].fillna(fill_value) + + # Report NaNs after imputation. + tp_remaining = int(tp_data.isna().sum().sum()) + fp_remaining = int(fp_data.isna().sum().sum()) + logging.info('NaN imputation complete. Remaining NaNs - TP: %d, FP: %d', tp_remaining, fp_remaining) + + return tp_data, fp_data + + +def stratified_undersample_fp(fp_data, target_count, random_state=42): + """Undersample false positives using stratified sampling to preserve SV type and length distribution. + + Args: + fp_data (pd.DataFrame): False positive data to undersample. + target_count (int): Target number of samples to retain. + random_state (int): Random seed for reproducibility. + + Returns: + pd.DataFrame: Undersampled false positive data. + """ + logging.info('Performing stratified undersampling of false positives (count = %d) to target (count = %d)', + fp_data.shape[0], target_count) + + # Create length bins for stratification + fp_data_temp = fp_data.copy() + fp_data_temp['length_bin'] = pd.cut(fp_data_temp['sv_length'], + bins=[0, 1000, 10000, 100000, float('inf')], + labels=['<1kb', '1-10kb', '10-100kb', '>100kb']) + + # Create stratification column combining SV type and length bin + fp_data_temp['stratum'] = fp_data_temp['sv_type'].astype(str) + '_' + fp_data_temp['length_bin'].astype(str) + + # Calculate target sample size per stratum (proportional to original distribution) + stratum_counts = fp_data_temp['stratum'].value_counts() + stratum_fracs = stratum_counts / len(fp_data_temp) + + logging.info('Sampling from %d strata with proportional allocation', len(stratum_counts)) + + # Sample from each stratum proportionally + sampled_dfs = [] + for stratum, frac in stratum_fracs.items(): + stratum_data = fp_data_temp[fp_data_temp['stratum'] == stratum] + n_samples = max(1, int(round(frac * target_count))) # At least 1 sample per stratum + n_samples = min(n_samples, len(stratum_data)) # Can't sample more than available + sampled = stratum_data.sample(n=n_samples, random_state=random_state) + sampled_dfs.append(sampled) + + fp_data_balanced = pd.concat(sampled_dfs, ignore_index=True) + + # Drop temporary columns + fp_data_balanced = fp_data_balanced.drop(columns=['length_bin', 'stratum']) + + # If we're slightly off from target due to rounding, adjust by random sampling + if len(fp_data_balanced) > target_count: + fp_data_balanced = fp_data_balanced.sample(n=target_count, random_state=random_state) + elif len(fp_data_balanced) < target_count: + # Sample additional rows to reach target + n_additional = target_count - len(fp_data_balanced) + additional = fp_data.sample(n=n_additional, random_state=random_state+1) + fp_data_balanced = pd.concat([fp_data_balanced, additional], ignore_index=True) + + logging.info('Stratified undersampling complete. Final count: %d', len(fp_data_balanced)) + + return fp_data_balanced + + +def train(tp_hg002_grch37, fp_hg002_grch37, tp_visor_grch38, fp_visor_grch38, tp_na12877_grch38, fp_na12877_grch38, tp_na12878_grch38, fp_na12878_grch38, tp_na12879_grch38, fp_na12879_grch38, output_directory, annovar_path, db_path, outdiranno, leave_out="none", split_80_20=False, per_chr_validation=False, sample_coverage_hg002=None, sample_coverage_visor=None, sample_coverage_na12877=None, sample_coverage_na12878=None, sample_coverage_na12879=None): + """Train the binary classification model. + + Args: + sample_coverage_hg002 (float): Required. Mean read depth coverage for HG002 sample. + sample_coverage_visor (float): Required. Mean read depth coverage for Visor sample. + sample_coverage_na12877 (float): Required. Mean read depth coverage for NA12877 sample. + sample_coverage_na12878 (float): Required. Mean read depth coverage for NA12878 sample. + sample_coverage_na12879 (float): Required. Mean read depth coverage for NA12879 sample. + """ + + # --------------------------------------------------------------- + # SV Feature Extraction + # --------------------------------------------------------------- + + # Set paths to none if leave_out is set to the corresponding dataset + no_leave_out = False + if leave_out == "hg002": + logging.info('Leaving out HG002 dataset from training.') + tp_hg002_grch37 = None + fp_hg002_grch37 = None + elif leave_out == "visor": + logging.info('Leaving out Visor dataset from training.') + tp_visor_grch38 = None + fp_visor_grch38 = None + elif leave_out == "platinum": + logging.info('Leaving out Platinum Pedigree datasets (all 3 samples) from training.') + tp_na12877_grch38 = None + fp_na12877_grch38 = None + tp_na12878_grch38 = None + fp_na12878_grch38 = None + tp_na12879_grch38 = None + fp_na12879_grch38 = None + else: + logging.info('Not leaving out any dataset from training.') + no_leave_out = True + + # =============================================================== + # Extract the features from the VCF files. + # =============================================================== + # GRCh38 data. + logging.info('Extracting features from the true positive and false positive VCF files (GRCh38).') + buildversion = 'hg38' + tp_visor_anno = extract_features(tp_visor_grch38, annovar_path, db_path, os.path.join(outdiranno, "tp_visor_anno_grch38"), buildversion=buildversion, sample_coverage=sample_coverage_visor) if tp_visor_grch38 is not None else None + fp_visor_anno = extract_features(fp_visor_grch38, annovar_path, db_path, os.path.join(outdiranno, "fp_visor_anno_grch38"), buildversion=buildversion, sample_coverage=sample_coverage_visor) if fp_visor_grch38 is not None else None + + tp_na12877_anno = extract_features(tp_na12877_grch38, annovar_path, db_path, os.path.join(outdiranno, "tp_na12877_anno_grch38"), buildversion=buildversion, sample_coverage=sample_coverage_na12877) if tp_na12877_grch38 is not None else None + fp_na12877_anno = extract_features(fp_na12877_grch38, annovar_path, db_path, os.path.join(outdiranno, "fp_na12877_anno_grch38"), buildversion=buildversion, sample_coverage=sample_coverage_na12877) if fp_na12877_grch38 is not None else None + + tp_na12878_anno = extract_features(tp_na12878_grch38, annovar_path, db_path, os.path.join(outdiranno, "tp_na12878_anno_grch38"), buildversion=buildversion, sample_coverage=sample_coverage_na12878) if tp_na12878_grch38 is not None else None + fp_na12878_anno = extract_features(fp_na12878_grch38, annovar_path, db_path, os.path.join(outdiranno, "fp_na12878_anno_grch38"), buildversion=buildversion, sample_coverage=sample_coverage_na12878) if fp_na12878_grch38 is not None else None + + tp_na12879_anno = extract_features(tp_na12879_grch38, annovar_path, db_path, os.path.join(outdiranno, "tp_na12879_anno_grch38"), buildversion=buildversion, sample_coverage=sample_coverage_na12879) if tp_na12879_grch38 is not None else None + fp_na12879_anno = extract_features(fp_na12879_grch38, annovar_path, db_path, os.path.join(outdiranno, "fp_na12879_anno_grch38"), buildversion=buildversion, sample_coverage=sample_coverage_na12879) if fp_na12879_grch38 is not None else None + + # HG002 data (GRCh37). + logging.info('Extracting features from the true positive and false positive VCF files (HG002-GRCh37).') + buildversion = 'hg19' + tp_hg002_anno = extract_features(tp_hg002_grch37, annovar_path, db_path, os.path.join(outdiranno, "tp_anno_grch37"), buildversion=buildversion, sample_coverage=sample_coverage_hg002) if tp_hg002_grch37 is not None else None + fp_hg002_anno = extract_features(fp_hg002_grch37, annovar_path, db_path, os.path.join(outdiranno, "fp_anno_grch37"), buildversion=buildversion, sample_coverage=sample_coverage_hg002) if fp_hg002_grch37 is not None else None + + # Concatenate the data from all datasets. + logging.info('Concatenating the data from all datasets.') + tp_data = pd.concat([df for df in [tp_visor_anno, tp_na12877_anno, tp_na12878_anno, tp_na12879_anno, tp_hg002_anno] if df is not None], ignore_index=True) + fp_data = pd.concat([df for df in [fp_visor_anno, fp_na12877_anno, fp_na12878_anno, fp_na12879_anno, fp_hg002_anno] if df is not None], ignore_index=True) + + # --------------------------------------------------------------- + # Data Preprocessing + # --------------------------------------------------------------- + + # Remove duplicate rows from the concatenated data. + tp_count_before = tp_data.shape[0] + tp_data.drop_duplicates(inplace=True) + tp_count_after = tp_data.shape[0] + fp_count_before = fp_data.shape[0] + fp_data.drop_duplicates(inplace=True) + fp_count_after = fp_data.shape[0] + logging.info('Removed %d tp duplicates and %d fp duplicates from the concatenated data. Remaining true positives: %d, remaining false positives: %d', tp_count_before - tp_count_after, fp_count_before - fp_count_after, tp_data.shape[0], fp_data.shape[0]) + + # Add the labels. + tp_data['label'] = 1 + fp_data['label'] = 0 + + # Print the number of true positives and false positives. + logging.info('Number of true labels: %d', tp_data.shape[0]) + logging.info('Number of false labels: %d', fp_data.shape[0]) + + # Impute NaN values from the data using TP-referenced statistics. + tp_data, fp_data = impute_missing_values(tp_data, fp_data) + + # Safety drop for any residual NaNs that could break downstream training. + logging.info('Dropping any residual NaN rows after imputation.') + tp_data = tp_data.dropna() + fp_data = fp_data.dropna() + logging.info('Number of true labels after impute+dropna: %d', tp_data.shape[0]) + logging.info('Number of false labels after impute+dropna: %d', fp_data.shape[0]) + + # Instead of undersampling, use class_weight='balanced' in Random Forest + # to handle class imbalance while preserving all training data. + logging.info('Skipping undersampling - will use class_weight="balanced" instead') + logging.info('Final class counts - TP: %d, FP: %d', tp_data.shape[0], fp_data.shape[0]) + + # Combine the true positive and false positive data. + data = pd.concat([tp_data, fp_data], ignore_index=True) # Ignore the index to realign the indices. + + # Pop the chrom column to use it later for cross-validation. + chrom_col = data.pop('chrom') + + # Drop columns that are not needed for training. + # Keep normalized *_per_kb features; remove raw versions. + data = data.drop(columns=['start', 'end', 'sv_type_str', 'cluster_size', 'dist_to_nearest_sv', 'read_depth'], errors='ignore') + + logging.info('Columns list after preprocessing: %s', data.columns.tolist()) + + # Print duplicate columns if any. + duplicate_columns = data.columns[data.columns.duplicated()].tolist() + if duplicate_columns: + logging.warning('Duplicate columns found: %s', duplicate_columns) + + # Get the features and labels. + features = data.drop(columns=['label']) + labels = data["label"] + + # Print the number of features. + logging.info('Number of features: %d', features.shape[1]) + logging.info('Feature names: %s', features.columns.tolist()) + if split_80_20: + # Split the data into training and testing sets using stratified sampling to maintain the class balance. + logging.info('Splitting the data into training and testing sets using an 80-20 split with stratified sampling to maintain class balance.') + X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42, stratify=labels) + else: + # Use all the data for training and testing. We will use cross-validation to evaluate the model performance. + logging.info('Using all the data for training and testing. Cross-validation will be used to evaluate the model performance.') + X_train, y_train = features, labels + X_test, y_test = features, labels + + # If not 80/20 split, use XGBoost and Random Forest only (highest performing models) to save time. + if split_80_20: + try: + from xgboost import XGBClassifier + except ImportError as exc: + raise ImportError( + 'xgboost is required when --split-80-20 is enabled. Install xgboost to train with this option.' + ) from exc + pipelines = { + "Random_Forest": Pipeline([('classifier', RandomForestClassifier(n_estimators=100, random_state=42))]), + "XGBoost": Pipeline([('classifier', XGBClassifier(n_estimators=100, eval_metric='logloss', random_state=42, enable_categorical=False))]) + } + # pipelines = { + # "Logistic_Regression": Pipeline([('classifier', LogisticRegression(max_iter=1000, random_state=42))]), + # "Random_Forest": Pipeline([('classifier', RandomForestClassifier(n_estimators=100, random_state=42))]), + # "XGBoost": Pipeline([('classifier', XGBClassifier(n_estimators=100, eval_metric='logloss', random_state=42, enable_categorical=False))]) + # } + else: + pipelines = { + "Random_Forest": Pipeline([('classifier', RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced'))]), + } + + param_grids = { + "Logistic_Regression": { + 'classifier__C': [0.01, 0.1, 1, 10], + 'classifier__penalty': ['l1', 'l2'], + 'classifier__solver': ['liblinear'] + }, + "Random_Forest": { + 'classifier__n_estimators': [100, 200], + 'classifier__max_depth': [None, 10, 20], + 'classifier__min_samples_split': [2, 5], + 'classifier__min_samples_leaf': [1, 2] + }, + "XGBoost": { + 'classifier__n_estimators': [150, 250], # Slightly more trees + 'classifier__max_depth': [3, 6], + 'classifier__learning_rate': [0.01, 0.1], + 'classifier__subsample': [0.8, 1] + } + } + + if per_chr_validation: + # ====================================================== + # Evaluate the model using per-chromosome cross-validation, but don't save. + # ====================================================== + logging.info('Evaluating the model using per-chromosome cross-validation.') + + # Remove chrY from the analysis. More than half is missing in GRCh38 and leads to high false positive rates. + chromosomes = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', + 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] + + logging.info('Chromosomes: %s', chromosomes) + f1_scores = {} + precision_scores = {} + recall_scores = {} + for model_name, pipeline in pipelines.items(): + # Dictionary with number of SVs in the held-out set for each chromosome. + sv_counts = {chrom: features[chrom_col == chrom].shape[0] for chrom in chromosomes} + logging.info('Number of SVs in the held-out set for each chromosome: %s', sv_counts) + + for chrom in chromosomes: + logging.info('Training the %s model on chromosome %s.', model_name, chrom) + + # Split the data into training and testing sets by chromosome. + X_train_chrom = features[chrom_col != chrom].copy() + y_train_chrom = labels[chrom_col != chrom].copy() + X_test_chrom = features[chrom_col == chrom].copy() + y_test_chrom = labels[chrom_col == chrom].copy() + + logging.info('Training set size: %d, Testing set size: %d', + X_train_chrom.shape[0], X_test_chrom.shape[0]) + + X_train_chrom_processed = preprocess_feature_matrix(X_train_chrom) + X_test_chrom_processed = preprocess_feature_matrix(X_test_chrom) + + fold_cv = get_cv_splits(y_train_chrom) + if fold_cv is None: + logging.warning( + 'Skipping chromosome %s for %s: insufficient class balance for stratified CV.', + chrom, + model_name + ) + continue + + grid_search = GridSearchCV( + estimator=pipeline, + param_grid=param_grids[model_name], + cv=fold_cv, + scoring='precision', + n_jobs=-1 + ) + grid_search.fit(X_train_chrom_processed, y_train_chrom) + best_model = grid_search.best_estimator_ + logging.info( + 'Best hyperparameters for %s on held-out chromosome %s: %s', + model_name, + chrom, + grid_search.best_params_ + ) + + # Get the predicted probabilities for the testing set. + y_test_chrom_prob = best_model.predict_proba(X_test_chrom_processed)[:, 1] + + # Compute the ROC curve and ROC area for the testing set. + fpr_chrom, tpr_chrom, _ = roc_curve(y_test_chrom, y_test_chrom_prob) + roc_auc_chrom = auc(fpr_chrom, tpr_chrom) + logging.info('ROC AUC score for the %s model on chromosome %s: %f', model_name, chrom, roc_auc_chrom) + + # Compute the F1 score for the testing set. + from sklearn.metrics import f1_score + y_test_chrom_pred = (y_test_chrom_prob >= 0.5).astype(int) # Use a threshold of 0.5 for classification. + f1 = f1_score(y_test_chrom, y_test_chrom_pred) + f1_scores[(model_name, chrom)] = f1 + logging.info('F1 score for the %s model on chromosome %s: %f', model_name, chrom, f1) + + # Compute precision and recall for the testing set. + from sklearn.metrics import precision_score, recall_score + precision = precision_score(y_test_chrom, y_test_chrom_pred) + recall = recall_score(y_test_chrom, y_test_chrom_pred) + precision_scores[(model_name, chrom)] = precision + recall_scores[(model_name, chrom)] = recall + logging.info('Precision for the %s model on chromosome %s: %f', model_name, chrom, precision) + logging.info('Recall for the %s model on chromosome %s: %f', model_name, chrom, recall) + + logging.info('Cross-validation analysis completed. F1 scores: %s', f1_scores) + + # Plot the F1 scores for each model and chromosome (one plot per model). + logging.info('Plotting the scores for each model and chromosome.') + try: + import matplotlib.pyplot as plt + import seaborn as sns + except ImportError as exc: + raise ImportError( + 'matplotlib and seaborn are required for per-chromosome validation plots.' + ) from exc + metrics = ['F1 Score', 'Precision', 'Recall'] + for model_name in pipelines.keys(): + + # Save a plot with F1, Precision, and Recall scores for chrY + if 'chrY' in chromosomes: + logging.info('Plotting scores for %s model on chrY.', model_name) + + # Create a bar plot for the F1 scores by chromosome. + chry_f1 = f1_scores.get((model_name, 'chrY'), 0) + chry_precision = precision_scores.get((model_name, 'chrY'), 0) + chry_recall = recall_scores.get((model_name, 'chrY'), 0) + plt.figure(figsize=(6, 4)) + + # Plot F1, Precision, and Recall scores for chrY. + sns.barplot(x=['F1 Score', 'Precision', 'Recall'], y=[chry_f1, chry_precision, chry_recall], color='black') + + # plt.xlabel('Metric') + plt.ylabel('Score') + plt.title('%s Scores for %s Model on chrY' % (model_name, model_name)) + plt.xticks(rotation=45) + plt.tight_layout() + # Save the plot to the output directory. + score_plot_path = os.path.join(output_directory, model_name + '_scores_chrY.svg') + plt.savefig(score_plot_path) + plt.close() + logging.info('Saved the scores plot for chrY to %s', score_plot_path) + + for metric, scores in zip(metrics, [f1_scores, precision_scores, recall_scores]): + logging.info('Plotting %s for %s model by chromosome.', metric, model_name) + # Create a bar plot for the F1 scores by chromosome. + model_scores = {chrom: scores[(model_name, chrom)] for chrom in chromosomes if (model_name, chrom) in scores} + plt.figure(figsize=(10, 6)) + ax = sns.barplot(x=list(model_scores.keys()), y=list(model_scores.values()), color='black') + + plt.xlabel('Chromosome') + plt.ylabel(metric) + plt.title('%s for %s Model by Chromosome' % (metric, model_name)) + plt.xticks(rotation=45) + plt.tight_layout() + score_plot_path = os.path.join(output_directory, model_name + '_%s_by_chromosome.svg' % metric.lower().replace(' ', '_')) + plt.savefig(score_plot_path) + plt.close() + logging.info('Saved the %s plot to %s', metric, score_plot_path) + + else: + # ======================================================= + # Train the model using cross-validation and grid search for hyperparameter tuning. + # ======================================================= + cv = get_cv_splits(y_train) + if cv is None: + raise ValueError('Unable to run training: need at least two classes with at least two samples each for stratified CV.') + + for model_name, pipeline in pipelines.items(): + logging.info('Training model class %s', model_name) + model_name_fp = "contextscore_" + model_name.lower() + "_leaveout_" + leave_out + + if split_80_20: + model_name_fp += "_80_20_split" + + # Perform grid search to find the best hyperparameters for the model, optimizing for precision to prioritize reducing false positives. + X_train_processed = preprocess_feature_matrix(X_train) + X_test_processed = preprocess_feature_matrix(X_test) + + grid_search = GridSearchCV(estimator=pipeline, param_grid=param_grids[model_name], cv=cv, scoring='precision', n_jobs=-1) + grid_search.fit(X_train_processed, y_train) + logging.info('Best hyperparameters for %s: %s', model_name, grid_search.best_params_) + + # Get predicted probabilities for the training and testing sets. + best_model = grid_search.best_estimator_ + + # Save plots only for 80-20 split since the ROC curve will be overly optimistic when using all the data for training and testing. + if split_80_20: + try: + import matplotlib.pyplot as plt + except ImportError as exc: + raise ImportError( + 'matplotlib is required when --split-80-20 is enabled to generate ROC plots.' + ) from exc + y_train_prob = best_model.predict_proba(X_train_processed)[:, 1] + y_test_prob = best_model.predict_proba(X_test_processed)[:, 1] + + # Compute the ROC curve and ROC area for the training set. + fpr_train, tpr_train, _ = roc_curve(y_train, y_train_prob) + roc_auc_train = auc(fpr_train, tpr_train) + + # Compute the ROC curve and ROC area for the testing set. + fpr_test, tpr_test, thresholds = roc_curve(y_test, y_test_prob) + roc_auc_test = auc(fpr_test, tpr_test) + + # Print the ROC AUC scores. + logging.info('ROC AUC score for the training set: %f', roc_auc_train) + logging.info('ROC AUC score for the testing set: %f', roc_auc_test) + + # Plot the ROC curve for the training set. + plt.figure() + plt.plot(fpr_train, tpr_train, color='blue', lw=2, label='ROC curve (area = %0.3f)' % roc_auc_train) + # plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc) + plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + model_name_label = model_name.replace("_", " ") + plt.title('{} Receiver Operating Characteristic (Training Set)'.format(model_name_label)) + plt.legend(loc='lower right') + roc_plot_path = os.path.join(output_directory, model_name_fp + '_roc_curve_train.svg') + plt.savefig(roc_plot_path) + plt.close() + logging.info('Saved the ROC curve to %s', roc_plot_path) + + # Plot the ROC curve for the testing set. + plt.figure() + plt.plot(fpr_test, tpr_test, color='blue', lw=2, label='ROC curve (area = %0.3f)' % roc_auc_test) + plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title('{} Receiver Operating Characteristic (Testing Set)'.format(model_name_label)) + plt.legend(loc='lower right') + # Save the plot to the output directory. + roc_plot_path = os.path.join(output_directory, model_name + '_roc_curve_test.svg') + plt.savefig(roc_plot_path) + plt.close() + logging.info('Saved the ROC curve to %s', roc_plot_path) + else: + # Save the model to the output directory as a pickle file. + model_path = os.path.join(output_directory, model_name_fp + '_model.pkl') + joblib.dump(best_model, model_path) + logging.info('Saved the %s model to %s', model_name, model_path) + + logging.info('Completed training and evaluation for %s model.', model_name) + + # Run SHAP if full analysis and no leave-outs (SHAP is slow) + if not split_80_20 and no_leave_out: + logging.info('Running feature importance analysis for %s model.', model_name) + classifier = best_model.named_steps['classifier'] + + # For Random Forest, use both native importance and SHAP (with aggressive sampling) + if model_name == 'Random_Forest': + try: + import matplotlib.pyplot as plt + + # 1. Native Random Forest feature importance (instant) + feature_importances = classifier.feature_importances_ + feature_names = X_train.columns.tolist() + display_feature_names = get_display_feature_names(feature_names) + + importance_df = pd.DataFrame({ + 'feature': feature_names, + 'feature_display': display_feature_names, + 'importance': feature_importances + }).sort_values('importance', ascending=True) + + plt.figure(figsize=(10, 8)) + plt.bar(importance_df['feature_display'], importance_df['importance']) + plt.ylabel('Feature Importance') + plt.xlabel('Feature') + plt.title('Random Forest Feature Importance') + plt.xticks(rotation=45, ha='right') + plt.tight_layout() + importance_plot_path = os.path.join(output_directory, model_name_fp + '_feature_importance_plot.svg') + plt.savefig(importance_plot_path, dpi=300, bbox_inches='tight') + plt.close() + logging.info('Saved Random Forest feature-importance plot to %s', importance_plot_path) + + importance_csv_path = os.path.join(output_directory, model_name_fp + '_feature_importances.csv') + importance_df.sort_values('importance', ascending=False).to_csv(importance_csv_path, index=False) + logging.info('Saved feature importances to %s', importance_csv_path) + + if ENABLE_SHAP: + # 2. SHAP analysis with aggressive sampling for efficiency + logging.info('Computing SHAP values for Random Forest (with sampling)...') + X_train_numeric = X_train.copy() + for col in X_train_numeric.columns: + if X_train_numeric[col].dtype == 'object': + X_train_numeric[col] = pd.to_numeric(X_train_numeric[col], errors='coerce') + X_train_numeric = X_train_numeric.fillna(0).astype('float64') + + # Aggressive sampling for RF SHAP: reduce from 148k to ~300 samples + explain_size = min(300, len(X_train_numeric)) + background_size = min(50, len(X_train_numeric) // 100) # ~1% of data + X_explain = shap.sample(X_train_numeric, explain_size, random_state=42) + X_background = shap.sample(X_train_numeric, background_size, random_state=42) + + logging.info('SHAP RF: explain_size=%d, background_size=%d (from %d total)', + explain_size, background_size, len(X_train_numeric)) + + # Use interventional mode for standard SHAP values (not interactions) + explainer = shap.TreeExplainer(classifier) + shap_values = explainer.shap_values(X_explain, check_additivity=False) + + logging.info('SHAP raw output type: %s, raw shape: %s', + type(shap_values), + shap_values.shape if hasattr(shap_values, 'shape') else 'N/A') + + # Handle different output formats + if isinstance(shap_values, list): + # List of arrays for each class + shap_values = shap_values[1] # Use positive class + elif len(shap_values.shape) == 3: + # 3D array: (n_samples, n_features, n_classes) + shap_values = shap_values[:, :, 1] # Select positive class + + logging.info('SHAP debug: shap_values shape=%s (final), X_explain shape=%s', + shap_values.shape, X_explain.shape) + + # Ensure X_explain is explicitly indexed by feature names + X_explain_for_plot = X_explain.reset_index(drop=True) + X_explain_display = X_explain_for_plot.rename(columns=get_display_feature_name) + + # SHAP summary plot + plt.figure(figsize=(12, 8)) + shap.summary_plot(shap_values, X_explain_display, show=False, max_display=15) + shap_plot_path = os.path.join(output_directory, model_name_fp + '_shap_summary_plot.svg') + plt.savefig(shap_plot_path, dpi=300, bbox_inches='tight') + plt.close() + logging.info('Saved SHAP summary plot to %s', shap_plot_path) + + # SHAP bar plot (mean |SHAP|) + plt.figure(figsize=(10, 8)) + shap.summary_plot(shap_values, X_explain_display, plot_type='bar', show=False) + bar_plot_path = os.path.join(output_directory, model_name_fp + '_shap_importance_plot.svg') + plt.savefig(bar_plot_path, dpi=300, bbox_inches='tight') + plt.close() + logging.info('Saved SHAP importance plot to %s', bar_plot_path) + + except Exception as exc: + logging.warning('SHAP analysis skipped for %s: %s', model_name, exc) + + # For other models, use SHAP + else: + if ENABLE_SHAP: + logging.info('Computing SHAP values for %s model...', model_name) + # Prepare numeric data for SHAP + X_train_numeric = X_train.copy() + for col in X_train_numeric.columns: + if X_train_numeric[col].dtype == 'object': + X_train_numeric[col] = pd.to_numeric(X_train_numeric[col], errors='coerce') + + X_train_numeric = X_train_numeric.fillna(0).astype('float64') + + # Bound SHAP workload to avoid OOM/core-dump on large full-model runs. + explain_size = min(5000, len(X_train_numeric)) + background_size = min(300, len(X_train_numeric)) + X_explain = shap.sample(X_train_numeric, explain_size, random_state=42) + X_background = shap.sample(X_train_numeric, background_size, random_state=42) + + logging.info( + 'SHAP sampling: explain_size=%d, background_size=%d (from %d training rows)', + len(X_explain), len(X_background), len(X_train_numeric) + ) + + try: + if model_name == 'XGBoost': + explainer = shap.TreeExplainer(classifier, feature_perturbation='tree_path_dependent') + shap_values = explainer.shap_values(X_explain) + elif model_name == 'Logistic_Regression': + explainer = shap.LinearExplainer(classifier, X_background) + shap_values = explainer.shap_values(X_explain) + else: + explainer = shap.Explainer(classifier, X_background) + shap_values = explainer(X_explain) + + # Some SHAP explainers return one array per class. For binary + # classification plots, use positive class values. + if isinstance(shap_values, list) and len(shap_values) > 1: + shap_values_to_plot = shap_values[1] + else: + shap_values_to_plot = shap_values + + X_explain_display = X_explain.rename(columns=get_display_feature_name) + + # 1. Summary plot + plt.figure(figsize=(10, 8)) + shap.summary_plot(shap_values_to_plot, X_explain_display, show=False) + shap_plot_path = os.path.join(output_directory, model_name_fp + '_shap_summary_plot.svg') + plt.savefig(shap_plot_path, dpi=300, bbox_inches='tight') + plt.close() + logging.info('Saved the SHAP summary plot to %s', shap_plot_path) + + # 2. Bar plot showing mean absolute SHAP values (feature importance) + plt.figure(figsize=(10, 8)) + shap.summary_plot(shap_values_to_plot, X_explain_display, plot_type='bar', show=False) + bar_plot_path = os.path.join(output_directory, model_name_fp + '_shap_importance_plot.svg') + plt.savefig(bar_plot_path, dpi=300, bbox_inches='tight') + plt.close() + logging.info('Saved the SHAP importance plot to %s', bar_plot_path) + except Exception as exc: + logging.warning('SHAP analysis failed for %s: %s. Continuing without SHAP outputs.', model_name, exc) + +if __name__ == '__main__': + # Parse the command line arguments. + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--tp_hg002_grch37", required=True, help="Path to the true positive BED file for HG002 in GRCh37") + parser.add_argument("--fp_hg002_grch37", required=True, help="Path to the false positive BED file for HG002 in GRCh37") + parser.add_argument("--tp_visor_grch38", required=True, help="Path to the true positive BED file for Visor in GRCh38") + parser.add_argument("--fp_visor_grch38", required=True, help="Path to the false positive BED file for Visor in GRCh38") + parser.add_argument("--tp_na12877_grch38", required=True, help="Path to the true positive BED file for NA12877 in GRCh38") + parser.add_argument("--fp_na12877_grch38", required=True, help="Path to the false positive BED file for NA12877 in GRCh38") + parser.add_argument("--tp_na12878_grch38", required=True, help="Path to the true positive BED file for NA12878 in GRCh38") + parser.add_argument("--fp_na12878_grch38", required=True, help="Path to the false positive BED file for NA12878 in GRCh38") + parser.add_argument("--tp_na12879_grch38", required=True, help="Path to the true positive BED file for NA12879 in GRCh38") + parser.add_argument("--fp_na12879_grch38", required=True, help="Path to the false positive BED file for NA12879 in GRCh38") + parser.add_argument("--outdiranno", required=True, help="Output directory for saving the ANNOVAR annotations") + parser.add_argument("--outdir", required=True, help="Output directory for saving the model") + parser.add_argument("--annovar", required=True, help="Path to ANNOVAR") + parser.add_argument("--annovar_db", required=True, help="Path to ANNOVAR database") + parser.add_argument("--leave_out", required=True, help="Which dataset to leave out for training") + parser.add_argument("--sample_coverage_hg002", type=float, required=True, help="Mean read depth coverage for HG002 sample (required)") + parser.add_argument("--sample_coverage_visor", type=float, required=True, help="Mean read depth coverage for Visor sample (required)") + parser.add_argument("--sample_coverage_na12877", type=float, required=True, help="Mean read depth coverage for NA12877 sample (required)") + parser.add_argument("--sample_coverage_na12878", type=float, required=True, help="Mean read depth coverage for NA12878 sample (required)") + parser.add_argument("--sample_coverage_na12879", type=float, required=True, help="Mean read depth coverage for NA12879 sample (required)") + parser.add_argument("--split_80_20", action='store_true', help="Whether to split the data into training and testing sets using an 80-20 split. If not specified, all the data will be used for training and testing, and cross-validation will be used to evaluate the model performance.") + parser.add_argument("--per_chr_validation", action='store_true', help="Whether to run per-chromosome cross-validation.") + args = parser.parse_args() + + # Run the program. + logging.info('Training the model, split_80_20 = %s, leave_out = %s, per_chr_validation = %s', args.split_80_20, args.leave_out, args.per_chr_validation) + train(args.tp_hg002_grch37, args.fp_hg002_grch37, args.tp_visor_grch38, args.fp_visor_grch38, args.tp_na12877_grch38, args.fp_na12877_grch38, args.tp_na12878_grch38, args.fp_na12878_grch38, args.tp_na12879_grch38, args.fp_na12879_grch38, args.outdir, args.annovar, args.annovar_db, args.outdiranno, args.leave_out, args.split_80_20, args.per_chr_validation, args.sample_coverage_hg002, args.sample_coverage_visor, args.sample_coverage_na12877, args.sample_coverage_na12878, args.sample_coverage_na12879) + logging.info('done.') diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..8a741eb --- /dev/null +++ b/environment.yml @@ -0,0 +1,14 @@ +name: contextscore +channels: + - wglab + - conda-forge + - bioconda + - defaults +dependencies: + - python=3.10 + - numpy + - pandas + - scikit-learn=1.6.1 # For consistency with model training environment + - joblib + - bedtools + - pytest diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..27eec68 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +testpaths = tests +python_files = test_*.py diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..8bc29e3 --- /dev/null +++ b/setup.py @@ -0,0 +1,35 @@ +from pathlib import Path +from setuptools import setup, find_packages + + +PROJECT_ROOT = Path(__file__).resolve().parent +DATA_FILES = [ + path.relative_to(PROJECT_ROOT).as_posix() + for path in (PROJECT_ROOT / "data").glob("*") + if path.is_file() +] + +setup( + name="ContextScore", + version="0.1.0", + packages=find_packages(), + include_package_data=True, + data_files=[("contextscore/data", DATA_FILES)], + install_requires=[ + "numpy", + "pandas", + "scikit-learn", + "joblib", + ], + extras_require={ + "plot": [ + "matplotlib", + "seaborn", + ] + }, + entry_points={ + "console_scripts": [ + "contextscore=contextscore.predict:main", + ] + }, +) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..23a629f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,7 @@ +from pathlib import Path +import sys + + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) diff --git a/tests/fixtures/example.vcf b/tests/fixtures/example.vcf new file mode 100644 index 0000000..e69de29 diff --git a/tests/fixtures/output.vcf.gz b/tests/fixtures/output.vcf.gz new file mode 100644 index 0000000..02bf563 Binary files /dev/null and b/tests/fixtures/output.vcf.gz differ diff --git a/tests/fixtures/output.vcf.gz.tbi b/tests/fixtures/output.vcf.gz.tbi new file mode 100644 index 0000000..ef05888 Binary files /dev/null and b/tests/fixtures/output.vcf.gz.tbi differ diff --git a/tests/test_extract_features_helpers.py b/tests/test_extract_features_helpers.py new file mode 100644 index 0000000..a4e6af6 --- /dev/null +++ b/tests/test_extract_features_helpers.py @@ -0,0 +1,67 @@ +import numpy as np +import pandas as pd + +from contextscore import extract_features as extract_features_module +from contextscore.extract_features import bed_to_annovar_input, normalize_chrom_label + + +def test_normalize_chrom_label_handles_none_like_values(): + assert normalize_chrom_label(None) is None + assert normalize_chrom_label(np.nan) is None + assert normalize_chrom_label(" ") is None + + +def test_normalize_chrom_label_normalizes_prefix_and_case(): + assert normalize_chrom_label("chr1") == "1" + assert normalize_chrom_label("1") == "1" + assert normalize_chrom_label("chrx") == "X" + assert normalize_chrom_label("x") == "X" + + +def test_bed_to_annovar_input_preserves_first_record_from_headerless_bed(tmp_path): + bed_path = tmp_path / "input.bed" + bed_path.write_text( + "chr1\t100\t150\tINS\t50\t./.\t10\t0\tCIGARINS\t3\t0\t0\t0\n" + "chr2\t200\t260\tDEL\t-60\t./.\t11\t0\tCIGARDEL\t4\t0\t0\t1\n", + encoding="utf-8", + ) + + output_path = bed_to_annovar_input(str(bed_path)) + annovar_df = pd.read_csv( + output_path, + sep='\t', + header=None, + names=['chrom', 'start', 'end', 'ref', 'alt'], + ) + + assert annovar_df[['chrom', 'start', 'end']].values.tolist() == [ + ['chr1', 100, 150], + ['chr2', 200, 260], + ] + + +def test_extract_features_preserves_id_zero_from_headerless_prediction_bed(tmp_path, monkeypatch): + bed_path = tmp_path / "input.bed" + bed_path.write_text( + "chr1\t100\t150\tINS\t50\t./.\t10\t0\tCIGARINS\t3\t0\t0\t0\n" + "chr1\t300\t370\tDEL\t-70\t./.\t11\t0\tCIGARDEL\t4\t0\t0\t1\n", + encoding="utf-8", + ) + + def fake_add_annotations(data, input_bed, annovar_path, db_path, anno_outdir, buildversion='hg38', training_format=False): + annotated = data.copy() + annotated['telomere'] = False + annotated['centromere'] = False + return annotated + + monkeypatch.setattr(extract_features_module, 'add_annotations', fake_add_annotations) + + feature_df = extract_features_module.extract_features( + str(bed_path), + annovar_path='unused', + db_path='unused', + outdiranno=str(tmp_path), + sample_coverage=30, + ) + + assert feature_df['id'].tolist() == [0, 1] diff --git a/tests/test_predict_helpers.py b/tests/test_predict_helpers.py new file mode 100644 index 0000000..4ea26a3 --- /dev/null +++ b/tests/test_predict_helpers.py @@ -0,0 +1,101 @@ +import importlib +from pathlib import Path + +import pytest + +from contextscore.predict import ( + DEFAULT_MODEL_INSTALL_PATH, + DEFAULT_MODEL_ENV_VAR, + resolve_model_path, + resolve_annovar_paths, + try_import_plotting_libs, + validate_annovar_paths, +) + + +def test_example_vcf_fixture_exists(): + fixture_path = Path(__file__).parent / "fixtures" / "example.vcf" + assert fixture_path.exists() + + +def test_resolve_annovar_paths_prefers_cli_over_env(monkeypatch): + monkeypatch.setenv("ANNOVAR_PATH", "/env/annovar") + monkeypatch.setenv("ANNOVAR_DB_PATH", "/env/db") + + annovar_path, annovar_db = resolve_annovar_paths("/cli/annovar", "/cli/db") + + assert annovar_path == "/cli/annovar" + assert annovar_db == "/cli/db" + + +def test_resolve_annovar_paths_uses_env_when_cli_missing(monkeypatch): + monkeypatch.setenv("ANNOVAR_PATH", "/env/annovar") + monkeypatch.setenv("ANNOVAR_DB_PATH", "/env/db") + + annovar_path, annovar_db = resolve_annovar_paths(None, None) + + assert annovar_path == "/env/annovar" + assert annovar_db == "/env/db" + + +def test_resolve_model_path_prefers_cli_over_env(monkeypatch): + monkeypatch.setenv(DEFAULT_MODEL_ENV_VAR, '/env/model.pkl') + + resolved, source = resolve_model_path('/cli/model.pkl') + + assert resolved == '/cli/model.pkl' + assert source == 'cli' + + +def test_resolve_model_path_uses_env_when_cli_missing(monkeypatch): + monkeypatch.setenv(DEFAULT_MODEL_ENV_VAR, '/env/model.pkl') + + resolved, source = resolve_model_path(None) + + assert resolved == '/env/model.pkl' + assert source == 'env' + + +def test_resolve_model_path_uses_default_when_cli_and_env_missing(monkeypatch): + monkeypatch.delenv(DEFAULT_MODEL_ENV_VAR, raising=False) + + resolved, source = resolve_model_path(None) + + assert resolved == DEFAULT_MODEL_INSTALL_PATH + assert source == 'default' + + +def test_validate_annovar_paths_requires_path_and_db(): + with pytest.raises(ValueError, match="ANNOVAR path is required"): + validate_annovar_paths(None, "/db") + + with pytest.raises(ValueError, match="ANNOVAR database path is required"): + validate_annovar_paths("/annovar", None) + + +def test_validate_annovar_paths_accepts_valid_layout(tmp_path): + annovar_dir = tmp_path / "annovar" + db_dir = tmp_path / "humandb" + annovar_dir.mkdir() + db_dir.mkdir() + + (annovar_dir / "annotate_variation.pl").write_text("#!/usr/bin/env perl\n", encoding="utf-8") + (annovar_dir / "table_annovar.pl").write_text("#!/usr/bin/env perl\n", encoding="utf-8") + + validate_annovar_paths(str(annovar_dir), str(db_dir)) + + +def test_try_import_plotting_libs_graceful_when_missing(monkeypatch): + real_import_module = importlib.import_module + + def fake_import_module(name): + if name in {"matplotlib.pyplot", "seaborn"}: + raise ImportError("not installed") + return real_import_module(name) + + monkeypatch.setattr(importlib, "import_module", fake_import_module) + + plt, sns = try_import_plotting_libs() + + assert plt is None + assert sns is None diff --git a/tests/test_predict_io.py b/tests/test_predict_io.py new file mode 100644 index 0000000..ef6153d --- /dev/null +++ b/tests/test_predict_io.py @@ -0,0 +1,156 @@ +from pathlib import Path + +import numpy as np +import pandas as pd + +from contextscore import predict + + +FIXTURE_VCF_GZ = Path(__file__).parent / 'fixtures' / 'output.vcf.gz' +TEST_OUTPUT_DIR = Path(__file__).parent / 'output' +FILTERED_VCF = TEST_OUTPUT_DIR / 'output_filtered.vcf' +REMOVED_VCF = TEST_OUTPUT_DIR / 'removed_svs.vcf' +PREDICTIONS_TSV = TEST_OUTPUT_DIR / 'predictions.tsv' +FILTERED_IDS = TEST_OUTPUT_DIR / 'filtered_ids.txt' + + +class DummyModel: + def predict_proba(self, feature_df): + length_signal = feature_df['length_signal'].to_numpy(dtype=float) + probabilities = np.where(length_signal >= 200, 0.95, 0.05) + return np.column_stack([1.0 - probabilities, probabilities]) + + +def _fake_extract_features(bed_file, annovar_path, annovar_db_path, anno_outdir, buildver, sample_coverage=None): + bed_df = pd.read_csv( + bed_file, + sep='\t', + header=None, + names=['chrom', 'start', 'end', 'sv_type_str', 'sv_length', 'gt', 'dp', 'hmm', 'aln', 'cluster', 'cn', 'alnoffset', 'id'], + ) + sv_length = pd.to_numeric(bed_df['sv_length'], errors='coerce').fillna(0) + read_depth = pd.to_numeric(bed_df['dp'], errors='coerce').fillna(0) + + return pd.DataFrame({ + 'id': bed_df['id'].astype(int), + 'chrom': bed_df['chrom'].astype(str), + 'start': pd.to_numeric(bed_df['start'], errors='coerce').fillna(0).astype(int), + 'end': pd.to_numeric(bed_df['end'], errors='coerce').fillna(0).astype(int), + 'sv_type_str': bed_df['sv_type_str'].astype(str), + 'sv_length': sv_length.astype(int), + 'length_signal': sv_length.abs().astype(float), + 'depth_signal': read_depth.astype(float), + }) + + +def _count_vcf_records(path): + with open(path, 'r', encoding='utf-8') as handle: + return sum(1 for line in handle if line.strip() and not line.startswith('#')) + + +def _prepare_output_dir(): + TEST_OUTPUT_DIR.mkdir(exist_ok=True) + for path in [FILTERED_VCF, REMOVED_VCF, PREDICTIONS_TSV, FILTERED_IDS, FIXTURE_VCF_GZ.with_suffix('.bed')]: + if path.exists(): + path.unlink() + + +def test_open_vcf_text_gz(): + with predict.open_vcf_text(FIXTURE_VCF_GZ) as handle: + lines = [line for line in handle] + assert len(lines) > 0 + assert any(line.startswith('#') for line in lines) + + +def test_open_vcf_text_plain(tmp_path): + vcf_path = tmp_path / 'test.vcf' + vcf_path.write_text( + '#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\n' + 'chr1\t100\t.\tA\tT\t.\tPASS\tSVTYPE=DEL;END=200\n', + encoding='utf-8', + ) + with predict.open_vcf_text(str(vcf_path)) as handle: + lines = [line for line in handle] + assert lines[0].startswith('#') + assert 'SVTYPE=DEL' in lines[1] + + +def test_vcf_gz_header_and_records(): + with predict.open_vcf_text(FIXTURE_VCF_GZ) as handle: + lines = [line.rstrip() for line in handle] + header_lines = [line for line in lines if line.startswith('#')] + variant_lines = [line for line in lines if line and not line.startswith('#')] + + assert any(line.startswith('##fileformat=VCFv4.2') for line in header_lines) + assert any(line.startswith('#CHROM') for line in header_lines) + assert len(variant_lines) > 0 + assert 'SVTYPE=' in variant_lines[0].split('\t')[7] + assert 'END=' in variant_lines[0].split('\t')[7] + + +def test_vcf_gz_svtype_counts(): + svtypes = [] + with predict.open_vcf_text(FIXTURE_VCF_GZ) as handle: + for line in handle: + if line.startswith('#'): + continue + info = line.rstrip().split('\t')[7] + for entry in info.split(';'): + if entry.startswith('SVTYPE='): + svtypes.append(entry.split('=')[1]) + break + + assert len(svtypes) > 0 + assert 'INS' in svtypes + assert 'DEL' in svtypes + + +def test_score_generates_outputs_in_tests_output(monkeypatch): + _prepare_output_dir() + input_bed_path = FIXTURE_VCF_GZ.with_suffix('.bed') + monkeypatch.setattr(predict, 'extract_features', _fake_extract_features) + monkeypatch.setattr(predict.joblib, 'load', lambda model_path: DummyModel()) + + summary = predict.score( + model='tests/fixtures/dummy_model.pkl', + input_vcf=str(FIXTURE_VCF_GZ), + output_vcf=str(FILTERED_VCF), + threshold=0.2, + sample_coverage=30, + large_cutoff=10000, + annovar_path='unused', + annovar_db_path='unused', + ) + + assert summary['output_vcf'] == str(FILTERED_VCF) + assert summary['removed_vcf'] == str(REMOVED_VCF) + assert summary['predictions_tsv'] == str(PREDICTIONS_TSV) + assert FILTERED_VCF.exists() + assert REMOVED_VCF.exists() + assert PREDICTIONS_TSV.exists() + assert FILTERED_IDS.exists() + assert not input_bed_path.exists() + + predictions_df = pd.read_csv(PREDICTIONS_TSV, sep='\t') + kept_records = _count_vcf_records(FILTERED_VCF) + removed_records = _count_vcf_records(REMOVED_VCF) + + assert not predictions_df.empty + assert set(['id', 'chrom', 'start', 'end', 'sv_type_str', 'sv_length', 'sv_length_abs', 'confidence_score']).issubset(predictions_df.columns) + assert predictions_df['confidence_score'].between(0, 1).all() + assert predictions_df['confidence_score'].max() == 0.95 + assert predictions_df['confidence_score'].min() == 0.05 + assert kept_records > 0 + assert removed_records > 0 + assert summary['total_records'] == len(predictions_df) + assert summary['passed_records'] == kept_records + assert summary['filtered_records'] == removed_records + assert summary['passed_records'] + summary['filtered_records'] == summary['total_records'] + + +def test_generated_predictions_include_multiple_svtypes(): + assert PREDICTIONS_TSV.exists() + predictions_df = pd.read_csv(PREDICTIONS_TSV, sep='\t') + + assert predictions_df['sv_type_str'].nunique() >= 2 + assert {'DEL', 'INS'}.issubset(set(predictions_df['sv_type_str'].unique()))