From dfb20f7b0e881ca0a0996754ac7c928b69e617f6 Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Sat, 12 Nov 2022 13:30:07 +0100 Subject: [PATCH 01/42] reordered flow to only process full batches --- spliceai/__main__.py | 19 ++++- spliceai/batch/batch.py | 174 ++++++++++++++++++++++------------------ 2 files changed, 112 insertions(+), 81 deletions(-) diff --git a/spliceai/__main__.py b/spliceai/__main__.py index 11687d2..980fa5e 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -118,13 +118,28 @@ def run_spliceai(input_data, output_data, reference, annotation, distance, mask, if len(scores) > 0: record.info['SpliceAI'] = scores output_data.write(record) + + # close VCF + vcf.close() if batch: # Ensure we process any leftover records in the batch when we finish iterating the VCF. This # would be a good candidate for a context manager if we removed the original non batching code above batch.finish() - - vcf.close() + # Iterate over original list of vcf records again, reconstructing record with annotations from shelved data + vcf = pysam.VariantFile(input_data) + # have to update header again + header = vcf.header + header.add_line('##INFO=') + batch.write_records(vcf) + # close shelves + batch.shelf_records.close() + batch.shelf_preds.close() + + output_data.close() diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index 12f98db..1ced6e5 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -4,8 +4,10 @@ import collections import logging import time - +import shelve +import tempfile import numpy as np +import os from spliceai.batch.batch_utils import extract_delta_scores, get_preds, encode_batch_records @@ -16,11 +18,12 @@ BatchLookupIndex = collections.namedtuple( - 'BatchLookupIndex', 'sequence_type tensor_size batch_index' + # ref/alt size batch for this size index in current batch for this size + 'BatchLookupIndex', 'sequence_type tensor_size batch_ix batch_index' ) PreparedVCFRecord = collections.namedtuple( - 'PreparedVCFRecord', 'vcf_record gene_info locations' + 'PreparedVCFRecord', 'vcf_idx gene_info locations' ) @@ -43,88 +46,91 @@ def __init__(self, ann, output, dist, mask, prediction_batch_size, tensorflow_ba self.batch_predictions = 0 self.total_predictions = 0 self.total_vcf_records = 0 + self.batch_counters = {} - def _clear_batch(self): - self.batch_predictions = 0 - self.batches.clear() - del self.prepared_vcf_records[:] + # shelves to track data. + self.tmpdir = tempfile.TemporaryDirectory() + # store batches of predictions using 'tensor_size|batch_idx' as key. + self.shelf_preds = shelve.open(os.path.join(self.tmpdir.name,"spliceai_preds.shelf")) + # track records to have order correct + self.shelf_records = shelve.open(os.path.join(self.tmpdir.name,"spliceai_records.shelf")) - def _process_batch(self): - start = time.time() - total_batch_predictions = 0 - logger.debug('Starting process_batch') + def _process_batch(self,tensor_size): + start = time.time() + # get last batch for this tensor_size + batch_ix = self.batch_counters[tensor_size] + batch = self.batches[tensor_size] # Sanity check dump of batch sizes - batch_sizes = ["{}:{}".format(tensor_size, len(batch)) for tensor_size, batch in self.batches.items()] - logger.debug('Batch Sizes: {}'.format(batch_sizes)) - - # Collect each batch's predictions - batch_preds = {} - for tensor_size, batch in self.batches.items(): - # Convert list of encodings into a proper sized numpy matrix - prediction_batch = np.concatenate(batch, axis=0) - - # Run predictions - batch_preds[tensor_size] = np.mean( - get_preds(self.ann, prediction_batch, self.prediction_batch_size), axis=0 - ) + logger.debug('Tensor size : {} : batch_ix {} : nr.entries : {}'.format(tensor_size, batch_ix , len(batch))) - # Iterate over original list of vcf records, reconstructing record with annotations - for prepared_record in self.prepared_vcf_records: - record_predictions = self._write_record(prepared_record, batch_preds) - total_batch_predictions += record_predictions + # Convert list of encodings into a proper sized numpy matrix + prediction_batch = np.concatenate(batch, axis=0) + # Run predictions && add to shelf. + self.shelf_preds["{}|{}".format(tensor_size,batch_ix)] = np.mean( + get_preds(self.ann, prediction_batch, self.prediction_batch_size), axis=0 + ) - self._clear_batch() + # clear the batch. + self.batches[tensor_size] = [] + # initialize the next batch_ix + self.batch_counters[tensor_size] += 1 + logger.debug('Predictions: {}, VCF Records: {}'.format(self.total_predictions, self.total_vcf_records)) duration = time.time() - start - preds_per_sec = total_batch_predictions / duration + preds_per_sec = len(batch) / duration preds_per_hour = preds_per_sec * 60 * 60 logger.debug('Finished in {:0.2f}s, per sec: {:0.2f}, per hour: {:0.2f}'.format(duration, preds_per_sec, preds_per_hour)) - def _write_record(self, prepared_record, batch_preds): - record = prepared_record.vcf_record - gene_info = prepared_record.gene_info - record_predictions = 0 - - all_y_ref = [] - all_y_alt = [] - - # Each prediction in the batch is located and put into the correct y - for location in prepared_record.locations: - # No prediction here - if location.tensor_size == 0: + # wrapper to write out all shelved variants + def write_records(self, vcf): + line_idx = 0 + for record in vcf: + line_idx += 1 + # get prepared record by line_idx + prepared_record = self.shelf_records[str(line_idx)] + #record = prepared_record.vcf_record + gene_info = prepared_record.gene_info + record_predictions = 0 + + all_y_ref = [] + all_y_alt = [] + + # Each prediction in the batch is located and put into the correct y + for location in prepared_record.locations: + # No prediction here + if location.tensor_size == 0: + if location.sequence_type == SequenceType_REF: + all_y_ref.append(None) + else: + all_y_alt.append(None) + continue + + # Extract the prediction from the batch into a list of predictions for this record + batch = self.shelf_preds["{}|{}".format(location.tensor_size,location.batch_ix)] # batch_preds[location.tensor_size] if location.sequence_type == SequenceType_REF: - all_y_ref.append(None) + all_y_ref.append(batch[[location.batch_index], :, :]) else: - all_y_alt.append(None) - continue - - # Extract the prediction from the batch into a list of predictions for this record - batch = batch_preds[location.tensor_size] - if location.sequence_type == SequenceType_REF: - all_y_ref.append(batch[[location.batch_index], :, :]) - else: - all_y_alt.append(batch[[location.batch_index], :, :]) - - delta_scores = extract_delta_scores( - all_y_ref=all_y_ref, - all_y_alt=all_y_alt, - record=record, - ann=self.ann, - dist_var=self.dist, - mask=self.mask, - gene_info=gene_info, - ) + all_y_alt.append(batch[[location.batch_index], :, :]) + delta_scores = extract_delta_scores( + all_y_ref=all_y_ref, + all_y_alt=all_y_alt, + record=record, + ann=self.ann, + dist_var=self.dist, + mask=self.mask, + gene_info=gene_info, + ) - # If there are predictions, write them to the VCF INFO section - if len(delta_scores) > 0: - record.info['SpliceAI'] = delta_scores - record_predictions += len(delta_scores) + # If there are predictions, write them to the VCF INFO section + if len(delta_scores) > 0: + record.info['SpliceAI'] = delta_scores + record_predictions += len(delta_scores) - self.output.write(record) - return record_predictions + self.output.write(record) + def add_record(self, record): """ @@ -135,7 +141,7 @@ def add_record(self, record): are made, it knows where to look up the corresponding prediction for the vcf record. Once the batch size hits it's capacity, it'll process all the predictions for the - encoded batches. + encoded batch. """ self.total_vcf_records += 1 @@ -160,7 +166,7 @@ def add_record(self, record): if len(encoded_seq) == 0: # Add BatchLookupIndex with zeros so when the batch collects the outputs # it knows that there is no prediction for this record - batch_lookup_indexes.append(BatchLookupIndex(var_type, 0, 0)) + batch_lookup_indexes.append(BatchLookupIndex(var_type, 0, 0, 0)) continue # Iterate over the encoded sequence and drop into the correct batch by size and @@ -172,29 +178,39 @@ def add_record(self, record): # Create batch for this size if tensor_size not in self.batches: self.batches[tensor_size] = [] + self.batch_counters[tensor_size] = 0 - # Add encoded record to batch + # Add encoded record to batch 'n' for tensor_size self.batches[tensor_size].append(row) # Get the index of the record we just added in the batch cur_batch_record_ix = len(self.batches[tensor_size]) - 1 # Store a reference so we can pull out the prediction for this item from the batches - batch_lookup_indexes.append(BatchLookupIndex(var_type, tensor_size, cur_batch_record_ix)) + batch_lookup_indexes.append(BatchLookupIndex(var_type, tensor_size, self.batch_counters[tensor_size] , cur_batch_record_ix)) # Save the batch locations for this record on the composite object prepared_record = PreparedVCFRecord( - vcf_record=record, gene_info=gene_info, locations=batch_lookup_indexes + vcf_idx=self.total_vcf_records, gene_info=gene_info, locations=batch_lookup_indexes ) - self.prepared_vcf_records.append(prepared_record) + # add to shelf by vcf_idx + self.shelf_records[str(self.total_vcf_records)] = prepared_record # If we're reached our threshold for the max items to process, then process the batch - if self.batch_predictions >= self.prediction_batch_size: - self._process_batch() + for tensor_size in self.batch_counters: + if len(self.batches[tensor_size]) >= self.prediction_batch_size: + logger.debug("Batch {} full. Processing".format(tensor_size)) + self._process_batch(tensor_size) + + def finish(self): """ - Method to process all the remaining items that have been added to the batch. + Method to process all the remaining items that have been added to the batches. """ - if len(self.prepared_vcf_records) > 0: - self._process_batch() + #if len(self.prepared_vcf_records) > 0: + # self._process_batch() + logger.debug("Processing remaining batches") + for tensor_size in self.batch_counters: + if len(self.batches[tensor_size] ) > 0: + self._process_batch(tensor_size) From d72d9ee61f83571d10157ddfcc05a49982ff4af7 Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Sat, 12 Nov 2022 13:39:00 +0100 Subject: [PATCH 02/42] sporadic missing predictions handling --- spliceai/batch/batch_utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index 310a82a..e36bb4c 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -67,9 +67,19 @@ def extract_delta_scores( for alt_ix in range(len(record.alts)): for gene_ix in range(len(gene_info.idxs)): + # Pull prediction out of batch - y_ref = all_y_ref[pred_ix] - y_alt = all_y_alt[pred_ix] + try: + y_ref = all_y_ref[pred_ix] + y_alt = all_y_alt[pred_ix] + except IndexError: + logger.warn("No data for record below, alt_ix {} : gene_ix {} : pred_ix {}".format(alt_ix, gene_ix,pred_ix)) + logger.warn(record) + continue + except Exception as e: + logger.error("Predction error: {}".format(e)) + logger.error(record) + raise e # No prediction here if y_ref is None or y_alt is None: From 6f1d40ef787b4a1b4591754acac9bbc969f39f6d Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Sun, 13 Nov 2022 13:53:57 +0100 Subject: [PATCH 03/42] print overall performance when batching --- spliceai/__main__.py | 29 +++++++++++++++++++++++------ spliceai/batch/batch.py | 4 +++- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/spliceai/__main__.py b/spliceai/__main__.py index 980fa5e..14f2709 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -5,6 +5,7 @@ import argparse import logging import pysam +import time from spliceai.batch.batch import VCFPredictionBatch from spliceai.utils import Annotator, get_delta_scores @@ -53,12 +54,16 @@ def main(): args = get_options() if args.verbose: - logging.basicConfig( - format='%(asctime)s %(levelname)s %(name)s: - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - level=logging.DEBUG, - ) - + loglevel = logging.DEBUG + else: + loglevel = logging.INFO + + logging.basicConfig( + format='%(asctime)s %(levelname)s %(name)s: - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.DEBUG, + ) + if None in [args.I, args.O, args.D, args.M]: logging.error('Usage: spliceai [-h] [-I [input]] [-O [output]] -R reference -A annotation ' '[-D [distance]] [-M [mask]] [-B [prediction_batch_size]] [-T [tensorflow_batch_size]]') @@ -126,6 +131,10 @@ def run_spliceai(input_data, output_data, reference, annotation, distance, mask, # Ensure we process any leftover records in the batch when we finish iterating the VCF. This # would be a good candidate for a context manager if we removed the original non batching code above batch.finish() + # stats without writing phase + duration = time.time() - batch.start_time + preds_per_sec = batch.total_predictions / duration + preds_per_hour = preds_per_sec * 60 * 60 # Iterate over original list of vcf records again, reconstructing record with annotations from shelved data vcf = pysam.VariantFile(input_data) # have to update header again @@ -141,6 +150,14 @@ def run_spliceai(input_data, output_data, reference, annotation, distance, mask, output_data.close() + ## stats + if batch: + duration = time.time() - batch.start_time + logging.info("Analysis Finished. Statistics:") + logging.info("Total RunTime: {:0.2f}s".format(duration)) + logging.info("Processed Records: {}".format(batch.total_vcf_records)) + logging.info("Processed Predictions: {}".format(batch.total_predictions)) + logging.info("Overall performance : {:0.2f} predictions/sec ; {:0.2f} predictions/hour".format(preds_per_sec, preds_per_hour)) if __name__ == '__main__': diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index 1ced6e5..e179d36 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -37,7 +37,8 @@ def __init__(self, ann, output, dist, mask, prediction_batch_size, tensorflow_ba self.prediction_batch_size = prediction_batch_size # This is the size of the batch tensorflow will use to make the predictions self.tensorflow_batch_size = tensorflow_batch_size - + # track runtime + self.start_time = time.time() # Batch vars self.batches = {} self.prepared_vcf_records = [] @@ -47,6 +48,7 @@ def __init__(self, ann, output, dist, mask, prediction_batch_size, tensorflow_ba self.total_predictions = 0 self.total_vcf_records = 0 self.batch_counters = {} + # shelves to track data. self.tmpdir = tempfile.TemporaryDirectory() From 7662a76704d335ec7f7f700134a3d4f8259da086 Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Mon, 14 Nov 2022 09:59:27 +0100 Subject: [PATCH 04/42] fix tensorflow batch size to actually use the provided size --- spliceai/batch/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index e179d36..a8aef92 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -70,7 +70,7 @@ def _process_batch(self,tensor_size): prediction_batch = np.concatenate(batch, axis=0) # Run predictions && add to shelf. self.shelf_preds["{}|{}".format(tensor_size,batch_ix)] = np.mean( - get_preds(self.ann, prediction_batch, self.prediction_batch_size), axis=0 + get_preds(self.ann, prediction_batch, self.tensorflow_batch_size), axis=0 ) # clear the batch. From ecd71f683492d6f29e436a934e4c4d035db1b8c4 Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Mon, 14 Nov 2022 18:07:09 +0100 Subject: [PATCH 05/42] revise logic to read parallel to prediction --- spliceai/__main__.py | 148 ++++++++++++++++----------- spliceai/batch/batch.py | 150 +++++++-------------------- spliceai/batch/batch_utils.py | 187 +++++++++++++++++++++++++++++++++- 3 files changed, 313 insertions(+), 172 deletions(-) diff --git a/spliceai/__main__.py b/spliceai/__main__.py index 14f2709..3fdfe85 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -6,8 +6,11 @@ import logging import pysam import time +import tempfile +from multiprocessing import Process,Queue from spliceai.batch.batch import VCFPredictionBatch +from spliceai.batch.batch_utils import prepare_batches from spliceai.utils import Annotator, get_delta_scores try: @@ -61,7 +64,7 @@ def main(): logging.basicConfig( format='%(asctime)s %(levelname)s %(name)s: - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', - level=logging.DEBUG, + level=loglevel, ) if None in [args.I, args.O, args.D, args.M]: @@ -72,14 +75,94 @@ def main(): # Default the tensorflow batch size to the prediction_batch_size if it's not supplied in the args tensorflow_batch_size = args.tensorflow_batch_size if args.tensorflow_batch_size else args.prediction_batch_size - run_spliceai(input_data=args.I, output_data=args.O, reference=args.R, - annotation=args.A, distance=args.D, mask=args.M, - prediction_batch_size=args.prediction_batch_size, - tensorflow_batch_size=tensorflow_batch_size) + # load annotation data: + ann = Annotator(args.R, args.A) + ## revised code for batched analysis + if args.prediction_batch_size > 1: + run_spliceai_batched(input_data=args.I, output_data=args.O, reference=args.R, + ann=ann, distance=args.D, mask=args.M, + prediction_batch_size=args.prediction_batch_size, + tensorflow_batch_size=tensorflow_batch_size) + else: # run original code: + run_spliceai(input_data=args.I, output_data=args.O, ann=ann, distance=args.D, mask=args.M) + +## revised logic to allow batched tensorflow analysis +def run_spliceai_batched(input_data, output_data, reference, ann, distance, mask, prediction_batch_size, + tensorflow_batch_size): + + ## mk a temp directory + tmpdir = tempfile.TemporaryDirectory() + # initialize the prediction object + batch = VCFPredictionBatch( + ann=ann, + output=output_data, + dist=distance, + mask=mask, + prediction_batch_size=prediction_batch_size, + tensorflow_batch_size=tensorflow_batch_size, + tmpdir = tmpdir + ) + + # creates a queue with max 10 ready-to-go batches in it. + # starts processing & filling the queue. + prediction_queue = Queue(maxsize=10) + reader = Process(target=prepare_batches, kwargs={'ann':ann, + 'input_data':input_data, + 'prediction_batch_size':prediction_batch_size, + 'prediction_queue': prediction_queue, + 'tmpdir':tmpdir, + 'dist' : distance + }) + reader.start() + + # Process the queue. + batch.process_batches(prediction_queue) + # join the reader process. + reader.join() + + # stats without writing phase + prediction_duration = time.time() - batch.start_time + + # write results. + # Iterate over original list of vcf records again, reconstructing record with annotations from shelved data + logging.debug("Writing output file") + vcf = pysam.VariantFile(input_data) + # have to update header again + header = vcf.header + header.add_line('##INFO=') + try: + batch.output_data = pysam.VariantFile(output_data, mode='w', header=header) + + except (IOError, ValueError) as e: + logging.error('{}'.format(e)) + exit() + + batch.write_records(vcf) + # close shelf + batch.shelf_preds.close() + # close vcf + vcf.close() + batch.output_data.close() -def run_spliceai(input_data, output_data, reference, annotation, distance, mask, prediction_batch_size, - tensorflow_batch_size): + + ## stats + overall_duration = time.time() - batch.start_time + preds_per_sec = batch.total_predictions / prediction_duration + preds_per_hour = preds_per_sec * 60 * 60 + logging.info("Analysis Finished. Statistics:") + logging.info("Total RunTime: {:0.2f}s".format(overall_duration)) + logging.info("Prediction RunTime: {:0.2f}s".format(prediction_duration)) + logging.info("Processed Records: {}".format(batch.total_vcf_records)) + logging.info("Processed Predictions: {}".format(batch.total_predictions)) + logging.info("Overall performance : {:0.2f} predictions/sec ; {:0.2f} predictions/hour".format(preds_per_sec, preds_per_hour)) + + +# original flow : record by record reading/predict/write +def run_spliceai(input_data, output_data, ann, distance, mask): try: vcf = pysam.VariantFile(input_data) @@ -99,26 +182,7 @@ def run_spliceai(input_data, output_data, reference, annotation, distance, mask, logging.error('{}'.format(e)) exit() - ann = Annotator(reference, annotation) - batch = None - - # Only use the batching code if we are batching - if prediction_batch_size > 1: - batch = VCFPredictionBatch( - ann=ann, - output=output_data, - dist=distance, - mask=mask, - prediction_batch_size=prediction_batch_size, - tensorflow_batch_size=tensorflow_batch_size, - ) - for record in vcf: - if batch: - # Add record to batch, if batch fills, then they will all be processed at once - batch.add_record(record) - else: - # If we're not batching, let's run the original code scores = get_delta_scores(record, ann, distance, mask) if len(scores) > 0: record.info['SpliceAI'] = scores @@ -126,39 +190,7 @@ def run_spliceai(input_data, output_data, reference, annotation, distance, mask, # close VCF vcf.close() - - if batch: - # Ensure we process any leftover records in the batch when we finish iterating the VCF. This - # would be a good candidate for a context manager if we removed the original non batching code above - batch.finish() - # stats without writing phase - duration = time.time() - batch.start_time - preds_per_sec = batch.total_predictions / duration - preds_per_hour = preds_per_sec * 60 * 60 - # Iterate over original list of vcf records again, reconstructing record with annotations from shelved data - vcf = pysam.VariantFile(input_data) - # have to update header again - header = vcf.header - header.add_line('##INFO=') - batch.write_records(vcf) - # close shelves - batch.shelf_records.close() - batch.shelf_preds.close() - - output_data.close() - ## stats - if batch: - duration = time.time() - batch.start_time - logging.info("Analysis Finished. Statistics:") - logging.info("Total RunTime: {:0.2f}s".format(duration)) - logging.info("Processed Records: {}".format(batch.total_vcf_records)) - logging.info("Processed Predictions: {}".format(batch.total_predictions)) - logging.info("Overall performance : {:0.2f} predictions/sec ; {:0.2f} predictions/hour".format(preds_per_sec, preds_per_hour)) - if __name__ == '__main__': main() diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index a8aef92..2437bb6 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -1,15 +1,15 @@ # Original source code modified to add prediction batching support by Invitae in 2021. # Modifications copyright (c) 2021 Invitae Corporation. -import collections +#import collections import logging import time import shelve -import tempfile import numpy as np import os from spliceai.batch.batch_utils import extract_delta_scores, get_preds, encode_batch_records +from multiprocessing import Queue,Process logger = logging.getLogger(__name__) @@ -17,18 +17,18 @@ SequenceType_ALT = 1 -BatchLookupIndex = collections.namedtuple( - # ref/alt size batch for this size index in current batch for this size - 'BatchLookupIndex', 'sequence_type tensor_size batch_ix batch_index' -) - -PreparedVCFRecord = collections.namedtuple( - 'PreparedVCFRecord', 'vcf_idx gene_info locations' -) +#BatchLookupIndex = collections.namedtuple( +# # ref/alt size batch for this size index in current batch for this size +# 'BatchLookupIndex', 'sequence_type tensor_size batch_ix batch_index' +#) +#PreparedVCFRecord = collections.namedtuple( +# 'PreparedVCFRecord', 'vcf_idx gene_info locations' +#) +# Class to handle predictions class VCFPredictionBatch: - def __init__(self, ann, output, dist, mask, prediction_batch_size, tensorflow_batch_size): + def __init__(self, ann, output, dist, mask, prediction_batch_size, tensorflow_batch_size,tmpdir): self.ann = ann self.output = output self.dist = dist @@ -44,25 +44,31 @@ def __init__(self, ann, output, dist, mask, prediction_batch_size, tensorflow_ba self.prepared_vcf_records = [] # Counts - self.batch_predictions = 0 self.total_predictions = 0 self.total_vcf_records = 0 self.batch_counters = {} + # shelves to track data. - self.tmpdir = tempfile.TemporaryDirectory() + self.tmpdir = tmpdir # store batches of predictions using 'tensor_size|batch_idx' as key. self.shelf_preds = shelve.open(os.path.join(self.tmpdir.name,"spliceai_preds.shelf")) - # track records to have order correct - self.shelf_records = shelve.open(os.path.join(self.tmpdir.name,"spliceai_records.shelf")) - - - def _process_batch(self,tensor_size): + + # monitor the queue and submit incoming batches. + def process_batches(self,prediction_queue): + while True: + item =prediction_queue.get() + # reader submits None when all are queued. + if item is None: + break + self._process_batch(item['tensor_size'],item['batch_ix'], item['data']) + + def _process_batch(self,tensor_size,batch_ix, batch): start = time.time() # get last batch for this tensor_size - batch_ix = self.batch_counters[tensor_size] - batch = self.batches[tensor_size] + #batch_ix = self.batch_counters[tensor_size] + #batch = self.batches[tensor_size] # Sanity check dump of batch sizes logger.debug('Tensor size : {} : batch_ix {} : nr.entries : {}'.format(tensor_size, batch_ix , len(batch))) @@ -74,11 +80,11 @@ def _process_batch(self,tensor_size): ) # clear the batch. - self.batches[tensor_size] = [] + #self.batches[tensor_size] = [] # initialize the next batch_ix - self.batch_counters[tensor_size] += 1 + #self.batch_counters[tensor_size] += 1 - logger.debug('Predictions: {}, VCF Records: {}'.format(self.total_predictions, self.total_vcf_records)) + #logger.debug('Predictions: {}, VCF Records: {}'.format(self.total_predictions, self.total_vcf_records)) duration = time.time() - start preds_per_sec = len(batch) / duration preds_per_hour = preds_per_sec * 60 * 60 @@ -88,15 +94,17 @@ def _process_batch(self,tensor_size): # wrapper to write out all shelved variants def write_records(self, vcf): + # open the shelf with records: + shelf_records = shelve.open(os.path.join(self.tmpdir.name,"spliceai_records.shelf")) + # parse vcf line_idx = 0 for record in vcf: line_idx += 1 # get prepared record by line_idx - prepared_record = self.shelf_records[str(line_idx)] - #record = prepared_record.vcf_record + prepared_record = shelf_records[str(line_idx)] gene_info = prepared_record.gene_info - record_predictions = 0 - + self.total_predictions += len(record.alts) * len(gene_info.genes) + all_y_ref = [] all_y_alt = [] @@ -129,90 +137,10 @@ def write_records(self, vcf): # If there are predictions, write them to the VCF INFO section if len(delta_scores) > 0: record.info['SpliceAI'] = delta_scores - record_predictions += len(delta_scores) - self.output.write(record) + self.output_data.write(record) + # close shelf again + self.total_vcf_records = line_idx + shelf_records.close() - def add_record(self, record): - """ - Adds a record to a batch. It'll capture the gene information for the record and - save it for later to avoid looking it up again, then it'll encode ref and alt from - the VCF record and place the encoded values into lists of matching sizes. Once the - encoded values are added, a BatchLookupIndex is created so that after the predictions - are made, it knows where to look up the corresponding prediction for the vcf record. - - Once the batch size hits it's capacity, it'll process all the predictions for the - encoded batch. - """ - - self.total_vcf_records += 1 - # Collect gene information for this record - gene_info = self.ann.get_name_and_strand(record.chrom, record.pos) - - # Keep track of how many predictions we're going to make - prediction_count = len(record.alts) * len(gene_info.genes) - self.batch_predictions += prediction_count - self.total_predictions += prediction_count - - # Collect lists of encoded ref/alt sequences - x_ref, x_alt = encode_batch_records(record, self.ann, self.dist, gene_info) - - # List of BatchLookupIndex's so we know how to lookup predictions for records from - # the batches - batch_lookup_indexes = [] - - # Process the encodings into batches - for var_type, encoded_seq in zip((SequenceType_REF, SequenceType_ALT), (x_ref, x_alt)): - - if len(encoded_seq) == 0: - # Add BatchLookupIndex with zeros so when the batch collects the outputs - # it knows that there is no prediction for this record - batch_lookup_indexes.append(BatchLookupIndex(var_type, 0, 0, 0)) - continue - - # Iterate over the encoded sequence and drop into the correct batch by size and - # create an index to use to pull out the result after batch is processed - for row in encoded_seq: - # Extract the size of the sequence that was encoded to build a batch from - tensor_size = row.shape[1] - - # Create batch for this size - if tensor_size not in self.batches: - self.batches[tensor_size] = [] - self.batch_counters[tensor_size] = 0 - - # Add encoded record to batch 'n' for tensor_size - self.batches[tensor_size].append(row) - - # Get the index of the record we just added in the batch - cur_batch_record_ix = len(self.batches[tensor_size]) - 1 - - # Store a reference so we can pull out the prediction for this item from the batches - batch_lookup_indexes.append(BatchLookupIndex(var_type, tensor_size, self.batch_counters[tensor_size] , cur_batch_record_ix)) - - # Save the batch locations for this record on the composite object - prepared_record = PreparedVCFRecord( - vcf_idx=self.total_vcf_records, gene_info=gene_info, locations=batch_lookup_indexes - ) - # add to shelf by vcf_idx - self.shelf_records[str(self.total_vcf_records)] = prepared_record - - # If we're reached our threshold for the max items to process, then process the batch - for tensor_size in self.batch_counters: - if len(self.batches[tensor_size]) >= self.prediction_batch_size: - logger.debug("Batch {} full. Processing".format(tensor_size)) - self._process_batch(tensor_size) - - - - def finish(self): - """ - Method to process all the remaining items that have been added to the batches. - """ - #if len(self.prepared_vcf_records) > 0: - # self._process_batch() - logger.debug("Processing remaining batches") - for tensor_size in self.batch_counters: - if len(self.batches[tensor_size] ) > 0: - self._process_batch(tensor_size) diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index e36bb4c..2df0fb4 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -2,6 +2,11 @@ # Modifications copyright (c) 2021 Invitae Corporation. import logging +import shelve +import pysam +import collections +import os +import gc from spliceai.utils import get_alt_gene_delta_score, is_record_valid, get_seq, \ is_location_predictable, get_cov, get_wid, is_valid_alt_record, encode_seqs, create_unhandled_delta_score @@ -9,11 +14,48 @@ logger = logging.getLogger(__name__) +## CUSTOM DATA TYPES +SequenceType_REF = 0 +SequenceType_ALT = 1 + +BatchLookupIndex = collections.namedtuple( + # ref/alt size batch for this size index in current batch for this size + 'BatchLookupIndex', 'sequence_type tensor_size batch_ix batch_index' +) + +PreparedVCFRecord = collections.namedtuple( + 'PreparedVCFRecord', 'vcf_idx gene_info locations' +) + + +## routine to create the batches for prediction. +def prepare_batches(ann, input_data,prediction_batch_size, prediction_queue,tmpdir,dist): + # create the parser object + vcf_reader = VCFReader(ann=ann, input_data=input_data, prediction_batch_size=prediction_batch_size, prediction_queue=prediction_queue,tmpdir=tmpdir,dist=dist) + # parse records + vcf_reader.add_records() + # finalize last batches + vcf_reader.finish() + # close the shelf. + vcf_reader.shelf_records.close() + # stats + logger.info("Read {} vcf records, queued {} predictions".format(vcf_reader.total_vcf_records, vcf_reader.total_predictions)) + + +## get tensorflow predictions using batch-based submissions def get_preds(ann, x, batch_size=32): logger.debug('Running get_preds with matrix size: {}'.format(x.shape)) - return [ - ann.models[m].predict(x, batch_size=batch_size, verbose=0) for m in range(5) - ] + try: + predictions = [ann.models[m].predict(x, batch_size=batch_size, verbose=0) for m in range(5)] + + except Exception as e: + # try a smaller batch (less efficient, but lower on memory). if it crashes again : it raises. + logger.warning("TF.predict failed ({}).Retrying with smaller batch size".format(e)) + predictions = [ann.models[m].predict(x, batch_size=4, verbose=0) for m in range(5)] + # garbage collection to prevent memory overflow... + gc.collect() + return predictions + # Heavily based on utils.get_delta_scores but only handles the validation and encoding @@ -115,3 +157,142 @@ def extract_delta_scores( pred_ix += 1 return delta_scores + + + +# class to parse input and prep batches +class VCFReader: + def __init__(self, ann, input_data, prediction_batch_size, prediction_queue,tmpdir,dist): + self.ann = ann + # This is the maximum number of predictions to parse/encode/predict at a time + self.prediction_batch_size = prediction_batch_size + # the vcf file + self.input_data = input_data + # window to consider + self.dist = dist + # Batch vars + self.batches = {} + #self.prepared_vcf_records = [] + + # Counts + self.total_predictions = 0 + self.total_vcf_records = 0 + self.batch_counters = {} + + # the queue + self.prediction_queue = prediction_queue + + # shelves to track data. + self.tmpdir = tmpdir + # track records to have order correct + self.shelf_records = shelve.open(os.path.join(self.tmpdir.name,"spliceai_records.shelf")) + + + + def add_records(self): + + try: + vcf = pysam.VariantFile(self.input_data) + except (IOError, ValueError) as e: + logging.error('{}'.format(e)) + raise(e) + for record in vcf: + try: + self.add_record(record) + except Exception as e: + raise(e) + vcf.close() + + + def add_record(self, record): + """ + Adds a record to a batch. It'll capture the gene information for the record and + save it for later to avoid looking it up again, then it'll encode ref and alt from + the VCF record and place the encoded values into lists of matching sizes. Once the + encoded values are added, a BatchLookupIndex is created so that after the predictions + are made, it knows where to look up the corresponding prediction for the vcf record. + + Once the batch size hits it's capacity, it'll process all the predictions for the + encoded batch. + """ + + self.total_vcf_records += 1 + # Collect gene information for this record + gene_info = self.ann.get_name_and_strand(record.chrom, record.pos) + + # Keep track of how many predictions we're going to make + prediction_count = len(record.alts) * len(gene_info.genes) + self.total_predictions += prediction_count + + # Collect lists of encoded ref/alt sequences + x_ref, x_alt = encode_batch_records(record, self.ann, self.dist, gene_info) + + # List of BatchLookupIndex's so we know how to lookup predictions for records from + # the batches + batch_lookup_indexes = [] + + # Process the encodings into batches + for var_type, encoded_seq in zip((SequenceType_REF, SequenceType_ALT), (x_ref, x_alt)): + + if len(encoded_seq) == 0: + # Add BatchLookupIndex with zeros so when the batch collects the outputs + # it knows that there is no prediction for this record + batch_lookup_indexes.append(BatchLookupIndex(var_type, 0, 0, 0)) + continue + + # Iterate over the encoded sequence and drop into the correct batch by size and + # create an index to use to pull out the result after batch is processed + for row in encoded_seq: + # Extract the size of the sequence that was encoded to build a batch from + tensor_size = row.shape[1] + + # Create batch for this size + if tensor_size not in self.batches: + self.batches[tensor_size] = [] + self.batch_counters[tensor_size] = 0 + + # Add encoded record to batch 'n' for tensor_size + self.batches[tensor_size].append(row) + + # Get the index of the record we just added in the batch + cur_batch_record_ix = len(self.batches[tensor_size]) - 1 + + # Store a reference so we can pull out the prediction for this item from the batches + batch_lookup_indexes.append(BatchLookupIndex(var_type, tensor_size, self.batch_counters[tensor_size] , cur_batch_record_ix)) + + # Save the batch locations for this record on the composite object + prepared_record = PreparedVCFRecord( + vcf_idx=self.total_vcf_records, gene_info=gene_info, locations=batch_lookup_indexes + ) + # add to shelf by vcf_idx + self.shelf_records[str(self.total_vcf_records)] = prepared_record + + # If we're reached our threshold for the max items to process, then process the batch + for tensor_size in self.batch_counters: + if len(self.batches[tensor_size]) >= self.prediction_batch_size: + logger.debug("Batch {} full. Adding to queue".format(tensor_size)) + queue_item = {'tensor_size': tensor_size, 'batch_ix': self.batch_counters[tensor_size], 'data' : self.batches[tensor_size]} + self.prediction_queue.put(queue_item) + # reset + self.batches[tensor_size] = [] + self.batch_counters[tensor_size] += 10 + + #self._process_batch(tensor_size) + + + + def finish(self): + """ + Method to process all the remaining items that have been added to the batches. + """ + #if len(self.prepared_vcf_records) > 0: + # self._process_batch() + logger.debug("Queueing remaining batches") + for tensor_size in self.batch_counters: + if len(self.batches[tensor_size] ) > 0: + queue_item = {'tensor_size': tensor_size, 'batch_ix': self.batch_counters[tensor_size], 'data' : self.batches[tensor_size]} + self.prediction_queue.put(queue_item) + # clear + self.batches[tensor_size] = [] + # all done : + self.prediction_queue.put(None) \ No newline at end of file From 5e64b960dd793f74a1355699d138dabd9d14073f Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Tue, 15 Nov 2022 11:33:55 +0100 Subject: [PATCH 06/42] disabled performance output on predict in illumina implementation --- spliceai/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/spliceai/utils.py b/spliceai/utils.py index 0f79669..00d44c0 100644 --- a/spliceai/utils.py +++ b/spliceai/utils.py @@ -1,6 +1,9 @@ # Original source code modified to add prediction batching support by Invitae in 2021. # Modifications copyright (c) 2021 Invitae Corporation. +# Invitae source code modified to improve GPU utilization +# Modifications made by Geert Vandeweyer (Antwerp University Hospital, Belgium) + import collections from pkg_resources import resource_filename @@ -215,8 +218,8 @@ def get_delta_scores(record, ann, dist_var, mask): alt_ix=alt_ix, wid=wid) - y_ref = np.mean([ann.models[m].predict(x_ref) for m in range(5)], axis=0) - y_alt = np.mean([ann.models[m].predict(x_alt) for m in range(5)], axis=0) + y_ref = np.mean([ann.models[m].predict(x_ref,verbose=0) for m in range(5)], axis=0) + y_alt = np.mean([ann.models[m].predict(x_alt,verbose=0) for m in range(5)], axis=0) delta_score = get_alt_gene_delta_score(record=record, ann=ann, From e20c586a9cded4f0ac4670b0ffe992e3a5e9fc2f Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Tue, 15 Nov 2022 11:34:23 +0100 Subject: [PATCH 07/42] furhter optimization of gpu usage through offloading np-tensor conversion to cpu --- spliceai/__main__.py | 12 +++++--- spliceai/batch/batch.py | 56 +++++++++++++++++------------------ spliceai/batch/batch_utils.py | 37 ++++++++++++++++++----- 3 files changed, 65 insertions(+), 40 deletions(-) diff --git a/spliceai/__main__.py b/spliceai/__main__.py index 3fdfe85..23b5d77 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -48,6 +48,9 @@ def get_options(): parser.add_argument('-T', '--tensorflow-batch-size', metavar='tensorflow_batch_size', type=int, help='tensorflow batch size for model predictions') parser.add_argument('-V', '--verbose', action='store_true', help='enables verbose logging') + parser.add_argument('-t',metavar='tmpdir',type=str,default='/tmp/',required=False, + help="Use Alternate location to store tmp files. (Note: B=4096 equals to roughly 15Gb of tmp files)") + args = parser.parse_args() return args @@ -69,7 +72,7 @@ def main(): if None in [args.I, args.O, args.D, args.M]: logging.error('Usage: spliceai [-h] [-I [input]] [-O [output]] -R reference -A annotation ' - '[-D [distance]] [-M [mask]] [-B [prediction_batch_size]] [-T [tensorflow_batch_size]]') + '[-D [distance]] [-M [mask]] [-B [prediction_batch_size]] [-T [tensorflow_batch_size]] [-t [tmp_location]]') exit() # Default the tensorflow batch size to the prediction_batch_size if it's not supplied in the args @@ -82,16 +85,17 @@ def main(): run_spliceai_batched(input_data=args.I, output_data=args.O, reference=args.R, ann=ann, distance=args.D, mask=args.M, prediction_batch_size=args.prediction_batch_size, - tensorflow_batch_size=tensorflow_batch_size) + tensorflow_batch_size=tensorflow_batch_size,tempdir=args.t) else: # run original code: run_spliceai(input_data=args.I, output_data=args.O, ann=ann, distance=args.D, mask=args.M) ## revised logic to allow batched tensorflow analysis def run_spliceai_batched(input_data, output_data, reference, ann, distance, mask, prediction_batch_size, - tensorflow_batch_size): + tensorflow_batch_size,tempdir): ## mk a temp directory - tmpdir = tempfile.TemporaryDirectory() + tmpdir = tempfile.TemporaryDirectory(dir=tempdir) + logging.debug("tmp dir : {}".format(tmpdir.name)) # initialize the prediction object batch = VCFPredictionBatch( ann=ann, diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index 2437bb6..cd5d101 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -1,15 +1,18 @@ # Original source code modified to add prediction batching support by Invitae in 2021. # Modifications copyright (c) 2021 Invitae Corporation. -#import collections +# Invitae source code modified to improve GPU utilization +# Modifications made by Geert Vandeweyer (Antwerp University Hospital, Belgium) + import logging import time import shelve import numpy as np import os +import tensorflow as tf +import pickle -from spliceai.batch.batch_utils import extract_delta_scores, get_preds, encode_batch_records -from multiprocessing import Queue,Process +from spliceai.batch.batch_utils import extract_delta_scores, get_preds logger = logging.getLogger(__name__) @@ -17,15 +20,6 @@ SequenceType_ALT = 1 -#BatchLookupIndex = collections.namedtuple( -# # ref/alt size batch for this size index in current batch for this size -# 'BatchLookupIndex', 'sequence_type tensor_size batch_ix batch_index' -#) - -#PreparedVCFRecord = collections.namedtuple( -# 'PreparedVCFRecord', 'vcf_idx gene_info locations' -#) - # Class to handle predictions class VCFPredictionBatch: def __init__(self, ann, output, dist, mask, prediction_batch_size, tensorflow_batch_size,tmpdir): @@ -62,31 +56,28 @@ def process_batches(self,prediction_queue): # reader submits None when all are queued. if item is None: break - self._process_batch(item['tensor_size'],item['batch_ix'], item['data']) + # load pickled object + with open(os.path.join(self.tmpdir.name,item),'rb') as p: + data = pickle.load(p) + # remove from disk. + os.unlink(os.path.join(self.tmpdir.name,item)) + self._process_batch(data['tensor_size'],data['batch_ix'], data['data'],data['length']) + - def _process_batch(self,tensor_size,batch_ix, batch): + def _process_batch(self,tensor_size,batch_ix, prediction_batch,nr_preds): start = time.time() - # get last batch for this tensor_size - #batch_ix = self.batch_counters[tensor_size] - #batch = self.batches[tensor_size] + # Sanity check dump of batch sizes - logger.debug('Tensor size : {} : batch_ix {} : nr.entries : {}'.format(tensor_size, batch_ix , len(batch))) + logger.debug('Tensor size : {} : batch_ix {} : nr.entries : {}'.format(tensor_size, batch_ix , nr_preds)) - # Convert list of encodings into a proper sized numpy matrix - prediction_batch = np.concatenate(batch, axis=0) # Run predictions && add to shelf. self.shelf_preds["{}|{}".format(tensor_size,batch_ix)] = np.mean( get_preds(self.ann, prediction_batch, self.tensorflow_batch_size), axis=0 ) - - # clear the batch. - #self.batches[tensor_size] = [] - # initialize the next batch_ix - #self.batch_counters[tensor_size] += 1 - #logger.debug('Predictions: {}, VCF Records: {}'.format(self.total_predictions, self.total_vcf_records)) + # status duration = time.time() - start - preds_per_sec = len(batch) / duration + preds_per_sec = nr_preds / duration preds_per_hour = preds_per_sec * 60 * 60 logger.debug('Finished in {:0.2f}s, per sec: {:0.2f}, per hour: {:0.2f}'.format(duration, preds_per_sec, @@ -98,12 +89,15 @@ def write_records(self, vcf): shelf_records = shelve.open(os.path.join(self.tmpdir.name,"spliceai_records.shelf")) # parse vcf line_idx = 0 + batch = [] + last_batch_key = '' for record in vcf: line_idx += 1 # get prepared record by line_idx prepared_record = shelf_records[str(line_idx)] gene_info = prepared_record.gene_info - self.total_predictions += len(record.alts) * len(gene_info.genes) + # (REF + #ALT ) * #genes (* 5 models) + self.total_predictions += (1 + len(record.alts)) * len(gene_info.genes) all_y_ref = [] all_y_alt = [] @@ -119,7 +113,11 @@ def write_records(self, vcf): continue # Extract the prediction from the batch into a list of predictions for this record - batch = self.shelf_preds["{}|{}".format(location.tensor_size,location.batch_ix)] # batch_preds[location.tensor_size] + # recycle the batch variable if key is the same. + if not last_batch_key == "{}|{}".format(location.tensor_size,location.batch_ix): + last_batch_key = "{}|{}".format(location.tensor_size,location.batch_ix) + batch = self.shelf_preds[last_batch_key] # batch_preds[location.tensor_size] + if location.sequence_type == SequenceType_REF: all_y_ref.append(batch[[location.batch_index], :, :]) else: diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index 2df0fb4..00d672b 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -1,12 +1,19 @@ # Original source code modified to add prediction batching support by Invitae in 2021. # Modifications copyright (c) 2021 Invitae Corporation. +# Invitae source code modified to improve GPU utilization +# Modifications made by Geert Vandeweyer (Antwerp University Hospital, Belgium) + + import logging import shelve import pysam import collections import os import gc +import numpy as np +import tensorflow as tf +import pickle from spliceai.utils import get_alt_gene_delta_score, is_record_valid, get_seq, \ is_location_predictable, get_cov, get_wid, is_valid_alt_record, encode_seqs, create_unhandled_delta_score @@ -46,8 +53,7 @@ def prepare_batches(ann, input_data,prediction_batch_size, prediction_queue,tmpd def get_preds(ann, x, batch_size=32): logger.debug('Running get_preds with matrix size: {}'.format(x.shape)) try: - predictions = [ann.models[m].predict(x, batch_size=batch_size, verbose=0) for m in range(5)] - + predictions = [ann.models[m].predict(x, batch_size=batch_size, verbose=0) for m in range(5)] except Exception as e: # try a smaller batch (less efficient, but lower on memory). if it crashes again : it raises. logger.warning("TF.predict failed ({}).Retrying with smaller batch size".format(e)) @@ -271,11 +277,20 @@ def add_record(self, record): for tensor_size in self.batch_counters: if len(self.batches[tensor_size]) >= self.prediction_batch_size: logger.debug("Batch {} full. Adding to queue".format(tensor_size)) - queue_item = {'tensor_size': tensor_size, 'batch_ix': self.batch_counters[tensor_size], 'data' : self.batches[tensor_size]} - self.prediction_queue.put(queue_item) + # fully prep the batch outside of gpu routine... + data = np.concatenate(self.batches[tensor_size]) + concat_len = len(data) + # offload conversion of batch from np to tensor to CPU + with tf.device('CPU:0'): + data = tf.convert_to_tensor(data) + queue_item = {'tensor_size': tensor_size, 'batch_ix': self.batch_counters[tensor_size], 'data' : data, 'length':concat_len} + with open(os.path.join(self.tmpdir.name,"{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])),"wb") as p: + pickle.dump(queue_item,p) + self.prediction_queue.put("{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])) + # reset self.batches[tensor_size] = [] - self.batch_counters[tensor_size] += 10 + self.batch_counters[tensor_size] += 1 #self._process_batch(tensor_size) @@ -290,8 +305,16 @@ def finish(self): logger.debug("Queueing remaining batches") for tensor_size in self.batch_counters: if len(self.batches[tensor_size] ) > 0: - queue_item = {'tensor_size': tensor_size, 'batch_ix': self.batch_counters[tensor_size], 'data' : self.batches[tensor_size]} - self.prediction_queue.put(queue_item) + # fully prep the batch outside of gpu routine... + data = np.concatenate(self.batches[tensor_size]) + concat_len = len(data) + # offload conversion of batch from np to tensor to CPU + with tf.device('CPU:0'): + data = tf.convert_to_tensor(data) + queue_item = {'tensor_size': tensor_size, 'batch_ix': self.batch_counters[tensor_size], 'data' : data, 'length':concat_len} + with open(os.path.join(self.tmpdir.name,"{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])),"wb") as p: + pickle.dump(queue_item,p) + self.prediction_queue.put("{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])) # clear self.batches[tensor_size] = [] # all done : From ee16e0065d25167038500e492e07480f7847850f Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Wed, 16 Nov 2022 06:25:58 +0100 Subject: [PATCH 08/42] add gc collect to original code to prevent oom kills --- spliceai/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spliceai/utils.py b/spliceai/utils.py index 00d44c0..5049084 100644 --- a/spliceai/utils.py +++ b/spliceai/utils.py @@ -12,7 +12,7 @@ from pyfaidx import Fasta from keras.models import load_model import logging - +import gc GeneInfo = collections.namedtuple('GeneInfo', 'genes strands idxs') @@ -231,7 +231,7 @@ def get_delta_scores(record, ann, dist_var, mask): gene_info=gene_info, mask=mask) delta_scores.append(delta_score) - + gc.collect() return delta_scores From 316079355ceb47c1556548cf68af47ee74855035 Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Wed, 16 Nov 2022 07:01:09 +0100 Subject: [PATCH 09/42] add long variant for all arguments --- spliceai/__main__.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/spliceai/__main__.py b/spliceai/__main__.py index 23b5d77..1480d26 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -24,21 +24,21 @@ def get_options(): parser = argparse.ArgumentParser(description='Version: 1.3.1') - parser.add_argument('-I', metavar='input', nargs='?', default=std_in, + parser.add_argument('-I', '--input', metavar='input', nargs='?', default=std_in, help='path to the input VCF file, defaults to standard in') - parser.add_argument('-O', metavar='output', nargs='?', default=std_out, + parser.add_argument('-O', '--output', metavar='output', nargs='?', default=std_out, help='path to the output VCF file, defaults to standard out') - parser.add_argument('-R', metavar='reference', required=True, + parser.add_argument('-R', '--reference', metavar='reference', required=True, help='path to the reference genome fasta file') - parser.add_argument('-A', metavar='annotation', required=True, + parser.add_argument('-A', '--annotation',metavar='annotation', required=True, help='"grch37" (GENCODE V24lift37 canonical annotation file in ' 'package), "grch38" (GENCODE V24 canonical annotation file in ' 'package), or path to a similar custom gene annotation file') - parser.add_argument('-D', metavar='distance', nargs='?', default=50, + parser.add_argument('-D', '--distance', metavar='distance', nargs='?', default=50, type=int, choices=range(0, 5000), help='maximum distance between the variant and gained/lost splice ' 'site, defaults to 50') - parser.add_argument('-M', metavar='mask', nargs='?', default=0, + parser.add_argument('-M', '--mask', metavar='mask', nargs='?', default=0, type=int, choices=[0, 1], help='mask scores representing annotated acceptor/donor gain and ' 'unannotated acceptor/donor loss, defaults to 0') @@ -48,7 +48,7 @@ def get_options(): parser.add_argument('-T', '--tensorflow-batch-size', metavar='tensorflow_batch_size', type=int, help='tensorflow batch size for model predictions') parser.add_argument('-V', '--verbose', action='store_true', help='enables verbose logging') - parser.add_argument('-t',metavar='tmpdir',type=str,default='/tmp/',required=False, + parser.add_argument('-t','--tmpdir', metavar='tmpdir',type=str,default='/tmp/',required=False, help="Use Alternate location to store tmp files. (Note: B=4096 equals to roughly 15Gb of tmp files)") args = parser.parse_args() From 36a66592e85c9526e51eb4286bf289322d369945 Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Wed, 16 Nov 2022 07:01:44 +0100 Subject: [PATCH 10/42] add dockerfile --- spliceai/docker/Dockerfile | 67 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 spliceai/docker/Dockerfile diff --git a/spliceai/docker/Dockerfile b/spliceai/docker/Dockerfile new file mode 100644 index 0000000..9a02001 --- /dev/null +++ b/spliceai/docker/Dockerfile @@ -0,0 +1,67 @@ +###################################### +## CONTAINER FOR GPU based SpliceAI ## +###################################### + +# start from the cuda docker base +FROM nvidia/cuda:11.4.0-base-ubuntu20.04 + +LABEL version="1.3" +LABEL description="This container was tested with \ + - V100 on AWS p3.2xlarge with nvidia drivers 510.47.03 and cuda v11.6 \ + - K80 on AWS p2.xlarge with nvidia drivers 470.141.03 and cuda v11.4 \ + - Geforce RTX 2070 SUPER (local) with nvidia drivers 470.141.03 and cuda v11.4" + +LABEL author="Geert Vandeweyer" +LABEL author.email="geert.vandeweyer@uza.be" + +## needed apt packages +ARG BUILD_PACKAGES="wget git bzip2" +# needed conda packages + +ARG CONDA_PACKAGES="python=3.9.13 tensorflow-gpu=2.10.0 cuda-nvcc=11.8.89" + +## ENV SETTINGS during runtime +ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 +ENV PATH=/opt/conda/bin:$PATH +ENV DEBIAN_FRONTEND noninteractive + +# For micromamba: +SHELL ["/bin/bash", "-l", "-c"] +ENV MAMBA_ROOT_PREFIX=/opt/conda/ +ENV PATH=/opt/micromamba/bin:/opt/conda/bin:$PATH +ARG CONDA_CHANNEL="-c bioconda -c conda-forge -c nvidia" + +## INSTALL +RUN apt-get -y update && \ + apt-get -y install $BUILD_PACKAGES && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + + +# conda packages +RUN mkdir /opt/conda && \ + mkdir /opt/micromamba && \ + wget -qO - https://micromamba.snakepit.net/api/micromamba/linux-64/0.23.0 | tar -xvj -C /opt/micromamba bin/micromamba && \ + # initialize bash + micromamba shell init --shell=bash --prefix=/opt/conda && \ + # remove a statement from bashrc that prevents initialization + grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/micromamba/bashrc && \ + mv /opt/micromamba/bashrc /root/.bashrc && \ + source ~/.bashrc && \ + # activate & install base conda packag + micromamba activate && \ + micromamba install -y $CONDA_CHANNEL $CONDA_PACKAGES && \ + micromamba clean --all --yes + +# Break cache for recloning git +ENV DATE_CACHE_BREAK=$(date) + +# my fork of spliceai : has gpu optimizations +RUN cd /opt/ && \ + git clone https://github.com/geertvandeweyer/SpliceAI.git && \ + cd SpliceAI && \ + python setup.py install + +# no command given, print help. +CMD spliceai -h + From e6e06ccb071025d7e27d60e0f44cec53f107cc23 Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Wed, 16 Nov 2022 07:03:36 +0100 Subject: [PATCH 11/42] moved docker directory --- {spliceai/docker => docker}/Dockerfile | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {spliceai/docker => docker}/Dockerfile (100%) diff --git a/spliceai/docker/Dockerfile b/docker/Dockerfile similarity index 100% rename from spliceai/docker/Dockerfile rename to docker/Dockerfile From d223e4cc22d3ee18859d02ab945831588cdd24aa Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Wed, 16 Nov 2022 07:38:27 +0100 Subject: [PATCH 12/42] corrected long argument access --- spliceai/__main__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/spliceai/__main__.py b/spliceai/__main__.py index 1480d26..a8a5c50 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -70,7 +70,7 @@ def main(): level=loglevel, ) - if None in [args.I, args.O, args.D, args.M]: + if None in [args.input, args.output, args.distance, args.mask]: logging.error('Usage: spliceai [-h] [-I [input]] [-O [output]] -R reference -A annotation ' '[-D [distance]] [-M [mask]] [-B [prediction_batch_size]] [-T [tensorflow_batch_size]] [-t [tmp_location]]') exit() @@ -79,15 +79,15 @@ def main(): tensorflow_batch_size = args.tensorflow_batch_size if args.tensorflow_batch_size else args.prediction_batch_size # load annotation data: - ann = Annotator(args.R, args.A) + ann = Annotator(args.reference, args.annotation) ## revised code for batched analysis if args.prediction_batch_size > 1: - run_spliceai_batched(input_data=args.I, output_data=args.O, reference=args.R, - ann=ann, distance=args.D, mask=args.M, + run_spliceai_batched(input_data=args.input, output_data=args.output, reference=args.reference, + ann=ann, distance=args.distance, mask=args.mask, prediction_batch_size=args.prediction_batch_size, - tensorflow_batch_size=tensorflow_batch_size,tempdir=args.t) + tensorflow_batch_size=tensorflow_batch_size,tempdir=args.tmpdir) else: # run original code: - run_spliceai(input_data=args.I, output_data=args.O, ann=ann, distance=args.D, mask=args.M) + run_spliceai(input_data=args.input, output_data=args.output, ann=ann, distance=args.distance, mask=args.mask) ## revised logic to allow batched tensorflow analysis def run_spliceai_batched(input_data, output_data, reference, ann, distance, mask, prediction_batch_size, From 580f5d53501855396aaaa9c0edee363ebb6d30c6 Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Wed, 16 Nov 2022 15:10:48 +0100 Subject: [PATCH 13/42] updated benchmarks --- README.md | 59 +++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 42d8b5f..e9ea25c 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ Required parameters: - ```-I```: Input VCF with variants of interest. - ```-O```: Output VCF with SpliceAI predictions `ALLELE|SYMBOL|DS_AG|DS_AL|DS_DG|DS_DL|DP_AG|DP_AL|DP_DG|DP_DL` included in the INFO column (see table below for details). Only SNVs and simple INDELs (REF or ALT is a single base) within genes are annotated. Variants in multiple genes have separate predictions for each gene. - ```-R```: Reference genome fasta file. Can be downloaded from [GRCh37/hg19](http://hgdownload.cse.ucsc.edu/goldenPath/hg19/bigZips/hg19.fa.gz) or [GRCh38/hg38](http://hgdownload.cse.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz). - - ```-A```: Gene annotation file. Can instead provide `grch37` or `grch38` to use GENCODE V24 canonical annotation files included with the package. To create custom gene annotation files, use `spliceai/annotations/grch37.txt` in repository as template. + - ```-A```: Gene annotation file. Can instead provide `grch37` or `grch38` to use GENCODE V24 canonical annotation files included with the package. To create custom gene annotation files, use `spliceai/annotations/grch37.txt` in repository as template and provide as full path. Optional parameters: - ```-D```: Maximum distance between the variant and gained/lost splice site (default: 50). @@ -50,20 +50,48 @@ Optional parameters: - ```-B```: Number of predictions to collect before running models on them in batch. (default: 1 (don't batch)) - ```-T```: Internal Tensorflow `predict()` batch size if you want something different from the `-B` value. (default: the `-B` value) - ```-V```: Enable verbose logging during run - -**Batching Considerations:** When setting the batching parameters, be mindful of the system and gpu memory of the machine you -are running the script on. Feel free to experiment, but some reasonable `-B` numbers would be 64/128. - -Batching Performance Benchmarks: - -| Type | Speed | -| -------- | ----------- | -| n1-standard-2 CPU (GCP) | ~800 per hour | -| CPU (2019 MacBook Pro) | ~3,000 per hour | -| K80 GPU (GCP) | ~25,000 per hour | -| V100 GPU (GCP) | ~150,000 per hour | - -Details of SpliceAI INFO field: + - ```-t```: Specify a location to create the temporary files + +**Batching Considerations:** + +When setting the batching parameters, be mindful of the system and gpu memory of the machine you +are running the script on. Feel free to experiment, but some reasonable `-T` numbers would be 64/128. CPU memory is larger, and increasing `-B` might further improve performance. + +*Batching Performance Benchmarks:* +- Input data: GATK generated WES sample with ~ 90K variants in genome build GRCh37. +- Total predictions made : 174,237 +- invitae v2 mainly implements logic to prioritize full batches while predicting +- settings : + - invitae & invitae v2 : B = T = 64 + - invitae v2 optimal : on V100 : B = 4096 ; T = 256 -- on K80/GeForce : B = 4096 ; T = 64 + +*Benchmark results* + +| Type | Implementation | Total Time | Speed (predictions / hour) | +| -------- | -------------- | ----------- | -------------------------- | +| CPU (intel i5-8365U)a | illumina | ~100h | ~1000 pred/h | +| | invitae | ~39h | ~4500 pred/h | +| | invitae v2 | ~35h | ~5000 pred/h | +| | invitae v2 optimal | ~35h | ~5000 pred/h | +| K80 GPU (AWS p2.large) | illuminab | ~25 h | ~7000 pred/h | +| | invitae | 242m | ~43,000 pred / h | +| | invitae v2 | 213m | ~50,000 pred / h | +| | invitae v2 optimal | 188 m | ~56,000 pred / h | +| GeForce RTX 2070 SUPER GPU (desktop) | illuminab | ~10 h | ~ 17,000 pred/h | +| | invitae | 76m | ~137,000 pred / h | +| | invitae v2 | 63m | ~166,000 pred / h | +| | invitae v2 optimal | 52m | ~200,000 pred / h | +| V100 GPU (AWS p3.xlarge) | illuminab | ~10h | ~18,000 pred/h | +| | invitae | 78m | ~135,000 pred / h | +| | invitae v2 | 54m | ~190,000 pred / h | +| | invitae v2 optimal | 31 m | ~335,000 pred / h | +| + +(a) : Extrapolated from first 500 variants +(b) : Illumina implementation showed a memory leak with the installed versions of tf/keras/.... Values extrapolated from incomplete runs at the point of OOM. + + +### Details of SpliceAI INFO field: | ID | Description | | -------- | ----------- | @@ -135,3 +163,4 @@ donor_prob = y[0, :, 2] ### Contact Kishore Jaganathan: kjaganathan@illumina.com +Geert Vandeweyer (This implementation) : geert.vandeweyer@uza.be From 90a3f01d77935ee188a2813f4aadf6a84ce198fe Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Wed, 16 Nov 2022 16:34:23 +0100 Subject: [PATCH 14/42] added docker link --- README.md | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index e9ea25c..6966670 100644 --- a/README.md +++ b/README.md @@ -9,16 +9,25 @@ This package annotates genetic variants with their predicted effect on splicing, SpliceAI source code is provided under the [GPLv3 license](LICENSE). SpliceAI includes several third party packages provided under other open source licenses, please see [NOTICE](NOTICE) for additional details. The trained models used by SpliceAI (located in this package at spliceai/models) are provided under the [CC BY NC 4.0](LICENSE) license for academic and non-commercial use; other use requires a commercial license from Illumina, Inc. ### Installation -The simplest way to install SpliceAI is through pip or conda: + +This release can most easily be used as a docker container: + +``' docker pull cmgantwerpen/spliceai_v1.3 + +docker run --gpus all cmgantwerpen/spliceai_v1.3 spliceai -h +``` + + +The simplest way to install (the original version of) SpliceAI is through pip or conda: ```sh pip install spliceai # or conda install -c bioconda spliceai ``` -Alternately, SpliceAI can be installed from the [github repository](https://github.com/Illumina/SpliceAI.git): +Alternately, SpliceAI can be installed from the [github repository](https://github.com/invitae/SpliceAI.git): ```sh -git clone https://github.com/Illumina/SpliceAI.git +git clone https://github.com/invitae/SpliceAI.git cd SpliceAI python setup.py install ``` From f7ec23ca4b4023dd3b304956a636946152a988ce Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Wed, 16 Nov 2022 16:38:03 +0100 Subject: [PATCH 15/42] small fix to dockerfile --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 9a02001..4aefaff 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -54,7 +54,7 @@ RUN mkdir /opt/conda && \ micromamba clean --all --yes # Break cache for recloning git -ENV DATE_CACHE_BREAK=$(date) +ARG DATE_CACHE_BREAK=$(date) # my fork of spliceai : has gpu optimizations RUN cd /opt/ && \ From 7255331cc2e992ab94dc5db1f64feb5edccfd52e Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Wed, 16 Nov 2022 16:52:27 +0100 Subject: [PATCH 16/42] small fix to dockerfile --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6966670..61209ce 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ This release can most easily be used as a docker container: ``' docker pull cmgantwerpen/spliceai_v1.3 -docker run --gpus all cmgantwerpen/spliceai_v1.3 spliceai -h +docker run --gpus all cmgantwerpen/spliceai_v1.3:latest spliceai -h ``` From b1f28d24a256e701a116809cce6fdaf415d7511f Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Wed, 16 Nov 2022 17:05:54 +0100 Subject: [PATCH 17/42] fixed typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 61209ce..8832acc 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ SpliceAI source code is provided under the [GPLv3 license](LICENSE). SpliceAI in This release can most easily be used as a docker container: -``' docker pull cmgantwerpen/spliceai_v1.3 +```docker pull cmgantwerpen/spliceai_v1.3 docker run --gpus all cmgantwerpen/spliceai_v1.3:latest spliceai -h ``` From 2c5d0529c6129d876dd6079ed9b8c38fdd7fc939 Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Wed, 16 Nov 2022 17:06:28 +0100 Subject: [PATCH 18/42] fixed code layout --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8832acc..8662bca 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,8 @@ SpliceAI source code is provided under the [GPLv3 license](LICENSE). SpliceAI in This release can most easily be used as a docker container: -```docker pull cmgantwerpen/spliceai_v1.3 +```sh +docker pull cmgantwerpen/spliceai_v1.3 docker run --gpus all cmgantwerpen/spliceai_v1.3:latest spliceai -h ``` From 1b28d311daf22da698497c2c2364036cacd5f77f Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Wed, 16 Nov 2022 17:14:04 +0100 Subject: [PATCH 19/42] fixed code layout --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8662bca..c2fda11 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ SpliceAI source code is provided under the [GPLv3 license](LICENSE). SpliceAI in This release can most easily be used as a docker container: ```sh -docker pull cmgantwerpen/spliceai_v1.3 +docker pull cmgantwerpen/spliceai_v1.3:latest docker run --gpus all cmgantwerpen/spliceai_v1.3:latest spliceai -h ``` @@ -98,6 +98,7 @@ are running the script on. Feel free to experiment, but some reasonable `-T` num | (a) : Extrapolated from first 500 variants + (b) : Illumina implementation showed a memory leak with the installed versions of tf/keras/.... Values extrapolated from incomplete runs at the point of OOM. @@ -170,7 +171,12 @@ donor_prob = y[0, :, 2] * Adds batch utility methods that split up what was all previously done in `get_delta_scores`. `encode_batch_record` handles what was in the first half, taking in the VCF record and generating one-hot encoded matrices for the ref/alts. `extract_delta_scores` handles the second half of the `get_delta_scores` by reassembling the annotations based on the batched predictions * Adds test cases to run a small file using a generated FASTA reference to test if the results are the same with no batching and with different batching sizes * Slightly modifies the entrypoint of running the code to allow for easier unit testing. Being able to pass in what would normally come from the argparser +* Offload more code to CPU (eg np to tensor conversion) to *only* perform predictions on the GPU +* Implement queuing system to always have full batches ready for prediction +* Implement new parameter, `--tmpdir` to support a custom tmp folder + ### Contact Kishore Jaganathan: kjaganathan@illumina.com + Geert Vandeweyer (This implementation) : geert.vandeweyer@uza.be From 2cf5ac5f917a9617f69a00bd4150372467d78809 Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Wed, 16 Nov 2022 17:14:35 +0100 Subject: [PATCH 20/42] fixed table layout --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c2fda11..1283006 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ are running the script on. Feel free to experiment, but some reasonable `-T` num | | invitae | 78m | ~135,000 pred / h | | | invitae v2 | 54m | ~190,000 pred / h | | | invitae v2 optimal | 31 m | ~335,000 pred / h | -| + (a) : Extrapolated from first 500 variants From af2e7f84d8f4d65956fd87773a1dfabdd31441fa Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Wed, 16 Nov 2022 19:15:06 +0100 Subject: [PATCH 21/42] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1283006..b4bd086 100644 --- a/README.md +++ b/README.md @@ -91,7 +91,7 @@ are running the script on. Feel free to experiment, but some reasonable `-T` num | | invitae | 76m | ~137,000 pred / h | | | invitae v2 | 63m | ~166,000 pred / h | | | invitae v2 optimal | 52m | ~200,000 pred / h | -| V100 GPU (AWS p3.xlarge) | illuminab | ~10h | ~18,000 pred/h | +| V100 GPU (AWS p3.xlarge) | illuminab | ~10h | ~18,000 pred/h | | | invitae | 78m | ~135,000 pred / h | | | invitae v2 | 54m | ~190,000 pred / h | | | invitae v2 optimal | 31 m | ~335,000 pred / h | From c7fd5796af3db40998003d63b41ddf7de26cd734 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Mon, 28 Nov 2022 16:02:22 +0100 Subject: [PATCH 22/42] revised code to scale to multiple gpus --- spliceai/__main__.py | 177 ++++++++------ spliceai/batch/batch.py | 217 ++++++++++------- spliceai/batch/batch_utils.py | 406 +++++++++++-------------------- spliceai/batch/data_handlers.py | 419 ++++++++++++++++++++++++++++++++ spliceai/utils.py | 6 +- 5 files changed, 791 insertions(+), 434 deletions(-) create mode 100644 spliceai/batch/data_handlers.py diff --git a/spliceai/__main__.py b/spliceai/__main__.py index a8a5c50..048727b 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -7,11 +7,16 @@ import pysam import time import tempfile -from multiprocessing import Process,Queue +from multiprocessing import Process,Queue,Pool +from functools import partial -from spliceai.batch.batch import VCFPredictionBatch -from spliceai.batch.batch_utils import prepare_batches +import tensorflow as tf +import subprocess as sp +import os + +from spliceai.batch.batch_utils import prepare_batches, start_workers,initialize_devices from spliceai.utils import Annotator, get_delta_scores +from spliceai.batch.data_handlers import VCFWriter try: from sys.stdin import buffer as std_in @@ -24,9 +29,9 @@ def get_options(): parser = argparse.ArgumentParser(description='Version: 1.3.1') - parser.add_argument('-I', '--input', metavar='input', nargs='?', default=std_in, + parser.add_argument('-I', '--input_data', metavar='input', nargs='?', default=std_in, help='path to the input VCF file, defaults to standard in') - parser.add_argument('-O', '--output', metavar='output', nargs='?', default=std_out, + parser.add_argument('-O', '--output_data', metavar='output', nargs='?', default=std_out, help='path to the output VCF file, defaults to standard out') parser.add_argument('-R', '--reference', metavar='reference', required=True, help='path to the reference genome fasta file') @@ -42,15 +47,18 @@ def get_options(): type=int, choices=[0, 1], help='mask scores representing annotated acceptor/donor gain and ' 'unannotated acceptor/donor loss, defaults to 0') - parser.add_argument('-B', '--prediction-batch-size', metavar='prediction_batch_size', default=1, type=int, + parser.add_argument('-B', '--prediction_batch_size', metavar='prediction_batch_size', default=1, type=int, help='number of predictions to process at a time, note a single vcf record ' 'may have multiple predictions for overlapping genes and multiple alts') - parser.add_argument('-T', '--tensorflow-batch-size', metavar='tensorflow_batch_size', type=int, + parser.add_argument('-T', '--tensorflow_batch_size', metavar='tensorflow_batch_size', type=int, help='tensorflow batch size for model predictions') parser.add_argument('-V', '--verbose', action='store_true', help='enables verbose logging') parser.add_argument('-t','--tmpdir', metavar='tmpdir',type=str,default='/tmp/',required=False, help="Use Alternate location to store tmp files. (Note: B=4096 equals to roughly 15Gb of tmp files)") - + parser.add_argument('-G','--gpus',metavar='gpus',type=str,default='all',required=False, + help="Number of GPUs to use for SpliceAI. Provide 'all', or comma-seperated list of GPUs to use. eg '0,2' (first and third). Defaults to 'all'") + parser.add_argument('-S', '--simulated_gpus',metavar='simulated_gpus',default='0',type=int, required=False, + help="For development: simulated logical gpus on a single physical device to simulate a multi-gpu environment") args = parser.parse_args() return args @@ -58,116 +66,127 @@ def get_options(): def main(): args = get_options() - + # logging if args.verbose: loglevel = logging.DEBUG else: loglevel = logging.INFO - logging.basicConfig( format='%(asctime)s %(levelname)s %(name)s: - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=loglevel, ) - - if None in [args.input, args.output, args.distance, args.mask]: + # sanity check for mandatory arguments + if None in [args.input_data, args.output_data, args.distance, args.mask]: logging.error('Usage: spliceai [-h] [-I [input]] [-O [output]] -R reference -A annotation ' '[-D [distance]] [-M [mask]] [-B [prediction_batch_size]] [-T [tensorflow_batch_size]] [-t [tmp_location]]') exit() - # Default the tensorflow batch size to the prediction_batch_size if it's not supplied in the args - tensorflow_batch_size = args.tensorflow_batch_size if args.tensorflow_batch_size else args.prediction_batch_size - # load annotation data: - ann = Annotator(args.reference, args.annotation) ## revised code for batched analysis if args.prediction_batch_size > 1: - run_spliceai_batched(input_data=args.input, output_data=args.output, reference=args.reference, - ann=ann, distance=args.distance, mask=args.mask, - prediction_batch_size=args.prediction_batch_size, - tensorflow_batch_size=tensorflow_batch_size,tempdir=args.tmpdir) + # initialize the GPU and setup to estimate + devices,mem_per_logical = initialize_devices(args) + # Default the tensorflow batch size to the prediction_batch_size if it's not supplied in the args + args.tensorflow_batch_size = args.tensorflow_batch_size if args.tensorflow_batch_size else args.prediction_batch_size + + # load annotation data: + ann = Annotator(args.reference, args.annotation) + logging.debug("Annotation loaded.") + # run + run_spliceai_batched(args,ann,devices,mem_per_logical) + else: # run original code: - run_spliceai(input_data=args.input, output_data=args.output, ann=ann, distance=args.distance, mask=args.mask) + # load annotation + ann = Annotator(args.reference, args.annotation) + # run scoring + run_spliceai(args, ann) # input_data=args.input, output_data=args.output, ann=ann, distance=args.distance, mask=args.mask) -## revised logic to allow batched tensorflow analysis -def run_spliceai_batched(input_data, output_data, reference, ann, distance, mask, prediction_batch_size, - tensorflow_batch_size,tempdir): + +## revised logic to allow batched tensorflow analysis on multiple GPUs +def run_spliceai_batched(args, ann,devices,mem_per_logical): #input_data, output_data, reference, ann, distance, mask, prediction_batch_size, + #tensorflow_batch_size,tempdir,devices,args): + ## GOAL + ## - launch a reader that preps & pickles input vcf + ## - launch per GPU/device, using sockets, a utility script that runs tasks from the queue on that device. + ## - communicate through sockets : server threads issue items from the queue to worker clients + ## - when all predictions are done, build the output vcf. + + + ## track start time + start_time = time.time() + ## variables: + input_data = args.input_data + output_data = args.output_data + distance = args.distance + mask = args.mask + prediction_batch_size = args.prediction_batch_size + tensorflow_batch_size = args.tensorflow_batch_size + ## mk a temp directory - tmpdir = tempfile.TemporaryDirectory(dir=tempdir) - logging.debug("tmp dir : {}".format(tmpdir.name)) - # initialize the prediction object - batch = VCFPredictionBatch( - ann=ann, - output=output_data, - dist=distance, - mask=mask, - prediction_batch_size=prediction_batch_size, - tensorflow_batch_size=tensorflow_batch_size, - tmpdir = tmpdir - ) - + tmpdir = tempfile.mkdtemp(dir=args.tmpdir) # TemporaryDirectory(dir=args.tmpdir) + #tmpdir = tmpdir.name + logging.debug("tmp dir : {}".format(tmpdir)) + # creates a queue with max 10 ready-to-go batches in it. - # starts processing & filling the queue. prediction_queue = Queue(maxsize=10) - reader = Process(target=prepare_batches, kwargs={'ann':ann, - 'input_data':input_data, - 'prediction_batch_size':prediction_batch_size, - 'prediction_queue': prediction_queue, - 'tmpdir':tmpdir, - 'dist' : distance - }) + # starts processing & filling the queue. + reader_args={'ann':ann, 'args':args, 'tmpdir':tmpdir, 'prediction_queue': prediction_queue, 'nr_workers': len(devices)} + reader = Process(target=prepare_batches, kwargs=reader_args) reader.start() - # Process the queue. - batch.process_batches(prediction_queue) + worker_clients, worker_servers, devices = start_workers(prediction_queue,tmpdir,args,devices,mem_per_logical) - # join the reader process. + ## wait for everything to finish. + # readers sends finish signal to workers + logging.debug("Waiting for VCF reader to join") reader.join() + logging.debug("Reader joined!") + # clients receive signal, send it to servers. + logging.debug("Waiting for workers to join.") + for p in worker_clients: + # subprocesses : wait() + p.wait() + logging.debug("Workers are done!") + logging.debug("Waiting for servers to join.") + for p in worker_servers: + # mp processes : join() + p.join() + logging.debug("SErvers are done") # stats without writing phase - prediction_duration = time.time() - batch.start_time - - # write results. - # Iterate over original list of vcf records again, reconstructing record with annotations from shelved data + prediction_duration = time.time() - start_time + + # write results. in/out from args, devices to get shelf names logging.debug("Writing output file") - vcf = pysam.VariantFile(input_data) - # have to update header again - header = vcf.header - header.add_line('##INFO=') - try: - batch.output_data = pysam.VariantFile(output_data, mode='w', header=header) - - except (IOError, ValueError) as e: - logging.error('{}'.format(e)) - exit() - - batch.write_records(vcf) - # close shelf - batch.shelf_preds.close() - # close vcf - vcf.close() - batch.output_data.close() - + writer = VCFWriter(args=args,tmpdir=tmpdir,devices=devices,ann=ann) + writer.process() + # Iterate over original list of vcf records again, reconstructing record with annotations from shelved data + logging.debug("Writing output file") + ## stats - overall_duration = time.time() - batch.start_time - preds_per_sec = batch.total_predictions / prediction_duration + overall_duration = time.time() - start_time + preds_per_sec = writer.total_predictions / prediction_duration preds_per_hour = preds_per_sec * 60 * 60 logging.info("Analysis Finished. Statistics:") logging.info("Total RunTime: {:0.2f}s".format(overall_duration)) logging.info("Prediction RunTime: {:0.2f}s".format(prediction_duration)) - logging.info("Processed Records: {}".format(batch.total_vcf_records)) - logging.info("Processed Predictions: {}".format(batch.total_predictions)) + logging.info("Processed Records: {}".format(writer.total_vcf_records)) + logging.info("Processed Predictions: {}".format(writer.total_predictions)) logging.info("Overall performance : {:0.2f} predictions/sec ; {:0.2f} predictions/hour".format(preds_per_sec, preds_per_hour)) # original flow : record by record reading/predict/write -def run_spliceai(input_data, output_data, ann, distance, mask): - +def run_spliceai(args, ann): + # assign variables + input_data = args.input_data + output_data = args_output_data + distance = args.distance + mask = args.mask + + # open infile try: vcf = pysam.VariantFile(input_data) except (IOError, ValueError) as e: diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index cd5d101..a81e245 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -11,31 +11,104 @@ import os import tensorflow as tf import pickle +import gc +import socket +import sys +import argparse + +#from spliceai.batch.batch_utils import extract_delta_scores, get_preds +sys.path.append('../../../spliceai') +from spliceai.batch.batch_utils import get_preds, initialize_devices +from spliceai.utils import Annotator, get_delta_scores -from spliceai.batch.batch_utils import extract_delta_scores, get_preds -logger = logging.getLogger(__name__) SequenceType_REF = 0 SequenceType_ALT = 1 +# options : revised from __main__ +def get_options(): + + parser = argparse.ArgumentParser(description='Version: 1.3.1') + parser.add_argument('-R', '--reference', metavar='reference', required=True, + help='path to the reference genome fasta file') + parser.add_argument('-A', '--annotation',metavar='annotation', required=True, + help='"grch37" (GENCODE V24lift37 canonical annotation file in ' + 'package), "grch38" (GENCODE V24 canonical annotation file in ' + 'package), or path to a similar custom gene annotation file') + parser.add_argument('-T', '--tensorflow_batch_size', metavar='tensorflow_batch_size', type=int, + help='tensorflow batch size for model predictions') + parser.add_argument('-V', '--verbose', action='store_true', help='enables verbose logging') + parser.add_argument('-t','--tmpdir', metavar='tmpdir',type=str,default='/tmp/',required=False, + help="Use Alternate location to store tmp files. (Note: B=4096 equals to roughly 15Gb of tmp files)") + parser.add_argument('-d','--device',metavar='device',type=str,required=True, + help="CPU/GPU device to deploy worker on") + parser.add_argument('-S', '--simulated_gpus',metavar='simulated_gpus',default='0',type=int, required=False, + help="For development: simulated logical gpus on a single physical device to simulate a multi-gpu environment") + parser.add_argument('-M', '--mem_per_logical', metavar='mem_per_logical',default=0,type=int, required=False, + help="For simulated GPUs assign this amount of memory (Mb)") + parser.add_argument('-G','--gpus',metavar='gpus',type=str,default='all',required=False, + help="Number of GPUs to use for SpliceAI. Provide 'all', or comma-seperated list of GPUs to use. eg '0,2' (first and third). Defaults to 'all'") + args = parser.parse_args() + + return args + + +def main(): + # get arguments + args = get_options() + if args.verbose: + loglevel = logging.DEBUG + else: + loglevel = logging.INFO + logging.basicConfig( + format='%(asctime)s %(levelname)s %(name)s: - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=loglevel, + ) + logger = logging.getLogger(__name__) + + # initialize && assign device + devices = [x for x in initialize_devices(args)[0] if x.name == args.device] + if not devices: + logger.error(f"Specified device '{args.device}' not found!") + sys.exit(1) + device = devices[0].name + # get annotator + logger.info("loading annotations") + ann = Annotator(args.reference, args.annotation) + + + with tf.device(device): + logger.info(f"Working on device {device}") + # initialize the VCFPredictionBatch + worker = VCFPredictionBatch(ann=ann, tensorflow_batch_size=args.tensorflow_batch_size, tmpdir=args.tmpdir,device=device,logger=logger) + # start working ! + worker.process_batches() + # done. + + + + # Class to handle predictions class VCFPredictionBatch: - def __init__(self, ann, output, dist, mask, prediction_batch_size, tensorflow_batch_size,tmpdir): + def __init__(self, ann, tensorflow_batch_size, tmpdir,device,logger): self.ann = ann - self.output = output - self.dist = dist - self.mask = mask + #self.output = output + #self.dist = dist + #self.mask = mask + self.device = device # This is the maximum number of predictions to parse/encode/predict at a time - self.prediction_batch_size = prediction_batch_size + #self.prediction_batch_size = prediction_batch_size # This is the size of the batch tensorflow will use to make the predictions self.tensorflow_batch_size = tensorflow_batch_size # track runtime - self.start_time = time.time() + #self.start_time = time.time() # Batch vars self.batches = {} - self.prepared_vcf_records = [] + #self.prepared_vcf_records = [] + self.logger = logger # Counts self.total_predictions = 0 @@ -46,29 +119,53 @@ def __init__(self, ann, output, dist, mask, prediction_batch_size, tensorflow_ba # shelves to track data. self.tmpdir = tmpdir - # store batches of predictions using 'tensor_size|batch_idx' as key. - self.shelf_preds = shelve.open(os.path.join(self.tmpdir.name,"spliceai_preds.shelf")) + # store batches of predictions using 'tensor_size|batch_idx' as key. + self.shelf_preds_name = f"spliceai_preds.{self.device[1:].replace(':','_')}.shelf" + self.shelf_preds = shelve.open(os.path.join(self.tmpdir, self.shelf_preds_name)) # monitor the queue and submit incoming batches. - def process_batches(self,prediction_queue): - while True: - item =prediction_queue.get() - # reader submits None when all are queued. - if item is None: - break - # load pickled object - with open(os.path.join(self.tmpdir.name,item),'rb') as p: - data = pickle.load(p) - # remove from disk. - os.unlink(os.path.join(self.tmpdir.name,item)) - self._process_batch(data['tensor_size'],data['batch_ix'], data['data'],data['length']) - + def process_batches(self): + with socket.socket() as s: + host = socket.gethostname() # locahost + port = 54677 + try: + s.connect((host,port)) + except Exception as e: + raise(e) + # first response : server is running + res = s.recv(2048) + # then start polling queue + msg = "Ready for work..." + + while True: + # send request for work + s.send(str.encode(msg)) + res = s.recv(2048).decode('utf-8') + # response can be a job, 'hold on' for empty queue, or 'Done' for all finished. + if res == 'Hold On': + msg = 'Ready for work...' + time.sleep(0.1) + elif res == 'Finished': + self.logger.info("Worker done. Shutting down") + break + else: + # got a batch id: + with open(os.path.join(self.tmpdir,res),'rb') as p: + data = pickle.load(p) + # remove pickled batch + os.unlink(os.path.join(self.tmpdir,res)) + # process : stats are send back as next 'ready for work' result. + msg = self._process_batch(data['tensor_size'],data['batch_ix'], data['data'],data['length']) + # send signal to server thread to exit. + s.send(str.encode('Done')) + self.logger.info(f"Closing Worker on device {self.device}") + def _process_batch(self,tensor_size,batch_ix, prediction_batch,nr_preds): start = time.time() # Sanity check dump of batch sizes - logger.debug('Tensor size : {} : batch_ix {} : nr.entries : {}'.format(tensor_size, batch_ix , nr_preds)) + self.logger.debug('Tensor size : {} : batch_ix {} : nr.entries : {}'.format(tensor_size, batch_ix , nr_preds)) # Run predictions && add to shelf. self.shelf_preds["{}|{}".format(tensor_size,batch_ix)] = np.mean( @@ -79,66 +176,12 @@ def _process_batch(self,tensor_size,batch_ix, prediction_batch,nr_preds): duration = time.time() - start preds_per_sec = nr_preds / duration preds_per_hour = preds_per_sec * 60 * 60 - logger.debug('Finished in {:0.2f}s, per sec: {:0.2f}, per hour: {:0.2f}'.format(duration, - preds_per_sec, - preds_per_hour)) - - # wrapper to write out all shelved variants - def write_records(self, vcf): - # open the shelf with records: - shelf_records = shelve.open(os.path.join(self.tmpdir.name,"spliceai_records.shelf")) - # parse vcf - line_idx = 0 - batch = [] - last_batch_key = '' - for record in vcf: - line_idx += 1 - # get prepared record by line_idx - prepared_record = shelf_records[str(line_idx)] - gene_info = prepared_record.gene_info - # (REF + #ALT ) * #genes (* 5 models) - self.total_predictions += (1 + len(record.alts)) * len(gene_info.genes) - - all_y_ref = [] - all_y_alt = [] - - # Each prediction in the batch is located and put into the correct y - for location in prepared_record.locations: - # No prediction here - if location.tensor_size == 0: - if location.sequence_type == SequenceType_REF: - all_y_ref.append(None) - else: - all_y_alt.append(None) - continue - - # Extract the prediction from the batch into a list of predictions for this record - # recycle the batch variable if key is the same. - if not last_batch_key == "{}|{}".format(location.tensor_size,location.batch_ix): - last_batch_key = "{}|{}".format(location.tensor_size,location.batch_ix) - batch = self.shelf_preds[last_batch_key] # batch_preds[location.tensor_size] - - if location.sequence_type == SequenceType_REF: - all_y_ref.append(batch[[location.batch_index], :, :]) - else: - all_y_alt.append(batch[[location.batch_index], :, :]) - delta_scores = extract_delta_scores( - all_y_ref=all_y_ref, - all_y_alt=all_y_alt, - record=record, - ann=self.ann, - dist_var=self.dist, - mask=self.mask, - gene_info=gene_info, - ) - - # If there are predictions, write them to the VCF INFO section - if len(delta_scores) > 0: - record.info['SpliceAI'] = delta_scores - - self.output_data.write(record) - # close shelf again - self.total_vcf_records = line_idx - shelf_records.close() - + msg = 'Device {} : Finished in {:0.2f}s, per sec: {:0.2f}, per hour: {:0.2f}'.format(self.device, duration, preds_per_sec, preds_per_hour) + self.logger.debug(msg) + return msg + + + +if __name__ == '__main__': + main() diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index 00d672b..853e277 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -7,49 +7,121 @@ import logging import shelve -import pysam -import collections +#import pysam +#import collections import os import gc import numpy as np import tensorflow as tf import pickle +import socket +from multiprocessing import Process +import subprocess +import time +import sys from spliceai.utils import get_alt_gene_delta_score, is_record_valid, get_seq, \ is_location_predictable, get_cov, get_wid, is_valid_alt_record, encode_seqs, create_unhandled_delta_score -logger = logging.getLogger(__name__) - +sys.path.append('../../spliceai') +from spliceai.batch.data_handlers import VCFReader, VCFWriter -## CUSTOM DATA TYPES -SequenceType_REF = 0 -SequenceType_ALT = 1 - -BatchLookupIndex = collections.namedtuple( - # ref/alt size batch for this size index in current batch for this size - 'BatchLookupIndex', 'sequence_type tensor_size batch_ix batch_index' -) +logger = logging.getLogger(__name__) -PreparedVCFRecord = collections.namedtuple( - 'PreparedVCFRecord', 'vcf_idx gene_info locations' -) +########### +## INPUT ## +########### ## routine to create the batches for prediction. -def prepare_batches(ann, input_data,prediction_batch_size, prediction_queue,tmpdir,dist): +def prepare_batches(ann, args, tmpdir, prediction_queue,nr_workers): # input_data,prediction_batch_size, prediction_queue,tmpdir,dist): # create the parser object - vcf_reader = VCFReader(ann=ann, input_data=input_data, prediction_batch_size=prediction_batch_size, prediction_queue=prediction_queue,tmpdir=tmpdir,dist=dist) + vcf_reader = VCFReader(ann=ann, + input_data=args.input_data, + prediction_batch_size=args.prediction_batch_size, + prediction_queue=prediction_queue, + tmpdir=tmpdir,dist=args.distance, + ) # parse records vcf_reader.add_records() # finalize last batches - vcf_reader.finish() + vcf_reader.finish(nr_workers) # close the shelf. vcf_reader.shelf_records.close() # stats logger.info("Read {} vcf records, queued {} predictions".format(vcf_reader.total_vcf_records, vcf_reader.total_predictions)) -## get tensorflow predictions using batch-based submissions + + + +############## +## ANALYSIS ## +############## +## routine to start the worker Threads +def start_workers(prediction_queue, tmpdir, args,devices,mem_per_logical): + # start server socket + s = socket.socket() + host = socket.gethostname() # locahost + port = 54677 + try: + s.bind((host,port)) + except Exception as e: + logger.error(f"Cannot bind to port {port} : {e}") + sys.exit(1) + s.listen(5) + # start client sockets & server threads. + clientThreads = list() + serverThreads = list() + + for device in devices: + # launch the worker. + cmd = ["python",os.path.join(os.path.dirname(os.path.realpath(__file__)),"batch.py"),"-S",str(args.simulated_gpus),"-M",str(int(mem_per_logical)), "-t",tmpdir,"-d",device.name, '-R', args.reference, '-A', args.annotation, '-T', str(args.tensorflow_batch_size)] + if args.verbose: + cmd.append('-V') + #print(cmd) + fh_stdout = open(tmpdir+'/'+device.name.replace('/','_').replace(':','.')+'.stdout','w') + fh_stderr = open(tmpdir+'/'+device.name.replace('/','_').replace(':','.')+'.stderr','w') + + p = subprocess.Popen(cmd) # ,stdout=fh_stdout, stderr=fh_stderr) + clientThreads.append(p) + ## then a new thread in the server for this connection. + client, address = s.accept() + logger.debug("Connected to : " + address[0] + ' : ' + str(address[1])) + p = Process(target=_process_server,args=(client,device.name,prediction_queue,)) + p.start() + serverThreads.append(p) + logger.debug(f"Thread {device.name} activated!") + + return clientThreads, serverThreads, devices + +# routine that runs in the server threads, issuing work to the worker_clients. +def _process_server(clientsocket,device,queue): + # initial response + clientsocket.send(str.encode('Server is online')) + while True: + msg = clientsocket.recv(2048).decode('utf-8') + if msg == 'Done': + logger.debug(f"Stopping thread {device}") + break + elif not msg == 'Ready for work...': + logger.debug(msg) + # send/get new item + try: + item = queue.get(False) + except Exception as e: + #print(str(e)) + item = 'Hold On' + + # set reply + clientsocket.sendall(str.encode(str(item))) + + logger.debug(f"Closing {device} socket.") + clientsocket.close() + + + +## get tensorflow predictions using batch-based submissions (used in worker clients) def get_preds(ann, x, batch_size=32): logger.debug('Running get_preds with matrix size: {}'.format(x.shape)) try: @@ -63,259 +135,61 @@ def get_preds(ann, x, batch_size=32): return predictions - -# Heavily based on utils.get_delta_scores but only handles the validation and encoding -# of the record, but doesn't do any of the prediction or post-processing steps -def encode_batch_records(record, ann, dist_var, gene_info): - cov = get_cov(dist_var) - wid = get_wid(cov) - # If the record is not going to get a prediction, return this empty encoding - empty_encoding = ([], []) - - if not is_record_valid(record): - return empty_encoding - - seq = get_seq(record, ann, wid) - if not seq: - return empty_encoding - - if not is_location_predictable(record, seq, wid, dist_var): - return empty_encoding - - all_x_ref = [] - all_x_alt = [] - for alt_ix in range(len(record.alts)): - for gene_ix in range(len(gene_info.idxs)): - - if not is_valid_alt_record(record, alt_ix): - continue - - x_ref, x_alt = encode_seqs(record=record, - seq=seq, - ann=ann, - gene_info=gene_info, - gene_ix=gene_ix, - alt_ix=alt_ix, - wid=wid) - - all_x_ref.append(x_ref) - all_x_alt.append(x_alt) - - return all_x_ref, all_x_alt - - -# Heavily based on utils.get_delta_scores but only handles the post-processing steps after -# the models have made the predictions -def extract_delta_scores( - all_y_ref, all_y_alt, record, ann, dist_var, mask, gene_info -): - cov = get_cov(dist_var) - delta_scores = [] - pred_ix = 0 - for alt_ix in range(len(record.alts)): - for gene_ix in range(len(gene_info.idxs)): - - - # Pull prediction out of batch - try: - y_ref = all_y_ref[pred_ix] - y_alt = all_y_alt[pred_ix] - except IndexError: - logger.warn("No data for record below, alt_ix {} : gene_ix {} : pred_ix {}".format(alt_ix, gene_ix,pred_ix)) - logger.warn(record) - continue - except Exception as e: - logger.error("Predction error: {}".format(e)) - logger.error(record) - raise e - - # No prediction here - if y_ref is None or y_alt is None: - continue - - if not is_valid_alt_record(record, alt_ix): - continue - - if len(record.ref) > 1 and len(record.alts[alt_ix]) > 1: - pred_ix += 1 - delta_score = create_unhandled_delta_score(record.alts[alt_ix], gene_info.genes[gene_ix]) - delta_scores.append(delta_score) - continue - - if pred_ix >= len(all_y_ref) or pred_ix >= len(all_y_alt): - raise LookupError( - 'Prediction index {} does not exist in prediction matrices: ref({}) alt({})'.format( - pred_ix, len(all_y_ref), len(all_y_alt) - ) - ) - - delta_score = get_alt_gene_delta_score(record=record, - ann=ann, - alt_ix=alt_ix, - gene_ix=gene_ix, - y_ref=y_ref, - y_alt=y_alt, - cov=cov, - gene_info=gene_info, - mask=mask) - delta_scores.append(delta_score) - - pred_ix += 1 - - return delta_scores - - - -# class to parse input and prep batches -class VCFReader: - def __init__(self, ann, input_data, prediction_batch_size, prediction_queue,tmpdir,dist): - self.ann = ann - # This is the maximum number of predictions to parse/encode/predict at a time - self.prediction_batch_size = prediction_batch_size - # the vcf file - self.input_data = input_data - # window to consider - self.dist = dist - # Batch vars - self.batches = {} - #self.prepared_vcf_records = [] - - # Counts - self.total_predictions = 0 - self.total_vcf_records = 0 - self.batch_counters = {} - - # the queue - self.prediction_queue = prediction_queue - - # shelves to track data. - self.tmpdir = tmpdir - # track records to have order correct - self.shelf_records = shelve.open(os.path.join(self.tmpdir.name,"spliceai_records.shelf")) - - - - def add_records(self): - +## management routine to initialize gpu/cpu devices and do simulated logical devices if needed +def initialize_devices(args): + ## need to simulate gpus ? + gpus = tf.config.list_physical_devices('GPU') + mem_per_logical = 0 + if gpus and args.simulated_gpus > 1: + logger.warning(f"Simulating {args.simulated_gpus} logical GPUs on the first physical GPU device") try: - vcf = pysam.VariantFile(self.input_data) - except (IOError, ValueError) as e: - logging.error('{}'.format(e)) + gpu_mem_mb = _get_gpu_memory() + except Exception as e: + logger.error(f"Could not get GPU memory (needs nvidia-smi) : {e}") + sys.exit(1) + + # Create n virtual GPUs with [available] / n GB memory each + if hasattr(args,'mem_per_logical'): + mem_per_logical = args.mem_per_logical + else: + mem_per_logical = (int(gpu_mem_mb[0])-2048) / args.simulated_gpus + + logger.info(f"Assigning {mem_per_logical}mb of GPU memory per simulated GPU.") + try: + device_list = [tf.config.LogicalDeviceConfiguration(memory_limit=mem_per_logical)] * args.simulated_gpus + tf.config.set_logical_device_configuration( + gpus[0], + device_list) + logical_gpus = tf.config.list_logical_devices('GPU') + + except RuntimeError as e: + # Virtual devices must be set before GPUs have been initialized raise(e) - for record in vcf: - try: - self.add_record(record) - except Exception as e: - raise(e) - vcf.close() - - - def add_record(self, record): - """ - Adds a record to a batch. It'll capture the gene information for the record and - save it for later to avoid looking it up again, then it'll encode ref and alt from - the VCF record and place the encoded values into lists of matching sizes. Once the - encoded values are added, a BatchLookupIndex is created so that after the predictions - are made, it knows where to look up the corresponding prediction for the vcf record. - - Once the batch size hits it's capacity, it'll process all the predictions for the - encoded batch. - """ - - self.total_vcf_records += 1 - # Collect gene information for this record - gene_info = self.ann.get_name_and_strand(record.chrom, record.pos) - - # Keep track of how many predictions we're going to make - prediction_count = len(record.alts) * len(gene_info.genes) - self.total_predictions += prediction_count - - # Collect lists of encoded ref/alt sequences - x_ref, x_alt = encode_batch_records(record, self.ann, self.dist, gene_info) + if gpus: + prediction_devices = tf.config.list_logical_devices('GPU') + if not args.gpus.lower() == 'all': + idxs = [int(x) for x in args.gpus.split(',')] + prediction_devices = [prediction_devices[x] for x in idx] + else: + # run on cpu + prediction_devices = tf.config.list_logical_devices('CPU')[0] - # List of BatchLookupIndex's so we know how to lookup predictions for records from - # the batches - batch_lookup_indexes = [] + logger.info("Using the following devices for prediction:") + for d in prediction_devices: + logger.info(f" - {d.name}") + # add verbosity + if args.verbose: + tf.debugging.set_log_device_placement(True) - # Process the encodings into batches - for var_type, encoded_seq in zip((SequenceType_REF, SequenceType_ALT), (x_ref, x_alt)): + return prediction_devices, mem_per_logical - if len(encoded_seq) == 0: - # Add BatchLookupIndex with zeros so when the batch collects the outputs - # it knows that there is no prediction for this record - batch_lookup_indexes.append(BatchLookupIndex(var_type, 0, 0, 0)) - continue +## helper function to get gpu memory. +def _get_gpu_memory(): + command = "nvidia-smi --query-gpu=memory.free --format=csv" + memory_free_info = subprocess.check_output(command.split()).decode('ascii').split('\n')[:-1][1:] + memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)] + return memory_free_values - # Iterate over the encoded sequence and drop into the correct batch by size and - # create an index to use to pull out the result after batch is processed - for row in encoded_seq: - # Extract the size of the sequence that was encoded to build a batch from - tensor_size = row.shape[1] - # Create batch for this size - if tensor_size not in self.batches: - self.batches[tensor_size] = [] - self.batch_counters[tensor_size] = 0 - - # Add encoded record to batch 'n' for tensor_size - self.batches[tensor_size].append(row) - - # Get the index of the record we just added in the batch - cur_batch_record_ix = len(self.batches[tensor_size]) - 1 - - # Store a reference so we can pull out the prediction for this item from the batches - batch_lookup_indexes.append(BatchLookupIndex(var_type, tensor_size, self.batch_counters[tensor_size] , cur_batch_record_ix)) - - # Save the batch locations for this record on the composite object - prepared_record = PreparedVCFRecord( - vcf_idx=self.total_vcf_records, gene_info=gene_info, locations=batch_lookup_indexes - ) - # add to shelf by vcf_idx - self.shelf_records[str(self.total_vcf_records)] = prepared_record - - # If we're reached our threshold for the max items to process, then process the batch - for tensor_size in self.batch_counters: - if len(self.batches[tensor_size]) >= self.prediction_batch_size: - logger.debug("Batch {} full. Adding to queue".format(tensor_size)) - # fully prep the batch outside of gpu routine... - data = np.concatenate(self.batches[tensor_size]) - concat_len = len(data) - # offload conversion of batch from np to tensor to CPU - with tf.device('CPU:0'): - data = tf.convert_to_tensor(data) - queue_item = {'tensor_size': tensor_size, 'batch_ix': self.batch_counters[tensor_size], 'data' : data, 'length':concat_len} - with open(os.path.join(self.tmpdir.name,"{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])),"wb") as p: - pickle.dump(queue_item,p) - self.prediction_queue.put("{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])) - - # reset - self.batches[tensor_size] = [] - self.batch_counters[tensor_size] += 1 - - #self._process_batch(tensor_size) - - def finish(self): - """ - Method to process all the remaining items that have been added to the batches. - """ - #if len(self.prepared_vcf_records) > 0: - # self._process_batch() - logger.debug("Queueing remaining batches") - for tensor_size in self.batch_counters: - if len(self.batches[tensor_size] ) > 0: - # fully prep the batch outside of gpu routine... - data = np.concatenate(self.batches[tensor_size]) - concat_len = len(data) - # offload conversion of batch from np to tensor to CPU - with tf.device('CPU:0'): - data = tf.convert_to_tensor(data) - queue_item = {'tensor_size': tensor_size, 'batch_ix': self.batch_counters[tensor_size], 'data' : data, 'length':concat_len} - with open(os.path.join(self.tmpdir.name,"{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])),"wb") as p: - pickle.dump(queue_item,p) - self.prediction_queue.put("{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])) - # clear - self.batches[tensor_size] = [] - # all done : - self.prediction_queue.put(None) \ No newline at end of file diff --git a/spliceai/batch/data_handlers.py b/spliceai/batch/data_handlers.py new file mode 100644 index 0000000..1e90d0a --- /dev/null +++ b/spliceai/batch/data_handlers.py @@ -0,0 +1,419 @@ +import logging +import shelve +import pysam +import collections +import os +import numpy as np +import pickle +import tensorflow as tf +import sys + +#from spliceai.utils import get_alt_gene_delta_score, is_record_valid, get_seq, \ +# is_location_predictable, get_cov, get_wid, is_valid_alt_record, encode_seqs, create_unhandled_delta_score +from spliceai.utils import get_cov, get_wid, get_seq, is_record_valid, is_location_predictable, \ + is_valid_alt_record, encode_seqs, create_unhandled_delta_score, get_alt_gene_delta_score + + +logger = logging.getLogger(__name__) + + +## CUSTOM DATA TYPES +SequenceType_REF = 0 +SequenceType_ALT = 1 + +BatchLookupIndex = collections.namedtuple( + # ref/alt size batch for this size index in current batch for this size + 'BatchLookupIndex', 'sequence_type tensor_size batch_ix batch_index' +) + +PreparedVCFRecord = collections.namedtuple( + 'PreparedVCFRecord', 'vcf_idx gene_info locations' +) + + +# class to parse input and prep batches +class VCFReader: + def __init__(self, ann, input_data, prediction_batch_size, prediction_queue, tmpdir, dist): + self.ann = ann + # This is the maximum number of predictions to parse/encode/predict at a time + self.prediction_batch_size = prediction_batch_size + # the vcf file + self.input_data = input_data + # window to consider + self.dist = dist + # Batch vars + self.batches = {} + #self.prepared_vcf_records = [] + + # Counts + self.total_predictions = 0 + self.total_vcf_records = 0 + self.batch_counters = {} + + # the queue + self.prediction_queue = prediction_queue + + # shelves to track data. + self.tmpdir = tmpdir + # track records to have order correct + logging.info("Opening spliceai_records shelf") + try: + self.shelf_records = shelve.open(os.path.join(self.tmpdir,"spliceai_records.shelf")) + except Exception as e: + logging.error(f"Could not open shelf: {e}") + raise(e) + + + def add_records(self): + + try: + vcf = pysam.VariantFile(self.input_data) + except (IOError, ValueError) as e: + logging.error('{}'.format(e)) + raise(e) + for record in vcf: + try: + self.add_record(record) + except Exception as e: + raise(e) + vcf.close() + + + def add_record(self, record): + """ + Adds a record to a batch. It'll capture the gene information for the record and + save it for later to avoid looking it up again, then it'll encode ref and alt from + the VCF record and place the encoded values into lists of matching sizes. Once the + encoded values are added, a BatchLookupIndex is created so that after the predictions + are made, it knows where to look up the corresponding prediction for the vcf record. + + Once the batch size hits it's capacity, it'll process all the predictions for the + encoded batch. + """ + + self.total_vcf_records += 1 + # Collect gene information for this record + gene_info = self.ann.get_name_and_strand(record.chrom, record.pos) + + # Keep track of how many predictions we're going to make + prediction_count = len(record.alts) * len(gene_info.genes) + self.total_predictions += prediction_count + + # Collect lists of encoded ref/alt sequences + x_ref, x_alt = self._encode_batch_records(record, self.ann, self.dist, gene_info) + + # List of BatchLookupIndex's so we know how to lookup predictions for records from + # the batches + batch_lookup_indexes = [] + + # Process the encodings into batches + for var_type, encoded_seq in zip((SequenceType_REF, SequenceType_ALT), (x_ref, x_alt)): + + if len(encoded_seq) == 0: + # Add BatchLookupIndex with zeros so when the batch collects the outputs + # it knows that there is no prediction for this record + batch_lookup_indexes.append(BatchLookupIndex(var_type, 0, 0, 0)) + continue + + # Iterate over the encoded sequence and drop into the correct batch by size and + # create an index to use to pull out the result after batch is processed + for row in encoded_seq: + # Extract the size of the sequence that was encoded to build a batch from + tensor_size = row.shape[1] + + # Create batch for this size + if tensor_size not in self.batches: + self.batches[tensor_size] = [] + self.batch_counters[tensor_size] = 0 + + # Add encoded record to batch 'n' for tensor_size + self.batches[tensor_size].append(row) + + # Get the index of the record we just added in the batch + cur_batch_record_ix = len(self.batches[tensor_size]) - 1 + + # Store a reference so we can pull out the prediction for this item from the batches + batch_lookup_indexes.append(BatchLookupIndex(var_type, tensor_size, self.batch_counters[tensor_size] , cur_batch_record_ix)) + + # Save the batch locations for this record on the composite object + prepared_record = PreparedVCFRecord( + vcf_idx=self.total_vcf_records, gene_info=gene_info, locations=batch_lookup_indexes + ) + # add to shelf by vcf_idx + self.shelf_records[str(self.total_vcf_records)] = prepared_record + + # If we're reached our threshold for the max items to process, then process the batch + for tensor_size in self.batch_counters: + if len(self.batches[tensor_size]) >= self.prediction_batch_size: + logger.debug("Batch {} full. Adding to queue".format(tensor_size)) + # fully prep the batch outside of gpu routine... + data = np.concatenate(self.batches[tensor_size]) + concat_len = len(data) + # offload conversion of batch from np to tensor to CPU + with tf.device('CPU:0'): + data = tf.convert_to_tensor(data) + queue_item = {'tensor_size': tensor_size, 'batch_ix': self.batch_counters[tensor_size], 'data' : data, 'length':concat_len} + with open(os.path.join(self.tmpdir,"{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])),"wb") as p: + pickle.dump(queue_item,p) + self.prediction_queue.put("{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])) + + # reset + self.batches[tensor_size] = [] + self.batch_counters[tensor_size] += 1 + + #self._process_batch(tensor_size) + + + + def finish(self,nr_workers): + """ + Method to process all the remaining items that have been added to the batches. + """ + #if len(self.prepared_vcf_records) > 0: + # self._process_batch() + logger.debug("Queueing remaining batches") + for tensor_size in self.batch_counters: + if len(self.batches[tensor_size] ) > 0: + # fully prep the batch outside of gpu routine... + data = np.concatenate(self.batches[tensor_size]) + concat_len = len(data) + # offload conversion of batch from np to tensor to CPU + with tf.device('CPU:0'): + data = tf.convert_to_tensor(data) + queue_item = {'tensor_size': tensor_size, 'batch_ix': self.batch_counters[tensor_size], 'data' : data, 'length':concat_len} + with open(os.path.join(self.tmpdir,"{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])),"wb") as p: + pickle.dump(queue_item,p) + self.prediction_queue.put("{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])) + # clear + self.batches[tensor_size] = [] + # all done : push finish signals (one per process device..). + logging.debug("Queueing finish signals") + for i in range(nr_workers): + self.prediction_queue.put('Finished') + + + # Heavily based on utils.get_delta_scores but only handles the validation and encoding + # of the record, but doesn't do any of the prediction or post-processing steps + def _encode_batch_records(self, record, ann, dist_var, gene_info): + cov = get_cov(dist_var) + wid = get_wid(cov) + # If the record is not going to get a prediction, return this empty encoding + empty_encoding = ([], []) + + if not is_record_valid(record): + return empty_encoding + + seq = get_seq(record, ann, wid) + if not seq: + return empty_encoding + + if not is_location_predictable(record, seq, wid, dist_var): + return empty_encoding + + all_x_ref = [] + all_x_alt = [] + for alt_ix in range(len(record.alts)): + for gene_ix in range(len(gene_info.idxs)): + + if not is_valid_alt_record(record, alt_ix): + continue + + x_ref, x_alt = encode_seqs(record=record, + seq=seq, + ann=ann, + gene_info=gene_info, + gene_ix=gene_ix, + alt_ix=alt_ix, + wid=wid) + + all_x_ref.append(x_ref) + all_x_alt.append(x_alt) + + return all_x_ref, all_x_alt + + + + +# class to parse input and prep batches +class VCFWriter: + def __init__(self, args, tmpdir, devices, ann): + + self.args = args + # the vcf file + self.input_data = args.input_data + self.output_data = args.output_data + # window to consider + self.dist = args.distance + # used devices + self.devices = [x.name for x in devices] + # shelves to track data. + self.tmpdir = tmpdir + # track records to have order correct + self.shelf_records = shelve.open(os.path.join(self.tmpdir,"spliceai_records.shelf")) + # trackers + self.total_records = 0 + self.total_predictions = 0 + # annotations. + self.ann = ann + + def process(self): + # prepare the global pred_shelf + self._aggregate_predictions() + + # open the files & update header: + self.vcf_in = pysam.VariantFile(self.input_data) + header = self.vcf_in.header + header.add_line('##INFO=') + self.vcf_out = pysam.VariantFile(self.output_data,mode='w',header=header) + + # write the output vcf. + self._write_records() + + # close shelves + self.shelf_records.close() + self.shelf_preds.close() + + # close output file. + self.vcf_in.close() + self.vcf_out.close() + + # aggregate shelves over the devices + def _aggregate_predictions(self): + logger.debug("Aggregating device shelves") + self.shelf_preds_name = f"spliceai_preds.shelf" + self.shelf_preds = shelve.open(os.path.join(self.tmpdir, self.shelf_preds_name)) + for device in self.devices: + device_shelf_name = f"spliceai_preds.{device[1:].replace(':','_')}.shelf" + device_shelf_preds = shelve.open(os.path.join(self.tmpdir, device_shelf_name)) + for x in device_shelf_preds: + self.shelf_preds[x] = device_shelf_preds[x] + device_shelf_preds.close() + + + + # wrapper to write out all shelved variants + def _write_records(self): + logger.debug("Writing output file") + # open the shelf with records: + #shelf_records = shelve.open(os.path.join(self.tmpdir.name,"spliceai_records.shelf")) + # parse vcf + line_idx = 0 + batch = [] + last_batch_key = '' + for record in self.vcf_in: + line_idx += 1 + # get prepared record by line_idx + prepared_record = self.shelf_records[str(line_idx)] + gene_info = prepared_record.gene_info + # (REF + #ALT ) * #genes (* 5 models) + self.total_predictions += (1 + len(record.alts)) * len(gene_info.genes) + + all_y_ref = [] + all_y_alt = [] + + # Each prediction in the batch is located and put into the correct y + for location in prepared_record.locations: + # No prediction here + if location.tensor_size == 0: + if location.sequence_type == SequenceType_REF: + all_y_ref.append(None) + else: + all_y_alt.append(None) + continue + + # Extract the prediction from the batch into a list of predictions for this record + # recycle the batch variable if key is the same. + if not last_batch_key == "{}|{}".format(location.tensor_size,location.batch_ix): + last_batch_key = "{}|{}".format(location.tensor_size,location.batch_ix) + batch = self.shelf_preds[last_batch_key] + + if location.sequence_type == SequenceType_REF: + all_y_ref.append(batch[[location.batch_index], :, :]) + else: + all_y_alt.append(batch[[location.batch_index], :, :]) + # get delta scores + delta_scores = self._extract_delta_scores( + all_y_ref=all_y_ref, + all_y_alt=all_y_alt, + record=record, + gene_info=gene_info, + ) + + # If there are predictions, write them to the VCF INFO section + if len(delta_scores) > 0: + record.info['SpliceAI'] = delta_scores + + self.vcf_out.write(record) + # close shelf again + self.total_vcf_records = line_idx + + + # Heavily based on utils.get_delta_scores but only handles the post-processing steps after + # the models have made the predictions + def _extract_delta_scores(self, all_y_ref, all_y_alt, record, gene_info): + # variables: + dist_var = self.dist + ann = self.ann + mask = self.args.mask + + cov = get_cov(dist_var) + delta_scores = [] + pred_ix = 0 + for alt_ix in range(len(record.alts)): + for gene_ix in range(len(gene_info.idxs)): + + + # Pull prediction out of batch + try: + y_ref = all_y_ref[pred_ix] + y_alt = all_y_alt[pred_ix] + except IndexError: + logger.warn("No data for record below, alt_ix {} : gene_ix {} : pred_ix {}".format(alt_ix, gene_ix,pred_ix)) + logger.warn(record) + continue + except Exception as e: + logger.error("Predction error: {}".format(e)) + logger.error(record) + raise e + + # No prediction here + if y_ref is None or y_alt is None: + continue + + if not is_valid_alt_record(record, alt_ix): + continue + + if len(record.ref) > 1 and len(record.alts[alt_ix]) > 1: + pred_ix += 1 + delta_score = create_unhandled_delta_score(record.alts[alt_ix], gene_info.genes[gene_ix]) + delta_scores.append(delta_score) + continue + + if pred_ix >= len(all_y_ref) or pred_ix >= len(all_y_alt): + raise LookupError( + 'Prediction index {} does not exist in prediction matrices: ref({}) alt({})'.format( + pred_ix, len(all_y_ref), len(all_y_alt) + ) + ) + + delta_score = get_alt_gene_delta_score(record=record, + ann=ann, + alt_ix=alt_ix, + gene_ix=gene_ix, + y_ref=y_ref, + y_alt=y_alt, + cov=cov, + gene_info=gene_info, + mask=mask) + delta_scores.append(delta_score) + + pred_ix += 1 + + return delta_scores + + + + diff --git a/spliceai/utils.py b/spliceai/utils.py index 5049084..99eea84 100644 --- a/spliceai/utils.py +++ b/spliceai/utils.py @@ -11,6 +11,7 @@ import numpy as np from pyfaidx import Fasta from keras.models import load_model +import tensorflow as tf import logging import gc @@ -49,9 +50,10 @@ def __init__(self, ref_fasta, annotations): except IOError as e: logging.error('{}'.format(e)) exit() - paths = ('models/spliceai{}.h5'.format(x) for x in range(1, 6)) - self.models = [load_model(resource_filename(__name__, x)) for x in paths] + # use CPU memory for loading models, to prevent gpu memory allocation. + with tf.device('CPU:0'): + self.models = [load_model(resource_filename(__name__, x)) for x in paths] def get_name_and_strand(self, chrom, pos): From cbc9679ac999ec201c0af3c2e3d345bbc2bd0655 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Tue, 29 Nov 2022 06:57:36 +0100 Subject: [PATCH 23/42] relocate annotation loader in worker --- spliceai/batch/batch.py | 33 ++++++++++++++++++--------------- spliceai/batch/batch_utils.py | 3 ++- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index a81e245..e9f5eda 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -75,15 +75,13 @@ def main(): logger.error(f"Specified device '{args.device}' not found!") sys.exit(1) device = devices[0].name - # get annotator - logger.info("loading annotations") - ann = Annotator(args.reference, args.annotation) - with tf.device(device): logger.info(f"Working on device {device}") + #logger.info("loading annotations") + #ann = Annotator(args.reference, args.annotation) # initialize the VCFPredictionBatch - worker = VCFPredictionBatch(ann=ann, tensorflow_batch_size=args.tensorflow_batch_size, tmpdir=args.tmpdir,device=device,logger=logger) + worker = VCFPredictionBatch(args=args,device=device,logger=logger) # , tensorflow_batch_size=args.tensorflow_batch_size, tmpdir=args.tmpdir,device=device,logger=logger) # start working ! worker.process_batches() # done. @@ -94,21 +92,23 @@ def main(): # Class to handle predictions class VCFPredictionBatch: def __init__(self, ann, tensorflow_batch_size, tmpdir,device,logger): - self.ann = ann - #self.output = output - #self.dist = dist - #self.mask = mask + self.args = args + self.ann = None + self.tensorflow_batch_size = args.tensorflow_batch_size + self.tmpdir = args.tmpdir self.device = device - # This is the maximum number of predictions to parse/encode/predict at a time - #self.prediction_batch_size = prediction_batch_size + self.logger = logger + + #self.ann = ann + #self.device = device + # This is the size of the batch tensorflow will use to make the predictions - self.tensorflow_batch_size = tensorflow_batch_size - # track runtime - #self.start_time = time.time() + # self.tensorflow_batch_size = tensorflow_batch_size + # Batch vars self.batches = {} #self.prepared_vcf_records = [] - self.logger = logger + # self.logger = logger # Counts self.total_predictions = 0 @@ -137,6 +137,9 @@ def process_batches(self): # then start polling queue msg = "Ready for work..." + # first load annotation + if not self.ann: + self.ann = Annotator(self.args.reference, self.args.annotation) while True: # send request for work s.send(str.encode(msg)) diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index 853e277..9eab675 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -76,6 +76,7 @@ def start_workers(prediction_queue, tmpdir, args,devices,mem_per_logical): for device in devices: # launch the worker. + logger.info(f"Starting worker on device {device.name}") cmd = ["python",os.path.join(os.path.dirname(os.path.realpath(__file__)),"batch.py"),"-S",str(args.simulated_gpus),"-M",str(int(mem_per_logical)), "-t",tmpdir,"-d",device.name, '-R', args.reference, '-A', args.annotation, '-T', str(args.tensorflow_batch_size)] if args.verbose: cmd.append('-V') @@ -91,7 +92,7 @@ def start_workers(prediction_queue, tmpdir, args,devices,mem_per_logical): p = Process(target=_process_server,args=(client,device.name,prediction_queue,)) p.start() serverThreads.append(p) - logger.debug(f"Thread {device.name} activated!") + logger.info(f"Thread {device.name} activated!") return clientThreads, serverThreads, devices From 0754a37ac3adc318abadec7f3d6aac5c2cfc5700 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Tue, 29 Nov 2022 07:06:17 +0100 Subject: [PATCH 24/42] fix arguments --- spliceai/batch/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index e9f5eda..e67c12f 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -91,7 +91,7 @@ def main(): # Class to handle predictions class VCFPredictionBatch: - def __init__(self, ann, tensorflow_batch_size, tmpdir,device,logger): + def __init__(self, args, device, logger): # ann, tensorflow_batch_size, tmpdir,device,logger): self.args = args self.ann = None self.tensorflow_batch_size = args.tensorflow_batch_size From e2ec04b2a92185ba9293835a80d3d56c3712e153 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Tue, 29 Nov 2022 07:12:57 +0100 Subject: [PATCH 25/42] fix arguments --- spliceai/batch/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index e67c12f..c51bd47 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -118,7 +118,7 @@ def __init__(self, args, device, logger): # ann, tensorflow_batch_size, tmpdir,d # shelves to track data. - self.tmpdir = tmpdir + #self.tmpdir = tmpdir # store batches of predictions using 'tensor_size|batch_idx' as key. self.shelf_preds_name = f"spliceai_preds.{self.device[1:].replace(':','_')}.shelf" self.shelf_preds = shelve.open(os.path.join(self.tmpdir, self.shelf_preds_name)) From aad6bf9b7016688c1f10022dbaafcecdfe41170f Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Tue, 29 Nov 2022 07:37:41 +0100 Subject: [PATCH 26/42] add small sleep investigating issues above 3 gpus --- spliceai/batch/batch.py | 3 ++- spliceai/batch/batch_utils.py | 1 + spliceai/utils.py | 7 +++++-- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index c51bd47..8eeec6a 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -139,7 +139,8 @@ def process_batches(self): # first load annotation if not self.ann: - self.ann = Annotator(self.args.reference, self.args.annotation) + # load annotation + self.ann = Annotator(self.args.reference, self.args.annotation,cpu=True) while True: # send request for work s.send(str.encode(msg)) diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index 9eab675..f43cdbc 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -93,6 +93,7 @@ def start_workers(prediction_queue, tmpdir, args,devices,mem_per_logical): p.start() serverThreads.append(p) logger.info(f"Thread {device.name} activated!") + time.sleep(3) return clientThreads, serverThreads, devices diff --git a/spliceai/utils.py b/spliceai/utils.py index 99eea84..c45cf30 100644 --- a/spliceai/utils.py +++ b/spliceai/utils.py @@ -20,7 +20,7 @@ class Annotator: - def __init__(self, ref_fasta, annotations): + def __init__(self, ref_fasta, annotations,cpu=True): if annotations == 'grch37': annotations = resource_filename(__name__, 'annotations/grch37.txt') @@ -52,7 +52,10 @@ def __init__(self, ref_fasta, annotations): exit() paths = ('models/spliceai{}.h5'.format(x) for x in range(1, 6)) # use CPU memory for loading models, to prevent gpu memory allocation. - with tf.device('CPU:0'): + if cpu: + with tf.device('CPU:0'): + self.models = [load_model(resource_filename(__name__, x)) for x in paths] + else: self.models = [load_model(resource_filename(__name__, x)) for x in paths] def get_name_and_strand(self, chrom, pos): From fdbcf1de1a8db9a678faea965069b5a6dd2f4e61 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Tue, 29 Nov 2022 07:58:27 +0100 Subject: [PATCH 27/42] looking into startup issues --- spliceai/batch/batch_utils.py | 55 ++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index f43cdbc..d5bcd43 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -142,36 +142,39 @@ def initialize_devices(args): ## need to simulate gpus ? gpus = tf.config.list_physical_devices('GPU') mem_per_logical = 0 - if gpus and args.simulated_gpus > 1: - logger.warning(f"Simulating {args.simulated_gpus} logical GPUs on the first physical GPU device") - try: - gpu_mem_mb = _get_gpu_memory() - except Exception as e: - logger.error(f"Could not get GPU memory (needs nvidia-smi) : {e}") - sys.exit(1) - - # Create n virtual GPUs with [available] / n GB memory each - if hasattr(args,'mem_per_logical'): - mem_per_logical = args.mem_per_logical + if gpus: + if args.simulated_gpus > 1: + logger.warning(f"Simulating {args.simulated_gpus} logical GPUs on the first physical GPU device") + try: + gpu_mem_mb = _get_gpu_memory() + except Exception as e: + logger.error(f"Could not get GPU memory (needs nvidia-smi) : {e}") + sys.exit(1) + + # Create n virtual GPUs with [available] / n GB memory each + if hasattr(args,'mem_per_logical'): + mem_per_logical = args.mem_per_logical + else: + mem_per_logical = (int(gpu_mem_mb[0])-2048) / args.simulated_gpus + + logger.info(f"Assigning {mem_per_logical}mb of GPU memory per simulated GPU.") + try: + device_list = [tf.config.LogicalDeviceConfiguration(memory_limit=mem_per_logical)] * args.simulated_gpus + tf.config.set_logical_device_configuration( + gpus[0], + device_list) + logical_gpus = tf.config.list_logical_devices('GPU') + + except RuntimeError as e: + # Virtual devices must be set before GPUs have been initialized + raise(e) + prediction_devices = tf.config.list_logical_devices('GPU') else: - mem_per_logical = (int(gpu_mem_mb[0])-2048) / args.simulated_gpus + prediction_devices = gpus - logger.info(f"Assigning {mem_per_logical}mb of GPU memory per simulated GPU.") - try: - device_list = [tf.config.LogicalDeviceConfiguration(memory_limit=mem_per_logical)] * args.simulated_gpus - tf.config.set_logical_device_configuration( - gpus[0], - device_list) - logical_gpus = tf.config.list_logical_devices('GPU') - - except RuntimeError as e: - # Virtual devices must be set before GPUs have been initialized - raise(e) - if gpus: - prediction_devices = tf.config.list_logical_devices('GPU') if not args.gpus.lower() == 'all': idxs = [int(x) for x in args.gpus.split(',')] - prediction_devices = [prediction_devices[x] for x in idx] + prediction_devices = [prediction_devices[x] for x in idxs] else: # run on cpu prediction_devices = tf.config.list_logical_devices('CPU')[0] From 31a2ad0d08912275e50142db48f409bbf919c593 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Tue, 29 Nov 2022 08:24:42 +0100 Subject: [PATCH 28/42] There is an issue when going above 2 GPUs --- spliceai/__main__.py | 6 ++++-- spliceai/batch/batch.py | 2 +- spliceai/batch/batch_utils.py | 7 ++++--- spliceai/batch/data_handlers.py | 2 +- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/spliceai/__main__.py b/spliceai/__main__.py index 048727b..4ba1501 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -9,7 +9,7 @@ import tempfile from multiprocessing import Process,Queue,Pool from functools import partial - +import shutil import tensorflow as tf import subprocess as sp import os @@ -165,7 +165,9 @@ def run_spliceai_batched(args, ann,devices,mem_per_logical): #input_data, output # Iterate over original list of vcf records again, reconstructing record with annotations from shelved data logging.debug("Writing output file") - + + # clear out tmp + shutil.rmtree(tmpdir) ## stats overall_duration = time.time() - start_time preds_per_sec = writer.total_predictions / prediction_duration diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index 8eeec6a..4c0fd34 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -74,7 +74,7 @@ def main(): if not devices: logger.error(f"Specified device '{args.device}' not found!") sys.exit(1) - device = devices[0].name + device = devices[0].name.replace('physical_','') with tf.device(device): logger.info(f"Working on device {device}") diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index d5bcd43..3516730 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -170,7 +170,8 @@ def initialize_devices(args): raise(e) prediction_devices = tf.config.list_logical_devices('GPU') else: - prediction_devices = gpus + logger.info("Running on physical devices") + prediction_devices = tf.config.list_physical_devices('GPU') if not args.gpus.lower() == 'all': idxs = [int(x) for x in args.gpus.split(',')] @@ -183,8 +184,8 @@ def initialize_devices(args): for d in prediction_devices: logger.info(f" - {d.name}") # add verbosity - if args.verbose: - tf.debugging.set_log_device_placement(True) + #if args.verbose: + # tf.debugging.set_log_device_placement(True) return prediction_devices, mem_per_logical diff --git a/spliceai/batch/data_handlers.py b/spliceai/batch/data_handlers.py index 1e90d0a..609dc68 100644 --- a/spliceai/batch/data_handlers.py +++ b/spliceai/batch/data_handlers.py @@ -286,7 +286,7 @@ def _aggregate_predictions(self): self.shelf_preds_name = f"spliceai_preds.shelf" self.shelf_preds = shelve.open(os.path.join(self.tmpdir, self.shelf_preds_name)) for device in self.devices: - device_shelf_name = f"spliceai_preds.{device[1:].replace(':','_')}.shelf" + device_shelf_name = f"spliceai_preds.{device[1:].replace('physical_','').replace(':','_')}.shelf" device_shelf_preds = shelve.open(os.path.join(self.tmpdir, device_shelf_name)) for x in device_shelf_preds: self.shelf_preds[x] = device_shelf_preds[x] From 0e89d639faf59760f908ee76dedda6698abb18bb Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Tue, 29 Nov 2022 09:19:02 +0100 Subject: [PATCH 29/42] hide physical devices in batch workers to evaluate memory issues --- spliceai/batch/batch.py | 13 +++++++++---- spliceai/batch/batch_utils.py | 30 +++++++++++++++++++++++++----- spliceai/batch/data_handlers.py | 2 +- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index 4c0fd34..579ff31 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -18,7 +18,7 @@ #from spliceai.batch.batch_utils import extract_delta_scores, get_preds sys.path.append('../../../spliceai') -from spliceai.batch.batch_utils import get_preds, initialize_devices +from spliceai.batch.batch_utils import get_preds, initialize_devices, initialize_one_device from spliceai.utils import Annotator, get_delta_scores @@ -70,12 +70,17 @@ def main(): logger = logging.getLogger(__name__) # initialize && assign device - devices = [x for x in initialize_devices(args)[0] if x.name == args.device] + # no simulation : set a physical + if args.simulated_gpus > 0: + devices = [x for x in initialize_devices(args)[0] if x.name == args.device] + else: + devices = initialize_one_device(args) + + if not devices: logger.error(f"Specified device '{args.device}' not found!") sys.exit(1) - device = devices[0].name.replace('physical_','') - + device = devices[0].name with tf.device(device): logger.info(f"Working on device {device}") #logger.info("loading annotations") diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index 3516730..a99fe4f 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -137,6 +137,29 @@ def get_preds(ann, x, batch_size=32): return predictions +## initialize a single device, hide others (only for non-simulated gpus) +def initialize_one_device(args): + gpus = tf.config.list_physical_devices('GPU') + if not gpus: + return tf.config.list_logical_devices('CPU') + # get the index of specified device. + idx = None + for i in range(len(gpus)): + if gpus[i].name.replace('physical_','') == args.device: + idx = i + break + if idx is None: + logger.error("Device not found") + logger.debug(idx) + logger.debug(args.device) + logger.debug([x.name.replace('physical_','') for x in gpus]) + raise Exception(f"specified device '{args.device}' not found.") + # set visible + tf.config.set_visible_devices(gpus[idx], 'GPU') + logical_devices = tf.config.list_logical_devices('GPU') + return logical_devices + + ## management routine to initialize gpu/cpu devices and do simulated logical devices if needed def initialize_devices(args): ## need to simulate gpus ? @@ -155,7 +178,7 @@ def initialize_devices(args): if hasattr(args,'mem_per_logical'): mem_per_logical = args.mem_per_logical else: - mem_per_logical = (int(gpu_mem_mb[0])-2048) / args.simulated_gpus + mem_per_logical = int((gpu_mem_mb[0]-2048) / args.simulated_gpus) logger.info(f"Assigning {mem_per_logical}mb of GPU memory per simulated GPU.") try: @@ -168,10 +191,7 @@ def initialize_devices(args): except RuntimeError as e: # Virtual devices must be set before GPUs have been initialized raise(e) - prediction_devices = tf.config.list_logical_devices('GPU') - else: - logger.info("Running on physical devices") - prediction_devices = tf.config.list_physical_devices('GPU') + prediction_devices = tf.config.list_logical_devices('GPU') if not args.gpus.lower() == 'all': idxs = [int(x) for x in args.gpus.split(',')] diff --git a/spliceai/batch/data_handlers.py b/spliceai/batch/data_handlers.py index 609dc68..1e90d0a 100644 --- a/spliceai/batch/data_handlers.py +++ b/spliceai/batch/data_handlers.py @@ -286,7 +286,7 @@ def _aggregate_predictions(self): self.shelf_preds_name = f"spliceai_preds.shelf" self.shelf_preds = shelve.open(os.path.join(self.tmpdir, self.shelf_preds_name)) for device in self.devices: - device_shelf_name = f"spliceai_preds.{device[1:].replace('physical_','').replace(':','_')}.shelf" + device_shelf_name = f"spliceai_preds.{device[1:].replace(':','_')}.shelf" device_shelf_preds = shelve.open(os.path.join(self.tmpdir, device_shelf_name)) for x in device_shelf_preds: self.shelf_preds[x] = device_shelf_preds[x] From 18c6dac287ca6f2b53ce56bbf523556e4b5203af Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Tue, 29 Nov 2022 09:34:46 +0100 Subject: [PATCH 30/42] pass nonmasked device to worker for correct shelf names --- spliceai/batch/batch.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index 579ff31..76a647a 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -72,6 +72,7 @@ def main(): # initialize && assign device # no simulation : set a physical if args.simulated_gpus > 0: + devices = [x for x in initialize_devices(args)[0] if x.name == args.device] else: devices = initialize_one_device(args) @@ -82,11 +83,11 @@ def main(): sys.exit(1) device = devices[0].name with tf.device(device): - logger.info(f"Working on device {device}") + logger.info(f"Working on device {args.device}") #logger.info("loading annotations") #ann = Annotator(args.reference, args.annotation) - # initialize the VCFPredictionBatch - worker = VCFPredictionBatch(args=args,device=device,logger=logger) # , tensorflow_batch_size=args.tensorflow_batch_size, tmpdir=args.tmpdir,device=device,logger=logger) + # initialize the VCFPredictionBatch, pass (non-masked) device name + worker = VCFPredictionBatch(args=args,logger=logger) # , tensorflow_batch_size=args.tensorflow_batch_size, tmpdir=args.tmpdir,device=device,logger=logger) # start working ! worker.process_batches() # done. @@ -101,24 +102,16 @@ def __init__(self, args, device, logger): # ann, tensorflow_batch_size, tmpdir,d self.ann = None self.tensorflow_batch_size = args.tensorflow_batch_size self.tmpdir = args.tmpdir - self.device = device + self.device = args.device self.logger = logger - #self.ann = ann - #self.device = device - - # This is the size of the batch tensorflow will use to make the predictions - # self.tensorflow_batch_size = tensorflow_batch_size - # Batch vars - self.batches = {} - #self.prepared_vcf_records = [] - # self.logger = logger + # self.batches = {} # Counts - self.total_predictions = 0 - self.total_vcf_records = 0 - self.batch_counters = {} + #self.total_predictions = 0 + #self.total_vcf_records = 0 + #self.batch_counters = {} From a1ce6c567f1a7316e3b78615a295f012cdc05d85 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Tue, 29 Nov 2022 09:43:40 +0100 Subject: [PATCH 31/42] corrected arguments --- spliceai/batch/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index 76a647a..e59c86c 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -97,7 +97,7 @@ def main(): # Class to handle predictions class VCFPredictionBatch: - def __init__(self, args, device, logger): # ann, tensorflow_batch_size, tmpdir,device,logger): + def __init__(self, args, logger): self.args = args self.ann = None self.tensorflow_batch_size = args.tensorflow_batch_size From d3cb4621e320b32cadd11143c118fa3394310648 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Tue, 29 Nov 2022 18:36:27 +0100 Subject: [PATCH 32/42] final code cleanup --- README.md | 9 ++++++++- spliceai/__main__.py | 14 +++++--------- spliceai/batch/batch.py | 19 ++----------------- spliceai/batch/batch_utils.py | 18 ++++++------------ spliceai/batch/data_handlers.py | 10 +--------- 5 files changed, 22 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index b4bd086..319860b 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,8 @@ Optional parameters: - ```-T```: Internal Tensorflow `predict()` batch size if you want something different from the `-B` value. (default: the `-B` value) - ```-V```: Enable verbose logging during run - ```-t```: Specify a location to create the temporary files + - ```-G```: Specify the GPU(s) to run on : either indexed (eg : 0,2) or 'all'. (default: 'all') + - ```-S```: Simulate *n* multiple GPUs on a single physical device. Used for development only, currently all values above 2 crashed due to memory issues. (default: 0) **Batching Considerations:** @@ -101,6 +103,7 @@ are running the script on. Feel free to experiment, but some reasonable `-T` num (b) : Illumina implementation showed a memory leak with the installed versions of tf/keras/.... Values extrapolated from incomplete runs at the point of OOM. +*Note:* On a p3.8xlarge machine, hosting 4 V100 GPU's, we were able reach 1,379,505 predictions/hour ! This is a nearly linear scale-up. ### Details of SpliceAI INFO field: @@ -171,9 +174,13 @@ donor_prob = y[0, :, 2] * Adds batch utility methods that split up what was all previously done in `get_delta_scores`. `encode_batch_record` handles what was in the first half, taking in the VCF record and generating one-hot encoded matrices for the ref/alts. `extract_delta_scores` handles the second half of the `get_delta_scores` by reassembling the annotations based on the batched predictions * Adds test cases to run a small file using a generated FASTA reference to test if the results are the same with no batching and with different batching sizes * Slightly modifies the entrypoint of running the code to allow for easier unit testing. Being able to pass in what would normally come from the argparser + +**Multi-GPU support** - Geert Vandeweyer (_November 2022_) + * Offload more code to CPU (eg np to tensor conversion) to *only* perform predictions on the GPU * Implement queuing system to always have full batches ready for prediction -* Implement new parameter, `--tmpdir` to support a custom tmp folder +* Implement new parameter, `--tmpdir` to support a custom tmp folder to store prepped batches +* Implement socket-based client/server approach to scale over multiple GPUs ### Contact diff --git a/spliceai/__main__.py b/spliceai/__main__.py index 4ba1501..7fa9188 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -100,12 +100,11 @@ def main(): # load annotation ann = Annotator(args.reference, args.annotation) # run scoring - run_spliceai(args, ann) # input_data=args.input, output_data=args.output, ann=ann, distance=args.distance, mask=args.mask) + run_spliceai(args, ann) ## revised logic to allow batched tensorflow analysis on multiple GPUs -def run_spliceai_batched(args, ann,devices,mem_per_logical): #input_data, output_data, reference, ann, distance, mask, prediction_batch_size, - #tensorflow_batch_size,tempdir,devices,args): +def run_spliceai_batched(args, ann,devices,mem_per_logical): ## GOAL ## - launch a reader that preps & pickles input vcf @@ -127,7 +126,7 @@ def run_spliceai_batched(args, ann,devices,mem_per_logical): #input_data, output ## mk a temp directory tmpdir = tempfile.mkdtemp(dir=args.tmpdir) # TemporaryDirectory(dir=args.tmpdir) #tmpdir = tmpdir.name - logging.debug("tmp dir : {}".format(tmpdir)) + logging.info("Using tmpdir : {}".format(tmpdir)) # creates a queue with max 10 ready-to-go batches in it. prediction_queue = Queue(maxsize=10) @@ -153,19 +152,16 @@ def run_spliceai_batched(args, ann,devices,mem_per_logical): #input_data, output for p in worker_servers: # mp processes : join() p.join() - logging.debug("SErvers are done") + logging.debug("Servers are done") # stats without writing phase prediction_duration = time.time() - start_time # write results. in/out from args, devices to get shelf names - logging.debug("Writing output file") + logging.info("Writing output file") writer = VCFWriter(args=args,tmpdir=tmpdir,devices=devices,ann=ann) writer.process() - # Iterate over original list of vcf records again, reconstructing record with annotations from shelved data - logging.debug("Writing output file") - # clear out tmp shutil.rmtree(tmpdir) ## stats diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index e59c86c..39125c1 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -70,11 +70,10 @@ def main(): logger = logging.getLogger(__name__) # initialize && assign device - # no simulation : set a physical if args.simulated_gpus > 0: - devices = [x for x in initialize_devices(args)[0] if x.name == args.device] else: + # no simulation : expose only the requested device to tensor. devices = initialize_one_device(args) @@ -84,10 +83,8 @@ def main(): device = devices[0].name with tf.device(device): logger.info(f"Working on device {args.device}") - #logger.info("loading annotations") - #ann = Annotator(args.reference, args.annotation) # initialize the VCFPredictionBatch, pass (non-masked) device name - worker = VCFPredictionBatch(args=args,logger=logger) # , tensorflow_batch_size=args.tensorflow_batch_size, tmpdir=args.tmpdir,device=device,logger=logger) + worker = VCFPredictionBatch(args=args,logger=logger) # start working ! worker.process_batches() # done. @@ -105,18 +102,6 @@ def __init__(self, args, logger): self.device = args.device self.logger = logger - # Batch vars - # self.batches = {} - - # Counts - #self.total_predictions = 0 - #self.total_vcf_records = 0 - #self.batch_counters = {} - - - - # shelves to track data. - #self.tmpdir = tmpdir # store batches of predictions using 'tensor_size|batch_idx' as key. self.shelf_preds_name = f"spliceai_preds.{self.device[1:].replace(':','_')}.shelf" self.shelf_preds = shelve.open(os.path.join(self.tmpdir, self.shelf_preds_name)) diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index a99fe4f..e716078 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -7,8 +7,6 @@ import logging import shelve -#import pysam -#import collections import os import gc import numpy as np @@ -34,7 +32,7 @@ ## INPUT ## ########### ## routine to create the batches for prediction. -def prepare_batches(ann, args, tmpdir, prediction_queue,nr_workers): # input_data,prediction_batch_size, prediction_queue,tmpdir,dist): +def prepare_batches(ann, args, tmpdir, prediction_queue,nr_workers): # create the parser object vcf_reader = VCFReader(ann=ann, input_data=args.input_data, @@ -76,15 +74,15 @@ def start_workers(prediction_queue, tmpdir, args,devices,mem_per_logical): for device in devices: # launch the worker. - logger.info(f"Starting worker on device {device.name}") + logger.info(f"Starting worker on device {device.name}, output is available under {tmpdir}") cmd = ["python",os.path.join(os.path.dirname(os.path.realpath(__file__)),"batch.py"),"-S",str(args.simulated_gpus),"-M",str(int(mem_per_logical)), "-t",tmpdir,"-d",device.name, '-R', args.reference, '-A', args.annotation, '-T', str(args.tensorflow_batch_size)] if args.verbose: cmd.append('-V') - #print(cmd) + logger.debug(cmd) fh_stdout = open(tmpdir+'/'+device.name.replace('/','_').replace(':','.')+'.stdout','w') fh_stderr = open(tmpdir+'/'+device.name.replace('/','_').replace(':','.')+'.stderr','w') - p = subprocess.Popen(cmd) # ,stdout=fh_stdout, stderr=fh_stderr) + p = subprocess.Popen(cmd ,stdout=fh_stdout, stderr=fh_stderr) clientThreads.append(p) ## then a new thread in the server for this connection. client, address = s.accept() @@ -92,8 +90,7 @@ def start_workers(prediction_queue, tmpdir, args,devices,mem_per_logical): p = Process(target=_process_server,args=(client,device.name,prediction_queue,)) p.start() serverThreads.append(p) - logger.info(f"Thread {device.name} activated!") - time.sleep(3) + logger.debug(f"Thread {device.name} activated!") return clientThreads, serverThreads, devices @@ -107,7 +104,7 @@ def _process_server(clientsocket,device,queue): logger.debug(f"Stopping thread {device}") break elif not msg == 'Ready for work...': - logger.debug(msg) + logger.info(msg) # send/get new item try: item = queue.get(False) @@ -203,9 +200,6 @@ def initialize_devices(args): logger.info("Using the following devices for prediction:") for d in prediction_devices: logger.info(f" - {d.name}") - # add verbosity - #if args.verbose: - # tf.debugging.set_log_device_placement(True) return prediction_devices, mem_per_logical diff --git a/spliceai/batch/data_handlers.py b/spliceai/batch/data_handlers.py index 1e90d0a..a96bce3 100644 --- a/spliceai/batch/data_handlers.py +++ b/spliceai/batch/data_handlers.py @@ -8,8 +8,6 @@ import tensorflow as tf import sys -#from spliceai.utils import get_alt_gene_delta_score, is_record_valid, get_seq, \ -# is_location_predictable, get_cov, get_wid, is_valid_alt_record, encode_seqs, create_unhandled_delta_score from spliceai.utils import get_cov, get_wid, get_seq, is_record_valid, is_location_predictable, \ is_valid_alt_record, encode_seqs, create_unhandled_delta_score, get_alt_gene_delta_score @@ -43,7 +41,6 @@ def __init__(self, ann, input_data, prediction_batch_size, prediction_queue, tmp self.dist = dist # Batch vars self.batches = {} - #self.prepared_vcf_records = [] # Counts self.total_predictions = 0 @@ -56,7 +53,7 @@ def __init__(self, ann, input_data, prediction_batch_size, prediction_queue, tmp # shelves to track data. self.tmpdir = tmpdir # track records to have order correct - logging.info("Opening spliceai_records shelf") + logging.debug("Opening spliceai_records shelf") try: self.shelf_records = shelve.open(os.path.join(self.tmpdir,"spliceai_records.shelf")) except Exception as e: @@ -161,7 +158,6 @@ def add_record(self, record): self.batches[tensor_size] = [] self.batch_counters[tensor_size] += 1 - #self._process_batch(tensor_size) @@ -169,8 +165,6 @@ def finish(self,nr_workers): """ Method to process all the remaining items that have been added to the batches. """ - #if len(self.prepared_vcf_records) > 0: - # self._process_batch() logger.debug("Queueing remaining batches") for tensor_size in self.batch_counters: if len(self.batches[tensor_size] ) > 0: @@ -297,8 +291,6 @@ def _aggregate_predictions(self): # wrapper to write out all shelved variants def _write_records(self): logger.debug("Writing output file") - # open the shelf with records: - #shelf_records = shelve.open(os.path.join(self.tmpdir.name,"spliceai_records.shelf")) # parse vcf line_idx = 0 batch = [] From bf00fb1a79ba101b72d36dfdb5d71d460b6ba4ba Mon Sep 17 00:00:00 2001 From: Barney Hill Date: Thu, 16 Mar 2023 11:04:12 +0000 Subject: [PATCH 33/42] Added custom port option --- spliceai/__main__.py | 4 +++- spliceai/batch/batch.py | 3 ++- spliceai/batch/batch_utils.py | 6 ++++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/spliceai/__main__.py b/spliceai/__main__.py index 7fa9188..d6e6ea0 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -29,6 +29,8 @@ def get_options(): parser = argparse.ArgumentParser(description='Version: 1.3.1') + parser.add_argument('-P', '--port', metavar='port', type=int, + help='option to change port if several GPUs on one network') parser.add_argument('-I', '--input_data', metavar='input', nargs='?', default=std_in, help='path to the input VCF file, defaults to standard in') parser.add_argument('-O', '--output_data', metavar='output', nargs='?', default=std_out, @@ -81,7 +83,7 @@ def main(): logging.error('Usage: spliceai [-h] [-I [input]] [-O [output]] -R reference -A annotation ' '[-D [distance]] [-M [mask]] [-B [prediction_batch_size]] [-T [tensorflow_batch_size]] [-t [tmp_location]]') exit() - + logging.debug(f"PORT:{args.port}") ## revised code for batched analysis if args.prediction_batch_size > 1: diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index 39125c1..b099407 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -31,6 +31,7 @@ def get_options(): parser = argparse.ArgumentParser(description='Version: 1.3.1') + parser.add_argument('-P', '--port', metavar='port', required=True, type=int) parser.add_argument('-R', '--reference', metavar='reference', required=True, help='path to the reference genome fasta file') parser.add_argument('-A', '--annotation',metavar='annotation', required=True, @@ -110,7 +111,7 @@ def __init__(self, args, logger): def process_batches(self): with socket.socket() as s: host = socket.gethostname() # locahost - port = 54677 + port = self.args.port try: s.connect((host,port)) except Exception as e: diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index e716078..2a80824 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -61,7 +61,9 @@ def start_workers(prediction_queue, tmpdir, args,devices,mem_per_logical): # start server socket s = socket.socket() host = socket.gethostname() # locahost - port = 54677 + port = args.port + logger.info(f"Starting server: {host}:{port}") + try: s.bind((host,port)) except Exception as e: @@ -75,7 +77,7 @@ def start_workers(prediction_queue, tmpdir, args,devices,mem_per_logical): for device in devices: # launch the worker. logger.info(f"Starting worker on device {device.name}, output is available under {tmpdir}") - cmd = ["python",os.path.join(os.path.dirname(os.path.realpath(__file__)),"batch.py"),"-S",str(args.simulated_gpus),"-M",str(int(mem_per_logical)), "-t",tmpdir,"-d",device.name, '-R', args.reference, '-A', args.annotation, '-T', str(args.tensorflow_batch_size)] + cmd = ["python",os.path.join(os.path.dirname(os.path.realpath(__file__)),"batch.py"),"-S",str(args.simulated_gpus),"-M",str(int(mem_per_logical)), "-t",tmpdir,"-d",device.name, '-R', args.reference, '-A', args.annotation, '-T', str(args.tensorflow_batch_size), '-P', str(args.port)] if args.verbose: cmd.append('-V') logger.debug(cmd) From c985bc13e204741eae6fb2fb31c4064c91e0523e Mon Sep 17 00:00:00 2001 From: Matthias Blum Date: Wed, 22 Mar 2023 17:44:18 +0000 Subject: [PATCH 34/42] Set default port --- spliceai/__main__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spliceai/__main__.py b/spliceai/__main__.py index d6e6ea0..43c2f65 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -29,8 +29,8 @@ def get_options(): parser = argparse.ArgumentParser(description='Version: 1.3.1') - parser.add_argument('-P', '--port', metavar='port', type=int, - help='option to change port if several GPUs on one network') + parser.add_argument('-P', '--port', metavar='port', type=int, default=54677, + help='option to change port if several GPUs on one network (default: 54677)') parser.add_argument('-I', '--input_data', metavar='input', nargs='?', default=std_in, help='path to the input VCF file, defaults to standard in') parser.add_argument('-O', '--output_data', metavar='output', nargs='?', default=std_out, From 36d22d6c3b161e951332cbbe48930e92a391e00a Mon Sep 17 00:00:00 2001 From: Matthias Blum Date: Wed, 22 Mar 2023 17:44:29 +0000 Subject: [PATCH 35/42] Add -P option to README --- README.md | 63 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 319860b..ac3c19a 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,8 @@ Optional parameters: - ```-V```: Enable verbose logging during run - ```-t```: Specify a location to create the temporary files - ```-G```: Specify the GPU(s) to run on : either indexed (eg : 0,2) or 'all'. (default: 'all') - - ```-S```: Simulate *n* multiple GPUs on a single physical device. Used for development only, currently all values above 2 crashed due to memory issues. (default: 0) + - ```-S```: Simulate *n* multiple GPUs on a single physical device. Used for development only, currently all values above 2 crashed due to memory issues. (default: 0) + - ```-P```: Port to use when connecting to the socket (default: 54677, only used in batch mode). **Batching Considerations:** @@ -79,24 +80,24 @@ are running the script on. Feel free to experiment, but some reasonable `-T` num *Benchmark results* -| Type | Implementation | Total Time | Speed (predictions / hour) | -| -------- | -------------- | ----------- | -------------------------- | -| CPU (intel i5-8365U)a | illumina | ~100h | ~1000 pred/h | -| | invitae | ~39h | ~4500 pred/h | -| | invitae v2 | ~35h | ~5000 pred/h | -| | invitae v2 optimal | ~35h | ~5000 pred/h | -| K80 GPU (AWS p2.large) | illuminab | ~25 h | ~7000 pred/h | -| | invitae | 242m | ~43,000 pred / h | -| | invitae v2 | 213m | ~50,000 pred / h | -| | invitae v2 optimal | 188 m | ~56,000 pred / h | -| GeForce RTX 2070 SUPER GPU (desktop) | illuminab | ~10 h | ~ 17,000 pred/h | -| | invitae | 76m | ~137,000 pred / h | -| | invitae v2 | 63m | ~166,000 pred / h | -| | invitae v2 optimal | 52m | ~200,000 pred / h | -| V100 GPU (AWS p3.xlarge) | illuminab | ~10h | ~18,000 pred/h | -| | invitae | 78m | ~135,000 pred / h | -| | invitae v2 | 54m | ~190,000 pred / h | -| | invitae v2 optimal | 31 m | ~335,000 pred / h | +| Type | Implementation | Total Time | Speed (predictions / hour) | +|--------------------------------------|-----------------------|------------|----------------------------| +| CPU (intel i5-8365U)a | illumina | ~100h | ~1000 pred/h | +| | invitae | ~39h | ~4500 pred/h | +| | invitae v2 | ~35h | ~5000 pred/h | +| | invitae v2 optimal | ~35h | ~5000 pred/h | +| K80 GPU (AWS p2.large) | illuminab | ~25 h | ~7000 pred/h | +| | invitae | 242m | ~43,000 pred / h | +| | invitae v2 | 213m | ~50,000 pred / h | +| | invitae v2 optimal | 188 m | ~56,000 pred / h | +| GeForce RTX 2070 SUPER GPU (desktop) | illuminab | ~10 h | ~ 17,000 pred/h | +| | invitae | 76m | ~137,000 pred / h | +| | invitae v2 | 63m | ~166,000 pred / h | +| | invitae v2 optimal | 52m | ~200,000 pred / h | +| V100 GPU (AWS p3.xlarge) | illuminab | ~10h | ~18,000 pred/h | +| | invitae | 78m | ~135,000 pred / h | +| | invitae v2 | 54m | ~190,000 pred / h | +| | invitae v2 optimal | 31 m | ~335,000 pred / h | (a) : Extrapolated from first 500 variants @@ -107,18 +108,18 @@ are running the script on. Feel free to experiment, but some reasonable `-T` num ### Details of SpliceAI INFO field: -| ID | Description | -| -------- | ----------- | -| ALLELE | Alternate allele | -| SYMBOL | Gene symbol | -| DS_AG | Delta score (acceptor gain) | -| DS_AL | Delta score (acceptor loss) | -| DS_DG | Delta score (donor gain) | -| DS_DL | Delta score (donor loss) | -| DP_AG | Delta position (acceptor gain) | -| DP_AL | Delta position (acceptor loss) | -| DP_DG | Delta position (donor gain) | -| DP_DL | Delta position (donor loss) | +| ID | Description | +|--------|--------------------------------| +| ALLELE | Alternate allele | +| SYMBOL | Gene symbol | +| DS_AG | Delta score (acceptor gain) | +| DS_AL | Delta score (acceptor loss) | +| DS_DG | Delta score (donor gain) | +| DS_DL | Delta score (donor loss) | +| DP_AG | Delta position (acceptor gain) | +| DP_AL | Delta position (acceptor loss) | +| DP_DG | Delta position (donor gain) | +| DP_DL | Delta position (donor loss) | Delta score of a variant, defined as the maximum of (DS_AG, DS_AL, DS_DG, DS_DL), ranges from 0 to 1 and can be interpreted as the probability of the variant being splice-altering. In the paper, a detailed characterization is provided for 0.2 (high recall), 0.5 (recommended), and 0.8 (high precision) cutoffs. Delta position conveys information about the location where splicing changes relative to the variant position (positive values are downstream of the variant, negative values are upstream). From e09bca67875034262499410cf85bc9e50c5451ab Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Mon, 9 Sep 2024 09:25:10 +0200 Subject: [PATCH 36/42] working on error handling to shutdown on issues --- README.md | 9 +++++++++ spliceai/__main__.py | 22 +++++++++++++++++++--- spliceai/batch/batch.py | 8 +++++++- spliceai/batch/batch_utils.py | 29 +++++++++++++++++++---------- 4 files changed, 54 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index ac3c19a..e6b76d3 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,15 @@ docker pull cmgantwerpen/spliceai_v1.3:latest docker run --gpus all cmgantwerpen/spliceai_v1.3:latest spliceai -h ``` +A container including reference and annotation data is available as well: + + +```sh +docker pull cmgantwerpen/spliceai_v1.3:full +``` +Note that this version has a larger footprint (12Gb). Data is available for Genome Build hg19 and hg38 under /data/ + + The simplest way to install (the original version of) SpliceAI is through pip or conda: ```sh diff --git a/spliceai/__main__.py b/spliceai/__main__.py index 43c2f65..df1ef00 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -16,7 +16,7 @@ from spliceai.batch.batch_utils import prepare_batches, start_workers,initialize_devices from spliceai.utils import Annotator, get_delta_scores -from spliceai.batch.data_handlers import VCFWriter +from spliceai.batch.data_handlers import VCFWriter,VCFReader try: from sys.stdin import buffer as std_in @@ -140,12 +140,28 @@ def run_spliceai_batched(args, ann,devices,mem_per_logical): worker_clients, worker_servers, devices = start_workers(prediction_queue,tmpdir,args,devices,mem_per_logical) ## wait for everything to finish. + # => If exit codes != 0 are detected, the main process will exit with the first non-zero exit code. + while True: + # any exit codes defined and != 0 ? + exit_codes = [p.exitcode for p in worker_servers + [reader] if p.exitcode is not None] + [p.poll() for p in worker_clients if p.poll() is not None] + if any(rc != 0 for rc in exit_codes): + logging.error("Error encountered Exiting.") + # kill all processes + for p in worker_servers + [reader]: + if p.is_alive(): + p.kill() + for p in worker_clients: + if p.poll() is None: + p.kill() + # and exit + sys.exit(1) + # readers sends finish signal to workers - logging.debug("Waiting for VCF reader to join") + logging.debug("Cleanup VCF reader") reader.join() logging.debug("Reader joined!") # clients receive signal, send it to servers. - logging.debug("Waiting for workers to join.") + logging.debug("Cleaning up workers.") for p in worker_clients: # subprocesses : wait() p.wait() diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index b099407..5841306 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -143,7 +143,13 @@ def process_batches(self): # remove pickled batch os.unlink(os.path.join(self.tmpdir,res)) # process : stats are send back as next 'ready for work' result. - msg = self._process_batch(data['tensor_size'],data['batch_ix'], data['data'],data['length']) + try: + msg = self._process_batch(data['tensor_size'],data['batch_ix'], data['data'],data['length']) + except Exception as e: + self.logger.error(f"Error processing batch {data['tensor_size']}|{data['batch_ix']}: {repr(e)}") + # send error message back to server + msg = "Error : {}".format(repr(e)) + # send signal to server thread to exit. s.send(str.encode('Done')) self.logger.info(f"Closing Worker on device {self.device}") diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index 2a80824..c6893ee 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -33,21 +33,25 @@ ########### ## routine to create the batches for prediction. def prepare_batches(ann, args, tmpdir, prediction_queue,nr_workers): - # create the parser object - vcf_reader = VCFReader(ann=ann, + try: + # create the parser object + vcf_reader = VCFReader(ann=ann, input_data=args.input_data, prediction_batch_size=args.prediction_batch_size, prediction_queue=prediction_queue, tmpdir=tmpdir,dist=args.distance, ) - # parse records - vcf_reader.add_records() - # finalize last batches - vcf_reader.finish(nr_workers) - # close the shelf. - vcf_reader.shelf_records.close() - # stats - logger.info("Read {} vcf records, queued {} predictions".format(vcf_reader.total_vcf_records, vcf_reader.total_predictions)) + # parse records + vcf_reader.add_records() + # finalize last batches + vcf_reader.finish(nr_workers) + # close the shelf. + vcf_reader.shelf_records.close() + # stats + logger.info("Read {} vcf records, queued {} predictions".format(vcf_reader.total_vcf_records, vcf_reader.total_predictions)) + except Exception as e: + logger.error(f"Error in prepare_batches: {repr(e)}") + raise(e) @@ -105,6 +109,11 @@ def _process_server(clientsocket,device,queue): if msg == 'Done': logger.debug(f"Stopping thread {device}") break + elif msg.startswith('Error'): + # send finish signal to worker to shut down cleanly + clientsocket.sendall(str.encode('Finished')) + # then raise the error to the main thread. + raise Exception(msg) elif not msg == 'Ready for work...': logger.info(msg) # send/get new item From a354e3819f8471601990b8d512c5c9b1a1e7bc93 Mon Sep 17 00:00:00 2001 From: geertvandeweyer Date: Fri, 13 Sep 2024 08:22:00 +0200 Subject: [PATCH 37/42] fix joining of workers --- spliceai/__main__.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/spliceai/__main__.py b/spliceai/__main__.py index df1ef00..79b4896 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -136,14 +136,15 @@ def run_spliceai_batched(args, ann,devices,mem_per_logical): reader_args={'ann':ann, 'args':args, 'tmpdir':tmpdir, 'prediction_queue': prediction_queue, 'nr_workers': len(devices)} reader = Process(target=prepare_batches, kwargs=reader_args) reader.start() - + logging.debug("Reader started") worker_clients, worker_servers, devices = start_workers(prediction_queue,tmpdir,args,devices,mem_per_logical) - + logging.debug("workers started") ## wait for everything to finish. # => If exit codes != 0 are detected, the main process will exit with the first non-zero exit code. while True: # any exit codes defined and != 0 ? exit_codes = [p.exitcode for p in worker_servers + [reader] if p.exitcode is not None] + [p.poll() for p in worker_clients if p.poll() is not None] + logging.debug("exit codes: {}".format(exit_codes)) if any(rc != 0 for rc in exit_codes): logging.error("Error encountered Exiting.") # kill all processes @@ -155,18 +156,21 @@ def run_spliceai_batched(args, ann,devices,mem_per_logical): p.kill() # and exit sys.exit(1) + if len(exit_codes) == len(worker_servers + [reader] + worker_clients): + break + time.sleep(30) # readers sends finish signal to workers - logging.debug("Cleanup VCF reader") + logging.info("Cleanup VCF reader") reader.join() logging.debug("Reader joined!") # clients receive signal, send it to servers. - logging.debug("Cleaning up workers.") + logging.info("Cleaning up workers.") for p in worker_clients: # subprocesses : wait() p.wait() logging.debug("Workers are done!") - logging.debug("Waiting for servers to join.") + logging.info("Waiting for servers to join.") for p in worker_servers: # mp processes : join() p.join() From 6d729452abcb7f69d4dfdb1545809fbce4c9c4a3 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Mon, 24 Mar 2025 07:11:07 +0100 Subject: [PATCH 38/42] improve error handling --- docker/Dockerfile.cuda-11.4.0 | 67 ++++++++++++++ spliceai/__main__.py | 37 ++++++-- spliceai/batch/batch.py | 166 ++++++++++++++++++++-------------- spliceai/batch/batch_utils.py | 2 +- 4 files changed, 198 insertions(+), 74 deletions(-) create mode 100644 docker/Dockerfile.cuda-11.4.0 diff --git a/docker/Dockerfile.cuda-11.4.0 b/docker/Dockerfile.cuda-11.4.0 new file mode 100644 index 0000000..4aefaff --- /dev/null +++ b/docker/Dockerfile.cuda-11.4.0 @@ -0,0 +1,67 @@ +###################################### +## CONTAINER FOR GPU based SpliceAI ## +###################################### + +# start from the cuda docker base +FROM nvidia/cuda:11.4.0-base-ubuntu20.04 + +LABEL version="1.3" +LABEL description="This container was tested with \ + - V100 on AWS p3.2xlarge with nvidia drivers 510.47.03 and cuda v11.6 \ + - K80 on AWS p2.xlarge with nvidia drivers 470.141.03 and cuda v11.4 \ + - Geforce RTX 2070 SUPER (local) with nvidia drivers 470.141.03 and cuda v11.4" + +LABEL author="Geert Vandeweyer" +LABEL author.email="geert.vandeweyer@uza.be" + +## needed apt packages +ARG BUILD_PACKAGES="wget git bzip2" +# needed conda packages + +ARG CONDA_PACKAGES="python=3.9.13 tensorflow-gpu=2.10.0 cuda-nvcc=11.8.89" + +## ENV SETTINGS during runtime +ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 +ENV PATH=/opt/conda/bin:$PATH +ENV DEBIAN_FRONTEND noninteractive + +# For micromamba: +SHELL ["/bin/bash", "-l", "-c"] +ENV MAMBA_ROOT_PREFIX=/opt/conda/ +ENV PATH=/opt/micromamba/bin:/opt/conda/bin:$PATH +ARG CONDA_CHANNEL="-c bioconda -c conda-forge -c nvidia" + +## INSTALL +RUN apt-get -y update && \ + apt-get -y install $BUILD_PACKAGES && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + + +# conda packages +RUN mkdir /opt/conda && \ + mkdir /opt/micromamba && \ + wget -qO - https://micromamba.snakepit.net/api/micromamba/linux-64/0.23.0 | tar -xvj -C /opt/micromamba bin/micromamba && \ + # initialize bash + micromamba shell init --shell=bash --prefix=/opt/conda && \ + # remove a statement from bashrc that prevents initialization + grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/micromamba/bashrc && \ + mv /opt/micromamba/bashrc /root/.bashrc && \ + source ~/.bashrc && \ + # activate & install base conda packag + micromamba activate && \ + micromamba install -y $CONDA_CHANNEL $CONDA_PACKAGES && \ + micromamba clean --all --yes + +# Break cache for recloning git +ARG DATE_CACHE_BREAK=$(date) + +# my fork of spliceai : has gpu optimizations +RUN cd /opt/ && \ + git clone https://github.com/geertvandeweyer/SpliceAI.git && \ + cd SpliceAI && \ + python setup.py install + +# no command given, print help. +CMD spliceai -h + diff --git a/spliceai/__main__.py b/spliceai/__main__.py index 79b4896..7f1387a 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -13,6 +13,7 @@ import tensorflow as tf import subprocess as sp import os +import socket from spliceai.batch.batch_utils import prepare_batches, start_workers,initialize_devices from spliceai.utils import Annotator, get_delta_scores @@ -29,8 +30,8 @@ def get_options(): parser = argparse.ArgumentParser(description='Version: 1.3.1') - parser.add_argument('-P', '--port', metavar='port', type=int, default=54677, - help='option to change port if several GPUs on one network (default: 54677)') + parser.add_argument('-P', '--port', metavar='port', type=int, default=None, + help='specify a port for socket/socketserver communication') parser.add_argument('-I', '--input_data', metavar='input', nargs='?', default=std_in, help='path to the input VCF file, defaults to standard in') parser.add_argument('-O', '--output_data', metavar='output', nargs='?', default=std_out, @@ -83,20 +84,42 @@ def main(): logging.error('Usage: spliceai [-h] [-I [input]] [-O [output]] -R reference -A annotation ' '[-D [distance]] [-M [mask]] [-B [prediction_batch_size]] [-T [tensorflow_batch_size]] [-t [tmp_location]]') exit() - logging.debug(f"PORT:{args.port}") + # select a free socket + if args.port is None: + try: + sock = socket.socket() + sock.bind(('', 0)) + args.port = sock.getsockname()[1] + logging.debug(f"PORT:{args.port}") + except Exception as e: + logging.error(f"Error: {repr(e)}") + sys.exit(1) ## revised code for batched analysis if args.prediction_batch_size > 1: # initialize the GPU and setup to estimate - devices,mem_per_logical = initialize_devices(args) + try: + devices,mem_per_logical = initialize_devices(args) + except Exception as e: + logging.error(f"Error initializing devices: {repr(e)}") + sys.exit(1) + # Default the tensorflow batch size to the prediction_batch_size if it's not supplied in the args args.tensorflow_batch_size = args.tensorflow_batch_size if args.tensorflow_batch_size else args.prediction_batch_size # load annotation data: - ann = Annotator(args.reference, args.annotation) - logging.debug("Annotation loaded.") + try: + ann = Annotator(args.reference, args.annotation) + logging.debug("Annotation loaded.") + except Exception as e: + logging.error(f"Error loading annotation: {repr(e)}") + sys.exit(1) # run - run_spliceai_batched(args,ann,devices,mem_per_logical) + try: + run_spliceai_batched(args,ann,devices,mem_per_logical) + except Exception as e: + logging.error(f"Error running SpliceAI: {repr(e)}") + sys.exit(1) else: # run original code: # load annotation diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index 5841306..398d74e 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -21,7 +21,9 @@ from spliceai.batch.batch_utils import get_preds, initialize_devices, initialize_one_device from spliceai.utils import Annotator, get_delta_scores - +class TensorFlowError(Exception): + """Custom exception for TensorFlow errors.""" + pass SequenceType_REF = 0 SequenceType_ALT = 1 @@ -55,6 +57,16 @@ def get_options(): return args +# set a trap so that on any unexpected error, the worker will send an error message back to the server. +def handle_exception(socket, exc_type, exc_value, exc_traceback): + logger = logging.getLogger(__name__) + logger.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)) + socket.send(str.encode(f"Error : {exc_value}")) + sys.__excepthook__(exc_type, exc_value, exc_traceback) + sys.exit(1) + +def setup_exception_hook(socket): + sys.excepthook = lambda *exc_info: handle_exception(socket, *exc_info) def main(): # get arguments @@ -70,38 +82,62 @@ def main(): ) logger = logging.getLogger(__name__) - # initialize && assign device - if args.simulated_gpus > 0: - devices = [x for x in initialize_devices(args)[0] if x.name == args.device] - else: - # no simulation : expose only the requested device to tensor. - devices = initialize_one_device(args) + # setup the socket + try: + s = socket.socket() + host = socket.gethostname() # locahost + port = args.port + s.connect((host, port)) + except Exception as e: + raise(e) + + # setup the exception hook + setup_exception_hook(s) + + try: + # initialize && assign device + if args.simulated_gpus > 0: + devices = [x for x in initialize_devices(args)[0] if x.name == args.device] + else: + # no simulation : expose only the requested device to tensor. + devices = initialize_one_device(args) + except Exception as e: + raise e if not devices: - logger.error(f"Specified device '{args.device}' not found!") - sys.exit(1) - device = devices[0].name - with tf.device(device): - logger.info(f"Working on device {args.device}") - # initialize the VCFPredictionBatch, pass (non-masked) device name - worker = VCFPredictionBatch(args=args,logger=logger) - # start working ! - worker.process_batches() + # raise with message + raise ValueError(f"Specified device '{args.device}' not found!") + + try: + device = devices[0].name + with tf.device(device): + logger.info(f"Working on device {args.device}") + # initialize the VCFPredictionBatch, pass (non-masked) device name + worker = VCFPredictionBatch(args=args,logger=logger,socket=s) + # start working ! + worker.process_batches() + except tf.errors.ResourceExhaustedError as e: + raise TensorFlowError("Caught TensorFlow OOM Error!") + except Exception as e: + raise e # done. - + s.close() + logger.info("Worker done. Shutting down") + sys.exit(0) # Class to handle predictions class VCFPredictionBatch: - def __init__(self, args, logger): + def __init__(self, args, logger, socket): self.args = args self.ann = None self.tensorflow_batch_size = args.tensorflow_batch_size self.tmpdir = args.tmpdir self.device = args.device self.logger = logger + self.socket = socket # store batches of predictions using 'tensor_size|batch_idx' as key. self.shelf_preds_name = f"spliceai_preds.{self.device[1:].replace(':','_')}.shelf" @@ -109,57 +145,55 @@ def __init__(self, args, logger): # monitor the queue and submit incoming batches. def process_batches(self): - with socket.socket() as s: - host = socket.gethostname() # locahost - port = self.args.port - try: - s.connect((host,port)) - except Exception as e: - raise(e) - # first response : server is running - res = s.recv(2048) - # then start polling queue - msg = "Ready for work..." - - # first load annotation - if not self.ann: - # load annotation - self.ann = Annotator(self.args.reference, self.args.annotation,cpu=True) - while True: - # send request for work - s.send(str.encode(msg)) - res = s.recv(2048).decode('utf-8') - # response can be a job, 'hold on' for empty queue, or 'Done' for all finished. - if res == 'Hold On': - msg = 'Ready for work...' - time.sleep(0.1) - elif res == 'Finished': - self.logger.info("Worker done. Shutting down") - break - else: - # got a batch id: - with open(os.path.join(self.tmpdir,res),'rb') as p: - data = pickle.load(p) - # remove pickled batch - os.unlink(os.path.join(self.tmpdir,res)) - # process : stats are send back as next 'ready for work' result. - try: - msg = self._process_batch(data['tensor_size'],data['batch_ix'], data['data'],data['length']) - except Exception as e: - self.logger.error(f"Error processing batch {data['tensor_size']}|{data['batch_ix']}: {repr(e)}") - # send error message back to server - msg = "Error : {}".format(repr(e)) - - # send signal to server thread to exit. - s.send(str.encode('Done')) - self.logger.info(f"Closing Worker on device {self.device}") - - - def _process_batch(self,tensor_size,batch_ix, prediction_batch,nr_preds): + + #host = socket.gethostname() # locahost + #port = self.args.port + #try: + # s.connect((host,port)) + #except Exception as e: + # raise(e) + # first response : server is running + res = self.socket.recv(2048) + # then start polling queue + msg = "Ready for work..." + # first load annotation + if not self.ann: + # load annotation + self.ann = Annotator(self.args.reference, self.args.annotation,cpu=True) + while True: + # send request for work + self.socket.send(str.encode(msg)) + res = self.socket.recv(2048).decode('utf-8') + # response can be a job, 'hold on' for empty queue, or 'Done' for all finished. + if res == 'Hold On': + msg = 'Ready for work...' + time.sleep(0.1) + elif res == 'Finished': + self.logger.info("Worker done. Shutting down") + break + else: + # got a batch id: + with open(os.path.join(self.tmpdir,res),'rb') as p: + data = pickle.load(p) + # remove pickled batch + os.unlink(os.path.join(self.tmpdir,res)) + # process : stats are send back as next 'ready for work' result. + try: + msg = self._process_batch(data['tensor_size'],data['batch_ix'], data['data'],data['length']) + except Exception as e: + raise Exception(f"Error processing batch {data['tensor_size']}|{data['batch_ix']}: {repr(e)}") + # send error message back to server + #msg = "Error : {}".format(repr(e)) + # send signal to server thread to exit. + self.socket.send(str.encode('Done')) + self.logger.info(f"Closing Worker on device {self.device}") + + + def _process_batch(self, tensor_size, batch_ix, prediction_batch, nr_preds): start = time.time() # Sanity check dump of batch sizes - self.logger.debug('Tensor size : {} : batch_ix {} : nr.entries : {}'.format(tensor_size, batch_ix , nr_preds)) + self.logger.debug('Tensor size : {} : batch_ix {} : nr.entries : {}'.format(tensor_size, batch_ix, nr_preds)) # Run predictions && add to shelf. self.shelf_preds["{}|{}".format(tensor_size,batch_ix)] = np.mean( diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index c6893ee..42bcfed 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -210,7 +210,7 @@ def initialize_devices(args): logger.info("Using the following devices for prediction:") for d in prediction_devices: - logger.info(f" - {d.name}") + logger.info(f" - {d}") return prediction_devices, mem_per_logical From 07ad8da084e27a7047d7d565568539ee7b62da19 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Mon, 24 Mar 2025 08:36:06 +0100 Subject: [PATCH 39/42] correct port selection --- spliceai/__main__.py | 11 +---------- spliceai/batch/batch_utils.py | 24 +++++++++++++++++------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/spliceai/__main__.py b/spliceai/__main__.py index 7f1387a..4b963c5 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -84,16 +84,7 @@ def main(): logging.error('Usage: spliceai [-h] [-I [input]] [-O [output]] -R reference -A annotation ' '[-D [distance]] [-M [mask]] [-B [prediction_batch_size]] [-T [tensorflow_batch_size]] [-t [tmp_location]]') exit() - # select a free socket - if args.port is None: - try: - sock = socket.socket() - sock.bind(('', 0)) - args.port = sock.getsockname()[1] - logging.debug(f"PORT:{args.port}") - except Exception as e: - logging.error(f"Error: {repr(e)}") - sys.exit(1) + ## revised code for batched analysis if args.prediction_batch_size > 1: diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index 42bcfed..36935fa 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -66,13 +66,23 @@ def start_workers(prediction_queue, tmpdir, args,devices,mem_per_logical): s = socket.socket() host = socket.gethostname() # locahost port = args.port - logger.info(f"Starting server: {host}:{port}") - - try: - s.bind((host,port)) - except Exception as e: - logger.error(f"Cannot bind to port {port} : {e}") - sys.exit(1) + # select a free socket + if args.port is None: + try: + sock = socket.socket() + sock.bind(('', 0)) + args.port = sock.getsockname()[1] + logging.debug(f"PORT:{args.port}") + except Exception as e: + logging.error(f"Error: {repr(e)}") + sys.exit(1) + else: + logger.info(f"Starting server: {host}:{port}") + try: + s.bind((host,port)) + except Exception as e: + logger.error(f"Cannot bind to port {port} : {e}") + sys.exit(1) s.listen(5) # start client sockets & server threads. clientThreads = list() From d4a2809821d236e3ea1930930137b4bf4584c2f4 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Mon, 24 Mar 2025 08:56:30 +0100 Subject: [PATCH 40/42] correction on port selection --- spliceai/batch/batch.py | 3 ++- spliceai/batch/batch_utils.py | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index 398d74e..4a6a9c3 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -85,8 +85,9 @@ def main(): # setup the socket try: s = socket.socket() - host = socket.gethostname() # locahost + host = socket.gethostname() # localhost port = args.port + logger.info(f"Connecting to server {host}:{port}") s.connect((host, port)) except Exception as e: raise(e) diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index 36935fa..7314711 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -69,9 +69,9 @@ def start_workers(prediction_queue, tmpdir, args,devices,mem_per_logical): # select a free socket if args.port is None: try: - sock = socket.socket() - sock.bind(('', 0)) - args.port = sock.getsockname()[1] + #sock = socket.socket() + s.bind(('', 0)) + args.port = s.getsockname()[1] logging.debug(f"PORT:{args.port}") except Exception as e: logging.error(f"Error: {repr(e)}") @@ -80,9 +80,11 @@ def start_workers(prediction_queue, tmpdir, args,devices,mem_per_logical): logger.info(f"Starting server: {host}:{port}") try: s.bind((host,port)) + except Exception as e: logger.error(f"Cannot bind to port {port} : {e}") sys.exit(1) + logging.info("Server started as : {} : {}".format(host,port)) s.listen(5) # start client sockets & server threads. clientThreads = list() From 7a721a65400a1835b159d8310124eb4ad5c26eac Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Mon, 24 Mar 2025 09:30:23 +0100 Subject: [PATCH 41/42] final --- spliceai/batch/batch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index 7314711..3df914c 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -150,7 +150,7 @@ def get_preds(ann, x, batch_size=32): predictions = [ann.models[m].predict(x, batch_size=batch_size, verbose=0) for m in range(5)] except Exception as e: # try a smaller batch (less efficient, but lower on memory). if it crashes again : it raises. - logger.warning("TF.predict failed ({}).Retrying with smaller batch size".format(e)) + logger.warning("TF.predict failed ({}). Retrying with smaller batch size".format(e)) predictions = [ann.models[m].predict(x, batch_size=4, verbose=0) for m in range(5)] # garbage collection to prevent memory overflow... gc.collect() From d23f835dff11326e7a58fb9c522b4b9a7518b19e Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Mon, 24 Mar 2025 09:46:17 +0100 Subject: [PATCH 42/42] more exception handling --- spliceai/batch/batch_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index 3df914c..5e6f042 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -136,7 +136,10 @@ def _process_server(clientsocket,device,queue): item = 'Hold On' # set reply - clientsocket.sendall(str.encode(str(item))) + try: + clientsocket.sendall(str.encode(str(item))) + except BrokenPipeError as e: + raise Exception(f"Error in server thread {device}: {repr(e)}") logger.debug(f"Closing {device} socket.") clientsocket.close()