Skip to content

Conversation

@isaaccorley
Copy link
Collaborator

Running the torchfix linter results in the following errors. This PR fixes these.

segmentation_models_pytorch/base/heads.py:13:13: TOR101 Use of deprecated function torch.nn.UpsamplingBilinear2d
segmentation_models_pytorch/decoders/deeplabv3/decoder.py:108:19: TOR101 Use of deprecated function torch.nn.UpsamplingBilinear2d
segmentation_models_pytorch/losses/_functional.py:229:33: TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`.

@isaaccorley isaaccorley self-assigned this Dec 1, 2025
@isaaccorley isaaccorley requested review from Copilot and qubvel and removed request for Copilot December 1, 2025 18:35
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses deprecation warnings from the torchfix linter by replacing deprecated PyTorch functions with their modern equivalents. The changes improve code accuracy and prevent future compatibility issues.

  • Replaced torch.nn.UpsamplingBilinear2d with torch.nn.Upsample(mode="bilinear") in two locations
  • Replaced torch.log(1 + x) with torch.log1p(x) for improved numerical accuracy

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
segmentation_models_pytorch/base/heads.py Updated upsampling layer to use non-deprecated API
segmentation_models_pytorch/decoders/deeplabv3/decoder.py Updated upsampling layer to use non-deprecated API
segmentation_models_pytorch/losses/_functional.py Improved numerical accuracy in wing loss calculation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@codecov
Copy link

codecov bot commented Dec 1, 2025

Codecov Report

❌ Patch coverage is 66.66667% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
segmentation_models_pytorch/losses/_functional.py 0.00% 1 Missing ⚠️
Files with missing lines Coverage Δ
segmentation_models_pytorch/base/heads.py 94.44% <ø> (ø)
...ation_models_pytorch/decoders/deeplabv3/decoder.py 98.68% <100.00%> (ø)
segmentation_models_pytorch/losses/_functional.py 60.33% <0.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@isaaccorley isaaccorley requested a review from Copilot December 1, 2025 20:07
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

)
upsampling = (
nn.UpsamplingBilinear2d(scale_factor=upsampling)
nn.Upsample(mode="bilinear", scale_factor=upsampling, align_corners=True)
Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting align_corners=True changes the upsampling behavior compared to the deprecated UpsamplingBilinear2d which used align_corners=False by default. This could affect model output and should be verified against existing trained models. Consider using align_corners=False to maintain backward compatibility unless the behavior change is intentional.

Suggested change
nn.Upsample(mode="bilinear", scale_factor=upsampling, align_corners=True)
nn.Upsample(mode="bilinear", scale_factor=upsampling, align_corners=False)

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@isaaccorley can we keep the previous behaviour with align_corners=False?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried this initially but this caused the logit tests to fail. I believe this means that the original pretrained checkpoints used align_corners=True.

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! Just a question re align corners

)
upsampling = (
nn.UpsamplingBilinear2d(scale_factor=upsampling)
nn.Upsample(mode="bilinear", scale_factor=upsampling, align_corners=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@isaaccorley can we keep the previous behaviour with align_corners=False?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants