File size: 3,220 Bytes
7134ebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
---
license: apache-2.0
language:
- ko
tags:
- rwkv
- KoRWKV
---

# KoRWKV

[RWKV-Runner](https://github.com/josStorer/RWKV-Runner)์—์„œ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด ๋ณ€ํ™˜ํ•œ ๋ชจ๋ธ ํŒŒ์ผ

- [beomi/KoAlpaca-KoRWKV-6B](https://huggingface.co/beomi/KoAlpaca-KoRWKV-6B)
- [beomi/KoRWKV-6B](https://huggingface.co/beomi/KoRWKV-6B)

```py
import re

import torch

from transformers import RwkvForCausalLM

def convert_state_dict(state_dict):
    state_dict_keys = list(state_dict.keys())
    for name in state_dict_keys:
        weight = state_dict.pop(name)
        # emb -> embedding
        if name.startswith("emb."):
            name = name.replace("emb.", "embeddings.")
        # ln_0 -> pre_ln (only present at block 0)
        if name.startswith("blocks.0.ln0"):
            name = name.replace("blocks.0.ln0", "blocks.0.pre_ln")
        # att -> attention
        name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name)
        # ffn -> feed_forward
        name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name)
        # time_mix_k -> time_mix_key and reshape
        if name.endswith(".time_mix_k"):
            name = name.replace(".time_mix_k", ".time_mix_key")
        # time_mix_v -> time_mix_value and reshape
        if name.endswith(".time_mix_v"):
            name = name.replace(".time_mix_v", ".time_mix_value")
        # time_mix_r -> time_mix_key and reshape
        if name.endswith(".time_mix_r"):
            name = name.replace(".time_mix_r", ".time_mix_receptance")

        if name != "head.weight":
            name = "rwkv." + name

        state_dict[name] = weight
    return state_dict


def revert_state_dict(state_dict):
    state_dict_keys = list(state_dict.keys())
    for name in state_dict_keys:
        weight = state_dict.pop(name)
        name = name.removeprefix("rwkv.")

        # emb -> embedding
        if name.startswith("embeddings."):
            name = name.replace("embeddings.", "emb.")
        # ln_0 -> pre_ln (only present at block 0)
        if name.startswith("blocks.0.pre_ln"):
            name = name.replace("blocks.0.pre_ln", "blocks.0.ln0")
        # att -> attention
        name = re.sub(r"blocks\.(\d+)\.attention", r"blocks.\1.att", name)
        # ffn -> feed_forward
        name = re.sub(r"blocks\.(\d+)\.feed_forward", r"blocks.\1.ffn", name)
        # time_mix_k -> time_mix_key and reshape
        if name.endswith(".time_mix_key"):
            name = name.replace(".time_mix_key", ".time_mix_k")
        # time_mix_v -> time_mix_value and reshape
        if name.endswith(".time_mix_value"):
            name = name.replace(".time_mix_value", ".time_mix_v")
        # time_mix_r -> time_mix_key and reshape
        if name.endswith(".time_mix_receptance"):
            name = name.replace(".time_mix_receptance", ".time_mix_r")

        state_dict[name] = weight
    return state_dict


if __name__ == "__main__":
    # repo = "beomi/KoRWKV-6B"
    repo = "beomi/KoAlpaca-KoRWKV-6B"
    model = RwkvForCausalLM.from_pretrained(repo, torch_dtype=torch.bfloat16)

    state_dict = model.state_dict()
    converted = revert_state_dict(state_dict)
    name = repo.split("/")[-1] + ".bf16.pth"

    torch.save(converted, name)
```