Skip to content

Commit 7936c22

Browse files
committed
multiregression: Initial MicroPython code
1 parent f21d1ad commit 7936c22

1 file changed

Lines changed: 68 additions & 0 deletions

File tree

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
2+
"""MicroPython code for doing multi-output regression with emlearn_trees
3+
"""
4+
5+
import os
6+
7+
import emlearn_trees
8+
import array
9+
10+
class MultiRegressor():
11+
"""Convenience wrapper for a collection of tree-based regression models"""
12+
13+
def __init__(self, max_trees=10, max_nodes=1000, max_leaves=1000):
14+
self.models = []
15+
16+
self.max_trees = max_trees
17+
self.max_nodes = max_nodes
18+
self.max_leaves = max_leaves
19+
20+
# temporary buffer for invididual model output
21+
self._output = array.array('f', [0.0])
22+
23+
def load(self, path):
24+
"""Load a directory of model files"""
25+
26+
for filename in os.listdir(path):
27+
if not filename.endswith('.csv'):
28+
print('Warning: Ignoring unknown file in model directory', filename)
29+
continue
30+
31+
model_path = path + '/' + filename
32+
33+
# TODO: support reading neccesary capacity from file
34+
model = emlearn_trees.new(self.max_trees, self.max_nodes, self.max_leaves)
35+
36+
with open(model_path, 'r') as f:
37+
emlearn_trees.load_model(model, f)
38+
39+
self.models.append(model)
40+
41+
def predict(self, features : array.array, outputs : array.array):
42+
assert len(self.models), 'no models'
43+
44+
for i, model in self.models():
45+
model.predict(features, self._output)
46+
outputs[i] = self._output[0]
47+
48+
def main():
49+
50+
# FIXME: read paths from sys.argv
51+
model = MultiRegressor()
52+
model.load('models')
53+
54+
outputs = array.array('f', [0.0 for _ in range(model.models)])
55+
56+
import npyfile
57+
(n_samples, n_features), data = npyfile.load('data.npy')
58+
59+
# TODO: write output to a file
60+
for row in range(n_samples):
61+
offset = row*n_features
62+
f = data[offset:offset+n_features]
63+
model.predict(f, outputs)
64+
65+
66+
67+
if __name__ == '__main__':
68+
main()

0 commit comments

Comments
 (0)