-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Description
Describe the bug
Hi,I am trying to train a deit_small_patch16_224 model with a Google Cloud TPU V3 VM Instance using the bits_and_tpu branch. The strange thing is, though I am able to run the code, the final accuracy of the model is non-trivially lower than the counterpart trained on GPU. It's 73.1 vs 74.6 with Repeated Augmentation, 74.7 vs over 76 without Repeated Augmentation (I manually merge the RAS sampler from master branch). I wonder if there is any way to locate the problem and fix it?
To Reproduce
The training command I used is as follows (Note that amp is disabled since currently amp is not supported on TPU):
python3 launch_xla.py --num-devices 8 train.py /mnt/disks/persist/ImageNet --val-split val --model deit_small_patch16_224 --batch-size 128 --opt adamw --opt-eps 1e-8 --momentum 0.9 --weight-decay 0.05 --sched cosine --lr 5e-4 --lr-cycle-decay 1.0 --warmup-lr 1e-6 --epochs 100 --decay-epochs 30 --cooldown-epochs 0 --aa rand-m9-mstd0.5-inc1 (--aug-repeats 3) --reprob 0.25 --mixup 0.8 --cutmix 1.0 --smoothing 0.1 --train-interpolation bicubic --drop-path 0.1 --workers 12 --seed 0 --pin-mem --output output --experiment deit_small_patch16_224_default_noamp_epoch100 --log-interval 200 --recovery-interval 12500
Training Log
Here is the log for repeat-aug-training on TPU:
deit_small_patch16_224_repeated_aug_epoch100.csv
Here is the log for repeat-aug-training on GPU:
deit_small_patch16_224_repeat_aug_epoch100.txt
Here is the log for no-repeat-aug-training on TPU:
deit_small_patch16_224_no_repeated_aug_epoch100.csv
Sorry, I don't have a exact log for no-repeat-aug-trainingon GPU, but it should closely resemble this:
similar_deit_small_patch16_224_no_repeat_aug_epoch100.txt
Server
- TPU V3
- Ubuntu 20.04.2 LTS
- PyTorch 1.9.0, PyTorch/XLA 1.9
Additional Comment
Maybe there is some data augmentation or optimization strategy I am missing here. But I've also tried modify the original Deit codebase using the same PyTorch/XLA technique (so that data augs and optimization will be exactly the same as original DeiT), the training results is still consistently lower than the counterpart on GPU for 100 epochs training (around absolute 2%). For 300 epochs training, the TPU model accuracy is lower than that of the GPU model for the first 80-100 epochs, but will catch up in the 150th epoch.