diff --git a/asr_test.sh b/asr_test.sh new file mode 100755 index 0000000..89e486c --- /dev/null +++ b/asr_test.sh @@ -0,0 +1,297 @@ +#!/usr/bin/env bash +# asr_test.sh +# +# Automatic Speech Recognition (ASR) testing for the Radio Autoencoder. This script +# takes the samples from a clean dataset (e.g. Librispeech test-clean), and generates +# a dataset with channel simulations (RADE, SSB etc) applied. + +CODEC2_DEV=${CODEC2_DEV:-${HOME}/codec2-dev} +PATH=${PATH}:${CODEC2_DEV}/build_linux/src:${CODEC2_DEV}/build_linux/misc:${PWD}/build/src + +which ch >/dev/null || { printf "\n**** Can't find ch - check CODEC2_PATH **** \n\n"; exit 1; } + +source utils.sh + +function print_help { + echo + echo "Automated Speech Recognition (ASR) dataset processing for Radio Autoencoder testing" + echo + echo " usage ./asr_test.sh ssb|rade|700D|fargan|4kHz [test option below]" + echo " usage ./ota_test.sh ssb --No -30" + echo " usage ./ota_test.sh rade --EbNodB 10" + echo + echo " --EbNodB EbNodB inference.py simulation noise level (experiment to get desired SNR)" + echo " --No NodB ch channel simulation No value (experiment to get desired SNR)" + echo " -n numSamples number of dataset samples to process (default all)" + echo " --results resultsFile name of results file (deafult results.txt)" + echo " -d verbose debug information" + exit +} + +n_samples=0 +No=-100 +EbNodB=100 +setpoint_rms=2048 +comp_gain=6 +results=asr_results.txt +inference_args="" +ch_args="" +sil=0.5 + +POSITIONAL=() +while [[ $# -gt 0 ]] +do +key="$1" +case $key in + --EbNodB) + EbNodB="$2" + shift + shift + ;; + --g_file) + g_file="$2" + if [ ! -f $2 ]; then + echo "can't find $2" + exit 1 + fi + inference_args="${inference_args} --g_file ${2}" + cp ${2} fast_fading_samples.float + ch_args="${ch_args} --fading_dir . --mpp --gain 0.5" + shift + shift + ;; + --No) + No="$2" + shift + shift + ;; + --results) + results="$2" + shift + shift + ;; + -n) + n_samples="$2" + shift + shift + ;; + -d) + set -x; + shift + ;; + -h) + print_help + ;; + *) + POSITIONAL+=("$1") # save it in an array for later + shift + ;; +esac +done +set -- "${POSITIONAL[@]}" # restore positional parameters + +if [ $# -lt 1 ]; then + print_help +fi +mode=$1 + +source=~/.cache/LibriSpeech/test-clean +if [ ! -d $source ]; then + echo "cant find Librispeech source directory" $source + exit 1 +fi +# results must be written to a directory known by Librispeech package (can't be any name) +dest=~/.cache/LibriSpeech/test-other +rm -Rf $dest + +# cp translation files to new dataset directory +function cp_translation_files { + pushd $source > /dev/null; trans=$(find . -name '*.txt'); popd > /dev/null + for f in $trans + do + d=$(dirname $f) + mkdir -p ${dest}/${d} + cp ${source}/${f} ${dest}/${f} + done +} + +function print_mean_text_file { + file_name=$1 + python3 - < /dev/null; flac=$(find . -name '*.flac'); popd > /dev/null + if [ $n_samples -ne 0 ]; then + flac=$(echo "$flac" | shuf --random-source=<(yes 42) | head -n $n_samples) + fi + + n=$(echo "$flac" | wc -l) + printf "Processing %d samples in dataset\n" $n + + in=in.raw + comp=comp.raw + ch_log=ch_log.txt + rade_log=rade_log.txt + snr_log=snr_log.txt + asr_log=asr.txt + rm -f ${snr_log} + CNo_log=CNo_log.txt + rm -f ${CNo_log} + sox -n -r 16000 -c 1 /tmp/silence.wav trim 0.0 ${sil} + + if [ $mode == "ssb" ] || [ $mode == "4kHz" ]; then + + fading_adv=0 + for f in $flac + do + d=$(dirname $f) + mkdir -p ${dest}/${d} + + if [ $mode == "ssb" ]; then + sox ${source}/${f} -t .s16 -r 8000 ${in} + # AGC and Hilbert compression + set_rms ${in} $setpoint_rms + analog_compressor ${in} ${comp} ${comp_gain} 2>/dev/null + ch ${comp} - --No ${No} ${ch_args} --fading_adv ${fading_adv} 2>${ch_log} | sox -t .s16 -r 8000 -c 1 - -r 16000 ${dest}/${f} + grep "Fading file finished" $ch_log + if [ $? -eq 0 ]; then + echo "Error - fading file too short after" $fading_adv " seconds" + exit 1 + fi + snr=$(cat $ch_log | grep "SNR3k" | tr -s ' ' | cut -d' ' -f3) + CNo=$(cat $ch_log | grep "SNR3k" | tr -s ' ' | cut -d' ' -f5) + echo $snr >> ${snr_log} + echo $CNo >> ${CNo_log} + + # advance through fading simulation file + dur=$(sox --info -D ${source}/${f}) + fading_adv=$(python3 -c "print(${fading_adv} + ${dur})") + else + # $mode == "4kHz" (4kHz bandwidth, representing ideal Fs=8kHz vocoder) + sox ${source}/${f} -r 8000 -t .s16 -c 1 - | sox -r 8000 -t .s16 -c 1 - -r 16000 ${dest}/${f} + fi + done + if [ $mode == "ssb" ]; then + SNR_mean=$(print_mean_text_file ${snr_log}) + CNo_mean=$(print_mean_text_file ${CNo_log}) + fi + fi + + if [ $mode == "700D" ]; then + + fading_adv=0 + for f in $flac + do + d=$(dirname $f) + mkdir -p ${dest}/${d} + + # silence either side of sample to allow time for acquisition and latency + sox /tmp/silence.wav /tmp/silence.wav ${source}/${f} /tmp/silence.wav -t .s16 -r 8000 ${in} + + # trim start to remove acquisition noise + freedv_tx 700D ${in} - | \ + ch - - --No ${No} ${ch_args} --fading_adv ${fading_adv} 2>${ch_log} | \ + freedv_rx 700D - out.raw 2>/dev/null + cat out.raw | sox -t .s16 -r 8000 -c 1 - -r 16000 ${dest}/${f} trim 0.5 + # error check + grep "Fading file finished" $ch_log + if [ $? -eq 0 ]; then + echo "Error - fading file too short after" $fading_adv " seconds" + exit 1 + fi + snr=$(cat $ch_log | grep "SNR3k" | tr -s ' ' | cut -d' ' -f3) + CNo=$(cat $ch_log | grep "SNR3k" | tr -s ' ' | cut -d' ' -f5) + echo $snr >> ${snr_log} + echo $CNo >> ${CNo_log} + + # advance through fading simulation file + dur=$(sox --info -D ${source}/${f}) + fading_adv=$(python3 -c "print(${fading_adv} + ${dur})") + + done + SNR_mean=$(print_mean_text_file ${snr_log}) + CNo_mean=$(print_mean_text_file ${CNo_log}) + fi + + if [ $mode == "rade" ] || [ $mode == "fargan" ]; then + # find length of each file + duration_log="" + flac_full="" + pushd $source > /dev/null; + for f in $flac + do + duration_log+=$(sox --info -D ${f}) + duration_log+=" " + flac_full+="${source}/${f} /tmp/silence.wav " + done + popd > /dev/null; + + # cat samples into one long input file, insert 500ms at end of sample to allow for processing at output + sox $flac_full -t .s16 ${in} + + # process all samples as one file to save time + + if [ $mode == "rade" ]; then + ./inference.sh model19_check3/checkpoints/checkpoint_epoch_100.pth ${in} out.wav \ + --rate_Fs --pilots --pilot_eq --eq_ls --cp 0.004 --bottleneck 3 --auxdata --time_offset -16 \ + --EbNodB $EbNodB ${inference_args} | tee ${rade_log} + grep "Multipath Doppler spread file too short" $rade_log + if [ $? -eq 0 ]; then + echo "Error - fading file too short" + exit 1 + fi + + SNR_mean=$(cat $rade_log | grep "Measured" | tr -s ' ' | cut -d' ' -f4) + CNo_mean=$(cat $rade_log | grep "Measured" | tr -s ' ' | cut -d' ' -f3) + else + # $mode == "fargan" + ./inference.sh model19_check3/checkpoints/checkpoint_epoch_100.pth ${in} out.wav --auxdata --passthru + #sox -t .s16 -r 16000 -c 1 ${in} out.wav + fi + + # extract individual output files + duration_array=( ${duration_log} ) + i=0 + st=0 + for f in $flac + do + dur=${duration_array[i]} + dur=$(python3 -c "print($dur + ${sil})") + #printf "%4d %s %5.2f %5.2f\n" $i $f $st $dur + ((i++)) + if [ $i -eq ${#duration_array[@]} ]; then + sox out.wav ${dest}/${f} trim $st + else + sox out.wav ${dest}/${f} trim $st $dur + fi + st=$(python3 -c "print($st + $dur)") + done + fi + + # test mode that just copies files + if [ $mode == "clean" ]; then + for f in $flac + do + cp ${source}/${f} ${dest}/${f} + done + + fi + + python3 asr_wer.py test-other -n $n_samples --model turbo | tee > $asr_log + wer=$(tail -n1 $asr_log | tr -s ' ' | cut -d' ' -f2) + if [ $mode == "ssb" ] || [ $mode == "rade" ] || [ $mode == "700D" ]; then + printf "%-6s %5.2f %5.2f %5.2f\n" $mode $SNR_mean $CNo_mean $wer | tee -a $results + else + printf "%-6s %5.2f\n" $mode $wer | tee -a $results + fi +} + +cp_translation_files +process + diff --git a/asr_test_top.sh b/asr_test_top.sh new file mode 100755 index 0000000..41631e0 --- /dev/null +++ b/asr_test_top.sh @@ -0,0 +1,64 @@ +#!/usr/bin/env bash +# asr_test_awgn.sh +# +# Top level ASR test script for AWGN and MPP channels +set -x +results_file=241221_asr +n=500 + +function ssb { + local results_file=$1 + No_range=$2 + for No in $No_range + do + ./asr_test.sh ssb --No $No -n $n --results ${results_file} $3 + done + cat ${results_file} | grep ssb | sed -e "s/ssb//" > tmp.txt + mv tmp.txt ${results_file} +} + +function rade { + local results_file=$1 + EbNodB_range=$2 + for EbNodB in $EbNodB_range + do + ./asr_test.sh rade --EbNodB $EbNodB -n $n --results ${results_file} $3 + done + cat ${results_file} | grep rade | sed -e "s/rade//" > tmp.txt + mv tmp.txt ${results_file} +} + +function freedv_700D { + local results_file=$1 + No_range=$2 + for No in $No_range + do + ./asr_test.sh 700D --No $No -n $n --results ${results_file} $3 + done + cat ${results_file} | grep 700D | sed -e "s/700D//" > tmp.txt + mv tmp.txt ${results_file} +} + +freedv_700D ${results_file}_awgn_700D.txt "-100 -30 -26 -23 -20 -17 -15 -13" +freedv_700D ${results_file}_mpp_700D.txt "-100 -39 -36 -33 -30 -27" "--g_file g_mpp.f32" +#freedv_700D ${results_file}_awgn_700D.txt "-100 -38 -35 -32 -29 -26 -23 -20 -17" +#freedv_700D ${results_file}_mpp_700D.txt "-100 -44 -39 -36 -33 -30 -27" "--g_file g_mpp.f32" +exit 0 + +# run the controls +controls_file=${results_file}_controls.txt +rm -f ${controls_file} +./asr_test.sh clean -n $n --results ${controls_file} +./asr_test.sh fargan -n $n --results ${controls_file} +./asr_test.sh 4kHz -n $n --results ${controls_file} +./asr_test.sh ssb -n $n --results ${controls_file} +./asr_test.sh rade -n $n --results ${controls_file} +# strip off all but last column for Octave plotting +cat ${controls_file} | awk '{print $NF}' > ${results_file}_c.txt + + +ssb ${results_file}_awgn_ssb.txt "-100 -38 -35 -32 -29 -26 -23 -20 -17" +rade ${results_file}_awgn_rade.txt "100 15 10 5 2.5 0 -2.5" +ssb ${results_file}_mpp_ssb.txt "-100 -44 -39 -36 -33 -30 -27" "--g_file g_mpp.f32" +rade ${results_file}_mpp_rade.txt "100 15 10 5 2.5 0" "--g_file g_mpp.f32" + diff --git a/asr_wer.py b/asr_wer.py new file mode 100644 index 0000000..68b441a --- /dev/null +++ b/asr_wer.py @@ -0,0 +1,91 @@ +# coding: utf-8 + +# derived from: https://github.com/openai/whisper/blob/main/notebooks/LibriSpeech.ipynb + +import os,argparse +import numpy as np +import torch +import pandas as pd +import whisper +import torchaudio +from tqdm.notebook import tqdm + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +class LibriSpeech(torch.utils.data.Dataset): + """ + A simple class to wrap LibriSpeech and trim/pad the audio to 30 seconds. + It will drop the last few seconds of a very small portion of the utterances. + """ + def __init__(self, n_mels, split="test-clean", device=DEVICE): + self.dataset = torchaudio.datasets.LIBRISPEECH( + root=os.path.expanduser("~/.cache"), + url=split, + download=True, + ) + self.device = device + self.n_mels = n_mels + print(n_mels) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, item): + audio, sample_rate, text, _, _, _ = self.dataset[item] + assert sample_rate == 16000 + audio = whisper.pad_or_trim(audio.flatten()).to(self.device) + mel = whisper.log_mel_spectrogram(audio,n_mels=self.n_mels) + + return (mel, text) + + +parser = argparse.ArgumentParser() +parser.add_argument('test_name', type=str, help='Librispeech dataset name (e.g. test-clean)') +parser.add_argument('-n', type=str, help='Number of dataset entries to use (default all of them)') +parser.add_argument('--model', default='base.en',type=str, help='Whisper model') +args = parser.parse_args() + +model = whisper.load_model(args.model) +print( + f"Model is {'multilingual' if model.is_multilingual else 'English-only'} " + f"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters." +) +# predict without timestamps for short-form transcription +options = whisper.DecodingOptions(language="en", without_timestamps=True) + +dataset = LibriSpeech(model.dims.n_mels, args.test_name) +if args.n: + dataset = torch.utils.data.Subset(dataset,list(range(0,int(args.n)))) +print("dataset length:", dataset.__len__()) +loader = torch.utils.data.DataLoader(dataset, batch_size=16) + + +hypotheses = [] +references = [] + +for mels, texts in loader: + results = model.decode(mels, options) + hypotheses.extend([result.text for result in results]) + references.extend(texts) + +data = pd.DataFrame(dict(hypothesis=hypotheses, reference=references)) + + +# # Calculating the word error rate +# +# Now, we use our English normalizer implementation to standardize the transcription and calculate the WER. + +import jiwer +from whisper.normalizers import EnglishTextNormalizer + +normalizer = EnglishTextNormalizer() + +data["hypothesis_clean"] = [normalizer(text) for text in data["hypothesis"]] +data["reference_clean"] = [normalizer(text) for text in data["reference"]] +print(data) + +wer = jiwer.wer(list(data["reference_clean"]), list(data["hypothesis_clean"])) + +print(f"WER: {wer * 100:.2f} %") + diff --git a/compare_models_inf.sh b/compare_models_inf.sh new file mode 100755 index 0000000..9f6fe7c --- /dev/null +++ b/compare_models_inf.sh @@ -0,0 +1,253 @@ +#!/bin/bash -x +# +# Compare models by plotting loss v SNR/PSNR curves from data generated by inference.py +# Similar to compare_models.sh, but uses time domain path in forward() that is closer to the +# real world configuration, e.g. time domain multipath model, cyclic prefix, ISI, pilot +# insertion and DSP EQ (if used). + +# Build an input test file from Librispeech +function build_input_file_from_librispeech() { + n_samples=$1 + input_file=$2 + source=~/.cache/LibriSpeech/test-clean + if [ ! -d $source ]; then + echo "cant find Librispeech source directory" $source + exit 1 + fi + flac=$(find ${source} -name '*.flac') + # randomise selection of files so we don't get them all from one speaker, + # --random-source makes it the same repeatable random sequence so we get the + # same results on each run + flac=$(echo "$flac" | shuf --random-source=<(yes 42) | head -n $n_samples) + n=$(echo "$flac" | wc -l) + printf "Collecting %d samples from Librispeech\n" $n + sox ${flac} -c 1 -r 16000 ${input_file} + dur=$(sox --info -D ${input_file}) + printf "%s duration %d\n" $input_file $dur +} + +# Run inference on a range of SNRs to compute loss +function run_model() { + model=$1 + dim=$2 + epoch=$3 + chan=$4 + freq_offset=$5 + shift + shift + shift + shift + shift + EbNodB_list='-3 0 3 6 9 12 15 18 21' + results=${model}_${chan}_${freq_offset}Hz_loss_SNR3k.txt + + # return if results file already exists + if [ $rebuild -eq 0 ]; then + if [ -f $results ]; then + return + fi + fi + + rm -f $results + for aEbNodB in $EbNodB_list + do + log=$(./inference.sh ${model}/checkpoints/checkpoint_epoch_${epoch}.pth ${input_file} /dev/null --bottleneck 3 --rate_Fs \ + --latent-dim ${dim} $@ --EbNodB ${aEbNodB} --freq_offset ${freq_offset}) + SNR3k=$(echo "$log" | grep "Measured:" | tr -s ' ' | cut -d' ' -f4) + PAPR=$(echo "$log" | grep "Measured:" | tr -s ' ' | cut -d' ' -f5) + loss=$(echo "$log" | grep "loss:" | tr -s ' ' | cut -d' ' -f2) + printf "%f\t%f\t%f\n" $SNR3k $loss $PAPR >> $results + done +} + +function print_help { + echo + echo " Compare models by plotting loss v SNR/PSNR curves from time domain inference" + echo + echo " usage ./compare_models_inf.sh [-n NumberSpeechSamples] [-p plotName] [-r]" + echo + echo " -n NumberSpeechSamples Use a low number (e.g. 10) when testing plots" + echo " -p Name of plot (see source)" + echo " -r Rebuild results files even if they already exists (e.g. if -n has changed)" + echo + exit +} + +# default is about 10 minutes long +n_samples=80 +input_file="wav/librispeech.wav" +plot="250413_inf" +rebuild=0 + +POSITIONAL=() +while [[ $# -gt 0 ]] +do +key="$1" +case $key in + -n) + n_samples="$2" + shift + shift + ;; + -p) + plot="$2" + shift + shift + ;; + -r) + rebuild=1 + shift + ;; + -h) + print_help + ;; + *) + POSITIONAL+=("$1") # save it in an array for later + shift + ;; +esac +done +set -- "${POSITIONAL[@]}" # restore positional parameters + +build_input_file_from_librispeech ${n_samples} ${input_file} + + +# compare RADE V1 to 250227b +if [ $plot == "250227b_inf" ]; then + #run_model model19_check3 80 100 awgn --tanh_clipper --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls + #run_model model19_check3 80 100 mpp --tanh_clipper --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls --g_file g_mpp.f32 + #run_model 250227b_test 40 200 awgn --cp 0.004 --time_offset -16 --correct_time_offset -32 + #run_model 250227b_test 40 200 mpp --cp 0.004 --time_offset -16 --correct_time_offset -32 --g_file g_mpp.f32 + + model_list='model19_check3_awgn model19_check3_mpp 250227b_test_awgn 250227b_test_mpp' + declare -a model_legend=("model19_check3 AWGN Nc=30" "model19_check3 MPP Nc=30" "250227b_test AWGN Nc=10" "250227b_test MPP Nc=10") +fi + +# comparsion of models trained with and without --freq_rand and --auxdata, 0 Hz freq offset +if [ $plot == "250412_inf" ]; then + #run_model model19_check3 80 100 awgn 0 --tanh_clipper --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls + #run_model 2504227b_test 40 200 awgn 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 + #run_model 250411 40 200 awgn 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata + #run_model 250411 40 200 awgn 1 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata + #run_model 250411 40 200 awgn 2 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata + #run_model 250411 40 200 awgn 3 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata + #run_model 250411b 40 200 awgn 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 + #run_model 250412 40 200 awgn 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata + + model_list='model19_check3_awgn_0Hz 250227b_test_awgn_0Hz 250411_awgn_0Hz 250411b_awgn_0Hz 250412_awgn_0Hz' + declare -a model_legend=("model19_check3" "250227b_test" "250411 --freq_rand --auxdata" "250411b --freq_rand" "250412 repeat of 250227b") +fi + +# compare RADE V1 to 250411 which can handle +/-2 Hz, although all curves tested here at 0 Hz offset +if [ $plot == "250413_inf" ]; then + run_model model19_check3 80 100 awgn 0 --tanh_clipper --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls + run_model model19_check3 80 100 mpp 0 --tanh_clipper --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls --g_file g_mpp.f32 + run_model 250411 40 200 awgn 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata + run_model 250411 40 200 mpp 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --g_file g_mpp.f32 --auxdata + + model_list='model19_check3_awgn_0Hz model19_check3_mpp_0Hz 250411_awgn_0Hz 250411_mpp_0Hz' + declare -a model_legend=("model19_check3 AWGN Nc=30" "model19_check3 MPP Nc=30" "250411 AWGN Nc=10" "250411 MPP Nc=10") +fi + +# 250411 at different freq offsets +if [ $plot == "250413a_inf" ]; then + run_model 250411 40 200 awgn 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata + run_model 250411 40 200 awgn 2 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata + run_model 250411 40 200 awgn -2 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata + run_model 250411 40 200 mpp 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --g_file g_mpp.f32 --auxdata + run_model 250411 40 200 mpp 2 --cp 0.004 --time_offset -16 --correct_time_offset -32 --g_file g_mpp.f32 --auxdata + run_model 250411 40 200 mpp -2 --cp 0.004 --time_offset -16 --correct_time_offset -32 --g_file g_mpp.f32 --auxdata + + model_list='250411_awgn_0Hz 250411_awgn_2Hz 250411_awgn_-2Hz 250411_mpp_0Hz 250411_mpp_2Hz 250411_mpp_-2Hz' + declare -a model_legend=("250411 AWGN 0 Hz" "250411 AWGN 2 Hz" "250411 AWGN -2 Hz" "250411 MPP 0 Hz" "250411 MPP 2 Hz" "250411 MPP -2 Hz") +fi + +# Basic AWGN test of frame step 2 models +if [ $plot == "250416_inf" ]; then + run_model model19_check3 80 100 awgn 0 --tanh_clipper --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls + run_model 250411 40 200 awgn 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata + run_model 250413_test 20 200 awgn 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata --frames_per_step 2 + run_model 250416_test 40 200 awgn 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata --frames_per_step 2 + + model_list='model19_check3_awgn_0Hz 250411_awgn_0Hz 250413_test_awgn_0Hz 250416_test_awgn_0Hz' + declare -a model_legend=("model19_check3 AWGN fs=4 d=80 Nc=30" "250411 AWGN fs=4 d=40 Nc=10" "250413_test AWGN fs=2 d=20 Nc=10" \ + "250416_test AWGN fs=2 d=40 Nc=20 start=-3dB") +fi + +# compare RADE V1 to 250415, poor MPP results as expected as not trained correctly. +if [ $plot == "250416b_inf" ]; then + run_model model19_check3 80 100 awgn 0 --tanh_clipper --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls + run_model model19_check3 80 100 mpp 0 --tanh_clipper --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls --g_file g_mpp.f32 + run_model 250415_test 40 200 awgn 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata --frames_per_step 2 + run_model 250415_test 40 200 mpp 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata --frames_per_step 2 --g_file g_mpp.f32 + + model_list='model19_check3_awgn_0Hz model19_check3_mpp_0Hz 250415_test_awgn_0Hz 250415_test_mpp_0Hz' + declare -a model_legend=("model19_check3 AWGN fs=4 d=80 Nc=30" "model19_check3 MPP fs=4 d=80 Nc=30" "250415 AWGN fs=2 d=40 Nc=20" "250415 MPP fs=2 d=40 Nc=20") +fi + +# compare RADE V1 to framestep 2 250416, MPP results a bit worse than RADE V1 at high SNR. Perhaps this is a result of the lower frame step +if [ $plot == "250416a_inf" ]; then + run_model model19_check3 80 100 awgn 0 --tanh_clipper --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls + run_model model19_check3 80 100 mpp 0 --tanh_clipper --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls --g_file g_mpp.f32 + run_model 250416_test 40 200 awgn 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata --frames_per_step 2 + run_model 250416a_test 80 200 awgn 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata --frames_per_step 8 + run_model 250416_test 40 200 mpp 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata --frames_per_step 2 --g_file g_mpp.f32 + run_model 250416a_test 80 200 mpp 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata --frames_per_step 8 --g_file g_mpp.f32 + + model_list='model19_check3_awgn_0Hz model19_check3_mpp_0Hz 250416_test_awgn_0Hz 250416_test_mpp_0Hz 250416a_test_awgn_0Hz 250416a_test_mpp_0Hz' + declare -a model_legend=("model19_check3 AWGN fs=4 d=80 Nc=30" "model19_check3 MPP fs=4 d=80 Nc=30" "250416 AWGN fs=2 d=40 Nc=20" \ + "250416 MPP fs=2 d=40 Nc=20" "250416a AWGN fs=8 d=80 Nc=10" "250416a MPP fs=8 d=80 Nc=10" ) +fi + +# compare RADE V1 to framestep 2,d=20,Nc=10 250413 +if [ $plot == "250417_inf" ]; then + run_model model19_check3 80 100 awgn 0 --tanh_clipper --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls + run_model model19_check3 80 100 mpp 0 --tanh_clipper --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls --g_file g_mpp.f32 + run_model 250413_test 20 200 awgn 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata --frames_per_step 2 + run_model 250413_test 20 200 mpp 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata --frames_per_step 2 --g_file g_mpp.f32 + + model_list='model19_check3_awgn_0Hz model19_check3_mpp_0Hz 250413_test_awgn_0Hz 250413_test_mpp_0Hz' + declare -a model_legend=("model19_check3 AWGN fs=4 d=80 Nc=30" "model19_check3 MPP fs=4 d=80 Nc=30" \ + "250413 AWGN fs=2 d=20 Nc=10" "250413 MPP fs=2 d=20 Nc=10" ) +fi + +# compare RADE V1 to framestep 4,d=10,Nc=10 250417 (240411 repeat with frame step arg) +if [ $plot == "250417a_inf" ]; then + run_model model19_check3 80 100 awgn 0 --tanh_clipper --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls + run_model model19_check3 80 100 mpp 0 --tanh_clipper --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls --g_file g_mpp.f32 + run_model 250417_test 40 200 awgn 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata + run_model 250417_test 40 200 mpp 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata --g_file g_mpp.f32 + run_model 250417a_test 40 200 awgn 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata + run_model 250417a_test 40 200 mpp 0 --cp 0.004 --time_offset -16 --correct_time_offset -32 --auxdata --g_file g_mpp.f32 + + model_list='model19_check3_awgn_0Hz model19_check3_mpp_0Hz 250417_test_awgn_0Hz 250417_test_mpp_0Hz 250417a_test_awgn_0Hz 250417a_test_mpp_0Hz' + declare -a model_legend=("model19_check3 AWGN fs=4 d=80 Nc=30" "model19_check3 MPP fs=4 d=80 Nc=30" \ + "250417 AWGN fs=4 d=40 Nc=10" "250417 MPP fs=4 d=40 Nc=10" \ + "250417a AWGN fs=4 d=40 Nc=10" "250417a MPP fs=4 d=40 Nc=10") +fi + +# compare RADE V1 candidate2 to RADE V1 candidate3 (prototyped in dr-asr branch). Tracking down training issue +# discovered in HF RADE paper review, mapping of z elements to symbols. +if [ $plot == "250503_inf" ]; then + run_model model19_check3 80 100 awgn 0 --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls + run_model model19_check3 80 100 mpp 0 --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls --g_file g_mpp.f32 + run_model 250502 60 200 awgn 0 --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls + run_model 250502 60 200 mpp 0 --cp 0.004 --time_offset -16 --auxdata --pilots --pilot_eq --eq_ls --g_file g_mpp.f32 + + model_list='model19_check3_awgn_0Hz model19_check3_mpp_0Hz 250502_awgn_0Hz 250502_mpp_0Hz' + declare -a model_legend=("cand2 model19_check3 AWGN d=80 Nc=30" "cand2 model19_check3 MPP d=80 Nc=30" \ + "cand3 250502 AWGN d=60 Nc=30" "cand3 250502 MPP d=60 Nc=30") +fi + +# Generate the plots in PNG and EPS form, file names have suffix of ${plot} +vargs="" +i=0 +for model in $model_list + do + vargs="${vargs},'${model}_loss_SNR3k.txt','${model_legend[i]}'" + ((i++)) + done +echo "radae_plots; loss_SNR3k_plot(psnr=0,'${plot}_loss_SNR3k_models',''${vargs}); quit" | octave-cli -qf # PNG +echo "radae_plots; loss_SNR3k_plot(psnr=1,'${plot}_loss_PNR3k_models',''${vargs}); quit" | octave-cli -qf # PNG +echo "radae_plots; loss_SNR3k_plot(psnr=1,'','${plot}_loss_PNR3k_models'${vargs}); quit" | octave-cli -qf # EPS + diff --git a/doc/FreeDV-032 Radio Autoencoder Waveform Design.ods b/doc/FreeDV-032 Radio Autoencoder Waveform Design.ods index d4c52ed..fdfb047 100644 Binary files a/doc/FreeDV-032 Radio Autoencoder Waveform Design.ods and b/doc/FreeDV-032 Radio Autoencoder Waveform Design.ods differ diff --git a/inference.py b/inference.py index f66a18c..62e580f 100644 --- a/inference.py +++ b/inference.py @@ -75,6 +75,7 @@ parser.add_argument('--sine_amp', type=float, default=0.0, help='single freq interferer level (default zero)') parser.add_argument('--sine_freq', type=float, default=1000.0, help='single freq interferer freq (default 1000Hz)') parser.add_argument('--auxdata', action='store_true', help='inject auxillary data symbol') +parser.add_argument('--print_frame', action='store_true', help='print OFDM modem frame and exit') args = parser.parse_args() if len(args.h_file): @@ -101,7 +102,8 @@ phase_offset=args.phase_offset, freq_offset=args.freq_offset, df_dt=args.df_dt, gain=args.gain, pilots=args.pilots, pilot_eq=args.pilot_eq, eq_mean6 = not args.eq_ls, cyclic_prefix = args.cp, time_offset=args.time_offset, coarse_mag=args.coarse_mag, - bottleneck=args.bottleneck, correct_freq_offset=args.correct_freq_offset) + bottleneck=args.bottleneck, correct_freq_offset=args.correct_freq_offset, + print_frame=args.print_frame) checkpoint = torch.load(args.model_name, map_location='cpu',weights_only=True) model.load_state_dict(checkpoint['state_dict'], strict=False) checkpoint['state_dict'] = model.state_dict() diff --git a/multipath_samples.m b/multipath_samples.m index 208f711..6646340 100644 --- a/multipath_samples.m +++ b/multipath_samples.m @@ -61,12 +61,14 @@ function multipath_samples(ch, Fs, Rs, Nc, Nseconds, H_fn, G_fn="") LCR_meas = LC/Nseconds subplot(211); hold on; stem(LC_log,sqrt(P)*ones(length(LC_log))); hold off; axis([0 Nsecplot*Rs 0 3]); end - printf("H file size is Nseconds*Rs*Nc*(4 bytes/sample) = %d*%d*%d*4 = %d bytes\n", Nseconds,Rs,Nc,Nseconds*Rs*Nc*4) - f=fopen(H_fn,"wb"); - [r c] = size(H); - Hflat = reshape(H', 1, r*c); - fwrite(f, Hflat, 'float32'); - fclose(f); + if length(H_fn) + printf("H file size is Nseconds*Rs*Nc*(4 bytes/sample) = %d*%d*%d*4 = %d bytes\n", Nseconds,Rs,Nc,Nseconds*Rs*Nc*4) + f=fopen(H_fn,"wb"); + [r c] = size(H); + Hflat = reshape(H', 1, r*c); + fwrite(f, Hflat, 'float32'); + fclose(f); + end if length(G_fn) % G matrix cols are G1 G2, rows timesteps, with hf_gain the first row, diff --git a/radae/radae.py b/radae/radae.py index 09fbe27..7e08bc1 100644 --- a/radae/radae.py +++ b/radae/radae.py @@ -80,7 +80,8 @@ def __init__(self, time_offset = 0, coarse_mag = False, correct_freq_offset = False, - stateful_decoder = False + stateful_decoder = False, + print_frame = False ): super(RADAE, self).__init__() @@ -110,6 +111,7 @@ def __init__(self, self.coarse_mag = coarse_mag self.correct_freq_offset = correct_freq_offset self.stateful_decoder = stateful_decoder + self.print_frame = print_frame # TODO: nn.DataParallel() shouldn't be needed self.core_encoder = nn.DataParallel(radae_base.CoreEncoder(feature_dim, latent_dim, bottleneck=bottleneck)) @@ -134,8 +136,9 @@ def __init__(self, # wide in frequency and Ns symbols in duration bps = 2 # BPSK symbols per QPSK symbol + # TODO: consider a better way to set this up, e.g. so we can handle candidate2/3 without code changes (if useful in future) if self.pilots: - Ts = 0.03 # OFDM QPSK symbol period (without pilots or CP) + Ts = 0.03 # OFDM QPSK symbol period else: Ts = 0.02 Rs = 1/Ts # OFDM QPSK symbol rate @@ -218,7 +221,7 @@ def __init__(self, eoo = torch.tanh(torch.abs(eoo)) * torch.exp(1j*torch.angle(eoo)) self.eoo = eoo - print(f"Rs: {Rs:5.2f} Rs': {Rs_dash:5.2f} Ts': {Ts_dash:5.3f} Nsmf: {Nsmf:3d} Ns: {Ns:3d} Nc: {Nc:3d} M: {self.M:d} Ncp: {self.Ncp:d}", file=sys.stderr) + print(f"d: {latent_dim:3d} Rs: {Rs:5.2f} Rs': {Rs_dash:5.2f} Ts': {Ts_dash:5.3f} Nsmf: {Nsmf:3d} Ns: {Ns:3d} Nc: {Nc:3d} M: {self.M:d} Ncp: {self.Ncp:d}", file=sys.stderr) self.Tmf = Tmf self.bps = bps @@ -455,8 +458,12 @@ def forward(self, features, H, G=None): # run encoder, outputs sequence of latents that each describe 40ms of speech z = self.core_encoder(features) + if self.ber_test: z = torch.sign(torch.rand_like(z)-0.5) + if self.print_frame: + # replace z with element indexes + z[:,:,:] = torch.arange(0,self.latent_dim, dtype=torch.float32) # map z to QPSK symbols, note Es = var(tx_sym) = 2 var(z) = 2 # assuming |z| ~ 1 after training @@ -469,7 +476,7 @@ def forward(self, features, H, G=None): # reshape into sequence of OFDM modem frames tx_sym = torch.reshape(tx_sym,(num_batches,num_timesteps_at_rate_Rs,self.Nc)) - + # optionally insert pilot symbols, at the start of each modem frame if self.pilots: num_modem_frames = num_timesteps_at_rate_Rs // self.Ns @@ -480,6 +487,18 @@ def forward(self, features, H, G=None): num_timesteps_at_rate_Rs = num_timesteps_at_rate_Rs + num_modem_frames tx_sym = torch.reshape(tx_sym_pilots,(num_batches, num_timesteps_at_rate_Rs, self.Nc)) + # optionally print modem frame + if self.print_frame: + Ns = self.Ns + if self.pilots: + Ns += 1 + print(file=sys.stderr) + for c in range(self.Nc): + for t in range(Ns): + print(f"{tx_sym[0,t,c]:5.0f}\t", end='', file=sys.stderr) + print(file=sys.stderr) + quit() + tx_before_channel = None rx = None if self.rate_Fs: diff --git a/radae/radae_base.py b/radae/radae_base.py index 95d8790..bab6236 100644 --- a/radae/radae_base.py +++ b/radae/radae_base.py @@ -185,7 +185,7 @@ def __init__(self, feature_dim, output_dim, bottleneck = 1): self.z_dense = nn.Linear(864, self.output_dim) nb_params = sum(p.numel() for p in self.parameters()) - print(f"encoder: {nb_params} weights", file=sys.stderr) + #print(f"encoder: {nb_params} weights", file=sys.stderr) # initialize weights self.apply(init_weights) @@ -251,7 +251,7 @@ def __init__(self, feature_dim, output_dim, bottleneck = 1): self.z_dense = nn.Linear(864, self.output_dim) nb_params = sum(p.numel() for p in self.parameters()) - print(f"encoder: {nb_params} weights", file=sys.stderr) + #print(f"encoder: {nb_params} weights", file=sys.stderr) # initialize weights self.apply(init_weights) @@ -326,7 +326,7 @@ def __init__(self, input_dim, output_dim): self.glu5 = GLU(96) nb_params = sum(p.numel() for p in self.parameters()) - print(f"decoder: {nb_params} weights", file=sys.stderr) + #print(f"decoder: {nb_params} weights", file=sys.stderr) # initialize weights self.apply(init_weights) @@ -393,7 +393,7 @@ def __init__(self, input_dim, output_dim): self.glu5 = GLU(96) nb_params = sum(p.numel() for p in self.parameters()) - print(f"decoder: {nb_params} weights", file=sys.stderr) + #print(f"decoder: {nb_params} weights", file=sys.stderr) # initialize weights self.apply(init_weights) diff --git a/radae_plots.m b/radae_plots.m index 3925862..84e446d 100644 --- a/radae_plots.m +++ b/radae_plots.m @@ -176,6 +176,45 @@ function loss_CNo_plot(png_fn, Rs, B, varargin) end endfunction +% Plots loss v SNR3k curves from text files dumped by inference.py, see compare_models_inf.py +% pnsr flag optionally includes PAPR +function loss_SNR3k_plot(pnsr=0,png_fn, epslatex, varargin) + if length(epslatex) + [textfontsize linewidth] = set_fonts(20); + end + figure(1); clf; hold on; + i = 1; + mn = 100; + while i <= length(varargin) + fn = varargin{i}; + data = load(fn); + i++; leg = varargin{i}; leg = strrep (leg, "_", " "); + SNR3k = data(:,1); + if pnsr + SNR3k += data(:,3); + end + mn = min([mn; SNR3k]); + plot(SNR3k,data(:,2),sprintf("+-;%s;",leg)) + i++; + end + hold off; grid('minor'); + if pnsr + xlabel('PNR (dB)'); + else + xlabel('SNR (dB)'); + end + ylabel('loss'); + mn = floor(mn); + axis([-5 20 0.05 0.35]) + legend('boxoff'); + if length(png_fn) + print("-dpng",png_fn); + end + if length(epslatex) + print_eps_restore(epslatex,"-S300,300",textfontsize,linewidth); + end +endfunction + % usage: % radae_plots; ofdm_sync_plots("","ofdm_sync.txt","go-;genie;","ofdm_sync_pilot_eq.txt","r+-;mean6;","ofdm_sync_pilot_eq_f2.txt","bx-;mean6 2 Hz;","ofdm_sync_pilot_eq_g0.1.txt","gx-;mean6 gain 0.1;","ofdm_sync_pilot_eq_ls.txt","ro-;LS;","ofdm_sync_pilot_eq_ls_f2.txt","bo-;LS 2 Hz;") @@ -440,6 +479,7 @@ function plot_SNR_CNR(epslatex="") plot(theta, phi, "r-;phi;") hold off endfunction + function compare_pitch_corr(wav_fn,feat1_fn,feat2_fn,png_feat_fn="") Fs=16000; s=load_raw(wav_fn); @@ -483,3 +523,66 @@ function plot_sample_spec(wav_fn,png_spec_fn="") print("-dpng",png_spec_fn,"-S800,600"); end end + +function plot_wer(prefix_fn, png_fn="", epslatex="") + ssb_awgn_fn = sprintf("%s_asr_awgn_ssb.txt",prefix_fn); + rade_awgn_fn = sprintf("%s_asr_awgn_rade.txt",prefix_fn); + freedv_700D_awgn_fn = sprintf("%s_asr_awgn_700D.txt",prefix_fn); + ssb_mpp_fn = sprintf("%s_asr_mpp_ssb.txt",prefix_fn); + rade_mpp_fn = sprintf("%s_asr_mpp_rade.txt",prefix_fn); + freedv_700D_mpp_fn = sprintf("%s_asr_mpp_700D.txt",prefix_fn); + controls_fn = sprintf("%s_asr_c.txt",prefix_fn); + + ssb_awgn = load(ssb_awgn_fn); + rade_awgn = load(rade_awgn_fn); + freedv_700D_awgn = load(freedv_700D_awgn_fn); + ssb_mpp = load(ssb_mpp_fn); + rade_mpp = load(rade_mpp_fn); + freedv_700D_mpp = load(freedv_700D_mpp_fn); + c = load(controls_fn); + + if length(epslatex) + [textfontsize linewidth] = set_fonts(20); + end + + # WER v C/No plot + figure(1); clf; + plot(ssb_awgn(:,2),ssb_awgn(:,3),'b+-;SSB AWGN;'); + hold on; + plot(rade_awgn(:,2),rade_awgn(:,3),'g+-;RADE AWGN;'); + plot(freedv_700D_awgn(:,2),freedv_700D_awgn(:,3),'r+-;700D AWGN;'); + plot(ssb_mpp(:,2),ssb_mpp(:,3),'bo--;SSB MPP;'); + plot(rade_mpp(:,2),rade_mpp(:,3),'go--;RADE MPP;'); + plot(freedv_700D_mpp(:,2),freedv_700D_mpp(:,3),'ro--;700D MPP;'); + xmin=30; xmax=60; + plot(xmax-5,c(1),'cx;clean;') + plot(xmax-5,c(2),'mo;FARGAN;') + plot(xmax-5,c(3),'k+;4kHz;') + hold off; + axis([xmin,xmax,0,40]); grid; ylabel('WER \%'); xlabel("C/No (dB)"); + + # WER v SNR plot + figure(2); clf; + plot(ssb_awgn(:,1),ssb_awgn(:,3),'b+-;SSB AWGN;'); + hold on; + plot(rade_awgn(:,1),rade_awgn(:,3),'r+-;RADE AWGN;'); + plot(freedv_700D_awgn(:,1),freedv_700D_awgn(:,3),'g+-;700D AWGN;'); + plot(ssb_mpp(:,1),ssb_mpp(:,3),'bo--;SSB MPP;'); + plot(rade_mpp(:,1),rade_mpp(:,3),'ro--;RADE MPP;'); + plot(freedv_700D_mpp(:,1),freedv_700D_mpp(:,3),'go--;700D MPP;'); + xmin=-5; xmax=20; + # plot(xmax-5,c(2),'mo;FARGAN;') + # plot(xmax-5,c(1),'cx;clean;') + plot([xmin xmax],[c(2) c(2)],'m-;FARGAN;') + plot([xmin xmax],[c(1) c(1)],'c-;clean;') + hold off; + axis([xmin,xmax,0,40]); grid; ylabel('WER (\%)'); xlabel("SNR3k (dB)"); + legend('boxoff'); legend("left"); + + if length(png_fn) + print("-dpng",png_fn,"-S800,600"); + end + if length(epslatex) + print_eps_restore(epslatex,"-S250,250",textfontsize,linewidth); + end +end diff --git a/train.py b/train.py index 2c934a9..1b24ad3 100644 --- a/train.py +++ b/train.py @@ -71,6 +71,7 @@ training_group.add_argument('--plot_loss', action='store_true', help='plot loss versus epoch as we train') training_group.add_argument('--plot_EqNo', type=str, default="", help='plot loss versus Eq/No for final epoch') training_group.add_argument('--auxdata', action='store_true', help='inject auxillary data symbol') +training_group.add_argument('--print_frame', action='store_true', help='print OFDM modem frame and exit') args = parser.parse_args() @@ -123,7 +124,8 @@ rate_Fs = args.rate_Fs, range_EbNo_start=args.range_EbNo_start, freq_rand=args.freq_rand,gain_rand=args.gain_rand, bottleneck=args.bottleneck, - pilots=args.pilots, pilot_eq=args.pilot_eq, eq_mean6 = not args.eq_ls, cyclic_prefix = args.cp) + pilots=args.pilots, pilot_eq=args.pilot_eq, eq_mean6 = not args.eq_ls, + cyclic_prefix = args.cp, print_frame=args.print_frame) if type(args.initial_checkpoint) != type(None): print(f"Loading from checkpoint: {args.initial_checkpoint}")