mollysama commited on
Commit
d08f1a8
·
verified ·
1 Parent(s): a346ab5

Upload folder using huggingface_hub

Browse files
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<s>": 0
3
+ }
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Rwkv7ForCausalLM"
4
+ ],
5
+ "attention_hidden_size": 1024,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_rwkv7.Rwkv7Config",
8
+ "AutoModelForCausalLM": "modeling_rwkv7.Rwkv7ForCausalLM"
9
+ },
10
+ "bos_token_id": 0,
11
+ "eos_token_id": 0,
12
+ "head_size": 64,
13
+ "hidden_size": 1024,
14
+ "intermediate_size": null,
15
+ "layer_norm_epsilon": 1e-05,
16
+ "lora_rank_decay": null,
17
+ "lora_rank_gate": 128,
18
+ "lora_rank_iclr": null,
19
+ "lora_rank_value_residual_mix": null,
20
+ "model_type": "rwkv7",
21
+ "num_hidden_layers": 24,
22
+ "tie_word_embeddings": false,
23
+ "transformers_version": "4.45.2",
24
+ "use_cache": true,
25
+ "vocab_size": 65536
26
+ }
configuration_rwkv7.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ RWKV configuration"""
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ RWKV7_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
25
+
26
+
27
+ class Rwkv7Config(PretrainedConfig):
28
+ """
29
+ This is the configuration class to store the configuration of a [`Rwkv7Model`]. It is used to instantiate a RWKV7
30
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
31
+ defaults will yield a similar configuration to that of the RWVK-7
32
+ [RWKV/v7-Goose-1.6B-Pile-HF](https://huggingface.co/RWKV/v7-Goose-1.6B-Pile-HF) architecture.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 65536):
40
+ Vocabulary size of the RWKV7 model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`Rwkv7Model`].
42
+ hidden_size (`int`, *optional*, defaults to 768):
43
+ Dimensionality of the embeddings and hidden states.
44
+ num_hidden_layers (`int`, *optional*, defaults to 24):
45
+ Number of hidden layers in the model.
46
+ attention_hidden_size (`int`, *optional*):
47
+ Dimensionality of the attention hidden states. Will default to `hidden_size` if unset.
48
+ num_attention_heads (`int`, *optional*, defaults to 64):
49
+ The attention heads to use in rwkv7 self_attention module.
50
+ head_size (`int`, *optional*, defaults to 64): head_size of rwkv7 self_attention module.
51
+ intermediate_size (`int`, *optional*):
52
+ Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset.
53
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
54
+ The epsilon to use in the layer normalization layers.
55
+ bos_token_id (`int`, *optional*, defaults to 0):
56
+ The id of the beginning of sentence token in the vocabulary. Defaults to 0.
57
+ eos_token_id (`int`, *optional*, defaults to 0):
58
+ The id of the end of sentence token in the vocabulary. Defaults to 0.
59
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
60
+ Whether or not to tie the word embeddings with the input token embeddings.
61
+ use_cache (`bool`, *optional*, defaults to `True`):
62
+ Whether or not the model should return the last state.
63
+
64
+
65
+ Example:
66
+
67
+ ```python
68
+ >>> from transformers import Rwkv7Config, Rwkv7Model
69
+
70
+ >>> # Initializing a Rwkv7 configuration
71
+ >>> configuration = Rwkv7Config()
72
+
73
+ >>> # Initializing a model (with random weights) from the configuration
74
+ >>> model = Rwkv7Model(configuration)
75
+
76
+ >>> # Accessing the model configuration
77
+ >>> configuration = model.config
78
+ ```"""
79
+
80
+ model_type = "rwkv7"
81
+
82
+ def __init__(
83
+ self,
84
+ vocab_size=65536,
85
+ hidden_size=768,
86
+ num_hidden_layers=24,
87
+ attention_hidden_size=None,
88
+ head_size=64,
89
+ intermediate_size=None,
90
+ lora_rank_decay=None,
91
+ lora_rank_iclr=None,
92
+ lora_rank_value_residual_mix=None,
93
+ lora_rank_gate=None,
94
+ layer_norm_epsilon=1e-5,
95
+ bos_token_id=0,
96
+ eos_token_id=0,
97
+ tie_word_embeddings=False,
98
+ use_cache=True,
99
+ **kwargs,
100
+ ):
101
+ self.vocab_size = vocab_size
102
+ self.hidden_size = hidden_size
103
+ self.num_hidden_layers = num_hidden_layers
104
+ self.attention_hidden_size = attention_hidden_size if attention_hidden_size is not None else hidden_size
105
+ self.head_size = head_size
106
+ self.intermediate_size = intermediate_size
107
+ self.lora_rank_decay = lora_rank_decay
108
+ self.lora_rank_iclr = lora_rank_iclr
109
+ self.lora_rank_value_residual_mix = lora_rank_value_residual_mix
110
+ self.lora_rank_gate = lora_rank_gate
111
+ self.layer_norm_epsilon = layer_norm_epsilon
112
+ self.use_cache = use_cache
113
+
114
+ super().__init__(
115
+ tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs
116
+ )
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chat_format": "chatml",
3
+ "eos_token_id": 0,
4
+ "pad_token_id": 0,
5
+ "max_window_size": 4096,
6
+ "max_new_tokens": 4096,
7
+ "do_sample": true,
8
+ "top_k": 0,
9
+ "top_p": 0.1,
10
+ "repetition_penalty": 1.0,
11
+ "transformers_version": "4.31.1"
12
+ }
hf_rwkv_tokenizer.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for RWKV6."""
16
+
17
+ import os
18
+ import re
19
+ from typing import TYPE_CHECKING, List, Optional, Tuple
20
+
21
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
22
+ from transformers.utils import logging
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ pass
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ VOCAB_FILES_NAMES = {
32
+ "vocab_file": "rwkv_vocab_v20230424.txt",
33
+ }
34
+
35
+ class TRIE:
36
+ __slots__ = tuple("ch,to,values,front".split(","))
37
+ to: list
38
+ values: set
39
+
40
+ def __init__(self, front=None, ch=None):
41
+ self.ch = ch
42
+ self.to = [None for ch in range(256)]
43
+ self.values = set()
44
+ self.front = front
45
+
46
+ def __repr__(self):
47
+ fr = self
48
+ ret = []
49
+ while fr != None:
50
+ if fr.ch != None:
51
+ ret.append(fr.ch)
52
+ fr = fr.front
53
+ return "<TRIE %s %s>" % (ret[::-1], self.values)
54
+
55
+ def add(self, key: bytes, idx: int = 0, val=None):
56
+ if idx == len(key):
57
+ if val is None:
58
+ val = key
59
+ self.values.add(val)
60
+ return self
61
+ ch = key[idx]
62
+ if self.to[ch] is None:
63
+ self.to[ch] = TRIE(front=self, ch=ch)
64
+ return self.to[ch].add(key, idx=idx + 1, val=val)
65
+
66
+ def find_longest(self, key: bytes, idx: int = 0):
67
+ u: TRIE = self
68
+ ch: int = key[idx]
69
+
70
+ while u.to[ch] is not None:
71
+ u = u.to[ch]
72
+ idx += 1
73
+ if u.values:
74
+ ret = idx, u, u.values
75
+ if idx == len(key):
76
+ break
77
+ ch = key[idx]
78
+ return ret
79
+
80
+
81
+ class RWKV_TOKENIZER:
82
+ def __init__(self, file_name):
83
+ self.idx2token = {}
84
+ sorted = [] # must be already sorted
85
+ with open(file_name, "r", encoding="utf-8") as f:
86
+ lines = f.readlines()
87
+ for l in lines:
88
+ idx = int(l[: l.index(" ")])
89
+ x = eval(l[l.index(" ") : l.rindex(" ")])
90
+ x = x.encode("utf-8") if isinstance(x, str) else x
91
+ assert isinstance(x, bytes)
92
+
93
+ assert len(x) == int(l[l.rindex(" ") :])
94
+ sorted += [x]
95
+ self.idx2token[idx] = x
96
+
97
+ self.token2idx = {}
98
+ for k, v in self.idx2token.items():
99
+ self.token2idx[v] = int(k)
100
+
101
+ self.root = TRIE()
102
+ for t, i in self.token2idx.items():
103
+ _ = self.root.add(t, val=(t, i))
104
+
105
+ def encodeBytes(self, src: bytes):
106
+ idx: int = 0
107
+ tokens = []
108
+ while idx < len(src):
109
+ _idx: int = idx
110
+ idx, _, values = self.root.find_longest(src, idx)
111
+ assert idx != _idx
112
+ _, token = next(iter(values))
113
+ tokens.append(token)
114
+ return tokens
115
+
116
+ def decodeBytes(self, tokens):
117
+ return b"".join(map(lambda i: self.idx2token[i], tokens))
118
+
119
+ def encode(self, src):
120
+ if isinstance(src, str):
121
+ return [self.encodeBytes(src.encode("utf-8"))]
122
+ elif isinstance(src, list):
123
+ return [self.encodeBytes(s.encode("utf-8")) for s in src]
124
+
125
+ def decode(self, tokens):
126
+ return [self.decodeBytes(batch).decode("utf-8") for batch in tokens]
127
+ # try:
128
+ # return self.decodeBytes(tokens).decode('utf-8')
129
+ # except:
130
+ # return '\ufffd' # bad utf-8
131
+
132
+ def printTokens(self, tokens):
133
+ for i in tokens:
134
+ s = self.idx2token[i]
135
+ try:
136
+ s = s.decode("utf-8")
137
+ except:
138
+ pass
139
+ print(f"{repr(s)}{i}", end=" ")
140
+ print()
141
+
142
+
143
+ class Rwkv6Tokenizer(PreTrainedTokenizer):
144
+ vocab_files_names = VOCAB_FILES_NAMES
145
+ model_input_names = ["input_ids", "attention_mask"]
146
+
147
+ def __init__(
148
+ self, vocab_file, bos_token="<s>", eos_token="<s>", unk_token="<s>", **kwargs
149
+ ):
150
+ if not os.path.isfile(vocab_file):
151
+ raise ValueError(
152
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
153
+ " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
154
+ )
155
+
156
+ with open(vocab_file, "r", encoding="utf-8") as reader:
157
+ tokens = reader.readlines()
158
+
159
+ if "add_bos_token" in kwargs:
160
+ self.add_bos_token = kwargs["add_bos_token"]
161
+ else:
162
+ self.add_bos_token = False
163
+ self.trie_tokenizer = RWKV_TOKENIZER(vocab_file)
164
+ vocab = self.trie_tokenizer.token2idx
165
+ self.encoder = vocab
166
+ self.decoder = {v: k for k, v in vocab.items()}
167
+ self._added_tokens_decoder = {0: AddedToken(str(bos_token))}
168
+ super().__init__(
169
+ bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs
170
+ )
171
+
172
+ @property
173
+ def vocab_size(self):
174
+ return len(self.encoder)
175
+
176
+ def get_vocab(self):
177
+ vocab = {str(self.convert_ids_to_tokens(i)): i for i in range(self.vocab_size)}
178
+ vocab.update(self.added_tokens_encoder)
179
+ return vocab
180
+
181
+ def _tokenize(self, text, split_special_tokens=False):
182
+ # return self.wordpiece_tokenizer.tokenize(text.encode("utf-8"))
183
+ return self.trie_tokenizer.encode(text)[0]
184
+
185
+ def _convert_token_to_id(self, token):
186
+ return token
187
+
188
+ def _convert_id_to_token(self, index):
189
+ """Converts an index (integer) in a token (byte) using the vocab."""
190
+ token = self.decoder.get(index, self.unk_token)
191
+ if isinstance(token, (bytes)):
192
+ token = token.decode("utf-8", errors="replace")
193
+ return token
194
+
195
+ def convert_tokens_to_string(self, tokens):
196
+ """Converts a sequence of tokens (bytes) in a single string. Additional tokens are encoded to bytes"""
197
+ out_string = b"".join(
198
+ [k.encode(errors="replace") if isinstance(k, str) else k for k in tokens]
199
+ ).decode("utf-8")
200
+ return out_string
201
+
202
+ def save_vocabulary(
203
+ self, save_directory: str, filename_prefix: Optional[str] = None
204
+ ) -> Tuple[str]:
205
+ index = 0
206
+ if os.path.isdir(save_directory):
207
+ vocab_file = os.path.join(
208
+ save_directory,
209
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.txt",
210
+ )
211
+ else:
212
+ vocab_file = (
213
+ filename_prefix + "-" if filename_prefix else ""
214
+ ) + save_directory
215
+ with open(vocab_file, "w", encoding="utf-8") as writer:
216
+ for token, token_index in sorted(
217
+ self.encoder.items(), key=lambda kv: kv[1]
218
+ ):
219
+ if index != token_index:
220
+ logger.warning(
221
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
222
+ " Please check that the vocabulary is not corrupted!"
223
+ )
224
+ index = token_index
225
+ writer.write(str(token) + "\n")
226
+ index += 1
227
+ return (vocab_file,)
228
+
229
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
230
+ if self.add_bos_token:
231
+ bos_token_ids = [self.bos_token_id]
232
+ else:
233
+ bos_token_ids = []
234
+
235
+ output = bos_token_ids + token_ids_0
236
+
237
+ if token_ids_1 is None:
238
+ return output
239
+
240
+ return output + bos_token_ids + token_ids_1
241
+
242
+ def get_special_tokens_mask(
243
+ self,
244
+ token_ids_0: List[int],
245
+ token_ids_1: Optional[List[int]] = None,
246
+ already_has_special_tokens: bool = False,
247
+ ) -> List[int]:
248
+ """
249
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
250
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
251
+
252
+ Args:
253
+ token_ids_0 (`List[int]`):
254
+ List of IDs.
255
+ token_ids_1 (`List[int]`, *optional*):
256
+ Optional second list of IDs for sequence pairs.
257
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
258
+ Whether or not the token list is already formatted with special tokens for the model.
259
+
260
+ Returns:
261
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
262
+ """
263
+ if already_has_special_tokens:
264
+ return super().get_special_tokens_mask(
265
+ token_ids_0=token_ids_0,
266
+ token_ids_1=token_ids_1,
267
+ already_has_special_tokens=True,
268
+ )
269
+
270
+ if not self.add_bos_token:
271
+ return super().get_special_tokens_mask(
272
+ token_ids_0=token_ids_0,
273
+ token_ids_1=token_ids_1,
274
+ already_has_special_tokens=False,
275
+ )
276
+
277
+ if token_ids_1 is None:
278
+ return [1] + ([0] * len(token_ids_0))
279
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8e455a2f154c9cd0d307813f63a8e733f1edde12e3d16c598902ac843d71616
3
+ size 901619456
modeling_rwkv7.py ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The RWKV team and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch RWKV7 World model."""
16
+
17
+ from dataclasses import dataclass
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ from pathlib import Path
21
+
22
+ import math
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import CrossEntropyLoss
28
+
29
+ from transformers.modeling_utils import PreTrainedModel, GenerationMixin, _init_weights
30
+ from transformers.utils import (
31
+ ModelOutput,
32
+ add_code_sample_docstrings,
33
+ add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ is_ninja_available,
36
+ is_torch_cuda_available,
37
+ logging,
38
+ )
39
+
40
+ from .configuration_rwkv7 import Rwkv7Config
41
+
42
+ # MIT License
43
+
44
+ # Copyright (c) 2024 Songlin Yang
45
+
46
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
47
+ # of this software and associated documentation files (the "Software"), to deal
48
+ # in the Software without restriction, including without limitation the rights
49
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
50
+ # copies of the Software, and to permit persons to whom the Software is
51
+ # furnished to do so, subject to the following conditions:
52
+
53
+ # The above copyright notice and this permission notice shall be included in all
54
+ # copies or substantial portions of the Software.
55
+
56
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
57
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
58
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
59
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
60
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
61
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
62
+ # SOFTWARE.
63
+
64
+ # Copyright (c) 2024, Johan Sokrates Wind
65
+
66
+ import torch as th
67
+ import triton
68
+ import triton.language as tl
69
+
70
+ @triton.jit
71
+ def IND4(a,b,c,d,nb,nc,nd):
72
+ return ((a*nb+b)*nc+c)*nd+d
73
+ @triton.jit
74
+ def IND5(a,b,c,d,e,nb,nc,nd,ne):
75
+ return (((a*nb+b)*nc+c)*nd+d)*ne+e
76
+
77
+ @triton.jit
78
+ def _prod(a,b): return a*b
79
+
80
+ # inv(I-A) where A is a strictly lower triangular nxn matrix
81
+ @triton.jit
82
+ def tri_minv(A, n:tl.constexpr, prec:tl.constexpr):
83
+ i = tl.arange(0,n)
84
+ prod = (i[None,:]==i[:,None]).to(tl.float32)
85
+ for j in range(n-1):
86
+ prod += tl_dot(prec, prod, (A*((i[None,:]==j)*(i[:,None]>i[None,:]))).trans())
87
+ return prod.trans()
88
+
89
+ @triton.jit
90
+ def fw_attn_triton(w_,q_,k_,v_,a_,b_, s0_,y_,s_,sT_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr):
91
+ bi = tl.program_id(1)
92
+ hi = tl.program_id(0)
93
+
94
+ i = tl.arange(0,C)[None,:]
95
+ state = tl.load(s0_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32)
96
+ for t0 in range(T//dT):
97
+ t = t0*dT+tl.arange(0,dT)[:,None]
98
+ sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
99
+ sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
100
+ sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
101
+ sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
102
+ sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
103
+ sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
104
+
105
+ w = (-sw.exp()).exp()
106
+ fw = tl.reduce(w, 0, _prod, keep_dims=True)
107
+ incl_pref = tl.cumprod(w,axis=0)
108
+ non_incl_pref = incl_pref / w
109
+ inv_incl_pref = 1 / incl_pref
110
+
111
+ wq = sq * incl_pref
112
+ wa = sa * non_incl_pref
113
+ kwi = sk * inv_incl_pref
114
+ bwi = sb * inv_incl_pref
115
+
116
+ mask1 = (t > t.trans())
117
+ ab = tl_dot(prec, wa, bwi.trans()) * mask1
118
+ ak = tl_dot(prec, wa, kwi.trans()) * mask1
119
+
120
+ ab_inv = tri_minv(ab, dT, prec)
121
+
122
+ ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans())
123
+ u = tl_dot(prec, ab_inv, ab_u)
124
+ mask2 = (t >= t.trans())
125
+ qk = tl_dot(prec, wq, kwi.trans()) * mask2
126
+ qb = tl_dot(prec, wq, bwi.trans()) * mask2
127
+ yy = tl_dot(prec, qk, sv) + tl_dot(prec, qb, u) + tl_dot(prec, wq, state.trans())
128
+ tl.store(y_+IND4(bi,t,hi,i, T,H,C), yy.to(tl.bfloat16))
129
+
130
+ tl.store(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C), state.to(tl.float32))
131
+ state = state * fw + tl_dot(prec, sv.trans(), kwi*fw) + tl_dot(prec, u.trans(), bwi*fw)
132
+ tl.store(sT_+IND4(bi,hi,i.trans(),i, H,C,C), state.to(tl.bfloat16))
133
+
134
+ @triton.jit
135
+ def bw_attn_triton(w_,q_,k_,v_,a_,b_, dy_,s_,dsT_, dw_,dq_,dk_,dv_,da_,db_,ds0_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr):
136
+ bi = tl.program_id(1)
137
+ hi = tl.program_id(0)
138
+
139
+ i = tl.arange(0,C)[None,:]
140
+ dstate = tl.load(dsT_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32)
141
+
142
+ for t0 in range(T//dT-1,-1,-1):
143
+ t = t0*dT+tl.arange(0,dT)[:,None]
144
+
145
+ state = tl.load(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C)).to(tl.float32)
146
+
147
+ sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
148
+ sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
149
+ sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
150
+ sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
151
+ sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
152
+ sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
153
+ sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
154
+
155
+ dw_fac = -sw.exp()
156
+ w = dw_fac.exp()
157
+ fw = tl.reduce(w, 0, _prod, keep_dims=True)
158
+ incl_pref = tl.cumprod(w,axis=0)
159
+ non_incl_pref = incl_pref / w
160
+ inv_incl_pref = 1 / incl_pref
161
+
162
+ wq = sq * incl_pref
163
+ wa = sa * non_incl_pref
164
+ kwi = sk * inv_incl_pref
165
+ bwi = sb * inv_incl_pref
166
+
167
+ mask1 = (t > t.trans())
168
+ ab = tl_dot(prec, wa, bwi.trans()) * mask1
169
+ ak = tl_dot(prec, wa, kwi.trans()) * mask1
170
+
171
+ ab_inv = tri_minv(ab, dT, prec)
172
+
173
+ ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans())
174
+ u = tl_dot(prec, ab_inv, ab_u)
175
+ mask2 = (t >= t.trans())
176
+ qk = tl_dot(prec, wq, kwi.trans()) * mask2
177
+ qb = tl_dot(prec, wq, bwi.trans()) * mask2
178
+
179
+ du = tl_dot(prec, qb.trans(), sdy) + tl_dot(prec, bwi*fw, dstate.trans())
180
+ dab_u = tl_dot(prec, ab_inv.trans(), du)
181
+
182
+ dv = tl_dot(prec, qk.trans(), sdy) + tl_dot(prec, kwi*fw, dstate.trans()) + tl_dot(prec, ak.trans(), dab_u)
183
+ tl.store(dv_+IND4(bi,t,hi,i, T,H,C), dv.to(tl.bfloat16))
184
+
185
+ dab = tl_dot(prec, tl_dot(prec, ab_inv.trans(), du), u.trans()) * mask1
186
+ dak = tl_dot(prec, dab_u, sv.trans()) * mask1
187
+ dab_u_state = tl_dot(prec, dab_u, state)
188
+ da = non_incl_pref * (tl_dot(prec, dab, bwi) + tl_dot(prec, dak, kwi) + dab_u_state)
189
+ tl.store(da_+IND4(bi,t,hi,i, T,H,C), da.to(tl.bfloat16))
190
+
191
+ dqb = tl_dot(prec, sdy, u.trans()) * mask2
192
+ dqk = tl_dot(prec, sdy, sv.trans()) * mask2
193
+ dy_state = tl_dot(prec, sdy, state)
194
+ dq = incl_pref * (tl_dot(prec, dqb, bwi) + tl_dot(prec, dqk, kwi) + dy_state)
195
+ tl.store(dq_+IND4(bi,t,hi,i, T,H,C), dq.to(tl.bfloat16))
196
+
197
+ fw_u_dstate = fw * tl_dot(prec, u, dstate)
198
+ db = inv_incl_pref * (tl_dot(prec, dab.trans(), wa) + tl_dot(prec, dqb.trans(), wq) + fw_u_dstate)
199
+ tl.store(db_+IND4(bi,t,hi,i, T,H,C), db.to(tl.bfloat16))
200
+
201
+ fw_v_dstate = fw * tl_dot(prec, sv, dstate)
202
+ dk = inv_incl_pref * (tl_dot(prec, dak.trans(), wa) + tl_dot(prec, dqk.trans(), wq) + fw_v_dstate)
203
+ tl.store(dk_+IND4(bi,t,hi,i, T,H,C), dk.to(tl.bfloat16))
204
+
205
+ dw0 = fw * tl.sum(state*dstate, axis=0,keep_dims=True)
206
+ for k in range(t0*dT,t0*dT+dT):
207
+ lmask = (t<k).trans()
208
+ A = (tl_dot(prec, dab*lmask, bwi) + tl_dot(prec, dak*lmask, kwi)) * wa * (t>k)
209
+ A += (tl_dot(prec, dqb*lmask, bwi) + tl_dot(prec, dqk*lmask, kwi)) * wq * (t>=k)
210
+ A += (fw_v_dstate*kwi + fw_u_dstate*bwi) * (t<k)
211
+ A += dab_u_state*wa * (t>k) + dy_state*wq * (t>=k)
212
+ dw = tl.sum(A, axis=0,keep_dims=True) + dw0
213
+
214
+ wk = tl.load(w_+IND4(bi,k,hi,i, T,H,C)).to(tl.float32)
215
+ dw *= -wk.exp()
216
+ tl.store(dw_+IND4(bi,k,hi,i, T,H,C), dw.to(tl.bfloat16))
217
+
218
+ dstate = dstate * fw + tl_dot(prec, sdy.trans(), wq) + tl_dot(prec, dab_u.trans(), wa)
219
+ tl.store(ds0_+IND4(bi,hi,i.trans(),i, H,C,C), dstate.to(tl.bfloat16))
220
+
221
+
222
+ class TritonRWKV7(th.autograd.Function):
223
+ @staticmethod
224
+ def forward(ctx, w,q,k,v,z,b,s0, dot_prec):
225
+ K = 16
226
+ B,T,H,C = w.shape
227
+ s0 = th.zeros(B,H,C,C, dtype=w.dtype,device=w.device) if s0 is None else s0
228
+ y = th.empty_like(v)
229
+ sT = th.empty_like(s0)
230
+ s = th.zeros(B,H,T//K,C,C, dtype=th.float32,device=w.device)
231
+ fw_attn_triton[(H,B)](w,q,k,v,z,b, s0,y,s,sT, B,T,H,C,K, dot_prec)
232
+ ctx.dot_prec = dot_prec
233
+ ctx.save_for_backward(w,q,k,v,z,b,s)
234
+ return y, sT
235
+ @staticmethod
236
+ def backward(ctx, dy, dsT):
237
+ K = 16
238
+ w,q,k,v,z,b,s = ctx.saved_tensors
239
+ B,T,H,C = w.shape
240
+ dw,dq,dk,dv,dz,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,z,b,dsT]]
241
+ bw_attn_triton[(H,B)](w,q,k,v,z,b, dy,s,dsT, dw,dq,dk,dv,dz,db,ds0, B,T,H,C,K, ctx.dot_prec)
242
+ return dw,dq,dk,dv,dz,db,ds0,None
243
+
244
+ @triton.jit
245
+ def tl_dot(prec:tl.constexpr, a, b) -> torch.Tensor:
246
+ if prec == 'fp32':
247
+ return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=False)
248
+ elif prec == 'tf32':
249
+ return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=True)
250
+ elif prec == 'bf16':
251
+ return tl.dot(a.to(tl.bfloat16),b.trans().to(tl.bfloat16).trans(), allow_tf32=True)
252
+ else:
253
+ tl.static_assert(False)
254
+
255
+ def rwkv7_attn_triton(r,w,k,v,a,b, HEAD_SIZE, dot_prec = 'fp32'):
256
+ B,T,HC = w.shape
257
+ C = HEAD_SIZE
258
+ H = HC//C
259
+ r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,a,b]]
260
+ s0 = th.zeros(B,H,C,C, dtype=th.bfloat16,device=w.device)
261
+ return TritonRWKV7.apply(w,r,k,v,a,b,s0,dot_prec)[0].view(B,T,HC)
262
+
263
+ logger = logging.get_logger(__name__)
264
+
265
+ _CHECKPOINT_FOR_DOC = "RWKV/v7-Goose-1.6B-Pile-HF"
266
+ _CONFIG_FOR_DOC = "Rwkv7Config"
267
+
268
+ class Rwkv7SelfAttention(nn.Module):
269
+ def __init__(self, config, layer_id=0):
270
+ super().__init__()
271
+ self.config = config
272
+ self.layer_id = layer_id
273
+ C = hidden_size = config.hidden_size
274
+ attention_hidden_size = config.attention_hidden_size
275
+ self.attention_hidden_size = attention_hidden_size
276
+ H = self.num_heads = attention_hidden_size // config.head_size
277
+ N = self.head_size = config.head_size
278
+
279
+ calc_lora_rank = lambda exponent, multiplier: max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
280
+ lora_rank_decay = config.lora_rank_decay or calc_lora_rank(0.5, 1.8)
281
+ lora_rank_iclr = config.lora_rank_iclr or calc_lora_rank(0.5, 1.8)
282
+ lora_rank_value_residual_mix = config.lora_rank_value_residual_mix or calc_lora_rank(0.5, 1.3)
283
+ lora_rank_gate = config.lora_rank_gate or calc_lora_rank(0.8, 0.6)
284
+
285
+ self.x_r = nn.Parameter(torch.empty(1,1,C))
286
+ self.x_w = nn.Parameter(torch.empty(1,1,C))
287
+ self.x_k = nn.Parameter(torch.empty(1,1,C))
288
+ self.x_v = nn.Parameter(torch.empty(1,1,C))
289
+ self.x_a = nn.Parameter(torch.empty(1,1,C))
290
+ self.x_g = nn.Parameter(torch.empty(1,1,C))
291
+
292
+ self.w0 = nn.Parameter(torch.empty(1,1,C))
293
+ self.w1 = nn.Parameter(torch.empty(C, lora_rank_decay))
294
+ self.w2 = nn.Parameter(torch.empty(lora_rank_decay, C))
295
+
296
+ self.a0 = nn.Parameter(torch.empty(1,1,C))
297
+ self.a1 = nn.Parameter(torch.empty(C, lora_rank_iclr))
298
+ self.a2 = nn.Parameter(torch.empty(lora_rank_iclr, C))
299
+
300
+ if layer_id > 0:
301
+ self.v0 = nn.Parameter(torch.empty(1,1,C))
302
+ self.v1 = nn.Parameter(torch.empty(C, lora_rank_value_residual_mix))
303
+ self.v2 = nn.Parameter(torch.empty(lora_rank_value_residual_mix, C))
304
+
305
+ self.g1 = nn.Parameter(torch.empty(C, lora_rank_gate))
306
+ self.g2 = nn.Parameter(torch.empty(lora_rank_gate, C))
307
+
308
+ self.k_k = nn.Parameter(torch.empty(1,1,C))
309
+ self.k_a = nn.Parameter(torch.empty(1,1,C))
310
+ self.r_k = nn.Parameter(torch.empty(H,N))
311
+
312
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
313
+ self.receptance = nn.Linear(C, C, bias=False)
314
+ self.key = nn.Linear(C, C, bias=False)
315
+ self.value = nn.Linear(C, C, bias=False)
316
+ self.output = nn.Linear(C, C, bias=False)
317
+ self.ln_x = nn.GroupNorm(H, C, eps=self.head_size * 1e-5)
318
+
319
+
320
+ def forward(self, hidden, state=None, v_first=None, use_cache=False, seq_mode=True):
321
+ # Mix hidden with the previous timestep to produce key, value, receptance
322
+ if hidden.size(1) == 1 and state is not None:
323
+ shifted = state[0][self.layer_id]
324
+ else:
325
+ shifted = self.time_shift(hidden)
326
+ if state is not None:
327
+ shifted[:, 0] = state[0][self.layer_id]
328
+ if len(shifted.size()) == 2:
329
+ shifted = shifted.unsqueeze(1)
330
+
331
+ x = hidden
332
+
333
+ B, T, C = hidden.shape
334
+ H = self.num_heads
335
+ N = self.head_size
336
+
337
+ xx = shifted - x
338
+
339
+ xr = x+xx*self.x_r
340
+ xw = x+xx*self.x_w
341
+ xk = x+xx*self.x_k
342
+ xv = x+xx*self.x_v
343
+ xa = x+xx*self.x_a
344
+ xg = x+xx*self.x_g
345
+
346
+ r = self.receptance(xr)
347
+ w = torch.tanh(xw @ self.w1) @ self.w2
348
+ k = self.key(xk)
349
+ v = self.value(xv)
350
+ a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2)
351
+ g = torch.sigmoid(xg @ self.g1) @ self.g2
352
+
353
+ kk = torch.nn.functional.normalize((k * self.k_k).view(B,T,H,-1), dim=-1, p=2.0).view(B,T,-1)
354
+ k = k * (1 + (a-1) * self.k_a)
355
+ if self.layer_id == 0: v_first = v
356
+ else: v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2)
357
+
358
+ if T == 1 or not self.training:
359
+ w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) # 0.606531 = exp(-0.5)
360
+ vk_state = state[1][self.layer_id]
361
+ for t in range(T):
362
+ r_, w_, k_, v_, kk_, a_ = r[:,t], w[:,t], k[:,t], v[:,t], kk[:,t], a[:,t]
363
+ vk = v_.view(B,H,N,1) @ k_.view(B,H,1,N)
364
+ ab = (-kk_).view(B,H,N,1) @ (kk_*a_).view(B,H,1,N)
365
+ vk_state = vk_state * w_.view(B,H,1,N) + vk_state @ ab.float() + vk.float()
366
+ xx[:,t] = (vk_state.to(dtype=x.dtype) @ r_.view(B,H,N,1)).view(B,H*N)
367
+ state[1][self.layer_id] = vk_state
368
+ # FIXME - support fast triton kernel for non-training pre-fill with state in and out
369
+ else:
370
+ w = -torch.nn.functional.softplus(-(self.w0 + w)) - 0.5
371
+ rwkv7_attn_triton(r, w, k, v, -kk, kk*a, self.head_size)
372
+
373
+ xx = torch.nn.functional.group_norm(xx.view(B*T,H*N), num_groups=H, weight=self.ln_x.weight, bias=self.ln_x.bias, eps = self.ln_x.eps).view(B,T,H*N)
374
+ xx = xx + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)
375
+ xx = self.output(xx * g)
376
+
377
+ if state is not None:
378
+ state[0][self.layer_id] = hidden[:, -1]
379
+
380
+ return xx, state, v_first
381
+
382
+
383
+ class Rwkv7FeedForward(nn.Module):
384
+ def __init__(self, config, layer_id=0):
385
+ super().__init__()
386
+ self.config = config
387
+ self.layer_id = layer_id
388
+ hidden_size = config.hidden_size
389
+ intermediate_size = (
390
+ config.intermediate_size
391
+ if config.intermediate_size is not None
392
+ else int(config.hidden_size * 4)
393
+ )
394
+
395
+
396
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
397
+
398
+ self.x_k = nn.Parameter(torch.empty(1, 1, hidden_size))
399
+
400
+ self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
401
+ self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
402
+
403
+ def forward(self, hidden, state=None):
404
+ if hidden.size(1) == 1 and state is not None:
405
+ shifted = state[2][self.layer_id]
406
+ else:
407
+ shifted = self.time_shift(hidden)
408
+ if state is not None:
409
+ shifted[:, 0] = state[2][self.layer_id]
410
+ if len(shifted.size()) == 2:
411
+ shifted = shifted.unsqueeze(1)
412
+
413
+ delta_hidden_to_shifted = shifted - hidden
414
+ key = hidden + delta_hidden_to_shifted * self.x_k
415
+
416
+ key = torch.square(torch.relu(self.key(key)))
417
+ value = self.value(key)
418
+
419
+ if state is not None:
420
+ state[2][self.layer_id] = hidden[:, -1]
421
+
422
+ return value, state
423
+
424
+
425
+ class Rwkv7Block(nn.Module):
426
+ def __init__(self, config, layer_id):
427
+ super().__init__()
428
+ self.config = config
429
+ self.layer_id = layer_id
430
+
431
+ self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
432
+ self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
433
+
434
+ self.attention = Rwkv7SelfAttention(config, layer_id)
435
+ self.feed_forward = Rwkv7FeedForward(config, layer_id)
436
+
437
+ def forward(self, hidden, state=None, v_first=None, use_cache=False, output_attentions=False, seq_mode=True):
438
+ attention, state, v_first = self.attention(self.ln1(hidden), state=state, v_first=v_first, use_cache=use_cache, seq_mode=seq_mode)
439
+ hidden = hidden + attention
440
+
441
+ feed_forward, state = self.feed_forward(self.ln2(hidden), state=state)
442
+ hidden = hidden + feed_forward
443
+
444
+ outputs = (hidden, state, v_first)
445
+ if output_attentions:
446
+ outputs += (attention,)
447
+ else:
448
+ outputs += (None,)
449
+
450
+ return outputs
451
+
452
+
453
+ class Rwkv7PreTrainedModel(PreTrainedModel):
454
+ """
455
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
456
+ models.
457
+ """
458
+
459
+ config_class = Rwkv7Config
460
+ base_model_prefix = "rwkv7"
461
+ _no_split_modules = ["Rwkv7Block"]
462
+ _keep_in_fp32_modules = []
463
+ supports_gradient_checkpointing = True
464
+
465
+ def _init_weights(self, module):
466
+ return
467
+
468
+ """Initialize the weights."""
469
+ if isinstance(module, Rwkv7SelfAttention):
470
+ layer_id = module.layer_id
471
+ num_hidden_layers = module.config.num_hidden_layers
472
+ hidden_size = module.config.hidden_size
473
+ attention_hidden_size = module.attention_hidden_size
474
+ head_size = module.config.head_size
475
+ num_heads = attention_hidden_size // head_size
476
+
477
+ ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
478
+ ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
479
+
480
+ time_weight = torch.tensor(
481
+ [i / hidden_size for i in range(hidden_size)],
482
+ dtype=module.x_k.dtype,
483
+ device=module.x_k.device,
484
+ )
485
+ time_weight = time_weight[None, None, :]
486
+
487
+ decay_speed = [
488
+ -7.0 + 5.0 * (n / (attention_hidden_size - 1)) ** (0.85 + 1.0 * ratio_0_to_1 ** 0.5)
489
+ for n in range(attention_hidden_size)
490
+ ]
491
+ decay_speed = torch.tensor(decay_speed, dtype=module.w0.dtype, device=module.w0.device)
492
+
493
+ with torch.no_grad():
494
+ module.x_r.copy_( 1.0 - torch.pow(time_weight, 0.2 * ratio_1_to_almost0) )
495
+ module.x_w.copy_( 1.0 - torch.pow(time_weight, 0.9 * ratio_1_to_almost0) )
496
+ module.x_k.copy_( 1.0 - (torch.pow(time_weight, 0.9 * ratio_1_to_almost0) + 0.4 * ratio_0_to_1) )
497
+ module.x_v.copy_( 1.0 - (torch.pow(time_weight, 0.4 * ratio_1_to_almost0) + 0.6 * ratio_0_to_1) )
498
+ module.x_a.copy_( 1.0 - torch.pow(time_weight, 0.9 * ratio_1_to_almost0) )
499
+ module.x_g.copy_( 1.0 - torch.pow(time_weight, 0.2 * ratio_1_to_almost0) )
500
+
501
+ def ortho_init(x, scale):
502
+ with torch.no_grad():
503
+ shape = x.shape
504
+ if len(shape) == 2:
505
+ gain = math.sqrt(shape[0] / shape[1]) if shape[0] > shape[1] else 1
506
+ nn.init.orthogonal_(x, gain=gain * scale)
507
+ elif len(shape) == 3:
508
+ gain = math.sqrt(shape[1] / shape[2]) if shape[1] > shape[2] else 1
509
+ for i in range(shape[0]):
510
+ nn.init.orthogonal_(x[i], gain=gain * scale)
511
+ else:
512
+ assert False
513
+ return x
514
+
515
+ module.w0.copy_(decay_speed.reshape(1,1,attention_hidden_size) + 0.5) # !!! 0.5 comes from F.softplus !!!
516
+ module.w1.zero_()
517
+ ortho_init(module.w2, 0.1)
518
+
519
+ module.a0.zero_()
520
+ module.a1.zero_()
521
+ ortho_init(module.a2, 0.1)
522
+
523
+ module.v0.copy_(1.0)
524
+ module.v1.zero_()
525
+ ortho_init(module.v2, 0.1)
526
+
527
+ module.g1.zero_()
528
+ ortho_init(module.g2, 0.1)
529
+
530
+ self.k_k.copy_(0.85)
531
+ self.k_a.copy_(1.0)
532
+ self.r_k.zero_()
533
+
534
+ module.receptance.weight.data.uniform_(-0.5/(hidden_size**0.5), 0.5/(attention_hidden_size**0.5))
535
+ module.key.weight.data.uniform_(-0.05/(hidden_size**0.5), 0.05/(attention_hidden_size**0.5))
536
+ module.value.weight.data.uniform_(-0.5/(hidden_size**0.5), 0.5/(attention_hidden_size**0.5))
537
+ module.output.weight.data.zero_()
538
+
539
+ elif isinstance(module, Rwkv7FeedForward):
540
+ layer_id = module.layer_id
541
+ num_hidden_layers = module.config.num_hidden_layers
542
+ hidden_size = module.config.hidden_size
543
+
544
+ ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
545
+
546
+ time_weight = torch.tensor(
547
+ [i / hidden_size for i in range(hidden_size)],
548
+ dtype=module.x_k.dtype,
549
+ device=module.x_k.device,
550
+ )
551
+ time_weight = time_weight[None, None, :]
552
+
553
+ with torch.no_grad():
554
+ module.x_k.copy_( 1.0 - torch.pow(time_weight, ratio_1_to_almost0**4) )
555
+
556
+ self.key.weight.data.uniform_(-0.5/(hidden_size**0.5), 0.5/(hidden_size**0.5))
557
+ self.value.weight.data.zero_()
558
+
559
+ @dataclass
560
+ class Rwkv7Output(ModelOutput):
561
+ """
562
+ Class for the RWKV model outputs.
563
+ Args:
564
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
565
+ Sequence of hidden-states at the output of the last layer of the model.
566
+ state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
567
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
568
+ avoid providing the old `input_ids`.
569
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
570
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
571
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
572
+ the model at the output of each layer plus the optional initial embedding outputs.
573
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
574
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
575
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
576
+ the self-attention heads.
577
+ """
578
+
579
+ last_hidden_state: torch.FloatTensor = None
580
+ state: Optional[List[torch.FloatTensor]] = None
581
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
582
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
583
+
584
+
585
+ @dataclass
586
+ class Rwkv7CausalLMOutput(ModelOutput):
587
+ """
588
+ Base class for causal language model (or autoregressive) outputs.
589
+ Args:
590
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
591
+ Language modeling loss (for next-token prediction).
592
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
593
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
594
+ state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
595
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
596
+ avoid providing the old `input_ids`.
597
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
598
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
599
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
600
+ the model at the output of each layer plus the optional initial embedding outputs.
601
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
602
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
603
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
604
+ the self-attention heads.
605
+ """
606
+
607
+ loss: Optional[torch.FloatTensor] = None
608
+ logits: torch.FloatTensor = None
609
+ state: Optional[List[torch.FloatTensor]] = None
610
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
611
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
612
+
613
+
614
+ RWKV7_START_DOCSTRING = r"""
615
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
616
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
617
+ etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
618
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
619
+ general usage and behavior.
620
+ Parameters:
621
+ config ([`Rwkv7Config`]): Model configuration class with all the parameters of the model.
622
+ Initializing with a config file does not load the weights associated with the model, only the
623
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
624
+ """
625
+
626
+ RWKV7_INPUTS_DOCSTRING = r"""
627
+ Args:
628
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
629
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
630
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
631
+ sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their
632
+ past calculated should be passed as `input_ids`. Indices can be obtained using [`AutoTokenizer`]. See
633
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
634
+ IDs?](../glossary#input-ids)
635
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
636
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
637
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
638
+ model's internal embedding lookup matrix.
639
+ state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
640
+ If passed along, the model uses the previous state in all the blocks (which will give the output for the
641
+ `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
642
+ use_cache (`bool`, *optional*):
643
+ If set to `True`, the last state is returned and can be used to quickly generate the next logits.
644
+ output_attentions (`bool`, *optional*):
645
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
646
+ tensors for more detail.
647
+ output_hidden_states (`bool`, *optional*):
648
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
649
+ more detail.
650
+ return_dict (`bool`, *optional*):
651
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
652
+ """
653
+
654
+
655
+ @add_start_docstrings(
656
+ "The bare RWKV7 Model transformer outputting raw hidden-states without any specific head on top.",
657
+ RWKV7_START_DOCSTRING,
658
+ )
659
+ class Rwkv7Model(Rwkv7PreTrainedModel):
660
+ def __init__(self, config):
661
+ super().__init__(config)
662
+
663
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
664
+ self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
665
+ self.blocks = nn.ModuleList([Rwkv7Block(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
666
+ self.ln_out = nn.LayerNorm(config.hidden_size)
667
+
668
+ self.gradient_checkpointing = False
669
+
670
+ # Initialize weights and apply final processing
671
+ self.post_init()
672
+
673
+ def get_input_embeddings(self):
674
+ return self.embeddings
675
+
676
+ def set_input_embeddings(self, new_embeddings):
677
+ self.embeddings = new_embeddings
678
+
679
+ @add_start_docstrings_to_model_forward(RWKV7_INPUTS_DOCSTRING)
680
+ @add_code_sample_docstrings(
681
+ checkpoint=_CHECKPOINT_FOR_DOC,
682
+ output_type=Rwkv7Output,
683
+ config_class=_CONFIG_FOR_DOC,
684
+ )
685
+ def forward(
686
+ self,
687
+ input_ids: Optional[torch.LongTensor] = None,
688
+ attention_mask: Optional[torch.LongTensor] = None, # noqa
689
+ inputs_embeds: Optional[torch.FloatTensor] = None,
690
+ state: Optional[List[torch.FloatTensor]] = None,
691
+ use_cache: Optional[bool] = None,
692
+ output_attentions: Optional[bool] = None,
693
+ output_hidden_states: Optional[bool] = None,
694
+ return_dict: Optional[bool] = None,
695
+ ) -> Union[Tuple, Rwkv7Output]:
696
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
697
+ output_hidden_states = (
698
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
699
+ )
700
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
701
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
702
+
703
+ if input_ids is not None and inputs_embeds is not None:
704
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
705
+ elif input_ids is None and inputs_embeds is None:
706
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
707
+
708
+ if inputs_embeds is None:
709
+ inputs_embeds = self.embeddings(input_ids)
710
+
711
+ if state is None:
712
+ state = []
713
+ head_size = self.config.head_size
714
+ num_heads = self.config.attention_hidden_size // head_size
715
+ state_attn_x = torch.zeros(
716
+ (self.config.num_hidden_layers, inputs_embeds.size(0), self.config.hidden_size),
717
+ dtype=inputs_embeds.dtype,
718
+ requires_grad=False,
719
+ device=inputs_embeds.device,
720
+ ).contiguous()
721
+ state_attn_vk = torch.zeros(
722
+ (
723
+ self.config.num_hidden_layers,
724
+ inputs_embeds.size(0),
725
+ num_heads,
726
+ head_size,
727
+ head_size,
728
+ ),
729
+ dtype=torch.float32,
730
+ requires_grad=False,
731
+ device=inputs_embeds.device,
732
+ ).contiguous()
733
+ state_ffn_x = torch.zeros(
734
+ (self.config.num_hidden_layers, inputs_embeds.size(0), self.config.hidden_size),
735
+ dtype=inputs_embeds.dtype,
736
+ requires_grad=False,
737
+ device=inputs_embeds.device,
738
+ ).contiguous()
739
+ state.append(state_attn_x)
740
+ state.append(state_attn_vk)
741
+ state.append(state_ffn_x)
742
+
743
+ seq_mode = inputs_embeds.shape[1] > 1
744
+ hidden_states = self.pre_ln(inputs_embeds)
745
+ v_first = None
746
+
747
+ all_self_attentions = () if output_attentions else None
748
+ all_hidden_states = () if output_hidden_states else None
749
+ for idx, block in enumerate(self.blocks):
750
+ hidden_states, state, v_first, attentions = block(
751
+ hidden_states, state=state, v_first=v_first, use_cache=use_cache, output_attentions=output_attentions, seq_mode=seq_mode
752
+ )
753
+
754
+ if output_hidden_states:
755
+ all_hidden_states = all_hidden_states + (hidden_states,)
756
+
757
+ if output_attentions:
758
+ all_self_attentions = all_self_attentions + (attentions,)
759
+
760
+ hidden_states = self.ln_out(hidden_states)
761
+
762
+ if output_hidden_states:
763
+ all_hidden_states = all_hidden_states + (hidden_states,)
764
+
765
+ if not return_dict:
766
+ return (hidden_states, state, all_hidden_states, all_self_attentions)
767
+
768
+ return Rwkv7Output(
769
+ last_hidden_state=hidden_states,
770
+ state=state,
771
+ hidden_states=all_hidden_states, # None
772
+ attentions=all_self_attentions, # None
773
+ )
774
+
775
+ # copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
776
+ @add_start_docstrings(
777
+ """
778
+ The RWKV7 Model transformer with a language modeling head on top (linear layer with weights tied to the input
779
+ embeddings).
780
+ """,
781
+ RWKV7_START_DOCSTRING,
782
+ )
783
+ class Rwkv7ForCausalLM(Rwkv7PreTrainedModel, GenerationMixin):
784
+ _tied_weights_keys = ["head.weight"]
785
+
786
+ def __init__(self, config):
787
+ super().__init__(config)
788
+ self.model = Rwkv7Model(config)
789
+ self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
790
+
791
+ # Initialize weights and apply final processing
792
+ self.post_init()
793
+
794
+ def get_output_embeddings(self):
795
+ return self.head
796
+
797
+ def set_output_embeddings(self, new_embeddings):
798
+ self.head = new_embeddings
799
+
800
+ def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs):
801
+ # only last token for inputs_ids if the state is passed along.
802
+ if state is not None:
803
+ input_ids = input_ids[:, -1].unsqueeze(-1)
804
+
805
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
806
+ if inputs_embeds is not None and state is None:
807
+ model_inputs = {"inputs_embeds": inputs_embeds}
808
+ else:
809
+ model_inputs = {"input_ids": input_ids}
810
+
811
+ model_inputs["state"] = state
812
+ return model_inputs
813
+
814
+ @add_start_docstrings_to_model_forward(RWKV7_INPUTS_DOCSTRING)
815
+ @add_code_sample_docstrings(
816
+ checkpoint=_CHECKPOINT_FOR_DOC,
817
+ output_type=Rwkv7CausalLMOutput,
818
+ config_class=_CONFIG_FOR_DOC,
819
+ )
820
+ def forward(
821
+ self,
822
+ input_ids: Optional[torch.LongTensor] = None,
823
+ attention_mask: Optional[torch.LongTensor] = None,
824
+ inputs_embeds: Optional[torch.FloatTensor] = None,
825
+ state: Optional[List[torch.FloatTensor]] = None,
826
+ labels: Optional[torch.LongTensor] = None,
827
+ use_cache: Optional[bool] = None,
828
+ output_attentions: Optional[bool] = None,
829
+ output_hidden_states: Optional[bool] = None,
830
+ return_dict: Optional[bool] = None,
831
+ ) -> Union[Tuple, Rwkv7CausalLMOutput]:
832
+ r"""
833
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
834
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
835
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
836
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
837
+ """
838
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
839
+
840
+ outputs = self.model(
841
+ input_ids,
842
+ inputs_embeds=inputs_embeds,
843
+ state=state,
844
+ use_cache=use_cache,
845
+ output_attentions=output_attentions,
846
+ output_hidden_states=output_hidden_states,
847
+ return_dict=return_dict,
848
+ )
849
+ hidden_states = outputs[0]
850
+
851
+ logits = self.head(hidden_states)
852
+
853
+ loss = None
854
+ if labels is not None:
855
+ # move labels to correct device to enable model parallelism
856
+ labels = labels.to(logits.device)
857
+ # Shift so that tokens < n predict n
858
+ shift_logits = logits[..., :-1, :].contiguous()
859
+ shift_labels = labels[..., 1:].contiguous()
860
+ # Flatten the tokens
861
+ loss_fct = CrossEntropyLoss()
862
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
863
+
864
+ if not return_dict:
865
+ output = (logits,) + outputs[1:]
866
+ return ((loss,) + output) if loss is not None else output
867
+
868
+ return Rwkv7CausalLMOutput(
869
+ loss=loss,
870
+ logits=logits,
871
+ state=outputs.state,
872
+ hidden_states=outputs.hidden_states,
873
+ attentions=outputs.attentions,
874
+ )
rwkv_vocab_v20230424.txt ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "<s>",
4
+ "unk_token": "<s>"
5
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name_or_path": "rwkv-6-tokenizer",
3
+ "add_prefix_space": false,
4
+ "tokenizer_class": "Rwkv6Tokenizer",
5
+ "use_fast": false,
6
+ "auto_map": {
7
+ "AutoTokenizer": [
8
+ "hf_rwkv_tokenizer.Rwkv6Tokenizer",
9
+ null
10
+ ]
11
+ }
12
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff