alizhgir commited on
Commit
f71220a
·
1 Parent(s): aed38e9

исправлена третья модель

Browse files
Files changed (1) hide show
  1. app.py +36 -16
app.py CHANGED
@@ -22,6 +22,8 @@ import json
22
  import gensim
23
  import torch.nn.functional as F
24
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
 
25
 
26
 
27
  st.title('10-я неделя DS. Классификация отзывов, определение токсичности и генерация текста')
@@ -230,43 +232,57 @@ if page == "Определение токсичности":
230
  return model
231
 
232
  # Загрузка обученной модели
233
- clf = load_model('toxic/logistic_regression_model_toxic.pkl') # Укажите путь к файлу модели
234
 
235
  # Загрузка токенизатора и модели BERT
236
- tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny-toxicity")
237
- model = AutoModel.from_pretrained("cointegrated/rubert-tiny-toxicity")
238
 
239
  # Функция для предсказания токсичности сообщения
240
  def predict_toxicity(text):
241
- encoded = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
242
  with torch.no_grad():
243
- outputs = model(**encoded)
244
  features = outputs.last_hidden_state[:, 0, :].numpy()
245
- prediction = clf.predict(features)
246
  return prediction[0]
247
 
 
 
 
 
 
 
 
 
 
 
248
  # Создание интерфейса Streamlit
249
  st.title("Оценка токсичности сообщения")
250
 
251
  # Текстовое поле для ввода сообщения
252
  user_input = st.text_area("Введите сообщение для оценки")
253
 
254
- if st.button("Оценить"):
255
  if user_input:
256
  # Оценка токсичности сообщения
257
- prediction = predict_toxicity(user_input)
258
- if prediction > 0.5:
259
- st.write("Сообщение токсично")
260
- st.write(prediction)
261
- else:
262
- st.write("Сообщение не токсично")
263
- st.write(prediction)
264
  else:
265
  st.write("Пожалуйста, введите сообщение")
 
 
 
 
 
 
 
 
266
 
267
 
268
 
269
  if page == "Генерация текста":
 
270
  # Путь к вашим весам модели
271
  model_weights_path = 'gpt-2/model.pt'
272
 
@@ -288,7 +304,7 @@ if page == "Генерация текста":
288
  st.title("Генератор плохих отзывов больниц от ruGPT3")
289
 
290
  # Ввод текста от пользователя
291
- user_prompt = st.text_area("Введите текст-промпт:", "Я была в этой клинике..")
292
 
293
  # Виджеты для динамической регуляции параметров
294
  max_length = st.slider("Выберите max_length:", 10, 300, 100)
@@ -312,4 +328,8 @@ if page == "Генерация текста":
312
  ).cpu().numpy()
313
  generated_text = tokenizer.decode(out[0], skip_special_tokens=True)
314
  st.subheader("Сгенерированный текст:")
315
- st.write(generated_text)
 
 
 
 
 
22
  import gensim
23
  import torch.nn.functional as F
24
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
25
+ from transformers import AutoModelForSequenceClassification
26
+
27
 
28
 
29
  st.title('10-я неделя DS. Классификация отзывов, определение токсичности и генерация текста')
 
232
  return model
233
 
234
  # Загрузка обученной модели
235
+ clf_c = load_model('toxic/logistic_regression_model_toxic.pkl') # Укажите путь к файлу модели
236
 
237
  # Загрузка токенизатора и модели BERT
238
+ tokenizer_c = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny-toxicity")
239
+ model_c = AutoModel.from_pretrained("cointegrated/rubert-tiny-toxicity")
240
 
241
  # Функция для предсказания токсичности сообщения
242
  def predict_toxicity(text):
243
+ encoded = tokenizer_c(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
244
  with torch.no_grad():
245
+ outputs = model_c(**encoded)
246
  features = outputs.last_hidden_state[:, 0, :].numpy()
247
+ prediction = clf_c.predict_proba(features)
248
  return prediction[0]
249
 
250
+ model_checkpoint = 'cointegrated/rubert-tiny-toxicity'
251
+ tokenizer_b = AutoTokenizer.from_pretrained(model_checkpoint)
252
+ model_b = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
253
+
254
+ def text2toxicity(text):
255
+ with torch.no_grad():
256
+ inputs = tokenizer_b(text, return_tensors='pt', truncation=True, padding=True)
257
+ proba = torch.sigmoid(model_b(**inputs).logits).cpu().numpy()
258
+ return proba[0][1]
259
+
260
  # Создание интерфейса Streamlit
261
  st.title("Оценка токсичности сообщения")
262
 
263
  # Текстовое поле для ввода сообщения
264
  user_input = st.text_area("Введите сообщение для оценки")
265
 
266
+ if st.button("Оценить токсичность сообщения кастомизированной моделью"):
267
  if user_input:
268
  # Оценка токсичности сообщения
269
+ prediction = predict_toxicity(user_input)[1]
270
+ st.write(f'Вероятность токсичности согласно кастомизированной модели: {prediction:.4f}')
 
 
 
 
 
271
  else:
272
  st.write("Пожалуйста, введите сообщение")
273
+
274
+ if st.button('Определить токсичность базовой моделью'):
275
+ if user_input:
276
+ # Определение токсичности сообщения
277
+ proba_toxicity = text2toxicity(user_input)
278
+ st.write(f'Вероятность токсичности rubert-tiny-toxicity.pretrained: {proba_toxicity:.4f}')
279
+ else:
280
+ st.write('Пожалуйста, введите сообщение')
281
 
282
 
283
 
284
  if page == "Генерация текста":
285
+
286
  # Путь к вашим весам модели
287
  model_weights_path = 'gpt-2/model.pt'
288
 
 
304
  st.title("Генератор плохих отзывов больниц от ruGPT3")
305
 
306
  # Ввод текста от пользователя
307
+ user_prompt = st.text_area("Введите текст-промпт:", "Я была в этой клинике")
308
 
309
  # Виджеты для динамической регуляции параметров
310
  max_length = st.slider("Выберите max_length:", 10, 300, 100)
 
328
  ).cpu().numpy()
329
  generated_text = tokenizer.decode(out[0], skip_special_tokens=True)
330
  st.subheader("Сгенерированный текст:")
331
+ st.write(generated_text)
332
+
333
+ if __name__ == "__main__":
334
+ main()
335
+