|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" PyTorch LLaMA model.""" |
|
import math |
|
from typing import List, Optional, Tuple, Union |
|
import faiss |
|
from einops import rearrange |
|
|
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torch.linalg import vector_norm |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
from transformers.activations import ACT2FN |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPast, |
|
CausalLMOutputWithPast, |
|
SequenceClassifierOutputWithPast, |
|
) |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import ( |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
from .configuration_llama import ExtendedLlamaConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_CONFIG_FOR_DOC = "ExtendedLlamaConfig" |
|
|
|
|
|
|
|
def _make_causal_mask( |
|
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 |
|
): |
|
""" |
|
Make causal mask used for bi-directional self-attention. |
|
""" |
|
bsz, tgt_len = input_ids_shape |
|
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) |
|
mask_cond = torch.arange(mask.size(-1), device=device) |
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) |
|
mask = mask.to(dtype) |
|
|
|
if past_key_values_length > 0: |
|
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) |
|
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) |
|
|
|
|
|
|
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): |
|
""" |
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. |
|
""" |
|
bsz, src_len = mask.size() |
|
tgt_len = tgt_len if tgt_len is not None else src_len |
|
|
|
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) |
|
|
|
inverted_mask = 1.0 - expanded_mask |
|
|
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) |
|
|
|
|
|
class LlamaRMSNorm(nn.Module): |
|
def __init__(self, hidden_size, eps=1e-6): |
|
""" |
|
LlamaRMSNorm is equivalent to T5LayerNorm |
|
""" |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
self.variance_epsilon = eps |
|
|
|
def forward(self, hidden_states): |
|
input_dtype = hidden_states.dtype |
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) |
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
|
return (self.weight * hidden_states).to(input_dtype) |
|
|
|
|
|
class LlamaRotaryEmbedding(torch.nn.Module): |
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): |
|
super().__init__() |
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
|
|
self.max_seq_len_cached = max_position_embeddings |
|
t = torch.arange( |
|
self.max_seq_len_cached, |
|
device=self.inv_freq.device, |
|
dtype=self.inv_freq.dtype, |
|
) |
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
dtype = torch.get_default_dtype() |
|
self.register_buffer( |
|
"cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False |
|
) |
|
self.register_buffer( |
|
"sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False |
|
) |
|
|
|
def forward(self, x, seq_len=None): |
|
|
|
|
|
if seq_len > self.max_seq_len_cached: |
|
self.max_seq_len_cached = seq_len |
|
t = torch.arange( |
|
self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype |
|
) |
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
|
self.register_buffer( |
|
"cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False |
|
) |
|
self.register_buffer( |
|
"sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False |
|
) |
|
return ( |
|
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
|
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
|
) |
|
|
|
class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module): |
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None): |
|
super().__init__() |
|
self.ntk = ntk |
|
self.base = base |
|
self.dim = dim |
|
self.max_position_embeddings = max_position_embeddings |
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
|
|
self.max_seq_len_cached = max_position_embeddings |
|
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
dtype = torch.get_default_dtype() |
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) |
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) |
|
|
|
def forward(self, x, seq_len=None): |
|
|
|
|
|
if seq_len > self.max_seq_len_cached: |
|
self.max_seq_len_cached = seq_len |
|
if self.ntk: |
|
base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2)) |
|
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) |
|
if not self.ntk: |
|
t *= self.max_position_embeddings / seq_len |
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) |
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) |
|
return ( |
|
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
|
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
|
) |
|
|
|
class LlamaLinearScaledRotaryEmbedding(torch.nn.Module): |
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None): |
|
super().__init__() |
|
self.scale = scale |
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
|
|
self.max_seq_len_cached = max_position_embeddings |
|
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
|
t /= self.scale |
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
dtype = torch.get_default_dtype() |
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) |
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) |
|
|
|
def forward(self, x, seq_len=None): |
|
|
|
|
|
if seq_len > self.max_seq_len_cached: |
|
self.max_seq_len_cached = seq_len |
|
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) |
|
t /= self.scale |
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) |
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) |
|
return ( |
|
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
|
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
|
) |
|
|
|
class LlamaNTKScaledRotaryEmbedding(torch.nn.Module): |
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None): |
|
super().__init__() |
|
base = base * alpha ** (dim / (dim-2)) |
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
|
|
self.max_seq_len_cached = max_position_embeddings |
|
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
dtype = torch.get_default_dtype() |
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) |
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) |
|
|
|
def forward(self, x, seq_len=None): |
|
|
|
|
|
if seq_len > self.max_seq_len_cached: |
|
self.max_seq_len_cached = seq_len |
|
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) |
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) |
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) |
|
return ( |
|
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
|
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
|
) |
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): |
|
|
|
cos = cos.squeeze(1).squeeze(0) |
|
sin = sin.squeeze(1).squeeze(0) |
|
|
|
s_q = q.size(-2) |
|
_q_position_ids = position_ids[:, -s_q:] |
|
_q_cos = cos[_q_position_ids].unsqueeze(1) |
|
_q_sin = sin[_q_position_ids].unsqueeze(1) |
|
q_embed = (q * _q_cos) + (rotate_half(q) * _q_sin) |
|
|
|
cos = cos[position_ids].unsqueeze(1) |
|
sin = sin[position_ids].unsqueeze(1) |
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
return q_embed, k_embed |
|
|
|
|
|
class LlamaMLP(nn.Module): |
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
intermediate_size: int, |
|
hidden_act: str, |
|
): |
|
super().__init__() |
|
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
|
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) |
|
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
|
self.act_fn = ACT2FN[hidden_act] |
|
|
|
def forward(self, x): |
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
class ExtendedLlamaAttention(nn.Module): |
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
def __init__(self, config: ExtendedLlamaConfig): |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = self.hidden_size // self.num_heads |
|
self.max_position_embeddings = config.max_position_embeddings |
|
self.num_hidden_layers = config.num_hidden_layers |
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_size: |
|
raise ValueError( |
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
|
f" and `num_heads`: {self.num_heads})." |
|
) |
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) |
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
|
|
long_range_past_key_value=None, |
|
faiss_indexes=None, |
|
mask_by_sim=False, |
|
sim_threshold=0.0, |
|
topk=None, |
|
current_layer=None, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
if past_key_value is not None: |
|
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
|
|
past_key_value = (key_states, value_states) if use_cache else None |
|
|
|
kv_seq_len = key_states.shape[-2] |
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
|
|
|
query_states_no_rotary = query_states.clone() |
|
|
|
query_states, key_states = apply_rotary_pos_emb( |
|
query_states, key_states, cos, sin, position_ids |
|
) |
|
|
|
bsz, nh, s_q, hd = query_states.shape |
|
s_k = key_states.size(-2) |
|
|
|
attn_weights = torch.matmul( |
|
query_states, key_states.transpose(2, 3) |
|
) / math.sqrt(self.head_dim) |
|
|
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" |
|
f" {attn_weights.size()}" |
|
) |
|
|
|
if long_range_past_key_value is not None or faiss_indexes is not None: |
|
if long_range_past_key_value is not None: |
|
|
|
k_cache, v_cache = long_range_past_key_value |
|
s_cache = k_cache.size(-2) |
|
|
|
k_cache = k_cache.to(key_states.device) |
|
v_cache = v_cache.to(key_states.device) |
|
|
|
q_n = query_states_no_rotary/vector_norm(query_states_no_rotary, ord=2, dim=-1, keepdim=True) |
|
k_n = k_cache/vector_norm(k_cache, ord=2, dim=-1, keepdim=True) |
|
|
|
sim = q_n.matmul(k_n.transpose(2,3)) |
|
if s_cache<topk: |
|
topk = s_cache |
|
val, idx = torch.topk(sim, k=topk, dim=-1) |
|
|
|
reshaped_idx = idx.reshape(bsz, nh, s_q * topk) |
|
|
|
cos_m, sin_m = self.rotary_emb(value_states, seq_len=self.max_position_embeddings) |
|
cos_m = cos_m[:,:,-1,...].repeat(1,1,s_q * topk,1) |
|
sin_m = sin_m[:,:,-1,...].repeat(1,1,s_q * topk,1) |
|
|
|
selected_k = k_cache.gather(dim=-2, index=reshaped_idx.unsqueeze(-1).expand(-1, -1, -1, hd)) |
|
_, selected_k = apply_rotary_pos_emb( |
|
torch.ones(selected_k.shape, device=key_states.device), selected_k, cos_m, sin_m, position_ids=torch.arange(s_q * topk, device=key_states.device).unsqueeze(0) |
|
) |
|
|
|
selected_v = v_cache.gather(dim=-2, index=reshaped_idx.unsqueeze(-1).expand(-1, -1, -1, hd)) |
|
|
|
sim_mask = rearrange(~ (val > sim_threshold).bool(), 'b h s i -> b h (s i)').unsqueeze(-2).expand(-1, -1, s_q, -1) |
|
|
|
elif faiss_indexes is not None: |
|
|
|
kn_index, kv_index = faiss_indexes |
|
q_n = query_states_no_rotary/vector_norm(query_states_no_rotary, ord=2, dim=-1, keepdim=True) |
|
|
|
one_hot_encodings = F.one_hot(torch.arange(0, nh*self.num_hidden_layers, device=query_states.device))*10 |
|
q_n = torch.concat([rearrange(q_n, 'b h s d -> b (h s) d', h=nh), one_hot_encodings[nh*current_layer:nh*(current_layer+1)].unsqueeze(0).repeat_interleave(repeats=query_states.size(-2), dim=-2)], dim=-1).squeeze() |
|
|
|
D, I = kn_index.search(q_n.to('cpu').numpy(), k=topk) |
|
|
|
selected_k=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,:hd], '(h s) d -> 1 h s d', h=nh).to(query_states.device) |
|
cos_m, sin_m = self.rotary_emb(value_states, seq_len=self.max_position_embeddings) |
|
cos_m = cos_m[:,:,-1,...].repeat(1,1,s_q * topk,1) |
|
sin_m = sin_m[:,:,-1,...].repeat(1,1,s_q * topk,1) |
|
|
|
_, selected_k = apply_rotary_pos_emb( |
|
torch.ones(selected_k.shape, device=key_states.device), selected_k, cos_m, sin_m, position_ids=torch.arange(s_q * topk, device=key_states.device).unsqueeze(0) |
|
) |
|
|
|
selected_v=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,hd:], '(h s) d -> 1 h s d', h=nh).to(query_states.device) |
|
|
|
attn_weight_cache = torch.matmul(query_states, selected_k.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
if mask_by_sim: |
|
attn_weight_cache = attn_weight_cache.masked_fill(sim_mask, torch.finfo(selected_k.dtype).min) |
|
|
|
attn_weights = torch.cat([attn_weight_cache, attn_weights], dim=-1) |
|
value_states = torch.cat([selected_v, value_states], dim=-2) |
|
|
|
min_val = torch.finfo(attn_weights.dtype).min |
|
def _create_active_externalism_mask(k, s_q, device, min_val=min_val): |
|
mask = torch.ones(s_q, s_q * k, device=device, dtype=torch.float32) |
|
for i in range(s_q): |
|
mask[i, i * k : (i + 1) * k] = 0 |
|
|
|
filled = mask.masked_fill(mask.bool(), min_val) |
|
return filled |
|
|
|
if attention_mask is not None: |
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
|
) |
|
if long_range_past_key_value is not None: |
|
memory_mask = _create_active_externalism_mask(k=topk,s_q=s_q, device=attn_weights.device) |
|
attention_mask = torch.cat([memory_mask, attention_mask[:,:,:,-s_k:].squeeze(dim=[0,1])], dim=1).unsqueeze(dim=0).unsqueeze(dim=1) |
|
|
|
attn_weights = attn_weights + attention_mask |
|
attn_weights = torch.max( |
|
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) |
|
) |
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2) |
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
|
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
if long_range_past_key_value is None and faiss_indexes is None: |
|
reshaped_idx=None |
|
|
|
return attn_output, attn_weights, past_key_value, reshaped_idx |
|
|
|
class ExtendedLlamaDecoderLayer(nn.Module): |
|
def __init__(self, config: ExtendedLlamaConfig): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.self_attn = ExtendedLlamaAttention(config=config) |
|
self.mlp = LlamaMLP( |
|
hidden_size=self.hidden_size, |
|
intermediate_size=config.intermediate_size, |
|
hidden_act=config.hidden_act, |
|
) |
|
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
|
|
long_range_past_key_value:Optional[Tuple[torch.Tensor]] = None, |
|
faiss_indexes:Tuple=None, |
|
mask_by_sim:bool=False, |
|
sim_threshold:float=None, |
|
topk:int=None, |
|
current_layer=None |
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
""" |
|
Args: |
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size |
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
|
(see `past_key_values`). |
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
|
""" |
|
|
|
residual = hidden_states |
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, self_attn_weights, present_key_value, selected_idx = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
|
|
long_range_past_key_value=long_range_past_key_value, |
|
faiss_indexes=faiss_indexes, |
|
mask_by_sim=mask_by_sim, |
|
sim_threshold=sim_threshold, |
|
topk=topk, |
|
current_layer=current_layer, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
if output_attentions: |
|
outputs += (selected_idx,) |
|
|
|
return outputs |
|
|
|
|
|
LLAMA_START_DOCSTRING = r""" |
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
and behavior. |
|
|
|
Parameters: |
|
config ([`LlamaConfig`]): |
|
Model configuration class with all the parameters of the model. Initializing with a config file does not |
|
load the weights associated with the model, only the configuration. Check out the |
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
""" |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.", |
|
LLAMA_START_DOCSTRING, |
|
) |
|
class LlamaPreTrainedModel(PreTrainedModel): |
|
config_class = ExtendedLlamaConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["LlamaDecoderLayer"] |
|
_skip_keys_device_placement = "past_key_values" |
|
|
|
def _init_weights(self, module): |
|
std = self.config.initializer_range |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if isinstance(module, ExtendedLlamaModel): |
|
module.gradient_checkpointing = value |
|
|
|
|
|
LLAMA_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
|
it. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see |
|
`past_key_values`). |
|
|
|
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] |
|
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more |
|
information on the default strategy. |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
config.n_positions - 1]`. |
|
|
|
[What are position IDs?](../glossary#position-ids) |
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape |
|
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
|
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
|
|
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that |
|
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all |
|
`decoder_input_ids` of shape `(batch_size, sequence_length)`. |
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
|
model's internal embedding lookup matrix. |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
|
`past_key_values`). |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.", |
|
LLAMA_START_DOCSTRING, |
|
) |
|
class ExtendedLlamaModel(LlamaPreTrainedModel): |
|
""" |
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] |
|
|
|
Args: |
|
config: LlamaConfig |
|
""" |
|
|
|
def __init__(self, config: ExtendedLlamaConfig): |
|
super().__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
self.layers = nn.ModuleList([ExtendedLlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) |
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.mask_by_sim = config.mask_by_sim |
|
self.sim_threshold = config.sim_threshold |
|
self.topk = config.topk |
|
self.use_active_externalism = config.use_active_externalism |
|
self.use_active_externalism_by_layer = config.use_active_externalism_by_layer |
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.embed_tokens = value |
|
|
|
|
|
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): |
|
|
|
|
|
combined_attention_mask = None |
|
if input_shape[-1] > 1: |
|
combined_attention_mask = _make_causal_mask( |
|
input_shape, |
|
inputs_embeds.dtype, |
|
device=inputs_embeds.device, |
|
past_key_values_length=past_key_values_length, |
|
) |
|
|
|
if attention_mask is not None: |
|
|
|
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( |
|
inputs_embeds.device |
|
) |
|
combined_attention_mask = ( |
|
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask |
|
) |
|
|
|
return combined_attention_mask |
|
|
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
|
|
use_active_externalism:Optional[bool]=None, |
|
long_range_past_key_values:Optional[List[Tuple[torch.FloatTensor]]] = None, |
|
faiss_indexes:Tuple=None, |
|
topk:int=None, |
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
use_active_externalism = (use_active_externalism if use_active_externalism is not None else self.use_active_externalism) |
|
topk = (topk if topk is not None else self.topk) |
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") |
|
elif input_ids is not None: |
|
batch_size, seq_length = input_ids.shape |
|
elif inputs_embeds is not None: |
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
else: |
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") |
|
|
|
seq_length_with_past = seq_length |
|
past_key_values_length = 0 |
|
|
|
if past_key_values is not None: |
|
past_key_values_length = past_key_values[0][0].shape[2] |
|
seq_length_with_past = seq_length_with_past + past_key_values_length |
|
|
|
if position_ids is None: |
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
position_ids = torch.arange( |
|
seq_length_with_past, dtype=torch.long, device=device |
|
) |
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length_with_past) |
|
else: |
|
position_ids = position_ids.view(-1, seq_length).long() |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones( |
|
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device |
|
) |
|
attention_mask = self._prepare_decoder_attention_mask( |
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length |
|
) |
|
|
|
hidden_states = inputs_embeds |
|
|
|
if self.gradient_checkpointing and self.training: |
|
if use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
) |
|
use_cache = False |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
next_decoder_cache = () if use_cache else None |
|
all_idx = () if output_attentions else None |
|
|
|
for idx, decoder_layer in enumerate(self.layers): |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
past_key_value = past_key_values[idx] if past_key_values is not None else None |
|
|
|
long_range_past_key_value = (long_range_past_key_values[idx] |
|
if (long_range_past_key_values is not None and self.use_active_externalism_by_layer[idx] and use_active_externalism is True) else None) |
|
|
|
if long_range_past_key_value is not None and faiss_indexes is not None: |
|
raise NotImplementedError( |
|
'Using faiss and passing key value pairs manually are mutually exclusive right now.') |
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
|
|
return module(*inputs, output_attentions, None) |
|
|
|
return custom_forward |
|
|
|
layer_outputs = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(decoder_layer), |
|
hidden_states, |
|
attention_mask, |
|
position_ids, |
|
None, |
|
) |
|
else: |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
|
|
topk=topk, |
|
long_range_past_key_value=long_range_past_key_value, |
|
faiss_indexes=faiss_indexes, |
|
mask_by_sim=self.mask_by_sim, |
|
sim_threshold=self.sim_threshold, |
|
current_layer=idx, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if use_cache: |
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
all_idx += (layer_outputs[3],) |
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=(all_self_attns, all_idx) |
|
) |
|
|
|
|
|
class ExtendedLlamaForCausalLM(LlamaPreTrainedModel): |
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
def __init__(self, config, external_memories=None, **kwargs): |
|
super().__init__(config) |
|
self.model = ExtendedLlamaModel(config) |
|
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
self.use_active_externalism = config.use_active_externalism |
|
self.memory_type = config.memory_type |
|
self.memory_device = config.memory_device |
|
self._memories = None |
|
if external_memories is not None: |
|
self._memories = external_memories |
|
self.memories = None |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def set_decoder(self, decoder): |
|
self.model = decoder |
|
|
|
def get_decoder(self): |
|
return self.model |
|
|
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
|
|
use_active_externalism: Optional[bool]=None, |
|
topk:int=None |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
r""" |
|
Args: |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, LlamaForCausalLM |
|
|
|
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) |
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) |
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?" |
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
>>> # Generate |
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
|
```""" |
|
|
|
if self._memories is not None and self.memories is None: |
|
self.memories = self.generate_cache(self._memories, cache_type=self.memory_type) |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
use_active_externalism = (use_active_externalism |
|
if use_active_externalism is not None else self.use_active_externalism |
|
) |
|
topk = topk if topk is not None else None |
|
|
|
long_range_past_key_values = None |
|
faiss_indexes = None |
|
if hasattr(self, "memories") and isinstance(self.memories, list): |
|
long_range_past_key_values = self.memories |
|
faiss_indexes = None |
|
elif hasattr(self, "memories"): |
|
long_range_past_key_values = None |
|
faiss_indexes = self.memories |
|
|
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
|
|
long_range_past_key_values=long_range_past_key_values, |
|
faiss_indexes=faiss_indexes, |
|
use_active_externalism=use_active_externalism, |
|
topk=topk |
|
) |
|
|
|
hidden_states = outputs[0] |
|
logits = self.lm_head(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = CrossEntropyLoss() |
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def generate_cache(self, |
|
input_ids:torch.LongTensor, |
|
stride:int=512, |
|
max_len:int=2048, |
|
cache_type:str='manual'): |
|
if cache_type not in ['manual', 'faiss']: |
|
raise NotImplementedError(f"Cache type {cache_type} not implemented.") |
|
|
|
prev_end_loc=0 |
|
long_range_past_key_values = None |
|
faiss_indexes= None |
|
for b_idx in range(0, input_ids.size(-1), stride): |
|
end_loc = min(b_idx + max_len, input_ids.size(-1)) |
|
trg_len = end_loc - prev_end_loc |
|
subseq = input_ids[:, b_idx:end_loc].to(self.model.device) |
|
with torch.no_grad(): |
|
outputs = self.model(subseq, use_cache=True, use_active_externalism=False) |
|
to_cache = [( |
|
kv[0][:,:,-trg_len:], |
|
kv[1][:,:,-trg_len:]) |
|
for kv in outputs.past_key_values |
|
] |
|
long_range_past_key_values, faiss_indexes = self.cache(to_cache, cache_type, long_range_past_key_values=long_range_past_key_values, faiss_indexes=faiss_indexes) |
|
|
|
prev_end_loc = end_loc |
|
if end_loc == input_ids.size(-1): |
|
break |
|
if long_range_past_key_values is not None: |
|
return long_range_past_key_values |
|
else: |
|
return faiss_indexes |
|
|
|
def cache(self, |
|
to_cache:List, |
|
cache_type:str='manual', |
|
long_range_past_key_values:List=None, |
|
faiss_indexes:faiss.IndexFlatIP=None, |
|
max_length_cache=100000, |
|
verbose=False): |
|
if long_range_past_key_values is not None and faiss_indexes is not None: |
|
raise NotImplementedError("Using faiss and passing key value pairs manually are mutually exclusive right now.") |
|
|
|
if cache_type=='faiss': |
|
one_hot_encodings = F.one_hot(torch.arange(0, self.config.n_heads*self.config.num_hidden_layers))*10 |
|
if faiss_indexes is None: |
|
faiss_indexes = (faiss.IndexFlatIP(to_cache[0][0].size(-1)+one_hot_encodings.size(-1)), faiss.IndexFlatIP(to_cache[0][1].size(-1)*2)) |
|
kn_index, kv_index = faiss_indexes |
|
for b_idx, (k, v) in enumerate(to_cache): |
|
k_n = (k/vector_norm(k, ord=2, dim=-1, keepdim=True)).to('cpu') |
|
k_n = torch.concat([rearrange(k_n, 'b h s d -> b (h s) d', h=self.config.n_heads), one_hot_encodings[self.config.n_heads*b_idx:self.config.n_heads*(b_idx+1)].unsqueeze(0).repeat_interleave(repeats=k.size(-2), dim=-2)], dim=-1) |
|
kn_index.add(k_n.squeeze().numpy()) |
|
|
|
k= rearrange(k, 'b h s d -> b (h s) d', h=self.config.n_heads) |
|
v= rearrange(v, 'b h s d -> b (h s) d', h=self.config.n_heads) |
|
kv_index.add(torch.concat([v.squeeze(), k.squeeze()], dim=1).to('cpu').numpy()) |
|
else: |
|
if long_range_past_key_values is None: |
|
long_range_past_key_values = [(k.to(self.memory_device),v.to(self.memory_device)) for k,v in to_cache] |
|
else: |
|
long_range_past_key_values = [ |
|
( |
|
torch.concat([kv[0], to_cache[ind][0].to(self.memory_device)], dim=2), |
|
torch.concat([kv[1], to_cache[ind][1].to(self.memory_device)], dim=2) |
|
) |
|
for ind, kv in enumerate(long_range_past_key_values) |
|
] |
|
if long_range_past_key_values is not None: |
|
if long_range_past_key_values[0][0].size(-2) > max_length_cache: |
|
long_range_past_key_values = [ |
|
( |
|
kv[0][:, :, -max_length_cache:], |
|
kv[1][:, :, -max_length_cache:] |
|
) |
|
for kv in long_range_past_key_values] |
|
if verbose: |
|
if cache_type == 'faiss': |
|
print(f"{kn_index.ntotal} keys in faiss index") |
|
else: |
|
print(f"{long_range_past_key_values[0][0].size(-2)} cached kvs") |
|
|
|
return long_range_past_key_values, (kn_index, kv_index) if cache_type == 'faiss' else None |
|
|
|
def prepare_inputs_for_generation( |
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
|
): |
|
if past_key_values: |
|
input_ids = input_ids[:, -1:] |
|
|
|
position_ids = kwargs.get("position_ids", None) |
|
if attention_mask is not None and position_ids is None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
if past_key_values: |
|
position_ids = position_ids[:, -1].unsqueeze(-1) |
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
else: |
|
model_inputs = {"input_ids": input_ids} |
|
|
|
model_inputs.update( |
|
{ |
|
"position_ids": position_ids, |
|
"past_key_values": past_key_values, |
|
"use_cache": kwargs.get("use_cache"), |
|
"attention_mask": attention_mask, |
|
'use_active_externalism': kwargs.get('use_active_externalism'), |
|
'topk': kwargs.get('topk', None), |
|
} |
|
) |
|
return model_inputs |
|
|
|
@staticmethod |
|
def _reorder_cache(past_key_values, beam_idx): |
|
reordered_past = () |
|
for layer_past in past_key_values: |
|
reordered_past += ( |
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), |
|
) |
|
return reordered_past |
|
|