diff --git a/pyproject.toml b/pyproject.toml index ed2b248a..a234979b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,6 +141,13 @@ dependencies = [ ] [project.optional-dependencies] +evaluation = [ + "pydantic>=2.0.0", + "litellm>=1.0.0", + "transformers>=4.40.0", + "accelerate>=0.20.0", +] + stable-fast = [ "xformers>=0.0.30", "stable-fast-pruna==1.0.8", diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py index 77ccef6a..6b362a66 100644 --- a/src/pruna/evaluation/metrics/__init__.py +++ b/src/pruna/evaluation/metrics/__init__.py @@ -24,6 +24,14 @@ from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper +from pruna.evaluation.metrics.metrics_vlm import ( + AlignmentScoreMetric, + ImageEditScoreMetric, + QAAccuracyMetric, + TextScoreMetric, + VieScoreMetric, + VQAMetric, +) __all__ = [ "MetricRegistry", @@ -43,4 +51,10 @@ "DinoScore", "SharpnessMetric", "AestheticLAION", + "VQAMetric", + "AlignmentScoreMetric", + "ImageEditScoreMetric", + "QAAccuracyMetric", + "TextScoreMetric", + "VieScoreMetric", ] diff --git a/src/pruna/evaluation/metrics/metrics_vlm.py b/src/pruna/evaluation/metrics/metrics_vlm.py new file mode 100644 index 00000000..b7d6a968 --- /dev/null +++ b/src/pruna/evaluation/metrics/metrics_vlm.py @@ -0,0 +1,726 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +VLM-based metrics for Pruna. + +Metrics using Vision-Language Models for evaluation. +Supports LitellmVLM (API-based) and TransformersVLM (local models). + +References +---------- +VQAScore: https://arxiv.org/abs/2310.08868 +VieScore: https://github.com/ByteDance/IEA-eval +""" + +from __future__ import annotations + +import math +import re +from typing import Any, List, Literal, Optional + +import numpy as np +import torch +from PIL import Image +from pydantic import BaseModel + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import SINGLE, get_call_type_for_single_metric, metric_data_processor +from pruna.evaluation.metrics.vlm_base import LitellmVLM, TransformersVLM + + +def _tensor_to_pil(tensor: torch.Tensor) -> Image.Image: + from PIL import Image + + if tensor.ndim == 4: + tensor = tensor[0] + if tensor.max() > 1: + tensor = tensor / 255.0 + np_img = (tensor.cpu().numpy() * 255).astype("uint8") + return Image.fromarray(np_img.transpose(1, 2, 0)) + + +def _process_images(images: torch.Tensor) -> List[Any]: + return [_tensor_to_pil(img) if isinstance(img, torch.Tensor) else img for img in images] + + +# Pydantic models for structured generation +class VQAnswer(BaseModel): + """ + Structured output for VQA. + + Parameters + ---------- + answer : str + The VQA answer text. + confidence : float, optional + Confidence score. Default is 1.0. + """ + + answer: str + confidence: float = 1.0 + + +class ScoreOutput(BaseModel): + """ + Structured output for scoring metrics. + + Parameters + ---------- + score : float + The numeric score. + reasoning : str | None, optional + Optional reasoning for the score. + """ + + score: float + reasoning: Optional[str] = None + + +# VQA Metric +@MetricRegistry.register("vqa") +class VQAMetric(StatefulMetric): + """ + VQA (Visual Question Answering) metric. + + Uses VLM to answer questions about images and compare with expected answers. + Higher scores indicate better image-text alignment. + + Reference + ---------- + VQAScore: Uses VLM for VQA-based image evaluation + https://arxiv.org/abs/2310.08868 + + Parameters + ---------- + *args : Any + Additional positional arguments. + vlm_type : {"litellm", "transformers"}, optional + VLM backend to use. Default is "litellm". + model_name : str, optional + Model name (gpt-4o for litellm, model path for transformers). + structured_output : bool, optional + Use structured generation for stable outputs. Default is True. + use_outlines : bool, optional + Use outlines for transformers. Default is False. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. + """ + + scores: List[float] + default_call_type: str = "y" + higher_is_better: bool = True + metric_name: str = "vqa" + runs_on: List[str] = ["cpu"] + + def __init__( + self, + *args, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str = "gpt-4o", + structured_output: bool = True, + use_outlines: bool = False, + device=None, + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + self.structured_output = structured_output + + # Create VLM with structured generation support + if vlm_type == "litellm": + self.vlm = LitellmVLM(model_name=model_name, api_key=api_key) + self.response_format = VQAnswer if structured_output else None + else: + self.vlm = TransformersVLM(model_name=model_name, device=device, use_outlines=use_outlines) + self.response_format = "yes_no" if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (text prompts). + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output images to score. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = x if isinstance(x, list) else [""] * len(images) + + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + question = f'Does this image show "{prompt}"? Answer Yes or No.' + score = self.vlm.score([image], [question], ["Yes"], response_format=self.response_format)[0] + self.scores.append(score) + + def compute(self) -> MetricResult: + """ + Compute the metric result. + + Returns + ------- + MetricResult + The computed metric result. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) + + +# Alignment Score Metric +@MetricRegistry.register("alignment_score") +class AlignmentScoreMetric(StatefulMetric): + """ + Alignment Score metric using VLM. + + Assesses how well generated images match text prompts through structured questioning. + Higher scores indicate better alignment. + + Reference + ---------- + Uses VLM for image-text alignment evaluation. + + Parameters + ---------- + *args : Any + Additional positional arguments. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str, optional + Model name. Default is "gpt-4o". + structured_output : bool, optional + Use structured generation. Default is True. + use_outlines : bool, optional + Use outlines for transformers. Default is False. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. + """ + + scores: List[float] + default_call_type: str = "y" + higher_is_better: bool = True + metric_name: str = "alignment_score" + runs_on: List[str] = ["cpu"] + + def __init__( + self, + *args, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str = "gpt-4o", + structured_output: bool = True, + use_outlines: bool = False, + device=None, + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + if vlm_type == "litellm": + self.vlm = LitellmVLM(model_name=model_name, api_key=api_key) + self.response_format = ScoreOutput if structured_output else None + else: + self.vlm = TransformersVLM(model_name=model_name, device=device, use_outlines=use_outlines) + self.response_format = "integer" if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (text prompts). + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output images to score. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = x if isinstance(x, list) else [""] * len(images) + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + question = f'Does this image show "{prompt}"? Answer Yes or No.' + score = self.vlm.score([image], [question], ["Yes"], response_format=self.response_format)[0] + self.scores.append(score) + + def compute(self) -> MetricResult: + """ + Compute the metric result. + + Returns + ------- + MetricResult + The computed metric result. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) + + +# Image Edit Score Metric +@MetricRegistry.register("img_edit_score") +class ImageEditScoreMetric(StatefulMetric): + """ + Image Edit Score metric. + + Evaluates how well an image was edited based on editing instructions. + Higher scores indicate better editing quality. + + Reference + ---------- + VieScore: https://github.com/ByteDance/IEA-eval + + Parameters + ---------- + *args : Any + Additional positional arguments. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str, optional + Model name. Default is "gpt-4o". + structured_output : bool, optional + Use structured generation. Default is True. + use_outlines : bool, optional + Use outlines for transformers. Default is False. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. + """ + + scores: List[float] + default_call_type: str = "y" + higher_is_better: bool = True + metric_name: str = "img_edit_score" + runs_on: List[str] = ["cpu"] + + def __init__( + self, + *args, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str = "gpt-4o", + structured_output: bool = True, + use_outlines: bool = False, + device=None, + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + if vlm_type == "litellm": + self.vlm = LitellmVLM(model_name=model_name, api_key=api_key) + self.response_format = ScoreOutput if structured_output else None + else: + self.vlm = TransformersVLM(model_name=model_name, device=device, use_outlines=use_outlines) + self.response_format = "integer" if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (text prompts). + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output images to score. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = x if isinstance(x, list) else [""] * len(images) + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + question = f'Rate 0-10: Does this image show "{prompt}"? Reply with a number.' + responses = self.vlm.generate([image], [question], response_format=self.response_format) + score = self._parse_score(responses[0]) + self.scores.append(score) + + def _parse_score(self, response: str) -> float: + if isinstance(response, str): + numbers = re.findall(r"\d+", response) + return min(float(numbers[0]), 10.0) / 10.0 if numbers else 0.0 + return 0.0 + + def compute(self) -> MetricResult: + """ + Compute the metric result. + + Returns + ------- + MetricResult + The computed metric result. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) + + +# QA Accuracy Metric +@MetricRegistry.register("qa_accuracy") +class QAAccuracyMetric(StatefulMetric): + """ + QA Accuracy metric. + + Uses VLM to answer questions about images. + Higher scores indicate better image understanding. + + Parameters + ---------- + *args : Any + Additional positional arguments. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str, optional + Model name. Default is "gpt-4o". + structured_output : bool, optional + Use structured generation. Default is True. + use_outlines : bool, optional + Use outlines for transformers. Default is False. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. + """ + + scores: List[float] + default_call_type: str = "y" + higher_is_better: bool = True + metric_name: str = "qa_accuracy" + runs_on: List[str] = ["cpu"] + + def __init__( + self, + *args, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str = "gpt-4o", + structured_output: bool = True, + use_outlines: bool = False, + device=None, + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + if vlm_type == "litellm": + self.vlm = LitellmVLM(model_name=model_name, api_key=api_key) + self.response_format = VQAnswer if structured_output else None + else: + self.vlm = TransformersVLM(model_name=model_name, device=device, use_outlines=use_outlines) + self.response_format = None # No constraint for open QA + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (text prompts). + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output images to score. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + for image in images: + question = "What is in this image? Answer:" + responses = self.vlm.generate([image], [question], response_format=self.response_format) + score = 1.0 if responses and responses[0].strip() else 0.0 + self.scores.append(score) + + def compute(self) -> MetricResult: + """ + Compute the metric result. + + Returns + ------- + MetricResult + The computed metric result. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) + + +# Text Score Metric +@MetricRegistry.register("text_score") +class TextScoreMetric(StatefulMetric): + """ + Text Score metric for evaluating text rendering in images. + + Uses VLM for OCR to extract text and compare with ground truth. + Lower scores (edit distance) are better. + + Parameters + ---------- + *args : Any + Additional positional arguments. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str, optional + Model name. Default is "gpt-4o". + structured_output : bool, optional + Use structured generation. Default is True. + use_outlines : bool, optional + Use outlines for transformers. Default is False. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. + """ + + scores: List[float] + default_call_type: str = "y" + higher_is_better: bool = False + metric_name: str = "text_score" + runs_on: List[str] = ["cpu"] + + def __init__( + self, + *args, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str = "gpt-4o", + structured_output: bool = True, + use_outlines: bool = False, + device=None, + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + if vlm_type == "litellm": + self.vlm = LitellmVLM(model_name=model_name, api_key=api_key) + self.response_format = None # OCR is open-ended + else: + self.vlm = TransformersVLM(model_name=model_name, device=device, use_outlines=use_outlines) + self.response_format = None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (text prompts). + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output images to score. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + for image in images: + prompt = "Extract all text from this image. If no text, say 'No text'." + responses = self.vlm.generate([image], [prompt], response_format=self.response_format) + score = 0.0 if responses and responses[0].strip().lower() != "no text" else 10.0 + self.scores.append(score) + + def compute(self) -> MetricResult: + """ + Compute the metric result. + + Returns + ------- + MetricResult + The computed metric result. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) + + +# VieScore Metric +@MetricRegistry.register("viescore") +class VieScoreMetric(StatefulMetric): + """ + VieScore metric for evaluating image quality (semantic + quality). + + Uses VLM to assess both semantic alignment and visual quality. + Higher scores indicate better overall quality. + + Reference + ---------- + VieScore: https://github.com/ByteDance/IEA-eval + + Computes: + - Semantic score: How well image follows prompt + - Quality score: Naturalness and artifacts + - Overall: Geometric mean of semantic and quality + + Parameters + ---------- + *args : Any + Additional positional arguments. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str, optional + Model name. Default is "gpt-4o". + structured_output : bool, optional + Use structured generation. Default is True. + use_outlines : bool, optional + Use outlines for transformers. Default is False. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. + """ + + scores: List[float] + default_call_type: str = "y" + higher_is_better: bool = True + metric_name: str = "viescore" + runs_on: List[str] = ["cpu"] + + def __init__( + self, + *args, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str = "gpt-4o", + structured_output: bool = True, + use_outlines: bool = False, + device=None, + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + if vlm_type == "litellm": + self.vlm = LitellmVLM(model_name=model_name, api_key=api_key) + self.response_format = ScoreOutput if structured_output else None + else: + self.vlm = TransformersVLM(model_name=model_name, device=device, use_outlines=use_outlines) + self.response_format = "integer" if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (text prompts). + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output images to score. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = x if isinstance(x, list) else [""] * len(images) + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + + # Semantic score + sem_prompt = f'Rate 0-10: Does this image show "{prompt}"?' + sem_resp = self.vlm.generate([image], [sem_prompt], response_format=self.response_format)[0] + sem_score = self._parse_score(sem_resp) + + # Quality score + qual_prompt = "Rate 0-10: How natural is this image? Any artifacts?" + qual_resp = self.vlm.generate([image], [qual_prompt], response_format=self.response_format)[0] + qual_score = self._parse_score(qual_resp) + + # Overall = geometric mean + score = math.sqrt(sem_score * qual_score) / 10.0 + self.scores.append(score) + + def _parse_score(self, response: str) -> float: + if isinstance(response, str): + numbers = re.findall(r"\d+", response) + return min(float(numbers[0]), 10.0) if numbers else 0.0 + return 0.0 + + def compute(self) -> MetricResult: + """ + Compute the metric result. + + Returns + ------- + MetricResult + The computed metric result. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) diff --git a/src/pruna/evaluation/metrics/vlm_base.py b/src/pruna/evaluation/metrics/vlm_base.py new file mode 100644 index 00000000..781487b8 --- /dev/null +++ b/src/pruna/evaluation/metrics/vlm_base.py @@ -0,0 +1,427 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +VLM (Vision-Language Model) base classes for metrics. + +This module provides two VLM implementations: +1. LitellmVLM - Uses litellm for API-based VLM calls (supports 100+ providers) +2. TransformersVLM - Uses local VLM models from HuggingFace Transformers + +Both support structured generation for stable outputs: +- LitellmVLM: Uses pydantic models with response_format +- TransformersVLM: Uses outlines for constrained decoding. +""" + +from __future__ import annotations + +import base64 +import io +import os +from abc import ABC, abstractmethod +from typing import Any, List, Optional, Type, TypeVar + +import torch +from PIL import Image +from pydantic import BaseModel + +from pruna.logging.logger import pruna_logger + +T = TypeVar("T", bound=BaseModel) + + +class BaseVLM(ABC): + """Base class for Vision-Language Models.""" + + @abstractmethod + def generate( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Type[BaseModel]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate responses for images and prompts. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + prompts : List[str] + List of text prompts. + response_format : Type[BaseModel] | None + Optional pydantic model for structured output. + **kwargs : Any + Additional arguments passed to the implementation. + + Returns + ------- + List[str] + Generated responses. + """ + pass + + @abstractmethod + def score( + self, + images: List[Image.Image], + questions: List[str], + answers: List[str], + **kwargs: Any, + ) -> List[float]: + """ + Score how well answers match images for given questions. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + questions : List[str] + List of questions. + answers : List[str] + List of expected answers. + **kwargs : Any + Additional arguments passed to the implementation. + + Returns + ------- + List[float] + Scores for each image-question pair. + """ + pass + + +class LitellmVLM(BaseVLM): + """ + VLM using litellm for API-based inference. + + Supports 100+ LLM providers (OpenAI, Anthropic, Azure, etc.) + Default model is gpt-4o. + + Parameters + ---------- + model_name : str, optional + Model name (e.g., gpt-4o). Default is "gpt-4o". + api_key : str | None, optional + API key for the provider. Uses LITELLM_API_KEY or OPENAI_API_KEY env if None. + **kwargs : Any + Additional arguments passed to litellm. + """ + + def __init__( + self, + model_name: str = "gpt-4o", + api_key: Optional[str] = None, + **kwargs: Any, + ) -> None: + self.model_name = model_name + self.api_key = api_key or os.getenv("LITELLM_API_KEY") or os.getenv("OPENAI_API_KEY") + self.extra_kwargs = kwargs + try: + import litellm + + litellm.drop_params = True + self._litellm = litellm + except ImportError: + pruna_logger.error("litellm not installed. Install with: pip install litellm") + raise + + def generate( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Type[BaseModel]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate responses for images and prompts. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + prompts : List[str] + List of text prompts. + response_format : Type[BaseModel] | None + Optional pydantic model for structured output. + **kwargs : Any + Additional arguments passed to litellm completion. + + Returns + ------- + List[str] + Generated responses. + """ + results = [] + for image, prompt in zip(images, prompts): + try: + # Prepare message content + content = [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": self._image_to_data_url(image)}}, + ] + # Prepare completion kwargs + completion_kwargs = { + "model": self.model_name, + "messages": [{"role": "user", "content": content}], + "api_key": self.api_key, + **self.extra_kwargs, + **kwargs, + } + # Add structured generation if requested + if response_format is not None: + # Use litellm's response_format parameter + completion_kwargs["response_format"] = response_format + # Use synchronous completion + response = self._litellm.completion(**completion_kwargs) + content_result = response.choices[0].message.content + # If using pydantic, content is already parsed + if response_format is not None and isinstance(content_result, response_format): + # Return JSON string representation + results.append(content_result.model_dump_json()) + else: + results.append(content_result) + except Exception as e: + pruna_logger.error(f"Litellm generation failed: {e}") + results.append("") + return results + + def score( + self, + images: List[Image.Image], + questions: List[str], + answers: List[str], + **kwargs: Any, + ) -> List[float]: + """ + Score how well answers match images for given questions. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + questions : List[str] + List of questions. + answers : List[str] + List of expected answers. + **kwargs : Any + Additional arguments passed to generate. + + Returns + ------- + List[float] + Scores for each image-question pair. + """ + scores = [] + for image, question, answer in zip(images, questions, answers): + prompt = f"{question} Answer with just Yes or No." + response = self.generate([image], [prompt], **kwargs)[0].lower() + score = 1.0 if answer.lower() in response else 0.0 + scores.append(score) + return scores + + def _image_to_data_url(self, image: Image.Image) -> str: + buffer = io.BytesIO() + image.save(buffer, format="PNG") + buffer.seek(0) + b64 = base64.b64encode(buffer.read()).decode("utf-8") + return f"data:image/png;base64,{b64}" + + +class TransformersVLM(BaseVLM): + """ + VLM using HuggingFace Transformers for local inference. + + Supports models like BLIP, LLaVA, etc. + + Parameters + ---------- + model_name : str, optional + HuggingFace model name. Default is "Salesforce/blip2-opt-2.7b". + device : str | torch.device | None, optional + Device for inference. Auto-detected if None. + use_outlines : bool, optional + Use outlines for constrained decoding. Default is False. + **kwargs : Any + Additional arguments passed to model generation. + """ + + def __init__( + self, + model_name: str = "Salesforce/blip2-opt-2.7b", + device: Optional[str | torch.device] = None, + use_outlines: bool = False, + **kwargs: Any, + ) -> None: + self.model_name = model_name + self.use_outlines = use_outlines + if device is None: + if torch.cuda.is_available(): + self.device = torch.device("cuda") + elif torch.backends.mps.is_available(): + self.device = torch.device("mps") + else: + self.device = torch.device("cpu") + else: + self.device = torch.device(device) + self.extra_kwargs = kwargs + self._model = None + self._processor = None + + def _load_model(self) -> None: + if self._model is not None: + return + try: + from transformers import AutoModelForVision2Seq, AutoProcessorForVision2Seq + except ImportError: + pruna_logger.error("transformers not installed. Install with: pip install transformers") + raise + pruna_logger.info(f"Loading VLM model: {self.model_name}") + self._processor = AutoProcessorForVision2Seq.from_pretrained(self.model_name) + self._model = AutoModelForVision2Seq.from_pretrained(self.model_name) + self._model.to(self.device) + self._model.eval() + + def generate( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[str] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate responses using local VLM. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + prompts : List[str] + List of text prompts. + response_format : str | None + Optional format constraint (e.g., "json", "integer", "yes_no"). + **kwargs : Any + Additional arguments passed to model generate. + + Returns + ------- + List[str] + Generated responses. + """ + self._load_model() + results = [] + max_new_tokens = kwargs.get("max_new_tokens", 128) + # Try outlines if requested + if self.use_outlines and response_format: + results = self._generate_with_outlines(images, prompts, response_format, max_new_tokens) + else: + # Standard generation + with torch.inference_mode(): + for image, prompt in zip(images, prompts): + inputs = self._processor(images=[image], text=prompt, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + output = self._model.generate(**inputs, max_new_tokens=max_new_tokens, **self.extra_kwargs) + response = self._processor.decode(output[0], skip_special_tokens=True) + results.append(response) + return results + + def _generate_with_outlines( + self, + images: List[Image.Image], + prompts: List[str], + format_type: str, + max_new_tokens: int, + ) -> List[str]: + """Generate using outlines for constrained decoding.""" + try: + import outlines + except ImportError: + pruna_logger.warning("outlines not installed, using standard generation") + return self._generate_standard(images, prompts, max_new_tokens) + results = [] + # Define format constraints + if format_type == "json": + generator = outlines.generate.json(self._model) + elif format_type == "integer": + generator = outlines.generate.format(self._model, r"\d+") + elif format_type == "yes_no": + generator = outlines.generate.format(self._model, r"(Yes|No)") + else: + return self._generate_standard(images, prompts, max_new_tokens) + with torch.inference_mode(): + for image, prompt in zip(images, prompts): + try: + inputs = self._processor(images=[image], text=prompt, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + # Generate with outlines + output = generator(**inputs, max_tokens=max_new_tokens) + response = self._processor.decode(output[0], skip_special_tokens=True) + results.append(response) + except Exception as e: + pruna_logger.warning(f"Outlines generation failed: {e}, using standard") + results.append("") + return results + + def _generate_standard( + self, + images: List[Image.Image], + prompts: List[str], + max_new_tokens: int, + ) -> List[str]: + """Standard generation without outlines.""" + results = [] + with torch.inference_mode(): + for image, prompt in zip(images, prompts): + inputs = self._processor(images=[image], text=prompt, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + output = self._model.generate(**inputs, max_new_tokens=max_new_tokens, **self.extra_kwargs) + response = self._processor.decode(output[0], skip_special_tokens=True) + results.append(response) + return results + + def score( + self, + images: List[Image.Image], + questions: List[str], + answers: List[str], + **kwargs: Any, + ) -> List[float]: + """ + Score how well answers match images for given questions. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + questions : List[str] + List of questions. + answers : List[str] + List of expected answers. + **kwargs : Any + Additional arguments passed to generate. + + Returns + ------- + List[float] + Scores for each image-question pair. + """ + scores = [] + for image, question, answer in zip(images, questions, answers): + prompt = f"Question: {question} Answer:" + responses = self.generate([image], [prompt], **kwargs) + response = responses[0].lower() if responses else "" + score = 1.0 if answer.lower() in response else 0.0 + scores.append(score) + return scores