nico-che commited on
Commit
82b1566
1 Parent(s): 33d59e2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, pipeline
2
+
3
+ import gradio as gr
4
+
5
+ model_name = "gpt2-large"
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name,
8
+ trust_remote_code=True
9
+ )
10
+ tokenizer.pad_token = tokenizer.eos_token
11
+ generator = pipeline(task="text-generation",
12
+ model=model_name,
13
+ tokenizer=tokenizer,
14
+ trust_remote_code=True
15
+ )
16
+
17
+ def nb_tokens(input):
18
+ return len(tokenizer(input)['input_ids'])
19
+
20
+ def client_generate(input, max_new_tokens=256, stop_sequences=[]):
21
+ output = generator(
22
+ input,
23
+ max_length=max_new_tokens+nb_tokens(input),
24
+ pad_token_id=50256,
25
+ num_return_sequences=1,
26
+ )
27
+ if len(output)==0 or 'generated_text' not in output[0]:
28
+ return {'text': input, 'generated_text': ''}
29
+ response = output[0]['generated_text'].split(input)[1].strip()
30
+ if type(stop_sequences)==list and len(stop_sequences)>0:
31
+ for seq in stop_sequences:
32
+ response = response[:response.find(seq)]
33
+ return {'text': input, 'generated_text': response}
34
+
35
+ def respond(message, chat_history, model_name=modelname_new, max_tokens=128):
36
+ bot_message = client_generate(reshape_prompt(message, model_name),
37
+ max_new_tokens=max_tokens,#1024,
38
+ stop_sequences=["."], #stop_sequences to not generate the user answer
39
+ )['generated_text']
40
+ chat_history.append((message, f"{bot_message}."))
41
+ return "", chat_history
42
+
43
+ with gr.Blocks(
44
+ title='RugbyXpert',
45
+ # theme='sudeepshouche/minimalist', # https://www.gradio.app/guides/theming-guide
46
+ ) as demo:
47
+ gr.Markdown(
48
+ """
49
+ # RugbyXpert
50
+ """
51
+ )
52
+ chatbot = gr.Chatbot(
53
+ height=500, # just to fit the notebook
54
+ )
55
+ msg = gr.Textbox(label="Pose-moi une question sur le rugby pendant la saison 2022-2023")
56
+ with gr.Row():
57
+ with gr.Column():
58
+ btn = gr.Button("Submit", variant="primary")
59
+ with gr.Column():
60
+ clear = gr.ClearButton(components=[msg, chatbot], value="Clear console")
61
+ gr.Examples([
62
+ "Tu peux me donner le 21 de Vannes lors du match les opposant à Aurillac du vendredi 24 février 2023 ?",
63
+ "Tu peux me retrouver le score final du match opposant Soyaux-Angoulême à Grenoble le vendredi 17 mars 2023 ?",
64
+ "Dis-moi le score final du match opposant Vannes à Aurillac le vendredi 24 février 2023 ?",
65
+ ], [msg])
66
+
67
+ btn.click(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])
68
+ msg.submit(respond, inputs=[msg, chatbot], outputs=[msg, chatbot]) #Press enter to submit
69
+
70
+ demo.launch()
71
+