diff --git a/src/Native/LibTorchSharp/THSTorch.cpp b/src/Native/LibTorchSharp/THSTorch.cpp index b846557bc..a6a079232 100644 --- a/src/Native/LibTorchSharp/THSTorch.cpp +++ b/src/Native/LibTorchSharp/THSTorch.cpp @@ -279,6 +279,11 @@ double THSTorch_scalar_to_float64(Scalar value) return value->toDouble(); } +void THSTorch_scalar_to_bfloat16(Scalar value, unsigned short* res) +{ + *res = value->toBFloat16().x; +} + void THSTorch_scalar_to_float16(Scalar value, unsigned short *res) { *res = value->toHalf().x; diff --git a/src/Native/LibTorchSharp/THSTorch.h b/src/Native/LibTorchSharp/THSTorch.h index 9ab80e828..ede2bcf54 100644 --- a/src/Native/LibTorchSharp/THSTorch.h +++ b/src/Native/LibTorchSharp/THSTorch.h @@ -77,6 +77,7 @@ EXPORT_API(float) THSTorch_scalar_to_float32(Scalar value); EXPORT_API(double) THSTorch_scalar_to_float64(Scalar value); EXPORT_API(bool) THSTorch_scalar_to_bool(Scalar value); +EXPORT_API(void) THSTorch_scalar_to_bfloat16(Scalar value, unsigned short* res); EXPORT_API(void) THSTorch_scalar_to_float16(Scalar value, unsigned short* res); EXPORT_API(void) THSTorch_scalar_to_complex32(Scalar value, float* (*allocator)(size_t length)); diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs index 3d3919ee3..25e4de947 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs @@ -62,6 +62,9 @@ internal static partial class NativeMethods [DllImport("LibTorchSharp")] internal static extern float THSTorch_scalar_to_float32(IntPtr handle); + [DllImport("LibTorchSharp")] + internal static extern void THSTorch_scalar_to_bfloat16(IntPtr value, out ushort res); + #if NET6_0_OR_GREATER [DllImport("LibTorchSharp")] internal static extern void THSTorch_scalar_to_float16(IntPtr value, out Half res);