Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- ko
|
5 |
+
tags:
|
6 |
+
- rwkv
|
7 |
+
- KoRWKV
|
8 |
+
---
|
9 |
+
|
10 |
+
# KoRWKV
|
11 |
+
|
12 |
+
[RWKV-Runner](https://github.com/josStorer/RWKV-Runner)์์ ์ฌ์ฉํ๊ธฐ ์ํด ๋ณํํ ๋ชจ๋ธ ํ์ผ
|
13 |
+
|
14 |
+
- [beomi/KoAlpaca-KoRWKV-6B](https://huggingface.co/beomi/KoAlpaca-KoRWKV-6B)
|
15 |
+
- [beomi/KoRWKV-6B](https://huggingface.co/beomi/KoRWKV-6B)
|
16 |
+
|
17 |
+
```py
|
18 |
+
import re
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from transformers import RwkvForCausalLM
|
23 |
+
|
24 |
+
def convert_state_dict(state_dict):
|
25 |
+
state_dict_keys = list(state_dict.keys())
|
26 |
+
for name in state_dict_keys:
|
27 |
+
weight = state_dict.pop(name)
|
28 |
+
# emb -> embedding
|
29 |
+
if name.startswith("emb."):
|
30 |
+
name = name.replace("emb.", "embeddings.")
|
31 |
+
# ln_0 -> pre_ln (only present at block 0)
|
32 |
+
if name.startswith("blocks.0.ln0"):
|
33 |
+
name = name.replace("blocks.0.ln0", "blocks.0.pre_ln")
|
34 |
+
# att -> attention
|
35 |
+
name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name)
|
36 |
+
# ffn -> feed_forward
|
37 |
+
name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name)
|
38 |
+
# time_mix_k -> time_mix_key and reshape
|
39 |
+
if name.endswith(".time_mix_k"):
|
40 |
+
name = name.replace(".time_mix_k", ".time_mix_key")
|
41 |
+
# time_mix_v -> time_mix_value and reshape
|
42 |
+
if name.endswith(".time_mix_v"):
|
43 |
+
name = name.replace(".time_mix_v", ".time_mix_value")
|
44 |
+
# time_mix_r -> time_mix_key and reshape
|
45 |
+
if name.endswith(".time_mix_r"):
|
46 |
+
name = name.replace(".time_mix_r", ".time_mix_receptance")
|
47 |
+
|
48 |
+
if name != "head.weight":
|
49 |
+
name = "rwkv." + name
|
50 |
+
|
51 |
+
state_dict[name] = weight
|
52 |
+
return state_dict
|
53 |
+
|
54 |
+
|
55 |
+
def revert_state_dict(state_dict):
|
56 |
+
state_dict_keys = list(state_dict.keys())
|
57 |
+
for name in state_dict_keys:
|
58 |
+
weight = state_dict.pop(name)
|
59 |
+
name = name.removeprefix("rwkv.")
|
60 |
+
|
61 |
+
# emb -> embedding
|
62 |
+
if name.startswith("embeddings."):
|
63 |
+
name = name.replace("embeddings.", "emb.")
|
64 |
+
# ln_0 -> pre_ln (only present at block 0)
|
65 |
+
if name.startswith("blocks.0.pre_ln"):
|
66 |
+
name = name.replace("blocks.0.pre_ln", "blocks.0.ln0")
|
67 |
+
# att -> attention
|
68 |
+
name = re.sub(r"blocks\.(\d+)\.attention", r"blocks.\1.att", name)
|
69 |
+
# ffn -> feed_forward
|
70 |
+
name = re.sub(r"blocks\.(\d+)\.feed_forward", r"blocks.\1.ffn", name)
|
71 |
+
# time_mix_k -> time_mix_key and reshape
|
72 |
+
if name.endswith(".time_mix_key"):
|
73 |
+
name = name.replace(".time_mix_key", ".time_mix_k")
|
74 |
+
# time_mix_v -> time_mix_value and reshape
|
75 |
+
if name.endswith(".time_mix_value"):
|
76 |
+
name = name.replace(".time_mix_value", ".time_mix_v")
|
77 |
+
# time_mix_r -> time_mix_key and reshape
|
78 |
+
if name.endswith(".time_mix_receptance"):
|
79 |
+
name = name.replace(".time_mix_receptance", ".time_mix_r")
|
80 |
+
|
81 |
+
state_dict[name] = weight
|
82 |
+
return state_dict
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
# repo = "beomi/KoRWKV-6B"
|
87 |
+
repo = "beomi/KoAlpaca-KoRWKV-6B"
|
88 |
+
model = RwkvForCausalLM.from_pretrained(repo, torch_dtype=torch.bfloat16)
|
89 |
+
|
90 |
+
state_dict = model.state_dict()
|
91 |
+
converted = revert_state_dict(state_dict)
|
92 |
+
name = repo.split("/")[-1] + ".bf16.pth"
|
93 |
+
|
94 |
+
torch.save(converted, name)
|
95 |
+
```
|