fix: squeeze 3D action tensor in LinUCB learn_batch#129
Conversation
batch.action can have shape [B, 1, N] for one-hot encoded actions, but torch.cat with batch.state (shape [B, D]) requires 2D tensors. Squeeze dim=1 to handle both [B, N] and [B, 1, N] action shapes. Fixes facebookresearch#125
|
Hi @dashitongzhi! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
There was a problem hiding this comment.
Pull request overview
This PR addresses a shape mismatch in LinearBandit.learn_batch() (LinUCB) when TransitionBatch.action is provided as a 3D tensor (e.g., [B, 1, N] one-hot), which breaks feature concatenation with 2D batch.state.
Changes:
- Squeezes
batch.actionon dimension 1 before concatenating withbatch.stateinLinearBandit.learn_batch().
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| else torch.ones_like(expected_values) | ||
| ) | ||
| x = torch.cat([batch.state, batch.action], dim=1) | ||
| x = torch.cat([batch.state, torch.squeeze(batch.action, dim=1)], dim=1) |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
Problem
When running the contextual_bandits_tutorial with LinUCB, training fails due to tensor dimension mismatch:
batch.stateshape:[1, 16]batch.actionshape:[1, 1, 10](one-hot encoded, 3D)torch.cat([batch.state, batch.action], dim=1)fails because state is 2D but action is 3D.Fix
Added
torch.squeeze(batch.action, dim=1)to handle 3D action tensors. This safely converts[B, 1, N]→[B, N]while leaving already-2D[B, N]tensors unchanged (squeeze is a no-op when dim=1 has size > 1).Fixes #125