From ff2c6ad8a2962181c22ad7d9f9eb1c60fd90198c Mon Sep 17 00:00:00 2001 From: ds5678 <49847914+ds5678@users.noreply.github.com> Date: Wed, 19 Nov 2025 00:18:52 -0800 Subject: [PATCH] Improve Scalar to Complex performance --- src/Native/LibTorchSharp/THSTorch.cpp | 14 +++++------- src/Native/LibTorchSharp/THSTorch.h | 4 ++-- .../PInvoke/LibTorchSharp.THSTorch.cs | 4 ++-- src/TorchSharp/Scalar.cs | 22 +++++-------------- 4 files changed, 16 insertions(+), 28 deletions(-) diff --git a/src/Native/LibTorchSharp/THSTorch.cpp b/src/Native/LibTorchSharp/THSTorch.cpp index b846557bc..4e55b637a 100644 --- a/src/Native/LibTorchSharp/THSTorch.cpp +++ b/src/Native/LibTorchSharp/THSTorch.cpp @@ -284,20 +284,18 @@ void THSTorch_scalar_to_float16(Scalar value, unsigned short *res) *res = value->toHalf().x; } -void THSTorch_scalar_to_complex32(Scalar value, float* (*allocator)(size_t length)) +void THSTorch_scalar_to_complex32(Scalar value, float* real, float* imaginary) { auto result = value->toComplexFloat(); - auto space = allocator(2); - space[0] = result.real(); - space[1] = result.imag(); + *real = result.real(); + *imaginary = result.imag(); } -void THSTorch_scalar_to_complex64(Scalar value, double* (*allocator)(size_t length)) +void THSTorch_scalar_to_complex64(Scalar value, double* real, double* imaginary) { auto result = value->toComplexDouble(); - auto space = allocator(2); - space[0] = result.real(); - space[1] = result.imag(); + *real = result.real(); + *imaginary = result.imag(); } bool THSTorch_scalar_to_bool(Scalar value) diff --git a/src/Native/LibTorchSharp/THSTorch.h b/src/Native/LibTorchSharp/THSTorch.h index 9ab80e828..df51cd990 100644 --- a/src/Native/LibTorchSharp/THSTorch.h +++ b/src/Native/LibTorchSharp/THSTorch.h @@ -79,8 +79,8 @@ EXPORT_API(bool) THSTorch_scalar_to_bool(Scalar value); EXPORT_API(void) THSTorch_scalar_to_float16(Scalar value, unsigned short* res); -EXPORT_API(void) THSTorch_scalar_to_complex32(Scalar value, float* (*allocator)(size_t length)); -EXPORT_API(void) THSTorch_scalar_to_complex64(Scalar value, double* (*allocator)(size_t length)); +EXPORT_API(void) THSTorch_scalar_to_complex32(Scalar value, float* real, float* imaginary); +EXPORT_API(void) THSTorch_scalar_to_complex64(Scalar value, double* real, double* imaginary); EXPORT_API(int8_t) THSTorch_scalar_type(Scalar value); diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs index 3d3919ee3..d3c64f960 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs @@ -86,10 +86,10 @@ internal static partial class NativeMethods internal static extern long THSTorch_scalar_to_int64(IntPtr handle); [DllImport("LibTorchSharp")] - internal static extern void THSTorch_scalar_to_complex32(IntPtr handle, AllocatePinnedArray allocator); + internal static extern void THSTorch_scalar_to_complex32(IntPtr handle, out float real, out float imaginary); [DllImport("LibTorchSharp")] - internal static extern void THSTorch_scalar_to_complex64(IntPtr handle, AllocatePinnedArray allocator); + internal static extern void THSTorch_scalar_to_complex64(IntPtr handle, out double real, out double imaginary); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTorch_get_and_reset_last_err(); diff --git a/src/TorchSharp/Scalar.cs b/src/TorchSharp/Scalar.cs index cfe92cd98..972039c0c 100644 --- a/src/TorchSharp/Scalar.cs +++ b/src/TorchSharp/Scalar.cs @@ -373,15 +373,10 @@ public static bool ToBoolean(this Scalar value) /// The input value. public static (float Real, float Imaginary) ToComplexFloat32(this Scalar value) { - float[] floatArray; + THSTorch_scalar_to_complex32(value.Handle, out float real, out float imaginary); + torch.CheckForErrors(); - using (var pa = new PinnedArray()) { - THSTorch_scalar_to_complex32(value.Handle, pa.CreateArray); - torch.CheckForErrors(); - floatArray = pa.Array; - } - - return (floatArray[0], floatArray[1]); + return (real, imaginary); } /// @@ -390,15 +385,10 @@ public static (float Real, float Imaginary) ToComplexFloat32(this Scalar value) /// The input value. public static System.Numerics.Complex ToComplexFloat64(this Scalar value) { - double[] floatArray; - - using (var pa = new PinnedArray()) { - THSTorch_scalar_to_complex64(value.Handle, pa.CreateArray); - torch.CheckForErrors(); - floatArray = pa.Array; - } + THSTorch_scalar_to_complex64(value.Handle, out double real, out double imaginary); + torch.CheckForErrors(); - return new System.Numerics.Complex(floatArray[0], floatArray[1]); + return new System.Numerics.Complex(real, imaginary); } } }