-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconvert_keras_model_to_numpy.py
More file actions
37 lines (25 loc) · 1.03 KB
/
convert_keras_model_to_numpy.py
File metadata and controls
37 lines (25 loc) · 1.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import pickle
import numpy as np
import keras
from models import PointNetClassifierNumpy, numpy_layers
def convert_layer(keras_layer):
"""Convert"""
# Get the type of layer
layer_type = keras_layer.__class__.__name__
# Find the equivalent (same naming)
numpy_layer = getattr(numpy_layers, layer_type)
# Initialize with data
return numpy_layer.from_keras_layer(keras_layer)
if __name__ == "__main__":
SAVE_DIR = "./models/saved/"
# Load Keras model (choose among Keras saved models)
trained_model = "PointNetClassifierModelNet40Full.keras"
keras_model = keras.saving.load_model(os.path.join(SAVE_DIR, trained_model))
# Convert to Numpy model
numpy_model_layers = {layer.name: convert_layer(layer) for layer in keras_model.layers}
numpy_model = PointNetClassifierNumpy(**numpy_model_layers)
# Save Numpy model
model_name = trained_model.replace(".keras", "Numpy.pkl")
with open(os.path.join(SAVE_DIR, model_name), "wb") as f:
pickle.dump(numpy_model, f)