update handler for new output
Browse files- handler.py +123 -9
handler.py
CHANGED
@@ -3,6 +3,9 @@ from scipy.special import softmax
|
|
3 |
import numpy as np
|
4 |
import weakref
|
5 |
import re
|
|
|
|
|
|
|
6 |
|
7 |
from utils import clean_str, clean_str_nopunct
|
8 |
import torch
|
@@ -10,7 +13,7 @@ from utils import MultiHeadModel, BertInputBuilder, get_num_words, MATH_PREFIXES
|
|
10 |
|
11 |
import transformers
|
12 |
from transformers import BertTokenizer, BertForSequenceClassification
|
13 |
-
|
14 |
|
15 |
transformers.logging.set_verbosity_debug()
|
16 |
|
@@ -30,9 +33,15 @@ class Utterance:
|
|
30 |
self.endtime = endtime
|
31 |
self.transcript = weakref.ref(transcript) if transcript else None
|
32 |
self.props = kwargs
|
|
|
|
|
|
|
|
|
|
|
33 |
self.num_math_terms = None
|
34 |
self.math_terms = None
|
35 |
|
|
|
36 |
self.uptake = None
|
37 |
self.reasoning = None
|
38 |
self.question = None
|
@@ -62,6 +71,21 @@ class Utterance:
|
|
62 |
**self.props
|
63 |
}
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
def __repr__(self):
|
66 |
return f"Utterance(speaker='{self.speaker}'," \
|
67 |
f"text='{self.text}', uid={self.uid}," \
|
@@ -91,6 +115,86 @@ class Transcript:
|
|
91 |
def length(self):
|
92 |
return len(self.utterances)
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
def to_dict(self):
|
95 |
return {
|
96 |
'utterances': [utterance.to_dict() for utterance in self.utterances],
|
@@ -218,8 +322,6 @@ class UptakeModel:
|
|
218 |
return_pooler_output=False)
|
219 |
return output
|
220 |
|
221 |
-
|
222 |
-
|
223 |
class FocusingQuestionModel:
|
224 |
def __init__(self, device, tokenizer, input_builder, max_length=128, path=FOCUSING_QUESTION_MODEL):
|
225 |
print("Loading models...")
|
@@ -254,8 +356,7 @@ class FocusingQuestionModel:
|
|
254 |
output = self.model(input_ids=instance["input_ids"],
|
255 |
attention_mask=instance["attention_mask"],
|
256 |
token_type_ids=instance["token_type_ids"])
|
257 |
-
return output
|
258 |
-
|
259 |
|
260 |
def load_math_terms():
|
261 |
math_terms = []
|
@@ -283,7 +384,7 @@ def run_math_density(transcript):
|
|
283 |
matches = [match for match in matches if not any(match.start() in range(existing[0], existing[1]) for existing in matched_positions)]
|
284 |
if len(matches) > 0:
|
285 |
match_list.append(math_terms_dict[term])
|
286 |
-
# Update
|
287 |
matched_positions.update((match.start(), match.end()) for match in matches)
|
288 |
num_matches += len(matches)
|
289 |
utt.num_math_terms = num_matches
|
@@ -319,13 +420,13 @@ class EndpointHandler():
|
|
319 |
transcript.add_utterance(Utterance(**utt))
|
320 |
|
321 |
print("Running inference on %d examples..." % transcript.length())
|
322 |
-
|
323 |
# Uptake
|
324 |
uptake_model = UptakeModel(
|
325 |
self.device, self.tokenizer, self.input_builder)
|
|
|
326 |
uptake_model.run_inference(transcript, min_prev_words=params['uptake_min_num_words'],
|
327 |
uptake_speaker=uptake_speaker)
|
328 |
-
|
329 |
# Reasoning
|
330 |
reasoning_model = ReasoningModel(
|
331 |
self.device, self.tokenizer, self.input_builder)
|
@@ -343,4 +444,17 @@ class EndpointHandler():
|
|
343 |
|
344 |
run_math_density(transcript)
|
345 |
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import numpy as np
|
4 |
import weakref
|
5 |
import re
|
6 |
+
import nltk
|
7 |
+
from nltk.corpus import stopwords
|
8 |
+
nltk.download('stopwords')
|
9 |
|
10 |
from utils import clean_str, clean_str_nopunct
|
11 |
import torch
|
|
|
13 |
|
14 |
import transformers
|
15 |
from transformers import BertTokenizer, BertForSequenceClassification
|
16 |
+
from transformers.utils import logging
|
17 |
|
18 |
transformers.logging.set_verbosity_debug()
|
19 |
|
|
|
33 |
self.endtime = endtime
|
34 |
self.transcript = weakref.ref(transcript) if transcript else None
|
35 |
self.props = kwargs
|
36 |
+
self.role = None
|
37 |
+
self.word_count = self.get_num_words()
|
38 |
+
self.timestamp = [starttime, endtime]
|
39 |
+
self.unit_measure = None
|
40 |
+
self.aggregate_unit_measure = endtime
|
41 |
self.num_math_terms = None
|
42 |
self.math_terms = None
|
43 |
|
44 |
+
# moments
|
45 |
self.uptake = None
|
46 |
self.reasoning = None
|
47 |
self.question = None
|
|
|
71 |
**self.props
|
72 |
}
|
73 |
|
74 |
+
def to_talk_timeline_dict(self):
|
75 |
+
return{
|
76 |
+
'speaker': self.speaker,
|
77 |
+
'text': self.text,
|
78 |
+
'uid': self.uid,
|
79 |
+
'role': self.role,
|
80 |
+
'timestamp': self.timestamp,
|
81 |
+
'moments': {'reasoning': True if self.reasoning else False, 'questioning': True if self.question else False, 'uptake': True if self.uptake else False, 'focusingQuestion': True if self.focusing_question else False},
|
82 |
+
'unitMeasure': self.unit_measure,
|
83 |
+
'aggregateUnitMeasure': self.aggregate_unit_measure,
|
84 |
+
'wordCount': self.word_count,
|
85 |
+
'numMathTerms': self.num_math_terms,
|
86 |
+
'mathTerms': self.math_terms
|
87 |
+
}
|
88 |
+
|
89 |
def __repr__(self):
|
90 |
return f"Utterance(speaker='{self.speaker}'," \
|
91 |
f"text='{self.text}', uid={self.uid}," \
|
|
|
115 |
def length(self):
|
116 |
return len(self.utterances)
|
117 |
|
118 |
+
def update_utterance_roles(self, uptake_speaker):
|
119 |
+
for utt in self.utterances:
|
120 |
+
if (utt.speaker == uptake_speaker):
|
121 |
+
utt.role = 'teacher'
|
122 |
+
else:
|
123 |
+
utt.role = 'student'
|
124 |
+
|
125 |
+
def get_talk_distribution_and_length(self, uptake_speaker):
|
126 |
+
if ((uptake_speaker is None)):
|
127 |
+
return None
|
128 |
+
teacher_words = 0
|
129 |
+
teacher_utt_count = 0
|
130 |
+
student_words = 0
|
131 |
+
student_utt_count = 0
|
132 |
+
for utt in self.utterances:
|
133 |
+
if (utt.speaker == uptake_speaker):
|
134 |
+
utt.role = 'teacher'
|
135 |
+
teacher_words += utt.get_num_words()
|
136 |
+
teacher_utt_count += 1
|
137 |
+
else:
|
138 |
+
utt.role = 'student'
|
139 |
+
student_words += utt.get_num_words()
|
140 |
+
student_utt_count += 1
|
141 |
+
teacher_percentage = round(
|
142 |
+
(teacher_words / (teacher_words + student_words)) * 100)
|
143 |
+
student_percentage = 100 - teacher_percentage
|
144 |
+
avg_teacher_length = teacher_words / teacher_utt_count
|
145 |
+
avg_student_length = student_words / student_utt_count
|
146 |
+
return {'teacher': teacher_percentage, 'student': student_percentage}, {'teacher': avg_teacher_length, 'student': avg_student_length}
|
147 |
+
|
148 |
+
def get_word_cloud_dicts(self):
|
149 |
+
teacher_dict = {}
|
150 |
+
student_dict = {}
|
151 |
+
uptake_teacher_dict = {}
|
152 |
+
stop_words = stopwords.words('english')
|
153 |
+
# stopwords = nltk.corpus.stopwords.word('english')
|
154 |
+
# print("stopwords: ", stopwords)
|
155 |
+
for utt in self.utterances:
|
156 |
+
words = (utt.get_clean_text(remove_punct=True)).split(' ')
|
157 |
+
for word in words:
|
158 |
+
if word in stop_words: continue
|
159 |
+
if utt.role == 'teacher':
|
160 |
+
if word not in teacher_dict:
|
161 |
+
teacher_dict[word] = 0
|
162 |
+
teacher_dict[word] += 1
|
163 |
+
if utt.uptake == 1:
|
164 |
+
if word not in uptake_teacher_dict:
|
165 |
+
uptake_teacher_dict[word] = 0
|
166 |
+
uptake_teacher_dict[word] += 1
|
167 |
+
else:
|
168 |
+
if word not in student_dict:
|
169 |
+
student_dict[word] = 0
|
170 |
+
student_dict[word] += 1
|
171 |
+
dict_list = []
|
172 |
+
uptake_dict_list = []
|
173 |
+
for word in uptake_teacher_dict.keys():
|
174 |
+
uptake_dict_list.append({'text': word, 'value': uptake_teacher_dict[word], 'category': 'teacher'})
|
175 |
+
for word in teacher_dict.keys():
|
176 |
+
dict_list.append(
|
177 |
+
{'text': word, 'value': teacher_dict[word], 'category': 'teacher'})
|
178 |
+
for word in student_dict.keys():
|
179 |
+
dict_list.append(
|
180 |
+
{'text': word, 'value': student_dict[word], 'category': 'student'})
|
181 |
+
sorted_dict_list = sorted(dict_list, key=lambda x: x['value'], reverse=True)
|
182 |
+
sorted_uptake_dict_list = sorted(uptake_dict_list, key=lambda x: x['value'], reverse=True)
|
183 |
+
return sorted_dict_list[:50], sorted_uptake_dict_list[:50]
|
184 |
+
|
185 |
+
def get_talk_timeline(self):
|
186 |
+
return [utterance.to_talk_timeline_dict() for utterance in self.utterances]
|
187 |
+
|
188 |
+
def calculate_aggregate_word_count(self):
|
189 |
+
unit_measures = [utt.unit_measure for utt in self.utterances]
|
190 |
+
if None in unit_measures:
|
191 |
+
aggregate_word_count = 0
|
192 |
+
for utt in self.utterances:
|
193 |
+
aggregate_word_count += utt.get_num_words()
|
194 |
+
utt.unit_measure = utt.get_num_words()
|
195 |
+
utt.aggregate_unit_measure = aggregate_word_count
|
196 |
+
|
197 |
+
|
198 |
def to_dict(self):
|
199 |
return {
|
200 |
'utterances': [utterance.to_dict() for utterance in self.utterances],
|
|
|
322 |
return_pooler_output=False)
|
323 |
return output
|
324 |
|
|
|
|
|
325 |
class FocusingQuestionModel:
|
326 |
def __init__(self, device, tokenizer, input_builder, max_length=128, path=FOCUSING_QUESTION_MODEL):
|
327 |
print("Loading models...")
|
|
|
356 |
output = self.model(input_ids=instance["input_ids"],
|
357 |
attention_mask=instance["attention_mask"],
|
358 |
token_type_ids=instance["token_type_ids"])
|
359 |
+
return output
|
|
|
360 |
|
361 |
def load_math_terms():
|
362 |
math_terms = []
|
|
|
384 |
matches = [match for match in matches if not any(match.start() in range(existing[0], existing[1]) for existing in matched_positions)]
|
385 |
if len(matches) > 0:
|
386 |
match_list.append(math_terms_dict[term])
|
387 |
+
# Update matched positions
|
388 |
matched_positions.update((match.start(), match.end()) for match in matches)
|
389 |
num_matches += len(matches)
|
390 |
utt.num_math_terms = num_matches
|
|
|
420 |
transcript.add_utterance(Utterance(**utt))
|
421 |
|
422 |
print("Running inference on %d examples..." % transcript.length())
|
423 |
+
logging.set_verbosity_info()
|
424 |
# Uptake
|
425 |
uptake_model = UptakeModel(
|
426 |
self.device, self.tokenizer, self.input_builder)
|
427 |
+
uptake_speaker = params.pop("uptake_speaker", None)
|
428 |
uptake_model.run_inference(transcript, min_prev_words=params['uptake_min_num_words'],
|
429 |
uptake_speaker=uptake_speaker)
|
|
|
430 |
# Reasoning
|
431 |
reasoning_model = ReasoningModel(
|
432 |
self.device, self.tokenizer, self.input_builder)
|
|
|
444 |
|
445 |
run_math_density(transcript)
|
446 |
|
447 |
+
transcript.update_utterance_roles(uptake_speaker)
|
448 |
+
transcript.calculate_aggregate_word_count()
|
449 |
+
return_dict = {'talkDistribution': None, 'talkLength': None, 'talkMoments': None, 'commonTopWords': None, 'uptakeTopWords': None}
|
450 |
+
talk_dist, talk_len = transcript.get_talk_distribution_and_length(uptake_speaker)
|
451 |
+
return_dict['talkDistribution'] = talk_dist
|
452 |
+
return_dict['talkLength'] = talk_len
|
453 |
+
talk_moments = transcript.get_talk_timeline()
|
454 |
+
return_dict['talkMoments'] = talk_moments
|
455 |
+
word_cloud, uptake_word_cloud = transcript.get_word_cloud_dicts()
|
456 |
+
return_dict['commonTopWords'] = word_cloud
|
457 |
+
return_dict['uptakeTopWords'] = uptake_word_cloud
|
458 |
+
|
459 |
+
|
460 |
+
return return_dict
|