@@ -85,8 +85,8 @@ constexpr bool is_fp8_v = false;
8585 * @brief Type trait to check if T is any floating point type (including half)
8686 */
8787template <typename T>
88- struct is_floating : std::integral_constant< bool ,
89- std::is_floating_point <T>::value || is_half_v<T> || is_fp8_v<T>
88+ struct is_floating : std::bool_constant<
89+ std::is_floating_point_v <T> || is_half_v<T> || is_fp8_v<T>
9090> {};
9191
9292template <typename T>
@@ -209,40 +209,36 @@ constexpr size_t type_bits_v = type_bits<T>::value;
209209 */
210210template <typename T>
211211TC_HOST_DEVICE_INLINE float to_float (T val) {
212- return static_cast <float >(val);
213- }
214-
215- template <>
216- TC_HOST_DEVICE_INLINE float to_float<__half>(__half val) {
217- return __half2float (val);
218- }
219-
212+ if constexpr (std::is_same_v<T, __half>) {
213+ return __half2float (val);
214+ }
220215#if defined(TC_HAS_BF16)
221- template <>
222- TC_HOST_DEVICE_INLINE float to_float<__nv_bfloat16>(__nv_bfloat16 val) {
223- return __bfloat162float (val);
224- }
216+ else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
217+ return __bfloat162float (val);
218+ }
225219#endif
220+ else {
221+ return static_cast <float >(val);
222+ }
223+ }
226224
227225/* *
228226 * @brief Convert float to target type
229227 */
230228template <typename T>
231229TC_HOST_DEVICE_INLINE T from_float (float val) {
232- return static_cast <T>(val);
233- }
234-
235- template <>
236- TC_HOST_DEVICE_INLINE __half from_float<__half>(float val) {
237- return __float2half (val);
238- }
239-
230+ if constexpr (std::is_same_v<T, __half>) {
231+ return __float2half (val);
232+ }
240233#if defined(TC_HAS_BF16)
241- template <>
242- TC_HOST_DEVICE_INLINE __nv_bfloat16 from_float<__nv_bfloat16>(float val) {
243- return __float2bfloat16 (val);
244- }
234+ else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
235+ return __float2bfloat16 (val);
236+ }
245237#endif
238+ else {
239+ return static_cast <T>(val);
240+ }
241+ }
246242
247243// ============================================================================
248244// Data Type Enumeration
@@ -266,45 +262,26 @@ enum class DataType {
266262 * @brief Get DataType enum from C++ type
267263 */
268264template <typename T>
269- struct dtype_of {
270- static constexpr DataType value = DataType::FP32;
271- };
272-
273- template <>
274- struct dtype_of <float > {
275- static constexpr DataType value = DataType::FP32;
276- };
277-
278- template <>
279- struct dtype_of <__half> {
280- static constexpr DataType value = DataType::FP16;
281- };
282-
265+ constexpr DataType get_dtype () {
266+ if constexpr (std::is_same_v<T, float >) {
267+ return DataType::FP32;
268+ } else if constexpr (std::is_same_v<T, __half>) {
269+ return DataType::FP16;
270+ }
283271#if defined(TC_HAS_BF16)
284- template <>
285- struct dtype_of <__nv_bfloat16> {
286- static constexpr DataType value = DataType::BF16;
287- };
272+ else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
273+ return DataType::BF16;
274+ }
288275#endif
289-
290- template <>
291- struct dtype_of <int8_t > {
292- static constexpr DataType value = DataType::INT8;
293- };
294-
295- template <>
296- struct dtype_of <int32_t > {
297- static constexpr DataType value = DataType::INT32;
298- };
299-
300- template <>
301- struct dtype_of <int64_t > {
302- static constexpr DataType value = DataType::INT64;
303- };
304-
305- template <typename T>
306- constexpr DataType get_dtype () {
307- return dtype_of<T>::value;
276+ else if constexpr (std::is_same_v<T, int8_t >) {
277+ return DataType::INT8;
278+ } else if constexpr (std::is_same_v<T, int32_t >) {
279+ return DataType::INT32;
280+ } else if constexpr (std::is_same_v<T, int64_t >) {
281+ return DataType::INT64;
282+ } else {
283+ return DataType::FP32;
284+ }
308285}
309286
310287/* *
0 commit comments