File size: 7,049 Bytes
85e407f 2cc4a47 85e407f 93c15cc 85e407f c7adb99 85e407f c7adb99 85e407f c7adb99 c29ddfe c7adb99 85e407f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
#!/usr/bin/env python3
import json, lzma, glob, sys, os, re, subprocess
from pprint import pprint
import torch, sys
import transformers
model_path = "e3.0"
print(f"Loading {model_path} ...")
model = transformers.AutoModelForCausalLM.from_pretrained(
model_path,
device_map = "auto",
torch_dtype = torch.bfloat16,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(".")
from qwen_vocab import old2new, new2old
STOP_WORDS = "<|im_end|> <|endoftext|>".split()
def map_tids(map_dict, tids):
return [ map_dict[x] for x in tids if x in map_dict ]
class KeywordsStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, str):
self.keyword_ids = tokenizer.encode(str)
self.keyword_ids = map_tids(old2new, self.keyword_ids)
self.keyword_len = len(self.keyword_ids)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
last_token_ids = input_ids[0][-self.keyword_len:]
return last_token_ids.tolist() == self.keyword_ids
stop_criteria_list = transformers.StoppingCriteriaList(
[ KeywordsStoppingCriteria(x) for x in STOP_WORDS ]
)
def chat(q, temperature = 0.5):
prompt = f"<|im_start|>user\n{q}<|im_end|>\n<|im_start|>assistant"
old_tids = tokenizer.encode(prompt)
new_tids = map_tids(old2new, old_tids)
new_old_tids = map_tids(new2old, new_tids)
new_prompt = tokenizer.decode(new_old_tids)
# if new_old_tids != old_tids:
# print(f"!!! Cảnh báo sự trimm vocab làm mất thông tin !!!")
# print(f"!!! old prompt: {prompt}")
# print(f"!!! new prompt: {new_prompt}")
inputs = tokenizer(new_prompt, return_tensors="pt").to(model.device)
assert inputs["input_ids"][0].tolist() == new_old_tids
for i, x in enumerate(new_tids):
inputs["input_ids"][0][i] = x
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=1024*4,
temperature=0.1,
top_p=1.0, top_k=30, do_sample=True,
repetition_penalty=1.3,
stopping_criteria=stop_criteria_list,
pad_token_id=tokenizer.pad_token_id,
)
answer_tids = output_ids[0][len(inputs["input_ids"][0]) : ] # bỏ đi prompt tokens
old_tids = map_tids(new2old, answer_tids.tolist())
return tokenizer.decode(old_tids).split("<|im_end|>")[0].strip()
envi = """
Không cần giải thích, giữ nguyên các từ viết tắt, các ký hiệu, và dịch đoạn văn sau sang tiếng Việt:
Ví dụ 1:
<|en|> Most languages have been developed using the same alphabet because of the popularity and prevalence of the latin-based English Alphabet. This alphabet is estimated to be used by around 2 billion people, and is used by many European, romance, African and Vietnamese languages.
<|vi|> Hầu hết các ngôn ngữ được phát triển sử dụng cùng một bảng chữ cái do sự phổ biến và thịnh hành của bảng chữ cái tiếng Anh dựa trên hệ Latin. Bảng chữ cái này ước tính được khoảng 2 tỷ người sử dụng[4], và được dùng trong nhiều ngôn ngữ châu Âu, ngôn ngữ lãng mạn, châu Phi và tiếng Việt.
Ví dụ 2:
<|en|> Do you have any fun expressions in your language to say you forget something? Share them in the comments below!
<|vi|> Bạn có câu nói vui nào trong ngôn ngữ của mình để diễn tả việc quên điều gì đó không? Hãy chia sẻ trong phần bình luận bên dưới!
Ví dụ 3:
<|en|> What is the scientific explanation for making us feel "cuteness" when we see something cute?
<|vi|> Giải thích khoa học về việc tại sao chúng ta cảm thấy "dễ thương" khi nhìn thấy thứ gì đó dễ thương là gì?
Không cần giải thích, giữ nguyên các từ viết tắt, các ký hiệu, và dịch đoạn văn sau sang tiếng Việt:
<|en|> {english}
<|vi|>
""".strip()
junks = """
Câu trả lời của tôi:
sang tiếng Việt:
sang tiếng Việt là:
dịch tiếng Việt:
dịch tiếng Việt là:
tiếng Việt như sau:
sang tiếng Việt sẽ là:
tiếng Việt của đoạn văn:
tiếng Việt của câu hỏi là:
tiếng Việt của câu trên là:
tiếng Việt của đoạn văn là:
tiếng Việt của đoạn văn trên:
tiếng Việt của đoạn văn như sau:
tiếng Việt của đoạn văn trên là:
dịch đoạn văn sau sang tiếng Việt:
tiếng Việt của đoạn văn bạn yêu cầu:
Bây giờ đến lượt bạn:
dịch sang tiếng Việt là
<|en|>
<|vi|>
""".strip().split("\n")
# print(junks)
def trans(prompt, temperinit = 0.2):
print("\n- - - - - -\n")
print(prompt, "\n==>\n" )
res = trans_(prompt, temperinit)
print(res, flush = True)
return res
def trans_(prompt, temperinit = 0.2):
if not isinstance(prompt, str):
return prompt
if len(prompt) < 8:
return prompt
trials = max_trials = 3
temperature = temperinit
temperdelta = 0.2
while trials > 0:
trials -= 1
n = max_trials - trials
if n > 1:
temperature += temperdelta
print(f"\033[91m{prompt}\033[0m => {x}") # Red then reset
print(f"\033[33mThử lại lần {n}\033[0m") # Yellow then reset
x = trans__(prompt, temperature = temperature).strip()
if x is not None and len(x) > 0:
for j in junks: # Loại bỏ những header thừa
x = x.split(j.strip())[-1].strip()
pp = prompt.lower()
if "tiếng việt" in pp or "vietnamese" in pp:
return x
xx = x.lower()
if "tiếng việt" not in xx:
return x
def trans__(prompt, temperature = 0.0):
# print("\n- - - - - -\n")
# print(prompt, "\n==>\n")
prompt = envi.format(english = prompt)
res = chat(prompt, temperature = temperature)
# print(res)
return res
# infile = args.input
infile = sys.argv[1]
outfile = infile.replace(".jsonl.xz", "__vi.jsonl")
if os.path.exists(outfile):
sources = [ json.loads(line)['source'] for line in open(outfile, "rt") ]
else:
sources = []
print(len(sources), sources[-1] if len(sources) > 0 else None)
for idx, line in enumerate(lzma.open(infile, "rt")):
source = f"{infile}:{idx}"
if source in sources: continue
print(source)
data = json.loads(line)
data["query"] = trans(data['query'])
if data["query"] is None: continue
for idx, x in enumerate( data["pos"] ):
data['pos'][idx] = trans(x)
if data['pos'][idx] is None: break
if data['pos'][idx] is None: continue
for idx, x in enumerate( data["neg"] ):
data['neg'][idx] = trans(x)
if data['neg'][idx] is None: break
if data['neg'][idx] is None: continue
with open(outfile, "at") as f:
data["source"] = source
f.write(json.dumps(data, ensure_ascii = False) + "\n")
|