-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
36 lines (27 loc) · 1.22 KB
/
train.py
File metadata and controls
36 lines (27 loc) · 1.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from random import randrange
from model import LinUCB
from utils import load_data, preprocess_fast, preprocess_without_title, preprocess_with_title, get_next_song_context
def main():
# Load and preprocess data
print("preprocessing data...")
data = load_data("data/songs.csv")
features = preprocess_fast(data)
# Initialize the LinUCB model
num_actions = len(features)
num_features = features.shape[1]
alpha = 0.9
linucb_model = LinUCB(num_actions=num_actions, num_features=num_features, alpha=alpha)
chosen_action = randrange(num_actions)
# Online training loop
while True:
title, context = get_next_song_context(data, features, chosen_action)
chosen_action = linucb_model.choose_action(context)
# Simulate user providing feedback for the chosen action (song)
print(f"Please rate the song titled '{title}' on a scale from 1 to 10:")
user_rating = int(input())
# Update the model with the user feedback
linucb_model.update_model(context=context, action=chosen_action, reward=user_rating)
# Save the trained model for later use (optional)
# linucb_model.save_model("trained_model.pkl")
if __name__ == "__main__":
main()