Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 85 additions & 3 deletions lightly_studio/src/lightly_studio/api/routes/api/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,49 @@

from lightly_studio.api.routes.api.status import HTTP_STATUS_NOT_FOUND
from lightly_studio.db_manager import SessionDep
from lightly_studio.models.collection import SampleType
from lightly_studio.plugins.base_operator import OperatorResult, OperatorStatus
from lightly_studio.plugins.execution_context import ExecutionContext
from lightly_studio.plugins.operator_registry import RegisteredOperatorMetadata, operator_registry
from lightly_studio.plugins.operator_scope import (
OperatorScope,
get_allowed_scopes_for_collection,
get_scope_for_sample_type,
)
from lightly_studio.plugins.parameter import BaseParameter
from lightly_studio.resolvers import collection_resolver
from lightly_studio.resolvers.image_filter import ImageFilter
from lightly_studio.resolvers.sample_resolver.sample_filter import SampleFilter
from lightly_studio.resolvers.video_resolver.video_filter import VideoFilter

operator_router = APIRouter(prefix="/operators", tags=["operators"])

HTTP_STATUS_CONFLICT = 409
HTTP_STATUS_UNPROCESSABLE = 422


class OperatorContextRequest(BaseModel):
"""Client-supplied execution context for scoped operator calls.

If ``sample_id`` is provided, the API translates it to a sample-id filter
before invoking the operator.
"""

collection_id: UUID | None = None
sample_id: UUID | None = None
filter: ImageFilter | VideoFilter | None = None


class ExecuteOperatorRequest(BaseModel):
"""Request model for executing an operator."""

parameters: dict[str, Any]
context: OperatorContextRequest | None = None


@operator_router.get("")
def get_operators() -> list[RegisteredOperatorMetadata]:
"""Get all registered operators (id, name, status)."""
"""Get all registered operators (id, name, status, scopes)."""
return operator_registry.get_all_metadata()


Expand Down Expand Up @@ -57,7 +82,7 @@ def execute_operator(
Args:
operator_id: The ID of the operator to execute.
collection_id: The ID of the collection to operate on.
request: The execution request containing parameters.
request: The execution request containing parameters and optional context.
session: Database session.

Returns:
Expand All @@ -77,9 +102,66 @@ def execute_operator(
detail=f"Operator '{operator_id}' is not ready (status: {operator.status.value})",
)

req_ctx = request.context

# Resolve the target collection: the context may specify a focused sub-collection
# (e.g. a frame or video collection within a group); fall back to the route collection.
target_collection_id = (
req_ctx.collection_id if req_ctx and req_ctx.collection_id else collection_id
)
target_collection = collection_resolver.get_by_id(session=session, collection_id=target_collection_id)
if target_collection is None:
raise HTTPException(
status_code=HTTP_STATUS_NOT_FOUND,
detail=f"Collection '{target_collection_id}' not found",
)

collection_type = target_collection.sample_type
allowed_scopes = get_allowed_scopes_for_collection(
sample_type=collection_type,
is_root_collection=target_collection.parent_collection_id is None,
)
effective_filter = _build_filter_from_context(
filter_obj=req_ctx.filter if req_ctx else None,
sample_id=req_ctx.sample_id if req_ctx else None,
collection_type=collection_type,
)

# Validate that the operator supports at least one valid scope for this collection context.
if not any(scope in operator.supported_scopes for scope in allowed_scopes):
raise HTTPException(
status_code=HTTP_STATUS_UNPROCESSABLE,
detail=(
f"Operator '{operator_id}' does not support scope(s) "
f"{[scope.value for scope in allowed_scopes]}. "
f"Supported scopes: {[s.value for s in operator.supported_scopes]}"
),
)

context = ExecutionContext(
collection_id=target_collection_id,
filter=effective_filter,
)

return operator.execute(
session=session,
collection_id=collection_id,
context=context,
parameters=request.parameters,
)


def _build_filter_from_context(
*,
filter_obj: ImageFilter | VideoFilter | None,
sample_id: UUID | None,
collection_type: SampleType,
) -> ImageFilter | VideoFilter | None:
"""Translate ``sample_id`` to a dedicated filter payload."""
if sample_id is None:
return filter_obj

sample_filter = SampleFilter(sample_ids=[sample_id])

if get_scope_for_sample_type(sample_type=collection_type) == OperatorScope.VIDEO:
return VideoFilter(sample_filter=sample_filter)
return ImageFilter(sample_filter=sample_filter)
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,19 @@
from lightly_studio.models.annotation.annotation_base import AnnotationCreate, AnnotationType
from lightly_studio.models.annotation_label import AnnotationLabelCreate
from lightly_studio.plugins.base_operator import BaseOperator, OperatorResult
from lightly_studio.plugins.execution_context import ExecutionContext
from lightly_studio.plugins.operator_scope import OperatorScope
from lightly_studio.plugins.parameter import BaseParameter, FloatParameter, StringParameter
from lightly_studio.resolvers import (
annotation_label_resolver,
annotation_resolver,
image_resolver,
tag_resolver,
)
from lightly_studio.resolvers.image_filter import ImageFilter
from lightly_studio.resolvers.sample_resolver.sample_filter import SampleFilter

DEFAULT_INPUT_TAG = "unlabeled"
DEFAULT_MODEL_NAME = "dinov3/convnext-tiny-ltdetr-coco"
DEFAULT_SCORE_THRESHOLD = 0.5

PARAM_INPUT_TAG = "input_tag"
PARAM_MODEL_NAME = "model_name"
PARAM_SCORE_THRESHOLD = "score_threshold"

Expand All @@ -38,7 +36,17 @@ class LightlyTrainObjectDetectionInferenceOperator(BaseOperator):
"""Runs LightlyTrain object detection inference to auto-label images."""

name: str = "LightlyTrain object detection inference"
description: str = "Runs object detection inference and adds annotations to unlabeled images."
description: str = (
"Runs object detection inference and adds annotations to unlabeled samples. "
"Supports image collections."
)

@property
def supported_scopes(self) -> list[OperatorScope]:
"""Return the list of scopes this operator can be triggered from."""
return [
OperatorScope.IMAGE,
]

@property
def parameters(self) -> list[BaseParameter]:
Expand All @@ -56,43 +64,26 @@ def parameters(self) -> list[BaseParameter]:
default=DEFAULT_SCORE_THRESHOLD,
description="Minimum score for keeping a prediction.",
),
StringParameter(
name=PARAM_INPUT_TAG,
required=True,
default=DEFAULT_INPUT_TAG,
description="Tag of samples to auto-label.",
),
]

def execute(
self,
*,
session: Session,
collection_id: UUID,
context: ExecutionContext,
parameters: dict[str, Any],
) -> OperatorResult:
"""Execute the operator with the given parameters."""
collection_id = context.collection_id
model_name = str(parameters.get(PARAM_MODEL_NAME, DEFAULT_MODEL_NAME))
score_threshold = float(parameters.get(PARAM_SCORE_THRESHOLD, DEFAULT_SCORE_THRESHOLD))
input_tag = str(parameters.get(PARAM_INPUT_TAG, DEFAULT_INPUT_TAG))

if score_threshold < 0.0 or score_threshold > 1.0:
return OperatorResult(
success=False,
message="score_threshold must be between 0 and 1",
)

input_tag_entry = tag_resolver.get_by_name(
session=session,
tag_name=input_tag,
collection_id=collection_id,
)
if input_tag_entry is None:
return OperatorResult(
success=False,
message=f"Tag '{input_tag}' not found.",
)

model = lightly_train.load_model(model=model_name)
label_map = _get_or_create_label_map(
session=session,
Expand All @@ -103,15 +94,13 @@ def execute(
samples_result = image_resolver.get_all_by_collection_id(
session=session,
collection_id=collection_id,
filters=ImageFilter(
sample_filter=SampleFilter(tag_ids=[input_tag_entry.tag_id]),
),
filters=context.filter if isinstance(context.filter, ImageFilter) else None,
)
samples = list(samples_result.samples)
if not samples:
return OperatorResult(
success=True,
message=f"No samples found for tag '{input_tag}'.",
message="No samples found for the current filter.",
)

annotations_to_create: list[AnnotationCreate] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import time
from dataclasses import dataclass
from typing import Any
from uuid import UUID

from sqlmodel import Session

from lightly_studio.examples.coco_plugins_demo import lightly_train_inference_operator
from lightly_studio.plugins.base_operator import BaseOperator, OperatorResult
from lightly_studio.plugins.execution_context import ExecutionContext
from lightly_studio.plugins.operator_scope import OperatorScope
from lightly_studio.plugins.parameter import BaseParameter, IntParameter, StringParameter
from lightly_studio.resolvers import image_resolver, tag_resolver
from lightly_studio.resolvers.image_filter import ImageFilter
Expand All @@ -35,6 +36,11 @@ class LightlyTrainObjectDetectionTrainingOperator(BaseOperator):
name: str = "LightlyTrain object detection training"
description: str = "Runs training for labeled images."

@property
def supported_scopes(self) -> list[OperatorScope]:
"""Return the list of scopes this operator can be triggered from."""
return [OperatorScope.IMAGE]

@property
def parameters(self) -> list[BaseParameter]:
"""Return the list of parameters this operator expects."""
Expand Down Expand Up @@ -63,10 +69,11 @@ def execute(
self,
*,
session: Session,
collection_id: UUID,
context: ExecutionContext,
parameters: dict[str, Any],
) -> OperatorResult:
"""Execute the operator with the given parameters."""
collection_id = context.collection_id
model_name = str(parameters.get(PARAM_MODEL_NAME, DEFAULT_MODEL_NAME))
checkpoint_name = model_name
input_tag = str(parameters.get(PARAM_INPUT_TAG, DEFAULT_INPUT_TAG))
Expand Down
14 changes: 10 additions & 4 deletions lightly_studio/src/lightly_studio/examples/example_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@

from dataclasses import dataclass
from typing import Any
from uuid import UUID

from environs import Env
from sqlmodel import Session

import lightly_studio as ls
from lightly_studio import db_manager
from lightly_studio.plugins.base_operator import BaseOperator, OperatorResult
from lightly_studio.plugins.execution_context import ExecutionContext
from lightly_studio.plugins.operator_registry import operator_registry
from lightly_studio.plugins.operator_scope import OperatorScope
from lightly_studio.plugins.parameter import (
BaseParameter,
BoolParameter,
Expand All @@ -29,6 +30,11 @@ class TestOperator(BaseOperator):
name: str = "test operator"
description: str = "used to test the operator and registry system"

@property
def supported_scopes(self) -> list[OperatorScope]:
"""Return the list of scopes this operator can be triggered from."""
return [OperatorScope.ROOT]

@property
def parameters(self) -> list[BaseParameter]:
"""Return the list of parameters this operator expects."""
Expand Down Expand Up @@ -59,14 +65,14 @@ def execute(
self,
*,
session: Session,
collection_id: UUID,
context: ExecutionContext,
parameters: dict[str, Any],
) -> OperatorResult:
"""Execute the operator with the given parameters.

Args:
session: Database session.
collection_id: ID of the collection to operate on.
context: The context for the operator (collection_id).
parameters: Parameters passed to the operator.

Returns:
Expand All @@ -84,7 +90,7 @@ def execute(
+ " "
+ str(parameters.get("test int"))
+ " "
+ str(collection_id)
+ str(context.collection_id)
+ str(session),
)

Expand Down
21 changes: 16 additions & 5 deletions lightly_studio/src/lightly_studio/plugins/base_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any
from uuid import UUID

from sqlmodel import Session

from lightly_studio.plugins.execution_context import ExecutionContext
from lightly_studio.plugins.operator_scope import OperatorScope
from lightly_studio.plugins.parameter import BaseParameter


Expand Down Expand Up @@ -59,6 +60,16 @@ def description(self) -> str:
def parameters(self) -> list[BaseParameter]:
"""Return the list of parameters this operator expects."""

@property
@abstractmethod
def supported_scopes(self) -> list[OperatorScope]:
"""Return the list of scopes this operator can be triggered from.

Determines where in the UI the operator is surfaced.
``OperatorScope.ROOT`` targets dataset/root collections.
``OperatorScope.IMAGE`` covers both images and video frames.
"""

# --- Lifecycle methods ---

async def start(self) -> None:
Expand Down Expand Up @@ -89,18 +100,18 @@ def execute(
self,
*,
session: Session,
collection_id: UUID,
context: ExecutionContext,
parameters: dict[str, Any],
) -> OperatorResult:
"""Execute the operator with the given parameters.

Args:
session: Database session.
collection_id: ID of the collection to operate on.
parameters: Parameters passed to the operator.
context: Execution context containing collection_id and optional filter.
parameters: Parameters passed by the user.

Returns:
Dictionary with 'success' (bool) and 'message' (str) keys.
An OperatorResult with success flag and message.
"""
# TODO (Jonas 11/2025): The parameters dict should be validated against self.parameters,
# for now we leave it to the operator implementation.
Loading
Loading