From 991877dfc259cda16c44b99fd6b6a71ebc238c41 Mon Sep 17 00:00:00 2001 From: ds5678 <49847914+ds5678@users.noreply.github.com> Date: Sat, 15 Nov 2025 11:56:14 -0800 Subject: [PATCH] Add THSTorch_scalar_to_bfloat16 --- src/Native/LibTorchSharp/THSTorch.cpp | 5 +++++ src/Native/LibTorchSharp/THSTorch.h | 1 + src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs | 3 +++ 3 files changed, 9 insertions(+) diff --git a/src/Native/LibTorchSharp/THSTorch.cpp b/src/Native/LibTorchSharp/THSTorch.cpp index b846557bc..a6a079232 100644 --- a/src/Native/LibTorchSharp/THSTorch.cpp +++ b/src/Native/LibTorchSharp/THSTorch.cpp @@ -279,6 +279,11 @@ double THSTorch_scalar_to_float64(Scalar value) return value->toDouble(); } +void THSTorch_scalar_to_bfloat16(Scalar value, unsigned short* res) +{ + *res = value->toBFloat16().x; +} + void THSTorch_scalar_to_float16(Scalar value, unsigned short *res) { *res = value->toHalf().x; diff --git a/src/Native/LibTorchSharp/THSTorch.h b/src/Native/LibTorchSharp/THSTorch.h index 9ab80e828..ede2bcf54 100644 --- a/src/Native/LibTorchSharp/THSTorch.h +++ b/src/Native/LibTorchSharp/THSTorch.h @@ -77,6 +77,7 @@ EXPORT_API(float) THSTorch_scalar_to_float32(Scalar value); EXPORT_API(double) THSTorch_scalar_to_float64(Scalar value); EXPORT_API(bool) THSTorch_scalar_to_bool(Scalar value); +EXPORT_API(void) THSTorch_scalar_to_bfloat16(Scalar value, unsigned short* res); EXPORT_API(void) THSTorch_scalar_to_float16(Scalar value, unsigned short* res); EXPORT_API(void) THSTorch_scalar_to_complex32(Scalar value, float* (*allocator)(size_t length)); diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs index 3d3919ee3..25e4de947 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs @@ -62,6 +62,9 @@ internal static partial class NativeMethods [DllImport("LibTorchSharp")] internal static extern float THSTorch_scalar_to_float32(IntPtr handle); + [DllImport("LibTorchSharp")] + internal static extern void THSTorch_scalar_to_bfloat16(IntPtr value, out ushort res); + #if NET6_0_OR_GREATER [DllImport("LibTorchSharp")] internal static extern void THSTorch_scalar_to_float16(IntPtr value, out Half res);