fruitpicker01 commited on
Commit
677b493
1 Parent(s): a15d783

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -0
app.py CHANGED
@@ -16,6 +16,11 @@ from collections import defaultdict
16
  import requests
17
  import base64
18
  import io
 
 
 
 
 
19
 
20
  MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY')
21
  token = os.getenv('GITHUB_TOKEN')
@@ -783,6 +788,24 @@ def generate_all_messages(desc, benefits, key_message, gender, generation, psych
783
  time.sleep(1)
784
  save_statistics_to_github(approach_stats)
785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
786
 
787
  # ФУНКЦИИ ПРОВЕРОК (НАЧАЛО)
788
 
@@ -1763,6 +1786,7 @@ with gr.Blocks() as demo:
1763
  with gr.Row():
1764
  non_personalized_messages = gr.Textbox(label="Стандартные сообщения", lines=12, interactive=False)
1765
  personalized_messages = gr.Textbox(label="Персонализированные сообщения", lines=12, interactive=False)
 
1766
 
1767
  # Сначала переключаем вкладку, потом запускаем генерацию сообщений
1768
  btn_to_prompts.click(
@@ -1779,4 +1803,10 @@ with gr.Blocks() as demo:
1779
  ]
1780
  )
1781
 
 
 
 
 
 
 
1782
  demo.launch()
 
16
  import requests
17
  import base64
18
  import io
19
+ from transformers import AutoTokenizer, AutoModel
20
+ from utils import best_text_choice
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained("ai-forever/ru-en-RoSBERTa")
23
+ model = AutoModel.from_pretrained("ai-forever/ru-en-RoSBERTa")
24
 
25
  MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY')
26
  token = os.getenv('GITHUB_TOKEN')
 
788
  time.sleep(1)
789
  save_statistics_to_github(approach_stats)
790
 
791
+ def rank_messages(non_personalized_messages, personalized_messages):
792
+ # Предполагается, что у вас есть DataFrame unique_sms_df, используемый в функции best_text_choice
793
+ unique_sms_df = pd.read_parquet('unique_texts.parquet')
794
+
795
+ # Разделяем сообщения на отдельные строки
796
+ non_personalized_list = [msg.strip() for msg in non_personalized_messages.strip().split('\n\n') if msg.strip()]
797
+ personalized_list = [msg.strip() for msg in personalized_messages.strip().split('\n\n') if msg.strip()]
798
+
799
+ # Ранжируем неперсонализированные сообщения
800
+ ranked_non_personalized = best_text_choice(non_personalized_list, unique_sms_df, tokenizer, model)
801
+ # Ранжируем персонализированные сообщения
802
+ ranked_personalized = best_text_choice(personalized_list, unique_sms_df, tokenizer, model)
803
+
804
+ # Формируем строки для отображения
805
+ ranked_non_personalized_messages = '\n\n'.join(ranked_non_personalized)
806
+ ranked_personalized_messages = '\n\n'.join(ranked_personalized)
807
+
808
+ return ranked_non_personalized_messages, ranked_personalized_messages
809
 
810
  # ФУНКЦИИ ПРОВЕРОК (НАЧАЛО)
811
 
 
1786
  with gr.Row():
1787
  non_personalized_messages = gr.Textbox(label="Стандартные сообщения", lines=12, interactive=False)
1788
  personalized_messages = gr.Textbox(label="Персонализированные сообщения", lines=12, interactive=False)
1789
+ rank_button = gr.Button("Ранжировать")
1790
 
1791
  # Сначала переключаем вкладку, потом запускаем генерацию сообщений
1792
  btn_to_prompts.click(
 
1803
  ]
1804
  )
1805
 
1806
+ rank_button.click(
1807
+ fn=rank_messages,
1808
+ inputs=[non_personalized_messages, personalized_messages],
1809
+ outputs=[non_personalized_messages, personalized_messages]
1810
+ )
1811
+
1812
  demo.launch()