diff --git a/Common/Tools/PID/pidTPCModule.h b/Common/Tools/PID/pidTPCModule.h index 24c7683b70c..8de26239896 100644 --- a/Common/Tools/PID/pidTPCModule.h +++ b/Common/Tools/PID/pidTPCModule.h @@ -414,6 +414,7 @@ class pidTPCModule network.initModel(pidTPCopts.networkPathLocally.value, pidTPCopts.enableNetworkOptimizations.value, pidTPCopts.networkSetNumThreads.value, strtoul(headers["Valid-From"].c_str(), NULL, 0), strtoul(headers["Valid-Until"].c_str(), NULL, 0)); std::vector dummyInput(network.getNumInputNodes(), 1.); network.evalModel(dummyInput); /// Init the model evaluations + setupColumnInputNetwork(); LOGP(info, "Retrieved NN corrections for production tag {}, pass number {}, and NN-Version {}", headers["LPMProductionTag"], headers["RecoPassName"], headers["NN-Version"]); } else { LOG(fatal) << "No valid NN object found matching retrieved Bethe-Bloch parametrisation for pass " << metadata["RecoPassName"] << ". Please ensure that the requested pass has dedicated NN corrections available"; @@ -427,6 +428,7 @@ class pidTPCModule network.initModel(pidTPCopts.networkPathLocally.value, pidTPCopts.enableNetworkOptimizations.value, pidTPCopts.networkSetNumThreads.value); std::vector dummyInput(network.getNumInputNodes(), 1.); network.evalModel(dummyInput); // This is an initialisation and might reduce the overhead of the model + setupColumnInputNetwork(); } } else { return; @@ -438,6 +440,110 @@ class pidTPCModule } } // end init + //__________________________________________________ + void setupColumnInputNetwork() + { + using PI = o2::ml::OnnxModel::PreprocInput; + using PF = o2::ml::OnnxModel::PreprocFeature; + const int nFeat = network.getNumInputNodes(); // # network features (6..9), original model + + // Raw graph inputs (this order defines the tensor feeding order in + // createNetworkPrediction). All per-track columns are wrapped zero-copy from + // the Arrow buffers; nclNorm/hrDivisor/hrFallback are per-DF runtime scalars. + std::vector in; + in.push_back({"tpcInnerParam", PI::Type::TrackFloat}); + in.push_back({"tgl", PI::Type::TrackFloat}); + in.push_back({"signed1Pt", PI::Type::TrackFloat}); + in.push_back({"mass", PI::Type::ScalarFloat}); + in.push_back({"collisionId", PI::Type::TrackInt32}); + in.push_back({"multArray", PI::Type::CollisionFloat}); + in.push_back({"nclNorm", PI::Type::ScalarFloat}); + in.push_back({"nclsFindable", PI::Type::TrackUint8}); + in.push_back({"nclsFMF", PI::Type::TrackInt8}); + if (nFeat >= 7) { + in.push_back({"occArray", PI::Type::CollisionFloat}); + } + if (nFeat >= 8) { + in.push_back({"hrArray", PI::Type::CollisionFloat}); + in.push_back({"hrDivisor", PI::Type::ScalarFloat}); + in.push_back({"hrFallback", PI::Type::ScalarFloat}); + } + if (nFeat >= 9) { + in.push_back({"phi", PI::Type::TrackFloat}); + } + in.push_back({"validMask", PI::Type::TrackBool}); + + // Per-feature preprocessing (exactly nFeat entries, in the training order). + std::vector feat; + { + PF f; + f.op = PF::Op::Passthrough; + f.a = "tpcInnerParam"; + feat.push_back(f); + } + { + PF f; + f.op = PF::Op::Passthrough; + f.a = "tgl"; + feat.push_back(f); + } + { + PF f; + f.op = PF::Op::Passthrough; + f.a = "signed1Pt"; + feat.push_back(f); + } + { + PF f; + f.op = PF::Op::BroadcastScalar; + f.a = "mass"; + f.shapeRef = "collisionId"; + feat.push_back(f); + } + { + PF f; + f.op = PF::Op::GatherNormWhere; + f.a = "multArray"; + f.b = "collisionId"; + f.c = {11000.f, 1.f, 0.f}; + feat.push_back(f); + } + { + PF f; + f.op = PF::Op::NClsSqrtRecip; + f.a = "nclsFindable"; + f.b = "nclsFMF"; + f.scaleInput = "nclNorm"; + feat.push_back(f); + } + if (nFeat >= 7) { + PF f; + f.op = PF::Op::GatherNormWhere; + f.a = "occArray"; + f.b = "collisionId"; + f.c = {60000.f, 1.f, 0.f}; + feat.push_back(f); + } + if (nFeat >= 8) { + PF f; + f.op = PF::Op::GatherNormWhere; + f.a = "hrArray"; + f.b = "collisionId"; + f.scaleInput = "hrDivisor"; + f.fallbackInput = "hrFallback"; + feat.push_back(f); + } + if (nFeat >= 9) { + PF f; + f.op = PF::Op::Mod2; + f.a = "phi"; + f.c = {2.f * static_cast(M_PI), 2.f * static_cast(M_PI), static_cast(M_PI) / 9.0f}; + feat.push_back(f); + } + + network.setupColumnInputs(in, feat, "validMask"); + } + //__________________________________________________ template std::vector createNetworkPrediction(TCCDB& ccdb, soa::Join const& collisions, M const& mults, T const& tracks, B const& bcs, const size_t size) @@ -489,6 +595,7 @@ class pidTPCModule network.initModel(pidTPCopts.networkPathLocally.value, pidTPCopts.enableNetworkOptimizations.value, pidTPCopts.networkSetNumThreads.value, strtoul(headers["Valid-From"].c_str(), NULL, 0), strtoul(headers["Valid-Until"].c_str(), NULL, 0)); std::vector dummyInput(network.getNumInputNodes(), 1.); network.evalModel(dummyInput); + setupColumnInputNetwork(); LOGP(info, "Retrieved NN corrections for production tag {}, pass number {}, NN-Version number {}", headers["LPMProductionTag"], headers["RecoPassName"], headers["NN-Version"]); } else { LOG(fatal) << "No valid NN object found matching retrieved Bethe-Bloch parametrisation for pass " << metadata["RecoPassName"] << ". Please ensure that the requested pass has dedicated NN corrections available"; @@ -497,19 +604,14 @@ class pidTPCModule } // Defining some network parameters - int input_dimensions = network.getNumInputNodes(); + const int nFeat = network.getNumFeatures(); int output_dimensions = network.getNumOutputNodes(); - const uint64_t track_prop_size = input_dimensions * size; const uint64_t prediction_size = output_dimensions * size; network_prediction = std::vector(prediction_size * 9); // For each mass hypotheses const float nNclNormalization = response->GetNClNormalization(); float duration_network = 0; - std::vector track_properties(track_prop_size); - uint64_t counter_track_props = 0; - int loop_counter = 0; - // To load the Hadronic rate once for each collision float hadronicRateBegin = 0.; std::vector hadronicRateForCollision(collisions.size(), 0.0f); @@ -530,88 +632,113 @@ class pidTPCModule hadronicRateBegin = 0.0f; } - // Filling a std::vector to be evaluated by the network - // Evaluation on single tracks brings huge overhead: Thus evaluation is done on one large vector static constexpr int NParticleTypes = 9; constexpr int ExpectedInputDimensionsNNV2 = 7; constexpr int ExpectedInputDimensionsNNV3 = 8; constexpr int ExpectedInputDimensionsNNV4 = 9; - constexpr auto NetworkVersionV2 = "2"; - constexpr auto NetworkVersionV3 = "3"; - constexpr auto NetworkVersionV4 = "4"; - for (int j = 0; j < NParticleTypes; j++) { // Loop over particle number for which network correction is used - for (auto const& trk : tracks) { - if (!trk.hasTPC()) { - continue; - } - if (pidTPCopts.skipTPCOnly) { - if (!trk.hasITS() && !trk.hasTRD() && !trk.hasTOF()) { - continue; - } - } - track_properties[counter_track_props] = trk.tpcInnerParam(); - track_properties[counter_track_props + 1] = trk.tgl(); - track_properties[counter_track_props + 2] = trk.signed1Pt(); - track_properties[counter_track_props + 3] = o2::track::pid_constants::sMasses[j]; - track_properties[counter_track_props + 4] = trk.has_collision() ? mults[trk.collisionId()] / 11000. : 1.; - track_properties[counter_track_props + 5] = std::sqrt(nNclNormalization / trk.tpcNClsFound()); - if (input_dimensions == ExpectedInputDimensionsNNV2 && networkVersion == NetworkVersionV2) { - track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.; - } - if (input_dimensions == ExpectedInputDimensionsNNV3 && networkVersion == NetworkVersionV3) { - track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.; - if (trk.has_collision()) { - if (collsys == CollisionSystemType::kCollSyspp) { - track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 1500.; - } else { - track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 50.; - } - } else { - // asign Hadronic Rate at beginning of run if track does not belong to a collision - if (collsys == CollisionSystemType::kCollSyspp) { - track_properties[counter_track_props + 7] = hadronicRateBegin / 1500.; - } else { - track_properties[counter_track_props + 7] = hadronicRateBegin / 50.; - } - } + + const float hadronicRateDivisor = (collsys == CollisionSystemType::kCollSyspp) ? 1500.f : 50.f; + + // Per-collision arrays (O(nColl)); gathered per track inside the model via the + // collisionId column, then normalised in-graph. + const int64_t nColl = static_cast(collisions.size()); + std::vector multArray(nColl); + std::vector occArray(nFeat >= ExpectedInputDimensionsNNV2 ? nColl : 0); + { + int64_t c = 0; + for (const auto& col : collisions) { + multArray[c] = static_cast(mults[c]); + if (nFeat >= ExpectedInputDimensionsNNV2) { + occArray[c] = col.ft0cOccupancyInTimeRange(); } + ++c; + } + } - if (input_dimensions == ExpectedInputDimensionsNNV4 && networkVersion == NetworkVersionV4) { - track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.; - if (trk.has_collision()) { - if (collsys == CollisionSystemType::kCollSyspp) { - track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 1500.; - } else { - track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 50.; - } - } else { - // asign Hadronic Rate at beginning of run if track does not belong to a collision - if (collsys == CollisionSystemType::kCollSyspp) { - track_properties[counter_track_props + 7] = hadronicRateBegin / 1500.; - } else { - track_properties[counter_track_props + 7] = hadronicRateBegin / 50.; - } - } - track_properties[counter_track_props + 8] = std::fmod(std::fmod(trk.phi(), 2 * M_PI) + 2 * M_PI, M_PI / 9.0); + // Raw per-track Arrow column buffers (zero-copy; one chunk per DataFrame). + auto arrowTable = tracks.asArrowTable(); + auto chunk0 = [&](const char* name) -> std::shared_ptr { + const int idx = arrowTable->schema()->GetFieldIndex(name); + if (idx < 0) { + LOG(fatal) << "createNetworkPrediction: column '" << name << "' not found in tracks table"; + } + auto col = arrowTable->column(idx); + if (col->num_chunks() != 1) { + LOG(fatal) << "createNetworkPrediction: column '" << name << "' has " << col->num_chunks() + << " chunks; a single chunk per DataFrame is required for zero-copy input"; + } + return col->chunk(0); + }; + const int64_t nTrk = static_cast(tracks.size()); + const float* pTpcInner = std::static_pointer_cast(chunk0("fTPCInnerParam"))->raw_values(); + const float* pTgl = std::static_pointer_cast(chunk0("fTgl"))->raw_values(); + const float* pSigned1Pt = std::static_pointer_cast(chunk0("fSigned1Pt"))->raw_values(); + const int32_t* pCollId = std::static_pointer_cast(chunk0("fIndexCollisions"))->raw_values(); + const uint8_t* pFindable = std::static_pointer_cast(chunk0("fTPCNClsFindable"))->raw_values(); + const int8_t* pFMF = std::static_pointer_cast(chunk0("fTPCNClsFindableMinusFound"))->raw_values(); + const float* pPhi = (nFeat >= ExpectedInputDimensionsNNV4) + ? std::static_pointer_cast(chunk0("fPhi"))->raw_values() + : nullptr; + + // Single boolean mask of the tracks the network runs on; the model Compress'es + // to exactly these rows so the output is compact and the consumer's + // count_tracks indexing is unchanged. Condition matches process()'s counter. + std::vector validMask(nTrk); + { + int64_t t = 0; + for (auto const& trk : tracks) { + bool valid = trk.hasTPC(); + if (valid && pidTPCopts.skipTPCOnly && !trk.hasITS() && !trk.hasTRD() && !trk.hasTOF()) { + valid = false; } - counter_track_props += input_dimensions; + validMask[t++] = valid ? 1 : 0; + } + } + + auto memInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); + const int64_t one = 1; + float massVal = 0.f; + float nclNormVal = nNclNormalization; + float hrDivisorVal = hadronicRateDivisor; + float hrFallbackVal = hadronicRateBegin / hadronicRateDivisor; + + // Evaluate once per mass hypothesis; only the mass scalar input changes. + for (int j = 0; j < NParticleTypes; j++) { + massVal = o2::track::pid_constants::sMasses[j]; + + std::vector inputTensors; + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, const_cast(pTpcInner), nTrk, &nTrk, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, const_cast(pTgl), nTrk, &nTrk, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, const_cast(pSigned1Pt), nTrk, &nTrk, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, &massVal, 1, &one, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, const_cast(pCollId), nTrk, &nTrk, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, multArray.data(), nColl, &nColl, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, &nclNormVal, 1, &one, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, const_cast(pFindable), nTrk, &nTrk, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, const_cast(pFMF), nTrk, &nTrk, 1)); + if (nFeat >= ExpectedInputDimensionsNNV2) { + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, occArray.data(), nColl, &nColl, 1)); + } + if (nFeat >= ExpectedInputDimensionsNNV3) { + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, hadronicRateForCollision.data(), nColl, &nColl, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, &hrDivisorVal, 1, &one, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, &hrFallbackVal, 1, &one, 1)); } + if (nFeat >= ExpectedInputDimensionsNNV4) { + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, const_cast(pPhi), nTrk, &nTrk, 1)); + } + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, reinterpret_cast(validMask.data()), nTrk, &nTrk, 1)); auto start_network_eval = std::chrono::high_resolution_clock::now(); - float* output_network = network.evalModel(track_properties); + float* output_network = network.evalModel(inputTensors); auto stop_network_eval = std::chrono::high_resolution_clock::now(); duration_network += std::chrono::duration>(stop_network_eval - start_network_eval).count(); for (uint64_t k = 0; k < prediction_size; k += output_dimensions) { for (int l = 0; l < output_dimensions; l++) { - network_prediction[k + l + prediction_size * loop_counter] = output_network[k + l]; + network_prediction[k + l + prediction_size * j] = output_network[k + l]; } } - - counter_track_props = 0; - loop_counter += 1; } - track_properties.clear(); - auto stop_network_total = std::chrono::high_resolution_clock::now(); LOG(debug) << "Neural Network for the TPC PID response correction: Time per track (eval ONNX): " << duration_network / (size * 9) << "ns ; Total time (eval ONNX): " << duration_network / 1000000000 << " s"; LOG(debug) << "Neural Network for the TPC PID response correction: Time per track (eval + overhead): " << std::chrono::duration>(stop_network_total - start_network_total).count() / (size * 9) << "ns ; Total time (eval + overhead): " << std::chrono::duration>(stop_network_total - start_network_total).count() / 1000000000 << " s"; diff --git a/Tools/ML/CMakeLists.txt b/Tools/ML/CMakeLists.txt index b95c108584a..4dde4095cb6 100644 --- a/Tools/ML/CMakeLists.txt +++ b/Tools/ML/CMakeLists.txt @@ -11,5 +11,5 @@ o2physics_add_library(MLCore SOURCES model.cxx - PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore ONNXRuntime::ONNXRuntime + PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore ONNXRuntime::ONNXRuntime ONNX::onnx_proto ) diff --git a/Tools/ML/model.h b/Tools/ML/model.h index 3be08e72fa9..8428dc492d9 100644 --- a/Tools/ML/model.h +++ b/Tools/ML/model.h @@ -22,15 +22,21 @@ #include +#include #include #include #include #include +#include #include #include +#include +#include #include +#include #include +#include #include #include @@ -40,6 +46,165 @@ namespace o2 namespace ml { +// ============================================================================ +// FusedPreprocLinear custom op +// ---------------------------------------------------------------------------- +// Fuses the per-column preprocessing AND the first linear layer into a single +// streaming kernel: for every selected (masked-in) track it computes the K +// features on a small stack and immediately applies W/b, writing one [H] output +// row. No [M, K] feature buffer and no per-op intermediate tensors are ever +// materialised, so the peak memory is just the [M, H] output. The op is generic +// and driven entirely by node attributes (the same spec as PreprocFeature), so +// the original network's layers 1..N are copied verbatim on top of its output. +// +// Op enum (matches OnnxModel::PreprocFeature::Op order): +// 0 Passthrough 1 BroadcastScalar 2 NClsSqrtRecip 3 Mod2 4 GatherNormWhere +// Inputs are a single heterogeneous variadic list: the raw spec inputs (in +// declaration order), then W [K,H] (transB already folded in), then optional b. +struct FusedPreprocLinearKernel { + int mK = 0, mH = 0, mMaskIdx = 0, mWeightIdx = 0, mBiasIdx = -1; + std::vector mOps, mA, mB, mScale, mFallback; + std::vector mConsts; + + explicit FusedPreprocLinearKernel(const OrtKernelInfo* info) + { + Ort::ConstKernelInfo ki(info); + mK = static_cast(ki.GetAttribute("n_features")); + mH = static_cast(ki.GetAttribute("hidden")); + mMaskIdx = static_cast(ki.GetAttribute("mask_index")); + mWeightIdx = static_cast(ki.GetAttribute("weight_index")); + mBiasIdx = static_cast(ki.GetAttribute("bias_index")); + mOps = ki.GetAttributes("ops"); + mA = ki.GetAttributes("a_idx"); + mB = ki.GetAttributes("b_idx"); + mScale = ki.GetAttributes("scale_idx"); + mFallback = ki.GetAttributes("fallback_idx"); + mConsts = ki.GetAttributes("consts"); + } + + void Compute(OrtKernelContext* context) + { + Ort::KernelContext ctx(context); + const size_t nIn = ctx.GetInputCount(); + struct In { + const void* p = nullptr; + ONNXTensorElementDataType t = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + }; + std::vector in(nIn); + std::vector hold; + hold.reserve(nIn); + for (size_t i = 0; i < nIn; ++i) { + hold.emplace_back(ctx.GetInput(i)); + in[i] = {hold[i].GetTensorRawData(), hold[i].GetTensorTypeAndShapeInfo().GetElementType()}; + } + + auto readF = [&](int idx, int64_t n) -> float { + const In& x = in[idx]; + switch (x.t) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return reinterpret_cast(x.p)[n]; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return static_cast(reinterpret_cast(x.p)[n]); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return static_cast(reinterpret_cast(x.p)[n]); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return static_cast(reinterpret_cast(x.p)[n]); + default: + return 0.f; + } + }; + auto readI = [&](int idx, int64_t n) -> int64_t { + const In& x = in[idx]; + if (x.t == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { + return reinterpret_cast(x.p)[n]; + } + return static_cast(readF(idx, n)); + }; + + const bool* mask = reinterpret_cast(in[mMaskIdx].p); + const int64_t N = hold[mMaskIdx].GetTensorTypeAndShapeInfo().GetShape()[0]; + int64_t M = 0; + for (int64_t n = 0; n < N; ++n) { + if (mask[n]) { + ++M; + } + } + + const float* W = reinterpret_cast(in[mWeightIdx].p); // [K, H] + const float* b = (mBiasIdx >= 0) ? reinterpret_cast(in[mBiasIdx].p) : nullptr; + + auto out = ctx.GetOutput(0, {M, static_cast(mH)}); + float* o = out.GetTensorMutableData(); + + std::vector feat(mK); + int64_t m = 0; + for (int64_t n = 0; n < N; ++n) { + if (!mask[n]) { + continue; + } + for (int k = 0; k < mK; ++k) { + const float c0 = mConsts[3 * k], c1 = mConsts[3 * k + 1], c2 = mConsts[3 * k + 2]; + switch (mOps[k]) { + case 0: // Passthrough + feat[k] = readF(static_cast(mA[k]), n); + break; + case 1: // BroadcastScalar (e.g. mass, a [1] scalar) + feat[k] = readF(static_cast(mA[k]), 0); + break; + case 2: { // NClsSqrtRecip: sqrt(numer / (a - b)) + const float ncl = readF(static_cast(mA[k]), n) - readF(static_cast(mB[k]), n); + const float numer = mScale[k] >= 0 ? readF(static_cast(mScale[k]), 0) : c0; + feat[k] = std::sqrt(numer / ncl); + break; + } + case 3: { // Mod2: fmod(fmod(a, c0) + c1, c2) + const float a = readF(static_cast(mA[k]), n); + feat[k] = std::fmod(std::fmod(a, c0) + c1, c2); + break; + } + case 4: { // GatherNormWhere: idx<0 ? fallback : array[idx]/denom + const int64_t idx = readI(static_cast(mB[k]), n); + if (idx < 0) { + feat[k] = mFallback[k] >= 0 ? readF(static_cast(mFallback[k]), 0) : c1; + } else { + const float denom = mScale[k] >= 0 ? readF(static_cast(mScale[k]), 0) : c0; + feat[k] = readF(static_cast(mA[k]), idx) / denom; + } + break; + } + default: + feat[k] = 0.f; + break; + } + } + float* orow = o + m * mH; + for (int h = 0; h < mH; ++h) { + orow[h] = b ? b[h] : 0.f; + } + for (int k = 0; k < mK; ++k) { + const float fk = feat[k]; + const float* wrow = W + static_cast(k) * mH; + for (int h = 0; h < mH; ++h) { + orow[h] += fk * wrow[h]; + } + } + ++m; + } + } +}; + +struct FusedPreprocLinearOp : Ort::CustomOpBase { + void* CreateKernel(const OrtApi&, const OrtKernelInfo* info) const { return new FusedPreprocLinearKernel(info); } + const char* GetName() const { return "FusedPreprocLinear"; } + size_t GetInputTypeCount() const { return 1; } + ONNXTensorElementDataType GetInputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; } + bool GetVariadicInputHomogeneity() const { return false; } // heterogeneous: mixed input types + int GetVariadicInputMinArity() const { return 1; } + size_t GetOutputTypeCount() const { return 1; } + ONNXTensorElementDataType GetOutputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; } +}; + class OnnxModel { @@ -132,6 +297,360 @@ class OnnxModel mSession.reset(new Ort::Session{*mEnv, modelPath.c_str(), sessionOptions}); } + // Declaration of a raw graph input fed directly from an Arrow buffer. + struct PreprocInput { + enum class Type { TrackFloat, // per-track float column [N] + TrackInt32, // per-track int32 column [N] + TrackUint8, // per-track uint8 column [N] + TrackInt8, // per-track int8 column [N] + TrackBool, // per-track bool mask [N] + CollisionFloat, // per-collision float array [C] + ScalarFloat }; // single scalar (e.g. mass) [1] + std::string name; + Type type; + }; + + // Preprocessing recipe for one network feature (produces a [N] float tensor + // that feeds column i of the decomposed first layer). + struct PreprocFeature { + enum class Op { + Passthrough, // feature = a (a: TrackFloat) + BroadcastScalar, // feature = Expand(a, shape(shapeRef)) (a: ScalarFloat) + NClsSqrtRecip, // feature = Sqrt(c0 / (float(a) - float(b))) (a,b: Track int cols) + Mod2, // feature = Mod(Mod(a, c0) + c1, c2) (a: TrackFloat) + GatherNormWhere // feature = Where(b<0, fb, Gather(a, b) / c0) (a: CollisionFloat, b: TrackInt32) + }; + Op op; + std::string a; // primary input + std::string b; // secondary input (NCls: 2nd col; Gather: index) + std::string shapeRef; // BroadcastScalar: [N] input to size the Expand + std::string fallbackInput; // GatherNormWhere: scalar input for the b<0 fallback; "" => use c[1] + std::string scaleInput; // NCls: numerator scalar input; Gather: divisor scalar input; "" => use c[0] + std::array c{}; // op constants + }; + + // Rebuild the model from scratch so the network reads its raw Arrow inputs + // directly and performs all preprocessing + the first linear layer inside the + // graph. Each feature column is produced by a small preprocessing subgraph + // (`features`) from the raw inputs (`inputDefs`), then the first linear layer is + // decomposed + // layer0 = X @ W + b = sum_i (feat_i[N,1] @ W_row_i[1,H]) + b + // so no [N, K] interleaving / Concat buffer is ever materialised. The original + // layers 1..N are copied verbatim on top of the decomposed layer-0 output. + // Building a fresh model (rather than augmenting the existing one) is required: + // the Model Editor can only add nodes, so the original layer-0 Gemm would + // otherwise remain and collide on the layer-0 output name. + // If maskInput is non-empty it names a bool [N] input; each feature is then + // Compress'd by it so the matmul runs only on the selected (valid) rows and the + // output is the compact set of selected tracks, in order. + void setupColumnInputs(const std::vector& inputDefs, + const std::vector& features, + const std::string& maskInput = "") + { + const int numFeatures = static_cast(features.size()); + if (numFeatures != mInputShapes[0][1]) { + LOG(fatal) << "setupColumnInputs: expected " << mInputShapes[0][1] << " features, got " << numFeatures; + return; + } + + onnx::ModelProto onnxModel; + { + std::ifstream ifs(modelPath, std::ios::binary); + if (!ifs || !onnxModel.ParseFromIstream(&ifs)) { + LOG(fatal) << "setupColumnInputs: failed to parse ONNX model from " << modelPath; + return; + } + } + const auto& og = onnxModel.graph(); + + auto findInit = [&](const std::string& name) -> const onnx::TensorProto* { + for (int i = 0; i < og.initializer_size(); ++i) { + if (og.initializer(i).name() == name) { + return &og.initializer(i); + } + } + return nullptr; + }; + auto tensorFloats = [&](const onnx::TensorProto* t, int64_t n) { + std::vector v(n); + if (t->raw_data().size() > 0) { + std::memcpy(v.data(), t->raw_data().data(), n * sizeof(float)); + } else { + for (int64_t i = 0; i < n; ++i) { + v[i] = t->float_data(i); + } + } + return v; + }; + + // --- locate the first linear layer and extract its weights. The custom op + // reproduces this layer's pre-activation output, so the original layers + // 1..N are copied on top (the first layer and its input Cast are dropped). --- + const onnx::NodeProto* first = nullptr; + for (int i = 0; i < og.node_size(); ++i) { + const auto& n = og.node(i); + if (n.op_type() == "Gemm" || n.op_type() == "MatMul") { + first = &n; + break; + } + } + if (!first) { + LOG(fatal) << "setupColumnInputs: no Gemm/MatMul layer found in model"; + return; + } + const std::string layer0Out = first->output(0); + const onnx::TensorProto* wT = findInit(first->input(1)); + if (!wT || wT->dims_size() != 2) { + LOG(fatal) << "setupColumnInputs: first-layer weight initializer not found or not 2D"; + return; + } + bool transB = false; + if (first->op_type() == "Gemm") { + for (int i = 0; i < first->attribute_size(); ++i) { + if (first->attribute(i).name() == "transB") { + transB = (first->attribute(i).i() != 0); + } + } + } + const int K = transB ? static_cast(wT->dims(1)) : static_cast(wT->dims(0)); + const int H = transB ? static_cast(wT->dims(0)) : static_cast(wT->dims(1)); + if (K != numFeatures) { + LOG(fatal) << "setupColumnInputs: first-layer K=" << K << " != numFeatures=" << numFeatures; + return; + } + const std::vector wRaw = tensorFloats(wT, static_cast(K) * H); + std::vector wKH(static_cast(K) * H); // row-major [K, H], transB folded in + for (int k = 0; k < K; ++k) { + for (int h = 0; h < H; ++h) { + wKH[static_cast(k) * H + h] = + transB ? wRaw[static_cast(h) * K + k] : wRaw[static_cast(k) * H + h]; + } + } + std::vector bData; + if (first->op_type() == "Gemm" && first->input_size() >= 3 && !first->input(2).empty()) { + if (const onnx::TensorProto* bT = findInit(first->input(2))) { + bData = tensorFloats(bT, H); + } + } + + // --- map each raw input name to its position in the custom node's input list, + // and encode the per-feature preprocessing spec for the custom kernel. --- + std::vector rawInputNames; + rawInputNames.reserve(inputDefs.size()); + std::map inputIndex; + for (int i = 0; i < static_cast(inputDefs.size()); ++i) { + inputIndex[inputDefs[i].name] = i; + rawInputNames.push_back(inputDefs[i].name); + } + auto idxOf = [&](const std::string& name) -> int64_t { + auto it = inputIndex.find(name); + return it == inputIndex.end() ? -1 : static_cast(it->second); + }; + std::vector specOps(K), specA(K), specB(K), specScale(K), specFallback(K); + std::vector specConsts(static_cast(3) * K); + for (int i = 0; i < K; ++i) { + const auto& f = features[i]; + specOps[i] = static_cast(f.op); + specA[i] = idxOf(f.a); + specB[i] = f.b.empty() ? -1 : idxOf(f.b); + specScale[i] = f.scaleInput.empty() ? -1 : idxOf(f.scaleInput); + specFallback[i] = f.fallbackInput.empty() ? -1 : idxOf(f.fallbackInput); + specConsts[3 * i] = f.c[0]; + specConsts[3 * i + 1] = f.c[1]; + specConsts[3 * i + 2] = f.c[2]; + } + + // --- keep the original layers 1..N: every node reachable backwards from the + // graph outputs, except the replaced first layer (and the now-dead input + // Cast). They consume the custom node's pre-activation output. --- + std::set live; + for (int i = 0; i < og.output_size(); ++i) { + live.insert(og.output(i).name()); + } + std::vector keep(og.node_size(), false); + bool changed = true; + while (changed) { + changed = false; + for (int i = 0; i < og.node_size(); ++i) { + if (keep[i] || &og.node(i) == first) { + continue; + } + const auto& n = og.node(i); + bool produces = false; + for (int o = 0; o < n.output_size(); ++o) { + if (live.count(n.output(o))) { + produces = true; + break; + } + } + if (!produces) { + continue; + } + keep[i] = true; + changed = true; + for (int in = 0; in < n.input_size(); ++in) { + live.insert(n.input(in)); + } + } + } + + // --- assemble the rebuilt model as an ONNX protobuf and load it from bytes. + // (The Model-Editor session path does not resolve user custom ops; the + // standard from-bytes session does.) Layers 1..N are copied verbatim; only + // the first Gemm + dead input Cast are replaced by the fused custom node. --- + onnx::ModelProto nm; + nm.set_ir_version(onnxModel.ir_version()); + for (const auto& oi : onnxModel.opset_import()) { + *nm.add_opset_import() = oi; + } + { + auto* oi = nm.add_opset_import(); + oi->set_domain("ai.o2.ml"); + oi->set_version(1); + } + onnx::GraphProto* ng = nm.mutable_graph(); + ng->set_name(og.name()); + + // raw graph inputs + for (const auto& pin : inputDefs) { + auto* vp = ng->add_input(); + vp->set_name(pin.name); + auto* tt = vp->mutable_type()->mutable_tensor_type(); + int et = onnx::TensorProto::FLOAT; + const char* dim = "N"; + bool scalar = false; + switch (pin.type) { + case PreprocInput::Type::TrackFloat: et = onnx::TensorProto::FLOAT; break; + case PreprocInput::Type::TrackInt32: et = onnx::TensorProto::INT32; break; + case PreprocInput::Type::TrackUint8: et = onnx::TensorProto::UINT8; break; + case PreprocInput::Type::TrackInt8: et = onnx::TensorProto::INT8; break; + case PreprocInput::Type::TrackBool: et = onnx::TensorProto::BOOL; break; + case PreprocInput::Type::CollisionFloat: et = onnx::TensorProto::FLOAT; dim = "C"; break; + case PreprocInput::Type::ScalarFloat: et = onnx::TensorProto::FLOAT; scalar = true; break; + } + tt->set_elem_type(et); + auto* sh = tt->mutable_shape(); + if (scalar) { + sh->add_dim()->set_dim_value(1); + } else { + sh->add_dim()->set_dim_param(dim); + } + } + + // W [K,H] (transB folded) and optional b [H] as initializers fed to the node. + auto addRawInit = [&](const std::string& name, const std::vector& data, + const std::vector& shape) { + auto* t = ng->add_initializer(); + t->set_name(name); + t->set_data_type(onnx::TensorProto::FLOAT); + for (const auto d : shape) { + t->add_dims(d); + } + t->set_raw_data(std::string(reinterpret_cast(data.data()), data.size() * sizeof(float))); + }; + addRawInit("_l0_weight", wKH, {static_cast(K), static_cast(H)}); + std::vector nodeInputs = rawInputNames; + const int64_t weightIdx = static_cast(nodeInputs.size()); + nodeInputs.push_back("_l0_weight"); + int64_t biasIdx = -1; + if (!bData.empty()) { + addRawInit("_l0_bias", bData, {static_cast(H)}); + biasIdx = static_cast(nodeInputs.size()); + nodeInputs.push_back("_l0_bias"); + } + + // the fused preprocessing + first-layer custom node -> layer0Out [M, H] + { + auto* cn = ng->add_node(); + cn->set_op_type("FusedPreprocLinear"); + cn->set_domain("ai.o2.ml"); + cn->set_name("_fused_preproc_linear"); + for (const auto& in : nodeInputs) { + cn->add_input(in); + } + cn->add_output(layer0Out); + auto addAttrI = [&](const char* nm, int64_t v) { + auto* a = cn->add_attribute(); + a->set_name(nm); + a->set_type(onnx::AttributeProto::INT); + a->set_i(v); + }; + auto addAttrInts = [&](const char* nm, const std::vector& v) { + auto* a = cn->add_attribute(); + a->set_name(nm); + a->set_type(onnx::AttributeProto::INTS); + for (const auto x : v) { + a->add_ints(x); + } + }; + addAttrI("n_features", K); + addAttrI("hidden", H); + addAttrI("mask_index", maskInput.empty() ? -1 : idxOf(maskInput)); + addAttrI("weight_index", weightIdx); + addAttrI("bias_index", biasIdx); + addAttrInts("ops", specOps); + addAttrInts("a_idx", specA); + addAttrInts("b_idx", specB); + addAttrInts("scale_idx", specScale); + addAttrInts("fallback_idx", specFallback); + auto* ac = cn->add_attribute(); + ac->set_name("consts"); + ac->set_type(onnx::AttributeProto::FLOATS); + for (const auto x : specConsts) { + ac->add_floats(x); + } + } + + // kept original nodes (layers 1..N) + the initializers they reference + int keptNodes = 0; + std::set neededInits; + for (int i = 0; i < og.node_size(); ++i) { + if (!keep[i]) { + continue; + } + const auto& n = og.node(i); + *ng->add_node() = n; + ++keptNodes; + for (int in = 0; in < n.input_size(); ++in) { + if (findInit(n.input(in))) { + neededInits.insert(n.input(in)); + } + } + } + for (const auto& name : neededInits) { + *ng->add_initializer() = *findInit(name); + } + for (int i = 0; i < og.output_size(); ++i) { + *ng->add_output() = og.output(i); + } + + std::string modelBytes; + nm.SerializeToString(&modelBytes); + + // Register the FusedPreprocLinear custom op (once) on the session options. + if (!mFusedOpRegistered) { + static FusedPreprocLinearOp fusedOp; + static Ort::CustomOpDomain fusedDomain("ai.o2.ml"); + static bool opAdded = false; + if (!opAdded) { + fusedDomain.Add(&fusedOp); + opAdded = true; + } + sessionOptions.Add(fusedDomain); + mFusedOpRegistered = true; + } + + mSession = std::make_shared(*mEnv, modelBytes.data(), modelBytes.size(), sessionOptions); + + mInputNames = rawInputNames; + mInputShapes.assign(rawInputNames.size(), std::vector{-1}); + mNumFeatures = K; + + LOG(info) << "setupColumnInputs: rebuilt model with " << rawInputNames.size() + << " raw inputs -> fused preprocessing + first layer (K=" << K << ", H=" << H + << ") in one custom op, " << keptNodes << " downstream nodes kept"; + } + // Getters & Setters Ort::SessionOptions* getSessionOptions() { return &sessionOptions; } // For optimizations in post std::shared_ptr getSession() @@ -139,6 +658,9 @@ class OnnxModel return mSession; } int getNumInputNodes() const { return mInputShapes[0][1]; } + bool hasColumnInputs() const { return mInputShapes.size() > 1 || (mInputShapes.size() == 1 && mInputShapes[0].size() == 1); } + int getNumColumns() const { return static_cast(mInputNames.size()); } + int getNumFeatures() const { return mNumFeatures; } std::vector> getInputShapes() const { return mInputShapes; } int getNumOutputNodes() const { return mOutputShapes[0][1]; } uint64_t getValidityFrom() const { return validFrom; } @@ -154,6 +676,8 @@ class OnnxModel // Input & Output specifications of the loaded network std::vector mInputNames; std::vector> mInputShapes; + int mNumFeatures = 0; + bool mFusedOpRegistered = false; std::vector mOutputNames; std::vector> mOutputShapes; diff --git a/dependencies/O2PhysicsDependencies.cmake b/dependencies/O2PhysicsDependencies.cmake index d1a2a6280ac..807755aa3b8 100644 --- a/dependencies/O2PhysicsDependencies.cmake +++ b/dependencies/O2PhysicsDependencies.cmake @@ -25,5 +25,6 @@ find_package(fjcontrib) set_package_properties(fjcontrib PROPERTIES TYPE REQUIRED) find_package(ONNXRuntime) +find_package(ONNX) feature_summary(WHAT ALL FATAL_ON_MISSING_REQUIRED_PACKAGES)