From 3d1e24cfb6534e7d1c7816c84dd9a43a7776787c Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 24 Dec 2025 08:15:21 -0500 Subject: [PATCH 1/2] Move getDataType and castToDtype out-of-line to polymorphic_value.cpp Reduces cumulative template instantiation time by 22% (clang -ftime-trace). --- csrc/polymorphic_value.cpp | 61 ++++++++++++++++++++++++++++++++++++ csrc/type.h | 63 ++++---------------------------------- 2 files changed, 67 insertions(+), 57 deletions(-) diff --git a/csrc/polymorphic_value.cpp b/csrc/polymorphic_value.cpp index 58c3c3344eb..f9b4a26fb66 100644 --- a/csrc/polymorphic_value.cpp +++ b/csrc/polymorphic_value.cpp @@ -13,6 +13,67 @@ namespace nvfuser { +// Implementation of getDataType - moved from type.h to reduce template bloat. +// This function uses for_all_types which triggers heavy template instantiation. +DataType getDataType(const PolymorphicValue& value) { + std::optional dtype = std::nullopt; + PolymorphicValue::for_all_types([&value, &dtype](auto _) { + using T = typename decltype(_)::type; + if constexpr (IsPrimitiveNativeType::value) { + if (value.is()) { + dtype = NativeTypeToDataType::type; + } + } else if constexpr (std::is_same_v>) { + if (value.is()) { + const auto& vec = value.as(); + size_t size = vec.size(); + NVF_CHECK(size > 0, "Empty array is not supported"); + dtype = + ArrayType{std::make_shared(getDataType(vec[0])), size}; + } + } else if constexpr (std::is_same_v) { + // For pointers in polymorphic value, we only store the data size of the + // pointee, so it is impossible to infer the pointer type. + NVF_CHECK(!value.is(), "Can not infer pointer type."); + } else if constexpr (std::is_same_v) { + if (value.is()) { + dtype = value.as().type(); + } + } else if constexpr (std::is_same_v) { + if (value.is()) { + const auto& opaque = value.as(); + dtype = DataType(OpaqueType{ + .type_info = opaque.any().type(), .size = opaque.size()}); + } + } + }); + NVF_CHECK(dtype.has_value(), "Unknown dtype for ", value.type().name()); + return dtype.value(); +} + +// Implementation of castToDtype - moved from type.h to reduce template bloat. +// This function uses for_all_types which triggers heavy template instantiation. +PolymorphicValue castToDtype(PolymorphicValue value, const DataType& dtype) { + if (!value.hasValue()) { + return value; + } + // Cast the given value to the given data type. This enables interface + // like: IrBuilder::create(0, DataType::Double) where value is + // an integer but the desired data type is double. + if (!hasCompatibleDataType(value, dtype)) { + PolymorphicValue::for_all_types([&](auto _) { + using T = typename decltype(_)::type; + if constexpr (IsPrimitiveNativeType::value) { + if (isCompatibleDataType(NativeTypeToDataType::type, dtype)) { + value = PolymorphicValue(static_cast(value)); + } + } + // TODO: support arrays and pointers + }); + } + return value; +} + bool StructHandle::operator==(const StructHandle& other) const { if (struct_ptr_ == other.struct_ptr_) { return true; diff --git a/csrc/type.h b/csrc/type.h index 649297229b9..fdbb1128927 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -414,41 +414,9 @@ DEFINE_DATATYPE_TO_NATIVE_TYPE(DataType::ComplexDouble, std::complex); #undef DEFINE_DATATYPE_TO_NATIVE_TYPE -inline DataType getDataType(const PolymorphicValue& value) { - std::optional dtype = std::nullopt; - PolymorphicValue::for_all_types([&value, &dtype](auto _) { - using T = typename decltype(_)::type; - if constexpr (IsPrimitiveNativeType::value) { - if (value.is()) { - dtype = NativeTypeToDataType::type; - } - } else if constexpr (std::is_same_v>) { - if (value.is()) { - const auto& vec = value.as(); - size_t size = vec.size(); - NVF_CHECK(size > 0, "Empty array is not supported"); - dtype = - ArrayType{std::make_shared(getDataType(vec[0])), size}; - } - } else if constexpr (std::is_same_v) { - // For pointers in polymorphic value, we only store the data size of the - // pointee, so it is impossible to infer the pointer type. - NVF_CHECK(!value.is(), "Can not infer pointer type."); - } else if constexpr (std::is_same_v) { - if (value.is()) { - dtype = value.as().type(); - } - } else if constexpr (std::is_same_v) { - if (value.is()) { - const auto& opaque = value.as(); - dtype = DataType(OpaqueType{ - .type_info = opaque.any().type(), .size = opaque.size()}); - } - } - }); - NVF_CHECK(dtype.has_value(), "Unknown dtype for ", value.type().name()); - return dtype.value(); -} +// Get the DataType corresponding to the runtime type held in a PolymorphicValue. +// Implementation moved to polymorphic_value.cpp to reduce template instantiation. +NVF_API DataType getDataType(const PolymorphicValue& value); inline bool isCompatibleDataType(DataType dtype, DataType dtype2) { if (dtype == dtype2) { @@ -1128,28 +1096,9 @@ Pointer::Pointer(void* ptr, DataType dtype) : ptr_(reinterpret_cast(ptr)), size_bit_(dataTypeSizeBit(dtype)) {} -inline PolymorphicValue castToDtype( - PolymorphicValue value, - const DataType& dtype) { - if (!value.hasValue()) { - return value; - } - // Cast the given value to the given data type. This enables interface - // like: IrBuilder::create(0, DataType::Double) where value is - // an integer but the desired data type is double. - if (!hasCompatibleDataType(value, dtype)) { - PolymorphicValue::for_all_types([&](auto _) { - using T = typename decltype(_)::type; - if constexpr (IsPrimitiveNativeType::value) { - if (isCompatibleDataType(NativeTypeToDataType::type, dtype)) { - value = PolymorphicValue(static_cast(value)); - } - } - // TODO: support arrays and pointers - }); - } - return value; -} +// Cast a PolymorphicValue to match the specified DataType. +// Implementation moved to polymorphic_value.cpp to reduce template instantiation. +NVF_API PolymorphicValue castToDtype(PolymorphicValue value, const DataType& dtype); // Converts an enum to its underlying type. // It corresponds with std::to_underlying introduced in c++23 From 11b14cc3a582daa16a3867bd149cf50fbab02425 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 24 Dec 2025 09:18:29 -0500 Subject: [PATCH 2/2] Move operator-using functions to polymorphic_value.cpp Move isSame, ceildiv, max, fmax, min, fmin from header to cpp to reduce template instantiation costs. These functions use PolymorphicValue operators that trigger ForAllTypes recursion. --- csrc/polymorphic_value.cpp | 70 +++++++++++++++++++++++++++++++++++ csrc/polymorphic_value.h | 76 ++++++++------------------------------ 2 files changed, 85 insertions(+), 61 deletions(-) diff --git a/csrc/polymorphic_value.cpp b/csrc/polymorphic_value.cpp index f9b4a26fb66..6e7a5fcaa03 100644 --- a/csrc/polymorphic_value.cpp +++ b/csrc/polymorphic_value.cpp @@ -103,6 +103,76 @@ bool StructHandle::operator==(const StructHandle& other) const { namespace PolymorphicValue_functions { +// Implementation of isSame - moved from polymorphic_value.h to reduce template bloat. +// Uses operator== which triggers ForAllTypes template instantiation. +bool isSame(const PolymorphicValue& a, const PolymorphicValue& b) { + if (a.type() != b.type()) { + return false; + } + if (a.is()) { + return (a.as().is_same(b.as())); + } + if (a.is()) { + return isSameNanSensitive(a.as(), b.as()); + } + if (a.is>()) { + return isSameNanSensitive( + a.as>(), b.as>()); + } + return a == b; +} + +// Implementation of ceildiv - moved from polymorphic_value.h to reduce template bloat. +// Uses operator/ which triggers ForAllTypes template instantiation. +PolymorphicValue ceildiv(const PolymorphicValue& a, const PolymorphicValue& b) { + if (a.is() && b.is()) { + auto aa = a.as(); + auto bb = b.as(); + if (bb > 0) { + return PolymorphicValue((aa + bb - 1) / bb); + } else { + return PolymorphicValue((aa + bb + 1) / bb); + } + } + return PolymorphicValue(std::ceil((a / b).as())); +} + +// Implementation of max - moved from polymorphic_value.h to reduce template bloat. +// Uses operator!= and operator> which trigger ForAllTypes template instantiation. +PolymorphicValue max(const PolymorphicValue& a, const PolymorphicValue& b) { + if (a != a) { + return PolymorphicValue(a); + } + return PolymorphicValue(a > b ? a : b); +} + +// Implementation of fmax - moved from polymorphic_value.h to reduce template bloat. +// Uses operator!= and operator< which trigger ForAllTypes template instantiation. +PolymorphicValue fmax(const PolymorphicValue& a, const PolymorphicValue& b) { + if (a != a) { + return PolymorphicValue(b); + } + return PolymorphicValue(a < b ? b : a); +} + +// Implementation of min - moved from polymorphic_value.h to reduce template bloat. +// Uses operator!= and operator< which trigger ForAllTypes template instantiation. +PolymorphicValue min(const PolymorphicValue& a, const PolymorphicValue& b) { + if (a != a) { + return PolymorphicValue(a); + } + return PolymorphicValue(a < b ? a : b); +} + +// Implementation of fmin - moved from polymorphic_value.h to reduce template bloat. +// Uses operator!= and operator> which trigger ForAllTypes template instantiation. +PolymorphicValue fmin(const PolymorphicValue& a, const PolymorphicValue& b) { + if (a != a) { + return PolymorphicValue(b); + } + return PolymorphicValue(a > b ? b : a); +} + size_t hash(const PolymorphicValue& v) { constexpr size_t nan_hash_value = 572491308; // NaNs are considered the same, so map all NaN values to same hash value. diff --git a/csrc/polymorphic_value.h b/csrc/polymorphic_value.h index 49b42555d79..0677cd26eb2 100644 --- a/csrc/polymorphic_value.h +++ b/csrc/polymorphic_value.h @@ -251,22 +251,9 @@ inline bool isSameNanSensitive(const T& a, const T& b) { return a == b; } -inline bool isSame(const PolymorphicValue& a, const PolymorphicValue& b) { - if (a.type() != b.type()) { - return false; - } - if (a.is()) { - return (a.as().is_same(b.as())); - } - if (a.is()) { - return isSameNanSensitive(a.as(), b.as()); - } - if (a.is>()) { - return isSameNanSensitive( - a.as>(), b.as>()); - } - return a == b; -} +// Declaration only - implementation in polymorphic_value.cpp +// Uses operator== which triggers ForAllTypes template instantiation +NVF_API bool isSame(const PolymorphicValue& a, const PolymorphicValue& b); inline PolymorphicValue signbit(const PolymorphicValue& a) { if (a.is()) { @@ -322,56 +309,23 @@ inline PolymorphicValue fmod( b.type().name()); } -inline PolymorphicValue ceildiv( +// Declarations only - implementations in polymorphic_value.cpp +// These functions use PolymorphicValue operators which trigger ForAllTypes +NVF_API PolymorphicValue ceildiv( const PolymorphicValue& a, - const PolymorphicValue& b) { - if (a.is() && b.is()) { - auto aa = a.as(); - auto bb = b.as(); - if (bb > 0) { - return PolymorphicValue((aa + bb - 1) / bb); - } else { - return PolymorphicValue((aa + bb + 1) / bb); - } - } - return PolymorphicValue(std::ceil((a / b).as())); -} - -inline PolymorphicValue max( + const PolymorphicValue& b); +NVF_API PolymorphicValue max( const PolymorphicValue& a, - const PolymorphicValue& b) { - if (a != a) { - return PolymorphicValue(a); - } - return PolymorphicValue(a > b ? a : b); -} - -inline PolymorphicValue fmax( + const PolymorphicValue& b); +NVF_API PolymorphicValue fmax( const PolymorphicValue& a, - const PolymorphicValue& b) { - if (a != a) { - return PolymorphicValue(b); - } - return PolymorphicValue(a < b ? b : a); -} - -inline PolymorphicValue min( + const PolymorphicValue& b); +NVF_API PolymorphicValue min( const PolymorphicValue& a, - const PolymorphicValue& b) { - if (a != a) { - return PolymorphicValue(a); - } - return PolymorphicValue(a < b ? a : b); -} - -inline PolymorphicValue fmin( + const PolymorphicValue& b); +NVF_API PolymorphicValue fmin( const PolymorphicValue& a, - const PolymorphicValue& b) { - if (a != a) { - return PolymorphicValue(b); - } - return PolymorphicValue(a > b ? b : a); -} + const PolymorphicValue& b); inline PolymorphicValue gcd( const PolymorphicValue& a,