File size: 6,664 Bytes
0dd5c06
 
 
cf196b3
 
73e8b86
cf196b3
67812d2
3b5ccf8
73e8b86
a19f11e
076f69b
65566f3
2a0aa5a
 
3c495cc
 
a19f11e
3b5ccf8
 
cf196b3
 
 
 
 
 
 
000d4f2
43c8549
47db0c3
cf196b3
 
a089fa0
f00b7ff
a089fa0
3c495cc
 
43c8549
3c495cc
 
 
 
 
 
 
 
 
 
3b5ccf8
 
 
47db0c3
3b5ccf8
 
3c495cc
a089fa0
f00b7ff
cf196b3
3c495cc
a089fa0
 
 
47db0c3
a089fa0
 
3c495cc
73e8b86
f00b7ff
a089fa0
 
 
73e8b86
b272a27
 
 
 
 
 
 
 
 
cf196b3
3c495cc
8a26fe6
 
3c495cc
 
cf196b3
 
 
8a26fe6
cf196b3
 
 
 
 
 
8a26fe6
cf196b3
 
 
 
 
3c495cc
 
cf196b3
871741c
 
cf196b3
 
3c495cc
 
cf196b3
6b89337
000d4f2
6b89337
43c8549
6b89337
 
f00b7ff
 
61f356c
 
 
f00b7ff
 
 
 
73e8b86
4b31650
cf196b3
a089fa0
 
cf196b3
3c495cc
 
60e8d65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9342b4
 
60e8d65
 
a9342b4
60e8d65
43c8549
 
 
 
60e8d65
 
 
 
 
 
 
 
 
 
000d4f2
a9342b4
 
 
 
 
 
 
 
 
 
 
 
 
73e8b86
076f69b
a19f11e
73e8b86
2a0aa5a
fa7ac61
73e8b86
 
6b89337
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""
It provides a platform for comparing the responses of two LLMs. 
"""
import enum
from uuid import uuid4

from firebase_admin import firestore
import gradio as gr
import lingua

from leaderboard import build_leaderboard
from leaderboard import db
from leaderboard import SUPPORTED_TRANSLATION_LANGUAGES
from model import check_models
from model import supported_models
import response
from response import get_responses

detector = lingua.LanguageDetectorBuilder.from_all_languages().build()


class VoteOptions(enum.Enum):
  MODEL_A = "Model A is better"
  MODEL_B = "Model B is better"
  TIE = "Tie"


def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
         prompt, instruction, category, source_lang, target_lang):
  doc_id = uuid4().hex
  winner = VoteOptions(vote_button).name.lower()

  deactivated_buttons = [gr.Button(interactive=False) for _ in range(3)]
  outputs = deactivated_buttons + [gr.Row(visible=True)]

  doc = {
      "id": doc_id,
      "prompt": prompt,
      "instruction": instruction,
      "model_a": model_a_name,
      "model_b": model_b_name,
      "model_a_response": response_a,
      "model_b_response": response_b,
      "winner": winner,
      "timestamp": firestore.SERVER_TIMESTAMP
  }

  if category == response.Category.SUMMARIZE.value:
    language_a = detector.detect_language_of(response_a)
    language_b = detector.detect_language_of(response_b)

    doc_ref = db.collection("arena-summarizations").document(doc_id)
    doc["model_a_response_language"] = language_a.name.lower()
    doc["model_b_response_language"] = language_b.name.lower()
    doc_ref.set(doc)

    return outputs

  if category == response.Category.TRANSLATE.value:
    if not source_lang or not target_lang:
      raise gr.Error("Please select source and target languages.")

    doc_ref = db.collection("arena-translations").document(doc_id)
    doc["source_language"] = source_lang.lower()
    doc["target_language"] = target_lang.lower()
    doc_ref.set(doc)

    return outputs

  raise gr.Error("Please select a response type.")


# Removes the persistent orange border from the leaderboard, which
# appears due to the 'generating' class when using the 'every' parameter.
css = """
.leaderboard .generating {
  border: none;
}
"""

with gr.Blocks(title="Arena", css=css) as app:
  with gr.Row():
    category_radio = gr.Radio(
        choices=[category.value for category in response.Category],
        value=response.Category.SUMMARIZE.value,
        label="Category",
        info="The chosen category determines the instruction sent to the LLMs.")

    source_language = gr.Dropdown(
        choices=SUPPORTED_TRANSLATION_LANGUAGES,
        value="English",
        label="Source language",
        info="Choose the source language for translation.",
        interactive=True,
        visible=False)
    target_language = gr.Dropdown(
        choices=SUPPORTED_TRANSLATION_LANGUAGES,
        value="Spanish",
        label="Target language",
        info="Choose the target language for translation.",
        interactive=True,
        visible=False)

    def update_language_visibility(category):
      visible = category == response.Category.TRANSLATE.value
      return {
          source_language: gr.Dropdown(visible=visible),
          target_language: gr.Dropdown(visible=visible)
      }

    category_radio.change(update_language_visibility, category_radio,
                          [source_language, target_language])

  model_names = [gr.State(None), gr.State(None)]
  response_boxes = [gr.State(None), gr.State(None)]

  prompt_textarea = gr.TextArea(label="Prompt", lines=4)
  submit = gr.Button()

  with gr.Group():
    with gr.Row():
      response_boxes[0] = gr.Textbox(label="Model A", interactive=False)

      response_boxes[1] = gr.Textbox(label="Model B", interactive=False)

    with gr.Row(visible=False) as model_name_row:
      model_names[0] = gr.Textbox(show_label=False)
      model_names[1] = gr.Textbox(show_label=False)

  with gr.Row(visible=False) as vote_row:
    option_a = gr.Button(VoteOptions.MODEL_A.value)
    option_b = gr.Button(VoteOptions.MODEL_B.value)
    tie = gr.Button(VoteOptions.TIE.value)

  instruction_state = gr.State("")

  # The following elements need to be reset when the user changes
  # the category, source language, or target language.
  ui_elements = [
      response_boxes[0], response_boxes[1], model_names[0], model_names[1],
      instruction_state, model_name_row, vote_row
  ]

  def reset_ui():
    return [gr.Textbox(value="") for _ in range(4)
           ] + [gr.State(""),
                gr.Row(visible=False),
                gr.Row(visible=False)]

  category_radio.change(fn=reset_ui, outputs=ui_elements)
  source_language.change(fn=reset_ui, outputs=ui_elements)
  target_language.change(fn=reset_ui, outputs=ui_elements)

  submit_event = submit.click(
      fn=lambda: [
          gr.Radio(interactive=False),
          gr.Dropdown(interactive=False),
          gr.Dropdown(interactive=False),
          gr.Button(interactive=False),
          gr.Row(visible=False),
          gr.Row(visible=False),
      ] + [gr.Button(interactive=True) for _ in range(3)],
      outputs=[
          category_radio, source_language, target_language, submit, vote_row,
          model_name_row, option_a, option_b, tie
      ]).then(fn=get_responses,
              inputs=[
                  prompt_textarea, category_radio, source_language,
                  target_language
              ],
              outputs=response_boxes + model_names + [instruction_state])
  submit_event.success(fn=lambda: gr.Row(visible=True), outputs=vote_row)
  submit_event.then(
      fn=lambda: [
          gr.Radio(interactive=True),
          gr.Dropdown(interactive=True),
          gr.Dropdown(interactive=True),
          gr.Button(interactive=True)
      ],
      outputs=[category_radio, source_language, target_language, submit])

  def deactivate_after_voting(option_button: gr.Button):
    option_button.click(
        fn=vote,
        inputs=[option_button] + response_boxes + model_names + [
            prompt_textarea, instruction_state, category_radio, source_language,
            target_language
        ],
        outputs=[option_a, option_b, tie, model_name_row]).then(
            fn=lambda: [gr.Button(interactive=False) for _ in range(3)],
            outputs=[option_a, option_b, tie])

  for option in [option_a, option_b, tie]:
    deactivate_after_voting(option)

  build_leaderboard()

if __name__ == "__main__":
  check_models(supported_models)

  # We need to enable queue to use generators.
  app.queue()
  app.launch(debug=True)