Story-Teller / app.py
3v324v23's picture
1
ae0b701
raw
history blame
1.27 kB
from gradio import Interface, components
from transformers import AutoTokenizer, AutoModelForCausalLM
# 加载模型和tokenizer
tokenizer = AutoTokenizer.from_pretrained("raincandy-u/TinyStories-656K")
model = AutoModelForCausalLM.from_pretrained("raincandy-u/TinyStories-656K")
# 定义你的应用程序
def generate_story(input_text):
input_text = f"<|start_story|>{input_text}"
encoded_input = tokenizer(input_text, return_tensors="pt")
output_sequences = model.generate(
**encoded_input,
pad_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=True,
top_k=40,
top_p=0.9,
temperature=0.65
)
return tokenizer.decode(output_sequences[0], skip_special_tokens=True)
# 定义组件
input_component = components.Textbox(lines=10)
label = components.Label("Try it!\nNote: Most of the time the default beginning works well.")
# 定义Interface
interface = Interface(
fn=generate_story,
inputs=input_component,
outputs="textbox",
title="TinyStories-656K",
description="Try it!\nNote: Most of the time the default beginning works well.",
examples=[['Once upon a time, there was a girl '], ['Long time ago, ']],
theme="gradio/light"
)
interface.launch()