XFious commited on
Commit
4ae913a
·
1 Parent(s): 41be08c

first-commit

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ /flagged/
2
+ /__pycache__/
3
+ ```
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import transformers
3
+ import torch
4
+ import yaml
5
+
6
+ from dearth_config import DearthConfig
7
+ from dearth_model import DearthForCausalLM
8
+
9
+ import random
10
+
11
+
12
+ tk = transformers.AutoTokenizer.from_pretrained("./tk")
13
+
14
+ #model_path = "./ts100-re2-h1-4000.pt"
15
+ model_path = "./ts100-re2-h1-4000-model.pt"
16
+ yml_path = "./ts100-re2-h1.yml"
17
+ with open(yml_path, "r") as f:
18
+ config = yaml.load(f, Loader=yaml.FullLoader)['model']
19
+ if "vocab_size" not in config:
20
+ config['vocab_size'] = tk.vocab_size
21
+ config["attn_window_size"] = 500
22
+ print(config)
23
+ config = DearthConfig(**config)
24
+ model = DearthForCausalLM(config)
25
+ states = torch.load(model_path, map_location="cpu")
26
+ model_states = states
27
+ unwanted_prefix_dueto_compile = '_orig_mod.'
28
+ unwanted_prefix_dueto_ddp = 'module.'
29
+ unwanted_prefix_dueto_ddp_compiled = 'module._orig_mod.'
30
+
31
+ for k,v in list(model_states.items()):
32
+ if k.startswith(unwanted_prefix_dueto_ddp_compiled):
33
+ new_key = k[len(unwanted_prefix_dueto_ddp_compiled):]
34
+ model_states[k[len(unwanted_prefix_dueto_ddp_compiled):]] = model_states.pop(k)
35
+ elif k.startswith(unwanted_prefix_dueto_ddp):
36
+ new_key = k[len(unwanted_prefix_dueto_ddp):]
37
+ model_states[k[len(unwanted_prefix_dueto_ddp):]] = model_states.pop(k)
38
+ elif k.startswith(unwanted_prefix_dueto_compile):
39
+ new_key = k[len(unwanted_prefix_dueto_compile):]
40
+ model_states[k[len(unwanted_prefix_dueto_compile):]] = model_states.pop(k)
41
+
42
+ model.load_state_dict(model_states)
43
+
44
+
45
+ def generate(input, num_more_tokens):
46
+ num_more_tokens = int(num_more_tokens)
47
+ print(input)
48
+ input = input.strip()
49
+ input_ids = tk.encode(input)
50
+ input_ids = [tk.bos_token_id] + input_ids
51
+ input_ids = torch.tensor(input_ids, dtype=torch.long).view(1, -1)
52
+ print(input_ids)
53
+
54
+ output_ids = input_ids.squeeze(0).tolist()
55
+ for i in range(num_more_tokens):
56
+ input = torch.tensor(output_ids, dtype=torch.long).view(1, -1)
57
+ with torch.no_grad():
58
+ output = model(input)[0]
59
+ last_token_logits = output[0, -1, :]
60
+ last_token_logits_topk = torch.topk(last_token_logits, k=8, dim=-1)
61
+ probs = torch.softmax(last_token_logits_topk.values, dim=-1)
62
+ new_token = torch.multinomial(probs, num_samples=1).item()
63
+ new_token = last_token_logits_topk.indices[new_token].item()
64
+ if new_token == tk.eos_token_id:
65
+ break
66
+ output_ids.append(new_token)
67
+
68
+ print(output_ids)
69
+ print(tk.decode(output_ids))
70
+ output_ids = output_ids[1:]
71
+
72
+ return tk.decode(output_ids)
73
+
74
+ example_input = ["Once upon a time, there was a little girl",
75
+ "John and Sarah were playing together in their backyard when",
76
+ "It was a warm summer day when Billy and",
77
+ ]
78
+
79
+
80
+ Description = """
81
+ This is a small language model with 11M parameters, trained with the TinyStories dataset, and distilled from a 28M parameter teacher model.\n
82
+ This model has been trained with 512M tokens, which is about 0.9 epoch of the TinyStories dataset.\n
83
+ The PPL on the validation set is 1.7, in comparison, the teacher model has a PPL of 0.9. Lower PPL means better performance.\n
84
+ """
85
+
86
+
87
+ server = gr.Interface(
88
+ fn=generate,
89
+ title="Tinystories LM 11M",
90
+ description=Description,
91
+ inputs=[
92
+ gr.Textbox(lines=5, label="Input Text", value=example_input[random.randint(0, len(example_input)-1)]),
93
+ gr.Slider(16, 64, step=1.0, value=32, label="more tokens", info="")
94
+ ],
95
+ outputs="text"
96
+ )
97
+
98
+ server.launch()
dearth_config.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class DearthConfig(PretrainedConfig):
4
+ model_type = "dearth"
5
+ def __init__(
6
+ self,
7
+ max_token_len: int = 8192,
8
+ vocab_size: int = None, # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
9
+ n_layer: int = None,
10
+ n_head: int = None,
11
+ n_kv_head: int = None,
12
+ dim: int = None,
13
+ dim_qk_head = None,
14
+ hidden_dim: int = None,
15
+ multiple_of: int = None,
16
+ dropout_rate: float = 0.0,
17
+ layer_init_factor: float = None,
18
+ residual_factor: float = None, # should > 1.0
19
+ sliding_window_size: int = 4096,
20
+ front_window_size: int = 256,
21
+ use_rotary: bool = True,
22
+ rope_theta: float = 10000.0,
23
+ use_alibi: bool = False,
24
+
25
+ mimic_attn_layer: int = None, # 1-based, starting from the bottom; The first layer should be 1, not 0
26
+ mimic_n_head: int = None,
27
+ mimic_n_kv_head: int = None,
28
+ mimic_attn_dropout: float = None,
29
+ mimic_dim_qk_head: int = None,
30
+ mimic_use_rotary: bool = True,
31
+ mimic_use_alibi: bool = False,
32
+
33
+ pad_token_id=None,
34
+ bos_token_id=1,
35
+ eos_token_id=2,
36
+ tie_word_embeddings=False,
37
+ **kwargs,
38
+ ):
39
+ self.max_token_len = max_token_len
40
+ self.vocab_size = vocab_size
41
+ self.n_layer = n_layer
42
+ self.n_head = n_head
43
+ self.n_kv_head = n_kv_head
44
+ self.dim = dim
45
+ self.dim_qk_head = dim_qk_head
46
+ self.hidden_dim = hidden_dim
47
+ if hidden_dim is None:
48
+ self.hidden_dim = dim * 4
49
+ print(f"hidden_dim is not specified. Set to {self.hidden_dim}")
50
+ self.multiple_of = multiple_of
51
+ self.dropout_rate = dropout_rate
52
+ self.layer_init_factor = layer_init_factor
53
+ self.residual_factor = residual_factor
54
+ self.sliding_window_size = sliding_window_size
55
+ self.front_window_size = front_window_size
56
+ self.use_rotary = use_rotary
57
+ self.rope_theta = rope_theta
58
+ self.use_alibi = use_alibi
59
+
60
+ self.mimic_attn_layer = mimic_attn_layer
61
+ self.mimic_n_head = mimic_n_head
62
+ self.mimic_n_kv_head = mimic_n_kv_head
63
+ self.mimic_attn_dropout = mimic_attn_dropout
64
+ self.mimic_dim_qk_head = mimic_dim_qk_head
65
+ self.mimic_use_rotary = mimic_use_rotary
66
+ self.mimic_use_alibi = mimic_use_alibi
67
+
68
+ if "attn_window_size" in kwargs:
69
+ print("Warning: attn_window_size is deprecated. Please use sliding_window_size instead !!!!!!!!!!!")
70
+ self.sliding_window_size = kwargs["attn_window_size"]
71
+
72
+ super().__init__(
73
+ pad_token_id=pad_token_id,
74
+ bos_token_id=bos_token_id,
75
+ eos_token_id=eos_token_id,
76
+ tie_word_embeddings=tie_word_embeddings,
77
+ **kwargs,
78
+ )
79
+
80
+ def __str__(self) -> str:
81
+ return f"""
82
+ max_token_len = {self.max_token_len}
83
+ vocab_size = {self.vocab_size}
84
+ n_layer = {self.n_layer}
85
+ n_head = {self.n_head}
86
+ n_kv_head = {self.n_kv_head}
87
+ dim = {self.dim}
88
+ dim_qk_head = {self.dim_qk_head}
89
+ hidden_dim = {self.hidden_dim}
90
+ multiple_of = {self.multiple_of}
91
+ dropout_rate = {self.dropout_rate}
92
+ layer_init_factor = {self.layer_init_factor}
93
+ residual_factor = {self.residual_factor}
94
+ sliding_window_size = {self.sliding_window_size}
95
+ front_window_size = {self.front_window_size}
96
+ use_rotary = {self.use_rotary}
97
+ use_alibi = {self.use_alibi}
98
+
99
+ mimic_attn_layer = {self.mimic_attn_layer}
100
+ mimic_n_head = {self.mimic_n_head}
101
+ mimic_n_kv_head = {self.mimic_n_kv_head}
102
+ mimic_attn_dropout = {self.mimic_attn_dropout}
103
+ mimic_dim_qk_head = {self.mimic_dim_qk_head}
104
+ mimic_use_rotary = {self.mimic_use_rotary}
105
+ mimic_use_alibi = {self.mimic_use_alibi}
106
+ """
dearth_model.py ADDED
@@ -0,0 +1,777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from torch import Tensor
5
+ from typing import Optional, Tuple
6
+ import math
7
+
8
+ import logging
9
+ import copy
10
+
11
+ from dearth_config import DearthConfig
12
+
13
+ _USE_FAST_ROPE = False
14
+
15
+ class RMSNorm(torch.nn.Module): # a variant of LayerNorm that is faster and more memory efficient
16
+ def __init__(self, dim: int, eps: float = 1e-5):
17
+ super().__init__()
18
+ self.eps = eps
19
+ # set the weight to be 1 initially
20
+ self.weight = nn.Parameter(torch.ones(dim))
21
+
22
+ def _norm(self, x):
23
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
24
+
25
+ def forward(self, x):
26
+ output = self._norm(x.float()).type_as(x)
27
+ return output * self.weight
28
+
29
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
30
+ class RotaryEmbedding(nn.Module):
31
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
32
+ super().__init__()
33
+
34
+ self.dim = dim
35
+ self.max_position_embeddings = max_position_embeddings
36
+ self.base = base
37
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
38
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
39
+
40
+ # Build here to make `torch.jit.trace` work.
41
+ self._set_cos_sin_cache(
42
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
43
+ )
44
+
45
+ self.register_buffer("default_pos_ids",
46
+ torch.arange(0, self.max_position_embeddings, dtype=torch.long).view(-1, self.max_position_embeddings),
47
+ persistent=False)
48
+
49
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
50
+ self.max_seq_len_cached = seq_len
51
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
52
+
53
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
54
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
55
+ emb = torch.cat((freqs, freqs), dim=-1)
56
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) # shape: (max_seq_len_cached, dim // 2)
57
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
58
+
59
+ def forward(self, x, seq_len=None):
60
+ # x: [bs, num_attention_heads, seq_len, head_size]
61
+ if seq_len > self.max_seq_len_cached:
62
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
63
+
64
+ return (
65
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
66
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
67
+ )
68
+
69
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
70
+ def rotate_half(x):
71
+ """Rotates half the hidden dims of the input."""
72
+ x1 = x[..., : x.shape[-1] // 2]
73
+ x2 = x[..., x.shape[-1] // 2 :]
74
+ return torch.cat((-x2, x1), dim=-1)
75
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
76
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
77
+ cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
78
+ sin = sin[position_ids].unsqueeze(1)
79
+ q_embed = (q * cos) + (rotate_half(q) * sin)
80
+ k_embed = (k * cos) + (rotate_half(k) * sin)
81
+ return q_embed, k_embed
82
+
83
+
84
+
85
+
86
+ class FastRope(nn.Module):
87
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
88
+ super().__init__()
89
+
90
+ self.dim = dim
91
+ self.max_position_embeddings = max_position_embeddings
92
+ cis = precompute_freqs_cis(dim, max_position_embeddings, theta=base)
93
+ self.register_buffer("cis", cis, persistent=False)
94
+
95
+ def forward(self, start_idx, seq_len):
96
+ return self.cis[start_idx:start_idx+seq_len, :]
97
+
98
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
99
+ """
100
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
101
+
102
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
103
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
104
+ The returned tensor contains complex values in complex64 data type.
105
+
106
+ Args:
107
+ dim (int): Dimension of the frequency tensor.
108
+ end (int): End index for precomputing frequencies.
109
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
110
+
111
+ Returns:
112
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
113
+ """
114
+ with torch.no_grad():
115
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
116
+ t = torch.arange(end, device=freqs.device) # type: ignore
117
+ freqs = torch.outer(t, freqs).float() # type: ignore
118
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
119
+ return freqs_cis
120
+
121
+ def apply_rotary_emb(
122
+ xq: torch.Tensor,
123
+ xk: torch.Tensor,
124
+ freqs_cis: torch.Tensor,
125
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
126
+ """
127
+ Apply rotary embeddings to input tensors using the given frequency tensor.
128
+
129
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
130
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
131
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
132
+ returned as real tensors.
133
+
134
+ Args:
135
+ xq (torch.Tensor): Query tensor to apply rotary embeddings.
136
+ xk (torch.Tensor): Key tensor to apply rotary embeddings.
137
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
138
+
139
+ Returns:
140
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
141
+ """
142
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
143
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
144
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
145
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
146
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
147
+ return xq_out.type_as(xq), xk_out.type_as(xk)
148
+
149
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
150
+ """
151
+ Reshape frequency tensor for broadcasting it with another tensor.
152
+
153
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
154
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
155
+
156
+ Args:
157
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
158
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
159
+
160
+ Returns:
161
+ torch.Tensor: Reshaped frequency tensor.
162
+
163
+ Raises:
164
+ AssertionError: If the frequency tensor doesn't match the expected shape.
165
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
166
+ """
167
+ ndim = x.ndim
168
+ assert 0 <= 1 < ndim
169
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f"freqs_cis.shape: {freqs_cis.shape}, x.shape: {x.shape}"
170
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
171
+ return freqs_cis.view(*shape)
172
+
173
+
174
+
175
+
176
+ class AttentionMask(nn.Module):
177
+ attn_mask: torch.Tensor = None
178
+ def __init__(self, config: DearthConfig):
179
+ super().__init__()
180
+ self.config = config
181
+ self.sliding_window_size = config.sliding_window_size
182
+ self.front_window_size = config.front_window_size
183
+ if self.attn_mask is None:
184
+ tmp_attn_mask = self.build_causal_and_window_mask(config.max_token_len, config.sliding_window_size, config.front_window_size)
185
+ self.attn_mask = tmp_attn_mask.requires_grad_(False) # shape: (max_token_len, max_token_len)
186
+ #self.register_buffer("attn_mask", self.build_causal_and_window_mask(config.max_token_len, config.sliding_window_size, config.front_window_size).requires_grad_(False), persistent=False)
187
+
188
+ def forward(self, bz, n_head, q_seq_len, kv_seq_len, q_start_idx: int, device, dtype) -> torch.Tensor:
189
+ if self.attn_mask.device != device or self.attn_mask.dtype != dtype:
190
+ self.attn_mask = self.attn_mask.to(device=device, dtype=dtype).requires_grad_(False)
191
+ end_idx = q_start_idx + q_seq_len
192
+ q_k_diff_len = kv_seq_len - q_seq_len # it should be >= 0, because it is meaningless to attend future tokens
193
+ top = q_start_idx
194
+ bottom = end_idx
195
+ if q_start_idx == 0 and q_k_diff_len == 0:
196
+ # assume: sliding window size = 100, front window size = 50
197
+ # case 1: training: q_start_idx = 0, q_seq_len = 1000, kv_seq_len = 1000
198
+ mask = self.attn_mask[:end_idx, :end_idx]
199
+ elif q_k_diff_len > 0 and q_start_idx > 0 and end_idx >= kv_seq_len:
200
+ # TODO: not allow in training; remove this line after testing
201
+ raise RuntimeError(f"NOT FOR TRAINING: q_start_idx = {q_start_idx}, q_seq_len = {q_seq_len}, kv_seq_len = {kv_seq_len}")
202
+ if end_idx > self.front_window_size + self.sliding_window_size:
203
+ # case 2: qsl < kvsl: q_start_idx = 190, q_seq_len = 10, kv_seq_len = 150, end_idx = 200
204
+ # mask = self.attn_mask[top:bottom, :self.front_window_size] + \
205
+ # self.attn_mask[q_start_idx:end_idx, end_idx - (kv_seq_len - self.front_window_size):end_idx]
206
+ mask = torch.cat([self.attn_mask[top:bottom, :self.front_window_size], self.attn_mask[top:bottom, end_idx - (kv_seq_len - self.front_window_size):end_idx]], dim=-1)
207
+ elif end_idx <= self.front_window_size + self.sliding_window_size:
208
+ # case 3: qsl < kvsl: q_start_idx = 140, q_seq_len = 10, kv_seq_len = 150, end_idx = 150
209
+ mask = self.attn_mask[top:bottom, :end_idx]
210
+ else:
211
+ raise RuntimeError(f"q_start_idx = {q_start_idx}, q_seq_len = {q_seq_len}, kv_seq_len = {kv_seq_len}")
212
+ return mask.expand(bz, n_head, q_seq_len, kv_seq_len).detach()
213
+
214
+
215
+ @staticmethod
216
+ def build_causal_and_window_mask(seq_len, sliding_window_size, front_window_size) -> torch.Tensor:
217
+ mask = torch.ones(seq_len, seq_len)
218
+ if seq_len > sliding_window_size: # need to apply sliding window mask, beacause the sequence is too long
219
+ mask = torch.triu(mask, diagonal=-sliding_window_size+1)
220
+ if front_window_size > 0:
221
+ tmp_front_mask = torch.cat([torch.ones(seq_len, front_window_size), torch.zeros(seq_len, seq_len-front_window_size)], dim=-1)
222
+ tmp_front_mask = torch.tril(tmp_front_mask, diagonal=-sliding_window_size)
223
+ mask = mask + tmp_front_mask
224
+ # apply causal mask
225
+ mask = mask.tril(diagonal=0)
226
+ mask = mask.log() # map 0 to -inf, 1 to 0
227
+ # print(f"mask.shape: {mask.shape}, and mask")
228
+ # print(mask)
229
+ return mask
230
+
231
+
232
+ class SharedAttentionMask(nn.Module):
233
+ def __init__(self, config: DearthConfig):
234
+ super().__init__()
235
+ self.config = config
236
+ self.sliding_window_size = config.sliding_window_size
237
+ self.front_window_size = config.front_window_size
238
+ tmp_attn_mask = self.build_causal_and_window_mask(config.max_token_len, config.sliding_window_size, config.front_window_size)
239
+ self.register_buffer("attn_mask", tmp_attn_mask, persistent=False)
240
+
241
+ def forward(self, q_seq_len, kv_seq_len, q_start_idx: int) -> torch.Tensor:
242
+ end_idx = q_start_idx + q_seq_len
243
+ q_k_diff_len = kv_seq_len - q_seq_len # it should be >= 0, because it is meaningless to attend future tokens
244
+ top = q_start_idx
245
+ bottom = end_idx
246
+ if q_start_idx == 0 and q_k_diff_len == 0:
247
+ # assume: sliding window size = 100, front window size = 50
248
+ # case 1: training: q_start_idx = 0, q_seq_len = 1000, kv_seq_len = 1000
249
+ mask = self.attn_mask[:end_idx, :end_idx]
250
+ elif q_k_diff_len > 0 and q_start_idx > 0 and end_idx >= kv_seq_len:
251
+ # TODO: not allow in training; remove this line after testing
252
+ raise RuntimeError(f"NOT FOR TRAINING: q_start_idx = {q_start_idx}, q_seq_len = {q_seq_len}, kv_seq_len = {kv_seq_len}")
253
+ if end_idx > self.front_window_size + self.sliding_window_size:
254
+ # case 2: qsl < kvsl: q_start_idx = 190, q_seq_len = 10, kv_seq_len = 150, end_idx = 200
255
+ # mask = self.attn_mask[top:bottom, :self.front_window_size] + \
256
+ # self.attn_mask[q_start_idx:end_idx, end_idx - (kv_seq_len - self.front_window_size):end_idx]
257
+ mask = torch.cat([self.attn_mask[top:bottom, :self.front_window_size], self.attn_mask[top:bottom, end_idx - (kv_seq_len - self.front_window_size):end_idx]], dim=-1)
258
+ elif end_idx <= self.front_window_size + self.sliding_window_size:
259
+ # case 3: qsl < kvsl: q_start_idx = 140, q_seq_len = 10, kv_seq_len = 150, end_idx = 150
260
+ mask = self.attn_mask[top:bottom, :end_idx]
261
+ else:
262
+ raise RuntimeError(f"q_start_idx = {q_start_idx}, q_seq_len = {q_seq_len}, kv_seq_len = {kv_seq_len}")
263
+ return mask.detach() # shape: (1, 1, seqlen, seqlen)
264
+
265
+
266
+ @staticmethod
267
+ def build_causal_and_window_mask(seq_len, sliding_window_size, front_window_size) -> torch.Tensor:
268
+ mask = torch.ones(seq_len, seq_len)
269
+ if seq_len > sliding_window_size: # need to apply sliding window mask, beacause the sequence is too long
270
+ mask = torch.triu(mask, diagonal=-sliding_window_size+1)
271
+ if front_window_size > 0:
272
+ tmp_front_mask = torch.cat([torch.ones(seq_len, front_window_size), torch.zeros(seq_len, seq_len-front_window_size)], dim=-1)
273
+ tmp_front_mask = torch.tril(tmp_front_mask, diagonal=-sliding_window_size)
274
+ mask = mask + tmp_front_mask
275
+ # apply causal mask
276
+ mask = mask.tril(diagonal=0)
277
+ mask = mask.log() # map 0 to -inf, 1 to 0
278
+ # print(f"mask.shape: {mask.shape}, and mask")
279
+ # print(mask)
280
+ return mask
281
+
282
+
283
+
284
+ def build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8, device=None):
285
+ r"""
286
+ Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
287
+ relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
288
+ the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
289
+ https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
290
+
291
+ retrun shape: (1, num_heads, 1, sequence_length)
292
+ """
293
+ alibi = torch.arange(1 - sequence_length, 1, dtype=torch.int32, device=device).view(1, 1, 1, sequence_length)
294
+ num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads))
295
+
296
+ base = torch.arange(1, num_heads_power_of_2 + 1, dtype=torch.float32, device=device)
297
+ base = base * (alibi_bias_max / num_heads_power_of_2)
298
+
299
+ slopes = 1.0 / torch.pow(2, base)
300
+ slopes = slopes.view(1, num_heads, 1, 1)
301
+
302
+ if num_heads_power_of_2 != num_heads:
303
+ slopes = torch.concat([slopes[1::2], slopes[::2]])[:num_heads]
304
+
305
+ alibi = alibi * slopes
306
+ return alibi
307
+
308
+
309
+ # def build_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8, device=None):
310
+ # r"""
311
+ # Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
312
+ # relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
313
+ # the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
314
+ # https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
315
+
316
+ # retrun shape: (1, num_heads, 1, sequence_length)
317
+ # """
318
+ # slope = []
319
+ # m_power = (-8/num_heads)
320
+ # m_increace = -8/num_heads
321
+ # for i in range(num_heads):
322
+ # slope.append(m_power)
323
+ # m_power += m_increace
324
+ # slope = torch.tensor(slope, device=device)
325
+ # alibi = torch.arange(1 - sequence_length, 1, dtype=torch.int32, device=device).view(1, 1, 1, sequence_length)
326
+ # alibi = alibi * slope.view(1, num_heads, 1, 1)
327
+ # return alibi
328
+
329
+ def compute_alibi(num_heads, sequence_length, alibi_bias_max=8, device=None):
330
+ r"""
331
+ Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
332
+ relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
333
+ the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
334
+ https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
335
+
336
+ retrun shape: (1, num_heads, 1, sequence_length)
337
+ """
338
+ slope = []
339
+ m_power = (-8/num_heads)
340
+ m_increace = -8/num_heads
341
+ for i in range(num_heads):
342
+ slope.append(2 ** m_power)
343
+ m_power += m_increace
344
+ slope = torch.tensor(slope, device=device)
345
+ alibi = torch.arange(1 - sequence_length, 1, dtype=torch.int32, device=device).view(1, 1, 1, sequence_length)
346
+ alibi = alibi * slope.view(1, num_heads, 1, 1)
347
+ return alibi
348
+
349
+
350
+ class Attention(nn.Module):
351
+ def __init__(self, config: DearthConfig):
352
+ super().__init__()
353
+ assert config.dim % config.n_head == 0
354
+
355
+ # regularization
356
+ self.n_head = config.n_head
357
+ self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
358
+ self.dim = config.dim
359
+ assert config.dim % config.n_head == 0
360
+ self.dim_qk_head = config.dim_qk_head if config.dim_qk_head is not None else config.dim // config.n_head
361
+ self.dim_v_head = config.dim // config.n_head
362
+ assert config.n_kv_head <= config.n_head and config.n_head % config.n_kv_head == 0
363
+ self.n_kv_group = config.n_head // config.n_kv_head
364
+ self.dropout_rate = config.dropout_rate
365
+
366
+ self.alibi_emb = None
367
+ self.pos_emb = None
368
+
369
+ self.sliding_window_size = config.sliding_window_size
370
+
371
+ def _fill_with_neg_inf(t):
372
+ """FP16-compatible function that fills a tensor with -inf."""
373
+ return t.float().fill_(float("-inf")).type_as(t)
374
+
375
+ # neg_inf_mask = _fill_with_neg_inf(torch.ones_like(torch.empty(config.max_token_len, config.max_token_len)))
376
+ # window_size_mask = torch.triu(neg_inf_mask, diagonal=1)
377
+ # if config.sliding_window_size is not None and config.max_token_len > config.sliding_window_size:
378
+ # window_size_mask = window_size_mask + torch.tril(neg_inf_mask, diagonal=-config.sliding_window_size)
379
+ # self.register_buffer("window_size_mask", window_size_mask, persistent=False)
380
+ # if config.use_alibi:
381
+ # alibi_emb = compute_alibi(config.n_head, config.max_token_len) # shape: (1, n_head, 1, seqlen)
382
+ # #self.alibi_emb = self.alibi_emb.expand(1, config.n_head, config.max_token_len, config.max_token_len) # shape: (1, n_head, seqlen, seqlen)
383
+ # self.register_buffer("alibi_emb", alibi_emb, persistent=False)
384
+
385
+ self.window_size_mask = AttentionMask(config)
386
+
387
+ if config.use_rotary:
388
+ if not _USE_FAST_ROPE:
389
+ self.pos_emb = RotaryEmbedding(
390
+ self.dim_qk_head,
391
+ max_position_embeddings=config.max_token_len,
392
+ base=config.rope_theta,
393
+ )
394
+ if _USE_FAST_ROPE:
395
+ self.pos_emb = FastRope(
396
+ self.dim_qk_head,
397
+ max_position_embeddings=config.max_token_len,
398
+ base=config.rope_theta,
399
+ )
400
+
401
+ # query, key, values projections for all heads
402
+ self.wq = nn.Linear(self.dim, self.n_head * self.dim_qk_head, bias=True)
403
+ self.wk = nn.Linear(self.dim, self.n_kv_head * self.dim_qk_head, bias=True)
404
+ self.wv = nn.Linear(self.dim, self.dim // self.n_kv_group, bias=False)
405
+ self.wo = nn.Linear(self.dim, self.dim, bias=False)
406
+
407
+
408
+ def forward(self, x: Tensor, attn_mask: Tensor, start_idx: Optional[int] = 0):
409
+ batch_size, seqlen, emb_dim = x.size() # batch size, sequence length, embedding dimensionality (dim)
410
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
411
+
412
+ # split embedding dim into number of heads
413
+ xq = xq.view(batch_size, seqlen, self.n_head, self.dim_qk_head)
414
+ xk = xk.view(batch_size, seqlen, self.n_kv_head, self.dim_qk_head)
415
+ xv = xv.view(batch_size, seqlen, self.n_kv_head, self.dim_v_head)
416
+
417
+ if self.pos_emb is not None and _USE_FAST_ROPE:
418
+ xq, xk = apply_rotary_emb(xq, xk, self.pos_emb(start_idx, seqlen))
419
+
420
+ # transpose to get dimensions batch_size * n_head * seqlen * emb_dim
421
+ xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
422
+ kv_seqlen = xk.size(2)
423
+
424
+ # apply positional embeddings
425
+ if self.pos_emb is not None and not _USE_FAST_ROPE:
426
+ # self.pos_emb = self.pos_emb.to(x.device, dtype=x.dtype)
427
+ # xq, xk = apply_rotary_pos_emb(xq, xk, self.pos_emb[start_idx:start_idx+seqlen])
428
+ cos, sin = self.pos_emb(xv, seq_len=kv_seqlen)
429
+ xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin, self.pos_emb.default_pos_ids[:, :kv_seqlen])
430
+
431
+ # TODO: add cache for fast inference
432
+
433
+
434
+ # grouped query
435
+ xk = repeat_kv(xk, self.n_kv_group)
436
+ xv = repeat_kv(xv, self.n_kv_group)
437
+
438
+ # self.window_size_mask = self.window_size_mask.to(x.device, dtype=x.dtype)
439
+ # attn_mask = self.window_size_mask[start_idx:start_idx+seqlen, start_idx:start_idx+kv_seqlen]
440
+ # attn_mask = attn_mask.unsqueeze(0).unsqueeze(0) # shape: (1, 1, seqlen, seqlen)
441
+ # attn_mask = attn_mask.expand(batch_size, self.n_head, seqlen, kv_seqlen) # shape: (batch_size, n_head, seqlen, seqlen)
442
+ # if self.alibi_emb is not None:
443
+ # self.alibi_emb = self.alibi_emb.to(x.device, dtype=x.dtype)
444
+ # attn_mask = attn_mask + self.alibi_emb[:,:,:,:kv_seqlen]
445
+
446
+ #attn_mask = self.window_size_mask(batch_size, self.n_head, seqlen, kv_seqlen, start_idx, x.device, x.dtype) # -inf or 0
447
+
448
+ # efficient attention using Flash Attention CUDA kernels
449
+ y = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=self.dropout_rate if self.training else 0)
450
+ y = y.transpose(1, 2).contiguous().view(batch_size, seqlen, emb_dim) # merge heads
451
+
452
+ # output projection
453
+ return self.wo(y)
454
+
455
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
456
+ """
457
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
458
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
459
+ hidden_states.shape = (batch, n_kv_head, seqlen, head_dim)
460
+ """
461
+ # if n_rep == 1:
462
+ # return hidden_states
463
+ # return torch.repeat_interleave(hidden_states, n_rep, dim=1)
464
+
465
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
466
+ if n_rep == 1:
467
+ return hidden_states
468
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
469
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
470
+
471
+ # def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
472
+ # """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
473
+ # bs, slen, n_kv_heads, head_dim = x.shape
474
+ # if n_rep == 1:
475
+ # return x
476
+ # return (
477
+ # x[:, :, :, None, :]
478
+ # .expand(bs, slen, n_kv_heads, n_rep, head_dim)
479
+ # .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
480
+ # )
481
+
482
+ class MLP(nn.Module):
483
+ def __init__(self, config):
484
+ super().__init__()
485
+ dim = config.dim
486
+ hidden_dim = config.dim * 4 if config.hidden_dim is None else config.hidden_dim
487
+ multiple_of = 64 if config.multiple_of is None else config.multiple_of
488
+ hidden_dim = int(2 * hidden_dim / 3)
489
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) # round up to nearest multiple of multiple_of
490
+
491
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
492
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
493
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
494
+
495
+ def forward(self, x):
496
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
497
+
498
+ class Mimic_Attn(Attention):
499
+ def __init__(self, config):
500
+ new_config = copy.deepcopy(config)
501
+ new_config.n_head = config.mimic_n_head if config.mimic_n_head is not None else config.n_head
502
+ new_config.n_kv_head = config.mimic_n_kv_head if config.mimic_n_kv_head is not None else config.n_kv_head
503
+ new_config.dim_qk_head = config.mimic_dim_qk_head if config.mimic_dim_qk_head is not None else config.dim_qk_head
504
+ new_config.dropout_rate = config.mimic_attn_dropout if config.mimic_attn_dropout is not None else 0.0
505
+ new_config.use_rotary = config.mimic_use_rotary if config.mimic_use_rotary is not None else config.use_rotary
506
+ new_config.use_alibi = config.mimic_use_alibi if config.mimic_use_alibi is not None else config.use_alibi
507
+
508
+ super().__init__(new_config)
509
+ self.saved_q = None
510
+ self.saved_k = None
511
+ self.saved_v = None
512
+ self.saved_attn_map = None
513
+
514
+ def forward(self, x: Tensor, attn_mask: Tensor, start_idx: Optional[int] = 0): # shape of attn_mask: (bz, n_head, q_seq_len, kv_seq_len)
515
+ batch_size, seqlen, emb_dim = x.size() # batch size, sequence length, embedding dimensionality (dim)
516
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
517
+ self.saved_v = xv
518
+
519
+ # split embedding dim into number of heads
520
+ xq = xq.view(batch_size, seqlen, self.n_head, self.dim_qk_head)
521
+ xk = xk.view(batch_size, seqlen, self.n_kv_head, self.dim_qk_head)
522
+ xv = xv.view(batch_size, seqlen, self.n_kv_head, self.dim_v_head)
523
+
524
+ if self.pos_emb is not None and _USE_FAST_ROPE:
525
+ xq, xk = apply_rotary_emb(xq, xk, self.pos_emb(start_idx, seqlen))
526
+
527
+ # transpose to get dimensions batch_size * n_head * seqlen * emb_dim
528
+ xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
529
+ kv_seqlen = xk.size(2)
530
+
531
+ # # apply positional embeddings
532
+ # if self.pos_emb is not None:
533
+ # self.pos_emb = self.pos_emb.to(x.device)
534
+ # xq, xk = apply_pos_emb(xq, xk, self.pos_emb[start_idx:start_idx+seqlen])
535
+ if self.pos_emb is not None and not _USE_FAST_ROPE:
536
+ cos, sin = self.pos_emb(xv, seq_len=kv_seqlen)
537
+ xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin, self.pos_emb.default_pos_ids[:, :kv_seqlen])
538
+
539
+ # TODO: add cache for fast inference
540
+
541
+ # grouped query
542
+ xk = repeat_kv(xk, self.n_kv_group)
543
+ xv = repeat_kv(xv, self.n_kv_group)
544
+
545
+ # self.window_size_mask = self.window_size_mask.to(x.device)
546
+ # kv_seqlen = xk.size(2)
547
+ # attn_mask = self.window_size_mask[start_idx:start_idx+seqlen, start_idx:start_idx+kv_seqlen]
548
+ # attn_mask = attn_mask.unsqueeze(0).unsqueeze(0) # shape: (1, 1, seqlen, seqlen)
549
+ # attn_mask = attn_mask.expand(batch_size, self.n_head, seqlen, kv_seqlen) # shape: (batch_size, n_head, seqlen, seqlen)
550
+ # if self.alibi_emb is not None:
551
+ # self.alibi_emb = self.alibi_emb.to(x.device)
552
+ # attn_mask = attn_mask + self.alibi_emb[:,:,:,:kv_seqlen]
553
+
554
+ #attn_mask = self.window_size_mask(batch_size, self.n_head, seqlen, kv_seqlen, start_idx, x.device, x.dtype) # -inf or 0
555
+
556
+ attn_weights = torch.matmul(xq, xk.transpose(2, 3)) * (1 / math.sqrt(self.dim_qk_head)) # shape: (batch_size, n_head, seqlen, seqlen)
557
+ attn_weights = attn_weights + attn_mask.expand(batch_size, self.n_head, seqlen, kv_seqlen) # shape: (batch_size, n_head, seqlen, seqlen
558
+ attn_weights = F.softmax(attn_weights.float(), dim=-1).to(xq.dtype) # shape: (batch_size, n_head, seqlen, seqlen)
559
+ # use log_softmax to avoid overflow
560
+ #attn_weights = F.log_softmax(attn_weights, dim=-1).exp() # shape: (batch_size, n_head, seqlen, seqlen)
561
+ self.saved_attn_map = attn_weights
562
+
563
+ attn_weights = F.dropout(attn_weights, p=self.dropout_rate, training=self.training)
564
+
565
+ y = torch.matmul(attn_weights, xv) # shape: (batch_size, n_head, seqlen, head_dim)
566
+
567
+ y = y.transpose(1, 2).contiguous().view(batch_size, seqlen, emb_dim) # merge heads
568
+
569
+ # output projection
570
+ return self.wo(y)
571
+
572
+ def get_intermediate_attn_v(self):
573
+ return self.saved_attn_map, self.saved_v
574
+
575
+
576
+ class TransformerBlock(nn.Module):
577
+ def __init__(self, config):
578
+ super().__init__()
579
+ self.ln_1 = RMSNorm(config.dim)
580
+ self.attn = Attention(config)
581
+ self.ln_2 = RMSNorm(config.dim)
582
+ self.mlp = MLP(config)
583
+
584
+ self.residual_factor = config.residual_factor
585
+
586
+ def forward(self, x: Tensor, attn_mask: Tensor, start_idx: int):
587
+ # post-LN
588
+ residual = x
589
+ x = self.attn(x, attn_mask, start_idx=start_idx)
590
+ x = self.ln_1(self.residual_connection(x, residual))
591
+
592
+ residual = x
593
+ x = self.mlp(x)
594
+ x = self.ln_2(self.residual_connection(x, residual))
595
+
596
+ return x
597
+
598
+ def residual_connection(self, x, residual):
599
+ # residual factor should > 1.0
600
+ return residual * self.residual_factor + x
601
+
602
+
603
+
604
+ class DearthModel(nn.Module):
605
+ def __init__(self, config: DearthConfig):
606
+ super().__init__()
607
+ assert config.vocab_size is not None
608
+ assert config.max_token_len is not None
609
+
610
+ self.layer_init_factor = config.layer_init_factor if config.layer_init_factor is not None else float(config.n_layer * 8) ** (-1/2)
611
+ self.residual_factor = config.residual_factor if config.residual_factor is not None else float(config.n_layer * 2) ** (1/4)
612
+ if config.residual_factor is None:
613
+ config.residual_factor = self.residual_factor
614
+ logging.warning(f"residual_factor is not set, using default value {self.residual_factor} = (2 * n_layer) ** 1/4")
615
+ if config.layer_init_factor is None:
616
+ config.layer_init_factor = self.layer_init_factor
617
+ logging.warning(f"layer_init_factor is not set, using default value {self.layer_init_factor} = (n_layer * 8) ** -1/2")
618
+
619
+ self.config = config
620
+
621
+ layers = []
622
+ for i in range(config.n_layer):
623
+ if config.mimic_attn_layer is not None and i+1 == config.mimic_attn_layer:
624
+ new_layer = TransformerBlock(config)
625
+ new_layer.attn = Mimic_Attn(config)
626
+ layers.append(new_layer)
627
+ else:
628
+ layers.append(TransformerBlock(config))
629
+
630
+ self.layers = nn.ModuleList(layers)
631
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
632
+ self.ln_before = RMSNorm(config.dim)
633
+ self.shared_attn_mask = SharedAttentionMask(config)
634
+
635
+ if config.mimic_attn_layer is not None and config.mimic_attn_layer > 0 and config.mimic_attn_layer <= config.n_layer:
636
+ self.mimic_attn = self.layers[config.mimic_attn_layer-1].attn
637
+ else:
638
+ self.mimic_attn = None
639
+
640
+ # initialize weights
641
+ _init_weight(self, self.layer_init_factor)
642
+
643
+ def get_input_device(self):
644
+ return self.embed_tokens.weight.device
645
+
646
+ # def _init_weights(self, module):
647
+ # if isinstance(module, nn.Linear):
648
+ # torch.nn.init.xavier_normal_(module.weight, gain=self.layer_init_factor)
649
+ # if module.bias is not None:
650
+ # torch.nn.init.zeros_(module.bias)
651
+ # elif isinstance(module, nn.Embedding):
652
+ # torch.nn.init.xavier_normal_(module.weight, gain=1)
653
+ # elif isinstance(module, RMSNorm):
654
+ # module.weight.data.fill_(1.0)
655
+
656
+
657
+
658
+ def forward(self, tokens, start_idx=0): # return all logits
659
+ batch_size, seqlen = tokens.size()
660
+ if seqlen > self.config.max_token_len:
661
+ raise ValueError(f"input sequence length {seqlen} exceeds maximum sequence length {self.config.max_token_len}")
662
+
663
+ # create token embeddings from token table; x.shape = (batch_size, seqlen, dim)
664
+ h = self.embed_tokens(tokens)
665
+ assert h.size() == (batch_size, seqlen, self.config.dim)
666
+
667
+ h = self.ln_before(h)
668
+
669
+ # transformer layers
670
+ attn_mask = self.shared_attn_mask(seqlen, seqlen, q_start_idx=start_idx) # TODO: it will not work if q_seq_len != kv_seq_len
671
+ for layer in self.layers:
672
+ h = layer(h, attn_mask, start_idx=start_idx) # h.shape = (batch_size, seqlen, dim)
673
+
674
+ return h, None
675
+
676
+
677
+ def get_num_params(self):
678
+ """
679
+ Return the number of parameters in the model.
680
+ For non-embedding count (default), the position embeddings get subtracted.
681
+ The token embeddings would too, except due to the parameter sharing these
682
+ params are actually used as weights in the final layer, so we include them.
683
+ """
684
+ #n_params = sum(p.numel() for p in self.parameters())
685
+ n_params = sum(p.numel() for p in self.transformer.layers[0].parameters() if p.requires_grad)
686
+ return int(n_params)
687
+
688
+
689
+ def get_intermediate_attn_v(self):
690
+ if self.mimic_attn is None:
691
+ return torch.zeros(1, 1, 1, 1), torch.zeros(1, 1, 1, 1)
692
+ return self.mimic_attn.get_intermediate_attn_v()
693
+
694
+
695
+ class DearthForCausalLM(nn.Module):
696
+ _tied_weights_keys = ["lm_head.weight"]
697
+
698
+ def __init__(self, config: DearthConfig):
699
+ super().__init__()
700
+ self.model = DearthModel(config)
701
+ self.dearth_config = config
702
+ self.vocab_size = config.vocab_size
703
+ self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
704
+ torch.nn.init.xavier_normal_(self.lm_head.weight, gain=1)
705
+
706
+ self.front_window_size = config.front_window_size
707
+ self.sliding_window_size = config.sliding_window_size
708
+
709
+ def get_input_device(self):
710
+ return self.model.get_input_device()
711
+
712
+ def get_intermediate_attn_v(self):
713
+ return self.model.get_intermediate_attn_v()
714
+
715
+ def print_all_params(self):
716
+ for name, param in self.named_parameters():
717
+ print(f"name: {name}, param.shape: {param.shape}")
718
+
719
+ def forward(
720
+ self,
721
+ input_ids: torch.LongTensor = None,
722
+ use_cache: Optional[bool] = False,
723
+ ) ->Tuple: #-> Union[Tuple, CausalLMOutputWithPast]:
724
+ r"""
725
+ Args:
726
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
727
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
728
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
729
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
730
+
731
+ Returns:
732
+
733
+ Example:
734
+
735
+ ```python
736
+ >>> from transformers import AutoTokenizer, MistralForCausalLM
737
+
738
+ >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
739
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
740
+
741
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
742
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
743
+
744
+ >>> # Generate
745
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
746
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
747
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
748
+ ```"""
749
+ outputs = self.model(
750
+ tokens=input_ids
751
+ )
752
+
753
+ hidden_states = outputs[0]
754
+ logits = self.lm_head(hidden_states)
755
+
756
+ output = (logits,) + outputs[1:]
757
+ return output
758
+
759
+
760
+ def _init_weight(model, weight_init_factor): # TODO: fix this part if change any model structure
761
+ small_list = {'wv', 'wo', 'w1', 'w2', 'w3'}
762
+ norm_list = {'ln_before', 'ln_2', 'ln_1'}
763
+ for name, p in model.named_parameters():
764
+ percise_name = name.split(".")[-2]
765
+ if "bias" in name:
766
+ logging.debug(f"the parameter {name} is initialized with 0.0")
767
+ p.data.fill_(0.0)
768
+ elif percise_name in small_list:
769
+ logging.debug(f"the parameter {name} is initialized with gain={weight_init_factor}")
770
+ torch.nn.init.xavier_normal_(p, gain=weight_init_factor)
771
+ elif percise_name in norm_list:
772
+ logging.debug(f"the parameter {name} is initialized with 1.0")
773
+ p.data.fill_(1.0)
774
+ else:
775
+ logging.debug(f"the parameter {name} is initialized with gain=1.0")
776
+ torch.nn.init.xavier_normal_(p, gain=1)
777
+
extract_model.ipynb ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "model_path = \"./ts100-re2-h1-4000.pt\"\n",
11
+ "model = torch.load(model_path, map_location=torch.device('cpu'))"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 6,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "pure_model = model['model']"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 7,
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "new_model_name = \"ts100-re2-h1-4000-model.pt\"\n",
30
+ "torch.save(pure_model, new_model_name)"
31
+ ]
32
+ }
33
+ ],
34
+ "metadata": {
35
+ "kernelspec": {
36
+ "display_name": "pytorch",
37
+ "language": "python",
38
+ "name": "python3"
39
+ },
40
+ "language_info": {
41
+ "codemirror_mode": {
42
+ "name": "ipython",
43
+ "version": 3
44
+ },
45
+ "file_extension": ".py",
46
+ "mimetype": "text/x-python",
47
+ "name": "python",
48
+ "nbconvert_exporter": "python",
49
+ "pygments_lexer": "ipython3",
50
+ "version": "3.10.11"
51
+ }
52
+ },
53
+ "nbformat": 4,
54
+ "nbformat_minor": 2
55
+ }
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ torch
3
+ gradio
tk/config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "EleutherAI/gpt-neo-125M",
3
+ "activation_function": "gelu_new",
4
+ "architectures": [
5
+ "GPTNeoForCausalLM"
6
+ ],
7
+ "attention_dropout": 0,
8
+ "attention_layers": [
9
+ "global",
10
+ "local",
11
+ "global",
12
+ "local"
13
+ ],
14
+ "attention_types": [
15
+ [
16
+ [
17
+ "global",
18
+ "local"
19
+ ],
20
+ 2
21
+ ]
22
+ ],
23
+ "bos_token_id": 50256,
24
+ "embed_dropout": 0,
25
+ "eos_token_id": 50256,
26
+ "gradient_checkpointing": false,
27
+ "hidden_size": 768,
28
+ "initializer_range": 0.02,
29
+ "intermediate_size": null,
30
+ "layer_norm_epsilon": 1e-05,
31
+ "max_position_embeddings": 2048,
32
+ "model_type": "gpt_neo",
33
+ "num_heads": 16,
34
+ "num_layers": 4,
35
+ "resid_dropout": 0,
36
+ "summary_activation": null,
37
+ "summary_first_dropout": 0.1,
38
+ "summary_proj_to_labels": true,
39
+ "summary_type": "cls_index",
40
+ "summary_use_proj": true,
41
+ "torch_dtype": "float32",
42
+ "transformers_version": "4.28.0",
43
+ "use_cache": true,
44
+ "vocab_size": 50257,
45
+ "window_size": 256
46
+ }
tk/special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tk/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tk/tokenizer_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "eos_token": {
13
+ "__type": "AddedToken",
14
+ "content": "<|endoftext|>",
15
+ "lstrip": false,
16
+ "normalized": true,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "errors": "replace",
21
+ "model_max_length": 2048,
22
+ "pad_token": null,
23
+ "special_tokens_map_file": null,
24
+ "tokenizer_class": "GPT2Tokenizer",
25
+ "unk_token": {
26
+ "__type": "AddedToken",
27
+ "content": "<|endoftext|>",
28
+ "lstrip": false,
29
+ "normalized": true,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tk/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
ts100-re2-h1-4000-model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d4f2675a1c7d191bb30f8cf02c1049e6cbd22ebafb360e3c7541f027751278b
3
+ size 35395438
ts100-re2-h1.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ max_token_len: 1024 # should be larger than the seqlen
3
+ #vocab_size: 32000
4
+ n_layer: 24
5
+ n_head: 4
6
+ n_kv_head: 2 # multi-query attention
7
+ dim: 128
8
+ #dim_qk_head: 32 # usually set to dim // n_head, but can be different
9
+ #hidden_dim: # 768*4, the MLP after the attention layer
10
+ #multiple_of: 64 # make sure the hidden_dim is a multiple of this number, beause silu (swish) is used, so hidden layer will be changed
11
+ dropout_rate: 0.05 # for the attention map
12
+ #layer_init_factor: 0.1 # by default = (n_layer * 8) ** -1/2; should use default value, based on the microsoft DeepNet paper
13
+ #residual_factor: 2 # by default = (2 * n_layer) ** 1/2; should use default value
14
+ attn_window_size: 512
15
+ front_window_size: 0
16
+ use_rotary: True
17
+ use_alibi: False
18
+
19
+ mimic_attn_layer: 21 # replace this layer to be a training target, to mimic the attention of the teacher; this special layer should use the similar setting as the teacher
20
+ mimic_n_head: 16
21
+ mimic_n_kv_head: 16
22
+ #mimic_sliding_window_size: 1024
23
+ mimic_attn_dropout: 0.0
24
+ mimic_dim_qk_head: 16
25
+ mimic_use_rotary: True
26
+ mimic_use_alibi: False
27
+
28
+ opt:
29
+ gradient_clip: 1.0
30
+ lr: 1
31
+ beta1: 0.9
32
+ beta2: 0.99
33
+ weight_decay: 0.2
34
+ opt_name: sophia
35
+
36
+ loss:
37
+ soft_loss_weight: 0.0
38
+ hard_loss_weight: 1.0
39
+ mimic_loss_weight: 0.0
40
+ virtual_v_head_num: 16 # based on MiniLM v2, it is similar to attention but only use v to do self-attn. It make the student's x_v similar to teacher's x_v
41
+ loss_soft_temperature: 1 # temperature for the soft loss, to make the softmax more smooth, sensitive to the small logits
42
+
43
+ scheduler:
44
+ slr_seg:
45
+ # - [0.0000001, 0.0005, 300]
46
+ # - [0.0005, 0.0005, 2000]
47
+ - [0.0005, 0.00025, 1000]
48
+
49
+