Text Generation
Transformers
PyTorch
Chinese
English
gpt2
text-generation-inference
Inference Endpoints

中文文本生成

1 Usage

1.1 Initalization 初始化

!pip install transformers[torch]

from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = GPT2Tokenizer.from_pretrained('Hollway/gpt2_finetune')
model = GPT2LMHeadModel.from_pretrained('Hollway/gpt2_finetune').to(device)

1.2 Inference 基本推理任务

def generate(text):  # 基本的下文预测任务
    inputs = tokenizer(text, return_tensors="pt").to(device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
        )
    return tokenizer.decode(tokens[0], skip_special_tokens=True)

generate("派蒙是应急食品,但是不能吃派蒙,请分析不能吃的原因。")

1.3 Chatbot 聊天模式

def chat(turns=5): # 多轮对话模式,通过字符串拼接实现。
    for step in range(turns):
        query = input(">> 用户:")
        new_user_input_ids = tokenizer.encode(
            f"用户: {query}\n\n系统: ", return_tensors='pt').to(device)
        bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

        base_tokens = bot_input_ids.shape[-1]
        chat_history_ids = model.generate(
            bot_input_ids,
            max_length=base_tokens+64, # 单次回复的最大token数量
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id)

        response = tokenizer.decode(
            chat_history_ids[:, bot_input_ids.shape[-1]:][0], 
            skip_special_tokens=True)

        print(f"系统: {response}\n")

chat(turns=5)
Downloads last month
33
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Datasets used to train Hollway/gpt2_finetune