-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_predict_churn_data.py
More file actions
51 lines (41 loc) · 1.84 KB
/
test_predict_churn_data.py
File metadata and controls
51 lines (41 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import pandas as pd
from pycaret.classification import predict_model, load_model
# Function to load data from a CSV file
def load_data(filepath):
"""
Loads churn data into a DataFrame from a specified filepath.
"""
df = pd.read_csv(filepath, index_col='customerID')
return df
# Function to make predictions using a pre-trained model
def make_predictions(df, threshold=0.75):
"""
Uses a loaded PyCaret model to make predictions on the dataframe.
Rounds up to 1 if prediction confidence is greater than or equal to the threshold.
"""
# Load your pre-trained model
model = load_model('lr') # Ensure the 'lr' model file is in an accessible location
# Make predictions
predictions = predict_model(model, data=df)
# Set the prediction threshold and adjust predictions accordingly
predictions['Churn_prediction'] = (predictions['prediction_score'] >= threshold).astype(int)
predictions = predictions.rename(columns={'prediction_score': 'Score'})
# Keep only the columns necessary for output
return predictions[['Score', 'Churn_prediction']]
if __name__ == "__main__":
# Load new data for prediction
df_new = load_data('new_churn_data.csv')
# Print the head of the data to confirm correct load
print("Loaded New Data:")
print(df_new.head())
# Get and print predictions
predictions = make_predictions(df_new)
print('Predictions:')
print(predictions)
# True values for comparison (from the problem statement)
true_values = [1, 0, 0, 1, 0]
# Calculate accuracy
predicted_labels = predictions['Churn_prediction'].tolist()
correct_predictions = [1 if pred == true else 0 for pred, true in zip(predicted_labels, true_values)]
accuracy = sum(correct_predictions) / len(correct_predictions)
print(f"Prediction Accuracy: {accuracy:.2f}")