-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase.py
More file actions
334 lines (301 loc) · 10.9 KB
/
base.py
File metadata and controls
334 lines (301 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
from abc import ABC, abstractmethod
from typing import Optional, Tuple
from tqdm import tqdm
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
from utils import model_size_b, MiB
class Sampleable(ABC):
"""
Distribution which can be sampled from
"""
@abstractmethod
def sample(self, key: jax.random.PRNGKey, num_samples: int) -> Tuple[jax.Array, Optional[jax.Array]]:
"""
Args:
- num_samples: the desired number of samples
Returns:
- samples: shape (batch_size, ...)
- labels: shape (batch_size, label_dim)
"""
pass
# Abstract class for scheduler functions
class Alpha(ABC):
def __init__(self):
# Check alpha_t(0) = 0
assert jnp.allclose(
self(jnp.zeros((1,1,1,1))), jnp.zeros((1,1,1,1))
)
# Check alpha_1 = 1
assert jnp.allclose(
self(jnp.ones((1,1,1,1))), jnp.ones((1,1,1,1))
)
@abstractmethod
def __call__(self, t: jax.Array) -> jax.Array:
"""
Evaluates alpha_t. Should satisfy: self(0.0) = 0.0, self(1.0) = 1.0.
Args:
- t: time (num_samples, 1, 1, 1)
Returns:
- alpha_t (num_samples, 1, 1, 1)
"""
pass
def dt(self, t: jax.Array) -> jax.Array:
"""
Evaluates d/dt alpha_t.
Args:
- t: time (num_samples, 1, 1, 1)
Returns:
- d/dt alpha_t (num_samples, 1, 1, 1)
"""
t_flat = t.squeeze()
dt = jax.vmap(jax.grad(lambda s: self(s[None, None, None, None]).squeeze()))(t_flat)
return dt.reshape(-1, 1, 1, 1)
class Beta(ABC):
def __init__(self):
# Check beta_0 = 1
assert jnp.allclose(
self(jnp.zeros((1,1,1,1))), jnp.ones((1,1,1,1))
)
# Check beta_1 = 0
assert jnp.allclose(
self(jnp.ones((1,1,1,1))), jnp.zeros((1,1,1,1))
)
@abstractmethod
def __call__(self, t: jax.Array) -> jax.Array:
"""
Evaluates alpha_t. Should satisfy: self(0.0) = 1.0, self(1.0) = 0.0.
Args:
- t: time (num_samples, 1, 1, 1)
Returns:
- beta_t (num_samples, 1, 1, 1)
"""
pass
def dt(self, t: jax.Array) -> jax.Array:
"""
Evaluates d/dt beta_t.
Args:
- t: time (num_samples, 1, 1, 1)
Returns:
- d/dt beta_t (num_samples, 1, 1, 1)
"""
t_flat = t.squeeze()
dt = jax.vmap(jax.grad(lambda s: self(s[None, None, None, None]).squeeze()))(t_flat)
return dt.reshape(-1, 1, 1, 1)
# Abstract class for both ODE and SDE
class ODE(ABC):
@abstractmethod
def drift_coefficient(self, xt: jax.Array, t: jax.Array, **kwargs) -> jax.Array:
"""
Returns the drift coefficient of the ODE.
Args:
- xt: state at time t, shape (bs, c, h, w)
- t: time, shape (bs, 1, 1, 1)
Returns:
- drift_coefficient: shape (bs, c, h, w)
"""
pass
class SDE(ABC):
@abstractmethod
def drift_coefficient(self, xt: jax.Array, t: jax.Array, **kwargs) -> jax.Array:
"""
Returns the drift coefficient of the SDE.
Args:
- xt: state at time t, shape (bs, c, h, w)
- t: time, shape (bs, 1, 1, 1)
Returns:
- drift_coefficient: shape (bs, c, h, w)
"""
pass
@abstractmethod
def diffusion_coefficient(self, xt: jax.Array, t: jax.Array, **kwargs) -> jax.Array:
"""
Returns the diffusion coefficient of the SDE.
Args:
- xt: state at time t, shape (bs, c, h, w)
- t: time, shape (bs, 1, 1, 1)
Returns:
- diffusion_coefficient: shape (bs, c, h, w)
"""
pass
# Abstract class for simulators
class Simulator(ABC):
@abstractmethod
def step(self, xt: jax.Array, t: jax.Array, dt: jax.Array, key: jax.Array, **kwargs) -> jax.Array:
"""
Takes one simulation step
Args:
- xt: state at time t, shape (bs, c, h, w)
- t: time, shape (bs, 1, 1, 1)
- dt: time step, shape (bs, 1, 1, 1)
- key: JAX PRNG key (for SDE simulators that need randomness)
Returns:
- nxt: state at time t + dt (bs, c, h, w)
"""
pass
def simulate(self, x: jax.Array, ts: jax.Array, key: jax.Array, **kwargs) -> jax.Array:
"""
Simulates using the discretization gives by ts
Args:
- x: initial state, shape (bs, c, h, w)
- ts: timesteps, shape (bs, nts, 1, 1, 1)
- key: JAX PRNG key
Returns:
- x_final: final state at time ts[-1], shape (bs, c, h, w)
"""
nts = ts.shape[1]
keys = jax.random.split(key, nts - 1)
for t_idx in tqdm(range(nts - 1)):
t = ts[:, t_idx]
h = ts[:, t_idx + 1] - ts[:, t_idx]
x = self.step(x, t, h, keys[t_idx], **kwargs)
return x
def simulate_with_trajectory(self, x: jax.Array, ts: jax.Array, key: jax.Array, **kwargs) -> jax.Array:
"""
Simulates using the discretization gives by ts
Args:
- x: initial state, shape (bs, c, h, w)
- ts: timesteps, shape (bs, nts, 1, 1, 1)
- key: JAX PRNG key
Returns:
- xs: trajectory of xts over ts, shape (batch_size, nts, c, h, w)
"""
xs = [x]
nts = ts.shape[1]
keys = jax.random.split(key, nts - 1)
for t_idx in tqdm(range(nts - 1)):
t = ts[:,t_idx]
h = ts[:, t_idx + 1] - ts[:, t_idx]
x = self.step(x, t, h, keys[t_idx], **kwargs)
xs.append(x)
return jnp.stack(xs, axis=1)
# Abstract class for conditional
class ConditionalProbabilityPath(ABC):
"""
Abstract base class for conditional probability paths
"""
def __init__(self, p_simple: Sampleable, p_data: Sampleable):
self.p_simple = p_simple
self.p_data = p_data
def sample_marginal_path(self, t: jax.Array, key: jax.Array) -> jax.Array:
"""
Samples from the marginal distribution p_t(x) = p_t(x|z) p(z)
Args:
- t: time (num_samples, 1, 1, 1)
- key: JAX PRNG key
Returns:
- x: samples from p_t(x), (num_samples, c, h, w)
"""
num_samples = t.shape[0]
key1, key2 = jax.random.split(key)
# Sample conditioning variable z ~ p(z)
z, _ = self.sample_conditioning_variable(key1, num_samples) # (num_samples, c, h, w)
# Sample conditional probability path x ~ p_t(x|z)
x = self.sample_conditional_path(z, t, key2) # (num_samples, c, h, w)
return x
@abstractmethod
def sample_conditioning_variable(self, key: jax.Array, num_samples: int) -> Tuple[jax.Array, Optional[jax.Array]]:
"""
Samples the conditioning variable z and label y
Args:
- key: JAX PRNG key
- num_samples: the number of samples
Returns:
- z: (num_samples, c, h, w)
- y: (num_samples, label_dim)
"""
pass
@abstractmethod
def sample_conditional_path(self, z: jax.Array, t: jax.Array, key: jax.Array) -> jax.Array:
"""
Samples from the conditional distribution p_t(x|z)
Args:
- z: conditioning variable (num_samples, c, h, w)
- t: time (num_samples, 1, 1, 1)
- key: JAX PRNG key
Returns:
- x: samples from p_t(x|z), (num_samples, c, h, w)
"""
pass
@abstractmethod
def conditional_vector_field(self, x: jax.Array, z: jax.Array, t: jax.Array) -> jax.Array:
"""
Evaluates the conditional vector field u_t(x|z)
Args:
- x: position variable (num_samples, c, h, w)
- z: conditioning variable (num_samples, c, h, w)
- t: time (num_samples, 1, 1, 1)
Returns:
- conditional_vector_field: conditional vector field (num_samples, c, h, w)
"""
pass
@abstractmethod
def conditional_score(self, x: jax.Array, z: jax.Array, t: jax.Array) -> jax.Array:
"""
Evaluates the conditional score of p_t(x|z)
Args:
- x: position variable (num_samples, c, h, w)
- z: conditioning variable (num_samples, c, h, w)
- t: time (num_samples, 1, 1, 1)
Returns:
- conditional_score: conditional score (num_samples, c, h, w)
"""
pass
class ConditionalVectorField(ABC):
"""
MLP-parameterization of the learned vector field u_t^theta(x)
"""
@abstractmethod
def __call__(self, x: jax.Array, t: jax.Array, y: jax.Array) -> jax.Array:
"""
Args:
- x: (bs, c, h, w)
- t: (bs, 1, 1, 1)
- y: (bs,)
Returns:
- u_t^theta(x|y): (bs, c, h, w)
"""
pass
# Abstract class for training models
class Trainer(ABC):
def __init__(self, model: eqx.Module):
super().__init__()
self.model = model
@abstractmethod
def sample_batch(self, key: jax.random.PRNGKey, batch_size: int):
"""Sample a batch of training data. Called outside JIT."""
pass
@abstractmethod
def get_train_loss(self, model: eqx.Module, *args, **kwargs) -> jax.Array:
"""Compute loss given a model and pre-sampled data. Must be a pure function."""
pass
def get_optimizer(self, lr: float):
return optax.adam(lr)
def train(self, num_epochs: int, lr: float = 1e-3, checkpoint_callback=None, **kwargs):
# Report model size
size_b = model_size_b(self.model)
print(f'Training model with size: {size_b / MiB:.3f} MiB')
# Initialize optimizer
opt = self.get_optimizer(lr)
opt_state = opt.init(eqx.filter(self.model, eqx.is_array))
@eqx.filter_jit
def train_step(model, opt_state, *args):
loss, grads = eqx.filter_value_and_grad(self.get_train_loss)(model, *args)
updates, opt_state = opt.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
key = kwargs.pop('key', jax.random.PRNGKey(0))
batch_size = kwargs.pop('batch_size', 32)
# Train loop
pbar = tqdm(range(num_epochs))
for epoch in pbar:
# Generate new keys for data sampling and model
key, data_key, model_key = jax.random.split(key, 3)
# Sample batch OUTSIDE JIT
batch_data = self.sample_batch(data_key, batch_size)
self.model, opt_state, loss = train_step(self.model, opt_state, *batch_data, model_key)
loss_val = float(loss)
pbar.set_description(f'Epoch {epoch}, loss: {loss_val:.3f}')
if checkpoint_callback is not None:
checkpoint_callback(epoch, self.model, opt_state, loss_val)