1717
1818
1919@njit
20- def dict_to_points (points : NumbaTypedDict [int , int ]) -> np .ndarray :
20+ def dict_to_points (
21+ points : NumbaTypedDict [int , tuple [int , int ]],
22+ ) -> tuple [np .ndarray , np .ndarray ]:
2123 """
2224 Converts dictionary of N pad,tb keys with corresponding number of electrons
2325 to Nx3 array where each row is [pad, tb, e], now combined over pad/tb combos.
@@ -29,17 +31,19 @@ def dict_to_points(points: NumbaTypedDict[int, int]) -> np.ndarray:
2931
3032 Returns
3133 -------
32- point_array: numpy.ndarray
33- Array of points.
34+ tuple[numpy.ndarray, numpy.ndarray]
35+ Array of points and lables (in that order)
3436 """
3537 point_array = np .empty ((len (points ), 3 ), dtype = float )
36- for idx , point in enumerate (points .items ()):
37- tb , pad = unpair (point [0 ])
38+ label_array = np .empty (len (points ), dtype = types .int64 )
39+ for idx , (key , data ) in enumerate (points .items ()):
40+ tb , pad = unpair (key )
3841 point_array [idx , 0 ] = pad
3942 point_array [idx , 1 ] = tb
40- point_array [idx , 2 ] = point [1 ]
43+ point_array [idx , 2 ] = data [0 ]
44+ label_array [idx ] = data [1 ]
4145
42- return point_array
46+ return point_array , label_array
4347
4448
4549def simulate (
@@ -49,38 +53,31 @@ def simulate(
4953 mass_numbers : np .ndarray ,
5054 config : Config ,
5155 rng : Generator ,
52- indicies : list [int ] | None ,
53- ) -> np .ndarray :
54- nuclei_to_sim = None
55- if indicies is not None :
56- nuclei_to_sim = indicies
57- else :
58- # default nuclei to sim, all final outgoing particles
59- nuclei_to_sim = [idx for idx in range (2 , len (proton_numbers ), 2 )]
60- nuclei_to_sim .append (len (proton_numbers ) - 1 ) # add the last
61-
62- points = Dict .empty (key_type = types .int64 , value_type = types .int64 )
63-
64- for idx in nuclei_to_sim :
56+ indicies : list [int ],
57+ ) -> tuple [np .ndarray , np .ndarray ]:
58+ points = Dict .empty (
59+ key_type = types .int64 , value_type = types .Tuple (types = [types .int64 , types .int64 ])
60+ )
61+ for idx in indicies :
6562 if proton_numbers [idx ] == 0 :
6663 continue
6764 nucleus = nuclear_map .get_data (proton_numbers [idx ], mass_numbers [idx ])
6865 momentum = momenta [idx ]
69- generate_point_cloud (momentum , vertex , nucleus , config , rng , points )
66+ generate_point_cloud (momentum , vertex , nucleus , config , rng , points , idx )
7067
7168 # Convert to numpy array of [pad, tb, e], now combined over pad/tb combos
72- point_array = dict_to_points (points )
69+ point_array , label_array = dict_to_points (points )
7370
7471 # Wiggle point TBs over interval [0.0, 1.0). This simulates effect of converting
7572 # the (in principle) int TBs to floats.
7673 point_array [:, 1 ] += rng .uniform (low = 0.0 , high = 1.0 , size = len (point_array ))
7774
7875 # Remove points outside legal bounds in time. TODO check if this is needed
79- point_array = point_array [
80- np . logical_and ( 0 < = point_array [:, 1 ], point_array [:, 1 ] < NUM_TB )
81- ]
76+ mask = np . logical_and ( 0 <= point_array [:, 1 ], point_array [:, 1 ] < NUM_TB )
77+ point_array = point_array [mask ]
78+ label_array = label_array [ mask ]
8279
83- return point_array
80+ return point_array , label_array
8481
8582
8683def run_simulation (
@@ -107,9 +104,18 @@ def run_simulation(
107104 print (f"Applying detector effects to kinematics from file: { input_path } " )
108105 input = h5 .File (input_path , "r" )
109106 input_data_group : h5 .Group = input ["data" ] # type: ignore
110- proton_numbers = input_data_group .attrs ["proton_numbers" ]
107+ proton_numbers : np . ndarray = input_data_group .attrs ["proton_numbers" ] # type: ignore
111108 mass_numbers = input_data_group .attrs ["mass_numbers" ]
112109
110+ # Decide which nuclei to sim, either by user input or all reaction final products
111+ nuclei_to_sim = None
112+ if indicies is not None :
113+ nuclei_to_sim = indicies
114+ else :
115+ # default nuclei to sim, all final outgoing particles
116+ nuclei_to_sim = [idx for idx in range (2 , len (proton_numbers ), 2 )]
117+ nuclei_to_sim .append (len (proton_numbers ) - 1 ) # add the last
118+
113119 n_events : int = input_data_group .attrs ["n_events" ] # type: ignore
114120 miniters = int (0.01 * n_events )
115121 n_chunks : int = input_data_group .attrs ["n_chunks" ] # type: ignore
@@ -138,7 +144,7 @@ def run_simulation(
138144 dataset : h5 .Dataset = input_data_group [f"chunk_{ chunk } " ][ # type: ignore
139145 f"event_{ event_number } "
140146 ] # type: ignore
141- cloud = simulate (
147+ cloud , labels = simulate (
142148 dataset [:].copy (), # type: ignore
143149 np .array (
144150 [
@@ -151,13 +157,13 @@ def run_simulation(
151157 mass_numbers , # type: ignore
152158 config ,
153159 rng ,
154- indicies ,
160+ nuclei_to_sim ,
155161 )
156162
157163 if len (cloud ) == 0 :
158164 continue
159165
160- writer .write (cloud , config , event_number )
166+ writer .write (cloud , labels , config , event_number )
161167 writer .close ()
162168 print ("Done." )
163169 print ("----------------------------------------" )
0 commit comments