Skip to content

Commit 7e48fa1

Browse files
[JAX] Debugging inspect utility (#2651)
* initial debug of inspect ffi Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * writing binary dumps of tensors works Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * loading works Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add tensor statistics Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * lint Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Add cuda error check and tests Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Ad __init__.py to debug folder Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Fix lint Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Fix lint Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address greptile comments Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Lint Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Gate tests behind fp8 support Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent fa68781 commit 7e48fa1

8 files changed

Lines changed: 338 additions & 2 deletions

File tree

tests/jax/test_custom_call_compute.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,3 +1921,37 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
19211921
assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
19221922
assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
19231923
assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
1924+
1925+
1926+
class TestDebugInspectFFI:
1927+
1928+
@pytest_parametrize_wrapper("shape", [(256, 128)])
1929+
@pytest_parametrize_wrapper(
1930+
"dtype",
1931+
[
1932+
jnp.float32,
1933+
jnp.bfloat16,
1934+
jnp.float16,
1935+
# Note: fp4 currently doesn't work
1936+
# jnp.float4_e2m1fn
1937+
]
1938+
+ ([jnp.float8_e4m3fn, jnp.float8_e5m2] if is_fp8_supported else []),
1939+
)
1940+
def test_debug_inspect_ffi(self, shape, dtype):
1941+
from transformer_engine.jax.debug.experimental import inspect_array, load_array_dump
1942+
1943+
def f(x):
1944+
x = x + 1
1945+
x = inspect_array(x, "my_array")
1946+
x = x + 1
1947+
return x
1948+
1949+
key = jax.random.PRNGKey(0)
1950+
x = jax.random.uniform(key, shape, jnp.float32)
1951+
x = x.astype(dtype)
1952+
_ = jax.jit(f)(x)
1953+
1954+
expected = x + 1
1955+
actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype)
1956+
1957+
assert_allclose(actual, expected, dtype=dtype)

transformer_engine/jax/csrc/extensions.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
143143
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler);
144144
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler);
145145

146+
// Inspect
147+
XLA_FFI_DECLARE_HANDLER_SYMBOL(InspectHandler);
148+
146149
// Cudnn helpers
147150
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
148151

transformer_engine/jax/csrc/extensions/amax.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
************************************************************************/
66
#include <cuda_runtime.h>
77

8-
#include <iostream>
9-
108
#include "../extensions.h"
119
#include "transformer_engine/cast.h"
1210
#include "transformer_engine/hadamard_transform.h"
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*************************************************************************
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE for license information.
5+
************************************************************************/
6+
#include <cuda_runtime.h>
7+
8+
#include <fstream>
9+
#include <iostream>
10+
11+
#include "../extensions.h"
12+
#include "xla/ffi/api/c_api.h"
13+
14+
namespace transformer_engine {
15+
namespace jax {
16+
17+
Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type min_buf,
18+
Buffer_Type max_buf, Buffer_Type mean_buf, Buffer_Type std_buf,
19+
Result_Type output_buf) {
20+
NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation");
21+
NVTE_CHECK(output_buf->untyped_data() != nullptr,
22+
"Output must be provided for inspect operation");
23+
NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(),
24+
"Input and output must point to the same buffer for inspect operation");
25+
26+
std::vector<uint8_t> input_data(input_buf.size_bytes());
27+
NVTE_CHECK_CUDA(cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(),
28+
input_buf.size_bytes(), cudaMemcpyDeviceToHost, stream));
29+
30+
float min_val{}, max_val{}, mean_val{}, std_val{};
31+
NVTE_CHECK_CUDA(cudaMemcpyAsync(&min_val, min_buf.untyped_data(), sizeof(float),
32+
cudaMemcpyDeviceToHost, stream));
33+
NVTE_CHECK_CUDA(cudaMemcpyAsync(&max_val, max_buf.untyped_data(), sizeof(float),
34+
cudaMemcpyDeviceToHost, stream));
35+
NVTE_CHECK_CUDA(cudaMemcpyAsync(&mean_val, mean_buf.untyped_data(), sizeof(float),
36+
cudaMemcpyDeviceToHost, stream));
37+
NVTE_CHECK_CUDA(cudaMemcpyAsync(&std_val, std_buf.untyped_data(), sizeof(float),
38+
cudaMemcpyDeviceToHost, stream));
39+
40+
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
41+
42+
int device;
43+
NVTE_CHECK_CUDA(cudaGetDevice(&device));
44+
45+
// Write the tensor data to a file as a binary blob
46+
std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin";
47+
std::ofstream file(filename, std::ios::binary);
48+
NVTE_CHECK(file.is_open(), "Failed to create file: ", filename);
49+
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
50+
file.close();
51+
52+
// Write out a metadata file
53+
std::string meta_filename = "my_tensor_gpu" + std::to_string(device) + "_meta.json";
54+
std::ofstream meta_file(meta_filename);
55+
NVTE_CHECK(meta_file.is_open(), "Failed to create file: ", meta_filename);
56+
meta_file << "{";
57+
meta_file << "\"shape\": [";
58+
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
59+
meta_file << input_buf.dimensions()[i];
60+
if (i < input_buf.dimensions().size() - 1) {
61+
meta_file << ", ";
62+
}
63+
}
64+
meta_file << "], ";
65+
meta_file << "\"dtype\": " << static_cast<int>(input_buf.element_type());
66+
meta_file << ", \"min\": " << min_val;
67+
meta_file << ", \"max\": " << max_val;
68+
meta_file << ", \"mean\": " << mean_val;
69+
meta_file << ", \"std\": " << std_val;
70+
meta_file << "}";
71+
meta_file.close();
72+
73+
// Log the tensor metadata to the console
74+
printf("[gpu%d]: Tensor data written to %s (shape: [", device, filename.c_str());
75+
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
76+
printf("%zu", static_cast<size_t>(input_buf.dimensions()[i]));
77+
if (i < input_buf.dimensions().size() - 1) {
78+
printf(", ");
79+
}
80+
}
81+
printf("], dtype: %d", static_cast<int>(input_buf.element_type()));
82+
printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val);
83+
84+
return ffi_with_cuda_error_check();
85+
}
86+
87+
XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI,
88+
FFI::Bind()
89+
.Ctx<FFI_Stream_Type>() // stream
90+
.Arg<Buffer_Type>() // input
91+
.Arg<Buffer_Type>() // min
92+
.Arg<Buffer_Type>() // max
93+
.Arg<Buffer_Type>() // mean
94+
.Arg<Buffer_Type>() // std
95+
.Ret<Buffer_Type>() // output
96+
);
97+
98+
} // namespace jax
99+
} // namespace transformer_engine

transformer_engine/jax/csrc/extensions/pybind.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ pybind11::dict Registrations() {
8181
pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler),
8282
pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler));
8383

84+
dict["te_inspect_ffi"] =
85+
pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler));
86+
8487
return dict;
8588
}
8689

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
"""EXPERIMENTAL debugging utilities for Transformer Engine JAX.
5+
6+
This API is experimental and may change or be removed without deprecation in future releases.
7+
"""
8+
9+
__all__ = [
10+
"experimental",
11+
]
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
"""EXPERIMENTAL debugging utilities for Transformer Engine JAX.
5+
6+
This API is experimental and may change or be removed without deprecation in future releases.
7+
"""
8+
9+
from .inspect import inspect_array, load_array_dump
10+
11+
__all__ = [
12+
"inspect_array",
13+
"load_array_dump",
14+
]
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
"""Experimental JAX array inspection utilities."""
5+
6+
from functools import partial
7+
8+
import jax
9+
import jax.numpy as jnp
10+
from jax import ffi
11+
12+
from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive
13+
14+
__all__ = ["inspect_array", "load_array_dump"]
15+
16+
17+
class InspectPrimitive(BasePrimitive):
18+
"""
19+
No-op used for inspect array values.
20+
"""
21+
22+
name = "te_inspect_ffi"
23+
multiple_results = False
24+
impl_static_args = ()
25+
inner_primitive = None
26+
outer_primitive = None
27+
28+
@staticmethod
29+
def abstract(
30+
x_aval,
31+
x_min_aval,
32+
x_max_aval,
33+
x_mean_aval,
34+
x_std_aval,
35+
):
36+
"""
37+
inspect abstract
38+
"""
39+
assert (
40+
x_min_aval.shape == () and x_min_aval.dtype == jnp.float32
41+
), "x_min must be a scalar with dtype float32"
42+
assert (
43+
x_max_aval.shape == () and x_max_aval.dtype == jnp.float32
44+
), "x_max must be a scalar with dtype float32"
45+
assert (
46+
x_mean_aval.shape == () and x_mean_aval.dtype == jnp.float32
47+
), "x_mean must be a scalar with dtype float32"
48+
assert (
49+
x_std_aval.shape == () and x_std_aval.dtype == jnp.float32
50+
), "x_std must be a scalar with dtype float32"
51+
return x_aval
52+
53+
@staticmethod
54+
def lowering(
55+
ctx,
56+
x,
57+
x_min,
58+
x_max,
59+
x_mean,
60+
x_std,
61+
):
62+
"""
63+
inspect lowering rules
64+
"""
65+
66+
return ffi.ffi_lowering(
67+
InspectPrimitive.name,
68+
operand_output_aliases={0: 0}, # donate input buffer to output buffer
69+
)(
70+
ctx,
71+
x,
72+
x_min,
73+
x_max,
74+
x_mean,
75+
x_std,
76+
)
77+
78+
@staticmethod
79+
def impl(
80+
x,
81+
x_min,
82+
x_max,
83+
x_mean,
84+
x_std,
85+
):
86+
"""
87+
inspect implementation
88+
"""
89+
assert InspectPrimitive.inner_primitive is not None
90+
(x) = InspectPrimitive.inner_primitive.bind(
91+
x,
92+
x_min,
93+
x_max,
94+
x_mean,
95+
x_std,
96+
)
97+
return x
98+
99+
100+
register_primitive(InspectPrimitive)
101+
102+
103+
def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray:
104+
assert InspectPrimitive.outer_primitive is not None, (
105+
"InspectPrimitive FFI is not registered. Please ensure the C++ extension is properly built"
106+
" and registered."
107+
)
108+
return InspectPrimitive.outer_primitive.bind(
109+
x,
110+
jnp.min(x).astype(jnp.float32),
111+
jnp.max(x).astype(jnp.float32),
112+
jnp.mean(x.astype(jnp.float32)),
113+
jnp.std(x.astype(jnp.float32)),
114+
)
115+
116+
117+
@partial(jax.custom_vjp, nondiff_argnums=())
118+
def _inspect(
119+
x,
120+
):
121+
""" """
122+
output, _ = _inspect_fwd_rule(
123+
x,
124+
)
125+
return output
126+
127+
128+
def _inspect_fwd_rule(
129+
x,
130+
):
131+
""""""
132+
ctx = ()
133+
x = _inspect_array_inner(x)
134+
return x, ctx
135+
136+
137+
def _inspect_bwd_rule(
138+
ctx,
139+
grad,
140+
):
141+
""""""
142+
del ctx
143+
return (grad,)
144+
145+
146+
_inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule)
147+
148+
149+
def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
150+
"""Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics.
151+
152+
Args:
153+
x (jnp.ndarray): The JAX array to inspect.
154+
name (str): The name of the array for identification in the output.
155+
"""
156+
del name # Name is currently unused, but can be included in the future for more informative output
157+
return _inspect(x)
158+
159+
160+
def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarray:
161+
"""Utility function to load a JAX array from a dumped binary file.
162+
163+
Args:
164+
filename (str): The path to the binary file containing the array data.
165+
shape (tuple): The shape of the array to be loaded.
166+
dtype (jnp.dtype): The data type of the array to be loaded.
167+
168+
Returns:
169+
jnp.ndarray: The loaded JAX array.
170+
"""
171+
with open(filename, "rb") as f:
172+
data = f.read()
173+
array = jnp.frombuffer(data, dtype=dtype).reshape(shape)
174+
return array

0 commit comments

Comments
 (0)