Skip to content

Commit 54b5302

Browse files
Add QInt8, QUInt8, QInt32 quantized scalar types with full quantization support
- Uncomment and fix ScalarType enum entries (QInt8=12, QUInt8=13, QInt32=14) - Fix incorrect QUInt32 name to QInt32 to match PyTorch - Add torch.qint8, torch.quint8, torch.qint32 dtype aliases - Add torch.is_quantized() and Tensor.is_quantized() methods - Add IsQuantized() extension method on ScalarType Native quantization bindings: - Add THSTensor_quantize_per_tensor, THSTensor_quantize_per_channel (C++ and P/Invoke) - Add THSTensor_dequantize (C++ and P/Invoke) - Add THSTensor_q_scale, THSTensor_q_zero_point (C++ and P/Invoke) - Add THSTensor_int_repr (C++ and P/Invoke) - Add THSTensor_q_per_channel_scales, THSTensor_q_per_channel_zero_points, THSTensor_q_per_channel_axis (C++ and P/Invoke) Managed API: - Add torch.quantize_per_tensor() and torch.quantize_per_channel() static methods - Add torch.dequantize() static method - Add Tensor.dequantize(), Tensor.q_scale(), Tensor.q_zero_point() instance methods - Add Tensor.int_repr() instance method - Add Tensor.q_per_channel_scales(), Tensor.q_per_channel_zero_points(), Tensor.q_per_channel_axis() instance methods Unit tests for all new functionality (13 tests)
1 parent 80291c7 commit 54b5302

7 files changed

Lines changed: 468 additions & 3 deletions

File tree

src/Native/LibTorchSharp/THSTensor.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2278,3 +2278,48 @@ Tensor THSTensor_unflatten_names(Tensor tensor, const char** names, const int64_
22782278

22792279
return nullptr;
22802280
}
2281+
2282+
Tensor THSTensor_quantize_per_tensor(const Tensor tensor, double scale, int64_t zero_point, int8_t scalar_type)
2283+
{
2284+
CATCH_TENSOR(torch::quantize_per_tensor(*tensor, scale, zero_point, at::ScalarType(scalar_type)));
2285+
}
2286+
2287+
Tensor THSTensor_quantize_per_channel(const Tensor tensor, const Tensor scales, const Tensor zero_points, int64_t axis, int8_t scalar_type)
2288+
{
2289+
CATCH_TENSOR(torch::quantize_per_channel(*tensor, *scales, *zero_points, axis, at::ScalarType(scalar_type)));
2290+
}
2291+
2292+
Tensor THSTensor_dequantize(const Tensor tensor)
2293+
{
2294+
CATCH_TENSOR(tensor->dequantize());
2295+
}
2296+
2297+
double THSTensor_q_scale(const Tensor tensor)
2298+
{
2299+
CATCH_RETURN(double, 0.0, tensor->q_scale());
2300+
}
2301+
2302+
int64_t THSTensor_q_zero_point(const Tensor tensor)
2303+
{
2304+
CATCH_RETURN(int64_t, 0, tensor->q_zero_point());
2305+
}
2306+
2307+
Tensor THSTensor_int_repr(const Tensor tensor)
2308+
{
2309+
CATCH_TENSOR(tensor->int_repr());
2310+
}
2311+
2312+
Tensor THSTensor_q_per_channel_scales(const Tensor tensor)
2313+
{
2314+
CATCH_TENSOR(tensor->q_per_channel_scales());
2315+
}
2316+
2317+
Tensor THSTensor_q_per_channel_zero_points(const Tensor tensor)
2318+
{
2319+
CATCH_TENSOR(tensor->q_per_channel_zero_points());
2320+
}
2321+
2322+
int64_t THSTensor_q_per_channel_axis(const Tensor tensor)
2323+
{
2324+
CATCH_RETURN(int64_t, 0, tensor->q_per_channel_axis());
2325+
}

src/Native/LibTorchSharp/THSTensor.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,3 +1790,15 @@ EXPORT_API(Tensor) THSTensor_kaiser_window(const int64_t len, bool periodic, dou
17901790

17911791
EXPORT_API(Tensor) THSTensor_stft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool normalized, int64_t onesided, bool return_complex);
17921792
EXPORT_API(Tensor) THSTensor_istft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool center, bool normalized, int64_t onesided, int64_t length, bool return_complex);
1793+
1794+
// Quantization Ops
1795+
1796+
EXPORT_API(Tensor) THSTensor_quantize_per_tensor(const Tensor tensor, double scale, int64_t zero_point, int8_t scalar_type);
1797+
EXPORT_API(Tensor) THSTensor_quantize_per_channel(const Tensor tensor, const Tensor scales, const Tensor zero_points, int64_t axis, int8_t scalar_type);
1798+
EXPORT_API(Tensor) THSTensor_dequantize(const Tensor tensor);
1799+
EXPORT_API(double) THSTensor_q_scale(const Tensor tensor);
1800+
EXPORT_API(int64_t) THSTensor_q_zero_point(const Tensor tensor);
1801+
EXPORT_API(Tensor) THSTensor_int_repr(const Tensor tensor);
1802+
EXPORT_API(Tensor) THSTensor_q_per_channel_scales(const Tensor tensor);
1803+
EXPORT_API(Tensor) THSTensor_q_per_channel_zero_points(const Tensor tensor);
1804+
EXPORT_API(int64_t) THSTensor_q_per_channel_axis(const Tensor tensor);

src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2176,6 +2176,33 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
21762176
internal static extern IntPtr THSTensor_histogram_out_t(IntPtr input, IntPtr bins, IntPtr weight, bool density, out IntPtr hist, out IntPtr bin_edges, out IntPtr r_bin_edges);
21772177
[DllImport("LibTorchSharp")]
21782178
internal static extern IntPtr THSTensor_histogram_out_i(IntPtr input, long bins, IntPtr range, int length, IntPtr weight, bool density, out IntPtr hist, out IntPtr bin_edges, out IntPtr r_bin_edges);
2179+
2180+
[DllImport("LibTorchSharp")]
2181+
internal static extern IntPtr THSTensor_quantize_per_tensor(IntPtr tensor, double scale, long zero_point, sbyte scalar_type);
2182+
2183+
[DllImport("LibTorchSharp")]
2184+
internal static extern IntPtr THSTensor_quantize_per_channel(IntPtr tensor, IntPtr scales, IntPtr zero_points, long axis, sbyte scalar_type);
2185+
2186+
[DllImport("LibTorchSharp")]
2187+
internal static extern IntPtr THSTensor_dequantize(IntPtr tensor);
2188+
2189+
[DllImport("LibTorchSharp")]
2190+
internal static extern double THSTensor_q_scale(IntPtr tensor);
2191+
2192+
[DllImport("LibTorchSharp")]
2193+
internal static extern long THSTensor_q_zero_point(IntPtr tensor);
2194+
2195+
[DllImport("LibTorchSharp")]
2196+
internal static extern IntPtr THSTensor_int_repr(IntPtr tensor);
2197+
2198+
[DllImport("LibTorchSharp")]
2199+
internal static extern IntPtr THSTensor_q_per_channel_scales(IntPtr tensor);
2200+
2201+
[DllImport("LibTorchSharp")]
2202+
internal static extern IntPtr THSTensor_q_per_channel_zero_points(IntPtr tensor);
2203+
2204+
[DllImport("LibTorchSharp")]
2205+
internal static extern long THSTensor_q_per_channel_axis(IntPtr tensor);
21792206
}
21802207
#pragma warning restore CA2101
21812208
}

src/TorchSharp/Tensor/Tensor.cs

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,95 @@ internal IntPtr MoveHandle()
271271
/// </summary>
272272
public bool is_complex() => torch.is_complex(dtype);
273273

274+
/// <summary>
275+
/// Returns True if the data type of input is a quantized data type i.e., one of torch.qint8, torch.quint8, and torch.qint32.
276+
/// </summary>
277+
public bool is_quantized() => torch.is_quantized(dtype);
278+
279+
/// <summary>
280+
/// Given a quantized Tensor, returns a dequantized (float) Tensor.
281+
/// </summary>
282+
public Tensor dequantize()
283+
{
284+
var res = NativeMethods.THSTensor_dequantize(Handle);
285+
if (res == IntPtr.Zero) { CheckForErrors(); }
286+
return new Tensor(res);
287+
}
288+
289+
/// <summary>
290+
/// Given a quantized Tensor, returns the scale of the quantization as a double.
291+
/// </summary>
292+
public double q_scale()
293+
{
294+
var res = NativeMethods.THSTensor_q_scale(Handle);
295+
CheckForErrors();
296+
return res;
297+
}
298+
299+
/// <summary>
300+
/// Given a quantized Tensor, returns the zero_point of the quantization as a long.
301+
/// </summary>
302+
public long q_zero_point()
303+
{
304+
var res = NativeMethods.THSTensor_q_zero_point(Handle);
305+
CheckForErrors();
306+
return res;
307+
}
308+
309+
/// <summary>
310+
/// Given a quantized Tensor, returns a Tensor of the underlying integer representation.
311+
/// </summary>
312+
public Tensor int_repr()
313+
{
314+
var res = NativeMethods.THSTensor_int_repr(Handle);
315+
if (res == IntPtr.Zero) { CheckForErrors(); }
316+
return new Tensor(res);
317+
}
318+
319+
/// <summary>
320+
/// Given a quantized Tensor quantized per channel, returns a Tensor of the scales of the quantization for each channel.
321+
/// </summary>
322+
public Tensor q_per_channel_scales()
323+
{
324+
var res = NativeMethods.THSTensor_q_per_channel_scales(Handle);
325+
if (res == IntPtr.Zero) { CheckForErrors(); }
326+
return new Tensor(res);
327+
}
328+
329+
/// <summary>
330+
/// Given a quantized Tensor quantized per channel, returns a Tensor of the zero points of the quantization for each channel.
331+
/// </summary>
332+
public Tensor q_per_channel_zero_points()
333+
{
334+
var res = NativeMethods.THSTensor_q_per_channel_zero_points(Handle);
335+
if (res == IntPtr.Zero) { CheckForErrors(); }
336+
return new Tensor(res);
337+
}
338+
339+
/// <summary>
340+
/// Given a quantized Tensor quantized per channel, returns the axis along which per channel quantization is applied.
341+
/// </summary>
342+
public long q_per_channel_axis()
343+
{
344+
var res = NativeMethods.THSTensor_q_per_channel_axis(Handle);
345+
CheckForErrors();
346+
return res;
347+
}
348+
349+
internal Tensor _quantize_per_tensor(double scale, long zero_point, ScalarType dtype)
350+
{
351+
var res = NativeMethods.THSTensor_quantize_per_tensor(Handle, scale, zero_point, (sbyte)dtype);
352+
if (res == IntPtr.Zero) { CheckForErrors(); }
353+
return new Tensor(res);
354+
}
355+
356+
internal Tensor _quantize_per_channel(Tensor scales, Tensor zero_points, long axis, ScalarType dtype)
357+
{
358+
var res = NativeMethods.THSTensor_quantize_per_channel(Handle, scales.Handle, zero_points.Handle, axis, (sbyte)dtype);
359+
if (res == IntPtr.Zero) { CheckForErrors(); }
360+
return new Tensor(res);
361+
}
362+
274363
/// <summary>
275364
/// Returns True if the input is a single element tensor which is not equal to zero after type conversions,
276365
/// i.e. not equal to torch.tensor([0.]) or torch.tensor([0]) or torch.tensor([False]).
@@ -7279,9 +7368,9 @@ public enum ScalarType : sbyte
72797368
ComplexFloat32 = 9,
72807369
ComplexFloat64 = 10,
72817370
Bool = 11,
7282-
//QInt8 = 12,
7283-
//QUInt8 = 13,
7284-
//QUInt32 = 14,
7371+
QInt8 = 12,
7372+
QUInt8 = 13,
7373+
QInt32 = 14,
72857374
BFloat16 = 15
72867375
}
72877376

@@ -7413,6 +7502,18 @@ public static bool is_complex(ScalarType type)
74137502
}
74147503
}
74157504

7505+
public static bool is_quantized(ScalarType type)
7506+
{
7507+
switch (type) {
7508+
case ScalarType.QInt8:
7509+
case ScalarType.QUInt8:
7510+
case ScalarType.QInt32:
7511+
return true;
7512+
default:
7513+
return false;
7514+
}
7515+
}
7516+
74167517
public static long max_int_value(ScalarType type)
74177518
{
74187519
switch (type) {
@@ -7463,6 +7564,10 @@ public static long max_int_value(ScalarType type)
74637564
public static ScalarType cfloat = ScalarType.ComplexFloat32;
74647565
public static ScalarType cdouble = ScalarType.ComplexFloat64;
74657566

7567+
public static ScalarType qint8 = ScalarType.QInt8;
7568+
public static ScalarType quint8 = ScalarType.QUInt8;
7569+
public static ScalarType qint32 = ScalarType.QInt32;
7570+
74667571
/// <summary>
74677572
/// Creates a new dispose scope for the current thread. Any tensor created within the dispose scope will
74687573
/// be automatically disposed once the dispose scope is disposed.

src/TorchSharp/Tensor/TensorExtensionMethods.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,23 @@ internal static bool IsComplex(this ScalarType type)
368368
}
369369
}
370370

371+
/// <summary>
372+
/// Indicates whether a given element type is quantized.
373+
/// </summary>
374+
/// <param name="type">The input type.</param>
375+
/// <returns></returns>
376+
internal static bool IsQuantized(this ScalarType type)
377+
{
378+
switch (type) {
379+
case ScalarType.QInt8:
380+
case ScalarType.QUInt8:
381+
case ScalarType.QInt32:
382+
return true;
383+
default:
384+
return false;
385+
}
386+
}
387+
371388
/// <summary>
372389
/// Save the tensor in a .NET-specific format.
373390
/// </summary>

src/TorchSharp/Tensor/torch.PointwiseOps.cs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,47 @@ public static Tensor fake_quantize_per_channel_affine(Tensor input, Tensor scale
761761
public static Tensor fake_quantize_per_tensor_affine(Tensor input, Tensor scale, Tensor zero_point, long quant_min, long quant_max)
762762
=> throw new NotImplementedException();
763763

764+
// https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor
765+
/// <summary>
766+
/// Converts a float tensor to a quantized tensor with given scale and zero point.
767+
/// </summary>
768+
/// <param name="input">Float tensor to quantize</param>
769+
/// <param name="scale">Scale to apply in quantization formula</param>
770+
/// <param name="zero_point">Offset in integer value that maps to float zero</param>
771+
/// <param name="dtype">The desired data type of returned tensor. Must be a quantized type (torch.qint8, torch.quint8, or torch.qint32).</param>
772+
/// <returns>A newly quantized tensor</returns>
773+
public static Tensor quantize_per_tensor(Tensor input, double scale, long zero_point, ScalarType dtype)
774+
{
775+
if (!is_quantized(dtype))
776+
throw new ArgumentException("dtype must be a quantized type (QInt8, QUInt8, or QInt32)", nameof(dtype));
777+
return input._quantize_per_tensor(scale, zero_point, dtype);
778+
}
779+
780+
// https://pytorch.org/docs/stable/generated/torch.quantize_per_channel
781+
/// <summary>
782+
/// Converts a float tensor to a per-channel quantized tensor with given scales and zero points.
783+
/// </summary>
784+
/// <param name="input">Float tensor to quantize</param>
785+
/// <param name="scales">Float 1D tensor of scales to use, size should match input.size(axis)</param>
786+
/// <param name="zero_points">Integer 1D tensor of offsets to use, size should match input.size(axis)</param>
787+
/// <param name="axis">Dimension on which to apply per-channel quantization</param>
788+
/// <param name="dtype">The desired data type of returned tensor. Must be a quantized type (torch.qint8, torch.quint8, or torch.qint32).</param>
789+
/// <returns>A newly quantized tensor</returns>
790+
public static Tensor quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, long axis, ScalarType dtype)
791+
{
792+
if (!is_quantized(dtype))
793+
throw new ArgumentException("dtype must be a quantized type (QInt8, QUInt8, or QInt32)", nameof(dtype));
794+
return input._quantize_per_channel(scales, zero_points, axis, dtype);
795+
}
796+
797+
// https://pytorch.org/docs/stable/generated/torch.dequantize
798+
/// <summary>
799+
/// Returns an fp32 Tensor by dequantizing a quantized Tensor.
800+
/// </summary>
801+
/// <param name="input">A quantized tensor</param>
802+
/// <returns>A dequantized (float) tensor</returns>
803+
public static Tensor dequantize(Tensor input) => input.dequantize();
804+
764805
// https://pytorch.org/docs/stable/generated/torch.fix
765806
/// <summary>
766807
/// Returns a new tensor with the truncated integer values of the elements of input.

0 commit comments

Comments
 (0)