diff --git a/Makefile b/Makefile index d203506..dd27143 100644 --- a/Makefile +++ b/Makefile @@ -2,8 +2,10 @@ ## datasets (Argelaguet, CRC, Ecker) on a server with enough RAM. ## ## Not intended for laptops: the runs allocate hundreds of GB of virtual -## memory across many parallel amet jobs. The recipes set ulimit -v -## 200 GB as a soft safeguard and let snakemake fan out across CORES cores. +## memory across many parallel amet jobs. The recipes set ulimit -v as a +## soft safeguard (100 GB per process, see ULIMIT_KB) and let snakemake fan +## out across CORES cores. The ulimit is per-process: every job shell +## inherits the same cap, it is not a shared budget across rules. ## ## Usage: ## make argelaguet # proto by default (results/argelaguet_proto/) @@ -17,15 +19,16 @@ ## Variables (override on the command line): ## MODE proto | full which dataset config file to load ## (default: proto) -## CORES snakemake --cores value (default 16) -## ULIMIT_KB virtual memory cap in KB (default 209715200, i.e. 200 GB) +## CORES snakemake --cores value (default 40) +## ULIMIT_KB per-process virtual memory cap in KB, inherited by every +## job shell (default 104857600, i.e. 100 GB) ## CONDA_ENV name of the conda env that holds snakemake (default snakemake) ## CONDA_INIT path to the conda activation script ## (default ~/miniconda3/bin/activate) MODE ?= proto -CORES ?= 16 -ULIMIT_KB ?= 209715200 +CORES ?= 40 +ULIMIT_KB ?= 104857600 CONDA_ENV ?= snakemake CONDA_INIT ?= $(HOME)/miniconda3/bin/activate diff --git a/README.md b/README.md index 0869fe0..8ba88dc 100644 --- a/README.md +++ b/README.md @@ -95,8 +95,8 @@ Tunable variables: | Variable | Default | Description | |---|---|---| | `MODE` | `proto` | `proto` or `full`; picks `workflow/config/datasets_$(MODE).yaml` | -| `CORES` | 16 | Snakemake `--cores` value | -| `ULIMIT_KB` | 209715200 (200 GB) | Virtual-memory cap; inherited by every amet job | +| `CORES` | 40 | Snakemake `--cores` value | +| `ULIMIT_KB` | 104857600 (100 GB) | Per-process virtual-memory cap, inherited by every amet job shell. Not a shared budget: each of the `CORES` concurrent jobs is capped independently. | | `CONDA_ENV` | `snakemake` | Conda env that holds snakemake | | `CONDA_INIT` | `~/miniconda3/bin/activate` | Conda activation script | @@ -125,7 +125,7 @@ The three dataset rule files expand over a fixed list of annotations defined at ### Server deployment -The Makefile is designed for a workstation with enough RAM for the whole-genome amet runs (hundreds of GB virtual memory under parallel jobs; the recipes apply `ulimit -v` as a soft cap). It is not designed for laptops. +The Makefile is designed for a workstation with enough RAM for the whole-genome amet runs. The recipes apply `ulimit -v` as a soft per-process cap (`ULIMIT_KB`, default 100 GB); it is inherited by every job shell and bounds each amet process independently, so peak machine memory is roughly `CORES` times one job's actual usage, not the cap. Each amet job scores all of a dataset's annotation BEDs in one pass, so its footprint grows with the total number of features across those BEDs. It is not designed for laptops. If you are in the Mark Robinson lab at UZH, `workflow/scripts/internal/setup_barbara_links.sh` populates `results//{cells,raw,features,mm10,hg19}` as symlinks to the pre-staged data tree on `barbara`'s filesystem. See `workflow/scripts/internal/README.md`. Outside that lab, run the per-dataset download rules in each `.smk` instead. @@ -136,7 +136,7 @@ If you are in the Mark Robinson lab at UZH, `workflow/scripts/internal/setup_bar | `--genome` | (required, mutually exclusive with `--cpg-reference`) | FASTA of the reference genome. amet derives all CpG positions on first use and caches them to `.cpg`. | | `--cpg-reference` | (required, mutually exclusive with `--genome`) | Tab-separated `chrom\tpos` of every CpG, 0-based. Defines adjacency: an uncovered reference CpG breaks 2-mer pairing across it. | | `--cells` | (required unless `--build-cpg-only`) | Manifest TSV (see below). | -| `--features` | (required unless `--build-cpg-only`) | BED of regions to score. Features must not overlap. | +| `--features` | (required unless `--build-cpg-only`) | BED of regions to score. Features within a BED must not overlap. Pass `--features` multiple times to score against several BEDs in one cell-read pass; see "Multiple feature sets" below. | | `--output-prefix` | (required unless `--build-cpg-only`) | Prefix for the output files. | | `--build-cpg-only` | off | Only materialise `.cpg` and exit. Requires `--genome`. | | `--group-column` | `group` | Manifest column to use as the group label. | @@ -144,7 +144,7 @@ If you are in the Mark Robinson lab at UZH, `workflow/scripts/internal/setup_bar | `--min-reads-per-cpg` | `1` | A CpG is observed only if covered by at least this many reads. Bulk WGBS users typically set 5-10. | | `--min-cpgs-per-feature` | `5` | A `(cell, feature)` is scored only if at least this many CpGs are covered. Below the threshold, scores are reported as `NA`. | | `--min-cells-per-group` | `10` | A `(feature, group)` reports `jsd` only if at least this many cells pass the per-cell coverage filter. Otherwise `jsd` is `NA`. | -| `--i-max-lag` | `3` | Maximum CpG lag k for `I_total = sum_{k=1..max} I_k`. | +| `--i-max-lag` | `3` | Maximum CpG lag k for `I_total = sum_{k=1..max} I_k`. Must be at least 1; lag 1 underpins JSD. | | `--max-pair-distance` | `0` (off) | Maximum nucleotide distance allowed between two CpGs of a pair. Pairs whose genomic distance exceeds this value are not counted, at any lag. `0` disables the cap. | | `--threads` | `0` (all) | Number of threads. | @@ -189,6 +189,35 @@ chr1 1000 2000 promoter_GENE1 chr1 5000 7000 cgi_chr1_5000 ``` +### Multiple feature sets + +`--features` can be passed more than once to score the same cells against several BEDs in a single run, so each cell file is parsed only once regardless of how many feature sets are used. This is the recommended way to compare, for example, promoters, enhancers, and heterochromatin in one pass. + +``` +amet \ + --genome mm10.fa \ + --cells cells.tsv \ + --features promoters.bed \ + --features enhancers.bed \ + --features heterochromatin.bed \ + --output-prefix run1 +``` + +With a single `--features` the output paths are exactly as documented above (`run1.cell_feature.tsv.gz`, etc.). With two or more, amet writes one output triplet per BED, keyed by the BED basename: + +``` +run1.promoters.cell_feature.tsv.gz +run1.promoters.feature.tsv.gz +run1.promoters.pair_counts.tsv.gz +run1.enhancers.cell_feature.tsv.gz +run1.enhancers.feature.tsv.gz +run1.enhancers.pair_counts.tsv.gz +run1.heterochromatin.cell_feature.tsv.gz +... +``` + +The label is the BED file name with any of `.bed.gz`, `.bed.bgz`, `.bed`, `.gz`, or `.bgz` stripped. If two BEDs resolve to the same label (for example `regions.bed` in two different directories), amet exits with an error before doing any work; rename one of the inputs. + ### Outputs `.cell_feature.tsv.gz`. One row per `(cell, feature)`: diff --git a/method/src/cli.rs b/method/src/cli.rs index 1ebc4ba..6542376 100644 --- a/method/src/cli.rs +++ b/method/src/cli.rs @@ -12,9 +12,17 @@ pub struct Cli { #[arg(long, value_name = "TSV", required_unless_present = "build_cpg_only")] pub cells: Option, - /// BED file of features to score. Features should not overlap. - #[arg(long, value_name = "BED", required_unless_present = "build_cpg_only")] - pub features: Option, + /// BED file of features to score. Features within a single BED should not overlap. + /// Pass --features multiple times to score the same cells against several feature + /// sets in one cell-read pass; each set writes its own output triplet keyed by the + /// BED basename. With a single --features the output paths are unchanged. + #[arg( + long, + value_name = "BED", + action = clap::ArgAction::Append, + required_unless_present = "build_cpg_only" + )] + pub features: Vec, /// FASTA of the reference genome. amet derives all CpG positions from it on first /// use and caches them to .cpg next to the input. Subsequent runs reuse the @@ -63,7 +71,8 @@ pub struct Cli { pub min_cells_per_group: u32, /// Maximum CpG lag k for the I_total within-cell score: I_total = sum_{k=1..max} I_k. - #[arg(long, default_value_t = 3)] + /// Must be at least 1; lag 1 is required to compute JSD. + #[arg(long, default_value_t = 3, value_parser = clap::value_parser!(u32).range(1..))] pub i_max_lag: u32, /// Maximum nucleotide distance allowed between paired CpGs. Pairs whose genomic diff --git a/method/src/io.rs b/method/src/io.rs index 4add1d1..89a21a3 100644 --- a/method/src/io.rs +++ b/method/src/io.rs @@ -17,8 +17,9 @@ pub fn open_read(path: &Path) -> Result> { } } -/// Open a file for writing, gzipping if the path ends with .gz. -pub fn open_write(path: &Path) -> Result> { +/// Open a file for writing, gzipping if the path ends with .gz. The returned +/// writer is `Send` so it can be shared across worker threads behind a `Mutex`. +pub fn open_write(path: &Path) -> Result> { let file = File::create(path)?; let ext = path.extension().and_then(|s| s.to_str()).unwrap_or(""); if ext == "gz" { diff --git a/method/src/kmer.rs b/method/src/kmer.rs index 0122567..6f70d7d 100644 --- a/method/src/kmer.rs +++ b/method/src/kmer.rs @@ -1,39 +1,40 @@ -//! Per-cell L-mer counting at fixed lag. +//! Per-cell pair counting at fixed lag. //! -//! For a feature, the cell's calls are placed onto the reference's CpG positions, leaving -//! gaps where the cell has no call. A pair (X_i, X_{i+k}) is counted only when both -//! reference positions i and i+k inside the feature are observed in this cell. +//! For a feature, the cell's calls are placed onto the reference's CpG positions. Only +//! observed positions are kept in a compact list, sorted by reference-CpG index. A pair +//! (X_i, X_{i+k}) at lag k is counted only when both reference positions i and i+k inside +//! the feature are observed in this cell. use crate::MethCall; use crate::features::Feature; use crate::reference::CpgReference; -/// One feature's per-cell binary observation: one entry per reference CpG in the feature, -/// either Some(0/1) if observed in this cell or None if missing. `positions` mirrors -/// `calls` and holds the 0-based CpG start coordinate for each slot. +/// One observed methylation call inside a feature, addressed by the reference-CpG index +/// and the genomic position of the C on the + strand. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Observation { + pub ref_idx: u32, + pub value: u8, + pub pos: u64, +} + +/// Per-cell observations for a feature, sorted by `ref_idx` ascending. pub struct CellWindow<'a> { pub feature: &'a Feature, - pub calls: Vec>, - pub positions: &'a [u64], + pub observed: Vec, } impl<'a> CellWindow<'a> { pub fn n_observed(&self) -> usize { - self.calls.iter().filter(|c| c.is_some()).count() + self.observed.len() } pub fn mean_meth(&self) -> Option { - let mut sum = 0u32; - let mut n = 0u32; - for v in self.calls.iter().flatten() { - sum += *v as u32; - n += 1; - } - if n == 0 { - None - } else { - Some(sum as f64 / n as f64) + if self.observed.is_empty() { + return None; } + let sum: u32 = self.observed.iter().map(|o| o.value as u32).sum(); + Some(sum as f64 / self.observed.len() as f64) } } @@ -73,10 +74,11 @@ impl MarginalCounts { } } -/// Build the per-feature observation vector for one cell. +/// Build the per-feature observation list for one cell. /// /// `calls` must be sorted by (chrom_id, pos). The function aligns them to the reference -/// positions in the feature's `cpg_start_idx..cpg_end_idx` range. +/// positions in the feature's `cpg_start_idx..cpg_end_idx` range and emits one entry per +/// observed CpG that passes the binarization filter. pub fn build_window<'a>( feature: &'a Feature, reference: &'a CpgReference, @@ -87,13 +89,11 @@ pub fn build_window<'a>( let positions = &reference.positions[feature.chrom_id as usize]; let feature_positions = &positions[feature.cpg_start_idx..feature.cpg_end_idx]; let n = feature_positions.len(); - let mut window = vec![None; n]; if n == 0 { return CellWindow { feature, - calls: window, - positions: feature_positions, + observed: Vec::new(), }; } @@ -103,7 +103,8 @@ pub fn build_window<'a>( let lo = calls.partition_point(|c| (c.chrom_id, c.pos) < (feature.chrom_id, feature_start_pos)); let hi = calls.partition_point(|c| (c.chrom_id, c.pos) <= (feature.chrom_id, feature_end_pos)); - let mut ref_idx = 0; + let mut observed = Vec::new(); + let mut ref_idx = 0usize; for call in &calls[lo..hi] { if call.chrom_id != feature.chrom_id { continue; @@ -116,43 +117,88 @@ pub fn build_window<'a>( } if feature_positions[ref_idx] == call.pos { if let Some(b) = call.binarize(threshold, min_reads) { - window[ref_idx] = Some(b); + observed.push(Observation { + ref_idx: ref_idx as u32, + value: b, + pos: call.pos, + }); } ref_idx += 1; } } - CellWindow { - feature, - calls: window, - positions: feature_positions, - } + CellWindow { feature, observed } } -/// Count (X_i, X_{i+k}) pairs in a cell window for a single lag k. Pairs whose genomic -/// distance exceeds `max_distance` are skipped; pass 0 to disable the cap. +/// Count (X_i, X_{i+k}) pairs at a single lag k. Pairs whose genomic distance exceeds +/// `max_distance` are skipped; pass 0 to disable the cap. pub fn pair_counts(window: &CellWindow, lag: usize, max_distance: u64) -> PairCounts { let mut pc = PairCounts::default(); - let n = window.calls.len(); - if lag == 0 || lag >= n { + let obs = &window.observed; + if lag == 0 || obs.is_empty() { return pc; } - for i in 0..(n - lag) { - if max_distance > 0 && window.positions[i + lag] - window.positions[i] > max_distance { - continue; + let lag_u32 = lag as u32; + let mut j = 0usize; + for i in 0..obs.len() { + let target = obs[i].ref_idx + lag_u32; + while j < obs.len() && obs[j].ref_idx < target { + j += 1; + } + if j >= obs.len() { + break; } - if let (Some(a), Some(b)) = (window.calls[i], window.calls[i + lag]) { - pc.counts[(a as usize) * 2 + b as usize] += 1; + if obs[j].ref_idx == target + && (max_distance == 0 || obs[j].pos - obs[i].pos <= max_distance) + { + let idx = (obs[i].value as usize) * 2 + obs[j].value as usize; + pc.counts[idx] += 1; } } pc } -/// Marginal counts over the whole window. +/// Count pairs at every lag in 1..=k_max in a single sweep over observed pairs. +/// Breaks the inner walk as soon as the lag exceeds k_max or the genomic distance +/// exceeds `max_distance` (0 disables the cap). Returns a `Vec` of length `k_max`, +/// where index `k - 1` holds the table for lag k. +/// +/// Cost is O(n_observed * local_neighbours_within_k_max), independent of the feature's +/// total CpG count. This is the hot path for large features such as heterochromatin +/// blocks where most reference CpGs are not observed in a given single cell. +pub fn pair_counts_all_lags( + window: &CellWindow, + k_max: usize, + max_distance: u64, +) -> Vec { + let mut out = vec![PairCounts::default(); k_max]; + if k_max == 0 { + return out; + } + let obs = &window.observed; + let k_max_u32 = k_max as u32; + for i in 0..obs.len() { + let oi = obs[i]; + for oj in &obs[i + 1..] { + let lag = oj.ref_idx - oi.ref_idx; + if lag > k_max_u32 { + break; + } + if max_distance > 0 && oj.pos - oi.pos > max_distance { + break; + } + let idx = (oi.value as usize) * 2 + oj.value as usize; + out[(lag as usize) - 1].counts[idx] += 1; + } + } + out +} + +/// Marginal counts over all observed positions in the window. pub fn marginal_counts(window: &CellWindow) -> MarginalCounts { let mut mc = MarginalCounts::default(); - for v in window.calls.iter().flatten() { - mc.counts[*v as usize] += 1; + for o in &window.observed { + mc.counts[o.value as usize] += 1; } mc } @@ -206,7 +252,26 @@ mod tests { }, ]; let w = build_window(&f, &r, &calls, 0.0, 1); - assert_eq!(w.calls, vec![Some(1), None, Some(0), None, Some(1)]); + assert_eq!( + w.observed, + vec![ + Observation { + ref_idx: 0, + value: 1, + pos: 10 + }, + Observation { + ref_idx: 2, + value: 0, + pos: 30 + }, + Observation { + ref_idx: 4, + value: 1, + pos: 50 + }, + ] + ); } #[test] @@ -241,11 +306,10 @@ mod tests { }, ]; let w = build_window(&f, &r, &calls, 0.0, 1); - // calls = [1,1,_,0,0]; lag-1 pairs: (1,1), (0,0). (1,_) and (_,0) and (0,0) excluded. + // observed = [1@0, 1@1, 0@3, 0@4]; lag-1 pairs: (1,1) ref 0->1, (0,0) ref 3->4. let pc = pair_counts(&w, 1, 0); - // 11 (idx 3) = 1; 00 (idx 0) = 1. - assert_eq!(pc.counts[3], 1); - assert_eq!(pc.counts[0], 1); + assert_eq!(pc.counts[3], 1); // (1,1) + assert_eq!(pc.counts[0], 1); // (0,0) assert_eq!(pc.total(), 2); } @@ -286,7 +350,7 @@ mod tests { }, ]; let w = build_window(&f, &r, &calls, 0.0, 1); - // calls = [1,0,1,0,1]; lag-2 pairs: (1,1), (0,0), (1,1). + // observed values [1,0,1,0,1]; lag-2 pairs: (1,1), (0,0), (1,1). let pc = pair_counts(&w, 2, 0); assert_eq!(pc.counts[3], 2); // (1,1) assert_eq!(pc.counts[0], 1); // (0,0) @@ -362,7 +426,16 @@ mod tests { }, ]; let w = build_window(&f, &r, &calls, 0.0, 2); - assert_eq!(w.calls, vec![None, Some(1), None, None, None]); + // pos 10 dropped (t < 2); pos 20 kept (value 1, ref_idx 1). + assert_eq!(w.observed.len(), 1); + assert_eq!( + w.observed[0], + Observation { + ref_idx: 1, + value: 1, + pos: 20 + } + ); } #[test] @@ -471,4 +544,219 @@ mod tests { let pc = pair_counts(&w, 10, 0); assert_eq!(pc.total(), 0); } + + #[test] + fn pair_counts_all_lags_matches_single_lag() { + // Five observations spanning ref indices 0..5; cross-check the multi-lag sweep + // against repeated single-lag calls. + let r = ref3(); + let f = feat_full(); + let calls = vec![ + MethCall { + chrom_id: 0, + pos: 10, + m: 1, + t: 1, + }, + MethCall { + chrom_id: 0, + pos: 20, + m: 0, + t: 1, + }, + MethCall { + chrom_id: 0, + pos: 30, + m: 1, + t: 1, + }, + MethCall { + chrom_id: 0, + pos: 40, + m: 0, + t: 1, + }, + MethCall { + chrom_id: 0, + pos: 50, + m: 1, + t: 1, + }, + ]; + let w = build_window(&f, &r, &calls, 0.0, 1); + let all = pair_counts_all_lags(&w, 4, 0); + assert_eq!(all.len(), 4); + for k in 1..=4 { + let single = pair_counts(&w, k, 0); + assert_eq!( + all[k - 1].counts, + single.counts, + "lag {} mismatch: all={:?} single={:?}", + k, + all[k - 1].counts, + single.counts + ); + } + } + + /// Brute-force reference implementation against which the optimized sweep is checked. + /// Mirrors the pre-rewrite dense-window algorithm: walk every reference slot, treat + /// unobserved slots as gaps, and emit one PairCounts per lag. + fn brute_force_pair_counts( + feature_len: usize, + slot_value: &[Option], + slot_pos: &[u64], + k_max: usize, + max_distance: u64, + ) -> Vec { + assert_eq!(slot_value.len(), feature_len); + assert_eq!(slot_pos.len(), feature_len); + let mut out = vec![PairCounts::default(); k_max]; + for lag in 1..=k_max { + if lag >= feature_len { + continue; + } + for i in 0..(feature_len - lag) { + if max_distance > 0 && slot_pos[i + lag] - slot_pos[i] > max_distance { + continue; + } + if let (Some(a), Some(b)) = (slot_value[i], slot_value[i + lag]) { + out[lag - 1].counts[(a as usize) * 2 + b as usize] += 1; + } + } + } + out + } + + /// xorshift32 — deterministic, no external rng dependency, good enough for tests. + fn rng_next(state: &mut u32) -> u32 { + let mut x = *state; + x ^= x << 13; + x ^= x >> 17; + x ^= x << 5; + *state = x; + x + } + + #[test] + fn pair_counts_all_lags_matches_brute_force_random() { + // Stress test: 200 reference CpGs, irregular spacing, ~30% observed, varying + // values. Run several seeds and several max_distance settings. + let feature_len = 200usize; + let mut positions = Vec::with_capacity(feature_len); + // Irregular but monotonic positions so max_distance has bite. + let mut p = 100u64; + let mut spacing_state = 0xdead_beef_u32; + for _ in 0..feature_len { + positions.push(p); + let gap = (rng_next(&mut spacing_state) % 50) as u64 + 1; + p += gap; + } + + let chrom_id_of: std::collections::HashMap = + [("chr1".into(), 0u32)].into_iter().collect(); + let reference = CpgReference { + chrom_names: vec!["chr1".into()], + chrom_id_of, + positions: vec![positions.clone()], + }; + let feature = Feature { + feature_id: "f".into(), + chrom_id: 0, + start: 0, + end: positions[feature_len - 1] + 10, + cpg_start_idx: 0, + cpg_end_idx: feature_len, + }; + + let k_max = 8usize; + let max_distances = [0u64, 50, 200, 1000]; + + for seed in [1u32, 42, 9999, 0x5a5a5a5a] { + let mut state = seed; + // Build synthetic dense slot arrays (Option per reference slot), + // and the matching sorted MethCall list for build_window. + let mut slot_value: Vec> = vec![None; feature_len]; + let mut calls: Vec = Vec::new(); + for i in 0..feature_len { + if rng_next(&mut state) % 100 < 30 { + let v = (rng_next(&mut state) & 1) as u8; + slot_value[i] = Some(v); + calls.push(MethCall { + chrom_id: 0, + pos: positions[i], + m: v as u32, + t: 1, + }); + } + } + let window = build_window(&feature, &reference, &calls, 0.0, 1); + + for &md in &max_distances { + let got = pair_counts_all_lags(&window, k_max, md); + let want = brute_force_pair_counts(feature_len, &slot_value, &positions, k_max, md); + for k in 1..=k_max { + assert_eq!( + got[k - 1].counts, + want[k - 1].counts, + "seed={seed} max_distance={md} lag={k}: optimized={:?} brute={:?}", + got[k - 1].counts, + want[k - 1].counts, + ); + } + // Also check the single-lag pair_counts API matches the multi-lag sweep. + for k in 1..=k_max { + let single = pair_counts(&window, k, md); + assert_eq!( + single.counts, + got[k - 1].counts, + "single vs sweep mismatch at lag {k}", + ); + } + } + } + } + + #[test] + fn pair_counts_all_lags_respects_max_distance() { + // Same fixture as max_distance_drops_far_pairs but via the multi-lag sweep. + let r = CpgReference { + chrom_names: vec!["chr1".into()], + chrom_id_of: [("chr1".into(), 0u32)].into_iter().collect(), + positions: vec![vec![10, 20, 1100]], + }; + let f = Feature { + feature_id: "f".into(), + chrom_id: 0, + start: 0, + end: 2000, + cpg_start_idx: 0, + cpg_end_idx: 3, + }; + let calls = vec![ + MethCall { + chrom_id: 0, + pos: 10, + m: 1, + t: 1, + }, + MethCall { + chrom_id: 0, + pos: 20, + m: 0, + t: 1, + }, + MethCall { + chrom_id: 0, + pos: 1100, + m: 1, + t: 1, + }, + ]; + let w = build_window(&f, &r, &calls, 0.0, 1); + let all = pair_counts_all_lags(&w, 2, 1000); + assert_eq!(all[0].counts[2], 1); // lag-1 pair (1,0) at distance 10 + assert_eq!(all[0].total(), 1); + assert_eq!(all[1].total(), 0); // lag-2 pair distance 1090 dropped + } } diff --git a/method/src/main.rs b/method/src/main.rs index 7ed3cae..be4e04c 100644 --- a/method/src/main.rs +++ b/method/src/main.rs @@ -1,17 +1,18 @@ use amet::cli::Cli; -use amet::features::read_features; +use amet::features::{Feature, read_features}; use amet::genome::ensure_cpg_index; use amet::io::open_write; -use amet::kmer::{PairCounts, build_window, marginal_counts, pair_counts}; -use amet::manifest::read_manifest; +use amet::kmer::{PairCounts, build_window, marginal_counts, pair_counts_all_lags}; +use amet::manifest::{CellRow, read_manifest}; use amet::parsers::{CellFormat, read_cell}; use amet::reference::read_cpg_reference; -use amet::scores::{i_total::i_total, jsd::multi_jsd}; -use anyhow::{Context, Result}; +use amet::scores::{i_total::i_total, jsd::JsdAccumulator}; +use anyhow::{Context, Result, anyhow}; use rayon::prelude::*; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::io::Write; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; +use std::sync::Mutex; fn main() -> Result<()> { let cli = Cli::parse_args(); @@ -41,216 +42,354 @@ fn main() -> Result<()> { ); let cells_path = cli.cells.as_ref().expect("--cells required"); - let features_path = cli.features.as_ref().expect("--features required"); let output_prefix = cli .output_prefix .as_ref() .expect("--output-prefix required"); - eprintln!("[amet] reading features: {}", features_path.display()); - let features = read_features(features_path, &reference).context("reading features")?; - eprintln!("[amet] features: {}", features.len()); + if cli.features.is_empty() { + return Err(anyhow!("--features is required")); + } + + // Resolve per-set labels and check uniqueness up front so we fail before + // reading feature BEDs and creating per-set outputs. + let labels: Vec = cli.features.iter().map(|p| features_label(p)).collect(); + { + let mut seen: HashSet<&str> = HashSet::new(); + for l in &labels { + if !seen.insert(l.as_str()) { + return Err(anyhow!( + "two --features BEDs resolve to the same label `{}`; rename one of the input files", + l + )); + } + } + } + let single_set = cli.features.len() == 1; + + let mut sets: Vec = Vec::with_capacity(cli.features.len()); + let mut sinks: Vec> = Vec::with_capacity(cli.features.len()); + let mut feat_writers: Vec> = Vec::with_capacity(cli.features.len()); + for (path, label) in cli.features.iter().zip(labels.iter()) { + eprintln!("[amet] reading features: {}", path.display()); + let features = read_features(path, &reference).context("reading features")?; + eprintln!("[amet] features ({}): {}", label, features.len()); + let set_prefix = if single_set { + output_prefix.clone() + } else { + with_suffix(output_prefix, &format!(".{}", label)) + }; + let cf_path = with_suffix(&set_prefix, ".cell_feature.tsv.gz"); + let feat_path = with_suffix(&set_prefix, ".feature.tsv.gz"); + let pair_path = with_suffix(&set_prefix, ".pair_counts.tsv.gz"); + let mut cf_writer = open_write(&cf_path).context("opening cell_feature output")?; + let mut feat_writer = open_write(&feat_path).context("opening feature output")?; + let mut pair_writer = open_write(&pair_path).context("opening pair_counts output")?; + write_headers( + &mut cf_writer, + &mut feat_writer, + &mut pair_writer, + cli.i_max_lag, + )?; + sets.push(FeatureSet { + label: label.clone(), + features, + cf_path, + feat_path, + pair_path, + }); + sinks.push(Mutex::new(SetSink { + cf_writer, + pair_writer, + agg: HashMap::new(), + })); + feat_writers.push(feat_writer); + } eprintln!("[amet] reading manifest: {}", cells_path.display()); let manifest = read_manifest(cells_path, &cli.group_column).context("reading manifest")?; eprintln!("[amet] cells: {}", manifest.len()); - let i_max_lag = cli.i_max_lag as usize; - let cf_path = with_suffix(output_prefix, ".cell_feature.tsv.gz"); - let feat_path = with_suffix(output_prefix, ".feature.tsv.gz"); - let pair_path = with_suffix(output_prefix, ".pair_counts.tsv.gz"); - - let mut cf_writer = open_write(&cf_path).context("opening cell_feature output")?; - let mut feat_writer = open_write(&feat_path).context("opening feature output")?; - let mut pair_writer = open_write(&pair_path).context("opening pair_counts output")?; - - write!( - cf_writer, - "cell_id\tgroup\tfeature_id\tn_covered\tmean_meth\tn_zeros\tn_ones\ti_total" - )?; - for k in 1..=i_max_lag { - write!(cf_writer, "\ti_{}", k)?; + // Intern group labels: the per-row aggregate key becomes two integers + // (feature index, group id) instead of two freshly allocated strings. + // A run has only a handful of distinct groups, so a linear scan is fine. + let mut group_names: Vec = Vec::new(); + for cell in &manifest { + if !group_names.iter().any(|g| g == &cell.group) { + group_names.push(cell.group.clone()); + } } - writeln!(cf_writer)?; - writeln!( - feat_writer, - "feature_id\tgroup\tn_cells\tmean_coverage\tmean_meth_mean\tmean_meth_var\ti_total_mean\ti_total_var\tjsd" - )?; - writeln!( - pair_writer, - "cell_id\tgroup\tfeature_id\tlag\tn00\tn01\tn10\tn11" - )?; - // Per-cell processing in parallel; results streamed to the writers afterwards. - let per_cell_results: Vec> = manifest - .par_iter() - .map(|cell| { - let format = match cell.format.as_deref() { - Some(s) => { - CellFormat::parse(s).unwrap_or_else(|| CellFormat::detect_from_path(&cell.path)) - } - None => CellFormat::detect_from_path(&cell.path), - }; - let calls = match read_cell(&cell.path, format, &reference) { - Ok(c) => c, - Err(e) => { - eprintln!("[amet] error reading {}: {}", cell.path.display(), e); - return Vec::new(); - } - }; - let mut rows = Vec::with_capacity(features.len()); - for feature in &features { - let window = build_window( - feature, - &reference, - &calls, - cli.meth_call_threshold, - cli.min_reads_per_cpg, - ); - let n_cov = window.n_observed() as u32; - let mc = marginal_counts(&window); - let pair_tables: Vec = (1..=i_max_lag) - .map(|lag| pair_counts(&window, lag, cli.max_pair_distance)) - .collect(); - let mean = window.mean_meth(); - let i_per_lag: Vec = - pair_tables.iter().map(amet::scores::i_total::i_k).collect(); - let total = i_total(&pair_tables); - rows.push(CellFeatureRow { - cell_id: cell.cell_id.clone(), - group: cell.group.clone(), - feature_id: feature.feature_id.clone(), - n_covered: n_cov, - n_zeros: mc.counts[0], - n_ones: mc.counts[1], - mean_meth: mean, - i_total_value: total, - i_per_lag, - pair_tables, - }); - } - rows - }) - .collect(); - - // Write per-cell-per-feature, per-cell-per-feature-per-lag, and collect lag-1 for JSD. - let mut feat_to_group_cells: HashMap<(String, String), Vec> = HashMap::new(); - let mut feat_to_group_coverage: HashMap<(String, String), (u64, u64)> = HashMap::new(); - let mut feat_to_group_meth: HashMap<(String, String), Vec> = HashMap::new(); - let mut feat_to_group_itotal: HashMap<(String, String), Vec> = HashMap::new(); + let i_max_lag = cli.i_max_lag as usize; let min_n = cli.min_cpgs_per_feature; - for cell_rows in &per_cell_results { - for row in cell_rows { - write!( - cf_writer, - "{}\t{}\t{}\t{}\t", - row.cell_id, row.group, row.feature_id, row.n_covered - )?; - match row.mean_meth { - Some(m) => write!(cf_writer, "{:.6}", m)?, - None => write!(cf_writer, "NA")?, + // Score cells in parallel. Each cell file is read once and scored against + // every feature set; its cell_feature/pair_counts rows are formatted into a + // thread-local buffer, then a brief per-set lock flushes the buffer and + // folds the streaming aggregates. Nothing accumulates across cells, so peak + // memory stays near `threads * one (cell, set) batch`. cell_feature and + // pair_counts row order is therefore cell-interleaved; consumers key by + // cell_id, and feature.tsv below is written in a sorted order. + manifest.par_iter().try_for_each(|cell| -> Result<()> { + let format = match cell.format.as_deref() { + Some(s) => { + CellFormat::parse(s).unwrap_or_else(|| CellFormat::detect_from_path(&cell.path)) } - write!(cf_writer, "\t{}\t{}", row.n_zeros, row.n_ones)?; - if row.n_covered >= min_n { - write!(cf_writer, "\t{:.6}", row.i_total_value)?; - for v in &row.i_per_lag { - write!(cf_writer, "\t{:.6}", v)?; - } - } else { - write!(cf_writer, "\tNA")?; - for _ in 0..i_max_lag { - write!(cf_writer, "\tNA")?; - } + None => CellFormat::detect_from_path(&cell.path), + }; + let calls = match read_cell(&cell.path, format, &reference) { + Ok(c) => c, + Err(e) => { + eprintln!("[amet] error reading {}: {}", cell.path.display(), e); + return Ok(()); } - writeln!(cf_writer)?; - - // Pair counts per lag, always emitted (regardless of min_cpgs_per_feature). - for (idx, pt) in row.pair_tables.iter().enumerate() { - let lag = idx + 1; - writeln!( - pair_writer, - "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}", - row.cell_id, - row.group, - row.feature_id, - lag, - pt.counts[0], - pt.counts[1], - pt.counts[2], - pt.counts[3] - )?; + }; + // Every cell's group was interned from this same manifest above. + let group_id = group_names + .iter() + .position(|g| g == &cell.group) + .expect("cell group interned from manifest") as u32; + + for (set, sink) in sets.iter().zip(sinks.iter()) { + // Cells are the parallel axis, so features are scored sequentially. + let rows: Vec = set + .features + .iter() + .map(|feature| { + let window = build_window( + feature, + &reference, + &calls, + cli.meth_call_threshold, + cli.min_reads_per_cpg, + ); + let mc = marginal_counts(&window); + let pair_tables = + pair_counts_all_lags(&window, i_max_lag, cli.max_pair_distance); + let i_per_lag: Vec = + pair_tables.iter().map(amet::scores::i_total::i_k).collect(); + CellFeatureRow { + feature_id: &feature.feature_id, + n_covered: window.n_observed() as u32, + n_zeros: mc.counts[0], + n_ones: mc.counts[1], + mean_meth: window.mean_meth(), + i_total_value: i_total(&pair_tables), + i_per_lag, + pair_tables, + } + }) + .collect(); + + // Format outside the lock so the critical section is just the + // buffered write and the aggregate fold. + let mut cf_buf: Vec = Vec::new(); + let mut pair_buf: Vec = Vec::new(); + for row in &rows { + write_cell_feature_row(&mut cf_buf, cell, row, min_n, i_max_lag)?; + write_pair_count_rows(&mut pair_buf, cell, row)?; } - if row.n_covered >= min_n { - let key = (row.feature_id.clone(), row.group.clone()); - feat_to_group_cells - .entry(key.clone()) - .or_default() - .push(row.pair_tables[0]); - let acc = feat_to_group_coverage.entry(key.clone()).or_insert((0, 0)); - acc.0 += row.n_covered as u64; - acc.1 += 1; - if let Some(m) = row.mean_meth { - feat_to_group_meth.entry(key.clone()).or_default().push(m); + let mut guard = sink.lock().expect("sink mutex poisoned"); + guard.cf_writer.write_all(&cf_buf)?; + guard.pair_writer.write_all(&pair_buf)?; + for (feat_idx, row) in rows.iter().enumerate() { + if row.n_covered >= min_n { + let e = guard.agg.entry((feat_idx as u32, group_id)).or_default(); + e.cov_sum += row.n_covered as u64; + e.n_cells += 1; + if let Some(m) = row.mean_meth { + e.meth.add(m); + } + e.itotal.add(row.i_total_value); + e.jsd.add(&row.pair_tables[0]); } - feat_to_group_itotal - .entry(key) - .or_default() - .push(row.i_total_value); } } + Ok(()) + })?; + + // Feature-level aggregates per set, written sorted by (feature_id, group). + // Aggregates are keyed by integer indices; reconstruct the names here and + // sort lexically so the output order is stable and independent of the + // HashMap iteration order. The feature index breaks ties so the order is + // fully determined even if two features share a feature_id. + for ((set, sink), feat_writer) in sets.iter().zip(sinks).zip(feat_writers.iter_mut()) { + let agg = sink.into_inner().expect("sink mutex poisoned").agg; + let w: &mut dyn Write = &mut **feat_writer; + let mut keys: Vec<(u32, u32)> = agg.keys().copied().collect(); + keys.sort_by(|a, b| { + let ka = ( + set.features[a.0 as usize].feature_id.as_str(), + group_names[a.1 as usize].as_str(), + a.0, + ); + let kb = ( + set.features[b.0 as usize].feature_id.as_str(), + group_names[b.1 as usize].as_str(), + b.0, + ); + ka.cmp(&kb) + }); + for key in keys { + let e = &agg[&key]; + let feature_id = set.features[key.0 as usize].feature_id.as_str(); + let group = group_names[key.1 as usize].as_str(); + let mean_cov = if e.n_cells > 0 { + e.cov_sum as f64 / e.n_cells as f64 + } else { + 0.0 + }; + let jsd = if e.n_cells >= cli.min_cells_per_group as u64 { + Some(e.jsd.finish()) + } else { + None + }; + write!( + w, + "{}\t{}\t{}\t{:.6}\t", + feature_id, group, e.n_cells, mean_cov + )?; + write_opt(w, e.meth.mean())?; + write!(w, "\t")?; + write_opt(w, e.meth.var())?; + write!(w, "\t")?; + write_opt(w, e.itotal.mean())?; + write!(w, "\t")?; + write_opt(w, e.itotal.var())?; + write!(w, "\t")?; + write_opt(w, jsd)?; + writeln!(w)?; + } + + eprintln!( + "[amet] done {}: wrote {}, {}, {}", + set.label, + set.cf_path.display(), + set.feat_path.display(), + set.pair_path.display() + ); } + Ok(()) +} - let mut keys: Vec<_> = feat_to_group_cells.keys().cloned().collect(); - keys.sort(); - for key in keys { - let cells = &feat_to_group_cells[&key]; - let (cov_sum, n_cells) = feat_to_group_coverage[&key]; - let mean_cov = if n_cells > 0 { - cov_sum as f64 / n_cells as f64 - } else { - 0.0 - }; - let meth_vals = feat_to_group_meth.get(&key); - let i_vals = feat_to_group_itotal.get(&key); - let (meth_mean, meth_var) = mean_var(meth_vals); - let (i_mean, i_var) = mean_var(i_vals); - let jsd = if n_cells >= cli.min_cells_per_group as u64 { - Some(multi_jsd(cells)) - } else { +/// Welford online accumulator for mean and sample variance. Used so the +/// feature-level aggregates need not retain every per-cell value. +#[derive(Default)] +struct Welford { + n: u64, + mean: f64, + m2: f64, +} + +impl Welford { + fn add(&mut self, x: f64) { + self.n += 1; + let delta = x - self.mean; + self.mean += delta / self.n as f64; + let delta2 = x - self.mean; + self.m2 += delta * delta2; + } + + fn mean(&self) -> Option { + if self.n == 0 { None } else { Some(self.mean) } + } + + fn var(&self) -> Option { + if self.n < 2 { None - }; - write!( - feat_writer, - "{}\t{}\t{}\t{:.6}\t", - key.0, key.1, n_cells, mean_cov - )?; - write_opt(&mut feat_writer, meth_mean)?; - write!(feat_writer, "\t")?; - write_opt(&mut feat_writer, meth_var)?; - write!(feat_writer, "\t")?; - write_opt(&mut feat_writer, i_mean)?; - write!(feat_writer, "\t")?; - write_opt(&mut feat_writer, i_var)?; - write!(feat_writer, "\t")?; - write_opt(&mut feat_writer, jsd)?; - writeln!(feat_writer)?; + } else { + Some(self.m2 / (self.n - 1) as f64) + } } +} - eprintln!( - "[amet] done. wrote {}, {}, {}", - cf_path.display(), - feat_path.display(), - pair_path.display() - ); +/// Streaming feature-level aggregate for one (feature, group). +#[derive(Default)] +struct AggEntry { + cov_sum: u64, + n_cells: u64, + meth: Welford, + itotal: Welford, + jsd: JsdAccumulator, +} + +fn write_cell_feature_row( + w: &mut dyn Write, + cell: &CellRow, + row: &CellFeatureRow<'_>, + min_n: u32, + i_max_lag: usize, +) -> std::io::Result<()> { + write!( + w, + "{}\t{}\t{}\t{}\t", + cell.cell_id, cell.group, row.feature_id, row.n_covered + )?; + match row.mean_meth { + Some(m) => write!(w, "{:.6}", m)?, + None => write!(w, "NA")?, + } + write!(w, "\t{}\t{}", row.n_zeros, row.n_ones)?; + if row.n_covered >= min_n { + write!(w, "\t{:.6}", row.i_total_value)?; + for v in &row.i_per_lag { + write!(w, "\t{:.6}", v)?; + } + } else { + write!(w, "\tNA")?; + for _ in 0..i_max_lag { + write!(w, "\tNA")?; + } + } + writeln!(w) +} + +fn write_pair_count_rows( + w: &mut dyn Write, + cell: &CellRow, + row: &CellFeatureRow<'_>, +) -> std::io::Result<()> { + for (idx, pt) in row.pair_tables.iter().enumerate() { + writeln!( + w, + "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}", + cell.cell_id, + cell.group, + row.feature_id, + idx + 1, + pt.counts[0], + pt.counts[1], + pt.counts[2], + pt.counts[3] + )?; + } Ok(()) } -struct CellFeatureRow { - cell_id: String, - group: String, - feature_id: String, +struct FeatureSet { + label: String, + features: Vec, + cf_path: PathBuf, + feat_path: PathBuf, + pair_path: PathBuf, +} + +/// Per-set mutable state shared across the parallel cell workers behind a +/// `Mutex`. The feature.tsv writer is not here: it is touched only before and +/// after the parallel section, never concurrently. +struct SetSink { + cf_writer: Box, + pair_writer: Box, + /// Keyed by (feature index within the set, interned group id) so the hot + /// per-row fold does not allocate a fresh key string each time. + agg: HashMap<(u32, u32), AggEntry>, +} + +/// One scored (cell, feature). `feature_id` borrows from the feature set, which +/// outlives the short-lived per-(cell, set) batch this row belongs to. +struct CellFeatureRow<'a> { + feature_id: &'a str, n_covered: u32, n_zeros: u32, n_ones: u32, @@ -260,29 +399,110 @@ struct CellFeatureRow { pair_tables: Vec, } -fn with_suffix(prefix: &std::path::Path, suffix: &str) -> std::path::PathBuf { - let mut s = prefix.as_os_str().to_owned(); - s.push(suffix); - std::path::PathBuf::from(s) +fn write_headers( + cf_writer: &mut dyn Write, + feat_writer: &mut dyn Write, + pair_writer: &mut dyn Write, + i_max_lag: u32, +) -> std::io::Result<()> { + write!( + cf_writer, + "cell_id\tgroup\tfeature_id\tn_covered\tmean_meth\tn_zeros\tn_ones\ti_total" + )?; + for k in 1..=i_max_lag { + write!(cf_writer, "\ti_{}", k)?; + } + writeln!(cf_writer)?; + writeln!( + feat_writer, + "feature_id\tgroup\tn_cells\tmean_coverage\tmean_meth_mean\tmean_meth_var\ti_total_mean\ti_total_var\tjsd" + )?; + writeln!( + pair_writer, + "cell_id\tgroup\tfeature_id\tlag\tn00\tn01\tn10\tn11" + )?; + Ok(()) } -fn mean_var(values: Option<&Vec>) -> (Option, Option) { - let v = match values { - Some(v) if !v.is_empty() => v, - _ => return (None, None), - }; - let n = v.len() as f64; - let mean = v.iter().sum::() / n; - if v.len() < 2 { - return (Some(mean), None); +/// Derive a stable label for a BED path by stripping known BED and compression +/// suffixes if present; otherwise use the file name as-is. Used to disambiguate +/// output paths when multiple --features are supplied. +fn features_label(path: &Path) -> String { + let raw = path + .file_name() + .map(|n| n.to_string_lossy().into_owned()) + .unwrap_or_else(|| "features".to_string()); + for suffix in [".bed.gz", ".bed.bgz", ".bed", ".gz", ".bgz"] { + if let Some(s) = raw.strip_suffix(suffix) { + return s.to_string(); + } } - let var = v.iter().map(|x| (x - mean).powi(2)).sum::() / (n - 1.0); - (Some(mean), Some(var)) + raw +} + +fn with_suffix(prefix: &Path, suffix: &str) -> PathBuf { + let mut s = prefix.as_os_str().to_owned(); + s.push(suffix); + PathBuf::from(s) } -fn write_opt(w: &mut W, x: Option) -> std::io::Result<()> { +fn write_opt(w: &mut W, x: Option) -> std::io::Result<()> { match x { Some(v) => write!(w, "{:.6}", v), None => write!(w, "NA"), } } + +#[cfg(test)] +mod tests { + use super::features_label; + use std::path::Path; + + #[test] + fn label_strips_bed_gz() { + assert_eq!( + features_label(Path::new("/x/promoters.bed.gz")), + "promoters" + ); + } + + #[test] + fn label_strips_bed() { + assert_eq!(features_label(Path::new("/x/enhancers.bed")), "enhancers"); + } + + #[test] + fn label_strips_gz_only() { + assert_eq!( + features_label(Path::new("/x/regions.tsv.gz")), + "regions.tsv" + ); + } + + #[test] + fn label_keeps_unknown_extension() { + assert_eq!(features_label(Path::new("/x/regions.txt")), "regions.txt"); + } + + #[test] + fn label_handles_multiple_dots() { + // Common case: dataset-tagged bed names like "mm10.heterochromatin.bed". + assert_eq!( + features_label(Path::new("/x/mm10.heterochromatin.bed")), + "mm10.heterochromatin" + ); + } + + #[test] + fn label_strips_bed_bgz() { + assert_eq!(features_label(Path::new("/x/regions.bed.bgz")), "regions"); + } + + #[test] + fn label_strips_bgz_only() { + assert_eq!( + features_label(Path::new("/x/regions.tsv.bgz")), + "regions.tsv" + ); + } +} diff --git a/method/src/scores/jsd.rs b/method/src/scores/jsd.rs index e4c8af6..19453ba 100644 --- a/method/src/scores/jsd.rs +++ b/method/src/scores/jsd.rs @@ -6,37 +6,65 @@ use super::shannon_entropy; use crate::kmer::PairCounts; +/// Streaming accumulator for multi-distribution JSD. +/// +/// Folds one cell's lag-1 2-mer histogram at a time, so per-cell counts need +/// not be retained for the whole group. Holds a running sum of normalised +/// distributions and of per-cell entropies; both terms of the JSD are means +/// over cells, so they accumulate exactly. +#[derive(Debug, Clone, Default)] +pub struct JsdAccumulator { + mixture_sum: [f64; 4], + entropy_sum: f64, + n: u64, +} + +impl JsdAccumulator { + /// Fold one cell's lag-1 2-mer counts. Cells with no pairs are ignored, + /// matching the non-empty filter in `multi_jsd`. + pub fn add(&mut self, cell: &PairCounts) { + let total = cell.total(); + if total == 0 { + return; + } + let t = total as f64; + let mut p = [0.0f64; 4]; + for (k, slot) in p.iter_mut().enumerate() { + *slot = cell.counts[k] as f64 / t; + self.mixture_sum[k] += *slot; + } + self.entropy_sum += entropy_of_distribution(&p); + self.n += 1; + } + + /// JSD = H(mean P_i) - mean H(P_i). Returns 0 with fewer than 2 non-empty cells. + pub fn finish(&self) -> f64 { + if self.n < 2 { + return 0.0; + } + let n = self.n as f64; + let mixture = [ + self.mixture_sum[0] / n, + self.mixture_sum[1] / n, + self.mixture_sum[2] / n, + self.mixture_sum[3] / n, + ]; + let jsd = entropy_of_distribution(&mixture) - self.entropy_sum / n; + if jsd < 0.0 { 0.0 } else { jsd } + } +} + /// Multi-distribution generalised JSD (in bits, using log base 2). /// /// JSD(P_1, ..., P_n) = H(mean P_i) - mean H(P_i). /// /// Returns 0 when fewer than 2 cells have non-zero histograms. pub fn multi_jsd(per_cell: &[PairCounts]) -> f64 { - let nonempty: Vec<&PairCounts> = per_cell.iter().filter(|c| c.total() > 0).collect(); - if nonempty.len() < 2 { - return 0.0; - } - - // Per-cell normalised distributions and their entropies. - let mut h_avg = 0.0; - let mut mixture = [0.0f64; 4]; - for cell in &nonempty { - let n = cell.total() as f64; - let mut p = [0.0; 4]; - for k in 0..4 { - p[k] = cell.counts[k] as f64 / n; - mixture[k] += p[k]; - } - h_avg += entropy_of_distribution(&p); - } - let n_cells = nonempty.len() as f64; - h_avg /= n_cells; - for v in &mut mixture { - *v /= n_cells; + let mut acc = JsdAccumulator::default(); + for cell in per_cell { + acc.add(cell); } - let h_mix = entropy_of_distribution(&mixture); - let jsd = h_mix - h_avg; - if jsd < 0.0 { 0.0 } else { jsd } + acc.finish() } fn entropy_of_distribution(p: &[f64]) -> f64 { @@ -118,4 +146,29 @@ mod tests { let far = vec![pc(50, 25, 15, 10), pc(10, 15, 25, 50)]; assert!(multi_jsd(&far) > multi_jsd(&close)); } + + #[test] + fn accumulator_matches_multi_jsd() { + // The streaming accumulator must yield bit-identical results to the + // slice-based multi_jsd, since the workflow relies on the streaming path. + let cases: Vec> = vec![ + vec![pc(10, 10, 10, 10), pc(20, 20, 20, 20), pc(5, 5, 5, 5)], + vec![pc(50, 25, 15, 10), pc(10, 15, 25, 50), pc(0, 0, 0, 0)], + vec![pc(10, 0, 0, 0), pc(0, 10, 0, 0), pc(0, 0, 10, 0)], + vec![pc(7, 3, 1, 9)], + vec![], + ]; + for case in &cases { + let mut acc = JsdAccumulator::default(); + for c in case { + acc.add(c); + } + assert_eq!( + acc.finish(), + multi_jsd(case), + "accumulator diverged from multi_jsd on {:?}", + case + ); + } + } } diff --git a/method/tests/integration.rs b/method/tests/integration.rs index 6432a68..1cd7f67 100644 --- a/method/tests/integration.rs +++ b/method/tests/integration.rs @@ -99,11 +99,18 @@ fn end_to_end_two_cells_one_feature_allc() { assert!(header.contains("i_1")); // The 010101 cell should have higher i_total than the 011010 cell. - let row_a: Vec<&str> = lines[1].split('\t').collect(); - let row_b: Vec<&str> = lines[2].split('\t').collect(); + // Rows are looked up by cell_id (column 0): cell_feature row order is + // cell-interleaved and not guaranteed to follow the manifest. let i_total_col = header.split('\t').position(|h| h == "i_total").unwrap(); - let a_score: f64 = row_a[i_total_col].parse().unwrap(); - let b_score: f64 = row_b[i_total_col].parse().unwrap(); + let score_of = |cell_id: &str| -> f64 { + let row = lines[1..] + .iter() + .find(|l| l.split('\t').next() == Some(cell_id)) + .unwrap_or_else(|| panic!("no cell_feature row for cell {}", cell_id)); + row.split('\t').nth(i_total_col).unwrap().parse().unwrap() + }; + let a_score = score_of("A"); + let b_score = score_of("B"); assert!( a_score > b_score, "010101 (A) should have higher i_total than 011010 (B); got {} vs {}", @@ -350,6 +357,265 @@ fn mixed_chrom_naming_still_runs() { assert_eq!(lines.len(), 2, "header + 1 cell, got {}", lines.len()); } +#[test] +fn multi_features_writes_one_triplet_per_set() { + // Two BEDs over the same cells, scored in one cell-read pass. + // Each BED gets its own output triplet keyed by basename. + let dir = tempdir().unwrap(); + + let cpgs = write_file( + dir.path(), + "cpgs.tsv", + "chr1\t9\nchr1\t19\nchr1\t29\nchr1\t39\nchr1\t49\nchr1\t59\n", + ); + + // BED 1: one region named "prom_a" over CpGs 1..3. + let bed_prom = write_file(dir.path(), "prom.bed", "chr1\t0\t35\tprom_a\n"); + // BED 2: one region named "enh_a" over CpGs 4..6. + let bed_enh = write_file(dir.path(), "enh.bed", "chr1\t35\t100\tenh_a\n"); + + let cell = write_file( + dir.path(), + "cell.allc.tsv", + "chr1\t10\t+\tCGN\t0\t1\t0\n\ + chr1\t20\t+\tCGN\t1\t1\t1\n\ + chr1\t30\t+\tCGN\t0\t1\t0\n\ + chr1\t40\t+\tCGN\t1\t1\t1\n\ + chr1\t50\t+\tCGN\t0\t1\t0\n\ + chr1\t60\t+\tCGN\t1\t1\t1\n", + ); + let manifest = write_file( + dir.path(), + "cells.tsv", + &format!("cell_id\tgroup\tpath\nA\tg1\t{}\n", cell.display()), + ); + + let prefix = dir.path().join("run"); + let status = Command::new(binary_path()) + .args([ + "--cpg-reference", + cpgs.to_str().unwrap(), + "--features", + bed_prom.to_str().unwrap(), + "--features", + bed_enh.to_str().unwrap(), + "--cells", + manifest.to_str().unwrap(), + "--output-prefix", + prefix.to_str().unwrap(), + "--min-cpgs-per-feature", + "3", + ]) + .status() + .expect("running amet binary"); + assert!(status.success(), "amet exited with non-zero status"); + + // Each set must have its own triplet, suffixed with the BED stem. + let prom_cf = dir.path().join("run.prom.cell_feature.tsv.gz"); + let prom_feat = dir.path().join("run.prom.feature.tsv.gz"); + let prom_pair = dir.path().join("run.prom.pair_counts.tsv.gz"); + let enh_cf = dir.path().join("run.enh.cell_feature.tsv.gz"); + let enh_feat = dir.path().join("run.enh.feature.tsv.gz"); + let enh_pair = dir.path().join("run.enh.pair_counts.tsv.gz"); + for p in [ + &prom_cf, &prom_feat, &prom_pair, &enh_cf, &enh_feat, &enh_pair, + ] { + assert!(p.exists(), "expected output {} to exist", p.display()); + } + + // Single-features output paths must NOT exist (no fallback bleed-through). + assert!(!dir.path().join("run.cell_feature.tsv.gz").exists()); + assert!(!dir.path().join("run.feature.tsv.gz").exists()); + assert!(!dir.path().join("run.pair_counts.tsv.gz").exists()); + + // Each set sees only its own feature id. + let prom_cf_text = read_gz(&prom_cf); + let enh_cf_text = read_gz(&enh_cf); + assert!(prom_cf_text.contains("prom_a")); + assert!(!prom_cf_text.contains("enh_a")); + assert!(enh_cf_text.contains("enh_a")); + assert!(!enh_cf_text.contains("prom_a")); +} + +#[test] +fn cli_rejects_zero_i_max_lag() { + // --i-max-lag 0 leaves pair_tables empty, which would later panic on JSD aggregation. + // The CLI is the right place to reject this; assert the binary exits non-zero. + let dir = tempdir().unwrap(); + let cpgs = write_file(dir.path(), "cpgs.tsv", "chr1\t9\n"); + let bed = write_file(dir.path(), "f.bed", "chr1\t0\t100\tx\n"); + let cells = write_file(dir.path(), "c.tsv", "cell_id\tgroup\tpath\n"); + let prefix = dir.path().join("p"); + let status = Command::new(binary_path()) + .args([ + "--cpg-reference", + cpgs.to_str().unwrap(), + "--features", + bed.to_str().unwrap(), + "--cells", + cells.to_str().unwrap(), + "--output-prefix", + prefix.to_str().unwrap(), + "--i-max-lag", + "0", + ]) + .status() + .expect("running amet binary"); + assert!( + !status.success(), + "expected non-zero exit when --i-max-lag is 0" + ); +} + +#[test] +fn multi_features_rejects_duplicate_basenames() { + // Two BEDs whose basenames collapse to the same label must be rejected. + let dir = tempdir().unwrap(); + let cpgs = write_file(dir.path(), "cpgs.tsv", "chr1\t9\n"); + let sub = dir.path().join("sub"); + std::fs::create_dir_all(&sub).unwrap(); + let bed1 = write_file(dir.path(), "regions.bed", "chr1\t0\t100\tx\n"); + let bed2 = write_file(&sub, "regions.bed", "chr1\t0\t100\ty\n"); + let cells = write_file(dir.path(), "cells.tsv", "cell_id\tgroup\tpath\n"); + let prefix = dir.path().join("run"); + let status = Command::new(binary_path()) + .args([ + "--cpg-reference", + cpgs.to_str().unwrap(), + "--features", + bed1.to_str().unwrap(), + "--features", + bed2.to_str().unwrap(), + "--cells", + cells.to_str().unwrap(), + "--output-prefix", + prefix.to_str().unwrap(), + ]) + .status() + .expect("running amet binary"); + assert!( + !status.success(), + "expected duplicate-label rejection to exit non-zero" + ); +} + +#[test] +fn feature_tsv_multi_feature_multi_group() { + // Two features in one BED, four cells across two groups. Exercises the + // per-(feature index, group id) aggregate keying: each feature must get + // its own feature.tsv row per group, correctly labelled, with no + // cross-contamination between features. + let dir = tempdir().unwrap(); + + // 12 CpGs at 0-based positions 9, 19, ..., 119 on chr1. + let cpgs: String = (0..12).map(|i| format!("chr1\t{}\n", i * 10 + 9)).collect(); + let cpg_path = write_file(dir.path(), "cpgs.tsv", &cpgs); + + // featA covers the first 6 CpGs (0-based start < 60), featB the last 6. + let bed = write_file( + dir.path(), + "feat.bed", + "chr1\t0\t60\tfeatA\nchr1\t60\t130\tfeatB\n", + ); + + // Every cell: featA CpGs unmethylated (m=0), featB CpGs methylated (m=1). + // The allc + strand C position is the 0-based CpG start + 1, so 10..=120. + let mut cell_body = String::new(); + for pos in (10..=60).step_by(10) { + cell_body.push_str(&format!("chr1\t{}\t+\tCGN\t0\t1\t0\n", pos)); + } + for pos in (70..=120).step_by(10) { + cell_body.push_str(&format!("chr1\t{}\t+\tCGN\t1\t1\t1\n", pos)); + } + let cell_paths: Vec<_> = ["A", "B", "C", "D"] + .iter() + .map(|id| write_file(dir.path(), &format!("cell{}.allc.tsv", id), &cell_body)) + .collect(); + + let manifest = write_file( + dir.path(), + "cells.tsv", + &format!( + "cell_id\tgroup\tpath\nA\tg1\t{}\nB\tg1\t{}\nC\tg2\t{}\nD\tg2\t{}\n", + cell_paths[0].display(), + cell_paths[1].display(), + cell_paths[2].display(), + cell_paths[3].display(), + ), + ); + + let prefix = dir.path().join("run"); + let status = Command::new(binary_path()) + .args([ + "--cpg-reference", + cpg_path.to_str().unwrap(), + "--features", + bed.to_str().unwrap(), + "--cells", + manifest.to_str().unwrap(), + "--output-prefix", + prefix.to_str().unwrap(), + "--min-cpgs-per-feature", + "3", + "--min-cells-per-group", + "2", + ]) + .status() + .expect("running amet binary"); + assert!(status.success(), "amet exited with non-zero status"); + + // cell_feature: header + 4 cells x 2 features = 8 data rows. + let cf = read_gz(&dir.path().join("run.cell_feature.tsv.gz")); + assert_eq!( + cf.lines().count(), + 9, + "expected header + 8 cell_feature rows" + ); + + // feature.tsv: header + (2 features x 2 groups) = 4 data rows, sorted. + let feat = read_gz(&dir.path().join("run.feature.tsv.gz")); + let lines: Vec<&str> = feat.lines().collect(); + assert_eq!( + lines.len(), + 5, + "expected header + 4 feature rows, got {}", + lines.len() + ); + let header: Vec<&str> = lines[0].split('\t').collect(); + let col = |name: &str| header.iter().position(|h| *h == name).unwrap(); + let (fid, grp, ncells, meth) = ( + col("feature_id"), + col("group"), + col("n_cells"), + col("mean_meth_mean"), + ); + let rows: Vec> = lines[1..].iter().map(|l| l.split('\t').collect()).collect(); + + // Sorted by (feature_id, group): featA/g1, featA/g2, featB/g1, featB/g2. + let expect = [ + ("featA", "g1"), + ("featA", "g2"), + ("featB", "g1"), + ("featB", "g2"), + ]; + for (i, (efid, egrp)) in expect.iter().enumerate() { + assert_eq!(rows[i][fid], *efid, "row {} feature_id", i); + assert_eq!(rows[i][grp], *egrp, "row {} group", i); + assert_eq!(rows[i][ncells], "2", "row {} n_cells", i); + } + + // featA is all-unmethylated, featB all-methylated. If the feature-index + // keying were wrong, these per-feature means would be swapped or merged. + for row in &rows[0..2] { + let m: f64 = row[meth].parse().unwrap(); + assert!(m < 0.5, "featA mean_meth_mean should be ~0, got {}", m); + } + for row in &rows[2..4] { + let m: f64 = row[meth].parse().unwrap(); + assert!(m > 0.5, "featB mean_meth_mean should be ~1, got {}", m); + } +} + #[test] fn cli_rejects_neither_genome_nor_cpg_reference() { let dir = tempdir().unwrap(); diff --git a/method/tests/snapshot.rs b/method/tests/snapshot.rs index fc696ed..d3a535c 100644 --- a/method/tests/snapshot.rs +++ b/method/tests/snapshot.rs @@ -114,18 +114,35 @@ fn snapshot_matches_golden() { let cf_golden = fs::read_to_string(&cf_golden_path).unwrap(); let feat_golden = fs::read_to_string(&feat_golden_path).unwrap(); + // amet scores cells in parallel and writes cell_feature/pair_counts rows in + // a cell-interleaved order that depends on thread scheduling, so the row + // *set* is the stable thing, not the line order. Compare with the header + // pinned and the data rows sorted; this still catches any value, schema, or + // formatting drift. (feature.tsv is already written sorted, so this is a + // no-op for it.) assert_eq!( - cf_actual, - cf_golden, + normalize(&cf_actual), + normalize(&cf_golden), "cell_feature output drifted vs golden at {}.\n\ To accept the new output, rerun with UPDATE_SNAPSHOTS=1.", cf_golden_path.display() ); assert_eq!( - feat_actual, - feat_golden, + normalize(&feat_actual), + normalize(&feat_golden), "feature output drifted vs golden at {}.\n\ To accept the new output, rerun with UPDATE_SNAPSHOTS=1.", feat_golden_path.display() ); } + +/// Keep the header as the first line, sort the data rows. Lets the comparison +/// ignore the (thread-scheduling dependent) row order while still catching any +/// change to the row contents. +fn normalize(s: &str) -> String { + let mut lines: Vec<&str> = s.lines().collect(); + if lines.len() > 1 { + lines[1..].sort_unstable(); + } + lines.join("\n") +} diff --git a/workflow/Rmd/argelaguet.Rmd b/workflow/Rmd/argelaguet.Rmd index bf4722c..a106ca7 100644 --- a/workflow/Rmd/argelaguet.Rmd +++ b/workflow/Rmd/argelaguet.Rmd @@ -114,12 +114,13 @@ sanitize <- function(x) gsub("[ ._]", "-", x) recover_annotation <- function(fid) sub("_\\d+$", "", fid) ## Filename parsers: -## filename = "__.cell_feature.tsv.gz" -## annotation may itself contain hyphens; stage and lineage are the trailing -## two underscore-separated tokens before the extension. -get_annotation <- function(fn) sub("^(.*)_[^_]+_[^_]+\\..*$", "\\1", fn) -get_stage <- function(fn) sub("^.*_([^_]+)_[^_]+\\..*$", "\\1", fn) -get_lineage <- function(fn) sub("^.*_([^_.]+)\\.[^_]*$", "\\1", fn) +## filename = "_..cell_feature.tsv.gz" +## amet writes one triplet per --features BED; the label is the BED basename +## (the annotation). annotation may contain hyphens; stage and lineage are +## sanitized (gsub '[ ._]' '-'). None of the three contains '_' or '.'. +get_annotation <- function(fn) sub("^[^_.]+_[^_.]+\\.([^_.]+)\\..*$", "\\1", fn) +get_stage <- function(fn) sub("^([^_.]+)_[^_.]+\\..*$", "\\1", fn) +get_lineage <- function(fn) sub("^[^_.]+_([^_.]+)\\..*$", "\\1", fn) ``` ```{r load_metadata} diff --git a/workflow/Rmd/crc.Rmd b/workflow/Rmd/crc.Rmd index a867a0b..8532c4e 100644 --- a/workflow/Rmd/crc.Rmd +++ b/workflow/Rmd/crc.Rmd @@ -66,19 +66,21 @@ opts_chunk$set( These are amet feature TSVs (per (subcat, cat, patient, location) combo). Per-cell `i_total` from cell_feature.tsv.gz captures within-cell heterogeneity; per-feature `jsd` from feature.tsv.gz captures across-cell heterogeneity. ```{r, import_short_reports} -## Filename layout: ___.{cell_feature,feature}.tsv.gz -## subcat may itself contain underscores (e.g., 0_Enhancer, crc01_nc_scna), so -## we pin location, patient and cat to the last three tokens and let subcat -## absorb the rest. +## Filename layout: _...{cell_feature,feature}.tsv.gz +## amet writes one triplet per --features BED; the label is the staged BED +## basename .. patient and location are single tokens with no dot +## or underscore, joined by "_"; subcat may contain underscores (e.g., +## 0_Enhancer, crc01_nc_scna) but neither subcat nor cat contains a dot. parse_combo_base <- function(base) { - parts <- strsplit(base, "_", fixed = TRUE)[[1]] - n <- length(parts) - stopifnot(n >= 4) + parts <- strsplit(base, ".", fixed = TRUE)[[1]] + stopifnot(length(parts) == 3) + combo <- strsplit(parts[1], "_", fixed = TRUE)[[1]] + stopifnot(length(combo) == 2) list( - subcat = paste(parts[seq_len(n - 3)], collapse = "_"), - cat = parts[n - 2], - patient = parts[n - 1], - location = parts[n] + subcat = parts[2], + cat = parts[3], + patient = combo[1], + location = combo[2] ) } diff --git a/workflow/Rmd/ecker.Rmd b/workflow/Rmd/ecker.Rmd index ddce718..0b4c89d 100644 --- a/workflow/Rmd/ecker.Rmd +++ b/workflow/Rmd/ecker.Rmd @@ -70,12 +70,14 @@ knitr::opts_chunk$set( Mouse (mm10), CpG context only. ```{r helpers} -## filenames: {annotation}_{region}_{sub_type}.*.gz -## annotation has no underscores; neither do region or sub_type values -get_annotation <- function(fn) sub("^(.*)_[^_]+_[^_]+\\..*$", "\\1", fn) -## Slab axis (e.g. 2C/3C/4B/5D), second-to-last filename token. -get_region <- function(fn) sub("^.*_([^_]+)_[^_]+\\..*$", "\\1", fn) -get_sub_type <- function(fn) sub("^.*_([^_.]+)\\.[^_]*$", "\\1", fn) +## filenames: {region}_{sub_type}.{annotation}.*.gz +## amet writes one triplet per --features BED; the label is the BED basename +## (the annotation). region, sub_type and annotation contain no underscores +## or dots. +get_annotation <- function(fn) sub("^[^_.]+_[^_.]+\\.([^_.]+)\\..*$", "\\1", fn) +## Slab axis (e.g. 2C/3C/4B/5D), the first filename token. +get_region <- function(fn) sub("^([^_.]+)_[^_.]+\\..*$", "\\1", fn) +get_sub_type <- function(fn) sub("^[^_.]+_([^_.]+)\\..*$", "\\1", fn) ann_labels <- c( "genes" = "Genes", diff --git a/workflow/rules/argelaguet.smk b/workflow/rules/argelaguet.smk index fb31a31..8c84552 100644 --- a/workflow/rules/argelaguet.smk +++ b/workflow/rules/argelaguet.smk @@ -231,11 +231,16 @@ rule chr19_sizes: rule run_amet_on_argelaguet_features: - """Run amet on one (annotation, stage, lineage) combo. Wildcards: - {annotation, stage, lineage}, where stage and lineage are sanitized - strings (gsub '[ ._]' '-').""" + """Run amet once per (stage, lineage) combo across every annotation BED. + Each BED is passed as a separate --features so the cell files are parsed + only once for the whole annotation panel. amet writes a cell_feature, + feature, and pair_counts file per BED, keyed by the BED basename (the + annotation name); this rule declares the cell_feature and feature files + as tracked outputs. stage and lineage are sanitized strings + (gsub '[ ._]' '-').""" wildcard_constraints: - annotation = "|".join(_ALL_ARGELAGUET_ANN_NAMES), + stage = r"[^_.]+", + lineage = r"[^_.]+", conda: op.join("..", "envs", "bedtools.yml") input: @@ -243,33 +248,35 @@ rule run_amet_on_argelaguet_features: cells = op.join(ARG_DATA, "manifests", "{stage}_{lineage}.tsv"), genome = op.join(REFS, "mm10_ucsc", "genome.fa"), cpg = op.join(REFS, "mm10_ucsc", "genome.fa.cpg"), - bed = op.join(ARG_RUN, "beds", "{annotation}.bed"), + beds = [op.join(ARG_RUN, "beds", f"{ann}.bed") + for ann in _ALL_ARGELAGUET_ANN_NAMES], output: - cell_feature = op.join( - ARG_RUN, "features", - "{annotation}_{stage}_{lineage}.cell_feature.tsv.gz"), - feature = op.join( - ARG_RUN, "features", - "{annotation}_{stage}_{lineage}.feature.tsv.gz"), + cell_feature = [ + op.join(ARG_RUN, "features", + "{stage}_{lineage}." + f"{ann}.cell_feature.tsv.gz") + for ann in _ALL_ARGELAGUET_ANN_NAMES], + feature = [ + op.join(ARG_RUN, "features", + "{stage}_{lineage}." + f"{ann}.feature.tsv.gz") + for ann in _ALL_ARGELAGUET_ANN_NAMES], params: - prefix = op.join( - ARG_RUN, "features", - "{annotation}_{stage}_{lineage}"), + prefix = op.join(ARG_RUN, "features", "{stage}_{lineage}"), + features_flags = lambda w, input: " ".join( + f"--features {b}" for b in input.beds), i_max_lag = config["amet"]["i_max_lag"], min_cpgs = config["amet"]["min_cpgs_per_feature"], min_cells = min_cells_per_group(), thresh = config["amet"]["meth_call_threshold"], threads: min(workflow.cores, 4) log: - op.join(ARG_RUN, "logs", - "amet_{annotation}_{stage}_{lineage}.log"), + op.join(ARG_RUN, "logs", "amet_features_{stage}_{lineage}.log"), shell: """ mkdir -p $(dirname {params.prefix}) {input.binary} \ --genome {input.genome} \ --cells {input.cells} \ - --features {input.bed} \ + {params.features_flags} \ --output-prefix {params.prefix} \ --i-max-lag {params.i_max_lag} \ --min-cpgs-per-feature {params.min_cpgs} \ @@ -400,15 +407,15 @@ def _argelaguet_combos(): def list_argelaguet_features_outputs(wildcards): - """All (annotation x stage x lineage) amet output files.""" + """All (stage x lineage x annotation) amet output files.""" combos = _argelaguet_combos() out = [] - for ann in _ALL_ARGELAGUET_ANN_NAMES: - for stage, lineage in combos: + for stage, lineage in combos: + for ann in _ALL_ARGELAGUET_ANN_NAMES: out.append(op.join(ARG_RUN, "features", - f"{ann}_{stage}_{lineage}.cell_feature.tsv.gz")) + f"{stage}_{lineage}.{ann}.cell_feature.tsv.gz")) out.append(op.join(ARG_RUN, "features", - f"{ann}_{stage}_{lineage}.feature.tsv.gz")) + f"{stage}_{lineage}.{ann}.feature.tsv.gz")) return out diff --git a/workflow/rules/crc.smk b/workflow/rules/crc.smk index 8e48b43..8faa0c1 100644 --- a/workflow/rules/crc.smk +++ b/workflow/rules/crc.smk @@ -328,10 +328,15 @@ rule crc_combine_window_annotations: rule run_amet_on_crc_features: - """Run amet on one (subcat, cat, patient, location) combo.""" + """Run amet once per (patient, location) combo across every annotation BED. + Each BED is passed as a separate --features, so the cell files are parsed + only once for the whole annotation panel. amet writes a cell_feature, + feature, and pair_counts file per BED, keyed by the staged BED basename + .; this rule declares the cell_feature and feature files as + tracked outputs.""" wildcard_constraints: - subcat = _CRC_SUBCAT_RE, - cat = _CRC_CAT_RE, + patient = r"[^_.]+", + location = r"[^_.]+", conda: op.join("..", "envs", "bedtools.yml") input: @@ -340,18 +345,21 @@ rule run_amet_on_crc_features: "{patient}_{location}.tsv"), genome = op.join(REFS, "hg19_ucsc", "genome.fa"), cpg = op.join(REFS, "hg19_ucsc", "genome.fa.cpg"), - bed = op.join(CRC_RUN, "beds", "{subcat}.{cat}.bed"), + beds = [op.join(CRC_RUN, "beds", f"{sc}.{c}.bed") + for sc, c in _CRC_LOCAL_PAIRS], output: - cell_feature = op.join( - CRC_RUN, "features", - "{subcat}_{cat}_{patient}_{location}.cell_feature.tsv.gz"), - feature = op.join( - CRC_RUN, "features", - "{subcat}_{cat}_{patient}_{location}.feature.tsv.gz"), + cell_feature = [ + op.join(CRC_RUN, "features", + "{patient}_{location}." + f"{sc}.{c}.cell_feature.tsv.gz") + for sc, c in _CRC_LOCAL_PAIRS], + feature = [ + op.join(CRC_RUN, "features", + "{patient}_{location}." + f"{sc}.{c}.feature.tsv.gz") + for sc, c in _CRC_LOCAL_PAIRS], params: - prefix = op.join( - CRC_RUN, "features", - "{subcat}_{cat}_{patient}_{location}"), + prefix = op.join(CRC_RUN, "features", "{patient}_{location}"), + features_flags = lambda w, input: " ".join( + f"--features {b}" for b in input.beds), i_max_lag = config["amet"]["i_max_lag"], min_cpgs = config["amet"]["min_cpgs_per_feature"], min_cells = min_cells_per_group(), @@ -359,14 +367,14 @@ rule run_amet_on_crc_features: threads: min(workflow.cores, 4) log: op.join(CRC_RUN, "logs", - "amet_{subcat}_{cat}_{patient}_{location}.log"), + "amet_features_{patient}_{location}.log"), shell: """ mkdir -p $(dirname {params.prefix}) {input.binary} \ --genome {input.genome} \ --cells {input.cells} \ - --features {input.bed} \ + {params.features_flags} \ --output-prefix {params.prefix} \ --i-max-lag {params.i_max_lag} \ --min-cpgs-per-feature {params.min_cpgs} \ @@ -439,12 +447,12 @@ def _crc_combos(): def list_crc_features_outputs(wildcards): combos = _crc_combos() out = [] - for sc, c in _CRC_LOCAL_PAIRS: - for p, l in combos: + for p, l in combos: + for sc, c in _CRC_LOCAL_PAIRS: out.append(op.join(CRC_RUN, "features", - f"{sc}_{c}_{p}_{l}.cell_feature.tsv.gz")) + f"{p}_{l}.{sc}.{c}.cell_feature.tsv.gz")) out.append(op.join(CRC_RUN, "features", - f"{sc}_{c}_{p}_{l}.feature.tsv.gz")) + f"{p}_{l}.{sc}.{c}.feature.tsv.gz")) return out diff --git a/workflow/rules/ecker.smk b/workflow/rules/ecker.smk index 0036122..9de3746 100644 --- a/workflow/rules/ecker.smk +++ b/workflow/rules/ecker.smk @@ -385,9 +385,15 @@ def _ecker_combo_cell_tsvs(wildcards): rule run_amet_on_ecker_features: - """Run amet on one (annotation, region, sub_type) combo.""" + """Run amet once per (region, sub_type) combo across every annotation BED. + Each BED is passed as a separate --features so the cell files are parsed + only once for the whole annotation panel. amet writes a cell_feature, + feature, and pair_counts file per BED, keyed by the BED basename (the + annotation name); this rule declares the cell_feature and feature files + as tracked outputs.""" wildcard_constraints: - annotation = "|".join(_ECKER_ALL_ANN_NAMES), + region = r"[^_.]+", + sub_type = r"[^_.]+", conda: op.join("..", "envs", "bedtools.yml") input: @@ -397,18 +403,21 @@ rule run_amet_on_ecker_features: cell_files = _ecker_combo_cell_tsvs, genome = op.join(REFS, "mm10_ensembl", "genome.fa"), cpg = op.join(REFS, "mm10_ensembl", "genome.fa.cpg"), - bed = op.join(ECKER_RUN, "beds", "{annotation}.bed"), + beds = [op.join(ECKER_RUN, "beds", f"{ann}.bed") + for ann in _ECKER_ALL_ANN_NAMES], output: - cell_feature = op.join( - ECKER_RUN, "features", - "{annotation}_{region}_{sub_type}.cell_feature.tsv.gz"), - feature = op.join( - ECKER_RUN, "features", - "{annotation}_{region}_{sub_type}.feature.tsv.gz"), + cell_feature = [ + op.join(ECKER_RUN, "features", + "{region}_{sub_type}." + f"{ann}.cell_feature.tsv.gz") + for ann in _ECKER_ALL_ANN_NAMES], + feature = [ + op.join(ECKER_RUN, "features", + "{region}_{sub_type}." + f"{ann}.feature.tsv.gz") + for ann in _ECKER_ALL_ANN_NAMES], params: - prefix = op.join( - ECKER_RUN, "features", - "{annotation}_{region}_{sub_type}"), + prefix = op.join(ECKER_RUN, "features", "{region}_{sub_type}"), + features_flags = lambda w, input: " ".join( + f"--features {b}" for b in input.beds), i_max_lag = config["amet"]["i_max_lag"], min_cpgs = config["amet"]["min_cpgs_per_feature"], min_cells = min_cells_per_group(), @@ -416,14 +425,14 @@ rule run_amet_on_ecker_features: threads: min(workflow.cores, 4) log: op.join(ECKER_RUN, "logs", - "amet_{annotation}_{region}_{sub_type}.log"), + "amet_features_{region}_{sub_type}.log"), shell: """ mkdir -p $(dirname {params.prefix}) {input.binary} \ --genome {input.genome} \ --cells {input.cells} \ - --features {input.bed} \ + {params.features_flags} \ --output-prefix {params.prefix} \ --i-max-lag {params.i_max_lag} \ --min-cpgs-per-feature {params.min_cpgs} \ @@ -493,12 +502,12 @@ def _ecker_combos(): def list_ecker_features_outputs(wildcards): combos = _ecker_combos() out = [] - for ann in _ECKER_ALL_ANN_NAMES: - for sr, st in combos: + for sr, st in combos: + for ann in _ECKER_ALL_ANN_NAMES: out.append(op.join(ECKER_RUN, "features", - f"{ann}_{sr}_{st}.cell_feature.tsv.gz")) + f"{sr}_{st}.{ann}.cell_feature.tsv.gz")) out.append(op.join(ECKER_RUN, "features", - f"{ann}_{sr}_{st}.feature.tsv.gz")) + f"{sr}_{st}.{ann}.feature.tsv.gz")) return out diff --git a/workflow/scripts/internal/setup_barbara_links.sh b/workflow/scripts/internal/setup_barbara_links.sh index 928a3e9..6107d2b 100755 --- a/workflow/scripts/internal/setup_barbara_links.sh +++ b/workflow/scripts/internal/setup_barbara_links.sh @@ -18,7 +18,7 @@ set -euo pipefail YAMET="${YAMET:-$HOME/src/yamet/workflow}" YAMET_HG19_CURATED="${YAMET_HG19_CURATED:-$HOME/src/yamet/hg19}" -repo_root="$(cd "$(dirname "$0")/../.." && pwd)" +repo_root="$(cd "$(dirname "$0")/../../.." && pwd)" res="$repo_root/results" link() { diff --git a/workflow/scripts/internal/sync_from_barbara.sh b/workflow/scripts/internal/sync_from_barbara.sh index 469f331..98fb34c 100755 --- a/workflow/scripts/internal/sync_from_barbara.sh +++ b/workflow/scripts/internal/sync_from_barbara.sh @@ -15,7 +15,7 @@ if [[ $# -lt 1 ]]; then fi dataset="$1" -repo_root="$(cd "$(dirname "$0")/../.." && pwd)" +repo_root="$(cd "$(dirname "$0")/../../.." && pwd)" config="$repo_root/workflow/config/datasets.yaml" out_root="$repo_root/results/$dataset" mkdir -p "$out_root/cells" "$out_root/features"