Skip to content

Commit bafd3a5

Browse files
committed
add fastAPI
1 parent fd6a478 commit bafd3a5

1 file changed

Lines changed: 316 additions & 0 deletions

File tree

main.py

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
# app.py
2+
import os
3+
import json
4+
import numpy as np
5+
import pandas as pd
6+
import tensorflow as tf
7+
from fastapi import FastAPI, HTTPException
8+
from pydantic import BaseModel, Field
9+
from typing import List, Optional, Dict
10+
11+
from anime_recommendation_app.modeling.model import HybridRecommenderNet
12+
13+
ARTIFACT_DIR = os.getenv("ARTIFACT_DIR", "./models/artifacts")
14+
MODEL_PATH = os.getenv("MODEL_PATH", "./models/model.keras")
15+
ANIME_CSV = os.getenv("ANIME_CSV", "./data/raw/anime.csv")
16+
17+
USER_TO_ENC_PATH = os.path.join(ARTIFACT_DIR, "user_to_user_encoded.json")
18+
ANIME_TO_ENC_PATH = os.path.join(ARTIFACT_DIR, "anime_to_anime_encoded.json")
19+
GENRE_TO_ENC_PATH = os.path.join(ARTIFACT_DIR, "genre_to_genre_encoded.json")
20+
ANIME_ENC_TO_ID = os.path.join(ARTIFACT_DIR, "anime_encoded_to_anime.json")
21+
SCALE_PATH = os.path.join(ARTIFACT_DIR, "rating_scale.json")
22+
23+
app = FastAPI(title="Anime Hybrid Recommender API", version="1.0.0")
24+
25+
class PredictRequest(BaseModel):
26+
user_id: int
27+
anime_id: int
28+
29+
class PredictResponse(BaseModel):
30+
user_id: int
31+
anime_id: int
32+
predicted_score_0_1: float = Field(..., description="Model output in [0,1]")
33+
predicted_rating: Optional[float] = Field(None, description="Denormalized rating (e.g., 0–10)")
34+
35+
class RecommendRequest(BaseModel):
36+
user_id: Optional[int] = Field(None, description="Known user. If None, use cold-start via preferred_genres.")
37+
top_k: int = 10
38+
allowed_genres: Optional[List[str]] = None
39+
exclude_anime_ids: Optional[List[int]] = None
40+
only_type: Optional[str] = Field(None, description="e.g., 'TV', 'Movie'")
41+
42+
preferred_genres: Optional[List[str]] = None
43+
44+
class RecommendedItem(BaseModel):
45+
anime_id: int
46+
name: Optional[str]
47+
main_genre: Optional[str]
48+
predicted_score_0_1: float
49+
50+
class RecommendResponse(BaseModel):
51+
items: List[RecommendedItem]
52+
53+
@app.on_event("startup")
54+
def load_artifacts():
55+
global model, anime_df, user_to_enc, anime_to_enc, genre_to_enc, enc_to_anime, rating_scale
56+
57+
# Load model
58+
model = tf.keras.models.load_model(
59+
MODEL_PATH,
60+
custom_objects={'HybridRecommenderNet': HybridRecommenderNet}
61+
)
62+
63+
# Load data for lookups
64+
anime_df = pd.read_csv(ANIME_CSV)
65+
if "genre" not in anime_df.columns:
66+
anime_df["genre"] = "Unknown"
67+
anime_df["genre"] = anime_df["genre"].fillna("Unknown")
68+
anime_df["main_genre"] = anime_df["genre"].apply(
69+
lambda x: x.split(",")[0].strip() if isinstance(x, str) and x else "Unknown"
70+
)
71+
72+
# Load encoders
73+
with open(USER_TO_ENC_PATH, "r") as f:
74+
user_to_enc = {int(k): int(v) for k, v in json.load(f).items()}
75+
with open(ANIME_TO_ENC_PATH, "r") as f:
76+
anime_to_enc = {int(k): int(v) for k, v in json.load(f).items()}
77+
with open(GENRE_TO_ENC_PATH, "r") as f:
78+
genre_to_enc = json.load(f)
79+
80+
enc_to_anime = None
81+
if os.path.exists(ANIME_ENC_TO_ID):
82+
with open(ANIME_ENC_TO_ID, "r") as f:
83+
enc_to_anime = {int(k): int(v) for k, v in json.load(f).items()}
84+
85+
rating_scale = {"min": 0.0, "max": 10.0}
86+
if os.path.exists(SCALE_PATH):
87+
with open(SCALE_PATH, "r") as f:
88+
rating_scale = json.load(f)
89+
90+
global COLD_BASE_INDEX
91+
max_known = max(user_to_enc.values()) if len(user_to_enc) > 0 else -1
92+
COLD_BASE_INDEX = max_known + 1
93+
94+
95+
def encode_row(user_id: int, anime_id: int) -> np.ndarray:
96+
if user_id not in user_to_enc:
97+
raise KeyError("unknown_user")
98+
if anime_id not in anime_to_enc:
99+
raise KeyError("unknown_anime")
100+
101+
row = anime_df.loc[anime_df["anime_id"] == anime_id]
102+
if row.empty:
103+
raise KeyError("anime_not_found_in_master")
104+
main_genre = row.iloc[0]["main_genre"]
105+
if main_genre not in genre_to_enc:
106+
raise KeyError("unknown_genre")
107+
108+
user_code = user_to_enc[user_id]
109+
anime_code = anime_to_enc[anime_id]
110+
genre_code = genre_to_enc[main_genre]
111+
return np.array([[user_code, anime_code, genre_code]], dtype=np.int64)
112+
113+
114+
def denormalize(y_pred: float) -> float:
115+
return rating_scale["min"] + y_pred * (rating_scale["max"] - rating_scale["min"])
116+
117+
118+
def filter_candidate_anime(
119+
allowed_genres: Optional[List[str]], only_type: Optional[str], exclude_anime_ids: Optional[List[int]]
120+
) -> pd.DataFrame:
121+
df = anime_df
122+
if only_type:
123+
df = df[df["type"] == only_type]
124+
if allowed_genres and len(allowed_genres) > 0:
125+
df = df[df["main_genre"].isin(allowed_genres)]
126+
if exclude_anime_ids and len(exclude_anime_ids) > 0:
127+
df = df[~df["anime_id"].isin(exclude_anime_ids)]
128+
129+
df = df[df["anime_id"].isin(anime_to_enc.keys())]
130+
131+
df = df[df["main_genre"].isin(genre_to_enc.keys())]
132+
return df.copy()
133+
134+
@app.get("/health")
135+
def health():
136+
return {"status": "ok"}
137+
138+
@app.post("/predict", response_model=PredictResponse)
139+
def predict(req: PredictRequest):
140+
try:
141+
X = encode_row(req.user_id, req.anime_id)
142+
except KeyError as e:
143+
msg = str(e)
144+
if "unknown_user" in msg:
145+
raise HTTPException(status_code=400, detail="User not found in trained encoders.")
146+
if "unknown_anime" in msg or "anime_not_found_in_master" in msg:
147+
raise HTTPException(status_code=400, detail="Anime not found or not in trained encoders.")
148+
if "unknown_genre" in msg:
149+
raise HTTPException(status_code=400, detail="Anime main_genre not recognized by encoder.")
150+
raise
151+
152+
y_pred = float(model.predict(X, verbose=0).reshape(-1)[0]) # [0,1]
153+
out = PredictResponse(
154+
user_id=req.user_id,
155+
anime_id=req.anime_id,
156+
predicted_score_0_1=y_pred,
157+
predicted_rating=round(denormalize(y_pred), 3)
158+
)
159+
return out
160+
161+
@app.post("/recommend", response_model=RecommendResponse)
162+
def recommend(req: RecommendRequest):
163+
candidates = filter_candidate_anime(req.allowed_genres, req.only_type, req.exclude_anime_ids)
164+
165+
# If we have a known user, give recommendation based on preference (collaborative + content base)
166+
if req.user_id is not None and req.user_id in user_to_enc:
167+
user_code = user_to_enc[req.user_id]
168+
# Build [user_code, anime_code, genre_code] for all candidates
169+
anime_codes = candidates["anime_id"].map(anime_to_enc)
170+
genre_codes = candidates["main_genre"].map(genre_to_enc)
171+
X = np.column_stack([np.full(len(candidates), user_code, dtype=np.int64),
172+
anime_codes.values.astype(np.int64),
173+
genre_codes.values.astype(np.int64)])
174+
y_pred = model.predict(X, verbose=0).reshape(-1)
175+
candidates = candidates.assign(score=y_pred)
176+
top = candidates.sort_values("score", ascending=False).head(req.top_k)
177+
178+
else:
179+
# Cold-start: score by genre preference if provided. Otherwise, return popular/random
180+
if not req.preferred_genres:
181+
# simple neutral score = 0.5; you can plug a popularity prior
182+
candidates = candidates.assign(score=0.5)
183+
top = candidates.sample(n=min(req.top_k, len(candidates)), random_state=42)
184+
else:
185+
# simple heuristic: preferred genre gets 0.7, others 0.4
186+
candidates = candidates.assign(
187+
score=np.where(candidates["main_genre"].isin(req.preferred_genres), 0.7, 0.4)
188+
)
189+
top = candidates.sort_values("score", ascending=False).head(req.top_k)
190+
191+
items = [
192+
RecommendedItem(
193+
anime_id=int(r.anime_id),
194+
name=r.get("name") if "name" in r.index else None,
195+
main_genre=r.get("main_genre") if "main_genre" in r.index else None,
196+
predicted_score_0_1=float(r.score)
197+
)
198+
for _, r in top.iterrows()
199+
]
200+
return RecommendResponse(items=items)
201+
202+
# ---- Config for cold-start pool ----
203+
COLD_SLOTS = int(os.getenv("COLD_SLOTS", "1000"))
204+
COLD_BASE_INDEX = None # set at startup based on loaded encoders
205+
cold_slot_in_use: Dict[str, int] = {} # map session/user token -> reserved slot
206+
207+
class RatedItem(BaseModel):
208+
anime_id: int
209+
rating: float
210+
211+
class BootstrapRequest(BaseModel):
212+
session_key: str = Field(..., description="Your client session/user token")
213+
rated: List[RatedItem]
214+
top_k: int = 10
215+
allowed_genres: Optional[List[str]] = None
216+
only_type: Optional[str] = None
217+
218+
class BootstrapResponse(BaseModel):
219+
personalized_user_code: int
220+
items: List[RecommendedItem]
221+
222+
def get_or_assign_cold_slot(session_key: str) -> int:
223+
# reuse if already assigned this session
224+
if session_key in cold_slot_in_use:
225+
return cold_slot_in_use[session_key]
226+
# find next free slot
227+
for i in range(COLD_SLOTS):
228+
slot = COLD_BASE_INDEX + i
229+
if slot not in cold_slot_in_use.values():
230+
cold_slot_in_use[session_key] = slot
231+
return slot
232+
raise HTTPException(status_code=429, detail="No cold-start slots available right now.")
233+
234+
235+
236+
def _normalize(y: np.ndarray) -> np.ndarray:
237+
# assumes rating_scale["min"], ["max"]
238+
return (y - rating_scale["min"]) / max(1e-8, (rating_scale["max"] - rating_scale["min"]))
239+
240+
def _prepare_bootstrap_xy(rated: List[RatedItem], cold_user_code: int):
241+
xs, ys = [], []
242+
for r in rated:
243+
if r.anime_id not in anime_to_enc:
244+
# skip unknown items to the model
245+
continue
246+
row = anime_df.loc[anime_df["anime_id"] == r.anime_id]
247+
if row.empty:
248+
continue
249+
main_genre = row.iloc[0]["main_genre"]
250+
if main_genre not in genre_to_enc:
251+
continue
252+
xs.append([cold_user_code, anime_to_enc[r.anime_id], genre_to_enc[main_genre]])
253+
ys.append(r.rating)
254+
if not xs:
255+
raise HTTPException(status_code=400, detail="None of the provided anime exist in the trained encoders.")
256+
X = np.array(xs, dtype=np.int64)
257+
y = _normalize(np.array(ys, dtype=np.float32))
258+
return X, y
259+
260+
@app.post("/bootstrap_recommend", response_model=BootstrapResponse)
261+
def bootstrap_recommend(req: BootstrapRequest):
262+
# Pick or create a cold slot for this session
263+
cold_user_code = get_or_assign_cold_slot(req.session_key)
264+
265+
# Build training mini-batch from user’s rated items
266+
X, y = _prepare_bootstrap_xy(req.rated, cold_user_code)
267+
268+
# Freeze everything except user_embedding and user_bias
269+
for layer in model.layers:
270+
layer.trainable = False
271+
try:
272+
model.user_embedding.trainable = True
273+
model.user_bias.trainable = True
274+
except Exception:
275+
276+
for l in model.layers:
277+
if "user_embedding" in l.name or "user_bias" in l.name:
278+
l.trainable = True
279+
280+
# Quick personalization fit
281+
model.compile(
282+
loss=tf.keras.losses.MeanSquaredError(),
283+
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2), # faster convergence
284+
metrics=[tf.keras.metrics.RootMeanSquaredError()]
285+
)
286+
model.fit(X, y, batch_size=min(64, len(X)), epochs=5, verbose=0)
287+
288+
# Score candidates and return top-K (same filtering as /recommend)
289+
candidates = filter_candidate_anime(req.allowed_genres, req.only_type, exclude_anime_ids=None)
290+
if candidates.empty:
291+
raise HTTPException(status_code=404, detail="No candidates after filters.")
292+
293+
anime_codes = candidates["anime_id"].map(anime_to_enc)
294+
genre_codes = candidates["main_genre"].map(genre_to_enc)
295+
Xc = np.column_stack([
296+
np.full(len(candidates), cold_user_code, dtype=np.int64),
297+
anime_codes.values.astype(np.int64),
298+
genre_codes.values.astype(np.int64),
299+
])
300+
y_pred = model.predict(Xc, verbose=0).reshape(-1)
301+
candidates = candidates.assign(score=y_pred)
302+
top = candidates.sort_values("score", ascending=False).head(req.top_k)
303+
304+
items = [
305+
RecommendedItem(
306+
anime_id=int(r.anime_id),
307+
name=r.get("name") if "name" in r.index else None,
308+
main_genre=r.get("main_genre") if "main_genre" in r.index else None,
309+
predicted_score_0_1=float(r.score),
310+
)
311+
for _, r in top.iterrows()
312+
]
313+
return BootstrapResponse(
314+
personalized_user_code=cold_user_code,
315+
items=items
316+
)

0 commit comments

Comments
 (0)