Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 97 additions & 18 deletions js/llama_webgpu_bridge.js
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,26 @@ function toUint8Array(value) {
return null;
}

function trimUnstableUtf8Tail(text) {
if (typeof text !== 'string' || text.length === 0) {
return '';
}

let end = text.length;
while (end > 0 && text.charCodeAt(end - 1) === 0xFFFD) {
end -= 1;
}

if (end > 0) {
const tail = text.charCodeAt(end - 1);
if (tail >= 0xD800 && tail <= 0xDBFF) {
end -= 1;
}
}

return end === text.length ? text : text.slice(0, end);
}

function toFloat32Array(value) {
if (!value) {
return null;
Expand Down Expand Up @@ -3831,7 +3851,8 @@ class LlamaWebGpuBridgeRuntime {
const shouldYieldForResponsiveness =
!(typeof WorkerGlobalScope !== 'undefined' && globalThis instanceof WorkerGlobalScope);
const yieldInterval = shouldYieldForResponsiveness ? 4 : 0;
let streamed = shouldEmitCurrentText ? '' : null;
let streamed = '';
let emittedStableText = '';

while (generated < nPredict) {
if (this._abortRequested || options.signal?.aborted) {
Expand Down Expand Up @@ -3888,19 +3909,25 @@ class LlamaWebGpuBridgeRuntime {
}

generated += 1;
const piece = this._core.ccall('llamadart_webgpu_last_piece', 'string', [], []) || '';
if (piece.length === 0) {
const fullText = this._core.ccall('llamadart_webgpu_last_output', 'string', [], []) || '';
streamed = fullText;
const stableText = trimUnstableUtf8Tail(fullText);

if (!stableText.startsWith(emittedStableText)) {
emittedStableText = '';
}

const deltaText = stableText.slice(emittedStableText.length);
if (deltaText.length === 0) {
continue;
}
emittedStableText = stableText;

if (typeof options.onToken === 'function') {
const piecePayload = emitTokenText ? piece : textEncoder.encode(piece);
if (shouldEmitCurrentText) {
streamed += piece;
options.onToken(piecePayload, streamed);
} else {
options.onToken(piecePayload, null);
}
const piecePayload = emitTokenText
? deltaText
: textEncoder.encode(deltaText);
options.onToken(piecePayload, shouldEmitCurrentText ? fullText : null);
}

if (yieldInterval > 0 && (generated % yieldInterval) === 0) {
Expand All @@ -3909,6 +3936,17 @@ class LlamaWebGpuBridgeRuntime {
}

const text = this._core.ccall('llamadart_webgpu_last_output', 'string', [], []) || streamed || '';
if (typeof options.onToken === 'function') {
const tailText = text.startsWith(emittedStableText)
? text.slice(emittedStableText.length)
: '';
if (tailText.length > 0) {
const piecePayload = emitTokenText
? tailText
: textEncoder.encode(tailText);
options.onToken(piecePayload, shouldEmitCurrentText ? text : null);
}
}
return text;
} finally {
if (generationStarted) {
Expand Down Expand Up @@ -4203,6 +4241,40 @@ export class LlamaWebGpuBridge {
return sanitized;
}

_createCpuSafeMultimodalLoadOptions(options = {}) {
const sanitized = this._sanitizeModelLoadOptions(options);
sanitized.nGpuLayers = 0;

if (Number.isFinite(Number(sanitized.nCtx)) && Number(sanitized.nCtx) > 4096) {
sanitized.nCtx = 4096;
}

if (!Number.isFinite(Number(sanitized.nThreads)) || Number(sanitized.nThreads) <= 0) {
sanitized.nThreads = 4;
} else {
sanitized.nThreads = Math.min(4, Math.max(1, Math.trunc(Number(sanitized.nThreads))));
}

sanitized.nThreadsBatch = sanitized.nThreads;

if (!Number.isFinite(Number(sanitized.nBatch)) || Number(sanitized.nBatch) <= 0) {
sanitized.nBatch = 128;
} else {
sanitized.nBatch = Math.min(128, Math.max(32, Math.trunc(Number(sanitized.nBatch))));
}

if (!Number.isFinite(Number(sanitized.nUbatch)) || Number(sanitized.nUbatch) <= 0) {
sanitized.nUbatch = Math.min(64, sanitized.nBatch);
} else {
sanitized.nUbatch = Math.min(
sanitized.nBatch,
Math.min(64, Math.max(1, Math.trunc(Number(sanitized.nUbatch)))),
);
}

return sanitized;
}

_rememberLoadedModel(url, options = {}) {
const normalizedUrl = String(url || '').trim();
if (normalizedUrl.length === 0) {
Expand Down Expand Up @@ -4277,7 +4349,9 @@ export class LlamaWebGpuBridge {
return false;
}

const selectedOptions = this._sanitizeModelLoadOptions(this._loadedModelOptions || {});
const selectedOptions = this._createCpuSafeMultimodalLoadOptions(
this._loadedModelOptions || {},
);

const applyWorkerSafeMode = async () => {
await this._callWorker('loadModelFromUrl', [this._loadedModelUrl, selectedOptions]);
Expand Down Expand Up @@ -4385,8 +4459,11 @@ export class LlamaWebGpuBridge {
this._hasMediaParts(options)
&& typeof this._loadedMmProjUrl === 'string'
&& this._loadedMmProjUrl.length > 0;
const forceCpuMultimodalFallback =
this._hasMediaParts(options)
&& Number(this._loadedModelOptions?.nGpuLayers) !== 0;

if (Number(this._runtime?._modelBytes) > 0 && !forceReloadRequested) {
if (Number(this._runtime?._modelBytes) > 0 && !forceReloadRequested && !forceCpuMultimodalFallback) {
if (shouldEnsureMultimodalInRuntime) {
const runtimeSupportsMedia =
(typeof this._runtime.supportsVision === 'function' && this._runtime.supportsVision())
Expand All @@ -4407,16 +4484,18 @@ export class LlamaWebGpuBridge {
return;
}

const loadOptions = this._sanitizeModelLoadOptions(this._loadedModelOptions || {});
const loadOptions = forceCpuMultimodalFallback
? this._createCpuSafeMultimodalLoadOptions(this._loadedModelOptions || {})
: this._sanitizeModelLoadOptions(this._loadedModelOptions || {});
const workerTimedOut = this._isWorkerTimeoutError(fallbackError);
const forcedCpuFallback = this._isForcedCpuMultimodalFallbackError(fallbackError);
const forceCpuMultimodalFallback =
this._hasMediaParts(options)
const shouldWarnCpuMultimodalFallback =
forceCpuMultimodalFallback
&& (this._isDispatchWorkgroupLimitError(fallbackError)
|| forcedCpuFallback)
&& Number(loadOptions.nGpuLayers) !== 0;
|| forcedCpuFallback
|| Number(this._loadedModelOptions?.nGpuLayers) !== 0);

if (forceCpuMultimodalFallback) {
if (shouldWarnCpuMultimodalFallback) {
loadOptions.nGpuLayers = 0;
if (Number.isFinite(loadOptions.nCtx) && Number(loadOptions.nCtx) > 4096) {
loadOptions.nCtx = 4096;
Expand Down
15 changes: 14 additions & 1 deletion src/llama_webgpu_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,14 @@ std::string normalize_media_markers(const std::string & prompt, const size_t med
replace_all_inplace(normalized, "<|image|>", marker);
replace_all_inplace(normalized, "<img>", marker);
replace_all_inplace(normalized, "<|img|>", marker);
replace_all_inplace(
normalized,
"<|vision_start|><|image_pad|><|vision_end|>",
marker);
replace_all_inplace(
normalized,
"<|vision_start|><|video_pad|><|vision_end|>",
marker);
replace_all_inplace(normalized, "<audio>", marker);
replace_all_inplace(normalized, "<|audio|>", marker);

Expand Down Expand Up @@ -861,7 +869,12 @@ int32_t next_token_impl() {
return 0;
}

g_last_piece = token_to_piece(token, true);
if (llama_vocab_is_control(g_state.vocab, token)) {
end_generation_state();
return 0;
}

g_last_piece = token_to_piece(token, false);
g_last_output += g_last_piece;

llama_token token_for_decode = token;
Expand Down