From ff2c6ad8a2962181c22ad7d9f9eb1c60fd90198c Mon Sep 17 00:00:00 2001
From: ds5678 <49847914+ds5678@users.noreply.github.com>
Date: Wed, 19 Nov 2025 00:18:52 -0800
Subject: [PATCH] Improve Scalar to Complex performance
---
src/Native/LibTorchSharp/THSTorch.cpp | 14 +++++-------
src/Native/LibTorchSharp/THSTorch.h | 4 ++--
.../PInvoke/LibTorchSharp.THSTorch.cs | 4 ++--
src/TorchSharp/Scalar.cs | 22 +++++--------------
4 files changed, 16 insertions(+), 28 deletions(-)
diff --git a/src/Native/LibTorchSharp/THSTorch.cpp b/src/Native/LibTorchSharp/THSTorch.cpp
index b846557bc..4e55b637a 100644
--- a/src/Native/LibTorchSharp/THSTorch.cpp
+++ b/src/Native/LibTorchSharp/THSTorch.cpp
@@ -284,20 +284,18 @@ void THSTorch_scalar_to_float16(Scalar value, unsigned short *res)
*res = value->toHalf().x;
}
-void THSTorch_scalar_to_complex32(Scalar value, float* (*allocator)(size_t length))
+void THSTorch_scalar_to_complex32(Scalar value, float* real, float* imaginary)
{
auto result = value->toComplexFloat();
- auto space = allocator(2);
- space[0] = result.real();
- space[1] = result.imag();
+ *real = result.real();
+ *imaginary = result.imag();
}
-void THSTorch_scalar_to_complex64(Scalar value, double* (*allocator)(size_t length))
+void THSTorch_scalar_to_complex64(Scalar value, double* real, double* imaginary)
{
auto result = value->toComplexDouble();
- auto space = allocator(2);
- space[0] = result.real();
- space[1] = result.imag();
+ *real = result.real();
+ *imaginary = result.imag();
}
bool THSTorch_scalar_to_bool(Scalar value)
diff --git a/src/Native/LibTorchSharp/THSTorch.h b/src/Native/LibTorchSharp/THSTorch.h
index 9ab80e828..df51cd990 100644
--- a/src/Native/LibTorchSharp/THSTorch.h
+++ b/src/Native/LibTorchSharp/THSTorch.h
@@ -79,8 +79,8 @@ EXPORT_API(bool) THSTorch_scalar_to_bool(Scalar value);
EXPORT_API(void) THSTorch_scalar_to_float16(Scalar value, unsigned short* res);
-EXPORT_API(void) THSTorch_scalar_to_complex32(Scalar value, float* (*allocator)(size_t length));
-EXPORT_API(void) THSTorch_scalar_to_complex64(Scalar value, double* (*allocator)(size_t length));
+EXPORT_API(void) THSTorch_scalar_to_complex32(Scalar value, float* real, float* imaginary);
+EXPORT_API(void) THSTorch_scalar_to_complex64(Scalar value, double* real, double* imaginary);
EXPORT_API(int8_t) THSTorch_scalar_type(Scalar value);
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs
index 3d3919ee3..d3c64f960 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs
@@ -86,10 +86,10 @@ internal static partial class NativeMethods
internal static extern long THSTorch_scalar_to_int64(IntPtr handle);
[DllImport("LibTorchSharp")]
- internal static extern void THSTorch_scalar_to_complex32(IntPtr handle, AllocatePinnedArray allocator);
+ internal static extern void THSTorch_scalar_to_complex32(IntPtr handle, out float real, out float imaginary);
[DllImport("LibTorchSharp")]
- internal static extern void THSTorch_scalar_to_complex64(IntPtr handle, AllocatePinnedArray allocator);
+ internal static extern void THSTorch_scalar_to_complex64(IntPtr handle, out double real, out double imaginary);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTorch_get_and_reset_last_err();
diff --git a/src/TorchSharp/Scalar.cs b/src/TorchSharp/Scalar.cs
index cfe92cd98..972039c0c 100644
--- a/src/TorchSharp/Scalar.cs
+++ b/src/TorchSharp/Scalar.cs
@@ -373,15 +373,10 @@ public static bool ToBoolean(this Scalar value)
/// The input value.
public static (float Real, float Imaginary) ToComplexFloat32(this Scalar value)
{
- float[] floatArray;
+ THSTorch_scalar_to_complex32(value.Handle, out float real, out float imaginary);
+ torch.CheckForErrors();
- using (var pa = new PinnedArray()) {
- THSTorch_scalar_to_complex32(value.Handle, pa.CreateArray);
- torch.CheckForErrors();
- floatArray = pa.Array;
- }
-
- return (floatArray[0], floatArray[1]);
+ return (real, imaginary);
}
///
@@ -390,15 +385,10 @@ public static (float Real, float Imaginary) ToComplexFloat32(this Scalar value)
/// The input value.
public static System.Numerics.Complex ToComplexFloat64(this Scalar value)
{
- double[] floatArray;
-
- using (var pa = new PinnedArray()) {
- THSTorch_scalar_to_complex64(value.Handle, pa.CreateArray);
- torch.CheckForErrors();
- floatArray = pa.Array;
- }
+ THSTorch_scalar_to_complex64(value.Handle, out double real, out double imaginary);
+ torch.CheckForErrors();
- return new System.Numerics.Complex(floatArray[0], floatArray[1]);
+ return new System.Numerics.Complex(real, imaginary);
}
}
}