-
Notifications
You must be signed in to change notification settings - Fork 375
Empty tensor handling #3891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
apbose
wants to merge
1
commit into
main
Choose a base branch
from
abose/torchTRT_empty_tensor_handling
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Empty tensor handling #3891
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,186 @@ | ||
| import pytest | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch_tensorrt as torchtrt | ||
| from parameterized import parameterized | ||
| from torch.testing._internal.common_utils import TestCase, run_tests | ||
|
|
||
| DECIMALS_OF_AGREEMENT = 5 # for output comparison | ||
|
|
||
|
|
||
| # We provide non null address to TRT | ||
| class ConcatEmptyModel(nn.Module): | ||
| def __init__(self, dim=0): | ||
| super().__init__() | ||
| self.dim = dim | ||
|
|
||
| def forward(self, x, y): | ||
| return torch.cat([x, y], dim=self.dim) | ||
|
|
||
|
|
||
| # TRT will handle | ||
| class ConcatEmptyModelEmptyConstant(nn.Module): | ||
| def __init__(self, dim=0): | ||
| super().__init__() | ||
| self.dim = dim | ||
|
|
||
| def forward(self, x): | ||
| y = torch.empty((0, 4), dtype=torch.float).cuda() | ||
| return torch.cat([x, y], dim=self.dim) | ||
|
|
||
|
|
||
| # makes use of validator | ||
| class ConcatEmptyModelEmptyConstantMisMatchDim(nn.Module): | ||
| def __init__(self, dim=0): | ||
| super().__init__() | ||
| self.dim = dim | ||
|
|
||
| def forward(self, x): | ||
| y = torch.tensor([], device="cuda") | ||
| return torch.cat([x, y], dim=self.dim) | ||
|
|
||
|
|
||
| class TestConcatEmptyTensor(TestCase): | ||
|
|
||
| @parameterized.expand( | ||
| [ | ||
| ( | ||
| "python_runtime_model_one_empty_0", | ||
| True, | ||
| ConcatEmptyModel, | ||
| "two_inputs", | ||
| (0,), | ||
| ), | ||
| ( | ||
| "cpp_runtime_model_one_empty_0", | ||
| False, | ||
| ConcatEmptyModel, | ||
| "two_inputs", | ||
| (0,), | ||
| ), | ||
| ( | ||
| "python_runtime_model_one_empty_0_4", | ||
| True, | ||
| ConcatEmptyModel, | ||
| "two_inputs", | ||
| (0, 4), | ||
| ), | ||
| ( | ||
| "cpp_runtime_model_one_empty_0_4", | ||
| False, | ||
| ConcatEmptyModel, | ||
| "two_inputs", | ||
| (0, 4), | ||
| ), | ||
| ( | ||
| "python_runtime_model_two_empty_0_4", | ||
| True, | ||
| ConcatEmptyModelEmptyConstant, | ||
| "one_input", | ||
| (0, 4), | ||
| ), | ||
| ( | ||
| "cpp_runtime_model_two_empty_0_4", | ||
| False, | ||
| ConcatEmptyModelEmptyConstant, | ||
| "one_input", | ||
| (0, 4), | ||
| ), | ||
| ( | ||
| "python_runtime_model_three_empty_0", | ||
| True, | ||
| ConcatEmptyModelEmptyConstantMisMatchDim, | ||
| "one_input", | ||
| (0,), | ||
| ), | ||
| ( | ||
| "cpp_runtime_model_three_empty_0", | ||
| False, | ||
| ConcatEmptyModelEmptyConstantMisMatchDim, | ||
| "one_input", | ||
| (0,), | ||
| ), | ||
| ] | ||
| ) | ||
| def test_concat_empty_with_nonempty( | ||
| self, _, use_python_runtime, model_class, input_type, empty_shape | ||
| ): | ||
| """ | ||
| Test concatenation of empty tensor with non-empty tensor | ||
| along a specific dimension using Torch-TensorRT compiled model. | ||
| """ | ||
| # Create model | ||
| model = model_class(dim=0).eval().cuda() | ||
|
|
||
| # Inputs: prepare based on model requirements | ||
| empty_input = torch.empty(empty_shape, dtype=torch.float).cuda() | ||
| non_empty_input = torch.randn((3, 4), dtype=torch.float).cuda() | ||
|
|
||
| if input_type == "two_inputs": | ||
| inputs = [empty_input, non_empty_input] | ||
| else: # one_input | ||
| inputs = [non_empty_input] | ||
|
|
||
| # Compile with Torch-TensorRT | ||
| compiled_model = torchtrt.compile( | ||
| model, | ||
| "dynamo", | ||
| inputs, | ||
| min_block_size=5, | ||
| use_python_runtime=use_python_runtime, | ||
| ) | ||
|
|
||
| # Run reference model | ||
| ref_out = model(*inputs) | ||
| # Run compiled model | ||
| trt_out = compiled_model(*inputs) | ||
|
|
||
| # Assertions | ||
| self.assertEqual(ref_out.shape, trt_out.shape) | ||
| self.assertAlmostEqual( | ||
| float(torch.max(torch.abs(ref_out - trt_out))), | ||
| 0, | ||
| DECIMALS_OF_AGREEMENT, | ||
| msg="Concat with empty tensor output mismatch", | ||
| ) | ||
|
|
||
| @parameterized.expand( | ||
| [ | ||
| ("python_runtime_empty_0", True, (0,)), | ||
| ("cpp_runtime_empty_0", False, (0,)), | ||
| ("python_runtime_empty_0_4", True, (0, 4)), | ||
| ("cpp_runtime_empty_0_4", False, (0, 4)), | ||
| ] | ||
| ) | ||
| def test_concat_nonempty_with_empty(self, _, use_python_runtime, empty_shape): | ||
| """ | ||
| Concatenate non-empty tensor with empty tensor (opposite order) | ||
| """ | ||
| model = ConcatEmptyModel(dim=0).eval().cuda() | ||
|
|
||
| non_empty_input = torch.randn((3, 4), dtype=torch.float).cuda() | ||
| empty_input = torch.empty(empty_shape, dtype=torch.float).cuda() | ||
| inputs = [non_empty_input, empty_input] | ||
|
|
||
| compiled_model = torchtrt.compile( | ||
| model, | ||
| "dynamo", | ||
| inputs, | ||
| min_block_size=5, | ||
| use_python_runtime=use_python_runtime, | ||
| ) | ||
|
|
||
| ref_out = model(*inputs) | ||
| trt_out = compiled_model(*inputs) | ||
|
|
||
| self.assertEqual(ref_out.shape, trt_out.shape) | ||
| self.assertAlmostEqual( | ||
| float(torch.max(torch.abs(ref_out - trt_out))), | ||
| 0, | ||
| DECIMALS_OF_AGREEMENT, | ||
| msg="Concat with empty tensor (opposite order) output mismatch", | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_tests() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I strongly want to avoid having nullptr basically anywhere, we should be looking for some sane default