-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathsearch_joint.py
More file actions
73 lines (51 loc) · 2.18 KB
/
search_joint.py
File metadata and controls
73 lines (51 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Joint optimization: tree + seq params optimized together with a single optimizer."""
import optax
from jaxopt import OptaxSolver
from jax import jit, vmap
from common import *
# ── Objective ──
def objective(params, data):
seqs, temp, epoch = data
tree_params = {'t': params['t']}
seq_params = {'s': params['s']}
return compute_loss_optimized(tree_params, seq_params, seqs, temp, epoch)
# ── Setup ──
args = vars(parse_args())
args = sanity_check(args)
metadata = build_metadata(args)
setup_device(args)
print(pretty_print_dict(metadata))
seqs, gt_seqs, tree, base_tree, sm, sankoff_cost, gt_cost = generate_data(metadata)
clear_metadata_for_jit(metadata, args)
tree_params, seq_params = init_params(metadata, seqs, metadata['n_leaves'], metadata['n_ancestors'], metadata['init_count'])
if args['initialize_tree']:
tree_params['t'] = tree[0:-1, metadata['n_leaves']:] * 100
# Merge into single param dict
params = {'t': tree_params['t'], 's': seq_params['s']}
# ── Optimizer ──
optimizer = OptaxSolver(opt=optax.adam(metadata['lr']), fun=objective)
vmap_init = vmap(optimizer.init_state, (0, None), 0)
# vmap over init_count dimension — need vmap spec for merged params
vmap_keys = {k: 0 for k in params.keys()}
vmap_init = vmap(optimizer.init_state, (vmap_keys, None), 0)
opt_state = vmap_init(params, [seqs, metadata['tLs'][0], 0])
jitted_update = jit(vmap(optimizer.update, (vmap_keys, 0, None), 0))
# ── Update step ──
def update_step(tree_params, seq_params, seqs, metadata, epoch):
nonlocal_state = update_step.state
# Merge
merged = {'t': tree_params['t'], 's': seq_params['s']}
merged, nonlocal_state['opt'] = jitted_update(
merged, nonlocal_state['opt'], [seqs, metadata['tLs'][0], epoch]
)
update_step.state = nonlocal_state
# Split back
new_tree_params = {'t': merged['t']}
new_seq_params = {'s': merged['s']}
return new_tree_params, new_seq_params
update_step.state = {'opt': opt_state}
# ── Run ──
best_cost, best_tree, best_seq = run_search(
update_step, tree_params, seq_params, seqs, metadata, sm,
base_tree=base_tree, tree=tree, sankoff_cost=sankoff_cost, gt_cost=gt_cost
)