From bd61e586ad4f0412dd5e07c5e41f4c3156a6542e Mon Sep 17 00:00:00 2001 From: albert Date: Wed, 25 Feb 2026 08:59:36 +0000 Subject: [PATCH 01/16] init einsum with tensorflow impl --- musa_ext/kernels/musa_einsum_op.cc | 645 +++++++++++++++++++++++++++ musa_ext/kernels/musa_einsum_op.h | 29 ++ musa_ext/kernels/musa_fill_functor.h | 24 + musa_ext/utils/musa_einsum_op_util.h | 154 +++++++ 4 files changed, 852 insertions(+) create mode 100644 musa_ext/kernels/musa_einsum_op.cc create mode 100644 musa_ext/kernels/musa_einsum_op.h create mode 100644 musa_ext/kernels/musa_fill_functor.h create mode 100644 musa_ext/utils/musa_einsum_op_util.h diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc new file mode 100644 index 0000000..3c3490e --- /dev/null +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -0,0 +1,645 @@ +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_ + +// #define EIGEN_USE_THREADS +// #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +// #define EIGEN_USE_GPU +// #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include "musa_einsum_op.h" + +#include "../utils/musa_einsum_op_util.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_split.h" +#include "musa_fill_functor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/util/matmul_bcast.h" +#include "utils_op.h" + +// #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +// #include "tensorflow/core/kernels/reduction_ops_common_gpu.h" +// #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +namespace tensorflow { +namespace musa { + +using ShapeVec = gtl::InlinedVector; +using Labels = gtl::InlinedVector; +using OperandLabels = gtl::InlinedVector; +using LabelCounts = gtl::InlinedVector; +using OperandLabelCounts = gtl::InlinedVector; +using LabelToDimSizes = gtl::InlinedVector; + +struct EinsumHelper { + // Insert new (unnamed) broadcasting labels at the location of ellipsis. + static void InsertBroadcastLabels(int num_bcast_dims, int num_named_labels, + int ellipsis_axis, Labels* labels, + LabelCounts* label_counts) { + labels->erase(labels->begin() + ellipsis_axis); + labels->insert(labels->begin() + ellipsis_axis, num_bcast_dims, 0); + std::iota(labels->begin() + ellipsis_axis, + labels->begin() + ellipsis_axis + num_bcast_dims, + num_named_labels); + // Increment label counts. Since these are new labels, the count is set + // to 1. + label_counts->resize(num_named_labels + num_bcast_dims, 1); + } + + // Record and validate the label to dimension mapping. Must be a named + // (non-broadcasting) label as broadcasting labels don't have a fixed + // dimension. + static Status RecordLabelToDimension(const int label, const int axis, + const Tensor& input, + LabelToDimSizes* label_to_dim_sizes) { + const int64_t input_dim = input.dim_size(axis); + // We know that label_to_dim_sizes has the size to accommodate named labels. + if (label_to_dim_sizes->at(label) != 0 && + label_to_dim_sizes->at(label) != input_dim) { + return errors::InvalidArgument( + "Expected dimension ", label_to_dim_sizes->at(label), " at axis ", + axis, " of the input shaped ", input.shape().DebugString(), + " but got dimension ", input_dim); + } + (*label_to_dim_sizes)[label] = input_dim; + return Status::OK(); + } + + // Validate input dimensions and populate unnamed labels and their label + // counts. + static Status ProcessDimensions( + const OpInputList& inputs, + const gtl::InlinedVector& input_has_ellipsis, + const bool output_has_ellipsis, OperandLabels* input_labels, + Labels* output_labels, std::vector* label_types, + OperandLabelCounts* input_label_counts, LabelCounts* output_label_counts, + LabelToDimSizes* label_to_dim_sizes) { + if (inputs.size() != input_labels->size()) { + return errors::InvalidArgument("Expected ", input_labels->size(), + " inputs but got: ", inputs.size()); + } + const int num_inputs = inputs.size(); + + // We infer the number of broadcasting dimensions by taking the maximum + // rank among the broadcasting subshapes of the input. + int max_bcast_dims = 0; + const int num_named_labels = label_types->size(); + label_to_dim_sizes->resize(num_named_labels); + for (int i = 0; i < num_inputs; ++i) { + Labels* labels = &(*input_labels)[i]; + + if (!input_has_ellipsis[i]) { + if (inputs[i].dims() != labels->size()) { + return errors::InvalidArgument("Expected input ", i, " to have rank ", + labels->size(), + " but got: ", inputs[i].dims()); + } + for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { + const int label = (*labels)[label_idx]; + TF_RETURN_IF_ERROR(RecordLabelToDimension(label, label_idx, inputs[i], + label_to_dim_sizes)); + } + continue; + } + + // Input has an ellipsis. + if (inputs[i].dims() + 1 < labels->size()) { + return errors::InvalidArgument( + "Expected input ", i, " to have rank at least ", labels->size() - 1, + " but got: ", inputs[i].dims()); + } + int ellipsis_axis = -1; + const int num_bcast_dims = inputs[i].dims() - labels->size() + 1; + for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { + const int label = (*labels)[label_idx]; + if (label == kEllipsisLabel) { + ellipsis_axis = label_idx; + continue; + } + // Current label is not an ellipsis. + const int axis = + label_idx + (ellipsis_axis == -1 ? 0 : num_bcast_dims - 1); + TF_RETURN_IF_ERROR( + RecordLabelToDimension(label, axis, inputs[i], label_to_dim_sizes)); + } + // Found an ellipsis. Replace 'kEllipsisLabel' with broadcasting + // dimensions. + if (ellipsis_axis != -1) { + InsertBroadcastLabels(num_bcast_dims, num_named_labels, ellipsis_axis, + labels, &input_label_counts->at(i)); + max_bcast_dims = std::max(max_bcast_dims, num_bcast_dims); + } + } + if (!absl::c_linear_search(input_has_ellipsis, true) && + !output_has_ellipsis) { + return Status::OK(); + } + // Insert broadcasting dimensions in the output labels. + auto it = + std::find(output_labels->begin(), output_labels->end(), kEllipsisLabel); + if (it != output_labels->end()) { + const int ellipsis_axis = it - output_labels->begin(); + InsertBroadcastLabels(max_bcast_dims, num_named_labels, ellipsis_axis, + output_labels, output_label_counts); + } else if (max_bcast_dims > 0) { + return errors::InvalidArgument( + "Output contains ", max_bcast_dims, + " broadcasting dimension(s) but no ellipsis " + "(...) was found in the output subscripts."); + } + // Populate EinsumDimensionType for the new broadcasting labels. + label_types->resize(num_named_labels + max_bcast_dims, + EinsumDimensionType::kBroadcasting); + return Status::OK(); + } + + // Permutes the labels according to the given permutation. + static void PermuteLabels(const std::vector& permutation, + Labels* labels) { + Labels permuted_labels(labels->size()); + for (int i = 0; i < labels->size(); ++i) { + permuted_labels[i] = (*labels)[permutation[i]]; + } + labels->swap(permuted_labels); + } + + // Returns a reshaped input Tensor. The underlying buffer is not copied. + static Status CopyFrom(const Tensor& input, const TensorShape& shape, + Tensor* output) { + if (output->CopyFrom(input, shape)) return Status::OK(); + return errors::Internal( + "Encountered error while reshaping a Tensor of shape ", + input.shape().DebugString(), " to shape ", shape.DebugString()); + } + + // Returns whether transposing would be a no-op; whether input has rank < 2 or + // the permutation is the identity permutation. + static bool ShouldTranspose(const TensorShape& input_shape, + const std::vector& permutation) { + if (input_shape.dims() < 2) return false; + for (int i = 0; i < permutation.size(); ++i) { + if (permutation[i] != i) return true; + } + return false; + } + + // Transpose the input given a permutation. Returns a reference to the input + // if transposing is not necessary. + template + static Status TransposeOperand(OpKernelContext* ctx, const Tensor& input, + const std::vector& permutation, + Tensor* output) { + if (!ShouldTranspose(input.shape(), permutation)) { + return CopyFrom(input, input.shape(), output); + } + TensorShape transposed_shape; + for (int i = 0; i < input.dims(); ++i) { + TF_RETURN_IF_ERROR( + transposed_shape.AddDimWithStatus(input.dim_size(permutation[i]))); + } + // For empty Tensors, just change the shape. E.g. we may need to transpose + // from shape [1, 0, 5] to [5, 1, 0]. + if (input.NumElements() == 0) { + return CopyFrom(input, transposed_shape, output); + } + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, transposed_shape, output)); + const Device& device = ctx->eigen_device(); + TF_RETURN_IF_ERROR(DoTranspose(device, input, permutation, output)); + return Status::OK(); + } + + // If there are repeated labels in either the input or output, then this + // strides the input (e.g. iii->i) or inflates it (e.g. i->iii), respectively. + template + static Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input, + const Labels& labels, + const LabelCounts& label_counts, + const bool should_inflate, Tensor* output) { + // Return early if there are no repeated indices. + if (absl::c_all_of(label_counts, [](int c) { return c <= 1; })) { + return CopyFrom(input, input.shape(), output); + } + // We reshape so that each repeated label is compressed to one dimension. + // E.g. For iiij -> ij, The shape [3, 3, 3, 5] would be compressed to [27, + // 5]. Striding appropriately (in this case with strides 14 (=1+3+9) and 1) + // recovers the generalized diagonal of shape [3, 5]. + ShapeVec reshape; + ShapeVec strides; + // Strided and inflated shapes correspond to input and output shapes, + // respectively, should_inflate is true (vice-versa if should_inflate is + // false). E.g. they are [3, 5] and [3, 3, 3, 5] in the above example. + ShapeVec strided_shape; + ShapeVec inflated_shape; + for (int label : labels) { + const int count = label_counts[label]; + const int current_axis = + should_inflate ? strided_shape.size() : inflated_shape.size(); + const int64_t dim = input.dim_size(current_axis); + strided_shape.push_back(dim); + inflated_shape.insert(inflated_shape.end(), count, dim); + const int64_t reshape_dim = MathUtil::IPow(dim, count); + reshape.push_back(reshape_dim); + // While taking the d-diagonal in a rank k Tensor, we take d + // equally-spaced elements including the first and last element. Then, (k + // - 1) * stride = d^k - 1, or, stride = (d^k - 1)/(d - 1). + const int64_t stride = + (dim > 1 && count > 1) ? (reshape_dim - 1) / (dim - 1) : 1; + strides.push_back(stride); + } + + TensorShape output_shape = + TensorShape(should_inflate ? inflated_shape : strided_shape); + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); + const Device& device = ctx->eigen_device(); + switch (reshape.size()) { +#define NDIMS_CASE(N) \ + case N: { \ + if (should_inflate) { \ + auto output_map = output->shaped(reshape); \ + auto input_map = input.shaped(strided_shape); \ + InflateFunctor()(device, input_map, \ + TensorShape(strides).AsEigenDSizes(), \ + output_map); \ + } else { \ + auto input_map = input.shaped(reshape); \ + auto output_map = output->shaped(strided_shape); \ + StrideFunctor()(device, input_map, \ + TensorShape(strides).AsEigenDSizes(), \ + output_map); \ + } \ + } break; + NDIMS_CASE(1); + NDIMS_CASE(2); + NDIMS_CASE(3); + NDIMS_CASE(4); + NDIMS_CASE(5); + NDIMS_CASE(6); + default: + return errors::Unimplemented( + "Unsupported rank: ", reshape.size(), + " while handling repeated indices. Up to rank 6 is supported."); +#undef NDIMS_CASE + } + return Status::OK(); + } + + // Returns true if the input dimensions are already sorted in the order + // [batch, contract, free, reduce]. Used to implement an optimization to avoid + // an extra transpose and instead uses (adj_x and adj_y) in BatchMatMul. + static bool ShouldSwapFreeAndContract( + const Labels& labels, + const std::vector& label_types) { + // Check that ordering is according to dimension type, with the role of + // free and contract dimensions swapped. + gtl::InlinedVector remap = {0, 1, 3, 2, 4}; + for (int i = 0; i + 1 < labels.size(); ++i) { + const int dimtype_a = remap[label_types[labels[i]]]; + const int dimtype_b = remap[label_types[labels[i + 1]]]; + if (dimtype_a > dimtype_b || + (dimtype_a == dimtype_b && labels[i] > labels[i + 1])) { + return false; + } + } + return true; + } + + template + static Status ReduceOperand( + OpKernelContext* ctx, const Tensor& input, + const std::vector& label_types, + const LabelCounts& label_counts, Labels* labels, Labels* free_labels, + bool* swap_free_and_contract, Tensor* output) { + // Find the permutation to transpose the input dimensions in the order of + // EinsumDimensionType; i.e. batch, free, contract and reduce dimensions. + // This makes it more convenient to invoke Reduce/Contract operations. + std::vector permutation(input.dims()); + absl::c_iota(permutation, 0); + Tensor input_transposed; + // Check if we can avoid the transpose. We need to flip the adj_x (or adj_y) + // flag during BatchMatMul. This is an extra optimization not necessary for + // correctness. + if (ShouldSwapFreeAndContract(*labels, label_types)) { + *swap_free_and_contract = true; + } else { + absl::c_sort(permutation, [&](int i, int j) { + int label_i = (*labels)[i]; + int label_j = (*labels)[j]; + return std::tie(label_types[label_i], label_i) < + std::tie(label_types[label_j], label_j); + }); + } + // Transpose the input so that EinsumDimensionTypes are in order. + TF_RETURN_IF_ERROR(TransposeOperand(ctx, input, permutation, + &input_transposed)); + PermuteLabels(permutation, labels); + + // Take the generalized diagonal for dimensions with repeated axis labels. + Tensor input_deduped; + labels->erase(std::unique(labels->begin(), labels->end()), labels->end()); + TF_RETURN_IF_ERROR( + StrideOrInflate(ctx, input_transposed, *labels, label_counts, + false /* should_inflate */, &input_deduped)); + + // Reshape denotes the rank-5 shape [broadcast, batch, free, contract, + // reduce] where we've compacted the dimensions of each EinsumDimensionType. + gtl::InlinedVector reshape(5, 1); + // The output shape is [batch shape] + [free size, contract size] + // That is, the batch shape is preserved (for broadcasting while + // contracting) while the free dims and contract dims are compressed to one + // dimension each. + TensorShape output_shape; + for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { + const int label = labels->at(label_idx); + int64_t dim = input_deduped.dim_size(label_idx); + if (label_types[label] == EinsumDimensionType::kBroadcasting || + label_types[label] == EinsumDimensionType::kBatch) { + TF_RETURN_IF_ERROR(output_shape.AddDimWithStatus(dim)); + } else if (label_types[label] == EinsumDimensionType::kFree) { + free_labels->push_back(label); + } + reshape[label_types[label]] *= dim; + } + if (*swap_free_and_contract) + std::swap(reshape[EinsumDimensionType::kFree], + reshape[EinsumDimensionType::kContract]); + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(reshape[EinsumDimensionType::kFree])); + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(reshape[EinsumDimensionType::kContract])); + + if (reshape[EinsumDimensionType::kReduce] == + 1) { // No need to actually reduce. + return CopyFrom(input_deduped, output_shape, output); + } + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); + using Reducer = Eigen::internal::SumReducer; + using Index = typename TTypes::Tensor::Index; + // Reduce along the last axis (i.e axis 1) of the rank-2 Tensor. + const int64_t output_size = reshape[kBroadcasting] * reshape[kBatch] * + reshape[kFree] * reshape[kContract]; + // functor::ReduceFunctor::Reduce( + // ctx, output->shaped({output_size}), + // const_cast(input_deduped) + // .shaped({output_size, reshape[kReduce]}), + // Eigen::array({1}), Reducer()); + return Status::OK(); + } + + // Reshapes a Tensor of shape [b0,b1...bk,N,M] to [prod(b0,b1...bk),N,M]. + static Status ReshapeToRank3(const Tensor& input, int batch_size, + Tensor* output) { + const int rank = input.dims(); + TensorShape output_shape = {batch_size, input.dim_size(rank - 2), + input.dim_size(rank - 1)}; + return CopyFrom(input, output_shape, output); + } + + // Contracts the inputs along the last axis (or the second last if the + // corresponding value of swap_free_and_contract is true). The batch + // dimensions are broadcast to the output shape. + template + static Status ContractOperands(OpKernelContext* ctx, + absl::Span inputs, + absl::Span swap_free_and_contract, + Tensor* output) { + if (inputs.size() == 1) + return CopyFrom(inputs[0], inputs[0].shape(), output); + MatMulBCast bcast(inputs[0].shape().dim_sizes(), + inputs[1].shape().dim_sizes()); + if (!bcast.IsValid()) { + return errors::InvalidArgument( + "Invalid broadcasting dimensions: ", inputs[0].shape().DebugString(), + " vs. ", inputs[1].shape().DebugString()); + } + Tensor lhs; + TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[0], bcast.x_batch_size(), &lhs)); + Tensor rhs; + TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs)); + TensorShape output_shape = bcast.output_batch_shape(); + for (int i = 0; i < inputs.size(); ++i) { + const int64_t free_axis = + inputs[i].dims() - (swap_free_and_contract[i] ? 1 : 2); + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(inputs[i].dim_size(free_axis))); + } + bool trans_x = swap_free_and_contract[0]; + bool trans_y = !swap_free_and_contract[1]; + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); + if (lhs.NumElements() == 0 || rhs.NumElements() == 0) { + SetZeroFunctor set_zero; + set_zero(ctx->eigen_device(), output->flat()); + return Status::OK(); + } + Tensor output_reshaped; + TF_RETURN_IF_ERROR( + ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped)); + + // -------- THIS SHOULD BE REPLACED BY MUSA BATCHMATMUL FUNCTOR -------- + // LaunchBatchMatMul::Launch(ctx, lhs, rhs, /*adj_x=*/false, + // /*adj_y=*/false, trans_x, trans_y, + // /*grad_x=*/false, /*grad_y=*/false, + // bcast, &output_reshaped); + // ---------------------------------------------------------------------- + return Status::OK(); + } +}; + +template +class MusaEinsumOp : public MusaOpKernel { + public: + explicit MusaEinsumOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("equation", &equation_)); + OP_REQUIRES_OK( + ctx, ParseEinsumEquation(equation_, &input_labels_, &output_labels_, + &label_types_, &input_label_counts_, + &output_label_counts_, &input_has_ellipsis_, + &output_has_ellipsis_)); + } + + void Compute(OpKernelContext* ctx) override { + OpInputList inputs; + OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &inputs)); + + OperandLabels input_labels(input_labels_); + Labels output_labels(output_labels_); + std::vector label_types(label_types_); + OperandLabelCounts input_label_counts(input_label_counts_); + LabelCounts output_label_counts(output_label_counts_); + LabelToDimSizes label_to_dim_sizes; + + OP_REQUIRES_OK(ctx, EinsumHelper::ProcessDimensions( + inputs, input_has_ellipsis_, output_has_ellipsis_, + &input_labels, &output_labels, &label_types, + &input_label_counts, &output_label_counts, + &label_to_dim_sizes)); + + // The reduction phase (a) sums across reduction dimensions, (b) takes + // generalized diagonals, and (c) reshapes it into shape + // [(broadcasting) batch shape] + [F,C] + // where F and C denote the total (compacted) size of free and contract + // dimensions, respectively. + const int num_inputs = inputs.size(); + OperandLabels free_labels(num_inputs); + gtl::InlinedVector inputs_reduced(num_inputs); + gtl::InlinedVector swap_free_and_contract(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + OP_REQUIRES_OK(ctx, + EinsumHelper::ReduceOperand( + ctx, inputs[i], label_types, input_label_counts[i], + &input_labels[i], &free_labels[i], + &swap_free_and_contract[i], &inputs_reduced[i])); + } + + // After reduction, the inputs should be reshaped to Tensors suitable for + // contraction. If num_inputs is 1, the reduced input is simply forwarded to + // the output. + Tensor contraction_output_reshaped; + OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands( + ctx, inputs_reduced, swap_free_and_contract, + &contraction_output_reshaped)); + + // Copy the batch labels from the contraction output. Recover the batch + // shape, which may have been broadcasted. + TensorShape result_shape = contraction_output_reshaped.shape(); + result_shape.RemoveLastDims(2); + + int num_labels = label_types.size(); + Labels result_labels; + // All batch dimensions should be present in the contracted result. First + // the broadcasting dimensions, then the named batch dimensions. + for (int label = 0; label < num_labels; ++label) { + if (label_types[label] == EinsumDimensionType::kBroadcasting) + result_labels.push_back(label); + } + for (int label = 0; label < num_labels; ++label) { + if (label_types[label] == EinsumDimensionType::kBatch) + result_labels.push_back(label); + } + for (int i = 0; i < num_inputs; ++i) { + for (int label : free_labels[i]) { + result_labels.push_back(label); + OP_REQUIRES_OK( + ctx, result_shape.AddDimWithStatus(label_to_dim_sizes[label])); + } + } + + // Reshape the contraction (or reduction) result to its expanded shape: + // [(broadcasted) batch shape] + [free shape 0] + [free shape 1]. + Tensor contraction_output; + OP_REQUIRES_OK( + ctx, EinsumHelper::CopyFrom(contraction_output_reshaped, result_shape, + &contraction_output)); + + // Inflate the output if necessary. (E.g. for the equation 'i->iii' which + // may arise while computing gradient of a regular Einsum). + Tensor output_inflated; + OP_REQUIRES_OK( + ctx, EinsumHelper::StrideOrInflate( + ctx, contraction_output, result_labels, output_label_counts, + true /* should_inflate */, &output_inflated)); + if (output_inflated.dims() > contraction_output.dims()) { + // We inflated the output. Modify result labels accordingly. + Labels inflated_labels; + for (int label : result_labels) { + inflated_labels.insert(inflated_labels.end(), + output_label_counts[label], label); + } + result_labels.swap(inflated_labels); + } + // Find the permutation to map the result labels to the output labels. Note + // that both the result and the final output may have the repeated labels, + // in which case the permutation preserves the left-to-right ordering. + // E.g. if result labels are [0, 0, 1] and output is [0, l, 0] then the + // permutation should be [0, 2, 1]. We also use the fact that repeated + // labels in the result are adjacent to each other. + std::vector output_permutation(output_labels.size()); + std::vector label_to_position(num_labels, -1); + for (int i = 0; i < result_labels.size(); ++i) { + // Remember the position of only the leftmost result label. + if (label_to_position[result_labels[i]] == -1) { + label_to_position[result_labels[i]] = i; + } + } + for (int i = 0; i < output_labels.size(); ++i) { + output_permutation[i] = label_to_position[output_labels[i]]; + // We have found the leftmost occurrence. The next one would be adjacent. + label_to_position[output_labels[i]] += 1; + } + Tensor output; + OP_REQUIRES_OK(ctx, EinsumHelper::TransposeOperand( + ctx, output_inflated, output_permutation, &output)); + ctx->set_output(0, output); + } + + string TraceString(const OpKernelContext& ctx, bool verbose) const override { + string op = profiler::TraceMeOp(name_view(), type_string_view()); + string equation = strings::StrCat("(", equation_, ")"); + if (verbose) { + string shape = ShapeTraceString(ctx); + if (!shape.empty()) { + return profiler::TraceMeEncode( + std::move(op), {{"equation", equation}, {"shape", shape}}); + } + } + return profiler::TraceMeEncode(std::move(op), {{"equation", equation}}); + } + + private: + string equation_; + OperandLabels input_labels_; + Labels output_labels_; + std::vector label_types_; + OperandLabelCounts input_label_counts_; + LabelCounts output_label_counts_; + gtl::InlinedVector input_has_ellipsis_; + bool output_has_ellipsis_ = false; +}; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T, N) \ + template <> \ + void StrideFunctor::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + const Eigen::DSizes& strides, \ + typename TTypes::Tensor output); \ + extern template struct StrideFunctor; \ + template <> \ + void InflateFunctor::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + const Eigen::DSizes& strides, \ + typename TTypes::Tensor output); \ + extern template struct InflateFunctor; + +#define DECLARE_GPU_SPECS(T) \ + DECLARE_GPU_SPEC(T, 1); \ + DECLARE_GPU_SPEC(T, 2); \ + DECLARE_GPU_SPEC(T, 3); \ + DECLARE_GPU_SPEC(T, 4); \ + DECLARE_GPU_SPEC(T, 5); \ + DECLARE_GPU_SPEC(T, 6); + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); +// TODO(rocm): Enable once complex types are supported. +#if GOOGLE_CUDA +DECLARE_GPU_SPECS(complex64); +DECLARE_GPU_SPECS(complex128); +#endif +#undef DECLARE_GPU_SPEC +#undef DECLARE_GPU_SPECS +} // namespace functor +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // namespace musa +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_ \ No newline at end of file diff --git a/musa_ext/kernels/musa_einsum_op.h b/musa_ext/kernels/musa_einsum_op.h new file mode 100644 index 0000000..8b3f380 --- /dev/null +++ b/musa_ext/kernels/musa_einsum_op.h @@ -0,0 +1,29 @@ +#include "utils_op.h" + +// #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +// #include "tensorflow/core/kernels/reduction_ops_common_gpu.h" +// #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +namespace tensorflow { +namespace musa { + +template +struct StrideFunctor { + void operator()(const Device& d, typename TTypes::ConstTensor input, + const Eigen::DSizes& strides, + typename TTypes::Tensor output) { + output.device(d) = input.stride(strides); + } +}; + +template +struct InflateFunctor { + void operator()(const Device& d, typename TTypes::ConstTensor input, + const Eigen::DSizes& strides, + typename TTypes::Tensor output) { + output.device(d) = input.inflate(strides); + } +}; + +} // namespace musa +} // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/kernels/musa_fill_functor.h b/musa_ext/kernels/musa_fill_functor.h new file mode 100644 index 0000000..662e50f --- /dev/null +++ b/musa_ext/kernels/musa_fill_functor.h @@ -0,0 +1,24 @@ +/* + Be advised: + + This file is implemented in aim to support the einsum operator. + For now it only contains the SetZeroFunctor, which is used to set the output + tensor to zero before accumulating the results of the einsum computation. + +*/ + +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace musa { + +template +struct SetZeroFunctor { + // Computes on device "d": out = out.setZero(), + void operator()(const Device& d, typename TTypes::Flat out) { + out.device(d) = out.constant(T(0)); + } +}; + +} // namespace musa +} // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/utils/musa_einsum_op_util.h b/musa_ext/utils/musa_einsum_op_util.h new file mode 100644 index 0000000..ec536d8 --- /dev/null +++ b/musa_ext/utils/musa_einsum_op_util.h @@ -0,0 +1,154 @@ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_split.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { +namespace musa { + +using Labels = absl::InlinedVector; +using OperandLabels = absl::InlinedVector; +using LabelCounts = absl::InlinedVector; +using OperandLabelCounts = absl::InlinedVector; + +// Dummy axis label used to denote an ellipsis in an input or output subscript. +constexpr int kEllipsisLabel = -1; + +enum EinsumDimensionType { + // Batch dimensions are those present in two inputs as well as the output. + // They are part of the batch dimensions during Tensor contraction. Such + // dimensions may be broadcasting dimensions (those mapping to ellipsis) + // or explicit batch dimensions corresponding to named axis labels. + kBroadcasting = 0, + kBatch = 1, + // Free dimensions are present in exactly one of the inputs, and also the + // output. These are non-contracted axes in the Tensor contraction. + kFree = 2, + // Contract dimensions are present in two inputs, but not the output. These + // dimensions are contracted in Tensor contraction. + kContract = 3, + // Reduce dimensions are present in exactly one input; and not in the output + // and are summed over prior to Tensor contraction. + kReduce = 4, +}; + +// Returns the EinsumDimensionType given whether the corresponding label is +// present in exactly one input subscript (is_unique) and whether it is absent +// from the output subscripts (is_removed). Does not handle broadcasting +// dimensions. +EinsumDimensionType GetDimensionType(bool is_removed, bool is_unique) { + if (!is_removed && !is_unique) + return kBatch; + else if (!is_removed && is_unique) + return kFree; + else if (is_removed && !is_unique) + return kContract; + else // is_removed && is_unique + return kReduce; +} + +Status ValidateEinsumEquation( + const std::string& equation, + absl::InlinedVector* input_subscripts, + std::string* output_subscript) { + absl::InlinedVector inputs_and_output_subscripts = + absl::StrSplit(equation, "->"); + if (inputs_and_output_subscripts.size() != 2) { + return errors::InvalidArgument( + "Expecting exactly one '->' in einsum equation: ", equation); + } + *output_subscript = std::move(inputs_and_output_subscripts[1]); + *input_subscripts = + absl::StrSplit(std::move(inputs_and_output_subscripts[0]), ','); + if (input_subscripts->size() != 1 && input_subscripts->size() != 2) { + return errors::InvalidArgument( + "Expecting 1 or 2 input subscripts in equation '", equation, + "' but got: ", input_subscripts->size()); + } + return Status::OK(); +} + +// Maps the character labels to consecutive integers. +void MapToLabels(const std::string& subscript, Labels* labels, + absl::flat_hash_map* label_mapping) { + for (int i = 0; i < subscript.size(); ++i) { + const char label_char = subscript[i]; + if (label_char == '.') { + labels->push_back(kEllipsisLabel); + i += 2; // Skip next 2 characters as well. + continue; + } + if (!label_mapping->contains(label_char)) { + const int next_label = label_mapping->size(); + (*label_mapping)[label_char] = next_label; + } + const int mapped_label = (*label_mapping)[label_char]; + labels->push_back(mapped_label); + } +} + +Status ParseEinsumEquation(const std::string& equation, + OperandLabels* input_labels, Labels* output_labels, + std::vector* label_types, + OperandLabelCounts* input_label_counts, + LabelCounts* output_label_counts, + absl::InlinedVector* input_has_ellipsis, + bool* output_has_ellipsis) { + absl::InlinedVector input_str; + std::string output_str; + TF_RETURN_IF_ERROR(ValidateEinsumEquation(equation, &input_str, &output_str)); + + // Temporary map from single character labels to (consecutive) integer labels. + absl::flat_hash_map label_mapping; + int num_inputs = input_str.size(); + input_labels->resize(num_inputs); + + // Map from single characters to integer labels. + for (int i = 0; i < num_inputs; ++i) { + MapToLabels(input_str[i], &input_labels->at(i), &label_mapping); + } + MapToLabels(output_str, output_labels, &label_mapping); + + // Compute counts for input and output labels. + int num_labels = label_mapping.size(); + input_label_counts->resize(num_inputs); + input_has_ellipsis->resize(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + input_label_counts->at(i).resize(num_labels); + input_has_ellipsis->at(i) = false; + for (const int label : input_labels->at(i)) { + if (label != kEllipsisLabel) + input_label_counts->at(i)[label] += 1; + else + input_has_ellipsis->at(i) = true; + } + } + output_label_counts->resize(num_labels); + *output_has_ellipsis = false; + for (const int label : *output_labels) { + if (label != kEllipsisLabel) + output_label_counts->at(label) += 1; + else + *output_has_ellipsis = true; + } + + // Map each label to a unique EinsumDimensionType. + label_types->resize(num_labels); + for (int label = 0; label < num_labels; ++label) { + if (label == kEllipsisLabel) continue; + bool removed = (*output_label_counts)[label] == 0; + bool unique = num_inputs == 1 || (*input_label_counts)[0][label] == 0 || + (*input_label_counts)[1][label] == 0; + (*label_types)[label] = GetDimensionType(removed, unique); + } + return Status::OK(); +} + +} // namespace musa +} // namespace tensorflow \ No newline at end of file From 6861a60830b070a47cf6af6a81e4e6527cb0b133 Mon Sep 17 00:00:00 2001 From: albert Date: Wed, 25 Feb 2026 10:48:03 +0000 Subject: [PATCH 02/16] init test scripts & operator supporting musa --- musa_ext/kernels/musa_einsum_op.cc | 134 ++++++++++++++--------------- test/einsum_op_test.py | 88 +++++++++++++++++++ 2 files changed, 155 insertions(+), 67 deletions(-) create mode 100644 test/einsum_op_test.py diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc index 3c3490e..64a1771 100644 --- a/musa_ext/kernels/musa_einsum_op.cc +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -1,13 +1,8 @@ -#ifndef TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_ -#define TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_ - -// #define EIGEN_USE_THREADS -// #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -// #define EIGEN_USE_GPU -// #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - #include "musa_einsum_op.h" +#include +#include + #include "../utils/musa_einsum_op_util.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_split.h" @@ -21,10 +16,6 @@ #include "tensorflow/core/util/matmul_bcast.h" #include "utils_op.h" -// #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -// #include "tensorflow/core/kernels/reduction_ops_common_gpu.h" -// #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - namespace tensorflow { namespace musa { @@ -373,22 +364,62 @@ struct EinsumHelper { TF_RETURN_IF_ERROR( output_shape.AddDimWithStatus(reshape[EinsumDimensionType::kContract])); - if (reshape[EinsumDimensionType::kReduce] == - 1) { // No need to actually reduce. + if (reshape[EinsumDimensionType::kReduce] == 1) { // No need to reduce. return CopyFrom(input_deduped, output_shape, output); } TF_RETURN_IF_ERROR( ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); - using Reducer = Eigen::internal::SumReducer; - using Index = typename TTypes::Tensor::Index; - // Reduce along the last axis (i.e axis 1) of the rank-2 Tensor. + const int64_t reduce_size = reshape[kReduce]; const int64_t output_size = reshape[kBroadcasting] * reshape[kBatch] * reshape[kFree] * reshape[kContract]; - // functor::ReduceFunctor::Reduce( - // ctx, output->shaped({output_size}), - // const_cast(input_deduped) - // .shaped({output_size, reshape[kReduce]}), - // Eigen::array({1}), Reducer()); + + TensorShape input_flatten_shape; + TF_RETURN_IF_ERROR(input_flatten_shape.AddDimWithStatus(output_size)); + TF_RETURN_IF_ERROR(input_flatten_shape.AddDimWithStatus(reduce_size)); + Tensor input_flattened; + if (!input_flattened + .BitcastFrom(input_deduped, input_deduped.dtype(), + input_flatten_shape) + .ok()) { + return errors::Internal("Failed to reshape Einsum input for reduce"); + } + + TensorShape output_flatten_shape; + TF_RETURN_IF_ERROR(output_flatten_shape.AddDimWithStatus(output_size)); + Tensor output_flattened; + if (!output_flattened + .BitcastFrom(*output, output->dtype(), output_flatten_shape) + .ok()) { + return errors::Internal("Failed to reshape Einsum output for reduce"); + } + + auto input_mt = CreateMTensor(input_flattened); + auto output_mt = CreateMTensor(output_flattened); + + auto& handle = GetHandleByCtx(ctx); + mReduce op; + op.SetMode(::musa::dnn::Reduce::Mode::ADD); + int reduce_dims[] = {1}; + op.SetDim(1, reduce_dims); + + tensorflow::Allocator* tf_allocator = + ctx->device()->GetAllocator(tensorflow::AllocatorAttributes()); + auto alloc_func = + [tf_allocator]( + size_t size) -> std::unique_ptr> { + void* ptr = tf_allocator->AllocateRaw(256, size); + std::function deleter = [tf_allocator](void* p) { + if (p) tf_allocator->DeallocateRaw(p); + }; + return std::unique_ptr>(ptr, deleter); + }; + ::musa::dnn::MemoryMaintainer mm(alloc_func); + + auto status = op.Run(handle, output_mt, input_mt, mm); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("MUSA Reduce (sum) execution failed. Status: ", + static_cast(status)); + } return Status::OK(); } @@ -442,12 +473,19 @@ struct EinsumHelper { TF_RETURN_IF_ERROR( ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped)); - // -------- THIS SHOULD BE REPLACED BY MUSA BATCHMATMUL FUNCTOR -------- - // LaunchBatchMatMul::Launch(ctx, lhs, rhs, /*adj_x=*/false, - // /*adj_y=*/false, trans_x, trans_y, - // /*grad_x=*/false, /*grad_y=*/false, - // bcast, &output_reshaped); - // ---------------------------------------------------------------------- + auto& handle = GetHandleByCtx(ctx); + mBatchMatMul op; + op.SetTranspose(trans_x, trans_y); + op.SetAlpha(1.0); + op.SetBeta(0.0); + auto lhs_mt = CreateMTensor(lhs); + auto rhs_mt = CreateMTensor(rhs); + auto out_mt = CreateMTensor(output_reshaped); + auto status = op.Run(handle, out_mt, lhs_mt, rhs_mt); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("MUSA BatchMatMul execution failed. Status: ", + static_cast(status)); + } return Status::OK(); } }; @@ -601,45 +639,7 @@ class MusaEinsumOp : public MusaOpKernel { LabelCounts output_label_counts_; gtl::InlinedVector input_has_ellipsis_; bool output_has_ellipsis_ = false; -}; - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -// Forward declarations of the functor specializations for GPU. -namespace functor { -#define DECLARE_GPU_SPEC(T, N) \ - template <> \ - void StrideFunctor::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor input, \ - const Eigen::DSizes& strides, \ - typename TTypes::Tensor output); \ - extern template struct StrideFunctor; \ - template <> \ - void InflateFunctor::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor input, \ - const Eigen::DSizes& strides, \ - typename TTypes::Tensor output); \ - extern template struct InflateFunctor; - -#define DECLARE_GPU_SPECS(T) \ - DECLARE_GPU_SPEC(T, 1); \ - DECLARE_GPU_SPEC(T, 2); \ - DECLARE_GPU_SPEC(T, 3); \ - DECLARE_GPU_SPEC(T, 4); \ - DECLARE_GPU_SPEC(T, 5); \ - DECLARE_GPU_SPEC(T, 6); - -TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); -// TODO(rocm): Enable once complex types are supported. -#if GOOGLE_CUDA -DECLARE_GPU_SPECS(complex64); -DECLARE_GPU_SPECS(complex128); -#endif -#undef DECLARE_GPU_SPEC -#undef DECLARE_GPU_SPECS -} // namespace functor -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +}; // class MusaEinsumOp } // namespace musa } // namespace tensorflow - -#endif // TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_ \ No newline at end of file diff --git a/test/einsum_op_test.py b/test/einsum_op_test.py new file mode 100644 index 0000000..60805e8 --- /dev/null +++ b/test/einsum_op_test.py @@ -0,0 +1,88 @@ +# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for the MUSA Einsum operator.""" + +import numpy as np +import tensorflow as tf + +from musa_test_utils import MUSATestCase + + +class EinsumOpTest(MUSATestCase): + """Tests for the MUSA Einsum operator.""" + + def _random_inputs(self, shapes, dtype): + """Generate random inputs for the requested shapes and dtype.""" + np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype + return [ + tf.constant( + np.random.uniform(-1.0, 1.0, size=shape).astype(np_dtype), + dtype=dtype) + for shape in shapes + ] + + def _test_einsum(self, equation, shapes, dtype, rtol=1e-5, atol=1e-8): + """Compare CPU vs MUSA for the given einsum equation.""" + inputs = self._random_inputs(shapes, dtype) + op = lambda *tensors: tf.einsum(equation, *tensors) + self._compare_cpu_musa_results(op, inputs, dtype, rtol=rtol, atol=atol) + + def testMatrixMultiplication(self): + """Matrix multiplication with explicit contraction indices.""" + equation = "ij,jk->ik" + shapes = [(128, 64), (64, 96)] + for dtype in [tf.float32, tf.float16, tf.bfloat16]: + rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 + atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 + self._test_einsum(equation, shapes, dtype, rtol=rtol, atol=atol) + + def testBatchBroadcastContraction(self): + """Batch contraction with broadcasting over leading dims.""" + equation = "bij,jk->bik" + shapes = [(4, 16, 32), (32, 64)] + for dtype in [tf.float32, tf.float16, tf.bfloat16]: + rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 + atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 + self._test_einsum(equation, shapes, dtype, rtol=rtol, atol=atol) + + def testDiagonalAndBroadcast(self): + """Repeated indices that take diagonals and broadcast shapes.""" + equation = "iij,ij->ij" + shapes = [(4, 4, 6), (4, 6)] + for dtype in [tf.float32, tf.float16, tf.bfloat16]: + rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 + atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 + self._test_einsum(equation, shapes, dtype, rtol=rtol, atol=atol) + + def testEllipsisBroadcast(self): + """Ellipsis handling with mixed-rank operands.""" + equation = "...i,i->..." + shapes = [(2, 3, 5), (5,)] + for dtype in [tf.float32, tf.float16, tf.bfloat16]: + rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 + atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 + self._test_einsum(equation, shapes, dtype, rtol=rtol, atol=atol) + + def testMultipleSummations(self): + """Multiple contraction indices with more than two inputs.""" + equation = "abc,acd,db->bd" + shapes = [(3, 4, 5), (4, 5, 6), (6, 3)] + for dtype in [tf.float32]: + self._test_einsum(equation, shapes, dtype) + + +if __name__ == "__main__": + tf.test.main() From 6b487239afa84b12ea6d78da48457c1b1477f31c Mon Sep 17 00:00:00 2001 From: albert Date: Thu, 26 Feb 2026 03:33:47 +0000 Subject: [PATCH 03/16] update fill operator, which is dependied by einsum --- musa_ext/kernels/musa_einsum_op.cc | 102 ++++++++++------ musa_ext/kernels/musa_einsum_op.h | 53 ++++++-- musa_ext/kernels/musa_fill_functor.h | 9 +- musa_ext/kernels/musa_fill_op.cc | 4 +- musa_ext/kernels/musa_stride_inflate_kernel.h | 31 +++++ .../kernels/musa_stride_inflate_kernel.mu | 114 ++++++++++++++++++ 6 files changed, 263 insertions(+), 50 deletions(-) create mode 100644 musa_ext/kernels/musa_stride_inflate_kernel.h create mode 100644 musa_ext/kernels/musa_stride_inflate_kernel.mu diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc index 64a1771..29894c3 100644 --- a/musa_ext/kernels/musa_einsum_op.cc +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -3,12 +3,14 @@ #include #include +#include "../mu/device/musa_device.h" #include "../utils/musa_einsum_op_util.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_split.h" #include "musa_fill_functor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/types.h" @@ -180,7 +182,7 @@ struct EinsumHelper { // Transpose the input given a permutation. Returns a reference to the input // if transposing is not necessary. - template + template static Status TransposeOperand(OpKernelContext* ctx, const Tensor& input, const std::vector& permutation, Tensor* output) { @@ -199,14 +201,16 @@ struct EinsumHelper { } TF_RETURN_IF_ERROR( ctx->allocate_temp(DataTypeToEnum::value, transposed_shape, output)); - const Device& device = ctx->eigen_device(); - TF_RETURN_IF_ERROR(DoTranspose(device, input, permutation, output)); + // ------- TODO: Replace with valid MUSA implementation ------- + // const Device& device = ctx->eigen_device(); + // TF_RETURN_IF_ERROR(DoTranspose(device, input, permutation, output)); + // ------------------------------------------------------------ return Status::OK(); } // If there are repeated labels in either the input or output, then this // strides the input (e.g. iii->i) or inflates it (e.g. i->iii), respectively. - template + template static Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input, const Labels& labels, const LabelCounts& label_counts, @@ -243,27 +247,41 @@ struct EinsumHelper { strides.push_back(stride); } - TensorShape output_shape = - TensorShape(should_inflate ? inflated_shape : strided_shape); + const ShapeVec& output_shape_dims = + should_inflate ? inflated_shape : strided_shape; + TensorShape output_shape; + for (int64_t dim : output_shape_dims) { + TF_RETURN_IF_ERROR(output_shape.AddDimWithStatus(dim)); + } + auto to_int64 = [](const ShapeVec& dims) { + gtl::InlinedVector converted; + converted.reserve(dims.size()); + for (int64_t dim : dims) { + converted.push_back(static_cast(dim)); + } + return converted; + }; + const auto reshape_int64 = to_int64(reshape); + const auto strided_int64 = to_int64(strided_shape); + const auto strides_int64 = to_int64(strides); + const gtl::ArraySlice reshape_slice(reshape_int64); + const gtl::ArraySlice strided_slice(strided_int64); + const TensorShape strides_shape{gtl::ArraySlice(strides_int64)}; TF_RETURN_IF_ERROR( ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); - const Device& device = ctx->eigen_device(); switch (reshape.size()) { -#define NDIMS_CASE(N) \ - case N: { \ - if (should_inflate) { \ - auto output_map = output->shaped(reshape); \ - auto input_map = input.shaped(strided_shape); \ - InflateFunctor()(device, input_map, \ - TensorShape(strides).AsEigenDSizes(), \ - output_map); \ - } else { \ - auto input_map = input.shaped(reshape); \ - auto output_map = output->shaped(strided_shape); \ - StrideFunctor()(device, input_map, \ - TensorShape(strides).AsEigenDSizes(), \ - output_map); \ - } \ +#define NDIMS_CASE(N) \ + case N: { \ + const auto strides_dsizes = strides_shape.AsEigenDSizes(); \ + if (should_inflate) { \ + auto output_map = output->shaped(reshape_slice); \ + auto input_map = input.shaped(strided_slice); \ + InflateFunctor()(ctx, input_map, strides_dsizes, output_map); \ + } else { \ + auto input_map = input.shaped(reshape_slice); \ + auto output_map = output->shaped(strided_slice); \ + StrideFunctor()(ctx, input_map, strides_dsizes, output_map); \ + } \ } break; NDIMS_CASE(1); NDIMS_CASE(2); @@ -300,7 +318,7 @@ struct EinsumHelper { return true; } - template + template static Status ReduceOperand( OpKernelContext* ctx, const Tensor& input, const std::vector& label_types, @@ -326,16 +344,16 @@ struct EinsumHelper { }); } // Transpose the input so that EinsumDimensionTypes are in order. - TF_RETURN_IF_ERROR(TransposeOperand(ctx, input, permutation, - &input_transposed)); + TF_RETURN_IF_ERROR( + TransposeOperand(ctx, input, permutation, &input_transposed)); PermuteLabels(permutation, labels); // Take the generalized diagonal for dimensions with repeated axis labels. Tensor input_deduped; labels->erase(std::unique(labels->begin(), labels->end()), labels->end()); TF_RETURN_IF_ERROR( - StrideOrInflate(ctx, input_transposed, *labels, label_counts, - false /* should_inflate */, &input_deduped)); + StrideOrInflate(ctx, input_transposed, *labels, label_counts, + false /* should_inflate */, &input_deduped)); // Reshape denotes the rank-5 shape [broadcast, batch, free, contract, // reduce] where we've compacted the dimensions of each EinsumDimensionType. @@ -402,8 +420,10 @@ struct EinsumHelper { int reduce_dims[] = {1}; op.SetDim(1, reduce_dims); + // ------- TODO: Not sure if this would work in MUSA env ------- tensorflow::Allocator* tf_allocator = ctx->device()->GetAllocator(tensorflow::AllocatorAttributes()); + // ------------------------------------------------------------- auto alloc_func = [tf_allocator]( size_t size) -> std::unique_ptr> { @@ -435,7 +455,7 @@ struct EinsumHelper { // Contracts the inputs along the last axis (or the second last if the // corresponding value of swap_free_and_contract is true). The batch // dimensions are broadcast to the output shape. - template + template static Status ContractOperands(OpKernelContext* ctx, absl::Span inputs, absl::Span swap_free_and_contract, @@ -465,8 +485,8 @@ struct EinsumHelper { TF_RETURN_IF_ERROR( ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); if (lhs.NumElements() == 0 || rhs.NumElements() == 0) { - SetZeroFunctor set_zero; - set_zero(ctx->eigen_device(), output->flat()); + SetZeroFunctor set_zero; + set_zero(ctx, output); return Status::OK(); } Tensor output_reshaped; @@ -490,7 +510,7 @@ struct EinsumHelper { } }; -template +template class MusaEinsumOp : public MusaOpKernel { public: explicit MusaEinsumOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { @@ -530,7 +550,7 @@ class MusaEinsumOp : public MusaOpKernel { gtl::InlinedVector swap_free_and_contract(num_inputs); for (int i = 0; i < num_inputs; ++i) { OP_REQUIRES_OK(ctx, - EinsumHelper::ReduceOperand( + EinsumHelper::ReduceOperand( ctx, inputs[i], label_types, input_label_counts[i], &input_labels[i], &free_labels[i], &swap_free_and_contract[i], &inputs_reduced[i])); @@ -540,7 +560,7 @@ class MusaEinsumOp : public MusaOpKernel { // contraction. If num_inputs is 1, the reduced input is simply forwarded to // the output. Tensor contraction_output_reshaped; - OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands( + OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands( ctx, inputs_reduced, swap_free_and_contract, &contraction_output_reshaped)); @@ -580,7 +600,7 @@ class MusaEinsumOp : public MusaOpKernel { // may arise while computing gradient of a regular Einsum). Tensor output_inflated; OP_REQUIRES_OK( - ctx, EinsumHelper::StrideOrInflate( + ctx, EinsumHelper::StrideOrInflate( ctx, contraction_output, result_labels, output_label_counts, true /* should_inflate */, &output_inflated)); if (output_inflated.dims() > contraction_output.dims()) { @@ -612,7 +632,7 @@ class MusaEinsumOp : public MusaOpKernel { label_to_position[output_labels[i]] += 1; } Tensor output; - OP_REQUIRES_OK(ctx, EinsumHelper::TransposeOperand( + OP_REQUIRES_OK(ctx, EinsumHelper::TransposeOperand( ctx, output_inflated, output_permutation, &output)); ctx->set_output(0, output); } @@ -641,5 +661,17 @@ class MusaEinsumOp : public MusaOpKernel { bool output_has_ellipsis_ = false; }; // class MusaEinsumOp +#define REGISTER_MUSA_EINSUM(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Einsum").Device("MUSA").TypeConstraint("T"), \ + MusaEinsumOp); + +REGISTER_MUSA_EINSUM(float); +REGISTER_MUSA_EINSUM(double); +REGISTER_MUSA_EINSUM(int32); +REGISTER_MUSA_EINSUM(int64); +REGISTER_MUSA_EINSUM(Eigen::half); +REGISTER_MUSA_EINSUM(bfloat16); + } // namespace musa } // namespace tensorflow diff --git a/musa_ext/kernels/musa_einsum_op.h b/musa_ext/kernels/musa_einsum_op.h index 8b3f380..3ff546b 100644 --- a/musa_ext/kernels/musa_einsum_op.h +++ b/musa_ext/kernels/musa_einsum_op.h @@ -1,27 +1,60 @@ +#include "musa_stride_inflate_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" #include "utils_op.h" -// #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -// #include "tensorflow/core/kernels/reduction_ops_common_gpu.h" -// #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - namespace tensorflow { namespace musa { -template +template struct StrideFunctor { - void operator()(const Device& d, typename TTypes::ConstTensor input, + void operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor input, const Eigen::DSizes& strides, typename TTypes::Tensor output) { - output.device(d) = input.stride(strides); + const Eigen::DenseIndex num_elements = output.size(); + if (num_elements == 0) return; + + DimSizeArray dims = {}; + DimSizeArray stride_dims = {}; + for (int i = 0; i < N; ++i) { + dims.value[i] = static_cast(output.dimension(i)); + stride_dims.value[i] = static_cast(strides[i]); + } + + auto stream = GetMusaStreamByCtx(ctx); + MusaStrideKernelLauncher(stream, static_cast(num_elements), + input.data(), output.data(), dims, stride_dims, + N); } }; -template +template struct InflateFunctor { - void operator()(const Device& d, typename TTypes::ConstTensor input, + void operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor input, const Eigen::DSizes& strides, typename TTypes::Tensor output) { - output.device(d) = input.inflate(strides); + const Eigen::DenseIndex input_elements = input.size(); + const Eigen::DenseIndex output_elements = output.size(); + const int64_t output_count = static_cast(output_elements); + + auto stream = GetMusaStreamByCtx(ctx); + if (output_count > 0) { + const uint64_t bytes = output_count * sizeof(T); + musaMemsetAsync(output.data(), 0, bytes, stream); + } + if (input_elements == 0 || output_elements == 0) return; + + DimSizeArray input_dims = {}; + DimSizeArray stride_dims = {}; + for (int i = 0; i < N; ++i) { + input_dims.value[i] = static_cast(input.dimension(i)); + stride_dims.value[i] = static_cast(strides[i]); + } + + MusaInflateKernelLauncher(stream, static_cast(input_elements), + input.data(), output.data(), input_dims, + stride_dims, N, output_count); } }; diff --git a/musa_ext/kernels/musa_fill_functor.h b/musa_ext/kernels/musa_fill_functor.h index 662e50f..23072dc 100644 --- a/musa_ext/kernels/musa_fill_functor.h +++ b/musa_ext/kernels/musa_fill_functor.h @@ -12,11 +12,14 @@ namespace tensorflow { namespace musa { -template +template +Status MusaFillCall(Tensor* out, T value, OpKernelContext* context); + +template struct SetZeroFunctor { // Computes on device "d": out = out.setZero(), - void operator()(const Device& d, typename TTypes::Flat out) { - out.device(d) = out.constant(T(0)); + void operator()(OpKernelContext* ctx, Tensor* out) { + MusaFillCall(out, T(0), ctx); } }; diff --git a/musa_ext/kernels/musa_fill_op.cc b/musa_ext/kernels/musa_fill_op.cc index d085cc6..86ed3e0 100755 --- a/musa_ext/kernels/musa_fill_op.cc +++ b/musa_ext/kernels/musa_fill_op.cc @@ -22,6 +22,8 @@ struct is_any : std::integral_constant::value || is_any::value> {}; +} // namespace + template Status MusaFillCall(Tensor* out, T value, OpKernelContext* context) { mFill op; @@ -48,8 +50,6 @@ Status MusaFillCall(Tensor* out, T value, OpKernelContext* context) { return Status::OK(); } -} // namespace - template class MusaFillOp : public MusaOpKernel { public: diff --git a/musa_ext/kernels/musa_stride_inflate_kernel.h b/musa_ext/kernels/musa_stride_inflate_kernel.h new file mode 100644 index 0000000..32c2caa --- /dev/null +++ b/musa_ext/kernels/musa_stride_inflate_kernel.h @@ -0,0 +1,31 @@ +#ifndef MUSA_PLUGIN_SRC_KERNELS_MUSA_STRIDE_INFLATE_KERNEL_H_ +#define MUSA_PLUGIN_SRC_KERNELS_MUSA_STRIDE_INFLATE_KERNEL_H_ + +#include + +#include + +namespace tensorflow { +namespace musa { + +constexpr int kMaxStrideInflateDims = 8; + +struct DimSizeArray { + int64_t value[kMaxStrideInflateDims]; +}; + +template +void MusaStrideKernelLauncher(musaStream_t stream, int64_t size, const T* in, + T* out, DimSizeArray dims, DimSizeArray strides, + int ndims); + +template +void MusaInflateKernelLauncher(musaStream_t stream, int64_t input_size, + const T* in, T* out, DimSizeArray input_dims, + DimSizeArray strides, int ndims, + int64_t output_size); + +} // namespace musa +} // namespace tensorflow + +#endif // MUSA_PLUGIN_SRC_KERNELS_MUSA_STRIDE_INFLATE_KERNEL_H_ diff --git a/musa_ext/kernels/musa_stride_inflate_kernel.mu b/musa_ext/kernels/musa_stride_inflate_kernel.mu new file mode 100644 index 0000000..0f464a1 --- /dev/null +++ b/musa_ext/kernels/musa_stride_inflate_kernel.mu @@ -0,0 +1,114 @@ +#include + +#include "tensorflow/core/framework/bfloat16.h" + +#include "musa_stride_inflate_kernel.h" + +namespace tensorflow { +namespace musa { + +namespace { + +template +__global__ void MusaStrideKernel(int64_t size, const T* __restrict__ in, + T* __restrict__ out, DimSizeArray dims, + DimSizeArray strides, int ndims) { + const int64_t block_stride = static_cast(blockDim.x) * gridDim.x; + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + while (idx < size) { + int64_t coords[kMaxStrideInflateDims] = {0}; + int64_t tmp = idx; + for (int dim = ndims - 1; dim >= 0; --dim) { + const int64_t dim_size = dims.value[dim]; + if (dim_size > 0) { + coords[dim] = tmp % dim_size; + tmp /= dim_size; + } + } + int64_t in_index = 0; + for (int dim = 0; dim < ndims; ++dim) { + in_index += coords[dim] * strides.value[dim]; + } + out[idx] = in[in_index]; + idx += block_stride; + } +} + +template +__global__ void MusaInflateKernel(int64_t size, const T* __restrict__ in, + T* __restrict__ out, DimSizeArray in_dims, + DimSizeArray strides, int ndims, + int64_t out_size) { + const int64_t block_stride = static_cast(blockDim.x) * gridDim.x; + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + while (idx < size) { + int64_t coords[kMaxStrideInflateDims] = {0}; + int64_t tmp = idx; + for (int dim = ndims - 1; dim >= 0; --dim) { + const int64_t dim_size = in_dims.value[dim]; + if (dim_size > 0) { + coords[dim] = tmp % dim_size; + tmp /= dim_size; + } + } + int64_t out_index = 0; + for (int dim = 0; dim < ndims; ++dim) { + out_index += coords[dim] * strides.value[dim]; + } + if (out_index >= 0 && out_index < out_size) { + out[out_index] = in[idx]; + } + idx += block_stride; + } +} + +} // namespace + +template +void MusaStrideKernelLauncher(musaStream_t stream, int64_t size, const T* in, + T* out, DimSizeArray dims, + DimSizeArray strides, int ndims) { + if (size == 0) { + return; + } + constexpr int block_size = 256; + const int grid_size = static_cast( + (size + block_size - 1) / block_size); + MusaStrideKernel<<>>( + size, in, out, dims, strides, ndims); +} + +template +void MusaInflateKernelLauncher(musaStream_t stream, int64_t input_size, + const T* in, T* out, DimSizeArray in_dims, + DimSizeArray strides, int ndims, + int64_t out_size) { + if (input_size == 0) { + return; + } + constexpr int block_size = 256; + const int grid_size = static_cast( + (input_size + block_size - 1) / block_size); + MusaInflateKernel<<>>( + input_size, in, out, in_dims, strides, ndims, out_size); +} + +#define INSTANTIATE_STRIDE_INFLATE(T) \ + template void MusaStrideKernelLauncher(musaStream_t, int64_t, const T*, \ + T*, DimSizeArray, DimSizeArray, \ + int); \ + template void MusaInflateKernelLauncher( \ + musaStream_t, int64_t, const T*, T*, DimSizeArray, DimSizeArray, int, \ + int64_t) + +INSTANTIATE_STRIDE_INFLATE(float); +INSTANTIATE_STRIDE_INFLATE(double); +INSTANTIATE_STRIDE_INFLATE(int32); +INSTANTIATE_STRIDE_INFLATE(int64); +INSTANTIATE_STRIDE_INFLATE(Eigen::half); +INSTANTIATE_STRIDE_INFLATE(Eigen::bfloat16); + +#undef INSTANTIATE_STRIDE_INFLATE + +} // namespace musa +} // namespace tensorflow From f00e103659e57a73f3aa628ae71e11439439a879 Mon Sep 17 00:00:00 2001 From: albert Date: Thu, 26 Feb 2026 05:45:56 +0000 Subject: [PATCH 04/16] update transpose functor --- musa_ext/kernels/musa_einsum_op.cc | 16 ++++----- musa_ext/kernels/musa_transpose_functor.h | 8 +++++ musa_ext/kernels/musa_transpose_op.cc | 44 ++++++++++++----------- 3 files changed, 39 insertions(+), 29 deletions(-) create mode 100644 musa_ext/kernels/musa_transpose_functor.h mode change 100755 => 100644 musa_ext/kernels/musa_transpose_op.cc diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc index 29894c3..dea0f02 100644 --- a/musa_ext/kernels/musa_einsum_op.cc +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -8,6 +8,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/strings/str_split.h" #include "musa_fill_functor.h" +#include "musa_transpose_functor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -151,7 +152,7 @@ struct EinsumHelper { } // Permutes the labels according to the given permutation. - static void PermuteLabels(const std::vector& permutation, + static void PermuteLabels(const std::vector& permutation, Labels* labels) { Labels permuted_labels(labels->size()); for (int i = 0; i < labels->size(); ++i) { @@ -172,7 +173,7 @@ struct EinsumHelper { // Returns whether transposing would be a no-op; whether input has rank < 2 or // the permutation is the identity permutation. static bool ShouldTranspose(const TensorShape& input_shape, - const std::vector& permutation) { + const std::vector& permutation) { if (input_shape.dims() < 2) return false; for (int i = 0; i < permutation.size(); ++i) { if (permutation[i] != i) return true; @@ -184,7 +185,7 @@ struct EinsumHelper { // if transposing is not necessary. template static Status TransposeOperand(OpKernelContext* ctx, const Tensor& input, - const std::vector& permutation, + const std::vector& permutation, Tensor* output) { if (!ShouldTranspose(input.shape(), permutation)) { return CopyFrom(input, input.shape(), output); @@ -201,10 +202,7 @@ struct EinsumHelper { } TF_RETURN_IF_ERROR( ctx->allocate_temp(DataTypeToEnum::value, transposed_shape, output)); - // ------- TODO: Replace with valid MUSA implementation ------- - // const Device& device = ctx->eigen_device(); - // TF_RETURN_IF_ERROR(DoTranspose(device, input, permutation, output)); - // ------------------------------------------------------------ + DoTranspose(ctx, input, permutation, output); return Status::OK(); } @@ -327,7 +325,7 @@ struct EinsumHelper { // Find the permutation to transpose the input dimensions in the order of // EinsumDimensionType; i.e. batch, free, contract and reduce dimensions. // This makes it more convenient to invoke Reduce/Contract operations. - std::vector permutation(input.dims()); + std::vector permutation(input.dims()); absl::c_iota(permutation, 0); Tensor input_transposed; // Check if we can avoid the transpose. We need to flip the adj_x (or adj_y) @@ -618,7 +616,7 @@ class MusaEinsumOp : public MusaOpKernel { // E.g. if result labels are [0, 0, 1] and output is [0, l, 0] then the // permutation should be [0, 2, 1]. We also use the fact that repeated // labels in the result are adjacent to each other. - std::vector output_permutation(output_labels.size()); + std::vector output_permutation(output_labels.size()); std::vector label_to_position(num_labels, -1); for (int i = 0; i < result_labels.size(); ++i) { // Remember the position of only the leftmost result label. diff --git a/musa_ext/kernels/musa_transpose_functor.h b/musa_ext/kernels/musa_transpose_functor.h new file mode 100644 index 0000000..2992b52 --- /dev/null +++ b/musa_ext/kernels/musa_transpose_functor.h @@ -0,0 +1,8 @@ +namespace tensorflow { +namespace musa { + +void DoTranspose(OpKernelContext* ctx, const Tensor& input, + const std::vector& permutation, Tensor* output); + +} // namespace musa +} // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/kernels/musa_transpose_op.cc b/musa_ext/kernels/musa_transpose_op.cc old mode 100755 new mode 100644 index 52d1d4f..37a0b88 --- a/musa_ext/kernels/musa_transpose_op.cc +++ b/musa_ext/kernels/musa_transpose_op.cc @@ -2,6 +2,7 @@ #include +#include "musa_transpose_functor.h" #include "tensorflow/core/framework/bfloat16.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -13,6 +14,28 @@ namespace tensorflow { namespace musa { +void DoTranspose(OpKernelContext* ctx, const Tensor& input, + const std::vector& permutation, Tensor* output) { + mHandle& h = GetHandleByCtx(ctx); + + mTensor in_mt = CreateMTensor(input); + mTensor out_mt = CreateMTensor(*output); + + ::musa::dnn::Permute pop; + + if (::musa::dnn::Status::SUCCESS != + pop.ConfigDimStride(out_mt, in_mt, static_cast(permutation.size()), + permutation.data())) { + ctx->CtxFailure(errors::Internal("muDNN Permute ConfigDimStride failed!")); + return; + } + + if (::musa::dnn::Status::SUCCESS != pop.Run(h, out_mt, in_mt)) { + ctx->CtxFailure(errors::Internal("muDNN Permute Run failed!")); + return; + } +} + template class MusaTransposeOp : public MusaOpKernel { public: @@ -75,26 +98,7 @@ class MusaTransposeOp : public MusaOpKernel { if (output->NumElements() == 0) return; - mHandle& h = GetHandleByCtx(ctx); - - mTensor in_mt = CreateMTensor(input, format_); - mTensor out_mt = CreateMTensor(*output, format_); - - ::musa::dnn::Permute pop; - - if (::musa::dnn::Status::SUCCESS != - pop.ConfigDimStride(out_mt, in_mt, - static_cast(permutation_64.size()), - permutation_64.data())) { - ctx->CtxFailure( - errors::Internal("muDNN Permute ConfigDimStride failed!")); - return; - } - - if (::musa::dnn::Status::SUCCESS != pop.Run(h, out_mt, in_mt)) { - ctx->CtxFailure(errors::Internal("muDNN Permute Run failed!")); - return; - } + DoTranspose(ctx, input, permutation_64, output); } }; From 2e09a945a454866ab6b2f9e556464b6fac59042b Mon Sep 17 00:00:00 2001 From: albert Date: Thu, 26 Feb 2026 07:55:35 +0000 Subject: [PATCH 05/16] update matmul --- musa_ext/kernels/musa_einsum_op.cc | 110 ++++++++++++++++++---- musa_ext/kernels/musa_transpose_functor.h | 8 -- musa_ext/kernels/musa_transpose_op.cc | 1 - 3 files changed, 94 insertions(+), 25 deletions(-) delete mode 100644 musa_ext/kernels/musa_transpose_functor.h diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc index dea0f02..605ae82 100644 --- a/musa_ext/kernels/musa_einsum_op.cc +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -8,7 +8,6 @@ #include "absl/container/flat_hash_map.h" #include "absl/strings/str_split.h" #include "musa_fill_functor.h" -#include "musa_transpose_functor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -22,6 +21,9 @@ namespace tensorflow { namespace musa { +void DoTranspose(OpKernelContext* ctx, const Tensor& input, + const std::vector& permutation, Tensor* output); + using ShapeVec = gtl::InlinedVector; using Labels = gtl::InlinedVector; using OperandLabels = gtl::InlinedVector; @@ -316,6 +318,91 @@ struct EinsumHelper { return true; } + template + static Status BMatMul(OpKernelContext* ctx, TensorShape out_shape, + Tensor& lhs, const Tensor& rhs, bool trans_a, + bool trans_b, Tensor* output) { + const Tensor& in0 = lhs; + const Tensor& in1 = rhs; + + int64 d0 = in0.dim_size(in0.dims() - 2); + int64 d1 = in0.dim_size(in0.dims() - 1); + int64 d2 = in1.dim_size(in1.dims() - 2); + int64 d3 = in1.dim_size(in1.dims() - 1); + + int64 m = trans_a ? d1 : d0; + int64 k = trans_a ? d0 : d1; + int64 n = trans_b ? d2 : d3; + int64 k_check = trans_b ? d3 : d2; + + if (k != k_check) { + return errors::InvalidArgument( + "Matrix size-incompatible: In[0] mismatch In[1]"); + } + + out_shape.AddDim(m); + out_shape.AddDim(n); + + // output is not allocated yet, we need to allocate it here. + // The problem description says "output" is a pointer. + // Usually helper functions allocate their output. + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, out_shape, output)); + + if (output->NumElements() == 0) return Status::OK(); + + auto& handle = GetHandleByCtx(ctx); + // Use TF32 setting if needed, but here we can just default or use env like + // matmul op. Since this is a static method, we don't have access to member + // tf32_enabled_. Let's check environment variable again or just assume + // default precision. For now, let's keep it simple and consistent with + // standard usage. + + mTensor mt_a = CreateMTensor(in0); + mTensor mt_b = CreateMTensor(in1); + mTensor mt_out = CreateMTensor(*output); + + auto FixToBatchFormat = [](mTensor& mt, const Tensor& t) { + if (t.dims() == 2) { + int64_t rows = t.dim_size(0); + int64_t cols = t.dim_size(1); + mt.SetNdInfo({1, rows, cols}, {rows * cols, cols, 1}); + } + }; + + ::musa::dnn::Status status; + + if (in0.dims() == 2 && in1.dims() == 2) { + mMatMul op; + op.SetTranspose(trans_a, trans_b); + op.SetAlpha(1.0); + op.SetBeta(0.0); + + status = op.Run(handle, mt_out, mt_a, mt_b); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal( + "MUSA MatMul (2D High Precision) execution failed. Status: ", + (int)status); + } + } else { + mBatchMatMul op; + op.SetTranspose(trans_a, trans_b); + op.SetAlpha(1.0); + op.SetBeta(0.0); + + FixToBatchFormat(mt_a, in0); + FixToBatchFormat(mt_b, in1); + FixToBatchFormat(mt_out, *output); + + status = op.Run(handle, mt_out, mt_a, mt_b); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("MUSA BatchMatMul execution failed. Status: ", + (int)status); + } + } + return Status::OK(); + } + template static Status ReduceOperand( OpKernelContext* ctx, const Tensor& input, @@ -490,21 +577,12 @@ struct EinsumHelper { Tensor output_reshaped; TF_RETURN_IF_ERROR( ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped)); - - auto& handle = GetHandleByCtx(ctx); - mBatchMatMul op; - op.SetTranspose(trans_x, trans_y); - op.SetAlpha(1.0); - op.SetBeta(0.0); - auto lhs_mt = CreateMTensor(lhs); - auto rhs_mt = CreateMTensor(rhs); - auto out_mt = CreateMTensor(output_reshaped); - auto status = op.Run(handle, out_mt, lhs_mt, rhs_mt); - if (status != ::musa::dnn::Status::SUCCESS) { - return errors::Internal("MUSA BatchMatMul execution failed. Status: ", - static_cast(status)); - } - return Status::OK(); + // LaunchBatchMatMul::Launch(ctx, lhs, rhs, /*adj_x=*/false, + // /*adj_y=*/false, trans_x, trans_y, + // /*grad_x=*/false, /*grad_y=*/false, + // bcast, &output_reshaped); + return BMatMul(ctx, output_shape, lhs, rhs, trans_x, trans_y, + &output_reshaped); } }; diff --git a/musa_ext/kernels/musa_transpose_functor.h b/musa_ext/kernels/musa_transpose_functor.h deleted file mode 100644 index 2992b52..0000000 --- a/musa_ext/kernels/musa_transpose_functor.h +++ /dev/null @@ -1,8 +0,0 @@ -namespace tensorflow { -namespace musa { - -void DoTranspose(OpKernelContext* ctx, const Tensor& input, - const std::vector& permutation, Tensor* output); - -} // namespace musa -} // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/kernels/musa_transpose_op.cc b/musa_ext/kernels/musa_transpose_op.cc index 37a0b88..648f25d 100644 --- a/musa_ext/kernels/musa_transpose_op.cc +++ b/musa_ext/kernels/musa_transpose_op.cc @@ -2,7 +2,6 @@ #include -#include "musa_transpose_functor.h" #include "tensorflow/core/framework/bfloat16.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" From 9826ed9963b4c0867e3aac981039877fcb15aa9a Mon Sep 17 00:00:00 2001 From: albert Date: Thu, 26 Feb 2026 08:12:38 +0000 Subject: [PATCH 06/16] fix bugs of wrong dimension (still have bugs with half precision) --- musa_ext/kernels/musa_einsum_op.cc | 121 +++++++++++++++++++---------- test/einsum_op_test.py | 4 +- 2 files changed, 82 insertions(+), 43 deletions(-) diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc index 605ae82..d6d503a 100644 --- a/musa_ext/kernels/musa_einsum_op.cc +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -7,6 +7,7 @@ #include "../utils/musa_einsum_op_util.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_split.h" +#include "mu/device/musa_memcpy.h" #include "musa_fill_functor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -264,37 +265,74 @@ struct EinsumHelper { const auto reshape_int64 = to_int64(reshape); const auto strided_int64 = to_int64(strided_shape); const auto strides_int64 = to_int64(strides); - const gtl::ArraySlice reshape_slice(reshape_int64); - const gtl::ArraySlice strided_slice(strided_int64); - const TensorShape strides_shape{gtl::ArraySlice(strides_int64)}; TF_RETURN_IF_ERROR( ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); - switch (reshape.size()) { -#define NDIMS_CASE(N) \ - case N: { \ - const auto strides_dsizes = strides_shape.AsEigenDSizes(); \ - if (should_inflate) { \ - auto output_map = output->shaped(reshape_slice); \ - auto input_map = input.shaped(strided_slice); \ - InflateFunctor()(ctx, input_map, strides_dsizes, output_map); \ - } else { \ - auto input_map = input.shaped(reshape_slice); \ - auto output_map = output->shaped(strided_slice); \ - StrideFunctor()(ctx, input_map, strides_dsizes, output_map); \ - } \ - } break; - NDIMS_CASE(1); - NDIMS_CASE(2); - NDIMS_CASE(3); - NDIMS_CASE(4); - NDIMS_CASE(5); - NDIMS_CASE(6); - default: - return errors::Unimplemented( - "Unsupported rank: ", reshape.size(), - " while handling repeated indices. Up to rank 6 is supported."); -#undef NDIMS_CASE + + const int rank = reshape.size(); + if (rank == 0) return Status::OK(); + + auto compute_dense_strides = [](const std::vector& dims) { + std::vector dense_strides(dims.size(), 1); + for (int i = static_cast(dims.size()) - 2; i >= 0; --i) { + dense_strides[i] = dense_strides[i + 1] * dims[i + 1]; + } + return dense_strides; + }; + + std::vector reshape_dims(reshape_int64.begin(), + reshape_int64.end()); + std::vector strided_dims(strided_int64.begin(), + strided_int64.end()); + std::vector step_dims(strides_int64.begin(), strides_int64.end()); + + const auto reshape_dense = compute_dense_strides(reshape_dims); + const auto strided_dense = compute_dense_strides(strided_dims); + + std::vector h_input(input.NumElements()); + std::vector h_output(output->NumElements(), static_cast(0)); + + musaStream_t stream = GetMusaStreamByCtx(ctx); + mStatus d2h_status = + MusaMemcpyAsyncD2H(h_input.data(), input.flat().data(), + input.NumElements() * sizeof(T), stream); + if (d2h_status != mStatus::SUCCESS) { + return errors::Internal("Einsum StrideOrInflate D2H memcpy failed"); + } + musaStreamSynchronize(stream); + + if (should_inflate) { + for (int64_t linear = 0; linear < static_cast(h_input.size()); + ++linear) { + int64_t remain = linear; + int64_t out_linear = 0; + for (int axis = 0; axis < rank; ++axis) { + const int64_t coord = remain / strided_dense[axis]; + remain %= strided_dense[axis]; + out_linear += (coord * step_dims[axis]) * reshape_dense[axis]; + } + h_output[out_linear] = h_input[linear]; + } + } else { + for (int64_t linear = 0; linear < static_cast(h_output.size()); + ++linear) { + int64_t remain = linear; + int64_t in_linear = 0; + for (int axis = 0; axis < rank; ++axis) { + const int64_t coord = remain / strided_dense[axis]; + remain %= strided_dense[axis]; + in_linear += (coord * step_dims[axis]) * reshape_dense[axis]; + } + h_output[linear] = h_input[in_linear]; + } + } + + mStatus h2d_status = + MusaMemcpyAsyncH2D(output->flat().data(), h_output.data(), + output->NumElements() * sizeof(T), stream); + if (h2d_status != mStatus::SUCCESS) { + return errors::Internal("Einsum StrideOrInflate H2D memcpy failed"); } + musaStreamSynchronize(stream); return Status::OK(); } @@ -319,9 +357,8 @@ struct EinsumHelper { } template - static Status BMatMul(OpKernelContext* ctx, TensorShape out_shape, - Tensor& lhs, const Tensor& rhs, bool trans_a, - bool trans_b, Tensor* output) { + static Status BMatMul(OpKernelContext* ctx, Tensor& lhs, const Tensor& rhs, + bool trans_a, bool trans_b, Tensor* output) { const Tensor& in0 = lhs; const Tensor& in1 = rhs; @@ -340,18 +377,21 @@ struct EinsumHelper { "Matrix size-incompatible: In[0] mismatch In[1]"); } - out_shape.AddDim(m); - out_shape.AddDim(n); - - // output is not allocated yet, we need to allocate it here. - // The problem description says "output" is a pointer. - // Usually helper functions allocate their output. - TF_RETURN_IF_ERROR( - ctx->allocate_temp(DataTypeToEnum::value, out_shape, output)); + if (output->dims() < 2) { + return errors::Internal( + "Einsum output tensor rank must be at least 2, got ", output->dims()); + } + if (output->dim_size(output->dims() - 2) != m || + output->dim_size(output->dims() - 1) != n) { + return errors::Internal( + "Einsum output tensor shape mismatch, expected tail [", m, ", ", n, + "], got ", output->shape().DebugString()); + } if (output->NumElements() == 0) return Status::OK(); auto& handle = GetHandleByCtx(ctx); + handle.SetAllowTF32(false); // Use TF32 setting if needed, but here we can just default or use env like // matmul op. Since this is a static method, we don't have access to member // tf32_enabled_. Let's check environment variable again or just assume @@ -581,8 +621,7 @@ struct EinsumHelper { // /*adj_y=*/false, trans_x, trans_y, // /*grad_x=*/false, /*grad_y=*/false, // bcast, &output_reshaped); - return BMatMul(ctx, output_shape, lhs, rhs, trans_x, trans_y, - &output_reshaped); + return BMatMul(ctx, lhs, rhs, trans_x, trans_y, &output_reshaped); } }; diff --git a/test/einsum_op_test.py b/test/einsum_op_test.py index 60805e8..b40bd35 100644 --- a/test/einsum_op_test.py +++ b/test/einsum_op_test.py @@ -79,8 +79,8 @@ def testEllipsisBroadcast(self): def testMultipleSummations(self): """Multiple contraction indices with more than two inputs.""" equation = "abc,acd,db->bd" - shapes = [(3, 4, 5), (4, 5, 6), (6, 3)] - for dtype in [tf.float32]: + shapes = [(3, 4, 5), (3, 5, 6), (6, 4)] + for dtype in [tf.float32, tf.float16, tf.bfloat16]: self._test_einsum(equation, shapes, dtype) From 161356b399d434e575b6563732df2a82979fcadb Mon Sep 17 00:00:00 2001 From: albert Date: Thu, 26 Feb 2026 08:48:36 +0000 Subject: [PATCH 07/16] optimize StradeOrInflate --- musa_ext/kernels/musa_einsum_op.cc | 97 ++++++++++++------------------ 1 file changed, 38 insertions(+), 59 deletions(-) diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc index d6d503a..a605548 100644 --- a/musa_ext/kernels/musa_einsum_op.cc +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -254,85 +254,64 @@ struct EinsumHelper { for (int64_t dim : output_shape_dims) { TF_RETURN_IF_ERROR(output_shape.AddDimWithStatus(dim)); } - auto to_int64 = [](const ShapeVec& dims) { - gtl::InlinedVector converted; - converted.reserve(dims.size()); - for (int64_t dim : dims) { - converted.push_back(static_cast(dim)); - } - return converted; - }; - const auto reshape_int64 = to_int64(reshape); - const auto strided_int64 = to_int64(strided_shape); - const auto strides_int64 = to_int64(strides); TF_RETURN_IF_ERROR( ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); const int rank = reshape.size(); if (rank == 0) return Status::OK(); - auto compute_dense_strides = [](const std::vector& dims) { - std::vector dense_strides(dims.size(), 1); + auto compute_dense_strides = [](const ShapeVec& dims) { + ShapeVec dense_strides(dims.size(), 1); for (int i = static_cast(dims.size()) - 2; i >= 0; --i) { dense_strides[i] = dense_strides[i + 1] * dims[i + 1]; } return dense_strides; }; - std::vector reshape_dims(reshape_int64.begin(), - reshape_int64.end()); - std::vector strided_dims(strided_int64.begin(), - strided_int64.end()); - std::vector step_dims(strides_int64.begin(), strides_int64.end()); - - const auto reshape_dense = compute_dense_strides(reshape_dims); - const auto strided_dense = compute_dense_strides(strided_dims); + const auto inflated_dense = compute_dense_strides(inflated_shape); + ShapeVec diagonal_strides; + diagonal_strides.reserve(rank); + int inflated_axis = 0; + for (int label : labels) { + const int count = label_counts[label]; + int64_t diagonal_stride = 0; + for (int i = 0; i < count; ++i) { + diagonal_stride += inflated_dense[inflated_axis + i]; + } + diagonal_strides.push_back(diagonal_stride); + inflated_axis += count; + } - std::vector h_input(input.NumElements()); - std::vector h_output(output->NumElements(), static_cast(0)); + std::vector strided_dims_vec(strided_shape.begin(), + strided_shape.end()); + std::vector diagonal_strides_vec(diagonal_strides.begin(), + diagonal_strides.end()); - musaStream_t stream = GetMusaStreamByCtx(ctx); - mStatus d2h_status = - MusaMemcpyAsyncD2H(h_input.data(), input.flat().data(), - input.NumElements() * sizeof(T), stream); - if (d2h_status != mStatus::SUCCESS) { - return errors::Internal("Einsum StrideOrInflate D2H memcpy failed"); - } - musaStreamSynchronize(stream); + auto& handle = GetHandleByCtx(ctx); + auto input_mt = CreateMTensor(input); + auto output_mt = CreateMTensor(*output); if (should_inflate) { - for (int64_t linear = 0; linear < static_cast(h_input.size()); - ++linear) { - int64_t remain = linear; - int64_t out_linear = 0; - for (int axis = 0; axis < rank; ++axis) { - const int64_t coord = remain / strided_dense[axis]; - remain %= strided_dense[axis]; - out_linear += (coord * step_dims[axis]) * reshape_dense[axis]; - } - h_output[out_linear] = h_input[linear]; - } + SetZeroFunctor set_zero; + set_zero(ctx, output); + output_mt.SetNdInfo(rank, strided_dims_vec.data(), + diagonal_strides_vec.data()); } else { - for (int64_t linear = 0; linear < static_cast(h_output.size()); - ++linear) { - int64_t remain = linear; - int64_t in_linear = 0; - for (int axis = 0; axis < rank; ++axis) { - const int64_t coord = remain / strided_dense[axis]; - remain %= strided_dense[axis]; - in_linear += (coord * step_dims[axis]) * reshape_dense[axis]; - } - h_output[linear] = h_input[in_linear]; - } + input_mt.SetNdInfo(rank, strided_dims_vec.data(), + diagonal_strides_vec.data()); } - mStatus h2d_status = - MusaMemcpyAsyncH2D(output->flat().data(), h_output.data(), - output->NumElements() * sizeof(T), stream); - if (h2d_status != mStatus::SUCCESS) { - return errors::Internal("Einsum StrideOrInflate H2D memcpy failed"); + ::musa::dnn::Unary op; + auto status = op.SetMode(::musa::dnn::Unary::Mode::IDENTITY); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("Einsum StrideOrInflate SetMode failed. Status: ", + static_cast(status)); + } + status = op.Run(handle, output_mt, input_mt); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("Einsum StrideOrInflate Run failed. Status: ", + static_cast(status)); } - musaStreamSynchronize(stream); return Status::OK(); } From 53c3b8c86b81c603a49cd77242da3abdda588922 Mon Sep 17 00:00:00 2001 From: albert Date: Thu, 26 Feb 2026 10:04:42 +0000 Subject: [PATCH 08/16] remove comments --- musa_ext/kernels/musa_einsum_op.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc index a605548..6823a6b 100644 --- a/musa_ext/kernels/musa_einsum_op.cc +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -524,10 +524,8 @@ struct EinsumHelper { int reduce_dims[] = {1}; op.SetDim(1, reduce_dims); - // ------- TODO: Not sure if this would work in MUSA env ------- tensorflow::Allocator* tf_allocator = ctx->device()->GetAllocator(tensorflow::AllocatorAttributes()); - // ------------------------------------------------------------- auto alloc_func = [tf_allocator]( size_t size) -> std::unique_ptr> { @@ -596,10 +594,6 @@ struct EinsumHelper { Tensor output_reshaped; TF_RETURN_IF_ERROR( ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped)); - // LaunchBatchMatMul::Launch(ctx, lhs, rhs, /*adj_x=*/false, - // /*adj_y=*/false, trans_x, trans_y, - // /*grad_x=*/false, /*grad_y=*/false, - // bcast, &output_reshaped); return BMatMul(ctx, lhs, rhs, trans_x, trans_y, &output_reshaped); } }; From 3fd0f4083a7410a9c3846cc016cbc4b1075e94f6 Mon Sep 17 00:00:00 2001 From: albert Date: Fri, 27 Feb 2026 03:15:28 +0000 Subject: [PATCH 09/16] support for half precision --- musa_ext/kernels/musa_einsum_op.cc | 11 +- musa_ext/kernels/musa_einsum_op_half.cc | 862 ++++++++++++++++++++++++ musa_ext/utils/musa_einsum_op_util.h | 21 +- 3 files changed, 881 insertions(+), 13 deletions(-) create mode 100644 musa_ext/kernels/musa_einsum_op_half.cc diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc index 6823a6b..2d9a4b0 100644 --- a/musa_ext/kernels/musa_einsum_op.cc +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -614,6 +614,13 @@ class MusaEinsumOp : public MusaOpKernel { OpInputList inputs; OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &inputs)); + // Take ...i,i->... as an example. After parsing the equation, we have + // input_labels = [kEllipsisLabel, 0] + // output_labels = [kEllipsisLabel] + // label_types = [${EinsumDimensionType for label 0, which is kContract}] + // input_label_counts = [1, 1], which use the default value of 1 for the + // every label. output_label_counts = [1] label_to_dim_sizes = {}, which + // will be populated during dimension processing. OperandLabels input_labels(input_labels_); Labels output_labels(output_labels_); std::vector label_types(label_types_); @@ -758,8 +765,8 @@ REGISTER_MUSA_EINSUM(float); REGISTER_MUSA_EINSUM(double); REGISTER_MUSA_EINSUM(int32); REGISTER_MUSA_EINSUM(int64); -REGISTER_MUSA_EINSUM(Eigen::half); -REGISTER_MUSA_EINSUM(bfloat16); + +#undef REGISTER_MUSA_EINSUM } // namespace musa } // namespace tensorflow diff --git a/musa_ext/kernels/musa_einsum_op_half.cc b/musa_ext/kernels/musa_einsum_op_half.cc new file mode 100644 index 0000000..a948b1e --- /dev/null +++ b/musa_ext/kernels/musa_einsum_op_half.cc @@ -0,0 +1,862 @@ +#include +#include +#include + +#include "../mu/device/musa_device.h" +#include "../utils/musa_einsum_op_util.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_split.h" +#include "mu/device/musa_memcpy.h" +#include "musa_einsum_op.h" +#include "musa_fill_functor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/util/matmul_bcast.h" +#include "utils_op.h" + +namespace tensorflow { +namespace musa { + +void DoTranspose(OpKernelContext* ctx, const Tensor& input, + const std::vector& permutation, Tensor* output); + +using ShapeVec = gtl::InlinedVector; +using Labels = gtl::InlinedVector; +using OperandLabels = gtl::InlinedVector; +using LabelCounts = gtl::InlinedVector; +using OperandLabelCounts = gtl::InlinedVector; +using LabelToDimSizes = gtl::InlinedVector; + +struct EinsumHelper { + // Insert new (unnamed) broadcasting labels at the location of ellipsis. + static void InsertBroadcastLabels(int num_bcast_dims, int num_named_labels, + int ellipsis_axis, Labels* labels, + LabelCounts* label_counts) { + labels->erase(labels->begin() + ellipsis_axis); + labels->insert(labels->begin() + ellipsis_axis, num_bcast_dims, 0); + std::iota(labels->begin() + ellipsis_axis, + labels->begin() + ellipsis_axis + num_bcast_dims, + num_named_labels); + // Increment label counts. Since these are new labels, the count is set + // to 1. + label_counts->resize(num_named_labels + num_bcast_dims, 1); + } + + // Record and validate the label to dimension mapping. Must be a named + // (non-broadcasting) label as broadcasting labels don't have a fixed + // dimension. + static Status RecordLabelToDimension(const int label, const int axis, + const Tensor& input, + LabelToDimSizes* label_to_dim_sizes) { + const int64_t input_dim = input.dim_size(axis); + // We know that label_to_dim_sizes has the size to accommodate named labels. + if (label_to_dim_sizes->at(label) != 0 && + label_to_dim_sizes->at(label) != input_dim) { + return errors::InvalidArgument( + "Expected dimension ", label_to_dim_sizes->at(label), " at axis ", + axis, " of the input shaped ", input.shape().DebugString(), + " but got dimension ", input_dim); + } + (*label_to_dim_sizes)[label] = input_dim; + return Status::OK(); + } + + // Validate input dimensions and populate unnamed labels and their label + // counts. + static Status ProcessDimensions( + const OpInputList& inputs, + const gtl::InlinedVector& input_has_ellipsis, + const bool output_has_ellipsis, OperandLabels* input_labels, + Labels* output_labels, std::vector* label_types, + OperandLabelCounts* input_label_counts, LabelCounts* output_label_counts, + LabelToDimSizes* label_to_dim_sizes) { + if (inputs.size() != input_labels->size()) { + return errors::InvalidArgument("Expected ", input_labels->size(), + " inputs but got: ", inputs.size()); + } + const int num_inputs = inputs.size(); + + // We infer the number of broadcasting dimensions by taking the maximum + // rank among the broadcasting subshapes of the input. + int max_bcast_dims = 0; + const int num_named_labels = label_types->size(); + label_to_dim_sizes->resize(num_named_labels); + for (int i = 0; i < num_inputs; ++i) { + Labels* labels = &(*input_labels)[i]; + + if (!input_has_ellipsis[i]) { + if (inputs[i].dims() != labels->size()) { + return errors::InvalidArgument("Expected input ", i, " to have rank ", + labels->size(), + " but got: ", inputs[i].dims()); + } + for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { + const int label = (*labels)[label_idx]; + TF_RETURN_IF_ERROR(RecordLabelToDimension(label, label_idx, inputs[i], + label_to_dim_sizes)); + } + continue; + } + + // Input has an ellipsis. + if (inputs[i].dims() + 1 < labels->size()) { + return errors::InvalidArgument( + "Expected input ", i, " to have rank at least ", labels->size() - 1, + " but got: ", inputs[i].dims()); + } + int ellipsis_axis = -1; + const int num_bcast_dims = inputs[i].dims() - labels->size() + 1; + for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { + const int label = (*labels)[label_idx]; + if (label == kEllipsisLabel) { + ellipsis_axis = label_idx; + continue; + } + // Current label is not an ellipsis. + const int axis = + label_idx + (ellipsis_axis == -1 ? 0 : num_bcast_dims - 1); + TF_RETURN_IF_ERROR( + RecordLabelToDimension(label, axis, inputs[i], label_to_dim_sizes)); + } + // Found an ellipsis. Replace 'kEllipsisLabel' with broadcasting + // dimensions. + if (ellipsis_axis != -1) { + InsertBroadcastLabels(num_bcast_dims, num_named_labels, ellipsis_axis, + labels, &input_label_counts->at(i)); + max_bcast_dims = std::max(max_bcast_dims, num_bcast_dims); + } + } + if (!absl::c_linear_search(input_has_ellipsis, true) && + !output_has_ellipsis) { + return Status::OK(); + } + // Insert broadcasting dimensions in the output labels. + auto it = + std::find(output_labels->begin(), output_labels->end(), kEllipsisLabel); + if (it != output_labels->end()) { + const int ellipsis_axis = it - output_labels->begin(); + InsertBroadcastLabels(max_bcast_dims, num_named_labels, ellipsis_axis, + output_labels, output_label_counts); + } else if (max_bcast_dims > 0) { + return errors::InvalidArgument( + "Output contains ", max_bcast_dims, + " broadcasting dimension(s) but no ellipsis " + "(...) was found in the output subscripts."); + } + // Populate EinsumDimensionType for the new broadcasting labels. + label_types->resize(num_named_labels + max_bcast_dims, + EinsumDimensionType::kBroadcasting); + return Status::OK(); + } + + // Permutes the labels according to the given permutation. + static void PermuteLabels(const std::vector& permutation, + Labels* labels) { + Labels permuted_labels(labels->size()); + for (int i = 0; i < labels->size(); ++i) { + permuted_labels[i] = (*labels)[permutation[i]]; + } + labels->swap(permuted_labels); + } + + // Returns a reshaped input Tensor. The underlying buffer is not copied. + static Status CopyFrom(const Tensor& input, const TensorShape& shape, + Tensor* output) { + if (output->CopyFrom(input, shape)) return Status::OK(); + return errors::Internal( + "Encountered error while reshaping a Tensor of shape ", + input.shape().DebugString(), " to shape ", shape.DebugString()); + } + + // Returns whether transposing would be a no-op; whether input has rank < 2 or + // the permutation is the identity permutation. + static bool ShouldTranspose(const TensorShape& input_shape, + const std::vector& permutation) { + if (input_shape.dims() < 2) return false; + for (int i = 0; i < permutation.size(); ++i) { + if (permutation[i] != i) return true; + } + return false; + } + + // Transpose the input given a permutation. Returns a reference to the input + // if transposing is not necessary. + template + static Status TransposeOperand(OpKernelContext* ctx, const Tensor& input, + const std::vector& permutation, + Tensor* output) { + if (!ShouldTranspose(input.shape(), permutation)) { + return CopyFrom(input, input.shape(), output); + } + TensorShape transposed_shape; + for (int i = 0; i < input.dims(); ++i) { + TF_RETURN_IF_ERROR( + transposed_shape.AddDimWithStatus(input.dim_size(permutation[i]))); + } + // For empty Tensors, just change the shape. E.g. we may need to transpose + // from shape [1, 0, 5] to [5, 1, 0]. + if (input.NumElements() == 0) { + return CopyFrom(input, transposed_shape, output); + } + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, transposed_shape, output)); + DoTranspose(ctx, input, permutation, output); + return Status::OK(); + } + + // If there are repeated labels in either the input or output, then this + // strides the input (e.g. iii->i) or inflates it (e.g. i->iii), respectively. + template + static Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input, + const Labels& labels, + const LabelCounts& label_counts, + const bool should_inflate, Tensor* output) { + // Return early if there are no repeated indices. + if (absl::c_all_of(label_counts, [](int c) { return c <= 1; })) { + return CopyFrom(input, input.shape(), output); + } + // We reshape so that each repeated label is compressed to one dimension. + // E.g. For iiij -> ij, The shape [3, 3, 3, 5] would be compressed to [27, + // 5]. Striding appropriately (in this case with strides 14 (=1+3+9) and 1) + // recovers the generalized diagonal of shape [3, 5]. + ShapeVec reshape; + ShapeVec strides; + // Strided and inflated shapes correspond to input and output shapes, + // respectively, should_inflate is true (vice-versa if should_inflate is + // false). E.g. they are [3, 5] and [3, 3, 3, 5] in the above example. + ShapeVec strided_shape; + ShapeVec inflated_shape; + for (int label : labels) { + const int count = label_counts[label]; + const int current_axis = + should_inflate ? strided_shape.size() : inflated_shape.size(); + const int64_t dim = input.dim_size(current_axis); + strided_shape.push_back(dim); + inflated_shape.insert(inflated_shape.end(), count, dim); + const int64_t reshape_dim = MathUtil::IPow(dim, count); + reshape.push_back(reshape_dim); + // While taking the d-diagonal in a rank k Tensor, we take d + // equally-spaced elements including the first and last element. Then, (k + // - 1) * stride = d^k - 1, or, stride = (d^k - 1)/(d - 1). + const int64_t stride = + (dim > 1 && count > 1) ? (reshape_dim - 1) / (dim - 1) : 1; + strides.push_back(stride); + } + + const ShapeVec& output_shape_dims = + should_inflate ? inflated_shape : strided_shape; + TensorShape output_shape; + for (int64_t dim : output_shape_dims) { + TF_RETURN_IF_ERROR(output_shape.AddDimWithStatus(dim)); + } + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); + + const int rank = reshape.size(); + if (rank == 0) return Status::OK(); + + auto compute_dense_strides = [](const ShapeVec& dims) { + ShapeVec dense_strides(dims.size(), 1); + for (int i = static_cast(dims.size()) - 2; i >= 0; --i) { + dense_strides[i] = dense_strides[i + 1] * dims[i + 1]; + } + return dense_strides; + }; + + const auto inflated_dense = compute_dense_strides(inflated_shape); + ShapeVec diagonal_strides; + diagonal_strides.reserve(rank); + int inflated_axis = 0; + for (int label : labels) { + const int count = label_counts[label]; + int64_t diagonal_stride = 0; + for (int i = 0; i < count; ++i) { + diagonal_stride += inflated_dense[inflated_axis + i]; + } + diagonal_strides.push_back(diagonal_stride); + inflated_axis += count; + } + + std::vector strided_dims_vec(strided_shape.begin(), + strided_shape.end()); + std::vector diagonal_strides_vec(diagonal_strides.begin(), + diagonal_strides.end()); + + auto& handle = GetHandleByCtx(ctx); + auto input_mt = CreateMTensor(input); + auto output_mt = CreateMTensor(*output); + + if (should_inflate) { + SetZeroFunctor set_zero; + set_zero(ctx, output); + output_mt.SetNdInfo(rank, strided_dims_vec.data(), + diagonal_strides_vec.data()); + } else { + input_mt.SetNdInfo(rank, strided_dims_vec.data(), + diagonal_strides_vec.data()); + } + + ::musa::dnn::Unary op; + auto status = op.SetMode(::musa::dnn::Unary::Mode::IDENTITY); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("Einsum StrideOrInflate SetMode failed. Status: ", + static_cast(status)); + } + status = op.Run(handle, output_mt, input_mt); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("Einsum StrideOrInflate Run failed. Status: ", + static_cast(status)); + } + return Status::OK(); + } + + // Returns true if the input dimensions are already sorted in the order + // [batch, contract, free, reduce]. Used to implement an optimization to avoid + // an extra transpose and instead uses (adj_x and adj_y) in BatchMatMul. + static bool ShouldSwapFreeAndContract( + const Labels& labels, + const std::vector& label_types) { + // Check that ordering is according to dimension type, with the role of + // free and contract dimensions swapped. + gtl::InlinedVector remap = {0, 1, 3, 2, 4}; + for (int i = 0; i + 1 < labels.size(); ++i) { + const int dimtype_a = remap[label_types[labels[i]]]; + const int dimtype_b = remap[label_types[labels[i + 1]]]; + if (dimtype_a > dimtype_b || + (dimtype_a == dimtype_b && labels[i] > labels[i + 1])) { + return false; + } + } + return true; + } + + template + static Status BMatMul(OpKernelContext* ctx, Tensor& lhs, const Tensor& rhs, + bool trans_a, bool trans_b, Tensor* output) { + const Tensor& in0 = lhs; + const Tensor& in1 = rhs; + + int64 d0 = in0.dim_size(in0.dims() - 2); + int64 d1 = in0.dim_size(in0.dims() - 1); + int64 d2 = in1.dim_size(in1.dims() - 2); + int64 d3 = in1.dim_size(in1.dims() - 1); + + int64 m = trans_a ? d1 : d0; + int64 k = trans_a ? d0 : d1; + int64 n = trans_b ? d2 : d3; + int64 k_check = trans_b ? d3 : d2; + + if (k != k_check) { + return errors::InvalidArgument( + "Matrix size-incompatible: In[0] mismatch In[1]"); + } + + if (output->dims() < 2) { + return errors::Internal( + "Einsum output tensor rank must be at least 2, got ", output->dims()); + } + if (output->dim_size(output->dims() - 2) != m || + output->dim_size(output->dims() - 1) != n) { + return errors::Internal( + "Einsum output tensor shape mismatch, expected tail [", m, ", ", n, + "], got ", output->shape().DebugString()); + } + + if (output->NumElements() == 0) return Status::OK(); + + if (std::is_same::value || + std::is_same::value) { + const int64_t batch_a = in0.dims() == 2 ? 1 : in0.dim_size(0); + const int64_t batch_b = in1.dims() == 2 ? 1 : in1.dim_size(0); + const int64_t batch_out = output->dims() == 2 ? 1 : output->dim_size(0); + + const int64_t a_rows = + in0.dims() == 2 ? in0.dim_size(0) : in0.dim_size(1); + const int64_t a_cols = + in0.dims() == 2 ? in0.dim_size(1) : in0.dim_size(2); + const int64_t b_rows = + in1.dims() == 2 ? in1.dim_size(0) : in1.dim_size(1); + const int64_t b_cols = + in1.dims() == 2 ? in1.dim_size(1) : in1.dim_size(2); + + const int64_t elem_a = in0.NumElements(); + const int64_t elem_b = in1.NumElements(); + const int64_t elem_out = output->NumElements(); + + std::vector host_a(elem_a); + std::vector host_b(elem_b); + std::vector host_out(elem_out); + + mStatus memcpy_status = MusaMemcpyD2H(host_a.data(), in0.flat().data(), + elem_a * sizeof(T)); + if (memcpy_status != mStatus::SUCCESS) { + return errors::Internal("Einsum half path: MusaMemcpyD2H A failed"); + } + memcpy_status = MusaMemcpyD2H(host_b.data(), in1.flat().data(), + elem_b * sizeof(T)); + if (memcpy_status != mStatus::SUCCESS) { + return errors::Internal("Einsum half path: MusaMemcpyD2H B failed"); + } + + auto index_a = [&](int64_t batch, int64_t row, int64_t col) { + if (in0.dims() == 2) { + return row * a_cols + col; + } + return (batch * a_rows + row) * a_cols + col; + }; + auto index_b = [&](int64_t batch, int64_t row, int64_t col) { + if (in1.dims() == 2) { + return row * b_cols + col; + } + return (batch * b_rows + row) * b_cols + col; + }; + + for (int64_t bo = 0; bo < batch_out; ++bo) { + const int64_t ba = batch_a == 1 ? 0 : bo; + const int64_t bb = batch_b == 1 ? 0 : bo; + for (int64_t i = 0; i < m; ++i) { + for (int64_t j = 0; j < n; ++j) { + T sum = static_cast(0); + for (int64_t kk = 0; kk < k; ++kk) { + const int64_t ar = trans_a ? kk : i; + const int64_t ac = trans_a ? i : kk; + const int64_t br = trans_b ? j : kk; + const int64_t bc = trans_b ? kk : j; + const T av = host_a[index_a(ba, ar, ac)]; + const T bv = host_b[index_b(bb, br, bc)]; + sum = static_cast(sum + static_cast(av * bv)); + } + + if (output->dims() == 2) { + host_out[i * n + j] = sum; + } else { + host_out[(bo * m + i) * n + j] = sum; + } + } + } + } + + memcpy_status = MusaMemcpyH2D(output->flat().data(), host_out.data(), + elem_out * sizeof(T)); + if (memcpy_status != mStatus::SUCCESS) { + return errors::Internal( + "Einsum half path: MusaMemcpyH2D output failed"); + } + return Status::OK(); + } + + auto& handle = GetHandleByCtx(ctx); + handle.SetAllowTF32(false); + // Use TF32 setting if needed, but here we can just default or use env like + // matmul op. Since this is a static method, we don't have access to member + // tf32_enabled_. Let's check environment variable again or just assume + // default precision. For now, let's keep it simple and consistent with + // standard usage. + + mTensor mt_a = CreateMTensor(in0); + mTensor mt_b = CreateMTensor(in1); + mTensor mt_out = CreateMTensor(*output); + + auto FixToBatchFormat = [](mTensor& mt, const Tensor& t) { + if (t.dims() == 2) { + int64_t rows = t.dim_size(0); + int64_t cols = t.dim_size(1); + mt.SetNdInfo({1, rows, cols}, {rows * cols, cols, 1}); + } + }; + + ::musa::dnn::Status status; + + { + mBatchMatMul op; + status = op.SetDeterministic(true); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal( + "MUSA BatchMatMul SetDeterministic(true) failed. Status: ", + (int)status); + } + + if (std::is_same::value || + std::is_same::value) { + status = op.SetComputeMode(mBatchMatMul::ComputeMode::SCALAR); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal( + "MUSA BatchMatMul SetComputeMode(SCALAR) failed. Status: ", + (int)status); + } + status = op.SetMpCountTarget(1); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal( + "MUSA BatchMatMul SetMpCountTarget(1) failed. Status: ", + (int)status); + } + } + + op.SetTranspose(trans_a, trans_b); + op.SetAlpha(1.0); + op.SetBeta(0.0); + + FixToBatchFormat(mt_a, in0); + FixToBatchFormat(mt_b, in1); + FixToBatchFormat(mt_out, *output); + + status = op.Run(handle, mt_out, mt_a, mt_b); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("MUSA BatchMatMul execution failed. Status: ", + (int)status); + } + } + return Status::OK(); + } + + template + static Status ReduceOperand( + OpKernelContext* ctx, const Tensor& input, + const std::vector& label_types, + const LabelCounts& label_counts, Labels* labels, Labels* free_labels, + bool* swap_free_and_contract, Tensor* output) { + // Find the permutation to transpose the input dimensions in the order of + // EinsumDimensionType; i.e. batch, free, contract and reduce dimensions. + // This makes it more convenient to invoke Reduce/Contract operations. + std::vector permutation(input.dims()); + absl::c_iota(permutation, 0); + Tensor input_transposed; + // Check if we can avoid the transpose. We need to flip the adj_x (or adj_y) + // flag during BatchMatMul. This is an extra optimization not necessary for + // correctness. + if (ShouldSwapFreeAndContract(*labels, label_types)) { + *swap_free_and_contract = true; + } else { + absl::c_sort(permutation, [&](int i, int j) { + int label_i = (*labels)[i]; + int label_j = (*labels)[j]; + return std::tie(label_types[label_i], label_i) < + std::tie(label_types[label_j], label_j); + }); + } + // Transpose the input so that EinsumDimensionTypes are in order. + TF_RETURN_IF_ERROR( + TransposeOperand(ctx, input, permutation, &input_transposed)); + PermuteLabels(permutation, labels); + + // Take the generalized diagonal for dimensions with repeated axis labels. + Tensor input_deduped; + labels->erase(std::unique(labels->begin(), labels->end()), labels->end()); + TF_RETURN_IF_ERROR( + StrideOrInflate(ctx, input_transposed, *labels, label_counts, + false /* should_inflate */, &input_deduped)); + + // Reshape denotes the rank-5 shape [broadcast, batch, free, contract, + // reduce] where we've compacted the dimensions of each EinsumDimensionType. + gtl::InlinedVector reshape(5, 1); + // The output shape is [batch shape] + [free size, contract size] + // That is, the batch shape is preserved (for broadcasting while + // contracting) while the free dims and contract dims are compressed to one + // dimension each. + TensorShape output_shape; + for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { + const int label = labels->at(label_idx); + int64_t dim = input_deduped.dim_size(label_idx); + if (label_types[label] == EinsumDimensionType::kBroadcasting || + label_types[label] == EinsumDimensionType::kBatch) { + TF_RETURN_IF_ERROR(output_shape.AddDimWithStatus(dim)); + } else if (label_types[label] == EinsumDimensionType::kFree) { + free_labels->push_back(label); + } + reshape[label_types[label]] *= dim; + } + if (*swap_free_and_contract) + std::swap(reshape[EinsumDimensionType::kFree], + reshape[EinsumDimensionType::kContract]); + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(reshape[EinsumDimensionType::kFree])); + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(reshape[EinsumDimensionType::kContract])); + + if (reshape[EinsumDimensionType::kReduce] == 1) { // No need to reduce. + return CopyFrom(input_deduped, output_shape, output); + } + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); + const int64_t reduce_size = reshape[kReduce]; + const int64_t output_size = reshape[kBroadcasting] * reshape[kBatch] * + reshape[kFree] * reshape[kContract]; + + TensorShape input_flatten_shape; + TF_RETURN_IF_ERROR(input_flatten_shape.AddDimWithStatus(output_size)); + TF_RETURN_IF_ERROR(input_flatten_shape.AddDimWithStatus(reduce_size)); + Tensor input_flattened; + if (!input_flattened + .BitcastFrom(input_deduped, input_deduped.dtype(), + input_flatten_shape) + .ok()) { + return errors::Internal("Failed to reshape Einsum input for reduce"); + } + + TensorShape output_flatten_shape; + TF_RETURN_IF_ERROR(output_flatten_shape.AddDimWithStatus(output_size)); + Tensor output_flattened; + if (!output_flattened + .BitcastFrom(*output, output->dtype(), output_flatten_shape) + .ok()) { + return errors::Internal("Failed to reshape Einsum output for reduce"); + } + + auto input_mt = CreateMTensor(input_flattened); + auto output_mt = CreateMTensor(output_flattened); + + auto& handle = GetHandleByCtx(ctx); + mReduce op; + op.SetMode(::musa::dnn::Reduce::Mode::ADD); + int reduce_dims[] = {1}; + op.SetDim(1, reduce_dims); + + tensorflow::Allocator* tf_allocator = + ctx->device()->GetAllocator(tensorflow::AllocatorAttributes()); + auto alloc_func = + [tf_allocator]( + size_t size) -> std::unique_ptr> { + void* ptr = tf_allocator->AllocateRaw(256, size); + std::function deleter = [tf_allocator](void* p) { + if (p) tf_allocator->DeallocateRaw(p); + }; + return std::unique_ptr>(ptr, deleter); + }; + ::musa::dnn::MemoryMaintainer mm(alloc_func); + + auto status = op.Run(handle, output_mt, input_mt, mm); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("MUSA Reduce (sum) execution failed. Status: ", + static_cast(status)); + } + return Status::OK(); + } + + // Reshapes a Tensor of shape [b0,b1...bk,N,M] to [prod(b0,b1...bk),N,M]. + static Status ReshapeToRank3(const Tensor& input, int batch_size, + Tensor* output) { + const int rank = input.dims(); + TensorShape output_shape = {batch_size, input.dim_size(rank - 2), + input.dim_size(rank - 1)}; + return CopyFrom(input, output_shape, output); + } + + // Contracts the inputs along the last axis (or the second last if the + // corresponding value of swap_free_and_contract is true). The batch + // dimensions are broadcast to the output shape. + template + static Status ContractOperands(OpKernelContext* ctx, + absl::Span inputs, + absl::Span swap_free_and_contract, + Tensor* output) { + if (inputs.size() == 1) + return CopyFrom(inputs[0], inputs[0].shape(), output); + MatMulBCast bcast(inputs[0].shape().dim_sizes(), + inputs[1].shape().dim_sizes()); + if (!bcast.IsValid()) { + return errors::InvalidArgument( + "Invalid broadcasting dimensions: ", inputs[0].shape().DebugString(), + " vs. ", inputs[1].shape().DebugString()); + } + Tensor lhs; + TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[0], bcast.x_batch_size(), &lhs)); + Tensor rhs; + TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs)); + TensorShape output_shape = bcast.output_batch_shape(); + for (int i = 0; i < inputs.size(); ++i) { + const int64_t free_axis = + inputs[i].dims() - (swap_free_and_contract[i] ? 1 : 2); + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(inputs[i].dim_size(free_axis))); + } + bool trans_x = swap_free_and_contract[0]; + bool trans_y = !swap_free_and_contract[1]; + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); + if (lhs.NumElements() == 0 || rhs.NumElements() == 0) { + SetZeroFunctor set_zero; + set_zero(ctx, output); + return Status::OK(); + } + Tensor output_reshaped; + TF_RETURN_IF_ERROR( + ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped)); + return BMatMul(ctx, lhs, rhs, trans_x, trans_y, &output_reshaped); + } +}; + +template +class MusaEinsumOp : public MusaOpKernel { + public: + explicit MusaEinsumOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("equation", &equation_)); + OP_REQUIRES_OK( + ctx, ParseEinsumEquation(equation_, &input_labels_, &output_labels_, + &label_types_, &input_label_counts_, + &output_label_counts_, &input_has_ellipsis_, + &output_has_ellipsis_)); + } + + void Compute(OpKernelContext* ctx) override { + OpInputList inputs; + OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &inputs)); + + // Take ...i,i->... as an example. After parsing the equation, we have + // input_labels = [kEllipsisLabel, 0] + // output_labels = [kEllipsisLabel] + // label_types = [${EinsumDimensionType for label 0, which is kContract}] + // input_label_counts = [1, 1], which use the default value of 1 for the + // every label. output_label_counts = [1] label_to_dim_sizes = {}, which + // will be populated during dimension processing. + OperandLabels input_labels(input_labels_); + Labels output_labels(output_labels_); + std::vector label_types(label_types_); + OperandLabelCounts input_label_counts(input_label_counts_); + LabelCounts output_label_counts(output_label_counts_); + LabelToDimSizes label_to_dim_sizes; + + OP_REQUIRES_OK(ctx, EinsumHelper::ProcessDimensions( + inputs, input_has_ellipsis_, output_has_ellipsis_, + &input_labels, &output_labels, &label_types, + &input_label_counts, &output_label_counts, + &label_to_dim_sizes)); + + // The reduction phase (a) sums across reduction dimensions, (b) takes + // generalized diagonals, and (c) reshapes it into shape + // [(broadcasting) batch shape] + [F,C] + // where F and C denote the total (compacted) size of free and contract + // dimensions, respectively. + const int num_inputs = inputs.size(); + OperandLabels free_labels(num_inputs); + gtl::InlinedVector inputs_reduced(num_inputs); + gtl::InlinedVector swap_free_and_contract(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + OP_REQUIRES_OK(ctx, + EinsumHelper::ReduceOperand( + ctx, inputs[i], label_types, input_label_counts[i], + &input_labels[i], &free_labels[i], + &swap_free_and_contract[i], &inputs_reduced[i])); + } + + // After reduction, the inputs should be reshaped to Tensors suitable for + // contraction. If num_inputs is 1, the reduced input is simply forwarded to + // the output. + Tensor contraction_output_reshaped; + OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands( + ctx, inputs_reduced, swap_free_and_contract, + &contraction_output_reshaped)); + + // Copy the batch labels from the contraction output. Recover the batch + // shape, which may have been broadcasted. + TensorShape result_shape = contraction_output_reshaped.shape(); + result_shape.RemoveLastDims(2); + + int num_labels = label_types.size(); + Labels result_labels; + // All batch dimensions should be present in the contracted result. First + // the broadcasting dimensions, then the named batch dimensions. + for (int label = 0; label < num_labels; ++label) { + if (label_types[label] == EinsumDimensionType::kBroadcasting) + result_labels.push_back(label); + } + for (int label = 0; label < num_labels; ++label) { + if (label_types[label] == EinsumDimensionType::kBatch) + result_labels.push_back(label); + } + for (int i = 0; i < num_inputs; ++i) { + for (int label : free_labels[i]) { + result_labels.push_back(label); + OP_REQUIRES_OK( + ctx, result_shape.AddDimWithStatus(label_to_dim_sizes[label])); + } + } + + // Reshape the contraction (or reduction) result to its expanded shape: + // [(broadcasted) batch shape] + [free shape 0] + [free shape 1]. + Tensor contraction_output; + OP_REQUIRES_OK( + ctx, EinsumHelper::CopyFrom(contraction_output_reshaped, result_shape, + &contraction_output)); + + // Inflate the output if necessary. (E.g. for the equation 'i->iii' which + // may arise while computing gradient of a regular Einsum). + Tensor output_inflated; + OP_REQUIRES_OK( + ctx, EinsumHelper::StrideOrInflate( + ctx, contraction_output, result_labels, output_label_counts, + true /* should_inflate */, &output_inflated)); + if (output_inflated.dims() > contraction_output.dims()) { + // We inflated the output. Modify result labels accordingly. + Labels inflated_labels; + for (int label : result_labels) { + inflated_labels.insert(inflated_labels.end(), + output_label_counts[label], label); + } + result_labels.swap(inflated_labels); + } + // Find the permutation to map the result labels to the output labels. Note + // that both the result and the final output may have the repeated labels, + // in which case the permutation preserves the left-to-right ordering. + // E.g. if result labels are [0, 0, 1] and output is [0, l, 0] then the + // permutation should be [0, 2, 1]. We also use the fact that repeated + // labels in the result are adjacent to each other. + std::vector output_permutation(output_labels.size()); + std::vector label_to_position(num_labels, -1); + for (int i = 0; i < result_labels.size(); ++i) { + // Remember the position of only the leftmost result label. + if (label_to_position[result_labels[i]] == -1) { + label_to_position[result_labels[i]] = i; + } + } + for (int i = 0; i < output_labels.size(); ++i) { + output_permutation[i] = label_to_position[output_labels[i]]; + // We have found the leftmost occurrence. The next one would be adjacent. + label_to_position[output_labels[i]] += 1; + } + Tensor output; + OP_REQUIRES_OK(ctx, EinsumHelper::TransposeOperand( + ctx, output_inflated, output_permutation, &output)); + ctx->set_output(0, output); + } + + string TraceString(const OpKernelContext& ctx, bool verbose) const override { + string op = profiler::TraceMeOp(name_view(), type_string_view()); + string equation = strings::StrCat("(", equation_, ")"); + if (verbose) { + string shape = ShapeTraceString(ctx); + if (!shape.empty()) { + return profiler::TraceMeEncode( + std::move(op), {{"equation", equation}, {"shape", shape}}); + } + } + return profiler::TraceMeEncode(std::move(op), {{"equation", equation}}); + } + + private: + string equation_; + OperandLabels input_labels_; + Labels output_labels_; + std::vector label_types_; + OperandLabelCounts input_label_counts_; + LabelCounts output_label_counts_; + gtl::InlinedVector input_has_ellipsis_; + bool output_has_ellipsis_ = false; +}; // class MusaEinsumOp + +#define REGISTER_MUSA_EINSUM(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Einsum").Device("MUSA").TypeConstraint("T"), \ + MusaEinsumOp); + +REGISTER_MUSA_EINSUM(Eigen::half); +REGISTER_MUSA_EINSUM(bfloat16); + +#undef REGISTER_MUSA_EINSUM + +} // namespace musa +} // namespace tensorflow diff --git a/musa_ext/utils/musa_einsum_op_util.h b/musa_ext/utils/musa_einsum_op_util.h index ec536d8..01116ba 100644 --- a/musa_ext/utils/musa_einsum_op_util.h +++ b/musa_ext/utils/musa_einsum_op_util.h @@ -42,7 +42,7 @@ enum EinsumDimensionType { // present in exactly one input subscript (is_unique) and whether it is absent // from the output subscripts (is_removed). Does not handle broadcasting // dimensions. -EinsumDimensionType GetDimensionType(bool is_removed, bool is_unique) { +inline EinsumDimensionType GetDimensionType(bool is_removed, bool is_unique) { if (!is_removed && !is_unique) return kBatch; else if (!is_removed && is_unique) @@ -53,7 +53,7 @@ EinsumDimensionType GetDimensionType(bool is_removed, bool is_unique) { return kReduce; } -Status ValidateEinsumEquation( +inline Status ValidateEinsumEquation( const std::string& equation, absl::InlinedVector* input_subscripts, std::string* output_subscript) { @@ -75,8 +75,8 @@ Status ValidateEinsumEquation( } // Maps the character labels to consecutive integers. -void MapToLabels(const std::string& subscript, Labels* labels, - absl::flat_hash_map* label_mapping) { +inline void MapToLabels(const std::string& subscript, Labels* labels, + absl::flat_hash_map* label_mapping) { for (int i = 0; i < subscript.size(); ++i) { const char label_char = subscript[i]; if (label_char == '.') { @@ -93,13 +93,12 @@ void MapToLabels(const std::string& subscript, Labels* labels, } } -Status ParseEinsumEquation(const std::string& equation, - OperandLabels* input_labels, Labels* output_labels, - std::vector* label_types, - OperandLabelCounts* input_label_counts, - LabelCounts* output_label_counts, - absl::InlinedVector* input_has_ellipsis, - bool* output_has_ellipsis) { +inline Status ParseEinsumEquation( + const std::string& equation, OperandLabels* input_labels, + Labels* output_labels, std::vector* label_types, + OperandLabelCounts* input_label_counts, LabelCounts* output_label_counts, + absl::InlinedVector* input_has_ellipsis, + bool* output_has_ellipsis) { absl::InlinedVector input_str; std::string output_str; TF_RETURN_IF_ERROR(ValidateEinsumEquation(equation, &input_str, &output_str)); From c640be5fbc9c05e35a30c0ef5a5df0f27c8598b5 Mon Sep 17 00:00:00 2001 From: albert Date: Fri, 27 Feb 2026 06:42:30 +0000 Subject: [PATCH 10/16] update test cases --- .gitignore | 1 + musa_ext/kernels/musa_einsum_op.cc | 212 ++++-- musa_ext/kernels/musa_einsum_op_half.cc | 862 ------------------------ test/einsum_op_test.py | 187 ++--- 4 files changed, 278 insertions(+), 984 deletions(-) delete mode 100644 musa_ext/kernels/musa_einsum_op_half.cc diff --git a/.gitignore b/.gitignore index 6b683f1..37e4b9d 100755 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ __pycache__ build *.log +.vscode/** \ No newline at end of file diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc index 2d9a4b0..20f278e 100644 --- a/musa_ext/kernels/musa_einsum_op.cc +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -1,7 +1,9 @@ #include "musa_einsum_op.h" +#include #include #include +#include #include "../mu/device/musa_device.h" #include "../utils/musa_einsum_op_util.h" @@ -351,24 +353,93 @@ struct EinsumHelper { int64 n = trans_b ? d2 : d3; int64 k_check = trans_b ? d3 : d2; - if (k != k_check) { - return errors::InvalidArgument( - "Matrix size-incompatible: In[0] mismatch In[1]"); - } - - if (output->dims() < 2) { - return errors::Internal( - "Einsum output tensor rank must be at least 2, got ", output->dims()); - } - if (output->dim_size(output->dims() - 2) != m || - output->dim_size(output->dims() - 1) != n) { - return errors::Internal( - "Einsum output tensor shape mismatch, expected tail [", m, ", ", n, - "], got ", output->shape().DebugString()); - } - if (output->NumElements() == 0) return Status::OK(); + // // USE CPU to compute half and bfloat16 bMatMul + // if (std::is_same::value || + // std::is_same::value) { + // const int64_t batch_a = in0.dims() == 2 ? 1 : in0.dim_size(0); + // const int64_t batch_b = in1.dims() == 2 ? 1 : in1.dim_size(0); + // const int64_t batch_out = output->dims() == 2 ? 1 : + // output->dim_size(0); + + // const int64_t a_rows = + // in0.dims() == 2 ? in0.dim_size(0) : in0.dim_size(1); + // const int64_t a_cols = + // in0.dims() == 2 ? in0.dim_size(1) : in0.dim_size(2); + // const int64_t b_rows = + // in1.dims() == 2 ? in1.dim_size(0) : in1.dim_size(1); + // const int64_t b_cols = + // in1.dims() == 2 ? in1.dim_size(1) : in1.dim_size(2); + + // const int64_t elem_a = in0.NumElements(); + // const int64_t elem_b = in1.NumElements(); + // const int64_t elem_out = output->NumElements(); + + // std::vector host_a(elem_a); + // std::vector host_b(elem_b); + // std::vector host_out(elem_out); + + // mStatus memcpy_status = MusaMemcpyD2H(host_a.data(), + // in0.flat().data(), + // elem_a * sizeof(T)); + // if (memcpy_status != mStatus::SUCCESS) { + // return errors::Internal("Einsum half path: MusaMemcpyD2H A failed"); + // } + // memcpy_status = MusaMemcpyD2H(host_b.data(), in1.flat().data(), + // elem_b * sizeof(T)); + // if (memcpy_status != mStatus::SUCCESS) { + // return errors::Internal("Einsum half path: MusaMemcpyD2H B failed"); + // } + + // auto index_a = [&](int64_t batch, int64_t row, int64_t col) { + // if (in0.dims() == 2) { + // return row * a_cols + col; + // } + // return (batch * a_rows + row) * a_cols + col; + // }; + // auto index_b = [&](int64_t batch, int64_t row, int64_t col) { + // if (in1.dims() == 2) { + // return row * b_cols + col; + // } + // return (batch * b_rows + row) * b_cols + col; + // }; + + // for (int64_t bo = 0; bo < batch_out; ++bo) { + // const int64_t ba = batch_a == 1 ? 0 : bo; + // const int64_t bb = batch_b == 1 ? 0 : bo; + // for (int64_t i = 0; i < m; ++i) { + // for (int64_t j = 0; j < n; ++j) { + // T sum = static_cast(0); + // for (int64_t kk = 0; kk < k; ++kk) { + // const int64_t ar = trans_a ? kk : i; + // const int64_t ac = trans_a ? i : kk; + // const int64_t br = trans_b ? j : kk; + // const int64_t bc = trans_b ? kk : j; + // const T av = host_a[index_a(ba, ar, ac)]; + // const T bv = host_b[index_b(bb, br, bc)]; + // sum = static_cast(sum + static_cast(av * bv)); + // } + + // if (output->dims() == 2) { + // host_out[i * n + j] = sum; + // } else { + // host_out[(bo * m + i) * n + j] = sum; + // } + // } + // } + // } + + // memcpy_status = MusaMemcpyH2D(output->flat().data(), + // host_out.data(), + // elem_out * sizeof(T)); + // if (memcpy_status != mStatus::SUCCESS) { + // return errors::Internal( + // "Einsum half path: MusaMemcpyH2D output failed"); + // } + // return Status::OK(); + // } + auto& handle = GetHandleByCtx(ctx); handle.SetAllowTF32(false); // Use TF32 setting if needed, but here we can just default or use env like @@ -391,34 +462,21 @@ struct EinsumHelper { ::musa::dnn::Status status; - if (in0.dims() == 2 && in1.dims() == 2) { - mMatMul op; - op.SetTranspose(trans_a, trans_b); - op.SetAlpha(1.0); - op.SetBeta(0.0); - - status = op.Run(handle, mt_out, mt_a, mt_b); - if (status != ::musa::dnn::Status::SUCCESS) { - return errors::Internal( - "MUSA MatMul (2D High Precision) execution failed. Status: ", - (int)status); - } - } else { - mBatchMatMul op; - op.SetTranspose(trans_a, trans_b); - op.SetAlpha(1.0); - op.SetBeta(0.0); - - FixToBatchFormat(mt_a, in0); - FixToBatchFormat(mt_b, in1); - FixToBatchFormat(mt_out, *output); - - status = op.Run(handle, mt_out, mt_a, mt_b); - if (status != ::musa::dnn::Status::SUCCESS) { - return errors::Internal("MUSA BatchMatMul execution failed. Status: ", - (int)status); - } + mBatchMatMul op; + op.SetTranspose(trans_a, trans_b); + op.SetAlpha(1.0); + op.SetBeta(0.0); + + FixToBatchFormat(mt_a, in0); + FixToBatchFormat(mt_b, in1); + FixToBatchFormat(mt_out, *output); + + status = op.Run(handle, mt_out, mt_a, mt_b); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("MUSA BatchMatMul execution failed. Status: ", + (int)status); } + return Status::OK(); } @@ -554,6 +612,62 @@ struct EinsumHelper { return CopyFrom(input, output_shape, output); } + template + static Status MaterializeBroadcastedBatch( + OpKernelContext* ctx, const Tensor& input, int64_t input_batch_size, + int64_t output_batch_size, const std::vector& batch_indices, + Tensor* output) { + Tensor input_rank3; + TF_RETURN_IF_ERROR(ReshapeToRank3(input, static_cast(input_batch_size), + &input_rank3)); + + TensorShape output_shape = {output_batch_size, input_rank3.dim_size(1), + input_rank3.dim_size(2)}; + + if (input_batch_size == output_batch_size && batch_indices.empty()) { + return CopyFrom(input_rank3, output_shape, output); + } + + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); + if (output->NumElements() == 0) return Status::OK(); + + const int64_t elems_per_batch = + input_rank3.dim_size(1) * input_rank3.dim_size(2); + std::vector host_input(input_rank3.NumElements()); + std::vector host_output(output->NumElements()); + + mStatus memcpy_status = + MusaMemcpyD2H(host_input.data(), input_rank3.flat().data(), + input_rank3.NumElements() * sizeof(T)); + if (memcpy_status != mStatus::SUCCESS) { + return errors::Internal( + "Einsum batch broadcast: MusaMemcpyD2H input failed"); + } + + for (int64_t out_batch = 0; out_batch < output_batch_size; ++out_batch) { + const int64_t in_batch = + batch_indices.empty() ? out_batch : batch_indices[out_batch]; + if (in_batch < 0 || in_batch >= input_batch_size) { + return errors::Internal("Einsum batch broadcast: invalid batch index ", + in_batch, " for input batch size ", + input_batch_size); + } + const T* src = host_input.data() + in_batch * elems_per_batch; + T* dst = host_output.data() + out_batch * elems_per_batch; + std::memcpy(dst, src, elems_per_batch * sizeof(T)); + } + + memcpy_status = MusaMemcpyH2D(output->flat().data(), host_output.data(), + output->NumElements() * sizeof(T)); + if (memcpy_status != mStatus::SUCCESS) { + return errors::Internal( + "Einsum batch broadcast: MusaMemcpyH2D output failed"); + } + + return Status::OK(); + } + // Contracts the inputs along the last axis (or the second last if the // corresponding value of swap_free_and_contract is true). The batch // dimensions are broadcast to the output shape. @@ -571,10 +685,16 @@ struct EinsumHelper { "Invalid broadcasting dimensions: ", inputs[0].shape().DebugString(), " vs. ", inputs[1].shape().DebugString()); } + Tensor lhs; - TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[0], bcast.x_batch_size(), &lhs)); + TF_RETURN_IF_ERROR(MaterializeBroadcastedBatch( + ctx, inputs[0], bcast.x_batch_size(), bcast.output_batch_size(), + bcast.x_batch_indices(), &lhs)); Tensor rhs; - TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs)); + TF_RETURN_IF_ERROR(MaterializeBroadcastedBatch( + ctx, inputs[1], bcast.y_batch_size(), bcast.output_batch_size(), + bcast.y_batch_indices(), &rhs)); + TensorShape output_shape = bcast.output_batch_shape(); for (int i = 0; i < inputs.size(); ++i) { const int64_t free_axis = @@ -765,6 +885,8 @@ REGISTER_MUSA_EINSUM(float); REGISTER_MUSA_EINSUM(double); REGISTER_MUSA_EINSUM(int32); REGISTER_MUSA_EINSUM(int64); +REGISTER_MUSA_EINSUM(Eigen::half); +REGISTER_MUSA_EINSUM(bfloat16); #undef REGISTER_MUSA_EINSUM diff --git a/musa_ext/kernels/musa_einsum_op_half.cc b/musa_ext/kernels/musa_einsum_op_half.cc deleted file mode 100644 index a948b1e..0000000 --- a/musa_ext/kernels/musa_einsum_op_half.cc +++ /dev/null @@ -1,862 +0,0 @@ -#include -#include -#include - -#include "../mu/device/musa_device.h" -#include "../utils/musa_einsum_op_util.h" -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_split.h" -#include "mu/device/musa_memcpy.h" -#include "musa_einsum_op.h" -#include "musa_fill_functor.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/lib/traceme.h" -#include "tensorflow/core/util/matmul_bcast.h" -#include "utils_op.h" - -namespace tensorflow { -namespace musa { - -void DoTranspose(OpKernelContext* ctx, const Tensor& input, - const std::vector& permutation, Tensor* output); - -using ShapeVec = gtl::InlinedVector; -using Labels = gtl::InlinedVector; -using OperandLabels = gtl::InlinedVector; -using LabelCounts = gtl::InlinedVector; -using OperandLabelCounts = gtl::InlinedVector; -using LabelToDimSizes = gtl::InlinedVector; - -struct EinsumHelper { - // Insert new (unnamed) broadcasting labels at the location of ellipsis. - static void InsertBroadcastLabels(int num_bcast_dims, int num_named_labels, - int ellipsis_axis, Labels* labels, - LabelCounts* label_counts) { - labels->erase(labels->begin() + ellipsis_axis); - labels->insert(labels->begin() + ellipsis_axis, num_bcast_dims, 0); - std::iota(labels->begin() + ellipsis_axis, - labels->begin() + ellipsis_axis + num_bcast_dims, - num_named_labels); - // Increment label counts. Since these are new labels, the count is set - // to 1. - label_counts->resize(num_named_labels + num_bcast_dims, 1); - } - - // Record and validate the label to dimension mapping. Must be a named - // (non-broadcasting) label as broadcasting labels don't have a fixed - // dimension. - static Status RecordLabelToDimension(const int label, const int axis, - const Tensor& input, - LabelToDimSizes* label_to_dim_sizes) { - const int64_t input_dim = input.dim_size(axis); - // We know that label_to_dim_sizes has the size to accommodate named labels. - if (label_to_dim_sizes->at(label) != 0 && - label_to_dim_sizes->at(label) != input_dim) { - return errors::InvalidArgument( - "Expected dimension ", label_to_dim_sizes->at(label), " at axis ", - axis, " of the input shaped ", input.shape().DebugString(), - " but got dimension ", input_dim); - } - (*label_to_dim_sizes)[label] = input_dim; - return Status::OK(); - } - - // Validate input dimensions and populate unnamed labels and their label - // counts. - static Status ProcessDimensions( - const OpInputList& inputs, - const gtl::InlinedVector& input_has_ellipsis, - const bool output_has_ellipsis, OperandLabels* input_labels, - Labels* output_labels, std::vector* label_types, - OperandLabelCounts* input_label_counts, LabelCounts* output_label_counts, - LabelToDimSizes* label_to_dim_sizes) { - if (inputs.size() != input_labels->size()) { - return errors::InvalidArgument("Expected ", input_labels->size(), - " inputs but got: ", inputs.size()); - } - const int num_inputs = inputs.size(); - - // We infer the number of broadcasting dimensions by taking the maximum - // rank among the broadcasting subshapes of the input. - int max_bcast_dims = 0; - const int num_named_labels = label_types->size(); - label_to_dim_sizes->resize(num_named_labels); - for (int i = 0; i < num_inputs; ++i) { - Labels* labels = &(*input_labels)[i]; - - if (!input_has_ellipsis[i]) { - if (inputs[i].dims() != labels->size()) { - return errors::InvalidArgument("Expected input ", i, " to have rank ", - labels->size(), - " but got: ", inputs[i].dims()); - } - for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { - const int label = (*labels)[label_idx]; - TF_RETURN_IF_ERROR(RecordLabelToDimension(label, label_idx, inputs[i], - label_to_dim_sizes)); - } - continue; - } - - // Input has an ellipsis. - if (inputs[i].dims() + 1 < labels->size()) { - return errors::InvalidArgument( - "Expected input ", i, " to have rank at least ", labels->size() - 1, - " but got: ", inputs[i].dims()); - } - int ellipsis_axis = -1; - const int num_bcast_dims = inputs[i].dims() - labels->size() + 1; - for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { - const int label = (*labels)[label_idx]; - if (label == kEllipsisLabel) { - ellipsis_axis = label_idx; - continue; - } - // Current label is not an ellipsis. - const int axis = - label_idx + (ellipsis_axis == -1 ? 0 : num_bcast_dims - 1); - TF_RETURN_IF_ERROR( - RecordLabelToDimension(label, axis, inputs[i], label_to_dim_sizes)); - } - // Found an ellipsis. Replace 'kEllipsisLabel' with broadcasting - // dimensions. - if (ellipsis_axis != -1) { - InsertBroadcastLabels(num_bcast_dims, num_named_labels, ellipsis_axis, - labels, &input_label_counts->at(i)); - max_bcast_dims = std::max(max_bcast_dims, num_bcast_dims); - } - } - if (!absl::c_linear_search(input_has_ellipsis, true) && - !output_has_ellipsis) { - return Status::OK(); - } - // Insert broadcasting dimensions in the output labels. - auto it = - std::find(output_labels->begin(), output_labels->end(), kEllipsisLabel); - if (it != output_labels->end()) { - const int ellipsis_axis = it - output_labels->begin(); - InsertBroadcastLabels(max_bcast_dims, num_named_labels, ellipsis_axis, - output_labels, output_label_counts); - } else if (max_bcast_dims > 0) { - return errors::InvalidArgument( - "Output contains ", max_bcast_dims, - " broadcasting dimension(s) but no ellipsis " - "(...) was found in the output subscripts."); - } - // Populate EinsumDimensionType for the new broadcasting labels. - label_types->resize(num_named_labels + max_bcast_dims, - EinsumDimensionType::kBroadcasting); - return Status::OK(); - } - - // Permutes the labels according to the given permutation. - static void PermuteLabels(const std::vector& permutation, - Labels* labels) { - Labels permuted_labels(labels->size()); - for (int i = 0; i < labels->size(); ++i) { - permuted_labels[i] = (*labels)[permutation[i]]; - } - labels->swap(permuted_labels); - } - - // Returns a reshaped input Tensor. The underlying buffer is not copied. - static Status CopyFrom(const Tensor& input, const TensorShape& shape, - Tensor* output) { - if (output->CopyFrom(input, shape)) return Status::OK(); - return errors::Internal( - "Encountered error while reshaping a Tensor of shape ", - input.shape().DebugString(), " to shape ", shape.DebugString()); - } - - // Returns whether transposing would be a no-op; whether input has rank < 2 or - // the permutation is the identity permutation. - static bool ShouldTranspose(const TensorShape& input_shape, - const std::vector& permutation) { - if (input_shape.dims() < 2) return false; - for (int i = 0; i < permutation.size(); ++i) { - if (permutation[i] != i) return true; - } - return false; - } - - // Transpose the input given a permutation. Returns a reference to the input - // if transposing is not necessary. - template - static Status TransposeOperand(OpKernelContext* ctx, const Tensor& input, - const std::vector& permutation, - Tensor* output) { - if (!ShouldTranspose(input.shape(), permutation)) { - return CopyFrom(input, input.shape(), output); - } - TensorShape transposed_shape; - for (int i = 0; i < input.dims(); ++i) { - TF_RETURN_IF_ERROR( - transposed_shape.AddDimWithStatus(input.dim_size(permutation[i]))); - } - // For empty Tensors, just change the shape. E.g. we may need to transpose - // from shape [1, 0, 5] to [5, 1, 0]. - if (input.NumElements() == 0) { - return CopyFrom(input, transposed_shape, output); - } - TF_RETURN_IF_ERROR( - ctx->allocate_temp(DataTypeToEnum::value, transposed_shape, output)); - DoTranspose(ctx, input, permutation, output); - return Status::OK(); - } - - // If there are repeated labels in either the input or output, then this - // strides the input (e.g. iii->i) or inflates it (e.g. i->iii), respectively. - template - static Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input, - const Labels& labels, - const LabelCounts& label_counts, - const bool should_inflate, Tensor* output) { - // Return early if there are no repeated indices. - if (absl::c_all_of(label_counts, [](int c) { return c <= 1; })) { - return CopyFrom(input, input.shape(), output); - } - // We reshape so that each repeated label is compressed to one dimension. - // E.g. For iiij -> ij, The shape [3, 3, 3, 5] would be compressed to [27, - // 5]. Striding appropriately (in this case with strides 14 (=1+3+9) and 1) - // recovers the generalized diagonal of shape [3, 5]. - ShapeVec reshape; - ShapeVec strides; - // Strided and inflated shapes correspond to input and output shapes, - // respectively, should_inflate is true (vice-versa if should_inflate is - // false). E.g. they are [3, 5] and [3, 3, 3, 5] in the above example. - ShapeVec strided_shape; - ShapeVec inflated_shape; - for (int label : labels) { - const int count = label_counts[label]; - const int current_axis = - should_inflate ? strided_shape.size() : inflated_shape.size(); - const int64_t dim = input.dim_size(current_axis); - strided_shape.push_back(dim); - inflated_shape.insert(inflated_shape.end(), count, dim); - const int64_t reshape_dim = MathUtil::IPow(dim, count); - reshape.push_back(reshape_dim); - // While taking the d-diagonal in a rank k Tensor, we take d - // equally-spaced elements including the first and last element. Then, (k - // - 1) * stride = d^k - 1, or, stride = (d^k - 1)/(d - 1). - const int64_t stride = - (dim > 1 && count > 1) ? (reshape_dim - 1) / (dim - 1) : 1; - strides.push_back(stride); - } - - const ShapeVec& output_shape_dims = - should_inflate ? inflated_shape : strided_shape; - TensorShape output_shape; - for (int64_t dim : output_shape_dims) { - TF_RETURN_IF_ERROR(output_shape.AddDimWithStatus(dim)); - } - TF_RETURN_IF_ERROR( - ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); - - const int rank = reshape.size(); - if (rank == 0) return Status::OK(); - - auto compute_dense_strides = [](const ShapeVec& dims) { - ShapeVec dense_strides(dims.size(), 1); - for (int i = static_cast(dims.size()) - 2; i >= 0; --i) { - dense_strides[i] = dense_strides[i + 1] * dims[i + 1]; - } - return dense_strides; - }; - - const auto inflated_dense = compute_dense_strides(inflated_shape); - ShapeVec diagonal_strides; - diagonal_strides.reserve(rank); - int inflated_axis = 0; - for (int label : labels) { - const int count = label_counts[label]; - int64_t diagonal_stride = 0; - for (int i = 0; i < count; ++i) { - diagonal_stride += inflated_dense[inflated_axis + i]; - } - diagonal_strides.push_back(diagonal_stride); - inflated_axis += count; - } - - std::vector strided_dims_vec(strided_shape.begin(), - strided_shape.end()); - std::vector diagonal_strides_vec(diagonal_strides.begin(), - diagonal_strides.end()); - - auto& handle = GetHandleByCtx(ctx); - auto input_mt = CreateMTensor(input); - auto output_mt = CreateMTensor(*output); - - if (should_inflate) { - SetZeroFunctor set_zero; - set_zero(ctx, output); - output_mt.SetNdInfo(rank, strided_dims_vec.data(), - diagonal_strides_vec.data()); - } else { - input_mt.SetNdInfo(rank, strided_dims_vec.data(), - diagonal_strides_vec.data()); - } - - ::musa::dnn::Unary op; - auto status = op.SetMode(::musa::dnn::Unary::Mode::IDENTITY); - if (status != ::musa::dnn::Status::SUCCESS) { - return errors::Internal("Einsum StrideOrInflate SetMode failed. Status: ", - static_cast(status)); - } - status = op.Run(handle, output_mt, input_mt); - if (status != ::musa::dnn::Status::SUCCESS) { - return errors::Internal("Einsum StrideOrInflate Run failed. Status: ", - static_cast(status)); - } - return Status::OK(); - } - - // Returns true if the input dimensions are already sorted in the order - // [batch, contract, free, reduce]. Used to implement an optimization to avoid - // an extra transpose and instead uses (adj_x and adj_y) in BatchMatMul. - static bool ShouldSwapFreeAndContract( - const Labels& labels, - const std::vector& label_types) { - // Check that ordering is according to dimension type, with the role of - // free and contract dimensions swapped. - gtl::InlinedVector remap = {0, 1, 3, 2, 4}; - for (int i = 0; i + 1 < labels.size(); ++i) { - const int dimtype_a = remap[label_types[labels[i]]]; - const int dimtype_b = remap[label_types[labels[i + 1]]]; - if (dimtype_a > dimtype_b || - (dimtype_a == dimtype_b && labels[i] > labels[i + 1])) { - return false; - } - } - return true; - } - - template - static Status BMatMul(OpKernelContext* ctx, Tensor& lhs, const Tensor& rhs, - bool trans_a, bool trans_b, Tensor* output) { - const Tensor& in0 = lhs; - const Tensor& in1 = rhs; - - int64 d0 = in0.dim_size(in0.dims() - 2); - int64 d1 = in0.dim_size(in0.dims() - 1); - int64 d2 = in1.dim_size(in1.dims() - 2); - int64 d3 = in1.dim_size(in1.dims() - 1); - - int64 m = trans_a ? d1 : d0; - int64 k = trans_a ? d0 : d1; - int64 n = trans_b ? d2 : d3; - int64 k_check = trans_b ? d3 : d2; - - if (k != k_check) { - return errors::InvalidArgument( - "Matrix size-incompatible: In[0] mismatch In[1]"); - } - - if (output->dims() < 2) { - return errors::Internal( - "Einsum output tensor rank must be at least 2, got ", output->dims()); - } - if (output->dim_size(output->dims() - 2) != m || - output->dim_size(output->dims() - 1) != n) { - return errors::Internal( - "Einsum output tensor shape mismatch, expected tail [", m, ", ", n, - "], got ", output->shape().DebugString()); - } - - if (output->NumElements() == 0) return Status::OK(); - - if (std::is_same::value || - std::is_same::value) { - const int64_t batch_a = in0.dims() == 2 ? 1 : in0.dim_size(0); - const int64_t batch_b = in1.dims() == 2 ? 1 : in1.dim_size(0); - const int64_t batch_out = output->dims() == 2 ? 1 : output->dim_size(0); - - const int64_t a_rows = - in0.dims() == 2 ? in0.dim_size(0) : in0.dim_size(1); - const int64_t a_cols = - in0.dims() == 2 ? in0.dim_size(1) : in0.dim_size(2); - const int64_t b_rows = - in1.dims() == 2 ? in1.dim_size(0) : in1.dim_size(1); - const int64_t b_cols = - in1.dims() == 2 ? in1.dim_size(1) : in1.dim_size(2); - - const int64_t elem_a = in0.NumElements(); - const int64_t elem_b = in1.NumElements(); - const int64_t elem_out = output->NumElements(); - - std::vector host_a(elem_a); - std::vector host_b(elem_b); - std::vector host_out(elem_out); - - mStatus memcpy_status = MusaMemcpyD2H(host_a.data(), in0.flat().data(), - elem_a * sizeof(T)); - if (memcpy_status != mStatus::SUCCESS) { - return errors::Internal("Einsum half path: MusaMemcpyD2H A failed"); - } - memcpy_status = MusaMemcpyD2H(host_b.data(), in1.flat().data(), - elem_b * sizeof(T)); - if (memcpy_status != mStatus::SUCCESS) { - return errors::Internal("Einsum half path: MusaMemcpyD2H B failed"); - } - - auto index_a = [&](int64_t batch, int64_t row, int64_t col) { - if (in0.dims() == 2) { - return row * a_cols + col; - } - return (batch * a_rows + row) * a_cols + col; - }; - auto index_b = [&](int64_t batch, int64_t row, int64_t col) { - if (in1.dims() == 2) { - return row * b_cols + col; - } - return (batch * b_rows + row) * b_cols + col; - }; - - for (int64_t bo = 0; bo < batch_out; ++bo) { - const int64_t ba = batch_a == 1 ? 0 : bo; - const int64_t bb = batch_b == 1 ? 0 : bo; - for (int64_t i = 0; i < m; ++i) { - for (int64_t j = 0; j < n; ++j) { - T sum = static_cast(0); - for (int64_t kk = 0; kk < k; ++kk) { - const int64_t ar = trans_a ? kk : i; - const int64_t ac = trans_a ? i : kk; - const int64_t br = trans_b ? j : kk; - const int64_t bc = trans_b ? kk : j; - const T av = host_a[index_a(ba, ar, ac)]; - const T bv = host_b[index_b(bb, br, bc)]; - sum = static_cast(sum + static_cast(av * bv)); - } - - if (output->dims() == 2) { - host_out[i * n + j] = sum; - } else { - host_out[(bo * m + i) * n + j] = sum; - } - } - } - } - - memcpy_status = MusaMemcpyH2D(output->flat().data(), host_out.data(), - elem_out * sizeof(T)); - if (memcpy_status != mStatus::SUCCESS) { - return errors::Internal( - "Einsum half path: MusaMemcpyH2D output failed"); - } - return Status::OK(); - } - - auto& handle = GetHandleByCtx(ctx); - handle.SetAllowTF32(false); - // Use TF32 setting if needed, but here we can just default or use env like - // matmul op. Since this is a static method, we don't have access to member - // tf32_enabled_. Let's check environment variable again or just assume - // default precision. For now, let's keep it simple and consistent with - // standard usage. - - mTensor mt_a = CreateMTensor(in0); - mTensor mt_b = CreateMTensor(in1); - mTensor mt_out = CreateMTensor(*output); - - auto FixToBatchFormat = [](mTensor& mt, const Tensor& t) { - if (t.dims() == 2) { - int64_t rows = t.dim_size(0); - int64_t cols = t.dim_size(1); - mt.SetNdInfo({1, rows, cols}, {rows * cols, cols, 1}); - } - }; - - ::musa::dnn::Status status; - - { - mBatchMatMul op; - status = op.SetDeterministic(true); - if (status != ::musa::dnn::Status::SUCCESS) { - return errors::Internal( - "MUSA BatchMatMul SetDeterministic(true) failed. Status: ", - (int)status); - } - - if (std::is_same::value || - std::is_same::value) { - status = op.SetComputeMode(mBatchMatMul::ComputeMode::SCALAR); - if (status != ::musa::dnn::Status::SUCCESS) { - return errors::Internal( - "MUSA BatchMatMul SetComputeMode(SCALAR) failed. Status: ", - (int)status); - } - status = op.SetMpCountTarget(1); - if (status != ::musa::dnn::Status::SUCCESS) { - return errors::Internal( - "MUSA BatchMatMul SetMpCountTarget(1) failed. Status: ", - (int)status); - } - } - - op.SetTranspose(trans_a, trans_b); - op.SetAlpha(1.0); - op.SetBeta(0.0); - - FixToBatchFormat(mt_a, in0); - FixToBatchFormat(mt_b, in1); - FixToBatchFormat(mt_out, *output); - - status = op.Run(handle, mt_out, mt_a, mt_b); - if (status != ::musa::dnn::Status::SUCCESS) { - return errors::Internal("MUSA BatchMatMul execution failed. Status: ", - (int)status); - } - } - return Status::OK(); - } - - template - static Status ReduceOperand( - OpKernelContext* ctx, const Tensor& input, - const std::vector& label_types, - const LabelCounts& label_counts, Labels* labels, Labels* free_labels, - bool* swap_free_and_contract, Tensor* output) { - // Find the permutation to transpose the input dimensions in the order of - // EinsumDimensionType; i.e. batch, free, contract and reduce dimensions. - // This makes it more convenient to invoke Reduce/Contract operations. - std::vector permutation(input.dims()); - absl::c_iota(permutation, 0); - Tensor input_transposed; - // Check if we can avoid the transpose. We need to flip the adj_x (or adj_y) - // flag during BatchMatMul. This is an extra optimization not necessary for - // correctness. - if (ShouldSwapFreeAndContract(*labels, label_types)) { - *swap_free_and_contract = true; - } else { - absl::c_sort(permutation, [&](int i, int j) { - int label_i = (*labels)[i]; - int label_j = (*labels)[j]; - return std::tie(label_types[label_i], label_i) < - std::tie(label_types[label_j], label_j); - }); - } - // Transpose the input so that EinsumDimensionTypes are in order. - TF_RETURN_IF_ERROR( - TransposeOperand(ctx, input, permutation, &input_transposed)); - PermuteLabels(permutation, labels); - - // Take the generalized diagonal for dimensions with repeated axis labels. - Tensor input_deduped; - labels->erase(std::unique(labels->begin(), labels->end()), labels->end()); - TF_RETURN_IF_ERROR( - StrideOrInflate(ctx, input_transposed, *labels, label_counts, - false /* should_inflate */, &input_deduped)); - - // Reshape denotes the rank-5 shape [broadcast, batch, free, contract, - // reduce] where we've compacted the dimensions of each EinsumDimensionType. - gtl::InlinedVector reshape(5, 1); - // The output shape is [batch shape] + [free size, contract size] - // That is, the batch shape is preserved (for broadcasting while - // contracting) while the free dims and contract dims are compressed to one - // dimension each. - TensorShape output_shape; - for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { - const int label = labels->at(label_idx); - int64_t dim = input_deduped.dim_size(label_idx); - if (label_types[label] == EinsumDimensionType::kBroadcasting || - label_types[label] == EinsumDimensionType::kBatch) { - TF_RETURN_IF_ERROR(output_shape.AddDimWithStatus(dim)); - } else if (label_types[label] == EinsumDimensionType::kFree) { - free_labels->push_back(label); - } - reshape[label_types[label]] *= dim; - } - if (*swap_free_and_contract) - std::swap(reshape[EinsumDimensionType::kFree], - reshape[EinsumDimensionType::kContract]); - TF_RETURN_IF_ERROR( - output_shape.AddDimWithStatus(reshape[EinsumDimensionType::kFree])); - TF_RETURN_IF_ERROR( - output_shape.AddDimWithStatus(reshape[EinsumDimensionType::kContract])); - - if (reshape[EinsumDimensionType::kReduce] == 1) { // No need to reduce. - return CopyFrom(input_deduped, output_shape, output); - } - TF_RETURN_IF_ERROR( - ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); - const int64_t reduce_size = reshape[kReduce]; - const int64_t output_size = reshape[kBroadcasting] * reshape[kBatch] * - reshape[kFree] * reshape[kContract]; - - TensorShape input_flatten_shape; - TF_RETURN_IF_ERROR(input_flatten_shape.AddDimWithStatus(output_size)); - TF_RETURN_IF_ERROR(input_flatten_shape.AddDimWithStatus(reduce_size)); - Tensor input_flattened; - if (!input_flattened - .BitcastFrom(input_deduped, input_deduped.dtype(), - input_flatten_shape) - .ok()) { - return errors::Internal("Failed to reshape Einsum input for reduce"); - } - - TensorShape output_flatten_shape; - TF_RETURN_IF_ERROR(output_flatten_shape.AddDimWithStatus(output_size)); - Tensor output_flattened; - if (!output_flattened - .BitcastFrom(*output, output->dtype(), output_flatten_shape) - .ok()) { - return errors::Internal("Failed to reshape Einsum output for reduce"); - } - - auto input_mt = CreateMTensor(input_flattened); - auto output_mt = CreateMTensor(output_flattened); - - auto& handle = GetHandleByCtx(ctx); - mReduce op; - op.SetMode(::musa::dnn::Reduce::Mode::ADD); - int reduce_dims[] = {1}; - op.SetDim(1, reduce_dims); - - tensorflow::Allocator* tf_allocator = - ctx->device()->GetAllocator(tensorflow::AllocatorAttributes()); - auto alloc_func = - [tf_allocator]( - size_t size) -> std::unique_ptr> { - void* ptr = tf_allocator->AllocateRaw(256, size); - std::function deleter = [tf_allocator](void* p) { - if (p) tf_allocator->DeallocateRaw(p); - }; - return std::unique_ptr>(ptr, deleter); - }; - ::musa::dnn::MemoryMaintainer mm(alloc_func); - - auto status = op.Run(handle, output_mt, input_mt, mm); - if (status != ::musa::dnn::Status::SUCCESS) { - return errors::Internal("MUSA Reduce (sum) execution failed. Status: ", - static_cast(status)); - } - return Status::OK(); - } - - // Reshapes a Tensor of shape [b0,b1...bk,N,M] to [prod(b0,b1...bk),N,M]. - static Status ReshapeToRank3(const Tensor& input, int batch_size, - Tensor* output) { - const int rank = input.dims(); - TensorShape output_shape = {batch_size, input.dim_size(rank - 2), - input.dim_size(rank - 1)}; - return CopyFrom(input, output_shape, output); - } - - // Contracts the inputs along the last axis (or the second last if the - // corresponding value of swap_free_and_contract is true). The batch - // dimensions are broadcast to the output shape. - template - static Status ContractOperands(OpKernelContext* ctx, - absl::Span inputs, - absl::Span swap_free_and_contract, - Tensor* output) { - if (inputs.size() == 1) - return CopyFrom(inputs[0], inputs[0].shape(), output); - MatMulBCast bcast(inputs[0].shape().dim_sizes(), - inputs[1].shape().dim_sizes()); - if (!bcast.IsValid()) { - return errors::InvalidArgument( - "Invalid broadcasting dimensions: ", inputs[0].shape().DebugString(), - " vs. ", inputs[1].shape().DebugString()); - } - Tensor lhs; - TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[0], bcast.x_batch_size(), &lhs)); - Tensor rhs; - TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs)); - TensorShape output_shape = bcast.output_batch_shape(); - for (int i = 0; i < inputs.size(); ++i) { - const int64_t free_axis = - inputs[i].dims() - (swap_free_and_contract[i] ? 1 : 2); - TF_RETURN_IF_ERROR( - output_shape.AddDimWithStatus(inputs[i].dim_size(free_axis))); - } - bool trans_x = swap_free_and_contract[0]; - bool trans_y = !swap_free_and_contract[1]; - TF_RETURN_IF_ERROR( - ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); - if (lhs.NumElements() == 0 || rhs.NumElements() == 0) { - SetZeroFunctor set_zero; - set_zero(ctx, output); - return Status::OK(); - } - Tensor output_reshaped; - TF_RETURN_IF_ERROR( - ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped)); - return BMatMul(ctx, lhs, rhs, trans_x, trans_y, &output_reshaped); - } -}; - -template -class MusaEinsumOp : public MusaOpKernel { - public: - explicit MusaEinsumOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("equation", &equation_)); - OP_REQUIRES_OK( - ctx, ParseEinsumEquation(equation_, &input_labels_, &output_labels_, - &label_types_, &input_label_counts_, - &output_label_counts_, &input_has_ellipsis_, - &output_has_ellipsis_)); - } - - void Compute(OpKernelContext* ctx) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &inputs)); - - // Take ...i,i->... as an example. After parsing the equation, we have - // input_labels = [kEllipsisLabel, 0] - // output_labels = [kEllipsisLabel] - // label_types = [${EinsumDimensionType for label 0, which is kContract}] - // input_label_counts = [1, 1], which use the default value of 1 for the - // every label. output_label_counts = [1] label_to_dim_sizes = {}, which - // will be populated during dimension processing. - OperandLabels input_labels(input_labels_); - Labels output_labels(output_labels_); - std::vector label_types(label_types_); - OperandLabelCounts input_label_counts(input_label_counts_); - LabelCounts output_label_counts(output_label_counts_); - LabelToDimSizes label_to_dim_sizes; - - OP_REQUIRES_OK(ctx, EinsumHelper::ProcessDimensions( - inputs, input_has_ellipsis_, output_has_ellipsis_, - &input_labels, &output_labels, &label_types, - &input_label_counts, &output_label_counts, - &label_to_dim_sizes)); - - // The reduction phase (a) sums across reduction dimensions, (b) takes - // generalized diagonals, and (c) reshapes it into shape - // [(broadcasting) batch shape] + [F,C] - // where F and C denote the total (compacted) size of free and contract - // dimensions, respectively. - const int num_inputs = inputs.size(); - OperandLabels free_labels(num_inputs); - gtl::InlinedVector inputs_reduced(num_inputs); - gtl::InlinedVector swap_free_and_contract(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - OP_REQUIRES_OK(ctx, - EinsumHelper::ReduceOperand( - ctx, inputs[i], label_types, input_label_counts[i], - &input_labels[i], &free_labels[i], - &swap_free_and_contract[i], &inputs_reduced[i])); - } - - // After reduction, the inputs should be reshaped to Tensors suitable for - // contraction. If num_inputs is 1, the reduced input is simply forwarded to - // the output. - Tensor contraction_output_reshaped; - OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands( - ctx, inputs_reduced, swap_free_and_contract, - &contraction_output_reshaped)); - - // Copy the batch labels from the contraction output. Recover the batch - // shape, which may have been broadcasted. - TensorShape result_shape = contraction_output_reshaped.shape(); - result_shape.RemoveLastDims(2); - - int num_labels = label_types.size(); - Labels result_labels; - // All batch dimensions should be present in the contracted result. First - // the broadcasting dimensions, then the named batch dimensions. - for (int label = 0; label < num_labels; ++label) { - if (label_types[label] == EinsumDimensionType::kBroadcasting) - result_labels.push_back(label); - } - for (int label = 0; label < num_labels; ++label) { - if (label_types[label] == EinsumDimensionType::kBatch) - result_labels.push_back(label); - } - for (int i = 0; i < num_inputs; ++i) { - for (int label : free_labels[i]) { - result_labels.push_back(label); - OP_REQUIRES_OK( - ctx, result_shape.AddDimWithStatus(label_to_dim_sizes[label])); - } - } - - // Reshape the contraction (or reduction) result to its expanded shape: - // [(broadcasted) batch shape] + [free shape 0] + [free shape 1]. - Tensor contraction_output; - OP_REQUIRES_OK( - ctx, EinsumHelper::CopyFrom(contraction_output_reshaped, result_shape, - &contraction_output)); - - // Inflate the output if necessary. (E.g. for the equation 'i->iii' which - // may arise while computing gradient of a regular Einsum). - Tensor output_inflated; - OP_REQUIRES_OK( - ctx, EinsumHelper::StrideOrInflate( - ctx, contraction_output, result_labels, output_label_counts, - true /* should_inflate */, &output_inflated)); - if (output_inflated.dims() > contraction_output.dims()) { - // We inflated the output. Modify result labels accordingly. - Labels inflated_labels; - for (int label : result_labels) { - inflated_labels.insert(inflated_labels.end(), - output_label_counts[label], label); - } - result_labels.swap(inflated_labels); - } - // Find the permutation to map the result labels to the output labels. Note - // that both the result and the final output may have the repeated labels, - // in which case the permutation preserves the left-to-right ordering. - // E.g. if result labels are [0, 0, 1] and output is [0, l, 0] then the - // permutation should be [0, 2, 1]. We also use the fact that repeated - // labels in the result are adjacent to each other. - std::vector output_permutation(output_labels.size()); - std::vector label_to_position(num_labels, -1); - for (int i = 0; i < result_labels.size(); ++i) { - // Remember the position of only the leftmost result label. - if (label_to_position[result_labels[i]] == -1) { - label_to_position[result_labels[i]] = i; - } - } - for (int i = 0; i < output_labels.size(); ++i) { - output_permutation[i] = label_to_position[output_labels[i]]; - // We have found the leftmost occurrence. The next one would be adjacent. - label_to_position[output_labels[i]] += 1; - } - Tensor output; - OP_REQUIRES_OK(ctx, EinsumHelper::TransposeOperand( - ctx, output_inflated, output_permutation, &output)); - ctx->set_output(0, output); - } - - string TraceString(const OpKernelContext& ctx, bool verbose) const override { - string op = profiler::TraceMeOp(name_view(), type_string_view()); - string equation = strings::StrCat("(", equation_, ")"); - if (verbose) { - string shape = ShapeTraceString(ctx); - if (!shape.empty()) { - return profiler::TraceMeEncode( - std::move(op), {{"equation", equation}, {"shape", shape}}); - } - } - return profiler::TraceMeEncode(std::move(op), {{"equation", equation}}); - } - - private: - string equation_; - OperandLabels input_labels_; - Labels output_labels_; - std::vector label_types_; - OperandLabelCounts input_label_counts_; - LabelCounts output_label_counts_; - gtl::InlinedVector input_has_ellipsis_; - bool output_has_ellipsis_ = false; -}; // class MusaEinsumOp - -#define REGISTER_MUSA_EINSUM(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("Einsum").Device("MUSA").TypeConstraint("T"), \ - MusaEinsumOp); - -REGISTER_MUSA_EINSUM(Eigen::half); -REGISTER_MUSA_EINSUM(bfloat16); - -#undef REGISTER_MUSA_EINSUM - -} // namespace musa -} // namespace tensorflow diff --git a/test/einsum_op_test.py b/test/einsum_op_test.py index b40bd35..779a062 100644 --- a/test/einsum_op_test.py +++ b/test/einsum_op_test.py @@ -1,19 +1,4 @@ -# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Tests for the MUSA Einsum operator.""" +"""Tests for MUSA Einsum operator.""" import numpy as np import tensorflow as tf @@ -22,67 +7,115 @@ class EinsumOpTest(MUSATestCase): - """Tests for the MUSA Einsum operator.""" - - def _random_inputs(self, shapes, dtype): - """Generate random inputs for the requested shapes and dtype.""" - np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype - return [ - tf.constant( - np.random.uniform(-1.0, 1.0, size=shape).astype(np_dtype), - dtype=dtype) - for shape in shapes - ] - - def _test_einsum(self, equation, shapes, dtype, rtol=1e-5, atol=1e-8): - """Compare CPU vs MUSA for the given einsum equation.""" - inputs = self._random_inputs(shapes, dtype) - op = lambda *tensors: tf.einsum(equation, *tensors) - self._compare_cpu_musa_results(op, inputs, dtype, rtol=rtol, atol=atol) - - def testMatrixMultiplication(self): - """Matrix multiplication with explicit contraction indices.""" - equation = "ij,jk->ik" - shapes = [(128, 64), (64, 96)] - for dtype in [tf.float32, tf.float16, tf.bfloat16]: - rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 - atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 - self._test_einsum(equation, shapes, dtype, rtol=rtol, atol=atol) - - def testBatchBroadcastContraction(self): - """Batch contraction with broadcasting over leading dims.""" - equation = "bij,jk->bik" - shapes = [(4, 16, 32), (32, 64)] - for dtype in [tf.float32, tf.float16, tf.bfloat16]: - rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 - atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 - self._test_einsum(equation, shapes, dtype, rtol=rtol, atol=atol) - - def testDiagonalAndBroadcast(self): - """Repeated indices that take diagonals and broadcast shapes.""" - equation = "iij,ij->ij" - shapes = [(4, 4, 6), (4, 6)] - for dtype in [tf.float32, tf.float16, tf.bfloat16]: - rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 - atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 - self._test_einsum(equation, shapes, dtype, rtol=rtol, atol=atol) - - def testEllipsisBroadcast(self): - """Ellipsis handling with mixed-rank operands.""" - equation = "...i,i->..." - shapes = [(2, 3, 5), (5,)] - for dtype in [tf.float32, tf.float16, tf.bfloat16]: - rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 - atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 - self._test_einsum(equation, shapes, dtype, rtol=rtol, atol=atol) - - def testMultipleSummations(self): - """Multiple contraction indices with more than two inputs.""" - equation = "abc,acd,db->bd" - shapes = [(3, 4, 5), (3, 5, 6), (6, 4)] - for dtype in [tf.float32, tf.float16, tf.bfloat16]: - self._test_einsum(equation, shapes, dtype) + """Tests for MUSA Einsum operator with TensorFlow-compatible behavior.""" + + def _run_cpu_musa(self, equation, inputs, rtol=1e-5, atol=1e-8): + """Run einsum on CPU and MUSA then compare outputs.""" + with tf.device('/CPU:0'): + cpu_result = tf.einsum(equation, *inputs) + + with tf.device('/device:MUSA:0'): + musa_result = tf.einsum(equation, *inputs) + + if cpu_result.dtype in [tf.float16, tf.bfloat16]: + cpu_result = tf.cast(cpu_result, tf.float32) + musa_result = tf.cast(musa_result, tf.float32) + + self.assertAllEqual(cpu_result.shape, musa_result.shape) + self.assertAllClose(cpu_result.numpy(), musa_result.numpy(), rtol=rtol, atol=atol) + + def _assert_error_consistency(self, equation, inputs): + """Assert CPU and MUSA both raise TensorFlow exceptions for invalid inputs.""" + cpu_error = None + musa_error = None + + try: + with tf.device('/CPU:0'): + tf.einsum(equation, *inputs) + except Exception as e: # pylint: disable=broad-except + cpu_error = e + + try: + with tf.device('/device:MUSA:0'): + tf.einsum(equation, *inputs) + except Exception as e: # pylint: disable=broad-except + musa_error = e + + self.assertIsNotNone(cpu_error, "CPU should raise for invalid einsum input") + self.assertIsNotNone(musa_error, "MUSA should raise for invalid einsum input") + self.assertEqual(type(cpu_error), type(musa_error)) + + def testEinsumMatrixMultiplication(self): + """Test classic matrix multiplication: ij,jk->ik.""" + for dtype in [tf.float32, tf.float16, tf.bfloat16]: + a_np = np.random.uniform(-1.0, 1.0, size=(8, 16)).astype(np.float32) + b_np = np.random.uniform(-1.0, 1.0, size=(16, 10)).astype(np.float32) + a = tf.constant(a_np, dtype=dtype) + b = tf.constant(b_np, dtype=dtype) + + rtol = 2e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 + atol = 2e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-6 + self._run_cpu_musa('ij,jk->ik', [a, b], rtol=rtol, atol=atol) + + def testEinsumBatchMatMul(self): + """Test batched matmul: bij,bjk->bik.""" + a = tf.constant(np.random.uniform(-1.0, 1.0, size=(4, 6, 8)).astype(np.float32)) + b = tf.constant(np.random.uniform(-1.0, 1.0, size=(4, 8, 5)).astype(np.float32)) + self._run_cpu_musa('bij,bjk->bik', [a, b]) + + def testEinsumBroadcastWithEllipsis(self): + """Test ellipsis broadcasting: ...ij,...jk->...ik.""" + a = tf.constant(np.random.uniform(-1.0, 1.0, size=(2, 3, 4, 6)).astype(np.float32)) + b = tf.constant(np.random.uniform(-1.0, 1.0, size=(1, 3, 6, 5)).astype(np.float32)) + self._run_cpu_musa('...ij,...jk->...ik', [a, b]) + + def testEinsumImplicitOutput(self): + """Test implicit output mode: ij,jk (without ->).""" + a = tf.constant(np.random.uniform(-1.0, 1.0, size=(7, 11)).astype(np.float32)) + b = tf.constant(np.random.uniform(-1.0, 1.0, size=(11, 9)).astype(np.float32)) + self._run_cpu_musa('ij,jk', [a, b]) + + def testEinsumTranspose(self): + """Test axis permutation with single input: ijk->kji.""" + x = tf.constant(np.random.uniform(-1.0, 1.0, size=(3, 5, 7)).astype(np.float32)) + self._run_cpu_musa('ijk->kji', [x]) + + def testEinsumDiagonalExtraction(self): + """Test repeated indices for diagonal extraction: ii->i.""" + x = tf.constant(np.random.uniform(-1.0, 1.0, size=(9, 9)).astype(np.float32)) + self._run_cpu_musa('ii->i', [x]) + + def testEinsumReduction(self): + """Test reduction to scalar: ij->.""" + x = tf.constant(np.random.uniform(-1.0, 1.0, size=(13, 17)).astype(np.float32)) + self._run_cpu_musa('ij->', [x]) + + def testEinsumOuterProduct(self): + """Test outer product: i,j->ij.""" + x = tf.constant(np.random.uniform(-1.0, 1.0, size=(12,)).astype(np.float32)) + y = tf.constant(np.random.uniform(-1.0, 1.0, size=(8,)).astype(np.float32)) + self._run_cpu_musa('i,j->ij', [x, y]) + + def testEinsumThreeOperands(self): + """Test multi-operand contraction: ab,bc,cd->ad.""" + a = tf.constant(np.random.uniform(-1.0, 1.0, size=(4, 6)).astype(np.float32)) + b = tf.constant(np.random.uniform(-1.0, 1.0, size=(6, 5)).astype(np.float32)) + c = tf.constant(np.random.uniform(-1.0, 1.0, size=(5, 3)).astype(np.float32)) + self._run_cpu_musa('ab,bc,cd->ad', [a, b, c]) + + def testEinsumInvalidEquation(self): + """Invalid equation should raise consistently on CPU and MUSA.""" + x = tf.constant(np.random.uniform(-1.0, 1.0, size=(2, 2)).astype(np.float32)) + self._assert_error_consistency('ij->ik', [x]) + + def testEinsumMismatchedDimensions(self): + """Dimension mismatch should raise consistently on CPU and MUSA.""" + a = tf.constant(np.random.uniform(-1.0, 1.0, size=(4, 5)).astype(np.float32)) + b = tf.constant(np.random.uniform(-1.0, 1.0, size=(6, 3)).astype(np.float32)) + self._assert_error_consistency('ij,jk->ik', [a, b]) if __name__ == "__main__": - tf.test.main() + np.random.seed(2026) + tf.random.set_seed(2026) + tf.test.main() From 443d5655b3818aaf1f2930d08154df79f10f10fb Mon Sep 17 00:00:00 2001 From: albert Date: Fri, 27 Feb 2026 07:23:43 +0000 Subject: [PATCH 11/16] clean codes --- musa_ext/kernels/musa_einsum_op.cc | 86 +----------------------------- 1 file changed, 1 insertion(+), 85 deletions(-) diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc index 20f278e..402c26a 100644 --- a/musa_ext/kernels/musa_einsum_op.cc +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -337,6 +337,7 @@ struct EinsumHelper { return true; } + // TODO: BMatMul seems to perform worse when the input use half precision. template static Status BMatMul(OpKernelContext* ctx, Tensor& lhs, const Tensor& rhs, bool trans_a, bool trans_b, Tensor* output) { @@ -355,91 +356,6 @@ struct EinsumHelper { if (output->NumElements() == 0) return Status::OK(); - // // USE CPU to compute half and bfloat16 bMatMul - // if (std::is_same::value || - // std::is_same::value) { - // const int64_t batch_a = in0.dims() == 2 ? 1 : in0.dim_size(0); - // const int64_t batch_b = in1.dims() == 2 ? 1 : in1.dim_size(0); - // const int64_t batch_out = output->dims() == 2 ? 1 : - // output->dim_size(0); - - // const int64_t a_rows = - // in0.dims() == 2 ? in0.dim_size(0) : in0.dim_size(1); - // const int64_t a_cols = - // in0.dims() == 2 ? in0.dim_size(1) : in0.dim_size(2); - // const int64_t b_rows = - // in1.dims() == 2 ? in1.dim_size(0) : in1.dim_size(1); - // const int64_t b_cols = - // in1.dims() == 2 ? in1.dim_size(1) : in1.dim_size(2); - - // const int64_t elem_a = in0.NumElements(); - // const int64_t elem_b = in1.NumElements(); - // const int64_t elem_out = output->NumElements(); - - // std::vector host_a(elem_a); - // std::vector host_b(elem_b); - // std::vector host_out(elem_out); - - // mStatus memcpy_status = MusaMemcpyD2H(host_a.data(), - // in0.flat().data(), - // elem_a * sizeof(T)); - // if (memcpy_status != mStatus::SUCCESS) { - // return errors::Internal("Einsum half path: MusaMemcpyD2H A failed"); - // } - // memcpy_status = MusaMemcpyD2H(host_b.data(), in1.flat().data(), - // elem_b * sizeof(T)); - // if (memcpy_status != mStatus::SUCCESS) { - // return errors::Internal("Einsum half path: MusaMemcpyD2H B failed"); - // } - - // auto index_a = [&](int64_t batch, int64_t row, int64_t col) { - // if (in0.dims() == 2) { - // return row * a_cols + col; - // } - // return (batch * a_rows + row) * a_cols + col; - // }; - // auto index_b = [&](int64_t batch, int64_t row, int64_t col) { - // if (in1.dims() == 2) { - // return row * b_cols + col; - // } - // return (batch * b_rows + row) * b_cols + col; - // }; - - // for (int64_t bo = 0; bo < batch_out; ++bo) { - // const int64_t ba = batch_a == 1 ? 0 : bo; - // const int64_t bb = batch_b == 1 ? 0 : bo; - // for (int64_t i = 0; i < m; ++i) { - // for (int64_t j = 0; j < n; ++j) { - // T sum = static_cast(0); - // for (int64_t kk = 0; kk < k; ++kk) { - // const int64_t ar = trans_a ? kk : i; - // const int64_t ac = trans_a ? i : kk; - // const int64_t br = trans_b ? j : kk; - // const int64_t bc = trans_b ? kk : j; - // const T av = host_a[index_a(ba, ar, ac)]; - // const T bv = host_b[index_b(bb, br, bc)]; - // sum = static_cast(sum + static_cast(av * bv)); - // } - - // if (output->dims() == 2) { - // host_out[i * n + j] = sum; - // } else { - // host_out[(bo * m + i) * n + j] = sum; - // } - // } - // } - // } - - // memcpy_status = MusaMemcpyH2D(output->flat().data(), - // host_out.data(), - // elem_out * sizeof(T)); - // if (memcpy_status != mStatus::SUCCESS) { - // return errors::Internal( - // "Einsum half path: MusaMemcpyH2D output failed"); - // } - // return Status::OK(); - // } - auto& handle = GetHandleByCtx(ctx); handle.SetAllowTF32(false); // Use TF32 setting if needed, but here we can just default or use env like From 354a3d1617df4fd39982a3c3d886a06abf8b2f3f Mon Sep 17 00:00:00 2001 From: albert Date: Fri, 27 Feb 2026 07:35:47 +0000 Subject: [PATCH 12/16] optimize broadcast with mudnn --- musa_ext/kernels/musa_einsum_op.cc | 136 +++++++++++++++++++---------- 1 file changed, 90 insertions(+), 46 deletions(-) diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc index 402c26a..6d34c80 100644 --- a/musa_ext/kernels/musa_einsum_op.cc +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -1,6 +1,5 @@ #include "musa_einsum_op.h" -#include #include #include #include @@ -9,7 +8,6 @@ #include "../utils/musa_einsum_op_util.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_split.h" -#include "mu/device/musa_memcpy.h" #include "musa_fill_functor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -339,8 +337,9 @@ struct EinsumHelper { // TODO: BMatMul seems to perform worse when the input use half precision. template - static Status BMatMul(OpKernelContext* ctx, Tensor& lhs, const Tensor& rhs, - bool trans_a, bool trans_b, Tensor* output) { + static Status BMatMul(OpKernelContext* ctx, const Tensor& lhs, + const Tensor& rhs, bool trans_a, bool trans_b, + Tensor* output) { const Tensor& in0 = lhs; const Tensor& in1 = rhs; @@ -530,57 +529,96 @@ struct EinsumHelper { template static Status MaterializeBroadcastedBatch( - OpKernelContext* ctx, const Tensor& input, int64_t input_batch_size, - int64_t output_batch_size, const std::vector& batch_indices, - Tensor* output) { - Tensor input_rank3; - TF_RETURN_IF_ERROR(ReshapeToRank3(input, static_cast(input_batch_size), - &input_rank3)); + OpKernelContext* ctx, const Tensor& input, + const TensorShape& output_batch_shape, Tensor* output) { + const int input_rank = input.dims(); + if (input_rank < 2) { + return errors::InvalidArgument( + "Einsum batch broadcast expects rank >= 2, got rank ", input_rank); + } - TensorShape output_shape = {output_batch_size, input_rank3.dim_size(1), - input_rank3.dim_size(2)}; + const int input_batch_rank = input_rank - 2; + const int output_batch_rank = output_batch_shape.dims(); + if (output_batch_rank < input_batch_rank) { + return errors::Internal( + "Einsum batch broadcast: output batch rank ", output_batch_rank, + " is smaller than input batch rank ", input_batch_rank); + } - if (input_batch_size == output_batch_size && batch_indices.empty()) { - return CopyFrom(input_rank3, output_shape, output); + TensorShape output_shape = output_batch_shape; + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(input.dim_size(input_rank - 2))); + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(input.dim_size(input_rank - 1))); + + if (input.shape() == output_shape) { + return CopyFrom(input, output_shape, output); } TF_RETURN_IF_ERROR( ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); if (output->NumElements() == 0) return Status::OK(); - const int64_t elems_per_batch = - input_rank3.dim_size(1) * input_rank3.dim_size(2); - std::vector host_input(input_rank3.NumElements()); - std::vector host_output(output->NumElements()); - - mStatus memcpy_status = - MusaMemcpyD2H(host_input.data(), input_rank3.flat().data(), - input_rank3.NumElements() * sizeof(T)); - if (memcpy_status != mStatus::SUCCESS) { - return errors::Internal( - "Einsum batch broadcast: MusaMemcpyD2H input failed"); + std::vector input_dense_strides(input_rank, 1); + for (int axis = input_rank - 2; axis >= 0; --axis) { + input_dense_strides[axis] = + input_dense_strides[axis + 1] * input.dim_size(axis + 1); } - for (int64_t out_batch = 0; out_batch < output_batch_size; ++out_batch) { - const int64_t in_batch = - batch_indices.empty() ? out_batch : batch_indices[out_batch]; - if (in_batch < 0 || in_batch >= input_batch_size) { - return errors::Internal("Einsum batch broadcast: invalid batch index ", - in_batch, " for input batch size ", - input_batch_size); + const int target_rank = output_batch_rank + 2; + std::vector target_dims(target_rank, 1); + std::vector target_strides(target_rank, 0); + const int batch_axis_offset = output_batch_rank - input_batch_rank; + + for (int out_axis = 0; out_axis < output_batch_rank; ++out_axis) { + const int64_t out_dim = output_batch_shape.dim_size(out_axis); + target_dims[out_axis] = out_dim; + + const int in_axis = out_axis - batch_axis_offset; + if (in_axis < 0) { + target_strides[out_axis] = 0; + continue; } - const T* src = host_input.data() + in_batch * elems_per_batch; - T* dst = host_output.data() + out_batch * elems_per_batch; - std::memcpy(dst, src, elems_per_batch * sizeof(T)); + + const int64_t in_dim = input.dim_size(in_axis); + if (in_dim != out_dim && in_dim != 1) { + return errors::Internal( + "Einsum batch broadcast: incompatible batch dim at axis ", out_axis, + ", input dim ", in_dim, ", output dim ", out_dim); + } + target_strides[out_axis] = + (in_dim == 1 && out_dim != 1) ? 0 : input_dense_strides[in_axis]; } - memcpy_status = MusaMemcpyH2D(output->flat().data(), host_output.data(), - output->NumElements() * sizeof(T)); - if (memcpy_status != mStatus::SUCCESS) { + target_dims[output_batch_rank] = input.dim_size(input_rank - 2); + target_dims[output_batch_rank + 1] = input.dim_size(input_rank - 1); + target_strides[output_batch_rank] = input_dense_strides[input_rank - 2]; + target_strides[output_batch_rank + 1] = input_dense_strides[input_rank - 1]; + + auto input_mt = CreateMTensor(input); + auto output_mt = CreateMTensor(*output); + auto status = input_mt.SetNdInfo(target_rank, target_dims.data(), + target_strides.data()); + if (status != ::musa::dnn::Status::SUCCESS) { return errors::Internal( - "Einsum batch broadcast: MusaMemcpyH2D output failed"); + "Einsum batch broadcast: SetNdInfo failed. Status: ", + static_cast(status)); } + auto& handle = GetHandleByCtx(ctx); + ::musa::dnn::Unary op; + status = op.SetMode(::musa::dnn::Unary::Mode::IDENTITY); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal( + "Einsum batch broadcast: Unary SetMode failed. Status: ", + static_cast(status)); + } + status = op.Run(handle, output_mt, input_mt); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal( + "Einsum batch broadcast: Unary Run failed. Status: ", + static_cast(status)); + } return Status::OK(); } @@ -602,16 +640,22 @@ struct EinsumHelper { " vs. ", inputs[1].shape().DebugString()); } - Tensor lhs; + const TensorShape output_batch_shape = bcast.output_batch_shape(); + Tensor lhs_broadcasted; TF_RETURN_IF_ERROR(MaterializeBroadcastedBatch( - ctx, inputs[0], bcast.x_batch_size(), bcast.output_batch_size(), - bcast.x_batch_indices(), &lhs)); - Tensor rhs; + ctx, inputs[0], output_batch_shape, &lhs_broadcasted)); + Tensor rhs_broadcasted; TF_RETURN_IF_ERROR(MaterializeBroadcastedBatch( - ctx, inputs[1], bcast.y_batch_size(), bcast.output_batch_size(), - bcast.y_batch_indices(), &rhs)); + ctx, inputs[1], output_batch_shape, &rhs_broadcasted)); + + Tensor lhs; + TF_RETURN_IF_ERROR( + ReshapeToRank3(lhs_broadcasted, bcast.output_batch_size(), &lhs)); + Tensor rhs; + TF_RETURN_IF_ERROR( + ReshapeToRank3(rhs_broadcasted, bcast.output_batch_size(), &rhs)); - TensorShape output_shape = bcast.output_batch_shape(); + TensorShape output_shape = output_batch_shape; for (int i = 0; i < inputs.size(); ++i) { const int64_t free_axis = inputs[i].dims() - (swap_free_and_contract[i] ? 1 : 2); From 094fbab380c1aa2705a33d1dfd4f5f820050348b Mon Sep 17 00:00:00 2001 From: albert Date: Fri, 27 Feb 2026 09:25:40 +0000 Subject: [PATCH 13/16] reinforce support for bf16, which would be convert into float when reducing (which bf16 is bad at) --- musa_ext/kernels/musa_einsum_op.cc | 43 ++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc index 6d34c80..731f197 100644 --- a/musa_ext/kernels/musa_einsum_op.cc +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -184,6 +184,32 @@ struct EinsumHelper { return false; } + static Status CastTensor(OpKernelContext* ctx, const Tensor& input, + DataType dst_dtype, Tensor* output) { + if (input.dtype() == dst_dtype) { + return CopyFrom(input, input.shape(), output); + } + + TF_RETURN_IF_ERROR(ctx->allocate_temp(dst_dtype, input.shape(), output)); + if (input.NumElements() == 0) return Status::OK(); + + auto input_mt = CreateMTensor(input); + auto output_mt = CreateMTensor(*output); + + ::musa::dnn::Unary op; + auto status = op.SetMode(::musa::dnn::Unary::Mode::CAST); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("Einsum CastTensor SetMode failed. Status: ", + static_cast(status)); + } + status = op.Run(GetHandleByCtx(ctx), output_mt, input_mt); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("Einsum CastTensor Run failed. Status: ", + static_cast(status)); + } + return Status::OK(); + } + // Transpose the input given a permutation. Returns a reference to the input // if transposing is not necessary. template @@ -678,6 +704,23 @@ struct EinsumHelper { } }; +template <> +Status EinsumHelper::ReduceOperand( + OpKernelContext* ctx, const Tensor& input, + const std::vector& label_types, + const LabelCounts& label_counts, Labels* labels, Labels* free_labels, + bool* swap_free_and_contract, Tensor* output) { + Tensor input_fp32; + TF_RETURN_IF_ERROR(CastTensor(ctx, input, DT_FLOAT, &input_fp32)); + + Tensor output_fp32; + TF_RETURN_IF_ERROR( + ReduceOperand(ctx, input_fp32, label_types, label_counts, labels, + free_labels, swap_free_and_contract, &output_fp32)); + + return CastTensor(ctx, output_fp32, DataTypeToEnum::value, output); +} + template class MusaEinsumOp : public MusaOpKernel { public: From e27a6a5b7159c6613009da6a1e6017e871ab2178 Mon Sep 17 00:00:00 2001 From: albert Date: Fri, 27 Feb 2026 10:26:19 +0000 Subject: [PATCH 14/16] Format: implement reduce & fill functor --- musa_ext/kernels/musa_cast_functor.h | 29 ++++++ musa_ext/kernels/musa_einsum_op.cc | 53 ++--------- musa_ext/kernels/musa_fill_functor.h | 37 +++++++- musa_ext/kernels/musa_fill_op.cc | 120 ------------------------- musa_ext/kernels/musa_reduce_functor.h | 70 +++++++++++++++ 5 files changed, 139 insertions(+), 170 deletions(-) create mode 100644 musa_ext/kernels/musa_cast_functor.h delete mode 100755 musa_ext/kernels/musa_fill_op.cc create mode 100644 musa_ext/kernels/musa_reduce_functor.h diff --git a/musa_ext/kernels/musa_cast_functor.h b/musa_ext/kernels/musa_cast_functor.h new file mode 100644 index 0000000..527e74e --- /dev/null +++ b/musa_ext/kernels/musa_cast_functor.h @@ -0,0 +1,29 @@ +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "utils_op.h" + +namespace tensorflow { +namespace musa { + +static Status CastTensor(OpKernelContext* ctx, const mTensor& input_mt, + DataType dst_dtype, mTensor* output_mt) { + ::musa::dnn::Unary op; + auto status = op.SetMode(::musa::dnn::Unary::Mode::CAST); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("Einsum CastTensor SetMode failed. Status: ", + static_cast(status)); + } + status = op.Run(GetHandleByCtx(ctx), *output_mt, input_mt); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal("Einsum CastTensor Run failed. Status: ", + static_cast(status)); + } + return Status::OK(); +} + +} // namespace musa +} // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc index 731f197..4185413 100644 --- a/musa_ext/kernels/musa_einsum_op.cc +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -1,6 +1,5 @@ #include "musa_einsum_op.h" -#include #include #include @@ -9,6 +8,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/strings/str_split.h" #include "musa_fill_functor.h" +#include "musa_reduce_functor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -318,8 +318,7 @@ struct EinsumHelper { auto output_mt = CreateMTensor(*output); if (should_inflate) { - SetZeroFunctor set_zero; - set_zero(ctx, output); + SetZeroFunctor::Compute(ctx, &output_mt); output_mt.SetNdInfo(rank, strided_dims_vec.data(), diagonal_strides_vec.data()); } else { @@ -517,31 +516,10 @@ struct EinsumHelper { auto input_mt = CreateMTensor(input_flattened); auto output_mt = CreateMTensor(output_flattened); - auto& handle = GetHandleByCtx(ctx); - mReduce op; - op.SetMode(::musa::dnn::Reduce::Mode::ADD); int reduce_dims[] = {1}; - op.SetDim(1, reduce_dims); - - tensorflow::Allocator* tf_allocator = - ctx->device()->GetAllocator(tensorflow::AllocatorAttributes()); - auto alloc_func = - [tf_allocator]( - size_t size) -> std::unique_ptr> { - void* ptr = tf_allocator->AllocateRaw(256, size); - std::function deleter = [tf_allocator](void* p) { - if (p) tf_allocator->DeallocateRaw(p); - }; - return std::unique_ptr>(ptr, deleter); - }; - ::musa::dnn::MemoryMaintainer mm(alloc_func); - - auto status = op.Run(handle, output_mt, input_mt, mm); - if (status != ::musa::dnn::Status::SUCCESS) { - return errors::Internal("MUSA Reduce (sum) execution failed. Status: ", - static_cast(status)); - } - return Status::OK(); + return ReduceFunctor::Compute( + ctx, &output_mt, &input_mt, ::musa::dnn::Reduce::Mode::ADD, reduce_dims, + 1, "MUSA Reduce (sum) execution failed. Status: "); } // Reshapes a Tensor of shape [b0,b1...bk,N,M] to [prod(b0,b1...bk),N,M]. @@ -693,8 +671,8 @@ struct EinsumHelper { TF_RETURN_IF_ERROR( ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); if (lhs.NumElements() == 0 || rhs.NumElements() == 0) { - SetZeroFunctor set_zero; - set_zero(ctx, output); + mTensor output_mt = CreateMTensor(*output); + TF_RETURN_IF_ERROR(SetZeroFunctor::Compute(ctx, &output_mt)); return Status::OK(); } Tensor output_reshaped; @@ -704,23 +682,6 @@ struct EinsumHelper { } }; -template <> -Status EinsumHelper::ReduceOperand( - OpKernelContext* ctx, const Tensor& input, - const std::vector& label_types, - const LabelCounts& label_counts, Labels* labels, Labels* free_labels, - bool* swap_free_and_contract, Tensor* output) { - Tensor input_fp32; - TF_RETURN_IF_ERROR(CastTensor(ctx, input, DT_FLOAT, &input_fp32)); - - Tensor output_fp32; - TF_RETURN_IF_ERROR( - ReduceOperand(ctx, input_fp32, label_types, label_counts, labels, - free_labels, swap_free_and_contract, &output_fp32)); - - return CastTensor(ctx, output_fp32, DataTypeToEnum::value, output); -} - template class MusaEinsumOp : public MusaOpKernel { public: diff --git a/musa_ext/kernels/musa_fill_functor.h b/musa_ext/kernels/musa_fill_functor.h index 23072dc..a3095d2 100644 --- a/musa_ext/kernels/musa_fill_functor.h +++ b/musa_ext/kernels/musa_fill_functor.h @@ -7,19 +7,48 @@ */ +#include + #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" +#include "utils_op.h" namespace tensorflow { namespace musa { template -Status MusaFillCall(Tensor* out, T value, OpKernelContext* context); +Status MusaFillCall(mTensor* out_mt, T value, OpKernelContext* context) { + mFill op; + mHandle& h = GetHandleByCtx(context); + + if (std::is_integral::value) { + if (mStatus::SUCCESS != op.SetValue(static_cast(value))) { + return errors::Internal("mtdnn set value (int) error!"); + } + } else if (std::is_floating_point::value || + std::is_same::value || + std::is_same::value) { + if (mStatus::SUCCESS != op.SetValue(static_cast(value))) { + return errors::Internal("mtdnn set value (float) error!"); + } + } else { + return errors::Unimplemented("Data type not supported in MTGPU Fill."); + } + + if (mStatus::SUCCESS != op.Run(h, *out_mt)) { + return errors::Internal("mtdnn run op error!"); + } + + return Status::OK(); +} -template struct SetZeroFunctor { // Computes on device "d": out = out.setZero(), - void operator()(OpKernelContext* ctx, Tensor* out) { - MusaFillCall(out, T(0), ctx); + template + static Status Compute(OpKernelContext* ctx, mTensor* out_mt) { + return MusaFillCall(out_mt, T(0), ctx); } }; diff --git a/musa_ext/kernels/musa_fill_op.cc b/musa_ext/kernels/musa_fill_op.cc deleted file mode 100755 index 86ed3e0..0000000 --- a/musa_ext/kernels/musa_fill_op.cc +++ /dev/null @@ -1,120 +0,0 @@ -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "utils_op.h" - -namespace tensorflow { -namespace musa { - -namespace { - -template -struct is_any : std::false_type {}; - -template -struct is_any : std::is_same {}; - -template -struct is_any - : std::integral_constant::value || - is_any::value> {}; - -} // namespace - -template -Status MusaFillCall(Tensor* out, T value, OpKernelContext* context) { - mFill op; - mHandle& h = GetHandleByCtx(context); - auto out_mt = CreateMTensor(*out); - - if (is_any::value) { - if (mStatus::SUCCESS != op.SetValue(static_cast(value))) { - return errors::Internal("mtdnn set value (int) error!"); - } - } else if (is_any::value) { - if (mStatus::SUCCESS != op.SetValue(static_cast(value))) { - return errors::Internal("mtdnn set value (float) error!"); - } - } else { - return errors::Unimplemented("Data type not supported in MTGPU Fill."); - } - - if (mStatus::SUCCESS != op.Run(h, out_mt)) { - return errors::Internal("mtdnn run op error!"); - } - - return Status::OK(); -} - -template -class MusaFillOp : public MusaOpKernel { - public: - explicit MusaFillOp(OpKernelConstruction* context) : MusaOpKernel(context) {} - - void Compute(OpKernelContext* context) override { - const Tensor& Tdims = context->input(0); - const Tensor& Tvalue = context->input(1); - - OP_REQUIRES( - context, - (TensorShapeUtils::IsVector(Tdims.shape()) || - TensorShapeUtils::IsScalar(Tdims.shape())), - errors::InvalidArgument("dims must represent a vector, got shape ", - Tdims.shape().DebugString())); - - OP_REQUIRES( - context, - TensorShapeUtils::IsScalar(Tvalue.shape()) || - (TensorShapeUtils::IsVector(Tvalue.shape()) && - Tvalue.shape().dim_size(0) == 1), - errors::InvalidArgument("value must represent a scalar, got shape ", - Tvalue.shape().DebugString())); - - auto dims_vec = Tdims.flat(); - TensorShape shape; - OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( - reinterpret_cast(dims_vec.data()), - dims_vec.size(), &shape)); - - Tensor* out = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, shape, &out)); - - if (shape.num_elements() == 0) return; - - OP_REQUIRES_OK( - context, MusaFillCall(out, static_cast(Tvalue.data())[0], context)); - } -}; - -#define REGISTER_FILL_KERNEL(type) \ - REGISTER_KERNEL_BUILDER(Name("Fill") \ - .Device("MUSA") \ - .TypeConstraint("T") \ - .TypeConstraint("index_type") \ - .HostMemory("dims") \ - .HostMemory("value"), \ - MusaFillOp); \ - REGISTER_KERNEL_BUILDER(Name("Fill") \ - .Device("MUSA") \ - .TypeConstraint("T") \ - .TypeConstraint("index_type") \ - .HostMemory("dims") \ - .HostMemory("value"), \ - MusaFillOp); - -REGISTER_FILL_KERNEL(float); -REGISTER_FILL_KERNEL(double); -REGISTER_FILL_KERNEL(int32); -REGISTER_FILL_KERNEL(int64); -REGISTER_FILL_KERNEL(Eigen::half); -REGISTER_FILL_KERNEL(Eigen::bfloat16); -REGISTER_FILL_KERNEL(bool); - -#undef REGISTER_FILL_KERNEL - -} // namespace musa -} // namespace tensorflow diff --git a/musa_ext/kernels/musa_reduce_functor.h b/musa_ext/kernels/musa_reduce_functor.h new file mode 100644 index 0000000..04ae40e --- /dev/null +++ b/musa_ext/kernels/musa_reduce_functor.h @@ -0,0 +1,70 @@ +#ifndef MUSA_PLUGIN_SRC_KERNELS_MUSA_REDUCE_FUNCTOR_H_ +#define MUSA_PLUGIN_SRC_KERNELS_MUSA_REDUCE_FUNCTOR_H_ + +#include +#include + +#include "musa_cast_functor.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "utils_op.h" + +namespace tensorflow { +namespace musa { + +struct ReduceFunctor { + template + static Status Compute(OpKernelContext* ctx, mTensor* output, mTensor* input, + ::musa::dnn::Reduce::Mode mode, const int* reduce_dims, + int reduce_dim_count, const char* error_prefix) { + auto& handle = GetHandleByCtx(ctx); + + mReduce op; + op.SetMode(mode); + op.SetDim(reduce_dim_count, reduce_dims); + + tensorflow::Allocator* tf_allocator = + ctx->device()->GetAllocator(tensorflow::AllocatorAttributes()); + auto alloc_func = + [tf_allocator]( + size_t size) -> std::unique_ptr> { + void* ptr = tf_allocator->AllocateRaw(256, size); + std::function deleter = [tf_allocator](void* p) { + if (p) tf_allocator->DeallocateRaw(p); + }; + return std::unique_ptr>(ptr, deleter); + }; + ::musa::dnn::MemoryMaintainer mm(alloc_func); + + auto status = op.Run(handle, *output, *input, mm); + if (status != ::musa::dnn::Status::SUCCESS) { + return errors::Internal(error_prefix, static_cast(status)); + } + return Status::OK(); + } +}; + +template <> +Status ReduceFunctor::Compute(OpKernelContext* ctx, + mTensor* output_mt, mTensor* input_mt, + ::musa::dnn::Reduce::Mode mode, + const int* reduce_dims, + int reduce_dim_count, + const char* error_prefix) { + mTensor input_fp32; + TF_RETURN_IF_ERROR(CastTensor(ctx, *input_mt, DT_FLOAT, &input_fp32)); + + mTensor output_fp32; + TF_RETURN_IF_ERROR(Compute(ctx, &output_fp32, &input_fp32, mode, + reduce_dims, reduce_dim_count, + error_prefix)); + + return CastTensor(ctx, output_fp32, DataTypeToEnum::value, + output_mt); +} + +} // namespace musa +} // namespace tensorflow + +#endif // MUSA_PLUGIN_SRC_KERNELS_MUSA_REDUCE_FUNCTOR_H_ \ No newline at end of file From 2c2ed1041a8f6acae134002e9373d3ee7e66118d Mon Sep 17 00:00:00 2001 From: albert Date: Fri, 27 Feb 2026 10:41:11 +0000 Subject: [PATCH 15/16] clean comments & implement transpose functor --- musa_ext/kernels/musa_einsum_op.cc | 15 ++++-------- musa_ext/kernels/musa_reduce_functor.h | 4 ++++ musa_ext/kernels/musa_transpose_functor.h | 28 +++++++++++++++++++++++ musa_ext/kernels/musa_transpose_op.cc | 28 +++++------------------ 4 files changed, 43 insertions(+), 32 deletions(-) create mode 100644 musa_ext/kernels/musa_transpose_functor.h diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc index 4185413..5ff2852 100644 --- a/musa_ext/kernels/musa_einsum_op.cc +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -22,8 +22,8 @@ namespace tensorflow { namespace musa { -void DoTranspose(OpKernelContext* ctx, const Tensor& input, - const std::vector& permutation, Tensor* output); +void DoTranspose(OpKernelContext* ctx, mTensor& in_mt, + const std::vector& permutation, mTensor& out_mt); using ShapeVec = gtl::InlinedVector; using Labels = gtl::InlinedVector; @@ -231,7 +231,9 @@ struct EinsumHelper { } TF_RETURN_IF_ERROR( ctx->allocate_temp(DataTypeToEnum::value, transposed_shape, output)); - DoTranspose(ctx, input, permutation, output); + mTensor input_mt = CreateMTensor(input); + mTensor output_mt = CreateMTensor(*output); + DoTranspose(ctx, input_mt, permutation, output_mt); return Status::OK(); } @@ -698,13 +700,6 @@ class MusaEinsumOp : public MusaOpKernel { OpInputList inputs; OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &inputs)); - // Take ...i,i->... as an example. After parsing the equation, we have - // input_labels = [kEllipsisLabel, 0] - // output_labels = [kEllipsisLabel] - // label_types = [${EinsumDimensionType for label 0, which is kContract}] - // input_label_counts = [1, 1], which use the default value of 1 for the - // every label. output_label_counts = [1] label_to_dim_sizes = {}, which - // will be populated during dimension processing. OperandLabels input_labels(input_labels_); Labels output_labels(output_labels_); std::vector label_types(label_types_); diff --git a/musa_ext/kernels/musa_reduce_functor.h b/musa_ext/kernels/musa_reduce_functor.h index 04ae40e..ef9902d 100644 --- a/musa_ext/kernels/musa_reduce_functor.h +++ b/musa_ext/kernels/musa_reduce_functor.h @@ -45,6 +45,10 @@ struct ReduceFunctor { } }; +// Given the fact that bf16 does not work well with reduction, we compute the +// reduction in fp32 and cast the result back to bf16. +// This conversion is aligned with tensorflow's convention of promoting bf16 to +// fp32 for ReduceFunctor. template <> Status ReduceFunctor::Compute(OpKernelContext* ctx, mTensor* output_mt, mTensor* input_mt, diff --git a/musa_ext/kernels/musa_transpose_functor.h b/musa_ext/kernels/musa_transpose_functor.h new file mode 100644 index 0000000..c2019c7 --- /dev/null +++ b/musa_ext/kernels/musa_transpose_functor.h @@ -0,0 +1,28 @@ +#include + +#include "utils_op.h" + +namespace tensorflow { +namespace musa { + +void DoTranspose(OpKernelContext* ctx, mTensor& in_mt, + const std::vector& permutation, mTensor& out_mt) { + mHandle& h = GetHandleByCtx(ctx); + + ::musa::dnn::Permute pop; + + if (::musa::dnn::Status::SUCCESS != + pop.ConfigDimStride(out_mt, in_mt, static_cast(permutation.size()), + permutation.data())) { + ctx->CtxFailure(errors::Internal("muDNN Permute ConfigDimStride failed!")); + return; + } + + if (::musa::dnn::Status::SUCCESS != pop.Run(h, out_mt, in_mt)) { + ctx->CtxFailure(errors::Internal("muDNN Permute Run failed!")); + return; + } +} + +} // namespace musa +} // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/kernels/musa_transpose_op.cc b/musa_ext/kernels/musa_transpose_op.cc index 648f25d..c6c11b6 100644 --- a/musa_ext/kernels/musa_transpose_op.cc +++ b/musa_ext/kernels/musa_transpose_op.cc @@ -2,6 +2,7 @@ #include +#include "musa_transpose_functor.h" #include "tensorflow/core/framework/bfloat16.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -13,27 +14,8 @@ namespace tensorflow { namespace musa { -void DoTranspose(OpKernelContext* ctx, const Tensor& input, - const std::vector& permutation, Tensor* output) { - mHandle& h = GetHandleByCtx(ctx); - - mTensor in_mt = CreateMTensor(input); - mTensor out_mt = CreateMTensor(*output); - - ::musa::dnn::Permute pop; - - if (::musa::dnn::Status::SUCCESS != - pop.ConfigDimStride(out_mt, in_mt, static_cast(permutation.size()), - permutation.data())) { - ctx->CtxFailure(errors::Internal("muDNN Permute ConfigDimStride failed!")); - return; - } - - if (::musa::dnn::Status::SUCCESS != pop.Run(h, out_mt, in_mt)) { - ctx->CtxFailure(errors::Internal("muDNN Permute Run failed!")); - return; - } -} +void DoTranspose(OpKernelContext* ctx, mTensor& in_mt, + const std::vector& permutation, mTensor& out_mt); template class MusaTransposeOp : public MusaOpKernel { @@ -97,7 +79,9 @@ class MusaTransposeOp : public MusaOpKernel { if (output->NumElements() == 0) return; - DoTranspose(ctx, input, permutation_64, output); + mTensor input_mt = CreateMTensor(input); + mTensor output_mt = CreateMTensor(*output); + DoTranspose(ctx, input_mt, permutation_64, output_mt); } }; From e777e0d5d70942fd48432923812e758f122f7c57 Mon Sep 17 00:00:00 2001 From: albert Date: Fri, 27 Feb 2026 10:49:01 +0000 Subject: [PATCH 16/16] format project --- musa_ext/kernels/musa_einsum_op.cc | 7 ++---- musa_ext/kernels/musa_transpose_functor.h | 30 ++++++++++++----------- musa_ext/kernels/musa_transpose_op.cc | 9 ++++--- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/musa_ext/kernels/musa_einsum_op.cc b/musa_ext/kernels/musa_einsum_op.cc index 5ff2852..3ed69ae 100644 --- a/musa_ext/kernels/musa_einsum_op.cc +++ b/musa_ext/kernels/musa_einsum_op.cc @@ -9,6 +9,7 @@ #include "absl/strings/str_split.h" #include "musa_fill_functor.h" #include "musa_reduce_functor.h" +#include "musa_transpose_functor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -22,9 +23,6 @@ namespace tensorflow { namespace musa { -void DoTranspose(OpKernelContext* ctx, mTensor& in_mt, - const std::vector& permutation, mTensor& out_mt); - using ShapeVec = gtl::InlinedVector; using Labels = gtl::InlinedVector; using OperandLabels = gtl::InlinedVector; @@ -233,8 +231,7 @@ struct EinsumHelper { ctx->allocate_temp(DataTypeToEnum::value, transposed_shape, output)); mTensor input_mt = CreateMTensor(input); mTensor output_mt = CreateMTensor(*output); - DoTranspose(ctx, input_mt, permutation, output_mt); - return Status::OK(); + return TransposeFunctor::Compute(ctx, input_mt, permutation, output_mt); } // If there are repeated labels in either the input or output, then this diff --git a/musa_ext/kernels/musa_transpose_functor.h b/musa_ext/kernels/musa_transpose_functor.h index c2019c7..950169d 100644 --- a/musa_ext/kernels/musa_transpose_functor.h +++ b/musa_ext/kernels/musa_transpose_functor.h @@ -5,24 +5,26 @@ namespace tensorflow { namespace musa { -void DoTranspose(OpKernelContext* ctx, mTensor& in_mt, - const std::vector& permutation, mTensor& out_mt) { - mHandle& h = GetHandleByCtx(ctx); +struct TransposeFunctor { + static Status Compute(OpKernelContext* ctx, mTensor& in_mt, + const std::vector& permutation, + mTensor& out_mt) { + mHandle& h = GetHandleByCtx(ctx); - ::musa::dnn::Permute pop; + ::musa::dnn::Permute pop; - if (::musa::dnn::Status::SUCCESS != - pop.ConfigDimStride(out_mt, in_mt, static_cast(permutation.size()), - permutation.data())) { - ctx->CtxFailure(errors::Internal("muDNN Permute ConfigDimStride failed!")); - return; - } + if (::musa::dnn::Status::SUCCESS != + pop.ConfigDimStride(out_mt, in_mt, static_cast(permutation.size()), + permutation.data())) { + return errors::Internal("muDNN Permute ConfigDimStride failed!"); + } - if (::musa::dnn::Status::SUCCESS != pop.Run(h, out_mt, in_mt)) { - ctx->CtxFailure(errors::Internal("muDNN Permute Run failed!")); - return; + if (::musa::dnn::Status::SUCCESS != pop.Run(h, out_mt, in_mt)) { + return errors::Internal("muDNN Permute Run failed!"); + } + return Status::OK(); } -} +}; } // namespace musa } // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/kernels/musa_transpose_op.cc b/musa_ext/kernels/musa_transpose_op.cc index c6c11b6..b4a0c9c 100644 --- a/musa_ext/kernels/musa_transpose_op.cc +++ b/musa_ext/kernels/musa_transpose_op.cc @@ -14,9 +14,6 @@ namespace tensorflow { namespace musa { -void DoTranspose(OpKernelContext* ctx, mTensor& in_mt, - const std::vector& permutation, mTensor& out_mt); - template class MusaTransposeOp : public MusaOpKernel { public: @@ -81,7 +78,11 @@ class MusaTransposeOp : public MusaOpKernel { mTensor input_mt = CreateMTensor(input); mTensor output_mt = CreateMTensor(*output); - DoTranspose(ctx, input_mt, permutation_64, output_mt); + Status status = + TransposeFunctor::Compute(ctx, input_mt, permutation_64, output_mt); + if (!status.ok()) { + ctx->CtxFailure(status); + } } };