Skip to content

Commit ac31bb7

Browse files
authored
Merge pull request #20 from stanleygvi/weight_adjustment_ui
UI Improvements: Weight adjustments and expired session detection
2 parents f0aa2fd + d1fef84 commit ac31bb7

7 files changed

Lines changed: 1537 additions & 122 deletions

File tree

Backend/app.py

Lines changed: 137 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_spotify_redirect_uri,
2424
)
2525
from Backend.playlist_processing import process_all
26+
from Backend.grouping import normalize_feature_weights
2627
from Backend.helpers import generate_random_string
2728
from Backend.job_status_store import (
2829
set_job_state,
@@ -39,6 +40,7 @@
3940
get_spotify_redirect_uri,
4041
)
4142
from playlist_processing import process_all # type: ignore
43+
from grouping import normalize_feature_weights # type: ignore
4244
from helpers import generate_random_string # type: ignore
4345
from job_status_store import ( # type: ignore
4446
set_job_state,
@@ -120,9 +122,53 @@ def _missing_required_scopes() -> list[str]:
120122
JOB_STATUS_TTL_SECONDS = int(os.getenv("JOB_STATUS_TTL_SECONDS", "21600"))
121123

122124

123-
def get_auth_token_from_request():
124-
"""Return auth token from request cookies, falling back to server session."""
125-
return session.get("auth_token") or request.cookies.get("auth_token")
125+
def _clear_auth_session():
126+
"""Clear server-side auth/session values for a stale login state."""
127+
session.pop("uid", None)
128+
session.pop("auth_token", None)
129+
session.pop("refresh_token", None)
130+
session.pop("auth_scopes", None)
131+
132+
133+
def _unauthorized_session_response(message: str = "Spotify session expired. Please log in again."):
134+
"""Return standardized 401 payload for expired/missing auth sessions."""
135+
return (
136+
jsonify(
137+
{
138+
"Code": 401,
139+
"Error": message,
140+
"reauth": True,
141+
}
142+
),
143+
401,
144+
)
145+
146+
147+
def _resolve_active_auth_token():
148+
"""
149+
Return a valid Spotify auth token for current request.
150+
151+
Attempts refresh when session token is expired. Returns tuple:
152+
(auth_token, error_response_or_none)
153+
"""
154+
auth_token = session.get("auth_token")
155+
refresh_token = session.get("refresh_token")
156+
157+
if not auth_token:
158+
_clear_auth_session()
159+
return None, _unauthorized_session_response("Authorization required.")
160+
161+
if is_access_token_valid(auth_token):
162+
return auth_token, None
163+
164+
if refresh_token:
165+
refreshed_token = refresh_access_token(refresh_token)
166+
if refreshed_token:
167+
session["auth_token"] = refreshed_token
168+
return refreshed_token, None
169+
170+
_clear_auth_session()
171+
return None, _unauthorized_session_response()
126172

127173

128174
def _prune_old_jobs():
@@ -136,11 +182,54 @@ def _set_job_state(job_id: str, **fields):
136182
set_job_state(job_id, **fields)
137183

138184

139-
def _run_process_playlist_job(job_id: str, auth_token: str, playlist_ids: list[str]):
185+
def _run_process_playlist_job(
186+
job_id: str,
187+
auth_token: str,
188+
playlist_ids: list[str],
189+
feature_weights: dict[str, float] | None = None,
190+
split_criterion: str | None = None,
191+
):
140192
"""Run playlist processing in background and persist status fields."""
193+
total_playlists = len(playlist_ids)
194+
195+
def _emit_progress(
196+
completed_playlists: int,
197+
total_playlists: int,
198+
failed_playlists: int = 0,
199+
last_completed_playlist_id: str | None = None,
200+
last_completed_playlist_name: str | None = None,
201+
):
202+
safe_total = max(1, int(total_playlists))
203+
raw_percent = int(round((completed_playlists / safe_total) * 100))
204+
progress_percent = max(0, min(100, raw_percent))
205+
_set_job_state(
206+
job_id,
207+
completed_playlists=completed_playlists,
208+
total_playlists=total_playlists,
209+
failed_playlists=failed_playlists,
210+
progress_percent=progress_percent,
211+
last_completed_playlist_id=last_completed_playlist_id,
212+
last_completed_playlist_name=last_completed_playlist_name,
213+
)
214+
141215
_set_job_state(job_id, status="running", started_at=time.time())
216+
_emit_progress(
217+
completed_playlists=0,
218+
total_playlists=total_playlists,
219+
failed_playlists=0,
220+
)
142221
try:
143-
process_all(auth_token, playlist_ids)
222+
process_all(
223+
auth_token,
224+
playlist_ids,
225+
feature_weights=feature_weights,
226+
split_criterion=split_criterion,
227+
progress_callback=_emit_progress,
228+
)
229+
_emit_progress(
230+
completed_playlists=total_playlists,
231+
total_playlists=total_playlists,
232+
)
144233
_set_job_state(
145234
job_id,
146235
status="succeeded",
@@ -193,6 +282,8 @@ def login_handler():
193282
if not is_access_token_valid(auth_token):
194283
if refresh_token:
195284
new_auth_token = refresh_access_token(refresh_token)
285+
if not new_auth_token:
286+
return redirect_to_spotify_login()
196287
session["auth_token"] = new_auth_token
197288
auth_token = new_auth_token
198289
else:
@@ -264,16 +355,17 @@ def callback_handler():
264355
@app.route("/api/user-playlists")
265356
def get_playlist_handler():
266357
"""Return current user's Spotify playlists based on auth cookie token."""
267-
auth_token = get_auth_token_from_request()
268-
269-
if not auth_token:
270-
print(f"NO AUTH: {auth_token}")
271-
return {"Code": 401, "Error": "Authorization token required"}
358+
auth_token, auth_error = _resolve_active_auth_token()
359+
if auth_error:
360+
return auth_error
272361

273362
playlists = get_all_playlists(auth_token)
274363

275364
if not playlists:
276-
return {"Code": 500, "Error": "Failed to get playlists"}
365+
if not is_access_token_valid(auth_token):
366+
_clear_auth_session()
367+
return _unauthorized_session_response()
368+
return jsonify({"Code": 502, "Error": "Failed to get playlists"}), 502
277369

278370
return jsonify(playlists)
279371

@@ -282,10 +374,9 @@ def get_playlist_handler():
282374
@app.route("/api/process-playlist", methods=["POST"])
283375
def process_playlist_handler():
284376
"""Start async processing job for selected playlists."""
285-
auth_token = get_auth_token_from_request()
286-
287-
if not auth_token:
288-
return "Authorization required", 401
377+
auth_token, auth_error = _resolve_active_auth_token()
378+
if auth_error:
379+
return auth_error
289380

290381
missing_scopes = _missing_required_scopes()
291382
if missing_scopes:
@@ -302,9 +393,31 @@ def process_playlist_handler():
302393

303394
assert request.json
304395
playlist_ids = request.json.get("playlistIds", [])
396+
feature_weights_payload = request.json.get("featureWeights")
397+
split_criterion_payload = request.json.get("splitCriterion")
305398

306399
if not playlist_ids:
307400
return "No playlist IDs provided", 400
401+
if feature_weights_payload is not None and not isinstance(feature_weights_payload, dict):
402+
return (
403+
jsonify(
404+
{
405+
"Code": 400,
406+
"Error": "featureWeights must be an object keyed by feature name.",
407+
}
408+
),
409+
400,
410+
)
411+
feature_weights = (
412+
normalize_feature_weights(feature_weights_payload)
413+
if isinstance(feature_weights_payload, dict)
414+
else None
415+
)
416+
split_criterion = (
417+
split_criterion_payload.strip().lower()
418+
if isinstance(split_criterion_payload, str) and split_criterion_payload.strip()
419+
else None
420+
)
308421

309422
_prune_old_jobs()
310423
job_id = str(uuid.uuid4())
@@ -315,10 +428,18 @@ def process_playlist_handler():
315428
finished_at=None,
316429
error=None,
317430
playlist_count=len(playlist_ids),
431+
completed_playlists=0,
432+
total_playlists=len(playlist_ids),
433+
failed_playlists=0,
434+
progress_percent=0,
435+
last_completed_playlist_id=None,
436+
last_completed_playlist_name=None,
437+
feature_weights=feature_weights,
438+
split_criterion=split_criterion,
318439
)
319440
job_thread = threading.Thread(
320441
target=_run_process_playlist_job,
321-
args=(job_id, auth_token, playlist_ids),
442+
args=(job_id, auth_token, playlist_ids, feature_weights, split_criterion),
322443
daemon=True,
323444
)
324445
job_thread.start()

Backend/grouping.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Clustering utilities for grouping tracks by audio feature similarity."""
22

33
import os
4+
import math
45

56
import numpy as np
67
from sklearn.metrics import pairwise_distances
@@ -40,6 +41,8 @@
4041
"tempo": 0.85,
4142
"valence": 1.55,
4243
}
44+
MIN_FEATURE_WEIGHT = 0.0
45+
MAX_FEATURE_WEIGHT = 3.0
4346

4447

4548
def _env_positive_int(name: str, default_value: int) -> int:
@@ -69,6 +72,26 @@ def _env_positive_int(name: str, default_value: int) -> int:
6972
REFINE_MAX_TRACKS = _env_positive_int("CLUSTER_REFINE_MAX_TRACKS", 600)
7073

7174

75+
def normalize_feature_weights(
76+
feature_weights: dict[str, float] | None,
77+
) -> dict[str, float]:
78+
"""Return bounded feature weights merged with defaults."""
79+
if not isinstance(feature_weights, dict):
80+
return dict(FEATURE_WEIGHTS)
81+
82+
normalized = {}
83+
for key, default_value in FEATURE_WEIGHTS.items():
84+
raw_value = feature_weights.get(key, default_value)
85+
try:
86+
candidate = float(raw_value)
87+
except (TypeError, ValueError):
88+
candidate = default_value
89+
if not math.isfinite(candidate):
90+
candidate = default_value
91+
normalized[key] = min(MAX_FEATURE_WEIGHT, max(MIN_FEATURE_WEIGHT, candidate))
92+
return normalized
93+
94+
7295
def _merge_small_clusters(
7396
scaled_features: np.ndarray, labels: np.ndarray, min_cluster_size: int
7497
) -> np.ndarray:
@@ -251,7 +274,9 @@ def _refine_cluster_cohesion(
251274
return refined
252275

253276

254-
def cluster_df(track_audio_features: list[dict]) -> pd.DataFrame:
277+
def cluster_df(
278+
track_audio_features: list[dict], feature_weights: dict[str, float] | None = None
279+
) -> pd.DataFrame:
255280
"""Return dataframe with track id and assigned GMM clusters."""
256281
if not track_audio_features:
257282
return pd.DataFrame(columns=["id", "cluster"])
@@ -278,7 +303,8 @@ def cluster_df(track_audio_features: list[dict]) -> pd.DataFrame:
278303

279304
scaler = StandardScaler()
280305
scaled = scaler.fit_transform(feature_frame)
281-
weights = np.array([FEATURE_WEIGHTS.get(key, 1.0) for key in available_keys])
306+
effective_feature_weights = normalize_feature_weights(feature_weights)
307+
weights = np.array([effective_feature_weights.get(key, 1.0) for key in available_keys])
282308
weighted_scaled = scaled * weights
283309

284310
track_count = len(feature_frame)

0 commit comments

Comments
 (0)