Skip to content

Fine-tune VIT model with BYOL method #600

@khawar-islam

Description

@khawar-islam

I am fine-tuning the dataset on VIT using the below line

model = timm.create_model('vit_base_resnet50_384', pretrained=True, num_classes=7)
The accuracy is not that much good so I decided to integrate BYOL paper which is very easy to integrate with VIT.
https://github.com/lucidrains/byol-pytorch

Code

self.learner = BYOL(
            model,
            image_size=224,
            hidden_layer=model.cls_token
        )
        optimizer = optim.SGD(model, weight_decay=.0005, momentum=.9, nesterov=args.nesterov, lr=args.lr) 

    def _do_epoch(self, epoch=None):
        criterion = nn.CrossEntropyLoss()

        self.model.train()
        for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader):
            data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(self.device), class_l.to(
                self.device), d_idx.to(self.device)
            self.optimizer.zero_grad()

            data_flip = torch.flip(data, (3,)).detach().clone()
            data = torch.cat((data, data_flip))
            class_l = torch.cat((class_l, class_l))

            class_logit = self.model(data, class_l, True, epoch)
            class_loss = criterion(class_logit, class_l)
            _, cls_pred = class_logit.max(dim=1)
            loss = class_loss

            loss.backward()
            self.optimizer.step()
            self.learner.update_moving_average() #byol code

            self.logger.log(it, len(self.source_loader),
                            {"class": class_loss.item()},
                            {"class": torch.sum(cls_pred == class_l.data).item(), }, data.shape[0])
            del loss, class_


Traceback:

Traceback (most recent call last):
  File "/media/khawar/HDD_Khawar/RSC/Domain_Generalization/train.py", line 193, in <module>
    main()
  File "/media/khawar/HDD_Khawar/RSC/Domain_Generalization/train.py", line 187, in main
    trainer = Trainer(args, device)
  File "/media/khawar/HDD_Khawar/RSC/Domain_Generalization/train.py", line 87, in __init__
    hidden_layer=model.cls_token
  File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/byol_pytorch/byol_pytorch.py", line 211, in __init__
    self.forward(torch.randn(2, 3, image_size, image_size, device=device))
  File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/byol_pytorch/byol_pytorch.py", line 239, in forward
    online_proj_one, _ = self.online_encoder(image_one)
  File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/byol_pytorch/byol_pytorch.py", line 149, in forward
    representation = self.get_representation(x)
  File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/byol_pytorch/byol_pytorch.py", line 134, in get_representation
    if self.layer == -1:
RuntimeError: Boolean value of Tensor with more than one value is ambiguous

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions