Skip to content

Commit 5be80a0

Browse files
author
Sigrid Jin
committed
feat: add XProvence context pruning support for RAG
Add XProvence model integration for zero-cost context pruning in reranking. XProvence removes irrelevant sentences from passages based on query relevance, returning both reranking scores and pruned context. Changes: - Add XProvenceModel class with process() method for sentence-level pruning - Add pruned_text field to Score/Prediction types and HTTP response - Pass raw_query/raw_text through tokenization pipeline for pruning - Make flash_attn imports optional for XProvence compatibility - Add XProvence architecture detection in router and Python backend - Handle bfloat16 to float32 conversion for XProvence process() method - Update candle, ort backends to support Prediction with pruned_text - Add Dockerfile-cuda-python for Python backend with CUDA support Configuration: - XPROVENCE_THRESHOLD: Pruning threshold 0.0-1.0 (default: 0.3) - XPROVENCE_ALWAYS_SELECT_TITLE: Keep first sentence as title (default: true) Usage: XPROVENCE_THRESHOLD=0.3 text-embeddings-router \ --model-id naver/xprovence-reranker-bgem3-v1 --port 8080 Docker build: docker build -f Dockerfile-cuda-python -t tei-python-cuda .
1 parent 78502d8 commit 5be80a0

File tree

17 files changed

+394
-98
lines changed

17 files changed

+394
-98
lines changed

Dockerfile

Lines changed: 84 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,123 +1,133 @@
1-
FROM lukemathwalker/cargo-chef:latest-rust-1.85-bookworm AS chef
2-
WORKDIR /usr/src
1+
# Dockerfile for TEI with Python backend and CUDA support
2+
# Supports: L40s (sm_89), RTX 3090 (sm_86)
3+
4+
# =============================================================================
5+
# Stage 1: Rust Builder
6+
# =============================================================================
7+
FROM nvidia/cuda:12.4.0-devel-ubuntu22.04 AS rust-builder
38

49
ENV SCCACHE=0.10.0
510
ENV RUSTC_WRAPPER=/usr/local/bin/sccache
11+
ENV PATH="/root/.cargo/bin:${PATH}"
12+
ENV CARGO_CHEF=0.1.71
13+
14+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
15+
curl \
16+
libssl-dev \
17+
pkg-config \
18+
protobuf-compiler \
19+
&& rm -rf /var/lib/apt/lists/*
620

7-
# Donwload, configure sccache
821
RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \
922
chmod +x /usr/local/bin/sccache
1023

11-
FROM chef AS planner
24+
RUN curl https://sh.rustup.rs -sSf | bash -s -- -y
25+
RUN cargo install cargo-chef --version $CARGO_CHEF --locked
26+
27+
# =============================================================================
28+
# Stage 2: Recipe Planner
29+
# =============================================================================
30+
FROM rust-builder AS planner
31+
32+
WORKDIR /usr/src
1233

1334
COPY backends backends
1435
COPY core core
1536
COPY router router
1637
COPY Cargo.toml ./
1738
COPY Cargo.lock ./
1839

19-
RUN cargo chef prepare --recipe-path recipe.json
40+
RUN cargo chef prepare --recipe-path recipe.json
2041

21-
FROM chef AS builder
42+
# =============================================================================
43+
# Stage 3: Dependency Builder
44+
# =============================================================================
45+
FROM rust-builder AS builder
2246

2347
ARG GIT_SHA
2448
ARG DOCKER_LABEL
2549

26-
# sccache specific variables
27-
ARG SCCACHE_GHA_ENABLED
28-
29-
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
30-
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
31-
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | \
32-
tee /etc/apt/sources.list.d/oneAPI.list
33-
34-
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
35-
intel-oneapi-mkl-devel=2024.0.0-49656 \
36-
build-essential \
37-
&& rm -rf /var/lib/apt/lists/*
38-
39-
RUN echo "int mkl_serv_intel_cpu_true() {return 1;}" > fakeintel.c && \
40-
gcc -shared -fPIC -o libfakeintel.so fakeintel.c
50+
WORKDIR /usr/src
4151

4252
COPY --from=planner /usr/src/recipe.json recipe.json
4353

44-
RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
45-
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
46-
cargo chef cook --release --features ort,candle,mkl,static-linking --no-default-features --recipe-path recipe.json && sccache -s
54+
RUN cargo chef cook --release --features python --features http --recipe-path recipe.json && sccache -s
4755

4856
COPY backends backends
4957
COPY core core
5058
COPY router router
5159
COPY Cargo.toml ./
5260
COPY Cargo.lock ./
5361

54-
FROM builder AS http-builder
62+
RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s
5563

56-
RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
57-
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
58-
cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,http --no-default-features && sccache -s
64+
# =============================================================================
65+
# Stage 4: Python Environment
66+
# =============================================================================
67+
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 AS python-builder
5968

60-
FROM builder AS grpc-builder
69+
ENV DEBIAN_FRONTEND=noninteractive
6170

62-
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
63-
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
64-
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
65-
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
66-
rm -f $PROTOC_ZIP
71+
RUN apt-get update && apt-get install -y --no-install-recommends \
72+
python3.10 \
73+
python3.10-dev \
74+
python3-pip \
75+
git \
76+
&& rm -rf /var/lib/apt/lists/*
6777

68-
COPY proto proto
78+
RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \
79+
ln -sf /usr/bin/python3.10 /usr/bin/python3
6980

70-
RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
71-
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
72-
cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,grpc --no-default-features && sccache -s
81+
RUN pip install --no-cache-dir --upgrade pip setuptools wheel
7382

74-
FROM debian:bookworm-slim AS base
83+
WORKDIR /opt/server
7584

76-
ENV HUGGINGFACE_HUB_CACHE=/data \
77-
PORT=80 \
78-
MKL_ENABLE_INSTRUCTIONS=AVX512_E4 \
79-
RAYON_NUM_THREADS=8 \
80-
LD_PRELOAD=/usr/local/libfakeintel.so \
81-
LD_LIBRARY_PATH=/usr/local/lib
85+
COPY backends/proto /opt/proto
86+
COPY backends/python/server /opt/server
8287

83-
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
84-
libomp-dev \
85-
ca-certificates \
86-
libssl-dev \
87-
curl \
88-
&& rm -rf /var/lib/apt/lists/*
88+
RUN pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir && \
89+
mkdir -p text_embeddings_server/pb && \
90+
python -m grpc_tools.protoc -I/opt/proto --python_out=text_embeddings_server/pb \
91+
--grpc_python_out=text_embeddings_server/pb --mypy_out=text_embeddings_server/pb /opt/proto/embed.proto && \
92+
find text_embeddings_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; && \
93+
touch text_embeddings_server/pb/__init__.py
8994

90-
# Copy a lot of the Intel shared objects because of the mkl_serv_intel_cpu_true patch...
91-
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_lp64.so.2 /usr/local/lib/libmkl_intel_lp64.so.2
92-
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_thread.so.2 /usr/local/lib/libmkl_intel_thread.so.2
93-
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so.2 /usr/local/lib/libmkl_core.so.2
94-
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_def.so.2 /usr/local/lib/libmkl_vml_def.so.2
95-
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_def.so.2 /usr/local/lib/libmkl_def.so.2
96-
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx2.so.2 /usr/local/lib/libmkl_vml_avx2.so.2
97-
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx512.so.2 /usr/local/lib/libmkl_vml_avx512.so.2
98-
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx2.so.2 /usr/local/lib/libmkl_avx2.so.2
99-
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx512.so.2 /usr/local/lib/libmkl_avx512.so.2
100-
COPY --from=builder /usr/src/libfakeintel.so /usr/local/libfakeintel.so
95+
RUN pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu124
10196

102-
FROM base AS grpc
97+
RUN pip install --no-cache-dir -r requirements.txt
10398

104-
COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
99+
RUN pip install --no-cache-dir -e .
105100

106-
ENTRYPOINT ["text-embeddings-router"]
107-
CMD ["--json-output"]
101+
# =============================================================================
102+
# Stage 5: Final Image
103+
# =============================================================================
104+
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04
105+
106+
ENV DEBIAN_FRONTEND=noninteractive
107+
ENV HUGGINGFACE_HUB_CACHE=/data
108+
ENV PORT=80
109+
ENV TQDM_DISABLE=1
110+
111+
RUN apt-get update && apt-get install -y --no-install-recommends \
112+
python3.10 \
113+
python3-pip \
114+
ca-certificates \
115+
libssl-dev \
116+
curl \
117+
&& rm -rf /var/lib/apt/lists/*
108118

109-
FROM base AS http
119+
RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \
120+
ln -sf /usr/bin/python3.10 /usr/bin/python3
110121

111-
COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
122+
COPY --from=python-builder /usr/local/lib/python3.10/dist-packages /usr/local/lib/python3.10/dist-packages
123+
COPY --from=python-builder /opt/server /opt/server
112124

113-
# Amazon SageMaker compatible image
114-
FROM http AS sagemaker
115-
COPY --chmod=775 sagemaker-entrypoint.sh entrypoint.sh
125+
COPY --from=builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
116126

117-
ENTRYPOINT ["./entrypoint.sh"]
127+
ENV PATH="/usr/local/bin:${PATH}"
128+
ENV PYTHONPATH="/opt/server:${PYTHONPATH}"
118129

119-
# Default image
120-
FROM http
130+
WORKDIR /opt/server
121131

122132
ENTRYPOINT ["text-embeddings-router"]
123133
CMD ["--json-output"]

backends/candle/src/lib.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use serde::{de::Deserializer, Deserialize};
1414
use std::collections::HashMap;
1515
use std::path::Path;
1616
use text_embeddings_backend_core::{
17-
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
17+
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Prediction, Predictions,
1818
};
1919

2020
#[cfg(feature = "cuda")]
@@ -653,7 +653,10 @@ impl Backend for CandleBackend {
653653
let mut predictions =
654654
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
655655
for (i, r) in results.into_iter().enumerate() {
656-
predictions.insert(i, r);
656+
predictions.insert(i, Prediction {
657+
scores: r,
658+
pruned_text: None,
659+
});
657660
}
658661

659662
Ok(predictions)

backends/core/src/lib.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ pub struct Batch {
1414
pub max_length: u32,
1515
pub pooled_indices: Vec<u32>,
1616
pub raw_indices: Vec<u32>,
17+
/// XProvence: raw query texts for context pruning
18+
pub raw_queries: Vec<Option<String>>,
19+
/// XProvence: raw context texts for context pruning
20+
pub raw_texts: Vec<Option<String>>,
1721
}
1822

1923
impl Batch {
@@ -32,7 +36,16 @@ pub enum Embedding {
3236
}
3337

3438
pub type Embeddings = IntMap<usize, Embedding>;
35-
pub type Predictions = IntMap<usize, Vec<f32>>;
39+
40+
/// XProvence: Prediction result containing scores and optional pruned text
41+
#[derive(Debug, Clone)]
42+
pub struct Prediction {
43+
pub scores: Vec<f32>,
44+
/// XProvence: pruned context text after removing irrelevant sentences
45+
pub pruned_text: Option<String>,
46+
}
47+
48+
pub type Predictions = IntMap<usize, Prediction>;
3649

3750
pub trait Backend {
3851
fn health(&self) -> Result<(), BackendError>;

backends/grpc-client/src/client.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ impl Client {
5959
position_ids,
6060
max_length,
6161
cu_seq_lengths,
62+
raw_query: None,
63+
raw_text: None,
6264
})
6365
.inject_context();
6466
let response = self.stub.embed(request).await?.into_inner();
@@ -73,13 +75,17 @@ impl Client {
7375
position_ids: Vec<u32>,
7476
cu_seq_lengths: Vec<u32>,
7577
max_length: u32,
78+
raw_query: Option<String>,
79+
raw_text: Option<String>,
7680
) -> Result<Vec<Score>> {
7781
let request = tonic::Request::new(EmbedRequest {
7882
input_ids,
7983
token_type_ids,
8084
position_ids,
8185
max_length,
8286
cu_seq_lengths,
87+
raw_query,
88+
raw_text,
8389
})
8490
.inject_context();
8591
let response = self.stub.predict(request).await?.into_inner();

backends/ort/src/lib.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use std::ops::{Div, Mul};
88
use std::path::Path;
99
use std::sync::Mutex;
1010
use text_embeddings_backend_core::{
11-
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
11+
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Prediction, Predictions,
1212
};
1313

1414
#[derive(Debug, Clone, Deserialize)]
@@ -679,7 +679,10 @@ impl Backend for OrtBackend {
679679
let mut predictions =
680680
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
681681
for (i, r) in outputs.rows().into_iter().enumerate() {
682-
predictions.insert(i, r.to_vec());
682+
predictions.insert(i, Prediction {
683+
scores: r.to_vec(),
684+
pruned_text: None,
685+
});
683686
}
684687

685688
Ok(predictions)

backends/proto/embed.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ message EmbedRequest {
2121
repeated uint32 cu_seq_lengths = 4;
2222
/// Length of the longest request
2323
uint32 max_length = 5;
24+
/// XProvence: raw query text for context pruning
25+
optional string raw_query = 6;
26+
/// XProvence: raw context text for context pruning
27+
optional string raw_text = 7;
2428
}
2529

2630
message Embedding {
@@ -33,6 +37,8 @@ message EmbedResponse {
3337

3438
message Score {
3539
repeated float values = 1;
40+
/// XProvence: pruned context text after removing irrelevant sentences
41+
optional string pruned_text = 2;
3642
}
3743

3844
message PredictResponse {

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,19 @@
1111
from text_embeddings_server.models.masked_model import MaskedLanguageModel
1212
from text_embeddings_server.models.default_model import DefaultModel
1313
from text_embeddings_server.models.classification_model import ClassificationModel
14-
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
15-
from text_embeddings_server.models.flash_mistral import FlashMistral
16-
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
14+
from text_embeddings_server.models.xprovence_model import XProvenceModel
1715
from text_embeddings_server.utils.device import get_device, use_ipex
1816

17+
FlashJinaBert = None
18+
FlashMistral = None
19+
FlashQwen3 = None
20+
try:
21+
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
22+
from text_embeddings_server.models.flash_mistral import FlashMistral
23+
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
24+
except ImportError as e:
25+
logger.warning(f"Flash attention models not available: {e}")
26+
1927
__all__ = ["Model"]
2028

2129
TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"]
@@ -76,13 +84,21 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
7684
config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
7785

7886
if (
79-
hasattr(config, "auto_map")
87+
hasattr(config, "architectures")
88+
and config.architectures
89+
and "XProvence" in config.architectures[0]
90+
):
91+
logger.info("Detected XProvence model for context pruning")
92+
return XProvenceModel(model_path, device, datatype, trust_remote=True)
93+
94+
if (
95+
FlashJinaBert is not None
96+
and hasattr(config, "auto_map")
8097
and isinstance(config.auto_map, dict)
8198
and "AutoModel" in config.auto_map
8299
and config.auto_map["AutoModel"]
83100
== "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel"
84101
):
85-
# Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository
86102
return create_model(FlashJinaBert, model_path, device, datatype)
87103

88104
if config.model_type == "bert":
@@ -116,19 +132,18 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
116132
else:
117133
return create_model(DefaultModel, model_path, device, datatype, pool)
118134

119-
if config.model_type == "mistral" and device.type == "hpu":
135+
if FlashMistral is not None and config.model_type == "mistral" and device.type == "hpu":
120136
try:
121137
return create_model(FlashMistral, model_path, device, datatype, pool)
122138
except FileNotFoundError:
123139
return create_model(DefaultModel, model_path, device, datatype, pool)
124140

125-
if config.model_type == "qwen3" and device.type == "hpu":
141+
if FlashQwen3 is not None and config.model_type == "qwen3" and device.type == "hpu":
126142
try:
127143
return create_model(FlashQwen3, model_path, device, datatype, pool)
128144
except FileNotFoundError:
129145
return create_model(DefaultModel, model_path, device, datatype, pool)
130146

131-
# Default case
132147
if config.architectures[0].endswith("Classification"):
133148
return create_model(ClassificationModel, model_path, device, datatype)
134149
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":

0 commit comments

Comments
 (0)