Skip to content

Commit bc65e84

Browse files
authored
Merge pull request #19 from deepflame-ai/augment-time-selectors
Add snapshot selection to augment CLI
2 parents c790050 + 5ce1030 commit bc65e84

6 files changed

Lines changed: 182 additions & 6 deletions

File tree

dfode_kit/cli/commands/augment.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ def add_command_parser(subparsers):
2626
help='Requested number of augmented rows.',
2727
)
2828
augment_parser.add_argument('--seed', type=int, help='Random seed for reproducible augmentation.')
29+
augment_parser.add_argument(
30+
'--time',
31+
action='append',
32+
help='Select time snapshots by ordered snapshot index expression, e.g. 0, -1, 0:12, or ::10. Repeatable.',
33+
)
2934
augment_parser.add_argument('--from-config', type=str, help='Load an augment plan/config JSON.')
3035
augment_parser.add_argument('--write-config', type=str, help='Write the resolved augment plan/config to JSON.')
3136
augment_parser.add_argument('--preview', action='store_true', help='Preview the resolved plan without executing augmentation.')
@@ -76,6 +81,12 @@ def _print_human_plan(plan: dict):
7681
print(f"save: {plan['save']}")
7782
print(f"target_size: {plan['target_size']}")
7883
print(f"seed: {plan['seed']}")
84+
print(f"time_selectors: {plan['time_selectors']}")
85+
print(f"resolved_snapshot_count: {plan['resolved_snapshot_count']}")
86+
if plan['resolved_snapshot_names']:
87+
print('resolved_snapshot_names:')
88+
for name in plan['resolved_snapshot_names']:
89+
print(f' - {name}')
7990
print('resolved:')
8091
for key in sorted(plan['resolved']):
8192
print(f" {key}: {plan['resolved'][key]}")

dfode_kit/cli/commands/augment_helpers.py

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
from pathlib import Path
88
from typing import Any
99

10+
import h5py
11+
import numpy as np
12+
13+
from dfode_kit.data.contracts import SCALAR_FIELDS_GROUP, ordered_group_dataset_names, require_h5_group
14+
1015

1116
DEFAULT_AUGMENT_PRESET = 'random-local-combustion-v1'
1217

@@ -56,6 +61,7 @@ def resolve_augment_plan(args) -> dict[str, Any]:
5661
preset_name = args.preset or plan.get('preset', DEFAULT_AUGMENT_PRESET)
5762
target_size = args.target_size if args.target_size is not None else plan.get('target_size')
5863
seed = args.seed if args.seed is not None else plan.get('seed')
64+
time_selectors = args.time if args.time is not None else plan.get('time_selectors')
5965
else:
6066
_validate_required_args(args, ('source', 'mech', 'preset', 'target_size'))
6167
source = args.source
@@ -64,6 +70,7 @@ def resolve_augment_plan(args) -> dict[str, Any]:
6470
preset_name = args.preset
6571
target_size = args.target_size
6672
seed = args.seed
73+
time_selectors = args.time
6774

6875
if args.apply and not save:
6976
raise ValueError('The --save path is required when using --apply.')
@@ -77,6 +84,9 @@ def resolve_augment_plan(args) -> dict[str, Any]:
7784
if not mechanism_path.is_file():
7885
raise ValueError(f'Mechanism file does not exist: {mechanism_path}')
7986

87+
ordered_names = _read_ordered_snapshot_names(source_path)
88+
resolved_snapshot_names = _resolve_time_selectors(ordered_names, time_selectors)
89+
8090
plan = {
8191
'schema_version': 1,
8292
'command_type': 'augment',
@@ -87,6 +97,9 @@ def resolve_augment_plan(args) -> dict[str, Any]:
8797
'save': str(Path(save).resolve()) if save else None,
8898
'target_size': int(target_size),
8999
'seed': int(seed) if seed is not None else None,
100+
'time_selectors': list(time_selectors) if time_selectors else None,
101+
'resolved_snapshot_names': resolved_snapshot_names,
102+
'resolved_snapshot_count': len(resolved_snapshot_names),
90103
'config_path': str(Path(args.from_config).resolve()) if args.from_config else None,
91104
'notes': preset.notes,
92105
'resolved': dict(preset.resolved),
@@ -95,17 +108,15 @@ def resolve_augment_plan(args) -> dict[str, Any]:
95108

96109

97110
def apply_augment_plan(plan: dict[str, Any], quiet: bool = False) -> dict[str, Any]:
98-
import numpy as np
99-
100-
from dfode_kit.data import get_TPY_from_h5, random_perturb
111+
from dfode_kit.data import random_perturb
101112

102113
source_path = Path(plan['source']).resolve()
103114
output_path = Path(plan['save']).resolve()
104115
output_path.parent.mkdir(parents=True, exist_ok=True)
105116

106117
if quiet:
107118
with redirect_stdout(io.StringIO()):
108-
data = get_TPY_from_h5(source_path)
119+
data = _load_selected_tpy_from_h5(source_path, plan['resolved_snapshot_names'])
109120
augmented = random_perturb(
110121
data,
111122
plan['mechanism'],
@@ -117,7 +128,9 @@ def apply_augment_plan(plan: dict[str, Any], quiet: bool = False) -> dict[str, A
117128
else:
118129
print('Handling augment command')
119130
print(f'Loading data from h5 file: {source_path}')
120-
data = get_TPY_from_h5(source_path)
131+
if plan['time_selectors']:
132+
print(f"Selecting snapshots with --time: {plan['time_selectors']}")
133+
data = _load_selected_tpy_from_h5(source_path, plan['resolved_snapshot_names'])
121134
print('Data shape:', data.shape)
122135
augmented = random_perturb(
123136
data,
@@ -141,6 +154,8 @@ def apply_augment_plan(plan: dict[str, Any], quiet: bool = False) -> dict[str, A
141154
'returned_count': int(augmented.shape[0]),
142155
'feature_count': int(augmented.shape[1]) if augmented.ndim == 2 else None,
143156
'seed': plan.get('seed'),
157+
'resolved_snapshot_count': int(plan['resolved_snapshot_count']),
158+
'resolved_snapshot_names': list(plan['resolved_snapshot_names']),
144159
}
145160

146161

@@ -156,6 +171,72 @@ def load_plan_json(path: str | Path) -> dict[str, Any]:
156171
return json.loads(input_path.read_text(encoding='utf-8'))
157172

158173

174+
def _read_ordered_snapshot_names(source_path: str | Path) -> list[str]:
175+
with h5py.File(source_path, 'r') as hdf5_file:
176+
scalar_group = require_h5_group(hdf5_file, SCALAR_FIELDS_GROUP)
177+
return ordered_group_dataset_names(scalar_group)
178+
179+
180+
def _load_selected_tpy_from_h5(source_path: str | Path, dataset_names: list[str]) -> np.ndarray:
181+
with h5py.File(source_path, 'r') as hdf5_file:
182+
scalar_group = require_h5_group(hdf5_file, SCALAR_FIELDS_GROUP)
183+
arrays = [scalar_group[name][:] for name in dataset_names]
184+
if not arrays:
185+
raise ValueError(f"No datasets selected from '{SCALAR_FIELDS_GROUP}' in {source_path}")
186+
return np.concatenate(arrays, axis=0)
187+
188+
189+
def _resolve_time_selectors(ordered_names: list[str], selectors: list[str] | None) -> list[str]:
190+
if not ordered_names:
191+
raise ValueError('No scalar-field snapshots are available in the source HDF5.')
192+
if not selectors:
193+
return list(ordered_names)
194+
195+
selected_indices: list[int] = []
196+
seen = set()
197+
for selector in selectors:
198+
indices = _indices_from_selector(selector, len(ordered_names))
199+
for index in indices:
200+
if index not in seen:
201+
seen.add(index)
202+
selected_indices.append(index)
203+
204+
if not selected_indices:
205+
raise ValueError('The provided --time selectors resolved to zero snapshots.')
206+
207+
return [ordered_names[index] for index in selected_indices]
208+
209+
210+
def _indices_from_selector(selector: str, length: int) -> list[int]:
211+
text = selector.strip()
212+
if not text:
213+
raise ValueError('Empty --time selector is not allowed.')
214+
215+
if ':' in text:
216+
parts = text.split(':')
217+
if len(parts) > 3:
218+
raise ValueError(f'Invalid --time slice selector: {selector}')
219+
values = []
220+
for part in parts:
221+
if part == '':
222+
values.append(None)
223+
else:
224+
values.append(int(part))
225+
while len(values) < 3:
226+
values.append(None)
227+
start, stop, step = values
228+
if step == 0:
229+
raise ValueError(f'Invalid --time selector with zero step: {selector}')
230+
return list(range(length))[slice(start, stop, step)]
231+
232+
index = int(text)
233+
if index < 0:
234+
index += length
235+
if index < 0 or index >= length:
236+
raise ValueError(f'--time index out of range for {length} snapshots: {selector}')
237+
return [index]
238+
239+
159240
def _validate_required_args(args, names: tuple[str, ...]):
160241
missing = [f'--{name.replace("_", "-")}' for name in names if getattr(args, name) is None]
161242
if missing:

docs/augment.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ dfode-kit augment [options]
4949
### Optional but high-value
5050

5151
- `--seed`
52+
- `--time` (repeatable snapshot index/slice selector)
5253

5354
## Current preset
5455

@@ -66,6 +67,7 @@ dfode-kit augment \
6667
--mech /path/to/gri30.yaml \
6768
--preset random-local-combustion-v1 \
6869
--target-size 20000 \
70+
--time 0:12 \
6971
--preview --json
7072
```
7173

@@ -90,6 +92,7 @@ dfode-kit augment \
9092
--save /path/to/aug.npy \
9193
--preset random-local-combustion-v1 \
9294
--target-size 20000 \
95+
--time ::10 \
9396
--seed 1234 \
9497
--apply
9598
```
@@ -103,6 +106,22 @@ dfode-kit augment \
103106
--apply
104107
```
105108

109+
## Time snapshot selection
110+
111+
When `--time` is omitted, augmentation uses all snapshots in the sampled HDF5 source.
112+
113+
When `--time` is provided, it selects snapshots from the ordered HDF5 snapshot list by index expression.
114+
115+
Supported forms include:
116+
117+
- single index: `--time 0`
118+
- negative index: `--time -1`
119+
- slice: `--time 0:12`
120+
- stride: `--time ::10`
121+
- repeated selectors: `--time 0:5 --time -1`
122+
123+
Selection is applied to snapshots only; all rows from each selected snapshot are included.
124+
106125
## Output behavior
107126

108127
### `--preview --json`
@@ -122,6 +141,8 @@ In `--json` mode, the command reports a structured completion record including:
122141
- requested row count
123142
- returned row count
124143
- seed
144+
- resolved snapshot count
145+
- resolved snapshot names
125146

126147
## Action rule
127148

docs/cli.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ dfode-kit augment \
8383
--save /path/to/augmented.npy \
8484
--preset random-local-combustion-v1 \
8585
--target-size 20000 \
86+
--time 0:12 \
8687
--apply
8788
```
8889

docs/data-workflow.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ dfode-kit augment \
7373
--save /path/to/data/ch4_phi1_aug.npy \
7474
--preset random-local-combustion-v1 \
7575
--target-size 20000 \
76+
--time 0:12 \
7677
--apply
7778
```
7879

tests/test_augment_cli.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
from pathlib import Path
33
from types import SimpleNamespace
44

5+
import h5py
6+
import numpy as np
7+
58
from dfode_kit.cli.commands import augment, augment_helpers
9+
import dfode_kit.data as data_module
610

711

812
class DummyArgs(SimpleNamespace):
@@ -11,7 +15,11 @@ class DummyArgs(SimpleNamespace):
1115

1216
def make_args(tmp_path, **overrides):
1317
source = tmp_path / 'sample.h5'
14-
source.write_text('stub', encoding='utf-8')
18+
with h5py.File(source, 'w') as h5:
19+
scalar = h5.create_group('scalar_fields')
20+
scalar.create_dataset('0.0', data=np.array([[1.0, 2.0], [3.0, 4.0]]))
21+
scalar.create_dataset('1.0', data=np.array([[5.0, 6.0], [7.0, 8.0]]))
22+
scalar.create_dataset('2.0', data=np.array([[9.0, 10.0], [11.0, 12.0]]))
1523
mech = tmp_path / 'mech.yaml'
1624
mech.write_text('stub', encoding='utf-8')
1725
data = {
@@ -27,6 +35,7 @@ def make_args(tmp_path, **overrides):
2735
'preview': True,
2836
'apply': False,
2937
'json': True,
38+
'time': None,
3039
}
3140
data.update(overrides)
3241
return DummyArgs(**data)
@@ -42,6 +51,7 @@ def test_resolve_augment_plan_uses_minimal_contract(tmp_path):
4251
assert plan['target_size'] == 12
4352
assert plan['seed'] == 123
4453
assert plan['resolved'] == {'heat_limit': False, 'element_limit': True}
54+
assert plan['resolved_snapshot_names'] == ['0.0', '1.0', '2.0']
4555

4656

4757
def test_resolve_augment_plan_from_config_allows_save_override(tmp_path):
@@ -61,6 +71,7 @@ def test_resolve_augment_plan_from_config_allows_save_override(tmp_path):
6171
from_config=str(config_path),
6272
preview=True,
6373
apply=False,
74+
time=None,
6475
)
6576

6677
loaded = augment_helpers.resolve_augment_plan(from_config_args)
@@ -70,6 +81,55 @@ def test_resolve_augment_plan_from_config_allows_save_override(tmp_path):
7081
assert loaded['seed'] == 123
7182

7283

84+
def test_resolve_augment_plan_time_selectors_support_index_and_slice(tmp_path):
85+
args = make_args(tmp_path, time=['0', '1:'])
86+
87+
plan = augment_helpers.resolve_augment_plan(args)
88+
89+
assert plan['time_selectors'] == ['0', '1:']
90+
assert plan['resolved_snapshot_names'] == ['0.0', '1.0', '2.0']
91+
assert plan['resolved_snapshot_count'] == 3
92+
93+
94+
def test_resolve_augment_plan_time_selector_can_stride(tmp_path):
95+
args = make_args(tmp_path, time=['::2'])
96+
97+
plan = augment_helpers.resolve_augment_plan(args)
98+
99+
assert plan['resolved_snapshot_names'] == ['0.0', '2.0']
100+
101+
102+
def test_resolve_augment_plan_time_selector_out_of_range_fails(tmp_path):
103+
args = make_args(tmp_path, time=['10'])
104+
105+
try:
106+
augment_helpers.resolve_augment_plan(args)
107+
except ValueError as exc:
108+
assert 'out of range' in str(exc)
109+
else:
110+
raise AssertionError('expected ValueError')
111+
112+
113+
def test_apply_augment_plan_uses_selected_snapshots_only(tmp_path, monkeypatch):
114+
args = make_args(tmp_path, time=['1'])
115+
plan = augment_helpers.resolve_augment_plan(args)
116+
captured = {}
117+
118+
def fake_random_perturb(data, mech_path, dataset, heat_limit, element_limit, seed=None):
119+
captured['data'] = data.copy()
120+
return data
121+
122+
monkeypatch.setattr(data_module, 'random_perturb', fake_random_perturb)
123+
124+
result = augment_helpers.apply_augment_plan(plan, quiet=True)
125+
126+
assert result['resolved_snapshot_count'] == 1
127+
assert captured['data'].shape == (2, 2)
128+
assert captured['data'][0, 0] == 5.0
129+
assert Path(plan['save']).exists()
130+
131+
132+
73133
def test_handle_command_json_preview_and_apply(tmp_path, monkeypatch, capsys):
74134
args = make_args(tmp_path, preview=True, apply=True, json=True)
75135

@@ -88,6 +148,7 @@ def test_handle_command_json_preview_and_apply(tmp_path, monkeypatch, capsys):
88148
payload = json.loads(capsys.readouterr().out)
89149
assert payload['command_type'] == 'augment'
90150
assert payload['plan']['target_size'] == 12
151+
assert payload['plan']['resolved_snapshot_count'] == 3
91152
assert payload['apply']['returned_count'] == 9
92153

93154

0 commit comments

Comments
 (0)