From 60c22a9a7bf1feacb3ebefa53b17e9c526ea9cbd Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Thu, 15 Feb 2024 12:04:26 +0100 Subject: [PATCH] Raise default threads for BAM parsing 8->32, BLAS 8->16 It was originally capped at 8 under the belief that reading 8 BAM files in parallel would saturate the disk, so there would be no benefit of going higher. However, my laptop can read at 4 GB/s, and decompress BAM files perhaps 40 times slower, so it's CPU bottlenecked even with 32 threads. This change is significant, because users have reported slow BAM file parsing. However, it will potentially quadruple the memory usage of the BAM parsing step. Will be benchmarked before merging. The BLAS change is simply because I think 8 CPUs is too conservative. --- vamb/__main__.py | 22 +++++++++++++--------- vamb/aamb_encode.py | 1 - vamb/parsebam.py | 15 ++++++++------- vamb/semisupervised_encode.py | 1 - vamb/taxvamb_encode.py | 1 - workflow_avamb/src/rip_bins.py | 12 ++++++------ 6 files changed, 27 insertions(+), 25 deletions(-) diff --git a/vamb/__main__.py b/vamb/__main__.py index 87abc48f..d9942629 100755 --- a/vamb/__main__.py +++ b/vamb/__main__.py @@ -22,13 +22,13 @@ import pandas as pd _ncpu = os.cpu_count() -DEFAULT_THREADS = 8 if _ncpu is None else min(_ncpu, 8) +DEFAULT_BLAS_THREADS = 16 if _ncpu is None else min(_ncpu, 16) # These MUST be set before importing numpy # I know this is a shitty hack, see https://github.com/numpy/numpy/issues/11826 -os.environ["MKL_NUM_THREADS"] = str(DEFAULT_THREADS) -os.environ["NUMEXPR_NUM_THREADS"] = str(DEFAULT_THREADS) -os.environ["OMP_NUM_THREADS"] = str(DEFAULT_THREADS) +os.environ["MKL_NUM_THREADS"] = str(DEFAULT_BLAS_THREADS) +os.environ["NUMEXPR_NUM_THREADS"] = str(DEFAULT_BLAS_THREADS) +os.environ["OMP_NUM_THREADS"] = str(DEFAULT_BLAS_THREADS) # Append vamb to sys.path to allow vamb import even if vamb was not installed # using pip @@ -771,9 +771,11 @@ def cluster_and_write_files( print( str(i + 1), None if cluster.radius is None else round(cluster.radius, 3), - None - if cluster.observed_pvr is None - else round(cluster.observed_pvr, 2), + ( + None + if cluster.observed_pvr is None + else round(cluster.observed_pvr, 2) + ), cluster.kind_str, sum(sequence_lens[i] for i in cluster.members), len(cluster.members), @@ -1686,9 +1688,11 @@ def add_input_output_arguments(subparser): dest="nthreads", metavar="", type=int, - default=DEFAULT_THREADS, + default=vamb.parsebam.DEFAULT_BAM_THREADS, help=( - "number of threads to use " "[min(" + str(DEFAULT_THREADS) + ", nbamfiles)]" + "number of threads to read BAM files [min(" + + str(vamb.parsebam.DEFAULT_BAM_THREADS) + + ", nbamfiles)]" ), ) inputos.add_argument( diff --git a/vamb/aamb_encode.py b/vamb/aamb_encode.py index 689688cd..19f4ce85 100644 --- a/vamb/aamb_encode.py +++ b/vamb/aamb_encode.py @@ -1,6 +1,5 @@ """Adversarial autoencoders (AAE) for metagenomics binning, this files contains the implementation of the AAE""" - import numpy as np from math import log, isfinite import time diff --git a/vamb/parsebam.py b/vamb/parsebam.py index afa113f2..4e8ef680 100644 --- a/vamb/parsebam.py +++ b/vamb/parsebam.py @@ -14,12 +14,13 @@ from typing import Optional, TypeVar, Union, IO, Sequence, Iterable from pathlib import Path import shutil - -_ncpu = _os.cpu_count() -DEFAULT_THREADS = 8 if _ncpu is None else _ncpu +import os A = TypeVar("A", bound="Abundance") +_ncpu = os.cpu_count() +DEFAULT_BAM_THREADS = 32 if _ncpu is None else min(_ncpu, 32) + class Abundance: "Object representing contig abundance. Contains a matrix and refhash." @@ -115,10 +116,10 @@ def from_files( chunksize = min(nthreads, len(paths)) - # We cap it to 16 threads, max. This will prevent pycoverm from consuming a huge amount + # We cap it to DEFAULT_BAM_THREADS threads, max. This will prevent pycoverm from consuming a huge amount # of memory if given a crapload of threads, and most programs will probably be IO bound - # when reading 16 files at a time. - chunksize = min(chunksize, 16) + # when reading DEFAULT_BAM_THREADS files at a time. + chunksize = min(chunksize, DEFAULT_BAM_THREADS) # If it can be done in memory, do so if chunksize >= len(paths): @@ -134,7 +135,7 @@ def from_files( else: if cache_directory is None: raise ValueError( - "If min(16, nthreads) < len(paths), cache_directory must not be None" + "If min(DEFAULT_BAM_THREADS, nthreads) < len(paths), cache_directory must not be None" ) return cls.chunkwise_loading( paths, diff --git a/vamb/semisupervised_encode.py b/vamb/semisupervised_encode.py index 8af77f86..ac81e58a 100644 --- a/vamb/semisupervised_encode.py +++ b/vamb/semisupervised_encode.py @@ -1,6 +1,5 @@ """Semisupervised multimodal VAEs for metagenomics binning, this files contains the implementation of the VAEVAE for MMSEQ predictions""" - __cmd_doc__ = """Encode depths and TNF using a VAE to latent representation""" import numpy as _np diff --git a/vamb/taxvamb_encode.py b/vamb/taxvamb_encode.py index 4c14a778..0854dde8 100644 --- a/vamb/taxvamb_encode.py +++ b/vamb/taxvamb_encode.py @@ -1,6 +1,5 @@ """Hierarchical loss for the labels suggested in https://arxiv.org/abs/2210.10929""" - __cmd_doc__ = """Hierarchical loss for the labels""" diff --git a/workflow_avamb/src/rip_bins.py b/workflow_avamb/src/rip_bins.py index 9287cbd7..e4fa705b 100644 --- a/workflow_avamb/src/rip_bins.py +++ b/workflow_avamb/src/rip_bins.py @@ -183,9 +183,9 @@ def remove_meaningless_edges_from_pairs( contig_length, ) print("Cluster ripped because of a meaningless edge ", cluster_updated) - clusters_changed_but_not_intersecting_contigs[ - cluster_updated - ] = cluster_contigs[cluster_updated] + clusters_changed_but_not_intersecting_contigs[cluster_updated] = ( + cluster_contigs[cluster_updated] + ) components: list[set[str]] = list() for component in nx.connected_components(graph_clusters): @@ -295,9 +295,9 @@ def make_all_components_pair( contig_length, ) print("Cluster ripped because of a pairing component ", cluster_updated) - clusters_changed_but_not_intersecting_contigs[ - cluster_updated - ] = cluster_contigs[cluster_updated] + clusters_changed_but_not_intersecting_contigs[cluster_updated] = ( + cluster_contigs[cluster_updated] + ) component_len = max( [ len(nx.node_connected_component(graph_clusters, node_i))