-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfilter_holodeck.py
More file actions
65 lines (54 loc) · 3.01 KB
/
filter_holodeck.py
File metadata and controls
65 lines (54 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import compress_pickle
import open_clip
import torch
from sentence_transformers import SentenceTransformer
import numpy as np
import torch.nn.functional as F
import pickle
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--path_to_asset", type=str, default="~/.objathor-assets")
args = parser.parse_args()
return args
args = parse_args()
objathor_clip_features_dict = compress_pickle.load(f"{args.path_to_asset}/2023_09_23/features/clip_features.pkl")
objathor_sbert_features_dict = compress_pickle.load(f"{args.path_to_asset}/2023_09_23/features/sbert_features.pkl")
with open("./scannet_classes.txt","r") as f:
scannet_cat = f.readlines()
scannet_cat = [item.strip() for item in scannet_cat]
with open("./scannetpp_classes.txt","r") as f:
scannetpp_cat = f.readlines()
scannetpp_cat = [item.strip() for item in scannetpp_cat]
scannetpp_cat = list(set(scannetpp_cat)-set(scannet_cat))
scannetpp_cat = ["a 3D model of "+item.strip() for item in scannetpp_cat]
objathor_uids = objathor_clip_features_dict["uids"]
objathor_clip_features = torch.from_numpy(objathor_clip_features_dict["img_features"].astype(np.float32))
objathor_clip_features = F.normalize(objathor_clip_features, p=2, dim=-1)
objathor_sbert_features = torch.from_numpy(objathor_sbert_features_dict["text_features"].astype(np.float32))
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms("ViT-L-14", pretrained="laion2b_s32b_b82k")
clip_tokenizer = open_clip.get_tokenizer("ViT-L-14")
sbert_model = SentenceTransformer("all-mpnet-base-v2", device="cpu")
with torch.no_grad():
query_feature_clip = clip_model.encode_text(
clip_tokenizer(scannetpp_cat)
)
query_feature_clip = F.normalize(query_feature_clip, p=2, dim=-1)
clip_similarities = 100 * torch.einsum(
"ij, lkj -> ilk", query_feature_clip, objathor_clip_features
)
clip_similarities = torch.max(clip_similarities, dim=-1).values
query_feature_sbert = sbert_model.encode(scannetpp_cat, convert_to_tensor=True, show_progress_bar=False)
sbert_similarities = query_feature_sbert @ objathor_sbert_features.T
similarities = clip_similarities + sbert_similarities
row_indices, col_indices = torch.where(similarities >= 35)
indices_left = list(set(range(len(objathor_uids)))-set(col_indices.numpy()))
objathor_uids_left = [objathor_uids[indice] for indice in indices_left]
objathor_clip_features_left = objathor_clip_features_dict["img_features"][indices_left]
objathor_sbert_features_left = objathor_sbert_features_dict["text_features"][indices_left]
new_objathor_clip_features_dict = {"uids":objathor_uids_left, "img_features":objathor_clip_features_left}
new_objathor_sbert_features_dict = {"uids":objathor_uids_left, "text_features":objathor_sbert_features_left}
with open(f"{args.path_to_asset}/2023_09_23/features/clip_features.pkl","wb") as f:
pickle.dump(new_objathor_clip_features_dict, f)
with open(f"{args.path_to_asset}/2023_09_23/features/sbert_features.pkl","wb") as f:
pickle.dump(new_objathor_sbert_features_dict, f)