fix(modeling_phi): Fixes initial generation with length larger than context length.
Browse files- modeling_phi.py +7 -6
modeling_phi.py
CHANGED
@@ -170,11 +170,11 @@ def _apply_rotary_emb_qkv(
|
|
170 |
|
171 |
class RotaryEmbedding(nn.Module):
|
172 |
"""Rotary positional embedding (RoPE).
|
173 |
-
|
174 |
Reference:
|
175 |
RoFormer: Enhanced Transformer with Rotary Position Embedding.
|
176 |
https://arxiv.org/pdf/2104.09864.pdf.
|
177 |
-
|
178 |
"""
|
179 |
|
180 |
def __init__(
|
@@ -495,9 +495,9 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
|
|
495 |
sequence_start = inference_params.seqlen_offset
|
496 |
sequence_end = sequence_start + kv.shape[1]
|
497 |
|
498 |
-
# When the current sequence length is
|
499 |
# we need to concatenate the current `kv` with the cached `kv` to expand its length
|
500 |
-
if sequence_end
|
501 |
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
|
502 |
|
503 |
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
@@ -863,9 +863,10 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
863 |
**kwargs,
|
864 |
) -> Dict[str, Any]:
|
865 |
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
|
|
866 |
past_key_values = InferenceParams(
|
867 |
-
max_seqlen=self.config.n_positions,
|
868 |
-
max_batch_size=
|
869 |
seqlen_offset=0,
|
870 |
batch_size_offset=0,
|
871 |
key_value_memory_dict={},
|
|
|
170 |
|
171 |
class RotaryEmbedding(nn.Module):
|
172 |
"""Rotary positional embedding (RoPE).
|
173 |
+
|
174 |
Reference:
|
175 |
RoFormer: Enhanced Transformer with Rotary Position Embedding.
|
176 |
https://arxiv.org/pdf/2104.09864.pdf.
|
177 |
+
|
178 |
"""
|
179 |
|
180 |
def __init__(
|
|
|
495 |
sequence_start = inference_params.seqlen_offset
|
496 |
sequence_end = sequence_start + kv.shape[1]
|
497 |
|
498 |
+
# When the current sequence length is larger than the maximum sequence length,
|
499 |
# we need to concatenate the current `kv` with the cached `kv` to expand its length
|
500 |
+
if sequence_end > inference_params.max_seqlen:
|
501 |
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
|
502 |
|
503 |
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
|
|
863 |
**kwargs,
|
864 |
) -> Dict[str, Any]:
|
865 |
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
866 |
+
max_batch_size, max_seqlen = input_ids.shape
|
867 |
past_key_values = InferenceParams(
|
868 |
+
max_seqlen=max(max_seqlen, self.config.n_positions),
|
869 |
+
max_batch_size=max_batch_size,
|
870 |
seqlen_offset=0,
|
871 |
batch_size_offset=0,
|
872 |
key_value_memory_dict={},
|