cognitivess / cognitivess_model /modeling_cognitivess.py
cognitivess's picture
Update cognitivess_model/modeling_cognitivess.py
0cfbce2 verified
raw
history blame
5.47 kB
# cognitivess_model/modeling_cognitivess.py
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_cognitivess import CognitivessConfig
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, dropout_prob=0.0):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_size = hidden_size // num_attention_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(hidden_size, self.all_head_size)
self.key = nn.Linear(hidden_size, self.all_head_size)
self.value = nn.Linear(hidden_size, self.all_head_size)
self.dense = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(dropout_prob)
def forward(self, hidden_states, attention_mask=None):
batch_size, seq_length, hidden_size = hidden_states.size()
query_layer = self.query(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
key_layer = self.key(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
value_layer = self.value(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / torch.sqrt(torch.tensor(self.attention_head_size, dtype=torch.float32))
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer).transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_size)
output_layer = self.dense(context_layer)
return output_layer
class FeedForward(nn.Module):
def __init__(self, hidden_size, intermediate_size, hidden_act, mlp_bias):
super().__init__()
self.dense = nn.Linear(hidden_size, intermediate_size, bias=mlp_bias)
self.activation = nn.SiLU() if hidden_act == "silu" else nn.ReLU()
self.output = nn.Linear(intermediate_size, hidden_size, bias=mlp_bias)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.output(hidden_states)
return hidden_states
class TransformerBlock(nn.Module):
def __init__(self, hidden_size, num_attention_heads, intermediate_size, hidden_act, layer_norm_eps, mlp_bias, attention_dropout):
super().__init__()
self.attention = MultiHeadAttention(hidden_size, num_attention_heads, dropout_prob=attention_dropout)
self.feed_forward = FeedForward(hidden_size, intermediate_size, hidden_act, mlp_bias)
self.layer_norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.layer_norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.dropout = nn.Dropout(0.1)
def forward(self, hidden_states, attention_mask=None):
# Attention
attention_output = self.attention(hidden_states, attention_mask)
hidden_states = self.layer_norm1(hidden_states + attention_output)
# Feed Forward
feed_forward_output = self.feed_forward(hidden_states)
hidden_states = self.layer_norm2(hidden_states + feed_forward_output)
return hidden_states
class CognitivessForCausalLM(PreTrainedModel):
config_class = CognitivessConfig
_no_split_modules = [] # This line ensures that device_map='auto' works
def __init__(self, config: CognitivessConfig):
super().__init__(config)
self.config = config
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.layers = nn.ModuleList([
TransformerBlock(
config.hidden_size,
config.num_attention_heads,
config.intermediate_size,
config.hidden_act,
config.layer_norm_eps,
config.mlp_bias,
config.attention_dropout
)
for _ in range(config.num_hidden_layers)
])
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.activation = nn.SiLU() if config.hidden_act == "silu" else nn.ReLU()
def forward(self, input_ids, attention_mask=None):
# Embeddings
embeddings = self.embeddings(input_ids)
position_ids = torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device)
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
# Transformer Layers
hidden_states = embeddings
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)
# Pooler
pooled_output = self.pooler(hidden_states[:, 0])
return pooled_output