diff --git a/AGENTS.md b/AGENTS.md index 444b6104f3..77f5d1b179 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -72,6 +72,9 @@ pusher agent-proxy (agent-proxy/main.py) └── ws ──► user agent VM (private IP, port 8080) +backend-sync (main.py, Cloud Run) + └── ──────► Cloud Tasks queue `sync-jobs` ──► POST /v2/sync-jobs/run (OIDC, same service) + notifications-job (modal/job.py) [cron] ``` @@ -83,6 +86,7 @@ Helm charts: `backend/charts/{backend-listen,pusher,diarizer,vad,deepgram-self-h - **diarizer** (`diarizer/main.py`) — GPU. Speaker embeddings at `/v2/embedding`. Called by backend and pusher (`HOSTED_SPEAKER_EMBEDDING_API_URL`). - **vad** (`modal/main.py`) — GPU. `/v1/vad` and `/v1/speaker-identification`. Called by backend only. - **deepgram** — STT. Streaming uses self-hosted (`DEEPGRAM_SELF_HOSTED_URL`) or cloud based on `DEEPGRAM_SELF_HOSTED_ENABLED`. Pre-recorded always uses Deepgram cloud. Called by backend and pusher. +- **backend-sync** (`main.py`, same image as backend) — Cloud Run service for `/v2/sync-local-files`. When `SYNC_DISPATCH_MODE=cloud_tasks`: stages raw audio in GCS, enqueues to Cloud Tasks queue `sync-jobs`, which POSTs `/v2/sync-jobs/run` (OIDC-verified, `utils/cloud_tasks.py`) to run decode→VAD→STT inside a request. Inline fallback when the flag is off, env is incomplete, BYOK headers are present, or enqueue fails. - **notifications-job** (`modal/job.py`) — Cron job, reads Firestore/Redis, sends push notifications. Keep this map up to date. When adding, removing, or changing inter-service calls, update this section. If a PR changes audio streaming, transcription, conversation lifecycle, speaker identification, or the listen/pusher WebSocket protocol — update `docs/doc/developer/backend/listen_pusher_pipeline.mdx` in the same PR. diff --git a/app/lib/pages/apps/app_detail/app_detail.dart b/app/lib/pages/apps/app_detail/app_detail.dart index cf558d5235..f7c944f2ec 100644 --- a/app/lib/pages/apps/app_detail/app_detail.dart +++ b/app/lib/pages/apps/app_detail/app_detail.dart @@ -682,8 +682,9 @@ class _AppDetailPageState extends State { // Get the position of the share button for iOS final RenderBox? box = context.findRenderObject() as RenderBox?; - final Rect? sharePositionOrigin = - box != null ? box.localToGlobal(Offset.zero) & box.size : null; + final Rect? sharePositionOrigin = box != null + ? box.localToGlobal(Offset.zero) & box.size + : null; await Share.share( 'https://h.omi.me/apps/${app.id}', @@ -697,32 +698,32 @@ class _AppDetailPageState extends State { ), appProvider.isAppOwner ? (isLoading - ? const SizedBox.shrink() - : Container( - width: 36, - height: 36, - margin: const EdgeInsets.only(right: 8), - decoration: BoxDecoration(color: Colors.grey.withOpacity(0.3), shape: BoxShape.circle), - child: IconButton( - padding: EdgeInsets.zero, - icon: const FaIcon(FontAwesomeIcons.edit, size: 16.0, color: Colors.white), - onPressed: () async { - HapticFeedback.mediumImpact(); - await showModalBottomSheet( - context: context, - shape: const RoundedRectangleBorder( - borderRadius: BorderRadius.only( - topLeft: Radius.circular(16), - topRight: Radius.circular(16), + ? const SizedBox.shrink() + : Container( + width: 36, + height: 36, + margin: const EdgeInsets.only(right: 8), + decoration: BoxDecoration(color: Colors.grey.withOpacity(0.3), shape: BoxShape.circle), + child: IconButton( + padding: EdgeInsets.zero, + icon: const FaIcon(FontAwesomeIcons.edit, size: 16.0, color: Colors.white), + onPressed: () async { + HapticFeedback.mediumImpact(); + await showModalBottomSheet( + context: context, + shape: const RoundedRectangleBorder( + borderRadius: BorderRadius.only( + topLeft: Radius.circular(16), + topRight: Radius.circular(16), + ), ), - ), - builder: (context) { - return ShowAppOptionsSheet(app: app); - }, - ); - }, - ), - )) + builder: (context) { + return ShowAppOptionsSheet(app: app); + }, + ); + }, + ), + )) : const SizedBox(width: 8), ], ), @@ -839,73 +840,73 @@ class _AppDetailPageState extends State { color: const Color(0xFF35343B), ) : app.enabled - ? AnimatedLoadingButton( - text: 'Disable', - width: 90, - height: 32, - onPressed: () => _toggleApp(app.id, false), - color: Colors.grey.shade700, - ) - : (app.isPaid && !app.isUserPaid - ? AnimatedLoadingButton( - width: 100, - height: 32, - text: "Subscribe", - onPressed: () async { - // Track subscribe button clicked - PlatformManager.instance.analytics.appDetailSubscribeClicked( - appId: app.id, - appName: app.name, - ); - - if (app.paymentLink != null && app.paymentLink!.isNotEmpty) { - final uri = Uri.tryParse(app.paymentLink!); - if (uri == null) { - ScaffoldMessenger.of(context).showSnackBar( - SnackBar(content: Text(context.l10n.invalidPaymentUrl)), - ); - return; - } - _checkPaymentStatus(app.id); - await _launchUrlSafely(uri); - } else { - await _toggleApp(app.id, true); - } - }, - color: const Color(0xFF8B5CF6), - ) - : AnimatedLoadingButton( - width: 75, - height: 32, - text: 'Enable', - onPressed: () async { - if (app.worksExternally()) { - showDialog( - context: context, - builder: (ctx) { - return StatefulBuilder( - builder: (ctx, setState) { - return ConfirmationDialog( - title: context.l10n.dataAccessNotice, - description: context.l10n.dataAccessNoticeDescription, - onConfirm: () { - _toggleApp(app.id, true); - Navigator.pop(context); - }, - onCancel: () { - Navigator.pop(context); - }, - ); - }, - ); - }, + ? AnimatedLoadingButton( + text: 'Disable', + width: 90, + height: 32, + onPressed: () => _toggleApp(app.id, false), + color: Colors.grey.shade700, + ) + : (app.isPaid && !app.isUserPaid + ? AnimatedLoadingButton( + width: 100, + height: 32, + text: "Subscribe", + onPressed: () async { + // Track subscribe button clicked + PlatformManager.instance.analytics.appDetailSubscribeClicked( + appId: app.id, + appName: app.name, + ); + + if (app.paymentLink != null && app.paymentLink!.isNotEmpty) { + final uri = Uri.tryParse(app.paymentLink!); + if (uri == null) { + ScaffoldMessenger.of(context).showSnackBar( + SnackBar(content: Text(context.l10n.invalidPaymentUrl)), ); - } else { - _toggleApp(app.id, true); + return; } - }, - color: const Color(0xFF8B5CF6), - )), + _checkPaymentStatus(app.id); + await _launchUrlSafely(uri); + } else { + await _toggleApp(app.id, true); + } + }, + color: const Color(0xFF8B5CF6), + ) + : AnimatedLoadingButton( + width: 75, + height: 32, + text: 'Enable', + onPressed: () async { + if (app.worksExternally()) { + showDialog( + context: context, + builder: (ctx) { + return StatefulBuilder( + builder: (ctx, setState) { + return ConfirmationDialog( + title: context.l10n.dataAccessNotice, + description: context.l10n.dataAccessNoticeDescription, + onConfirm: () { + _toggleApp(app.id, true); + Navigator.pop(context); + }, + onCancel: () { + Navigator.pop(context); + }, + ); + }, + ); + }, + ); + } else { + _toggleApp(app.id, true); + } + }, + color: const Color(0xFF8B5CF6), + )), ], ), ), @@ -1396,8 +1397,10 @@ class _AppDetailPageState extends State { ), const SizedBox(height: 16), RecentReviewsSection( - reviews: - app.reviews.sorted((a, b) => b.ratedAt.compareTo(a.ratedAt)).take(3).toList(), + reviews: app.reviews + .sorted((a, b) => b.ratedAt.compareTo(a.ratedAt)) + .take(3) + .toList(), userReview: app.userReview, app: app, onReviewUpdated: () { @@ -1679,8 +1682,8 @@ class _RecentReviewsSectionState extends State { final userName = widget.userReview?.username.isNotEmpty == true ? widget.userReview!.username : prefs.fullName.isNotEmpty - ? prefs.fullName - : prefs.givenName; + ? prefs.fullName + : prefs.givenName; final rev = AppReview( uid: prefs.uid, @@ -1897,8 +1900,9 @@ class _RecentReviewsSectionState extends State { Widget _buildReviewItem(BuildContext context, AppReview review, {bool isUserReview = false}) { final l10n = AppLocalizations.of(context)!; - final displayName = - isUserReview ? l10n.yourReview : (review.username.isNotEmpty ? review.username : l10n.anonymousUser); + final displayName = isUserReview + ? l10n.yourReview + : (review.username.isNotEmpty ? review.username : l10n.anonymousUser); final avatarSeed = review.uid.isNotEmpty ? review.uid : review.username; return Padding( diff --git a/app/lib/pages/chat/widgets/ai_message.dart b/app/lib/pages/chat/widgets/ai_message.dart index 548bcdba6a..a5429ab08d 100644 --- a/app/lib/pages/chat/widgets/ai_message.dart +++ b/app/lib/pages/chat/widgets/ai_message.dart @@ -59,7 +59,8 @@ Widget _buildAppIcon(BuildContext context, String appId, {double size = 15, doub final appProvider = Provider.of(context, listen: false); final messageProvider = Provider.of(context, listen: false); // Check both public apps and user's installed chat apps (includes private MCP apps) - final app = appProvider.apps.firstWhereOrNull((a) => a.id == appId) ?? + final app = + appProvider.apps.firstWhereOrNull((a) => a.id == appId) ?? messageProvider.chatApps.firstWhereOrNull((a) => a.id == appId); if (app != null) { @@ -751,28 +752,28 @@ class _MemoriesMessageWidgetState extends State { ), ) : widget.showTypingIndicator - ? const Row( - mainAxisSize: MainAxisSize.min, - crossAxisAlignment: CrossAxisAlignment.start, - mainAxisAlignment: MainAxisAlignment.start, - children: [SizedBox(width: 4), TypingIndicator(), Spacer()], - ) - : Builder( - builder: (context) { - String? selectedText; - return SelectionArea( - onSelectionChanged: (SelectedContent? selectedContent) { - selectedText = selectedContent?.plainText; - }, - contextMenuBuilder: (context, selectableRegionState) { - return omiSelectionMenuBuilder(context, selectableRegionState, (text) { - widget.onAskOmi?.call(text); - }, selectedText: selectedText); - }, - child: getMarkdownWidget(context, widget.messageText, onAskOmi: widget.onAskOmi), - ); + ? const Row( + mainAxisSize: MainAxisSize.min, + crossAxisAlignment: CrossAxisAlignment.start, + mainAxisAlignment: MainAxisAlignment.start, + children: [SizedBox(width: 4), TypingIndicator(), Spacer()], + ) + : Builder( + builder: (context) { + String? selectedText; + return SelectionArea( + onSelectionChanged: (SelectedContent? selectedContent) { + selectedText = selectedContent?.plainText; }, - ), + contextMenuBuilder: (context, selectableRegionState) { + return omiSelectionMenuBuilder(context, selectableRegionState, (text) { + widget.onAskOmi?.call(text); + }, selectedText: selectedText); + }, + child: getMarkdownWidget(context, widget.messageText, onAskOmi: widget.onAskOmi), + ); + }, + ), if (widget.messageText.isNotEmpty && widget.messageText != '...' && !widget.showTypingIndicator) MessageActionBar( messageText: widget.messageText, diff --git a/backend/database/sync_jobs.py b/backend/database/sync_jobs.py index 3426249b8f..c8a12b24f1 100644 --- a/backend/database/sync_jobs.py +++ b/backend/database/sync_jobs.py @@ -35,6 +35,16 @@ JOB_TTL_SECONDS = 86400 # 24 hours — reconcile window (see module docstring) STALE_THRESHOLD_SECONDS = 600 # 10 minutes — if processing exceeds this, treat as failed +TERMINAL_STATUSES = ('completed', 'partial_failure', 'failed') + +RUN_LOCK_KEY_PREFIX = 'sync_job_lock:' +# Must stay above the handler's request timeout (HTTP_SYNC_JOBS_RUN_TIMEOUT, +# 1500s) so the lock can never expire while a run is still executing. +RUN_LOCK_TTL_SECONDS = 1800 + +PROCESSED_SEGMENTS_KEY_PREFIX = 'sync_job_segments:' +ONCE_KEY_PREFIX = 'sync_job_once:' + def create_sync_job(uid: str, total_files: int, total_segments: int, job_id: str | None = None) -> dict: """Create a new sync job and store in Redis. Returns the job dict.""" @@ -163,3 +173,91 @@ def mark_job_failed(job_id: str, error: str) -> Optional[dict]: 'error': error, }, ) + + +def mark_job_queued_for_retry(job_id: str, attempt: int, error: str) -> Optional[dict]: + """Reset a job to 'queued' before a Cloud Tasks retry. + + 'queued' is exempt from the stale detector in get_sync_job(), so the app + polling during the retry backoff window cannot flip the job to a terminal + 'failed' while a retry is still pending. + """ + return update_sync_job( + job_id, + { + 'status': 'queued', + 'attempt': attempt, + 'last_error': error, + }, + ) + + +def try_acquire_job_run_lock(job_id: str) -> Optional[str]: + """Acquire the per-job run lock. Returns a release token, or None if held. + + Fails CLOSED: Redis errors propagate to the caller. An unobtainable lock + must block execution (the Cloud Tasks retry will come back later), never + allow two concurrent runs of the same job. + """ + token = str(uuid.uuid4()) + acquired = r.set(f'{RUN_LOCK_KEY_PREFIX}{job_id}', token, nx=True, ex=RUN_LOCK_TTL_SECONDS) + return token if acquired else None + + +_RELEASE_LOCK_SCRIPT = """ +if redis.call('get', KEYS[1]) == ARGV[1] then + return redis.call('del', KEYS[1]) +end +return 0 +""" + + +def release_job_run_lock(job_id: str, token: str) -> None: + """Release the run lock if we still own it (compare-and-delete). + + Best-effort: on Redis failure the lock simply expires via its TTL and a + duplicate delivery in the meantime gets 409-retried. + """ + try: + r.eval(_RELEASE_LOCK_SCRIPT, 1, f'{RUN_LOCK_KEY_PREFIX}{job_id}', token) + except Exception as e: + logger.warning('release_job_run_lock failed for %s: %s', job_id, e) + + +def add_processed_segment(job_id: str, segment_path: str) -> None: + """Record a segment as fully processed (conversation written) for this job. + + Lets a Cloud Tasks retry skip segments that already landed. Best-effort: + on failure the retry falls back to the timestamp-based segment dedup. + """ + try: + key = f'{PROCESSED_SEGMENTS_KEY_PREFIX}{job_id}' + r.sadd(key, segment_path) + r.expire(key, JOB_TTL_SECONDS) + except Exception as e: + logger.warning('add_processed_segment failed for %s: %s', job_id, e) + + +def get_processed_segments(job_id: str) -> set: + """Return segment paths already processed for this job.""" + try: + members = r.smembers(f'{PROCESSED_SEGMENTS_KEY_PREFIX}{job_id}') + return {m.decode() if isinstance(m, bytes) else m for m in members} + except Exception as e: + logger.warning('get_processed_segments failed for %s: %s', job_id, e) + return set() + + +def try_mark_once(job_id: str, tag: str) -> bool: + """SETNX guard so per-job side effects (fair-use metering, usage recording) + run at most once across Cloud Tasks retries. + + Fails OPEN (returns True on Redis error) to match the metering functions' + own fail-open posture — better to occasionally double-count than to + silently never count. + """ + try: + return bool(r.set(f'{ONCE_KEY_PREFIX}{job_id}:{tag}', '1', nx=True, ex=JOB_TTL_SECONDS)) + except Exception as e: + logger.warning('try_mark_once failed for %s:%s: %s', job_id, tag, e) + return True diff --git a/backend/main.py b/backend/main.py index f6f54404e4..119486f091 100644 --- a/backend/main.py +++ b/backend/main.py @@ -153,7 +153,14 @@ "DELETE": os.environ.get('HTTP_DELETE_TIMEOUT'), } -app.add_middleware(TimeoutMiddleware, methods_timeout=methods_timeout) +# The Cloud Tasks sync-job handler runs the whole pipeline inside the request, +# so it needs a much higher cap than the default. Must stay below the job run +# lock TTL (1800s) so a lock can never expire under a live run. +paths_timeout = { + "/v2/sync-jobs/run": os.environ.get('HTTP_SYNC_JOBS_RUN_TIMEOUT', 1500), +} + +app.add_middleware(TimeoutMiddleware, methods_timeout=methods_timeout, paths_timeout=paths_timeout) from utils.byok import BYOKMiddleware diff --git a/backend/requirements.txt b/backend/requirements.txt index 5330b1ceea..49bf31caf1 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -60,6 +60,7 @@ google-auth-httplib2==0.2.0 google-cloud-core==2.4.1 google-cloud-firestore==2.20.0 google-cloud-storage==2.18.0 +google-cloud-tasks==2.16.4 google-crc32c==1.5.0 google-resumable-media==2.7.1 googleapis-common-protos==1.63.2 diff --git a/backend/routers/sync.py b/backend/routers/sync.py index b33bfcb9f5..d92fbabf9d 100644 --- a/backend/routers/sync.py +++ b/backend/routers/sync.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import io import logging import os @@ -36,12 +37,19 @@ from database import users as users_db from database.conversations import get_closest_conversation_to_timestamps, update_conversation_segments from database.sync_jobs import ( + TERMINAL_STATUSES, create_sync_job, get_sync_job, update_sync_job, mark_job_processing, mark_job_completed, mark_job_failed, + mark_job_queued_for_retry, + try_acquire_job_run_lock, + release_job_run_lock, + add_processed_segment, + get_processed_segments, + try_mark_once, ) from models.conversation import Conversation, CreateConversation from models.conversation_enums import ConversationSource @@ -53,6 +61,8 @@ from utils.other.storage import ( get_syncing_file_temporal_signed_url, delete_syncing_temporal_file, + upload_syncing_temporal_file, + download_syncing_temporal_file, download_audio_chunks_and_merge, get_or_create_merged_audio, get_merged_audio_signed_url, @@ -60,7 +70,13 @@ ) from utils import encryption -from utils.byok import get_byok_keys, set_byok_keys +from utils.byok import get_byok_keys, set_byok_keys, has_byok_keys +from utils.cloud_tasks import ( + enqueue_sync_job, + get_sync_tasks_max_attempts, + is_cloud_tasks_dispatch_enabled, + verify_cloud_tasks_oidc, +) from utils.http_client import _get_semaphore from utils.log_sanitizer import sanitize from utils.stt.pre_recorded import postprocess_words, prerecorded @@ -1032,11 +1048,11 @@ def delete_file(): # DG processed audio successfully but found no speech (silence/noise). # Real DG failures now raise RuntimeError and are caught by the except block. logger.info(f'No transcript words for segment {path} (silence or noise-only audio)') - return + return True transcript_segments: List[TranscriptSegment] = postprocess_words(words, 0) if not transcript_segments: logger.warning(f'Postprocessing returned empty for segment {path} (words present but no segments)') - return + return True # Speaker identification: voice embedding matching + text-based detection audio_bytes = _download_audio_bytes(url) if person_embeddings_cache else None @@ -1108,7 +1124,7 @@ def delete_file(): logger.info(f'All segments already exist in conversation {closest_memory["id"]}, skipping merge') with lock: response['updated_memories'].add(closest_memory['id']) - return + return True # merge and sort segments by start timestamp segments = closest_memory['transcript_segments'] + deduped_segments @@ -1154,11 +1170,13 @@ def delete_file(): # instead of once per merged segment. with lock: response.setdefault('_merged', {})[closest_memory['id']] = language + return True except Exception as e: error_msg = f'Failed to process segment {path}: {e}' logger.error(error_msg) with lock: errors.append(error_msg) + return False finally: if turnstile: turnstile.complete(path) @@ -1457,14 +1475,26 @@ async def _run_full_pipeline_background_async( should_lock: bool, job_dir: str, target_conversation_id: str = None, + task_mode: bool = False, ): """Async coordinator for the full sync pipeline (decode → VAD → fair-use → STT → LLM). - Runs as an asyncio task on the event loop. All blocking work is offloaded to - thread pools via run_blocking(). The coordinator itself holds zero thread pool - slots — only leaf operations use threads, and only for their actual duration. + Inline dispatch (task_mode=False): runs as a fire-and-forget asyncio task, + bounded by the per-instance pipeline semaphore; unexpected errors mark the + job failed (no retry exists). + + Cloud Tasks dispatch (task_mode=True): runs inside the /v2/sync-jobs/run + request — Cloud Run's containerConcurrency is the concurrency bound, so no + semaphore; unexpected errors re-raise so the handler can reset the job for + a queue retry; segments that completed in a prior attempt are skipped via + the processed-segment ledger. + + All blocking work is offloaded to thread pools via run_blocking(). The + coordinator itself holds zero thread pool slots — only leaf operations use + threads, and only for their actual duration. """ - async with _get_sync_pipeline_semaphore(): + concurrency_gate = contextlib.nullcontext() if task_mode else _get_sync_pipeline_semaphore() + async with concurrency_gate: segmented_paths = set() wav_paths = [] stage_timings = {} @@ -1563,7 +1593,9 @@ def _run_vad_bg(path): return if FAIR_USE_ENABLED and total_speech_ms > 0: - await run_blocking(db_executor, record_speech_ms, uid, total_speech_ms, source='sync') + # Once-guard: a Cloud Tasks retry must not count the same audio twice + if await run_blocking(db_executor, try_mark_once, job_id, 'speech_ms'): + await run_blocking(db_executor, record_speech_ms, uid, total_speech_ms, source='sync') speech_totals = await run_blocking(db_executor, get_rolling_speech_ms, uid) triggered_caps = await run_blocking(db_executor, check_soft_caps, uid, speech_totals=speech_totals) if triggered_caps: @@ -1609,6 +1641,16 @@ def _run_vad_bg(path): segment_errors = [] segment_lock = threading.Lock() + # Segments that fully landed in a prior Cloud Tasks attempt are skipped + already_processed = set() + if task_mode: + already_processed = await run_blocking(db_executor, get_processed_segments, job_id) + if already_processed: + logger.info( + f'sync_v2 bg: job={job_id} skipping {len(already_processed)} ' + f'segment(s) processed in a prior attempt' + ) + # Chronological order + turnstile: STT runs in parallel (per chunk), but # conversation assignment is serialized oldest-first so adjacent chunks merge # instead of racing into separate conversations (#6551, #5747). @@ -1616,7 +1658,11 @@ def _run_vad_bg(path): assignment_turnstile = _OrderedTurnstile(segment_list) def _process_one_segment(path): - process_segment( + if path in already_processed: + # Release the assignment slot — later segments wait on it + assignment_turnstile.complete(path) + return + ok = process_segment( path, uid, response, @@ -1629,6 +1675,8 @@ def _process_one_segment(path): target_conversation_id, assignment_turnstile, ) + if ok and task_mode: + add_processed_segment(job_id, path) chunk_size = 5 for i in range(0, len(segment_list), chunk_size): @@ -1664,7 +1712,7 @@ def _process_one_segment(path): if fair_use_restrict_dg: try: dg_ms = int(total_speech_seconds * 1000) - if dg_ms > 0: + if dg_ms > 0 and await run_blocking(db_executor, try_mark_once, job_id, 'dg_ms'): await run_blocking(db_executor, record_dg_usage_ms, uid, dg_ms) except Exception as e: logger.error(f'sync_v2 bg: DG usage record error for {uid}: {e}') @@ -1684,7 +1732,7 @@ def _process_one_segment(path): if successful_segments > 0: try: usage_seconds = int(total_speech_seconds) - if usage_seconds > 0: + if usage_seconds > 0 and await run_blocking(db_executor, try_mark_once, job_id, 'usage'): await run_blocking( db_executor, record_usage, @@ -1717,6 +1765,10 @@ def _process_one_segment(path): ) except Exception as e: logger.error(f'sync_v2 bg failed job={job_id} uid={uid}: {e}') + if task_mode: + # Let the handler decide: queued-reset + Cloud Tasks retry, or + # final-attempt consume. Marking failed here would be terminal. + raise try: await run_blocking(db_executor, mark_job_failed, job_id, str(e)) except Exception: @@ -1732,6 +1784,32 @@ def _process_one_segment(path): logger.error(f'sync_v2 bg: failed to cleanup job dir {job_dir}: {e}') +def _stage_files_to_gcs(paths: list): + """Upload raw .bin files to the syncing bucket (blob name = local path).""" + for p in paths: + upload_syncing_temporal_file(p) + + +def _delete_staged_blobs(blob_paths: list): + for p in blob_paths: + try: + delete_syncing_temporal_file(p) + except Exception as e: + logger.warning(f'Failed to delete staged blob {p}: {e}') + + +async def _delete_staged_blobs_async(blob_paths: list): + await run_blocking(storage_executor, _delete_staged_blobs, blob_paths) + + +def _download_staged_files(blob_paths: list) -> bool: + """Download staged blobs back to their local paths. False if any is gone.""" + for p in blob_paths: + if not download_syncing_temporal_file(p): + return False + return True + + @router.post("/v2/sync-local-files") async def sync_local_files_v2( files: List[UploadFile] = File(...), @@ -1777,20 +1855,52 @@ async def sync_local_files_v2( owned_paths = list(paths) paths = [] # Prevent finally cleanup of files now owned by bg task - # Async coordinator: runs on event loop, offloads blocking work to pools. - # No thread pool slot held for the full pipeline duration (fixes #7361). - start_background_task( - _run_full_pipeline_background_async( - job_id, - uid, - owned_paths, - source, - should_lock, - job_dir, - conversation_id, - ), - name=f'sync_pipeline:{job_id}', - ) + dispatched = False + # BYOK keys live only in this request's context and cannot follow a + # Cloud Task, so BYOK requests always run inline. + if is_cloud_tasks_dispatch_enabled() and not has_byok_keys(): + try: + # sync_executor, NOT storage_executor — same reasoning as the + # file save above (#7372): a saturated storage pool would queue + # the staging upload and delay the 202. + await run_blocking(sync_executor, _stage_files_to_gcs, owned_paths) + await run_blocking( + db_executor, + enqueue_sync_job, + { + 'schema_version': 1, + 'job_id': job_id, + 'uid': uid, + 'raw_blob_paths': owned_paths, + 'source': source.value, + 'should_lock': should_lock, + 'conversation_id': conversation_id, + 'enqueued_at': time.time(), + }, + ) + dispatched = True + # The handler instance downloads from GCS; local copies are done + await run_blocking(sync_executor, _cleanup_files, owned_paths) + await run_blocking(sync_executor, shutil.rmtree, job_dir, True) + except Exception as e: + logger.error(f'sync_v2: Cloud Tasks dispatch failed job={job_id}, falling back inline: {e}') + start_background_task(_delete_staged_blobs_async(owned_paths), name=f'sync_unstage:{job_id}') + + if not dispatched: + # Async coordinator: runs on event loop, offloads blocking work to pools. + # No thread pool slot held for the full pipeline duration (fixes #7361). + start_background_task( + _run_full_pipeline_background_async( + job_id, + uid, + owned_paths, + source, + should_lock, + job_dir, + conversation_id, + ), + name=f'sync_pipeline:{job_id}', + ) return JSONResponse( status_code=202, @@ -1837,3 +1947,89 @@ def get_sync_job_status(job_id: str, uid: str = Depends(auth.get_current_user_ui resp['error'] = job['error'] return resp + + +@router.post("/v2/sync-jobs/run", include_in_schema=False) +async def run_sync_job(request: Request, task_retry_count: int = Depends(verify_cloud_tasks_oidc)): + """Cloud Tasks handler: runs one sync job inside the request. + + Auth is the Cloud Tasks OIDC token (verify_cloud_tasks_oidc), not Firebase. + Response semantics drive the queue: 2xx consumes the task, 409 while the + run-lock is held retries later, 500 retries with backoff. + + Idempotency: a per-job Redis run-lock serializes concurrent deliveries; + terminal jobs are acked without re-running; segments completed by a prior + attempt are skipped via the processed-segment ledger inside the pipeline. + """ + try: + payload = await request.json() + job_id = payload['job_id'] + uid = payload['uid'] + blob_paths = list(payload['raw_blob_paths']) + source = ConversationSource(payload.get('source') or 'omi') + should_lock = bool(payload.get('should_lock', False)) + conversation_id = payload.get('conversation_id') + except Exception as e: + # A malformed payload will not fix itself on retry — consume it. + logger.error(f'sync job handler: invalid payload, dropping task: {e}') + return JSONResponse(status_code=200, content={'status': 'dropped', 'reason': 'invalid_payload'}) + + # Fail-closed lock: Redis errors propagate → 500 → Cloud Tasks retries later. + lock_token = await run_blocking(db_executor, try_acquire_job_run_lock, job_id) + if not lock_token: + logger.warning(f'sync job {job_id}: run-lock held by another attempt, deferring') + return JSONResponse(status_code=409, content={'status': 'locked'}) + + try: + job = await run_blocking(db_executor, get_sync_job, job_id) + if not job: + # Job TTL (24h) expired before dispatch — staged blobs are gone or + # about to be (1-day lifecycle); the app re-uploads on 404. + logger.warning(f'sync job {job_id}: job expired before dispatch, dropping task') + await _delete_staged_blobs_async(blob_paths) + return JSONResponse(status_code=200, content={'status': 'dropped', 'reason': 'job_expired'}) + + if job['status'] in TERMINAL_STATUSES: + # Duplicate delivery, stale-detector-failed job, or a prior attempt + # that finished. Never re-run terminal jobs — the app may already be + # re-uploading these files as a new job. + await _delete_staged_blobs_async(blob_paths) + return JSONResponse(status_code=200, content={'status': 'acked', 'job_status': job['status']}) + + if not await run_blocking(storage_executor, _download_staged_files, blob_paths): + # Blobs deleted by the bucket's 1-day lifecycle (deep queue backlog). + await run_blocking(db_executor, mark_job_failed, job_id, 'Staged audio expired before processing') + await _delete_staged_blobs_async(blob_paths) + return JSONResponse(status_code=200, content={'status': 'dropped', 'reason': 'staged_audio_expired'}) + + job_dir = f'syncing/{uid}/{job_id}' + try: + await _run_full_pipeline_background_async( + job_id, + uid, + blob_paths, + source, + should_lock, + job_dir, + conversation_id, + task_mode=True, + ) + except Exception as e: + max_attempts = get_sync_tasks_max_attempts() + if task_retry_count >= max_attempts - 1: + logger.error(f'sync job {job_id}: final attempt {task_retry_count + 1} failed, consuming: {e}') + await run_blocking(db_executor, mark_job_failed, job_id, f'Failed after {max_attempts} attempts: {e}') + await _delete_staged_blobs_async(blob_paths) + return JSONResponse(status_code=200, content={'status': 'failed_final'}) + # Reset to 'queued' so the stale detector cannot terminally fail the + # job while the Cloud Tasks retry backoff elapses. Blobs are kept. + logger.warning(f'sync job {job_id}: attempt {task_retry_count + 1} failed, will retry: {e}') + await run_blocking(db_executor, mark_job_queued_for_retry, job_id, task_retry_count + 1, str(e)) + return JSONResponse(status_code=500, content={'status': 'retry'}) + + # Pipeline returned normally: completed, or it marked the job failed + # itself (decode/VAD/DG-budget) — terminal either way, staging is done. + await _delete_staged_blobs_async(blob_paths) + return JSONResponse(status_code=200, content={'status': 'done'}) + finally: + await run_blocking(db_executor, release_job_run_lock, job_id, lock_token) diff --git a/backend/test.sh b/backend/test.sh index f48ea67ce5..cf12aea3d9 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -108,6 +108,7 @@ pytest tests/unit/test_rate_limiting.py -v pytest tests/unit/test_memories_batch.py -v pytest tests/unit/test_memories_create.py -v pytest tests/unit/test_sync_v2.py -v +pytest tests/unit/test_sync_cloud_tasks.py -v pytest tests/unit/test_sync_transcription_prefs.py -v pytest tests/unit/test_sync_record_usage.py -v pytest tests/unit/test_vision_stream_async.py -v diff --git a/backend/tests/unit/test_sync_cloud_tasks.py b/backend/tests/unit/test_sync_cloud_tasks.py new file mode 100644 index 0000000000..6e6b28cc4e --- /dev/null +++ b/backend/tests/unit/test_sync_cloud_tasks.py @@ -0,0 +1,288 @@ +""" +Tests for Cloud Tasks dispatch of the v2 sync pipeline. + +Covers the new primitives in database/sync_jobs.py (run lock, queued-reset, +processed-segment ledger, metering once-guards), the OIDC verification in +utils/cloud_tasks.py, and the structural contract of the /v2/sync-jobs/run +handler in routers/sync.py. +""" + +import os +import sys +import unittest +from unittest.mock import MagicMock, patch + +import pytest + +BACKEND_DIR = os.path.join(os.path.dirname(__file__), '..', '..') + + +def _load_module_with_stubs(relative_path, module_name, stubs): + """Load a backend module with selected imports stubbed in sys.modules.""" + import importlib.util + + saved = {} + for mod, mock in stubs.items(): + saved[mod] = sys.modules.get(mod) + sys.modules[mod] = mock + try: + spec = importlib.util.spec_from_file_location(module_name, os.path.join(BACKEND_DIR, relative_path)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + finally: + for mod, original in saved.items(): + if original is None: + sys.modules.pop(mod, None) + else: + sys.modules[mod] = original + + +def _load_sync_jobs(): + mock_redis = MagicMock() + module = _load_module_with_stubs( + os.path.join('database', 'sync_jobs.py'), + 'sync_jobs_under_test', + { + 'database': MagicMock(), + 'database.redis_db': MagicMock(r=mock_redis), + }, + ) + return module, mock_redis + + +# --------------------------------------------------------------------------- +# Run lock +# --------------------------------------------------------------------------- + + +class TestJobRunLock: + def test_acquire_returns_token_when_free(self): + sync_jobs, mock_redis = _load_sync_jobs() + mock_redis.set.return_value = True + token = sync_jobs.try_acquire_job_run_lock('job-1') + assert token + args, kwargs = mock_redis.set.call_args + assert args[0] == 'sync_job_lock:job-1' + assert kwargs['nx'] is True + assert kwargs['ex'] == sync_jobs.RUN_LOCK_TTL_SECONDS + + def test_acquire_returns_none_when_held(self): + sync_jobs, mock_redis = _load_sync_jobs() + mock_redis.set.return_value = None + assert sync_jobs.try_acquire_job_run_lock('job-1') is None + + def test_acquire_fails_closed_on_redis_error(self): + """A Redis outage must block execution, never allow duplicate runs.""" + sync_jobs, mock_redis = _load_sync_jobs() + mock_redis.set.side_effect = ConnectionError('redis down') + with pytest.raises(ConnectionError): + sync_jobs.try_acquire_job_run_lock('job-1') + + def test_release_is_compare_and_delete(self): + sync_jobs, mock_redis = _load_sync_jobs() + sync_jobs.release_job_run_lock('job-1', 'tok') + args = mock_redis.eval.call_args[0] + assert args[1] == 1 + assert args[2] == 'sync_job_lock:job-1' + assert args[3] == 'tok' + + def test_release_swallows_redis_errors(self): + """Failed release just lets the TTL expire — must not raise.""" + sync_jobs, mock_redis = _load_sync_jobs() + mock_redis.eval.side_effect = ConnectionError('redis down') + sync_jobs.release_job_run_lock('job-1', 'tok') + + def test_lock_ttl_exceeds_handler_timeout(self): + """Invariant: a run-lock can never expire under a live run (request + timeout HTTP_SYNC_JOBS_RUN_TIMEOUT=1500 < lock TTL).""" + sync_jobs, _ = _load_sync_jobs() + assert sync_jobs.RUN_LOCK_TTL_SECONDS > 1500 + + +# --------------------------------------------------------------------------- +# Queued-reset, ledger, once-guards +# --------------------------------------------------------------------------- + + +class TestRetryPrimitives: + def test_mark_job_queued_for_retry_resets_status(self): + import json + + sync_jobs, mock_redis = _load_sync_jobs() + mock_redis.get.return_value = json.dumps({'job_id': 'job-1', 'status': 'processing'}) + sync_jobs.mark_job_queued_for_retry('job-1', attempt=2, error='boom') + written = json.loads(mock_redis.set.call_args[0][1]) + assert written['status'] == 'queued' + assert written['attempt'] == 2 + assert written['last_error'] == 'boom' + + def test_terminal_statuses(self): + sync_jobs, _ = _load_sync_jobs() + assert set(sync_jobs.TERMINAL_STATUSES) == {'completed', 'partial_failure', 'failed'} + + def test_processed_segment_ledger_roundtrip(self): + sync_jobs, mock_redis = _load_sync_jobs() + sync_jobs.add_processed_segment('job-1', 'syncing/u/job-1/seg_1.wav') + mock_redis.sadd.assert_called_once_with('sync_job_segments:job-1', 'syncing/u/job-1/seg_1.wav') + mock_redis.expire.assert_called_once() + + mock_redis.smembers.return_value = {b'a.wav', 'b.wav'} + assert sync_jobs.get_processed_segments('job-1') == {'a.wav', 'b.wav'} + + def test_ledger_fails_open(self): + sync_jobs, mock_redis = _load_sync_jobs() + mock_redis.sadd.side_effect = ConnectionError('redis down') + sync_jobs.add_processed_segment('job-1', 'x.wav') # must not raise + mock_redis.smembers.side_effect = ConnectionError('redis down') + assert sync_jobs.get_processed_segments('job-1') == set() + + def test_try_mark_once_first_and_second_call(self): + sync_jobs, mock_redis = _load_sync_jobs() + mock_redis.set.return_value = True + assert sync_jobs.try_mark_once('job-1', 'speech_ms') is True + mock_redis.set.return_value = None + assert sync_jobs.try_mark_once('job-1', 'speech_ms') is False + + def test_try_mark_once_fails_open(self): + """Metering guard prefers occasional double-count over never counting.""" + sync_jobs, mock_redis = _load_sync_jobs() + mock_redis.set.side_effect = ConnectionError('redis down') + assert sync_jobs.try_mark_once('job-1', 'usage') is True + + +# --------------------------------------------------------------------------- +# OIDC verification (utils/cloud_tasks.py) +# --------------------------------------------------------------------------- + + +def _load_cloud_tasks(): + tasks_v2_mock = MagicMock() + return _load_module_with_stubs( + os.path.join('utils', 'cloud_tasks.py'), + 'cloud_tasks_under_test', + {'google.cloud.tasks_v2': tasks_v2_mock}, + ) + + +def _request_with(headers: dict): + request = MagicMock() + request.headers = headers + return request + + +OIDC_ENV = { + 'SYNC_TASKS_HANDLER_URL': 'https://backend-sync.example.com/v2/sync-jobs/run', + 'SYNC_TASKS_INVOKER_SA': 'invoker@project.iam.gserviceaccount.com', +} + + +class TestVerifyCloudTasksOidc: + def test_env_unset_fails_closed(self): + """Services not configured as task targets must reject all task traffic.""" + from fastapi import HTTPException + + cloud_tasks = _load_cloud_tasks() + with patch.dict(os.environ, {}, clear=False): + os.environ.pop('SYNC_TASKS_HANDLER_URL', None) + os.environ.pop('SYNC_TASKS_OIDC_AUDIENCE', None) + os.environ.pop('SYNC_TASKS_INVOKER_SA', None) + with pytest.raises(HTTPException) as exc: + cloud_tasks.verify_cloud_tasks_oidc(_request_with({'authorization': 'Bearer x'})) + assert exc.value.status_code == 403 + + def test_missing_bearer_rejected(self): + from fastapi import HTTPException + + cloud_tasks = _load_cloud_tasks() + with patch.dict(os.environ, OIDC_ENV): + with pytest.raises(HTTPException) as exc: + cloud_tasks.verify_cloud_tasks_oidc(_request_with({})) + assert exc.value.status_code == 403 + + def test_invalid_token_rejected(self): + from fastapi import HTTPException + + cloud_tasks = _load_cloud_tasks() + with patch.dict(os.environ, OIDC_ENV): + with patch.object(cloud_tasks.id_token, 'verify_oauth2_token', side_effect=ValueError('bad')): + with pytest.raises(HTTPException) as exc: + cloud_tasks.verify_cloud_tasks_oidc(_request_with({'authorization': 'Bearer bad'})) + assert exc.value.status_code == 403 + + def test_wrong_identity_rejected(self): + from fastapi import HTTPException + + cloud_tasks = _load_cloud_tasks() + claims = {'email': 'attacker@project.iam.gserviceaccount.com', 'email_verified': True} + with patch.dict(os.environ, OIDC_ENV): + with patch.object(cloud_tasks.id_token, 'verify_oauth2_token', return_value=claims): + with pytest.raises(HTTPException) as exc: + cloud_tasks.verify_cloud_tasks_oidc(_request_with({'authorization': 'Bearer t'})) + assert exc.value.status_code == 403 + + def test_valid_token_returns_retry_count(self): + cloud_tasks = _load_cloud_tasks() + claims = {'email': OIDC_ENV['SYNC_TASKS_INVOKER_SA'], 'email_verified': True} + headers = {'authorization': 'Bearer t', 'x-cloudtasks-taskretrycount': '3'} + with patch.dict(os.environ, OIDC_ENV): + with patch.object(cloud_tasks.id_token, 'verify_oauth2_token', return_value=claims) as verify: + assert cloud_tasks.verify_cloud_tasks_oidc(_request_with(headers)) == 3 + assert verify.call_args.kwargs['audience'] == OIDC_ENV['SYNC_TASKS_HANDLER_URL'] + + def test_enqueue_requires_complete_env(self): + cloud_tasks = _load_cloud_tasks() + with patch.dict(os.environ, {}, clear=False): + for var in ('SYNC_TASKS_PROJECT', 'SYNC_TASKS_LOCATION', 'SYNC_TASKS_QUEUE'): + os.environ.pop(var, None) + with pytest.raises(RuntimeError): + cloud_tasks.enqueue_sync_job({'job_id': 'j'}) + + +# --------------------------------------------------------------------------- +# Structural contract of routers/sync.py and main.py wiring +# --------------------------------------------------------------------------- + + +class TestSyncRouterStructure: + @staticmethod + def _read(relative_path): + with open(os.path.join(BACKEND_DIR, relative_path), encoding='utf-8') as f: + return f.read() + + def test_handler_endpoint_exists_with_oidc(self): + source = self._read(os.path.join('routers', 'sync.py')) + assert '"/v2/sync-jobs/run"' in source + assert 'Depends(verify_cloud_tasks_oidc)' in source + + def test_handler_respects_terminal_statuses(self): + source = self._read(os.path.join('routers', 'sync.py')) + handler = source[source.index('async def run_sync_job') :] + assert 'TERMINAL_STATUSES' in handler + assert 'mark_job_queued_for_retry' in handler + assert 'status_code=409' in handler + + def test_fast_path_gates_on_env_and_byok(self): + source = self._read(os.path.join('routers', 'sync.py')) + assert 'is_cloud_tasks_dispatch_enabled() and not has_byok_keys()' in source + + def test_pipeline_reraises_in_task_mode(self): + source = self._read(os.path.join('routers', 'sync.py')) + assert 'task_mode: bool = False' in source + # Catch-all must re-raise in task mode so the handler controls retry + assert 'if task_mode:' in source + + def test_timeout_override_wired(self): + main_src = self._read('main.py') + assert 'paths_timeout' in main_src + assert 'HTTP_SYNC_JOBS_RUN_TIMEOUT' in main_src + timeout_src = self._read(os.path.join('utils', 'other', 'timeout.py')) + assert 'paths_timeout' in timeout_src + + def test_v1_endpoint_unchanged(self): + source = self._read(os.path.join('routers', 'sync.py')) + assert '"/v1/sync-local-files"' in source + + +if __name__ == '__main__': + sys.exit(pytest.main([__file__, '-v'])) diff --git a/backend/tests/unit/test_sync_silent_failure.py b/backend/tests/unit/test_sync_silent_failure.py index 8d056e6b0a..3fe9af73b4 100644 --- a/backend/tests/unit/test_sync_silent_failure.py +++ b/backend/tests/unit/test_sync_silent_failure.py @@ -617,6 +617,7 @@ def test_all_duplicates_skips_merge(self): 'utils.stt.speaker_embedding', 'utils.fair_use', 'utils.subscription', + 'utils.cloud_tasks', 'utils.conversations.process_conversation', ] @@ -659,6 +660,12 @@ def setup_class(cls): sys.modules['utils.other.storage'].get_or_create_merged_audio = MagicMock() sys.modules['utils.other.storage'].get_merged_audio_signed_url = MagicMock() sys.modules['utils.other.storage']._PRECACHE_FILE_SEM = MagicMock() + sys.modules['utils.other.storage'].upload_syncing_temporal_file = MagicMock() + sys.modules['utils.other.storage'].download_syncing_temporal_file = MagicMock(return_value=True) + sys.modules['utils.cloud_tasks'].enqueue_sync_job = MagicMock() + sys.modules['utils.cloud_tasks'].get_sync_tasks_max_attempts = MagicMock(return_value=5) + sys.modules['utils.cloud_tasks'].is_cloud_tasks_dispatch_enabled = MagicMock(return_value=False) + sys.modules['utils.cloud_tasks'].verify_cloud_tasks_oidc = MagicMock() sys.modules['utils.log_sanitizer'].sanitize = lambda value: value sys.modules['utils.encryption'].encrypt = MagicMock() sys.modules['utils.stt.pre_recorded'].deepgram_prerecorded = MagicMock() diff --git a/backend/tests/unit/test_sync_v2.py b/backend/tests/unit/test_sync_v2.py index 0280be503e..25ba0fab5a 100644 --- a/backend/tests/unit/test_sync_v2.py +++ b/backend/tests/unit/test_sync_v2.py @@ -1161,7 +1161,7 @@ def test_target_conversation_id_forwarded_to_process_segment(self): """target_conversation_id must be passed through to _process_one_segment / process_segment.""" body = self._get_bg_func_body() process_segment_section = body[body.index('def _process_one_segment') :] - process_segment_call = process_segment_section[:500] + process_segment_call = process_segment_section[:800] assert 'target_conversation_id' in process_segment_call # --- Cleanup on success and failure --- @@ -1196,10 +1196,14 @@ def test_finally_removes_job_directory(self): assert 'job_dir' in after_finally def test_general_exception_marks_failed(self): - """General except Exception must mark job failed with error message.""" + """Inline mode: general except Exception must mark job failed. + Task mode: it must re-raise instead, so the Cloud Tasks handler + controls retry vs final-attempt consume.""" body = self._get_bg_func_body() main_except = body[body.index("except Exception as e:\n logger.error(f'sync_v2 bg failed") :] - main_except_early = main_except[:200] + main_except_early = main_except[:600] + assert 'if task_mode:' in main_except_early + assert 'raise' in main_except_early assert 'mark_job_failed' in main_except_early def test_cleanup_order_byok_before_files(self): @@ -1248,6 +1252,7 @@ def _load_sync_module(): 'utils', 'utils.analytics', 'utils.byok', + 'utils.cloud_tasks', 'utils.conversations', 'utils.conversations.process_conversation', 'utils.conversations.factory', @@ -1688,6 +1693,7 @@ def _build_test_app(): 'utils', 'utils.analytics', 'utils.byok', + 'utils.cloud_tasks', 'utils.conversations', 'utils.conversations.process_conversation', 'utils.conversations.factory', diff --git a/backend/utils/cloud_tasks.py b/backend/utils/cloud_tasks.py new file mode 100644 index 0000000000..7bcd4cf39f --- /dev/null +++ b/backend/utils/cloud_tasks.py @@ -0,0 +1,134 @@ +"""Cloud Tasks dispatch + OIDC verification for the v2 sync pipeline. + +The /v2/sync-local-files fast path enqueues one named task per sync job; +Cloud Tasks POSTs it back to /v2/sync-jobs/run on the backend-sync service +with an OIDC token minted for SYNC_TASKS_INVOKER_SA. + +All functions fail closed when the SYNC_TASKS_* env vars are unset: enqueue +raises (caller falls back to the inline pipeline) and verification returns +403 — the handler ships in the shared image to services that must never +accept task traffic. +""" + +import json +import logging +import os +from typing import Optional + +from fastapi import HTTPException, Request +from google.api_core.exceptions import AlreadyExists +from google.auth.transport import requests as google_auth_requests +from google.cloud import tasks_v2 +from google.oauth2 import id_token +from google.protobuf import duration_pb2 + +logger = logging.getLogger(__name__) + +# Must match the queue's dispatchDeadline and the handler's request timeout +# (HTTP_SYNC_JOBS_RUN_TIMEOUT); see the run-lock TTL invariant in sync_jobs.py. +DISPATCH_DEADLINE_SECONDS = 1500 + +_tasks_client: Optional[tasks_v2.CloudTasksClient] = None +_google_auth_request: Optional[google_auth_requests.Request] = None + + +def _get_tasks_client() -> tasks_v2.CloudTasksClient: + global _tasks_client + if _tasks_client is None: + _tasks_client = tasks_v2.CloudTasksClient() + return _tasks_client + + +def _get_auth_request() -> google_auth_requests.Request: + global _google_auth_request + if _google_auth_request is None: + _google_auth_request = google_auth_requests.Request() + return _google_auth_request + + +def _handler_url() -> str: + return os.getenv('SYNC_TASKS_HANDLER_URL', '') + + +def _oidc_audience() -> str: + return os.getenv('SYNC_TASKS_OIDC_AUDIENCE') or _handler_url() + + +def _invoker_sa() -> str: + return os.getenv('SYNC_TASKS_INVOKER_SA', '') + + +def get_sync_tasks_max_attempts() -> int: + # Must mirror the queue's maxAttempts (documented invariant). + return int(os.getenv('SYNC_TASKS_MAX_ATTEMPTS', '5')) + + +def is_cloud_tasks_dispatch_enabled() -> bool: + return os.getenv('SYNC_DISPATCH_MODE', 'inline') == 'cloud_tasks' + + +def enqueue_sync_job(payload: dict) -> None: + """Enqueue one named HTTP task (task id = job_id) for a sync job. + + A duplicate enqueue of the same job_id is treated as success — Cloud Tasks + deduplicates named tasks. Any other failure raises; the caller falls back + to the inline pipeline. + """ + project = os.getenv('SYNC_TASKS_PROJECT', '') + location = os.getenv('SYNC_TASKS_LOCATION', '') + queue = os.getenv('SYNC_TASKS_QUEUE', '') + url = _handler_url() + invoker_sa = _invoker_sa() + if not all([project, location, queue, url, invoker_sa]): + raise RuntimeError('Cloud Tasks dispatch enabled but SYNC_TASKS_* env vars are incomplete') + + client = _get_tasks_client() + parent = client.queue_path(project, location, queue) + task = tasks_v2.Task( + name=client.task_path(project, location, queue, payload['job_id']), + http_request=tasks_v2.HttpRequest( + http_method=tasks_v2.HttpMethod.POST, + url=url, + headers={'Content-Type': 'application/json'}, + body=json.dumps(payload).encode(), + oidc_token=tasks_v2.OidcToken(service_account_email=invoker_sa, audience=_oidc_audience()), + ), + dispatch_deadline=duration_pb2.Duration(seconds=DISPATCH_DEADLINE_SECONDS), + ) + try: + client.create_task(parent=parent, task=task) + except AlreadyExists: + logger.info('sync job task %s already enqueued, skipping duplicate', payload['job_id']) + + +def verify_cloud_tasks_oidc(request: Request) -> int: + """FastAPI dependency for /v2/sync-jobs/run. Returns the task retry count. + + Sync function on purpose — verify_oauth2_token fetches Google certs over + HTTP, and FastAPI runs sync dependencies in the threadpool. + """ + audience = _oidc_audience() + invoker_sa = _invoker_sa() + if not audience or not invoker_sa: + # Env unset: this service is not a task target (e.g. main backend + # running the shared image) — never accept task traffic. + raise HTTPException(status_code=403, detail='Task dispatch not configured on this service') + + auth_header = request.headers.get('authorization', '') + if not auth_header.startswith('Bearer '): + raise HTTPException(status_code=403, detail='Missing bearer token') + + try: + claims = id_token.verify_oauth2_token(auth_header[len('Bearer ') :], _get_auth_request(), audience=audience) + except Exception as e: + # Distinguishes bad tokens from transient JWKS-fetch failures in logs + logger.warning('OIDC token verification failed: %s', e) + raise HTTPException(status_code=403, detail='Invalid OIDC token') + + if claims.get('email') != invoker_sa or not claims.get('email_verified'): + raise HTTPException(status_code=403, detail='Unexpected token identity') + + try: + return int(request.headers.get('x-cloudtasks-taskretrycount', '0')) + except ValueError: + return 0 diff --git a/backend/utils/other/storage.py b/backend/utils/other/storage.py index a887ef9341..df95cdc00f 100644 --- a/backend/utils/other/storage.py +++ b/backend/utils/other/storage.py @@ -342,6 +342,30 @@ def delete_syncing_temporal_file(file_path: str): pass +def upload_syncing_temporal_file(file_path: str): + """Stage a local file in the syncing bucket (blob name = local relative path).""" + bucket = storage_client.bucket(syncing_local_bucket) + bucket.blob(file_path).upload_from_filename(file_path) + + +def download_syncing_temporal_file(file_path: str) -> bool: + """Download a staged blob back to its local relative path. + + Returns False when the blob no longer exists (e.g. deleted by the + bucket's 1-day lifecycle rule before a deeply delayed task ran). + """ + bucket = storage_client.bucket(syncing_local_bucket) + blob = bucket.blob(file_path) + directory = os.path.dirname(file_path) + if directory: + os.makedirs(directory, exist_ok=True) + try: + blob.download_to_filename(file_path) + return True + except BlobNotFound: + return False + + # ************************************************ # *********** PRIVATE CLOUD SYNC ***************** # ************************************************ diff --git a/backend/utils/other/timeout.py b/backend/utils/other/timeout.py index 8eb9d89510..831a3c81ee 100644 --- a/backend/utils/other/timeout.py +++ b/backend/utils/other/timeout.py @@ -7,7 +7,7 @@ class TimeoutMiddleware(BaseHTTPMiddleware): - def __init__(self, app, methods_timeout: dict = None): + def __init__(self, app, methods_timeout: dict = None, paths_timeout: dict = None): super().__init__(app) self.default_timeout = self._get_timeout_from_env("HTTP_DEFAULT_TIMEOUT", default=2 * 60) @@ -15,6 +15,7 @@ def __init__(self, app, methods_timeout: dict = None): self.clock_skew_allowance = self._get_timeout_from_env("HTTP_CLOCK_SKEW_ALLOWANCE", default=5 * 60) self.methods_timeout = self._parse_methods_timeout(methods_timeout or {}) + self.paths_timeout = self._parse_paths_timeout(paths_timeout or {}) @staticmethod def _get_timeout_from_env(env_var: str, default: float) -> float: @@ -36,6 +37,18 @@ def _parse_methods_timeout(methods_timeout: dict) -> dict: raise ValueError(f"Invalid timeout value for method {method}: {timeout}") return result + @staticmethod + def _parse_paths_timeout(paths_timeout: dict) -> dict: + result = {} + for path, timeout in paths_timeout.items(): + if timeout is None: + continue + try: + result[path] = float(timeout) + except ValueError: + raise ValueError(f"Invalid timeout value for path {path}: {timeout}") + return result + async def dispatch(self, request: Request, call_next): # Check for stale request header first # Uses clock_skew_allowance to tolerate client/server clock drift (#5929) @@ -61,7 +74,10 @@ async def dispatch(self, request: Request, call_next): except (ValueError, TypeError): pass - timeout = self.methods_timeout.get(request.method, self.default_timeout) + path_timeout = self.paths_timeout.get(request.url.path) + timeout = ( + path_timeout if path_timeout is not None else self.methods_timeout.get(request.method, self.default_timeout) + ) try: return await asyncio.wait_for(call_next(request), timeout=timeout) except asyncio.TimeoutError: