rugbyxpert-gpt2 / app_1.py
nico-che's picture
Rename app.py to app_1.py
255c79c
from transformers import AutoTokenizer, pipeline
import gradio as gr
model_name = "gpt2-large"
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
generator = pipeline(task="text-generation",
model=model_name,
tokenizer=tokenizer,
trust_remote_code=True
)
def nb_tokens(input):
return len(tokenizer(input)['input_ids'])
def client_generate(input, max_new_tokens=256, stop_sequences=[]):
output = generator(
input,
max_length=max_new_tokens+nb_tokens(input),
pad_token_id=50256,
num_return_sequences=1,
)
if len(output)==0 or 'generated_text' not in output[0]:
return {'text': input, 'generated_text': ''}
response = output[0]['generated_text'].split(input)[1].strip()
if type(stop_sequences)==list and len(stop_sequences)>0:
for seq in stop_sequences:
response = response[:response.find(seq)]
return {'text': input, 'generated_text': response}
def respond(message, chat_history, max_tokens=32):
bot_message = client_generate(message,
max_new_tokens=max_tokens,#1024,
stop_sequences=["."], #stop_sequences to not generate the user answer
)['generated_text']
chat_history.append((message, f"{bot_message}."))
return "", chat_history
with gr.Blocks(
title='RugbyXpert',
# theme='sudeepshouche/minimalist', # https://www.gradio.app/guides/theming-guide
) as demo:
gr.Markdown(
"""
# RugbyXpert
"""
)
chatbot = gr.Chatbot(
height=310, # just to fit the notebook
)
msg = gr.Textbox(label="Pose-moi une question sur le rugby pendant la saison 2022-2023")
with gr.Row():
with gr.Column():
btn = gr.Button("Submit", variant="primary")
with gr.Column():
clear = gr.ClearButton(components=[msg, chatbot], value="Clear console")
gr.Examples([
"Tu peux me donner le 21 de Vannes lors du match les opposant à Aurillac du vendredi 24 février 2023 ?",
"Tu peux me retrouver le score final du match opposant Soyaux-Angoulême à Grenoble le vendredi 17 mars 2023 ?",
"Dis-moi le score final du match opposant Vannes à Aurillac le vendredi 24 février 2023 ?",
], [msg])
btn.click(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])
msg.submit(respond, inputs=[msg, chatbot], outputs=[msg, chatbot]) #Press enter to submit
demo.launch()