-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprobability_paths.py
More file actions
116 lines (104 loc) · 3.82 KB
/
probability_paths.py
File metadata and controls
116 lines (104 loc) · 3.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from typing import List, Tuple, Optional
import jax
import jax.numpy as jnp
import jax.random as jrandom
from base import Alpha, Beta, ConditionalProbabilityPath, Sampleable
from distributions import IsotropicGaussian
class LinearAlpha(Alpha):
"""
Implements alpha_t = t
"""
def __call__(self, t: jax.Array) -> jax.Array:
"""
Args:
- t: time (num_samples, 1, 1, 1)
Returns:
- alpha_t (num_samples, 1, 1, 1)
"""
return t
def dt(self, t: jax.Array) -> jax.Array:
"""
Evaluates d/dt alpha_t.
Args:
- t: time (num_samples, 1, 1, 1)
Returns:
- d/dt alpha_t (num_samples, 1, 1, 1)
"""
return jnp.ones_like(t)
class LinearBeta(Beta):
"""
Implements beta_t = 1-t
"""
def __call__(self, t: jax.Array) -> jax.Array:
"""
Args:
- t: time (num_samples, 1)
Returns:
- beta_t (num_samples, 1)
"""
return 1-t
def dt(self, t: jax.Array) -> jax.Array:
"""
Evaluates d/dt alpha_t.
Args:
- t: time (num_samples, 1, 1, 1)
Returns:
- d/dt alpha_t (num_samples, 1, 1, 1)
"""
return - jnp.ones_like(t)
class GaussianConditionalProbabilityPath(ConditionalProbabilityPath):
def __init__(self, p_data: Sampleable, p_simple_shape: List[int], alpha: Alpha, beta: Beta):
p_simple = IsotropicGaussian(shape = p_simple_shape, std = 1.0)
super().__init__(p_simple, p_data)
self.alpha = alpha
self.beta = beta
def sample_conditioning_variable(self, key: jax.Array, num_samples: int) -> Tuple[jax.Array, Optional[jax.Array]]:
"""
Samples the conditioning variable z and label y
Args:
- key: JAX PRNG key
- num_samples: the number of samples
Returns:
- z: (num_samples, c, h, w)
- y: (num_samples, label_dim)
"""
return self.p_data.sample(key, num_samples)
def sample_conditional_path(self, z: jax.Array, t: jax.Array, key: jax.Array) -> jax.Array:
"""
Samples from the conditional distribution p_t(x|z)
Args:
- z: conditioning variable (num_samples, c, h, w)
- t: time (num_samples, 1, 1, 1)
- key: JAX PRNG key
Returns:
- x: samples from p_t(x|z), (num_samples, c, h, w)
"""
return self.alpha(t) * z + self.beta(t) * jax.random.normal(key, shape=z.shape)
def conditional_vector_field(self, x: jax.Array, z: jax.Array, t: jax.Array) -> jax.Array:
"""
Evaluates the conditional vector field u_t(x|z)
Args:
- x: position variable (num_samples, c, h, w)
- z: conditioning variable (num_samples, c, h, w)
- t: time (num_samples, 1, 1, 1)
Returns:
- conditional_vector_field: conditional vector field (num_samples, c, h, w)
"""
alpha_t = self.alpha(t) # (num_samples, 1, 1, 1)
beta_t = self.beta(t) # (num_samples, 1, 1, 1)
dt_alpha_t = self.alpha.dt(t) # (num_samples, 1, 1, 1)
dt_beta_t = self.beta.dt(t) # (num_samples, 1, 1, 1)
return (dt_alpha_t - dt_beta_t / beta_t * alpha_t) * z + dt_beta_t / beta_t * x
def conditional_score(self, x: jax.Array, z: jax.Array, t: jax.Array) -> jax.Array:
"""
Evaluates the conditional score of p_t(x|z)
Args:
- x: position variable (num_samples, c, h, w)
- z: conditioning variable (num_samples, c, h, w)
- t: time (num_samples, 1, 1, 1)
Returns:
- conditional_score: conditional score (num_samples, c, h, w)
"""
alpha_t = self.alpha(t)
beta_t = self.beta(t)
return (z * alpha_t - x) / beta_t ** 2