Update CXR_LLAVA_HF.py
Browse files- CXR_LLAVA_HF.py +2 -8
CXR_LLAVA_HF.py
CHANGED
@@ -615,11 +615,8 @@ class CXRLLAVAModel(PreTrainedModel):
|
|
615 |
images = self.vision_tower.image_processor(image, return_tensors='pt')['pixel_values']
|
616 |
images = images.to(self.device)
|
617 |
input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
|
618 |
-
|
619 |
-
|
620 |
-
print('using cuda')
|
621 |
-
else:
|
622 |
-
print(f'using device {self.device}')
|
623 |
stopping_criteria = KeywordsStoppingCriteria(["</s>"], self.tokenizer, input_ids)
|
624 |
|
625 |
image_args = {"images": images}
|
@@ -642,11 +639,8 @@ class CXRLLAVAModel(PreTrainedModel):
|
|
642 |
))
|
643 |
thread.start()
|
644 |
generated_text = ""
|
645 |
-
text_len = 0
|
646 |
for new_text in streamer:
|
647 |
generated_text += new_text
|
648 |
-
text_len += 1
|
649 |
-
if text_len > 200: break
|
650 |
|
651 |
return generated_text
|
652 |
|
|
|
615 |
images = self.vision_tower.image_processor(image, return_tensors='pt')['pixel_values']
|
616 |
images = images.to(self.device)
|
617 |
input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
|
618 |
+
input_ids = input_ids.to(self.device)
|
619 |
+
# print(f'using device {self.device}')
|
|
|
|
|
|
|
620 |
stopping_criteria = KeywordsStoppingCriteria(["</s>"], self.tokenizer, input_ids)
|
621 |
|
622 |
image_args = {"images": images}
|
|
|
639 |
))
|
640 |
thread.start()
|
641 |
generated_text = ""
|
|
|
642 |
for new_text in streamer:
|
643 |
generated_text += new_text
|
|
|
|
|
644 |
|
645 |
return generated_text
|
646 |
|