diff --git a/examples/image_classification/imagenet_resnet50_data.py b/examples/image_classification/imagenet_resnet50_data.py index 9fdd023..574d399 100644 --- a/examples/image_classification/imagenet_resnet50_data.py +++ b/examples/image_classification/imagenet_resnet50_data.py @@ -83,10 +83,7 @@ def load(split: Split, is_training: bool, batch_dims: Sequence[int], tfds_data_d if is_training: ds = ds.repeat() ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0) - else: - if split.num_examples % total_batch_size != 0: - raise ValueError(f'Test set size must be divisible by {total_batch_size}') - + def preprocess(example): image = _preprocess_image(example['image'], is_training) image = tf.transpose(image, (2, 0, 1)) # transpose HWC image to CHW format