Skip to content

Commit bd8ded6

Browse files
Add full BFloat16 managed type support
Implement a custom BFloat16 struct (2 bytes, binary-compatible with c10::BFloat16) with round-to-nearest-even conversion matching PyTorch. New features: - BFloat16 struct with arithmetic, comparison, and conversion operators - data<BFloat16>() and item<BFloat16>() for zero-copy tensor data access - torch.tensor(BFloat16[], ...) factory overloads for 1D-4D arrays - Scalar implicit conversion and ToBFloat16() extraction - ToBFloat16() tensor extension method - ToTensor<T> support for BFloat16 - as_tensor() BFloat16 overloads - frombuffer offset fix for BFloat16 and Float16 (2-byte types) Fixes #1469 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 6beec00 commit bd8ded6

File tree

8 files changed

+457
-2
lines changed

8 files changed

+457
-2
lines changed

src/TorchSharp/BFloat16.cs

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
2+
using System;
3+
using System.Globalization;
4+
using System.Runtime.InteropServices;
5+
6+
#nullable enable
7+
namespace TorchSharp
8+
{
9+
/// <summary>
10+
/// Represents a 16-bit brain floating-point number (BFloat16).
11+
/// Binary layout: 1 sign bit, 8 exponent bits, 7 mantissa bits — the upper 16 bits of IEEE 754 float32.
12+
/// Binary-compatible with c10::BFloat16 in LibTorch.
13+
/// </summary>
14+
[StructLayout(LayoutKind.Sequential)]
15+
public readonly struct BFloat16 : IComparable<BFloat16>, IEquatable<BFloat16>, IComparable, IFormattable
16+
{
17+
internal readonly ushort value;
18+
19+
internal BFloat16(ushort rawValue, bool _)
20+
{
21+
value = rawValue;
22+
}
23+
24+
/// <summary>
25+
/// Creates a BFloat16 from a float value using round-to-nearest-even (matching PyTorch c10::BFloat16).
26+
/// </summary>
27+
public BFloat16(float f)
28+
{
29+
value = FloatToBFloat16Bits(f);
30+
}
31+
32+
/// <summary>
33+
/// Creates a BFloat16 from the raw 16-bit representation.
34+
/// </summary>
35+
public static BFloat16 FromRawValue(ushort rawValue) => new BFloat16(rawValue, false);
36+
37+
// --- Conversion to/from float ---
38+
39+
private static unsafe ushort FloatToBFloat16Bits(float f)
40+
{
41+
uint bits = *(uint*)&f;
42+
// NaN: preserve payload, just truncate
43+
if ((bits & 0x7F800000u) == 0x7F800000u && (bits & 0x007FFFFFu) != 0)
44+
return (ushort)(bits >> 16 | 0x0040u); // quiet NaN
45+
// Round-to-nearest-even (matching PyTorch c10::BFloat16)
46+
uint lsb = (bits >> 16) & 1u;
47+
uint roundingBias = 0x7FFFu + lsb;
48+
bits += roundingBias;
49+
return (ushort)(bits >> 16);
50+
}
51+
52+
private static unsafe float BFloat16BitsToFloat(ushort raw)
53+
{
54+
int bits = raw << 16;
55+
return *(float*)&bits;
56+
}
57+
58+
/// <summary>
59+
/// Converts this BFloat16 to a float.
60+
/// </summary>
61+
public float ToSingle() => BFloat16BitsToFloat(value);
62+
63+
// --- Conversion operators ---
64+
65+
public static explicit operator float(BFloat16 bf) => bf.ToSingle();
66+
public static explicit operator double(BFloat16 bf) => bf.ToSingle();
67+
public static explicit operator BFloat16(float f) => new BFloat16(f);
68+
public static explicit operator BFloat16(double d) => new BFloat16((float)d);
69+
70+
// --- Arithmetic operators (promote to float, truncate back) ---
71+
72+
public static BFloat16 operator +(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() + b.ToSingle());
73+
public static BFloat16 operator -(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() - b.ToSingle());
74+
public static BFloat16 operator *(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() * b.ToSingle());
75+
public static BFloat16 operator /(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() / b.ToSingle());
76+
public static BFloat16 operator %(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() % b.ToSingle());
77+
public static BFloat16 operator -(BFloat16 a) => new BFloat16(-a.ToSingle());
78+
79+
// --- Comparison operators ---
80+
81+
public static bool operator ==(BFloat16 a, BFloat16 b) => a.ToSingle() == b.ToSingle();
82+
public static bool operator !=(BFloat16 a, BFloat16 b) => a.ToSingle() != b.ToSingle();
83+
public static bool operator <(BFloat16 a, BFloat16 b) => a.ToSingle() < b.ToSingle();
84+
public static bool operator >(BFloat16 a, BFloat16 b) => a.ToSingle() > b.ToSingle();
85+
public static bool operator <=(BFloat16 a, BFloat16 b) => a.ToSingle() <= b.ToSingle();
86+
public static bool operator >=(BFloat16 a, BFloat16 b) => a.ToSingle() >= b.ToSingle();
87+
88+
// --- IEquatable / IComparable ---
89+
90+
public bool Equals(BFloat16 other) => value == other.value;
91+
public override bool Equals(object? obj) => obj is BFloat16 other && Equals(other);
92+
public override int GetHashCode() => value.GetHashCode();
93+
94+
public int CompareTo(BFloat16 other) => ToSingle().CompareTo(other.ToSingle());
95+
public int CompareTo(object? obj)
96+
{
97+
if (obj is null) return 1;
98+
if (obj is BFloat16 other) return CompareTo(other);
99+
throw new ArgumentException("Object must be of type BFloat16.");
100+
}
101+
102+
// --- Formatting ---
103+
104+
public override string ToString() => ToSingle().ToString();
105+
public string ToString(string? format, IFormatProvider? formatProvider) => ToSingle().ToString(format, formatProvider);
106+
107+
// --- Constants ---
108+
109+
public static readonly BFloat16 Zero = FromRawValue(0x0000);
110+
public static readonly BFloat16 One = FromRawValue(0x3F80);
111+
public static readonly BFloat16 NaN = FromRawValue(0x7FC0);
112+
public static readonly BFloat16 PositiveInfinity = FromRawValue(0x7F80);
113+
public static readonly BFloat16 NegativeInfinity = FromRawValue(0xFF80);
114+
public static readonly BFloat16 MaxValue = FromRawValue(0x7F7F); // ~3.39e+38
115+
public static readonly BFloat16 MinValue = FromRawValue(0xFF7F); // ~-3.39e+38
116+
public static readonly BFloat16 Epsilon = FromRawValue(0x0080); // smallest normal
117+
public static readonly BFloat16 SmallestSubnormal = FromRawValue(0x0001);
118+
119+
// --- Static helpers ---
120+
121+
public static bool IsNaN(BFloat16 bf) => float.IsNaN(bf.ToSingle());
122+
public static bool IsInfinity(BFloat16 bf) => float.IsInfinity(bf.ToSingle());
123+
public static bool IsPositiveInfinity(BFloat16 bf) => float.IsPositiveInfinity(bf.ToSingle());
124+
public static bool IsNegativeInfinity(BFloat16 bf) => float.IsNegativeInfinity(bf.ToSingle());
125+
public static bool IsFinite(BFloat16 bf) => !IsInfinity(bf) && !IsNaN(bf);
126+
}
127+
}

src/TorchSharp/Scalar.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,15 @@ public static implicit operator Scalar(Half value)
8181
}
8282
#endif
8383

84+
/// <summary>
85+
/// Implicitly convert a BFloat16 value to Scalar
86+
/// </summary>
87+
/// <param name="value">The scalar value.</param>
88+
public static implicit operator Scalar(BFloat16 value)
89+
{
90+
return value.ToScalar();
91+
}
92+
8493
/// <summary>
8594
/// Implicitly convert a .NET scalar value to Scalar
8695
/// </summary>
@@ -282,6 +291,16 @@ public static Scalar ToBFloat16Scalar(this float value)
282291
return new Scalar(THSTorch_bfloat16_to_scalar(value));
283292
}
284293

294+
/// <summary>
295+
/// Explicitly construct a Scalar from a BFloat16 value.
296+
/// </summary>
297+
/// <param name="value">The input scalar value</param>
298+
public static Scalar ToScalar(this BFloat16 value)
299+
{
300+
torch.InitializeDeviceType(DeviceType.CPU);
301+
return new Scalar(THSTorch_bfloat16_to_scalar(value.ToSingle()));
302+
}
303+
285304
#if NET6_0_OR_GREATER
286305
/// <summary>
287306
/// Explicitly convert a Scalar value to a .NET scalar
@@ -295,6 +314,16 @@ public static Half ToHalf(this Scalar value)
295314
}
296315
#endif
297316

317+
/// <summary>
318+
/// Explicitly convert a Scalar value to a BFloat16.
319+
/// </summary>
320+
/// <param name="value">The input value.</param>
321+
public static BFloat16 ToBFloat16(this Scalar value)
322+
{
323+
THSTorch_scalar_to_bfloat16(value.Handle, out ushort res);
324+
return BFloat16.FromRawValue(res);
325+
}
326+
298327
/// <summary>
299328
/// Explicitly convert a Scalar value to a .NET scalar
300329
/// </summary>

src/TorchSharp/Tensor/Factories/Tensor.Factories.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ public static Tensor frombuffer(Array rawArray, ScalarType dtype, long count = -
370370

371371
switch (origType) {
372372
case ScalarType.Int16:
373+
case ScalarType.Float16:
374+
case ScalarType.BFloat16:
373375
offset *= 2;
374376
break;
375377
case ScalarType.Int32:

src/TorchSharp/Tensor/Factories/as_tensor.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,5 +121,15 @@ public static Tensor as_tensor(System.Numerics.Complex[] rawArray, torch.ScalarT
121121
{
122122
return torch.from_array(rawArray, dtype, device);
123123
}
124+
125+
public static Tensor as_tensor(IList<BFloat16> rawArray, torch.ScalarType? dtype = null, torch.Device? device = null)
126+
{
127+
return torch.from_array(rawArray.ToArray(), dtype, device);
128+
}
129+
130+
public static Tensor as_tensor(BFloat16[] rawArray, torch.ScalarType? dtype = null, torch.Device? device = null)
131+
{
132+
return torch.from_array(rawArray, dtype, device);
133+
}
124134
}
125135
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Diagnostics.Contracts;
5+
using System.Linq;
6+
7+
#nullable enable
8+
namespace TorchSharp
9+
{
10+
public static partial class torch
11+
{
12+
/// <summary>
13+
/// Create a tensor from an array of values, shaping it based on the shape passed in.
14+
/// </summary>
15+
/// <remarks>The Torch runtime does not take ownership of the data, so there is no device argument.</remarks>
16+
[Pure]
17+
public static Tensor tensor(IList<BFloat16> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
18+
{
19+
return _tensor_generic(rawArray.ToArray(), dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, false, names);
20+
}
21+
22+
/// <summary>
23+
/// Create a tensor from an array of values, shaping it based on the shape passed in.
24+
/// </summary>
25+
[Pure]
26+
public static Tensor tensor(BFloat16[] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
27+
{
28+
return _tensor_generic(rawArray, stackalloc long[] { rawArray.LongLength }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
29+
}
30+
31+
/// <summary>
32+
/// Create a tensor from an array of values, shaping it based on the shape passed in.
33+
/// </summary>
34+
[Pure]
35+
public static Tensor tensor(BFloat16[] rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
36+
{
37+
return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
38+
}
39+
40+
/// <summary>
41+
/// Create a 1-D tensor from an array of values, shaping it based on the input array.
42+
/// </summary>
43+
/// <remarks>The Torch runtime does not take ownership of the data, so there is no device argument.</remarks>
44+
[Pure]
45+
public static Tensor tensor(IList<BFloat16> rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
46+
{
47+
return tensor(rawArray, stackalloc long[] { (long)rawArray.Count }, dtype, device, requires_grad, names: names);
48+
}
49+
50+
/// <summary>
51+
/// Create a tensor from an array of values, organizing it as a two-dimensional tensor.
52+
/// </summary>
53+
/// <remarks>
54+
/// The Torch runtime does not take ownership of the data, so there is no device argument.
55+
/// The input array must have rows * columns elements.
56+
/// </remarks>
57+
[Pure]
58+
public static Tensor tensor(IList<BFloat16> rawArray, long rows, long columns, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
59+
{
60+
return tensor(rawArray, stackalloc long[] { rows, columns }, dtype, device, requires_grad, names: names);
61+
}
62+
63+
/// <summary>
64+
/// Create a tensor from an array of values, organizing it as a three-dimensional tensor.
65+
/// </summary>
66+
/// <remarks>
67+
/// The Torch runtime does not take ownership of the data, so there is no device argument.
68+
/// The input array must have dim0*dim1*dim2 elements.
69+
/// </remarks>
70+
[Pure]
71+
public static Tensor tensor(IList<BFloat16> rawArray, long dim0, long dim1, long dim2, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
72+
{
73+
return tensor(rawArray, stackalloc long[] { dim0, dim1, dim2 }, dtype, device, requires_grad, names: names);
74+
}
75+
76+
/// <summary>
77+
/// Create a tensor from an array of values, organizing it as a four-dimensional tensor.
78+
/// </summary>
79+
/// <remarks>
80+
/// The Torch runtime does not take ownership of the data, so there is no device argument.
81+
/// The input array must have dim0*dim1*dim2*dim3 elements.
82+
/// </remarks>
83+
[Pure]
84+
public static Tensor tensor(IList<BFloat16> rawArray, long dim0, long dim1, long dim2, long dim3, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
85+
{
86+
return tensor(rawArray, stackalloc long[] { dim0, dim1, dim2, dim3 }, dtype, device, requires_grad, names: names);
87+
}
88+
89+
/// <summary>
90+
/// Create a two-dimensional tensor from a two-dimensional array of values.
91+
/// </summary>
92+
[Pure]
93+
public static Tensor tensor(BFloat16[,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
94+
{
95+
return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
96+
}
97+
98+
/// <summary>
99+
/// Create a three-dimensional tensor from a three-dimensional array of values.
100+
/// </summary>
101+
[Pure]
102+
public static Tensor tensor(BFloat16[,,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
103+
{
104+
return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
105+
}
106+
107+
/// <summary>
108+
/// Create a four-dimensional tensor from a four-dimensional array of values.
109+
/// </summary>
110+
[Pure]
111+
public static Tensor tensor(BFloat16[,,,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
112+
{
113+
return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2), rawArray.GetLongLength(3) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
114+
}
115+
116+
/// <summary>
117+
/// Create a tensor from an array of values, shaping it based on the shape passed in.
118+
/// </summary>
119+
[Pure]
120+
public static Tensor tensor(Memory<BFloat16> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
121+
{
122+
return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names);
123+
}
124+
}
125+
}

src/TorchSharp/Tensor/Tensor.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,9 @@ internal void ValidateType(Type dotnetType)
475475
throw new ArgumentException($"{dotnetType.Name} is not compatible with {dtype.ToString()}");
476476
break;
477477
case ScalarType.BFloat16:
478-
throw new ArgumentException($"No support for {dtype.ToString()} in TorchSharp");
478+
if (dotnetType != typeof(BFloat16))
479+
throw new ArgumentException($"{dotnetType.Name} is not compatible with {dtype.ToString()}");
480+
break;
479481
case ScalarType.Float16:
480482
#if NET6_0_OR_GREATER
481483
if (dotnetType != typeof(Half))
@@ -7471,6 +7473,7 @@ public enum ScalarType : sbyte
74717473
{ typeof(short), ScalarType.Int16 },
74727474
{ typeof(int), ScalarType.Int32 },
74737475
{ typeof(long), ScalarType.Int64 },
7476+
{ typeof(BFloat16), ScalarType.BFloat16 },
74747477
#if NET6_0_OR_GREATER
74757478
{ typeof(Half), ScalarType.Float16 },
74767479
#endif

src/TorchSharp/Tensor/TensorExtensionMethods.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,8 @@ true when typeof(T) == typeof(float) => tensor((array as float[])!, dimensions,
578578
requires_grad: requires_grad),
579579
true when typeof(T) == typeof(bool) => tensor((array as bool[])!, dimensions,
580580
requires_grad: requires_grad),
581+
true when typeof(T) == typeof(BFloat16) => tensor((array as BFloat16[])!, dimensions,
582+
requires_grad: requires_grad),
581583
_ => throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.")
582584
};
583585
}
@@ -592,7 +594,8 @@ true when typeof(T) == typeof(bool) => tensor((array as bool[])!, dimensions,
592594
/// <returns></returns>
593595
public static Tensor ToTensor<T>(this T scalar, Device? device = null, bool requires_grad = false) where T : struct
594596
{
595-
if (requires_grad && typeof(T) != typeof(float) && typeof(T) != typeof(double)) {
597+
if (requires_grad && typeof(T) != typeof(float) && typeof(T) != typeof(double)
598+
&& typeof(T) != typeof(BFloat16)) {
596599
throw new ArgumentException(nameof(requires_grad), "Only floating point types support gradients.");
597600
}
598601

@@ -610,6 +613,8 @@ public static Tensor ToTensor<T>(this T scalar, Device? device = null, bool requ
610613
return tensor((float)(object)scalar, float32, device, requires_grad);
611614
if (typeof(T) == typeof(double))
612615
return tensor((double)(object)scalar, float64, device, requires_grad);
616+
if (typeof(T) == typeof(BFloat16))
617+
return tensor(new BFloat16[] { (BFloat16)(object)scalar }, bfloat16, device, requires_grad);
613618
throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
614619
}
615620

@@ -633,6 +638,12 @@ public static Tensor ToTensor<T>(this T scalar, Device? device = null, bool requ
633638
public static Half ToHalf(this Tensor value) => value.ToScalar().ToHalf();
634639
#endif
635640

641+
/// <summary>
642+
/// Explicitly convert a singleton tensor to a BFloat16 value.
643+
/// </summary>
644+
/// <param name="value">The input tensor</param>
645+
public static BFloat16 ToBFloat16(this Tensor value) => value.ToScalar().ToBFloat16();
646+
636647
/// <summary>
637648
/// Explicitly convert a singleton tensor to a .NET scalar value.
638649
/// </summary>

0 commit comments

Comments
 (0)