Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions src/Native/LibTorchSharp/THSTorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,20 +284,18 @@ void THSTorch_scalar_to_float16(Scalar value, unsigned short *res)
*res = value->toHalf().x;
}

void THSTorch_scalar_to_complex32(Scalar value, float* (*allocator)(size_t length))
void THSTorch_scalar_to_complex32(Scalar value, float* real, float* imaginary)
{
auto result = value->toComplexFloat();
auto space = allocator(2);
space[0] = result.real();
space[1] = result.imag();
*real = result.real();
*imaginary = result.imag();
}

void THSTorch_scalar_to_complex64(Scalar value, double* (*allocator)(size_t length))
void THSTorch_scalar_to_complex64(Scalar value, double* real, double* imaginary)
{
auto result = value->toComplexDouble();
auto space = allocator(2);
space[0] = result.real();
space[1] = result.imag();
*real = result.real();
*imaginary = result.imag();
}

bool THSTorch_scalar_to_bool(Scalar value)
Expand Down
4 changes: 2 additions & 2 deletions src/Native/LibTorchSharp/THSTorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ EXPORT_API(bool) THSTorch_scalar_to_bool(Scalar value);

EXPORT_API(void) THSTorch_scalar_to_float16(Scalar value, unsigned short* res);

EXPORT_API(void) THSTorch_scalar_to_complex32(Scalar value, float* (*allocator)(size_t length));
EXPORT_API(void) THSTorch_scalar_to_complex64(Scalar value, double* (*allocator)(size_t length));
EXPORT_API(void) THSTorch_scalar_to_complex32(Scalar value, float* real, float* imaginary);
EXPORT_API(void) THSTorch_scalar_to_complex64(Scalar value, double* real, double* imaginary);

EXPORT_API(int8_t) THSTorch_scalar_type(Scalar value);

Expand Down
4 changes: 2 additions & 2 deletions src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ internal static partial class NativeMethods
internal static extern long THSTorch_scalar_to_int64(IntPtr handle);

[DllImport("LibTorchSharp")]
internal static extern void THSTorch_scalar_to_complex32(IntPtr handle, AllocatePinnedArray allocator);
internal static extern void THSTorch_scalar_to_complex32(IntPtr handle, out float real, out float imaginary);

[DllImport("LibTorchSharp")]
internal static extern void THSTorch_scalar_to_complex64(IntPtr handle, AllocatePinnedArray allocator);
internal static extern void THSTorch_scalar_to_complex64(IntPtr handle, out double real, out double imaginary);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTorch_get_and_reset_last_err();
Expand Down
22 changes: 6 additions & 16 deletions src/TorchSharp/Scalar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,10 @@ public static bool ToBoolean(this Scalar value)
/// <param name="value">The input value.</param>
public static (float Real, float Imaginary) ToComplexFloat32(this Scalar value)
{
float[] floatArray;
THSTorch_scalar_to_complex32(value.Handle, out float real, out float imaginary);
torch.CheckForErrors();

using (var pa = new PinnedArray<float>()) {
THSTorch_scalar_to_complex32(value.Handle, pa.CreateArray);
torch.CheckForErrors();
floatArray = pa.Array;
}

return (floatArray[0], floatArray[1]);
return (real, imaginary);
}

/// <summary>
Expand All @@ -390,15 +385,10 @@ public static (float Real, float Imaginary) ToComplexFloat32(this Scalar value)
/// <param name="value">The input value.</param>
public static System.Numerics.Complex ToComplexFloat64(this Scalar value)
{
double[] floatArray;

using (var pa = new PinnedArray<double>()) {
THSTorch_scalar_to_complex64(value.Handle, pa.CreateArray);
torch.CheckForErrors();
floatArray = pa.Array;
}
THSTorch_scalar_to_complex64(value.Handle, out double real, out double imaginary);
torch.CheckForErrors();

return new System.Numerics.Complex(floatArray[0], floatArray[1]);
return new System.Numerics.Complex(real, imaginary);
}
}
}