diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index eb06565..b7a8c32 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -8,3 +8,4 @@ # Plugins /plugins/data-designer-template/ @NVIDIA-NeMo/data_designer_reviewers +/plugins/data-designer-visual-search/ eric.tramel@gmail.com diff --git a/docs/plugins/data-designer-visual-search/examples.md b/docs/plugins/data-designer-visual-search/examples.md new file mode 100644 index 0000000..20227c3 --- /dev/null +++ b/docs/plugins/data-designer-visual-search/examples.md @@ -0,0 +1,86 @@ +# Practical Examples + +## Branch From an Earlier Crop + +The image workspace is tree-shaped. A model can create one crop, inspect it, +then operate on the original image again: + +1. `open_image()` returns `img_0000`. +2. `crop_image(image_id="img_0000", x=0, y=0, width=50, height=50, unit="percent")` + returns `img_0001`. +3. `edit_color(image_id="img_0001", contrast=1.5)` returns `img_0002`. +4. `crop_image(image_id="img_0000", x=50, y=50, width=50, height=50, unit="percent")` + returns `img_0003`. + +The resulting history preserves both branches: + +```text +img_0000 open_image +|-- img_0001 crop_image +| `-- img_0002 edit_color +`-- img_0003 crop_image +``` + +This is useful when the model needs to compare multiple areas or recover from a +crop that turned out to be unhelpful. + +## Read Small Text + +```python +builder.add_column( + name="label_text", + column_type="visual-search", + image_column="product_photo", + prompt=( + "Find the ingredients label. Crop tightly around it, increase contrast " + "if needed, and return the text you can read." + ), + model_alias="vision", + max_tool_call_turns=5, +) +``` + +Expected model behavior: + +- Inspect the original image. +- Crop the label region. +- Optionally increase contrast or convert to grayscale. +- Answer using the attached edited crop. + +## Compare Two Regions + +```python +builder.add_column( + name="comparison", + column_type="visual-search", + image_column="shelf_image", + prompt=( + "Compare the price tags on the left and right sides of the shelf. " + "Use separate crops and report which price is lower." + ), + model_alias="vision", + max_tool_call_turns=6, +) +``` + +The model can crop the left tag from `img_0000`, crop the right tag from +`img_0000`, inspect both resulting IDs, and answer from the evidence. + +## Data URI Input + +The `image_column` can contain base64 data or a full data URI instead of a file +path: + +```python +builder.add_column( + name="base64_answer", + column_type="visual-search", + image_column="image_data_uri", + prompt="Crop the center of the image and describe what is visible.", + model_alias="vision", +) +``` + +If values are raw base64 and the format cannot be detected reliably, set +`image_data_type="base64"` and `image_format="png"` or another supported image +format. diff --git a/docs/plugins/data-designer-visual-search/index.md b/docs/plugins/data-designer-visual-search/index.md new file mode 100644 index 0000000..b14c019 --- /dev/null +++ b/docs/plugins/data-designer-visual-search/index.md @@ -0,0 +1,94 @@ +# data-designer-visual-search + +`data-designer-visual-search` adds a `visual-search` column type for +image-grounded visual search workflows. It is intended for cases where a VLM +needs to inspect an image, crop into regions, transform the view, adjust color, +and then continue reasoning over the resulting image. + +The plugin owns the extra plumbing that ordinary model tool calling does not +handle: each local image operation returns an `image_id`, the new image is held +in memory, and the generated image is attached back into the next model turn as +multimodal context. + +## What It Provides + +- A `VisualSearchColumnConfig` registered as column type `visual-search`. +- A row-scoped in-memory image workspace. +- Local tools for opening images, listing image IDs, inspecting image metadata, + cropping, transforming, and editing color. +- Tree-shaped image history, so the model can branch from any previous + `image_id` instead of following a single linear edit chain. +- A default side-effect column named `{column_name}__image_history` that records + image IDs, parent IDs, child IDs, operations, dimensions, and operation + metadata. +- Optional model trace and reasoning-content side-effect columns that match the + conventions used by Data Designer LLM columns. + +## Column Interface + +| Field | Required | Description | +| --- | --- | --- | +| `name` | Yes | Output column name. | +| `column_type` | Yes | Must be `visual-search`. | +| `image_column` | Yes | Existing column containing a local image path, URL, base64 image, or image data URI. | +| `prompt` | Yes | Jinja2 prompt template for the visual search task. | +| `model_alias` | Yes | Alias of a vision-capable chat model in the Data Designer config. | +| `system_prompt` | No | Optional Jinja2 system prompt appended to the built-in visual search instructions. | +| `image_data_type` | No | Optional explicit image data type, such as `url` or `base64`. Leave unset for auto-detection. | +| `image_format` | Conditional | Required when `image_data_type` is explicitly `base64`. | +| `image_placeholder` | No | Optional text token to include next to every image attachment for endpoints that require one. | +| `max_tool_call_turns` | No | Maximum tool-calling turns per row. Defaults to `6`. | +| `allowed_tools` | No | Optional allowlist of built-in visual tools. Defaults to all tools. | +| `attach_images_after_tool_calls` | No | Whether to attach tool-created images into the next model turn. Defaults to `True`. | +| `include_image_history` | No | Whether to write `{name}__image_history`. Defaults to `True`. | +| `with_trace` | No | Optional trace capture mode. Defaults to `none`. | +| `extract_reasoning_content` | No | Whether to write `{name}__reasoning_content`. Defaults to `False`. | +| `use_default_system_prompt` | No | Whether to prepend built-in image-tool instructions. Defaults to `True`. | + +## Built-In Tools + +| Tool | Purpose | +| --- | --- | +| `open_image` | Opens the configured row image and returns the root `image_id`. | +| `get_image_info` | Returns dimensions, parent ID, children IDs, operation name, and metadata for an `image_id`. | +| `list_images` | Lists every image currently held in the row workspace. | +| `crop_image` | Crops an existing image by pixel or percent coordinates and returns a new `image_id`. | +| `transform_image` | Rotates, flips, or resizes an existing image and returns a new `image_id`. | +| `edit_color` | Adjusts brightness, contrast, saturation, sharpness, grayscale, or inversion and returns a new `image_id`. | + +Tool results are ordinary tool messages containing JSON metadata. When a tool +creates an image, the plugin also attaches that image to the next user turn so +the model can inspect it visually. + +## Image History + +Every image node has stable metadata: + +```json +{ + "image_id": "img_0001", + "parent_image_id": "img_0000", + "children_image_ids": [], + "operation": "crop_image", + "width": 512, + "height": 384, + "metadata": { + "box": {"left": 0, "top": 0, "right": 512, "bottom": 384}, + "unit": "pixels" + } +} +``` + +Because the model controls the `image_id` argument, it can crop from the root +image, transform that crop, rewind to the root, and crop a different region. +The workspace keeps the whole tree for the duration of that row. + +## When To Use It + +Use `visual-search` when the model needs iterative visual operations before it +can answer reliably. Good examples include reading small labels, comparing +regions, checking color after contrast adjustment, or zooming into a specific +part of a larger image. + +For a single prompt over an image with no iterative image manipulation, a +standard Data Designer LLM column with multimodal context may be simpler. diff --git a/docs/plugins/data-designer-visual-search/usage.md b/docs/plugins/data-designer-visual-search/usage.md new file mode 100644 index 0000000..cf4cc12 --- /dev/null +++ b/docs/plugins/data-designer-visual-search/usage.md @@ -0,0 +1,125 @@ +# Usage + +This example starts with a dataframe column containing image paths and adds a +`visual-search` column. The model can call image tools while answering the +prompt, and the plugin will pass each resulting crop or edited image back to the +model automatically. + +```python +import pandas as pd + +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider +from data_designer.config.seed_source_dataframe import DataFrameSeedSource +from data_designer.interface.data_designer import DataDesigner + +seed_df = pd.DataFrame( + { + "image_path": ["/path/to/store-shelf.png"], + "target": ["the nutrition label on the cereal box"], + } +) + +provider = ModelProvider( + name="nvidia", + endpoint="https://integrate.api.nvidia.com/v1", + api_key="NVIDIA_API_KEY", + provider_type="openai", +) + +vision_model = ModelConfig( + alias="vision", + model="qwen/qwen3.5-122b-a10b", + provider="nvidia", + inference_parameters=ChatCompletionInferenceParams( + temperature=0, + max_tokens=512, + timeout=60, + ), +) + +builder = DataDesignerConfigBuilder(model_configs=[vision_model]) +builder.with_seed_dataset(DataFrameSeedSource(df=seed_df)) +builder.add_column( + name="visual_answer", + column_type="visual-search", + image_column="image_path", + prompt=( + "Find {{ target }}. Use crop_image or edit_color if that helps. " + "Return the text you can read and explain which image_id you used." + ), + model_alias="vision", + max_tool_call_turns=4, +) + +result = DataDesigner( + artifact_path="artifacts", + model_providers=[provider], +).preview(builder, num_records=1) +``` + +The generated dataset includes: + +- `visual_answer`: the model's final answer. +- `visual_answer__image_history`: the image operation tree produced while + answering the row. + +## Restricting Tools + +Use `allowed_tools` when you want the model to perform only a narrower set of +operations: + +```python +builder.add_column( + name="crop_only_answer", + column_type="visual-search", + image_column="image_path", + prompt="Crop the upper-right quadrant and describe the dominant color.", + model_alias="vision", + allowed_tools=["open_image", "get_image_info", "crop_image"], + max_tool_call_turns=2, +) +``` + +## Endpoint Image Tokens + +Most OpenAI-compatible multimodal endpoints accept image content blocks directly. +Some model servers also require a model-specific image token in the text for +each attached image. Set `image_placeholder` for those endpoints: + +```python +builder.add_column( + name="answer", + column_type="visual-search", + image_column="image_path", + prompt="Inspect the attached image and answer the question.", + model_alias="vision", + image_placeholder="", +) +``` + +The plugin prepends the placeholder to the initial image turn and to every later +turn that attaches a tool-created image. + +## Capturing Trace Output + +The column supports the same trace side-effect pattern as other LLM-backed Data +Designer columns: + +```python +from data_designer.config.utils.trace_type import TraceType + +builder.add_column( + name="answer_with_trace", + column_type="visual-search", + image_column="image_path", + prompt="Zoom into the serial number and read it.", + model_alias="vision", + with_trace=TraceType.ALL_MESSAGES, + extract_reasoning_content=True, +) +``` + +This adds `answer_with_trace__trace` and +`answer_with_trace__reasoning_content` when the selected model provides +reasoning content. diff --git a/docs/plugins/index.md b/docs/plugins/index.md index 4e54e2e..3e3b34c 100644 --- a/docs/plugins/index.md +++ b/docs/plugins/index.md @@ -16,4 +16,15 @@ Browse available Data Designer plugins by what they add to your data generation text-transform + + + data-designer-visual-search + v0.1.0 + + Visual search column with local image crop, transform, and color-edit tools + + Column types + visual-search + + diff --git a/plugins/data-designer-visual-search/CODEOWNERS b/plugins/data-designer-visual-search/CODEOWNERS new file mode 100644 index 0000000..013e51f --- /dev/null +++ b/plugins/data-designer-visual-search/CODEOWNERS @@ -0,0 +1,3 @@ +# Owner(s) of this plugin — used to generate the root CODEOWNERS file. +# GitHub accepts @username, @org/team, or email format. +* eric.tramel@gmail.com diff --git a/plugins/data-designer-visual-search/README.md b/plugins/data-designer-visual-search/README.md new file mode 100644 index 0000000..20ae4bd --- /dev/null +++ b/plugins/data-designer-visual-search/README.md @@ -0,0 +1,61 @@ +# data-designer-visual-search + +Data Designer plugin for VLM-driven visual search over image columns, with +local image crop, transform, and color-edit tools. + +The `visual-search` column runs a vision-capable chat model with built-in +image-operation tools: + +- `open_image` +- `get_image_info` +- `list_images` +- `crop_image` +- `transform_image` +- `edit_color` + +Each operation returns an `image_id`. The column keeps intermediate images in +memory and re-attaches tool-produced images to the following model turn, so the +model can inspect a crop or transformed image before deciding what to do next. +Because IDs remain addressable, the model can branch from an earlier image +rather than being forced through a linear edit chain. + +## Installation + +```bash +pip install data-designer-visual-search +``` + +## Usage + +Once installed, the `visual-search` column type is automatically discovered by +[NeMo Data Designer](https://github.com/NVIDIA-NeMo/DataDesigner). + +```python +import pandas as pd +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.seed_source_dataframe import DataFrameSeedSource +from data_designer.interface.data_designer import DataDesigner + +seed_df = pd.DataFrame({"image_path": ["/path/to/scene.png"]}) + +builder = DataDesignerConfigBuilder() +builder.with_seed_dataset(DataFrameSeedSource(df=seed_df)) +builder.add_column( + name="visual_answer", + column_type="visual-search", + image_column="image_path", + prompt="Find the red object. Crop or transform the image if that helps.", + model_alias="nvidia-vision", + # Optional: set a model-specific image token here if your endpoint requires + # one in the text for every attached image. + # image_placeholder="", +) + +result = DataDesigner(artifact_path="artifacts").preview(builder, num_records=1) +``` + +The main output column contains the model's final answer. By default the plugin +also writes `{column_name}__image_history`, a compact tree of image IDs, parent +IDs, operations, and dimensions. + +See `docs/` for the full interface reference and practical examples. diff --git a/plugins/data-designer-visual-search/docs/examples.md b/plugins/data-designer-visual-search/docs/examples.md new file mode 100644 index 0000000..20227c3 --- /dev/null +++ b/plugins/data-designer-visual-search/docs/examples.md @@ -0,0 +1,86 @@ +# Practical Examples + +## Branch From an Earlier Crop + +The image workspace is tree-shaped. A model can create one crop, inspect it, +then operate on the original image again: + +1. `open_image()` returns `img_0000`. +2. `crop_image(image_id="img_0000", x=0, y=0, width=50, height=50, unit="percent")` + returns `img_0001`. +3. `edit_color(image_id="img_0001", contrast=1.5)` returns `img_0002`. +4. `crop_image(image_id="img_0000", x=50, y=50, width=50, height=50, unit="percent")` + returns `img_0003`. + +The resulting history preserves both branches: + +```text +img_0000 open_image +|-- img_0001 crop_image +| `-- img_0002 edit_color +`-- img_0003 crop_image +``` + +This is useful when the model needs to compare multiple areas or recover from a +crop that turned out to be unhelpful. + +## Read Small Text + +```python +builder.add_column( + name="label_text", + column_type="visual-search", + image_column="product_photo", + prompt=( + "Find the ingredients label. Crop tightly around it, increase contrast " + "if needed, and return the text you can read." + ), + model_alias="vision", + max_tool_call_turns=5, +) +``` + +Expected model behavior: + +- Inspect the original image. +- Crop the label region. +- Optionally increase contrast or convert to grayscale. +- Answer using the attached edited crop. + +## Compare Two Regions + +```python +builder.add_column( + name="comparison", + column_type="visual-search", + image_column="shelf_image", + prompt=( + "Compare the price tags on the left and right sides of the shelf. " + "Use separate crops and report which price is lower." + ), + model_alias="vision", + max_tool_call_turns=6, +) +``` + +The model can crop the left tag from `img_0000`, crop the right tag from +`img_0000`, inspect both resulting IDs, and answer from the evidence. + +## Data URI Input + +The `image_column` can contain base64 data or a full data URI instead of a file +path: + +```python +builder.add_column( + name="base64_answer", + column_type="visual-search", + image_column="image_data_uri", + prompt="Crop the center of the image and describe what is visible.", + model_alias="vision", +) +``` + +If values are raw base64 and the format cannot be detected reliably, set +`image_data_type="base64"` and `image_format="png"` or another supported image +format. diff --git a/plugins/data-designer-visual-search/docs/index.md b/plugins/data-designer-visual-search/docs/index.md new file mode 100644 index 0000000..b14c019 --- /dev/null +++ b/plugins/data-designer-visual-search/docs/index.md @@ -0,0 +1,94 @@ +# data-designer-visual-search + +`data-designer-visual-search` adds a `visual-search` column type for +image-grounded visual search workflows. It is intended for cases where a VLM +needs to inspect an image, crop into regions, transform the view, adjust color, +and then continue reasoning over the resulting image. + +The plugin owns the extra plumbing that ordinary model tool calling does not +handle: each local image operation returns an `image_id`, the new image is held +in memory, and the generated image is attached back into the next model turn as +multimodal context. + +## What It Provides + +- A `VisualSearchColumnConfig` registered as column type `visual-search`. +- A row-scoped in-memory image workspace. +- Local tools for opening images, listing image IDs, inspecting image metadata, + cropping, transforming, and editing color. +- Tree-shaped image history, so the model can branch from any previous + `image_id` instead of following a single linear edit chain. +- A default side-effect column named `{column_name}__image_history` that records + image IDs, parent IDs, child IDs, operations, dimensions, and operation + metadata. +- Optional model trace and reasoning-content side-effect columns that match the + conventions used by Data Designer LLM columns. + +## Column Interface + +| Field | Required | Description | +| --- | --- | --- | +| `name` | Yes | Output column name. | +| `column_type` | Yes | Must be `visual-search`. | +| `image_column` | Yes | Existing column containing a local image path, URL, base64 image, or image data URI. | +| `prompt` | Yes | Jinja2 prompt template for the visual search task. | +| `model_alias` | Yes | Alias of a vision-capable chat model in the Data Designer config. | +| `system_prompt` | No | Optional Jinja2 system prompt appended to the built-in visual search instructions. | +| `image_data_type` | No | Optional explicit image data type, such as `url` or `base64`. Leave unset for auto-detection. | +| `image_format` | Conditional | Required when `image_data_type` is explicitly `base64`. | +| `image_placeholder` | No | Optional text token to include next to every image attachment for endpoints that require one. | +| `max_tool_call_turns` | No | Maximum tool-calling turns per row. Defaults to `6`. | +| `allowed_tools` | No | Optional allowlist of built-in visual tools. Defaults to all tools. | +| `attach_images_after_tool_calls` | No | Whether to attach tool-created images into the next model turn. Defaults to `True`. | +| `include_image_history` | No | Whether to write `{name}__image_history`. Defaults to `True`. | +| `with_trace` | No | Optional trace capture mode. Defaults to `none`. | +| `extract_reasoning_content` | No | Whether to write `{name}__reasoning_content`. Defaults to `False`. | +| `use_default_system_prompt` | No | Whether to prepend built-in image-tool instructions. Defaults to `True`. | + +## Built-In Tools + +| Tool | Purpose | +| --- | --- | +| `open_image` | Opens the configured row image and returns the root `image_id`. | +| `get_image_info` | Returns dimensions, parent ID, children IDs, operation name, and metadata for an `image_id`. | +| `list_images` | Lists every image currently held in the row workspace. | +| `crop_image` | Crops an existing image by pixel or percent coordinates and returns a new `image_id`. | +| `transform_image` | Rotates, flips, or resizes an existing image and returns a new `image_id`. | +| `edit_color` | Adjusts brightness, contrast, saturation, sharpness, grayscale, or inversion and returns a new `image_id`. | + +Tool results are ordinary tool messages containing JSON metadata. When a tool +creates an image, the plugin also attaches that image to the next user turn so +the model can inspect it visually. + +## Image History + +Every image node has stable metadata: + +```json +{ + "image_id": "img_0001", + "parent_image_id": "img_0000", + "children_image_ids": [], + "operation": "crop_image", + "width": 512, + "height": 384, + "metadata": { + "box": {"left": 0, "top": 0, "right": 512, "bottom": 384}, + "unit": "pixels" + } +} +``` + +Because the model controls the `image_id` argument, it can crop from the root +image, transform that crop, rewind to the root, and crop a different region. +The workspace keeps the whole tree for the duration of that row. + +## When To Use It + +Use `visual-search` when the model needs iterative visual operations before it +can answer reliably. Good examples include reading small labels, comparing +regions, checking color after contrast adjustment, or zooming into a specific +part of a larger image. + +For a single prompt over an image with no iterative image manipulation, a +standard Data Designer LLM column with multimodal context may be simpler. diff --git a/plugins/data-designer-visual-search/docs/usage.md b/plugins/data-designer-visual-search/docs/usage.md new file mode 100644 index 0000000..cf4cc12 --- /dev/null +++ b/plugins/data-designer-visual-search/docs/usage.md @@ -0,0 +1,125 @@ +# Usage + +This example starts with a dataframe column containing image paths and adds a +`visual-search` column. The model can call image tools while answering the +prompt, and the plugin will pass each resulting crop or edited image back to the +model automatically. + +```python +import pandas as pd + +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider +from data_designer.config.seed_source_dataframe import DataFrameSeedSource +from data_designer.interface.data_designer import DataDesigner + +seed_df = pd.DataFrame( + { + "image_path": ["/path/to/store-shelf.png"], + "target": ["the nutrition label on the cereal box"], + } +) + +provider = ModelProvider( + name="nvidia", + endpoint="https://integrate.api.nvidia.com/v1", + api_key="NVIDIA_API_KEY", + provider_type="openai", +) + +vision_model = ModelConfig( + alias="vision", + model="qwen/qwen3.5-122b-a10b", + provider="nvidia", + inference_parameters=ChatCompletionInferenceParams( + temperature=0, + max_tokens=512, + timeout=60, + ), +) + +builder = DataDesignerConfigBuilder(model_configs=[vision_model]) +builder.with_seed_dataset(DataFrameSeedSource(df=seed_df)) +builder.add_column( + name="visual_answer", + column_type="visual-search", + image_column="image_path", + prompt=( + "Find {{ target }}. Use crop_image or edit_color if that helps. " + "Return the text you can read and explain which image_id you used." + ), + model_alias="vision", + max_tool_call_turns=4, +) + +result = DataDesigner( + artifact_path="artifacts", + model_providers=[provider], +).preview(builder, num_records=1) +``` + +The generated dataset includes: + +- `visual_answer`: the model's final answer. +- `visual_answer__image_history`: the image operation tree produced while + answering the row. + +## Restricting Tools + +Use `allowed_tools` when you want the model to perform only a narrower set of +operations: + +```python +builder.add_column( + name="crop_only_answer", + column_type="visual-search", + image_column="image_path", + prompt="Crop the upper-right quadrant and describe the dominant color.", + model_alias="vision", + allowed_tools=["open_image", "get_image_info", "crop_image"], + max_tool_call_turns=2, +) +``` + +## Endpoint Image Tokens + +Most OpenAI-compatible multimodal endpoints accept image content blocks directly. +Some model servers also require a model-specific image token in the text for +each attached image. Set `image_placeholder` for those endpoints: + +```python +builder.add_column( + name="answer", + column_type="visual-search", + image_column="image_path", + prompt="Inspect the attached image and answer the question.", + model_alias="vision", + image_placeholder="", +) +``` + +The plugin prepends the placeholder to the initial image turn and to every later +turn that attaches a tool-created image. + +## Capturing Trace Output + +The column supports the same trace side-effect pattern as other LLM-backed Data +Designer columns: + +```python +from data_designer.config.utils.trace_type import TraceType + +builder.add_column( + name="answer_with_trace", + column_type="visual-search", + image_column="image_path", + prompt="Zoom into the serial number and read it.", + model_alias="vision", + with_trace=TraceType.ALL_MESSAGES, + extract_reasoning_content=True, +) +``` + +This adds `answer_with_trace__trace` and +`answer_with_trace__reasoning_content` when the selected model provides +reasoning content. diff --git a/plugins/data-designer-visual-search/pyproject.toml b/plugins/data-designer-visual-search/pyproject.toml new file mode 100644 index 0000000..0414b27 --- /dev/null +++ b/plugins/data-designer-visual-search/pyproject.toml @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[project] +name = "data-designer-visual-search" +version = "0.1.0" +description = "Visual search column with local image crop, transform, and color-edit tools" +requires-python = ">=3.10" +dependencies = [ + "data-designer>=0.5.7", + "pillow", + "requests", +] +license = "Apache-2.0" +readme = "README.md" +authors = [ + {name = "NVIDIA Corporation"}, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", +] + +[project.entry-points."data_designer.plugins"] +visual-search = "data_designer_visual_search.plugin:plugin" + +[project.urls] +Repository = "https://github.com/NVIDIA-NeMo/DataDesignerPlugins" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/data_designer_visual_search"] + +[tool.ruff] +extend = "../../pyproject.toml" diff --git a/plugins/data-designer-visual-search/src/data_designer_visual_search/__init__.py b/plugins/data-designer-visual-search/src/data_designer_visual_search/__init__.py new file mode 100644 index 0000000..52a7a9d --- /dev/null +++ b/plugins/data-designer-visual-search/src/data_designer_visual_search/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/plugins/data-designer-visual-search/src/data_designer_visual_search/config.py b/plugins/data-designer-visual-search/src/data_designer_visual_search/config.py new file mode 100644 index 0000000..9b6179c --- /dev/null +++ b/plugins/data-designer-visual-search/src/data_designer_visual_search/config.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Literal + +from data_designer.config.base import SingleColumnConfig +from data_designer.config.models import ModalityDataType +from data_designer.config.utils.constants import REASONING_CONTENT_COLUMN_POSTFIX, TRACE_COLUMN_POSTFIX +from data_designer.config.utils.image_helpers import ImageFormat +from data_designer.config.utils.misc import assert_valid_jinja2_template, extract_keywords_from_jinja2_template +from data_designer.config.utils.trace_type import TraceType +from pydantic import Field, model_validator +from typing_extensions import Self + +VisualSearchToolName = Literal[ + "open_image", + "get_image_info", + "list_images", + "crop_image", + "transform_image", + "edit_color", +] + + +class VisualSearchColumnConfig(SingleColumnConfig): + """Configuration for image-grounded visual search with local image-operation tools. + + The column runs a vision-capable chat model with built-in image tools. Each tool + returns an image ID, and subsequent calls may operate on any previous image ID, + which lets the model branch from earlier points in the image history. + """ + + column_type: Literal["visual-search"] = "visual-search" + + image_column: str = Field(description="Column containing a local image path, URL, base64 string, or data URI.") + prompt: str = Field(description="Jinja2 prompt template for the visual search task.") + model_alias: str = Field(description="Alias of the vision-capable chat model to use.") + system_prompt: str | None = Field(default=None, description="Optional Jinja2 system prompt template.") + image_data_type: ModalityDataType | None = Field( + default=None, + description="Optional explicit format for values in image_column. Leave unset for auto-detection.", + ) + image_format: ImageFormat | None = Field( + default=None, + description="Required when image_data_type is base64 and the image format cannot be auto-detected.", + ) + image_placeholder: str | None = Field( + default=None, + description="Optional model-specific image token to include in text for endpoints that require it.", + ) + max_tool_call_turns: int = Field( + default=6, + ge=1, + description="Maximum tool-calling turns allowed for each row before the model must answer.", + ) + allowed_tools: list[VisualSearchToolName] | None = Field( + default=None, + description="Optional allowlist of built-in visual tools. Defaults to all tools.", + ) + attach_images_after_tool_calls: bool = Field( + default=True, + description="Attach resulting tool images back into the next model turn.", + ) + include_image_history: bool = Field( + default=True, + description="Add a side-effect column with the tree of image operations and IDs.", + ) + with_trace: TraceType = Field(default=TraceType.NONE, description="Optional chat trace capture mode.") + extract_reasoning_content: bool = Field( + default=False, + description="If True, capture reasoning_content from the final assistant message.", + ) + use_default_system_prompt: bool = Field( + default=True, + description="Prepend built-in instructions explaining image IDs and visual tools.", + ) + + @staticmethod + def get_column_emoji() -> str: + return "🔎" + + @property + def required_columns(self) -> list[str]: + required_cols = [self.image_column, *extract_keywords_from_jinja2_template(self.prompt)] + if self.system_prompt: + required_cols.extend(extract_keywords_from_jinja2_template(self.system_prompt)) + return list(dict.fromkeys(required_cols)) + + @property + def side_effect_columns(self) -> list[str]: + return [ + *([f"{self.name}__image_history"] if self.include_image_history else []), + *([f"{self.name}{TRACE_COLUMN_POSTFIX}"] if self.with_trace != TraceType.NONE else []), + *([f"{self.name}{REASONING_CONTENT_COLUMN_POSTFIX}"] if self.extract_reasoning_content else []), + ] + + @model_validator(mode="after") + def validate_templates_and_image_format(self) -> Self: + """Validate prompt templates and image modality settings.""" + assert_valid_jinja2_template(self.prompt) + if self.system_prompt: + assert_valid_jinja2_template(self.system_prompt) + if self.image_data_type == ModalityDataType.BASE64 and self.image_format is None: + raise ValueError("image_format is required when image_data_type is base64") + return self diff --git a/plugins/data-designer-visual-search/src/data_designer_visual_search/impl.py b/plugins/data-designer-visual-search/src/data_designer_visual_search/impl.py new file mode 100644 index 0000000..d03ced3 --- /dev/null +++ b/plugins/data-designer-visual-search/src/data_designer_visual_search/impl.py @@ -0,0 +1,229 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from data_designer.config.utils.constants import REASONING_CONTENT_COLUMN_POSTFIX, TRACE_COLUMN_POSTFIX +from data_designer.config.utils.trace_type import TraceType +from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModel, GenerationStrategy +from data_designer.engine.models.utils import ChatMessage +from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering +from data_designer.engine.processing.utils import deserialize_json_values + +from data_designer_visual_search.config import VisualSearchColumnConfig +from data_designer_visual_search.tools import VisualImageWorkspace, VisualSearchToolExecutor + +if TYPE_CHECKING: + from typing import Any + + from data_designer.engine.models.clients.types import ChatCompletionResponse, ToolCall + +DEFAULT_VISUAL_SEARCH_SYSTEM_PROMPT = """\ +You are a visual search agent working with an in-memory image tree. +Use the available image tools when cropping, transforming, or color-adjusting the image would help answer. +Every image has an image_id. Tool calls may operate on any previous image_id, so you can branch from earlier images. +After a tool creates an image, that image will be attached in the next user turn with its image_id. +When you have enough evidence, stop calling tools and answer the user's prompt directly. +""" + +TOOL_BUDGET_EXHAUSTED_MESSAGE = ( + "Tool call budget exhausted. Use the images and tool results already shown, then provide the final answer." +) + + +class VisualSearchColumnGenerator( + WithJinja2UserTemplateRendering, + ColumnGeneratorWithModel[VisualSearchColumnConfig], +): + """Run a vision model with built-in image-operation tools.""" + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, data: dict) -> dict: + """Generate a visual-search answer for one row.""" + deserialized_record = deserialize_json_values(data) + workspace = self._create_workspace(deserialized_record) + executor = VisualSearchToolExecutor(workspace=workspace, allowed_tools=self.config.allowed_tools) + root = workspace.open_image() + + messages = self._build_initial_messages(deserialized_record, workspace, root["image_id"]) + final_text, trace = self._run_tool_loop(messages, workspace, executor) + + data[self.config.name] = final_text + if self.config.include_image_history: + data[f"{self.config.name}__image_history"] = workspace.history() + if self.config.with_trace == TraceType.ALL_MESSAGES: + data[f"{self.config.name}{TRACE_COLUMN_POSTFIX}"] = [message.to_dict() for message in trace] + elif self.config.with_trace == TraceType.LAST_MESSAGE: + last_assistant = next((message for message in reversed(trace) if message.role == "assistant"), None) + data[f"{self.config.name}{TRACE_COLUMN_POSTFIX}"] = ( + [last_assistant.to_dict()] if last_assistant is not None else [] + ) + if self.config.extract_reasoning_content: + data[f"{self.config.name}{REASONING_CONTENT_COLUMN_POSTFIX}"] = self._extract_reasoning_content(trace) + return data + + def _create_workspace(self, record: dict[str, Any]) -> VisualImageWorkspace: + return VisualImageWorkspace( + source_value=record[self.config.image_column], + base_path=self.base_dataset_path, + image_data_type=self.config.image_data_type, + image_format=self.config.image_format, + ) + + def _build_initial_messages( + self, + record: dict[str, Any], + workspace: VisualImageWorkspace, + root_image_id: str, + ) -> list[ChatMessage]: + prompt = self._render_template(self.config.prompt, record) + prompt = ( + f"{prompt}\n\n" + f"The source image is attached and is available in the tool workspace as image_id {root_image_id!r}. " + "You may call open_image() to retrieve the same root image_id, or operate on this image_id directly." + ) + if self.config.image_placeholder: + prompt = f"{self.config.image_placeholder}\n{prompt}" + + messages: list[ChatMessage] = [] + system_prompt = self._build_system_prompt(record) + if system_prompt: + messages.append(ChatMessage.as_system(system_prompt)) + messages.append( + ChatMessage.as_user([{"type": "text", "text": prompt}, workspace.image_context_block(root_image_id)]) + ) + return messages + + def _build_system_prompt(self, record: dict[str, Any]) -> str | None: + system_parts: list[str] = [] + if self.config.use_default_system_prompt: + system_parts.append(DEFAULT_VISUAL_SEARCH_SYSTEM_PROMPT) + if self.config.system_prompt: + system_parts.append(self._render_template(self.config.system_prompt, record)) + return "\n\n".join(part for part in system_parts if part).strip() or None + + def _render_template(self, template: str, record: dict[str, Any]) -> str: + jinja_render_env = self._create_render_environment(dataset_variables=list(record.keys())) + jinja_render_env.validate_template(template) + return jinja_render_env.render_template(template, record, skip_template_validation=True) + + def _run_tool_loop( + self, + messages: list[ChatMessage], + workspace: VisualImageWorkspace, + executor: VisualSearchToolExecutor, + ) -> tuple[str, list[ChatMessage]]: + tool_call_turns = 0 + tools_enabled = True + tool_schemas = executor.get_tool_schemas() + + while True: + completion_response = self._complete(messages, tool_schemas if tools_enabled else None) + tool_calls = completion_response.message.tool_calls + if tool_calls and tools_enabled: + tool_call_turns += 1 + messages.append(_assistant_tool_message(completion_response)) + + if tool_call_turns > self.config.max_tool_call_turns: + messages.extend(ChatMessage.as_tool(TOOL_BUDGET_EXHAUSTED_MESSAGE, call.id) for call in tool_calls) + messages.append(ChatMessage.as_user(TOOL_BUDGET_EXHAUSTED_MESSAGE)) + tools_enabled = False + continue + + image_ids = self._execute_tool_calls(messages, executor, tool_calls) + if image_ids and self.config.attach_images_after_tool_calls: + messages.append( + ChatMessage.as_user( + _tool_image_context_blocks( + workspace, + image_ids, + image_placeholder=self.config.image_placeholder, + ) + ) + ) + continue + + response_text = (completion_response.message.content or "").strip() + messages.append( + ChatMessage.as_assistant( + content=response_text, + reasoning_content=completion_response.message.reasoning_content or None, + ) + ) + return response_text, messages + + def _complete( + self, messages: list[ChatMessage], tool_schemas: list[dict[str, Any]] | None + ) -> ChatCompletionResponse: + completion_kwargs = {"purpose": f"running visual search for column {self.config.name!r}"} + if tool_schemas: + completion_kwargs["tools"] = tool_schemas + return self.model.completion(messages, **completion_kwargs) + + def _execute_tool_calls( + self, + messages: list[ChatMessage], + executor: VisualSearchToolExecutor, + tool_calls: list[ToolCall], + ) -> list[str]: + image_ids: list[str] = [] + for tool_call in tool_calls: + result = executor.execute(tool_call.name, tool_call.arguments_json) + messages.append(ChatMessage.as_tool(content=result.content, tool_call_id=tool_call.id)) + image_ids.extend(result.image_ids) + return image_ids + + def _extract_reasoning_content(self, trace: list[ChatMessage]) -> str | None: + reasoning_value: str | None = None + for message in reversed(trace): + if message.role == "assistant": + reasoning_value = message.reasoning_content + break + return reasoning_value.strip() or None if reasoning_value is not None else None + + +def _assistant_tool_message(completion_response: ChatCompletionResponse) -> ChatMessage: + tool_calls = [ + { + "id": tool_call.id, + "type": "function", + "function": {"name": tool_call.name, "arguments": tool_call.arguments_json}, + } + for tool_call in completion_response.message.tool_calls + ] + return ChatMessage.as_assistant( + content=(completion_response.message.content or "").strip(), + reasoning_content=completion_response.message.reasoning_content or None, + tool_calls=tool_calls, + ) + + +def _tool_image_context_blocks( + workspace: VisualImageWorkspace, + image_ids: list[str], + *, + image_placeholder: str | None = None, +) -> list[dict[str, Any]]: + blocks: list[dict[str, Any]] = [] + for image_id in image_ids: + info = workspace.get_image_info(image_id) + text = ( + "Tool result image attached: " + f"image_id={image_id}, parent_image_id={info['parent_image_id']}, " + f"operation={info['operation']}, size={info['width']}x{info['height']}." + ) + if image_placeholder: + text = f"{image_placeholder}\n{text}" + blocks.append( + { + "type": "text", + "text": text, + } + ) + blocks.append(workspace.image_context_block(image_id)) + return blocks diff --git a/plugins/data-designer-visual-search/src/data_designer_visual_search/plugin.py b/plugins/data-designer-visual-search/src/data_designer_visual_search/plugin.py new file mode 100644 index 0000000..414f9a6 --- /dev/null +++ b/plugins/data-designer-visual-search/src/data_designer_visual_search/plugin.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.plugins.plugin import Plugin, PluginType + +plugin = Plugin( + config_qualified_name="data_designer_visual_search.config.VisualSearchColumnConfig", + impl_qualified_name="data_designer_visual_search.impl.VisualSearchColumnGenerator", + plugin_type=PluginType.COLUMN_GENERATOR, +) diff --git a/plugins/data-designer-visual-search/src/data_designer_visual_search/tools.py b/plugins/data-designer-visual-search/src/data_designer_visual_search/tools.py new file mode 100644 index 0000000..5f93b62 --- /dev/null +++ b/plugins/data-designer-visual-search/src/data_designer_visual_search/tools.py @@ -0,0 +1,643 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import base64 +import io +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import requests +from data_designer.config.models import ModalityDataType +from data_designer.config.utils.image_helpers import ( + ImageFormat, + decode_base64_image, + detect_image_format, + extract_base64_from_data_uri, + is_base64_image, + is_image_url, +) +from PIL import Image, ImageEnhance, ImageOps + +DEFAULT_IMAGE_FORMAT = ImageFormat.PNG + + +@dataclass +class ImageNode: + """Image plus lineage metadata stored in a visual-search workspace.""" + + image_id: str + image: Image.Image + parent_image_id: str | None + operation: str + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self, children_image_ids: list[str] | None = None) -> dict[str, Any]: + """Return JSON-serializable metadata for this image node.""" + return { + "image_id": self.image_id, + "parent_image_id": self.parent_image_id, + "children_image_ids": children_image_ids or [], + "operation": self.operation, + "width": self.image.width, + "height": self.image.height, + "metadata": self.metadata, + } + + +@dataclass +class VisualToolExecution: + """Result of executing a local visual tool call.""" + + content: str + image_ids: list[str] = field(default_factory=list) + is_error: bool = False + + +class VisualImageWorkspace: + """In-memory image tree for visual search tool calls. + + The workspace keeps every intermediate image addressable by ID. Tools can + operate on any prior image ID, so a model can branch from an earlier crop or + transform instead of being forced into a linear edit history. + """ + + def __init__( + self, + *, + source_value: Any, + base_path: Path | None = None, + image_data_type: ModalityDataType | None = None, + image_format: ImageFormat | None = None, + ) -> None: + self._source_value = source_value + self._base_path = base_path + self._image_data_type = image_data_type + self._image_format = image_format + self._nodes: dict[str, ImageNode] = {} + self._next_image_index = 0 + self._root_image_id: str | None = None + + @property + def root_image_id(self) -> str | None: + """Return the root image ID after the source image has been opened.""" + return self._root_image_id + + def open_image(self, path: str | None = None) -> dict[str, Any]: + """Open the configured source image or an explicitly supplied image reference. + + Args: + path: Optional local path, URL, base64 string, or data URI. When omitted, + the configured source image for this row is opened. + + Returns: + Metadata for the opened image node. + """ + if path is None and self._root_image_id is not None: + return self.get_image_info(self._root_image_id) + + source = self._source_value if path is None else path + image = self._load_image(source) + node = self._create_node( + image=image, + parent_image_id=None, + operation="open_image", + metadata={"source": _summarize_source(source)}, + ) + if path is None: + self._root_image_id = node.image_id + return self.get_image_info(node.image_id) + + def get_image_info(self, image_id: str) -> dict[str, Any]: + """Return metadata for an image ID.""" + node = self._get_node(image_id) + return node.to_dict(children_image_ids=self._children_for(image_id)) + + def list_images(self) -> dict[str, Any]: + """Return the current image tree metadata.""" + return { + "root_image_id": self._root_image_id, + "images": [self.get_image_info(image_id) for image_id in self._nodes], + } + + def crop_image( + self, + image_id: str, + x: float, + y: float, + width: float, + height: float, + unit: str = "pixels", + ) -> dict[str, Any]: + """Crop an image by pixels or percentages and return the new image metadata.""" + node = self._get_node(image_id) + left, top, right, bottom = _resolve_crop_box(node.image.size, x, y, width, height, unit) + cropped = node.image.crop((left, top, right, bottom)) + child = self._create_node( + image=cropped, + parent_image_id=image_id, + operation="crop_image", + metadata={"box": {"left": left, "top": top, "right": right, "bottom": bottom}, "unit": "pixels"}, + ) + return self.get_image_info(child.image_id) + + def transform_image( + self, + image_id: str, + rotate_degrees: float = 0.0, + flip_horizontal: bool = False, + flip_vertical: bool = False, + resize_width: int | None = None, + resize_height: int | None = None, + preserve_aspect_ratio: bool = True, + ) -> dict[str, Any]: + """Rotate, flip, and resize an image, returning a new image ID.""" + node = self._get_node(image_id) + image = node.image.copy() + + if rotate_degrees: + image = image.rotate(-rotate_degrees, expand=True) + if flip_horizontal: + image = ImageOps.mirror(image) + if flip_vertical: + image = ImageOps.flip(image) + if resize_width is not None or resize_height is not None: + image = _resize_image(image, resize_width, resize_height, preserve_aspect_ratio) + + child = self._create_node( + image=image, + parent_image_id=image_id, + operation="transform_image", + metadata={ + "rotate_degrees": rotate_degrees, + "flip_horizontal": flip_horizontal, + "flip_vertical": flip_vertical, + "resize_width": resize_width, + "resize_height": resize_height, + "preserve_aspect_ratio": preserve_aspect_ratio, + }, + ) + return self.get_image_info(child.image_id) + + def edit_color( + self, + image_id: str, + brightness: float = 1.0, + contrast: float = 1.0, + saturation: float = 1.0, + sharpness: float = 1.0, + grayscale: bool = False, + invert: bool = False, + ) -> dict[str, Any]: + """Adjust color properties and return a new image ID.""" + node = self._get_node(image_id) + image = node.image.copy() + + if grayscale: + image = ImageOps.grayscale(image).convert("RGB") + if invert: + image = _invert_image(image) + image = ImageEnhance.Brightness(image).enhance(brightness) + image = ImageEnhance.Contrast(image).enhance(contrast) + image = ImageEnhance.Color(image).enhance(saturation) + image = ImageEnhance.Sharpness(image).enhance(sharpness) + + child = self._create_node( + image=image, + parent_image_id=image_id, + operation="edit_color", + metadata={ + "brightness": brightness, + "contrast": contrast, + "saturation": saturation, + "sharpness": sharpness, + "grayscale": grayscale, + "invert": invert, + }, + ) + return self.get_image_info(child.image_id) + + def image_context_block(self, image_id: str) -> dict[str, Any]: + """Return an OpenAI-compatible image content block for an image ID.""" + data_uri = self.image_data_uri(image_id) + return {"type": "image_url", "image_url": {"url": data_uri}} + + def image_data_uri(self, image_id: str) -> str: + """Return an image as a PNG data URI.""" + base64_data = self.image_base64(image_id) + return f"data:image/{DEFAULT_IMAGE_FORMAT.value};base64,{base64_data}" + + def image_base64(self, image_id: str) -> str: + """Return an image encoded as base64 PNG.""" + node = self._get_node(image_id) + buffer = io.BytesIO() + image = _normalize_image_for_png(node.image) + image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("ascii") + + def history(self) -> list[dict[str, Any]]: + """Return JSON-serializable operation history.""" + return [self.get_image_info(image_id) for image_id in self._nodes] + + def _create_node( + self, + *, + image: Image.Image, + parent_image_id: str | None, + operation: str, + metadata: dict[str, Any], + ) -> ImageNode: + image_id = f"img_{self._next_image_index:04d}" + self._next_image_index += 1 + node = ImageNode( + image_id=image_id, + image=_normalize_loaded_image(image), + parent_image_id=parent_image_id, + operation=operation, + metadata=metadata, + ) + self._nodes[image_id] = node + return node + + def _get_node(self, image_id: str) -> ImageNode: + try: + return self._nodes[image_id] + except KeyError: + known = ", ".join(self._nodes) or "(none)" + raise ValueError(f"Unknown image_id {image_id!r}. Known image IDs: {known}") + + def _children_for(self, image_id: str) -> list[str]: + return [node.image_id for node in self._nodes.values() if node.parent_image_id == image_id] + + def _load_image(self, source: Any) -> Image.Image: + if isinstance(source, Image.Image): + return source.copy() + if isinstance(source, bytes): + return Image.open(io.BytesIO(source)) + if isinstance(source, str): + return Image.open(io.BytesIO(self._load_image_bytes_from_string(source))) + raise TypeError(f"Unsupported image source type: {type(source).__name__}") + + def _load_image_bytes_from_string(self, source: str) -> bytes: + if self._image_data_type == ModalityDataType.URL or is_image_url(source): + response = requests.get(source, timeout=60) + response.raise_for_status() + return response.content + + if ( + self._image_data_type == ModalityDataType.BASE64 + or source.startswith("data:image/") + or is_base64_image(source) + ): + return decode_base64_image(source) + + path = Path(source) + if not path.is_absolute() and self._base_path is not None: + candidate = self._base_path / path + if candidate.exists(): + path = candidate + if not path.is_absolute() and not path.exists(): + path = Path.cwd() / source + if path.exists(): + return path.read_bytes() + + try: + return decode_base64_image(extract_base64_from_data_uri(source)) + except ValueError as exc: + raise ValueError(f"Could not load image source {source!r} as a path, URL, or base64 image") from exc + + +class VisualSearchToolExecutor: + """Executes the built-in visual-search tools for one row.""" + + def __init__( + self, + *, + workspace: VisualImageWorkspace, + allowed_tools: list[str] | None = None, + allow_external_open: bool = False, + ) -> None: + self._workspace = workspace + self._allowed_tools = set(allowed_tools or TOOL_FUNCTIONS) + self._allow_external_open = allow_external_open + + def get_tool_schemas(self) -> list[dict[str, Any]]: + """Return OpenAI-compatible tool schemas for the allowed tools.""" + return [schema for schema in VISUAL_SEARCH_TOOL_SCHEMAS if schema["function"]["name"] in self._allowed_tools] + + def execute(self, tool_name: str, arguments_json: str) -> VisualToolExecution: + """Execute a tool call and return a tool-message-ready result.""" + if tool_name not in self._allowed_tools: + return _error_result(tool_name, f"Tool {tool_name!r} is not allowed for this column.") + if tool_name not in TOOL_FUNCTIONS: + return _error_result(tool_name, f"Unknown visual-search tool {tool_name!r}.") + + try: + arguments = json.loads(arguments_json) if arguments_json else {} + if not isinstance(arguments, dict): + raise ValueError("Tool arguments must decode to a JSON object.") + if tool_name == "open_image" and arguments.get("path") and not self._allow_external_open: + raise ValueError("open_image path is managed by the visual-search column; omit path for row input.") + payload = TOOL_FUNCTIONS[tool_name](self._workspace, **arguments) + return _success_result(tool_name, payload) + except Exception as exc: + return _error_result(tool_name, str(exc)) + + +def _success_result(tool_name: str, payload: dict[str, Any]) -> VisualToolExecution: + image_ids = [payload["image_id"]] if isinstance(payload.get("image_id"), str) else [] + return VisualToolExecution( + content=json.dumps({"ok": True, "tool": tool_name, "result": payload}, sort_keys=True), + image_ids=image_ids, + ) + + +def _error_result(tool_name: str, message: str) -> VisualToolExecution: + return VisualToolExecution( + content=json.dumps({"ok": False, "tool": tool_name, "error": message}, sort_keys=True), + is_error=True, + ) + + +def _open_image(workspace: VisualImageWorkspace, path: str | None = None) -> dict[str, Any]: + return workspace.open_image(path=path) + + +def _get_image_info(workspace: VisualImageWorkspace, image_id: str) -> dict[str, Any]: + return workspace.get_image_info(image_id) + + +def _list_images(workspace: VisualImageWorkspace) -> dict[str, Any]: + return workspace.list_images() + + +def _crop_image( + workspace: VisualImageWorkspace, + image_id: str, + x: float, + y: float, + width: float, + height: float, + unit: str = "pixels", +) -> dict[str, Any]: + return workspace.crop_image(image_id=image_id, x=x, y=y, width=width, height=height, unit=unit) + + +def _transform_image( + workspace: VisualImageWorkspace, + image_id: str, + rotate_degrees: float = 0.0, + flip_horizontal: bool = False, + flip_vertical: bool = False, + resize_width: int | None = None, + resize_height: int | None = None, + preserve_aspect_ratio: bool = True, +) -> dict[str, Any]: + return workspace.transform_image( + image_id=image_id, + rotate_degrees=rotate_degrees, + flip_horizontal=flip_horizontal, + flip_vertical=flip_vertical, + resize_width=resize_width, + resize_height=resize_height, + preserve_aspect_ratio=preserve_aspect_ratio, + ) + + +def _edit_color( + workspace: VisualImageWorkspace, + image_id: str, + brightness: float = 1.0, + contrast: float = 1.0, + saturation: float = 1.0, + sharpness: float = 1.0, + grayscale: bool = False, + invert: bool = False, +) -> dict[str, Any]: + return workspace.edit_color( + image_id=image_id, + brightness=brightness, + contrast=contrast, + saturation=saturation, + sharpness=sharpness, + grayscale=grayscale, + invert=invert, + ) + + +TOOL_FUNCTIONS = { + "open_image": _open_image, + "get_image_info": _get_image_info, + "list_images": _list_images, + "crop_image": _crop_image, + "transform_image": _transform_image, + "edit_color": _edit_color, +} + +VISUAL_SEARCH_TOOL_SCHEMAS: list[dict[str, Any]] = [ + { + "type": "function", + "function": { + "name": "open_image", + "description": ( + "Open the configured source image for this row and return its image_id. " + "If called repeatedly without a path, returns the existing root image." + ), + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Optional image path, URL, or base64 data. Usually omit this.", + } + }, + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_image_info", + "description": "Return dimensions, parent, children, and operation metadata for an image_id.", + "parameters": { + "type": "object", + "properties": {"image_id": {"type": "string", "description": "Image ID to inspect."}}, + "required": ["image_id"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "list_images", + "description": "List all image IDs currently in memory with parent/child relationships.", + "parameters": {"type": "object", "properties": {}, "additionalProperties": False}, + }, + }, + { + "type": "function", + "function": { + "name": "crop_image", + "description": ( + "Create a crop from any existing image_id. Use unit='percent' for approximate visual regions " + "or unit='pixels' for exact coordinates." + ), + "parameters": { + "type": "object", + "properties": { + "image_id": {"type": "string"}, + "x": {"type": "number", "description": "Left coordinate in pixels or percent."}, + "y": {"type": "number", "description": "Top coordinate in pixels or percent."}, + "width": {"type": "number", "description": "Crop width in pixels or percent."}, + "height": {"type": "number", "description": "Crop height in pixels or percent."}, + "unit": {"type": "string", "enum": ["pixels", "percent"], "default": "pixels"}, + }, + "required": ["image_id", "x", "y", "width", "height"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "transform_image", + "description": "Rotate, flip, and/or resize an existing image_id and return a new image_id.", + "parameters": { + "type": "object", + "properties": { + "image_id": {"type": "string"}, + "rotate_degrees": {"type": "number", "default": 0}, + "flip_horizontal": {"type": "boolean", "default": False}, + "flip_vertical": {"type": "boolean", "default": False}, + "resize_width": {"type": "integer", "minimum": 1}, + "resize_height": {"type": "integer", "minimum": 1}, + "preserve_aspect_ratio": {"type": "boolean", "default": True}, + }, + "required": ["image_id"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "edit_color", + "description": ( + "Adjust brightness, contrast, saturation, sharpness, grayscale, or inversion for an image_id." + ), + "parameters": { + "type": "object", + "properties": { + "image_id": {"type": "string"}, + "brightness": {"type": "number", "default": 1.0, "minimum": 0}, + "contrast": {"type": "number", "default": 1.0, "minimum": 0}, + "saturation": {"type": "number", "default": 1.0, "minimum": 0}, + "sharpness": {"type": "number", "default": 1.0, "minimum": 0}, + "grayscale": {"type": "boolean", "default": False}, + "invert": {"type": "boolean", "default": False}, + }, + "required": ["image_id"], + "additionalProperties": False, + }, + }, + }, +] + + +def _resolve_crop_box( + image_size: tuple[int, int], + x: float, + y: float, + width: float, + height: float, + unit: str, +) -> tuple[int, int, int, int]: + image_width, image_height = image_size + if width <= 0 or height <= 0: + raise ValueError("Crop width and height must be positive.") + if unit == "percent": + left = round(image_width * (x / 100.0)) + top = round(image_height * (y / 100.0)) + right = round(image_width * ((x + width) / 100.0)) + bottom = round(image_height * ((y + height) / 100.0)) + elif unit == "pixels": + left = round(x) + top = round(y) + right = round(x + width) + bottom = round(y + height) + else: + raise ValueError("unit must be either 'pixels' or 'percent'.") + + left = max(0, min(image_width - 1, left)) + top = max(0, min(image_height - 1, top)) + right = max(left + 1, min(image_width, right)) + bottom = max(top + 1, min(image_height, bottom)) + return left, top, right, bottom + + +def _resize_image( + image: Image.Image, + resize_width: int | None, + resize_height: int | None, + preserve_aspect_ratio: bool, +) -> Image.Image: + if resize_width is not None and resize_width < 1: + raise ValueError("resize_width must be at least 1.") + if resize_height is not None and resize_height < 1: + raise ValueError("resize_height must be at least 1.") + + if resize_width is None and resize_height is None: + return image + if preserve_aspect_ratio: + if resize_width is None: + ratio = resize_height / image.height + resize_width = max(1, round(image.width * ratio)) + elif resize_height is None: + ratio = resize_width / image.width + resize_height = max(1, round(image.height * ratio)) + resized = image.copy() + resized.thumbnail((resize_width, resize_height), Image.Resampling.LANCZOS) + return resized + + target_width = resize_width or image.width + target_height = resize_height or image.height + return image.resize((target_width, target_height), Image.Resampling.LANCZOS) + + +def _invert_image(image: Image.Image) -> Image.Image: + if image.mode == "RGBA": + red, green, blue, alpha = image.split() + inverted = ImageOps.invert(Image.merge("RGB", (red, green, blue))) + inverted.putalpha(alpha) + return inverted + return ImageOps.invert(image.convert("RGB")) + + +def _normalize_loaded_image(image: Image.Image) -> Image.Image: + image.load() + if image.mode in {"RGBA", "RGB", "L"}: + return image.copy() + return image.convert("RGBA" if "A" in image.getbands() else "RGB") + + +def _normalize_image_for_png(image: Image.Image) -> Image.Image: + if image.mode in {"RGB", "RGBA", "L"}: + return image + return image.convert("RGBA" if "A" in image.getbands() else "RGB") + + +def _summarize_source(source: Any) -> str: + if not isinstance(source, str): + return type(source).__name__ + if source.startswith("data:image/") or is_base64_image(source): + try: + image_format = detect_image_format(decode_base64_image(source)) + return f"{image_format.value} base64 image" + except ValueError: + return "base64 image" + return source diff --git a/plugins/data-designer-visual-search/tests/test_plugin.py b/plugins/data-designer-visual-search/tests/test_plugin.py new file mode 100644 index 0000000..e71f3e3 --- /dev/null +++ b/plugins/data-designer-visual-search/tests/test_plugin.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +from types import SimpleNamespace + +import pytest +from data_designer.config.models import ModalityDataType +from data_designer.config.run_config import RunConfig +from data_designer.config.utils.image_helpers import ImageFormat +from data_designer.config.utils.trace_type import TraceType +from data_designer.engine.models.clients.types import AssistantMessage, ChatCompletionResponse, ToolCall +from data_designer.engine.storage.artifact_storage import ArtifactStorage +from data_designer.engine.testing.utils import assert_valid_plugin +from PIL import Image, ImageDraw + +from data_designer_visual_search.config import VisualSearchColumnConfig +from data_designer_visual_search.impl import VisualSearchColumnGenerator +from data_designer_visual_search.plugin import plugin +from data_designer_visual_search.tools import VisualImageWorkspace, VisualSearchToolExecutor + + +def test_valid_plugin() -> None: + assert_valid_plugin(plugin) + + +class TestVisualSearchColumnConfig: + def test_required_columns_include_image_and_template_references(self) -> None: + config = VisualSearchColumnConfig( + name="answer", + image_column="image_path", + prompt="Find {{ target }} in the image.", + system_prompt="Prefer {{ style }} answers.", + model_alias="vision", + ) + + assert config.required_columns == ["image_path", "target", "style"] + + def test_side_effect_columns_follow_options(self) -> None: + config = VisualSearchColumnConfig( + name="answer", + image_column="image_path", + prompt="Find the object.", + model_alias="vision", + with_trace=TraceType.LAST_MESSAGE, + extract_reasoning_content=True, + ) + + assert config.side_effect_columns == [ + "answer__image_history", + "answer__trace", + "answer__reasoning_content", + ] + + def test_base64_data_type_requires_image_format(self) -> None: + with pytest.raises(ValueError, match="image_format is required"): + VisualSearchColumnConfig( + name="answer", + image_column="image_base64", + prompt="Find the object.", + model_alias="vision", + image_data_type=ModalityDataType.BASE64, + ) + + def test_base64_data_type_accepts_image_format(self) -> None: + config = VisualSearchColumnConfig( + name="answer", + image_column="image_base64", + prompt="Find the object.", + model_alias="vision", + image_data_type=ModalityDataType.BASE64, + image_format=ImageFormat.PNG, + ) + + assert config.image_format == ImageFormat.PNG + + +class TestVisualImageWorkspace: + def test_tools_create_branching_image_tree(self, tmp_path) -> None: + image_path = tmp_path / "scene.png" + _write_test_image(image_path) + workspace = VisualImageWorkspace(source_value=str(image_path), base_path=tmp_path) + + root = workspace.open_image() + crop = workspace.crop_image(root["image_id"], x=0, y=0, width=50, height=50, unit="percent") + transform = workspace.transform_image(root["image_id"], flip_horizontal=True, resize_width=40) + color_edit = workspace.edit_color(crop["image_id"], saturation=0.0, contrast=1.5) + + assert root["image_id"] == "img_0000" + assert crop["image_id"] == "img_0001" + assert crop["width"] == 50 + assert crop["height"] == 40 + assert transform["parent_image_id"] == root["image_id"] + assert color_edit["parent_image_id"] == crop["image_id"] + assert workspace.get_image_info(root["image_id"])["children_image_ids"] == ["img_0001", "img_0002"] + assert workspace.image_context_block(color_edit["image_id"])["image_url"]["url"].startswith("data:image/png") + + def test_executor_returns_json_tool_results(self, tmp_path) -> None: + image_path = tmp_path / "scene.png" + _write_test_image(image_path) + workspace = VisualImageWorkspace(source_value=str(image_path), base_path=tmp_path) + executor = VisualSearchToolExecutor(workspace=workspace, allowed_tools=["open_image", "crop_image"]) + + open_result = executor.execute("open_image", "{}") + crop_result = executor.execute( + "crop_image", + json.dumps({"image_id": "img_0000", "x": 10, "y": 10, "width": 20, "height": 20}), + ) + blocked_result = executor.execute("edit_color", json.dumps({"image_id": "img_0000"})) + + assert json.loads(open_result.content)["result"]["image_id"] == "img_0000" + assert crop_result.image_ids == ["img_0001"] + assert json.loads(blocked_result.content)["ok"] is False + + +class TestVisualSearchColumnGenerator: + def test_generate_executes_tool_loop_and_attaches_resulting_image(self, tmp_path) -> None: + image_path = tmp_path / "scene.png" + _write_test_image(image_path) + fake_model = FakeVisionModel() + generator = _make_generator( + VisualSearchColumnConfig( + name="answer", + image_column="image_path", + prompt="Crop the red square and answer what color it is.", + model_alias="vision", + image_placeholder="", + with_trace=TraceType.LAST_MESSAGE, + ), + fake_model=fake_model, + artifact_path=tmp_path, + ) + + result = generator.generate({"image_path": str(image_path)}) + + assert result["answer"] == "The cropped object is red." + assert [node["image_id"] for node in result["answer__image_history"]] == ["img_0000", "img_0001"] + assert result["answer__trace"][0]["role"] == "assistant" + assert len(fake_model.requests) == 2 + + initial_request_messages = fake_model.requests[0]["messages"] + assert initial_request_messages[1]["content"][0]["text"].startswith("") + second_request_messages = fake_model.requests[1]["messages"] + attached_blocks = second_request_messages[-1]["content"] + assert any(block["type"] == "image_url" for block in attached_blocks) + assert attached_blocks[0]["text"].startswith("") + assert "tools" in fake_model.requests[0]["kwargs"] + + +class FakeVisionModel: + def __init__(self) -> None: + self.requests: list[dict] = [] + + def completion(self, messages: list, **kwargs) -> ChatCompletionResponse: + self.requests.append({"messages": [message.to_dict() for message in messages], "kwargs": kwargs}) + if len(self.requests) == 1: + return ChatCompletionResponse( + message=AssistantMessage( + tool_calls=[ + ToolCall( + id="call_crop", + name="crop_image", + arguments_json=json.dumps( + {"image_id": "img_0000", "x": 0, "y": 0, "width": 50, "height": 50, "unit": "percent"} + ), + ) + ] + ) + ) + return ChatCompletionResponse(message=AssistantMessage(content="The cropped object is red.")) + + +def _make_generator( + config: VisualSearchColumnConfig, + *, + fake_model: FakeVisionModel, + artifact_path, +) -> VisualSearchColumnGenerator: + generator = VisualSearchColumnGenerator.__new__(VisualSearchColumnGenerator) + generator._config = config + generator._resource_provider = SimpleNamespace( + artifact_storage=ArtifactStorage(artifact_path=artifact_path), + run_config=RunConfig(), + ) + generator.__dict__["model"] = fake_model + return generator + + +def _write_test_image(path) -> None: + image = Image.new("RGB", (100, 80), "white") + draw = ImageDraw.Draw(image) + draw.rectangle((0, 0, 50, 40), fill="red") + draw.rectangle((50, 40, 99, 79), fill="blue") + image.save(path) diff --git a/uv.lock b/uv.lock index 97c44fd..e196618 100644 --- a/uv.lock +++ b/uv.lock @@ -11,6 +11,7 @@ resolution-markers = [ members = [ "data-designer-plugins-workspace", "data-designer-template", + "data-designer-visual-search", "ddp", ] constraints = [{ name = "idna", specifier = ">=3.12" }] @@ -448,6 +449,23 @@ dependencies = [ [package.metadata] requires-dist = [{ name = "data-designer", specifier = ">=0.5.7" }] +[[package]] +name = "data-designer-visual-search" +version = "0.1.0" +source = { editable = "plugins/data-designer-visual-search" } +dependencies = [ + { name = "data-designer" }, + { name = "pillow" }, + { name = "requests" }, +] + +[package.metadata] +requires-dist = [ + { name = "data-designer", specifier = ">=0.5.7" }, + { name = "pillow" }, + { name = "requests" }, +] + [[package]] name = "ddp" version = "0.1.0" diff --git a/zensical.toml b/zensical.toml index 3f1af80..7bf284d 100644 --- a/zensical.toml +++ b/zensical.toml @@ -23,6 +23,11 @@ nav = [ {"Overview" = "plugins/data-designer-template/index.md"}, {"Usage" = "plugins/data-designer-template/usage.md"}, ]}, + {"data-designer-visual-search" = [ + {"Overview" = "plugins/data-designer-visual-search/index.md"}, + {"Practical Examples" = "plugins/data-designer-visual-search/examples.md"}, + {"Usage" = "plugins/data-designer-visual-search/usage.md"}, + ]}, # END GENERATED PLUGIN DOCS NAV ]}, ]