Skip to content

Commit 00d78ef

Browse files
skudithaUditha Weerasinghe
authored andcommitted
feat: ALERT AI Pre PID (#1112)
* modified recobankwriter with dummy functions * added a call for AI PID --currently in testing mode * commented lines out for testing * commented lines out for testing * modified to include PrePID similar to TrackMatching * modified to include PrePID similar to TrackMatching * created PrePIDResult.java * created ModelPrePID.java * added ALERT PID bank to bankdefs * added ALERT PrePID bank to bankdefs * cleaning up * cleaning up * fix: added probabilities for prepid * fixed a typo in alert.json * fix: changed all prints to logger + minor cleanup * fix: added info to alert.json for AI PID bank * major change: switched from ATOF::clusters to ATOF::hits --------- Co-authored-by: Uditha Weerasinghe <skuditha@jlab.org>
1 parent d801dcc commit 00d78ef

File tree

5 files changed

+275
-1
lines changed

5 files changed

+275
-1
lines changed

etc/bankdefs/hipo4/alert.json

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,38 @@
6565
{"name": "trackid", "type": "I", "info": "track id"},
6666
{"name": "matched_atof_hit_id", "type": "I", "info": "id of the matched ATOF hit, -1 if no hit was matched"}
6767
]
68+
},
69+
{
70+
"name": "ALERT::ai:prepid",
71+
"group": 23000,
72+
"item": 33,
73+
"info": "ALERT AI-assisted PrePID to be used for Kalman Filter",
74+
"entries": [
75+
{"name":"trackid", "type":"I", "info":"AHDC trackid"},
76+
{"name":"clusterid", "type":"I", "info":"ATOF cluster id"},
77+
{"name":"prepid", "type":"I", "info":"argmax PID"},
78+
{"name":"p2212", "type":"F", "info":"P(pid=2212)"},
79+
{"name":"p45", "type":"F", "info":"P(pid=45)"},
80+
{"name":"p46", "type":"F", "info":"P(pid=46)"},
81+
{"name":"p47", "type":"F", "info":"P(pid=47)"},
82+
{"name":"p49", "type":"F", "info":"P(pid=49)"}
83+
]
84+
},
85+
{
86+
"name": "ALERT::ai:pid",
87+
"group": 23000,
88+
"item": 34,
89+
"info": "AI-assisted PID for ALERT",
90+
"entries": [
91+
{"name":"trackid", "type":"I", "info":"AHDC trackid"},
92+
{"name":"clusterid", "type":"I", "info":"ATOF cluster id"},
93+
{"name":"pid", "type":"I", "info":"argmax PID"},
94+
{"name":"prob_2212", "type":"F", "info":"P(pid=2212)"},
95+
{"name":"prob_45", "type":"F", "info":"P(pid=45)"},
96+
{"name":"prob_46", "type":"F", "info":"P(pid=46)"},
97+
{"name":"prob_47", "type":"F", "info":"P(pid=47)"},
98+
{"name":"prob_49", "type":"F", "info":"P(pid=49)"}
99+
]
68100
},
69101
{
70102
"name": "ATOF::hits",
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
package org.jlab.rec.alert.AIPID;
2+
3+
import ai.djl.MalformedModelException;
4+
import ai.djl.inference.Predictor;
5+
import ai.djl.ndarray.NDArray;
6+
import ai.djl.ndarray.NDList;
7+
import ai.djl.ndarray.NDManager;
8+
import ai.djl.ndarray.types.Shape;
9+
import ai.djl.repository.zoo.Criteria;
10+
import ai.djl.repository.zoo.ModelNotFoundException;
11+
import ai.djl.repository.zoo.ZooModel;
12+
import ai.djl.training.util.ProgressBar;
13+
import ai.djl.translate.TranslateException;
14+
import ai.djl.translate.Translator;
15+
import ai.djl.translate.TranslatorContext;
16+
17+
import org.jlab.utils.CLASResources;
18+
19+
import java.io.IOException;
20+
import java.nio.file.Paths;
21+
import java.util.logging.Logger;
22+
23+
public class ModelPrePID {
24+
25+
static final Logger LOGGER = Logger.getLogger(ModelPrePID.class.getName());
26+
// Must match training class order
27+
private static final int[] CLASS_IDS = new int[]{2212, 45, 46, 47, 49};
28+
29+
private final ZooModel<float[], float[]> model;
30+
31+
public ModelPrePID() {
32+
33+
Translator<float[], float[]> my_translator = new Translator<>() {
34+
35+
@Override
36+
public NDList processInput(TranslatorContext ctx, float[] floats) {
37+
NDManager manager = ctx.getNDManager();
38+
39+
// IMPORTANT: model expects (batch, 23). Provide (1, 23).
40+
NDArray x = manager.create(floats, new Shape(1, 23));
41+
return new NDList(x);
42+
}
43+
44+
@Override
45+
public float[] processOutput(TranslatorContext ctx, NDList ndList) {
46+
NDArray logits = ndList.get(0); // (1,5)
47+
NDArray probs = logits.softmax(1); // (1,5)
48+
49+
float[] p = probs.toFloatArray(); // length 5 (row-major)
50+
51+
// argmax
52+
int bestIdx = 0;
53+
float best = p[0];
54+
for (int k = 1; k < 5; k++) {
55+
if (p[k] > best) { best = p[k]; bestIdx = k; }
56+
}
57+
int prepid = CLASS_IDS[bestIdx];
58+
59+
// Return: prepid + probabilities in fixed class order
60+
return new float[]{
61+
(float) prepid,
62+
p[0], p[1], p[2], p[3], p[4]
63+
};
64+
}
65+
};
66+
67+
System.setProperty("ai.djl.pytorch.num_interop_threads", "1");
68+
System.setProperty("ai.djl.pytorch.num_threads", "1");
69+
System.setProperty("ai.djl.pytorch.graph_optimizer", "false");
70+
71+
String path = CLASResources.getResourcePath("etc/data/nnet/rg-l/model_PrePID/");
72+
73+
Criteria<float[], float[]> criteria = Criteria.builder()
74+
.setTypes(float[].class, float[].class)
75+
.optModelPath(Paths.get(path))
76+
.optEngine("PyTorch")
77+
.optTranslator(my_translator)
78+
.optProgress(new ProgressBar())
79+
.build();
80+
81+
try {
82+
model = criteria.loadModel();
83+
} catch (IOException | ModelNotFoundException | MalformedModelException e) {
84+
throw new RuntimeException(e);
85+
}
86+
}
87+
88+
public ZooModel<float[], float[]> getModel() {
89+
return model;
90+
}
91+
92+
/** Returns float[]{prepid} where prepid in {2212,45,46,47,49}.
93+
* @param features23
94+
* @return
95+
* @throws ai.djl.translate.TranslateException */
96+
public float[] prediction(float[] features23) throws TranslateException {
97+
if (features23 == null || features23.length != 23) {
98+
LOGGER.warning("PrePID input must be float[23]");
99+
return null;
100+
}
101+
Predictor<float[], float[]> predictor = model.newPredictor();
102+
return predictor.predict(features23);
103+
}
104+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package org.jlab.rec.alert.AIPID;
2+
3+
public class PrePIDResult {
4+
public final int trackid;
5+
public final int clusterid;
6+
public final int prepid;
7+
public final float p2212, p45, p46, p47, p49;
8+
9+
public PrePIDResult(int trackid, int clusterid, int prepid, float p2212, float p45, float p46, float p47, float p49) {
10+
this.trackid = trackid;
11+
this.clusterid = clusterid;
12+
this.prepid = prepid;
13+
this.p2212 = p2212;
14+
this.p45 = p45;
15+
this.p46 = p46;
16+
this.p47 = p47;
17+
this.p49 = p49;
18+
}
19+
}

reconstruction/alert/src/main/java/org/jlab/rec/alert/banks/RecoBankWriter.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package org.jlab.rec.alert.banks;
22

33
import java.util.ArrayList;
4+
import java.util.List;
45
import org.jlab.io.base.DataBank;
56
import org.jlab.io.base.DataEvent;
67
import org.jlab.rec.alert.projections.TrackProjection;
8+
//import org.jlab.rec.alert.AIpid.PIDResult;
79

810
import ai.djl.util.Pair;
911

@@ -89,6 +91,31 @@ public int appendTrackMatchingAIBank(DataEvent event, ArrayList<Pair<Integer, In
8991

9092
return 0;
9193
}
94+
95+
public int appendPrePIDBank(DataEvent event, ArrayList<org.jlab.rec.alert.AIPID.PrePIDResult> results) {
96+
97+
DataBank bank = event.createBank("ALERT::ai:prepid", results.size());
98+
if (bank == null) {
99+
System.err.println("COULD NOT CREATE A ALERT::ai:prepid BANK!!!!!!");
100+
return 1;
101+
}
102+
103+
for (int i = 0; i < results.size(); i++) {
104+
org.jlab.rec.alert.AIPID.PrePIDResult r = results.get(i);
105+
bank.setInt("trackid", i, r.trackid);
106+
bank.setInt("clusterid", i, r.clusterid);
107+
bank.setInt("prepid", i, r.prepid);
108+
bank.setFloat("p2212", i, r.p2212);
109+
bank.setFloat("p45", i, r.p45);
110+
bank.setFloat("p46", i, r.p46);
111+
bank.setFloat("p47", i, r.p47);
112+
bank.setFloat("p49", i, r.p49);
113+
}
114+
115+
event.appendBank(bank);
116+
return 0;
117+
}
118+
92119

93120
/**
94121
* @param args the command line arguments

reconstruction/alert/src/main/java/org/jlab/service/alert/ALERTEngine.java

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package org.jlab.service.alert;
22

3+
import ai.djl.repository.zoo.ZooModel;
4+
import ai.djl.translate.TranslateException;
35
import java.io.File;
46
import java.util.ArrayList;
57
import java.util.concurrent.atomic.AtomicInteger;
@@ -14,6 +16,7 @@
1416
import org.jlab.io.hipo.HipoDataSource;
1517
import org.jlab.io.hipo.HipoDataSync;
1618
import org.jlab.rec.alert.TrackMatchingAI.ModelTrackMatching;
19+
import org.jlab.rec.alert.AIPID.ModelPrePID;
1720
import org.jlab.rec.alert.banks.RecoBankWriter;
1821
import org.jlab.rec.alert.projections.TrackProjector;
1922
import org.jlab.rec.atof.hit.ATOFHit;
@@ -24,10 +27,12 @@
2427
import org.jlab.rec.ahdc.Track.Track;
2528
import org.jlab.clas.pdg.PDGDatabase;
2629
import org.jlab.clas.pdg.PDGParticle;
30+
import java.util.logging.Logger;
2731

2832

2933

3034
import ai.djl.util.Pair;
35+
import org.jlab.rec.alert.AIPID.PrePIDResult;
3136

3237

3338
/**
@@ -51,7 +56,7 @@ public class ALERTEngine extends ReconstructionEngine {
5156
*
5257
*/
5358
private RecoBankWriter rbc;
54-
59+
static final Logger LOGGER = Logger.getLogger(ModelPrePID.class.getName());
5560
Detector ATOF; // ALERT ATOF detector
5661
private AlertDCDetector AHDC; // ALERT AHDC detector
5762

@@ -64,6 +69,7 @@ public class ALERTEngine extends ReconstructionEngine {
6469
private double b; //Magnetic field
6570

6671
private ModelTrackMatching modelTrackMatching;
72+
private ModelPrePID modelPrePID;
6773

6874
public void setB(double B) {
6975
this.b = B;
@@ -90,6 +96,7 @@ public boolean init() {
9096
rbc = new RecoBankWriter();
9197

9298
modelTrackMatching = new ModelTrackMatching();
99+
modelPrePID = new ModelPrePID();
93100

94101
AlertTOFFactory factory = new AlertTOFFactory();
95102
DatabaseConstantProvider cp = new DatabaseConstantProvider(11, "default");
@@ -222,7 +229,92 @@ public boolean processDataEvent(DataEvent event) {
222229
}
223230
}
224231
rbc.appendTrackMatchingAIBank(event, matched_ATOF_hit_id);
232+
233+
// ---------------------------------------------------------------------------------------
234+
// PrePID using AI (AHDC::track + ATOF::clusters matched via ALERT::ai:projections)
235+
// ---------------------------------------------------------------------------------------
236+
if (event.hasBank("ALERT::ai:projections") && event.hasBank("AHDC::track") && event.hasBank("ATOF::hits")) {
237+
238+
DataBank bankProj = event.getBank("ALERT::ai:projections");
239+
DataBank bankTrk = event.getBank("AHDC::track");
240+
DataBank bankHit = event.getBank("ATOF::hits");
241+
242+
ArrayList<PrePIDResult> prepid_results = new ArrayList<>();
243+
244+
for (int i = 0; i < bankProj.rows(); i++) {
245+
246+
int trackid = bankProj.getInt("trackid", i);
247+
int hitid = bankProj.getInt("matched_atof_hit_id", i); // TODO: Fix to hit_id instead of clusterid
248+
249+
// TODO: refactor this to replace this with single line
250+
int trkRow = -1;
251+
for (int r = 0; r < bankTrk.rows(); r++) {
252+
if (bankTrk.getInt("trackid", r) == trackid) { trkRow = r; break; }
253+
}
254+
if (trkRow < 0) continue;
225255

256+
int hitRow = -1;
257+
for (int r = 0; r < bankHit.rows(); r++) {
258+
if (bankHit.getInt("id", r) == hitid) { hitRow = r; break; }
259+
}
260+
if (hitRow < 0) continue;
261+
262+
// Build feature vector float[23] in the exact training order
263+
float[] x = new float[23];
264+
265+
// AHDC::track (13)
266+
x[0] = bankTrk.getFloat("x", trkRow);
267+
x[1] = bankTrk.getFloat("y", trkRow);
268+
x[2] = bankTrk.getFloat("z", trkRow);
269+
x[3] = bankTrk.getFloat("px", trkRow);
270+
x[4] = bankTrk.getFloat("py", trkRow);
271+
x[5] = bankTrk.getFloat("pz", trkRow);
272+
x[6] = bankTrk.getInt("n_hits", trkRow);
273+
x[7] = bankTrk.getInt("sum_adc", trkRow);
274+
x[8] = bankTrk.getFloat("path", trkRow);
275+
x[9] = bankTrk.getFloat("dEdx", trkRow);
276+
x[10] = bankTrk.getFloat("p_drift", trkRow);
277+
x[11] = bankTrk.getFloat("chi2", trkRow);
278+
x[12] = bankTrk.getFloat("sum_residuals", trkRow);
279+
280+
/*// ATOF::clusters (10)
281+
x[13] = bankClu.getInt("n_bar", cluRow);
282+
x[14] = bankClu.getInt("n_wedge", cluRow);
283+
x[15] = bankClu.getFloat("time", cluRow);
284+
x[16] = bankClu.getFloat("x", cluRow);
285+
x[17] = bankClu.getFloat("y", cluRow);
286+
x[18] = bankClu.getFloat("z", cluRow);
287+
x[19] = bankClu.getFloat("energy", cluRow);
288+
x[20] = bankClu.getFloat("pathlength", cluRow);
289+
x[21] = bankClu.getFloat("inpathlength", cluRow);
290+
x[22] = bankClu.getInt("projID", cluRow);*/
291+
292+
// ATOF::Hits (Temporarily updating to the same 10 slots as ATOF Clusters would have if it worked)
293+
x[13] = 0f;
294+
x[14] = 0f;
295+
x[15] = bankHit.getFloat("time", hitRow);
296+
x[16] = bankHit.getFloat("x", hitRow);
297+
x[17] = bankHit.getFloat("y", hitRow);
298+
x[18] = bankHit.getFloat("z", hitRow);
299+
x[19] = bankHit.getFloat("energy", hitRow);
300+
x[20] = 0f;
301+
x[21] = 0f;
302+
x[22] = 0f;
303+
304+
try {
305+
float[] pred = modelPrePID.prediction(x);
306+
int prepid = (int) pred[0];
307+
prepid_results.add(new PrePIDResult(trackid, hitid, prepid, pred[1], pred[2], pred[3], pred[4], pred[5]));
308+
} catch (TranslateException ex) {
309+
LOGGER.warning(() -> "Exception in ALERTEngine PrePID: " + ex);
310+
}
311+
}
312+
313+
rbc.appendPrePIDBank(event, prepid_results);
314+
}
315+
316+
317+
226318
///////////////////////////////////////////
227319
/// Kalmam Filter
228320
/// ///////////////////////////////////////

0 commit comments

Comments
 (0)