diff --git a/src/eaa/image_proc.py b/src/eaa/image_proc.py index fd23152..9ff7db2 100644 --- a/src/eaa/image_proc.py +++ b/src/eaa/image_proc.py @@ -1,8 +1,10 @@ -from typing import Literal, Optional, List +from typing import Literal, Optional, List, Tuple import numpy as np from PIL import Image import matplotlib.pyplot as plt +from sciagent.message_proc import generate_openai_message +from sciagent.task_manager.base import BaseTaskManager def stitch_images( @@ -50,10 +52,12 @@ def stitch_images( return buffer -def windowed_phase_cross_correlation( +def phase_cross_correlation( moving: np.ndarray, ref: np.ndarray, -) -> np.ndarray: + return_correlation_value: bool = False, + use_hanning_window: bool = True, +) -> np.ndarray | Tuple[np.ndarray, float]: """Phase correlation with windowing. The result gives the offset of the moving image with respect to the reference image. If the moving image is shifted to the right, the result will have a @@ -66,6 +70,11 @@ def windowed_phase_cross_correlation( A 2D image. ref : np.ndarray A 2D image. + return_correlation_value : bool, optional + If True, the correlation value is returned along with the offset. + use_hanning_window : bool, optional + If True, a Hanning window is used to smooth the images before the + correlation is computed. Returns ------- @@ -75,12 +84,16 @@ def windowed_phase_cross_correlation( assert np.all(np.array(moving.shape) == np.array(ref.shape)), ( "The shapes of the moving and reference images must be the same." ) - win_y = np.hanning(moving.shape[0]) - win_x = np.hanning(moving.shape[1]) - win = np.outer(win_y, win_x) + if use_hanning_window: + win_y = np.hanning(moving.shape[0]) + win_x = np.hanning(moving.shape[1]) + win = np.outer(win_y, win_x) - f_moving = np.fft.fft2(moving * win) - f_ref = np.fft.fft2(ref * win) + f_moving = np.fft.fft2(moving * win) + f_ref = np.fft.fft2(ref * win) + else: + f_moving = np.fft.fft2(moving) + f_ref = np.fft.fft2(ref) f_corr = f_moving * f_ref.conj() f_corr = f_corr / np.abs(f_corr) @@ -90,7 +103,10 @@ def windowed_phase_cross_correlation( for i in range(2): if shift[i] > map.shape[i] / 2: shift[i] -= map.shape[i] - return shift + if return_correlation_value: + return shift, np.max(map) + else: + return shift def physical_pos_to_pixel( @@ -276,3 +292,45 @@ def add_marker_to_imgae( raise ValueError(f"Invalid marker type: {marker_type}") return ax.get_figure() + + +def check_feature_presence_llm( + task_manager: Optional[BaseTaskManager], + image: np.ndarray, + reference_image: np.ndarray, + n_votes: int = 1, +) -> bool: + """Lets an LLM judge if the features in the reference image + are present in the current image. + + Returns + ------- + bool + Whether the feature is present in the current image. + """ + stitched_image = stitch_images([reference_image, image], gap=10) + message = generate_openai_message(\ + role="system", + content=( + "Are the non-periodic features in the image on the left also present in the image on the right?\n" + "- Features don't have to be exactly aligned, and one may be blurrier than another.\n" + "- 'Periodic features' refers to repeating patterns like grids, repeating dots, etc. " + "They should not be considered as features.\n" + "- Just answer with 'yes' or 'no'." + ), + image=stitched_image + ) + votes = [] + for _ in range(n_votes): + while True: + response, outgoing = task_manager.agent.receive( + message, + return_outgoing_message=True + ) + if task_manager is not None: + task_manager.update_message_history(outgoing, update_context=False, update_full_history=True) + task_manager.update_message_history(response, update_context=False, update_full_history=True) + if "yes" in response["content"].lower() or "no" in response["content"].lower(): + votes.append(True if "yes" in response["content"].lower() else False) + break + return np.mean(votes) >= 0.5 diff --git a/src/eaa/task_manager/imaging/analytical_feature_tracking.py b/src/eaa/task_manager/imaging/analytical_feature_tracking.py new file mode 100644 index 0000000..cb0aa50 --- /dev/null +++ b/src/eaa/task_manager/imaging/analytical_feature_tracking.py @@ -0,0 +1,200 @@ +from typing import Optional, Tuple +import copy +import logging + +import numpy as np +from sciagent.api.llm_config import LLMConfig +from sciagent.api.memory import MemoryManagerConfig + +from eaa.tool.imaging.acquisition import AcquireImage +from eaa.tool.imaging.registration import ImageRegistration +from eaa.task_manager.imaging.base import ImagingBaseTaskManager +from eaa.image_proc import check_feature_presence_llm + +logger = logging.getLogger(__name__) + + +class AnalyticalFeatureTrackingTaskManager(ImagingBaseTaskManager): + + def __init__( + self, + llm_config: LLMConfig = None, + memory_config: Optional[MemoryManagerConfig] = None, + image_acquisition_tool: AcquireImage = None, + message_db_path: Optional[str] = None, + build: bool = True, + image_acquisition_tool_x_coordinate_args: Tuple[str, ...] = ("x_center",), + image_acquisition_tool_y_coordinate_args: Tuple[str, ...] = ("y_center",), + *args, **kwargs + ) -> None: + """Move the FOV in a spiral pattern to look for a feature in a + reference image. + + Parameters + ---------- + llm_config : LLMConfig + The configuration for the LLM. + memory_config : MemoryManagerConfig, optional + Memory configuration forwarded to the agent. + image_acquisition_tool : AcquireImage + The tool to use to acquire images. + message_db_path : Optional[str] + If provided, the entire chat history will be stored in + a SQLite database at the given path. This is essential + if you want to use the WebUI, which polls the database + for new messages. + build : bool + Whether to build the internal state of the task manager. + image_acquisition_tool_x_coordinate_args: Tuple[str, ...] + The names of the arguments of the image acquisition tool that specify x-coordinates. + image_acquisition_tool_y_coordinate_args: Tuple[str, ...] + The names of the arguments of the image acquisition tool that specify y-coordinates. + """ + if image_acquisition_tool is None: + raise ValueError("image_acquisition_tool must be provided.") + if llm_config is None: + raise ValueError("llm_config must be provided for feature presence detection.") + + self.image_acquisition_tool = image_acquisition_tool + self.image_registration_tool = self.create_image_registration_tool(image_acquisition_tool) + + self.image_acquisition_tool_x_coordinate_args = image_acquisition_tool_x_coordinate_args + self.image_acquisition_tool_y_coordinate_args = image_acquisition_tool_y_coordinate_args + + super().__init__( + llm_config=llm_config, + memory_config=memory_config, + tools=[], + message_db_path=message_db_path, + build=build, + *args, **kwargs + ) + + def create_image_registration_tool(self, acquisition_tool: AcquireImage): + image_registration_tool = ImageRegistration( + image_acquisition_tool=acquisition_tool, + reference_image=None, + reference_pixel_size=1.0, + image_coordinates_origin="top_left", + ) + return image_registration_tool + + @staticmethod + def get_position_deltas(idx: int, step_size: Tuple[float, float]) -> Tuple[float, float]: + """Get the delta of y/x positions of the FOV relative to the initial position + given the index of the current FOV in the spiral pattern. + + Parameters + ---------- + idx : int + The index of the current FOV in the spiral pattern. + step_size : Tuple[float, float] + The step size of the spiral pattern in y/x directions. + + Returns + ------- + Tuple[float, float] + The delta of y/x positions of the FOV relative to the initial position. + """ + if idx == 0: + return 0, 0 + + # Determine the "radius", or the layer of the loop in the spiral pattern. + r = 1 + while idx >= (2 * r + 1) ** 2: + r += 1 + idx_current_loop = idx - (2 * (r - 1) + 1) ** 2 + side_len = 2 * r + + # Top edge (moving left to right, includes top-right corner) + if idx_current_loop < side_len: + iy = -r + ix = -r + 1 + idx_current_loop + # Right edge (moving top to bottom, includes bottom-right corner) + elif idx_current_loop < 2 * side_len: + iy = -r + 1 + (idx_current_loop - side_len) + ix = r + # Bottom edge (moving right to left, includes bottom-left corner) + elif idx_current_loop < 3 * side_len: + iy = r + ix = r - 1 - (idx_current_loop - 2 * side_len) + # Left edge (moving bottom to top, includes top-left corner) + elif idx_current_loop < 4 * side_len: + iy = r - 1 - (idx_current_loop - 3 * side_len) + ix = -r + else: + raise ValueError(f"Invalid index: {idx}") + return iy * step_size[0], ix * step_size[1] + + def update_kwargs_buffers( + self, + current_acquisition_kwargs: dict, + y_delta: float, + x_delta: float, + ): + for arg in self.image_acquisition_tool_y_coordinate_args: + current_acquisition_kwargs[arg] += y_delta + for arg in self.image_acquisition_tool_x_coordinate_args: + current_acquisition_kwargs[arg] += x_delta + return current_acquisition_kwargs + + def run( + self, + current_acquisition_kwargs: dict, + reference_image: np.ndarray, + step_size: Tuple[float, float], + reference_image_pixel_size: float = 1.0, + n_max_rounds: int = 20, + ) -> np.ndarray: + """Run the feature tracking task manager. + + Parameters + ---------- + current_acquisition_kwargs: dict + The current kwargs of the image acquisition tool. + reference_image: np.ndarray + A 2D numpy array of the reference image to look for the feature in. + step_size: Tuple[float, float] + The step size of the spiral pattern in y/x directions. + n_max_rounds: int + The maximum number of rounds to run the feature tracking task manager. + correlation_threshold: float + The threshold of the correlation value to consider the feature present + in the current image. + + Returns + ------- + np.ndarray + Offset in y and x. If these offsets are added to the initial positions + in `initial_acquisition_kwargs`, the FOV should be aligned with the reference + image. + """ + initial_acquisition_kwargs = copy.deepcopy(current_acquisition_kwargs) + self.image_registration_tool.set_reference_image( + reference_image, reference_pixel_size=reference_image_pixel_size + ) + + for i in range(n_max_rounds): + y_delta, x_delta = self.get_position_deltas(i, step_size) + acquisition_kwargs = self.update_kwargs_buffers( + copy.deepcopy(initial_acquisition_kwargs), y_delta, x_delta + ) + self.image_acquisition_tool.acquire_image(**acquisition_kwargs) + image = self.image_acquisition_tool.image_k + + # Get offset with windowing + offset = self.image_registration_tool.register_images( + image, + reference_image, + psize_t=self.image_acquisition_tool.psize_k, + psize_r=self.image_registration_tool.reference_pixel_size, + return_correlation_value=False, + use_hanning_window=True + ) + if check_feature_presence_llm( + task_manager=self, + image=image, + reference_image=reference_image, + ): + break + return np.array([y_delta, x_delta]) + offset diff --git a/src/eaa/task_manager/tuning/analytical_focusing.py b/src/eaa/task_manager/tuning/analytical_focusing.py new file mode 100644 index 0000000..901dad8 --- /dev/null +++ b/src/eaa/task_manager/tuning/analytical_focusing.py @@ -0,0 +1,438 @@ +from typing import Optional, Tuple, Sequence, Literal +import logging +import copy +import json + +import numpy as np +import botorch.acquisition + +from sciagent.api.llm_config import LLMConfig +from sciagent.api.memory import MemoryManagerConfig + +from eaa.tool.imaging.acquisition import AcquireImage +from eaa.tool.imaging.param_tuning import SetParameters +from eaa.task_manager.imaging.analytical_feature_tracking import AnalyticalFeatureTrackingTaskManager +from eaa.task_manager.tuning.base import BaseParameterTuningTaskManager +from eaa.tool.imaging.registration import ImageRegistration +from eaa.tool.bo import BayesianOptimizationTool +from eaa.util import to_numpy +from eaa.image_proc import check_feature_presence_llm + +logger = logging.getLogger(__name__) + + +class AnalyticalScanningMicroscopeFocusingTaskManager(BaseParameterTuningTaskManager): + + def __init__( + self, + llm_config: LLMConfig = None, + memory_config: Optional[MemoryManagerConfig] = None, + param_setting_tool: SetParameters = None, + acquisition_tool: AcquireImage = None, + initial_parameters: dict[str, float] = None, + parameter_ranges: list[tuple[float, ...], tuple[float, ...]] = None, + message_db_path: Optional[str] = None, + build: bool = True, + line_scan_tool_x_coordinate_args: Tuple[str, ...] = ("x_center",), + line_scan_tool_y_coordinate_args: Tuple[str, ...] = ("y_center",), + image_acquisition_tool_x_coordinate_args: Tuple[str, ...] = ("x_center",), + image_acquisition_tool_y_coordinate_args: Tuple[str, ...] = ("y_center",), + *args, **kwargs + ): + """Analytical scanning microscope focusing task manager driven + by logic instead of LLM. + + The workflow is as follows: + 1. Acquire a 2D image in the user-specified region of interest. + 2. Run a line scan at user-specified coordinates and record the FWHM of the Gaussian fit. + 3. Change parameter and acquire a new 2D image. + 4.1. If the same feature remains in the FOV, run image registration to get the offset and + adjust 1D/2D scan coordinates. + 4.2. If the feature is no longer in the FOV, run a spiral feature tracking to find the feature. + 5. Repeat 1 - 3 a few times to collect initial data for Bayesian optimization. + 6. Use Bayesian optimization to suggest new parameters. + 7. Change parameter. + 8. Run image registration or feature tracking as in 4. + 9. Run line scan and record the FWHM of the Gaussian fit, update Gaussian process model. + 10. Repeat 6 - 9 until the FWHM is minimized. + + Parameters + ---------- + llm_config : LLMConfig, optional + The LLM configuration to use. + memory_config : MemoryManagerConfig, optional + Memory configuration forwarded to the agent. + param_setting_tool : SetParameters + The tool to use to set the parameters. + acquisition_tool : AcquireImage + The BaseTool object used to acquire data. It should contain a 2D + image acquisition tool and a line scan tool. + bo_tool : BayesianOptimizationTool, optional + The Bayesian optimization tool to use. + image_registration_tool : ImageRegistration, optional + The image registration tool. This tool is optional and is only + used for the feature tracking sub-task if `use_feature_tracking_subtask` + is True. To use registration in the focusing task manager, refer to + ``use_registration_in_workflow`` in the ``run`` method. + initial_parameters : dict[str, float], optional + The initial parameters given as a dictionary of + parameter names and values. + parameter_ranges : list[tuple[float, ...], tuple[float, ...]] + The ranges of the parameters. It should be given as a list of + 2 tuples, where the first tuple gives the lower bounds and the + second tuple gives the upper bounds. The order of the parameters + should match the order of the initial parameters. + message_db_path : Optional[str], optional + If provided, the entire chat history will be stored in + a SQLite database at the given path. This is essential + if you want to use the WebUI, which polls the database + for new messages. + build : bool, optional + Whether to build the internal state of the task manager. + line_scan_tool_x_coordinate_args: Tuple[str, ...] + The names of the arguments of the line scan tool that specify x-coordinates. + When the lab-frame coordinates drift and offsets are found, these arguments + will be updated using the offsets. + line_scan_tool_y_coordinate_args: Tuple[str, ...] + See `line_scan_tool_x_coordinate_args`. + image_acquisition_tool_x_coordinate_args: Tuple[str, ...] + See `line_scan_tool_x_coordinate_args`. + image_acquisition_tool_y_coordinate_args: Tuple[str, ...] + See `line_scan_tool_y_coordinate_args`. + """ + if acquisition_tool is None: + raise ValueError("`acquisition_tool` must be provided.") + + self.acquisition_tool = acquisition_tool + self.bo_tool = self.create_bo_tool(parameter_ranges) + self.image_registration_tool = self.create_image_registration_tool(acquisition_tool) + + if hasattr(acquisition_tool, "line_scan_return_gaussian_fit"): + acquisition_tool.line_scan_return_gaussian_fit = True + else: + logger.warning( + "`line_scan_return_gaussian_fit` attribute is not found in the acquisition tool." + ) + + self.last_acquisition_count_registered = 0 + self.last_acquisition_count_stitched = 0 + + self.feature_tracking_task_manager: Optional[AnalyticalFeatureTrackingTaskManager] = None + + self.line_scan_tool_x_coordinate_args = line_scan_tool_x_coordinate_args + self.line_scan_tool_y_coordinate_args = line_scan_tool_y_coordinate_args + self.image_acquisition_tool_x_coordinate_args = image_acquisition_tool_x_coordinate_args + self.image_acquisition_tool_y_coordinate_args = image_acquisition_tool_y_coordinate_args + + self.line_scan_kwargs = {} + self.image_acquisition_kwargs = {} + + super().__init__( + llm_config=llm_config, + memory_config=memory_config, + param_setting_tool=param_setting_tool, + initial_parameters=initial_parameters, + parameter_ranges=parameter_ranges, + message_db_path=message_db_path, + build=build, + *args, **kwargs + ) + + def create_bo_tool(self, parameter_ranges: list[tuple[float, ...], tuple[float, ...]]): + bo_tool = BayesianOptimizationTool( + bounds=parameter_ranges, + n_observations=1, + kernel_lengthscales=None, + acquisition_function_class=botorch.acquisition.UpperConfidenceBound, + acquisition_function_kwargs={"beta": 1.0}, + ) + return bo_tool + + def create_image_registration_tool(self, acquisition_tool: AcquireImage): + image_registration_tool = ImageRegistration( + image_acquisition_tool=acquisition_tool, + reference_image=None, + reference_pixel_size=1.0, + image_coordinates_origin="top_left", + ) + return image_registration_tool + + def prerun_check( + self, + initial_sampling_range: Optional[Tuple[float, float]], + parameter_change_step_limit: Optional[float | Tuple[float, ...]] + ) -> bool: + if initial_sampling_range is None: + raise ValueError("initial_sampling_range must be provided.") + if len(initial_sampling_range) != len(self.parameter_names): + raise ValueError( + f"The length of initial_sampling_range must be the same as the number of parameters, " + f"but got {len(initial_sampling_range)} and {len(self.parameter_names)}." + ) + if isinstance(parameter_change_step_limit, Sequence): + if len(parameter_change_step_limit) != len(self.parameter_names): + raise ValueError( + f"The length of parameter_change_step_limit must be the same as the number of parameters, " + f"but got {len(parameter_change_step_limit)} and {len(self.parameter_names)}." + ) + return True + + def run( + self, + initial_2d_scan_kwargs: dict = None, + initial_line_scan_kwargs: dict = None, + n_initial_points: int = 5, + initial_sampling_window_size: Optional[Tuple[float, ...]] = None, + n_max_bo_iterations: int = 99, + parameter_change_step_limit: Optional[float | Tuple[float, ...]] = None, + termination_behavior: Literal["ask", "return"] = "ask", + *args, **kwargs + ): + """Run the focusing task. + + Parameters + ---------- + initial_line_scan_kwargs: dict + The keyword arguments for the initial line scan. The argument should + match the signature of the `scan_line` method of the acquisition tool. + initial_2d_scan_kwargs: dict + The keyword arguments for the initial 2D scan. The argument should + match the signature of the `acquire_image` method of the acquisition tool. + n_initial_line_scans: int + The number of initial points to prime the Gaussian process model. + initial_sampling_range: Optional[Tuple[float, float]] + The range over which the initial measurements for Bayesian optimization + are sampled. Should be a tuple with the same length as the number of parameters. + n_max_bo_iterations: int + The maximum number of Bayesian optimization iterations. + parameter_change_step_limit: float + The limit on the step size of the parameter change. Parameter changes + are clipped to this limit if the absolute difference between the one + suggested by BO and the current parameter value is larger than this limit. + If None, no limit is applied. + termination_behavior: Literal["ask", "return"] + The behavior when the task manager reaches the maximum number of Bayesian + optimization iterations. If "ask", the task manager will ask the user for + input. If "return", the task manager will return. + """ + try: + self.prerun_check(initial_sampling_window_size, parameter_change_step_limit) + self.initialize_kwargs_buffers(initial_line_scan_kwargs, initial_2d_scan_kwargs) + + # Initial 2D scan to populate image buffer of acquisition tool. + self.run_2d_scan() + + # Initialize BO tool. + self.collect_initial_data_for_bo( + current_x=np.array(list(self.initial_parameters.values())), + sampling_range=initial_sampling_window_size, + n=n_initial_points, + ) + self.bo_tool.build(acquisition_function_kwargs=None) + + # Run Bayesian optimization. + for i_iter in range(n_max_bo_iterations): + iter_message = f"Running Bayesian optimization iteration {i_iter}..." + logger.info(iter_message) + self.record_system_message(iter_message) + p_suggested = self.get_suggested_next_parameters(parameter_change_step_limit) + suggestion_message = f"Suggested parameter: {p_suggested}" + logger.info(suggestion_message) + self.record_system_message(suggestion_message) + self.run_tuning_iteration(p_suggested) + report = self.generate_report_csv() + final_report_message = f"Final report:\n{report}" + logger.info(final_report_message) + self.record_system_message(final_report_message, update_context=True) + except KeyboardInterrupt: + pass + + if termination_behavior == "ask": + self.run_conversation() + elif termination_behavior == "return": + return + else: + raise ValueError( + f"Invalid termination behavior: {termination_behavior}. " + "Must be one of 'ask' or 'return'." + ) + + def initialize_kwargs_buffers( + self, initial_line_scan_kwargs: dict, initial_2d_scan_kwargs: dict + ): + self.line_scan_kwargs = copy.deepcopy(initial_line_scan_kwargs) + self.image_acquisition_kwargs = copy.deepcopy(initial_2d_scan_kwargs) + + def run_line_scan(self) -> float: + """Run a line scan and return the FWHM of the Gaussian fit. + + Returns + ------- + float + The FWHM of the Gaussian fit. + """ + res = self.acquisition_tool.acquire_line_scan(**self.line_scan_kwargs) + try: + res = json.loads(res) + except json.JSONDecodeError: + raise ValueError( + f"The line scan tool should return a stringified JSON object, but got {res}." + ) + if "fwhm" not in res: + raise ValueError( + f"The stringified JSON object should contain the 'fwhm' key, but got {res}." + ) + content = f"Line scan completed with kwargs {self.line_scan_kwargs}. FWHM = {res['fwhm']:.4f}" + image_path = res.get("image_path") + if isinstance(image_path, str): + self.record_system_message(content, image_path=image_path) + else: + self.record_system_message(content) + return res["fwhm"] + + def update_bo_model(self, fwhm: float): + x = self.param_setting_tool.get_parameter_at_iteration(-1) + x = np.array(x).reshape(1, -1) + # Use negative FWHM because we want to minimize the FWHM. + self.bo_tool.update(x, -np.array([[fwhm]])) + + def run_2d_scan(self): + image_path = self.acquisition_tool.acquire_image(**self.image_acquisition_kwargs) + content = f"Acquired 2D scan with kwargs: {self.image_acquisition_kwargs}" + if isinstance(image_path, str): + self.record_system_message(content, image_path=image_path) + else: + self.record_system_message(content) + + def get_suggested_next_parameters(self, step_size_limit: Optional[float | Tuple[float, ...]] = None): + p_suggested = to_numpy(self.bo_tool.suggest(n_suggestions=1)[0]) + p_current = to_numpy(self.param_setting_tool.get_parameter_at_iteration(-1)) + if step_size_limit is not None: + signs = np.sign(p_suggested - p_current) + step_sizes = np.abs(p_suggested - p_current) + step_sizes = np.clip(step_sizes, min=None, max=step_size_limit) + p_suggested = p_current + signs * step_sizes + return p_suggested + + def find_offset_and_feature_presence(self) -> Tuple[np.ndarray, bool]: + """Find the offset between the latest image and the previous image + and check if the feature is present. + + Returns + ------- + np.ndarray + The offset between the latest image and the previous image. + Offset is in physical units, i.e., pixel size is already accounted for. + bool + Whether the feature is present in the current image. + """ + image_k = self.acquisition_tool.image_k + image_km1 = self.acquisition_tool.image_km1 + + if self.llm_config is None: + logger.warning("`llm_config` is not provided. Unable to check if the feature is present.") + is_present = True + else: + is_present = check_feature_presence_llm( + task_manager=self, + image=image_k, + reference_image=image_km1, + ) + + shift = self.image_registration_tool.register_images( + image_t=self.image_registration_tool.process_image(image_k), + image_r=self.image_registration_tool.process_image(image_km1), + psize_t=self.acquisition_tool.psize_k, + psize_r=self.acquisition_tool.psize_km1, + return_correlation_value=False, + ) + + # Count in the difference of scan positions. + scan_pos_diff = np.array([ + float(self.acquisition_tool.image_acquisition_call_history[-1][f"loc_{dir}"]) + - float(self.acquisition_tool.image_acquisition_call_history[-2][f"loc_{dir}"]) + for dir in ["y", "x"] + ]) + shift += scan_pos_diff + return shift, is_present + + def apply_offset_to_kwargs_buffers(self, offset: np.ndarray): + for arg in self.line_scan_tool_x_coordinate_args: + self.line_scan_kwargs[arg] += offset[1] + for arg in self.line_scan_tool_y_coordinate_args: + self.line_scan_kwargs[arg] += offset[0] + for arg in self.image_acquisition_tool_x_coordinate_args: + self.image_acquisition_kwargs[arg] += offset[1] + for arg in self.image_acquisition_tool_y_coordinate_args: + self.image_acquisition_kwargs[arg] += offset[0] + + def collect_initial_data_for_bo( + self, + current_x: np.ndarray, + sampling_range: np.ndarray, + n: int = 5, + ): + if len(sampling_range) != len(self.parameter_names): + raise ValueError( + f"The length of sampling_range must be the same as the number of parameters, " + f"but got {len(sampling_range)} and {len(self.parameter_names)}." + ) + sampling_range = np.array(sampling_range) + if len(current_x) != len(self.parameter_names): + raise ValueError( + f"The length of current_x must be the same as the number of parameters, " + f"but got {len(current_x)} and {len(self.parameter_names)}." + ) + current_x = np.array(current_x) + + xs = np.linspace(current_x - sampling_range, current_x + sampling_range, n) + for x in xs: + self.run_tuning_iteration(x) + + def run_tuning_iteration(self, x: np.ndarray): + if len(x) != len(self.parameter_names): + raise ValueError( + f"The length of x must be the same as the number of parameters, " + f"but got {len(x)} and {len(self.parameter_names)}." + ) + x = np.array(x) + self.param_setting_tool.set_parameters(x) + self.run_2d_scan() + offset, is_present = self.find_offset_and_feature_presence() + if not is_present: + msg = "Feature is not present in the current image. Running feature tracking sub-task." + logger.info(msg) + self.record_system_message(msg) + offset = self.run_feature_tracking_subtask() + self.apply_offset_to_kwargs_buffers(offset) + fwhm = self.run_line_scan() + self.update_bo_model(fwhm) + + def generate_report_csv(self) -> str: + xs = self.bo_tool.xs_untransformed.tolist() + fwhms = self.bo_tool.ys_untransformed.tolist() + report = "Parameters,FWHM\n" + for x, fwhm in zip(xs, fwhms): + report += f"{x[0]},{fwhm[0]}\n" + return report + + def run_feature_tracking_subtask(self): + if self.feature_tracking_task_manager is None: + self.feature_tracking_task_manager = AnalyticalFeatureTrackingTaskManager( + llm_config=self.llm_config, + image_acquisition_tool=self.acquisition_tool, + image_acquisition_tool_x_coordinate_args=self.image_acquisition_tool_x_coordinate_args, + image_acquisition_tool_y_coordinate_args=self.image_acquisition_tool_y_coordinate_args, + message_db_path=self.message_db_path, + ) + offset = self.feature_tracking_task_manager.run( + current_acquisition_kwargs=self.image_acquisition_kwargs, + reference_image=self.acquisition_tool.image_km1, + step_size=[ + self.acquisition_tool.image_acquisition_call_history[-1]["size_y"] * 0.8, + self.acquisition_tool.image_acquisition_call_history[-1]["size_x"] * 0.8 + ], + reference_image_pixel_size=self.acquisition_tool.psize_km1, + n_max_rounds=20, + ) + return offset diff --git a/src/eaa/tool/bo.py b/src/eaa/tool/bo.py index 391aaab..73000a3 100644 --- a/src/eaa/tool/bo.py +++ b/src/eaa/tool/bo.py @@ -151,7 +151,7 @@ def get_random_initial_points( + self.bounds[0] ) - def build(self) -> None: + def build(self, acquisition_function_kwargs: dict = None) -> None: """Build the Gaussian process model and data transform modules. This function should be called after the initial data are collected and updated to the tool using the `update` method. @@ -160,7 +160,7 @@ def build(self) -> None: self.train_transforms_and_transform_data() self.initialize_model(self.xs_transformed, self.ys_transformed) self.fit_kernel_hyperparameters() - self.build_acquisition_function() + self.build_acquisition_function(acquisition_function_kwargs) def initialize_model(self, x_train: torch.Tensor, y_train: torch.Tensor): """Initialize the Gaussian process model with recorded data. @@ -202,8 +202,10 @@ def fit_kernel_hyperparameters(self, *args, **kwargs): ) ) - def build_acquisition_function(self): + def build_acquisition_function(self, acquisition_function_kwargs: dict = None): """Build the acquisition function.""" + if acquisition_function_kwargs is not None: + self.acquisition_function_kwargs.update(acquisition_function_kwargs) self.acquisition_function = self.acquisition_function_class( model=self.model, **self.acquisition_function_kwargs, diff --git a/src/eaa/tool/imaging/acquisition.py b/src/eaa/tool/imaging/acquisition.py index 6d1b32e..1ee5169 100644 --- a/src/eaa/tool/imaging/acquisition.py +++ b/src/eaa/tool/imaging/acquisition.py @@ -1,6 +1,7 @@ from typing import Annotated, Dict, List, Any import logging import os +import json import matplotlib.pyplot as plt import numpy as np @@ -133,6 +134,7 @@ def __init__( line_scan_gaussian_fit_y_threshold: float = 0, add_line_scan_candidates_to_image: bool = False, plot_image_in_log_scale: bool = False, + line_scan_return_gaussian_fit: bool = False, *args, require_approval: bool = False, **kwargs @@ -164,6 +166,9 @@ def __init__( If True, the tool adds line scan candidates to the image. plot_image_in_log_scale : bool, optional If True, 2D images are plotted in log scale. + line_scan_return_gaussian_fit : bool, optional + If True, the function returns a stringified JSON object containing the image path + and the Gaussian fit FWHM. """ self.whole_image = whole_image self.interpolator = None @@ -178,6 +183,7 @@ def __init__( self.invert_yaxis = invert_yaxis self.add_line_scan_candidates_to_image = add_line_scan_candidates_to_image self.plot_image_in_log_scale = plot_image_in_log_scale + self.line_scan_return_gaussian_fit = line_scan_return_gaussian_fit self.line_scan_candidates: Dict[int, list[int]] = {} @@ -199,6 +205,7 @@ def build_interpolator(self, *args, **kwargs): np.arange(self.whole_image.shape[1]) ), self.whole_image, + bounds_error=False, ) def set_blur(self, blur: float): @@ -222,6 +229,7 @@ def set_offset(self, offset: np.ndarray): of (y, x) coordinates. """ self.offset = offset + logging.info(f"Offset set to {self.offset}") def add_line_scan_candidates( self, @@ -347,8 +355,8 @@ def acquire_image( else: return arr - @tool(name="scan_line", return_type=ToolReturnType.IMAGE_PATH) - def scan_line( + @tool(name="acquire_line_scan", return_type=ToolReturnType.IMAGE_PATH) + def acquire_line_scan( self, start_x: Annotated[float, "The x-coordinate of the starting point of the line scan."], start_y: Annotated[float, "The y-coordinate of the starting point of the line scan."], @@ -368,7 +376,7 @@ def scan_line( The ending point of the line scan. scan_step : float The step size of the line scan. - + Returns ------- str @@ -420,7 +428,17 @@ def scan_line( ) fig.savefig(fname) plt.close(fig) - return fname + if self.line_scan_return_gaussian_fit: + return json.dumps({ + "image_path": fname, + "fwhm": fwhm, + "a": a, + "mu": mu, + "sigma": sigma, + "c": c + }) + else: + return fname def scan_line_by_choice( self, @@ -447,4 +465,4 @@ def scan_line_by_choice( """ start_x, start_y, end_x, end_y = self.line_scan_candidates[choice] self.update_line_scan_call_history(start_x, start_y, end_x, end_y, scan_step) - return self.scan_line(start_x, start_y, end_x, end_y, scan_step=scan_step) + return self.acquire_line_scan(start_x, start_y, end_x, end_y, scan_step=scan_step) diff --git a/src/eaa/tool/imaging/aps_mic/acquisition.py b/src/eaa/tool/imaging/aps_mic/acquisition.py index b74d4b2..594fdde 100644 --- a/src/eaa/tool/imaging/aps_mic/acquisition.py +++ b/src/eaa/tool/imaging/aps_mic/acquisition.py @@ -1,6 +1,7 @@ from typing import Annotated, Tuple, Optional import logging import os +import json from sciagent.tool.base import ToolReturnType, tool @@ -49,6 +50,7 @@ def __init__( plot_image_in_log_scale: bool = False, show_colorbar_in_image: bool = False, require_approval: bool = False, + line_scan_return_gaussian_fit: bool = False, *args, **kwargs ): """Image acquisition tool with Bluesky. @@ -77,6 +79,9 @@ def __init__( Whether to plot the image in log scale. show_colorbar_in_image: bool, optional Whether to show the colorbar in the image. + line_scan_return_gaussian_fit: bool, optional + If True, the function returns a stringified JSON object containing the image path + and the Gaussian fit FWHM. Raises ------ @@ -96,6 +101,7 @@ def __init__( self.allowable_z_range = allowable_z_range self.plot_image_in_log_scale = plot_image_in_log_scale self.show_colorbar_in_image = show_colorbar_in_image + self.line_scan_return_gaussian_fit = line_scan_return_gaussian_fit super().__init__(*args, require_approval=require_approval, **kwargs) @@ -276,7 +282,7 @@ def acquire_line_scan( if not os.path.exists(png_output_dir): os.makedirs(png_output_dir) - img_path, _ = save_xrf_line_scan( + img_path, [_, _, _, fwhm] = save_xrf_line_scan( mda_file_path, png_output_dir, roi_num=self.xrf_roi_num, return_line_array=True ) @@ -284,7 +290,11 @@ def acquire_line_scan( # self.update_line_scan_buffers(img_arr, psize=stepsize_x) if img_path: - return img_path + if self.line_scan_return_gaussian_fit: + return json.dumps({ + "image_path": img_path, + "fwhm": fwhm, + }) else: logger.error(f"Failed to save images for {current_mda_file}") return f"Failed to save images for {current_mda_file}" diff --git a/src/eaa/tool/imaging/registration.py b/src/eaa/tool/imaging/registration.py index 7f58978..f0d3bf1 100644 --- a/src/eaa/tool/imaging/registration.py +++ b/src/eaa/tool/imaging/registration.py @@ -1,12 +1,13 @@ -from typing import Annotated, List, Literal +from typing import Annotated, List, Literal, Tuple import logging +import json import numpy as np import scipy.ndimage as ndi from sciagent.tool.base import BaseTool, check, ToolReturnType, tool from eaa.tool.imaging.acquisition import AcquireImage -from eaa.image_proc import windowed_phase_cross_correlation +from eaa.image_proc import phase_cross_correlation logger = logging.getLogger(__name__) @@ -140,8 +141,10 @@ def register_images( image_t: np.ndarray, image_r: np.ndarray, psize_t: float, - psize_r: float - ) -> np.ndarray: + psize_r: float, + return_correlation_value: bool = False, + use_hanning_window: bool = True, + ) -> np.ndarray | Tuple[np.ndarray, float] | str: """ Register the target image with the reference image. @@ -155,15 +158,23 @@ def register_images( The pixel size of the target image. psize_r : float The pixel size of the reference image. + return_correlation_value : bool, optional + If True, the correlation value is returned along with the offset. + use_hanning_window : bool, optional + If True, a Hanning window is used to smooth the images before the + correlation is computed. Returns ------- - np.ndarray - The offset of the target image with respect to the reference image. If the + np.ndarray | str + If `return_correlation_value` is False, the offset of the target + image with respect to the reference image is returned. If the target image is shifted to the right compared to the reference image, the result will have a positive x-component; if the target image is shifted to the bottom, the result will have a positive y-component. The returned values are in physical units, i.e., pixel size is already accounted for. + If `return_correlation_value` is True, a stringified JSON object with the + keys "offset" and "correlation_value" is returned. """ # Handle pixel size and image size differences if psize_t != psize_r: @@ -198,10 +209,15 @@ def register_images( f"Invalid value for image_coordinates_origin: {self.image_coordinates_origin}" ) - offset = windowed_phase_cross_correlation(image_t, image_r) + offset, correlation_value = phase_cross_correlation( + image_t, image_r, return_correlation_value=return_correlation_value, use_hanning_window=use_hanning_window + ) # Convert the offset from pixel units to physical units. We use psize_r here # since the target image has already been resized to have the same pixel size # as the reference image. offset = offset * psize_r - return offset + if return_correlation_value: + return json.dumps({"offset": offset.tolist(), "correlation_value": float(correlation_value)}) + else: + return offset diff --git a/tests/test_analytical_feature_tracking.py b/tests/test_analytical_feature_tracking.py new file mode 100644 index 0000000..701d7e5 --- /dev/null +++ b/tests/test_analytical_feature_tracking.py @@ -0,0 +1,108 @@ +import os +import argparse + +import numpy as np +import tifffile + +from eaa.task_manager.imaging.analytical_feature_tracking import ( + AnalyticalFeatureTrackingTaskManager, +) +from eaa.tool.imaging.acquisition import SimulatedAcquireImage + +import test_utils as tutils + +import logging + +logging.basicConfig(level=logging.INFO) + +class TestAnalyticalFeatureTracking(tutils.BaseTester): + def _build_task_manager(self): + image_path = os.path.join( + self.get_ci_input_data_dir(), + "simulated_images", + "grid_test_pattern_roi.tiff", + ) + image = tifffile.imread(image_path) + if image.ndim == 3: + image = image[..., 0] + + acquisition_tool = SimulatedAcquireImage( + whole_image=image, + add_axis_ticks=True, + add_grid_lines=False, + invert_yaxis=False, + add_line_scan_candidates_to_image=False, + plot_image_in_log_scale=False, + ) + task_manager = AnalyticalFeatureTrackingTaskManager( + image_acquisition_tool=acquisition_tool, + image_acquisition_tool_x_coordinate_args=("loc_x",), + image_acquisition_tool_y_coordinate_args=("loc_y",), + ) + return task_manager, acquisition_tool, image + + def test_get_position_deltas_matches_spiral_pattern(self): + task_manager, _, _ = self._build_task_manager() + expected_positions = [ + (0, 0), + (-1, 0), + (-1, 1), + (0, 1), + (1, 1), + (1, 0), + (1, -1), + (0, -1), + (-1, -1), + (-2, -1), + (-2, 0), + (-2, 1), + (-2, 2), + (-1, 2), + ] + for idx, expected in enumerate(expected_positions): + assert ( + task_manager.get_position_deltas(idx, (1, 1)) == expected + ), f"Index {idx} mismatch" + + def test_feature_tracking_run_returns_expected_offset(self): + task_manager, _, image = self._build_task_manager() + reference_loc = (60, 270) + size = (100, 100) + reference_image = image[ + reference_loc[0] : reference_loc[0] + size[0], + reference_loc[1] : reference_loc[1] + size[1], + ] + + drift = (100, 0) + current_kwargs = { + "loc_y": reference_loc[0] + drift[0], + "loc_x": reference_loc[1] + drift[1], + "size_y": size[0], + "size_x": size[1], + } + step_size = (80.0, 80.0) + offset = task_manager.run( + current_acquisition_kwargs=current_kwargs, + reference_image=reference_image, + step_size=step_size, + n_max_rounds=2, + correlation_threshold=0.5, + ) + expected_offset = np.array([-drift[0], -drift[1]]) + np.testing.assert_allclose(offset, expected_offset, atol=1.5) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--generate-gold", action="store_true") + args = parser.parse_args() + + tester = TestAnalyticalFeatureTracking() + tester.setup_method( + name="", + generate_data=False, + generate_gold=args.generate_gold, + debug=True, + ) + tester.test_feature_tracking_run_returns_expected_offset() + diff --git a/tests/test_analytical_focusing.py b/tests/test_analytical_focusing.py new file mode 100644 index 0000000..3609f37 --- /dev/null +++ b/tests/test_analytical_focusing.py @@ -0,0 +1,92 @@ +import argparse +import os + +import tifffile + +from eaa.task_manager.tuning.analytical_focusing import ( + AnalyticalScanningMicroscopeFocusingTaskManager, +) +from eaa.tool.imaging.acquisition import SimulatedAcquireImage +from eaa.tool.imaging.param_tuning import SimulatedSetParameters + +import test_utils as tutils + + +class TestAnalyticalFocusing(tutils.BaseTester): + def _build_task_manager(self): + image_path = os.path.join( + self.get_ci_input_data_dir(), + "simulated_images", + "grid_test_pattern_roi.tiff", + ) + image = tifffile.imread(image_path) + if image.ndim == 3: + image = image[..., 0] + + acquisition_tool = SimulatedAcquireImage( + whole_image=image, + add_axis_ticks=True, + add_grid_lines=False, + invert_yaxis=False, + add_line_scan_candidates_to_image=False, + plot_image_in_log_scale=False, + ) + + param_setting_tool = SimulatedSetParameters( + acquisition_tool=acquisition_tool, + parameter_names=["z"], + true_parameters=[3.0], + parameter_ranges=[(0.0,), (10.0,)], + drift_factor=10, + ) + + task_manager = AnalyticalScanningMicroscopeFocusingTaskManager( + param_setting_tool=param_setting_tool, + acquisition_tool=acquisition_tool, + initial_parameters={"z": 10.0}, + parameter_ranges=[(0.0,), (10.0,)], + line_scan_tool_x_coordinate_args=("start_x", "end_x"), + line_scan_tool_y_coordinate_args=("start_y", "end_y"), + image_acquisition_tool_x_coordinate_args=("loc_x",), + image_acquisition_tool_y_coordinate_args=("loc_y",), + ) + return task_manager, acquisition_tool + + def test_task_manager_runs(self): + task_manager, acquisition_tool = self._build_task_manager() + n_initial_points = 2 + n_bo_iterations = 1 + task_manager.run( + initial_2d_scan_kwargs={"loc_y": 0, "loc_x": 0, "size_y": 350, "size_x": 350}, + initial_line_scan_kwargs={ + "start_x": 130, + "start_y": 170, + "end_x": 190, + "end_y": 170, + "scan_step": 1.0, + }, + n_initial_points=n_initial_points, + initial_sampling_window_size=(0.5,), + n_max_bo_iterations=n_bo_iterations, + parameter_change_step_limit=0.5, + ) + assert ( + task_manager.param_setting_tool.len_parameter_history + == n_initial_points + n_bo_iterations + 1 + ) + assert acquisition_tool.counter_acquire_image >= n_initial_points + n_bo_iterations + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--generate-gold", action="store_true") + args = parser.parse_args() + + tester = TestAnalyticalFocusing() + tester.setup_method( + name="", + generate_data=False, + generate_gold=args.generate_gold, + debug=True, + ) + tester.test_task_manager_runs() diff --git a/tests/test_simulated_image_acquisition.py b/tests/test_simulated_image_acquisition.py index a5669fe..ed2b5bb 100644 --- a/tests/test_simulated_image_acquisition.py +++ b/tests/test_simulated_image_acquisition.py @@ -41,7 +41,7 @@ def test_simulated_line_scan(self): tool = SimulatedAcquireImage(whole_image, return_message=False) - fname = tool.scan_line( + fname = tool.acquire_line_scan( start_y=140, end_y=140, start_x=408, diff --git a/uv.lock b/uv.lock index 1e2cac6..b361eb7 100644 --- a/uv.lock +++ b/uv.lock @@ -6202,8 +6202,8 @@ wheels = [ [[package]] name = "sci-agent" -version = "0.1.dev3+g4682432b2" -source = { git = "https://github.com/mdw771/sci-agent#4682432b2cae6822ff6f61a177228cadc7ac8652" } +version = "0.1.dev8+g2a0e3e7a8" +source = { git = "https://github.com/mdw771/sci-agent#2a0e3e7a8b164e30a17274407cd6eb5e75b23c96" } dependencies = [ { name = "chromadb" }, { name = "fastapi" }, @@ -6573,6 +6573,7 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/be/f9/5e4491e5ccf42f5d9cfc663741d261b3e6e1683ae7812114e7636409fcc6/sqlalchemy-2.0.45.tar.gz", hash = "sha256:1632a4bda8d2d25703fdad6363058d882541bdaaee0e5e3ddfa0cd3229efce88", size = 9869912, upload-time = "2025-12-09T21:05:16.737Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/1c/769552a9d840065137272ebe86ffbb0bc92b0f1e0a68ee5266a225f8cd7b/sqlalchemy-2.0.45-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e90a344c644a4fa871eb01809c32096487928bd2038bf10f3e4515cb688cc56", size = 2153860, upload-time = "2025-12-10T20:03:23.843Z" }, { url = "https://files.pythonhosted.org/packages/f3/f8/9be54ff620e5b796ca7b44670ef58bc678095d51b0e89d6e3102ea468216/sqlalchemy-2.0.45-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b8c8b41b97fba5f62349aa285654230296829672fc9939cd7f35aab246d1c08b", size = 3309379, upload-time = "2025-12-09T22:06:07.461Z" }, { url = "https://files.pythonhosted.org/packages/f6/2b/60ce3ee7a5ae172bfcd419ce23259bb874d2cddd44f67c5df3760a1e22f9/sqlalchemy-2.0.45-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:12c694ed6468333a090d2f60950e4250b928f457e4962389553d6ba5fe9951ac", size = 3309948, upload-time = "2025-12-09T22:09:57.643Z" }, { url = "https://files.pythonhosted.org/packages/a3/42/bac8d393f5db550e4e466d03d16daaafd2bad1f74e48c12673fb499a7fc1/sqlalchemy-2.0.45-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:f7d27a1d977a1cfef38a0e2e1ca86f09c4212666ce34e6ae542f3ed0a33bc606", size = 3261239, upload-time = "2025-12-09T22:06:08.879Z" },