77from pathlib import Path
88from typing import Any
99
10+ import h5py
11+ import numpy as np
12+
13+ from dfode_kit .data .contracts import SCALAR_FIELDS_GROUP , ordered_group_dataset_names , require_h5_group
14+
1015
1116DEFAULT_AUGMENT_PRESET = 'random-local-combustion-v1'
1217
@@ -56,6 +61,7 @@ def resolve_augment_plan(args) -> dict[str, Any]:
5661 preset_name = args .preset or plan .get ('preset' , DEFAULT_AUGMENT_PRESET )
5762 target_size = args .target_size if args .target_size is not None else plan .get ('target_size' )
5863 seed = args .seed if args .seed is not None else plan .get ('seed' )
64+ time_selectors = args .time if args .time is not None else plan .get ('time_selectors' )
5965 else :
6066 _validate_required_args (args , ('source' , 'mech' , 'preset' , 'target_size' ))
6167 source = args .source
@@ -64,6 +70,7 @@ def resolve_augment_plan(args) -> dict[str, Any]:
6470 preset_name = args .preset
6571 target_size = args .target_size
6672 seed = args .seed
73+ time_selectors = args .time
6774
6875 if args .apply and not save :
6976 raise ValueError ('The --save path is required when using --apply.' )
@@ -77,6 +84,9 @@ def resolve_augment_plan(args) -> dict[str, Any]:
7784 if not mechanism_path .is_file ():
7885 raise ValueError (f'Mechanism file does not exist: { mechanism_path } ' )
7986
87+ ordered_names = _read_ordered_snapshot_names (source_path )
88+ resolved_snapshot_names = _resolve_time_selectors (ordered_names , time_selectors )
89+
8090 plan = {
8191 'schema_version' : 1 ,
8292 'command_type' : 'augment' ,
@@ -87,6 +97,9 @@ def resolve_augment_plan(args) -> dict[str, Any]:
8797 'save' : str (Path (save ).resolve ()) if save else None ,
8898 'target_size' : int (target_size ),
8999 'seed' : int (seed ) if seed is not None else None ,
100+ 'time_selectors' : list (time_selectors ) if time_selectors else None ,
101+ 'resolved_snapshot_names' : resolved_snapshot_names ,
102+ 'resolved_snapshot_count' : len (resolved_snapshot_names ),
90103 'config_path' : str (Path (args .from_config ).resolve ()) if args .from_config else None ,
91104 'notes' : preset .notes ,
92105 'resolved' : dict (preset .resolved ),
@@ -95,17 +108,15 @@ def resolve_augment_plan(args) -> dict[str, Any]:
95108
96109
97110def apply_augment_plan (plan : dict [str , Any ], quiet : bool = False ) -> dict [str , Any ]:
98- import numpy as np
99-
100- from dfode_kit .data import get_TPY_from_h5 , random_perturb
111+ from dfode_kit .data import random_perturb
101112
102113 source_path = Path (plan ['source' ]).resolve ()
103114 output_path = Path (plan ['save' ]).resolve ()
104115 output_path .parent .mkdir (parents = True , exist_ok = True )
105116
106117 if quiet :
107118 with redirect_stdout (io .StringIO ()):
108- data = get_TPY_from_h5 (source_path )
119+ data = _load_selected_tpy_from_h5 (source_path , plan [ 'resolved_snapshot_names' ] )
109120 augmented = random_perturb (
110121 data ,
111122 plan ['mechanism' ],
@@ -117,7 +128,9 @@ def apply_augment_plan(plan: dict[str, Any], quiet: bool = False) -> dict[str, A
117128 else :
118129 print ('Handling augment command' )
119130 print (f'Loading data from h5 file: { source_path } ' )
120- data = get_TPY_from_h5 (source_path )
131+ if plan ['time_selectors' ]:
132+ print (f"Selecting snapshots with --time: { plan ['time_selectors' ]} " )
133+ data = _load_selected_tpy_from_h5 (source_path , plan ['resolved_snapshot_names' ])
121134 print ('Data shape:' , data .shape )
122135 augmented = random_perturb (
123136 data ,
@@ -141,6 +154,8 @@ def apply_augment_plan(plan: dict[str, Any], quiet: bool = False) -> dict[str, A
141154 'returned_count' : int (augmented .shape [0 ]),
142155 'feature_count' : int (augmented .shape [1 ]) if augmented .ndim == 2 else None ,
143156 'seed' : plan .get ('seed' ),
157+ 'resolved_snapshot_count' : int (plan ['resolved_snapshot_count' ]),
158+ 'resolved_snapshot_names' : list (plan ['resolved_snapshot_names' ]),
144159 }
145160
146161
@@ -156,6 +171,72 @@ def load_plan_json(path: str | Path) -> dict[str, Any]:
156171 return json .loads (input_path .read_text (encoding = 'utf-8' ))
157172
158173
174+ def _read_ordered_snapshot_names (source_path : str | Path ) -> list [str ]:
175+ with h5py .File (source_path , 'r' ) as hdf5_file :
176+ scalar_group = require_h5_group (hdf5_file , SCALAR_FIELDS_GROUP )
177+ return ordered_group_dataset_names (scalar_group )
178+
179+
180+ def _load_selected_tpy_from_h5 (source_path : str | Path , dataset_names : list [str ]) -> np .ndarray :
181+ with h5py .File (source_path , 'r' ) as hdf5_file :
182+ scalar_group = require_h5_group (hdf5_file , SCALAR_FIELDS_GROUP )
183+ arrays = [scalar_group [name ][:] for name in dataset_names ]
184+ if not arrays :
185+ raise ValueError (f"No datasets selected from '{ SCALAR_FIELDS_GROUP } ' in { source_path } " )
186+ return np .concatenate (arrays , axis = 0 )
187+
188+
189+ def _resolve_time_selectors (ordered_names : list [str ], selectors : list [str ] | None ) -> list [str ]:
190+ if not ordered_names :
191+ raise ValueError ('No scalar-field snapshots are available in the source HDF5.' )
192+ if not selectors :
193+ return list (ordered_names )
194+
195+ selected_indices : list [int ] = []
196+ seen = set ()
197+ for selector in selectors :
198+ indices = _indices_from_selector (selector , len (ordered_names ))
199+ for index in indices :
200+ if index not in seen :
201+ seen .add (index )
202+ selected_indices .append (index )
203+
204+ if not selected_indices :
205+ raise ValueError ('The provided --time selectors resolved to zero snapshots.' )
206+
207+ return [ordered_names [index ] for index in selected_indices ]
208+
209+
210+ def _indices_from_selector (selector : str , length : int ) -> list [int ]:
211+ text = selector .strip ()
212+ if not text :
213+ raise ValueError ('Empty --time selector is not allowed.' )
214+
215+ if ':' in text :
216+ parts = text .split (':' )
217+ if len (parts ) > 3 :
218+ raise ValueError (f'Invalid --time slice selector: { selector } ' )
219+ values = []
220+ for part in parts :
221+ if part == '' :
222+ values .append (None )
223+ else :
224+ values .append (int (part ))
225+ while len (values ) < 3 :
226+ values .append (None )
227+ start , stop , step = values
228+ if step == 0 :
229+ raise ValueError (f'Invalid --time selector with zero step: { selector } ' )
230+ return list (range (length ))[slice (start , stop , step )]
231+
232+ index = int (text )
233+ if index < 0 :
234+ index += length
235+ if index < 0 or index >= length :
236+ raise ValueError (f'--time index out of range for { length } snapshots: { selector } ' )
237+ return [index ]
238+
239+
159240def _validate_required_args (args , names : tuple [str , ...]):
160241 missing = [f'--{ name .replace ("_" , "-" )} ' for name in names if getattr (args , name ) is None ]
161242 if missing :
0 commit comments