-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathjina_server.py
More file actions
1163 lines (952 loc) · 38.2 KB
/
jina_server.py
File metadata and controls
1163 lines (952 loc) · 38.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Jina Embedding & Reranker Server
OpenAI-compatible API for embeddings + reranking
Endpoints:
- POST /v1/embeddings (OpenAI compatible)
- POST /v1/rerank (Jina/Cohere style)
- POST /v1/files (OpenAI Files API)
- POST /v1/batches (OpenAI Batch API)
"""
import asyncio
import os
import sys
import time
import uuid
import json
import threading
from typing import List, Optional, Union, Dict, Any
from contextlib import asynccontextmanager
import platform
# =============================================================================
# Thread configuration — MUST be set before torch import
# =============================================================================
# Tuned for AMD R9 9900X (12C/24T, Zen 5)
# Use physical cores only; SMT yields negligible gains for dense matmul
MAX_THREADS = 24
os.environ["OMP_NUM_THREADS"] = str(MAX_THREADS)
os.environ["MKL_NUM_THREADS"] = str(MAX_THREADS)
os.environ["OPENBLAS_NUM_THREADS"] = str(MAX_THREADS)
os.environ["NUMEXPR_NUM_THREADS"] = str(MAX_THREADS)
os.environ["TORCH_INTEROP_THREADS"] = "8"
import torch
torch.set_num_threads(MAX_THREADS)
from fastapi import FastAPI, HTTPException, UploadFile, File, BackgroundTasks
from fastapi.responses import Response
from pydantic import BaseModel, Field, field_validator, model_validator
from sentence_transformers import SentenceTransformer
# =============================================================================
# Configuration
# =============================================================================
MODELS_DIR = r".\jinaai"
EMBEDDING_MODEL_PATH = os.path.join(MODELS_DIR, "jina-embeddings-v5-text-small")
RERANKER_MODEL_PATH = os.path.join(MODELS_DIR, "jina-reranker-v3")
# Valid task types for jina-embeddings-v5-text-small (LoRA task adapters)
# See: https://huggingface.co/jinaai/jina-embeddings-v5-text-small
VALID_EMBEDDING_TASKS = ("retrieval", "text-matching", "clustering", "classification")
# For retrieval task, prompt_name selects query vs document prefix
VALID_PROMPT_NAMES = ("query", "document")
# Jina cloud API sends dot-notation tasks; map them to (task, prompt_name)
TASK_ALIAS_MAP = {
"retrieval.query": ("retrieval", "query"),
"retrieval.passage": ("retrieval", "document"),
"text-matching": ("text-matching", None),
"classification": ("classification", None),
"clustering": ("clustering", None),
"separation": ("text-matching", None),
}
# Check CPU capabilities
def check_cpu_capabilities():
"""Check CPU instruction set support using py-cpuinfo."""
import cpuinfo
cpu_info = {
"cpu_count": os.cpu_count(),
"pytorch_threads": torch.get_num_threads(),
"platform": platform.platform(),
"cpu_brand": cpuinfo.get_cpu_info().get("brand_raw", "Unknown"),
"avx": False,
"avx2": False,
"avx512": False,
"avx512f": False,
"avx512_vnni": False,
}
# Get all CPU flags
info = cpuinfo.get_cpu_info()
flags = info.get("flags", [])
# Check instruction set support
cpu_info["avx"] = "avx" in flags
cpu_info["avx2"] = "avx2" in flags
cpu_info["avx512f"] = "avx512f" in flags
cpu_info["avx512_vnni"] = "avx512vnni" in flags
# AVX512 considered supported if F (foundation) is present
cpu_info["avx512"] = cpu_info["avx512f"]
# Log detailed AVX512 subsets if available
if cpu_info["avx512"]:
avx512_subsets = [f for f in flags if f.startswith("avx512")]
cpu_info["avx512_subsets"] = avx512_subsets
return cpu_info
# Global model references
embedding_model = None
reranker_model = None
reranker_lock = threading.Lock() # Protects reranker_model._block_size mutations
# In-memory storage for Files and Batches
files_storage: Dict[str, Dict[str, Any]] = {} # file_id -> file metadata + content
batches_storage: Dict[str, Dict[str, Any]] = {} # batch_id -> batch metadata
# ---------------------------------------------------------------------------
# Dynamic batching infrastructure for /v1/embeddings
# ---------------------------------------------------------------------------
BATCH_WINDOW_MS = 50 # How long to wait before firing a batch (ms)
BATCH_MAX_SIZE = 64 # Max requests to accumulate before firing early
class _PendingEmbedRequest:
"""Holds a single request awaiting batch encoding."""
__slots__ = ("texts", "task", "prompt_name", "batch_size", "future")
def __init__(self, texts, task, prompt_name, batch_size, future):
self.texts = texts
self.task = task
self.prompt_name = prompt_name
self.batch_size = batch_size
self.future = future
_pending_embeddings: list[_PendingEmbedRequest] = []
_batch_flush_lock: asyncio.Lock | None = None
_batch_timer_handle: asyncio.TimerHandle | None = None
async def _flush_embedding_batch():
"""Drain pending requests, group by (task, prompt_name), batch-encode."""
global _pending_embeddings
if not _pending_embeddings:
return
# Grab everything in the queue
batch = _pending_embeddings[:]
_pending_embeddings = []
# Group by (task, prompt_name)
groups: Dict[tuple, List[_PendingEmbedRequest]] = {}
for req in batch:
key = (req.task, req.prompt_name)
groups.setdefault(key, []).append(req)
for (task, prompt_name), reqs in groups.items():
try:
# Flatten texts, track which request they belong to
flat_texts: List[str] = []
spans: List[tuple] = [] # (req_index, start, end)
offset = 0
for idx, r in enumerate(reqs):
flat_texts.extend(r.texts)
spans.append((idx, offset, offset + len(r.texts)))
offset += len(r.texts)
max_bs = max(r.batch_size for r in reqs)
encode_kwargs = {
"task": task,
"normalize_embeddings": True,
"batch_size": max_bs,
"convert_to_numpy": False,
}
if prompt_name is not None:
encode_kwargs["prompt_name"] = prompt_name
all_embs = embedding_model.encode(flat_texts, **encode_kwargs)
# Distribute results back
for req_idx, start, end in spans:
reqs[req_idx].future.set_result(all_embs[start:end])
except Exception as e:
# If encoding fails, propagate to all requests in group
for r in reqs:
if not r.future.done():
r.future.set_exception(e)
def _schedule_batch_flush():
"""Schedule a flush after BATCH_WINDOW_MS unless one is already pending."""
global _batch_timer_handle
loop = asyncio.get_running_loop()
if _batch_timer_handle is not None and not _batch_timer_handle.cancelled():
return # Already scheduled
_batch_timer_handle = loop.call_later(
BATCH_WINDOW_MS / 1000.0,
lambda: loop.create_task(_safe_flush()),
)
async def _safe_flush():
"""Flush with lock to prevent concurrent flushes."""
global _batch_timer_handle
if _batch_flush_lock is None:
return
async with _batch_flush_lock:
_batch_timer_handle = None # Allow next request to schedule new timer
await _flush_embedding_batch()
# =============================================================================
# Lifespan
# =============================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load models on startup."""
global embedding_model, reranker_model, _batch_flush_lock
# Display CPU capabilities
cpu_caps = check_cpu_capabilities()
print("=" * 60)
print("System Information")
print("=" * 60)
print(f" Platform: {cpu_caps['platform']}")
print(f" CPU: {cpu_caps['cpu_brand']}")
print(f" CPU Cores: {cpu_caps['cpu_count']}")
print(f" PyTorch Threads: {cpu_caps['pytorch_threads']}")
print(f" AVX: {cpu_caps['avx']}")
print(f" AVX2: {cpu_caps['avx2']}")
print(f" AVX-512: {cpu_caps['avx512']}")
if cpu_caps.get("avx512"):
print(f" AVX-512 Subsets: {', '.join(cpu_caps.get('avx512_subsets', []))}")
print("=" * 60)
print("\nLoading models...")
print("=" * 60)
# Load embedding model
print(f"\n[1/2] Loading embedding model from: {EMBEDDING_MODEL_PATH}")
try:
embedding_model = SentenceTransformer(
EMBEDDING_MODEL_PATH,
trust_remote_code=True,
model_kwargs={
"default_task": "retrieval",
"default_prompt": "query",
"torch_dtype": torch.bfloat16,
"attn_implementation": "sdpa",
},
)
embedding_model.to(torch.device("cpu"))
print(
f" [OK] Embedding model loaded (dim={embedding_model.get_sentence_embedding_dimension()})"
)
except Exception as e:
print(f" [FAIL] Failed to load embedding model: {e}")
raise
# Compile embedding model for optimized CPU inference
# max-autotune: tries many kernel strategies, picks best for long-running serving
# dynamic=True: handles variable-length text inputs without recompilation
print("\n[OPTIMIZE] torch.compile() on embedding model...")
try:
embedding_model = torch.compile(embedding_model, dynamic=True, mode="max-autotune")
print(" [OK] torch.compile applied")
except Exception as e:
print(f" [WARN] torch.compile() failed (falling back to eager): {e}")
# Load reranker model
print(f"\n[2/2] Loading reranker model from: {RERANKER_MODEL_PATH}")
try:
# Import from local modeling.py to avoid AutoModel issues
sys.path.insert(0, RERANKER_MODEL_PATH)
from modeling import JinaForRanking
from transformers import AutoConfig
config = AutoConfig.from_pretrained(
RERANKER_MODEL_PATH, trust_remote_code=True, dtype=torch.bfloat16
)
config._attn_implementation = "sdpa"
reranker_model = JinaForRanking(config)
reranker_model.eval()
reranker_model = torch.compile(reranker_model, mode="max-autotune", dynamic=True)
reranker_model.to(torch.device("cpu"))
print(" [OK] Reranker model loaded")
except Exception as e:
print(f" [FAIL] Failed to load reranker model: {e}")
raise
# Initialize dynamic batching lock (needs running event loop)
_batch_flush_lock = asyncio.Lock()
print("\n" + "=" * 60)
print("Server ready!")
print("=" * 60)
# Warmup: pre-allocate memory and trigger JIT compilation
print("\n[WARMUP] Pre-warming models...")
_ = embedding_model.encode(
[
"dummy_text_here_put_your_string_lol1234567890 vgyfsFewwg4rgeghrafW EDDDDD₫fvvv俄国v恶个过程各方位"
],
task="retrieval",
prompt_name="query",
convert_to_numpy=False,
normalize_embeddings=True,
)
try:
_ = reranker_model.rerank(
"warmup query", ["dummy_text_here_put_your_string_lol"], top_n=1
)
except Exception:
pass
print(" [OK] Models pre-warmed")
yield
# Cleanup
print("Shutting down...")
app = FastAPI(
title="Jina Embedding & Reranker Server",
description="OpenAI-compatible API for embeddings + reranking",
version="1.5.0",
lifespan=lifespan,
)
# =============================================================================
# Request/Response Models
# =============================================================================
class EmbeddingRequest(BaseModel):
"""OpenAI-compatible embedding request."""
input: Union[str, List[str]] = Field(..., description="Text to embed")
model: str = Field(default="jina-embeddings-v5-text-small")
encoding_format: str = Field(default="float", description="float or base64")
batch_size: int = Field(
default=32, ge=1, le=128, description="Batch size for processing"
)
task: str = Field(
default="retrieval",
description="Task adapter: retrieval, text-matching, clustering, classification",
)
prompt_name: Optional[str] = Field(
default="query",
description="For retrieval task only: 'query' or 'document'. Required when task='retrieval'.",
)
@field_validator("task")
@classmethod
def validate_task(cls, v: str) -> str:
# Accept both plain tasks ("retrieval") and Jina cloud aliases ("retrieval.query")
# Aliases are resolved after validation — see model_validator below
if v not in VALID_EMBEDDING_TASKS and v not in TASK_ALIAS_MAP:
raise ValueError(
f"Invalid task '{v}'. Must be one of: {VALID_EMBEDDING_TASKS} or aliases: {list(TASK_ALIAS_MAP.keys())}"
)
return v
@field_validator("prompt_name")
@classmethod
def validate_prompt_name(cls, v: Optional[str], info) -> Optional[str]:
if v is not None and v not in VALID_PROMPT_NAMES:
raise ValueError(
f"Invalid prompt_name '{v}'. Must be one of: {VALID_PROMPT_NAMES}"
)
return v
@model_validator(mode="after")
def resolve_task_alias(self) -> "EmbeddingRequest":
"""Expand dot-notation task aliases (e.g. 'retrieval.query') into task + prompt_name."""
if self.task in TASK_ALIAS_MAP:
resolved_task, resolved_prompt = TASK_ALIAS_MAP[self.task]
object.__setattr__(self, "task", resolved_task)
if self.prompt_name is None:
object.__setattr__(self, "prompt_name", resolved_prompt)
# Retrieval task requires prompt_name
if self.task == "retrieval" and self.prompt_name is None:
raise ValueError(
"prompt_name is required when task='retrieval'. Use 'query' or 'document'."
)
return self
class EmbeddingObject(BaseModel):
"""Single embedding object."""
object: str = "embedding"
index: int
embedding: List[float]
class EmbeddingResponse(BaseModel):
"""OpenAI-compatible embedding response."""
object: str = "list"
data: List[EmbeddingObject]
model: str
usage: dict
class RerankRequest(BaseModel):
"""Rerank request (Jina/Cohere style)."""
model: str = Field(default="jina-reranker-v3")
query: str = Field(..., description="Search query")
documents: List[str] = Field(..., description="Documents to rerank")
top_n: Optional[int] = Field(default=None, description="Return only top N results")
return_documents: bool = Field(
default=False, description="Include document text in response"
)
batch_size: int = Field(
default=64, ge=1, le=256, description="Batch size for reranking"
)
class RerankResult(BaseModel):
"""Single rerank result."""
index: int
relevance_score: float
document: Optional[str] = None
class RerankResponse(BaseModel):
"""Rerank response."""
model: str
results: List[RerankResult]
usage: dict
# =============================================================================
# File & Batch Models (OpenAI Batch API)
# =============================================================================
class FileObject(BaseModel):
"""OpenAI File object."""
id: str
object: str = "file"
bytes: int
created_at: int
filename: str
purpose: str
status: str = "uploaded"
class FileListResponse(BaseModel):
"""List of files."""
object: str = "list"
data: List[FileObject]
class BatchRequest(BaseModel):
"""Batch creation request."""
input_file_id: str
endpoint: str = "/v1/embeddings"
completion_window: str = "24h"
metadata: Optional[Dict[str, str]] = None
class BatchObject(BaseModel):
"""OpenAI Batch object."""
id: str
object: str = "batch"
endpoint: str
errors: Optional[Dict[str, Any]] = None
input_file_id: str
completion_window: str
status: str
output_file_id: Optional[str] = None
error_file_id: Optional[str] = None
created_at: int
in_progress_at: Optional[int] = None
expires_at: Optional[int] = None
finalizing_at: Optional[int] = None
completed_at: Optional[int] = None
failed_at: Optional[int] = None
expired_at: Optional[int] = None
cancelling_at: Optional[int] = None
cancelled_at: Optional[int] = None
request_counts: Dict[str, int] = Field(
default_factory=lambda: {"total": 0, "completed": 0, "failed": 0}
)
metadata: Optional[Dict[str, str]] = None
class BatchListResponse(BaseModel):
"""List of batches."""
object: str = "list"
data: List[BatchObject]
# =============================================================================
# Endpoints
# =============================================================================
@app.get("/")
async def root():
"""Health check."""
return {
"status": "ok",
"models": {
"embedding": embedding_model is not None,
"reranker": reranker_model is not None,
},
}
@app.get("/v1/models")
async def list_models():
"""List available models."""
return {
"object": "list",
"data": [
{
"id": "jina-embeddings-v5-text-small",
"object": "model",
"owned_by": "jina-ai",
},
{
"id": "jina-reranker-v3",
"object": "model",
"owned_by": "jina-ai",
},
],
}
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def create_embeddings(request: EmbeddingRequest):
"""
Create embeddings for input text(s).
OpenAI-compatible endpoint with dynamic batching.
"""
if embedding_model is None:
raise HTTPException(status_code=503, detail="Embedding model not loaded")
# Normalize input to list
texts = [request.input] if isinstance(request.input, str) else request.input
if not texts:
raise HTTPException(status_code=400, detail="Input cannot be empty")
start_time = time.time()
# Submit to dynamic batch queue when available and request is small enough
if _batch_flush_lock is not None and len(texts) <= BATCH_MAX_SIZE:
loop = asyncio.get_running_loop()
future = loop.create_future()
_pending_embeddings.append(
_PendingEmbedRequest(
texts=texts,
task=request.task,
prompt_name=request.prompt_name,
batch_size=request.batch_size,
future=future,
)
)
# Flush immediately if we hit max batch size, otherwise schedule timer
if len(_pending_embeddings) >= BATCH_MAX_SIZE:
loop.create_task(_safe_flush())
else:
_schedule_batch_flush()
all_embeddings = await future
else:
# Fallback: encode directly (large single requests or before lifespan init)
encode_kwargs = {
"task": request.task,
"normalize_embeddings": True,
"batch_size": request.batch_size,
"convert_to_numpy": False,
}
if request.prompt_name is not None:
encode_kwargs["prompt_name"] = request.prompt_name
loop = asyncio.get_running_loop()
all_embeddings = embedding_model.encode(texts, **encode_kwargs)
elapsed = time.time() - start_time
print(
f" [INFO] Processed {len(texts)} texts in {elapsed:.2f}s ({len(texts) / elapsed:.1f} texts/s)"
)
# Build response
data = []
total_tokens = 0
for i, emb in enumerate(all_embeddings):
data.append(
EmbeddingObject(
object="embedding",
index=i,
embedding=emb.tolist(),
)
)
# Estimate tokens (rough)
total_tokens += len(texts[i].split()) * 2
return EmbeddingResponse(
object="list",
data=data,
model=request.model,
usage={
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
"batch_size": request.batch_size,
},
)
@app.post("/v1/rerank", response_model=RerankResponse)
async def rerank(request: RerankRequest):
"""
Rerank documents based on query relevance.
Jina/Cohere style endpoint.
"""
if reranker_model is None:
raise HTTPException(status_code=503, detail="Reranker model not loaded")
if not request.documents:
raise HTTPException(status_code=400, detail="Documents cannot be empty")
if not request.query:
raise HTTPException(status_code=400, detail="Query cannot be empty")
# Rerank using model's built-in rerank method with batch control
start_time = time.time()
# Thread-safe: protect _block_size mutation from concurrent requests
with reranker_lock:
original_block_size = getattr(reranker_model, "_block_size", 125)
reranker_model._block_size = request.batch_size
results = reranker_model.rerank(
query=request.query,
documents=request.documents,
top_n=request.top_n,
)
reranker_model._block_size = original_block_size
elapsed = time.time() - start_time
print(
f" [INFO] Reranked {len(request.documents)} documents in {elapsed:.2f}s (batch_size={request.batch_size})"
)
# Build response
rerank_results = []
for r in results:
rerank_results.append(
RerankResult(
index=r["index"],
relevance_score=r["relevance_score"],
document=r["document"] if request.return_documents else None,
)
)
# Estimate tokens
total_tokens = len(request.query.split()) * 2
for doc in request.documents:
total_tokens += len(doc.split()) * 2
return RerankResponse(
model=request.model,
results=rerank_results,
usage={
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
"batch_size": request.batch_size,
"elapsed_time": f"{elapsed:.3f}s",
},
)
# =============================================================================
# File Endpoints (OpenAI Files API)
# =============================================================================
@app.post("/v1/files", response_model=FileObject)
async def upload_file(file: UploadFile = File(...), purpose: str = "batch"):
"""
Upload a file for batch processing.
Expects JSONL format with one request per line.
"""
if purpose != "batch":
raise HTTPException(status_code=400, detail="Only 'batch' purpose is supported")
# Read file content
content = await file.read()
# Generate file ID
file_id = f"file-{uuid.uuid4().hex[:24]}"
created_at = int(time.time())
# Store file
files_storage[file_id] = {
"id": file_id,
"bytes": len(content),
"created_at": created_at,
"filename": file.filename or "batch.jsonl",
"purpose": purpose,
"status": "uploaded",
"content": content,
}
print(f" [INFO] Uploaded file {file_id}: {file.filename} ({len(content)} bytes)")
return FileObject(
id=file_id,
bytes=len(content),
created_at=created_at,
filename=file.filename or "batch.jsonl",
purpose=purpose,
status="uploaded",
)
@app.get("/v1/files", response_model=FileListResponse)
async def list_files(purpose: str = "batch"):
"""List all uploaded files."""
files = []
for file_id, file_data in files_storage.items():
if purpose and file_data.get("purpose") != purpose:
continue
files.append(
FileObject(
id=file_data["id"],
bytes=file_data["bytes"],
created_at=file_data["created_at"],
filename=file_data["filename"],
purpose=file_data["purpose"],
status=file_data.get("status", "uploaded"),
)
)
return FileListResponse(object="list", data=files)
@app.get("/v1/files/{file_id}", response_model=FileObject)
async def get_file(file_id: str):
"""Get file metadata."""
if file_id not in files_storage:
raise HTTPException(status_code=404, detail="File not found")
file_data = files_storage[file_id]
return FileObject(
id=file_data["id"],
bytes=file_data["bytes"],
created_at=file_data["created_at"],
filename=file_data["filename"],
purpose=file_data["purpose"],
status=file_data.get("status", "uploaded"),
)
@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
"""Delete a file."""
if file_id not in files_storage:
raise HTTPException(status_code=404, detail="File not found")
del files_storage[file_id]
print(f" [INFO] Deleted file {file_id}")
return {"id": file_id, "object": "file", "deleted": True}
@app.get("/v1/files/{file_id}/content")
async def get_file_content(file_id: str):
"""Get file content (for output files)."""
if file_id not in files_storage:
raise HTTPException(status_code=404, detail="File not found")
file_data = files_storage[file_id]
content = file_data.get("content", b"")
return Response(
content=content,
media_type="application/jsonl",
headers={
"Content-Disposition": f'attachment; filename="{file_data["filename"]}"'
},
)
# =============================================================================
# Batch Processing Logic
# =============================================================================
def _build_batch_object(batch: Dict[str, Any]) -> BatchObject:
"""Construct a BatchObject from batch storage dict."""
return BatchObject(
id=batch["id"],
endpoint=batch["endpoint"],
input_file_id=batch["input_file_id"],
completion_window=batch["completion_window"],
status=batch["status"],
created_at=batch["created_at"],
in_progress_at=batch.get("in_progress_at"),
completed_at=batch.get("completed_at"),
failed_at=batch.get("failed_at"),
output_file_id=batch.get("output_file_id"),
error_file_id=batch.get("error_file_id"),
errors=batch.get("errors"),
request_counts=batch.get(
"request_counts", {"total": 0, "completed": 0, "failed": 0}
),
metadata=batch.get("metadata"),
)
async def process_batch_job(batch_id: str):
"""Background task to process a batch job.
Optimization: collect all texts from same task/prompt_name group,
batch-encode them in a single model call, then distribute results.
"""
global embedding_model
if batch_id not in batches_storage:
return
batch = batches_storage[batch_id]
input_file_id = batch["input_file_id"]
if input_file_id not in files_storage:
batch["status"] = "failed"
batch["failed_at"] = int(time.time())
batch["errors"] = {"message": "Input file not found"}
return
# Update status to in_progress
batch["status"] = "in_progress"
batch["in_progress_at"] = int(time.time())
print(f" [BATCH] Starting batch {batch_id}")
try:
# Read input file
input_content = files_storage[input_file_id]["content"].decode("utf-8")
lines = [
line.strip() for line in input_content.strip().split("\n") if line.strip()
]
total_requests = len(lines)
default_batch_size = 32
# ---- Phase 1: Parse all requests, collect valid embedding tasks ----
# Group: (task, prompt_name) -> [(line_index, text_list)]
task_groups: Dict[tuple, List[tuple]] = {}
# Per-line metadata for reconstruction
line_meta: List[Dict[str, Any]] = []
parse_errors: List[tuple] = [] # (index, custom_id, error_msg)
for i, line in enumerate(lines):
custom_id = f"request-{i}"
try:
request_data = json.loads(line)
custom_id = request_data.get("custom_id", f"request-{i}")
body = request_data.get("body", {})
endpoint = request_data.get("endpoint") or batch.get("endpoint")
if endpoint != "/v1/embeddings":
raise ValueError(f"Unsupported endpoint: {endpoint}")
# Extract input texts
input_texts = body.get("input", [])
if isinstance(input_texts, str):
input_texts = [input_texts]
if not input_texts or embedding_model is None:
raise ValueError("No input texts or model not loaded")
# Resolve task/prompt_name
task = body.get("task", "retrieval")
prompt_name = body.get("prompt_name", None)
if task in TASK_ALIAS_MAP:
task, prompt_name = TASK_ALIAS_MAP[task]
if task not in VALID_EMBEDDING_TASKS:
raise ValueError(
f"Invalid task '{task}'. Must be one of: {VALID_EMBEDDING_TASKS}"
)
if prompt_name is not None and prompt_name not in VALID_PROMPT_NAMES:
raise ValueError(
f"Invalid prompt_name '{prompt_name}'. Must be one of: {VALID_PROMPT_NAMES}"
)
group_key = (task, prompt_name)
if group_key not in task_groups:
task_groups[group_key] = []
task_groups[group_key].append((i, input_texts))
line_meta.append(
{
"index": i,
"custom_id": custom_id,
"group_key": group_key,
"text_count": len(input_texts),
"model": body.get("model", "jina-embeddings-v5-text-small"),
"texts": input_texts,
}
)
except Exception as e:
parse_errors.append((i, custom_id, str(e)))
# ---- Phase 2: Batch-encode per (task, prompt_name) group ----
# Maps: line_index -> list[torch.Tensor] (embeddings for that line)
embeddings_by_line: Dict[int, list] = {}
for (task, prompt_name), entries in task_groups.items():
# Flatten all texts in this group, tracking which line they belong to
flat_texts: List[str] = []
line_spans: List[tuple] = [] # (line_index, start, end)
offset = 0
for line_idx, texts in entries:
flat_texts.extend(texts)
line_spans.append((line_idx, offset, offset + len(texts)))
offset += len(texts)
encode_kwargs = {
"task": task,
"normalize_embeddings": True,
"batch_size": default_batch_size,
"convert_to_numpy": False,
}
if prompt_name is not None:
encode_kwargs["prompt_name"] = prompt_name
all_embs = embedding_model.encode(flat_texts, **encode_kwargs)
# Distribute embeddings back to their originating lines
for line_idx, start, end in line_spans:
embeddings_by_line[line_idx] = all_embs[start:end]
# ---- Phase 3: Build results ----
results: List[Dict[str, Any]] = []
completed = 0
failed = len(parse_errors)
# Successful lines (in original order)
for meta in line_meta:
idx = meta["index"]
embs = embeddings_by_line.get(idx, [])
response_data = []
total_tokens = 0
for emb_idx, emb in enumerate(embs):
response_data.append(
{
"object": "embedding",
"index": emb_idx,
"embedding": emb.tolist(),
}
)
total_tokens += len(meta["texts"][emb_idx].split()) * 2
results.append(
{
"id": f"resp-{uuid.uuid4().hex[:24]}",
"custom_id": meta["custom_id"],
"response": {
"status_code": 200,
"body": {
"object": "list",
"data": response_data,
"model": meta["model"],
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
},
},
},
"error": None,
}
)
completed += 1
# Failed lines from parse errors
for err_idx, custom_id, err_msg in parse_errors:
results.append(
{
"id": f"resp-{uuid.uuid4().hex[:24]}",
"custom_id": custom_id,
"response": None,
"error": {"message": err_msg, "type": "processing_error"},
}