Skip to content

Optimize torch.nn.functional.one_hot #3284

@xinyu-intel

Description

@xinyu-intel

🚀 The feature, motivation and pitch

one_hot is used in vLLM Gemma 4 26B workload https://github.com/vllm-project/vllm/blob/308cec5864890f5c0724e1d4531d9fe2ee0a8209/vllm/model_executor/models/gemma4.py#L214. We found 120 D2H copy were invoked per step and most likely caused by the boundary check in https://github.com/pytorch/pytorch/blob/beae96dfc1a2880f80a62f196306188a8d6dfdd9/aten/src/ATen/native/Onehot.cpp#L62. Pls consider to follow cuda to skip the check.

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions