@@ -414,6 +414,7 @@ class pidTPCModule
414414 network.initModel (pidTPCopts.networkPathLocally .value , pidTPCopts.enableNetworkOptimizations .value , pidTPCopts.networkSetNumThreads .value , strtoul (headers[" Valid-From" ].c_str (), NULL , 0 ), strtoul (headers[" Valid-Until" ].c_str (), NULL , 0 ));
415415 std::vector<float > dummyInput (network.getNumInputNodes (), 1 .);
416416 network.evalModel (dummyInput); // / Init the model evaluations
417+ setupColumnInputNetwork ();
417418 LOGP (info, " Retrieved NN corrections for production tag {}, pass number {}, and NN-Version {}" , headers[" LPMProductionTag" ], headers[" RecoPassName" ], headers[" NN-Version" ]);
418419 } else {
419420 LOG (fatal) << " No valid NN object found matching retrieved Bethe-Bloch parametrisation for pass " << metadata[" RecoPassName" ] << " . Please ensure that the requested pass has dedicated NN corrections available" ;
@@ -427,6 +428,7 @@ class pidTPCModule
427428 network.initModel (pidTPCopts.networkPathLocally .value , pidTPCopts.enableNetworkOptimizations .value , pidTPCopts.networkSetNumThreads .value );
428429 std::vector<float > dummyInput (network.getNumInputNodes (), 1 .);
429430 network.evalModel (dummyInput); // This is an initialisation and might reduce the overhead of the model
431+ setupColumnInputNetwork ();
430432 }
431433 } else {
432434 return ;
@@ -438,6 +440,110 @@ class pidTPCModule
438440 }
439441 } // end init
440442
443+ // __________________________________________________
444+ void setupColumnInputNetwork ()
445+ {
446+ using PI = o2::ml::OnnxModel::PreprocInput;
447+ using PF = o2::ml::OnnxModel::PreprocFeature;
448+ const int nFeat = network.getNumInputNodes (); // # network features (6..9), original model
449+
450+ // Raw graph inputs (this order defines the tensor feeding order in
451+ // createNetworkPrediction). All per-track columns are wrapped zero-copy from
452+ // the Arrow buffers; nclNorm/hrDivisor/hrFallback are per-DF runtime scalars.
453+ std::vector<PI> in;
454+ in.push_back ({" tpcInnerParam" , PI::Type::TrackFloat});
455+ in.push_back ({" tgl" , PI::Type::TrackFloat});
456+ in.push_back ({" signed1Pt" , PI::Type::TrackFloat});
457+ in.push_back ({" mass" , PI::Type::ScalarFloat});
458+ in.push_back ({" collisionId" , PI::Type::TrackInt32});
459+ in.push_back ({" multArray" , PI::Type::CollisionFloat});
460+ in.push_back ({" nclNorm" , PI::Type::ScalarFloat});
461+ in.push_back ({" nclsFindable" , PI::Type::TrackUint8});
462+ in.push_back ({" nclsFMF" , PI::Type::TrackInt8});
463+ if (nFeat >= 7 ) {
464+ in.push_back ({" occArray" , PI::Type::CollisionFloat});
465+ }
466+ if (nFeat >= 8 ) {
467+ in.push_back ({" hrArray" , PI::Type::CollisionFloat});
468+ in.push_back ({" hrDivisor" , PI::Type::ScalarFloat});
469+ in.push_back ({" hrFallback" , PI::Type::ScalarFloat});
470+ }
471+ if (nFeat >= 9 ) {
472+ in.push_back ({" phi" , PI::Type::TrackFloat});
473+ }
474+ in.push_back ({" validMask" , PI::Type::TrackBool});
475+
476+ // Per-feature preprocessing (exactly nFeat entries, in the training order).
477+ std::vector<PF> feat;
478+ {
479+ PF f;
480+ f.op = PF::Op::Passthrough;
481+ f.a = " tpcInnerParam" ;
482+ feat.push_back (f);
483+ }
484+ {
485+ PF f;
486+ f.op = PF::Op::Passthrough;
487+ f.a = " tgl" ;
488+ feat.push_back (f);
489+ }
490+ {
491+ PF f;
492+ f.op = PF::Op::Passthrough;
493+ f.a = " signed1Pt" ;
494+ feat.push_back (f);
495+ }
496+ {
497+ PF f;
498+ f.op = PF::Op::BroadcastScalar;
499+ f.a = " mass" ;
500+ f.shapeRef = " collisionId" ;
501+ feat.push_back (f);
502+ }
503+ {
504+ PF f;
505+ f.op = PF::Op::GatherNormWhere;
506+ f.a = " multArray" ;
507+ f.b = " collisionId" ;
508+ f.c = {11000 .f , 1 .f , 0 .f };
509+ feat.push_back (f);
510+ }
511+ {
512+ PF f;
513+ f.op = PF::Op::NClsSqrtRecip;
514+ f.a = " nclsFindable" ;
515+ f.b = " nclsFMF" ;
516+ f.scaleInput = " nclNorm" ;
517+ feat.push_back (f);
518+ }
519+ if (nFeat >= 7 ) {
520+ PF f;
521+ f.op = PF::Op::GatherNormWhere;
522+ f.a = " occArray" ;
523+ f.b = " collisionId" ;
524+ f.c = {60000 .f , 1 .f , 0 .f };
525+ feat.push_back (f);
526+ }
527+ if (nFeat >= 8 ) {
528+ PF f;
529+ f.op = PF::Op::GatherNormWhere;
530+ f.a = " hrArray" ;
531+ f.b = " collisionId" ;
532+ f.scaleInput = " hrDivisor" ;
533+ f.fallbackInput = " hrFallback" ;
534+ feat.push_back (f);
535+ }
536+ if (nFeat >= 9 ) {
537+ PF f;
538+ f.op = PF::Op::Mod2;
539+ f.a = " phi" ;
540+ f.c = {2 .f * static_cast <float >(M_PI), 2 .f * static_cast <float >(M_PI), static_cast <float >(M_PI) / 9 .0f };
541+ feat.push_back (f);
542+ }
543+
544+ network.setupColumnInputs (in, feat, " validMask" );
545+ }
546+
441547 // __________________________________________________
442548 template <typename TCCDB, typename M, typename T, typename B>
443549 std::vector<float > createNetworkPrediction (TCCDB& ccdb, soa::Join<aod::Collisions, aod::EvSels> const & collisions, M const & mults, T const & tracks, B const & bcs, const size_t size)
@@ -489,6 +595,7 @@ class pidTPCModule
489595 network.initModel (pidTPCopts.networkPathLocally .value , pidTPCopts.enableNetworkOptimizations .value , pidTPCopts.networkSetNumThreads .value , strtoul (headers[" Valid-From" ].c_str (), NULL , 0 ), strtoul (headers[" Valid-Until" ].c_str (), NULL , 0 ));
490596 std::vector<float > dummyInput (network.getNumInputNodes (), 1 .);
491597 network.evalModel (dummyInput);
598+ setupColumnInputNetwork ();
492599 LOGP (info, " Retrieved NN corrections for production tag {}, pass number {}, NN-Version number {}" , headers[" LPMProductionTag" ], headers[" RecoPassName" ], headers[" NN-Version" ]);
493600 } else {
494601 LOG (fatal) << " No valid NN object found matching retrieved Bethe-Bloch parametrisation for pass " << metadata[" RecoPassName" ] << " . Please ensure that the requested pass has dedicated NN corrections available" ;
@@ -497,19 +604,14 @@ class pidTPCModule
497604 }
498605
499606 // Defining some network parameters
500- int input_dimensions = network.getNumInputNodes ();
607+ const int nFeat = network.getNumFeatures ();
501608 int output_dimensions = network.getNumOutputNodes ();
502- const uint64_t track_prop_size = input_dimensions * size;
503609 const uint64_t prediction_size = output_dimensions * size;
504610
505611 network_prediction = std::vector<float >(prediction_size * 9 ); // For each mass hypotheses
506612 const float nNclNormalization = response->GetNClNormalization ();
507613 float duration_network = 0 ;
508614
509- std::vector<float > track_properties (track_prop_size);
510- uint64_t counter_track_props = 0 ;
511- int loop_counter = 0 ;
512-
513615 // To load the Hadronic rate once for each collision
514616 float hadronicRateBegin = 0 .;
515617 std::vector<float > hadronicRateForCollision (collisions.size (), 0 .0f );
@@ -530,88 +632,113 @@ class pidTPCModule
530632 hadronicRateBegin = 0 .0f ;
531633 }
532634
533- // Filling a std::vector<float> to be evaluated by the network
534- // Evaluation on single tracks brings huge overhead: Thus evaluation is done on one large vector
535635 static constexpr int NParticleTypes = 9 ;
536636 constexpr int ExpectedInputDimensionsNNV2 = 7 ;
537637 constexpr int ExpectedInputDimensionsNNV3 = 8 ;
538638 constexpr int ExpectedInputDimensionsNNV4 = 9 ;
539- constexpr auto NetworkVersionV2 = " 2" ;
540- constexpr auto NetworkVersionV3 = " 3" ;
541- constexpr auto NetworkVersionV4 = " 4" ;
542- for (int j = 0 ; j < NParticleTypes; j++) { // Loop over particle number for which network correction is used
543- for (auto const & trk : tracks) {
544- if (!trk.hasTPC ()) {
545- continue ;
546- }
547- if (pidTPCopts.skipTPCOnly ) {
548- if (!trk.hasITS () && !trk.hasTRD () && !trk.hasTOF ()) {
549- continue ;
550- }
551- }
552- track_properties[counter_track_props] = trk.tpcInnerParam ();
553- track_properties[counter_track_props + 1 ] = trk.tgl ();
554- track_properties[counter_track_props + 2 ] = trk.signed1Pt ();
555- track_properties[counter_track_props + 3 ] = o2::track::pid_constants::sMasses [j];
556- track_properties[counter_track_props + 4 ] = trk.has_collision () ? mults[trk.collisionId ()] / 11000 . : 1 .;
557- track_properties[counter_track_props + 5 ] = std::sqrt (nNclNormalization / trk.tpcNClsFound ());
558- if (input_dimensions == ExpectedInputDimensionsNNV2 && networkVersion == NetworkVersionV2) {
559- track_properties[counter_track_props + 6 ] = trk.has_collision () ? collisions.iteratorAt (trk.collisionId ()).ft0cOccupancyInTimeRange () / 60000 . : 1 .;
560- }
561- if (input_dimensions == ExpectedInputDimensionsNNV3 && networkVersion == NetworkVersionV3) {
562- track_properties[counter_track_props + 6 ] = trk.has_collision () ? collisions.iteratorAt (trk.collisionId ()).ft0cOccupancyInTimeRange () / 60000 . : 1 .;
563- if (trk.has_collision ()) {
564- if (collsys == CollisionSystemType::kCollSyspp ) {
565- track_properties[counter_track_props + 7 ] = hadronicRateForCollision[trk.collisionId ()] / 1500 .;
566- } else {
567- track_properties[counter_track_props + 7 ] = hadronicRateForCollision[trk.collisionId ()] / 50 .;
568- }
569- } else {
570- // asign Hadronic Rate at beginning of run if track does not belong to a collision
571- if (collsys == CollisionSystemType::kCollSyspp ) {
572- track_properties[counter_track_props + 7 ] = hadronicRateBegin / 1500 .;
573- } else {
574- track_properties[counter_track_props + 7 ] = hadronicRateBegin / 50 .;
575- }
576- }
639+
640+ const float hadronicRateDivisor = (collsys == CollisionSystemType::kCollSyspp ) ? 1500 .f : 50 .f ;
641+
642+ // Per-collision arrays (O(nColl)); gathered per track inside the model via the
643+ // collisionId column, then normalised in-graph.
644+ const int64_t nColl = static_cast <int64_t >(collisions.size ());
645+ std::vector<float > multArray (nColl);
646+ std::vector<float > occArray (nFeat >= ExpectedInputDimensionsNNV2 ? nColl : 0 );
647+ {
648+ int64_t c = 0 ;
649+ for (const auto & col : collisions) {
650+ multArray[c] = static_cast <float >(mults[c]);
651+ if (nFeat >= ExpectedInputDimensionsNNV2) {
652+ occArray[c] = col.ft0cOccupancyInTimeRange ();
577653 }
654+ ++c;
655+ }
656+ }
578657
579- if (input_dimensions == ExpectedInputDimensionsNNV4 && networkVersion == NetworkVersionV4) {
580- track_properties[counter_track_props + 6 ] = trk.has_collision () ? collisions.iteratorAt (trk.collisionId ()).ft0cOccupancyInTimeRange () / 60000 . : 1 .;
581- if (trk.has_collision ()) {
582- if (collsys == CollisionSystemType::kCollSyspp ) {
583- track_properties[counter_track_props + 7 ] = hadronicRateForCollision[trk.collisionId ()] / 1500 .;
584- } else {
585- track_properties[counter_track_props + 7 ] = hadronicRateForCollision[trk.collisionId ()] / 50 .;
586- }
587- } else {
588- // asign Hadronic Rate at beginning of run if track does not belong to a collision
589- if (collsys == CollisionSystemType::kCollSyspp ) {
590- track_properties[counter_track_props + 7 ] = hadronicRateBegin / 1500 .;
591- } else {
592- track_properties[counter_track_props + 7 ] = hadronicRateBegin / 50 .;
593- }
594- }
595- track_properties[counter_track_props + 8 ] = std::fmod (std::fmod (trk.phi (), 2 * M_PI) + 2 * M_PI, M_PI / 9.0 );
658+ // Raw per-track Arrow column buffers (zero-copy; one chunk per DataFrame).
659+ auto arrowTable = tracks.asArrowTable ();
660+ auto chunk0 = [&](const char * name) -> std::shared_ptr<arrow::Array> {
661+ const int idx = arrowTable->schema ()->GetFieldIndex (name);
662+ if (idx < 0 ) {
663+ LOG (fatal) << " createNetworkPrediction: column '" << name << " ' not found in tracks table" ;
664+ }
665+ auto col = arrowTable->column (idx);
666+ if (col->num_chunks () != 1 ) {
667+ LOG (fatal) << " createNetworkPrediction: column '" << name << " ' has " << col->num_chunks ()
668+ << " chunks; a single chunk per DataFrame is required for zero-copy input" ;
669+ }
670+ return col->chunk (0 );
671+ };
672+ const int64_t nTrk = static_cast <int64_t >(tracks.size ());
673+ const float * pTpcInner = std::static_pointer_cast<arrow::FloatArray>(chunk0 (" fTPCInnerParam" ))->raw_values ();
674+ const float * pTgl = std::static_pointer_cast<arrow::FloatArray>(chunk0 (" fTgl" ))->raw_values ();
675+ const float * pSigned1Pt = std::static_pointer_cast<arrow::FloatArray>(chunk0 (" fSigned1Pt" ))->raw_values ();
676+ const int32_t * pCollId = std::static_pointer_cast<arrow::Int32Array>(chunk0 (" fIndexCollisions" ))->raw_values ();
677+ const uint8_t * pFindable = std::static_pointer_cast<arrow::UInt8Array>(chunk0 (" fTPCNClsFindable" ))->raw_values ();
678+ const int8_t * pFMF = std::static_pointer_cast<arrow::Int8Array>(chunk0 (" fTPCNClsFindableMinusFound" ))->raw_values ();
679+ const float * pPhi = (nFeat >= ExpectedInputDimensionsNNV4)
680+ ? std::static_pointer_cast<arrow::FloatArray>(chunk0 (" fPhi" ))->raw_values ()
681+ : nullptr ;
682+
683+ // Single boolean mask of the tracks the network runs on; the model Compress'es
684+ // to exactly these rows so the output is compact and the consumer's
685+ // count_tracks indexing is unchanged. Condition matches process()'s counter.
686+ std::vector<uint8_t > validMask (nTrk);
687+ {
688+ int64_t t = 0 ;
689+ for (auto const & trk : tracks) {
690+ bool valid = trk.hasTPC ();
691+ if (valid && pidTPCopts.skipTPCOnly && !trk.hasITS () && !trk.hasTRD () && !trk.hasTOF ()) {
692+ valid = false ;
596693 }
597- counter_track_props += input_dimensions;
694+ validMask[t++] = valid ? 1 : 0 ;
695+ }
696+ }
697+
698+ auto memInfo = Ort::MemoryInfo::CreateCpu (OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
699+ const int64_t one = 1 ;
700+ float massVal = 0 .f ;
701+ float nclNormVal = nNclNormalization;
702+ float hrDivisorVal = hadronicRateDivisor;
703+ float hrFallbackVal = hadronicRateBegin / hadronicRateDivisor;
704+
705+ // Evaluate once per mass hypothesis; only the mass scalar input changes.
706+ for (int j = 0 ; j < NParticleTypes; j++) {
707+ massVal = o2::track::pid_constants::sMasses [j];
708+
709+ std::vector<Ort::Value> inputTensors;
710+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, const_cast <float *>(pTpcInner), nTrk, &nTrk, 1 ));
711+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, const_cast <float *>(pTgl), nTrk, &nTrk, 1 ));
712+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, const_cast <float *>(pSigned1Pt), nTrk, &nTrk, 1 ));
713+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, &massVal, 1 , &one, 1 ));
714+ inputTensors.emplace_back (Ort::Value::CreateTensor<int32_t >(memInfo, const_cast <int32_t *>(pCollId), nTrk, &nTrk, 1 ));
715+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, multArray.data (), nColl, &nColl, 1 ));
716+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, &nclNormVal, 1 , &one, 1 ));
717+ inputTensors.emplace_back (Ort::Value::CreateTensor<uint8_t >(memInfo, const_cast <uint8_t *>(pFindable), nTrk, &nTrk, 1 ));
718+ inputTensors.emplace_back (Ort::Value::CreateTensor<int8_t >(memInfo, const_cast <int8_t *>(pFMF), nTrk, &nTrk, 1 ));
719+ if (nFeat >= ExpectedInputDimensionsNNV2) {
720+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, occArray.data (), nColl, &nColl, 1 ));
721+ }
722+ if (nFeat >= ExpectedInputDimensionsNNV3) {
723+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, hadronicRateForCollision.data (), nColl, &nColl, 1 ));
724+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, &hrDivisorVal, 1 , &one, 1 ));
725+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, &hrFallbackVal, 1 , &one, 1 ));
598726 }
727+ if (nFeat >= ExpectedInputDimensionsNNV4) {
728+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, const_cast <float *>(pPhi), nTrk, &nTrk, 1 ));
729+ }
730+ inputTensors.emplace_back (Ort::Value::CreateTensor<bool >(memInfo, reinterpret_cast <bool *>(validMask.data ()), nTrk, &nTrk, 1 ));
599731
600732 auto start_network_eval = std::chrono::high_resolution_clock::now ();
601- float * output_network = network.evalModel (track_properties );
733+ float * output_network = network.evalModel < float >(inputTensors );
602734 auto stop_network_eval = std::chrono::high_resolution_clock::now ();
603735 duration_network += std::chrono::duration<float , std::ratio<1 , 1000000000 >>(stop_network_eval - start_network_eval).count ();
604736 for (uint64_t k = 0 ; k < prediction_size; k += output_dimensions) {
605737 for (int l = 0 ; l < output_dimensions; l++) {
606- network_prediction[k + l + prediction_size * loop_counter ] = output_network[k + l];
738+ network_prediction[k + l + prediction_size * j ] = output_network[k + l];
607739 }
608740 }
609-
610- counter_track_props = 0 ;
611- loop_counter += 1 ;
612741 }
613- track_properties.clear ();
614-
615742 auto stop_network_total = std::chrono::high_resolution_clock::now ();
616743 LOG (debug) << " Neural Network for the TPC PID response correction: Time per track (eval ONNX): " << duration_network / (size * 9 ) << " ns ; Total time (eval ONNX): " << duration_network / 1000000000 << " s" ;
617744 LOG (debug) << " Neural Network for the TPC PID response correction: Time per track (eval + overhead): " << std::chrono::duration<float , std::ratio<1 , 1000000000 >>(stop_network_total - start_network_total).count () / (size * 9 ) << " ns ; Total time (eval + overhead): " << std::chrono::duration<float , std::ratio<1 , 1000000000 >>(stop_network_total - start_network_total).count () / 1000000000 << " s" ;
0 commit comments