Skip to content

Commit 8e1ecd1

Browse files
committed
Updated the spectra generation script to include validation data.
1 parent d482143 commit 8e1ecd1

5 files changed

Lines changed: 184 additions & 47 deletions

File tree

cosmopower/examples/2_create_spectra.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,17 @@
1414
tqdm = lambda x: x # noqa: E731
1515

1616
"""
17-
In the previous file, we generated an LHC and saved it to the file
17+
In the previous tutorial, we generated an LHC and saved it to the file
1818
`example/spectra/parameters.hdf5`. Now we'll generate the spectra associated
1919
with this dataset.
20+
21+
The result of this file will be several training spectra. You can compare the
22+
result of this tutorial with invoking the command
23+
24+
python -m cosmopower generate example.yaml
25+
26+
which generates both the LHC from tutorial 1 and the spectra from this
27+
tutorial.
2028
"""
2129
parser = YAMLParser("example.yaml")
2230

cosmopower/examples/3_train_emulator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111
"""
1212
In the previous file, we created a training dataset for a linear P(k,z)
1313
emulator. Here, we will train the emulator over this dataset.
14+
15+
The result of this file will be an emulator trained over the spectra generated
16+
before. You can compare the result from this file with the results from
17+
invoking the command
18+
19+
python -m cosmopower train example.yaml
20+
21+
which loads the data, initializes the emulators, and trains them.
1422
"""
1523
parser = YAMLParser("example.yaml")
1624

@@ -42,7 +50,7 @@
4250
trainable=True,
4351
**settings.get("n_traits", {}))
4452

45-
with tf.device("/device:CPU:0"):
53+
with tf.device(None):
4654
network.train(training_data=datasets,
4755
filename_saved_model=output_file,
4856
validation=validation,

cosmopower/examples/example.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ emulated_code:
1414

1515
samples:
1616
# How many training and validation samples do we want to generate.
17-
Ntraining: 1000
18-
Nvalidation: 100
17+
Ntraining: 400
18+
Nvalidation: 25
1919

2020
# The parameters of the LHC over which the samples are generated.
2121
parameters:

cosmopower/spectra.py

Lines changed: 156 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import numpy as np
88
import h5py
99
from importlib import import_module
10-
from typing import Optional
10+
from typing import Optional, Tuple
11+
from types import ModuleType
1112

1213

1314
def setup_path(parser: YAMLParser, args: object) -> bool:
@@ -93,7 +94,8 @@ def get_boltzmann_spectra(parser: YAMLParser, state: dict, args: dict = {},
9394

9495

9596
def cycle_spectrum_file(parser: YAMLParser, quantity: str,
96-
fp: Optional[Dataset], n: int = 0) -> Dataset:
97+
fp: Optional[Dataset], n: int = 0,
98+
validation: bool = False) -> Dataset:
9799
"""
98100
Cycle the given spectrum file, i.e. open the next file we expect will
99101
contain spectrum data.
@@ -103,18 +105,77 @@ def cycle_spectrum_file(parser: YAMLParser, quantity: str,
103105
can contain spectra for quantity.
104106
If the resulting file does not exist, it will automatically create one.
105107
"""
108+
suffix = "_validation" if validation else ""
106109
if fp is None:
107110
dataset = Dataset(parser, quantity,
108-
quantity.replace("/", "_") + f".{n}.hdf5")
111+
quantity.replace("/", "_") + f"{suffix}.{n}.hdf5")
109112
else:
110113
i = int(fp.filename.split(".")[1]) + 1
111114
dataset = Dataset(parser, quantity,
112-
quantity.replace("/", "_") + f".{i}.hdf5")
115+
quantity.replace("/", "_") + f"{suffix}.{i}.hdf5")
113116
fp.close()
114117
dataset.open()
115118
return dataset
116119

117120

121+
def split_samples(MPI: Optional[ModuleType], parser: YAMLParser, samples: dict,
122+
nsamples: int) -> Tuple:
123+
"""
124+
Given a set of samples, check how to split them evenly between the MPI
125+
processes, and get the range over which this process needs to operate.
126+
"""
127+
if MPI:
128+
comm = MPI.COMM_WORLD
129+
rank = comm.Get_rank()
130+
n_tot = comm.Get_size()
131+
Barrier = comm.Barrier
132+
else:
133+
MPI = None
134+
comm = None
135+
rank = 0
136+
n_tot = 1
137+
Barrier = lambda: -1 # noqa E731
138+
139+
first = 0
140+
last = nsamples-1
141+
first_file = 0
142+
last_file = (last // parser.max_filesize)
143+
144+
# If mpi-ing, share the data here.
145+
if comm is not None:
146+
samples = comm.bcast(samples, root=0)
147+
# Amount of spectra/files handled by each runner
148+
cut = (last - first) // n_tot
149+
fcut = int(np.ceil(float(cut) / parser.max_filesize))
150+
151+
first, last = first + (rank) * cut, first + (rank + 1) * cut - 1
152+
first_file, last_file = (first_file + (rank) * fcut,
153+
first_file + (rank + 1) * fcut - 1)
154+
155+
return samples, first, last, first_file, last_file
156+
157+
158+
def check_open_files(parser: YAMLParser, files: dict, n: int,
159+
first_file: int, validation: bool = False
160+
) -> Tuple[dict, list]:
161+
quantities_to_be_computed = []
162+
163+
for q in parser.quantities:
164+
if files[q] is None:
165+
# Open first file to read from.
166+
files[q] = cycle_spectrum_file(parser, q, files[q],
167+
n=first_file, validation=validation)
168+
while files[q].empty_size == 0 and files[q].indices.max() < n:
169+
# Current file is full, so we have to cycle to the next one.
170+
files[q] = cycle_spectrum_file(parser, q, files[q],
171+
n=first_file, validation=validation)
172+
173+
if n not in files[q].indices:
174+
quantities_to_be_computed.append(q)
175+
176+
return files, quantities_to_be_computed
177+
178+
118179
def generate_spectra(args: list = None) -> None:
119180
"""
120181
Hook for the "generate spectra" command.
@@ -184,29 +245,22 @@ def generate_spectra(args: list = None) -> None:
184245

185246
if rank == 0:
186247
# TODO: Generate the validation samples.
187-
samples, validation_samples = \
188-
parser.get_parameter_samples(force_new = args.force)
248+
if parser.nvalidation:
249+
samples, validation_samples = \
250+
parser.get_parameter_samples(force_new=args.force_overwrite,
251+
return_validation=True)
252+
else:
253+
samples = \
254+
parser.get_parameter_samples(force_new=args.force_overwrite)
255+
validation_samples = {}
189256
else:
190-
samples = None
257+
samples, validation_samples = None, None
191258

192-
first = 0
193-
last = parser.nsamples
194-
first_file = 0
195-
last_file = (last // parser.max_filesize)
259+
samples, first, last, first_file, last_file = \
260+
split_samples(MPI, parser, samples, parser.nsamples)
196261

197-
# If mpi-ing, share the data here.
198-
if comm is not None:
199-
samples = comm.bcast(samples, root=0)
200-
# Amount of spectra/files handled by each runner
201-
cut = (last - first) // n_tot
202-
fcut = int(np.ceil(float(cut) / parser.max_filesize))
203-
204-
first, last = first + (rank) * cut, first + (rank + 1) * cut - 1
205-
first_file, last_file = (first_file + (rank) * fcut,
206-
first_file + (rank + 1) * fcut - 1)
207-
208-
print(f"[{rank}]: Iterating over samples {first}--{last} in files \
209-
{first_file}--{last_file}.")
262+
print(f"[{rank}]: Iterating over samples {first}--{last} in files " \
263+
f"{first_file}--{last_file}.")
210264

211265
state = init_boltzmann_code(parser)
212266
extra_args = parser.boltzmann_extra_args
@@ -221,20 +275,9 @@ def generate_spectra(args: list = None) -> None:
221275
+ f"{accepted/n:.1%} success rate")
222276

223277
boltzmann_params = {k: samples[k][n] for k in parser.boltzmann_inputs}
224-
quantities_to_be_computed = []
225-
226-
for q in parser.quantities:
227-
if files[q] is None:
228-
# Open first file to read from.
229-
files[q] = cycle_spectrum_file(parser, q, files[q],
230-
n=first_file)
231-
while files[q].empty_size == 0 and files[q].indices.max() < n:
232-
# Current file is full, so we have to cycle to the next one.
233-
files[q] = cycle_spectrum_file(parser, q, files[q],
234-
n=first_file)
235278

236-
if n not in files[q].indices:
237-
quantities_to_be_computed.append(q)
279+
files, quantities_to_be_computed = \
280+
check_open_files(parser, files, n, first_file,validation=False)
238281

239282
if len(quantities_to_be_computed) == 0:
240283
accepted += 1
@@ -254,10 +297,79 @@ def generate_spectra(args: list = None) -> None:
254297
else:
255298
spec = state.get(q, None)
256299

300+
if spec is None:
301+
continue
302+
303+
if parser.is_log(q):
304+
spec = np.log10(spec)
305+
306+
if np.any(np.isnan(spec)):
307+
continue
308+
257309
network_params = np.array([
258310
samples[k][n] for k in parser.network_input_parameters(q)
259311
])
260312

313+
if files[q] is None:
314+
files[q] = cycle_spectrum_file(parser, q, files[q],
315+
n=first_file,
316+
validation=False)
317+
318+
files[q].write_data(n, network_params, spec)
319+
320+
for q in files:
321+
if files[q] is not None and files[q].is_open:
322+
files[q].close()
323+
files[q] = None
324+
325+
if validation_samples == {}:
326+
if rank == 0:
327+
print(f"Finished generating {accepted} spectra.")
328+
print(f"You can now run\n\tcosmopower train {args.yamlfile}\n" \
329+
"to train the networks on this dataset.")
330+
return
331+
332+
# Do the exact same thing all over again, but for validation samples.
333+
validation_samples, first, last, first_file, last_file = \
334+
split_samples(MPI, parser, validation_samples, parser.nvalidation)
335+
336+
print(f"[{rank}]: Iterating over validation samples {first}--{last} in " \
337+
f"files {first_file}--{last_file}.")
338+
339+
accepted = 0
340+
tbar = tqdm.tqdm(np.arange(first, last + 1))
341+
342+
Barrier()
343+
344+
for n in tbar:
345+
tbar.set_description(("" if MPI is None else f"[{rank}] ")
346+
+ f"{accepted/n:.1%} success rate")
347+
348+
boltzmann_params = {
349+
k: validation_samples[k][n] for k in parser.boltzmann_inputs
350+
}
351+
352+
files, quantities_to_be_computed = \
353+
check_open_files(parser, files, n, first_file, validation=True)
354+
355+
if len(quantities_to_be_computed) == 0:
356+
accepted += 1
357+
continue
358+
359+
if get_boltzmann_spectra(parser, state, boltzmann_params,
360+
quantities_to_be_computed, extra_args):
361+
accepted += 1
362+
363+
for k in state["derived"]:
364+
validation_samples[k][n] = state["derived"][k]
365+
366+
for q in quantities_to_be_computed:
367+
if q == "derived":
368+
spec = np.asarray([state["derived"].get(p)
369+
for p in parser.computed_parameters])
370+
else:
371+
spec = state.get(q, None)
372+
261373
if spec is None:
262374
continue
263375

@@ -267,17 +379,18 @@ def generate_spectra(args: list = None) -> None:
267379
if np.any(np.isnan(spec)):
268380
continue
269381

382+
network_params = np.array([
383+
validation_samples[k][n]
384+
for k in parser.network_input_parameters(q)
385+
])
386+
270387
if files[q] is None:
271388
files[q] = cycle_spectrum_file(parser, q, files[q],
272-
n=first_file)
389+
n=first_file,
390+
validation=True)
273391

274392
files[q].write_data(n, network_params, spec)
275393

276394
for q in files:
277395
if files[q] is not None and files[q].is_open:
278396
files[q].close()
279-
280-
if rank == 0:
281-
print(f"Finished generating {accepted} spectra.")
282-
print(f"You can now run\n\tcosmopower train {args.yamlfile}\n\
283-
to train the networks on this dataset.")

cosmopower/train.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ def train_network_NN(parser: YAMLParser, quantity: str, device: str = "",
5757
"_validation.*.hdf5"))
5858
validation = [Dataset(parser, quantity, os.path.basename(filename))
5959
for filename in filenames]
60+
61+
if len(validation) == 0:
62+
print(f"No validation data found? Defaulting to 10% split.")
63+
validation = 0.1
6064

6165
with tf.device(device):
6266
print("\tTraining NN.")
@@ -108,6 +112,10 @@ def train_network_PCAplusNN(parser: YAMLParser, quantity: str,
108112
"_validation.*.hdf5"))
109113
validation = [Dataset(parser, quantity, os.path.basename(filename))
110114
for filename in filenames]
115+
116+
if len(validation) == 0:
117+
print(f"No validation data found? Defaulting to 10% split.")
118+
validation = 0.1
111119

112120
cp_pca = cosmopower_PCA(parameters=parameters, modes=modes, n_pcas=n_pcas,
113121
n_batches=n_batches, verbose=True)

0 commit comments

Comments
 (0)