diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cd47c3591..6d0ae8cc0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,6 +2,8 @@ ci: autofix_prs: true autoupdate_commit_msg: ':technologist: pre-commit autoupdate' autoupdate_schedule: 'monthly' +default_language_version: + python: python3.12 repos: - repo: local hooks: diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 1aba18896..ddc6a28f2 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -33,6 +33,7 @@ timm>=1.0.3 torch>=2.5.0 torchvision>=0.15.0 tqdm>=4.64.1 +transformers>=4.51.1 umap-learn>=0.5.3 wsidicom>=0.18.0 zarr>=2.13.3, <3.0.0 diff --git a/tests/models/test_arch_sam.py b/tests/models/test_arch_sam.py new file mode 100644 index 000000000..a28b55981 --- /dev/null +++ b/tests/models/test_arch_sam.py @@ -0,0 +1,63 @@ +"""Unit test package for SAM.""" + +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import pytest +import torch + +from tiatoolbox.models.architecture.sam import SAM +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils import imread +from tiatoolbox.utils.misc import select_device + +ON_GPU = toolbox_env.has_gpu() + +# Test pretrained Model ============================= + + +def test_functional_sam(remote_sample: Callable) -> None: + """Test for SAM.""" + # convert to pathlib Path to prevent wsireader complaint + tile_path = Path(remote_sample("patch-extraction-vf")) + img = imread(tile_path) + + # test creation + + model = SAM(device=select_device(on_gpu=ON_GPU)) + + # create image patch and prompts + patch = img[63:191, 750:878, :] + + points = np.array([[[64, 64]]]) + boxes = np.array([[[64, 64, 128, 128]]]) + + # test preproc + tensor = torch.from_numpy(img) + patch = np.expand_dims(model.preproc(tensor), axis=0) + patch = model.preproc(patch) + + # test inference + + mask_output, score_output = model.infer_batch( + model, patch, points, device=select_device(on_gpu=ON_GPU) + ) + + assert mask_output is not None, "Output should not be None" + assert len(mask_output) > 0, "Output should have at least one element" + assert len(score_output) > 0, "Output should have at least one element" + + mask_output, score_output = model.infer_batch( + model, patch, box_coords=boxes, device=select_device(on_gpu=ON_GPU) + ) + + assert len(mask_output) > 0, "Output should have at least one element" + assert len(score_output) > 0, "Output should have at least one element" + + # test error when no prompts provided + with pytest.raises( + ValueError, + match=r"At least one of point_coords or box_coords must be provided.", + ): + _ = model.infer_batch(model, patch, device=select_device(on_gpu=ON_GPU)) diff --git a/tests/models/test_prompt_segmentor.py b/tests/models/test_prompt_segmentor.py new file mode 100644 index 000000000..41196fbe2 --- /dev/null +++ b/tests/models/test_prompt_segmentor.py @@ -0,0 +1,53 @@ +"""Unit test package for prompt segmentor.""" + +from pathlib import Path + +import cv2 +import numpy as np + +from tiatoolbox.annotation.storage import SQLiteStore +from tiatoolbox.models.architecture.sam import SAM +from tiatoolbox.models.engine.prompt_segmentor import PromptSegmentor + + +def test_prompt_segmentor(track_tmp_path: Path) -> None: + """Test for Prompt Segmentor.""" + # create dummy image patch 256x256x3 with small circle in middle + img = np.zeros((256, 256, 3), dtype=np.uint8) + cv2.circle(img, (128, 128), 50, (255, 255, 255), -1) + + expected_area = np.pi * (50**2) + + # prompt with pt in center of circle + points = np.array([[[128, 128]]], np.uint32) + boxes = None + + # instantiate prompt segmentor with SAM model + sam_model = SAM() + prompt_segmentor = PromptSegmentor(model=sam_model) + + # run prediction + output_paths = prompt_segmentor.run( + images=[img], + point_coords=points, + box_coords=boxes, + save_dir=track_tmp_path / "sam_test_output", + device="cpu", + ) + + assert isinstance(output_paths, list), "Output should be a list of paths" + assert len(output_paths) == 1, "Output list should contain one path" + + # load the saved annotation db + store = SQLiteStore(output_paths[0]) + ann = next(iter(store.values())) + # area should be close to expected area + assert abs(ann.geometry.area - expected_area) < expected_area * 0.1, ( + "should segment circle area correctly" + ) + + # check with model=None + prompt_segmentor_default = PromptSegmentor(model=None) + assert isinstance(prompt_segmentor_default.model, SAM), ( + "Default model should be SAM" + ) diff --git a/tests/test_app_bokeh.py b/tests/test_app_bokeh.py index 6760a5137..75181321d 100644 --- a/tests/test_app_bokeh.py +++ b/tests/test_app_bokeh.py @@ -558,6 +558,52 @@ def test_hovernet_on_box(doc: Document, data_path: pytest.TempPathFactory) -> No assert len(main.UI["type_column"].children) == 1 +def test_sam_segment(doc: Document, data_path: pytest.TempPathFactory) -> None: + """Test running SAM on points and a box.""" + slide_select = doc.get_model_by_name("slide_select0") + slide_select.value = [data_path["slide2"].name] + run_button = doc.get_model_by_name("to_model0") + assert len(main.UI["color_column"].children) == 0 + slide_select.value = [data_path["slide1"].name] + # set up a box selection + main.UI["box_source"].data = { + "x": [1200], + "y": [-2000], + "width": [400], + "height": [400], + } + + # select SAM model and run it on box + model_select = doc.get_model_by_name("model_drop0") + model_select.value = "SAM" + + click = ButtonClick(run_button) + run_button._trigger_event(click) + assert len(main.UI["color_column"].children) > 0 + + # test save functionality + save_button = doc.get_model_by_name("save_button0") + click = ButtonClick(save_button) + save_button._trigger_event(click) + saved_path = ( + data_path["base_path"] + / "overlays" + / (data_path["slide1"].stem + "_saved_anns.db") + ) + assert saved_path.exists() + + # load an overlay with different types + cprop_select = doc.get_model_by_name("cprop0") + cprop_select.value = ["prob"] + layer_drop = doc.get_model_by_name("layer_drop0") + click = MenuItemClick(layer_drop, str(data_path["dat_anns"])) + layer_drop._trigger_event(click) + assert main.UI["vstate"].types == ["annotation"] + # check the per-type ui controls have been updated + assert len(main.UI["color_column"].children) == 1 + assert len(main.UI["type_column"].children) == 1 + + def test_alpha_sliders(doc: Document) -> None: """Test sliders for adjusting slide and overlay alpha.""" slide_alpha = doc.get_model_by_name("slide_alpha0") diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index 39d1441ce..42b758c33 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -8,6 +8,7 @@ from .architecture.mapde import MapDe from .architecture.micronet import MicroNet from .architecture.nuclick import NuClick +from .architecture.sam import SAM from .architecture.sccnn import SCCNN from .engine.multi_task_segmentor import MultiTaskSegmentor from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor @@ -17,6 +18,7 @@ PatchPredictor, WSIPatchDataset, ) +from .engine.prompt_segmentor import PromptSegmentor from .engine.semantic_segmentor import ( DeepFeatureExtractor, IOSegmentorConfig, @@ -25,6 +27,7 @@ ) __all__ = [ + "SAM", "SCCNN", "HoVerNet", "HoVerNetPlus", @@ -35,5 +38,6 @@ "NuClick", "NucleusInstanceSegmentor", "PatchPredictor", + "PromptSegmentor", "SemanticSegmentor", ] diff --git a/tiatoolbox/models/architecture/sam.py b/tiatoolbox/models/architecture/sam.py new file mode 100644 index 000000000..729baf585 --- /dev/null +++ b/tiatoolbox/models/architecture/sam.py @@ -0,0 +1,238 @@ +"""Define SAM architecture.""" + +from __future__ import annotations + +import numpy as np +import torch +from PIL import Image +from transformers import SamModel, SamProcessor + +from tiatoolbox.models.models_abc import ModelABC + + +class SAM(ModelABC): + """Segment Anything Model (SAM) Architecture. + + Meta AI's zero-shot segmentation model. + SAM is used for interactive general-purpose segmentation. + + Currently supports SAM, which requires a checkpoint and model type. + + SAM accepts an RGB image patch along with a list of point and bounding + box coordinates as prompts. + + Args: + model_type (str): + Model type. + Currently supported: vit_b, vit_l, vit_h. + checkpoint_path (str): + Path to the model checkpoint. + device (str): + Device to run inference on. + + Examples: + >>> # instantiate SAM with checkpoint path and model type + >>> sam = SAM( + ... model_type="vit_b", + ... checkpoint_path="path/to/sam_checkpoint.pth" + ... ) + """ + + def __init__( + self: SAM, + model_path: str = "facebook/sam-vit-huge", + *, + device: str = "cpu", + ) -> None: + """Initialize :class:`SAM`.""" + super().__init__() + self.net_name = "SAM" + self.device = device + + self.model = SamModel.from_pretrained(model_path).to(device) + self.processor = SamProcessor.from_pretrained(model_path) + + def _process_prompts( + self: SAM, + image: np.ndarray, + embeddings: torch.Tensor, + orig_sizes: torch.Tensor, + reshaped_sizes: torch.Tensor, + points: list | None = None, + boxes: list | None = None, + point_labels: list | None = None, + ) -> tuple[list, list]: + """Process prompts and return masks and scores.""" + inputs = self.processor( + image, + input_points=points, + input_labels=point_labels, + input_boxes=boxes, + return_tensors="pt", + ).to(self.device) + + # Replaces pixel_values with image embeddings + inputs.pop("pixel_values", None) + inputs.update( + { + "image_embeddings": embeddings, + "original_sizes": orig_sizes, + "reshaped_input_sizes": reshaped_sizes, + } + ) + + with torch.inference_mode(): + # Forward pass through the model + outputs = self.model(**inputs, multimask_output=False) + image_masks = self.processor.image_processor.post_process_masks( + outputs.pred_masks.cpu(), + inputs["original_sizes"].cpu(), + inputs["reshaped_input_sizes"].cpu(), + ) + image_scores = outputs.iou_scores.cpu() + + return image_masks, image_scores + + def forward( # skipcq: PYL-W0221 + self: SAM, + imgs: list, + point_coords: list | None = None, + box_coords: list | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Torch method. Defines forward pass on each image in the batch. + + Note: This architecture only uses a single layer, so only one forward pass + is needed. + + Args: + imgs (list): + List of images to process, of the shape NHWC. + point_coords (list): + List of point coordinates for each image. + box_coords (list): + Bounding box coordinates for each image. + + Returns: + tuple[np.ndarray, np.ndarray]: + Array of masks and scores for each image. + + """ + masks, scores = [], [] + + for i, img in enumerate(imgs): + image = [Image.fromarray(img)] + embeddings, orig_sizes, reshaped_sizes = self._encode_image(image) + point_labels = None + points = None + boxes = None + + if box_coords is not None: + boxes = box_coords[i] + # Convert box coordinates to list + boxes = [boxes[:, None, :].tolist()] + image_masks, image_scores = self._process_prompts( + image, + embeddings, + orig_sizes, + reshaped_sizes, + None, + boxes, + point_labels, + ) + masks.append(np.array([image_masks])) + scores.append(np.array([image_scores])) + + if point_coords is not None: + points = point_coords[i] + # Convert point coordinates to list + point_labels = np.ones((1, len(points), 1), dtype=int).tolist() + points = [points[:, None, :].tolist()] + image_masks, image_scores = self._process_prompts( + image, + embeddings, + orig_sizes, + reshaped_sizes, + points, + None, + point_labels, + ) + masks.append(np.array([image_masks])) + scores.append(np.array([image_scores])) + + torch.cuda.empty_cache() + + return np.concatenate(masks, axis=2), np.concatenate(scores, axis=2) + + @staticmethod + def infer_batch( + model: torch.nn.Module, + batch_data: list, + point_coords: np.ndarray | None = None, + box_coords: np.ndarray | None = None, + *, + device: str = "cpu", + ) -> tuple[np.ndarray, np.ndarray]: + """Run inference on an input batch. + + Contains logic for forward operation as well as I/O aggregation. + SAM accepts a list of points and a single bounding box per image. + + Args: + model (nn.Module): + PyTorch defined model. + batch_data (list): + A batch of data generated by + `torch.utils.data.DataLoader`. + point_coords (np.ndarray | None): + Point coordinates for each image in the batch. + box_coords (np.ndarray | None): + Bounding box coordinates for each image in the batch. + device (str): + Device to run inference on. + + Returns: + pred_info (tuple[np.ndarray, np.ndarray]): + Tuple of masks and scores for each image in the batch. + + """ + model.eval().to(device) + if point_coords is None and box_coords is None: + msg = "At least one of point_coords or box_coords must be provided." + raise ValueError(msg) + + with torch.inference_mode(): + masks, scores = model(batch_data, point_coords, box_coords) + + return masks, scores + + def _encode_image(self: SAM, image: np.ndarray) -> np.ndarray: + """Encodes image and stores size info for later mask post-processing.""" + processed = self.processor(image, return_tensors="pt") + original_sizes = processed["original_sizes"] + reshaped_sizes = processed["reshaped_input_sizes"] + + inputs = processed.to(self.device) + embeddings = self.model.get_image_embeddings(inputs["pixel_values"]) + return embeddings, original_sizes, reshaped_sizes + + @staticmethod + def preproc(image: np.ndarray) -> np.ndarray: + """Pre-processes an image - Converts it into a format accepted by SAM (HWC).""" + # Move the tensor to the CPU if it's a PyTorch tensor + if isinstance(image, torch.Tensor): + image = image.permute(1, 2, 0).cpu().numpy() + + return image[..., :3] # Remove alpha channel if present + + def to( + self: ModelABC, + device: str = "cpu", + dtype: torch.dtype | None = None, + *, + non_blocking: bool = False, + ) -> ModelABC | torch.nn.DataParallel[ModelABC]: + """Moves the model to the specified device.""" + super().to(device, dtype=dtype, non_blocking=non_blocking) + self.device = device + self.model.to(device) + return self diff --git a/tiatoolbox/models/engine/prompt_segmentor.py b/tiatoolbox/models/engine/prompt_segmentor.py new file mode 100644 index 000000000..b8d550a9a --- /dev/null +++ b/tiatoolbox/models/engine/prompt_segmentor.py @@ -0,0 +1,114 @@ +"""This module enables interactive segmentation.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np + +from tiatoolbox.models.architecture.sam import SAM +from tiatoolbox.utils.misc import dict_to_store_semantic_segmentor + +if TYPE_CHECKING: # pragma: no cover + import torch + + from tiatoolbox.type_hints import IntPair + + +class PromptSegmentor: + """Engine for prompt-based segmentation of WSIs. + + This class is designed to work with the SAM model architecture. + It allows for interactive segmentation by providing point and bounding box + coordinates as prompts. The model is intended to be used with image tiles + selected interactively in some way and provided as np.arrays. At least + one of either point_coords or box_coords must be provided to guide + segmentation. + + Args: + model (SAM): + Model architecture to use. If None, defaults to SAM. + + """ + + def __init__( + self, + model: torch.nn.Module = None, + ) -> None: + """Initializes the PromptSegmentor.""" + if model is None: + model = SAM() + self.model = model + self.scale = 1.0 + self.offset = (0, 0) + + def run( # skipcq: PYL-W0221 + self, + images: list, + point_coords: np.ndarray | None = None, + box_coords: np.ndarray | None = None, + save_dir: str | Path | None = None, + device: str = "cpu", + ) -> list[Path]: + """Run inference on image patches with prompts. + + Args: + images (list): + List of image patch arrays to run inference on. + point_coords (np.ndarray): + N_im x N_points x 2 array of point coordinates for each image patch. + box_coords (np.ndarray): + N_im x N_boxes x 4 array of bounding box coordinates for each + image patch. + save_dir (str or Path): + Directory to save the output databases. + device (str): + Device to run inference on. + + Returns: + list[Path]: + Paths to the saved output databases. + """ + self.model.to(device) + paths = [] + masks, _ = self.model.infer_batch( + self.model, + images, + point_coords=point_coords, + box_coords=box_coords, + device=device, + ) + save_dir = Path(save_dir) + for i, _mask in enumerate(masks): + save_path = save_dir / f"{i}" + mask = np.any(_mask[0], axis=0, keepdims=False) + dict_to_store_semantic_segmentor( + patch_output={"predictions": mask[0]}, + scale_factor=self.scale, + offset=self.offset, + save_path=Path(f"{save_path}.{i}.db"), + ) + paths.append(Path(f"{save_path}.{i}.db")) + return paths + + def calc_mpp( + self, area_dims: IntPair, base_mpp: float, fixed_size: int = 1500 + ) -> float: + """Calculates the microns per pixel for a fixed area of an image. + + Args: + area_dims (tuple): + Dimensions of the area to be scaled. + base_mpp (float): + Microns per pixel of the base image. + fixed_size (int): + Fixed size of the area. + + Returns: + float: + Microns per pixel required to scale the area to a fixed size. + """ + scale = max(area_dims) / fixed_size if max(area_dims) > fixed_size else 1.0 + self.scale = scale + return base_mpp * scale, scale diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index bf4819ebe..2e69f83c8 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1208,6 +1208,7 @@ def process_contours( contours: list[np.ndarray], hierarchy: np.ndarray, scale_factor: tuple[float, float] = (1, 1), + offset: np.ndarray | None = None, ) -> list[Annotation]: """Process contours and hierarchy to create annotations. @@ -1218,6 +1219,8 @@ def process_contours( A list of hierarchy. scale_factor (tuple[float, float]): The scale factor to use when loading the annotations. + offset (np.ndarray | None): + Optional offset to be added to the coordinates of the annotations. Returns: list: @@ -1231,6 +1234,8 @@ def process_contours( for i, layer_ in enumerate(contours): coords: np.ndarray = layer_.squeeze() scaled_coords: np.ndarray = np.array([np.array(scale_factor) * coords]) + if offset is not None: + scaled_coords += offset # save one points as a line, otherwise save the Polygon if len(layer_) > 2: # noqa: PLR2004 @@ -1308,6 +1313,7 @@ def dict_to_store_semantic_segmentor( scale_factor: tuple[float, float], class_dict: dict | None = None, save_path: Path | None = None, + offset: np.ndarray | None = None, ) -> AnnotationStore | Path: """Converts output of TIAToolbox SemanticSegmentor engine to AnnotationStore. @@ -1324,13 +1330,14 @@ def dict_to_store_semantic_segmentor( save_path (str or Path): Optional Output directory to save the Annotation Store results. + offset: np.ndarray | None = None: + Optional offset to be added to the coordinates of the annotations. Returns: (SQLiteStore or Path): An SQLiteStore containing Annotations for each patch or Path to file storing SQLiteStore containing Annotations for each patch. - """ preds = patch_output["predictions"] @@ -1354,7 +1361,7 @@ def dict_to_store_semantic_segmentor( ) contours = cast("list[np.ndarray]", contours) - annotations_list_ = process_contours(contours, hierarchy, scale_factor) + annotations_list_ = process_contours(contours, hierarchy, scale_factor, offset) annotations_list.extend(annotations_list_) _ = store.append_many( diff --git a/tiatoolbox/visualization/bokeh_app/main.py b/tiatoolbox/visualization/bokeh_app/main.py index 832d8661d..15fd484d7 100644 --- a/tiatoolbox/visualization/bokeh_app/main.py +++ b/tiatoolbox/visualization/bokeh_app/main.py @@ -66,8 +66,9 @@ # GitHub actions seems unable to find TIAToolbox unless this is here sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) from tiatoolbox import logger -from tiatoolbox.models.engine.nucleus_instance_segmentor import ( - NucleusInstanceSegmentor, +from tiatoolbox.models.engine.nucleus_instance_segmentor import NucleusInstanceSegmentor +from tiatoolbox.models.engine.prompt_segmentor import ( + PromptSegmentor, # skipcq: FLK-E402 ) from tiatoolbox.tools.pyramid import ZoomifyGenerator from tiatoolbox.utils.misc import select_device @@ -1119,6 +1120,8 @@ def to_model_cb(attr: ButtonClick) -> None: # noqa: ARG001 """Callback to run currently selected model.""" if UI["vstate"].current_model == "hovernet": segment_on_box() + elif UI["vstate"].current_model == "SAM": + sam_segment() # Add any other models here else: # pragma: no cover logger.warning("unknown model") @@ -1276,6 +1279,104 @@ def segment_on_box() -> None: rmtree(tmp_mask_dir) +def sam_segment() -> None: + """Callback to run SAM using a point on the slide. + + Will run GeneralSegmentor on selected region of wsi defined + by the point in pt_source. + + """ + prompt_segmentor = PromptSegmentor() + x_start = max(0, UI["p"].x_range.start) + y_start = max(0, -UI["p"].y_range.end) + x_end = min(UI["p"].x_range.end, UI["vstate"].dims[0]) + y_end = min(-UI["p"].y_range.start, UI["vstate"].dims[1]) + offset = np.array([x_start, y_start]) + prompt_segmentor.offset = offset + + height = y_end - y_start + width = x_end - x_start + res, scale_factor = prompt_segmentor.calc_mpp( + (width, height), UI["vstate"].mpp[0], 1500 + ) + + # Get point coordinates + x = np.round(UI["pt_source"].data["x"]) + y = np.round(UI["pt_source"].data["y"]) + point_coords = ( + ( + np.array([[[x[i], -y[i]] for i in range(len(x))]], np.uint32) + - np.array([[x_start, y_start]]) + ) + / scale_factor + if len(x) > 0 + else None + ) + + # Get box coordinates + x = np.round(UI["box_source"].data["x"]) + y = np.round(UI["box_source"].data["y"]) + height = np.round(UI["box_source"].data["height"]) + width = np.round(UI["box_source"].data["width"]) + x = [ + round(UI["box_source"].data["x"][i] - 0.5 * UI["box_source"].data["width"][i]) + for i in range(len(x)) + ] + y = [ + -round(UI["box_source"].data["y"][i] + 0.5 * UI["box_source"].data["height"][i]) + for i in range(len(y)) + ] + width = [round(UI["box_source"].data["width"][i]) for i in range(len(x))] + height = [round(UI["box_source"].data["height"][0]) for i in range(len(x))] + box_coords = ( + ( + np.array( + [ + [ + [x[i], y[i], x[i] + width[i], height[i] + y[i]] + for i in range(len(x)) + ] + ], + np.uint32, + ) + - np.array( + [[x_start, y_start, x_start, y_start]], + ) + ) + / scale_factor + if len(x) > 0 + else None + ) + + tmp_save_dir = Path(tempfile.mkdtemp(suffix="bokeh_temp")) + + # read the region of interest from the slide + roi = UI["vstate"].wsi.read_bounds( + (int(x_start), int(y_start), int(x_end), int(y_end)), + resolution=res, + units="mpp", + ) + + # Run SAM on the point + prediction = prompt_segmentor.run( + images=[roi], + device=select_device(on_gpu=torch.cuda.is_available()), + save_dir=tmp_save_dir / "sam_out", + point_coords=point_coords, + box_coords=box_coords, + ) + + ann_loc = str(prediction[0]) + + fname = make_safe_name(ann_loc) + resp = UI["s"].put( + f"http://{host2}:{port}/tileserver/overlay", + data={"overlay_path": fname}, + ) + ann_types = json.loads(resp.text) + update_ui_on_new_annotations(ann_types) + + # endregion # Set up main window @@ -1504,7 +1605,7 @@ def gather_ui_elements( # noqa: PLR0915 ) model_drop = Select( title="choose model:", - options=["hovernet"], + options=["hovernet", "SAM"], height=25, width=120, max_width=120, diff --git a/tiatoolbox/visualization/tileserver.py b/tiatoolbox/visualization/tileserver.py index 236868f17..f284ae8b7 100644 --- a/tiatoolbox/visualization/tileserver.py +++ b/tiatoolbox/visualization/tileserver.py @@ -718,6 +718,7 @@ def commit_db(self: TileServer) -> str: if ( layer.store.path.suffix == ".db" and layer.store.path.name != f"temp_{session_id}.db" + and not str(layer.store.path.parent.name).endswith("bokeh_temp") ): logger.info("%s*.db committed.", layer.store.path.stem) layer.store.commit()