add focusing questions
Browse files- handler.py +74 -15
handler.py
CHANGED
@@ -13,9 +13,11 @@ from transformers import BertTokenizer, BertForSequenceClassification
|
|
13 |
|
14 |
transformers.logging.set_verbosity_debug()
|
15 |
|
16 |
-
UPTAKE_MODEL='ddemszky/uptake-model'
|
17 |
-
REASONING_MODEL ='ddemszky/student-reasoning'
|
18 |
-
QUESTION_MODEL ='ddemszky/question-detection'
|
|
|
|
|
19 |
|
20 |
class Utterance:
|
21 |
def __init__(self, speaker, text, uid=None,
|
@@ -31,6 +33,7 @@ class Utterance:
|
|
31 |
self.uptake = None
|
32 |
self.reasoning = None
|
33 |
self.question = None
|
|
|
34 |
|
35 |
def get_clean_text(self, remove_punct=False):
|
36 |
if remove_punct:
|
@@ -50,6 +53,7 @@ class Utterance:
|
|
50 |
'uptake': self.uptake,
|
51 |
'reasoning': self.reasoning,
|
52 |
'question': self.question,
|
|
|
53 |
**self.props
|
54 |
}
|
55 |
|
@@ -58,6 +62,7 @@ class Utterance:
|
|
58 |
f"text='{self.text}', uid={self.uid}," \
|
59 |
f"starttime={self.starttime}, endtime={self.endtime}, props={self.props})"
|
60 |
|
|
|
61 |
class Transcript:
|
62 |
def __init__(self, **kwargs):
|
63 |
self.utterances = []
|
@@ -90,6 +95,7 @@ class Transcript:
|
|
90 |
def __repr__(self):
|
91 |
return f"Transcript(utterances={self.utterances}, custom_params={self.params})"
|
92 |
|
|
|
93 |
class QuestionModel:
|
94 |
def __init__(self, device, tokenizer, input_builder, max_length=300, path=QUESTION_MODEL):
|
95 |
print("Loading models...")
|
@@ -97,10 +103,10 @@ class QuestionModel:
|
|
97 |
self.tokenizer = tokenizer
|
98 |
self.input_builder = input_builder
|
99 |
self.max_length = max_length
|
100 |
-
self.model = MultiHeadModel.from_pretrained(
|
|
|
101 |
self.model.to(self.device)
|
102 |
|
103 |
-
|
104 |
def run_inference(self, transcript):
|
105 |
self.model.eval()
|
106 |
with torch.no_grad():
|
@@ -114,12 +120,14 @@ class QuestionModel:
|
|
114 |
input_str=True)
|
115 |
output = self.get_prediction(instance)
|
116 |
print(output)
|
117 |
-
utt.question = np.argmax(
|
|
|
118 |
|
119 |
def get_prediction(self, instance):
|
120 |
instance["attention_mask"] = [[1] * len(instance["input_ids"])]
|
121 |
for key in ["input_ids", "token_type_ids", "attention_mask"]:
|
122 |
-
instance[key] = torch.tensor(
|
|
|
123 |
instance[key].to(self.device)
|
124 |
|
125 |
output = self.model(input_ids=instance["input_ids"],
|
@@ -128,6 +136,7 @@ class QuestionModel:
|
|
128 |
return_pooler_output=False)
|
129 |
return output
|
130 |
|
|
|
131 |
class ReasoningModel:
|
132 |
def __init__(self, device, tokenizer, input_builder, max_length=128, path=REASONING_MODEL):
|
133 |
print("Loading models...")
|
@@ -152,7 +161,8 @@ class ReasoningModel:
|
|
152 |
def get_prediction(self, instance):
|
153 |
instance["attention_mask"] = [[1] * len(instance["input_ids"])]
|
154 |
for key in ["input_ids", "token_type_ids", "attention_mask"]:
|
155 |
-
instance[key] = torch.tensor(
|
|
|
156 |
instance[key].to(self.device)
|
157 |
|
158 |
output = self.model(input_ids=instance["input_ids"],
|
@@ -160,6 +170,7 @@ class ReasoningModel:
|
|
160 |
token_type_ids=instance["token_type_ids"])
|
161 |
return output
|
162 |
|
|
|
163 |
class UptakeModel:
|
164 |
def __init__(self, device, tokenizer, input_builder, max_length=120, path=UPTAKE_MODEL):
|
165 |
print("Loading models...")
|
@@ -184,14 +195,16 @@ class UptakeModel:
|
|
184 |
input_str=True)
|
185 |
output = self.get_prediction(instance)
|
186 |
|
187 |
-
utt.uptake = int(
|
|
|
188 |
prev_num_words = utt.get_num_words()
|
189 |
prev_utt = utt
|
190 |
|
191 |
def get_prediction(self, instance):
|
192 |
instance["attention_mask"] = [[1] * len(instance["input_ids"])]
|
193 |
for key in ["input_ids", "token_type_ids", "attention_mask"]:
|
194 |
-
instance[key] = torch.tensor(
|
|
|
195 |
instance[key].to(self.device)
|
196 |
|
197 |
output = self.model(input_ids=instance["input_ids"],
|
@@ -201,6 +214,44 @@ class UptakeModel:
|
|
201 |
return output
|
202 |
|
203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
class EndpointHandler():
|
205 |
def __init__(self, path="."):
|
206 |
print("Loading models...")
|
@@ -231,18 +282,26 @@ class EndpointHandler():
|
|
231 |
transcript.add_utterance(Utterance(**utt))
|
232 |
|
233 |
print("Running inference on %d examples..." % transcript.length())
|
234 |
-
|
235 |
# Uptake
|
236 |
-
uptake_model = UptakeModel(
|
|
|
237 |
uptake_model.run_inference(transcript, min_prev_words=params['uptake_min_num_words'],
|
238 |
-
uptake_speaker=
|
239 |
|
240 |
# Reasoning
|
241 |
-
reasoning_model = ReasoningModel(
|
|
|
242 |
reasoning_model.run_inference(transcript)
|
243 |
|
244 |
# Question
|
245 |
-
question_model = QuestionModel(
|
|
|
246 |
question_model.run_inference(transcript)
|
247 |
|
|
|
|
|
|
|
|
|
|
|
248 |
return transcript.to_dict()
|
|
|
13 |
|
14 |
transformers.logging.set_verbosity_debug()
|
15 |
|
16 |
+
UPTAKE_MODEL = 'ddemszky/uptake-model'
|
17 |
+
REASONING_MODEL = 'ddemszky/student-reasoning'
|
18 |
+
QUESTION_MODEL = 'ddemszky/question-detection'
|
19 |
+
FOCUSING_QUESTION_MODEL = 'ddemszky/focusing-questions'
|
20 |
+
|
21 |
|
22 |
class Utterance:
|
23 |
def __init__(self, speaker, text, uid=None,
|
|
|
33 |
self.uptake = None
|
34 |
self.reasoning = None
|
35 |
self.question = None
|
36 |
+
self.focusing_question = None
|
37 |
|
38 |
def get_clean_text(self, remove_punct=False):
|
39 |
if remove_punct:
|
|
|
53 |
'uptake': self.uptake,
|
54 |
'reasoning': self.reasoning,
|
55 |
'question': self.question,
|
56 |
+
'focusingquestion': self.focusing_question,
|
57 |
**self.props
|
58 |
}
|
59 |
|
|
|
62 |
f"text='{self.text}', uid={self.uid}," \
|
63 |
f"starttime={self.starttime}, endtime={self.endtime}, props={self.props})"
|
64 |
|
65 |
+
|
66 |
class Transcript:
|
67 |
def __init__(self, **kwargs):
|
68 |
self.utterances = []
|
|
|
95 |
def __repr__(self):
|
96 |
return f"Transcript(utterances={self.utterances}, custom_params={self.params})"
|
97 |
|
98 |
+
|
99 |
class QuestionModel:
|
100 |
def __init__(self, device, tokenizer, input_builder, max_length=300, path=QUESTION_MODEL):
|
101 |
print("Loading models...")
|
|
|
103 |
self.tokenizer = tokenizer
|
104 |
self.input_builder = input_builder
|
105 |
self.max_length = max_length
|
106 |
+
self.model = MultiHeadModel.from_pretrained(
|
107 |
+
path, head2size={"is_question": 2})
|
108 |
self.model.to(self.device)
|
109 |
|
|
|
110 |
def run_inference(self, transcript):
|
111 |
self.model.eval()
|
112 |
with torch.no_grad():
|
|
|
120 |
input_str=True)
|
121 |
output = self.get_prediction(instance)
|
122 |
print(output)
|
123 |
+
utt.question = np.argmax(
|
124 |
+
output["is_question_logits"][0].tolist())
|
125 |
|
126 |
def get_prediction(self, instance):
|
127 |
instance["attention_mask"] = [[1] * len(instance["input_ids"])]
|
128 |
for key in ["input_ids", "token_type_ids", "attention_mask"]:
|
129 |
+
instance[key] = torch.tensor(
|
130 |
+
instance[key]).unsqueeze(0) # Batch size = 1
|
131 |
instance[key].to(self.device)
|
132 |
|
133 |
output = self.model(input_ids=instance["input_ids"],
|
|
|
136 |
return_pooler_output=False)
|
137 |
return output
|
138 |
|
139 |
+
|
140 |
class ReasoningModel:
|
141 |
def __init__(self, device, tokenizer, input_builder, max_length=128, path=REASONING_MODEL):
|
142 |
print("Loading models...")
|
|
|
161 |
def get_prediction(self, instance):
|
162 |
instance["attention_mask"] = [[1] * len(instance["input_ids"])]
|
163 |
for key in ["input_ids", "token_type_ids", "attention_mask"]:
|
164 |
+
instance[key] = torch.tensor(
|
165 |
+
instance[key]).unsqueeze(0) # Batch size = 1
|
166 |
instance[key].to(self.device)
|
167 |
|
168 |
output = self.model(input_ids=instance["input_ids"],
|
|
|
170 |
token_type_ids=instance["token_type_ids"])
|
171 |
return output
|
172 |
|
173 |
+
|
174 |
class UptakeModel:
|
175 |
def __init__(self, device, tokenizer, input_builder, max_length=120, path=UPTAKE_MODEL):
|
176 |
print("Loading models...")
|
|
|
195 |
input_str=True)
|
196 |
output = self.get_prediction(instance)
|
197 |
|
198 |
+
utt.uptake = int(
|
199 |
+
softmax(output["nsp_logits"][0].tolist())[1] > .8)
|
200 |
prev_num_words = utt.get_num_words()
|
201 |
prev_utt = utt
|
202 |
|
203 |
def get_prediction(self, instance):
|
204 |
instance["attention_mask"] = [[1] * len(instance["input_ids"])]
|
205 |
for key in ["input_ids", "token_type_ids", "attention_mask"]:
|
206 |
+
instance[key] = torch.tensor(
|
207 |
+
instance[key]).unsqueeze(0) # Batch size = 1
|
208 |
instance[key].to(self.device)
|
209 |
|
210 |
output = self.model(input_ids=instance["input_ids"],
|
|
|
214 |
return output
|
215 |
|
216 |
|
217 |
+
|
218 |
+
class FocusingQuestionModel:
|
219 |
+
def __init__(self, device, tokenizer, input_builder, max_length=128, path=FOCUSING_QUESTION_MODEL):
|
220 |
+
print("Loading models...")
|
221 |
+
self.device = device
|
222 |
+
self.tokenizer = tokenizer
|
223 |
+
self.input_builder = input_builder
|
224 |
+
self.model = BertForSequenceClassification.from_pretrained(path)
|
225 |
+
self.model.to(self.device)
|
226 |
+
self.max_length = max_length
|
227 |
+
|
228 |
+
def run_inference(self, transcript, min_focusing_words=0, uptake_speaker=None):
|
229 |
+
self.model.eval()
|
230 |
+
with torch.no_grad():
|
231 |
+
for i, utt in enumerate(transcript.utterances):
|
232 |
+
if utt.speaker != uptake_speaker or uptake_speaker is None:
|
233 |
+
utt.focusing_question = None
|
234 |
+
continue
|
235 |
+
if utt.get_num_words() < min_focusing_words:
|
236 |
+
utt.focusing_question = None
|
237 |
+
continue
|
238 |
+
instance = self.input_builder.build_inputs([], utt.text, max_length=self.max_length, input_str=True)
|
239 |
+
output = self.get_prediction(instance)
|
240 |
+
utt.focusing_question = np.argmax(output["logits"][0].tolist())
|
241 |
+
|
242 |
+
def get_prediction(self, instance):
|
243 |
+
instance["attention_mask"] = [[1] * len(instance["input_ids"])]
|
244 |
+
for key in ["input_ids", "token_type_ids", "attention_mask"]:
|
245 |
+
instance[key] = torch.tensor(
|
246 |
+
instance[key]).unsqueeze(0) # Batch size = 1
|
247 |
+
instance[key].to(self.device)
|
248 |
+
|
249 |
+
output = self.model(input_ids=instance["input_ids"],
|
250 |
+
attention_mask=instance["attention_mask"],
|
251 |
+
token_type_ids=instance["token_type_ids"])
|
252 |
+
return output
|
253 |
+
|
254 |
+
|
255 |
class EndpointHandler():
|
256 |
def __init__(self, path="."):
|
257 |
print("Loading models...")
|
|
|
282 |
transcript.add_utterance(Utterance(**utt))
|
283 |
|
284 |
print("Running inference on %d examples..." % transcript.length())
|
285 |
+
uptake_speaker = params.pop("uptake_speaker", None)
|
286 |
# Uptake
|
287 |
+
uptake_model = UptakeModel(
|
288 |
+
self.device, self.tokenizer, self.input_builder)
|
289 |
uptake_model.run_inference(transcript, min_prev_words=params['uptake_min_num_words'],
|
290 |
+
uptake_speaker=uptake_speaker)
|
291 |
|
292 |
# Reasoning
|
293 |
+
reasoning_model = ReasoningModel(
|
294 |
+
self.device, self.tokenizer, self.input_builder)
|
295 |
reasoning_model.run_inference(transcript)
|
296 |
|
297 |
# Question
|
298 |
+
question_model = QuestionModel(
|
299 |
+
self.device, self.tokenizer, self.input_builder)
|
300 |
question_model.run_inference(transcript)
|
301 |
|
302 |
+
# Focusing Question
|
303 |
+
focusing_question_model = FocusingQuestionModel(
|
304 |
+
self.device, self.tokenizer, self.input_builder)
|
305 |
+
focusing_question_model.run_inference(transcript, uptake_speaker=uptake_speaker)
|
306 |
+
|
307 |
return transcript.to_dict()
|