Skip to content

Commit cb5013b

Browse files
timmoon10zhongbozhuksivaman
authored
[PyTorch] Refactor C++ quantizer infrastructure (NVIDIA#1952)
* remove reciprocal op Signed-off-by: zhongboz <zhongboz@nvidia.com> * Refactor Quantizer::create_tensor function Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug when constructing FP8 tensor Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add quantize function to C++ quantizers Signed-off-by: Tim Moon <tmoon@nvidia.com> * Prototype function to coerce Python quantized tensors to match quantizer Signed-off-by: Tim Moon <tmoon@nvidia.com> * Use quantizer class in tex.quantize Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add FP8 current scaling support for activation backward Signed-off-by: Tim Moon <tmoon@nvidia.com> * Disable quantized GEMM output with FP8 current scaling Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add coerce_tensor functions for MXFP8 and DSv3 Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Avoid quantizing empty tensors Signed-off-by: Tim Moon <tmoon@nvidia.com> * Use consistent shapes for FP8 transposes Signed-off-by: Tim Moon <tmoon@nvidia.com> * In attention impl, construct FP8 tensors with pre-initialized scale-invs Signed-off-by: Tim Moon <tmoon@nvidia.com> * Initialize MXFP8 scales to zero Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Store copy of quantizer when creating quantized tensors Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fix linter warnings Signed-off-by: Tim Moon <tmoon@nvidia.com> * Make sure quantized tensors have private quantizer Avoid problems with in-place ops after quantizer usages are changed externally. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Rename "coerce_tensor" to "convert_and_update_tensor" Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Make sure CUDA context is available when launching NVRTC kernel Signed-off-by: Tim Moon <tmoon@nvidia.com> * Expose CUDA context creation function externally Signed-off-by: Tim Moon <tmoon@nvidia.com> --------- Signed-off-by: zhongboz <zhongboz@nvidia.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: zhongboz <zhongboz@nvidia.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent 5a495a3 commit cb5013b

20 files changed

+864
-376
lines changed

tests/pytorch/test_fusible_ops.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -837,10 +837,9 @@ def _test_basic_linear(
837837
pytest.skip("FP8 output is only supported with FP8 GEMMs")
838838
if quantized_grad_input and not quantized_compute:
839839
pytest.skip("FP8 grad input is only supported with FP8 GEMMs")
840-
if quantization == "mxfp8" and quantized_output:
841-
pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs")
842-
if quantization == "mxfp8" and quantized_grad_input:
843-
pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs")
840+
if quantization not in (None, "fp8"):
841+
if quantized_output or quantized_grad_input:
842+
pytest.skip("Recipe does not support quantized GEMM output")
844843

845844
# Random data
846845
x_ref, x_test = make_reference_and_test_tensors(

transformer_engine/common/libtransformer_engine.version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
transformer_engine::cuda::stream_priority_range*;
99
transformer_engine::cuda::current_device*;
1010
transformer_engine::cuda_driver::get_symbol*;
11+
transformer_engine::cuda_driver::ensure_context_exists*;
1112
transformer_engine::ubuf_built_with_mpi*;
1213
*transformer_engine::rtc*;
1314
transformer_engine::nvte_cudnn_handle_init*;

transformer_engine/common/util/cuda_driver.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,19 @@ void *get_symbol(const char *symbol, int cuda_version) {
4444
return entry_point;
4545
}
4646

47+
void ensure_context_exists() {
48+
CUcontext context;
49+
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context);
50+
if (context == nullptr) {
51+
// Add primary context to context stack
52+
CUdevice device;
53+
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, cuda::current_device());
54+
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device);
55+
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context);
56+
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device);
57+
}
58+
}
59+
4760
} // namespace cuda_driver
4861

4962
} // namespace transformer_engine

transformer_engine/common/util/cuda_driver.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ inline CUresult call(const char *symbol, ArgTs... args) {
3939
return (*func)(args...);
4040
}
4141

42+
/*! \brief Ensure that the calling thread has a CUDA context
43+
*
44+
* Each thread maintains a stack of CUDA contexts. If the calling
45+
* thread has an empty stack, the primary context is added to the
46+
* stack.
47+
*/
48+
void ensure_context_exists();
49+
4250
} // namespace cuda_driver
4351

4452
} // namespace transformer_engine

transformer_engine/common/util/rtc.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class Kernel {
5959
template <typename... ArgTs>
6060
void launch(int device_id, const dim3 grid_dim, const dim3 block_dim,
6161
unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) {
62+
cuda_driver::ensure_context_exists();
6263
void *arg_ptrs[] = {const_cast<void *>(static_cast<const void *>(&args))...};
6364
NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y,
6465
grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes,

transformer_engine/pytorch/csrc/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
namespace transformer_engine::pytorch {
1414

15-
std::vector<size_t> getTensorShape(at::Tensor t) {
15+
std::vector<size_t> getTensorShape(const at::Tensor& t) {
1616
std::vector<size_t> shape;
1717
for (auto s : t.sizes()) {
1818
shape.push_back(s);

transformer_engine/pytorch/csrc/common.h

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,21 @@ class Quantizer {
9898

9999
virtual void set_quantization_params(TensorWrapper* tensor) const = 0;
100100

101-
virtual std::pair<TensorWrapper, py::object> create_tensor(
102-
const std::vector<size_t>& shape, DType dtype,
103-
std::optional<at::Tensor> rowwise_data = std::nullopt) const = 0;
101+
/*! @brief Construct a tensor with uninitialized data */
102+
virtual std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
103+
DType dtype) const = 0;
104+
105+
/*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor
106+
*
107+
* The PyTorch tensor's attributes are modified to match the
108+
* quantizer's configuration.
109+
*/
110+
virtual std::pair<TensorWrapper, py::object> convert_and_update_tensor(
111+
py::object tensor) const = 0;
112+
113+
/*! @brief Convert to a quantized data format */
114+
virtual void quantize(const TensorWrapper& input, TensorWrapper& out,
115+
const std::optional<TensorWrapper>& noop_flag = std::nullopt) = 0;
104116

105117
virtual ~Quantizer() = default;
106118

@@ -121,9 +133,17 @@ class NoneQuantizer : public Quantizer {
121133

122134
void set_quantization_params(TensorWrapper* tensor) const override {}
123135

124-
std::pair<TensorWrapper, py::object> create_tensor(
125-
const std::vector<size_t>& shape, DType dtype,
126-
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
136+
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
137+
DType dtype) const override;
138+
139+
/*! @brief Construct a tensor with pre-initialized data */
140+
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
141+
at::Tensor data) const;
142+
143+
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object tensor) const override;
144+
145+
void quantize(const TensorWrapper& input, TensorWrapper& out,
146+
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
127147
};
128148

129149
class Float8Quantizer : public Quantizer {
@@ -139,9 +159,19 @@ class Float8Quantizer : public Quantizer {
139159

140160
void set_quantization_params(TensorWrapper* tensor) const override;
141161

142-
std::pair<TensorWrapper, py::object> create_tensor(
143-
const std::vector<size_t>& shape, DType dtype,
144-
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
162+
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
163+
DType dtype) const override;
164+
165+
/*! @brief Construct a tensor with pre-initialized data */
166+
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
167+
std::optional<at::Tensor> data,
168+
std::optional<at::Tensor> transpose,
169+
std::optional<at::Tensor> scale_inv) const;
170+
171+
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
172+
173+
void quantize(const TensorWrapper& input, TensorWrapper& out,
174+
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
145175
};
146176

147177
class Float8CurrentScalingQuantizer : public Quantizer {
@@ -161,9 +191,13 @@ class Float8CurrentScalingQuantizer : public Quantizer {
161191

162192
void set_quantization_params(TensorWrapper* tensor) const override;
163193

164-
std::pair<TensorWrapper, py::object> create_tensor(
165-
const std::vector<size_t>& shape, DType dtype,
166-
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
194+
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
195+
DType dtype) const override;
196+
197+
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
198+
199+
void quantize(const TensorWrapper& input, TensorWrapper& out,
200+
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
167201
};
168202

169203
class Float8BlockQuantizer : public Quantizer {
@@ -195,9 +229,13 @@ class Float8BlockQuantizer : public Quantizer {
195229
// Create a python Float8BlockQuantized tensor and C++ wrapper
196230
// for the tensor. Should set quantized data, scales for rowwise
197231
// and optionally columnwise usage.
198-
std::pair<TensorWrapper, py::object> create_tensor(
199-
const std::vector<size_t>& shape, DType dtype,
200-
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
232+
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
233+
DType dtype) const override;
234+
235+
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
236+
237+
void quantize(const TensorWrapper& input, TensorWrapper& out,
238+
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
201239

202240
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
203241
};
@@ -212,16 +250,20 @@ class MXFP8Quantizer : public Quantizer {
212250

213251
void set_quantization_params(TensorWrapper* tensor) const override;
214252

215-
std::pair<TensorWrapper, py::object> create_tensor(
216-
const std::vector<size_t>& shape, DType dtype,
217-
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
253+
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
254+
DType dtype) const override;
255+
256+
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
257+
258+
void quantize(const TensorWrapper& input, TensorWrapper& out,
259+
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
218260

219261
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
220262
};
221263

222264
std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);
223265

224-
std::vector<size_t> getTensorShape(at::Tensor t);
266+
std::vector<size_t> getTensorShape(const at::Tensor& t);
225267

226268
transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
227269
const std::string& fp8_recipe);

transformer_engine/pytorch/csrc/extensions/activation.cpp

Lines changed: 57 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -13,87 +13,74 @@ namespace transformer_engine::pytorch {
1313
template <void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t)>
1414
py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) {
1515
init_extension();
16-
auto my_quantizer = convert_quantizer(quantizer);
17-
auto input_tensor = input.contiguous();
18-
19-
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
20-
const auto& te_input_shape = te_input.shape();
21-
std::vector<size_t> input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim);
22-
input_shape[input_shape.size() - 1] /= shape_divisor;
23-
auto fake_tensor_type = input.scalar_type();
24-
25-
auto [te_output, out] =
26-
my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
27-
28-
// for current scaling, we need to compute amax first and then quantize
29-
// because cache cannot fit in the entire tensor to compute amax and quantize
30-
// the quantizer should not need amax reduction, no process group needed here
31-
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
32-
// activation function might change the input data range, we need to first call the activation function
33-
// and then find the amax and scale of that and then do the quantization
34-
// get a NoneQuantizer to calculate amax of activation output
35-
auto my_quantizer_none = std::make_unique<NoneQuantizer>(py::none());
36-
auto [te_output_act, out_act] =
37-
my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
38-
39-
NVTE_SCOPED_GIL_RELEASE({
40-
act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream());
41-
// use te_output_act as input to the compute amax and find the amax of activated tensor
42-
nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
43-
});
4416

45-
// my_quantizer here has to be a Float8CurrentScalingQuantizer
46-
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
47-
if (my_quantizer_cs->with_amax_reduction) {
48-
NVTE_ERROR(
49-
"per-tensor current scaling amax reduction is not supported in activation functions.");
50-
}
51-
QuantizationConfigWrapper quant_config;
52-
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
53-
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
54-
55-
NVTE_SCOPED_GIL_RELEASE({
56-
nvte_compute_scale_from_amax(te_output.data(), quant_config,
57-
at::cuda::getCurrentCUDAStream());
58-
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
59-
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
60-
nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config,
61-
at::cuda::getCurrentCUDAStream());
62-
});
63-
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
64-
// sanity check, since activation fusion is not supported for blockwise quantization yet
65-
// need to raise an error here instead of silently going into act_func with wrong numerics
66-
NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet.");
17+
// Input tensor
18+
auto input_tensor = input.contiguous();
19+
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor);
20+
21+
// Construct output tensor
22+
auto quantizer_cpp = convert_quantizer(quantizer);
23+
const auto input_shape = input_cpp.shape();
24+
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
25+
output_shape.back() /= shape_divisor;
26+
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
27+
auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype);
28+
29+
// Compute activation
30+
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
31+
detail::IsMXFP8Quantizers(quantizer.ptr())) {
32+
// Compute activation directly
33+
NVTE_SCOPED_GIL_RELEASE(
34+
{ act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); });
6735
} else {
36+
// Compute activation in high-precision, then quantize
37+
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype);
6838
NVTE_SCOPED_GIL_RELEASE(
69-
{ act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); });
39+
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); });
40+
quantizer_cpp->quantize(temp_cpp, out_cpp);
7041
}
7142

72-
return out;
43+
return out_py;
7344
}
7445

75-
template <void (*act_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t)>
76-
py::object dactivation_helper(const at::Tensor& grad, const at::Tensor& input,
46+
template <void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t)>
47+
py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input,
7748
py::handle quantizer) {
7849
init_extension();
79-
auto my_quantizer = convert_quantizer(quantizer);
80-
auto input_tensor = input.contiguous();
81-
auto grad_tensor = grad.contiguous();
82-
83-
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
84-
const TensorWrapper& te_grad = makeTransformerEngineTensor(grad_tensor);
85-
const auto& te_input_shape = te_input.shape();
86-
std::vector<size_t> input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim);
87-
auto fake_tensor_type = input.scalar_type();
88-
89-
auto [te_output, out] =
90-
my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
9150

92-
NVTE_SCOPED_GIL_RELEASE({
93-
act_func(te_grad.data(), te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
94-
});
51+
// Grad output and input tensors
52+
auto grad_output_tensor = grad_output.contiguous();
53+
auto input_tensor = input.contiguous();
54+
const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor);
55+
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor);
56+
57+
// Construct grad input tensor
58+
auto quantizer_cpp = convert_quantizer(quantizer);
59+
const auto input_shape_te = input_cpp.shape();
60+
const std::vector<size_t> input_shape(input_shape_te.data,
61+
input_shape_te.data + input_shape_te.ndim);
62+
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
63+
auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype);
64+
65+
// Compute activation backward
66+
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
67+
detail::IsMXFP8Quantizers(quantizer.ptr())) {
68+
// Compute activation backward directly
69+
NVTE_SCOPED_GIL_RELEASE({
70+
dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(),
71+
at::cuda::getCurrentCUDAStream());
72+
});
73+
} else {
74+
// Compute activation backward in high-precision, then quantize
75+
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype);
76+
NVTE_SCOPED_GIL_RELEASE({
77+
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(),
78+
at::cuda::getCurrentCUDAStream());
79+
});
80+
quantizer_cpp->quantize(temp_cpp, grad_input_cpp);
81+
}
9582

96-
return out;
83+
return grad_input_py;
9784
}
9885

9986
py::object gelu(const at::Tensor& input, py::handle quantizer) {

0 commit comments

Comments
 (0)