Skip to content

Commit 35881d5

Browse files
author
Nicolas Béreux
committed
remove compilation
1 parent 476671a commit 35881d5

20 files changed

Lines changed: 279 additions & 442 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ dependencies = [
2424
"h5py>=3.12.0",
2525
"numpy>=2.0.0",
2626
"matplotlib>=3.8.0",
27-
"torch>=2.6.0",
27+
"torch>=2.10.0",
2828
"tqdm>=4.65.0",
2929
]
3030

@@ -85,4 +85,4 @@ docstring-code-format = false
8585
[dependency-groups]
8686
dev = [
8787
"pytest>=8.4.1",
88-
]
88+
]

rbms/bernoulli_bernoulli/implement.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _compute_energy_hiddens(
5858
return -field - log_term.sum(1)
5959

6060

61-
@torch.jit.script
61+
# @torch.jit.script
6262
def _compute_gradient(
6363
v_data: Tensor,
6464
mh_data: Tensor,
@@ -108,9 +108,9 @@ def _compute_gradient(
108108

109109
# Attach to the parameters
110110

111-
weight_matrix.grad.set_(grad_weight_matrix)
112-
vbias.grad.set_(grad_vbias)
113-
hbias.grad.set_(grad_hbias)
111+
weight_matrix.grad = grad_weight_matrix
112+
vbias.grad = grad_vbias
113+
hbias.grad = grad_hbias
114114

115115

116116
@torch.jit.script

rbms/bernoulli_gaussian/classes.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
from botocore.vendored.six import u
23

34
import numpy as np
45
import torch
@@ -40,7 +41,8 @@ def __init__(
4041
self.weight_matrix = weight_matrix.to(device=self.device, dtype=self.dtype)
4142
self.vbias = vbias.to(device=self.device, dtype=self.dtype)
4243
self.hbias = hbias.to(device=self.device, dtype=self.dtype)
43-
log_two_pi = torch.log(2.0 * torch.pi, dtype=vbias.dtype, device=vbias.device)
44+
log_two_pi = torch.log(torch.tensor(2.0 * torch.pi, dtype=dtype, device=device))
45+
4446
self.const = (
4547
0.5
4648
* float(weight_matrix.shape[1])
@@ -191,20 +193,21 @@ def named_parameters(self):
191193
}
192194

193195
@property
194-
def num_hiddens(self):
196+
def num_hiddens(self) -> int:
195197
return self.hbias.shape[0]
196198

197199
@property
198-
def num_visibles(self):
200+
def num_visibles(self) -> int:
199201
return self.vbias.shape[0]
200202

201203
def parameters(self) -> list[Tensor]:
202204
# keep trainables only
203205
return [self.weight_matrix, self.vbias, self.hbias]
204206

205-
def ref_log_z(self):
206-
K = self.num_hiddens()
207-
Nv = self.num_visibles()
207+
@property
208+
def ref_log_z(self) -> float:
209+
K = self.num_hiddens
210+
Nv = self.num_visibles
208211
logZ_v = torch.log1p(torch.exp(self.vbias)).sum()
209212
inv_gamma = 1.0 / float(Nv)
210213
quad = 0.5 * inv_gamma * torch.dot(self.hbias, self.hbias)

rbms/bernoulli_gaussian/functional.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,10 @@ def compute_gradient(
7979
chains: dict[str, Tensor],
8080
params: BGRBM,
8181
centered: bool,
82-
lambda_l1: float = 0.0,
83-
lambda_l2: float = 0.0,
8482
) -> None:
8583
_compute_gradient(
8684
v_data=data["visible"],
87-
mh_data=data["hidden_mag"], # use conditional mean for positive phase
85+
h_data=data["hidden_mag"], # use conditional mean for positive phase
8886
w_data=data["weights"],
8987
v_chain=chains["visible"],
9088
h_chain=chains["hidden_mag"], # negative phase from chain samples
@@ -93,8 +91,6 @@ def compute_gradient(
9391
hbias=params.hbias,
9492
weight_matrix=params.weight_matrix,
9593
centered=centered,
96-
lambda_l1=lambda_l1,
97-
lambda_l2=lambda_l2,
9894
)
9995

10096

rbms/bernoulli_gaussian/implement.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from torch import Tensor
33

44

5-
@torch.jit.script
65
def _sample_hiddens(
76
v: Tensor, weight_matrix: Tensor, hbias: Tensor, beta: float = 1.0
87
) -> tuple[Tensor, Tensor]:
@@ -11,7 +10,6 @@ def _sample_hiddens(
1110
return h, mh
1211

1312

14-
@torch.jit.script
1513
def _sample_visibles(
1614
h: Tensor, weight_matrix: Tensor, vbias: Tensor, beta: float = 1.0
1715
) -> tuple[Tensor, Tensor]:
@@ -20,7 +18,6 @@ def _sample_visibles(
2018
return v, mv
2119

2220

23-
@torch.jit.script
2421
def _compute_energy(
2522
v: Tensor,
2623
h: Tensor,
@@ -39,7 +36,6 @@ def _compute_energy(
3936
return -fields - interaction + quad
4037

4138

42-
@torch.jit.script
4339
def _compute_energy_visibles(
4440
v: Tensor,
4541
vbias: Tensor,
@@ -54,7 +50,6 @@ def _compute_energy_visibles(
5450
return -field - quad_term + const
5551

5652

57-
@torch.jit.script
5853
def _compute_energy_hiddens(
5954
h: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor
6055
) -> Tensor:
@@ -66,7 +61,6 @@ def _compute_energy_hiddens(
6661
return -field - log_term.sum(1) + quad
6762

6863

69-
@torch.jit.script
7064
def _compute_gradient(
7165
v_data: Tensor,
7266
h_data: Tensor,
@@ -113,12 +107,11 @@ def _compute_gradient(
113107
grad_hbias = h_data_mean - h_gen_mean
114108

115109
# Attach to the parameters
116-
weight_matrix.grad.set_(grad_weight_matrix)
117-
vbias.grad.set_(grad_vbias)
118-
hbias.grad.set_(grad_hbias)
110+
weight_matrix.grad = grad_weight_matrix
111+
vbias.grad = grad_vbias
112+
hbias.grad = grad_hbias
119113

120114

121-
@torch.jit.script
122115
def _init_chains(
123116
num_samples: int,
124117
weight_matrix: Tensor,

rbms/classes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def named_parameters(self) -> dict[str, np.ndarray]: ...
353353
@abstractmethod
354354
def set_named_parameters(
355355
named_params: dict[str, np.ndarray],
356-
map_model: dict[str, EBM],
356+
map_model: dict[str, type[EBM]],
357357
device: torch.device | str,
358358
dtype: torch.dtype,
359359
) -> Sampler: ...

rbms/io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def load_params(
6969
index: int,
7070
device: torch.device | str,
7171
dtype: torch.dtype,
72-
map_model: dict[str, EBM] = map_model,
72+
map_model: dict[str, type[EBM]] = map_model,
7373
) -> EBM:
7474
"""Load the parameters of the RBM from the specified archive at the given update index.
7575
@@ -97,7 +97,7 @@ def load_model(
9797
device: torch.device | str,
9898
dtype: torch.dtype,
9999
restore: bool = False,
100-
map_model: dict[str, EBM] = map_model,
100+
map_model: dict[str, type[EBM]] = map_model,
101101
) -> tuple[EBM, dict[str, Tensor], float]:
102102
"""Load a RBM from a h5 archive.
103103

rbms/ising_gaussian/classes.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def compute_energy_visibles(self, v: Tensor) -> Tensor:
114114
const=self.const,
115115
)
116116

117-
def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0):
117+
def compute_gradient(self, data, chains, centered=True):
118118
_compute_gradient(
119119
v_data=data["visible"],
120120
mh_data=data["hidden_mag"],
@@ -126,8 +126,6 @@ def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2
126126
hbias=self.hbias,
127127
weight_matrix=self.weight_matrix,
128128
centered=centered,
129-
lambda_l1=lambda_l1,
130-
lambda_l2=lambda_l2,
131129
)
132130

133131
def independent_model(self):

rbms/ising_gaussian/functional.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ def compute_gradient(
7676
chains: dict[str, Tensor],
7777
params: IGRBM,
7878
centered: bool,
79-
lambda_l1: float = 0.0,
80-
lambda_l2: float = 0.0,
8179
) -> None:
8280
_compute_gradient(
8381
v_data=data["visible"],
@@ -90,8 +88,6 @@ def compute_gradient(
9088
hbias=params.hbias,
9189
weight_matrix=params.weight_matrix,
9290
centered=centered,
93-
lambda_l1=lambda_l1,
94-
lambda_l2=lambda_l2,
9591
)
9692

9793

rbms/ising_gaussian/implement.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@
66
from rbms.custom_fn import log2cosh
77

88

9-
@torch.jit.script
109
def _sample_hiddens(
1110
v: Tensor, weight_matrix: Tensor, hbias: Tensor, beta: float = 1.0
1211
) -> Tuple[Tensor, Tensor]:
1312
mh = hbias + (v @ weight_matrix)
14-
h = torch.randn_like(mh) / torch.sqrt(weight_matrix.shape[0]) + mh
13+
h = (
14+
torch.randn_like(mh) / torch.sqrt(torch.ones_like(mh) * weight_matrix.shape[0])
15+
+ mh
16+
)
1517
return h, mh
1618

1719

18-
@torch.jit.script
1920
def _sample_visibles(
2021
h: Tensor, weight_matrix: Tensor, vbias: Tensor, beta: float = 1.0
2122
) -> Tuple[Tensor, Tensor]:
@@ -25,7 +26,6 @@ def _sample_visibles(
2526
return v, mv
2627

2728

28-
@torch.jit.script
2929
def _compute_energy(
3030
v: Tensor,
3131
h: Tensor,
@@ -43,7 +43,6 @@ def _compute_energy(
4343
return -fields - interaction + quad
4444

4545

46-
@torch.jit.script
4746
def _compute_energy_visibles(
4847
v: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor, const: Tensor
4948
) -> Tensor:
@@ -53,7 +52,6 @@ def _compute_energy_visibles(
5352
return -field - quad_term + const
5453

5554

56-
@torch.jit.script
5755
def _compute_energy_hiddens(
5856
h: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor
5957
) -> Tensor:
@@ -65,7 +63,6 @@ def _compute_energy_hiddens(
6563
return -field - log_term.sum(1) + quad
6664

6765

68-
@torch.jit.script
6966
def _compute_gradient(
7067
v_data: Tensor,
7168
mh_data: Tensor,
@@ -121,12 +118,11 @@ def _compute_gradient(
121118
hbias.shape[0], device=hbias.device, dtype=hbias.dtype
122119
) # No training on biases
123120

124-
weight_matrix.grad.set_(grad_weight_matrix)
125-
vbias.grad.set_(grad_vbias)
126-
hbias.grad.set_(grad_hbias)
121+
weight_matrix.grad = grad_weight_matrix
122+
vbias.grad = grad_vbias
123+
hbias.grad = grad_hbias
127124

128125

129-
@torch.jit.script
130126
def _init_chains(
131127
num_samples: int,
132128
weight_matrix: Tensor,

0 commit comments

Comments
 (0)