Skip to content

Commit e5a241e

Browse files
committed
Avoid having intermediary vectors for the NN-based TPC PID
1 parent 0bba10f commit e5a241e

4 files changed

Lines changed: 721 additions & 72 deletions

File tree

Common/Tools/PID/pidTPCModule.h

Lines changed: 198 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ class pidTPCModule
414414
network.initModel(pidTPCopts.networkPathLocally.value, pidTPCopts.enableNetworkOptimizations.value, pidTPCopts.networkSetNumThreads.value, strtoul(headers["Valid-From"].c_str(), NULL, 0), strtoul(headers["Valid-Until"].c_str(), NULL, 0));
415415
std::vector<float> dummyInput(network.getNumInputNodes(), 1.);
416416
network.evalModel(dummyInput); /// Init the model evaluations
417+
setupColumnInputNetwork();
417418
LOGP(info, "Retrieved NN corrections for production tag {}, pass number {}, and NN-Version {}", headers["LPMProductionTag"], headers["RecoPassName"], headers["NN-Version"]);
418419
} else {
419420
LOG(fatal) << "No valid NN object found matching retrieved Bethe-Bloch parametrisation for pass " << metadata["RecoPassName"] << ". Please ensure that the requested pass has dedicated NN corrections available";
@@ -427,6 +428,7 @@ class pidTPCModule
427428
network.initModel(pidTPCopts.networkPathLocally.value, pidTPCopts.enableNetworkOptimizations.value, pidTPCopts.networkSetNumThreads.value);
428429
std::vector<float> dummyInput(network.getNumInputNodes(), 1.);
429430
network.evalModel(dummyInput); // This is an initialisation and might reduce the overhead of the model
431+
setupColumnInputNetwork();
430432
}
431433
} else {
432434
return;
@@ -438,6 +440,110 @@ class pidTPCModule
438440
}
439441
} // end init
440442

443+
//__________________________________________________
444+
void setupColumnInputNetwork()
445+
{
446+
using PI = o2::ml::OnnxModel::PreprocInput;
447+
using PF = o2::ml::OnnxModel::PreprocFeature;
448+
const int nFeat = network.getNumInputNodes(); // # network features (6..9), original model
449+
450+
// Raw graph inputs (this order defines the tensor feeding order in
451+
// createNetworkPrediction). All per-track columns are wrapped zero-copy from
452+
// the Arrow buffers; nclNorm/hrDivisor/hrFallback are per-DF runtime scalars.
453+
std::vector<PI> in;
454+
in.push_back({"tpcInnerParam", PI::Type::TrackFloat});
455+
in.push_back({"tgl", PI::Type::TrackFloat});
456+
in.push_back({"signed1Pt", PI::Type::TrackFloat});
457+
in.push_back({"mass", PI::Type::ScalarFloat});
458+
in.push_back({"collisionId", PI::Type::TrackInt32});
459+
in.push_back({"multArray", PI::Type::CollisionFloat});
460+
in.push_back({"nclNorm", PI::Type::ScalarFloat});
461+
in.push_back({"nclsFindable", PI::Type::TrackUint8});
462+
in.push_back({"nclsFMF", PI::Type::TrackInt8});
463+
if (nFeat >= 7) {
464+
in.push_back({"occArray", PI::Type::CollisionFloat});
465+
}
466+
if (nFeat >= 8) {
467+
in.push_back({"hrArray", PI::Type::CollisionFloat});
468+
in.push_back({"hrDivisor", PI::Type::ScalarFloat});
469+
in.push_back({"hrFallback", PI::Type::ScalarFloat});
470+
}
471+
if (nFeat >= 9) {
472+
in.push_back({"phi", PI::Type::TrackFloat});
473+
}
474+
in.push_back({"validMask", PI::Type::TrackBool});
475+
476+
// Per-feature preprocessing (exactly nFeat entries, in the training order).
477+
std::vector<PF> feat;
478+
{
479+
PF f;
480+
f.op = PF::Op::Passthrough;
481+
f.a = "tpcInnerParam";
482+
feat.push_back(f);
483+
}
484+
{
485+
PF f;
486+
f.op = PF::Op::Passthrough;
487+
f.a = "tgl";
488+
feat.push_back(f);
489+
}
490+
{
491+
PF f;
492+
f.op = PF::Op::Passthrough;
493+
f.a = "signed1Pt";
494+
feat.push_back(f);
495+
}
496+
{
497+
PF f;
498+
f.op = PF::Op::BroadcastScalar;
499+
f.a = "mass";
500+
f.shapeRef = "collisionId";
501+
feat.push_back(f);
502+
}
503+
{
504+
PF f;
505+
f.op = PF::Op::GatherNormWhere;
506+
f.a = "multArray";
507+
f.b = "collisionId";
508+
f.c = {11000.f, 1.f, 0.f};
509+
feat.push_back(f);
510+
}
511+
{
512+
PF f;
513+
f.op = PF::Op::NClsSqrtRecip;
514+
f.a = "nclsFindable";
515+
f.b = "nclsFMF";
516+
f.scaleInput = "nclNorm";
517+
feat.push_back(f);
518+
}
519+
if (nFeat >= 7) {
520+
PF f;
521+
f.op = PF::Op::GatherNormWhere;
522+
f.a = "occArray";
523+
f.b = "collisionId";
524+
f.c = {60000.f, 1.f, 0.f};
525+
feat.push_back(f);
526+
}
527+
if (nFeat >= 8) {
528+
PF f;
529+
f.op = PF::Op::GatherNormWhere;
530+
f.a = "hrArray";
531+
f.b = "collisionId";
532+
f.scaleInput = "hrDivisor";
533+
f.fallbackInput = "hrFallback";
534+
feat.push_back(f);
535+
}
536+
if (nFeat >= 9) {
537+
PF f;
538+
f.op = PF::Op::Mod2;
539+
f.a = "phi";
540+
f.c = {2.f * static_cast<float>(M_PI), 2.f * static_cast<float>(M_PI), static_cast<float>(M_PI) / 9.0f};
541+
feat.push_back(f);
542+
}
543+
544+
network.setupColumnInputs(in, feat, "validMask");
545+
}
546+
441547
//__________________________________________________
442548
template <typename TCCDB, typename M, typename T, typename B>
443549
std::vector<float> createNetworkPrediction(TCCDB& ccdb, soa::Join<aod::Collisions, aod::EvSels> const& collisions, M const& mults, T const& tracks, B const& bcs, const size_t size)
@@ -489,6 +595,7 @@ class pidTPCModule
489595
network.initModel(pidTPCopts.networkPathLocally.value, pidTPCopts.enableNetworkOptimizations.value, pidTPCopts.networkSetNumThreads.value, strtoul(headers["Valid-From"].c_str(), NULL, 0), strtoul(headers["Valid-Until"].c_str(), NULL, 0));
490596
std::vector<float> dummyInput(network.getNumInputNodes(), 1.);
491597
network.evalModel(dummyInput);
598+
setupColumnInputNetwork();
492599
LOGP(info, "Retrieved NN corrections for production tag {}, pass number {}, NN-Version number {}", headers["LPMProductionTag"], headers["RecoPassName"], headers["NN-Version"]);
493600
} else {
494601
LOG(fatal) << "No valid NN object found matching retrieved Bethe-Bloch parametrisation for pass " << metadata["RecoPassName"] << ". Please ensure that the requested pass has dedicated NN corrections available";
@@ -497,19 +604,14 @@ class pidTPCModule
497604
}
498605

499606
// Defining some network parameters
500-
int input_dimensions = network.getNumInputNodes();
607+
const int nFeat = network.getNumFeatures();
501608
int output_dimensions = network.getNumOutputNodes();
502-
const uint64_t track_prop_size = input_dimensions * size;
503609
const uint64_t prediction_size = output_dimensions * size;
504610

505611
network_prediction = std::vector<float>(prediction_size * 9); // For each mass hypotheses
506612
const float nNclNormalization = response->GetNClNormalization();
507613
float duration_network = 0;
508614

509-
std::vector<float> track_properties(track_prop_size);
510-
uint64_t counter_track_props = 0;
511-
int loop_counter = 0;
512-
513615
// To load the Hadronic rate once for each collision
514616
float hadronicRateBegin = 0.;
515617
std::vector<float> hadronicRateForCollision(collisions.size(), 0.0f);
@@ -530,88 +632,113 @@ class pidTPCModule
530632
hadronicRateBegin = 0.0f;
531633
}
532634

533-
// Filling a std::vector<float> to be evaluated by the network
534-
// Evaluation on single tracks brings huge overhead: Thus evaluation is done on one large vector
535635
static constexpr int NParticleTypes = 9;
536636
constexpr int ExpectedInputDimensionsNNV2 = 7;
537637
constexpr int ExpectedInputDimensionsNNV3 = 8;
538638
constexpr int ExpectedInputDimensionsNNV4 = 9;
539-
constexpr auto NetworkVersionV2 = "2";
540-
constexpr auto NetworkVersionV3 = "3";
541-
constexpr auto NetworkVersionV4 = "4";
542-
for (int j = 0; j < NParticleTypes; j++) { // Loop over particle number for which network correction is used
543-
for (auto const& trk : tracks) {
544-
if (!trk.hasTPC()) {
545-
continue;
546-
}
547-
if (pidTPCopts.skipTPCOnly) {
548-
if (!trk.hasITS() && !trk.hasTRD() && !trk.hasTOF()) {
549-
continue;
550-
}
551-
}
552-
track_properties[counter_track_props] = trk.tpcInnerParam();
553-
track_properties[counter_track_props + 1] = trk.tgl();
554-
track_properties[counter_track_props + 2] = trk.signed1Pt();
555-
track_properties[counter_track_props + 3] = o2::track::pid_constants::sMasses[j];
556-
track_properties[counter_track_props + 4] = trk.has_collision() ? mults[trk.collisionId()] / 11000. : 1.;
557-
track_properties[counter_track_props + 5] = std::sqrt(nNclNormalization / trk.tpcNClsFound());
558-
if (input_dimensions == ExpectedInputDimensionsNNV2 && networkVersion == NetworkVersionV2) {
559-
track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.;
560-
}
561-
if (input_dimensions == ExpectedInputDimensionsNNV3 && networkVersion == NetworkVersionV3) {
562-
track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.;
563-
if (trk.has_collision()) {
564-
if (collsys == CollisionSystemType::kCollSyspp) {
565-
track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 1500.;
566-
} else {
567-
track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 50.;
568-
}
569-
} else {
570-
// asign Hadronic Rate at beginning of run if track does not belong to a collision
571-
if (collsys == CollisionSystemType::kCollSyspp) {
572-
track_properties[counter_track_props + 7] = hadronicRateBegin / 1500.;
573-
} else {
574-
track_properties[counter_track_props + 7] = hadronicRateBegin / 50.;
575-
}
576-
}
639+
640+
const float hadronicRateDivisor = (collsys == CollisionSystemType::kCollSyspp) ? 1500.f : 50.f;
641+
642+
// Per-collision arrays (O(nColl)); gathered per track inside the model via the
643+
// collisionId column, then normalised in-graph.
644+
const int64_t nColl = static_cast<int64_t>(collisions.size());
645+
std::vector<float> multArray(nColl);
646+
std::vector<float> occArray(nFeat >= ExpectedInputDimensionsNNV2 ? nColl : 0);
647+
{
648+
int64_t c = 0;
649+
for (const auto& col : collisions) {
650+
multArray[c] = static_cast<float>(mults[c]);
651+
if (nFeat >= ExpectedInputDimensionsNNV2) {
652+
occArray[c] = col.ft0cOccupancyInTimeRange();
577653
}
654+
++c;
655+
}
656+
}
578657

579-
if (input_dimensions == ExpectedInputDimensionsNNV4 && networkVersion == NetworkVersionV4) {
580-
track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.;
581-
if (trk.has_collision()) {
582-
if (collsys == CollisionSystemType::kCollSyspp) {
583-
track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 1500.;
584-
} else {
585-
track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 50.;
586-
}
587-
} else {
588-
// asign Hadronic Rate at beginning of run if track does not belong to a collision
589-
if (collsys == CollisionSystemType::kCollSyspp) {
590-
track_properties[counter_track_props + 7] = hadronicRateBegin / 1500.;
591-
} else {
592-
track_properties[counter_track_props + 7] = hadronicRateBegin / 50.;
593-
}
594-
}
595-
track_properties[counter_track_props + 8] = std::fmod(std::fmod(trk.phi(), 2 * M_PI) + 2 * M_PI, M_PI / 9.0);
658+
// Raw per-track Arrow column buffers (zero-copy; one chunk per DataFrame).
659+
auto arrowTable = tracks.asArrowTable();
660+
auto chunk0 = [&](const char* name) -> std::shared_ptr<arrow::Array> {
661+
const int idx = arrowTable->schema()->GetFieldIndex(name);
662+
if (idx < 0) {
663+
LOG(fatal) << "createNetworkPrediction: column '" << name << "' not found in tracks table";
664+
}
665+
auto col = arrowTable->column(idx);
666+
if (col->num_chunks() != 1) {
667+
LOG(fatal) << "createNetworkPrediction: column '" << name << "' has " << col->num_chunks()
668+
<< " chunks; a single chunk per DataFrame is required for zero-copy input";
669+
}
670+
return col->chunk(0);
671+
};
672+
const int64_t nTrk = static_cast<int64_t>(tracks.size());
673+
const float* pTpcInner = std::static_pointer_cast<arrow::FloatArray>(chunk0("fTPCInnerParam"))->raw_values();
674+
const float* pTgl = std::static_pointer_cast<arrow::FloatArray>(chunk0("fTgl"))->raw_values();
675+
const float* pSigned1Pt = std::static_pointer_cast<arrow::FloatArray>(chunk0("fSigned1Pt"))->raw_values();
676+
const int32_t* pCollId = std::static_pointer_cast<arrow::Int32Array>(chunk0("fIndexCollisions"))->raw_values();
677+
const uint8_t* pFindable = std::static_pointer_cast<arrow::UInt8Array>(chunk0("fTPCNClsFindable"))->raw_values();
678+
const int8_t* pFMF = std::static_pointer_cast<arrow::Int8Array>(chunk0("fTPCNClsFindableMinusFound"))->raw_values();
679+
const float* pPhi = (nFeat >= ExpectedInputDimensionsNNV4)
680+
? std::static_pointer_cast<arrow::FloatArray>(chunk0("fPhi"))->raw_values()
681+
: nullptr;
682+
683+
// Single boolean mask of the tracks the network runs on; the model Compress'es
684+
// to exactly these rows so the output is compact and the consumer's
685+
// count_tracks indexing is unchanged. Condition matches process()'s counter.
686+
std::vector<uint8_t> validMask(nTrk);
687+
{
688+
int64_t t = 0;
689+
for (auto const& trk : tracks) {
690+
bool valid = trk.hasTPC();
691+
if (valid && pidTPCopts.skipTPCOnly && !trk.hasITS() && !trk.hasTRD() && !trk.hasTOF()) {
692+
valid = false;
596693
}
597-
counter_track_props += input_dimensions;
694+
validMask[t++] = valid ? 1 : 0;
695+
}
696+
}
697+
698+
auto memInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
699+
const int64_t one = 1;
700+
float massVal = 0.f;
701+
float nclNormVal = nNclNormalization;
702+
float hrDivisorVal = hadronicRateDivisor;
703+
float hrFallbackVal = hadronicRateBegin / hadronicRateDivisor;
704+
705+
// Evaluate once per mass hypothesis; only the mass scalar input changes.
706+
for (int j = 0; j < NParticleTypes; j++) {
707+
massVal = o2::track::pid_constants::sMasses[j];
708+
709+
std::vector<Ort::Value> inputTensors;
710+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, const_cast<float*>(pTpcInner), nTrk, &nTrk, 1));
711+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, const_cast<float*>(pTgl), nTrk, &nTrk, 1));
712+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, const_cast<float*>(pSigned1Pt), nTrk, &nTrk, 1));
713+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, &massVal, 1, &one, 1));
714+
inputTensors.emplace_back(Ort::Value::CreateTensor<int32_t>(memInfo, const_cast<int32_t*>(pCollId), nTrk, &nTrk, 1));
715+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, multArray.data(), nColl, &nColl, 1));
716+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, &nclNormVal, 1, &one, 1));
717+
inputTensors.emplace_back(Ort::Value::CreateTensor<uint8_t>(memInfo, const_cast<uint8_t*>(pFindable), nTrk, &nTrk, 1));
718+
inputTensors.emplace_back(Ort::Value::CreateTensor<int8_t>(memInfo, const_cast<int8_t*>(pFMF), nTrk, &nTrk, 1));
719+
if (nFeat >= ExpectedInputDimensionsNNV2) {
720+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, occArray.data(), nColl, &nColl, 1));
721+
}
722+
if (nFeat >= ExpectedInputDimensionsNNV3) {
723+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, hadronicRateForCollision.data(), nColl, &nColl, 1));
724+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, &hrDivisorVal, 1, &one, 1));
725+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, &hrFallbackVal, 1, &one, 1));
598726
}
727+
if (nFeat >= ExpectedInputDimensionsNNV4) {
728+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, const_cast<float*>(pPhi), nTrk, &nTrk, 1));
729+
}
730+
inputTensors.emplace_back(Ort::Value::CreateTensor<bool>(memInfo, reinterpret_cast<bool*>(validMask.data()), nTrk, &nTrk, 1));
599731

600732
auto start_network_eval = std::chrono::high_resolution_clock::now();
601-
float* output_network = network.evalModel(track_properties);
733+
float* output_network = network.evalModel<float>(inputTensors);
602734
auto stop_network_eval = std::chrono::high_resolution_clock::now();
603735
duration_network += std::chrono::duration<float, std::ratio<1, 1000000000>>(stop_network_eval - start_network_eval).count();
604736
for (uint64_t k = 0; k < prediction_size; k += output_dimensions) {
605737
for (int l = 0; l < output_dimensions; l++) {
606-
network_prediction[k + l + prediction_size * loop_counter] = output_network[k + l];
738+
network_prediction[k + l + prediction_size * j] = output_network[k + l];
607739
}
608740
}
609-
610-
counter_track_props = 0;
611-
loop_counter += 1;
612741
}
613-
track_properties.clear();
614-
615742
auto stop_network_total = std::chrono::high_resolution_clock::now();
616743
LOG(debug) << "Neural Network for the TPC PID response correction: Time per track (eval ONNX): " << duration_network / (size * 9) << "ns ; Total time (eval ONNX): " << duration_network / 1000000000 << " s";
617744
LOG(debug) << "Neural Network for the TPC PID response correction: Time per track (eval + overhead): " << std::chrono::duration<float, std::ratio<1, 1000000000>>(stop_network_total - start_network_total).count() / (size * 9) << "ns ; Total time (eval + overhead): " << std::chrono::duration<float, std::ratio<1, 1000000000>>(stop_network_total - start_network_total).count() / 1000000000 << " s";

Tools/ML/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@
1111

1212
o2physics_add_library(MLCore
1313
SOURCES model.cxx
14-
PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore ONNXRuntime::ONNXRuntime
14+
PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore ONNXRuntime::ONNXRuntime ONNX::onnx_proto
1515
)

0 commit comments

Comments
 (0)