Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 58 additions & 40 deletions discojs/src/models/onnx.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
import { AutoModelForCausalLM, PreTrainedModel, Tensor } from '@xenova/transformers';
import { Model } from './index.js';
import type { WeightsContainer } from '../index.js';
import { List } from 'immutable';
import type { CausalLMOutput} from '@xenova/transformers';
import type { GenerationConfig as TFJSGenerationConfig } from './gpt/config.js';
import { DefaultGenerationConfig } from './gpt/config.js';
import {
AutoModelForCausalLM,
PreTrainedModel,
Tensor,
} from "@xenova/transformers";
import { Model } from "./index.js";
import type { WeightsContainer } from "../index.js";
import { List } from "immutable";
import type { CausalLMOutput } from "@xenova/transformers";
import type { GenerationConfig as TFJSGenerationConfig } from "./gpt/config.js";
import { DefaultGenerationConfig } from "./gpt/config.js";
import type { Batched, DataFormat } from "../index.js";


export class ONNXModel extends Model<'text'> {
export class ONNXModel extends Model<"text"> {
private model: PreTrainedModel;

private constructor(model: PreTrainedModel) {
super();
this.model = model;
}

static async init_pretrained(modelName = 'Xenova/gpt2'): Promise<ONNXModel> {
static async init_pretrained(modelName = "Xenova/gpt2"): Promise<ONNXModel> {
const model = await AutoModelForCausalLM.from_pretrained(modelName);
return new ONNXModel(model);
}

getConfig(): Record<string, unknown> {
get config(): Record<string, unknown> {
return this.model.config as Record<string, unknown>;
}

Expand All @@ -30,81 +33,96 @@ export class ONNXModel extends Model<'text'> {
options?: Partial<TFJSGenerationConfig>
): Promise<Batched<DataFormat.ModelEncoded["text"][1]>> {
const config = Object.assign({}, DefaultGenerationConfig, options);

return List(
await Promise.all(
batch.map(tokens => this.#predictSingle(tokens, config))
batch.map((tokens) => this.#predictSingle(tokens, config))
)
);
}


async #predictSingle(
tokens: DataFormat.ModelEncoded["text"][0],
config: TFJSGenerationConfig
): Promise<DataFormat.ModelEncoded["text"][1]> {
const contextLength = (this.model.config as { max_position_embeddings?: number }).max_position_embeddings ?? 1024;
const contextLength =
(this.model.config as { max_position_embeddings?: number })
.max_position_embeddings ?? 1024;
const truncated = tokens.slice(-contextLength).toArray();

if (truncated.length === 0) {
throw new Error('Token list is empty. Cannot run generate().');
throw new Error("Token list is empty. Cannot run generate().");
}

const input_ids = new Tensor('int64', truncated.map(BigInt), [1, truncated.length]);

const output = await this.model.generate(input_ids, {

const input_ids = new Tensor("int64", truncated.map(BigInt), [
1,
truncated.length,
]);

const output = (await this.model.generate(input_ids, {
max_new_tokens: 1,
temperature: config.temperature,
do_sample: config.doSample,
top_k: config.topk,
}) as number[][];

if (!Array.isArray(output) || output.length === 0 || !Array.isArray(output[0])) {
throw new Error('ONNX model.generate() did not return valid sequences.');
}

})) as number[][];

if (
!Array.isArray(output) ||
output.length === 0 ||
!Array.isArray(output[0])
) {
throw new Error("ONNX model.generate() did not return valid sequences.");
}

const predicted_id = output[0].at(-1) as number;
return Number(predicted_id);

}



async getLogits(batch: List<List<number>>): Promise<Tensor> {
const input_ids_array: number[][] = batch.toArray().map(seq => seq.toArray());
const input_ids_array: number[][] = batch
.toArray()
.map((seq) => seq.toArray());
const attention_mask_array: number[][] = input_ids_array.map(
(seq): number[] => new Array<number>(seq.length).fill(1)
);

const input_ids_flat = input_ids_array.flat();
const attention_mask_flat = attention_mask_array.flat();
const shape = [input_ids_array.length, input_ids_array[0].length];

// use BigInt for int64 compatibility
const input_ids = new Tensor('int64', input_ids_flat.map(BigInt), shape);
const attention_mask = new Tensor('int64', attention_mask_flat.map(BigInt), shape);
const input_ids = new Tensor("int64", input_ids_flat.map(BigInt), shape);
const attention_mask = new Tensor(
"int64",
attention_mask_flat.map(BigInt),
shape
);

// run model forward
const outputs = await this.model.forward({ input_ids, attention_mask }) as CausalLMOutput;
const outputs = (await this.model.forward({
input_ids,
attention_mask,
})) as CausalLMOutput;
return outputs.logits;
}

async *train(): AsyncGenerator<never, never> {
await Promise.resolve(); // dummy await
const yieldFlag = false;
if (yieldFlag) yield undefined as never; // satisfy 'require-yield'
throw new Error('Training not supported for ONNX models');
throw new Error("Training not supported for ONNX models");
}

get weights(): WeightsContainer {
throw new Error('Weights access not supported in ONNX models');
throw new Error("Weights access not supported in ONNX models");
}

set weights(_: WeightsContainer) {
throw new Error('Weights setting not supported in ONNX models');
throw new Error("Weights setting not supported in ONNX models");
}

[Symbol.dispose](): void {
// Dispose of the model to free up memory
void this.model.dispose();}
void this.model.dispose();
}
}
6 changes: 3 additions & 3 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions webapp/src/assets/svg/MessageArrow.vue
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<template>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="currentColor"
:class="customClass"
:viewBox="viewBox"
>
<path
d="M214.6 17.4c-12.5-12.5-32.8-12.5-45.3 0l-160 160c-12.5 12.5-12.5 32.8 0 45.3s32.8 12.5 45.3 0L160 117.3 160 488c0 17.7 14.3 32 32 32s32-14.3 32-32l0-370.7 105.4 105.4c12.5 12.5 32.8 12.5 45.3 0s12.5-32.8 0-45.3l-160-160z"
/>
</svg>
</template>
<script lang="ts">
export default {
props: {
customClass: {
default: "w-6 h-6",
type: String,
},
viewBox: { default: "-60 0 512 512", type: String },
},
};
</script>
23 changes: 23 additions & 0 deletions webapp/src/assets/svg/StopIcon.vue
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<template>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="currentColor"
:class="customClass"
:viewBox="viewBox"
>
<path
d="M320 576C461.4 576 576 461.4 576 320C576 178.6 461.4 64 320 64C178.6 64 64 178.6 64 320C64 461.4 178.6 576 320 576zM256 224L384 224C401.7 224 416 238.3 416 256L416 384C416 401.7 401.7 416 384 416L256 416C238.3 416 224 401.7 224 384L224 256C224 238.3 238.3 224 256 224z"
/>
</svg>
</template>
<script lang="ts">
export default {
props: {
customClass: {
default: "w-7 h-7",
type: String,
},
viewBox: { default: "0 0 640 640", type: String },
},
};
</script>
3 changes: 3 additions & 0 deletions webapp/src/components/testing/Benchmarcks.vue
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
<template>
<h1>Benchmarks</h1>
</template>
Loading