File size: 7,049 Bytes
85e407f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cc4a47
85e407f
 
 
 
 
 
 
 
93c15cc
 
 
 
85e407f
 
 
 
 
 
 
 
 
 
 
 
c7adb99
85e407f
c7adb99
85e407f
 
 
 
 
 
 
 
 
 
c7adb99
c29ddfe
c7adb99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85e407f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
#!/usr/bin/env python3
import json, lzma, glob, sys, os, re, subprocess
from pprint import pprint

import torch, sys
import transformers

model_path = "e3.0"
print(f"Loading {model_path} ...")

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_path, 
    device_map = "auto",
    torch_dtype = torch.bfloat16,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(".")

from qwen_vocab import old2new, new2old
STOP_WORDS = "<|im_end|> <|endoftext|>".split()


def map_tids(map_dict, tids):
    return [ map_dict[x] for x in tids if x in map_dict ]


class KeywordsStoppingCriteria(transformers.StoppingCriteria):
    def __init__(self, str):
        self.keyword_ids = tokenizer.encode(str)
        self.keyword_ids = map_tids(old2new, self.keyword_ids)
        self.keyword_len = len(self.keyword_ids)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        last_token_ids = input_ids[0][-self.keyword_len:]
        return last_token_ids.tolist() == self.keyword_ids

stop_criteria_list = transformers.StoppingCriteriaList(
    [ KeywordsStoppingCriteria(x) for x in STOP_WORDS ]
)


def chat(q, temperature = 0.5):
    prompt = f"<|im_start|>user\n{q}<|im_end|>\n<|im_start|>assistant"
    old_tids = tokenizer.encode(prompt)

    new_tids = map_tids(old2new, old_tids)
    new_old_tids = map_tids(new2old, new_tids)

    new_prompt = tokenizer.decode(new_old_tids)

    # if new_old_tids != old_tids:
    #     print(f"!!! Cảnh báo sự trimm vocab làm mất thông tin !!!")
    #     print(f"!!! old prompt: {prompt}")
    #     print(f"!!! new prompt: {new_prompt}")

    inputs = tokenizer(new_prompt, return_tensors="pt").to(model.device)

    assert inputs["input_ids"][0].tolist() == new_old_tids

    for i, x in enumerate(new_tids):
        inputs["input_ids"][0][i] = x

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=1024*4,
            temperature=0.1,
            top_p=1.0, top_k=30, do_sample=True,
            repetition_penalty=1.3,
            stopping_criteria=stop_criteria_list,
            pad_token_id=tokenizer.pad_token_id,
        )

    answer_tids = output_ids[0][len(inputs["input_ids"][0]) : ] # bỏ đi prompt tokens
    old_tids = map_tids(new2old, answer_tids.tolist())
    return tokenizer.decode(old_tids).split("<|im_end|>")[0].strip()


envi = """
Không cần giải thích, giữ nguyên các từ viết tắt, các ký hiệu, và dịch đoạn văn sau sang tiếng Việt:

Ví dụ 1:
<|en|> Most languages have been developed using the same alphabet because of the popularity and prevalence of the latin-based English Alphabet. This alphabet is estimated to be used by around 2 billion people, and is used by many European, romance, African and Vietnamese languages.
<|vi|> Hầu hết các ngôn ngữ được phát triển sử dụng cùng một bảng chữ cái do sự phổ biến và thịnh hành của bảng chữ cái tiếng Anh dựa trên hệ Latin. Bảng chữ cái này ước tính được khoảng 2 tỷ người sử dụng[4], và được dùng trong nhiều ngôn ngữ châu Âu, ngôn ngữ lãng mạn, châu Phi và tiếng Việt.

Ví dụ 2:
<|en|> Do you have any fun expressions in your language to say you forget something? Share them in the comments below!
<|vi|> Bạn có câu nói vui nào trong ngôn ngữ của mình để diễn tả việc quên điều gì đó không? Hãy chia sẻ trong phần bình luận bên dưới!

Ví dụ 3:
<|en|> What is the scientific explanation for making us feel "cuteness" when we see something cute?
<|vi|> Giải thích khoa học về việc tại sao chúng ta cảm thấy "dễ thương" khi nhìn thấy thứ gì đó dễ thương là gì?

Không cần giải thích, giữ nguyên các từ viết tắt, các ký hiệu, và dịch đoạn văn sau sang tiếng Việt:
<|en|> {english}
<|vi|>
""".strip()


junks = """
Câu trả lời của tôi:
sang tiếng Việt:
sang tiếng Việt là:
dịch tiếng Việt:
dịch tiếng Việt là:
tiếng Việt như sau:
sang tiếng Việt sẽ là:
tiếng Việt của đoạn văn:
tiếng Việt của câu hỏi là:
tiếng Việt của câu trên là:
tiếng Việt của đoạn văn là:
tiếng Việt của đoạn văn trên:
tiếng Việt của đoạn văn như sau:
tiếng Việt của đoạn văn trên là:
dịch đoạn văn sau sang tiếng Việt:
tiếng Việt của đoạn văn bạn yêu cầu:
Bây giờ đến lượt bạn:
dịch sang tiếng Việt là 
<|en|>
<|vi|>
""".strip().split("\n")

# print(junks)

def trans(prompt, temperinit = 0.2):
    print("\n- - - - - -\n")
    print(prompt, "\n==>\n" )

    res = trans_(prompt, temperinit)

    print(res, flush = True)
    return res


def trans_(prompt, temperinit = 0.2):

    if not isinstance(prompt, str):
        return prompt

    if len(prompt) < 8:
        return prompt 

    trials = max_trials = 3
    temperature = temperinit
    temperdelta = 0.2

    while trials > 0:
        trials -= 1
        n = max_trials - trials

        if n > 1:
            temperature += temperdelta
            print(f"\033[91m{prompt}\033[0m => {x}") # Red then reset
            print(f"\033[33mThử lại lần {n}\033[0m") # Yellow then reset

        x = trans__(prompt, temperature = temperature).strip()

        if x is not None and len(x) > 0:

            for j in junks: # Loại bỏ những header thừa
                x = x.split(j.strip())[-1].strip()

            pp = prompt.lower()
            if "tiếng việt" in pp or "vietnamese" in pp:
                return x

            xx = x.lower()
            if  "tiếng việt" not in xx:
                return x


def trans__(prompt, temperature = 0.0):
    # print("\n- - - - - -\n")
    # print(prompt, "\n==>\n")

    prompt = envi.format(english = prompt)
    res = chat(prompt, temperature = temperature)

    # print(res)
    return res


# infile = args.input
infile = sys.argv[1]
outfile = infile.replace(".jsonl.xz", "__vi.jsonl")


if os.path.exists(outfile):
    sources = [ json.loads(line)['source'] for line in open(outfile, "rt") ]
else:
    sources = []

print(len(sources), sources[-1] if len(sources) > 0 else None)



for idx, line in enumerate(lzma.open(infile, "rt")):

    source = f"{infile}:{idx}"
    if source in sources: continue
    print(source)

    data = json.loads(line)

    data["query"] = trans(data['query'])
    if data["query"] is None: continue

    for idx, x in enumerate( data["pos"] ):
        data['pos'][idx] = trans(x)
        if data['pos'][idx] is None: break

    if data['pos'][idx] is None: continue


    for idx, x in enumerate( data["neg"] ):
        data['neg'][idx] = trans(x)
        if data['neg'][idx] is None: break

    if data['neg'][idx] is None: continue

    with open(outfile, "at") as f:
        data["source"] = source
        f.write(json.dumps(data, ensure_ascii = False) + "\n")