Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@ make download-punkt # download the punkt dataset for NLTK

### Models

Download models from huggingface.
Download the Gemma model from HuggingFace. The model is small enough for local
testing but still supports multimodal input.

```shell
huggingface-cli download HuggingFaceTB/SmolLM2-1.7B-Instruct --local-dir models/SmolLM-1.7B
huggingface-cli download google/gemma-3n-e4b-it --local-dir models/gemma
wget -P models https://huggingface.co/geneing/Kokoro/resolve/f610f07c62f8baa30d4ed731530e490230e4ee83/kokoro-v0_19.pth

```
Expand All @@ -127,6 +128,14 @@ models/tg a general model accessible with model=multilingual
models/tgz an instruct model accessible with model=instruct
models/tgc a chat model accessible with model=chat

For the Gemma image pipeline you can override the default model or device using
the following environment variables:

```
GEMMA_MODEL_ID=myorg/my-gemma-checkpoint
GEMMA_DEVICE=0 # set to -1 for CPU
```

model=best is configured to figure out which model to use based on the prompt being scored based on perplexity of each model.

This needs tuning for the avg and std deviation of the perplexity as each model has different ideas about how confidenti it is. Overtrained models are more confident about all text being in the dataset (tend to generate text verbatim from the dataset).
Expand Down
9 changes: 6 additions & 3 deletions questions/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
weights_path_tgz = os.getenv("WEIGHTS_PATH_TGZ", "models/SmolLM-1.7B")
weights_path_tgc = os.getenv("WEIGHTS_PATH_TGC", "models/SmolLM-1.7B")
weights_path_tg = os.getenv("WEIGHTS_PATH", "models/SmolLM-1.7B")
# Default to the multimodal Gemma model which can handle both text generation
# and image description. The environment variables allow overriding the model
# path, but when unset a small Gemma checkpoint from HuggingFace is used.
weights_path_tgz = os.getenv("WEIGHTS_PATH_TGZ", "google/gemma-3n-e4b-it")
weights_path_tgc = os.getenv("WEIGHTS_PATH_TGC", "google/gemma-3n-e4b-it")
weights_path_tg = os.getenv("WEIGHTS_PATH", "google/gemma-3n-e4b-it")
1 change: 1 addition & 0 deletions questions/inference_server/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,3 +1343,4 @@ def tts_demo(request: Request):
# return HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")

if __name__ == "__main__":
pass
63 changes: 36 additions & 27 deletions questions/link_enricher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
setup_logging()
logger = logging.getLogger(__name__)
from requests_futures.sessions import FuturesSession
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
try:
import torch
except ModuleNotFoundError: # pragma: no cover - optional dependency
torch = None
from transformers import pipeline
from questions.inference_server.model_cache import ModelCache

# change into OFA dir
Expand Down Expand Up @@ -97,42 +100,48 @@ def get_title_from_html(html, long_description=False):

LINK_MODEL_CACHE = ModelCache()

def load_moondream_model():
"""Load the Moondream model for image captioning"""
model_id = "vikhyatk/moondream2"
revision = "2024-08-26"

model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
revision=revision
def load_gemma_pipe():
"""Load the Gemma model for image captioning and text generation."""
model_id = os.getenv("GEMMA_MODEL_ID", "google/gemma-3n-e4b-it")
device_env = os.getenv("GEMMA_DEVICE")
if device_env is not None:
device = int(device_env)
else:
device = 0 if torch and torch.cuda.is_available() else -1
return pipeline(
"image-text-to-text",
model=model_id,
device=device,
torch_dtype=torch.bfloat16 if torch and torch.cuda.is_available() else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision
)
return model, tokenizer

def get_caption_for_image_response(response, prompt="Describe this image."):
"""Get image caption using Moondream model"""
"""Generate a caption for an image using Gemma."""
response.raw.decode_content = True
image_bytes = response.content

img = Image.open(BytesIO(image_bytes))

img = Image.open(BytesIO(response.content))

with log_time("image captioning"):
model, tokenizer = LINK_MODEL_CACHE.add_or_get("moondream_model", load_moondream_model)
enc_image = model.encode_image(img)
caption = model.answer_question(enc_image, prompt, tokenizer)

pipe = LINK_MODEL_CACHE.add_or_get("gemma_pipe", load_gemma_pipe)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": img},
{"type": "text", "text": prompt},
],
},
]
output = pipe(text=messages, max_new_tokens=100)
caption = output[0]["generated_text"][-1]["content"]

if debug:
logger.info(f"Image description: {caption}")

if any(ocr_tag in caption.lower() for ocr_tag in ocr_tags):
with log_time("OCR"):
ocr_data = ocr_tess(img)
caption += " " + ocr_data

return caption


Expand Down
39 changes: 39 additions & 0 deletions questions/test_gemma_multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
import importlib.util
import pytest
from PIL import Image
from transformers import pipeline
from questions.logging_config import setup_logging

# Skip the tests if torch is not available as the transformers pipelines
# require it for model execution.
if importlib.util.find_spec("torch") is None:
pytest.skip("torch is required for Gemma tests", allow_module_level=True)

setup_logging()

def test_gemma_image_captioning():
model_id = "yujiepan/gemma-3n-tiny-random"
pipe = pipeline(
"image-text-to-text",
model=model_id,
device=-1,
)

image_path = "static/img/me.jpg"
assert os.path.exists(image_path)
img = Image.open(image_path)

messages = [
{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": "Describe this image."}]},
]
output = pipe(text=messages, max_new_tokens=5)
assert isinstance(output[0]["generated_text"][-1]["content"], str)


def test_gemma_text_generation():
model_id = "yujiepan/gemma-3n-tiny-random"
text_pipe = pipeline("text-generation", model=model_id, device=-1)
out = text_pipe("Hello", max_new_tokens=5)
assert isinstance(out[0]["generated_text"], str)

47 changes: 0 additions & 47 deletions questions/test_inf_moon.py

This file was deleted.

8 changes: 4 additions & 4 deletions questions/text_generator_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def load_model(weights_path):
if (Path("/" + weights_path) / "config.json").exists():
weights_path = str(Path("/") / weights_path)

# if (Path("/models")).exists(): # prefer to save in ramdisk
# weights_path = "/" + weights_path

if not (Path(weights_path) / "config.json").exists():
# Only attempt to download weights if a local directory was specified and no
# config file is present. When using a HuggingFace model id the directory
# will not exist and downloading is handled by `from_pretrained`.
if os.path.isdir(weights_path) and not (Path(weights_path) / "config.json").exists():
download_model(weights_path, weights_path)

logger.info(f"loading model from {weights_path}")
Expand Down
59 changes: 22 additions & 37 deletions scripts/example_moodream.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
import os
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer

def ensure_model_downloaded():
"""Ensure model is downloaded to models directory"""
model_path = "models/moondream2"
if not os.path.exists(model_path):
print(f"Model not found in {model_path}, downloading...")
os.makedirs(model_path, exist_ok=True)
return model_path
from transformers import pipeline

def load_image():
"""Load the local chrome icon image"""
Expand All @@ -18,45 +10,38 @@ def load_image():
return Image.open(image_path)

def main():
# Initialize model and tokenizer
model_id = "vikhyatk/moondream2"
revision = "2024-08-26"
model_path = ensure_model_downloaded()

print("Loading model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
revision=revision,
cache_dir=model_path,
force_download=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
cache_dir=model_path,
force_download=True
"""Run a simple caption generation using the Gemma model."""
model_id = "google/gemma-3n-e4b-it"
pipe = pipeline(
"image-text-to-text",
model=model_id,
device=0 if torch.cuda.is_available() else -1,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
)

# Load and process image
print("Loading local image...")
image = load_image()

print("Encoding image...")
enc_image = model.encode_image(image)

# Ask questions about the image
questions = [
"Describe this image.",
"What colors are prominent in this image?",
"Is this an icon or logo? If so, describe its design."
"Is this an icon or logo? If so, describe its design.",
]

print("\nAsking questions about the image:")
for question in questions:
print(f"\nQ: {question}")
answer = model.answer_question(enc_image, question, tokenizer)
print(f"A: {answer}")
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": question},
],
},
]
answer = pipe(text=messages, max_new_tokens=100)[0]["generated_text"][-1]["content"]
print(f"\nQ: {question}\nA: {answer}")

if __name__ == "__main__":
main()
main()
35 changes: 35 additions & 0 deletions scripts/run_gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import argparse
from PIL import Image
from transformers import pipeline
import os


def main():
parser = argparse.ArgumentParser(description="Generate captions using Gemma")
parser.add_argument("image", help="Path to image")
parser.add_argument("prompt", nargs="?", default="Describe this image.", help="Prompt for the model")
args = parser.parse_args()

model_id = os.getenv("GEMMA_MODEL_ID", "google/gemma-3n-e4b-it")
device_env = os.getenv("GEMMA_DEVICE")
if device_env is not None:
device = int(device_env)
else:
try:
import torch
device = 0 if torch.cuda.is_available() else -1
except ModuleNotFoundError:
device = -1

pipe = pipeline("image-text-to-text", model=model_id, device=device)

img = Image.open(args.image)
messages = [
{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": args.prompt}]},
]
output = pipe(text=messages, max_new_tokens=100)
print(output[0]["generated_text"][-1]["content"])


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions tests/unit/questions/test_link_enricher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import logging
import importlib.util
import pytest

if importlib.util.find_spec("transformers") is None:
pytest.skip("transformers not installed", allow_module_level=True)

bs4 = pytest.importorskip("bs4", reason="bs4 required for link enrichment tests")
from questions.link_enricher import get_urls, enrich_links
from questions.logging_config import setup_logging
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/test_audio_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import builtins
from unittest import mock
import importlib.util
import pytest

if importlib.util.find_spec("torch") is None:
pytest.skip("torch not installed", allow_module_level=True)

import questions.inference_server.inference_server as server

Expand Down
4 changes: 4 additions & 0 deletions tests/unit/test_document_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import requests
import pytest
import importlib.util

if importlib.util.find_spec("docx") is None:
pytest.skip("python-docx not installed", allow_module_level=True)

from questions.document_processor import convert_to_markdown, convert_documents_async

Expand Down
Loading
Loading