Skip to content

Commit 87ff2d6

Browse files
feat: add quantization script
1 parent ebfcf15 commit 87ff2d6

2 files changed

Lines changed: 39 additions & 0 deletions

File tree

nn.vim

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
autocmd BufEnter *.nn runtime! syntax/json.vim

scripts/quantize_network.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#!/usr/bin/env python3
2+
import json
3+
import sys
4+
5+
if len(sys.argv) < 3:
6+
print("Usage: quantize_model.py input.json output.json [quantize_level]")
7+
sys.exit(1)
8+
9+
input_file = sys.argv[1]
10+
output_file = sys.argv[2]
11+
quantize_level = int(sys.argv[3]) if len(sys.argv) >= 4 else None
12+
13+
with open(input_file) as f:
14+
model = json.load(f)
15+
16+
quantized_weights = []
17+
scales = []
18+
19+
for layer in model["parameters"]["weights"]:
20+
w_q = [
21+
[round(float(val), quantize_level or 1) for val in sub]
22+
for sub in layer
23+
]
24+
25+
quantized_weights.append(w_q)
26+
27+
model["parameters"]["weights"] = quantized_weights
28+
model["parameters"]["quantization"] = {
29+
"type": "int8" if quantize_level is None else "rounding",
30+
"scheme": "symmetric" if quantize_level is None else "decimal",
31+
"per": "layer",
32+
"quantize_level": quantize_level,
33+
}
34+
35+
with open(output_file, "w") as f:
36+
json.dump(model, f, indent=2)
37+
38+
print("Quantization complete.")

0 commit comments

Comments
 (0)