diff --git a/src/decima/cli/finetune.py b/src/decima/cli/finetune.py index 70105bb..64fbdf3 100755 --- a/src/decima/cli/finetune.py +++ b/src/decima/cli/finetune.py @@ -45,6 +45,7 @@ @click.option("--logger", default="wandb", type=str, help="Logger.") @click.option("--num-workers", default=16, type=int, help="Number of workers.") @click.option("--seed", default=0, type=int, help="Random seed.") +@click.option("--checkpoint", default=None, type=str, help="Path to a checkpoint to resume training from.") def cli_finetune( name, model, @@ -63,6 +64,7 @@ def cli_finetune( logger, num_workers, seed, + checkpoint, ): """Finetune the Decima model. @@ -102,8 +104,11 @@ def cli_finetune( ) val_dataset = HDF5Dataset(h5_file=h5_file, ad=ad, key="val", max_seq_shift=0) - if isinstance(device, str) and device.isdigit(): - device = int(device) + if isinstance(device, str): + if "," in device: + device = [int(d) for d in device.split(",")] + elif device.isdigit(): + device = int(device) train_params = { "batch_size": batch_size, @@ -137,7 +142,7 @@ def cli_finetune( run = wandb.init(project="decima", dir=name, name=name) logger.info("Training") - model.train_on_dataset(train_dataset, val_dataset) + model.train_on_dataset(train_dataset, val_dataset, checkpoint_path=checkpoint) train_dataset.close() val_dataset.close() if logger == "wandb": diff --git a/src/decima/data/write_hdf5.py b/src/decima/data/write_hdf5.py index ccd3c1d..0d01f45 100644 --- a/src/decima/data/write_hdf5.py +++ b/src/decima/data/write_hdf5.py @@ -1,10 +1,12 @@ import h5py import numpy as np -from grelu.sequence.format import convert_input_type +from grelu.io.genome import get_genome +from grelu.sequence.format import _BASE_LUT from grelu.sequence.utils import get_unique_length +from tqdm import tqdm -def write_hdf5(file, ad, pad=0, genome="hg38"): +def write_hdf5(file, ad, pad=0, genome="hg38", batch_size=1000): """Write AnnData object to HDF5 file. Args: @@ -12,16 +14,22 @@ def write_hdf5(file, ad, pad=0, genome="hg38"): ad: AnnData object containing the data pad: Amount of padding to add. Defaults to 0 genome: Genome name or path to the genome fasta file. Defaults to "hg38" + batch_size: Number of genes per write batch. Defaults to 1000 """ - # Calculate seq_len seq_len = get_unique_length(ad.var) + padded_seq_len = seq_len + 2 * pad + n_genes = ad.var.shape[0] + genome_obj = get_genome(genome) + + intervals = ad.var[["chrom", "start", "end", "strand"]].copy() + intervals["start"] = intervals["start"] - pad + intervals["end"] = intervals["end"] + pad with h5py.File(file, "w") as f: # Metadata print("Writing metadata") f.create_dataset("pad", shape=(), data=pad) f.create_dataset("seq_len", shape=(), data=seq_len) - padded_seq_len = seq_len + 2 * pad f.create_dataset("padded_seq_len", shape=(), data=padded_seq_len) # Tasks @@ -35,26 +43,47 @@ def write_hdf5(file, ad, pad=0, genome="hg38"): f.create_dataset("genes", shape=arr.shape, data=arr) # Labels - arr = np.expand_dims(ad.X.T.astype(np.float32), 2) - print(f"Writing labels array of shape: {arr.shape}") + print("Writing labels") + X = ad.X.toarray() if hasattr(ad.X, "toarray") else np.asarray(ad.X) + arr = np.expand_dims(X.T.astype(np.float32), 2) + print(f" shape: {arr.shape}") f.create_dataset("labels", shape=arr.shape, dtype=np.float32, data=arr) + del X, arr + + # Masks and sequences — written in batches to avoid OOM + print("Writing masks and sequences") + masks_ds = f.create_dataset( + "masks", shape=(n_genes, padded_seq_len), dtype=np.float32 + ) + seqs_ds = f.create_dataset( + "sequences", shape=(n_genes, padded_seq_len), dtype=np.int8 + ) + + n_batches = (n_genes + batch_size - 1) // batch_size + for b in tqdm(range(n_batches), desc="Batches"): + start_i = b * batch_size + end_i = min(start_i + batch_size, n_genes) + batch_var = ad.var.iloc[start_i:end_i] + batch_iv = intervals.iloc[start_i:end_i] + + masks = np.zeros((end_i - start_i, padded_seq_len), dtype=np.float32) + seqs = np.empty((end_i - start_i, padded_seq_len), dtype=np.int8) + + for j, (row_var, row_iv) in enumerate( + zip(batch_var.itertuples(), batch_iv.itertuples()) + ): + masks[j, row_var.gene_mask_start + pad : row_var.gene_mask_end + pad] = 1.0 + seq = str( + genome_obj.get_seq( + row_iv.chrom, + row_iv.start + 1, + row_iv.end, + rc=row_iv.strand == "-", + ) + ).upper() + seqs[j] = _BASE_LUT[np.frombuffer(seq.encode("ascii"), dtype=np.uint8)] - # Gene masks - print("Making gene masks") - shape = (ad.var.shape[0], padded_seq_len) - arr = np.zeros(shape=shape) - for i, row in enumerate(ad.var.itertuples()): - arr[i, row.gene_mask_start + pad : row.gene_mask_end + pad] += 1 - print(f"Writing mask array of shape: {arr.shape}") - f.create_dataset("masks", shape=shape, dtype=np.float32, data=arr) - - # Sequences - print("Encoding sequences") - arr = ad.var[["chrom", "start", "end", "strand"]].copy() - arr.start = arr.start - pad - arr.end = arr.end + pad - arr = convert_input_type(arr, "indices", genome=genome) - print(f"Writing sequence array of shape: {arr.shape}") - f.create_dataset("sequences", shape=arr.shape, dtype=np.int8, data=arr) + masks_ds[start_i:end_i] = masks + seqs_ds[start_i:end_i] = seqs print("Done!") diff --git a/src/decima/model/lightning.py b/src/decima/model/lightning.py index 1c20ac7..5c5573d 100644 --- a/src/decima/model/lightning.py +++ b/src/decima/model/lightning.py @@ -313,12 +313,14 @@ def train_on_dataset( logger = self.parse_logger() # Set up trainer + devices = make_list(self.train_params["devices"]) trainer = pl.Trainer( max_epochs=self.train_params["max_epochs"], accelerator="gpu", - devices=make_list(self.train_params["devices"]), + devices=devices, + strategy="ddp" if len(devices) > 1 else "auto", logger=logger, - callbacks=[ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=self.train_params["save_top_k"])], + callbacks=[ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=self.train_params["save_top_k"], save_last=True)], default_root_dir=self.train_params["save_dir"], accumulate_grad_batches=self.train_params["accumulate_grad_batches"], gradient_clip_val=self.train_params["clip"],