|
import gradio as gr |
|
import os, gc, copy, torch |
|
from huggingface_hub import hf_hub_download |
|
from pynvml import * |
|
|
|
|
|
HAS_GPU = False |
|
|
|
|
|
ctx_limit = 2000 |
|
title = "RWKV-5-World-1B5-v2-Translator" |
|
model_file = "RWKV-5-World-1B5-v2-20231025-ctx4096" |
|
|
|
|
|
try: |
|
nvmlInit() |
|
GPU_COUNT = nvmlDeviceGetCount() |
|
if GPU_COUNT > 0: |
|
HAS_GPU = True |
|
gpu_h = nvmlDeviceGetHandleByIndex(0) |
|
except NVMLError as error: |
|
print(error) |
|
|
|
os.environ["RWKV_JIT_ON"] = '1' |
|
|
|
|
|
MODEL_STRAT = "cpu bf16" |
|
os.environ["RWKV_CUDA_ON"] = '0' |
|
|
|
|
|
if HAS_GPU: |
|
os.environ["RWKV_CUDA_ON"] = '1' |
|
MODEL_STRAT = "cuda bf16" |
|
|
|
|
|
from rwkv.model import RWKV |
|
model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{model_file}.pth") |
|
model = RWKV(model=model_path, strategy=MODEL_STRAT) |
|
from rwkv.utils import PIPELINE |
|
pipeline = PIPELINE(model, "rwkv_vocab_v20230424") |
|
|
|
|
|
def universal_deepcopy(obj): |
|
if hasattr(obj, 'clone'): |
|
return obj.clone() |
|
elif isinstance(obj, list): |
|
return [universal_deepcopy(item) for item in obj] |
|
else: |
|
return copy.deepcopy(obj) |
|
|
|
|
|
def inspect_structure(obj, depth=0): |
|
indent = " " * depth |
|
obj_type = type(obj).__name__ |
|
|
|
if isinstance(obj, list): |
|
print(f"{indent}List (length {len(obj)}):") |
|
for item in obj: |
|
inspect_structure(item, depth + 1) |
|
elif isinstance(obj, dict): |
|
print(f"{indent}Dict (length {len(obj)}):") |
|
for key, value in obj.items(): |
|
print(f"{indent} Key: {key}") |
|
inspect_structure(value, depth + 1) |
|
else: |
|
print(f"{indent}{obj_type}") |
|
|
|
|
|
def precompute_state(text): |
|
state = None |
|
text_encoded = pipeline.encode(text) |
|
_, state = model.forward(text_encoded, state) |
|
return state |
|
|
|
|
|
INSTRUCT_PREFIX = f''' |
|
You are a translator bot that can translate text to any language. |
|
And will respond only with the translated text, without additional comments. |
|
|
|
## From English: |
|
It is not enough to know, we must also apply; it is not enough to will, we must also do. |
|
## To Polish: |
|
Nie wystarczy wiedzieć, trzeba także zastosować; nie wystarczy chcieć, trzeba też działać. |
|
|
|
## From Spanish: |
|
La muerte no nos concierne, porque mientras existamos, la muerte no está aquí. Y cuando llega, ya no existimos. |
|
## To English: |
|
Death does not concern us, because as long as we exist, death is not here. And when it does come, we no longer exist. |
|
|
|
|
|
''' |
|
|
|
|
|
PREFIX_STATE = precompute_state(INSTRUCT_PREFIX) |
|
|
|
|
|
def translate( |
|
text, source_language, target_language, |
|
inState=PREFIX_STATE, |
|
temperature=0.2, |
|
top_p=0.5, |
|
presencePenalty = 0.1, |
|
countPenalty = 0.1, |
|
): |
|
prompt = f"## From {source_language}:\n{text}\n\n## To {target_language}:\n" |
|
ctx = prompt.strip() |
|
all_tokens = [] |
|
out_last = 0 |
|
out_str = '' |
|
occurrence = {} |
|
|
|
alpha_frequency = countPenalty |
|
alpha_presence = presencePenalty |
|
|
|
state = None |
|
if inState != None: |
|
state = universal_deepcopy(inState) |
|
|
|
|
|
gc.collect() |
|
if HAS_GPU == True : |
|
torch.cuda.empty_cache() |
|
|
|
|
|
for i in range(ctx_limit): |
|
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state) |
|
for n in occurrence: |
|
out[n] -= (alpha_presence + occurrence[n] * alpha_frequency) |
|
token = pipeline.sample_logits(out, temperature=temperature, top_p=top_p) |
|
|
|
if token in [0]: |
|
break |
|
|
|
all_tokens += [token] |
|
for xxx in occurrence: |
|
occurrence[xxx] *= 0.996 |
|
if token not in occurrence: |
|
occurrence[token] = 1 |
|
else: |
|
occurrence[token] += 1 |
|
|
|
tmp = pipeline.decode(all_tokens[out_last:]) |
|
if '\ufffd' not in tmp: |
|
out_str += tmp |
|
out_last = i + 1 |
|
else: |
|
return out_str.strip() |
|
|
|
if "\n:" in out_str : |
|
out_str = out_str.split("\n\nHuman:")[0].split("\nHuman:")[0] |
|
return out_str.strip() |
|
|
|
if "{source_language}:" in out_str : |
|
out_str = out_str.split("{source_language}:")[0] |
|
return out_str.strip() |
|
|
|
if "{target_language}:" in out_str : |
|
out_str = out_str.split("{target_language}:")[0] |
|
return out_str.strip() |
|
|
|
if "\nHuman:" in out_str : |
|
out_str = out_str.split("\n\nHuman:")[0].split("\nHuman:")[0] |
|
return out_str.strip() |
|
|
|
if "\nAssistant:" in out_str : |
|
out_str = out_str.split("\n\nAssistant:")[0].split("\nAssistant:")[0] |
|
return out_str.strip() |
|
|
|
if "\n#" in out_str : |
|
out_str = out_str.split("\n\n#")[0].split("\n#")[0] |
|
return out_str.strip() |
|
|
|
|
|
yield out_str.strip() |
|
|
|
del out |
|
del state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return out_str.strip() |
|
|
|
|
|
LANGUAGES = [ |
|
"English", |
|
"Chinese", |
|
"Spanish", |
|
"Bengali", |
|
"Hindi", |
|
"Portuguese", |
|
"Russian", |
|
"Japanese", |
|
"German", |
|
"Chinese (Wu)", |
|
"Javanese", |
|
"Korean", |
|
"French", |
|
"Vietnamese", |
|
"Telugu", |
|
"Chinese (Yue)", |
|
"Marathi", |
|
"Tamil", |
|
"Turkish", |
|
"Urdu", |
|
"Chinese (Min Nan)", |
|
"Chinese (Jin Yu)", |
|
"Gujarati", |
|
"Polish", |
|
"Arabic (Egyptian Spoken)", |
|
"Ukrainian", |
|
"Italian", |
|
"Chinese (Xiang)", |
|
"Malayalam", |
|
"Chinese (Hakka)", |
|
"Kannada", |
|
"Oriya", |
|
"Panjabi (Western)", |
|
"Panjabi (Eastern)", |
|
"Sunda", |
|
"Romanian", |
|
"Bhojpuri", |
|
"Azerbaijani (South)", |
|
"Farsi (Western)", |
|
"Maithili", |
|
"Hausa", |
|
"Arabic (Algerian Spoken)", |
|
"Burmese", |
|
"Serbo-Croatian", |
|
"Chinese (Gan)", |
|
"Awadhi", |
|
"Thai", |
|
"Dutch", |
|
"Yoruba", |
|
"Sindhi", |
|
"Arabic (Moroccan Spoken)", |
|
"Arabic (Saidi Spoken)", |
|
"Uzbek, Northern", |
|
"Malay", |
|
"Amharic", |
|
"Indonesian", |
|
"Igbo", |
|
"Tagalog", |
|
"Nepali", |
|
"Arabic (Sudanese Spoken)", |
|
"Saraiki", |
|
"Cebuano", |
|
"Arabic (North Levantine Spoken)", |
|
"Thai (Northeastern)", |
|
"Assamese", |
|
"Hungarian", |
|
"Chittagonian", |
|
"Arabic (Mesopotamian Spoken)", |
|
"Madura", |
|
"Sinhala", |
|
"Haryanvi", |
|
"Marwari", |
|
"Czech", |
|
"Greek", |
|
"Magahi", |
|
"Chhattisgarhi", |
|
"Deccan", |
|
"Chinese (Min Bei)", |
|
"Belarusan", |
|
"Zhuang (Northern)", |
|
"Arabic (Najdi Spoken)", |
|
"Pashto (Northern)", |
|
"Somali", |
|
"Malagasy", |
|
"Arabic (Tunisian Spoken)", |
|
"Rwanda", |
|
"Zulu", |
|
"Latin", |
|
"Bulgarian", |
|
"Swedish", |
|
"Lombard", |
|
"Oromo (West-central)", |
|
"Pashto (Southern)", |
|
"Kazakh", |
|
"Ilocano", |
|
"Tatar", |
|
"Fulfulde (Nigerian)", |
|
"Arabic (Sanaani Spoken)", |
|
"Uyghur", |
|
"Haitian Creole French", |
|
"Azerbaijani, North", |
|
"Napoletano-calabrese", |
|
"Khmer (Central)", |
|
"Farsi (Eastern)", |
|
"Akan", |
|
"Hiligaynon", |
|
"Kurmanji", |
|
"Shona" |
|
] |
|
|
|
|
|
EXAMPLES = [ |
|
|
|
["Többen tanulnának a hibáikból, ha nem lennének annyira elfoglalva, hogy tagadják azokat.", "Hungarian", "English"], |
|
["La mejor venganza es el éxito masivo.", "Spanish", "English"], |
|
["Tout est bien qui finit bien.", "French", "English"], |
|
["Lasciate ogne speranza, voi ch'intrate.", "Italian", "English"], |
|
["Errare humanum est.", "Latin", "English"], |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
with gr.Blocks(title=title) as demo: |
|
gr.HTML(f"<div style=\"text-align: center;\"><h1>RWKV-5 World v2 - {title}</h1></div>") |
|
gr.Markdown("This is the RWKV-5 World v2 1B5 model tailored for translation tasks. All on 8 vCPUs") |
|
|
|
|
|
text = gr.Textbox(lines=5, label="Source Text", placeholder="Enter the text you want to translate...", value=EXAMPLES[0][0]) |
|
source_language = gr.Dropdown(choices=LANGUAGES, label="Source Language", value=EXAMPLES[0][1]) |
|
target_language = gr.Dropdown(choices=LANGUAGES, label="Target Language", value=EXAMPLES[0][2]) |
|
output = gr.Textbox(lines=5, label="Translated Text") |
|
|
|
|
|
submit = gr.Button("Translate", variant="primary") |
|
|
|
|
|
data = gr.Dataset(components=[text, source_language, target_language], samples=EXAMPLES, label="Example Translations", headers=["Source Text", "Source Language", "Target Language"]) |
|
|
|
|
|
submit.click(translate, [text, source_language, target_language], [output]) |
|
data.click(lambda x: x, [data], [text, source_language, target_language]) |
|
|
|
|
|
demo.queue(concurrency_count=1, max_size=10) |
|
demo.launch(share=False, debug=True) |