From fea70ee4490e9819330e6831b850e7b2aad404e8 Mon Sep 17 00:00:00 2001 From: Sigrid Jin Date: Fri, 5 Dec 2025 10:20:29 +0000 Subject: [PATCH 01/16] feat: add XProvence context pruning support for RAG Add XProvence model integration for zero-cost context pruning in reranking. XProvence removes irrelevant sentences from passages based on query relevance, returning both reranking scores and pruned context. Changes: - Add XProvenceModel class with process() method for sentence-level pruning - Add pruned_text field to Score type and HTTP response - Pass raw_query/raw_text through tokenization pipeline for pruning - Make flash_attn imports optional for XProvence compatibility - Add XProvence architecture detection in router and Python backend - Handle bfloat16 to float32 conversion for XProvence process() method Configuration: - XPROVENCE_THRESHOLD: Pruning threshold 0.0-1.0 (default: 0.3) - XPROVENCE_ALWAYS_SELECT_TITLE: Keep first sentence as title (default: true) Usage: XPROVENCE_THRESHOLD=0.3 text-embeddings-router \ --model-id naver/xprovence-reranker-bgem3-v1 --port 8080 --- backends/core/src/lib.rs | 15 +- backends/grpc-client/src/client.rs | 4 + backends/proto/embed.proto | 6 + .../text_embeddings_server/models/__init__.py | 31 ++- .../text_embeddings_server/models/types.py | 9 + .../models/xprovence_model.py | 176 ++++++++++++++++++ backends/python/src/lib.rs | 20 +- core/src/infer.rs | 10 +- core/src/queue.rs | 10 + core/src/tokenization.rs | 12 ++ router/src/http/server.rs | 12 +- router/src/http/types.rs | 4 + 12 files changed, 290 insertions(+), 19 deletions(-) create mode 100644 backends/python/server/text_embeddings_server/models/xprovence_model.py diff --git a/backends/core/src/lib.rs b/backends/core/src/lib.rs index 8e134d2be..55dad0d8e 100644 --- a/backends/core/src/lib.rs +++ b/backends/core/src/lib.rs @@ -14,6 +14,10 @@ pub struct Batch { pub max_length: u32, pub pooled_indices: Vec, pub raw_indices: Vec, + /// XProvence: raw query texts for context pruning + pub raw_queries: Vec>, + /// XProvence: raw context texts for context pruning + pub raw_texts: Vec>, } impl Batch { @@ -32,7 +36,16 @@ pub enum Embedding { } pub type Embeddings = IntMap; -pub type Predictions = IntMap>; + +/// XProvence: Prediction result containing scores and optional pruned text +#[derive(Debug, Clone)] +pub struct Prediction { + pub scores: Vec, + /// XProvence: pruned context text after removing irrelevant sentences + pub pruned_text: Option, +} + +pub type Predictions = IntMap; pub trait Backend { fn health(&self) -> Result<(), BackendError>; diff --git a/backends/grpc-client/src/client.rs b/backends/grpc-client/src/client.rs index 1f6036eed..6c3968614 100644 --- a/backends/grpc-client/src/client.rs +++ b/backends/grpc-client/src/client.rs @@ -73,6 +73,8 @@ impl Client { position_ids: Vec, cu_seq_lengths: Vec, max_length: u32, + raw_query: Option, + raw_text: Option, ) -> Result> { let request = tonic::Request::new(EmbedRequest { input_ids, @@ -80,6 +82,8 @@ impl Client { position_ids, max_length, cu_seq_lengths, + raw_query, + raw_text, }) .inject_context(); let response = self.stub.predict(request).await?.into_inner(); diff --git a/backends/proto/embed.proto b/backends/proto/embed.proto index 036f3db4b..55df0889f 100644 --- a/backends/proto/embed.proto +++ b/backends/proto/embed.proto @@ -21,6 +21,10 @@ message EmbedRequest { repeated uint32 cu_seq_lengths = 4; /// Length of the longest request uint32 max_length = 5; + /// XProvence: raw query text for context pruning + optional string raw_query = 6; + /// XProvence: raw context text for context pruning + optional string raw_text = 7; } message Embedding { @@ -33,6 +37,8 @@ message EmbedResponse { message Score { repeated float values = 1; + /// XProvence: pruned context text after removing irrelevant sentences + optional string pruned_text = 2; } message PredictResponse { diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 1e919f233..55f93c0bf 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -11,11 +11,19 @@ from text_embeddings_server.models.masked_model import MaskedLanguageModel from text_embeddings_server.models.default_model import DefaultModel from text_embeddings_server.models.classification_model import ClassificationModel -from text_embeddings_server.models.jinaBert_model import FlashJinaBert -from text_embeddings_server.models.flash_mistral import FlashMistral -from text_embeddings_server.models.flash_qwen3 import FlashQwen3 +from text_embeddings_server.models.xprovence_model import XProvenceModel from text_embeddings_server.utils.device import get_device, use_ipex +FlashJinaBert = None +FlashMistral = None +FlashQwen3 = None +try: + from text_embeddings_server.models.jinaBert_model import FlashJinaBert + from text_embeddings_server.models.flash_mistral import FlashMistral + from text_embeddings_server.models.flash_qwen3 import FlashQwen3 +except ImportError as e: + logger.warning(f"Flash attention models not available: {e}") + __all__ = ["Model"] TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"] @@ -76,13 +84,21 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE) if ( - hasattr(config, "auto_map") + hasattr(config, "architectures") + and config.architectures + and "XProvence" in config.architectures[0] + ): + logger.info("Detected XProvence model for context pruning") + return XProvenceModel(model_path, device, datatype, trust_remote=True) + + if ( + FlashJinaBert is not None + and hasattr(config, "auto_map") and isinstance(config.auto_map, dict) and "AutoModel" in config.auto_map and config.auto_map["AutoModel"] == "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel" ): - # Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository return create_model(FlashJinaBert, model_path, device, datatype) if config.model_type == "bert": @@ -116,19 +132,18 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): else: return create_model(DefaultModel, model_path, device, datatype, pool) - if config.model_type == "mistral" and device.type == "hpu": + if FlashMistral is not None and config.model_type == "mistral" and device.type == "hpu": try: return create_model(FlashMistral, model_path, device, datatype, pool) except FileNotFoundError: return create_model(DefaultModel, model_path, device, datatype, pool) - if config.model_type == "qwen3" and device.type == "hpu": + if FlashQwen3 is not None and config.model_type == "qwen3" and device.type == "hpu": try: return create_model(FlashQwen3, model_path, device, datatype, pool) except FileNotFoundError: return create_model(DefaultModel, model_path, device, datatype, pool) - # Default case if config.architectures[0].endswith("Classification"): return create_model(ClassificationModel, model_path, device, datatype) elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade": diff --git a/backends/python/server/text_embeddings_server/models/types.py b/backends/python/server/text_embeddings_server/models/types.py index f27572a9b..f4da0da8e 100644 --- a/backends/python/server/text_embeddings_server/models/types.py +++ b/backends/python/server/text_embeddings_server/models/types.py @@ -36,6 +36,9 @@ class PaddedBatch(Batch): token_type_ids: torch.Tensor position_ids: torch.Tensor attention_mask: torch.Tensor + # XProvence: raw text for context pruning + raw_query: str = None + raw_text: str = None @classmethod @tracer.start_as_current_span("from_pb") @@ -77,11 +80,17 @@ def from_pb( # Move padded tensors all at once all_tensors = all_tensors.to(device) + # XProvence: Extract raw text if present in proto + raw_query = pb.raw_query if hasattr(pb, 'raw_query') and pb.raw_query else None + raw_text = pb.raw_text if hasattr(pb, 'raw_text') and pb.raw_text else None + return PaddedBatch( input_ids=all_tensors[0], token_type_ids=all_tensors[1], position_ids=all_tensors[2], attention_mask=all_tensors[3], + raw_query=raw_query, + raw_text=raw_text, ) def __len__(self): diff --git a/backends/python/server/text_embeddings_server/models/xprovence_model.py b/backends/python/server/text_embeddings_server/models/xprovence_model.py new file mode 100644 index 000000000..f4145a871 --- /dev/null +++ b/backends/python/server/text_embeddings_server/models/xprovence_model.py @@ -0,0 +1,176 @@ +import os +import torch + +from pathlib import Path +from typing import Type, List +from transformers import AutoModel +from opentelemetry import trace +from loguru import logger + +from text_embeddings_server.models.model import Model +from text_embeddings_server.models.types import PaddedBatch, Embedding, Score + +tracer = trace.get_tracer(__name__) + + +def _parse_bool(value: str) -> bool: + """Parse boolean from string with common conventions.""" + return str(value).lower() in ("true", "1", "t", "yes", "on") + + +class XProvenceModel(Model): + """ + XProvence: Zero-cost context pruning model for RAG. + + XProvence removes irrelevant sentences from passages based on relevance + to the query, returning both a reranking score and pruned context. + + Based on bge-reranker-v2-m3 (XLM-RoBERTa), supports 16+ languages. + + Environment Variables: + XPROVENCE_THRESHOLD (float): Pruning threshold between 0.0-1.0. + - 0.3 (default): Conservative pruning, minimal performance drop + - 0.7: Aggressive pruning, higher compression + XPROVENCE_ALWAYS_SELECT_TITLE (bool): Keep first sentence as title. + - true (default): Always include first sentence (useful for Wikipedia) + - false: Only include sentences above threshold + """ + + def __init__( + self, + model_path: Path, + device: torch.device, + dtype: torch.dtype, + pool: str = "cls", + trust_remote: bool = True, + ): + model = AutoModel.from_pretrained(model_path, trust_remote_code=True) + + if dtype == torch.bfloat16: + logger.info("XProvence: using float32 instead of bfloat16 for process() compatibility") + dtype = torch.float32 + + model = model.to(dtype).to(device) + + self.hidden_size = model.config.hidden_size + + position_offset = 0 + model_type = model.config.model_type + if model_type in ["xlm-roberta", "camembert", "roberta"]: + position_offset = model.config.pad_token_id + 1 + + if hasattr(model.config, "max_seq_length"): + self.max_input_length = model.config.max_seq_length + else: + self.max_input_length = ( + model.config.max_position_embeddings - position_offset + ) + + try: + threshold_env = os.getenv("XPROVENCE_THRESHOLD", "0.3") + self.threshold = float(threshold_env) + if not (0.0 <= self.threshold <= 1.0): + logger.warning( + f"XPROVENCE_THRESHOLD={self.threshold} out of bounds [0.0, 1.0], " + f"defaulting to 0.3" + ) + self.threshold = 0.3 + except ValueError: + logger.error( + f"Invalid XPROVENCE_THRESHOLD='{threshold_env}', defaulting to 0.3" + ) + self.threshold = 0.3 + + self.always_select_title = _parse_bool( + os.getenv("XPROVENCE_ALWAYS_SELECT_TITLE", "true") + ) + + logger.info( + f"XProvence model loaded: threshold={self.threshold}, " + f"always_select_title={self.always_select_title} " + f"(Configure via XPROVENCE_THRESHOLD, XPROVENCE_ALWAYS_SELECT_TITLE env vars)" + ) + + super(XProvenceModel, self).__init__(model=model, dtype=dtype, device=device) + + @property + def batch_type(self) -> Type[PaddedBatch]: + return PaddedBatch + + @tracer.start_as_current_span("embed") + def embed(self, batch: PaddedBatch) -> List[Embedding]: + pass + + @tracer.start_as_current_span("predict") + def predict(self, batch: PaddedBatch) -> List[Score]: + """ + XProvence prediction with context pruning support. + + For single-item batches with raw_query/raw_text available, + uses XProvence's process() method for sentence-level pruning. + Otherwise falls back to standard forward pass. + """ + batch_size = len(batch) + + if batch_size == 1 and batch.raw_query and batch.raw_text: + return self._predict_with_pruning(batch.raw_query, batch.raw_text) + + return self._predict_standard(batch) + + def _predict_with_pruning(self, raw_query: str, raw_text: str) -> List[Score]: + """ + Use XProvence's process() method for context pruning. + + Returns score with pruned_text containing only relevant sentences. + """ + try: + os.environ["TQDM_DISABLE"] = "1" + + original_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.float32) + + try: + output = self.model.process( + raw_query, + raw_text, + threshold=self.threshold, + always_select_title=self.always_select_title, + ) + finally: + torch.set_default_dtype(original_dtype) + + reranking_score = float(output["reranking_score"]) + pruned_context = output["pruned_context"] + + logger.debug( + f"XProvence pruning: score={reranking_score:.4f}, " + f"original_len={len(raw_text)}, pruned_len={len(pruned_context)}" + ) + + return [Score(values=[reranking_score], pruned_text=pruned_context)] + + except Exception as e: + logger.error(f"XProvence process() failed: {e}, falling back to standard") + return [Score(values=[0.0], pruned_text=None)] + + def _predict_standard(self, batch: PaddedBatch) -> List[Score]: + kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} + + output = self.model(**kwargs, return_dict=True) + + if hasattr(output, "ranking_scores"): + scores_tensor = output.ranking_scores + elif hasattr(output, "logits"): + scores_tensor = output.logits[:, 0] if output.logits.dim() == 2 else output.logits + else: + scores_tensor = output[0] + + if scores_tensor.dim() == 0: + scores = [float(scores_tensor.item())] + else: + scores = scores_tensor.view(-1).tolist() + + if isinstance(scores, float): + scores = [scores] + + return [Score(values=[float(s)], pruned_text=None) for s in scores] diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index 53255b07d..0c1d04684 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -5,7 +5,7 @@ use backend_grpc_client::Client; use nohash_hasher::BuildNoHashHasher; use std::collections::HashMap; use text_embeddings_backend_core::{ - Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, + Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Prediction, Predictions, }; use tokio::runtime::Runtime; @@ -108,6 +108,11 @@ impl Backend for PythonBackend { )); } let batch_size = batch.len(); + + // XProvence: Get first raw query/text from batch (for single request) + let raw_query = batch.raw_queries.first().cloned().flatten(); + let raw_text = batch.raw_texts.first().cloned().flatten(); + let results = self .tokio_runtime .block_on(self.backend_client.clone().predict( @@ -116,15 +121,22 @@ impl Backend for PythonBackend { batch.position_ids, batch.cumulative_seq_lengths, batch.max_length, + raw_query, + raw_text, )) .map_err(|err| BackendError::Inference(err.to_string()))?; - let raw_results: Vec> = results.into_iter().map(|r| r.values).collect(); let mut predictions = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); - for (i, r) in raw_results.into_iter().enumerate() { - predictions.insert(i, r); + for (i, score) in results.into_iter().enumerate() { + predictions.insert( + i, + Prediction { + scores: score.values, + pruned_text: score.pruned_text, + }, + ); } Ok(predictions) diff --git a/core/src/infer.rs b/core/src/infer.rs index a2ff22c51..fb16eb15a 100644 --- a/core/src/infer.rs +++ b/core/src/infer.rs @@ -561,11 +561,13 @@ async fn backend_task(backend: Backend, mut embed_receiver: mpsc::Receiver, + /// XProvence: pruned context text after removing irrelevant sentences + pub pruned_text: Option, pub metadata: InferMetadata, } diff --git a/core/src/queue.rs b/core/src/queue.rs index 3fd8b7715..acc3409d4 100644 --- a/core/src/queue.rs +++ b/core/src/queue.rs @@ -129,6 +129,10 @@ fn queue_blocking_task( let mut cu_seq_lengths = Vec::with_capacity(capacity); cu_seq_lengths.push(0); + // XProvence: raw text vectors for context pruning + let mut raw_queries = Vec::with_capacity(capacity); + let mut raw_texts = Vec::with_capacity(capacity); + let mut current_tokens = 0; let mut max_length = 0; @@ -168,6 +172,10 @@ fn queue_blocking_task( token_type_ids.extend(entry.encoding.token_type_ids); position_ids.extend(entry.encoding.position_ids); + // XProvence: collect raw texts for context pruning + raw_queries.push(entry.encoding.raw_query); + raw_texts.push(entry.encoding.raw_text); + current_tokens += entry_tokens; metadata.push(entry.metadata); cu_seq_lengths.push(current_tokens as u32); @@ -193,6 +201,8 @@ fn queue_blocking_task( max_length, pooled_indices, raw_indices, + raw_queries, + raw_texts, }, )) }; diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs index 3639b9845..f42ceb352 100644 --- a/core/src/tokenization.rs +++ b/core/src/tokenization.rs @@ -374,6 +374,12 @@ fn encode_input( prompts: Option<&HashMap>, tokenizer: &mut Tokenizer, ) -> Result { + // XProvence: Extract raw query and text before tokenization (for Dual inputs) + let (raw_query, raw_text) = match &inputs { + EncodingInput::Dual(query, text) => (Some(query.clone()), Some(text.clone())), + _ => (None, None), + }; + // Default truncation params let truncate_params = truncate.then_some(TruncationParams { direction: truncation_direction, @@ -406,6 +412,8 @@ fn encode_input( token_type_ids: encoding.get_type_ids().to_vec(), position_ids: (position_offset as u32..(seq_len + position_offset) as u32) .collect::>(), + raw_query, + raw_text, }) } @@ -414,6 +422,10 @@ pub struct ValidEncoding { pub input_ids: Vec, pub token_type_ids: Vec, pub position_ids: Vec, + /// XProvence: raw query text for context pruning (from Dual input) + pub raw_query: Option, + /// XProvence: raw context text for context pruning (from Dual input) + pub raw_text: Option, } #[derive(Debug)] diff --git a/router/src/http/server.rs b/router/src/http/server.rs index a22af9628..1cb57a165 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -361,13 +361,16 @@ async fn rerank( .map_err(ErrorResponse::from)?; let score = response.results[0]; + // XProvence: extract pruned_text from response + let pruned_text = response.pruned_text; - Ok::<(usize, Duration, Duration, Duration, f32), ErrorResponse>(( + Ok::<(usize, Duration, Duration, Duration, f32, Option), ErrorResponse>(( response.metadata.prompt_tokens, response.metadata.tokenization, response.metadata.queue, response.metadata.inference, score, + pruned_text, )) }; @@ -410,7 +413,7 @@ async fn rerank( let results = join_all(futures) .await .into_iter() - .collect::, ErrorResponse>>()?; + .collect::)>, ErrorResponse>>()?; let mut ranks = Vec::with_capacity(batch_size); let mut total_tokenization_time = 0; @@ -430,6 +433,9 @@ async fn rerank( }; let score = r.4; + // XProvence: extract pruned_text from result + let pruned_text = r.5; + // Check that s is not NaN or the partial_cmp below will panic if score.is_nan() { Err(ErrorResponse { @@ -438,7 +444,7 @@ async fn rerank( })?; } - ranks.push(Rank { index, text, score }) + ranks.push(Rank { index, text, score, pruned_text }) } // Reverse sort diff --git a/router/src/http/types.rs b/router/src/http/types.rs index dedaab60a..ce9994b22 100644 --- a/router/src/http/types.rs +++ b/router/src/http/types.rs @@ -266,6 +266,10 @@ pub(crate) struct Rank { pub text: Option, #[schema(example = "1.0")] pub score: f32, + /// XProvence: pruned context with irrelevant sentences removed + #[schema(nullable = true, default = "null")] + #[serde(skip_serializing_if = "Option::is_none")] + pub pruned_text: Option, } #[derive(Serialize, ToSchema)] From 443a227187fae5040f571b265b6599dabd8f8705 Mon Sep 17 00:00:00 2001 From: Sigrid Jin Date: Fri, 5 Dec 2025 10:20:29 +0000 Subject: [PATCH 02/16] feat: add XProvence context pruning support for RAG Add XProvence model integration for zero-cost context pruning in reranking. XProvence removes irrelevant sentences from passages based on query relevance, returning both reranking scores and pruned context. Changes: - Add XProvenceModel class with process() method for sentence-level pruning - Add pruned_text field to Score/Prediction types and HTTP response - Pass raw_query/raw_text through tokenization pipeline for pruning - Make flash_attn imports optional for XProvence compatibility - Add XProvence architecture detection in router and Python backend - Handle bfloat16 to float32 conversion for XProvence process() method - Update candle, ort backends to support Prediction with pruned_text - Add Dockerfile-cuda-python for Python backend with CUDA support Configuration: - XPROVENCE_THRESHOLD: Pruning threshold 0.0-1.0 (default: 0.3) - XPROVENCE_ALWAYS_SELECT_TITLE: Keep first sentence as title (default: true) Usage: XPROVENCE_THRESHOLD=0.3 text-embeddings-router \ --model-id naver/xprovence-reranker-bgem3-v1 --port 8080 Docker build: docker build -f Dockerfile-cuda-python -t tei-python-cuda . --- Dockerfile | 158 +++++++++++++++-------------- backends/candle/src/lib.rs | 7 +- backends/grpc-client/src/client.rs | 2 + backends/ort/src/lib.rs | 7 +- backends/src/lib.rs | 6 ++ router/src/lib.rs | 3 +- 6 files changed, 104 insertions(+), 79 deletions(-) diff --git a/Dockerfile b/Dockerfile index e4a01b249..fbddbf631 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,14 +1,35 @@ -FROM lukemathwalker/cargo-chef:latest-rust-1.85-bookworm AS chef -WORKDIR /usr/src +# Dockerfile for TEI with Python backend and CUDA support +# Supports: L40s (sm_89), RTX 3090 (sm_86) + +# ============================================================================= +# Stage 1: Rust Builder +# ============================================================================= +FROM nvidia/cuda:12.4.0-devel-ubuntu22.04 AS rust-builder ENV SCCACHE=0.10.0 ENV RUSTC_WRAPPER=/usr/local/bin/sccache +ENV PATH="/root/.cargo/bin:${PATH}" +ENV CARGO_CHEF=0.1.71 + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl \ + libssl-dev \ + pkg-config \ + protobuf-compiler \ + && rm -rf /var/lib/apt/lists/* -# Donwload, configure sccache RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ chmod +x /usr/local/bin/sccache -FROM chef AS planner +RUN curl https://sh.rustup.rs -sSf | bash -s -- -y +RUN cargo install cargo-chef --version $CARGO_CHEF --locked + +# ============================================================================= +# Stage 2: Recipe Planner +# ============================================================================= +FROM rust-builder AS planner + +WORKDIR /usr/src COPY backends backends COPY core core @@ -16,34 +37,21 @@ COPY router router COPY Cargo.toml ./ COPY Cargo.lock ./ -RUN cargo chef prepare --recipe-path recipe.json +RUN cargo chef prepare --recipe-path recipe.json -FROM chef AS builder +# ============================================================================= +# Stage 3: Dependency Builder +# ============================================================================= +FROM rust-builder AS builder ARG GIT_SHA ARG DOCKER_LABEL -# sccache specific variables -ARG SCCACHE_GHA_ENABLED - -RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ - | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \ - echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | \ - tee /etc/apt/sources.list.d/oneAPI.list - -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ - intel-oneapi-mkl-devel=2024.0.0-49656 \ - build-essential \ - && rm -rf /var/lib/apt/lists/* - -RUN echo "int mkl_serv_intel_cpu_true() {return 1;}" > fakeintel.c && \ - gcc -shared -fPIC -o libfakeintel.so fakeintel.c +WORKDIR /usr/src COPY --from=planner /usr/src/recipe.json recipe.json -RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ - --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ - cargo chef cook --release --features ort,candle,mkl,static-linking --no-default-features --recipe-path recipe.json && sccache -s +RUN cargo chef cook --release --features python --features http --recipe-path recipe.json && sccache -s COPY backends backends COPY core core @@ -51,73 +59,75 @@ COPY router router COPY Cargo.toml ./ COPY Cargo.lock ./ -FROM builder AS http-builder +RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s -RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ - --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ - cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,http --no-default-features && sccache -s +# ============================================================================= +# Stage 4: Python Environment +# ============================================================================= +FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 AS python-builder -FROM builder AS grpc-builder +ENV DEBIAN_FRONTEND=noninteractive -RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ - curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ - unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ - unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ - rm -f $PROTOC_ZIP +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3.10 \ + python3.10-dev \ + python3-pip \ + git \ + && rm -rf /var/lib/apt/lists/* -COPY proto proto +RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \ + ln -sf /usr/bin/python3.10 /usr/bin/python3 -RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ - --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ - cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,grpc --no-default-features && sccache -s +RUN pip install --no-cache-dir --upgrade pip setuptools wheel -FROM debian:bookworm-slim AS base +WORKDIR /opt/server -ENV HUGGINGFACE_HUB_CACHE=/data \ - PORT=80 \ - MKL_ENABLE_INSTRUCTIONS=AVX512_E4 \ - RAYON_NUM_THREADS=8 \ - LD_PRELOAD=/usr/local/libfakeintel.so \ - LD_LIBRARY_PATH=/usr/local/lib +COPY backends/proto /opt/proto +COPY backends/python/server /opt/server -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ - libomp-dev \ - ca-certificates \ - libssl-dev \ - curl \ - && rm -rf /var/lib/apt/lists/* +RUN pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir && \ + mkdir -p text_embeddings_server/pb && \ + python -m grpc_tools.protoc -I/opt/proto --python_out=text_embeddings_server/pb \ + --grpc_python_out=text_embeddings_server/pb --mypy_out=text_embeddings_server/pb /opt/proto/embed.proto && \ + find text_embeddings_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; && \ + touch text_embeddings_server/pb/__init__.py -# Copy a lot of the Intel shared objects because of the mkl_serv_intel_cpu_true patch... -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_lp64.so.2 /usr/local/lib/libmkl_intel_lp64.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_thread.so.2 /usr/local/lib/libmkl_intel_thread.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so.2 /usr/local/lib/libmkl_core.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_def.so.2 /usr/local/lib/libmkl_vml_def.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_def.so.2 /usr/local/lib/libmkl_def.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx2.so.2 /usr/local/lib/libmkl_vml_avx2.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx512.so.2 /usr/local/lib/libmkl_vml_avx512.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx2.so.2 /usr/local/lib/libmkl_avx2.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx512.so.2 /usr/local/lib/libmkl_avx512.so.2 -COPY --from=builder /usr/src/libfakeintel.so /usr/local/libfakeintel.so +RUN pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu124 -FROM base AS grpc +RUN pip install --no-cache-dir -r requirements.txt -COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router +RUN pip install --no-cache-dir -e . -ENTRYPOINT ["text-embeddings-router"] -CMD ["--json-output"] +# ============================================================================= +# Stage 5: Final Image +# ============================================================================= +FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 + +ENV DEBIAN_FRONTEND=noninteractive +ENV HUGGINGFACE_HUB_CACHE=/data +ENV PORT=80 +ENV TQDM_DISABLE=1 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3.10 \ + python3-pip \ + ca-certificates \ + libssl-dev \ + curl \ + && rm -rf /var/lib/apt/lists/* -FROM base AS http +RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \ + ln -sf /usr/bin/python3.10 /usr/bin/python3 -COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router +COPY --from=python-builder /usr/local/lib/python3.10/dist-packages /usr/local/lib/python3.10/dist-packages +COPY --from=python-builder /opt/server /opt/server -# Amazon SageMaker compatible image -FROM http AS sagemaker -COPY --chmod=775 sagemaker-entrypoint.sh entrypoint.sh +COPY --from=builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router -ENTRYPOINT ["./entrypoint.sh"] +ENV PATH="/usr/local/bin:${PATH}" +ENV PYTHONPATH="/opt/server:${PYTHONPATH}" -# Default image -FROM http +WORKDIR /opt/server ENTRYPOINT ["text-embeddings-router"] CMD ["--json-output"] diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index ff824f555..0d5fa97fc 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -14,7 +14,7 @@ use serde::{de::Deserializer, Deserialize}; use std::collections::HashMap; use std::path::Path; use text_embeddings_backend_core::{ - Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions, + Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Prediction, Predictions, }; #[cfg(feature = "cuda")] @@ -653,7 +653,10 @@ impl Backend for CandleBackend { let mut predictions = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); for (i, r) in results.into_iter().enumerate() { - predictions.insert(i, r); + predictions.insert(i, Prediction { + scores: r, + pruned_text: None, + }); } Ok(predictions) diff --git a/backends/grpc-client/src/client.rs b/backends/grpc-client/src/client.rs index 6c3968614..33f75da6e 100644 --- a/backends/grpc-client/src/client.rs +++ b/backends/grpc-client/src/client.rs @@ -59,6 +59,8 @@ impl Client { position_ids, max_length, cu_seq_lengths, + raw_query: None, + raw_text: None, }) .inject_context(); let response = self.stub.embed(request).await?.into_inner(); diff --git a/backends/ort/src/lib.rs b/backends/ort/src/lib.rs index bfc2d03ad..4f84d4f79 100644 --- a/backends/ort/src/lib.rs +++ b/backends/ort/src/lib.rs @@ -8,7 +8,7 @@ use std::ops::{Div, Mul}; use std::path::Path; use std::sync::Mutex; use text_embeddings_backend_core::{ - Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, + Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Prediction, Predictions, }; #[derive(Debug, Clone, Deserialize)] @@ -679,7 +679,10 @@ impl Backend for OrtBackend { let mut predictions = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); for (i, r) in outputs.rows().into_iter().enumerate() { - predictions.insert(i, r.to_vec()); + predictions.insert(i, Prediction { + scores: r.to_vec(), + pruned_text: None, + }); } Ok(predictions) diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 245715b38..79bc05d29 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -223,6 +223,8 @@ impl Backend { max_length: tmp_length, pooled_indices, raw_indices: vec![], + raw_queries: vec![], + raw_texts: vec![], } } @@ -280,6 +282,8 @@ impl Backend { max_length, pooled_indices, raw_indices: vec![], + raw_queries: vec![], + raw_texts: vec![], }; match &self.model_type { @@ -314,6 +318,8 @@ impl Backend { max_length: 1, pooled_indices: vec![0], raw_indices: vec![], + raw_queries: vec![], + raw_texts: vec![], }; match &self.model_type { ModelType::Classifier => self.predict(batch).await.map(|_| ()), diff --git a/router/src/lib.rs b/router/src/lib.rs index d83bd95c5..9c5eb98f4 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -396,7 +396,8 @@ fn get_backend_model_type( return Ok(text_embeddings_backend::ModelType::Embedding( text_embeddings_backend::Pool::Splade, )); - } else if arch.ends_with("Classification") { + } else if arch.ends_with("Classification") || arch == "XProvence" { + // XProvence is a reranker model for context pruning if pooling.is_some() { tracing::warn!( "`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg." From 7f37c49f3d75d6813da12f28f2309feef72f0944 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigrid=20Jin=20=28=E0=B8=87=27=CC=80-=27=CC=81=29=E0=B8=87?= =?UTF-8?q?=20oO?= Date: Sun, 7 Dec 2025 12:40:07 +0900 Subject: [PATCH 03/16] Create revision --- revision | 1 + 1 file changed, 1 insertion(+) create mode 100644 revision diff --git a/revision b/revision new file mode 100644 index 000000000..d00491fd7 --- /dev/null +++ b/revision @@ -0,0 +1 @@ +1 From 1287a94c29803cffd931d680fbc1fb6355377ce9 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sun, 7 Dec 2025 12:55:50 +0900 Subject: [PATCH 04/16] Remove unused assets --- assets/bs1-lat.png | 3 --- assets/bs1-tp.png | 3 --- assets/bs32-lat.png | 3 --- assets/bs32-tp.png | 3 --- 4 files changed, 12 deletions(-) delete mode 100644 assets/bs1-lat.png delete mode 100644 assets/bs1-tp.png delete mode 100644 assets/bs32-lat.png delete mode 100644 assets/bs32-tp.png diff --git a/assets/bs1-lat.png b/assets/bs1-lat.png deleted file mode 100644 index 6105ddcc9..000000000 --- a/assets/bs1-lat.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:778b29d7d21382004fef2c528973f66bb175951ab7cd168d588cd245e36bd629 -size 15202 diff --git a/assets/bs1-tp.png b/assets/bs1-tp.png deleted file mode 100644 index 953ff0b68..000000000 --- a/assets/bs1-tp.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:478984ace4f33044bc0a53b0503a0cbfcd0a64f601922e2a13cc34d52c2b7c2b -size 17169 diff --git a/assets/bs32-lat.png b/assets/bs32-lat.png deleted file mode 100644 index ed352e40f..000000000 --- a/assets/bs32-lat.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:769326aad7e582a2e5271dd2d73c3bb5289684add10eb7146ddadd00d3b2077f -size 17596 diff --git a/assets/bs32-tp.png b/assets/bs32-tp.png deleted file mode 100644 index c952bd285..000000000 --- a/assets/bs32-tp.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c227c5adbb8664af7aa3d59aaa408557b2865dcfbd3c6c6353caf71f2eb5b7bc -size 18521 From 367a696f7110fc33cc2c65fa9d560a2e1f26423b Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sun, 7 Dec 2025 12:55:50 +0900 Subject: [PATCH 05/16] Remove unused assets --- Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index fbddbf631..cec7291c8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -96,7 +96,7 @@ RUN pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch RUN pip install --no-cache-dir -r requirements.txt -RUN pip install --no-cache-dir -e . +RUN pip install --no-cache-dir . # ============================================================================= # Stage 5: Final Image @@ -120,6 +120,7 @@ RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \ ln -sf /usr/bin/python3.10 /usr/bin/python3 COPY --from=python-builder /usr/local/lib/python3.10/dist-packages /usr/local/lib/python3.10/dist-packages +COPY --from=python-builder /usr/local/bin/python-text-embeddings-server /usr/local/bin/python-text-embeddings-server COPY --from=python-builder /opt/server /opt/server COPY --from=builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router From 795cf5270bbd690cc2abe5cd49d2cd0a5a45af1c Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sun, 7 Dec 2025 23:57:39 +0900 Subject: [PATCH 06/16] fix: snapshots --- .../models/xprovence_model.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/backends/python/server/text_embeddings_server/models/xprovence_model.py b/backends/python/server/text_embeddings_server/models/xprovence_model.py index f4145a871..c782045c1 100644 --- a/backends/python/server/text_embeddings_server/models/xprovence_model.py +++ b/backends/python/server/text_embeddings_server/models/xprovence_model.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Type, List from transformers import AutoModel +from huggingface_hub import snapshot_download from opentelemetry import trace from loguru import logger @@ -44,6 +45,25 @@ def __init__( pool: str = "cls", trust_remote: bool = True, ): + # Download all model files including custom Python files for trust_remote_code + # The Rust router only downloads config/tokenizer/weights, but not custom modeling files + model_path_str = str(model_path) + if model_path_str.startswith("/data/models--"): + # Extract model_id from HF cache path format: /data/models--org--name/... + # Convert "models--naver--xprovence-reranker-bgem3-v1" to "naver/xprovence-reranker-bgem3-v1" + parts = model_path_str.split("/") + for part in parts: + if part.startswith("models--"): + model_id = part.replace("models--", "").replace("--", "/", 1) + logger.info(f"XProvence: Downloading custom files for {model_id}") + cache_dir = os.getenv("HUGGINGFACE_HUB_CACHE", "/data") + snapshot_download( + repo_id=model_id, + cache_dir=cache_dir, + local_files_only=False, + ) + break + model = AutoModel.from_pretrained(model_path, trust_remote_code=True) if dtype == torch.bfloat16: From 1b0aba04f9b14c74376e7a5f2acd008899a9a084 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Mon, 8 Dec 2025 00:45:26 +0900 Subject: [PATCH 07/16] feat: add spacy dependency for XProvence sentence tokenization --- Dockerfile | 7 +++++++ backends/python/server/requirements.txt | 1 + 2 files changed, 8 insertions(+) diff --git a/Dockerfile b/Dockerfile index cec7291c8..1b543ccc0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -128,6 +128,13 @@ COPY --from=builder /usr/src/target/release/text-embeddings-router /usr/local/bi ENV PATH="/usr/local/bin:${PATH}" ENV PYTHONPATH="/opt/server:${PYTHONPATH}" +# Download spacy model in final image (ensures it's available at runtime) +# This is needed because spacy models may not be fully copied from builder stage +RUN pip install --no-cache-dir spacy>=3.7.0 && \ + python -m spacy download xx_sent_ud_sm && \ + python -c "import spacy; spacy.load('xx_sent_ud_sm')" && \ + echo "Spacy model verified successfully" + WORKDIR /opt/server ENTRYPOINT ["text-embeddings-router"] diff --git a/backends/python/server/requirements.txt b/backends/python/server/requirements.txt index 687ec1028..b893f569c 100644 --- a/backends/python/server/requirements.txt +++ b/backends/python/server/requirements.txt @@ -52,6 +52,7 @@ safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" scikit-learn==1.5.2 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13" +spacy>=3.7.0 ; python_version >= "3.9" and python_version < "3.13" setuptools==75.6.0 ; python_version >= "3.9" and python_version < "3.13" sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13" From 7ff382c4b7fbe4d5782856e447d777f8ae8a664b Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Mon, 8 Dec 2025 01:17:50 +0900 Subject: [PATCH 08/16] fix: load XProvenceConfig before model to avoid config class mismatch --- .../server/text_embeddings_server/models/xprovence_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/xprovence_model.py b/backends/python/server/text_embeddings_server/models/xprovence_model.py index c782045c1..23c6e3eff 100644 --- a/backends/python/server/text_embeddings_server/models/xprovence_model.py +++ b/backends/python/server/text_embeddings_server/models/xprovence_model.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Type, List -from transformers import AutoModel +from transformers import AutoModel, AutoConfig from huggingface_hub import snapshot_download from opentelemetry import trace from loguru import logger @@ -64,7 +64,9 @@ def __init__( ) break - model = AutoModel.from_pretrained(model_path, trust_remote_code=True) + # Load config first with trust_remote_code to get the correct XProvenceConfig + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True) if dtype == torch.bfloat16: logger.info("XProvence: using float32 instead of bfloat16 for process() compatibility") From cda8d7981c2ad109cf639d276af99c2f6a4491fd Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Mon, 8 Dec 2025 01:50:32 +0900 Subject: [PATCH 09/16] fix: use model_id directly to avoid XProvenceConfig class mismatch The previous fix (7ff382c) incorrectly passed config from AutoConfig.from_pretrained to AutoModel.from_pretrained. Since XProvence's config.json lacks auto_map for AutoConfig, it returned XLMRobertaConfig while the model expected XProvenceConfig. New approach: - Extract model_id from cache path (e.g., naver/xprovence-reranker-bgem3-v1) - Use model_id directly with AutoModel.from_pretrained(model_id, trust_remote_code=True) - Let AutoModel handle config internally via model class's config_class attribute - Remove explicit config passing and snapshot_download (AutoModel handles downloads) --- .../models/xprovence_model.py | 67 ++++++++++++------- 1 file changed, 43 insertions(+), 24 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/xprovence_model.py b/backends/python/server/text_embeddings_server/models/xprovence_model.py index 23c6e3eff..5d4ab6122 100644 --- a/backends/python/server/text_embeddings_server/models/xprovence_model.py +++ b/backends/python/server/text_embeddings_server/models/xprovence_model.py @@ -2,9 +2,8 @@ import torch from pathlib import Path -from typing import Type, List -from transformers import AutoModel, AutoConfig -from huggingface_hub import snapshot_download +from typing import Type, List, Optional +from transformers import AutoModel from opentelemetry import trace from loguru import logger @@ -19,6 +18,23 @@ def _parse_bool(value: str) -> bool: return str(value).lower() in ("true", "1", "t", "yes", "on") +def _extract_model_id(model_path_str: str) -> Optional[str]: + """Extract model_id from HF cache path format. + + Converts paths like '/data/models--naver--xprovence-reranker-bgem3-v1/snapshots/...' + to 'naver/xprovence-reranker-bgem3-v1' + """ + if "/models--" not in model_path_str: + return None + + parts = model_path_str.split("/") + for part in parts: + if part.startswith("models--"): + # models--naver--xprovence-reranker-bgem3-v1 -> naver/xprovence-reranker-bgem3-v1 + return part.replace("models--", "").replace("--", "/", 1) + return None + + class XProvenceModel(Model): """ XProvence: Zero-cost context pruning model for RAG. @@ -45,28 +61,31 @@ def __init__( pool: str = "cls", trust_remote: bool = True, ): - # Download all model files including custom Python files for trust_remote_code - # The Rust router only downloads config/tokenizer/weights, but not custom modeling files model_path_str = str(model_path) - if model_path_str.startswith("/data/models--"): - # Extract model_id from HF cache path format: /data/models--org--name/... - # Convert "models--naver--xprovence-reranker-bgem3-v1" to "naver/xprovence-reranker-bgem3-v1" - parts = model_path_str.split("/") - for part in parts: - if part.startswith("models--"): - model_id = part.replace("models--", "").replace("--", "/", 1) - logger.info(f"XProvence: Downloading custom files for {model_id}") - cache_dir = os.getenv("HUGGINGFACE_HUB_CACHE", "/data") - snapshot_download( - repo_id=model_id, - cache_dir=cache_dir, - local_files_only=False, - ) - break - - # Load config first with trust_remote_code to get the correct XProvenceConfig - config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True) + cache_dir = os.getenv("HUGGINGFACE_HUB_CACHE", "/data") + + # Extract model_id from cache path for proper trust_remote_code handling + model_id = _extract_model_id(model_path_str) + + if model_id: + # Use model_id directly with AutoModel.from_pretrained + # This ensures: + # 1. All custom Python files (modeling_*.py) are downloaded + # 2. The correct XProvenceConfig is loaded via model class's config_class attribute + # 3. No config class mismatch (unlike passing config from AutoConfig.from_pretrained) + logger.info(f"XProvence: Loading {model_id} with trust_remote_code=True") + model = AutoModel.from_pretrained( + model_id, + trust_remote_code=True, + cache_dir=cache_dir, + ) + else: + # Fallback for local paths not in HF cache format + logger.info(f"XProvence: Loading from local path {model_path}") + model = AutoModel.from_pretrained( + model_path, + trust_remote_code=True, + ) if dtype == torch.bfloat16: logger.info("XProvence: using float32 instead of bfloat16 for process() compatibility") From 7b6363422a607197bea0479e7db17704a88e6735 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Mon, 8 Dec 2025 02:23:18 +0900 Subject: [PATCH 10/16] fix: check XProvence before AutoConfig to prevent registry pollution The previous fix still failed because __init__.py called AutoConfig.from_pretrained before XProvenceModel was created. This polluted transformers' internal config registry with XLMRobertaConfig, causing conflicts when XProvenceModel tried to load the custom XProvenceConfig. Solution: - Add _is_xprovence_model() helper that reads config.json directly - Check for XProvence BEFORE calling AutoConfig.from_pretrained - This prevents transformers from caching the wrong config class --- .../text_embeddings_server/models/__init__.py | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 55f93c0bf..ac6fd0211 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -1,4 +1,5 @@ import os +import json import torch from loguru import logger @@ -14,6 +15,25 @@ from text_embeddings_server.models.xprovence_model import XProvenceModel from text_embeddings_server.utils.device import get_device, use_ipex + +def _is_xprovence_model(model_path: Path) -> bool: + """Check if model is XProvence by reading config.json directly. + + This avoids calling AutoConfig.from_pretrained which can pollute + transformers' internal registry and cause config class conflicts. + """ + config_path = model_path / "config.json" + if not config_path.exists(): + return False + + try: + with open(config_path, "r") as f: + config = json.load(f) + architectures = config.get("architectures", []) + return any("XProvence" in arch for arch in architectures) + except Exception: + return False + FlashJinaBert = None FlashMistral = None FlashQwen3 = None @@ -81,16 +101,14 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): device = get_device() logger.info(f"backend device: {device}") - config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE) - - if ( - hasattr(config, "architectures") - and config.architectures - and "XProvence" in config.architectures[0] - ): + # Check for XProvence BEFORE calling AutoConfig.from_pretrained + # to avoid polluting transformers' internal config registry + if _is_xprovence_model(model_path): logger.info("Detected XProvence model for context pruning") return XProvenceModel(model_path, device, datatype, trust_remote=True) + config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE) + if ( FlashJinaBert is not None and hasattr(config, "auto_map") From 09e8491d7624598ba4d975eced4d3d9cc49bff1c Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Mon, 8 Dec 2025 02:50:05 +0900 Subject: [PATCH 11/16] fix: use get_class_from_dynamic_module to avoid config mismatch AutoModel.from_pretrained internally calls AutoConfig which returns XLMRobertaConfig, causing a conflict with the model's XProvenceConfig. Solution: Use transformers.dynamic_module_utils.get_class_from_dynamic_module to directly import the custom XProvenceForSequenceClassification class, then call from_pretrained on the custom class which uses its own config_class. --- .../models/xprovence_model.py | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/xprovence_model.py b/backends/python/server/text_embeddings_server/models/xprovence_model.py index 5d4ab6122..42f44deb5 100644 --- a/backends/python/server/text_embeddings_server/models/xprovence_model.py +++ b/backends/python/server/text_embeddings_server/models/xprovence_model.py @@ -3,7 +3,8 @@ from pathlib import Path from typing import Type, List, Optional -from transformers import AutoModel +from transformers.dynamic_module_utils import get_class_from_dynamic_module +from huggingface_hub import hf_hub_download from opentelemetry import trace from loguru import logger @@ -68,21 +69,32 @@ def __init__( model_id = _extract_model_id(model_path_str) if model_id: - # Use model_id directly with AutoModel.from_pretrained - # This ensures: - # 1. All custom Python files (modeling_*.py) are downloaded - # 2. The correct XProvenceConfig is loaded via model class's config_class attribute - # 3. No config class mismatch (unlike passing config from AutoConfig.from_pretrained) - logger.info(f"XProvence: Loading {model_id} with trust_remote_code=True") - model = AutoModel.from_pretrained( + # Directly import the custom model class to avoid AutoModel's config class mismatch + # AutoModel.from_pretrained internally loads config which causes XLMRobertaConfig + # to be registered, conflicting with the model's expected XProvenceConfig + logger.info(f"XProvence: Loading custom model class for {model_id}") + + # Get the custom model class directly from the dynamic module + model_class = get_class_from_dynamic_module( + "modeling_xprovence_hf.XProvenceForSequenceClassification", + model_id, + cache_dir=cache_dir, + ) + + # Load using the custom class directly - this uses the correct config_class + model = model_class.from_pretrained( model_id, trust_remote_code=True, cache_dir=cache_dir, ) else: - # Fallback for local paths not in HF cache format + # Fallback for local paths - try to import from local path logger.info(f"XProvence: Loading from local path {model_path}") - model = AutoModel.from_pretrained( + model_class = get_class_from_dynamic_module( + "modeling_xprovence_hf.XProvenceForSequenceClassification", + model_path, + ) + model = model_class.from_pretrained( model_path, trust_remote_code=True, ) From 8016c4fe9896ae2c368c4547d4643ab06b786a83 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Mon, 8 Dec 2025 03:24:15 +0900 Subject: [PATCH 12/16] fix: correct XProvence class name (XProvence, not XProvenceForSequenceClassification) --- .../server/text_embeddings_server/models/xprovence_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/xprovence_model.py b/backends/python/server/text_embeddings_server/models/xprovence_model.py index 42f44deb5..407c1a423 100644 --- a/backends/python/server/text_embeddings_server/models/xprovence_model.py +++ b/backends/python/server/text_embeddings_server/models/xprovence_model.py @@ -76,7 +76,7 @@ def __init__( # Get the custom model class directly from the dynamic module model_class = get_class_from_dynamic_module( - "modeling_xprovence_hf.XProvenceForSequenceClassification", + "modeling_xprovence_hf.XProvence", model_id, cache_dir=cache_dir, ) @@ -91,7 +91,7 @@ def __init__( # Fallback for local paths - try to import from local path logger.info(f"XProvence: Loading from local path {model_path}") model_class = get_class_from_dynamic_module( - "modeling_xprovence_hf.XProvenceForSequenceClassification", + "modeling_xprovence_hf.XProvence", model_path, ) model = model_class.from_pretrained( From 6765247933562c2e0275c1dbab5abae9677a39e4 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Mon, 8 Dec 2025 04:14:31 +0900 Subject: [PATCH 13/16] debug: add logging to XProvence predict to diagnose raw_query/raw_text --- .../text_embeddings_server/models/xprovence_model.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/backends/python/server/text_embeddings_server/models/xprovence_model.py b/backends/python/server/text_embeddings_server/models/xprovence_model.py index 407c1a423..94ba34216 100644 --- a/backends/python/server/text_embeddings_server/models/xprovence_model.py +++ b/backends/python/server/text_embeddings_server/models/xprovence_model.py @@ -165,9 +165,19 @@ def predict(self, batch: PaddedBatch) -> List[Score]: """ batch_size = len(batch) + # Debug: log raw_query/raw_text availability + has_query = batch.raw_query is not None + has_text = batch.raw_text is not None + logger.info( + f"XProvence predict: batch_size={batch_size}, " + f"has_raw_query={has_query}, has_raw_text={has_text}" + ) + if batch_size == 1 and batch.raw_query and batch.raw_text: + logger.info("XProvence: Using process() for context pruning") return self._predict_with_pruning(batch.raw_query, batch.raw_text) + logger.info("XProvence: Using standard forward pass (no raw_query/raw_text)") return self._predict_standard(batch) def _predict_with_pruning(self, raw_query: str, raw_text: str) -> List[Score]: From cc0b4e57140089e5d99a5cc5ec7bc8b740c0e5f6 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Mon, 8 Dec 2025 04:16:40 +0900 Subject: [PATCH 14/16] fix: use HasField for proto3 optional fields in types.py --- .../server/text_embeddings_server/models/types.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/types.py b/backends/python/server/text_embeddings_server/models/types.py index f4da0da8e..ca014a46b 100644 --- a/backends/python/server/text_embeddings_server/models/types.py +++ b/backends/python/server/text_embeddings_server/models/types.py @@ -81,8 +81,18 @@ def from_pb( all_tensors = all_tensors.to(device) # XProvence: Extract raw text if present in proto - raw_query = pb.raw_query if hasattr(pb, 'raw_query') and pb.raw_query else None - raw_text = pb.raw_text if hasattr(pb, 'raw_text') and pb.raw_text else None + # Use HasField for proto3 optional fields to properly detect if they were set + raw_query = None + raw_text = None + if hasattr(pb, 'HasField'): + if pb.HasField('raw_query'): + raw_query = pb.raw_query + if pb.HasField('raw_text'): + raw_text = pb.raw_text + else: + # Fallback for older proto versions + raw_query = pb.raw_query if pb.raw_query else None + raw_text = pb.raw_text if pb.raw_text else None return PaddedBatch( input_ids=all_tensors[0], From 778caf7c7e2e34fa0eee8d0a2161f4b5c1128807 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Mon, 8 Dec 2025 05:01:36 +0900 Subject: [PATCH 15/16] feat: support batch processing with pruned_text for multiple texts Previously, only the first raw_query/raw_text was sent to Python backend, so process() was only called when batch_size == 1. Now all pairs are sent. Changes: - embed.proto: change to repeated string raw_queries/raw_texts - grpc-client: accept Vec instead of Option - backends/python/src/lib.rs: send all raw_queries/texts from batch - types.py: extract lists from proto repeated fields - xprovence_model.py: iterate batch and call process() for each pair Now /rerank with multiple texts returns pruned_text for each result. --- backends/grpc-client/src/client.rs | 12 +++--- backends/proto/embed.proto | 8 ++-- .../text_embeddings_server/models/types.py | 29 +++++-------- .../models/xprovence_model.py | 42 ++++++++++++++----- backends/python/src/lib.rs | 18 +++++--- 5 files changed, 64 insertions(+), 45 deletions(-) diff --git a/backends/grpc-client/src/client.rs b/backends/grpc-client/src/client.rs index 33f75da6e..a5872642f 100644 --- a/backends/grpc-client/src/client.rs +++ b/backends/grpc-client/src/client.rs @@ -59,8 +59,8 @@ impl Client { position_ids, max_length, cu_seq_lengths, - raw_query: None, - raw_text: None, + raw_queries: vec![], + raw_texts: vec![], }) .inject_context(); let response = self.stub.embed(request).await?.into_inner(); @@ -75,8 +75,8 @@ impl Client { position_ids: Vec, cu_seq_lengths: Vec, max_length: u32, - raw_query: Option, - raw_text: Option, + raw_queries: Vec, + raw_texts: Vec, ) -> Result> { let request = tonic::Request::new(EmbedRequest { input_ids, @@ -84,8 +84,8 @@ impl Client { position_ids, max_length, cu_seq_lengths, - raw_query, - raw_text, + raw_queries, + raw_texts, }) .inject_context(); let response = self.stub.predict(request).await?.into_inner(); diff --git a/backends/proto/embed.proto b/backends/proto/embed.proto index 55df0889f..e233902d0 100644 --- a/backends/proto/embed.proto +++ b/backends/proto/embed.proto @@ -21,10 +21,10 @@ message EmbedRequest { repeated uint32 cu_seq_lengths = 4; /// Length of the longest request uint32 max_length = 5; - /// XProvence: raw query text for context pruning - optional string raw_query = 6; - /// XProvence: raw context text for context pruning - optional string raw_text = 7; + /// XProvence: raw query texts for context pruning (one per batch item) + repeated string raw_queries = 6; + /// XProvence: raw context texts for context pruning (one per batch item) + repeated string raw_texts = 7; } message Embedding { diff --git a/backends/python/server/text_embeddings_server/models/types.py b/backends/python/server/text_embeddings_server/models/types.py index ca014a46b..92eb5b2ee 100644 --- a/backends/python/server/text_embeddings_server/models/types.py +++ b/backends/python/server/text_embeddings_server/models/types.py @@ -3,7 +3,8 @@ import torch from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import List, Optional from opentelemetry import trace from text_embeddings_server.pb import embed_pb2 @@ -36,9 +37,9 @@ class PaddedBatch(Batch): token_type_ids: torch.Tensor position_ids: torch.Tensor attention_mask: torch.Tensor - # XProvence: raw text for context pruning - raw_query: str = None - raw_text: str = None + # XProvence: raw texts for context pruning (one per batch item) + raw_queries: Optional[List[str]] = None + raw_texts: Optional[List[str]] = None @classmethod @tracer.start_as_current_span("from_pb") @@ -80,27 +81,17 @@ def from_pb( # Move padded tensors all at once all_tensors = all_tensors.to(device) - # XProvence: Extract raw text if present in proto - # Use HasField for proto3 optional fields to properly detect if they were set - raw_query = None - raw_text = None - if hasattr(pb, 'HasField'): - if pb.HasField('raw_query'): - raw_query = pb.raw_query - if pb.HasField('raw_text'): - raw_text = pb.raw_text - else: - # Fallback for older proto versions - raw_query = pb.raw_query if pb.raw_query else None - raw_text = pb.raw_text if pb.raw_text else None + # XProvence: Extract repeated raw_queries/raw_texts from proto + raw_queries = list(pb.raw_queries) if pb.raw_queries else None + raw_texts = list(pb.raw_texts) if pb.raw_texts else None return PaddedBatch( input_ids=all_tensors[0], token_type_ids=all_tensors[1], position_ids=all_tensors[2], attention_mask=all_tensors[3], - raw_query=raw_query, - raw_text=raw_text, + raw_queries=raw_queries, + raw_texts=raw_texts, ) def __len__(self): diff --git a/backends/python/server/text_embeddings_server/models/xprovence_model.py b/backends/python/server/text_embeddings_server/models/xprovence_model.py index 94ba34216..d59f5a0ba 100644 --- a/backends/python/server/text_embeddings_server/models/xprovence_model.py +++ b/backends/python/server/text_embeddings_server/models/xprovence_model.py @@ -159,25 +159,45 @@ def predict(self, batch: PaddedBatch) -> List[Score]: """ XProvence prediction with context pruning support. - For single-item batches with raw_query/raw_text available, - uses XProvence's process() method for sentence-level pruning. + For batches with raw_queries/raw_texts available (one per item), + uses XProvence's process() method for sentence-level pruning on each pair. Otherwise falls back to standard forward pass. """ batch_size = len(batch) - # Debug: log raw_query/raw_text availability - has_query = batch.raw_query is not None - has_text = batch.raw_text is not None + # Check if we have raw data for the full batch + has_raw_data = ( + batch.raw_queries is not None + and batch.raw_texts is not None + and len(batch.raw_queries) == batch_size + and len(batch.raw_texts) == batch_size + ) + logger.info( f"XProvence predict: batch_size={batch_size}, " - f"has_raw_query={has_query}, has_raw_text={has_text}" + f"has_raw_queries={batch.raw_queries is not None}, " + f"has_raw_texts={batch.raw_texts is not None}, " + f"has_full_raw_data={has_raw_data}" ) - if batch_size == 1 and batch.raw_query and batch.raw_text: - logger.info("XProvence: Using process() for context pruning") - return self._predict_with_pruning(batch.raw_query, batch.raw_text) - - logger.info("XProvence: Using standard forward pass (no raw_query/raw_text)") + if has_raw_data: + logger.info(f"XProvence: Processing batch of {batch_size} with pruning") + results = [] + for i in range(batch_size): + query = batch.raw_queries[i] + text = batch.raw_texts[i] + + # Verify we have valid strings (not empty) + if query and text: + scores = self._predict_with_pruning(query, text) + results.append(scores[0]) + else: + # Empty string fallback - use standard forward pass result + logger.warning(f"XProvence: Empty query/text at index {i}, using 0.0") + results.append(Score(values=[0.0], pruned_text=None)) + return results + + logger.info("XProvence: Using standard forward pass (no raw_queries/raw_texts)") return self._predict_standard(batch) def _predict_with_pruning(self, raw_query: str, raw_text: str) -> List[Score]: diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index 0c1d04684..331391a99 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -109,9 +109,17 @@ impl Backend for PythonBackend { } let batch_size = batch.len(); - // XProvence: Get first raw query/text from batch (for single request) - let raw_query = batch.raw_queries.first().cloned().flatten(); - let raw_text = batch.raw_texts.first().cloned().flatten(); + // XProvence: Collect all raw queries/texts for the batch (one per item) + let raw_queries: Vec = batch + .raw_queries + .into_iter() + .map(|q| q.unwrap_or_default()) + .collect(); + let raw_texts: Vec = batch + .raw_texts + .into_iter() + .map(|t| t.unwrap_or_default()) + .collect(); let results = self .tokio_runtime @@ -121,8 +129,8 @@ impl Backend for PythonBackend { batch.position_ids, batch.cumulative_seq_lengths, batch.max_length, - raw_query, - raw_text, + raw_queries, + raw_texts, )) .map_err(|err| BackendError::Inference(err.to_string()))?; From 42654a49f53fbaef1d574b2f20ad2158a7c76747 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Mon, 8 Dec 2025 05:30:14 +0900 Subject: [PATCH 16/16] fix: optimize XProvence batch processing with broadcasting and warnings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add broadcasting support: 1 query → N texts (common reranking pattern) - Replace silent fallback with explicit warning on dimension mismatch - Use torch.inference_mode() around entire batch for better performance - Reduce per-item overhead by batching dtype handling and TQDM_DISABLE - Add per-item error handling with graceful fallback to 0.0 score Performance improvements: - Single dtype context switch instead of per-item - Single inference_mode context for entire batch - Reduced logging overhead with debug level for per-item details --- .../models/xprovence_model.py | 108 ++++++++++++------ 1 file changed, 76 insertions(+), 32 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/xprovence_model.py b/backends/python/server/text_embeddings_server/models/xprovence_model.py index d59f5a0ba..fb4856432 100644 --- a/backends/python/server/text_embeddings_server/models/xprovence_model.py +++ b/backends/python/server/text_embeddings_server/models/xprovence_model.py @@ -164,41 +164,85 @@ def predict(self, batch: PaddedBatch) -> List[Score]: Otherwise falls back to standard forward pass. """ batch_size = len(batch) + raw_queries = batch.raw_queries or [] + raw_texts = batch.raw_texts or [] - # Check if we have raw data for the full batch - has_raw_data = ( - batch.raw_queries is not None - and batch.raw_texts is not None - and len(batch.raw_queries) == batch_size - and len(batch.raw_texts) == batch_size - ) + # Broadcasting: 1 query → N texts (common reranking pattern) + if len(raw_queries) == 1 and len(raw_texts) == batch_size and batch_size > 1: + logger.info(f"XProvence: Broadcasting single query to {batch_size} texts") + raw_queries = raw_queries * batch_size - logger.info( - f"XProvence predict: batch_size={batch_size}, " - f"has_raw_queries={batch.raw_queries is not None}, " - f"has_raw_texts={batch.raw_texts is not None}, " - f"has_full_raw_data={has_raw_data}" - ) + # Check for dimension mismatch with explicit warning + if len(raw_queries) != batch_size or len(raw_texts) != batch_size: + if raw_queries or raw_texts: + logger.warning( + f"XProvence: Dimension mismatch - batch_size={batch_size}, " + f"raw_queries={len(raw_queries)}, raw_texts={len(raw_texts)}. " + f"Falling back to standard inference (no pruned_text)." + ) + return self._predict_standard(batch) + + # Process batch with pruning (optimized) + logger.info(f"XProvence: Processing {batch_size} pairs with pruning") + return self._predict_batch_with_pruning(raw_queries, raw_texts) + + def _predict_batch_with_pruning( + self, raw_queries: List[str], raw_texts: List[str] + ) -> List[Score]: + """ + Optimized batch processing with pruning. + + Uses inference_mode and batched dtype handling to reduce per-item overhead. + Note: XProvence process() is inherently per-pair for sentence-level analysis. + """ + batch_size = len(raw_queries) + results = [] + + # Suppress progress bars once for entire batch + os.environ["TQDM_DISABLE"] = "1" + + # Use inference_mode for better performance (no grad tracking) + with torch.inference_mode(): + original_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.float32) + + try: + for i in range(batch_size): + query = raw_queries[i] + text = raw_texts[i] + + if not query or not text: + logger.warning( + f"XProvence: Empty query/text at index {i}, score=0.0" + ) + results.append(Score(values=[0.0], pruned_text=None)) + continue + + try: + output = self.model.process( + query, + text, + threshold=self.threshold, + always_select_title=self.always_select_title, + ) + + score = float(output["reranking_score"]) + pruned = output["pruned_context"] + + logger.debug( + f"XProvence [{i}]: score={score:.4f}, " + f"len={len(text)}→{len(pruned)}" + ) + results.append(Score(values=[score], pruned_text=pruned)) + + except Exception as e: + logger.error(f"XProvence process() failed at index {i}: {e}") + results.append(Score(values=[0.0], pruned_text=None)) + + finally: + torch.set_default_dtype(original_dtype) - if has_raw_data: - logger.info(f"XProvence: Processing batch of {batch_size} with pruning") - results = [] - for i in range(batch_size): - query = batch.raw_queries[i] - text = batch.raw_texts[i] - - # Verify we have valid strings (not empty) - if query and text: - scores = self._predict_with_pruning(query, text) - results.append(scores[0]) - else: - # Empty string fallback - use standard forward pass result - logger.warning(f"XProvence: Empty query/text at index {i}, using 0.0") - results.append(Score(values=[0.0], pruned_text=None)) - return results - - logger.info("XProvence: Using standard forward pass (no raw_queries/raw_texts)") - return self._predict_standard(batch) + return results def _predict_with_pruning(self, raw_query: str, raw_text: str) -> List[Score]: """