songhune commited on
Commit
72a7b9b
·
1 Parent(s): d04b805

애초에 함수 자체가 동기적으로 되어있으니까...

Browse files
Files changed (2) hide show
  1. chatbot_utils.py +16 -4
  2. gradio_interface.py +14 -20
chatbot_utils.py CHANGED
@@ -3,6 +3,7 @@ from openai import OpenAI
3
  import json
4
  from datetime import datetime
5
  from scenario_handler import ScenarioHandler
 
6
 
7
  client = OpenAI(api_key="sk-proj-3IEelWYK3Wl251k9qNriT3BlbkFJ9M7GpUGBijobUj1LETdu")
8
 
@@ -31,6 +32,18 @@ def chatbot_response(response, handler_type='offender', n=1):
31
  choices = [choice.message.content for choice in api_response.choices]
32
  return choices[0], choices
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def save_history(history):
35
  os.makedirs('logs', exist_ok=True)
36
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -39,13 +52,12 @@ def save_history(history):
39
  json.dump(history, file, ensure_ascii=False, indent=4)
40
  print(f"History saved to {filename}")
41
 
42
- def process_user_input(user_input, chatbot_history):
43
  if user_input.strip().lower() == "종료":
44
  save_history(chatbot_history)
45
  return chatbot_history + [("종료", "실험에 참가해 주셔서 감사합니다. 후속 지시를 따라주세요")], []
46
 
47
- offender_response, _ = chatbot_response(user_input, 'offender', n=1)
48
- new_history = chatbot_history + [(user_input, offender_response)]
49
 
50
- _, victim_choices = chatbot_response(offender_response, 'victim', n=3)
51
  return new_history, victim_choices
 
3
  import json
4
  from datetime import datetime
5
  from scenario_handler import ScenarioHandler
6
+ import asyncio
7
 
8
  client = OpenAI(api_key="sk-proj-3IEelWYK3Wl251k9qNriT3BlbkFJ9M7GpUGBijobUj1LETdu")
9
 
 
32
  choices = [choice.message.content for choice in api_response.choices]
33
  return choices[0], choices
34
 
35
+ async def async_chatbot_response(response, handler_type='offender', n=1):
36
+ return chatbot_response(response, handler_type, n)
37
+
38
+ async def get_both_responses(user_input):
39
+ victim_task = asyncio.create_task(async_chatbot_response(user_input, 'victim', n=3))
40
+ offender_task = asyncio.create_task(async_chatbot_response(user_input, 'offender', n=1))
41
+
42
+ victim_response, victim_choices = await victim_task
43
+ offender_response, _ = await offender_task
44
+
45
+ return victim_response, victim_choices, offender_response
46
+
47
  def save_history(history):
48
  os.makedirs('logs', exist_ok=True)
49
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
 
52
  json.dump(history, file, ensure_ascii=False, indent=4)
53
  print(f"History saved to {filename}")
54
 
55
+ async def process_user_input(user_input, chatbot_history):
56
  if user_input.strip().lower() == "종료":
57
  save_history(chatbot_history)
58
  return chatbot_history + [("종료", "실험에 참가해 주셔서 감사합니다. 후속 지시를 따라주세요")], []
59
 
60
+ victim_response, victim_choices, offender_response = await get_both_responses(user_input)
61
+ new_history = chatbot_history + [(user_input, victim_response), (None, offender_response)]
62
 
 
63
  return new_history, victim_choices
gradio_interface.py CHANGED
@@ -1,35 +1,24 @@
1
  import gradio as gr
2
- from chatbot_utils import process_user_input, chatbot_response
 
3
 
4
  def create_interface():
5
- def handle_user_response(user_input, selected_response, chatbot_history):
6
  input_text = user_input if user_input else selected_response
7
 
8
  if input_text.strip().lower() == "종료":
9
- new_history = chatbot_history + [("종료", "실험에 참가해 주셔서 감사합니다. 후속 지시를 따라주세요")]
10
  return new_history, gr.update(choices=[], interactive=False)
11
 
12
- # Get victim's response first
13
- victim_response, victim_choices = chatbot_response(input_text, 'victim', n=3)
14
-
15
- # Then get offender's response
16
- offender_response, _ = chatbot_response(input_text, 'offender', n=1)
17
-
18
- new_history = chatbot_history + [(input_text, victim_response), (None, offender_response)]
19
  return new_history, gr.update(choices=victim_choices)
20
 
21
- def handle_case_selection():
22
  initial_message = "발표가 망한 건 제 잘못도 좀 있지만, 팀장님은 아무것도 안 하면서 이러는 건 선 넘은거죠"
23
- chatbot_history = [(initial_message, None)]
24
 
25
- # Get victim's response first
26
- victim_response, victim_choices = chatbot_response(initial_message, 'victim', n=3)
27
- chatbot_history.append((None, victim_response))
28
-
29
- # Then get offender's response
30
- offender_response, _ = chatbot_response(initial_message, 'offender', n=1)
31
- chatbot_history.append((None, offender_response))
32
 
 
33
  return chatbot_history, gr.update(choices=victim_choices)
34
 
35
  with gr.Blocks() as demo:
@@ -44,4 +33,9 @@ def create_interface():
44
 
45
  submit_button.click(handle_user_response, inputs=[user_input, response_choices, screen], outputs=[screen, response_choices])
46
 
47
- return demo
 
 
 
 
 
 
1
  import gradio as gr
2
+ from chatbot_utils import process_user_input, get_both_responses
3
+ import asyncio
4
 
5
  def create_interface():
6
+ async def handle_user_response(user_input, selected_response, chatbot_history):
7
  input_text = user_input if user_input else selected_response
8
 
9
  if input_text.strip().lower() == "종료":
10
+ new_history = chatbot_history + [(input_text, "실험에 참가해 주셔서 감사합니다. 후속 지시를 따라주세요")]
11
  return new_history, gr.update(choices=[], interactive=False)
12
 
13
+ new_history, victim_choices = await process_user_input(input_text, chatbot_history)
 
 
 
 
 
 
14
  return new_history, gr.update(choices=victim_choices)
15
 
16
+ async def handle_case_selection():
17
  initial_message = "발표가 망한 건 제 잘못도 좀 있지만, 팀장님은 아무것도 안 하면서 이러는 건 선 넘은거죠"
 
18
 
19
+ victim_response, victim_choices, offender_response = await get_both_responses(initial_message)
 
 
 
 
 
 
20
 
21
+ chatbot_history = [(initial_message, victim_response), (None, offender_response)]
22
  return chatbot_history, gr.update(choices=victim_choices)
23
 
24
  with gr.Blocks() as demo:
 
33
 
34
  submit_button.click(handle_user_response, inputs=[user_input, response_choices, screen], outputs=[screen, response_choices])
35
 
36
+ return demo
37
+
38
+ # Gradio의 launch 함수를 사용하여 인터페이스를 실행합니다.
39
+ if __name__ == "__main__":
40
+ demo = create_interface()
41
+ demo.launch()