Skip to content

Commit 793e06b

Browse files
authored
Merge pull request #927 from epfml/NAN-init_from_ONNX-christinakopi
ONNX to Tensorflow.js conversion of GPT-2
2 parents cf96721 + 9fbe6b3 commit 793e06b

29 files changed

Lines changed: 10453 additions & 132 deletions

cli/src/hellaswag_gpt.ts

Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,95 @@
1+
import fsPromise from 'node:fs/promises';
2+
import { dirname } from 'path';
3+
import { fileURLToPath } from 'url';
4+
import { parse } from 'ts-command-line-args'
5+
16
import '@tensorflow/tfjs-node';
2-
import fs from 'node:fs';
37
import path from 'node:path';
4-
import { Tokenizer, models } from '@epfml/discojs';
8+
import { models, serialization, Tokenizer } from '@epfml/discojs';
59
import { loadHellaSwag } from '@epfml/discojs-node';
610

7-
const logFile = path.join('..', 'datasets', 'LogFile_hellaswag.txt');
8-
const logLines: string[] = [];
9-
10-
function log(message: string) {
11-
console.log(message);
12-
logLines.push(message);
13-
}
11+
const __dirname = dirname(fileURLToPath(import.meta.url));
1412

15-
const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(-1)
16-
17-
async function evaluateTFJS(tokenizer: Tokenizer) {
18-
const model = new models.GPT({ seed: 42 });
19-
log('Evaluating TFJS GPT on HellaSwag...');
13+
async function evaluateModel(model: models.GPT | models.ONNXModel, numDataPoints = -1) {
14+
const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(numDataPoints)
15+
const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2');
16+
console.log('Starting the HellaSwag benchmark...');
2017

2118
const start = Date.now();
22-
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false);
19+
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, true);
2320
const duration = ((Date.now() - start) / 1000).toFixed(2);
2421

25-
log(`TFJS GPT Accuracy: ${(accuracy * 100).toFixed(2)}%`);
26-
log(`TFJS GPT Evaluation Time: ${duration} seconds`);
22+
console.log(`Final accuracy: ${(accuracy * 100).toFixed(2)}%`);
23+
console.log(`Evaluation Time: ${duration} seconds`);
2724
}
2825

29-
async function evaluateXenova(tokenizer: Tokenizer) {
30-
const model = await models.ONNXModel.init_pretrained('Xenova/gpt2');
31-
log('Evaluating Xenova GPT-2 (ONNX) on HellaSwag...');
26+
const ModelTypes = ['onnx', 'gpt-tfjs-random', 'gpt-tfjs-pretrained'] as const;
27+
type ModelType = typeof ModelTypes[number];
3228

33-
const start = Date.now();
34-
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false);
35-
const duration = ((Date.now() - start) / 1000).toFixed(2);
29+
interface HellaSwagArgs {
30+
model: ModelType
31+
numDataPoints: number
32+
logFile: string
33+
pretrainedModelPath: string
34+
help?: boolean
35+
}
3636

37-
log(`Xenova GPT-2 Accuracy: ${(accuracy * 100).toFixed(2)}%`);
38-
log(`Xenova GPT-2 Evaluation Time: ${duration} seconds`);
37+
function castModelType(raw: string): ModelType {
38+
for (const t of ModelTypes) if (raw === t) return t
39+
throw new Error(`Invalid model type: ${raw}`)
3940
}
4041

4142
async function main(): Promise<void> {
42-
fs.writeFileSync(logFile, '', 'utf-8'); // Clear old log file
43+
const args = parse<HellaSwagArgs>({
44+
model: {
45+
type: (raw: string) => castModelType(raw),
46+
description: `Model type, one of ${ModelTypes.toString()}`,
47+
defaultValue: 'onnx'
48+
},
49+
numDataPoints: {
50+
type: Number,
51+
description: 'Number of HellaSwag datapoints to evaluate, set -1 for the whole benchmark',
52+
defaultValue: -1
53+
},
54+
logFile: {
55+
type: String,
56+
description: 'Relative path to the log file, default to ./hellaswag.log', defaultValue: 'hellaswag.log'
57+
},
58+
pretrainedModelPath: {
59+
type: String,
60+
description: 'If specifying gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model',
61+
defaultValue: path.join(__dirname, "..", "..", "onnx-converter", "assets", "model.json")
62+
},
63+
help: {
64+
type: Boolean,
65+
optional: true,
66+
alias: 'h',
67+
description: 'Prints this usage guide'
68+
}
69+
}, { helpArg: 'help' })
4370

44-
const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2');
45-
await evaluateTFJS(tokenizer);
46-
log('\n---\n');
47-
await evaluateXenova(tokenizer);
71+
let model: models.GPT | models.ONNXModel | undefined;
72+
switch (args.model) {
73+
case 'onnx':
74+
console.log("Using ONNX pretrained model Xenova/gpt2")
75+
model = await models.ONNXModel.init_pretrained('Xenova/gpt2');
76+
break;
77+
case 'gpt-tfjs-random':
78+
console.log("Using GPT-TFJS with random initialization")
79+
model = new models.GPT({ seed: 42 });
80+
break;
81+
case 'gpt-tfjs-pretrained':
82+
console.log("Using GPT-TFJS with pretrained weights")
83+
if (args.pretrainedModelPath === undefined) {
84+
throw new Error("If choosing gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model `pretrainedModelPath")
85+
}
86+
const encodedModel = await fsPromise.readFile(args.pretrainedModelPath);
87+
model = await serialization.model.decode(encodedModel) as models.GPT;
88+
break;
89+
}
90+
await evaluateModel(model, args.numDataPoints);
4891

49-
fs.writeFileSync(logFile, logLines.join('\n'), 'utf-8');
50-
console.log(`\nResults written to ${logFile}`);
92+
console.log("Benchmark completed!")
5193
}
5294

5395
main().catch(console.error);

datasets/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,6 @@
2020

2121
# GDHF demo
2222
/tinder_dog/
23+
24+
# HellaSwag benchmark
25+
hellaswag*

discojs-node/src/hellaswag.ts

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
1-
import { models } from '@epfml/discojs';
1+
import path from "node:path";
22
import fetch from 'node-fetch';
3+
import fs from 'node:fs/promises';
4+
5+
import { models } from '@epfml/discojs';
6+
7+
import { dirname } from 'path';
8+
import { fileURLToPath } from 'url';
9+
const __dirname = dirname(fileURLToPath(import.meta.url));
10+
11+
const DATASET_DIR = path.join(__dirname, "..", "..", "datasets");
12+
const hellaswag_filepath = path.join(DATASET_DIR, "hellaswag_val.jsonl")
313

414
/**
515
* Loads the HellaSwag dataset from the remote URL in Node.js
@@ -8,12 +18,23 @@ import fetch from 'node-fetch';
818
* @returns A HellaSwagDataset containing the examples.
919
*/
1020
export async function load(limit = -1): Promise<models.HellaSwagDataset> {
11-
const response = await fetch(models.HELLASWAG_URL);
12-
if (!response.ok) {
13-
throw new Error(`Failed to fetch dataset from ${models.HELLASWAG_URL}: ${response.statusText}`);
21+
let text: string;
22+
try {
23+
// Reads the file if it exists locally
24+
text = (await fs.readFile(hellaswag_filepath)).toString();
25+
} catch {
26+
console.log("Downloading the Hellaswag benchmark")
27+
// Otherwise fetch it
28+
const response = await fetch(models.HELLASWAG_URL);
29+
if (!response.ok) {
30+
throw new Error(`Failed to fetch dataset from ${models.HELLASWAG_URL}: ${response.statusText}`);
31+
}
32+
33+
text = await response.text();
34+
// Save the file locally
35+
await fs.writeFile(hellaswag_filepath, text);
1436
}
15-
16-
const text = await response.text();
37+
1738
const lines = text.split('\n');
1839

1940
const dataset: models.HellaSwagDataset = [];

discojs/src/default_tasks/cifar10.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ export const cifar10: TaskProvider<"image", "decentralized"> = {
1414
title: 'CIFAR10',
1515
summary: {
1616
preview: 'CIFAR-10 is a classic image classification task, and one of the most widely used datasets for machine learning research.',
17-
overview: "The dataset contains 60,000 32x32 color images in 10 different classes: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. The official CIFAR-10 website can be found <a class='underline text-blue-400' href='https://www.cs.toronto.edu/~kriz/cifar.html' target='_blank'>here</a>. You can find a link to a sample dataset at the next step (Connect Your Data)."
17+
overview: "The dataset contains 60,000 32x32 color images in 10 different classes: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. The official CIFAR-10 website can be found at https://www.cs.toronto.edu/~kriz/cifar.html . You can find a link to a sample dataset at the next step."
1818
},
19-
model: 'The model is a pretrained <a class="underline text-blue-400" target="_blank" href="https://github.com/tensorflow/tfjs-models/tree/master/mobilenet">MobileNetV1 model</a> trained in Tensorflow.js. The last output layer is replaced with a fully connected layer with softmax activation and one output neuron per CIFAR10 category. The data preprocessing reshapes images into 224x224 pixels and normalizes values between 0 and 1. The neural network is optimized via Stochastic Gradient Descent and a categorical Cross Entropy loss.',
19+
model: 'The model is a pretrained MobileNetV1 model trained in Tensorflow.js. The last output layer is replaced with a fully connected layer with softmax activation and one output neuron per CIFAR10 category. The data preprocessing reshapes images into 224x224 pixels and normalizes values between 0 and 1. The neural network is optimized via Stochastic Gradient Descent and a categorical Cross Entropy loss.',
2020
dataFormatInformation: 'Images should be of .png format and of size 32x32. <br> The CSV file should start with the exact header "filename,label", and each row should contain an image filename (without extension) and its label.<br><br> For example if you have images: 0.png (of a frog) and 1.png (of a car) <br> The CSV file should be: <br>filename, label <br><br> 0, frog <br> 1, car',
2121
dataExample:
2222
"https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/cifar10-example.png",

discojs/src/default_tasks/lus_covid.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ export const lusCovid: TaskProvider<"image", "federated"> = {
1212
title: 'Lung Ultrasound Image Classification',
1313
summary: {
1414
preview: "Medical images are a typical example of data that exists in huge quantity yet that can't be shared due to confidentiality reasons. Medical applications would immensely benefit from training on data currently locked. More data diversity leads to better generalization and bias mitigation.",
15-
overview: "Disco allows data owners to collaboratively train machine learning models using their respective data without any privacy breach. This example problem is about diagnosing whether patients are positive or negative to COVID-19 from lung ultrasounds images. <br>Don't have a dataset of your own? You can find a link to a sample dataset at the next step."
15+
overview: "Disco allows data owners to collaboratively train machine learning models using their respective data without any privacy breach. This example problem is about diagnosing whether patients are positive or negative to COVID-19 from lung ultrasounds images. You can find a link to a sample dataset at the next step."
1616
},
1717
model: "The model is a simple Convolutional Neural Network composed of two convolutional layers with ReLU activations and max pooling layers, followed by a fully connected output layer. The data preprocessing reshapes images into 100x100 pixels and normalizes values between 0 and 1",
1818
dataFormatInformation: 'This model takes as input an image dataset of lung ultrasounds. The images are resized automatically.',

discojs/src/default_tasks/mnist.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ export const mnist: TaskProvider<"image", "decentralized"> = {
1212
title: 'Handwritten Digit Recognition',
1313
summary: {
1414
preview: "The MNIST handwritten digit classification problem is a classic dataset used in computer vision and deep learning. The objective is to classify handwritten digits from 28x28 pixel images.",
15-
overview: "Download the classic MNIST dataset of hand-written numbers <a class='underline text-blue-400' target='_blank' href='https://www.kaggle.com/scolianni/mnistasjpg'>here</a>. You can also find a sample dataset at the next step."
15+
overview: "Download the classic MNIST dataset of hand-written numbers at https://www.kaggle.com/scolianni/mnistasjpg . You can also find a sample dataset at the next step."
1616
},
1717
model: "The model is a simple Convolutional Neural Network composed of three convolutional layers with ReLU activations and max pooling layers, followed by two fully connected layers. The data preprocessing simply normalizes values between 0 and 1. The neural network is optimized via RMSProp and a categorical cross-entropy loss.",
1818
dataFormatInformation: 'This model is trained on images corresponding to digits 0 to 9. You can connect your own images of each digit in the box corresponding to its label. The model takes images of size 28x28 as input.',

discojs/src/default_tasks/titanic.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ export const titanic: TaskProvider<"tabular", "federated"> = {
1212
title: 'Titanic Prediction',
1313
summary: {
1414
preview: "The Titanic classification task is one of the main entrypoints into machine learning. Using passenger data (name, age, gender, socio-economic class, etc), the goal is to identify who was more likely to survive the infamous shipwreck.",
15-
overview: "The original competition can be found on <a target='_blank' class='underline text-blue-400' href='https://www.kaggle.com/c/titanic'>Kaggle</a> and a link to the training set can be found here <a target='_blank' class='underline text-blue-400' href='https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/titanic_train.csv'>here</a>."
15+
overview: "The original competition can be found on Kaggle (https://www.kaggle.com/c/titanic) and a link to the training set can be found here: https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/titanic_train.csv"
1616
},
1717
model: 'The model is a simple 5-layer feedforward network with ReLU activations. The model is optimized with Adam and binary cross-entropy loss. The preprocessing only fills missing value with a placeholder value (0).',
18-
dataFormatInformation: 'The expected format for the tabular dataset is exactly the same as the sample data provided above or in the Kaggle competition. It is a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc.<br>The first line of the CSV contains the header: "PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked"<br>Each subsequent row contains passenger data.',
18+
dataFormatInformation: 'The expected format for the tabular dataset is exactly the same as the sample data provided above or in the Kaggle competition. It is a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc. The first line of the CSV contains the header: "PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked". Each subsequent row contains passenger data.',
1919
dataExample: [
2020
{ name: "PassengerId", data: "1" },
2121
{ name: "Survived", data: "0" },

discojs/src/default_tasks/wikitext.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export const wikitext: TaskProvider<"text", "federated"> = {
1010
title: "GPT Language Modeling",
1111
summary: {
1212
preview: 'Train a language model (L)LM in your browser, collaboratively and from scratch.',
13-
overview: "You can train a GPT-2 model in your browser and in a collaborative manner on any textual dataset. As an example, you can try the Wikitext-103 dataset, composed of Wikipedia articles, widely used in natural language modeling, which you can download <a class='underline text-blue-400' target='_blank' href='https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz'>here</a>. More information on how to connect the dataset at the next step."
13+
overview: "You can train a GPT-2 model in your browser and in a collaborative manner on any textual dataset. As an example, you can try the Wikitext-103 dataset, composed of Wikipedia articles, widely used in natural language modeling. More information on how to connect the dataset at the next step."
1414
},
1515
model: [
1616
"The model follows the exact GPT-2 architecture and is implemented in TensorFlow.js.",

discojs/src/models/gpt/config.ts

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@ export type GPTConfig = {
1212
contextLength: number
1313
vocabSize?: number
1414
modelType: GPTModelType
15-
name?: string,
1615
evaluate?: boolean
1716
maxEvalBatches?: number
1817
evaluateEvery?: number
1918
maxIter?: number
2019
weightDecay?: number
2120
verbose?: 0 | 1
2221
debug?: boolean
23-
dropout?: number
22+
attnDrop?: number
2423
residDrop?: number
2524
embdDrop?: number
2625
nLayer?: number
@@ -30,7 +29,6 @@ export type GPTConfig = {
3029
}
3130
// for a benchmark of performance, see https://github.com/epfml/disco/pull/659
3231
export const DefaultGPTConfig: Required<GPTConfig> = {
33-
name: 'transformer', // prefix for the model layer names
3432
lr: 0.001,
3533
weightDecay: 0,
3634
maxIter: 10,
@@ -42,9 +40,9 @@ export const DefaultGPTConfig: Required<GPTConfig> = {
4240
contextLength: 128,
4341
vocabSize: 50257,
4442
debug: false,
45-
dropout: 0.2,
46-
residDrop: 0.2,
47-
embdDrop: 0.2,
43+
attnDrop: 0.1,
44+
residDrop: 0.1,
45+
embdDrop: 0.1,
4846
nLayer: 3,
4947
nHead: 3,
5048
nEmbd: 48,

discojs/src/models/gpt/layers.spec.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,9 @@ describe('GPT Layers', () => {
174174
name: 'testCSA',
175175
contextLength: 5,
176176
nHead: 2,
177-
nEmbd: 8, // divisible by nHead, so head size = 4
178-
dropout: 0.0, // no dropout for deterministic tests
177+
nEmbd: 8, // divisible by nHead, so head size = 4
178+
attnDrop: 0.0, // no dropout for deterministic tests
179+
residDrop: 0.0,
179180
nLayer: 2,
180181
seed: 42
181182
};

0 commit comments

Comments
 (0)