Update CXR_LLAVA_HF.py
Browse files- CXR_LLAVA_HF.py +2 -1
CXR_LLAVA_HF.py
CHANGED
@@ -295,7 +295,8 @@ class CXRLLAVAModel(PreTrainedModel):
|
|
295 |
cur_new_labels.append(cur_labels[image_token_start:image_token_start + 1])
|
296 |
cur_labels = cur_labels[image_token_start + 2:]
|
297 |
else:
|
298 |
-
print(self.llama.device, cur_input_ids)
|
|
|
299 |
cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids[:image_token_start]))
|
300 |
cur_new_input_embeds.append(cur_image_features)
|
301 |
if labels is not None:
|
|
|
295 |
cur_new_labels.append(cur_labels[image_token_start:image_token_start + 1])
|
296 |
cur_labels = cur_labels[image_token_start + 2:]
|
297 |
else:
|
298 |
+
# print(self.llama.device, cur_input_ids)
|
299 |
+
cur_input_ids = cur_input_ids.to(self.llama.device)
|
300 |
cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids[:image_token_start]))
|
301 |
cur_new_input_embeds.append(cur_image_features)
|
302 |
if labels is not None:
|