-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_model.py
More file actions
61 lines (45 loc) · 2.03 KB
/
train_model.py
File metadata and controls
61 lines (45 loc) · 2.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import json
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score, mean_absolute_error
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
data_sqft = [800, 1200, 1500, 1800, 2000, 2200, 2400, 2600]
data_beds = [2, 3, 3, 4, 4, 5, 4, 5]
data_price = [150000, 200000, 250000, 300000, 320000, 360000, 380000, 400000]
TRAIN_X = [[s, b] for s, b in zip(data_sqft, data_beds)]
TRAIN_Y = data_price
def train_and_export():
print("Training monotonic polynomial regression model...")
model = Pipeline([
("poly", PolynomialFeatures(degree=2, include_bias=False)),
("lin", LinearRegression(positive=True))
])
model.fit(TRAIN_X, TRAIN_Y)
train_pred = model.predict(TRAIN_X)
train_mae = mean_absolute_error(TRAIN_Y, train_pred)
train_r2 = r2_score(TRAIN_Y, train_pred)
print(f"Training MAE: ${train_mae:,.0f}")
print(f"Training R²: {train_r2:.4f}")
sample_sqft, sample_beds = 1900, 3
sample_pred = model.predict([[sample_sqft, sample_beds]])[0]
print("Predicted price for 1900 sqft, 3 beds:", round(sample_pred))
print("\nExporting ONNX model...")
initial_type = [('float_input', FloatTensorType([None, 2]))]
onnx_model = convert_sklearn(model, initial_types=initial_type)
with open('Geviti_App/public/model.onnx', 'wb') as f:
f.write(onnx_model.SerializeToString())
print("ONNX model exported to Geviti_App/public/model.onnx")
model_metadata = {
"model_type": "onnx_polynomial_linear_positive",
"input_features": ["square_footage", "bedrooms"],
"poly_degree": 2,
"model_file": "model.onnx"
}
with open('Geviti_App/public/model.json', 'w') as f:
json.dump(model_metadata, f, indent=2)
print("Model metadata exported to Geviti_App/public/model.json")
return model_metadata
if __name__ == "__main__":
train_and_export()