forked from matthew-lowery/kernel_neural_operator
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkernels.py
More file actions
150 lines (128 loc) · 4.99 KB
/
kernels.py
File metadata and controls
150 lines (128 loc) · 4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from jax import random as jr, numpy as jnp
import jax
import equinox as eqx
from typing import List, Callable
from functools import partial
from abc import ABC, abstractmethod
class KernelBaseClass(ABC):
@abstractmethod
def __init__(self, *args, **kwargs):
pass
### to be implemented by each subclass, evaluate the kernel on a single pair of coordinates x,y, of shape (ndims,)
@abstractmethod
def eval(self, x, y):
pass
### make kernel/Gram matrix given vectors x,y of coordinates, same for all kernels
def __call__(self,
x,
y):
if x.ndim == 1 or y.ndim == 1:
ndims = 1
else:
ndims = x.shape[-1]
X,Y = x.reshape(-1, ndims), y.reshape(-1, ndims)
k_xy = jax.vmap(jax.vmap(self.eval, (0, None)), (None, 0))(Y,X)
return k_xy
class GaussianKernel(eqx.Module, KernelBaseClass):
scale: jax.Array
def __init__(self, *, key, **kwargs):
key,_ = jr.split(key)
self.scale = jr.uniform(key,minval=-3.,maxval=0.0)
def eval(self, x, y,):
return (jnp.exp(- 0.5*(x - y)**2 / jnp.exp(2*self.scale))).sum()
class AnisotropicGaussianKernel(eqx.Module, KernelBaseClass):
scale: jax.Array
ndims: int
place: Callable
def __init__(self, *, ndims, key, **kwargs):
keys = jr.split(key)
self.scale = jr.uniform(keys[0], (int(ndims*(ndims+1)/2),), minval=-3., maxval=0.)
self.ndims = ndims
self.place = lambda vals: jnp.zeros((ndims,ndims)).at[jnp.tril_indices(ndims)].set(vals)
### one pair of pts
def eval(self, x, y,):
L = self.place(self.scale)
scale = L @ L.T
r_scaled = (x-y) @ scale @ (x-y)
return jnp.exp(-1/2 * r_scaled)
class NonstationaryGaussianKernel(eqx.Module, KernelBaseClass):
scale: eqx.Module
ndims: int
def __init__(self, *, ndims, latent_dim, key, **kwargs):
keys = jr.split(key)
self.scale = eqx.nn.Sequential(
[
eqx.nn.Linear(ndims, latent_dim, key=keys[0]),
eqx.nn.Lambda(jax.nn.selu),
eqx.nn.Linear(latent_dim, 1, key=keys[1]),
eqx.nn.Lambda(jax.nn.softplus),
]
)
self.ndims = ndims
def eval(self, x, y,):
sx = self.scale(x).squeeze()
sy = self.scale(y).squeeze()
k_gibbs_t1 = jnp.sqrt((2 * sx * sy) / (sx**2 + sy**2))
k_gibbs_t2 = jnp.exp(
-jnp.sum((x - y) ** 2) / (sx**2 + sy**2)
)
k_gibbs = k_gibbs_t1 * k_gibbs_t2
return k_gibbs
class GaussianSpectralMixtureKernel(eqx.Module, KernelBaseClass):
q: int
weights: jax.Array
freqs: jax.Array
scales: jax.Array
def __init__(self, *, base_kernel, q, ndims, key, **kwargs):
key1, key2, key3 = jr.split(key, 3)
self.q = q
self.weights = jr.uniform(key1, (q,1), maxval=0.1, minval=-3.)
self.freqs = jr.uniform(key2, (q,ndims), maxval=0.1, minval=-3.)
self.scales = jr.uniform(key, (q,), maxval=0.1, minval=-3,)
def eval(self, x,y):
weights = jax.nn.softplus(self.weights)
freqs = jax.nn.softplus(self.freqs)
scales = jax.nn.softplus(self.scales)
tau = x-y
cos = jnp.cos(freqs @ tau)
gauss = eqx.filter_vmap(lambda scale: (jnp.exp(- 0.5*tau**2 / jnp.exp(2*scale))).sum())(scales)
return jnp.sum(weights * cos * gauss)
class NonstationaryGaussianSpectralMixtureKernel(eqx.Module, KernelBaseClass):
weights: eqx.Module
q: int
def __init__(
self,
*,
ndims: int,
q: int,
latent_dim: int,
key,
**kwargs):
self.q = q
key,_ = jr.split(key)
self.weights = eqx.nn.MLP(key=key,
in_size=ndims,
out_size=q + q + (q*ndims),
width_size=latent_dim,
depth=1,
activation=jax.nn.selu,
final_activation=jax.nn.softplus,
)
def eval(self, x, y):
q = self.q
all_x, all_y = self.weights(x), self.weights(y)
wx, wy = all_x[:q], all_y[:q]
sx,sy = all_x[q:2*q], all_y[q:2*q]
fx,fy = all_x[2*q:].reshape(q,-1), all_y[2*q:].reshape(q,-1)
k_gibbs = (jnp.sqrt(2 * sx * sy) / (sx**2 + sy**2)) * jnp.exp(
-(jnp.sum((x - y) ** 2)) / (sx**2 + sy**2)
)
cosine = jnp.cos(2 * jnp.pi * (fx @ x - fy @ y))
k_xy = (wx * wy * k_gibbs * cosine).sum() # sum over mixtures
return k_xy
kernels = {'g': GaussianKernel,
'a_g': AnisotropicGaussianKernel,
'ns_g': partial(NonstationaryGaussianKernel, latent_dim=8),
'gsm': partial(GaussianSpectralMixtureKernel, base_kernel=GaussianKernel, q=2),
'ns_gsm': partial(NonstationaryGaussianSpectralMixtureKernel, latent_dim=8, q=2),
}