Fix device mismatch errors in torchstain color space conversion functions #71
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The
rgb2lab()andlab2rgb()functions intorchstain/torch/utils/were failing withRuntimeError: Expected all tensors to be on the same devicewhen processing CUDA tensors.This occurred because:
Constant tensors (
_rgb2xyz,_white,_xyz2rgb) were created on CPU by defaultInput tensors were on CUDA devices
PyTorch operations between tensors on different devices are not allowed
Solution
Modified both functions to automatically detect the input tensor's device and move constant tensors to the same device:
rgb2lab.py: Added device detection and moved_rgb2xyzand_whitetensors to input devicelab2rgb.py: Added device detection and moved_whiteand_xyz2rgbtensors to input device