diff --git a/src/deepforest/predict.py b/src/deepforest/predict.py index 7396a023f..d1dd0e696 100644 --- a/src/deepforest/predict.py +++ b/src/deepforest/predict.py @@ -35,6 +35,10 @@ def _predict_image_( """ image = torch.tensor(image).permute(2, 0, 1) image = image / 255 + try: + image = image.to(next(model.parameters()).device) + except StopIteration: + pass with torch.no_grad(): prediction = model(image.unsqueeze(0)) diff --git a/tests/test_main.py b/tests/test_main.py index c403483e1..9756be879 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -411,6 +411,19 @@ def test_predict_image_fromfile(m): assert not prediction.empty +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_predict_image_on_cuda(m): + """Regression: model on CUDA must not raise a device mismatch in predict_image.""" + m.model.to("cuda") + try: + prediction = m.predict_image( + path=get_data(path="2019_YELL_2_528000_4978000_image_crop2.png") + ) + finally: + m.model.to("cpu") + assert isinstance(prediction, pd.DataFrame) + + def test_predict_image_fromarray(m): image_path = get_data(path="2019_YELL_2_528000_4978000_image_crop2.png")