File size: 5,218 Bytes
0dd5c06
 
 
cf196b3
f796553
 
cf196b3
73e8b86
cf196b3
f796553
cf196b3
67812d2
73e8b86
a19f11e
3c495cc
 
a19f11e
f796553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
000d4f2
f796553
cf196b3
 
8ee349a
 
 
cf196b3
 
 
 
 
 
 
 
000d4f2
3c495cc
47db0c3
cf196b3
 
a089fa0
f00b7ff
a089fa0
3c495cc
 
 
 
 
 
 
 
 
 
 
 
 
47db0c3
3c495cc
a089fa0
f00b7ff
cf196b3
3c495cc
a089fa0
 
 
47db0c3
a089fa0
 
3c495cc
73e8b86
f00b7ff
a089fa0
 
 
73e8b86
300b938
cf196b3
3c495cc
 
 
 
cf196b3
 
 
 
 
 
 
 
 
 
 
 
 
 
3c495cc
 
cf196b3
871741c
 
cf196b3
 
3c495cc
 
cf196b3
6b89337
000d4f2
6b89337
 
 
 
f00b7ff
 
 
 
 
 
 
 
73e8b86
71d0339
cf196b3
 
a089fa0
 
cf196b3
a089fa0
3c495cc
 
a089fa0
 
f00b7ff
 
000d4f2
 
3c495cc
 
000d4f2
f00b7ff
 
 
 
73e8b86
a19f11e
 
73e8b86
 
 
6b89337
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
It provides a platform for comparing the responses of two LLMs. 
"""
import enum
import json
import os
from uuid import uuid4

import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
import gradio as gr

from leaderboard import build_leaderboard
import response
from response import get_responses

# Path to local credentials file, used in local development.
CREDENTIALS_PATH = os.environ.get("CREDENTIALS_PATH")

# Credentials passed as an environment variable, used in deployment.
CREDENTIALS = os.environ.get("CREDENTIALS")


def get_credentials():
  # Set credentials using a file in a local environment, if available.
  if CREDENTIALS_PATH and os.path.exists(CREDENTIALS_PATH):
    return credentials.Certificate(CREDENTIALS_PATH)

  # Use environment variable for credentials when the file is not found,
  # as credentials should not be public.
  json_cred = json.loads(CREDENTIALS)
  return credentials.Certificate(json_cred)


# TODO(#21): Fix auto-reload issue related to the initialization of Firebase.
firebase_admin.initialize_app(get_credentials())
db = firestore.client()

SUPPORTED_TRANSLATION_LANGUAGES = [
    "Korean", "English", "Chinese", "Japanese", "Spanish", "French"
]


class VoteOptions(enum.Enum):
  MODEL_A = "Model A is better"
  MODEL_B = "Model B is better"
  TIE = "Tie"


def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
         user_prompt, instruction, category, source_lang, target_lang):
  doc_id = uuid4().hex
  winner = VoteOptions(vote_button).name.lower()

  deactivated_buttons = [gr.Button(interactive=False) for _ in range(3)]
  outputs = deactivated_buttons + [gr.Row(visible=True)]

  doc = {
      "id": doc_id,
      "prompt": user_prompt,
      "instruction": instruction,
      "model_a": model_a_name,
      "model_b": model_b_name,
      "model_a_response": response_a,
      "model_b_response": response_b,
      "winner": winner,
      "timestamp": firestore.SERVER_TIMESTAMP
  }

  if category == response.Category.SUMMARIZE.value:
    doc_ref = db.collection("arena-summarizations").document(doc_id)
    doc_ref.set(doc)

    return outputs

  if category == response.Category.TRANSLATE.value:
    if not source_lang or not target_lang:
      raise gr.Error("Please select source and target languages.")

    doc_ref = db.collection("arena-translations").document(doc_id)
    doc["source_language"] = source_lang.lower()
    doc["target_language"] = target_lang.lower()
    doc_ref.set(doc)

    return outputs

  raise gr.Error("Please select a response type.")


with gr.Blocks(title="Arena") as app:
  with gr.Row():
    category_radio = gr.Radio(
        [category.value for category in response.Category],
        label="Category",
        info="The chosen category determines the instruction sent to the LLMs.")

    source_language = gr.Dropdown(
        choices=SUPPORTED_TRANSLATION_LANGUAGES,
        label="Source language",
        info="Choose the source language for translation.",
        interactive=True,
        visible=False)
    target_language = gr.Dropdown(
        choices=SUPPORTED_TRANSLATION_LANGUAGES,
        label="Target language",
        info="Choose the target language for translation.",
        interactive=True,
        visible=False)

    def update_language_visibility(category):
      visible = category == response.Category.TRANSLATE.value
      return {
          source_language: gr.Dropdown(visible=visible),
          target_language: gr.Dropdown(visible=visible)
      }

    category_radio.change(update_language_visibility, category_radio,
                          [source_language, target_language])

  model_names = [gr.State(None), gr.State(None)]
  response_boxes = [gr.State(None), gr.State(None)]

  prompt = gr.TextArea(label="Prompt", lines=4)
  submit = gr.Button()

  with gr.Group():
    with gr.Row():
      response_boxes[0] = gr.Textbox(label="Model A", interactive=False)
      response_boxes[1] = gr.Textbox(label="Model B", interactive=False)

    with gr.Row(visible=False) as model_name_row:
      model_names[0] = gr.Textbox(show_label=False)
      model_names[1] = gr.Textbox(show_label=False)

  # TODO(#5): Display it only after the user submits the prompt.
  with gr.Row():
    option_a = gr.Button(VoteOptions.MODEL_A.value)
    option_b = gr.Button(VoteOptions.MODEL_B.value)
    tie = gr.Button(VoteOptions.TIE.value)

  vote_buttons = [option_a, option_b, tie]
  instruction_state = gr.State("")

  submit.click(
      get_responses, [prompt, category_radio, source_language, target_language],
      response_boxes + model_names + vote_buttons +
      [instruction_state, model_name_row])

  common_inputs = response_boxes + model_names + [
      prompt, instruction_state, category_radio, source_language,
      target_language
  ]
  common_outputs = vote_buttons + [model_name_row]
  option_a.click(vote, [option_a] + common_inputs, common_outputs)
  option_b.click(vote, [option_b] + common_inputs, common_outputs)
  tie.click(vote, [tie] + common_inputs, common_outputs)

  build_leaderboard(db)

if __name__ == "__main__":
  # We need to enable queue to use generators.
  app.queue()
  app.launch(debug=True)