Skip to content

Conversation

@sridhs21
Copy link
Contributor

This PR fixes incorrect tensor dimension handling in the reflect() function, improves mask interpolation in the rotate() function to preserve binary nature of segmentation masks, and removes redundant train_loss and val_loss declarations that reset loaded loss history.

Problems:

  • reflect() function was mixing 2D and 3D tensor operations which caused incorrect flips for all tensor (containing 4 channels of physical fields).
  • Using bilinear interpolation for binary masks in the rotate() function introduced intermediate values, which could corrupt the binary ground truth (i.e. instead of 0, 1, we can get 0.25, 0.75).
  • Extra declaration of train_loss and val_loss resetting loss history.

Changes made:

  • Changed BILINEAR to NEAREST for mask in rotate function to preserve binary values (0s and 1s only).
  • Corrected reflection operations to use consistent 3D tensor dimensions by changing dims=(axis,) to dims=(axis+1,) for all tensors, fixing the incorrect dimension mapping for 3D tensors.
  • Removed extra train_loss and val_loss declarations after checkpoint loading that were resetting the loaded loss history.

@cwsmith cwsmith linked an issue Jun 30, 2025 that may be closed by this pull request
Copy link
Contributor

@cwsmith cwsmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Thank you.

@cwsmith cwsmith merged commit 624e08f into main Jul 2, 2025
1 check passed
@cwsmith cwsmith deleted the cws/debugRotation branch July 2, 2025 13:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

rotation and reflection for combined tensor is incorrect

3 participants