From 6bf9819330e18bfd9faa9d5fc5de1e6102c0041d Mon Sep 17 00:00:00 2001 From: moneta Date: Mon, 16 Mar 2026 17:05:48 +0100 Subject: [PATCH] [tmva][sofie] Add new ScatterND operator Add an implementation of ScatterND operator which is needed to parse the MLPF model from CMS Include also 3 tests to probe the different type of scattering wich can be performed --- tmva/sofie/CMakeLists.txt | 1 + tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx | 192 ++++++++++++++++++ tmva/sofie/test/TestCustomModelsFromONNX.cxx | 54 +++++ tmva/sofie/test/input_models/ScatterND_1.onnx | 21 ++ tmva/sofie/test/input_models/ScatterND_2.onnx | 22 ++ tmva/sofie/test/input_models/ScatterND_3.onnx | 22 ++ tmva/sofie_parsers/CMakeLists.txt | 1 + tmva/sofie_parsers/src/ParseScatterND.cxx | 58 ++++++ tmva/sofie_parsers/src/RModelParser_ONNX.cxx | 2 + 9 files changed, 373 insertions(+) create mode 100644 tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx create mode 100644 tmva/sofie/test/input_models/ScatterND_1.onnx create mode 100644 tmva/sofie/test/input_models/ScatterND_2.onnx create mode 100644 tmva/sofie/test/input_models/ScatterND_3.onnx create mode 100644 tmva/sofie_parsers/src/ParseScatterND.cxx diff --git a/tmva/sofie/CMakeLists.txt b/tmva/sofie/CMakeLists.txt index dc44ac0a59af2..6fdc7a46183ee 100644 --- a/tmva/sofie/CMakeLists.txt +++ b/tmva/sofie/CMakeLists.txt @@ -65,6 +65,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofie TMVA/ROperator_Einsum.hxx TMVA/ROperator_Random.hxx TMVA/ROperator_ScatterElements.hxx + TMVA/ROperator_ScatterND.hxx TMVA/ROperator_Gather.hxx TMVA/ROperator_GatherND.hxx TMVA/ROperator_NonZero.hxx diff --git a/tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx b/tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx new file mode 100644 index 0000000000000..570b3f7a294aa --- /dev/null +++ b/tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx @@ -0,0 +1,192 @@ +#ifndef TMVA_SOFIE_ROPERATOR_ScatterND +#define TMVA_SOFIE_ROPERATOR_ScatterND + +#include "TMVA/SOFIE_common.hxx" +#include "TMVA/ROperator.hxx" +#include "TMVA/RModel.hxx" + +#include +#include +#include + +namespace TMVA{ +namespace Experimental{ +namespace SOFIE{ + +class ROperator_ScatterND final : public ROperator +{ +private: + + + std::string fNX; + std::string fNI; + std::string fNU; + std::string fNY; + std::string fReduction; + + std::vector fShapeX; + std::vector fShapeI; + std::vector fShapeY; + + + std::vector fIndices; // indices vector in case they are known at initialization + + std::string fType; + + +public: + ROperator_ScatterND(){} + ROperator_ScatterND(const std::string & nameX, const std::string & nameI, const std::string & nameU, const std::string & nameY, + std::string reduction): + fNX(UTILITY::Clean_name(nameX)), fNI(UTILITY::Clean_name(nameI)), fNU(UTILITY::Clean_name(nameU)), + fNY(UTILITY::Clean_name(nameY)), fReduction(reduction) + { + fInputTensorNames = { fNX, fNI, fNU }; + fOutputTensorNames = { fNY }; + } + + void Initialize(RModel& model) override { + + // input must be a graph input, or already initialized intermediate tensor + if (!model.CheckIfTensorAlreadyExist(fNX)){ + throw std::runtime_error(std::string("TMVA SOFIE ScatterND Op Input Tensor ") + fNX + "is not found in model"); + } + if (!model.CheckIfTensorAlreadyExist(fNI)) { + throw std::runtime_error(std::string("TMVA SOFIE ScatterND Op Input Tensor ") + fNI + "is not found in model"); + } + if (!model.CheckIfTensorAlreadyExist(fNU)) { + throw std::runtime_error(std::string("TMVA SOFIE ScatterND Op Input Tensor ") + fNU + "is not found in model"); + } + //tbd check for constant tensors + + fShapeX = model.GetDimTensorShape(fNX); + fShapeI = model.GetDimTensorShape(fNI); + auto shapeU = model.GetDimTensorShape(fNU); + + // Validate inputs if fShapeI last is not dynamic + + //if (!model.IsDynamicTensor(fNI)) { + const size_t r = fShapeX.size(); // rank of data + const size_t q = fShapeI.size(); // rank of indices + if (!(fShapeI.back().isParam) ) { + const size_t k = fShapeI.back().dim; // index depth + + if (k > r) + throw std::invalid_argument( + "ScatterND: last dim of indices (" + std::to_string(k) + + ") must be <= rank of data (" + std::to_string(r) + ")"); + + // Expected updates rank = q - 1 + r - k + int64_t expected_updates_rank = q - 1 + r - k; + if ((int64_t) shapeU.size() != expected_updates_rank) + throw std::invalid_argument("ScatterND: updates rank mismatch"); + } else { + // Assumption is that last dimension of index shape is known (is not dynamic) + throw std::runtime_error("TMVA SOFIE ScatterND : Index_shape(-1) is not known. This case is not supported"); + } + + // output shape is equal to input shape + fShapeY = fShapeX; + + model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY); + if (model.Verbose()) { + std::cout << "ScatterElements: input: " << ConvertDimShapeToString(fShapeX) + << " indices " << ConvertDimShapeToString(fShapeI) + << " update " << ConvertDimShapeToString(shapeU); + std::cout << "\t----> " << ConvertDimShapeToString(fShapeY) << std::endl; + } + } + + std::string Generate(std::string opName) override { + if (fIsOutputConstant) { + // no code to generate here for constant output. Tensor output is defined in Session constructor + return "//---------------------------------------\n"; + } + opName = "op_" + opName; + std::stringstream out; + out << "//--------- ScatterND " << opName << " --> " << ConvertDimShapeToString(fShapeY) << "\n"; + + size_t r = fShapeX.size(); + + // Strides + auto stridesX = UTILITY::ComputeStrideFromShape(fShapeX); + auto stridesY = UTILITY::ComputeStrideFromShape(fShapeY); + auto stridesI = UTILITY::ComputeStrideFromShape(fShapeI); + + // case input_index_shape == rank of input + size_t k = fShapeI.back().dim; + + // Total number of index tuples = product of indices dims except last + std::vector shapeIndFirst(fShapeI.begin(), fShapeI.begin()+ fShapeI.size()-1); + auto num_index_tuples = ConvertDimShapeToLength(shapeIndFirst); + + //slice size (is product of input from k to r) + std::vector shapeSlice(fShapeX.begin()+k, fShapeX.end()); + auto slice_size = ConvertDimShapeToLength(shapeSlice); + + auto data_length = ConvertDimShapeToLength(fShapeX); + + //step1: input->output + out << SP << "// Step 1: copy input data to output\n"; + out << SP << "std::copy(tensor_" << fNX << ", tensor_" << fNX << " + " << data_length << ", tensor_" << fNY << ");\n"; + + // Step 2: Emit strides as a static constexpr array + out << SP << "// Step 2: data strides (row-major)\n"; + out << SP << "static constexpr int64_t " << opName << "_data_strides[" << r << "] = {"; + for (size_t i = 0; i < r; ++i) + out << stridesX[i] << (i + 1 < r ? ", " : ""); + out << "};\n\n"; + + // Step 3: Scatter loop + out << SP << "// Step 3: scatter updates into output\n"; + out << SP << "for (int64_t idx = 0; idx < " << num_index_tuples << "; idx++) {\n"; + + // Resolve flat data offset from k-dimensional index tuple + out << SP << SP << "int64_t data_offset = 0;\n"; + for (size_t dim = 0; dim < k; ++dim) { + out << SP << SP << "{\n"; + out << SP << SP << SP << "int64_t coord = tensor_" << fNI + << "[idx * " << k << " + " << dim << "];\n"; + // Support negative indices + out << SP << SP << SP << "if (coord < 0) coord += " << fShapeX[dim] << ";\n"; + out << SP << SP << SP << "data_offset += coord * " + << opName << "_data_strides[" << dim << "];\n"; + out << SP << SP << "}\n"; + } + + // Apply updates with reduction + out << SP << SP << "for (int64_t s = 0; s < " << slice_size << "; s++) {\n"; + out << SP << SP << SP << "auto upd = tensor_" << fNU + << "[idx * " << slice_size << " + s];\n"; + + if (fReduction.empty() || fReduction == "none") { + out << SP << SP << SP << "tensor_" << fNY << "[data_offset + s] = upd;\n"; + } else if (fReduction == "add") { + out << SP << SP << SP << "tensor_" << fNY<< "[data_offset + s] += upd;\n"; + } else if (fReduction == "mul") { + out << SP << SP << SP << "tensor_" << fNY << "[data_offset + s] *= upd;\n"; + } else if (fReduction == "min") { + out << SP << SP << SP << "tensor_" << fNY<< "[data_offset + s] = " + << "std::min(tensor_" << fNY << "[data_offset + s], upd);\n"; + } else if (fReduction == "max") { + out << SP << SP << SP << "tensor_" << fNY << "[data_offset + s] = " + << "std::max(tensor_" << fNY << "[data_offset + s], upd);\n"; + } else { + throw std::runtime_error( + "TMVA SOFIE ScatterND: unsupported reduction '" + fReduction + "'"); + } + + out << SP << SP << "}\n"; // end slice loop + out << SP << "}\n"; // end index tuple loop + + return out.str(); + } + +}; + +}//SOFIE +}//Experimental +}//TMVA + + +#endif //TMVA_SOFIE_ROPERATOR_RELU diff --git a/tmva/sofie/test/TestCustomModelsFromONNX.cxx b/tmva/sofie/test/TestCustomModelsFromONNX.cxx index 94993f601a3c4..25bf2350c5a61 100644 --- a/tmva/sofie/test/TestCustomModelsFromONNX.cxx +++ b/tmva/sofie/test/TestCustomModelsFromONNX.cxx @@ -3006,3 +3006,57 @@ TEST(ONNX, NotIsNaN) } } +TEST(ONNX, ScatterND_1) +{ + // test 1-D scatter (k=1, scalar slice) + std::vector input = {1.,2.,3.,4.,5.}; // shape {5} + std::vector indices = { 0, 2, 4}; // shape {3,1} + std::vector updates = { 10.,30.,50.}; // shape {3} + std::vector correct_output = {10., 2., 30., 4., 50.}; + + ASSERT_INCLUDE_AND_RUN(std::vector, "ScatterND_1", input, indices, updates); + + // Checking output size + EXPECT_EQ(output.size(), correct_output.size()); + // Checking output + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - correct_output[i]), DEFAULT_TOLERANCE); + } +} + +TEST(ONNX, ScatterND_2) +{ + // test 2-d Scatter - scatter rows - reduction = 'add + std::vector input = {1.,1.,2.,2.,3.,3.}; // shape {3,2} + std::vector indices = { 0, 1}; // shape {2,1} + std::vector updates = { 10.,10.,20.,20.}; // shape { 2,2} + std::vector correct_output = {11., 11., 22., 22., 3., 3.}; + + ASSERT_INCLUDE_AND_RUN(std::vector, "ScatterND_2", input, indices, updates); + + // Checking output size + EXPECT_EQ(output.size(), correct_output.size()); + // Checking output + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - correct_output[i]), DEFAULT_TOLERANCE); + } +} + +TEST(ONNX, ScatterND_3) +{ + // test element wise scatter (k==rank input) reduction = 'mul' + std::vector input = {1.,2.,3.,4.}; // shape {2,2} + std::vector indices = { 0,0, 1,1}; // shape {2,2} + std::vector updates = { 11.,22.}; // shape { 2} + std::vector correct_output = {11., 2., 3., 88.}; + + ASSERT_INCLUDE_AND_RUN(std::vector, "ScatterND_3", input, indices, updates); + + // Checking output size + EXPECT_EQ(output.size(), correct_output.size()); + // Checking output + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - correct_output[i]), DEFAULT_TOLERANCE); + } +} + diff --git a/tmva/sofie/test/input_models/ScatterND_1.onnx b/tmva/sofie/test/input_models/ScatterND_1.onnx new file mode 100644 index 0000000000000..6e6bd2b58c0f7 --- /dev/null +++ b/tmva/sofie/test/input_models/ScatterND_1.onnx @@ -0,0 +1,21 @@ +  onnx-example:” ++ +data +indices +updatesoutput" ScatterND TestGraphZ +data + + +Z +indices +  + +Z +updates + + +b +output + + +B \ No newline at end of file diff --git a/tmva/sofie/test/input_models/ScatterND_2.onnx b/tmva/sofie/test/input_models/ScatterND_2.onnx new file mode 100644 index 0000000000000..9211d555dffda --- /dev/null +++ b/tmva/sofie/test/input_models/ScatterND_2.onnx @@ -0,0 +1,22 @@ +  onnx-example:µ +@ +data +indices +updatesoutput" ScatterND* + reduction"add  TestGraphZ +data +  + +Z +indices +  + +Z +updates +  + +b +output +  + +B \ No newline at end of file diff --git a/tmva/sofie/test/input_models/ScatterND_3.onnx b/tmva/sofie/test/input_models/ScatterND_3.onnx new file mode 100644 index 0000000000000..20d83a7dd1715 --- /dev/null +++ b/tmva/sofie/test/input_models/ScatterND_3.onnx @@ -0,0 +1,22 @@ +  onnx-example:± +@ +data +indices +updatesoutput" ScatterND* + reduction"mul  TestGraphZ +data +  + +Z +indices +  + +Z +updates + + +b +output +  + +B \ No newline at end of file diff --git a/tmva/sofie_parsers/CMakeLists.txt b/tmva/sofie_parsers/CMakeLists.txt index 80069c44c6929..5dc49688dfa6f 100644 --- a/tmva/sofie_parsers/CMakeLists.txt +++ b/tmva/sofie_parsers/CMakeLists.txt @@ -76,6 +76,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofieParser src/ParseEinsum.cxx src/ParseRandom.cxx src/ParseScatterElements.cxx + src/ParseScatterND.cxx src/ParseNonZero.cxx src/ParseNot.cxx ${PROTO_SRCS} diff --git a/tmva/sofie_parsers/src/ParseScatterND.cxx b/tmva/sofie_parsers/src/ParseScatterND.cxx new file mode 100644 index 0000000000000..feda091182c63 --- /dev/null +++ b/tmva/sofie_parsers/src/ParseScatterND.cxx @@ -0,0 +1,58 @@ +#include "TMVA/RModelParser_ONNX.hxx" +#include "TMVA/ROperator_ScatterND.hxx" +#include "onnx_proto3.pb.h" + +namespace TMVA { +namespace Experimental { +namespace SOFIE { + +ParserFuncSignature ParseScatterND = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + + if (nodeproto.input_size() != 3) { + throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has invalid input size"); + } + // data is input 0 + if (!parser.IsRegisteredTensorType(nodeproto.input(0))){ + throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has input tensor " + nodeproto.input(0) + + " but its type is not yet registered"); + } + if (!parser.IsRegisteredTensorType(nodeproto.input(1))){ + throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has input tensor " + nodeproto.input(1) + + " but its type is not yet registered"); + } + if (!parser.IsRegisteredTensorType(nodeproto.input(2))){ + throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has input tensor " + nodeproto.input(2) + + " but its type is not yet registered"); + } + ETensorType input_type = parser.GetTensorType(nodeproto.input(0)); + if (parser.GetTensorType(nodeproto.input(2)) != input_type) { + throw std::runtime_error("TMVA::SOFIE ONNX parser ScatterND op has input tensors of different types: " + + nodeproto.input(2) + " : " + ConvertTypeToString(parser.GetTensorType(nodeproto.input(2))) + + " and " + nodeproto.input(0) + " : " + ConvertTypeToString(input_type)); + } + + std::string reduction; + for (int i = 0; i < nodeproto.attribute_size(); i++) { + std::string attribute_name = nodeproto.attribute(i).name(); + if (attribute_name == "reduction") + reduction = nodeproto.attribute(i).s(); + } + + std::unique_ptr op; + std::string output_name = nodeproto.output(0); + + op.reset(new ROperator_ScatterND(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2), + output_name, reduction)); + + // Infer the output type + if (!parser.IsRegisteredTensorType(output_name)) { + parser.RegisterTensorType(output_name, input_type); + } + + return op; +}; + + +} // namespace SOFIE +} // namespace Experimental +} // namespace TMVA diff --git a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx index b77451da25c5b..6090b8a0799c6 100644 --- a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx +++ b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx @@ -94,6 +94,7 @@ extern ParserFuncSignature ParseWhere; extern ParserFuncSignature ParseEinsum; extern ParserFuncSignature ParseRandom; extern ParserFuncSignature ParseScatterElements; +extern ParserFuncSignature ParseScatterND; extern ParserFuncSignature ParseNonZero; // Declaration of fused operators extern ParserFuseFuncSignature ParseFuseConvAdd; @@ -250,6 +251,7 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un RegisterOperator("RandomUniform", ParseRandom); RegisterOperator("RandomUniformLike", ParseRandom); RegisterOperator("ScatterElements", ParseScatterElements); + RegisterOperator("ScatterND", ParseScatterND); RegisterOperator("NonZero", ParseNonZero); }