diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fb115b22abb..0d704052583 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -634,6 +634,8 @@ jobs: -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DBUILD_SHARED_LIBS=ON -DWHISPER_SDL2=${{ matrix.sdl2 }} + -DGGML_NATIVE=OFF + -DGGML_BMI2=OFF - name: Build run: | diff --git a/CMakeLists.txt b/CMakeLists.txt index a0f74041321..7aa4728ba81 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -179,12 +179,20 @@ set(WHISPER_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location get_directory_property(WHISPER_TRANSIENT_DEFINES COMPILE_DEFINITIONS) set_target_properties(whisper PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/whisper.h) + install(TARGETS whisper LIBRARY PUBLIC_HEADER) target_compile_definitions(whisper PRIVATE WHISPER_VERSION="${PROJECT_VERSION}" ) +set_target_properties(parakeet PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/parakeet.h) +install(TARGETS parakeet LIBRARY PUBLIC_HEADER) + +target_compile_definitions(parakeet PRIVATE + PARAKEET_VERSION="${PROJECT_VERSION}" +) + configure_package_config_file( ${CMAKE_CURRENT_SOURCE_DIR}/cmake/whisper-config.cmake.in ${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake @@ -210,6 +218,35 @@ configure_file(cmake/whisper.pc.in install(FILES "${CMAKE_CURRENT_BINARY_DIR}/whisper.pc" DESTINATION lib/pkgconfig) +set(PARAKEET_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files") +set(PARAKEET_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") +set(PARAKEET_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") + +configure_package_config_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/parakeet-config.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet + PATH_VARS + PARAKEET_INCLUDE_INSTALL_DIR + PARAKEET_LIB_INSTALL_DIR + PARAKEET_BIN_INSTALL_DIR) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake + VERSION ${WHISPER_INSTALL_VERSION} + COMPATIBILITY SameMajorVersion) + +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet) + +configure_file(cmake/parakeet.pc.in + "${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc" + @ONLY) + +install(FILES "${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc" + DESTINATION lib/pkgconfig) + # # programs, examples and tests # diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index 4b09b6ebe13..99894f1234d 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -30,6 +30,6 @@ #{libs}: cmake-targets cmake-targets: #{"\t"}"#{cmake}" -S sources -B build #{options} - #{"\t"}"#{cmake}" --build build --config Release --target common whisper + #{"\t"}"#{cmake}" --build build --config Release --target common whisper parakeet EOF end diff --git a/cmake/parakeet-config.cmake.in b/cmake/parakeet-config.cmake.in new file mode 100644 index 00000000000..aadb55c2d19 --- /dev/null +++ b/cmake/parakeet-config.cmake.in @@ -0,0 +1,30 @@ +set(PARAKEET_VERSION @WHISPER_INSTALL_VERSION@) +set(PARAKEET_BUILD_COMMIT @WHISPER_BUILD_COMMIT@) +set(PARAKEET_BUILD_NUMBER @WHISPER_BUILD_NUMBER@) +set(PARAKEET_SHARED_LIB @BUILD_SHARED_LIBS@) + +@PACKAGE_INIT@ + +set_and_check(PARAKEET_INCLUDE_DIR "@PACKAGE_PARAKEET_INCLUDE_INSTALL_DIR@") +set_and_check(PARAKEET_LIB_DIR "@PACKAGE_PARAKEET_LIB_INSTALL_DIR@") +set_and_check(PARAKEET_BIN_DIR "@PACKAGE_PARAKEET_BIN_INSTALL_DIR@") + +find_package(ggml REQUIRED HINTS ${PARAKEET_LIB_DIR}/cmake) + +find_library(parakeet_LIBRARY parakeet + REQUIRED + HINTS ${PARAKEET_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH +) + +add_library(parakeet UNKNOWN IMPORTED) +set_target_properties(parakeet + PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PARAKEET_INCLUDE_DIR}" + INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + IMPORTED_LOCATION "${parakeet_LIBRARY}" + INTERFACE_COMPILE_FEATURES cxx_std_11 + POSITION_INDEPENDENT_CODE ON) + +check_required_components(parakeet) diff --git a/cmake/parakeet.pc.in b/cmake/parakeet.pc.in new file mode 100644 index 00000000000..a83cc472547 --- /dev/null +++ b/cmake/parakeet.pc.in @@ -0,0 +1,10 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=${exec_prefix}/lib +includedir=${prefix}/include + +Name: parakeet +Description: Port of NVIDIA's Parakeet model in C/C++ +Version: @PROJECT_VERSION@ +Libs: -L${libdir} -lggml -lggml-base -lparakeet +Cflags: -I${includedir} diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index b202ca00b77..2cf5a67aff3 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -107,6 +107,7 @@ else() add_subdirectory(server) add_subdirectory(quantize) add_subdirectory(vad-speech-segments) + add_subdirectory(parakeet-cli) if (WHISPER_SDL2) add_subdirectory(stream) add_subdirectory(command) diff --git a/examples/parakeet-cli/CMakeLists.txt b/examples/parakeet-cli/CMakeLists.txt new file mode 100644 index 00000000000..adb9aba38ef --- /dev/null +++ b/examples/parakeet-cli/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET parakeet-cli) +add_executable(${TARGET} parakeet-cli.cpp) + +include(DefaultTargetOptions) + +target_link_libraries(${TARGET} PRIVATE common parakeet ${FFMPEG_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) + +install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/parakeet-cli/README.md b/examples/parakeet-cli/README.md new file mode 100644 index 00000000000..dfe58f57a03 --- /dev/null +++ b/examples/parakeet-cli/README.md @@ -0,0 +1,112 @@ +# whisper.cpp/examples/parakeet-cli + +This is an example of using the [Parakeet] model in whisper.cpp. + +### Download converted model +```console +$ hf download danbev/parakeet parakeet-tdt-0.6b-v3.bin --local-dir models +``` + +### Building +```console +$ cmake -B build -S . +$ cmake --build build --target parakeet-cli -j 12 +``` + +### Usage +```console +$ ./build/bin/parakeet-cli --help + +usage: ./build/bin/parakeet-cli [options] file0 file1 ... +supported audio formats: flac, mp3, ogg, wav + +options: + -h, --help [default] show this help message and exit + -t N, --threads N [4 ] number of threads to use during computation + -cl N, --chunk-length N [10000 ] chunk length in milliseconds + -lc N, --left-context N [10000 ] left context in milliseconds + -rc N, --right-context N [4960 ] right context in milliseconds + -m, --model FILE [models/ggml-parakeet-tdt-0.6b-v3.bin] model path + -f, --file FILE [ ] input audio file + -ng, --no-gpu [false ] disable GPU + -dev N, --device N [0 ] GPU device to use + -fa, --flash-attn [false ] enable flash attention + -nfa, --no-flash-attn [false ] disable flash attention + -ps, --print-segments [false ] print segment information +``` + +### Example +```console +$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3.bin -f samples/jfk.wav +Processing audio (176000 samples, 11.00 seconds) +Processing audio: total_frames=1101, chunk_size=1101 +parakeet_decode: starting decode with n_frames=138 +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. +``` + +To print segment information: +```console +$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3.bin -f samples/jfk.wav --print-segments +Processing audio (176000 samples, 11.00 seconds) +Processing audio: total_frames=1101, chunk_size=1101 +parakeet_decode: starting decode with n_frames=138 +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. + +Segments (1): +Segment 0: [0 -> 1101] "And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country." +Tokens [38]: + [ 0] id= 1976 frame= 3 dur_idx= 4 dur_val= 4 p=0.9996 plog=-15.6206 t0= 24 t1= 56 word_start=true "▁And" + [ 1] id= 547 frame= 7 dur_idx= 4 dur_val= 4 p=0.9999 plog=-18.7922 t0= 56 t1= 88 word_start=true "▁so" + [ 2] id= 7877 frame= 11 dur_idx= 2 dur_val= 2 p=0.8451 plog=-14.5929 t0= 88 t1= 88 word_start=false "," + [ 3] id= 1103 frame= 13 dur_idx= 3 dur_val= 3 p=0.9996 plog=-15.6127 t0= 104 t1= 128 word_start=true "▁my" + [ 4] id= 309 frame= 16 dur_idx= 1 dur_val= 1 p=0.9912 plog=-11.9635 t0= 128 t1= 136 word_start=true "▁f" + [ 5] id= 530 frame= 17 dur_idx= 2 dur_val= 2 p=1.0000 plog=-13.5239 t0= 136 t1= 152 word_start=false "ell" + [ 6] id= 596 frame= 19 dur_idx= 3 dur_val= 3 p=1.0000 plog=-16.3120 t0= 152 t1= 176 word_start=false "ow" + [ 7] id= 3213 frame= 22 dur_idx= 4 dur_val= 4 p=0.9999 plog=-10.1462 t0= 176 t1= 208 word_start=true "▁Amer" + [ 8] id= 404 frame= 26 dur_idx= 4 dur_val= 4 p=1.0000 plog=-25.0910 t0= 208 t1= 240 word_start=false "ic" + [ 9] id= 667 frame= 30 dur_idx= 4 dur_val= 4 p=1.0000 plog=-27.1707 t0= 240 t1= 272 word_start=false "ans" + [10] id= 7877 frame= 37 dur_idx= 4 dur_val= 4 p=0.9094 plog=-16.3405 t0= 272 t1= 272 word_start=false "," + [11] id= 279 frame= 41 dur_idx= 4 dur_val= 4 p=0.9980 plog=-19.7244 t0= 328 t1= 360 word_start=true "▁a" + [12] id= 583 frame= 45 dur_idx= 4 dur_val= 4 p=1.0000 plog=-24.5312 t0= 360 t1= 392 word_start=false "sk" + [13] id= 1491 frame= 53 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2991 t0= 424 t1= 456 word_start=true "▁not" + [14] id= 3470 frame= 65 dur_idx= 4 dur_val= 4 p=0.9995 plog=-16.7306 t0= 520 t1= 552 word_start=true "▁what" + [15] id= 3629 frame= 69 dur_idx= 2 dur_val= 2 p=0.8139 plog=-11.6486 t0= 552 t1= 568 word_start=true "▁your" + [16] id= 867 frame= 75 dur_idx= 1 dur_val= 1 p=0.9980 plog=-12.5265 t0= 600 t1= 608 word_start=true "▁co" + [17] id= 331 frame= 76 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.6697 t0= 608 t1= 624 word_start=false "un" + [18] id= 958 frame= 78 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.3621 t0= 624 t1= 640 word_start=false "tr" + [19] id= 7893 frame= 80 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.3245 t0= 640 t1= 656 word_start=false "y" + [20] id= 2059 frame= 82 dur_idx= 3 dur_val= 3 p=1.0000 plog=-17.7694 t0= 656 t1= 680 word_start=true "▁can" + [21] id= 458 frame= 85 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2510 t0= 680 t1= 712 word_start=true "▁do" + [22] id= 509 frame= 89 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.0688 t0= 712 t1= 744 word_start=true "▁for" + [23] id= 1180 frame= 93 dur_idx= 4 dur_val= 4 p=0.9999 plog=-25.0567 t0= 744 t1= 776 word_start=true "▁you" + [24] id= 7877 frame= 98 dur_idx= 4 dur_val= 4 p=0.8820 plog=-14.2549 t0= 776 t1= 776 word_start=false "," + [25] id= 279 frame=102 dur_idx= 3 dur_val= 3 p=0.9992 plog=-16.8176 t0= 816 t1= 840 word_start=true "▁a" + [26] id= 583 frame=105 dur_idx= 4 dur_val= 4 p=1.0000 plog=-21.0352 t0= 840 t1= 872 word_start=false "sk" + [27] id= 3470 frame=109 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.4659 t0= 872 t1= 896 word_start=true "▁what" + [28] id= 1180 frame=112 dur_idx= 4 dur_val= 4 p=0.9997 plog=-17.6392 t0= 896 t1= 928 word_start=true "▁you" + [29] id= 2059 frame=116 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.5484 t0= 928 t1= 952 word_start=true "▁can" + [30] id= 458 frame=119 dur_idx= 2 dur_val= 2 p=1.0000 plog=-15.9953 t0= 952 t1= 968 word_start=true "▁do" + [31] id= 509 frame=121 dur_idx= 3 dur_val= 3 p=1.0000 plog=-15.9605 t0= 968 t1= 992 word_start=true "▁for" + [32] id= 3629 frame=124 dur_idx= 2 dur_val= 2 p=0.9994 plog=-12.2083 t0= 992 t1=1008 word_start=true "▁your" + [33] id= 867 frame=126 dur_idx= 2 dur_val= 2 p=0.9969 plog=-9.1252 t0=1008 t1=1024 word_start=true "▁co" + [34] id= 331 frame=128 dur_idx= 1 dur_val= 1 p=0.9999 plog=-12.6911 t0=1024 t1=1032 word_start=false "un" + [35] id= 958 frame=129 dur_idx= 1 dur_val= 1 p=1.0000 plog=-8.8885 t0=1032 t1=1040 word_start=false "tr" + [36] id= 7893 frame=130 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.1441 t0=1040 t1=1056 word_start=false "y" + [37] id= 7883 frame=132 dur_idx= 4 dur_val= 4 p=0.9567 plog=-11.5227 t0=1056 t1=1056 word_start=false "." +``` + +### Model conversion +Clone the original model from Hugging Face: +```console +$ git clone https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3 +``` +Convert the model: +```console +(venv) $ python models/convert-parakeet-to-ggml.py \ + --model \ + --use-f32 \ + --out-dir models \ + --out-name ggml-parakeet-tdt-0.6b-v3.bin +``` + +[Parakeet]: https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3 diff --git a/examples/parakeet-cli/parakeet-cli.cpp b/examples/parakeet-cli/parakeet-cli.cpp new file mode 100644 index 00000000000..293484431dc --- /dev/null +++ b/examples/parakeet-cli/parakeet-cli.cpp @@ -0,0 +1,259 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include +#include +#include +#include +#include +#include + +// command-line parameters +struct parakeet_params { + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + int32_t chunk_length_ms = 10000; + int32_t left_context_ms = 10000; + int32_t right_context_ms = 4960; + + bool use_gpu = true; + int32_t gpu_device = 0; + + bool print_segments = false; + bool output_txt = false; + bool no_prints = false; + + std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin"; + std::string output_file = ""; + std::vector fname_inp = {}; +}; + +static void parakeet_print_usage(int argc, char ** argv, const parakeet_params & params); + +static char * requires_value_error(const std::string & arg) { + fprintf(stderr, "error: argument %s requires value\n", arg.c_str()); + exit(1); +} + +static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & params) { + if (const char * env_device = std::getenv("PARAKEET_ARG_DEVICE")) { + params.gpu_device = std::stoi(env_device); + } + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-"){ + params.fname_inp.push_back(arg); + continue; + } + + if (arg[0] != '-') { + params.fname_inp.push_back(arg); + continue; + } + + if (arg == "-h" || arg == "--help") { + parakeet_print_usage(argc, argv, params); + exit(0); + } + #define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg)) + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } + else if (arg == "-cl" || arg == "--chunk-length") { params.chunk_length_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-lc" || arg == "--left-context") { params.left_context_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-rc" || arg == "--right-context") { params.right_context_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); } + else if (arg == "-ps" || arg == "--print-segments") { params.print_segments = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-of" || arg == "--output-file") { params.output_file = ARGV_NEXT; } + else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; } + else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + parakeet_print_usage(argc, argv, params); + exit(1); + } + } + + return true; +} + +static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_params & params) { + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options] file0 file1 ...\n", argv[0]); + fprintf(stderr, "supported audio formats: flac, mp3, ogg, wav\n"); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -cl N, --chunk-length N [%-7d] chunk length in milliseconds\n", params.chunk_length_ms); + fprintf(stderr, " -lc N, --left-context N [%-7d] left context in milliseconds\n", params.left_context_ms); + fprintf(stderr, " -rc N, --right-context N [%-7d] right context in milliseconds\n", params.right_context_ms); + fprintf(stderr, " -m, --model FILE [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f, --file FILE [%-7s] input audio file\n", ""); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -dev N, --device N [%-7d] GPU device to use\n", params.gpu_device); + fprintf(stderr, " -ps, --print-segments [%-7s] print segment information\n", params.print_segments ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -of, --output-file FILE [%-7s] output file path (without file extension)\n", ""); + fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false"); + fprintf(stderr, "\n"); +} + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + bool * is_first = (bool *) user_data; + + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, *is_first, text_buf, sizeof(text_buf)); + printf("%s", text_buf); + fflush(stdout); + + *is_first = false; +} + +static void cb_log_disable(enum ggml_log_level , const char * , void * ) { } + +int main(int argc, char ** argv) { + ggml_backend_load_all(); + + parakeet_params params; + + if (parakeet_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.no_prints) { + parakeet_log_set(cb_log_disable, NULL); + } + + if (params.fname_inp.empty()) { + fprintf(stderr, "error: no input files specified\n"); + parakeet_print_usage(argc, argv, params); + return 1; + } + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + ctx_params.use_gpu = params.use_gpu; + ctx_params.gpu_device = params.gpu_device; + + if (!params.no_prints) { + fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str()); + } + + + struct parakeet_context * pctx = parakeet_init_from_file_with_params(params.model.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "error: failed to load Parakeet model from '%s'\n", params.model.c_str()); + return 1; + } + + if (!params.no_prints) { + fprintf(stderr, "Successfully loaded Parakeet model\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info()); + } + + // Process each input file + for (const auto & fname : params.fname_inp) { + if (!params.no_prints) { + fprintf(stderr, "\nProcessing file: %s\n", fname.c_str()); + } + + std::vector pcmf32; + std::vector> pcmf32s; + if (!read_audio_data(fname.c_str(), pcmf32, pcmf32s, false)) { + fprintf(stderr, "error: failed to read audio file '%s'\n", fname.c_str()); + continue; + } + + if (pcmf32.empty()) { + fprintf(stderr, "error: no audio data in file '%s'\n", fname.c_str()); + continue; + } + + bool is_first = true; + struct parakeet_full_params full_params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + full_params.n_threads = params.n_threads; + full_params.chunk_length_ms = params.chunk_length_ms; + full_params.left_context_ms = params.left_context_ms; + full_params.right_context_ms = params.right_context_ms; + full_params.new_token_callback = token_callback; + full_params.new_token_callback_user_data = &is_first; + + const int mel_frames = (int)(pcmf32.size() / PARAKEET_HOP_LENGTH); + if (mel_frames <= parakeet_n_audio_ctx(pctx)) { + full_params.chunk_length_ms = 0; + } + + int ret = parakeet_full(pctx, full_params, pcmf32.data(), pcmf32.size()); + + if (ret != 0) { + fprintf(stderr, "error: failed to process audio file '%s'\n", fname.c_str()); + continue; + } + + printf("\n"); + + if (params.output_txt) { + const std::string fname_out = (!params.output_file.empty() ? params.output_file : fname) + ".txt"; + + std::ofstream fout(fname_out); + if (fout.is_open()) { + const int n_segments = parakeet_full_n_segments(pctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = parakeet_full_get_segment_text(pctx, i); + fout << text << "\n"; + } + fout.close(); + if (!params.no_prints) { + fprintf(stderr, "Output written to: %s\n", fname_out.c_str()); + } + } else { + fprintf(stderr, "error: failed to open '%s' for writing\n", fname_out.c_str()); + } + } + + if (!params.no_prints) { + parakeet_print_timings(pctx); + } + + if (params.print_segments) { + const int n_segments = parakeet_full_n_segments(pctx); + fprintf(stderr, "\nSegments (%d):\n", n_segments); + + for (int i = 0; i < n_segments; i++) { + const char * text = parakeet_full_get_segment_text(pctx, i); + const int64_t t0 = parakeet_full_get_segment_t0(pctx, i); + const int64_t t1 = parakeet_full_get_segment_t1(pctx, i); + const int n_tokens = parakeet_full_n_tokens(pctx, i); + + fprintf(stderr, "Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text); + fprintf(stderr, "Tokens [%d]:\n", n_tokens); + + for (int j = 0; j < n_tokens; j++) { + parakeet_token_data token_data = parakeet_full_get_token_data(pctx, i, j); + const char * token_str = parakeet_token_to_str(pctx, token_data.id); + + fprintf(stderr, " [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%s \"%s\"\n", + j, + token_data.id, + token_data.frame_index, + token_data.duration_idx, + token_data.duration_value, + token_data.p, + token_data.plog, + (long long)token_data.t0, + (long long)token_data.t1, + token_data.is_word_start ? "true": "false", + token_str); + } + } + } + } + + parakeet_free(pctx); + + return 0; +} diff --git a/include/parakeet.h b/include/parakeet.h new file mode 100644 index 00000000000..9e72483f788 --- /dev/null +++ b/include/parakeet.h @@ -0,0 +1,369 @@ +#ifndef PARAKEET_H +#define PARAKEET_H + +#include "ggml.h" +#include "ggml-cpu.h" + +#include +#include +#include + +#ifdef __GNUC__ +# define PARAKEET_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define PARAKEET_DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define PARAKEET_DEPRECATED(func, hint) func +#endif + +#ifdef PARAKEET_SHARED +# ifdef _WIN32 +# ifdef PARAKEET_BUILD +# define PARAKEET_API __declspec(dllexport) +# else +# define PARAKEET_API __declspec(dllimport) +# endif +# else +# define PARAKEET_API __attribute__ ((visibility ("default"))) +# endif +#else +# define PARAKEET_API +#endif + +#define PARAKEET_SAMPLE_RATE 16000 +#define PARAKEET_HOP_LENGTH 160 + +#ifdef __cplusplus +extern "C" { +#endif + + struct parakeet_context; + struct parakeet_state; + struct parakeet_full_params; + + typedef int32_t parakeet_pos; + typedef int32_t parakeet_token; + typedef int32_t parakeet_seq_id; + + struct parakeet_context_params { + bool use_gpu; + int gpu_device; // CUDA device + }; + + typedef struct parakeet_token_data { + parakeet_token id; // the BPE subword ID (0-8191) + + int duration_idx; // index into the models durations array + int duration_value; // actual duration value + int frame_index; + + float p; + float plog; + + int64_t t0; + int64_t t1; + + bool is_word_start; + } parakeet_token_data; + + typedef struct parakeet_model_loader { + void * context; + + size_t (*read)(void * ctx, void * output, size_t read_size); + bool (*eof)(void * ctx); + void (*close)(void * ctx); + } parakeet_model_loader; + + PARAKEET_API const char * parakeet_version(void); + + // Various functions for loading a ggml parakeet model. + // Allocate (almost) all memory needed for the model. + // Return NULL on failure + PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params (const char * path_model, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_with_params (struct parakeet_model_loader * loader, struct parakeet_context_params params); + + // These are the same as the above, but the internal state of the context is not allocated automatically + // It is the responsibility of the caller to allocate the state using parakeet_init_state() (#523) + PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params_no_state (const char * path_model, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_with_params_no_state (struct parakeet_model_loader * loader, struct parakeet_context_params params); + + PARAKEET_API struct parakeet_state * parakeet_init_state(struct parakeet_context * ctx); + + // Frees all allocated memory + PARAKEET_API void parakeet_free (struct parakeet_context * ctx); + PARAKEET_API void parakeet_free_state(struct parakeet_state * state); + PARAKEET_API void parakeet_free_params(struct parakeet_full_params * params); + PARAKEET_API void parakeet_free_context_params(struct parakeet_context_params * params); + + // Convert RAW PCM audio to log mel spectrogram. + // The resulting spectrogram is stored inside the default state of the provided parakeet context. + // Returns 0 on success + PARAKEET_API int parakeet_pcm_to_mel( + struct parakeet_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + PARAKEET_API int parakeet_pcm_to_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * samples, + int n_samples, + int n_threads); + + // This can be used to set a custom log mel spectrogram inside the default state of the provided parakeet context. + // Use this instead of parakeet_pcm_to_mel() if you want to provide your own log mel spectrogram. + // n_mel must be 128 + // Returns 0 on success + PARAKEET_API int parakeet_set_mel( + struct parakeet_context * ctx, + const float * data, + int n_len, + int n_mel); + + PARAKEET_API int parakeet_set_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * data, + int n_len, + int n_mel); + + // Run the Parakeet encoder on the log mel spectrogram stored inside the default state in the provided parakeet context. + // Make sure to call parakeet_pcm_to_mel() or parakeet_set_mel() first. + // offset can be used to specify the offset of the first frame in the spectrogram. + // Returns 0 on success + PARAKEET_API int parakeet_encode( + struct parakeet_context * ctx, + int offset, + int n_threads); + + PARAKEET_API int parakeet_encode_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + int offset, + int n_threads); + + // Convert the provided text into tokens. + // The tokens pointer must be large enough to hold the resulting tokens. + // Returns the number of tokens on success, no more than n_max_tokens + // Returns a negative number on failure - the number of tokens that would have been returned + // TODO: not sure if correct + PARAKEET_API int parakeet_tokenize( + struct parakeet_context * ctx, + const char * text, + parakeet_token * tokens, + int n_max_tokens); + + // Return the number of tokens in the provided text + // Equivalent to: -parakeet_tokenize(ctx, text, NULL, 0) + int parakeet_token_count(struct parakeet_context * ctx, const char * text); + + PARAKEET_API int parakeet_n_len (struct parakeet_context * ctx); // mel length + PARAKEET_API int parakeet_n_len_from_state(struct parakeet_state * state); // mel length + PARAKEET_API int parakeet_n_vocab (struct parakeet_context * ctx); + PARAKEET_API int parakeet_n_audio_ctx (struct parakeet_context * ctx); + + PARAKEET_API int parakeet_model_n_vocab (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_ctx (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_state(struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_head (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_layer(struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_mels (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_ftype (struct parakeet_context * ctx); + + // Token logits obtained from the last call to parakeet_full/parakeet_chunk + // The logits for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab + PARAKEET_API float * parakeet_get_logits (struct parakeet_context * ctx); + PARAKEET_API float * parakeet_get_logits_from_state(struct parakeet_state * state); + + // Token Id -> String. Uses the vocabulary in the provided context + PARAKEET_API const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token); + + PARAKEET_API int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len); + + // Special tokens + PARAKEET_API parakeet_token parakeet_token_blank(struct parakeet_context * ctx); + PARAKEET_API parakeet_token parakeet_token_unk (struct parakeet_context * ctx); + PARAKEET_API parakeet_token parakeet_token_bos (struct parakeet_context * ctx); + + // Performance information from the default state. + struct parakeet_timings { + float sample_ms; + float encode_ms; + float decode_ms; + }; + PARAKEET_API struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx); + PARAKEET_API void parakeet_print_timings(struct parakeet_context * ctx); + PARAKEET_API void parakeet_reset_timings(struct parakeet_context * ctx); + + // Print system information + PARAKEET_API const char * parakeet_print_system_info(void); + + // Available sampling strategies + enum parakeet_sampling_strategy { + PARAKEET_SAMPLING_GREEDY, + }; + + // Token callback. + // Called for each new predicted token. + // Use the parakeet_full_...() functions to obtain the text segments + typedef void (*parakeet_new_token_callback)( + struct parakeet_context * ctx, + struct parakeet_state * state, + const parakeet_token_data * token_data, + void * user_data); + + // Text segment callback + // Called on every newly generated text segment + // Use the parakeet_full_...() functions to obtain the text segments + typedef void (*parakeet_new_segment_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int n_new, void * user_data); + + // Progress callback + typedef void (*parakeet_progress_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int progress, void * user_data); + + // Encoder begin callback + // If not NULL, called before the encoder starts + // If it returns false, the computation is aborted + typedef bool (*parakeet_encoder_begin_callback)(struct parakeet_context * ctx, struct parakeet_state * state, void * user_data); + + // Parameters for the parakeet_full() function + // If you change the order or add new parameters, make sure to update the default values in parakeet.cpp: + // parakeet_full_default_params() + struct parakeet_full_params { + enum parakeet_sampling_strategy strategy; + + int n_threads; + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms + + bool no_context; // do not use past transcription (if any) as context + + int audio_ctx; // overwrite the audio context size (0 = use default) + + int chunk_length_ms; // length of each chunk in ms + int left_context_ms; // left context in ms + int right_context_ms; // right context in ms + + // called for every newly generated text segment + parakeet_new_segment_callback new_segment_callback; + void * new_segment_callback_user_data; + + // called for every newly generated token + parakeet_new_token_callback new_token_callback; + void * new_token_callback_user_data; + + // called on each progress update + parakeet_progress_callback progress_callback; + void * progress_callback_user_data; + + // called each time before the encoder starts + parakeet_encoder_begin_callback encoder_begin_callback; + void * encoder_begin_callback_user_data; + + // called each time before ggml computation starts + ggml_abort_callback abort_callback; + void * abort_callback_user_data; + }; + + // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see parakeet_free_context_params() & parakeet_free_params() + PARAKEET_API struct parakeet_context_params * parakeet_context_default_params_by_ref(void); + PARAKEET_API struct parakeet_context_params parakeet_context_default_params (void); + + PARAKEET_API struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy); + PARAKEET_API struct parakeet_full_params parakeet_full_default_params (enum parakeet_sampling_strategy strategy); + + // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Not thread safe for same context + PARAKEET_API int parakeet_full( + struct parakeet_context * ctx, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + PARAKEET_API int parakeet_full_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + // Process a single chunk of audio data that fits within the model's audio context window. + // This is more efficient than parakeet_full() for short audio clips. + PARAKEET_API int parakeet_chunk( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + // Initialize streaming state for a new stream. + PARAKEET_API int parakeet_stream_init( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params); + + // Push audio samples in streaming mode. Internally this function will structure + // the samples in a buffer where with a left context, a center chunk, and a + // right context. The encoder will see the complete buffer which enables it + // to get boundry context for the target/center audio chunk. This avoids hard + // cut offs at the chunk boundaries. The joint network then only sees the + // center chunk and this function internally handles the context windowing. + PARAKEET_API int parakeet_stream_push( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * samples, + int n_samples); + + // Flush the final partial chunk at end-of-stream. + PARAKEET_API int parakeet_stream_flush( + struct parakeet_context * ctx, + struct parakeet_state * state); + + // Number of generated text segments + PARAKEET_API int parakeet_full_n_segments (struct parakeet_context * ctx); + PARAKEET_API int parakeet_full_n_segments_from_state(struct parakeet_state * state); + + // Get the start and end time of the specified segment + PARAKEET_API int64_t parakeet_full_get_segment_t0 (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment); + + PARAKEET_API int64_t parakeet_full_get_segment_t1 (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment); + + // Get the text of the specified segment + PARAKEET_API const char * parakeet_full_get_segment_text (struct parakeet_context * ctx, int i_segment); + PARAKEET_API const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment); + + // Get number of tokens in the specified segment + PARAKEET_API int parakeet_full_n_tokens (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment); + + // Get the token text of the specified token in the specified segment + PARAKEET_API const char * parakeet_full_get_token_text (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token); + + // Get the token id of the specified token in the specified segment + PARAKEET_API parakeet_token parakeet_full_get_token_id (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Get token data for the specified token in the specified segment + PARAKEET_API parakeet_token_data parakeet_full_get_token_data (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Get the probability of the specified token in the specified segment + PARAKEET_API float parakeet_full_get_token_p (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Control logging output; default behavior is to print to stderr + + PARAKEET_API void parakeet_log_set(ggml_log_callback log_callback, void * user_data); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/models/convert-parakeet-to-ggml.py b/models/convert-parakeet-to-ggml.py new file mode 100755 index 00000000000..cbeec7e6c70 --- /dev/null +++ b/models/convert-parakeet-to-ggml.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python3 +# Convert Parakeet TDT model from NeMo format to ggml format +# +# Usage: python convert-parakeet-to-ggml.py --model parakeet-model.nemo --output-dir output-dir [--use-f32] +# +# The NeMo file is a tar archive containing: +# - model_weights.ckpt (PyTorch checkpoint) +# - model_config.yaml (model configuration) +# - tokenizer files +# +# This script extracts the NeMo archive, loads the model weights and configuration, +# and saves them in ggml format compatible with whisper.cpp. +# + +import torch +import argparse +import io +import os +import sys +import struct +import tarfile +import tempfile +import shutil +import yaml +import numpy as np +from pathlib import Path +from typing import Optional + +def hz_to_mel(freq): + return 2595.0 * np.log10(1.0 + freq / 700.0) + +def mel_to_hz(mel): + return 700.0 * (10.0**(mel / 2595.0) - 1.0) + +def extract_nemo_archive(nemo_path, extract_dir): + print(f"Extracting {nemo_path} to {extract_dir}") + with tarfile.open(nemo_path, 'r') as tar: + tar.extractall(path=extract_dir) + print("Extraction complete") + +def load_model_config(config_path): + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + return config + +def load_tokenizer(extract_dir, config): + tokenizer_model_path = None + tokenizer_vocab_path = None + + for file in os.listdir(extract_dir): + if file.endswith('_tokenizer.model'): + tokenizer_model_path = os.path.join(extract_dir, file) + elif file.endswith('tokenizer.vocab'): + tokenizer_vocab_path = os.path.join(extract_dir, file) + + if not tokenizer_model_path: + raise FileNotFoundError("Tokenizer model file not found") + + if not tokenizer_vocab_path: + raise FileNotFoundError("Tokenizer vocab file not found") + + tokens = {} + with open(tokenizer_vocab_path, 'r', encoding='utf-8') as f: + for idx, line in enumerate(f): + parts = line.strip().split('\t') + if len(parts) >= 1: + token = parts[0] + tokens[token.encode('utf-8')] = idx + + print(f"Loaded {len(tokens)} tokens from {os.path.basename(tokenizer_vocab_path)}") + + if len(tokens) != 8192: + print(f"WARNING: Expected 8192 tokens, got {len(tokens)}") + + return tokens + +def write_tensor(fout, name, data, use_f16=True, force_f32=False): + if 'pre_encode.conv' in name and 'bias' in name and len(data.shape) == 1: + data = data.reshape(1, -1, 1, 1) + print(f" Reshaped conv bias {name} to {data.shape}") + + n_dims = len(data.shape) + + ftype = 1 if use_f16 and not force_f32 else 0 + if force_f32: + data = data.astype(np.float32) + elif use_f16: + if n_dims < 2 or 'bias' in name or 'norm' in name: + data = data.astype(np.float32) + ftype = 0 + else: + data = data.astype(np.float16) + else: + data = data.astype(np.float32) + + dims_reversed = [data.shape[n_dims - 1 - i] for i in range(n_dims)] + print(f"Processing: {name} {list(data.shape)}, dtype: {data.dtype}, n_dims: {n_dims}, reversed: {dims_reversed}") + name_bytes = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(name_bytes) + + data.tofile(fout) + +def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None): + nemo_path = Path(nemo_path) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Create temporary directory for extraction + with tempfile.TemporaryDirectory() as temp_dir: + extract_nemo_archive(nemo_path, temp_dir) + + config_path = os.path.join(temp_dir, 'model_config.yaml') + config = load_model_config(config_path) + + print("Model configuration:") + print(f" Sample rate: {config['sample_rate']}") + print(f" Encoder layers: {config['encoder']['n_layers']}") + print(f" Encoder d_model: {config['encoder']['d_model']}") + print(f" Mel features: {config['preprocessor']['features']}") + + weights_path = os.path.join(temp_dir, 'model_weights.ckpt') + print(f"\nLoading model weights from {weights_path}") + checkpoint = torch.load(weights_path, map_location='cpu') + + # Extract state dict + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + print(f"Loaded {len(state_dict)} tensors") + + # Load tokenizer + print("\nLoading tokenizer...") + tokens = load_tokenizer(temp_dir, config) + print(f"Loaded {len(tokens)} tokens") + + # Prepare hyperparameters for the Parakeet ggml format. + hparams = { + 'n_audio_ctx': 5000, + 'n_audio_state': config['encoder']['d_model'], + 'n_audio_head': config['encoder']['n_heads'], + 'n_audio_layer': config['encoder']['n_layers'], + 'n_mels': config['preprocessor']['features'], + 'n_fft': config['preprocessor']['n_fft'], + 'subsampling_factor': config['encoder']['subsampling_factor'], + 'n_subsampling_channels': config['encoder']['subsampling_conv_channels'], + + 'n_pred_dim': config['decoder']['prednet']['pred_hidden'], + 'n_pred_layers': config['decoder']['prednet']['pred_rnn_layers'], + 'n_vocab': config['decoder']['vocab_size'], + 'n_tdt_durations': config['model_defaults']['num_tdt_durations'], + 'n_max_tokens': config['decoding']['greedy']['max_symbols'], + } + + print("\nGGML hyperparameters:") + for key, value in hparams.items(): + print(f" {key}: {value}") + + # Create output file + if out_name: + fname_out = output_dir / out_name + else: + fname_out = output_dir / ("ggml-model-f32.bin" if not use_f16 else "ggml-model.bin") + print(f"\nWriting to {fname_out}") + + with open(fname_out, 'wb') as fout: + # Write magic number + fout.write(struct.pack("i", 0x67676d6c)) # 'ggml' in hex + + # Write hyperparameters + fout.write(struct.pack("i", hparams['n_vocab'])) + fout.write(struct.pack("i", hparams['n_audio_ctx'])) + fout.write(struct.pack("i", hparams['n_audio_state'])) + fout.write(struct.pack("i", hparams['n_audio_head'])) + fout.write(struct.pack("i", hparams['n_audio_layer'])) + fout.write(struct.pack("i", hparams['n_mels'])) + fout.write(struct.pack("i", 1 if use_f16 else 0)) + fout.write(struct.pack("i", hparams['n_fft'])) + fout.write(struct.pack("i", hparams['subsampling_factor'])) + fout.write(struct.pack("i", hparams['n_subsampling_channels'])) + fout.write(struct.pack("i", hparams['n_pred_dim'])) + fout.write(struct.pack("i", hparams['n_pred_layers'])) + fout.write(struct.pack("i", hparams['n_tdt_durations'])) + fout.write(struct.pack("i", hparams['n_max_tokens'])) + + # Extract mel filterbank from model + fb_key = None + for key in state_dict.keys(): + if 'featurizer.fb' in key or 'filterbank' in key.lower(): + fb_key = key + break + + if not fb_key: + print("\nERROR: Mel filterbank not found in model!") + print("Expected tensor with 'featurizer.fb' or 'filterbank' in name") + print("\nAvailable preprocessor tensors:") + for key in sorted(state_dict.keys()): + if 'preprocessor' in key or 'featurizer' in key: + print(f" {key}: {state_dict[key].shape}") + raise ValueError("Mel filterbank tensor not found in model") + + print(f"\nUsing model's mel filterbank from: {fb_key}") + mel_filters = state_dict[fb_key].squeeze().numpy().astype(np.float32) + print(f" Filterbank shape: {mel_filters.shape}") + print(f" Filterbank min/max values: {mel_filters.min():.6f} / {mel_filters.max():.6f}") + print(f" Filterbank non-zero elements: {np.count_nonzero(mel_filters)} / {mel_filters.size}") + print(f" First row sum: {mel_filters[0].sum():.6f}") + + if len(mel_filters.shape) != 2: + raise ValueError(f"Expected 2D filterbank, got shape {mel_filters.shape}") + + n_mels, n_freqs = mel_filters.shape + fout.write(struct.pack("i", n_mels)) # n_mel + fout.write(struct.pack("i", n_freqs)) # n_fb (frequency bins) + + # Write mel filterbank + for i in range(n_mels): + for j in range(n_freqs): + fout.write(struct.pack("f", mel_filters[i, j])) + + # Extract window function from model + window_key = None + for key in state_dict.keys(): + if 'featurizer.window' in key or 'preproc' in key and 'window' in key: + window_key = key + break + + if not window_key: + print("\nERROR: Window function not found in model!") + print("Expected tensor with 'featurizer.window' in name") + raise ValueError("Window function tensor not found in model") + + print(f"\nUsing model's window function from: {window_key}") + window = state_dict[window_key].squeeze().numpy().astype(np.float32) + print(f" Window shape: {window.shape}") + print(f" Window min/max values: {window.min():.6f} / {window.max():.6f}") + print(f" Window non-zero elements: {np.count_nonzero(window)} / {window.size}") + print(f" Window sum: {window.sum():.6f}") + + if len(window.shape) != 1: + raise ValueError(f"Expected 1D window, got shape {window.shape}") + + n_window = window.shape[0] + fout.write(struct.pack("i", n_window)) + + # Write window function + for i in range(n_window): + fout.write(struct.pack("f", window[i])) + + # Write TDT durations + tdt_durations = config['model_defaults']['tdt_durations'] + if len(tdt_durations) != hparams['n_tdt_durations']: + raise ValueError(f"TDT durations count mismatch: {len(tdt_durations)} vs {hparams['n_tdt_durations']}") + + for duration in tdt_durations: + fout.write(struct.pack("I", duration)) + + fout.write(struct.pack("i", len(tokens))) + for token_bytes, idx in sorted(tokens.items(), key=lambda x: x[1]): + fout.write(struct.pack("i", len(token_bytes))) + fout.write(token_bytes) + + print("\nConverting model weights...") + for name, tensor in state_dict.items(): + # Skip the filterbank and window - already written in preprocessing section + if name == fb_key: + continue + if name == window_key: + continue + + # Don't squeeze Conv2d weights - they need to preserve all 4 dimensions + if 'conv' in name and 'weight' in name and len(tensor.shape) == 4: + data = tensor.numpy() + else: + data = tensor.squeeze().numpy() + + write_tensor(fout, name, data, use_f16=use_f16) + + print(f"\nConversion complete!") + print(f"Output file: {fname_out}") + print(f"File size: {fname_out.stat().st_size / (1024**2):.2f} MB") + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Convert Parakeet TDT model from NeMo format to ggml format' + ) + parser.add_argument('--model', type=str, required=True, + help='Path to Parakeet .nemo model file') + parser.add_argument('--out-dir', type=str, required=True, + help='Directory to write ggml model file') + parser.add_argument('--use-f32', action='store_true', default=False, + help='Use f32 instead of f16 (default: f16)') + parser.add_argument('--out-name', type=str, default=None, + help='Output file name (default: ggml-model.bin or ggml-model-f32.bin)') + + args = parser.parse_args() + + if not os.path.exists(args.model): + print(f"Error: {args.model} not found") + sys.exit(1) + + use_f16 = not args.use_f32 + convert_parakeet_to_ggml(args.model, args.out_dir, use_f16, args.out_name) diff --git a/models/requirements-parakeet.txt b/models/requirements-parakeet.txt new file mode 100644 index 00000000000..5239ae0af5d --- /dev/null +++ b/models/requirements-parakeet.txt @@ -0,0 +1,3 @@ +torch +numpy +pyyaml diff --git a/scripts/upload-parakeet.py b/scripts/upload-parakeet.py new file mode 100644 index 00000000000..311a16a0239 --- /dev/null +++ b/scripts/upload-parakeet.py @@ -0,0 +1,75 @@ +import os +from huggingface_hub import HfApi, create_repo + +# TODO: change to ggml-org once merged. +USER_NAME = "danbev" +REPO_ID = f"{USER_NAME}/parakeet" +LOCAL_GGUF_PATH = "models/ggml-parakeet-tdt-0.6b-v3.bin" +REMOTE_GGUF_NAME = "parakeet-tdt-0.6b-v3.bin" + +MODEL_CARD_CONTENT = f"""--- +license: apache-2.0 +base_model: {USER_NAME}/parakeet +tags: +- gguf +--- + +# Parakeet Model Card + +## Description +This is an iterative release of the Parakeet model in whisper.cpp format. + +## Usage +You can use this file with [parakeet-cli](https://github.com/danbev/whisper.cpp/tree/parakeet-support/examples/parakeet-cli). + +Build parakeet-cli: +```console +$ git clone -b parakeet-support https://github.com/danbev/whisper.cpp.git +$ cd whisper.cpp +$ cmake -B build -S . +$ cmake --build build --target parakeet-cli -j 12 +``` + +Download the model: +```console +$ hf download danbev/parakeet parakeet-tdt-0.6b-v3.bin --local-dir models +``` + +Run: +```console +$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3.bin -f samples/jfk.wav +``` + +""" + +api = HfApi() + +def deploy_iteration(): + create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True) + + print("Updating Model Card...") + api.upload_file( + path_or_fileobj=MODEL_CARD_CONTENT.encode(), + path_in_repo="README.md", + repo_id=REPO_ID, + repo_type="model", + commit_message="Update README.md" + ) + + print(f"Uploading {REMOTE_GGUF_NAME}...") + api.upload_file( + path_or_fileobj=LOCAL_GGUF_PATH, + path_in_repo=REMOTE_GGUF_NAME, + repo_id=REPO_ID, + repo_type="model", + commit_message="Upload new parakeet iteration" + ) + + print(f"\nDeployment successful!") + print(f"URL: https://huggingface.co/{REPO_ID}") + +if __name__ == "__main__": + if os.path.exists(LOCAL_GGUF_PATH): + deploy_iteration() + else: + print(f"Error: {LOCAL_GGUF_PATH} not found.") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 095a2791de5..4e7c5b24dc3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -109,23 +109,43 @@ add_library(whisper whisper.cpp ) +add_library(parakeet + ../include/parakeet.h + parakeet-arch.h + parakeet.cpp + ) + +target_include_directories(parakeet PUBLIC . ../include) +target_compile_features (parakeet PUBLIC cxx_std_11) +target_link_libraries(parakeet PUBLIC ggml Threads::Threads) + # Set the version numbers set_target_properties(whisper PROPERTIES VERSION ${PROJECT_VERSION} SOVERSION ${SOVERSION} ) +set_target_properties(parakeet PROPERTIES + VERSION ${PROJECT_VERSION} + SOVERSION ${SOVERSION} +) + target_include_directories(whisper PUBLIC . ../include) target_compile_features (whisper PUBLIC cxx_std_11) # don't bump if (CMAKE_CXX_BYTE_ORDER STREQUAL "BIG_ENDIAN") set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_BIG_ENDIAN) + set(PARAKEET_EXTRA_FLAGS ${PARAKEET_EXTRA_FLAGS} -DPARAKEET_BIG_ENDIAN) endif() if (WHISPER_EXTRA_FLAGS) target_compile_options(whisper PRIVATE ${WHISPER_EXTRA_FLAGS}) endif() +if (PARAKEET_EXTRA_FLAGS) + target_compile_options(parakeet PRIVATE ${PARAKEET_EXTRA_FLAGS}) +endif() + find_package(Threads REQUIRED) target_link_libraries(whisper PUBLIC ggml Threads::Threads) @@ -144,4 +164,7 @@ endif() if (BUILD_SHARED_LIBS) set_target_properties(whisper PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(whisper PRIVATE WHISPER_SHARED WHISPER_BUILD) + + set_target_properties(parakeet PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(parakeet PRIVATE PARAKEET_SHARED PARAKEET_BUILD) endif() diff --git a/src/parakeet-arch.h b/src/parakeet-arch.h new file mode 100644 index 00000000000..2f3668a84f7 --- /dev/null +++ b/src/parakeet-arch.h @@ -0,0 +1,191 @@ +#pragma once + +#include "ggml.h" + +#include + +enum parakeet_tensor { + // Encoder pre_encode + PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, + + // Encoder layers (per-layer) + PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, + PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, + PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, + PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_BN_BIAS, + PARAKEET_TENSOR_ENC_CONV_BN_MEAN, + PARAKEET_TENSOR_ENC_CONV_BN_VAR, + PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, + PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, + PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, + PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, + PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, + PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, + PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, + + // Prediction network + PARAKEET_TENSOR_PRED_EMBED_WEIGHT, + PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, + PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, + PARAKEET_TENSOR_PRED_LSTM_BIAS_IH, + PARAKEET_TENSOR_PRED_LSTM_BIAS_HH, + + // Joint network + PARAKEET_TENSOR_JOINT_PRED_WEIGHT, + PARAKEET_TENSOR_JOINT_PRED_BIAS, + PARAKEET_TENSOR_JOINT_ENC_WEIGHT, + PARAKEET_TENSOR_JOINT_ENC_BIAS, + PARAKEET_TENSOR_JOINT_NET_WEIGHT, + PARAKEET_TENSOR_JOINT_NET_BIAS, +}; + +static const std::map PARAKEET_TENSOR_NAMES = { + // Encoder pre_encode + {PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, "encoder.pre_encode.out.weight"}, + {PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, "encoder.pre_encode.out.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, "encoder.pre_encode.conv.0.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, "encoder.pre_encode.conv.0.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, "encoder.pre_encode.conv.2.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, "encoder.pre_encode.conv.2.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, "encoder.pre_encode.conv.3.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, "encoder.pre_encode.conv.3.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, "encoder.pre_encode.conv.5.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, "encoder.pre_encode.conv.5.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, "encoder.pre_encode.conv.6.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, "encoder.pre_encode.conv.6.bias"}, + + // Encoder layers (use %d for layer number) + {PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, "encoder.layers.%d.norm_feed_forward1.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, "encoder.layers.%d.norm_feed_forward1.bias"}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward1.linear1.weight"}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward1.linear2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, "encoder.layers.%d.norm_conv.weight"}, + {PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, "encoder.layers.%d.norm_conv.bias"}, + {PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, "encoder.layers.%d.conv.pointwise_conv1.weight"}, + {PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, "encoder.layers.%d.conv.depthwise_conv.weight"}, + {PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, "encoder.layers.%d.conv.batch_norm.weight"}, + {PARAKEET_TENSOR_ENC_CONV_BN_BIAS, "encoder.layers.%d.conv.batch_norm.bias"}, + {PARAKEET_TENSOR_ENC_CONV_BN_MEAN, "encoder.layers.%d.conv.batch_norm.running_mean"}, + {PARAKEET_TENSOR_ENC_CONV_BN_VAR, "encoder.layers.%d.conv.batch_norm.running_var"}, + {PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, "encoder.layers.%d.conv.batch_norm.num_batches_tracked"}, + {PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, "encoder.layers.%d.conv.pointwise_conv2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, "encoder.layers.%d.norm_self_att.weight"}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, "encoder.layers.%d.norm_self_att.bias"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, "encoder.layers.%d.self_attn.pos_bias_u"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, "encoder.layers.%d.self_attn.pos_bias_v"}, + {PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, "encoder.layers.%d.self_attn.linear_q.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, "encoder.layers.%d.self_attn.linear_k.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, "encoder.layers.%d.self_attn.linear_v.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, "encoder.layers.%d.self_attn.linear_out.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, "encoder.layers.%d.self_attn.linear_pos.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, "encoder.layers.%d.norm_feed_forward2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, "encoder.layers.%d.norm_feed_forward2.bias"}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward2.linear1.weight"}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward2.linear2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, "encoder.layers.%d.norm_out.weight"}, + {PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, "encoder.layers.%d.norm_out.bias"}, + + // Prediction network + {PARAKEET_TENSOR_PRED_EMBED_WEIGHT, "decoder.prediction.embed.weight"}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, "decoder.prediction.dec_rnn.lstm.weight_ih_l%d"}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, "decoder.prediction.dec_rnn.lstm.weight_hh_l%d"}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_IH, "decoder.prediction.dec_rnn.lstm.bias_ih_l%d"}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_HH, "decoder.prediction.dec_rnn.lstm.bias_hh_l%d"}, + + // Joint network + {PARAKEET_TENSOR_JOINT_PRED_WEIGHT, "joint.pred.weight"}, + {PARAKEET_TENSOR_JOINT_PRED_BIAS, "joint.pred.bias"}, + {PARAKEET_TENSOR_JOINT_ENC_WEIGHT, "joint.enc.weight"}, + {PARAKEET_TENSOR_JOINT_ENC_BIAS, "joint.enc.bias"}, + {PARAKEET_TENSOR_JOINT_NET_WEIGHT, "joint.joint_net.2.weight"}, + {PARAKEET_TENSOR_JOINT_NET_BIAS, "joint.joint_net.2.bias"}, +}; + +static const std::map PARAKEET_TENSOR_INFO = { + // Encoder pre_encode + {PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, GGML_OP_ADD}, + + // Encoder layers + {PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_CONV_BN_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_CONV_BN_MEAN, GGML_OP_SUB}, + {PARAKEET_TENSOR_ENC_CONV_BN_VAR, GGML_OP_DIV}, + {PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, GGML_OP_NONE}, + {PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, GGML_OP_ADD}, + + // Prediction network + {PARAKEET_TENSOR_PRED_EMBED_WEIGHT, GGML_OP_GET_ROWS}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_IH, GGML_OP_ADD}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_HH, GGML_OP_ADD}, + + // Joint network + {PARAKEET_TENSOR_JOINT_PRED_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_PRED_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_JOINT_ENC_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_ENC_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_JOINT_NET_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_NET_BIAS, GGML_OP_ADD}, +}; diff --git a/src/parakeet.cpp b/src/parakeet.cpp new file mode 100644 index 00000000000..a593c47d9f6 --- /dev/null +++ b/src/parakeet.cpp @@ -0,0 +1,4134 @@ +#include "parakeet.h" +#include "parakeet-arch.h" + +#include "ggml.h" +#include "ggml-cpp.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#include +#include +#include +#include +#define _USE_MATH_DEFINES +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +#include +#endif + +#if defined(PARAKEET_BIG_ENDIAN) +template +static T byteswap(T value) { + T value_swapped; + char * source = reinterpret_cast(&value); + char * target = reinterpret_cast(&value_swapped); + int size = sizeof(T); + for (int i = 0; i < size; i++) { + target[size - 1 - i] = source[i]; + } + return value_swapped; +} + +template +static void byteswap_tensor_data(ggml_tensor * tensor) { + T * datum = reinterpret_cast(tensor->data); + for (int i = 0; i < ggml_nelements(tensor); i++) { + datum[i] = byteswap(datum[i]); + } +} + +static void byteswap_tensor(ggml_tensor * tensor) { + switch (tensor->type) { + case GGML_TYPE_I16: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F16: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_I32: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F32: { + byteswap_tensor_data(tensor); + break; + } + default: { // GML_TYPE_I8 + break; + } + } +} + +#define BYTESWAP_VALUE(d) d = byteswap(d) +#define BYTESWAP_FILTERS(f) \ + do { \ + for (auto & datum : f.data) { \ + datum = byteswap(datum); \ + } \ + } while (0) +#define BYTESWAP_TENSOR(t) \ + do { \ + byteswap_tensor(t); \ + } while (0) +#else +#define BYTESWAP_VALUE(d) do {} while (0) +#define BYTESWAP_FILTERS(f) do {} while (0) +#define BYTESWAP_TENSOR(t) do {} while (0) +#endif + +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define PARAKEET_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// + +PARAKEET_ATTRIBUTE_FORMAT(2, 3) +static void parakeet_log_internal (ggml_log_level level, const char * format, ...); +static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data); + +#define PARAKEET_LOG_ERROR(...) parakeet_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define PARAKEET_LOG_WARN(...) parakeet_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define PARAKEET_LOG_INFO(...) parakeet_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) + +// define this to enable verbose trace logging - useful for debugging purposes +//#define PARAKEET_DEBUG + +#if defined(PARAKEET_DEBUG) +#define PARAKEET_LOG_DEBUG(...) parakeet_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#else +#define PARAKEET_LOG_DEBUG(...) +#endif + +#define PARAKEET_ASSERT(x) \ + do { \ + if (!(x)) { \ + PARAKEET_LOG_ERROR("PARAKEET_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +#define PARAKEET_MAX_NODES 4096 + +static std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +// +// ggml helpers +// + +static bool ggml_graph_compute_helper( + struct ggml_cgraph * graph, + int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) }; + + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); + + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data); + } + + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + ggml_backend_set_n_threads_fn(backend.get(), n_threads); + } + + return ggml_backend_graph_compute(backend.get(), graph) == GGML_STATUS_SUCCESS; +} + +static bool ggml_graph_compute_helper( + ggml_backend_sched_t sched, + struct ggml_cgraph * graph, + int n_threads, + bool sched_reset = true) { + for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i); + ggml_backend_dev_t dev = ggml_backend_get_device(backend); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + + auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (fn_set_n_threads) { + fn_set_n_threads(backend, n_threads); + } + } + + const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS); + + if (!t || sched_reset) { + ggml_backend_sched_reset(sched); + } + + return t; +} + +// TODO: move these functions to ggml-base with support for ggml-backend? + + +struct parakeet_mel { + int n_len = 0; + int n_len_org = 0; + int n_mel = 0; + + std::vector data; +}; + +struct parakeet_filters { + int32_t n_mel = 0; + int32_t n_fb = 0; // number of frequency bins + + std::vector data; +}; + +struct parakeet_vocab { + using id = int32_t; + using token = std::string; + + int n_vocab = 8192; + size_t max_token_length = 0; + + std::map token_to_id; + std::map id_to_token; + + id token_unk; + id token_bos; + id token_blank; + id token_eos; +}; + +struct parakeet_segment { + int64_t t0; + int64_t t1; + + std::string text; + + std::vector tokens; +}; + +struct parakeet_batch { + int32_t n_tokens; + + parakeet_token * token; + int32_t * i_time; // index of the audio frame + parakeet_pos * pos; + int32_t * n_seq_id; // always 1, here for consistency with llama.cpp + parakeet_seq_id ** seq_id; // null terminated + int8_t * logits; +}; + +// ggml_backend_sched wrapper for parakeet usage +struct parakeet_sched { + ggml_backend_sched_t sched = nullptr; + + std::vector meta; +}; + +// TODO: Find out is there a multiple version types. It is not yet clear to me +// at this point. +enum parakeet_arch { + PARAKEET_ARCH_UNKNOWN = 0, + PARAKEET_ARCH_TDT = 1, // NVIDIA Parakeet TDT (RNN-T) +}; + +struct parakeet_hparams { + int32_t n_vocab = 8192; + int32_t n_audio_ctx = 0; // 0 = unlimited, will be set based on input + int32_t n_audio_state = 1024; + int32_t n_audio_head = 8; + int32_t n_audio_layer = 24; + int32_t n_mels = 128; + int32_t ftype = 1; + int32_t n_fft = 512; // FFT size for mel spectrogram + float eps = 1e-5f; + int32_t subsampling_factor = 8; + int32_t n_subsampling_channels = 256; + int32_t n_pred_dim = 640; + int32_t n_pred_layers = 2; + int32_t n_tdt_durations = 5; + int32_t n_max_tokens = 10; + + parakeet_arch arch = PARAKEET_ARCH_TDT; +}; + +struct parakeet_layer_encoder { + struct ggml_tensor * norm_ff1_w = nullptr; + struct ggml_tensor * norm_ff1_b = nullptr; + + struct ggml_tensor * ff1_linear1_w = nullptr; + struct ggml_tensor * ff1_linear2_w = nullptr; + + struct ggml_tensor * norm_conv_w = nullptr; + struct ggml_tensor * norm_conv_b = nullptr; + + struct ggml_tensor * conv_pw1_w = nullptr; // pointwise_conv1 + struct ggml_tensor * conv_dw_w = nullptr; // depthwise_conv + struct ggml_tensor * conv_bn_w = nullptr; // batch_norm weight + struct ggml_tensor * conv_bn_b = nullptr; // batch_norm bias + struct ggml_tensor * conv_bn_mean = nullptr; // batch_norm running_mean + struct ggml_tensor * conv_bn_var = nullptr; // batch_norm running_var + struct ggml_tensor * conv_bn_num_batches = nullptr; // batch_norm num_batches_tracked + struct ggml_tensor * conv_pw2_w = nullptr; // pointwise_conv2 + + struct ggml_tensor * norm_attn_w = nullptr; + struct ggml_tensor * norm_attn_b = nullptr; + + struct ggml_tensor * attn_pos_bias_u = nullptr; + struct ggml_tensor * attn_pos_bias_v = nullptr; + struct ggml_tensor * attn_q_w = nullptr; + struct ggml_tensor * attn_k_w = nullptr; + struct ggml_tensor * attn_v_w = nullptr; + struct ggml_tensor * attn_out_w = nullptr; + struct ggml_tensor * attn_pos_w = nullptr; + + struct ggml_tensor * norm_ff2_w = nullptr; + struct ggml_tensor * norm_ff2_b = nullptr; + + struct ggml_tensor * ff2_linear1_w = nullptr; + struct ggml_tensor * ff2_linear2_w = nullptr; + + struct ggml_tensor * norm_out_w = nullptr; + struct ggml_tensor * norm_out_b = nullptr; +}; + +struct parakeet_lsmt_layer { + struct ggml_tensor * ih_w = nullptr; // input-to-hidden weight + struct ggml_tensor * ih_b = nullptr; // input-to-hidden bias + struct ggml_tensor * hh_w = nullptr; // hidden-to-hidden weight + struct ggml_tensor * hh_b = nullptr; // hidden-to-hidden bias +}; + +struct parakeet_prediction_network { + struct ggml_tensor * embed_w = nullptr; + + std::vector lstm_layer; +}; + +struct parakeet_joint_network { + struct ggml_tensor * pred_w = nullptr; + struct ggml_tensor * pred_b = nullptr; + struct ggml_tensor * enc_w = nullptr; + struct ggml_tensor * enc_b = nullptr; + struct ggml_tensor * net_w = nullptr; + struct ggml_tensor * net_b = nullptr; +}; + +struct parakeet_model { + parakeet_filters filters; + parakeet_hparams hparams; + + struct ggml_tensor * enc_pre_out_w = nullptr; + struct ggml_tensor * enc_pre_out_b = nullptr; + struct ggml_tensor * enc_pre_conv_0_w = nullptr; + struct ggml_tensor * enc_pre_conv_0_b = nullptr; + struct ggml_tensor * enc_pre_conv_2_w = nullptr; + struct ggml_tensor * enc_pre_conv_2_b = nullptr; + struct ggml_tensor * enc_pre_conv_3_w = nullptr; + struct ggml_tensor * enc_pre_conv_3_b = nullptr; + struct ggml_tensor * enc_pre_conv_5_w = nullptr; + struct ggml_tensor * enc_pre_conv_5_b = nullptr; + struct ggml_tensor * enc_pre_conv_6_w = nullptr; + struct ggml_tensor * enc_pre_conv_6_b = nullptr; + + std::vector layers; + + parakeet_prediction_network prediction; + + parakeet_joint_network joint; + + std::vector tdt_durations; + + std::vector ctxs; + + std::vector buffers; + + int n_loaded = 0; + std::map tensors; +}; + +struct parakeet_lstm_state_layer { + struct ggml_tensor * h_state = nullptr; + struct ggml_tensor * c_state = nullptr; +}; + +struct parakeet_lstm_state { + std::vector layer; + + std::vector ctx_buf; + + ggml_backend_buffer_t buffer = nullptr; +}; + +struct tdt_stream_state { + parakeet_token last_token; // last emitted token (for predictor input) + int32_t time_step; // overflow frames into next chunk + int32_t decoded_length; // total tokens decoded + bool initialized; // whether prediction LSTM state has been initialized +}; + +struct parakeet_stream { + std::vector buffer; + int64_t n_samples_advanced = 0; + + int n_left_ctx = 0; + int n_chunk = 0; + int n_right_ctx = 0; + + parakeet_full_params params = {}; + bool initialized = false; +}; + +struct parakeet_state { + int64_t t_sample_us = 0; + int64_t t_encode_us = 0; + int64_t t_decode_us = 0; + int64_t t_mel_us = 0; + + int32_t n_sample = 0; // number of tokens sampled + int32_t n_encode = 0; // number of encoder calls + int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) + int32_t n_fail_p = 0; // number of logprob threshold failures + int32_t n_fail_h = 0; // number of entropy threshold failures + + parakeet_mel mel; + + parakeet_batch batch; + + int n_frames = 0; + + std::vector backends; + + parakeet_sched sched_conv; + parakeet_sched sched_encode; + parakeet_sched sched_decode; + + // outputs from encoder stages + struct ggml_tensor * pre_enc_out = nullptr; + struct ggml_tensor * enc_out = nullptr; + struct ggml_tensor * pred_out = nullptr; + + std::vector pred_out_buf; + ggml_backend_buffer_t pred_out_buffer = nullptr; + + struct ggml_tensor * attn_mask = nullptr; + + std::vector inp_mel; + std::vector inp_mask; + + std::vector enc_out_buffer; + int enc_out_frames = 0; + + std::vector logits; + + std::vector result_all; + + std::vector decoded_tokens; + std::vector decoded_token_data; + + std::string path_model; + + int32_t n_audio_ctx = 0; + + parakeet_lstm_state lstm_state; + + struct tdt_stream_state tdt_stream_state = {0, 0, 0, false}; + + parakeet_stream stream; +}; + +// FFT cache for mel spectrogram computation +struct parakeet_mel_cache { + int n_fft = 0; + + // In FFT, we frequently use sine and cosine operations with the same values. + // We can use precalculated values to speed up the process. + std::vector sin_vals; + std::vector cos_vals; + + // Hann window (Use cosf to eliminate difference) + // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 + std::vector hann_window; + + // Window function from model (Parakeet uses actual window from training) + std::vector window; + + void init(int fft_size) { + n_fft = fft_size; + sin_vals.resize(n_fft); + cos_vals.resize(n_fft); + hann_window.resize(n_fft); + + fill_sin_cos_table(); + fill_hann_window(n_fft, true, hann_window.data()); + } + + void fill_sin_cos_table() { + for (int i = 0; i < n_fft; i++) { + double theta = (2 * M_PI * i) / n_fft; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } + } + + void fill_hann_window(int length, bool periodic, float * output) { + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } + } +}; + +struct parakeet_context { + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + ggml_type wtype = ggml_type::GGML_TYPE_F16; + ggml_type itype = ggml_type::GGML_TYPE_F16; + + parakeet_context_params params; + + parakeet_model model; + parakeet_vocab vocab; + + parakeet_state * state = nullptr; + + parakeet_mel_cache mel_cache; + + std::string path_model; +}; + +struct parakeet_global { + // We save the log callback globally + ggml_log_callback log_callback = parakeet_log_callback_default; + void * log_callback_user_data = nullptr; +}; + +static parakeet_global g_state; + +static const std::string PARAKEET_SPM_SPACE = "\xE2\x96\x81"; + +static inline int utf8_codepoint_len(unsigned char c) { + if ((c & 0x80) == 0x00) return 1; + if ((c & 0xE0) == 0xC0) return 2; + if ((c & 0xF0) == 0xE0) return 3; + if ((c & 0xF8) == 0xF0) return 4; + return 1; +} + +static bool is_sentencepiece_control(const std::string & piece) { + return piece == "" || piece == "" || piece == "" || piece == "[BLANK]"; +} + +static std::string sentencepiece_normalize(const std::string & text) { + std::string normalized; + normalized.reserve(text.size() + PARAKEET_SPM_SPACE.size()); + normalized += PARAKEET_SPM_SPACE; // SentencePiece dummy prefix + + for (unsigned char c : text) { + if (std::isspace(c)) { + normalized += PARAKEET_SPM_SPACE; + } else { + normalized += static_cast(c); + } + } + + return normalized; +} + +static std::string sentencepiece_piece_to_text(const std::string & piece, bool is_first_piece) { + if (is_sentencepiece_control(piece)) { + return ""; + } + + std::string text; + text.reserve(piece.size()); + + size_t pos = 0; + while (pos < piece.size()) { + if (piece.compare(pos, PARAKEET_SPM_SPACE.size(), PARAKEET_SPM_SPACE) == 0) { + if (!is_first_piece || !text.empty()) { + text += ' '; + } + pos += PARAKEET_SPM_SPACE.size(); + continue; + } + + text += piece[pos]; + ++pos; + } + + return text; +} + + +static struct parakeet_batch parakeet_batch_init(int32_t n_tokens) { + parakeet_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, }; + + batch.token = (parakeet_token * ) malloc(sizeof(parakeet_token) * (n_tokens)); + batch.i_time = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.pos = (parakeet_pos *) malloc(sizeof(parakeet_pos) * (n_tokens)); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.seq_id = (parakeet_seq_id **) malloc(sizeof(parakeet_seq_id *) * (n_tokens + 1)); + for (int i = 0; i < n_tokens; ++i) { + batch.seq_id[i] = (parakeet_seq_id *) malloc(sizeof(parakeet_seq_id)); + } + batch.seq_id[n_tokens] = nullptr; + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +static void parakeet_batch_free(struct parakeet_batch batch) { + if (batch.token) free(batch.token); + if (batch.i_time) free(batch.i_time); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i]; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} + +static void parakeet_batch_prep_legacy(parakeet_batch & batch, const parakeet_token * tokens, int n_tokens, int n_past, int seq_id) { + batch.n_tokens = n_tokens; + for (int i = 0; i < n_tokens; ++i) { + if (tokens) { + batch.token[i] = tokens[i]; + } + batch.pos [i] = n_past + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i][0] = seq_id; + batch.logits [i] = 0; + } + batch.logits[n_tokens - 1] = 1; +} + + +static size_t parakeet_sched_size(struct parakeet_sched & allocr) { + size_t size = allocr.meta.size(); + for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i); + size += ggml_backend_sched_get_buffer_size(allocr.sched, backend); + } + return size; +} + +// measure the memory usage of a graph and prepare the allocr's internal data buffer +static bool parakeet_sched_graph_init(struct parakeet_sched & allocr, std::vector backends, std::function && get_graph) { + auto & sched = allocr.sched; + auto & meta = allocr.meta; + + sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), PARAKEET_MAX_NODES, false, true); + + if (!sched) { + PARAKEET_LOG_ERROR("%s: failed to create scheduler\n", __func__); + return false; + } + + meta.resize(ggml_tensor_overhead()*PARAKEET_MAX_NODES + ggml_graph_overhead()); + + // since there are dependencies between the different graphs, + // we need to allocate them instead of only reserving to get the correct compute buffer size + if (!ggml_backend_sched_alloc_graph(sched, get_graph())) { + // failed to allocate the compute buffer + PARAKEET_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__); + ggml_backend_sched_free(sched); + sched = nullptr; + return false; + } + + ggml_backend_sched_reset(sched); + + return true; +} + + +template +static void read_safe(parakeet_model_loader * loader, T & dest) { + loader->read(loader->context, &dest, sizeof(T)); + BYTESWAP_VALUE(dest); +} + +static bool parakeet_lstm_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_layer) { + parakeet_lstm_state & lstm_state = pstate.lstm_state; + + lstm_state.ctx_buf.resize(ggml_tensor_overhead() * n_layer * 2); + lstm_state.layer.resize(n_layer); + + struct ggml_init_params params = { + /*.mem_size =*/ lstm_state.ctx_buf.size(), + /*.mem_buffer =*/ lstm_state.ctx_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states context\n", __func__); + return false; + } + + + for (int il = 0; il < n_layer; ++il) { + lstm_state.layer[il].h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 640); + lstm_state.layer[il].c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 640); + } + + lstm_state.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!lstm_state.buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states\n", __func__); + return false; + } + + ggml_backend_buffer_clear(lstm_state.buffer, 0); + + ggml_free(ctx); + + return true; +} + +static bool parakeet_pred_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_pred_dim) { + pstate.pred_out_buf.resize(ggml_tensor_overhead()); + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.pred_out_buf.size(), + /*.mem_buffer =*/ pstate.pred_out_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor context\n", __func__); + return false; + } + + pstate.pred_out = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim); + pstate.pred_out_buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!pstate.pred_out_buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor\n", __func__); + ggml_free(ctx); + return false; + } + + ggml_free(ctx); + + return true; +} + +static ggml_backend_t whisper_backend_init_gpu(const parakeet_context_params & params) { + ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); + + ggml_backend_dev_t dev = nullptr; + + int cnt = 0; + if (params.use_gpu) { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i); + enum ggml_backend_dev_type dev_type = ggml_backend_dev_type(dev_cur); + const char * dev_name = ggml_backend_dev_name(dev_cur); + PARAKEET_LOG_INFO("%s: device %zu: %s (type: %d)\n", __func__, i, dev_name, dev_type); + if (dev_type == GGML_BACKEND_DEVICE_TYPE_GPU || dev_type == GGML_BACKEND_DEVICE_TYPE_IGPU) { + PARAKEET_LOG_INFO("%s: found GPU device %zu: %s (type: %d, cnt: %d)\n", __func__, i, dev_name, dev_type, cnt); + if (cnt == params.gpu_device) { + dev = dev_cur; + } + + if (++cnt > params.gpu_device) { + break; + } + } + } + } + + if (dev == nullptr) { + PARAKEET_LOG_INFO("%s: no GPU found\n", __func__); + return nullptr; + } + + PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); + ggml_backend_t result = ggml_backend_dev_init(dev, nullptr); + if (!result) { + PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + } + + return result; +} + +static std::vector whisper_backend_init(const parakeet_context_params & params) { + std::vector result; + + ggml_backend_t backend_gpu = whisper_backend_init_gpu(params); + + if (backend_gpu) { + result.push_back(backend_gpu); + } + + // ACCEL backends + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + if (!backend) { + PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + continue; + } + result.push_back(backend); + } + } + + ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + if (backend_cpu == nullptr) { + throw std::runtime_error("failed to initialize CPU backend"); + } + result.push_back(backend_cpu); + + return result; +} + +using buft_list_t = std::vector>; + +static buft_list_t make_buft_list(parakeet_context_params & params) { + // Prio order: GPU -> CPU Extra -> CPU + buft_list_t buft_list; + + // GPU + if (params.use_gpu) { + int cnt = 0; + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU) { + if (cnt == params.gpu_device) { + auto * buft = ggml_backend_dev_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + } + } + + if (++cnt > params.gpu_device) { + break; + } + } + } + } + + // CPU Extra + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } + } + + // CPU + buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type()); + + return buft_list; +} + +static bool weight_buft_supported(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + bool op_supported = true; + + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || + ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU || + (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) { + // GPU and default CPU backend support all operators + op_supported = true; + } else { + switch (op) { + // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS + case GGML_OP_GET_ROWS: + case GGML_OP_MUL_MAT: { + ggml_init_params params = { + /*.mem_size =*/ 2 * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error("failed to create ggml context"); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + if (op == GGML_OP_MUL_MAT) { + int64_t n_ctx = hparams.n_audio_ctx; + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } else if (op == GGML_OP_GET_ROWS) { + int64_t num_indices = 8; + ggml_tensor * indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices); + op_tensor = ggml_get_rows(ctx, w, indices); + } + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + break; + } + default: { + op_supported = false; + break; + } + }; + } + + return op_supported; +} + +static ggml_backend_buffer_type_t select_weight_buft(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) { + GGML_ASSERT(!buft_list.empty()); + for (const auto & p : buft_list) { + ggml_backend_dev_t dev = p.first; + ggml_backend_buffer_type_t buft = p.second; + if (weight_buft_supported(hparams, w, op, buft, dev)) { + return buft; + } + } + + return nullptr; +} + +// load the model from a ggml file +// + +// see the convert-parakeet-to-ggml.py script for details +// +static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_context & wctx) { + PARAKEET_LOG_INFO("%s: loading model\n", __func__); + + const int64_t t_start_us = ggml_time_us(); + + wctx.t_start_us = t_start_us; + + auto & model = wctx.model; + auto & vocab = wctx.vocab; + + // verify magic + { + uint32_t magic; + read_safe(loader, magic); + if (magic != GGML_FILE_MAGIC) { + PARAKEET_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__); + return false; + } + } + + //load hparams + parakeet_hparams hparams; + { + read_safe(loader, hparams.n_vocab); + read_safe(loader, hparams.n_audio_ctx); + read_safe(loader, hparams.n_audio_state); + read_safe(loader, hparams.n_audio_head); + read_safe(loader, hparams.n_audio_layer); + read_safe(loader, hparams.n_mels); + read_safe(loader, hparams.ftype); + read_safe(loader, hparams.n_fft); + read_safe(loader, hparams.subsampling_factor); + read_safe(loader, hparams.n_subsampling_channels); + read_safe(loader, hparams.n_pred_dim); + read_safe(loader, hparams.n_pred_layers); + read_safe(loader, hparams.n_tdt_durations); + read_safe(loader, hparams.n_max_tokens); + + hparams.arch = PARAKEET_ARCH_TDT; + wctx.model.hparams = hparams; + + const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR; + + hparams.ftype %= GGML_QNT_VERSION_FACTOR; + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) hparams.ftype); + if (wctx.wtype == GGML_TYPE_COUNT) { + PARAKEET_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, hparams.ftype); + return false; + } + + const char* arch_name = hparams.arch == PARAKEET_ARCH_TDT ? "Parakeet TDT" : "unknown"; + PARAKEET_LOG_INFO("%s: arch = %s\n", __func__, arch_name); + PARAKEET_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + PARAKEET_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + PARAKEET_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + PARAKEET_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + PARAKEET_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + PARAKEET_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels); + PARAKEET_LOG_INFO("%s: n_fft = %d\n", __func__, hparams.n_fft); + PARAKEET_LOG_INFO("%s: eps = %f\n", __func__, hparams.eps); + PARAKEET_LOG_INFO("%s: ftype = %d\n", __func__, hparams.ftype); + PARAKEET_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr); + PARAKEET_LOG_INFO("%s: subsampling_factor = %d\n", __func__, hparams.subsampling_factor); + PARAKEET_LOG_INFO("%s: n_subsampling_channels = %d\n", __func__, hparams.n_subsampling_channels); + PARAKEET_LOG_INFO("%s: n_pred_dim = %d\n", __func__, hparams.n_pred_dim); + PARAKEET_LOG_INFO("%s: n_pred_layers = %d\n", __func__, hparams.n_pred_layers); + PARAKEET_LOG_INFO("%s: n_tdt_durations = %d\n", __func__, hparams.n_tdt_durations); + PARAKEET_LOG_INFO("%s: n_max_tokens = %d\n", __func__, hparams.n_max_tokens); + } + + // load mel filters + { + auto & filters = wctx.model.filters; + + read_safe(loader, filters.n_mel); + read_safe(loader, filters.n_fb); + + filters.data.resize(filters.n_mel * filters.n_fb); + loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float)); + BYTESWAP_FILTERS(filters); + } + + // load window function + { + int32_t n_window = 0; + read_safe(loader, n_window); + + wctx.mel_cache.window.resize(n_window); + loader->read(loader->context, wctx.mel_cache.window.data(), n_window * sizeof(float)); + +#ifdef GGML_BIG_ENDIAN + for (auto & datum : wctx.mel_cache.window) { + datum = byteswap(datum); + } +#endif + + PARAKEET_LOG_INFO("%s: loaded window function with %d samples\n", __func__, n_window); + } + + // load TDT (Token and Duration Transducer) values + { + auto & tdt_durations = wctx.model.tdt_durations; + tdt_durations.resize(hparams.n_tdt_durations); + loader->read(loader->context, tdt_durations.data(), hparams.n_tdt_durations * sizeof(uint32_t)); + + PARAKEET_LOG_INFO("%s: loaded tdt_durations: [", __func__); + for (const auto value : tdt_durations) { + PARAKEET_LOG_INFO("%u ", value); + } + PARAKEET_LOG_INFO("]\n"); + } + + // load vocab + { + int32_t n_vocab = 0; + read_safe(loader, n_vocab); + + std::string word; + std::vector tmp; + + tmp.reserve(128); + + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + read_safe(loader, len); + + if (len > 0) { + tmp.resize(len); + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + word.assign(&tmp[0], tmp.size()); + } else { + PARAKEET_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + word = ""; + } + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + vocab.max_token_length = std::max(vocab.max_token_length, word.size()); + } + // Blank token for transducer is at index n_vocab (8192), outside the vocabulary + int blank_id = n_vocab; + vocab.token_blank = blank_id; + vocab.id_to_token[blank_id] = "[BLANK]"; + vocab.token_to_id["[BLANK]"] = blank_id; + + // Set special token IDs by looking them up in the loaded vocabulary + // These are from the SentencePiece vocab file loaded above + if (vocab.token_to_id.find("") != vocab.token_to_id.end()) { + vocab.token_unk = vocab.token_to_id.at(""); + } else { + vocab.token_unk = 0; // Fallback + } + + if (vocab.token_to_id.find("") != vocab.token_to_id.end()) { + vocab.token_bos = vocab.token_to_id.at(""); + } else if (vocab.token_to_id.find("<|startoftranscript|>") != vocab.token_to_id.end()) { + vocab.token_bos = vocab.token_to_id.at("<|startoftranscript|>"); + } else { + vocab.token_bos = 0; // Fallback + } + + if (vocab.token_to_id.find("") != vocab.token_to_id.end()) { + vocab.token_eos = vocab.token_to_id.at(""); + } else if (vocab.token_to_id.find("<|endoftext|>") != vocab.token_to_id.end()) { + vocab.token_eos = vocab.token_to_id.at("<|endoftext|>"); + } else { + vocab.token_eos = 0; // Fallback + } + + vocab.n_vocab = model.hparams.n_vocab; + + PARAKEET_LOG_INFO("%s: loaded vocab with %d tokens (blank_id=%d, unk=%d, bos=%d, eos=%d)\n", + __func__, n_vocab, blank_id, vocab.token_unk, vocab.token_bos, vocab.token_eos); + } + + const ggml_type wtype = wctx.wtype; + const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type + + const int n_audio_layer = hparams.n_audio_layer; + + // Calculate tensor count: pre_encode (12) + encoder layers (29 per layer) + prediction (9) + joint (6) + size_t n_tensors = 12 + (29 * n_audio_layer) + 9 + 6; + + std::map ctx_map; + auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error("failed to create ggml context"); + } + + ctx_map[buft] = ctx; + wctx.model.ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + // Create a list of available bufts, in priority order + buft_list_t buft_list = make_buft_list(wctx.params); + + auto create_tensor = [&](parakeet_tensor type, ggml_tensor * meta, int layer = -1) -> ggml_tensor * { + ggml_op op = PARAKEET_TENSOR_INFO.at(type); + ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for parakeet tensor %s", + PARAKEET_TENSOR_NAMES.at(type))); + } + + ggml_context * ctx = get_ctx(buft); + ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); + + std::string tensor_name; + if (layer >= 0) { + tensor_name = format(PARAKEET_TENSOR_NAMES.at(type), layer); + } else { + tensor_name = PARAKEET_TENSOR_NAMES.at(type); + } + + wctx.model.tensors[tensor_name] = tensor; + + return tensor; + }; + + // prepare tensors for the weights + + ggml_init_params params = { + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + + const int n_audio_state = hparams.n_audio_state; + + model.layers.resize(n_audio_layer); + + // Encoder pre_encode + const int n_subsampling_channels = hparams.n_subsampling_channels; + model.enc_pre_out_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4096, n_audio_state)); + ggml_set_name(model.enc_pre_out_w, "enc_pre_out_w"); + model.enc_pre_out_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state)); + ggml_set_name(model.enc_pre_out_b, "enc_pre_out_b"); + + model.enc_pre_conv_0_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, ggml_new_tensor_4d(ctx, vtype, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_0_w, "enc_pre_conv_0_w"); + model.enc_pre_conv_0_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_0_b, "enc_pre_conv_0_b"); + + model.enc_pre_conv_2_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, ggml_new_tensor_4d(ctx, vtype, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_2_w, "enc_pre_conv_2_w"); + model.enc_pre_conv_2_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_2_b, "enc_pre_conv_2_b"); + + model.enc_pre_conv_3_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, ggml_new_tensor_4d(ctx, wtype, 1, 1, n_subsampling_channels, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_3_w, "enc_pre_conv_3_w"); + model.enc_pre_conv_3_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_3_b, "enc_pre_conv_3_b"); + + model.enc_pre_conv_5_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, ggml_new_tensor_4d(ctx, vtype, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_5_w, "enc_pre_conv_5_w"); + model.enc_pre_conv_5_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_5_b, "enc_pre_conv_5_b"); + + model.enc_pre_conv_6_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, ggml_new_tensor_4d(ctx, wtype, 1, 1, n_subsampling_channels, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_6_w, "enc_pre_conv_6_w"); + model.enc_pre_conv_6_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_6_b, "enc_pre_conv_6_b"); + + // Encoder layers + for (int i = 0; i < n_audio_layer; ++i) { + auto & layer = model.layers[i]; + + // Feed forward 1 + layer.norm_ff1_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_ff1_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.ff1_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i); + ggml_format_name(layer.ff1_linear1_w, "enc_%d_ff1_linear1_w", i); + layer.ff1_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i); + ggml_format_name(layer.ff1_linear2_w, "enc_%d_ff1_linear2_w", i); + + // Convolution module + layer.norm_conv_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.norm_conv_w, "enc_%d_norm_conv_w", i); + layer.norm_conv_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.norm_conv_b, "enc_%d_norm_conv_b", i); + layer.conv_pw1_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 2*n_audio_state), i); + ggml_format_name(layer.conv_pw1_w, "enc_%d_conv_pw1_w", i); + layer.conv_dw_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 9, n_audio_state), i); + ggml_format_name(layer.conv_dw_w, "enc_%d_conv_dw_w", i); + layer.conv_bn_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_w, "enc_%d_conv_bn_w", i); + layer.conv_bn_b = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_b, "enc_%d_conv_bn_b", i); + layer.conv_bn_mean = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_MEAN, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.conv_bn_var = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_VAR, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_var, "enc_%d_conv_bn_var", i); + layer.conv_bn_num_batches = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), i); + layer.conv_pw2_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + ggml_format_name(layer.conv_pw2_w, "enc_%d_conv_pw2_w", i); + + // Self attention + layer.norm_attn_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_attn_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.attn_pos_bias_u = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 128, 8), i); + layer.attn_pos_bias_v = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 128, 8), i); + layer.attn_q_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_k_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_v_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_out_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_pos_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + ggml_format_name(layer.attn_pos_w, "enc_%d_attn_pos_w", i); + + // Feed forward 2 + layer.norm_ff2_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_ff2_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.ff2_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i); + layer.ff2_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i); + + // Output norm + layer.norm_out_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_out_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + } + + // Prediction network (decoder) + const int dec_hidden = hparams.n_pred_dim; + model.prediction.embed_w = create_tensor(PARAKEET_TENSOR_PRED_EMBED_WEIGHT, ggml_new_tensor_2d(ctx, wtype, dec_hidden, 8193)); + model.prediction.lstm_layer.resize(hparams.n_pred_layers); + for (int i = 0; i < hparams.n_pred_layers; ++i) { + auto & layer = model.prediction.lstm_layer[i]; + layer.ih_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, ggml_new_tensor_2d(ctx, wtype, dec_hidden, 2560), i); + ggml_format_name(layer.ih_w, "pred_%d_ih_w", i); + + layer.ih_b = create_tensor(PARAKEET_TENSOR_PRED_LSTM_BIAS_IH, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2560), i); + ggml_format_name(layer.ih_b, "pred_%d_ih_b", i); + + layer.hh_b = create_tensor(PARAKEET_TENSOR_PRED_LSTM_BIAS_HH, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2560), i); + ggml_format_name(layer.hh_b, "pred_%d_hh_b", i); + + layer.hh_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, ggml_new_tensor_2d(ctx, wtype, dec_hidden, 2560), i); + ggml_format_name(layer.hh_w, "pred_%d_hh_w", i); + } + + // Joint network + model.joint.pred_w = create_tensor(PARAKEET_TENSOR_JOINT_PRED_WEIGHT, ggml_new_tensor_2d(ctx, wtype, dec_hidden, dec_hidden)); + ggml_set_name(model.joint.pred_w, "pred_w"); + model.joint.pred_b = create_tensor(PARAKEET_TENSOR_JOINT_PRED_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden)); + ggml_set_name(model.joint.pred_b, "pred_b"); + model.joint.enc_w = create_tensor(PARAKEET_TENSOR_JOINT_ENC_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, dec_hidden)); + ggml_set_name(model.joint.enc_w, "enc_w"); + model.joint.enc_b = create_tensor(PARAKEET_TENSOR_JOINT_ENC_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden)); + ggml_set_name(model.joint.enc_b, "enc_b"); + model.joint.net_w = create_tensor(PARAKEET_TENSOR_JOINT_NET_WEIGHT, ggml_new_tensor_2d(ctx, wtype, dec_hidden, 8198)); + ggml_set_name(model.joint.net_w, "net_w"); + model.joint.net_b = create_tensor(PARAKEET_TENSOR_JOINT_NET_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 8198)); + ggml_set_name(model.joint.net_b, "net_b"); + + ggml_free(ctx); + + // allocate tensors in the backend buffers + for (auto & p : ctx_map) { + ggml_backend_buffer_type_t buft = p.first; + ggml_context * ctx = p.second; + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (buf) { + wctx.model.buffers.emplace_back(buf); + + size_t size_main = ggml_backend_buffer_get_size(buf); + PARAKEET_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6); + } + } + + // load weights + { + size_t total_size = 0; + + auto & tensors_map = wctx.model.tensors; + int & n_loaded = wctx.model.n_loaded; + + n_loaded = 0; + + std::vector read_buf; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + read_safe(loader, n_dims); + read_safe(loader, length); + read_safe(loader, ttype); + + if (loader->eof(loader->context)) { + break; + } + + int32_t nelements = 1; + int32_t ne[4] = { 1, 1, 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + read_safe(loader, ne[i]); + nelements *= ne[i]; + } + + std::string name; + std::vector tmp(length); // create a buffer + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + name.assign(&tmp[0], tmp.size()); + + if (tensors_map.find(name) == tensors_map.end()) { + PARAKEET_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = tensors_map[name.data()]; + + if (ggml_nelements(tensor) != nelements) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + PARAKEET_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", + __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2] || tensor->ne[3] != ne[3]) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d, %d], expected [%d, %d, %d, %d]\n", + __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], (int) tensor->ne[3], ne[0], ne[1], ne[2], ne[3]); + return false; + } + + const size_t bpe = ggml_type_size(ggml_type(ttype)); + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; + } + + if (ggml_backend_buffer_is_host(tensor->buffer)) { + // for the CPU and Metal backend, we can read directly into the tensor + loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(ggml_nbytes(tensor)); + + loader->read(loader->context, read_buf.data(), read_buf.size()); + + ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); + } + + total_size += ggml_nbytes(tensor); + n_loaded++; + } + + PARAKEET_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6); + + if (n_loaded == 0) { + PARAKEET_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } else if (n_loaded != (int) tensors_map.size()) { + PARAKEET_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, tensors_map.size(), n_loaded); + return false; + } + } + + auto & buffers = wctx.model.buffers; + for (auto & buf : buffers) { + ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } + + wctx.t_load_us = ggml_time_us() - t_start_us; + + return true; +} + +// +// This function builds the computation graph for the pre-encoder stage. +// +// It takes the generated mel spectrogram as an input tensor, which will be set +// during processing, and performs a number of convolutions to detect/filter +// features, reduces the time dimension by a factor of 8, and finally projects +// this information into the models abstract feature space. +// +static struct ggml_cgraph * parakeet_build_graph_conv(parakeet_context & pctx, parakeet_state & pstate) { + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_time = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : hparams.n_audio_ctx; + const int n_mels = hparams.n_mels; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_conv.meta.size(), + /*.mem_buffer =*/ pstate.sched_conv.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph(ctx0); + + // [freq, time] + struct ggml_tensor * mel = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_mels, n_time, 1, 1); + ggml_set_name(mel, "mel"); + ggml_set_input(mel); + + // [freq, time, channels, batch] + struct ggml_tensor * cur = ggml_conv_2d(ctx0, model.enc_pre_conv_0_w, mel, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_0_b); + ggml_set_name(cur, "pre_conv_0"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_0_relu"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_2_w, cur, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_2_b); + ggml_set_name(cur, "pre_conv_2"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d(ctx0, model.enc_pre_conv_3_w, cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_3_b); + ggml_set_name(cur, "pre_conv_3"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_3_relu"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_5_w, cur, 2, 2, 1, 1, 1, 1); + ggml_set_name(cur, "pre_conv_5_direct"); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_5_b); + ggml_set_name(cur, "pre_conv_5"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d(ctx0, model.enc_pre_conv_6_w, cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_6_b); + ggml_set_name(cur, "pre_conv_6"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_6_relu"); + + // [freq, time, chan] + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + // [freq, chan, time] + cur = ggml_cont(ctx0, cur); + + const int n_freq = cur->ne[0]; // 16 + const int n_chan = cur->ne[1]; // 256 + const int n_frames = cur->ne[2]; // time + + // [freq, time, chan, batch] -> [(freq * chan), time] + cur = ggml_reshape_2d(ctx0, cur, n_freq * n_chan, n_frames); + + cur = ggml_mul_mat(ctx0, model.enc_pre_out_w, cur); + cur = ggml_add(ctx0, cur, model.enc_pre_out_b); + + ggml_set_name(cur, "pre_enc_out"); + ggml_set_output(cur); + pstate.pre_enc_out = cur; + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + +// This function builds the graph for the encoder part of the model. +// +// It takes the output of the pre-encoder convolution above which will have the +// shape of [hidden_dim, time_frames] +static struct ggml_cgraph * parakeet_build_graph_encoder(parakeet_context & pctx, parakeet_state & pstate) { + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_layer = hparams.n_audio_layer; + const int n_state = hparams.n_audio_state; + const float fc_factor = 0.5f; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_encode.meta.size(), + /*.mem_buffer =*/ pstate.sched_encode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + // Create a view of the output produced by parakeet_build_graph_conv + // [feat, time_frames, 1, 1] + struct ggml_tensor * cur = ggml_view_tensor(ctx0, pstate.pre_enc_out); + ggml_set_name(cur, "encoder_inp"); + + // [time_frames, time_frames, 1, 1]] + struct ggml_tensor * attn_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, cur->ne[1], cur->ne[1]); + ggml_set_name(attn_mask, "attn_mask"); + ggml_set_input(attn_mask); + + const int n_time = cur->ne[1]; + const int window_size = 2 * n_time - 1; + const int d_half = n_state / 2; + + struct ggml_tensor * pos_freqs = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_half); + ggml_set_name(pos_freqs, "pos_freqs"); + ggml_set_input(pos_freqs); + + struct ggml_tensor * rel_positions = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, window_size); + ggml_set_name(rel_positions, "rel_positions"); + ggml_set_input(rel_positions); + + struct ggml_tensor * freqs = ggml_repeat_4d(ctx0, pos_freqs, d_half, window_size, 1, 1); + struct ggml_tensor * theta = ggml_mul(ctx0, freqs, rel_positions); + + struct ggml_tensor * sin_t = ggml_reshape_3d(ctx0, ggml_sin(ctx0, theta), 1, d_half, window_size); + struct ggml_tensor * cos_t = ggml_reshape_3d(ctx0, ggml_cos(ctx0, theta), 1, d_half, window_size); + // [n_state, window_size] + struct ggml_tensor * pos_emb = ggml_reshape_2d(ctx0, ggml_cont(ctx0, ggml_concat(ctx0, sin_t, cos_t, 0)), n_state, window_size); + ggml_set_name(pos_emb, "pos_emb"); + + for (int il = 0; il < n_layer; ++il) { + // FFN1 + { + struct ggml_tensor * residual = cur; + ggml_format_name(cur, "enc_%d_res", il); + + // norm + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].norm_ff1_w), model.layers[il].norm_ff1_b); + ggml_format_name(cur, "enc_%d_ffn_norm_1", il); + + // ffn_1 + cur = ggml_mul_mat(ctx0, model.layers[il].ff1_linear1_w, cur); + cur = ggml_silu(ctx0, cur); + ggml_format_name(cur, "enc_%d_silu", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].ff1_linear2_w, cur); + ggml_format_name(cur, "enc_%d_ffn_1", il); + + cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor)); + ggml_format_name(cur, "enc_%d_res_ffn", il); + } + + // self attention block using relative positional encoding computed in graph. + { + // [feat, time_frames, 1, 1] + struct ggml_tensor * residual = cur; + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].norm_attn_w), model.layers[il].norm_attn_b); + ggml_format_name(cur, "enc_%d_attn_norm", il); + + const int n_head = hparams.n_audio_head; + const int d_head = n_state / n_head; + const int n_time = cur->ne[1]; + + // [feat, time_frames, 1, 1] + struct ggml_tensor * Q_cur = ggml_mul_mat(ctx0, model.layers[il].attn_q_w, cur); + struct ggml_tensor * K_cur = ggml_mul_mat(ctx0, model.layers[il].attn_k_w, cur); + struct ggml_tensor * V_cur = ggml_mul_mat(ctx0, model.layers[il].attn_v_w, cur); + + // [d_head, n_heads, time_frames, 1] + Q_cur = ggml_reshape_3d(ctx0, Q_cur, d_head, n_head, n_time); + K_cur = ggml_reshape_3d(ctx0, K_cur, d_head, n_head, n_time); + V_cur = ggml_reshape_3d(ctx0, V_cur, d_head, n_head, n_time); + + // [n_state, window_size] + struct ggml_tensor * pos = ggml_mul_mat(ctx0, model.layers[il].attn_pos_w, pos_emb); + ggml_format_name(pos, "enc_%d_attn_pos", il); + + // Add the content bias to Q. Like a baseline query for content. Like + // regardless of my specific position all ways look for these base features. + // So we shift the Q by the content bias. + // [feat, head, time_frames, batch] + struct ggml_tensor * Q_u = ggml_add(ctx0, Q_cur, model.layers[il].attn_pos_bias_u); + ggml_format_name(Q_u, "enc_%d_attn_q_u", il); + + // [feat, time_frames, head, 1] + struct ggml_tensor * K_prep = ggml_permute(ctx0, K_cur, 0, 2, 1, 3); + // [feat, time_frames, head, 1] + struct ggml_tensor * Q_prep = ggml_permute(ctx0, Q_u, 0, 2, 1, 3); + // [feat, feat, head, 1] + struct ggml_tensor * content_scores = ggml_mul_mat(ctx0, K_prep, Q_prep); + ggml_format_name(content_scores, "enc_%d_attn_content_scores", il); + + // Add the position bias to Q. Like a baseline for query positions. Like + // regardless of my specific content I'm always interested in tokens a certain + // distance away. So we shift the Q by the position bias. + // [feat, head, time_frames, batch] + struct ggml_tensor * Q_v = ggml_add(ctx0, Q_cur, model.layers[il].attn_pos_bias_v); + ggml_format_name(Q_v, "enc_%d_attn_q_v", il); + + // [feat, window_size, 1, 1] and we are doing multi-head attention so + // we need to split this into heads. + // [feat, head, window_size, 1] + pos = ggml_reshape_3d(ctx0, pos, d_head, n_head, window_size); + + // [feat, window_size, head, 1] + pos = ggml_permute(ctx0, pos, 0, 2, 1, 3); + pos = ggml_cont(ctx0, pos); + ggml_format_name(pos, "enc_%d_attn_pos_perm", il); + // [feat, time, head, 1] + Q_v = ggml_permute(ctx0, Q_v, 0, 2, 1, 3); + Q_v = ggml_cont(ctx0, Q_v); + ggml_format_name(Q_v, "enc_%d_attn_q_v_perm", il); + + // [window_size, time_frames, head, 1] + struct ggml_tensor * rel_pos_scores = ggml_mul_mat(ctx0, pos, Q_v); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos", il); + + // Relative positional shift + { + + const auto pos_window = rel_pos_scores->ne[0]; + const auto n_frame = rel_pos_scores->ne[1]; + const auto n_head = rel_pos_scores->ne[2]; + + // [feat_padded, window_size, head, 1] + rel_pos_scores = ggml_pad(ctx0, rel_pos_scores, 1, 0, 0, 0); + rel_pos_scores = ggml_roll(ctx0, rel_pos_scores, 1, 0, 0, 0); + + rel_pos_scores = ggml_reshape_3d(ctx0, rel_pos_scores, n_frame, pos_window + 1, n_head); + rel_pos_scores = ggml_cont(ctx0, rel_pos_scores); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_reshaped", il); + + int center = pos_window / 2; + size_t offset = rel_pos_scores->nb[0] * (center+1); + + rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores, + n_frame, pos_window, n_head, + (pos_window) * 4, + rel_pos_scores->nb[2], + offset); + + rel_pos_scores = ggml_cont(ctx0, rel_pos_scores); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted", il); + + rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores, + content_scores->ne[0], + content_scores->ne[1], + rel_pos_scores->ne[2], + rel_pos_scores->nb[1], + rel_pos_scores->nb[2], + 0); + rel_pos_scores = ggml_cont(ctx0, rel_pos_scores); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted_view", il); + } + + struct ggml_tensor * attn_scores = ggml_add(ctx0, content_scores, rel_pos_scores); + attn_scores = ggml_cont(ctx0, attn_scores); + ggml_format_name(attn_scores, "enc_%d_attn_scores", il); + attn_scores = ggml_scale(ctx0, attn_scores, 1.0f / std::sqrt(d_head)); + attn_scores = ggml_add(ctx0, attn_scores, attn_mask); + ggml_format_name(attn_scores, "enc_%d_attn_scores_scaled", il); + + struct ggml_tensor * probs = ggml_soft_max(ctx0, attn_scores); + ggml_format_name(probs, "enc_%d_attn_probs", il); + + V_cur = ggml_cont(ctx0, ggml_permute(ctx0, V_cur, 1, 2, 0, 3)); + ggml_format_name(V_cur, "enc_%d_attn_v_cur", il); + cur = ggml_mul_mat(ctx0, probs, V_cur); + ggml_format_name(cur, "enc_%d_attn_inp", il); + + cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); + cur = ggml_cont_2d(ctx0, cur, n_state, n_time); + cur = ggml_mul_mat(ctx0, model.layers[il].attn_out_w, cur); + ggml_format_name(cur, "enc_%d_attn_out", il); + + cur = ggml_add(ctx0, residual, cur); + ggml_format_name(cur, "enc_%d_attn_res", il); + } + + // Convolution + { + struct ggml_tensor * residual = cur; + ggml_format_name(cur, "enc_%d_residual_conv", il); + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].norm_conv_w), model.layers[il].norm_conv_b); + ggml_format_name(cur, "enc_%d_norm_conv", il); + + // pointwise 1d convolution: [1024, 138] -> [2048, 138] + cur = ggml_mul_mat(ctx0, model.layers[il].conv_pw1_w, cur); + ggml_format_name(cur, "enc_%d_conv_pw1", il); + + { + int64_t d = cur->ne[0] / 2; + struct ggml_tensor * signal = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], 0); + struct ggml_tensor * gate = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], d * cur->nb[0]); + + cur = ggml_mul(ctx0, signal, ggml_sigmoid(ctx0, gate)); + ggml_format_name(cur, "enc_%d_conv_glu", il); + } + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + // use ggml_ssm_conv for f32 precision + cur = ggml_pad(ctx0, cur, 4, 0, 0, 0); + cur = ggml_roll(ctx0, cur, 4, 0, 0, 0); + cur = ggml_pad(ctx0, cur, 4, 0, 0, 0); + ggml_format_name(cur, "enc_%d_conv_dw_pad", il); + + cur = ggml_ssm_conv(ctx0, cur, model.layers[il].conv_dw_w); + ggml_format_name(cur, "enc_%d_conv_1d_dw", il); + + cur = ggml_sub(ctx0, cur, model.layers[il].conv_bn_mean); + struct ggml_tensor * std = ggml_sqrt(ctx0, model.layers[il].conv_bn_var); + cur = ggml_div(ctx0, cur, std); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].conv_bn_w), model.layers[il].conv_bn_b); + ggml_format_name(cur, "enc_%d_conv_bn", il); + + cur = ggml_silu(ctx0, cur); + ggml_format_name(cur, "enc_%d_conv_silu", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].conv_pw2_w, cur); + ggml_format_name(cur, "enc_%d_conv_pw2", il); + + cur = ggml_add(ctx0, residual, cur); + ggml_format_name(cur, "enc_%d_conv_res", il); + } + + // FFN2 + { + struct ggml_tensor * residual = cur; + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].norm_ff2_w), model.layers[il].norm_ff2_b); + ggml_format_name(cur, "enc_%d_ffn_norm_2", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].ff2_linear1_w, cur); + cur = ggml_silu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.layers[il].ff2_linear2_w, cur); + cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, 0.5)); + ggml_format_name(cur, "enc_%d_ffn_res", il); + } + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].norm_out_w), model.layers[il].norm_out_b); + } + + ggml_set_name(cur, "encoder_out"); + ggml_set_output(cur); + pstate.enc_out = cur; + pstate.n_frames = cur->ne[1]; + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + +static bool parakeet_encode_internal( + parakeet_context & pctx, + parakeet_state & pstate, + const int mel_offset, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + const int64_t t_start_us = ggml_time_us(); + + // conv + { + auto & sched = pstate.sched_conv.sched; + + ggml_cgraph * gf = parakeet_build_graph_conv(pctx, pstate); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); + + // set the input + { + const auto & mel_inp = pstate.mel; + const int n_ctx = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : pctx.model.hparams.n_audio_ctx; + + assert(mel->type == GGML_TYPE_F32); + assert(mel_inp.n_mel == pctx.model.hparams.n_mels); + + pstate.inp_mel.resize(ggml_nelements(mel)); + + float * dst = pstate.inp_mel.data(); + memset(dst, 0, ggml_nbytes(mel)); + + const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i1 = std::min(mel_offset + n_ctx, mel_inp.n_len); + + const int n_frames_to_copy = i1 - i0; + memcpy(dst, mel_inp.data.data() + i0 * mel_inp.n_mel, n_frames_to_copy * mel_inp.n_mel * sizeof(float)); + + ggml_backend_tensor_set(mel, pstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float)); + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + } + + // encoder + auto & sched = pstate.sched_encode.sched; + + ggml_cgraph * gf = parakeet_build_graph_encoder(pctx, pstate); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + // set the inputs + { + struct ggml_tensor * attn_mask = ggml_graph_get_tensor(gf, "attn_mask"); + const int n_q = attn_mask->ne[1]; + const int n_k = attn_mask->ne[0]; + + const int32_t subsampl_factor = pctx.model.hparams.subsampling_factor; + const int n_tokens_real = (pstate.mel.n_len_org + subsampl_factor-1) / subsampl_factor; + + std::vector mask_data(n_q * n_k); + const float mask_value = -1e30f; + + for (int q = 0; q < n_q; ++q) { + for (int k = 0; k < n_k; ++k) { + bool is_padding = (k >= n_tokens_real); + + mask_data[q * n_k + k] = (is_padding) ? mask_value : 0.0f; + } + } + ggml_backend_tensor_set(attn_mask, mask_data.data(), 0, mask_data.size() * sizeof(float)); + } + + { + struct ggml_tensor * pos_freqs_t = ggml_graph_get_tensor(gf, "pos_freqs"); + const int d_half = pos_freqs_t->ne[0]; + const int n_state = pctx.model.hparams.n_audio_state; + const float log_10000 = logf(10000.0f); + std::vector freqs(d_half); + for (int k = 0; k < d_half; ++k) { + freqs[k] = expf(-(float(k * 2) * log_10000 / float(n_state))); + } + ggml_backend_tensor_set(pos_freqs_t, freqs.data(), 0, freqs.size() * sizeof(float)); + } + + { + struct ggml_tensor * rel_pos_t = ggml_graph_get_tensor(gf, "rel_positions"); + const int window_size = rel_pos_t->ne[1]; + const int n_time = (window_size + 1) / 2; + std::vector pos(window_size); + for (int t = 0; t < window_size; ++t) { + pos[t] = float(n_time - 1 - t); + } + ggml_backend_tensor_set(rel_pos_t, pos.data(), 0, pos.size() * sizeof(float)); + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + + pstate.t_encode_us += ggml_time_us() - t_start_us; + pstate.n_encode++; + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static struct ggml_tensor * parakeet_build_graph_lstm_layer( + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + struct ggml_tensor * x_t, // the current input token embedding + struct ggml_tensor * w_ih, // input to hidden weights (4 weight tensors packed) + struct ggml_tensor * b_ih, // input to hidden bias (4 bias packed together) + struct ggml_tensor * w_hh, // hidden to hidden weights (4 weight tensors packed) + struct ggml_tensor * b_hh, // hidden to hidden bias (4 bias tensors packed) + struct ggml_tensor * h_state, // this layers hidden state + struct ggml_tensor * c_state, // this layers cell state + int li) { // layer index (for tensor naming) + + ggml_format_name(x_t, "lstm_layer_%d_x_t", li); + ggml_format_name(h_state, "lstm_layer_%d_h_state", li); + ggml_format_name(c_state, "lstm_layer_%d_c_state", li); + + // Input/Forget/Cell/Output gates are packed in the same weight tensor. + struct ggml_tensor * inp_gates = ggml_mul_mat(ctx0, w_ih, x_t); + inp_gates = ggml_add(ctx0, inp_gates, b_ih); + + // Hidden-to-Hidden Projections are also packed in the same weight tensor. + struct ggml_tensor * hid_gates = ggml_mul_mat(ctx0, w_hh, h_state); + hid_gates = ggml_add(ctx0, hid_gates, b_hh); + + // Combine the input and hidden contributions of the gates. + struct ggml_tensor * gates = ggml_add(ctx0, inp_gates, hid_gates); + ggml_format_name(gates, "lstm_layer_%d_gates", li); + + const int h_dim = h_state->ne[0]; + const size_t row_size = ggml_row_size(gates->type, h_dim); + + // 1. Input Gate at time t. + struct ggml_tensor * i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, gates, h_dim, 0 * row_size)); + ggml_format_name(i_t, "lstm_layer_%d_i_t", li); + + // Forget gate. + struct ggml_tensor * f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, gates, h_dim, 1 * row_size)); + ggml_format_name(f_t, "lstm_layer_%d_f_t", li); + + // Cell gate. + struct ggml_tensor * c_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, gates, h_dim, 2 * row_size)); + ggml_format_name(c_t, "lstm_layer_%d_c_t", li); + + // Output gate. + struct ggml_tensor * o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, gates, h_dim, 3 * row_size)); + ggml_format_name(o_t, "lstm_layer_%d_o_t", li); + + // Calculate the new cell state. + struct ggml_tensor * c_new = ggml_add(ctx0, + ggml_mul(ctx0, f_t, c_state), // apply forget gate to cell state. + ggml_mul(ctx0, i_t, c_t)); // apply input gate to cell gate. + ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_new, c_state)); + + // Calculate the new hidden state. + struct ggml_tensor * h_new = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_new)); + ggml_set_output(h_new); + ggml_format_name(h_new, "lstm_layer_%d_h_new", li); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_new, h_state)); + + return h_new; +} + +static struct ggml_cgraph * parakeet_build_graph_prediction( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + bool worst_case) { + GGML_UNUSED(worst_case); + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_tokens = batch.n_tokens; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_decode.meta.size(), + /*.mem_buffer =*/ pstate.sched_decode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + // Prediction Network + struct ggml_tensor * token = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_name(token, "token_inp"); + ggml_set_input(token); + + struct ggml_tensor * token_embd = ggml_get_rows(ctx0, model.prediction.embed_w, token); + ggml_set_input(token_embd); + + struct ggml_tensor * inpL = token_embd; + + for (int il = 0; il < hparams.n_pred_layers; ++il) { + inpL = parakeet_build_graph_lstm_layer(ctx0, gf, inpL, + model.prediction.lstm_layer[il].ih_w, + model.prediction.lstm_layer[il].ih_b, + model.prediction.lstm_layer[il].hh_w, + model.prediction.lstm_layer[il].hh_b, + pstate.lstm_state.layer[il].h_state, + pstate.lstm_state.layer[il].c_state, + il); + } + + struct ggml_tensor * pred_out = inpL; + ggml_format_name(pred_out, "lstm_pred_out"); + + // Project the prediction network output to the joint network hidden dimension. + struct ggml_tensor * pred = ggml_mul_mat(ctx0, model.joint.pred_w, pred_out); + pred = ggml_add(ctx0, pred, model.joint.pred_b); + ggml_set_output(pred); + ggml_set_name(pred, "h_pred"); + + ggml_build_forward_expand(gf, pred); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * parakeet_build_graph_joint( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + bool worst_case) { + GGML_UNUSED(worst_case); + GGML_UNUSED(batch); + const auto & model = pctx.model; + const auto & hparams = model.hparams; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_decode.meta.size(), + /*.mem_buffer =*/ pstate.sched_decode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + struct ggml_tensor * pred = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_pred_dim); + ggml_set_name(pred, "pred"); + ggml_set_input(pred); + + struct ggml_tensor * enc_out = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_audio_state); + ggml_set_name(enc_out, "enc_out"); + ggml_set_input(enc_out); + + // Project the encoder output to the joint network hidden dimension. + struct ggml_tensor * enc = ggml_mul_mat(ctx0, model.joint.enc_w, enc_out); + enc = ggml_add(ctx0, enc, model.joint.enc_b); + ggml_set_name(enc, "enc"); + + struct ggml_tensor * joint = ggml_add(ctx0, enc, pred); + ggml_set_name(joint, "joint"); + joint = ggml_relu(ctx0, joint); + + struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.joint.net_w, joint); + logits = ggml_add(ctx0, logits, model.joint.net_b); + ggml_set_output(logits); + ggml_set_name(logits, "logits"); + + struct ggml_tensor * probs = ggml_soft_max(ctx0, logits); + struct ggml_tensor * log_probs = ggml_log(ctx0, probs); + ggml_set_output(log_probs); + ggml_format_name(log_probs, "log_probs"); + + ggml_build_forward_expand(gf, log_probs); + + ggml_free(ctx0); + + return gf; +} + +static bool parakeet_predict( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + + const int n_tokens = batch.n_tokens; + + { + auto & sched = pstate.sched_decode.sched; + + ggml_cgraph * gf = parakeet_build_graph_prediction(pctx, pstate, batch, false); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + // set the inputs + { + struct ggml_tensor * token_inp = ggml_graph_get_tensor(gf, "token_inp"); + ggml_backend_tensor_set(token_inp, batch.token, 0, n_tokens * ggml_element_size(token_inp)); + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + + // Copy h_pred output to pstate.pred_out for use in joint network + { + struct ggml_tensor * h_pred = ggml_graph_get_tensor(gf, "h_pred"); + ggml_backend_tensor_copy(h_pred, pstate.pred_out); + } + + } + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool parakeet_joint( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + const int64_t t_start_us = ggml_time_us(); + + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_tokens = batch.n_tokens; + + auto & logits_out = pstate.logits; + + struct ggml_tensor * logits; + + { + auto & sched = pstate.sched_decode.sched; + + ggml_cgraph * gf = parakeet_build_graph_joint(pctx, pstate, batch, false); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + // set the inputs + { + struct ggml_tensor * pred = ggml_graph_get_tensor(gf, "pred"); + if (n_tokens == 1) { + ggml_backend_tensor_copy(pstate.pred_out, pred); + } + + struct ggml_tensor * enc_out_inp = ggml_graph_get_tensor(gf, "enc_out"); + if (n_tokens == 1) { + const int t_idx = batch.i_time[0]; + const int d_enc = hparams.n_audio_state; + + std::vector frame_buf(d_enc); + + // If we have accumulated encoder output (chunked processing), use that + if (!pstate.enc_out_buffer.empty() && t_idx < pstate.enc_out_frames) { + const float * src = pstate.enc_out_buffer.data() + (t_idx * d_enc); + std::copy(src, src + d_enc, frame_buf.begin()); + } else { + ggml_backend_tensor_get(pstate.enc_out, frame_buf.data(), t_idx * d_enc * sizeof(float), d_enc * sizeof(float)); + } + + ggml_backend_tensor_set(enc_out_inp, frame_buf.data(), 0, d_enc * sizeof(float)); + } + } + + logits = ggml_graph_node(gf, -1); + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + + } + + const int n_logits = hparams.n_vocab + hparams.n_tdt_durations + 1; // one for the blank token + logits_out.resize(n_tokens * n_logits); + for (int i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + ggml_backend_tensor_get(logits, logits_out.data() + (n_logits*i), sizeof(float)*(n_logits*i), sizeof(float)*n_logits); + } + + if (batch.n_tokens == 1) { + pstate.t_decode_us += ggml_time_us() - t_start_us; + pstate.n_decode++; + } + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool is_word_start_token(parakeet_vocab & vocab, parakeet_token token_id) { + const std::string & token_str = vocab.id_to_token[token_id]; + // check if it starts with the SentencePiece meta-space "▁" (U+2581) or 3-byte UTF-8 character: 0xE2 0x96 0x81 + if (!token_str.empty()) { + if (token_str.find("\xE2\x96\x81") == 0 || token_str[0] == '_') { + return true; + } + } + return false; +} + +static bool is_punctuation_token(parakeet_vocab & vocab, parakeet_token token_id) { + const std::string & token_str = vocab.id_to_token[token_id]; + static const std::string punct_chars = ".,!?;:'\"-()[]{}"; + + if (token_str.empty()) { + return false; + } + + std::string clean_token = token_str; + if (clean_token.find("\xE2\x96\x81") == 0) { + clean_token = clean_token.substr(3); // Remove the 3-byte UTF-8 character + } else if (clean_token[0] == '_') { + clean_token = clean_token.substr(1); + } + + return clean_token.length() == 1 && punct_chars.find(clean_token[0]) != std::string::npos; +} + +// Collapse punctuation timestamps to match the original Parakeet model. +// Punctuations symbols like ',', '.' and others are not spoken words but the +// model will still produce a duration for these tokens. But since these are +// non-spoken we collapse the timestamps so that they don't have an time duration. +static void refine_timestamps_tdt(parakeet_vocab & vocab, std::vector & tokens) { + if (tokens.empty()) { + return; + } + + int64_t last_non_punct_t1 = -1; + + for (size_t i = 0; i < tokens.size(); ++i) { + if (is_punctuation_token(vocab, tokens[i].id)) { + if (last_non_punct_t1 >= 0) { + tokens[i].t0 = last_non_punct_t1; + tokens[i].t1 = last_non_punct_t1; + } + } else { + last_non_punct_t1 = tokens[i].t1; + } + } +} + +static parakeet_token_data create_token_data( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_token token_id, + int duration_idx, + int duration_value, + int frame_index, + float token_logit, + int n_vocab_logits) { + + float token_sum = 0.0f; + for (int i = 0; i < n_vocab_logits; ++i) { + token_sum += expf(pstate.logits[i]); + } + float token_p = expf(token_logit) / token_sum; + + parakeet_token_data token_data; + token_data.id = token_id; + token_data.duration_idx = duration_idx; + token_data.duration_value = duration_value; + token_data.frame_index = frame_index; + token_data.p = token_p; + token_data.plog = token_logit; + token_data.t0 = frame_index * pctx.model.hparams.subsampling_factor; + token_data.t1 = (frame_index + duration_value) * pctx.model.hparams.subsampling_factor; + token_data.is_word_start = is_word_start_token(pctx.vocab, token_id); + + return token_data; +} + +static bool parakeet_decode( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_batch & batch, + const int n_threads, + const parakeet_full_params * params = nullptr) { + const auto & hparams = pctx.model.hparams; + const auto & tdt_durations = pctx.model.tdt_durations; + + const int n_tdt_durations = hparams.n_tdt_durations; + const int n_frames = pstate.n_frames; + const int blank_id = pctx.vocab.token_blank; + const int n_vocab_logits = blank_id + 1; + const int max_tokens_per_timestep = hparams.n_max_tokens; + + // time index into the encoder frame (current time frame) + int t = 0; + // number of symbols emitted for the current time frame + int tokens_emitted = 0; + + // Start with the blank token (8192) + parakeet_token last_token = blank_id; + + PARAKEET_LOG_DEBUG("parakeet_decode: starting decode with n_frames=%d\n", n_frames); + + batch.n_tokens = 1; + batch.token[0] = last_token; + batch.logits[0] = 1; + batch.i_time[0] = 0; + + // run the prediction network for the initial blank token. This will + // initialize the LSTM state and produce an initial hidden state that can + // be used in the joint network below. + if (!parakeet_predict(pctx, pstate, batch, n_threads, nullptr, nullptr)) { + return false; + } + + // process all time frames of the encoder output + while (t < n_frames) { + batch.n_tokens = 1; + batch.i_time[0] = t; + batch.logits[0] = 1; + + // Use the current encoder frame (t) and the output of the prediction to + // generate probabilities for the next token and duration. batch.i_time + // is used in to select the correct frame from the encoder output. + // The joint network outputs logits for all the tokens in the vocabulary + // plus the blank token, and also n_duration logits for the duration + // tokens which contain information about how many frames to skip/advance forward. + if (!parakeet_joint(pctx, pstate, batch, n_threads, nullptr, nullptr)) { + return false; + } + + // find the best token (greedy). + // TODO: implement beam search? + int best_token = 0; + float max_logit = -1e10f; + for (int i = 0; i < n_vocab_logits; ++i) { + if (pstate.logits[i] > max_logit) { + max_logit = pstate.logits[i]; + best_token = i; + } + } + + // find the max index of the duration logits, and look up that index + // value in the tdt_durations array to get the actual duration value. + int best_duration_idx = 0; + float best_duration_logit = -1e10f; + for (int i = 0; i < n_tdt_durations; ++i) { + if (pstate.logits[n_vocab_logits + i] > best_duration_logit) { + best_duration_logit = pstate.logits[n_vocab_logits + i]; + best_duration_idx = i; + } + } + // look up that max duration index value in the tdt_durations array to + // get the actual duration value. + int duration = tdt_durations[best_duration_idx]; + + if (best_token == blank_id) { + if (duration == 0) { + duration = 1; + } + // skip forward by duration time frames. + t += duration; + // reset symbols emitted counter + tokens_emitted = 0; + // continue without predicting. + continue; + } + + // Emit non-blank token at current frame t. + pstate.decoded_tokens.push_back(best_token); + pstate.n_sample++; + + parakeet_token_data token_data = create_token_data( + pctx, pstate, best_token, best_duration_idx, duration, t, + max_logit, n_vocab_logits); + + pstate.decoded_token_data.push_back(token_data); + + // Call token callback if registered (for real-time streaming) + if (params && params->new_token_callback) { + params->new_token_callback(&pctx, &pstate, &token_data, params->new_token_callback_user_data); + } + + last_token = best_token; + + // advance predictor for the non-blank token. + batch.token[0] = last_token; + if (!parakeet_predict(pctx, pstate, batch, n_threads, nullptr, nullptr)) { + return false; + } + + // if duration greater than 0, continue looping over the encoder frames + // and skip to the updated time frame (t). + if (duration > 0) { + t += duration; + tokens_emitted = 0; + continue; + } + + // if duration is zero we stay on the current time frame. + tokens_emitted++; + if (tokens_emitted >= max_tokens_per_timestep) { + t += 1; // forced blank/time advance behavior + tokens_emitted = 0; + } + } + + return true; +} + +static bool parakeet_decode_chunk( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_batch & batch, + const int chunk_enc_frames, + const int n_threads, + const parakeet_full_params * params = nullptr) { + const auto & hparams = pctx.model.hparams; + const auto & tdt_durations = pctx.model.tdt_durations; + + const int blank_id = pctx.vocab.token_blank; + const int n_vocab_logits = blank_id + 1; + const int max_tokens_per_timestep = hparams.n_max_tokens; + + // time index into the encoder frame (current time frame) + int t = pstate.tdt_stream_state.time_step; + parakeet_token last_token = pstate.tdt_stream_state.last_token; + + PARAKEET_LOG_DEBUG("%s: chunk_frames=%d, start_t=%d, last_token=%d, initialized=%d\n", + __func__, chunk_enc_frames, t, last_token, pstate.tdt_stream_state.initialized); + + if (!pstate.tdt_stream_state.initialized) { + // Start with the blank token (8192) + last_token = blank_id; + t = 0; + + batch.n_tokens = 1; + batch.token[0] = last_token; + batch.logits[0] = 0; + batch.i_time[0] = 0; + + // run the prediction network for the initial blank token. This will + // initialize the LSTM state and produce an initial hidden state that can + // be used in the joint network. + if (!parakeet_predict(pctx, pstate, batch, n_threads, nullptr, nullptr)) { + return false; + } + pstate.tdt_stream_state.initialized = true; + pstate.tdt_stream_state.last_token = last_token; + } + + // Count non-blank emissions at the same encoder frame to mimic max_symbols handling. + int last_emission_time = -1; + int tokens_at_time = 0; + + // loop over all the time frames the encoder produced for this chunk. + while (t < chunk_enc_frames) { + int current_token_time = t; + + int chosen_token = blank_id; + int chosen_duration = 1; + int chosen_duration_idx = 0; + float chosen_token_logit = 0.0f; + + // inner loop handles the joint network. + while (t < chunk_enc_frames) { + batch.n_tokens = 1; + batch.i_time[0] = std::min(t, chunk_enc_frames - 1); + batch.logits[0] = 1; + + // Use the current encoder frame (t) and the output of the prediction to + // generate probabilities for the next token and duration. + // The joint network outputs logits for all the tokens in the vocabulary + // and the blank token, and also addtionally N duration logits for the duration + // tokens which contain information about how many frames to skip/advance forward. + if (!parakeet_joint(pctx, pstate, batch, n_threads, nullptr, nullptr)) { + return false; + } + + // find the token with the highest value. + int best_token = 0; + float best_token_logit = -1e10f; + for (int i = 0; i < n_vocab_logits; ++i) { + if (pstate.logits[i] > best_token_logit) { + best_token_logit = pstate.logits[i]; + best_token = i; + } + } + + // find the max index of the duration logits, and look up that index + // value in the tdt_durations array to get the actual duration value. + int best_duration_idx = 0; + float best_duration_logit = -1e10f; + for (int i = 0; i < hparams.n_tdt_durations; ++i) { + const float v = pstate.logits[n_vocab_logits + i]; + if (v > best_duration_logit) { + best_duration_logit = v; + best_duration_idx = i; + } + } + // look up that max duration index value in the tdt_durations array to + // get the actual duration value. + int duration = tdt_durations[best_duration_idx]; + + // handle blank token output from the joint network. + if (best_token == blank_id) { + if (duration == 0) { + duration = 1; + } + + // always add the duration. + t += duration; + + // if blank duration jumped past the chunk break. + if (t >= chunk_enc_frames) { + break; + } + + // we continue with the current frame. + current_token_time = t; + continue; + } + + // found the next non-blank token and label for this outer step. + chosen_token = best_token; + chosen_duration = duration; + chosen_token_logit = best_token_logit; + chosen_duration_idx = best_duration_idx; + break; + } + + // if we exited because blank durations ran beyond the chunk, no label to emit. + if (chosen_token == blank_id) { + break; + } + + // store token at the frame where it was found. + pstate.decoded_tokens.push_back(chosen_token); + pstate.n_sample++; + + parakeet_token_data token_data = create_token_data( + pctx, pstate, chosen_token, chosen_duration_idx, chosen_duration, current_token_time, + chosen_token_logit, n_vocab_logits); + + pstate.decoded_token_data.push_back(token_data); + + // emit the token to the callback if registered. + if (params && params->new_token_callback) { + params->new_token_callback(&pctx, &pstate, &token_data, params->new_token_callback_user_data); + } + + // advance predictor for the non-blank token. + last_token = chosen_token; + pstate.tdt_stream_state.last_token = last_token; + + batch.n_tokens = 1; + batch.token[0] = last_token; + batch.logits[0] = 0; + batch.i_time[0] = 0; + + if (!parakeet_predict(pctx, pstate, batch, n_threads, nullptr, nullptr)) { + return false; + } + + if (chosen_duration > 0) { + // if the duration looked up in the duration array was greater than zero + t += chosen_duration; + last_emission_time = -1; + tokens_at_time = 0; + } else { + // the duration looked up in the duration array was zero + if (current_token_time == last_emission_time) { + // update the token count as it can be valid for the model + // to emit a token with a zero duration. + tokens_at_time++; + } else { + // update the current emission time and reset the token count. + last_emission_time = current_token_time; + tokens_at_time = 1; + } + + // handle the case where the model is emitting too many tokens with + // zero duration to force progress. + if (tokens_at_time >= max_tokens_per_timestep) { + t += 1; + last_emission_time = -1; + tokens_at_time = 0; + } + } + } + + pstate.tdt_stream_state.time_step = t - chunk_enc_frames; + pstate.tdt_stream_state.decoded_length += chunk_enc_frames; + + PARAKEET_LOG_DEBUG("%s: finished t=%d, time_step=%d, decoded_length=%d, total_tokens=%zu\n", + __func__, + t, + pstate.tdt_stream_state.time_step, + pstate.tdt_stream_state.decoded_length, + pstate.decoded_tokens.size()); + + return true; +} + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +// naive Discrete Fourier Transform +// input is real-valued +// output is complex-valued +static void dft(const float* in, int N, float* out, const parakeet_mel_cache & cache) { + const int sin_cos_step = cache.n_fft / N; + + for (int k = 0; k < N; k++) { + float re = 0; + float im = 0; + + for (int n = 0; n < N; n++) { + int idx = (k * n * sin_cos_step) % cache.n_fft; // t = 2*M_PI*k*n/N + re += in[n]*cache.cos_vals[idx]; // cos(t) + im -= in[n]*cache.sin_vals[idx]; // sin(t) + } + + out[k*2 + 0] = re; + out[k*2 + 1] = im; + } +} + +// Cooley-Tukey FFT +// poor man's implementation - use something better +// input is real-valued +// output is complex-valued +static void fft(float* in, int N, float* out, const parakeet_mel_cache & cache) { + if (N == 1) { + out[0] = in[0]; + out[1] = 0; + return; + } + + const int half_N = N / 2; + if (N - half_N*2 == 1) { + dft(in, N, out, cache); + return; + } + + float* even = in + N; + for (int i = 0; i < half_N; ++i) { + even[i]= in[2*i]; + } + float* even_fft = out + 2 * N; + fft(even, half_N, even_fft, cache); + + float* odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i] = in[2*i + 1]; + } + float* odd_fft = even_fft + N; + fft(odd, half_N, odd_fft, cache); + + const int sin_cos_step = cache.n_fft / N; + for (int k = 0; k < half_N; k++) { + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = cache.cos_vals[idx]; // cos(t) + float im = -cache.sin_vals[idx]; // sin(t) + + float re_odd = odd_fft[2*k + 0]; + float im_odd = odd_fft[2*k + 1]; + + out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; + out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + + out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; + out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + } +} + +struct mel_worker_params { + int ith; + int window_size; + int n_samples; + int frame_size; + int frame_step; + int n_threads; +}; + +static void log_mel_spectrogram_worker_thread( + mel_worker_params params, + const float * window_func, + const std::vector & samples, + const parakeet_filters & filters, + parakeet_mel & mel, + const parakeet_mel_cache & cache) { + std::vector fft_in(params.frame_size * 2, 0.0); + std::vector fft_out(params.frame_size * 2 * 2 * 2); + + int n_fb = filters.n_fb; // number of frequency bins + int i = params.ith; + + // make sure n_fb == 1 + (frame_size / 2), bin_0 to bin_nyquist + assert(n_fb == 1 + (params.frame_size / 2)); + + const double eps = 5.960464477539063e-08; + + // calculate FFT only when fft_in are not all zero + for (; i < std::min(params.n_samples / params.frame_step + 1, mel.n_len); i += params.n_threads) { + const int offset = i * params.frame_step; + + const int window_pad_left = (params.frame_size - params.window_size) / 2; + + // Zero-pad left + std::fill(fft_in.begin(), fft_in.begin() + window_pad_left, 0.0f); + + // Apply windowed samples in the center + const int n_to_process = std::min({params.window_size, params.n_samples - offset}); + for (int j = 0; j < n_to_process; j++) { + fft_in[window_pad_left + j] = window_func[j] * samples[offset + window_pad_left + j]; + } + + // Zero-pad right (and any samples we didn't have) + std::fill(fft_in.begin() + window_pad_left + n_to_process, fft_in.begin() + params.frame_size, 0.0f); + + // FFT + fft(fft_in.data(), params.frame_size, fft_out.data(), cache); + + // Calculate modulus^2 of complex numbers + // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. + for (int j = 0; j < n_fb; j++) { + fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); + } + + // mel spectrogram + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + // unroll loop (suggested by GH user @lunixbochs) + int k = 0; + for (k = 0; k < n_fb - 3; k += 4) { + sum += + fft_out[k + 0] * filters.data[j * n_fb + k + 0] + + fft_out[k + 1] * filters.data[j * n_fb + k + 1] + + fft_out[k + 2] * filters.data[j * n_fb + k + 2] + + fft_out[k + 3] * filters.data[j * n_fb + k + 3]; + } + // handle n_fb remainder + for (; k < n_fb; k++) { + sum += fft_out[k] * filters.data[j * n_fb + k]; + } + + mel.data[i * mel.n_mel + j] = std::log(sum + eps); + } + } + + // Otherwise fft_out are all zero - use log(eps) for consistency + const double empty_sum = std::log(eps); + for (; i < mel.n_len; i += params.n_threads) { + for (int j = 0; j < mel.n_mel; j++) { + mel.data[i * mel.n_mel + j] = empty_sum; + } + } +} + +static bool log_mel_spectrogram( + parakeet_state & wstate, + const float * samples, + const int n_samples, + const int /*sample_rate*/, + const int frame_size, + const int frame_step, + const int n_mel, + const int n_threads, + const parakeet_filters & filters, + const bool debug, + parakeet_mel & mel, + const parakeet_mel_cache & cache) { + const int64_t t_start_us = ggml_time_us(); + + const float * window_func = cache.window.empty() ? cache.hann_window.data() : cache.window.data(); + const int window_size = cache.window.empty() ? cache.n_fft : cache.window.size(); + + std::vector samples_preprocessed(samples, samples + n_samples); + + // Apply preemphasis filter (high-pass): x[i] = x[i] - 0.97 * x[i-1] + { + const float preemph = 0.97f; + for (int i = n_samples - 1; i > 0; i--) { + samples_preprocessed[i] = samples_preprocessed[i] - preemph * samples_preprocessed[i - 1]; + } + } + + // Parakeet Pytorch implementation uses centered contant padding. + const size_t pad = (size_t)(frame_size / 2); + std::vector samples_padded(n_samples + 2 * pad, 0.0f); + std::copy(samples_preprocessed.begin(), samples_preprocessed.end(), samples_padded.begin() + pad); + + mel.n_mel = n_mel; + mel.n_len = (samples_padded.size() - frame_size) / frame_step + 1; + mel.n_len_org = mel.n_len; + mel.data.resize(mel.n_mel * mel.n_len); + + // Worker Threads (STFT + Mel + Natural Log) + { + std::vector workers(n_threads - 1); + const mel_worker_params mel_params { 0, window_size, (int)samples_padded.size(), frame_size, frame_step, n_threads }; + + for (int iw = 0; iw < n_threads - 1; ++iw) { + mel_worker_params params = mel_params; + params.ith = iw + 1; + workers[iw] = std::thread(log_mel_spectrogram_worker_thread, + params, + window_func, + std::cref(samples_padded), + std::cref(filters), + std::ref(mel), + std::cref(cache)); + } + + log_mel_spectrogram_worker_thread( + mel_params, + window_func, + samples_padded, + filters, + mel, + cache); + + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw].join(); + } + } + + { + const double eps = 1e-5; + int valid_frames = n_samples / frame_step; + + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + double sq_diff_sum = 0.0; + + // Calculate Mean ONLY on valid audio frames + for (int i = 0; i < valid_frames; i++) { + sum += (double)mel.data[i * mel.n_mel + j]; + } + double mean = sum / valid_frames; + + // Calculate Variance ONLY on valid audio frames + for (int i = 0; i < valid_frames; i++) { + double diff = (double)mel.data[i * mel.n_mel + j] - mean; + sq_diff_sum += diff * diff; + } + + double std_dev = std::sqrt(sq_diff_sum / (valid_frames - 1.0)); + double denominator = std_dev + eps; + + // Apply to ALL frames (including the padded ones) + for (int i = 0; i < mel.n_len; i++) { + mel.data[i * mel.n_mel + j] = (float)((mel.data[i * mel.n_mel + j] - mean) / denominator); + } + } + } + + wstate.t_mel_us += ggml_time_us() - t_start_us; + + if (debug) { + std::ofstream outFile("log_mel_spectrogram.json"); + outFile << "["; + for (uint64_t i = 0; i < mel.data.size() - 1; i++) { + outFile << mel.data[i] << ", "; + } + outFile << mel.data[mel.data.size() - 1] << "]"; + outFile.close(); + } + + return true; +} + +static std::vector tokenize(const parakeet_vocab & vocab, const std::string & text) { + std::vector tokens; + const std::string normalized = sentencepiece_normalize(text); + + size_t i = 0; + while (i < normalized.size()) { + const size_t remaining = normalized.size() - i; + const size_t max_len = std::min(vocab.max_token_length, remaining); + + bool found = false; + for (size_t len = max_len; len > 0; --len) { + const auto it = vocab.token_to_id.find(normalized.substr(i, len)); + if (it != vocab.token_to_id.end() && !is_sentencepiece_control(it->first)) { + tokens.push_back(it->second); + i += len; + found = true; + break; + } + } + + if (!found) { + if (vocab.token_unk >= 0) { + tokens.push_back(vocab.token_unk); + } + + const unsigned char c = static_cast(normalized[i]); + i += utf8_codepoint_len(c); + } + } + + return tokens; +} + + +// +// interface implementation +// + +struct parakeet_state * parakeet_init_state(parakeet_context * ctx) { + parakeet_state * state = new parakeet_state; + + state->backends = whisper_backend_init(ctx->params); + if (state->backends.empty()) { + PARAKEET_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + const int batch_size = ctx->model.hparams.n_audio_ctx; + + state->logits.reserve(ctx->vocab.n_vocab * batch_size); + + state->batch = parakeet_batch_init(batch_size); + + // conv allocator + { + bool ok = parakeet_sched_graph_init(state->sched_conv, state->backends, + [&]() { + return parakeet_build_graph_conv(*ctx, *state); + }); + + if (!ok) { + PARAKEET_LOG_ERROR("%s: failed to init conv allocator\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + PARAKEET_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_conv) / 1e6); + } + + // encoder allocator + bool ok = parakeet_sched_graph_init(state->sched_encode, state->backends, + [&]() { + return parakeet_build_graph_encoder(*ctx, *state); + }); + + if (!ok) { + PARAKEET_LOG_ERROR("%s: failed to init encoder allocator\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + if (!parakeet_lstm_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_layers)) { + PARAKEET_LOG_ERROR("%s: parakeet_lstm_states_init () failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + { + const size_t mem_lstm_ctx = state->lstm_state.ctx_buf.size(); + const size_t mem_lstm_buf = ggml_backend_buffer_get_size(state->lstm_state.buffer); + PARAKEET_LOG_INFO("%s: lstm state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_lstm_ctx / 1024.0 / 1024.0, mem_lstm_buf / 1024.0 / 1024.0); + } + + if (!parakeet_pred_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_dim)) { + PARAKEET_LOG_ERROR("%s: parakeet_pred_state_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + { + const size_t mem_pred_ctx = state->pred_out_buf.size(); + const size_t mem_pred_out_buf = ggml_backend_buffer_get_size(state->pred_out_buffer); + PARAKEET_LOG_INFO("%s: pred state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_pred_ctx / 1024.0 / 1024.0, mem_pred_out_buf / 1024.0 / 1024.0); + } + + PARAKEET_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_encode) / 1e6); + + { + bool ok = parakeet_sched_graph_init(state->sched_decode, state->backends, + [&]() { + const auto & hparams = ctx->model.hparams; + const int n_tokens = hparams.n_audio_ctx; // Use audio ctx for Parakeet + + parakeet_batch_prep_legacy(state->batch, nullptr, n_tokens, 0, 0); + + return parakeet_build_graph_prediction(*ctx, *state, state->batch, true); + }); + + if (!ok) { + PARAKEET_LOG_ERROR("%s: failed to init decoder allocator\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + PARAKEET_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_decode) / 1e6); + } + + return state; +} + +struct parakeet_context_params parakeet_context_default_params() { + struct parakeet_context_params result = { + /*.use_gpu =*/ true, + /*.gpu_device =*/ 0, + }; + return result; +} + +struct parakeet_context * parakeet_init_from_file_with_params_no_state(const char * path_model, struct parakeet_context_params params) { + PARAKEET_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model); +#ifdef _MSC_VER + // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues. + std::wstring_convert> converter; + std::wstring path_model_wide = converter.from_bytes(path_model); + auto fin = std::ifstream(path_model_wide, std::ios::binary); +#else + auto fin = std::ifstream(path_model, std::ios::binary); +#endif + if (!fin) { + PARAKEET_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model); + return nullptr; + } + + parakeet_model_loader loader = {}; + + loader.context = &fin; + + loader.read = [](void * ctx, void * output, size_t read_size) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->read((char *)output, read_size); + return read_size; + }; + + loader.eof = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + return fin->eof(); + }; + + loader.close = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->close(); + }; + + auto ctx = parakeet_init_with_params_no_state(&loader, params); + + if (ctx) { + ctx->path_model = path_model; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params) { + struct buf_context { + uint8_t* buffer; + size_t size; + size_t current_offset; + }; + + buf_context ctx = { reinterpret_cast(buffer), buffer_size, 0 }; + + PARAKEET_LOG_INFO("%s: loading model from buffer\n", __func__); + + parakeet_model_loader loader = {}; + + loader.context = &ctx; + + loader.read = [](void * ctx, void * output, size_t read_size) { + buf_context * buf = reinterpret_cast(ctx); + + size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset; + + memcpy(output, buf->buffer + buf->current_offset, size_to_copy); + buf->current_offset += size_to_copy; + + return size_to_copy; + }; + + loader.eof = [](void * ctx) { + buf_context * buf = reinterpret_cast(ctx); + + return buf->current_offset >= buf->size; + }; + + loader.close = [](void * /*ctx*/) { }; + + return parakeet_init_with_params_no_state(&loader, params); +} + +struct parakeet_context * parakeet_init_with_params_no_state(struct parakeet_model_loader * loader, struct parakeet_context_params params) { + ggml_time_init(); + + PARAKEET_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu); + PARAKEET_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device); + PARAKEET_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count()); + PARAKEET_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count()); + + parakeet_context * ctx = new parakeet_context; + ctx->params = params; + + if (!parakeet_model_load(loader, *ctx)) { + loader->close(loader->context); + PARAKEET_LOG_ERROR("%s: failed to load model\n", __func__); + delete ctx; + return nullptr; + } + + loader->close(loader->context); + + // Initialize mel cache with model's FFT size + ctx->mel_cache.init(ctx->model.hparams.n_fft); + PARAKEET_LOG_INFO("%s: initialized mel cache with n_fft = %d\n", __func__, ctx->model.hparams.n_fft); + + return ctx; +} + +struct parakeet_context * parakeet_init_from_file_with_params(const char * path_model, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_from_file_with_params_no_state(path_model, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_from_buffer_with_params_no_state(buffer, buffer_size, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_with_params(struct parakeet_model_loader * loader, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_with_params_no_state(loader, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +void parakeet_free_state(struct parakeet_state * state) { + if (state) { + ggml_backend_buffer_free(state->lstm_state.buffer); + ggml_backend_buffer_free(state->pred_out_buffer); + + parakeet_batch_free(state->batch); + + ggml_backend_sched_free(state->sched_conv.sched); + ggml_backend_sched_free(state->sched_encode.sched); + ggml_backend_sched_free(state->sched_decode.sched); + + for (auto & backend : state->backends) { + ggml_backend_free(backend); + } + + delete state; + } +} + +void parakeet_free(struct parakeet_context * ctx) { + if (ctx) { + for (ggml_context * context : ctx->model.ctxs) { + ggml_free(context); + } + + for (ggml_backend_buffer_t buf : ctx->model.buffers) { + ggml_backend_buffer_free(buf); + } + + parakeet_free_state(ctx->state); + + delete ctx; + } +} + +void parakeet_free_context_params(struct parakeet_context_params * params) { + if (params) { + delete params; + } +} + +void parakeet_free_params(struct parakeet_full_params * params) { + if (params) { + delete params; + } +} + +int parakeet_pcm_to_mel_with_state(struct parakeet_context * ctx, struct parakeet_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*state, + samples, + n_samples, + PARAKEET_SAMPLE_RATE, + ctx->model.hparams.n_fft, + PARAKEET_HOP_LENGTH, + ctx->model.filters.n_mel, + n_threads, + ctx->model.filters, + false, // debug + state->mel, + ctx->mel_cache)) { + PARAKEET_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_pcm_to_mel(struct parakeet_context * ctx, const float * samples, int n_samples, int n_threads) { + return parakeet_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); +} + +int parakeet_set_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * data, + int n_len, + int n_mel) { + if (n_mel != ctx->model.filters.n_mel) { + PARAKEET_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); + return -1; + } + + state->mel.n_len = n_len; + state->mel.n_len_org = n_len; + state->mel.n_mel = n_mel; + + state->mel.data.resize(n_len*n_mel); + memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float)); + + return 0; +} + +int parakeet_set_mel( + struct parakeet_context * ctx, + const float * data, + int n_len, + int n_mel) { + return parakeet_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel); +} + +int parakeet_encode_with_state(struct parakeet_context * ctx, struct parakeet_state * state, int offset, int n_threads) { + if (!parakeet_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) { + PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_encode(struct parakeet_context * ctx, int offset, int n_threads) { + if (!parakeet_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) { + PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_tokenize(struct parakeet_context * ctx, const char * text, parakeet_token * tokens, int n_max_tokens) { + const auto res = tokenize(ctx->vocab, text); + + if (n_max_tokens < (int) res.size()) { + PARAKEET_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + return -(int) res.size(); + } + + for (int i = 0; i < (int) res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +int parakeet_token_count(struct parakeet_context * ctx, const char * text) { + return -parakeet_tokenize(ctx, text, NULL, 0); +} + +int parakeet_model_n_vocab(struct parakeet_context * ctx) { + return ctx->model.hparams.n_vocab; +} + +int parakeet_model_n_audio_ctx(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +int parakeet_model_n_audio_state(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_state; +} + +int parakeet_model_n_audio_head(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_head; +} + +int parakeet_model_n_audio_layer(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_layer; +} + +int parakeet_model_n_mels(struct parakeet_context * ctx) { + return ctx->model.hparams.n_mels; +} + +int parakeet_model_ftype(struct parakeet_context * ctx) { + return ctx->model.hparams.ftype; +} + +int parakeet_n_len_from_state(struct parakeet_state * state) { + return state->mel.n_len_org; +} + +int parakeet_n_len(struct parakeet_context * ctx) { + return ctx->state->mel.n_len_org; +} + +int parakeet_n_vocab(struct parakeet_context * ctx) { + return ctx->vocab.n_vocab; +} + +int parakeet_n_audio_ctx(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +float * parakeet_get_logits(struct parakeet_context * ctx) { + return ctx->state->logits.data(); +} + +float * parakeet_get_logits_from_state(struct parakeet_state * state) { + return state->logits.data(); +} + +const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token) { + return ctx->vocab.id_to_token.at(token).c_str(); +} + +int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len) { + std::string text = sentencepiece_piece_to_text(token_str, is_first); + + if (output == nullptr) { + return text.size(); + } + + int bytes_to_copy = std::min((int)text.size(), max_len - 1); + if (bytes_to_copy > 0) { + memcpy(output, text.c_str(), bytes_to_copy); + output[bytes_to_copy] = '\0'; + } else if (max_len > 0) { + output[0] = '\0'; + } + + return text.size(); +} + +parakeet_token parakeet_token_bos(struct parakeet_context * ctx) { + return ctx->vocab.token_bos; +} + +parakeet_token parakeet_token_unk(struct parakeet_context * ctx) { + return ctx->vocab.token_unk; +} + +parakeet_token parakeet_token_blank(struct parakeet_context * ctx) { + return ctx->vocab.token_blank; +} + +struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx) { + if (ctx->state == nullptr) { + return nullptr; + } + parakeet_timings * timings = new parakeet_timings; + timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample); + timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode); + timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode); + return timings; +} + +void parakeet_print_timings(struct parakeet_context * ctx) { + const int64_t t_end_us = ggml_time_us(); + + PARAKEET_LOG_INFO("\n"); + PARAKEET_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + if (ctx->state != nullptr) { + + const int32_t n_sample = std::max(1, ctx->state->n_sample); + const int32_t n_encode = std::max(1, ctx->state->n_encode); + const int32_t n_decode = std::max(1, ctx->state->n_decode); + + PARAKEET_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); + PARAKEET_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); + PARAKEET_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); + PARAKEET_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); + PARAKEET_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + + } + PARAKEET_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); +} + +void parakeet_reset_timings(struct parakeet_context * ctx) { + ctx->t_start_us = ggml_time_us(); + if (ctx->state != nullptr) { + ctx->state->t_mel_us = 0; + ctx->state->t_sample_us = 0; + ctx->state->t_encode_us = 0; + ctx->state->t_decode_us = 0; + + ctx->state->n_sample = 0; + ctx->state->n_encode = 0; + ctx->state->n_decode = 0; + } +} + +const char * parakeet_print_system_info(void) { + static std::string s; + + s = ""; + s += "PARAKEET : "; + + for (size_t i = 0; i < ggml_backend_reg_count(); i++) { + auto * reg = ggml_backend_reg_get(i); + auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features"); + if (get_features_fn) { + ggml_backend_feature * features = get_features_fn(reg); + s += ggml_backend_reg_name(reg); + s += " : "; + for (; features->name; features++) { + s += features->name; + s += " = "; + s += features->value; + s += " | "; + } + } + } + return s.c_str(); +} + +struct parakeet_context_params * parakeet_context_default_params_by_ref(void) { + struct parakeet_context_params params = parakeet_context_default_params(); + + struct parakeet_context_params* result = new parakeet_context_params(); + *result = params; + return result; +} + +struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy) { + struct parakeet_full_params params = parakeet_full_default_params(strategy); + + struct parakeet_full_params* result = new parakeet_full_params(); + *result = params; + return result; +} + +struct parakeet_full_params parakeet_full_default_params(enum parakeet_sampling_strategy strategy) { + struct parakeet_full_params result = { + /*.strategy =*/ strategy, + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + /*.no_context =*/ true, + /*.audio_ctx =*/ 0, + /*.chunk_length_ms =*/ 10000, // 10 second chunks + /*.left_context_ms =*/ 10000, // 10 second left context + /*.right_context_ms =*/ 4960, // 4.96 second right context + /*.new_token_callback =*/ nullptr, + /*.new_token_callback_user_data =*/ nullptr, + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, + /*.abort_callback =*/ nullptr, + /*.abort_callback_user_data =*/ nullptr, + }; + + return result; +} + +static void parakeet_stream_clear(struct parakeet_state * state) { + state->stream.buffer.clear(); + state->stream.n_samples_advanced = 0; + state->stream.n_left_ctx = 0; + state->stream.n_chunk = 0; + state->stream.n_right_ctx = 0; + state->stream.params = {}; + state->stream.initialized = false; +} + +static void parakeet_reset_state(struct parakeet_state * state) { + state->decoded_tokens.clear(); + state->decoded_token_data.clear(); + + if (state->lstm_state.buffer) { + ggml_backend_buffer_clear(state->lstm_state.buffer, 0); + } + + state->tdt_stream_state.initialized = false; + state->tdt_stream_state.last_token = 0; + state->tdt_stream_state.time_step = 0; + state->tdt_stream_state.decoded_length = 0; +} + +static void parakeet_stream_reset_state(struct parakeet_state * state) { + if (state == nullptr) { + return; + } + + parakeet_reset_state(state); + + state->result_all.clear(); + + parakeet_stream_clear(state); + + state->enc_out_buffer.clear(); + state->enc_out_frames = 0; + state->n_frames = 0; + state->n_audio_ctx = 0; +} + +static int parakeet_stream_process_window( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * samples, + int n_samples, + int n_chunk) { + const parakeet_stream & stream = state->stream; + const parakeet_full_params & params = stream.params; + const int d_enc = ctx->model.hparams.n_audio_state; + + // process all the samples. + if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + return -2; + } + + const int left_mel_frames = stream.n_left_ctx / PARAKEET_HOP_LENGTH; + const int chunk_mel_frames = n_chunk / PARAKEET_HOP_LENGTH; + + state->n_audio_ctx = state->mel.n_len; + // process entire log mel spectrogram. + if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, + params.abort_callback, params.abort_callback_user_data)) { + return -6; + } + + const int left_enc_frames = left_mel_frames / ctx->model.hparams.subsampling_factor; + const int chunk_enc_frames = chunk_mel_frames / ctx->model.hparams.subsampling_factor; + + if (chunk_enc_frames <= 0) { + return 0; + } + + // Copy the center chunk so that it is the only part that the joint network sees. + state->enc_out_buffer.resize(chunk_enc_frames * d_enc); + ggml_backend_tensor_get(state->enc_out, state->enc_out_buffer.data(), + left_enc_frames * d_enc * sizeof(float), + chunk_enc_frames * d_enc * sizeof(float)); + + state->enc_out_frames = chunk_enc_frames; + state->n_frames = chunk_enc_frames; + + const size_t tokens_before = state->decoded_tokens.size(); + + // Run the prediction and joint network on the center chunk. + if (!parakeet_decode_chunk(*ctx, *state, state->batch, chunk_enc_frames, params.n_threads, ¶ms)) { + return -7; + } + + const size_t tokens_after = state->decoded_tokens.size(); + const size_t new_token_count = tokens_after - tokens_before; + + if (new_token_count > 0) { + std::string text; + std::vector result_tokens; + const int64_t chunk_t0 = 100LL * stream.n_samples_advanced / PARAKEET_SAMPLE_RATE; + const int64_t chunk_t1 = 100LL * (stream.n_samples_advanced + n_chunk) / PARAKEET_SAMPLE_RATE; + const int frame_offset = chunk_t0 / ctx->model.hparams.subsampling_factor; + + result_tokens.reserve(new_token_count); + + for (size_t i = tokens_before; i < tokens_after; ++i) { + const auto token_id = state->decoded_tokens[i]; + const char * token_str = parakeet_token_to_str(ctx, token_id); + if (token_str) { + const bool is_first_piece = (tokens_before == 0) && text.empty(); + text += sentencepiece_piece_to_text(token_str, is_first_piece); + } + + auto token_data = state->decoded_token_data[i]; + token_data.frame_index += frame_offset; + token_data.t0 += chunk_t0; + token_data.t1 += chunk_t0; + result_tokens.push_back(token_data); + } + + refine_timestamps_tdt(ctx->vocab, result_tokens); + + if (!text.empty()) { + parakeet_segment segment; + segment.t0 = chunk_t0; + segment.t1 = chunk_t1; + segment.text = std::move(text); + segment.tokens = std::move(result_tokens); + + state->result_all.push_back(std::move(segment)); + + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data); + } + } + } + + return 0; +} + +static int ms_to_n_samples(int ms) { + return ms * PARAKEET_SAMPLE_RATE / 1000; +} + +int parakeet_stream_init( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params) { + if (ctx == nullptr || state == nullptr) { + return -1; + } + + const int n_left_ctx = ms_to_n_samples(params.left_context_ms); + const int n_chunk = ms_to_n_samples(params.chunk_length_ms); + const int n_right_ctx = ms_to_n_samples(params.right_context_ms); + + if (n_left_ctx < 0 || n_chunk <= 0 || n_right_ctx < 0) { + return -1; + } + + parakeet_stream_reset_state(state); + + state->stream.n_left_ctx = n_left_ctx; + state->stream.n_chunk = n_chunk; + state->stream.n_right_ctx = n_right_ctx; + state->stream.params = params; + state->stream.initialized = true; + + if (n_left_ctx > 0) { + state->stream.buffer.assign(n_left_ctx, 0.0f); + } + + return 0; +} + +int parakeet_stream_push( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * samples, + int n_samples) { + if (ctx == nullptr || state == nullptr || samples == nullptr || n_samples <= 0) { + return -1; + } + + if (!state->stream.initialized) { + return -1; + } + + const int n_total_samples = state->stream.n_left_ctx + state->stream.n_chunk + state->stream.n_right_ctx; + + // Insert the new chunk of samples as the new center and right context. + state->stream.buffer.insert(state->stream.buffer.end(), samples, samples + n_samples); + + // As long as we have enough samples to form a complete window we process it. + while (state->stream.buffer.size() >= (size_t) n_total_samples) { + const int ret = parakeet_stream_process_window( + ctx, + state, + state->stream.buffer.data(), + n_total_samples, + state->stream.n_chunk); + if (ret != 0) { + return ret; + } + + // TODO: std::vector::erase is O(n) and not optimal. We should probably + // use a ring buffer instead. + // Shift the center and right context to the start of the buffer. This + // allows the next call to have the current center chunk as its left + // context, and the right context will become part of the next target + // chunk together with the new samples which will make up the rest of + // the target chunk and the new right context. + state->stream.buffer.erase(state->stream.buffer.begin(), state->stream.buffer.begin() + state->stream.n_chunk); + + state->stream.n_samples_advanced += state->stream.n_chunk; + } + + return 0; +} + +int parakeet_stream_flush( + struct parakeet_context * ctx, + struct parakeet_state * state) { + if (ctx == nullptr || state == nullptr) { + return -1; + } + + if (!state->stream.initialized) { + return -1; + } + + while (state->stream.buffer.size() > (size_t) state->stream.n_left_ctx) { + const int n_remaining_samples = (int) state->stream.buffer.size() - state->stream.n_left_ctx; + const int n_flush_chunk = std::min(state->stream.n_chunk, n_remaining_samples); + const int n_right_available = std::min(state->stream.n_right_ctx, n_remaining_samples - n_flush_chunk); + const int n_copied = state->stream.n_left_ctx + n_flush_chunk + n_right_available; + + std::vector flush_window(state->stream.n_left_ctx + n_flush_chunk + state->stream.n_right_ctx, 0.0f); + + std::copy_n(state->stream.buffer.begin(), n_copied, flush_window.begin()); + + const int ret = parakeet_stream_process_window( + ctx, + state, + flush_window.data(), + (int) flush_window.size(), + n_flush_chunk); + if (ret != 0) { + return ret; + } + + state->stream.buffer.erase(state->stream.buffer.begin(), state->stream.buffer.begin() + n_flush_chunk); + state->stream.n_samples_advanced += n_flush_chunk; + } + + parakeet_stream_clear(state); + + return 0; +} + +int parakeet_full_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + // clear old results + auto & result_all = state->result_all; + result_all.clear(); + + // Clear any previous decoded tokens if no_context is set + if (params.no_context) { + parakeet_reset_state(state); + } + + // If no chunking specified, delegate to parakeet_chunk (processes entire audio as one chunk) + if (params.chunk_length_ms == 0) { + return parakeet_chunk(ctx, state, params, samples, n_samples); + } + + // Nemo sliding window chunking (chunk raw samples) + { + PARAKEET_LOG_INFO("%s: chunking: chunk_ms=%d, left_ctx=%d, right_ctx=%d\n", + __func__, params.chunk_length_ms, params.left_context_ms, params.right_context_ms); + + // reset tdt streaming state + state->tdt_stream_state.initialized = false; + state->tdt_stream_state.last_token = 0; + state->tdt_stream_state.time_step = 0; + state->tdt_stream_state.decoded_length = 0; + + const int left_context_samples = (params.left_context_ms * PARAKEET_SAMPLE_RATE) / 1000; + const int chunk_samples = (params.chunk_length_ms * PARAKEET_SAMPLE_RATE) / 1000; + const int right_context_samples = (params.right_context_ms * PARAKEET_SAMPLE_RATE) / 1000; + + const int d_enc = ctx->model.hparams.n_audio_state; + + int left_sample = 0; + + // Process audio in sliding chunks of raw samples. + while (left_sample < n_samples) { + const int buffer_start_sample = std::max(0, left_sample - left_context_samples); + const int right_sample = std::min(left_sample + chunk_samples, n_samples); + const int buffer_end_sample = std::min(right_sample + right_context_samples, n_samples); + const int buffer_length_samples = buffer_end_sample - buffer_start_sample; + const size_t tokens_before = state->decoded_tokens.size(); + + PARAKEET_LOG_DEBUG("%s: sample chunk: buffer=[%d-%d] samples, chunk=[%d-%d]\n", + __func__, buffer_start_sample, buffer_end_sample, left_sample, right_sample); + + // Compute mel spectrogram for the left context, the center chunk, + // and the right context. + if (parakeet_pcm_to_mel_with_state(ctx, state, samples + buffer_start_sample, + buffer_length_samples, params.n_threads) != 0) { + PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram for chunk\n", __func__); + return -2; + } + + const int total_mel_frames = state->mel.n_len; + + // Calculate mel frame offsets within this chunks mel spectrogram + // mel frames = raw samples / 160 (for 16kHz, 10ms hop) + const int left_context_mel_frames = (left_sample - buffer_start_sample) / 160; + const int chunk_mel_frames = (right_sample - left_sample) / 160; + + PARAKEET_LOG_DEBUG("%s: Mel features: total=%d frames, left_ctx=%d, chunk=%d\n", + __func__, total_mel_frames, left_context_mel_frames, chunk_mel_frames); + + // The encoder will see the left context, the main chunk, and the right + // context, so that it has access to nearby frames during processing. + // Without this there would be hard cutoffs and the encoder might not + // be able to detect speech near the edges. + state->n_audio_ctx = total_mel_frames; + if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, + params.abort_callback, params.abort_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: failed to encode chunk\n", __func__); + return -6; + } + + // Calculate which encoder frames correspond to the actual chunk (center chunk) + const int left_context_enc_frames = left_context_mel_frames / ctx->model.hparams.subsampling_factor; + const int chunk_enc_frames = chunk_mel_frames / ctx->model.hparams.subsampling_factor; + + PARAKEET_LOG_DEBUG("%s: encoder output: total=%d frames, left_ctx=%d, chunk=%d\n", + __func__, state->n_frames, left_context_enc_frames, chunk_enc_frames); + + // The joint network only sees the center chunk encoded frames and + // not the left and right context like the encoder did. + state->enc_out_buffer.resize(chunk_enc_frames * d_enc); + ggml_backend_tensor_get(state->enc_out, state->enc_out_buffer.data(), + left_context_enc_frames * d_enc * sizeof(float), + chunk_enc_frames * d_enc * sizeof(float)); + state->enc_out_frames = chunk_enc_frames; + + state->n_frames = chunk_enc_frames; + + PARAKEET_LOG_DEBUG("%s: decoding %d encoder frames for this chunk\n", __func__, chunk_enc_frames); + + if (!parakeet_decode_chunk(*ctx, *state, state->batch, chunk_enc_frames, params.n_threads, ¶ms)) { + PARAKEET_LOG_ERROR("%s: failed to decode chunk\n", __func__); + return -7; + } + + const size_t tokens_after = state->decoded_tokens.size(); + const size_t new_token_count = tokens_after - tokens_before; + + if (new_token_count > 0) { + std::string text; + std::vector result_tokens; + const int64_t chunk_t0 = 100LL * left_sample / PARAKEET_SAMPLE_RATE; + const int64_t chunk_t1 = 100LL * right_sample / PARAKEET_SAMPLE_RATE; + const int frame_offset = chunk_t0 / ctx->model.hparams.subsampling_factor; + + result_tokens.reserve(new_token_count); + + for (size_t i = tokens_before; i < tokens_after; ++i) { + const auto token_id = state->decoded_tokens[i]; + const char * token_str = parakeet_token_to_str(ctx, token_id); + if (token_str) { + const bool is_first_piece = (tokens_before == 0) && text.empty(); + text += sentencepiece_piece_to_text(token_str, is_first_piece); + } + + auto token_data = state->decoded_token_data[i]; + token_data.frame_index += frame_offset; + token_data.t0 += chunk_t0; + token_data.t1 += chunk_t0; + result_tokens.push_back(token_data); + } + + refine_timestamps_tdt(ctx->vocab, result_tokens); + + if (!text.empty()) { + parakeet_segment segment; + segment.t0 = chunk_t0; + segment.t1 = chunk_t1; + segment.text = std::move(text); + segment.tokens = std::move(result_tokens); + + state->result_all.push_back(std::move(segment)); + + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data); + } + } + } + + // shift the window. + left_sample = right_sample; + } + + // Clear buffers after all chunks processed + state->enc_out_buffer.clear(); + state->enc_out_frames = 0; + } + + return 0; +} + +int parakeet_full( + struct parakeet_context * ctx, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + return parakeet_full_with_state(ctx, ctx->state, params, samples, n_samples); +} + +int parakeet_chunk( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + + if (params.no_context) { + parakeet_reset_state(state); + } + + if (n_samples > 0) { + if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + + if (params.audio_ctx == 0) { + const int total_len = parakeet_n_len_from_state(state); + const int model_max_ctx = parakeet_n_audio_ctx(ctx); + params.audio_ctx = std::min(total_len, model_max_ctx); + PARAKEET_LOG_DEBUG("Processing audio: total_frames=%d, chunk_size=%d\n", total_len, params.audio_ctx); + } + state->n_audio_ctx = params.audio_ctx; + + const int n_frames = parakeet_n_len_from_state(state); + + if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + const size_t tokens_before = state->decoded_tokens.size(); + + if (!parakeet_decode(*ctx, *state, state->batch, params.n_threads, ¶ms)) { + PARAKEET_LOG_ERROR("%s: failed to decode\n", __func__); + return -7; + } + + const size_t tokens_after = state->decoded_tokens.size(); + const size_t new_token_count = tokens_after - tokens_before; + + if (new_token_count > 0) { + std::string text; + std::vector result_tokens; + + for (size_t i = tokens_before; i < tokens_after; i++) { + const auto token_id = state->decoded_tokens[i]; + const char * token_str = parakeet_token_to_str(ctx, token_id); + if (token_str) { + const bool is_first_piece = (tokens_before == 0) && text.empty(); + text += sentencepiece_piece_to_text(token_str, is_first_piece); + } + + // Use the stored token data from parakeet_decode + result_tokens.push_back(state->decoded_token_data[i]); + } + + refine_timestamps_tdt(ctx->vocab, result_tokens); + + if (!text.empty()) { + parakeet_segment segment; + segment.t0 = 0; // Caller tracks timing + segment.t1 = n_frames; + segment.text = text; + segment.tokens = result_tokens; + + state->result_all.push_back(std::move(segment)); + + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data); + } + } + } + + return 0; +} + +int parakeet_full_n_segments_from_state(struct parakeet_state * state) { + return state->result_all.size(); +} + +int parakeet_full_n_segments(struct parakeet_context * ctx) { + return ctx->state->result_all.size(); +} + +int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].t0; +} + +int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].t1; +} + +int64_t parakeet_full_get_segment_t0(struct parakeet_context * ctx, int i_segment) { + return parakeet_full_get_segment_t0_from_state(ctx->state, i_segment); +} + +int64_t parakeet_full_get_segment_t1(struct parakeet_context * ctx, int i_segment) { + return parakeet_full_get_segment_t1_from_state(ctx->state, i_segment); +} + +const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].text.c_str(); +} + +const char * parakeet_full_get_segment_text(struct parakeet_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].text.c_str(); +} + +int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].tokens.size(); +} + +int parakeet_full_n_tokens(struct parakeet_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].tokens.size(); +} + +const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token) { + return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +const char* parakeet_full_get_token_text(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].id; +} + +parakeet_token parakeet_full_get_token_id(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].id; +} + +struct parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token]; +} + +struct parakeet_token_data parakeet_full_get_token_data(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token]; +} + +float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].p; +} + +float parakeet_full_get_token_p(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].p; +} + +void parakeet_log_set(ggml_log_callback log_callback, void * user_data) { + g_state.log_callback = log_callback ? log_callback : parakeet_log_callback_default; + g_state.log_callback_user_data = user_data; + ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); +} + +const char * parakeet_version(void) { + return PARAKEET_VERSION; +} + +GGML_ATTRIBUTE_FORMAT(2, 3) +static void parakeet_log_internal(ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + char buffer[1024]; + int len = vsnprintf(buffer, 1024, format, args); + if (len < 1024) { + g_state.log_callback(level, buffer, g_state.log_callback_user_data); + } else { + char* buffer2 = new char[len+1]; + vsnprintf(buffer2, len+1, format, args); + buffer2[len] = 0; + g_state.log_callback(level, buffer2, g_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args); +} + +static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; +#ifndef WHISPER_DEBUG + if (level == GGML_LOG_LEVEL_DEBUG) { + return; + } +#endif + fputs(text, stderr); + fflush(stderr); +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 09e77ea89c2..a426d9dca88 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -110,3 +110,32 @@ target_compile_definitions(${VAD_TEST} PRIVATE SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST}) set_tests_properties(${VAD_TEST} PROPERTIES LABELS "base;en") + +# Parakeet model loading test +set(PARAKEET_TEST test-parakeet) +add_executable(${PARAKEET_TEST} ${PARAKEET_TEST}.cpp) +target_include_directories(${PARAKEET_TEST} PRIVATE ../include ../ggml/include ../examples) +target_link_libraries(${PARAKEET_TEST} PRIVATE parakeet common) +target_compile_definitions(${PARAKEET_TEST} PRIVATE + PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/ggml-parakeet-tdt-0.6b-v3.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") +add_test(NAME ${PARAKEET_TEST} COMMAND ${PARAKEET_TEST}) + +set(PARAKEET_TEST test-parakeet-stream) +add_executable(${PARAKEET_TEST} ${PARAKEET_TEST}.cpp) +target_include_directories(${PARAKEET_TEST} PRIVATE ../include ../ggml/include ../examples) +target_link_libraries(${PARAKEET_TEST} PRIVATE parakeet common) +target_compile_definitions(${PARAKEET_TEST} PRIVATE + PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/ggml-parakeet-tdt-0.6b-v3.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/gb1.wav") +add_test(NAME ${PARAKEET_TEST} COMMAND ${PARAKEET_TEST}) + +set(PARAKEET_TEST test-parakeet-full) +add_executable(${PARAKEET_TEST} ${PARAKEET_TEST}.cpp) +target_include_directories(${PARAKEET_TEST} PRIVATE ../include ../ggml/include ../examples) +target_link_libraries(${PARAKEET_TEST} PRIVATE parakeet common) +target_compile_definitions(${PARAKEET_TEST} PRIVATE + PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/ggml-parakeet-tdt-0.6b-v3.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/gb1.wav") +add_test(NAME ${PARAKEET_TEST} COMMAND ${PARAKEET_TEST}) +set_tests_properties(${PARAKEET_TEST} PROPERTIES LABELS "parakeet;unit") diff --git a/tests/librispeech-parakeet/.gitignore b/tests/librispeech-parakeet/.gitignore new file mode 100644 index 00000000000..838bfeae9db --- /dev/null +++ b/tests/librispeech-parakeet/.gitignore @@ -0,0 +1,6 @@ +__pycache__ +*.tar.gz +*.txt +eval.conf +venv +LibriSpeech diff --git a/tests/librispeech-parakeet/Makefile b/tests/librispeech-parakeet/Makefile new file mode 100644 index 00000000000..0afa2465f49 --- /dev/null +++ b/tests/librispeech-parakeet/Makefile @@ -0,0 +1,15 @@ +TAR_URL = https://www.openslr.org/resources/12/test-clean.tar.gz + +all: eval + +eval: + $(MAKE) -f eval.mk + +clean: + $(MAKE) -f eval.mk clean + +get-audio: + wget -c $(TAR_URL) + tar -xf test-clean.tar.gz + +.PHONY: all eval clean setup-venv clean-venv get-audio diff --git a/tests/librispeech-parakeet/README.md b/tests/librispeech-parakeet/README.md new file mode 100644 index 00000000000..e09cba405ef --- /dev/null +++ b/tests/librispeech-parakeet/README.md @@ -0,0 +1,57 @@ +# parakeet.cpp/tests/librispeech + +[LibriSpeech](https://www.openslr.org/12) is a standard dataset for +training and evaluating automatic speech recognition systems. + +This directory contains a set of tools to evaluate the recognition +performance of parakeet.cpp on LibriSpeech corpus. + +## Quick Start + +1. (Pre-requirement) Compile `parakeet-cli` and prepare the Parakeet + model in `ggml` format. + + ``` + $ # Execute the commands below in the project root dir. + $ cmake -B build + $ cmake --build build --config Release + ``` + +2. Download the audio files from LibriSpeech project. + + ``` + $ make get-audio + ``` + +3. Set up the environment to compute WER score. + + ``` + $ pip install -r requirements.txt + ``` + + For example, if you use `virtualenv`, you can set up it as follows: + + ``` + $ python3 -m venv venv + $ . venv/bin/activate + $ pip install -r requirements.txt + ``` + +4. Run the benchmark test. + + ``` + $ make + ``` + +## How-to guides + +### How to change the inference parameters + +Create `eval.conf` and override variables. + +``` +PARAKEET_MODEL = parakeet-tdt-0.6b-v3 +PARAKEET_FLAGS = --no-prints --threads 8 --language en --output-txt +``` + +Check out `eval.mk` for more details. diff --git a/tests/librispeech-parakeet/eval.mk b/tests/librispeech-parakeet/eval.mk new file mode 100644 index 00000000000..7d8992ec471 --- /dev/null +++ b/tests/librispeech-parakeet/eval.mk @@ -0,0 +1,39 @@ +PYTHON = python + +PARAKEET_PREFIX = ../../ +PARAKEET_MODEL = parakeet-tdt-0.6b-v3 + +PARAKEET_CLI = $(PARAKEET_PREFIX)build/bin/parakeet-cli +PARAKEET_FLAGS = --no-prints --output-txt + +# You can create eval.conf to override the PARAKEET_* variables +# defined above. +-include eval.conf + +# This follows the file structure of the LibriSpeech project. +AUDIO_SRCS = $(sort $(wildcard LibriSpeech/*/*/*/*.flac)) +TRANS_TXTS = $(addsuffix .txt, $(AUDIO_SRCS)) + +# We output the evaluation result to this file. +DONE = $(PARAKEET_MODEL).txt + +all: $(DONE) + +$(DONE): $(TRANS_TXTS) + $(PYTHON) eval.py > $@.tmp + mv $@.tmp $@ + +# Note: This task writes to a temporary file first to +# create the target file atomically. +%.flac.txt: %.flac + $(PARAKEET_CLI) $(PARAKEET_FLAGS) --model $(PARAKEET_PREFIX)models/ggml-$(PARAKEET_MODEL).bin --file $^ --output-file $^.tmp + mv $^.tmp.txt $^.txt + +archive: + tar -czf $(PARAKEET_MODEL).tar.gz --exclude="*.flac" LibriSpeech $(DONE) + +clean: + @rm -f $(TRANS_TXTS) + @rm -f $(DONE) + +.PHONY: all clean diff --git a/tests/librispeech-parakeet/eval.py b/tests/librispeech-parakeet/eval.py new file mode 100644 index 00000000000..cdaf8352fd4 --- /dev/null +++ b/tests/librispeech-parakeet/eval.py @@ -0,0 +1,47 @@ +import os +import glob +import jiwer +from normalizers import EnglishTextNormalizer + +def get_reference(): + ref = {} + for path in glob.glob('LibriSpeech/*/*/*/*.trans.txt'): + with open(path) as fp: + for line in fp: + code, text = line.strip().split(" ", maxsplit=1) + ref [code] = text + return ref + +def get_hypothesis(): + hyp = {} + for path in glob.glob('LibriSpeech/*/*/*/*.flac.txt'): + with open(path) as fp: + text = fp.read().strip() + code = os.path.basename(path).replace('.flac.txt', '') + hyp[code] = text + return hyp + +def get_codes(): + codes = [] + for path in glob.glob('LibriSpeech/*/*/*/*.flac'): + codes.append(os.path.basename(path).replace('.flac', '')) + return sorted(codes) + +def main(): + normalizer = EnglishTextNormalizer() + + ref_orig = get_reference() + hyp_orig = get_hypothesis() + + ref_clean = [] + hyp_clean = [] + + for code in get_codes(): + ref_clean.append(normalizer(ref_orig[code])) + hyp_clean.append(normalizer(hyp_orig[code])) + + wer = jiwer.wer(ref_clean, hyp_clean) + print(f"WER: {wer * 100:.2f}%") + +if __name__ == '__main__': + main() diff --git a/tests/librispeech-parakeet/normalizers/LICENSE b/tests/librispeech-parakeet/normalizers/LICENSE new file mode 100644 index 00000000000..7c8e603b0fc --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/LICENSE @@ -0,0 +1,25 @@ +Code in this directory is adapted from OpenAI Whisper project +(https://github.com/openai/whisper) and carries the following +copyright and license. + + MIT License + + Copyright (c) 2022 OpenAI + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/tests/librispeech-parakeet/normalizers/__init__.py b/tests/librispeech-parakeet/normalizers/__init__.py new file mode 100644 index 00000000000..896d5e33641 --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/__init__.py @@ -0,0 +1,2 @@ +from .basic import BasicTextNormalizer as BasicTextNormalizer +from .english import EnglishTextNormalizer as EnglishTextNormalizer diff --git a/tests/librispeech-parakeet/normalizers/basic.py b/tests/librispeech-parakeet/normalizers/basic.py new file mode 100644 index 00000000000..8690ae71c5f --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/basic.py @@ -0,0 +1,80 @@ +import re +import unicodedata + +import regex + +# non-ASCII letters that are not separated by "NFKD" normalization +ADDITIONAL_DIACRITICS = { + "œ": "oe", + "Œ": "OE", + "ø": "o", + "Ø": "O", + "æ": "ae", + "Æ": "AE", + "ß": "ss", + "ẞ": "SS", + "đ": "d", + "Đ": "D", + "ð": "d", + "Ð": "D", + "þ": "th", + "Þ": "th", + "ł": "l", + "Ł": "L", +} + + +def remove_symbols_and_diacritics(s: str, keep=""): + """ + Replace any other markers, symbols, and punctuations with a space, + and drop any diacritics (category 'Mn' and some manual mappings) + """ + return "".join( + ( + c + if c in keep + else ( + ADDITIONAL_DIACRITICS[c] + if c in ADDITIONAL_DIACRITICS + else ( + "" + if unicodedata.category(c) == "Mn" + else " " if unicodedata.category(c)[0] in "MSP" else c + ) + ) + ) + for c in unicodedata.normalize("NFKD", s) + ) + + +def remove_symbols(s: str): + """ + Replace any other markers, symbols, punctuations with a space, keeping diacritics + """ + return "".join( + " " if unicodedata.category(c)[0] in "MSP" else c + for c in unicodedata.normalize("NFKC", s) + ) + + +class BasicTextNormalizer: + def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): + self.clean = ( + remove_symbols_and_diacritics if remove_diacritics else remove_symbols + ) + self.split_letters = split_letters + + def __call__(self, s: str): + s = s.lower() + s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = self.clean(s).lower() + + if self.split_letters: + s = " ".join(regex.findall(r"\X", s, regex.U)) + + s = re.sub( + r"\s+", " ", s + ) # replace any successive whitespace characters with a space + + return s diff --git a/tests/librispeech-parakeet/normalizers/english.json b/tests/librispeech-parakeet/normalizers/english.json new file mode 100644 index 00000000000..74a1c3521d9 --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/english.json @@ -0,0 +1,1741 @@ +{ + "accessorise": "accessorize", + "accessorised": "accessorized", + "accessorises": "accessorizes", + "accessorising": "accessorizing", + "acclimatisation": "acclimatization", + "acclimatise": "acclimatize", + "acclimatised": "acclimatized", + "acclimatises": "acclimatizes", + "acclimatising": "acclimatizing", + "accoutrements": "accouterments", + "aeon": "eon", + "aeons": "eons", + "aerogramme": "aerogram", + "aerogrammes": "aerograms", + "aeroplane": "airplane", + "aeroplanes": "airplanes", + "aesthete": "esthete", + "aesthetes": "esthetes", + "aesthetic": "esthetic", + "aesthetically": "esthetically", + "aesthetics": "esthetics", + "aetiology": "etiology", + "ageing": "aging", + "aggrandisement": "aggrandizement", + "agonise": "agonize", + "agonised": "agonized", + "agonises": "agonizes", + "agonising": "agonizing", + "agonisingly": "agonizingly", + "almanack": "almanac", + "almanacks": "almanacs", + "aluminium": "aluminum", + "amortisable": "amortizable", + "amortisation": "amortization", + "amortisations": "amortizations", + "amortise": "amortize", + "amortised": "amortized", + "amortises": "amortizes", + "amortising": "amortizing", + "amphitheatre": "amphitheater", + "amphitheatres": "amphitheaters", + "anaemia": "anemia", + "anaemic": "anemic", + "anaesthesia": "anesthesia", + "anaesthetic": "anesthetic", + "anaesthetics": "anesthetics", + "anaesthetise": "anesthetize", + "anaesthetised": "anesthetized", + "anaesthetises": "anesthetizes", + "anaesthetising": "anesthetizing", + "anaesthetist": "anesthetist", + "anaesthetists": "anesthetists", + "anaesthetize": "anesthetize", + "anaesthetized": "anesthetized", + "anaesthetizes": "anesthetizes", + "anaesthetizing": "anesthetizing", + "analogue": "analog", + "analogues": "analogs", + "analyse": "analyze", + "analysed": "analyzed", + "analyses": "analyzes", + "analysing": "analyzing", + "anglicise": "anglicize", + "anglicised": "anglicized", + "anglicises": "anglicizes", + "anglicising": "anglicizing", + "annualised": "annualized", + "antagonise": "antagonize", + "antagonised": "antagonized", + "antagonises": "antagonizes", + "antagonising": "antagonizing", + "apologise": "apologize", + "apologised": "apologized", + "apologises": "apologizes", + "apologising": "apologizing", + "appal": "appall", + "appals": "appalls", + "appetiser": "appetizer", + "appetisers": "appetizers", + "appetising": "appetizing", + "appetisingly": "appetizingly", + "arbour": "arbor", + "arbours": "arbors", + "archeological": "archaeological", + "archaeologically": "archeologically", + "archaeologist": "archeologist", + "archaeologists": "archeologists", + "archaeology": "archeology", + "ardour": "ardor", + "armour": "armor", + "armoured": "armored", + "armourer": "armorer", + "armourers": "armorers", + "armouries": "armories", + "armoury": "armory", + "artefact": "artifact", + "artefacts": "artifacts", + "authorise": "authorize", + "authorised": "authorized", + "authorises": "authorizes", + "authorising": "authorizing", + "axe": "ax", + "backpedalled": "backpedaled", + "backpedalling": "backpedaling", + "bannister": "banister", + "bannisters": "banisters", + "baptise": "baptize", + "baptised": "baptized", + "baptises": "baptizes", + "baptising": "baptizing", + "bastardise": "bastardize", + "bastardised": "bastardized", + "bastardises": "bastardizes", + "bastardising": "bastardizing", + "battleax": "battleaxe", + "baulk": "balk", + "baulked": "balked", + "baulking": "balking", + "baulks": "balks", + "bedevilled": "bedeviled", + "bedevilling": "bedeviling", + "behaviour": "behavior", + "behavioural": "behavioral", + "behaviourism": "behaviorism", + "behaviourist": "behaviorist", + "behaviourists": "behaviorists", + "behaviours": "behaviors", + "behove": "behoove", + "behoved": "behooved", + "behoves": "behooves", + "bejewelled": "bejeweled", + "belabour": "belabor", + "belaboured": "belabored", + "belabouring": "belaboring", + "belabours": "belabors", + "bevelled": "beveled", + "bevvies": "bevies", + "bevvy": "bevy", + "biassed": "biased", + "biassing": "biasing", + "bingeing": "binging", + "bougainvillaea": "bougainvillea", + "bougainvillaeas": "bougainvilleas", + "bowdlerise": "bowdlerize", + "bowdlerised": "bowdlerized", + "bowdlerises": "bowdlerizes", + "bowdlerising": "bowdlerizing", + "breathalyse": "breathalyze", + "breathalysed": "breathalyzed", + "breathalyser": "breathalyzer", + "breathalysers": "breathalyzers", + "breathalyses": "breathalyzes", + "breathalysing": "breathalyzing", + "brutalise": "brutalize", + "brutalised": "brutalized", + "brutalises": "brutalizes", + "brutalising": "brutalizing", + "busses": "buses", + "bussing": "busing", + "caesarean": "cesarean", + "caesareans": "cesareans", + "calibre": "caliber", + "calibres": "calibers", + "calliper": "caliper", + "callipers": "calipers", + "callisthenics": "calisthenics", + "canalise": "canalize", + "canalised": "canalized", + "canalises": "canalizes", + "canalising": "canalizing", + "cancelation": "cancellation", + "cancelations": "cancellations", + "cancelled": "canceled", + "cancelling": "canceling", + "candour": "candor", + "cannibalise": "cannibalize", + "cannibalised": "cannibalized", + "cannibalises": "cannibalizes", + "cannibalising": "cannibalizing", + "canonise": "canonize", + "canonised": "canonized", + "canonises": "canonizes", + "canonising": "canonizing", + "capitalise": "capitalize", + "capitalised": "capitalized", + "capitalises": "capitalizes", + "capitalising": "capitalizing", + "caramelise": "caramelize", + "caramelised": "caramelized", + "caramelises": "caramelizes", + "caramelising": "caramelizing", + "carbonise": "carbonize", + "carbonised": "carbonized", + "carbonises": "carbonizes", + "carbonising": "carbonizing", + "carolled": "caroled", + "carolling": "caroling", + "catalogue": "catalog", + "catalogued": "cataloged", + "catalogues": "catalogs", + "cataloguing": "cataloging", + "catalyse": "catalyze", + "catalysed": "catalyzed", + "catalyses": "catalyzes", + "catalysing": "catalyzing", + "categorise": "categorize", + "categorised": "categorized", + "categorises": "categorizes", + "categorising": "categorizing", + "cauterise": "cauterize", + "cauterised": "cauterized", + "cauterises": "cauterizes", + "cauterising": "cauterizing", + "cavilled": "caviled", + "cavilling": "caviling", + "centigramme": "centigram", + "centigrammes": "centigrams", + "centilitre": "centiliter", + "centilitres": "centiliters", + "centimetre": "centimeter", + "centimetres": "centimeters", + "centralise": "centralize", + "centralised": "centralized", + "centralises": "centralizes", + "centralising": "centralizing", + "centre": "center", + "centred": "centered", + "centrefold": "centerfold", + "centrefolds": "centerfolds", + "centrepiece": "centerpiece", + "centrepieces": "centerpieces", + "centres": "centers", + "channelled": "channeled", + "channelling": "channeling", + "characterise": "characterize", + "characterised": "characterized", + "characterises": "characterizes", + "characterising": "characterizing", + "cheque": "check", + "chequebook": "checkbook", + "chequebooks": "checkbooks", + "chequered": "checkered", + "cheques": "checks", + "chilli": "chili", + "chimaera": "chimera", + "chimaeras": "chimeras", + "chiselled": "chiseled", + "chiselling": "chiseling", + "circularise": "circularize", + "circularised": "circularized", + "circularises": "circularizes", + "circularising": "circularizing", + "civilise": "civilize", + "civilised": "civilized", + "civilises": "civilizes", + "civilising": "civilizing", + "clamour": "clamor", + "clamoured": "clamored", + "clamouring": "clamoring", + "clamours": "clamors", + "clangour": "clangor", + "clarinettist": "clarinetist", + "clarinettists": "clarinetists", + "collectivise": "collectivize", + "collectivised": "collectivized", + "collectivises": "collectivizes", + "collectivising": "collectivizing", + "colonisation": "colonization", + "colonise": "colonize", + "colonised": "colonized", + "coloniser": "colonizer", + "colonisers": "colonizers", + "colonises": "colonizes", + "colonising": "colonizing", + "colour": "color", + "colourant": "colorant", + "colourants": "colorants", + "coloured": "colored", + "coloureds": "coloreds", + "colourful": "colorful", + "colourfully": "colorfully", + "colouring": "coloring", + "colourize": "colorize", + "colourized": "colorized", + "colourizes": "colorizes", + "colourizing": "colorizing", + "colourless": "colorless", + "colours": "colors", + "commercialise": "commercialize", + "commercialised": "commercialized", + "commercialises": "commercializes", + "commercialising": "commercializing", + "compartmentalise": "compartmentalize", + "compartmentalised": "compartmentalized", + "compartmentalises": "compartmentalizes", + "compartmentalising": "compartmentalizing", + "computerise": "computerize", + "computerised": "computerized", + "computerises": "computerizes", + "computerising": "computerizing", + "conceptualise": "conceptualize", + "conceptualised": "conceptualized", + "conceptualises": "conceptualizes", + "conceptualising": "conceptualizing", + "connexion": "connection", + "connexions": "connections", + "contextualise": "contextualize", + "contextualised": "contextualized", + "contextualises": "contextualizes", + "contextualising": "contextualizing", + "cosier": "cozier", + "cosies": "cozies", + "cosiest": "coziest", + "cosily": "cozily", + "cosiness": "coziness", + "cosy": "cozy", + "councillor": "councilor", + "councillors": "councilors", + "counselled": "counseled", + "counselling": "counseling", + "counsellor": "counselor", + "counsellors": "counselors", + "crenelated": "crenellated", + "criminalise": "criminalize", + "criminalised": "criminalized", + "criminalises": "criminalizes", + "criminalising": "criminalizing", + "criticise": "criticize", + "criticised": "criticized", + "criticises": "criticizes", + "criticising": "criticizing", + "crueller": "crueler", + "cruellest": "cruelest", + "crystallisation": "crystallization", + "crystallise": "crystallize", + "crystallised": "crystallized", + "crystallises": "crystallizes", + "crystallising": "crystallizing", + "cudgelled": "cudgeled", + "cudgelling": "cudgeling", + "customise": "customize", + "customised": "customized", + "customises": "customizes", + "customising": "customizing", + "cypher": "cipher", + "cyphers": "ciphers", + "decentralisation": "decentralization", + "decentralise": "decentralize", + "decentralised": "decentralized", + "decentralises": "decentralizes", + "decentralising": "decentralizing", + "decriminalisation": "decriminalization", + "decriminalise": "decriminalize", + "decriminalised": "decriminalized", + "decriminalises": "decriminalizes", + "decriminalising": "decriminalizing", + "defence": "defense", + "defenceless": "defenseless", + "defences": "defenses", + "dehumanisation": "dehumanization", + "dehumanise": "dehumanize", + "dehumanised": "dehumanized", + "dehumanises": "dehumanizes", + "dehumanising": "dehumanizing", + "demeanour": "demeanor", + "demilitarisation": "demilitarization", + "demilitarise": "demilitarize", + "demilitarised": "demilitarized", + "demilitarises": "demilitarizes", + "demilitarising": "demilitarizing", + "demobilisation": "demobilization", + "demobilise": "demobilize", + "demobilised": "demobilized", + "demobilises": "demobilizes", + "demobilising": "demobilizing", + "democratisation": "democratization", + "democratise": "democratize", + "democratised": "democratized", + "democratises": "democratizes", + "democratising": "democratizing", + "demonise": "demonize", + "demonised": "demonized", + "demonises": "demonizes", + "demonising": "demonizing", + "demoralisation": "demoralization", + "demoralise": "demoralize", + "demoralised": "demoralized", + "demoralises": "demoralizes", + "demoralising": "demoralizing", + "denationalisation": "denationalization", + "denationalise": "denationalize", + "denationalised": "denationalized", + "denationalises": "denationalizes", + "denationalising": "denationalizing", + "deodorise": "deodorize", + "deodorised": "deodorized", + "deodorises": "deodorizes", + "deodorising": "deodorizing", + "depersonalise": "depersonalize", + "depersonalised": "depersonalized", + "depersonalises": "depersonalizes", + "depersonalising": "depersonalizing", + "deputise": "deputize", + "deputised": "deputized", + "deputises": "deputizes", + "deputising": "deputizing", + "desensitisation": "desensitization", + "desensitise": "desensitize", + "desensitised": "desensitized", + "desensitises": "desensitizes", + "desensitising": "desensitizing", + "destabilisation": "destabilization", + "destabilise": "destabilize", + "destabilised": "destabilized", + "destabilises": "destabilizes", + "destabilising": "destabilizing", + "dialled": "dialed", + "dialling": "dialing", + "dialogue": "dialog", + "dialogues": "dialogs", + "diarrhoea": "diarrhea", + "digitise": "digitize", + "digitised": "digitized", + "digitises": "digitizes", + "digitising": "digitizing", + "disc": "disk", + "discolour": "discolor", + "discoloured": "discolored", + "discolouring": "discoloring", + "discolours": "discolors", + "discs": "disks", + "disembowelled": "disemboweled", + "disembowelling": "disemboweling", + "disfavour": "disfavor", + "dishevelled": "disheveled", + "dishonour": "dishonor", + "dishonourable": "dishonorable", + "dishonourably": "dishonorably", + "dishonoured": "dishonored", + "dishonouring": "dishonoring", + "dishonours": "dishonors", + "disorganisation": "disorganization", + "disorganised": "disorganized", + "distil": "distill", + "distils": "distills", + "dramatisation": "dramatization", + "dramatisations": "dramatizations", + "dramatise": "dramatize", + "dramatised": "dramatized", + "dramatises": "dramatizes", + "dramatising": "dramatizing", + "draught": "draft", + "draughtboard": "draftboard", + "draughtboards": "draftboards", + "draughtier": "draftier", + "draughtiest": "draftiest", + "draughts": "drafts", + "draughtsman": "draftsman", + "draughtsmanship": "draftsmanship", + "draughtsmen": "draftsmen", + "draughtswoman": "draftswoman", + "draughtswomen": "draftswomen", + "draughty": "drafty", + "drivelled": "driveled", + "drivelling": "driveling", + "duelled": "dueled", + "duelling": "dueling", + "economise": "economize", + "economised": "economized", + "economises": "economizes", + "economising": "economizing", + "edoema": "edema", + "editorialise": "editorialize", + "editorialised": "editorialized", + "editorialises": "editorializes", + "editorialising": "editorializing", + "empathise": "empathize", + "empathised": "empathized", + "empathises": "empathizes", + "empathising": "empathizing", + "emphasise": "emphasize", + "emphasised": "emphasized", + "emphasises": "emphasizes", + "emphasising": "emphasizing", + "enamelled": "enameled", + "enamelling": "enameling", + "enamoured": "enamored", + "encyclopaedia": "encyclopedia", + "encyclopaedias": "encyclopedias", + "encyclopaedic": "encyclopedic", + "endeavour": "endeavor", + "endeavoured": "endeavored", + "endeavouring": "endeavoring", + "endeavours": "endeavors", + "energise": "energize", + "energised": "energized", + "energises": "energizes", + "energising": "energizing", + "enrol": "enroll", + "enrols": "enrolls", + "enthral": "enthrall", + "enthrals": "enthralls", + "epaulette": "epaulet", + "epaulettes": "epaulets", + "epicentre": "epicenter", + "epicentres": "epicenters", + "epilogue": "epilog", + "epilogues": "epilogs", + "epitomise": "epitomize", + "epitomised": "epitomized", + "epitomises": "epitomizes", + "epitomising": "epitomizing", + "equalisation": "equalization", + "equalise": "equalize", + "equalised": "equalized", + "equaliser": "equalizer", + "equalisers": "equalizers", + "equalises": "equalizes", + "equalising": "equalizing", + "eulogise": "eulogize", + "eulogised": "eulogized", + "eulogises": "eulogizes", + "eulogising": "eulogizing", + "evangelise": "evangelize", + "evangelised": "evangelized", + "evangelises": "evangelizes", + "evangelising": "evangelizing", + "exorcise": "exorcize", + "exorcised": "exorcized", + "exorcises": "exorcizes", + "exorcising": "exorcizing", + "extemporisation": "extemporization", + "extemporise": "extemporize", + "extemporised": "extemporized", + "extemporises": "extemporizes", + "extemporising": "extemporizing", + "externalisation": "externalization", + "externalisations": "externalizations", + "externalise": "externalize", + "externalised": "externalized", + "externalises": "externalizes", + "externalising": "externalizing", + "factorise": "factorize", + "factorised": "factorized", + "factorises": "factorizes", + "factorising": "factorizing", + "faecal": "fecal", + "faeces": "feces", + "familiarisation": "familiarization", + "familiarise": "familiarize", + "familiarised": "familiarized", + "familiarises": "familiarizes", + "familiarising": "familiarizing", + "fantasise": "fantasize", + "fantasised": "fantasized", + "fantasises": "fantasizes", + "fantasising": "fantasizing", + "favour": "favor", + "favourable": "favorable", + "favourably": "favorably", + "favoured": "favored", + "favouring": "favoring", + "favourite": "favorite", + "favourites": "favorites", + "favouritism": "favoritism", + "favours": "favors", + "feminise": "feminize", + "feminised": "feminized", + "feminises": "feminizes", + "feminising": "feminizing", + "fertilisation": "fertilization", + "fertilise": "fertilize", + "fertilised": "fertilized", + "fertiliser": "fertilizer", + "fertilisers": "fertilizers", + "fertilises": "fertilizes", + "fertilising": "fertilizing", + "fervour": "fervor", + "fibre": "fiber", + "fibreglass": "fiberglass", + "fibres": "fibers", + "fictionalisation": "fictionalization", + "fictionalisations": "fictionalizations", + "fictionalise": "fictionalize", + "fictionalised": "fictionalized", + "fictionalises": "fictionalizes", + "fictionalising": "fictionalizing", + "fillet": "filet", + "filleted": "fileted", + "filleting": "fileting", + "fillets": "filets", + "finalisation": "finalization", + "finalise": "finalize", + "finalised": "finalized", + "finalises": "finalizes", + "finalising": "finalizing", + "flautist": "flutist", + "flautists": "flutists", + "flavour": "flavor", + "flavoured": "flavored", + "flavouring": "flavoring", + "flavourings": "flavorings", + "flavourless": "flavorless", + "flavours": "flavors", + "flavoursome": "flavorsome", + "flyer / flier": "flier / flyer", + "foetal": "fetal", + "foetid": "fetid", + "foetus": "fetus", + "foetuses": "fetuses", + "formalisation": "formalization", + "formalise": "formalize", + "formalised": "formalized", + "formalises": "formalizes", + "formalising": "formalizing", + "fossilisation": "fossilization", + "fossilise": "fossilize", + "fossilised": "fossilized", + "fossilises": "fossilizes", + "fossilising": "fossilizing", + "fraternisation": "fraternization", + "fraternise": "fraternize", + "fraternised": "fraternized", + "fraternises": "fraternizes", + "fraternising": "fraternizing", + "fulfil": "fulfill", + "fulfilment": "fulfillment", + "fulfils": "fulfills", + "funnelled": "funneled", + "funnelling": "funneling", + "galvanise": "galvanize", + "galvanised": "galvanized", + "galvanises": "galvanizes", + "galvanising": "galvanizing", + "gambolled": "gamboled", + "gambolling": "gamboling", + "gaol": "jail", + "gaolbird": "jailbird", + "gaolbirds": "jailbirds", + "gaolbreak": "jailbreak", + "gaolbreaks": "jailbreaks", + "gaoled": "jailed", + "gaoler": "jailer", + "gaolers": "jailers", + "gaoling": "jailing", + "gaols": "jails", + "gasses": "gases", + "gage": "gauge", + "gaged": "gauged", + "gages": "gauges", + "gaging": "gauging", + "generalisation": "generalization", + "generalisations": "generalizations", + "generalise": "generalize", + "generalised": "generalized", + "generalises": "generalizes", + "generalising": "generalizing", + "ghettoise": "ghettoize", + "ghettoised": "ghettoized", + "ghettoises": "ghettoizes", + "ghettoising": "ghettoizing", + "gipsies": "gypsies", + "glamorise": "glamorize", + "glamorised": "glamorized", + "glamorises": "glamorizes", + "glamorising": "glamorizing", + "glamor": "glamour", + "globalisation": "globalization", + "globalise": "globalize", + "globalised": "globalized", + "globalises": "globalizes", + "globalising": "globalizing", + "glueing": "gluing", + "goitre": "goiter", + "goitres": "goiters", + "gonorrhoea": "gonorrhea", + "gramme": "gram", + "grammes": "grams", + "gravelled": "graveled", + "grey": "gray", + "greyed": "grayed", + "greying": "graying", + "greyish": "grayish", + "greyness": "grayness", + "greys": "grays", + "grovelled": "groveled", + "grovelling": "groveling", + "groyne": "groin", + "groynes": "groins", + "gruelling": "grueling", + "gruellingly": "gruelingly", + "gryphon": "griffin", + "gryphons": "griffins", + "gynaecological": "gynecological", + "gynaecologist": "gynecologist", + "gynaecologists": "gynecologists", + "gynaecology": "gynecology", + "haematological": "hematological", + "haematologist": "hematologist", + "haematologists": "hematologists", + "haematology": "hematology", + "haemoglobin": "hemoglobin", + "haemophilia": "hemophilia", + "haemophiliac": "hemophiliac", + "haemophiliacs": "hemophiliacs", + "haemorrhage": "hemorrhage", + "haemorrhaged": "hemorrhaged", + "haemorrhages": "hemorrhages", + "haemorrhaging": "hemorrhaging", + "haemorrhoids": "hemorrhoids", + "harbour": "harbor", + "harboured": "harbored", + "harbouring": "harboring", + "harbours": "harbors", + "harmonisation": "harmonization", + "harmonise": "harmonize", + "harmonised": "harmonized", + "harmonises": "harmonizes", + "harmonising": "harmonizing", + "homoeopath": "homeopath", + "homoeopathic": "homeopathic", + "homoeopaths": "homeopaths", + "homoeopathy": "homeopathy", + "homogenise": "homogenize", + "homogenised": "homogenized", + "homogenises": "homogenizes", + "homogenising": "homogenizing", + "honour": "honor", + "honourable": "honorable", + "honourably": "honorably", + "honoured": "honored", + "honouring": "honoring", + "honours": "honors", + "hospitalisation": "hospitalization", + "hospitalise": "hospitalize", + "hospitalised": "hospitalized", + "hospitalises": "hospitalizes", + "hospitalising": "hospitalizing", + "humanise": "humanize", + "humanised": "humanized", + "humanises": "humanizes", + "humanising": "humanizing", + "humour": "humor", + "humoured": "humored", + "humouring": "humoring", + "humourless": "humorless", + "humours": "humors", + "hybridise": "hybridize", + "hybridised": "hybridized", + "hybridises": "hybridizes", + "hybridising": "hybridizing", + "hypnotise": "hypnotize", + "hypnotised": "hypnotized", + "hypnotises": "hypnotizes", + "hypnotising": "hypnotizing", + "hypothesise": "hypothesize", + "hypothesised": "hypothesized", + "hypothesises": "hypothesizes", + "hypothesising": "hypothesizing", + "idealisation": "idealization", + "idealise": "idealize", + "idealised": "idealized", + "idealises": "idealizes", + "idealising": "idealizing", + "idolise": "idolize", + "idolised": "idolized", + "idolises": "idolizes", + "idolising": "idolizing", + "immobilisation": "immobilization", + "immobilise": "immobilize", + "immobilised": "immobilized", + "immobiliser": "immobilizer", + "immobilisers": "immobilizers", + "immobilises": "immobilizes", + "immobilising": "immobilizing", + "immortalise": "immortalize", + "immortalised": "immortalized", + "immortalises": "immortalizes", + "immortalising": "immortalizing", + "immunisation": "immunization", + "immunise": "immunize", + "immunised": "immunized", + "immunises": "immunizes", + "immunising": "immunizing", + "impanelled": "impaneled", + "impanelling": "impaneling", + "imperilled": "imperiled", + "imperilling": "imperiling", + "individualise": "individualize", + "individualised": "individualized", + "individualises": "individualizes", + "individualising": "individualizing", + "industrialise": "industrialize", + "industrialised": "industrialized", + "industrialises": "industrializes", + "industrialising": "industrializing", + "inflexion": "inflection", + "inflexions": "inflections", + "initialise": "initialize", + "initialised": "initialized", + "initialises": "initializes", + "initialising": "initializing", + "initialled": "initialed", + "initialling": "initialing", + "instal": "install", + "instalment": "installment", + "instalments": "installments", + "instals": "installs", + "instil": "instill", + "instils": "instills", + "institutionalisation": "institutionalization", + "institutionalise": "institutionalize", + "institutionalised": "institutionalized", + "institutionalises": "institutionalizes", + "institutionalising": "institutionalizing", + "intellectualise": "intellectualize", + "intellectualised": "intellectualized", + "intellectualises": "intellectualizes", + "intellectualising": "intellectualizing", + "internalisation": "internalization", + "internalise": "internalize", + "internalised": "internalized", + "internalises": "internalizes", + "internalising": "internalizing", + "internationalisation": "internationalization", + "internationalise": "internationalize", + "internationalised": "internationalized", + "internationalises": "internationalizes", + "internationalising": "internationalizing", + "ionisation": "ionization", + "ionise": "ionize", + "ionised": "ionized", + "ioniser": "ionizer", + "ionisers": "ionizers", + "ionises": "ionizes", + "ionising": "ionizing", + "italicise": "italicize", + "italicised": "italicized", + "italicises": "italicizes", + "italicising": "italicizing", + "itemise": "itemize", + "itemised": "itemized", + "itemises": "itemizes", + "itemising": "itemizing", + "jeopardise": "jeopardize", + "jeopardised": "jeopardized", + "jeopardises": "jeopardizes", + "jeopardising": "jeopardizing", + "jewelled": "jeweled", + "jeweller": "jeweler", + "jewellers": "jewelers", + "jewellery": "jewelry", + "judgement": "judgment", + "kilogramme": "kilogram", + "kilogrammes": "kilograms", + "kilometre": "kilometer", + "kilometres": "kilometers", + "labelled": "labeled", + "labelling": "labeling", + "labour": "labor", + "laboured": "labored", + "labourer": "laborer", + "labourers": "laborers", + "labouring": "laboring", + "labours": "labors", + "lacklustre": "lackluster", + "legalisation": "legalization", + "legalise": "legalize", + "legalised": "legalized", + "legalises": "legalizes", + "legalising": "legalizing", + "legitimise": "legitimize", + "legitimised": "legitimized", + "legitimises": "legitimizes", + "legitimising": "legitimizing", + "leukaemia": "leukemia", + "levelled": "leveled", + "leveller": "leveler", + "levellers": "levelers", + "levelling": "leveling", + "libelled": "libeled", + "libelling": "libeling", + "libellous": "libelous", + "liberalisation": "liberalization", + "liberalise": "liberalize", + "liberalised": "liberalized", + "liberalises": "liberalizes", + "liberalising": "liberalizing", + "licence": "license", + "licenced": "licensed", + "licences": "licenses", + "licencing": "licensing", + "likeable": "likable", + "lionisation": "lionization", + "lionise": "lionize", + "lionised": "lionized", + "lionises": "lionizes", + "lionising": "lionizing", + "liquidise": "liquidize", + "liquidised": "liquidized", + "liquidiser": "liquidizer", + "liquidisers": "liquidizers", + "liquidises": "liquidizes", + "liquidising": "liquidizing", + "litre": "liter", + "litres": "liters", + "localise": "localize", + "localised": "localized", + "localises": "localizes", + "localising": "localizing", + "louvre": "louver", + "louvred": "louvered", + "louvres": "louvers", + "lustre": "luster", + "magnetise": "magnetize", + "magnetised": "magnetized", + "magnetises": "magnetizes", + "magnetising": "magnetizing", + "manoeuvrability": "maneuverability", + "manoeuvrable": "maneuverable", + "manoeuvre": "maneuver", + "manoeuvred": "maneuvered", + "manoeuvres": "maneuvers", + "manoeuvring": "maneuvering", + "manoeuvrings": "maneuverings", + "marginalisation": "marginalization", + "marginalise": "marginalize", + "marginalised": "marginalized", + "marginalises": "marginalizes", + "marginalising": "marginalizing", + "marshalled": "marshaled", + "marshalling": "marshaling", + "marvelled": "marveled", + "marvelling": "marveling", + "marvellous": "marvelous", + "marvellously": "marvelously", + "materialisation": "materialization", + "materialise": "materialize", + "materialised": "materialized", + "materialises": "materializes", + "materialising": "materializing", + "maximisation": "maximization", + "maximise": "maximize", + "maximised": "maximized", + "maximises": "maximizes", + "maximising": "maximizing", + "meagre": "meager", + "mechanisation": "mechanization", + "mechanise": "mechanize", + "mechanised": "mechanized", + "mechanises": "mechanizes", + "mechanising": "mechanizing", + "mediaeval": "medieval", + "memorialise": "memorialize", + "memorialised": "memorialized", + "memorialises": "memorializes", + "memorialising": "memorializing", + "memorise": "memorize", + "memorised": "memorized", + "memorises": "memorizes", + "memorising": "memorizing", + "mesmerise": "mesmerize", + "mesmerised": "mesmerized", + "mesmerises": "mesmerizes", + "mesmerising": "mesmerizing", + "metabolise": "metabolize", + "metabolised": "metabolized", + "metabolises": "metabolizes", + "metabolising": "metabolizing", + "metre": "meter", + "metres": "meters", + "micrometre": "micrometer", + "micrometres": "micrometers", + "militarise": "militarize", + "militarised": "militarized", + "militarises": "militarizes", + "militarising": "militarizing", + "milligramme": "milligram", + "milligrammes": "milligrams", + "millilitre": "milliliter", + "millilitres": "milliliters", + "millimetre": "millimeter", + "millimetres": "millimeters", + "miniaturisation": "miniaturization", + "miniaturise": "miniaturize", + "miniaturised": "miniaturized", + "miniaturises": "miniaturizes", + "miniaturising": "miniaturizing", + "minibusses": "minibuses", + "minimise": "minimize", + "minimised": "minimized", + "minimises": "minimizes", + "minimising": "minimizing", + "misbehaviour": "misbehavior", + "misdemeanour": "misdemeanor", + "misdemeanours": "misdemeanors", + "misspelt": "misspelled", + "mitre": "miter", + "mitres": "miters", + "mobilisation": "mobilization", + "mobilise": "mobilize", + "mobilised": "mobilized", + "mobilises": "mobilizes", + "mobilising": "mobilizing", + "modelled": "modeled", + "modeller": "modeler", + "modellers": "modelers", + "modelling": "modeling", + "modernise": "modernize", + "modernised": "modernized", + "modernises": "modernizes", + "modernising": "modernizing", + "moisturise": "moisturize", + "moisturised": "moisturized", + "moisturiser": "moisturizer", + "moisturisers": "moisturizers", + "moisturises": "moisturizes", + "moisturising": "moisturizing", + "monologue": "monolog", + "monologues": "monologs", + "monopolisation": "monopolization", + "monopolise": "monopolize", + "monopolised": "monopolized", + "monopolises": "monopolizes", + "monopolising": "monopolizing", + "moralise": "moralize", + "moralised": "moralized", + "moralises": "moralizes", + "moralising": "moralizing", + "motorised": "motorized", + "mould": "mold", + "moulded": "molded", + "moulder": "molder", + "mouldered": "moldered", + "mouldering": "moldering", + "moulders": "molders", + "mouldier": "moldier", + "mouldiest": "moldiest", + "moulding": "molding", + "mouldings": "moldings", + "moulds": "molds", + "mouldy": "moldy", + "moult": "molt", + "moulted": "molted", + "moulting": "molting", + "moults": "molts", + "moustache": "mustache", + "moustached": "mustached", + "moustaches": "mustaches", + "moustachioed": "mustachioed", + "multicoloured": "multicolored", + "nationalisation": "nationalization", + "nationalisations": "nationalizations", + "nationalise": "nationalize", + "nationalised": "nationalized", + "nationalises": "nationalizes", + "nationalising": "nationalizing", + "naturalisation": "naturalization", + "naturalise": "naturalize", + "naturalised": "naturalized", + "naturalises": "naturalizes", + "naturalising": "naturalizing", + "neighbour": "neighbor", + "neighbourhood": "neighborhood", + "neighbourhoods": "neighborhoods", + "neighbouring": "neighboring", + "neighbourliness": "neighborliness", + "neighbourly": "neighborly", + "neighbours": "neighbors", + "neutralisation": "neutralization", + "neutralise": "neutralize", + "neutralised": "neutralized", + "neutralises": "neutralizes", + "neutralising": "neutralizing", + "normalisation": "normalization", + "normalise": "normalize", + "normalised": "normalized", + "normalises": "normalizes", + "normalising": "normalizing", + "odour": "odor", + "odourless": "odorless", + "odours": "odors", + "oesophagus": "esophagus", + "oesophaguses": "esophaguses", + "oestrogen": "estrogen", + "offence": "offense", + "offences": "offenses", + "omelette": "omelet", + "omelettes": "omelets", + "optimise": "optimize", + "optimised": "optimized", + "optimises": "optimizes", + "optimising": "optimizing", + "organisation": "organization", + "organisational": "organizational", + "organisations": "organizations", + "organise": "organize", + "organised": "organized", + "organiser": "organizer", + "organisers": "organizers", + "organises": "organizes", + "organising": "organizing", + "orthopaedic": "orthopedic", + "orthopaedics": "orthopedics", + "ostracise": "ostracize", + "ostracised": "ostracized", + "ostracises": "ostracizes", + "ostracising": "ostracizing", + "outmanoeuvre": "outmaneuver", + "outmanoeuvred": "outmaneuvered", + "outmanoeuvres": "outmaneuvers", + "outmanoeuvring": "outmaneuvering", + "overemphasise": "overemphasize", + "overemphasised": "overemphasized", + "overemphasises": "overemphasizes", + "overemphasising": "overemphasizing", + "oxidisation": "oxidization", + "oxidise": "oxidize", + "oxidised": "oxidized", + "oxidises": "oxidizes", + "oxidising": "oxidizing", + "paederast": "pederast", + "paederasts": "pederasts", + "paediatric": "pediatric", + "paediatrician": "pediatrician", + "paediatricians": "pediatricians", + "paediatrics": "pediatrics", + "paedophile": "pedophile", + "paedophiles": "pedophiles", + "paedophilia": "pedophilia", + "palaeolithic": "paleolithic", + "palaeontologist": "paleontologist", + "palaeontologists": "paleontologists", + "palaeontology": "paleontology", + "panelled": "paneled", + "panelling": "paneling", + "panellist": "panelist", + "panellists": "panelists", + "paralyse": "paralyze", + "paralysed": "paralyzed", + "paralyses": "paralyzes", + "paralysing": "paralyzing", + "parcelled": "parceled", + "parcelling": "parceling", + "parlour": "parlor", + "parlours": "parlors", + "particularise": "particularize", + "particularised": "particularized", + "particularises": "particularizes", + "particularising": "particularizing", + "passivisation": "passivization", + "passivise": "passivize", + "passivised": "passivized", + "passivises": "passivizes", + "passivising": "passivizing", + "pasteurisation": "pasteurization", + "pasteurise": "pasteurize", + "pasteurised": "pasteurized", + "pasteurises": "pasteurizes", + "pasteurising": "pasteurizing", + "patronise": "patronize", + "patronised": "patronized", + "patronises": "patronizes", + "patronising": "patronizing", + "patronisingly": "patronizingly", + "pedalled": "pedaled", + "pedalling": "pedaling", + "pedestrianisation": "pedestrianization", + "pedestrianise": "pedestrianize", + "pedestrianised": "pedestrianized", + "pedestrianises": "pedestrianizes", + "pedestrianising": "pedestrianizing", + "penalise": "penalize", + "penalised": "penalized", + "penalises": "penalizes", + "penalising": "penalizing", + "pencilled": "penciled", + "pencilling": "penciling", + "personalise": "personalize", + "personalised": "personalized", + "personalises": "personalizes", + "personalising": "personalizing", + "pharmacopoeia": "pharmacopeia", + "pharmacopoeias": "pharmacopeias", + "philosophise": "philosophize", + "philosophised": "philosophized", + "philosophises": "philosophizes", + "philosophising": "philosophizing", + "philtre": "filter", + "philtres": "filters", + "phoney": "phony", + "plagiarise": "plagiarize", + "plagiarised": "plagiarized", + "plagiarises": "plagiarizes", + "plagiarising": "plagiarizing", + "plough": "plow", + "ploughed": "plowed", + "ploughing": "plowing", + "ploughman": "plowman", + "ploughmen": "plowmen", + "ploughs": "plows", + "ploughshare": "plowshare", + "ploughshares": "plowshares", + "polarisation": "polarization", + "polarise": "polarize", + "polarised": "polarized", + "polarises": "polarizes", + "polarising": "polarizing", + "politicisation": "politicization", + "politicise": "politicize", + "politicised": "politicized", + "politicises": "politicizes", + "politicising": "politicizing", + "popularisation": "popularization", + "popularise": "popularize", + "popularised": "popularized", + "popularises": "popularizes", + "popularising": "popularizing", + "pouffe": "pouf", + "pouffes": "poufs", + "practise": "practice", + "practised": "practiced", + "practises": "practices", + "practising": "practicing", + "praesidium": "presidium", + "praesidiums": "presidiums", + "pressurisation": "pressurization", + "pressurise": "pressurize", + "pressurised": "pressurized", + "pressurises": "pressurizes", + "pressurising": "pressurizing", + "pretence": "pretense", + "pretences": "pretenses", + "primaeval": "primeval", + "prioritisation": "prioritization", + "prioritise": "prioritize", + "prioritised": "prioritized", + "prioritises": "prioritizes", + "prioritising": "prioritizing", + "privatisation": "privatization", + "privatisations": "privatizations", + "privatise": "privatize", + "privatised": "privatized", + "privatises": "privatizes", + "privatising": "privatizing", + "professionalisation": "professionalization", + "professionalise": "professionalize", + "professionalised": "professionalized", + "professionalises": "professionalizes", + "professionalising": "professionalizing", + "programme": "program", + "programmes": "programs", + "prologue": "prolog", + "prologues": "prologs", + "propagandise": "propagandize", + "propagandised": "propagandized", + "propagandises": "propagandizes", + "propagandising": "propagandizing", + "proselytise": "proselytize", + "proselytised": "proselytized", + "proselytiser": "proselytizer", + "proselytisers": "proselytizers", + "proselytises": "proselytizes", + "proselytising": "proselytizing", + "psychoanalyse": "psychoanalyze", + "psychoanalysed": "psychoanalyzed", + "psychoanalyses": "psychoanalyzes", + "psychoanalysing": "psychoanalyzing", + "publicise": "publicize", + "publicised": "publicized", + "publicises": "publicizes", + "publicising": "publicizing", + "pulverisation": "pulverization", + "pulverise": "pulverize", + "pulverised": "pulverized", + "pulverises": "pulverizes", + "pulverising": "pulverizing", + "pummelled": "pummel", + "pummelling": "pummeled", + "pyjama": "pajama", + "pyjamas": "pajamas", + "pzazz": "pizzazz", + "quarrelled": "quarreled", + "quarrelling": "quarreling", + "radicalise": "radicalize", + "radicalised": "radicalized", + "radicalises": "radicalizes", + "radicalising": "radicalizing", + "rancour": "rancor", + "randomise": "randomize", + "randomised": "randomized", + "randomises": "randomizes", + "randomising": "randomizing", + "rationalisation": "rationalization", + "rationalisations": "rationalizations", + "rationalise": "rationalize", + "rationalised": "rationalized", + "rationalises": "rationalizes", + "rationalising": "rationalizing", + "ravelled": "raveled", + "ravelling": "raveling", + "realisable": "realizable", + "realisation": "realization", + "realisations": "realizations", + "realise": "realize", + "realised": "realized", + "realises": "realizes", + "realising": "realizing", + "recognisable": "recognizable", + "recognisably": "recognizably", + "recognisance": "recognizance", + "recognise": "recognize", + "recognised": "recognized", + "recognises": "recognizes", + "recognising": "recognizing", + "reconnoitre": "reconnoiter", + "reconnoitred": "reconnoitered", + "reconnoitres": "reconnoiters", + "reconnoitring": "reconnoitering", + "refuelled": "refueled", + "refuelling": "refueling", + "regularisation": "regularization", + "regularise": "regularize", + "regularised": "regularized", + "regularises": "regularizes", + "regularising": "regularizing", + "remodelled": "remodeled", + "remodelling": "remodeling", + "remould": "remold", + "remoulded": "remolded", + "remoulding": "remolding", + "remoulds": "remolds", + "reorganisation": "reorganization", + "reorganisations": "reorganizations", + "reorganise": "reorganize", + "reorganised": "reorganized", + "reorganises": "reorganizes", + "reorganising": "reorganizing", + "revelled": "reveled", + "reveller": "reveler", + "revellers": "revelers", + "revelling": "reveling", + "revitalise": "revitalize", + "revitalised": "revitalized", + "revitalises": "revitalizes", + "revitalising": "revitalizing", + "revolutionise": "revolutionize", + "revolutionised": "revolutionized", + "revolutionises": "revolutionizes", + "revolutionising": "revolutionizing", + "rhapsodise": "rhapsodize", + "rhapsodised": "rhapsodized", + "rhapsodises": "rhapsodizes", + "rhapsodising": "rhapsodizing", + "rigour": "rigor", + "rigours": "rigors", + "ritualised": "ritualized", + "rivalled": "rivaled", + "rivalling": "rivaling", + "romanticise": "romanticize", + "romanticised": "romanticized", + "romanticises": "romanticizes", + "romanticising": "romanticizing", + "rumour": "rumor", + "rumoured": "rumored", + "rumours": "rumors", + "sabre": "saber", + "sabres": "sabers", + "saltpetre": "saltpeter", + "sanitise": "sanitize", + "sanitised": "sanitized", + "sanitises": "sanitizes", + "sanitising": "sanitizing", + "satirise": "satirize", + "satirised": "satirized", + "satirises": "satirizes", + "satirising": "satirizing", + "saviour": "savior", + "saviours": "saviors", + "savour": "savor", + "savoured": "savored", + "savouries": "savories", + "savouring": "savoring", + "savours": "savors", + "savoury": "savory", + "scandalise": "scandalize", + "scandalised": "scandalized", + "scandalises": "scandalizes", + "scandalising": "scandalizing", + "sceptic": "skeptic", + "sceptical": "skeptical", + "sceptically": "skeptically", + "scepticism": "skepticism", + "sceptics": "skeptics", + "sceptre": "scepter", + "sceptres": "scepters", + "scrutinise": "scrutinize", + "scrutinised": "scrutinized", + "scrutinises": "scrutinizes", + "scrutinising": "scrutinizing", + "secularisation": "secularization", + "secularise": "secularize", + "secularised": "secularized", + "secularises": "secularizes", + "secularising": "secularizing", + "sensationalise": "sensationalize", + "sensationalised": "sensationalized", + "sensationalises": "sensationalizes", + "sensationalising": "sensationalizing", + "sensitise": "sensitize", + "sensitised": "sensitized", + "sensitises": "sensitizes", + "sensitising": "sensitizing", + "sentimentalise": "sentimentalize", + "sentimentalised": "sentimentalized", + "sentimentalises": "sentimentalizes", + "sentimentalising": "sentimentalizing", + "sepulchre": "sepulcher", + "sepulchres": "sepulchers", + "serialisation": "serialization", + "serialisations": "serializations", + "serialise": "serialize", + "serialised": "serialized", + "serialises": "serializes", + "serialising": "serializing", + "sermonise": "sermonize", + "sermonised": "sermonized", + "sermonises": "sermonizes", + "sermonising": "sermonizing", + "sheikh": "sheik", + "shovelled": "shoveled", + "shovelling": "shoveling", + "shrivelled": "shriveled", + "shrivelling": "shriveling", + "signalise": "signalize", + "signalised": "signalized", + "signalises": "signalizes", + "signalising": "signalizing", + "signalled": "signaled", + "signalling": "signaling", + "smoulder": "smolder", + "smouldered": "smoldered", + "smouldering": "smoldering", + "smoulders": "smolders", + "snivelled": "sniveled", + "snivelling": "sniveling", + "snorkelled": "snorkeled", + "snorkelling": "snorkeling", + "snowplough": "snowplow", + "snowploughs": "snowplow", + "socialisation": "socialization", + "socialise": "socialize", + "socialised": "socialized", + "socialises": "socializes", + "socialising": "socializing", + "sodomise": "sodomize", + "sodomised": "sodomized", + "sodomises": "sodomizes", + "sodomising": "sodomizing", + "solemnise": "solemnize", + "solemnised": "solemnized", + "solemnises": "solemnizes", + "solemnising": "solemnizing", + "sombre": "somber", + "specialisation": "specialization", + "specialisations": "specializations", + "specialise": "specialize", + "specialised": "specialized", + "specialises": "specializes", + "specialising": "specializing", + "spectre": "specter", + "spectres": "specters", + "spiralled": "spiraled", + "spiralling": "spiraling", + "splendour": "splendor", + "splendours": "splendors", + "squirrelled": "squirreled", + "squirrelling": "squirreling", + "stabilisation": "stabilization", + "stabilise": "stabilize", + "stabilised": "stabilized", + "stabiliser": "stabilizer", + "stabilisers": "stabilizers", + "stabilises": "stabilizes", + "stabilising": "stabilizing", + "standardisation": "standardization", + "standardise": "standardize", + "standardised": "standardized", + "standardises": "standardizes", + "standardising": "standardizing", + "stencilled": "stenciled", + "stencilling": "stenciling", + "sterilisation": "sterilization", + "sterilisations": "sterilizations", + "sterilise": "sterilize", + "sterilised": "sterilized", + "steriliser": "sterilizer", + "sterilisers": "sterilizers", + "sterilises": "sterilizes", + "sterilising": "sterilizing", + "stigmatisation": "stigmatization", + "stigmatise": "stigmatize", + "stigmatised": "stigmatized", + "stigmatises": "stigmatizes", + "stigmatising": "stigmatizing", + "storey": "story", + "storeys": "stories", + "subsidisation": "subsidization", + "subsidise": "subsidize", + "subsidised": "subsidized", + "subsidiser": "subsidizer", + "subsidisers": "subsidizers", + "subsidises": "subsidizes", + "subsidising": "subsidizing", + "succour": "succor", + "succoured": "succored", + "succouring": "succoring", + "succours": "succors", + "sulphate": "sulfate", + "sulphates": "sulfates", + "sulphide": "sulfide", + "sulphides": "sulfides", + "sulphur": "sulfur", + "sulphurous": "sulfurous", + "summarise": "summarize", + "summarised": "summarized", + "summarises": "summarizes", + "summarising": "summarizing", + "swivelled": "swiveled", + "swivelling": "swiveling", + "symbolise": "symbolize", + "symbolised": "symbolized", + "symbolises": "symbolizes", + "symbolising": "symbolizing", + "sympathise": "sympathize", + "sympathised": "sympathized", + "sympathiser": "sympathizer", + "sympathisers": "sympathizers", + "sympathises": "sympathizes", + "sympathising": "sympathizing", + "synchronisation": "synchronization", + "synchronise": "synchronize", + "synchronised": "synchronized", + "synchronises": "synchronizes", + "synchronising": "synchronizing", + "synthesise": "synthesize", + "synthesised": "synthesized", + "synthesiser": "synthesizer", + "synthesisers": "synthesizers", + "synthesises": "synthesizes", + "synthesising": "synthesizing", + "syphon": "siphon", + "syphoned": "siphoned", + "syphoning": "siphoning", + "syphons": "siphons", + "systematisation": "systematization", + "systematise": "systematize", + "systematised": "systematized", + "systematises": "systematizes", + "systematising": "systematizing", + "tantalise": "tantalize", + "tantalised": "tantalized", + "tantalises": "tantalizes", + "tantalising": "tantalizing", + "tantalisingly": "tantalizingly", + "tasselled": "tasseled", + "technicolour": "technicolor", + "temporise": "temporize", + "temporised": "temporized", + "temporises": "temporizes", + "temporising": "temporizing", + "tenderise": "tenderize", + "tenderised": "tenderized", + "tenderises": "tenderizes", + "tenderising": "tenderizing", + "terrorise": "terrorize", + "terrorised": "terrorized", + "terrorises": "terrorizes", + "terrorising": "terrorizing", + "theatre": "theater", + "theatregoer": "theatergoer", + "theatregoers": "theatergoers", + "theatres": "theaters", + "theorise": "theorize", + "theorised": "theorized", + "theorises": "theorizes", + "theorising": "theorizing", + "tonne": "ton", + "tonnes": "tons", + "towelled": "toweled", + "towelling": "toweling", + "toxaemia": "toxemia", + "tranquillise": "tranquilize", + "tranquillised": "tranquilized", + "tranquilliser": "tranquilizer", + "tranquillisers": "tranquilizers", + "tranquillises": "tranquilizes", + "tranquillising": "tranquilizing", + "tranquillity": "tranquility", + "tranquillize": "tranquilize", + "tranquillized": "tranquilized", + "tranquillizer": "tranquilizer", + "tranquillizers": "tranquilizers", + "tranquillizes": "tranquilizes", + "tranquillizing": "tranquilizing", + "tranquilly": "tranquility", + "transistorised": "transistorized", + "traumatise": "traumatize", + "traumatised": "traumatized", + "traumatises": "traumatizes", + "traumatising": "traumatizing", + "travelled": "traveled", + "traveller": "traveler", + "travellers": "travelers", + "travelling": "traveling", + "travelog": "travelogue", + "travelogs": "travelogues", + "trialled": "trialed", + "trialling": "trialing", + "tricolour": "tricolor", + "tricolours": "tricolors", + "trivialise": "trivialize", + "trivialised": "trivialized", + "trivialises": "trivializes", + "trivialising": "trivializing", + "tumour": "tumor", + "tumours": "tumors", + "tunnelled": "tunneled", + "tunnelling": "tunneling", + "tyrannise": "tyrannize", + "tyrannised": "tyrannized", + "tyrannises": "tyrannizes", + "tyrannising": "tyrannizing", + "tyre": "tire", + "tyres": "tires", + "unauthorised": "unauthorized", + "uncivilised": "uncivilized", + "underutilised": "underutilized", + "unequalled": "unequaled", + "unfavourable": "unfavorable", + "unfavourably": "unfavorably", + "unionisation": "unionization", + "unionise": "unionize", + "unionised": "unionized", + "unionises": "unionizes", + "unionising": "unionizing", + "unorganised": "unorganized", + "unravelled": "unraveled", + "unravelling": "unraveling", + "unrecognisable": "unrecognizable", + "unrecognised": "unrecognized", + "unrivalled": "unrivaled", + "unsavoury": "unsavory", + "untrammelled": "untrammeled", + "urbanisation": "urbanization", + "urbanise": "urbanize", + "urbanised": "urbanized", + "urbanises": "urbanizes", + "urbanising": "urbanizing", + "utilisable": "utilizable", + "utilisation": "utilization", + "utilise": "utilize", + "utilised": "utilized", + "utilises": "utilizes", + "utilising": "utilizing", + "valour": "valor", + "vandalise": "vandalize", + "vandalised": "vandalized", + "vandalises": "vandalizes", + "vandalising": "vandalizing", + "vaporisation": "vaporization", + "vaporise": "vaporize", + "vaporised": "vaporized", + "vaporises": "vaporizes", + "vaporising": "vaporizing", + "vapour": "vapor", + "vapours": "vapors", + "verbalise": "verbalize", + "verbalised": "verbalized", + "verbalises": "verbalizes", + "verbalising": "verbalizing", + "victimisation": "victimization", + "victimise": "victimize", + "victimised": "victimized", + "victimises": "victimizes", + "victimising": "victimizing", + "videodisc": "videodisk", + "videodiscs": "videodisks", + "vigour": "vigor", + "visualisation": "visualization", + "visualisations": "visualizations", + "visualise": "visualize", + "visualised": "visualized", + "visualises": "visualizes", + "visualising": "visualizing", + "vocalisation": "vocalization", + "vocalisations": "vocalizations", + "vocalise": "vocalize", + "vocalised": "vocalized", + "vocalises": "vocalizes", + "vocalising": "vocalizing", + "vulcanised": "vulcanized", + "vulgarisation": "vulgarization", + "vulgarise": "vulgarize", + "vulgarised": "vulgarized", + "vulgarises": "vulgarizes", + "vulgarising": "vulgarizing", + "waggon": "wagon", + "waggons": "wagons", + "watercolour": "watercolor", + "watercolours": "watercolors", + "weaselled": "weaseled", + "weaselling": "weaseling", + "westernisation": "westernization", + "westernise": "westernize", + "westernised": "westernized", + "westernises": "westernizes", + "westernising": "westernizing", + "womanise": "womanize", + "womanised": "womanized", + "womaniser": "womanizer", + "womanisers": "womanizers", + "womanises": "womanizes", + "womanising": "womanizing", + "woollen": "woolen", + "woollens": "woolens", + "woollies": "woolies", + "woolly": "wooly", + "worshipped": "worshiped", + "worshipping": "worshiping", + "worshipper": "worshiper", + "yodelled": "yodeled", + "yodelling": "yodeling", + "yoghourt": "yogurt", + "yoghourts": "yogurts", + "yoghurt": "yogurt", + "yoghurts": "yogurts", + "mhm": "hmm", + "mmm": "hmm" +} \ No newline at end of file diff --git a/tests/librispeech-parakeet/normalizers/english.py b/tests/librispeech-parakeet/normalizers/english.py new file mode 100644 index 00000000000..4932042bc5b --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/english.py @@ -0,0 +1,550 @@ +import json +import os +import re +from fractions import Fraction +from typing import Iterator, List, Match, Optional, Union + +from more_itertools import windowed + +from .basic import remove_symbols_and_diacritics + + +class EnglishNumberNormalizer: + """ + Convert any spelled-out numbers into arabic numbers, while handling: + + - remove any commas + - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc. + - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars` + - spell out `one` and `ones` + - interpret successive single-digit numbers as nominal: `one oh one` -> `101` + """ + + def __init__(self): + super().__init__() + + self.zeros = {"o", "oh", "zero"} + self.ones = { + name: i + for i, name in enumerate( + [ + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + "ten", + "eleven", + "twelve", + "thirteen", + "fourteen", + "fifteen", + "sixteen", + "seventeen", + "eighteen", + "nineteen", + ], + start=1, + ) + } + self.ones_plural = { + "sixes" if name == "six" else name + "s": (value, "s") + for name, value in self.ones.items() + } + self.ones_ordinal = { + "zeroth": (0, "th"), + "first": (1, "st"), + "second": (2, "nd"), + "third": (3, "rd"), + "fifth": (5, "th"), + "twelfth": (12, "th"), + **{ + name + ("h" if name.endswith("t") else "th"): (value, "th") + for name, value in self.ones.items() + if value > 3 and value != 5 and value != 12 + }, + } + self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal} + + self.tens = { + "twenty": 20, + "thirty": 30, + "forty": 40, + "fifty": 50, + "sixty": 60, + "seventy": 70, + "eighty": 80, + "ninety": 90, + } + self.tens_plural = { + name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() + } + self.tens_ordinal = { + name.replace("y", "ieth"): (value, "th") + for name, value in self.tens.items() + } + self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} + + self.multipliers = { + "hundred": 100, + "thousand": 1_000, + "million": 1_000_000, + "billion": 1_000_000_000, + "trillion": 1_000_000_000_000, + "quadrillion": 1_000_000_000_000_000, + "quintillion": 1_000_000_000_000_000_000, + "sextillion": 1_000_000_000_000_000_000_000, + "septillion": 1_000_000_000_000_000_000_000_000, + "octillion": 1_000_000_000_000_000_000_000_000_000, + "nonillion": 1_000_000_000_000_000_000_000_000_000_000, + "decillion": 1_000_000_000_000_000_000_000_000_000_000_000, + } + self.multipliers_plural = { + name + "s": (value, "s") for name, value in self.multipliers.items() + } + self.multipliers_ordinal = { + name + "th": (value, "th") for name, value in self.multipliers.items() + } + self.multipliers_suffixed = { + **self.multipliers_plural, + **self.multipliers_ordinal, + } + self.decimals = {*self.ones, *self.tens, *self.zeros} + + self.preceding_prefixers = { + "minus": "-", + "negative": "-", + "plus": "+", + "positive": "+", + } + self.following_prefixers = { + "pound": "£", + "pounds": "£", + "euro": "€", + "euros": "€", + "dollar": "$", + "dollars": "$", + "cent": "¢", + "cents": "¢", + } + self.prefixes = set( + list(self.preceding_prefixers.values()) + + list(self.following_prefixers.values()) + ) + self.suffixers = { + "per": {"cent": "%"}, + "percent": "%", + } + self.specials = {"and", "double", "triple", "point"} + + self.words = set( + [ + key + for mapping in [ + self.zeros, + self.ones, + self.ones_suffixed, + self.tens, + self.tens_suffixed, + self.multipliers, + self.multipliers_suffixed, + self.preceding_prefixers, + self.following_prefixers, + self.suffixers, + self.specials, + ] + for key in mapping + ] + ) + self.literal_words = {"one", "ones"} + + def process_words(self, words: List[str]) -> Iterator[str]: + prefix: Optional[str] = None + value: Optional[Union[str, int]] = None + skip = False + + def to_fraction(s: str): + try: + return Fraction(s) + except ValueError: + return None + + def output(result: Union[str, int]): + nonlocal prefix, value + result = str(result) + if prefix is not None: + result = prefix + result + value = None + prefix = None + return result + + if len(words) == 0: + return + + for prev, current, next in windowed([None] + words + [None], 3): + if skip: + skip = False + continue + + next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next) + has_prefix = current[0] in self.prefixes + current_without_prefix = current[1:] if has_prefix else current + if re.match(r"^\d+(\.\d+)?$", current_without_prefix): + # arabic numbers (potentially with signs and fractions) + f = to_fraction(current_without_prefix) + assert f is not None + if value is not None: + if isinstance(value, str) and value.endswith("."): + # concatenate decimals / ip address components + value = str(value) + str(current) + continue + else: + yield output(value) + + prefix = current[0] if has_prefix else prefix + if f.denominator == 1: + value = f.numerator # store integers as int + else: + value = current_without_prefix + elif current not in self.words: + # non-numeric words + if value is not None: + yield output(value) + yield output(current) + elif current in self.zeros: + value = str(value or "") + "0" + elif current in self.ones: + ones = self.ones[current] + + if value is None: + value = ones + elif isinstance(value, str) or prev in self.ones: + if ( + prev in self.tens and ones < 10 + ): # replace the last zero with the digit + assert value[-1] == "0" + value = value[:-1] + str(ones) + else: + value = str(value) + str(ones) + elif ones < 10: + if value % 10 == 0: + value += ones + else: + value = str(value) + str(ones) + else: # eleven to nineteen + if value % 100 == 0: + value += ones + else: + value = str(value) + str(ones) + elif current in self.ones_suffixed: + # ordinal or cardinal; yield the number right away + ones, suffix = self.ones_suffixed[current] + if value is None: + yield output(str(ones) + suffix) + elif isinstance(value, str) or prev in self.ones: + if prev in self.tens and ones < 10: + assert value[-1] == "0" + yield output(value[:-1] + str(ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + elif ones < 10: + if value % 10 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + else: # eleven to nineteen + if value % 100 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + value = None + elif current in self.tens: + tens = self.tens[current] + if value is None: + value = tens + elif isinstance(value, str): + value = str(value) + str(tens) + else: + if value % 100 == 0: + value += tens + else: + value = str(value) + str(tens) + elif current in self.tens_suffixed: + # ordinal or cardinal; yield the number right away + tens, suffix = self.tens_suffixed[current] + if value is None: + yield output(str(tens) + suffix) + elif isinstance(value, str): + yield output(str(value) + str(tens) + suffix) + else: + if value % 100 == 0: + yield output(str(value + tens) + suffix) + else: + yield output(str(value) + str(tens) + suffix) + elif current in self.multipliers: + multiplier = self.multipliers[current] + if value is None: + value = multiplier + elif isinstance(value, str) or value == 0: + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + value = p.numerator + else: + yield output(value) + value = multiplier + else: + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + elif current in self.multipliers_suffixed: + multiplier, suffix = self.multipliers_suffixed[current] + if value is None: + yield output(str(multiplier) + suffix) + elif isinstance(value, str): + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + yield output(str(p.numerator) + suffix) + else: + yield output(value) + yield output(str(multiplier) + suffix) + else: # int + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + yield output(str(value) + suffix) + value = None + elif current in self.preceding_prefixers: + # apply prefix (positive, minus, etc.) if it precedes a number + if value is not None: + yield output(value) + + if next in self.words or next_is_numeric: + prefix = self.preceding_prefixers[current] + else: + yield output(current) + elif current in self.following_prefixers: + # apply prefix (dollars, cents, etc.) only after a number + if value is not None: + prefix = self.following_prefixers[current] + yield output(value) + else: + yield output(current) + elif current in self.suffixers: + # apply suffix symbols (percent -> '%') + if value is not None: + suffix = self.suffixers[current] + if isinstance(suffix, dict): + if next in suffix: + yield output(str(value) + suffix[next]) + skip = True + else: + yield output(value) + yield output(current) + else: + yield output(str(value) + suffix) + else: + yield output(current) + elif current in self.specials: + if next not in self.words and not next_is_numeric: + # apply special handling only if the next word can be numeric + if value is not None: + yield output(value) + yield output(current) + elif current == "and": + # ignore "and" after hundreds, thousands, etc. + if prev not in self.multipliers: + if value is not None: + yield output(value) + yield output(current) + elif current == "double" or current == "triple": + if next in self.ones or next in self.zeros: + repeats = 2 if current == "double" else 3 + ones = self.ones.get(next, 0) + value = str(value or "") + str(ones) * repeats + skip = True + else: + if value is not None: + yield output(value) + yield output(current) + elif current == "point": + if next in self.decimals or next_is_numeric: + value = str(value or "") + "." + else: + # should all have been covered at this point + raise ValueError(f"Unexpected token: {current}") + else: + # all should have been covered at this point + raise ValueError(f"Unexpected token: {current}") + + if value is not None: + yield output(value) + + def preprocess(self, s: str): + # replace " and a half" with " point five" + results = [] + + segments = re.split(r"\band\s+a\s+half\b", s) + for i, segment in enumerate(segments): + if len(segment.strip()) == 0: + continue + if i == len(segments) - 1: + results.append(segment) + else: + results.append(segment) + last_word = segment.rsplit(maxsplit=2)[-1] + if last_word in self.decimals or last_word in self.multipliers: + results.append("point five") + else: + results.append("and a half") + + s = " ".join(results) + + # put a space at number/letter boundary + s = re.sub(r"([a-z])([0-9])", r"\1 \2", s) + s = re.sub(r"([0-9])([a-z])", r"\1 \2", s) + + # but remove spaces which could be a suffix + s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s) + + return s + + def postprocess(self, s: str): + def combine_cents(m: Match): + try: + currency = m.group(1) + integer = m.group(2) + cents = int(m.group(3)) + return f"{currency}{integer}.{cents:02d}" + except ValueError: + return m.string + + def extract_cents(m: Match): + try: + return f"¢{int(m.group(1))}" + except ValueError: + return m.string + + # apply currency postprocessing; "$2 and ¢7" -> "$2.07" + s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s) + s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s) + + # write "one(s)" instead of "1(s)", just for the readability + s = re.sub(r"\b1(s?)\b", r"one\1", s) + + return s + + def __call__(self, s: str): + s = self.preprocess(s) + s = " ".join(word for word in self.process_words(s.split()) if word is not None) + s = self.postprocess(s) + + return s + + +class EnglishSpellingNormalizer: + """ + Applies British-American spelling mappings as listed in [1]. + + [1] https://www.tysto.com/uk-us-spelling-list.html + """ + + def __init__(self): + mapping_path = os.path.join(os.path.dirname(__file__), "english.json") + self.mapping = json.load(open(mapping_path)) + + def __call__(self, s: str): + return " ".join(self.mapping.get(word, word) for word in s.split()) + + +class EnglishTextNormalizer: + def __init__(self): + self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b" + self.replacers = { + # common contractions + r"\bwon't\b": "will not", + r"\bcan't\b": "can not", + r"\blet's\b": "let us", + r"\bain't\b": "aint", + r"\by'all\b": "you all", + r"\bwanna\b": "want to", + r"\bgotta\b": "got to", + r"\bgonna\b": "going to", + r"\bi'ma\b": "i am going to", + r"\bimma\b": "i am going to", + r"\bwoulda\b": "would have", + r"\bcoulda\b": "could have", + r"\bshoulda\b": "should have", + r"\bma'am\b": "madam", + # contractions in titles/prefixes + r"\bmr\b": "mister ", + r"\bmrs\b": "missus ", + r"\bst\b": "saint ", + r"\bdr\b": "doctor ", + r"\bprof\b": "professor ", + r"\bcapt\b": "captain ", + r"\bgov\b": "governor ", + r"\bald\b": "alderman ", + r"\bgen\b": "general ", + r"\bsen\b": "senator ", + r"\brep\b": "representative ", + r"\bpres\b": "president ", + r"\brev\b": "reverend ", + r"\bhon\b": "honorable ", + r"\basst\b": "assistant ", + r"\bassoc\b": "associate ", + r"\blt\b": "lieutenant ", + r"\bcol\b": "colonel ", + r"\bjr\b": "junior ", + r"\bsr\b": "senior ", + r"\besq\b": "esquire ", + # prefect tenses, ideally it should be any past participles, but it's harder.. + r"'d been\b": " had been", + r"'s been\b": " has been", + r"'d gone\b": " had gone", + r"'s gone\b": " has gone", + r"'d done\b": " had done", # "'s done" is ambiguous + r"'s got\b": " has got", + # general contractions + r"n't\b": " not", + r"'re\b": " are", + r"'s\b": " is", + r"'d\b": " would", + r"'ll\b": " will", + r"'t\b": " not", + r"'ve\b": " have", + r"'m\b": " am", + } + self.standardize_numbers = EnglishNumberNormalizer() + self.standardize_spellings = EnglishSpellingNormalizer() + + def __call__(self, s: str): + s = s.lower() + + s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = re.sub(self.ignore_patterns, "", s) + s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe + + for pattern, replacement in self.replacers.items(): + s = re.sub(pattern, replacement, s) + + s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits + s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers + s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols + + s = self.standardize_numbers(s) + s = self.standardize_spellings(s) + + # now remove prefix/suffix symbols that are not preceded/followed by numbers + s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) + s = re.sub(r"([^0-9])%", r"\1 ", s) + + s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space + + return s diff --git a/tests/test-parakeet-full.cpp b/tests/test-parakeet-full.cpp new file mode 100644 index 00000000000..3de9e86b308 --- /dev/null +++ b/tests/test-parakeet-full.cpp @@ -0,0 +1,62 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include +#include + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + static bool is_first = true; + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf)); + + int32_t time_ms = token_data->frame_index * 10; + + printf("%s", text_buf); + fflush(stdout); + + is_first = false; +} + +int main() { + std::string model_path = PARAKEET_MODEL_PATH; + std::string sample_path = SAMPLE_PATH; + + std::vector pcmf32; + std::vector> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + assert(pcmf32s.size() == 0); // no stereo vector + + printf("Loading Parakeet model from: %s\n", model_path.c_str()); + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + + struct parakeet_context * pctx = parakeet_init_from_file_with_params(model_path.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "Failed to load Parakeet model\n"); + return 1; + } + printf("Successfully loaded Parakeet model\n"); + + struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + params.new_token_callback = token_callback; + params.new_token_callback_user_data = nullptr; + + params.chunk_length_ms = 10000; + params.left_context_ms = 10000; + params.right_context_ms = 4960; + + int ret = parakeet_full(pctx, params, pcmf32.data(), pcmf32.size()); + assert(ret == 0); + + parakeet_free(pctx); + + printf("\nTest passed: parakeet_full succeeded!\n"); + return 0; +} diff --git a/tests/test-parakeet-stream.cpp b/tests/test-parakeet-stream.cpp new file mode 100644 index 00000000000..f112bcb932b --- /dev/null +++ b/tests/test-parakeet-stream.cpp @@ -0,0 +1,107 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include +#include + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + static bool is_first = true; + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf)); + + int32_t time_ms = token_data->frame_index * 10; + + printf("%s", text_buf); + fflush(stdout); + + is_first = false; +} + +void segment_callback(parakeet_context * ctx, parakeet_state * state, int n_new, void * user_data) { + const int n_segments = parakeet_full_n_segments_from_state(state); + const int s0 = n_segments - n_new; + + printf("\nSegment Callback: %d new segment(s)\n", n_new); + + for (int i = s0; i < n_segments; i++) { + const char * text = parakeet_full_get_segment_text_from_state(state, i); + const int64_t t0 = parakeet_full_get_segment_t0_from_state(state, i); + const int64_t t1 = parakeet_full_get_segment_t1_from_state(state, i); + + printf("Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text); + printf("Tokens:\n"); + + const int n_tokens = parakeet_full_n_tokens_from_state(state, i); + for (int j = 0; j < n_tokens; j++) { + parakeet_token_data token_data = parakeet_full_get_token_data_from_state(state, i, j); + const char * token_str = parakeet_token_to_str(ctx, token_data.id); + + printf(" [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%d \"%s\"\n", + j, + token_data.id, + token_data.frame_index, + token_data.duration_idx, + token_data.duration_value, + token_data.p, + token_data.plog, + (long long)token_data.t0, + (long long)token_data.t1, + token_data.is_word_start, + token_str); + } + } + printf("\n"); +} + +int main() { + std::string model_path = PARAKEET_MODEL_PATH; + std::string sample_path = SAMPLE_PATH; + + std::vector pcmf32; + std::vector> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + struct parakeet_context * pctx = parakeet_init_from_file_with_params_no_state(model_path.c_str(), ctx_params); + if (pctx == nullptr) { return 1; } + + struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + params.new_token_callback = token_callback; + + params.left_context_ms = 10000; + params.chunk_length_ms = 10000; + params.right_context_ms = 4960; + + parakeet_state * state = parakeet_init_state(pctx); + + // initialize streaming state + assert(parakeet_stream_init(pctx, state, params) == 0); + + const int samples_batch_size = 1600; + int position = 0; + + while (position < (int)pcmf32.size()) { + int samples_to_push = std::min(samples_batch_size, (int)pcmf32.size() - position); + + int ret = parakeet_stream_push(pctx, state, pcmf32.data() + position, samples_to_push); + assert(ret == 0); + + position += samples_to_push; + } + + // flush remaining samples. + assert(parakeet_stream_flush(pctx, state) == 0); + + parakeet_free_state(state); + parakeet_free(pctx); + + printf("\n\nTest passed: Streaming logic.\n"); + return 0; +} diff --git a/tests/test-parakeet.cpp b/tests/test-parakeet.cpp new file mode 100644 index 00000000000..83237c600ac --- /dev/null +++ b/tests/test-parakeet.cpp @@ -0,0 +1,99 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include +#include + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + static bool is_first = true; + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf)); + + int32_t time_ms = token_data->frame_index * 10; + + printf("%s", text_buf); + fflush(stdout); + + is_first = false; +} + +void segment_callback(parakeet_context * ctx, parakeet_state * state, int n_new, void * user_data) { + const int n_segments = parakeet_full_n_segments_from_state(state); + const int s0 = n_segments - n_new; + + printf("\nSegment Callback: %d new segment(s)\n", n_new); + + for (int i = s0; i < n_segments; i++) { + const char * text = parakeet_full_get_segment_text_from_state(state, i); + const int64_t t0 = parakeet_full_get_segment_t0_from_state(state, i); + const int64_t t1 = parakeet_full_get_segment_t1_from_state(state, i); + + printf("Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text); + printf("Tokens:\n"); + + const int n_tokens = parakeet_full_n_tokens_from_state(state, i); + for (int j = 0; j < n_tokens; j++) { + parakeet_token_data token_data = parakeet_full_get_token_data_from_state(state, i, j); + const char * token_str = parakeet_token_to_str(ctx, token_data.id); + + printf(" [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%d \"%s\"\n", + j, + token_data.id, + token_data.frame_index, + token_data.duration_idx, + token_data.duration_value, + token_data.p, + token_data.plog, + (long long)token_data.t0, + (long long)token_data.t1, + token_data.is_word_start, + token_str); + } + } + printf("\n"); +} + +int main() { + std::string model_path = PARAKEET_MODEL_PATH; + std::string sample_path = SAMPLE_PATH; + + // Load the sample audio file + std::vector pcmf32; + std::vector> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + assert(pcmf32s.size() == 0); + + printf("Loading Parakeet model from: %s\n", model_path.c_str()); + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + + struct parakeet_context * pctx = parakeet_init_from_file_with_params_no_state(model_path.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "Failed to load Parakeet model\n"); + return 1; + } + printf("Successfully loaded Parakeet model\n"); + + struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + params.new_token_callback = token_callback; + params.new_token_callback_user_data = nullptr; + params.new_segment_callback = segment_callback; + params.new_segment_callback_user_data = nullptr; + parakeet_state * state = parakeet_init_state(pctx); + + int ret = parakeet_chunk(pctx, state, params, pcmf32.data(), pcmf32.size()); + assert(ret == 0); + + parakeet_free_state(state); + parakeet_free(pctx); + + printf("\nTest passed: Parakeet model loaded and freed successfully\n"); + return 0; +}