-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathRNN_ECAL_test.C
More file actions
116 lines (82 loc) · 4.26 KB
/
RNN_ECAL_test.C
File metadata and controls
116 lines (82 loc) · 4.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
void RNN_ECAL_test() {
TFile *input(0);
TString fname = "EcalData.root";
input = TFile::Open( fname ); // check if file in local directory exists
if (!input) {
std::cout << "ERROR: could not open data file" << std::endl;
exit(1);
}
std::cout << "--- RNNClassification : Using input file: " << input->GetName() << std::endl;
// Create a ROOT output file where TMVA will store ntuples, histograms, etc.
TString outfileName( "TMVA_DNN.root" );
TFile* outputFile = TFile::Open( outfileName, "RECREATE" );
// Creating the factory object
TMVA::Factory *factory = new TMVA::Factory( "TMVAClassification", outputFile,
"!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=Classification:!ModelPersistence" );
TMVA::DataLoader *dataloader=new TMVA::DataLoader("dataset");
TTree *signalTree = (TTree*)input->Get("sig");
TTree *background = (TTree*)input->Get("bgk");
signalTree->Print();
background->Print();
// add variables (time zero and time 1)
for (int i = 500; i < 540; ++i) {
TString varName = TString::Format("EB_adc0[%d]",i);
dataloader->AddVariable(varName,'F');
}
for (int i = 500; i < 540; ++i) {
TString varName = TString::Format("EB_adc1[%d]",i);
dataloader->AddVariable(varName,'F');
}
dataloader->AddSignalTree ( signalTree, 1.0 );
dataloader->AddBackgroundTree( background, 1.0 );
// check given input
auto & datainfo = dataloader->GetDataSetInfo();
auto vars = datainfo.GetListOfVariables();
std::cout << "number of variables is " << vars.size() << std::endl;
for ( auto & v : vars) std::cout << v << ",";
std::cout << std::endl;
int ntrainEvts = 500;
int ntestEvts = 500;
TString trainAndTestOpt = TString::Format("nTrain_Signal=%d:nTrain_Background=%d:nTest_Signal=%d:nTest_Background=%d:SplitMode=Random:NormMode=NumEvents:!V",ntrainEvts,ntrainEvts,ntestEvts,ntestEvts );
TCut mycuts = "";//Entry$<1000"; // for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
TCut mycutb = "";//Entry$<1000";
dataloader->PrepareTrainingAndTestTree( mycuts, mycutb,trainAndTestOpt);
std::cout << "prepared DATA LOADER " << std::endl;
// Input Layout
TString inputLayoutString("InputLayout=1|2|40");
// Batch Layout
TString batchLayoutString("BatchLayout=256|2|40");
// General layout.
//TString layoutString ("Layout=RNN|128|300|2|0,DENSE|64|TANH,DENSE|1|LINEAR");
TString layoutString ("Layout=RNN|128|40|2|0,RESHAPE|1|2|128|FLAT,DENSE|64|TANH,DENSE|1|LINEAR");
// Training strategies.
TString training0("LearningRate=1e-1,Momentum=0.9,Repetitions=1,"
"ConvergenceSteps=100,BatchSize=256,TestRepetitions=1,"
"WeightDecay=1e-4,Regularization=L2,"
"DropConfig=0.0+0.5+0.5+0.5, Multithreading=True");
TString training1("LearningRate=1e-2,Momentum=0.9,Repetitions=1,"
"ConvergenceSteps=20,BatchSize=256,TestRepetitions=10,"
"WeightDecay=1e-4,Regularization=L2,"
"DropConfig=0.0+0.0+0.0+0.0, Multithreading=True");
TString training2("LearningRate=1e-3,Momentum=0.0,Repetitions=1,"
"ConvergenceSteps=20,BatchSize=256,TestRepetitions=10,"
"WeightDecay=1e-4,Regularization=L2,"
"DropConfig=0.0+0.0+0.0+0.0, Multithreading=True");
TString trainingStrategyString ("TrainingStrategy=");
trainingStrategyString += training0; // + "|" + training1 + "|" + training2;
// General Options.
TString rnnOptions ("!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=N:"
"WeightInitialization=XAVIERUNIFORM");
rnnOptions.Append(":"); rnnOptions.Append(inputLayoutString);
rnnOptions.Append(":"); rnnOptions.Append(batchLayoutString);
rnnOptions.Append(":"); rnnOptions.Append(layoutString);
rnnOptions.Append(":"); rnnOptions.Append(trainingStrategyString);
rnnOptions.Append(":Architecture=CPU");
factory->BookMethod(dataloader, TMVA::Types::kDL, "DNN_CPU", rnnOptions);
factory->TrainAllMethods();
// ---- Evaluate all MVAs using the set of test events
factory->TestAllMethods();
// ----- Evaluate and compare performance of all configured MVAs
factory->EvaluateAllMethods();
outputFile->Close();
}