# cognitivess_model/modeling_cognitivess.py import torch import torch.nn as nn from transformers import PreTrainedModel from .configuration_cognitivess import CognitivessConfig class MultiHeadAttention(nn.Module): def __init__(self, hidden_size, num_attention_heads, dropout_prob=0.0): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_size = hidden_size // num_attention_heads self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(hidden_size, self.all_head_size) self.key = nn.Linear(hidden_size, self.all_head_size) self.value = nn.Linear(hidden_size, self.all_head_size) self.dense = nn.Linear(hidden_size, hidden_size) self.dropout = nn.Dropout(dropout_prob) def forward(self, hidden_states, attention_mask=None): batch_size, seq_length, hidden_size = hidden_states.size() query_layer = self.query(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2) key_layer = self.key(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2) value_layer = self.value(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / torch.sqrt(torch.tensor(self.attention_head_size, dtype=torch.float32)) if attention_mask is not None: attention_scores = attention_scores + attention_mask attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1) attention_probs = self.dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer).transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_size) output_layer = self.dense(context_layer) return output_layer class FeedForward(nn.Module): def __init__(self, hidden_size, intermediate_size, hidden_act, mlp_bias): super().__init__() self.dense = nn.Linear(hidden_size, intermediate_size, bias=mlp_bias) self.activation = nn.SiLU() if hidden_act == "silu" else nn.ReLU() self.output = nn.Linear(intermediate_size, hidden_size, bias=mlp_bias) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.activation(hidden_states) hidden_states = self.output(hidden_states) return hidden_states class TransformerBlock(nn.Module): def __init__(self, hidden_size, num_attention_heads, intermediate_size, hidden_act, layer_norm_eps, mlp_bias, attention_dropout): super().__init__() self.attention = MultiHeadAttention(hidden_size, num_attention_heads, dropout_prob=attention_dropout) self.feed_forward = FeedForward(hidden_size, intermediate_size, hidden_act, mlp_bias) self.layer_norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.layer_norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.dropout = nn.Dropout(0.1) def forward(self, hidden_states, attention_mask=None): # Attention attention_output = self.attention(hidden_states, attention_mask) hidden_states = self.layer_norm1(hidden_states + attention_output) # Feed Forward feed_forward_output = self.feed_forward(hidden_states) hidden_states = self.layer_norm2(hidden_states + feed_forward_output) return hidden_states class CognitivessForCausalLM(PreTrainedModel): config_class = CognitivessConfig _no_split_modules = [] # This line ensures that device_map='auto' works def __init__(self, config: CognitivessConfig): super().__init__(config) self.config = config self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.layers = nn.ModuleList([ TransformerBlock( config.hidden_size, config.num_attention_heads, config.intermediate_size, config.hidden_act, config.layer_norm_eps, config.mlp_bias, config.attention_dropout ) for _ in range(config.num_hidden_layers) ]) self.pooler = nn.Linear(config.hidden_size, config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.activation = nn.SiLU() if config.hidden_act == "silu" else nn.ReLU() def forward(self, input_ids, attention_mask=None): # Embeddings embeddings = self.embeddings(input_ids) position_ids = torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device) position_embeddings = self.position_embeddings(position_ids) embeddings = embeddings + position_embeddings # Transformer Layers hidden_states = embeddings for layer in self.layers: hidden_states = layer(hidden_states, attention_mask) # Pooler pooled_output = self.pooler(hidden_states[:, 0]) return pooled_output