|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
import torch |
|
from modules.file import ExcelFileWriter |
|
import os |
|
|
|
from abc import ABC, abstractmethod |
|
from typing import List |
|
import re |
|
|
|
class FilterPipeline(): |
|
def __init__(self, filter_list): |
|
self._filter_list:List[Filter] = filter_list |
|
|
|
def append(self, filter): |
|
self._filter_list.append(filter) |
|
|
|
def batch_encoder(self, inputs): |
|
for filter in self._filter_list: |
|
inputs = filter.encoder(inputs) |
|
return inputs |
|
|
|
def batch_decoder(self, inputs): |
|
for filter in reversed(self._filter_list): |
|
inputs = filter.decoder(inputs) |
|
return inputs |
|
|
|
class Filter(ABC): |
|
def __init__(self): |
|
self.name = 'filter' |
|
self.code = [] |
|
@abstractmethod |
|
def encoder(self, inputs): |
|
pass |
|
|
|
@abstractmethod |
|
def decoder(self, inputs): |
|
pass |
|
|
|
class SpecialTokenFilter(Filter): |
|
def __init__(self): |
|
self.name = 'special token filter' |
|
self.code = [] |
|
self.special_tokens = ['!', '!', '-'] |
|
|
|
def encoder(self, inputs): |
|
filtered_inputs = [] |
|
self.code = [] |
|
for i, input_str in enumerate(inputs): |
|
if not all(char in self.special_tokens for char in input_str): |
|
filtered_inputs.append(input_str) |
|
else: |
|
self.code.append([i, input_str]) |
|
return filtered_inputs |
|
|
|
def decoder(self, inputs): |
|
original_inputs = inputs.copy() |
|
for removed_indice in self.code: |
|
original_inputs.insert(removed_indice[0], removed_indice[1]) |
|
return original_inputs |
|
|
|
class SperSignFilter(Filter): |
|
def __init__(self): |
|
self.name = 's percentage sign filter' |
|
self.code = [] |
|
|
|
def encoder(self, inputs): |
|
encoded_inputs = [] |
|
self.code = [] |
|
for i, input_str in enumerate(inputs): |
|
if '%s' in input_str: |
|
encoded_str = input_str.replace('%s', '*') |
|
self.code.append(i) |
|
else: |
|
encoded_str = input_str |
|
encoded_inputs.append(encoded_str) |
|
return encoded_inputs |
|
|
|
def decoder(self, inputs): |
|
decoded_inputs = inputs.copy() |
|
for i in self.code: |
|
decoded_inputs[i] = decoded_inputs[i].replace('*', '%s') |
|
return decoded_inputs |
|
|
|
class ParenSParenFilter(Filter): |
|
def __init__(self): |
|
self.name = 'Paren s paren filter' |
|
self.code = [] |
|
|
|
def encoder(self, inputs): |
|
encoded_inputs = [] |
|
self.code = [] |
|
for i, input_str in enumerate(inputs): |
|
if '(s)' in input_str: |
|
encoded_str = input_str.replace('(s)', '$') |
|
self.code.append(i) |
|
else: |
|
encoded_str = input_str |
|
encoded_inputs.append(encoded_str) |
|
return encoded_inputs |
|
|
|
def decoder(self, inputs): |
|
decoded_inputs = inputs.copy() |
|
for i in self.code: |
|
decoded_inputs[i] = decoded_inputs[i].replace('$', '(s)') |
|
return decoded_inputs |
|
|
|
class ChevronsFilter(Filter): |
|
def __init__(self): |
|
self.name = 'chevrons filter' |
|
self.code = [] |
|
|
|
def encoder(self, inputs): |
|
encoded_inputs = [] |
|
self.code = [] |
|
pattern = re.compile(r'<.*?>') |
|
for i, input_str in enumerate(inputs): |
|
if pattern.search(input_str): |
|
matches = pattern.findall(input_str) |
|
encoded_str = pattern.sub('#', input_str) |
|
self.code.append((i, matches)) |
|
else: |
|
encoded_str = input_str |
|
encoded_inputs.append(encoded_str) |
|
return encoded_inputs |
|
|
|
def decoder(self, inputs): |
|
decoded_inputs = inputs.copy() |
|
for i, matches in self.code: |
|
for match in matches: |
|
decoded_inputs[i] = decoded_inputs[i].replace('#', match, 1) |
|
return decoded_inputs |
|
|
|
class SimilarFilter(Filter): |
|
def __init__(self): |
|
self.name = 'similar filter' |
|
self.code = [] |
|
|
|
def is_similar(self, str1, str2): |
|
|
|
pattern = re.compile(r'\d+') |
|
return pattern.sub('', str1) == pattern.sub('', str2) |
|
|
|
def encoder(self, inputs): |
|
encoded_inputs = [] |
|
self.code = [] |
|
i = 0 |
|
while i < len(inputs): |
|
encoded_inputs.append(inputs[i]) |
|
similar_strs = [inputs[i]] |
|
j = i + 1 |
|
while j < len(inputs) and self.is_similar(inputs[i], inputs[j]): |
|
similar_strs.append(inputs[j]) |
|
j += 1 |
|
if len(similar_strs) > 1: |
|
self.code.append((i, similar_strs)) |
|
i = j |
|
return encoded_inputs |
|
|
|
def decoder(self, inputs:List): |
|
decoded_inputs = inputs |
|
for i, similar_strs in self.code: |
|
pattern = re.compile(r'\d+') |
|
for j in range(len(similar_strs)): |
|
if pattern.search(similar_strs[j]): |
|
number = re.findall(r'\d+', similar_strs[j])[0] |
|
new_str = pattern.sub(number, inputs[i]) |
|
else: |
|
new_str = inputs[i] |
|
if j > 0: |
|
decoded_inputs.insert(i+j, new_str) |
|
return decoded_inputs |
|
|
|
class ChineseFilter: |
|
def __init__(self, pinyin_lib_file='pinyin.txt'): |
|
self.name = 'chinese filter' |
|
self.code = [] |
|
self.pinyin_lib = self.load_pinyin_lib(pinyin_lib_file) |
|
|
|
def load_pinyin_lib(self, file_path): |
|
with open(os.path.join(script_dir,file_path), 'r', encoding='utf-8') as f: |
|
return set(line.strip().lower() for line in f) |
|
|
|
def is_valid_chinese(self, word): |
|
|
|
if len(word.split()) == 1 and word[0].isupper(): |
|
|
|
return self.is_pinyin(word.lower()) |
|
return False |
|
|
|
def encoder(self, inputs): |
|
encoded_inputs = [] |
|
self.code = [] |
|
for i, word in enumerate(inputs): |
|
if self.is_valid_chinese(word): |
|
self.code.append((i, word)) |
|
else: |
|
encoded_inputs.append(word) |
|
return encoded_inputs |
|
|
|
def decoder(self, inputs): |
|
decoded_inputs = inputs.copy() |
|
for i, word in self.code: |
|
decoded_inputs.insert(i, word) |
|
return decoded_inputs |
|
|
|
def is_pinyin(self, string): |
|
''' |
|
judge a string is a pinyin or a english word. |
|
pinyin_Lib comes from a txt file. |
|
''' |
|
string = string.lower() |
|
stringlen = len(string) |
|
max_len = 6 |
|
result = [] |
|
n = 0 |
|
while n < stringlen: |
|
matched = 0 |
|
temp_result = [] |
|
for i in range(max_len, 0, -1): |
|
s = string[0:i] |
|
if s in self.pinyin_lib: |
|
temp_result.append(string[:i]) |
|
matched = i |
|
break |
|
if i == 1 and len(temp_result) == 0: |
|
return False |
|
result.extend(temp_result) |
|
string = string[matched:] |
|
n += matched |
|
return True |
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir))) |
|
|
|
class Model(): |
|
def __init__(self, modelname, selected_lora_model, selected_gpu): |
|
def get_gpu_index(gpu_info, target_gpu_name): |
|
""" |
|
从 GPU 信息中获取目标 GPU 的索引 |
|
Args: |
|
gpu_info (list): 包含 GPU 名称的列表 |
|
target_gpu_name (str): 目标 GPU 的名称 |
|
|
|
Returns: |
|
int: 目标 GPU 的索引,如果未找到则返回 -1 |
|
""" |
|
for i, name in enumerate(gpu_info): |
|
if target_gpu_name.lower() in name.lower(): |
|
return i |
|
return -1 |
|
if selected_gpu != "cpu": |
|
gpu_count = torch.cuda.device_count() |
|
gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)] |
|
selected_gpu_index = get_gpu_index(gpu_info, selected_gpu) |
|
self.device_name = f"cuda:{selected_gpu_index}" |
|
else: |
|
self.device_name = "cpu" |
|
print("device_name", self.device_name) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(modelname).to(self.device_name) |
|
self.tokenizer = AutoTokenizer.from_pretrained(modelname) |
|
|
|
|
|
def generate(self, inputs, original_language, target_languages, max_batch_size): |
|
filter_list = [SpecialTokenFilter(), SperSignFilter(), ParenSParenFilter(), ChevronsFilter(), SimilarFilter(), ChineseFilter()] |
|
filter_pipeline = FilterPipeline(filter_list) |
|
def language_mapping(original_language): |
|
d = { |
|
"Achinese (Arabic script)": "ace_Arab", |
|
"Achinese (Latin script)": "ace_Latn", |
|
"Mesopotamian Arabic": "acm_Arab", |
|
"Ta'izzi-Adeni Arabic": "acq_Arab", |
|
"Tunisian Arabic": "aeb_Arab", |
|
"Afrikaans": "afr_Latn", |
|
"South Levantine Arabic": "ajp_Arab", |
|
"Akan": "aka_Latn", |
|
"Amharic": "amh_Ethi", |
|
"North Levantine Arabic": "apc_Arab", |
|
"Standard Arabic": "arb_Arab", |
|
"Najdi Arabic": "ars_Arab", |
|
"Moroccan Arabic": "ary_Arab", |
|
"Egyptian Arabic": "arz_Arab", |
|
"Assamese": "asm_Beng", |
|
"Asturian": "ast_Latn", |
|
"Awadhi": "awa_Deva", |
|
"Central Aymara": "ayr_Latn", |
|
"South Azerbaijani": "azb_Arab", |
|
"North Azerbaijani": "azj_Latn", |
|
"Bashkir": "bak_Cyrl", |
|
"Bambara": "bam_Latn", |
|
"Balinese": "ban_Latn", |
|
"Belarusian": "bel_Cyrl", |
|
"Bemba": "bem_Latn", |
|
"Bengali": "ben_Beng", |
|
"Bhojpuri": "bho_Deva", |
|
"Banjar (Arabic script)": "bjn_Arab", |
|
"Banjar (Latin script)": "bjn_Latn", |
|
"Tibetan": "bod_Tibt", |
|
"Bosnian": "bos_Latn", |
|
"Buginese": "bug_Latn", |
|
"Bulgarian": "bul_Cyrl", |
|
"Catalan": "cat_Latn", |
|
"Cebuano": "ceb_Latn", |
|
"Czech": "ces_Latn", |
|
"Chokwe": "cjk_Latn", |
|
"Central Kurdish": "ckb_Arab", |
|
"Crimean Tatar": "crh_Latn", |
|
"Welsh": "cym_Latn", |
|
"Danish": "dan_Latn", |
|
"German": "deu_Latn", |
|
"Dinka": "dik_Latn", |
|
"Jula": "dyu_Latn", |
|
"Dzongkha": "dzo_Tibt", |
|
"Greek": "ell_Grek", |
|
"English": "eng_Latn", |
|
"Esperanto": "epo_Latn", |
|
"Estonian": "est_Latn", |
|
"Basque": "eus_Latn", |
|
"Ewe": "ewe_Latn", |
|
"Faroese": "fao_Latn", |
|
"Persian": "pes_Arab", |
|
"Fijian": "fij_Latn", |
|
"Finnish": "fin_Latn", |
|
"Fon": "fon_Latn", |
|
"French": "fra_Latn", |
|
"Friulian": "fur_Latn", |
|
"Nigerian Fulfulde": "fuv_Latn", |
|
"Scottish Gaelic": "gla_Latn", |
|
"Irish": "gle_Latn", |
|
"Galician": "glg_Latn", |
|
"Guarani": "grn_Latn", |
|
"Gujarati": "guj_Gujr", |
|
"Haitian Creole": "hat_Latn", |
|
"Hausa": "hau_Latn", |
|
"Hebrew": "heb_Hebr", |
|
"Hindi": "hin_Deva", |
|
"Chhattisgarhi": "hne_Deva", |
|
"Croatian": "hrv_Latn", |
|
"Hungarian": "hun_Latn", |
|
"Armenian": "hye_Armn", |
|
"Igbo": "ibo_Latn", |
|
"Iloko": "ilo_Latn", |
|
"Indonesian": "ind_Latn", |
|
"Icelandic": "isl_Latn", |
|
"Italian": "ita_Latn", |
|
"Javanese": "jav_Latn", |
|
"Japanese": "jpn_Jpan", |
|
"Kabyle": "kab_Latn", |
|
"Kachin": "kac_Latn", |
|
"Arabic": "ar_AR", |
|
"Chinese": "zho_Hans", |
|
"Spanish": "spa_Latn", |
|
"Dutch": "nld_Latn", |
|
"Kazakh": "kaz_Cyrl", |
|
"Korean": "kor_Hang", |
|
"Lithuanian": "lit_Latn", |
|
"Malayalam": "mal_Mlym", |
|
"Marathi": "mar_Deva", |
|
"Nepali": "ne_NP", |
|
"Polish": "pol_Latn", |
|
"Portuguese": "por_Latn", |
|
"Russian": "rus_Cyrl", |
|
"Sinhala": "sin_Sinh", |
|
"Tamil": "tam_Taml", |
|
"Turkish": "tur_Latn", |
|
"Ukrainian": "ukr_Cyrl", |
|
"Urdu": "urd_Arab", |
|
"Vietnamese": "vie_Latn", |
|
"Thai":"tha_Thai", |
|
"Khmer":"khm_Khmr" |
|
} |
|
return d[original_language] |
|
def process_gpu_translate_result(temp_outputs): |
|
outputs = [] |
|
for temp_output in temp_outputs: |
|
length = len(temp_output[0]["generated_translation"]) |
|
for i in range(length): |
|
temp = [] |
|
for trans in temp_output: |
|
temp.append({ |
|
"target_language": trans["target_language"], |
|
"generated_translation": trans['generated_translation'][i], |
|
}) |
|
outputs.append(temp) |
|
excel_writer = ExcelFileWriter() |
|
excel_writer.write_text(os.path.join(parent_dir,r"temp/empty.xlsx"), outputs, 'A', 1, len(outputs)) |
|
self.tokenizer.src_lang = language_mapping(original_language) |
|
if self.device_name == "cpu": |
|
|
|
input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name) |
|
output = [] |
|
for target_language in target_languages: |
|
|
|
target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)] |
|
|
|
generated_tokens = self.model.generate( |
|
**input_ids, |
|
forced_bos_token_id=target_lang_code, |
|
max_length=128 |
|
) |
|
generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
|
|
output.append({ |
|
"target_language": target_language, |
|
"generated_translation": generated_translation, |
|
}) |
|
outputs = [] |
|
length = len(output[0]["generated_translation"]) |
|
for i in range(length): |
|
temp = [] |
|
for trans in output: |
|
temp.append({ |
|
"target_language": trans["target_language"], |
|
"generated_translation": trans['generated_translation'][i], |
|
}) |
|
outputs.append(temp) |
|
return outputs |
|
else: |
|
|
|
|
|
|
|
print("length of inputs: ",len(inputs)) |
|
batch_size = min(len(inputs), int(max_batch_size)) |
|
batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)] |
|
print("length of batches size: ", len(batches)) |
|
temp_outputs = [] |
|
processed_num = 0 |
|
for index, batch in enumerate(batches): |
|
|
|
print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") |
|
print(len(batch)) |
|
print(batch) |
|
batch = filter_pipeline.batch_encoder(batch) |
|
print(batch) |
|
temp = [] |
|
if len(batch) > 0: |
|
input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name) |
|
for target_language in target_languages: |
|
target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)] |
|
generated_tokens = self.model.generate( |
|
**input_ids, |
|
forced_bos_token_id=target_lang_code, |
|
) |
|
generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
|
|
print(generated_translation) |
|
generated_translation = filter_pipeline.batch_decoder(generated_translation) |
|
print(generated_translation) |
|
print(len(generated_translation)) |
|
|
|
temp.append({ |
|
"target_language": target_language, |
|
"generated_translation": generated_translation, |
|
}) |
|
input_ids.to('cpu') |
|
del input_ids |
|
else: |
|
for target_language in target_languages: |
|
generated_translation = filter_pipeline.batch_decoder(batch) |
|
print(generated_translation) |
|
print(len(generated_translation)) |
|
|
|
temp.append({ |
|
"target_language": target_language, |
|
"generated_translation": generated_translation, |
|
}) |
|
temp_outputs.append(temp) |
|
processed_num += len(batch) |
|
if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1: |
|
print("Already processed number: ", len(temp_outputs)) |
|
process_gpu_translate_result(temp_outputs) |
|
outputs = [] |
|
for temp_output in temp_outputs: |
|
length = len(temp_output[0]["generated_translation"]) |
|
for i in range(length): |
|
temp = [] |
|
for trans in temp_output: |
|
temp.append({ |
|
"target_language": trans["target_language"], |
|
"generated_translation": trans['generated_translation'][i], |
|
}) |
|
outputs.append(temp) |
|
return outputs |
|
for filter in self._filter_list: |
|
inputs = filter.encoder(inputs) |
|
return inputs |
|
|
|
def batch_decoder(self, inputs): |
|
for filter in reversed(self._filter_list): |
|
inputs = filter.decoder(inputs) |
|
return inputs |
|
|
|
|