|
|
|
|
|
import numpy as np |
|
from typing import List, Optional, Tuple, Union, Dict |
|
from tqdm import tqdm |
|
from einops import rearrange, repeat |
|
import torch |
|
from torch import nn |
|
from transformers.activations import ACT2FN |
|
from transformers.cache_utils import Cache, DynamicCache, StaticCache |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPast, |
|
) |
|
from transformers import AutoConfig |
|
from transformers import AutoModel |
|
from transformers.modeling_utils import PreTrainedModel |
|
try: |
|
from flash_attn.flash_attn_interface import flash_attn_func |
|
except Exception as e: |
|
print( |
|
f"Could not import flash attention. " |
|
) |
|
flash_attn_func = None |
|
|
|
PHARIAEMBED_TYPE = "phariaembed" |
|
|
|
class RotaryConfig(): |
|
def __init__( |
|
self, |
|
dimensions: int = 0, |
|
base: int = 10000, |
|
max_seq_length: int = 2048 |
|
): |
|
self.dimensions = dimensions |
|
self.base = base |
|
self.max_seq_length = max_seq_length |
|
|
|
class PhariaAdapterConfig: |
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
intermediate_size: int, |
|
mlp_bias: bool, |
|
hidden_act: str |
|
): |
|
self.hidden_size = hidden_size |
|
self.intermediate_size = intermediate_size |
|
self.mlp_bias = mlp_bias |
|
self.hidden_act = hidden_act |
|
|
|
|
|
def to_dict(self): |
|
return { |
|
"hidden_size": self.hidden_size, |
|
"intermediate_size": self.intermediate_size, |
|
"mlp_bias": self.mlp_bias, |
|
"hidden_act": self.hidden_act, |
|
} |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict): |
|
return cls(**config_dict) |
|
|
|
|
|
|
|
class PhariaConfig(PretrainedConfig): |
|
model_type = "phariaembed" |
|
|
|
def __init__( |
|
self, |
|
pad_token_id=None, |
|
bos_token_id=1, |
|
eos_token_id=2, |
|
hidden_act="gelu", |
|
hidden_size=512, |
|
bias_name=None, |
|
initializer_range=0.02, |
|
intermediate_size=2048, |
|
max_position_embeddings=8192, |
|
|
|
model_type="phariaembed", |
|
num_attention_heads=4, |
|
num_hidden_layers=4, |
|
num_key_value_heads=2, |
|
torch_dtype="bfloat16", |
|
transformers_version="4.31.0.dev0", |
|
use_cache=True, |
|
vocab_size=128000, |
|
mlp_bias=True, |
|
attention_bias=True, |
|
tie_word_embeddings=False, |
|
attention_dropout=0.0, |
|
causal_attention=True, |
|
rope_theta=1000000, |
|
rope_scaling=None, |
|
mlp_adapter_config=None, |
|
attn_adapter_config=None, |
|
_attn_implementation='eager', |
|
embedding_head_out=1024, |
|
lora_config=None, |
|
pooling_method=None, |
|
layer_norm_epsilon=1e-05, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
pad_token_id=pad_token_id, |
|
bos_token_id=bos_token_id, |
|
eos_token_id=eos_token_id, |
|
tie_word_embeddings=tie_word_embeddings, |
|
**kwargs, |
|
) |
|
|
|
self.pad_token_id = pad_token_id |
|
self.bos_token_id = bos_token_id |
|
self.eos_token_id = eos_token_id |
|
self.hidden_act = hidden_act |
|
self.hidden_size = hidden_size |
|
self.initializer_range = initializer_range |
|
self.intermediate_size = intermediate_size |
|
self.max_position_embeddings = max_position_embeddings |
|
self.model_type = model_type |
|
self.num_attention_heads = num_attention_heads |
|
self.num_hidden_layers = num_hidden_layers |
|
self.num_key_value_heads = num_key_value_heads |
|
self.torch_dtype = torch_dtype |
|
self.causal_attention = causal_attention |
|
self.attn_adapter_config = attn_adapter_config |
|
self.mlp_adapter_config = mlp_adapter_config |
|
self.bias_name = bias_name |
|
self.transformers_version = transformers_version |
|
self.use_cache = use_cache |
|
self.vocab_size = vocab_size |
|
self.mlp_bias = mlp_bias |
|
self.attention_bias = attention_bias |
|
self.tie_word_embeddings = tie_word_embeddings |
|
self.attention_dropout = attention_dropout |
|
self.rope_theta = rope_theta |
|
self.rope_scaling = rope_scaling |
|
self.embedding_head_out = embedding_head_out |
|
self.pooling_method = pooling_method |
|
self.lora_config = lora_config |
|
self._attn_implementation = _attn_implementation |
|
self.layer_norm_epsilon = layer_norm_epsilon |
|
|
|
|
|
def to_dict(self): |
|
output = super(PhariaConfig, self).to_dict() |
|
if self.mlp_adapter_config is not None: |
|
output["mlp_adapter_config"] = self.mlp_adapter_config.to_dict() |
|
if self.attn_adapter_config is not None: |
|
output["attn_adapter_config"] = self.attn_adapter_config.to_dict() |
|
return output |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict, **kwargs): |
|
if 'use_cache' in config_dict: |
|
del config_dict['use_cache'] |
|
|
|
if 'mlp_adapter_config' in config_dict and config_dict["mlp_adapter_config"] is not None: |
|
config_dict["mlp_adapter_config"] = PhariaAdapterConfig.from_dict(config_dict["mlp_adapter_config"]) |
|
if 'attn_adapter_config' in config_dict and config_dict["attn_adapter_config"] is not None: |
|
config_dict["attn_adapter_config"] = PhariaAdapterConfig.from_dict(config_dict["attn_adapter_config"]) |
|
return cls(**config_dict, **kwargs) |
|
|
|
|
|
def reshape_complex_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: |
|
ndim = x.ndim |
|
assert 0 <= 1 < ndim |
|
assert freqs_cis.shape[0] == x.shape[1] |
|
assert freqs_cis.shape[1] == x.shape[-1] |
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
|
return freqs_cis.view(*shape) |
|
|
|
def precompute_freqs_cis( |
|
dim: int, |
|
end: int, |
|
theta: float, |
|
device: torch.device, |
|
) -> torch.Tensor: |
|
theta = float(theta) |
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim)).to(device) |
|
t = torch.arange(end, device=device) |
|
freqs = torch.outer(t, freqs).float() |
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
|
return freqs_cis.to(device) |
|
|
|
|
|
def apply_complex_rotary_emb( |
|
xq: torch.Tensor, |
|
xk: torch.Tensor, |
|
freqs_cis: torch.Tensor, |
|
query_position_ids: Optional[torch.Tensor], |
|
key_position_ids: Optional[torch.Tensor], |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
|
xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
|
|
|
if query_position_ids is None: |
|
freqs_cis_q = reshape_complex_for_broadcast(freqs_cis, xq_complex) |
|
else: |
|
freqs_cis_q = vector_gather_complex(freqs_cis, query_position_ids) |
|
|
|
if key_position_ids is None: |
|
freqs_cis_k = reshape_complex_for_broadcast(freqs_cis, xq_complex) |
|
else: |
|
freqs_cis_k = vector_gather_complex(freqs_cis, key_position_ids) |
|
|
|
xq_out = torch.view_as_real(xq_complex * freqs_cis_q).flatten(3) |
|
xk_out = torch.view_as_real(xk_complex * freqs_cis_k).flatten(3) |
|
return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
|
class RotaryEmbeddingComplex(torch.nn.Module): |
|
""" |
|
Relative rotary position embedding based on |
|
* RoFormer: Enhanced Transformer with Rotary Position Embedding (https://arxiv.org/abs/2104.09864) |
|
* Rotary Embeddings: A Relative Revolution (https://blog.eleuther.ai/rotary-embeddings/) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
config: RotaryConfig, |
|
device: torch.device, |
|
) -> None: |
|
super().__init__() |
|
assert config.dimensions > 1, "RotaryEmbedding cannot use `dim` == 1, this results in weird reshape errors" |
|
|
|
freqs_cis = precompute_freqs_cis( |
|
dim=config.dimensions, |
|
end=config.max_seq_length, |
|
theta=config.base, |
|
device=device, |
|
) |
|
|
|
|
|
self.freqs_cis_real = freqs_cis.real |
|
self.freqs_cis_imag = freqs_cis.imag |
|
|
|
def forward( |
|
self, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
query_position_ids: Optional[torch.Tensor] = None, |
|
key_position_ids: Optional[torch.Tensor] = None, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
query, key = apply_complex_rotary_emb( |
|
xq=rearrange(query, "sq b nh hh -> b sq nh hh"), |
|
xk=rearrange(key, "sq b nh hh -> b sq nh hh"), |
|
freqs_cis=torch.complex(self.freqs_cis_real.float(), self.freqs_cis_imag.float()), |
|
query_position_ids=query_position_ids, |
|
key_position_ids=key_position_ids, |
|
) |
|
return rearrange(query, "b sq nh hh -> sq b nh hh"), rearrange(key, "b sq nh hh -> sq b nh hh") |
|
|
|
def vector_gather(vectors: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Gathers (batched) vectors according to indices. |
|
""" |
|
vectors = repeat(vectors, "sq b nh d -> sq b B nh d", B=indices.shape[1]).squeeze(1) |
|
indices = repeat( |
|
indices, |
|
"sq b -> sq b nh d", |
|
nh=vectors.shape[-2], |
|
d=vectors.shape[-1], |
|
) |
|
|
|
out = torch.gather(vectors, dim=0, index=indices) |
|
|
|
return out |
|
|
|
|
|
def vector_gather_complex(vectors: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Gathers (batched) vectors according to indices. |
|
""" |
|
vectors = repeat(vectors, "sq d -> sq B nh d", B=indices.shape[1], nh=1) |
|
indices = repeat( |
|
indices, |
|
"sq b -> sq b nh d", |
|
nh=1, |
|
d=vectors.shape[-1], |
|
) |
|
|
|
out = torch.gather(vectors, dim=0, index=indices) |
|
|
|
out = rearrange(out, "sq b nh hh -> b sq nh hh") |
|
|
|
return out |
|
|
|
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)""" |
|
bs, slen, n_kv_heads, head_dim = x.shape |
|
if n_rep == 1: |
|
return x |
|
return ( |
|
x[:, :, :, None, :] |
|
.expand(bs, slen, n_kv_heads, n_rep, head_dim) |
|
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) |
|
) |
|
|
|
|
|
|
|
class PhariaAttention(nn.Module): |
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
def __init__(self, config: PhariaConfig, layer_idx: Optional[int] = None): |
|
super().__init__() |
|
self.config = config |
|
self.layer_idx = layer_idx |
|
self.attention_dropout = config.attention_dropout |
|
self.hidden_size = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = self.hidden_size // self.num_heads |
|
self.num_key_value_heads = config.num_key_value_heads |
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
self.max_position_embeddings = config.max_position_embeddings |
|
self.rope_theta = config.rope_theta |
|
self.is_causal = config.causal_attention |
|
self.query_key_scaling_factor = 1 / (self.head_dim ** 0.5) |
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_size: |
|
raise ValueError( |
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
|
f" and `num_heads`: {self.num_heads})." |
|
) |
|
|
|
self.q_proj = nn.Linear( |
|
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias |
|
) |
|
self.k_proj = nn.Linear( |
|
self.hidden_size, |
|
self.num_key_value_heads * self.head_dim, |
|
bias=config.attention_bias, |
|
) |
|
self.v_proj = nn.Linear( |
|
self.hidden_size, |
|
self.num_key_value_heads * self.head_dim, |
|
bias=config.attention_bias, |
|
) |
|
self.o_proj = nn.Linear( |
|
self.hidden_size, self.hidden_size, bias=config.attention_bias |
|
) |
|
|
|
self._init_rope() |
|
|
|
def _init_rope(self): |
|
self.rotary_emb = RotaryEmbeddingComplex( |
|
config=RotaryConfig( |
|
dimensions=self.head_dim, |
|
max_seq_length=self.max_position_embeddings, |
|
base=self.rope_theta |
|
), |
|
device='cuda:0' |
|
) |
|
|
|
def prepare_query_key_value( |
|
self, |
|
hidden_states: torch.Tensor, |
|
position_ids: torch.Tensor, |
|
past_key_value: Optional[Cache] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
): |
|
query_states = rearrange(self.q_proj(hidden_states), "b sq (np hn) -> sq b np hn", np=self.num_heads) |
|
key_states = rearrange(self.k_proj(hidden_states), "b sq (np hn) -> sq b np hn", np=self.num_key_value_heads) |
|
value_states = rearrange(self.v_proj(hidden_states), "b sq (np hn) -> sq b np hn", np=self.num_key_value_heads) |
|
|
|
|
|
position_ids = rearrange(position_ids, 'b sq -> sq b') |
|
query_states, key_states = self.rotary_emb( |
|
query_states, key_states, query_position_ids=position_ids, key_position_ids=position_ids |
|
) |
|
|
|
if past_key_value is not None: |
|
|
|
cache_kwargs = {"cache_position": cache_position} |
|
key_states, value_states = past_key_value.update( |
|
key_states, value_states, self.layer_idx, cache_kwargs |
|
) |
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
return query_states, key_states, value_states |
|
|
|
def forward ( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
softmax_in_fp32: Optional[bool] = False |
|
): |
|
bsz, _, _ = hidden_states.size() |
|
query, key, value = self.prepare_query_key_value( |
|
hidden_states, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
cache_position=cache_position |
|
) |
|
seq_length, batch_size, _, head_dim = query.shape |
|
|
|
query = rearrange(query, "sq bs nh hd -> sq (bs nh) hd") |
|
key = rearrange(key, "sq bs nh hd -> sq (bs nh) hd") |
|
value = rearrange(value, "sq bs nh hd -> sq (bs nh) hd") |
|
|
|
matmul_result = torch.empty( |
|
query.size(1), |
|
query.size(0), |
|
key.size(0), |
|
dtype=query.dtype, |
|
device=query.device, |
|
) |
|
|
|
|
|
matmul_result = torch.baddbmm( |
|
matmul_result, |
|
query.transpose(0, 1), |
|
key.transpose(0, 1).transpose(1, 2), |
|
beta=0.0, |
|
alpha=self.query_key_scaling_factor, |
|
) |
|
|
|
attention_scores = rearrange(matmul_result, "(b n) s_q s_k -> b n s_q s_k", b=batch_size) |
|
if softmax_in_fp32 and attention_scores.dtype != torch.float32: |
|
input_dtype = attention_scores.dtype |
|
attention_scores = attention_scores.float() |
|
else: |
|
input_dtype = None |
|
|
|
|
|
causal_mask = torch.triu( |
|
torch.ones(seq_length, seq_length, device=query.device), |
|
diagonal=1 |
|
).bool() |
|
|
|
attention_scores.masked_fill_(causal_mask.to(attention_scores.device), -10000.0) |
|
probs = torch.nn.functional.softmax(attention_scores, dim=-1) |
|
if softmax_in_fp32 and input_dtype is not None: |
|
probs = probs.to(input_dtype) |
|
|
|
|
|
probs = rearrange(probs, "b n s_q s_k -> (b n) s_q s_k") |
|
hidden_state = torch.bmm(probs.to(dtype=value.dtype), value.transpose(0, 1)) |
|
attn_output = rearrange(hidden_state, "(b np) sq hn -> b sq (np hn)", b=bsz) |
|
|
|
|
|
attn_output = nn.functional.linear(attn_output, self.o_proj.weight, None) + self.o_proj.bias |
|
|
|
return attn_output, _, past_key_value |
|
|
|
class PhariaFlashAttention2(PhariaAttention): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
@staticmethod |
|
def get_max_seq_length(cumulative_seq_lengths: torch.Tensor) -> int: |
|
return int((cumulative_seq_lengths[1:] - cumulative_seq_lengths[:-1]).max().item()) |
|
|
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
softmax_in_fp32: Optional[bool] = False |
|
): |
|
assert flash_attn_func is not None, "Please install Flash Attention via optimization requirements" |
|
query, key, value = self.prepare_query_key_value(hidden_states, position_ids=position_ids) |
|
|
|
batch_size = query.shape[1] |
|
|
|
|
|
query = rearrange(query, "s_q b n h -> b s_q n h") |
|
key = rearrange(key, "s_k b n h -> b s_k n h") |
|
value = rearrange(value, "s_k b n h -> b s_k n h") |
|
|
|
attention_output = flash_attn_func( |
|
q=query, |
|
k=key, |
|
v=value, |
|
causal=self.is_causal, |
|
softmax_scale=self.query_key_scaling_factor |
|
) |
|
attention_output = rearrange(attention_output, "b sq np hn -> b sq (np hn)", b=batch_size) |
|
|
|
attention_output = nn.functional.linear(attention_output, self.o_proj.weight, None) + self.o_proj.bias |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attention_output, attn_weights, past_key_value |
|
|
|
|
|
ATTN_IMPLEMENTATION = { |
|
'flash_attention_2': PhariaFlashAttention2, |
|
'sdpa': PhariaAttention, |
|
'eager': PhariaAttention |
|
} |
|
|
|
|
|
class PhariaMLP(nn.Module): |
|
def __init__(self, config, layer_idx: int): |
|
super().__init__() |
|
self.layer_idx = layer_idx |
|
self.config = config |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = config.intermediate_size |
|
self.up_proj = nn.Linear( |
|
self.hidden_size, self.intermediate_size, bias=config.mlp_bias |
|
) |
|
self.down_proj = nn.Linear( |
|
self.intermediate_size, self.hidden_size, bias=config.mlp_bias |
|
) |
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
|
def forward(self, x): |
|
x = self.up_proj(x) |
|
x = self.act_fn(x) |
|
if not self.down_proj.bias is None: |
|
|
|
o = nn.functional.linear(x, self.down_proj.weight, None) + self.down_proj.bias |
|
else: |
|
o = self.down_proj(x) |
|
return o |
|
|
|
|
|
class PhariaDecoderLayer(nn.Module): |
|
def __init__(self, config: PhariaConfig, layer_idx: int): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.self_attn = ATTN_IMPLEMENTATION[config._attn_implementation](config=config, layer_idx=layer_idx) |
|
|
|
self.post_mlp_adapter = None |
|
if config.mlp_adapter_config: |
|
self.post_mlp_adapter = PhariaMLP(config.mlp_adapter_config, layer_idx=layer_idx) |
|
self.post_attn_adapter = None |
|
if config.attn_adapter_config: |
|
self.post_attn_adapter = PhariaMLP(config.attn_adapter_config, layer_idx=layer_idx) |
|
|
|
self.mlp = PhariaMLP(config, layer_idx=layer_idx) |
|
self.input_layernorm = nn.LayerNorm(config.hidden_size) |
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) |
|
self.layer_idx = layer_idx |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
) -> Tuple[ |
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] |
|
]: |
|
residual = hidden_states |
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
) |
|
|
|
hidden_states = residual + hidden_states |
|
|
|
if self.post_attn_adapter: |
|
hidden_states = self.post_attn_adapter(hidden_states) + hidden_states |
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
|
hidden_states = self.mlp(hidden_states) |
|
|
|
hidden_states = residual + hidden_states |
|
if self.post_mlp_adapter: |
|
hidden_states = self.post_mlp_adapter(hidden_states) + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
return outputs |
|
|
|
class PhariaPreTrainedModel(PreTrainedModel): |
|
config_class = PhariaConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = False |
|
_no_split_modules = ["PhariaDecoderLayer"] |
|
_skip_keys_device_placement = ["past_key_values"] |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
_supports_cache_class = True |
|
_supports_static_cache = True |
|
|
|
|
|
def _init_weights(self, module): |
|
std = self.config.initializer_range |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
class PhariaModel(PhariaPreTrainedModel): |
|
config_class = PhariaConfig |
|
|
|
def __init__(self, config: PhariaConfig): |
|
super().__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
|
|
self.embed_tokens = nn.Embedding( |
|
config.vocab_size, config.hidden_size, self.padding_idx |
|
) |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
PhariaDecoderLayer(config, layer_idx) |
|
for layer_idx in range(config.num_hidden_layers) |
|
] |
|
) |
|
|
|
self.norm = nn.LayerNorm(config.hidden_size) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError( |
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
|
) |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
return_legacy_cache = False |
|
if use_cache and not isinstance( |
|
past_key_values, Cache |
|
): |
|
return_legacy_cache = True |
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
|
|
|
if cache_position is None: |
|
past_seen_tokens = ( |
|
past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
) |
|
cache_position = torch.arange( |
|
past_seen_tokens, |
|
past_seen_tokens + inputs_embeds.shape[1], |
|
device=inputs_embeds.device, |
|
) |
|
if position_ids is None: |
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
if self.config.causal_attention: |
|
mask = self._update_causal_mask( |
|
attention_mask, |
|
inputs_embeds, |
|
cache_position, |
|
past_key_values, |
|
output_attentions, |
|
) |
|
else: |
|
mask = self._create_bidirectional_attention_mask( |
|
attention_mask, |
|
inputs_embeds.dtype |
|
) |
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
next_decoder_cache = None |
|
|
|
for decoder_layer in self.layers: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if use_cache: |
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
if return_legacy_cache: |
|
next_cache = next_cache.to_legacy_cache() |
|
|
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] |
|
if v is not None |
|
) |
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
|
|
def _create_bidirectional_attention_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: |
|
bidirectional_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2).to(dtype) |
|
bidirectional_mask = 1 - bidirectional_mask |
|
dtype_min_value = torch.finfo(dtype).min |
|
attention_mask = bidirectional_mask.masked_fill(bidirectional_mask == 1, dtype_min_value) |
|
|
|
return attention_mask |
|
|
|
|
|
def _update_causal_mask( |
|
self, |
|
attention_mask: torch.Tensor, |
|
input_tensor: torch.Tensor, |
|
cache_position: torch.Tensor, |
|
past_key_values: Cache, |
|
output_attentions: bool, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
if self.config._attn_implementation == "flash_attention_2": |
|
if attention_mask is not None and 0.0 in attention_mask: |
|
return attention_mask |
|
return None |
|
|
|
|
|
|
|
|
|
past_seen_tokens = ( |
|
past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
) |
|
using_static_cache = isinstance(past_key_values, StaticCache) |
|
|
|
|
|
if ( |
|
self.config._attn_implementation == "sdpa" |
|
and not using_static_cache |
|
and not output_attentions |
|
): |
|
if AttentionMaskConverter._ignore_causal_mask_sdpa( |
|
attention_mask, |
|
inputs_embeds=input_tensor, |
|
past_key_values_length=past_seen_tokens, |
|
is_training=self.training, |
|
): |
|
return None |
|
|
|
dtype, device = input_tensor.dtype, input_tensor.device |
|
min_dtype = torch.finfo(dtype).min |
|
sequence_length = input_tensor.shape[1] |
|
if using_static_cache: |
|
target_length = past_key_values.get_max_length() |
|
else: |
|
target_length = ( |
|
attention_mask.shape[-1] |
|
if isinstance(attention_mask, torch.Tensor) |
|
else past_seen_tokens + sequence_length + 1 |
|
) |
|
|
|
if attention_mask is not None and attention_mask.dim() == 4: |
|
|
|
if attention_mask.max() != 0: |
|
raise ValueError( |
|
"Custom 4D attention mask should be passed in inverted form with max==0`" |
|
) |
|
causal_mask = attention_mask |
|
else: |
|
causal_mask = torch.full( |
|
(sequence_length, target_length), |
|
fill_value=min_dtype, |
|
dtype=dtype, |
|
device=device, |
|
) |
|
if sequence_length != 1: |
|
causal_mask = torch.triu(causal_mask, diagonal=1) |
|
causal_mask *= torch.arange( |
|
target_length, device=device |
|
) > cache_position.reshape(-1, 1) |
|
causal_mask = causal_mask[None, None, :, :].expand( |
|
input_tensor.shape[0], 1, -1, -1 |
|
) |
|
if attention_mask is not None: |
|
causal_mask = ( |
|
causal_mask.clone() |
|
) |
|
mask_length = attention_mask.shape[-1] |
|
padding_mask = ( |
|
causal_mask[:, :, :, :mask_length] |
|
+ attention_mask[:, None, None, :] |
|
) |
|
padding_mask = padding_mask == 0 |
|
causal_mask[:, :, :, :mask_length] = causal_mask[ |
|
:, :, :, :mask_length |
|
].masked_fill(padding_mask, min_dtype) |
|
if ( |
|
self.config._attn_implementation == "sdpa" |
|
and attention_mask is not None |
|
and attention_mask.device.type == "cuda" |
|
and not output_attentions |
|
): |
|
|
|
|
|
|
|
causal_mask = AttentionMaskConverter._unmask_unattended( |
|
causal_mask, min_dtype |
|
) |
|
|
|
return causal_mask |
|
|
|
class Embeddinghead(torch.nn.Module): |
|
def __init__( |
|
self, |
|
pooling_method: str |
|
): |
|
super().__init__() |
|
self.pooling_method = pooling_method |
|
|
|
def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: |
|
""" |
|
Args: |
|
hidden_state: [b, n, d] |
|
attention_mask: [b, n] |
|
""" |
|
hidden_state = hidden_state.to(attention_mask.device) |
|
if self.pooling_method == 'cls': |
|
embedding = hidden_state[:, 0] |
|
elif self.pooling_method == 'lasttoken': |
|
b, n, d = hidden_state.size() |
|
|
|
reversed_mask = torch.flip(attention_mask, dims=(1,)) |
|
argmax_reverse = torch.argmax(reversed_mask, dim=1, keepdim=False) |
|
|
|
gather_indices = attention_mask.size(1) - argmax_reverse - 1 |
|
gather_indices = torch.clamp(gather_indices, min=0) |
|
gather_indices = gather_indices.unsqueeze(-1).repeat(1, d) |
|
gather_indices = gather_indices.unsqueeze(1) |
|
assert gather_indices.shape == (b, 1, d) |
|
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand((b, n, d)).float() |
|
embedding = torch.gather(hidden_state * input_mask_expanded, 1, gather_indices).squeeze(dim=1) |
|
|
|
elif self.pooling_method in ['mean', 'weighted_mean']: |
|
if self.pooling_method == 'weighted_mean': |
|
attention_mask *= attention_mask.cumsum(dim=1) |
|
s = torch.sum(hidden_state * attention_mask.unsqueeze(-1).float(), dim=1) |
|
d = attention_mask.sum(dim=1, keepdim=True).float() |
|
embedding = s / d |
|
else: raise NotImplementedError(f"Unknown pooling method: {self.pooling_method}") |
|
|
|
return embedding |
|
|
|
|
|
|
|
class PhariaForEmbedding(PhariaPreTrainedModel): |
|
def __init__(self, config, tokenizer): |
|
super().__init__(config) |
|
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" |
|
self._use_sdpa = config._attn_implementation == "sdpa" |
|
self.model = PhariaModel(config) |
|
self.tokenizer = tokenizer |
|
self.tokenizer.pad_token_id = 1 |
|
|
|
self.embedding_head = Embeddinghead(pooling_method=self.config.pooling_method) |
|
|
|
def encode_queries(self, queries: Union[List[str], str], **kwargs) -> np.ndarray: |
|
"""Used for encoding the queries of retrieval or reranking tasks""" |
|
return self.encode(queries, **kwargs) |
|
|
|
def encode_corpus(self, corpus: Union[List[str], str, List[Dict[str, str]]], **kwargs) -> np.ndarray: |
|
"""Used for encoding the corpus of retrieval tasks""" |
|
if isinstance(corpus, dict): |
|
corpus = [corpus] |
|
if isinstance(corpus, list) and isinstance(corpus[0], dict): |
|
corpus = [ |
|
doc["text"] for doc in corpus |
|
] |
|
return self.encode(corpus, **kwargs) |
|
|
|
@torch.no_grad() |
|
def encode( |
|
self, |
|
sentences: Union[List[str], str], |
|
batch_size: int = 256, |
|
max_length: int = 512, |
|
instruction: str = "", |
|
user_token: str = "<|start_header_id|>user<|end_header_id|>", |
|
embed_instruction: bool = False, |
|
embed_eos_token: str = "\n<|embed|>\n", |
|
convert_to_tensor: bool = False, |
|
add_special_tokens: bool = True, |
|
**kwargs, |
|
) -> np.ndarray: |
|
|
|
input_was_string = False |
|
if isinstance(sentences, str): |
|
sentences = [sentences] |
|
input_was_string = True |
|
|
|
all_embeddings, all_kv_caches = [], [] |
|
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences)<256): |
|
sentences_batch = [ |
|
user_token + instruction + embed_eos_token + s for s in sentences[start_index:start_index + batch_size] |
|
] |
|
|
|
inputs = self.tokenizer( |
|
sentences_batch, |
|
padding=True, |
|
truncation=True, |
|
return_tensors='pt', |
|
max_length=max_length, |
|
add_special_tokens=add_special_tokens, |
|
).to(self.device) |
|
|
|
last_hidden_state = self.model(inputs['input_ids'])['last_hidden_state'] |
|
|
|
if ("mean" in self.embedding_head.pooling_method) and not embed_instruction: |
|
instruct_with_special_tokens = user_token + instruction + embed_eos_token |
|
|
|
instruction_tokens = self.tokenizer( |
|
instruct_with_special_tokens, |
|
padding=False, |
|
truncation=True, |
|
max_length=max_length, |
|
add_special_tokens=add_special_tokens, |
|
)["input_ids"] |
|
inputs['attention_mask'][:, :len(instruction_tokens)] = 0 |
|
|
|
embeddings = self.embedding_head(last_hidden_state, inputs['attention_mask']) |
|
|
|
if convert_to_tensor: |
|
all_embeddings.append(embeddings) |
|
else: |
|
|
|
all_embeddings.append(embeddings.cpu().to(torch.float32).numpy()) |
|
|
|
all_embeddings = ( |
|
torch.cat(all_embeddings, dim=0) if convert_to_tensor else np.concatenate(all_embeddings, axis=0) |
|
) |
|
if input_was_string: |
|
all_embeddings = all_embeddings[0] |
|
|
|
return all_embeddings |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AutoModel.register(PhariaConfig, PhariaForEmbedding) |
|
|
|
PhariaForEmbedding.register_for_auto_class("AutoModel") |
|
|