zRzRzRzRzRzRzR vcadillo commited on
Commit
4a1c77f
1 Parent(s): 301cb65

Multiple GPU's issue. (#19)

Browse files

- Multiple GPU's issue. (73cd2bc90c8711e5de330bec279b118d58170258)


Co-authored-by: Victor Cadillo <vcadillo@users.noreply.huggingface.co>

Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -858,7 +858,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
858
  self.config.eoi_token_id)
859
  assert eoi_token_pos - boi_token_pos == 2
860
  new_input_embeds.append(torch.cat(
861
- (inputs_embeds[i, :boi_token_pos], images_features[i], inputs_embeds[i, eoi_token_pos + 1:])))
862
  new_position_ids.append(torch.cat(
863
  (position_ids[i, :boi_token_pos + 1], position_ids[i, boi_token_pos + 1].repeat(num_patches),
864
  position_ids[i, eoi_token_pos:])
 
858
  self.config.eoi_token_id)
859
  assert eoi_token_pos - boi_token_pos == 2
860
  new_input_embeds.append(torch.cat(
861
+ (inputs_embeds[i, :boi_token_pos], images_features[i].to(inputs_embeds.device), inputs_embeds[i, eoi_token_pos + 1:])))
862
  new_position_ids.append(torch.cat(
863
  (position_ids[i, :boi_token_pos + 1], position_ids[i, boi_token_pos + 1].repeat(num_patches),
864
  position_ids[i, eoi_token_pos:])