jcsagar commited on
Commit
5a20da3
1 Parent(s): b74d72d

Update CXR_LLAVA_HF.py

Browse files
Files changed (1) hide show
  1. CXR_LLAVA_HF.py +2 -1
CXR_LLAVA_HF.py CHANGED
@@ -295,7 +295,8 @@ class CXRLLAVAModel(PreTrainedModel):
295
  cur_new_labels.append(cur_labels[image_token_start:image_token_start + 1])
296
  cur_labels = cur_labels[image_token_start + 2:]
297
  else:
298
- print(self.llama.device, cur_input_ids)
 
299
  cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids[:image_token_start]))
300
  cur_new_input_embeds.append(cur_image_features)
301
  if labels is not None:
 
295
  cur_new_labels.append(cur_labels[image_token_start:image_token_start + 1])
296
  cur_labels = cur_labels[image_token_start + 2:]
297
  else:
298
+ # print(self.llama.device, cur_input_ids)
299
+ cur_input_ids = cur_input_ids.to(self.llama.device)
300
  cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids[:image_token_start]))
301
  cur_new_input_embeds.append(cur_image_features)
302
  if labels is not None: