Skip to content

[RFC] Migrate XPU kernel math functions from std::/:: to sycl::/sycl::native:: namespace #3266

@jianyizh

Description

@jianyizh

Summary

XPU SYCL kernels in src/ATen/native/xpu/sycl/ currently use std::, ::, or bare C math functions (e.g. std::exp, ::expf, sqrtf) instead of the SYCL-native sycl:: or sycl::native:: namespace equivalents. This should be unified to use sycl:: namespace functions for correctness, portability, and potential performance benefits on Intel GPU hardware.

Mapping convention (mirroring CUDA → SYCL)

Current (std/bare) SYCL replacement Notes
std::exp, ::expf, expf sycl::exp
std::log, ::logf, logf sycl::log
std::sqrt, sqrtf sycl::sqrt
std::abs, std::fabs, ::abs sycl::fabs (float/double), sycl::abs (int) complex → keep std::abs
std::pow sycl::pow
std::ceil, ceilf sycl::ceil
std::floor, floorf sycl::floor
std::tanh, ::tanhf sycl::tanh complex → keep std::
std::fmod sycl::fmod
std::lgamma sycl::lgamma
std::log10 sycl::log10 complex → keep std::
std::log1p, ::log1p sycl::log1p complex → keep std::
std::log2 sycl::log2 complex → keep std::
std::isinf sycl::isinf device code only
std::isnan sycl::isnan device code only
std::isfinite sycl::isfinite device code only
std::copysign sycl::copysign
std::erf sycl::erf
std::erfc sycl::erfc
std::exp2 sycl::exp2
std::expm1 sycl::expm1 complex → keep std::
std::frexp sycl::frexp
std::sin sycl::sin complex → keep std::
std::cos sycl::cos complex → keep std::
std::tan sycl::tan complex → keep std::
std::asin sycl::asin complex → keep std::
std::acos sycl::acos complex → keep std::
std::atan sycl::atan complex → keep std::
std::sinh sycl::sinh complex → keep std::
std::cosh sycl::cosh complex → keep std::
std::asinh sycl::asinh
std::acosh sycl::acosh
std::atanh sycl::atanh
std::trunc, std::truncf sycl::trunc complex component calls → keep std::
std::nearbyintf sycl::rint (closest SYCL equivalent) needs accuracy verification
std::rsqrt sycl::rsqrt

Important notes

  • Complex types: sycl:: math functions do NOT support c10::complex/std::complex types. When an operator is dispatched for both real and complex types, the std:: call must be preserved for complex and only changed to sycl:: for real types.
  • Host code: std:: calls in host code (outside kernels/functors) should remain as-is.
  • STD_FUNCTOR macro: ForeachUnaryKernels.cpp defines a STD_FUNCTOR macro (line 181) that generates std::OP_NAME(t) inside device functors. Many functions are affected by this single macro. Migrating these requires either specializing the macro or replacing it with explicit functor definitions that dispatch between sycl:: and std:: based on type.

SYCL math function reference

https://github.khronos.org/SYCL_Reference/iface/math-functions.html


Task List

Each task = one function → one PR with code change + accuracy test + performance test.
PR must be merged before checking the box.

Infrastructure

  • 0. XPUMathCompat.h internal updates

    • Migrate: ::expfsycl::exp, ::tanhfsycl::tanh, etc.
    • Add missing compat functions to match CUDAMathCompat.h: abs, ceil, copysign, floor, log, log1p, max, min, pow, sincos, sqrt, tan, normcdf
  • 0b. NumericUtils.h — add XPU fast math branches (PyTorch main repo)

    • aten/src/ATen/NumericUtils.h defines at::exp, at::log, at::log1p, at::tan template functions
    • On CUDA/HIP these use fast math intrinsics (__expf, __logf, __log1pf, __tanf)
    • On XPU (non-CUDA/HIP), they currently fall back to ::exp, ::log, etc. (C libm)
    • Should add #elif defined(__SYCL_DEVICE_ONLY__) branches using sycl::native::exp, sycl::native::log, etc. to match CUDA behavior
    • Note: XPU kernels currently do NOT call at::exp etc. (0 occurrences in torch-xpu-ops/src/), but aligning for correctness and future use
    • ⚠️ This is a PyTorch main repo change — needs separate PR to pytorch/pytorch

Math functions (alphabetical)

  • 1. abs / fabsstd::abs, std::fabs, ::abssycl::fabs (float/double), sycl::abs (int)

    • ~65 occurrences across 17+ files
    • Files: ActivationLogSigmoidKernels.cpp, BinaryMiscOpsKernels.cpp, DeformConv2dKernels.cpp, DistanceKernels.cpp, Embedding.cpp, ForeachReduceKernels.cpp, GcdLcmKernels.cpp, LaguerrePolynomialLKernel.cpp, LegendrePolynomialPKernel.cpp, LogAddExpKernels.cpp, ReflectionPadKernels.cpp, UpSampleBilinear2dKernels.cpp, MathExtensions.h
    • ⚠️ Keep std::abs for complex types in: AbsKernel.cpp, ForeachUnaryKernels.cpp, UnarySignKernels.cpp, SharedReduceOps.h
  • 2. acosstd::acossycl::acos

    • UnaryGeometricAcosKernel.cpp (1 call)
    • ForeachUnaryKernels.cpp:246 via STD_FUNCTOR macro
    • ⚠️ ForeachUnaryKernels also dispatched for complex types → keep std:: for complex
  • 3. acoshstd::acoshsycl::acosh

    • UnaryGeometricAcoshKernel.cpp (1 call)
  • 4. asinstd::asinsycl::asin

    • UnaryGeometricAsinKernel.cpp (2 calls: AsinComplexFunctor, AsinFunctor)
    • ForeachUnaryKernels.cpp:247 via STD_FUNCTOR macro
    • ⚠️ ForeachUnaryKernels also dispatched for complex types → keep std:: for complex
  • 5. asinhstd::asinhsycl::asinh

    • UnaryGeometricAsinhKernel.cpp (2 calls: AsinhComplexFunctor, AsinhFunctor)
  • 6. atanstd::atansycl::atan

    • UnaryGeometricAtanKernel.cpp (2 calls: AtanComplexFunctor, AtanFunctor)
    • ForeachUnaryKernels.cpp:248 via STD_FUNCTOR macro
    • ⚠️ ForeachUnaryKernels also dispatched for complex types → keep std:: for complex
  • 7. atanhstd::atanhsycl::atanh

    • UnaryGeometricAtanhKernel.cpp (2 calls: AtanhComplexFunctor, AtanhFunctor)
  • 8. ceilstd::ceil, bare ceilfsycl::ceil

    • UpSampleBilinear2dKernels.cpp: 8 occurrences (4 device + 4 host)
    • ForeachUnaryKernels.cpp:195 via STD_FUNCTOR macro
    • ⚠️ Keep host-code ceilf as-is
  • 9. copysignstd::copysignsycl::copysign

    • CopysignKernel.cpp
  • 10. cosstd::cossycl::cos

    • UnaryGeometricCosKernel.cpp (1 call)
    • ForeachUnaryKernels.cpp:252 via STD_FUNCTOR macro
    • Also used in Expm1Functorc10::complex in UnaryKernels.cpp
    • ⚠️ ForeachUnaryKernels also dispatched for complex types → keep std:: for complex
  • 11. coshstd::coshsycl::cosh

    • UnaryGeometricCoshKernel.cpp (2 calls: CoshComplexFunctor, CoshFunctor)
    • ForeachUnaryKernels.cpp:249 via STD_FUNCTOR macro
    • ⚠️ ForeachUnaryKernels also dispatched for complex types → keep std:: for complex
  • 12. erfstd::erfsycl::erf

    • UnarySpecialOpsKernels.cpp (2 calls: ErfFunctor for float + double)
    • ForeachUnaryKernels.cpp:189 via STD_FUNCTOR macro (float types only)
  • 13. erfcstd::erfcsycl::erfc

    • UnarySpecialOpsKernels.cpp (2 calls: ErfcFunctor for float + double)
    • ForeachUnaryKernels.cpp:190 via STD_FUNCTOR macro (float types only)
  • 14. expstd::exp, ::expf, bare expfsycl::exp

    • UnaryKernels.cpp (ExpFunctor), UnarySpecialOpsKernels.cpp (SigmoidFunctor, Exp2Functor, Logit*, EntrFunctor)
    • ActivationEluKernels.cpp, ActivationGeluKernel.cpp, ActivationMishKernels.cpp, ActivationSiluKernels.cpp, ActivationSoftplusKernels.cpp, ActivationLogSigmoidKernels.cpp
    • BatchNormKernels.cpp, DistributionTemplates.h, FusedAdamUtils.h, LossCTCKernels.cpp, LossKernels.cpp, SoftMaxKernels.cpp, RNNKernels.cpp
    • Philox4x32.h (bare expf, 4 occurrences)
    • ForeachUnaryKernels.cpp:319 via STD_FUNCTOR macro
  • 15. exp2std::exp2sycl::exp2

    • UnarySpecialOpsKernels.cpp (Exp2Functor, 1 call for real types)
    • Note: complex overload uses sycl::exp(ln_2 * x) instead (no sycl::exp2 for complex)
  • 16. expm1std::expm1sycl::expm1

    • UnaryKernels.cpp (2 calls: Expm1Functor for scalar + complex)
    • ForeachUnaryKernels.cpp:191 via STD_FUNCTOR macro
    • ⚠️ ForeachUnaryKernels also dispatched for complex types → keep std:: for complex
  • 17. floorstd::floor, bare floorfsycl::floor

    • UpSampleBicubic2dKernels.cpp (4 occurrences, bare floorf)
    • Philox4x32.h (1 occurrence, bare floorf)
    • ForeachUnaryKernels.cpp:194 via STD_FUNCTOR macro
  • 18. fmodstd::fmodsycl::fmod

    • BinaryDivFloorKernel.cpp (1), BinaryRemainderKernel.cpp (2)
  • 19. frexpstd::frexpsycl::frexp

    • UnaryKernels.cpp (FrexpFunctor, 1 call)
  • 20. isinf / isnan / isfinitestd::isinf/isnan/isfinitesycl::isinf/isnan/isfinite

    • ~34 device-code occurrences
    • Files: AmpKernels.cpp, LogAddExpKernels.cpp, LogcumsumexpKernel.cpp, UnaryFractionKernels.cpp, SegmentReduceKernels.cpp, LaguerrePolynomialLKernel.cpp, LegendrePolynomialPKernel.cpp, ForeachFunctors.h, MathExtensions.h, GridSampler.h
    • ⚠️ Keep std:: in host code: SummaryOpsKernels.cpp, DistanceKernels.cpp
  • 21. lgammastd::lgammasycl::lgamma

    • UnaryGammaKernels.cpp (1), MathExtensions.h (3)
    • ForeachUnaryKernels.cpp:192 via STD_FUNCTOR macro
  • 22. logstd::log, bare logfsycl::log

    • UnaryLogKernels.cpp (LogFunctor), UnarySpecialOpsKernels.cpp (Logit0/1Functor, EntrFunctor)
    • ActivationLogSigmoidKernels.cpp, BatchNormKernels.cpp, DistributionTemplates.h, LossCTCKernels.cpp, LossNLLKernel.cpp, RandpermKernel.cpp, RangeFactoriesKernel.cpp
    • Philox4x32.h (bare logf, 2 occurrences)
    • ForeachUnaryKernels.cpp:320 via STD_FUNCTOR macro
  • 23. log1pstd::log1p, ::log1psycl::log1p

    • UnaryLogKernels.cpp (Log1pFunctor)
    • LogAddExpKernels.cpp (2, complex → keep std::)
    • LogcumsumexpKernel.cpp (2)
    • ForeachUnaryKernels.cpp:321 via STD_FUNCTOR macro
  • 24. log2std::log2sycl::log2

    • UnaryLogKernels.cpp (Log2Functor)
    • ForeachUnaryKernels.cpp:322 via STD_FUNCTOR macro
    • ⚠️ ForeachUnaryKernels also dispatched for complex types → keep std:: for complex
  • 25. log10std::log10sycl::log10

    • UnaryLogKernels.cpp (1, complex → keep std::)
    • ForeachUnaryKernels.cpp:323 via STD_FUNCTOR macro
    • ⚠️ Complex types → keep std::
  • 26. nearbyintfstd::nearbyintfsycl::rint (closest SYCL equivalent)

    • UnaryFractionKernels.cpp:107,117,118 (1 generic scalar + 2 complex components)
    • ⚠️ Semantics differ: nearbyint uses current rounding mode, rint may raise inexact. Needs careful accuracy testing.
    • ⚠️ Complex component calls (lines 117-118) → keep std::nearbyintf
  • 27. powstd::powsycl::pow

    • Pow.h, PowKernels.cpp
    • SharedReduceOps.h: #define compat_pow std::powsycl::pow
    • Various other files using pow
  • 28. sinstd::sinsycl::sin

    • UnaryGeometricSinKernel.cpp (1 call)
    • UnarySpecialOpsKernels.cpp (SincFunctor)
    • UnaryKernels.cpp (Expm1Functorc10::complex)
    • ForeachUnaryKernels.cpp:253 via STD_FUNCTOR macro
    • ⚠️ ForeachUnaryKernels also dispatched for complex types → keep std:: for complex
  • 29. sinhstd::sinhsycl::sinh

    • UnaryGeometricSinhKernel.cpp (2 calls: SinhComplexFunctor, SinhFunctor)
    • ForeachUnaryKernels.cpp:250 via STD_FUNCTOR macro
    • ⚠️ ForeachUnaryKernels also dispatched for complex types → keep std:: for complex
  • 30. sqrtstd::sqrt, bare sqrtfsycl::sqrt

    • UnaryKernels.cpp (SqrtFunctor, RsqrtFunctor)
    • UnarySpecialOpsKernels.cpp (KaiserWindowFunctor)
    • WeightNormKernels.cpp (2 bare sqrtf)
    • Philox4x32.h (6 bare sqrtf)
    • SharedReduceOps.h: #define device_sqrt std::sqrtsycl::sqrt
    • ForeachUnaryKernels.cpp:324 via STD_FUNCTOR macro
  • 31. tanstd::tansycl::tan

    • UnaryGeometricTanKernel.cpp (2 calls: TanComplexFunctor, TanFunctor)
    • ForeachUnaryKernels.cpp:254 via STD_FUNCTOR macro
    • ⚠️ ForeachUnaryKernels also dispatched for complex types → keep std:: for complex
  • 32. tanhstd::tanhsycl::tanh

    • UnaryGeometricTanhKernel.cpp (TanhFunctor)
    • RNNKernels.cpp, ActivationMishKernels.cpp, ActivationGeluKernel.cpp, etc.
    • ForeachUnaryKernels.cpp:251 via STD_FUNCTOR macro
  • 33. trunc / truncfstd::trunc, std::truncfsycl::trunc

    • UnaryFractionKernels.cpp:197 (generic scalar std::truncf), :206-207 (complex components)
    • ForeachUnaryKernels.cpp:193 via STD_FUNCTOR macro
    • ⚠️ Complex component calls (lines 206-207) → keep std::truncf

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions