Skip to content

Commit 70b1a4e

Browse files
Expose Storage.data_ptr() as public API
Change the access modifier of Storage.data_ptr() from protected to public, allowing external consumers to obtain the raw data pointer for interop with other native libraries (e.g. bitsandbytes). Fixes #1454 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent b2a2f20 commit 70b1a4e

2 files changed

Lines changed: 15 additions & 1 deletion

File tree

src/TorchSharp/Tensor/Storage.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ protected static Tensor CreateTypedTensor<T>(ScalarType dtype, IList<T> rawArray
139139
/// A pointer to the raw data in memory.
140140
/// </summary>
141141
/// <returns></returns>
142-
protected IntPtr data_ptr()
142+
public IntPtr data_ptr()
143143
{
144144
if (_tensor_data_ptr != IntPtr.Zero)
145145
return _tensor_data_ptr;

test/TorchSharpTest/TestTorchTensor.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8078,6 +8078,20 @@ public void Storage_Copy()
80788078
Assert.Equal(data, x.data<long>().ToArray());
80798079
}
80808080

8081+
[Fact]
8082+
[TestOf(nameof(Tensor.storage))]
8083+
public void Storage_DataPtr()
8084+
{
8085+
var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
8086+
var x = torch.tensor(data);
8087+
8088+
var st = x.storage<float>();
8089+
Assert.NotNull(st);
8090+
8091+
var ptr = st.data_ptr();
8092+
Assert.NotEqual(IntPtr.Zero, ptr);
8093+
}
8094+
80818095
[Fact]
80828096
[TestOf(nameof(torch.stft))]
80838097
public void Float32STFT()

0 commit comments

Comments
 (0)