File size: 8,186 Bytes
711211a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import PIL
import torch

from .prompts import GetPromptList

ORG_PART_ORDER = ['back', 'beak', 'belly', 'breast', 'crown', 'forehead', 'eyes', 'legs', 'wings', 'nape', 'tail', 'throat']
ORDERED_PARTS = ['crown', 'forehead', 'nape', 'eyes', 'beak', 'throat', 'breast', 'belly', 'back', 'wings', 'legs', 'tail']

def encode_descs_xclip(owlvit_det_processor: callable, model: callable, descs: list[str], device: str, max_batch_size: int = 512):
    total_num_batches = len(descs) // max_batch_size + 1
    with torch.no_grad():
        text_embeds = []
        for batch_idx in range(total_num_batches):
            query_descs = descs[batch_idx*max_batch_size:(batch_idx+1)*max_batch_size]
            query_tokens = owlvit_det_processor(text=query_descs, padding="max_length", truncation=True, return_tensors="pt").to(device)
            query_embeds = model.owlvit.get_text_features(**query_tokens)
            text_embeds.append(query_embeds.cpu().float())
    text_embeds = torch.cat(text_embeds, dim=0)
    return text_embeds.to(device)

# def encode_descs_clip(model: callable, descs: list[str], device: str, max_batch_size: int = 512):
#     total_num_batches = len(descs) // max_batch_size + 1
#     with torch.no_grad():
#         text_embeds = []
#         for batch_idx in range(total_num_batches):
#             desc = descs[batch_idx*max_batch_size:(batch_idx+1)*max_batch_size]
#             query_tokens = clip.tokenize(desc).to(device)
#             text_embeds.append(model.encode_text(query_tokens).cpu().float())
#     text_embeds = torch.cat(text_embeds, dim=0)
#     text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1)
#     return text_embeds.to(device)

def xclip_pred(new_desc: dict, 
               new_part_mask: dict, 
               new_class: str, 
               org_desc: str, 
               image: PIL.Image, 
               model: callable, 
               owlvit_processor: callable,
               device: str,
               return_img_embeds: bool = False,
               use_precompute_embeddings = True,
               image_name: str = None,):
    # reorder the new description and the mask
    if new_class is not None:
        new_desc_ = {k: new_desc[k] for k in ORG_PART_ORDER}
        new_part_mask_ = {k: new_part_mask[k] for k in ORG_PART_ORDER}
        desc_mask = list(new_part_mask_.values())
    else:
        desc_mask = [1] * 12

    # replace the description if the new class is in the description, otherwise add a new class
    getprompt = GetPromptList(org_desc)
    if new_class not in getprompt.desc and new_class is not None:
        getprompt.name2idx[new_class] = len(getprompt.name2idx)
    if new_class is not None:
        getprompt.desc[new_class] = list(new_desc_.values())
    
    idx2name = dict(zip(getprompt.name2idx.values(), getprompt.name2idx.keys()))
    modified_class_idx = getprompt.name2idx[new_class] if new_class is not None else None
    
    n_classes = len(getprompt.name2idx)
    model.cls_head.num_classes = n_classes
    
    descs, class_idxs, class_mapping, org_desc_mapper, class_list = getprompt('chatgpt-no-template', max_len=12, pad=True)
    query_embeds = encode_descs_xclip(owlvit_processor, model, descs, device)
    
    with torch.no_grad():
        image_input = owlvit_processor(images=image, return_tensors='pt').to(device)
        # image_input['pixel_values'] = image_input['pixel_values'].squeeze(1)
        
        part_embeds = owlvit_processor(text=[ORG_PART_ORDER], return_tensors="pt").to(device)
        if return_img_embeds:
            feature_map, _ = model.image_embedder(pixel_values = image_input['pixel_values'])
        if use_precompute_embeddings:
            image_embeds = torch.load(f'data/image_embeddings/{image_name}.pt').to(device)
            pred_logits, part_logits, output_dict = model(image_embeds, part_embeds, query_embeds, None)
        else:
            pred_logits, part_logits, output_dict = model(image_input, part_embeds, query_embeds, None)
        
        b, c, n = part_logits.shape
        mask = torch.tensor(desc_mask, dtype=float).unsqueeze(0).unsqueeze(0).repeat(b, c, 1).to(device)
        # overwrite the pred_logits
        part_logits = part_logits * mask
        pred_logits = torch.sum(part_logits, dim=-1)
        
        pred_class_idx = torch.argmax(pred_logits, dim=-1).cpu()
        pred_class_name = idx2name[pred_class_idx.item()]
        
        softmax_scores = torch.softmax(pred_logits, dim=-1).cpu()
        softmax_score_top1 = torch.topk(softmax_scores, k=1, dim=-1)[0].squeeze(-1).item()
        
        part_scores = part_logits[0, pred_class_idx].cpu().squeeze(0)
        part_scores_dict = dict(zip(ORG_PART_ORDER, part_scores.tolist()))
        
        if modified_class_idx is not None:
            modified_score = softmax_scores[0, modified_class_idx].item()
            modified_part_scores = part_logits[0, modified_class_idx].cpu().squeeze(0)
            modified_part_scores_dict = dict(zip(ORG_PART_ORDER, modified_part_scores.tolist()))
        else:
            modified_score = None
            modified_part_scores_dict = None
            modified_part_scores_dict = None
        
    output_dict = {"pred_class": pred_class_name,
                   "pred_score": softmax_score_top1,
                   "pred_desc_scores": part_scores_dict,
                   "descriptions": getprompt.desc[pred_class_name],
                   "modified_class": new_class,
                   "modified_score": modified_score,
                   "modified_desc_scores": modified_part_scores_dict,
                   "modified_descriptions": getprompt.desc[new_class] if new_class is not None else None,
                   }
    return output_dict if not return_img_embeds else (output_dict, feature_map)


# def sachit_pred(new_desc: list, 
#                 new_class: str,
#                 org_desc: str,
#                 image: PIL.Image,
#                 model: callable,
#                 preprocess: callable,
#                 device: str,
#                 ):

#     # replace the description if the new class is in the description, otherwise add a new class
#     getprompt = GetPromptList(org_desc)
    
#     if new_class not in getprompt.desc:
#         getprompt.name2idx[new_class] = len(getprompt.name2idx)
#     getprompt.desc[new_class] = new_desc
    
#     idx2name = dict(zip(getprompt.name2idx.values(), getprompt.name2idx.keys()))
#     modified_class_idx = getprompt.name2idx[new_class]
    
#     descs, class_idxs, class_mapping, org_desc_mapper, class_list = getprompt('Sachit-descriptors', max_len=12, pad=True)
    
#     text_embeds = encode_descs_clip(model, descs, device)
    
#     with torch.no_grad():
#         image_embed = model.encode_image(preprocess(image).unsqueeze(0).to(device))
#         desc_mask = torch.tensor(class_idxs)
#         desc_mask = torch.where(desc_mask == -1, 0, 1).unsqueeze(0).to(device)
        
#         sim = torch.matmul(image_embed.float(), text_embeds.T)
#         sim = (sim * desc_mask).view(1, -1, 12)
#         pred_scores = torch.sum(sim, dim=-1)
#         pred_class_idx = torch.argmax(pred_scores, dim=-1).cpu()
#         pred_class = idx2name[pred_class_idx.item()]
        
#         softmax_scores = torch.nn.functional.softmax(pred_scores, dim=-1).cpu()
#         top1_score = torch.topk(softmax_scores, k=1, dim=-1)[0].squeeze(-1).item()
#         modified_score = softmax_scores[0, modified_class_idx].item()
        
#         pred_desc_scores = sim[0, pred_class_idx].cpu().squeeze(0)
#         modified_class_scores = sim[0, modified_class_idx].cpu().squeeze(0)
        
    
#     output_dict = {"pred_class": pred_class,
#                    "pred_score": top1_score,
#                    "pred_desc_scores": pred_desc_scores.tolist(),
#                    "descriptions": getprompt.desc[pred_class],
#                    "modified_class": new_class,
#                    "modified_score": modified_score,
#                    "modified_desc_scores": modified_class_scores.tolist(),
#                    "modified_descriptions": getprompt.desc[new_class],
#                    }
    
#     return output_dict