diff --git a/.github/scripts/free-up-disk-space-fast.sh b/.github/scripts/free-up-disk-space-fast.sh index 2e8ddef7..ec3fcdd5 100755 --- a/.github/scripts/free-up-disk-space-fast.sh +++ b/.github/scripts/free-up-disk-space-fast.sh @@ -1,5 +1,7 @@ #!/bin/sh +# Disk space cleanup # https://dev.to/mathio/squeezing-disk-space-from-github-actions-runners-an-engineers-guide-3pjg + # Remove Java (JDKs) sudo rm -rf /usr/lib/jvm diff --git a/.github/workflows/core.yml b/.github/workflows/core.yml index efea555f..d08ef8ee 100644 --- a/.github/workflows/core.yml +++ b/.github/workflows/core.yml @@ -59,22 +59,12 @@ jobs: - name: verify feature consistency run: .github/scripts/verify-feature-consistency.sh - - name: install ollama - run: | - # Disk space cleanup # https://dev.to/mathio/squeezing-disk-space-from-github-actions-runners-an-engineers-guide-3pjg - ./.github/scripts/free-up-disk-space-fast.sh & - - if [ "$RUNNER_OS" = "Linux" ]; then - curl -fsSL https://ollama.com/install.sh | bash - # note: install.sh will start ollama as an systemd service, no need to start it ourselves (actively harmful -- port conflict) - elif [ "$RUNNER_OS" = "macOS" ]; then - brew update - brew install ollama - brew services start ollama - fi - wait # for disk space cleanup + - name: free up disk space + run: ./.github/scripts/free-up-disk-space-fast.sh - name: cargo test + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | if [ "$RUNNER_OS" = "macOS" ]; then # IMPORTANT: Cannot use --all-features because mistralrs (used by local/metal/cuda features) @@ -82,10 +72,10 @@ jobs: # # When adding new features to spnl/Cargo.toml, update this list! # Current features tested (from spnl/Cargo.toml): - # - Included: cli_support, print, lisp, run, ollama, openai, gemini, pull, yaml, metal, rag, rag-deep-debug, spnl-api, vllm, k8s, gce, ffi, pypi, run_py, tok, openssl-vendored + # - Included: cli_support, print, lisp, run, ollama, openai, gemini, yaml, metal, rag, rag-deep-debug, spnl-api, vllm, k8s, gce, ffi, pypi, run_py, tok, openssl-vendored # - Excluded: local (CPU inference - pulls in CUDA deps via mistralrs) # - Excluded: cuda, cuda-flash-attn, cuda-flash-attn-v3 (CUDA features) - cargo test -p spnl --features cli_support,print,lisp,run,ollama,openai,gemini,pull,yaml,metal,rag,rag-deep-debug,spnl-api,vllm,k8s,gce,ffi,pypi,run_py,tok,openssl-vendored -- --nocapture + cargo test -p spnl --features cli_support,print,lisp,run,ollama,openai,gemini,yaml,metal,rag,rag-deep-debug,spnl-api,vllm,k8s,gce,ffi,pypi,run_py,tok,openssl-vendored -- --nocapture cargo test -p spnl-cli --features rag,spnl-api,vllm,k8s,gce,local,metal -- --nocapture else # Test default features on Linux (no GPU features) @@ -101,7 +91,7 @@ jobs: # When adding new features, update the feature list above in cargo test! if: runner.os == 'macOS' run: | - cargo clippy -p spnl --features cli_support,print,lisp,run,ollama,openai,gemini,pull,yaml,metal,rag,rag-deep-debug,spnl-api,vllm,k8s,gce,ffi,pypi,run_py,tok,openssl-vendored --tests --no-deps -- -D warnings + cargo clippy -p spnl --features cli_support,print,lisp,run,ollama,openai,gemini,yaml,metal,rag,rag-deep-debug,spnl-api,vllm,k8s,gce,ffi,pypi,run_py,tok,openssl-vendored --tests --no-deps -- -D warnings cargo clippy -p spnl-cli --features rag,spnl-api,vllm,k8s,gce,local,metal --tests --no-deps -- -D warnings - name: rustfmt diff --git a/.github/workflows/release-cli.yml b/.github/workflows/release-cli.yml index 9cf6878d..423ed89d 100644 --- a/.github/workflows/release-cli.yml +++ b/.github/workflows/release-cli.yml @@ -73,15 +73,15 @@ jobs: local_feature: local name: Windows x86_64 - - runner: windows-latest - target: aarch64-pc-windows-msvc - platform: windows - arch: aarch64 - libc: "" - # local: very slow to build on aarch64 windows (60+ minutes); disabled for now. maybe we can restore if we figure out rustc caching - #local_feature: local - #rustflags: "-Ctarget-feature=+fp16,+fhm" # https://github.com/sarah-quinones/gemm/issues/31 avoid "error: instruction requires: fullfp16" - name: Windows ARM64 + # TODO. the local feature is very slow to build on aarch64 windows (60+ minutes); disabled for now. maybe we can restore if we figure out rustc caching + # - runner: windows-latest + # target: aarch64-pc-windows-msvc + # platform: windows + # arch: aarch64 + # libc: "" + # local_feature: local + # rustflags: "-Ctarget-feature=+fp16,+fhm" # https://github.com/sarah-quinones/gemm/issues/31 avoid "error: instruction requires: fullfp16" + # name: Windows ARM64 env: TARGET: ${{ matrix.platform.target }} diff --git a/Cargo.lock b/Cargo.lock index 16547833..28b9a517 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10370,7 +10370,6 @@ dependencies = [ "tokenizers 0.22.2", "tokio", "tokio-stream", - "tokio-util", "tower-test", "tracing", "uuid 1.21.0", diff --git a/cli/src/args.rs b/cli/src/args.rs index 69b23fb7..63ffaf78 100644 --- a/cli/src/args.rs +++ b/cli/src/args.rs @@ -43,19 +43,14 @@ pub struct Args { pub builtin: Option, /// Generative Model - #[arg( - short, - long, - default_value = "ollama/granite3.3:2b", - env = "SPNL_MODEL" - )] + #[arg(short, long, default_value = "llama3.2:1b", env = "SPNL_MODEL")] pub model: String, /// Embedding Model #[arg( short, long, - default_value = "ollama/mxbai-embed-large:335m", + default_value = "local/google/embeddinggemma-300m", env = "SPNL_EMBEDDING_MODEL" )] pub embedding_model: String, diff --git a/docker/Containerfile.hostbuild b/docker/Containerfile.hostbuild deleted file mode 100644 index 0a147272..00000000 --- a/docker/Containerfile.hostbuild +++ /dev/null @@ -1,14 +0,0 @@ -FROM debian:stable-slim - -ARG TARGET=target/release/spnl -COPY $TARGET /usr/local/bin/spnl -RUN groupadd -g 1001 spnl && useradd -u 1001 -g spnl spnl && mkdir /home/spnl && chown spnl /home/spnl -USER spnl - -LABEL org.opencontainers.image.source=https://github.com/IBM/spnl -LABEL org.opencontainers.image.description="Span Query CLI" -LABEL org.opencontainers.image.licenses="Apache-2.0" - -ENV SPNL_BUILTIN=email2 - -ENTRYPOINT ["spnl"] diff --git a/docker/Containerfile.hostbuild.ollama b/docker/Containerfile.hostbuild.ollama deleted file mode 100644 index 757973af..00000000 --- a/docker/Containerfile.hostbuild.ollama +++ /dev/null @@ -1,30 +0,0 @@ -FROM debian:stable-slim - -LABEL org.opencontainers.image.source=https://github.com/IBM/spnl -LABEL org.opencontainers.image.description="Span Query CLI with Ollama" -LABEL org.opencontainers.image.licenses="Apache-2.0" - -ARG TARGET=target/release/spnl -ARG DEBIAN_FRONTEND=noninteractive - -RUN apt update && \ - apt install -y ca-certificates curl lshw zstd && \ - curl -fsSL https://ollama.com/install.sh | sh && \ - apt remove -y ca-certificates curl && apt autoremove -y && \ - rm -rf /var/lib/apt/lists/* - -RUN printf '#!/bin/sh\nollama serve > /dev/null 2> /dev/null & spnl $@\n' > /usr/local/bin/spnl.sh && chmod a+rx /usr/local/bin/spnl.sh -RUN groupadd -g 1001 spnl && useradd -u 1001 -g spnl spnl && mkdir /home/spnl && chown spnl /home/spnl - -USER spnl - -ARG SPNL_MODEL_BASE=smollm:135m -RUN ollama serve & sleep 5 ; ollama pull $SPNL_MODEL_BASE - -ENV SPNL_BUILTIN=email2 -ENV SPNL_MODEL=ollama/$SPNL_MODEL_BASE -ENV OLLAMA_NUM_PARALLEL=4 - -COPY $TARGET /usr/local/bin/spnl - -ENTRYPOINT ["spnl.sh"] diff --git a/docker/gce/vllm/create-vllm-gce-image.sh b/docker/gce/vllm/create-vllm-gce-image.sh index 8da56de4..f2822998 100755 --- a/docker/gce/vllm/create-vllm-gce-image.sh +++ b/docker/gce/vllm/create-vllm-gce-image.sh @@ -143,7 +143,7 @@ if [[ -f /etc/environment ]]; then fi echo "=== Disabling unnecessary services ===" -# Disable services not needed for vLLM/ollama +# Disable services not needed for vLLM sudo systemctl disable snapd.service || true sudo systemctl disable snapd.socket || true sudo systemctl disable unattended-upgrades.service || true @@ -176,9 +176,6 @@ uv venv --seed source .venv/bin/activate VLLM_USE_PRECOMPILED=1 uv pip install --editable . -echo "=== Installing ollama ===" -curl -fsSL https://ollama.com/install.sh | sh - echo "=== Creating systemd service for vLLM ===" # Create directory for vLLM configuration sudo mkdir -p /etc/vllm @@ -219,10 +216,6 @@ StandardError=journal WantedBy=multi-user.target VLLM_SERVICE_EOF -echo "=== Creating systemd service for Ollama ===" -# Create Ollama systemd service (ollama install.sh already creates one, but we ensure it's enabled) -sudo systemctl enable ollama.service - echo "=== Enabling services to start at boot ===" sudo systemctl enable vllm.service diff --git a/docker/gce/vllm/setup-dev.sh b/docker/gce/vllm/setup-dev.sh index 742b9feb..ce37f196 100644 --- a/docker/gce/vllm/setup-dev.sh +++ b/docker/gce/vllm/setup-dev.sh @@ -137,22 +137,16 @@ VLLM_ATTENTION_BACKEND=TRITON_ATTN \ VLLM_SERVER_DEV_MODE=1 \ nohup vllm serve $MODEL --enforce-eager & -# Install ollama (for embedding) -(curl -fsSL https://ollama.com/install.sh | sh && ollama serve) & - # Wait till vllm is ready timeout 5m bash -c 'until curl --output /dev/null --silent --fail http://localhost:8000/health; do sleep 3; done' echo "vllm is ready" -# Wait till ollama is ready -timeout 5m bash -c 'until ollama ps; do sleep 3; done' -echo "ollama is ready" - # Run tests # Here are the variables we will allow to be used in the test.d/* scripts declare -x GCS_BUCKET declare -x RUN_ID declare -x MODEL +declare -x HF_TOKEN declare -x OPENAI_API_BASE=http://localhost:8000/v1 cd $HOME diff --git a/docker/gce/vllm/test.d/spnl-speedup.sh b/docker/gce/vllm/test.d/spnl-speedup.sh index 0a66312b..6bb22316 100755 --- a/docker/gce/vllm/test.d/spnl-speedup.sh +++ b/docker/gce/vllm/test.d/spnl-speedup.sh @@ -8,8 +8,6 @@ SCRIPTDIR=$(cd $(dirname "$0") && pwd) #set -x # debug -export SPNL_EMBEDDING_MODEL=ollama/qwen3-embedding:0.6b - # TODO: make at least the inner-most loop bound a parameter rather than hard-coded for b in email2 rag do diff --git a/docs/feature-flags.md b/docs/feature-flags.md index f7c949c3..670b7e7e 100644 --- a/docs/feature-flags.md +++ b/docs/feature-flags.md @@ -13,30 +13,78 @@ messages from a filesystem or from stdin. Or you may wish to have your server side also support fetching message content from a filesystem. The choice is yours. -- **rag**: This allows a span query to express that a given message - should be augmented with fragments from a given set of - documents. The query process, with this feature flag enabled, - handles the fragmentation, indexing, etc. - -- **run**: This allows for execution of a query. Without this flag - enabled, the compiled code will only be able to parse - -- **ollama**: This allows the query execution to direct `g` (generate) - at a local Ollama model server. - -- **openai**: This allows the query execution to direct `g` (generate) - at an OpenAI compatible model server. By default, this will talk to - `http://localhost:8000`, but this can be changed via the - `OPENAI_BASE_URL` environment variable. - -- **pull**: This allows the query execution to pull down Ollama models - specified in a query. - -- **tok**: This adds an API for both parsing and then tokenizing the - messages in a query. - -- **python_bindings**: This adds python bindings to the span query - APIs (currently only the tokenization APIs are supported). - -- **lisp**: A highly experimental effort to allow for [static - compilation](./lisp) of a query into a shrinkwrapped executable. +## Core Features + +- **run**: Enables execution of span queries. Without this flag enabled, the compiled code will only be able to parse queries. + +- **print**: Enables printing/display functionality for span queries. + +- **cli_support**: Enables CLI-specific support features. Depends on `print` and includes pretty-printing with `ptree`. + +## Model Backend Features + +- **ollama**: Enables support for directing `g` (generate) operations to a local Ollama model server. Depends on `openai` for API compatibility. + +- **openai**: Enables support for directing `g` (generate) operations to an OpenAI-compatible model server. By default, this will talk to `http://localhost:8000`, but this can be changed via the `OPENAI_BASE_URL` environment variable. + +- **gemini**: Enables support for Google's Gemini API. Depends on `openai` for API compatibility. + +- **local**: Enables local model inference using mistral.rs. Supports running models directly on your machine without external API calls. Depends on `run` and includes mistralrs, tokio, and related dependencies. + +- **metal**: Enables Metal GPU acceleration for local inference on macOS. Depends on `local` and enables mistralrs Metal backend. + +- **cuda**: Enables CUDA GPU acceleration for local inference on NVIDIA GPUs. Depends on `local` and enables mistralrs CUDA backend. + +- **cuda-flash-attn**: Enables Flash Attention optimization for CUDA. Depends on `cuda`. + +- **cuda-flash-attn-v3**: Enables Flash Attention v3 optimization for CUDA. Depends on `cuda` and uses mistralrs-core directly. + +## RAG (Retrieval-Augmented Generation) Features + +- **rag**: Enables RAG capabilities, allowing span queries to augment messages with fragments from a given set of documents. The query process handles fragmentation, indexing, embedding, and retrieval. Depends on `run` and includes LanceDB, PDF extraction, and vector operations. + +- **rag-deep-debug**: Enables deep debugging output for RAG operations. + +## Language & Format Features + +- **lisp**: A highly experimental effort to allow for [static compilation](./lisp) of a query into a shrinkwrapped executable. + +- **yaml**: Enables YAML parsing and serialization support. + +## Tokenization & Python Features + +- **tok**: Adds an API for both parsing and then tokenizing the messages in a query. + +- **ffi**: Enables Foreign Function Interface support for calling spnl from other languages. + +- **pypi**: Enables Python bindings for spnl. Depends on `ffi` and `tok`. Includes PyO3 for Python interop. + +- **run_py**: Enables running span queries from Python with async support. Depends on `run`, `pypi`, and model backends. + +## Cloud & Infrastructure Features + +- **spnl-api**: Enables the spnl API client for communicating with spnl services. + +- **vllm**: Enables support for vLLM model serving. Depends on `yaml`. + +- **k8s**: Enables Kubernetes integration for deploying and managing vLLM instances. Depends on `vllm` and includes kube client libraries. + +- **gce**: Enables Google Compute Engine integration for deploying and managing vLLM instances. Depends on `vllm` and includes GCP client libraries. + +## Utility Features + +- **openssl-vendored**: Uses a vendored (statically linked) version of OpenSSL instead of the system version. Useful for portable builds. + +## Default Features + +The following features are enabled by default: +- `cli_support` +- `lisp` +- `run` +- `ollama` +- `openai` +- `gemini` +- `yaml` +- `local` + +This provides a full-featured experience with CLI support, multiple model backends (local and remote), and common format support. diff --git a/spnl/Cargo.toml b/spnl/Cargo.toml index 2bd65ca5..13e29703 100644 --- a/spnl/Cargo.toml +++ b/spnl/Cargo.toml @@ -14,20 +14,19 @@ include = [ ] [features] -default = ["cli_support","lisp","run","ollama","openai","gemini","pull","yaml"] +default = ["cli_support","lisp","run","ollama","openai","gemini","yaml","local"] cli_support = ["print","dep:ptree"] print = [] lisp = ["dep:serde-lexpr"] ollama = ["openai"] openai = ["dep:async-openai","dep:tokio"] gemini = ["openai"] -local = ["run","dep:mistralrs","dep:tokio","dep:uuid","dep:indexmap","dep:either"] +local = ["run","dep:mistralrs","dep:tokio","dep:uuid","dep:indexmap","dep:either","dep:reqwest"] metal = ["local","mistralrs/metal"] cuda = ["local","mistralrs/cuda"] cuda-flash-attn = ["cuda","mistralrs/flash-attn"] cuda-flash-attn-v3 = ["cuda","dep:mistralrs-core","mistralrs-core/flash-attn-v3"] openssl-vendored = ["dep:openssl"] -pull = ["dep:reqwest","dep:tokio-util"] ffi = [] pypi = ["ffi","tok","dep:pyo3","pyo3/extension-module","dep:thiserror"] rag = ["run","dep:sha2", "dep:lancedb","dep:tracing","dep:arrow-schema","dep:arrow-array","dep:itertools","dep:pdf-extract","dep:async-recursion","dep:regex","dep:rand"] @@ -71,7 +70,6 @@ futures = { version = "0.3.31", optional = true } indicatif = { version = "0.18.0", optional = true } tokio = { version = "1.44.1", features = ["io-std", "io-util", "signal"], optional = true } tokio-stream = { version = "0.1.18", features = ["net"], optional = true } -tokio-util = { version = "0.7.16", optional = true } anyhow = { version = "1.0.98" } lancedb = { version = "0.26.0", default-features = false, optional = true } tracing = { version = "0.1.41", optional = true } diff --git a/spnl/benches/mt_rag.rs b/spnl/benches/mt_rag.rs index 76d10c4b..531785ba 100644 --- a/spnl/benches/mt_rag.rs +++ b/spnl/benches/mt_rag.rs @@ -113,8 +113,8 @@ async fn run_rag_benchmark( fn mt_rag_benchmark(c: &mut Criterion) { let runtime = tokio::runtime::Runtime::new().unwrap(); - let model = "ollama/granite3.3:2b"; - let embedding_model = "ollama/mxbai-embed-large:335m"; + let model = "llama3.2:3b"; + let embedding_model = "local/google/embeddinggemma-300m"; let temperature = 0.0; let max_tokens = 100; // Use small token limit for faster benchmarking diff --git a/spnl/src/augment/embed.rs b/spnl/src/augment/embed.rs index 63db630e..66ee01f3 100644 --- a/spnl/src/augment/embed.rs +++ b/spnl/src/augment/embed.rs @@ -1,5 +1,5 @@ use crate::generate::backend::openai; -use crate::ir::Query; +use crate::ir::{Message::*, Query}; pub enum EmbedData { String(String), @@ -7,31 +7,56 @@ pub enum EmbedData { Vec(Vec), } +/// Helper function to convert Query to text content for embeddings +pub fn contentify(input: &Query) -> Vec { + match input { + Query::Seq(v) | Query::Plus(v) | Query::Cross(v) => v.iter().flat_map(contentify).collect(), + Query::Message(Assistant(s)) | Query::Message(System(s)) => vec![s.clone()], + o => { + let s = o.to_string(); + if s.is_empty() { + vec![] + } else { + vec![o.to_string()] + } + } + } +} + pub async fn embed( embedding_model: &String, data: EmbedData, ) -> anyhow::Result>> { - match embedding_model { - #[cfg(feature = "ollama")] - m if m.starts_with("ollama/") => { - openai::embed(openai::Provider::Ollama, &m[7..], &data).await + let embeddings: Vec> = match embedding_model { + #[cfg(feature = "local")] + m if m.starts_with("local/") => { + crate::generate::backend::mistralrs::embed::embed(&m[6..], &data).await? } + #[cfg(feature = "ollama")] + m if m.starts_with("ollama/") => openai::embed(openai::Provider::Ollama, &m[7..], &data) + .await? + .collect(), + #[cfg(feature = "ollama")] m if m.starts_with("ollama_chat/") => { - openai::embed(openai::Provider::Ollama, &m[12..], &data).await + openai::embed(openai::Provider::Ollama, &m[12..], &data) + .await? + .collect() } #[cfg(feature = "openai")] - m if m.starts_with("openai/") => { - openai::embed(openai::Provider::OpenAI, &m[7..], &data).await - } + m if m.starts_with("openai/") => openai::embed(openai::Provider::OpenAI, &m[7..], &data) + .await? + .collect(), #[cfg(feature = "gemini")] - m if m.starts_with("gemini/") => { - openai::embed(openai::Provider::Gemini, &m[7..], &data).await - } + m if m.starts_with("gemini/") => openai::embed(openai::Provider::Gemini, &m[7..], &data) + .await? + .collect(), _ => todo!("Unsupported embedding model {embedding_model}"), - } + }; + + Ok(embeddings.into_iter()) } diff --git a/spnl/src/augment/index/layer1.rs b/spnl/src/augment/index/layer1.rs index 21e8726a..a09d8489 100644 --- a/spnl/src/augment/index/layer1.rs +++ b/spnl/src/augment/index/layer1.rs @@ -54,9 +54,6 @@ async fn process_document( options: &AugmentOptions, m: &MultiProgress, ) -> anyhow::Result> { - #[cfg(feature = "pull")] - crate::pull::pull_model_if_needed(a.embedding_model.as_str()).await?; - let (filename, content) = &a.doc; let window_size = match content { Document::Text(_) => 1, diff --git a/spnl/src/augment/index/mod.rs b/spnl/src/augment/index/mod.rs index 836d5c45..f29754fe 100644 --- a/spnl/src/augment/index/mod.rs +++ b/spnl/src/augment/index/mod.rs @@ -33,7 +33,7 @@ pub async fn index(query: &Query, options: &AugmentOptions) -> anyhow::Result<() let augments = extract_augments(query, &None); match options.indexer { - Indexer::Raptor => raptor::index(query, &augments, options, &m).await, + Indexer::Raptor => raptor::index(&augments, options, &m).await, Indexer::SimpleEmbedRetrieve => { simple_embed_retrieve::index(query, &augments, options, &m).await } diff --git a/spnl/src/augment/index/raptor.rs b/spnl/src/augment/index/raptor.rs index efcf6280..5d241df3 100644 --- a/spnl/src/augment/index/raptor.rs +++ b/spnl/src/augment/index/raptor.rs @@ -18,16 +18,10 @@ const CONCURRENCY_LIMIT: usize = 32; /// Index using the RAPTOR algorithm https://github.com/parthsarthi03/raptor pub async fn index( - query: &Query, a: &[(String, Augment)], // (enclosing_model, Augment) options: &AugmentOptions, m: &MultiProgress, ) -> anyhow::Result<()> { - // TODO if we really want the pulls to be done in parallel with - // the process_corpora, we'll need something fancier... - #[cfg(feature = "pull")] - crate::pull::pull_if_needed(query).await?; - // This will generate one Fragments struct per corpus, and then iterate over each Fragments struct to "cross index" it let cross_index_futures = process_corpora(a, options, m) .await? diff --git a/spnl/src/execute/mod.rs b/spnl/src/execute/mod.rs index 94cd784b..87c71406 100644 --- a/spnl/src/execute/mod.rs +++ b/spnl/src/execute/mod.rs @@ -2,9 +2,6 @@ use crate::ir::{Bulk, Generate, GenerateBuilder, Message::*, Query, Repeat}; use crate::optimizer::hlo::simplify; use indicatif::MultiProgress; -#[cfg(feature = "pull")] -pub mod pull; - pub type ExecuteOptions = crate::generate::GenerateOptions; pub type SpnlError = anyhow::Error; @@ -76,9 +73,6 @@ async fn run_subtree(query: &Query, rp: &ExecuteOptions, m: Option<&MultiProgres } async fn run_subtree_(query: &Query, rp: &ExecuteOptions, m: Option<&MultiProgress>) -> SpnlResult { - #[cfg(feature = "pull")] - crate::execute::pull::pull_if_needed(query).await?; - match query { Query::Message(_) => Ok(query.clone()), diff --git a/spnl/src/execute/pull.rs b/spnl/src/execute/pull.rs deleted file mode 100644 index ef72c193..00000000 --- a/spnl/src/execute/pull.rs +++ /dev/null @@ -1,268 +0,0 @@ -use futures::stream::StreamExt; -use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; -use std::collections::HashMap; -use tokio::io::{AsyncBufReadExt, BufReader}; - -use crate::ir::{Generate, GenerateMetadata, Query}; - -/// Pull models (in parallel, if needed) used by the given query -pub async fn pull_if_needed(query: &Query) -> anyhow::Result<()> { - futures::future::try_join_all( - extract_models(query) - .iter() - .map(String::as_str) - .map(pull_model_if_needed), - ) - .await?; - - Ok(()) -} - -/// Pull the given model, if needed -pub async fn pull_model_if_needed(model: &str) -> anyhow::Result<()> { - match model { - m if model.starts_with("ollama/") => ollama_pull_if_needed(&m[7..]).await, - m if model.starts_with("ollama_chat/") => ollama_pull_if_needed(&m[12..]).await, - _ => Ok(()), - } -} - -#[derive(serde::Deserialize)] -struct OllamaModel { - model: String, -} - -#[derive(serde::Deserialize)] -struct OllamaTags { - models: Vec, -} - -#[derive(serde::Serialize)] -struct DeleteRequest { - model: String, - name: String, -} - -// struct to hold request params -#[derive(serde::Serialize)] -struct PullRequest { - model: String, - insecure: Option, - stream: Option, -} - -// struct to hold response params -#[derive(Debug, serde::Deserialize)] -struct ValidPullResponse { - status: String, - digest: Option, - total: Option, - completed: Option, -} - -#[derive(Debug, serde::Deserialize)] -struct InvalidPullResponse { - error: String, -} - -#[derive(Debug, serde::Deserialize)] -#[serde(untagged)] -enum PullResponse { - Ok(ValidPullResponse), - Err(InvalidPullResponse), -} - -fn api_base() -> String { - ::std::env::var("OLLAMA_API_BASE").unwrap_or("http://localhost:11434/api".to_string()) -} - -async fn ollama_exists(model: &str) -> anyhow::Result { - let tags: OllamaTags = reqwest::get(format!("{}/tags", api_base())) - .await? - .json() - .await?; - Ok(tags - .models - .into_iter() - .any(|m| m.model.eq_ignore_ascii_case(model))) -} - -// The Ollama implementation of a single model pull -async fn ollama_pull_if_needed(model: &str) -> anyhow::Result<()> { - let mut err: Option = None; - for _ in 0..5 { - if !ollama_exists(model).await? { - // creating client and request body - let http_client = reqwest::Client::new(); - let request_body = PullRequest { - model: model.to_string(), - insecure: Some(false), - stream: Some(true), - }; - - // receiving response and error handling - let response = http_client - .post(format!("{}/pull", api_base())) - .json(&request_body) - .send() - .await?; - if !response.status().is_success() { - eprintln!("API request failed with status: {}", response.status(),); - return Err(anyhow::anyhow!("Ollama API request failed")); - } - - // creating streaming structure - let byte_stream = response - .bytes_stream() - .map(|r| r.map_err(std::io::Error::other)); - let stream_reader = tokio_util::io::StreamReader::new(byte_stream); - let buf_reader = BufReader::new(stream_reader); - let mut lines = buf_reader.lines(); - - // creation of multiprogress container and style - let m = MultiProgress::new(); - let style = - ProgressStyle::with_template("{msg:<20} {percent:>3}% ▕{wide_bar}▏ {bytes:>7}") - .expect("Failed to create progress style template") - .progress_chars("█ "); - let mut digests: HashMap = HashMap::new(); - let mut final_status_lines: Vec = Vec::new(); - - while let Some(line) = lines.next_line().await? { - // stores in pull response struct - let update = match serde_json::from_str(&line) { - Ok(PullResponse::Ok(u)) => { - err = None; // clear any prior error - Ok(u) - } - Ok(PullResponse::Err(e)) => { - if e.error == "pull model manifest: file does not exist" { - return Err(anyhow::anyhow!(e.error)); - } - eprintln!("Possible transient error in ollama pull {}", e.error); - err = Some(anyhow::anyhow!(e.error)); - - let _ = http_client - .post("http://localhost:11434/api/delete") // TODO use OLLAMA_API_BASE? - .json(&DeleteRequest { - model: model.to_string(), - name: model.to_string(), - }) - .send() - .await; - - ::std::thread::sleep(::std::time::Duration::from_millis(2000)); - break; // break out of while iteration over lines - } - Err(e) => { - eprintln!("Invalid response from ollama pull: {line}"); - eprintln!("Parse error: {e}"); - Err(anyhow::anyhow!("Ollama API request failed")) - } - }?; - - let my_status = update.status.to_lowercase(); - - if let Some(digest) = update.digest { - // handles multiple progress bars - let current_pb = digests.entry(digest.clone()).or_insert_with(|| { - let new_pb = m.add(ProgressBar::new(0)); - new_pb.set_style(style.clone()); - new_pb - }); - - current_pb.set_message(my_status.clone()); - - // sets progress bar length - if let (Some(total), Some(done)) = (update.total, update.completed) { - if current_pb.length().unwrap_or(0) == 0 { - current_pb.set_length(total); - } - current_pb.set_position(done); - } - } else if digests.is_empty() { - // prints out status updates (before download) - m.println(&my_status).unwrap(); - } else { - // stores to print out status updates (after download) - final_status_lines.push(my_status.clone()); - } - - // checks for error or end of stream - if my_status == "error" { - return Err(anyhow::anyhow!("Ollama streaming error: {}", line)); - } else if my_status == "success" { - break; - } - } - - if err.is_some() { - continue; // continue for retry loop - } - - // finishes drawing progress bars and outputs rest of status updates - m.set_draw_target(indicatif::ProgressDrawTarget::hidden()); - for line in final_status_lines { - eprintln!("{}", line); - } - - break; // break out of retry loop - } - } - - match err { - Some(err) => Err(err), - None => Ok(()), - } -} - -/// Extract models referenced by the query -pub fn extract_models(query: &Query) -> Vec { - let mut models = vec![]; - extract_models_iter(query, &mut models); - - // A single query may specify the same model more than once. Dedup! - models.sort(); - models.dedup(); - - models -} - -/// Produce a vector of the models used by the given `query` -fn extract_models_iter(query: &Query, models: &mut Vec) { - match query { - #[cfg(feature = "rag")] - Query::Augment(crate::ir::Augment { - embedding_model, .. - }) => models.push(embedding_model.clone()), - Query::Generate(Generate { - metadata: GenerateMetadata { model, .. }, - .. - }) => models.push(model.clone()), - Query::Plus(v) | Query::Cross(v) => { - v.iter().for_each(|vv| extract_models_iter(vv, models)); - } - _ => {} - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio; - - // testing a valid model pull - #[tokio::test] - async fn test_pull_local_ollama() { - // this is the smallest model @starpit could find as of 20260108 - let result = ollama_pull_if_needed("smollm:135m").await; - assert!(result.is_ok()); - } - - // testing invalid model pull - #[tokio::test] - async fn test_pull_invalid_model() { - let result = ollama_pull_if_needed("notamodel").await; - assert!(result.is_err()); - } -} diff --git a/spnl/src/generate/backend/mistralrs/embed.rs b/spnl/src/generate/backend/mistralrs/embed.rs new file mode 100644 index 00000000..b22a65d7 --- /dev/null +++ b/spnl/src/generate/backend/mistralrs/embed.rs @@ -0,0 +1,64 @@ +//! Embedding support for mistral.rs backend + +use crate::augment::embed::{EmbedData, contentify}; +use mistralrs::{EmbeddingModelBuilder, EmbeddingRequest, best_device}; + +/// Returns true if MISTRALRS_VERBOSE env var is set to "true" or "1" +fn should_enable_logging() -> bool { + std::env::var("MISTRALRS_VERBOSE") + .map(|v| v.to_lowercase() == "true" || v == "1") + .unwrap_or(false) +} + +/// Generate embeddings using mistral.rs backend +/// +/// Note: Unlike text generation models, embedding models are loaded fresh each time +/// because the EmbeddingModel type is not publicly exported from mistralrs. +pub async fn embed(embedding_model: &str, data: &EmbedData) -> anyhow::Result>> { + // Load the embedding model + let device = best_device(false).expect("Failed to detect device"); + + if should_enable_logging() { + eprintln!( + "Loading embedding model: {} on device: {:?}", + embedding_model, device + ); + } + + let mut builder = EmbeddingModelBuilder::new(embedding_model).with_device(device); + + // Optionally enable logging + if should_enable_logging() { + builder = builder.with_logging(); + } + + let model = builder.build().await?; + + if should_enable_logging() { + eprintln!("Embedding model loaded successfully"); + } + + // Convert data to text strings + let docs = match data { + EmbedData::String(s) => vec![s.clone()], + EmbedData::Vec(v) => v.clone(), + EmbedData::Query(u) => contentify(u), + }; + + if should_enable_logging() { + eprintln!("Generating embeddings for {} documents", docs.len()); + } + + // Create an embedding request using the builder pattern + let mut request = EmbeddingRequest::builder(); + for doc in docs { + request = request.add_prompt(doc); + } + + // Get embeddings from the model - returns Vec> directly + let embeddings = model.generate_embeddings(request).await?; + + Ok(embeddings) +} + +// Made with Bob diff --git a/spnl/src/generate/backend/mistralrs/mod.rs b/spnl/src/generate/backend/mistralrs/mod.rs index 8928a59a..29175d6b 100644 --- a/spnl/src/generate/backend/mistralrs/mod.rs +++ b/spnl/src/generate/backend/mistralrs/mod.rs @@ -20,6 +20,9 @@ use crate::{ mod loader; use loader::ModelPool; +#[cfg(feature = "rag")] +pub mod embed; + // Global model pool - initialized once and reused across all requests static MODEL_POOL: OnceLock = OnceLock::new(); diff --git a/spnl/src/generate/backend/openai.rs b/spnl/src/generate/backend/openai.rs index 934072ed..ac6e16b2 100644 --- a/spnl/src/generate/backend/openai.rs +++ b/spnl/src/generate/backend/openai.rs @@ -397,22 +397,6 @@ pub fn messagify(input: &Query) -> Vec { } } -#[cfg(feature = "rag")] -pub fn contentify(input: &Query) -> Vec { - match input { - Query::Seq(v) | Query::Plus(v) | Query::Cross(v) => v.iter().flat_map(contentify).collect(), - Query::Message(Assistant(s)) | Query::Message(System(s)) => vec![s.clone()], - o => { - let s = o.to_string(); - if s.is_empty() { - vec![] - } else { - vec![o.to_string()] - } - } - } -} - #[cfg(feature = "rag")] pub async fn embed( provider: Provider, @@ -426,7 +410,7 @@ pub async fn embed( let docs = match data { EmbedData::String(s) => &vec![s.clone()], EmbedData::Vec(v) => v, - EmbedData::Query(u) => &contentify(u), + EmbedData::Query(u) => &crate::augment::embed::contentify(u), }; let request = CreateEmbeddingRequestArgs::default() diff --git a/spnl/src/optimizer/hlo.rs b/spnl/src/optimizer/hlo.rs index 1c54c6b4..065144c0 100644 --- a/spnl/src/optimizer/hlo.rs +++ b/spnl/src/optimizer/hlo.rs @@ -268,13 +268,21 @@ mod tests { #[cfg(feature = "rag")] #[tokio::test] // <-- needed for async tests async fn retrieve() -> anyhow::Result<()> { + // Skip test with warning if HF_TOKEN is not set in CI environment + if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() { + eprintln!( + "WARNING: HF_TOKEN is not set in CI environment. Skipping retrieve test to avoid failures when accessing Hugging Face models." + ); + return Ok(()); + } + let model = "spnl/m"; // This should work, because we use SimpleEmbedRetrieve which won't do any generation let q = Message(User("Hello".to_string())); let d = "I know all about Hello and stuff"; let outer_generate = GenerateBuilder::default() .metadata(GenerateMetadataBuilder::default().model(model).build()?) .input(Box::new(Query::Augment(crate::ir::Augment { - embedding_model: "ollama/mxbai-embed-large:335m".to_string(), + embedding_model: "local/google/embeddinggemma-300m".to_string(), body: Box::new(q), doc: ( "path/to/doc.txt".to_string(),