Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 65 additions & 9 deletions examples/cli/cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,33 @@ static void replace_all(std::string & s, const std::string & search, const std::
}
}

// Returns the number of trailing continuation bytes still needed for `s` to end
// on a complete UTF-8 codepoint. Returns 0 if the tail of `s` is already a
// complete codepoint (or if the tail looks malformed and we should stop merging).
// Used to merge whisper tokens whose bytes split a multi-byte UTF-8 character
// (e.g. CJK), so the JSON output stays valid UTF-8. See issue #1798.
static int utf8_trailing_bytes_needed(const std::string & s) {
const int n = (int) s.size();
int i = n - 1;
// walk back past continuation bytes (10xxxxxx)
while (i >= 0 && ((unsigned char) s[i] & 0xC0) == 0x80) {
--i;
}
if (i < 0) {
// all continuation bytes, or empty — nothing we can do
return 0;
}
const unsigned char c = (unsigned char) s[i];
int expected;
if ((c & 0x80) == 0x00) expected = 1; // ASCII
else if ((c & 0xE0) == 0xC0) expected = 2;
else if ((c & 0xF0) == 0xE0) expected = 3;
else if ((c & 0xF8) == 0xF0) expected = 4;
else return 0; // malformed lead, give up
const int have = n - i;
return have >= expected ? 0 : (expected - have);
}

// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
Expand Down Expand Up @@ -738,18 +765,47 @@ static void output_json(
if (full) {
start_arr("tokens");
const int n = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n; ++j) {
auto token = whisper_full_get_token_data(ctx, i, j);

// Merge adjacent tokens whose bytes together form a
// single UTF-8 codepoint. Multi-byte characters (CJK
// in particular) can end up split across whisper
// tokens, which used to produce invalid UTF-8 in the
// JSON string. Refs issue #1798.
struct merged_token {
std::string text;
whisper_token_data data;
int64_t t1;
};
std::vector<merged_token> merged;
merged.reserve(n);
for (int j = 0; j < n; ) {
auto tok = whisper_full_get_token_data(ctx, i, j);
merged_token m{ whisper_token_to_str(ctx, tok.id), tok, tok.t1 };
++j;
while (j < n && utf8_trailing_bytes_needed(m.text) > 0) {
auto tok_next = whisper_full_get_token_data(ctx, i, j);
m.text += whisper_token_to_str(ctx, tok_next.id);
if (tok_next.t1 > -1) {
m.t1 = tok_next.t1;
}
++j;
}
merged.push_back(std::move(m));
}

const int nm = (int) merged.size();
for (int j = 0; j < nm; ++j) {
const auto & mt = merged[j];
start_obj(nullptr);
value_s("text", whisper_token_to_str(ctx, token.id), false);
if(token.t0 > -1 && token.t1 > -1) {
value_s("text", mt.text.c_str(), false);
if (mt.data.t0 > -1 && mt.t1 > -1) {
// If we have per-token timestamps, write them out
times_o(token.t0, token.t1, false);
times_o(mt.data.t0, mt.t1, false);
}
value_i("id", token.id, false);
value_f("p", token.p, false);
value_f("t_dtw", token.t_dtw, true);
end_obj(j == (n - 1));
value_i("id", mt.data.id, false);
value_f("p", mt.data.p, false);
value_f("t_dtw", mt.data.t_dtw, true);
end_obj(j == (nm - 1));
}
end_arr(!params.diarize && !params.tinydiarize);
}
Expand Down