Skip to content

🐛[BUG]: GeoTransolver raises ValueError: The last dimension of p_grid and x must be 3 #1606

@chychen

Description

@chychen

Version

2.1.0a0

On which installation method(s) does this occur?

Docker

Describe the issue

I am encountering a ValueError when initializing GeoTransolver and attempting a forward pass. It seems that the model is passing a tensor with a dimension size other than 3 (specifically, likely related to the geometry input) into the ball_query module, which strictly enforces a last dimension of 3 for spatial coordinates.

Expected Behavior

The model should process the geometry features independently of the spatial coordinates, or the ball_query logic should be resilient to the input dimensions provided.

Actual Behavior

The code crashes at physicsnemo/nn/module/ball_query.py, line 89:
ValueError: The last dimension of p_grid and x must be 3

Minimum reproducible example

import torch
from physicsnemo.experimental.models.geotransolver import GeoTransolver
model = GeoTransolver(
    functional_dim=64,
    out_dim=3,
    geometry_dim=5,
    global_dim=16,
    n_hidden=256,
    n_layers=4,
    use_te=False,
    include_local_features=True
)
local_emb = torch.randn(1, 1000, 64)
coords = torch.randn(1, 1000, 3)  # (batch, nodes, spatial_dim)
geometry = torch.randn(1, 1000, 5)  # (batch, nodes, spatial_dim)
global_emb = torch.randn(1, 1, 16)  # (batch, 1, global_features)
output = model(local_emb, local_positions=coords, global_embedding=global_emb, geometry=geometry)

Relevant log output

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/jaych/Data/physicsnemo/physicsnemo/experimental/models/geotransolver/geotransolver.py", line 532, in forward
    embedding_states, local_embedding_bq = self.context_builder.build_context(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/jaych/Data/physicsnemo/physicsnemo/experimental/models/geotransolver/context_projector.py", line 801, in build_context
    context_feats = self.local_extractors[i].extract_context_features(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/jaych/Data/physicsnemo/physicsnemo/experimental/models/geotransolver/context_projector.py", line 552, in extract_context_features
    tokenizer(processor(spatial_coords, geometry))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/jaych/Data/physicsnemo/physicsnemo/experimental/models/geotransolver/context_projector.py", line 423, in forward
    _, neighbors = self.bq_warp(query_points, key_features)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/jaych/Data/physicsnemo/physicsnemo/nn/module/ball_query.py", line 89, in forward
    raise ValueError("The last dimension of p_grid and x must be 3")
ValueError: The last dimension of p_grid and x must be 3

Environment details

docker run -it --rm nvcr.io/nvidia/physicsnemo/physicsnemo:26.03 bash

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingexternalIssues/PR filed by people outside the team

Type

No fields configured for Bug.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions