liujch1998 commited on
Commit
c752f9e
Β·
1 Parent(s): 0ef49e6
Files changed (1) hide show
  1. app.py +108 -32
app.py CHANGED
@@ -40,7 +40,8 @@ class Interactive:
40
  if MODE == 'debug':
41
  return
42
  self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto')
43
- self.linear = torch.nn.Linear(self.model.shared.embedding_dim, 1, dtype=self.model.dtype).to(device)
 
44
  self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D)
45
  self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1)
46
  self.model.eval()
@@ -64,52 +65,124 @@ class Interactive:
64
  score = logit.sigmoid()
65
  score_calibrated = logit_calibrated.sigmoid()
66
  return {
 
 
67
  'logit': logit.item(),
68
  'logit_calibrated': logit_calibrated.item(),
69
  'score': score.item(),
70
  'score_calibrated': score_calibrated.item(),
71
  }
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  interactive = Interactive()
74
 
75
- def predict(statement, do_save=True):
76
- result = interactive.run(statement)
77
- output = {
78
- 'True': result['score_calibrated'],
79
- 'False': 1 - result['score_calibrated'],
80
- }
81
- output_raw = {
82
- 'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
83
- 'statement': statement,
84
- }
85
- output_raw.update(result)
86
- if do_save:
87
- with open(DATA_PATH, 'a') as f:
88
- json.dump(output_raw, f, ensure_ascii=False)
89
- f.write('\n')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  commit_url = repo.push_to_hub()
91
  print('Logged statement to dataset:')
92
  print('Commit URL:', commit_url)
93
  print(output_raw)
94
  print()
95
- return output, output_raw, gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value='Please provide your feedback before trying out another statement.')
96
-
97
- def record_feedback(output_raw, feedback, do_save=True):
98
- if do_save:
99
- output_raw.update({ 'feedback': feedback })
100
- with open(DATA_PATH, 'a') as f:
101
- json.dump(output_raw, f, ensure_ascii=False)
102
- f.write('\n')
 
 
 
 
 
 
103
  commit_url = repo.push_to_hub()
104
  print('Logged feedback to dataset:')
105
  print('Commit URL:', commit_url)
106
  print(output_raw)
107
  print()
108
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value='Thanks for your feedback! Now you can enter another statement.')
109
- def record_feedback_agree(output_raw, do_save=True):
110
- return record_feedback(output_raw, 'agree', do_save)
111
- def record_feedback_disagree(output_raw, do_save=True):
112
- return record_feedback(output_raw, 'disagree', do_save)
 
 
 
113
 
114
  examples = [
115
  # openbookqa
@@ -223,9 +296,12 @@ with gr.Blocks() as demo:
223
  cache_examples=False,
224
  run_on_click=False, # If we want this to be True, I suspect we need to enable the statement.submit()
225
  )
226
- submit.click(predict, inputs=[statement, do_save], outputs=[output, output_raw, submit, feedback_agree, feedback_disagree, feedback_ack])
227
  # statement.submit(predict, inputs=[statement], outputs=[output, output_raw])
228
- feedback_agree.click(record_feedback_agree, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_disagree, feedback_ack])
229
- feedback_disagree.click(record_feedback_disagree, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_disagree, feedback_ack])
230
 
231
  demo.queue(concurrency_count=16).launch(debug=True)
 
 
 
 
40
  if MODE == 'debug':
41
  return
42
  self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto')
43
+ self.model.D = self.model.shared.embedding_dim
44
+ self.linear = torch.nn.Linear(self.model.D, 1, dtype=self.model.dtype).to(device)
45
  self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D)
46
  self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1)
47
  self.model.eval()
 
65
  score = logit.sigmoid()
66
  score_calibrated = logit_calibrated.sigmoid()
67
  return {
68
+ 'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
69
+ 'statement': statement,
70
  'logit': logit.item(),
71
  'logit_calibrated': logit_calibrated.item(),
72
  'score': score.item(),
73
  'score_calibrated': score_calibrated.item(),
74
  }
75
 
76
+ def runs(self, statements):
77
+ if MODE == 'debug':
78
+ return [{
79
+ 'logit': 0.0,
80
+ 'logit_calibrated': 0.0,
81
+ 'score': 0.5,
82
+ 'score_calibrated': 0.5,
83
+ } for _ in statements]
84
+ tok = self.tokenizer.batch_encode_plus(statements, return_tensors='pt', padding='longest')
85
+ input_ids = tok.input_ids.to(device)
86
+ attention_mask = tok.attention_mask.to(device)
87
+ with torch.no_grad():
88
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)
89
+ last_indices = attention_mask.sum(dim=1, keepdim=True) - 1 # (B, 1)
90
+ last_indices = last_indices.unsqueeze(-1).expand(-1, -1, self.model.D) # (B, 1, D)
91
+ last_hidden_state = output.last_hidden_state.to(device) # (B, L, D)
92
+ hidden = last_hidden_state.gather(dim=1, index=last_indices).squeeze(1) # (B, D)
93
+ logits = self.linear(hidden).squeeze(-1) # (B)
94
+ logits_calibrated = logits / self.t
95
+ scores = logits.sigmoid()
96
+ scores_calibrated = logits_calibrated.sigmoid()
97
+ return [{
98
+ 'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
99
+ 'statement': statement,
100
+ 'logit': logit.item(),
101
+ 'logit_calibrated': logit_calibrated.item(),
102
+ 'score': score.item(),
103
+ 'score_calibrated': score_calibrated.item(),
104
+ } for statement, logit, logit_calibrated, score, score_calibrated in zip(statements, logits, logits_calibrated, scores, scores_calibrated)]
105
+
106
  interactive = Interactive()
107
 
108
+ # def predict(statement, do_save=True):
109
+ # output_raw = interactive.run(statement)
110
+ # output = {
111
+ # 'True': output_raw['score_calibrated'],
112
+ # 'False': 1 - output_raw['score_calibrated'],
113
+ # }
114
+ # if do_save:
115
+ # with open(DATA_PATH, 'a') as f:
116
+ # json.dump(output_raw, f, ensure_ascii=False)
117
+ # f.write('\n')
118
+ # commit_url = repo.push_to_hub()
119
+ # print('Logged statement to dataset:')
120
+ # print('Commit URL:', commit_url)
121
+ # print(output_raw)
122
+ # print()
123
+ # return output, output_raw, gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value='Please provide your feedback before trying out another statement.')
124
+
125
+ # def record_feedback(output_raw, feedback, do_save=True):
126
+ # if do_save:
127
+ # output_raw.update({ 'feedback': feedback })
128
+ # with open(DATA_PATH, 'a') as f:
129
+ # json.dump(output_raw, f, ensure_ascii=False)
130
+ # f.write('\n')
131
+ # commit_url = repo.push_to_hub()
132
+ # print('Logged feedback to dataset:')
133
+ # print('Commit URL:', commit_url)
134
+ # print(output_raw)
135
+ # print()
136
+ # return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value='Thanks for your feedback! Now you can enter another statement.')
137
+ # def record_feedback_agree(output_raw, do_save=True):
138
+ # return record_feedback(output_raw, 'agree', do_save)
139
+ # def record_feedback_disagree(output_raw, do_save=True):
140
+ # return record_feedback(output_raw, 'disagree', do_save)
141
+
142
+ def predict(statements, do_saves):
143
+ output_raws = interactive.runs(statements)
144
+ outputs = [{
145
+ 'True': output_raw['score_calibrated'],
146
+ 'False': 1 - output_raw['score_calibrated'],
147
+ } for output_raw in output_raws]
148
+ for output_raw, do_save in zip(output_raws, do_saves):
149
+ if do_save:
150
+ with open(DATA_PATH, 'a') as f:
151
+ json.dump(output_raw, f, ensure_ascii=False)
152
+ f.write('\n')
153
+ if any(do_saves):
154
  commit_url = repo.push_to_hub()
155
  print('Logged statement to dataset:')
156
  print('Commit URL:', commit_url)
157
  print(output_raw)
158
  print()
159
+ return outputs, output_raws, \
160
+ [gr.update(visible=False) for _ in statements], \
161
+ [gr.update(visible=True) for _ in statements], \
162
+ [gr.update(visible=True) for _ in statements], \
163
+ [gr.update(value='Please provide your feedback before trying out another statement.') for _ in statements]
164
+
165
+ def record_feedback(output_raws, feedback, do_saves):
166
+ for output_raw, do_save in zip(output_raws, do_saves):
167
+ if do_save:
168
+ output_raw.update({ 'feedback': feedback })
169
+ with open(DATA_PATH, 'a') as f:
170
+ json.dump(output_raw, f, ensure_ascii=False)
171
+ f.write('\n')
172
+ if any(do_saves):
173
  commit_url = repo.push_to_hub()
174
  print('Logged feedback to dataset:')
175
  print('Commit URL:', commit_url)
176
  print(output_raw)
177
  print()
178
+ return [gr.update(visible=True) for _ in output_raws], \
179
+ [gr.update(visible=False) for _ in output_raws], \
180
+ [gr.update(visible=False) for _ in output_raws], \
181
+ [gr.update(value='Thanks for your feedback! Now you can enter another statement.') for _ in output_raws]
182
+ def record_feedback_agree(output_raws, do_saves):
183
+ return record_feedback(output_raws, 'agree', do_saves)
184
+ def record_feedback_disagree(output_raws, do_saves):
185
+ return record_feedback(output_raws, 'disagree', do_saves)
186
 
187
  examples = [
188
  # openbookqa
 
296
  cache_examples=False,
297
  run_on_click=False, # If we want this to be True, I suspect we need to enable the statement.submit()
298
  )
299
+ submit.click(predict, inputs=[statement, do_save], outputs=[output, output_raw, submit, feedback_agree, feedback_disagree, feedback_ack], batch=True, max_batch_size=16)
300
  # statement.submit(predict, inputs=[statement], outputs=[output, output_raw])
301
+ feedback_agree.click(record_feedback_agree, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_disagree, feedback_ack], batch=True, max_batch_size=16)
302
+ feedback_disagree.click(record_feedback_disagree, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_disagree, feedback_ack], batch=True, max_batch_size=16)
303
 
304
  demo.queue(concurrency_count=16).launch(debug=True)
305
+
306
+ # Concurrency, Batching
307
+ # Theme, CSS