Akbartus commited on
Commit
06f0d67
·
1 Parent(s): e93e03d

Create question_generation.py

Browse files
Files changed (1) hide show
  1. question_generation.py +97 -0
question_generation.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+
4
+ @torch.no_grad()
5
+ def question_generation_sampling(
6
+ g1_model,
7
+ g1_tokenizer,
8
+ g2_model,
9
+ g2_tokenizer,
10
+ context,
11
+ num_questions,
12
+ device,
13
+ ):
14
+ qa_input_ids = prepare_qa_input(
15
+ g1_tokenizer,
16
+ context=context,
17
+ device=device,
18
+ )
19
+ max_repeated_sampling = int(num_questions * 1.5) # sometimes generated question+answer is invalid
20
+ num_valid_questions = 0
21
+ questions = []
22
+ for q_ in range(max_repeated_sampling):
23
+ # Stage G.1: question+answer generation
24
+ outputs = g1_model.generate(
25
+ qa_input_ids,
26
+ max_new_tokens=128,
27
+ do_sample=True,
28
+ )
29
+ question_answer = g1_tokenizer.decode(outputs[0], skip_special_tokens=False)
30
+ question_answer = question_answer.replace(g1_tokenizer.pad_token, "").replace(g1_tokenizer.eos_token, "")
31
+ question_answer_split = question_answer.split(g1_tokenizer.sep_token)
32
+ if len(question_answer_split) == 2:
33
+ # valid Question + Annswer output
34
+ num_valid_questions += 1
35
+ else:
36
+ continue
37
+ question = question_answer_split[0].strip()
38
+ answer = question_answer_split[1].strip()
39
+
40
+ # Stage G.2: Distractor Generation
41
+ distractor_input_ids = prepare_distractor_input(
42
+ g2_tokenizer,
43
+ context = context,
44
+ question = question,
45
+ answer = answer,
46
+ device = device,
47
+ separator = g2_tokenizer.sep_token,
48
+ )
49
+ outputs = g2_model.generate(
50
+ distractor_input_ids,
51
+ max_new_tokens=128,
52
+ do_sample=True,
53
+ )
54
+ distractors = g2_tokenizer.decode(outputs[0], skip_special_tokens=False)
55
+ distractors = distractors.replace(g2_tokenizer.pad_token, "").replace(g2_tokenizer.eos_token, "")
56
+ distractors = re.sub("<extra\S+>", g2_tokenizer.sep_token, distractors)
57
+ distractors = [y.strip() for y in distractors.split(g2_tokenizer.sep_token)]
58
+ options = [answer] + distractors
59
+
60
+ while len(options) < 4:
61
+ options.append(options[-1])
62
+
63
+ question_item = {
64
+ 'question': question,
65
+ 'options': options,
66
+ }
67
+ questions.append(question_item)
68
+ if num_valid_questions == num_questions:
69
+ break
70
+ return questions
71
+
72
+
73
+ def prepare_qa_input(t5_tokenizer, context, device):
74
+ """
75
+ input: context
76
+ output: question <sep> answer
77
+ """
78
+ encoding = t5_tokenizer(
79
+ [context],
80
+ return_tensors="pt",
81
+ )
82
+ input_ids = encoding.input_ids.to(device)
83
+ return input_ids
84
+
85
+
86
+ def prepare_distractor_input(t5_tokenizer, context, question, answer, device, separator='<sep>'):
87
+ """
88
+ input: question <sep> answer <sep> article
89
+ output: distractor1 <sep> distractor2 <sep> distractor3
90
+ """
91
+ input_text = question + ' ' + separator + ' ' + answer + ' ' + separator + ' ' + context
92
+ encoding = t5_tokenizer(
93
+ [input_text],
94
+ return_tensors="pt",
95
+ )
96
+ input_ids = encoding.input_ids.to(device)
97
+ return input_ids