diff --git a/tests/README.md b/tests/README.md index 0b725dba..9115300c 100644 --- a/tests/README.md +++ b/tests/README.md @@ -34,7 +34,7 @@ Theses file describe metadata about the test vector to encode an - `base_test`: The recommended textproto to diff against. - Other fields refer to the OBUs and data within the test vector. -# Input WAV files +## Input WAV files Test vectors may have multiple substreams with several input .wav files. These .wav files may be shared with other test vectors. The .textproto file has a @@ -68,7 +68,7 @@ Title | Summary `Transport_TOA_5s.wav` | Short clip of vehicles driving by using third-order ambisonics. | 16 | 48kHz | pcm_s16le | 5s `Transport_9.1.6_5s.wav` | Short clip of vehicles driving by using 9.1.6. | 16 | 48kHz | pcm_s16le | 5s -# Output WAV files +## Output WAV files Output wav files are based on the [layout](https://aomediacodec.github.io/iamf/#syntax-layout) in the mix @@ -93,20 +93,77 @@ Sound System 12 | IAMF | C Sound System 13 | IAMF | FL, FR, FC, LFE, BL, BR, FLc, FRc, SiL, SiR, TpFL, TpFR, TpBL, TpBR, TpSiL, TpSiR Binaural Layout | IAMF | L2, R2 -# Verification +## Decode and Verification -For test cases using Opus or AAC codecs, the average PSNR value must exceed 30, and for the other codecs, an average PSNR value exceeding 80 is considered PASS. -You can use `psnr_calc.py` file to calculate PSNR between reference and generated output. +For test cases with lossy codecs (Opus or AAC), the average PSNR value must +exceed 30. otherwise the average PSNR must exceed 80. -- How to use `psnr_calc.py` script: - ``` - python psnr_calc.py - --dir - --target - --ref - ``` +`run_decode_and_psnr_tests` will run the decoder for all reference test cases +and compare the PSNR between all outputs. + +Prerequisites: + +- The path to a built `iamfdec`, usually + `libiamf/code/test/tools/iamfdec/iamfdec` +- `protoc`, and compiled `libiamf/code/proto` files. +- A python environment with `scipy`, `protobuf`, `tqdm`, `numpy`. + +Note that example commands below assume a working directory of `libiamf/tests`. + +To compile the proto files run + +`protoc -I=proto/ --python_out=proto/ proto/*.proto` + +To set up a python environment using pip + +``` +python3 -m venv venv +source venv/bin/activate +pip install scipy protobuf tqdm numpy +``` + +Run the test suite. + +Arguments: -- Calculate PSNR values of multiple wav files +`iamfdec_path`, full path to the built `iamfdec` tool. `test_file_directory`, +full path to folder containing `.textproto` and reference `.wav` files. +`working_directory`, full path to write audio files produced by `iamfdec`. +`csv_summary`, optionally included, full path and filename to write a summary of +test results. + +``` +python3 run_decode_and_psnr_test.py --iamfdec_path /your/full/path/to/libiamf/code/test/tools/iamfdec/iamfdec --test_file_directory /your/full/path/to/libiamf/tests/ --working_directory /your/path/for/scratch/wav/files --csv_summary /your/path/to/write/summary.csv +``` + +For a simple configuration, this example will dump all files to the current +working directory. + +`python3 run_decode_and_psnr_test.py --iamfdec_path +../code/test/tools/iamfdec/iamfdec --test_file_directory $PWD --csv_summary +$PWD/summary.csv -w $PWD` + +Extra arguments: + +`regex_filter`, optionally included, regex to filter output files. For example +`--regex_filter="000100"` will run a single file, or +`--regex_filter="0001\d{2}"` will process files in the range [test_000100, +test_000199]. `verbose_test_summary`, turns on verbose logging. +`--preserve_output_files`, set to keep the output generated `.wav` files, +otherwise they are deleted. + +## Verification Only + +For test cases using Opus or AAC codecs, the average PSNR value must exceed 30, +and for the other codecs, an average PSNR value exceeding 80 is considered PASS. +You can use `psnr_calc.py` file to calculate PSNR between reference and +generated output. + +- How to use `psnr_calc.py` script: `python psnr_calc.py --dir --target --ref --verbose` + +- Calculate PSNR values of multiple wav files Multiple files can be entered as `::` @@ -114,4 +171,4 @@ You can use `psnr_calc.py` file to calculate PSNR between reference and generate Example: python psnr_calc.py --dir . --target target1.wav::target2.wav --ref ref1.wav::ref2.wav - ``` \ No newline at end of file + ``` diff --git a/tests/dsp_utils.py b/tests/dsp_utils.py new file mode 100644 index 00000000..52a64830 --- /dev/null +++ b/tests/dsp_utils.py @@ -0,0 +1,104 @@ +"""PSNR calculation utilities.""" + +import logging +import math +import wave +import numpy as np +import scipy.io.wavfile as wavfile + + +def calc_average_channel_psnr_pcm( + ref_signal: np.ndarray, signal: np.ndarray, sampwidth_bytes: int +): + """Calculates the PSNR between two signals. + + Args: + ref_signal: The reference signal as a numpy array. + signal: The signal to compare as a numpy array. + sampwidth_bytes: The sample width in bytes (e.g. 2 for 16-bit, 3 for + 24-bit). + + Returns: + The average PSNR in dB across all channels, or -1 if all channels are + identical. + """ + assert ( + sampwidth_bytes > 1 + ), "Supports sample format: [pcm_s16le, pcm_s24le, pcm_s32le]" + max_value = pow(2, sampwidth_bytes * 8) - 1 + + # To prevent overflow + ref_signal = ref_signal.astype("int64") + signal = signal.astype("int64") + + mse = np.mean((ref_signal - signal) ** 2, axis=0, dtype="float64") + + psnr_list = list() + + # To support mono signal + num_channels = 1 if ref_signal.shape[1:] == () else ref_signal.shape[1] + for i in range(num_channels): + mse_value = mse[i] if num_channels > 1 else mse + if mse_value == 0: + logging.debug("ch#%d PSNR: inf", i) + else: + psnr_value = 10 * math.log10(max_value**2 / mse_value) + psnr_list.append(psnr_value) + logging.debug("ch#%d PSNR: %f dB", i, psnr_value) + + return -1 if len(psnr_list) == 0 else sum(psnr_list) / len(psnr_list) + + +def calc_average_channel_psnr_wav(ref_filepath: str, target_filepath: str): + """Calculates the PSNR between two WAV files. + + Args: + ref_filepath: Path to the reference WAV file. + target_filepath: Path to the target WAV file to compare. + + Returns: + The average PSNR in dB across all channels. Or -1 if all channels are + identical. + + Raises: + Exception: If the wav files have different samplerate, channels, bit-depth + or number of samples. + """ + ref_wav = wave.open(ref_filepath, "rb") + target_wav = wave.open(target_filepath, "rb") + + # Check sampling rate + if ref_wav.getframerate() != target_wav.getframerate(): + raise ValueError( + "Sampling rate of reference file and comparison file are different:" + f" {ref_filepath} vs {target_filepath}" + ) + + # Check number of channels + if ref_wav.getnchannels() != target_wav.getnchannels(): + raise ValueError( + "Number of channels of reference file and comparison file are" + f" different: {ref_filepath} vs {target_filepath}" + ) + + # Check number of samples + if ref_wav.getnframes() != target_wav.getnframes(): + raise ValueError( + "Number of samples of reference file and comparison file are different:" + f" {ref_filepath} vs {target_filepath}" + ) + + # Check bit depth + if ref_wav.getsampwidth() != target_wav.getsampwidth(): + raise ValueError( + "Bit depth of reference file and comparison file are different:" + f" {ref_filepath} vs {target_filepath}" + ) + + # Open wav as a np array + _, ref_data = wavfile.read(ref_filepath) + _, target_data = wavfile.read(target_filepath) + + return calc_average_channel_psnr_pcm( + ref_data, target_data, ref_wav.getsampwidth() + ) diff --git a/proto/arbitrary_obu.proto b/tests/proto/arbitrary_obu.proto similarity index 100% rename from proto/arbitrary_obu.proto rename to tests/proto/arbitrary_obu.proto diff --git a/proto/audio_element.proto b/tests/proto/audio_element.proto similarity index 100% rename from proto/audio_element.proto rename to tests/proto/audio_element.proto diff --git a/proto/audio_frame.proto b/tests/proto/audio_frame.proto similarity index 100% rename from proto/audio_frame.proto rename to tests/proto/audio_frame.proto diff --git a/proto/codec_config.proto b/tests/proto/codec_config.proto similarity index 100% rename from proto/codec_config.proto rename to tests/proto/codec_config.proto diff --git a/proto/ia_sequence_header.proto b/tests/proto/ia_sequence_header.proto similarity index 100% rename from proto/ia_sequence_header.proto rename to tests/proto/ia_sequence_header.proto diff --git a/proto/mix_presentation.proto b/tests/proto/mix_presentation.proto similarity index 100% rename from proto/mix_presentation.proto rename to tests/proto/mix_presentation.proto diff --git a/proto/obu_header.proto b/tests/proto/obu_header.proto similarity index 100% rename from proto/obu_header.proto rename to tests/proto/obu_header.proto diff --git a/proto/param_definitions.proto b/tests/proto/param_definitions.proto similarity index 100% rename from proto/param_definitions.proto rename to tests/proto/param_definitions.proto diff --git a/proto/parameter_block.proto b/tests/proto/parameter_block.proto similarity index 100% rename from proto/parameter_block.proto rename to tests/proto/parameter_block.proto diff --git a/proto/parameter_data.proto b/tests/proto/parameter_data.proto similarity index 100% rename from proto/parameter_data.proto rename to tests/proto/parameter_data.proto diff --git a/proto/temporal_delimiter.proto b/tests/proto/temporal_delimiter.proto similarity index 100% rename from proto/temporal_delimiter.proto rename to tests/proto/temporal_delimiter.proto diff --git a/proto/test_vector_metadata.proto b/tests/proto/test_vector_metadata.proto similarity index 100% rename from proto/test_vector_metadata.proto rename to tests/proto/test_vector_metadata.proto diff --git a/proto/user_metadata.proto b/tests/proto/user_metadata.proto similarity index 100% rename from proto/user_metadata.proto rename to tests/proto/user_metadata.proto diff --git a/tests/psnr_calc.py b/tests/psnr_calc.py index 06f737fe..c523f535 100644 --- a/tests/psnr_calc.py +++ b/tests/psnr_calc.py @@ -1,139 +1,91 @@ import argparse -import wave +import logging import os -import scipy.io.wavfile as wavfile -import numpy as np -import math - -parser = argparse.ArgumentParser(description="PSNR verification script") -parser.add_argument( - "--dir", - type=str, - required=True, - help="decoder verification wav output directory", -) -parser.add_argument( - "--target", - type=str, - required=True, - help="decoder verification wav output file. Multiple files can be entered as ::. (ex - test1.wav::test2.wav)", -) -parser.add_argument( - "--ref", - type=str, - required=True, - help="decoder verification PSNR evaluation reference file. Multiple files can be entered as ::. (ex - test1.wav::test2.wav)", -) -args = parser.parse_args() - - -def get_sampwdith(path): - with wave.open(path, "rb") as wf: - sampwidth_bytes = wf.getsampwidth() - return sampwidth_bytes - - -def calc_psnr(ref_signal, signal, sampwidth_bytes): - assert ( - sampwidth_bytes > 1 - ), "Supports sample format: [pcm_s16le, pcm_s24le, pcm_s32le]" - max_value = pow(2, sampwidth_bytes * 8) - 1 - - # To prevent overflow - ref_signal = ref_signal.astype("int64") - signal = signal.astype("int64") - - mse = np.mean((ref_signal - signal) ** 2, axis=0, dtype="float64") - - psnr_list = list() - - # To support mono signal - num_channels = 1 if ref_signal.shape[1:] == () else ref_signal.shape[1] - for i in range(num_channels): - mse_value = mse[i] if num_channels > 1 else mse - if mse_value == 0: - print(f"ch#{i} PSNR: inf") - else: - psnr_value = 10 * math.log10(max_value**2 / mse_value) - psnr_list.append(psnr_value) - print(f"ch#{i} PSNR: {psnr_value} dB") - - return -1 if len(psnr_list) == 0 else sum(psnr_list) / len(psnr_list) - - -target_file_list = args.target.split("::") -ref_file_list = args.ref.split("::") - -tc_number_list = [] -psnr_list = [] -for file_idx in range(len(target_file_list)): +import dsp_utils + + +def main(): + """Main function for PSNR calculation script.""" + parser = argparse.ArgumentParser(description='PSNR verification script') + parser.add_argument( + '--dir', + type=str, + required=True, + help='decoder verification wav output directory', + ) + parser.add_argument( + '--target', + type=str, + required=True, + help=( + 'decoder verification wav output file. Multiple files can be entered' + ' as ::. (ex - test1.wav::test2.wav)' + ), + ) + parser.add_argument( + '--ref', + type=str, + required=True, + help=( + 'decoder verification PSNR evaluation reference file. Multiple files' + ' can be entered as ::. (ex - test1.wav::test2.wav)' + ), + ) + parser.add_argument( + '-v', + '--verbose', + action='store_true', + help='Verbose logging, of PSNR valuesfor each channel.', + ) + args = parser.parse_args() + logging.basicConfig( + level=logging.DEBUG if args.verbose_logging else logging.INFO, + format='%(message)s', + ) + + target_file_list = args.target.split('::') + ref_file_list = args.ref.split('::') + + tc_number_list = [] + psnr_list = [] + for file_idx in range(len(target_file_list)): target_file = target_file_list[file_idx] ref_file = ref_file_list[file_idx] print( - "[%d] PSNR evaluation: compare %s with %s" + '[%d] PSNR evaluation: compare %s with %s' % (file_idx, target_file, ref_file) ) tc_number_list.append(file_idx) try: - ref_filepath = os.path.join( - os.path.dirname(os.path.abspath(__file__)), args.dir, ref_file - ) - target_filepath = os.path.join( - os.path.dirname(os.path.abspath(__file__)), args.dir, target_file - ) - - ref_samplerate, ref_data = wavfile.read(ref_filepath) - target_samplerate, target_data = wavfile.read(target_filepath) - - ref_sampwdith_bytes = get_sampwdith(ref_filepath) - target_sampwidth_bytes = get_sampwdith(target_filepath) - - # Check sampling rate - if not (ref_samplerate == target_samplerate): - print(ref_file, " / ", target_file) - raise Exception( - "Sampling rate of reference file and comparison file are different." - ) - - # Check number of channels - if not (ref_data.shape[1:] == target_data.shape[1:]): - raise Exception( - "Number of channels of reference file and comparison file are different." - ) - - # Check number of samples - if not (ref_data.shape[0] == target_data.shape[0]): - print(ref_file, " / ", target_file) - raise Exception( - "Number of samples of reference file and comparison file are different." - ) - - # Check bit depth - if not (ref_sampwdith_bytes == target_sampwidth_bytes): - print(ref_file, " / ", target_file) - raise Exception( - "Bit depth of reference file and comparison file are different." - ) - - average_psnr = calc_psnr(ref_data, target_data, ref_sampwdith_bytes) - if average_psnr != -1: - print("average PSNR: %.15f" % (average_psnr)) - psnr_list.append(average_psnr) - else: - print("average PSNR: %.15f" % (100)) - psnr_list.append(100) - except Exception as err: - print(str(err)) - psnr_list.append(0) - print("") - -# print result -print( - "\n\n\n[Result] - (If the OPUS or AAC codec has a over avgPSNR 30, it is considered PASS. Other codecs must be over avgPSNR 80.)" -) -for i in range(len(tc_number_list)): + ref_filepath = os.path.join( + os.path.dirname(os.path.abspath(__file__)), args.dir, ref_file + ) + target_filepath = os.path.join( + os.path.dirname(os.path.abspath(__file__)), args.dir, target_file + ) + + average_psnr = dsp_utils.calc_average_channel_psnr_wav( + ref_filepath, target_filepath + ) + if average_psnr != -1: + print('average PSNR: %.15f' % (average_psnr)) + psnr_list.append(average_psnr) + else: + print('average PSNR: %.15f' % (100)) + psnr_list.append(100) + except ValueError as err: + print(str(err)) + psnr_list.append(0) + print('') + + # print result + print( + '\n\n\n[Result] - (If the OPUS or AAC codec has a over avgPSNR 30, it is' + ' considered PASS. Other codecs must be over avgPSNR 80.)' + ) + for i in range(len(tc_number_list)): print( - "TC#%d : %.3f (compare %s with %s)" + 'TC#%d : %.3f (compare %s with %s)' % ( tc_number_list[i], round(psnr_list[i], 3), @@ -141,3 +93,7 @@ def calc_psnr(ref_signal, signal, sampwidth_bytes): ref_file_list[i], ) ) + + +if __name__ == '__main__': + main() diff --git a/tests/run_decode_and_psnr_test.py b/tests/run_decode_and_psnr_test.py new file mode 100644 index 00000000..03fd374a --- /dev/null +++ b/tests/run_decode_and_psnr_test.py @@ -0,0 +1,336 @@ +import argparse +from collections import defaultdict +import csv +from dataclasses import dataclass, field +import enum +import glob +import logging +import os +import re +import subprocess +import sys +from typing import Optional +from google.protobuf import text_format +from tqdm import tqdm + +sys.path.append(os.path.join(os.path.dirname(__file__), 'proto')) +import user_metadata_pb2 +import dsp_utils +import user_metadata_parsing_utils as utils +from user_metadata_parsing_utils import TestExclusions + + +exclusions = [ + TestExclusions( + file_name_prefix='test_000710', + mix_presentation_id=42, + layout_index=0, + reason='Mix surpasses base-enhanced profile limits', + ), + TestExclusions( + file_name_prefix='test_000711', + mix_presentation_id=42, + layout_index=0, + reason='Mix surpasses base-enhanced profile limits', + ), + TestExclusions( + file_name_prefix='test_000126', + mix_presentation_id=42, + layout_index=1, + reason='Extension layouts cannot be decoded.', + ), +] + +# Opus/AAC are lossy codecs, we allow a more lenient threshold for them. +LOSSY_PSNR_THRESHOLD = 30 +LOSSLESS_PSNR_THRESHOLD = 80 + + +class ResultStatus(enum.Enum): + SUCCESS = 1 + FAILURE = 2 + CRASH = 3 + SKIPPED = 4 + + +@dataclass +class Result: + test_prefix: str + is_lossy: bool + mix_presentation_id: int + sub_mix_index: int + layout_index: int + psnr_score: Optional[float] = None + reason: Optional[str] = None + iamfdec_command: Optional[str] = None + + def log(self, status: ResultStatus): + logging.debug( + '%s: %s >= %s for %s', + status.name, + self.psnr_score, + self.is_lossy, + self.test_prefix, + ) + logging.debug('') + + +@dataclass +class TestSummary: + results: dict[ResultStatus, list[Result]] = field( + default_factory=lambda: defaultdict(list) + ) + + def print_test_summary(self, csv_summary_file=None): + """Prints test summary to console and optionally a CSV file. + + Args: + csv_summary_file: Path to CSV file to log test results. + """ + logging.info('\n-----------------SUMMARY-----------------') + for status in ResultStatus: + logging.info('%s: %d', status.name, len(self.results[status])) + + logging.info('-----------------------------------------') + + if csv_summary_file: + with open(csv_summary_file, 'w', newline='') as csvfile: + csvwriter = csv.writer(csvfile) + csvwriter.writerow([ + 'Test Prefix', + 'Mix ID', + 'Submix Index', + 'Layout Index', + 'Status', + 'PSNR', + 'Is Lossy', + 'Reason', + 'Command', + ]) + for status in ResultStatus: + for item in self.results[status]: + csvwriter.writerow([ + item.test_prefix, + item.mix_presentation_id, + item.sub_mix_index, + item.layout_index, + status.name, + item.psnr_score if item.psnr_score is not None else '', + 'lossy' if item.is_lossy else 'lossless', + item.reason if item.reason is not None else '', + item.iamfdec_command + if item.iamfdec_command is not None + else '', + ]) + + +def run_decoder(args, metadata): + """Runs the iamfdec decoder and returns True if successful.""" + iamfdec_args = utils.get_iamfdec_args( + metadata, args.test_file_directory, args.working_directory + ) + if iamfdec_args is None: + return False, None + cmd = [args.iamfdec_path] + iamfdec_args + cmd_str = ' '.join(cmd) + + verbose_logging = logging.getLogger().isEnabledFor(logging.DEBUG) + logging.debug('Running: %s', cmd) + try: + subprocess.run( + cmd, + check=True, + stdout=subprocess.PIPE if verbose_logging else subprocess.DEVNULL, + stderr=subprocess.PIPE if verbose_logging else subprocess.DEVNULL, + ) + except subprocess.CalledProcessError: + return False, cmd_str + return True, cmd_str + + +def run_psnr_test(args, metadata): + """Gets PSNR score, returns None if calculation fails. + + Args: + args: Command line arguments. + metadata: Metadata for the test vector. + + Returns: + A tuple of (ResultStatus, reason, psnr_score). + """ + ref_file = os.path.join( + args.test_file_directory, metadata.golden_wav_file_name + ) + test_file = os.path.join( + args.working_directory, metadata.base_name_to_generate + ) + assert os.path.exists(ref_file), f'Reference file {ref_file} does not exist.' + assert os.path.exists(test_file), f'Test file {test_file} does not exist.' + logging.debug('ref_file: %s', ref_file) + logging.debug('test_file: %s', test_file) + try: + raw_psnr_score = dsp_utils.calc_average_channel_psnr_wav( + ref_file, test_file + ) + except ValueError as e: + print(f'Failed to calculate PSNR: {e}') + return ResultStatus.CRASH, 'PSNR calculation failed', None + + psnr_score = 100 if raw_psnr_score == -1 else raw_psnr_score + # Check if this PSNR is a pass or a fail, it depends on whether the test + # represents a lossy or lossless codec. + logging.debug('psnr score: %s', psnr_score) + threshold = ( + LOSSY_PSNR_THRESHOLD if metadata.is_lossy else LOSSLESS_PSNR_THRESHOLD + ) + if psnr_score >= threshold: + return ResultStatus.SUCCESS, None, psnr_score + else: + return ResultStatus.FAILURE, 'PSNR score below threshold.', psnr_score + + +def _is_excluded(metadata, exclusions, args): + if not args.test_binaural and metadata.is_binaural: + return 'Binaural layout not tested by default.' + + for exclusion in exclusions: + if ( + exclusion.file_name_prefix == metadata.test_prefix + and exclusion.mix_presentation_id == metadata.mix_presentation_id + and exclusion.layout_index == metadata.layout_index + ): + print( + f'Skipping {metadata.test_prefix} layout' + f' {metadata.layout_index} for mix ID' + f' {metadata.mix_presentation_id} because ({exclusion.reason})' + ) + return exclusion.reason + return None + + +def run_tests(args, text_proto_files) -> TestSummary: + """Runs tests on the given textproto files. + + Args: + args: Command line arguments. + text_proto_files: List of textproto files to run tests on. + + Returns: + A TestSummary object containing the results of the tests. + """ + summary = TestSummary() + progress_bar = tqdm(text_proto_files) + for text_proto_path in progress_bar: + progress_bar.set_description(os.path.basename(text_proto_path)) + with open(text_proto_path, 'r') as f: + user_metadata = text_format.Parse( + f.read(), user_metadata_pb2.UserMetadata() + ) + + # Get the metadata for this test vector, there may be multiple mix + # presentation, submix, and layout combinations per test vector. + metadatas = utils.get_test_combination_metadata( + user_metadata, args.test_file_directory + ) + for metadata in metadatas: + # Usually each loop will generate a new file. Delete them if they are new + file_to_generate = os.path.join( + args.working_directory, metadata.base_name_to_generate + ) + generated_file_is_new = not os.path.exists(file_to_generate) + cleanup_after_decode = ( + generated_file_is_new and not args.preserve_output_files + ) + status, reason, psnr_score, cmd_str = None, None, None, None + if skip_reason := _is_excluded(metadata, exclusions, args): + # Test was intentionally excluded. + reason = skip_reason + status = ResultStatus.SKIPPED + else: + decoder_success, cmd_str = run_decoder(args, metadata) + if decoder_success: + # Run the PSNR test, this could crash, or be better than or worse than + # the threshold PSNR. + status, reason, psnr_score = run_psnr_test(args, metadata) + else: + # Decoder crashed. + reason = 'iamfdec crash' + status = ResultStatus.CRASH + + # Regardless of status, record what happened. + result = Result( + metadata.test_prefix, + metadata.is_lossy, + metadata.mix_presentation_id, + metadata.sub_mix_index, + metadata.layout_index, + psnr_score, + reason, + iamfdec_command=cmd_str, + ) + result.log(status) + summary.results[status].append(result) + if cleanup_after_decode and os.path.exists(file_to_generate): + os.remove(file_to_generate) + return summary + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-t', '--test_file_directory', required=True) + parser.add_argument('-w', '--working_directory', required=True) + parser.add_argument('-l', '--iamfdec_path', required=True) + parser.add_argument( + '-p', + '--preserve_output_files', + action='store_true', + help='Preserve output files in working directory.', + ) + parser.add_argument( + '-r', + '--regex_filter', + help='Regex filter to apply to textproto filenames.', + ) + parser.add_argument( + '-v', + '--verbose_test_summary', + action='store_true', + help='Print verbose test summary', + ) + parser.add_argument( + '-c', + '--csv_summary_file', + help='Path to CSV file to log test results.', + ) + parser.add_argument( + '-b', + '--test_binaural', + action='store_true', + help='Enable testing binaural layouts.', + ) + args = parser.parse_args() + + logging.basicConfig( + level=logging.DEBUG if args.verbose_test_summary else logging.INFO, + format='%(message)s', + ) + + # The textprotos contain metadata the various test vectors. + glob_path = os.path.join(args.test_file_directory, '*.textproto') + text_proto_files = glob.glob(glob_path) + assert text_proto_files, f'No textproto files found in {glob_path}' + + if args.regex_filter: + text_proto_files = [ + f + for f in text_proto_files + if re.search(args.regex_filter, os.path.basename(f)) + ] + text_proto_files.sort() + + test_summary = run_tests(args, text_proto_files) + test_summary.print_test_summary(args.csv_summary_file) + + +if __name__ == '__main__': + main() diff --git a/tests/test_000070_rendered_id_42_sub_mix_0_layout_1.wav b/tests/test_000070_rendered_id_42_sub_mix_0_layout_1.wav index 4080a071..910bcceb 100644 Binary files a/tests/test_000070_rendered_id_42_sub_mix_0_layout_1.wav and b/tests/test_000070_rendered_id_42_sub_mix_0_layout_1.wav differ diff --git a/tests/test_000210_rendered_id_42_sub_mix_0_layout_2.wav b/tests/test_000210_rendered_id_42_sub_mix_0_layout_2.wav index 0cda704e..35c16ed5 100644 Binary files a/tests/test_000210_rendered_id_42_sub_mix_0_layout_2.wav and b/tests/test_000210_rendered_id_42_sub_mix_0_layout_2.wav differ diff --git a/tests/user_metadata_parsing_utils.py b/tests/user_metadata_parsing_utils.py new file mode 100644 index 00000000..102e7fc3 --- /dev/null +++ b/tests/user_metadata_parsing_utils.py @@ -0,0 +1,181 @@ +"""Utils for parsing user metadata.""" + +from dataclasses import dataclass, field +import os +import sys +import wave + +sys.path.append(os.path.join(os.path.dirname(__file__), 'proto')) +import codec_config_pb2 +import mix_presentation_pb2 +import user_metadata_pb2 + + +@dataclass +class TestExclusions: + file_name_prefix: str + mix_presentation_id: int + layout_index: int + reason: str + + +@dataclass +class TestCombinationMetadata: + test_prefix: str + mix_presentation_id: int + sub_mix_index: int + layout_index: int + golden_wav_file_name: str + base_name_to_generate: str + layout_type_enum: mix_presentation_pb2.LayoutType = field(repr=False) + sound_system_enum: mix_presentation_pb2.SoundSystem | None = field(repr=False) + is_lossy: bool = field(repr=False) + is_binaural: bool = False + bit_depth: int = 16 + sample_rate: int = 48000 + + +_SOUND_SYSTEM_TO_FLAG = { + mix_presentation_pb2.SOUND_SYSTEM_A_0_2_0: '0', + mix_presentation_pb2.SOUND_SYSTEM_B_0_5_0: '1', + mix_presentation_pb2.SOUND_SYSTEM_C_2_5_0: '2', + mix_presentation_pb2.SOUND_SYSTEM_D_4_5_0: '3', + mix_presentation_pb2.SOUND_SYSTEM_E_4_5_1: '4', + mix_presentation_pb2.SOUND_SYSTEM_F_3_7_0: '5', + mix_presentation_pb2.SOUND_SYSTEM_G_4_9_0: '6', + mix_presentation_pb2.SOUND_SYSTEM_H_9_10_3: '7', + mix_presentation_pb2.SOUND_SYSTEM_I_0_7_0: '8', + mix_presentation_pb2.SOUND_SYSTEM_J_4_7_0: '9', + mix_presentation_pb2.SOUND_SYSTEM_10_2_7_0: '10', + mix_presentation_pb2.SOUND_SYSTEM_11_2_3_0: '11', + mix_presentation_pb2.SOUND_SYSTEM_12_0_1_0: '12', + mix_presentation_pb2.SOUND_SYSTEM_13_6_9_0: '13', +} + + +def _map_layout_to_iamfdec_s_flag( + layout_type_enum: mix_presentation_pb2.LayoutType, + sound_system_enum: mix_presentation_pb2.SoundSystem | None, +) -> str: + """Maps layout in metadata to '-s' flag.""" + if layout_type_enum == mix_presentation_pb2.LAYOUT_TYPE_BINAURAL: + return 'b' + if ( + layout_type_enum + == mix_presentation_pb2.LAYOUT_TYPE_LOUDSPEAKERS_SS_CONVENTION + ): + flag = _SOUND_SYSTEM_TO_FLAG.get(sound_system_enum) + if flag is None: + # Some test vectors use reserved types. + print(f'Could not map sound system to -s flag: {sound_system_enum}') + return flag + + # Some test vectors use reserved types. + print(f'Could not map layout type to -s flag: {layout_type_enum}') + return None + + +def get_iamfdec_args( + metadata: TestCombinationMetadata, input_path: str, output_path: str +) -> list[str]: + """Generates renderer command line arguments for a TestCombinationMetadata. + + Args: + metadata: A TestCombinationMetadata object. + output_path: Path to output directory for generated WAV file. + + Returns: + A list of strings for the command line arguments. + """ + iamfdec_s_flag = _map_layout_to_iamfdec_s_flag( + metadata.layout_type_enum, metadata.sound_system_enum + ) + if iamfdec_s_flag is None: + return None + return [ + '-i0', + '-mp', + str(metadata.mix_presentation_id), + f'-s{iamfdec_s_flag}', + '-o3', + ## Join directory and path + os.path.join(output_path, metadata.base_name_to_generate), + '-d', + str(metadata.bit_depth), + '-r', + str(metadata.sample_rate), + '-disable_limiter', + os.path.join(input_path, f'{metadata.test_prefix}.iamf'), + ] + + +def get_test_combination_metadata(user_metadata_proto, test_file_directory): + """Parses TestCombinationMetadata from UserMetadata proto. + + Args: + user_metadata_proto: A UserMetadata proto object. + test_file_directory: Directory containing golden WAV files. + + Returns: + A list of TestCombinationMetadata objects containing mix presentation + layout information, or None if file_name_prefix is missing. + """ + file_name_prefix = user_metadata_proto.test_vector_metadata.file_name_prefix + if not file_name_prefix: + return [] + if not user_metadata_proto.test_vector_metadata.is_valid_to_decode: + # Skip test vectors that are not valid to decode. + return [] + + assert len(user_metadata_proto.codec_config_metadata) == 1 + codec_id = user_metadata_proto.codec_config_metadata[0].codec_config.codec_id + is_lossy = codec_id in [ + codec_config_pb2.CODEC_ID_AAC_LC, + codec_config_pb2.CODEC_ID_OPUS, + ] + + result = [] + for mp_metadata in user_metadata_proto.mix_presentation_metadata: + for sub_mix_idx, sub_mix in enumerate(mp_metadata.sub_mixes): + for layout_idx, layout in enumerate(sub_mix.layouts): + golden_wav_file_name = f'{file_name_prefix}_rendered_id_{mp_metadata.mix_presentation_id}_sub_mix_{sub_mix_idx}_layout_{layout_idx}.wav' + golden_wav_path = os.path.join( + test_file_directory, golden_wav_file_name + ) + bit_depth = 16 + sample_rate = 48000 + if os.path.exists(golden_wav_path): + with wave.open(golden_wav_path, 'rb') as wave_file: + bit_depth = wave_file.getsampwidth() * 8 + sample_rate = wave_file.getframerate() + else: + print( + 'Warning: golden wav file not found, sometimes this is because' + f' the mix presentation is invalid to decode: {golden_wav_path}' + ) + continue + ss_enum = None + if ( + layout.loudness_layout.layout_type + == mix_presentation_pb2.LAYOUT_TYPE_LOUDSPEAKERS_SS_CONVENTION + ): + ss_enum = layout.loudness_layout.ss_layout.sound_system + data = TestCombinationMetadata( + test_prefix=file_name_prefix, + mix_presentation_id=mp_metadata.mix_presentation_id, + sub_mix_index=sub_mix_idx, + layout_index=layout_idx, + golden_wav_file_name=golden_wav_file_name, + base_name_to_generate=golden_wav_file_name.replace( + '.wav', '_generated.wav' + ), + layout_type_enum=layout.loudness_layout.layout_type, + sound_system_enum=ss_enum, + is_lossy=is_lossy, + bit_depth=bit_depth, + sample_rate=sample_rate, + is_binaural=layout.loudness_layout.layout_type + == mix_presentation_pb2.LAYOUT_TYPE_BINAURAL, + ) + result.append(data) + return result