from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch from modules.file import ExcelFileWriter import os from abc import ABC, abstractmethod from typing import List import re class FilterPipeline(): def __init__(self, filter_list): self._filter_list:List[Filter] = filter_list def append(self, filter): self._filter_list.append(filter) def batch_encoder(self, inputs): for filter in self._filter_list: inputs = filter.encoder(inputs) return inputs def batch_decoder(self, inputs): for filter in reversed(self._filter_list): inputs = filter.decoder(inputs) return inputs class Filter(ABC): def __init__(self): self.name = 'filter' self.code = [] @abstractmethod def encoder(self, inputs): pass @abstractmethod def decoder(self, inputs): pass class SpecialTokenFilter(Filter): def __init__(self): self.name = 'special token filter' self.code = [] self.special_tokens = ['!', '!', '-'] def encoder(self, inputs): filtered_inputs = [] self.code = [] for i, input_str in enumerate(inputs): if not all(char in self.special_tokens for char in input_str): filtered_inputs.append(input_str) else: self.code.append([i, input_str]) return filtered_inputs def decoder(self, inputs): original_inputs = inputs.copy() for removed_indice in self.code: original_inputs.insert(removed_indice[0], removed_indice[1]) return original_inputs class SperSignFilter(Filter): def __init__(self): self.name = 's percentage sign filter' self.code = [] def encoder(self, inputs): encoded_inputs = [] self.code = [] # 清空 self.code for i, input_str in enumerate(inputs): if '%s' in input_str: encoded_str = input_str.replace('%s', '*') self.code.append(i) # 将包含 '%s' 的字符串的索引存储到 self.code 中 else: encoded_str = input_str encoded_inputs.append(encoded_str) return encoded_inputs def decoder(self, inputs): decoded_inputs = inputs.copy() for i in self.code: decoded_inputs[i] = decoded_inputs[i].replace('*', '%s') # 使用 self.code 中的索引还原原始字符串 return decoded_inputs class ParenSParenFilter(Filter): def __init__(self): self.name = 'Paren s paren filter' self.code = [] def encoder(self, inputs): encoded_inputs = [] self.code = [] # 清空 self.code for i, input_str in enumerate(inputs): if '(s)' in input_str: encoded_str = input_str.replace('(s)', '$') self.code.append(i) # 将包含 '(s)' 的字符串的索引存储到 self.code 中 else: encoded_str = input_str encoded_inputs.append(encoded_str) return encoded_inputs def decoder(self, inputs): decoded_inputs = inputs.copy() for i in self.code: decoded_inputs[i] = decoded_inputs[i].replace('$', '(s)') # 使用 self.code 中的索引还原原始字符串 return decoded_inputs class ChevronsFilter(Filter): def __init__(self): self.name = 'chevrons filter' self.code = [] def encoder(self, inputs): encoded_inputs = [] self.code = [] # 清空 self.code pattern = re.compile(r'<.*?>') for i, input_str in enumerate(inputs): if pattern.search(input_str): matches = pattern.findall(input_str) encoded_str = pattern.sub('#', input_str) self.code.append((i, matches)) # 将包含匹配模式的字符串的索引和匹配列表存储到 self.code 中 else: encoded_str = input_str encoded_inputs.append(encoded_str) return encoded_inputs def decoder(self, inputs): decoded_inputs = inputs.copy() for i, matches in self.code: for match in matches: decoded_inputs[i] = decoded_inputs[i].replace('#', match, 1) # 使用 self.code 中的匹配列表依次还原原始字符串 return decoded_inputs class SimilarFilter(Filter): def __init__(self): self.name = 'similar filter' self.code = [] def is_similar(self, str1, str2): # 判断两个字符串是否相似(只有数字上有区别) pattern = re.compile(r'\d+') return pattern.sub('', str1) == pattern.sub('', str2) def encoder(self, inputs): encoded_inputs = [] self.code = [] # 清空 self.code i = 0 while i < len(inputs): encoded_inputs.append(inputs[i]) similar_strs = [inputs[i]] j = i + 1 while j < len(inputs) and self.is_similar(inputs[i], inputs[j]): similar_strs.append(inputs[j]) j += 1 if len(similar_strs) > 1: self.code.append((i, similar_strs)) # 将相似字符串的起始索引和实际字符串列表存储到 self.code 中 i = j return encoded_inputs def decoder(self, inputs:List): decoded_inputs = inputs for i, similar_strs in self.code: pattern = re.compile(r'\d+') for j in range(len(similar_strs)): if pattern.search(similar_strs[j]): number = re.findall(r'\d+', similar_strs[j])[0] # 获取相似字符串的数字部分 new_str = pattern.sub(number, inputs[i]) # 将新字符串的数字部分替换为相似字符串的数字部分 else: new_str = inputs[i] # 如果相似字符串不含数字,直接使用新字符串 if j > 0: decoded_inputs.insert(i+j, new_str) return decoded_inputs class ChineseFilter: def __init__(self, pinyin_lib_file='pinyin.txt'): self.name = 'chinese filter' self.code = [] self.pinyin_lib = self.load_pinyin_lib(pinyin_lib_file) def load_pinyin_lib(self, file_path): with open(os.path.join(script_dir,file_path), 'r', encoding='utf-8') as f: return set(line.strip().lower() for line in f) def is_valid_chinese(self, word): # 判断一个单词是否符合要求:只有一个单词构成,并且首字母大写 if len(word.split()) == 1 and word[0].isupper(): # 使用pinyin_or_word函数判断是否是合法的拼音 return self.is_pinyin(word.lower()) return False def encoder(self, inputs): encoded_inputs = [] self.code = [] # 清空 self.code for i, word in enumerate(inputs): if self.is_valid_chinese(word): self.code.append((i, word)) # 将需要过滤的中文单词的索引和拼音存储到 self.code 中 else: encoded_inputs.append(word) return encoded_inputs def decoder(self, inputs): decoded_inputs = inputs.copy() for i, word in self.code: decoded_inputs.insert(i, word) # 根据索引将过滤的中文单词还原到原位置 return decoded_inputs def is_pinyin(self, string): ''' judge a string is a pinyin or a english word. pinyin_Lib comes from a txt file. ''' string = string.lower() stringlen = len(string) max_len = 6 result = [] n = 0 while n < stringlen: matched = 0 temp_result = [] for i in range(max_len, 0, -1): s = string[0:i] if s in self.pinyin_lib: temp_result.append(string[:i]) matched = i break if i == 1 and len(temp_result) == 0: return False result.extend(temp_result) string = string[matched:] n += matched return True script_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir))) class Model(): def __init__(self, modelname, selected_lora_model, selected_gpu): def get_gpu_index(gpu_info, target_gpu_name): """ 从 GPU 信息中获取目标 GPU 的索引 Args: gpu_info (list): 包含 GPU 名称的列表 target_gpu_name (str): 目标 GPU 的名称 Returns: int: 目标 GPU 的索引,如果未找到则返回 -1 """ for i, name in enumerate(gpu_info): if target_gpu_name.lower() in name.lower(): return i return -1 if selected_gpu != "cpu": gpu_count = torch.cuda.device_count() gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)] selected_gpu_index = get_gpu_index(gpu_info, selected_gpu) self.device_name = f"cuda:{selected_gpu_index}" else: self.device_name = "cpu" print("device_name", self.device_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(modelname).to(self.device_name) self.tokenizer = AutoTokenizer.from_pretrained(modelname) # self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device) def generate(self, inputs, original_language, target_languages, max_batch_size): filter_list = [SpecialTokenFilter(), SperSignFilter(), ParenSParenFilter(), ChevronsFilter(), SimilarFilter(), ChineseFilter()] filter_pipeline = FilterPipeline(filter_list) def language_mapping(original_language): d = { "Achinese (Arabic script)": "ace_Arab", "Achinese (Latin script)": "ace_Latn", "Mesopotamian Arabic": "acm_Arab", "Ta'izzi-Adeni Arabic": "acq_Arab", "Tunisian Arabic": "aeb_Arab", "Afrikaans": "afr_Latn", "South Levantine Arabic": "ajp_Arab", "Akan": "aka_Latn", "Amharic": "amh_Ethi", "North Levantine Arabic": "apc_Arab", "Standard Arabic": "arb_Arab", "Najdi Arabic": "ars_Arab", "Moroccan Arabic": "ary_Arab", "Egyptian Arabic": "arz_Arab", "Assamese": "asm_Beng", "Asturian": "ast_Latn", "Awadhi": "awa_Deva", "Central Aymara": "ayr_Latn", "South Azerbaijani": "azb_Arab", "North Azerbaijani": "azj_Latn", "Bashkir": "bak_Cyrl", "Bambara": "bam_Latn", "Balinese": "ban_Latn", "Belarusian": "bel_Cyrl", "Bemba": "bem_Latn", "Bengali": "ben_Beng", "Bhojpuri": "bho_Deva", "Banjar (Arabic script)": "bjn_Arab", "Banjar (Latin script)": "bjn_Latn", "Tibetan": "bod_Tibt", "Bosnian": "bos_Latn", "Buginese": "bug_Latn", "Bulgarian": "bul_Cyrl", "Catalan": "cat_Latn", "Cebuano": "ceb_Latn", "Czech": "ces_Latn", "Chokwe": "cjk_Latn", "Central Kurdish": "ckb_Arab", "Crimean Tatar": "crh_Latn", "Welsh": "cym_Latn", "Danish": "dan_Latn", "German": "deu_Latn", "Dinka": "dik_Latn", "Jula": "dyu_Latn", "Dzongkha": "dzo_Tibt", "Greek": "ell_Grek", "English": "eng_Latn", "Esperanto": "epo_Latn", "Estonian": "est_Latn", "Basque": "eus_Latn", "Ewe": "ewe_Latn", "Faroese": "fao_Latn", "Persian": "pes_Arab", "Fijian": "fij_Latn", "Finnish": "fin_Latn", "Fon": "fon_Latn", "French": "fra_Latn", "Friulian": "fur_Latn", "Nigerian Fulfulde": "fuv_Latn", "Scottish Gaelic": "gla_Latn", "Irish": "gle_Latn", "Galician": "glg_Latn", "Guarani": "grn_Latn", "Gujarati": "guj_Gujr", "Haitian Creole": "hat_Latn", "Hausa": "hau_Latn", "Hebrew": "heb_Hebr", "Hindi": "hin_Deva", "Chhattisgarhi": "hne_Deva", "Croatian": "hrv_Latn", "Hungarian": "hun_Latn", "Armenian": "hye_Armn", "Igbo": "ibo_Latn", "Iloko": "ilo_Latn", "Indonesian": "ind_Latn", "Icelandic": "isl_Latn", "Italian": "ita_Latn", "Javanese": "jav_Latn", "Japanese": "jpn_Jpan", "Kabyle": "kab_Latn", "Kachin": "kac_Latn", "Arabic": "ar_AR", "Chinese": "zho_Hans", "Spanish": "spa_Latn", "Dutch": "nld_Latn", "Kazakh": "kaz_Cyrl", "Korean": "kor_Hang", "Lithuanian": "lit_Latn", "Malayalam": "mal_Mlym", "Marathi": "mar_Deva", "Nepali": "ne_NP", "Polish": "pol_Latn", "Portuguese": "por_Latn", "Russian": "rus_Cyrl", "Sinhala": "sin_Sinh", "Tamil": "tam_Taml", "Turkish": "tur_Latn", "Ukrainian": "ukr_Cyrl", "Urdu": "urd_Arab", "Vietnamese": "vie_Latn", "Thai":"tha_Thai", "Khmer":"khm_Khmr" } return d[original_language] def process_gpu_translate_result(temp_outputs): outputs = [] for temp_output in temp_outputs: length = len(temp_output[0]["generated_translation"]) for i in range(length): temp = [] for trans in temp_output: temp.append({ "target_language": trans["target_language"], "generated_translation": trans['generated_translation'][i], }) outputs.append(temp) excel_writer = ExcelFileWriter() excel_writer.write_text(os.path.join(parent_dir,r"temp/empty.xlsx"), outputs, 'A', 1, len(outputs)) self.tokenizer.src_lang = language_mapping(original_language) if self.device_name == "cpu": # Tokenize input input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name) output = [] for target_language in target_languages: # Get language code for the target language target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)] # Generate translation generated_tokens = self.model.generate( **input_ids, forced_bos_token_id=target_lang_code, max_length=128 ) generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) # Append result to output output.append({ "target_language": target_language, "generated_translation": generated_translation, }) outputs = [] length = len(output[0]["generated_translation"]) for i in range(length): temp = [] for trans in output: temp.append({ "target_language": trans["target_language"], "generated_translation": trans['generated_translation'][i], }) outputs.append(temp) return outputs else: # 最大批量大小 = 可用 GPU 内存字节数 / 4 / (张量大小 + 可训练参数) # max_batch_size = 10 # Ensure batch size is within model limits: print("length of inputs: ",len(inputs)) batch_size = min(len(inputs), int(max_batch_size)) batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)] print("length of batches size: ", len(batches)) temp_outputs = [] processed_num = 0 for index, batch in enumerate(batches): # Tokenize input print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") print(len(batch)) print(batch) batch = filter_pipeline.batch_encoder(batch) print(batch) temp = [] if len(batch) > 0: input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name) for target_language in target_languages: target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)] generated_tokens = self.model.generate( **input_ids, forced_bos_token_id=target_lang_code, ) generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) print(generated_translation) generated_translation = filter_pipeline.batch_decoder(generated_translation) print(generated_translation) print(len(generated_translation)) # Append result to output temp.append({ "target_language": target_language, "generated_translation": generated_translation, }) input_ids.to('cpu') del input_ids else: for target_language in target_languages: generated_translation = filter_pipeline.batch_decoder(batch) print(generated_translation) print(len(generated_translation)) # Append result to output temp.append({ "target_language": target_language, "generated_translation": generated_translation, }) temp_outputs.append(temp) processed_num += len(batch) if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1: print("Already processed number: ", len(temp_outputs)) process_gpu_translate_result(temp_outputs) outputs = [] for temp_output in temp_outputs: length = len(temp_output[0]["generated_translation"]) for i in range(length): temp = [] for trans in temp_output: temp.append({ "target_language": trans["target_language"], "generated_translation": trans['generated_translation'][i], }) outputs.append(temp) return outputs for filter in self._filter_list: inputs = filter.encoder(inputs) return inputs def batch_decoder(self, inputs): for filter in reversed(self._filter_list): inputs = filter.decoder(inputs) return inputs