Kang Suhyun commited on
Commit
0ac094d
2 Parent(s): 5e33531 71d0339

Merge pull request #3 from Y-IAB/1-vote

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +108 -0
  3. requirments.txt +12 -0
.gitignore CHANGED
@@ -1 +1,2 @@
1
  venv
 
 
1
  venv
2
+ *.log
app.py CHANGED
@@ -2,15 +2,74 @@
2
  It provides a platform for comparing the responses of two LLMs.
3
  """
4
 
 
 
5
  from random import sample
 
6
 
7
  from fastchat.serve import gradio_web_server
8
  from fastchat.serve.gradio_web_server import bot_response
 
 
9
  import gradio as gr
10
 
 
 
 
11
  # TODO(#1): Add more models.
12
  SUPPORTED_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo", "gemini-pro"]
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def user(user_prompt):
16
  model_pair = sample(SUPPORTED_MODELS, 2)
@@ -85,6 +144,35 @@ def bot(state_a, state_b, request: gr.Request):
85
 
86
 
87
  with gr.Blocks() as app:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  model_names = [gr.State(None), gr.State(None)]
89
  responses = [gr.State(None), gr.State(None)]
90
 
@@ -98,6 +186,26 @@ with gr.Blocks() as app:
98
  responses[0] = gr.Textbox(label="Model A", interactive=False)
99
  responses[1] = gr.Textbox(label="Model B", interactive=False)
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  with gr.Accordion("Show models", open=False):
102
  with gr.Row():
103
  model_names[0] = gr.Textbox(label="Model A", interactive=False)
 
2
  It provides a platform for comparing the responses of two LLMs.
3
  """
4
 
5
+ import enum
6
+ import json
7
  from random import sample
8
+ from uuid import uuid4
9
 
10
  from fastchat.serve import gradio_web_server
11
  from fastchat.serve.gradio_web_server import bot_response
12
+ import firebase_admin
13
+ from firebase_admin import firestore
14
  import gradio as gr
15
 
16
+ db_app = firebase_admin.initialize_app()
17
+ db = firestore.client()
18
+
19
  # TODO(#1): Add more models.
20
  SUPPORTED_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo", "gemini-pro"]
21
 
22
+ # TODO(#4): Add more languages.
23
+ SUPPORTED_TRANSLATION_LANGUAGES = ["Korean", "English"]
24
+
25
+
26
+ class ResponseType(enum.Enum):
27
+ SUMMARIZE = "Summarize"
28
+ TRANSLATE = "Translate"
29
+
30
+
31
+ class VoteOptions(enum.Enum):
32
+ MODEL_A = "Model A is better"
33
+ MODEL_B = "Model B is better"
34
+ TIE = "Tie"
35
+
36
+
37
+ def vote(state_a, state_b, vote_button, res_type, source_lang, target_lang):
38
+ doc_id = uuid4().hex
39
+ winner = VoteOptions(vote_button).name.lower()
40
+
41
+ # The 'messages' field in the state is an array of arrays, which is
42
+ # not supported by Firestore. Therefore, we convert it to a JSON string.
43
+ model_a_conv = json.dumps(state_a.dict())
44
+ model_b_conv = json.dumps(state_b.dict())
45
+
46
+ if res_type == ResponseType.SUMMARIZE.value:
47
+ doc_ref = db.collection("arena-summarizations").document(doc_id)
48
+ doc_ref.set({
49
+ "id": doc_id,
50
+ "model_a": state_a.model_name,
51
+ "model_b": state_b.model_name,
52
+ "model_a_conv": model_a_conv,
53
+ "model_b_conv": model_b_conv,
54
+ "winner": winner,
55
+ "timestamp": firestore.SERVER_TIMESTAMP
56
+ })
57
+ return
58
+
59
+ if res_type == ResponseType.TRANSLATE.value:
60
+ doc_ref = db.collection("arena-translations").document(doc_id)
61
+ doc_ref.set({
62
+ "id": doc_id,
63
+ "model_a": state_a.model_name,
64
+ "model_b": state_b.model_name,
65
+ "model_a_conv": model_a_conv,
66
+ "model_b_conv": model_b_conv,
67
+ "source_language": source_lang.lower(),
68
+ "target_language": target_lang.lower(),
69
+ "winner": winner,
70
+ "timestamp": firestore.SERVER_TIMESTAMP
71
+ })
72
+
73
 
74
  def user(user_prompt):
75
  model_pair = sample(SUPPORTED_MODELS, 2)
 
144
 
145
 
146
  with gr.Blocks() as app:
147
+ with gr.Row():
148
+ response_type_radio = gr.Radio(
149
+ [response_type.value for response_type in ResponseType],
150
+ label="Response type",
151
+ info="Choose the type of response you want from the model.")
152
+
153
+ source_language = gr.Dropdown(
154
+ choices=SUPPORTED_TRANSLATION_LANGUAGES,
155
+ label="Source language",
156
+ info="Choose the source language for translation.",
157
+ interactive=True,
158
+ visible=False)
159
+ target_language = gr.Dropdown(
160
+ choices=SUPPORTED_TRANSLATION_LANGUAGES,
161
+ label="Target language",
162
+ info="Choose the target language for translation.",
163
+ interactive=True,
164
+ visible=False)
165
+
166
+ def update_language_visibility(response_type):
167
+ visible = response_type == ResponseType.TRANSLATE.value
168
+ return {
169
+ source_language: gr.Dropdown(visible=visible),
170
+ target_language: gr.Dropdown(visible=visible)
171
+ }
172
+
173
+ response_type_radio.change(update_language_visibility, response_type_radio,
174
+ [source_language, target_language])
175
+
176
  model_names = [gr.State(None), gr.State(None)]
177
  responses = [gr.State(None), gr.State(None)]
178
 
 
186
  responses[0] = gr.Textbox(label="Model A", interactive=False)
187
  responses[1] = gr.Textbox(label="Model B", interactive=False)
188
 
189
+ # TODO(#5): Display it only after the user submits the prompt.
190
+ # TODO(#6): Block voting if the response_type is not set.
191
+ # TODO(#6): Block voting if the user already voted.
192
+ with gr.Row():
193
+ option_a = gr.Button(VoteOptions.MODEL_A.value)
194
+ option_a.click(
195
+ vote, states +
196
+ [option_a, response_type_radio, source_language, target_language])
197
+
198
+ option_b = gr.Button("Model B is better")
199
+ option_b.click(
200
+ vote, states +
201
+ [option_b, response_type_radio, source_language, target_language])
202
+
203
+ tie = gr.Button("Tie")
204
+ tie.click(
205
+ vote,
206
+ states + [tie, response_type_radio, source_language, target_language])
207
+
208
+ # TODO(#7): Hide it until the user votes.
209
  with gr.Accordion("Show models", open=False):
210
  with gr.Row():
211
  model_names[0] = gr.Textbox(label="Model A", interactive=False)
requirments.txt CHANGED
@@ -6,26 +6,33 @@ altair==5.2.0
6
  annotated-types==0.6.0
7
  anyio==4.2.0
8
  attrs==23.2.0
 
9
  cachetools==5.3.2
10
  certifi==2023.11.17
 
11
  charset-normalizer==3.3.2
12
  click==8.1.7
13
  colorama==0.4.6
14
  contourpy==1.2.0
 
15
  cycler==0.12.1
16
  distro==1.9.0
17
  fastapi==0.109.0
18
  ffmpy==0.3.1
19
  filelock==3.13.1
 
20
  fonttools==4.47.2
21
  frozenlist==1.4.1
22
  fschat==0.2.35
23
  fsspec==2023.12.2
24
  google-api-core==2.16.1
 
25
  google-auth==2.27.0
 
26
  google-cloud-aiplatform==1.40.0
27
  google-cloud-bigquery==3.17.1
28
  google-cloud-core==2.4.1
 
29
  google-cloud-resource-manager==1.11.0
30
  google-cloud-storage==2.14.0
31
  google-crc32c==1.5.0
@@ -38,6 +45,7 @@ grpcio==1.60.0
38
  grpcio-status==1.60.0
39
  h11==0.14.0
40
  httpcore==1.0.2
 
41
  httpx==0.26.0
42
  huggingface-hub==0.20.3
43
  idna==3.6
@@ -52,6 +60,7 @@ MarkupSafe==2.1.4
52
  matplotlib==3.8.2
53
  mdurl==0.1.2
54
  mpmath==1.3.0
 
55
  multidict==6.0.4
56
  networkx==3.2.1
57
  nh3==0.2.15
@@ -68,10 +77,12 @@ protobuf==4.25.2
68
  psutil==5.9.8
69
  pyasn1==0.5.1
70
  pyasn1-modules==0.3.0
 
71
  pydantic==1.10.14
72
  pydantic_core==2.16.1
73
  pydub==0.25.1
74
  Pygments==2.17.2
 
75
  pyparsing==3.1.1
76
  python-dateutil==2.8.2
77
  python-multipart==0.0.6
@@ -105,6 +116,7 @@ transformers==4.37.2
105
  typer==0.9.0
106
  typing_extensions==4.9.0
107
  tzdata==2023.4
 
108
  urllib3==2.2.0
109
  uvicorn==0.27.0.post1
110
  wavedrom==2.0.3.post3
 
6
  annotated-types==0.6.0
7
  anyio==4.2.0
8
  attrs==23.2.0
9
+ CacheControl==0.13.1
10
  cachetools==5.3.2
11
  certifi==2023.11.17
12
+ cffi==1.16.0
13
  charset-normalizer==3.3.2
14
  click==8.1.7
15
  colorama==0.4.6
16
  contourpy==1.2.0
17
+ cryptography==42.0.2
18
  cycler==0.12.1
19
  distro==1.9.0
20
  fastapi==0.109.0
21
  ffmpy==0.3.1
22
  filelock==3.13.1
23
+ firebase-admin==6.4.0
24
  fonttools==4.47.2
25
  frozenlist==1.4.1
26
  fschat==0.2.35
27
  fsspec==2023.12.2
28
  google-api-core==2.16.1
29
+ google-api-python-client==2.116.0
30
  google-auth==2.27.0
31
+ google-auth-httplib2==0.2.0
32
  google-cloud-aiplatform==1.40.0
33
  google-cloud-bigquery==3.17.1
34
  google-cloud-core==2.4.1
35
+ google-cloud-firestore==2.14.0
36
  google-cloud-resource-manager==1.11.0
37
  google-cloud-storage==2.14.0
38
  google-crc32c==1.5.0
 
45
  grpcio-status==1.60.0
46
  h11==0.14.0
47
  httpcore==1.0.2
48
+ httplib2==0.22.0
49
  httpx==0.26.0
50
  huggingface-hub==0.20.3
51
  idna==3.6
 
60
  matplotlib==3.8.2
61
  mdurl==0.1.2
62
  mpmath==1.3.0
63
+ msgpack==1.0.7
64
  multidict==6.0.4
65
  networkx==3.2.1
66
  nh3==0.2.15
 
77
  psutil==5.9.8
78
  pyasn1==0.5.1
79
  pyasn1-modules==0.3.0
80
+ pycparser==2.21
81
  pydantic==1.10.14
82
  pydantic_core==2.16.1
83
  pydub==0.25.1
84
  Pygments==2.17.2
85
+ PyJWT==2.8.0
86
  pyparsing==3.1.1
87
  python-dateutil==2.8.2
88
  python-multipart==0.0.6
 
116
  typer==0.9.0
117
  typing_extensions==4.9.0
118
  tzdata==2023.4
119
+ uritemplate==4.1.1
120
  urllib3==2.2.0
121
  uvicorn==0.27.0.post1
122
  wavedrom==2.0.3.post3