diff --git a/.gitignore b/.gitignore index 6b683f1..37e4b9d 100755 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ __pycache__ build *.log +.vscode/** \ No newline at end of file diff --git a/musa_ext/kernels/musa_cast_functor.h b/musa_ext/kernels/musa_cast_functor.h new file mode 100644 index 0000000..527e74e --- /dev/null +++ b/musa_ext/kernels/musa_cast_functor.h @@ -0,0 +1,29 @@ +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "utils_op.h" + +namespace tensorflow { +namespace musa { + +static Status CastTensor(OpKernelContext* ctx, const mTensor& input_mt, + DataType dst_dtype, mTensor* output_mt) { + ::musa::dnn::Unary op; + auto status = op.SetMode(::musa::dnn::Unary::Mode::CAST); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("Einsum CastTensor SetMode failed. Status: ", + static_cast(status)); + } + status = op.Run(GetHandleByCtx(ctx), *output_mt, input_mt); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("Einsum CastTensor Run failed. Status: ", + static_cast(status)); + } + return Status::OK(); +} + +} // namespace musa +} // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc new file mode 100644 index 0000000..3ed69ae --- /dev/null +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -0,0 +1,850 @@ +#include "musa_einsum_op.h" + +#include +#include + +#include "../mu/device/musa_device.h" +#include "../utils/musa_einsum_op_util.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_split.h" +#include "musa_fill_functor.h" +#include "musa_reduce_functor.h" +#include "musa_transpose_functor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/util/matmul_bcast.h" +#include "utils_op.h" + +namespace tensorflow { +namespace musa { + +using ShapeVec = gtl::InlinedVector; +using Labels = gtl::InlinedVector; +using OperandLabels = gtl::InlinedVector; +using LabelCounts = gtl::InlinedVector; +using OperandLabelCounts = gtl::InlinedVector; +using LabelToDimSizes = gtl::InlinedVector; + +struct EinsumHelper { + // Insert new (unnamed) broadcasting labels at the location of ellipsis. + static void InsertBroadcastLabels(int num_bcast_dims, int num_named_labels, + int ellipsis_axis, Labels* labels, + LabelCounts* label_counts) { + labels->erase(labels->begin() + ellipsis_axis); + labels->insert(labels->begin() + ellipsis_axis, num_bcast_dims, 0); + std::iota(labels->begin() + ellipsis_axis, + labels->begin() + ellipsis_axis + num_bcast_dims, + num_named_labels); + // Increment label counts. Since these are new labels, the count is set + // to 1. + label_counts->resize(num_named_labels + num_bcast_dims, 1); + } + + // Record and validate the label to dimension mapping. Must be a named + // (non-broadcasting) label as broadcasting labels don't have a fixed + // dimension. + static Status RecordLabelToDimension(const int label, const int axis, + const Tensor& input, + LabelToDimSizes* label_to_dim_sizes) { + const int64_t input_dim = input.dim_size(axis); + // We know that label_to_dim_sizes has the size to accommodate named labels. + if (label_to_dim_sizes->at(label) != 0 && + label_to_dim_sizes->at(label) != input_dim) { + return errors::InvalidArgument( + "Expected dimension ", label_to_dim_sizes->at(label), " at axis ", + axis, " of the input shaped ", input.shape().DebugString(), + " but got dimension ", input_dim); + } + (*label_to_dim_sizes)[label] = input_dim; + return Status::OK(); + } + + // Validate input dimensions and populate unnamed labels and their label + // counts. + static Status ProcessDimensions( + const OpInputList& inputs, + const gtl::InlinedVector& input_has_ellipsis, + const bool output_has_ellipsis, OperandLabels* input_labels, + Labels* output_labels, std::vector* label_types, + OperandLabelCounts* input_label_counts, LabelCounts* output_label_counts, + LabelToDimSizes* label_to_dim_sizes) { + if (inputs.size() != input_labels->size()) { + return errors::InvalidArgument("Expected ", input_labels->size(), + " inputs but got: ", inputs.size()); + } + const int num_inputs = inputs.size(); + + // We infer the number of broadcasting dimensions by taking the maximum + // rank among the broadcasting subshapes of the input. + int max_bcast_dims = 0; + const int num_named_labels = label_types->size(); + label_to_dim_sizes->resize(num_named_labels); + for (int i = 0; i < num_inputs; ++i) { + Labels* labels = &(*input_labels)[i]; + + if (!input_has_ellipsis[i]) { + if (inputs[i].dims() != labels->size()) { + return errors::InvalidArgument("Expected input ", i, " to have rank ", + labels->size(), + " but got: ", inputs[i].dims()); + } + for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { + const int label = (*labels)[label_idx]; + TF_RETURN_IF_ERROR(RecordLabelToDimension(label, label_idx, inputs[i], + label_to_dim_sizes)); + } + continue; + } + + // Input has an ellipsis. + if (inputs[i].dims() + 1 < labels->size()) { + return errors::InvalidArgument( + "Expected input ", i, " to have rank at least ", labels->size() - 1, + " but got: ", inputs[i].dims()); + } + int ellipsis_axis = -1; + const int num_bcast_dims = inputs[i].dims() - labels->size() + 1; + for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { + const int label = (*labels)[label_idx]; + if (label == kEllipsisLabel) { + ellipsis_axis = label_idx; + continue; + } + // Current label is not an ellipsis. + const int axis = + label_idx + (ellipsis_axis == -1 ? 0 : num_bcast_dims - 1); + TF_RETURN_IF_ERROR( + RecordLabelToDimension(label, axis, inputs[i], label_to_dim_sizes)); + } + // Found an ellipsis. Replace 'kEllipsisLabel' with broadcasting + // dimensions. + if (ellipsis_axis != -1) { + InsertBroadcastLabels(num_bcast_dims, num_named_labels, ellipsis_axis, + labels, &input_label_counts->at(i)); + max_bcast_dims = std::max(max_bcast_dims, num_bcast_dims); + } + } + if (!absl::c_linear_search(input_has_ellipsis, true) && + !output_has_ellipsis) { + return Status::OK(); + } + // Insert broadcasting dimensions in the output labels. + auto it = + std::find(output_labels->begin(), output_labels->end(), kEllipsisLabel); + if (it != output_labels->end()) { + const int ellipsis_axis = it - output_labels->begin(); + InsertBroadcastLabels(max_bcast_dims, num_named_labels, ellipsis_axis, + output_labels, output_label_counts); + } else if (max_bcast_dims > 0) { + return errors::InvalidArgument( + "Output contains ", max_bcast_dims, + " broadcasting dimension(s) but no ellipsis " + "(...) was found in the output subscripts."); + } + // Populate EinsumDimensionType for the new broadcasting labels. + label_types->resize(num_named_labels + max_bcast_dims, + EinsumDimensionType::kBroadcasting); + return Status::OK(); + } + + // Permutes the labels according to the given permutation. + static void PermuteLabels(const std::vector& permutation, + Labels* labels) { + Labels permuted_labels(labels->size()); + for (int i = 0; i < labels->size(); ++i) { + permuted_labels[i] = (*labels)[permutation[i]]; + } + labels->swap(permuted_labels); + } + + // Returns a reshaped input Tensor. The underlying buffer is not copied. + static Status CopyFrom(const Tensor& input, const TensorShape& shape, + Tensor* output) { + if (output->CopyFrom(input, shape)) return Status::OK(); + return errors::Internal( + "Encountered error while reshaping a Tensor of shape ", + input.shape().DebugString(), " to shape ", shape.DebugString()); + } + + // Returns whether transposing would be a no-op; whether input has rank < 2 or + // the permutation is the identity permutation. + static bool ShouldTranspose(const TensorShape& input_shape, + const std::vector& permutation) { + if (input_shape.dims() < 2) return false; + for (int i = 0; i < permutation.size(); ++i) { + if (permutation[i] != i) return true; + } + return false; + } + + static Status CastTensor(OpKernelContext* ctx, const Tensor& input, + DataType dst_dtype, Tensor* output) { + if (input.dtype() == dst_dtype) { + return CopyFrom(input, input.shape(), output); + } + + TF_RETURN_IF_ERROR(ctx->allocate_temp(dst_dtype, input.shape(), output)); + if (input.NumElements() == 0) return Status::OK(); + + auto input_mt = CreateMTensor(input); + auto output_mt = CreateMTensor(*output); + + ::musa::dnn::Unary op; + auto status = op.SetMode(::musa::dnn::Unary::Mode::CAST); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("Einsum CastTensor SetMode failed. Status: ", + static_cast(status)); + } + status = op.Run(GetHandleByCtx(ctx), output_mt, input_mt); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("Einsum CastTensor Run failed. Status: ", + static_cast(status)); + } + return Status::OK(); + } + + // Transpose the input given a permutation. Returns a reference to the input + // if transposing is not necessary. + template + static Status TransposeOperand(OpKernelContext* ctx, const Tensor& input, + const std::vector& permutation, + Tensor* output) { + if (!ShouldTranspose(input.shape(), permutation)) { + return CopyFrom(input, input.shape(), output); + } + TensorShape transposed_shape; + for (int i = 0; i < input.dims(); ++i) { + TF_RETURN_IF_ERROR( + transposed_shape.AddDimWithStatus(input.dim_size(permutation[i]))); + } + // For empty Tensors, just change the shape. E.g. we may need to transpose + // from shape [1, 0, 5] to [5, 1, 0]. + if (input.NumElements() == 0) { + return CopyFrom(input, transposed_shape, output); + } + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, transposed_shape, output)); + mTensor input_mt = CreateMTensor(input); + mTensor output_mt = CreateMTensor(*output); + return TransposeFunctor::Compute(ctx, input_mt, permutation, output_mt); + } + + // If there are repeated labels in either the input or output, then this + // strides the input (e.g. iii->i) or inflates it (e.g. i->iii), respectively. + template + static Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input, + const Labels& labels, + const LabelCounts& label_counts, + const bool should_inflate, Tensor* output) { + // Return early if there are no repeated indices. + if (absl::c_all_of(label_counts, [](int c) { return c <= 1; })) { + return CopyFrom(input, input.shape(), output); + } + // We reshape so that each repeated label is compressed to one dimension. + // E.g. For iiij -> ij, The shape [3, 3, 3, 5] would be compressed to [27, + // 5]. Striding appropriately (in this case with strides 14 (=1+3+9) and 1) + // recovers the generalized diagonal of shape [3, 5]. + ShapeVec reshape; + ShapeVec strides; + // Strided and inflated shapes correspond to input and output shapes, + // respectively, should_inflate is true (vice-versa if should_inflate is + // false). E.g. they are [3, 5] and [3, 3, 3, 5] in the above example. + ShapeVec strided_shape; + ShapeVec inflated_shape; + for (int label : labels) { + const int count = label_counts[label]; + const int current_axis = + should_inflate ? strided_shape.size() : inflated_shape.size(); + const int64_t dim = input.dim_size(current_axis); + strided_shape.push_back(dim); + inflated_shape.insert(inflated_shape.end(), count, dim); + const int64_t reshape_dim = MathUtil::IPow(dim, count); + reshape.push_back(reshape_dim); + // While taking the d-diagonal in a rank k Tensor, we take d + // equally-spaced elements including the first and last element. Then, (k + // - 1) * stride = d^k - 1, or, stride = (d^k - 1)/(d - 1). + const int64_t stride = + (dim > 1 && count > 1) ? (reshape_dim - 1) / (dim - 1) : 1; + strides.push_back(stride); + } + + const ShapeVec& output_shape_dims = + should_inflate ? inflated_shape : strided_shape; + TensorShape output_shape; + for (int64_t dim : output_shape_dims) { + TF_RETURN_IF_ERROR(output_shape.AddDimWithStatus(dim)); + } + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); + + const int rank = reshape.size(); + if (rank == 0) return Status::OK(); + + auto compute_dense_strides = [](const ShapeVec& dims) { + ShapeVec dense_strides(dims.size(), 1); + for (int i = static_cast(dims.size()) - 2; i >= 0; --i) { + dense_strides[i] = dense_strides[i + 1] * dims[i + 1]; + } + return dense_strides; + }; + + const auto inflated_dense = compute_dense_strides(inflated_shape); + ShapeVec diagonal_strides; + diagonal_strides.reserve(rank); + int inflated_axis = 0; + for (int label : labels) { + const int count = label_counts[label]; + int64_t diagonal_stride = 0; + for (int i = 0; i < count; ++i) { + diagonal_stride += inflated_dense[inflated_axis + i]; + } + diagonal_strides.push_back(diagonal_stride); + inflated_axis += count; + } + + std::vector strided_dims_vec(strided_shape.begin(), + strided_shape.end()); + std::vector diagonal_strides_vec(diagonal_strides.begin(), + diagonal_strides.end()); + + auto& handle = GetHandleByCtx(ctx); + auto input_mt = CreateMTensor(input); + auto output_mt = CreateMTensor(*output); + + if (should_inflate) { + SetZeroFunctor::Compute(ctx, &output_mt); + output_mt.SetNdInfo(rank, strided_dims_vec.data(), + diagonal_strides_vec.data()); + } else { + input_mt.SetNdInfo(rank, strided_dims_vec.data(), + diagonal_strides_vec.data()); + } + + ::musa::dnn::Unary op; + auto status = op.SetMode(::musa::dnn::Unary::Mode::IDENTITY); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("Einsum StrideOrInflate SetMode failed. Status: ", + static_cast(status)); + } + status = op.Run(handle, output_mt, input_mt); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("Einsum StrideOrInflate Run failed. Status: ", + static_cast(status)); + } + return Status::OK(); + } + + // Returns true if the input dimensions are already sorted in the order + // [batch, contract, free, reduce]. Used to implement an optimization to avoid + // an extra transpose and instead uses (adj_x and adj_y) in BatchMatMul. + static bool ShouldSwapFreeAndContract( + const Labels& labels, + const std::vector& label_types) { + // Check that ordering is according to dimension type, with the role of + // free and contract dimensions swapped. + gtl::InlinedVector remap = {0, 1, 3, 2, 4}; + for (int i = 0; i + 1 < labels.size(); ++i) { + const int dimtype_a = remap[label_types[labels[i]]]; + const int dimtype_b = remap[label_types[labels[i + 1]]]; + if (dimtype_a > dimtype_b || + (dimtype_a == dimtype_b && labels[i] > labels[i + 1])) { + return false; + } + } + return true; + } + + // TODO: BMatMul seems to perform worse when the input use half precision. + template + static Status BMatMul(OpKernelContext* ctx, const Tensor& lhs, + const Tensor& rhs, bool trans_a, bool trans_b, + Tensor* output) { + const Tensor& in0 = lhs; + const Tensor& in1 = rhs; + + int64 d0 = in0.dim_size(in0.dims() - 2); + int64 d1 = in0.dim_size(in0.dims() - 1); + int64 d2 = in1.dim_size(in1.dims() - 2); + int64 d3 = in1.dim_size(in1.dims() - 1); + + int64 m = trans_a ? d1 : d0; + int64 k = trans_a ? d0 : d1; + int64 n = trans_b ? d2 : d3; + int64 k_check = trans_b ? d3 : d2; + + if (output->NumElements() == 0) return Status::OK(); + + auto& handle = GetHandleByCtx(ctx); + handle.SetAllowTF32(false); + // Use TF32 setting if needed, but here we can just default or use env like + // matmul op. Since this is a static method, we don't have access to member + // tf32_enabled_. Let's check environment variable again or just assume + // default precision. For now, let's keep it simple and consistent with + // standard usage. + + mTensor mt_a = CreateMTensor(in0); + mTensor mt_b = CreateMTensor(in1); + mTensor mt_out = CreateMTensor(*output); + + auto FixToBatchFormat = [](mTensor& mt, const Tensor& t) { + if (t.dims() == 2) { + int64_t rows = t.dim_size(0); + int64_t cols = t.dim_size(1); + mt.SetNdInfo({1, rows, cols}, {rows * cols, cols, 1}); + } + }; + + ::musa::dnn::Status status; + + mBatchMatMul op; + op.SetTranspose(trans_a, trans_b); + op.SetAlpha(1.0); + op.SetBeta(0.0); + + FixToBatchFormat(mt_a, in0); + FixToBatchFormat(mt_b, in1); + FixToBatchFormat(mt_out, *output); + + status = op.Run(handle, mt_out, mt_a, mt_b); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("MUSA BatchMatMul execution failed. Status: ", + (int)status); + } + + return Status::OK(); + } + + template + static Status ReduceOperand( + OpKernelContext* ctx, const Tensor& input, + const std::vector& label_types, + const LabelCounts& label_counts, Labels* labels, Labels* free_labels, + bool* swap_free_and_contract, Tensor* output) { + // Find the permutation to transpose the input dimensions in the order of + // EinsumDimensionType; i.e. batch, free, contract and reduce dimensions. + // This makes it more convenient to invoke Reduce/Contract operations. + std::vector permutation(input.dims()); + absl::c_iota(permutation, 0); + Tensor input_transposed; + // Check if we can avoid the transpose. We need to flip the adj_x (or adj_y) + // flag during BatchMatMul. This is an extra optimization not necessary for + // correctness. + if (ShouldSwapFreeAndContract(*labels, label_types)) { + *swap_free_and_contract = true; + } else { + absl::c_sort(permutation, [&](int i, int j) { + int label_i = (*labels)[i]; + int label_j = (*labels)[j]; + return std::tie(label_types[label_i], label_i) < + std::tie(label_types[label_j], label_j); + }); + } + // Transpose the input so that EinsumDimensionTypes are in order. + TF_RETURN_IF_ERROR( + TransposeOperand(ctx, input, permutation, &input_transposed)); + PermuteLabels(permutation, labels); + + // Take the generalized diagonal for dimensions with repeated axis labels. + Tensor input_deduped; + labels->erase(std::unique(labels->begin(), labels->end()), labels->end()); + TF_RETURN_IF_ERROR( + StrideOrInflate(ctx, input_transposed, *labels, label_counts, + false /* should_inflate */, &input_deduped)); + + // Reshape denotes the rank-5 shape [broadcast, batch, free, contract, + // reduce] where we've compacted the dimensions of each EinsumDimensionType. + gtl::InlinedVector reshape(5, 1); + // The output shape is [batch shape] + [free size, contract size] + // That is, the batch shape is preserved (for broadcasting while + // contracting) while the free dims and contract dims are compressed to one + // dimension each. + TensorShape output_shape; + for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { + const int label = labels->at(label_idx); + int64_t dim = input_deduped.dim_size(label_idx); + if (label_types[label] == EinsumDimensionType::kBroadcasting || + label_types[label] == EinsumDimensionType::kBatch) { + TF_RETURN_IF_ERROR(output_shape.AddDimWithStatus(dim)); + } else if (label_types[label] == EinsumDimensionType::kFree) { + free_labels->push_back(label); + } + reshape[label_types[label]] *= dim; + } + if (*swap_free_and_contract) + std::swap(reshape[EinsumDimensionType::kFree], + reshape[EinsumDimensionType::kContract]); + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(reshape[EinsumDimensionType::kFree])); + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(reshape[EinsumDimensionType::kContract])); + + if (reshape[EinsumDimensionType::kReduce] == 1) { // No need to reduce. + return CopyFrom(input_deduped, output_shape, output); + } + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); + const int64_t reduce_size = reshape[kReduce]; + const int64_t output_size = reshape[kBroadcasting] * reshape[kBatch] * + reshape[kFree] * reshape[kContract]; + + TensorShape input_flatten_shape; + TF_RETURN_IF_ERROR(input_flatten_shape.AddDimWithStatus(output_size)); + TF_RETURN_IF_ERROR(input_flatten_shape.AddDimWithStatus(reduce_size)); + Tensor input_flattened; + if (!input_flattened + .BitcastFrom(input_deduped, input_deduped.dtype(), + input_flatten_shape) + .ok()) { + return errors::Internal("Failed to reshape Einsum input for reduce"); + } + + TensorShape output_flatten_shape; + TF_RETURN_IF_ERROR(output_flatten_shape.AddDimWithStatus(output_size)); + Tensor output_flattened; + if (!output_flattened + .BitcastFrom(*output, output->dtype(), output_flatten_shape) + .ok()) { + return errors::Internal("Failed to reshape Einsum output for reduce"); + } + + auto input_mt = CreateMTensor(input_flattened); + auto output_mt = CreateMTensor(output_flattened); + + int reduce_dims[] = {1}; + return ReduceFunctor::Compute( + ctx, &output_mt, &input_mt, ::musa::dnn::Reduce::Mode::ADD, reduce_dims, + 1, "MUSA Reduce (sum) execution failed. Status: "); + } + + // Reshapes a Tensor of shape [b0,b1...bk,N,M] to [prod(b0,b1...bk),N,M]. + static Status ReshapeToRank3(const Tensor& input, int batch_size, + Tensor* output) { + const int rank = input.dims(); + TensorShape output_shape = {batch_size, input.dim_size(rank - 2), + input.dim_size(rank - 1)}; + return CopyFrom(input, output_shape, output); + } + + template + static Status MaterializeBroadcastedBatch( + OpKernelContext* ctx, const Tensor& input, + const TensorShape& output_batch_shape, Tensor* output) { + const int input_rank = input.dims(); + if (input_rank < 2) { + return errors::InvalidArgument( + "Einsum batch broadcast expects rank >= 2, got rank ", input_rank); + } + + const int input_batch_rank = input_rank - 2; + const int output_batch_rank = output_batch_shape.dims(); + if (output_batch_rank < input_batch_rank) { + return errors::Internal( + "Einsum batch broadcast: output batch rank ", output_batch_rank, + " is smaller than input batch rank ", input_batch_rank); + } + + TensorShape output_shape = output_batch_shape; + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(input.dim_size(input_rank - 2))); + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(input.dim_size(input_rank - 1))); + + if (input.shape() == output_shape) { + return CopyFrom(input, output_shape, output); + } + + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); + if (output->NumElements() == 0) return Status::OK(); + + std::vector input_dense_strides(input_rank, 1); + for (int axis = input_rank - 2; axis >= 0; --axis) { + input_dense_strides[axis] = + input_dense_strides[axis + 1] * input.dim_size(axis + 1); + } + + const int target_rank = output_batch_rank + 2; + std::vector target_dims(target_rank, 1); + std::vector target_strides(target_rank, 0); + const int batch_axis_offset = output_batch_rank - input_batch_rank; + + for (int out_axis = 0; out_axis < output_batch_rank; ++out_axis) { + const int64_t out_dim = output_batch_shape.dim_size(out_axis); + target_dims[out_axis] = out_dim; + + const int in_axis = out_axis - batch_axis_offset; + if (in_axis < 0) { + target_strides[out_axis] = 0; + continue; + } + + const int64_t in_dim = input.dim_size(in_axis); + if (in_dim != out_dim && in_dim != 1) { + return errors::Internal( + "Einsum batch broadcast: incompatible batch dim at axis ", out_axis, + ", input dim ", in_dim, ", output dim ", out_dim); + } + target_strides[out_axis] = + (in_dim == 1 && out_dim != 1) ? 0 : input_dense_strides[in_axis]; + } + + target_dims[output_batch_rank] = input.dim_size(input_rank - 2); + target_dims[output_batch_rank + 1] = input.dim_size(input_rank - 1); + target_strides[output_batch_rank] = input_dense_strides[input_rank - 2]; + target_strides[output_batch_rank + 1] = input_dense_strides[input_rank - 1]; + + auto input_mt = CreateMTensor(input); + auto output_mt = CreateMTensor(*output); + auto status = input_mt.SetNdInfo(target_rank, target_dims.data(), + target_strides.data()); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal( + "Einsum batch broadcast: SetNdInfo failed. Status: ", + static_cast(status)); + } + + auto& handle = GetHandleByCtx(ctx); + ::musa::dnn::Unary op; + status = op.SetMode(::musa::dnn::Unary::Mode::IDENTITY); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal( + "Einsum batch broadcast: Unary SetMode failed. Status: ", + static_cast(status)); + } + status = op.Run(handle, output_mt, input_mt); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal( + "Einsum batch broadcast: Unary Run failed. Status: ", + static_cast(status)); + } + return Status::OK(); + } + + // Contracts the inputs along the last axis (or the second last if the + // corresponding value of swap_free_and_contract is true). The batch + // dimensions are broadcast to the output shape. + template + static Status ContractOperands(OpKernelContext* ctx, + absl::Span inputs, + absl::Span swap_free_and_contract, + Tensor* output) { + if (inputs.size() == 1) + return CopyFrom(inputs[0], inputs[0].shape(), output); + MatMulBCast bcast(inputs[0].shape().dim_sizes(), + inputs[1].shape().dim_sizes()); + if (!bcast.IsValid()) { + return errors::InvalidArgument( + "Invalid broadcasting dimensions: ", inputs[0].shape().DebugString(), + " vs. ", inputs[1].shape().DebugString()); + } + + const TensorShape output_batch_shape = bcast.output_batch_shape(); + Tensor lhs_broadcasted; + TF_RETURN_IF_ERROR(MaterializeBroadcastedBatch( + ctx, inputs[0], output_batch_shape, &lhs_broadcasted)); + Tensor rhs_broadcasted; + TF_RETURN_IF_ERROR(MaterializeBroadcastedBatch( + ctx, inputs[1], output_batch_shape, &rhs_broadcasted)); + + Tensor lhs; + TF_RETURN_IF_ERROR( + ReshapeToRank3(lhs_broadcasted, bcast.output_batch_size(), &lhs)); + Tensor rhs; + TF_RETURN_IF_ERROR( + ReshapeToRank3(rhs_broadcasted, bcast.output_batch_size(), &rhs)); + + TensorShape output_shape = output_batch_shape; + for (int i = 0; i < inputs.size(); ++i) { + const int64_t free_axis = + inputs[i].dims() - (swap_free_and_contract[i] ? 1 : 2); + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(inputs[i].dim_size(free_axis))); + } + bool trans_x = swap_free_and_contract[0]; + bool trans_y = !swap_free_and_contract[1]; + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); + if (lhs.NumElements() == 0 || rhs.NumElements() == 0) { + mTensor output_mt = CreateMTensor(*output); + TF_RETURN_IF_ERROR(SetZeroFunctor::Compute(ctx, &output_mt)); + return Status::OK(); + } + Tensor output_reshaped; + TF_RETURN_IF_ERROR( + ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped)); + return BMatMul(ctx, lhs, rhs, trans_x, trans_y, &output_reshaped); + } +}; + +template +class MusaEinsumOp : public MusaOpKernel { + public: + explicit MusaEinsumOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("equation", &equation_)); + OP_REQUIRES_OK( + ctx, ParseEinsumEquation(equation_, &input_labels_, &output_labels_, + &label_types_, &input_label_counts_, + &output_label_counts_, &input_has_ellipsis_, + &output_has_ellipsis_)); + } + + void Compute(OpKernelContext* ctx) override { + OpInputList inputs; + OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &inputs)); + + OperandLabels input_labels(input_labels_); + Labels output_labels(output_labels_); + std::vector label_types(label_types_); + OperandLabelCounts input_label_counts(input_label_counts_); + LabelCounts output_label_counts(output_label_counts_); + LabelToDimSizes label_to_dim_sizes; + + OP_REQUIRES_OK(ctx, EinsumHelper::ProcessDimensions( + inputs, input_has_ellipsis_, output_has_ellipsis_, + &input_labels, &output_labels, &label_types, + &input_label_counts, &output_label_counts, + &label_to_dim_sizes)); + + // The reduction phase (a) sums across reduction dimensions, (b) takes + // generalized diagonals, and (c) reshapes it into shape + // [(broadcasting) batch shape] + [F,C] + // where F and C denote the total (compacted) size of free and contract + // dimensions, respectively. + const int num_inputs = inputs.size(); + OperandLabels free_labels(num_inputs); + gtl::InlinedVector inputs_reduced(num_inputs); + gtl::InlinedVector swap_free_and_contract(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + OP_REQUIRES_OK(ctx, + EinsumHelper::ReduceOperand( + ctx, inputs[i], label_types, input_label_counts[i], + &input_labels[i], &free_labels[i], + &swap_free_and_contract[i], &inputs_reduced[i])); + } + + // After reduction, the inputs should be reshaped to Tensors suitable for + // contraction. If num_inputs is 1, the reduced input is simply forwarded to + // the output. + Tensor contraction_output_reshaped; + OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands( + ctx, inputs_reduced, swap_free_and_contract, + &contraction_output_reshaped)); + + // Copy the batch labels from the contraction output. Recover the batch + // shape, which may have been broadcasted. + TensorShape result_shape = contraction_output_reshaped.shape(); + result_shape.RemoveLastDims(2); + + int num_labels = label_types.size(); + Labels result_labels; + // All batch dimensions should be present in the contracted result. First + // the broadcasting dimensions, then the named batch dimensions. + for (int label = 0; label < num_labels; ++label) { + if (label_types[label] == EinsumDimensionType::kBroadcasting) + result_labels.push_back(label); + } + for (int label = 0; label < num_labels; ++label) { + if (label_types[label] == EinsumDimensionType::kBatch) + result_labels.push_back(label); + } + for (int i = 0; i < num_inputs; ++i) { + for (int label : free_labels[i]) { + result_labels.push_back(label); + OP_REQUIRES_OK( + ctx, result_shape.AddDimWithStatus(label_to_dim_sizes[label])); + } + } + + // Reshape the contraction (or reduction) result to its expanded shape: + // [(broadcasted) batch shape] + [free shape 0] + [free shape 1]. + Tensor contraction_output; + OP_REQUIRES_OK( + ctx, EinsumHelper::CopyFrom(contraction_output_reshaped, result_shape, + &contraction_output)); + + // Inflate the output if necessary. (E.g. for the equation 'i->iii' which + // may arise while computing gradient of a regular Einsum). + Tensor output_inflated; + OP_REQUIRES_OK( + ctx, EinsumHelper::StrideOrInflate( + ctx, contraction_output, result_labels, output_label_counts, + true /* should_inflate */, &output_inflated)); + if (output_inflated.dims() > contraction_output.dims()) { + // We inflated the output. Modify result labels accordingly. + Labels inflated_labels; + for (int label : result_labels) { + inflated_labels.insert(inflated_labels.end(), + output_label_counts[label], label); + } + result_labels.swap(inflated_labels); + } + // Find the permutation to map the result labels to the output labels. Note + // that both the result and the final output may have the repeated labels, + // in which case the permutation preserves the left-to-right ordering. + // E.g. if result labels are [0, 0, 1] and output is [0, l, 0] then the + // permutation should be [0, 2, 1]. We also use the fact that repeated + // labels in the result are adjacent to each other. + std::vector output_permutation(output_labels.size()); + std::vector label_to_position(num_labels, -1); + for (int i = 0; i < result_labels.size(); ++i) { + // Remember the position of only the leftmost result label. + if (label_to_position[result_labels[i]] == -1) { + label_to_position[result_labels[i]] = i; + } + } + for (int i = 0; i < output_labels.size(); ++i) { + output_permutation[i] = label_to_position[output_labels[i]]; + // We have found the leftmost occurrence. The next one would be adjacent. + label_to_position[output_labels[i]] += 1; + } + Tensor output; + OP_REQUIRES_OK(ctx, EinsumHelper::TransposeOperand( + ctx, output_inflated, output_permutation, &output)); + ctx->set_output(0, output); + } + + string TraceString(const OpKernelContext& ctx, bool verbose) const override { + string op = profiler::TraceMeOp(name_view(), type_string_view()); + string equation = strings::StrCat("(", equation_, ")"); + if (verbose) { + string shape = ShapeTraceString(ctx); + if (!shape.empty()) { + return profiler::TraceMeEncode( + std::move(op), {{"equation", equation}, {"shape", shape}}); + } + } + return profiler::TraceMeEncode(std::move(op), {{"equation", equation}}); + } + + private: + string equation_; + OperandLabels input_labels_; + Labels output_labels_; + std::vector label_types_; + OperandLabelCounts input_label_counts_; + LabelCounts output_label_counts_; + gtl::InlinedVector input_has_ellipsis_; + bool output_has_ellipsis_ = false; +}; // class MusaEinsumOp + +#define REGISTER_MUSA_EINSUM(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Einsum").Device("MUSA").TypeConstraint("T"), \ + MusaEinsumOp); + +REGISTER_MUSA_EINSUM(float); +REGISTER_MUSA_EINSUM(double); +REGISTER_MUSA_EINSUM(int32); +REGISTER_MUSA_EINSUM(int64); +REGISTER_MUSA_EINSUM(Eigen::half); +REGISTER_MUSA_EINSUM(bfloat16); + +#undef REGISTER_MUSA_EINSUM + +} // namespace musa +} // namespace tensorflow diff --git a/musa_ext/kernels/musa_einsum_op.h b/musa_ext/kernels/musa_einsum_op.h new file mode 100644 index 0000000..3ff546b --- /dev/null +++ b/musa_ext/kernels/musa_einsum_op.h @@ -0,0 +1,62 @@ +#include "musa_stride_inflate_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "utils_op.h" + +namespace tensorflow { +namespace musa { + +template +struct StrideFunctor { + void operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor input, + const Eigen::DSizes& strides, + typename TTypes::Tensor output) { + const Eigen::DenseIndex num_elements = output.size(); + if (num_elements == 0) return; + + DimSizeArray dims = {}; + DimSizeArray stride_dims = {}; + for (int i = 0; i < N; ++i) { + dims.value[i] = static_cast(output.dimension(i)); + stride_dims.value[i] = static_cast(strides[i]); + } + + auto stream = GetMusaStreamByCtx(ctx); + MusaStrideKernelLauncher(stream, static_cast(num_elements), + input.data(), output.data(), dims, stride_dims, + N); + } +}; + +template +struct InflateFunctor { + void operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor input, + const Eigen::DSizes& strides, + typename TTypes::Tensor output) { + const Eigen::DenseIndex input_elements = input.size(); + const Eigen::DenseIndex output_elements = output.size(); + const int64_t output_count = static_cast(output_elements); + + auto stream = GetMusaStreamByCtx(ctx); + if (output_count > 0) { + const uint64_t bytes = output_count * sizeof(T); + musaMemsetAsync(output.data(), 0, bytes, stream); + } + if (input_elements == 0 || output_elements == 0) return; + + DimSizeArray input_dims = {}; + DimSizeArray stride_dims = {}; + for (int i = 0; i < N; ++i) { + input_dims.value[i] = static_cast(input.dimension(i)); + stride_dims.value[i] = static_cast(strides[i]); + } + + MusaInflateKernelLauncher(stream, static_cast(input_elements), + input.data(), output.data(), input_dims, + stride_dims, N, output_count); + } +}; + +} // namespace musa +} // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/kernels/musa_fill_functor.h b/musa_ext/kernels/musa_fill_functor.h new file mode 100644 index 0000000..a3095d2 --- /dev/null +++ b/musa_ext/kernels/musa_fill_functor.h @@ -0,0 +1,56 @@ +/* + Be advised: + + This file is implemented in aim to support the einsum operator. + For now it only contains the SetZeroFunctor, which is used to set the output + tensor to zero before accumulating the results of the einsum computation. + +*/ + +#include + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" +#include "utils_op.h" + +namespace tensorflow { +namespace musa { + +template +Status MusaFillCall(mTensor* out_mt, T value, OpKernelContext* context) { + mFill op; + mHandle& h = GetHandleByCtx(context); + + if (std::is_integral::value) { + if (mStatus::SUCCESS != op.SetValue(static_cast(value))) { + return errors::Internal("mtdnn set value (int) error!"); + } + } else if (std::is_floating_point::value || + std::is_same::value || + std::is_same::value) { + if (mStatus::SUCCESS != op.SetValue(static_cast(value))) { + return errors::Internal("mtdnn set value (float) error!"); + } + } else { + return errors::Unimplemented("Data type not supported in MTGPU Fill."); + } + + if (mStatus::SUCCESS != op.Run(h, *out_mt)) { + return errors::Internal("mtdnn run op error!"); + } + + return Status::OK(); +} + +struct SetZeroFunctor { + // Computes on device "d": out = out.setZero(), + template + static Status Compute(OpKernelContext* ctx, mTensor* out_mt) { + return MusaFillCall(out_mt, T(0), ctx); + } +}; + +} // namespace musa +} // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/kernels/musa_fill_op.cc b/musa_ext/kernels/musa_fill_op.cc deleted file mode 100755 index d085cc6..0000000 --- a/musa_ext/kernels/musa_fill_op.cc +++ /dev/null @@ -1,120 +0,0 @@ -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "utils_op.h" - -namespace tensorflow { -namespace musa { - -namespace { - -template -struct is_any : std::false_type {}; - -template -struct is_any : std::is_same {}; - -template -struct is_any - : std::integral_constant::value || - is_any::value> {}; - -template -Status MusaFillCall(Tensor* out, T value, OpKernelContext* context) { - mFill op; - mHandle& h = GetHandleByCtx(context); - auto out_mt = CreateMTensor(*out); - - if (is_any::value) { - if (mStatus::SUCCESS != op.SetValue(static_cast(value))) { - return errors::Internal("mtdnn set value (int) error!"); - } - } else if (is_any::value) { - if (mStatus::SUCCESS != op.SetValue(static_cast(value))) { - return errors::Internal("mtdnn set value (float) error!"); - } - } else { - return errors::Unimplemented("Data type not supported in MTGPU Fill."); - } - - if (mStatus::SUCCESS != op.Run(h, out_mt)) { - return errors::Internal("mtdnn run op error!"); - } - - return Status::OK(); -} - -} // namespace - -template -class MusaFillOp : public MusaOpKernel { - public: - explicit MusaFillOp(OpKernelConstruction* context) : MusaOpKernel(context) {} - - void Compute(OpKernelContext* context) override { - const Tensor& Tdims = context->input(0); - const Tensor& Tvalue = context->input(1); - - OP_REQUIRES( - context, - (TensorShapeUtils::IsVector(Tdims.shape()) || - TensorShapeUtils::IsScalar(Tdims.shape())), - errors::InvalidArgument("dims must represent a vector, got shape ", - Tdims.shape().DebugString())); - - OP_REQUIRES( - context, - TensorShapeUtils::IsScalar(Tvalue.shape()) || - (TensorShapeUtils::IsVector(Tvalue.shape()) && - Tvalue.shape().dim_size(0) == 1), - errors::InvalidArgument("value must represent a scalar, got shape ", - Tvalue.shape().DebugString())); - - auto dims_vec = Tdims.flat(); - TensorShape shape; - OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( - reinterpret_cast(dims_vec.data()), - dims_vec.size(), &shape)); - - Tensor* out = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, shape, &out)); - - if (shape.num_elements() == 0) return; - - OP_REQUIRES_OK( - context, MusaFillCall(out, static_cast(Tvalue.data())[0], context)); - } -}; - -#define REGISTER_FILL_KERNEL(type) \ - REGISTER_KERNEL_BUILDER(Name("Fill") \ - .Device("MUSA") \ - .TypeConstraint("T") \ - .TypeConstraint("index_type") \ - .HostMemory("dims") \ - .HostMemory("value"), \ - MusaFillOp); \ - REGISTER_KERNEL_BUILDER(Name("Fill") \ - .Device("MUSA") \ - .TypeConstraint("T") \ - .TypeConstraint("index_type") \ - .HostMemory("dims") \ - .HostMemory("value"), \ - MusaFillOp); - -REGISTER_FILL_KERNEL(float); -REGISTER_FILL_KERNEL(double); -REGISTER_FILL_KERNEL(int32); -REGISTER_FILL_KERNEL(int64); -REGISTER_FILL_KERNEL(Eigen::half); -REGISTER_FILL_KERNEL(Eigen::bfloat16); -REGISTER_FILL_KERNEL(bool); - -#undef REGISTER_FILL_KERNEL - -} // namespace musa -} // namespace tensorflow diff --git a/musa_ext/kernels/musa_reduce_functor.h b/musa_ext/kernels/musa_reduce_functor.h new file mode 100644 index 0000000..ef9902d --- /dev/null +++ b/musa_ext/kernels/musa_reduce_functor.h @@ -0,0 +1,74 @@ +#ifndef MUSA_PLUGIN_SRC_KERNELS_MUSA_REDUCE_FUNCTOR_H_ +#define MUSA_PLUGIN_SRC_KERNELS_MUSA_REDUCE_FUNCTOR_H_ + +#include +#include + +#include "musa_cast_functor.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "utils_op.h" + +namespace tensorflow { +namespace musa { + +struct ReduceFunctor { + template + static Status Compute(OpKernelContext* ctx, mTensor* output, mTensor* input, + ::musa::dnn::Reduce::Mode mode, const int* reduce_dims, + int reduce_dim_count, const char* error_prefix) { + auto& handle = GetHandleByCtx(ctx); + + mReduce op; + op.SetMode(mode); + op.SetDim(reduce_dim_count, reduce_dims); + + tensorflow::Allocator* tf_allocator = + ctx->device()->GetAllocator(tensorflow::AllocatorAttributes()); + auto alloc_func = + [tf_allocator]( + size_t size) -> std::unique_ptr> { + void* ptr = tf_allocator->AllocateRaw(256, size); + std::function deleter = [tf_allocator](void* p) { + if (p) tf_allocator->DeallocateRaw(p); + }; + return std::unique_ptr>(ptr, deleter); + }; + ::musa::dnn::MemoryMaintainer mm(alloc_func); + + auto status = op.Run(handle, *output, *input, mm); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal(error_prefix, static_cast(status)); + } + return Status::OK(); + } +}; + +// Given the fact that bf16 does not work well with reduction, we compute the +// reduction in fp32 and cast the result back to bf16. +// This conversion is aligned with tensorflow's convention of promoting bf16 to +// fp32 for ReduceFunctor. +template <> +Status ReduceFunctor::Compute(OpKernelContext* ctx, + mTensor* output_mt, mTensor* input_mt, + ::musa::dnn::Reduce::Mode mode, + const int* reduce_dims, + int reduce_dim_count, + const char* error_prefix) { + mTensor input_fp32; + TF_RETURN_IF_ERROR(CastTensor(ctx, *input_mt, DT_FLOAT, &input_fp32)); + + mTensor output_fp32; + TF_RETURN_IF_ERROR(Compute(ctx, &output_fp32, &input_fp32, mode, + reduce_dims, reduce_dim_count, + error_prefix)); + + return CastTensor(ctx, output_fp32, DataTypeToEnum::value, + output_mt); +} + +} // namespace musa +} // namespace tensorflow + +#endif // MUSA_PLUGIN_SRC_KERNELS_MUSA_REDUCE_FUNCTOR_H_ \ No newline at end of file diff --git a/musa_ext/kernels/musa_stride_inflate_kernel.h b/musa_ext/kernels/musa_stride_inflate_kernel.h new file mode 100644 index 0000000..32c2caa --- /dev/null +++ b/musa_ext/kernels/musa_stride_inflate_kernel.h @@ -0,0 +1,31 @@ +#ifndef MUSA_PLUGIN_SRC_KERNELS_MUSA_STRIDE_INFLATE_KERNEL_H_ +#define MUSA_PLUGIN_SRC_KERNELS_MUSA_STRIDE_INFLATE_KERNEL_H_ + +#include + +#include + +namespace tensorflow { +namespace musa { + +constexpr int kMaxStrideInflateDims = 8; + +struct DimSizeArray { + int64_t value[kMaxStrideInflateDims]; +}; + +template +void MusaStrideKernelLauncher(musaStream_t stream, int64_t size, const T* in, + T* out, DimSizeArray dims, DimSizeArray strides, + int ndims); + +template +void MusaInflateKernelLauncher(musaStream_t stream, int64_t input_size, + const T* in, T* out, DimSizeArray input_dims, + DimSizeArray strides, int ndims, + int64_t output_size); + +} // namespace musa +} // namespace tensorflow + +#endif // MUSA_PLUGIN_SRC_KERNELS_MUSA_STRIDE_INFLATE_KERNEL_H_ diff --git a/musa_ext/kernels/musa_stride_inflate_kernel.mu b/musa_ext/kernels/musa_stride_inflate_kernel.mu new file mode 100644 index 0000000..0f464a1 --- /dev/null +++ b/musa_ext/kernels/musa_stride_inflate_kernel.mu @@ -0,0 +1,114 @@ +#include + +#include "tensorflow/core/framework/bfloat16.h" + +#include "musa_stride_inflate_kernel.h" + +namespace tensorflow { +namespace musa { + +namespace { + +template +__global__ void MusaStrideKernel(int64_t size, const T* __restrict__ in, + T* __restrict__ out, DimSizeArray dims, + DimSizeArray strides, int ndims) { + const int64_t block_stride = static_cast(blockDim.x) * gridDim.x; + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + while (idx < size) { + int64_t coords[kMaxStrideInflateDims] = {0}; + int64_t tmp = idx; + for (int dim = ndims - 1; dim >= 0; --dim) { + const int64_t dim_size = dims.value[dim]; + if (dim_size > 0) { + coords[dim] = tmp % dim_size; + tmp /= dim_size; + } + } + int64_t in_index = 0; + for (int dim = 0; dim < ndims; ++dim) { + in_index += coords[dim] * strides.value[dim]; + } + out[idx] = in[in_index]; + idx += block_stride; + } +} + +template +__global__ void MusaInflateKernel(int64_t size, const T* __restrict__ in, + T* __restrict__ out, DimSizeArray in_dims, + DimSizeArray strides, int ndims, + int64_t out_size) { + const int64_t block_stride = static_cast(blockDim.x) * gridDim.x; + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + while (idx < size) { + int64_t coords[kMaxStrideInflateDims] = {0}; + int64_t tmp = idx; + for (int dim = ndims - 1; dim >= 0; --dim) { + const int64_t dim_size = in_dims.value[dim]; + if (dim_size > 0) { + coords[dim] = tmp % dim_size; + tmp /= dim_size; + } + } + int64_t out_index = 0; + for (int dim = 0; dim < ndims; ++dim) { + out_index += coords[dim] * strides.value[dim]; + } + if (out_index >= 0 && out_index < out_size) { + out[out_index] = in[idx]; + } + idx += block_stride; + } +} + +} // namespace + +template +void MusaStrideKernelLauncher(musaStream_t stream, int64_t size, const T* in, + T* out, DimSizeArray dims, + DimSizeArray strides, int ndims) { + if (size == 0) { + return; + } + constexpr int block_size = 256; + const int grid_size = static_cast( + (size + block_size - 1) / block_size); + MusaStrideKernel<<>>( + size, in, out, dims, strides, ndims); +} + +template +void MusaInflateKernelLauncher(musaStream_t stream, int64_t input_size, + const T* in, T* out, DimSizeArray in_dims, + DimSizeArray strides, int ndims, + int64_t out_size) { + if (input_size == 0) { + return; + } + constexpr int block_size = 256; + const int grid_size = static_cast( + (input_size + block_size - 1) / block_size); + MusaInflateKernel<<>>( + input_size, in, out, in_dims, strides, ndims, out_size); +} + +#define INSTANTIATE_STRIDE_INFLATE(T) \ + template void MusaStrideKernelLauncher(musaStream_t, int64_t, const T*, \ + T*, DimSizeArray, DimSizeArray, \ + int); \ + template void MusaInflateKernelLauncher( \ + musaStream_t, int64_t, const T*, T*, DimSizeArray, DimSizeArray, int, \ + int64_t) + +INSTANTIATE_STRIDE_INFLATE(float); +INSTANTIATE_STRIDE_INFLATE(double); +INSTANTIATE_STRIDE_INFLATE(int32); +INSTANTIATE_STRIDE_INFLATE(int64); +INSTANTIATE_STRIDE_INFLATE(Eigen::half); +INSTANTIATE_STRIDE_INFLATE(Eigen::bfloat16); + +#undef INSTANTIATE_STRIDE_INFLATE + +} // namespace musa +} // namespace tensorflow diff --git a/musa_ext/kernels/musa_transpose_functor.h b/musa_ext/kernels/musa_transpose_functor.h new file mode 100644 index 0000000..950169d --- /dev/null +++ b/musa_ext/kernels/musa_transpose_functor.h @@ -0,0 +1,30 @@ +#include + +#include "utils_op.h" + +namespace tensorflow { +namespace musa { + +struct TransposeFunctor { + static Status Compute(OpKernelContext* ctx, mTensor& in_mt, + const std::vector& permutation, + mTensor& out_mt) { + mHandle& h = GetHandleByCtx(ctx); + + ::musa::dnn::Permute pop; + + if (::musa::dnn::Status::SUCCESS != + pop.ConfigDimStride(out_mt, in_mt, static_cast(permutation.size()), + permutation.data())) { + return errors::Internal("muDNN Permute ConfigDimStride failed!"); + } + + if (::musa::dnn::Status::SUCCESS != pop.Run(h, out_mt, in_mt)) { + return errors::Internal("muDNN Permute Run failed!"); + } + return Status::OK(); + } +}; + +} // namespace musa +} // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/kernels/musa_transpose_op.cc b/musa_ext/kernels/musa_transpose_op.cc old mode 100755 new mode 100644 index 52d1d4f..b4a0c9c --- a/musa_ext/kernels/musa_transpose_op.cc +++ b/musa_ext/kernels/musa_transpose_op.cc @@ -2,6 +2,7 @@ #include +#include "musa_transpose_functor.h" #include "tensorflow/core/framework/bfloat16.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -75,25 +76,12 @@ class MusaTransposeOp : public MusaOpKernel { if (output->NumElements() == 0) return; - mHandle& h = GetHandleByCtx(ctx); - - mTensor in_mt = CreateMTensor(input, format_); - mTensor out_mt = CreateMTensor(*output, format_); - - ::musa::dnn::Permute pop; - - if (::musa::dnn::Status::SUCCESS != - pop.ConfigDimStride(out_mt, in_mt, - static_cast(permutation_64.size()), - permutation_64.data())) { - ctx->CtxFailure( - errors::Internal("muDNN Permute ConfigDimStride failed!")); - return; - } - - if (::musa::dnn::Status::SUCCESS != pop.Run(h, out_mt, in_mt)) { - ctx->CtxFailure(errors::Internal("muDNN Permute Run failed!")); - return; + mTensor input_mt = CreateMTensor(input); + mTensor output_mt = CreateMTensor(*output); + Status status = + TransposeFunctor::Compute(ctx, input_mt, permutation_64, output_mt); + if (!status.ok()) { + ctx->CtxFailure(status); } } }; diff --git a/musa_ext/utils/musa_einsum_op_util.h b/musa_ext/utils/musa_einsum_op_util.h new file mode 100644 index 0000000..01116ba --- /dev/null +++ b/musa_ext/utils/musa_einsum_op_util.h @@ -0,0 +1,153 @@ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_split.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { +namespace musa { + +using Labels = absl::InlinedVector; +using OperandLabels = absl::InlinedVector; +using LabelCounts = absl::InlinedVector; +using OperandLabelCounts = absl::InlinedVector; + +// Dummy axis label used to denote an ellipsis in an input or output subscript. +constexpr int kEllipsisLabel = -1; + +enum EinsumDimensionType { + // Batch dimensions are those present in two inputs as well as the output. + // They are part of the batch dimensions during Tensor contraction. Such + // dimensions may be broadcasting dimensions (those mapping to ellipsis) + // or explicit batch dimensions corresponding to named axis labels. + kBroadcasting = 0, + kBatch = 1, + // Free dimensions are present in exactly one of the inputs, and also the + // output. These are non-contracted axes in the Tensor contraction. + kFree = 2, + // Contract dimensions are present in two inputs, but not the output. These + // dimensions are contracted in Tensor contraction. + kContract = 3, + // Reduce dimensions are present in exactly one input; and not in the output + // and are summed over prior to Tensor contraction. + kReduce = 4, +}; + +// Returns the EinsumDimensionType given whether the corresponding label is +// present in exactly one input subscript (is_unique) and whether it is absent +// from the output subscripts (is_removed). Does not handle broadcasting +// dimensions. +inline EinsumDimensionType GetDimensionType(bool is_removed, bool is_unique) { + if (!is_removed && !is_unique) + return kBatch; + else if (!is_removed && is_unique) + return kFree; + else if (is_removed && !is_unique) + return kContract; + else // is_removed && is_unique + return kReduce; +} + +inline Status ValidateEinsumEquation( + const std::string& equation, + absl::InlinedVector* input_subscripts, + std::string* output_subscript) { + absl::InlinedVector inputs_and_output_subscripts = + absl::StrSplit(equation, "->"); + if (inputs_and_output_subscripts.size() != 2) { + return errors::InvalidArgument( + "Expecting exactly one '->' in einsum equation: ", equation); + } + *output_subscript = std::move(inputs_and_output_subscripts[1]); + *input_subscripts = + absl::StrSplit(std::move(inputs_and_output_subscripts[0]), ','); + if (input_subscripts->size() != 1 && input_subscripts->size() != 2) { + return errors::InvalidArgument( + "Expecting 1 or 2 input subscripts in equation '", equation, + "' but got: ", input_subscripts->size()); + } + return Status::OK(); +} + +// Maps the character labels to consecutive integers. +inline void MapToLabels(const std::string& subscript, Labels* labels, + absl::flat_hash_map* label_mapping) { + for (int i = 0; i < subscript.size(); ++i) { + const char label_char = subscript[i]; + if (label_char == '.') { + labels->push_back(kEllipsisLabel); + i += 2; // Skip next 2 characters as well. + continue; + } + if (!label_mapping->contains(label_char)) { + const int next_label = label_mapping->size(); + (*label_mapping)[label_char] = next_label; + } + const int mapped_label = (*label_mapping)[label_char]; + labels->push_back(mapped_label); + } +} + +inline Status ParseEinsumEquation( + const std::string& equation, OperandLabels* input_labels, + Labels* output_labels, std::vector* label_types, + OperandLabelCounts* input_label_counts, LabelCounts* output_label_counts, + absl::InlinedVector* input_has_ellipsis, + bool* output_has_ellipsis) { + absl::InlinedVector input_str; + std::string output_str; + TF_RETURN_IF_ERROR(ValidateEinsumEquation(equation, &input_str, &output_str)); + + // Temporary map from single character labels to (consecutive) integer labels. + absl::flat_hash_map label_mapping; + int num_inputs = input_str.size(); + input_labels->resize(num_inputs); + + // Map from single characters to integer labels. + for (int i = 0; i < num_inputs; ++i) { + MapToLabels(input_str[i], &input_labels->at(i), &label_mapping); + } + MapToLabels(output_str, output_labels, &label_mapping); + + // Compute counts for input and output labels. + int num_labels = label_mapping.size(); + input_label_counts->resize(num_inputs); + input_has_ellipsis->resize(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + input_label_counts->at(i).resize(num_labels); + input_has_ellipsis->at(i) = false; + for (const int label : input_labels->at(i)) { + if (label != kEllipsisLabel) + input_label_counts->at(i)[label] += 1; + else + input_has_ellipsis->at(i) = true; + } + } + output_label_counts->resize(num_labels); + *output_has_ellipsis = false; + for (const int label : *output_labels) { + if (label != kEllipsisLabel) + output_label_counts->at(label) += 1; + else + *output_has_ellipsis = true; + } + + // Map each label to a unique EinsumDimensionType. + label_types->resize(num_labels); + for (int label = 0; label < num_labels; ++label) { + if (label == kEllipsisLabel) continue; + bool removed = (*output_label_counts)[label] == 0; + bool unique = num_inputs == 1 || (*input_label_counts)[0][label] == 0 || + (*input_label_counts)[1][label] == 0; + (*label_types)[label] = GetDimensionType(removed, unique); + } + return Status::OK(); +} + +} // namespace musa +} // namespace tensorflow \ No newline at end of file diff --git a/test/einsum_op_test.py b/test/einsum_op_test.py new file mode 100644 index 0000000..779a062 --- /dev/null +++ b/test/einsum_op_test.py @@ -0,0 +1,121 @@ +"""Tests for MUSA Einsum operator.""" + +import numpy as np +import tensorflow as tf + +from musa_test_utils import MUSATestCase + + +class EinsumOpTest(MUSATestCase): + """Tests for MUSA Einsum operator with TensorFlow-compatible behavior.""" + + def _run_cpu_musa(self, equation, inputs, rtol=1e-5, atol=1e-8): + """Run einsum on CPU and MUSA then compare outputs.""" + with tf.device('/CPU:0'): + cpu_result = tf.einsum(equation, *inputs) + + with tf.device('/device:MUSA:0'): + musa_result = tf.einsum(equation, *inputs) + + if cpu_result.dtype in [tf.float16, tf.bfloat16]: + cpu_result = tf.cast(cpu_result, tf.float32) + musa_result = tf.cast(musa_result, tf.float32) + + self.assertAllEqual(cpu_result.shape, musa_result.shape) + self.assertAllClose(cpu_result.numpy(), musa_result.numpy(), rtol=rtol, atol=atol) + + def _assert_error_consistency(self, equation, inputs): + """Assert CPU and MUSA both raise TensorFlow exceptions for invalid inputs.""" + cpu_error = None + musa_error = None + + try: + with tf.device('/CPU:0'): + tf.einsum(equation, *inputs) + except Exception as e: # pylint: disable=broad-except + cpu_error = e + + try: + with tf.device('/device:MUSA:0'): + tf.einsum(equation, *inputs) + except Exception as e: # pylint: disable=broad-except + musa_error = e + + self.assertIsNotNone(cpu_error, "CPU should raise for invalid einsum input") + self.assertIsNotNone(musa_error, "MUSA should raise for invalid einsum input") + self.assertEqual(type(cpu_error), type(musa_error)) + + def testEinsumMatrixMultiplication(self): + """Test classic matrix multiplication: ij,jk->ik.""" + for dtype in [tf.float32, tf.float16, tf.bfloat16]: + a_np = np.random.uniform(-1.0, 1.0, size=(8, 16)).astype(np.float32) + b_np = np.random.uniform(-1.0, 1.0, size=(16, 10)).astype(np.float32) + a = tf.constant(a_np, dtype=dtype) + b = tf.constant(b_np, dtype=dtype) + + rtol = 2e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 + atol = 2e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-6 + self._run_cpu_musa('ij,jk->ik', [a, b], rtol=rtol, atol=atol) + + def testEinsumBatchMatMul(self): + """Test batched matmul: bij,bjk->bik.""" + a = tf.constant(np.random.uniform(-1.0, 1.0, size=(4, 6, 8)).astype(np.float32)) + b = tf.constant(np.random.uniform(-1.0, 1.0, size=(4, 8, 5)).astype(np.float32)) + self._run_cpu_musa('bij,bjk->bik', [a, b]) + + def testEinsumBroadcastWithEllipsis(self): + """Test ellipsis broadcasting: ...ij,...jk->...ik.""" + a = tf.constant(np.random.uniform(-1.0, 1.0, size=(2, 3, 4, 6)).astype(np.float32)) + b = tf.constant(np.random.uniform(-1.0, 1.0, size=(1, 3, 6, 5)).astype(np.float32)) + self._run_cpu_musa('...ij,...jk->...ik', [a, b]) + + def testEinsumImplicitOutput(self): + """Test implicit output mode: ij,jk (without ->).""" + a = tf.constant(np.random.uniform(-1.0, 1.0, size=(7, 11)).astype(np.float32)) + b = tf.constant(np.random.uniform(-1.0, 1.0, size=(11, 9)).astype(np.float32)) + self._run_cpu_musa('ij,jk', [a, b]) + + def testEinsumTranspose(self): + """Test axis permutation with single input: ijk->kji.""" + x = tf.constant(np.random.uniform(-1.0, 1.0, size=(3, 5, 7)).astype(np.float32)) + self._run_cpu_musa('ijk->kji', [x]) + + def testEinsumDiagonalExtraction(self): + """Test repeated indices for diagonal extraction: ii->i.""" + x = tf.constant(np.random.uniform(-1.0, 1.0, size=(9, 9)).astype(np.float32)) + self._run_cpu_musa('ii->i', [x]) + + def testEinsumReduction(self): + """Test reduction to scalar: ij->.""" + x = tf.constant(np.random.uniform(-1.0, 1.0, size=(13, 17)).astype(np.float32)) + self._run_cpu_musa('ij->', [x]) + + def testEinsumOuterProduct(self): + """Test outer product: i,j->ij.""" + x = tf.constant(np.random.uniform(-1.0, 1.0, size=(12,)).astype(np.float32)) + y = tf.constant(np.random.uniform(-1.0, 1.0, size=(8,)).astype(np.float32)) + self._run_cpu_musa('i,j->ij', [x, y]) + + def testEinsumThreeOperands(self): + """Test multi-operand contraction: ab,bc,cd->ad.""" + a = tf.constant(np.random.uniform(-1.0, 1.0, size=(4, 6)).astype(np.float32)) + b = tf.constant(np.random.uniform(-1.0, 1.0, size=(6, 5)).astype(np.float32)) + c = tf.constant(np.random.uniform(-1.0, 1.0, size=(5, 3)).astype(np.float32)) + self._run_cpu_musa('ab,bc,cd->ad', [a, b, c]) + + def testEinsumInvalidEquation(self): + """Invalid equation should raise consistently on CPU and MUSA.""" + x = tf.constant(np.random.uniform(-1.0, 1.0, size=(2, 2)).astype(np.float32)) + self._assert_error_consistency('ij->ik', [x]) + + def testEinsumMismatchedDimensions(self): + """Dimension mismatch should raise consistently on CPU and MUSA.""" + a = tf.constant(np.random.uniform(-1.0, 1.0, size=(4, 5)).astype(np.float32)) + b = tf.constant(np.random.uniform(-1.0, 1.0, size=(6, 3)).astype(np.float32)) + self._assert_error_consistency('ij,jk->ik', [a, b]) + + +if __name__ == "__main__": + np.random.seed(2026) + tf.random.set_seed(2026) + tf.test.main()