Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 38 additions & 50 deletions vlmrun/client/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def execute(

def generate(
self,
domain: str,
domain: Optional[str] = None,
images: Optional[List[Union[Path, Image.Image]]] = None,
urls: Optional[List[str]] = None,
model: str = "vlm-1",
Expand All @@ -271,7 +271,7 @@ def generate(
"""Generate a document prediction.

Args:
domain: Domain to use for prediction
domain: Domain to use for prediction. Optional when skills are provided via config.
images: List of file paths (Path) or PIL Image objects to process. Either images or urls must be provided.
urls: List of HTTP URLs pointing to images. Either images or urls must be provided.
model: Model to use for prediction
Expand All @@ -287,59 +287,37 @@ def generate(
Raises:
ValueError: If neither images nor urls are provided, or if both are provided
"""
# Input validation
if not images and not urls:
raise ValueError("Either `images` or `urls` must be provided")
if images and urls:
raise ValueError("Only one of `images` or `urls` can be provided")

if images:
# Check if all images are of the same type
image_type = type(images[0])
if not all(isinstance(image, image_type) for image in images):
raise ValueError("All images must be of the same type")
if isinstance(images[0], Path):
images = [_open_image_with_exif(str(image)) for image in images]
elif isinstance(images[0], Image.Image):
pass
else:
raise ValueError("Image must be a path or a PIL Image")
images_data = [encode_image(image, format="JPEG") for image in images]
else:
# URL handling
if not urls:
raise ValueError("URLs list cannot be empty")
if not isinstance(urls[0], str):
raise ValueError("URLs must be strings")
if not all(isinstance(url, str) for url in urls):
raise ValueError("All URLs must be strings")
if not all(url.startswith("http") for url in urls):
raise ValueError("URLs must start with 'http'")
images_data = urls
has_skills = (
config is not None and config.skills is not None and len(config.skills) > 0
)
if not domain and not has_skills:
raise ValueError("Either `domain` or `config.skills` must be provided")

images_data = self._handle_images_or_urls(images, urls)
additional_kwargs = {}
if config:
additional_kwargs["config"] = config.model_dump()
if metadata:
additional_kwargs["metadata"] = metadata.model_dump()
data = {
"model": model,
"images": images_data,
"batch": batch,
"callback_url": callback_url,
**additional_kwargs,
}
if domain is not None:
data["domain"] = domain
response, status_code, headers = self._requestor.request(
method="POST",
url="image/generate",
data={
"model": model,
"images": images_data,
"domain": domain,
"batch": batch,
"callback_url": callback_url,
**additional_kwargs,
},
data=data,
)
if not isinstance(response, dict):
raise TypeError("Expected dict response")
prediction = PredictionResponse(**response)

if autocast:
if autocast and domain:
self._cast_response_to_schema(prediction, domain, config)
return prediction

Expand Down Expand Up @@ -435,7 +413,7 @@ def generate(
model: Model to use for prediction
file: File (pathlib.Path) or file_id to generate prediction from
url: URL to generate prediction from
domain: Domain to use for prediction
domain: Domain to use for prediction. Optional when skills are provided via config.
batch: Whether to run prediction in batch mode
config: GenerateConfig to use for prediction
metadata: Metadata to include in prediction
Expand All @@ -445,24 +423,34 @@ def generate(
Returns:
PredictionResponse: Prediction response
"""
has_skills = (
config is not None
and config.skills is not None
and len(config.skills) > 0
)
if not domain and not has_skills:
raise ValueError("Either `domain` or `config.skills` must be provided")

is_url, file_or_url = self._handle_file_or_url(file, url)

additional_kwargs = {}
if config:
additional_kwargs["config"] = config.model_dump()
if metadata:
additional_kwargs["metadata"] = metadata.model_dump()
data = {
"model": model,
"url" if is_url else "file_id": file_or_url,
"batch": batch,
"callback_url": callback_url,
**additional_kwargs,
}
if domain is not None:
data["domain"] = domain
response, status_code, headers = self._requestor.request(
method="POST",
url=f"{route}/generate",
data={
"model": model,
"url" if is_url else "file_id": file_or_url,
"domain": domain,
"batch": batch,
"callback_url": callback_url,
**additional_kwargs,
},
data=data,
)
if not isinstance(response, dict):
raise TypeError("Expected dict response")
Expand All @@ -479,7 +467,7 @@ def generate(
except Exception as e:
logger.warning(f"Failed to cast response to MarkdownDocument: {e}")
# Handle other domains with autocast
elif autocast:
elif autocast and domain:
self._cast_response_to_schema(prediction, domain, config)
return prediction

Expand Down
37 changes: 37 additions & 0 deletions vlmrun/client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ class AgentExecutionOrCreationConfig(BaseModel):
json_schema: Optional[Dict[str, Any]] = Field(
default=None, description="The JSON schema response model of the agent"
)
skills: Optional[List[AgentSkill]] = Field(
default=None,
description="List of agent skills to enable for this execution.",
)

@model_validator(mode="after")
def validate_config(self):
Expand Down Expand Up @@ -308,11 +312,44 @@ class AgentCreationResponse(BaseModel):
status: JobStatus = Field(..., description="The status of the agent")


class AgentSkill(BaseModel):
"""A skill that provides domain-specific expertise.

Provide either ``skill_id`` or ``skill_name``.
``version`` is used with ``skill_name`` to pin a specific skill version;
it defaults to ``"latest"``.
"""

type: str = Field(
default="vlm-run",
description="The type of the skill (e.g., 'vlm-run').",
)
skill_id: Optional[str] = Field(
default=None,
description="The unique identifier of the skill (UUID or name string).",
)
skill_name: Optional[str] = Field(
default=None,
description="Human-readable skill name for lookup. Alternative to skill_id.",
)
version: str = Field(
default="latest",
description="The version of the skill (e.g., 'latest', '20260219-abc123').",
)

@model_validator(mode="after")
def _require_skill_id_or_name(self):
if not self.skill_id and not self.skill_name:
raise ValueError("Either 'skill_id' or 'skill_name' must be provided")
return self


class GenerationConfig(BaseModel):
prompt: Optional[str] = Field(default=None)
response_model: Optional[Type[BaseModel]] = Field(default=None)
json_schema: Optional[Dict[str, Any]] = Field(default=None)
gql_stmt: Optional[str] = Field(default=None)
skills: Optional[List[AgentSkill]] = Field(default=None)
max_retries: int = Field(default=3)
max_tokens: int = Field(default=65535)
temperature: float = Field(default=0.0)
Expand Down
Loading