diff --git a/uniception/models/encoders/dinov2.py b/uniception/models/encoders/dinov2.py index 275d4e3..df5204a 100644 --- a/uniception/models/encoders/dinov2.py +++ b/uniception/models/encoders/dinov2.py @@ -25,6 +25,7 @@ def __init__( norm_returned_features: bool = True, pretrained_checkpoint_path: str = None, torch_hub_force_reload: bool = False, + torch_hub_pretrained: bool = True, gradient_checkpointing: bool = False, keep_first_n_layers: Optional[int] = None, use_pytorch_sdpa=True, @@ -43,6 +44,7 @@ def __init__( with_registers (bool): Whether to use the DINOv2 model with registers. Default: False pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of DINOv2. Default: None torch_hub_force_reload (bool): Whether to force reload the model from torch hub. Default: False + torch_hub_pretrained (bool): Whether to use the pretrained weights from torch hub. Default: True gradient_checkpointing (bool): Whether to use gradient checkpointing to save GPU memory during backward call. Default: False keep_first_n_layers (Optional[int]): If specified, only the first n layers of the model will be kept. Default: None use_pytorch_sdpa (bool): Whether to use PyTorch native SDPA for attention layers. Default: True @@ -90,9 +92,14 @@ def __init__( "facebookresearch/dinov2", DINO_MODELS[self.with_registers][self.version], force_reload=torch_hub_force_reload, + pretrained=torch_hub_pretrained if pretrained_checkpoint_path is None else False, ) except: # Load from cache - self.model = torch.hub.load("facebookresearch/dinov2", DINO_MODELS[self.with_registers][self.version]) + self.model = torch.hub.load( + "facebookresearch/dinov2", + DINO_MODELS[self.with_registers][self.version], + pretrained=torch_hub_pretrained if pretrained_checkpoint_path is None else False, + ) del ( self.model.mask_token diff --git a/uniception/models/encoders/dune.py b/uniception/models/encoders/dune.py index c7d1725..ffd2064 100644 --- a/uniception/models/encoders/dune.py +++ b/uniception/models/encoders/dune.py @@ -94,11 +94,13 @@ def __init__( "facebookresearch/dinov2", DINO_MODELS[self.with_registers][self.version], force_reload=torch_hub_force_reload, + pretrained=False, ) except: # Load from cache self.model = torch.hub.load( "facebookresearch/dinov2", DINO_MODELS[self.with_registers][self.version], + pretrained=False, ) del (