[Cpp API Compatibility] Fix arange default dtype#78552
[Cpp API Compatibility] Fix arange default dtype#78552SigureMo merged 3 commits intoPaddlePaddle:developfrom
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Pull request overview
This PR fixes ATen compat arange dtype inference to better match PyTorch behavior, especially when dtype is omitted, and expands C++ compatibility tests to cover the updated overloads and defaults.
Changes:
- Update compat
at::arangeto default tokLongwhen inputs are integral anddtypeis omitted, while keeping current default dtype for floating inputs. - Refactor
arangeoverloads to funnel through a common(start, end, step, options)implementation. - Extend C++ tests to cover more
arangeoverloads, pinned-memory behavior, and default-dtype expectations.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
paddle/phi/api/include/compat/ATen/ops/arange.h |
Adjusts dtype resolution for omitted dtype and consolidates overload implementations. |
test/cpp/compat/ATen_pin_memory_creation_test.cc |
Adds coverage for pinned-memory behavior across additional arange overloads. |
test/cpp/compat/ATen_factory_default_dtype_test.cc |
Adds tests ensuring integral arange defaults to kLong and floating arange follows current default dtype. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| auto dense = paddle::experimental::arange( | ||
| paddle::experimental::full( | ||
| {}, start.to<double>(), phi::DataType::FLOAT64), | ||
| paddle::experimental::full( | ||
| {}, end.to<double>(), phi::DataType::FLOAT64), | ||
| paddle::experimental::full({}, 1, phi::DataType::FLOAT64), | ||
| compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()), | ||
| paddle::experimental::full( | ||
| {}, step.to<double>(), phi::DataType::FLOAT64), | ||
| compat::_PD_AtenScalarTypeToPhiDataType(dtype), | ||
| phi::CPUPlace()); | ||
| return dense.copy_to(pinned_place, /*blocking=*/true); | ||
| } | ||
| return paddle::experimental::arange( | ||
| paddle::experimental::full( | ||
| {}, start.to<double>(), phi::DataType::FLOAT64), | ||
| paddle::experimental::full({}, end.to<double>(), phi::DataType::FLOAT64), | ||
| paddle::experimental::full({}, 1, phi::DataType::FLOAT64), | ||
| compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()), | ||
| paddle::experimental::full({}, step.to<double>(), phi::DataType::FLOAT64), | ||
| compat::_PD_AtenScalarTypeToPhiDataType(dtype), | ||
| options._PD_GetPlace()); |
There was a problem hiding this comment.
Now that omitted dtype resolves to kLong for integral inputs, the implementation still materializes start/end/step as FLOAT64 scalars via to<double>(). This loses integer precision for values outside the exact double range (e.g., >2^53) and can produce incorrect sequences compared to an int64 arange. Consider constructing start/end/step using an integral scalar tensor type when the resolved dtype is integral (or otherwise avoid the double round-trip) so large int64 ranges remain exact.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #78552 +/- ##
===========================================
Coverage ? 100.00%
===========================================
Files ? 1
Lines ? 42
Branches ? 0
===========================================
Hits ? 42
Misses ? 0
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
/re-run all-failed |
1 similar comment
|
/re-run all-failed |
|
/re-run all-failed |
|
@ShigureNyako 这个怎么不再 review 了? |
PR Category
Execute Infrastructure
PR Types
Bug fixes
Description
拆分自 #78484
修复
at::arange在不指定dtype时创建的 tensor 数据类型不是kLong的错误,用于 DeepGEMM 的对齐。变更详情
问题描述
arange(5)kLong(int64)kFloatarange(2, 7)kLongkFloatarange(1, 10, 2)kLongkFloat修复内容 (
ATen/ops/arange.h)实现与 PyTorch 一致的类型推断逻辑:
关键修改:
is_integral类型判断kLong(int64)回归测试补充 (
test/cpp/compat/ATen_factory_default_dtype_test.cc)ArangeNoDtypeInt:整数输入的 dtype 推断ArangeNoDtypeFloat:浮点输入的 dtype 推断dtype=nullopt两种路径对齐效果
NoDtypeWithEndInt相关文档
是否引起精度变化
否