@@ -363,42 +363,74 @@ DS_D_INLINE __nv_bfloat16 to(float val)
363363template <>
364364DS_D_INLINE __nv_bfloat16 to (int64_t val)
365365{
366+ #ifdef __HIP_PLATFORM_AMD__
367+ return __double2bfloat16 (__ll2double_rn (val));
368+ #else
366369 return __ll2bfloat16_rn (val);
370+ #endif
367371}
368372template <>
369373DS_D_INLINE __nv_bfloat16 to (int32_t val)
370374{
375+ #ifdef __HIP_PLATFORM_AMD__
376+ return __float2bfloat16 (__int2float_rn (val));
377+ #else
371378 return __int2bfloat16_rn (val);
379+ #endif
372380}
373381template <>
374382DS_D_INLINE __nv_bfloat16 to (int16_t val)
375383{
384+ #ifdef __HIP_PLATFORM_AMD__
385+ return __float2bfloat16 (__int2float_rn (val));
386+ #else
376387 return __short2bfloat16_rn (val);
388+ #endif
377389}
378390template <>
379391DS_D_INLINE __nv_bfloat16 to (int8_t val)
380392{
393+ #ifdef __HIP_PLATFORM_AMD__
394+ return __float2bfloat16 (__int2float_rn (val));
395+ #else
381396 return __int2bfloat16_rn (val);
397+ #endif
382398}
383399template <>
384400DS_D_INLINE __nv_bfloat16 to (uint64_t val)
385401{
402+ #ifdef __HIP_PLATFORM_AMD__
403+ return __double2bfloat16 (__ull2double_rn (val));
404+ #else
386405 return __ull2bfloat16_rn (val);
406+ #endif
387407}
388408template <>
389409DS_D_INLINE __nv_bfloat16 to (uint32_t val)
390410{
411+ #ifdef __HIP_PLATFORM_AMD__
412+ return __float2bfloat16 (__uint2float_rn (val));
413+ #else
391414 return __uint2bfloat16_rn (val);
415+ #endif
392416}
393417template <>
394418DS_D_INLINE __nv_bfloat16 to (uint16_t val)
395419{
420+ #ifdef __HIP_PLATFORM_AMD__
421+ return __float2bfloat16 (__uint2float_rn (val));
422+ #else
396423 return __ushort2bfloat16_rn (val);
424+ #endif
397425}
398426template <>
399427DS_D_INLINE __nv_bfloat16 to (uint8_t val)
400428{
429+ #ifdef __HIP_PLATFORM_AMD__
430+ return __float2bfloat16 (__uint2float_rn (val));
431+ #else
401432 return __uint2bfloat16_rn (val);
433+ #endif
402434}
403435#endif
404436
@@ -412,7 +444,11 @@ DS_D_INLINE __nv_bfloat162 to(float2 val)
412444template <>
413445DS_D_INLINE __nv_bfloat162 to (float val)
414446{
447+ #ifdef __HIP_PLATFORM_AMD__
448+ return __bfloat162bfloat162 (__float2bfloat16 (val));
449+ #else
415450 return __float2bfloat162_rn (val);
451+ #endif
416452}
417453template <>
418454DS_D_INLINE __nv_bfloat162 to (__half2 val)
@@ -444,7 +480,11 @@ DS_D_INLINE int64_t to(__half val)
444480template <>
445481DS_D_INLINE int64_t to (__nv_bfloat16 val)
446482{
483+ #ifdef __HIP_PLATFORM_AMD__
484+ return __float2ll_rn (__bfloat162float (val));
485+ #else
447486 return __bfloat162ll_rn (val);
487+ #endif
448488}
449489#endif
450490
@@ -471,7 +511,11 @@ DS_D_INLINE int32_t to(__half val)
471511template <>
472512DS_D_INLINE int32_t to (__nv_bfloat16 val)
473513{
514+ #ifdef __HIP_PLATFORM_AMD__
515+ return __float2int_rn (__bfloat162float (val));
516+ #else
474517 return __bfloat162int_rn (val);
518+ #endif
475519}
476520#endif
477521
@@ -498,7 +542,11 @@ DS_D_INLINE int16_t to(__half val)
498542template <>
499543DS_D_INLINE int16_t to (__nv_bfloat16 val)
500544{
545+ #ifdef __HIP_PLATFORM_AMD__
546+ return __float2int_rn (__bfloat162float (val));
547+ #else
501548 return __bfloat162int_rn (val);
549+ #endif
502550}
503551#endif
504552
@@ -525,7 +573,11 @@ DS_D_INLINE int8_t to(__half val)
525573template <>
526574DS_D_INLINE int8_t to (__nv_bfloat16 val)
527575{
576+ #ifdef __HIP_PLATFORM_AMD__
577+ return __float2int_rn (__bfloat162float (val));
578+ #else
528579 return __bfloat162int_rn (val);
580+ #endif
529581}
530582#endif
531583
@@ -552,7 +604,11 @@ DS_D_INLINE uint64_t to(__half val)
552604template <>
553605DS_D_INLINE uint64_t to (__nv_bfloat16 val)
554606{
607+ #ifdef __HIP_PLATFORM_AMD__
608+ return __float2ull_rn (__bfloat162float (val));
609+ #else
555610 return __bfloat162ull_rn (val);
611+ #endif
556612}
557613#endif
558614
@@ -579,7 +635,11 @@ DS_D_INLINE uint32_t to(__half val)
579635template <>
580636DS_D_INLINE uint32_t to (__nv_bfloat16 val)
581637{
638+ #ifdef __HIP_PLATFORM_AMD__
639+ return __float2uint_rn (__bfloat162float (val));
640+ #else
582641 return __bfloat162uint_rn (val);
642+ #endif
583643}
584644#endif
585645
@@ -606,7 +666,11 @@ DS_D_INLINE uint16_t to(__half val)
606666template <>
607667DS_D_INLINE uint16_t to (__nv_bfloat16 val)
608668{
669+ #ifdef __HIP_PLATFORM_AMD__
670+ return __float2uint_rn (__bfloat162float (val));
671+ #else
609672 return __bfloat162uint_rn (val);
673+ #endif
610674}
611675#endif
612676
@@ -633,7 +697,11 @@ DS_D_INLINE uint8_t to(__half val)
633697template <>
634698DS_D_INLINE uint8_t to (__nv_bfloat16 val)
635699{
700+ #ifdef __HIP_PLATFORM_AMD__
701+ return __float2uint_rn (__bfloat162float (val));
702+ #else
636703 return __bfloat162uint_rn (val);
704+ #endif
637705}
638706#endif
639707
0 commit comments