Update modeling_mplug_owl2.py
Browse files- modeling_mplug_owl2.py +2 -2
modeling_mplug_owl2.py
CHANGED
@@ -270,7 +270,7 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
|
|
270 |
def score(self, images,
|
271 |
task_: str = "quality",
|
272 |
input_: str = "image",
|
273 |
-
return_dict=False,
|
274 |
image_tensor = None,
|
275 |
):
|
276 |
if not hasattr(self, "weight_tensor"):
|
@@ -279,8 +279,8 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
|
|
279 |
if input_ == "image":
|
280 |
if image_tensor is None:
|
281 |
images = [expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in images]
|
282 |
-
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
283 |
image_tensor = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"].half().to(self.device)
|
|
|
284 |
with torch.inference_mode():
|
285 |
output_logits = self(input_ids.repeat(image_tensor.shape[0], 1),
|
286 |
images=image_tensor)["logits"][:,-1, self.preferential_ids_]
|
|
|
270 |
def score(self, images,
|
271 |
task_: str = "quality",
|
272 |
input_: str = "image",
|
273 |
+
return_dict = False,
|
274 |
image_tensor = None,
|
275 |
):
|
276 |
if not hasattr(self, "weight_tensor"):
|
|
|
279 |
if input_ == "image":
|
280 |
if image_tensor is None:
|
281 |
images = [expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in images]
|
|
|
282 |
image_tensor = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"].half().to(self.device)
|
283 |
+
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
284 |
with torch.inference_mode():
|
285 |
output_logits = self(input_ids.repeat(image_tensor.shape[0], 1),
|
286 |
images=image_tensor)["logits"][:,-1, self.preferential_ids_]
|