Skip to content
183 changes: 107 additions & 76 deletions addons/ONNXRuntime/python/jetFlavourHelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,50 @@
ROOT.gROOT.SetBatch(True)

class JetFlavourHelper:
def __init__(self, coll, jet, jetc, tag=""):
'''
NOTE: (May 2025) Once the full sim tagger is retrained on the new naming convention (see https://github.com/key4hep/k4MLJetTagger?tab=readme-ov-file#open-problems--further-work), the names defined here must be altered. Then, they will also nicely match the namings in ReconstructedParticle2Track.
'''
def __init__(self, coll, jet, jetc, tag="", sim_type="fast"):
'''
sim_type: fast or full
'''
# check if sim_type is valid
if sim_type not in ["fast", "full"]:
print("ERROR: sim_type must be either 'fast' or 'full'")
sys.exit()

self.jet = jet
self.const = jetc

self.tag = tag
if tag != "":
self.tag = "_{}".format(tag)
self.sim_type = sim_type

self.particle = coll["GenParticles"]
self.pfcand = coll["PFParticles"]
self.pftrack = coll["PFTracks"]
self.pfphoton = coll["PFPhotons"]
self.pfnh = coll["PFNeutralHadrons"]
self.trackstate = coll["TrackState"]

self.trackstate = coll["TrackStates"]
self.tracks = coll["Tracks"]

self.trackerhits = coll["TrackerHits"]
self.calohits = coll["CalorimeterHits"]
self.dndx = coll["dNdx"]
self.l = coll["PathLength"]
self.bz = coll["Bz"]

if sim_type == "fast":
self.dndx = coll["dNdx"]
self.bz = coll["Bz"]
elif sim_type == "full":
self.bz = "2.0" # CLD #FIXME: this should be read from the geometry
self.dndx = None
self.primvertex = coll["PV"]
self.definition = dict()

# ===== VERTEX
# MC primary vertex
self.definition["pv{}".format(self.tag)] = "FCCAnalyses::MCParticle::get_EventPrimaryVertexP4()( {} )".format(
self.particle
# ===== VERTEX (reconstructed)
self.definition["pv{}".format(self.tag)] = "JetConstituentsUtils::get_primary_vertex({})".format(
self.primvertex
)

# build jet constituents lists
Expand Down Expand Up @@ -69,105 +86,114 @@ def __init__(self, coll, jet, jetc, tag=""):
self.definition["pfcand_phirel{}".format(self.tag)] = "JetConstituentsUtils::get_phirel_cluster({}, {})".format(
jet, self.const
)
if self.sim_type == "fast":
self.definition["Bz{}".format(self.tag)] = "{}[0]".format(self.bz)
self.definition[
"pfcand_dndx{}".format(self.tag)
] = "JetConstituentsUtils::get_dndx({}, {}, {}, pfcand_isChargedHad{})".format(
self.const, self.dndx, self.pftrack, self.tag
)

self.definition[
"pfcand_dndx{}".format(self.tag)
] = "JetConstituentsUtils::get_dndx({}, {}, {}, pfcand_isChargedHad{})".format(
self.const, self.dndx, self.pftrack, self.tag
)
self.definition[
"pfcand_mtof{}".format(self.tag)
] = "JetConstituentsUtils::get_mtof({}, {}, {}, {}, {}, {}, {}, pv{})".format(
self.const, self.l, self.pftrack, self.trackerhits, self.pfphoton, self.pfnh, self.calohits, self.tag
)
elif self.sim_type == "full":
self.definition["Bz{}".format(self.tag)] = self.bz
# fill the dNdx and mtof variables with 0
self.definition[
"pfcand_dndx{}".format(self.tag)
] = "JetConstituentsUtils::get_dndx_dummy({})".format(self.const)

self.definition[
"pfcand_mtof{}".format(self.tag)
] = "JetConstituentsUtils::get_mtof({}, {}, {}, {}, {}, {}, {}, pv{})".format(
self.const, self.l, self.pftrack, self.trackerhits, self.pfphoton, self.pfnh, self.calohits, self.tag
)
self.definition[
"pfcand_mtof{}".format(self.tag)
] = "JetConstituentsUtils::get_mtof_dummy({})".format(self.const)

self.definition["Bz{}".format(self.tag)] = "{}[0]".format(self.bz)

self.definition[
"pfcand_dxy{}".format(self.tag)
] = "JetConstituentsUtils::XPtoPar_dxy({}, {}, pv{}, Bz{})".format(
self.const, self.trackstate, self.tag, self.tag
self.definition["pfcand_dxy{}".format(self.tag)] = "JetConstituentsUtils::XPtoPar_dxy({}, {}, {}, pv{}, Bz{})".format(
self.const, self.trackstate, self.tracks, self.tag, self.tag
)

self.definition["pfcand_dz{}".format(self.tag)] = "JetConstituentsUtils::XPtoPar_dz({}, {}, pv{}, Bz{})".format(
self.const, self.trackstate, self.tag, self.tag
self.definition["pfcand_dz{}".format(self.tag)] = "JetConstituentsUtils::XPtoPar_dz({}, {}, {}, pv{}, Bz{})".format(
self.const, self.trackstate, self.tracks, self.tag, self.tag
)

self.definition[
"pfcand_phi0{}".format(self.tag)
] = "JetConstituentsUtils::XPtoPar_phi({}, {}, pv{}, Bz{})".format(
self.const, self.trackstate, self.tag, self.tag
self.definition["pfcand_phi0{}".format(self.tag)] = "JetConstituentsUtils::XPtoPar_phi({}, {}, {}, pv{}, Bz{})".format(
self.const, self.trackstate, self.tracks, self.tag, self.tag
)

self.definition["pfcand_C{}".format(self.tag)] = "JetConstituentsUtils::XPtoPar_C({}, {}, Bz{})".format(
self.const, self.trackstate, self.tag
self.definition["pfcand_C{}".format(self.tag)] = "JetConstituentsUtils::XPtoPar_C({}, {}, {}, Bz{})".format(
self.const, self.trackstate, self.tracks, self.tag
)

self.definition["pfcand_ct{}".format(self.tag)] = "JetConstituentsUtils::XPtoPar_ct({}, {}, Bz{})".format(
self.const, self.trackstate, self.tag
self.definition["pfcand_ct{}".format(self.tag)] = "JetConstituentsUtils::XPtoPar_ct({}, {}, {}, Bz{})".format(
self.const, self.trackstate, self.tracks, self.tag
)

self.definition["pfcand_dptdpt{}".format(self.tag)] = "JetConstituentsUtils::get_omega_cov({}, {})".format(
self.const, self.trackstate
# covariance matrix (fixed track state problem)

self.definition["pfcand_dptdpt{}".format(self.tag)] = "JetConstituentsUtils::get_omega_cov({}, {}, {})".format(
self.const, self.tracks, self.trackstate
)

self.definition["pfcand_dxydxy{}".format(self.tag)] = "JetConstituentsUtils::get_d0_cov({}, {})".format(
self.const, self.trackstate
self.definition["pfcand_dxydxy{}".format(self.tag)] = "JetConstituentsUtils::get_d0_cov({}, {}, {})".format(
self.const, self.tracks, self.trackstate
)

self.definition["pfcand_dzdz{}".format(self.tag)] = "JetConstituentsUtils::get_z0_cov({}, {})".format(
self.const, self.trackstate
self.definition["pfcand_dzdz{}".format(self.tag)] = "JetConstituentsUtils::get_z0_cov({}, {}, {})".format(
self.const, self.tracks, self.trackstate
)

self.definition["pfcand_dphidphi{}".format(self.tag)] = "JetConstituentsUtils::get_phi0_cov({}, {})".format(
self.const, self.trackstate
self.definition["pfcand_dphidphi{}".format(self.tag)] = "JetConstituentsUtils::get_phi0_cov({}, {}, {})".format(
self.const, self.tracks, self.trackstate
)

self.definition[
"pfcand_detadeta{}".format(self.tag)
] = "JetConstituentsUtils::get_tanlambda_cov({}, {})".format(self.const, self.trackstate)
self.definition["pfcand_detadeta{}".format(self.tag)] = "JetConstituentsUtils::get_tanlambda_cov({}, {}, {})".format(
self.const, self.tracks, self.trackstate
)

self.definition["pfcand_dxydz{}".format(self.tag)] = "JetConstituentsUtils::get_d0_z0_cov({}, {})".format(
self.const, self.trackstate
self.definition["pfcand_dxydz{}".format(self.tag)] = "JetConstituentsUtils::get_d0_z0_cov({}, {}, {})".format(
self.const, self.tracks, self.trackstate
)

self.definition["pfcand_dphidxy{}".format(self.tag)] = "JetConstituentsUtils::get_phi0_d0_cov({}, {})".format(
self.const, self.trackstate
self.definition["pfcand_dphidxy{}".format(self.tag)] = "JetConstituentsUtils::get_phi0_d0_cov({}, {}, {})".format(
self.const, self.tracks, self.trackstate
)

self.definition["pfcand_phidz{}".format(self.tag)] = "JetConstituentsUtils::get_phi0_z0_cov({}, {})".format(
self.const, self.trackstate
self.definition["pfcand_phidz{}".format(self.tag)] = "JetConstituentsUtils::get_phi0_z0_cov({}, {}, {})".format(
self.const, self.tracks, self.trackstate
)

self.definition[
"pfcand_phictgtheta{}".format(self.tag)
] = "JetConstituentsUtils::get_tanlambda_phi0_cov({}, {})".format(self.const, self.trackstate)
self.definition["pfcand_phictgtheta{}".format(self.tag)] = "JetConstituentsUtils::get_tanlambda_phi0_cov({}, {}, {})".format(
self.const, self.tracks, self.trackstate
)

self.definition[
"pfcand_dxyctgtheta{}".format(self.tag)
] = "JetConstituentsUtils::get_tanlambda_d0_cov({}, {})".format(self.const, self.trackstate)
self.definition["pfcand_dxyctgtheta{}".format(self.tag)] = "JetConstituentsUtils::get_tanlambda_d0_cov({}, {}, {})".format(
self.const, self.tracks, self.trackstate
)

self.definition[
"pfcand_dlambdadz{}".format(self.tag)
] = "JetConstituentsUtils::get_tanlambda_z0_cov({}, {})".format(self.const, self.trackstate)
self.definition["pfcand_dlambdadz{}".format(self.tag)] = "JetConstituentsUtils::get_tanlambda_z0_cov({}, {}, {})".format(
self.const, self.tracks, self.trackstate
)

self.definition[
"pfcand_cctgtheta{}".format(self.tag)
] = "JetConstituentsUtils::get_omega_tanlambda_cov({}, {})".format(self.const, self.trackstate)
self.definition["pfcand_cctgtheta{}".format(self.tag)] = "JetConstituentsUtils::get_omega_tanlambda_cov({}, {}, {})".format(
self.const, self.tracks, self.trackstate
)

self.definition["pfcand_phic{}".format(self.tag)] = "JetConstituentsUtils::get_omega_phi0_cov({}, {})".format(
self.const, self.trackstate
self.definition["pfcand_phic{}".format(self.tag)] = "JetConstituentsUtils::get_omega_phi0_cov({}, {}, {})".format(
self.const, self.tracks, self.trackstate
)

self.definition["pfcand_dxyc{}".format(self.tag)] = "JetConstituentsUtils::get_omega_d0_cov({}, {})".format(
self.const, self.trackstate
self.definition["pfcand_dxyc{}".format(self.tag)] = "JetConstituentsUtils::get_omega_d0_cov({}, {}, {})".format(
self.const, self.tracks, self.trackstate
)

self.definition["pfcand_cdz{}".format(self.tag)] = "JetConstituentsUtils::get_omega_z0_cov({}, {})".format(
self.const, self.trackstate
self.definition["pfcand_cdz{}".format(self.tag)] = "JetConstituentsUtils::get_omega_z0_cov({}, {}, {})".format(
self.const, self.tracks, self.trackstate
)

# impact parameters

self.definition[
"pfcand_btagSip2dVal{}".format(self.tag)
] = "JetConstituentsUtils::get_Sip2dVal_clusterV({}, pfcand_dxy{}, pfcand_phi0{}, Bz{})".format(
Expand All @@ -176,7 +202,7 @@ def __init__(self, coll, jet, jetc, tag=""):

self.definition[
"pfcand_btagSip2dSig{}".format(self.tag)
] = "JetConstituentsUtils::get_Sip2dSig(pfcand_btagSip2dVal{}, pfcand_dxydxy{})".format(self.tag, self.tag)
] = 'JetConstituentsUtils::get_Sip2dSig(pfcand_btagSip2dVal{}, pfcand_dxydxy{}, "{}")'.format(self.tag, self.tag, self.sim_type)

self.definition[
"pfcand_btagSip3dVal{}".format(self.tag)
Expand All @@ -186,8 +212,8 @@ def __init__(self, coll, jet, jetc, tag=""):

self.definition[
"pfcand_btagSip3dSig{}".format(self.tag)
] = "JetConstituentsUtils::get_Sip3dSig(pfcand_btagSip3dVal{}, pfcand_dxydxy{}, pfcand_dzdz{})".format(
self.tag, self.tag, self.tag
] = 'JetConstituentsUtils::get_Sip3dSig(pfcand_btagSip3dVal{}, pfcand_dxydxy{}, pfcand_dzdz{}, "{}")'.format(
self.tag, self.tag, self.tag, self.sim_type
)

self.definition[
Expand All @@ -198,10 +224,12 @@ def __init__(self, coll, jet, jetc, tag=""):

self.definition[
"pfcand_btagJetDistSig{}".format(self.tag)
] = "JetConstituentsUtils::get_JetDistSig(pfcand_btagJetDistVal{}, pfcand_dxydxy{}, pfcand_dzdz{})".format(
self.tag, self.tag, self.tag
] = 'JetConstituentsUtils::get_JetDistSig(pfcand_btagJetDistVal{}, pfcand_dxydxy{}, pfcand_dzdz{}, "{}")'.format(
self.tag, self.tag, self.tag, self.sim_type
)

# count number of particles in the jet

self.definition["jet_nmu{}".format(self.tag)] = "JetConstituentsUtils::count_type(pfcand_isMu{})".format(
self.tag
)
Expand Down Expand Up @@ -249,13 +277,16 @@ def inference(self, jsonCfg, onnxCfg, df):
# convert to tuple
initvars = tuple(initvars)

# then funcs
# check if all variables are defined
for varname in self.variables:
matches = [obs for obs in self.definition.keys() if obs == varname]
if len(matches) != 1:
print("ERROR: {} variables was not defined.".format(varname))
sys.exit()

# check if variables are filled with values - HOW?


self.get_weight_str = "JetFlavourUtils::get_weights(rdfslot_, "
for var in self.variables:
self.get_weight_str += "{},".format(var)
Expand Down
Loading