-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.js
More file actions
152 lines (125 loc) · 4.04 KB
/
main.js
File metadata and controls
152 lines (125 loc) · 4.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import { addLog, getTrainingConfig } from './util.js'
import { ResNet10 } from './models.js'
import * as tf from '@tensorflow/tfjs'
// dataset loading helper
async function loadDataset(jsonFile) {
try {
const response = await fetch(jsonFile)
const jsonData = await response.json()
return processData(jsonData)
} catch (error) {
console.error('Error loading data:', error)
}
}
// data processing helper
function processData(jsonData) {
const images = jsonData.map(item => item.image)
const labels = jsonData.map(item => item.label)
return {
images: images,
labels: labels
}
}
// globals
let model = null
let dataSet = null
const loadDataBtn = document.getElementById('loadData')
const startTrainingBtn = document.getElementById('startTraining')
const dumpWeightsBtn = document.getElementById('dumpWeights')
// initial page loading
document.addEventListener('DOMContentLoaded', () => {
addLog('Page loaded.')
})
// handle dataloading
loadDataBtn.addEventListener('click', async () => {
addLog('Loading data...')
dataSet = await loadDataset('/data/bloodmnist_test.json')
addLog('Data loaded.')
startTrainingBtn.disabled = false
})
startTrainingBtn.addEventListener('click', async () => {
// create model
model = new ResNet10()
// process config from html
let trainingConfig = getTrainingConfig()
try {
startTrainingBtn.disabled = true
loadDataBtn.disabled = true
await model.train(dataSet, trainingConfig)
} catch (error) {
addLog(`Error during training:`, error)
}
dumpWeightsBtn.disabled = false
addLog("Done.")
})
dumpWeightsBtn.addEventListener('click', async () => {
addLog("Dumping weights...")
// const weightsJSON = {
// weights: [],
// shapes: [],
// names: []
// };
// const layers = model.model.layers
// addLog(`Model layers: ${layers.length}`)
// for (let i = 0; i < layers.length; i++) {
// const layer = layers[i]
// const layerWeights = layer.getWeights()
// for (let j = 0; j < layerWeights.length; j++) {
// const weightTensor = layerWeights[j]
// const weightValues = await weightTensor.data(0)
// weightsJSON.weights.push(Array.from(weightValues))
// weightsJSON.shapes.push(weightTensor.shape)
// weightsJSON.names.push(`${layer.name}_weight_${j}`)
// }
// }
// const weightString = JSON.stringify(weightsJSON, null, 2)
const weights = model.model.getWeights();
const weightString = await serializeMessage(weights);
addLog(`Done processing. Downloading...`)
const blob = new Blob([weightString], { type: 'application/json' })
const url = URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = 'model_weights.json'
document.body.appendChild(a)
a.click()
document.body.removeChild(a)
URL.revokeObjectURL(url)
})
async function tensorToSerializable(obj) {
if (obj instanceof tf.Tensor) {
return {
"__tensor__": true,
"data": await obj.array(),
"dtype": obj.dtype,
"shape": obj.shape
};
} else if (Array.isArray(obj)) {
return Promise.all(obj.map(item => tensorToSerializable(item)));
} else if (typeof obj === "object" && obj !== null) {
const entries = await Promise.all(
Object.entries(obj).map(async ([key, value]) => [key, await tensorToSerializable(value)])
);
return Object.fromEntries(entries);
}
return obj;
}
function serializableToTensor(obj) {
if (typeof obj === "object" && obj !== null) {
if ("__tensor__" in obj) {
return tf.tensor(obj.data, obj.shape, obj.dtype);
}
return Object.fromEntries(Object.entries(obj).map(([key, value]) => [key, serializableToTensor(value)]));
} else if (Array.isArray(obj)) {
return obj.map(item => serializableToTensor(item));
}
return obj;
}
async function serializeMessage(message) {
const serializableDict = await tensorToSerializable(message);
return JSON.stringify(serializableDict, null, 2);
}
function deserializeMessage(jsonStr) {
const serializableDict = JSON.parse(json);
return serializableToTensor(serializableDict);
}