Skip to content

Allow mjx.get_rgb/get_depth to accept batched render output#3216

Open
tkelestemur wants to merge 2 commits intogoogle-deepmind:mainfrom
tkelestemur:mjx-render-batched-unpackers
Open

Allow mjx.get_rgb/get_depth to accept batched render output#3216
tkelestemur wants to merge 2 commits intogoogle-deepmind:mainfrom
tkelestemur:mjx-render-batched-unpackers

Conversation

@tkelestemur
Copy link
Copy Markdown

Summary

  • Make mjx.get_rgb() and mjx.get_depth() accept packed buffers with leading batch dimensions, so callers can pass mjx.render() outputs directly for both single-world and batched rendering.
  • Preserve any leading dimensions in the unpacked outputs, returning (..., H, W, 3) for RGB and (..., H, W, 1) for depth.
  • Update render_util tests to cover direct batched inputs and remove the unnecessary Warp/CUDA dependency from these JAX-side unpacking tests.

Why

mjx.render() always returns packed outputs with a leading world axis, including when nworld == 1, but the public unpacking helpers previously only handled a single packed buffer. This change makes the helper API match the render API and removes the need for downstream callers to manually squeeze or vmap before unpacking.

Testing

  • python -m pytest mjx/mujoco/mjx/_src/render_util_test.py -q
  • python -m pytest -n auto -v -k 'not IntegrationTest' --pyargs mujoco.mjx
  • python -m pytest -v --pyargs mujoco
  • ctest -C Release --output-on-failure

yield


def setUp(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we keep this setup? we have other CI besides github CI, and want to make sure the tests have minimal changes. LGTM otherwise

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restored the existing Warp/CUDA-gated setup and kept the new batched-shape assertions as the only functional test change. Re-ran the MJX suite on this final shape (376 passed, 39 skipped).

Restore the existing Warp/CUDA-gated render_util test setup so the review change stays minimal while preserving the new batched shape coverage.

Made-with: Cursor
@tkelestemur
Copy link
Copy Markdown
Author

Thanks @btaba Good to go!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants