Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions clip_count/models/clip_count_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin

from clip_count.models.models_crossvit import CrossAttentionBlock, ConvCrossAttentionBlock

Expand All @@ -13,8 +14,9 @@
import einops
import functools
import operator
class CLIPCount(nn.Module):
def __init__(self, fim_depth:int=4,

class CLIPCount(nn.Module, PyTorchModelHubMixin):
def __init__(self, fim_depth:int=4,
fim_num_heads:int=8,
mlp_ratio:float=4.,
norm_layer=nn.LayerNorm,
Expand Down Expand Up @@ -44,6 +46,9 @@ def __init__(self, fim_depth:int=4,
unfreeze_vit: whether to fintune all clip vit parameters.
"""
super().__init__()
self.repo_url = "https://github.com/songrise/CLIP-Count/tree/main"
self.pipeline_tag = "zero-shot-image-classification"
self.license = "mit"

# --------------------------------------------------------------------------
# MAE encoder specifics
Expand Down
317 changes: 295 additions & 22 deletions optimized_text_to_image_objects_count.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def collate_fn(examples):
prompt = [class_name.split()[-1]]

with torch.cuda.amp.autocast():
orig_output = counting_model.forward(image, prompt)
orig_output = counting_model(image, prompt)

scale_factor = extract_clip_count_scale_factor(image_out.detach(), orig_output[0].detach(), yolo, yolo_image_processor, config.yolo_threshold) if config.is_dynamic_scale_factor else config.scale
output = torch.sum(orig_output[0] / scale_factor)
Expand Down
4 changes: 2 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def prepare_counting_model(config: RunConfig):
from transformers import CLIPModel
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").cuda()
case "clip-count":
from clip_count.run import Model
model = Model.load_from_checkpoint("clip_count/clipcount_pretrained.ckpt", strict=False).cuda()
from clip_count.models.clip_count_model import CLIPCount
model = CLIPCount.from_pretrained("ozzafar/clip-count-base", strict=False).cuda()
model.eval()
return model

Expand Down