andrew3279's picture
Update app.py
b0dd69b
from transformers import AutoTokenizer, AutoModelForCausalLM
from itertools import chain
import gradio as gr
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
tokenizer = AutoTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
model = AutoModelForCausalLM.from_pretrained("uer/gpt2-chinese-cluecorpussmall").to(device)
def generate_text(prompt,length=500):
inputs = tokenizer(prompt,add_special_tokens=False, return_tensors="pt").to(device)
txt = tokenizer.decode(model.generate(inputs["input_ids"],
max_length=length,
num_beams=2,
no_repeat_ngram_size=2,
early_stopping=True,
pad_token_id = 0
)[0])
#Replace text
replacements = {
'[': "",
']': "",
'S': "",
'E': "",
'P': "",
'U': "",
'N': "",
'K': ""
}
new_text = ''.join(chain.from_iterable(replacements.get(word, [word]) for word in txt))
return new_text
with gr.Blocks() as web:
gr.Markdown("<h1><center>Andrew Lim Chinese stories </center></h1>")
gr.Markdown("""<h2><center>让人工智能讲故事:<br><br>
<img src=https://images.unsplash.com/photo-1550450339-e7a4787a2074?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1252&q=80></center></h2>""")
gr.Markdown("""<center>******</center>""")
input_text = gr.Textbox(label="故事的开始", value="在空中飞翔", lines=6)
buton = gr.Button("Submit ")
output_text = gr.Textbox(lines=6, label="人工智能讲一个故事 :")
buton.click(generate_text, inputs=[input_text], outputs=output_text)
web.launch()