Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions baselines/jft/experiments/jft300m_vit_base16_sngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ def get_config():

pp_common = '|value_range(-1, 1)'
pp_common += f'|onehot({config.num_classes})'
# To use ancestor "smearing", use this line instead:
# pp_common += f'|onehot({config.num_classes}, key="labels_extended", key_result="labels") # pylint: disable=line-too-long
# To use ancestor 'smearing', use this line instead:
# pp_common += f'|onehot({config.num_classes}, key='labels_extended', key_result='labels') # pylint: disable=line-too-long
pp_common += '|keep("image", "labels")'
config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common
config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.

config.log_training_steps = 50
config.log_eval_steps = 1000
# NOTE: eval is very fast O(seconds) so it's fine to run it often.
config.checkpoint_steps = 1000
# NOTE: For pretraining, save infrequently to prevent crowding diskspace.
config.checkpoint_steps = 517790

# Model section
config.model = ml_collections.ConfigDict()
Expand All @@ -66,11 +66,11 @@ def get_config():
config.model.classifier = 'token' # Or 'gap'
config.model.representation_size = 768

# GP layer parameters.
# Gaussian process layer parameters.
config.gp_layer = ml_collections.ConfigDict()
config.gp_layer.normalize_input = True
config.gp_layer.random_feature_scale = 1. # 1. or None
config.gp_layer.random_feature_stddev = 0.025 # 1. or 0.025
# Use momentum for pre-training to prevent numeric error when inverting a
# precision matrix accumulated over 300M data.
config.gp_layer.covmat_momentum = .999

# Optimizer section
config.optim_name = 'Adam'
Expand All @@ -82,7 +82,8 @@ def get_config():

# TODO(lbeyer): make a mini-language like preprocessings.
config.lr = ml_collections.ConfigDict()
config.lr.base = 8e-4 # LR has to be lower for larger models!
# LR has to be lower for GP layer and on larger models.
config.lr.base = 4e-4
config.lr.warmup_steps = 10_000
config.lr.decay_type = 'linear'
config.lr.linear_end = 1e-5
Expand All @@ -96,9 +97,4 @@ def get_config():


def get_sweep(hyper):
# lr_grid = [3e-4, 4e-4, 5e-4, 6e-4]
# stddev_grid = [0.01, 0.02, 0.03, 0.04, 0.05]
return hyper.product([
# hyper.sweep('config.lr.base', lr_grid),
# hyper.sweep('config.gp_layer.random_feature_stddev', stddev_grid)
])
return hyper.product([])
42 changes: 16 additions & 26 deletions baselines/jft/sngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from clu import periodic_actions
import flax
import flax.jax_utils as flax_utils
import flax.linen as nn
import jax
import jax.numpy as jnp
import ml_collections
Expand Down Expand Up @@ -72,7 +71,7 @@ def accumulate_gradient_with_states(
accum_steps):
"""Improved version of `u.accumulate_gradient()` that allows for states."""
# This function handles the `loss_and_grad_fn` function which takes a state
# arguement and returns ((losses, states), grads).
# argument and returns ((losses, states), grads).
if accum_steps and accum_steps > 1:
assert images.shape[0] % accum_steps == 0, (
f'Bad accum_steps {accum_steps} for batch size {images.shape[0]}')
Expand Down Expand Up @@ -102,27 +101,16 @@ def acc_grad_and_loss(i, l_s_g):


def get_gp_kwargs(gp_config):
"""Extract keyword arguement parameters for the Gaussian process layer."""
normalize_input = gp_config.get('normalize_input', True)
kernel_stddev = gp_config.get('random_feature_stddev', 1.)
feature_scale = gp_config.get('random_feature_scale', -1.)
"""Extract keyword argument parameters for the Gaussian process layer."""
covmat_momentum = gp_config.get('covmat_momentum', 0.999)

logging.info('gp_config.normalize_input = %s', normalize_input)
logging.info('gp_config.random_feature_stddev = %s', kernel_stddev)
logging.info('gp_config.random_feature_scale = %s', feature_scale)
# Extracts model parameter.
logging.info('gp_config.covmat_momentum = %s', covmat_momentum)

feature_scale = None if feature_scale < 0. else feature_scale
kernel_init = nn.initializers.normal(stddev=kernel_stddev)
hidden_kwargs = dict(feature_scale=feature_scale, kernel_init=kernel_init)
covmat_momentum = None if covmat_momentum < 0. else covmat_momentum
covmat_kwargs = dict(momentum=covmat_momentum)

# Assemble into kwargs dictionary.
gp_layer_kwargs = dict(
normalize_input=normalize_input,
hidden_kwargs=hidden_kwargs,
covmat_kwargs=covmat_kwargs)
# Assembles into kwargs dictionary.
gp_layer_kwargs = dict(covmat_kwargs=covmat_kwargs)

return gp_layer_kwargs

Expand Down Expand Up @@ -337,7 +325,7 @@ def representation_fn(params, images, labels, mask, states):
@partial(jax.pmap, axis_name='batch', donate_argnums=(0,))
def update_fn(opt, states, lr, images, labels, rng):
"""Update step."""

# TODO(jereliu): Expand to allow precision matrix resetting.
measurements = {}

if config.get('mixup') and config.mixup.p:
Expand Down Expand Up @@ -423,17 +411,17 @@ def decay_fn(v, wd):
checkpoint['states'],
checkpoint['extra'])
elif config.get('model_init'):
write_note(f'Initialize model from {config.model_init}...')
raise ValueError(
'Load from `config.model_init` checkpoint is currently not supported.')
# Load trainable parameters from the checkpoint.
# This does not cause issue for SNGP since all non-trainable parameters
# (random feature, precision matrix, etc) are last-layer parameters that
# should be re-trained during fine-tuning.
write_note(f'Initialize trainable parameters from {config.model_init}...')
# TODO(dusenberrymw): Replace and test load function.
# pylint:disable=unreachable
loaded = resformer.load(params_cpu, config.model_init, config.get('model'))
opt_cpu = opt_cpu.replace(target=loaded)
if jax.host_id() == 0:
logging.info('Restored parameter overview:')
parameter_overview.log_parameter_overview(loaded)
# pylint:enable=unreachable

write_note('Kicking off misc stuff...')
first_step = int(opt_cpu.state.step) # Might be a DeviceArray type.
Expand Down Expand Up @@ -482,6 +470,7 @@ def decay_fn(v, wd):
mw.step_start(step)

with jax.profiler.TraceContext('train_step', step_num=step, _r=1):
# TODO(jereliu): Expand to allow precision matrix resetting.
(opt_repl, states_repl, loss_value, rngs_loop,
extra_measurements) = update_fn(
opt_repl,
Expand All @@ -505,8 +494,9 @@ def decay_fn(v, wd):
# alive while they'll be updated in a future step, creating hard to debug
# memory errors (see b/160593526). Also, takes device 0's params only.
# We will also do the same for untrainable parameters (`states`). This is
# ok since both `random features` and `predictive covariance` are frozen
# or task-specific parameters that are not important for pre-training.
# ok since `random features` are frozen throughout pre-training, and
# `predictive covariance` are irrelevant for downstream finetuning and
# will be discarded anyway.
opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl)

Expand Down
8 changes: 4 additions & 4 deletions baselines/jft/sngp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ def get_config(classifier, representation_size):
class SNGPTest(parameterized.TestCase, tf.test.TestCase):

@parameterized.parameters(
('token', 2, 1111.4404296875, 16258.519965277777, 0.16999999806284904),
('token', None, 13992.8515625, 3621.3713107638887, 0.20999999344348907),
('gap', 2, 8779.61328125, 3998.798285590278, 0.12999999895691872),
('gap', None, 11279.3515625, 3212.2536892361113, 0.2199999988079071),
('token', 2, 916.2851, 1954.3369140625, 0.16999999806284904),
('token', None, 290.0307, 915.987548828125, 0.20999999344348907),
('gap', 2, 695.6460, 600.8613823784722, 0.12999999895691872),
('gap', None, 192.9434, 341.7078450520833, 0.2199999988079071),
)
def test_sngp_script(self, classifier, representation_size,
correct_train_loss, correct_val_loss,
Expand Down