Upload model
#2
by
altndrr
- opened
- modeling_cased.py +8 -3
modeling_cased.py
CHANGED
@@ -212,6 +212,8 @@ class CaSEDModel(PreTrainedModel):
|
|
212 |
|
213 |
vocabularies, samples_p = [], []
|
214 |
for image_z in images_z:
|
|
|
|
|
215 |
# generate a single text embedding from the unfiltered vocabulary
|
216 |
vocabulary = self.query_index(image_z)
|
217 |
text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
|
@@ -219,6 +221,9 @@ class CaSEDModel(PreTrainedModel):
|
|
219 |
text["attention_mask"] = text["attention_mask"][:, :77].to(self.device)
|
220 |
text_z = self.language_encoder(**text)[1]
|
221 |
text_z = self.language_proj(text_z)
|
|
|
|
|
|
|
222 |
|
223 |
# filter the vocabulary, embed it, and get its mean embedding
|
224 |
vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
|
@@ -231,8 +236,8 @@ class CaSEDModel(PreTrainedModel):
|
|
231 |
# get the image and text predictions
|
232 |
image_z = image_z / image_z.norm(dim=-1, keepdim=True)
|
233 |
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
234 |
-
image_p = (
|
235 |
-
text_p = (
|
236 |
|
237 |
# average the image and text predictions
|
238 |
alpha = alpha or self.hparams["alpha"]
|
@@ -244,7 +249,7 @@ class CaSEDModel(PreTrainedModel):
|
|
244 |
|
245 |
# get the scores
|
246 |
samples_p = torch.stack(samples_p, dim=0)
|
247 |
-
scores = sample_p.cpu()
|
248 |
|
249 |
# define the results
|
250 |
results = {"vocabularies": vocabularies, "scores": scores}
|
|
|
212 |
|
213 |
vocabularies, samples_p = [], []
|
214 |
for image_z in images_z:
|
215 |
+
image_z = image_z.unsqueeze(0)
|
216 |
+
|
217 |
# generate a single text embedding from the unfiltered vocabulary
|
218 |
vocabulary = self.query_index(image_z)
|
219 |
text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
|
|
|
221 |
text["attention_mask"] = text["attention_mask"][:, :77].to(self.device)
|
222 |
text_z = self.language_encoder(**text)[1]
|
223 |
text_z = self.language_proj(text_z)
|
224 |
+
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
225 |
+
text_z = text_z.mean(dim=0).unsqueeze(0)
|
226 |
+
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
227 |
|
228 |
# filter the vocabulary, embed it, and get its mean embedding
|
229 |
vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
|
|
|
236 |
# get the image and text predictions
|
237 |
image_z = image_z / image_z.norm(dim=-1, keepdim=True)
|
238 |
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
239 |
+
image_p = (self.logit_scale * image_z @ vocabulary_z.T).softmax(dim=-1)
|
240 |
+
text_p = (self.logit_scale * text_z @ vocabulary_z.T).softmax(dim=-1)
|
241 |
|
242 |
# average the image and text predictions
|
243 |
alpha = alpha or self.hparams["alpha"]
|
|
|
249 |
|
250 |
# get the scores
|
251 |
samples_p = torch.stack(samples_p, dim=0)
|
252 |
+
scores = sample_p.cpu()
|
253 |
|
254 |
# define the results
|
255 |
results = {"vocabularies": vocabularies, "scores": scores}
|