diff --git a/Makefile b/Makefile index 54c21088afa0..d96775cdc6ea 100644 --- a/Makefile +++ b/Makefile @@ -227,6 +227,10 @@ test-e2e: build-mock-backend prepare-e2e run-e2e-image $(MAKE) teardown-e2e docker rmi localai-tests +test-e2e-spiritlm: build-mock-backend + @echo 'Running SpiritLM e2e tests (mock backend)' + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="SpiritLM" --flake-attempts $(TEST_FLAKES) -v ./tests/e2e/... + teardown-e2e: rm -rf $(TEST_DIR) || true docker stop $$(docker ps -q --filter ancestor=localai-tests) @@ -247,6 +251,10 @@ test-stablediffusion: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models BACKENDS_PATH=$(abspath ./)/backends \ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) +test-spiritlm: prepare-test + TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures MODELS_PATH=$(abspath ./)/test-models BACKENDS_PATH=$(abspath ./)/backend/python SPIRITLM_CHECKPOINTS_DIR=$(SPIRITLM_CHECKPOINTS_DIR) \ + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="spiritlm" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) + test-stores: $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stores" --flake-attempts $(TEST_FLAKES) -v -r tests/integration @@ -312,7 +320,7 @@ protoc: .PHONY: protogen-go protogen-go: protoc install-go-tools mkdir -p pkg/grpc/proto - ./protoc --experimental_allow_proto3_optional -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \ + PATH="$$(go env GOPATH)/bin:$$PATH" ./protoc --experimental_allow_proto3_optional -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \ backend/backend.proto .PHONY: protogen-go-clean diff --git a/backend/index.yaml b/backend/index.yaml index e518170ca680..df22d9770a2d 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -594,6 +594,27 @@ nvidia-cuda-13: "cuda13-nemo" nvidia-cuda-12: "cuda12-nemo" icon: https://www.nvidia.com/favicon.ico +- &spiritlm + urls: + - https://github.com/facebookresearch/spiritlm + description: | + Meta Spirit LM: interleaved spoken and written language model. Supports text generation, text-to-speech (TTS), and automatic speech recognition (ASR) in a single 7B model. + tags: + - text-to-text + - text-to-speech + - TTS + - speech-recognition + - ASR + - LLM + - multimodal + license: fair-noncommercial + name: "spiritlm" + alias: "spiritlm" + capabilities: + nvidia: "cuda12-spiritlm" + default: "cpu-spiritlm" + nvidia-cuda-12: "cuda12-spiritlm" + icon: https://ai.meta.com/favicon.ico - &voxcpm urls: - https://github.com/ModelBest/VoxCPM diff --git a/backend/python/backend.proto b/backend/python/backend.proto new file mode 120000 index 000000000000..748607dcf7cf --- /dev/null +++ b/backend/python/backend.proto @@ -0,0 +1 @@ +../backend.proto \ No newline at end of file diff --git a/backend/python/spiritlm/E2E.md b/backend/python/spiritlm/E2E.md new file mode 100644 index 000000000000..1a3fe77be93e --- /dev/null +++ b/backend/python/spiritlm/E2E.md @@ -0,0 +1,69 @@ +# SpiritLM E2E tests + +SpiritLM is covered by two test layers: + +1. **`tests/e2e/` (recommended for CI)** – Full e2e suite using the shared mock backend. File: `tests/e2e/spiritlm_e2e_test.go`, label: `SpiritLM`. No real SpiritLM backend or model required. +2. **`core/http/app_test.go`** – Integration-style tests under context **SpiritLM backend e2e** (label: `spiritlm`). Requires the Python SpiritLM backend and fixtures. + +## How to run + +From the repo root: + +**E2E suite (mock backend, no Python backend needed):** + +```bash +make test-e2e-spiritlm +``` + +Or run the full e2e suite (includes SpiritLM): + +```bash +make test-e2e +``` + +**Integration tests (real SpiritLM backend):** + +```bash +make test-spiritlm +``` + +This sets `BACKENDS_PATH=./backend/python` and `TEST_DIR=./test-dir`, runs `prepare-test`, then runs Ginkgo with `--label-filter="spiritlm"`. + +For the **transcription** test you need `test-dir/audio.wav` (e.g. run `make test-models/testmodel.ggml` once to download it, or set `TEST_DIR` to a directory that contains `audio.wav`). + +## Backend setup + +1. **Protos** + Generate Python gRPC stubs (required for the backend to start): + ```bash + cd backend/python/spiritlm && bash protogen.sh + ``` + Or run the full install (which also creates the venv and installs deps): + ```bash + make -C backend/python/spiritlm + ``` + +2. **Full e2e pass (all 3 specs pass)** + - Install the backend: `make -C backend/python/spiritlm` + - Download the Spirit LM model from [Meta AI Spirit LM](https://ai.meta.com/resources/models-and-libraries/spirit-lm-downloads/) and place it so the checkpoint directory layout is: + ``` + / + spiritlm_model/ + spirit-lm-base-7b/ # model files (config.json, tokenizer, etc.) + ``` + - Run the tests with the checkpoint dir set: + ```bash + SPIRITLM_CHECKPOINTS_DIR=/path/to/checkpoints make test-spiritlm + ``` + - Ensure LocalAI runs the backend with that env (e.g. export it before `make test-spiritlm`, or configure the backend to pass it through). + +Without the model, the backend starts and responds to Health, but LoadModel fails; the e2e specs **skip** with a message pointing here, and the suite still **passes** (0 failed). + +## Requirements + +- Linux (tests skip on other OS) +- SpiritLM backend runnable: `backend/python/spiritlm/run.sh` must exist (satisfied in-tree) +- For backend to start: Python protos generated (`backend_pb2.py`, `backend_pb2_grpc.py`) and venv with grpc/spiritlm (via `make -C backend/python/spiritlm`) +- For all 3 specs to pass: Spirit LM model under `SPIRITLM_CHECKPOINTS_DIR` as above + +Tests are skipped if `BACKENDS_PATH` is unset or `BACKENDS_PATH/spiritlm/run.sh` is missing. diff --git a/backend/python/spiritlm/Makefile b/backend/python/spiritlm/Makefile new file mode 100644 index 000000000000..85b00878cdbe --- /dev/null +++ b/backend/python/spiritlm/Makefile @@ -0,0 +1,23 @@ +.PHONY: spiritlm +spiritlm: + bash install.sh + +.PHONY: run +run: spiritlm + @echo "Running spiritlm..." + bash run.sh + @echo "spiritlm run." + +.PHONY: test +test: spiritlm + @echo "Testing spiritlm..." + bash test.sh + @echo "spiritlm tested." + +.PHONY: protogen-clean +protogen-clean: + $(RM) backend_pb2_grpc.py backend_pb2.py + +.PHONY: clean +clean: protogen-clean + rm -rf venv __pycache__ diff --git a/backend/python/spiritlm/backend.py b/backend/python/spiritlm/backend.py new file mode 100644 index 000000000000..1f3172e875d2 --- /dev/null +++ b/backend/python/spiritlm/backend.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python3 +""" +LocalAI gRPC backend for Meta Spirit LM: interleaved text and speech model. +Supports text generation (Predict), TTS, and audio transcription (ASR). +""" +from concurrent import futures +import argparse +import io +import os +import signal +import struct +import sys +import time +from typing import Any, Dict, List, Optional + +import backend_pb2 +import backend_pb2_grpc +import grpc + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 +MAX_WORKERS = int(os.environ.get("PYTHON_GRPC_MAX_WORKERS", "1")) +DEFAULT_SAMPLE_RATE = 16000 + + +def _parse_option_value(value: str) -> Any: + if value.lower() in ("true", "false"): + return value.lower() == "true" + try: + return int(value) + except ValueError: + pass + try: + return float(value) + except ValueError: + pass + return value + + +def _float32_audio_to_wav_bytes(audio: Any, sample_rate: int = DEFAULT_SAMPLE_RATE) -> bytes: + """Convert float32 mono audio array to WAV bytes.""" + import numpy as np + samples = np.asarray(audio, dtype=np.float32) + if samples.ndim != 1: + samples = samples.flatten() + n_frames = len(samples) + n_channels = 1 + sample_width = 2 # 16-bit + byte_rate = sample_rate * n_channels * sample_width + block_align = n_channels * sample_width + data_size = n_frames * sample_width + # Clip to [-1, 1] and convert to int16 + samples = samples.clip(-1.0, 1.0) + int16_samples = (samples * 32767).astype(" backend_pb2.Reply: + return backend_pb2.Reply(message=b"OK") + + def LoadModel(self, request: backend_pb2.ModelOptions, context: grpc.ServicerContext) -> backend_pb2.Result: + try: + from spiritlm.model.spiritlm_model import Spiritlm + except ImportError as e: + return backend_pb2.Result(success=False, message=f"SpiritLM not installed: {e}") + + self._options = self._parse_options(request.Options) + model_name = (request.Model or "").strip() or "spirit-lm-base-7b" + if model_name not in ("spirit-lm-base-7b", "spirit-lm-expressive-7b"): + return backend_pb2.Result(success=False, message=f"Unknown model: {model_name}") + + try: + print(f"Loading Spirit LM model: {model_name}", file=sys.stderr) + self._model = Spiritlm(model_name) + self._sample_rate = self._options.get("sample_rate", DEFAULT_SAMPLE_RATE) + print("Spirit LM model loaded successfully", file=sys.stderr) + except Exception as e: + print(f"LoadModel failed: {e}", file=sys.stderr) + import traceback + traceback.print_exc(file=sys.stderr) + return backend_pb2.Result(success=False, message=str(e)) + + return backend_pb2.Result(success=True, message="Model loaded successfully") + + def _parse_options(self, options: List[str]) -> Dict[str, Any]: + out: Dict[str, Any] = {} + for opt in options or []: + if ":" not in opt: + continue + key, _, value = opt.partition(":") + key = key.strip() + value = value.strip() + if key: + out[key] = _parse_option_value(value) + return out + + def _generation_config( + self, + tokens: int = 200, + temperature: float = 0.9, + top_p: float = 0.95, + do_sample: bool = True, + ) -> Any: + from transformers import GenerationConfig + return GenerationConfig( + max_new_tokens=tokens, + temperature=temperature, + top_p=top_p, + do_sample=do_sample, + ) + + def Predict(self, request: backend_pb2.PredictOptions, context: grpc.ServicerContext) -> backend_pb2.Reply: + if not getattr(self, "_model", None): + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("Model not loaded") + return backend_pb2.Reply(message=b"") + try: + text = self._generate_text( + prompt=request.Prompt or "", + tokens=max(1, request.Tokens or 200), + temperature=request.Temperature if request.Temperature > 0 else 0.9, + top_p=request.TopP if request.TopP > 0 else 0.95, + stop_prompts=list(request.StopPrompts) if request.StopPrompts else None, + ) + return backend_pb2.Reply( + message=text.encode("utf-8"), + tokens=len(text.split()), + prompt_tokens=0, + ) + except Exception as e: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + return backend_pb2.Reply(message=b"") + + def PredictStream( + self, request: backend_pb2.PredictOptions, context: grpc.ServicerContext + ) -> Any: + if not getattr(self, "_model", None): + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("Model not loaded") + return + try: + text = self._generate_text( + prompt=request.Prompt or "", + tokens=max(1, request.Tokens or 200), + temperature=request.Temperature if request.Temperature > 0 else 0.9, + top_p=request.TopP if request.TopP > 0 else 0.95, + stop_prompts=list(request.StopPrompts) if request.StopPrompts else None, + ) + yield backend_pb2.Reply(message=text.encode("utf-8"), tokens=len(text.split())) + except Exception as e: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + + def _generate_text( + self, + prompt: str, + tokens: int = 200, + temperature: float = 0.9, + top_p: float = 0.95, + stop_prompts: Optional[List[str]] = None, + ) -> str: + from spiritlm.model.spiritlm_model import ContentType, GenerationInput, OutputModality + + if not prompt.strip(): + return "" + + inputs = [GenerationInput(content=prompt.strip(), content_type=ContentType.TEXT)] + config = self._generation_config(tokens=tokens, temperature=temperature, top_p=top_p) + outputs = self._model.generate( + output_modality=OutputModality.TEXT, + interleaved_inputs=inputs, + generation_config=config, + ) + parts: List[str] = [] + for out in outputs or []: + if getattr(out, "content_type", None) and str(getattr(out.content_type, "name", "")) == "TEXT": + content = getattr(out, "content", None) + if isinstance(content, str): + parts.append(content) + result = "".join(parts) + if stop_prompts: + for stop in stop_prompts: + if stop in result: + result = result.split(stop)[0].strip() + return result + + def AudioTranscription( + self, request: backend_pb2.TranscriptRequest, context: grpc.ServicerContext + ) -> backend_pb2.TranscriptResult: + if not getattr(self, "_model", None): + return backend_pb2.TranscriptResult(segments=[], text="") + audio_path = (request.dst or "").strip() + if not audio_path or not os.path.isfile(audio_path): + print(f"Audio file not found: {audio_path}", file=sys.stderr) + return backend_pb2.TranscriptResult(segments=[], text="") + + try: + from spiritlm.model.spiritlm_model import ContentType, GenerationInput, OutputModality + + inputs = [GenerationInput(content=audio_path, content_type=ContentType.SPEECH)] + config = self._generation_config(tokens=500, temperature=0.2, top_p=0.95) + outputs = self._model.generate( + output_modality=OutputModality.TEXT, + interleaved_inputs=inputs, + generation_config=config, + ) + parts: List[str] = [] + for out in outputs or []: + if getattr(out, "content_type", None) and str(getattr(out.content_type, "name", "")) == "TEXT": + content = getattr(out, "content", None) + if isinstance(content, str): + parts.append(content) + text = " ".join(parts).strip() + segment = backend_pb2.TranscriptSegment(id=0, start=0, end=0, text=text) + return backend_pb2.TranscriptResult(segments=[segment], text=text) + except Exception as e: + print(f"AudioTranscription failed: {e}", file=sys.stderr) + import traceback + traceback.print_exc(file=sys.stderr) + return backend_pb2.TranscriptResult(segments=[], text="") + + def TTS(self, request: backend_pb2.TTSRequest, context: grpc.ServicerContext) -> backend_pb2.Result: + if not getattr(self, "_model", None): + return backend_pb2.Result(success=False, message="Model not loaded") + text = (request.text or "").strip() + if not text: + return backend_pb2.Result(success=False, message="TTS request has no text") + + try: + from spiritlm.model.spiritlm_model import ContentType, GenerationInput, OutputModality + + inputs = [GenerationInput(content=text, content_type=ContentType.TEXT)] + config = self._generation_config(tokens=400, temperature=0.9, top_p=0.95) + outputs = self._model.generate( + output_modality=OutputModality.SPEECH, + interleaved_inputs=inputs, + generation_config=config, + ) + audio_float32 = None + for out in outputs or []: + if getattr(out, "content_type", None) and str(getattr(out.content_type, "name", "")) == "SPEECH": + content = getattr(out, "content", None) + if content is not None and hasattr(content, "__len__"): + import numpy as np + arr = np.asarray(content, dtype=np.float32) + if audio_float32 is None: + audio_float32 = arr + else: + audio_float32 = np.concatenate([audio_float32, arr]) + if audio_float32 is None: + return backend_pb2.Result(success=False, message="No speech output from model") + + wav_bytes = _float32_audio_to_wav_bytes(audio_float32, self._sample_rate) + if request.dst: + with open(request.dst, "wb") as f: + f.write(wav_bytes) + return backend_pb2.Result(success=True, message="OK") + except Exception as e: + print(f"TTS failed: {e}", file=sys.stderr) + import traceback + traceback.print_exc(file=sys.stderr) + return backend_pb2.Result(success=False, message=str(e)) + + def TTSStream( + self, request: backend_pb2.TTSRequest, context: grpc.ServicerContext + ) -> Any: + if not getattr(self, "_model", None): + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("Model not loaded") + return + result = self.TTS(request, context) + if not result.success: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(result.message) + return + if request.dst and os.path.isfile(request.dst): + with open(request.dst, "rb") as f: + data = f.read() + yield backend_pb2.Reply(audio=data) + + +def serve(address: str) -> None: + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), + options=[ + ("grpc.max_message_length", 50 * 1024 * 1024), + ("grpc.max_send_message_length", 50 * 1024 * 1024), + ("grpc.max_receive_message_length", 50 * 1024 * 1024), + ], + ) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + server.start() + print("Spirit LM backend listening on: " + address, file=sys.stderr) + + def signal_handler(sig: int, frame: Any) -> None: + print("Shutting down...", file=sys.stderr) + server.stop(0) + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Spirit LM gRPC backend for LocalAI") + parser.add_argument("--addr", default="localhost:50051", help="Address to bind") + args = parser.parse_args() + serve(args.addr) diff --git a/backend/python/spiritlm/install.sh b/backend/python/spiritlm/install.sh new file mode 100644 index 000000000000..4136d8765589 --- /dev/null +++ b/backend/python/spiritlm/install.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +installRequirements diff --git a/backend/python/spiritlm/protogen.sh b/backend/python/spiritlm/protogen.sh new file mode 100644 index 000000000000..46dd35567ddf --- /dev/null +++ b/backend/python/spiritlm/protogen.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +# backend.proto lives at repo backend/; from backend/python/spiritlm that is ../../../backend +proto_root="${backend_dir}/../../../backend" +python3 -m grpc_tools.protoc -I"${proto_root}" --python_out=. --grpc_python_out=. "${proto_root}/backend.proto" diff --git a/backend/python/spiritlm/requirements-install.txt b/backend/python/spiritlm/requirements-install.txt new file mode 100644 index 000000000000..d5987270e043 --- /dev/null +++ b/backend/python/spiritlm/requirements-install.txt @@ -0,0 +1,2 @@ +# SpiritLM: install from GitHub (not on PyPI) +git+https://github.com/facebookresearch/spiritlm.git diff --git a/backend/python/spiritlm/requirements.txt b/backend/python/spiritlm/requirements.txt new file mode 100644 index 000000000000..72750ac22933 --- /dev/null +++ b/backend/python/spiritlm/requirements.txt @@ -0,0 +1,7 @@ +grpcio>=1.71.0 +grpcio-tools +protobuf +torch +transformers +soundfile +numpy diff --git a/backend/python/spiritlm/run.sh b/backend/python/spiritlm/run.sh new file mode 100755 index 000000000000..eae121f37b0b --- /dev/null +++ b/backend/python/spiritlm/run.sh @@ -0,0 +1,9 @@ +#!/bin/bash +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +startBackend $@ diff --git a/backend/python/spiritlm/test.sh b/backend/python/spiritlm/test.sh new file mode 100644 index 000000000000..eb59f2aaf3f3 --- /dev/null +++ b/backend/python/spiritlm/test.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +runUnittests diff --git a/core/http/app_test.go b/core/http/app_test.go index 6202573b5734..297bd532e12c 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -10,6 +10,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "time" "github.com/mudler/LocalAI/core/application" @@ -1378,6 +1379,140 @@ parameters: }) }) + Context("SpiritLM backend e2e", Label("spiritlm"), func() { + BeforeEach(func() { + if runtime.GOOS != "linux" { + Skip("SpiritLM e2e runs only on linux") + } + backendPath := os.Getenv("BACKENDS_PATH") + if backendPath == "" { + Skip("BACKENDS_PATH not set (e.g. use make test-spiritlm)") + } + spiritlmRun := filepath.Join(backendPath, "spiritlm", "run.sh") + if _, err := os.Stat(spiritlmRun); err != nil { + Skip("SpiritLM backend not found at BACKENDS_PATH/spiritlm/run.sh") + } + + var err error + tmpdir, err = os.MkdirTemp("", "spiritlm-e2e-") + Expect(err).ToNot(HaveOccurred()) + + modelDir = filepath.Join(tmpdir, "models") + err = os.Mkdir(modelDir, 0750) + Expect(err).ToNot(HaveOccurred()) + + modelYAML := []byte(`name: spirit-lm-base-7b +backend: spiritlm +known_usecases: + - transcript + - tts +parameters: + model: spirit-lm-base-7b +`) + err = os.WriteFile(filepath.Join(modelDir, "spirit-lm-base-7b.yaml"), modelYAML, 0600) + Expect(err).ToNot(HaveOccurred()) + + c, cancel = context.WithCancel(context.Background()) + + systemState, err := system.GetSystemState( + system.WithBackendPath(backendPath), + system.WithModelPath(modelDir), + ) + Expect(err).ToNot(HaveOccurred()) + + application, err := application.New( + append(commonOpts, + config.WithExternalBackend("spiritlm", spiritlmRun), + config.WithContext(c), + config.WithSystemState(systemState), + )...) + Expect(err).ToNot(HaveOccurred()) + app, err = API(application) + Expect(err).ToNot(HaveOccurred()) + + go func() { + if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { + xlog.Error("server error", "error", err) + } + }() + + defaultConfig := openai.DefaultConfig("") + defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" + client2 = openaigo.NewClient("") + client2.BaseURL = defaultConfig.BaseURL + client = openai.NewClientWithConfig(defaultConfig) + Eventually(func() error { + _, err := client.ListModels(context.TODO()) + return err + }, "2m").ShouldNot(HaveOccurred()) + }) + AfterEach(func() { + cancel() + if app != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := app.Shutdown(ctx) + Expect(err).ToNot(HaveOccurred()) + } + if tmpdir != "" { + _ = os.RemoveAll(tmpdir) + } + }) + + It("loads model and returns chat completion", func() { + resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{ + Model: "spirit-lm-base-7b", + Messages: []openai.ChatCompletionMessage{ + {Role: "user", Content: "Say hello in one word."}, + }, + }) + if err != nil { + errStr := err.Error() + if strings.Contains(errStr, "grpc service not ready") || strings.Contains(errStr, "failed to load model") || strings.Contains(errStr, "could not load model") || strings.Contains(errStr, "Repo id must be") || strings.Contains(errStr, "is not a local folder and is not a valid model identifier") { + Skip("SpiritLM backend not ready (run 'make -C backend/python/spiritlm', set SPIRITLM_CHECKPOINTS_DIR to model dir, see backend/python/spiritlm/E2E.md)") + } + Expect(err).ToNot(HaveOccurred()) + } + Expect(resp.Choices).ToNot(BeEmpty()) + Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) + }) + + It("TTS returns audio", func() { + resp, err := http.Post("http://127.0.0.1:9090/tts", "application/json", + bytes.NewBuffer([]byte(`{"input": "Hello", "model": "spirit-lm-base-7b"}`))) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + dat, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + if resp.StatusCode != 200 && (strings.Contains(string(dat), "failed to load") || strings.Contains(string(dat), "not ready") || strings.Contains(string(dat), "could not load") || strings.Contains(string(dat), "is not a valid model identifier")) { + Skip("SpiritLM backend not ready (run 'make -C backend/python/spiritlm', set SPIRITLM_CHECKPOINTS_DIR, see backend/python/spiritlm/E2E.md)") + } + Expect(resp.StatusCode).To(Equal(200), string(dat)) + Expect(resp.Header.Get("Content-Type")).To(Or(Equal("audio/x-wav"), Equal("audio/wav"), Equal("audio/vnd.wave"))) + Expect(len(dat)).To(BeNumerically(">", 0)) + }) + + It("transcription returns text", func() { + testDir := os.Getenv("TEST_DIR") + audioPath := filepath.Join(testDir, "audio.wav") + if _, err := os.Stat(audioPath); err != nil { + Skip("TEST_DIR/audio.wav not found (run prepare-test or set TEST_DIR)") + } + resp, err := client.CreateTranscription(context.Background(), openai.AudioRequest{ + Model: "spirit-lm-base-7b", + FilePath: audioPath, + }) + if err != nil { + errStr := err.Error() + if strings.Contains(errStr, "grpc service not ready") || strings.Contains(errStr, "failed to load model") || strings.Contains(errStr, "could not load model") || strings.Contains(errStr, "Repo id must be") || strings.Contains(errStr, "is not a local folder and is not a valid model identifier") { + Skip("SpiritLM backend not ready (run 'make -C backend/python/spiritlm', set SPIRITLM_CHECKPOINTS_DIR, see backend/python/spiritlm/E2E.md)") + } + Expect(err).ToNot(HaveOccurred()) + } + Expect(resp.Text).ToNot(BeEmpty()) + }) + }) + Context("Config file", func() { BeforeEach(func() { if runtime.GOOS != "linux" { diff --git a/gallery/index.yaml b/gallery/index.yaml index 0cd16d3cec09..e891c2c13dbf 100644 --- a/gallery/index.yaml +++ b/gallery/index.yaml @@ -859,6 +859,39 @@ - transcript parameters: model: Qwen/Qwen3-ASR-0.6B +- &spiritlm + urls: + - https://github.com/facebookresearch/spiritlm + - https://ai.meta.com/resources/models-and-libraries/spirit-lm-downloads/ + description: | + Meta Spirit LM is an interleaved spoken and written language model. It supports text generation, text-to-speech (TTS), and automatic speech recognition (ASR) in a single 7B model. Base version uses phonetic tokens; Expressive adds pitch and style. + tags: + - speech-recognition + - ASR + - text-to-speech + - TTS + - llm + - multimodal + license: fair-noncommercial + icon: https://ai.meta.com/favicon.ico + name: "spirit-lm-base-7b" + url: "github:mudler/LocalAI/gallery/virtual.yaml@master" + overrides: + backend: spiritlm + known_usecases: + - transcript + - tts + parameters: + model: spirit-lm-base-7b +- !!merge <<: *spiritlm + name: "spirit-lm-expressive-7b" + overrides: + backend: spiritlm + known_usecases: + - transcript + - tts + parameters: + model: spirit-lm-expressive-7b - name: "huihui-glm-4.7-flash-abliterated-i1" url: "github:mudler/LocalAI/gallery/virtual.yaml@master" urls: diff --git a/tests/e2e/e2e_suite_test.go b/tests/e2e/e2e_suite_test.go index 66d9d6cd7ffb..b15a58050e52 100644 --- a/tests/e2e/e2e_suite_test.go +++ b/tests/e2e/e2e_suite_test.go @@ -99,6 +99,18 @@ var _ = BeforeSuite(func() { Expect(err).ToNot(HaveOccurred()) Expect(os.WriteFile(configPath, configYAML, 0644)).To(Succeed()) + // SpiritLM-style model config (same mock backend, for e2e coverage of SpiritLM path) + spiritlmConfig := map[string]interface{}{ + "name": "spirit-lm-base-7b", + "backend": "spiritlm", + "parameters": map[string]interface{}{ + "model": "spirit-lm-base-7b", + }, + } + spiritlmConfigYAML, err := yaml.Marshal(spiritlmConfig) + Expect(err).ToNot(HaveOccurred()) + Expect(os.WriteFile(filepath.Join(modelsPath, "spirit-lm-base-7b.yaml"), spiritlmConfigYAML, 0644)).To(Succeed()) + // Set up system state systemState, err := system.GetSystemState( system.WithBackendPath(backendPath), @@ -122,6 +134,7 @@ var _ = BeforeSuite(func() { // Register backend with application's model loader application.ModelLoader().SetExternalBackend("mock-backend", mockBackendPath) + application.ModelLoader().SetExternalBackend("spiritlm", mockBackendPath) // Create HTTP app app, err = httpapi.API(application) diff --git a/tests/e2e/spiritlm_e2e_test.go b/tests/e2e/spiritlm_e2e_test.go new file mode 100644 index 000000000000..3a6472e73d40 --- /dev/null +++ b/tests/e2e/spiritlm_e2e_test.go @@ -0,0 +1,103 @@ +package e2e_test + +import ( + "bytes" + "context" + "encoding/binary" + "io" + "mime/multipart" + "net/http" + "strings" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/openai/openai-go/v3" +) + +var _ = Describe("SpiritLM backend E2E", Label("SpiritLM"), func() { + Describe("Chat completions", func() { + It("returns response for spirit-lm-base-7b", func() { + resp, err := client.Chat.Completions.New( + context.TODO(), + openai.ChatCompletionNewParams{ + Model: "spirit-lm-base-7b", + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Say hello."), + }, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Choices)).To(Equal(1)) + Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) + }) + }) + + Describe("TTS", func() { + It("returns audio for spirit-lm-base-7b", func() { + body := `{"model":"spirit-lm-base-7b","input":"Hello","voice":"default"}` + req, err := http.NewRequest("POST", apiURL+"/audio/speech", io.NopCloser(strings.NewReader(body))) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + + httpClient := &http.Client{Timeout: 30 * time.Second} + resp, err := httpClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + Expect(resp.StatusCode).To(Equal(200)) + Expect(resp.Header.Get("Content-Type")).To(HavePrefix("audio/")) + data, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(len(data)).To(BeNumerically(">", 0)) + }) + }) + + Describe("Transcription", func() { + It("returns transcription for spirit-lm-base-7b", func() { + var buf bytes.Buffer + w := multipart.NewWriter(&buf) + part, err := w.CreateFormFile("file", "audio.wav") + Expect(err).ToNot(HaveOccurred()) + _, _ = part.Write(minimalWAVBytes()) + _ = w.WriteField("model", "spirit-lm-base-7b") + Expect(w.Close()).To(Succeed()) + + req, err := http.NewRequest("POST", apiURL+"/audio/transcriptions", &buf) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", w.FormDataContentType()) + + httpClient := &http.Client{Timeout: 30 * time.Second} + resp, err := httpClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + Expect(resp.StatusCode).To(Equal(200)) + data, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(data)).To(ContainSubstring("mocked")) + }) + }) +}) + +func minimalWAVBytes() []byte { + const sampleRate = 16000 + const numChannels = 1 + const bitsPerSample = 16 + const numSamples = 160 + dataSize := numSamples * numChannels * (bitsPerSample / 8) + headerLen := 44 + var buf bytes.Buffer + buf.Write([]byte("RIFF")) + _ = binary.Write(&buf, binary.LittleEndian, uint32(headerLen-8+dataSize)) + buf.Write([]byte("WAVEfmt ")) + _ = binary.Write(&buf, binary.LittleEndian, uint32(16)) + _ = binary.Write(&buf, binary.LittleEndian, uint16(1)) + _ = binary.Write(&buf, binary.LittleEndian, uint16(numChannels)) + _ = binary.Write(&buf, binary.LittleEndian, uint32(sampleRate)) + _ = binary.Write(&buf, binary.LittleEndian, uint32(sampleRate*numChannels*(bitsPerSample/8))) + _ = binary.Write(&buf, binary.LittleEndian, uint16(numChannels*(bitsPerSample/8))) + _ = binary.Write(&buf, binary.LittleEndian, uint16(bitsPerSample)) + buf.Write([]byte("data")) + _ = binary.Write(&buf, binary.LittleEndian, uint32(dataSize)) + buf.Write(make([]byte, dataSize)) + return buf.Bytes() +}