diff --git a/AtData/AtDataLinkDef.h b/AtData/AtDataLinkDef.h index b1cc7f9da..411d87117 100644 --- a/AtData/AtDataLinkDef.h +++ b/AtData/AtDataLinkDef.h @@ -43,5 +43,12 @@ #pragma link C++ enum AtPatterns::PatternType; #pragma link C++ function AtPatterns::CreatePattern; +#pragma link C++ class AtFitMetadata + ; +#pragma link C++ class AtFitTrackMetadata + ; #pragma link C++ class MCFitter::AtMCResult + ; + +#pragma link C++ class AtTrackingEventOld + ; +#pragma link C++ class AtFittedTrackOld + ; +#pragma link C++ class MCFitter::AtMCResultOld + ; + #endif diff --git a/AtData/AtFitMetadata.cxx b/AtData/AtFitMetadata.cxx new file mode 100644 index 000000000..8e6906a91 --- /dev/null +++ b/AtData/AtFitMetadata.cxx @@ -0,0 +1,3 @@ +#include "AtFitMetadata.h" + +ClassImp(AtFitMetadata); diff --git a/AtData/AtFitMetadata.h b/AtData/AtFitMetadata.h new file mode 100644 index 000000000..da8fa1c3c --- /dev/null +++ b/AtData/AtFitMetadata.h @@ -0,0 +1,58 @@ +#ifndef ATFITMETADATA_H +#define ATFITMETADATA_H + +#include "AtFitTrackMetadata.h" + +#include + +#include // for Double_t, THashConsistencyHolder, ClassDefOverride +#include + +#include +#include +#include +#include + +class TBuffer; +class TClass; +class TMemberInspector; + +/** + * Class for storing the result of the fit for the entire AtTrackingEvent from an AtFitter class. + */ +class AtFitMetadata : public TObject { +public: + using TrackMetadataPtr = std::unique_ptr; + using TrackMetadatasVector = std::vector; + using MetadatasMap = std::map; + +protected: + /** + * Map to store the metadatas for all different fits done to all tracks in the event. + * The Int_t corresponds to the trackID for which the metadatas correspond to. + * The vector of AtFitTrackMetadata contains the different metadatas for all fits + * that have been done for the track (for example, different assumptions for the + * particles of the track, different initial conditions, etc...). + */ + MetadatasMap fMetadatas; + + // Event ID for which this fit was done. + ULong_t fEventID; + +public: + AtFitMetadata() = default; + ~AtFitMetadata() = default; + + void SetTrackMetadatasVector(Int_t trackID, TrackMetadatasVector metadatas) + { + fMetadatas.insert({trackID, std::move(metadatas)}); + } + + void SetEventID(ULong_t id) { fEventID = id; } + + TrackMetadatasVector &GetTrackMetadatasVector(Int_t trackID) { return fMetadatas.at(trackID); } + + ClassDefOverride(AtFitMetadata, 1); +}; + +#endif diff --git a/AtData/AtFitTrackMetadata.cxx b/AtData/AtFitTrackMetadata.cxx new file mode 100644 index 000000000..d2627dfcb --- /dev/null +++ b/AtData/AtFitTrackMetadata.cxx @@ -0,0 +1,16 @@ +#include "AtFitTrackMetadata.h" + +#include +ClassImp(AtFitTrackMetadata); + +void AtFitTrackMetadata::Print() const + +{ + std::cout << " Fit metadata for track with ID " << fTrackID << ":" << std::endl; + + std::cout << " Statistics: " << std::endl; + std::cout << " PValue = " << fPValue << std::endl; + std::cout << " Chi2 = " << fChi2 << std::endl; + std::cout << " NDF = " << fNdf << std::endl; + std::cout << " Converged = " << fFitConverged << std::endl; +} diff --git a/AtData/AtFitTrackMetadata.h b/AtData/AtFitTrackMetadata.h new file mode 100644 index 000000000..b21ad6729 --- /dev/null +++ b/AtData/AtFitTrackMetadata.h @@ -0,0 +1,46 @@ +#ifndef ATFITTRACKMETADATA_H +#define ATFITTRACKMETADATA_H + +#include // for Double_t, THashConsistencyHolder, ClassDefOverride +#include + +class TBuffer; +class TClass; +class TMemberInspector; + +/** + * Class for storing the result of the fit of an AtTrack from an AtFitter class. + */ +class AtFitTrackMetadata : public TObject { +protected: + // Statistics parameters of the fit. + Double_t fPValue{0}; + Double_t fChi2{0}; + Int_t fNdf{0}; + Bool_t fFitConverged{false}; + + // The track ID for which this fit was done for. + Int_t fTrackID{-1}; + +public: + AtFitTrackMetadata() = default; + ~AtFitTrackMetadata() = default; + + void SetPValue(Double_t value) { fPValue = value; } + void SetChi2(Double_t value) { fChi2 = value; } + void SetNdf(Int_t value) { fNdf = value; } + void SetFitConverged(Bool_t value) { fFitConverged = value; } + void SetTrackID(Int_t value) { fTrackID = value; } + + Double_t GetPValue() const { return fPValue; } + Double_t GetChi2() const { return fChi2; } + Int_t GetNdf() const { return fNdf; } + Bool_t GetFitConverged() const { return fFitConverged; } + Int_t GetTrackID() const { return fTrackID; } + + virtual void Print() const; + + ClassDefOverride(AtFitTrackMetadata, 1); +}; + +#endif diff --git a/AtData/AtFittedTrack.cxx b/AtData/AtFittedTrack.cxx index 4481bf8f7..940a59002 100644 --- a/AtData/AtFittedTrack.cxx +++ b/AtData/AtFittedTrack.cxx @@ -1,45 +1,49 @@ #include "AtFittedTrack.h" -#include - -#include -#include - ClassImp(AtFittedTrack); -using XYZVector = ROOT::Math::XYZVector; - -const std::tuple AtFittedTrack::GetEnergyAngles() -{ - return std::forward_as_tuple(fEnergy, fEnergyXtr, fTheta, fPhi, fEnergyPRA, fThetaPRA, fPhiPRA); -} - -const std::tuple AtFittedTrack::GetVertices() +void AtFittedTrack::SetKinematics(int particleIdx, Double_t energy, Double_t theta, Double_t phi) { - return std::forward_as_tuple(fInitialPos, fInitialPosPRA, fInitialPosXtr); + while (particleIdx >= fKinematics.size()) { + Kinematics newKinematics; + fKinematics.push_back(newKinematics); + } + + fKinematics[particleIdx].kineticEnergy = energy; + fKinematics[particleIdx].theta = theta; + fKinematics[particleIdx].phi = phi; } -const std::tuple AtFittedTrack::GetStats() +void AtFittedTrack::SetKinematicsXtr(int particleIdx, Double_t energyxtr, Double_t thetaxtr, Double_t phixtr) { - return std::forward_as_tuple(fPValue, fChi2, fBChi2, fNdf, fBNdf, fFitConverged); + while (particleIdx >= fKinematicsXtr.size()) { + Kinematics newKinematics; + fKinematicsXtr.push_back(newKinematics); + } + + fKinematicsXtr[particleIdx].kineticEnergy = energyxtr; + fKinematicsXtr[particleIdx].theta = thetaxtr; + fKinematicsXtr[particleIdx].phi = phixtr; } -const std::tuple AtFittedTrack::GetTrackProperties() +void AtFittedTrack::SetParticleInfo(int particleIdx, std::string pdg, Int_t charge, Double_t mass) { - return std::forward_as_tuple(fCharge, fBrho, fELossADC, fDEdxADC, fPDG, fTrackPoints); + while (particleIdx >= fParticleInfo.size()) { + ParticleInfo newParticleInfo; + fParticleInfo.push_back(newParticleInfo); + } + + fParticleInfo[particleIdx].idPDG = TString(pdg); + fParticleInfo[particleIdx].charge = charge; + fParticleInfo[particleIdx].mass = mass; } -const std::tuple AtFittedTrack::GetIonChamber() +void AtFittedTrack::SetVertex(int particleIdx, XYZVector point) { - return std::forward_as_tuple(fIonChamberEnergy, fIonChamberTime); -} + while (particleIdx >= fVertex.size()) { + XYZVector newVertex; + fVertex.push_back(newVertex); + } -const std::tuple AtFittedTrack::GetExcitationEnergy() -{ - return std::forward_as_tuple(fExcitationEnergy, fExcitationEnergyXtr); + fVertex[particleIdx] = point; } - -const std::tuple AtFittedTrack::GetDistances() -{ - return std::forward_as_tuple(fDistanceXtr, fTrackLength, fPOCAXtr); -} \ No newline at end of file diff --git a/AtData/AtFittedTrack.h b/AtData/AtFittedTrack.h index 4cb0c1044..62727a81d 100644 --- a/AtData/AtFittedTrack.h +++ b/AtData/AtFittedTrack.h @@ -1,6 +1,8 @@ #ifndef ATFITTEDTRACK_H #define ATFITTEDTRACK_H +#include "AtFitTrackMetadata.h" + #include #include #include @@ -8,6 +10,7 @@ #include #include #include +#include #include #include @@ -20,125 +23,233 @@ class TClass; class TMemberInspector; class AtFittedTrack : public TObject { - +public: using XYZVector = ROOT::Math::XYZVector; + using TrackMetadataPtr = std::unique_ptr; + + struct Kinematics { + Double_t kineticEnergy{-1}; // Kinetic energy + Double_t theta{-1}; // Theta scattering angle + Double_t phi{-1}; // Phi scattering angle + }; + + struct ParticleInfo { + TString idPDG{""}; // PDG code of the particle + Int_t charge{0}; // Charge number of the particle + Double_t mass{-1}; // Mass of the particle in amu + }; + + struct TrackProperties { + XYZVector initialPosition; // Position of the first hit + XYZVector initialPositionXtr; // Position of the point closest to (0,0) + Double_t extrapolatedDistance{-1}; // Distance initialPosition->initialPositionXtr along the pattern + Double_t distancePOCA{-1}; // Distance initialPositionXtr->(0,0) + Double_t trackLength{-1}; // Distance initialPosition->End of charge + Double_t trackLengthXtr{-1}; // Distance initialPositionXtr->End of charge + Double_t estimateTotalCharge{-1}; // Sum of the charge of all hits + Double_t estimateDeDx{-1}; // Sum of the charge of all hits divided by range + Int_t trackPoints{-1}; // Number of hits in the track + }; private: Int_t fTrackID{-1}; //< Track ID from pattern recognition - Float_t fEnergy{0}; - Float_t fTheta{0}; - Float_t fPhi{0}; - Float_t fEnergyPRA{0}; - Float_t fThetaPRA{0}; - Float_t fPhiPRA{0}; + // Kinematic variables obtained by the fit. + std::vector fKinematics; + std::vector fKinematicsXtr; + + // Particle information. + std::vector fParticleInfo; + + // Vertex where the particle has originated from. + std::vector fVertex; + + // Track properties. + TrackProperties fTrackProperties; + + // Copy of the AtFitTrackResult object corresponding to the fit used for this track. + TrackMetadataPtr fTrackMetadata{nullptr}; + + // Deprecated members needed to keep support of deprecated methods. + [[deprecated("No PRA information is supposed to be kept in AtFittedTrack.")]] Double_t fEnergyPRA{0}; + [[deprecated("No PRA information is supposed to be kept in AtFittedTrack.")]] Double_t fThetaPRA{0}; + [[deprecated("No PRA information is supposed to be kept in AtFittedTrack.")]] Double_t fPhiPRA{0}; + [[deprecated("No PRA information is supposed to be kept in AtFittedTrack.")]] XYZVector fInitialPosPRA; + [[deprecated("No backwards statistics is supposed to be stored. May not be needed in all fitters.")]] Float_t fBChi2{ + 0}; + [[deprecated("No backwards statistics is supposed to be stored. May not be needed in all fitters.")]] Float_t fBNdf{ + 0}; + [[deprecated("Not all experiments are done with magnetic field.")]] Float_t fBrho{0}; + [[deprecated("No IC information is supposed to be kept in AtFittedTrack.")]] Float_t fIonChamberEnergy{0}; + [[deprecated("No IC information is supposed to be kept in AtFittedTrack.")]] Int_t fIonChamberTime{0}; + [[deprecated("No excitation energy is supposed to be kept in AtFittedTrack.")]] Float_t fExcitationEnergy{0}; + [[deprecated("No excitation energy is supposed to be kept in AtFittedTrack.")]] Float_t fExcitationEnergyXtr{0}; + +public: + AtFittedTrack() = default; + ~AtFittedTrack() = default; + + void SetTrackID(Int_t trackid) { fTrackID = trackid; } - Float_t fExcitationEnergy{0}; + void SetKinematics(int particleIdx, Double_t energy, Double_t theta, Double_t phi); + void SetKinematicsXtr(int particleIdx, Double_t energyxtr, Double_t thetaxtr, Double_t phixtr); + void SetParticleInfo(int particleIdx, std::string pdg, Int_t charge, Double_t mass); + void SetVertex(int particleIdx, XYZVector point); + + void SetKinematics(Double_t energy, Double_t theta, Double_t phi) { SetKinematics(0, energy, theta, phi); } + void SetKinematicsXtr(Double_t energyxtr, Double_t thetaxtr, Double_t phixtr) + { + SetKinematicsXtr(0, energyxtr, thetaxtr, phixtr); + } - XYZVector fInitialPos; // xiniFitVec,yiniFitVec,ziniFitVec; - XYZVector fInitialPosPRA; // xiniPRAVec,yiniPRAVec,ziniPRAVec; - XYZVector fInitialPosXtr; + void SetParticleInfo(std::string pdg, Int_t charge, Double_t mass) { SetParticleInfo(0, pdg, charge, mass); } - Float_t fIonChamberEnergy{0}; - Int_t fIonChamberTime{0}; + void SetVertex(XYZVector point) { SetVertex(0, point); } - Float_t fEnergyXtr{0}; - Float_t fExcitationEnergyXtr{0}; + void SetTrackPropertiesStruct(XYZVector initialPosition, XYZVector initialPositionXtr, Double_t extrapolatedDistance, + Double_t distancePOCA, Double_t trackLength, Double_t trackLengthXtr, + Double_t estimateTotalCharge, Int_t trackPoints) + { + fTrackProperties.initialPosition = initialPosition; + fTrackProperties.initialPositionXtr = initialPositionXtr; + fTrackProperties.extrapolatedDistance = extrapolatedDistance; + fTrackProperties.distancePOCA = distancePOCA; + fTrackProperties.trackLength = trackLength; + fTrackProperties.trackLengthXtr = trackLengthXtr; + fTrackProperties.estimateTotalCharge = estimateTotalCharge; + fTrackProperties.estimateDeDx = estimateTotalCharge / trackLengthXtr; + fTrackProperties.trackPoints = trackPoints; + } - Float_t fDistanceXtr{0}; - Float_t fTrackLength{0}; - Float_t fPOCAXtr{0}; + void SetTrackMetadata(TrackMetadataPtr trackMetadata) { fTrackMetadata = std::move(trackMetadata); } - Float_t fPValue{0}; - Float_t fChi2{0}; - Float_t fBChi2{0}; - Float_t fNdf{0}; - Float_t fBNdf{0}; - Bool_t fFitConverged{0}; + const Int_t GetTrackID() { return fTrackID; } - Int_t fCharge{0}; - Float_t fBrho{0}; - Float_t fELossADC{0}; - Float_t fDEdxADC{0}; - std::string fPDG{0}; - Int_t fTrackPoints{0}; + const Kinematics GetKinematics(int particleIdx = 0) { return fKinematics[particleIdx]; } + const Kinematics GetKinematicsXtr(int particleIdx = 0) { return fKinematicsXtr[particleIdx]; } + const ParticleInfo GetParticleInfo(int particleIdx = 0) { return fParticleInfo[particleIdx]; } + const XYZVector GetVertex(int particleIdx = 0) { return fVertex[particleIdx]; } -public: - AtFittedTrack() = default; - ~AtFittedTrack() = default; + const TrackProperties GetTrackPropertiesStruct() { return fTrackProperties; } - inline void SetTrackID(Int_t trackid) { fTrackID = trackid; } + TrackMetadataPtr &GetTrackMetadata() { return fTrackMetadata; } - inline void SetEnergyAngles(Float_t energy, Float_t energyxtr, Float_t theta, Float_t phi, Float_t energypra, - Float_t thetapra, Float_t phipra) + // Old deprecated methods. + [[deprecated("Replaced by SetKinematics() and SetKinematicsXtr().")]] void + SetEnergyAngles(Float_t energy, Float_t energyxtr, Float_t theta, Float_t phi, Float_t energypra, Float_t thetapra, + Float_t phipra) { - fEnergy = energy; - fEnergyXtr = energyxtr; - fTheta = theta; - fPhi = phi; + fKinematics[0].kineticEnergy = energy; + fKinematics[0].theta = theta; + fKinematics[0].phi = phi; + fKinematicsXtr[0].kineticEnergy = energyxtr; fEnergyPRA = energypra; fThetaPRA = thetapra; - fPhiPRA = phi; + fPhiPRA = phipra; } - - inline void SetVertexPosition(XYZVector inipos, XYZVector iniposPRA, XYZVector iniposxtr) + [[deprecated("This information now lives in the TrackProperties struct. Check the new SetTrackProperties().")]] void + SetVertexPosition(XYZVector inipos, XYZVector iniposPRA, XYZVector iniposxtr) { - fInitialPos = inipos; + fTrackProperties.initialPosition = inipos; fInitialPosPRA = iniposPRA; - fInitialPosXtr = iniposxtr; + fVertex[0] = iniposxtr; } - - inline void SetStats(Float_t pvalue, Float_t chi2, Float_t bchi2, Float_t ndf, Float_t bndf, Bool_t conv) + [[deprecated("Statistics now live inside the AtFitTrackMetadata. Check SetTrackMetadata().")]] void + SetStats(Float_t pvalue, Float_t chi2, Float_t bchi2, Float_t ndf, Float_t bndf, Bool_t conv) { - fPValue = pvalue; - fChi2 = chi2; + if (fTrackMetadata == nullptr) + fTrackMetadata = std::make_unique(); + fTrackMetadata->SetPValue(pvalue); + fTrackMetadata->SetChi2(chi2); + fTrackMetadata->SetNdf(ndf); + fTrackMetadata->SetFitConverged(conv); + fTrackMetadata->SetTrackID(fTrackID); + fBChi2 = bchi2; - fNdf = ndf; fBNdf = bndf; - fFitConverged = conv; } - - inline void + [[deprecated("The TrackProperties have changed. Please check the new SetTrackProperties() method.")]] void SetTrackProperties(Int_t charge, Float_t brho, Float_t eloss, Float_t dedx, std::string pdg, Int_t points) { - fCharge = charge; + fParticleInfo[0].charge = charge; fBrho = brho; - fELossADC = eloss; - fDEdxADC = dedx; - fPDG = pdg; - fTrackPoints = points; + fTrackProperties.estimateTotalCharge = eloss; + fTrackProperties.estimateDeDx = dedx; + fParticleInfo[0].idPDG = pdg; + fTrackProperties.trackPoints = points; } - - inline void SetIonChamber(Float_t icenergy, Int_t ictime) + [[deprecated("Ion chamber information is no longer saved in the AtFittedTrack. This has been saved here still, but " + "refrain from using this further.")]] void + SetIonChamber(Float_t icenergy, Int_t ictime) { fIonChamberEnergy = icenergy; fIonChamberTime = ictime; } - - inline void SetExcitationEnergy(Float_t exenergy, Float_t exenergyxtr) + [[deprecated("Excitation energy is no longer saved in the AtFittedTrack. This has been saved here still, but " + "refrain from using this further.")]] void + SetExcitationEnergy(Float_t exenergy, Float_t exenergyxtr) { fExcitationEnergy = exenergy; fExcitationEnergyXtr = exenergyxtr; } - - inline void SetDistances(Float_t distancextr, Float_t length, Float_t poca) + [[deprecated("This information now lives in the TrackProperties struct. Check the new SetTrackProperties().")]] void + SetDistances(Float_t distancextr, Float_t length, Float_t poca) { - fDistanceXtr = distancextr; - fTrackLength = length; - fPOCAXtr = poca; + fTrackProperties.extrapolatedDistance = distancextr; + fTrackProperties.trackLength = length; + fTrackProperties.distancePOCA = poca; } - const Int_t GetTrackID() { return fTrackID; } - - const std::tuple GetEnergyAngles(); - const std::tuple GetVertices(); - const std::tuple GetStats(); - const std::tuple GetTrackProperties(); - const std::tuple GetIonChamber(); - const std::tuple GetExcitationEnergy(); - const std::tuple GetDistances(); + [[deprecated("Please check GetKinematics() and GetKinematicsXtr().")]] const std::tuple< + Float_t, Float_t, Float_t, Float_t, Float_t, Float_t, Float_t> + GetEnergyAngles() + { + return std::forward_as_tuple(fKinematics[0].kineticEnergy, fKinematicsXtr[0].kineticEnergy, fKinematics[0].theta, + fKinematics[0].phi, fEnergyPRA, fThetaPRA, fPhiPRA); + } + [[deprecated("Please check GetVertex().")]] const std::tuple GetVertices() + { + return std::forward_as_tuple(fTrackProperties.initialPosition, fInitialPosPRA, fVertex[0]); + } + [[deprecated("Statistics now live inside the AtFitTrackMetadata. Check GetTrackMetadata().")]] const std::tuple< + Float_t, Float_t, Float_t, Float_t, Float_t, Bool_t> + GetStats() + { + if (fTrackMetadata == nullptr) + return std::forward_as_tuple(0, 0, 0, 0, 0, 0); + return std::forward_as_tuple(fTrackMetadata->GetPValue(), fTrackMetadata->GetChi2(), fBChi2, + fTrackMetadata->GetNdf(), fBNdf, fTrackMetadata->GetFitConverged()); + } + [[deprecated("The TrackProperties have changed. Please check the new GetTrackProperties() method.")]] const std:: + tuple + GetTrackProperties() + { + return std::forward_as_tuple(fParticleInfo[0].charge, fBrho, fTrackProperties.estimateTotalCharge, + fTrackProperties.estimateDeDx, fParticleInfo[0].idPDG.Data(), + fTrackProperties.trackPoints); + } + [[deprecated("Ion chamber information is no longer saved in the AtFittedTrack. This has been saved here still, but " + "refrain from using this further.")]] const std::tuple + GetIonChamber() + { + return std::forward_as_tuple(fIonChamberEnergy, fIonChamberTime); + } + [[deprecated("Excitation energy is no longer saved in the AtFittedTrack. This has been saved here still, but " + "refrain from using this further.")]] const std::tuple + GetExcitationEnergy() + { + return std::forward_as_tuple(fExcitationEnergy, fExcitationEnergyXtr); + } + [[deprecated( + "This information now lives in the TrackProperties struct. Check the new SetTrackProperties().")]] const std:: + tuple + GetDistances() + { + return std::forward_as_tuple(fTrackProperties.extrapolatedDistance, fTrackProperties.trackLength, + fTrackProperties.distancePOCA); + } - ClassDef(AtFittedTrack, 1); + ClassDef(AtFittedTrack, 2); }; -#endif \ No newline at end of file +#endif diff --git a/AtData/AtFittedTrackOld.cxx b/AtData/AtFittedTrackOld.cxx new file mode 100644 index 000000000..8b222a55e --- /dev/null +++ b/AtData/AtFittedTrackOld.cxx @@ -0,0 +1,45 @@ +#include "AtFittedTrackOld.h" + +#include + +#include +#include + +ClassImp(AtFittedTrackOld); + +using XYZVector = ROOT::Math::XYZVector; + +const std::tuple AtFittedTrackOld::GetEnergyAngles() +{ + return std::forward_as_tuple(fEnergy, fEnergyXtr, fTheta, fPhi, fEnergyPRA, fThetaPRA, fPhiPRA); +} + +const std::tuple AtFittedTrackOld::GetVertices() +{ + return std::forward_as_tuple(fInitialPos, fInitialPosPRA, fInitialPosXtr); +} + +const std::tuple AtFittedTrackOld::GetStats() +{ + return std::forward_as_tuple(fPValue, fChi2, fBChi2, fNdf, fBNdf, fFitConverged); +} + +const std::tuple AtFittedTrackOld::GetTrackProperties() +{ + return std::forward_as_tuple(fCharge, fBrho, fELossADC, fDEdxADC, fPDG, fTrackPoints); +} + +const std::tuple AtFittedTrackOld::GetIonChamber() +{ + return std::forward_as_tuple(fIonChamberEnergy, fIonChamberTime); +} + +const std::tuple AtFittedTrackOld::GetExcitationEnergy() +{ + return std::forward_as_tuple(fExcitationEnergy, fExcitationEnergyXtr); +} + +const std::tuple AtFittedTrackOld::GetDistances() +{ + return std::forward_as_tuple(fDistanceXtr, fTrackLength, fPOCAXtr); +} diff --git a/AtData/AtFittedTrackOld.h b/AtData/AtFittedTrackOld.h new file mode 100644 index 000000000..4ef1ffcfb --- /dev/null +++ b/AtData/AtFittedTrackOld.h @@ -0,0 +1,144 @@ +#ifndef ATFITTEDTRACKOLD_H +#define ATFITTEDTRACKOLD_H + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +class TBuffer; +class TClass; +class TMemberInspector; + +class [[deprecated]] AtFittedTrackOld : public TObject { + + using XYZVector = ROOT::Math::XYZVector; + +private: + Int_t fTrackID{-1}; //< Track ID from pattern recognition + + Float_t fEnergy{0}; + Float_t fTheta{0}; + Float_t fPhi{0}; + Float_t fEnergyPRA{0}; + Float_t fThetaPRA{0}; + Float_t fPhiPRA{0}; + + Float_t fExcitationEnergy{0}; + + XYZVector fInitialPos; // xiniFitVec,yiniFitVec,ziniFitVec; + XYZVector fInitialPosPRA; // xiniPRAVec,yiniPRAVec,ziniPRAVec; + XYZVector fInitialPosXtr; + + Float_t fIonChamberEnergy{0}; + Int_t fIonChamberTime{0}; + + Float_t fEnergyXtr{0}; + Float_t fExcitationEnergyXtr{0}; + + Float_t fDistanceXtr{0}; + Float_t fTrackLength{0}; + Float_t fPOCAXtr{0}; + + Float_t fPValue{0}; + Float_t fChi2{0}; + Float_t fBChi2{0}; + Float_t fNdf{0}; + Float_t fBNdf{0}; + Bool_t fFitConverged{0}; + + Int_t fCharge{0}; + Float_t fBrho{0}; + Float_t fELossADC{0}; + Float_t fDEdxADC{0}; + std::string fPDG{0}; + Int_t fTrackPoints{0}; + +public: + AtFittedTrackOld() = default; + ~AtFittedTrackOld() = default; + + inline void SetTrackID(Int_t trackid) { fTrackID = trackid; } + + inline void SetEnergyAngles(Float_t energy, Float_t energyxtr, Float_t theta, Float_t phi, Float_t energypra, + Float_t thetapra, Float_t phipra) + { + fEnergy = energy; + fEnergyXtr = energyxtr; + fTheta = theta; + fPhi = phi; + fEnergyPRA = energypra; + fThetaPRA = thetapra; + fPhiPRA = phi; + } + + inline void SetVertexPosition(XYZVector inipos, XYZVector iniposPRA, XYZVector iniposxtr) + { + fInitialPos = inipos; + fInitialPosPRA = iniposPRA; + fInitialPosXtr = iniposxtr; + } + + inline void SetStats(Float_t pvalue, Float_t chi2, Float_t bchi2, Float_t ndf, Float_t bndf, Bool_t conv) + { + fPValue = pvalue; + fChi2 = chi2; + fBChi2 = bchi2; + fNdf = ndf; + fBNdf = bndf; + fFitConverged = conv; + } + + inline void + SetTrackProperties(Int_t charge, Float_t brho, Float_t eloss, Float_t dedx, std::string pdg, Int_t points) + { + fCharge = charge; + fBrho = brho; + fELossADC = eloss; + fDEdxADC = dedx; + fPDG = pdg; + fTrackPoints = points; + } + + inline void SetIonChamber(Float_t icenergy, Int_t ictime) + { + fIonChamberEnergy = icenergy; + fIonChamberTime = ictime; + } + + inline void SetExcitationEnergy(Float_t exenergy, Float_t exenergyxtr) + { + fExcitationEnergy = exenergy; + fExcitationEnergyXtr = exenergyxtr; + } + + inline void SetDistances(Float_t distancextr, Float_t length, Float_t poca) + { + fDistanceXtr = distancextr; + fTrackLength = length; + fPOCAXtr = poca; + } + + const Int_t GetTrackID() { return fTrackID; } + + const std::tuple GetEnergyAngles(); + const std::tuple GetVertices(); + const std::tuple GetStats(); + const std::tuple GetTrackProperties(); + const std::tuple GetIonChamber(); + const std::tuple GetExcitationEnergy(); + const std::tuple GetDistances(); + + ClassDef(AtFittedTrackOld, 1); +}; + +#endif diff --git a/AtData/AtMCResult.cxx b/AtData/AtMCResult.cxx index 23884150b..23de778d3 100644 --- a/AtData/AtMCResult.cxx +++ b/AtData/AtMCResult.cxx @@ -7,8 +7,12 @@ namespace MCFitter { void AtMCResult::Print() const { - std::cout << "Objective: " << fObjective << " Iteration: " << fIterNum << std::endl; + AtFitTrackMetadata::Print(); + std::cout << " MC fit specifics: " << std::endl; + + std::cout << " Iteration = " << fIterNum << std::endl; + std::cout << " Parameters:" << std::endl; for (auto &[name, val] : fParameters) - std::cout << name << ": " << val << std::endl; + std::cout << " " << name << " = " << val << std::endl; } } // namespace MCFitter diff --git a/AtData/AtMCResult.h b/AtData/AtMCResult.h index 74ea5d9f3..98c28c405 100644 --- a/AtData/AtMCResult.h +++ b/AtData/AtMCResult.h @@ -1,6 +1,8 @@ #ifndef ATMCRESULT_H #define ATMCRESULT_H +#include "AtFitTrackMetadata.h" + #include // for Double_t, THashConsistencyHolder, ClassDefOverride #include @@ -15,19 +17,19 @@ namespace MCFitter { /** * Class for storing the result of an iteration in the AtMCFitter method. */ -class AtMCResult : public TObject { +class AtMCResult : public AtFitTrackMetadata { public: using ParamMap = std::map; - Double_t fObjective; //< Value f the objective function for this iteration ParamMap fParameters; //< Parameters used in simulation Int_t fIterNum; //< Iteration number. Used to map with the simulated event ID in the TTree. AtMCResult() = default; + ~AtMCResult() = default; - void Print() const; + void Print() const override; - ClassDefOverride(AtMCResult, 1); + ClassDefOverride(AtMCResult, 2); }; } // namespace MCFitter diff --git a/AtData/AtMCResultOld.cxx b/AtData/AtMCResultOld.cxx new file mode 100644 index 000000000..54481b7d4 --- /dev/null +++ b/AtData/AtMCResultOld.cxx @@ -0,0 +1,14 @@ +#include "AtMCResultOld.h" + +#include +ClassImp(MCFitter::AtMCResultOld); +namespace MCFitter { + +void AtMCResultOld::Print() const + +{ + std::cout << "Objective: " << fObjective << " Iteration: " << fIterNum << std::endl; + for (auto &[name, val] : fParameters) + std::cout << name << ": " << val << std::endl; +} +} // namespace MCFitter diff --git a/AtData/AtMCResultOld.h b/AtData/AtMCResultOld.h new file mode 100644 index 000000000..e095c7497 --- /dev/null +++ b/AtData/AtMCResultOld.h @@ -0,0 +1,35 @@ +#ifndef ATMCRESULTOLD_H +#define ATMCRESULTOLD_H + +#include // for Double_t, THashConsistencyHolder, ClassDefOverride +#include + +#include +#include // for string +class TBuffer; +class TClass; +class TMemberInspector; + +namespace MCFitter { + +/** + * Class for storing the result of an iteration in the AtMCFitter method. + */ +class [[deprecated]] AtMCResultOld : public TObject { +public: + using ParamMap = std::map; + + Double_t fObjective; //< Value f the objective function for this iteration + ParamMap fParameters; //< Parameters used in simulation + Int_t fIterNum; //< Iteration number. Used to map with the simulated event ID in the TTree. + + AtMCResultOld() = default; + + void Print() const; + + ClassDefOverride(AtMCResultOld, 1); +}; + +} // namespace MCFitter + +#endif // #ifndef ATMCRESULT_H diff --git a/AtData/AtTrackingEventOld.cxx b/AtData/AtTrackingEventOld.cxx new file mode 100644 index 000000000..8da4cc097 --- /dev/null +++ b/AtData/AtTrackingEventOld.cxx @@ -0,0 +1,48 @@ +#include "AtTrackingEventOld.h" + +#include +#include + +#include + +ClassImp(AtTrackingEventOld); + +AtTrackingEventOld::AtTrackingEventOld() : AtBaseEvent("Tracking Event") {} + +void AtTrackingEventOld::SetTrackArray(std::vector *trackArray) +{ + fTrackArray = *trackArray; +} +void AtTrackingEventOld::SetTrack(AtTrack *track) +{ + fTrackArray.push_back(*track); +} +void AtTrackingEventOld::SetVertex(Double_t vertex) +{ + fVertex = vertex; +} +void AtTrackingEventOld::SetGeoVertex(TVector3 vertex) +{ + fGeoVertex = vertex; +} +void AtTrackingEventOld::SetVertexEnergy(Double_t vertexEner) +{ + fVertexEnergy = vertexEner; +} + +/*std::vector AtTrackingEventOld::GetTrackArray() +{ + return fTrackArray; +}*/ +Double_t AtTrackingEventOld::GetVertex() +{ + return fVertex; +} +Double_t AtTrackingEventOld::GetVertexEnergy() +{ + return fVertexEnergy; +} +TVector3 AtTrackingEventOld::GetGeoVertex() +{ + return fGeoVertex; +} diff --git a/AtData/AtTrackingEventOld.h b/AtData/AtTrackingEventOld.h new file mode 100644 index 000000000..599f71b44 --- /dev/null +++ b/AtData/AtTrackingEventOld.h @@ -0,0 +1,59 @@ +#ifndef AtTRACKINGEVENTOLD_H +#define AtTRACKINGEVENTOLD_H + +#include "AtBaseEvent.h" +#include "AtFittedTrackOld.h" +#include "AtTrack.h" + +#include +#include +#include + +#include + +class TBuffer; +class TClass; +class TMemberInspector; + +class [[deprecated]] AtTrackingEventOld : public AtBaseEvent { + + using FTrackPtr = std::unique_ptr; + using FTrackVector = std::vector; + +public: + AtTrackingEventOld(); + virtual ~AtTrackingEventOld() = default; + + void SetTrackArray(std::vector *trackArray); + void SetTrack(AtTrack *track); + void SetVertex(Double_t vertex); + void SetGeoVertex(TVector3 vertex); + void SetVertexEnergy(Double_t vertexEner); + + AtFittedTrackOld &AddFittedTrack(std::unique_ptr ptr) + { + fFittedTrackArray.push_back(std::move(ptr)); + if (fFittedTrackArray.back()->GetTrackID() == -1) + fFittedTrackArray.back()->SetTrackID(fFittedTrackArray.size() - 1); + LOG(debug) << "Adding Track with ID " << fFittedTrackArray.back()->GetTrackID() << " to event " << fEventID; + + return *(fFittedTrackArray.back()); + } + + Double_t GetVertex(); + Double_t GetVertexEnergy(); + TVector3 GetGeoVertex(); + std::vector GetTrackArray(); + const FTrackVector &GetFittedTracks() const { return fFittedTrackArray; } + +private: + std::vector fTrackArray; + FTrackVector fFittedTrackArray; + Double_t fVertex{-10.0}; + Double_t fVertexEnergy{-10.0}; + TVector3 fGeoVertex; + + ClassDef(AtTrackingEventOld, 1); +}; + +#endif diff --git a/AtData/CMakeLists.txt b/AtData/CMakeLists.txt index 445cb359f..3ab791cff 100644 --- a/AtData/CMakeLists.txt +++ b/AtData/CMakeLists.txt @@ -41,7 +41,6 @@ set(SRCS AtProtoQuadrant.cxx AtTrack.cxx AtContainerManip.cxx - AtMCResult.cxx AtFissionEvent.cxx @@ -55,7 +54,16 @@ set(SRCS AtPattern/AtPadPlaneElement.cxx AtFittedTrack.cxx - + AtFitMetadata.cxx + AtFitTrackMetadata.cxx + AtMCResult.cxx + + + # Deprecated things to be removed eventually. + AtTrackingEventOld.cxx + AtFittedTrackOld.cxx + AtMCResultOld.cxx + ) set(TEST_SRCS diff --git a/AtReconstruction/AtFitter/AtFitter.cxx b/AtReconstruction/AtFitter/AtFitter.cxx index bbb806185..2dd2582d9 100644 --- a/AtReconstruction/AtFitter/AtFitter.cxx +++ b/AtReconstruction/AtFitter/AtFitter.cxx @@ -1,268 +1,34 @@ #include "AtFitter.h" -#include "AtHit.h" -#include "AtHitCluster.h" // for AtHitCluster -#include "AtTrack.h" +#include "AtPatternEvent.h" +#include "AtTrackingEvent.h" -#include // for Cartesian3D, operator-, PositionVector3D -#include // for DisplacementVector3D -#include - -#include -#include // for sqrt -#include // for operator<<, basic_ostream::operator<< -#include // for shared_ptr, __shared_ptr_access, __sha... -#include // for pair - -constexpr auto cRED = "\033[1;31m"; -constexpr auto cYELLOW = "\033[1;33m"; -constexpr auto cNORMAL = "\033[0m"; -constexpr auto cGREEN = "\033[1;32m"; - -ClassImp(AtFITTER::AtFitter); - -AtFITTER::AtFitter::AtFitter() = default; - -AtFITTER::AtFitter::~AtFitter() = default; - -std::tuple AtFITTER::AtFitter::GetMomFromBrho(Double_t M, Double_t Z, Double_t brho) +void EventFit::AtFitter::FitEvent(AtTrackingEvent *trackingEvent, AtPatternEvent *patternEvent, + AtFitMetadata *fitMetadata, AtRawEvent *rawEvent, AtEvent *event) { - - const Double_t M_Ener = M * 931.49401 / 1000.0; - Double_t p = brho * Z * (2.99792458 / 10.0); // In GeV - Double_t E = TMath::Sqrt(TMath::Power(p, 2) + TMath::Power(M_Ener, 2)) - M_Ener; - // std::cout << " Brho : " << brho << " - p : " << p << " - E : " << E << "\n"; - return std::make_tuple(p, E); -} - -Bool_t AtFITTER::AtFitter::FindVertexTrack(AtTrack *trA, AtTrack *trB) -{ - // Determination of first hit distance. NB: Assuming both tracks have the same angle sign - Double_t vertexA = 0.0; - Double_t vertexB = 0.0; - if (trA->GetGeoTheta() * TMath::RadToDeg() < 90) { - auto iniClusterA = trA->GetHitClusterArray()->back(); - auto iniClusterB = trB->GetHitClusterArray()->back(); - vertexA = 1000.0 - iniClusterA.GetPosition().Z(); - vertexB = 1000.0 - iniClusterB.GetPosition().Z(); - } else if (trA->GetGeoTheta() * TMath::RadToDeg() > 90) { - auto iniClusterA = trA->GetHitClusterArray()->front(); - auto iniClusterB = trB->GetHitClusterArray()->front(); - vertexA = iniClusterA.GetPosition().Z(); - vertexB = iniClusterB.GetPosition().Z(); + // Check for nullptr. + if (trackingEvent == nullptr) { + LOG(error) << " Tracking event is nullptr! The fitter can not fit this event. Maybe the tracking event is not " + "being constructed properly in the fitter task."; + return; } - return vertexA < vertexB; -} - -Bool_t AtFITTER::AtFitter::MergeTracks(std::vector *trackCandSource, std::vector *trackDest, - Bool_t enableSingleVertexTrack, Double_t clusterRadius, Double_t clusterDistance) -{ - - Bool_t toMerge = kFALSE; - - Int_t addHitCnt = 0; - // Find the track closer to vertex - std::sort(trackCandSource->begin(), trackCandSource->end(), - [this](AtTrack *trA, AtTrack *trB) { return FindVertexTrack(trA, trB); }); - - // Track stitching from vertex - AtTrack *vertexTrack = *trackCandSource->begin(); - - if (enableSingleVertexTrack) { - - // Mark all tracks as merged - for (auto track : *trackCandSource) - track->SetIsMerged(kTRUE); - - trackDest->push_back(*vertexTrack); - return true; + if (patternEvent == nullptr) { + LOG(error) << " Pattern event is nullptr! The fitter can not fit this event."; + return; } - // Check if the candidate vertex track was merged - if (vertexTrack->GetIsMerged()) - return kFALSE; - else - vertexTrack->SetIsMerged(kTRUE); - - // If enabled, choose only the track closest to vertex (i.e. first one of the collection of candidates) - // TODO: Select by number of points - - for (auto it = trackCandSource->begin() + 1; it != trackCandSource->end(); ++it) { - // NB: These tracks were previously marked to merge. If merging fails they should be discarded. - AtTrack *trackToMerge = *(it); - toMerge = kFALSE; - - // Skip trackes flagged as merged - if (!trackToMerge->GetIsMerged()) { - trackToMerge->SetIsMerged(kTRUE); - } else - continue; - - Double_t endVertexZ = 0.0; - Double_t iniMergeZ = 0.0; - std::cout << " Trying to merge ... " - << "\n"; - std::cout << " Vertex track " << vertexTrack->GetTrackID() << " - Track to Merge " << trackToMerge->GetTrackID() - << "\n"; - // Check relative position between end and begin of each track using Hit Clusters - std::cout << " Vertex angle " << vertexTrack->GetGeoTheta() * TMath::RadToDeg() << "\n"; - if (vertexTrack->GetGeoTheta() * TMath::RadToDeg() < 90) { - auto endClusterVertex = vertexTrack->GetHitClusterArray()->front(); - auto iniClusterMerge = trackToMerge->GetHitClusterArray()->back(); - // Check separation and relative distance - endVertexZ = 1000.0 - endClusterVertex.GetPosition().Z(); - iniMergeZ = 1000.0 - iniClusterMerge.GetPosition().Z(); - - Double_t distance = std::sqrt((iniClusterMerge.GetPosition() - endClusterVertex.GetPosition()).Mag2()); - // std::cout << " Distance between tracks " << distance << "\n"; - // std::cout << " Ini Merge " << iniMergeZ << " - endVertexZ " << endVertexZ << "\n"; - if (((iniMergeZ + 10.0) > endVertexZ) && distance < 200) { - toMerge = kTRUE; - } - - } else if (vertexTrack->GetGeoTheta() * TMath::RadToDeg() > 90) { - auto endClusterVertex = vertexTrack->GetHitClusterArray()->back(); - auto iniClusterMerge = trackToMerge->GetHitClusterArray()->front(); - // Check separation and relative distance - endVertexZ = endClusterVertex.GetPosition().Z(); - iniMergeZ = iniClusterMerge.GetPosition().Z(); - Double_t distance = std::sqrt((iniClusterMerge.GetPosition() - endClusterVertex.GetPosition()).Mag2()); - // std::cout<<" Distance between tracks "< endVertexZ) && - distance < 100) { // NB: Distance between parts of the backward tracks is more critical - toMerge = kTRUE; - } - } - - if (toMerge) { - - std::cout << " --- Merging Succeeded! Vertex track " << vertexTrack->GetTrackID() << " - Track to Merge " - << trackToMerge->GetTrackID() << "\n"; - for (const auto &hit : trackToMerge->GetHitArray()) { - - vertexTrack->AddHit(hit->Clone()); // TODO: Look at code and see if this can be a move instead of a copy - ++addHitCnt; - } - - // Reclusterize after merging - vertexTrack->SortHitArrayTime(); - vertexTrack->ResetHitClusterArray(); - fTrackTransformer->ClusterizeSmooth3D( - *vertexTrack, clusterRadius, - clusterDistance); // NB: It can be removed if we force reclusterization for any track in the mina program + // Extract the candidate AtTracks. If there are not any tracks, return earlier. + std::vector tracks = patternEvent->GetTrackCand(); + if (!tracks.size()) + return; - // TODO: Check if phi recalculatio is needed + // Save the original AtTracks to the AtTrackingEvent. + trackingEvent->SetTrackArray(&tracks); - } else { - std::cout << " --- Merging Failed ! Vertex track " << vertexTrack->GetTrackID() << " - Track to Merge " - << trackToMerge->GetTrackID() << "\n"; - } + // Iterate over the AtTracks and store the AtFittedTracks in the AtTrackingEvent. + for (auto track : tracks) { + std::unique_ptr fittedTrack(GetFittedTrack(&track, fitMetadata, rawEvent, event)); + trackingEvent->AddFittedTrack(std::move(fittedTrack)); } - - trackDest->push_back(*vertexTrack); - - return toMerge; -} -[[deprecated]] void -AtFITTER::AtFitter::MergeTracks(std::vector *trackCandSource, std::vector *trackJunkSource, - std::vector *trackDest, bool fitDirection, bool simulationConv) -{ - // DEPRECATED - // Track destination are the merged tracks. - // Track candidate source are the main tracks identified as candidates. - // Track junk source are the tracks from which clusters will be extracted. Boundary conditions are applied: Proximity - // in space, angle and center. - // NB: Only works for backward tracks - - Double_t trackDist = 20.0; // Distance between clusters in mm - Double_t angleSpread = 5.0; // Maximum angular spread between clusters - Double_t centerSpread = 20.0; // Maximim distance between track centers for merging. - Int_t minClusters = 3; - Int_t trackSize = 0; - - for (auto trackCand : *trackCandSource) { - Double_t thetaCand = trackCand.GetGeoTheta(); - auto &hitArrayCand = trackCand.GetHitArray(); - std::pair centerCand = trackCand.GetGeoCenter(); - - AtTrack track = trackCand; - Int_t jnkCnt = 0; - Int_t jnkHitCnt = 0; - - if (simulationConv) { - thetaCand = 180.0 - thetaCand * TMath::RadToDeg(); - - } else { - thetaCand = thetaCand * TMath::RadToDeg(); - } - - for (auto trackJunk : *trackJunkSource) { - Double_t thetaJunk = trackJunk.GetGeoTheta(); - auto &hitArrayJunk = trackJunk.GetHitArray(); - std::pair centerJunk = trackJunk.GetGeoCenter(); - - if (simulationConv) { - - thetaJunk = 180.0 - thetaJunk * TMath::RadToDeg(); - } else { - - thetaJunk = thetaJunk * TMath::RadToDeg(); - } - - if (thetaCand > 90) { - - /*if ((thetaCand + angleSpread) > thetaJunk && (thetaCand - angleSpread) < thetaJunk) { - for (auto hit : *hitArrayJunk) { - - track.AddHit(&hit); - ++jnkHitCnt; - } - }*/ - - } else if (thetaCand < 90 && thetaJunk < 90) { - - if ((thetaCand + angleSpread) > thetaJunk && (thetaCand - angleSpread) < thetaJunk) { // Check angle - std::cout << " Center cand : " << centerCand.first << " - " << centerCand.second << " " << thetaCand - << "\n"; - std::cout << " Center junk : " << centerJunk.first << " - " << centerJunk.second << " " << thetaJunk - << "\n"; - Double_t centerDistance = TMath::Sqrt(TMath::Power(centerCand.first - centerJunk.first, 2) + - TMath::Power(centerCand.second - centerJunk.second, 2)); - std::cout << " Distance " << centerDistance << "\n"; - if (centerDistance < 50) // Check quadrant - - for (const auto &hit : hitArrayJunk) { - - track.AddHit(hit->Clone()); - ++jnkHitCnt; - } - } - } - - ++jnkCnt; - - } // Junk track - - // track.SortClusterHitArrayZ(); - track.SortHitArrayTime(); - - // Prune if other tracks were added track - - if (trackJunkSource->size() > 0) { - Double_t pruneFraction = 10.0; - if (jnkHitCnt > hitArrayCand.size()) - pruneFraction = 4.0; // 25% - else - pruneFraction = 10.0; // 10% - - Int_t numHits = (Int_t)track.GetHitArray().size() / pruneFraction; - for (auto iHit = 0; iHit < numHits; ++iHit) - track.GetHitArray().pop_back(); - } - - trackDest->push_back(track); - - } // Source track } diff --git a/AtReconstruction/AtFitter/AtFitter.h b/AtReconstruction/AtFitter/AtFitter.h index 9d1e8274a..18b8f5190 100644 --- a/AtReconstruction/AtFitter/AtFitter.h +++ b/AtReconstruction/AtFitter/AtFitter.h @@ -1,53 +1,47 @@ #ifndef AtFITTER_H #define AtFITTER_H -#include "AtTrackTransformer.h" +#include "AtFitMetadata.h" #include #include +#include #include +#include #include #include class AtDigiPar; class AtTrack; class AtFittedTrack; -class FairLogger; class TBuffer; class TClass; class TMemberInspector; +class AtRawEvent; +class AtEvent; +class AtFitMetadata; +class AtPatternEvent; +class AtRawEvent; +class AtTrackingEvent; -namespace genfit { -class Track; -} // namespace genfit - -namespace AtFITTER { +namespace EventFit { class AtFitter : public TObject { - public: - AtFitter(); - virtual ~AtFitter(); - virtual std::vector> ProcessTracks(std::vector &tracks) = 0; - virtual void Init() = 0; + AtFitter() = default; + ~AtFitter() = default; - void MergeTracks(std::vector *trackCandSource, std::vector *trackJunkSource, - std::vector *trackDest, bool fitDirection, bool simulationConv); - Bool_t MergeTracks(std::vector *trackCandSource, std::vector *trackDest, - Bool_t enableSingleVertexTrack, Double_t clusterRadius, Double_t clusterDistance); + virtual void FitEvent(AtTrackingEvent *trackingEvent, AtPatternEvent *patternEvent, + AtFitMetadata *fitMetadata = nullptr, AtRawEvent *rawEvent = nullptr, + AtEvent *event = nullptr); + virtual void Init() = 0; protected: - FairLogger *fLogger{}; ///< logger pointer - AtDigiPar *fPar{}; ///< parameter container - std::unique_ptr fTrackTransformer{std::make_unique()}; - std::tuple - GetMomFromBrho(Double_t A, Double_t Z, - Double_t brho); ///< Returns momentum (in GeV) from Brho assuming M (amu) and Z; - Bool_t FindVertexTrack(AtTrack *trA, AtTrack *trB); ///< Lambda function to find track closer to vertex - ClassDef(AtFitter, 1); + virtual AtFittedTrack *GetFittedTrack(AtTrack *track, AtFitMetadata *fitMetadata = nullptr, + AtRawEvent *rawEvent = nullptr, AtEvent *event = nullptr) = 0; }; -} // namespace AtFITTER +} // namespace EventFit #endif diff --git a/AtReconstruction/AtFitter/AtFitterOld.cxx b/AtReconstruction/AtFitter/AtFitterOld.cxx new file mode 100644 index 000000000..a663bb92a --- /dev/null +++ b/AtReconstruction/AtFitter/AtFitterOld.cxx @@ -0,0 +1,269 @@ +#include "AtFitterOld.h" + +#include "AtHit.h" +#include "AtHitCluster.h" // for AtHitCluster +#include "AtTrack.h" + +#include // for Cartesian3D, operator-, PositionVector3D +#include // for DisplacementVector3D +#include + +#include +#include // for sqrt +#include // for operator<<, basic_ostream::operator<< +#include // for shared_ptr, __shared_ptr_access, __sha... +#include // for pair + +constexpr auto cRED = "\033[1;31m"; +constexpr auto cYELLOW = "\033[1;33m"; +constexpr auto cNORMAL = "\033[0m"; +constexpr auto cGREEN = "\033[1;32m"; + +ClassImp(AtFITTER::AtFitterOld); + +AtFITTER::AtFitterOld::AtFitterOld() = default; + +AtFITTER::AtFitterOld::~AtFitterOld() = default; + +std::tuple AtFITTER::AtFitterOld::GetMomFromBrho(Double_t M, Double_t Z, Double_t brho) +{ + + const Double_t M_Ener = M * 931.49401 / 1000.0; + Double_t p = brho * Z * (2.99792458 / 10.0); // In GeV + Double_t E = TMath::Sqrt(TMath::Power(p, 2) + TMath::Power(M_Ener, 2)) - M_Ener; + // std::cout << " Brho : " << brho << " - p : " << p << " - E : " << E << "\n"; + return std::make_tuple(p, E); +} + +Bool_t AtFITTER::AtFitterOld::FindVertexTrack(AtTrack *trA, AtTrack *trB) +{ + // Determination of first hit distance. NB: Assuming both tracks have the same angle sign + Double_t vertexA = 0.0; + Double_t vertexB = 0.0; + if (trA->GetGeoTheta() * TMath::RadToDeg() < 90) { + auto iniClusterA = trA->GetHitClusterArray()->back(); + auto iniClusterB = trB->GetHitClusterArray()->back(); + vertexA = 1000.0 - iniClusterA.GetPosition().Z(); + vertexB = 1000.0 - iniClusterB.GetPosition().Z(); + } else if (trA->GetGeoTheta() * TMath::RadToDeg() > 90) { + auto iniClusterA = trA->GetHitClusterArray()->front(); + auto iniClusterB = trB->GetHitClusterArray()->front(); + vertexA = iniClusterA.GetPosition().Z(); + vertexB = iniClusterB.GetPosition().Z(); + } + + return vertexA < vertexB; +} + +Bool_t AtFITTER::AtFitterOld::MergeTracks(std::vector *trackCandSource, std::vector *trackDest, + Bool_t enableSingleVertexTrack, Double_t clusterRadius, + Double_t clusterDistance) +{ + + Bool_t toMerge = kFALSE; + + Int_t addHitCnt = 0; + // Find the track closer to vertex + std::sort(trackCandSource->begin(), trackCandSource->end(), + [this](AtTrack *trA, AtTrack *trB) { return FindVertexTrack(trA, trB); }); + + // Track stitching from vertex + AtTrack *vertexTrack = *trackCandSource->begin(); + + if (enableSingleVertexTrack) { + + // Mark all tracks as merged + for (auto track : *trackCandSource) + track->SetIsMerged(kTRUE); + + trackDest->push_back(*vertexTrack); + return true; + } + + // Check if the candidate vertex track was merged + if (vertexTrack->GetIsMerged()) + return kFALSE; + else + vertexTrack->SetIsMerged(kTRUE); + + // If enabled, choose only the track closest to vertex (i.e. first one of the collection of candidates) + // TODO: Select by number of points + + for (auto it = trackCandSource->begin() + 1; it != trackCandSource->end(); ++it) { + // NB: These tracks were previously marked to merge. If merging fails they should be discarded. + AtTrack *trackToMerge = *(it); + toMerge = kFALSE; + + // Skip trackes flagged as merged + if (!trackToMerge->GetIsMerged()) { + trackToMerge->SetIsMerged(kTRUE); + } else + continue; + + Double_t endVertexZ = 0.0; + Double_t iniMergeZ = 0.0; + std::cout << " Trying to merge ... " + << "\n"; + std::cout << " Vertex track " << vertexTrack->GetTrackID() << " - Track to Merge " << trackToMerge->GetTrackID() + << "\n"; + // Check relative position between end and begin of each track using Hit Clusters + std::cout << " Vertex angle " << vertexTrack->GetGeoTheta() * TMath::RadToDeg() << "\n"; + if (vertexTrack->GetGeoTheta() * TMath::RadToDeg() < 90) { + auto endClusterVertex = vertexTrack->GetHitClusterArray()->front(); + auto iniClusterMerge = trackToMerge->GetHitClusterArray()->back(); + // Check separation and relative distance + endVertexZ = 1000.0 - endClusterVertex.GetPosition().Z(); + iniMergeZ = 1000.0 - iniClusterMerge.GetPosition().Z(); + + Double_t distance = std::sqrt((iniClusterMerge.GetPosition() - endClusterVertex.GetPosition()).Mag2()); + // std::cout << " Distance between tracks " << distance << "\n"; + // std::cout << " Ini Merge " << iniMergeZ << " - endVertexZ " << endVertexZ << "\n"; + if (((iniMergeZ + 10.0) > endVertexZ) && distance < 200) { + toMerge = kTRUE; + } + + } else if (vertexTrack->GetGeoTheta() * TMath::RadToDeg() > 90) { + auto endClusterVertex = vertexTrack->GetHitClusterArray()->back(); + auto iniClusterMerge = trackToMerge->GetHitClusterArray()->front(); + // Check separation and relative distance + endVertexZ = endClusterVertex.GetPosition().Z(); + iniMergeZ = iniClusterMerge.GetPosition().Z(); + Double_t distance = std::sqrt((iniClusterMerge.GetPosition() - endClusterVertex.GetPosition()).Mag2()); + // std::cout<<" Distance between tracks "< endVertexZ) && + distance < 100) { // NB: Distance between parts of the backward tracks is more critical + toMerge = kTRUE; + } + } + + if (toMerge) { + + std::cout << " --- Merging Succeeded! Vertex track " << vertexTrack->GetTrackID() << " - Track to Merge " + << trackToMerge->GetTrackID() << "\n"; + for (const auto &hit : trackToMerge->GetHitArray()) { + + vertexTrack->AddHit(hit->Clone()); // TODO: Look at code and see if this can be a move instead of a copy + ++addHitCnt; + } + + // Reclusterize after merging + vertexTrack->SortHitArrayTime(); + vertexTrack->ResetHitClusterArray(); + fTrackTransformer->ClusterizeSmooth3D( + *vertexTrack, clusterRadius, + clusterDistance); // NB: It can be removed if we force reclusterization for any track in the mina program + + // TODO: Check if phi recalculatio is needed + + } else { + std::cout << " --- Merging Failed ! Vertex track " << vertexTrack->GetTrackID() << " - Track to Merge " + << trackToMerge->GetTrackID() << "\n"; + } + } + + trackDest->push_back(*vertexTrack); + + return toMerge; +} +[[deprecated]] void +AtFITTER::AtFitterOld::MergeTracks(std::vector *trackCandSource, std::vector *trackJunkSource, + std::vector *trackDest, bool fitDirection, bool simulationConv) +{ + // DEPRECATED + // Track destination are the merged tracks. + // Track candidate source are the main tracks identified as candidates. + // Track junk source are the tracks from which clusters will be extracted. Boundary conditions are applied: Proximity + // in space, angle and center. + // NB: Only works for backward tracks + + Double_t trackDist = 20.0; // Distance between clusters in mm + Double_t angleSpread = 5.0; // Maximum angular spread between clusters + Double_t centerSpread = 20.0; // Maximim distance between track centers for merging. + Int_t minClusters = 3; + Int_t trackSize = 0; + + for (auto trackCand : *trackCandSource) { + Double_t thetaCand = trackCand.GetGeoTheta(); + auto &hitArrayCand = trackCand.GetHitArray(); + std::pair centerCand = trackCand.GetGeoCenter(); + + AtTrack track = trackCand; + Int_t jnkCnt = 0; + Int_t jnkHitCnt = 0; + + if (simulationConv) { + thetaCand = 180.0 - thetaCand * TMath::RadToDeg(); + + } else { + thetaCand = thetaCand * TMath::RadToDeg(); + } + + for (auto trackJunk : *trackJunkSource) { + Double_t thetaJunk = trackJunk.GetGeoTheta(); + auto &hitArrayJunk = trackJunk.GetHitArray(); + std::pair centerJunk = trackJunk.GetGeoCenter(); + + if (simulationConv) { + + thetaJunk = 180.0 - thetaJunk * TMath::RadToDeg(); + } else { + + thetaJunk = thetaJunk * TMath::RadToDeg(); + } + + if (thetaCand > 90) { + + /*if ((thetaCand + angleSpread) > thetaJunk && (thetaCand - angleSpread) < thetaJunk) { + for (auto hit : *hitArrayJunk) { + + track.AddHit(&hit); + ++jnkHitCnt; + } + }*/ + + } else if (thetaCand < 90 && thetaJunk < 90) { + + if ((thetaCand + angleSpread) > thetaJunk && (thetaCand - angleSpread) < thetaJunk) { // Check angle + std::cout << " Center cand : " << centerCand.first << " - " << centerCand.second << " " << thetaCand + << "\n"; + std::cout << " Center junk : " << centerJunk.first << " - " << centerJunk.second << " " << thetaJunk + << "\n"; + Double_t centerDistance = TMath::Sqrt(TMath::Power(centerCand.first - centerJunk.first, 2) + + TMath::Power(centerCand.second - centerJunk.second, 2)); + std::cout << " Distance " << centerDistance << "\n"; + if (centerDistance < 50) // Check quadrant + + for (const auto &hit : hitArrayJunk) { + + track.AddHit(hit->Clone()); + ++jnkHitCnt; + } + } + } + + ++jnkCnt; + + } // Junk track + + // track.SortClusterHitArrayZ(); + track.SortHitArrayTime(); + + // Prune if other tracks were added track + + if (trackJunkSource->size() > 0) { + Double_t pruneFraction = 10.0; + if (jnkHitCnt > hitArrayCand.size()) + pruneFraction = 4.0; // 25% + else + pruneFraction = 10.0; // 10% + + Int_t numHits = (Int_t)track.GetHitArray().size() / pruneFraction; + for (auto iHit = 0; iHit < numHits; ++iHit) + track.GetHitArray().pop_back(); + } + + trackDest->push_back(track); + + } // Source track +} diff --git a/AtReconstruction/AtFitter/AtFitterOld.h b/AtReconstruction/AtFitter/AtFitterOld.h new file mode 100644 index 000000000..1f08e1252 --- /dev/null +++ b/AtReconstruction/AtFitter/AtFitterOld.h @@ -0,0 +1,53 @@ +#ifndef AtFITTEROLD_H +#define AtFITTEROLD_H + +#include "AtTrackTransformer.h" + +#include +#include + +#include +#include +#include + +class AtDigiPar; +class AtTrack; +class AtFittedTrack; +class FairLogger; +class TBuffer; +class TClass; +class TMemberInspector; + +namespace genfit { +class Track; +} // namespace genfit + +namespace AtFITTER { + +class [[deprecated]] AtFitterOld : public TObject { + +public: + AtFitterOld(); + virtual ~AtFitterOld(); + virtual std::vector> ProcessTracks(std::vector &tracks) = 0; + virtual void Init() = 0; + + void MergeTracks(std::vector *trackCandSource, std::vector *trackJunkSource, + std::vector *trackDest, bool fitDirection, bool simulationConv); + Bool_t MergeTracks(std::vector *trackCandSource, std::vector *trackDest, + Bool_t enableSingleVertexTrack, Double_t clusterRadius, Double_t clusterDistance); + +protected: + FairLogger *fLogger{}; ///< logger pointer + AtDigiPar *fPar{}; ///< parameter container + std::unique_ptr fTrackTransformer{std::make_unique()}; + std::tuple + GetMomFromBrho(Double_t A, Double_t Z, + Double_t brho); ///< Returns momentum (in GeV) from Brho assuming M (amu) and Z; + Bool_t FindVertexTrack(AtTrack *trA, AtTrack *trB); ///< Lambda function to find track closer to vertex + ClassDef(AtFitterOld, 1); +}; + +} // namespace AtFITTER + +#endif diff --git a/AtReconstruction/AtFitter/AtGenfit.h b/AtReconstruction/AtFitter/AtGenfit.h index e42aaa271..60c97ba59 100644 --- a/AtReconstruction/AtFitter/AtGenfit.h +++ b/AtReconstruction/AtFitter/AtGenfit.h @@ -1,7 +1,7 @@ #ifndef ATGENFIT_H #define ATGENFIT_H -#include "AtFitter.h" +#include "AtFitterOld.h" #include "AtFormat.h" #include "AtKinematics.h" #include "AtParsers.h" @@ -55,7 +55,9 @@ class MeasurementFactory; namespace AtFITTER { -class AtGenfit : public AtFitter { +class [[deprecated( + "This still derives from the old AtFitter. Please consider updating it to the new AtFitter at some point")]] AtGenfit + : public AtFitterOld { private: std::shared_ptr fKalmanFitter; TClonesArray *fGenfitTrackArray; diff --git a/AtReconstruction/AtFitter/AtMCFitter.cxx b/AtReconstruction/AtFitter/AtMCFitter.cxx index bef820f99..ea455eccd 100644 --- a/AtReconstruction/AtFitter/AtMCFitter.cxx +++ b/AtReconstruction/AtFitter/AtMCFitter.cxx @@ -30,7 +30,7 @@ namespace MCFitter { AtMCFitter::AtMCFitter(SimPtr sim, ClusterPtr cluster, PulsePtr pulse) : fMap(pulse->GetMap()), fSim(move(sim)), fClusterize(move(cluster)), fPulse(move(pulse)), - fResults([](const AtMCResult &a, const AtMCResult &b) { return a.fObjective < b.fObjective; }) + fResults([](const AtMCResult &a, const AtMCResult &b) { return a.GetChi2() < b.GetChi2(); }) { } @@ -84,7 +84,7 @@ void AtMCFitter::RunIterRange(int startIter, int numIter, AtPulse *pulse) double obj = ObjectiveFunction(*fCurrentEvent, idx, result); result.fIterNum = idx; - result.fObjective = obj; + result.SetChi2(obj); // result.Print(); { std::lock_guard lk(fResultMutex); diff --git a/AtReconstruction/AtFitter/AtMCFitterOld.cxx b/AtReconstruction/AtFitter/AtMCFitterOld.cxx new file mode 100644 index 000000000..87aa7985f --- /dev/null +++ b/AtReconstruction/AtFitter/AtMCFitterOld.cxx @@ -0,0 +1,216 @@ +#include "AtMCFitterOld.h" +// IWYU pragma: no_include + +#include "AtClusterize.h" // for AtClusterize +#include "AtDigiPar.h" // for AtDigiPar +#include "AtEvent.h" // for AtEvent +#include "AtMCResultOld.h" +#include "AtPSA.h" // for AtPSA +#include "AtParameterDistribution.h" +#include "AtPatternEvent.h" // for AtPatternEvent +#include "AtPulse.h" // for AtPulse +#include "AtRawEvent.h" // for AtRawEvent +#include "AtSimpleSimulation.h" // for AtSimpleSimulation +#include "AtSimulatedPoint.h" // IWYU pragma: keep +#include "AtSpaceChargeModel.h" + +#include // for LOG, Logger +#include // for FairParSet +#include // for FairRunAna +#include // for FairRuntimeDb + +#include + +#include // for max +#include +#include +#include +using std::move; +namespace MCFitter { + +AtMCFitterOld::AtMCFitterOld(SimPtr sim, ClusterPtr cluster, PulsePtr pulse) + : fMap(pulse->GetMap()), fSim(move(sim)), fClusterize(move(cluster)), fPulse(move(pulse)), + fResults([](const AtMCResultOld &a, const AtMCResultOld &b) { return a.fObjective < b.fObjective; }) +{ +} + +AtMCFitterOld::ParamPtr AtMCFitterOld::GetParameter(const std::string &name) const +{ + if (fParameters.find(name) != fParameters.end()) { + return fParameters.at(name); + } + return nullptr; +} + +void AtMCFitterOld::SetNumThreads(int num) +{ + if (num > 1) + ROOT::EnableThreadSafety(); + fNumThreads = num; +} + +void AtMCFitterOld::Init() +{ + CreateParamDistros(); + + FairRunAna *ana = FairRunAna::Instance(); + FairRuntimeDb *rtdb = ana->GetRuntimeDb(); + fPar = dynamic_cast(rtdb->getContainer("AtDigiPar")); + + fPulse->SetParameters(fPar); + fClusterize->GetParameters(fPar); + if (fPSA) + fPSA->Init(); + if (fSim->GetSpaceChargeModel()) + fSim->GetSpaceChargeModel()->LoadParameters(fPar); + + fThPulse.resize(fNumThreads); + for (int i = 0; i < fNumThreads; ++i) + fThPulse[i] = fPulse->Clone(); +} + +void AtMCFitterOld::RunIterRange(int startIter, int numIter, AtPulse *pulse) +{ + // Here we should copy each thread their own version of the clusterize, pulse, and simulation + // objects (only if the number of threads is greater than 1). Needs to be deep copies + + for (int i = 0; i < numIter; ++i) { + + int idx = startIter + i; + auto result = DefineEvent(); + auto mcPoints = SimulateEvent(result); + + DigitizeEvent(mcPoints, idx, pulse); + double obj = ObjectiveFunction(*fCurrentEvent, idx, result); + + result.fIterNum = idx; + result.fObjective = obj; + // result.Print(); + { + std::lock_guard lk(fResultMutex); + fResults.insert(result); + } + } + LOG(debug) << "Done with run iter range"; +} + +void AtMCFitterOld::Exec(const AtPatternEvent &event) +{ + fRawEventArray.clear(); + fEventArray.clear(); + fResults.clear(); + + SetParamDistributions(event); + + // Set the conditions for simulating the event + fCurrentEvent = &event; + + // Make sure the event arrays are large enough so no resizing will happen + fRawEventArray.resize(fNumIter); + fEventArray.resize(fNumIter); + + for (int i = 0; i < fNumRounds; ++i) { + RunRound(); + RecenterParamDistributions(); + } +} +void AtMCFitterOld::RunRound() +{ + // Begining of round + auto start = std::chrono::high_resolution_clock::now(); + + // Get what iterations to do on what thread. + std::vector> threadParam; + int iterPerTh = fNumIter / fNumThreads; + for (int i = 0; i < fNumThreads; ++i) + threadParam.emplace_back(0, iterPerTh); + for (int i = 0; i < fNumIter % fNumThreads; ++i) + threadParam[i].second++; + for (int i = 1; i < fNumThreads; ++i) + threadParam[i].first = threadParam[i - 1].first + threadParam[i - 1].second; + + for (int i = 0; i < threadParam.size(); ++i) { + LOG(info) << i << ": " << threadParam[i].first << " " << threadParam[i].second; + } + + std::vector threads; + for (int i = 0; i < fNumThreads; ++i) { + LOG(debug) << "Creating thread " << i << " with " << threadParam[i].first << " " << threadParam[i].second + << " and " << fPulse.get(); + + // Spawn a thread to call RunIterRange. + threads.emplace_back( + [this](std::pair param, AtPulse *pulse) { this->RunIterRange(param.first, param.second, pulse); }, + threadParam[i], fThPulse[i].get()); + } + + // Wait for all threads to finish + for (auto &th : threads) + th.join(); + + auto stop = std::chrono::high_resolution_clock::now(); + + if (fTimeEvent) + LOG(info) << "Simulation of " << fNumIter << " events took " + << std::chrono::duration_cast(stop - start).count() << " ms."; +} + +int AtMCFitterOld::DigitizeEvent(const TClonesArray &points, int idx, AtPulse *pulse) +{ + // Event has been simulated and is sitting in the fSim + auto vec = fClusterize->ProcessEvent(points); + LOG(debug) << "Digitizing event at " << idx; + + fRawEventArray[idx] = pulse->GenerateEvent(vec); + + if (fPSA) { + LOG(debug) << "Running PSA at " << idx; + fEventArray[idx] = fPSA->Analyze(fRawEventArray[idx]); + } + LOG(debug) << "Done digitizing event at " << idx; + return idx; +} + +/** + * Fill the TClonesArray in order of smallest to largest chi2. + */ +void AtMCFitterOld::FillResultArrays(TClonesArray &resultArray, TClonesArray &simEvent, TClonesArray &simRawEvent) +{ + resultArray.Delete(); + simEvent.Delete(); + simRawEvent.Delete(); + + for (auto &res : fResults) { + + int clonesIdx = resultArray.GetEntries(); + int eventIdx = res.fIterNum; + LOG(debug) << "Filling iteration " << eventIdx << " at index " << resultArray.GetEntries(); + + new (resultArray[clonesIdx]) AtMCResultOld(std::move(res)); + if (clonesIdx < fNumEventsToSave) { + new (simEvent[clonesIdx]) AtEvent(std::move(fEventArray[eventIdx])); + new (simRawEvent[clonesIdx]) AtRawEvent(std::move(fRawEventArray[eventIdx])); + } + } + + fEventArray.clear(); + fRawEventArray.clear(); +} + +AtMCResultOld AtMCFitterOld::DefineEvent() +{ + AtMCResultOld result; + for (auto &[name, distro] : fParameters) + result.fParameters[name] = distro->Sample(); + return result; +} +void AtMCFitterOld::RecenterParamDistributions() +{ + for (auto &[name, distro] : fParameters) { + AtMCResultOld result = *fResults.begin(); + distro->SetMean(result.fParameters[name]); + distro->TruncateSpace(); + } +} + +} // namespace MCFitter diff --git a/AtReconstruction/AtFitter/AtMCFitterOld.h b/AtReconstruction/AtFitter/AtMCFitterOld.h new file mode 100644 index 000000000..9cbdd6995 --- /dev/null +++ b/AtReconstruction/AtFitter/AtMCFitterOld.h @@ -0,0 +1,136 @@ +#ifndef ATMCFITTEROLD_H +#define ATMCFITTEROLD_H + +#include "AtEvent.h" +#include "AtMCResultOld.h" // for AtMCResult +#include "AtRawEvent.h" + +#include // for TClonesArray + +#include // for function +#include // for map +#include // for shared_ptr +#include // for mutex +#include // for set +#include // for string +#include // for pair +#include // for vector + +class AtBaseEvent; // lines 13-13 +class AtClusterize; // lines 15-15 +class AtMap; // lines 18-18 +class AtPatternEvent; // lines 12-12 +class AtPulse; // lines 16-16 +class AtSimpleSimulation; // lines 14-14 +class AtDigiPar; +class AtPSA; + +namespace MCFitter { +class AtParameterDistribution; + +class [[deprecated]] AtMCFitterOld { +protected: + using ParamPtr = std::shared_ptr; + using SimPtr = std::shared_ptr; + + using ClusterPtr = std::shared_ptr; + using PulsePtr = std::shared_ptr; + using MapPtr = std::shared_ptr; + using PsaPtr = std::shared_ptr; + using ObjPair = std::pair; //< Iteration number and objective function value + + std::map fParameters; + + MapPtr fMap; + SimPtr fSim; + ClusterPtr fClusterize; + PulsePtr fPulse; + PsaPtr fPSA{nullptr}; + + int fNumIter{1}; + int fNumRounds{1}; + int fNumEventsToSave{10}; + bool fTimeEvent{false}; + int fNumThreads{1}; + + // Things used by threads excecuting that are either expensive to create and delete + // or unaccessable due to FairRoot design choices + const AtPatternEvent *fCurrentEvent{nullptr}; + std::vector fThPulse; //< Cached because it is expensive to create and delete. + const AtDigiPar *fPar{nullptr}; // fRawEventArray; + std::vector fEventArray; + + /** Things below here need to be written to by threads and will be locked using a shared mutex ***/ + /// Store the iteration number sorted by lowest objective funtion + std::mutex fResultMutex; + std::set> fResults; + +public: + AtMCFitterOld(SimPtr sim, ClusterPtr cluster, PulsePtr pulse); + virtual ~AtMCFitterOld() = default; + + void Init(); + void SetPSA(PsaPtr psa) { fPSA = psa; } + void Exec(const AtPatternEvent &event); + + ParamPtr GetParameter(const std::string &name) const; + void FillResultArrays(TClonesArray &resultArray, TClonesArray &simEvent, TClonesArray &simRawEvent); + void SetNumIter(int iter) { fNumIter = iter; } + + /// Set number of times to run fNumIter iterations and then re-center and truncate the parameter space. + void SetNumRounds(int rounds) { fNumRounds = rounds; } + void SetTimeEvent(bool val) { fTimeEvent = val; } + void SetNumEventsToSave(int num) { fNumEventsToSave = num; } + void SetNumThreads(int num); + +protected: + void RunRound(); + void RunIterRange(int startIter, int numIter, AtPulse *pulse); + + /** + *@brief Create the parameter distributions to use for the fit. + */ + virtual void CreateParamDistros() = 0; + + /** + * @brief Set parameter distributions (mean/spread) from the event. + */ + virtual void SetParamDistributions(const AtPatternEvent &event) = 0; + + /** + * @brief This is the thing we are minimizing between events (SimEventID is index in TClonesArray) + */ + virtual double ObjectiveFunction(const AtBaseEvent &expEvent, int SimEventID, AtMCResultOld &definition) = 0; + + /** + * Simulate an event using the parameters in the passed AtMCResult class and return an array of + * the AtMCPoints to then digitize. + */ + virtual TClonesArray SimulateEvent(AtMCResultOld &definition) = 0; + + /** + * Sample parameter distributions and constrain the system to simulate an event. + * The parameters in AtMCResult will be used to then simulate an event. + * This function calls Sample() on all the parameter distributions and saves them. + */ + virtual AtMCResultOld DefineEvent(); + + /** + * Recenter the parameter distributions around the best result and truncate the parameter space. + */ + virtual void RecenterParamDistributions(); + + /** + * Create the AtRawEvent and AtEvent from fSim + * returns the index of the event in the TClonesArray + */ + int DigitizeEvent(const TClonesArray &points, int idx, AtPulse *pulse); +}; + +} // namespace MCFitter + +#endif // ATMCFITTEROLD_H diff --git a/AtReconstruction/AtFitter/AtMCFitterTaskOld.cxx b/AtReconstruction/AtFitter/AtMCFitterTaskOld.cxx new file mode 100644 index 000000000..10df98724 --- /dev/null +++ b/AtReconstruction/AtFitter/AtMCFitterTaskOld.cxx @@ -0,0 +1,51 @@ +#include "AtMCFitterTaskOld.h" + +#include "AtMCFitterOld.h" +#include "AtPatternEvent.h" + +#include // for LOG, Logger +#include // for FairRootManager + +#include // for TClonesArray +#include + +#include + +AtMCFitterTaskOld::AtMCFitterTaskOld(std::shared_ptr fitter) + : fFitter(std::move(fitter)), fResultArray("MCFitter::AtMCResultOld"), fSimEventArray("AtEvent"), + fSimRawEventArray("AtRawEvent") +{ +} + +InitStatus AtMCFitterTaskOld::Init() +{ + LOG(debug) << "Initialing fitter"; + fFitter->Init(); + + FairRootManager *ioman = FairRootManager::Instance(); + ioman->Register("SimEvent", "cbmsim", &fSimEventArray, fSaveEvent); + ioman->Register("SimRawEvent", "cbmsim", &fSimRawEventArray, fSaveRawEvent); + ioman->Register("AtMCResultOld", "cbmsim", &fResultArray, fSaveResult); + + fPatternArray = dynamic_cast(ioman->GetObject(fPatternBranchName)); + if (fPatternArray == nullptr) + LOG(fatal) << "Failed to load branch " << fPatternBranchName; + + LOG(debug) << "Done with sim init"; + return kSUCCESS; +} + +void AtMCFitterTaskOld::Exec(Option_t *) +{ + LOG(debug) << "Exec"; + auto patEvent = dynamic_cast(fPatternArray->At(0)); + if (!patEvent->IsGood()) + return; + + fFitter->Exec(*patEvent); + fResultArray.Delete(); + fSimEventArray.Delete(); + fSimRawEventArray.Delete(); + + fFitter->FillResultArrays(fResultArray, fSimEventArray, fSimRawEventArray); +} diff --git a/AtReconstruction/AtFitter/AtMCFitterTaskOld.h b/AtReconstruction/AtFitter/AtMCFitterTaskOld.h new file mode 100644 index 000000000..29d115203 --- /dev/null +++ b/AtReconstruction/AtFitter/AtMCFitterTaskOld.h @@ -0,0 +1,43 @@ +#ifndef ATMCFITTERTASKOLD_H +#define ATMCFITTERTASKOLD_H + +#include + +#include // for Option_t +#include +#include // for TString + +#include // for shared_ptr + +namespace MCFitter { +class AtMCFitterOld; +} + +class [[deprecated]] AtMCFitterTaskOld : public FairTask { + + std::shared_ptr fFitter; //! + TString fPatternBranchName{"AtPatternEvent"}; + TClonesArray *fPatternArray{nullptr}; + + TClonesArray fResultArray; //< Output of task + TClonesArray fSimEventArray; + TClonesArray fSimRawEventArray; + + Bool_t fSaveResult{true}; + Bool_t fSaveEvent{false}; + Bool_t fSaveRawEvent{false}; + +public: + AtMCFitterTaskOld(std::shared_ptr fitter); + + InitStatus Init() override; + void Exec(Option_t *option = "") override; + void Finish() override{}; + + void SetPatternBranchName(TString name) { fPatternBranchName = name; } + void SetSaveResult(bool val) { fSaveResult = val; } + void SetSaveEvent(bool val) { fSaveEvent = val; } + void SetSaveRawEvent(bool val) { fSaveRawEvent = val; } +}; + +#endif // ATMCFITTERTASKOLD_H diff --git a/AtReconstruction/AtFitterTask.cxx b/AtReconstruction/AtFitterTask.cxx index 8bef7da49..c4ce51af5 100644 --- a/AtReconstruction/AtFitterTask.cxx +++ b/AtReconstruction/AtFitterTask.cxx @@ -1,10 +1,12 @@ #include "AtFitterTask.h" #include "AtDigiPar.h" +#include "AtEvent.h" +#include "AtFitMetadata.h" #include "AtFitter.h" -#include "AtGenfit.h" #include "AtParsers.h" #include "AtPatternEvent.h" +#include "AtRawEvent.h" #include "AtTrackingEvent.h" #include @@ -15,7 +17,6 @@ #include #include -#include #include #include @@ -23,11 +24,10 @@ class AtTrack; class AtFittedTrack; -ClassImp(AtFitterTask); - -AtFitterTask::AtFitterTask(std::unique_ptr fitter) +AtFitterTask::AtFitterTask(std::unique_ptr fitter) : fInputBranchName("AtPatternEvent"), fOutputBranchName("AtTrackingEvent"), fIsPersistence(kFALSE), - fTrackingEventArray(TClonesArray("AtTrackingEvent", 1)), fFitter(std::move(fitter)) + fTrackingEventArray(TClonesArray("AtTrackingEvent", 1)), fFitter(std::move(fitter)), fRawEventBranchName(""), + fEventBranchName(""), fFitMetadataBranchName(""), fFitMetadataArray(TClonesArray("AtFitMetadata", 1)) { } @@ -46,6 +46,22 @@ void AtFitterTask::SetOutputBranch(TString branchName) fOutputBranchName = branchName; } +void AtFitterTask::SetRawEventBranch(TString branchName) +{ + fRawEventBranchName = branchName; +} + +void AtFitterTask::SetEventBranch(TString branchName) +{ + fEventBranchName = branchName; +} + +void AtFitterTask::SetFitMetadataBranch(TString branchName) +{ + fFitMetadataBranchName = branchName; + fSaveFitMetadata = true; +} + InitStatus AtFitterTask::Init() { FairRootManager *ioMan = FairRootManager::Instance(); @@ -54,13 +70,24 @@ InitStatus AtFitterTask::Init() return kERROR; } - fPatternEventArray = dynamic_cast(ioMan->GetObject("AtPatternEvent")); + fPatternEventArray = dynamic_cast(ioMan->GetObject(fInputBranchName)); if (fPatternEventArray == nullptr) { LOG(error) << "Cannot find AtPatternEvent array!"; return kERROR; } ioMan->Register(fOutputBranchName, "AtTPC", &fTrackingEventArray, fIsPersistence); + ioMan->Register(fFitMetadataBranchName, "AtTPC", &fFitMetadataArray, fIsPersistence && fSaveFitMetadata); + + fRawEventArray = dynamic_cast(ioMan->GetObject(fRawEventBranchName)); + if (fRawEventArray == nullptr) { + LOG(info) << "AtRawEvent branch name was not set. No AtRawEvent will be passed to the fitter."; + } + + fEventArray = dynamic_cast(ioMan->GetObject(fEventBranchName)); + if (fEventArray == nullptr) { + LOG(info) << "AtEvent branch name was not set. No AtEvent will be passed to the fitter."; + } return kSUCCESS; } @@ -87,22 +114,33 @@ void AtFitterTask::Exec(Option_t *option) if (fPatternEventArray->GetEntriesFast() == 0) return; + // If there is AtRawEvent available, get it so it can be passed to the fitter. + AtRawEvent *rawEvent = nullptr; + if (fRawEventArray) + rawEvent = dynamic_cast(fRawEventArray->At(0)); + + // If there is AtEvent available, get it so it can be passed to the fitter. + AtEvent *event = nullptr; + if (fEventArray) + event = dynamic_cast(fEventArray->At(0)); + fTrackingEventArray.Delete(); + fFitMetadataArray.Delete(); auto trackingEvent = dynamic_cast(fTrackingEventArray.ConstructedAt(0)); + auto fitMetadata = dynamic_cast(fFitMetadataArray.ConstructedAt(0)); - std::cout << " Event Counter " << fEventCnt << "\n"; + LOG(info) << " Fitting event " << fEventCnt; - AtPatternEvent &patternEvent = *(dynamic_cast(fPatternEventArray->At(0))); - std::vector &tracks = patternEvent.GetTrackCand(); - std::cout << " AtFitterTask:Exec - Number of candidate tracks : " << tracks.size() << "\n"; + AtPatternEvent *patternEvent = dynamic_cast(fPatternEventArray->At(0)); + std::vector &tracks = patternEvent->GetTrackCand(); + LOG(info) << " Number of candidate tracks : " << tracks.size(); - auto fittedTracks = fFitter->ProcessTracks(tracks); + fFitter->FitEvent(trackingEvent, patternEvent, fitMetadata, rawEvent, event); - std::cout << " Number of fitted tracks " << fittedTracks.size() << "\n"; + auto &fittedTracks = trackingEvent->GetFittedTracks(); - for (auto &fittedTrack : fittedTracks) - trackingEvent->AddFittedTrack(std::move(fittedTrack)); + LOG(info) << " Number of fitted tracks : " << fittedTracks.size(); ++fEventCnt; } diff --git a/AtReconstruction/AtFitterTask.h b/AtReconstruction/AtFitterTask.h index dc3256b12..771ad9a8d 100644 --- a/AtReconstruction/AtFitterTask.h +++ b/AtReconstruction/AtFitterTask.h @@ -15,9 +15,8 @@ #include #include +#include -#include "EventDisplay.h" -#include "Exception.h" #include "FairLogger.h" #include "FairRootManager.h" #include "FairRun.h" @@ -38,24 +37,29 @@ class AtTrack; namespace AtTools { class AtTrackTransformer; } // namespace AtTools -namespace AtFITTER { +namespace EventFit { class AtFitter; -} // namespace AtFITTER -namespace genfit { -class Track; -} // namespace genfit - +} // namespace EventFit + +/** + * Task that takes a certain AtFitter and uses to fit an AtPatternEvent. The AtFitter may need access to the AtRawEvent + * or AtEvent as well, so pointers to them are also read and passed to the AtFitter. An AtFitMetadata object may also be + * written, which would contain the fit metadata information for all fits done to all AtTracks. + * Specific logic of the fitting is contained in AtFitter and derived classes. + */ class AtFitterTask : public FairTask { - public: - // AtFitterTask(); + AtFitterTask(std::unique_ptr fitter); ~AtFitterTask() = default; - AtFitterTask(std::unique_ptr fitter); void SetInputBranch(TString branchName); void SetOutputBranch(TString branchName); void SetPersistence(Bool_t value = kTRUE); + void SetRawEventBranch(TString branchName); + void SetEventBranch(TString branchName); + void SetFitMetadataBranch(TString branchName); + virtual InitStatus Init(); virtual void SetParContainers(); virtual void Exec(Option_t *opt); @@ -66,14 +70,23 @@ class AtFitterTask : public FairTask { Bool_t fIsPersistence; //!< Persistence check variable - std::unique_ptr fFitter; + std::unique_ptr fFitter; AtDigiPar *fPar{nullptr}; TClonesArray *fPatternEventArray; TClonesArray fTrackingEventArray; std::size_t fEventCnt{0}; - ClassDef(AtFitterTask, 1); + // Include the option to input AtRawEvent and AtEvent in case some specific AtFitter needs it. + TString fRawEventBranchName; + TString fEventBranchName; + TClonesArray *fRawEventArray; + TClonesArray *fEventArray; + + // Include the option to store all the fit metadata in a AtFitResult branch. + TString fFitMetadataBranchName; + TClonesArray fFitMetadataArray; + bool fSaveFitMetadata{false}; }; #endif diff --git a/AtReconstruction/AtFitterTaskOld.cxx b/AtReconstruction/AtFitterTaskOld.cxx new file mode 100644 index 000000000..7c5be94e0 --- /dev/null +++ b/AtReconstruction/AtFitterTaskOld.cxx @@ -0,0 +1,108 @@ +#include "AtFitterTaskOld.h" + +#include "AtDigiPar.h" +#include "AtFitter.h" +#include "AtGenfit.h" +#include "AtParsers.h" +#include "AtPatternEvent.h" +#include "AtTrackingEvent.h" + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +class AtTrack; +class AtFittedTrackOld; + +ClassImp(AtFitterTaskOld); + +AtFitterTaskOld::AtFitterTaskOld(std::unique_ptr fitter) + : fInputBranchName("AtPatternEvent"), fOutputBranchName("AtTrackingEvent"), fIsPersistence(kFALSE), + fTrackingEventArray(TClonesArray("AtTrackingEvent", 1)), fFitter(std::move(fitter)) +{ +} + +void AtFitterTaskOld::SetPersistence(Bool_t value) +{ + fIsPersistence = value; +} + +void AtFitterTaskOld::SetInputBranch(TString branchName) +{ + fInputBranchName = branchName; +} + +void AtFitterTaskOld::SetOutputBranch(TString branchName) +{ + fOutputBranchName = branchName; +} + +InitStatus AtFitterTaskOld::Init() +{ + FairRootManager *ioMan = FairRootManager::Instance(); + if (ioMan == nullptr) { + LOG(error) << "Cannot find RootManager!"; + return kERROR; + } + + fPatternEventArray = dynamic_cast(ioMan->GetObject("AtPatternEvent")); + if (fPatternEventArray == nullptr) { + LOG(error) << "Cannot find AtPatternEvent array!"; + return kERROR; + } + + ioMan->Register(fOutputBranchName, "AtTPC", &fTrackingEventArray, fIsPersistence); + + return kSUCCESS; +} + +void AtFitterTaskOld::SetParContainers() +{ + LOG(debug) << "SetParContainers of AtFitterTask"; + + FairRun *run = FairRun::Instance(); + if (!run) + LOG(fatal) << "No analysis run!"; + + FairRuntimeDb *db = run->GetRuntimeDb(); // NOLINT + if (!db) + LOG(fatal) << "No runtime database!"; + + fPar = (AtDigiPar *)db->getContainer("AtDigiPar"); // NOLINT + if (!fPar) + LOG(fatal) << "AtDigiPar not found!!"; +} + +void AtFitterTaskOld::Exec(Option_t *option) +{ + if (fPatternEventArray->GetEntriesFast() == 0) + return; + + fTrackingEventArray.Delete(); + + auto trackingEvent = dynamic_cast(fTrackingEventArray.ConstructedAt(0)); + + std::cout << " Event Counter " << fEventCnt << "\n"; + + AtPatternEvent &patternEvent = *(dynamic_cast(fPatternEventArray->At(0))); + std::vector &tracks = patternEvent.GetTrackCand(); + std::cout << " AtFitterTask:Exec - Number of candidate tracks : " << tracks.size() << "\n"; + + auto fittedTracks = fFitter->ProcessTracks(tracks); + + std::cout << " Number of fitted tracks " << fittedTracks.size() << "\n"; + + for (auto &fittedTrack : fittedTracks) + trackingEvent->AddFittedTrack(std::move(fittedTrack)); + + ++fEventCnt; +} diff --git a/AtReconstruction/AtFitterTaskOld.h b/AtReconstruction/AtFitterTaskOld.h new file mode 100644 index 000000000..c627a37b9 --- /dev/null +++ b/AtReconstruction/AtFitterTaskOld.h @@ -0,0 +1,79 @@ +/********************************************************************* + * Fitter Task AtFitterTask.hh * + * Author: Y. Ayyad ayyadlim@frib.msu.edu * + * Log: 3/10/2021 * + * * + *********************************************************************/ + +#ifndef ATFITTERTASKOLD +#define ATFITTERTASKOLD + +#include "AtFormat.h" +#include "AtKinematics.h" +#include "AtParsers.h" + +#include + +#include + +#include "EventDisplay.h" +#include "Exception.h" +#include "FairLogger.h" +#include "FairRootManager.h" +#include "FairRun.h" +#include "FairRunAna.h" + +#include +#include +#include + +class AtDigiPar; +class FairLogger; +class TBuffer; +class TClass; +class TClonesArray; +class TMemberInspector; +class AtTrack; + +namespace AtTools { +class AtTrackTransformer; +} // namespace AtTools +namespace AtFITTER { +class AtFitterOld; +} // namespace AtFITTER +namespace genfit { +class Track; +} // namespace genfit + +class [[deprecated]] AtFitterTaskOld : public FairTask { + +public: + // AtFitterTask(); + ~AtFitterTaskOld() = default; + AtFitterTaskOld(std::unique_ptr fitter); + + void SetInputBranch(TString branchName); + void SetOutputBranch(TString branchName); + void SetPersistence(Bool_t value = kTRUE); + + virtual InitStatus Init(); + virtual void SetParContainers(); + virtual void Exec(Option_t *opt); + +private: + TString fInputBranchName; + TString fOutputBranchName; + + Bool_t fIsPersistence; //!< Persistence check variable + + std::unique_ptr fFitter; + AtDigiPar *fPar{nullptr}; + TClonesArray *fPatternEventArray; + TClonesArray fTrackingEventArray; + + std::size_t fEventCnt{0}; + + ClassDef(AtFitterTaskOld, 1); +}; + +#endif diff --git a/AtReconstruction/AtReconstructionLinkDef.h b/AtReconstruction/AtReconstructionLinkDef.h index 1526e0531..8e543d0ac 100755 --- a/AtReconstruction/AtReconstructionLinkDef.h +++ b/AtReconstruction/AtReconstructionLinkDef.h @@ -46,20 +46,26 @@ #pragma link C++ namespace kf; #pragma link C++ namespace kf::util; +#pragma link C++ namespace EventFit; +#pragma link C++ class AtFitterTask + ; +#pragma link C++ class EventFit::AtFitter - !; + /* Classes that depend on Genfit2 */ +#pragma link C++ namespace AtFITTER; #pragma link C++ class genfit::AtSpacepointMeasurement + ; -#pragma link C++ class AtFITTER::AtFitter + ; +#pragma link C++ class AtFITTER::AtFitterOld + ; +#pragma link C++ class AtFitterTaskOld + ; #pragma link C++ class AtFITTER::AtGenfit + ; -#pragma link C++ namespace AtFITTER; -#pragma link C++ class AtFitterTask + ; #pragma link C++ namespace MCFitter; #pragma link C++ class MCFitter::AtParameterDistribution - !; #pragma link C++ class MCFitter::AtUniformDistribution - !; #pragma link C++ class MCFitter::AtStudentDistribution - !; #pragma link C++ class MCFitter::AtMCFitter - !; +#pragma link C++ class MCFitter::AtMCFitterOld - !; #pragma link C++ class MCFitter::AtMCFission - !; #pragma link C++ class AtMCFitterTask + ; +#pragma link C++ class AtMCFitterTaskOld + ; /* Tasks in AtReconstruction */ #pragma link C++ class AtPSAtask + ; diff --git a/AtReconstruction/CMakeLists.txt b/AtReconstruction/CMakeLists.txt index a4bc5296b..cc40e6d9d 100755 --- a/AtReconstruction/CMakeLists.txt +++ b/AtReconstruction/CMakeLists.txt @@ -111,9 +111,16 @@ set(SRCS AtFitter/ParameterDistributions/AtUniformDistribution.cxx AtFitter/ParameterDistributions/AtStudentDistribution.cxx + AtFitter/AtFitter.cxx AtFitter/AtMCFitter.cxx AtFitter/AtMCFitterTask.cxx AtFitter/AtMCFission.cxx + AtFitterTask.cxx + + # Deprecated... + AtFitter/AtMCFitterOld.cxx + AtFitter/AtMCFitterTaskOld.cxx + ) @@ -153,10 +160,12 @@ endif() if(GENFIT2_FOUND) set(SRCS ${SRCS} - AtFitter/AtFitter.cxx AtFitter/AtGenfit.cxx AtFitter/AtSpacePointMeasurement.cxx - AtFitterTask.cxx + + # Deprecated, to be removed: + AtFitter/AtFitterOld.cxx + AtFitter/AtFitterTaskOld.cxx ) set(DEPENDENCIES ${DEPENDENCIES} GENFIT2::genfit2 diff --git a/AtTools/AtKinematics.cxx b/AtTools/AtKinematics.cxx index 2431d34f2..d924e5814 100644 --- a/AtTools/AtKinematics.cxx +++ b/AtTools/AtKinematics.cxx @@ -304,4 +304,13 @@ double EtoA(double mass) return mass / 931.5; } +std::tuple GetMomFromBrho(double mass, int Z, double brho) +{ + const Double_t M_Ener = mass * 931.49401; // In MeV + Double_t p = brho * Z * (2.99792458 / 10.0) * 1000.0; // In MeV + Double_t E = TMath::Sqrt(TMath::Power(p, 2) + TMath::Power(M_Ener, 2)) - M_Ener; // In MeV + + return std::make_tuple(p, E); +} + } // namespace AtTools::Kinematics diff --git a/AtTools/AtKinematics.h b/AtTools/AtKinematics.h index b0d343fd2..274dea9b0 100644 --- a/AtTools/AtKinematics.h +++ b/AtTools/AtKinematics.h @@ -73,6 +73,7 @@ double GetBeta(double p, double mass); double GetRelMom(double gamma, double mass); double AtoE(double Amu); double EtoA(double mass); +std::tuple GetMomFromBrho(double mass, int Z, double brho); template ROOT::Math::PxPyPzEVector Get4Vector(Vector mom, double m) diff --git a/AtTools/AtTrackTransformer.cxx b/AtTools/AtTrackTransformer.cxx index fb49ae267..b3c80d8c5 100644 --- a/AtTools/AtTrackTransformer.cxx +++ b/AtTools/AtTrackTransformer.cxx @@ -266,6 +266,26 @@ void AtTools::AtTrackTransformer::ClusterizeSmooth3D(AtTrack &track, Float_t rad } // if array size } +Bool_t AtTools::AtTrackTransformer::FindVertexTrack(AtTrack *trA, AtTrack *trB) +{ + // Determination of first hit distance. NB: Assuming both tracks have the same angle sign + Double_t vertexA = 0.0; + Double_t vertexB = 0.0; + if (trA->GetGeoTheta() * TMath::RadToDeg() < 90) { + auto iniClusterA = trA->GetHitClusterArray()->back(); + auto iniClusterB = trB->GetHitClusterArray()->back(); + vertexA = 1000.0 - iniClusterA.GetPosition().Z(); + vertexB = 1000.0 - iniClusterB.GetPosition().Z(); + } else if (trA->GetGeoTheta() * TMath::RadToDeg() > 90) { + auto iniClusterA = trA->GetHitClusterArray()->front(); + auto iniClusterB = trB->GetHitClusterArray()->front(); + vertexA = iniClusterA.GetPosition().Z(); + vertexB = iniClusterB.GetPosition().Z(); + } + + return vertexA < vertexB; +} + const std::tuple AtTools::AtTrackTransformer::GetPIDFromHits(AtTrack &track, Double_t theta) { @@ -310,3 +330,115 @@ const std::tuple AtTools::AtTrackTransformer::GetPIDFromHits return std::forward_as_tuple(dedx, eloss); } + +Bool_t AtTools::AtTrackTransformer::MergeTracks(std::vector *trackCandSource, + std::vector *trackDest, Bool_t enableSingleVertexTrack, + Double_t clusterRadius, Double_t clusterDistance) +{ + + Bool_t toMerge = kFALSE; + + Int_t addHitCnt = 0; + // Find the track closer to vertex + std::sort(trackCandSource->begin(), trackCandSource->end(), + [this](AtTrack *trA, AtTrack *trB) { return FindVertexTrack(trA, trB); }); + + // Track stitching from vertex + AtTrack *vertexTrack = *trackCandSource->begin(); + + if (enableSingleVertexTrack) { + + // Mark all tracks as merged + for (auto track : *trackCandSource) + track->SetIsMerged(kTRUE); + + trackDest->push_back(*vertexTrack); + return true; + } + + // Check if the candidate vertex track was merged + if (vertexTrack->GetIsMerged()) + return kFALSE; + else + vertexTrack->SetIsMerged(kTRUE); + + // If enabled, choose only the track closest to vertex (i.e. first one of the collection of candidates) + // TODO: Select by number of points + + for (auto it = trackCandSource->begin() + 1; it != trackCandSource->end(); ++it) { + // NB: These tracks were previously marked to merge. If merging fails they should be discarded. + AtTrack *trackToMerge = *(it); + toMerge = kFALSE; + + // Skip trackes flagged as merged + if (!trackToMerge->GetIsMerged()) { + trackToMerge->SetIsMerged(kTRUE); + } else + continue; + + Double_t endVertexZ = 0.0; + Double_t iniMergeZ = 0.0; + std::cout << " Trying to merge ... " + << "\n"; + std::cout << " Vertex track " << vertexTrack->GetTrackID() << " - Track to Merge " << trackToMerge->GetTrackID() + << "\n"; + // Check relative position between end and begin of each track using Hit Clusters + std::cout << " Vertex angle " << vertexTrack->GetGeoTheta() * TMath::RadToDeg() << "\n"; + if (vertexTrack->GetGeoTheta() * TMath::RadToDeg() < 90) { + auto endClusterVertex = vertexTrack->GetHitClusterArray()->front(); + auto iniClusterMerge = trackToMerge->GetHitClusterArray()->back(); + // Check separation and relative distance + endVertexZ = 1000.0 - endClusterVertex.GetPosition().Z(); + iniMergeZ = 1000.0 - iniClusterMerge.GetPosition().Z(); + + Double_t distance = std::sqrt((iniClusterMerge.GetPosition() - endClusterVertex.GetPosition()).Mag2()); + // std::cout << " Distance between tracks " << distance << "\n"; + // std::cout << " Ini Merge " << iniMergeZ << " - endVertexZ " << endVertexZ << "\n"; + if (((iniMergeZ + 10.0) > endVertexZ) && distance < 200) { + toMerge = kTRUE; + } + + } else if (vertexTrack->GetGeoTheta() * TMath::RadToDeg() > 90) { + auto endClusterVertex = vertexTrack->GetHitClusterArray()->back(); + auto iniClusterMerge = trackToMerge->GetHitClusterArray()->front(); + // Check separation and relative distance + endVertexZ = endClusterVertex.GetPosition().Z(); + iniMergeZ = iniClusterMerge.GetPosition().Z(); + Double_t distance = std::sqrt((iniClusterMerge.GetPosition() - endClusterVertex.GetPosition()).Mag2()); + // std::cout<<" Distance between tracks "< endVertexZ) && + distance < 100) { // NB: Distance between parts of the backward tracks is more critical + toMerge = kTRUE; + } + } + + if (toMerge) { + + std::cout << " --- Merging Succeeded! Vertex track " << vertexTrack->GetTrackID() << " - Track to Merge " + << trackToMerge->GetTrackID() << "\n"; + for (const auto &hit : trackToMerge->GetHitArray()) { + + vertexTrack->AddHit(hit->Clone()); // TODO: Look at code and see if this can be a move instead of a copy + ++addHitCnt; + } + + // Reclusterize after merging + vertexTrack->SortHitArrayTime(); + vertexTrack->ResetHitClusterArray(); + ClusterizeSmooth3D( + *vertexTrack, clusterRadius, + clusterDistance); // NB: It can be removed if we force reclusterization for any track in the mina program + + // TODO: Check if phi recalculatio is needed + + } else { + std::cout << " --- Merging Failed ! Vertex track " << vertexTrack->GetTrackID() << " - Track to Merge " + << trackToMerge->GetTrackID() << "\n"; + } + } + + trackDest->push_back(*vertexTrack); + + return toMerge; +} diff --git a/AtTools/AtTrackTransformer.h b/AtTools/AtTrackTransformer.h index b0c20db3c..acb93b78a 100644 --- a/AtTools/AtTrackTransformer.h +++ b/AtTools/AtTrackTransformer.h @@ -16,6 +16,11 @@ class AtTrackTransformer { void ClusterizeSmooth3D(AtTrack &track, Float_t radius, Float_t distance); const std::tuple GetPIDFromHits(AtTrack &track, Double_t theta); + Bool_t FindVertexTrack(AtTrack *trA, AtTrack *trB); + + Bool_t MergeTracks(std::vector *trackCandSource, std::vector *trackDest, + Bool_t enableSingleVertexTrack, Double_t clusterRadius, Double_t clusterDistance); + private: };