diff --git a/discojs/src/models/onnx.ts b/discojs/src/models/onnx.ts index 8227b1ebd..449beb039 100644 --- a/discojs/src/models/onnx.ts +++ b/discojs/src/models/onnx.ts @@ -1,14 +1,17 @@ -import { AutoModelForCausalLM, PreTrainedModel, Tensor } from '@xenova/transformers'; -import { Model } from './index.js'; -import type { WeightsContainer } from '../index.js'; -import { List } from 'immutable'; -import type { CausalLMOutput} from '@xenova/transformers'; -import type { GenerationConfig as TFJSGenerationConfig } from './gpt/config.js'; -import { DefaultGenerationConfig } from './gpt/config.js'; +import { + AutoModelForCausalLM, + PreTrainedModel, + Tensor, +} from "@xenova/transformers"; +import { Model } from "./index.js"; +import type { WeightsContainer } from "../index.js"; +import { List } from "immutable"; +import type { CausalLMOutput } from "@xenova/transformers"; +import type { GenerationConfig as TFJSGenerationConfig } from "./gpt/config.js"; +import { DefaultGenerationConfig } from "./gpt/config.js"; import type { Batched, DataFormat } from "../index.js"; - -export class ONNXModel extends Model<'text'> { +export class ONNXModel extends Model<"text"> { private model: PreTrainedModel; private constructor(model: PreTrainedModel) { @@ -16,12 +19,12 @@ export class ONNXModel extends Model<'text'> { this.model = model; } - static async init_pretrained(modelName = 'Xenova/gpt2'): Promise { + static async init_pretrained(modelName = "Xenova/gpt2"): Promise { const model = await AutoModelForCausalLM.from_pretrained(modelName); return new ONNXModel(model); } - getConfig(): Record { + get config(): Record { return this.model.config as Record; } @@ -30,62 +33,76 @@ export class ONNXModel extends Model<'text'> { options?: Partial ): Promise> { const config = Object.assign({}, DefaultGenerationConfig, options); - + return List( await Promise.all( - batch.map(tokens => this.#predictSingle(tokens, config)) + batch.map((tokens) => this.#predictSingle(tokens, config)) ) ); } - async #predictSingle( tokens: DataFormat.ModelEncoded["text"][0], config: TFJSGenerationConfig ): Promise { - const contextLength = (this.model.config as { max_position_embeddings?: number }).max_position_embeddings ?? 1024; + const contextLength = + (this.model.config as { max_position_embeddings?: number }) + .max_position_embeddings ?? 1024; const truncated = tokens.slice(-contextLength).toArray(); - + if (truncated.length === 0) { - throw new Error('Token list is empty. Cannot run generate().'); + throw new Error("Token list is empty. Cannot run generate()."); } - - const input_ids = new Tensor('int64', truncated.map(BigInt), [1, truncated.length]); - - const output = await this.model.generate(input_ids, { + + const input_ids = new Tensor("int64", truncated.map(BigInt), [ + 1, + truncated.length, + ]); + + const output = (await this.model.generate(input_ids, { max_new_tokens: 1, temperature: config.temperature, do_sample: config.doSample, top_k: config.topk, - }) as number[][]; - - if (!Array.isArray(output) || output.length === 0 || !Array.isArray(output[0])) { - throw new Error('ONNX model.generate() did not return valid sequences.'); - } - + })) as number[][]; + + if ( + !Array.isArray(output) || + output.length === 0 || + !Array.isArray(output[0]) + ) { + throw new Error("ONNX model.generate() did not return valid sequences."); + } + const predicted_id = output[0].at(-1) as number; return Number(predicted_id); - } - - async getLogits(batch: List>): Promise { - const input_ids_array: number[][] = batch.toArray().map(seq => seq.toArray()); + const input_ids_array: number[][] = batch + .toArray() + .map((seq) => seq.toArray()); const attention_mask_array: number[][] = input_ids_array.map( (seq): number[] => new Array(seq.length).fill(1) ); - + const input_ids_flat = input_ids_array.flat(); const attention_mask_flat = attention_mask_array.flat(); const shape = [input_ids_array.length, input_ids_array[0].length]; - + // use BigInt for int64 compatibility - const input_ids = new Tensor('int64', input_ids_flat.map(BigInt), shape); - const attention_mask = new Tensor('int64', attention_mask_flat.map(BigInt), shape); + const input_ids = new Tensor("int64", input_ids_flat.map(BigInt), shape); + const attention_mask = new Tensor( + "int64", + attention_mask_flat.map(BigInt), + shape + ); // run model forward - const outputs = await this.model.forward({ input_ids, attention_mask }) as CausalLMOutput; + const outputs = (await this.model.forward({ + input_ids, + attention_mask, + })) as CausalLMOutput; return outputs.logits; } @@ -93,18 +110,19 @@ export class ONNXModel extends Model<'text'> { await Promise.resolve(); // dummy await const yieldFlag = false; if (yieldFlag) yield undefined as never; // satisfy 'require-yield' - throw new Error('Training not supported for ONNX models'); + throw new Error("Training not supported for ONNX models"); } get weights(): WeightsContainer { - throw new Error('Weights access not supported in ONNX models'); + throw new Error("Weights access not supported in ONNX models"); } set weights(_: WeightsContainer) { - throw new Error('Weights setting not supported in ONNX models'); + throw new Error("Weights setting not supported in ONNX models"); } [Symbol.dispose](): void { // Dispose of the model to free up memory - void this.model.dispose();} + void this.model.dispose(); + } } diff --git a/package-lock.json b/package-lock.json index 52b95cdc7..daa6993c8 100644 --- a/package-lock.json +++ b/package-lock.json @@ -4789,9 +4789,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001727", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001727.tgz", - "integrity": "sha512-pB68nIHmbN6L/4C6MH1DokyR3bYqFwjaSs/sWDHGj4CTcFtQUQMuJftVwWkXq7mNWOybD3KhUv3oWHoGxgP14Q==", + "version": "1.0.30001764", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001764.tgz", + "integrity": "sha512-9JGuzl2M+vPL+pz70gtMF9sHdMFbY9FJaQBi186cHKH3pSzDvzoUJUPV6fqiKIMyXbud9ZLg4F3Yza1vJ1+93g==", "dev": true, "funding": [ { diff --git a/webapp/src/assets/svg/MessageArrow.vue b/webapp/src/assets/svg/MessageArrow.vue new file mode 100644 index 000000000..9db1de518 --- /dev/null +++ b/webapp/src/assets/svg/MessageArrow.vue @@ -0,0 +1,23 @@ + + diff --git a/webapp/src/assets/svg/StopIcon.vue b/webapp/src/assets/svg/StopIcon.vue new file mode 100644 index 000000000..72148ac0c --- /dev/null +++ b/webapp/src/assets/svg/StopIcon.vue @@ -0,0 +1,23 @@ + + diff --git a/webapp/src/components/testing/Benchmarcks.vue b/webapp/src/components/testing/Benchmarcks.vue new file mode 100644 index 000000000..1eeb5c8b3 --- /dev/null +++ b/webapp/src/components/testing/Benchmarcks.vue @@ -0,0 +1,3 @@ + diff --git a/webapp/src/components/testing/Chat.vue b/webapp/src/components/testing/Chat.vue new file mode 100644 index 000000000..16a553cf8 --- /dev/null +++ b/webapp/src/components/testing/Chat.vue @@ -0,0 +1,623 @@ + + + diff --git a/webapp/src/components/testing/ModelLibrary.vue b/webapp/src/components/testing/ModelLibrary.vue index e9219f356..fd3bb07b5 100644 --- a/webapp/src/components/testing/ModelLibrary.vue +++ b/webapp/src/components/testing/ModelLibrary.vue @@ -23,10 +23,15 @@ > @@ -163,6 +168,7 @@ import type { ModelID } from "@/store"; import { useModelsStore } from "@/store"; import { useTasksStore } from "@/store"; import { useValidationStore } from "@/store"; +import { useRouter } from "vue-router"; import ButtonsCard from "@/components/containers/ButtonsCard.vue"; import IconCard from "@/components/containers/IconCard.vue"; @@ -172,12 +178,14 @@ import DISCOllaboratives from "@/components/simple/DISCOllaboratives.vue"; import TestSteps from "./TestSteps.vue"; import PredictSteps from "./PredictSteps.vue"; +import { isDebuggerStatement } from "typescript"; const debug = createDebug("webapp:ModelLibrary"); const validationStore = useValidationStore(); const models = useModelsStore(); const { tasks } = storeToRefs(useTasksStore()); const toaster = useToaster(); +const router = useRouter(); type Selection = { mode: "predict" | "test"; @@ -307,4 +315,13 @@ function taskTitle(taskID: string): string | undefined { return titled.displayInformation.title; } + +const goToChat = (modelID: ModelID): void => { + validationStore.step = 0; + router.push({ path: "/chat", query: { modelID } }); +}; + +const goToBenchmarks = (): void => { + router.push({ path: "/benchmarks" }); +}; diff --git a/webapp/src/router/router.ts b/webapp/src/router/router.ts index 68d6dea7b..c07338b7a 100644 --- a/webapp/src/router/router.ts +++ b/webapp/src/router/router.ts @@ -1,16 +1,18 @@ import createDebug from "debug"; -import { createRouter, createWebHashHistory } from 'vue-router' +import { createRouter, createWebHashHistory } from "vue-router"; import { scrollToTop } from "@/utils"; -import TrainingBar from '@/components/progress_bars/TrainingBar.vue' -import TestingBar from '@/components/progress_bars/TestingBar.vue' -import HomePage from '@/components/pages/HomePage.vue' -import TaskCreationForm from '@/components/task_creation_form/TaskCreationForm.vue' -import TaskList from '@/components/pages/TaskList.vue' -import NotFound from '@/components/pages/NotFound.vue' -import Training from '@/components/training/TrainingSteps.vue' -import ModelLibrary from '@/components/testing/ModelLibrary.vue' -import AboutUs from '@/components/pages/AboutUs.vue' +import TrainingBar from "@/components/progress_bars/TrainingBar.vue"; +import TestingBar from "@/components/progress_bars/TestingBar.vue"; +import HomePage from "@/components/pages/HomePage.vue"; +import TaskCreationForm from "@/components/task_creation_form/TaskCreationForm.vue"; +import TaskList from "@/components/pages/TaskList.vue"; +import NotFound from "@/components/pages/NotFound.vue"; +import Training from "@/components/training/TrainingSteps.vue"; +import ModelLibrary from "@/components/testing/ModelLibrary.vue"; +import AboutUs from "@/components/pages/AboutUs.vue"; +import Chat from "@/components/testing/Chat.vue"; +import Benchmarks from "@/components/testing/Benchmarcks.vue"; const debug = createDebug("webapp:router"); @@ -21,70 +23,84 @@ const router = createRouter({ // Because router is wrapped in a BaseLayout, returning { top: 0 } doesn't do anything // https://github.com/vuejs/vue-router/issues/3451#issuecomment-975637797 scrollToTop(); - return { top: 0 } + return { top: 0 }; }, routes: [ { - path: '/', - name: 'HomePage', - component: HomePage + path: "/", + name: "HomePage", + component: HomePage, }, { - path: '/create', - name: 'task-creation-form', - component: TaskCreationForm + path: "/create", + name: "task-creation-form", + component: TaskCreationForm, }, { - path: '/about', - name: 'about', - component: AboutUs + path: "/about", + name: "about", + component: AboutUs, }, { - path: '/list', - name: 'task-list', + path: "/list", + name: "task-list", components: { default: TaskList, - ProgressBar: TrainingBar - } + ProgressBar: TrainingBar, + }, }, { - path: '/evaluate', - name: 'evaluate', + path: "/evaluate", + name: "evaluate", components: { default: ModelLibrary, - ProgressBar: TestingBar - } + ProgressBar: TestingBar, + }, }, { - path: '/:id', + path: "/:id", components: { default: Training, - ProgressBar: TrainingBar + ProgressBar: TrainingBar, }, props: { default: true, - ProgressBar: false - } + ProgressBar: false, + }, + }, + { + path: "/chat", + name: "chat", + components: { + default: Chat, + }, + }, + { + path: "/benchmarks", + name: "benchmarks", + components: { + default: () => Benchmarks, + }, }, { - path: '/:pathMatch(.*)*', - name: 'not-found', - component: NotFound + path: "/:pathMatch(.*)*", + name: "not-found", + component: NotFound, }, { - path: '/not-found', - name: 'not-found', - component: NotFound - } - ] -}) + path: "/not-found", + name: "not-found", + component: NotFound, + }, + ], +}); // Handle router errors router.onError((err) => { // Handle the router error here debug("router error: %o", err); // Add code for reporting or other error handling logic - void router.push({ path: '/not-found' }) -}) + void router.push({ path: "/not-found" }); +}); -export { router } +export { router };