From 73f174127752f6f2561db63b8d3a72dbaabe6d15 Mon Sep 17 00:00:00 2001 From: cttps <67242476+cttps@users.noreply.github.com> Date: Mon, 23 Sep 2024 15:54:35 -0400 Subject: [PATCH 01/12] Convert MLP to CNN on 16 categories --- .gitignore | 3 +- model/main.py | 220 +++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 177 insertions(+), 46 deletions(-) diff --git a/.gitignore b/.gitignore index 46c21b8..131570e 100644 --- a/.gitignore +++ b/.gitignore @@ -32,4 +32,5 @@ backend/app/out model/trainingdata *.pt *.onnx -*.bin \ No newline at end of file +*.bin +*.exe \ No newline at end of file diff --git a/model/main.py b/model/main.py index 56b2b1a..f1c228a 100644 --- a/model/main.py +++ b/model/main.py @@ -34,7 +34,7 @@ # to fix cairo bug & not have to run: # export DYLD_LIBRARY_PATH="/opt/homebrew/opt/cairo/lib:$DYLD_LIBRARY_PATH" # Should be macos only tho. thread: https://github.com/Kozea/CairoSVG/issues/354 -dyld.DEFAULT_LIBRARY_FALLBACK.append("/opt/homebrew/lib") +# dyld.DEFAULT_LIBRARY_FALLBACK.append("/opt/homebrew/lib") import cairocffi as cairo @@ -268,6 +268,8 @@ def simplify_strokes(input_strokes, epsilon=2.0): # plt.show() ################################ +torch.cuda.empty_cache() + labels = [] values_dict = {} @@ -277,25 +279,27 @@ def simplify_strokes(input_strokes, epsilon=2.0): print(f"Items: {items}") labels = trianing_data_names + for i,v in enumerate(trianing_data_names): - labels[i] = v.replace("full_binary_","").replace(".bin","") + # shortened_name = v.replace("full_binary_","").replace(".bin","") + labels[i] = v# shortened_name values_dict[v] = [] -# for item in values_dict.keys(): -# i = 0 -# for drawing in unpack_drawings(item): -# simplifiedVector = drawing["image"] -# raster = vector_to_raster([simplifiedVector])[0] -# values_dict[item].append(raster) -# i += 1 -# if i > 34999: -# break - +for item in values_dict.keys(): + i = 0 + for drawing in unpack_drawings(item): + simplifiedVector = drawing["image"] + raster = vector_to_raster([simplifiedVector])[0] + values_dict[item].append(raster) + i += 1 + if i > 1999: + break X = [] y = [] for key, value in enumerate(labels): + print(value+"\n") data_i = values_dict[value] Xi = np.concatenate([data_i], axis = 0) yi = np.full((len(Xi), 1), key).ravel() @@ -328,50 +332,147 @@ def view_images_grid(X, y): view_images_grid(X,y) -class SimpleMLP(nn.Module): - def __init__(self, input_size, hidden_sizes, output_size): - super(SimpleMLP, self).__init__() - self.fc1 = nn.Linear(input_size, hidden_sizes[0]) - self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1]) - self.fc3 = nn.Linear(hidden_sizes[1], hidden_sizes[2]) - self.fc4 = nn.Linear(hidden_sizes[2], hidden_sizes[3]) - self.fc5 = nn.Linear(hidden_sizes[3], output_size) +print(torch.cuda.is_available()) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +# class SimpleMLP(nn.Module): +# def __init__(self, input_size, hidden_sizes, output_size): +# super(SimpleMLP, self).__init__() +# self.fc1 = nn.Linear(input_size, hidden_sizes[0]) +# self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1]) +# self.fc3 = nn.Linear(hidden_sizes[1], hidden_sizes[2]) +# self.fc4 = nn.Linear(hidden_sizes[2], hidden_sizes[3]) +# self.fc5 = nn.Linear(hidden_sizes[3], output_size) + +# def forward(self, x): +# x = F.relu(self.fc1(x)) +# x = F.relu(self.fc2(x)) +# x = F.relu(self.fc3(x)) +# x = F.relu(self.fc4(x)) +# x = self.fc5(x) +# return x + +class SimpleCNN(nn.Module): + def __init__(self, num_classes): + super(SimpleCNN, self).__init__() + self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1) + self.pool = nn.MaxPool2d(2, 2) + self.fc1 = nn.Linear(64 * 3 * 3, 128) + self.fc2 = nn.Linear(128, num_classes) + self.dropout = nn.Dropout(0.5) def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = self.pool(F.relu(self.conv3(x))) + x = x.view(-1, 64 * 3 * 3) x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = F.relu(self.fc3(x)) - x = F.relu(self.fc4(x)) - x = self.fc5(x) + x = self.dropout(x) + x = self.fc2(x) return x +# reshape data for the cnn +def reshape_for_cnn(X, y): + X = X.reshape(-1, 1, 28, 28) + X = torch.from_numpy(X).float() + y = torch.from_numpy(y).long() + return X, y + + # Define training and evaluation functions def train_model(model, X_train, y_train, epochs=10, learning_rate=0.01): criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=learning_rate) + # move data to same device as model + device = next(model.parameters()).device + X_train = torch.from_numpy(X_train).float().to(device) + y_train = torch.from_numpy(y_train).long().to(device) + model.train() for epoch in range(epochs): optimizer.zero_grad() - outputs = model(torch.from_numpy(X_train).float()) - loss = criterion(outputs, torch.from_numpy(y_train).long()) + outputs = model(X_train) + loss = criterion(outputs, y_train) loss.backward() optimizer.step() print(f'Epoch {epoch}, Loss: {loss.item():.4f}') +def train_cnn(model, X_train, y_train, X_val, y_val, epochs=10, learning_rate=0.001): + model = model.to(device) + X_train = X_train.reshape(-1, 1, 28, 28) + X_train, y_train = torch.from_numpy(X_train).float().to(device), torch.from_numpy(y_train).long().to(device) + X_val = X_val.reshape(-1, 1, 28, 28) + X_val, y_val = torch.from_numpy(X_val).float().to(device), torch.from_numpy(y_val).long().to(device) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + + for epoch in range(epochs): + model.train() + optimizer.zero_grad() + outputs = model(X_train) + loss = criterion(outputs, y_train) + loss.backward() + optimizer.step() + + # Validation + model.eval() + with torch.no_grad(): + val_outputs = model(X_val) + val_loss = criterion(val_outputs, y_val) + _, predicted = torch.max(val_outputs.data, 1) + correct = (predicted == y_val).sum().item() + total = y_val.size(0) + + print(f'Epoch {epoch+1}, Train Loss: {loss.item():.4f}, ' + f'Val Loss: {val_loss.item():.4f}, ' + f'Val Accuracy: {100 * correct / total:.2f}%') + def evaluate_model(model, X_train, y_train, X_test, y_test): model.eval() + device = next(model.parameters()).device + with torch.no_grad(): - outputs1 = model(torch.from_numpy(X_train).float()) - _, predicted1 = torch.max(outputs1, 1) - accuracy = (predicted1.numpy() == y_train).mean() - print(f'Train Accuracy: {accuracy:.4f}') + # Move data to the same device as the model + X_train = torch.from_numpy(X_train).float().to(device) + y_train = torch.from_numpy(y_train).long().to(device) + X_test = torch.from_numpy(X_test).float().to(device) + y_test = torch.from_numpy(y_test).long().to(device) + + outputs_train = model(X_train) + _, predicted_train = torch.max(outputs_train, 1) + accuracy_train = (predicted_train == y_train).float().mean().item() + print(f'Train Accuracy: {accuracy_train:.4f}') + + outputs_test = model(X_test) + _, predicted_test = torch.max(outputs_test, 1) + accuracy_test = (predicted_test == y_test).float().mean().item() + print(f'Test Accuracy: {accuracy_test:.4f}') + + return accuracy_test + +def evaluate_cnn(model, X_test, y_test): + # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.eval() + # X_test, y_test = X_test.to(device), y_test.to(device) + X_test, y_test = reshape_for_cnn(X_test,y_test) + X_test = X_test.to(device) + y_test = y_test.to(device) - outputs = model(torch.from_numpy(X_test).float()) - _, predicted = torch.max(outputs, 1) - accuracy = (predicted.numpy() == y_test).mean() - print(f'Test Accuracy: {accuracy:.4f}') + with torch.no_grad(): + outputs = model(X_test) + _, predicted = torch.max(outputs.data, 1) + correct = (predicted == y_test).sum().item() + total = y_test.size(0) + + accuracy = 100 * correct / total + print(f'Test Accuracy: {accuracy:.2f}%') return accuracy def view_img(raster): @@ -380,7 +481,8 @@ def view_img(raster): plt.show() def get_pred(model, raster): - raster_tensor = torch.tensor(raster, dtype=torch.float).unsqueeze(0) + device = next(model.parameters()).device # i dont know what this does but yeah + raster_tensor = torch.tensor(raster, dtype=torch.float).unsqueeze(0).to(device) # use device for this model.eval() with torch.no_grad(): outputs = model(raster_tensor) @@ -389,40 +491,68 @@ def get_pred(model, raster): predicted_label = predicted.item() return labels[predicted_label] -X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1) +def get_pred_cnn(model, raster): + device = next(model.parameters()).device # i dont know what this does but yeah + raster_tensor = torch.tensor(raster, dtype=torch.float).unsqueeze(0).to(device) + raster_tensor = raster_tensor.reshape(1, 1, 28, 28) + model.eval() + with torch.no_grad(): + outputs = model(raster_tensor) + _, predicted = torch.max(outputs, 1) + + predicted_label = predicted.item() + return labels[predicted_label] + + +# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) +X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42) + +train_loader = reshape_for_cnn(X_train, y_train) +val_loader = reshape_for_cnn(X_val, y_val) +test_loader = reshape_for_cnn(X_test, y_test) # Set up and train the model input_size = 784 hidden_sizes = [600, 400, 160, 80] output_size = items -model = SimpleMLP(input_size, hidden_sizes, output_size) +# model = SimpleMLP(input_size, hidden_sizes, output_size).to(device) # use device +model = SimpleCNN(num_classes=len(labels)) # 16 cats + +# train_model(model, X_train, y_train, epochs=200, learning_rate=0.025) +train_cnn(model, X_train, y_train, X_val, y_val, epochs=100, learning_rate=0.001) -train_model(model, X_train, y_train, epochs=200, learning_rate=0.03) +# use sample input instead of 'torch.tensor(X_train[0], dtype=torch.float).unsqueeze(0)' to keep device consistent +# sampleinput = torch.randn(1, input_size, dtype=torch.float).to(device) +sampleinput = torch.randn(1, 1, 28, 28).to(device) -torch.save(model, "model4_3_large.pt") -torch.onnx.export(model, torch.tensor(X_train[0], dtype=torch.float).unsqueeze(0), - "model4_3_large.onnx", export_params=True, do_constant_folding=True, +modelname = "CNN_cat16_v6-0_large_gputrain" +torch.save(model, modelname+".pt") +torch.onnx.export(model, sampleinput, + modelname+".onnx", export_params=True, do_constant_folding=True, input_names = ['input'], output_names = ['output']) # model = torch.load("model3_1_large.pt") # Evaluate the model -evaluate_model(model, X_train, y_train, X_test, y_test) +# evaluate_model(model, X_train, y_train, X_test, y_test) +evaluate_cnn(model,X_test,y_test) model.eval() with torch.no_grad(): - evaluate_model(model, X_train, y_train, X_test, y_test) + # evaluate_model(model, X_train, y_train, X_test, y_test) + evaluate_cnn(model,X_test,y_test) rawStrokes = svg_to_strokes(svg_string1) reformattedStrokes = svg_strokes_reformat(rawStrokes) simplifiedVector = simplify_strokes(reformattedStrokes) raster = vector_to_raster([simplifiedVector])[0] - print(f"PREDICTED DRAWING: {get_pred(model,raster)}") + print(f"PREDICTED DRAWING: {get_pred_cnn(model,raster)}") view_img(raster) for i in range(61, 9771, 299): - print(f"Actual: {labels[y_train[i]]}, Pred: {get_pred(model,X_train[i])}") + print(f"Actual: {labels[y_train[i]]}, Pred: {get_pred_cnn(model,X_train[i])}") view_img(X_train[i]) # test_model(model,img, y_train[i]) From 0d394e29640bec9b630908381f5a0c2e89085b6a Mon Sep 17 00:00:00 2001 From: cttps <67242476+cttps@users.noreply.github.com> Date: Fri, 27 Sep 2024 17:22:16 -0400 Subject: [PATCH 02/12] Use CNN model for frontend, update websocket loop to not use player id --- backend/app.log | 1 + backend/app/main.py | 2 +- deploy/nginx/nginx.conf | 2 +- frontend/src/components/DrawCanvas.tsx | 59 +++++++++++++++++--------- model/main.py | 3 +- 5 files changed, 45 insertions(+), 22 deletions(-) create mode 100644 backend/app.log diff --git a/backend/app.log b/backend/app.log new file mode 100644 index 0000000..fac191e --- /dev/null +++ b/backend/app.log @@ -0,0 +1 @@ +[16:48:43] {C:\Users\mypc\Desktop\quickdraw-vs\backend\app\main.py:62} INFO - Redis connection established diff --git a/backend/app/main.py b/backend/app/main.py index 8b7fa97..854938d 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -223,7 +223,7 @@ async def websocket_endpoint(websocket: WebSocket): await wait_pubsub_subscribe(f"game:{game_id}:channel", subs) await send_next_round(game_id, 0) await asyncio.gather( - websocket_loop(websocket, game_id, player_id, game_data, player_data), + websocket_loop(websocket, game_id, game_data, player_data), pubsub_loop(websocket, pubsub), ) diff --git a/deploy/nginx/nginx.conf b/deploy/nginx/nginx.conf index 551b569..5990926 100644 --- a/deploy/nginx/nginx.conf +++ b/deploy/nginx/nginx.conf @@ -25,7 +25,7 @@ server { proxy_set_header X-Forwarded-Proto $scheme; } - location /model3_4_large.onnx { + location /CNN_cat16_v6-0_large_gputrain.onnx { root /out; expires 7d; add_header Cache-Control "public"; diff --git a/frontend/src/components/DrawCanvas.tsx b/frontend/src/components/DrawCanvas.tsx index fe02194..5579edb 100644 --- a/frontend/src/components/DrawCanvas.tsx +++ b/frontend/src/components/DrawCanvas.tsx @@ -50,28 +50,29 @@ const DrawCanvas: React.FC = ({ const predDebounce = 750; const modelCategories = [ - "apple", - "anvil", - "dresser", - "broom", - "hat", - "camera", - "dog", - "basketball", - "pencil", - "hammer", - "hexagon", - "banana", - "angel", - "airplane", - "ant", - "paper clip", + 'airplane', + 'angel', + 'ant', + 'anvil', + 'apple', + 'banana', + 'basketball', + 'broom', + 'camera', + 'dog', + 'dresser', + 'hammer', + 'hat', + 'hexagon', + 'paper clip', + 'pencil' ]; + useEffect(() => { (async () => { try { - session.current = await InferenceSession.create("model3_4_large.onnx"); + session.current = await InferenceSession.create("CNN_cat16_v6-0_large_gputrain.onnx"); } catch (error) { // TODO: Handle this error properly console.error("Failed to load model", error); @@ -201,6 +202,23 @@ const DrawCanvas: React.FC = ({ rasterImage[i / 4] = imageData.data[i]; // Invert colors (black on white background) } + // const rasterImage: number[][][][] = new Array(1).fill(null).map(() => + // new Array(1).fill(null).map(() => + // new Array(side).fill(null).map(() => + // new Array(side).fill(0) + // ) + // ) + // ); + + // for (let y = 0; y < side; y++) { + // for (let x = 0; x < side; x++) { + // const i = (y * side + x) * 4; + // rasterImage[0][0][y][x] = imageData.data[i] / 255; // Normalize to 0-1 + // } + // } + + + return rasterImage; }; @@ -213,7 +231,7 @@ const DrawCanvas: React.FC = ({ const argMax = (arr: Float32Array): number => arr.indexOf(Math.max(...arr)); - async function ONNX(input: any) { + async function ONNX(input: number[]) { if (session.current === null) { console.error( "Attempted to run inference while InferenceSession is null" @@ -221,7 +239,10 @@ const DrawCanvas: React.FC = ({ return; } try { - const tensor = new Tensor("float32", new Float32Array(input), [1, 784]); + const flattenedInput = input.flat(3); + const tensor = new Tensor("float32", new Float32Array(flattenedInput), [1, 1, 28, 28]); + + // const tensor = new Tensor("float32", new Float32Array(input), [1, 784]); const inputMap = { input: tensor }; diff --git a/model/main.py b/model/main.py index f1c228a..ab67506 100644 --- a/model/main.py +++ b/model/main.py @@ -275,8 +275,9 @@ def simplify_strokes(input_strokes, epsilon=2.0): datapath = os.path.join(os.getcwd(), "model", "trainingdata") trianing_data_names = [f for f in listdir(datapath) if isfile(join(datapath, f))] +print(f"Item Order {trianing_data_names}") items = len(trianing_data_names) # Number of items -print(f"Items: {items}") +print(f"Item Count: {items}") labels = trianing_data_names From 6688e54feb3578f1a58ff5745873288673bfdfe5 Mon Sep 17 00:00:00 2001 From: cttps <67242476+cttps@users.noreply.github.com> Date: Fri, 27 Sep 2024 17:34:30 -0400 Subject: [PATCH 03/12] Update prediction to use timeout loop rather than drawing change --- frontend/src/components/DrawCanvas.tsx | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/frontend/src/components/DrawCanvas.tsx b/frontend/src/components/DrawCanvas.tsx index 5579edb..2d40550 100644 --- a/frontend/src/components/DrawCanvas.tsx +++ b/frontend/src/components/DrawCanvas.tsx @@ -35,7 +35,6 @@ interface DrawCanvasProps { clearCanvas: boolean; } -let lastDrawn = Date.now(); const DrawCanvas: React.FC = ({ dataPass, onParentClearCanvas, @@ -47,7 +46,7 @@ const DrawCanvas: React.FC = ({ const isDrawing = useRef(false); const session = useRef(null); - const predDebounce = 750; + const predDebounce = 400; const modelCategories = [ 'airplane', @@ -79,7 +78,10 @@ const DrawCanvas: React.FC = ({ } })(); + let evalTimer = setTimeout(handleEvaluate,predDebounce); + return () => { + clearTimeout(evalTimer) session.current?.release(); }; }, []); @@ -99,12 +101,6 @@ const DrawCanvas: React.FC = ({ const lastLine = lines[lines.length - 1].concat([point]); setLines(lines.slice(0, -1).concat([lastLine])); } - - if (Date.now() - lastDrawn > predDebounce) { - // console.log("Evaluating drawing now"); - lastDrawn = Date.now(); - handleEvaluate(); - } }; const handleMouseUp = () => { From f8a0183a8cd351bfe56abbadc6d9d252d6ab6860 Mon Sep 17 00:00:00 2001 From: cttps <67242476+cttps@users.noreply.github.com> Date: Fri, 27 Sep 2024 17:44:04 -0400 Subject: [PATCH 04/12] Update to interval instead of timeout --- frontend/src/components/DrawCanvas.tsx | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/frontend/src/components/DrawCanvas.tsx b/frontend/src/components/DrawCanvas.tsx index 2d40550..a85d6c4 100644 --- a/frontend/src/components/DrawCanvas.tsx +++ b/frontend/src/components/DrawCanvas.tsx @@ -8,6 +8,7 @@ import React, { import { Stage, Layer, Line } from "react-konva"; import { InferenceSession, Tensor } from "onnxruntime-web"; import { clear } from "console"; +import { clearInterval } from "timers"; interface AnimateProps { children: React.ReactNode; @@ -69,19 +70,20 @@ const DrawCanvas: React.FC = ({ useEffect(() => { + var evalTimer : NodeJS.Timeout (async () => { try { session.current = await InferenceSession.create("CNN_cat16_v6-0_large_gputrain.onnx"); + evalTimer = setInterval(handleEvaluate,predDebounce); + console.log("evaltimer!",evalTimer) } catch (error) { // TODO: Handle this error properly console.error("Failed to load model", error); } })(); - let evalTimer = setTimeout(handleEvaluate,predDebounce); - return () => { - clearTimeout(evalTimer) + clearInterval(evalTimer) session.current?.release(); }; }, []); @@ -291,7 +293,7 @@ const DrawCanvas: React.FC = ({ setConfidence(probPercent); if (probPercent > 70) { dataPass(prediction); - } + } }); }; From 2d5fcd9abc850758e0e2fde0be773aee769bd8c4 Mon Sep 17 00:00:00 2001 From: cttps <67242476+cttps@users.noreply.github.com> Date: Sun, 29 Sep 2024 00:14:33 -0400 Subject: [PATCH 05/12] Update prediction loop to evaluate after debounce --- .gitignore | 3 +- backend/app.log | 1 - frontend/src/components/DrawCanvas.tsx | 78 ++++++++++++++++---------- 3 files changed, 51 insertions(+), 31 deletions(-) delete mode 100644 backend/app.log diff --git a/.gitignore b/.gitignore index 131570e..b3901fa 100644 --- a/.gitignore +++ b/.gitignore @@ -33,4 +33,5 @@ model/trainingdata *.pt *.onnx *.bin -*.exe \ No newline at end of file +*.exe +*.log \ No newline at end of file diff --git a/backend/app.log b/backend/app.log deleted file mode 100644 index fac191e..0000000 --- a/backend/app.log +++ /dev/null @@ -1 +0,0 @@ -[16:48:43] {C:\Users\mypc\Desktop\quickdraw-vs\backend\app\main.py:62} INFO - Redis connection established diff --git a/frontend/src/components/DrawCanvas.tsx b/frontend/src/components/DrawCanvas.tsx index a85d6c4..68a0499 100644 --- a/frontend/src/components/DrawCanvas.tsx +++ b/frontend/src/components/DrawCanvas.tsx @@ -50,40 +50,41 @@ const DrawCanvas: React.FC = ({ const predDebounce = 400; const modelCategories = [ - 'airplane', - 'angel', - 'ant', - 'anvil', - 'apple', - 'banana', - 'basketball', - 'broom', - 'camera', - 'dog', - 'dresser', - 'hammer', - 'hat', - 'hexagon', - 'paper clip', - 'pencil' + "airplane", + "angel", + "ant", + "anvil", + "apple", + "banana", + "basketball", + "broom", + "camera", + "dog", + "dresser", + "hammer", + "hat", + "hexagon", + "paper clip", + "pencil", ]; useEffect(() => { - var evalTimer : NodeJS.Timeout + (async () => { try { session.current = await InferenceSession.create("CNN_cat16_v6-0_large_gputrain.onnx"); - evalTimer = setInterval(handleEvaluate,predDebounce); - console.log("evaltimer!",evalTimer) } catch (error) { // TODO: Handle this error properly console.error("Failed to load model", error); } })(); + // const evalTimer = setInterval(() => { console.log("debug"); handleEvaluate()},predDebounce); + // console.log("evaltimer!",evalTimer); + return () => { - clearInterval(evalTimer) + // clearInterval(evalTimer) session.current?.release(); }; }, []); @@ -237,10 +238,8 @@ const DrawCanvas: React.FC = ({ return; } try { - const flattenedInput = input.flat(3); - const tensor = new Tensor("float32", new Float32Array(flattenedInput), [1, 1, 28, 28]); - - // const tensor = new Tensor("float32", new Float32Array(input), [1, 784]); + // const flattenedInput = input.flat(3); + const tensor = new Tensor("float32", new Float32Array(input), [1, 1, 28, 28]); const inputMap = { input: tensor }; @@ -248,7 +247,7 @@ const DrawCanvas: React.FC = ({ const output = outputMap["output"].data as Float32Array; - // console.log(output); + return output; } catch (error) { console.error("Error running ONNX model:", error); @@ -279,26 +278,47 @@ const DrawCanvas: React.FC = ({ return svgContent; }; + let lastDrawn= Date.now() + useEffect(() => { // evaluate on lines changed + if (lines.length > 0) { + let curDrawn = Date.now() + if (curDrawn - lastDrawn < predDebounce) { + handleEvaluate(); + } else { + let curLines = lines + setTimeout(() => { + if (lines === curLines) { // if lines have changed, then the drawing will get evaluated anyway + handleEvaluate(); + } + },predDebounce - (curDrawn - lastDrawn)) + } + } + }, [lines]); + const handleEvaluate = () => { const normalizedStrokes = normalizeStrokes(lines); const rasterArray = rasterizeStrokes(normalizedStrokes); + ONNX(rasterArray).then((res) => { - // console.log(res); res = res as Float32Array; let i = argMax(res); - setPrediction(modelCategories[i]); let prob = softmax(res)[i]; let probPercent = Math.floor(prob * 1000) / 10; + + setPrediction(modelCategories[i]); + setConfidence(probPercent); + + if (probPercent > 70) { + dataPass(prediction); } }); }; - useEffect(() => { - // effect to check if clearCanvas is true + useEffect(() => { // effect to check if clearCanvas is true if (clearCanvas) { setLines([]); onParentClearCanvas(); // call the callback function to reset the state in parent component From e89871a2a36258954816b75096d3eb0188def4ad Mon Sep 17 00:00:00 2001 From: cttps <67242476+cttps@users.noreply.github.com> Date: Sun, 29 Sep 2024 00:25:24 -0400 Subject: [PATCH 06/12] Add gitignore --- .gitignore | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.gitignore b/.gitignore index 0e2a451..88a79af 100644 --- a/.gitignore +++ b/.gitignore @@ -33,9 +33,5 @@ model/trainingdata *.pt *.onnx *.bin -<<<<<<< HEAD *.exe *.log -======= -*.log ->>>>>>> main From 096247138b72d36a0d54446f8d3ce4d84452f41b Mon Sep 17 00:00:00 2001 From: mxpph <52426524+mxpph@users.noreply.github.com> Date: Thu, 3 Oct 2024 18:30:08 +0100 Subject: [PATCH 07/12] Update python backend version --- backend.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend.Dockerfile b/backend.Dockerfile index dd3b695..ea84450 100644 --- a/backend.Dockerfile +++ b/backend.Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.10-slim +FROM python:3.12-slim WORKDIR /quickdraw COPY ./backend/requirements.txt . RUN pip install --no-cache-dir --upgrade -r ./requirements.txt From e0bf035566df48cf2758be34f55bbd988111099f Mon Sep 17 00:00:00 2001 From: mxpph <52426524+mxpph@users.noreply.github.com> Date: Thu, 3 Oct 2024 19:02:29 +0100 Subject: [PATCH 08/12] Change required confidence to 80% --- frontend/src/components/DrawCanvas.tsx | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/frontend/src/components/DrawCanvas.tsx b/frontend/src/components/DrawCanvas.tsx index 68a0499..0a0bba8 100644 --- a/frontend/src/components/DrawCanvas.tsx +++ b/frontend/src/components/DrawCanvas.tsx @@ -2,12 +2,9 @@ import React, { useRef, useState, useEffect, - useImperativeHandle, - forwardRef, } from "react"; import { Stage, Layer, Line } from "react-konva"; import { InferenceSession, Tensor } from "onnxruntime-web"; -import { clear } from "console"; import { clearInterval } from "timers"; interface AnimateProps { @@ -68,9 +65,7 @@ const DrawCanvas: React.FC = ({ "pencil", ]; - useEffect(() => { - (async () => { try { session.current = await InferenceSession.create("CNN_cat16_v6-0_large_gputrain.onnx"); @@ -82,7 +77,7 @@ const DrawCanvas: React.FC = ({ // const evalTimer = setInterval(() => { console.log("debug"); handleEvaluate()},predDebounce); // console.log("evaltimer!",evalTimer); - + return () => { // clearInterval(evalTimer) session.current?.release(); @@ -208,7 +203,7 @@ const DrawCanvas: React.FC = ({ // ) // ) // ); - + // for (let y = 0; y < side; y++) { // for (let x = 0; x < side; x++) { // const i = (y * side + x) * 4; @@ -216,8 +211,6 @@ const DrawCanvas: React.FC = ({ // } // } - - return rasterImage; }; @@ -299,22 +292,16 @@ const DrawCanvas: React.FC = ({ const normalizedStrokes = normalizeStrokes(lines); const rasterArray = rasterizeStrokes(normalizedStrokes); - ONNX(rasterArray).then((res) => { res = res as Float32Array; let i = argMax(res); let prob = softmax(res)[i]; let probPercent = Math.floor(prob * 1000) / 10; - setPrediction(modelCategories[i]); - setConfidence(probPercent); - - - if (probPercent > 70) { - + if (probPercent > 80) { dataPass(prediction); - } + } }); }; @@ -335,10 +322,10 @@ const DrawCanvas: React.FC = ({
{prediction && ( - I guess... {confidence > 70 ? prediction : "not sure"}! + I guess... {confidence > 80 ? prediction : "not sure"}! )} - {/* confidence > 70 ? ( + {/* confidence > 80 ? (

Confidence (dev): {confidence + "%"}

From 6d755ea47f1aac64d8aeb7ada3b0d3be57b0c4cc Mon Sep 17 00:00:00 2001 From: c ttps <67242476+cttps@users.noreply.github.com> Date: Tue, 8 Oct 2024 08:54:56 -0400 Subject: [PATCH 09/12] Add proper debouncing, update dockerignore --- frontend/.dockerignore | 1 + frontend/package-lock.json | 22 ++++++++++++++++++++++ frontend/package.json | 2 ++ frontend/src/components/DrawCanvas.tsx | 23 ++++++++++------------- 4 files changed, 35 insertions(+), 13 deletions(-) diff --git a/frontend/.dockerignore b/frontend/.dockerignore index 7cd53fd..e02706e 100644 --- a/frontend/.dockerignore +++ b/frontend/.dockerignore @@ -1 +1,2 @@ **/node_modules/ +**/out/ diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 87d41b4..6f8be6f 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -9,6 +9,7 @@ "version": "0.1.0", "dependencies": { "js-cookie": "^3.0.5", + "lodash.debounce": "^4.0.8", "next": "14.2.5", "onnxruntime-web": "^1.19.0", "react": "^18", @@ -18,6 +19,7 @@ }, "devDependencies": { "@types/js-cookie": "^3.0.6", + "@types/lodash.debounce": "^4.0.9", "@types/node": "^20", "@types/react": "^18", "@types/react-dom": "^18", @@ -506,6 +508,21 @@ "integrity": "sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==", "dev": true }, + "node_modules/@types/lodash": { + "version": "4.17.10", + "resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.17.10.tgz", + "integrity": "sha512-YpS0zzoduEhuOWjAotS6A5AVCva7X4lVlYLF0FYHAY9sdraBfnatttHItlWeZdGhuEkf+OzMNg2ZYAx8t+52uQ==", + "dev": true + }, + "node_modules/@types/lodash.debounce": { + "version": "4.0.9", + "resolved": "https://registry.npmjs.org/@types/lodash.debounce/-/lodash.debounce-4.0.9.tgz", + "integrity": "sha512-Ma5JcgTREwpLRwMM+XwBR7DaWe96nC38uCBDFKZWbNKD+osjVzdpnUSwBcqCptrp16sSOLBAUb50Car5I0TCsQ==", + "dev": true, + "dependencies": { + "@types/lodash": "*" + } + }, "node_modules/@types/node": { "version": "20.14.13", "resolved": "https://registry.npmjs.org/@types/node/-/node-20.14.13.tgz", @@ -3210,6 +3227,11 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/lodash.debounce": { + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/lodash.debounce/-/lodash.debounce-4.0.8.tgz", + "integrity": "sha512-FT1yDzDYEoYWhnSGnpE/4Kj1fLZkDFyqRb7fNt6FdYOSxlUWAtp42Eh6Wb0rGIv/m9Bgo7x4GhQbm5Ys4SG5ow==" + }, "node_modules/lodash.merge": { "version": "4.6.2", "resolved": "https://registry.npmjs.org/lodash.merge/-/lodash.merge-4.6.2.tgz", diff --git a/frontend/package.json b/frontend/package.json index 267b409..4efb2bb 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -10,6 +10,7 @@ }, "dependencies": { "js-cookie": "^3.0.5", + "lodash.debounce": "^4.0.8", "next": "14.2.5", "onnxruntime-web": "^1.19.0", "react": "^18", @@ -19,6 +20,7 @@ }, "devDependencies": { "@types/js-cookie": "^3.0.6", + "@types/lodash.debounce": "^4.0.9", "@types/node": "^20", "@types/react": "^18", "@types/react-dom": "^18", diff --git a/frontend/src/components/DrawCanvas.tsx b/frontend/src/components/DrawCanvas.tsx index 0a0bba8..f940f5a 100644 --- a/frontend/src/components/DrawCanvas.tsx +++ b/frontend/src/components/DrawCanvas.tsx @@ -5,7 +5,8 @@ import React, { } from "react"; import { Stage, Layer, Line } from "react-konva"; import { InferenceSession, Tensor } from "onnxruntime-web"; -import { clearInterval } from "timers"; +// import { clearInterval } from "timers"; +import debounce from "lodash.debounce"; interface AnimateProps { children: React.ReactNode; @@ -271,24 +272,20 @@ const DrawCanvas: React.FC = ({ return svgContent; }; - let lastDrawn= Date.now() useEffect(() => { // evaluate on lines changed if (lines.length > 0) { - let curDrawn = Date.now() - if (curDrawn - lastDrawn < predDebounce) { - handleEvaluate(); - } else { - let curLines = lines - setTimeout(() => { - if (lines === curLines) { // if lines have changed, then the drawing will get evaluated anyway - handleEvaluate(); - } - },predDebounce - (curDrawn - lastDrawn)) - } + const debounceEval = debounce(handleEvaluate,predDebounce) + debounceEval() + + return () => { + debounceEval.cancel(); // cleanup on unmount + }; + } }, [lines]); const handleEvaluate = () => { + console.log("ah") const normalizedStrokes = normalizeStrokes(lines); const rasterArray = rasterizeStrokes(normalizedStrokes); From 5e226f296a2e15eebeeed638b554fb1e0d5d7418 Mon Sep 17 00:00:00 2001 From: c ttps <67242476+cttps@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:21:31 -0400 Subject: [PATCH 10/12] Use ref to defer eval with setTimeout --- frontend/src/components/DrawCanvas.tsx | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/frontend/src/components/DrawCanvas.tsx b/frontend/src/components/DrawCanvas.tsx index f940f5a..bda0a5f 100644 --- a/frontend/src/components/DrawCanvas.tsx +++ b/frontend/src/components/DrawCanvas.tsx @@ -44,6 +44,7 @@ const DrawCanvas: React.FC = ({ const [confidence, setConfidence] = useState(0); const isDrawing = useRef(false); const session = useRef(null); + const [shouldReEval,setShouldReEval] = useState(false); const predDebounce = 400; @@ -272,14 +273,29 @@ const DrawCanvas: React.FC = ({ return svgContent; }; + const evalTimeoutRef = useRef(null); + useEffect(() => { // evaluate on lines changed if (lines.length > 0) { - const debounceEval = debounce(handleEvaluate,predDebounce) - debounceEval() + + if (lines.length > 0) { + if (evalTimeoutRef.current) { + clearTimeout(evalTimeoutRef.current); + } + evalTimeoutRef.current = setTimeout(handleEvaluate, predDebounce); + } return () => { - debounceEval.cancel(); // cleanup on unmount + if (evalTimeoutRef.current) { + clearTimeout(evalTimeoutRef.current); + } }; + // const debounceEval = debounce(handleEvaluate,predDebounce) + // debounceEval() + + // return () => { + // debounceEval.cancel(); // cleanup on unmount + // }; } }, [lines]); From 6e4bf4af2cb880d11494bf26b3c490ac255b019d Mon Sep 17 00:00:00 2001 From: c ttps <67242476+cttps@users.noreply.github.com> Date: Mon, 14 Oct 2024 13:24:03 -0400 Subject: [PATCH 11/12] Update code for improved remote training --- model/main.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/model/main.py b/model/main.py index ab67506..16921c5 100644 --- a/model/main.py +++ b/model/main.py @@ -38,6 +38,10 @@ import cairocffi as cairo +# Reduce mem usage on remote training runs +os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' +torch.cuda.empty_cache() + def unpack_drawing(file_handle): key_id, = unpack('Q', file_handle.read(8)) country_code, = unpack('2s', file_handle.read(2)) @@ -333,7 +337,7 @@ def view_images_grid(X, y): view_images_grid(X,y) -print(torch.cuda.is_available()) +print(torch.cuda.is_available()) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") @@ -406,7 +410,7 @@ def train_model(model, X_train, y_train, epochs=10, learning_rate=0.01): def train_cnn(model, X_train, y_train, X_val, y_val, epochs=10, learning_rate=0.001): model = model.to(device) - X_train = X_train.reshape(-1, 1, 28, 28) + X_train = X_train.reshape(-1, 1, 28, 28) X_train, y_train = torch.from_numpy(X_train).float().to(device), torch.from_numpy(y_train).long().to(device) X_val = X_val.reshape(-1, 1, 28, 28) X_val, y_val = torch.from_numpy(X_val).float().to(device), torch.from_numpy(y_val).long().to(device) @@ -503,7 +507,7 @@ def get_pred_cnn(model, raster): predicted_label = predicted.item() return labels[predicted_label] - + # X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) @@ -523,7 +527,7 @@ def get_pred_cnn(model, raster): # train_model(model, X_train, y_train, epochs=200, learning_rate=0.025) train_cnn(model, X_train, y_train, X_val, y_val, epochs=100, learning_rate=0.001) -# use sample input instead of 'torch.tensor(X_train[0], dtype=torch.float).unsqueeze(0)' to keep device consistent +# use sample input instead of 'torch.tensor(X_train[0], dtype=torch.float).unsqueeze(0)' to keep device consistent # sampleinput = torch.randn(1, input_size, dtype=torch.float).to(device) sampleinput = torch.randn(1, 1, 28, 28).to(device) From b771b07619385e8b0aacf3a28ac3d56460211db3 Mon Sep 17 00:00:00 2001 From: c ttps <67242476+cttps@users.noreply.github.com> Date: Mon, 14 Oct 2024 13:27:10 -0400 Subject: [PATCH 12/12] Update categories for new model on frontend --- frontend/src/components/DrawCanvas.tsx | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/frontend/src/components/DrawCanvas.tsx b/frontend/src/components/DrawCanvas.tsx index bda0a5f..2392ef7 100644 --- a/frontend/src/components/DrawCanvas.tsx +++ b/frontend/src/components/DrawCanvas.tsx @@ -48,29 +48,12 @@ const DrawCanvas: React.FC = ({ const predDebounce = 400; - const modelCategories = [ - "airplane", - "angel", - "ant", - "anvil", - "apple", - "banana", - "basketball", - "broom", - "camera", - "dog", - "dresser", - "hammer", - "hat", - "hexagon", - "paper clip", - "pencil", - ]; + const modelCategories = ['The%20Eiffel%20Tower.bin', 'airplane.bin', 'alarm%20clock.bin', 'anvil.bin', 'apple.bin', 'axe.bin', 'banana.bin', 'bed.bin', 'bee.bin', 'birthday%20cake.bin', 'book.bin', 'brain.bin', 'broom.bin', 'bucket.bin', 'calculator.bin', 'camera.bin', 'carrot.bin', 'car.bin', 'clock.bin', 'chair.bin', 'cookie.bin', 'diamond.bin', 'donut.bin', 'door.bin', 'elephant.bin', 'eye.bin', 'fish.bin', 'giraffe.bin', 'hammer.bin', 'hat.bin', 'key.bin', 'knife.bin', 'leaf.bin', 'map.bin', 'microphone.bin', 'mug.bin', 'mushroom.bin', 'nose.bin', 'palm%20tree.bin', 'pants.bin', 'paper%20clip.bin', 'peanut.bin', 'pillow.bin', 'rabbit.bin', 'river.bin'] useEffect(() => { (async () => { try { - session.current = await InferenceSession.create("CNN_cat16_v6-0_large_gputrain.onnx"); + session.current = await InferenceSession.create("CNN_cat45_v6-1_large_gputrain.onnx"); } catch (error) { // TODO: Handle this error properly console.error("Failed to load model", error);