@@ -72,7 +72,7 @@ def accumulate_gradient_with_states(
7272 accum_steps ):
7373 """Improved version of `u.accumulate_gradient()` that allows for states."""
7474 # This function handles the `loss_and_grad_fn` function which takes a state
75- # arguement and returns ((losses, states), grads).
75+ # argument and returns ((losses, states), grads).
7676 if accum_steps and accum_steps > 1 :
7777 assert images .shape [0 ] % accum_steps == 0 , (
7878 f'Bad accum_steps { accum_steps } for batch size { images .shape [0 ]} ' )
@@ -102,27 +102,16 @@ def acc_grad_and_loss(i, l_s_g):
102102
103103
104104def get_gp_kwargs (gp_config ):
105- """Extract keyword arguement parameters for the Gaussian process layer."""
106- normalize_input = gp_config .get ('normalize_input' , True )
107- kernel_stddev = gp_config .get ('random_feature_stddev' , 1. )
108- feature_scale = gp_config .get ('random_feature_scale' , - 1. )
105+ """Extract keyword argument parameters for the Gaussian process layer."""
109106 covmat_momentum = gp_config .get ('covmat_momentum' , 0.999 )
110107
111- logging .info ('gp_config.normalize_input = %s' , normalize_input )
112- logging .info ('gp_config.random_feature_stddev = %s' , kernel_stddev )
113- logging .info ('gp_config.random_feature_scale = %s' , feature_scale )
108+ # Extracts model parameter.
114109 logging .info ('gp_config.covmat_momentum = %s' , covmat_momentum )
115-
116- feature_scale = None if feature_scale < 0. else feature_scale
117- kernel_init = nn .initializers .normal (stddev = kernel_stddev )
118- hidden_kwargs = dict (feature_scale = feature_scale , kernel_init = kernel_init )
110+ covmat_momentum = None if covmat_momentum < 0. else covmat_momentum
119111 covmat_kwargs = dict (momentum = covmat_momentum )
120112
121- # Assemble into kwargs dictionary.
122- gp_layer_kwargs = dict (
123- normalize_input = normalize_input ,
124- hidden_kwargs = hidden_kwargs ,
125- covmat_kwargs = covmat_kwargs )
113+ # Assembles into kwargs dictionary.
114+ gp_layer_kwargs = dict (covmat_kwargs = covmat_kwargs )
126115
127116 return gp_layer_kwargs
128117
@@ -337,7 +326,7 @@ def representation_fn(params, images, labels, mask, states):
337326 @partial (jax .pmap , axis_name = 'batch' , donate_argnums = (0 ,))
338327 def update_fn (opt , states , lr , images , labels , rng ):
339328 """Update step."""
340-
329+ # TODO(jereliu): Expand to allow precision matrix resetting.
341330 measurements = {}
342331
343332 if config .get ('mixup' ) and config .mixup .p :
@@ -423,17 +412,17 @@ def decay_fn(v, wd):
423412 checkpoint ['states' ],
424413 checkpoint ['extra' ])
425414 elif config .get ('model_init' ):
426- write_note (f'Initialize model from { config .model_init } ...' )
427- raise ValueError (
428- 'Load from `config.model_init` checkpoint is currently not supported.' )
415+ # Load trainable parameters from the checkpoint.
416+ # This does not cause issue for SNGP since all non-trainable parameters
417+ # (random feature, precision matrix, etc) are last-layer parameters that
418+ # should be re-trained during fine-tuning.
419+ write_note (f'Initialize trainable parameters from { config .model_init } ...' )
429420 # TODO(dusenberrymw): Replace and test load function.
430- # pylint:disable=unreachable
431421 loaded = resformer .load (params_cpu , config .model_init , config .get ('model' ))
432422 opt_cpu = opt_cpu .replace (target = loaded )
433423 if jax .host_id () == 0 :
434424 logging .info ('Restored parameter overview:' )
435425 parameter_overview .log_parameter_overview (loaded )
436- # pylint:enable=unreachable
437426
438427 write_note ('Kicking off misc stuff...' )
439428 first_step = int (opt_cpu .state .step ) # Might be a DeviceArray type.
@@ -482,6 +471,7 @@ def decay_fn(v, wd):
482471 mw .step_start (step )
483472
484473 with jax .profiler .TraceContext ('train_step' , step_num = step , _r = 1 ):
474+ # TODO(jereliu): Expand to allow precision matrix resetting.
485475 (opt_repl , states_repl , loss_value , rngs_loop ,
486476 extra_measurements ) = update_fn (
487477 opt_repl ,
@@ -505,8 +495,9 @@ def decay_fn(v, wd):
505495 # alive while they'll be updated in a future step, creating hard to debug
506496 # memory errors (see b/160593526). Also, takes device 0's params only.
507497 # We will also do the same for untrainable parameters (`states`). This is
508- # ok since both `random features` and `predictive covariance` are frozen
509- # or task-specific parameters that are not important for pre-training.
498+ # ok since `random features` are frozen throughout pre-training, and
499+ # `predictive covariance` are irrelevant for downstream finetuning and
500+ # will be discarded anyway.
510501 opt_cpu = jax .tree_map (lambda x : np .array (x [0 ]), opt_repl )
511502 states_cpu = jax .tree_map (lambda x : np .array (x [0 ]), states_repl )
512503
0 commit comments