|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
from .configuration_cognitivess import CognitivessConfig |
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__(self, hidden_size, num_attention_heads, dropout_prob=0.0): |
|
super().__init__() |
|
self.num_attention_heads = num_attention_heads |
|
self.attention_head_size = hidden_size // num_attention_heads |
|
self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
|
self.query = nn.Linear(hidden_size, self.all_head_size) |
|
self.key = nn.Linear(hidden_size, self.all_head_size) |
|
self.value = nn.Linear(hidden_size, self.all_head_size) |
|
self.dense = nn.Linear(hidden_size, hidden_size) |
|
|
|
self.dropout = nn.Dropout(dropout_prob) |
|
|
|
def forward(self, hidden_states, attention_mask=None): |
|
batch_size, seq_length, hidden_size = hidden_states.size() |
|
|
|
query_layer = self.query(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2) |
|
key_layer = self.key(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2) |
|
value_layer = self.value(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2) |
|
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|
attention_scores = attention_scores / torch.sqrt(torch.tensor(self.attention_head_size, dtype=torch.float32)) |
|
|
|
if attention_mask is not None: |
|
attention_scores = attention_scores + attention_mask |
|
|
|
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1) |
|
attention_probs = self.dropout(attention_probs) |
|
|
|
context_layer = torch.matmul(attention_probs, value_layer).transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_size) |
|
output_layer = self.dense(context_layer) |
|
|
|
return output_layer |
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, hidden_size, intermediate_size, hidden_act, mlp_bias): |
|
super().__init__() |
|
self.dense = nn.Linear(hidden_size, intermediate_size, bias=mlp_bias) |
|
self.activation = nn.SiLU() if hidden_act == "silu" else nn.ReLU() |
|
self.output = nn.Linear(intermediate_size, hidden_size, bias=mlp_bias) |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.activation(hidden_states) |
|
hidden_states = self.output(hidden_states) |
|
return hidden_states |
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__(self, hidden_size, num_attention_heads, intermediate_size, hidden_act, layer_norm_eps, mlp_bias, attention_dropout): |
|
super().__init__() |
|
self.attention = MultiHeadAttention(hidden_size, num_attention_heads, dropout_prob=attention_dropout) |
|
self.feed_forward = FeedForward(hidden_size, intermediate_size, hidden_act, mlp_bias) |
|
self.layer_norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps) |
|
self.layer_norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps) |
|
self.dropout = nn.Dropout(0.1) |
|
|
|
def forward(self, hidden_states, attention_mask=None): |
|
|
|
attention_output = self.attention(hidden_states, attention_mask) |
|
hidden_states = self.layer_norm1(hidden_states + attention_output) |
|
|
|
|
|
feed_forward_output = self.feed_forward(hidden_states) |
|
hidden_states = self.layer_norm2(hidden_states + feed_forward_output) |
|
|
|
return hidden_states |
|
|
|
class CognitivessForCausalLM(PreTrainedModel): |
|
config_class = CognitivessConfig |
|
_no_split_modules = [] |
|
|
|
def __init__(self, config: CognitivessConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
self.layers = nn.ModuleList([ |
|
TransformerBlock( |
|
config.hidden_size, |
|
config.num_attention_heads, |
|
config.intermediate_size, |
|
config.hidden_act, |
|
config.layer_norm_eps, |
|
config.mlp_bias, |
|
config.attention_dropout |
|
) |
|
for _ in range(config.num_hidden_layers) |
|
]) |
|
self.pooler = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.activation = nn.SiLU() if config.hidden_act == "silu" else nn.ReLU() |
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
|
|
embeddings = self.embeddings(input_ids) |
|
position_ids = torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device) |
|
position_embeddings = self.position_embeddings(position_ids) |
|
embeddings = embeddings + position_embeddings |
|
|
|
|
|
hidden_states = embeddings |
|
for layer in self.layers: |
|
hidden_states = layer(hidden_states, attention_mask) |
|
|
|
|
|
pooled_output = self.pooler(hidden_states[:, 0]) |
|
return pooled_output |
|
|