File size: 5,351 Bytes
ad38e72
 
1736a41
 
ad38e72
1736a41
807cbac
ad38e72
 
1736a41
ad38e72
 
1736a41
 
ad38e72
 
 
 
1736a41
ad38e72
1736a41
 
ad38e72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1736a41
ad38e72
 
1736a41
ad38e72
 
 
df193f0
1736a41
ad38e72
 
 
1736a41
 
ad38e72
 
1736a41
ad38e72
 
 
 
 
1736a41
 
ad38e72
1736a41
ad38e72
 
 
 
 
 
1736a41
 
ad38e72
1736a41
df193f0
1736a41
 
ad38e72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1736a41
 
 
ad38e72
 
 
 
 
 
 
1736a41
 
 
ad38e72
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# cognitivess_model/modeling_cognitivess.py

import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_cognitivess import CognitivessConfig

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, dropout_prob=0.0):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = hidden_size // num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.all_head_size)
        self.value = nn.Linear(hidden_size, self.all_head_size)
        self.dense = nn.Linear(hidden_size, hidden_size)

        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, hidden_states, attention_mask=None):
        batch_size, seq_length, hidden_size = hidden_states.size()
        
        query_layer = self.query(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
        key_layer = self.key(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
        value_layer = self.value(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
        
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / torch.sqrt(torch.tensor(self.attention_head_size, dtype=torch.float32))
        
        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask
        
        attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        
        context_layer = torch.matmul(attention_probs, value_layer).transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_size)
        output_layer = self.dense(context_layer)
        
        return output_layer

class FeedForward(nn.Module):
    def __init__(self, hidden_size, intermediate_size, hidden_act, mlp_bias):
        super().__init__()
        self.dense = nn.Linear(hidden_size, intermediate_size, bias=mlp_bias)
        self.activation = nn.SiLU() if hidden_act == "silu" else nn.ReLU()
        self.output = nn.Linear(intermediate_size, hidden_size, bias=mlp_bias)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.output(hidden_states)
        return hidden_states

class TransformerBlock(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, intermediate_size, hidden_act, layer_norm_eps, mlp_bias, attention_dropout):
        super().__init__()
        self.attention = MultiHeadAttention(hidden_size, num_attention_heads, dropout_prob=attention_dropout)
        self.feed_forward = FeedForward(hidden_size, intermediate_size, hidden_act, mlp_bias)
        self.layer_norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        self.layer_norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        self.dropout = nn.Dropout(0.1)

    def forward(self, hidden_states, attention_mask=None):
        # Attention
        attention_output = self.attention(hidden_states, attention_mask)
        hidden_states = self.layer_norm1(hidden_states + attention_output)

        # Feed Forward
        feed_forward_output = self.feed_forward(hidden_states)
        hidden_states = self.layer_norm2(hidden_states + feed_forward_output)
        
        return hidden_states

class CognitivessModel(PreTrainedModel):
    def __init__(self, config: CognitivessConfig):
        super().__init__(config)
        self.config = config
        
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.layers = nn.ModuleList([
            TransformerBlock(
                config.hidden_size,
                config.num_attention_heads,
                config.intermediate_size,
                config.hidden_act,
                config.layer_norm_eps,
                config.mlp_bias,
                config.attention_dropout
            )
            for _ in range(config.num_hidden_layers)
        ])
        self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.activation = nn.SiLU() if config.hidden_act == "silu" else nn.ReLU()

    def forward(self, input_ids, attention_mask=None):
        # Embeddings
        embeddings = self.embeddings(input_ids)
        position_ids = torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = embeddings + position_embeddings

        # Transformer Layers
        hidden_states = embeddings
        for layer in self.layers:
            hidden_states = layer(hidden_states, attention_mask)
        
        # Pooler
        pooled_output = self.pooler(hidden_states[:, 0])
        return pooled_output