diff --git a/src/TorchSharp/Tensor/Storage.cs b/src/TorchSharp/Tensor/Storage.cs index 35515c054..7ee1488e2 100644 --- a/src/TorchSharp/Tensor/Storage.cs +++ b/src/TorchSharp/Tensor/Storage.cs @@ -138,8 +138,30 @@ protected static Tensor CreateTypedTensor(ScalarType dtype, IList rawArray /// /// A pointer to the raw data in memory. /// - /// - protected IntPtr data_ptr() + /// + /// + /// The returned pointer refers to the underlying storage buffer of the associated tensor. Depending on the + /// device of the tensor, this pointer may reference CPU memory or device-specific memory (for example, CUDA + /// device memory). Callers must not assume that the pointer is always CPU-accessible. + /// + /// + /// The lifetime of the pointer is strictly tied to the lifetime and layout of the owning tensor/storage. The + /// pointer becomes invalid if the underlying tensor or storage is moved, reallocated, resized in a way that + /// changes the underlying buffer, or disposed. Code that dereferences the pointer must ensure that the + /// associated tensor/storage instance remains alive and is not modified or disposed for the entire duration + /// of pointer use (for example, by keeping a strong reference and/or using ). + /// + /// + /// This API is intended for advanced interop and unsafe scenarios. The pointer must not be cached and reused + /// across operations that may cause the storage to be reallocated or moved, and callers are responsible for + /// any necessary synchronization with device operations before reading from or writing to the memory. + /// + /// + /// + /// An that points to the first element in the storage buffer. The pointer is only valid + /// while the underlying tensor/storage remains alive and its storage is not reallocated or disposed. + /// + public IntPtr data_ptr() { if (_tensor_data_ptr != IntPtr.Zero) return _tensor_data_ptr; diff --git a/test/TorchSharpTest/TestTorchTensor.cs b/test/TorchSharpTest/TestTorchTensor.cs index bd75ecbb2..62358791f 100644 --- a/test/TorchSharpTest/TestTorchTensor.cs +++ b/test/TorchSharpTest/TestTorchTensor.cs @@ -8078,6 +8078,20 @@ public void Storage_Copy() Assert.Equal(data, x.data().ToArray()); } + [Fact] + [TestOf(nameof(Tensor.storage))] + public void Storage_DataPtr() + { + var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + var x = torch.tensor(data); + + var st = x.storage(); + Assert.NotNull(st); + + var ptr = st.data_ptr(); + Assert.NotEqual(IntPtr.Zero, ptr); + } + [Fact] [TestOf(nameof(torch.stft))] public void Float32STFT()