From 9006f5030de2554383bab80f6ee3b1931a27e52d Mon Sep 17 00:00:00 2001 From: Qiusheng Wu Date: Wed, 13 May 2026 02:01:54 -0400 Subject: [PATCH] fix: move predict_image input tensor to the model's device `_predict_image_` built the input tensor on CPU and never moved it to the model's device, so a model placed on CUDA (`m.model.to('cuda')` or via a Lightning trainer that has been on a GPU) crashed with "Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same". Move the image tensor to `next(model.parameters()).device` before the forward pass, and add a regression test that loads the model onto CUDA and runs `predict_image` (skipped when CUDA is unavailable). Closes #1390. --- src/deepforest/predict.py | 4 ++++ tests/test_main.py | 13 +++++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/deepforest/predict.py b/src/deepforest/predict.py index 7396a023f..d1dd0e696 100644 --- a/src/deepforest/predict.py +++ b/src/deepforest/predict.py @@ -35,6 +35,10 @@ def _predict_image_( """ image = torch.tensor(image).permute(2, 0, 1) image = image / 255 + try: + image = image.to(next(model.parameters()).device) + except StopIteration: + pass with torch.no_grad(): prediction = model(image.unsqueeze(0)) diff --git a/tests/test_main.py b/tests/test_main.py index c403483e1..9756be879 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -411,6 +411,19 @@ def test_predict_image_fromfile(m): assert not prediction.empty +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_predict_image_on_cuda(m): + """Regression: model on CUDA must not raise a device mismatch in predict_image.""" + m.model.to("cuda") + try: + prediction = m.predict_image( + path=get_data(path="2019_YELL_2_528000_4978000_image_crop2.png") + ) + finally: + m.model.to("cpu") + assert isinstance(prediction, pd.DataFrame) + + def test_predict_image_fromarray(m): image_path = get_data(path="2019_YELL_2_528000_4978000_image_crop2.png")