-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
106 lines (93 loc) · 4.9 KB
/
main.py
File metadata and controls
106 lines (93 loc) · 4.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import pickle
import pandas as pd
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
app = FastAPI(
title="Customer Churn Prediction API",
description="Predicts whether a telecom customer will churn based on their account profile.",
version="2.0.0"
)
# ── Load model and feature list ───────────────────────────────────────────────
try:
model = pickle.load(open("churn_model.pkl", "rb"))
feature_columns = pickle.load(open("feature_columns.pkl", "rb"))
except FileNotFoundError:
raise RuntimeError("Model files not found. Run train_churn_model.py first.")
# ── Request Schema (mirrors Telco dataset features) ───────────────────────────
class CustomerData(BaseModel):
gender: int = Field(..., description="0 = Female, 1 = Male")
SeniorCitizen: int = Field(..., description="0 = No, 1 = Yes")
Partner: int = Field(..., description="0 = No, 1 = Yes")
Dependents: int = Field(..., description="0 = No, 1 = Yes")
tenure: int = Field(..., description="Months with company (0–72)")
PhoneService: int = Field(..., description="0 = No, 1 = Yes")
MultipleLines: int = Field(..., description="0 = No, 1 = No phone, 2 = Yes")
InternetService: int = Field(..., description="0 = DSL, 1 = Fiber optic, 2 = No")
OnlineSecurity: int = Field(..., description="0 = No, 1 = No internet, 2 = Yes")
OnlineBackup: int = Field(..., description="0 = No, 1 = No internet, 2 = Yes")
DeviceProtection: int = Field(..., description="0 = No, 1 = No internet, 2 = Yes")
TechSupport: int = Field(..., description="0 = No, 1 = No internet, 2 = Yes")
StreamingTV: int = Field(..., description="0 = No, 1 = No internet, 2 = Yes")
StreamingMovies: int = Field(..., description="0 = No, 1 = No internet, 2 = Yes")
Contract: int = Field(..., description="0 = Month-to-month, 1 = One year, 2 = Two year")
PaperlessBilling: int = Field(..., description="0 = No, 1 = Yes")
PaymentMethod: int = Field(..., description="0 = Bank transfer, 1 = Credit card, 2 = Electronic check, 3 = Mailed check")
MonthlyCharges: float = Field(..., description="Monthly bill amount in USD")
TotalCharges: float = Field(..., description="Total amount charged to date in USD")
model_config = {
"json_schema_extra": {
"example": {
"gender": 1,
"SeniorCitizen": 0,
"Partner": 1,
"Dependents": 0,
"tenure": 5,
"PhoneService": 1,
"MultipleLines": 0,
"InternetService": 1,
"OnlineSecurity": 0,
"OnlineBackup": 0,
"DeviceProtection": 0,
"TechSupport": 0,
"StreamingTV": 0,
"StreamingMovies": 0,
"Contract": 0,
"PaperlessBilling": 1,
"PaymentMethod": 2,
"MonthlyCharges": 70.35,
"TotalCharges": 351.75
}
}
}
# ── Prediction Endpoint ───────────────────────────────────────────────────────
@app.post("/predict", summary="Predict customer churn")
def predict(customer: CustomerData):
try:
df = pd.DataFrame([customer.model_dump()])
df = df[feature_columns] # Ensure column order matches training
prediction = int(model.predict(df)[0])
probability = round(float(model.predict_proba(df)[0][1]), 4)
return {
"prediction": "Will Churn" if prediction == 1 else "Will Stay",
"churn_probability": probability,
"risk_level": (
"High" if probability >= 0.7
else "Medium" if probability >= 0.4
else "Low"
)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ── Health Check ──────────────────────────────────────────────────────────────
@app.get("/health", summary="Health check")
def health():
return {"status": "ok", "model": "RandomForestClassifier", "version": "2.0.0"}
# ── Root ──────────────────────────────────────────────────────────────────────
@app.get("/", summary="API info")
def root():
return {
"name": "Customer Churn Prediction API",
"docs": "/docs",
"predict": "/predict",
"health": "/health"
}