Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion uniception/models/encoders/dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
norm_returned_features: bool = True,
pretrained_checkpoint_path: str = None,
torch_hub_force_reload: bool = False,
torch_hub_pretrained: bool = True,
gradient_checkpointing: bool = False,
keep_first_n_layers: Optional[int] = None,
use_pytorch_sdpa=True,
Expand All @@ -43,6 +44,7 @@ def __init__(
with_registers (bool): Whether to use the DINOv2 model with registers. Default: False
pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of DINOv2. Default: None
torch_hub_force_reload (bool): Whether to force reload the model from torch hub. Default: False
torch_hub_pretrained (bool): Whether to use the pretrained weights from torch hub. Default: True
gradient_checkpointing (bool): Whether to use gradient checkpointing to save GPU memory during backward call. Default: False
keep_first_n_layers (Optional[int]): If specified, only the first n layers of the model will be kept. Default: None
use_pytorch_sdpa (bool): Whether to use PyTorch native SDPA for attention layers. Default: True
Expand Down Expand Up @@ -90,9 +92,14 @@ def __init__(
"facebookresearch/dinov2",
DINO_MODELS[self.with_registers][self.version],
force_reload=torch_hub_force_reload,
pretrained=torch_hub_pretrained if pretrained_checkpoint_path is None else False,
)
except: # Load from cache
self.model = torch.hub.load("facebookresearch/dinov2", DINO_MODELS[self.with_registers][self.version])
self.model = torch.hub.load(
"facebookresearch/dinov2",
DINO_MODELS[self.with_registers][self.version],
pretrained=torch_hub_pretrained if pretrained_checkpoint_path is None else False,
)

del (
self.model.mask_token
Expand Down
2 changes: 2 additions & 0 deletions uniception/models/encoders/dune.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,13 @@ def __init__(
"facebookresearch/dinov2",
DINO_MODELS[self.with_registers][self.version],
force_reload=torch_hub_force_reload,
pretrained=False,
)
except: # Load from cache
self.model = torch.hub.load(
"facebookresearch/dinov2",
DINO_MODELS[self.with_registers][self.version],
pretrained=False,
)

del (
Expand Down