Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions tmva/sofie/test/TestRModelParserPyTorch.C
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,96 @@ TEST(RModelParser_PyTorch, CONVOLUTION_MODEL)
EXPECT_LE(std::abs(outputConv[i] - pOutputConv[i]), TOLERANCE);
}
}

TEST(RModelParser_PyTorch, ACTIVATION_MODEL)
{
constexpr float TOLERANCE = 1.E-3;
std::vector<float> inputActivation ={-1.6207, 0.6133,
0.5058, -1.2560,
-0.7750, -1.6701,
0.8171, -0.2858};

Py_Initialize();
if (gSystem->AccessPathName("PyTorchModelActivation.pt",kFileExists))
GenerateModels();

std::vector<size_t> inputTensorShapeActivation{2,4};
std::vector<std::vector<size_t>> inputShapesActivation{inputTensorShapeActivation};
TMVA::Experimental::RSofieReader s("PyTorchModelActivation.pt", inputShapesActivation);
std::vector<float> outputActivation = s.Compute(inputActivation);


PyObject* main = PyImport_AddModule("__main__");
PyObject* fGlobalNS = PyModule_GetDict(main);
PyObject* fLocalNS = PyDict_New();
if (!fGlobalNS) {
throw std::runtime_error("Can't init global namespace for Python");
}
if (!fLocalNS) {
throw std::runtime_error("Can't init local namespace for Python");
}
PyRun_String("import torch",Py_single_input,fGlobalNS,fLocalNS);
PyRun_String("model=torch.jit.load('PyTorchModelActivation.pt')",Py_single_input,fGlobalNS,fLocalNS);
PyRun_String("input=torch.reshape(torch.FloatTensor([-1.6207, 0.6133,"
" 0.5058, -1.2560,"
"-0.7750, -1.6701,"
" 0.8171, -0.2858]),(2,4))",Py_single_input,fGlobalNS,fLocalNS);
PyRun_String("output=model(input).detach().numpy().reshape(2,6)",Py_single_input,fGlobalNS,fLocalNS);
PyRun_String("outputSize=output.size",Py_single_input,fGlobalNS,fLocalNS);
std::size_t pOutputActivationSize=(std::size_t)PyLong_AsLong(PyDict_GetItemString(fLocalNS,"outputSize"));

//Testing the actual and expected output tensor sizes
EXPECT_EQ(outputActivation.size(), pOutputActivationSize);

PyArrayObject* pActivationValues=(PyArrayObject*)PyDict_GetItemString(fLocalNS,"output");
float* pOutputActivation=(float*)PyArray_DATA(pActivationValues);

//Testing the actual and expected output tensor values
for (size_t i = 0; i < outputActivation.size(); ++i) {
EXPECT_LE(std::abs(outputActivation[i] - pOutputActivation[i]), TOLERANCE);
}
}

TEST(RModelParser_PyTorch, BATCHNORM_MODEL)
{
constexpr float TOLERANCE = 1.E-3;
std::vector<float> inputBatchNorm(2*4*3);
std::iota(inputBatchNorm.begin(), inputBatchNorm.end(), 1.0f);

Py_Initialize();
if (gSystem->AccessPathName("PyTorchModelBatchNorm.pt",kFileExists))
GenerateModels();

std::vector<size_t> inputTensorShapeBatchNorm{2, 4, 3};
std::vector<std::vector<size_t>> inputShapesBatchNorm{inputTensorShapeBatchNorm};
TMVA::Experimental::RSofieReader s("PyTorchModelBatchNorm.pt", inputShapesBatchNorm);
std::vector<float> outputBatchNorm = s.Compute(inputBatchNorm);


PyObject* main = PyImport_AddModule("__main__");
PyObject* fGlobalNS = PyModule_GetDict(main);
PyObject* fLocalNS = PyDict_New();
if (!fGlobalNS) {
throw std::runtime_error("Can't init global namespace for Python");
}
if (!fLocalNS) {
throw std::runtime_error("Can't init local namespace for Python");
}
PyRun_String("import torch",Py_single_input,fGlobalNS,fLocalNS);
PyRun_String("model=torch.jit.load('PyTorchModelBatchNorm.pt')",Py_single_input,fGlobalNS,fLocalNS);
PyRun_String("input=torch.arange(1,25,dtype=torch.float).reshape(2,4,3)",Py_single_input,fGlobalNS,fLocalNS);
PyRun_String("output=model(input).detach().numpy().reshape(2,6)",Py_single_input,fGlobalNS,fLocalNS);
PyRun_String("outputSize=output.size",Py_single_input,fGlobalNS,fLocalNS);
std::size_t pOutputBatchNormSize=(std::size_t)PyLong_AsLong(PyDict_GetItemString(fLocalNS,"outputSize"));

//Testing the actual and expected output tensor sizes
EXPECT_EQ(outputBatchNorm.size(), pOutputBatchNormSize);

PyArrayObject* pBatchNormValues=(PyArrayObject*)PyDict_GetItemString(fLocalNS,"output");
float* pOutputBatchNorm=(float*)PyArray_DATA(pBatchNormValues);

//Testing the actual and expected output tensor values
for (size_t i = 0; i < outputBatchNorm.size(); ++i) {
EXPECT_LE(std::abs(outputBatchNorm[i] - pOutputBatchNorm[i]), TOLERANCE);
}
}
77 changes: 77 additions & 0 deletions tmva/sofie/test/generatePyTorchModels.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,83 @@ def generateConvolutionModel():
m = torch.jit.script(model)
torch.jit.save(m,"PyTorchModelConvolution.pt")


def generateActivationModel():
# Model using Tanh, LeakyReLU, and Softmax activations
model = nn.Sequential(
nn.Linear(4,8),
nn.Tanh(),
nn.Linear(8,6),
nn.LeakyReLU(0.01),
nn.Linear(6,6),
nn.Softmax(dim=1)
)

#Construct loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

#Constructing random test dataset
x=torch.randn(2,4)
y=torch.randn(2,6)

#Training the model
for i in range(2000):
y_pred = model(x)
loss = criterion(y_pred,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

#Saving the trained model
model.eval()
m = torch.jit.script(model)
torch.jit.save(m,"PyTorchModelActivation.pt")


def generateBatchNormModel():
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.bn = nn.BatchNorm1d(4)
self.fc = nn.Linear(12,6)
self.scale = nn.Parameter(torch.ones(6))
self.bias2 = nn.Parameter(torch.zeros(6))

def forward(self, x):
x = self.bn(x)
x = torch.flatten(x, 1)
x = self.fc(x)
x = x * self.scale
x = x + self.bias2
return x

model = Model()

#Construct loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

#Constructing random test dataset
x=torch.randn(2, 4, 3)
y=torch.randn(2, 6)

#Training the model
for i in range(2000):
y_pred = model(x)
loss = criterion(y_pred,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

#Saving the trained model
model.eval()
m = torch.jit.script(model)
torch.jit.save(m,"PyTorchModelBatchNorm.pt")


generateSequentialModel()
generateModuleModel()
generateConvolutionModel()
generateActivationModel()
generateBatchNormModel()
Loading