77import numpy as np
88import h5py
99from importlib import import_module
10- from typing import Optional
10+ from typing import Optional , Tuple
11+ from types import ModuleType
1112
1213
1314def setup_path (parser : YAMLParser , args : object ) -> bool :
@@ -93,7 +94,8 @@ def get_boltzmann_spectra(parser: YAMLParser, state: dict, args: dict = {},
9394
9495
9596def cycle_spectrum_file (parser : YAMLParser , quantity : str ,
96- fp : Optional [Dataset ], n : int = 0 ) -> Dataset :
97+ fp : Optional [Dataset ], n : int = 0 ,
98+ validation : bool = False ) -> Dataset :
9799 """
98100 Cycle the given spectrum file, i.e. open the next file we expect will
99101 contain spectrum data.
@@ -103,18 +105,77 @@ def cycle_spectrum_file(parser: YAMLParser, quantity: str,
103105 can contain spectra for quantity.
104106 If the resulting file does not exist, it will automatically create one.
105107 """
108+ suffix = "_validation" if validation else ""
106109 if fp is None :
107110 dataset = Dataset (parser , quantity ,
108- quantity .replace ("/" , "_" ) + f".{ n } .hdf5" )
111+ quantity .replace ("/" , "_" ) + f"{ suffix } .{ n } .hdf5" )
109112 else :
110113 i = int (fp .filename .split ("." )[1 ]) + 1
111114 dataset = Dataset (parser , quantity ,
112- quantity .replace ("/" , "_" ) + f".{ i } .hdf5" )
115+ quantity .replace ("/" , "_" ) + f"{ suffix } .{ i } .hdf5" )
113116 fp .close ()
114117 dataset .open ()
115118 return dataset
116119
117120
121+ def split_samples (MPI : Optional [ModuleType ], parser : YAMLParser , samples : dict ,
122+ nsamples : int ) -> Tuple :
123+ """
124+ Given a set of samples, check how to split them evenly between the MPI
125+ processes, and get the range over which this process needs to operate.
126+ """
127+ if MPI :
128+ comm = MPI .COMM_WORLD
129+ rank = comm .Get_rank ()
130+ n_tot = comm .Get_size ()
131+ Barrier = comm .Barrier
132+ else :
133+ MPI = None
134+ comm = None
135+ rank = 0
136+ n_tot = 1
137+ Barrier = lambda : - 1 # noqa E731
138+
139+ first = 0
140+ last = nsamples - 1
141+ first_file = 0
142+ last_file = (last // parser .max_filesize )
143+
144+ # If mpi-ing, share the data here.
145+ if comm is not None :
146+ samples = comm .bcast (samples , root = 0 )
147+ # Amount of spectra/files handled by each runner
148+ cut = (last - first ) // n_tot
149+ fcut = int (np .ceil (float (cut ) / parser .max_filesize ))
150+
151+ first , last = first + (rank ) * cut , first + (rank + 1 ) * cut - 1
152+ first_file , last_file = (first_file + (rank ) * fcut ,
153+ first_file + (rank + 1 ) * fcut - 1 )
154+
155+ return samples , first , last , first_file , last_file
156+
157+
158+ def check_open_files (parser : YAMLParser , files : dict , n : int ,
159+ first_file : int , validation : bool = False
160+ ) -> Tuple [dict , list ]:
161+ quantities_to_be_computed = []
162+
163+ for q in parser .quantities :
164+ if files [q ] is None :
165+ # Open first file to read from.
166+ files [q ] = cycle_spectrum_file (parser , q , files [q ],
167+ n = first_file , validation = validation )
168+ while files [q ].empty_size == 0 and files [q ].indices .max () < n :
169+ # Current file is full, so we have to cycle to the next one.
170+ files [q ] = cycle_spectrum_file (parser , q , files [q ],
171+ n = first_file , validation = validation )
172+
173+ if n not in files [q ].indices :
174+ quantities_to_be_computed .append (q )
175+
176+ return files , quantities_to_be_computed
177+
178+
118179def generate_spectra (args : list = None ) -> None :
119180 """
120181 Hook for the "generate spectra" command.
@@ -184,29 +245,22 @@ def generate_spectra(args: list = None) -> None:
184245
185246 if rank == 0 :
186247 # TODO: Generate the validation samples.
187- samples , validation_samples = \
188- parser .get_parameter_samples (force_new = args .force )
248+ if parser .nvalidation :
249+ samples , validation_samples = \
250+ parser .get_parameter_samples (force_new = args .force_overwrite ,
251+ return_validation = True )
252+ else :
253+ samples = \
254+ parser .get_parameter_samples (force_new = args .force_overwrite )
255+ validation_samples = {}
189256 else :
190- samples = None
257+ samples , validation_samples = None , None
191258
192- first = 0
193- last = parser .nsamples
194- first_file = 0
195- last_file = (last // parser .max_filesize )
259+ samples , first , last , first_file , last_file = \
260+ split_samples (MPI , parser , samples , parser .nsamples )
196261
197- # If mpi-ing, share the data here.
198- if comm is not None :
199- samples = comm .bcast (samples , root = 0 )
200- # Amount of spectra/files handled by each runner
201- cut = (last - first ) // n_tot
202- fcut = int (np .ceil (float (cut ) / parser .max_filesize ))
203-
204- first , last = first + (rank ) * cut , first + (rank + 1 ) * cut - 1
205- first_file , last_file = (first_file + (rank ) * fcut ,
206- first_file + (rank + 1 ) * fcut - 1 )
207-
208- print (f"[{ rank } ]: Iterating over samples { first } --{ last } in files \
209- { first_file } --{ last_file } ." )
262+ print (f"[{ rank } ]: Iterating over samples { first } --{ last } in files " \
263+ f"{ first_file } --{ last_file } ." )
210264
211265 state = init_boltzmann_code (parser )
212266 extra_args = parser .boltzmann_extra_args
@@ -221,20 +275,9 @@ def generate_spectra(args: list = None) -> None:
221275 + f"{ accepted / n :.1%} success rate" )
222276
223277 boltzmann_params = {k : samples [k ][n ] for k in parser .boltzmann_inputs }
224- quantities_to_be_computed = []
225-
226- for q in parser .quantities :
227- if files [q ] is None :
228- # Open first file to read from.
229- files [q ] = cycle_spectrum_file (parser , q , files [q ],
230- n = first_file )
231- while files [q ].empty_size == 0 and files [q ].indices .max () < n :
232- # Current file is full, so we have to cycle to the next one.
233- files [q ] = cycle_spectrum_file (parser , q , files [q ],
234- n = first_file )
235278
236- if n not in files [ q ]. indices :
237- quantities_to_be_computed . append ( q )
279+ files , quantities_to_be_computed = \
280+ check_open_files ( parser , files , n , first_file , validation = False )
238281
239282 if len (quantities_to_be_computed ) == 0 :
240283 accepted += 1
@@ -254,10 +297,79 @@ def generate_spectra(args: list = None) -> None:
254297 else :
255298 spec = state .get (q , None )
256299
300+ if spec is None :
301+ continue
302+
303+ if parser .is_log (q ):
304+ spec = np .log10 (spec )
305+
306+ if np .any (np .isnan (spec )):
307+ continue
308+
257309 network_params = np .array ([
258310 samples [k ][n ] for k in parser .network_input_parameters (q )
259311 ])
260312
313+ if files [q ] is None :
314+ files [q ] = cycle_spectrum_file (parser , q , files [q ],
315+ n = first_file ,
316+ validation = False )
317+
318+ files [q ].write_data (n , network_params , spec )
319+
320+ for q in files :
321+ if files [q ] is not None and files [q ].is_open :
322+ files [q ].close ()
323+ files [q ] = None
324+
325+ if validation_samples == {}:
326+ if rank == 0 :
327+ print (f"Finished generating { accepted } spectra." )
328+ print (f"You can now run\n \t cosmopower train { args .yamlfile } \n " \
329+ "to train the networks on this dataset." )
330+ return
331+
332+ # Do the exact same thing all over again, but for validation samples.
333+ validation_samples , first , last , first_file , last_file = \
334+ split_samples (MPI , parser , validation_samples , parser .nvalidation )
335+
336+ print (f"[{ rank } ]: Iterating over validation samples { first } --{ last } in " \
337+ f"files { first_file } --{ last_file } ." )
338+
339+ accepted = 0
340+ tbar = tqdm .tqdm (np .arange (first , last + 1 ))
341+
342+ Barrier ()
343+
344+ for n in tbar :
345+ tbar .set_description (("" if MPI is None else f"[{ rank } ] " )
346+ + f"{ accepted / n :.1%} success rate" )
347+
348+ boltzmann_params = {
349+ k : validation_samples [k ][n ] for k in parser .boltzmann_inputs
350+ }
351+
352+ files , quantities_to_be_computed = \
353+ check_open_files (parser , files , n , first_file , validation = True )
354+
355+ if len (quantities_to_be_computed ) == 0 :
356+ accepted += 1
357+ continue
358+
359+ if get_boltzmann_spectra (parser , state , boltzmann_params ,
360+ quantities_to_be_computed , extra_args ):
361+ accepted += 1
362+
363+ for k in state ["derived" ]:
364+ validation_samples [k ][n ] = state ["derived" ][k ]
365+
366+ for q in quantities_to_be_computed :
367+ if q == "derived" :
368+ spec = np .asarray ([state ["derived" ].get (p )
369+ for p in parser .computed_parameters ])
370+ else :
371+ spec = state .get (q , None )
372+
261373 if spec is None :
262374 continue
263375
@@ -267,17 +379,18 @@ def generate_spectra(args: list = None) -> None:
267379 if np .any (np .isnan (spec )):
268380 continue
269381
382+ network_params = np .array ([
383+ validation_samples [k ][n ]
384+ for k in parser .network_input_parameters (q )
385+ ])
386+
270387 if files [q ] is None :
271388 files [q ] = cycle_spectrum_file (parser , q , files [q ],
272- n = first_file )
389+ n = first_file ,
390+ validation = True )
273391
274392 files [q ].write_data (n , network_params , spec )
275393
276394 for q in files :
277395 if files [q ] is not None and files [q ].is_open :
278396 files [q ].close ()
279-
280- if rank == 0 :
281- print (f"Finished generating { accepted } spectra." )
282- print (f"You can now run\n \t cosmopower train { args .yamlfile } \n \
283- to train the networks on this dataset." )
0 commit comments