diff --git a/python/python_direct/ir.cpp b/python/python_direct/ir.cpp index 582a8aa4731..93c032e4ef6 100644 --- a/python/python_direct/ir.cpp +++ b/python/python_direct/ir.cpp @@ -231,6 +231,24 @@ Returns ------- list of Val The shape of this tensor. +)") + .def( + "dtype", + [](TensorView* self) -> PrimDataType { + DataType dt = self->dtype(); + NVF_CHECK( + std::holds_alternative(dt.type), + "Expected PrimDataType but got type: ", + dt); + return std::get(dt.type); + }, + R"( +Get the data type of this tensor. + +Returns +------- +DataType + The data type of this tensor. )") .def("has_root", &TensorView::hasRoot, R"( Check if this tensor has a root domain. diff --git a/python/python_direct/ops.cpp b/python/python_direct/ops.cpp index ef7f5db265f..11ffe6a264b 100644 --- a/python/python_direct/ops.cpp +++ b/python/python_direct/ops.cpp @@ -10,6 +10,7 @@ #include #include #include +#include namespace nvfuser::python { @@ -2418,46 +2419,31 @@ TensorView* expand_fn(TensorView* arg, ShapeType generic_new_shape) { template TensorView* broadcast_in_dim_fn( - TensorView* arg, + TensorView* input, ShapeType generic_output_shape, - std::vector& broadcast_dims) { + const std::vector& nonbroadcast_dims) { std::vector output_shape = SequenceAsVector(generic_output_shape); - NVF_CHECK( - output_shape.size() >= broadcast_dims.size(), - "broadcast_dims vector size is too big for output shape!"); + NVF_CHECK_GE(output_shape.size(), nonbroadcast_dims.size()); - const auto arg_ndims = static_cast(std::ranges::distance( - arg->getLoopDomain() | TensorDomain::kNoReductions)); - NVF_CHECK( - output_shape.size() >= broadcast_dims.size(), - "The new shape is expected to be greater-then-or-equal to the input: ", - output_shape.size(), - " vs ", - arg_ndims); - NVF_CHECK( - arg_ndims == broadcast_dims.size(), - "The broadcast dimensions should match the input dimensions: ", - arg_ndims, - " vs ", - broadcast_dims.size(), - ". arg = ", - arg->toString()); + const auto input_ndim = std::ranges::distance( + input->getLogicalDomain() | TensorDomain::kNoReductions); + NVF_CHECK_GE(std::ssize(output_shape), input_ndim); + NVF_CHECK_EQ(input_ndim, std::ssize(nonbroadcast_dims)); std::vector is_broadcast_dim(output_shape.size(), true); - for (const auto idx : arange(broadcast_dims.size())) { - if (idx > 0) { - NVF_CHECK( - broadcast_dims[idx - 1] < broadcast_dims[idx], - "Broadcast dimension is not greater than the previous value."); - } + for (int64_t nonbroadcast_dim : nonbroadcast_dims) { + nonbroadcast_dim = wrapDim(nonbroadcast_dim, std::ssize(output_shape)); NVF_CHECK( - broadcast_dims[idx] < static_cast(output_shape.size()), - "Invalid broadcast_dims value."); - is_broadcast_dim.at(broadcast_dims[idx]) = false; + is_broadcast_dim.at(nonbroadcast_dim), + "nonbroadcast_dim (", + nonbroadcast_dim, + ") is specified more than once."); + is_broadcast_dim.at(nonbroadcast_dim) = false; } - auto bcast_output = broadcast(arg, is_broadcast_dim); - return expand(bcast_output, output_shape); + TensorView* output = broadcast(input, is_broadcast_dim); + output = expand(output, output_shape); + return output; } template diff --git a/tests/python/direct/test_alphafold3.py b/tests/python/direct/test_alphafold3.py index c18388bb203..fe9254dde97 100644 --- a/tests/python/direct/test_alphafold3.py +++ b/tests/python/direct/test_alphafold3.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from enum import Enum, auto -from nvfuser_direct import FusionDefinition, DataType +from nvfuser_direct import FusionDefinition, DataType, TensorView @dataclass @@ -28,14 +28,157 @@ class Direction(Enum): OUTGOING = auto() # aka starting node +def layer_norm( + fd: FusionDefinition, x: TensorView, w: TensorView, b: TensorView +) -> TensorView: + io_dtype = x.dtype() + x = fd.ops.cast(x, dtype=DataType.Float) + var, mean = fd.ops.var_mean(x, dims=[-1], correction=0, keepdim=True) + y = fd.ops.sub(x, mean) + var = fd.ops.add(var, fd.define_scalar(1e-5)) + y = fd.ops.mul(y, fd.ops.rsqrt(var)) + shape = fd.ops.shape(x) + w = fd.ops.broadcast_in_dim(w, shape=shape, broadcast_dims=[-1]) + y = fd.ops.mul(y, w) + b = fd.ops.broadcast_in_dim(b, shape=shape, broadcast_dims=[-1]) + y = fd.ops.add(y, b) + y = fd.ops.cast(y, dtype=io_dtype) + return y + + +def gating( + fd: FusionDefinition, + z: TensorView, + w_p: TensorView, + z_in: TensorView, + w_g: TensorView, +) -> TensorView: + io_dtype = z.dtype() + p = fd.ops.linear(z, w_p) + g = fd.ops.linear(z_in, w_g) + g = fd.ops.sigmoid(g) + z = fd.ops.mul(p, g) + return fd.ops.cast(z, dtype=io_dtype) + + +# https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/#triangle-updates +# +# Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure +# prediction with AlphaFold. Nature 596, 583–589 (2021). +# https://doi.org/10.1038/s41586-021-03819-2 +# (see Supplementary Methods 1.6.5 for details) @pytest.mark.parametrize( "direction", [Direction.OUTGOING, Direction.INCOMING], ids=lambda d: d.name.lower() ) def test_triangle_updates(direction): - pass + c_z = _DEFAULT_CONFIG.c_z + + with FusionDefinition() as fd: + z_in = fd.define_tensor( + shape=[-1, -1, -1, c_z], + dtype=DataType.BFloat16, + contiguity=True, + ) # [b, i, j, c_z] + w_norm_in = fd.define_tensor( + shape=[c_z], dtype=DataType.BFloat16, contiguity=True + ) + b_norm_in = fd.define_tensor( + shape=[c_z], dtype=DataType.BFloat16, contiguity=True + ) + w_p_in = fd.define_tensor( + shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True + ) + w_g_in = fd.define_tensor( + shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True + ) + w_norm_out = fd.define_tensor( + shape=[c_z], dtype=DataType.BFloat16, contiguity=True + ) + b_norm_out = fd.define_tensor( + shape=[c_z], dtype=DataType.BFloat16, contiguity=True + ) + w_p_out = fd.define_tensor( + shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True + ) + w_g_out = fd.define_tensor( + shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True + ) + # Masking is used in an internal implementation: http://nv/e-4 + mask = fd.define_tensor( + shape=[-1, -1, -1], dtype=DataType.Bool, contiguity=True + ) # [b, i, j] + + batch_size = fd.ops.size(z_in, 0) + n_tokens = fd.ops.size(z_in, 1) + + z_in = layer_norm(fd, z_in, w_norm_in, b_norm_in) + z = gating(fd, z_in, w_p_in, z_in, w_g_in) + mask = fd.ops.broadcast_in_dim( + mask, shape=[batch_size, n_tokens, n_tokens, c_z], broadcast_dims=[0, 1, 2] + ) + z = fd.ops.where(mask, z, 0.0) + a = fd.ops.slice(z, [0, 0, 0, 0], [batch_size, n_tokens, n_tokens, c_z]) + b = fd.ops.slice(z, [0, 0, 0, c_z], [batch_size, n_tokens, n_tokens, c_z * 2]) + + match direction: + case Direction.OUTGOING: + # z_out = einsum("bikc,bjkc->bijc", a, b) + a = fd.ops.permute(a, [0, 3, 1, 2]) # [b, c, i, k] + b = fd.ops.permute(b, [0, 3, 2, 1]) # [b, c, k, j] + case Direction.INCOMING: + # z_out = einsum("bkic,bkjc->bijc", a, b) + a = fd.ops.permute(a, [0, 3, 2, 1]) # [b, c, i, k] + b = fd.ops.permute(b, [0, 3, 1, 2]) # [b, c, k, j] + z = fd.ops.matmul(a, b) # [b, c, i, j] + z = fd.ops.permute(z, [0, 2, 3, 1]) # [b, i, j, c] + + z = layer_norm(fd, z, w_norm_out, b_norm_out) + z = gating(fd, z, w_p_out, z_in, w_g_out) + fd.add_output(z) + + batch_size = 3 + n_tokens = 5 + z_in = torch.testing.make_tensor( + batch_size, n_tokens, n_tokens, c_z, dtype=torch.bfloat16, device="cuda" + ) + w_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") + b_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") + w_p_in = torch.testing.make_tensor( + c_z * 2, c_z, dtype=torch.bfloat16, device="cuda" + ) + w_g_in = torch.testing.make_tensor( + c_z * 2, c_z, dtype=torch.bfloat16, device="cuda" + ) + w_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") + b_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") + w_p_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda") + w_g_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda") + mask = torch.testing.make_tensor( + batch_size, n_tokens, n_tokens, dtype=torch.bool, device="cuda" + ) + (z_out,) = fd.execute( + [ + z_in, + w_norm_in, + b_norm_in, + w_p_in, + w_g_in, + w_norm_out, + b_norm_out, + w_p_out, + w_g_out, + mask, + ] + ) + assert z_out.shape == (batch_size, n_tokens, n_tokens, c_z) # https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/#triangle-attention +# +# Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure +# prediction with AlphaFold. Nature 596, 583–589 (2021). +# https://doi.org/10.1038/s41586-021-03819-2 +# (see Supplementary Methods 1.6.6 for details) @pytest.mark.parametrize( "direction", [Direction.OUTGOING, Direction.INCOMING], ids=lambda d: d.name.lower() ) @@ -52,8 +195,8 @@ def test_triangle_attention(direction): dtype=DataType.BFloat16, contiguity=True, ) # [b, i, j, c_z] - if direction == Direction.INCOMING: - z_in = fd.ops.permute(z_in, [0, 2, 1, 3]) + w_norm = fd.define_tensor(shape=[c_z], dtype=DataType.BFloat16, contiguity=True) + b_norm = fd.define_tensor(shape=[c_z], dtype=DataType.BFloat16, contiguity=True) w_q = fd.define_tensor( shape=[h * c_hidden, c_z], dtype=DataType.BFloat16, contiguity=True ) @@ -64,8 +207,6 @@ def test_triangle_attention(direction): mask = fd.define_tensor( shape=[-1, -1, -1], dtype=DataType.Bool, contiguity=True ) # [b, i, j] - if direction == Direction.INCOMING: - mask = fd.ops.permute(mask, [0, 2, 1]) w_v = fd.define_tensor( shape=[h * c_hidden, c_z], dtype=DataType.BFloat16, contiguity=True ) @@ -79,6 +220,9 @@ def test_triangle_attention(direction): batch_size = fd.ops.size(z_in, 0) n_tokens = fd.ops.size(z_in, 1) + if direction == Direction.INCOMING: + z_in = fd.ops.permute(z_in, [0, 2, 1, 3]) + z_in = layer_norm(fd, z_in, w_norm, b_norm) q = fd.ops.linear(z_in, w_q) q_h = fd.ops.reshape( q, [batch_size, n_tokens, n_tokens, h, -1] @@ -99,6 +243,8 @@ def test_triangle_attention(direction): broadcast_dims=[0, 2, 3, 4], ) # [b, 1, h, j, k] + if direction == Direction.INCOMING: + mask = fd.ops.permute(mask, [0, 2, 1]) mask = fd.ops.broadcast_in_dim( mask, shape=[batch_size, n_tokens, 1, 1, n_tokens], @@ -142,6 +288,8 @@ def test_triangle_attention(direction): z_in = torch.testing.make_tensor( batch_size, n_tokens, n_tokens, c_z, dtype=torch.bfloat16, device="cuda" ) + w_norm = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") + b_norm = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") w_q = torch.testing.make_tensor( h * c_hidden, c_z, dtype=torch.bfloat16, device="cuda" ) @@ -161,5 +309,5 @@ def test_triangle_attention(direction): w_o = torch.testing.make_tensor( c_z, h * c_hidden, dtype=torch.bfloat16, device="cuda" ) - (z_out,) = fd.execute([z_in, w_q, w_k, w_b, mask, w_v, w_g, w_o]) + (z_out,) = fd.execute([z_in, w_norm, b_norm, w_q, w_k, w_b, mask, w_v, w_g, w_o]) assert z_out.shape == (batch_size, n_tokens, n_tokens, c_z) diff --git a/tests/python/opinfo/opinfo_input_generators.py b/tests/python/opinfo/opinfo_input_generators.py index 1efd8dea6b8..d1291a1f944 100644 --- a/tests/python/opinfo/opinfo_input_generators.py +++ b/tests/python/opinfo/opinfo_input_generators.py @@ -217,21 +217,14 @@ def broadcast_in_dim_error_generator( "The new shape is expected to be greater-then-or-equal to the input", ) - # 3. broadcast_dimensions is an ascending sequence of integers. - descending_broadcast_dimensions = ( - ([2, 2], [2, 2], [1, 0]), - RuntimeError, - "Broadcast dimension is not greater than the previous value.", - ) - - # 4. Each broadcast dimension is within the new shape. + # 3. Each broadcast dimension is within the new shape. out_of_bounds_broadcast_dimensions = ( ([2, 2], [2, 2], [0, 2]), RuntimeError, "Invalid broadcast_dims value.", ) - # 5. The original tensor is not broadcastable to desired shape. + # 4. The original tensor is not broadcastable to desired shape. # tensor.shape[idx] == 1 or tensor.shape[idx] == output_shape[new_idx] # # Jax Exception: @@ -244,7 +237,7 @@ def broadcast_in_dim_error_generator( "Invalid broadcast_dims value.", ) - # 6. TypeError: broadcast_in_dim shape must have every element be nonnegative, got (-1, 2, 3). + # 5. TypeError: broadcast_in_dim shape must have every element be nonnegative, got (-1, 2, 3). negative_shape = ( ([2, 3], [2, 3, -1], [0, 1]), RuntimeError, @@ -255,7 +248,6 @@ def broadcast_in_dim_error_generator( error_cases = [ missing_axis_in_bcast_dims, fewer_dims_in_output_shape, - descending_broadcast_dimensions, out_of_bounds_broadcast_dimensions, # not_broadcastable, # negative_shape,