2121
2222import edward2 .jax as ed
2323
24+ import flax .linen as nn
25+
2426import jax
2527import jax .numpy as jnp
2628import numpy as np
@@ -94,6 +96,10 @@ def setUp(self):
9496 self .x_test = _generate_normal_data (
9597 self .num_test_sample , self .num_data_dim , seed = 21 )
9698
99+ # Uses classic RBF random feature distribution.
100+ self .hidden_kwargs = dict (
101+ kernel_init = nn .initializers .normal (stddev = 1. ), feature_scale = None )
102+
97103 self .rbf_approx_maximum_tol = 5e-3
98104 self .rbf_approx_average_tol = 5e-4
99105 self .primal_dual_maximum_diff = 1e-6
@@ -105,6 +111,7 @@ def one_step_rfgp_result(self, train_data, test_data, **eval_kwargs):
105111 features = 1 ,
106112 hidden_features = self .num_random_features ,
107113 normalize_input = False ,
114+ hidden_kwargs = self .hidden_kwargs ,
108115 covmat_kwargs = dict (ridge_penalty = self .ridge_penalty ))
109116
110117 # Computes posterior covariance on test data.
@@ -231,13 +238,19 @@ def setUp(self):
231238 self .x_test = _generate_normal_data (self .num_train_sample ,
232239 self .num_data_dim )
233240
241+ # Uses classic RBF random feature distribution.
242+ self .hidden_kwargs = dict (
243+ kernel_init = nn .initializers .normal (stddev = 1. ), feature_scale = None )
244+
234245 self .kernel_approx_tolerance = dict (atol = 5e-2 , rtol = 1e-2 )
235246
236247 def test_random_feature_mutable_collection (self ):
237248 """Tests if RFF variables are properly nested under a mutable collection."""
238249 rng = jax .random .PRNGKey (self .seed )
239250 rff_layer = ed .nn .RandomFourierFeatures (
240- features = self .num_random_features , collection_name = self .collection_name )
251+ features = self .num_random_features ,
252+ collection_name = self .collection_name ,
253+ ** self .hidden_kwargs )
241254
242255 # Computes forward pass with mutable collection specified.
243256 init_vars = rff_layer .init (rng , self .x_train )
@@ -260,7 +273,8 @@ def test_random_feature_mutable_collection(self):
260273 def test_random_feature_nd_input (self , input_shape ):
261274 rng = jax .random .PRNGKey (self .seed )
262275 x = jnp .ones (input_shape )
263- rff_layer = ed .nn .RandomFourierFeatures (features = self .num_random_features )
276+ rff_layer = ed .nn .RandomFourierFeatures (
277+ features = self .num_random_features , ** self .hidden_kwargs )
264278 y , _ = rff_layer .init_with_output (rng , x )
265279
266280 expected_output_shape = input_shape [:- 1 ] + (self .num_random_features ,)
@@ -270,7 +284,9 @@ def test_random_feature_kernel_approximation(self):
270284 """Tests if default RFF layer approximates a RBF kernel matrix."""
271285 rng = jax .random .PRNGKey (self .seed )
272286 rff_layer = ed .nn .RandomFourierFeatures (
273- features = self .num_random_features , collection_name = self .collection_name )
287+ features = self .num_random_features ,
288+ collection_name = self .collection_name ,
289+ ** self .hidden_kwargs )
274290
275291 # Extracts random features by computing forward pass.
276292 init_vars = rff_layer .init (rng , self .x_train )
0 commit comments