Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 84 additions & 74 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,123 +1,133 @@
FROM lukemathwalker/cargo-chef:latest-rust-1.85-bookworm AS chef
WORKDIR /usr/src
# Dockerfile for TEI with Python backend and CUDA support
# Supports: L40s (sm_89), RTX 3090 (sm_86)

# =============================================================================
# Stage 1: Rust Builder
# =============================================================================
FROM nvidia/cuda:12.4.0-devel-ubuntu22.04 AS rust-builder

ENV SCCACHE=0.10.0
ENV RUSTC_WRAPPER=/usr/local/bin/sccache
ENV PATH="/root/.cargo/bin:${PATH}"
ENV CARGO_CHEF=0.1.71

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
curl \
libssl-dev \
pkg-config \
protobuf-compiler \
&& rm -rf /var/lib/apt/lists/*

# Donwload, configure sccache
RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \
chmod +x /usr/local/bin/sccache

FROM chef AS planner
RUN curl https://sh.rustup.rs -sSf | bash -s -- -y
RUN cargo install cargo-chef --version $CARGO_CHEF --locked

# =============================================================================
# Stage 2: Recipe Planner
# =============================================================================
FROM rust-builder AS planner

WORKDIR /usr/src

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

RUN cargo chef prepare --recipe-path recipe.json
RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder
# =============================================================================
# Stage 3: Dependency Builder
# =============================================================================
FROM rust-builder AS builder

ARG GIT_SHA
ARG DOCKER_LABEL

# sccache specific variables
ARG SCCACHE_GHA_ENABLED

RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | \
tee /etc/apt/sources.list.d/oneAPI.list

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
intel-oneapi-mkl-devel=2024.0.0-49656 \
build-essential \
&& rm -rf /var/lib/apt/lists/*

RUN echo "int mkl_serv_intel_cpu_true() {return 1;}" > fakeintel.c && \
gcc -shared -fPIC -o libfakeintel.so fakeintel.c
WORKDIR /usr/src

COPY --from=planner /usr/src/recipe.json recipe.json

RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
cargo chef cook --release --features ort,candle,mkl,static-linking --no-default-features --recipe-path recipe.json && sccache -s
RUN cargo chef cook --release --features python --features http --recipe-path recipe.json && sccache -s

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

FROM builder AS http-builder
RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s

RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,http --no-default-features && sccache -s
# =============================================================================
# Stage 4: Python Environment
# =============================================================================
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 AS python-builder

FROM builder AS grpc-builder
ENV DEBIAN_FRONTEND=noninteractive

RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.10 \
python3.10-dev \
python3-pip \
git \
&& rm -rf /var/lib/apt/lists/*

COPY proto proto
RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \
ln -sf /usr/bin/python3.10 /usr/bin/python3

RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,grpc --no-default-features && sccache -s
RUN pip install --no-cache-dir --upgrade pip setuptools wheel

FROM debian:bookworm-slim AS base
WORKDIR /opt/server

ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80 \
MKL_ENABLE_INSTRUCTIONS=AVX512_E4 \
RAYON_NUM_THREADS=8 \
LD_PRELOAD=/usr/local/libfakeintel.so \
LD_LIBRARY_PATH=/usr/local/lib
COPY backends/proto /opt/proto
COPY backends/python/server /opt/server

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
libomp-dev \
ca-certificates \
libssl-dev \
curl \
&& rm -rf /var/lib/apt/lists/*
RUN pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir && \
mkdir -p text_embeddings_server/pb && \
python -m grpc_tools.protoc -I/opt/proto --python_out=text_embeddings_server/pb \
--grpc_python_out=text_embeddings_server/pb --mypy_out=text_embeddings_server/pb /opt/proto/embed.proto && \
find text_embeddings_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; && \
touch text_embeddings_server/pb/__init__.py

# Copy a lot of the Intel shared objects because of the mkl_serv_intel_cpu_true patch...
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_lp64.so.2 /usr/local/lib/libmkl_intel_lp64.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_thread.so.2 /usr/local/lib/libmkl_intel_thread.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so.2 /usr/local/lib/libmkl_core.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_def.so.2 /usr/local/lib/libmkl_vml_def.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_def.so.2 /usr/local/lib/libmkl_def.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx2.so.2 /usr/local/lib/libmkl_vml_avx2.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx512.so.2 /usr/local/lib/libmkl_vml_avx512.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx2.so.2 /usr/local/lib/libmkl_avx2.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx512.so.2 /usr/local/lib/libmkl_avx512.so.2
COPY --from=builder /usr/src/libfakeintel.so /usr/local/libfakeintel.so
RUN pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu124

FROM base AS grpc
RUN pip install --no-cache-dir -r requirements.txt

COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
RUN pip install --no-cache-dir -e .

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]
# =============================================================================
# Stage 5: Final Image
# =============================================================================
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04

ENV DEBIAN_FRONTEND=noninteractive
ENV HUGGINGFACE_HUB_CACHE=/data
ENV PORT=80
ENV TQDM_DISABLE=1

RUN apt-get update && apt-get install -y --no-install-recommends \
python3.10 \
python3-pip \
ca-certificates \
libssl-dev \
curl \
&& rm -rf /var/lib/apt/lists/*

FROM base AS http
RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \
ln -sf /usr/bin/python3.10 /usr/bin/python3

COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
COPY --from=python-builder /usr/local/lib/python3.10/dist-packages /usr/local/lib/python3.10/dist-packages
COPY --from=python-builder /opt/server /opt/server

# Amazon SageMaker compatible image
FROM http AS sagemaker
COPY --chmod=775 sagemaker-entrypoint.sh entrypoint.sh
COPY --from=builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["./entrypoint.sh"]
ENV PATH="/usr/local/bin:${PATH}"
ENV PYTHONPATH="/opt/server:${PYTHONPATH}"

# Default image
FROM http
WORKDIR /opt/server

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]
7 changes: 5 additions & 2 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use serde::{de::Deserializer, Deserialize};
use std::collections::HashMap;
use std::path::Path;
use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Prediction, Predictions,
};

#[cfg(feature = "cuda")]
Expand Down Expand Up @@ -653,7 +653,10 @@ impl Backend for CandleBackend {
let mut predictions =
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
for (i, r) in results.into_iter().enumerate() {
predictions.insert(i, r);
predictions.insert(i, Prediction {
scores: r,
pruned_text: None,
});
}

Ok(predictions)
Expand Down
15 changes: 14 additions & 1 deletion backends/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ pub struct Batch {
pub max_length: u32,
pub pooled_indices: Vec<u32>,
pub raw_indices: Vec<u32>,
/// XProvence: raw query texts for context pruning
pub raw_queries: Vec<Option<String>>,
/// XProvence: raw context texts for context pruning
pub raw_texts: Vec<Option<String>>,
}

impl Batch {
Expand All @@ -32,7 +36,16 @@ pub enum Embedding {
}

pub type Embeddings = IntMap<usize, Embedding>;
pub type Predictions = IntMap<usize, Vec<f32>>;

/// XProvence: Prediction result containing scores and optional pruned text
#[derive(Debug, Clone)]
pub struct Prediction {
pub scores: Vec<f32>,
/// XProvence: pruned context text after removing irrelevant sentences
pub pruned_text: Option<String>,
}

pub type Predictions = IntMap<usize, Prediction>;

pub trait Backend {
fn health(&self) -> Result<(), BackendError>;
Expand Down
6 changes: 6 additions & 0 deletions backends/grpc-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ impl Client {
position_ids,
max_length,
cu_seq_lengths,
raw_query: None,
raw_text: None,
})
.inject_context();
let response = self.stub.embed(request).await?.into_inner();
Expand All @@ -73,13 +75,17 @@ impl Client {
position_ids: Vec<u32>,
cu_seq_lengths: Vec<u32>,
max_length: u32,
raw_query: Option<String>,
raw_text: Option<String>,
) -> Result<Vec<Score>> {
let request = tonic::Request::new(EmbedRequest {
input_ids,
token_type_ids,
position_ids,
max_length,
cu_seq_lengths,
raw_query,
raw_text,
})
.inject_context();
let response = self.stub.predict(request).await?.into_inner();
Expand Down
7 changes: 5 additions & 2 deletions backends/ort/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::ops::{Div, Mul};
use std::path::Path;
use std::sync::Mutex;
use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Prediction, Predictions,
};

#[derive(Debug, Clone, Deserialize)]
Expand Down Expand Up @@ -679,7 +679,10 @@ impl Backend for OrtBackend {
let mut predictions =
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
for (i, r) in outputs.rows().into_iter().enumerate() {
predictions.insert(i, r.to_vec());
predictions.insert(i, Prediction {
scores: r.to_vec(),
pruned_text: None,
});
}

Ok(predictions)
Expand Down
6 changes: 6 additions & 0 deletions backends/proto/embed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ message EmbedRequest {
repeated uint32 cu_seq_lengths = 4;
/// Length of the longest request
uint32 max_length = 5;
/// XProvence: raw query text for context pruning
optional string raw_query = 6;
/// XProvence: raw context text for context pruning
optional string raw_text = 7;
}

message Embedding {
Expand All @@ -33,6 +37,8 @@ message EmbedResponse {

message Score {
repeated float values = 1;
/// XProvence: pruned context text after removing irrelevant sentences
optional string pruned_text = 2;
}

message PredictResponse {
Expand Down
31 changes: 23 additions & 8 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,19 @@
from text_embeddings_server.models.masked_model import MaskedLanguageModel
from text_embeddings_server.models.default_model import DefaultModel
from text_embeddings_server.models.classification_model import ClassificationModel
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
from text_embeddings_server.models.flash_mistral import FlashMistral
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
from text_embeddings_server.models.xprovence_model import XProvenceModel
from text_embeddings_server.utils.device import get_device, use_ipex

FlashJinaBert = None
FlashMistral = None
FlashQwen3 = None
try:
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
from text_embeddings_server.models.flash_mistral import FlashMistral
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
except ImportError as e:
logger.warning(f"Flash attention models not available: {e}")

__all__ = ["Model"]

TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"]
Expand Down Expand Up @@ -76,13 +84,21 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)

if (
hasattr(config, "auto_map")
hasattr(config, "architectures")
and config.architectures
and "XProvence" in config.architectures[0]
):
logger.info("Detected XProvence model for context pruning")
return XProvenceModel(model_path, device, datatype, trust_remote=True)

if (
FlashJinaBert is not None
and hasattr(config, "auto_map")
and isinstance(config.auto_map, dict)
and "AutoModel" in config.auto_map
and config.auto_map["AutoModel"]
== "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel"
):
# Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository
return create_model(FlashJinaBert, model_path, device, datatype)

if config.model_type == "bert":
Expand Down Expand Up @@ -116,19 +132,18 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
else:
return create_model(DefaultModel, model_path, device, datatype, pool)

if config.model_type == "mistral" and device.type == "hpu":
if FlashMistral is not None and config.model_type == "mistral" and device.type == "hpu":
try:
return create_model(FlashMistral, model_path, device, datatype, pool)
except FileNotFoundError:
return create_model(DefaultModel, model_path, device, datatype, pool)

if config.model_type == "qwen3" and device.type == "hpu":
if FlashQwen3 is not None and config.model_type == "qwen3" and device.type == "hpu":
try:
return create_model(FlashQwen3, model_path, device, datatype, pool)
except FileNotFoundError:
return create_model(DefaultModel, model_path, device, datatype, pool)

# Default case
if config.architectures[0].endswith("Classification"):
return create_model(ClassificationModel, model_path, device, datatype)
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
Expand Down
Loading