File size: 9,106 Bytes
92ef913
 
 
 
 
 
 
 
711211a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92ef913
711211a
 
 
 
 
 
 
 
 
 
395d6df
 
 
 
711211a
 
 
 
 
 
 
 
395d6df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92ef913
395d6df
 
 
 
 
 
 
711211a
 
 
 
 
 
 
 
66ba241
395d6df
 
 
711211a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395d6df
711211a
 
 
395d6df
711211a
395d6df
711211a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
try:
    import spaces
    gpu_decorator = spaces.GPU
except ImportError:
    # Define a no-operation decorator as fallback
    def gpu_decorator(func):
        return func
    
import PIL
import torch

from .prompts import GetPromptList

ORG_PART_ORDER = ['back', 'beak', 'belly', 'breast', 'crown', 'forehead', 'eyes', 'legs', 'wings', 'nape', 'tail', 'throat']
ORDERED_PARTS = ['crown', 'forehead', 'nape', 'eyes', 'beak', 'throat', 'breast', 'belly', 'back', 'wings', 'legs', 'tail']

def encode_descs_xclip(owlvit_det_processor: callable, model: callable, descs: list[str], device: str, max_batch_size: int = 512):
    total_num_batches = len(descs) // max_batch_size + 1
    with torch.no_grad():
        text_embeds = []
        for batch_idx in range(total_num_batches):
            query_descs = descs[batch_idx*max_batch_size:(batch_idx+1)*max_batch_size]
            query_tokens = owlvit_det_processor(text=query_descs, padding="max_length", truncation=True, return_tensors="pt").to(device)
            query_embeds = model.owlvit.get_text_features(**query_tokens)
            text_embeds.append(query_embeds.cpu().float())
    text_embeds = torch.cat(text_embeds, dim=0)
    return text_embeds.to(device)

# def encode_descs_clip(model: callable, descs: list[str], device: str, max_batch_size: int = 512):
#     total_num_batches = len(descs) // max_batch_size + 1
#     with torch.no_grad():
#         text_embeds = []
#         for batch_idx in range(total_num_batches):
#             desc = descs[batch_idx*max_batch_size:(batch_idx+1)*max_batch_size]
#             query_tokens = clip.tokenize(desc).to(device)
#             text_embeds.append(model.encode_text(query_tokens).cpu().float())
#     text_embeds = torch.cat(text_embeds, dim=0)
#     text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1)
#     return text_embeds.to(device)
@gpu_decorator
def xclip_pred(new_desc: dict, 
               new_part_mask: dict, 
               new_class: str, 
               org_desc: str, 
               image: PIL.Image, 
               model: callable, 
               owlvit_processor: callable,
               device: str,
               return_img_embeds: bool = False,
               use_precompute_embeddings = True,
               image_name: str = None,
               cub_embeds: torch.Tensor = None,
               cub_idx2name: dict = None,
               descriptors: dict = None):
    # reorder the new description and the mask
    if new_class is not None:
        new_desc_ = {k: new_desc[k] for k in ORG_PART_ORDER}
        new_part_mask_ = {k: new_part_mask[k] for k in ORG_PART_ORDER}
        desc_mask = list(new_part_mask_.values())
    else:
        desc_mask = [1] * 12

    if cub_embeds is None:
        # replace the description if the new class is in the description, otherwise add a new class
        getprompt = GetPromptList(org_desc)
        if new_class not in getprompt.desc and new_class is not None:
            getprompt.name2idx[new_class] = len(getprompt.name2idx)
        if new_class is not None:
            getprompt.desc[new_class] = list(new_desc_.values())
        
        idx2name = dict(zip(getprompt.name2idx.values(), getprompt.name2idx.keys()))
        modified_class_idx = getprompt.name2idx[new_class] if new_class is not None else None
        
        n_classes = len(getprompt.name2idx)
        descs, class_idxs, class_mapping, org_desc_mapper, class_list = getprompt('chatgpt-no-template', max_len=12, pad=True)
        query_embeds = encode_descs_xclip(owlvit_processor, model, descs, device)
    else:
        if new_class is not None:
            if new_class in list(cub_idx2name.values()):
                new_class = f"{new_class}_custom"
            idx2name = cub_idx2name | {200: new_class}
            descriptors |= {new_class: list(new_desc_.values())}
            n_classes = 201
            query_tokens = owlvit_processor(text=list(new_desc_.values()), padding="max_length", truncation=True, return_tensors="pt").to(device)
            new_class_embed = model.owlvit.get_text_features(**query_tokens)
            query_embeds = torch.cat([cub_embeds, new_class_embed], dim=0).to(device)
            modified_class_idx = 200
        else:
            n_classes = 200
            query_embeds = cub_embeds
            idx2name = cub_idx2name
            modified_class_idx = None
        
    model.cls_head.num_classes = n_classes
    
    with torch.no_grad():
        
        part_embeds = owlvit_processor(text=[ORG_PART_ORDER], return_tensors="pt").to(device)
        if use_precompute_embeddings:
            image_embeds = torch.load(f'data/image_embeddings/{image_name}.pt').to(device)
        else:
            image_input = owlvit_processor(images=image, return_tensors='pt').to(device)
            image_embeds, _ = model.image_embedder(pixel_values = image_input['pixel_values'])
            
        pred_logits, part_logits, output_dict = model(image_embeds, part_embeds, query_embeds, None)
        
        b, c, n = part_logits.shape
        mask = torch.tensor(desc_mask, dtype=float).unsqueeze(0).unsqueeze(0).repeat(b, c, 1).to(device)
        # overwrite the pred_logits
        part_logits = part_logits * mask
        pred_logits = torch.sum(part_logits, dim=-1)
        
        pred_class_idx = torch.argmax(pred_logits, dim=-1).cpu()
        pred_class_name = idx2name[pred_class_idx.item()]
        
        softmax_scores = torch.softmax(pred_logits, dim=-1).cpu()
        softmax_score_top1 = torch.topk(softmax_scores, k=1, dim=-1)[0].squeeze(-1).item()
        
        part_scores = part_logits[0, pred_class_idx].cpu().squeeze(0)
        part_scores_dict = dict(zip(ORG_PART_ORDER, part_scores.tolist()))
        
        if modified_class_idx is not None:
            modified_score = softmax_scores[0, modified_class_idx].item()
            modified_part_scores = part_logits[0, modified_class_idx].cpu().squeeze(0)
            modified_part_scores_dict = dict(zip(ORG_PART_ORDER, modified_part_scores.tolist()))
        else:
            modified_score = None
            modified_part_scores_dict = None
        
    output_dict = {"pred_class": pred_class_name,
                   "pred_score": softmax_score_top1,
                   "pred_desc_scores": part_scores_dict,
                   "descriptions": descriptors[pred_class_name],
                   "modified_class": new_class,
                   "modified_score": modified_score,
                   "modified_desc_scores": modified_part_scores_dict,
                   "modified_descriptions": descriptors.get(new_class),
                   }
    return (output_dict, image_embeds) if return_img_embeds else output_dict


# def sachit_pred(new_desc: list, 
#                 new_class: str,
#                 org_desc: str,
#                 image: PIL.Image,
#                 model: callable,
#                 preprocess: callable,
#                 device: str,
#                 ):

#     # replace the description if the new class is in the description, otherwise add a new class
#     getprompt = GetPromptList(org_desc)
    
#     if new_class not in getprompt.desc:
#         getprompt.name2idx[new_class] = len(getprompt.name2idx)
#     getprompt.desc[new_class] = new_desc
    
#     idx2name = dict(zip(getprompt.name2idx.values(), getprompt.name2idx.keys()))
#     modified_class_idx = getprompt.name2idx[new_class]
    
#     descs, class_idxs, class_mapping, org_desc_mapper, class_list = getprompt('Sachit-descriptors', max_len=12, pad=True)
    
#     text_embeds = encode_descs_clip(model, descs, device)
    
#     with torch.no_grad():
#         image_embed = model.encode_image(preprocess(image).unsqueeze(0).to(device))
#         desc_mask = torch.tensor(class_idxs)
#         desc_mask = torch.where(desc_mask == -1, 0, 1).unsqueeze(0).to(device)
        
#         sim = torch.matmul(image_embed.float(), text_embeds.T)
#         sim = (sim * desc_mask).view(1, -1, 12)
#         pred_scores = torch.sum(sim, dim=-1)
#         pred_class_idx = torch.argmax(pred_scores, dim=-1).cpu()
#         pred_class = idx2name[pred_class_idx.item()]
        
#         softmax_scores = torch.nn.functional.softmax(pred_scores, dim=-1).cpu()
#         top1_score = torch.topk(softmax_scores, k=1, dim=-1)[0].squeeze(-1).item()
#         modified_score = softmax_scores[0, modified_class_idx].item()
        
#         pred_desc_scores = sim[0, pred_class_idx].cpu().squeeze(0)
#         modified_class_scores = sim[0, modified_class_idx].cpu().squeeze(0)
        
    
#     output_dict = {"pred_class": pred_class,
#                    "pred_score": top1_score,
#                    "pred_desc_scores": pred_desc_scores.tolist(),
#                    "descriptions": getprompt.desc[pred_class],
#                    "modified_class": new_class,
#                    "modified_score": modified_score,
#                    "modified_desc_scores": modified_class_scores.tolist(),
#                    "modified_descriptions": getprompt.desc[new_class],
#                    }
    
#     return output_dict