Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 67 additions & 9 deletions src/eaa/image_proc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Literal, Optional, List
from typing import Literal, Optional, List, Tuple

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from sciagent.message_proc import generate_openai_message
from sciagent.task_manager.base import BaseTaskManager


def stitch_images(
Expand Down Expand Up @@ -50,10 +52,12 @@ def stitch_images(
return buffer


def windowed_phase_cross_correlation(
def phase_cross_correlation(
moving: np.ndarray,
ref: np.ndarray,
) -> np.ndarray:
return_correlation_value: bool = False,
use_hanning_window: bool = True,
) -> np.ndarray | Tuple[np.ndarray, float]:
"""Phase correlation with windowing. The result gives
the offset of the moving image with respect to the reference image.
If the moving image is shifted to the right, the result will have a
Expand All @@ -66,6 +70,11 @@ def windowed_phase_cross_correlation(
A 2D image.
ref : np.ndarray
A 2D image.
return_correlation_value : bool, optional
If True, the correlation value is returned along with the offset.
use_hanning_window : bool, optional
If True, a Hanning window is used to smooth the images before the
correlation is computed.

Returns
-------
Expand All @@ -75,12 +84,16 @@ def windowed_phase_cross_correlation(
assert np.all(np.array(moving.shape) == np.array(ref.shape)), (
"The shapes of the moving and reference images must be the same."
)
win_y = np.hanning(moving.shape[0])
win_x = np.hanning(moving.shape[1])
win = np.outer(win_y, win_x)
if use_hanning_window:
win_y = np.hanning(moving.shape[0])
win_x = np.hanning(moving.shape[1])
win = np.outer(win_y, win_x)

f_moving = np.fft.fft2(moving * win)
f_ref = np.fft.fft2(ref * win)
f_moving = np.fft.fft2(moving * win)
f_ref = np.fft.fft2(ref * win)
else:
f_moving = np.fft.fft2(moving)
f_ref = np.fft.fft2(ref)

f_corr = f_moving * f_ref.conj()
f_corr = f_corr / np.abs(f_corr)
Expand All @@ -90,7 +103,10 @@ def windowed_phase_cross_correlation(
for i in range(2):
if shift[i] > map.shape[i] / 2:
shift[i] -= map.shape[i]
return shift
if return_correlation_value:
return shift, np.max(map)
else:
return shift


def physical_pos_to_pixel(
Expand Down Expand Up @@ -276,3 +292,45 @@ def add_marker_to_imgae(
raise ValueError(f"Invalid marker type: {marker_type}")

return ax.get_figure()


def check_feature_presence_llm(
task_manager: Optional[BaseTaskManager],
image: np.ndarray,
reference_image: np.ndarray,
n_votes: int = 1,
) -> bool:
"""Lets an LLM judge if the features in the reference image
are present in the current image.

Returns
-------
bool
Whether the feature is present in the current image.
"""
stitched_image = stitch_images([reference_image, image], gap=10)
message = generate_openai_message(\
role="system",
content=(
"Are the non-periodic features in the image on the left also present in the image on the right?\n"
"- Features don't have to be exactly aligned, and one may be blurrier than another.\n"
"- 'Periodic features' refers to repeating patterns like grids, repeating dots, etc. "
"They should not be considered as features.\n"
"- Just answer with 'yes' or 'no'."
),
image=stitched_image
)
votes = []
for _ in range(n_votes):
while True:
response, outgoing = task_manager.agent.receive(
message,
return_outgoing_message=True
)
if task_manager is not None:
task_manager.update_message_history(outgoing, update_context=False, update_full_history=True)
task_manager.update_message_history(response, update_context=False, update_full_history=True)
if "yes" in response["content"].lower() or "no" in response["content"].lower():
votes.append(True if "yes" in response["content"].lower() else False)
break
return np.mean(votes) >= 0.5
200 changes: 200 additions & 0 deletions src/eaa/task_manager/imaging/analytical_feature_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from typing import Optional, Tuple
import copy
import logging

import numpy as np
from sciagent.api.llm_config import LLMConfig
from sciagent.api.memory import MemoryManagerConfig

from eaa.tool.imaging.acquisition import AcquireImage
from eaa.tool.imaging.registration import ImageRegistration
from eaa.task_manager.imaging.base import ImagingBaseTaskManager
from eaa.image_proc import check_feature_presence_llm

logger = logging.getLogger(__name__)


class AnalyticalFeatureTrackingTaskManager(ImagingBaseTaskManager):

def __init__(
self,
llm_config: LLMConfig = None,
memory_config: Optional[MemoryManagerConfig] = None,
image_acquisition_tool: AcquireImage = None,
message_db_path: Optional[str] = None,
build: bool = True,
image_acquisition_tool_x_coordinate_args: Tuple[str, ...] = ("x_center",),
image_acquisition_tool_y_coordinate_args: Tuple[str, ...] = ("y_center",),
*args, **kwargs
) -> None:
"""Move the FOV in a spiral pattern to look for a feature in a
reference image.

Parameters
----------
llm_config : LLMConfig
The configuration for the LLM.
memory_config : MemoryManagerConfig, optional
Memory configuration forwarded to the agent.
image_acquisition_tool : AcquireImage
The tool to use to acquire images.
message_db_path : Optional[str]
If provided, the entire chat history will be stored in
a SQLite database at the given path. This is essential
if you want to use the WebUI, which polls the database
for new messages.
build : bool
Whether to build the internal state of the task manager.
image_acquisition_tool_x_coordinate_args: Tuple[str, ...]
The names of the arguments of the image acquisition tool that specify x-coordinates.
image_acquisition_tool_y_coordinate_args: Tuple[str, ...]
The names of the arguments of the image acquisition tool that specify y-coordinates.
"""
if image_acquisition_tool is None:
raise ValueError("image_acquisition_tool must be provided.")
if llm_config is None:
raise ValueError("llm_config must be provided for feature presence detection.")

self.image_acquisition_tool = image_acquisition_tool
self.image_registration_tool = self.create_image_registration_tool(image_acquisition_tool)

self.image_acquisition_tool_x_coordinate_args = image_acquisition_tool_x_coordinate_args
self.image_acquisition_tool_y_coordinate_args = image_acquisition_tool_y_coordinate_args

super().__init__(
llm_config=llm_config,
memory_config=memory_config,
tools=[],
message_db_path=message_db_path,
build=build,
*args, **kwargs
)

def create_image_registration_tool(self, acquisition_tool: AcquireImage):
image_registration_tool = ImageRegistration(
image_acquisition_tool=acquisition_tool,
reference_image=None,
reference_pixel_size=1.0,
image_coordinates_origin="top_left",
)
return image_registration_tool

@staticmethod
def get_position_deltas(idx: int, step_size: Tuple[float, float]) -> Tuple[float, float]:
"""Get the delta of y/x positions of the FOV relative to the initial position
given the index of the current FOV in the spiral pattern.

Parameters
----------
idx : int
The index of the current FOV in the spiral pattern.
step_size : Tuple[float, float]
The step size of the spiral pattern in y/x directions.

Returns
-------
Tuple[float, float]
The delta of y/x positions of the FOV relative to the initial position.
"""
if idx == 0:
return 0, 0

# Determine the "radius", or the layer of the loop in the spiral pattern.
r = 1
while idx >= (2 * r + 1) ** 2:
r += 1
idx_current_loop = idx - (2 * (r - 1) + 1) ** 2
side_len = 2 * r

# Top edge (moving left to right, includes top-right corner)
if idx_current_loop < side_len:
iy = -r
ix = -r + 1 + idx_current_loop
# Right edge (moving top to bottom, includes bottom-right corner)
elif idx_current_loop < 2 * side_len:
iy = -r + 1 + (idx_current_loop - side_len)
ix = r
# Bottom edge (moving right to left, includes bottom-left corner)
elif idx_current_loop < 3 * side_len:
iy = r
ix = r - 1 - (idx_current_loop - 2 * side_len)
# Left edge (moving bottom to top, includes top-left corner)
elif idx_current_loop < 4 * side_len:
iy = r - 1 - (idx_current_loop - 3 * side_len)
ix = -r
else:
raise ValueError(f"Invalid index: {idx}")
return iy * step_size[0], ix * step_size[1]

def update_kwargs_buffers(
self,
current_acquisition_kwargs: dict,
y_delta: float,
x_delta: float,
):
for arg in self.image_acquisition_tool_y_coordinate_args:
current_acquisition_kwargs[arg] += y_delta
for arg in self.image_acquisition_tool_x_coordinate_args:
current_acquisition_kwargs[arg] += x_delta
return current_acquisition_kwargs

def run(
self,
current_acquisition_kwargs: dict,
reference_image: np.ndarray,
step_size: Tuple[float, float],
reference_image_pixel_size: float = 1.0,
n_max_rounds: int = 20,
) -> np.ndarray:
"""Run the feature tracking task manager.

Parameters
----------
current_acquisition_kwargs: dict
The current kwargs of the image acquisition tool.
reference_image: np.ndarray
A 2D numpy array of the reference image to look for the feature in.
step_size: Tuple[float, float]
The step size of the spiral pattern in y/x directions.
n_max_rounds: int
The maximum number of rounds to run the feature tracking task manager.
correlation_threshold: float
The threshold of the correlation value to consider the feature present
in the current image.

Returns
-------
np.ndarray
Offset in y and x. If these offsets are added to the initial positions
in `initial_acquisition_kwargs`, the FOV should be aligned with the reference
image.
"""
initial_acquisition_kwargs = copy.deepcopy(current_acquisition_kwargs)
self.image_registration_tool.set_reference_image(
reference_image, reference_pixel_size=reference_image_pixel_size
)

for i in range(n_max_rounds):
y_delta, x_delta = self.get_position_deltas(i, step_size)
acquisition_kwargs = self.update_kwargs_buffers(
copy.deepcopy(initial_acquisition_kwargs), y_delta, x_delta
)
self.image_acquisition_tool.acquire_image(**acquisition_kwargs)
image = self.image_acquisition_tool.image_k

# Get offset with windowing
offset = self.image_registration_tool.register_images(
image,
reference_image,
psize_t=self.image_acquisition_tool.psize_k,
psize_r=self.image_registration_tool.reference_pixel_size,
return_correlation_value=False,
use_hanning_window=True
)
if check_feature_presence_llm(
task_manager=self,
image=image,
reference_image=reference_image,
):
break
return np.array([y_delta, x_delta]) + offset
Loading
Loading