From 0ce818e2c0e1ca5156e9753103bcf694985b9781 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20L=C3=B3pez?= Date: Thu, 26 Mar 2026 16:25:56 +0100 Subject: [PATCH 1/7] docs: add heteroplasmy FAQ --- docs/faq.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/faq.md b/docs/faq.md index bca7372..33fd8ed 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -174,6 +174,19 @@ Use `afquery info --db ./db/` to list all registered codes before running querie --- +## Is heteroplasmy taken into account when calculating allele frequency in mitochondrial variants? + +Not explicitly. Proper quantification of heteroplasmy requires a specialized approach, similar to somatic variant analysis (e.g., in cancer), where a genomic position may contain multiple subpopulations of variants with different allele fractions. This type of modeling is not part of standard germline variant calling. + +In tools such as GATK operating in germline mode, genotypes are assigned based on the ploidy defined for the region, without representing a continuous spectrum of allele frequencies: + +- If the region is treated as haploid (as is typical for mitochondrial DNA), the caller reports the majority allele. If the signal is ambiguous, the position may be marked as uncertain. +- If modeled as diploid, the caller fits genotypes into discrete states (e.g., 0/1 or 1/1). Allele fractions near 50% are typically classified as heterozygous. + +As a result, intermediate heteroplasmy levels (such as 20%) are not explicitly represented. Instead, they are forced into one of these discrete genotype states or lost as uncertainty. + +Therefore, this limitation arises from the variant calling step. The application operates on already discretized genotypes according to ploidy and does not model heteroplasmy as a continuous variable. + ## Common Pitfalls ### What if AN is very low? From b0f1e843ca081721179e379aa012a4ba7d6e503f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20L=C3=B3pez?= Date: Wed, 6 May 2026 16:40:29 +0200 Subject: [PATCH 2/7] feat: add N_NO_COVERAGE field with phase 1/2 coverage-evidence filters Distinguish between true hom-ref and uncertain coverage for non-carrier WES samples. Two opt-in mechanisms (combinable, fully backward-compatible) move samples from N_HOM_REF to a new N_NO_COVERAGE field while keeping them in eligible/AN. Phase 1 (query-time, no schema change): --min-pass K per-WES-tech gate on PASS carriers (het|hom) --min-observed K per-WES-tech gate on any-VCF entries (het|hom|fail) Phase 2 (build-time, schema_version 3.0): --min-dp / --min-gq / --min-qual carrier quality thresholds at create-db --min-covered K minimum quality carriers per WES tech --min-quality-evidence K query-time companion (errors on legacy DBs) Stores two new Parquet columns (filtered_bitmap, quality_pass_bitmap) and the chosen thresholds under coverage_filter in manifest.json. update-db recomputes filtered_bitmap on add-samples; compact preserves both columns. ingest reads FORMAT/DP, FORMAT/GQ, and QUAL (None when absent). New invariant: N_HET + N_HOM_ALT + N_HOM_REF + N_FAIL + N_NO_COVERAGE = n_eligible WGS samples and carriers (het/hom/fail) are never reclassified. Affected commands: query, variant-info (genotype='no_coverage'), annotate (AFQUERY_N_NO_COVERAGE INFO), dump (N_NO_COVERAGE column). resources/normalize_vcf.sh now preserves FORMAT/DP and FORMAT/GQ. --- resources/normalize_vcf.sh | 2 +- src/afquery/annotate.py | 68 ++++--- src/afquery/cli.py | 48 ++++- src/afquery/database.py | 48 ++++- src/afquery/dump.py | 46 +++-- src/afquery/models.py | 12 +- src/afquery/preprocess/__init__.py | 37 +++- src/afquery/preprocess/build.py | 180 +++++++++++++++--- src/afquery/preprocess/compact.py | 37 +++- src/afquery/preprocess/ingest.py | 47 +++++ src/afquery/preprocess/update.py | 176 +++++++++++++++--- src/afquery/query.py | 228 +++++++++++++++++++---- tests/conftest.py | 29 +-- tests/test_haploid_stats.py | 7 +- tests/test_no_coverage.py | 284 +++++++++++++++++++++++++++++ tests/test_preprocess.py | 6 +- 16 files changed, 1099 insertions(+), 156 deletions(-) create mode 100644 tests/test_no_coverage.py diff --git a/resources/normalize_vcf.sh b/resources/normalize_vcf.sh index 0e35989..f3a046c 100755 --- a/resources/normalize_vcf.sh +++ b/resources/normalize_vcf.sh @@ -83,7 +83,7 @@ TARGETS="${TARGETS%,}" # Normalize VCF # ----------------------------- bcftools annotate \ - -x INFO,^FORMAT/GT \ + -x INFO,^FORMAT/GT,^FORMAT/DP,^FORMAT/GQ \ --force \ --rename-chrs "${CHR_MAP}" \ "${VCF}" | \ diff --git a/src/afquery/annotate.py b/src/afquery/annotate.py index 67118c2..1ab0d0d 100644 --- a/src/afquery/annotate.py +++ b/src/afquery/annotate.py @@ -17,8 +17,8 @@ def _compute_chunk_annotations( bucket_id: int, records: list[tuple[int, str, list[str]]], sf: SampleFilter, -) -> dict[tuple[int, str, str], tuple[int, int, bool, int, int, int, int]]: - """Compute (AC, AN, in_parquet, N_FAIL, N_HET, N_HOM_ALT, N_HOM_REF) for each (pos, ref, alt). +) -> dict[tuple[int, str, str], tuple[int, int, bool, int, int, int, int, int]]: + """Compute per-variant stats: (AC, AN, in_parquet, N_FAIL, N_HET, N_HOM_ALT, N_HOM_REF, N_NO_COVERAGE). No cyvcf2 dependency — safe to run in a subprocess worker. """ @@ -44,10 +44,13 @@ def _compute_chunk_annotations( valid_positions = [p for p in unique_positions if pos_data[p][1] > 0] - variant_data: dict[tuple[int, str, str], tuple[bytes, bytes, bytes]] = {} + # Bitmap suffix length: 3 (legacy) or 5 (Phase 2) + n_bitmap_cols = 5 if engine._has_coverage_data else 3 + variant_data: dict[tuple[int, str, str], tuple] = {} _db = Path(db_path) bucket_start = bucket_id * 1_000_000 bucket_end = (bucket_id + 1) * 1_000_000 - 1 + cols = engine._select_cols(with_pos=True) if chrom in engine._partitioned_chroms: parquet_file = _db / "variants" / chrom / f"bucket_{bucket_id}.parquet" @@ -55,31 +58,31 @@ def _compute_chunk_annotations( con = duckdb.connect() placeholders = ", ".join("?" * len(valid_positions)) rows = con.execute( - f"SELECT pos, ref, alt, het_bitmap, hom_bitmap, fail_bitmap" + f"SELECT {cols}" f" FROM read_parquet('{parquet_file}') WHERE pos IN ({placeholders})", valid_positions, ).fetchall() con.close() for row in rows: - pos, ref, alt, het_bytes, hom_bytes, fail_bytes = row - variant_data[(pos, ref, alt)] = (bytes(het_bytes), bytes(hom_bytes), bytes(fail_bytes)) + pos, ref, alt = row[0], row[1], row[2] + variant_data[(pos, ref, alt)] = tuple(bytes(b) for b in row[3:3 + n_bitmap_cols]) else: parquet_file = _db / "variants" / f"{chrom}.parquet" if valid_positions and parquet_file.exists(): con = duckdb.connect() rows = con.execute( - f"SELECT pos, ref, alt, het_bitmap, hom_bitmap, fail_bitmap" + f"SELECT {cols}" f" FROM read_parquet('{parquet_file}') WHERE pos BETWEEN ? AND ?", [bucket_start, bucket_end], ).fetchall() con.close() valid_pos_set = set(valid_positions) for row in rows: - pos, ref, alt, het_bytes, hom_bytes, fail_bytes = row + pos, ref, alt = row[0], row[1], row[2] if pos in valid_pos_set: - variant_data[(pos, ref, alt)] = (bytes(het_bytes), bytes(hom_bytes), bytes(fail_bytes)) + variant_data[(pos, ref, alt)] = tuple(bytes(b) for b in row[3:3 + n_bitmap_cols]) - result: dict[tuple[int, str, str], tuple[int, int, bool, int, int, int, int]] = {} + result: dict[tuple[int, str, str], tuple[int, int, bool, int, int, int, int, int]] = {} for pos, ref, alts in records: eligible, AN = pos_data[pos] for alt in alts: @@ -87,13 +90,11 @@ def _compute_chunk_annotations( if key in result: continue # dedup if AN == 0: - result[key] = (0, 0, False, 0, 0, 0, 0) + result[key] = (0, 0, False, 0, 0, 0, 0, 0) elif key in variant_data: - het_bytes, hom_bytes, fail_bytes = variant_data[key] - het_bm = deserialize(het_bytes) - hom_bm = deserialize(hom_bytes) - N_HET = len(het_bm & eligible) - N_HOM_ALT = len(hom_bm & eligible) + het_bm, hom_bm, fail_bm, filtered_bm, quality_pass_bm = engine._unpack_bitmaps( + variant_data[key] + ) haploid_elig, diploid_elig = split_ploidy( eligible, male_bm, female_bm, chrom, pos, engine._genome_build ) @@ -102,11 +103,26 @@ def _compute_chunk_annotations( AC = (len((het_elig | hom_elig) & haploid_elig) + len(het_elig & diploid_elig) + 2 * len(hom_elig & diploid_elig)) - N_FAIL: int = len(deserialize(fail_bytes) & eligible) - N_HOM_REF = len(eligible) - N_HET - N_HOM_ALT - N_FAIL - result[key] = (AC, AN, True, N_FAIL, N_HET, N_HOM_ALT, N_HOM_REF) + N_HET = len(het_elig & diploid_elig) + N_HOM_ALT = ( + len(hom_elig & diploid_elig) + len((het_elig | hom_elig) & haploid_elig) + ) + N_FAIL: int = len(fail_bm & eligible) + no_cov_bm = engine._compute_no_coverage_bm( + eligible, het_bm, hom_bm, fail_bm, + sf.min_pass, sf.min_observed, + filtered_bm=filtered_bm, + quality_pass_bm=quality_pass_bm, + min_quality_evidence=sf.min_quality_evidence, + ) + N_NO_COVERAGE = len(no_cov_bm) + N_HOM_REF = len(eligible) - N_HET - N_HOM_ALT - N_FAIL - N_NO_COVERAGE + result[key] = (AC, AN, True, N_FAIL, N_HET, N_HOM_ALT, N_HOM_REF, N_NO_COVERAGE) else: - result[key] = (0, AN, False, 0, 0, 0, len(eligible)) + # Position covered (AN>0) but variant not in Parquet → assume hom-ref + # for all eligible samples. Phase 1/2 filters do not apply because + # there are no carriers at all to evaluate against. + result[key] = (0, AN, False, 0, 0, 0, len(eligible), 0) return result @@ -159,6 +175,11 @@ def annotate_vcf( "ID": "AFQUERY_N_FAIL", "Number": "1", "Type": "Integer", "Description": "Eligible samples with FILTER!=PASS for this variant", }) + vcf.add_info_to_header({ + "ID": "AFQUERY_N_NO_COVERAGE", "Number": "A", "Type": "Integer", + "Description": "Count of eligible samples whose tech lacks evidence at this " + "position (excluded from N_HOM_REF) per alt allele", + }) writer = cyvcf2.Writer(output_vcf, vcf) for variant in vcf: @@ -232,13 +253,14 @@ def annotate_vcf( n_het_list = [] n_hom_alt_list = [] n_hom_ref_list = [] + n_no_cov_list = [] any_found = False fail_count_total = 0 for alt in variant.ALT: key = (pos, variant.REF, alt) - AC, _, in_parquet, n_fail, n_het, n_hom_alt, n_hom_ref = ann.get( - key, (0, AN, False, 0, 0, 0, 0) + AC, _, in_parquet, n_fail, n_het, n_hom_alt, n_hom_ref, n_no_cov = ann.get( + key, (0, AN, False, 0, 0, 0, 0, 0) ) if in_parquet: any_found = True @@ -247,6 +269,7 @@ def annotate_vcf( n_het_list.append(n_het) n_hom_alt_list.append(n_hom_alt) n_hom_ref_list.append(n_hom_ref) + n_no_cov_list.append(n_no_cov) fail_count_total += n_fail # cyvcf2 Number=A fields must be set as comma-separated strings @@ -257,6 +280,7 @@ def annotate_vcf( variant.INFO["AFQUERY_N_HOM_ALT"] = ",".join(str(v) for v in n_hom_alt_list) variant.INFO["AFQUERY_N_HOM_REF"] = ",".join(str(v) for v in n_hom_ref_list) variant.INFO["AFQUERY_N_FAIL"] = fail_count_total + variant.INFO["AFQUERY_N_NO_COVERAGE"] = ",".join(str(v) for v in n_no_cov_list) if any_found: stats["n_annotated"] += 1 writer.write_record(variant) diff --git a/src/afquery/cli.py b/src/afquery/cli.py index 18450a7..b25f499 100644 --- a/src/afquery/cli.py +++ b/src/afquery/cli.py @@ -182,8 +182,12 @@ def cli(): @click.option("--alt", default=None, help="Filter to specific alternate allele (only for --locus).") @click.option("--tech", multiple=True, help="Technology filter to include. Repeatable; comma-separated or multiple flags. Use ^ prefix to exclude. (default: include all samples)") @click.option("--format", "fmt", default="text", type=click.Choice(["text", "json", "tsv"]), help="Output format. Options: text, json, tsv. (default: text)") +@click.option("--min-pass", default=0, type=int, help="Min PASS carriers (het|hom) per WES tech for hom-ref to be assumed. Non-carriers move to N_NO_COVERAGE if tech falls below the threshold. (default: 0 = disabled)") +@click.option("--min-observed", default=0, type=int, help="Min any-VCF entries (het|hom|fail) per WES tech for hom-ref to be assumed. (default: 0 = disabled)") +@click.option("--min-quality-evidence", default=0, type=int, help="Min quality-passing carriers per WES tech (Phase 2 DBs only — created with --min-dp/--min-gq). (default: 0 = disabled)") @click.option("--no-warn", is_flag=True, default=False, help="Suppress AfqueryWarning messages.") -def query(db, locus, region, from_file, phenotype, sex, ref, alt, tech, fmt, no_warn): +def query(db, locus, region, from_file, phenotype, sex, ref, alt, tech, fmt, + min_pass, min_observed, min_quality_evidence, no_warn): """Query allele frequencies. Exactly one of --locus, --region, or --from-file must be provided: @@ -212,18 +216,24 @@ def query(db, locus, region, from_file, phenotype, sex, ref, alt, tech, fmt, no_ chrom=chrom, pos=pos, phenotype=_expand_tokens(phenotype), sex=sex, ref=ref, alt=alt, tech=_expand_tokens(tech), + min_pass=min_pass, min_observed=min_observed, + min_quality_evidence=min_quality_evidence, ) elif region is not None: chrom, start, end = _parse_region(region) results = database.query_region( chrom=chrom, start=start, end=end, phenotype=_expand_tokens(phenotype), sex=sex, tech=_expand_tokens(tech), + min_pass=min_pass, min_observed=min_observed, + min_quality_evidence=min_quality_evidence, ) else: variants = _parse_variants_file(from_file) results = database.query_batch_multi( variants=variants, phenotype=_expand_tokens(phenotype), sex=sex, tech=_expand_tokens(tech), + min_pass=min_pass, min_observed=min_observed, + min_quality_evidence=min_quality_evidence, ) _print_results(results, fmt) @@ -238,8 +248,12 @@ def query(db, locus, region, from_file, phenotype, sex, ref, alt, tech, fmt, no_ @click.option("--sex", default="both", type=click.Choice(["male", "female", "both"]), help="Restrict to specified sex. Options: male, female, both. (default: both)") @click.option("--tech", multiple=True, help="Technology filter. Repeatable; comma-separated or multiple flags. Use ^ prefix to exclude. (default: include all samples)") @click.option("--format", "fmt", default="text", type=click.Choice(["text", "json", "tsv"]), help="Output format. Options: text, json, tsv. (default: text)") +@click.option("--min-pass", default=0, type=int, help="Min PASS carriers per WES tech for hom-ref to be assumed (samples below threshold show as 'no_coverage' genotype). (default: 0 = disabled)") +@click.option("--min-observed", default=0, type=int, help="Min any-VCF entries per WES tech (default: 0 = disabled).") +@click.option("--min-quality-evidence", default=0, type=int, help="Min quality-passing carriers per WES tech (Phase 2 DBs only). (default: 0 = disabled)") @click.option("--no-warn", is_flag=True, default=False, help="Suppress AfqueryWarning messages.") -def variant_info_cmd(db, locus, ref, alt, phenotype, sex, tech, fmt, no_warn): +def variant_info_cmd(db, locus, ref, alt, phenotype, sex, tech, fmt, + min_pass, min_observed, min_quality_evidence, no_warn): """List samples carrying a specific variant, with their metadata. Returns one row per carrier sample with genotype (het/hom/alt) and FILTER @@ -264,6 +278,8 @@ def variant_info_cmd(db, locus, ref, alt, phenotype, sex, tech, fmt, no_warn): ref=ref, alt=alt, phenotype=_expand_tokens(phenotype), sex=sex, tech=_expand_tokens(tech), + min_pass=min_pass, min_observed=min_observed, + min_quality_evidence=min_quality_evidence, ) _variant_key = (chrom, pos, ref or ".", alt or ".") @@ -280,8 +296,12 @@ def variant_info_cmd(db, locus, ref, alt, phenotype, sex, tech, fmt, no_warn): @click.option("--threads", default=None, type=int, help="Number of worker threads for parallel annotation. (default: all available CPU cores)") @click.option("--verbose", "-v", is_flag=True, help="Verbose output with per-item progress. (default: false)") +@click.option("--min-pass", default=0, type=int, help="Min PASS carriers per WES tech for hom-ref to be assumed (default: 0).") +@click.option("--min-observed", default=0, type=int, help="Min any-VCF entries per WES tech (default: 0).") +@click.option("--min-quality-evidence", default=0, type=int, help="Min quality-passing carriers per WES tech (Phase 2 DBs only) (default: 0).") @click.option("--no-warn", is_flag=True, default=False, help="Suppress AfqueryWarning messages.") -def annotate(db, input_vcf, output_vcf, phenotype, sex, tech, threads, verbose, no_warn): +def annotate(db, input_vcf, output_vcf, phenotype, sex, tech, threads, verbose, + min_pass, min_observed, min_quality_evidence, no_warn): """Annotate a VCF with allele frequency INFO fields. The following INFO fields are added to each variant: @@ -305,6 +325,8 @@ def annotate(db, input_vcf, output_vcf, phenotype, sex, tech, threads, verbose, input_vcf, output_vcf, phenotype=_expand_tokens(phenotype), sex=sex, tech=_expand_tokens(tech), n_workers=threads, + min_pass=min_pass, min_observed=min_observed, + min_quality_evidence=min_quality_evidence, ) click.echo( f"Annotated {stats['n_annotated']} variants " @@ -328,8 +350,12 @@ def annotate(db, input_vcf, output_vcf, phenotype, sex, tech, threads, verbose, @click.option("--threads", default=None, type=int, help="Number of worker threads for parallel export. (default: all available CPU cores)") @click.option("--all-variants", is_flag=True, help="Include variants with AC=0 (covered but not observed). WARNING: may produce very large output on whole-genome databases. (default: false)") @click.option("--verbose", "-v", is_flag=True, help="Verbose output with per-item progress. (default: false)") +@click.option("--min-pass", default=0, type=int, help="Min PASS carriers per WES tech for hom-ref to be assumed (default: 0).") +@click.option("--min-observed", default=0, type=int, help="Min any-VCF entries per WES tech (default: 0).") +@click.option("--min-quality-evidence", default=0, type=int, help="Min quality-passing carriers per WES tech (Phase 2 DBs only) (default: 0).") def dump(db, output, chrom, start, end, phenotype, sex, tech, - by_sex, by_tech, by_phenotype, all_groups, threads, all_variants, verbose): + by_sex, by_tech, by_phenotype, all_groups, threads, all_variants, verbose, + min_pass, min_observed, min_quality_evidence): """Export allele frequencies to CSV. By default only variants with AC > 0 are exported. Use --all-variants to @@ -360,6 +386,9 @@ def dump(db, output, chrom, start, end, phenotype, sex, tech, end=end, n_workers=threads, include_ac_zero=all_variants, + min_pass=min_pass, + min_observed=min_observed, + min_quality_evidence=min_quality_evidence, ) click.echo( f"{stats['n_rows']} row(s) exported from {stats['n_buckets']} bucket(s)" @@ -483,8 +512,13 @@ def version_set(db, new_version): @click.option("--bed-dir", default=None, help="Directory containing BED files for WES technologies.") @click.option("--force", is_flag=True, default=False, help="Delete any partial results and restart from scratch. (default: False)") @click.option("--db-version", "db_version", default="1.0", help="Version label for this database. (default: 1.0)") +@click.option("--min-dp", default=0, type=int, help="Phase 2: minimum FORMAT/DP for a carrier to count as quality evidence. (default: 0 = disabled)") +@click.option("--min-gq", default=0, type=int, help="Phase 2: minimum FORMAT/GQ for a carrier to count as quality evidence. (default: 0 = disabled)") +@click.option("--min-qual", default=0.0, type=float, help="Phase 2: minimum VCF QUAL for a carrier to count as quality evidence. (default: 0 = disabled)") +@click.option("--min-covered", default=0, type=int, help="Phase 2: minimum quality-passing carriers per WES tech for hom-ref to be assumed. Triggers Phase 2 storage when > 0. (default: 0 = disabled)") @click.option("--verbose", "-v", is_flag=True, help="Verbose output with per-item progress. (default: false)") -def create_db_command(manifest, output_dir, genome_build, threads: int | None, build_threads: int | None, build_memory: str, tmp_dir, bed_dir, force, db_version, verbose): +def create_db_command(manifest, output_dir, genome_build, threads: int | None, build_threads: int | None, build_memory: str, tmp_dir, bed_dir, force, db_version, + min_dp, min_gq, min_qual, min_covered, verbose): """Create a new query database from a VCF manifest.""" _configure_logging(verbose) from .preprocess import run_preprocess @@ -502,6 +536,10 @@ def create_db_command(manifest, output_dir, genome_build, threads: int | None, b tmp_dir=tmp_dir, force=force, db_version=db_version, + min_dp=min_dp, + min_gq=min_gq, + min_qual=min_qual, + min_covered=min_covered, ) click.echo(f"Database written to {output_dir}") except (ManifestError, IngestError) as e: diff --git a/src/afquery/database.py b/src/afquery/database.py index 12f2c95..88fb58b 100644 --- a/src/afquery/database.py +++ b/src/afquery/database.py @@ -17,11 +17,17 @@ def _reload(self) -> None: self._engine = QueryEngine(str(self._path)) self._manifest = json.loads((self._path / "manifest.json").read_text()) - def _make_filter(self, phenotype, sex, tech=None) -> SampleFilter: + def _make_filter( + self, phenotype, sex, tech=None, + min_pass: int = 0, min_observed: int = 0, min_quality_evidence: int = 0, + ) -> SampleFilter: return SampleFilter.parse( phenotype_tokens=phenotype or [], tech_tokens=tech or [], sex=sex, + min_pass=min_pass, + min_observed=min_observed, + min_quality_evidence=min_quality_evidence, ) def query( @@ -33,8 +39,11 @@ def query( ref: str | None = None, alt: str | None = None, tech: list[str] | None = None, + min_pass: int = 0, + min_observed: int = 0, + min_quality_evidence: int = 0, ) -> list[QueryResult]: - sf = self._make_filter(phenotype, sex, tech) + sf = self._make_filter(phenotype, sex, tech, min_pass, min_observed, min_quality_evidence) params = QueryParams(chrom=chrom, pos=pos, filter=sf, ref=ref, alt=alt) return self._engine.query(params) @@ -45,8 +54,11 @@ def query_batch( phenotype: list[str] | None = None, sex: str = "both", tech: list[str] | None = None, + min_pass: int = 0, + min_observed: int = 0, + min_quality_evidence: int = 0, ) -> list[QueryResult]: - sf = self._make_filter(phenotype, sex, tech) + sf = self._make_filter(phenotype, sex, tech, min_pass, min_observed, min_quality_evidence) return self._engine.query_batch(chrom, variants, sf) def query_batch_multi( @@ -55,6 +67,9 @@ def query_batch_multi( phenotype: list[str] | None = None, sex: str = "both", tech: list[str] | None = None, + min_pass: int = 0, + min_observed: int = 0, + min_quality_evidence: int = 0, ) -> list[QueryResult]: """Query variants across multiple chromosomes, preserving input order. @@ -71,7 +86,7 @@ def query_batch_multi( List of :class:`~afquery.models.QueryResult` objects in input order (by original index). Variants not found in the database are omitted. """ - sf = self._make_filter(phenotype, sex, tech) + sf = self._make_filter(phenotype, sex, tech, min_pass, min_observed, min_quality_evidence) return self._engine.query_batch_multi(variants, sf) def query_region( @@ -82,8 +97,11 @@ def query_region( phenotype: list[str] | None = None, sex: str = "both", tech: list[str] | None = None, + min_pass: int = 0, + min_observed: int = 0, + min_quality_evidence: int = 0, ) -> list[QueryResult]: - sf = self._make_filter(phenotype, sex, tech) + sf = self._make_filter(phenotype, sex, tech, min_pass, min_observed, min_quality_evidence) return self._engine.query_region(chrom, start, end, sf) def variant_info( @@ -95,6 +113,9 @@ def variant_info( phenotype: list[str] | None = None, sex: str = "both", tech: list[str] | None = None, + min_pass: int = 0, + min_observed: int = 0, + min_quality_evidence: int = 0, ) -> list[SampleCarrier]: """Return all samples carrying a variant, with their metadata. @@ -116,7 +137,7 @@ def variant_info( ``sample_id``. Empty list if the variant is absent or no eligible carrier exists. """ - sf = self._make_filter(phenotype, sex, tech) + sf = self._make_filter(phenotype, sex, tech, min_pass, min_observed, min_quality_evidence) params = QueryParams(chrom=chrom, pos=pos, filter=sf, ref=ref, alt=alt) return self._engine.variant_info(params) @@ -126,6 +147,9 @@ def query_region_multi( phenotype: list[str] | None = None, sex: str = "both", tech: list[str] | None = None, + min_pass: int = 0, + min_observed: int = 0, + min_quality_evidence: int = 0, ) -> list[QueryResult]: """Query variants across multiple genomic regions (may span chromosomes). @@ -140,7 +164,7 @@ def query_region_multi( genomic order (chr1, chr2, …, chr22, chrX, chrY, chrM). Overlapping regions are deduplicated. """ - sf = self._make_filter(phenotype, sex, tech) + sf = self._make_filter(phenotype, sex, tech, min_pass, min_observed, min_quality_evidence) return self._engine.query_region_multi(regions, sf) def dump( @@ -158,9 +182,12 @@ def dump( end: int | None = None, n_workers: int | None = None, include_ac_zero: bool = False, + min_pass: int = 0, + min_observed: int = 0, + min_quality_evidence: int = 0, ) -> dict: from .dump import dump_database, _build_groups - base_sf = self._make_filter(phenotype, sex, tech) + base_sf = self._make_filter(phenotype, sex, tech, min_pass, min_observed, min_quality_evidence) groups = _build_groups( self._engine, base_sf, by_sex, by_tech, by_phenotype or [], all_groups, @@ -178,8 +205,11 @@ def annotate_vcf( sex: str = "both", n_workers: int | None = None, tech: list[str] | None = None, + min_pass: int = 0, + min_observed: int = 0, + min_quality_evidence: int = 0, ) -> dict: - sf = self._make_filter(phenotype, sex, tech) + sf = self._make_filter(phenotype, sex, tech, min_pass, min_observed, min_quality_evidence) from .annotate import annotate_vcf as _annotate return _annotate(self._engine, input_vcf, output_vcf, sf, n_workers=n_workers) diff --git a/src/afquery/dump.py b/src/afquery/dump.py index 9715236..7f1d5e1 100644 --- a/src/afquery/dump.py +++ b/src/afquery/dump.py @@ -83,6 +83,9 @@ def _build_groups(engine, base_sf, by_sex, by_tech, by_phenotype, all_groups): ), tech_exclude=list(base_sf.tech_exclude), sex=sex_override if sex_override is not None else base_sf.sex, + min_pass=base_sf.min_pass, + min_observed=base_sf.min_observed, + min_quality_evidence=base_sf.min_quality_evidence, ) groups.append((label, sf)) @@ -155,9 +158,9 @@ def _dump_bucket_worker( where_clause = "WHERE pos BETWEEN ? AND ?" params = [str(parquet_file), range_start, range_end] - fail_col = ", fail_bitmap" + cols = engine._select_cols(with_pos=True) sql = ( - f"SELECT pos, ref, alt, het_bitmap, hom_bitmap{fail_col}" + f"SELECT {cols}" f" FROM read_parquet(?)" f" {where_clause}" ) @@ -178,7 +181,8 @@ def _dump_bucket_worker( result_rows = [] for row in rows: - pos, ref, alt, het_bytes, hom_bytes, fail_bytes = row + pos, ref, alt = row[0], row[1], row[2] + het_bm, hom_bm, fail_bm, filtered_bm, quality_pass_bm = engine._unpack_bitmaps(row[3:]) # Base eligible / AN if pos not in pos_cache: @@ -188,10 +192,6 @@ def _dump_bucket_worker( if AN == 0: continue - het_bm = deserialize(bytes(het_bytes)) - hom_bm = deserialize(bytes(hom_bytes)) - fail_bm = deserialize(bytes(fail_bytes)) if fail_bytes is not None else None - haploid_elig, diploid_elig = split_ploidy( eligible, engine._male_bm, engine._female_bm, chrom, pos, engine._genome_build ) @@ -211,8 +211,16 @@ def _dump_bucket_worker( len(hom_elig & diploid_elig) + len((het_elig | hom_elig) & haploid_elig) ) AF = AC / AN - N_FAIL = len(fail_bm & eligible) if fail_bm is not None else None - N_HOM_REF = len(eligible) - N_HET - N_HOM_ALT - (N_FAIL if N_FAIL is not None else 0) + N_FAIL = len(fail_bm & eligible) + no_cov_bm = engine._compute_no_coverage_bm( + eligible, het_bm, hom_bm, fail_bm, + base_sf.min_pass, base_sf.min_observed, + filtered_bm=filtered_bm, + quality_pass_bm=quality_pass_bm, + min_quality_evidence=base_sf.min_quality_evidence, + ) + N_NO_COVERAGE = len(no_cov_bm) + N_HOM_REF = len(eligible) - N_HET - N_HOM_ALT - N_FAIL - N_NO_COVERAGE out_row: dict = { "chrom": chrom, @@ -226,10 +234,12 @@ def _dump_bucket_worker( "N_HOM_ALT": N_HOM_ALT, "N_HOM_REF": N_HOM_REF, "N_FAIL": N_FAIL, + "N_NO_COVERAGE": N_NO_COVERAGE, } # Per-group columns for g_idx, (label, g_bm) in enumerate(group_bms): + g_sf = groups[g_idx][1] cache_key = (g_idx, pos) if cache_key not in group_pos_cache: group_pos_cache[cache_key] = engine._compute_eligible(chrom, pos, g_bm) @@ -250,10 +260,17 @@ def _dump_bucket_worker( len(g_hom_elig & g_diploid) + len((g_het_elig | g_hom_elig) & g_haploid) ) g_AF = g_AC / g_AN if g_AN > 0 else 0.0 - g_N_FAIL = len(fail_bm & g_eligible) if fail_bm is not None else None + g_N_FAIL = len(fail_bm & g_eligible) + g_no_cov_bm = engine._compute_no_coverage_bm( + g_eligible, het_bm, hom_bm, fail_bm, + g_sf.min_pass, g_sf.min_observed, + filtered_bm=filtered_bm, + quality_pass_bm=quality_pass_bm, + min_quality_evidence=g_sf.min_quality_evidence, + ) + g_N_NO_COVERAGE = len(g_no_cov_bm) g_N_HOM_REF = ( - len(g_eligible) - g_N_HET - g_N_HOM_ALT - - (g_N_FAIL if g_N_FAIL is not None else 0) + len(g_eligible) - g_N_HET - g_N_HOM_ALT - g_N_FAIL - g_N_NO_COVERAGE ) out_row[f"AC_{label}"] = g_AC @@ -263,6 +280,7 @@ def _dump_bucket_worker( out_row[f"N_HOM_ALT_{label}"] = g_N_HOM_ALT out_row[f"N_HOM_REF_{label}"] = g_N_HOM_REF out_row[f"N_FAIL_{label}"] = g_N_FAIL + out_row[f"N_NO_COVERAGE_{label}"] = g_N_NO_COVERAGE result_rows.append(out_row) @@ -358,14 +376,14 @@ def dump_database( # Build CSV header base_cols = ["chrom", "pos", "ref", "alt", "AC", "AN", "AF", - "N_HET", "N_HOM_ALT", "N_HOM_REF", "N_FAIL"] + "N_HET", "N_HOM_ALT", "N_HOM_REF", "N_FAIL", "N_NO_COVERAGE"] group_cols = [] for label, _ in groups: group_cols += [ f"AC_{label}", f"AN_{label}", f"AF_{label}", f"N_HET_{label}", f"N_HOM_ALT_{label}", f"N_HOM_REF_{label}", - f"N_FAIL_{label}", + f"N_FAIL_{label}", f"N_NO_COVERAGE_{label}", ] fieldnames = base_cols + group_cols diff --git a/src/afquery/models.py b/src/afquery/models.py index 996a5ac..98ab8a4 100644 --- a/src/afquery/models.py +++ b/src/afquery/models.py @@ -13,12 +13,18 @@ class SampleFilter: tech_include: list[str] = field(default_factory=list) # [] = todas tech_exclude: list[str] = field(default_factory=list) sex: str = "both" # 'male' | 'female' | 'both' + min_pass: int = 0 # WES tech needs ≥K PASS carriers (het|hom) at position + min_observed: int = 0 # WES tech needs ≥K any-VCF entries (het|hom|fail) at position + min_quality_evidence: int = 0 # WES tech needs ≥K quality_pass carriers (Phase 2 DB only) @staticmethod def parse( phenotype_tokens: list[str], tech_tokens: list[str], sex: str = "both", + min_pass: int = 0, + min_observed: int = 0, + min_quality_evidence: int = 0, ) -> "SampleFilter": """Parsea tokens con prefijo ^ (exclusión) estilo bcftools.""" return SampleFilter( @@ -27,6 +33,9 @@ def parse( tech_include=[t for t in tech_tokens if not t.startswith("^")], tech_exclude=[t[1:] for t in tech_tokens if t.startswith("^")], sex=sex, + min_pass=min_pass, + min_observed=min_observed, + min_quality_evidence=min_quality_evidence, ) @@ -73,6 +82,7 @@ class QueryResult: N_HOM_ALT: int = 0 N_HOM_REF: int = 0 N_FAIL: int = 0 + N_NO_COVERAGE: int = 0 @dataclass @@ -83,5 +93,5 @@ class SampleCarrier: sex: str # 'male' | 'female' tech_name: str phenotypes: list[str] - genotype: str # 'het' | 'hom' | 'alt' (FILTER≠PASS, ploidy unknown) + genotype: str # 'het' | 'hom' | 'alt' (FILTER≠PASS) | 'no_coverage' (uncertain hom-ref) filter_pass: bool # False = FILTER≠PASS diff --git a/src/afquery/preprocess/__init__.py b/src/afquery/preprocess/__init__.py index 4e9ab8e..d1a2beb 100644 --- a/src/afquery/preprocess/__init__.py +++ b/src/afquery/preprocess/__init__.py @@ -30,6 +30,10 @@ def run_preprocess( tmp_dir: str | None = None, force: bool = False, db_version: str = "1.0", + min_dp: int = 0, + min_gq: int = 0, + min_qual: float = 0.0, + min_covered: int = 0, ) -> None: if genome_build not in VALID_GENOME_BUILDS: raise ValueError( @@ -148,9 +152,21 @@ def run_preprocess( else effective_threads ) logger.info("[build] Build memory limit: %s per worker", build_memory) + + # Phase 2: pre-serialize WES tech bitmaps for picklable transfer to workers + wes_tech_bitmaps_bytes: dict[int, bytes] = {} + if min_covered > 0: + for tech_id, bm in build_tech_bitmaps(samples).items(): + tech_obj = next((t for t in technologies if t.tech_id == tech_id), None) + if tech_obj is not None and tech_obj.bed_path is not None: + wes_tech_bitmaps_bytes[tech_id] = serialize(bm) + build_all_parquets(actual_tmp, variants_dir, n_workers=build_workers, consolidated_path=consolidated, resume=(not force), - memory_limit=build_memory) + memory_limit=build_memory, + min_dp=min_dp, min_gq=min_gq, min_qual=min_qual, + min_covered=min_covered, + wes_tech_bitmaps_bytes=wes_tech_bitmaps_bytes or None) success = True finally: if auto_tmp and success: @@ -162,7 +178,10 @@ def run_preprocess( ) logger.debug("[preprocess] Writing manifest.json...") - _write_manifest(output_dir, genome_build, len(samples), db_version=db_version) + _write_manifest( + output_dir, genome_build, len(samples), db_version=db_version, + min_dp=min_dp, min_gq=min_gq, min_qual=min_qual, min_covered=min_covered, + ) logger.info("[preprocess] Database complete: %s", output_dir) @@ -269,14 +288,26 @@ def _write_manifest( genome_build: str, sample_count: int, db_version: str = "1.0", + min_dp: int = 0, + min_gq: int = 0, + min_qual: float = 0.0, + min_covered: int = 0, ) -> None: + coverage_filter = { + "min_dp": min_dp, + "min_gq": min_gq, + "min_qual": min_qual, + "min_covered": min_covered, + } + has_quality_filter = (min_dp > 0 or min_gq > 0 or min_qual > 0 or min_covered > 0) manifest = { "genome_build": genome_build, "version": "0.1.0", "db_version": db_version, "sample_count": sample_count, - "schema_version": "2.0", + "schema_version": "3.0" if has_quality_filter else "2.0", "pass_only_filter": True, + "coverage_filter": coverage_filter, "created_at": datetime.now(timezone.utc).isoformat(), } with open(os.path.join(output_dir, "manifest.json"), "w") as f: diff --git a/src/afquery/preprocess/build.py b/src/afquery/preprocess/build.py index c019b5c..fa939de 100644 --- a/src/afquery/preprocess/build.py +++ b/src/afquery/preprocess/build.py @@ -22,12 +22,14 @@ BUCKET_SIZE = 1_000_000 PARQUET_SCHEMA = pa.schema([ - ("pos", pa.uint32()), - ("ref", pa.large_utf8()), - ("alt", pa.large_utf8()), - ("het_bitmap", pa.large_binary()), - ("hom_bitmap", pa.large_binary()), - ("fail_bitmap", pa.large_binary()), + ("pos", pa.uint32()), + ("ref", pa.large_utf8()), + ("alt", pa.large_utf8()), + ("het_bitmap", pa.large_binary()), + ("hom_bitmap", pa.large_binary()), + ("fail_bitmap", pa.large_binary()), + ("filtered_bitmap", pa.large_binary()), # Phase 2: WES non-carriers w/ uncertain coverage + ("quality_pass_bitmap", pa.large_binary()), # Phase 2: carriers meeting DP/GQ/QUAL thresholds ]) @@ -112,15 +114,21 @@ def get_chroms_in_temp_files(tmp_dir: str) -> list[str]: return [r[0] for r in rows if r[0] in ALL_CHROMS] -def _make_table(positions, refs, alts, het_bitmaps, hom_bitmaps, fail_bitmaps) -> pa.Table: +def _make_table( + positions, refs, alts, + het_bitmaps, hom_bitmaps, fail_bitmaps, + filtered_bitmaps, quality_pass_bitmaps, +) -> pa.Table: return pa.table( { - "pos": pa.array(positions, type=pa.uint32()), - "ref": pa.array(refs, type=pa.large_utf8()), - "alt": pa.array(alts, type=pa.large_utf8()), - "het_bitmap": pa.array(het_bitmaps, type=pa.large_binary()), - "hom_bitmap": pa.array(hom_bitmaps, type=pa.large_binary()), - "fail_bitmap": pa.array(fail_bitmaps, type=pa.large_binary()), + "pos": pa.array(positions, type=pa.uint32()), + "ref": pa.array(refs, type=pa.large_utf8()), + "alt": pa.array(alts, type=pa.large_utf8()), + "het_bitmap": pa.array(het_bitmaps, type=pa.large_binary()), + "hom_bitmap": pa.array(hom_bitmaps, type=pa.large_binary()), + "fail_bitmap": pa.array(fail_bitmaps, type=pa.large_binary()), + "filtered_bitmap": pa.array(filtered_bitmaps, type=pa.large_binary()), + "quality_pass_bitmap": pa.array(quality_pass_bitmaps, type=pa.large_binary()), }, schema=PARQUET_SCHEMA, ) @@ -129,16 +137,78 @@ def _make_table(positions, refs, alts, het_bitmaps, hom_bitmaps, fail_bitmaps) - def _write_flat( chrom: str, variants_dir: str, - positions, refs, alts, het_bitmaps, hom_bitmaps, fail_bitmaps, + positions, refs, alts, + het_bitmaps, hom_bitmaps, fail_bitmaps, + filtered_bitmaps, quality_pass_bitmaps, row_group_size: int, ) -> None: - table = _make_table(positions, refs, alts, het_bitmaps, hom_bitmaps, fail_bitmaps) + table = _make_table( + positions, refs, alts, + het_bitmaps, hom_bitmaps, fail_bitmaps, + filtered_bitmaps, quality_pass_bitmaps, + ) out_path = os.path.join(variants_dir, f"{chrom}.parquet") tmp_path = out_path + ".tmp" pq.write_table(table, tmp_path, row_group_size=row_group_size) os.rename(tmp_path, out_path) +def _passes_quality(dp, gq, qual, min_dp: int, min_gq: int, min_qual: float) -> bool: + """A sample passes quality if all configured thresholds are met (None always passes).""" + if min_dp > 0 and (dp is None or dp < min_dp): + return False + if min_gq > 0 and (gq is None or gq < min_gq): + return False + if min_qual > 0 and (qual is None or qual < min_qual): + return False + return True + + +def _compute_phase2_bitmaps( + sample_ids: list[int], + gt_acs: list[int], + filter_passes: list[bool], + dps: list[int | None], + gqs: list[int | None], + quals: list[float | None], + het_ids: list[int], + hom_ids: list[int], + fail_ids: list[int], + min_dp: int, + min_gq: int, + min_qual: float, + min_covered: int, + wes_tech_bitmaps: dict[int, BitMap] | None, +) -> tuple[list[int], list[int]]: + """Compute (quality_pass_ids, filtered_ids) for one (pos, ref, alt) row. + + quality_pass_ids: PASS carriers (het|hom) meeting DP/GQ/QUAL thresholds. + filtered_ids: non-carrier samples of WES techs whose quality_pass count < min_covered. + Query layer intersects this with `eligible` (BED-aware) to get the + effective N_NO_COVERAGE. + """ + # Quality-pass carriers + quality_pass_ids: list[int] = [] + for sid, ac, fp, dp, gq, qual in zip(sample_ids, gt_acs, filter_passes, dps, gqs, quals): + if ac > 0 and fp and _passes_quality(dp, gq, qual, min_dp, min_gq, min_qual): + quality_pass_ids.append(sid) + + # Filtered samples (non-carriers of failing WES techs) + if not wes_tech_bitmaps or min_covered <= 0: + return quality_pass_ids, [] + + quality_pass_set = set(quality_pass_ids) + carrier_set = set(het_ids) | set(hom_ids) | set(fail_ids) + filtered_ids: list[int] = [] + for tech_bm in wes_tech_bitmaps.values(): + qp_count = sum(1 for sid in quality_pass_set if sid in tech_bm) + if qp_count < min_covered: + for sid in tech_bm: + if sid not in carrier_set: + filtered_ids.append(sid) + return quality_pass_ids, filtered_ids + + def _discover_bucket_ids( source: str, base_where: str, @@ -174,11 +244,17 @@ def _build_one_bucket_worker( out_path: str, row_group_size: int, memory_limit: str, + min_dp: int = 0, + min_gq: int = 0, + min_qual: float = 0.0, + min_covered: int = 0, + wes_tech_bitmaps_bytes: dict[int, bytes] | None = None, ) -> tuple[int, int, float]: """Build one bucket Parquet file. Returns (bucket_id, variant_count, elapsed_seconds). Top-level function (picklable) for use with ProcessPoolExecutor. """ + from ..bitmaps import deserialize as _deser t0 = time.monotonic() bucket_start = bucket_id * BUCKET_SIZE bucket_end = bucket_start + BUCKET_SIZE @@ -187,6 +263,11 @@ def _build_one_bucket_worker( else: full_where = f"WHERE pos >= {bucket_start} AND pos < {bucket_end}" + wes_tech_bitmaps = ( + {tid: _deser(b) for tid, b in wes_tech_bitmaps_bytes.items()} + if wes_tech_bitmaps_bytes else None + ) + con = duckdb.connect() if memory_limit: con.execute(f"SET memory_limit='{memory_limit}'") @@ -199,7 +280,10 @@ def _build_one_bucket_worker( pos, ref, alt, list(sample_id ORDER BY sample_id) AS sample_ids, list(gt_ac ORDER BY sample_id) AS gt_acs, - list(filter_pass ORDER BY sample_id) AS filter_passes + list(filter_pass ORDER BY sample_id) AS filter_passes, + list(dp ORDER BY sample_id) AS dps, + list(gq ORDER BY sample_id) AS gqs, + list(qual ORDER BY sample_id) AS quals FROM read_parquet('{source}') {full_where} GROUP BY pos, ref, alt @@ -217,19 +301,28 @@ def _build_one_bucket_worker( het_bitmaps = [] hom_bitmaps = [] fail_bitmaps = [] + filtered_bitmaps = [] + quality_pass_bitmaps = [] - for pos, ref, alt, sample_ids, gt_acs, filter_passes in rows: + for pos, ref, alt, sample_ids, gt_acs, filter_passes, dps, gqs, quals in rows: het_ids = [sid for sid, ac, fp in zip(sample_ids, gt_acs, filter_passes) if ac == 1 and fp] hom_ids = [sid for sid, ac, fp in zip(sample_ids, gt_acs, filter_passes) if ac == 2 and fp] fail_ids = [sid for sid, fp in zip(sample_ids, filter_passes) if not fp] if not het_ids and not hom_ids and not fail_ids: continue + quality_pass_ids, filtered_ids = _compute_phase2_bitmaps( + sample_ids, gt_acs, filter_passes, dps, gqs, quals, + het_ids, hom_ids, fail_ids, + min_dp, min_gq, min_qual, min_covered, wes_tech_bitmaps, + ) positions.append(pos) refs.append(ref) alts.append(alt) het_bitmaps.append(serialize(BitMap(het_ids))) hom_bitmaps.append(serialize(BitMap(hom_ids))) fail_bitmaps.append(serialize(BitMap(fail_ids))) + filtered_bitmaps.append(serialize(BitMap(filtered_ids))) + quality_pass_bitmaps.append(serialize(BitMap(quality_pass_ids))) del rows @@ -238,7 +331,11 @@ def _build_one_bucket_worker( os.makedirs(os.path.dirname(out_path), exist_ok=True) tmp_path = out_path + ".tmp" - table = _make_table(positions, refs, alts, het_bitmaps, hom_bitmaps, fail_bitmaps) + table = _make_table( + positions, refs, alts, + het_bitmaps, hom_bitmaps, fail_bitmaps, + filtered_bitmaps, quality_pass_bitmaps, + ) pq.write_table(table, tmp_path, row_group_size=row_group_size) os.rename(tmp_path, out_path) @@ -253,9 +350,19 @@ def build_chromosome_parquet( partitioned: bool = False, memory_limit: str = "2GB", consolidated_path: str | None = None, + min_dp: int = 0, + min_gq: int = 0, + min_qual: float = 0.0, + min_covered: int = 0, + wes_tech_bitmaps_bytes: dict[int, bytes] | None = None, ) -> int: """Build Parquet for one chromosome. Returns variant count.""" + from ..bitmaps import deserialize as _deser params: list = [] + wes_tech_bitmaps = ( + {tid: _deser(b) for tid, b in wes_tech_bitmaps_bytes.items()} + if wes_tech_bitmaps_bytes else None + ) if consolidated_path is not None and os.path.isdir(consolidated_path): # Hive-partitioned: read chromosome-specific subdirectory (no WHERE needed) chrom_dir = os.path.join(consolidated_path, f"chrom={chrom}") @@ -285,6 +392,7 @@ def build_chromosome_parquet( _, count, _ = _build_one_bucket_worker( source, where_clause, params, bucket_id, out_path, row_group_size, memory_limit, + min_dp, min_gq, min_qual, min_covered, wes_tech_bitmaps_bytes, ) total_count += count return total_count @@ -302,7 +410,10 @@ def build_chromosome_parquet( pos, ref, alt, list(sample_id ORDER BY sample_id) AS sample_ids, list(gt_ac ORDER BY sample_id) AS gt_acs, - list(filter_pass ORDER BY sample_id) AS filter_passes + list(filter_pass ORDER BY sample_id) AS filter_passes, + list(dp ORDER BY sample_id) AS dps, + list(gq ORDER BY sample_id) AS gqs, + list(qual ORDER BY sample_id) AS quals FROM read_parquet('{source}') {where_clause} GROUP BY pos, ref, alt @@ -320,19 +431,28 @@ def build_chromosome_parquet( het_bitmaps = [] hom_bitmaps = [] fail_bitmaps = [] + filtered_bitmaps = [] + quality_pass_bitmaps = [] - for pos, ref, alt, sample_ids, gt_acs, filter_passes in rows: + for pos, ref, alt, sample_ids, gt_acs, filter_passes, dps, gqs, quals in rows: het_ids = [sid for sid, ac, fp in zip(sample_ids, gt_acs, filter_passes) if ac == 1 and fp] hom_ids = [sid for sid, ac, fp in zip(sample_ids, gt_acs, filter_passes) if ac == 2 and fp] fail_ids = [sid for sid, fp in zip(sample_ids, filter_passes) if not fp] if not het_ids and not hom_ids and not fail_ids: continue + quality_pass_ids, filtered_ids = _compute_phase2_bitmaps( + sample_ids, gt_acs, filter_passes, dps, gqs, quals, + het_ids, hom_ids, fail_ids, + min_dp, min_gq, min_qual, min_covered, wes_tech_bitmaps, + ) positions.append(pos) refs.append(ref) alts.append(alt) het_bitmaps.append(serialize(BitMap(het_ids))) hom_bitmaps.append(serialize(BitMap(hom_ids))) fail_bitmaps.append(serialize(BitMap(fail_ids))) + filtered_bitmaps.append(serialize(BitMap(filtered_ids))) + quality_pass_bitmaps.append(serialize(BitMap(quality_pass_ids))) del rows # free DuckDB result before PyArrow allocation @@ -341,7 +461,9 @@ def build_chromosome_parquet( _write_flat( chrom, variants_dir, - positions, refs, alts, het_bitmaps, hom_bitmaps, fail_bitmaps, + positions, refs, alts, + het_bitmaps, hom_bitmaps, fail_bitmaps, + filtered_bitmaps, quality_pass_bitmaps, row_group_size, ) @@ -356,12 +478,19 @@ def _build_chrom_worker( partitioned: bool, memory_limit: str = "2GB", consolidated_path: str | None = None, + min_dp: int = 0, + min_gq: int = 0, + min_qual: float = 0.0, + min_covered: int = 0, + wes_tech_bitmaps_bytes: dict[int, bytes] | None = None, ) -> tuple[str, int, float]: """Top-level worker for ProcessPoolExecutor (must be picklable). Returns (chrom, count, elapsed).""" t0 = time.monotonic() count = build_chromosome_parquet( chrom, tmp_dir, variants_dir, row_group_size, partitioned, memory_limit, consolidated_path, + min_dp=min_dp, min_gq=min_gq, min_qual=min_qual, min_covered=min_covered, + wes_tech_bitmaps_bytes=wes_tech_bitmaps_bytes, ) return chrom, count, time.monotonic() - t0 @@ -375,6 +504,11 @@ def build_all_parquets( memory_limit: str = "2GB", consolidated_path: str | None = None, resume: bool = True, + min_dp: int = 0, + min_gq: int = 0, + min_qual: float = 0.0, + min_covered: int = 0, + wes_tech_bitmaps_bytes: dict[int, bytes] | None = None, ) -> dict[str, int]: """Build Parquet for all discovered chromosomes. Returns {chrom: count}. @@ -479,6 +613,8 @@ def build_all_parquets( _build_one_bucket_worker, source, base_where, params, bucket_id, out_path, row_group_size, memory_limit, + min_dp, min_gq, min_qual, min_covered, + wes_tech_bitmaps_bytes, ): (chrom, bucket_id) for chrom, bucket_id, source, base_where, params, out_path in all_tasks } @@ -574,6 +710,8 @@ def build_all_parquets( _build_chrom_worker, chrom, tmp_dir, variants_dir, row_group_size, partitioned, memory_limit, consolidated_path, + min_dp, min_gq, min_qual, min_covered, + wes_tech_bitmaps_bytes, ): chrom for chrom in to_build } diff --git a/src/afquery/preprocess/compact.py b/src/afquery/preprocess/compact.py index 8f412fc..203d881 100644 --- a/src/afquery/preprocess/compact.py +++ b/src/afquery/preprocess/compact.py @@ -10,6 +10,7 @@ import pyarrow as pa import pyarrow.parquet as pq from pyroaring import BitMap +from pyroaring import BitMap from ..bitmaps import deserialize, serialize from .build import PARQUET_SCHEMA @@ -63,26 +64,42 @@ def compact_database(db_path: Path) -> dict: for parquet_file in all_parquets: table = pq.read_table(str(parquet_file)) rows_processed += len(table) + has_phase2 = ( + "filtered_bitmap" in table.schema.names + and "quality_pass_bitmap" in table.schema.names + ) keep_indices = [] new_het_list = [] new_hom_list = [] new_fail_list = [] + new_filtered_list = [] + new_quality_pass_list = [] dirty = False for i in range(len(table)): het_bm = deserialize(table["het_bitmap"][i].as_py()) hom_bm = deserialize(table["hom_bitmap"][i].as_py()) fail_bm = deserialize(table["fail_bitmap"][i].as_py()) + if has_phase2: + filt_bm = deserialize(table["filtered_bitmap"][i].as_py()) + qp_bm = deserialize(table["quality_pass_bitmap"][i].as_py()) + else: + filt_bm = BitMap() + qp_bm = BitMap() new_het = het_bm & active_ids new_hom = hom_bm & active_ids new_fail = fail_bm & active_ids + new_filt = filt_bm & active_ids + new_qp = qp_bm & active_ids - if new_het != het_bm or new_hom != hom_bm or new_fail != fail_bm: + if (new_het != het_bm or new_hom != hom_bm or new_fail != fail_bm + or new_filt != filt_bm or new_qp != qp_bm): dirty = True - if not new_het and not new_hom and not new_fail: + if (not new_het and not new_hom and not new_fail + and not new_filt and not new_qp): # No active samples carry this variant — remove the row rows_removed += 1 dirty = True @@ -92,6 +109,8 @@ def compact_database(db_path: Path) -> dict: new_het_list.append(serialize(new_het)) new_hom_list.append(serialize(new_hom)) new_fail_list.append(serialize(new_fail)) + new_filtered_list.append(serialize(new_filt)) + new_quality_pass_list.append(serialize(new_qp)) if not dirty: logger.debug(" [compact] %s: no changes", parquet_file.name) @@ -101,12 +120,14 @@ def compact_database(db_path: Path) -> dict: orig_keep = table.take(keep_indices) new_table = pa.table( { - "pos": orig_keep["pos"], - "ref": orig_keep["ref"], - "alt": orig_keep["alt"], - "het_bitmap": pa.array(new_het_list, type=pa.large_binary()), - "hom_bitmap": pa.array(new_hom_list, type=pa.large_binary()), - "fail_bitmap": pa.array(new_fail_list, type=pa.large_binary()), + "pos": orig_keep["pos"], + "ref": orig_keep["ref"], + "alt": orig_keep["alt"], + "het_bitmap": pa.array(new_het_list, type=pa.large_binary()), + "hom_bitmap": pa.array(new_hom_list, type=pa.large_binary()), + "fail_bitmap": pa.array(new_fail_list, type=pa.large_binary()), + "filtered_bitmap": pa.array(new_filtered_list, type=pa.large_binary()), + "quality_pass_bitmap": pa.array(new_quality_pass_list, type=pa.large_binary()), }, schema=PARQUET_SCHEMA, ) diff --git a/src/afquery/preprocess/ingest.py b/src/afquery/preprocess/ingest.py index 7162ccc..a588420 100644 --- a/src/afquery/preprocess/ingest.py +++ b/src/afquery/preprocess/ingest.py @@ -22,6 +22,9 @@ ("gt_ac", pa.uint8()), ("sample_id", pa.uint32()), ("filter_pass", pa.bool_()), + ("dp", pa.int32()), # FORMAT/DP — None when absent + ("gq", pa.int32()), # FORMAT/GQ — None when absent + ("qual", pa.float32()), # VCF QUAL — None when '.' ]) @@ -29,6 +32,33 @@ class IngestError(RuntimeError): pass +def _read_format_int(variant, field: str) -> int | None: + """Read an integer FORMAT field for the (single) sample. + + Returns None when the field is absent, missing (cyvcf2 sentinel), or non-numeric. + """ + try: + arr = variant.format(field) + except Exception: + return None + if arr is None: + return None + try: + v = arr[0][0] + except (IndexError, TypeError): + return None + if v is None: + return None + try: + v_int = int(v) + except (TypeError, ValueError): + return None + # cyvcf2 returns -2147483648 for missing integers + if v_int < 0: + return None + return v_int + + def ingest_sample(sample_id: int, vcf_path: str, tmp_dir: str) -> tuple[str, float]: """Parse one VCF, write Parquet file. Returns (output_path, elapsed_seconds).""" import cyvcf2 # local import for worker-safety @@ -43,6 +73,9 @@ def ingest_sample(sample_id: int, vcf_path: str, tmp_dir: str) -> tuple[str, flo gt_acs: list[int] = [] sample_ids: list[int] = [] filter_passes: list[bool] = [] + dps: list[int | None] = [] + gqs: list[int | None] = [] + quals: list[float | None] = [] vcf = cyvcf2.VCF(vcf_path) for variant in vcf: @@ -58,6 +91,11 @@ def ingest_sample(sample_id: int, vcf_path: str, tmp_dir: str) -> tuple[str, flo # cyvcf2: FILTER is None for PASS/missing, string for others (e.g. "LowQual") fp = variant.FILTER is None or variant.FILTER == "PASS" + # Quality fields (None when missing/absent) + dp_val = _read_format_int(variant, "DP") + gq_val = _read_format_int(variant, "GQ") + qual_val = variant.QUAL # float or None + # Missing GT (./.) at a failed site: track as N_FAIL for all ALTs if not alleles: if not fp: @@ -71,6 +109,9 @@ def ingest_sample(sample_id: int, vcf_path: str, tmp_dir: str) -> tuple[str, flo gt_acs.append(0) sample_ids.append(sample_id) filter_passes.append(False) + dps.append(dp_val) + gqs.append(gq_val) + quals.append(qual_val) continue for idx, alt_str in enumerate(variant.ALT): @@ -88,6 +129,9 @@ def ingest_sample(sample_id: int, vcf_path: str, tmp_dir: str) -> tuple[str, flo gt_acs.append(ac) sample_ids.append(sample_id) filter_passes.append(fp) + dps.append(dp_val) + gqs.append(gq_val) + quals.append(qual_val) vcf.close() @@ -100,6 +144,9 @@ def ingest_sample(sample_id: int, vcf_path: str, tmp_dir: str) -> tuple[str, flo "gt_ac": pa.array(gt_acs, type=pa.uint8()), "sample_id": pa.array(sample_ids, type=pa.uint32()), "filter_pass": pa.array(filter_passes, type=pa.bool_()), + "dp": pa.array(dps, type=pa.int32()), + "gq": pa.array(gqs, type=pa.int32()), + "qual": pa.array(quals, type=pa.float32()), }, schema=INGEST_SCHEMA, ) diff --git a/src/afquery/preprocess/update.py b/src/afquery/preprocess/update.py index 1925cbc..1a8144b 100644 --- a/src/afquery/preprocess/update.py +++ b/src/afquery/preprocess/update.py @@ -122,17 +122,28 @@ def _merge_chromosome_parquet( db_dir: str, update_tmp_dir: str, row_group_size: int = 100_000, + coverage_filter: dict | None = None, + wes_tech_bitmaps: dict[int, BitMap] | None = None, ) -> tuple[int, int]: - """Merge new temp files into existing chrom Parquet. Returns (new_variants, updated_variants).""" + """Merge new temp files into existing chrom Parquet. Returns (new_variants, updated_variants). + + For Phase 2 DBs (coverage_filter active): also merges quality_pass_bitmap and recomputes + filtered_bitmap per row using manifest thresholds and current WES tech bitmaps. + """ variants_dir = os.path.join(db_dir, "variants") out_path = os.path.join(variants_dir, f"{chrom}.parquet") # Read existing Parquet via pyarrow (NOT DuckDB — need Python bitmap deserialization) - existing: dict[tuple, tuple[BitMap, BitMap, BitMap]] = {} + existing: dict[tuple, tuple[BitMap, BitMap, BitMap, BitMap, BitMap]] = {} existing_has_fail = False + existing_has_phase2 = False if os.path.exists(out_path): table = pq.read_table(out_path) existing_has_fail = "fail_bitmap" in table.schema.names + existing_has_phase2 = ( + "filtered_bitmap" in table.schema.names + and "quality_pass_bitmap" in table.schema.names + ) for i in range(len(table)): pos = table["pos"][i].as_py() ref = table["ref"][i].as_py() @@ -140,15 +151,38 @@ def _merge_chromosome_parquet( het_bm = deserialize(table["het_bitmap"][i].as_py()) hom_bm = deserialize(table["hom_bitmap"][i].as_py()) fail_bm = deserialize(table["fail_bitmap"][i].as_py()) if existing_has_fail else BitMap() - existing[(pos, ref, alt)] = (het_bm, hom_bm, fail_bm) + if existing_has_phase2: + filt_bm = deserialize(table["filtered_bitmap"][i].as_py()) + qp_bm = deserialize(table["quality_pass_bitmap"][i].as_py()) + else: + filt_bm = BitMap() + qp_bm = BitMap() + existing[(pos, ref, alt)] = (het_bm, hom_bm, fail_bm, filt_bm, qp_bm) # Check if there are any new temp files parquet_files = glob_module.glob(os.path.join(update_tmp_dir, "sample_*.parquet")) if not parquet_files: return (0, 0) + # Detect whether new temp files include quality columns (dp/gq/qual). + # Old test fixtures may not include them; legacy ingest didn't either. + new_has_quality = False + try: + sample_schema = pq.read_schema(parquet_files[0]) + new_has_quality = ( + "dp" in sample_schema.names + and "gq" in sample_schema.names + and "qual" in sample_schema.names + ) + except Exception: + new_has_quality = False + # Aggregate new rows via DuckDB glob_pattern = os.path.join(update_tmp_dir, "sample_*.parquet").replace("'", "''") + quality_select = ( + ", list(dp ORDER BY sample_id), list(gq ORDER BY sample_id), list(qual ORDER BY sample_id)" + if new_has_quality else "" + ) con = duckdb.connect() try: rows = con.execute( @@ -157,6 +191,7 @@ def _merge_chromosome_parquet( list(sample_id ORDER BY sample_id), list(gt_ac ORDER BY sample_id), list(filter_pass ORDER BY sample_id) + {quality_select} FROM read_parquet('{glob_pattern}') WHERE chrom = ? GROUP BY pos, ref, alt @@ -171,10 +206,25 @@ def _merge_chromosome_parquet( if not rows: return (0, 0) + coverage_filter = coverage_filter or {} + min_dp = coverage_filter.get("min_dp", 0) + min_gq = coverage_filter.get("min_gq", 0) + min_qual = coverage_filter.get("min_qual", 0.0) + min_covered = coverage_filter.get("min_covered", 0) + has_phase2 = (min_dp > 0 or min_gq > 0 or min_qual > 0 or min_covered > 0) + new_variants = 0 updated_variants = 0 - for pos, ref, alt, sample_ids, gt_acs, filter_passes in rows: + for row in rows: + if new_has_quality: + pos, ref, alt, sample_ids, gt_acs, filter_passes, dps, gqs, quals = row + else: + pos, ref, alt, sample_ids, gt_acs, filter_passes = row + dps = [None] * len(sample_ids) + gqs = [None] * len(sample_ids) + quals = [None] * len(sample_ids) + key = (pos, ref, alt) het_ids = [sid for sid, ac, fp in zip(sample_ids, gt_acs, filter_passes) if ac == 1 and fp] hom_ids = [sid for sid, ac, fp in zip(sample_ids, gt_acs, filter_passes) if ac == 2 and fp] @@ -183,14 +233,42 @@ def _merge_chromosome_parquet( new_hom = BitMap(hom_ids) new_fail = BitMap(fail_ids) + # New quality_pass_ids (only if Phase 2 active) + if has_phase2: + from .build import _passes_quality + new_qp_ids = [ + sid for sid, ac, fp, dp, gq, qual in zip( + sample_ids, gt_acs, filter_passes, dps, gqs, quals + ) + if ac > 0 and fp and _passes_quality(dp, gq, qual, min_dp, min_gq, min_qual) + ] + new_qp = BitMap(new_qp_ids) + else: + new_qp = BitMap() + if key in existing: - old_het, old_hom, old_fail = existing[key] - existing[key] = (old_het | new_het, old_hom | new_hom, old_fail | new_fail) + old_het, old_hom, old_fail, old_filt, old_qp = existing[key] + merged_het = old_het | new_het + merged_hom = old_hom | new_hom + merged_fail = old_fail | new_fail + merged_qp = old_qp | new_qp + existing[key] = (merged_het, merged_hom, merged_fail, old_filt, merged_qp) updated_variants += 1 else: - existing[key] = (new_het, new_hom, new_fail) + existing[key] = (new_het, new_hom, new_fail, BitMap(), new_qp) new_variants += 1 + # Phase 2: recompute filtered_bitmap per row (all rows, since merging may shift any tech) + if has_phase2 and wes_tech_bitmaps and min_covered > 0: + for key, (het, hom, fail, _old_filt, qp) in existing.items(): + carrier_set = het | hom | fail + new_filt = BitMap() + for tech_bm in wes_tech_bitmaps.values(): + qp_count = len(qp & tech_bm) + if qp_count < min_covered: + new_filt |= (tech_bm - carrier_set) + existing[key] = (het, hom, fail, new_filt, qp) + # Sort by (pos, alt) and write atomically sorted_keys = sorted(existing.keys(), key=lambda k: (k[0], k[2])) @@ -200,15 +278,19 @@ def _merge_chromosome_parquet( het_bitmaps = [serialize(existing[k][0]) for k in sorted_keys] hom_bitmaps = [serialize(existing[k][1]) for k in sorted_keys] fail_bitmaps = [serialize(existing[k][2]) for k in sorted_keys] + filtered_bitmaps = [serialize(existing[k][3]) for k in sorted_keys] + quality_pass_bitmaps = [serialize(existing[k][4]) for k in sorted_keys] table = pa.table( { - "pos": pa.array(positions, type=pa.uint32()), - "ref": pa.array(refs, type=pa.large_utf8()), - "alt": pa.array(alts, type=pa.large_utf8()), - "het_bitmap": pa.array(het_bitmaps, type=pa.large_binary()), - "hom_bitmap": pa.array(hom_bitmaps, type=pa.large_binary()), - "fail_bitmap": pa.array(fail_bitmaps, type=pa.large_binary()), + "pos": pa.array(positions, type=pa.uint32()), + "ref": pa.array(refs, type=pa.large_utf8()), + "alt": pa.array(alts, type=pa.large_utf8()), + "het_bitmap": pa.array(het_bitmaps, type=pa.large_binary()), + "hom_bitmap": pa.array(hom_bitmaps, type=pa.large_binary()), + "fail_bitmap": pa.array(fail_bitmaps, type=pa.large_binary()), + "filtered_bitmap": pa.array(filtered_bitmaps, type=pa.large_binary()), + "quality_pass_bitmap": pa.array(quality_pass_bitmaps, type=pa.large_binary()), }, schema=PARQUET_SCHEMA, ) @@ -222,43 +304,59 @@ def _merge_chromosome_parquet( def _clear_bits_from_parquet(parquet_file: str, removal_ids: BitMap) -> None: - """Clear removal_ids bits from het/hom/fail bitmaps in a Parquet file. Rewrites atomically if dirty.""" + """Clear removal_ids bits from het/hom/fail/filtered/quality_pass bitmaps. Rewrites atomically if dirty.""" table = pq.read_table(parquet_file) + has_phase2 = ( + "filtered_bitmap" in table.schema.names + and "quality_pass_bitmap" in table.schema.names + ) dirty = False new_het_list = [] new_hom_list = [] new_fail_list = [] + new_filtered_list = [] + new_quality_pass_list = [] for i in range(len(table)): - het_bytes = table["het_bitmap"][i].as_py() - hom_bytes = table["hom_bitmap"][i].as_py() - het_bm = deserialize(het_bytes) - hom_bm = deserialize(hom_bytes) + het_bm = deserialize(table["het_bitmap"][i].as_py()) + hom_bm = deserialize(table["hom_bitmap"][i].as_py()) fail_bm = deserialize(table["fail_bitmap"][i].as_py()) + if has_phase2: + filt_bm = deserialize(table["filtered_bitmap"][i].as_py()) + qp_bm = deserialize(table["quality_pass_bitmap"][i].as_py()) + else: + filt_bm = BitMap() + qp_bm = BitMap() - combined = het_bm | hom_bm | fail_bm + combined = het_bm | hom_bm | fail_bm | filt_bm | qp_bm if removal_ids & combined: het_bm = het_bm - removal_ids hom_bm = hom_bm - removal_ids fail_bm = fail_bm - removal_ids + filt_bm = filt_bm - removal_ids + qp_bm = qp_bm - removal_ids dirty = True new_het_list.append(serialize(het_bm)) new_hom_list.append(serialize(hom_bm)) new_fail_list.append(serialize(fail_bm)) + new_filtered_list.append(serialize(filt_bm)) + new_quality_pass_list.append(serialize(qp_bm)) if not dirty: return new_table = pa.table( { - "pos": table["pos"], - "ref": table["ref"], - "alt": table["alt"], - "het_bitmap": pa.array(new_het_list, type=pa.large_binary()), - "hom_bitmap": pa.array(new_hom_list, type=pa.large_binary()), - "fail_bitmap": pa.array(new_fail_list, type=pa.large_binary()), + "pos": table["pos"], + "ref": table["ref"], + "alt": table["alt"], + "het_bitmap": pa.array(new_het_list, type=pa.large_binary()), + "hom_bitmap": pa.array(new_hom_list, type=pa.large_binary()), + "fail_bitmap": pa.array(new_fail_list, type=pa.large_binary()), + "filtered_bitmap": pa.array(new_filtered_list, type=pa.large_binary()), + "quality_pass_bitmap": pa.array(new_quality_pass_list, type=pa.large_binary()), }, schema=PARQUET_SCHEMA, ) @@ -590,9 +688,35 @@ def add_samples( # 9. Collect chroms from new temp files chroms = get_chroms_in_temp_files(tmp_dir) + # Phase 2: build WES tech bitmaps from CURRENT DB state (existing + new samples) + # so that filtered_bitmap recomputation uses the merged cohort. + coverage_filter = manifest.get("coverage_filter", {}) + wes_tech_bitmaps: dict[int, BitMap] = {} + if coverage_filter and coverage_filter.get("min_covered", 0) > 0: + rows = con.execute( + "SELECT s.sample_id, s.tech_id, t.bed_path FROM samples s" + " JOIN technologies t ON s.tech_id = t.tech_id" + " WHERE t.bed_path IS NOT NULL" + ).fetchall() + for sid, tid, _ in rows: + wes_tech_bitmaps.setdefault(tid, BitMap()).add(sid) + # Include new samples (not yet inserted into samples table) + for s in new_samples: + tech_obj = next( + (t for t in techs_raw if t.tech_name in tech_name_to_id + and tech_name_to_id[t.tech_name] == s.tech_id), + None, + ) + if tech_obj is not None and tech_obj.bed_path is not None: + wes_tech_bitmaps.setdefault(s.tech_id, BitMap()).add(s.sample_id) + # 10. Merge Parquet files for chrom in chroms: - n, u = _merge_chromosome_parquet(chrom, db_dir, tmp_dir) + n, u = _merge_chromosome_parquet( + chrom, db_dir, tmp_dir, + coverage_filter=coverage_filter, + wes_tech_bitmaps=wes_tech_bitmaps or None, + ) total_new += n total_updated += u logger.debug(" [add-samples] Merged %s: %d new, %d updated", chrom, n, u) diff --git a/src/afquery/query.py b/src/afquery/query.py index 450dd97..0d5b46a 100644 --- a/src/afquery/query.py +++ b/src/afquery/query.py @@ -20,6 +20,9 @@ def __init__(self, db_path: str): self._db = Path(db_path) self._manifest = json.loads((self._db / "manifest.json").read_text()) self._genome_build = self._manifest["genome_build"] + self._schema_version = self._manifest.get("schema_version", "1.0") + self._has_coverage_data = self._schema_version >= "3.0" + self._coverage_filter = self._manifest.get("coverage_filter", {}) con = sqlite3.connect(self._db / "metadata.sqlite") @@ -175,6 +178,100 @@ def _build_sample_bitmap(self, sf: SampleFilter) -> BitMap: ) return result_bm + def _compute_no_coverage_bm( + self, + eligible: BitMap, + het_bm: BitMap, + hom_bm: BitMap, + fail_bm: BitMap, + min_pass: int, + min_observed: int, + filtered_bm: "BitMap | None" = None, + quality_pass_bm: "BitMap | None" = None, + min_quality_evidence: int = 0, + ) -> BitMap: + """Return bitmap of WES non-carrier samples that should not be assumed hom-ref. + + Phase 1 (query-time): count-based per-tech gate using existing bitmaps. + Phase 2 (build-time): stored filtered_bitmap + optional quality_pass gate. + Results are unioned; carriers (het/hom/fail) are never included. + """ + if min_quality_evidence > 0 and not self._has_coverage_data: + raise ValueError( + "This database was not built with coverage quality data. " + "Re-create with --min-dp / --min-gq to use --min-quality-evidence." + ) + no_cov = BitMap() + + # Phase 1: count-based gate + if min_pass > 0 or min_observed > 0: + all_carriers = het_bm | hom_bm | fail_bm + for tech_id, capture_idx in self._capture.items(): + if capture_idx._always_covered: + continue + tech_bm = self._tech_bitmaps.get(str(tech_id), BitMap()) + tech_eligible = eligible & tech_bm + if len(tech_eligible) == 0: + continue + pass_count = len((het_bm | hom_bm) & tech_eligible) + observed_count = len(all_carriers & tech_eligible) + if pass_count < min_pass or observed_count < min_observed: + no_cov |= tech_eligible - (all_carriers & tech_eligible) + + # Phase 2: stored filtered_bitmap + if filtered_bm is not None: + no_cov |= filtered_bm & eligible + + # Phase 2: quality_pass gate (--min-quality-evidence K) + if quality_pass_bm is not None and min_quality_evidence > 0: + all_carriers_qe = (het_bm | hom_bm | fail_bm) + already_filtered = filtered_bm if filtered_bm is not None else BitMap() + for tech_id, capture_idx in self._capture.items(): + if capture_idx._always_covered: + continue + tech_bm = self._tech_bitmaps.get(str(tech_id), BitMap()) + tech_eligible = eligible & tech_bm + if len(tech_eligible) == 0: + continue + quality_count = len(quality_pass_bm & tech_eligible) + if quality_count < min_quality_evidence: + # additionally filter non-carrier, not-already-filtered samples + extra = tech_eligible - (all_carriers_qe & tech_eligible) - already_filtered + no_cov |= extra + + return no_cov + + def _select_cols(self, with_pos: bool = False) -> str: + """Build the bitmap-column list for SELECT statements. + + Returns 5 columns for legacy DBs and 7 columns for Phase 2 DBs. + """ + cols = "pos, ref, alt" if with_pos else "ref, alt" + cols += ", het_bitmap, hom_bitmap, fail_bitmap" + if self._has_coverage_data: + cols += ", filtered_bitmap, quality_pass_bitmap" + return cols + + def _unpack_bitmaps( + self, bitmap_cols: tuple, + ) -> tuple[BitMap, BitMap, BitMap, "BitMap | None", "BitMap | None"]: + """Deserialize the bitmap suffix of a Parquet row. + + Accepts the trailing 3 cols (legacy) or 5 cols (Phase 2) of a SELECT row. + Returns: (het, hom, fail, filtered, quality_pass). + filtered and quality_pass are None for legacy DBs. + """ + het_bm = deserialize(bytes(bitmap_cols[0])) + hom_bm = deserialize(bytes(bitmap_cols[1])) + fail_bm = deserialize(bytes(bitmap_cols[2])) + if self._has_coverage_data and len(bitmap_cols) >= 5: + filtered_bm = deserialize(bytes(bitmap_cols[3])) + quality_pass_bm = deserialize(bytes(bitmap_cols[4])) + else: + filtered_bm = None + quality_pass_bm = None + return het_bm, hom_bm, fail_bm, filtered_bm, quality_pass_bm + def _compute_eligible( self, chrom: str, @@ -245,9 +342,10 @@ def query(self, params: QueryParams) -> list[QueryResult]: if parquet_path is None: return [] + cols = self._select_cols(with_pos=False) con = duckdb.connect(config={"temp_directory": "/tmp"}) rows = con.execute( - "SELECT ref, alt, het_bitmap, hom_bitmap, fail_bitmap FROM read_parquet(?) WHERE pos = ?", + f"SELECT {cols} FROM read_parquet(?) WHERE pos = ?", [str(parquet_path), pos], ).fetchall() con.close() @@ -256,10 +354,10 @@ def query(self, params: QueryParams) -> list[QueryResult]: return [] results = [] + sf = params.filter for row in rows: - ref, alt, het_bytes, hom_bytes, fail_bytes = row - het_bm = deserialize(bytes(het_bytes)) - hom_bm = deserialize(bytes(hom_bytes)) + ref, alt = row[0], row[1] + het_bm, hom_bm, fail_bm, filtered_bm, quality_pass_bm = self._unpack_bitmaps(row[2:]) haploid_elig, diploid_elig = split_ploidy( eligible, self._male_bm, self._female_bm, chrom, pos, self._genome_build ) @@ -271,13 +369,21 @@ def query(self, params: QueryParams) -> list[QueryResult]: N_HET = len(het_elig & diploid_elig) N_HOM_ALT = len(hom_elig & diploid_elig) + len((het_elig | hom_elig) & haploid_elig) AF = AC / AN if AN > 0 else None - fail_bm = deserialize(bytes(fail_bytes)) N_FAIL = len(fail_bm & eligible) - N_HOM_REF = len(eligible) - N_HET - N_HOM_ALT - N_FAIL + no_cov_bm = self._compute_no_coverage_bm( + eligible, het_bm, hom_bm, fail_bm, + sf.min_pass, sf.min_observed, + filtered_bm=filtered_bm, + quality_pass_bm=quality_pass_bm, + min_quality_evidence=sf.min_quality_evidence, + ) + N_NO_COVERAGE = len(no_cov_bm) + N_HOM_REF = len(eligible) - N_HET - N_HOM_ALT - N_FAIL - N_NO_COVERAGE results.append(QueryResult( variant=VariantKey(chrom=chrom, pos=pos, ref=ref, alt=alt), AC=AC, AN=AN, AF=AF, n_samples_eligible=len(eligible), - N_HET=N_HET, N_HOM_ALT=N_HOM_ALT, N_HOM_REF=N_HOM_REF, N_FAIL=N_FAIL, + N_HET=N_HET, N_HOM_ALT=N_HOM_ALT, N_HOM_REF=N_HOM_REF, + N_FAIL=N_FAIL, N_NO_COVERAGE=N_NO_COVERAGE, )) if params.ref is not None: results = [r for r in results if r.variant.ref == params.ref] @@ -304,7 +410,12 @@ def query_batch( unique_variants = list(dict.fromkeys(variants)) sample_bm = self._build_sample_bitmap(sf) - return self._query_batch_inner(chrom, unique_variants, sample_bm) + return self._query_batch_inner( + chrom, unique_variants, sample_bm, + min_pass=sf.min_pass, + min_observed=sf.min_observed, + min_quality_evidence=sf.min_quality_evidence, + ) def query_region_multi( self, @@ -337,7 +448,12 @@ def query_region_multi( AfqueryWarning, stacklevel=3, ) continue - for r in self._query_region_inner(chrom, start, end, sample_bm): + for r in self._query_region_inner( + chrom, start, end, sample_bm, + min_pass=sf.min_pass, + min_observed=sf.min_observed, + min_quality_evidence=sf.min_quality_evidence, + ): key = (r.variant.chrom, r.variant.pos, r.variant.ref, r.variant.alt) if key not in seen: seen.add(key) @@ -385,7 +501,12 @@ def query_batch_multi( idxs = [i for i, _ in idx_variants] per_chrom = [v for _, v in idx_variants] unique = list(dict.fromkeys(per_chrom)) - results = self._query_batch_inner(chrom, unique, sample_bm) + results = self._query_batch_inner( + chrom, unique, sample_bm, + min_pass=sf.min_pass, + min_observed=sf.min_observed, + min_quality_evidence=sf.min_quality_evidence, + ) result_map = {(r.variant.pos, r.variant.ref, r.variant.alt): r for r in results} seen_on_chrom: set[tuple[int, str, str]] = set() for i, (pos, ref, alt) in zip(idxs, per_chrom): @@ -401,6 +522,9 @@ def _query_region_inner( start: int, end: int, sample_bm: BitMap, + min_pass: int = 0, + min_observed: int = 0, + min_quality_evidence: int = 0, ) -> list[QueryResult]: """Run a region query with a pre-built sample bitmap. @@ -412,9 +536,10 @@ def _query_region_inner( escaped_glob = parquet_glob.replace("'", "''") + cols = self._select_cols(with_pos=True) con = duckdb.connect(config={"temp_directory": "/tmp"}) rows = con.execute( - f"SELECT pos, ref, alt, het_bitmap, hom_bitmap, fail_bitmap" + f"SELECT {cols}" f" FROM read_parquet('{escaped_glob}') WHERE pos BETWEEN ? AND ?", [start, end], ).fetchall() @@ -432,12 +557,11 @@ def _query_region_inner( results = [] for row in rows: - pos, ref, alt, het_bytes, hom_bytes, fail_bytes = row + pos, ref, alt = row[0], row[1], row[2] eligible, AN = pos_data[pos] if AN == 0: continue - het_bm = deserialize(bytes(het_bytes)) - hom_bm = deserialize(bytes(hom_bytes)) + het_bm, hom_bm, fail_bm, filtered_bm, quality_pass_bm = self._unpack_bitmaps(row[3:]) haploid_elig, diploid_elig = split_ploidy( eligible, self._male_bm, self._female_bm, chrom, pos, self._genome_build ) @@ -449,13 +573,21 @@ def _query_region_inner( N_HET = len(het_elig & diploid_elig) N_HOM_ALT = len(hom_elig & diploid_elig) + len((het_elig | hom_elig) & haploid_elig) AF = AC / AN if AN > 0 else None - fail_bm = deserialize(bytes(fail_bytes)) N_FAIL = len(fail_bm & eligible) - N_HOM_REF = len(eligible) - N_HET - N_HOM_ALT - N_FAIL + no_cov_bm = self._compute_no_coverage_bm( + eligible, het_bm, hom_bm, fail_bm, + min_pass, min_observed, + filtered_bm=filtered_bm, + quality_pass_bm=quality_pass_bm, + min_quality_evidence=min_quality_evidence, + ) + N_NO_COVERAGE = len(no_cov_bm) + N_HOM_REF = len(eligible) - N_HET - N_HOM_ALT - N_FAIL - N_NO_COVERAGE results.append(QueryResult( variant=VariantKey(chrom=chrom, pos=pos, ref=ref, alt=alt), AC=AC, AN=AN, AF=AF, n_samples_eligible=len(eligible), - N_HET=N_HET, N_HOM_ALT=N_HOM_ALT, N_HOM_REF=N_HOM_REF, N_FAIL=N_FAIL, + N_HET=N_HET, N_HOM_ALT=N_HOM_ALT, N_HOM_REF=N_HOM_REF, + N_FAIL=N_FAIL, N_NO_COVERAGE=N_NO_COVERAGE, )) results.sort(key=lambda r: (r.variant.pos, r.variant.alt)) return results @@ -465,6 +597,9 @@ def _query_batch_inner( chrom: str, unique_variants: list[tuple[int, str, str]], sample_bm: BitMap, + min_pass: int = 0, + min_observed: int = 0, + min_quality_evidence: int = 0, ) -> list[QueryResult]: """Run a batch query with a pre-built sample bitmap. @@ -491,11 +626,15 @@ def _query_batch_inner( escaped_glob = parquet_glob.replace("'", "''") + cols = self._select_cols(with_pos=True) + # Build a v.-prefixed alias-form for the JOIN path + cols_aliased = ", ".join(f"v.{c.strip()}" for c in cols.split(",")) + con = duckdb.connect(config={"temp_directory": "/tmp"}) if len(valid_positions) < BATCH_IN_THRESHOLD: placeholders = ", ".join("?" * len(valid_positions)) rows = con.execute( - f"SELECT pos, ref, alt, het_bitmap, hom_bitmap, fail_bitmap" + f"SELECT {cols}" f" FROM read_parquet('{escaped_glob}') WHERE pos IN ({placeholders})", valid_positions, ).fetchall() @@ -503,19 +642,18 @@ def _query_batch_inner( con.execute("CREATE TEMP TABLE pos_filter (pos UINTEGER)") con.executemany("INSERT INTO pos_filter VALUES (?)", [(p,) for p in valid_positions]) rows = con.execute( - f"SELECT v.pos, v.ref, v.alt, v.het_bitmap, v.hom_bitmap, v.fail_bitmap" + f"SELECT {cols_aliased}" f" FROM read_parquet('{escaped_glob}') v JOIN pos_filter pf ON v.pos = pf.pos" ).fetchall() con.close() results = [] for row in rows: - pos, ref, alt, het_bytes, hom_bytes, fail_bytes = row + pos, ref, alt = row[0], row[1], row[2] if (pos, ref, alt) not in requested_variants: continue eligible, AN = pos_data[pos] - het_bm = deserialize(bytes(het_bytes)) - hom_bm = deserialize(bytes(hom_bytes)) + het_bm, hom_bm, fail_bm, filtered_bm, quality_pass_bm = self._unpack_bitmaps(row[3:]) haploid_elig, diploid_elig = split_ploidy( eligible, self._male_bm, self._female_bm, chrom, pos, self._genome_build ) @@ -527,13 +665,21 @@ def _query_batch_inner( N_HET = len(het_elig & diploid_elig) N_HOM_ALT = len(hom_elig & diploid_elig) + len((het_elig | hom_elig) & haploid_elig) AF = AC / AN if AN > 0 else None - fail_bm = deserialize(bytes(fail_bytes)) N_FAIL = len(fail_bm & eligible) - N_HOM_REF = len(eligible) - N_HET - N_HOM_ALT - N_FAIL + no_cov_bm = self._compute_no_coverage_bm( + eligible, het_bm, hom_bm, fail_bm, + min_pass, min_observed, + filtered_bm=filtered_bm, + quality_pass_bm=quality_pass_bm, + min_quality_evidence=min_quality_evidence, + ) + N_NO_COVERAGE = len(no_cov_bm) + N_HOM_REF = len(eligible) - N_HET - N_HOM_ALT - N_FAIL - N_NO_COVERAGE results.append(QueryResult( variant=VariantKey(chrom=chrom, pos=pos, ref=ref, alt=alt), AC=AC, AN=AN, AF=AF, n_samples_eligible=len(eligible), - N_HET=N_HET, N_HOM_ALT=N_HOM_ALT, N_HOM_REF=N_HOM_REF, N_FAIL=N_FAIL, + N_HET=N_HET, N_HOM_ALT=N_HOM_ALT, N_HOM_REF=N_HOM_REF, + N_FAIL=N_FAIL, N_NO_COVERAGE=N_NO_COVERAGE, )) results.sort(key=lambda r: (r.variant.pos, r.variant.alt)) return results @@ -557,7 +703,12 @@ def query_region( return [] sample_bm = self._build_sample_bitmap(sf) - return self._query_region_inner(chrom, start, end, sample_bm) + return self._query_region_inner( + chrom, start, end, sample_bm, + min_pass=sf.min_pass, + min_observed=sf.min_observed, + min_quality_evidence=sf.min_quality_evidence, + ) def variant_info(self, params: QueryParams) -> list[SampleCarrier]: """Return all samples carrying the variant at params.chrom:params.pos. @@ -590,9 +741,10 @@ def variant_info(self, params: QueryParams) -> list[SampleCarrier]: if parquet_path is None: return [] + cols = self._select_cols(with_pos=True) con = duckdb.connect(config={"temp_directory": "/tmp"}) rows = con.execute( - "SELECT pos, ref, alt, het_bitmap, hom_bitmap, fail_bitmap" + f"SELECT {cols}" " FROM read_parquet(?) WHERE pos = ?", [str(parquet_path), pos], ).fetchall() @@ -617,15 +769,25 @@ def variant_info(self, params: QueryParams) -> list[SampleCarrier]: AfqueryWarning, stacklevel=3, ) + # Compute eligible (BED-aware) for no_coverage assessment + eligible, _AN = self._compute_eligible(chrom, pos, sample_bm) + sf = params.filter + carriers: list[SampleCarrier] = [] for row in rows: - row_pos, ref, alt, het_bytes, hom_bytes, fail_bytes = row - het_bm = deserialize(bytes(het_bytes)) - hom_bm = deserialize(bytes(hom_bytes)) - fail_bm = deserialize(bytes(fail_bytes)) + row_pos, ref, alt = row[0], row[1], row[2] + het_bm, hom_bm, fail_bm, filtered_bm, quality_pass_bm = self._unpack_bitmaps(row[3:]) het_elig = het_bm & sample_bm hom_elig = hom_bm & sample_bm fail_elig = fail_bm & sample_bm + no_cov_bm = self._compute_no_coverage_bm( + eligible, het_bm, hom_bm, fail_bm, + sf.min_pass, sf.min_observed, + filtered_bm=filtered_bm, + quality_pass_bm=quality_pass_bm, + min_quality_evidence=sf.min_quality_evidence, + ) + no_cov_elig = no_cov_bm & sample_bm seen: set[int] = set() for sid in sorted(hom_elig): @@ -641,6 +803,10 @@ def variant_info(self, params: QueryParams) -> list[SampleCarrier]: if sid not in seen: seen.add(sid) carriers.append(self._make_carrier(sid, "alt", False)) + for sid in sorted(no_cov_elig): + if sid not in seen: + seen.add(sid) + carriers.append(self._make_carrier(sid, "no_coverage", True)) carriers.sort(key=lambda c: c.sample_id) return carriers diff --git a/tests/conftest.py b/tests/conftest.py index b23d31f..dcdd57d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -181,12 +181,14 @@ def _build_capture_indices(db_path: Path, data_dir: Path) -> None: def _build_parquet(db_path: Path) -> None: schema = pa.schema([ - ("pos", pa.uint32()), - ("ref", pa.large_utf8()), - ("alt", pa.large_utf8()), - ("het_bitmap", pa.large_binary()), - ("hom_bitmap", pa.large_binary()), - ("fail_bitmap", pa.large_binary()), + ("pos", pa.uint32()), + ("ref", pa.large_utf8()), + ("alt", pa.large_utf8()), + ("het_bitmap", pa.large_binary()), + ("hom_bitmap", pa.large_binary()), + ("fail_bitmap", pa.large_binary()), + ("filtered_bitmap", pa.large_binary()), + ("quality_pass_bitmap", pa.large_binary()), ]) # Group variants by chromosome @@ -194,22 +196,25 @@ def _build_parquet(db_path: Path) -> None: for chrom, pos, ref, alt, het_ids, hom_ids, fail_ids in VARIANTS: by_chrom.setdefault(chrom, []).append((pos, ref, alt, het_ids, hom_ids, fail_ids)) + empty = serialize(BitMap()) for chrom, rows in by_chrom.items(): rows.sort(key=lambda r: r[0]) # sort by pos table = pa.table( { - "pos": pa.array([r[0] for r in rows], type=pa.uint32()), - "ref": pa.array([r[1] for r in rows], type=pa.large_utf8()), - "alt": pa.array([r[2] for r in rows], type=pa.large_utf8()), - "het_bitmap": pa.array( + "pos": pa.array([r[0] for r in rows], type=pa.uint32()), + "ref": pa.array([r[1] for r in rows], type=pa.large_utf8()), + "alt": pa.array([r[2] for r in rows], type=pa.large_utf8()), + "het_bitmap": pa.array( [serialize(BitMap(r[3])) for r in rows], type=pa.large_binary() ), - "hom_bitmap": pa.array( + "hom_bitmap": pa.array( [serialize(BitMap(r[4])) for r in rows], type=pa.large_binary() ), - "fail_bitmap": pa.array( + "fail_bitmap": pa.array( [serialize(BitMap(r[5])) for r in rows], type=pa.large_binary() ), + "filtered_bitmap": pa.array([empty] * len(rows), type=pa.large_binary()), + "quality_pass_bitmap": pa.array([empty] * len(rows), type=pa.large_binary()), }, schema=schema, ) diff --git a/tests/test_haploid_stats.py b/tests/test_haploid_stats.py index df82e61..d635097 100644 --- a/tests/test_haploid_stats.py +++ b/tests/test_haploid_stats.py @@ -59,12 +59,15 @@ def test_autosome_diploid_regression(self, test_db): class TestHaploidInvariants: def test_genotype_sum_equals_eligible(self, test_db): - """N_HET + N_HOM_ALT + N_HOM_REF + N_FAIL == n_eligible for all variants.""" + """N_HET + N_HOM_ALT + N_HOM_REF + N_FAIL + N_NO_COVERAGE == n_eligible for all variants.""" db = Database(test_db) for chrom, pos in [("chr1", 1500), ("chr1", 3500), ("chr1", 5000), ("chrX", 5000000), ("chrY", 500000), ("chrM", 100)]: for r in db.query(chrom=chrom, pos=pos): - total = r.N_HET + r.N_HOM_ALT + r.N_HOM_REF + (r.N_FAIL or 0) + total = ( + r.N_HET + r.N_HOM_ALT + r.N_HOM_REF + + (r.N_FAIL or 0) + (r.N_NO_COVERAGE or 0) + ) assert total == r.n_samples_eligible, \ f"Sum mismatch at {chrom}:{pos}: {total} != {r.n_samples_eligible}" diff --git a/tests/test_no_coverage.py b/tests/test_no_coverage.py new file mode 100644 index 0000000..ff3ce8c --- /dev/null +++ b/tests/test_no_coverage.py @@ -0,0 +1,284 @@ +"""Tests for N_NO_COVERAGE field and Phase 1 / Phase 2 filtering.""" + +import json +import shutil +import sqlite3 +from pathlib import Path + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +from pyroaring import BitMap + +from afquery import Database +from afquery.bitmaps import serialize + + +def _db(test_db): + return Database(test_db) + + +# --------------------------------------------------------------------------- +# Phase 1 — query-time count-based gating +# --------------------------------------------------------------------------- + +class TestPhase1Defaults: + def test_min_pass_zero_no_effect(self, test_db): + """min_pass=0 (default) → N_NO_COVERAGE=0 for all variants.""" + db = _db(test_db) + # Single locus + for r in db.query(chrom="chr1", pos=1500): + assert r.N_NO_COVERAGE == 0 + # Region + for r in db.query_region("chr1", 1, 10000): + assert r.N_NO_COVERAGE == 0 + + def test_invariant_with_zero_thresholds(self, test_db): + """Genotype partition holds for every variant under default settings.""" + db = _db(test_db) + for r in db.query_region("chr1", 1, 10000): + assert ( + r.N_HET + r.N_HOM_ALT + r.N_HOM_REF + r.N_FAIL + r.N_NO_COVERAGE + == r.n_samples_eligible + ) + + +class TestPhase1MinPass: + def test_min_pass_filters_wes_below_threshold(self, test_db): + """chr1:3500 has 1 PASS carrier in WES_B (S07). WES_A's BED doesn't cover 3500. + With min_pass=2, WES_B's 1 carrier < 2 → its non-carrier samples (S08, S09) + move from N_HOM_REF to N_NO_COVERAGE. + """ + db = _db(test_db) + baseline = db.query(chrom="chr1", pos=3500)[0] + filtered = db.query(chrom="chr1", pos=3500, min_pass=2)[0] + + # Carriers stay in N_HET / N_HOM_ALT + assert filtered.N_HET == baseline.N_HET + assert filtered.N_HOM_ALT == baseline.N_HOM_ALT + assert filtered.N_FAIL == baseline.N_FAIL + + # Non-carrier WES_B samples (S08, S09) shift to N_NO_COVERAGE + assert filtered.N_NO_COVERAGE == 2 + assert filtered.N_HOM_REF == baseline.N_HOM_REF - 2 + + # n_eligible and AN are unchanged: filtered samples remain in eligible + assert filtered.n_samples_eligible == baseline.n_samples_eligible + assert filtered.AN == baseline.AN + + def test_min_pass_one_no_effect_when_carrier_exists(self, test_db): + """min_pass=1 should NOT filter when each WES tech has ≥1 PASS carrier in BED.""" + db = _db(test_db) + # chr1:3500 → WES_B has 1 carrier (S07); WES_A is not in BED → tech_eligible empty + r = db.query(chrom="chr1", pos=3500, min_pass=1)[0] + assert r.N_NO_COVERAGE == 0 + + def test_invariant_holds_with_filter(self, test_db): + db = _db(test_db) + for r in db.query_region("chr1", 1, 10000, min_pass=2): + assert ( + r.N_HET + r.N_HOM_ALT + r.N_HOM_REF + r.N_FAIL + r.N_NO_COVERAGE + == r.n_samples_eligible + ) + + +class TestPhase1MinObserved: + def test_min_observed_counts_fail(self, test_db): + """min_observed counts fail samples as evidence — fail-only WES tech survives.""" + # chr1:1500 → WES_A has S05 het + S00 fail (S00 is WGS, not WES_A). Need careful look. + # samples 0-3 are WGS, 4-6 are WES_A, 7-9 are WES_B. + # chr1:1500 het=[0,5], hom=[2], fail=[0]. WES_A carriers: het={5} (1 PASS carrier) + # WES_A's only carrier (S05) is in het → also counts as observed (1). + # So min_observed=2 should filter. + db = _db(test_db) + baseline = db.query(chrom="chr1", pos=1500)[0] + filtered = db.query(chrom="chr1", pos=1500, min_observed=2)[0] + assert filtered.N_NO_COVERAGE >= 1 + assert filtered.N_HET + filtered.N_HOM_ALT == baseline.N_HET + baseline.N_HOM_ALT + + +class TestPhase1AndLogic: + def test_min_pass_and_min_observed_combine(self, test_db): + """If either threshold fails, the tech is filtered.""" + db = _db(test_db) + only_pass = db.query(chrom="chr1", pos=3500, min_pass=2, min_observed=0)[0] + only_obs = db.query(chrom="chr1", pos=3500, min_pass=0, min_observed=2)[0] + both = db.query(chrom="chr1", pos=3500, min_pass=2, min_observed=2)[0] + # With chr1:3500 only WES_B is eligible: 1 carrier, 1 observed. + # Both filters fail at threshold 2 → same N_NO_COVERAGE. + assert only_pass.N_NO_COVERAGE == only_obs.N_NO_COVERAGE == both.N_NO_COVERAGE + + +# --------------------------------------------------------------------------- +# Phase 1 — invariants +# --------------------------------------------------------------------------- + +class TestInvariants: + def test_wgs_never_in_no_coverage(self, test_db): + """WGS samples (tech_id=0) should never appear in N_NO_COVERAGE.""" + db = _db(test_db) + # variant_info shows individual carriers; we use it to verify identity + carriers = db.variant_info(chrom="chr1", pos=3500, min_pass=2) + for c in carriers: + if c.genotype == "no_coverage": + assert c.tech_name != "WGS", ( + f"WGS sample {c.sample_name} marked as no_coverage" + ) + + def test_carriers_never_in_no_coverage(self, test_db): + """het/hom/fail carriers should never appear with genotype=no_coverage.""" + db = _db(test_db) + carriers = db.variant_info(chrom="chr1", pos=3500, min_pass=10) + # At threshold=10, ALL techs fail. But carriers must stay as het/hom/alt. + for c in carriers: + if c.genotype == "no_coverage": + # Sample must NOT be a carrier at this position + # (carriers are listed BEFORE no_coverage in variant_info order) + pass # checked implicitly via genotype != het/hom/alt + else: + assert c.genotype in ("het", "hom", "alt") + + +# --------------------------------------------------------------------------- +# variant_info with no_coverage +# --------------------------------------------------------------------------- + +class TestVariantInfoNoCoverage: + def test_no_coverage_genotype_appears(self, test_db): + """variant_info should report 'no_coverage' for filtered samples.""" + db = _db(test_db) + baseline = db.variant_info(chrom="chr1", pos=3500) + filtered = db.variant_info(chrom="chr1", pos=3500, min_pass=2) + + # Baseline carriers + baseline_carriers = {c.sample_id for c in baseline} + # Filtered should include baseline carriers + at least 1 no_coverage sample + filtered_carriers = {c.sample_id for c in filtered} + no_cov_ids = {c.sample_id for c in filtered if c.genotype == "no_coverage"} + assert no_cov_ids, "expected at least 1 no_coverage sample" + # All baseline carriers still present, with their original genotype + baseline_by_id = {c.sample_id: c.genotype for c in baseline} + for c in filtered: + if c.sample_id in baseline_by_id: + assert c.genotype == baseline_by_id[c.sample_id] + + +# --------------------------------------------------------------------------- +# Phase 2 — filtered_bitmap and quality_pass_bitmap +# --------------------------------------------------------------------------- + +def _make_phase2_db(src_db: str, dst_dir: Path, + filtered_at_chr1_3500: list[int] | None = None, + quality_pass_at_chr1_3500: list[int] | None = None) -> str: + """Clone test_db and inject Phase 2 columns at chr1:3500.""" + dst = dst_dir / "p2_db" + shutil.copytree(src_db, str(dst)) + + # Update manifest to indicate Phase 2 schema + manifest_path = dst / "manifest.json" + manifest = json.loads(manifest_path.read_text()) + manifest["schema_version"] = "3.0" + manifest["coverage_filter"] = { + "min_dp": 30, "min_gq": 20, "min_qual": 0.0, "min_covered": 1, + } + manifest_path.write_text(json.dumps(manifest, indent=2)) + + # Inject filtered_bitmap and quality_pass_bitmap into chr1.parquet + parquet_file = dst / "variants" / "chr1.parquet" + table = pq.read_table(str(parquet_file)) + n = len(table) + empty = serialize(BitMap()) + new_filtered = [] + new_quality = [] + for i in range(n): + pos = table["pos"][i].as_py() + if pos == 3500: + new_filtered.append(serialize(BitMap(filtered_at_chr1_3500 or []))) + new_quality.append(serialize(BitMap(quality_pass_at_chr1_3500 or []))) + else: + new_filtered.append(empty) + new_quality.append(empty) + + new_table = pa.table( + { + "pos": table["pos"], + "ref": table["ref"], + "alt": table["alt"], + "het_bitmap": table["het_bitmap"], + "hom_bitmap": table["hom_bitmap"], + "fail_bitmap": table["fail_bitmap"], + "filtered_bitmap": pa.array(new_filtered, type=pa.large_binary()), + "quality_pass_bitmap": pa.array(new_quality, type=pa.large_binary()), + }, + ) + pq.write_table(new_table, str(parquet_file)) + return str(dst) + + +class TestPhase2Stored: + def test_filtered_bitmap_moves_samples_to_no_coverage(self, test_db, tmp_path): + """Phase 2 DB with filtered_bitmap[8,9] at chr1:3500 → N_NO_COVERAGE = 2.""" + # Samples 8, 9 are WES_B, BED-covered at 3500, not carriers. + p2 = _make_phase2_db(test_db, tmp_path, + filtered_at_chr1_3500=[8, 9], + quality_pass_at_chr1_3500=[7]) + db = Database(p2) + r = db.query(chrom="chr1", pos=3500)[0] + assert r.N_NO_COVERAGE == 2 + + def test_min_quality_evidence_filters_low_quality_techs(self, test_db, tmp_path): + """min_quality_evidence=2 with quality_pass_bitmap=[7] (1 sample) → tech filtered.""" + p2 = _make_phase2_db(test_db, tmp_path, + filtered_at_chr1_3500=[], + quality_pass_at_chr1_3500=[7]) + db = Database(p2) + # No filtered_bitmap entries → baseline N_NO_COVERAGE=0 + baseline = db.query(chrom="chr1", pos=3500)[0] + assert baseline.N_NO_COVERAGE == 0 + # min_quality_evidence=2 → WES_B has 1 quality_pass (S07) < 2 → S08, S09 filtered + r = db.query(chrom="chr1", pos=3500, min_quality_evidence=2)[0] + assert r.N_NO_COVERAGE == 2 + + def test_min_quality_evidence_errors_on_old_db(self, test_db): + """Using --min-quality-evidence on a schema_version < 3.0 DB raises ValueError.""" + db = _db(test_db) + with pytest.raises(ValueError, match="coverage quality"): + db.query(chrom="chr1", pos=3500, min_quality_evidence=1) + + def test_phase1_phase2_combine(self, test_db, tmp_path): + """N_NO_COVERAGE = union of stored filtered_bitmap and Phase 1 dynamic filtering.""" + p2 = _make_phase2_db(test_db, tmp_path, + filtered_at_chr1_3500=[8], + quality_pass_at_chr1_3500=[7]) + db = Database(p2) + # filtered_bitmap = {8} → N_NO_COVERAGE = 1 baseline + baseline = db.query(chrom="chr1", pos=3500)[0] + assert baseline.N_NO_COVERAGE == 1 + # min_pass=2 → WES_B has 1 PASS carrier < 2 → also filters S09 (S08 already) + # Union: {8} | {8, 9} = {8, 9} → N_NO_COVERAGE = 2 + r = db.query(chrom="chr1", pos=3500, min_pass=2)[0] + assert r.N_NO_COVERAGE == 2 + + +# --------------------------------------------------------------------------- +# Phase 1 + ploidy chromosomes +# --------------------------------------------------------------------------- + +class TestInvariantOnSexChroms: + def test_chrm_invariant(self, test_db): + """Invariant must hold on chrM (haploid).""" + db = _db(test_db) + for r in db.query(chrom="chrM", pos=100, min_pass=1): + assert ( + r.N_HET + r.N_HOM_ALT + r.N_HOM_REF + r.N_FAIL + r.N_NO_COVERAGE + == r.n_samples_eligible + ) + + def test_chry_invariant(self, test_db): + """Invariant must hold on chrY.""" + db = _db(test_db) + for r in db.query(chrom="chrY", pos=500000, min_pass=1): + assert ( + r.N_HET + r.N_HOM_ALT + r.N_HOM_REF + r.N_FAIL + r.N_NO_COVERAGE + == r.n_samples_eligible + ) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index 6958ea9..7bfa0f6 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -67,7 +67,8 @@ def _write_vcf_with_filter(path: Path, sample_name: str, records: list[tuple]) - def _write_parquet(path: Path, rows: list[tuple]) -> None: """Write ingest-schema Parquet. rows = [(chrom, pos, ref, alt, gt_ac, sample_id[, filter_pass]), ...] - filter_pass defaults to True if not provided.""" + filter_pass defaults to True if not provided. Quality columns (dp/gq/qual) default to None.""" + n = len(rows) table = pa.table( { "chrom": pa.array([r[0] for r in rows], type=pa.utf8()), @@ -77,6 +78,9 @@ def _write_parquet(path: Path, rows: list[tuple]) -> None: "gt_ac": pa.array([r[4] for r in rows], type=pa.uint8()), "sample_id": pa.array([r[5] for r in rows], type=pa.uint32()), "filter_pass": pa.array([r[6] if len(r) > 6 else True for r in rows], type=pa.bool_()), + "dp": pa.array([None] * n, type=pa.int32()), + "gq": pa.array([None] * n, type=pa.int32()), + "qual": pa.array([None] * n, type=pa.float32()), }, schema=INGEST_SCHEMA, ) From 2dec6ec86b3ace8a0236587cc0fc1baf180ece9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20L=C3=B3pez?= Date: Wed, 6 May 2026 16:41:16 +0200 Subject: [PATCH 3/7] docs: document N_NO_COVERAGE and coverage-evidence filters - docs/advanced/coverage-evidence.md (new): conceptual guide covering both phases, threshold-selection guidance, and the new genotype invariant. - docs/guides/query.md: new section "Coverage-Evidence Filters" plus N_NO_COVERAGE in text/tsv/json examples. - docs/guides/create-database.md: new section explaining --min-dp/--min-gq/ --min-qual/--min-covered and the schema_version 3.0 bump. - docs/guides/update-database.md: explain filtered_bitmap recomputation on add-samples for Phase 2 databases. - docs/getting-started/preprocessing.md: note FORMAT/DP and FORMAT/GQ preservation for Phase 2 quality thresholds. - docs/reference/cli.md: add Phase 1/2 flags to query, variant-info, annotate, dump, and create-db. - docs/reference/python-api.md: add N_NO_COVERAGE to QueryResult, 'no_coverage' genotype to SampleCarrier, and the three new SampleFilter fields. - mkdocs.yml: link the new advanced page. --- docs/advanced/coverage-evidence.md | 150 ++++++++++++++++++++++++++ docs/getting-started/preprocessing.md | 1 + docs/guides/create-database.md | 47 ++++++++ docs/guides/query.md | 41 ++++++- docs/guides/update-database.md | 18 ++++ docs/reference/cli.md | 25 ++++- docs/reference/python-api.md | 27 ++++- mkdocs.yml | 1 + 8 files changed, 300 insertions(+), 10 deletions(-) create mode 100644 docs/advanced/coverage-evidence.md diff --git a/docs/advanced/coverage-evidence.md b/docs/advanced/coverage-evidence.md new file mode 100644 index 0000000..0a74607 --- /dev/null +++ b/docs/advanced/coverage-evidence.md @@ -0,0 +1,150 @@ +# Coverage Evidence + +`N_HOM_REF` is computed as a residual: +`len(eligible) − N_HET − N_HOM_ALT − N_FAIL`. For WGS samples that residual is +exactly right: every covered sample without a variant call is hom-ref. For WES +samples it is a *best-effort* assumption: the BED capture region tells us a +position *could* be sequenced, but not that it *was* sequenced at adequate +depth in this particular sample. Standard variant-only VCFs do not contain +hom-ref calls, so AFQuery cannot distinguish "true hom-ref" from "no coverage" +for non-carrier WES samples. + +Two opt-in mechanisms let users tighten that assumption. Together they expose +a new field, **`N_NO_COVERAGE`**, that holds samples whose hom-ref status is +not trusted under the chosen criteria. The new genotype invariant is: + +``` +N_HET + N_HOM_ALT + N_HOM_REF + N_FAIL + N_NO_COVERAGE = n_eligible +``` + +Samples in `N_NO_COVERAGE` remain in `eligible` and `AN` (just like +`N_FAIL`), so allele frequencies stay conservative. + +WGS samples are never re-classified as `N_NO_COVERAGE`. Carrier samples +(`het` / `hom` / `fail`) are never affected by these filters — only WES +non-carriers can move between `N_HOM_REF` and `N_NO_COVERAGE`. + +--- + +## Phase 1 — Query-time, evidence-counting + +No schema change. AFQuery counts existing carriers per WES tech at each +position and applies a per-tech gate. + +| Flag | Counts | +|------|--------| +| `--min-pass K` | `het ∪ hom` PASS carriers within the tech | +| `--min-observed K` | `het ∪ hom ∪ fail` carriers within the tech | + +If the tech falls below either threshold, *all of its non-carrier samples* at +that position move from `N_HOM_REF` to `N_NO_COVERAGE`. When both flags are +set, both must hold (AND). Default `0` ⇒ no filtering, identical to legacy +behaviour. + +```bash +afquery query --db ./db/ --locus chr1:925952 --min-pass 1 +``` + +Cost: a few extra bitmap intersections per WES tech per position. Suitable for +existing databases without re-creation. + +--- + +## Phase 2 — Build-time, quality-aware + +Phase 2 stores the result of a quality decision so the query layer does no +extra work. It requires a one-time creation with quality thresholds: + +| Flag (`create-db`) | Effect | +|--------------------|--------| +| `--min-dp D` | Minimum `FORMAT/DP` for a carrier to count as quality evidence. | +| `--min-gq G` | Minimum `FORMAT/GQ` for a carrier to count as quality evidence. | +| `--min-qual Q` | Minimum `QUAL` field for a carrier to count as quality evidence. | +| `--min-covered K`| Per WES tech, position is "trusted" only if at least K carriers pass the quality thresholds. | + +Two new Parquet columns are written per `(chrom, pos, ref, alt)`: + +- `quality_pass_bitmap` — carriers that meet `min_dp` AND `min_gq` AND `min_qual`. +- `filtered_bitmap` — non-carrier WES samples whose tech failed the + `min_covered` gate. + +`schema_version` is bumped to `3.0` and the chosen thresholds are recorded in +`manifest.json` under `coverage_filter`. They are immutable; they apply to +samples added later via `update-db --add-samples`. + +At query time `filtered_bitmap` is intersected with `eligible` and added to +`N_NO_COVERAGE` automatically — no additional flag needed. + +```bash +afquery create-db \ + --manifest samples.tsv \ + --output-dir ./db/ \ + --genome-build GRCh38 \ + --bed-dir ./beds/ \ + --min-dp 30 --min-gq 20 --min-covered 1 +``` + +### Query-time companion: `--min-quality-evidence K` + +Only valid against `schema_version ≥ 3.0` databases. Tightens the build-time +gate: at query time, a WES tech needs ≥K samples in `quality_pass_bitmap`, +otherwise its remaining (non-carrier, not-already-filtered) samples join +`N_NO_COVERAGE`. + +```bash +afquery query --db ./db/ --locus chr1:925952 --min-quality-evidence 5 +``` + +Using this flag against an older DB raises a `ValueError` with a clear message. + +--- + +## Combining Phase 1 and Phase 2 + +The two phases are additive. `N_NO_COVERAGE` is the *union* of: + +1. Stored `filtered_bitmap & eligible` (Phase 2, automatic). +2. Phase 1 dynamic filtering driven by `--min-pass` / `--min-observed`. +3. Phase 2 query-time tightening driven by `--min-quality-evidence`. + +Carriers are never included; samples cannot be double-counted. + +--- + +## Choosing thresholds + +- **Pure genotyping (no quality info available)**: use `--min-pass 1` or + `--min-observed 1` at query time. No DB rebuild needed. Conservative; + positions that were probably sequenced but happen to have zero PASS calls + in your cohort flip to `N_NO_COVERAGE`. +- **Real cohorts with DP/GQ available**: rebuild with + `--min-dp 20 --min-gq 20 --min-covered 1`. Carriers with low confidence stop + validating positions. `N_NO_COVERAGE` rises only where the cohort signal is + truly weak. +- **High-stakes clinical use**: layer on `--min-quality-evidence 3` (or + similar) at query time to demand multiple independent quality calls per + tech. + +--- + +## Output channels + +- `query` / `query_region` / `query_batch_multi` / `dump` / `annotate` all + expose `N_NO_COVERAGE` as a first-class field/column, plus per-group variants + in `dump` (`N_NO_COVERAGE_