pain commited on
Commit
81e4da0
1 Parent(s): 24682d3

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -200
utils.py DELETED
@@ -1,200 +0,0 @@
1
-
2
- import os
3
- import numpy as np
4
- import pickle
5
- import torch
6
- import transformers
7
- from PIL import Image
8
- from open_clip import create_model_from_pretrained, create_model_and_transforms
9
- import json
10
-
11
- # XLM model functions
12
- from multilingual_clip import pt_multilingual_clip
13
-
14
- from model_loading import load_model
15
-
16
-
17
-
18
- class CustomDataSet(torch.utils.data.Dataset):
19
- def __init__(self, main_dir, compose, image_name_list):
20
- self.main_dir = main_dir
21
- self.transform = compose
22
- self.total_imgs = image_name_list
23
-
24
- def __len__(self):
25
- return len(self.total_imgs)
26
-
27
- def get_image_name(self, idx):
28
-
29
- return self.total_imgs[idx]
30
-
31
- def __getitem__(self, idx):
32
- img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
33
- image = Image.open(img_loc)
34
-
35
- return self.transform(image)
36
-
37
-
38
- def features_pickle(file_path=None):
39
-
40
- with open(file_path, 'rb') as handle:
41
- features_pickle = pickle.load(handle)
42
-
43
- return features_pickle
44
-
45
-
46
- def dataset_loading():
47
-
48
- with open("/home/think3/Desktop/2. tf_testing_araclip/XTD_dataset/en_ar_XTD10_edited_v2.jsonl") as filino:
49
-
50
-
51
- data = [json.loads(file_i) for file_i in filino]
52
-
53
- sorted_data = sorted(data, key=lambda x: x['id'])
54
-
55
- image_name_list = [lin["image_name"] for lin in sorted_data]
56
-
57
-
58
- return sorted_data, image_name_list
59
-
60
-
61
- def text_encoder(language_model, text):
62
- """Normalize the text embeddings"""
63
- embedding = language_model(text)
64
- norm_embedding = embedding / np.linalg.norm(embedding)
65
-
66
- return embedding, norm_embedding
67
-
68
-
69
- def compare_embeddings(logit_scale, img_embs, txt_embs):
70
-
71
- image_features = img_embs / img_embs.norm(dim=-1, keepdim=True)
72
-
73
- text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True)
74
-
75
- logits_per_text = logit_scale * text_features @ image_features.t()
76
-
77
- return logits_per_text
78
-
79
- # Done
80
- def compare_embeddings_text(full_text_embds, txt_embs):
81
-
82
- full_text_embds_features = full_text_embds / full_text_embds.norm(dim=-1, keepdim=True)
83
-
84
- text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True)
85
-
86
- logits_per_text_full = text_features @ full_text_embds_features.t()
87
-
88
- return logits_per_text_full
89
-
90
-
91
-
92
- def find_image(language_model,clip_model, text_query, dataset, image_features, text_features_new,sorted_data, num=1):
93
-
94
- embedding, _ = text_encoder(language_model, text_query)
95
-
96
- logit_scale = clip_model.logit_scale.exp().float().to('cpu')
97
-
98
- language_logits, text_logits = {}, {}
99
-
100
- language_logits["Arabic"] = compare_embeddings(logit_scale, torch.from_numpy(image_features), torch.from_numpy(embedding))
101
-
102
- text_logits["Arabic_text"] = compare_embeddings_text(torch.from_numpy(text_features_new), torch.from_numpy(embedding))
103
-
104
-
105
- for _, txt_logits in language_logits.items():
106
-
107
- probs = txt_logits.softmax(dim=-1).cpu().detach().numpy().T
108
-
109
- file_paths = []
110
- labels, json_data = {}, {}
111
-
112
- for i in range(1, num+1):
113
- idx = np.argsort(probs, axis=0)[-i, 0]
114
- path = 'photos/XTD10_dataset/' + dataset.get_image_name(idx)
115
-
116
- path_l = (path,f"{sorted_data[idx]['caption_ar']}")
117
-
118
- labels[f" Image # {i}"] = probs[idx]
119
- json_data[f" Image # {i}"] = sorted_data[idx]
120
-
121
- file_paths.append(path_l)
122
-
123
-
124
- json_text = {}
125
-
126
- for _, txt_logits_full in text_logits.items():
127
-
128
- probs_text = txt_logits_full.softmax(dim=-1).cpu().detach().numpy().T
129
-
130
- for j in range(1, num+1):
131
-
132
- idx = np.argsort(probs_text, axis=0)[-j, 0]
133
- json_text[f" Text # {j}"] = sorted_data[idx]
134
-
135
- return file_paths, labels, json_data, json_text
136
-
137
-
138
-
139
- class AraClip():
140
- def __init__(self):
141
-
142
- self.text_model = load_model('bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M', in_features= 768, out_features=768)
143
- self.language_model = lambda queries: np.asarray(self.text_model(queries).detach().to('cpu'))
144
- self.clip_model, self.compose = create_model_from_pretrained('hf-hub:timm/ViT-B-16-SigLIP-512')
145
- self.sorted_data, self.image_name_list = dataset_loading()
146
-
147
- def load_images(self):
148
- # Return the features of the text and images
149
- image_features_new = features_pickle('testing_pickle_files_images_text/image_features_XTD_1000_images_arabert_siglib_best_model.pickle')
150
- return image_features_new
151
-
152
- def load_text(self):
153
- text_features_new = features_pickle('testing_pickle_files_images_text/text_features_XTD_1000_images_arabert_siglib_best_model.pickle')
154
- return text_features_new
155
-
156
- def load_dataset(self):
157
- dataset = CustomDataSet("photos/XTD10_dataset", self.compose, self.image_name_list)
158
- return dataset
159
-
160
-
161
- araclip = AraClip()
162
-
163
- def predict(text, num):
164
-
165
- image_paths, labels, json_data, json_text = find_image(araclip.language_model,araclip.clip_model, text, araclip.load_dataset(), araclip.load_images() , araclip.load_text(), araclip.sorted_data, num=int(num))
166
-
167
- return image_paths, labels, json_data, json_text
168
-
169
-
170
- class Mclip():
171
- def __init__(self) -> None:
172
-
173
-
174
- self.tokenizer_mclip = transformers.AutoTokenizer.from_pretrained('M-CLIP/XLM-Roberta-Large-Vit-B-16Plus')
175
- self.text_model_mclip = pt_multilingual_clip.MultilingualCLIP.from_pretrained('M-CLIP/XLM-Roberta-Large-Vit-B-16Plus')
176
- self.language_model_mclip = lambda queries: np.asarray(self.text_model_mclip.forward(queries, self.tokenizer_mclip).detach().to('cpu'))
177
- self.clip_model_mclip, _, self.compose_mclip = create_model_and_transforms('ViT-B-16-plus-240', pretrained="laion400m_e32")
178
- self.sorted_data, self.image_name_list = dataset_loading()
179
-
180
- def load_images(self):
181
- # Return the features of the text and images
182
- image_features_mclip = features_pickle('Cach_embeddings/image_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle')
183
- return image_features_mclip
184
-
185
- def load_text(self):
186
- text_features_new_mclip = features_pickle('Cach_embeddings/text_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle')
187
- return text_features_new_mclip
188
-
189
- def load_dataset(self):
190
- dataset_mclip = CustomDataSet("photos/XTD10_dataset", self.compose_mclip, self.image_name_list)
191
- return dataset_mclip
192
-
193
-
194
- mclip = Mclip()
195
-
196
- def predict_mclip(text, num):
197
-
198
- image_paths, labels, json_data, json_text = find_image(mclip.language_model_mclip,mclip.clip_model_mclip, text, mclip.load_dataset() , mclip.load_text() , mclip.load_text() , mclip.sorted_data , num=int(num))
199
-
200
- return image_paths, labels, json_data, json_text