diff --git a/tmva/sofie/test/TestRModelParserPyTorch.C b/tmva/sofie/test/TestRModelParserPyTorch.C index 4dbe36b91d983..00a73fc912199 100644 --- a/tmva/sofie/test/TestRModelParserPyTorch.C +++ b/tmva/sofie/test/TestRModelParserPyTorch.C @@ -172,3 +172,96 @@ TEST(RModelParser_PyTorch, CONVOLUTION_MODEL) EXPECT_LE(std::abs(outputConv[i] - pOutputConv[i]), TOLERANCE); } } + +TEST(RModelParser_PyTorch, ACTIVATION_MODEL) +{ + constexpr float TOLERANCE = 1.E-3; + std::vector inputActivation ={-1.6207, 0.6133, + 0.5058, -1.2560, + -0.7750, -1.6701, + 0.8171, -0.2858}; + + Py_Initialize(); + if (gSystem->AccessPathName("PyTorchModelActivation.pt",kFileExists)) + GenerateModels(); + + std::vector inputTensorShapeActivation{2,4}; + std::vector> inputShapesActivation{inputTensorShapeActivation}; + TMVA::Experimental::RSofieReader s("PyTorchModelActivation.pt", inputShapesActivation); + std::vector outputActivation = s.Compute(inputActivation); + + + PyObject* main = PyImport_AddModule("__main__"); + PyObject* fGlobalNS = PyModule_GetDict(main); + PyObject* fLocalNS = PyDict_New(); + if (!fGlobalNS) { + throw std::runtime_error("Can't init global namespace for Python"); + } + if (!fLocalNS) { + throw std::runtime_error("Can't init local namespace for Python"); + } + PyRun_String("import torch",Py_single_input,fGlobalNS,fLocalNS); + PyRun_String("model=torch.jit.load('PyTorchModelActivation.pt')",Py_single_input,fGlobalNS,fLocalNS); + PyRun_String("input=torch.reshape(torch.FloatTensor([-1.6207, 0.6133," + " 0.5058, -1.2560," + "-0.7750, -1.6701," + " 0.8171, -0.2858]),(2,4))",Py_single_input,fGlobalNS,fLocalNS); + PyRun_String("output=model(input).detach().numpy().reshape(2,6)",Py_single_input,fGlobalNS,fLocalNS); + PyRun_String("outputSize=output.size",Py_single_input,fGlobalNS,fLocalNS); + std::size_t pOutputActivationSize=(std::size_t)PyLong_AsLong(PyDict_GetItemString(fLocalNS,"outputSize")); + + //Testing the actual and expected output tensor sizes + EXPECT_EQ(outputActivation.size(), pOutputActivationSize); + + PyArrayObject* pActivationValues=(PyArrayObject*)PyDict_GetItemString(fLocalNS,"output"); + float* pOutputActivation=(float*)PyArray_DATA(pActivationValues); + + //Testing the actual and expected output tensor values + for (size_t i = 0; i < outputActivation.size(); ++i) { + EXPECT_LE(std::abs(outputActivation[i] - pOutputActivation[i]), TOLERANCE); + } +} + +TEST(RModelParser_PyTorch, BATCHNORM_MODEL) +{ + constexpr float TOLERANCE = 1.E-3; + std::vector inputBatchNorm(2*4*3); + std::iota(inputBatchNorm.begin(), inputBatchNorm.end(), 1.0f); + + Py_Initialize(); + if (gSystem->AccessPathName("PyTorchModelBatchNorm.pt",kFileExists)) + GenerateModels(); + + std::vector inputTensorShapeBatchNorm{2, 4, 3}; + std::vector> inputShapesBatchNorm{inputTensorShapeBatchNorm}; + TMVA::Experimental::RSofieReader s("PyTorchModelBatchNorm.pt", inputShapesBatchNorm); + std::vector outputBatchNorm = s.Compute(inputBatchNorm); + + + PyObject* main = PyImport_AddModule("__main__"); + PyObject* fGlobalNS = PyModule_GetDict(main); + PyObject* fLocalNS = PyDict_New(); + if (!fGlobalNS) { + throw std::runtime_error("Can't init global namespace for Python"); + } + if (!fLocalNS) { + throw std::runtime_error("Can't init local namespace for Python"); + } + PyRun_String("import torch",Py_single_input,fGlobalNS,fLocalNS); + PyRun_String("model=torch.jit.load('PyTorchModelBatchNorm.pt')",Py_single_input,fGlobalNS,fLocalNS); + PyRun_String("input=torch.arange(1,25,dtype=torch.float).reshape(2,4,3)",Py_single_input,fGlobalNS,fLocalNS); + PyRun_String("output=model(input).detach().numpy().reshape(2,6)",Py_single_input,fGlobalNS,fLocalNS); + PyRun_String("outputSize=output.size",Py_single_input,fGlobalNS,fLocalNS); + std::size_t pOutputBatchNormSize=(std::size_t)PyLong_AsLong(PyDict_GetItemString(fLocalNS,"outputSize")); + + //Testing the actual and expected output tensor sizes + EXPECT_EQ(outputBatchNorm.size(), pOutputBatchNormSize); + + PyArrayObject* pBatchNormValues=(PyArrayObject*)PyDict_GetItemString(fLocalNS,"output"); + float* pOutputBatchNorm=(float*)PyArray_DATA(pBatchNormValues); + + //Testing the actual and expected output tensor values + for (size_t i = 0; i < outputBatchNorm.size(); ++i) { + EXPECT_LE(std::abs(outputBatchNorm[i] - pOutputBatchNorm[i]), TOLERANCE); + } +} diff --git a/tmva/sofie/test/generatePyTorchModels.py b/tmva/sofie/test/generatePyTorchModels.py index 1f9f3bbbb6c6d..b55d0800f021e 100644 --- a/tmva/sofie/test/generatePyTorchModels.py +++ b/tmva/sofie/test/generatePyTorchModels.py @@ -104,6 +104,83 @@ def generateConvolutionModel(): m = torch.jit.script(model) torch.jit.save(m,"PyTorchModelConvolution.pt") + +def generateActivationModel(): + # Model using Tanh, LeakyReLU, and Softmax activations + model = nn.Sequential( + nn.Linear(4,8), + nn.Tanh(), + nn.Linear(8,6), + nn.LeakyReLU(0.01), + nn.Linear(6,6), + nn.Softmax(dim=1) + ) + + #Construct loss function and optimizer + criterion = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(),lr=0.01) + + #Constructing random test dataset + x=torch.randn(2,4) + y=torch.randn(2,6) + + #Training the model + for i in range(2000): + y_pred = model(x) + loss = criterion(y_pred,y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + #Saving the trained model + model.eval() + m = torch.jit.script(model) + torch.jit.save(m,"PyTorchModelActivation.pt") + + +def generateBatchNormModel(): + class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.bn = nn.BatchNorm1d(4) + self.fc = nn.Linear(12,6) + self.scale = nn.Parameter(torch.ones(6)) + self.bias2 = nn.Parameter(torch.zeros(6)) + + def forward(self, x): + x = self.bn(x) + x = torch.flatten(x, 1) + x = self.fc(x) + x = x * self.scale + x = x + self.bias2 + return x + + model = Model() + + #Construct loss function and optimizer + criterion = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(),lr=0.01) + + #Constructing random test dataset + x=torch.randn(2, 4, 3) + y=torch.randn(2, 6) + + #Training the model + for i in range(2000): + y_pred = model(x) + loss = criterion(y_pred,y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + #Saving the trained model + model.eval() + m = torch.jit.script(model) + torch.jit.save(m,"PyTorchModelBatchNorm.pt") + + generateSequentialModel() generateModuleModel() generateConvolutionModel() +generateActivationModel() +generateBatchNormModel() diff --git a/tmva/sofie_parsers/src/RModelParser_PyTorch.cxx b/tmva/sofie_parsers/src/RModelParser_PyTorch.cxx index 34125e5f521d1..a00816f32141b 100644 --- a/tmva/sofie_parsers/src/RModelParser_PyTorch.cxx +++ b/tmva/sofie_parsers/src/RModelParser_PyTorch.cxx @@ -74,6 +74,16 @@ std::unique_ptr MakePyTorchRelu(PyObject* fNode); // For instant std::unique_ptr MakePyTorchSelu(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Selu operator std::unique_ptr MakePyTorchSigmoid(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Sigmoid operator std::unique_ptr MakePyTorchTranspose(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Transpose operator +std::unique_ptr MakePyTorchTanh(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Tanh operator +std::unique_ptr MakePyTorchSoftmax(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Softmax operator +std::unique_ptr MakePyTorchLeakyRelu(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's LeakyRelu operator +std::unique_ptr MakePyTorchAdd(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Add operator +std::unique_ptr MakePyTorchSub(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Sub operator +std::unique_ptr MakePyTorchMul(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Mul operator +std::unique_ptr MakePyTorchMatMul(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's MatMul operator +std::unique_ptr MakePyTorchFlatten(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Flatten operator +std::unique_ptr MakePyTorchReshape(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Reshape operator +std::unique_ptr MakePyTorchBatchNormalization(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's BatchNormalization operator // For mapping PyTorch ONNX Graph's Node with the preparatory functions for ROperators using PyTorchMethodMap = std::unordered_map (*)(PyObject* fNode)>; @@ -85,7 +95,17 @@ const PyTorchMethodMap mapPyTorchNode = {"onnx::Relu", &MakePyTorchRelu}, {"onnx::Selu", &MakePyTorchSelu}, {"onnx::Sigmoid", &MakePyTorchSigmoid}, - {"onnx::Transpose", &MakePyTorchTranspose} + {"onnx::Transpose", &MakePyTorchTranspose}, + {"onnx::Tanh", &MakePyTorchTanh}, + {"onnx::Softmax", &MakePyTorchSoftmax}, + {"onnx::LeakyRelu", &MakePyTorchLeakyRelu}, + {"onnx::Add", &MakePyTorchAdd}, + {"onnx::Sub", &MakePyTorchSub}, + {"onnx::Mul", &MakePyTorchMul}, + {"onnx::MatMul", &MakePyTorchMatMul}, + {"onnx::Flatten", &MakePyTorchFlatten}, + {"onnx::Reshape", &MakePyTorchReshape}, + {"onnx::BatchNormalization", &MakePyTorchBatchNormalization} }; @@ -332,6 +352,339 @@ std::unique_ptr MakePyTorchConv(PyObject* fNode){ } return op; } + + +////////////////////////////////////////////////////////////////////////////////// +/// \brief Prepares a ROperator_Tanh object +/// +/// \param[in] fNode Python PyTorch ONNX Graph node +/// \return Unique pointer to ROperator object +/// +/// For instantiating a ROperator_Tanh object, the names of +/// input & output tensors and the data-type of the Graph node +/// are extracted. +std::unique_ptr MakePyTorchTanh(PyObject* fNode){ + PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs"); + PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs"); + std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0)); + std::string fNameX = PyStringAsString(PyList_GetItem(fInputs,0)); + std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0)); + std::unique_ptr op; + switch(ConvertStringToType(fNodeDType)){ + case ETensorType::FLOAT: { + op.reset(new ROperator_Tanh(fNameX,fNameY)); + break; + } + default: + throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Tanh does not yet support input type " + fNodeDType); + } + return op; +} + + +////////////////////////////////////////////////////////////////////////////////// +/// \brief Prepares a ROperator_Softmax object +/// +/// \param[in] fNode Python PyTorch ONNX Graph node +/// \return Unique pointer to ROperator object +/// +/// For instantiating a ROperator_Softmax object, the names of +/// input & output tensors, the data-type, and the axis attribute +/// are extracted. The axis defaults to -1 per the ONNX specification +/// when the attribute is not present. +std::unique_ptr MakePyTorchSoftmax(PyObject* fNode){ + PyObject* fAttributes = PyDict_GetItemString(fNode,"nodeAttributes"); + PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs"); + PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs"); + std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0)); + std::string fNameX = PyStringAsString(PyList_GetItem(fInputs,0)); + std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0)); + + // Extract the axis attribute; default to -1 per ONNX spec + int64_t fAttrAxis = -1; + PyObject* fAxisKey = PyUnicode_FromString("axis"); + if (PyDict_Contains(fAttributes, fAxisKey)) { + fAttrAxis = (int64_t)(PyLong_AsLong(PyDict_GetItemString(fAttributes,"axis"))); + } + Py_DECREF(fAxisKey); + + std::unique_ptr op; + switch(ConvertStringToType(fNodeDType)){ + case ETensorType::FLOAT: { + op.reset(new ROperator_Softmax(fAttrAxis,fNameX,fNameY)); + break; + } + default: + throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Softmax does not yet support input type " + fNodeDType); + } + return op; +} + + +////////////////////////////////////////////////////////////////////////////////// +/// \brief Prepares a ROperator_LeakyRelu object +/// +/// \param[in] fNode Python PyTorch ONNX Graph node +/// \return Unique pointer to ROperator object +/// +/// For instantiating a ROperator_LeakyRelu object, the names of +/// input & output tensors, the data-type, and the alpha attribute +/// are extracted. The alpha defaults to 0.01 per the ONNX specification +/// when the attribute is not present. +std::unique_ptr MakePyTorchLeakyRelu(PyObject* fNode){ + PyObject* fAttributes = PyDict_GetItemString(fNode,"nodeAttributes"); + PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs"); + PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs"); + std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0)); + std::string fNameX = PyStringAsString(PyList_GetItem(fInputs,0)); + std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0)); + + // Extract the alpha attribute; default to 0.01 per ONNX spec + float fAttrAlpha = 0.01f; + PyObject* fAlphaKey = PyUnicode_FromString("alpha"); + if (PyDict_Contains(fAttributes, fAlphaKey)) { + fAttrAlpha = (float)(PyFloat_AsDouble(PyDict_GetItemString(fAttributes,"alpha"))); + } + Py_DECREF(fAlphaKey); + + std::unique_ptr op; + switch(ConvertStringToType(fNodeDType)){ + case ETensorType::FLOAT: { + op.reset(new ROperator_LeakyRelu(fAttrAlpha,fNameX,fNameY)); + break; + } + default: + throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator LeakyRelu does not yet support input type " + fNodeDType); + } + return op; +} + + +////////////////////////////////////////////////////////////////////////////////// +/// \brief Prepares a ROperator_BasicBinary object +/// +/// \param[in] fNode Python PyTorch ONNX Graph node +/// \return Unique pointer to ROperator object +/// +/// For instantiating a ROperator_BasicBinary, the names of the two +/// input tensors and the output tensor are extracted. +std::unique_ptr MakePyTorchAdd(PyObject* fNode){ + PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs"); + PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs"); + std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0)); + std::string fNameA = PyStringAsString(PyList_GetItem(fInputs,0)); + std::string fNameB = PyStringAsString(PyList_GetItem(fInputs,1)); + std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0)); + std::unique_ptr op; + switch(ConvertStringToType(fNodeDType)){ + case ETensorType::FLOAT: { + op.reset(new ROperator_BasicBinary(fNameA,fNameB,fNameY)); + break; + } + default: + throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Add does not yet support input type " + fNodeDType); + } + return op; +} + + +////////////////////////////////////////////////////////////////////////////////// +/// \brief Prepares a ROperator_BasicBinary object +/// +/// \param[in] fNode Python PyTorch ONNX Graph node +/// \return Unique pointer to ROperator object +/// +/// For instantiating a ROperator_BasicBinary, the names of the two +/// input tensors and the output tensor are extracted. +std::unique_ptr MakePyTorchSub(PyObject* fNode){ + PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs"); + PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs"); + std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0)); + std::string fNameA = PyStringAsString(PyList_GetItem(fInputs,0)); + std::string fNameB = PyStringAsString(PyList_GetItem(fInputs,1)); + std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0)); + std::unique_ptr op; + switch(ConvertStringToType(fNodeDType)){ + case ETensorType::FLOAT: { + op.reset(new ROperator_BasicBinary(fNameA,fNameB,fNameY)); + break; + } + default: + throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Sub does not yet support input type " + fNodeDType); + } + return op; +} + + +////////////////////////////////////////////////////////////////////////////////// +/// \brief Prepares a ROperator_BasicBinary object +/// +/// \param[in] fNode Python PyTorch ONNX Graph node +/// \return Unique pointer to ROperator object +/// +/// For instantiating a ROperator_BasicBinary, the names of the two +/// input tensors and the output tensor are extracted. +std::unique_ptr MakePyTorchMul(PyObject* fNode){ + PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs"); + PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs"); + std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0)); + std::string fNameA = PyStringAsString(PyList_GetItem(fInputs,0)); + std::string fNameB = PyStringAsString(PyList_GetItem(fInputs,1)); + std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0)); + std::unique_ptr op; + switch(ConvertStringToType(fNodeDType)){ + case ETensorType::FLOAT: { + op.reset(new ROperator_BasicBinary(fNameA,fNameB,fNameY)); + break; + } + default: + throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Mul does not yet support input type " + fNodeDType); + } + return op; +} + +////////////////////////////////////////////////////////////////////////////////// +/// \brief Prepares a ROperator_Gemm object for MatMul +/// +/// \param[in] fNode Python PyTorch ONNX Graph node +/// \return Unique pointer to ROperator object +/// +/// For PyTorch's MatMul operation in its ONNX graph, the names of the two +/// input tensors and the output tensor are extracted. MatMul is mapped to +/// ROperator_Gemm with alpha=1.0, beta=0.0, no transpose, and no bias. +std::unique_ptr MakePyTorchMatMul(PyObject* fNode){ + PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs"); + PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs"); + std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0)); + std::string fNameA = PyStringAsString(PyList_GetItem(fInputs,0)); + std::string fNameB = PyStringAsString(PyList_GetItem(fInputs,1)); + std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0)); + std::unique_ptr op; + switch(ConvertStringToType(fNodeDType)){ + case ETensorType::FLOAT: { + op.reset(new ROperator_Gemm(1.0, 0.0, 0, 0, fNameA, fNameB, fNameY)); + break; + } + default: + throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator MatMul does not yet support input type " + fNodeDType); + } + return op; +} + + +////////////////////////////////////////////////////////////////////////////////// +/// \brief Prepares a ROperator_Reshape object for Flatten +/// +/// \param[in] fNode Python PyTorch ONNX Graph node +/// \return Unique pointer to ROperator object +/// +/// For PyTorch's Flatten operation in its ONNX graph, the axis attribute is +/// extracted (default 1). Flatten is mapped to ROperator_Reshape with +/// ReshapeOpMode Flatten and an empty shape tensor name. +std::unique_ptr MakePyTorchFlatten(PyObject* fNode){ + PyObject* fAttributes = PyDict_GetItemString(fNode,"nodeAttributes"); + PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs"); + PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs"); + std::string fNameData = PyStringAsString(PyList_GetItem(fInputs,0)); + std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0)); + + // Extract axis attribute; default to 1 per ONNX spec + int fAttrAxis = 1; + PyObject* fAxisKey = PyUnicode_FromString("axis"); + if (PyDict_Contains(fAttributes, fAxisKey)) { + fAttrAxis = (int)(PyLong_AsLong(PyDict_GetItemString(fAttributes,"axis"))); + } + Py_DECREF(fAxisKey); + + std::unique_ptr op; + op.reset(new ROperator_Reshape(Flatten, fAttrAxis, fNameData, "", fNameY)); + return op; +} + + +////////////////////////////////////////////////////////////////////////////////// +/// \brief Prepares a ROperator_Reshape object for Reshape +/// +/// \param[in] fNode Python PyTorch ONNX Graph node +/// \return Unique pointer to ROperator object +/// +/// For PyTorch's Reshape operation in its ONNX graph, the data tensor and shape +/// tensor names are extracted from nodeInputs[0] and nodeInputs[1] respectively. +/// The allowzero attribute is extracted (default 0). +std::unique_ptr MakePyTorchReshape(PyObject* fNode){ + PyObject* fAttributes = PyDict_GetItemString(fNode,"nodeAttributes"); + PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs"); + PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs"); + std::string fNameData = PyStringAsString(PyList_GetItem(fInputs,0)); + std::string fNameShape = PyStringAsString(PyList_GetItem(fInputs,1)); + std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0)); + + // Extract allowzero attribute; default to 0 per ONNX spec + int fAttrAllowZero = 0; + PyObject* fAllowZeroKey = PyUnicode_FromString("allowzero"); + if (PyDict_Contains(fAttributes, fAllowZeroKey)) { + fAttrAllowZero = (int)(PyLong_AsLong(PyDict_GetItemString(fAttributes,"allowzero"))); + } + Py_DECREF(fAllowZeroKey); + + std::unique_ptr op; + op.reset(new ROperator_Reshape(Reshape, fAttrAllowZero, fNameData, fNameShape, fNameY)); + return op; +} + + +////////////////////////////////////////////////////////////////////////////////// +/// \brief Prepares a ROperator_BatchNormalization object +/// +/// \param[in] fNode Python PyTorch ONNX Graph node +/// \return Unique pointer to ROperator object +/// +/// For BatchNormalization, five inputs are extracted: X (data), scale, B (bias), +/// mean and var. The epsilon and momentum attributes are extracted with ONNX +/// default values of 1e-5 and 0.9 respectively. Training mode is set to 0 +/// as SOFIE is strictly an inference engine. +std::unique_ptr MakePyTorchBatchNormalization(PyObject* fNode){ + PyObject* fAttributes = PyDict_GetItemString(fNode,"nodeAttributes"); + PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs"); + PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs"); + std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0)); + + // 5 inputs: X, scale, B, mean, var + std::string fNameX = PyStringAsString(PyList_GetItem(fInputs,0)); + std::string fNameScale = PyStringAsString(PyList_GetItem(fInputs,1)); + std::string fNameB = PyStringAsString(PyList_GetItem(fInputs,2)); + std::string fNameMean = PyStringAsString(PyList_GetItem(fInputs,3)); + std::string fNameVar = PyStringAsString(PyList_GetItem(fInputs,4)); + std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0)); + + // Extract epsilon attribute; default to 1e-5 per ONNX spec + float fAttrEpsilon = 1e-5f; + PyObject* fEpsilonKey = PyUnicode_FromString("epsilon"); + if (PyDict_Contains(fAttributes, fEpsilonKey)) { + fAttrEpsilon = (float)(PyFloat_AsDouble(PyDict_GetItemString(fAttributes,"epsilon"))); + } + Py_DECREF(fEpsilonKey); + + // Extract momentum attribute; default to 0.9 per ONNX spec + float fAttrMomentum = 0.9f; + PyObject* fMomentumKey = PyUnicode_FromString("momentum"); + if (PyDict_Contains(fAttributes, fMomentumKey)) { + fAttrMomentum = (float)(PyFloat_AsDouble(PyDict_GetItemString(fAttributes,"momentum"))); + } + Py_DECREF(fMomentumKey); + + std::unique_ptr op; + switch(ConvertStringToType(fNodeDType)){ + case ETensorType::FLOAT: { + op.reset(new ROperator_BatchNormalization(fAttrEpsilon, fAttrMomentum, 0, + fNameX, fNameScale, fNameB, fNameMean, fNameVar, fNameY)); + break; + } + default: + throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator BatchNormalization does not yet support input type " + fNodeDType); + } + return op; +} }//INTERNAL @@ -460,10 +813,10 @@ RModel Parse(std::string filename, std::vector> inputShapes, " nodeAttributes={j: _node_get(i, j) for j in nodeAttributeNames}\n" " nodeData['nodeAttributes']=nodeAttributes\n" " nodeInputs=[x for x in i.inputs()]\n" - " nodeInputNames=[x.debugName() for x in nodeInputs]\n" + " nodeInputNames=[x.debugName().replace('.','_') for x in nodeInputs]\n" " nodeData['nodeInputs']=nodeInputNames\n" " nodeOutputs=[x for x in i.outputs()]\n" - " nodeOutputNames=[x.debugName() for x in nodeOutputs]\n" + " nodeOutputNames=[x.debugName().replace('.','_') for x in nodeOutputs]\n" " nodeData['nodeOutputs']=nodeOutputNames\n" " nodeDType=[x.type().scalarType() for x in nodeOutputs]\n" " nodeData['nodeDType']=nodeDType\n" @@ -484,18 +837,28 @@ RModel Parse(std::string filename, std::vector> inputShapes, if(fNodeType == "onnx::Gemm"){ rmodel.AddBlasRoutines({"Gemm", "Gemv"}); } + else if(fNodeType == "onnx::MatMul"){ + rmodel.AddBlasRoutines({"Gemm", "Gemv"}); + } else if(fNodeType == "onnx::Selu" || fNodeType == "onnx::Sigmoid"){ rmodel.AddNeededStdLib("cmath"); } + else if(fNodeType == "onnx::Tanh" || fNodeType == "onnx::Softmax" + || fNodeType == "onnx::LeakyRelu"){ + rmodel.AddNeededStdLib("cmath"); + } else if (fNodeType == "onnx::Conv") { rmodel.AddBlasRoutines({"Gemm", "Axpy"}); } + else if(fNodeType == "onnx::BatchNormalization"){ + rmodel.AddNeededStdLib("cmath"); + } rmodel.AddOperator(INTERNAL::MakePyTorchNode(fNode)); } //Extracting model weights to add the initialized tensors to the RModel - PyRunString("weightNames=[k for k in graph[1].keys()]",fGlobalNS,fLocalNS); + PyRunString("weightNames=[k.replace('.','_') for k in graph[1].keys()]",fGlobalNS,fLocalNS); PyRunString("weights=[v.numpy() for v in graph[1].values()]",fGlobalNS,fLocalNS); PyRunString("weightDTypes=[v.type()[6:-6] for v in graph[1].values()]",fGlobalNS,fLocalNS); PyObject* fPWeightNames = PyDict_GetItemString(fLocalNS,"weightNames"); @@ -534,7 +897,7 @@ RModel Parse(std::string filename, std::vector> inputShapes, //Extracting Input tensor info PyRunString("inputs=[x for x in model.graph.inputs()]",fGlobalNS,fLocalNS); PyRunString("inputs=inputs[1:]",fGlobalNS,fLocalNS); - PyRunString("inputNames=[x.debugName() for x in inputs]",fGlobalNS,fLocalNS); + PyRunString("inputNames=[x.debugName().replace('.','_') for x in inputs]",fGlobalNS,fLocalNS); PyObject* fPInputs= PyDict_GetItemString(fLocalNS,"inputNames"); std::string fInputName; std::vectorfInputShape; @@ -557,7 +920,7 @@ RModel Parse(std::string filename, std::vector> inputShapes, //Extracting output tensor names PyRunString("outputs=[x for x in graph[0].outputs()]",fGlobalNS,fLocalNS); - PyRunString("outputNames=[x.debugName() for x in outputs]",fGlobalNS,fLocalNS); + PyRunString("outputNames=[x.debugName().replace('.','_') for x in outputs]",fGlobalNS,fLocalNS); PyObject* fPOutputs= PyDict_GetItemString(fLocalNS,"outputNames"); std::vector fOutputNames; for(Py_ssize_t outputIter = 0; outputIter < PyList_Size(fPOutputs);++outputIter){