From ad6274f8814911dabc4fcc2d3d8256bfab69823b Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 27 Feb 2026 16:16:32 +0100 Subject: [PATCH 01/33] parakeet : add support for NVIDIA Parakeet --- CMakeLists.txt | 37 + bindings/ruby/ext/extconf.rb | 2 +- cmake/parakeet-config.cmake.in | 30 + cmake/parakeet.pc.in | 10 + examples/CMakeLists.txt | 1 + examples/parakeet-cli/CMakeLists.txt | 8 + examples/parakeet-cli/README.md | 112 + examples/parakeet-cli/parakeet-cli.cpp | 220 ++ include/parakeet.h | 383 +++ models/convert-parakeet-to-ggml.py | 331 ++ models/requirements-parakeet.txt | 1 + scripts/upload-parakeet.py | 75 + src/CMakeLists.txt | 23 + src/parakeet-arch.h | 194 ++ src/parakeet.cpp | 4315 ++++++++++++++++++++++++ tests/CMakeLists.txt | 29 + tests/test-parakeet-full.cpp | 62 + tests/test-parakeet-stream.cpp | 107 + tests/test-parakeet.cpp | 99 + 19 files changed, 6038 insertions(+), 1 deletion(-) create mode 100644 cmake/parakeet-config.cmake.in create mode 100644 cmake/parakeet.pc.in create mode 100644 examples/parakeet-cli/CMakeLists.txt create mode 100644 examples/parakeet-cli/README.md create mode 100644 examples/parakeet-cli/parakeet-cli.cpp create mode 100644 include/parakeet.h create mode 100755 models/convert-parakeet-to-ggml.py create mode 100644 models/requirements-parakeet.txt create mode 100644 scripts/upload-parakeet.py create mode 100644 src/parakeet-arch.h create mode 100644 src/parakeet.cpp create mode 100644 tests/test-parakeet-full.cpp create mode 100644 tests/test-parakeet-stream.cpp create mode 100644 tests/test-parakeet.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index a0f74041321..7aa4728ba81 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -179,12 +179,20 @@ set(WHISPER_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location get_directory_property(WHISPER_TRANSIENT_DEFINES COMPILE_DEFINITIONS) set_target_properties(whisper PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/whisper.h) + install(TARGETS whisper LIBRARY PUBLIC_HEADER) target_compile_definitions(whisper PRIVATE WHISPER_VERSION="${PROJECT_VERSION}" ) +set_target_properties(parakeet PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/parakeet.h) +install(TARGETS parakeet LIBRARY PUBLIC_HEADER) + +target_compile_definitions(parakeet PRIVATE + PARAKEET_VERSION="${PROJECT_VERSION}" +) + configure_package_config_file( ${CMAKE_CURRENT_SOURCE_DIR}/cmake/whisper-config.cmake.in ${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake @@ -210,6 +218,35 @@ configure_file(cmake/whisper.pc.in install(FILES "${CMAKE_CURRENT_BINARY_DIR}/whisper.pc" DESTINATION lib/pkgconfig) +set(PARAKEET_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files") +set(PARAKEET_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") +set(PARAKEET_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") + +configure_package_config_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/parakeet-config.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet + PATH_VARS + PARAKEET_INCLUDE_INSTALL_DIR + PARAKEET_LIB_INSTALL_DIR + PARAKEET_BIN_INSTALL_DIR) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake + VERSION ${WHISPER_INSTALL_VERSION} + COMPATIBILITY SameMajorVersion) + +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet) + +configure_file(cmake/parakeet.pc.in + "${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc" + @ONLY) + +install(FILES "${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc" + DESTINATION lib/pkgconfig) + # # programs, examples and tests # diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index acff501aa3b..ab740ca4c29 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -18,6 +18,6 @@ #{libs}: cmake-targets cmake-targets: #{"\t"}#{cmake} -S sources -B build -D BUILD_SHARED_LIBS=OFF -D CMAKE_ARCHIVE_OUTPUT_DIRECTORY=#{__dir__} -D CMAKE_POSITION_INDEPENDENT_CODE=ON #{options} - #{"\t"}#{cmake} --build build --config Release --target common whisper + #{"\t"}#{cmake} --build build --config Release --target common whisper parakeet EOF end diff --git a/cmake/parakeet-config.cmake.in b/cmake/parakeet-config.cmake.in new file mode 100644 index 00000000000..aadb55c2d19 --- /dev/null +++ b/cmake/parakeet-config.cmake.in @@ -0,0 +1,30 @@ +set(PARAKEET_VERSION @WHISPER_INSTALL_VERSION@) +set(PARAKEET_BUILD_COMMIT @WHISPER_BUILD_COMMIT@) +set(PARAKEET_BUILD_NUMBER @WHISPER_BUILD_NUMBER@) +set(PARAKEET_SHARED_LIB @BUILD_SHARED_LIBS@) + +@PACKAGE_INIT@ + +set_and_check(PARAKEET_INCLUDE_DIR "@PACKAGE_PARAKEET_INCLUDE_INSTALL_DIR@") +set_and_check(PARAKEET_LIB_DIR "@PACKAGE_PARAKEET_LIB_INSTALL_DIR@") +set_and_check(PARAKEET_BIN_DIR "@PACKAGE_PARAKEET_BIN_INSTALL_DIR@") + +find_package(ggml REQUIRED HINTS ${PARAKEET_LIB_DIR}/cmake) + +find_library(parakeet_LIBRARY parakeet + REQUIRED + HINTS ${PARAKEET_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH +) + +add_library(parakeet UNKNOWN IMPORTED) +set_target_properties(parakeet + PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PARAKEET_INCLUDE_DIR}" + INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + IMPORTED_LOCATION "${parakeet_LIBRARY}" + INTERFACE_COMPILE_FEATURES cxx_std_11 + POSITION_INDEPENDENT_CODE ON) + +check_required_components(parakeet) diff --git a/cmake/parakeet.pc.in b/cmake/parakeet.pc.in new file mode 100644 index 00000000000..a83cc472547 --- /dev/null +++ b/cmake/parakeet.pc.in @@ -0,0 +1,10 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=${exec_prefix}/lib +includedir=${prefix}/include + +Name: parakeet +Description: Port of NVIDIA's Parakeet model in C/C++ +Version: @PROJECT_VERSION@ +Libs: -L${libdir} -lggml -lggml-base -lparakeet +Cflags: -I${includedir} diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index b202ca00b77..2cf5a67aff3 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -107,6 +107,7 @@ else() add_subdirectory(server) add_subdirectory(quantize) add_subdirectory(vad-speech-segments) + add_subdirectory(parakeet-cli) if (WHISPER_SDL2) add_subdirectory(stream) add_subdirectory(command) diff --git a/examples/parakeet-cli/CMakeLists.txt b/examples/parakeet-cli/CMakeLists.txt new file mode 100644 index 00000000000..adb9aba38ef --- /dev/null +++ b/examples/parakeet-cli/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET parakeet-cli) +add_executable(${TARGET} parakeet-cli.cpp) + +include(DefaultTargetOptions) + +target_link_libraries(${TARGET} PRIVATE common parakeet ${FFMPEG_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) + +install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/parakeet-cli/README.md b/examples/parakeet-cli/README.md new file mode 100644 index 00000000000..25a3f1e49c1 --- /dev/null +++ b/examples/parakeet-cli/README.md @@ -0,0 +1,112 @@ +# whisper.cpp/examples/parakeet-cli + +This is an example of using the [Parakeet] model in whisper.cpp. + +### Download converted model +```console +$ hf download danbev/parakeet parakeet-tdt-0.6b-v3.bin --local-dir models +``` + +### Building +```console +$ cmake -B build -S . +$ cmake --build build --target parakeet-cli -j 12 +``` + +### Usage +```console +$ ./build/bin/parakeet-cli --help + +usage: ./build/bin/parakeet-cli [options] file0 file1 ... +supported audio formats: flac, mp3, ogg, wav + +options: + -h, --help [default] show this help message and exit + -t N, --threads N [4 ] number of threads to use during computation + -cl N, --chunk-length N [10000 ] chunk length in milliseconds + -lc N, --left-context N [10000 ] left context in milliseconds + -rc N, --right-context N [4960 ] right context in milliseconds + -m, --model FILE [models/ggml-parakeet-tdt-0.6b-v3.bin] model path + -f, --file FILE [ ] input audio file + -ng, --no-gpu [false ] disable GPU + -dev N, --device N [0 ] GPU device to use + -fa, --flash-attn [true ] enable flash attention + -nfa, --no-flash-attn [false ] disable flash attention + -ps, --print-segments [false ] print segment information +``` + +### Example +```console +$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3.bin -f samples/jfk.wav +Processing audio (176000 samples, 11.00 seconds) +Processing audio: total_frames=1101, chunk_size=1101 +parakeet_decode: starting decode with n_frames=138 +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. +``` + +To print segment information: +```console +$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3.bin -f samples/jfk.wav --print-segments +Processing audio (176000 samples, 11.00 seconds) +Processing audio: total_frames=1101, chunk_size=1101 +parakeet_decode: starting decode with n_frames=138 +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. + +Segments (1): +Segment 0: [0 -> 1101] "And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country." +Tokens [38]: + [ 0] id= 1976 frame= 3 dur_idx= 4 dur_val= 4 p=0.9996 plog=-15.6206 t0= 24 t1= 56 word_start=true "▁And" + [ 1] id= 547 frame= 7 dur_idx= 4 dur_val= 4 p=0.9999 plog=-18.7922 t0= 56 t1= 88 word_start=true "▁so" + [ 2] id= 7877 frame= 11 dur_idx= 2 dur_val= 2 p=0.8451 plog=-14.5929 t0= 88 t1= 88 word_start=false "," + [ 3] id= 1103 frame= 13 dur_idx= 3 dur_val= 3 p=0.9996 plog=-15.6127 t0= 104 t1= 128 word_start=true "▁my" + [ 4] id= 309 frame= 16 dur_idx= 1 dur_val= 1 p=0.9912 plog=-11.9635 t0= 128 t1= 136 word_start=true "▁f" + [ 5] id= 530 frame= 17 dur_idx= 2 dur_val= 2 p=1.0000 plog=-13.5239 t0= 136 t1= 152 word_start=false "ell" + [ 6] id= 596 frame= 19 dur_idx= 3 dur_val= 3 p=1.0000 plog=-16.3120 t0= 152 t1= 176 word_start=false "ow" + [ 7] id= 3213 frame= 22 dur_idx= 4 dur_val= 4 p=0.9999 plog=-10.1462 t0= 176 t1= 208 word_start=true "▁Amer" + [ 8] id= 404 frame= 26 dur_idx= 4 dur_val= 4 p=1.0000 plog=-25.0910 t0= 208 t1= 240 word_start=false "ic" + [ 9] id= 667 frame= 30 dur_idx= 4 dur_val= 4 p=1.0000 plog=-27.1707 t0= 240 t1= 272 word_start=false "ans" + [10] id= 7877 frame= 37 dur_idx= 4 dur_val= 4 p=0.9094 plog=-16.3405 t0= 272 t1= 272 word_start=false "," + [11] id= 279 frame= 41 dur_idx= 4 dur_val= 4 p=0.9980 plog=-19.7244 t0= 328 t1= 360 word_start=true "▁a" + [12] id= 583 frame= 45 dur_idx= 4 dur_val= 4 p=1.0000 plog=-24.5312 t0= 360 t1= 392 word_start=false "sk" + [13] id= 1491 frame= 53 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2991 t0= 424 t1= 456 word_start=true "▁not" + [14] id= 3470 frame= 65 dur_idx= 4 dur_val= 4 p=0.9995 plog=-16.7306 t0= 520 t1= 552 word_start=true "▁what" + [15] id= 3629 frame= 69 dur_idx= 2 dur_val= 2 p=0.8139 plog=-11.6486 t0= 552 t1= 568 word_start=true "▁your" + [16] id= 867 frame= 75 dur_idx= 1 dur_val= 1 p=0.9980 plog=-12.5265 t0= 600 t1= 608 word_start=true "▁co" + [17] id= 331 frame= 76 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.6697 t0= 608 t1= 624 word_start=false "un" + [18] id= 958 frame= 78 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.3621 t0= 624 t1= 640 word_start=false "tr" + [19] id= 7893 frame= 80 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.3245 t0= 640 t1= 656 word_start=false "y" + [20] id= 2059 frame= 82 dur_idx= 3 dur_val= 3 p=1.0000 plog=-17.7694 t0= 656 t1= 680 word_start=true "▁can" + [21] id= 458 frame= 85 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2510 t0= 680 t1= 712 word_start=true "▁do" + [22] id= 509 frame= 89 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.0688 t0= 712 t1= 744 word_start=true "▁for" + [23] id= 1180 frame= 93 dur_idx= 4 dur_val= 4 p=0.9999 plog=-25.0567 t0= 744 t1= 776 word_start=true "▁you" + [24] id= 7877 frame= 98 dur_idx= 4 dur_val= 4 p=0.8820 plog=-14.2549 t0= 776 t1= 776 word_start=false "," + [25] id= 279 frame=102 dur_idx= 3 dur_val= 3 p=0.9992 plog=-16.8176 t0= 816 t1= 840 word_start=true "▁a" + [26] id= 583 frame=105 dur_idx= 4 dur_val= 4 p=1.0000 plog=-21.0352 t0= 840 t1= 872 word_start=false "sk" + [27] id= 3470 frame=109 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.4659 t0= 872 t1= 896 word_start=true "▁what" + [28] id= 1180 frame=112 dur_idx= 4 dur_val= 4 p=0.9997 plog=-17.6392 t0= 896 t1= 928 word_start=true "▁you" + [29] id= 2059 frame=116 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.5484 t0= 928 t1= 952 word_start=true "▁can" + [30] id= 458 frame=119 dur_idx= 2 dur_val= 2 p=1.0000 plog=-15.9953 t0= 952 t1= 968 word_start=true "▁do" + [31] id= 509 frame=121 dur_idx= 3 dur_val= 3 p=1.0000 plog=-15.9605 t0= 968 t1= 992 word_start=true "▁for" + [32] id= 3629 frame=124 dur_idx= 2 dur_val= 2 p=0.9994 plog=-12.2083 t0= 992 t1=1008 word_start=true "▁your" + [33] id= 867 frame=126 dur_idx= 2 dur_val= 2 p=0.9969 plog=-9.1252 t0=1008 t1=1024 word_start=true "▁co" + [34] id= 331 frame=128 dur_idx= 1 dur_val= 1 p=0.9999 plog=-12.6911 t0=1024 t1=1032 word_start=false "un" + [35] id= 958 frame=129 dur_idx= 1 dur_val= 1 p=1.0000 plog=-8.8885 t0=1032 t1=1040 word_start=false "tr" + [36] id= 7893 frame=130 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.1441 t0=1040 t1=1056 word_start=false "y" + [37] id= 7883 frame=132 dur_idx= 4 dur_val= 4 p=0.9567 plog=-11.5227 t0=1056 t1=1056 word_start=false "." +``` + +### Model conversion +Clone the original model from Hugging Face: +```console +$ git clone https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3 +``` +Convert the model: +```console +(venv) $ python models/convert-parakeet-to-ggml.py \ + --model \ + --use-f32 \ + --out-dir models \ + --out-name ggml-parakeet-tdt-0.6b-v3.bin +``` + +[Parakeet]: https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3 diff --git a/examples/parakeet-cli/parakeet-cli.cpp b/examples/parakeet-cli/parakeet-cli.cpp new file mode 100644 index 00000000000..bdb1739ce5e --- /dev/null +++ b/examples/parakeet-cli/parakeet-cli.cpp @@ -0,0 +1,220 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include +#include +#include +#include +#include + +// command-line parameters +struct parakeet_params { + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + int32_t chunk_length_ms = 10000; + int32_t left_context_ms = 10000; + int32_t right_context_ms = 4960; + + bool use_gpu = true; + bool flash_attn = true; + int32_t gpu_device = 0; + + bool print_segments = false; + + std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin"; + std::vector fname_inp = {}; +}; + +static void parakeet_print_usage(int argc, char ** argv, const parakeet_params & params); + +static char * requires_value_error(const std::string & arg) { + fprintf(stderr, "error: argument %s requires value\n", arg.c_str()); + exit(1); +} + +static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & params) { + if (const char * env_device = std::getenv("PARAKEET_ARG_DEVICE")) { + params.gpu_device = std::stoi(env_device); + } + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-"){ + params.fname_inp.push_back(arg); + continue; + } + + if (arg[0] != '-') { + params.fname_inp.push_back(arg); + continue; + } + + if (arg == "-h" || arg == "--help") { + parakeet_print_usage(argc, argv, params); + exit(0); + } + #define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg)) + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } + else if (arg == "-cl" || arg == "--chunk-length") { params.chunk_length_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-lc" || arg == "--left-context") { params.left_context_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-rc" || arg == "--right-context") { params.right_context_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } + else if (arg == "-ps" || arg == "--print-segments") { params.print_segments = true; } + else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + parakeet_print_usage(argc, argv, params); + exit(1); + } + } + + return true; +} + +static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_params & params) { + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options] file0 file1 ...\n", argv[0]); + fprintf(stderr, "supported audio formats: flac, mp3, ogg, wav\n"); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -cl N, --chunk-length N [%-7d] chunk length in milliseconds\n", params.chunk_length_ms); + fprintf(stderr, " -lc N, --left-context N [%-7d] left context in milliseconds\n", params.left_context_ms); + fprintf(stderr, " -rc N, --right-context N [%-7d] right context in milliseconds\n", params.right_context_ms); + fprintf(stderr, " -m, --model FILE [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f, --file FILE [%-7s] input audio file\n", ""); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -dev N, --device N [%-7d] GPU device to use\n", params.gpu_device); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", !params.flash_attn ? "true" : "false"); + fprintf(stderr, " -ps, --print-segments [%-7s] print segment information\n", params.print_segments ? "true" : "false"); + fprintf(stderr, "\n"); +} + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + static bool is_first = true; + + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf)); + printf("%s", text_buf); + fflush(stdout); + + is_first = false; +} + + +int main(int argc, char ** argv) { + ggml_backend_load_all(); + + parakeet_params params; + + if (parakeet_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.fname_inp.empty()) { + fprintf(stderr, "error: no input files specified\n"); + parakeet_print_usage(argc, argv, params); + return 1; + } + + // Process each input file + for (const auto & fname : params.fname_inp) { + fprintf(stderr, "\nProcessing file: %s\n", fname.c_str()); + + std::vector pcmf32; + std::vector> pcmf32s; + if (!read_audio_data(fname.c_str(), pcmf32, pcmf32s, false)) { + fprintf(stderr, "error: failed to read audio file '%s'\n", fname.c_str()); + continue; + } + + if (pcmf32.empty()) { + fprintf(stderr, "error: no audio data in file '%s'\n", fname.c_str()); + continue; + } + + fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str()); + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + ctx_params.use_gpu = params.use_gpu; + ctx_params.flash_attn = params.flash_attn; + ctx_params.gpu_device = params.gpu_device; + + struct parakeet_context * pctx = parakeet_init_from_file_with_params(params.model.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "error: failed to load Parakeet model from '%s'\n", params.model.c_str()); + return 1; + } + + fprintf(stderr, "Successfully loaded Parakeet model\n"); + fprintf(stderr, "Processing audio (%zu samples, %.2f seconds)\n", + pcmf32.size(), (float)pcmf32.size() / PARAKEET_SAMPLE_RATE); + + struct parakeet_full_params full_params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + full_params.n_threads = params.n_threads; + full_params.chunk_length_ms = params.chunk_length_ms; + full_params.left_context_ms = params.left_context_ms; + full_params.right_context_ms = params.right_context_ms; + full_params.new_token_callback = token_callback; + full_params.new_token_callback_user_data = nullptr; + + const int mel_frames = (int)(pcmf32.size() / PARAKEET_HOP_LENGTH); + if (mel_frames <= parakeet_n_audio_ctx(pctx)) { + full_params.chunk_length_ms = 0; + } + + int ret = parakeet_full(pctx, full_params, pcmf32.data(), pcmf32.size()); + + if (ret != 0) { + fprintf(stderr, "error: failed to process audio file '%s'\n", fname.c_str()); + parakeet_free(pctx); + continue; + } + + printf("\n"); + + if (params.print_segments) { + const int n_segments = parakeet_full_n_segments(pctx); + fprintf(stderr, "\nSegments (%d):\n", n_segments); + + for (int i = 0; i < n_segments; i++) { + const char * text = parakeet_full_get_segment_text(pctx, i); + const int64_t t0 = parakeet_full_get_segment_t0(pctx, i); + const int64_t t1 = parakeet_full_get_segment_t1(pctx, i); + const int n_tokens = parakeet_full_n_tokens(pctx, i); + + fprintf(stderr, "Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text); + fprintf(stderr, "Tokens [%d]:\n", n_tokens); + + for (int j = 0; j < n_tokens; j++) { + parakeet_token_data token_data = parakeet_full_get_token_data(pctx, i, j); + const char * token_str = parakeet_token_to_str(pctx, token_data.id); + + fprintf(stderr, " [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%s \"%s\"\n", + j, + token_data.id, + token_data.frame_index, + token_data.duration_idx, + token_data.duration_value, + token_data.p, + token_data.plog, + (long long)token_data.t0, + (long long)token_data.t1, + token_data.is_word_start ? "true": "false", + token_str); + } + } + } + + parakeet_free(pctx); + } + + return 0; +} diff --git a/include/parakeet.h b/include/parakeet.h new file mode 100644 index 00000000000..34179f7df18 --- /dev/null +++ b/include/parakeet.h @@ -0,0 +1,383 @@ +#ifndef PARAKEET_H +#define PARAKEET_H + +#include "ggml.h" +#include "ggml-cpu.h" + +#include +#include +#include + +#ifdef __GNUC__ +# define PARAKEET_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define PARAKEET_DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define PARAKEET_DEPRECATED(func, hint) func +#endif + +#ifdef PARAKEET_SHARED +# ifdef _WIN32 +# ifdef PARAKEET_BUILD +# define PARAKEET_API __declspec(dllexport) +# else +# define PARAKEET_API __declspec(dllimport) +# endif +# else +# define PARAKEET_API __attribute__ ((visibility ("default"))) +# endif +#else +# define PARAKEET_API +#endif + +#define PARAKEET_SAMPLE_RATE 16000 +#define PARAKEET_HOP_LENGTH 160 + +#ifdef __cplusplus +extern "C" { +#endif + + struct parakeet_context; + struct parakeet_state; + struct parakeet_full_params; + + typedef int32_t parakeet_pos; + typedef int32_t parakeet_token; + typedef int32_t parakeet_seq_id; + + struct parakeet_context_params { + bool use_gpu; + bool flash_attn; + int gpu_device; // CUDA device + }; + + typedef struct parakeet_token_data { + parakeet_token id; // the BPE subword ID (0-8191) + + int duration_idx; // index into the models durations array + int duration_value; // actual duration value + int frame_index; + + float p; + float plog; + + int64_t t0; + int64_t t1; + + bool is_word_start; + } parakeet_token_data; + + typedef struct parakeet_model_loader { + void * context; + + size_t (*read)(void * ctx, void * output, size_t read_size); + bool (*eof)(void * ctx); + void (*close)(void * ctx); + } parakeet_model_loader; + + PARAKEET_API const char * parakeet_version(void); + + // Various functions for loading a ggml parakeet model. + // Allocate (almost) all memory needed for the model. + // Return NULL on failure + PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params (const char * path_model, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_with_params (struct parakeet_model_loader * loader, struct parakeet_context_params params); + + // These are the same as the above, but the internal state of the context is not allocated automatically + // It is the responsibility of the caller to allocate the state using parakeet_init_state() (#523) + PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params_no_state (const char * path_model, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_with_params_no_state (struct parakeet_model_loader * loader, struct parakeet_context_params params); + + PARAKEET_API struct parakeet_state * parakeet_init_state(struct parakeet_context * ctx); + + // Frees all allocated memory + PARAKEET_API void parakeet_free (struct parakeet_context * ctx); + PARAKEET_API void parakeet_free_state(struct parakeet_state * state); + PARAKEET_API void parakeet_free_params(struct parakeet_full_params * params); + PARAKEET_API void parakeet_free_context_params(struct parakeet_context_params * params); + + // Convert RAW PCM audio to log mel spectrogram. + // The resulting spectrogram is stored inside the default state of the provided parakeet context. + // Returns 0 on success + PARAKEET_API int parakeet_pcm_to_mel( + struct parakeet_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + PARAKEET_API int parakeet_pcm_to_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * samples, + int n_samples, + int n_threads); + + // This can be used to set a custom log mel spectrogram inside the default state of the provided parakeet context. + // Use this instead of parakeet_pcm_to_mel() if you want to provide your own log mel spectrogram. + // n_mel must be 80 + // Returns 0 on success + PARAKEET_API int parakeet_set_mel( + struct parakeet_context * ctx, + const float * data, + int n_len, + int n_mel); + + PARAKEET_API int parakeet_set_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * data, + int n_len, + int n_mel); + + // Run the Parakeet encoder on the log mel spectrogram stored inside the default state in the provided parakeet context. + // Make sure to call parakeet_pcm_to_mel() or parakeet_set_mel() first. + // offset can be used to specify the offset of the first frame in the spectrogram. + // Returns 0 on success + PARAKEET_API int parakeet_encode( + struct parakeet_context * ctx, + int offset, + int n_threads); + + PARAKEET_API int parakeet_encode_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + int offset, + int n_threads); + + // Convert the provided text into tokens. + // The tokens pointer must be large enough to hold the resulting tokens. + // Returns the number of tokens on success, no more than n_max_tokens + // Returns a negative number on failure - the number of tokens that would have been returned + // TODO: not sure if correct + PARAKEET_API int parakeet_tokenize( + struct parakeet_context * ctx, + const char * text, + parakeet_token * tokens, + int n_max_tokens); + + // Return the number of tokens in the provided text + // Equivalent to: -parakeet_tokenize(ctx, text, NULL, 0) + int parakeet_token_count(struct parakeet_context * ctx, const char * text); + + PARAKEET_API int parakeet_n_len (struct parakeet_context * ctx); // mel length + PARAKEET_API int parakeet_n_len_from_state(struct parakeet_state * state); // mel length + PARAKEET_API int parakeet_n_vocab (struct parakeet_context * ctx); + PARAKEET_API int parakeet_n_audio_ctx (struct parakeet_context * ctx); + + PARAKEET_API int parakeet_model_n_vocab (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_ctx (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_state(struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_head (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_layer(struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_mels (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_ftype (struct parakeet_context * ctx); + + // Token logits obtained from the last call to parakeet_full/parakeet_chunk + // The logits for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab + PARAKEET_API float * parakeet_get_logits (struct parakeet_context * ctx); + PARAKEET_API float * parakeet_get_logits_from_state(struct parakeet_state * state); + + // Token Id -> String. Uses the vocabulary in the provided context + PARAKEET_API const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token); + + PARAKEET_API int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len); + + // Special tokens + PARAKEET_API parakeet_token parakeet_token_blank (struct parakeet_context * ctx); + PARAKEET_API parakeet_token parakeet_token_unk (struct parakeet_context * ctx); + PARAKEET_API parakeet_token parakeet_token_bos(struct parakeet_context * ctx); + + // Performance information from the default state. + struct parakeet_timings { + float sample_ms; + float encode_ms; + float decode_ms; + float batchd_ms; + float prompt_ms; + }; + PARAKEET_API struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx); + PARAKEET_API void parakeet_print_timings(struct parakeet_context * ctx); + PARAKEET_API void parakeet_reset_timings(struct parakeet_context * ctx); + + // Print system information + PARAKEET_API const char * parakeet_print_system_info(void); + + // Available sampling strategies + enum parakeet_sampling_strategy { + PARAKEET_SAMPLING_GREEDY, + }; + + // Token callback. + // Called for each new predicted token. + // Use the parakeet_full_...() functions to obtain the text segments + typedef void (*parakeet_new_token_callback)( + struct parakeet_context * ctx, + struct parakeet_state * state, + const parakeet_token_data * token_data, + void * user_data); + + // Text segment callback + // Called on every newly generated text segment + // Use the parakeet_full_...() functions to obtain the text segments + typedef void (*parakeet_new_segment_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int n_new, void * user_data); + + // Progress callback + typedef void (*parakeet_progress_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int progress, void * user_data); + + // Encoder begin callback + // If not NULL, called before the encoder starts + // If it returns false, the computation is aborted + typedef bool (*parakeet_encoder_begin_callback)(struct parakeet_context * ctx, struct parakeet_state * state, void * user_data); + + // Parameters for the parakeet_full() function + // If you change the order or add new parameters, make sure to update the default values in parakeet.cpp: + // parakeet_full_default_params() + struct parakeet_full_params { + enum parakeet_sampling_strategy strategy; + + int n_threads; + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms + + bool no_context; // do not use past transcription (if any) as context + + // [EXPERIMENTAL] speed-up techniques + int audio_ctx; // overwrite the audio context size (0 = use default) + + int chunk_length_ms; // length of each chunk in ms + int left_context_ms; // left context in ms + int right_context_ms; // right context in ms + + // called for every newly generated text segment + parakeet_new_segment_callback new_segment_callback; + void * new_segment_callback_user_data; + + // called for every newly generated token + parakeet_new_token_callback new_token_callback; + void * new_token_callback_user_data; + + // called on each progress update + parakeet_progress_callback progress_callback; + void * progress_callback_user_data; + + // called each time before the encoder starts + parakeet_encoder_begin_callback encoder_begin_callback; + void * encoder_begin_callback_user_data; + + // called each time before ggml computation starts + ggml_abort_callback abort_callback; + void * abort_callback_user_data; + }; + + // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see parakeet_free_context_params() & parakeet_free_params() + PARAKEET_API struct parakeet_context_params * parakeet_context_default_params_by_ref(void); + PARAKEET_API struct parakeet_context_params parakeet_context_default_params (void); + + PARAKEET_API struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy); + PARAKEET_API struct parakeet_full_params parakeet_full_default_params (enum parakeet_sampling_strategy strategy); + + // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Not thread safe for same context + PARAKEET_API int parakeet_full( + struct parakeet_context * ctx, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + PARAKEET_API int parakeet_full_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + // Split the input audio in chunks and process each chunk separately using parakeet_full_with_state() + // Result is stored in the default state of the context + // Not thread safe if executed in parallel on the same context. + PARAKEET_API int parakeet_full_parallel( + struct parakeet_context * ctx, + struct parakeet_full_params params, + const float * samples, + int n_samples, + int n_processors); + + // Process a single chunk of audio data that fits within the model's audio context window. + // This is more efficient than parakeet_full() for short audio clips. + PARAKEET_API int parakeet_chunk( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + // Initialize streaming state for a new stream. + PARAKEET_API int parakeet_stream_init( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params); + + // Push audio samples in streaming mode. Internally this function will structure + // the samples in a buffer where with a left context, a center chunk, and a + // right context. The encoder will see the complete buffer which enables it + // to get boundry context for the target/center audio chunk. This avoids hard + // cut offs at the chunk boundaries. The joint network then only sees the + // center chunk and this function internally handles the context windowing. + PARAKEET_API int parakeet_stream_push( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * samples, + int n_samples); + + // Flush the final partial chunk at end-of-stream. + PARAKEET_API int parakeet_stream_flush( + struct parakeet_context * ctx, + struct parakeet_state * state); + + // Number of generated text segments + PARAKEET_API int parakeet_full_n_segments (struct parakeet_context * ctx); + PARAKEET_API int parakeet_full_n_segments_from_state(struct parakeet_state * state); + + // Get the start and end time of the specified segment + PARAKEET_API int64_t parakeet_full_get_segment_t0 (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment); + + PARAKEET_API int64_t parakeet_full_get_segment_t1 (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment); + + // Get the text of the specified segment + PARAKEET_API const char * parakeet_full_get_segment_text (struct parakeet_context * ctx, int i_segment); + PARAKEET_API const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment); + + // Get number of tokens in the specified segment + PARAKEET_API int parakeet_full_n_tokens (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment); + + // Get the token text of the specified token in the specified segment + PARAKEET_API const char * parakeet_full_get_token_text (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token); + + // Get the token id of the specified token in the specified segment + PARAKEET_API parakeet_token parakeet_full_get_token_id (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Get token data for the specified token in the specified segment + PARAKEET_API parakeet_token_data parakeet_full_get_token_data (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Get the probability of the specified token in the specified segment + PARAKEET_API float parakeet_full_get_token_p (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Control logging output; default behavior is to print to stderr + + PARAKEET_API void parakeet_log_set(ggml_log_callback log_callback, void * user_data); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/models/convert-parakeet-to-ggml.py b/models/convert-parakeet-to-ggml.py new file mode 100755 index 00000000000..1cceba3ac92 --- /dev/null +++ b/models/convert-parakeet-to-ggml.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python3 +# Convert Parakeet TDT model from NeMo format to ggml format +# +# Usage: python convert-parakeet-to-ggml.py --model parakeet-model.nemo --output-dir output-dir [--use-f32] +# +# The NeMo file is a tar archive containing: +# - model_weights.ckpt (PyTorch checkpoint) +# - model_config.yaml (model configuration) +# - tokenizer files +# +# This script extracts the NeMo archive, loads the model weights and configuration, +# and saves them in ggml format compatible with whisper.cpp. +# + +import torch +import argparse +import io +import os +import sys +import struct +import tarfile +import tempfile +import shutil +import yaml +import numpy as np +from pathlib import Path +from typing import Optional + +def hz_to_mel(freq): + return 2595.0 * np.log10(1.0 + freq / 700.0) + +def mel_to_hz(mel): + return 700.0 * (10.0**(mel / 2595.0) - 1.0) + +def create_relative_positional_encoding(d_model: int, n_pos_max_len: int) -> np.ndarray: + max_len = n_pos_max_len * 2 - 1 + log_10000 = np.log(10000.0) + pe = np.zeros((max_len, d_model), dtype=np.float32) + + for idx in range(max_len): + position = float((max_len // 2) - idx) + + for i in range(0, d_model, 2): + div_term = np.exp(-float(i) * log_10000 / float(d_model)) + angle = position * div_term + + pe[idx, i] = np.sin(angle) + pe[idx, i + 1] = np.cos(angle) + + return pe + +def extract_nemo_archive(nemo_path, extract_dir): + print(f"Extracting {nemo_path} to {extract_dir}") + with tarfile.open(nemo_path, 'r') as tar: + tar.extractall(path=extract_dir) + print("Extraction complete") + +def load_model_config(config_path): + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + return config + +def load_tokenizer(extract_dir, config): + tokenizer_model_path = None + tokenizer_vocab_path = None + + for file in os.listdir(extract_dir): + if file.endswith('_tokenizer.model'): + tokenizer_model_path = os.path.join(extract_dir, file) + elif file.endswith('tokenizer.vocab'): + tokenizer_vocab_path = os.path.join(extract_dir, file) + + if not tokenizer_model_path: + raise FileNotFoundError("Tokenizer model file not found") + + if not tokenizer_vocab_path: + raise FileNotFoundError("Tokenizer vocab file not found") + + tokens = {} + with open(tokenizer_vocab_path, 'r', encoding='utf-8') as f: + for idx, line in enumerate(f): + parts = line.strip().split('\t') + if len(parts) >= 1: + token = parts[0] + tokens[token.encode('utf-8')] = idx + + print(f"Loaded {len(tokens)} tokens from {os.path.basename(tokenizer_vocab_path)}") + + if len(tokens) != 8192: + print(f"WARNING: Expected 8192 tokens, got {len(tokens)}") + + return tokens + +def write_tensor(fout, name, data, use_f16=True, force_f32=False): + if 'pre_encode.conv' in name and 'bias' in name and len(data.shape) == 1: + data = data.reshape(1, -1, 1, 1) + print(f" Reshaped conv bias {name} to {data.shape}") + + n_dims = len(data.shape) + + ftype = 1 if use_f16 and not force_f32 else 0 + if force_f32: + data = data.astype(np.float32) + elif use_f16: + if n_dims < 2 or 'bias' in name or 'norm' in name: + data = data.astype(np.float32) + ftype = 0 + else: + data = data.astype(np.float16) + else: + data = data.astype(np.float32) + + dims_reversed = [data.shape[n_dims - 1 - i] for i in range(n_dims)] + print(f"Processing: {name} {list(data.shape)}, dtype: {data.dtype}, n_dims: {n_dims}, reversed: {dims_reversed}") + name_bytes = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(name_bytes) + + data.tofile(fout) + +def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None): + nemo_path = Path(nemo_path) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Create temporary directory for extraction + with tempfile.TemporaryDirectory() as temp_dir: + extract_nemo_archive(nemo_path, temp_dir) + + config_path = os.path.join(temp_dir, 'model_config.yaml') + config = load_model_config(config_path) + + print("Model configuration:") + print(f" Sample rate: {config['sample_rate']}") + print(f" Encoder layers: {config['encoder']['n_layers']}") + print(f" Encoder d_model: {config['encoder']['d_model']}") + print(f" Mel features: {config['preprocessor']['features']}") + + weights_path = os.path.join(temp_dir, 'model_weights.ckpt') + print(f"\nLoading model weights from {weights_path}") + checkpoint = torch.load(weights_path, map_location='cpu') + + # Extract state dict + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + print(f"Loaded {len(state_dict)} tensors") + + # Load tokenizer + print("\nLoading tokenizer...") + tokens = load_tokenizer(temp_dir, config) + print(f"Loaded {len(tokens)} tokens") + + # Prepare hyperparameters for the Parakeet ggml format. + hparams = { + 'n_audio_ctx': 5000, + 'n_audio_state': config['encoder']['d_model'], + 'n_audio_head': config['encoder']['n_heads'], + 'n_audio_layer': config['encoder']['n_layers'], + 'n_mels': config['preprocessor']['features'], + 'n_fft': config['preprocessor']['n_fft'], + 'subsampling_factor': config['encoder']['subsampling_factor'], + 'n_subsampling_channels': config['encoder']['subsampling_conv_channels'], + 'n_pos_max_len': config['encoder']['pos_emb_max_len'], + + 'n_pred_dim': config['decoder']['prednet']['pred_hidden'], + 'n_pred_layers': config['decoder']['prednet']['pred_rnn_layers'], + 'n_vocab': config['decoder']['vocab_size'], + 'n_tdt_durations': config['model_defaults']['num_tdt_durations'], + 'n_max_tokens': config['decoding']['greedy']['max_symbols'], + } + + print("\nGGML hyperparameters:") + for key, value in hparams.items(): + print(f" {key}: {value}") + + pe = create_relative_positional_encoding(hparams['n_audio_state'], hparams['n_pos_max_len']) + print(f"\nGenerated positional encoding tensor 'encoder.pe' with shape {pe.shape}") + + # Create output file + if out_name: + fname_out = output_dir / out_name + else: + fname_out = output_dir / ("ggml-model-f32.bin" if not use_f16 else "ggml-model.bin") + print(f"\nWriting to {fname_out}") + + with open(fname_out, 'wb') as fout: + # Write magic number + fout.write(struct.pack("i", 0x67676d6c)) # 'ggml' in hex + + # Write hyperparameters + fout.write(struct.pack("i", hparams['n_vocab'])) + fout.write(struct.pack("i", hparams['n_audio_ctx'])) + fout.write(struct.pack("i", hparams['n_audio_state'])) + fout.write(struct.pack("i", hparams['n_audio_head'])) + fout.write(struct.pack("i", hparams['n_audio_layer'])) + fout.write(struct.pack("i", hparams['n_mels'])) + fout.write(struct.pack("i", 1 if use_f16 else 0)) + fout.write(struct.pack("i", hparams['n_fft'])) + fout.write(struct.pack("i", hparams['subsampling_factor'])) + fout.write(struct.pack("i", hparams['n_subsampling_channels'])) + fout.write(struct.pack("i", hparams['n_pos_max_len'])) + fout.write(struct.pack("i", hparams['n_pred_dim'])) + fout.write(struct.pack("i", hparams['n_pred_layers'])) + fout.write(struct.pack("i", hparams['n_tdt_durations'])) + fout.write(struct.pack("i", hparams['n_max_tokens'])) + + # Extract mel filterbank from model + fb_key = None + for key in state_dict.keys(): + if 'featurizer.fb' in key or 'filterbank' in key.lower(): + fb_key = key + break + + if not fb_key: + print("\nERROR: Mel filterbank not found in model!") + print("Expected tensor with 'featurizer.fb' or 'filterbank' in name") + print("\nAvailable preprocessor tensors:") + for key in sorted(state_dict.keys()): + if 'preprocessor' in key or 'featurizer' in key: + print(f" {key}: {state_dict[key].shape}") + raise ValueError("Mel filterbank tensor not found in model") + + print(f"\nUsing model's mel filterbank from: {fb_key}") + mel_filters = state_dict[fb_key].squeeze().numpy().astype(np.float32) + print(f" Filterbank shape: {mel_filters.shape}") + print(f" Filterbank min/max values: {mel_filters.min():.6f} / {mel_filters.max():.6f}") + print(f" Filterbank non-zero elements: {np.count_nonzero(mel_filters)} / {mel_filters.size}") + print(f" First row sum: {mel_filters[0].sum():.6f}") + + if len(mel_filters.shape) != 2: + raise ValueError(f"Expected 2D filterbank, got shape {mel_filters.shape}") + + n_mels, n_freqs = mel_filters.shape + fout.write(struct.pack("i", n_mels)) # n_mel + fout.write(struct.pack("i", n_freqs)) # n_fb (frequency bins) + + # Write mel filterbank + for i in range(n_mels): + for j in range(n_freqs): + fout.write(struct.pack("f", mel_filters[i, j])) + + # Extract window function from model + window_key = None + for key in state_dict.keys(): + if 'featurizer.window' in key or 'preproc' in key and 'window' in key: + window_key = key + break + + if not window_key: + print("\nERROR: Window function not found in model!") + print("Expected tensor with 'featurizer.window' in name") + raise ValueError("Window function tensor not found in model") + + print(f"\nUsing model's window function from: {window_key}") + window = state_dict[window_key].squeeze().numpy().astype(np.float32) + print(f" Window shape: {window.shape}") + print(f" Window min/max values: {window.min():.6f} / {window.max():.6f}") + print(f" Window non-zero elements: {np.count_nonzero(window)} / {window.size}") + print(f" Window sum: {window.sum():.6f}") + + if len(window.shape) != 1: + raise ValueError(f"Expected 1D window, got shape {window.shape}") + + n_window = window.shape[0] + fout.write(struct.pack("i", n_window)) + + # Write window function + for i in range(n_window): + fout.write(struct.pack("f", window[i])) + + # Write TDT durations + tdt_durations = config['model_defaults']['tdt_durations'] + if len(tdt_durations) != hparams['n_tdt_durations']: + raise ValueError(f"TDT durations count mismatch: {len(tdt_durations)} vs {hparams['n_tdt_durations']}") + + for duration in tdt_durations: + fout.write(struct.pack("I", duration)) + + fout.write(struct.pack("i", len(tokens))) + for token_bytes, idx in sorted(tokens.items(), key=lambda x: x[1]): + fout.write(struct.pack("i", len(token_bytes))) + fout.write(token_bytes) + + print("\nConverting model weights...") + for name, tensor in state_dict.items(): + # Skip the filterbank and window - already written in preprocessing section + if name == fb_key: + continue + if name == window_key: + continue + + # Don't squeeze Conv2d weights - they need to preserve all 4 dimensions + if 'conv' in name and 'weight' in name and len(tensor.shape) == 4: + data = tensor.numpy() + else: + data = tensor.squeeze().numpy() + + write_tensor(fout, name, data, use_f16=use_f16) + + write_tensor(fout, "encoder.pe", pe, use_f16=use_f16, force_f32=True) + + print(f"\nConversion complete!") + print(f"Output file: {fname_out}") + print(f"File size: {fname_out.stat().st_size / (1024**2):.2f} MB") + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Convert Parakeet TDT model from NeMo format to ggml format' + ) + parser.add_argument('--model', type=str, required=True, + help='Path to Parakeet .nemo model file') + parser.add_argument('--out-dir', type=str, required=True, + help='Directory to write ggml model file') + parser.add_argument('--use-f32', action='store_true', default=False, + help='Use f32 instead of f16 (default: f16)') + parser.add_argument('--out-name', type=str, default=None, + help='Output file name (default: ggml-model.bin or ggml-model-f32.bin)') + + args = parser.parse_args() + + if not os.path.exists(args.model): + print(f"Error: {args.model} not found") + sys.exit(1) + + use_f16 = not args.use_f32 + convert_parakeet_to_ggml(args.model, args.out_dir, use_f16, args.out_name) diff --git a/models/requirements-parakeet.txt b/models/requirements-parakeet.txt new file mode 100644 index 00000000000..c3726e8bfee --- /dev/null +++ b/models/requirements-parakeet.txt @@ -0,0 +1 @@ +pyyaml diff --git a/scripts/upload-parakeet.py b/scripts/upload-parakeet.py new file mode 100644 index 00000000000..311a16a0239 --- /dev/null +++ b/scripts/upload-parakeet.py @@ -0,0 +1,75 @@ +import os +from huggingface_hub import HfApi, create_repo + +# TODO: change to ggml-org once merged. +USER_NAME = "danbev" +REPO_ID = f"{USER_NAME}/parakeet" +LOCAL_GGUF_PATH = "models/ggml-parakeet-tdt-0.6b-v3.bin" +REMOTE_GGUF_NAME = "parakeet-tdt-0.6b-v3.bin" + +MODEL_CARD_CONTENT = f"""--- +license: apache-2.0 +base_model: {USER_NAME}/parakeet +tags: +- gguf +--- + +# Parakeet Model Card + +## Description +This is an iterative release of the Parakeet model in whisper.cpp format. + +## Usage +You can use this file with [parakeet-cli](https://github.com/danbev/whisper.cpp/tree/parakeet-support/examples/parakeet-cli). + +Build parakeet-cli: +```console +$ git clone -b parakeet-support https://github.com/danbev/whisper.cpp.git +$ cd whisper.cpp +$ cmake -B build -S . +$ cmake --build build --target parakeet-cli -j 12 +``` + +Download the model: +```console +$ hf download danbev/parakeet parakeet-tdt-0.6b-v3.bin --local-dir models +``` + +Run: +```console +$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3.bin -f samples/jfk.wav +``` + +""" + +api = HfApi() + +def deploy_iteration(): + create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True) + + print("Updating Model Card...") + api.upload_file( + path_or_fileobj=MODEL_CARD_CONTENT.encode(), + path_in_repo="README.md", + repo_id=REPO_ID, + repo_type="model", + commit_message="Update README.md" + ) + + print(f"Uploading {REMOTE_GGUF_NAME}...") + api.upload_file( + path_or_fileobj=LOCAL_GGUF_PATH, + path_in_repo=REMOTE_GGUF_NAME, + repo_id=REPO_ID, + repo_type="model", + commit_message="Upload new parakeet iteration" + ) + + print(f"\nDeployment successful!") + print(f"URL: https://huggingface.co/{REPO_ID}") + +if __name__ == "__main__": + if os.path.exists(LOCAL_GGUF_PATH): + deploy_iteration() + else: + print(f"Error: {LOCAL_GGUF_PATH} not found.") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 095a2791de5..4e7c5b24dc3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -109,23 +109,43 @@ add_library(whisper whisper.cpp ) +add_library(parakeet + ../include/parakeet.h + parakeet-arch.h + parakeet.cpp + ) + +target_include_directories(parakeet PUBLIC . ../include) +target_compile_features (parakeet PUBLIC cxx_std_11) +target_link_libraries(parakeet PUBLIC ggml Threads::Threads) + # Set the version numbers set_target_properties(whisper PROPERTIES VERSION ${PROJECT_VERSION} SOVERSION ${SOVERSION} ) +set_target_properties(parakeet PROPERTIES + VERSION ${PROJECT_VERSION} + SOVERSION ${SOVERSION} +) + target_include_directories(whisper PUBLIC . ../include) target_compile_features (whisper PUBLIC cxx_std_11) # don't bump if (CMAKE_CXX_BYTE_ORDER STREQUAL "BIG_ENDIAN") set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_BIG_ENDIAN) + set(PARAKEET_EXTRA_FLAGS ${PARAKEET_EXTRA_FLAGS} -DPARAKEET_BIG_ENDIAN) endif() if (WHISPER_EXTRA_FLAGS) target_compile_options(whisper PRIVATE ${WHISPER_EXTRA_FLAGS}) endif() +if (PARAKEET_EXTRA_FLAGS) + target_compile_options(parakeet PRIVATE ${PARAKEET_EXTRA_FLAGS}) +endif() + find_package(Threads REQUIRED) target_link_libraries(whisper PUBLIC ggml Threads::Threads) @@ -144,4 +164,7 @@ endif() if (BUILD_SHARED_LIBS) set_target_properties(whisper PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(whisper PRIVATE WHISPER_SHARED WHISPER_BUILD) + + set_target_properties(parakeet PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(parakeet PRIVATE PARAKEET_SHARED PARAKEET_BUILD) endif() diff --git a/src/parakeet-arch.h b/src/parakeet-arch.h new file mode 100644 index 00000000000..03eb1303edf --- /dev/null +++ b/src/parakeet-arch.h @@ -0,0 +1,194 @@ +#pragma once + +#include "ggml.h" + +#include + +enum parakeet_tensor { + // Encoder pre_encode + PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, + PARAKEET_TENSOR_ENC_PE, + + // Encoder layers (per-layer) + PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, + PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, + PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, + PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_BN_BIAS, + PARAKEET_TENSOR_ENC_CONV_BN_MEAN, + PARAKEET_TENSOR_ENC_CONV_BN_VAR, + PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, + PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, + PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, + PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, + PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, + PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, + PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, + + // Prediction network + PARAKEET_TENSOR_PRED_EMBED_WEIGHT, + PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, + PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, + PARAKEET_TENSOR_PRED_LSTM_BIAS_IH, + PARAKEET_TENSOR_PRED_LSTM_BIAS_HH, + + // Joint network + PARAKEET_TENSOR_JOINT_PRED_WEIGHT, + PARAKEET_TENSOR_JOINT_PRED_BIAS, + PARAKEET_TENSOR_JOINT_ENC_WEIGHT, + PARAKEET_TENSOR_JOINT_ENC_BIAS, + PARAKEET_TENSOR_JOINT_NET_WEIGHT, + PARAKEET_TENSOR_JOINT_NET_BIAS, +}; + +static const std::map PARAKEET_TENSOR_NAMES = { + // Encoder pre_encode + {PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, "encoder.pre_encode.out.weight"}, + {PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, "encoder.pre_encode.out.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, "encoder.pre_encode.conv.0.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, "encoder.pre_encode.conv.0.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, "encoder.pre_encode.conv.2.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, "encoder.pre_encode.conv.2.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, "encoder.pre_encode.conv.3.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, "encoder.pre_encode.conv.3.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, "encoder.pre_encode.conv.5.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, "encoder.pre_encode.conv.5.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, "encoder.pre_encode.conv.6.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, "encoder.pre_encode.conv.6.bias"}, + {PARAKEET_TENSOR_ENC_PE, "encoder.pe"}, + + // Encoder layers (use %d for layer number) + {PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, "encoder.layers.%d.norm_feed_forward1.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, "encoder.layers.%d.norm_feed_forward1.bias"}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward1.linear1.weight"}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward1.linear2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, "encoder.layers.%d.norm_conv.weight"}, + {PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, "encoder.layers.%d.norm_conv.bias"}, + {PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, "encoder.layers.%d.conv.pointwise_conv1.weight"}, + {PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, "encoder.layers.%d.conv.depthwise_conv.weight"}, + {PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, "encoder.layers.%d.conv.batch_norm.weight"}, + {PARAKEET_TENSOR_ENC_CONV_BN_BIAS, "encoder.layers.%d.conv.batch_norm.bias"}, + {PARAKEET_TENSOR_ENC_CONV_BN_MEAN, "encoder.layers.%d.conv.batch_norm.running_mean"}, + {PARAKEET_TENSOR_ENC_CONV_BN_VAR, "encoder.layers.%d.conv.batch_norm.running_var"}, + {PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, "encoder.layers.%d.conv.batch_norm.num_batches_tracked"}, + {PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, "encoder.layers.%d.conv.pointwise_conv2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, "encoder.layers.%d.norm_self_att.weight"}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, "encoder.layers.%d.norm_self_att.bias"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, "encoder.layers.%d.self_attn.pos_bias_u"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, "encoder.layers.%d.self_attn.pos_bias_v"}, + {PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, "encoder.layers.%d.self_attn.linear_q.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, "encoder.layers.%d.self_attn.linear_k.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, "encoder.layers.%d.self_attn.linear_v.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, "encoder.layers.%d.self_attn.linear_out.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, "encoder.layers.%d.self_attn.linear_pos.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, "encoder.layers.%d.norm_feed_forward2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, "encoder.layers.%d.norm_feed_forward2.bias"}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward2.linear1.weight"}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward2.linear2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, "encoder.layers.%d.norm_out.weight"}, + {PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, "encoder.layers.%d.norm_out.bias"}, + + // Prediction network + {PARAKEET_TENSOR_PRED_EMBED_WEIGHT, "decoder.prediction.embed.weight"}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, "decoder.prediction.dec_rnn.lstm.weight_ih_l%d"}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, "decoder.prediction.dec_rnn.lstm.weight_hh_l%d"}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_IH, "decoder.prediction.dec_rnn.lstm.bias_ih_l%d"}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_HH, "decoder.prediction.dec_rnn.lstm.bias_hh_l%d"}, + + // Joint network + {PARAKEET_TENSOR_JOINT_PRED_WEIGHT, "joint.pred.weight"}, + {PARAKEET_TENSOR_JOINT_PRED_BIAS, "joint.pred.bias"}, + {PARAKEET_TENSOR_JOINT_ENC_WEIGHT, "joint.enc.weight"}, + {PARAKEET_TENSOR_JOINT_ENC_BIAS, "joint.enc.bias"}, + {PARAKEET_TENSOR_JOINT_NET_WEIGHT, "joint.joint_net.2.weight"}, + {PARAKEET_TENSOR_JOINT_NET_BIAS, "joint.joint_net.2.bias"}, +}; + +static const std::map PARAKEET_TENSOR_INFO = { + // Encoder pre_encode + {PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PE, GGML_OP_ADD}, + + // Encoder layers + {PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_CONV_BN_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_CONV_BN_MEAN, GGML_OP_SUB}, + {PARAKEET_TENSOR_ENC_CONV_BN_VAR, GGML_OP_DIV}, + {PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, GGML_OP_NONE}, + {PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, GGML_OP_ADD}, + + // Prediction network + {PARAKEET_TENSOR_PRED_EMBED_WEIGHT, GGML_OP_GET_ROWS}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_IH, GGML_OP_ADD}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_HH, GGML_OP_ADD}, + + // Joint network + {PARAKEET_TENSOR_JOINT_PRED_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_PRED_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_JOINT_ENC_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_ENC_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_JOINT_NET_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_NET_BIAS, GGML_OP_ADD}, +}; diff --git a/src/parakeet.cpp b/src/parakeet.cpp new file mode 100644 index 00000000000..43f22aca0ea --- /dev/null +++ b/src/parakeet.cpp @@ -0,0 +1,4315 @@ +#include "parakeet.h" +#include "parakeet-arch.h" + +#include "ggml.h" +#include "ggml-cpp.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#include +#include +#include +#include +#define _USE_MATH_DEFINES +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +#include +#endif + +#if defined(PARAKEET_BIG_ENDIAN) +template +static T byteswap(T value) { + T value_swapped; + char * source = reinterpret_cast(&value); + char * target = reinterpret_cast(&value_swapped); + int size = sizeof(T); + for (int i = 0; i < size; i++) { + target[size - 1 - i] = source[i]; + } + return value_swapped; +} + +template +static void byteswap_tensor_data(ggml_tensor * tensor) { + T * datum = reinterpret_cast(tensor->data); + for (int i = 0; i < ggml_nelements(tensor); i++) { + datum[i] = byteswap(datum[i]); + } +} + +static void byteswap_tensor(ggml_tensor * tensor) { + switch (tensor->type) { + case GGML_TYPE_I16: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F16: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_I32: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F32: { + byteswap_tensor_data(tensor); + break; + } + default: { // GML_TYPE_I8 + break; + } + } +} + +#define BYTESWAP_VALUE(d) d = byteswap(d) +#define BYTESWAP_FILTERS(f) \ + do { \ + for (auto & datum : f.data) { \ + datum = byteswap(datum); \ + } \ + } while (0) +#define BYTESWAP_TENSOR(t) \ + do { \ + byteswap_tensor(t); \ + } while (0) +#else +#define BYTESWAP_VALUE(d) do {} while (0) +#define BYTESWAP_FILTERS(f) do {} while (0) +#define BYTESWAP_TENSOR(t) do {} while (0) +#endif + +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define PARAKEET_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// + +PARAKEET_ATTRIBUTE_FORMAT(2, 3) +static void parakeet_log_internal (ggml_log_level level, const char * format, ...); +static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data); + +#define PARAKEET_LOG_ERROR(...) parakeet_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define PARAKEET_LOG_WARN(...) parakeet_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define PARAKEET_LOG_INFO(...) parakeet_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) + +// define this to enable verbose trace logging - useful for debugging purposes +//#define PARAKEET_DEBUG + +#if defined(PARAKEET_DEBUG) +#define PARAKEET_LOG_DEBUG(...) parakeet_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#else +#define PARAKEET_LOG_DEBUG(...) +#endif + +#define PARAKEET_ASSERT(x) \ + do { \ + if (!(x)) { \ + PARAKEET_LOG_ERROR("PARAKEET_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +#define PARAKEET_MAX_NODES 4096 + +static std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +// +// ggml helpers +// + +static bool ggml_graph_compute_helper( + struct ggml_cgraph * graph, + int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) }; + + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); + + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data); + } + + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + ggml_backend_set_n_threads_fn(backend.get(), n_threads); + } + + return ggml_backend_graph_compute(backend.get(), graph) == GGML_STATUS_SUCCESS; +} + +static bool ggml_graph_compute_helper( + ggml_backend_sched_t sched, + struct ggml_cgraph * graph, + int n_threads, + bool sched_reset = true) { + for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i); + ggml_backend_dev_t dev = ggml_backend_get_device(backend); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + + auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (fn_set_n_threads) { + fn_set_n_threads(backend, n_threads); + } + } + + const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS); + + if (!t || sched_reset) { + ggml_backend_sched_reset(sched); + } + + return t; +} + +// TODO: move these functions to ggml-base with support for ggml-backend? + +static ggml_tensor * parakeet_set_f32(struct ggml_tensor * t, float v) { + GGML_ASSERT(t->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(t)); + size_t nels = ggml_nelements(t); + for (size_t i = 0; i < nels; ++i) { + ((float *) t->data)[i] = v; + } + return t; +} + +static ggml_tensor * parakeet_set_i32(struct ggml_tensor * t, int32_t v) { + GGML_ASSERT(t->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(t)); + size_t nels = ggml_nelements(t); + for (size_t i = 0; i < nels; ++i) { + ((int32_t *) t->data)[i] = v; + } + return t; +} + +static float parakeet_get_f32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + GGML_ASSERT(t->type == GGML_TYPE_F32); + void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3]; + return *(float *) data; +} + +static void parakeet_set_f32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float v) { + GGML_ASSERT(t->type == GGML_TYPE_F32); + void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3]; + *(float *) data = v; +} + +static int32_t parakeet_get_i32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + GGML_ASSERT(t->type == GGML_TYPE_I32); + void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3]; + return *(int32_t *) data; +} + +static void parakeet_set_i32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, int32_t v) { + GGML_ASSERT(t->type == GGML_TYPE_I32); + void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3]; + *(int32_t *) data = v; +} + +struct parakeet_mel { + int n_len = 0; + int n_len_org = 0; + int n_mel = 0; + + std::vector data; +}; + +struct parakeet_filters { + int32_t n_mel = 0; + int32_t n_fb = 0; // number of frequency bins + + std::vector data; +}; + +struct parakeet_vocab { + using id = int32_t; + using token = std::string; + + int n_vocab = 8192; + size_t max_token_length = 0; + + std::map token_to_id; + std::map id_to_token; + + id token_unk; + id token_bos; + id token_blank; + id token_eos; +}; + +static const std::string PARAKEET_SPM_SPACE = "\xE2\x96\x81"; + +static inline int utf8_codepoint_len(unsigned char c) { + if ((c & 0x80) == 0x00) return 1; + if ((c & 0xE0) == 0xC0) return 2; + if ((c & 0xF0) == 0xE0) return 3; + if ((c & 0xF8) == 0xF0) return 4; + return 1; +} + +static bool is_sentencepiece_control(const std::string & piece) { + return piece == "" || piece == "" || piece == "" || piece == "[BLANK]"; +} + +static std::string sentencepiece_normalize(const std::string & text) { + std::string normalized; + normalized.reserve(text.size() + PARAKEET_SPM_SPACE.size()); + normalized += PARAKEET_SPM_SPACE; // SentencePiece dummy prefix + + for (unsigned char c : text) { + if (std::isspace(c)) { + normalized += PARAKEET_SPM_SPACE; + } else { + normalized += static_cast(c); + } + } + + return normalized; +} + +static std::string sentencepiece_piece_to_text(const std::string & piece, bool is_first_piece) { + if (is_sentencepiece_control(piece)) { + return ""; + } + + std::string text; + text.reserve(piece.size()); + + size_t pos = 0; + while (pos < piece.size()) { + if (piece.compare(pos, PARAKEET_SPM_SPACE.size(), PARAKEET_SPM_SPACE) == 0) { + if (!is_first_piece || !text.empty()) { + text += ' '; + } + pos += PARAKEET_SPM_SPACE.size(); + continue; + } + + text += piece[pos]; + ++pos; + } + + return text; +} + +struct parakeet_segment { + int64_t t0; + int64_t t1; + + std::string text; + + std::vector tokens; +}; + +struct parakeet_batch { + int32_t n_tokens; + + parakeet_token * token; + int32_t * i_time; // index of the audio frame + parakeet_pos * pos; + int32_t * n_seq_id; // always 1, here for consistency with llama.cpp + parakeet_seq_id ** seq_id; // null terminated + int8_t * logits; +}; + +static struct parakeet_batch parakeet_batch_init(int32_t n_tokens) { + parakeet_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, }; + + batch.token = (parakeet_token * ) malloc(sizeof(parakeet_token) * (n_tokens)); + batch.i_time = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.pos = (parakeet_pos *) malloc(sizeof(parakeet_pos) * (n_tokens)); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.seq_id = (parakeet_seq_id **) malloc(sizeof(parakeet_seq_id *) * (n_tokens + 1)); + for (int i = 0; i < n_tokens; ++i) { + batch.seq_id[i] = (parakeet_seq_id *) malloc(sizeof(parakeet_seq_id)); + } + batch.seq_id[n_tokens] = nullptr; + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +static void parakeet_batch_free(struct parakeet_batch batch) { + if (batch.token) free(batch.token); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i]; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} + +static void parakeet_batch_prep_legacy(parakeet_batch & batch, const parakeet_token * tokens, int n_tokens, int n_past, int seq_id) { + batch.n_tokens = n_tokens; + for (int i = 0; i < n_tokens; ++i) { + if (tokens) { + batch.token[i] = tokens[i]; + } + batch.pos [i] = n_past + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i][0] = seq_id; + batch.logits [i] = 0; + } + batch.logits[n_tokens - 1] = 1; +} + +// ggml_backend_sched wrapper for parakeet usage +struct parakeet_sched { + ggml_backend_sched_t sched = nullptr; + + std::vector meta; +}; + +static size_t parakeet_sched_size(struct parakeet_sched & allocr) { + size_t size = allocr.meta.size(); + for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i); + size += ggml_backend_sched_get_buffer_size(allocr.sched, backend); + } + return size; +} + +// measure the memory usage of a graph and prepare the allocr's internal data buffer +static bool parakeet_sched_graph_init(struct parakeet_sched & allocr, std::vector backends, std::function && get_graph) { + auto & sched = allocr.sched; + auto & meta = allocr.meta; + + sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), PARAKEET_MAX_NODES, false, true); + + meta.resize(ggml_tensor_overhead()*PARAKEET_MAX_NODES + ggml_graph_overhead()); + + // since there are dependencies between the different graphs, + // we need to allocate them instead of only reserving to get the correct compute buffer size + if (!ggml_backend_sched_alloc_graph(sched, get_graph())) { + // failed to allocate the compute buffer + PARAKEET_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__); + return false; + } + + ggml_backend_sched_reset(sched); + + return true; +} + +// TODO: Find out is there a multiple version types. It is not yet clear to me +// at this point. +enum parakeet_arch { + PARAKEET_ARCH_UNKNOWN = 0, + PARAKEET_ARCH_TDT = 1, // NVIDIA Parakeet TDT (RNN-T) +}; + +struct parakeet_hparams { + int32_t n_vocab = 8192; + int32_t n_audio_ctx = 0; // 0 = unlimited, will be set based on input + int32_t n_audio_state = 1024; + int32_t n_audio_head = 8; + int32_t n_audio_layer = 24; + int32_t n_mels = 128; + int32_t ftype = 1; + int32_t n_fft = 512; // FFT size for mel spectrogram + float eps = 1e-5f; + int32_t subsampling_factor = 8; + int32_t n_subsampling_channels = 256; + int32_t n_pos_max_len = 5000; + int32_t n_pred_dim = 640; + int32_t n_pred_layers = 2; + int32_t n_tdt_durations = 5; + int32_t n_max_tokens = 10; + + parakeet_arch arch = PARAKEET_ARCH_TDT; +}; + +struct parakeet_layer_encoder { + struct ggml_tensor * norm_ff1_w; + struct ggml_tensor * norm_ff1_b; + + struct ggml_tensor * ff1_linear1_w; + struct ggml_tensor * ff1_linear2_w; + + struct ggml_tensor * norm_conv_w; + struct ggml_tensor * norm_conv_b; + + struct ggml_tensor * conv_pw1_w; // pointwise_conv1 + struct ggml_tensor * conv_dw_w; // depthwise_conv + struct ggml_tensor * conv_bn_w; // batch_norm weight + struct ggml_tensor * conv_bn_b; // batch_norm bias + struct ggml_tensor * conv_bn_mean; // batch_norm running_mean + struct ggml_tensor * conv_bn_var; // batch_norm running_var + struct ggml_tensor * conv_bn_num_batches; // batch_norm num_batches_tracked + struct ggml_tensor * conv_pw2_w; // pointwise_conv2 + + struct ggml_tensor * norm_attn_w; + struct ggml_tensor * norm_attn_b; + + struct ggml_tensor * attn_pos_bias_u; + struct ggml_tensor * attn_pos_bias_v; + struct ggml_tensor * attn_q_w; + struct ggml_tensor * attn_k_w; + struct ggml_tensor * attn_v_w; + struct ggml_tensor * attn_out_w; + struct ggml_tensor * attn_pos_w; + + struct ggml_tensor * norm_ff2_w; + struct ggml_tensor * norm_ff2_b; + + struct ggml_tensor * ff2_linear1_w; + struct ggml_tensor * ff2_linear2_w; + + struct ggml_tensor * norm_out_w; + struct ggml_tensor * norm_out_b; +}; + +struct parakeet_lsmt_layer { + struct ggml_tensor * ih_w; // input-to-hidden weight + struct ggml_tensor * ih_b; // input-to-hidden bias + struct ggml_tensor * hh_w; // hidden-to-hidden weight + struct ggml_tensor * hh_b; // hidden-to-hidden bias +}; + +struct parakeet_prediction_network { + struct ggml_tensor * embed_w; + + std::vector lstm_layer; +}; + +struct parakeet_joint_network { + struct ggml_tensor * pred_w; + struct ggml_tensor * pred_b; + struct ggml_tensor * enc_w; + struct ggml_tensor * enc_b; + struct ggml_tensor * net_w; + struct ggml_tensor * net_b; +}; + +struct parakeet_model { + parakeet_filters filters; + parakeet_hparams hparams; + + // Relative positional encoding (pe) + struct ggml_tensor * pe; + + struct ggml_tensor * enc_pre_out_w; + struct ggml_tensor * enc_pre_out_b; + struct ggml_tensor * enc_pre_conv_0_w; + struct ggml_tensor * enc_pre_conv_0_b; + struct ggml_tensor * enc_pre_conv_2_w; + struct ggml_tensor * enc_pre_conv_2_b; + struct ggml_tensor * enc_pre_conv_3_w; + struct ggml_tensor * enc_pre_conv_3_b; + struct ggml_tensor * enc_pre_conv_5_w; + struct ggml_tensor * enc_pre_conv_5_b; + struct ggml_tensor * enc_pre_conv_6_w; + struct ggml_tensor * enc_pre_conv_6_b; + + std::vector layers; + + parakeet_prediction_network prediction; + + parakeet_joint_network joint; + + std::vector tdt_durations; + + std::vector ctxs; + + std::vector buffers; + + int n_loaded; + std::map tensors; +}; + +struct parakeet_lstm_state_layer { + struct ggml_tensor * h_state; + struct ggml_tensor * c_state; +}; + +struct parakeet_lstm_state { + std::vector layer; + + std::vector ctx_buf; + + ggml_backend_buffer_t buffer = nullptr; +}; + +struct tdt_stream_state { + parakeet_token last_token; // last emitted token (for predictor input) + int32_t time_step; // overflow frames into next chunk + int32_t decoded_length; // total tokens decoded + bool initialized; // whether prediction LSTM state has been initialized +}; + +struct parakeet_stream { + std::vector buffer; + int64_t n_samples_advanced = 0; + + int n_left_ctx = 0; + int n_chunk = 0; + int n_right_ctx = 0; + + parakeet_full_params params = {}; + bool initialized = false; +}; + +struct parakeet_state { + int64_t t_sample_us = 0; + int64_t t_encode_us = 0; + int64_t t_decode_us = 0; + int64_t t_batchd_us = 0; + int64_t t_prompt_us = 0; + int64_t t_mel_us = 0; + + int32_t n_sample = 0; // number of tokens sampled + int32_t n_encode = 0; // number of encoder calls + int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) + int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding) + int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding) + int32_t n_fail_p = 0; // number of logprob threshold failures + int32_t n_fail_h = 0; // number of entropy threshold failures + + parakeet_mel mel; + + parakeet_batch batch; + + int n_frames; + + std::vector backends; + + parakeet_sched sched_conv; + parakeet_sched sched_encode; + parakeet_sched sched_decode; + + // outputs from encoder stages + struct ggml_tensor * pre_enc_out = nullptr; + struct ggml_tensor * enc_out = nullptr; + struct ggml_tensor * pred_out = nullptr; + + std::vector pred_out_buf; + ggml_backend_buffer_t pred_out_buffer = nullptr; + + struct ggml_tensor * attn_mask = nullptr; + + std::vector inp_mel; + std::vector inp_mask; + + std::vector enc_out_buffer; + int enc_out_frames = 0; + + std::vector logits; + + std::vector result_all; + + std::vector decoded_tokens; + std::vector decoded_token_data; + + std::string path_model; + + int32_t n_audio_ctx = 0; + + parakeet_lstm_state lstm_state; + + struct tdt_stream_state tdt_stream_state = {0, 0, 0, false}; + + parakeet_stream stream; +}; + +// FFT cache for mel spectrogram computation +struct parakeet_mel_cache { + int n_fft = 0; + + // In FFT, we frequently use sine and cosine operations with the same values. + // We can use precalculated values to speed up the process. + std::vector sin_vals; + std::vector cos_vals; + + // Hann window (Use cosf to eliminate difference) + // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 + std::vector hann_window; + + // Window function from model (Parakeet uses actual window from training) + std::vector window; + + void init(int fft_size) { + n_fft = fft_size; + sin_vals.resize(n_fft); + cos_vals.resize(n_fft); + hann_window.resize(n_fft); + + fill_sin_cos_table(); + fill_hann_window(n_fft, true, hann_window.data()); + } + + void fill_sin_cos_table() { + for (int i = 0; i < n_fft; i++) { + double theta = (2 * M_PI * i) / n_fft; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } + } + + void fill_hann_window(int length, bool periodic, float * output) { + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } + } +}; + +struct parakeet_context { + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + ggml_type wtype = ggml_type::GGML_TYPE_F16; + ggml_type itype = ggml_type::GGML_TYPE_F16; + + parakeet_context_params params; + + parakeet_model model; + parakeet_vocab vocab; + + parakeet_state * state = nullptr; + + parakeet_mel_cache mel_cache; + + std::string path_model; +}; + +struct parakeet_global { + // We save the log callback globally + ggml_log_callback log_callback = parakeet_log_callback_default; + void * log_callback_user_data = nullptr; +}; + +static parakeet_global g_state; + +template +static void read_safe(parakeet_model_loader * loader, T & dest) { + loader->read(loader->context, &dest, sizeof(T)); + BYTESWAP_VALUE(dest); +} + +static bool parakeet_lstm_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_layer) { + parakeet_lstm_state & lstm_state = pstate.lstm_state; + + lstm_state.ctx_buf.resize(ggml_tensor_overhead() * n_layer * 2); + lstm_state.layer.resize(n_layer); + + struct ggml_init_params params = { + /*.mem_size =*/ lstm_state.ctx_buf.size(), + /*.mem_buffer =*/ lstm_state.ctx_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states context\n", __func__); + return false; + } + + + for (int il = 0; il < n_layer; ++il) { + lstm_state.layer[il].h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 640); + lstm_state.layer[il].c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 640); + } + + lstm_state.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!lstm_state.buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states\n", __func__); + return false; + } + + ggml_backend_buffer_clear(lstm_state.buffer, 0); + + ggml_free(ctx); + + return true; +} + +static bool parakeet_pred_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_pred_dim) { + pstate.pred_out_buf.resize(ggml_tensor_overhead()); + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.pred_out_buf.size(), + /*.mem_buffer =*/ pstate.pred_out_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor context\n", __func__); + return false; + } + + pstate.pred_out = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim); + pstate.pred_out_buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!pstate.pred_out_buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor\n", __func__); + ggml_free(ctx); + return false; + } + + ggml_free(ctx); + + return true; +} + +static ggml_backend_t whisper_backend_init_gpu(const parakeet_context_params & params) { + ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); + + ggml_backend_dev_t dev = nullptr; + + int cnt = 0; + if (params.use_gpu) { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i); + enum ggml_backend_dev_type dev_type = ggml_backend_dev_type(dev_cur); + const char * dev_name = ggml_backend_dev_name(dev_cur); + PARAKEET_LOG_INFO("%s: device %zu: %s (type: %d)\n", __func__, i, dev_name, dev_type); + if (dev_type == GGML_BACKEND_DEVICE_TYPE_GPU || dev_type == GGML_BACKEND_DEVICE_TYPE_IGPU) { + PARAKEET_LOG_INFO("%s: found GPU device %zu: %s (type: %d, cnt: %d)\n", __func__, i, dev_name, dev_type, cnt); + if (cnt == params.gpu_device) { + dev = dev_cur; + } + + if (++cnt > params.gpu_device) { + break; + } + } + } + } + + if (dev == nullptr) { + PARAKEET_LOG_INFO("%s: no GPU found\n", __func__); + return nullptr; + } + + PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); + ggml_backend_t result = ggml_backend_dev_init(dev, nullptr); + if (!result) { + PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + } + + return result; +} + +static std::vector whisper_backend_init(const parakeet_context_params & params) { + std::vector result; + + ggml_backend_t backend_gpu = whisper_backend_init_gpu(params); + + if (backend_gpu) { + result.push_back(backend_gpu); + } + + // ACCEL backends + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + if (!backend) { + PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + continue; + } + result.push_back(backend); + } + } + + ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + if (backend_cpu == nullptr) { + throw std::runtime_error("failed to initialize CPU backend"); + } + result.push_back(backend_cpu); + + return result; +} + +using buft_list_t = std::vector>; + +static buft_list_t make_buft_list(parakeet_context_params & params) { + // Prio order: GPU -> CPU Extra -> CPU + buft_list_t buft_list; + + // GPU + if (params.use_gpu) { + int cnt = 0; + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU) { + if (cnt == params.gpu_device) { + auto * buft = ggml_backend_dev_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + } + } + + if (++cnt > params.gpu_device) { + break; + } + } + } + } + + // CPU Extra + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } + } + + // CPU + buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type()); + + return buft_list; +} + +static bool weight_buft_supported(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + bool op_supported = true; + + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || + ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU || + (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) { + // GPU and default CPU backend support all operators + op_supported = true; + } else { + switch (op) { + // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS + case GGML_OP_GET_ROWS: + case GGML_OP_MUL_MAT: { + ggml_init_params params = { + /*.mem_size =*/ 2 * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error("failed to create ggml context"); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + if (op == GGML_OP_MUL_MAT) { + int64_t n_ctx = hparams.n_audio_ctx; + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } else if (op == GGML_OP_GET_ROWS) { + int64_t num_indices = 8; + ggml_tensor * indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices); + op_tensor = ggml_get_rows(ctx, w, indices); + } + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + break; + } + default: { + op_supported = false; + break; + } + }; + } + + return op_supported; +} + +static ggml_backend_buffer_type_t select_weight_buft(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) { + GGML_ASSERT(!buft_list.empty()); + for (const auto & p : buft_list) { + ggml_backend_dev_t dev = p.first; + ggml_backend_buffer_type_t buft = p.second; + if (weight_buft_supported(hparams, w, op, buft, dev)) { + return buft; + } + } + + return nullptr; +} + +// load the model from a ggml file +// + +// see the convert-parakeet-to-ggml.py script for details +// +static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_context & wctx) { + PARAKEET_LOG_INFO("%s: loading model\n", __func__); + + const int64_t t_start_us = ggml_time_us(); + + wctx.t_start_us = t_start_us; + + auto & model = wctx.model; + auto & vocab = wctx.vocab; + + // verify magic + { + uint32_t magic; + read_safe(loader, magic); + if (magic != GGML_FILE_MAGIC) { + PARAKEET_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__); + return false; + } + } + + //load hparams + parakeet_hparams hparams; + { + read_safe(loader, hparams.n_vocab); + read_safe(loader, hparams.n_audio_ctx); + read_safe(loader, hparams.n_audio_state); + read_safe(loader, hparams.n_audio_head); + read_safe(loader, hparams.n_audio_layer); + read_safe(loader, hparams.n_mels); + read_safe(loader, hparams.ftype); + read_safe(loader, hparams.n_fft); + read_safe(loader, hparams.subsampling_factor); + read_safe(loader, hparams.n_subsampling_channels); + read_safe(loader, hparams.n_pos_max_len); + read_safe(loader, hparams.n_pred_dim); + read_safe(loader, hparams.n_pred_layers); + read_safe(loader, hparams.n_tdt_durations); + read_safe(loader, hparams.n_max_tokens); + + hparams.arch = PARAKEET_ARCH_TDT; + wctx.model.hparams = hparams; + + const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR; + + hparams.ftype %= GGML_QNT_VERSION_FACTOR; + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) hparams.ftype); + if (wctx.wtype == GGML_TYPE_COUNT) { + PARAKEET_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, hparams.ftype); + return false; + } + + const char* arch_name = hparams.arch == PARAKEET_ARCH_TDT ? "Parakeet TDT" : "unknown"; + PARAKEET_LOG_INFO("%s: arch = %s\n", __func__, arch_name); + PARAKEET_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + PARAKEET_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + PARAKEET_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + PARAKEET_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + PARAKEET_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + PARAKEET_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels); + PARAKEET_LOG_INFO("%s: n_fft = %d\n", __func__, hparams.n_fft); + PARAKEET_LOG_INFO("%s: ftype = %d\n", __func__, hparams.ftype); + PARAKEET_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr); + PARAKEET_LOG_INFO("%s: subsampling_factor = %d\n", __func__, hparams.subsampling_factor); + PARAKEET_LOG_INFO("%s: n_subsampling_channels = %d\n", __func__, hparams.n_subsampling_channels); + PARAKEET_LOG_INFO("%s: n_pos_max_len = %d\n", __func__, hparams.n_pos_max_len); + PARAKEET_LOG_INFO("%s: n_pred_dim = %d\n", __func__, hparams.n_pred_dim); + PARAKEET_LOG_INFO("%s: n_pred_layers = %d\n", __func__, hparams.n_pred_layers); + PARAKEET_LOG_INFO("%s: n_tdt_durations = %d\n", __func__, hparams.n_tdt_durations); + PARAKEET_LOG_INFO("%s: n_max_tokens = %d\n", __func__, hparams.n_max_tokens); + } + + // load mel filters + { + auto & filters = wctx.model.filters; + + read_safe(loader, filters.n_mel); + read_safe(loader, filters.n_fb); + + filters.data.resize(filters.n_mel * filters.n_fb); + loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float)); + BYTESWAP_FILTERS(filters); + } + + // load window function + { + int32_t n_window = 0; + read_safe(loader, n_window); + + wctx.mel_cache.window.resize(n_window); + loader->read(loader->context, wctx.mel_cache.window.data(), n_window * sizeof(float)); + +#ifdef GGML_BIG_ENDIAN + for (auto & datum : wctx.mel_cache.window) { + datum = byteswap(datum); + } +#endif + + PARAKEET_LOG_INFO("%s: loaded window function with %d samples\n", __func__, n_window); + } + + // load tdt durations values + { + auto & tdt_durations = wctx.model.tdt_durations; + tdt_durations.resize(hparams.n_tdt_durations); + loader->read(loader->context, tdt_durations.data(), hparams.n_tdt_durations * sizeof(uint32_t)); + + PARAKEET_LOG_INFO("%s: loaded tdt_durations: [", __func__); + for (const auto value : tdt_durations) { + PARAKEET_LOG_INFO("%u ", value); + } + PARAKEET_LOG_INFO("]\n"); + } + + // load vocab + { + int32_t n_vocab = 0; + read_safe(loader, n_vocab); + + std::string word; + std::vector tmp; + + tmp.reserve(128); + + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + read_safe(loader, len); + + if (len > 0) { + tmp.resize(len); + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + word.assign(&tmp[0], tmp.size()); + } else { + PARAKEET_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + word = ""; + } + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + vocab.max_token_length = std::max(vocab.max_token_length, word.size()); + } + // Blank token for transducer is at index n_vocab (8192), outside the vocabulary + int blank_id = n_vocab; + vocab.token_blank = blank_id; + vocab.id_to_token[blank_id] = "[BLANK]"; + vocab.token_to_id["[BLANK]"] = blank_id; + + // Set special token IDs by looking them up in the loaded vocabulary + // These are from the SentencePiece vocab file loaded above + if (vocab.token_to_id.find("") != vocab.token_to_id.end()) { + vocab.token_unk = vocab.token_to_id.at(""); + } else { + vocab.token_unk = 0; // Fallback + } + + if (vocab.token_to_id.find("") != vocab.token_to_id.end()) { + vocab.token_bos = vocab.token_to_id.at(""); + } else if (vocab.token_to_id.find("<|startoftranscript|>") != vocab.token_to_id.end()) { + vocab.token_bos = vocab.token_to_id.at("<|startoftranscript|>"); + } else { + vocab.token_bos = 0; // Fallback + } + + if (vocab.token_to_id.find("") != vocab.token_to_id.end()) { + vocab.token_eos = vocab.token_to_id.at(""); + } else if (vocab.token_to_id.find("<|endoftext|>") != vocab.token_to_id.end()) { + vocab.token_eos = vocab.token_to_id.at("<|endoftext|>"); + } else { + vocab.token_eos = 0; // Fallback + } + + vocab.n_vocab = model.hparams.n_vocab; + + PARAKEET_LOG_INFO("%s: loaded vocab with %d tokens (blank_id=%d, unk=%d, bos=%d, eos=%d)\n", + __func__, n_vocab, blank_id, vocab.token_unk, vocab.token_bos, vocab.token_eos); + } + + const ggml_type wtype = wctx.wtype; + const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type + + const int n_audio_layer = hparams.n_audio_layer; + + // Calculate tensor count based on architecture + size_t n_tensors = 2 + 14 + 1 + 29*n_audio_layer + 9 + 6; + + std::map ctx_map; + auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error("failed to create ggml context"); + } + + ctx_map[buft] = ctx; + wctx.model.ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + // Create a list of available bufts, in priority order + buft_list_t buft_list = make_buft_list(wctx.params); + + auto create_tensor = [&](parakeet_tensor type, ggml_tensor * meta, int layer = -1) -> ggml_tensor * { + ggml_op op = PARAKEET_TENSOR_INFO.at(type); + ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for parakeet tensor %s", PARAKEET_TENSOR_NAMES.at(type))); + } + + ggml_context * ctx = get_ctx(buft); + ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); + + std::string tensor_name; + if (layer >= 0) { + tensor_name = format(PARAKEET_TENSOR_NAMES.at(type), layer); + } else { + tensor_name = PARAKEET_TENSOR_NAMES.at(type); + } + + wctx.model.tensors[tensor_name] = tensor; + + return tensor; + }; + + // prepare tensors for the weights + + // Count: preprocessor (2) + pre_encode (14) + positional encoding (1) + encoder layers (29 per layer) + prediction (9) + joint (6) + const int n_parakeet_tensors = 2 + 14 + 1 + (29 * n_audio_layer) + 9 + 6; + + ggml_init_params params = { + /*.mem_size =*/ n_parakeet_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + + const int n_audio_state = hparams.n_audio_state; + + model.layers.resize(n_audio_layer); + + // Encoder pre_encode + const int n_subsampling_channels = hparams.n_subsampling_channels; + model.enc_pre_out_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4096, n_audio_state)); + ggml_set_name(model.enc_pre_out_w, "enc_pre_out_w"); + model.enc_pre_out_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state)); + ggml_set_name(model.enc_pre_out_b, "enc_pre_out_b"); + + model.enc_pre_conv_0_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, ggml_new_tensor_4d(ctx, vtype, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_0_w, "enc_pre_conv_0_w"); + model.enc_pre_conv_0_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_0_b, "enc_pre_conv_0_b"); + + model.enc_pre_conv_2_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, ggml_new_tensor_4d(ctx, vtype, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_2_w, "enc_pre_conv_2_w"); + model.enc_pre_conv_2_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_2_b, "enc_pre_conv_2_b"); + + model.enc_pre_conv_3_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, ggml_new_tensor_4d(ctx, wtype, 1, 1, n_subsampling_channels, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_3_w, "enc_pre_conv_3_w"); + model.enc_pre_conv_3_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_3_b, "enc_pre_conv_3_b"); + + model.enc_pre_conv_5_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, ggml_new_tensor_4d(ctx, vtype, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_5_w, "enc_pre_conv_5_w"); + model.enc_pre_conv_5_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_5_b, "enc_pre_conv_5_b"); + + model.enc_pre_conv_6_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, ggml_new_tensor_4d(ctx, wtype, 1, 1, n_subsampling_channels, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_6_w, "enc_pre_conv_6_w"); + model.enc_pre_conv_6_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_6_b, "enc_pre_conv_6_b"); + + // Encoder layers + for (int i = 0; i < n_audio_layer; ++i) { + auto & layer = model.layers[i]; + + // Feed forward 1 + layer.norm_ff1_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_ff1_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.ff1_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i); + ggml_format_name(layer.ff1_linear1_w, "enc_%d_ff1_linear1_w", i); + layer.ff1_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i); + ggml_format_name(layer.ff1_linear2_w, "enc_%d_ff1_linear2_w", i); + + // Convolution module + layer.norm_conv_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.norm_conv_w, "enc_%d_norm_conv_w", i); + layer.norm_conv_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.norm_conv_b, "enc_%d_norm_conv_b", i); + layer.conv_pw1_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 2*n_audio_state), i); + ggml_format_name(layer.conv_pw1_w, "enc_%d_conv_pw1_w", i); + layer.conv_dw_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 9, n_audio_state), i); + ggml_format_name(layer.conv_dw_w, "enc_%d_conv_dw_w", i); + layer.conv_bn_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_w, "enc_%d_conv_bn_w", i); + layer.conv_bn_b = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_b, "enc_%d_conv_bn_b", i); + layer.conv_bn_mean = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_MEAN, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.conv_bn_var = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_VAR, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_var, "enc_%d_conv_bn_var", i); + layer.conv_bn_num_batches = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), i); + layer.conv_pw2_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + ggml_format_name(layer.conv_pw2_w, "enc_%d_conv_pw2_w", i); + + // Self attention + layer.norm_attn_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_attn_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.attn_pos_bias_u = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 128, 8), i); + layer.attn_pos_bias_v = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 128, 8), i); + layer.attn_q_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_k_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_v_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_out_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_pos_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + ggml_format_name(layer.attn_pos_w, "enc_%d_attn_pos_w", i); + + // Feed forward 2 + layer.norm_ff2_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_ff2_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.ff2_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i); + layer.ff2_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i); + + // Output norm + layer.norm_out_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_out_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + } + + // Prediction network (decoder) + const int dec_hidden = hparams.n_pred_dim; + model.prediction.embed_w = create_tensor(PARAKEET_TENSOR_PRED_EMBED_WEIGHT, ggml_new_tensor_2d(ctx, wtype, dec_hidden, 8193)); + model.prediction.lstm_layer.resize(hparams.n_pred_layers); + for (int i = 0; i < hparams.n_pred_layers; ++i) { + auto & layer = model.prediction.lstm_layer[i]; + layer.ih_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, ggml_new_tensor_2d(ctx, wtype, dec_hidden, 2560), i); + ggml_format_name(layer.ih_w, "pred_%d_ih_w", i); + + layer.ih_b = create_tensor(PARAKEET_TENSOR_PRED_LSTM_BIAS_IH, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2560), i); + ggml_format_name(layer.ih_b, "pred_%d_ih_b", i); + + layer.hh_b = create_tensor(PARAKEET_TENSOR_PRED_LSTM_BIAS_HH, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2560), i); + ggml_format_name(layer.hh_b, "pred_%d_hh_b", i); + + layer.hh_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, ggml_new_tensor_2d(ctx, wtype, dec_hidden, 2560), i); + ggml_format_name(layer.hh_w, "pred_%d_hh_w", i); + } + + // Joint network + model.joint.pred_w = create_tensor(PARAKEET_TENSOR_JOINT_PRED_WEIGHT, ggml_new_tensor_2d(ctx, wtype, dec_hidden, dec_hidden)); + ggml_set_name(model.joint.pred_w, "pred_w"); + model.joint.pred_b = create_tensor(PARAKEET_TENSOR_JOINT_PRED_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden)); + ggml_set_name(model.joint.pred_b, "pred_b"); + model.joint.enc_w = create_tensor(PARAKEET_TENSOR_JOINT_ENC_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, dec_hidden)); + ggml_set_name(model.joint.enc_w, "enc_w"); + model.joint.enc_b = create_tensor(PARAKEET_TENSOR_JOINT_ENC_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden)); + ggml_set_name(model.joint.enc_b, "enc_b"); + model.joint.net_w = create_tensor(PARAKEET_TENSOR_JOINT_NET_WEIGHT, ggml_new_tensor_2d(ctx, wtype, dec_hidden, 8198)); + ggml_set_name(model.joint.net_w, "net_w"); + model.joint.net_b = create_tensor(PARAKEET_TENSOR_JOINT_NET_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 8198)); + ggml_set_name(model.joint.net_b, "net_b"); + + // Relative positional encoding + { + model.pe = create_tensor(PARAKEET_TENSOR_ENC_PE, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, hparams.n_pos_max_len * 2 - 1)); + ggml_set_name(model.pe, "pe"); + } + + ggml_free(ctx); + + // allocate tensors in the backend buffers + for (auto & p : ctx_map) { + ggml_backend_buffer_type_t buft = p.first; + ggml_context * ctx = p.second; + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (buf) { + wctx.model.buffers.emplace_back(buf); + + size_t size_main = ggml_backend_buffer_get_size(buf); + PARAKEET_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6); + } + } + + // load weights + { + size_t total_size = 0; + + auto & tensors_map = wctx.model.tensors; + int & n_loaded = wctx.model.n_loaded; + + n_loaded = 0; + + std::vector read_buf; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + read_safe(loader, n_dims); + read_safe(loader, length); + read_safe(loader, ttype); + + if (loader->eof(loader->context)) { + break; + } + + int32_t nelements = 1; + int32_t ne[4] = { 1, 1, 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + read_safe(loader, ne[i]); + nelements *= ne[i]; + } + + std::string name; + std::vector tmp(length); // create a buffer + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + name.assign(&tmp[0], tmp.size()); + + if (tensors_map.find(name) == tensors_map.end()) { + PARAKEET_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = tensors_map[name.data()]; + + if (ggml_nelements(tensor) != nelements) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + PARAKEET_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", + __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2] || tensor->ne[3] != ne[3]) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d, %d], expected [%d, %d, %d, %d]\n", + __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], (int) tensor->ne[3], ne[0], ne[1], ne[2], ne[3]); + return false; + } + + const size_t bpe = ggml_type_size(ggml_type(ttype)); + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; + } + + if (ggml_backend_buffer_is_host(tensor->buffer)) { + // for the CPU and Metal backend, we can read directly into the tensor + loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(ggml_nbytes(tensor)); + + loader->read(loader->context, read_buf.data(), read_buf.size()); + + ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); + } + + total_size += ggml_nbytes(tensor); + n_loaded++; + } + + PARAKEET_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6); + + if (n_loaded == 0) { + PARAKEET_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } else if (n_loaded != (int) tensors_map.size()) { + PARAKEET_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, tensors_map.size(), n_loaded); + return false; + } + } + + auto & buffers = wctx.model.buffers; + for (auto & buf : buffers) { + ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } + + wctx.t_load_us = ggml_time_us() - t_start_us; + + return true; +} + +// +// This function build the computation graph for the pre-encoder stage. +// +// It takes the generated mel spectrogram as an input tensor, which will be set +// during processing, and performs a number of convolutions to detect/filter +// features, reduces the time dimension by a factor of 8, and finally projects +// this information into the models abstract feature space. +// +static struct ggml_cgraph * parakeet_build_graph_conv(parakeet_context & pctx, parakeet_state & pstate) { + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_time = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : hparams.n_audio_ctx; + const int n_mels = hparams.n_mels; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_conv.meta.size(), + /*.mem_buffer =*/ pstate.sched_conv.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph(ctx0); + + // [freq, time] + struct ggml_tensor * mel = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_mels, n_time, 1, 1); + ggml_set_name(mel, "mel"); + ggml_set_input(mel); + + // [freq, time, channels, batch] + struct ggml_tensor * cur = ggml_conv_2d(ctx0, model.enc_pre_conv_0_w, mel, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_0_b); + ggml_set_name(cur, "pre_conv_0"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_0_relu"); + + // enc_pre_conv_2_w: {3, 3, 1, 256} (depthwise) + // [freq, time, channels, batch] + cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_2_w, cur, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_2_b); + ggml_set_name(cur, "pre_conv_2"); + + // enc_pre_conv_3_w: {1, 1, 256, 256} (pointwise) + // [freq, time, channels, batch] + cur = ggml_conv_2d(ctx0, model.enc_pre_conv_3_w, cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_3_b); + ggml_set_name(cur, "pre_conv_3"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_3_relu"); + + // enc_pre_conv_5_w: {3, 3, 1, 256} (depthwise) + // [freq, time, channels, batch] + cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_5_w, cur, 2, 2, 1, 1, 1, 1); + ggml_set_name(cur, "pre_conv_5_direct"); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_5_b); + ggml_set_name(cur, "pre_conv_5"); + + // enc_pre_conv_6_w: {1, 1, 256, 256} (pointwise) + // [freq, time, channels, batch] + cur = ggml_conv_2d(ctx0, model.enc_pre_conv_6_w, cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_6_b); + ggml_set_name(cur, "pre_conv_6"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_6_relu"); + + // [freq, time, chan] + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + // [freq, chan, time] + cur = ggml_cont(ctx0, cur); + + const int n_freq = cur->ne[0]; // 16 + const int n_chan = cur->ne[1]; // 256 + const int n_frames = cur->ne[2]; // time + + // [freq, time, chan, batch] -> [(freq * chan), time] + cur = ggml_reshape_2d(ctx0, cur, n_freq * n_chan, n_frames); + + cur = ggml_mul_mat(ctx0, model.enc_pre_out_w, cur); + cur = ggml_add(ctx0, cur, model.enc_pre_out_b); + + ggml_set_name(cur, "pre_enc_out"); + ggml_set_output(cur); + pstate.pre_enc_out = cur; + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + +// This function builds the graph for the encoder part of the model. +// +// It takes the output of the pre-encoder convolution above which will have the +// shape of [hidden_dim, time_frames] +static struct ggml_cgraph * parakeet_build_graph_encoder(parakeet_context & pctx, parakeet_state & pstate) { + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_layer = hparams.n_audio_layer; + const int n_state = hparams.n_audio_state; + const float fc_factor = 0.5f; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_encode.meta.size(), + /*.mem_buffer =*/ pstate.sched_encode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + // Create a view of the output produced by parakeet_build_graph_conv + // [feat, time_frames, 1, 1] + struct ggml_tensor * cur = ggml_view_tensor(ctx0, pstate.pre_enc_out); + ggml_set_name(cur, "encoder_inp"); + + // [time_frames, time_frames, 1, 1]] + struct ggml_tensor * attn_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, cur->ne[1], cur->ne[1]); + ggml_set_name(attn_mask, "attn_mask"); + ggml_set_input(attn_mask); + + for (int il = 0; il < n_layer; ++il) { + // FFN1 + { + struct ggml_tensor * residual = cur; + ggml_format_name(cur, "enc_%d_res", il); + + // norm + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].norm_ff1_w), model.layers[il].norm_ff1_b); + ggml_format_name(cur, "enc_%d_ffn_norm_1", il); + + // ffn_1 + cur = ggml_mul_mat(ctx0, model.layers[il].ff1_linear1_w, cur); + cur = ggml_silu(ctx0, cur); + ggml_format_name(cur, "enc_%d_silu", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].ff1_linear2_w, cur); + ggml_format_name(cur, "enc_%d_ffn_1", il); + + cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor)); + ggml_format_name(cur, "enc_%d_res_ffn", il); + } + + // self attention block using relative positional encoding from the model.pe tensor. + { + // [feat, time_frames, 1, 1] + struct ggml_tensor * residual = cur; + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].norm_attn_w), model.layers[il].norm_attn_b); + ggml_format_name(cur, "enc_%d_attn_norm", il); + + const int n_head = hparams.n_audio_head; + const int d_head = n_state / n_head; + const int n_time = cur->ne[1]; + + // [feat, time_frames, 1, 1] + struct ggml_tensor * Q_cur = ggml_mul_mat(ctx0, model.layers[il].attn_q_w, cur); + struct ggml_tensor * K_cur = ggml_mul_mat(ctx0, model.layers[il].attn_k_w, cur); + struct ggml_tensor * V_cur = ggml_mul_mat(ctx0, model.layers[il].attn_v_w, cur); + + // [d_head, n_heads, time_frames, 1] + Q_cur = ggml_reshape_3d(ctx0, Q_cur, d_head, n_head, n_time); + K_cur = ggml_reshape_3d(ctx0, K_cur, d_head, n_head, n_time); + V_cur = ggml_reshape_3d(ctx0, V_cur, d_head, n_head, n_time); + + const int input_len = cur->ne[1]; + const int center_pos = model.pe->ne[1] / 2 + 1; + const int start_pos = center_pos - input_len; + const int window_size = 2 * input_len - 1; + const size_t offset = start_pos * model.pe->nb[1]; + + // [feat, window_size] + struct ggml_tensor * pos_emb = ggml_view_2d(ctx0, model.pe, + n_state, window_size, + model.pe->nb[1], offset); + ggml_format_name(pos_emb, "enc_%d_attn_pos_emb", il); + + struct ggml_tensor * pos = ggml_mul_mat(ctx0, model.layers[il].attn_pos_w, pos_emb); + ggml_format_name(pos, "enc_%d_attn_pos", il); + + // Add the content bias to Q. Like a baseline query for content. Like + // regardless of my specific position all ways look for these base features. + // So we shift the Q by the content bias. + // [feat, head, time_frames, batch] + struct ggml_tensor * Q_u = ggml_add(ctx0, Q_cur, model.layers[il].attn_pos_bias_u); + ggml_format_name(Q_u, "enc_%d_attn_q_u", il); + + // [feat, time_frames, head, 1] + struct ggml_tensor * K_prep = ggml_permute(ctx0, K_cur, 0, 2, 1, 3); + // [feat, time_frames, head, 1] + struct ggml_tensor * Q_prep = ggml_permute(ctx0, Q_u, 0, 2, 1, 3); + // [feat, feat, head, 1] + struct ggml_tensor * content_scores = ggml_mul_mat(ctx0, K_prep, Q_prep); + ggml_format_name(content_scores, "enc_%d_attn_content_scores", il); + + // Add the position bias to Q. Like a baseline for query positions. Like + // regardless of my specific content I'm always interested in tokens a certain + // distance away. So we shift the Q by the position bias. + // [feat, head, time_frames, batch] + struct ggml_tensor * Q_v = ggml_add(ctx0, Q_cur, model.layers[il].attn_pos_bias_v); + ggml_format_name(Q_v, "enc_%d_attn_q_v", il); + + // [feat, window_size, 1, 1] and we are doing multi-head attention so + // we need to split this into heads. + // [feat, head, window_size, 1] + pos = ggml_reshape_3d(ctx0, pos, d_head, n_head, pos_emb->ne[1]); + + // [feat, window_size, head, 1] + pos = ggml_permute(ctx0, pos, 0, 2, 1, 3); + pos = ggml_cont(ctx0, pos); + ggml_format_name(pos, "enc_%d_attn_pos_perm", il); + // [feat, time, head, 1] + Q_v = ggml_permute(ctx0, Q_v, 0, 2, 1, 3); + Q_v = ggml_cont(ctx0, Q_v); + ggml_format_name(Q_v, "enc_%d_attn_q_v_perm", il); + + // [window_size, time_frames, head, 1] + struct ggml_tensor * rel_pos_scores = ggml_mul_mat(ctx0, pos, Q_v); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos", il); + + // Relative positional shift + { + + const auto pos_window = rel_pos_scores->ne[0]; + const auto n_frame = rel_pos_scores->ne[1]; + const auto n_head = rel_pos_scores->ne[2]; + + // [feat_padded, window_size, head, 1] + rel_pos_scores = ggml_pad(ctx0, rel_pos_scores, 1, 0, 0, 0); + rel_pos_scores = ggml_roll(ctx0, rel_pos_scores, 1, 0, 0, 0); + + rel_pos_scores = ggml_reshape_3d(ctx0, rel_pos_scores, n_frame, pos_window + 1, n_head); + rel_pos_scores = ggml_cont(ctx0, rel_pos_scores); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_reshaped", il); + + int center = pos_window / 2; + size_t offset = rel_pos_scores->nb[0] * (center+1); + + rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores, + n_frame, pos_window, n_head, + (pos_window) * 4, + rel_pos_scores->nb[2], + offset); + + rel_pos_scores = ggml_cont(ctx0, rel_pos_scores); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted", il); + + rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores, + content_scores->ne[0], + content_scores->ne[1], + rel_pos_scores->ne[2], + rel_pos_scores->nb[1], + rel_pos_scores->nb[2], + 0); + rel_pos_scores = ggml_cont(ctx0, rel_pos_scores); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted_view", il); + } + + struct ggml_tensor * attn_scores = ggml_add(ctx0, content_scores, rel_pos_scores); + attn_scores = ggml_cont(ctx0, attn_scores); + ggml_format_name(attn_scores, "enc_%d_attn_scores", il); + attn_scores = ggml_scale(ctx0, attn_scores, 1.0f / std::sqrt(d_head)); + attn_scores = ggml_add(ctx0, attn_scores, attn_mask); + ggml_format_name(attn_scores, "enc_%d_attn_scores_scaled", il); + + struct ggml_tensor * probs = ggml_soft_max(ctx0, attn_scores); + ggml_format_name(probs, "enc_%d_attn_probs", il); + + V_cur = ggml_cont(ctx0, ggml_permute(ctx0, V_cur, 1, 2, 0, 3)); + ggml_format_name(V_cur, "enc_%d_attn_v_cur", il); + cur = ggml_mul_mat(ctx0, probs, V_cur); + ggml_format_name(cur, "enc_%d_attn_inp", il); + + cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); + cur = ggml_cont_2d(ctx0, cur, n_state, n_time); + cur = ggml_mul_mat(ctx0, model.layers[il].attn_out_w, cur); + ggml_format_name(cur, "enc_%d_attn_out", il); + + cur = ggml_add(ctx0, residual, cur); + ggml_format_name(cur, "enc_%d_attn_res", il); + } + + // Convolution + { + struct ggml_tensor * residual = cur; + ggml_format_name(cur, "enc_%d_residual_conv", il); + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].norm_conv_w), model.layers[il].norm_conv_b); + ggml_format_name(cur, "enc_%d_norm_conv", il); + + // pointwise 1d convolution: [1024, 138] -> [2048, 138] + cur = ggml_mul_mat(ctx0, model.layers[il].conv_pw1_w, cur); + ggml_format_name(cur, "enc_%d_conv_pw1", il); + + { + int64_t d = cur->ne[0] / 2; + struct ggml_tensor * signal = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], 0); + struct ggml_tensor * gate = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], d * cur->nb[0]); + + cur = ggml_mul(ctx0, signal, ggml_sigmoid(ctx0, gate)); + ggml_format_name(cur, "enc_%d_conv_glu", il); + } + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + // use ggml_ssm_conv for f32 precision + cur = ggml_pad(ctx0, cur, 4, 0, 0, 0); + cur = ggml_roll(ctx0, cur, 4, 0, 0, 0); + cur = ggml_pad(ctx0, cur, 4, 0, 0, 0); + ggml_format_name(cur, "enc_%d_conv_dw_pad", il); + + cur = ggml_ssm_conv(ctx0, cur, model.layers[il].conv_dw_w); + ggml_format_name(cur, "enc_%d_conv_1d_dw", il); + + cur = ggml_sub(ctx0, cur, model.layers[il].conv_bn_mean); + struct ggml_tensor * std = ggml_sqrt(ctx0, model.layers[il].conv_bn_var); + cur = ggml_div(ctx0, cur, std); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].conv_bn_w), model.layers[il].conv_bn_b); + ggml_format_name(cur, "enc_%d_conv_bn", il); + + cur = ggml_silu(ctx0, cur); + ggml_format_name(cur, "enc_%d_conv_silu", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].conv_pw2_w, cur); + ggml_format_name(cur, "enc_%d_conv_pw2", il); + + cur = ggml_add(ctx0, residual, cur); + ggml_format_name(cur, "enc_%d_conv_res", il); + } + + // FFN2 + { + struct ggml_tensor * residual = cur; + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].norm_ff2_w), model.layers[il].norm_ff2_b); + ggml_format_name(cur, "enc_%d_ffn_norm_2", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].ff2_linear1_w, cur); + cur = ggml_silu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.layers[il].ff2_linear2_w, cur); + cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, 0.5)); + ggml_format_name(cur, "enc_%d_ffn_res", il); + } + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].norm_out_w), model.layers[il].norm_out_b); + } + + ggml_set_name(cur, "encoder_out"); + ggml_set_output(cur); + pstate.enc_out = cur; + pstate.n_frames = cur->ne[1]; + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + +static bool parakeet_encode_internal( + parakeet_context & pctx, + parakeet_state & pstate, + const int mel_offset, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + const int64_t t_start_us = ggml_time_us(); + + // conv + { + auto & sched = pstate.sched_conv.sched; + + ggml_cgraph * gf = parakeet_build_graph_conv(pctx, pstate); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); + + // set the input + { + const auto & mel_inp = pstate.mel; + const int n_ctx = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : pctx.model.hparams.n_audio_ctx; + + assert(mel->type == GGML_TYPE_F32); + assert(mel_inp.n_mel == pctx.model.hparams.n_mels); + + pstate.inp_mel.resize(ggml_nelements(mel)); + + float * dst = pstate.inp_mel.data(); + memset(dst, 0, ggml_nbytes(mel)); + + const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i1 = std::min(mel_offset + n_ctx, mel_inp.n_len); + + const int n_frames_to_copy = i1 - i0; + memcpy(dst, mel_inp.data.data() + i0 * mel_inp.n_mel, n_frames_to_copy * mel_inp.n_mel * sizeof(float)); + + ggml_backend_tensor_set(mel, pstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float)); + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + } + + // encoder + auto & sched = pstate.sched_encode.sched; + + ggml_cgraph * gf = parakeet_build_graph_encoder(pctx, pstate); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + // set the inputs + { + struct ggml_tensor * attn_mask = ggml_graph_get_tensor(gf, "attn_mask"); + const int n_q = attn_mask->ne[1]; + const int n_k = attn_mask->ne[0]; + + const int32_t subsampl_factor = pctx.model.hparams.subsampling_factor; + const int n_tokens_real = (pstate.mel.n_len_org + subsampl_factor-1) / subsampl_factor; + + std::vector mask_data(n_q * n_k); + const float mask_value = -1e30f; + + for (int q = 0; q < n_q; ++q) { + for (int k = 0; k < n_k; ++k) { + bool is_padding = (k >= n_tokens_real); + + mask_data[q * n_k + k] = (is_padding) ? mask_value : 0.0f; + } + } + ggml_backend_tensor_set(attn_mask, mask_data.data(), 0, mask_data.size() * sizeof(float)); + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + + pstate.t_encode_us += ggml_time_us() - t_start_us; + pstate.n_encode++; + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static struct ggml_tensor * parakeet_build_graph_lstm_layer( + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + struct ggml_tensor * x_t, // the current input token embedding + struct ggml_tensor * w_ih, // input to hidden weights (4 weight tensors packed) + struct ggml_tensor * b_ih, // input to hidden bias (4 bias packed together) + struct ggml_tensor * w_hh, // hidden to hidden weights (4 weight tensors packed) + struct ggml_tensor * b_hh, // hidden to hidden bias (4 bias tensors packed) + struct ggml_tensor * h_state, // this layers hidden state + struct ggml_tensor * c_state, // this layers cell state + int li) { // layer index (for tensor naming) + + ggml_format_name(x_t, "lstm_layer_%d_x_t", li); + ggml_format_name(h_state, "lstm_layer_%d_h_state", li); + ggml_format_name(c_state, "lstm_layer_%d_c_state", li); + + // Input/Forget/Cell/Output gates are packed in the same weight tensor. + struct ggml_tensor * inp_gates = ggml_mul_mat(ctx0, w_ih, x_t); + inp_gates = ggml_add(ctx0, inp_gates, b_ih); + + // Hidden-to-Hidden Projections are also packed in the same weight tensor. + struct ggml_tensor * hid_gates = ggml_mul_mat(ctx0, w_hh, h_state); + hid_gates = ggml_add(ctx0, hid_gates, b_hh); + + // Combine the input and hidden contributions of the gates. + struct ggml_tensor * gates = ggml_add(ctx0, inp_gates, hid_gates); + ggml_format_name(gates, "lstm_layer_%d_gates", li); + + const int h_dim = h_state->ne[0]; + const size_t row_size = ggml_row_size(gates->type, h_dim); + + // 1. Input Gate at time t. + struct ggml_tensor * i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, gates, h_dim, 0 * row_size)); + ggml_format_name(i_t, "lstm_layer_%d_i_t", li); + + // Forget gate. + struct ggml_tensor * f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, gates, h_dim, 1 * row_size)); + ggml_format_name(f_t, "lstm_layer_%d_f_t", li); + + // Cell gate. + struct ggml_tensor * c_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, gates, h_dim, 2 * row_size)); + ggml_format_name(c_t, "lstm_layer_%d_c_t", li); + + // Output gate. + struct ggml_tensor * o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, gates, h_dim, 3 * row_size)); + ggml_format_name(o_t, "lstm_layer_%d_o_t", li); + + // Calculate the new cell state. + struct ggml_tensor * c_new = ggml_add(ctx0, + ggml_mul(ctx0, f_t, c_state), // apply forget gate to cell state. + ggml_mul(ctx0, i_t, c_t)); // apply input gate to cell gate. + ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_new, c_state)); + + // Calculate the new hidden state. + struct ggml_tensor * h_new = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_new)); + ggml_set_output(h_new); + ggml_format_name(h_new, "lstm_layer_%d_h_new", li); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_new, h_state)); + + return h_new; +} + +static struct ggml_cgraph * parakeet_build_graph_prediction( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + bool worst_case) { + GGML_UNUSED(worst_case); + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_tokens = batch.n_tokens; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_decode.meta.size(), + /*.mem_buffer =*/ pstate.sched_decode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + // Prediction Network + struct ggml_tensor * token = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_name(token, "token_inp"); + ggml_set_input(token); + + struct ggml_tensor * token_embd = ggml_get_rows(ctx0, model.prediction.embed_w, token); + ggml_set_input(token_embd); + + struct ggml_tensor * inpL = token_embd; + + for (int il = 0; il < hparams.n_pred_layers; ++il) { + inpL = parakeet_build_graph_lstm_layer(ctx0, gf, inpL, + model.prediction.lstm_layer[il].ih_w, + model.prediction.lstm_layer[il].ih_b, + model.prediction.lstm_layer[il].hh_w, + model.prediction.lstm_layer[il].hh_b, + pstate.lstm_state.layer[il].h_state, + pstate.lstm_state.layer[il].c_state, + il); + } + + struct ggml_tensor * pred_out = inpL; + ggml_format_name(pred_out, "lstm_pred_out"); + + // Project the prediction network output to the joint network hidden dimension. + struct ggml_tensor * pred = ggml_mul_mat(ctx0, model.joint.pred_w, pred_out); + pred = ggml_add(ctx0, pred, model.joint.pred_b); + ggml_set_output(pred); + ggml_set_name(pred, "h_pred"); + + ggml_build_forward_expand(gf, pred); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * parakeet_build_graph_joint( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + bool worst_case) { + GGML_UNUSED(worst_case); + GGML_UNUSED(batch); + const auto & model = pctx.model; + const auto & hparams = model.hparams; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_decode.meta.size(), + /*.mem_buffer =*/ pstate.sched_decode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + struct ggml_tensor * pred = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_pred_dim); + ggml_set_name(pred, "pred"); + ggml_set_input(pred); + + struct ggml_tensor * enc_out = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_audio_state); + ggml_set_name(enc_out, "enc_out"); + ggml_set_input(enc_out); + + // Project the encoder output to the joint network hidden dimension. + struct ggml_tensor * enc = ggml_mul_mat(ctx0, model.joint.enc_w, enc_out); + enc = ggml_add(ctx0, enc, model.joint.enc_b); + ggml_set_name(enc, "enc"); + + struct ggml_tensor * joint = ggml_add(ctx0, enc, pred); + ggml_set_name(joint, "joint"); + joint = ggml_relu(ctx0, joint); + + struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.joint.net_w, joint); + logits = ggml_add(ctx0, logits, model.joint.net_b); + ggml_set_output(logits); + ggml_set_name(logits, "logits"); + + struct ggml_tensor * probs = ggml_soft_max(ctx0, logits); + struct ggml_tensor * log_probs = ggml_log(ctx0, probs); + ggml_set_output(log_probs); + ggml_format_name(log_probs, "log_probs"); + + ggml_build_forward_expand(gf, log_probs); + + ggml_free(ctx0); + + return gf; +} + +static bool parakeet_predict( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + + const int n_tokens = batch.n_tokens; + + { + auto & sched = pstate.sched_decode.sched; + + ggml_cgraph * gf = parakeet_build_graph_prediction(pctx, pstate, batch, false); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + // set the inputs + { + struct ggml_tensor * token_inp = ggml_graph_get_tensor(gf, "token_inp"); + ggml_backend_tensor_set(token_inp, batch.token, 0, n_tokens * ggml_element_size(token_inp)); + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + + // Copy h_pred output to pstate.pred_out for use in joint network + { + struct ggml_tensor * h_pred = ggml_graph_get_tensor(gf, "h_pred"); + ggml_backend_tensor_copy(h_pred, pstate.pred_out); + } + + } + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool parakeet_joint( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + const int64_t t_start_us = ggml_time_us(); + + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_tokens = batch.n_tokens; + + auto & logits_out = pstate.logits; + + struct ggml_tensor * logits; + + { + auto & sched = pstate.sched_decode.sched; + + ggml_cgraph * gf = parakeet_build_graph_joint(pctx, pstate, batch, false); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + // set the inputs + { + struct ggml_tensor * pred = ggml_graph_get_tensor(gf, "pred"); + if (n_tokens == 1) { + ggml_backend_tensor_copy(pstate.pred_out, pred); + } + + struct ggml_tensor * enc_out_inp = ggml_graph_get_tensor(gf, "enc_out"); + if (n_tokens == 1) { + const int t_idx = batch.i_time[0]; + const int d_enc = hparams.n_audio_state; + + std::vector frame_buf(d_enc); + + // If we have accumulated encoder output (chunked processing), use that + if (!pstate.enc_out_buffer.empty() && t_idx < pstate.enc_out_frames) { + const float * src = pstate.enc_out_buffer.data() + (t_idx * d_enc); + std::copy(src, src + d_enc, frame_buf.begin()); + } else { + ggml_backend_tensor_get(pstate.enc_out, frame_buf.data(), t_idx * d_enc * sizeof(float), d_enc * sizeof(float)); + } + + ggml_backend_tensor_set(enc_out_inp, frame_buf.data(), 0, d_enc * sizeof(float)); + } + } + + logits = ggml_graph_node(gf, -1); + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + + } + + const int n_logits = hparams.n_vocab + hparams.n_tdt_durations + 1; // one for the blank token + logits_out.resize(n_tokens * n_logits); + for (int i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + ggml_backend_tensor_get(logits, logits_out.data() + (n_logits*i), sizeof(float)*(n_logits*i), sizeof(float)*n_logits); + } + + if (batch.n_tokens == 1) { + pstate.t_decode_us += ggml_time_us() - t_start_us; + pstate.n_decode++; + } else if (batch.n_tokens < 16) { + pstate.t_batchd_us += ggml_time_us() - t_start_us; + pstate.n_batchd += n_tokens; + } else { + pstate.t_prompt_us += ggml_time_us() - t_start_us; + pstate.n_prompt += n_tokens; + } + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool is_word_start_token(parakeet_vocab & vocab, parakeet_token token_id) { + const std::string & token_str = vocab.id_to_token[token_id]; + // check if it starts with the SentencePiece meta-space "▁" (U+2581) or 3-byte UTF-8 character: 0xE2 0x96 0x81 + if (!token_str.empty()) { + if (token_str.find("\xE2\x96\x81") == 0 || token_str[0] == '_') { + return true; + } + } + return false; +} + +static bool is_punctuation_token(parakeet_vocab & vocab, parakeet_token token_id) { + const std::string & token_str = vocab.id_to_token[token_id]; + static const std::string punct_chars = ".,!?;:'\"-()[]{}"; + + if (token_str.empty()) { + return false; + } + + std::string clean_token = token_str; + if (clean_token.find("\xE2\x96\x81") == 0) { + clean_token = clean_token.substr(3); // Remove the 3-byte UTF-8 character + } else if (clean_token[0] == '_') { + clean_token = clean_token.substr(1); + } + + return clean_token.length() == 1 && punct_chars.find(clean_token[0]) != std::string::npos; +} + +// Collapse punctuation timestamps to match the original Parakeet model. +// Punctuations symbols like ',', '.' and others are not spoken words but the +// model will still produce a duration for these tokens. But since these are +// non-spoken we collapse the timesstamps so that they don't have an time duration. +static void refine_timestamps_tdt(parakeet_vocab & vocab, std::vector & tokens) { + if (tokens.empty()) { + return; + } + + int64_t last_non_punct_t1 = -1; + + for (size_t i = 0; i < tokens.size(); ++i) { + if (is_punctuation_token(vocab, tokens[i].id)) { + if (last_non_punct_t1 >= 0) { + tokens[i].t0 = last_non_punct_t1; + tokens[i].t1 = last_non_punct_t1; + } + } else { + last_non_punct_t1 = tokens[i].t1; + } + } +} + +static parakeet_token_data create_token_data( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_token token_id, + int duration_idx, + int duration_value, + int frame_index, + float token_logit, + int n_vocab_logits) { + + float token_sum = 0.0f; + for (int i = 0; i < n_vocab_logits; ++i) { + token_sum += expf(pstate.logits[i]); + } + float token_p = expf(token_logit) / token_sum; + + parakeet_token_data token_data; + token_data.id = token_id; + token_data.duration_idx = duration_idx; + token_data.duration_value = duration_value; + token_data.frame_index = frame_index; + token_data.p = token_p; + token_data.plog = token_logit; + token_data.t0 = frame_index * pctx.model.hparams.subsampling_factor; + token_data.t1 = (frame_index + duration_value) * pctx.model.hparams.subsampling_factor; + token_data.is_word_start = is_word_start_token(pctx.vocab, token_id); + + return token_data; +} + +static bool parakeet_decode( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_batch & batch, + const int n_threads, + const parakeet_full_params * params = nullptr) { + const auto & hparams = pctx.model.hparams; + const auto & tdt_durations = pctx.model.tdt_durations; + + const int n_tdt_durations = hparams.n_tdt_durations; + const int n_frames = pstate.n_frames; + const int blank_id = pctx.vocab.token_blank; + const int n_vocab_logits = blank_id + 1; + const int max_tokens_per_timestep = hparams.n_max_tokens; + + // time index into the encoder frame (current time frame) + int t = 0; + // number of symbols emitted for the current time frame + int tokens_emitted = 0; + + // Start with the blank token (8192) + parakeet_token last_token = blank_id; + + PARAKEET_LOG_DEBUG("parakeet_decode: starting decode with n_frames=%d\n", n_frames); + + batch.n_tokens = 1; + batch.token[0] = last_token; + batch.logits[0] = 1; + batch.i_time[0] = 0; + + // run the prediction network for the initial blank token. This will + // initialize the LSTM state and produce an initial hidden state that can + // be used in the joint network below. + if (!parakeet_predict(pctx, pstate, batch, n_threads, nullptr, nullptr)) { + return false; + } + + // process all time frames of the encoder output + while (t < n_frames) { + batch.n_tokens = 1; + batch.i_time[0] = t; + batch.logits[0] = 1; + + // Use the current encoder frame (t) and the output of the prediction to + // generate probabilities for the next token and duration. batch.i_time + // is used in to select the correct frame from the encoder output. + // The joint network outputs logits for all the tokens in the vocabulary + // plus the blank token, and also n_duration logits for the duration + // tokens which contain information about how many frames to skip/advance forward. + if (!parakeet_joint(pctx, pstate, batch, n_threads, nullptr, nullptr)) { + return false; + } + + // find the best token (greedy). + // TODO: implement beam search? + int best_token = 0; + float max_logit = -1e10f; + for (int i = 0; i < n_vocab_logits; ++i) { + if (pstate.logits[i] > max_logit) { + max_logit = pstate.logits[i]; + best_token = i; + } + } + + // find the max index of the duration logits, and look up that index + // value in the tdt_durations array to get the actual duration value. + int best_duration_idx = 0; + float best_duration_logit = -1e10f; + for (int i = 0; i < n_tdt_durations; ++i) { + if (pstate.logits[n_vocab_logits + i] > best_duration_logit) { + best_duration_logit = pstate.logits[n_vocab_logits + i]; + best_duration_idx = i; + } + } + // look up that max duration index value in the tdt_durations array to + // get the actual duration value. + int duration = tdt_durations[best_duration_idx]; + + if (best_token == blank_id) { + if (duration == 0) { + duration = 1; + } + // skip forward by duration time frames. + t += duration; + // reset symbols emitted counter + tokens_emitted = 0; + // continue without predicting. + continue; + } + + // Emit non-blank token at current frame t. + pstate.decoded_tokens.push_back(best_token); + pstate.n_sample++; + + parakeet_token_data token_data = create_token_data( + pctx, pstate, best_token, best_duration_idx, duration, t, + max_logit, n_vocab_logits); + + pstate.decoded_token_data.push_back(token_data); + + // Call token callback if registered (for real-time streaming) + if (params && params->new_token_callback) { + params->new_token_callback(&pctx, &pstate, &token_data, params->new_token_callback_user_data); + } + + last_token = best_token; + + // advance predictor for the non-blank token. + batch.token[0] = last_token; + if (!parakeet_predict(pctx, pstate, batch, n_threads, nullptr, nullptr)) { + return false; + } + + // if duration greater than 0, continue looping over the encoder frames + // and skip to the updated time frame (t). + if (duration > 0) { + t += duration; + tokens_emitted = 0; + continue; + } + + // if duration is zero we stay on the current time frame. + tokens_emitted++; + if (tokens_emitted >= max_tokens_per_timestep) { + t += 1; // forced blank/time advance behavior + tokens_emitted = 0; + } + } + + return true; +} + +static bool parakeet_decode_chunk( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_batch & batch, + const int chunk_enc_frames, + const int n_threads, + const parakeet_full_params * params = nullptr) { + const auto & hparams = pctx.model.hparams; + const auto & tdt_durations = pctx.model.tdt_durations; + + const int blank_id = pctx.vocab.token_blank; + const int n_vocab_logits = blank_id + 1; + const int max_tokens_per_timestep = hparams.n_max_tokens; + + // time index into the encoder frame (current time frame) + int t = pstate.tdt_stream_state.time_step; + parakeet_token last_token = pstate.tdt_stream_state.last_token; + + PARAKEET_LOG_DEBUG("%s: chunk_frames=%d, start_t=%d, last_token=%d, initialized=%d\n", + __func__, chunk_enc_frames, t, last_token, pstate.tdt_stream_state.initialized); + + if (!pstate.tdt_stream_state.initialized) { + // Start with the blank token (8192) + last_token = blank_id; + t = 0; + + batch.n_tokens = 1; + batch.token[0] = last_token; + batch.logits[0] = 0; + batch.i_time[0] = 0; + + // run the prediction network for the initial blank token. This will + // initialize the LSTM state and produce an initial hidden state that can + // be used in the joint network. + if (!parakeet_predict(pctx, pstate, batch, n_threads, nullptr, nullptr)) { + return false; + } + pstate.tdt_stream_state.initialized = true; + pstate.tdt_stream_state.last_token = last_token; + } + + // Count non-blank emissions at the same encoder frame to mimic max_symbols handling. + int last_emission_time = -1; + int tokens_at_time = 0; + + // loop over all the time frames the encoder produced for this chunk. + while (t < chunk_enc_frames) { + int current_token_time = t; + + int chosen_token = blank_id; + int chosen_duration = 1; + int chosen_duration_idx = 0; + float chosen_token_logit = 0.0f; + + // inner loop handles the joint network + while (t < chunk_enc_frames) { + batch.n_tokens = 1; + batch.i_time[0] = std::min(t, chunk_enc_frames - 1); + batch.logits[0] = 1; + + // Use the current encoder frame (t) and the output of the prediction to + // generate probabilities for the next token and duration. + // The joint network outputs logits for all the tokens in the vocabulary + // and the blank token, and also addtionally N duration logits for the duration + // tokens which contain information about how many frames to skip/advance forward. + if (!parakeet_joint(pctx, pstate, batch, n_threads, nullptr, nullptr)) { + return false; + } + + // find the token with the highest value. + int best_token = 0; + float best_token_logit = -1e10f; + for (int i = 0; i < n_vocab_logits; ++i) { + if (pstate.logits[i] > best_token_logit) { + best_token_logit = pstate.logits[i]; + best_token = i; + } + } + + // find the max index of the duration logits, and look up that index + // value in the tdt_durations array to get the actual duration value. + int best_duration_idx = 0; + float best_duration_logit = -1e10f; + for (int i = 0; i < hparams.n_tdt_durations; ++i) { + const float v = pstate.logits[n_vocab_logits + i]; + if (v > best_duration_logit) { + best_duration_logit = v; + best_duration_idx = i; + } + } + // look up that max duration index value in the tdt_durations array to + // get the actual duration value. + int duration = tdt_durations[best_duration_idx]; + + // handle blank token output from the joint network. + if (best_token == blank_id) { + if (duration == 0) { + duration = 1; + } + + // always add the duration. + t += duration; + + // if blank duration jumped past the chunk break. + if (t >= chunk_enc_frames) { + break; + } + + // we continue with the current frame + current_token_time = t; + continue; + } + + // found the next non-blank token and label for this outer step. + chosen_token = best_token; + chosen_duration = duration; + chosen_token_logit = best_token_logit; + chosen_duration_idx = best_duration_idx; + break; + } + + // if we exited because blank durations ran beyond the chunk, no label to emit. + if (chosen_token == blank_id) { + break; + } + + // store token at the frame where it was found. + pstate.decoded_tokens.push_back(chosen_token); + pstate.n_sample++; + + parakeet_token_data token_data = create_token_data( + pctx, pstate, chosen_token, chosen_duration_idx, chosen_duration, current_token_time, + chosen_token_logit, n_vocab_logits); + + pstate.decoded_token_data.push_back(token_data); + + // emit the token to the callback if registered. + if (params && params->new_token_callback) { + params->new_token_callback(&pctx, &pstate, &token_data, params->new_token_callback_user_data); + } + + // advance predictor for the non-blank token. + last_token = chosen_token; + pstate.tdt_stream_state.last_token = last_token; + + batch.n_tokens = 1; + batch.token[0] = last_token; + batch.logits[0] = 0; + batch.i_time[0] = 0; + + if (!parakeet_predict(pctx, pstate, batch, n_threads, nullptr, nullptr)) { + return false; + } + + if (chosen_duration > 0) { + // if the duration looked up in the duration array was greater than zero + t += chosen_duration; + last_emission_time = -1; + tokens_at_time = 0; + } else { + // the duration looked up in the duration array was zero + if (current_token_time == last_emission_time) { + // update the token count as it can be valid for the model + // to emit a token with a zero duration. + tokens_at_time++; + } else { + // update the current emission time and reset the token count. + last_emission_time = current_token_time; + tokens_at_time = 1; + } + + // handle the case where the model is emitting too many tokens with + // zero duration to force progress. + if (tokens_at_time >= max_tokens_per_timestep) { + t += 1; + last_emission_time = -1; + tokens_at_time = 0; + } + } + } + + pstate.tdt_stream_state.time_step = t - chunk_enc_frames; + pstate.tdt_stream_state.decoded_length += chunk_enc_frames; + + PARAKEET_LOG_DEBUG("%s: finished t=%d, time_step=%d, decoded_length=%d, total_tokens=%zu\n", + __func__, + t, + pstate.tdt_stream_state.time_step, + pstate.tdt_stream_state.decoded_length, + pstate.decoded_tokens.size()); + + return true; +} + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +static std::string to_timestamp(int64_t t, bool comma = false) { + int64_t msec = t * 10; + int64_t hr = msec / (1000 * 60 * 60); + msec = msec - hr * (1000 * 60 * 60); + int64_t min = msec / (1000 * 60); + msec = msec - min * (1000 * 60); + int64_t sec = msec / 1000; + msec = msec - sec * 1000; + + char buf[32]; + snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); + + return std::string(buf); +} + +// naive Discrete Fourier Transform +// input is real-valued +// output is complex-valued +static void dft(const float* in, int N, float* out, const parakeet_mel_cache & cache) { + const int sin_cos_step = cache.n_fft / N; + + for (int k = 0; k < N; k++) { + float re = 0; + float im = 0; + + for (int n = 0; n < N; n++) { + int idx = (k * n * sin_cos_step) % cache.n_fft; // t = 2*M_PI*k*n/N + re += in[n]*cache.cos_vals[idx]; // cos(t) + im -= in[n]*cache.sin_vals[idx]; // sin(t) + } + + out[k*2 + 0] = re; + out[k*2 + 1] = im; + } +} + +// Cooley-Tukey FFT +// poor man's implementation - use something better +// input is real-valued +// output is complex-valued +static void fft(float* in, int N, float* out, const parakeet_mel_cache & cache) { + if (N == 1) { + out[0] = in[0]; + out[1] = 0; + return; + } + + const int half_N = N / 2; + if (N - half_N*2 == 1) { + dft(in, N, out, cache); + return; + } + + float* even = in + N; + for (int i = 0; i < half_N; ++i) { + even[i]= in[2*i]; + } + float* even_fft = out + 2 * N; + fft(even, half_N, even_fft, cache); + + float* odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i] = in[2*i + 1]; + } + float* odd_fft = even_fft + N; + fft(odd, half_N, odd_fft, cache); + + const int sin_cos_step = cache.n_fft / N; + for (int k = 0; k < half_N; k++) { + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = cache.cos_vals[idx]; // cos(t) + float im = -cache.sin_vals[idx]; // sin(t) + + float re_odd = odd_fft[2*k + 0]; + float im_odd = odd_fft[2*k + 1]; + + out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; + out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + + out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; + out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + } +} + +static void log_mel_spectrogram_worker_thread( + int ith, + const float * window_func, + int window_size, + const std::vector & samples, + int n_samples, + int frame_size, + int frame_step, + int n_threads, + const parakeet_filters & filters, + parakeet_mel & mel, + const parakeet_mel_cache & cache) { + std::vector fft_in(frame_size * 2, 0.0); + std::vector fft_out(frame_size * 2 * 2 * 2); + + int n_fb = filters.n_fb; // number of frequency bins + int i = ith; + + // make sure n_fb == 1 + (frame_size / 2), bin_0 to bin_nyquist + assert(n_fb == 1 + (frame_size / 2)); + + const double eps = 5.960464477539063e-08; + + // calculate FFT only when fft_in are not all zero + for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { + const int offset = i * frame_step; + + const int window_pad_left = (frame_size - window_size) / 2; + + // Zero-pad left + std::fill(fft_in.begin(), fft_in.begin() + window_pad_left, 0.0f); + + // Apply windowed samples in the center + const int n_to_process = std::min({window_size, n_samples - offset}); + for (int j = 0; j < n_to_process; j++) { + fft_in[window_pad_left + j] = window_func[j] * samples[offset + window_pad_left + j]; + } + + // Zero-pad right (and any samples we didn't have) + std::fill(fft_in.begin() + window_pad_left + n_to_process, fft_in.begin() + frame_size, 0.0f); + + // FFT + fft(fft_in.data(), frame_size, fft_out.data(), cache); + + // Calculate modulus^2 of complex numbers + // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. + for (int j = 0; j < n_fb; j++) { + fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); + } + + // mel spectrogram + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + // unroll loop (suggested by GH user @lunixbochs) + int k = 0; + for (k = 0; k < n_fb - 3; k += 4) { + sum += + fft_out[k + 0] * filters.data[j * n_fb + k + 0] + + fft_out[k + 1] * filters.data[j * n_fb + k + 1] + + fft_out[k + 2] * filters.data[j * n_fb + k + 2] + + fft_out[k + 3] * filters.data[j * n_fb + k + 3]; + } + // handle n_fb remainder + for (; k < n_fb; k++) { + sum += fft_out[k] * filters.data[j * n_fb + k]; + } + + mel.data[i * mel.n_mel + j] = std::log(sum + eps); + } + } + + // Otherwise fft_out are all zero - use log(eps) for consistency + const double empty_sum = std::log(eps); + for (; i < mel.n_len; i += n_threads) { + for (int j = 0; j < mel.n_mel; j++) { + mel.data[i * mel.n_mel + j] = empty_sum; + } + } +} + +static bool log_mel_spectrogram( + parakeet_state & wstate, + const float * samples, + const int n_samples, + const int /*sample_rate*/, + const int frame_size, + const int frame_step, + const int n_mel, + const int n_threads, + const parakeet_filters & filters, + const bool debug, + parakeet_mel & mel, + const parakeet_mel_cache & cache) { + const int64_t t_start_us = ggml_time_us(); + + const float * window_func = cache.window.empty() ? cache.hann_window.data() : cache.window.data(); + const int window_size = cache.window.empty() ? cache.n_fft : cache.window.size(); + + std::vector samples_preprocessed(samples, samples + n_samples); + + // Apply preemphasis filter (high-pass): x[i] = x[i] - 0.97 * x[i-1] + { + const float preemph = 0.97f; + for (int i = n_samples - 1; i > 0; i--) { + samples_preprocessed[i] = samples_preprocessed[i] - preemph * samples_preprocessed[i - 1]; + } + } + + // Parakeet Pytorch implementation uses centered contant padding. + const int pad = frame_size / 2; + std::vector samples_padded(n_samples + 2 * pad); // 176512 + std::fill(samples_padded.begin(), samples_padded.begin() + pad, 0.0f); + std::fill(samples_padded.begin() + pad + n_samples, samples_padded.end(), 0.0f); + std::copy(samples_preprocessed.begin(), samples_preprocessed.end(), samples_padded.begin() + pad); + + mel.n_mel = n_mel; + mel.n_len = (samples_padded.size() - frame_size) / frame_step + 1; + mel.n_len_org = mel.n_len; + mel.data.resize(mel.n_mel * mel.n_len); + + // Worker Threads (STFT + Mel + Natural Log) + { + std::vector workers(n_threads - 1); + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw] = std::thread(log_mel_spectrogram_worker_thread, + iw + 1, // thread index + window_func, + window_size, + std::cref(samples_padded), + samples_padded.size(), + frame_size, + frame_step, + n_threads, + std::cref(filters), + std::ref(mel), + std::cref(cache)); + } + + log_mel_spectrogram_worker_thread( + 0, + window_func, + window_size, + samples_padded, + samples_padded.size(), + frame_size, + frame_step, + n_threads, + filters, + mel, + cache); + + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw].join(); + } + } + + { + const double eps = 1e-5; + int valid_frames = n_samples / frame_step; + + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + double sq_diff_sum = 0.0; + + // Calculate Mean ONLY on valid audio frames + for (int i = 0; i < valid_frames; i++) { + sum += (double)mel.data[i * mel.n_mel + j]; + } + double mean = sum / valid_frames; + + // Calculate Variance ONLY on valid audio frames + for (int i = 0; i < valid_frames; i++) { + double diff = (double)mel.data[i * mel.n_mel + j] - mean; + sq_diff_sum += diff * diff; + } + + double std_dev = std::sqrt(sq_diff_sum / (valid_frames - 1.0)); + double denominator = std_dev + eps; + + // Apply to ALL frames (including the padded ones) + for (int i = 0; i < mel.n_len; i++) { + mel.data[i * mel.n_mel + j] = (float)((mel.data[i * mel.n_mel + j] - mean) / denominator); + } + } + } + + wstate.t_mel_us += ggml_time_us() - t_start_us; + + if (debug) { + std::ofstream outFile("log_mel_spectrogram.json"); + outFile << "["; + for (uint64_t i = 0; i < mel.data.size() - 1; i++) { + outFile << mel.data[i] << ", "; + } + outFile << mel.data[mel.data.size() - 1] << "]"; + outFile.close(); + } + + return true; +} + +static std::vector tokenize(const parakeet_vocab & vocab, const std::string & text) { + std::vector tokens; + const std::string normalized = sentencepiece_normalize(text); + + size_t i = 0; + while (i < normalized.size()) { + const size_t remaining = normalized.size() - i; + const size_t max_len = std::min(vocab.max_token_length, remaining); + + bool found = false; + for (size_t len = max_len; len > 0; --len) { + const auto it = vocab.token_to_id.find(normalized.substr(i, len)); + if (it != vocab.token_to_id.end() && !is_sentencepiece_control(it->first)) { + tokens.push_back(it->second); + i += len; + found = true; + break; + } + } + + if (!found) { + if (vocab.token_unk >= 0) { + tokens.push_back(vocab.token_unk); + } + + const unsigned char c = static_cast(normalized[i]); + i += utf8_codepoint_len(c); + } + } + + return tokens; +} + + +// +// interface implementation +// + +struct parakeet_state * parakeet_init_state(parakeet_context * ctx) { + parakeet_state * state = new parakeet_state; + + state->backends = whisper_backend_init(ctx->params); + if (state->backends.empty()) { + PARAKEET_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + const int batch_size = ctx->model.hparams.n_audio_ctx; + + state->logits.reserve(ctx->vocab.n_vocab * batch_size); + + state->batch = parakeet_batch_init(batch_size); + + // conv allocator + { + bool ok = parakeet_sched_graph_init(state->sched_conv, state->backends, + [&]() { + return parakeet_build_graph_conv(*ctx, *state); + }); + + if (!ok) { + PARAKEET_LOG_ERROR("%s: failed to init conv allocator\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + PARAKEET_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_conv) / 1e6); + } + + // encoder allocator + bool ok = parakeet_sched_graph_init(state->sched_encode, state->backends, + [&]() { + return parakeet_build_graph_encoder(*ctx, *state); + }); + + if (!ok) { + PARAKEET_LOG_ERROR("%s: failed to init encoder allocator\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + if (!parakeet_lstm_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_layers)) { + PARAKEET_LOG_ERROR("%s: parakeet_lstm_states_init () failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + { + const size_t mem_lstm_ctx = state->lstm_state.ctx_buf.size(); + const size_t mem_lstm_buf = ggml_backend_buffer_get_size(state->lstm_state.buffer); + PARAKEET_LOG_INFO("%s: lstm state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_lstm_ctx / 1024.0 / 1024.0, mem_lstm_buf / 1024.0 / 1024.0); + } + + if (!parakeet_pred_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_dim)) { + PARAKEET_LOG_ERROR("%s: parakeet_pred_state_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + { + const size_t mem_pred_ctx = state->pred_out_buf.size(); + const size_t mem_pred_out_buf = ggml_backend_buffer_get_size(state->pred_out_buffer); + PARAKEET_LOG_INFO("%s: pred state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_pred_ctx / 1024.0 / 1024.0, mem_pred_out_buf / 1024.0 / 1024.0); + } + + PARAKEET_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_encode) / 1e6); + + { + bool ok = parakeet_sched_graph_init(state->sched_decode, state->backends, + [&]() { + const auto & hparams = ctx->model.hparams; + const int n_tokens = hparams.n_audio_ctx; // Use audio ctx for Parakeet + + parakeet_batch_prep_legacy(state->batch, nullptr, n_tokens, 0, 0); + + return parakeet_build_graph_prediction(*ctx, *state, state->batch, true); + }); + + if (!ok) { + PARAKEET_LOG_ERROR("%s: failed to init decoder allocator\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + PARAKEET_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_decode) / 1e6); + } + + return state; +} + +struct parakeet_context_params parakeet_context_default_params() { + struct parakeet_context_params result = { + /*.use_gpu =*/ true, + /*.flash_attn =*/ true, + /*.gpu_device =*/ 0, + }; + return result; +} + +struct parakeet_context * parakeet_init_from_file_with_params_no_state(const char * path_model, struct parakeet_context_params params) { + PARAKEET_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model); +#ifdef _MSC_VER + // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues. + std::wstring_convert> converter; + std::wstring path_model_wide = converter.from_bytes(path_model); + auto fin = std::ifstream(path_model_wide, std::ios::binary); +#else + auto fin = std::ifstream(path_model, std::ios::binary); +#endif + if (!fin) { + PARAKEET_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model); + return nullptr; + } + + parakeet_model_loader loader = {}; + + loader.context = &fin; + + loader.read = [](void * ctx, void * output, size_t read_size) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->read((char *)output, read_size); + return read_size; + }; + + loader.eof = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + return fin->eof(); + }; + + loader.close = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->close(); + }; + + auto ctx = parakeet_init_with_params_no_state(&loader, params); + + if (ctx) { + ctx->path_model = path_model; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params) { + struct buf_context { + uint8_t* buffer; + size_t size; + size_t current_offset; + }; + + buf_context ctx = { reinterpret_cast(buffer), buffer_size, 0 }; + + PARAKEET_LOG_INFO("%s: loading model from buffer\n", __func__); + + parakeet_model_loader loader = {}; + + loader.context = &ctx; + + loader.read = [](void * ctx, void * output, size_t read_size) { + buf_context * buf = reinterpret_cast(ctx); + + size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset; + + memcpy(output, buf->buffer + buf->current_offset, size_to_copy); + buf->current_offset += size_to_copy; + + return size_to_copy; + }; + + loader.eof = [](void * ctx) { + buf_context * buf = reinterpret_cast(ctx); + + return buf->current_offset >= buf->size; + }; + + loader.close = [](void * /*ctx*/) { }; + + return parakeet_init_with_params_no_state(&loader, params); +} + +struct parakeet_context * parakeet_init_with_params_no_state(struct parakeet_model_loader * loader, struct parakeet_context_params params) { + ggml_time_init(); + + PARAKEET_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu); + PARAKEET_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn); + PARAKEET_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device); + PARAKEET_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count()); + PARAKEET_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count()); + + parakeet_context * ctx = new parakeet_context; + ctx->params = params; + + if (!parakeet_model_load(loader, *ctx)) { + loader->close(loader->context); + PARAKEET_LOG_ERROR("%s: failed to load model\n", __func__); + delete ctx; + return nullptr; + } + + loader->close(loader->context); + + // Initialize mel cache with model's FFT size + ctx->mel_cache.init(ctx->model.hparams.n_fft); + PARAKEET_LOG_INFO("%s: initialized mel cache with n_fft = %d\n", __func__, ctx->model.hparams.n_fft); + + return ctx; +} + +struct parakeet_context * parakeet_init_from_file_with_params(const char * path_model, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_from_file_with_params_no_state(path_model, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_from_buffer_with_params_no_state(buffer, buffer_size, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_with_params(struct parakeet_model_loader * loader, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_with_params_no_state(loader, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +void parakeet_free_state(struct parakeet_state * state) { + if (state) { + ggml_backend_buffer_free(state->lstm_state.buffer); + ggml_backend_buffer_free(state->pred_out_buffer); + + parakeet_batch_free(state->batch); + + ggml_backend_sched_free(state->sched_conv.sched); + ggml_backend_sched_free(state->sched_encode.sched); + ggml_backend_sched_free(state->sched_decode.sched); + + for (auto & backend : state->backends) { + ggml_backend_free(backend); + } + + delete state; + } +} + +void parakeet_free(struct parakeet_context * ctx) { + if (ctx) { + for (ggml_context * context : ctx->model.ctxs) { + ggml_free(context); + } + + for (ggml_backend_buffer_t buf : ctx->model.buffers) { + ggml_backend_buffer_free(buf); + } + + parakeet_free_state(ctx->state); + + delete ctx; + } +} + +void parakeet_free_context_params(struct parakeet_context_params * params) { + if (params) { + delete params; + } +} + +void parakeet_free_params(struct parakeet_full_params * params) { + if (params) { + delete params; + } +} + +int parakeet_pcm_to_mel_with_state(struct parakeet_context * ctx, struct parakeet_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*state, + samples, + n_samples, + PARAKEET_SAMPLE_RATE, + ctx->model.hparams.n_fft, + PARAKEET_HOP_LENGTH, + ctx->model.filters.n_mel, + n_threads, + ctx->model.filters, + false, // debug + state->mel, + ctx->mel_cache)) { + PARAKEET_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_pcm_to_mel(struct parakeet_context * ctx, const float * samples, int n_samples, int n_threads) { + return parakeet_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); +} + +int parakeet_set_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * data, + int n_len, + int n_mel) { + if (n_mel != ctx->model.filters.n_mel) { + PARAKEET_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); + return -1; + } + + state->mel.n_len = n_len; + state->mel.n_len_org = n_len; + state->mel.n_mel = n_mel; + + state->mel.data.resize(n_len*n_mel); + memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float)); + + return 0; +} + +int parakeet_set_mel( + struct parakeet_context * ctx, + const float * data, + int n_len, + int n_mel) { + return parakeet_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel); +} + +int parakeet_encode_with_state(struct parakeet_context * ctx, struct parakeet_state * state, int offset, int n_threads) { + if (!parakeet_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) { + PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_encode(struct parakeet_context * ctx, int offset, int n_threads) { + if (!parakeet_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) { + PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_tokenize(struct parakeet_context * ctx, const char * text, parakeet_token * tokens, int n_max_tokens) { + const auto res = tokenize(ctx->vocab, text); + + if (n_max_tokens < (int) res.size()) { + PARAKEET_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + return -(int) res.size(); + } + + for (int i = 0; i < (int) res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +int parakeet_token_count(struct parakeet_context * ctx, const char * text) { + return -parakeet_tokenize(ctx, text, NULL, 0); +} + +int parakeet_model_n_vocab(struct parakeet_context * ctx) { + return ctx->model.hparams.n_vocab; +} + +int parakeet_model_n_audio_ctx(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +int parakeet_model_n_audio_state(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_state; +} + +int parakeet_model_n_audio_head(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_head; +} + +int parakeet_model_n_audio_layer(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_layer; +} + +int parakeet_model_n_mels(struct parakeet_context * ctx) { + return ctx->model.hparams.n_mels; +} + +int parakeet_model_ftype(struct parakeet_context * ctx) { + return ctx->model.hparams.ftype; +} + +int parakeet_n_len_from_state(struct parakeet_state * state) { + return state->mel.n_len_org; +} + +int parakeet_n_len(struct parakeet_context * ctx) { + return ctx->state->mel.n_len_org; +} + +int parakeet_n_vocab(struct parakeet_context * ctx) { + return ctx->vocab.n_vocab; +} + +int parakeet_n_audio_ctx(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +float * parakeet_get_logits(struct parakeet_context * ctx) { + return ctx->state->logits.data(); +} + +float * parakeet_get_logits_from_state(struct parakeet_state * state) { + return state->logits.data(); +} + +const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token) { + return ctx->vocab.id_to_token.at(token).c_str(); +} + +int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len) { + std::string text = sentencepiece_piece_to_text(token_str, is_first); + + if (output == nullptr) { + return text.size(); + } + + int bytes_to_copy = std::min((int)text.size(), max_len - 1); + if (bytes_to_copy > 0) { + memcpy(output, text.c_str(), bytes_to_copy); + output[bytes_to_copy] = '\0'; + } else if (max_len > 0) { + output[0] = '\0'; + } + + return text.size(); +} + +parakeet_token parakeet_token_bos(struct parakeet_context * ctx) { + return ctx->vocab.token_bos; +} + +parakeet_token parakeet_token_unk(struct parakeet_context * ctx) { + return ctx->vocab.token_unk; +} + +parakeet_token parakeet_token_blank(struct parakeet_context * ctx) { + return ctx->vocab.token_blank; +} + +struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx) { + if (ctx->state == nullptr) { + return nullptr; + } + parakeet_timings * timings = new parakeet_timings; + timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample); + timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode); + timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode); + timings->batchd_ms = 1e-3f * ctx->state->t_batchd_us / std::max(1, ctx->state->n_batchd); + timings->prompt_ms = 1e-3f * ctx->state->t_prompt_us / std::max(1, ctx->state->n_prompt); + return timings; +} + +void parakeet_print_timings(struct parakeet_context * ctx) { + const int64_t t_end_us = ggml_time_us(); + + PARAKEET_LOG_INFO("\n"); + PARAKEET_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + if (ctx->state != nullptr) { + + const int32_t n_sample = std::max(1, ctx->state->n_sample); + const int32_t n_encode = std::max(1, ctx->state->n_encode); + const int32_t n_decode = std::max(1, ctx->state->n_decode); + const int32_t n_batchd = std::max(1, ctx->state->n_batchd); + const int32_t n_prompt = std::max(1, ctx->state->n_prompt); + + PARAKEET_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); + PARAKEET_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); + PARAKEET_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); + PARAKEET_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); + PARAKEET_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + PARAKEET_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd); + PARAKEET_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); + } + PARAKEET_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); +} + +void parakeet_reset_timings(struct parakeet_context * ctx) { + ctx->t_start_us = ggml_time_us(); + if (ctx->state != nullptr) { + ctx->state->t_mel_us = 0; + ctx->state->t_sample_us = 0; + ctx->state->t_encode_us = 0; + ctx->state->t_decode_us = 0; + ctx->state->t_batchd_us = 0; + ctx->state->t_prompt_us = 0; + ctx->state->n_sample = 0; + ctx->state->n_encode = 0; + ctx->state->n_decode = 0; + ctx->state->n_batchd = 0; + ctx->state->n_prompt = 0; + } +} + +const char * parakeet_print_system_info(void) { + static std::string s; + + s = ""; + s += "PARAKEET : "; + + for (size_t i = 0; i < ggml_backend_reg_count(); i++) { + auto * reg = ggml_backend_reg_get(i); + auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features"); + if (get_features_fn) { + ggml_backend_feature * features = get_features_fn(reg); + s += ggml_backend_reg_name(reg); + s += " : "; + for (; features->name; features++) { + s += features->name; + s += " = "; + s += features->value; + s += " | "; + } + } + } + return s.c_str(); +} + +struct parakeet_context_params * parakeet_context_default_params_by_ref(void) { + struct parakeet_context_params params = parakeet_context_default_params(); + + struct parakeet_context_params* result = new parakeet_context_params(); + *result = params; + return result; +} + +struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy) { + struct parakeet_full_params params = parakeet_full_default_params(strategy); + + struct parakeet_full_params* result = new parakeet_full_params(); + *result = params; + return result; +} + +struct parakeet_full_params parakeet_full_default_params(enum parakeet_sampling_strategy strategy) { + struct parakeet_full_params result = { + /*.strategy =*/ strategy, + + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + + /*.no_context =*/ true, + + /*.audio_ctx =*/ 0, + + /*.chunk_length_ms =*/ 10000, // 10 second chunks + /*.left_context_ms =*/ 10000, // 10 second left context + /*.right_context_ms =*/ 4960, // 4.96 second right context + + /*.new_token_callback =*/ nullptr, + /*.new_token_callback_user_data =*/ nullptr, + + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, + + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, + + /*.abort_callback =*/ nullptr, + /*.abort_callback_user_data =*/ nullptr, + }; + + return result; +} + +static void parakeet_stream_reset_state(struct parakeet_state * state) { + if (state == nullptr) { + return; + } + + if (state->lstm_state.buffer) { + ggml_backend_buffer_clear(state->lstm_state.buffer, 0); + } + + state->decoded_tokens.clear(); + state->decoded_token_data.clear(); + state->result_all.clear(); + + state->tdt_stream_state.initialized = false; + state->tdt_stream_state.last_token = 0; + state->tdt_stream_state.time_step = 0; + state->tdt_stream_state.decoded_length = 0; + + state->stream.buffer.clear(); + state->stream.n_samples_advanced = 0; + state->stream.n_left_ctx = 0; + state->stream.n_chunk = 0; + state->stream.n_right_ctx = 0; + state->stream.params = {}; + state->stream.initialized = false; + + state->enc_out_buffer.clear(); + state->enc_out_frames = 0; + state->n_frames = 0; + state->n_audio_ctx = 0; +} + +static int parakeet_stream_process_window( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * samples, + int n_samples, + int n_chunk) { + const parakeet_stream & stream = state->stream; + const parakeet_full_params & params = stream.params; + const int d_enc = ctx->model.hparams.n_audio_state; + + // process all the samples. + if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + return -2; + } + + const int left_mel_frames = stream.n_left_ctx / PARAKEET_HOP_LENGTH; + const int chunk_mel_frames = n_chunk / PARAKEET_HOP_LENGTH; + + state->n_audio_ctx = state->mel.n_len; + // process entire log mel spectrogram. + if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, + params.abort_callback, params.abort_callback_user_data)) { + return -6; + } + + const int left_enc_frames = left_mel_frames / ctx->model.hparams.subsampling_factor; + const int chunk_enc_frames = chunk_mel_frames / ctx->model.hparams.subsampling_factor; + + if (chunk_enc_frames <= 0) { + return 0; + } + + // Copy the center chunk so that it is the only part that the joint network sees. + state->enc_out_buffer.resize(chunk_enc_frames * d_enc); + ggml_backend_tensor_get(state->enc_out, state->enc_out_buffer.data(), + left_enc_frames * d_enc * sizeof(float), + chunk_enc_frames * d_enc * sizeof(float)); + + state->enc_out_frames = chunk_enc_frames; + state->n_frames = chunk_enc_frames; + + const size_t tokens_before = state->decoded_tokens.size(); + + // Run the prediction and joint network on the center chunk. + if (!parakeet_decode_chunk(*ctx, *state, state->batch, chunk_enc_frames, params.n_threads, ¶ms)) { + return -7; + } + + const size_t tokens_after = state->decoded_tokens.size(); + const size_t new_token_count = tokens_after - tokens_before; + + if (new_token_count > 0) { + std::string text; + std::vector result_tokens; + const int64_t chunk_t0 = 100LL * stream.n_samples_advanced / PARAKEET_SAMPLE_RATE; + const int64_t chunk_t1 = 100LL * (stream.n_samples_advanced + n_chunk) / PARAKEET_SAMPLE_RATE; + const int frame_offset = chunk_t0 / ctx->model.hparams.subsampling_factor; + + result_tokens.reserve(new_token_count); + + for (size_t i = tokens_before; i < tokens_after; ++i) { + const auto token_id = state->decoded_tokens[i]; + const char * token_str = parakeet_token_to_str(ctx, token_id); + if (token_str) { + const bool is_first_piece = (tokens_before == 0) && text.empty(); + text += sentencepiece_piece_to_text(token_str, is_first_piece); + } + + auto token_data = state->decoded_token_data[i]; + token_data.frame_index += frame_offset; + token_data.t0 += chunk_t0; + token_data.t1 += chunk_t0; + result_tokens.push_back(token_data); + } + + refine_timestamps_tdt(ctx->vocab, result_tokens); + + if (!text.empty()) { + parakeet_segment segment; + segment.t0 = chunk_t0; + segment.t1 = chunk_t1; + segment.text = std::move(text); + segment.tokens = std::move(result_tokens); + + state->result_all.push_back(std::move(segment)); + + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data); + } + } + } + + return 0; +} + +static int ms_to_n_samples(int ms) { + return ms * PARAKEET_SAMPLE_RATE / 1000; +} + +int parakeet_stream_init( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params) { + if (ctx == nullptr || state == nullptr) { + return -1; + } + + const int n_left_ctx = ms_to_n_samples(params.left_context_ms); + const int n_chunk = ms_to_n_samples(params.chunk_length_ms); + const int n_right_ctx = ms_to_n_samples(params.right_context_ms); + + if (n_left_ctx < 0 || n_chunk <= 0 || n_right_ctx < 0) { + return -1; + } + + parakeet_stream_reset_state(state); + + state->stream.n_left_ctx = n_left_ctx; + state->stream.n_chunk = n_chunk; + state->stream.n_right_ctx = n_right_ctx; + state->stream.params = params; + state->stream.initialized = true; + + if (n_left_ctx > 0) { + state->stream.buffer.assign(n_left_ctx, 0.0f); + } + + return 0; +} + +int parakeet_stream_push( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * samples, + int n_samples) { + if (ctx == nullptr || state == nullptr || samples == nullptr || n_samples <= 0) { + return -1; + } + + if (!state->stream.initialized) { + return -1; + } + + const int n_total_samples = state->stream.n_left_ctx + state->stream.n_chunk + state->stream.n_right_ctx; + + // Insert the new chunk of samples as the new center and right context. + state->stream.buffer.insert(state->stream.buffer.end(), samples, samples + n_samples); + + // As long as we have enough samples to form a complete window we process it. + while (state->stream.buffer.size() >= (size_t) n_total_samples) { + const int ret = parakeet_stream_process_window( + ctx, + state, + state->stream.buffer.data(), + n_total_samples, + state->stream.n_chunk); + if (ret != 0) { + return ret; + } + + // TODO: std::vector::erase is O(n) and not optimal. We should probably + // use a ring buffer instead. + // Shift the center and right context to the start of the buffer. This + // allows the next call to have the current center chunk as its left + // context, and the right context will become part of the next target + // chunk together with the new samples which will make up the rest of + // the target chunk and the new right context. + state->stream.buffer.erase(state->stream.buffer.begin(), state->stream.buffer.begin() + state->stream.n_chunk); + + state->stream.n_samples_advanced += state->stream.n_chunk; + } + + return 0; +} + +int parakeet_stream_flush( + struct parakeet_context * ctx, + struct parakeet_state * state) { + if (ctx == nullptr || state == nullptr) { + return -1; + } + + if (!state->stream.initialized) { + return -1; + } + + while (state->stream.buffer.size() > (size_t) state->stream.n_left_ctx) { + const int n_remaining_samples = (int) state->stream.buffer.size() - state->stream.n_left_ctx; + const int n_flush_chunk = std::min(state->stream.n_chunk, n_remaining_samples); + const int n_right_available = std::min(state->stream.n_right_ctx, n_remaining_samples - n_flush_chunk); + const int n_copied = state->stream.n_left_ctx + n_flush_chunk + n_right_available; + + std::vector flush_window(state->stream.n_left_ctx + n_flush_chunk + state->stream.n_right_ctx, 0.0f); + + std::copy_n(state->stream.buffer.begin(), n_copied, flush_window.begin()); + + const int ret = parakeet_stream_process_window( + ctx, + state, + flush_window.data(), + (int) flush_window.size(), + n_flush_chunk); + if (ret != 0) { + return ret; + } + + state->stream.buffer.erase(state->stream.buffer.begin(), state->stream.buffer.begin() + n_flush_chunk); + state->stream.n_samples_advanced += n_flush_chunk; + } + + state->stream.buffer.clear(); + state->stream.n_samples_advanced = 0; + state->stream.n_left_ctx = 0; + state->stream.n_chunk = 0; + state->stream.n_right_ctx = 0; + state->stream.params = {}; + state->stream.initialized = false; + + return 0; +} + +int parakeet_full_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + // clear old results + auto & result_all = state->result_all; + result_all.clear(); + + // Clear any previous decoded tokens if no_context is set + if (params.no_context) { + state->decoded_tokens.clear(); + state->decoded_token_data.clear(); + } + + // If no chunking specified, delegate to parakeet_chunk (processes entire audio as one chunk) + if (params.chunk_length_ms == 0) { + return parakeet_chunk(ctx, state, params, samples, n_samples); + } + + // Nemo sliding window chunking (chunk raw samples) + { + PARAKEET_LOG_INFO("%s: chunking: chunk_ms=%d, left_ctx=%d, right_ctx=%d\n", + __func__, params.chunk_length_ms, params.left_context_ms, params.right_context_ms); + + // reset tdt streaming state + state->tdt_stream_state.initialized = false; + state->tdt_stream_state.last_token = 0; + state->tdt_stream_state.time_step = 0; + state->tdt_stream_state.decoded_length = 0; + + const int left_context_samples = (params.left_context_ms * PARAKEET_SAMPLE_RATE) / 1000; + const int chunk_samples = (params.chunk_length_ms * PARAKEET_SAMPLE_RATE) / 1000; + const int right_context_samples = (params.right_context_ms * PARAKEET_SAMPLE_RATE) / 1000; + + const int d_enc = ctx->model.hparams.n_audio_state; + + int left_sample = 0; + + // Process audio in sliding chunks of raw samples. + while (left_sample < n_samples) { + const int buffer_start_sample = std::max(0, left_sample - left_context_samples); + const int right_sample = std::min(left_sample + chunk_samples, n_samples); + const int buffer_end_sample = std::min(right_sample + right_context_samples, n_samples); + const int buffer_length_samples = buffer_end_sample - buffer_start_sample; + const size_t tokens_before = state->decoded_tokens.size(); + + PARAKEET_LOG_DEBUG("%s: sample chunk: buffer=[%d-%d] samples, chunk=[%d-%d]\n", + __func__, buffer_start_sample, buffer_end_sample, left_sample, right_sample); + + // Compute mel spectrogram for the left context, the center chunk, + // and the right context. + if (parakeet_pcm_to_mel_with_state(ctx, state, samples + buffer_start_sample, + buffer_length_samples, params.n_threads) != 0) { + PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram for chunk\n", __func__); + return -2; + } + + const int total_mel_frames = state->mel.n_len; + + // Calculate mel frame offsets within this chunks mel spectrogram + // mel frames = raw samples / 160 (for 16kHz, 10ms hop) + const int left_context_mel_frames = (left_sample - buffer_start_sample) / 160; + const int chunk_mel_frames = (right_sample - left_sample) / 160; + + PARAKEET_LOG_DEBUG("%s: Mel features: total=%d frames, left_ctx=%d, chunk=%d\n", + __func__, total_mel_frames, left_context_mel_frames, chunk_mel_frames); + + // The encoder will see the left context, the main chunk, and the right + // context so that it has access to nearby frames during processing. + // Without this there would be hard cutoffs and the encoder might not + // be able to detect speech near the edges. + state->n_audio_ctx = total_mel_frames; + if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, + params.abort_callback, params.abort_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: failed to encode chunk\n", __func__); + return -6; + } + + // Calculate which encoder frames correspond to the actual chunk (center chunk) + const int left_context_enc_frames = left_context_mel_frames / ctx->model.hparams.subsampling_factor; + const int chunk_enc_frames = chunk_mel_frames / ctx->model.hparams.subsampling_factor; + + PARAKEET_LOG_DEBUG("%s: encoder output: total=%d frames, left_ctx=%d, chunk=%d\n", + __func__, state->n_frames, left_context_enc_frames, chunk_enc_frames); + + // The joint network only sees the center chunk encoder frames and + // not the left and right context that the encoder did. + state->enc_out_buffer.resize(chunk_enc_frames * d_enc); + ggml_backend_tensor_get(state->enc_out, state->enc_out_buffer.data(), + left_context_enc_frames * d_enc * sizeof(float), + chunk_enc_frames * d_enc * sizeof(float)); + state->enc_out_frames = chunk_enc_frames; + + state->n_frames = chunk_enc_frames; + + PARAKEET_LOG_DEBUG("%s: decoding %d encoder frames for this chunk\n", __func__, chunk_enc_frames); + + if (!parakeet_decode_chunk(*ctx, *state, state->batch, chunk_enc_frames, params.n_threads, ¶ms)) { + PARAKEET_LOG_ERROR("%s: failed to decode chunk\n", __func__); + return -7; + } + + const size_t tokens_after = state->decoded_tokens.size(); + const size_t new_token_count = tokens_after - tokens_before; + + if (new_token_count > 0) { + std::string text; + std::vector result_tokens; + const int64_t chunk_t0 = 100LL * left_sample / PARAKEET_SAMPLE_RATE; + const int64_t chunk_t1 = 100LL * right_sample / PARAKEET_SAMPLE_RATE; + const int frame_offset = chunk_t0 / ctx->model.hparams.subsampling_factor; + + result_tokens.reserve(new_token_count); + + for (size_t i = tokens_before; i < tokens_after; ++i) { + const auto token_id = state->decoded_tokens[i]; + const char * token_str = parakeet_token_to_str(ctx, token_id); + if (token_str) { + const bool is_first_piece = (tokens_before == 0) && text.empty(); + text += sentencepiece_piece_to_text(token_str, is_first_piece); + } + + auto token_data = state->decoded_token_data[i]; + token_data.frame_index += frame_offset; + token_data.t0 += chunk_t0; + token_data.t1 += chunk_t0; + result_tokens.push_back(token_data); + } + + refine_timestamps_tdt(ctx->vocab, result_tokens); + + if (!text.empty()) { + parakeet_segment segment; + segment.t0 = chunk_t0; + segment.t1 = chunk_t1; + segment.text = std::move(text); + segment.tokens = std::move(result_tokens); + + state->result_all.push_back(std::move(segment)); + + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data); + } + } + } + + // shift the window. + left_sample = right_sample; + } + + // Clear buffers after all chunks processed + state->enc_out_buffer.clear(); + state->enc_out_frames = 0; + } + + return 0; +} + +int parakeet_full( + struct parakeet_context * ctx, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + return parakeet_full_with_state(ctx, ctx->state, params, samples, n_samples); +} + +int parakeet_chunk( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + + if (params.no_context) { + ggml_backend_buffer_clear(state->lstm_state.buffer, 0); + state->decoded_tokens.clear(); + state->decoded_token_data.clear(); + + state->tdt_stream_state.initialized = false; + state->tdt_stream_state.last_token = 0; + state->tdt_stream_state.time_step = 0; + state->tdt_stream_state.decoded_length = 0; + } + + if (n_samples > 0) { + if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + + if (params.audio_ctx == 0) { + const int total_len = parakeet_n_len_from_state(state); + const int model_max_ctx = parakeet_n_audio_ctx(ctx); + params.audio_ctx = std::min(total_len, model_max_ctx); + PARAKEET_LOG_DEBUG("Processing audio: total_frames=%d, chunk_size=%d\n", total_len, params.audio_ctx); + } + state->n_audio_ctx = params.audio_ctx; + + const int n_frames = parakeet_n_len_from_state(state); + + if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + const size_t tokens_before = state->decoded_tokens.size(); + + if (!parakeet_decode(*ctx, *state, state->batch, params.n_threads, ¶ms)) { + PARAKEET_LOG_ERROR("%s: failed to decode\n", __func__); + return -7; + } + + const size_t tokens_after = state->decoded_tokens.size(); + const size_t new_token_count = tokens_after - tokens_before; + + if (new_token_count > 0) { + std::string text; + std::vector result_tokens; + + for (size_t i = tokens_before; i < tokens_after; i++) { + const auto token_id = state->decoded_tokens[i]; + const char * token_str = parakeet_token_to_str(ctx, token_id); + if (token_str) { + const bool is_first_piece = (tokens_before == 0) && text.empty(); + text += sentencepiece_piece_to_text(token_str, is_first_piece); + } + + // Use the stored token data from parakeet_decode + result_tokens.push_back(state->decoded_token_data[i]); + } + + refine_timestamps_tdt(ctx->vocab, result_tokens); + + if (!text.empty()) { + parakeet_segment segment; + segment.t0 = 0; // Caller tracks timing + segment.t1 = n_frames; + segment.text = text; + segment.tokens = result_tokens; + + state->result_all.push_back(std::move(segment)); + + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data); + } + } + } + + return 0; +} + +int parakeet_full_parallel( + struct parakeet_context * ctx, + struct parakeet_full_params params, + const float * samples, + int n_samples, + int n_processors) { + if (n_processors == 1) { + return parakeet_full(ctx, params, samples, n_samples); + } + + int ret = 0; + + // prepare separate states for each thread + std::vector states; + + const int offset_samples = (PARAKEET_SAMPLE_RATE*params.offset_ms)/1000; + const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; + + // the calling thread will process the first chunk + // while the other threads will process the remaining chunks + + std::vector workers(n_processors - 1); + for (int i = 0; i < n_processors - 1; ++i) { + // create a new state for each thread + states.push_back(parakeet_init_state(ctx)); + + const int start_samples = offset_samples + (i + 1)*n_samples_per_processor; + const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; + + auto params_cur = params; + + params_cur.offset_ms = 0; + + params_cur.new_segment_callback = nullptr; + params_cur.new_segment_callback_user_data = nullptr; + + params_cur.progress_callback = nullptr; + params_cur.progress_callback_user_data = nullptr; + + workers[i] = std::thread(parakeet_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur); + } + + { + auto params_cur = params; + + // Run the first transformation using default state but only for the first chunk. + ret = parakeet_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor); + } + + for (int i = 0; i < n_processors - 1; ++i) { + workers[i].join(); + } + + const int64_t offset_t = (int64_t) params.offset_ms/10.0; + + // combine results into result_state->result_all from all other states + for (int i = 0; i < n_processors - 1; ++i) { + auto& results_i = states[i]->result_all; + + for (auto& result : results_i) { + // correct the segment timestamp taking into account the offset + result.t0 += 100 * ((i + 1) * n_samples_per_processor) / PARAKEET_SAMPLE_RATE + offset_t; + result.t1 += 100 * ((i + 1) * n_samples_per_processor) / PARAKEET_SAMPLE_RATE + offset_t; + + // make sure that segments are not overlapping + if (!ctx->state->result_all.empty()) { + result.t0 = std::max(result.t0, ctx->state->result_all.back().t1); + } + + ctx->state->result_all.push_back(std::move(result)); + + // call the new_segment_callback for each segment + if (params.new_segment_callback) { + params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data); + } + } + + ctx->state->t_mel_us += states[i]->t_mel_us; + + ctx->state->t_sample_us += states[i]->t_sample_us; + ctx->state->t_encode_us += states[i]->t_encode_us; + ctx->state->t_decode_us += states[i]->t_decode_us; + ctx->state->t_batchd_us += states[i]->t_batchd_us; + ctx->state->t_prompt_us += states[i]->t_prompt_us; + + ctx->state->n_sample += states[i]->n_sample; + ctx->state->n_encode += states[i]->n_encode; + ctx->state->n_decode += states[i]->n_decode; + ctx->state->n_batchd += states[i]->n_batchd; + ctx->state->n_prompt += states[i]->n_prompt; + + parakeet_free_state(states[i]); + } + + // average the timings + ctx->state->t_mel_us /= n_processors; + ctx->state->t_sample_us /= n_processors; + ctx->state->t_encode_us /= n_processors; + ctx->state->t_decode_us /= n_processors; + + // print information about the audio boundaries + PARAKEET_LOG_WARN("\n"); + PARAKEET_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); + for (int i = 0; i < n_processors - 1; ++i) { + PARAKEET_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/PARAKEET_SAMPLE_RATE + offset_t).c_str()); + } + PARAKEET_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__); + + return ret; +} + +int parakeet_full_n_segments_from_state(struct parakeet_state * state) { + return state->result_all.size(); +} + +int parakeet_full_n_segments(struct parakeet_context * ctx) { + return ctx->state->result_all.size(); +} + +int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].t0; +} + +int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].t1; +} + +int64_t parakeet_full_get_segment_t0(struct parakeet_context * ctx, int i_segment) { + return parakeet_full_get_segment_t0_from_state(ctx->state, i_segment); +} + +int64_t parakeet_full_get_segment_t1(struct parakeet_context * ctx, int i_segment) { + return parakeet_full_get_segment_t1_from_state(ctx->state, i_segment); +} + +const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].text.c_str(); +} + +const char * parakeet_full_get_segment_text(struct parakeet_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].text.c_str(); +} + +int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].tokens.size(); +} + +int parakeet_full_n_tokens(struct parakeet_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].tokens.size(); +} + +const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token) { + return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +const char* parakeet_full_get_token_text(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].id; +} + +parakeet_token parakeet_full_get_token_id(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].id; +} + +struct parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token]; +} + +struct parakeet_token_data parakeet_full_get_token_data(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token]; +} + +float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].p; +} + +float parakeet_full_get_token_p(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].p; +} + +void parakeet_log_set(ggml_log_callback log_callback, void * user_data) { + g_state.log_callback = log_callback ? log_callback : parakeet_log_callback_default; + g_state.log_callback_user_data = user_data; + ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); +} + +const char * parakeet_version(void) { + return PARAKEET_VERSION; +} + +GGML_ATTRIBUTE_FORMAT(2, 3) +static void parakeet_log_internal(ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + char buffer[1024]; + int len = vsnprintf(buffer, 1024, format, args); + if (len < 1024) { + g_state.log_callback(level, buffer, g_state.log_callback_user_data); + } else { + char* buffer2 = new char[len+1]; + vsnprintf(buffer2, len+1, format, args); + buffer2[len] = 0; + g_state.log_callback(level, buffer2, g_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args); +} + +static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; +#ifndef WHISPER_DEBUG + if (level == GGML_LOG_LEVEL_DEBUG) { + return; + } +#endif + fputs(text, stderr); + fflush(stderr); +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 09e77ea89c2..a426d9dca88 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -110,3 +110,32 @@ target_compile_definitions(${VAD_TEST} PRIVATE SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST}) set_tests_properties(${VAD_TEST} PROPERTIES LABELS "base;en") + +# Parakeet model loading test +set(PARAKEET_TEST test-parakeet) +add_executable(${PARAKEET_TEST} ${PARAKEET_TEST}.cpp) +target_include_directories(${PARAKEET_TEST} PRIVATE ../include ../ggml/include ../examples) +target_link_libraries(${PARAKEET_TEST} PRIVATE parakeet common) +target_compile_definitions(${PARAKEET_TEST} PRIVATE + PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/ggml-parakeet-tdt-0.6b-v3.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") +add_test(NAME ${PARAKEET_TEST} COMMAND ${PARAKEET_TEST}) + +set(PARAKEET_TEST test-parakeet-stream) +add_executable(${PARAKEET_TEST} ${PARAKEET_TEST}.cpp) +target_include_directories(${PARAKEET_TEST} PRIVATE ../include ../ggml/include ../examples) +target_link_libraries(${PARAKEET_TEST} PRIVATE parakeet common) +target_compile_definitions(${PARAKEET_TEST} PRIVATE + PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/ggml-parakeet-tdt-0.6b-v3.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/gb1.wav") +add_test(NAME ${PARAKEET_TEST} COMMAND ${PARAKEET_TEST}) + +set(PARAKEET_TEST test-parakeet-full) +add_executable(${PARAKEET_TEST} ${PARAKEET_TEST}.cpp) +target_include_directories(${PARAKEET_TEST} PRIVATE ../include ../ggml/include ../examples) +target_link_libraries(${PARAKEET_TEST} PRIVATE parakeet common) +target_compile_definitions(${PARAKEET_TEST} PRIVATE + PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/ggml-parakeet-tdt-0.6b-v3.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/gb1.wav") +add_test(NAME ${PARAKEET_TEST} COMMAND ${PARAKEET_TEST}) +set_tests_properties(${PARAKEET_TEST} PROPERTIES LABELS "parakeet;unit") diff --git a/tests/test-parakeet-full.cpp b/tests/test-parakeet-full.cpp new file mode 100644 index 00000000000..11d52488f51 --- /dev/null +++ b/tests/test-parakeet-full.cpp @@ -0,0 +1,62 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include +#include + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + static bool is_first = true; + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf)); + + int32_t time_ms = token_data->frame_index * 10; + + printf("%s", text_buf); + fflush(stdout); + + is_first = false; +} + +int main() { + std::string model_path = PARAKEET_MODEL_PATH; + std::string sample_path = SAMPLE_PATH; + + std::vector pcmf32; + std::vector> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + assert(pcmf32s.size() == 0); // no stereo vector + + printf("Loading Parakeet model from: %s\n", model_path.c_str()); + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + + struct parakeet_context * pctx = parakeet_init_from_file_with_params(model_path.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "Failed to load Parakeet model\n"); + return 1; + } + printf("Successfully loaded Parakeet model\n"); + + struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + params.new_token_callback = token_callback; + params.new_token_callback_user_data = nullptr; + + params.chunk_length_ms = 10000; + params.left_context_ms = 10000; + params.right_context_ms = 4960; + + int ret = parakeet_full(pctx, params, pcmf32.data(), pcmf32.size()); + assert(ret == 0); + + parakeet_free(pctx); + + printf("\nTest passed: parakeet_full_parallel succeeded!\n"); + return 0; +} diff --git a/tests/test-parakeet-stream.cpp b/tests/test-parakeet-stream.cpp new file mode 100644 index 00000000000..f112bcb932b --- /dev/null +++ b/tests/test-parakeet-stream.cpp @@ -0,0 +1,107 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include +#include + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + static bool is_first = true; + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf)); + + int32_t time_ms = token_data->frame_index * 10; + + printf("%s", text_buf); + fflush(stdout); + + is_first = false; +} + +void segment_callback(parakeet_context * ctx, parakeet_state * state, int n_new, void * user_data) { + const int n_segments = parakeet_full_n_segments_from_state(state); + const int s0 = n_segments - n_new; + + printf("\nSegment Callback: %d new segment(s)\n", n_new); + + for (int i = s0; i < n_segments; i++) { + const char * text = parakeet_full_get_segment_text_from_state(state, i); + const int64_t t0 = parakeet_full_get_segment_t0_from_state(state, i); + const int64_t t1 = parakeet_full_get_segment_t1_from_state(state, i); + + printf("Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text); + printf("Tokens:\n"); + + const int n_tokens = parakeet_full_n_tokens_from_state(state, i); + for (int j = 0; j < n_tokens; j++) { + parakeet_token_data token_data = parakeet_full_get_token_data_from_state(state, i, j); + const char * token_str = parakeet_token_to_str(ctx, token_data.id); + + printf(" [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%d \"%s\"\n", + j, + token_data.id, + token_data.frame_index, + token_data.duration_idx, + token_data.duration_value, + token_data.p, + token_data.plog, + (long long)token_data.t0, + (long long)token_data.t1, + token_data.is_word_start, + token_str); + } + } + printf("\n"); +} + +int main() { + std::string model_path = PARAKEET_MODEL_PATH; + std::string sample_path = SAMPLE_PATH; + + std::vector pcmf32; + std::vector> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + struct parakeet_context * pctx = parakeet_init_from_file_with_params_no_state(model_path.c_str(), ctx_params); + if (pctx == nullptr) { return 1; } + + struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + params.new_token_callback = token_callback; + + params.left_context_ms = 10000; + params.chunk_length_ms = 10000; + params.right_context_ms = 4960; + + parakeet_state * state = parakeet_init_state(pctx); + + // initialize streaming state + assert(parakeet_stream_init(pctx, state, params) == 0); + + const int samples_batch_size = 1600; + int position = 0; + + while (position < (int)pcmf32.size()) { + int samples_to_push = std::min(samples_batch_size, (int)pcmf32.size() - position); + + int ret = parakeet_stream_push(pctx, state, pcmf32.data() + position, samples_to_push); + assert(ret == 0); + + position += samples_to_push; + } + + // flush remaining samples. + assert(parakeet_stream_flush(pctx, state) == 0); + + parakeet_free_state(state); + parakeet_free(pctx); + + printf("\n\nTest passed: Streaming logic.\n"); + return 0; +} diff --git a/tests/test-parakeet.cpp b/tests/test-parakeet.cpp new file mode 100644 index 00000000000..83237c600ac --- /dev/null +++ b/tests/test-parakeet.cpp @@ -0,0 +1,99 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include +#include + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + static bool is_first = true; + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf)); + + int32_t time_ms = token_data->frame_index * 10; + + printf("%s", text_buf); + fflush(stdout); + + is_first = false; +} + +void segment_callback(parakeet_context * ctx, parakeet_state * state, int n_new, void * user_data) { + const int n_segments = parakeet_full_n_segments_from_state(state); + const int s0 = n_segments - n_new; + + printf("\nSegment Callback: %d new segment(s)\n", n_new); + + for (int i = s0; i < n_segments; i++) { + const char * text = parakeet_full_get_segment_text_from_state(state, i); + const int64_t t0 = parakeet_full_get_segment_t0_from_state(state, i); + const int64_t t1 = parakeet_full_get_segment_t1_from_state(state, i); + + printf("Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text); + printf("Tokens:\n"); + + const int n_tokens = parakeet_full_n_tokens_from_state(state, i); + for (int j = 0; j < n_tokens; j++) { + parakeet_token_data token_data = parakeet_full_get_token_data_from_state(state, i, j); + const char * token_str = parakeet_token_to_str(ctx, token_data.id); + + printf(" [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%d \"%s\"\n", + j, + token_data.id, + token_data.frame_index, + token_data.duration_idx, + token_data.duration_value, + token_data.p, + token_data.plog, + (long long)token_data.t0, + (long long)token_data.t1, + token_data.is_word_start, + token_str); + } + } + printf("\n"); +} + +int main() { + std::string model_path = PARAKEET_MODEL_PATH; + std::string sample_path = SAMPLE_PATH; + + // Load the sample audio file + std::vector pcmf32; + std::vector> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + assert(pcmf32s.size() == 0); + + printf("Loading Parakeet model from: %s\n", model_path.c_str()); + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + + struct parakeet_context * pctx = parakeet_init_from_file_with_params_no_state(model_path.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "Failed to load Parakeet model\n"); + return 1; + } + printf("Successfully loaded Parakeet model\n"); + + struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + params.new_token_callback = token_callback; + params.new_token_callback_user_data = nullptr; + params.new_segment_callback = segment_callback; + params.new_segment_callback_user_data = nullptr; + parakeet_state * state = parakeet_init_state(pctx); + + int ret = parakeet_chunk(pctx, state, params, pcmf32.data(), pcmf32.size()); + assert(ret == 0); + + parakeet_free_state(state); + parakeet_free(pctx); + + printf("\nTest passed: Parakeet model loaded and freed successfully\n"); + return 0; +} From 661c9e2981b55111c8c82d5837ce0a2a004b6e0a Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 20 Apr 2026 12:40:38 +0200 Subject: [PATCH 02/33] ci : add -DGGML_NATIVE=OFF to windows job --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fb115b22abb..6fa07179a4f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -634,6 +634,7 @@ jobs: -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DBUILD_SHARED_LIBS=ON -DWHISPER_SDL2=${{ matrix.sdl2 }} + -DGGML_NATIVE=OFF - name: Build run: | From 1e0c5dd5803a571548716b32d8e9b554766a59ee Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 20 Apr 2026 13:36:17 +0200 Subject: [PATCH 03/33] ci : add GGML_BMI2=OFF to window job --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6fa07179a4f..0d704052583 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -635,6 +635,7 @@ jobs: -DBUILD_SHARED_LIBS=ON -DWHISPER_SDL2=${{ matrix.sdl2 }} -DGGML_NATIVE=OFF + -DGGML_BMI2=OFF - name: Build run: | From 1ed636a70d4e45b0b4f5ccc73fe91bee18b437d6 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 21 Apr 2026 09:18:56 +0200 Subject: [PATCH 04/33] examples : set flash attention to false for parakeet-cli [no ci] --- examples/parakeet-cli/README.md | 2 +- examples/parakeet-cli/parakeet-cli.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/parakeet-cli/README.md b/examples/parakeet-cli/README.md index 25a3f1e49c1..dfe58f57a03 100644 --- a/examples/parakeet-cli/README.md +++ b/examples/parakeet-cli/README.md @@ -30,7 +30,7 @@ options: -f, --file FILE [ ] input audio file -ng, --no-gpu [false ] disable GPU -dev N, --device N [0 ] GPU device to use - -fa, --flash-attn [true ] enable flash attention + -fa, --flash-attn [false ] enable flash attention -nfa, --no-flash-attn [false ] disable flash attention -ps, --print-segments [false ] print segment information ``` diff --git a/examples/parakeet-cli/parakeet-cli.cpp b/examples/parakeet-cli/parakeet-cli.cpp index bdb1739ce5e..35f2381967c 100644 --- a/examples/parakeet-cli/parakeet-cli.cpp +++ b/examples/parakeet-cli/parakeet-cli.cpp @@ -62,7 +62,7 @@ static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & para else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); } - else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = false; } else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-ps" || arg == "--print-segments") { params.print_segments = true; } else { From ac39626016c21f92e5ce212815a844d399ebb93e Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 22 Apr 2026 14:17:41 +0200 Subject: [PATCH 05/33] parakeet : initialize ggml_tensors pointers to nullptr [no ci] --- src/parakeet.cpp | 132 +++++++++++++++++++++++------------------------ 1 file changed, 66 insertions(+), 66 deletions(-) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index 43f22aca0ea..c88ff27e216 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -463,65 +463,65 @@ struct parakeet_hparams { }; struct parakeet_layer_encoder { - struct ggml_tensor * norm_ff1_w; - struct ggml_tensor * norm_ff1_b; - - struct ggml_tensor * ff1_linear1_w; - struct ggml_tensor * ff1_linear2_w; - - struct ggml_tensor * norm_conv_w; - struct ggml_tensor * norm_conv_b; - - struct ggml_tensor * conv_pw1_w; // pointwise_conv1 - struct ggml_tensor * conv_dw_w; // depthwise_conv - struct ggml_tensor * conv_bn_w; // batch_norm weight - struct ggml_tensor * conv_bn_b; // batch_norm bias - struct ggml_tensor * conv_bn_mean; // batch_norm running_mean - struct ggml_tensor * conv_bn_var; // batch_norm running_var - struct ggml_tensor * conv_bn_num_batches; // batch_norm num_batches_tracked - struct ggml_tensor * conv_pw2_w; // pointwise_conv2 - - struct ggml_tensor * norm_attn_w; - struct ggml_tensor * norm_attn_b; - - struct ggml_tensor * attn_pos_bias_u; - struct ggml_tensor * attn_pos_bias_v; - struct ggml_tensor * attn_q_w; - struct ggml_tensor * attn_k_w; - struct ggml_tensor * attn_v_w; - struct ggml_tensor * attn_out_w; - struct ggml_tensor * attn_pos_w; - - struct ggml_tensor * norm_ff2_w; - struct ggml_tensor * norm_ff2_b; - - struct ggml_tensor * ff2_linear1_w; - struct ggml_tensor * ff2_linear2_w; - - struct ggml_tensor * norm_out_w; - struct ggml_tensor * norm_out_b; + struct ggml_tensor * norm_ff1_w = nullptr; + struct ggml_tensor * norm_ff1_b = nullptr; + + struct ggml_tensor * ff1_linear1_w = nullptr; + struct ggml_tensor * ff1_linear2_w = nullptr; + + struct ggml_tensor * norm_conv_w = nullptr; + struct ggml_tensor * norm_conv_b = nullptr; + + struct ggml_tensor * conv_pw1_w = nullptr; // pointwise_conv1 + struct ggml_tensor * conv_dw_w = nullptr; // depthwise_conv + struct ggml_tensor * conv_bn_w = nullptr; // batch_norm weight + struct ggml_tensor * conv_bn_b = nullptr; // batch_norm bias + struct ggml_tensor * conv_bn_mean = nullptr; // batch_norm running_mean + struct ggml_tensor * conv_bn_var = nullptr; // batch_norm running_var + struct ggml_tensor * conv_bn_num_batches = nullptr; // batch_norm num_batches_tracked + struct ggml_tensor * conv_pw2_w = nullptr; // pointwise_conv2 + + struct ggml_tensor * norm_attn_w = nullptr; + struct ggml_tensor * norm_attn_b = nullptr; + + struct ggml_tensor * attn_pos_bias_u = nullptr; + struct ggml_tensor * attn_pos_bias_v = nullptr; + struct ggml_tensor * attn_q_w = nullptr; + struct ggml_tensor * attn_k_w = nullptr; + struct ggml_tensor * attn_v_w = nullptr; + struct ggml_tensor * attn_out_w = nullptr; + struct ggml_tensor * attn_pos_w = nullptr; + + struct ggml_tensor * norm_ff2_w = nullptr; + struct ggml_tensor * norm_ff2_b = nullptr; + + struct ggml_tensor * ff2_linear1_w = nullptr; + struct ggml_tensor * ff2_linear2_w = nullptr; + + struct ggml_tensor * norm_out_w = nullptr; + struct ggml_tensor * norm_out_b = nullptr; }; struct parakeet_lsmt_layer { - struct ggml_tensor * ih_w; // input-to-hidden weight - struct ggml_tensor * ih_b; // input-to-hidden bias - struct ggml_tensor * hh_w; // hidden-to-hidden weight - struct ggml_tensor * hh_b; // hidden-to-hidden bias + struct ggml_tensor * ih_w = nullptr; // input-to-hidden weight + struct ggml_tensor * ih_b = nullptr; // input-to-hidden bias + struct ggml_tensor * hh_w = nullptr; // hidden-to-hidden weight + struct ggml_tensor * hh_b = nullptr; // hidden-to-hidden bias }; struct parakeet_prediction_network { - struct ggml_tensor * embed_w; + struct ggml_tensor * embed_w = nullptr; std::vector lstm_layer; }; struct parakeet_joint_network { - struct ggml_tensor * pred_w; - struct ggml_tensor * pred_b; - struct ggml_tensor * enc_w; - struct ggml_tensor * enc_b; - struct ggml_tensor * net_w; - struct ggml_tensor * net_b; + struct ggml_tensor * pred_w = nullptr; + struct ggml_tensor * pred_b = nullptr; + struct ggml_tensor * enc_w = nullptr; + struct ggml_tensor * enc_b = nullptr; + struct ggml_tensor * net_w = nullptr; + struct ggml_tensor * net_b = nullptr; }; struct parakeet_model { @@ -529,20 +529,20 @@ struct parakeet_model { parakeet_hparams hparams; // Relative positional encoding (pe) - struct ggml_tensor * pe; - - struct ggml_tensor * enc_pre_out_w; - struct ggml_tensor * enc_pre_out_b; - struct ggml_tensor * enc_pre_conv_0_w; - struct ggml_tensor * enc_pre_conv_0_b; - struct ggml_tensor * enc_pre_conv_2_w; - struct ggml_tensor * enc_pre_conv_2_b; - struct ggml_tensor * enc_pre_conv_3_w; - struct ggml_tensor * enc_pre_conv_3_b; - struct ggml_tensor * enc_pre_conv_5_w; - struct ggml_tensor * enc_pre_conv_5_b; - struct ggml_tensor * enc_pre_conv_6_w; - struct ggml_tensor * enc_pre_conv_6_b; + struct ggml_tensor * pe = nullptr; + + struct ggml_tensor * enc_pre_out_w = nullptr; + struct ggml_tensor * enc_pre_out_b = nullptr; + struct ggml_tensor * enc_pre_conv_0_w = nullptr; + struct ggml_tensor * enc_pre_conv_0_b = nullptr; + struct ggml_tensor * enc_pre_conv_2_w = nullptr; + struct ggml_tensor * enc_pre_conv_2_b = nullptr; + struct ggml_tensor * enc_pre_conv_3_w = nullptr; + struct ggml_tensor * enc_pre_conv_3_b = nullptr; + struct ggml_tensor * enc_pre_conv_5_w = nullptr; + struct ggml_tensor * enc_pre_conv_5_b = nullptr; + struct ggml_tensor * enc_pre_conv_6_w = nullptr; + struct ggml_tensor * enc_pre_conv_6_b = nullptr; std::vector layers; @@ -556,13 +556,13 @@ struct parakeet_model { std::vector buffers; - int n_loaded; + int n_loaded = 0; std::map tensors; }; struct parakeet_lstm_state_layer { - struct ggml_tensor * h_state; - struct ggml_tensor * c_state; + struct ggml_tensor * h_state = nullptr; + struct ggml_tensor * c_state = nullptr; }; struct parakeet_lstm_state { @@ -612,7 +612,7 @@ struct parakeet_state { parakeet_batch batch; - int n_frames; + int n_frames = 0; std::vector backends; From 05ffa91e05e46cdaad68354ff90bcb803ba9aa87 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 24 Apr 2026 10:58:59 +0200 Subject: [PATCH 06/33] parakeet : remove non relavant fields in parakeet_state [no ci] --- src/parakeet.cpp | 28 ++-------------------------- 1 file changed, 2 insertions(+), 26 deletions(-) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index c88ff27e216..17e37ff9de9 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -596,15 +596,11 @@ struct parakeet_state { int64_t t_sample_us = 0; int64_t t_encode_us = 0; int64_t t_decode_us = 0; - int64_t t_batchd_us = 0; - int64_t t_prompt_us = 0; int64_t t_mel_us = 0; int32_t n_sample = 0; // number of tokens sampled int32_t n_encode = 0; // number of encoder calls int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) - int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding) - int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding) int32_t n_fail_p = 0; // number of logprob threshold failures int32_t n_fail_h = 0; // number of entropy threshold failures @@ -2230,12 +2226,6 @@ static bool parakeet_joint( if (batch.n_tokens == 1) { pstate.t_decode_us += ggml_time_us() - t_start_us; pstate.n_decode++; - } else if (batch.n_tokens < 16) { - pstate.t_batchd_us += ggml_time_us() - t_start_us; - pstate.n_batchd += n_tokens; - } else { - pstate.t_prompt_us += ggml_time_us() - t_start_us; - pstate.n_prompt += n_tokens; } return !(abort_callback && abort_callback(abort_callback_data)); @@ -3463,8 +3453,6 @@ struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx) { timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample); timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode); timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode); - timings->batchd_ms = 1e-3f * ctx->state->t_batchd_us / std::max(1, ctx->state->n_batchd); - timings->prompt_ms = 1e-3f * ctx->state->t_prompt_us / std::max(1, ctx->state->n_prompt); return timings; } @@ -3478,16 +3466,13 @@ void parakeet_print_timings(struct parakeet_context * ctx) { const int32_t n_sample = std::max(1, ctx->state->n_sample); const int32_t n_encode = std::max(1, ctx->state->n_encode); const int32_t n_decode = std::max(1, ctx->state->n_decode); - const int32_t n_batchd = std::max(1, ctx->state->n_batchd); - const int32_t n_prompt = std::max(1, ctx->state->n_prompt); PARAKEET_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); PARAKEET_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); PARAKEET_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); PARAKEET_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); PARAKEET_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); - PARAKEET_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd); - PARAKEET_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); + } PARAKEET_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); } @@ -3499,13 +3484,10 @@ void parakeet_reset_timings(struct parakeet_context * ctx) { ctx->state->t_sample_us = 0; ctx->state->t_encode_us = 0; ctx->state->t_decode_us = 0; - ctx->state->t_batchd_us = 0; - ctx->state->t_prompt_us = 0; + ctx->state->n_sample = 0; ctx->state->n_encode = 0; ctx->state->n_decode = 0; - ctx->state->n_batchd = 0; - ctx->state->n_prompt = 0; } } @@ -4169,18 +4151,12 @@ int parakeet_full_parallel( } ctx->state->t_mel_us += states[i]->t_mel_us; - ctx->state->t_sample_us += states[i]->t_sample_us; ctx->state->t_encode_us += states[i]->t_encode_us; ctx->state->t_decode_us += states[i]->t_decode_us; - ctx->state->t_batchd_us += states[i]->t_batchd_us; - ctx->state->t_prompt_us += states[i]->t_prompt_us; - ctx->state->n_sample += states[i]->n_sample; ctx->state->n_encode += states[i]->n_encode; ctx->state->n_decode += states[i]->n_decode; - ctx->state->n_batchd += states[i]->n_batchd; - ctx->state->n_prompt += states[i]->n_prompt; parakeet_free_state(states[i]); } From b899ce0b799eed096b9c6f6bb01e908b1de45833 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 24 Apr 2026 11:05:47 +0200 Subject: [PATCH 07/33] parakeet : fix indentation in default params [no ci] --- src/parakeet.cpp | 39 +++++++++++++++------------------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index 17e37ff9de9..86864aafd60 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -3533,32 +3533,23 @@ struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_ struct parakeet_full_params parakeet_full_default_params(enum parakeet_sampling_strategy strategy) { struct parakeet_full_params result = { - /*.strategy =*/ strategy, - - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, - - /*.no_context =*/ true, - - /*.audio_ctx =*/ 0, - - /*.chunk_length_ms =*/ 10000, // 10 second chunks - /*.left_context_ms =*/ 10000, // 10 second left context - /*.right_context_ms =*/ 4960, // 4.96 second right context - - /*.new_token_callback =*/ nullptr, - /*.new_token_callback_user_data =*/ nullptr, - - /*.new_segment_callback =*/ nullptr, - /*.new_segment_callback_user_data =*/ nullptr, - - /*.progress_callback =*/ nullptr, - /*.progress_callback_user_data =*/ nullptr, - + /*.strategy =*/ strategy, + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + /*.no_context =*/ true, + /*.audio_ctx =*/ 0, + /*.chunk_length_ms =*/ 10000, // 10 second chunks + /*.left_context_ms =*/ 10000, // 10 second left context + /*.right_context_ms =*/ 4960, // 4.96 second right context + /*.new_token_callback =*/ nullptr, + /*.new_token_callback_user_data =*/ nullptr, + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, /*.encoder_begin_callback =*/ nullptr, /*.encoder_begin_callback_user_data =*/ nullptr, - /*.abort_callback =*/ nullptr, /*.abort_callback_user_data =*/ nullptr, }; From 9fe3b5cc3066e2170f2d0f7979b9b744615526dd Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 24 Apr 2026 11:49:57 +0200 Subject: [PATCH 08/33] parakeet : check and free ggml_backend_sched_t [no-ci] --- src/parakeet.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index 86864aafd60..c16ca6ab84f 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -419,6 +419,11 @@ static bool parakeet_sched_graph_init(struct parakeet_sched & allocr, std::vecto sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), PARAKEET_MAX_NODES, false, true); + if (!sched) { + PARAKEET_LOG_ERROR("%s: failed to create scheduler\n", __func__); + return false; + } + meta.resize(ggml_tensor_overhead()*PARAKEET_MAX_NODES + ggml_graph_overhead()); // since there are dependencies between the different graphs, @@ -426,6 +431,8 @@ static bool parakeet_sched_graph_init(struct parakeet_sched & allocr, std::vecto if (!ggml_backend_sched_alloc_graph(sched, get_graph())) { // failed to allocate the compute buffer PARAKEET_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__); + ggml_backend_sched_free(sched); + sched = nullptr; return false; } From ceaa8bb90d55c90eb9a5b6e4f37a0aca981cb0d6 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 29 Apr 2026 09:09:53 +0200 Subject: [PATCH 09/33] parakeet : group related types and helpers [no ci] --- src/parakeet.cpp | 277 ++++++++++++++++++++++++----------------------- 1 file changed, 140 insertions(+), 137 deletions(-) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index c16ca6ab84f..fc3d9f72b22 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -277,61 +277,6 @@ struct parakeet_vocab { id token_eos; }; -static const std::string PARAKEET_SPM_SPACE = "\xE2\x96\x81"; - -static inline int utf8_codepoint_len(unsigned char c) { - if ((c & 0x80) == 0x00) return 1; - if ((c & 0xE0) == 0xC0) return 2; - if ((c & 0xF0) == 0xE0) return 3; - if ((c & 0xF8) == 0xF0) return 4; - return 1; -} - -static bool is_sentencepiece_control(const std::string & piece) { - return piece == "" || piece == "" || piece == "" || piece == "[BLANK]"; -} - -static std::string sentencepiece_normalize(const std::string & text) { - std::string normalized; - normalized.reserve(text.size() + PARAKEET_SPM_SPACE.size()); - normalized += PARAKEET_SPM_SPACE; // SentencePiece dummy prefix - - for (unsigned char c : text) { - if (std::isspace(c)) { - normalized += PARAKEET_SPM_SPACE; - } else { - normalized += static_cast(c); - } - } - - return normalized; -} - -static std::string sentencepiece_piece_to_text(const std::string & piece, bool is_first_piece) { - if (is_sentencepiece_control(piece)) { - return ""; - } - - std::string text; - text.reserve(piece.size()); - - size_t pos = 0; - while (pos < piece.size()) { - if (piece.compare(pos, PARAKEET_SPM_SPACE.size(), PARAKEET_SPM_SPACE) == 0) { - if (!is_first_piece || !text.empty()) { - text += ' '; - } - pos += PARAKEET_SPM_SPACE.size(); - continue; - } - - text += piece[pos]; - ++pos; - } - - return text; -} - struct parakeet_segment { int64_t t0; int64_t t1; @@ -352,50 +297,6 @@ struct parakeet_batch { int8_t * logits; }; -static struct parakeet_batch parakeet_batch_init(int32_t n_tokens) { - parakeet_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, }; - - batch.token = (parakeet_token * ) malloc(sizeof(parakeet_token) * (n_tokens)); - batch.i_time = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); - batch.pos = (parakeet_pos *) malloc(sizeof(parakeet_pos) * (n_tokens)); - batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); - batch.seq_id = (parakeet_seq_id **) malloc(sizeof(parakeet_seq_id *) * (n_tokens + 1)); - for (int i = 0; i < n_tokens; ++i) { - batch.seq_id[i] = (parakeet_seq_id *) malloc(sizeof(parakeet_seq_id)); - } - batch.seq_id[n_tokens] = nullptr; - batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); - - return batch; -} - -static void parakeet_batch_free(struct parakeet_batch batch) { - if (batch.token) free(batch.token); - if (batch.pos) free(batch.pos); - if (batch.n_seq_id) free(batch.n_seq_id); - if (batch.seq_id) { - for (int i = 0; batch.seq_id[i]; ++i) { - free(batch.seq_id[i]); - } - free(batch.seq_id); - } - if (batch.logits) free(batch.logits); -} - -static void parakeet_batch_prep_legacy(parakeet_batch & batch, const parakeet_token * tokens, int n_tokens, int n_past, int seq_id) { - batch.n_tokens = n_tokens; - for (int i = 0; i < n_tokens; ++i) { - if (tokens) { - batch.token[i] = tokens[i]; - } - batch.pos [i] = n_past + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i][0] = seq_id; - batch.logits [i] = 0; - } - batch.logits[n_tokens - 1] = 1; -} - // ggml_backend_sched wrapper for parakeet usage struct parakeet_sched { ggml_backend_sched_t sched = nullptr; @@ -403,44 +304,6 @@ struct parakeet_sched { std::vector meta; }; -static size_t parakeet_sched_size(struct parakeet_sched & allocr) { - size_t size = allocr.meta.size(); - for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) { - ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i); - size += ggml_backend_sched_get_buffer_size(allocr.sched, backend); - } - return size; -} - -// measure the memory usage of a graph and prepare the allocr's internal data buffer -static bool parakeet_sched_graph_init(struct parakeet_sched & allocr, std::vector backends, std::function && get_graph) { - auto & sched = allocr.sched; - auto & meta = allocr.meta; - - sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), PARAKEET_MAX_NODES, false, true); - - if (!sched) { - PARAKEET_LOG_ERROR("%s: failed to create scheduler\n", __func__); - return false; - } - - meta.resize(ggml_tensor_overhead()*PARAKEET_MAX_NODES + ggml_graph_overhead()); - - // since there are dependencies between the different graphs, - // we need to allocate them instead of only reserving to get the correct compute buffer size - if (!ggml_backend_sched_alloc_graph(sched, get_graph())) { - // failed to allocate the compute buffer - PARAKEET_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__); - ggml_backend_sched_free(sched); - sched = nullptr; - return false; - } - - ggml_backend_sched_reset(sched); - - return true; -} - // TODO: Find out is there a multiple version types. It is not yet clear to me // at this point. enum parakeet_arch { @@ -730,6 +593,146 @@ struct parakeet_global { static parakeet_global g_state; +static const std::string PARAKEET_SPM_SPACE = "\xE2\x96\x81"; + +static inline int utf8_codepoint_len(unsigned char c) { + if ((c & 0x80) == 0x00) return 1; + if ((c & 0xE0) == 0xC0) return 2; + if ((c & 0xF0) == 0xE0) return 3; + if ((c & 0xF8) == 0xF0) return 4; + return 1; +} + +static bool is_sentencepiece_control(const std::string & piece) { + return piece == "" || piece == "" || piece == "" || piece == "[BLANK]"; +} + +static std::string sentencepiece_normalize(const std::string & text) { + std::string normalized; + normalized.reserve(text.size() + PARAKEET_SPM_SPACE.size()); + normalized += PARAKEET_SPM_SPACE; // SentencePiece dummy prefix + + for (unsigned char c : text) { + if (std::isspace(c)) { + normalized += PARAKEET_SPM_SPACE; + } else { + normalized += static_cast(c); + } + } + + return normalized; +} + +static std::string sentencepiece_piece_to_text(const std::string & piece, bool is_first_piece) { + if (is_sentencepiece_control(piece)) { + return ""; + } + + std::string text; + text.reserve(piece.size()); + + size_t pos = 0; + while (pos < piece.size()) { + if (piece.compare(pos, PARAKEET_SPM_SPACE.size(), PARAKEET_SPM_SPACE) == 0) { + if (!is_first_piece || !text.empty()) { + text += ' '; + } + pos += PARAKEET_SPM_SPACE.size(); + continue; + } + + text += piece[pos]; + ++pos; + } + + return text; +} + + +static struct parakeet_batch parakeet_batch_init(int32_t n_tokens) { + parakeet_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, }; + + batch.token = (parakeet_token * ) malloc(sizeof(parakeet_token) * (n_tokens)); + batch.i_time = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.pos = (parakeet_pos *) malloc(sizeof(parakeet_pos) * (n_tokens)); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.seq_id = (parakeet_seq_id **) malloc(sizeof(parakeet_seq_id *) * (n_tokens + 1)); + for (int i = 0; i < n_tokens; ++i) { + batch.seq_id[i] = (parakeet_seq_id *) malloc(sizeof(parakeet_seq_id)); + } + batch.seq_id[n_tokens] = nullptr; + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +static void parakeet_batch_free(struct parakeet_batch batch) { + if (batch.token) free(batch.token); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i]; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} + +static void parakeet_batch_prep_legacy(parakeet_batch & batch, const parakeet_token * tokens, int n_tokens, int n_past, int seq_id) { + batch.n_tokens = n_tokens; + for (int i = 0; i < n_tokens; ++i) { + if (tokens) { + batch.token[i] = tokens[i]; + } + batch.pos [i] = n_past + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i][0] = seq_id; + batch.logits [i] = 0; + } + batch.logits[n_tokens - 1] = 1; +} + + +static size_t parakeet_sched_size(struct parakeet_sched & allocr) { + size_t size = allocr.meta.size(); + for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i); + size += ggml_backend_sched_get_buffer_size(allocr.sched, backend); + } + return size; +} + +// measure the memory usage of a graph and prepare the allocr's internal data buffer +static bool parakeet_sched_graph_init(struct parakeet_sched & allocr, std::vector backends, std::function && get_graph) { + auto & sched = allocr.sched; + auto & meta = allocr.meta; + + sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), PARAKEET_MAX_NODES, false, true); + + if (!sched) { + PARAKEET_LOG_ERROR("%s: failed to create scheduler\n", __func__); + return false; + } + + meta.resize(ggml_tensor_overhead()*PARAKEET_MAX_NODES + ggml_graph_overhead()); + + // since there are dependencies between the different graphs, + // we need to allocate them instead of only reserving to get the correct compute buffer size + if (!ggml_backend_sched_alloc_graph(sched, get_graph())) { + // failed to allocate the compute buffer + PARAKEET_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__); + ggml_backend_sched_free(sched); + sched = nullptr; + return false; + } + + ggml_backend_sched_reset(sched); + + return true; +} + + template static void read_safe(parakeet_model_loader * loader, T & dest) { loader->read(loader->context, &dest, sizeof(T)); From 6e61988157afd0ef0f097b5e818bf1e06f7e63ff Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 30 Apr 2026 16:02:20 +0200 Subject: [PATCH 10/33] parakeet : remove parakeet_full_parallel() API and implementation [no ci] --- include/parakeet.h | 10 ---- src/parakeet.cpp | 105 ----------------------------------- tests/test-parakeet-full.cpp | 2 +- 3 files changed, 1 insertion(+), 116 deletions(-) diff --git a/include/parakeet.h b/include/parakeet.h index 34179f7df18..bd2f1434663 100644 --- a/include/parakeet.h +++ b/include/parakeet.h @@ -295,16 +295,6 @@ extern "C" { const float * samples, int n_samples); - // Split the input audio in chunks and process each chunk separately using parakeet_full_with_state() - // Result is stored in the default state of the context - // Not thread safe if executed in parallel on the same context. - PARAKEET_API int parakeet_full_parallel( - struct parakeet_context * ctx, - struct parakeet_full_params params, - const float * samples, - int n_samples, - int n_processors); - // Process a single chunk of audio data that fits within the model's audio context window. // This is more efficient than parakeet_full() for short audio clips. PARAKEET_API int parakeet_chunk( diff --git a/src/parakeet.cpp b/src/parakeet.cpp index fc3d9f72b22..c3a31b07cec 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -4074,111 +4074,6 @@ int parakeet_chunk( return 0; } -int parakeet_full_parallel( - struct parakeet_context * ctx, - struct parakeet_full_params params, - const float * samples, - int n_samples, - int n_processors) { - if (n_processors == 1) { - return parakeet_full(ctx, params, samples, n_samples); - } - - int ret = 0; - - // prepare separate states for each thread - std::vector states; - - const int offset_samples = (PARAKEET_SAMPLE_RATE*params.offset_ms)/1000; - const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; - - // the calling thread will process the first chunk - // while the other threads will process the remaining chunks - - std::vector workers(n_processors - 1); - for (int i = 0; i < n_processors - 1; ++i) { - // create a new state for each thread - states.push_back(parakeet_init_state(ctx)); - - const int start_samples = offset_samples + (i + 1)*n_samples_per_processor; - const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; - - auto params_cur = params; - - params_cur.offset_ms = 0; - - params_cur.new_segment_callback = nullptr; - params_cur.new_segment_callback_user_data = nullptr; - - params_cur.progress_callback = nullptr; - params_cur.progress_callback_user_data = nullptr; - - workers[i] = std::thread(parakeet_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur); - } - - { - auto params_cur = params; - - // Run the first transformation using default state but only for the first chunk. - ret = parakeet_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor); - } - - for (int i = 0; i < n_processors - 1; ++i) { - workers[i].join(); - } - - const int64_t offset_t = (int64_t) params.offset_ms/10.0; - - // combine results into result_state->result_all from all other states - for (int i = 0; i < n_processors - 1; ++i) { - auto& results_i = states[i]->result_all; - - for (auto& result : results_i) { - // correct the segment timestamp taking into account the offset - result.t0 += 100 * ((i + 1) * n_samples_per_processor) / PARAKEET_SAMPLE_RATE + offset_t; - result.t1 += 100 * ((i + 1) * n_samples_per_processor) / PARAKEET_SAMPLE_RATE + offset_t; - - // make sure that segments are not overlapping - if (!ctx->state->result_all.empty()) { - result.t0 = std::max(result.t0, ctx->state->result_all.back().t1); - } - - ctx->state->result_all.push_back(std::move(result)); - - // call the new_segment_callback for each segment - if (params.new_segment_callback) { - params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data); - } - } - - ctx->state->t_mel_us += states[i]->t_mel_us; - ctx->state->t_sample_us += states[i]->t_sample_us; - ctx->state->t_encode_us += states[i]->t_encode_us; - ctx->state->t_decode_us += states[i]->t_decode_us; - ctx->state->n_sample += states[i]->n_sample; - ctx->state->n_encode += states[i]->n_encode; - ctx->state->n_decode += states[i]->n_decode; - - parakeet_free_state(states[i]); - } - - // average the timings - ctx->state->t_mel_us /= n_processors; - ctx->state->t_sample_us /= n_processors; - ctx->state->t_encode_us /= n_processors; - ctx->state->t_decode_us /= n_processors; - - // print information about the audio boundaries - PARAKEET_LOG_WARN("\n"); - PARAKEET_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); - for (int i = 0; i < n_processors - 1; ++i) { - PARAKEET_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/PARAKEET_SAMPLE_RATE + offset_t).c_str()); - } - PARAKEET_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__); - - return ret; -} - int parakeet_full_n_segments_from_state(struct parakeet_state * state) { return state->result_all.size(); } diff --git a/tests/test-parakeet-full.cpp b/tests/test-parakeet-full.cpp index 11d52488f51..3de9e86b308 100644 --- a/tests/test-parakeet-full.cpp +++ b/tests/test-parakeet-full.cpp @@ -57,6 +57,6 @@ int main() { parakeet_free(pctx); - printf("\nTest passed: parakeet_full_parallel succeeded!\n"); + printf("\nTest passed: parakeet_full succeeded!\n"); return 0; } From 66e6b093e1576464e648d23f44dced454437f501 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 30 Apr 2026 16:12:27 +0200 Subject: [PATCH 11/33] parakeet : remove unused timeing fields [no ci] --- include/parakeet.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/include/parakeet.h b/include/parakeet.h index bd2f1434663..3db69c71d68 100644 --- a/include/parakeet.h +++ b/include/parakeet.h @@ -196,8 +196,6 @@ extern "C" { float sample_ms; float encode_ms; float decode_ms; - float batchd_ms; - float prompt_ms; }; PARAKEET_API struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx); PARAKEET_API void parakeet_print_timings(struct parakeet_context * ctx); @@ -245,7 +243,6 @@ extern "C" { bool no_context; // do not use past transcription (if any) as context - // [EXPERIMENTAL] speed-up techniques int audio_ctx; // overwrite the audio context size (0 = use default) int chunk_length_ms; // length of each chunk in ms From 2ad47d1127dc268fb7af7389f20afd3051758dc8 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 4 May 2026 08:30:40 +0200 Subject: [PATCH 12/33] parakeet : generate rel pos tensor in graph instead of in conversion [no ci] This commit removes the generation of the relative positional tensor in the model conversion script and instead computes it in the encoder graph. This is only done for the window of positions required for the current audio sample. This was suggested in the mtmd integration of parakeet and the same approach is used there. --- models/convert-parakeet-to-ggml.py | 24 ----------- src/parakeet-arch.h | 3 -- src/parakeet.cpp | 69 ++++++++++++++++++++---------- 3 files changed, 47 insertions(+), 49 deletions(-) diff --git a/models/convert-parakeet-to-ggml.py b/models/convert-parakeet-to-ggml.py index 1cceba3ac92..9df6dc3e679 100755 --- a/models/convert-parakeet-to-ggml.py +++ b/models/convert-parakeet-to-ggml.py @@ -32,23 +32,6 @@ def hz_to_mel(freq): def mel_to_hz(mel): return 700.0 * (10.0**(mel / 2595.0) - 1.0) -def create_relative_positional_encoding(d_model: int, n_pos_max_len: int) -> np.ndarray: - max_len = n_pos_max_len * 2 - 1 - log_10000 = np.log(10000.0) - pe = np.zeros((max_len, d_model), dtype=np.float32) - - for idx in range(max_len): - position = float((max_len // 2) - idx) - - for i in range(0, d_model, 2): - div_term = np.exp(-float(i) * log_10000 / float(d_model)) - angle = position * div_term - - pe[idx, i] = np.sin(angle) - pe[idx, i + 1] = np.cos(angle) - - return pe - def extract_nemo_archive(nemo_path, extract_dir): print(f"Extracting {nemo_path} to {extract_dir}") with tarfile.open(nemo_path, 'r') as tar: @@ -165,7 +148,6 @@ def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None) 'n_fft': config['preprocessor']['n_fft'], 'subsampling_factor': config['encoder']['subsampling_factor'], 'n_subsampling_channels': config['encoder']['subsampling_conv_channels'], - 'n_pos_max_len': config['encoder']['pos_emb_max_len'], 'n_pred_dim': config['decoder']['prednet']['pred_hidden'], 'n_pred_layers': config['decoder']['prednet']['pred_rnn_layers'], @@ -178,9 +160,6 @@ def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None) for key, value in hparams.items(): print(f" {key}: {value}") - pe = create_relative_positional_encoding(hparams['n_audio_state'], hparams['n_pos_max_len']) - print(f"\nGenerated positional encoding tensor 'encoder.pe' with shape {pe.shape}") - # Create output file if out_name: fname_out = output_dir / out_name @@ -203,7 +182,6 @@ def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None) fout.write(struct.pack("i", hparams['n_fft'])) fout.write(struct.pack("i", hparams['subsampling_factor'])) fout.write(struct.pack("i", hparams['n_subsampling_channels'])) - fout.write(struct.pack("i", hparams['n_pos_max_len'])) fout.write(struct.pack("i", hparams['n_pred_dim'])) fout.write(struct.pack("i", hparams['n_pred_layers'])) fout.write(struct.pack("i", hparams['n_tdt_durations'])) @@ -302,8 +280,6 @@ def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None) write_tensor(fout, name, data, use_f16=use_f16) - write_tensor(fout, "encoder.pe", pe, use_f16=use_f16, force_f32=True) - print(f"\nConversion complete!") print(f"Output file: {fname_out}") print(f"File size: {fname_out.stat().st_size / (1024**2):.2f} MB") diff --git a/src/parakeet-arch.h b/src/parakeet-arch.h index 03eb1303edf..2f3668a84f7 100644 --- a/src/parakeet-arch.h +++ b/src/parakeet-arch.h @@ -18,7 +18,6 @@ enum parakeet_tensor { PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, - PARAKEET_TENSOR_ENC_PE, // Encoder layers (per-layer) PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, @@ -81,7 +80,6 @@ static const std::map PARAKEET_TENSOR_NAMES = { {PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, "encoder.pre_encode.conv.5.bias"}, {PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, "encoder.pre_encode.conv.6.weight"}, {PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, "encoder.pre_encode.conv.6.bias"}, - {PARAKEET_TENSOR_ENC_PE, "encoder.pe"}, // Encoder layers (use %d for layer number) {PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, "encoder.layers.%d.norm_feed_forward1.weight"}, @@ -144,7 +142,6 @@ static const std::map PARAKEET_TENSOR_INFO = { {PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, GGML_OP_ADD}, {PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, GGML_OP_IM2COL}, {PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, GGML_OP_ADD}, - {PARAKEET_TENSOR_ENC_PE, GGML_OP_ADD}, // Encoder layers {PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, GGML_OP_MUL}, diff --git a/src/parakeet.cpp b/src/parakeet.cpp index c3a31b07cec..8c042576898 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -323,7 +323,6 @@ struct parakeet_hparams { float eps = 1e-5f; int32_t subsampling_factor = 8; int32_t n_subsampling_channels = 256; - int32_t n_pos_max_len = 5000; int32_t n_pred_dim = 640; int32_t n_pred_layers = 2; int32_t n_tdt_durations = 5; @@ -398,9 +397,6 @@ struct parakeet_model { parakeet_filters filters; parakeet_hparams hparams; - // Relative positional encoding (pe) - struct ggml_tensor * pe = nullptr; - struct ggml_tensor * enc_pre_out_w = nullptr; struct ggml_tensor * enc_pre_out_b = nullptr; struct ggml_tensor * enc_pre_conv_0_w = nullptr; @@ -1033,7 +1029,6 @@ static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_ read_safe(loader, hparams.n_fft); read_safe(loader, hparams.subsampling_factor); read_safe(loader, hparams.n_subsampling_channels); - read_safe(loader, hparams.n_pos_max_len); read_safe(loader, hparams.n_pred_dim); read_safe(loader, hparams.n_pred_layers); read_safe(loader, hparams.n_tdt_durations); @@ -1067,7 +1062,6 @@ static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_ PARAKEET_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr); PARAKEET_LOG_INFO("%s: subsampling_factor = %d\n", __func__, hparams.subsampling_factor); PARAKEET_LOG_INFO("%s: n_subsampling_channels = %d\n", __func__, hparams.n_subsampling_channels); - PARAKEET_LOG_INFO("%s: n_pos_max_len = %d\n", __func__, hparams.n_pos_max_len); PARAKEET_LOG_INFO("%s: n_pred_dim = %d\n", __func__, hparams.n_pred_dim); PARAKEET_LOG_INFO("%s: n_pred_layers = %d\n", __func__, hparams.n_pred_layers); PARAKEET_LOG_INFO("%s: n_tdt_durations = %d\n", __func__, hparams.n_tdt_durations); @@ -1375,8 +1369,6 @@ static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_ // Relative positional encoding { - model.pe = create_tensor(PARAKEET_TENSOR_ENC_PE, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, hparams.n_pos_max_len * 2 - 1)); - ggml_set_name(model.pe, "pe"); } ggml_free(ctx); @@ -1618,6 +1610,27 @@ static struct ggml_cgraph * parakeet_build_graph_encoder(parakeet_context & pctx ggml_set_name(attn_mask, "attn_mask"); ggml_set_input(attn_mask); + const int n_time = cur->ne[1]; + const int window_size = 2 * n_time - 1; + const int d_half = n_state / 2; + + struct ggml_tensor * pos_freqs = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_half); + ggml_set_name(pos_freqs, "pos_freqs"); + ggml_set_input(pos_freqs); + + struct ggml_tensor * rel_positions = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, window_size); + ggml_set_name(rel_positions, "rel_positions"); + ggml_set_input(rel_positions); + + struct ggml_tensor * freqs = ggml_repeat_4d(ctx0, pos_freqs, d_half, window_size, 1, 1); + struct ggml_tensor * theta = ggml_mul(ctx0, freqs, rel_positions); + + struct ggml_tensor * sin_t = ggml_reshape_3d(ctx0, ggml_sin(ctx0, theta), 1, d_half, window_size); + struct ggml_tensor * cos_t = ggml_reshape_3d(ctx0, ggml_cos(ctx0, theta), 1, d_half, window_size); + // [n_state, window_size] + struct ggml_tensor * pos_emb = ggml_reshape_2d(ctx0, ggml_cont(ctx0, ggml_concat(ctx0, sin_t, cos_t, 0)), n_state, window_size); + ggml_set_name(pos_emb, "pos_emb"); + for (int il = 0; il < n_layer; ++il) { // FFN1 { @@ -1641,7 +1654,7 @@ static struct ggml_cgraph * parakeet_build_graph_encoder(parakeet_context & pctx ggml_format_name(cur, "enc_%d_res_ffn", il); } - // self attention block using relative positional encoding from the model.pe tensor. + // self attention block using relative positional encoding computed in graph. { // [feat, time_frames, 1, 1] struct ggml_tensor * residual = cur; @@ -1664,18 +1677,7 @@ static struct ggml_cgraph * parakeet_build_graph_encoder(parakeet_context & pctx K_cur = ggml_reshape_3d(ctx0, K_cur, d_head, n_head, n_time); V_cur = ggml_reshape_3d(ctx0, V_cur, d_head, n_head, n_time); - const int input_len = cur->ne[1]; - const int center_pos = model.pe->ne[1] / 2 + 1; - const int start_pos = center_pos - input_len; - const int window_size = 2 * input_len - 1; - const size_t offset = start_pos * model.pe->nb[1]; - - // [feat, window_size] - struct ggml_tensor * pos_emb = ggml_view_2d(ctx0, model.pe, - n_state, window_size, - model.pe->nb[1], offset); - ggml_format_name(pos_emb, "enc_%d_attn_pos_emb", il); - + // [n_state, window_size] struct ggml_tensor * pos = ggml_mul_mat(ctx0, model.layers[il].attn_pos_w, pos_emb); ggml_format_name(pos, "enc_%d_attn_pos", il); @@ -1704,7 +1706,7 @@ static struct ggml_cgraph * parakeet_build_graph_encoder(parakeet_context & pctx // [feat, window_size, 1, 1] and we are doing multi-head attention so // we need to split this into heads. // [feat, head, window_size, 1] - pos = ggml_reshape_3d(ctx0, pos, d_head, n_head, pos_emb->ne[1]); + pos = ggml_reshape_3d(ctx0, pos, d_head, n_head, window_size); // [feat, window_size, head, 1] pos = ggml_permute(ctx0, pos, 0, 2, 1, 3); @@ -1941,6 +1943,29 @@ static bool parakeet_encode_internal( ggml_backend_tensor_set(attn_mask, mask_data.data(), 0, mask_data.size() * sizeof(float)); } + { + struct ggml_tensor * pos_freqs_t = ggml_graph_get_tensor(gf, "pos_freqs"); + const int d_half = pos_freqs_t->ne[0]; + const int n_state = pctx.model.hparams.n_audio_state; + const float log_10000 = logf(10000.0f); + std::vector freqs(d_half); + for (int k = 0; k < d_half; ++k) { + freqs[k] = expf(-(float(k * 2) * log_10000 / float(n_state))); + } + ggml_backend_tensor_set(pos_freqs_t, freqs.data(), 0, freqs.size() * sizeof(float)); + } + + { + struct ggml_tensor * rel_pos_t = ggml_graph_get_tensor(gf, "rel_positions"); + const int window_size = rel_pos_t->ne[1]; + const int n_time = (window_size + 1) / 2; + std::vector pos(window_size); + for (int t = 0; t < window_size; ++t) { + pos[t] = float(n_time - 1 - t); + } + ggml_backend_tensor_set(rel_pos_t, pos.data(), 0, pos.size() * sizeof(float)); + } + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { return false; } From 42dcf1923fc41b2050bb0617b22ce827930ab274 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 4 May 2026 08:49:20 +0200 Subject: [PATCH 13/33] examples : print system info and timings for parakeet-cli [no ci] --- examples/parakeet-cli/parakeet-cli.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/parakeet-cli/parakeet-cli.cpp b/examples/parakeet-cli/parakeet-cli.cpp index 35f2381967c..5f92144a4b8 100644 --- a/examples/parakeet-cli/parakeet-cli.cpp +++ b/examples/parakeet-cli/parakeet-cli.cpp @@ -154,6 +154,8 @@ int main(int argc, char ** argv) { } fprintf(stderr, "Successfully loaded Parakeet model\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info()); fprintf(stderr, "Processing audio (%zu samples, %.2f seconds)\n", pcmf32.size(), (float)pcmf32.size() / PARAKEET_SAMPLE_RATE); @@ -180,6 +182,8 @@ int main(int argc, char ** argv) { printf("\n"); + parakeet_print_timings(pctx); + if (params.print_segments) { const int n_segments = parakeet_full_n_segments(pctx); fprintf(stderr, "\nSegments (%d):\n", n_segments); From 2f9216f1b8278f99b6a6868a2f9b5b9cfcccdc93 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 4 May 2026 12:25:16 +0200 Subject: [PATCH 14/33] parakeet : add missing free of batch.i_time [no ci] --- src/parakeet.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index 8c042576898..9a7dfddf475 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -664,6 +664,7 @@ static struct parakeet_batch parakeet_batch_init(int32_t n_tokens) { static void parakeet_batch_free(struct parakeet_batch batch) { if (batch.token) free(batch.token); + if (batch.i_time) free(batch.i_time); if (batch.pos) free(batch.pos); if (batch.n_seq_id) free(batch.n_seq_id); if (batch.seq_id) { From 27cc5d7fb908c2a67cc9836166579285ecc197d6 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 4 May 2026 13:41:52 +0200 Subject: [PATCH 15/33] examples : add --output-txt and --output-file to parakeet-cli [no ci] --- examples/parakeet-cli/parakeet-cli.cpp | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/examples/parakeet-cli/parakeet-cli.cpp b/examples/parakeet-cli/parakeet-cli.cpp index 5f92144a4b8..b69d25a9226 100644 --- a/examples/parakeet-cli/parakeet-cli.cpp +++ b/examples/parakeet-cli/parakeet-cli.cpp @@ -6,6 +6,7 @@ #include #include #include +#include // command-line parameters struct parakeet_params { @@ -19,8 +20,10 @@ struct parakeet_params { int32_t gpu_device = 0; bool print_segments = false; + bool output_txt = false; - std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin"; + std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin"; + std::string output_file = ""; std::vector fname_inp = {}; }; @@ -65,6 +68,8 @@ static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & para else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = false; } else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-ps" || arg == "--print-segments") { params.print_segments = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-of" || arg == "--output-file") { params.output_file = ARGV_NEXT; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); parakeet_print_usage(argc, argv, params); @@ -93,6 +98,8 @@ static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_para fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", !params.flash_attn ? "true" : "false"); fprintf(stderr, " -ps, --print-segments [%-7s] print segment information\n", params.print_segments ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -of, --output-file FILE [%-7s] output file path (without file extension)\n", ""); fprintf(stderr, "\n"); } @@ -182,6 +189,23 @@ int main(int argc, char ** argv) { printf("\n"); + if (params.output_txt) { + const std::string fname_out = (!params.output_file.empty() ? params.output_file : fname) + ".txt"; + + std::ofstream fout(fname_out); + if (fout.is_open()) { + const int n_segments = parakeet_full_n_segments(pctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = parakeet_full_get_segment_text(pctx, i); + fout << text << "\n"; + } + fout.close(); + fprintf(stderr, "Output written to: %s\n", fname_out.c_str()); + } else { + fprintf(stderr, "error: failed to open '%s' for writing\n", fname_out.c_str()); + } + } + parakeet_print_timings(pctx); if (params.print_segments) { From 93806f4dec5d4a2b2a65e19601349a9d10619c40 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 4 May 2026 14:19:45 +0200 Subject: [PATCH 16/33] examples : add --no-prints to parakeet-cli This is to enable librispeech testing which will be enabled in a follow up commit. --- examples/parakeet-cli/parakeet-cli.cpp | 36 +++++++++++++++++++------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/examples/parakeet-cli/parakeet-cli.cpp b/examples/parakeet-cli/parakeet-cli.cpp index b69d25a9226..c8a77a42fbc 100644 --- a/examples/parakeet-cli/parakeet-cli.cpp +++ b/examples/parakeet-cli/parakeet-cli.cpp @@ -21,6 +21,7 @@ struct parakeet_params { bool print_segments = false; bool output_txt = false; + bool no_prints = false; std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin"; std::string output_file = ""; @@ -70,6 +71,7 @@ static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & para else if (arg == "-ps" || arg == "--print-segments") { params.print_segments = true; } else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } else if (arg == "-of" || arg == "--output-file") { params.output_file = ARGV_NEXT; } + else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); parakeet_print_usage(argc, argv, params); @@ -100,6 +102,7 @@ static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_para fprintf(stderr, " -ps, --print-segments [%-7s] print segment information\n", params.print_segments ? "true" : "false"); fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); fprintf(stderr, " -of, --output-file FILE [%-7s] output file path (without file extension)\n", ""); + fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false"); fprintf(stderr, "\n"); } @@ -115,6 +118,7 @@ void token_callback(parakeet_context * ctx, parakeet_state * state, const parake is_first = false; } +static void cb_log_disable(enum ggml_log_level , const char * , void * ) { } int main(int argc, char ** argv) { ggml_backend_load_all(); @@ -125,6 +129,10 @@ int main(int argc, char ** argv) { return 1; } + if (params.no_prints) { + parakeet_log_set(cb_log_disable, NULL); + } + if (params.fname_inp.empty()) { fprintf(stderr, "error: no input files specified\n"); parakeet_print_usage(argc, argv, params); @@ -133,7 +141,9 @@ int main(int argc, char ** argv) { // Process each input file for (const auto & fname : params.fname_inp) { - fprintf(stderr, "\nProcessing file: %s\n", fname.c_str()); + if (!params.no_prints) { + fprintf(stderr, "\nProcessing file: %s\n", fname.c_str()); + } std::vector pcmf32; std::vector> pcmf32s; @@ -147,7 +157,9 @@ int main(int argc, char ** argv) { continue; } - fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str()); + if (!params.no_prints) { + fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str()); + } struct parakeet_context_params ctx_params = parakeet_context_default_params(); ctx_params.use_gpu = params.use_gpu; @@ -160,11 +172,13 @@ int main(int argc, char ** argv) { return 1; } - fprintf(stderr, "Successfully loaded Parakeet model\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info()); - fprintf(stderr, "Processing audio (%zu samples, %.2f seconds)\n", - pcmf32.size(), (float)pcmf32.size() / PARAKEET_SAMPLE_RATE); + if (!params.no_prints) { + fprintf(stderr, "Successfully loaded Parakeet model\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info()); + fprintf(stderr, "Processing audio (%zu samples, %.2f seconds)\n", + pcmf32.size(), (float)pcmf32.size() / PARAKEET_SAMPLE_RATE); + } struct parakeet_full_params full_params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); full_params.n_threads = params.n_threads; @@ -200,13 +214,17 @@ int main(int argc, char ** argv) { fout << text << "\n"; } fout.close(); - fprintf(stderr, "Output written to: %s\n", fname_out.c_str()); + if (!params.no_prints) { + fprintf(stderr, "Output written to: %s\n", fname_out.c_str()); + } } else { fprintf(stderr, "error: failed to open '%s' for writing\n", fname_out.c_str()); } } - parakeet_print_timings(pctx); + if (!params.no_prints) { + parakeet_print_timings(pctx); + } if (params.print_segments) { const int n_segments = parakeet_full_n_segments(pctx); From 3f6b17c48db3f5eb7189e1eae13a61a8e4bea7bd Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 4 May 2026 18:22:33 +0200 Subject: [PATCH 17/33] tests : add script to benchmark parakeet.cpp on LibriSpeech [no ci] The result from running the tests was: ```console $ cat parakeet-tdt-0.6b-v3.txt WER: 1.96% ``` --- tests/librispeech-parakeet/.gitignore | 6 + tests/librispeech-parakeet/Makefile | 15 + tests/librispeech-parakeet/README.md | 57 + tests/librispeech-parakeet/eval.mk | 39 + tests/librispeech-parakeet/eval.py | 47 + .../librispeech-parakeet/normalizers/LICENSE | 25 + .../normalizers/__init__.py | 2 + .../librispeech-parakeet/normalizers/basic.py | 80 + .../normalizers/english.json | 1741 +++++++++++++++++ .../normalizers/english.py | 550 ++++++ 10 files changed, 2562 insertions(+) create mode 100644 tests/librispeech-parakeet/.gitignore create mode 100644 tests/librispeech-parakeet/Makefile create mode 100644 tests/librispeech-parakeet/README.md create mode 100644 tests/librispeech-parakeet/eval.mk create mode 100644 tests/librispeech-parakeet/eval.py create mode 100644 tests/librispeech-parakeet/normalizers/LICENSE create mode 100644 tests/librispeech-parakeet/normalizers/__init__.py create mode 100644 tests/librispeech-parakeet/normalizers/basic.py create mode 100644 tests/librispeech-parakeet/normalizers/english.json create mode 100644 tests/librispeech-parakeet/normalizers/english.py diff --git a/tests/librispeech-parakeet/.gitignore b/tests/librispeech-parakeet/.gitignore new file mode 100644 index 00000000000..838bfeae9db --- /dev/null +++ b/tests/librispeech-parakeet/.gitignore @@ -0,0 +1,6 @@ +__pycache__ +*.tar.gz +*.txt +eval.conf +venv +LibriSpeech diff --git a/tests/librispeech-parakeet/Makefile b/tests/librispeech-parakeet/Makefile new file mode 100644 index 00000000000..0afa2465f49 --- /dev/null +++ b/tests/librispeech-parakeet/Makefile @@ -0,0 +1,15 @@ +TAR_URL = https://www.openslr.org/resources/12/test-clean.tar.gz + +all: eval + +eval: + $(MAKE) -f eval.mk + +clean: + $(MAKE) -f eval.mk clean + +get-audio: + wget -c $(TAR_URL) + tar -xf test-clean.tar.gz + +.PHONY: all eval clean setup-venv clean-venv get-audio diff --git a/tests/librispeech-parakeet/README.md b/tests/librispeech-parakeet/README.md new file mode 100644 index 00000000000..e09cba405ef --- /dev/null +++ b/tests/librispeech-parakeet/README.md @@ -0,0 +1,57 @@ +# parakeet.cpp/tests/librispeech + +[LibriSpeech](https://www.openslr.org/12) is a standard dataset for +training and evaluating automatic speech recognition systems. + +This directory contains a set of tools to evaluate the recognition +performance of parakeet.cpp on LibriSpeech corpus. + +## Quick Start + +1. (Pre-requirement) Compile `parakeet-cli` and prepare the Parakeet + model in `ggml` format. + + ``` + $ # Execute the commands below in the project root dir. + $ cmake -B build + $ cmake --build build --config Release + ``` + +2. Download the audio files from LibriSpeech project. + + ``` + $ make get-audio + ``` + +3. Set up the environment to compute WER score. + + ``` + $ pip install -r requirements.txt + ``` + + For example, if you use `virtualenv`, you can set up it as follows: + + ``` + $ python3 -m venv venv + $ . venv/bin/activate + $ pip install -r requirements.txt + ``` + +4. Run the benchmark test. + + ``` + $ make + ``` + +## How-to guides + +### How to change the inference parameters + +Create `eval.conf` and override variables. + +``` +PARAKEET_MODEL = parakeet-tdt-0.6b-v3 +PARAKEET_FLAGS = --no-prints --threads 8 --language en --output-txt +``` + +Check out `eval.mk` for more details. diff --git a/tests/librispeech-parakeet/eval.mk b/tests/librispeech-parakeet/eval.mk new file mode 100644 index 00000000000..fa6d2e10264 --- /dev/null +++ b/tests/librispeech-parakeet/eval.mk @@ -0,0 +1,39 @@ +PYTHON = python + +PARAKEET_PREFIX = ../../ +PARAKEET_MODEL = parakeet-tdt-0.6b-v3 + +PARAKEET_CLI = $(PARAKEET_PREFIX)build-cuda-89-release/bin/parakeet-cli +PARAKEET_FLAGS = --no-prints --output-txt + +# You can create eval.conf to override the PARAKEET_* variables +# defined above. +-include eval.conf + +# This follows the file structure of the LibriSpeech project. +AUDIO_SRCS = $(sort $(wildcard LibriSpeech/*/*/*/*.flac)) +TRANS_TXTS = $(addsuffix .txt, $(AUDIO_SRCS)) + +# We output the evaluation result to this file. +DONE = $(PARAKEET_MODEL).txt + +all: $(DONE) + +$(DONE): $(TRANS_TXTS) + $(PYTHON) eval.py > $@.tmp + mv $@.tmp $@ + +# Note: This task writes to a temporary file first to +# create the target file atomically. +%.flac.txt: %.flac + $(PARAKEET_CLI) $(PARAKEET_FLAGS) --model $(PARAKEET_PREFIX)models/ggml-$(PARAKEET_MODEL).bin --file $^ --output-file $^.tmp + mv $^.tmp.txt $^.txt + +archive: + tar -czf $(PARAKEET_MODEL).tar.gz --exclude="*.flac" LibriSpeech $(DONE) + +clean: + @rm -f $(TRANS_TXTS) + @rm -f $(DONE) + +.PHONY: all clean diff --git a/tests/librispeech-parakeet/eval.py b/tests/librispeech-parakeet/eval.py new file mode 100644 index 00000000000..cdaf8352fd4 --- /dev/null +++ b/tests/librispeech-parakeet/eval.py @@ -0,0 +1,47 @@ +import os +import glob +import jiwer +from normalizers import EnglishTextNormalizer + +def get_reference(): + ref = {} + for path in glob.glob('LibriSpeech/*/*/*/*.trans.txt'): + with open(path) as fp: + for line in fp: + code, text = line.strip().split(" ", maxsplit=1) + ref [code] = text + return ref + +def get_hypothesis(): + hyp = {} + for path in glob.glob('LibriSpeech/*/*/*/*.flac.txt'): + with open(path) as fp: + text = fp.read().strip() + code = os.path.basename(path).replace('.flac.txt', '') + hyp[code] = text + return hyp + +def get_codes(): + codes = [] + for path in glob.glob('LibriSpeech/*/*/*/*.flac'): + codes.append(os.path.basename(path).replace('.flac', '')) + return sorted(codes) + +def main(): + normalizer = EnglishTextNormalizer() + + ref_orig = get_reference() + hyp_orig = get_hypothesis() + + ref_clean = [] + hyp_clean = [] + + for code in get_codes(): + ref_clean.append(normalizer(ref_orig[code])) + hyp_clean.append(normalizer(hyp_orig[code])) + + wer = jiwer.wer(ref_clean, hyp_clean) + print(f"WER: {wer * 100:.2f}%") + +if __name__ == '__main__': + main() diff --git a/tests/librispeech-parakeet/normalizers/LICENSE b/tests/librispeech-parakeet/normalizers/LICENSE new file mode 100644 index 00000000000..7c8e603b0fc --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/LICENSE @@ -0,0 +1,25 @@ +Code in this directory is adapted from OpenAI Whisper project +(https://github.com/openai/whisper) and carries the following +copyright and license. + + MIT License + + Copyright (c) 2022 OpenAI + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/tests/librispeech-parakeet/normalizers/__init__.py b/tests/librispeech-parakeet/normalizers/__init__.py new file mode 100644 index 00000000000..896d5e33641 --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/__init__.py @@ -0,0 +1,2 @@ +from .basic import BasicTextNormalizer as BasicTextNormalizer +from .english import EnglishTextNormalizer as EnglishTextNormalizer diff --git a/tests/librispeech-parakeet/normalizers/basic.py b/tests/librispeech-parakeet/normalizers/basic.py new file mode 100644 index 00000000000..8690ae71c5f --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/basic.py @@ -0,0 +1,80 @@ +import re +import unicodedata + +import regex + +# non-ASCII letters that are not separated by "NFKD" normalization +ADDITIONAL_DIACRITICS = { + "œ": "oe", + "Œ": "OE", + "ø": "o", + "Ø": "O", + "æ": "ae", + "Æ": "AE", + "ß": "ss", + "ẞ": "SS", + "đ": "d", + "Đ": "D", + "ð": "d", + "Ð": "D", + "þ": "th", + "Þ": "th", + "ł": "l", + "Ł": "L", +} + + +def remove_symbols_and_diacritics(s: str, keep=""): + """ + Replace any other markers, symbols, and punctuations with a space, + and drop any diacritics (category 'Mn' and some manual mappings) + """ + return "".join( + ( + c + if c in keep + else ( + ADDITIONAL_DIACRITICS[c] + if c in ADDITIONAL_DIACRITICS + else ( + "" + if unicodedata.category(c) == "Mn" + else " " if unicodedata.category(c)[0] in "MSP" else c + ) + ) + ) + for c in unicodedata.normalize("NFKD", s) + ) + + +def remove_symbols(s: str): + """ + Replace any other markers, symbols, punctuations with a space, keeping diacritics + """ + return "".join( + " " if unicodedata.category(c)[0] in "MSP" else c + for c in unicodedata.normalize("NFKC", s) + ) + + +class BasicTextNormalizer: + def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): + self.clean = ( + remove_symbols_and_diacritics if remove_diacritics else remove_symbols + ) + self.split_letters = split_letters + + def __call__(self, s: str): + s = s.lower() + s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = self.clean(s).lower() + + if self.split_letters: + s = " ".join(regex.findall(r"\X", s, regex.U)) + + s = re.sub( + r"\s+", " ", s + ) # replace any successive whitespace characters with a space + + return s diff --git a/tests/librispeech-parakeet/normalizers/english.json b/tests/librispeech-parakeet/normalizers/english.json new file mode 100644 index 00000000000..74a1c3521d9 --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/english.json @@ -0,0 +1,1741 @@ +{ + "accessorise": "accessorize", + "accessorised": "accessorized", + "accessorises": "accessorizes", + "accessorising": "accessorizing", + "acclimatisation": "acclimatization", + "acclimatise": "acclimatize", + "acclimatised": "acclimatized", + "acclimatises": "acclimatizes", + "acclimatising": "acclimatizing", + "accoutrements": "accouterments", + "aeon": "eon", + "aeons": "eons", + "aerogramme": "aerogram", + "aerogrammes": "aerograms", + "aeroplane": "airplane", + "aeroplanes": "airplanes", + "aesthete": "esthete", + "aesthetes": "esthetes", + "aesthetic": "esthetic", + "aesthetically": "esthetically", + "aesthetics": "esthetics", + "aetiology": "etiology", + "ageing": "aging", + "aggrandisement": "aggrandizement", + "agonise": "agonize", + "agonised": "agonized", + "agonises": "agonizes", + "agonising": "agonizing", + "agonisingly": "agonizingly", + "almanack": "almanac", + "almanacks": "almanacs", + "aluminium": "aluminum", + "amortisable": "amortizable", + "amortisation": "amortization", + "amortisations": "amortizations", + "amortise": "amortize", + "amortised": "amortized", + "amortises": "amortizes", + "amortising": "amortizing", + "amphitheatre": "amphitheater", + "amphitheatres": "amphitheaters", + "anaemia": "anemia", + "anaemic": "anemic", + "anaesthesia": "anesthesia", + "anaesthetic": "anesthetic", + "anaesthetics": "anesthetics", + "anaesthetise": "anesthetize", + "anaesthetised": "anesthetized", + "anaesthetises": "anesthetizes", + "anaesthetising": "anesthetizing", + "anaesthetist": "anesthetist", + "anaesthetists": "anesthetists", + "anaesthetize": "anesthetize", + "anaesthetized": "anesthetized", + "anaesthetizes": "anesthetizes", + "anaesthetizing": "anesthetizing", + "analogue": "analog", + "analogues": "analogs", + "analyse": "analyze", + "analysed": "analyzed", + "analyses": "analyzes", + "analysing": "analyzing", + "anglicise": "anglicize", + "anglicised": "anglicized", + "anglicises": "anglicizes", + "anglicising": "anglicizing", + "annualised": "annualized", + "antagonise": "antagonize", + "antagonised": "antagonized", + "antagonises": "antagonizes", + "antagonising": "antagonizing", + "apologise": "apologize", + "apologised": "apologized", + "apologises": "apologizes", + "apologising": "apologizing", + "appal": "appall", + "appals": "appalls", + "appetiser": "appetizer", + "appetisers": "appetizers", + "appetising": "appetizing", + "appetisingly": "appetizingly", + "arbour": "arbor", + "arbours": "arbors", + "archeological": "archaeological", + "archaeologically": "archeologically", + "archaeologist": "archeologist", + "archaeologists": "archeologists", + "archaeology": "archeology", + "ardour": "ardor", + "armour": "armor", + "armoured": "armored", + "armourer": "armorer", + "armourers": "armorers", + "armouries": "armories", + "armoury": "armory", + "artefact": "artifact", + "artefacts": "artifacts", + "authorise": "authorize", + "authorised": "authorized", + "authorises": "authorizes", + "authorising": "authorizing", + "axe": "ax", + "backpedalled": "backpedaled", + "backpedalling": "backpedaling", + "bannister": "banister", + "bannisters": "banisters", + "baptise": "baptize", + "baptised": "baptized", + "baptises": "baptizes", + "baptising": "baptizing", + "bastardise": "bastardize", + "bastardised": "bastardized", + "bastardises": "bastardizes", + "bastardising": "bastardizing", + "battleax": "battleaxe", + "baulk": "balk", + "baulked": "balked", + "baulking": "balking", + "baulks": "balks", + "bedevilled": "bedeviled", + "bedevilling": "bedeviling", + "behaviour": "behavior", + "behavioural": "behavioral", + "behaviourism": "behaviorism", + "behaviourist": "behaviorist", + "behaviourists": "behaviorists", + "behaviours": "behaviors", + "behove": "behoove", + "behoved": "behooved", + "behoves": "behooves", + "bejewelled": "bejeweled", + "belabour": "belabor", + "belaboured": "belabored", + "belabouring": "belaboring", + "belabours": "belabors", + "bevelled": "beveled", + "bevvies": "bevies", + "bevvy": "bevy", + "biassed": "biased", + "biassing": "biasing", + "bingeing": "binging", + "bougainvillaea": "bougainvillea", + "bougainvillaeas": "bougainvilleas", + "bowdlerise": "bowdlerize", + "bowdlerised": "bowdlerized", + "bowdlerises": "bowdlerizes", + "bowdlerising": "bowdlerizing", + "breathalyse": "breathalyze", + "breathalysed": "breathalyzed", + "breathalyser": "breathalyzer", + "breathalysers": "breathalyzers", + "breathalyses": "breathalyzes", + "breathalysing": "breathalyzing", + "brutalise": "brutalize", + "brutalised": "brutalized", + "brutalises": "brutalizes", + "brutalising": "brutalizing", + "busses": "buses", + "bussing": "busing", + "caesarean": "cesarean", + "caesareans": "cesareans", + "calibre": "caliber", + "calibres": "calibers", + "calliper": "caliper", + "callipers": "calipers", + "callisthenics": "calisthenics", + "canalise": "canalize", + "canalised": "canalized", + "canalises": "canalizes", + "canalising": "canalizing", + "cancelation": "cancellation", + "cancelations": "cancellations", + "cancelled": "canceled", + "cancelling": "canceling", + "candour": "candor", + "cannibalise": "cannibalize", + "cannibalised": "cannibalized", + "cannibalises": "cannibalizes", + "cannibalising": "cannibalizing", + "canonise": "canonize", + "canonised": "canonized", + "canonises": "canonizes", + "canonising": "canonizing", + "capitalise": "capitalize", + "capitalised": "capitalized", + "capitalises": "capitalizes", + "capitalising": "capitalizing", + "caramelise": "caramelize", + "caramelised": "caramelized", + "caramelises": "caramelizes", + "caramelising": "caramelizing", + "carbonise": "carbonize", + "carbonised": "carbonized", + "carbonises": "carbonizes", + "carbonising": "carbonizing", + "carolled": "caroled", + "carolling": "caroling", + "catalogue": "catalog", + "catalogued": "cataloged", + "catalogues": "catalogs", + "cataloguing": "cataloging", + "catalyse": "catalyze", + "catalysed": "catalyzed", + "catalyses": "catalyzes", + "catalysing": "catalyzing", + "categorise": "categorize", + "categorised": "categorized", + "categorises": "categorizes", + "categorising": "categorizing", + "cauterise": "cauterize", + "cauterised": "cauterized", + "cauterises": "cauterizes", + "cauterising": "cauterizing", + "cavilled": "caviled", + "cavilling": "caviling", + "centigramme": "centigram", + "centigrammes": "centigrams", + "centilitre": "centiliter", + "centilitres": "centiliters", + "centimetre": "centimeter", + "centimetres": "centimeters", + "centralise": "centralize", + "centralised": "centralized", + "centralises": "centralizes", + "centralising": "centralizing", + "centre": "center", + "centred": "centered", + "centrefold": "centerfold", + "centrefolds": "centerfolds", + "centrepiece": "centerpiece", + "centrepieces": "centerpieces", + "centres": "centers", + "channelled": "channeled", + "channelling": "channeling", + "characterise": "characterize", + "characterised": "characterized", + "characterises": "characterizes", + "characterising": "characterizing", + "cheque": "check", + "chequebook": "checkbook", + "chequebooks": "checkbooks", + "chequered": "checkered", + "cheques": "checks", + "chilli": "chili", + "chimaera": "chimera", + "chimaeras": "chimeras", + "chiselled": "chiseled", + "chiselling": "chiseling", + "circularise": "circularize", + "circularised": "circularized", + "circularises": "circularizes", + "circularising": "circularizing", + "civilise": "civilize", + "civilised": "civilized", + "civilises": "civilizes", + "civilising": "civilizing", + "clamour": "clamor", + "clamoured": "clamored", + "clamouring": "clamoring", + "clamours": "clamors", + "clangour": "clangor", + "clarinettist": "clarinetist", + "clarinettists": "clarinetists", + "collectivise": "collectivize", + "collectivised": "collectivized", + "collectivises": "collectivizes", + "collectivising": "collectivizing", + "colonisation": "colonization", + "colonise": "colonize", + "colonised": "colonized", + "coloniser": "colonizer", + "colonisers": "colonizers", + "colonises": "colonizes", + "colonising": "colonizing", + "colour": "color", + "colourant": "colorant", + "colourants": "colorants", + "coloured": "colored", + "coloureds": "coloreds", + "colourful": "colorful", + "colourfully": "colorfully", + "colouring": "coloring", + "colourize": "colorize", + "colourized": "colorized", + "colourizes": "colorizes", + "colourizing": "colorizing", + "colourless": "colorless", + "colours": "colors", + "commercialise": "commercialize", + "commercialised": "commercialized", + "commercialises": "commercializes", + "commercialising": "commercializing", + "compartmentalise": "compartmentalize", + "compartmentalised": "compartmentalized", + "compartmentalises": "compartmentalizes", + "compartmentalising": "compartmentalizing", + "computerise": "computerize", + "computerised": "computerized", + "computerises": "computerizes", + "computerising": "computerizing", + "conceptualise": "conceptualize", + "conceptualised": "conceptualized", + "conceptualises": "conceptualizes", + "conceptualising": "conceptualizing", + "connexion": "connection", + "connexions": "connections", + "contextualise": "contextualize", + "contextualised": "contextualized", + "contextualises": "contextualizes", + "contextualising": "contextualizing", + "cosier": "cozier", + "cosies": "cozies", + "cosiest": "coziest", + "cosily": "cozily", + "cosiness": "coziness", + "cosy": "cozy", + "councillor": "councilor", + "councillors": "councilors", + "counselled": "counseled", + "counselling": "counseling", + "counsellor": "counselor", + "counsellors": "counselors", + "crenelated": "crenellated", + "criminalise": "criminalize", + "criminalised": "criminalized", + "criminalises": "criminalizes", + "criminalising": "criminalizing", + "criticise": "criticize", + "criticised": "criticized", + "criticises": "criticizes", + "criticising": "criticizing", + "crueller": "crueler", + "cruellest": "cruelest", + "crystallisation": "crystallization", + "crystallise": "crystallize", + "crystallised": "crystallized", + "crystallises": "crystallizes", + "crystallising": "crystallizing", + "cudgelled": "cudgeled", + "cudgelling": "cudgeling", + "customise": "customize", + "customised": "customized", + "customises": "customizes", + "customising": "customizing", + "cypher": "cipher", + "cyphers": "ciphers", + "decentralisation": "decentralization", + "decentralise": "decentralize", + "decentralised": "decentralized", + "decentralises": "decentralizes", + "decentralising": "decentralizing", + "decriminalisation": "decriminalization", + "decriminalise": "decriminalize", + "decriminalised": "decriminalized", + "decriminalises": "decriminalizes", + "decriminalising": "decriminalizing", + "defence": "defense", + "defenceless": "defenseless", + "defences": "defenses", + "dehumanisation": "dehumanization", + "dehumanise": "dehumanize", + "dehumanised": "dehumanized", + "dehumanises": "dehumanizes", + "dehumanising": "dehumanizing", + "demeanour": "demeanor", + "demilitarisation": "demilitarization", + "demilitarise": "demilitarize", + "demilitarised": "demilitarized", + "demilitarises": "demilitarizes", + "demilitarising": "demilitarizing", + "demobilisation": "demobilization", + "demobilise": "demobilize", + "demobilised": "demobilized", + "demobilises": "demobilizes", + "demobilising": "demobilizing", + "democratisation": "democratization", + "democratise": "democratize", + "democratised": "democratized", + "democratises": "democratizes", + "democratising": "democratizing", + "demonise": "demonize", + "demonised": "demonized", + "demonises": "demonizes", + "demonising": "demonizing", + "demoralisation": "demoralization", + "demoralise": "demoralize", + "demoralised": "demoralized", + "demoralises": "demoralizes", + "demoralising": "demoralizing", + "denationalisation": "denationalization", + "denationalise": "denationalize", + "denationalised": "denationalized", + "denationalises": "denationalizes", + "denationalising": "denationalizing", + "deodorise": "deodorize", + "deodorised": "deodorized", + "deodorises": "deodorizes", + "deodorising": "deodorizing", + "depersonalise": "depersonalize", + "depersonalised": "depersonalized", + "depersonalises": "depersonalizes", + "depersonalising": "depersonalizing", + "deputise": "deputize", + "deputised": "deputized", + "deputises": "deputizes", + "deputising": "deputizing", + "desensitisation": "desensitization", + "desensitise": "desensitize", + "desensitised": "desensitized", + "desensitises": "desensitizes", + "desensitising": "desensitizing", + "destabilisation": "destabilization", + "destabilise": "destabilize", + "destabilised": "destabilized", + "destabilises": "destabilizes", + "destabilising": "destabilizing", + "dialled": "dialed", + "dialling": "dialing", + "dialogue": "dialog", + "dialogues": "dialogs", + "diarrhoea": "diarrhea", + "digitise": "digitize", + "digitised": "digitized", + "digitises": "digitizes", + "digitising": "digitizing", + "disc": "disk", + "discolour": "discolor", + "discoloured": "discolored", + "discolouring": "discoloring", + "discolours": "discolors", + "discs": "disks", + "disembowelled": "disemboweled", + "disembowelling": "disemboweling", + "disfavour": "disfavor", + "dishevelled": "disheveled", + "dishonour": "dishonor", + "dishonourable": "dishonorable", + "dishonourably": "dishonorably", + "dishonoured": "dishonored", + "dishonouring": "dishonoring", + "dishonours": "dishonors", + "disorganisation": "disorganization", + "disorganised": "disorganized", + "distil": "distill", + "distils": "distills", + "dramatisation": "dramatization", + "dramatisations": "dramatizations", + "dramatise": "dramatize", + "dramatised": "dramatized", + "dramatises": "dramatizes", + "dramatising": "dramatizing", + "draught": "draft", + "draughtboard": "draftboard", + "draughtboards": "draftboards", + "draughtier": "draftier", + "draughtiest": "draftiest", + "draughts": "drafts", + "draughtsman": "draftsman", + "draughtsmanship": "draftsmanship", + "draughtsmen": "draftsmen", + "draughtswoman": "draftswoman", + "draughtswomen": "draftswomen", + "draughty": "drafty", + "drivelled": "driveled", + "drivelling": "driveling", + "duelled": "dueled", + "duelling": "dueling", + "economise": "economize", + "economised": "economized", + "economises": "economizes", + "economising": "economizing", + "edoema": "edema", + "editorialise": "editorialize", + "editorialised": "editorialized", + "editorialises": "editorializes", + "editorialising": "editorializing", + "empathise": "empathize", + "empathised": "empathized", + "empathises": "empathizes", + "empathising": "empathizing", + "emphasise": "emphasize", + "emphasised": "emphasized", + "emphasises": "emphasizes", + "emphasising": "emphasizing", + "enamelled": "enameled", + "enamelling": "enameling", + "enamoured": "enamored", + "encyclopaedia": "encyclopedia", + "encyclopaedias": "encyclopedias", + "encyclopaedic": "encyclopedic", + "endeavour": "endeavor", + "endeavoured": "endeavored", + "endeavouring": "endeavoring", + "endeavours": "endeavors", + "energise": "energize", + "energised": "energized", + "energises": "energizes", + "energising": "energizing", + "enrol": "enroll", + "enrols": "enrolls", + "enthral": "enthrall", + "enthrals": "enthralls", + "epaulette": "epaulet", + "epaulettes": "epaulets", + "epicentre": "epicenter", + "epicentres": "epicenters", + "epilogue": "epilog", + "epilogues": "epilogs", + "epitomise": "epitomize", + "epitomised": "epitomized", + "epitomises": "epitomizes", + "epitomising": "epitomizing", + "equalisation": "equalization", + "equalise": "equalize", + "equalised": "equalized", + "equaliser": "equalizer", + "equalisers": "equalizers", + "equalises": "equalizes", + "equalising": "equalizing", + "eulogise": "eulogize", + "eulogised": "eulogized", + "eulogises": "eulogizes", + "eulogising": "eulogizing", + "evangelise": "evangelize", + "evangelised": "evangelized", + "evangelises": "evangelizes", + "evangelising": "evangelizing", + "exorcise": "exorcize", + "exorcised": "exorcized", + "exorcises": "exorcizes", + "exorcising": "exorcizing", + "extemporisation": "extemporization", + "extemporise": "extemporize", + "extemporised": "extemporized", + "extemporises": "extemporizes", + "extemporising": "extemporizing", + "externalisation": "externalization", + "externalisations": "externalizations", + "externalise": "externalize", + "externalised": "externalized", + "externalises": "externalizes", + "externalising": "externalizing", + "factorise": "factorize", + "factorised": "factorized", + "factorises": "factorizes", + "factorising": "factorizing", + "faecal": "fecal", + "faeces": "feces", + "familiarisation": "familiarization", + "familiarise": "familiarize", + "familiarised": "familiarized", + "familiarises": "familiarizes", + "familiarising": "familiarizing", + "fantasise": "fantasize", + "fantasised": "fantasized", + "fantasises": "fantasizes", + "fantasising": "fantasizing", + "favour": "favor", + "favourable": "favorable", + "favourably": "favorably", + "favoured": "favored", + "favouring": "favoring", + "favourite": "favorite", + "favourites": "favorites", + "favouritism": "favoritism", + "favours": "favors", + "feminise": "feminize", + "feminised": "feminized", + "feminises": "feminizes", + "feminising": "feminizing", + "fertilisation": "fertilization", + "fertilise": "fertilize", + "fertilised": "fertilized", + "fertiliser": "fertilizer", + "fertilisers": "fertilizers", + "fertilises": "fertilizes", + "fertilising": "fertilizing", + "fervour": "fervor", + "fibre": "fiber", + "fibreglass": "fiberglass", + "fibres": "fibers", + "fictionalisation": "fictionalization", + "fictionalisations": "fictionalizations", + "fictionalise": "fictionalize", + "fictionalised": "fictionalized", + "fictionalises": "fictionalizes", + "fictionalising": "fictionalizing", + "fillet": "filet", + "filleted": "fileted", + "filleting": "fileting", + "fillets": "filets", + "finalisation": "finalization", + "finalise": "finalize", + "finalised": "finalized", + "finalises": "finalizes", + "finalising": "finalizing", + "flautist": "flutist", + "flautists": "flutists", + "flavour": "flavor", + "flavoured": "flavored", + "flavouring": "flavoring", + "flavourings": "flavorings", + "flavourless": "flavorless", + "flavours": "flavors", + "flavoursome": "flavorsome", + "flyer / flier": "flier / flyer", + "foetal": "fetal", + "foetid": "fetid", + "foetus": "fetus", + "foetuses": "fetuses", + "formalisation": "formalization", + "formalise": "formalize", + "formalised": "formalized", + "formalises": "formalizes", + "formalising": "formalizing", + "fossilisation": "fossilization", + "fossilise": "fossilize", + "fossilised": "fossilized", + "fossilises": "fossilizes", + "fossilising": "fossilizing", + "fraternisation": "fraternization", + "fraternise": "fraternize", + "fraternised": "fraternized", + "fraternises": "fraternizes", + "fraternising": "fraternizing", + "fulfil": "fulfill", + "fulfilment": "fulfillment", + "fulfils": "fulfills", + "funnelled": "funneled", + "funnelling": "funneling", + "galvanise": "galvanize", + "galvanised": "galvanized", + "galvanises": "galvanizes", + "galvanising": "galvanizing", + "gambolled": "gamboled", + "gambolling": "gamboling", + "gaol": "jail", + "gaolbird": "jailbird", + "gaolbirds": "jailbirds", + "gaolbreak": "jailbreak", + "gaolbreaks": "jailbreaks", + "gaoled": "jailed", + "gaoler": "jailer", + "gaolers": "jailers", + "gaoling": "jailing", + "gaols": "jails", + "gasses": "gases", + "gage": "gauge", + "gaged": "gauged", + "gages": "gauges", + "gaging": "gauging", + "generalisation": "generalization", + "generalisations": "generalizations", + "generalise": "generalize", + "generalised": "generalized", + "generalises": "generalizes", + "generalising": "generalizing", + "ghettoise": "ghettoize", + "ghettoised": "ghettoized", + "ghettoises": "ghettoizes", + "ghettoising": "ghettoizing", + "gipsies": "gypsies", + "glamorise": "glamorize", + "glamorised": "glamorized", + "glamorises": "glamorizes", + "glamorising": "glamorizing", + "glamor": "glamour", + "globalisation": "globalization", + "globalise": "globalize", + "globalised": "globalized", + "globalises": "globalizes", + "globalising": "globalizing", + "glueing": "gluing", + "goitre": "goiter", + "goitres": "goiters", + "gonorrhoea": "gonorrhea", + "gramme": "gram", + "grammes": "grams", + "gravelled": "graveled", + "grey": "gray", + "greyed": "grayed", + "greying": "graying", + "greyish": "grayish", + "greyness": "grayness", + "greys": "grays", + "grovelled": "groveled", + "grovelling": "groveling", + "groyne": "groin", + "groynes": "groins", + "gruelling": "grueling", + "gruellingly": "gruelingly", + "gryphon": "griffin", + "gryphons": "griffins", + "gynaecological": "gynecological", + "gynaecologist": "gynecologist", + "gynaecologists": "gynecologists", + "gynaecology": "gynecology", + "haematological": "hematological", + "haematologist": "hematologist", + "haematologists": "hematologists", + "haematology": "hematology", + "haemoglobin": "hemoglobin", + "haemophilia": "hemophilia", + "haemophiliac": "hemophiliac", + "haemophiliacs": "hemophiliacs", + "haemorrhage": "hemorrhage", + "haemorrhaged": "hemorrhaged", + "haemorrhages": "hemorrhages", + "haemorrhaging": "hemorrhaging", + "haemorrhoids": "hemorrhoids", + "harbour": "harbor", + "harboured": "harbored", + "harbouring": "harboring", + "harbours": "harbors", + "harmonisation": "harmonization", + "harmonise": "harmonize", + "harmonised": "harmonized", + "harmonises": "harmonizes", + "harmonising": "harmonizing", + "homoeopath": "homeopath", + "homoeopathic": "homeopathic", + "homoeopaths": "homeopaths", + "homoeopathy": "homeopathy", + "homogenise": "homogenize", + "homogenised": "homogenized", + "homogenises": "homogenizes", + "homogenising": "homogenizing", + "honour": "honor", + "honourable": "honorable", + "honourably": "honorably", + "honoured": "honored", + "honouring": "honoring", + "honours": "honors", + "hospitalisation": "hospitalization", + "hospitalise": "hospitalize", + "hospitalised": "hospitalized", + "hospitalises": "hospitalizes", + "hospitalising": "hospitalizing", + "humanise": "humanize", + "humanised": "humanized", + "humanises": "humanizes", + "humanising": "humanizing", + "humour": "humor", + "humoured": "humored", + "humouring": "humoring", + "humourless": "humorless", + "humours": "humors", + "hybridise": "hybridize", + "hybridised": "hybridized", + "hybridises": "hybridizes", + "hybridising": "hybridizing", + "hypnotise": "hypnotize", + "hypnotised": "hypnotized", + "hypnotises": "hypnotizes", + "hypnotising": "hypnotizing", + "hypothesise": "hypothesize", + "hypothesised": "hypothesized", + "hypothesises": "hypothesizes", + "hypothesising": "hypothesizing", + "idealisation": "idealization", + "idealise": "idealize", + "idealised": "idealized", + "idealises": "idealizes", + "idealising": "idealizing", + "idolise": "idolize", + "idolised": "idolized", + "idolises": "idolizes", + "idolising": "idolizing", + "immobilisation": "immobilization", + "immobilise": "immobilize", + "immobilised": "immobilized", + "immobiliser": "immobilizer", + "immobilisers": "immobilizers", + "immobilises": "immobilizes", + "immobilising": "immobilizing", + "immortalise": "immortalize", + "immortalised": "immortalized", + "immortalises": "immortalizes", + "immortalising": "immortalizing", + "immunisation": "immunization", + "immunise": "immunize", + "immunised": "immunized", + "immunises": "immunizes", + "immunising": "immunizing", + "impanelled": "impaneled", + "impanelling": "impaneling", + "imperilled": "imperiled", + "imperilling": "imperiling", + "individualise": "individualize", + "individualised": "individualized", + "individualises": "individualizes", + "individualising": "individualizing", + "industrialise": "industrialize", + "industrialised": "industrialized", + "industrialises": "industrializes", + "industrialising": "industrializing", + "inflexion": "inflection", + "inflexions": "inflections", + "initialise": "initialize", + "initialised": "initialized", + "initialises": "initializes", + "initialising": "initializing", + "initialled": "initialed", + "initialling": "initialing", + "instal": "install", + "instalment": "installment", + "instalments": "installments", + "instals": "installs", + "instil": "instill", + "instils": "instills", + "institutionalisation": "institutionalization", + "institutionalise": "institutionalize", + "institutionalised": "institutionalized", + "institutionalises": "institutionalizes", + "institutionalising": "institutionalizing", + "intellectualise": "intellectualize", + "intellectualised": "intellectualized", + "intellectualises": "intellectualizes", + "intellectualising": "intellectualizing", + "internalisation": "internalization", + "internalise": "internalize", + "internalised": "internalized", + "internalises": "internalizes", + "internalising": "internalizing", + "internationalisation": "internationalization", + "internationalise": "internationalize", + "internationalised": "internationalized", + "internationalises": "internationalizes", + "internationalising": "internationalizing", + "ionisation": "ionization", + "ionise": "ionize", + "ionised": "ionized", + "ioniser": "ionizer", + "ionisers": "ionizers", + "ionises": "ionizes", + "ionising": "ionizing", + "italicise": "italicize", + "italicised": "italicized", + "italicises": "italicizes", + "italicising": "italicizing", + "itemise": "itemize", + "itemised": "itemized", + "itemises": "itemizes", + "itemising": "itemizing", + "jeopardise": "jeopardize", + "jeopardised": "jeopardized", + "jeopardises": "jeopardizes", + "jeopardising": "jeopardizing", + "jewelled": "jeweled", + "jeweller": "jeweler", + "jewellers": "jewelers", + "jewellery": "jewelry", + "judgement": "judgment", + "kilogramme": "kilogram", + "kilogrammes": "kilograms", + "kilometre": "kilometer", + "kilometres": "kilometers", + "labelled": "labeled", + "labelling": "labeling", + "labour": "labor", + "laboured": "labored", + "labourer": "laborer", + "labourers": "laborers", + "labouring": "laboring", + "labours": "labors", + "lacklustre": "lackluster", + "legalisation": "legalization", + "legalise": "legalize", + "legalised": "legalized", + "legalises": "legalizes", + "legalising": "legalizing", + "legitimise": "legitimize", + "legitimised": "legitimized", + "legitimises": "legitimizes", + "legitimising": "legitimizing", + "leukaemia": "leukemia", + "levelled": "leveled", + "leveller": "leveler", + "levellers": "levelers", + "levelling": "leveling", + "libelled": "libeled", + "libelling": "libeling", + "libellous": "libelous", + "liberalisation": "liberalization", + "liberalise": "liberalize", + "liberalised": "liberalized", + "liberalises": "liberalizes", + "liberalising": "liberalizing", + "licence": "license", + "licenced": "licensed", + "licences": "licenses", + "licencing": "licensing", + "likeable": "likable", + "lionisation": "lionization", + "lionise": "lionize", + "lionised": "lionized", + "lionises": "lionizes", + "lionising": "lionizing", + "liquidise": "liquidize", + "liquidised": "liquidized", + "liquidiser": "liquidizer", + "liquidisers": "liquidizers", + "liquidises": "liquidizes", + "liquidising": "liquidizing", + "litre": "liter", + "litres": "liters", + "localise": "localize", + "localised": "localized", + "localises": "localizes", + "localising": "localizing", + "louvre": "louver", + "louvred": "louvered", + "louvres": "louvers", + "lustre": "luster", + "magnetise": "magnetize", + "magnetised": "magnetized", + "magnetises": "magnetizes", + "magnetising": "magnetizing", + "manoeuvrability": "maneuverability", + "manoeuvrable": "maneuverable", + "manoeuvre": "maneuver", + "manoeuvred": "maneuvered", + "manoeuvres": "maneuvers", + "manoeuvring": "maneuvering", + "manoeuvrings": "maneuverings", + "marginalisation": "marginalization", + "marginalise": "marginalize", + "marginalised": "marginalized", + "marginalises": "marginalizes", + "marginalising": "marginalizing", + "marshalled": "marshaled", + "marshalling": "marshaling", + "marvelled": "marveled", + "marvelling": "marveling", + "marvellous": "marvelous", + "marvellously": "marvelously", + "materialisation": "materialization", + "materialise": "materialize", + "materialised": "materialized", + "materialises": "materializes", + "materialising": "materializing", + "maximisation": "maximization", + "maximise": "maximize", + "maximised": "maximized", + "maximises": "maximizes", + "maximising": "maximizing", + "meagre": "meager", + "mechanisation": "mechanization", + "mechanise": "mechanize", + "mechanised": "mechanized", + "mechanises": "mechanizes", + "mechanising": "mechanizing", + "mediaeval": "medieval", + "memorialise": "memorialize", + "memorialised": "memorialized", + "memorialises": "memorializes", + "memorialising": "memorializing", + "memorise": "memorize", + "memorised": "memorized", + "memorises": "memorizes", + "memorising": "memorizing", + "mesmerise": "mesmerize", + "mesmerised": "mesmerized", + "mesmerises": "mesmerizes", + "mesmerising": "mesmerizing", + "metabolise": "metabolize", + "metabolised": "metabolized", + "metabolises": "metabolizes", + "metabolising": "metabolizing", + "metre": "meter", + "metres": "meters", + "micrometre": "micrometer", + "micrometres": "micrometers", + "militarise": "militarize", + "militarised": "militarized", + "militarises": "militarizes", + "militarising": "militarizing", + "milligramme": "milligram", + "milligrammes": "milligrams", + "millilitre": "milliliter", + "millilitres": "milliliters", + "millimetre": "millimeter", + "millimetres": "millimeters", + "miniaturisation": "miniaturization", + "miniaturise": "miniaturize", + "miniaturised": "miniaturized", + "miniaturises": "miniaturizes", + "miniaturising": "miniaturizing", + "minibusses": "minibuses", + "minimise": "minimize", + "minimised": "minimized", + "minimises": "minimizes", + "minimising": "minimizing", + "misbehaviour": "misbehavior", + "misdemeanour": "misdemeanor", + "misdemeanours": "misdemeanors", + "misspelt": "misspelled", + "mitre": "miter", + "mitres": "miters", + "mobilisation": "mobilization", + "mobilise": "mobilize", + "mobilised": "mobilized", + "mobilises": "mobilizes", + "mobilising": "mobilizing", + "modelled": "modeled", + "modeller": "modeler", + "modellers": "modelers", + "modelling": "modeling", + "modernise": "modernize", + "modernised": "modernized", + "modernises": "modernizes", + "modernising": "modernizing", + "moisturise": "moisturize", + "moisturised": "moisturized", + "moisturiser": "moisturizer", + "moisturisers": "moisturizers", + "moisturises": "moisturizes", + "moisturising": "moisturizing", + "monologue": "monolog", + "monologues": "monologs", + "monopolisation": "monopolization", + "monopolise": "monopolize", + "monopolised": "monopolized", + "monopolises": "monopolizes", + "monopolising": "monopolizing", + "moralise": "moralize", + "moralised": "moralized", + "moralises": "moralizes", + "moralising": "moralizing", + "motorised": "motorized", + "mould": "mold", + "moulded": "molded", + "moulder": "molder", + "mouldered": "moldered", + "mouldering": "moldering", + "moulders": "molders", + "mouldier": "moldier", + "mouldiest": "moldiest", + "moulding": "molding", + "mouldings": "moldings", + "moulds": "molds", + "mouldy": "moldy", + "moult": "molt", + "moulted": "molted", + "moulting": "molting", + "moults": "molts", + "moustache": "mustache", + "moustached": "mustached", + "moustaches": "mustaches", + "moustachioed": "mustachioed", + "multicoloured": "multicolored", + "nationalisation": "nationalization", + "nationalisations": "nationalizations", + "nationalise": "nationalize", + "nationalised": "nationalized", + "nationalises": "nationalizes", + "nationalising": "nationalizing", + "naturalisation": "naturalization", + "naturalise": "naturalize", + "naturalised": "naturalized", + "naturalises": "naturalizes", + "naturalising": "naturalizing", + "neighbour": "neighbor", + "neighbourhood": "neighborhood", + "neighbourhoods": "neighborhoods", + "neighbouring": "neighboring", + "neighbourliness": "neighborliness", + "neighbourly": "neighborly", + "neighbours": "neighbors", + "neutralisation": "neutralization", + "neutralise": "neutralize", + "neutralised": "neutralized", + "neutralises": "neutralizes", + "neutralising": "neutralizing", + "normalisation": "normalization", + "normalise": "normalize", + "normalised": "normalized", + "normalises": "normalizes", + "normalising": "normalizing", + "odour": "odor", + "odourless": "odorless", + "odours": "odors", + "oesophagus": "esophagus", + "oesophaguses": "esophaguses", + "oestrogen": "estrogen", + "offence": "offense", + "offences": "offenses", + "omelette": "omelet", + "omelettes": "omelets", + "optimise": "optimize", + "optimised": "optimized", + "optimises": "optimizes", + "optimising": "optimizing", + "organisation": "organization", + "organisational": "organizational", + "organisations": "organizations", + "organise": "organize", + "organised": "organized", + "organiser": "organizer", + "organisers": "organizers", + "organises": "organizes", + "organising": "organizing", + "orthopaedic": "orthopedic", + "orthopaedics": "orthopedics", + "ostracise": "ostracize", + "ostracised": "ostracized", + "ostracises": "ostracizes", + "ostracising": "ostracizing", + "outmanoeuvre": "outmaneuver", + "outmanoeuvred": "outmaneuvered", + "outmanoeuvres": "outmaneuvers", + "outmanoeuvring": "outmaneuvering", + "overemphasise": "overemphasize", + "overemphasised": "overemphasized", + "overemphasises": "overemphasizes", + "overemphasising": "overemphasizing", + "oxidisation": "oxidization", + "oxidise": "oxidize", + "oxidised": "oxidized", + "oxidises": "oxidizes", + "oxidising": "oxidizing", + "paederast": "pederast", + "paederasts": "pederasts", + "paediatric": "pediatric", + "paediatrician": "pediatrician", + "paediatricians": "pediatricians", + "paediatrics": "pediatrics", + "paedophile": "pedophile", + "paedophiles": "pedophiles", + "paedophilia": "pedophilia", + "palaeolithic": "paleolithic", + "palaeontologist": "paleontologist", + "palaeontologists": "paleontologists", + "palaeontology": "paleontology", + "panelled": "paneled", + "panelling": "paneling", + "panellist": "panelist", + "panellists": "panelists", + "paralyse": "paralyze", + "paralysed": "paralyzed", + "paralyses": "paralyzes", + "paralysing": "paralyzing", + "parcelled": "parceled", + "parcelling": "parceling", + "parlour": "parlor", + "parlours": "parlors", + "particularise": "particularize", + "particularised": "particularized", + "particularises": "particularizes", + "particularising": "particularizing", + "passivisation": "passivization", + "passivise": "passivize", + "passivised": "passivized", + "passivises": "passivizes", + "passivising": "passivizing", + "pasteurisation": "pasteurization", + "pasteurise": "pasteurize", + "pasteurised": "pasteurized", + "pasteurises": "pasteurizes", + "pasteurising": "pasteurizing", + "patronise": "patronize", + "patronised": "patronized", + "patronises": "patronizes", + "patronising": "patronizing", + "patronisingly": "patronizingly", + "pedalled": "pedaled", + "pedalling": "pedaling", + "pedestrianisation": "pedestrianization", + "pedestrianise": "pedestrianize", + "pedestrianised": "pedestrianized", + "pedestrianises": "pedestrianizes", + "pedestrianising": "pedestrianizing", + "penalise": "penalize", + "penalised": "penalized", + "penalises": "penalizes", + "penalising": "penalizing", + "pencilled": "penciled", + "pencilling": "penciling", + "personalise": "personalize", + "personalised": "personalized", + "personalises": "personalizes", + "personalising": "personalizing", + "pharmacopoeia": "pharmacopeia", + "pharmacopoeias": "pharmacopeias", + "philosophise": "philosophize", + "philosophised": "philosophized", + "philosophises": "philosophizes", + "philosophising": "philosophizing", + "philtre": "filter", + "philtres": "filters", + "phoney": "phony", + "plagiarise": "plagiarize", + "plagiarised": "plagiarized", + "plagiarises": "plagiarizes", + "plagiarising": "plagiarizing", + "plough": "plow", + "ploughed": "plowed", + "ploughing": "plowing", + "ploughman": "plowman", + "ploughmen": "plowmen", + "ploughs": "plows", + "ploughshare": "plowshare", + "ploughshares": "plowshares", + "polarisation": "polarization", + "polarise": "polarize", + "polarised": "polarized", + "polarises": "polarizes", + "polarising": "polarizing", + "politicisation": "politicization", + "politicise": "politicize", + "politicised": "politicized", + "politicises": "politicizes", + "politicising": "politicizing", + "popularisation": "popularization", + "popularise": "popularize", + "popularised": "popularized", + "popularises": "popularizes", + "popularising": "popularizing", + "pouffe": "pouf", + "pouffes": "poufs", + "practise": "practice", + "practised": "practiced", + "practises": "practices", + "practising": "practicing", + "praesidium": "presidium", + "praesidiums": "presidiums", + "pressurisation": "pressurization", + "pressurise": "pressurize", + "pressurised": "pressurized", + "pressurises": "pressurizes", + "pressurising": "pressurizing", + "pretence": "pretense", + "pretences": "pretenses", + "primaeval": "primeval", + "prioritisation": "prioritization", + "prioritise": "prioritize", + "prioritised": "prioritized", + "prioritises": "prioritizes", + "prioritising": "prioritizing", + "privatisation": "privatization", + "privatisations": "privatizations", + "privatise": "privatize", + "privatised": "privatized", + "privatises": "privatizes", + "privatising": "privatizing", + "professionalisation": "professionalization", + "professionalise": "professionalize", + "professionalised": "professionalized", + "professionalises": "professionalizes", + "professionalising": "professionalizing", + "programme": "program", + "programmes": "programs", + "prologue": "prolog", + "prologues": "prologs", + "propagandise": "propagandize", + "propagandised": "propagandized", + "propagandises": "propagandizes", + "propagandising": "propagandizing", + "proselytise": "proselytize", + "proselytised": "proselytized", + "proselytiser": "proselytizer", + "proselytisers": "proselytizers", + "proselytises": "proselytizes", + "proselytising": "proselytizing", + "psychoanalyse": "psychoanalyze", + "psychoanalysed": "psychoanalyzed", + "psychoanalyses": "psychoanalyzes", + "psychoanalysing": "psychoanalyzing", + "publicise": "publicize", + "publicised": "publicized", + "publicises": "publicizes", + "publicising": "publicizing", + "pulverisation": "pulverization", + "pulverise": "pulverize", + "pulverised": "pulverized", + "pulverises": "pulverizes", + "pulverising": "pulverizing", + "pummelled": "pummel", + "pummelling": "pummeled", + "pyjama": "pajama", + "pyjamas": "pajamas", + "pzazz": "pizzazz", + "quarrelled": "quarreled", + "quarrelling": "quarreling", + "radicalise": "radicalize", + "radicalised": "radicalized", + "radicalises": "radicalizes", + "radicalising": "radicalizing", + "rancour": "rancor", + "randomise": "randomize", + "randomised": "randomized", + "randomises": "randomizes", + "randomising": "randomizing", + "rationalisation": "rationalization", + "rationalisations": "rationalizations", + "rationalise": "rationalize", + "rationalised": "rationalized", + "rationalises": "rationalizes", + "rationalising": "rationalizing", + "ravelled": "raveled", + "ravelling": "raveling", + "realisable": "realizable", + "realisation": "realization", + "realisations": "realizations", + "realise": "realize", + "realised": "realized", + "realises": "realizes", + "realising": "realizing", + "recognisable": "recognizable", + "recognisably": "recognizably", + "recognisance": "recognizance", + "recognise": "recognize", + "recognised": "recognized", + "recognises": "recognizes", + "recognising": "recognizing", + "reconnoitre": "reconnoiter", + "reconnoitred": "reconnoitered", + "reconnoitres": "reconnoiters", + "reconnoitring": "reconnoitering", + "refuelled": "refueled", + "refuelling": "refueling", + "regularisation": "regularization", + "regularise": "regularize", + "regularised": "regularized", + "regularises": "regularizes", + "regularising": "regularizing", + "remodelled": "remodeled", + "remodelling": "remodeling", + "remould": "remold", + "remoulded": "remolded", + "remoulding": "remolding", + "remoulds": "remolds", + "reorganisation": "reorganization", + "reorganisations": "reorganizations", + "reorganise": "reorganize", + "reorganised": "reorganized", + "reorganises": "reorganizes", + "reorganising": "reorganizing", + "revelled": "reveled", + "reveller": "reveler", + "revellers": "revelers", + "revelling": "reveling", + "revitalise": "revitalize", + "revitalised": "revitalized", + "revitalises": "revitalizes", + "revitalising": "revitalizing", + "revolutionise": "revolutionize", + "revolutionised": "revolutionized", + "revolutionises": "revolutionizes", + "revolutionising": "revolutionizing", + "rhapsodise": "rhapsodize", + "rhapsodised": "rhapsodized", + "rhapsodises": "rhapsodizes", + "rhapsodising": "rhapsodizing", + "rigour": "rigor", + "rigours": "rigors", + "ritualised": "ritualized", + "rivalled": "rivaled", + "rivalling": "rivaling", + "romanticise": "romanticize", + "romanticised": "romanticized", + "romanticises": "romanticizes", + "romanticising": "romanticizing", + "rumour": "rumor", + "rumoured": "rumored", + "rumours": "rumors", + "sabre": "saber", + "sabres": "sabers", + "saltpetre": "saltpeter", + "sanitise": "sanitize", + "sanitised": "sanitized", + "sanitises": "sanitizes", + "sanitising": "sanitizing", + "satirise": "satirize", + "satirised": "satirized", + "satirises": "satirizes", + "satirising": "satirizing", + "saviour": "savior", + "saviours": "saviors", + "savour": "savor", + "savoured": "savored", + "savouries": "savories", + "savouring": "savoring", + "savours": "savors", + "savoury": "savory", + "scandalise": "scandalize", + "scandalised": "scandalized", + "scandalises": "scandalizes", + "scandalising": "scandalizing", + "sceptic": "skeptic", + "sceptical": "skeptical", + "sceptically": "skeptically", + "scepticism": "skepticism", + "sceptics": "skeptics", + "sceptre": "scepter", + "sceptres": "scepters", + "scrutinise": "scrutinize", + "scrutinised": "scrutinized", + "scrutinises": "scrutinizes", + "scrutinising": "scrutinizing", + "secularisation": "secularization", + "secularise": "secularize", + "secularised": "secularized", + "secularises": "secularizes", + "secularising": "secularizing", + "sensationalise": "sensationalize", + "sensationalised": "sensationalized", + "sensationalises": "sensationalizes", + "sensationalising": "sensationalizing", + "sensitise": "sensitize", + "sensitised": "sensitized", + "sensitises": "sensitizes", + "sensitising": "sensitizing", + "sentimentalise": "sentimentalize", + "sentimentalised": "sentimentalized", + "sentimentalises": "sentimentalizes", + "sentimentalising": "sentimentalizing", + "sepulchre": "sepulcher", + "sepulchres": "sepulchers", + "serialisation": "serialization", + "serialisations": "serializations", + "serialise": "serialize", + "serialised": "serialized", + "serialises": "serializes", + "serialising": "serializing", + "sermonise": "sermonize", + "sermonised": "sermonized", + "sermonises": "sermonizes", + "sermonising": "sermonizing", + "sheikh": "sheik", + "shovelled": "shoveled", + "shovelling": "shoveling", + "shrivelled": "shriveled", + "shrivelling": "shriveling", + "signalise": "signalize", + "signalised": "signalized", + "signalises": "signalizes", + "signalising": "signalizing", + "signalled": "signaled", + "signalling": "signaling", + "smoulder": "smolder", + "smouldered": "smoldered", + "smouldering": "smoldering", + "smoulders": "smolders", + "snivelled": "sniveled", + "snivelling": "sniveling", + "snorkelled": "snorkeled", + "snorkelling": "snorkeling", + "snowplough": "snowplow", + "snowploughs": "snowplow", + "socialisation": "socialization", + "socialise": "socialize", + "socialised": "socialized", + "socialises": "socializes", + "socialising": "socializing", + "sodomise": "sodomize", + "sodomised": "sodomized", + "sodomises": "sodomizes", + "sodomising": "sodomizing", + "solemnise": "solemnize", + "solemnised": "solemnized", + "solemnises": "solemnizes", + "solemnising": "solemnizing", + "sombre": "somber", + "specialisation": "specialization", + "specialisations": "specializations", + "specialise": "specialize", + "specialised": "specialized", + "specialises": "specializes", + "specialising": "specializing", + "spectre": "specter", + "spectres": "specters", + "spiralled": "spiraled", + "spiralling": "spiraling", + "splendour": "splendor", + "splendours": "splendors", + "squirrelled": "squirreled", + "squirrelling": "squirreling", + "stabilisation": "stabilization", + "stabilise": "stabilize", + "stabilised": "stabilized", + "stabiliser": "stabilizer", + "stabilisers": "stabilizers", + "stabilises": "stabilizes", + "stabilising": "stabilizing", + "standardisation": "standardization", + "standardise": "standardize", + "standardised": "standardized", + "standardises": "standardizes", + "standardising": "standardizing", + "stencilled": "stenciled", + "stencilling": "stenciling", + "sterilisation": "sterilization", + "sterilisations": "sterilizations", + "sterilise": "sterilize", + "sterilised": "sterilized", + "steriliser": "sterilizer", + "sterilisers": "sterilizers", + "sterilises": "sterilizes", + "sterilising": "sterilizing", + "stigmatisation": "stigmatization", + "stigmatise": "stigmatize", + "stigmatised": "stigmatized", + "stigmatises": "stigmatizes", + "stigmatising": "stigmatizing", + "storey": "story", + "storeys": "stories", + "subsidisation": "subsidization", + "subsidise": "subsidize", + "subsidised": "subsidized", + "subsidiser": "subsidizer", + "subsidisers": "subsidizers", + "subsidises": "subsidizes", + "subsidising": "subsidizing", + "succour": "succor", + "succoured": "succored", + "succouring": "succoring", + "succours": "succors", + "sulphate": "sulfate", + "sulphates": "sulfates", + "sulphide": "sulfide", + "sulphides": "sulfides", + "sulphur": "sulfur", + "sulphurous": "sulfurous", + "summarise": "summarize", + "summarised": "summarized", + "summarises": "summarizes", + "summarising": "summarizing", + "swivelled": "swiveled", + "swivelling": "swiveling", + "symbolise": "symbolize", + "symbolised": "symbolized", + "symbolises": "symbolizes", + "symbolising": "symbolizing", + "sympathise": "sympathize", + "sympathised": "sympathized", + "sympathiser": "sympathizer", + "sympathisers": "sympathizers", + "sympathises": "sympathizes", + "sympathising": "sympathizing", + "synchronisation": "synchronization", + "synchronise": "synchronize", + "synchronised": "synchronized", + "synchronises": "synchronizes", + "synchronising": "synchronizing", + "synthesise": "synthesize", + "synthesised": "synthesized", + "synthesiser": "synthesizer", + "synthesisers": "synthesizers", + "synthesises": "synthesizes", + "synthesising": "synthesizing", + "syphon": "siphon", + "syphoned": "siphoned", + "syphoning": "siphoning", + "syphons": "siphons", + "systematisation": "systematization", + "systematise": "systematize", + "systematised": "systematized", + "systematises": "systematizes", + "systematising": "systematizing", + "tantalise": "tantalize", + "tantalised": "tantalized", + "tantalises": "tantalizes", + "tantalising": "tantalizing", + "tantalisingly": "tantalizingly", + "tasselled": "tasseled", + "technicolour": "technicolor", + "temporise": "temporize", + "temporised": "temporized", + "temporises": "temporizes", + "temporising": "temporizing", + "tenderise": "tenderize", + "tenderised": "tenderized", + "tenderises": "tenderizes", + "tenderising": "tenderizing", + "terrorise": "terrorize", + "terrorised": "terrorized", + "terrorises": "terrorizes", + "terrorising": "terrorizing", + "theatre": "theater", + "theatregoer": "theatergoer", + "theatregoers": "theatergoers", + "theatres": "theaters", + "theorise": "theorize", + "theorised": "theorized", + "theorises": "theorizes", + "theorising": "theorizing", + "tonne": "ton", + "tonnes": "tons", + "towelled": "toweled", + "towelling": "toweling", + "toxaemia": "toxemia", + "tranquillise": "tranquilize", + "tranquillised": "tranquilized", + "tranquilliser": "tranquilizer", + "tranquillisers": "tranquilizers", + "tranquillises": "tranquilizes", + "tranquillising": "tranquilizing", + "tranquillity": "tranquility", + "tranquillize": "tranquilize", + "tranquillized": "tranquilized", + "tranquillizer": "tranquilizer", + "tranquillizers": "tranquilizers", + "tranquillizes": "tranquilizes", + "tranquillizing": "tranquilizing", + "tranquilly": "tranquility", + "transistorised": "transistorized", + "traumatise": "traumatize", + "traumatised": "traumatized", + "traumatises": "traumatizes", + "traumatising": "traumatizing", + "travelled": "traveled", + "traveller": "traveler", + "travellers": "travelers", + "travelling": "traveling", + "travelog": "travelogue", + "travelogs": "travelogues", + "trialled": "trialed", + "trialling": "trialing", + "tricolour": "tricolor", + "tricolours": "tricolors", + "trivialise": "trivialize", + "trivialised": "trivialized", + "trivialises": "trivializes", + "trivialising": "trivializing", + "tumour": "tumor", + "tumours": "tumors", + "tunnelled": "tunneled", + "tunnelling": "tunneling", + "tyrannise": "tyrannize", + "tyrannised": "tyrannized", + "tyrannises": "tyrannizes", + "tyrannising": "tyrannizing", + "tyre": "tire", + "tyres": "tires", + "unauthorised": "unauthorized", + "uncivilised": "uncivilized", + "underutilised": "underutilized", + "unequalled": "unequaled", + "unfavourable": "unfavorable", + "unfavourably": "unfavorably", + "unionisation": "unionization", + "unionise": "unionize", + "unionised": "unionized", + "unionises": "unionizes", + "unionising": "unionizing", + "unorganised": "unorganized", + "unravelled": "unraveled", + "unravelling": "unraveling", + "unrecognisable": "unrecognizable", + "unrecognised": "unrecognized", + "unrivalled": "unrivaled", + "unsavoury": "unsavory", + "untrammelled": "untrammeled", + "urbanisation": "urbanization", + "urbanise": "urbanize", + "urbanised": "urbanized", + "urbanises": "urbanizes", + "urbanising": "urbanizing", + "utilisable": "utilizable", + "utilisation": "utilization", + "utilise": "utilize", + "utilised": "utilized", + "utilises": "utilizes", + "utilising": "utilizing", + "valour": "valor", + "vandalise": "vandalize", + "vandalised": "vandalized", + "vandalises": "vandalizes", + "vandalising": "vandalizing", + "vaporisation": "vaporization", + "vaporise": "vaporize", + "vaporised": "vaporized", + "vaporises": "vaporizes", + "vaporising": "vaporizing", + "vapour": "vapor", + "vapours": "vapors", + "verbalise": "verbalize", + "verbalised": "verbalized", + "verbalises": "verbalizes", + "verbalising": "verbalizing", + "victimisation": "victimization", + "victimise": "victimize", + "victimised": "victimized", + "victimises": "victimizes", + "victimising": "victimizing", + "videodisc": "videodisk", + "videodiscs": "videodisks", + "vigour": "vigor", + "visualisation": "visualization", + "visualisations": "visualizations", + "visualise": "visualize", + "visualised": "visualized", + "visualises": "visualizes", + "visualising": "visualizing", + "vocalisation": "vocalization", + "vocalisations": "vocalizations", + "vocalise": "vocalize", + "vocalised": "vocalized", + "vocalises": "vocalizes", + "vocalising": "vocalizing", + "vulcanised": "vulcanized", + "vulgarisation": "vulgarization", + "vulgarise": "vulgarize", + "vulgarised": "vulgarized", + "vulgarises": "vulgarizes", + "vulgarising": "vulgarizing", + "waggon": "wagon", + "waggons": "wagons", + "watercolour": "watercolor", + "watercolours": "watercolors", + "weaselled": "weaseled", + "weaselling": "weaseling", + "westernisation": "westernization", + "westernise": "westernize", + "westernised": "westernized", + "westernises": "westernizes", + "westernising": "westernizing", + "womanise": "womanize", + "womanised": "womanized", + "womaniser": "womanizer", + "womanisers": "womanizers", + "womanises": "womanizes", + "womanising": "womanizing", + "woollen": "woolen", + "woollens": "woolens", + "woollies": "woolies", + "woolly": "wooly", + "worshipped": "worshiped", + "worshipping": "worshiping", + "worshipper": "worshiper", + "yodelled": "yodeled", + "yodelling": "yodeling", + "yoghourt": "yogurt", + "yoghourts": "yogurts", + "yoghurt": "yogurt", + "yoghurts": "yogurts", + "mhm": "hmm", + "mmm": "hmm" +} \ No newline at end of file diff --git a/tests/librispeech-parakeet/normalizers/english.py b/tests/librispeech-parakeet/normalizers/english.py new file mode 100644 index 00000000000..4932042bc5b --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/english.py @@ -0,0 +1,550 @@ +import json +import os +import re +from fractions import Fraction +from typing import Iterator, List, Match, Optional, Union + +from more_itertools import windowed + +from .basic import remove_symbols_and_diacritics + + +class EnglishNumberNormalizer: + """ + Convert any spelled-out numbers into arabic numbers, while handling: + + - remove any commas + - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc. + - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars` + - spell out `one` and `ones` + - interpret successive single-digit numbers as nominal: `one oh one` -> `101` + """ + + def __init__(self): + super().__init__() + + self.zeros = {"o", "oh", "zero"} + self.ones = { + name: i + for i, name in enumerate( + [ + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + "ten", + "eleven", + "twelve", + "thirteen", + "fourteen", + "fifteen", + "sixteen", + "seventeen", + "eighteen", + "nineteen", + ], + start=1, + ) + } + self.ones_plural = { + "sixes" if name == "six" else name + "s": (value, "s") + for name, value in self.ones.items() + } + self.ones_ordinal = { + "zeroth": (0, "th"), + "first": (1, "st"), + "second": (2, "nd"), + "third": (3, "rd"), + "fifth": (5, "th"), + "twelfth": (12, "th"), + **{ + name + ("h" if name.endswith("t") else "th"): (value, "th") + for name, value in self.ones.items() + if value > 3 and value != 5 and value != 12 + }, + } + self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal} + + self.tens = { + "twenty": 20, + "thirty": 30, + "forty": 40, + "fifty": 50, + "sixty": 60, + "seventy": 70, + "eighty": 80, + "ninety": 90, + } + self.tens_plural = { + name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() + } + self.tens_ordinal = { + name.replace("y", "ieth"): (value, "th") + for name, value in self.tens.items() + } + self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} + + self.multipliers = { + "hundred": 100, + "thousand": 1_000, + "million": 1_000_000, + "billion": 1_000_000_000, + "trillion": 1_000_000_000_000, + "quadrillion": 1_000_000_000_000_000, + "quintillion": 1_000_000_000_000_000_000, + "sextillion": 1_000_000_000_000_000_000_000, + "septillion": 1_000_000_000_000_000_000_000_000, + "octillion": 1_000_000_000_000_000_000_000_000_000, + "nonillion": 1_000_000_000_000_000_000_000_000_000_000, + "decillion": 1_000_000_000_000_000_000_000_000_000_000_000, + } + self.multipliers_plural = { + name + "s": (value, "s") for name, value in self.multipliers.items() + } + self.multipliers_ordinal = { + name + "th": (value, "th") for name, value in self.multipliers.items() + } + self.multipliers_suffixed = { + **self.multipliers_plural, + **self.multipliers_ordinal, + } + self.decimals = {*self.ones, *self.tens, *self.zeros} + + self.preceding_prefixers = { + "minus": "-", + "negative": "-", + "plus": "+", + "positive": "+", + } + self.following_prefixers = { + "pound": "£", + "pounds": "£", + "euro": "€", + "euros": "€", + "dollar": "$", + "dollars": "$", + "cent": "¢", + "cents": "¢", + } + self.prefixes = set( + list(self.preceding_prefixers.values()) + + list(self.following_prefixers.values()) + ) + self.suffixers = { + "per": {"cent": "%"}, + "percent": "%", + } + self.specials = {"and", "double", "triple", "point"} + + self.words = set( + [ + key + for mapping in [ + self.zeros, + self.ones, + self.ones_suffixed, + self.tens, + self.tens_suffixed, + self.multipliers, + self.multipliers_suffixed, + self.preceding_prefixers, + self.following_prefixers, + self.suffixers, + self.specials, + ] + for key in mapping + ] + ) + self.literal_words = {"one", "ones"} + + def process_words(self, words: List[str]) -> Iterator[str]: + prefix: Optional[str] = None + value: Optional[Union[str, int]] = None + skip = False + + def to_fraction(s: str): + try: + return Fraction(s) + except ValueError: + return None + + def output(result: Union[str, int]): + nonlocal prefix, value + result = str(result) + if prefix is not None: + result = prefix + result + value = None + prefix = None + return result + + if len(words) == 0: + return + + for prev, current, next in windowed([None] + words + [None], 3): + if skip: + skip = False + continue + + next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next) + has_prefix = current[0] in self.prefixes + current_without_prefix = current[1:] if has_prefix else current + if re.match(r"^\d+(\.\d+)?$", current_without_prefix): + # arabic numbers (potentially with signs and fractions) + f = to_fraction(current_without_prefix) + assert f is not None + if value is not None: + if isinstance(value, str) and value.endswith("."): + # concatenate decimals / ip address components + value = str(value) + str(current) + continue + else: + yield output(value) + + prefix = current[0] if has_prefix else prefix + if f.denominator == 1: + value = f.numerator # store integers as int + else: + value = current_without_prefix + elif current not in self.words: + # non-numeric words + if value is not None: + yield output(value) + yield output(current) + elif current in self.zeros: + value = str(value or "") + "0" + elif current in self.ones: + ones = self.ones[current] + + if value is None: + value = ones + elif isinstance(value, str) or prev in self.ones: + if ( + prev in self.tens and ones < 10 + ): # replace the last zero with the digit + assert value[-1] == "0" + value = value[:-1] + str(ones) + else: + value = str(value) + str(ones) + elif ones < 10: + if value % 10 == 0: + value += ones + else: + value = str(value) + str(ones) + else: # eleven to nineteen + if value % 100 == 0: + value += ones + else: + value = str(value) + str(ones) + elif current in self.ones_suffixed: + # ordinal or cardinal; yield the number right away + ones, suffix = self.ones_suffixed[current] + if value is None: + yield output(str(ones) + suffix) + elif isinstance(value, str) or prev in self.ones: + if prev in self.tens and ones < 10: + assert value[-1] == "0" + yield output(value[:-1] + str(ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + elif ones < 10: + if value % 10 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + else: # eleven to nineteen + if value % 100 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + value = None + elif current in self.tens: + tens = self.tens[current] + if value is None: + value = tens + elif isinstance(value, str): + value = str(value) + str(tens) + else: + if value % 100 == 0: + value += tens + else: + value = str(value) + str(tens) + elif current in self.tens_suffixed: + # ordinal or cardinal; yield the number right away + tens, suffix = self.tens_suffixed[current] + if value is None: + yield output(str(tens) + suffix) + elif isinstance(value, str): + yield output(str(value) + str(tens) + suffix) + else: + if value % 100 == 0: + yield output(str(value + tens) + suffix) + else: + yield output(str(value) + str(tens) + suffix) + elif current in self.multipliers: + multiplier = self.multipliers[current] + if value is None: + value = multiplier + elif isinstance(value, str) or value == 0: + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + value = p.numerator + else: + yield output(value) + value = multiplier + else: + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + elif current in self.multipliers_suffixed: + multiplier, suffix = self.multipliers_suffixed[current] + if value is None: + yield output(str(multiplier) + suffix) + elif isinstance(value, str): + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + yield output(str(p.numerator) + suffix) + else: + yield output(value) + yield output(str(multiplier) + suffix) + else: # int + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + yield output(str(value) + suffix) + value = None + elif current in self.preceding_prefixers: + # apply prefix (positive, minus, etc.) if it precedes a number + if value is not None: + yield output(value) + + if next in self.words or next_is_numeric: + prefix = self.preceding_prefixers[current] + else: + yield output(current) + elif current in self.following_prefixers: + # apply prefix (dollars, cents, etc.) only after a number + if value is not None: + prefix = self.following_prefixers[current] + yield output(value) + else: + yield output(current) + elif current in self.suffixers: + # apply suffix symbols (percent -> '%') + if value is not None: + suffix = self.suffixers[current] + if isinstance(suffix, dict): + if next in suffix: + yield output(str(value) + suffix[next]) + skip = True + else: + yield output(value) + yield output(current) + else: + yield output(str(value) + suffix) + else: + yield output(current) + elif current in self.specials: + if next not in self.words and not next_is_numeric: + # apply special handling only if the next word can be numeric + if value is not None: + yield output(value) + yield output(current) + elif current == "and": + # ignore "and" after hundreds, thousands, etc. + if prev not in self.multipliers: + if value is not None: + yield output(value) + yield output(current) + elif current == "double" or current == "triple": + if next in self.ones or next in self.zeros: + repeats = 2 if current == "double" else 3 + ones = self.ones.get(next, 0) + value = str(value or "") + str(ones) * repeats + skip = True + else: + if value is not None: + yield output(value) + yield output(current) + elif current == "point": + if next in self.decimals or next_is_numeric: + value = str(value or "") + "." + else: + # should all have been covered at this point + raise ValueError(f"Unexpected token: {current}") + else: + # all should have been covered at this point + raise ValueError(f"Unexpected token: {current}") + + if value is not None: + yield output(value) + + def preprocess(self, s: str): + # replace " and a half" with " point five" + results = [] + + segments = re.split(r"\band\s+a\s+half\b", s) + for i, segment in enumerate(segments): + if len(segment.strip()) == 0: + continue + if i == len(segments) - 1: + results.append(segment) + else: + results.append(segment) + last_word = segment.rsplit(maxsplit=2)[-1] + if last_word in self.decimals or last_word in self.multipliers: + results.append("point five") + else: + results.append("and a half") + + s = " ".join(results) + + # put a space at number/letter boundary + s = re.sub(r"([a-z])([0-9])", r"\1 \2", s) + s = re.sub(r"([0-9])([a-z])", r"\1 \2", s) + + # but remove spaces which could be a suffix + s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s) + + return s + + def postprocess(self, s: str): + def combine_cents(m: Match): + try: + currency = m.group(1) + integer = m.group(2) + cents = int(m.group(3)) + return f"{currency}{integer}.{cents:02d}" + except ValueError: + return m.string + + def extract_cents(m: Match): + try: + return f"¢{int(m.group(1))}" + except ValueError: + return m.string + + # apply currency postprocessing; "$2 and ¢7" -> "$2.07" + s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s) + s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s) + + # write "one(s)" instead of "1(s)", just for the readability + s = re.sub(r"\b1(s?)\b", r"one\1", s) + + return s + + def __call__(self, s: str): + s = self.preprocess(s) + s = " ".join(word for word in self.process_words(s.split()) if word is not None) + s = self.postprocess(s) + + return s + + +class EnglishSpellingNormalizer: + """ + Applies British-American spelling mappings as listed in [1]. + + [1] https://www.tysto.com/uk-us-spelling-list.html + """ + + def __init__(self): + mapping_path = os.path.join(os.path.dirname(__file__), "english.json") + self.mapping = json.load(open(mapping_path)) + + def __call__(self, s: str): + return " ".join(self.mapping.get(word, word) for word in s.split()) + + +class EnglishTextNormalizer: + def __init__(self): + self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b" + self.replacers = { + # common contractions + r"\bwon't\b": "will not", + r"\bcan't\b": "can not", + r"\blet's\b": "let us", + r"\bain't\b": "aint", + r"\by'all\b": "you all", + r"\bwanna\b": "want to", + r"\bgotta\b": "got to", + r"\bgonna\b": "going to", + r"\bi'ma\b": "i am going to", + r"\bimma\b": "i am going to", + r"\bwoulda\b": "would have", + r"\bcoulda\b": "could have", + r"\bshoulda\b": "should have", + r"\bma'am\b": "madam", + # contractions in titles/prefixes + r"\bmr\b": "mister ", + r"\bmrs\b": "missus ", + r"\bst\b": "saint ", + r"\bdr\b": "doctor ", + r"\bprof\b": "professor ", + r"\bcapt\b": "captain ", + r"\bgov\b": "governor ", + r"\bald\b": "alderman ", + r"\bgen\b": "general ", + r"\bsen\b": "senator ", + r"\brep\b": "representative ", + r"\bpres\b": "president ", + r"\brev\b": "reverend ", + r"\bhon\b": "honorable ", + r"\basst\b": "assistant ", + r"\bassoc\b": "associate ", + r"\blt\b": "lieutenant ", + r"\bcol\b": "colonel ", + r"\bjr\b": "junior ", + r"\bsr\b": "senior ", + r"\besq\b": "esquire ", + # prefect tenses, ideally it should be any past participles, but it's harder.. + r"'d been\b": " had been", + r"'s been\b": " has been", + r"'d gone\b": " had gone", + r"'s gone\b": " has gone", + r"'d done\b": " had done", # "'s done" is ambiguous + r"'s got\b": " has got", + # general contractions + r"n't\b": " not", + r"'re\b": " are", + r"'s\b": " is", + r"'d\b": " would", + r"'ll\b": " will", + r"'t\b": " not", + r"'ve\b": " have", + r"'m\b": " am", + } + self.standardize_numbers = EnglishNumberNormalizer() + self.standardize_spellings = EnglishSpellingNormalizer() + + def __call__(self, s: str): + s = s.lower() + + s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = re.sub(self.ignore_patterns, "", s) + s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe + + for pattern, replacement in self.replacers.items(): + s = re.sub(pattern, replacement, s) + + s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits + s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers + s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols + + s = self.standardize_numbers(s) + s = self.standardize_spellings(s) + + # now remove prefix/suffix symbols that are not preceded/followed by numbers + s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) + s = re.sub(r"([^0-9])%", r"\1 ", s) + + s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space + + return s From d8b74e463f4037fbb72923bbdd7a9f3df95c7f2b Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 5 May 2026 08:48:29 +0200 Subject: [PATCH 18/33] squash! tests : add script to benchmark parakeet.cpp on LibriSpeech [no ci] Remove hardcoded build-cuda-89-release and just use build like whisper.cpp does. --- tests/librispeech-parakeet/eval.mk | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/librispeech-parakeet/eval.mk b/tests/librispeech-parakeet/eval.mk index fa6d2e10264..7d8992ec471 100644 --- a/tests/librispeech-parakeet/eval.mk +++ b/tests/librispeech-parakeet/eval.mk @@ -3,7 +3,7 @@ PYTHON = python PARAKEET_PREFIX = ../../ PARAKEET_MODEL = parakeet-tdt-0.6b-v3 -PARAKEET_CLI = $(PARAKEET_PREFIX)build-cuda-89-release/bin/parakeet-cli +PARAKEET_CLI = $(PARAKEET_PREFIX)build/bin/parakeet-cli PARAKEET_FLAGS = --no-prints --output-txt # You can create eval.conf to override the PARAKEET_* variables From 3fa1f6935f8ca51c125b734bf339409f68513ea5 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 5 May 2026 14:27:33 +0200 Subject: [PATCH 19/33] parakeet : enable model conversion on win [no ci] This commit updates the parkeet requirements that are out of date as I've ben using a virtual environment on linux/mac that contains torch and numpy. This also fixes the reading of the model configuration which was failing on window. --- models/convert-parakeet-to-ggml.py | 2 +- models/requirements-parakeet.txt | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/models/convert-parakeet-to-ggml.py b/models/convert-parakeet-to-ggml.py index 9df6dc3e679..cbeec7e6c70 100755 --- a/models/convert-parakeet-to-ggml.py +++ b/models/convert-parakeet-to-ggml.py @@ -39,7 +39,7 @@ def extract_nemo_archive(nemo_path, extract_dir): print("Extraction complete") def load_model_config(config_path): - with open(config_path, 'r') as f: + with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) return config diff --git a/models/requirements-parakeet.txt b/models/requirements-parakeet.txt index c3726e8bfee..5239ae0af5d 100644 --- a/models/requirements-parakeet.txt +++ b/models/requirements-parakeet.txt @@ -1 +1,3 @@ +torch +numpy pyyaml From 0c4f4baec869f483339683c1b15321821cacdb51 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 7 May 2026 15:01:12 +0200 Subject: [PATCH 20/33] parakeet : add parakeet_reset_state function [no ci] This commit adds a function to reset the parakeet state that can be resused instead of duplicating code. It also resets the lstm state which was not done by parakeet_full leading to incorrect transcriptions when called multiple times --- src/parakeet.cpp | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index 9a7dfddf475..d74a9f38580 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -3593,17 +3593,22 @@ struct parakeet_full_params parakeet_full_default_params(enum parakeet_sampling_ return result; } +static void parakeet_reset_state(struct parakeet_state * state) { + state->decoded_tokens.clear(); + state->decoded_token_data.clear(); + + if (state->lstm_state.buffer) { + ggml_backend_buffer_clear(state->lstm_state.buffer, 0); + } +} + static void parakeet_stream_reset_state(struct parakeet_state * state) { if (state == nullptr) { return; } - if (state->lstm_state.buffer) { - ggml_backend_buffer_clear(state->lstm_state.buffer, 0); - } + parakeet_reset_state(state); - state->decoded_tokens.clear(); - state->decoded_token_data.clear(); state->result_all.clear(); state->tdt_stream_state.initialized = false; @@ -3858,8 +3863,7 @@ int parakeet_full_with_state( // Clear any previous decoded tokens if no_context is set if (params.no_context) { - state->decoded_tokens.clear(); - state->decoded_token_data.clear(); + parakeet_reset_state(state); } // If no chunking specified, delegate to parakeet_chunk (processes entire audio as one chunk) @@ -4022,9 +4026,7 @@ int parakeet_chunk( int n_samples) { if (params.no_context) { - ggml_backend_buffer_clear(state->lstm_state.buffer, 0); - state->decoded_tokens.clear(); - state->decoded_token_data.clear(); + parakeet_reset_state(state); state->tdt_stream_state.initialized = false; state->tdt_stream_state.last_token = 0; From cb611a44f577542190c7dc028fd22d5e50087f3a Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 7 May 2026 15:07:00 +0200 Subject: [PATCH 21/33] examples : reuse context in parakeet-cli [no ci] --- examples/parakeet-cli/parakeet-cli.cpp | 59 +++++++++++++------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/examples/parakeet-cli/parakeet-cli.cpp b/examples/parakeet-cli/parakeet-cli.cpp index c8a77a42fbc..df2de45b84e 100644 --- a/examples/parakeet-cli/parakeet-cli.cpp +++ b/examples/parakeet-cli/parakeet-cli.cpp @@ -107,15 +107,15 @@ static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_para } void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { - static bool is_first = true; + bool * is_first = (bool *) user_data; const char * token_str = parakeet_token_to_str(ctx, token_data->id); char text_buf[256]; - parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf)); + parakeet_token_to_text(token_str, *is_first, text_buf, sizeof(text_buf)); printf("%s", text_buf); fflush(stdout); - is_first = false; + *is_first = false; } static void cb_log_disable(enum ggml_log_level , const char * , void * ) { } @@ -139,6 +139,28 @@ int main(int argc, char ** argv) { return 1; } + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + ctx_params.use_gpu = params.use_gpu; + ctx_params.flash_attn = params.flash_attn; + ctx_params.gpu_device = params.gpu_device; + + if (!params.no_prints) { + fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str()); + } + + + struct parakeet_context * pctx = parakeet_init_from_file_with_params(params.model.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "error: failed to load Parakeet model from '%s'\n", params.model.c_str()); + return 1; + } + + if (!params.no_prints) { + fprintf(stderr, "Successfully loaded Parakeet model\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info()); + } + // Process each input file for (const auto & fname : params.fname_inp) { if (!params.no_prints) { @@ -157,36 +179,14 @@ int main(int argc, char ** argv) { continue; } - if (!params.no_prints) { - fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str()); - } - - struct parakeet_context_params ctx_params = parakeet_context_default_params(); - ctx_params.use_gpu = params.use_gpu; - ctx_params.flash_attn = params.flash_attn; - ctx_params.gpu_device = params.gpu_device; - - struct parakeet_context * pctx = parakeet_init_from_file_with_params(params.model.c_str(), ctx_params); - if (pctx == nullptr) { - fprintf(stderr, "error: failed to load Parakeet model from '%s'\n", params.model.c_str()); - return 1; - } - - if (!params.no_prints) { - fprintf(stderr, "Successfully loaded Parakeet model\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info()); - fprintf(stderr, "Processing audio (%zu samples, %.2f seconds)\n", - pcmf32.size(), (float)pcmf32.size() / PARAKEET_SAMPLE_RATE); - } - + bool is_first = true; struct parakeet_full_params full_params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); full_params.n_threads = params.n_threads; full_params.chunk_length_ms = params.chunk_length_ms; full_params.left_context_ms = params.left_context_ms; full_params.right_context_ms = params.right_context_ms; full_params.new_token_callback = token_callback; - full_params.new_token_callback_user_data = nullptr; + full_params.new_token_callback_user_data = &is_first; const int mel_frames = (int)(pcmf32.size() / PARAKEET_HOP_LENGTH); if (mel_frames <= parakeet_n_audio_ctx(pctx)) { @@ -197,7 +197,6 @@ int main(int argc, char ** argv) { if (ret != 0) { fprintf(stderr, "error: failed to process audio file '%s'\n", fname.c_str()); - parakeet_free(pctx); continue; } @@ -258,9 +257,9 @@ int main(int argc, char ** argv) { } } } - - parakeet_free(pctx); } + parakeet_free(pctx); + return 0; } From 646976b8d67929789586191516d74ec57da9b636 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 8 May 2026 08:18:34 +0200 Subject: [PATCH 22/33] parakeet : remove unused functions [no ci] --- src/parakeet.cpp | 58 ------------------------------------------------ 1 file changed, 58 deletions(-) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index d74a9f38580..a1617320d4b 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -202,49 +202,6 @@ static bool ggml_graph_compute_helper( // TODO: move these functions to ggml-base with support for ggml-backend? -static ggml_tensor * parakeet_set_f32(struct ggml_tensor * t, float v) { - GGML_ASSERT(t->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(t)); - size_t nels = ggml_nelements(t); - for (size_t i = 0; i < nels; ++i) { - ((float *) t->data)[i] = v; - } - return t; -} - -static ggml_tensor * parakeet_set_i32(struct ggml_tensor * t, int32_t v) { - GGML_ASSERT(t->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_is_contiguous(t)); - size_t nels = ggml_nelements(t); - for (size_t i = 0; i < nels; ++i) { - ((int32_t *) t->data)[i] = v; - } - return t; -} - -static float parakeet_get_f32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { - GGML_ASSERT(t->type == GGML_TYPE_F32); - void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3]; - return *(float *) data; -} - -static void parakeet_set_f32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float v) { - GGML_ASSERT(t->type == GGML_TYPE_F32); - void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3]; - *(float *) data = v; -} - -static int32_t parakeet_get_i32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { - GGML_ASSERT(t->type == GGML_TYPE_I32); - void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3]; - return *(int32_t *) data; -} - -static void parakeet_set_i32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, int32_t v) { - GGML_ASSERT(t->type == GGML_TYPE_I32); - void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3]; - *(int32_t *) data = v; -} struct parakeet_mel { int n_len = 0; @@ -2677,21 +2634,6 @@ static bool parakeet_decode_chunk( // 500 -> 00:05.000 // 6000 -> 01:00.000 -static std::string to_timestamp(int64_t t, bool comma = false) { - int64_t msec = t * 10; - int64_t hr = msec / (1000 * 60 * 60); - msec = msec - hr * (1000 * 60 * 60); - int64_t min = msec / (1000 * 60); - msec = msec - min * (1000 * 60); - int64_t sec = msec / 1000; - msec = msec - sec * 1000; - - char buf[32]; - snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); - - return std::string(buf); -} - // naive Discrete Fourier Transform // input is real-valued // output is complex-valued From 5cc5a559dfc3b418cae0096f264f56b278feb04c Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 8 May 2026 10:02:40 +0200 Subject: [PATCH 23/33] parakeet : remove unused include [no ci] --- src/parakeet.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index a1617320d4b..cbb815429d0 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -21,7 +21,6 @@ #include #include #include -#include #include #include #include From f9dc4883feff0e7f8cdd2e9fda6464c66163508c Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 8 May 2026 13:25:53 +0200 Subject: [PATCH 24/33] parakeet : zero initialize samples_padded [no ci] --- src/parakeet.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index cbb815429d0..8286a1f8a87 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -2811,10 +2811,8 @@ static bool log_mel_spectrogram( } // Parakeet Pytorch implementation uses centered contant padding. - const int pad = frame_size / 2; - std::vector samples_padded(n_samples + 2 * pad); // 176512 - std::fill(samples_padded.begin(), samples_padded.begin() + pad, 0.0f); - std::fill(samples_padded.begin() + pad + n_samples, samples_padded.end(), 0.0f); + const size_t pad = (size_t)(frame_size / 2); + std::vector samples_padded(n_samples + 2 * pad, 0.0f); std::copy(samples_preprocessed.begin(), samples_preprocessed.end(), samples_padded.begin() + pad); mel.n_mel = n_mel; From 64160ed28709ffd69274bad85b73480a11149f23 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 8 May 2026 13:46:14 +0200 Subject: [PATCH 25/33] parakeet : add mel_worker_params struct [no ci] --- src/parakeet.cpp | 66 +++++++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index 8286a1f8a87..25552f87dac 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -2702,49 +2702,53 @@ static void fft(float* in, int N, float* out, const parakeet_mel_cache & cache) } } +struct mel_worker_params { + int ith; + int window_size; + int n_samples; + int frame_size; + int frame_step; + int n_threads; +}; + static void log_mel_spectrogram_worker_thread( - int ith, - const float * window_func, - int window_size, - const std::vector & samples, - int n_samples, - int frame_size, - int frame_step, - int n_threads, - const parakeet_filters & filters, - parakeet_mel & mel, - const parakeet_mel_cache & cache) { - std::vector fft_in(frame_size * 2, 0.0); - std::vector fft_out(frame_size * 2 * 2 * 2); + mel_worker_params params, + const float * window_func, + const std::vector & samples, + const parakeet_filters & filters, + parakeet_mel & mel, + const parakeet_mel_cache & cache) { + std::vector fft_in(params.frame_size * 2, 0.0); + std::vector fft_out(params.frame_size * 2 * 2 * 2); int n_fb = filters.n_fb; // number of frequency bins - int i = ith; + int i = params.ith; // make sure n_fb == 1 + (frame_size / 2), bin_0 to bin_nyquist - assert(n_fb == 1 + (frame_size / 2)); + assert(n_fb == 1 + (params.frame_size / 2)); const double eps = 5.960464477539063e-08; // calculate FFT only when fft_in are not all zero - for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { - const int offset = i * frame_step; + for (; i < std::min(params.n_samples / params.frame_step + 1, mel.n_len); i += params.n_threads) { + const int offset = i * params.frame_step; - const int window_pad_left = (frame_size - window_size) / 2; + const int window_pad_left = (params.frame_size - params.window_size) / 2; // Zero-pad left std::fill(fft_in.begin(), fft_in.begin() + window_pad_left, 0.0f); // Apply windowed samples in the center - const int n_to_process = std::min({window_size, n_samples - offset}); + const int n_to_process = std::min({params.window_size, params.n_samples - offset}); for (int j = 0; j < n_to_process; j++) { fft_in[window_pad_left + j] = window_func[j] * samples[offset + window_pad_left + j]; } // Zero-pad right (and any samples we didn't have) - std::fill(fft_in.begin() + window_pad_left + n_to_process, fft_in.begin() + frame_size, 0.0f); + std::fill(fft_in.begin() + window_pad_left + n_to_process, fft_in.begin() + params.frame_size, 0.0f); // FFT - fft(fft_in.data(), frame_size, fft_out.data(), cache); + fft(fft_in.data(), params.frame_size, fft_out.data(), cache); // Calculate modulus^2 of complex numbers // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. @@ -2775,7 +2779,7 @@ static void log_mel_spectrogram_worker_thread( // Otherwise fft_out are all zero - use log(eps) for consistency const double empty_sum = std::log(eps); - for (; i < mel.n_len; i += n_threads) { + for (; i < mel.n_len; i += params.n_threads) { for (int j = 0; j < mel.n_mel; j++) { mel.data[i * mel.n_mel + j] = empty_sum; } @@ -2823,30 +2827,24 @@ static bool log_mel_spectrogram( // Worker Threads (STFT + Mel + Natural Log) { std::vector workers(n_threads - 1); + const mel_worker_params mel_params { 0, window_size, (int)samples_padded.size(), frame_size, frame_step, n_threads }; + for (int iw = 0; iw < n_threads - 1; ++iw) { + mel_worker_params params = mel_params; + params.ith = iw + 1; workers[iw] = std::thread(log_mel_spectrogram_worker_thread, - iw + 1, // thread index + params, window_func, - window_size, std::cref(samples_padded), - samples_padded.size(), - frame_size, - frame_step, - n_threads, std::cref(filters), std::ref(mel), std::cref(cache)); } log_mel_spectrogram_worker_thread( - 0, + mel_params, window_func, - window_size, samples_padded, - samples_padded.size(), - frame_size, - frame_step, - n_threads, filters, mel, cache); From 89d59f342f729a21988055602aee5979641fd6c0 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 8 May 2026 13:59:46 +0200 Subject: [PATCH 26/33] parakeet : add missing eps to model output and align [no ci] --- src/parakeet.cpp | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index 25552f87dac..9c6384511e5 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -1007,22 +1007,23 @@ static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_ } const char* arch_name = hparams.arch == PARAKEET_ARCH_TDT ? "Parakeet TDT" : "unknown"; - PARAKEET_LOG_INFO("%s: arch = %s\n", __func__, arch_name); - PARAKEET_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab); - PARAKEET_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); - PARAKEET_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); - PARAKEET_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); - PARAKEET_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); - PARAKEET_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels); - PARAKEET_LOG_INFO("%s: n_fft = %d\n", __func__, hparams.n_fft); - PARAKEET_LOG_INFO("%s: ftype = %d\n", __func__, hparams.ftype); - PARAKEET_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr); - PARAKEET_LOG_INFO("%s: subsampling_factor = %d\n", __func__, hparams.subsampling_factor); + PARAKEET_LOG_INFO("%s: arch = %s\n", __func__, arch_name); + PARAKEET_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + PARAKEET_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + PARAKEET_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + PARAKEET_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + PARAKEET_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + PARAKEET_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels); + PARAKEET_LOG_INFO("%s: n_fft = %d\n", __func__, hparams.n_fft); + PARAKEET_LOG_INFO("%s: eps = %f\n", __func__, hparams.eps); + PARAKEET_LOG_INFO("%s: ftype = %d\n", __func__, hparams.ftype); + PARAKEET_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr); + PARAKEET_LOG_INFO("%s: subsampling_factor = %d\n", __func__, hparams.subsampling_factor); PARAKEET_LOG_INFO("%s: n_subsampling_channels = %d\n", __func__, hparams.n_subsampling_channels); - PARAKEET_LOG_INFO("%s: n_pred_dim = %d\n", __func__, hparams.n_pred_dim); - PARAKEET_LOG_INFO("%s: n_pred_layers = %d\n", __func__, hparams.n_pred_layers); - PARAKEET_LOG_INFO("%s: n_tdt_durations = %d\n", __func__, hparams.n_tdt_durations); - PARAKEET_LOG_INFO("%s: n_max_tokens = %d\n", __func__, hparams.n_max_tokens); + PARAKEET_LOG_INFO("%s: n_pred_dim = %d\n", __func__, hparams.n_pred_dim); + PARAKEET_LOG_INFO("%s: n_pred_layers = %d\n", __func__, hparams.n_pred_layers); + PARAKEET_LOG_INFO("%s: n_tdt_durations = %d\n", __func__, hparams.n_tdt_durations); + PARAKEET_LOG_INFO("%s: n_max_tokens = %d\n", __func__, hparams.n_max_tokens); } // load mel filters From a80354118570e23f2fbd74efc9aba5c9f481328f Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 8 May 2026 14:25:22 +0200 Subject: [PATCH 27/33] parakeet : fix n_tensor calculation and cleanup [no ci] --- src/parakeet.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index 9c6384511e5..d4b8c866ff8 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -1055,7 +1055,7 @@ static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_ PARAKEET_LOG_INFO("%s: loaded window function with %d samples\n", __func__, n_window); } - // load tdt durations values + // load TDT (Token and Duration Transducer) values { auto & tdt_durations = wctx.model.tdt_durations; tdt_durations.resize(hparams.n_tdt_durations); @@ -1136,8 +1136,8 @@ static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_ const int n_audio_layer = hparams.n_audio_layer; - // Calculate tensor count based on architecture - size_t n_tensors = 2 + 14 + 1 + 29*n_audio_layer + 9 + 6; + // Calculate tensor count: pre_encode (12) + encoder layers (29 per layer) + prediction (9) + joint (6) + size_t n_tensors = 12 + (29 * n_audio_layer) + 9 + 6; std::map ctx_map; auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { @@ -1170,7 +1170,8 @@ static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_ ggml_op op = PARAKEET_TENSOR_INFO.at(type); ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list); if (!buft) { - throw std::runtime_error(format("failed to find a compatible buffer type for parakeet tensor %s", PARAKEET_TENSOR_NAMES.at(type))); + throw std::runtime_error(format("failed to find a compatible buffer type for parakeet tensor %s", + PARAKEET_TENSOR_NAMES.at(type))); } ggml_context * ctx = get_ctx(buft); @@ -1190,11 +1191,8 @@ static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_ // prepare tensors for the weights - // Count: preprocessor (2) + pre_encode (14) + positional encoding (1) + encoder layers (29 per layer) + prediction (9) + joint (6) - const int n_parakeet_tensors = 2 + 14 + 1 + (29 * n_audio_layer) + 9 + 6; - ggml_init_params params = { - /*.mem_size =*/ n_parakeet_tensors * ggml_tensor_overhead(), + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; From a24a5dbd74d1aba737c71a8fdcb3c86311ff4c85 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 8 May 2026 14:26:54 +0200 Subject: [PATCH 28/33] parakeet : remove empty block [no ci] --- src/parakeet.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index d4b8c866ff8..bfe8063f035 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -1323,10 +1323,6 @@ static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_ model.joint.net_b = create_tensor(PARAKEET_TENSOR_JOINT_NET_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 8198)); ggml_set_name(model.joint.net_b, "net_b"); - // Relative positional encoding - { - } - ggml_free(ctx); // allocate tensors in the backend buffers From ab031ed1775e360ba29746aec3462a449746ed75 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 8 May 2026 14:36:17 +0200 Subject: [PATCH 29/33] parakeet : clean up comments [no ci] --- src/parakeet.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index bfe8063f035..cc94881764e 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -1440,7 +1440,7 @@ static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_ } // -// This function build the computation graph for the pre-encoder stage. +// This function builds the computation graph for the pre-encoder stage. // // It takes the generated mel spectrogram as an input tensor, which will be set // during processing, and performs a number of convolutions to detect/filter @@ -1475,13 +1475,11 @@ static struct ggml_cgraph * parakeet_build_graph_conv(parakeet_context & pctx, p cur = ggml_relu(ctx0, cur); ggml_set_name(cur, "pre_conv_0_relu"); - // enc_pre_conv_2_w: {3, 3, 1, 256} (depthwise) // [freq, time, channels, batch] cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_2_w, cur, 2, 2, 1, 1, 1, 1); cur = ggml_add(ctx0, cur, model.enc_pre_conv_2_b); ggml_set_name(cur, "pre_conv_2"); - // enc_pre_conv_3_w: {1, 1, 256, 256} (pointwise) // [freq, time, channels, batch] cur = ggml_conv_2d(ctx0, model.enc_pre_conv_3_w, cur, 1, 1, 0, 0, 1, 1); cur = ggml_add(ctx0, cur, model.enc_pre_conv_3_b); @@ -1490,14 +1488,12 @@ static struct ggml_cgraph * parakeet_build_graph_conv(parakeet_context & pctx, p cur = ggml_relu(ctx0, cur); ggml_set_name(cur, "pre_conv_3_relu"); - // enc_pre_conv_5_w: {3, 3, 1, 256} (depthwise) // [freq, time, channels, batch] cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_5_w, cur, 2, 2, 1, 1, 1, 1); ggml_set_name(cur, "pre_conv_5_direct"); cur = ggml_add(ctx0, cur, model.enc_pre_conv_5_b); ggml_set_name(cur, "pre_conv_5"); - // enc_pre_conv_6_w: {1, 1, 256, 256} (pointwise) // [freq, time, channels, batch] cur = ggml_conv_2d(ctx0, model.enc_pre_conv_6_w, cur, 1, 1, 0, 0, 1, 1); cur = ggml_add(ctx0, cur, model.enc_pre_conv_6_b); @@ -2250,7 +2246,7 @@ static bool is_punctuation_token(parakeet_vocab & vocab, parakeet_token token_id // Collapse punctuation timestamps to match the original Parakeet model. // Punctuations symbols like ',', '.' and others are not spoken words but the // model will still produce a duration for these tokens. But since these are -// non-spoken we collapse the timesstamps so that they don't have an time duration. +// non-spoken we collapse the timestamps so that they don't have an time duration. static void refine_timestamps_tdt(parakeet_vocab & vocab, std::vector & tokens) { if (tokens.empty()) { return; @@ -2486,7 +2482,7 @@ static bool parakeet_decode_chunk( int chosen_duration_idx = 0; float chosen_token_logit = 0.0f; - // inner loop handles the joint network + // inner loop handles the joint network. while (t < chunk_enc_frames) { batch.n_tokens = 1; batch.i_time[0] = std::min(t, chunk_enc_frames - 1); @@ -2540,7 +2536,7 @@ static bool parakeet_decode_chunk( break; } - // we continue with the current frame + // we continue with the current frame. current_token_time = t; continue; } From d9341017f01656a0d653e87ce51abf7bc9de3200 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 11 May 2026 08:35:38 +0200 Subject: [PATCH 30/33] parakeet : clear tdt_stream_state in parakeet_reset_state This commit moves the rest the tdt state into the parakeet_reset_state function instead of clearning them separately and duplicating this code. This commit also contains a new minor code format changes for clarity. --- src/parakeet.cpp | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index cc94881764e..feac4ea423a 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -3528,6 +3528,11 @@ static void parakeet_reset_state(struct parakeet_state * state) { if (state->lstm_state.buffer) { ggml_backend_buffer_clear(state->lstm_state.buffer, 0); } + + state->tdt_stream_state.initialized = false; + state->tdt_stream_state.last_token = 0; + state->tdt_stream_state.time_step = 0; + state->tdt_stream_state.decoded_length = 0; } static void parakeet_stream_reset_state(struct parakeet_state * state) { @@ -3539,11 +3544,6 @@ static void parakeet_stream_reset_state(struct parakeet_state * state) { state->result_all.clear(); - state->tdt_stream_state.initialized = false; - state->tdt_stream_state.last_token = 0; - state->tdt_stream_state.time_step = 0; - state->tdt_stream_state.decoded_length = 0; - state->stream.buffer.clear(); state->stream.n_samples_advanced = 0; state->stream.n_left_ctx = 0; @@ -3612,7 +3612,7 @@ static int parakeet_stream_process_window( if (new_token_count > 0) { std::string text; std::vector result_tokens; - const int64_t chunk_t0 = 100LL * stream.n_samples_advanced / PARAKEET_SAMPLE_RATE; + const int64_t chunk_t0 = 100LL * stream.n_samples_advanced / PARAKEET_SAMPLE_RATE; const int64_t chunk_t1 = 100LL * (stream.n_samples_advanced + n_chunk) / PARAKEET_SAMPLE_RATE; const int frame_offset = chunk_t0 / ctx->model.hparams.subsampling_factor; @@ -3848,7 +3848,7 @@ int parakeet_full_with_state( __func__, total_mel_frames, left_context_mel_frames, chunk_mel_frames); // The encoder will see the left context, the main chunk, and the right - // context so that it has access to nearby frames during processing. + // context, so that it has access to nearby frames during processing. // Without this there would be hard cutoffs and the encoder might not // be able to detect speech near the edges. state->n_audio_ctx = total_mel_frames; @@ -3865,8 +3865,8 @@ int parakeet_full_with_state( PARAKEET_LOG_DEBUG("%s: encoder output: total=%d frames, left_ctx=%d, chunk=%d\n", __func__, state->n_frames, left_context_enc_frames, chunk_enc_frames); - // The joint network only sees the center chunk encoder frames and - // not the left and right context that the encoder did. + // The joint network only sees the center chunk encoded frames and + // not the left and right context like the encoder did. state->enc_out_buffer.resize(chunk_enc_frames * d_enc); ggml_backend_tensor_get(state->enc_out, state->enc_out_buffer.data(), left_context_enc_frames * d_enc * sizeof(float), @@ -3888,7 +3888,7 @@ int parakeet_full_with_state( if (new_token_count > 0) { std::string text; std::vector result_tokens; - const int64_t chunk_t0 = 100LL * left_sample / PARAKEET_SAMPLE_RATE; + const int64_t chunk_t0 = 100LL * left_sample / PARAKEET_SAMPLE_RATE; const int64_t chunk_t1 = 100LL * right_sample / PARAKEET_SAMPLE_RATE; const int frame_offset = chunk_t0 / ctx->model.hparams.subsampling_factor; @@ -3955,11 +3955,6 @@ int parakeet_chunk( if (params.no_context) { parakeet_reset_state(state); - - state->tdt_stream_state.initialized = false; - state->tdt_stream_state.last_token = 0; - state->tdt_stream_state.time_step = 0; - state->tdt_stream_state.decoded_length = 0; } if (n_samples > 0) { From 7e0e274dd8ecce51bc05c9dfda9a1c6804535a95 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 11 May 2026 09:29:33 +0200 Subject: [PATCH 31/33] parakeet : add parakeet_stream_clear helper function --- src/parakeet.cpp | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/parakeet.cpp b/src/parakeet.cpp index feac4ea423a..00688715ae2 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -3521,6 +3521,16 @@ struct parakeet_full_params parakeet_full_default_params(enum parakeet_sampling_ return result; } +static void parakeet_stream_clear(struct parakeet_state * state) { + state->stream.buffer.clear(); + state->stream.n_samples_advanced = 0; + state->stream.n_left_ctx = 0; + state->stream.n_chunk = 0; + state->stream.n_right_ctx = 0; + state->stream.params = {}; + state->stream.initialized = false; +} + static void parakeet_reset_state(struct parakeet_state * state) { state->decoded_tokens.clear(); state->decoded_token_data.clear(); @@ -3544,13 +3554,7 @@ static void parakeet_stream_reset_state(struct parakeet_state * state) { state->result_all.clear(); - state->stream.buffer.clear(); - state->stream.n_samples_advanced = 0; - state->stream.n_left_ctx = 0; - state->stream.n_chunk = 0; - state->stream.n_right_ctx = 0; - state->stream.params = {}; - state->stream.initialized = false; + parakeet_stream_clear(state); state->enc_out_buffer.clear(); state->enc_out_frames = 0; @@ -3768,13 +3772,7 @@ int parakeet_stream_flush( state->stream.n_samples_advanced += n_flush_chunk; } - state->stream.buffer.clear(); - state->stream.n_samples_advanced = 0; - state->stream.n_left_ctx = 0; - state->stream.n_chunk = 0; - state->stream.n_right_ctx = 0; - state->stream.params = {}; - state->stream.initialized = false; + parakeet_stream_clear(state); return 0; } From ab609b0a5e4196ffbc5ad48e53ce6eed58c4f46b Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 13 May 2026 08:17:51 +0200 Subject: [PATCH 32/33] parakeet : minor cleanup of parakeet.h [no ci] --- include/parakeet.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/parakeet.h b/include/parakeet.h index 3db69c71d68..e73f6aa2f90 100644 --- a/include/parakeet.h +++ b/include/parakeet.h @@ -116,7 +116,7 @@ extern "C" { // This can be used to set a custom log mel spectrogram inside the default state of the provided parakeet context. // Use this instead of parakeet_pcm_to_mel() if you want to provide your own log mel spectrogram. - // n_mel must be 80 + // n_mel must be 128 // Returns 0 on success PARAKEET_API int parakeet_set_mel( struct parakeet_context * ctx, @@ -187,9 +187,9 @@ extern "C" { PARAKEET_API int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len); // Special tokens - PARAKEET_API parakeet_token parakeet_token_blank (struct parakeet_context * ctx); - PARAKEET_API parakeet_token parakeet_token_unk (struct parakeet_context * ctx); - PARAKEET_API parakeet_token parakeet_token_bos(struct parakeet_context * ctx); + PARAKEET_API parakeet_token parakeet_token_blank(struct parakeet_context * ctx); + PARAKEET_API parakeet_token parakeet_token_unk (struct parakeet_context * ctx); + PARAKEET_API parakeet_token parakeet_token_bos (struct parakeet_context * ctx); // Performance information from the default state. struct parakeet_timings { @@ -216,7 +216,7 @@ extern "C" { struct parakeet_context * ctx, struct parakeet_state * state, const parakeet_token_data * token_data, - void * user_data); + void * user_data); // Text segment callback // Called on every newly generated text segment From 9164f9ce31d747be96e719429230d1c5b47387c5 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 13 May 2026 15:33:49 +0200 Subject: [PATCH 33/33] parakeet : remove flash attention param This commit removes the `flash_attn` parameter as flash attention is not currently supported. --- examples/parakeet-cli/parakeet-cli.cpp | 6 ------ include/parakeet.h | 1 - src/parakeet.cpp | 2 -- 3 files changed, 9 deletions(-) diff --git a/examples/parakeet-cli/parakeet-cli.cpp b/examples/parakeet-cli/parakeet-cli.cpp index df2de45b84e..293484431dc 100644 --- a/examples/parakeet-cli/parakeet-cli.cpp +++ b/examples/parakeet-cli/parakeet-cli.cpp @@ -16,7 +16,6 @@ struct parakeet_params { int32_t right_context_ms = 4960; bool use_gpu = true; - bool flash_attn = true; int32_t gpu_device = 0; bool print_segments = false; @@ -66,8 +65,6 @@ static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & para else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); } - else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = false; } - else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-ps" || arg == "--print-segments") { params.print_segments = true; } else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } else if (arg == "-of" || arg == "--output-file") { params.output_file = ARGV_NEXT; } @@ -97,8 +94,6 @@ static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_para fprintf(stderr, " -f, --file FILE [%-7s] input audio file\n", ""); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); fprintf(stderr, " -dev N, --device N [%-7d] GPU device to use\n", params.gpu_device); - fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); - fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", !params.flash_attn ? "true" : "false"); fprintf(stderr, " -ps, --print-segments [%-7s] print segment information\n", params.print_segments ? "true" : "false"); fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); fprintf(stderr, " -of, --output-file FILE [%-7s] output file path (without file extension)\n", ""); @@ -141,7 +136,6 @@ int main(int argc, char ** argv) { struct parakeet_context_params ctx_params = parakeet_context_default_params(); ctx_params.use_gpu = params.use_gpu; - ctx_params.flash_attn = params.flash_attn; ctx_params.gpu_device = params.gpu_device; if (!params.no_prints) { diff --git a/include/parakeet.h b/include/parakeet.h index e73f6aa2f90..9e72483f788 100644 --- a/include/parakeet.h +++ b/include/parakeet.h @@ -47,7 +47,6 @@ extern "C" { struct parakeet_context_params { bool use_gpu; - bool flash_attn; int gpu_device; // CUDA device }; diff --git a/src/parakeet.cpp b/src/parakeet.cpp index 00688715ae2..a593c47d9f6 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -3026,7 +3026,6 @@ struct parakeet_state * parakeet_init_state(parakeet_context * ctx) { struct parakeet_context_params parakeet_context_default_params() { struct parakeet_context_params result = { /*.use_gpu =*/ true, - /*.flash_attn =*/ true, /*.gpu_device =*/ 0, }; return result; @@ -3117,7 +3116,6 @@ struct parakeet_context * parakeet_init_with_params_no_state(struct parakeet_mod ggml_time_init(); PARAKEET_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu); - PARAKEET_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn); PARAKEET_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device); PARAKEET_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count()); PARAKEET_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count());