chenlin commited on
Commit
b1f0832
1 Parent(s): 8f9c756

update for batch infer

Browse files
Files changed (1) hide show
  1. modeling_InternLM.py +4 -5
modeling_InternLM.py CHANGED
@@ -1,15 +1,14 @@
1
  import math
2
- from typing import List, Union
3
- from typing import Optional, Tuple
4
 
5
  import torch
6
  import torch.utils.checkpoint
7
- import torch.utils.checkpoint
8
  from einops import rearrange
9
  from torch import nn
10
  from torch.nn import CrossEntropyLoss
11
  from transformers.activations import ACT2FN
12
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
13
  from transformers.modeling_utils import PreTrainedModel
14
  from transformers.utils import logging
15
 
@@ -863,6 +862,6 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
863
  reordered_past = ()
864
  for layer_past in past_key_values:
865
  reordered_past += (tuple(
866
- past_state.index_select(0, beam_idx)
867
  for past_state in layer_past), )
868
  return reordered_past
 
1
  import math
2
+ from typing import List, Optional, Tuple, Union
 
3
 
4
  import torch
5
  import torch.utils.checkpoint
 
6
  from einops import rearrange
7
  from torch import nn
8
  from torch.nn import CrossEntropyLoss
9
  from transformers.activations import ACT2FN
10
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
11
+ CausalLMOutputWithPast)
12
  from transformers.modeling_utils import PreTrainedModel
13
  from transformers.utils import logging
14
 
 
862
  reordered_past = ()
863
  for layer_past in past_key_values:
864
  reordered_past += (tuple(
865
+ past_state.index_select(0, beam_idx.to(past_state.device))
866
  for past_state in layer_past), )
867
  return reordered_past