66from rbms .custom_fn import log2cosh
77
88
9- @torch .jit .script
109def _sample_hiddens (
1110 v : Tensor , weight_matrix : Tensor , hbias : Tensor , beta : float = 1.0
1211) -> Tuple [Tensor , Tensor ]:
1312 mh = hbias + (v @ weight_matrix )
14- h = torch .randn_like (mh ) / torch .sqrt (weight_matrix .shape [0 ]) + mh
13+ h = (
14+ torch .randn_like (mh ) / torch .sqrt (torch .ones_like (mh ) * weight_matrix .shape [0 ])
15+ + mh
16+ )
1517 return h , mh
1618
1719
18- @torch .jit .script
1920def _sample_visibles (
2021 h : Tensor , weight_matrix : Tensor , vbias : Tensor , beta : float = 1.0
2122) -> Tuple [Tensor , Tensor ]:
@@ -25,7 +26,6 @@ def _sample_visibles(
2526 return v , mv
2627
2728
28- @torch .jit .script
2929def _compute_energy (
3030 v : Tensor ,
3131 h : Tensor ,
@@ -43,7 +43,6 @@ def _compute_energy(
4343 return - fields - interaction + quad
4444
4545
46- @torch .jit .script
4746def _compute_energy_visibles (
4847 v : Tensor , vbias : Tensor , hbias : Tensor , weight_matrix : Tensor , const : Tensor
4948) -> Tensor :
@@ -53,7 +52,6 @@ def _compute_energy_visibles(
5352 return - field - quad_term + const
5453
5554
56- @torch .jit .script
5755def _compute_energy_hiddens (
5856 h : Tensor , vbias : Tensor , hbias : Tensor , weight_matrix : Tensor
5957) -> Tensor :
@@ -65,7 +63,6 @@ def _compute_energy_hiddens(
6563 return - field - log_term .sum (1 ) + quad
6664
6765
68- @torch .jit .script
6966def _compute_gradient (
7067 v_data : Tensor ,
7168 mh_data : Tensor ,
@@ -121,12 +118,11 @@ def _compute_gradient(
121118 hbias .shape [0 ], device = hbias .device , dtype = hbias .dtype
122119 ) # No training on biases
123120
124- weight_matrix .grad . set_ ( grad_weight_matrix )
125- vbias .grad . set_ ( grad_vbias )
126- hbias .grad . set_ ( grad_hbias )
121+ weight_matrix .grad = grad_weight_matrix
122+ vbias .grad = grad_vbias
123+ hbias .grad = grad_hbias
127124
128125
129- @torch .jit .script
130126def _init_chains (
131127 num_samples : int ,
132128 weight_matrix : Tensor ,
0 commit comments