Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 137 additions & 16 deletions Backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
get_spotify_redirect_uri,
)
from Backend.playlist_processing import process_all
from Backend.grouping import normalize_feature_weights
from Backend.helpers import generate_random_string
from Backend.job_status_store import (
set_job_state,
Expand All @@ -39,6 +40,7 @@
get_spotify_redirect_uri,
)
from playlist_processing import process_all # type: ignore
from grouping import normalize_feature_weights # type: ignore
from helpers import generate_random_string # type: ignore
from job_status_store import ( # type: ignore
set_job_state,
Expand Down Expand Up @@ -120,9 +122,53 @@ def _missing_required_scopes() -> list[str]:
JOB_STATUS_TTL_SECONDS = int(os.getenv("JOB_STATUS_TTL_SECONDS", "21600"))


def get_auth_token_from_request():
"""Return auth token from request cookies, falling back to server session."""
return session.get("auth_token") or request.cookies.get("auth_token")
def _clear_auth_session():
"""Clear server-side auth/session values for a stale login state."""
session.pop("uid", None)
session.pop("auth_token", None)
session.pop("refresh_token", None)
session.pop("auth_scopes", None)


def _unauthorized_session_response(message: str = "Spotify session expired. Please log in again."):
"""Return standardized 401 payload for expired/missing auth sessions."""
return (
jsonify(
{
"Code": 401,
"Error": message,
"reauth": True,
}
),
401,
)


def _resolve_active_auth_token():
"""
Return a valid Spotify auth token for current request.

Attempts refresh when session token is expired. Returns tuple:
(auth_token, error_response_or_none)
"""
auth_token = session.get("auth_token")
refresh_token = session.get("refresh_token")

if not auth_token:
_clear_auth_session()
return None, _unauthorized_session_response("Authorization required.")

if is_access_token_valid(auth_token):
return auth_token, None

if refresh_token:
refreshed_token = refresh_access_token(refresh_token)
if refreshed_token:
session["auth_token"] = refreshed_token
return refreshed_token, None

_clear_auth_session()
return None, _unauthorized_session_response()


def _prune_old_jobs():
Expand All @@ -136,11 +182,54 @@ def _set_job_state(job_id: str, **fields):
set_job_state(job_id, **fields)


def _run_process_playlist_job(job_id: str, auth_token: str, playlist_ids: list[str]):
def _run_process_playlist_job(
job_id: str,
auth_token: str,
playlist_ids: list[str],
feature_weights: dict[str, float] | None = None,
split_criterion: str | None = None,
):
"""Run playlist processing in background and persist status fields."""
total_playlists = len(playlist_ids)

def _emit_progress(
completed_playlists: int,
total_playlists: int,
failed_playlists: int = 0,
last_completed_playlist_id: str | None = None,
last_completed_playlist_name: str | None = None,
):
safe_total = max(1, int(total_playlists))
raw_percent = int(round((completed_playlists / safe_total) * 100))
progress_percent = max(0, min(100, raw_percent))
_set_job_state(
job_id,
completed_playlists=completed_playlists,
total_playlists=total_playlists,
failed_playlists=failed_playlists,
progress_percent=progress_percent,
last_completed_playlist_id=last_completed_playlist_id,
last_completed_playlist_name=last_completed_playlist_name,
)

_set_job_state(job_id, status="running", started_at=time.time())
_emit_progress(
completed_playlists=0,
total_playlists=total_playlists,
failed_playlists=0,
)
try:
process_all(auth_token, playlist_ids)
process_all(
auth_token,
playlist_ids,
feature_weights=feature_weights,
split_criterion=split_criterion,
progress_callback=_emit_progress,
)
_emit_progress(
completed_playlists=total_playlists,
total_playlists=total_playlists,
)
_set_job_state(
job_id,
status="succeeded",
Expand Down Expand Up @@ -193,6 +282,8 @@ def login_handler():
if not is_access_token_valid(auth_token):
if refresh_token:
new_auth_token = refresh_access_token(refresh_token)
if not new_auth_token:
return redirect_to_spotify_login()
session["auth_token"] = new_auth_token
auth_token = new_auth_token
else:
Expand Down Expand Up @@ -264,16 +355,17 @@ def callback_handler():
@app.route("/api/user-playlists")
def get_playlist_handler():
"""Return current user's Spotify playlists based on auth cookie token."""
auth_token = get_auth_token_from_request()

if not auth_token:
print(f"NO AUTH: {auth_token}")
return {"Code": 401, "Error": "Authorization token required"}
auth_token, auth_error = _resolve_active_auth_token()
if auth_error:
return auth_error

playlists = get_all_playlists(auth_token)

if not playlists:
return {"Code": 500, "Error": "Failed to get playlists"}
if not is_access_token_valid(auth_token):
_clear_auth_session()
return _unauthorized_session_response()
return jsonify({"Code": 502, "Error": "Failed to get playlists"}), 502

return jsonify(playlists)

Expand All @@ -282,10 +374,9 @@ def get_playlist_handler():
@app.route("/api/process-playlist", methods=["POST"])
def process_playlist_handler():
"""Start async processing job for selected playlists."""
auth_token = get_auth_token_from_request()

if not auth_token:
return "Authorization required", 401
auth_token, auth_error = _resolve_active_auth_token()
if auth_error:
return auth_error

missing_scopes = _missing_required_scopes()
if missing_scopes:
Expand All @@ -302,9 +393,31 @@ def process_playlist_handler():

assert request.json
playlist_ids = request.json.get("playlistIds", [])
feature_weights_payload = request.json.get("featureWeights")
split_criterion_payload = request.json.get("splitCriterion")

if not playlist_ids:
return "No playlist IDs provided", 400
if feature_weights_payload is not None and not isinstance(feature_weights_payload, dict):
return (
jsonify(
{
"Code": 400,
"Error": "featureWeights must be an object keyed by feature name.",
}
),
400,
)
feature_weights = (
normalize_feature_weights(feature_weights_payload)
if isinstance(feature_weights_payload, dict)
else None
)
split_criterion = (
split_criterion_payload.strip().lower()
if isinstance(split_criterion_payload, str) and split_criterion_payload.strip()
else None
)

_prune_old_jobs()
job_id = str(uuid.uuid4())
Expand All @@ -315,10 +428,18 @@ def process_playlist_handler():
finished_at=None,
error=None,
playlist_count=len(playlist_ids),
completed_playlists=0,
total_playlists=len(playlist_ids),
failed_playlists=0,
progress_percent=0,
last_completed_playlist_id=None,
last_completed_playlist_name=None,
feature_weights=feature_weights,
split_criterion=split_criterion,
)
job_thread = threading.Thread(
target=_run_process_playlist_job,
args=(job_id, auth_token, playlist_ids),
args=(job_id, auth_token, playlist_ids, feature_weights, split_criterion),
daemon=True,
)
job_thread.start()
Expand Down
30 changes: 28 additions & 2 deletions Backend/grouping.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Clustering utilities for grouping tracks by audio feature similarity."""

import os
import math

import numpy as np
from sklearn.metrics import pairwise_distances
Expand Down Expand Up @@ -40,6 +41,8 @@
"tempo": 0.85,
"valence": 1.55,
}
MIN_FEATURE_WEIGHT = 0.0
MAX_FEATURE_WEIGHT = 3.0


def _env_positive_int(name: str, default_value: int) -> int:
Expand Down Expand Up @@ -69,6 +72,26 @@ def _env_positive_int(name: str, default_value: int) -> int:
REFINE_MAX_TRACKS = _env_positive_int("CLUSTER_REFINE_MAX_TRACKS", 600)


def normalize_feature_weights(
feature_weights: dict[str, float] | None,
) -> dict[str, float]:
"""Return bounded feature weights merged with defaults."""
if not isinstance(feature_weights, dict):
return dict(FEATURE_WEIGHTS)

normalized = {}
for key, default_value in FEATURE_WEIGHTS.items():
raw_value = feature_weights.get(key, default_value)
try:
candidate = float(raw_value)
except (TypeError, ValueError):
candidate = default_value
if not math.isfinite(candidate):
candidate = default_value
normalized[key] = min(MAX_FEATURE_WEIGHT, max(MIN_FEATURE_WEIGHT, candidate))
return normalized


def _merge_small_clusters(
scaled_features: np.ndarray, labels: np.ndarray, min_cluster_size: int
) -> np.ndarray:
Expand Down Expand Up @@ -251,7 +274,9 @@ def _refine_cluster_cohesion(
return refined


def cluster_df(track_audio_features: list[dict]) -> pd.DataFrame:
def cluster_df(
track_audio_features: list[dict], feature_weights: dict[str, float] | None = None
) -> pd.DataFrame:
"""Return dataframe with track id and assigned GMM clusters."""
if not track_audio_features:
return pd.DataFrame(columns=["id", "cluster"])
Expand All @@ -278,7 +303,8 @@ def cluster_df(track_audio_features: list[dict]) -> pd.DataFrame:

scaler = StandardScaler()
scaled = scaler.fit_transform(feature_frame)
weights = np.array([FEATURE_WEIGHTS.get(key, 1.0) for key in available_keys])
effective_feature_weights = normalize_feature_weights(feature_weights)
weights = np.array([effective_feature_weights.get(key, 1.0) for key in available_keys])
weighted_scaled = scaled * weights

track_count = len(feature_frame)
Expand Down
Loading
Loading