Skip to content
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

Expand Down Expand Up @@ -39,13 +39,12 @@ namespace Microsoft.Data.SqlTypes;

private SqlVector(int length)
{
if (length < 0)
(_elementType, _elementSize, int maxElements) = GetTypeFieldsOrThrow();
if (length < 0 || length > maxElements)
{
throw ADP.InvalidArraySize(nameof(length));
}

(_elementType, _elementSize) = GetTypeFieldsOrThrow();

IsNull = true;

Length = length;
Expand All @@ -61,7 +60,11 @@ private SqlVector(int length)
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlTypes/SqlVector.xml' path='docs/members[@name="SqlVector"]/ctor1/*' />
public SqlVector(ReadOnlyMemory<T> memory)
{
(_elementType, _elementSize) = GetTypeFieldsOrThrow();
(_elementType, _elementSize, int maxElements) = GetTypeFieldsOrThrow();
if (memory.Length > maxElements)
{
throw ADP.InvalidArraySize(nameof(memory));
}

IsNull = false;

Expand All @@ -74,7 +77,7 @@ public SqlVector(ReadOnlyMemory<T> memory)

internal SqlVector(byte[] tdsBytes)
{
(_elementType, _elementSize) = GetTypeFieldsOrThrow();
(_elementType, _elementSize, _) = GetTypeFieldsOrThrow();

(Length, _size) = GetCountsOrThrow(tdsBytes);

Expand Down Expand Up @@ -125,10 +128,11 @@ internal string GetString()

#region Helpers

private (byte, byte) GetTypeFieldsOrThrow()
private static (byte, byte, int) GetTypeFieldsOrThrow()
{
byte elementType;
byte elementSize;
int maxSize;

if (typeof(T) == typeof(float))
{
Expand All @@ -139,12 +143,17 @@ internal string GetString()
{
throw SQL.VectorTypeNotSupported(typeof(T).FullName);
}
// The size of a vector (including its header) must not exceed the maximum size of a TDS packet.
// Calculate the maximum number of elements to simplify the validation of input sizes in constructors.
maxSize = (TdsEnums.MAXSIZE - TdsEnums.VECTOR_HEADER_SIZE) / elementSize;

return (elementType, elementSize);
return (elementType, elementSize, maxSize);
}

private byte[] MakeTdsBytes(ReadOnlyMemory<T> values)
{
Debug.Assert(Length <= ushort.MaxValue);

//Refer to TDS section 2.2.5.5.7 for vector header format
// +------------------------+-----------------+----------------------+------------------+----------------------------+--------------+
// | Field | Size (bytes) | Example Value | Description |
Expand All @@ -158,32 +167,42 @@ private byte[] MakeTdsBytes(ReadOnlyMemory<T> values)
// +------------------------+-----------------+----------------------+--------------------------------------------------------------+

byte[] result = new byte[_size];
ReadOnlySpan<T> valueSpan = values.Span;

// Header Bytes
result[0] = VecHeaderMagicNo;
result[1] = VecVersionNo;
result[2] = (byte)(Length & 0xFF);
result[3] = (byte)((Length >> 8) & 0xFF);
BinaryPrimitives.WriteUInt16LittleEndian(result.AsSpan(2), (ushort)Length);
Comment thread
apoorvdeshmukh marked this conversation as resolved.
result[4] = _elementType;
result[5] = 0x00;
result[6] = 0x00;
result[7] = 0x00;

#if NETFRAMEWORK
// Copy data via marshaling.
if (MemoryMarshal.TryGetArray(values, out ArraySegment<T> segment))
// If .NET is running on a little-endian architecture, cast directly to a byte array and proceed.
// This optimisation relies upon the base type of the vector transporting values in a format and
// endianness which is identical to the client. This is true for all little-endian clients reading
// float32-based vectors.
Comment thread
apoorvdeshmukh marked this conversation as resolved.
if (BitConverter.IsLittleEndian)
{
Buffer.BlockCopy(segment.Array, segment.Offset * _elementSize, result, TdsEnums.VECTOR_HEADER_SIZE, segment.Count * _elementSize);
ReadOnlySpan<byte> valuesAsBytes = MemoryMarshal.AsBytes(valueSpan);

valuesAsBytes.CopyTo(result.AsSpan(TdsEnums.VECTOR_HEADER_SIZE));
}
else
{
Buffer.BlockCopy(values.ToArray(), 0, result, TdsEnums.VECTOR_HEADER_SIZE, values.Length * _elementSize);
if (typeof(T) == typeof(float))
{
for (int i = 0, currPosition = TdsEnums.VECTOR_HEADER_SIZE; i < values.Length; i++, currPosition += _elementSize)
{
#if NET
BinaryPrimitives.WriteSingleLittleEndian(result.AsSpan(currPosition), (float)(object)valueSpan[i]);
#else
BinaryPrimitives.WriteInt32LittleEndian(result.AsSpan(currPosition), BitConverterCompatible.SingleToInt32Bits((float)(object)valueSpan[i]));
#endif
}
}
}
#else
// Fast span-based copy.
var byteSpan = MemoryMarshal.AsBytes(values.Span);
byteSpan.CopyTo(result.AsSpan(TdsEnums.VECTOR_HEADER_SIZE));
#endif

return result;
}

Expand Down Expand Up @@ -227,15 +246,32 @@ private T[] MakeArray()
return Array.Empty<T>();
}

#if NETFRAMEWORK
// Allocate array and copy bytes into it
T[] result = new T[Length];
Buffer.BlockCopy(_tdsBytes, 8, result, 0, _elementSize * Length);

// See the comment in MakeTdsBytes for more information on this optimisation.
if (BitConverter.IsLittleEndian)
{
Span<byte> valuesAsBytes = MemoryMarshal.AsBytes(result.AsSpan());

_tdsBytes.AsSpan(TdsEnums.VECTOR_HEADER_SIZE).CopyTo(valuesAsBytes);
}
else
{
if (typeof(T) == typeof(float))
{
for (int i = 0, currPosition = TdsEnums.VECTOR_HEADER_SIZE; i < Length; i++, currPosition += _elementSize)
{
#if NET
result[i] = (T)(object)BinaryPrimitives.ReadSingleLittleEndian(_tdsBytes.AsSpan(currPosition));
#else
result[i] = (T)(object)BitConverterCompatible.Int32BitsToSingle(BinaryPrimitives.ReadInt32LittleEndian(_tdsBytes.AsSpan(currPosition)));
#endif
}
}
}
Comment thread
paulmedynski marked this conversation as resolved.

return result;
#else
ReadOnlySpan<byte> dataSpan = _tdsBytes.AsSpan(8, _elementSize * Length);
return MemoryMarshal.Cast<byte, T>(dataSpan).ToArray();
#endif
}

#endregion
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

Expand Down Expand Up @@ -28,6 +28,13 @@ public void Construct_Length_Negative()
Assert.Throws<ArgumentOutOfRangeException>(() => SqlVector<float>.CreateNull(-1));
}

[Fact]
public void Construct_Length_Exceeds_8000()
{
Assert.Throws<ArgumentOutOfRangeException>(() => SqlVector<float>.CreateNull(1999));
Comment thread
apoorvdeshmukh marked this conversation as resolved.
Assert.Throws<ArgumentOutOfRangeException>(() => SqlVector<float>.CreateNull(int.MaxValue / 2));
}

[Fact]
public void Construct_Length()
{
Expand Down
Loading