Akbartus commited on
Commit
3695fb9
·
1 Parent(s): 06f0d67

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ from question_generation import question_generation_sampling
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ g1_tokenizer = AutoTokenizer.from_pretrained("potsawee/t5-large-generation-squad-QuestionAnswer")
9
+ g1_model = AutoModelForSeq2SeqLM.from_pretrained("potsawee/t5-large-generation-squad-QuestionAnswer")
10
+ g2_tokenizer = AutoTokenizer.from_pretrained("potsawee/t5-large-generation-race-Distractor")
11
+ g2_model = AutoModelForSeq2SeqLM.from_pretrained("potsawee/t5-large-generation-race-Distractor")
12
+ g1_model.eval()
13
+ g2_model.eval()
14
+ g1_model.to(device)
15
+ g2_model.to(device)
16
+
17
+
18
+ def generate_multiple_choice_question(
19
+ context
20
+ ):
21
+ num_questions = 1
22
+ question_item = question_generation_sampling(
23
+ g1_model, g1_tokenizer,
24
+ g2_model, g2_tokenizer,
25
+ context, num_questions, device
26
+ )[0]
27
+ question = question_item['question']
28
+ options = question_item['options']
29
+ options[0] = f"{options[0]} [ANSWER]"
30
+ random.shuffle(options)
31
+ output_string = f"Question: {question}\n[A] {options[0]}\n[B] {options[1]}\n[C] {options[2]}\n[D] {options[3]}"
32
+ return output_string
33
+
34
+ demo = gr.Interface(
35
+ fn=generate_multiple_choice_question,
36
+ inputs=gr.Textbox(lines=8, placeholder="Context Here..."),
37
+ outputs=gr.Textbox(lines=5, placeholder="Question: \n[A] \n[B] \n[C] \n[D] "),
38
+ title="Multiple-choice Question Generator",
39
+ description="Provide some context (e.g. news article or any passage) in the context box and click **Submit**. The models currently support English only. This demo is a part of MQAG - https://github.com/potsawee/mqag0.",
40
+ allow_flagging='never'
41
+ )
42
+ demo.launch()