jcsagar commited on
Commit
ad7d308
1 Parent(s): 5a20da3

Update CXR_LLAVA_HF.py

Browse files
Files changed (1) hide show
  1. CXR_LLAVA_HF.py +2 -8
CXR_LLAVA_HF.py CHANGED
@@ -615,11 +615,8 @@ class CXRLLAVAModel(PreTrainedModel):
615
  images = self.vision_tower.image_processor(image, return_tensors='pt')['pixel_values']
616
  images = images.to(self.device)
617
  input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
618
- if self.device == 'cuda':
619
- input_ids = input_ids.cuda()
620
- print('using cuda')
621
- else:
622
- print(f'using device {self.device}')
623
  stopping_criteria = KeywordsStoppingCriteria(["</s>"], self.tokenizer, input_ids)
624
 
625
  image_args = {"images": images}
@@ -642,11 +639,8 @@ class CXRLLAVAModel(PreTrainedModel):
642
  ))
643
  thread.start()
644
  generated_text = ""
645
- text_len = 0
646
  for new_text in streamer:
647
  generated_text += new_text
648
- text_len += 1
649
- if text_len > 200: break
650
 
651
  return generated_text
652
 
 
615
  images = self.vision_tower.image_processor(image, return_tensors='pt')['pixel_values']
616
  images = images.to(self.device)
617
  input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
618
+ input_ids = input_ids.to(self.device)
619
+ # print(f'using device {self.device}')
 
 
 
620
  stopping_criteria = KeywordsStoppingCriteria(["</s>"], self.tokenizer, input_ids)
621
 
622
  image_args = {"images": images}
 
639
  ))
640
  thread.start()
641
  generated_text = ""
 
642
  for new_text in streamer:
643
  generated_text += new_text
 
 
644
 
645
  return generated_text
646