Triang-jyed-driung
commited on
Commit
·
35429bb
1
Parent(s):
3a6cb37
Added chat template and attention mask
Browse files- __init__.py +0 -0
- added_tokens.json +1 -1
- config.json +1 -0
- generation_config.json +5 -4
- hf_rwkv_tokenizer.py +1 -1
- modeling_rwkv7.py +16 -5
- special_tokens_map.json +3 -3
- tokenizer_config.json +6 -5
__init__.py
ADDED
File without changes
|
added_tokens.json
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
{
|
2 |
-
"
|
3 |
}
|
|
|
1 |
{
|
2 |
+
"<|rwkv_tokenizer_end_of_text|>": 0
|
3 |
}
|
config.json
CHANGED
@@ -4,6 +4,7 @@
|
|
4 |
],
|
5 |
"attention_hidden_size": 768,
|
6 |
"auto_map": {
|
|
|
7 |
"AutoConfig": "configuration_rwkv7.Rwkv7Config",
|
8 |
"AutoModelForCausalLM": "modeling_rwkv7.Rwkv7ForCausalLM"
|
9 |
},
|
|
|
4 |
],
|
5 |
"attention_hidden_size": 768,
|
6 |
"auto_map": {
|
7 |
+
"AutoModel": "modeling_rwkv7.Rwkv7Model",
|
8 |
"AutoConfig": "configuration_rwkv7.Rwkv7Config",
|
9 |
"AutoModelForCausalLM": "modeling_rwkv7.Rwkv7ForCausalLM"
|
10 |
},
|
generation_config.json
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
{
|
2 |
"chat_format": "chatml",
|
|
|
3 |
"eos_token_id": 0,
|
4 |
"pad_token_id": 0,
|
5 |
-
"max_window_size":
|
6 |
"max_new_tokens": 4096,
|
7 |
"do_sample": true,
|
8 |
-
"top_k":
|
9 |
-
"top_p": 0
|
10 |
-
"
|
11 |
"transformers_version": "4.31.1"
|
12 |
}
|
|
|
1 |
{
|
2 |
"chat_format": "chatml",
|
3 |
+
"bos_token_id": 0,
|
4 |
"eos_token_id": 0,
|
5 |
"pad_token_id": 0,
|
6 |
+
"max_window_size": 2147483647,
|
7 |
"max_new_tokens": 4096,
|
8 |
"do_sample": true,
|
9 |
+
"top_k": 65536,
|
10 |
+
"top_p": 1.0,
|
11 |
+
"temperature": 1.0,
|
12 |
"transformers_version": "4.31.1"
|
13 |
}
|
hf_rwkv_tokenizer.py
CHANGED
@@ -145,7 +145,7 @@ class Rwkv6Tokenizer(PreTrainedTokenizer):
|
|
145 |
model_input_names = ["input_ids", "attention_mask"]
|
146 |
|
147 |
def __init__(
|
148 |
-
self, vocab_file, bos_token="
|
149 |
):
|
150 |
if not os.path.isfile(vocab_file):
|
151 |
raise ValueError(
|
|
|
145 |
model_input_names = ["input_ids", "attention_mask"]
|
146 |
|
147 |
def __init__(
|
148 |
+
self, vocab_file, bos_token="<|rwkv_tokenizer_end_of_text|>", eos_token="<|rwkv_tokenizer_end_of_text|>", unk_token="<|rwkv_tokenizer_end_of_text|>", **kwargs
|
149 |
):
|
150 |
if not os.path.isfile(vocab_file):
|
151 |
raise ValueError(
|
modeling_rwkv7.py
CHANGED
@@ -317,7 +317,7 @@ class Rwkv7SelfAttention(nn.Module):
|
|
317 |
self.ln_x = nn.GroupNorm(H, C, eps=self.head_size * 1e-5)
|
318 |
|
319 |
|
320 |
-
def forward(self, hidden, state=None, v_first=None, use_cache=False, seq_mode=True):
|
321 |
# Mix hidden with the previous timestep to produce key, value, receptance
|
322 |
if hidden.size(1) == 1 and state is not None:
|
323 |
shifted = state[0][self.layer_id]
|
@@ -371,6 +371,8 @@ class Rwkv7SelfAttention(nn.Module):
|
|
371 |
rwkv7_attn_triton(r, w, k, v, -kk, kk*a, self.head_size)
|
372 |
|
373 |
xx = torch.nn.functional.group_norm(xx.view(B*T,H*N), num_groups=H, weight=self.ln_x.weight, bias=self.ln_x.bias, eps = self.ln_x.eps).view(B,T,H*N)
|
|
|
|
|
374 |
#x = x + ((r * k * self.r_k).view(B,T,H,N).sum(dim=-1, keepdim=True) * v.view(B,T,H,N)).view(B,T,H*N)
|
375 |
xx = xx + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)
|
376 |
xx = self.output(xx * g)
|
@@ -435,11 +437,15 @@ class Rwkv7Block(nn.Module):
|
|
435 |
self.attention = Rwkv7SelfAttention(config, layer_id)
|
436 |
self.feed_forward = Rwkv7FeedForward(config, layer_id)
|
437 |
|
438 |
-
def forward(self, hidden, state=None, v_first=None, use_cache=False, output_attentions=False, seq_mode=True):
|
439 |
-
attention, state, v_first = self.attention(
|
|
|
|
|
440 |
hidden = hidden + attention
|
441 |
|
442 |
-
feed_forward, state = self.feed_forward(
|
|
|
|
|
443 |
hidden = hidden + feed_forward
|
444 |
|
445 |
outputs = (hidden, state, v_first)
|
@@ -743,13 +749,15 @@ class Rwkv7Model(Rwkv7PreTrainedModel):
|
|
743 |
|
744 |
seq_mode = inputs_embeds.shape[1] > 1
|
745 |
hidden_states = self.pre_ln(inputs_embeds)
|
|
|
|
|
746 |
v_first = None
|
747 |
|
748 |
all_self_attentions = () if output_attentions else None
|
749 |
all_hidden_states = () if output_hidden_states else None
|
750 |
for idx, block in enumerate(self.blocks):
|
751 |
hidden_states, state, v_first, attentions = block(
|
752 |
-
hidden_states, state=state, v_first=v_first, use_cache=use_cache, output_attentions=output_attentions, seq_mode=seq_mode
|
753 |
)
|
754 |
|
755 |
if output_hidden_states:
|
@@ -759,6 +767,8 @@ class Rwkv7Model(Rwkv7PreTrainedModel):
|
|
759 |
all_self_attentions = all_self_attentions + (attentions,)
|
760 |
|
761 |
hidden_states = self.ln_out(hidden_states)
|
|
|
|
|
762 |
|
763 |
if output_hidden_states:
|
764 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
@@ -846,6 +856,7 @@ class Rwkv7ForCausalLM(Rwkv7PreTrainedModel, GenerationMixin):
|
|
846 |
output_attentions=output_attentions,
|
847 |
output_hidden_states=output_hidden_states,
|
848 |
return_dict=return_dict,
|
|
|
849 |
)
|
850 |
hidden_states = outputs[0]
|
851 |
|
|
|
317 |
self.ln_x = nn.GroupNorm(H, C, eps=self.head_size * 1e-5)
|
318 |
|
319 |
|
320 |
+
def forward(self, hidden, state=None, v_first=None, use_cache=False, seq_mode=True, attention_mask=None):
|
321 |
# Mix hidden with the previous timestep to produce key, value, receptance
|
322 |
if hidden.size(1) == 1 and state is not None:
|
323 |
shifted = state[0][self.layer_id]
|
|
|
371 |
rwkv7_attn_triton(r, w, k, v, -kk, kk*a, self.head_size)
|
372 |
|
373 |
xx = torch.nn.functional.group_norm(xx.view(B*T,H*N), num_groups=H, weight=self.ln_x.weight, bias=self.ln_x.bias, eps = self.ln_x.eps).view(B,T,H*N)
|
374 |
+
if attention_mask is not None:
|
375 |
+
xx *= attention_mask.unsqueeze(-1)
|
376 |
#x = x + ((r * k * self.r_k).view(B,T,H,N).sum(dim=-1, keepdim=True) * v.view(B,T,H,N)).view(B,T,H*N)
|
377 |
xx = xx + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)
|
378 |
xx = self.output(xx * g)
|
|
|
437 |
self.attention = Rwkv7SelfAttention(config, layer_id)
|
438 |
self.feed_forward = Rwkv7FeedForward(config, layer_id)
|
439 |
|
440 |
+
def forward(self, hidden, state=None, v_first=None, use_cache=False, output_attentions=False, seq_mode=True, attention_mask=None):
|
441 |
+
attention, state, v_first = self.attention(
|
442 |
+
self.ln1(hidden) if attention_mask is None else self.ln1(hidden) * attention_mask.unsqueeze(-1) ,
|
443 |
+
state=state, v_first=v_first, use_cache=use_cache, seq_mode=seq_mode, attention_mask=attention_mask)
|
444 |
hidden = hidden + attention
|
445 |
|
446 |
+
feed_forward, state = self.feed_forward(
|
447 |
+
self.ln2(hidden) if attention_mask is None else self.ln2(hidden) * attention_mask.unsqueeze(-1) ,
|
448 |
+
state=state)
|
449 |
hidden = hidden + feed_forward
|
450 |
|
451 |
outputs = (hidden, state, v_first)
|
|
|
749 |
|
750 |
seq_mode = inputs_embeds.shape[1] > 1
|
751 |
hidden_states = self.pre_ln(inputs_embeds)
|
752 |
+
if attention_mask is not None:
|
753 |
+
hidden_states *= attention_mask.unsqueeze(-1)
|
754 |
v_first = None
|
755 |
|
756 |
all_self_attentions = () if output_attentions else None
|
757 |
all_hidden_states = () if output_hidden_states else None
|
758 |
for idx, block in enumerate(self.blocks):
|
759 |
hidden_states, state, v_first, attentions = block(
|
760 |
+
hidden_states, state=state, v_first=v_first, use_cache=use_cache, output_attentions=output_attentions, seq_mode=seq_mode, attention_mask=attention_mask,
|
761 |
)
|
762 |
|
763 |
if output_hidden_states:
|
|
|
767 |
all_self_attentions = all_self_attentions + (attentions,)
|
768 |
|
769 |
hidden_states = self.ln_out(hidden_states)
|
770 |
+
if attention_mask is not None:
|
771 |
+
hidden_states *= attention_mask.unsqueeze(-1)
|
772 |
|
773 |
if output_hidden_states:
|
774 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
856 |
output_attentions=output_attentions,
|
857 |
output_hidden_states=output_hidden_states,
|
858 |
return_dict=return_dict,
|
859 |
+
attention_mask=attention_mask,
|
860 |
)
|
861 |
hidden_states = outputs[0]
|
862 |
|
special_tokens_map.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"bos_token": "
|
3 |
-
"eos_token": "
|
4 |
-
"unk_token": "
|
5 |
}
|
|
|
1 |
{
|
2 |
+
"bos_token": "<|rwkv_tokenizer_end_of_text|>",
|
3 |
+
"eos_token": "<|rwkv_tokenizer_end_of_text|>",
|
4 |
+
"unk_token": "<|rwkv_tokenizer_end_of_text|>"
|
5 |
}
|
tokenizer_config.json
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
"add_prefix_space": false,
|
3 |
"added_tokens_decoder": {
|
4 |
"0": {
|
5 |
-
"content": "
|
6 |
"lstrip": false,
|
7 |
"normalized": false,
|
8 |
"rstrip": false,
|
@@ -16,11 +16,12 @@
|
|
16 |
null
|
17 |
]
|
18 |
},
|
19 |
-
"bos_token": "
|
20 |
"clean_up_tokenization_spaces": false,
|
21 |
-
"eos_token": "
|
22 |
"model_max_length": 1000000000000000019884624838656,
|
23 |
"tokenizer_class": "Rwkv6Tokenizer",
|
24 |
-
"unk_token": "
|
25 |
-
"use_fast": false
|
|
|
26 |
}
|
|
|
2 |
"add_prefix_space": false,
|
3 |
"added_tokens_decoder": {
|
4 |
"0": {
|
5 |
+
"content": "<|rwkv_tokenizer_end_of_text|>",
|
6 |
"lstrip": false,
|
7 |
"normalized": false,
|
8 |
"rstrip": false,
|
|
|
16 |
null
|
17 |
]
|
18 |
},
|
19 |
+
"bos_token": "<|rwkv_tokenizer_end_of_text|>",
|
20 |
"clean_up_tokenization_spaces": false,
|
21 |
+
"eos_token": "<|rwkv_tokenizer_end_of_text|>",
|
22 |
"model_max_length": 1000000000000000019884624838656,
|
23 |
"tokenizer_class": "Rwkv6Tokenizer",
|
24 |
+
"unk_token": "<|rwkv_tokenizer_end_of_text|>",
|
25 |
+
"use_fast": false,
|
26 |
+
"chat_template": "{{ '<|rwkv_tokenizer_end_of_text|>' }}{% for message in messages %}{% if message['role'] == 'user' %}{{'User: ' + message['content'] + '\n\n'}}{% elif message['role'] == 'system' %}{{'System: ' + message['content'] + '\n\n'}}{% elif message['role'] == 'assistant' %}{{'Assistant: ' + message['content'] + '\n\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
27 |
}
|