From 3b2b1adb047ee098dd5abe57019cfc79a7192d46 Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Fri, 14 Feb 2025 12:34:21 +0000 Subject: [PATCH 01/16] move model to cuda if available --- tests/models/test_arch_mapde.py | 3 +-- tests/models/test_arch_sccnn.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/models/test_arch_mapde.py b/tests/models/test_arch_mapde.py index 61bfde817..21f8cd611 100644 --- a/tests/models/test_arch_mapde.py +++ b/tests/models/test_arch_mapde.py @@ -21,7 +21,7 @@ def _load_mapde(name: str) -> torch.nn.Module: map_location = select_device(on_gpu=ON_GPU) pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) - + model.to(map_location) return model @@ -45,7 +45,6 @@ def test_functionality(remote_sample: Callable) -> None: model = _load_mapde(name="mapde-conic") patch = model.preproc(patch) batch = torch.from_numpy(patch)[None] - model = model.to() output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) output = model.postproc(output[0]) assert np.all(output[0:2] == [[19, 171], [53, 89]]) diff --git a/tests/models/test_arch_sccnn.py b/tests/models/test_arch_sccnn.py index 2729d2b3a..f2bc933a4 100644 --- a/tests/models/test_arch_sccnn.py +++ b/tests/models/test_arch_sccnn.py @@ -19,7 +19,7 @@ def _load_sccnn(name: str) -> torch.nn.Module: map_location = select_device(on_gpu=env_detection.has_gpu()) pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) - + model.to(map_location) return model @@ -48,7 +48,6 @@ def test_functionality(remote_sample: Callable) -> None: ) output = model.postproc(output[0]) assert np.all(output == [[8, 7]]) - model = _load_sccnn(name="sccnn-conic") output = model.infer_batch( model, From b580e1eb391440b67f72ed3a14a079337945163b Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Fri, 14 Feb 2025 12:40:03 +0000 Subject: [PATCH 02/16] improve typing --- tests/models/test_arch_mapde.py | 2 +- tests/models/test_arch_sccnn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_arch_mapde.py b/tests/models/test_arch_mapde.py index 21f8cd611..2a65583c6 100644 --- a/tests/models/test_arch_mapde.py +++ b/tests/models/test_arch_mapde.py @@ -14,7 +14,7 @@ ON_GPU = toolbox_env.has_gpu() -def _load_mapde(name: str) -> torch.nn.Module: +def _load_mapde(name: str) -> MapDe: """Loads MapDe model with specified weights.""" model = MapDe() weights_path = fetch_pretrained_weights(name) diff --git a/tests/models/test_arch_sccnn.py b/tests/models/test_arch_sccnn.py index f2bc933a4..16c99cc49 100644 --- a/tests/models/test_arch_sccnn.py +++ b/tests/models/test_arch_sccnn.py @@ -12,7 +12,7 @@ from tiatoolbox.wsicore.wsireader import WSIReader -def _load_sccnn(name: str) -> torch.nn.Module: +def _load_sccnn(name: str) -> SCCNN: """Loads SCCNN model with specified weights.""" model = SCCNN() weights_path = fetch_pretrained_weights(name) From 35ffd61fdd218a63254d0a193732a43b37c4817e Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Fri, 14 Feb 2025 15:59:18 +0000 Subject: [PATCH 03/16] fix typing --- .github/workflows/mypy-type-check.yml | 5 ++- tiatoolbox/models/architecture/__init__.py | 36 +++++++++++++++------- tiatoolbox/models/models_abc.py | 7 ++--- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/.github/workflows/mypy-type-check.yml b/.github/workflows/mypy-type-check.yml index b7c15a5ec..b5f11d17c 100644 --- a/.github/workflows/mypy-type-check.yml +++ b/.github/workflows/mypy-type-check.yml @@ -46,4 +46,7 @@ jobs: tiatoolbox/tools \ tiatoolbox/data \ tiatoolbox/annotation \ - tiatoolbox/cli/common.py + tiatoolbox/cli/common.py \ + tiatoolbox/models/__init__.py \ + tiatoolbox/models/models_abc.py \ + tiatoolbox/models/architecture/__init__.py \ diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index a2c33dc4f..38f38bf17 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -3,18 +3,18 @@ from __future__ import annotations import os +from pathlib import Path from pydoc import locate -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Union, cast import torch from tiatoolbox import rcParam from tiatoolbox.models.dataset.classification import predefined_preproc_func +from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.utils import download_data if TYPE_CHECKING: # pragma: no cover - from pathlib import Path - from tiatoolbox.models.models_abc import IOConfigABC @@ -53,10 +53,13 @@ def fetch_pretrained_weights( if save_path is None: file_name = info["url"].split("/")[-1] - save_path = rcParam["TIATOOLBOX_HOME"] / "models" / file_name + processed_save_path = rcParam["TIATOOLBOX_HOME"] / "models" / file_name + + if type(save_path) is str: + processed_save_path = Path(save_path) - download_data(info["url"], save_path=save_path, overwrite=overwrite) - return save_path + download_data(info["url"], save_path=processed_save_path, overwrite=overwrite) + return processed_save_path def get_pretrained_model( @@ -129,9 +132,15 @@ def get_pretrained_model( info = PRETRAINED_INFO[pretrained_model] arch_info = info["architecture"] - creator = locate(f"tiatoolbox.models.architecture.{arch_info['class']}") - - model = creator(**arch_info["kwargs"]) + model_class_info = arch_info["class"] + model_module_name = str(".".join(model_class_info.split(".")[:-1])) + model_name = str(model_class_info.split(".")[-1]) + + # Import module containing required model class + arch_module = locate(f"tiatoolbox.models.architecture.{model_module_name}") + # Get model class form module + model_class = getattr(arch_module, model_name) + model = model_class(**arch_info["kwargs"]) # TODO(TBC): Dictionary of dataset specific or transformation? # noqa: FIX002,TD003 if "dataset" in info: # ! this is a hack currently, need another PR to clean up @@ -152,7 +161,12 @@ def get_pretrained_model( # ! io_info = info["ioconfig"] - creator = locate(f"tiatoolbox.models.engine.{io_info['class']}") + io_class_info = io_info["class"] + io_module_name = str(".".join(io_class_info.split(".")[:-1])) + io_class_name = str(io_class_info.split(".")[-1]) + + engine_module = locate(f"tiatoolbox.models.engine.{io_module_name}") + engine_class = getattr(engine_module, io_class_name) - iostate = creator(**io_info["kwargs"]) + iostate = engine_class(**io_info["kwargs"]) return model, iostate diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index a3af4e7f0..689eec21f 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -7,7 +7,6 @@ import torch import torch._dynamo -from torch import device as torch_device torch._dynamo.config.suppress_errors = True # skipcq: PYL-W0212 # noqa: SLF001 @@ -189,12 +188,12 @@ def to(self: ModelABC, device: str = "cpu") -> torch.nn.Module: The model after being moved to cpu/gpu. """ - device = torch_device(device) - model = super().to(device) + torch_device = torch.device(device) + model = super().to(torch_device) # If target device istorch.cuda and more # than one GPU is available, use DataParallel - if device.type == "cuda" and torch.cuda.device_count() > 1: + if torch_device.type == "cuda" and torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) # pragma: no cover return model From 0bf42a7183382fdab44558cd5c81a8f5da1574a1 Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Fri, 14 Feb 2025 16:03:46 +0000 Subject: [PATCH 04/16] tidy up imports --- tiatoolbox/models/architecture/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index 38f38bf17..bbe24d6b9 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -2,16 +2,14 @@ from __future__ import annotations -import os from pathlib import Path from pydoc import locate -from typing import TYPE_CHECKING, Optional, Union, cast +from typing import TYPE_CHECKING import torch from tiatoolbox import rcParam from tiatoolbox.models.dataset.classification import predefined_preproc_func -from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.utils import download_data if TYPE_CHECKING: # pragma: no cover From 34d2c98c7d285403f631ab3cd74e24f5638e224c Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Fri, 14 Feb 2025 16:30:52 +0000 Subject: [PATCH 05/16] fix bug --- tiatoolbox/models/architecture/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index bbe24d6b9..3c61d7424 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -52,9 +52,10 @@ def fetch_pretrained_weights( if save_path is None: file_name = info["url"].split("/")[-1] processed_save_path = rcParam["TIATOOLBOX_HOME"] / "models" / file_name - - if type(save_path) is str: + elif type(save_path) is str: processed_save_path = Path(save_path) + else: + processed_save_path = save_path download_data(info["url"], save_path=processed_save_path, overwrite=overwrite) return processed_save_path From 2d7a0475bfe8ce1e0b4ff3f52ac31f6a1723136b Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Fri, 14 Feb 2025 17:13:20 +0000 Subject: [PATCH 06/16] add architecture/utils.py --- .github/workflows/mypy-type-check.yml | 1 + tiatoolbox/models/architecture/utils.py | 17 ++++++++++------- tiatoolbox/models/models_abc.py | 14 ++++++++------ 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/.github/workflows/mypy-type-check.yml b/.github/workflows/mypy-type-check.yml index b5f11d17c..02befe4d4 100644 --- a/.github/workflows/mypy-type-check.yml +++ b/.github/workflows/mypy-type-check.yml @@ -50,3 +50,4 @@ jobs: tiatoolbox/models/__init__.py \ tiatoolbox/models/models_abc.py \ tiatoolbox/models/architecture/__init__.py \ + tiatoolbox/models/architecture/utils.py \ diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py index 1f5cbfd64..a23f9f396 100644 --- a/tiatoolbox/models/architecture/utils.py +++ b/tiatoolbox/models/architecture/utils.py @@ -3,6 +3,7 @@ from __future__ import annotations import sys +from typing import cast import numpy as np import torch @@ -48,7 +49,7 @@ def compile_model( model: nn.Module | None = None, *, mode: str = "default", -) -> nn.Module: +) -> nn.Module | None: """A decorator to compile a model using torch-compile. Args: @@ -97,12 +98,12 @@ def compile_model( ) return model - return torch.compile(model, mode=mode) # pragma: no cover + return cast(nn.Module, torch.compile(model, mode=mode)) # pragma: no cover def centre_crop( - img: np.ndarray | torch.tensor, - crop_shape: np.ndarray | torch.tensor, + img: np.ndarray | torch.Tensor, + crop_shape: np.ndarray | torch.Tensor | tuple, data_format: str = "NCHW", ) -> np.ndarray | torch.Tensor: """A function to center crop image with given crop shape. @@ -136,8 +137,8 @@ def centre_crop( def centre_crop_to_shape( - x: np.ndarray | torch.tensor, - y: np.ndarray | torch.tensor, + x: np.ndarray | torch.Tensor, + y: np.ndarray | torch.Tensor, data_format: str = "NCHW", ) -> np.ndarray | torch.Tensor: """A function to center crop image to shape. @@ -200,11 +201,13 @@ def __init__(self: UpSample2x) -> None: """Initialize :class:`UpSample2x`.""" super().__init__() # correct way to create constant within module + + self.unpool_mat: torch.Tensor self.register_buffer( "unpool_mat", torch.from_numpy(np.ones((2, 2), dtype="float32")), ) - self.unpool_mat.unsqueeze(0) + self.unpool_mat.unsqueeze_(0) def forward(self: UpSample2x, x: torch.Tensor) -> torch.Tensor: """Logic for using layers defined in init. diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 689eec21f..97b917952 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -10,7 +10,6 @@ torch._dynamo.config.suppress_errors = True # skipcq: PYL-W0212 # noqa: SLF001 - if TYPE_CHECKING: # pragma: no cover from pathlib import Path @@ -56,8 +55,8 @@ def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module: # DataParallel work only for cuda model = torch.nn.DataParallel(model) - device = torch.device(device) - return model.to(device) + torch_device = torch.device(device) + return model.to(torch_device) class ModelABC(ABC, torch.nn.Module): @@ -174,7 +173,10 @@ def postproc_func(self: ModelABC, func: Callable) -> None: else: self._postproc = func - def to(self: ModelABC, device: str = "cpu") -> torch.nn.Module: + def to( # type: ignore[override] + self: ModelABC, + device: str = "cpu", + ) -> ModelABC | torch.nn.DataParallel[ModelABC]: """Transfers model to cpu/gpu. Args: @@ -184,7 +186,7 @@ def to(self: ModelABC, device: str = "cpu") -> torch.nn.Module: Transfers model to the specified device. Default is "cpu". Returns: - torch.nn.Module: + torch.nn.Module | torch.nn.DataParallel: The model after being moved to cpu/gpu. """ @@ -194,7 +196,7 @@ def to(self: ModelABC, device: str = "cpu") -> torch.nn.Module: # If target device istorch.cuda and more # than one GPU is available, use DataParallel if torch_device.type == "cuda" and torch.cuda.device_count() > 1: - model = torch.nn.DataParallel(model) # pragma: no cover + return torch.nn.DataParallel(model) # pragma: no cover return model From 9333aa593f572e0077c8c9510a5e39a9665884f7 Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Fri, 21 Feb 2025 12:18:07 +0000 Subject: [PATCH 07/16] fix model_abc.py --- tiatoolbox/models/models_abc.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 97b917952..a8a8f7262 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -70,7 +70,9 @@ def __init__(self: ModelABC) -> None: @abstractmethod # This is generic abc, else pylint will complain - def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None: + def forward( + self: ModelABC, *args: tuple[Any, ...], **kwargs: dict + ) -> None | torch.Tensor: """Torch method, this contains logic for using layers defined in init.""" ... # pragma: no cover @@ -176,6 +178,9 @@ def postproc_func(self: ModelABC, func: Callable) -> None: def to( # type: ignore[override] self: ModelABC, device: str = "cpu", + dtype: torch.dtype | None = None, + *, + non_blocking: bool = False, ) -> ModelABC | torch.nn.DataParallel[ModelABC]: """Transfers model to cpu/gpu. @@ -184,6 +189,11 @@ def to( # type: ignore[override] PyTorch defined model. device (str): Transfers model to the specified device. Default is "cpu". + dtype (:class:`torch.dtype`): the desired floating point or complex dtype of + the parameters and buffers in this module. + non_blocking (bool): When set, it tries to convert/move asynchronously + with respect to the host if possible, e.g., moving CPU Tensors with + pinned memory to CUDA devices. Returns: torch.nn.Module | torch.nn.DataParallel: @@ -191,7 +201,7 @@ def to( # type: ignore[override] """ torch_device = torch.device(device) - model = super().to(torch_device) + model = super().to(torch_device, dtype=dtype, non_blocking=non_blocking) # If target device istorch.cuda and more # than one GPU is available, use DataParallel From dedb2c3bdfa05bca861a35666296bfd7bf7ca967 Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Fri, 21 Feb 2025 12:18:58 +0000 Subject: [PATCH 08/16] fix utils.py --- tiatoolbox/models/architecture/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py index a23f9f396..f532e7d99 100644 --- a/tiatoolbox/models/architecture/utils.py +++ b/tiatoolbox/models/architecture/utils.py @@ -46,10 +46,10 @@ def is_torch_compile_compatible() -> bool: def compile_model( - model: nn.Module | None = None, + model: nn.Module, *, mode: str = "default", -) -> nn.Module | None: +) -> nn.Module: """A decorator to compile a model using torch-compile. Args: From 4c2fc4423ab96569bc06737dde1e39b3b1686652 Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Mon, 24 Mar 2025 15:51:25 +0000 Subject: [PATCH 09/16] try to fix pytest --- tiatoolbox/models/architecture/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py index f532e7d99..cffc3d8cb 100644 --- a/tiatoolbox/models/architecture/utils.py +++ b/tiatoolbox/models/architecture/utils.py @@ -201,13 +201,12 @@ def __init__(self: UpSample2x) -> None: """Initialize :class:`UpSample2x`.""" super().__init__() # correct way to create constant within module - - self.unpool_mat: torch.Tensor + self.unpool_mat:torch.Tensor self.register_buffer( "unpool_mat", torch.from_numpy(np.ones((2, 2), dtype="float32")), ) - self.unpool_mat.unsqueeze_(0) + self.unpool_mat.unsqueeze(0) def forward(self: UpSample2x, x: torch.Tensor) -> torch.Tensor: """Logic for using layers defined in init. From 3f3ea1208dad3c326f1006548b6c3d724c8cb9fa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Mar 2025 15:51:51 +0000 Subject: [PATCH 10/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tiatoolbox/models/architecture/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py index cffc3d8cb..59378383d 100644 --- a/tiatoolbox/models/architecture/utils.py +++ b/tiatoolbox/models/architecture/utils.py @@ -201,7 +201,7 @@ def __init__(self: UpSample2x) -> None: """Initialize :class:`UpSample2x`.""" super().__init__() # correct way to create constant within module - self.unpool_mat:torch.Tensor + self.unpool_mat: torch.Tensor self.register_buffer( "unpool_mat", torch.from_numpy(np.ones((2, 2), dtype="float32")), From 901a96835f65dcf43af6dbfed6a8dda7b623a58d Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Fri, 28 Mar 2025 12:33:34 +0000 Subject: [PATCH 11/16] pin glymur version < 0.14 --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index c29a54620..fcb00e75d 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ defusedxml>=0.7.1 filelock>=3.9.0 flask>=2.2.2 flask-cors>=4.0.0 -glymur>=0.12.7 +glymur>=0.12.7 < 0.14 # 0.14 is not compatible with python3.9 imagecodecs>=2022.9.26 joblib>=1.1.1 jupyterlab>=3.5.2 From 98e3687a98b27b1d46c7a9c4461781145b92a363 Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Fri, 28 Mar 2025 12:35:47 +0000 Subject: [PATCH 12/16] pin glymur version < 0.14 --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index fcb00e75d..9756b73d2 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ defusedxml>=0.7.1 filelock>=3.9.0 flask>=2.2.2 flask-cors>=4.0.0 -glymur>=0.12.7 < 0.14 # 0.14 is not compatible with python3.9 +glymur>=0.12.7, < 0.14 # 0.14 is not compatible with python3.9 imagecodecs>=2022.9.26 joblib>=1.1.1 jupyterlab>=3.5.2 From a0835a932cad0cc69e960a5ec99d257430f72fff Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Fri, 28 Mar 2025 13:10:23 +0000 Subject: [PATCH 13/16] improve test coverage --- tests/test_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index a06d57d90..200fc316a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1614,10 +1614,16 @@ def test_fetch_pretrained_weights(tmp_path: Path) -> None: fetch_pretrained_weights(model_name="mobilenet_v3_small-pcam", save_path=file_path) assert file_path.exists() assert file_path.stat().st_size > 0 + file_path.unlink() with pytest.raises(ValueError, match="does not exist"): fetch_pretrained_weights("abc", file_path) + # Test save_path is str + file_path = fetch_pretrained_weights("mobilenet_v3_small-pcam", os.path.join(tmp_path, "test_fetch_pretrained_weights.pth")) + assert Path(file_path).exists() + assert Path(file_path).stat().st_size > 0 + def test_imwrite(tmp_path: Path) -> NoReturn: """Create a temporary file path.""" From 3b0156e28593c9551095286193e4ee72e07c5765 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Mar 2025 13:12:17 +0000 Subject: [PATCH 14/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 200fc316a..20c828953 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1620,7 +1620,10 @@ def test_fetch_pretrained_weights(tmp_path: Path) -> None: fetch_pretrained_weights("abc", file_path) # Test save_path is str - file_path = fetch_pretrained_weights("mobilenet_v3_small-pcam", os.path.join(tmp_path, "test_fetch_pretrained_weights.pth")) + file_path = fetch_pretrained_weights( + "mobilenet_v3_small-pcam", + os.path.join(tmp_path, "test_fetch_pretrained_weights.pth"), + ) assert Path(file_path).exists() assert Path(file_path).stat().st_size > 0 From 66718358452435a5d97aa89494dc7c23c61c4ac5 Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Fri, 28 Mar 2025 13:17:10 +0000 Subject: [PATCH 15/16] fix ruff --- tests/test_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 200fc316a..d84aba0b0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1620,7 +1620,8 @@ def test_fetch_pretrained_weights(tmp_path: Path) -> None: fetch_pretrained_weights("abc", file_path) # Test save_path is str - file_path = fetch_pretrained_weights("mobilenet_v3_small-pcam", os.path.join(tmp_path, "test_fetch_pretrained_weights.pth")) + file_path_str = str(file_path) + file_path = fetch_pretrained_weights("mobilenet_v3_small-pcam", file_path_str) assert Path(file_path).exists() assert Path(file_path).stat().st_size > 0 From c8f2adc867d0ef3dc5398d5e6f1329c78dd81abf Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Thu, 1 May 2025 14:14:07 +0100 Subject: [PATCH 16/16] fix vanilla.py --- tiatoolbox/models/architecture/vanilla.py | 39 +++++++++++++++++------ tiatoolbox/models/models_abc.py | 9 +++--- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index 12f2bf8bb..08a9d2d16 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np import timm @@ -19,9 +19,9 @@ def _get_architecture( arch_name: str, - weights: str or WeightsEnum = "DEFAULT", + weights: str | WeightsEnum = "DEFAULT", **kwargs: dict, -) -> list[nn.Sequential, ...] | nn.Sequential: +) -> torch.nn.ModuleList | nn.Sequential: """Retrieve a CNN model architecture. This function fetches a Convolutional Neural Network (CNN) model architecture, @@ -38,7 +38,7 @@ def _get_architecture( Key-word arguments. Returns: - list[nn.Sequential, ...] | nn.Sequential: + list[nn.Sequential] | nn.Sequential: A list of PyTorch network layers wrapped with `nn.Sequential`. Raises: @@ -94,7 +94,7 @@ def _get_timm_architecture( arch_name: str, *, pretrained: bool, -) -> list[nn.Sequential, ...] | nn.Sequential: +) -> torch.nn.ModuleList | nn.Sequential: """Retrieve a timm model architecture. This function fetches a model architecture from the timm library, specifically for @@ -124,6 +124,7 @@ def _get_timm_architecture( model = timm.create_model(arch_name, pretrained=pretrained) return nn.Sequential(*list(model.children())[:-1]) + arch_map: dict[str, dict] = {} arch_map = { # UNI tile encoder: https://huggingface.co/MahmoodLab/UNI "UNI": { @@ -306,7 +307,9 @@ def __init__(self: CNNModel, backbone: str, num_classes: int = 1) -> None: # pylint: disable=W0221 # because abc is generic, this is actual definition - def forward(self: CNNModel, imgs: torch.Tensor) -> torch.Tensor: + def forward( + self: CNNModel, *args: tuple[Any, ...], **kwargs: dict + ) -> torch.Tensor | None: """Pass input data through the model. Args: @@ -318,6 +321,9 @@ def forward(self: CNNModel, imgs: torch.Tensor) -> torch.Tensor: The output logits after passing through the model. """ + imgs = args[0] + if imgs is None: + return None feat = self.feat_extract(imgs) gap_feat = self.pool(feat) gap_feat = torch.flatten(gap_feat, 1) @@ -431,7 +437,9 @@ def __init__( # pylint: disable=W0221 # because abc is generic, this is actual definition - def forward(self: TimmModel, imgs: torch.Tensor) -> torch.Tensor: + def forward( + self: TimmModel, *args: tuple[Any, ...], **kwargs: dict + ) -> torch.Tensor | None: """Pass input data through the model. Args: @@ -443,6 +451,9 @@ def forward(self: TimmModel, imgs: torch.Tensor) -> torch.Tensor: The output logits after passing through the model. """ + imgs = args[0] + if imgs is None: + return None feat = self.feat_extract(imgs) feat = torch.flatten(feat, 1) logit = self.classifier(feat) @@ -552,7 +563,9 @@ def __init__(self: CNNBackbone, backbone: str) -> None: # pylint: disable=W0221 # because abc is generic, this is actual definition - def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor: + def forward( + self: CNNBackbone, *args: tuple[Any, ...], **kwargs: dict + ) -> torch.Tensor | None: """Pass input data through the model. Args: @@ -564,6 +577,9 @@ def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor: The extracted features. """ + imgs = args[0] + if imgs is None: + return None feat = self.feat_extract(imgs) gap_feat = self.pool(feat) return torch.flatten(gap_feat, 1) @@ -645,7 +661,9 @@ def __init__(self: TimmBackbone, backbone: str, *, pretrained: bool) -> None: # pylint: disable=W0221 # because abc is generic, this is actual definition - def forward(self: TimmBackbone, imgs: torch.Tensor) -> torch.Tensor: + def forward( + self: TimmBackbone, *args: tuple[Any, ...], **kwargs: dict + ) -> torch.Tensor | None: """Pass input data through the model. Args: @@ -657,6 +675,9 @@ def forward(self: TimmBackbone, imgs: torch.Tensor) -> torch.Tensor: The extracted features. """ + imgs = args[0] + if imgs is None: + return None feats = self.feat_extract(imgs) return torch.flatten(feats, 1) diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index a8a8f7262..53da54818 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -80,9 +80,9 @@ def forward( @abstractmethod def infer_batch( model: torch.nn.Module, - batch_data: np.ndarray, + batch_data: torch.Tensor, device: str, - ) -> None: + ) -> None | dict[str, np.ndarray] | list[dict[str, np.ndarray]]: """Run inference on an input batch. Contains logic for forward operation as well as I/O aggregation. @@ -90,7 +90,7 @@ def infer_batch( Args: model (nn.Module): PyTorch defined model. - batch_data (np.ndarray): + batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. device (str): @@ -227,4 +227,5 @@ def load_weights_from_file(self: ModelABC, weights: str | Path) -> torch.nn.Modu # ! assume to be saved in single GPU mode # always load on to the CPU saved_state_dict = torch.load(weights, map_location="cpu") - return super().load_state_dict(saved_state_dict, strict=True) + super().load_state_dict(saved_state_dict, strict=True) + return self