--- language: - zh - en license: mit datasets: - TigerResearch/tigerbot-zhihu-zh-10k - TigerResearch/tigerbot-book-qa-1k - TigerResearch/sft_zh pipeline_tag: text-generation --- # 中文文本生成 ## 1 Usage ### 1.1 Initalization 初始化 !pip install transformers[torch] ``` from transformers import GPT2Tokenizer, GPT2LMHeadModel import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = GPT2Tokenizer.from_pretrained('Hollway/gpt2_finetune') model = GPT2LMHeadModel.from_pretrained('Hollway/gpt2_finetune').to(device) ``` ### 1.2 Inference 基本推理任务 ``` def generate(text): # 基本的下文预测任务 inputs = tokenizer(text, return_tensors="pt").to(device) with torch.no_grad(): tokens = model.generate( **inputs, max_new_tokens=512, do_sample=True, pad_token_id=tokenizer.pad_token_id, ) return tokenizer.decode(tokens[0], skip_special_tokens=True) generate("派蒙是应急食品,但是不能吃派蒙,请分析不能吃的原因。") ``` ### 1.3 Chatbot 聊天模式 ``` def chat(turns=5): # 多轮对话模式,通过字符串拼接实现。 for step in range(turns): query = input(">> 用户:") new_user_input_ids = tokenizer.encode( f"用户: {query}\n\n系统: ", return_tensors='pt').to(device) bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids base_tokens = bot_input_ids.shape[-1] chat_history_ids = model.generate( bot_input_ids, max_length=base_tokens+64, # 单次回复的最大token数量 do_sample=True, pad_token_id=tokenizer.eos_token_id) response = tokenizer.decode( chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) print(f"系统: {response}\n") chat(turns=5) ```