yuantao-infini-ai commited on
Commit
cf1798b
1 Parent(s): 4f617e5

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +3 -9
  2. __init__.py +0 -0
  3. __pycache__/__init__.cpython-310.pyc +0 -0
  4. __pycache__/__init__.cpython-311.pyc +0 -0
  5. __pycache__/api_provider.cpython-310.pyc +0 -0
  6. __pycache__/base_model_worker.cpython-310.pyc +0 -0
  7. __pycache__/cli.cpython-310.pyc +0 -0
  8. __pycache__/cli.cpython-311.pyc +0 -0
  9. __pycache__/controller.cpython-310.pyc +0 -0
  10. __pycache__/gradio_web_server.cpython-310.pyc +0 -0
  11. __pycache__/inference.cpython-310.pyc +0 -0
  12. __pycache__/model_worker.cpython-310.pyc +0 -0
  13. __pycache__/test_message.cpython-310.pyc +0 -0
  14. api_provider.py +130 -0
  15. base_model_worker.py +239 -0
  16. cli.py +313 -0
  17. controller.py +348 -0
  18. gateway/README.md +57 -0
  19. gateway/nginx.conf +97 -0
  20. gradio_block_arena_anony.py +608 -0
  21. gradio_block_arena_named.py +458 -0
  22. gradio_web_server.py +883 -0
  23. gradio_web_server_multi.py +270 -0
  24. huggingface_api.py +73 -0
  25. huggingface_api_worker.py +391 -0
  26. inference.py +596 -0
  27. launch_all_serve.py +284 -0
  28. model_worker.py +363 -0
  29. monitor/basic_stats.py +210 -0
  30. monitor/clean_battle_data.py +269 -0
  31. monitor/clean_chat_data.py +171 -0
  32. monitor/dataset_release_scripts/arena_33k/count_unique_users.py +25 -0
  33. monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py +155 -0
  34. monitor/dataset_release_scripts/arena_33k/merge_field.py +25 -0
  35. monitor/dataset_release_scripts/arena_33k/sample.py +32 -0
  36. monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py +9 -0
  37. monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py +13 -0
  38. monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py +119 -0
  39. monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py +148 -0
  40. monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py +27 -0
  41. monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md +23 -0
  42. monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py +45 -0
  43. monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh +18 -0
  44. monitor/dataset_release_scripts/lmsys_chat_1m/sample.py +32 -0
  45. monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py +17 -0
  46. monitor/elo_analysis.py +303 -0
  47. monitor/inspect_conv.py +87 -0
  48. monitor/intersect_conv_file.py +25 -0
  49. monitor/leaderboard_csv_to_html.py +51 -0
  50. monitor/monitor.py +313 -0
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Demo Test
3
- emoji: 🌍
4
- colorFrom: indigo
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.39.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: demo_test
3
+ app_file: gradio_web_server.py
 
 
4
  sdk: gradio
5
+ sdk_version: 3.45.0
 
 
6
  ---
 
 
__init__.py ADDED
File without changes
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (168 Bytes). View file
 
__pycache__/__init__.cpython-311.pyc ADDED
Binary file (184 Bytes). View file
 
__pycache__/api_provider.cpython-310.pyc ADDED
Binary file (2.69 kB). View file
 
__pycache__/base_model_worker.cpython-310.pyc ADDED
Binary file (7.01 kB). View file
 
__pycache__/cli.cpython-310.pyc ADDED
Binary file (9 kB). View file
 
__pycache__/cli.cpython-311.pyc ADDED
Binary file (15.6 kB). View file
 
__pycache__/controller.cpython-310.pyc ADDED
Binary file (9.35 kB). View file
 
__pycache__/gradio_web_server.cpython-310.pyc ADDED
Binary file (20.6 kB). View file
 
__pycache__/inference.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
__pycache__/model_worker.cpython-310.pyc ADDED
Binary file (9.37 kB). View file
 
__pycache__/test_message.cpython-310.pyc ADDED
Binary file (2.22 kB). View file
 
api_provider.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Call API providers."""
2
+
3
+ import os
4
+ import random
5
+ import time
6
+
7
+ from fastchat.utils import build_logger
8
+ from fastchat.constants import WORKER_API_TIMEOUT
9
+
10
+
11
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
12
+
13
+
14
+ def openai_api_stream_iter(
15
+ model_name,
16
+ messages,
17
+ temperature,
18
+ top_p,
19
+ max_new_tokens,
20
+ api_base=None,
21
+ api_key=None,
22
+ ):
23
+ import openai
24
+
25
+ openai.api_base = api_base or "https://api.openai.com/v1"
26
+ openai.api_key = api_key or os.environ["OPENAI_API_KEY"]
27
+ if model_name == "gpt-4-turbo":
28
+ model_name = "gpt-4-1106-preview"
29
+
30
+ # Make requests
31
+ gen_params = {
32
+ "model": model_name,
33
+ "prompt": messages,
34
+ "temperature": temperature,
35
+ "top_p": top_p,
36
+ "max_new_tokens": max_new_tokens,
37
+ }
38
+ logger.info(f"==== request ====\n{gen_params}")
39
+
40
+ res = openai.ChatCompletion.create(
41
+ model=model_name,
42
+ messages=messages,
43
+ temperature=temperature,
44
+ max_tokens=max_new_tokens,
45
+ stream=True,
46
+ )
47
+ text = ""
48
+ for chunk in res:
49
+ text += chunk["choices"][0]["delta"].get("content", "")
50
+ data = {
51
+ "text": text,
52
+ "error_code": 0,
53
+ }
54
+ yield data
55
+
56
+
57
+ def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens):
58
+ import anthropic
59
+
60
+ c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
61
+
62
+ # Make requests
63
+ gen_params = {
64
+ "model": model_name,
65
+ "prompt": prompt,
66
+ "temperature": temperature,
67
+ "top_p": top_p,
68
+ "max_new_tokens": max_new_tokens,
69
+ }
70
+ logger.info(f"==== request ====\n{gen_params}")
71
+
72
+ res = c.completions.create(
73
+ prompt=prompt,
74
+ stop_sequences=[anthropic.HUMAN_PROMPT],
75
+ max_tokens_to_sample=max_new_tokens,
76
+ temperature=temperature,
77
+ top_p=top_p,
78
+ model=model_name,
79
+ stream=True,
80
+ )
81
+ text = ""
82
+ for chunk in res:
83
+ text += chunk.completion
84
+ data = {
85
+ "text": text,
86
+ "error_code": 0,
87
+ }
88
+ yield data
89
+
90
+
91
+ def init_palm_chat(model_name):
92
+ import vertexai # pip3 install google-cloud-aiplatform
93
+ from vertexai.preview.language_models import ChatModel
94
+
95
+ project_id = os.environ["GCP_PROJECT_ID"]
96
+ location = "us-central1"
97
+ vertexai.init(project=project_id, location=location)
98
+
99
+ chat_model = ChatModel.from_pretrained(model_name)
100
+ chat = chat_model.start_chat(examples=[])
101
+ return chat
102
+
103
+
104
+ def palm_api_stream_iter(chat, message, temperature, top_p, max_new_tokens):
105
+ parameters = {
106
+ "temperature": temperature,
107
+ "top_p": top_p,
108
+ "max_output_tokens": max_new_tokens,
109
+ }
110
+ gen_params = {
111
+ "model": "palm-2",
112
+ "prompt": message,
113
+ }
114
+ gen_params.update(parameters)
115
+ logger.info(f"==== request ====\n{gen_params}")
116
+
117
+ response = chat.send_message(message, **parameters)
118
+ content = response.text
119
+
120
+ pos = 0
121
+ while pos < len(content):
122
+ # This is a fancy way to simulate token generation latency combined
123
+ # with a Poisson process.
124
+ pos += random.randint(10, 20)
125
+ time.sleep(random.expovariate(50))
126
+ data = {
127
+ "text": content[:pos],
128
+ "error_code": 0,
129
+ }
130
+ yield data
base_model_worker.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import threading
3
+ import time
4
+ from typing import List
5
+
6
+ from fastapi import FastAPI, Request, BackgroundTasks
7
+ from fastapi.responses import StreamingResponse, JSONResponse
8
+ import requests
9
+
10
+ from fastchat.constants import WORKER_HEART_BEAT_INTERVAL
11
+ from fastchat.conversation import Conversation
12
+ from fastchat.utils import pretty_print_semaphore, build_logger
13
+
14
+
15
+ worker = None
16
+ logger = None
17
+
18
+ app = FastAPI()
19
+
20
+
21
+ def heart_beat_worker(obj):
22
+ while True:
23
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
24
+ obj.send_heart_beat()
25
+
26
+
27
+ class BaseModelWorker:
28
+ def __init__(
29
+ self,
30
+ controller_addr: str,
31
+ worker_addr: str,
32
+ worker_id: str,
33
+ model_path: str,
34
+ model_names: List[str],
35
+ limit_worker_concurrency: int,
36
+ conv_template: str = None,
37
+ ):
38
+ global logger, worker
39
+
40
+ self.controller_addr = controller_addr
41
+ self.worker_addr = worker_addr
42
+ self.worker_id = worker_id
43
+ if model_path.endswith("/"):
44
+ model_path = model_path[:-1]
45
+ self.model_names = model_names or [model_path.split("/")[-1]]
46
+ self.limit_worker_concurrency = limit_worker_concurrency
47
+ self.conv = self.make_conv_template(conv_template, model_path)
48
+ self.conv.sep_style = int(self.conv.sep_style)
49
+ self.tokenizer = None
50
+ self.context_len = None
51
+ self.call_ct = 0
52
+ self.semaphore = None
53
+
54
+ self.heart_beat_thread = None
55
+
56
+ if logger is None:
57
+ logger = build_logger("model_worker", f"model_worker_{self.worker_id}.log")
58
+ if worker is None:
59
+ worker = self
60
+
61
+ def make_conv_template(
62
+ self,
63
+ conv_template: str = None,
64
+ model_path: str = None,
65
+ ) -> Conversation:
66
+ """
67
+ can be overrided to costomize the conversation template for different model workers.
68
+ """
69
+ from fastchat.conversation import get_conv_template
70
+ from fastchat.model.model_adapter import get_conversation_template
71
+
72
+ if conv_template:
73
+ conv = get_conv_template(conv_template)
74
+ else:
75
+ conv = get_conversation_template(model_path)
76
+ print(conv)
77
+ return conv
78
+
79
+ def init_heart_beat(self):
80
+ self.register_to_controller()
81
+ self.heart_beat_thread = threading.Thread(
82
+ target=heart_beat_worker,
83
+ args=(self,),
84
+ daemon=True,
85
+ )
86
+ self.heart_beat_thread.start()
87
+
88
+ def register_to_controller(self):
89
+ logger.info("Register to controller")
90
+
91
+ url = self.controller_addr + "/register_worker"
92
+ data = {
93
+ "worker_name": self.worker_addr,
94
+ "check_heart_beat": True,
95
+ "worker_status": self.get_status(),
96
+ }
97
+ r = requests.post(url, json=data)
98
+ assert r.status_code == 200
99
+
100
+ def send_heart_beat(self):
101
+ logger.info(
102
+ f"Send heart beat. Models: {self.model_names}. "
103
+ f"Semaphore: {pretty_print_semaphore(self.semaphore)}. "
104
+ f"call_ct: {self.call_ct}. "
105
+ f"worker_id: {self.worker_id}. "
106
+ )
107
+
108
+ url = self.controller_addr + "/receive_heart_beat"
109
+
110
+ while True:
111
+ try:
112
+ ret = requests.post(
113
+ url,
114
+ json={
115
+ "worker_name": self.worker_addr,
116
+ "queue_length": self.get_queue_length(),
117
+ },
118
+ timeout=5,
119
+ )
120
+ exist = ret.json()["exist"]
121
+ break
122
+ except (requests.exceptions.RequestException, KeyError) as e:
123
+ logger.error(f"heart beat error: {e}")
124
+ time.sleep(5)
125
+
126
+ if not exist:
127
+ self.register_to_controller()
128
+
129
+ def get_queue_length(self):
130
+ if (
131
+ self.semaphore is None
132
+ or self.semaphore._value is None
133
+ or self.semaphore._waiters is None
134
+ ):
135
+ return 0
136
+ else:
137
+ return (
138
+ self.limit_worker_concurrency
139
+ - self.semaphore._value
140
+ + len(self.semaphore._waiters)
141
+ )
142
+
143
+ def get_status(self):
144
+ return {
145
+ "model_names": self.model_names,
146
+ "speed": 1,
147
+ "queue_length": self.get_queue_length(),
148
+ }
149
+
150
+ def count_token(self, params):
151
+ prompt = params["prompt"]
152
+
153
+ try:
154
+ input_ids = self.tokenizer(prompt).input_ids
155
+ input_echo_len = len(input_ids)
156
+ except TypeError:
157
+ input_echo_len = self.tokenizer.num_tokens(prompt)
158
+
159
+ ret = {
160
+ "count": input_echo_len,
161
+ "error_code": 0,
162
+ }
163
+ return ret
164
+
165
+ def get_conv_template(self):
166
+ return {"conv": self.conv}
167
+
168
+ def generate_stream_gate(self, params):
169
+ raise NotImplementedError
170
+
171
+ def generate_gate(self, params):
172
+ raise NotImplementedError
173
+
174
+ def get_embeddings(self, params):
175
+ raise NotImplementedError
176
+
177
+
178
+ def release_worker_semaphore():
179
+ worker.semaphore.release()
180
+
181
+
182
+ def acquire_worker_semaphore():
183
+ if worker.semaphore is None:
184
+ worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
185
+ return worker.semaphore.acquire()
186
+
187
+
188
+ def create_background_tasks():
189
+ background_tasks = BackgroundTasks()
190
+ background_tasks.add_task(release_worker_semaphore)
191
+ return background_tasks
192
+
193
+
194
+ @app.post("/worker_generate_stream")
195
+ async def api_generate_stream(request: Request):
196
+ params = await request.json()
197
+ await acquire_worker_semaphore()
198
+ generator = worker.generate_stream_gate(params)
199
+ background_tasks = create_background_tasks()
200
+ return StreamingResponse(generator, background=background_tasks)
201
+
202
+
203
+ @app.post("/worker_generate")
204
+ async def api_generate(request: Request):
205
+ params = await request.json()
206
+ await acquire_worker_semaphore()
207
+ output = await asyncio.to_thread(worker.generate_gate, params)
208
+ release_worker_semaphore()
209
+ return JSONResponse(output)
210
+
211
+
212
+ @app.post("/worker_get_embeddings")
213
+ async def api_get_embeddings(request: Request):
214
+ params = await request.json()
215
+ await acquire_worker_semaphore()
216
+ embedding = worker.get_embeddings(params)
217
+ release_worker_semaphore()
218
+ return JSONResponse(content=embedding)
219
+
220
+
221
+ @app.post("/worker_get_status")
222
+ async def api_get_status(request: Request):
223
+ return worker.get_status()
224
+
225
+
226
+ @app.post("/count_token")
227
+ async def api_count_token(request: Request):
228
+ params = await request.json()
229
+ return worker.count_token(params)
230
+
231
+
232
+ @app.post("/worker_get_conv_template")
233
+ async def api_get_conv(request: Request):
234
+ return worker.get_conv_template()
235
+
236
+
237
+ @app.post("/model_details")
238
+ async def api_model_details(request: Request):
239
+ return {"context_length": worker.context_len}
cli.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chat with a model with command line interface.
3
+
4
+ Usage:
5
+ python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5
6
+ python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0
7
+
8
+ Other commands:
9
+ - Type "!!exit" or an empty line to exit.
10
+ - Type "!!reset" to start a new conversation.
11
+ - Type "!!remove" to remove the last prompt.
12
+ - Type "!!regen" to regenerate the last message.
13
+ - Type "!!save <filename>" to save the conversation history to a json file.
14
+ - Type "!!load <filename>" to load a conversation history from a json file.
15
+ """
16
+ import argparse
17
+ import os
18
+ import re
19
+ import sys
20
+
21
+ from prompt_toolkit import PromptSession
22
+ from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
23
+ from prompt_toolkit.completion import WordCompleter
24
+ from prompt_toolkit.history import InMemoryHistory
25
+ from prompt_toolkit.key_binding import KeyBindings
26
+ from rich.console import Console
27
+ from rich.live import Live
28
+ from rich.markdown import Markdown
29
+ import torch
30
+
31
+ from fastchat.model.model_adapter import add_model_args
32
+ from fastchat.modules.awq import AWQConfig
33
+ from fastchat.modules.exllama import ExllamaConfig
34
+ from fastchat.modules.xfastertransformer import XftConfig
35
+ from fastchat.modules.gptq import GptqConfig
36
+ from fastchat.serve.inference import ChatIO, chat_loop
37
+ from fastchat.utils import str_to_torch_dtype
38
+
39
+
40
+ class SimpleChatIO(ChatIO):
41
+ def __init__(self, multiline: bool = False, prefix: str = ''):
42
+ self._multiline = multiline
43
+ self.prefix = prefix
44
+
45
+ def prompt_for_input(self, role) -> str:
46
+ if not self._multiline:
47
+ return input(f"{role}: {self.prefix}")
48
+
49
+ prompt_data = []
50
+ line = input(f"{role} [ctrl-d/z on empty line to end]: ")
51
+ while True:
52
+ prompt_data.append(line.strip())
53
+ try:
54
+ line = input()
55
+ except EOFError as e:
56
+ break
57
+ return f"\n{self.prefix}".join(prompt_data)
58
+
59
+ def prompt_for_output(self, role: str):
60
+ print(f"{role}: ", end="", flush=True)
61
+
62
+ def stream_output(self, output_stream):
63
+ pre = 0
64
+ for outputs in output_stream:
65
+ output_text = outputs["text"]
66
+ output_text = output_text.strip().split(" ")
67
+ now = len(output_text) - 1
68
+ if now > pre:
69
+ print(" ".join(output_text[pre:now]), end=" ", flush=True)
70
+ pre = now
71
+ print(" ".join(output_text[pre:]), flush=True)
72
+ return " ".join(output_text)
73
+
74
+ def print_output(self, text: str):
75
+ print(text)
76
+
77
+
78
+ class RichChatIO(ChatIO):
79
+ bindings = KeyBindings()
80
+
81
+ @bindings.add("escape", "enter")
82
+ def _(event):
83
+ event.app.current_buffer.newline()
84
+
85
+ def __init__(self, multiline: bool = False, mouse: bool = False):
86
+ self._prompt_session = PromptSession(history=InMemoryHistory())
87
+ self._completer = WordCompleter(
88
+ words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"],
89
+ pattern=re.compile("$"),
90
+ )
91
+ self._console = Console()
92
+ self._multiline = multiline
93
+ self._mouse = mouse
94
+
95
+ def prompt_for_input(self, role) -> str:
96
+ self._console.print(f"[bold]{role}:")
97
+ # TODO(suquark): multiline input has some issues. fix it later.
98
+ prompt_input = self._prompt_session.prompt(
99
+ completer=self._completer,
100
+ multiline=False,
101
+ mouse_support=self._mouse,
102
+ auto_suggest=AutoSuggestFromHistory(),
103
+ key_bindings=self.bindings if self._multiline else None,
104
+ )
105
+ self._console.print()
106
+ return prompt_input
107
+
108
+ def prompt_for_output(self, role: str):
109
+ self._console.print(f"[bold]{role.replace('/', '|')}:")
110
+
111
+ def stream_output(self, output_stream):
112
+ """Stream output from a role."""
113
+ # TODO(suquark): the console flickers when there is a code block
114
+ # above it. We need to cut off "live" when a code block is done.
115
+
116
+ # Create a Live context for updating the console output
117
+ with Live(console=self._console, refresh_per_second=4) as live:
118
+ # Read lines from the stream
119
+ for outputs in output_stream:
120
+ if not outputs:
121
+ continue
122
+ text = outputs["text"]
123
+ # Render the accumulated text as Markdown
124
+ # NOTE: this is a workaround for the rendering "unstandard markdown"
125
+ # in rich. The chatbots output treat "\n" as a new line for
126
+ # better compatibility with real-world text. However, rendering
127
+ # in markdown would break the format. It is because standard markdown
128
+ # treat a single "\n" in normal text as a space.
129
+ # Our workaround is adding two spaces at the end of each line.
130
+ # This is not a perfect solution, as it would
131
+ # introduce trailing spaces (only) in code block, but it works well
132
+ # especially for console output, because in general the console does not
133
+ # care about trailing spaces.
134
+ lines = []
135
+ for line in text.splitlines():
136
+ lines.append(line)
137
+ if line.startswith("```"):
138
+ # Code block marker - do not add trailing spaces, as it would
139
+ # break the syntax highlighting
140
+ lines.append("\n")
141
+ else:
142
+ lines.append(" \n")
143
+ markdown = Markdown("".join(lines))
144
+ # Update the Live console output
145
+ live.update(markdown)
146
+ self._console.print()
147
+ return text
148
+
149
+ def print_output(self, text: str):
150
+ self.stream_output([{"text": text}])
151
+
152
+
153
+ class ProgrammaticChatIO(ChatIO):
154
+ def prompt_for_input(self, role) -> str:
155
+ contents = ""
156
+ # `end_sequence` signals the end of a message. It is unlikely to occur in
157
+ # message content.
158
+ end_sequence = " __END_OF_A_MESSAGE_47582648__\n"
159
+ len_end = len(end_sequence)
160
+ while True:
161
+ if len(contents) >= len_end:
162
+ last_chars = contents[-len_end:]
163
+ if last_chars == end_sequence:
164
+ break
165
+ try:
166
+ char = sys.stdin.read(1)
167
+ contents = contents + char
168
+ except EOFError:
169
+ continue
170
+ contents = contents[:-len_end]
171
+ print(f"[!OP:{role}]: {contents}", flush=True)
172
+ return contents
173
+
174
+ def prompt_for_output(self, role: str):
175
+ print(f"[!OP:{role}]: ", end="", flush=True)
176
+
177
+ def stream_output(self, output_stream):
178
+ pre = 0
179
+ for outputs in output_stream:
180
+ output_text = outputs["text"]
181
+ output_text = output_text.strip().split(" ")
182
+ now = len(output_text) - 1
183
+ if now > pre:
184
+ print(" ".join(output_text[pre:now]), end=" ", flush=True)
185
+ pre = now
186
+ print(" ".join(output_text[pre:]), flush=True)
187
+ return " ".join(output_text)
188
+
189
+ def print_output(self, text: str):
190
+ print(text)
191
+
192
+
193
+ def main(args):
194
+ if args.gpus:
195
+ if len(args.gpus.split(",")) < args.num_gpus:
196
+ raise ValueError(
197
+ f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
198
+ )
199
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
200
+ os.environ["XPU_VISIBLE_DEVICES"] = args.gpus
201
+ if args.enable_exllama:
202
+ exllama_config = ExllamaConfig(
203
+ max_seq_len=args.exllama_max_seq_len,
204
+ gpu_split=args.exllama_gpu_split,
205
+ )
206
+ else:
207
+ exllama_config = None
208
+ if args.enable_xft:
209
+ xft_config = XftConfig(
210
+ max_seq_len=args.xft_max_seq_len,
211
+ data_type=args.xft_dtype,
212
+ )
213
+ if args.device != "cpu":
214
+ print("xFasterTransformer now is only support CPUs. Reset device to CPU")
215
+ args.device = "cpu"
216
+ else:
217
+ xft_config = None
218
+ if args.style == "simple":
219
+ chatio = SimpleChatIO(args.multiline)
220
+ elif args.style == "rich":
221
+ chatio = RichChatIO(args.multiline, args.mouse)
222
+ elif args.style == "programmatic":
223
+ chatio = ProgrammaticChatIO()
224
+ else:
225
+ raise ValueError(f"Invalid style for console: {args.style}")
226
+ try:
227
+ if args.upload_file_path:
228
+ prefix = open(args.upload_file_path, 'r').read()
229
+ args.conv_system_msg = prefix[:20000]
230
+ chat_loop(
231
+ args.model_path,
232
+ args.device,
233
+ args.num_gpus,
234
+ args.max_gpu_memory,
235
+ str_to_torch_dtype(args.dtype),
236
+ args.load_8bit,
237
+ args.cpu_offloading,
238
+ args.conv_template,
239
+ args.conv_system_msg,
240
+ args.temperature,
241
+ args.repetition_penalty,
242
+ args.max_new_tokens,
243
+ chatio,
244
+ gptq_config=GptqConfig(
245
+ ckpt=args.gptq_ckpt or args.model_path,
246
+ wbits=args.gptq_wbits,
247
+ groupsize=args.gptq_groupsize,
248
+ act_order=args.gptq_act_order,
249
+ ),
250
+ awq_config=AWQConfig(
251
+ ckpt=args.awq_ckpt or args.model_path,
252
+ wbits=args.awq_wbits,
253
+ groupsize=args.awq_groupsize,
254
+ ),
255
+ exllama_config=exllama_config,
256
+ xft_config=xft_config,
257
+ revision=args.revision,
258
+ judge_sent_end=args.judge_sent_end,
259
+ debug=args.debug,
260
+ history=not args.no_history,
261
+ )
262
+ except KeyboardInterrupt:
263
+ print("exit...")
264
+
265
+
266
+ if __name__ == "__main__":
267
+ parser = argparse.ArgumentParser()
268
+ add_model_args(parser)
269
+ parser.add_argument(
270
+ "--conv-template", type=str, default=None, help="Conversation prompt template."
271
+ )
272
+ parser.add_argument(
273
+ "--conv-system-msg", type=str, default=None, help="Conversation system message."
274
+ )
275
+ parser.add_argument("--temperature", type=float, default=0.7)
276
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
277
+ parser.add_argument("--max-new-tokens", type=int, default=512)
278
+ parser.add_argument("--no-history", action="store_true")
279
+ parser.add_argument(
280
+ "--style",
281
+ type=str,
282
+ default="simple",
283
+ choices=["simple", "rich", "programmatic"],
284
+ help="Display style.",
285
+ )
286
+ parser.add_argument(
287
+ "--multiline",
288
+ action="store_true",
289
+ help="Enable multiline input. Use ESC+Enter for newline.",
290
+ )
291
+ parser.add_argument(
292
+ "--mouse",
293
+ action="store_true",
294
+ help="[Rich Style]: Enable mouse support for cursor positioning.",
295
+ )
296
+ parser.add_argument(
297
+ "--judge-sent-end",
298
+ action="store_true",
299
+ help="Whether enable the correction logic that interrupts the output of sentences due to EOS.",
300
+ )
301
+ parser.add_argument(
302
+ "--debug",
303
+ action="store_true",
304
+ help="Print useful debug information (e.g., prompts)",
305
+ )
306
+ parser.add_argument(
307
+ "--upload-file-path",
308
+ type=str,
309
+ default="",
310
+ help="upload long txt for summary.",
311
+ )
312
+ args = parser.parse_args()
313
+ main(args)
controller.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import asyncio
7
+ import dataclasses
8
+ from enum import Enum, auto
9
+ import json
10
+ import logging
11
+ import os
12
+ import time
13
+ from typing import List, Union
14
+ import threading
15
+
16
+ from fastapi import FastAPI, Request
17
+ from fastapi.responses import StreamingResponse
18
+ import numpy as np
19
+ import requests
20
+ import uvicorn
21
+
22
+ from fastchat.constants import (
23
+ CONTROLLER_HEART_BEAT_EXPIRATION,
24
+ WORKER_API_TIMEOUT,
25
+ ErrorCode,
26
+ SERVER_ERROR_MSG,
27
+ )
28
+ from fastchat.utils import build_logger
29
+
30
+
31
+ logger = build_logger("controller", "controller.log")
32
+
33
+
34
+ class DispatchMethod(Enum):
35
+ LOTTERY = auto()
36
+ SHORTEST_QUEUE = auto()
37
+
38
+ @classmethod
39
+ def from_str(cls, name):
40
+ if name == "lottery":
41
+ return cls.LOTTERY
42
+ elif name == "shortest_queue":
43
+ return cls.SHORTEST_QUEUE
44
+ else:
45
+ raise ValueError(f"Invalid dispatch method")
46
+
47
+
48
+ @dataclasses.dataclass
49
+ class WorkerInfo:
50
+ model_names: List[str]
51
+ speed: int
52
+ queue_length: int
53
+ check_heart_beat: bool
54
+ last_heart_beat: str
55
+
56
+
57
+ def heart_beat_controller(controller):
58
+ while True:
59
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
60
+ controller.remove_stale_workers_by_expiration()
61
+
62
+
63
+ class Controller:
64
+ def __init__(self, dispatch_method: str):
65
+ # Dict[str -> WorkerInfo]
66
+ self.worker_info = {}
67
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
68
+
69
+ self.heart_beat_thread = threading.Thread(
70
+ target=heart_beat_controller, args=(self,)
71
+ )
72
+ self.heart_beat_thread.start()
73
+
74
+ def register_worker(
75
+ self, worker_name: str, check_heart_beat: bool, worker_status: dict
76
+ ):
77
+ if worker_name not in self.worker_info:
78
+ logger.info(f"Register a new worker: {worker_name}")
79
+ else:
80
+ logger.info(f"Register an existing worker: {worker_name}")
81
+
82
+ if not worker_status:
83
+ worker_status = self.get_worker_status(worker_name)
84
+ if not worker_status:
85
+ return False
86
+
87
+ self.worker_info[worker_name] = WorkerInfo(
88
+ worker_status["model_names"],
89
+ worker_status["speed"],
90
+ worker_status["queue_length"],
91
+ check_heart_beat,
92
+ time.time(),
93
+ )
94
+
95
+ logger.info(f"Register done: {worker_name}, {worker_status}")
96
+ return True
97
+
98
+ def get_worker_status(self, worker_name: str):
99
+ try:
100
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
101
+ except requests.exceptions.RequestException as e:
102
+ logger.error(f"Get status fails: {worker_name}, {e}")
103
+ return None
104
+
105
+ if r.status_code != 200:
106
+ logger.error(f"Get status fails: {worker_name}, {r}")
107
+ return None
108
+
109
+ return r.json()
110
+
111
+ def remove_worker(self, worker_name: str):
112
+ del self.worker_info[worker_name]
113
+
114
+ def refresh_all_workers(self):
115
+ old_info = dict(self.worker_info)
116
+ self.worker_info = {}
117
+
118
+ for w_name, w_info in old_info.items():
119
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
120
+ logger.info(f"Remove stale worker: {w_name}")
121
+
122
+ def list_models(self):
123
+ model_names = set()
124
+
125
+ for w_name, w_info in self.worker_info.items():
126
+ model_names.update(w_info.model_names)
127
+
128
+ return list(model_names)
129
+
130
+ def get_worker_address(self, model_name: str):
131
+ if self.dispatch_method == DispatchMethod.LOTTERY:
132
+ worker_names = []
133
+ worker_speeds = []
134
+ for w_name, w_info in self.worker_info.items():
135
+ if model_name in w_info.model_names:
136
+ worker_names.append(w_name)
137
+ worker_speeds.append(w_info.speed)
138
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
139
+ norm = np.sum(worker_speeds)
140
+ if norm < 1e-4:
141
+ return ""
142
+ worker_speeds = worker_speeds / norm
143
+ if True: # Directly return address
144
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
145
+ worker_name = worker_names[pt]
146
+ return worker_name
147
+
148
+ # Check status before returning
149
+ while True:
150
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
151
+ worker_name = worker_names[pt]
152
+
153
+ if self.get_worker_status(worker_name):
154
+ break
155
+ else:
156
+ self.remove_worker(worker_name)
157
+ worker_speeds[pt] = 0
158
+ norm = np.sum(worker_speeds)
159
+ if norm < 1e-4:
160
+ return ""
161
+ worker_speeds = worker_speeds / norm
162
+ continue
163
+ return worker_name
164
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
165
+ worker_names = []
166
+ worker_qlen = []
167
+ for w_name, w_info in self.worker_info.items():
168
+ if model_name in w_info.model_names:
169
+ worker_names.append(w_name)
170
+ worker_qlen.append(w_info.queue_length / w_info.speed)
171
+ if len(worker_names) == 0:
172
+ return ""
173
+ min_index = np.argmin(worker_qlen)
174
+ w_name = worker_names[min_index]
175
+ self.worker_info[w_name].queue_length += 1
176
+ logger.info(
177
+ f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}"
178
+ )
179
+ return w_name
180
+ else:
181
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
182
+
183
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
184
+ if worker_name not in self.worker_info:
185
+ logger.info(f"Receive unknown heart beat. {worker_name}")
186
+ return False
187
+
188
+ self.worker_info[worker_name].queue_length = queue_length
189
+ self.worker_info[worker_name].last_heart_beat = time.time()
190
+ logger.info(f"Receive heart beat. {worker_name}")
191
+ return True
192
+
193
+ def remove_stale_workers_by_expiration(self):
194
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
195
+ to_delete = []
196
+ for worker_name, w_info in self.worker_info.items():
197
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
198
+ to_delete.append(worker_name)
199
+
200
+ for worker_name in to_delete:
201
+ self.remove_worker(worker_name)
202
+
203
+ def handle_no_worker(self, params):
204
+ logger.info(f"no worker: {params['model']}")
205
+ ret = {
206
+ "text": SERVER_ERROR_MSG,
207
+ "error_code": ErrorCode.CONTROLLER_NO_WORKER,
208
+ }
209
+ return json.dumps(ret).encode() + b"\0"
210
+
211
+ def handle_worker_timeout(self, worker_address):
212
+ logger.info(f"worker timeout: {worker_address}")
213
+ ret = {
214
+ "text": SERVER_ERROR_MSG,
215
+ "error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT,
216
+ }
217
+ return json.dumps(ret).encode() + b"\0"
218
+
219
+ # Let the controller act as a worker to achieve hierarchical
220
+ # management. This can be used to connect isolated sub networks.
221
+ def worker_api_get_status(self):
222
+ model_names = set()
223
+ speed = 0
224
+ queue_length = 0
225
+
226
+ for w_name in self.worker_info:
227
+ worker_status = self.get_worker_status(w_name)
228
+ if worker_status is not None:
229
+ model_names.update(worker_status["model_names"])
230
+ speed += worker_status["speed"]
231
+ queue_length += worker_status["queue_length"]
232
+
233
+ model_names = sorted(list(model_names))
234
+ return {
235
+ "model_names": model_names,
236
+ "speed": speed,
237
+ "queue_length": queue_length,
238
+ }
239
+
240
+ def worker_api_generate_stream(self, params):
241
+ worker_addr = self.get_worker_address(params["model"])
242
+ if not worker_addr:
243
+ yield self.handle_no_worker(params)
244
+
245
+ try:
246
+ response = requests.post(
247
+ worker_addr + "/worker_generate_stream",
248
+ json=params,
249
+ stream=True,
250
+ timeout=WORKER_API_TIMEOUT,
251
+ )
252
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
253
+ if chunk:
254
+ yield chunk + b"\0"
255
+ except requests.exceptions.RequestException as e:
256
+ yield self.handle_worker_timeout(worker_addr)
257
+
258
+
259
+ app = FastAPI()
260
+
261
+
262
+ @app.post("/register_worker")
263
+ async def register_worker(request: Request):
264
+ data = await request.json()
265
+ controller.register_worker(
266
+ data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)
267
+ )
268
+
269
+
270
+ @app.post("/refresh_all_workers")
271
+ async def refresh_all_workers():
272
+ models = controller.refresh_all_workers()
273
+
274
+
275
+ @app.post("/list_models")
276
+ async def list_models():
277
+ models = controller.list_models()
278
+ return {"models": models}
279
+
280
+
281
+ @app.post("/get_worker_address")
282
+ async def get_worker_address(request: Request):
283
+ data = await request.json()
284
+ addr = controller.get_worker_address(data["model"])
285
+ return {"address": addr}
286
+
287
+
288
+ @app.post("/receive_heart_beat")
289
+ async def receive_heart_beat(request: Request):
290
+ data = await request.json()
291
+ exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
292
+ return {"exist": exist}
293
+
294
+
295
+ @app.post("/worker_generate_stream")
296
+ async def worker_api_generate_stream(request: Request):
297
+ params = await request.json()
298
+ generator = controller.worker_api_generate_stream(params)
299
+ return StreamingResponse(generator)
300
+
301
+
302
+ @app.post("/worker_get_status")
303
+ async def worker_api_get_status(request: Request):
304
+ return controller.worker_api_get_status()
305
+
306
+
307
+ @app.get("/test_connection")
308
+ async def worker_api_get_status(request: Request):
309
+ return "success"
310
+
311
+
312
+ def create_controller():
313
+ parser = argparse.ArgumentParser()
314
+ parser.add_argument("--host", type=str, default="localhost")
315
+ parser.add_argument("--port", type=int, default=21001)
316
+ parser.add_argument(
317
+ "--dispatch-method",
318
+ type=str,
319
+ choices=["lottery", "shortest_queue"],
320
+ default="shortest_queue",
321
+ )
322
+ parser.add_argument(
323
+ "--ssl",
324
+ action="store_true",
325
+ required=False,
326
+ default=False,
327
+ help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
328
+ )
329
+ args = parser.parse_args()
330
+ logger.info(f"args: {args}")
331
+
332
+ controller = Controller(args.dispatch_method)
333
+ return args, controller
334
+
335
+
336
+ if __name__ == "__main__":
337
+ args, controller = create_controller()
338
+ if args.ssl:
339
+ uvicorn.run(
340
+ app,
341
+ host=args.host,
342
+ port=args.port,
343
+ log_level="info",
344
+ ssl_keyfile=os.environ["SSL_KEYFILE"],
345
+ ssl_certfile=os.environ["SSL_CERTFILE"],
346
+ )
347
+ else:
348
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
gateway/README.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # fastchat Nginx Gateway
2
+
3
+ ## Purpose of the Gateway
4
+
5
+ The Nginx gateway serves the following purposes:
6
+
7
+ 1. Protects Gradio servers by acting as a firewall.
8
+ 2. Facilitates dynamic mounting and unmounting of Gradio servers.
9
+ 3. Provides load balancing for Gradio servers.
10
+ 4. Offers additional security features, such as total connection limit.
11
+ 5. Reduces attack surface by requiring only a single public port to be exposed for serving.
12
+
13
+ ## Deployment and Updating of the Gateway
14
+
15
+ ### Installing Nginx
16
+
17
+ On Debian-based distributions (e.g., Ubuntu):
18
+
19
+ ```bash
20
+ sudo apt update
21
+ sudo apt install nginx
22
+ ```
23
+ On Red Hat-based distributions (e.g., CentOS, Fedora):
24
+
25
+ ```bash
26
+ sudo yum install epel-release
27
+ sudo yum install nginx
28
+ ```
29
+
30
+ ### Deployment
31
+
32
+ Copy `nginx.conf` to `/etc/nginx/nginx.conf` (need sudo permission).
33
+
34
+ Replace the port number 7860 in `server localhost:7860` with the port where you deploy the Gradio web server.
35
+
36
+ Modify `upstream websocket` to configure Gradio servers behind the gateway.
37
+
38
+ Lastly, update Nginx.
39
+
40
+
41
+ ### HTTPS Deployment with a Public Domain URL
42
+
43
+ Make sure you obtain the HTTPS certificate and the private key used to generate the certificate.
44
+
45
+ Fill the addresses to your certificate and private key in the `[PATH_TO_SSL_CERT]` and `[PATH_TO_PRIVATE_KEY]` fields.
46
+
47
+ If you have your own domain url to serve the chatbot, replace the chat.lmsys.org url with your own domain url.
48
+
49
+ ### Updating
50
+
51
+ Every time when `/etc/nginx/nginx.conf` is modified, you need to update the Nginx service:
52
+
53
+ ```bash
54
+ sudo nginx -t # check `/etc/nginx/nginx.conf`
55
+ sudo systemctl reload nginx # restart Nginx service to load the new config
56
+ sudo systemctl status nginx # check the status of the Nginx service. It should be active (running).
57
+ ```
gateway/nginx.conf ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ user www-data;
2
+ worker_processes auto;
3
+ pid /run/nginx.pid;
4
+ include /etc/nginx/modules-enabled/*.conf;
5
+
6
+ events {
7
+ worker_connections 1024; # maximum number of connections that a worker process can handle concurrently
8
+ # multi_accept on; # enabling multi_accept can help improve performance under high load, but may increase the number of simultaneous connections that a worker process can handle
9
+
10
+ }
11
+
12
+ http {
13
+ ##
14
+ # Basic Settings
15
+ ##
16
+
17
+ sendfile on; # enable sendfile for performance optimization
18
+ tcp_nopush on; # enable TCP no-pushing
19
+ tcp_nodelay on; # enable TCP no-delay
20
+ keepalive_timeout 65; # sets the timeout for keep-alive connections
21
+ types_hash_max_size 2048; # maximum size of the types hash table
22
+ # server_tokens off; # disable server token (i.e., server signature) in response headers to improve security
23
+
24
+ # server_names_hash_bucket_size 64;
25
+ # server_name_in_redirect off;
26
+
27
+ include /etc/nginx/mime.types; # include MIME types file
28
+ default_type application/octet-stream; # default MIME type for unknown file types
29
+
30
+ ##
31
+ # SSL Settings
32
+ ##
33
+
34
+ ssl_protocols TLSv1.2; # specify SSL/TLS protocols to use
35
+ ssl_prefer_server_ciphers on; # prefer server ciphers over client ciphers
36
+
37
+ ##
38
+ # Logging Settings
39
+ ##
40
+
41
+ access_log /var/log/nginx/access.log; # path to access log file
42
+ error_log /var/log/nginx/error.log; # path to error log file
43
+
44
+ ##
45
+ # Gzip Settings
46
+ ##
47
+ gzip on; # enable Gzip compression
48
+
49
+ ##
50
+ # Virtual Host Configs
51
+ ##
52
+
53
+ include /etc/nginx/conf.d/*.conf; # include all configuration files in conf.d directory
54
+ include /etc/nginx/sites-enabled/*; # include all enabled sites configuration files
55
+
56
+ # WebSocket Proxy: https://www.nginx.com/blog/websocket-nginx/
57
+ map $http_upgrade $connection_upgrade {
58
+ default upgrade;
59
+ '' close;
60
+ }
61
+
62
+ upstream websocket {
63
+ ip_hash; # load balancing by IP to guarantee session persistence
64
+ server localhost:7860; # The port should be the gradio web server port
65
+ # server localhost:7861; # extra gradio server if more than one
66
+ }
67
+
68
+ limit_conn_status 429;
69
+ limit_conn_zone $binary_remote_addr zone=perip:10m; # limit number of connections per IP
70
+ limit_conn_zone $server_name zone=perserver:10m; # limit number of connections per server
71
+
72
+ server {
73
+ listen 443 ssl; # the listening port of our server
74
+ ssl_certificate [PATH_TO_SSL_CERT];
75
+ ssl_certificate_key [PATH_TO_PRIVATE_KEY];
76
+ server_name chat.lmsys.org; # replace the url with your own domain url
77
+ limit_conn perserver 1024; # connections per server
78
+ location / {
79
+ proxy_pass http://websocket; # proxy all requests to the defined upstream server
80
+ limit_conn perip 5; # connections per IP
81
+ proxy_set_header Host $host; # set the Host header for the upstream server
82
+ proxy_set_header X-Real-IP $remote_addr; # set the client IP address as the real IP for the upstream server
83
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; # set the client IP addresses in the X-Forwarded-For header
84
+ proxy_http_version 1.1; # use HTTP version 1.1 for upstream communication
85
+ proxy_set_header Upgrade $http_upgrade;
86
+ proxy_set_header Connection "Upgrade"; # set the Connection header to Upgrade to enable WebSocket communication
87
+ }
88
+ }
89
+
90
+ # the following block routes all HTTP traffic to HTTPS via nginx
91
+ server {
92
+ listen 80;
93
+ server_name chat.lmsys.org;
94
+ return 301 https://chat.lmsys.org$request_uri;
95
+ }
96
+
97
+ }
gradio_block_arena_anony.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatbot Arena (battle) tab.
3
+ Users chat with two anonymous models.
4
+ """
5
+
6
+ import json
7
+ import time
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+
12
+ from fastchat.constants import (
13
+ MODERATION_MSG,
14
+ CONVERSATION_LIMIT_MSG,
15
+ SLOW_MODEL_MSG,
16
+ INPUT_CHAR_LEN_LIMIT,
17
+ CONVERSATION_TURN_LIMIT,
18
+ )
19
+ from fastchat.model.model_adapter import get_conversation_template
20
+ from fastchat.serve.gradio_block_arena_named import flash_buttons
21
+ from fastchat.serve.gradio_web_server import (
22
+ State,
23
+ bot_response,
24
+ get_conv_log_filename,
25
+ no_change_btn,
26
+ enable_btn,
27
+ disable_btn,
28
+ invisible_btn,
29
+ acknowledgment_md,
30
+ ip_expiration_dict,
31
+ get_ip,
32
+ )
33
+ from fastchat.utils import (
34
+ build_logger,
35
+ moderation_filter,
36
+ )
37
+
38
+ logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
39
+
40
+ num_sides = 2
41
+ enable_moderation = False
42
+ anony_names = ["", ""]
43
+ models = []
44
+
45
+
46
+ def set_global_vars_anony(enable_moderation_):
47
+ global enable_moderation
48
+ enable_moderation = enable_moderation_
49
+
50
+
51
+ def load_demo_side_by_side_anony(models_, url_params):
52
+ global models
53
+ models = models_
54
+
55
+ states = (None,) * num_sides
56
+ selector_updates = (
57
+ gr.Markdown.update(visible=True),
58
+ gr.Markdown.update(visible=True),
59
+ )
60
+
61
+ return states + selector_updates
62
+
63
+
64
+ def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
65
+ with open(get_conv_log_filename(), "a") as fout:
66
+ data = {
67
+ "tstamp": round(time.time(), 4),
68
+ "type": vote_type,
69
+ "models": [x for x in model_selectors],
70
+ "states": [x.dict() for x in states],
71
+ "ip": get_ip(request),
72
+ }
73
+ fout.write(json.dumps(data) + "\n")
74
+
75
+ if ":" not in model_selectors[0]:
76
+ for i in range(15):
77
+ names = (
78
+ "### Model A: " + states[0].model_name,
79
+ "### Model B: " + states[1].model_name,
80
+ )
81
+ yield names + ("",) + (disable_btn,) * 4
82
+ time.sleep(0.2)
83
+ else:
84
+ names = (
85
+ "### Model A: " + states[0].model_name,
86
+ "### Model B: " + states[1].model_name,
87
+ )
88
+ yield names + ("",) + (disable_btn,) * 4
89
+
90
+
91
+ def leftvote_last_response(
92
+ state0, state1, model_selector0, model_selector1, request: gr.Request
93
+ ):
94
+ logger.info(f"leftvote (anony). ip: {get_ip(request)}")
95
+ for x in vote_last_response(
96
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
97
+ ):
98
+ yield x
99
+
100
+
101
+ def rightvote_last_response(
102
+ state0, state1, model_selector0, model_selector1, request: gr.Request
103
+ ):
104
+ logger.info(f"rightvote (anony). ip: {get_ip(request)}")
105
+ for x in vote_last_response(
106
+ [state0, state1], "rightvote", [model_selector0, model_selector1], request
107
+ ):
108
+ yield x
109
+
110
+
111
+ def tievote_last_response(
112
+ state0, state1, model_selector0, model_selector1, request: gr.Request
113
+ ):
114
+ logger.info(f"tievote (anony). ip: {get_ip(request)}")
115
+ for x in vote_last_response(
116
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
117
+ ):
118
+ yield x
119
+
120
+
121
+ def bothbad_vote_last_response(
122
+ state0, state1, model_selector0, model_selector1, request: gr.Request
123
+ ):
124
+ logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
125
+ for x in vote_last_response(
126
+ [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
127
+ ):
128
+ yield x
129
+
130
+
131
+ def regenerate(state0, state1, request: gr.Request):
132
+ logger.info(f"regenerate (anony). ip: {get_ip(request)}")
133
+ states = [state0, state1]
134
+ for i in range(num_sides):
135
+ states[i].conv.update_last_message(None)
136
+ return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6
137
+
138
+
139
+ def clear_history(request: gr.Request):
140
+ logger.info(f"clear_history (anony). ip: {get_ip(request)}")
141
+ return (
142
+ [None] * num_sides
143
+ + [None] * num_sides
144
+ + anony_names
145
+ + [""]
146
+ + [invisible_btn] * 4
147
+ + [disable_btn] * 2
148
+ + [""]
149
+ )
150
+
151
+
152
+ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
153
+ logger.info(f"share (anony). ip: {get_ip(request)}")
154
+ if state0 is not None and state1 is not None:
155
+ vote_last_response(
156
+ [state0, state1], "share", [model_selector0, model_selector1], request
157
+ )
158
+
159
+
160
+ SAMPLING_WEIGHTS = {
161
+ # tier 0
162
+ "gpt-4": 4,
163
+ "gpt-4-turbo": 4,
164
+ "gpt-3.5-turbo": 2,
165
+ "gpt-3.5-turbo-1106": 2,
166
+ "claude-2": 8,
167
+ "claude-1": 2,
168
+ "claude-instant-1": 8,
169
+ "zephyr-7b-beta": 2,
170
+ "openchat-3.5": 2,
171
+ # tier 1
172
+ "deluxe-chat-v1.1": 2,
173
+ "palm-2": 1.5,
174
+ "llama-2-70b-chat": 1.5,
175
+ "llama-2-13b-chat": 1.5,
176
+ "codellama-34b-instruct": 1.5,
177
+ "vicuna-33b": 8,
178
+ "vicuna-13b": 1.5,
179
+ "wizardlm-70b": 1.5,
180
+ "wizardlm-13b": 1.5,
181
+ "qwen-14b-chat": 1.5,
182
+ "mistral-7b-instruct": 1.5,
183
+ # tier 2
184
+ "vicuna-7b": 1.0,
185
+ "llama-2-7b-chat": 1.0,
186
+ "chatglm2-6b": 1.0,
187
+ # deprecated
188
+ "zephyr-7b-alpha": 1.5,
189
+ "codellama-13b-instruct": 1.0,
190
+ "mpt-30b-chat": 1.5,
191
+ "guanaco-33b": 1.0,
192
+ "fastchat-t5-3b": 0.5,
193
+ "alpaca-13b": 0.5,
194
+ "mpt-7b-chat": 0.1,
195
+ "oasst-pythia-12b": 0.1,
196
+ "RWKV-4-Raven-14B": 0.1,
197
+ "gpt4all-13b-snoozy": 0.1,
198
+ "koala-13b": 0.1,
199
+ "stablelm-tuned-alpha-7b": 0.1,
200
+ "dolly-v2-12b": 0.1,
201
+ "llama-13b": 0.1,
202
+ "chatglm-6b": 0.5,
203
+ "deluxe-chat-v1": 4,
204
+ }
205
+
206
+ # target model sampling weights will be boosted.
207
+ BATTLE_TARGETS = {
208
+ "gpt-4": {"claude-2"},
209
+ "gpt-4-turbo": {"gpt-4", "gpt-3.5-turbo"},
210
+ "gpt-3.5-turbo": {"claude-instant-1", "gpt-4", "claude-2"},
211
+ "claude-2": {"gpt-4", "gpt-3.5-turbo", "claude-1"},
212
+ "claude-1": {"claude-2", "gpt-4", "gpt-3.5-turbo"},
213
+ "claude-instant-1": {"gpt-3.5-turbo", "claude-2"},
214
+ "deluxe-chat-v1.1": {"gpt-4"},
215
+ "openchat-3.5": {"gpt-3.5-turbo", "llama-2-70b-chat", "zephyr-7b-beta"},
216
+ "qwen-14b-chat": {"vicuna-13b", "llama-2-13b-chat", "llama-2-70b-chat"},
217
+ "zephyr-7b-alpha": {"mistral-7b-instruct", "llama-2-13b-chat"},
218
+ "zephyr-7b-beta": {
219
+ "mistral-7b-instruct",
220
+ "llama-2-13b-chat",
221
+ "llama-2-7b-chat",
222
+ "wizardlm-13b",
223
+ },
224
+ "llama-2-70b-chat": {"gpt-3.5-turbo", "vicuna-33b", "claude-instant-1"},
225
+ "llama-2-13b-chat": {"mistral-7b-instruct", "vicuna-13b", "llama-2-70b-chat"},
226
+ "llama-2-7b-chat": {"mistral-7b-instruct", "vicuna-7b", "llama-2-13b-chat"},
227
+ "mistral-7b-instruct": {
228
+ "llama-2-7b-chat",
229
+ "llama-2-13b-chat",
230
+ "llama-2-70b-chat",
231
+ },
232
+ "vicuna-33b": {"llama-2-70b-chat", "gpt-3.5-turbo", "claude-instant-1"},
233
+ "vicuna-13b": {"llama-2-13b-chat", "llama-2-70b-chat"},
234
+ "vicuna-7b": {"llama-2-7b-chat", "mistral-7b-instruct", "llama-2-13b-chat"},
235
+ "wizardlm-70b": {"gpt-3.5-turbo", "vicuna-33b", "claude-instant-1"},
236
+ "palm-2": {"llama-2-13b-chat", "gpt-3.5-turbo"},
237
+ }
238
+
239
+ SAMPLING_BOOST_MODELS = ["openchat-3.5", "gpt-4-turbo", "gpt-3.5-turbo-1106"]
240
+
241
+ # outage models won't be sampled.
242
+ OUTAGE_MODELS = []
243
+
244
+
245
+ def get_sample_weight(model):
246
+ if model in OUTAGE_MODELS:
247
+ return 0
248
+ weight = SAMPLING_WEIGHTS.get(model, 1.0)
249
+ if model in SAMPLING_BOOST_MODELS:
250
+ weight *= 5
251
+ return weight
252
+
253
+
254
+ def get_battle_pair():
255
+ if len(models) == 1:
256
+ return models[0], models[0]
257
+
258
+ model_weights = []
259
+ for model in models:
260
+ weight = get_sample_weight(model)
261
+ model_weights.append(weight)
262
+ total_weight = np.sum(model_weights)
263
+ model_weights = model_weights / total_weight
264
+ chosen_idx = np.random.choice(len(models), p=model_weights)
265
+ chosen_model = models[chosen_idx]
266
+
267
+ rival_models = []
268
+ rival_weights = []
269
+ for model in models:
270
+ if model == chosen_model:
271
+ continue
272
+ weight = get_sample_weight(model)
273
+ if (
274
+ weight != 0
275
+ and chosen_model in BATTLE_TARGETS
276
+ and model in BATTLE_TARGETS[chosen_model]
277
+ ):
278
+ # boost to 50% chance
279
+ weight = total_weight / len(BATTLE_TARGETS[chosen_model])
280
+ rival_models.append(model)
281
+ rival_weights.append(weight)
282
+ # for p, w in zip(rival_models, rival_weights):
283
+ # print(p, w)
284
+ rival_weights = rival_weights / np.sum(rival_weights)
285
+ rival_idx = np.random.choice(len(rival_models), p=rival_weights)
286
+ rival_model = rival_models[rival_idx]
287
+
288
+ swap = np.random.randint(2)
289
+ if swap == 0:
290
+ return chosen_model, rival_model
291
+ else:
292
+ return rival_model, chosen_model
293
+
294
+
295
+ def add_text(
296
+ state0, state1, model_selector0, model_selector1, text, request: gr.Request
297
+ ):
298
+ ip = get_ip(request)
299
+ logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
300
+ states = [state0, state1]
301
+ model_selectors = [model_selector0, model_selector1]
302
+
303
+ # Init states if necessary
304
+ if states[0] is None:
305
+ assert states[1] is None
306
+
307
+ model_left, model_right = get_battle_pair()
308
+ states = [
309
+ State(model_left),
310
+ State(model_right),
311
+ ]
312
+
313
+ if len(text) <= 0:
314
+ for i in range(num_sides):
315
+ states[i].skip_next = True
316
+ return (
317
+ states
318
+ + [x.to_gradio_chatbot() for x in states]
319
+ + [""]
320
+ + [
321
+ no_change_btn,
322
+ ]
323
+ * 6
324
+ + [""]
325
+ )
326
+
327
+ model_list = [states[i].model_name for i in range(num_sides)]
328
+ flagged = moderation_filter(text, model_list)
329
+ if flagged:
330
+ logger.info(f"violate moderation (anony). ip: {ip}. text: {text}")
331
+ # overwrite the original text
332
+ text = MODERATION_MSG
333
+
334
+ conv = states[0].conv
335
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
336
+ logger.info(f"conversation turn limit. ip: {get_ip(request)}. text: {text}")
337
+ for i in range(num_sides):
338
+ states[i].skip_next = True
339
+ return (
340
+ states
341
+ + [x.to_gradio_chatbot() for x in states]
342
+ + [CONVERSATION_LIMIT_MSG]
343
+ + [
344
+ no_change_btn,
345
+ ]
346
+ * 6
347
+ + [""]
348
+ )
349
+
350
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
351
+ for i in range(num_sides):
352
+ states[i].conv.append_message(states[i].conv.roles[0], text)
353
+ states[i].conv.append_message(states[i].conv.roles[1], None)
354
+ states[i].skip_next = False
355
+
356
+ slow_model_msg = ""
357
+ for i in range(num_sides):
358
+ if "deluxe" in states[i].model_name:
359
+ slow_model_msg = SLOW_MODEL_MSG
360
+ return (
361
+ states
362
+ + [x.to_gradio_chatbot() for x in states]
363
+ + [""]
364
+ + [
365
+ disable_btn,
366
+ ]
367
+ * 6
368
+ + [slow_model_msg]
369
+ )
370
+
371
+
372
+ def bot_response_multi(
373
+ state0,
374
+ state1,
375
+ temperature,
376
+ top_p,
377
+ max_new_tokens,
378
+ request: gr.Request,
379
+ ):
380
+ logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}")
381
+
382
+ if state0 is None or state0.skip_next:
383
+ # This generate call is skipped due to invalid inputs
384
+ yield (
385
+ state0,
386
+ state1,
387
+ state0.to_gradio_chatbot(),
388
+ state1.to_gradio_chatbot(),
389
+ ) + (no_change_btn,) * 6
390
+ return
391
+
392
+ states = [state0, state1]
393
+ gen = []
394
+ for i in range(num_sides):
395
+ gen.append(
396
+ bot_response(
397
+ states[i],
398
+ temperature,
399
+ top_p,
400
+ max_new_tokens,
401
+ request,
402
+ )
403
+ )
404
+
405
+ chatbots = [None] * num_sides
406
+ while True:
407
+ stop = True
408
+ for i in range(num_sides):
409
+ try:
410
+ ret = next(gen[i])
411
+ states[i], chatbots[i] = ret[0], ret[1]
412
+ stop = False
413
+ except StopIteration:
414
+ pass
415
+ yield states + chatbots + [disable_btn] * 6
416
+ if stop:
417
+ break
418
+
419
+
420
+ def build_side_by_side_ui_anony(models):
421
+ notice_markdown = """
422
+ # ⚔️ Chatbot Arena ⚔️ : Benchmarking LLMs in the Wild
423
+ | [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
424
+
425
+ ## 📜 Rules
426
+ - Ask any question to two anonymous models (e.g., ChatGPT, Claude, Llama) and vote for the better one!
427
+ - You can continue chatting until you identify a winner.
428
+ - Vote won't be counted if model identity is revealed during conversation.
429
+
430
+ ## 🏆 Arena Elo [Leaderboard](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard)
431
+ We use **100K** human votes to compile an Elo-based LLM leaderboard.
432
+ Find out who is the 🥇LLM Champion!
433
+
434
+ ## 👇 Chat now!
435
+
436
+ """
437
+
438
+ states = [gr.State() for _ in range(num_sides)]
439
+ model_selectors = [None] * num_sides
440
+ chatbots = [None] * num_sides
441
+
442
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
443
+
444
+ with gr.Box(elem_id="share-region-anony"):
445
+ with gr.Row():
446
+ for i in range(num_sides):
447
+ label = "Model A" if i == 0 else "Model B"
448
+ with gr.Column():
449
+ chatbots[i] = gr.Chatbot(
450
+ label=label, elem_id=f"chatbot", height=550
451
+ )
452
+
453
+ with gr.Row():
454
+ for i in range(num_sides):
455
+ with gr.Column():
456
+ model_selectors[i] = gr.Markdown(anony_names[i])
457
+ with gr.Row():
458
+ slow_warning = gr.Markdown("", elem_id="notice_markdown")
459
+
460
+ with gr.Row():
461
+ leftvote_btn = gr.Button(
462
+ value="👈 A is better", visible=False, interactive=False
463
+ )
464
+ rightvote_btn = gr.Button(
465
+ value="👉 B is better", visible=False, interactive=False
466
+ )
467
+ tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
468
+ bothbad_btn = gr.Button(
469
+ value="👎 Both are bad", visible=False, interactive=False
470
+ )
471
+
472
+ with gr.Row():
473
+ with gr.Column(scale=20):
474
+ textbox = gr.Textbox(
475
+ show_label=False,
476
+ placeholder="👉 Enter your prompt and press ENTER",
477
+ container=False,
478
+ elem_id="input_box",
479
+ )
480
+ with gr.Column(scale=1, min_width=50):
481
+ send_btn = gr.Button(value="Send", variant="primary")
482
+
483
+ with gr.Row() as button_row:
484
+ clear_btn = gr.Button(value="🎲 New Round", interactive=False)
485
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
486
+ share_btn = gr.Button(value="📷 Share")
487
+
488
+ with gr.Accordion("Parameters", open=False) as parameter_row:
489
+ temperature = gr.Slider(
490
+ minimum=0.0,
491
+ maximum=1.0,
492
+ value=0.7,
493
+ step=0.1,
494
+ interactive=True,
495
+ label="Temperature",
496
+ )
497
+ top_p = gr.Slider(
498
+ minimum=0.0,
499
+ maximum=1.0,
500
+ value=1.0,
501
+ step=0.1,
502
+ interactive=True,
503
+ label="Top P",
504
+ )
505
+ max_output_tokens = gr.Slider(
506
+ minimum=16,
507
+ maximum=1024,
508
+ value=512,
509
+ step=64,
510
+ interactive=True,
511
+ label="Max output tokens",
512
+ )
513
+
514
+ gr.Markdown(acknowledgment_md)
515
+
516
+ # Register listeners
517
+ btn_list = [
518
+ leftvote_btn,
519
+ rightvote_btn,
520
+ tie_btn,
521
+ bothbad_btn,
522
+ regenerate_btn,
523
+ clear_btn,
524
+ ]
525
+ leftvote_btn.click(
526
+ leftvote_last_response,
527
+ states + model_selectors,
528
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
529
+ )
530
+ rightvote_btn.click(
531
+ rightvote_last_response,
532
+ states + model_selectors,
533
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
534
+ )
535
+ tie_btn.click(
536
+ tievote_last_response,
537
+ states + model_selectors,
538
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
539
+ )
540
+ bothbad_btn.click(
541
+ bothbad_vote_last_response,
542
+ states + model_selectors,
543
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
544
+ )
545
+ regenerate_btn.click(
546
+ regenerate, states, states + chatbots + [textbox] + btn_list
547
+ ).then(
548
+ bot_response_multi,
549
+ states + [temperature, top_p, max_output_tokens],
550
+ states + chatbots + btn_list,
551
+ ).then(
552
+ flash_buttons, [], btn_list
553
+ )
554
+ clear_btn.click(
555
+ clear_history,
556
+ None,
557
+ states + chatbots + model_selectors + [textbox] + btn_list + [slow_warning],
558
+ )
559
+
560
+ share_js = """
561
+ function (a, b, c, d) {
562
+ const captureElement = document.querySelector('#share-region-anony');
563
+ html2canvas(captureElement)
564
+ .then(canvas => {
565
+ canvas.style.display = 'none'
566
+ document.body.appendChild(canvas)
567
+ return canvas
568
+ })
569
+ .then(canvas => {
570
+ const image = canvas.toDataURL('image/png')
571
+ const a = document.createElement('a')
572
+ a.setAttribute('download', 'chatbot-arena.png')
573
+ a.setAttribute('href', image)
574
+ a.click()
575
+ canvas.remove()
576
+ });
577
+ return [a, b, c, d];
578
+ }
579
+ """
580
+ share_btn.click(share_click, states + model_selectors, [], _js=share_js)
581
+
582
+ textbox.submit(
583
+ add_text,
584
+ states + model_selectors + [textbox],
585
+ states + chatbots + [textbox] + btn_list + [slow_warning],
586
+ ).then(
587
+ bot_response_multi,
588
+ states + [temperature, top_p, max_output_tokens],
589
+ states + chatbots + btn_list,
590
+ ).then(
591
+ flash_buttons,
592
+ [],
593
+ btn_list,
594
+ )
595
+
596
+ send_btn.click(
597
+ add_text,
598
+ states + model_selectors + [textbox],
599
+ states + chatbots + [textbox] + btn_list,
600
+ ).then(
601
+ bot_response_multi,
602
+ states + [temperature, top_p, max_output_tokens],
603
+ states + chatbots + btn_list,
604
+ ).then(
605
+ flash_buttons, [], btn_list
606
+ )
607
+
608
+ return states + model_selectors
gradio_block_arena_named.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatbot Arena (side-by-side) tab.
3
+ Users chat with two chosen models.
4
+ """
5
+
6
+ import json
7
+ import time
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+
12
+ from fastchat.constants import (
13
+ MODERATION_MSG,
14
+ CONVERSATION_LIMIT_MSG,
15
+ INPUT_CHAR_LEN_LIMIT,
16
+ CONVERSATION_TURN_LIMIT,
17
+ )
18
+ from fastchat.model.model_adapter import get_conversation_template
19
+ from fastchat.serve.gradio_web_server import (
20
+ State,
21
+ bot_response,
22
+ get_conv_log_filename,
23
+ no_change_btn,
24
+ enable_btn,
25
+ disable_btn,
26
+ invisible_btn,
27
+ acknowledgment_md,
28
+ get_model_description_md,
29
+ ip_expiration_dict,
30
+ get_ip,
31
+ )
32
+ from fastchat.utils import (
33
+ build_logger,
34
+ moderation_filter,
35
+ )
36
+
37
+
38
+ logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
39
+
40
+ num_sides = 2
41
+ enable_moderation = False
42
+
43
+
44
+ def set_global_vars_named(enable_moderation_):
45
+ global enable_moderation
46
+ enable_moderation = enable_moderation_
47
+
48
+
49
+ def load_demo_side_by_side_named(models, url_params):
50
+ states = (None,) * num_sides
51
+
52
+ model_left = models[0] if len(models) > 0 else ""
53
+ if len(models) > 1:
54
+ weights = ([8] * 4 + [4] * 8 + [1] * 32)[: len(models) - 1]
55
+ weights = weights / np.sum(weights)
56
+ model_right = np.random.choice(models[1:], p=weights)
57
+ else:
58
+ model_right = model_left
59
+
60
+ selector_updates = (
61
+ gr.Dropdown.update(choices=models, value=model_left, visible=True),
62
+ gr.Dropdown.update(choices=models, value=model_right, visible=True),
63
+ )
64
+
65
+ return states + selector_updates
66
+
67
+
68
+ def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
69
+ with open(get_conv_log_filename(), "a") as fout:
70
+ data = {
71
+ "tstamp": round(time.time(), 4),
72
+ "type": vote_type,
73
+ "models": [x for x in model_selectors],
74
+ "states": [x.dict() for x in states],
75
+ "ip": get_ip(request),
76
+ }
77
+ fout.write(json.dumps(data) + "\n")
78
+
79
+
80
+ def leftvote_last_response(
81
+ state0, state1, model_selector0, model_selector1, request: gr.Request
82
+ ):
83
+ logger.info(f"leftvote (named). ip: {get_ip(request)}")
84
+ vote_last_response(
85
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
86
+ )
87
+ return ("",) + (disable_btn,) * 4
88
+
89
+
90
+ def rightvote_last_response(
91
+ state0, state1, model_selector0, model_selector1, request: gr.Request
92
+ ):
93
+ logger.info(f"rightvote (named). ip: {get_ip(request)}")
94
+ vote_last_response(
95
+ [state0, state1], "rightvote", [model_selector0, model_selector1], request
96
+ )
97
+ return ("",) + (disable_btn,) * 4
98
+
99
+
100
+ def tievote_last_response(
101
+ state0, state1, model_selector0, model_selector1, request: gr.Request
102
+ ):
103
+ logger.info(f"tievote (named). ip: {get_ip(request)}")
104
+ vote_last_response(
105
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
106
+ )
107
+ return ("",) + (disable_btn,) * 4
108
+
109
+
110
+ def bothbad_vote_last_response(
111
+ state0, state1, model_selector0, model_selector1, request: gr.Request
112
+ ):
113
+ logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
114
+ vote_last_response(
115
+ [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
116
+ )
117
+ return ("",) + (disable_btn,) * 4
118
+
119
+
120
+ def regenerate(state0, state1, request: gr.Request):
121
+ logger.info(f"regenerate (named). ip: {get_ip(request)}")
122
+ states = [state0, state1]
123
+ for i in range(num_sides):
124
+ states[i].conv.update_last_message(None)
125
+ return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6
126
+
127
+
128
+ def clear_history(request: gr.Request):
129
+ logger.info(f"clear_history (named). ip: {get_ip(request)}")
130
+ return (
131
+ [None] * num_sides
132
+ + [None] * num_sides
133
+ + [""]
134
+ + [invisible_btn] * 4
135
+ + [disable_btn] * 2
136
+ )
137
+
138
+
139
+ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
140
+ logger.info(f"share (named). ip: {get_ip(request)}")
141
+ if state0 is not None and state1 is not None:
142
+ vote_last_response(
143
+ [state0, state1], "share", [model_selector0, model_selector1], request
144
+ )
145
+
146
+
147
+ def add_text(
148
+ state0, state1, model_selector0, model_selector1, text, request: gr.Request
149
+ ):
150
+ ip = get_ip(request)
151
+ logger.info(f"add_text (named). ip: {ip}. len: {len(text)}")
152
+ states = [state0, state1]
153
+ model_selectors = [model_selector0, model_selector1]
154
+
155
+ # Init states if necessary
156
+ for i in range(num_sides):
157
+ if states[i] is None:
158
+ states[i] = State(model_selectors[i])
159
+
160
+ if len(text) <= 0:
161
+ for i in range(num_sides):
162
+ states[i].skip_next = True
163
+ return (
164
+ states
165
+ + [x.to_gradio_chatbot() for x in states]
166
+ + [""]
167
+ + [
168
+ no_change_btn,
169
+ ]
170
+ * 6
171
+ )
172
+
173
+ model_list = [states[i].model_name for i in range(num_sides)]
174
+ flagged = moderation_filter(text, model_list)
175
+ if flagged:
176
+ logger.info(f"violate moderation (named). ip: {ip}. text: {text}")
177
+ # overwrite the original text
178
+ text = MODERATION_MSG
179
+
180
+ conv = states[0].conv
181
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
182
+ logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
183
+ for i in range(num_sides):
184
+ states[i].skip_next = True
185
+ return (
186
+ states
187
+ + [x.to_gradio_chatbot() for x in states]
188
+ + [CONVERSATION_LIMIT_MSG]
189
+ + [
190
+ no_change_btn,
191
+ ]
192
+ * 6
193
+ )
194
+
195
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
196
+ for i in range(num_sides):
197
+ states[i].conv.append_message(states[i].conv.roles[0], text)
198
+ states[i].conv.append_message(states[i].conv.roles[1], None)
199
+ states[i].skip_next = False
200
+
201
+ return (
202
+ states
203
+ + [x.to_gradio_chatbot() for x in states]
204
+ + [""]
205
+ + [
206
+ disable_btn,
207
+ ]
208
+ * 6
209
+ )
210
+
211
+
212
+ def bot_response_multi(
213
+ state0,
214
+ state1,
215
+ temperature,
216
+ top_p,
217
+ max_new_tokens,
218
+ request: gr.Request,
219
+ ):
220
+ logger.info(f"bot_response_multi (named). ip: {get_ip(request)}")
221
+
222
+ if state0.skip_next:
223
+ # This generate call is skipped due to invalid inputs
224
+ yield (
225
+ state0,
226
+ state1,
227
+ state0.to_gradio_chatbot(),
228
+ state1.to_gradio_chatbot(),
229
+ ) + (no_change_btn,) * 6
230
+ return
231
+
232
+ states = [state0, state1]
233
+ gen = []
234
+ for i in range(num_sides):
235
+ gen.append(
236
+ bot_response(
237
+ states[i],
238
+ temperature,
239
+ top_p,
240
+ max_new_tokens,
241
+ request,
242
+ )
243
+ )
244
+
245
+ chatbots = [None] * num_sides
246
+ while True:
247
+ stop = True
248
+ for i in range(num_sides):
249
+ try:
250
+ ret = next(gen[i])
251
+ states[i], chatbots[i] = ret[0], ret[1]
252
+ stop = False
253
+ except StopIteration:
254
+ pass
255
+ yield states + chatbots + [disable_btn] * 6
256
+ if stop:
257
+ break
258
+
259
+
260
+ def flash_buttons():
261
+ btn_updates = [
262
+ [disable_btn] * 4 + [enable_btn] * 2,
263
+ [enable_btn] * 6,
264
+ ]
265
+ for i in range(4):
266
+ yield btn_updates[i % 2]
267
+ time.sleep(0.5)
268
+
269
+
270
+ def build_side_by_side_ui_named(models):
271
+ notice_markdown = """
272
+ # ⚔️ Chatbot Arena ⚔️ : Benchmarking LLMs in the Wild
273
+ | [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
274
+
275
+ ## 📜 Rules
276
+ - Chat with any two models side-by-side and vote!
277
+ - You can continue chatting for multiple rounds.
278
+ - Click "Clear history" to start a new round.
279
+
280
+ ## 🤖 Choose two models to compare
281
+ """
282
+
283
+ states = [gr.State() for _ in range(num_sides)]
284
+ model_selectors = [None] * num_sides
285
+ chatbots = [None] * num_sides
286
+
287
+ model_description_md = get_model_description_md(models)
288
+ notice = gr.Markdown(
289
+ notice_markdown + model_description_md, elem_id="notice_markdown"
290
+ )
291
+
292
+ with gr.Box(elem_id="share-region-named"):
293
+ with gr.Row():
294
+ for i in range(num_sides):
295
+ with gr.Column():
296
+ model_selectors[i] = gr.Dropdown(
297
+ choices=models,
298
+ value=models[i] if len(models) > i else "",
299
+ interactive=True,
300
+ show_label=False,
301
+ container=False,
302
+ )
303
+
304
+ with gr.Row():
305
+ for i in range(num_sides):
306
+ label = "Model A" if i == 0 else "Model B"
307
+ with gr.Column():
308
+ chatbots[i] = gr.Chatbot(
309
+ label=label, elem_id=f"chatbot", height=550
310
+ )
311
+
312
+ with gr.Row():
313
+ leftvote_btn = gr.Button(
314
+ value="👈 A is better", visible=False, interactive=False
315
+ )
316
+ rightvote_btn = gr.Button(
317
+ value="👉 B is better", visible=False, interactive=False
318
+ )
319
+ tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
320
+ bothbad_btn = gr.Button(
321
+ value="👎 Both are bad", visible=False, interactive=False
322
+ )
323
+
324
+ with gr.Row():
325
+ with gr.Column(scale=20):
326
+ textbox = gr.Textbox(
327
+ show_label=False,
328
+ placeholder="Enter your prompt here and press ENTER",
329
+ container=False,
330
+ elem_id="input_box",
331
+ )
332
+ with gr.Column(scale=1, min_width=50):
333
+ send_btn = gr.Button(value="Send", variant="primary")
334
+
335
+ with gr.Row() as button_row:
336
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
337
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
338
+ share_btn = gr.Button(value="📷 Share")
339
+
340
+ with gr.Accordion("Parameters", open=False) as parameter_row:
341
+ temperature = gr.Slider(
342
+ minimum=0.0,
343
+ maximum=1.0,
344
+ value=0.7,
345
+ step=0.1,
346
+ interactive=True,
347
+ label="Temperature",
348
+ )
349
+ top_p = gr.Slider(
350
+ minimum=0.0,
351
+ maximum=1.0,
352
+ value=1.0,
353
+ step=0.1,
354
+ interactive=True,
355
+ label="Top P",
356
+ )
357
+ max_output_tokens = gr.Slider(
358
+ minimum=16,
359
+ maximum=1024,
360
+ value=512,
361
+ step=64,
362
+ interactive=True,
363
+ label="Max output tokens",
364
+ )
365
+
366
+ gr.Markdown(acknowledgment_md)
367
+
368
+ # Register listeners
369
+ btn_list = [
370
+ leftvote_btn,
371
+ rightvote_btn,
372
+ tie_btn,
373
+ bothbad_btn,
374
+ regenerate_btn,
375
+ clear_btn,
376
+ ]
377
+ leftvote_btn.click(
378
+ leftvote_last_response,
379
+ states + model_selectors,
380
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
381
+ )
382
+ rightvote_btn.click(
383
+ rightvote_last_response,
384
+ states + model_selectors,
385
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
386
+ )
387
+ tie_btn.click(
388
+ tievote_last_response,
389
+ states + model_selectors,
390
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
391
+ )
392
+ bothbad_btn.click(
393
+ bothbad_vote_last_response,
394
+ states + model_selectors,
395
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
396
+ )
397
+ regenerate_btn.click(
398
+ regenerate, states, states + chatbots + [textbox] + btn_list
399
+ ).then(
400
+ bot_response_multi,
401
+ states + [temperature, top_p, max_output_tokens],
402
+ states + chatbots + btn_list,
403
+ ).then(
404
+ flash_buttons, [], btn_list
405
+ )
406
+ clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list)
407
+
408
+ share_js = """
409
+ function (a, b, c, d) {
410
+ const captureElement = document.querySelector('#share-region-named');
411
+ html2canvas(captureElement)
412
+ .then(canvas => {
413
+ canvas.style.display = 'none'
414
+ document.body.appendChild(canvas)
415
+ return canvas
416
+ })
417
+ .then(canvas => {
418
+ const image = canvas.toDataURL('image/png')
419
+ const a = document.createElement('a')
420
+ a.setAttribute('download', 'chatbot-arena.png')
421
+ a.setAttribute('href', image)
422
+ a.click()
423
+ canvas.remove()
424
+ });
425
+ return [a, b, c, d];
426
+ }
427
+ """
428
+ share_btn.click(share_click, states + model_selectors, [], _js=share_js)
429
+
430
+ for i in range(num_sides):
431
+ model_selectors[i].change(
432
+ clear_history, None, states + chatbots + [textbox] + btn_list
433
+ )
434
+
435
+ textbox.submit(
436
+ add_text,
437
+ states + model_selectors + [textbox],
438
+ states + chatbots + [textbox] + btn_list,
439
+ ).then(
440
+ bot_response_multi,
441
+ states + [temperature, top_p, max_output_tokens],
442
+ states + chatbots + btn_list,
443
+ ).then(
444
+ flash_buttons, [], btn_list
445
+ )
446
+ send_btn.click(
447
+ add_text,
448
+ states + model_selectors + [textbox],
449
+ states + chatbots + [textbox] + btn_list,
450
+ ).then(
451
+ bot_response_multi,
452
+ states + [temperature, top_p, max_output_tokens],
453
+ states + chatbots + btn_list,
454
+ ).then(
455
+ flash_buttons, [], btn_list
456
+ )
457
+
458
+ return states + model_selectors
gradio_web_server.py ADDED
@@ -0,0 +1,883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The gradio demo server for chatting with a single model.
3
+ """
4
+
5
+ import argparse
6
+ from collections import defaultdict
7
+ import datetime
8
+ import json
9
+ import os
10
+ import random
11
+ import time
12
+ import uuid
13
+
14
+ import gradio as gr
15
+ import requests
16
+
17
+ from fastchat.conversation import SeparatorStyle
18
+ from fastchat.constants import (
19
+ LOGDIR,
20
+ WORKER_API_TIMEOUT,
21
+ ErrorCode,
22
+ MODERATION_MSG,
23
+ CONVERSATION_LIMIT_MSG,
24
+ SERVER_ERROR_MSG,
25
+ INPUT_CHAR_LEN_LIMIT,
26
+ CONVERSATION_TURN_LIMIT,
27
+ SESSION_EXPIRATION_TIME,
28
+ )
29
+ from fastchat.model.model_adapter import get_conversation_template
30
+ from fastchat.conversation import get_conv_template
31
+ from fastchat.model.model_registry import get_model_info, model_info
32
+ from fastchat.serve.api_provider import (
33
+ anthropic_api_stream_iter,
34
+ openai_api_stream_iter,
35
+ palm_api_stream_iter,
36
+ init_palm_chat,
37
+ )
38
+ from fastchat.utils import (
39
+ build_logger,
40
+ moderation_filter,
41
+ get_window_url_params_js,
42
+ get_window_url_params_with_tos_js,
43
+ parse_gradio_auth_creds,
44
+ )
45
+
46
+ CONV_TEMPLATE = ''
47
+
48
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
49
+
50
+ headers = {"User-Agent": "FastChat Client"}
51
+
52
+ no_change_btn = gr.Button.update()
53
+ enable_btn = gr.Button.update(interactive=True, visible=True)
54
+ disable_btn = gr.Button.update(interactive=False)
55
+ invisible_btn = gr.Button.update(interactive=False, visible=False)
56
+
57
+ controller_url = None
58
+ enable_moderation = False
59
+
60
+ acknowledgment_md = """
61
+ ### Acknowledgment
62
+ <div class="image-container">
63
+ <p> We thank <a href="https://www.kaggle.com/" target="_blank">Kaggle</a>, <a href="https://mbzuai.ac.ae/" target="_blank">MBZUAI</a>, <a href="https://www.anyscale.com/" target="_blank">AnyScale</a>, and <a href="https://huggingface.co/" target="_blank">HuggingFace</a> for their <a href="https://lmsys.org/donations/" target="_blank">sponsorship</a>. </p>
64
+ <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/7/7c/Kaggle_logo.png/400px-Kaggle_logo.png" alt="Image 1">
65
+ <img src="https://mma.prnewswire.com/media/1227419/MBZUAI_Logo.jpg?p=facebookg" alt="Image 2">
66
+ <img src="https://docs.anyscale.com/site-assets/logo.png" alt="Image 3">
67
+ <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo-with-title.png" alt="Image 4">
68
+ </div>
69
+ """
70
+
71
+ ip_expiration_dict = defaultdict(lambda: 0)
72
+
73
+ # Information about custom OpenAI compatible API models.
74
+ # JSON file format:
75
+ # {
76
+ # "vicuna-7b": {
77
+ # "model_name": "vicuna-7b-v1.5",
78
+ # "api_base": "http://8.8.8.55:5555/v1",
79
+ # "api_key": "password"
80
+ # },
81
+ # }
82
+ openai_compatible_models_info = {}
83
+
84
+
85
+ class State:
86
+ def __init__(self, model_name):
87
+ # if model_name=='checkpoint-800':
88
+ # self.conv = get_conv_template(CONV_TEMPLATE)
89
+ # elif model_name=='MiniCPM-2B-sft-bf16':
90
+ ret = requests.post(
91
+ controller_url + "/get_worker_address", json={"model": model_name}
92
+ )
93
+ worker_addr = ret.json()["address"]
94
+ conv_name = requests.post(
95
+ worker_addr + "/worker_get_conv_template",
96
+ ).json()['conv']['name']
97
+ self.conv = get_conv_template(conv_name)
98
+ # self.conv = get_conv_template('minicpm')
99
+ # print(self.conv)
100
+ # self.conv = get_conversation_template(model_name)
101
+ self.conv_id = uuid.uuid4().hex
102
+ self.skip_next = False
103
+ self.model_name = model_name
104
+
105
+ if model_name == "palm-2":
106
+ # According to release note, "chat-bison@001" is PaLM 2 for chat.
107
+ # https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023
108
+ self.palm_chat = init_palm_chat("chat-bison@001")
109
+
110
+ def to_gradio_chatbot(self):
111
+ return self.conv.to_gradio_chatbot()
112
+
113
+ def dict(self):
114
+ base = self.conv.dict()
115
+ base.update(
116
+ {
117
+ "conv_id": self.conv_id,
118
+ "model_name": self.model_name,
119
+ }
120
+ )
121
+ return base
122
+
123
+
124
+ def set_global_vars(controller_url_, enable_moderation_):
125
+ global controller_url, enable_moderation
126
+ controller_url = controller_url_
127
+ enable_moderation = enable_moderation_
128
+
129
+
130
+ def get_conv_log_filename():
131
+ t = datetime.datetime.now()
132
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
133
+ return name
134
+
135
+
136
+ def get_model_list(
137
+ controller_url, register_openai_compatible_models, add_chatgpt, add_claude, add_palm
138
+ ):
139
+ if controller_url:
140
+ ret = requests.post(controller_url + "/refresh_all_workers")
141
+ assert ret.status_code == 200
142
+ ret = requests.post(controller_url + "/list_models")
143
+ # ret = requests.post(controller_url + "/get_worker_address")
144
+ # ret = requests.post(controller_url + "/worker_get_status")
145
+ models = ret.json()["models"]
146
+ else:
147
+ models = []
148
+
149
+ # Add API providers
150
+ if register_openai_compatible_models:
151
+ global openai_compatible_models_info
152
+ openai_compatible_models_info = json.load(
153
+ open(register_openai_compatible_models)
154
+ )
155
+ models += list(openai_compatible_models_info.keys())
156
+
157
+ if add_chatgpt:
158
+ models += ["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", "gpt-3.5-turbo-1106"]
159
+ if add_claude:
160
+ models += ["claude-2", "claude-instant-1"]
161
+ if add_palm:
162
+ models += ["palm-2"]
163
+ models = list(set(models))
164
+
165
+ if "deluxe-chat-v1" in models:
166
+ del models[models.index("deluxe-chat-v1")]
167
+ if "deluxe-chat-v1.1" in models:
168
+ del models[models.index("deluxe-chat-v1.1")]
169
+
170
+ priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)}
171
+ models.sort(key=lambda x: priority.get(x, x))
172
+ logger.info(f"Models: {models}")
173
+ return models
174
+
175
+
176
+ def load_demo_single(models, url_params):
177
+ selected_model = models[0] if len(models) > 0 else ""
178
+ if "model" in url_params:
179
+ model = url_params["model"]
180
+ if model in models:
181
+ selected_model = model
182
+
183
+ dropdown_update = gr.Dropdown.update(
184
+ choices=models, value=selected_model, visible=True
185
+ )
186
+
187
+ state = None
188
+ return state, dropdown_update
189
+
190
+
191
+ def load_demo(url_params, request: gr.Request):
192
+ global models
193
+
194
+ ip = get_ip(request)
195
+ logger.info(f"load_demo. ip: {ip}. params: {url_params}")
196
+ ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME
197
+
198
+ if args.model_list_mode == "reload":
199
+ models = get_model_list(
200
+ controller_url,
201
+ args.register_openai_compatible_models,
202
+ args.add_chatgpt,
203
+ args.add_claude,
204
+ args.add_palm,
205
+ )
206
+
207
+ return load_demo_single(models, url_params)
208
+
209
+
210
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
211
+ with open('./web_chat_downvote.jsonl', "a+") as fout:
212
+ # data = {
213
+ # "tstamp": round(time.time(), 4),
214
+ # "type": vote_type,
215
+ # "model": model_selector,
216
+ # "state": state.dict(),
217
+ # "ip": get_ip(request),
218
+ # }
219
+ conversations = []
220
+ for i, turn in enumerate(state.dict()['messages']):
221
+ role = 'user' if i % 2 == 0 else 'assistant'
222
+ conversations.append({'role': role, 'content': turn[1]})
223
+ data = {
224
+ 'conversations': conversations,
225
+ 'idx': state.dict()['conv_id'],
226
+ 'tinder': 'badcase',
227
+ 'model': state.dict()['model_name'],
228
+ 'tokens_in': -1,
229
+ 'tokens_out': -1,
230
+ }
231
+ fout.write(json.dumps(data, ensure_ascii=False) + "\n")
232
+
233
+
234
+ def upvote_last_response(state, model_selector, request: gr.Request):
235
+ ip = get_ip(request)
236
+ logger.info(f"upvote. ip: {ip}")
237
+ vote_last_response(state, "upvote", model_selector, request)
238
+ return ("",) + (disable_btn,) * 3
239
+
240
+
241
+ def downvote_last_response(state, model_selector, request: gr.Request):
242
+ ip = get_ip(request)
243
+ logger.info(f"downvote. ip: {ip}")
244
+ vote_last_response(state, "downvote", model_selector, request)
245
+ return ("",) + (disable_btn,) * 3
246
+
247
+
248
+ def flag_last_response(state, model_selector, request: gr.Request):
249
+ ip = get_ip(request)
250
+ logger.info(f"flag. ip: {ip}")
251
+ vote_last_response(state, "flag", model_selector, request)
252
+ return ("",) + (disable_btn,) * 3
253
+
254
+
255
+ def regenerate(state, request: gr.Request):
256
+ ip = get_ip(request)
257
+ logger.info(f"regenerate. ip: {ip}")
258
+ state.conv.update_last_message(None)
259
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
260
+
261
+
262
+ def clear_history(request: gr.Request):
263
+ ip = get_ip(request)
264
+ logger.info(f"clear_history. ip: {ip}")
265
+ state = None
266
+ return (state, [], "") + (disable_btn,) * 5
267
+
268
+
269
+ def get_ip(request: gr.Request):
270
+ if "cf-connecting-ip" in request.headers:
271
+ ip = request.headers["cf-connecting-ip"]
272
+ else:
273
+ ip = request.client.host
274
+ return ip
275
+
276
+
277
+ def add_text(state, model_selector, text, request: gr.Request):
278
+ ip = get_ip(request)
279
+ logger.info(f"add_text. ip: {ip}. len: {len(text)}")
280
+
281
+ if state is None:
282
+ state = State(model_selector)
283
+
284
+ if len(text) <= 0:
285
+ state.skip_next = True
286
+ return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
287
+
288
+ flagged = moderation_filter(text, [state.model_name])
289
+ if flagged:
290
+ logger.info(f"violate moderation. ip: {ip}. text: {text}")
291
+ # overwrite the original text
292
+ text = MODERATION_MSG
293
+
294
+ conv = state.conv
295
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
296
+ logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
297
+ state.skip_next = True
298
+ return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + (
299
+ no_change_btn,
300
+ ) * 5
301
+
302
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
303
+ conv.append_message(conv.roles[0], text)
304
+ conv.append_message(conv.roles[1], None)
305
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
306
+
307
+
308
+ def post_process_code(code):
309
+ sep = "\n```"
310
+ if sep in code:
311
+ blocks = code.split(sep)
312
+ if len(blocks) % 2 == 1:
313
+ for i in range(1, len(blocks), 2):
314
+ blocks[i] = blocks[i].replace("\\_", "_")
315
+ code = sep.join(blocks)
316
+ return code
317
+
318
+
319
+ def model_worker_stream_iter(
320
+ conv,
321
+ model_name,
322
+ worker_addr,
323
+ prompt,
324
+ temperature,
325
+ repetition_penalty,
326
+ top_p,
327
+ max_new_tokens,
328
+ ):
329
+ # Make requests
330
+ gen_params = {
331
+ "model": model_name,
332
+ "prompt": prompt,
333
+ "temperature": temperature,
334
+ "repetition_penalty": repetition_penalty,
335
+ "top_p": top_p,
336
+ "max_new_tokens": max_new_tokens,
337
+ "stop": conv.stop_str,
338
+ "stop_token_ids": conv.stop_token_ids,
339
+ "echo": False,
340
+ }
341
+ logger.info(f"==== request ====\n{gen_params}")
342
+
343
+ # Stream output
344
+ response = requests.post(
345
+ worker_addr + "/worker_generate_stream",
346
+ headers=headers,
347
+ json=gen_params,
348
+ stream=True,
349
+ timeout=WORKER_API_TIMEOUT,
350
+ )
351
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
352
+ if chunk:
353
+ data = json.loads(chunk.decode())
354
+ yield data
355
+
356
+
357
+ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request):
358
+ ip = get_ip(request)
359
+ logger.info(f"bot_response. ip: {ip}")
360
+ start_tstamp = time.time()
361
+ temperature = float(temperature)
362
+ top_p = float(top_p)
363
+ max_new_tokens = int(max_new_tokens)
364
+
365
+ if state.skip_next:
366
+ # This generate call is skipped due to invalid inputs
367
+ state.skip_next = False
368
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
369
+ return
370
+
371
+ conv, model_name = state.conv, state.model_name
372
+ if model_name in ["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", "gpt-3.5-turbo-1106"]:
373
+ prompt = conv.to_openai_api_messages()
374
+ stream_iter = openai_api_stream_iter(
375
+ model_name, prompt, temperature, top_p, max_new_tokens
376
+ )
377
+ elif model_name in ["claude-2", "claude-1", "claude-instant-1"]:
378
+ prompt = conv.get_prompt()
379
+ stream_iter = anthropic_api_stream_iter(
380
+ model_name, prompt, temperature, top_p, max_new_tokens
381
+ )
382
+ elif model_name == "palm-2":
383
+ stream_iter = palm_api_stream_iter(
384
+ state.palm_chat, conv.messages[-2][1], temperature, top_p, max_new_tokens
385
+ )
386
+ elif model_name in openai_compatible_models_info:
387
+ model_info = openai_compatible_models_info[model_name]
388
+ prompt = conv.to_openai_api_messages()
389
+ stream_iter = openai_api_stream_iter(
390
+ model_info["model_name"],
391
+ prompt,
392
+ temperature,
393
+ top_p,
394
+ max_new_tokens,
395
+ api_base=model_info["api_base"],
396
+ api_key=model_info["api_key"],
397
+ )
398
+ else:
399
+ # Query worker address
400
+ ret = requests.post(
401
+ controller_url + "/get_worker_address", json={"model": model_name}
402
+ )
403
+ worker_addr = ret.json()["address"]
404
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
405
+
406
+ # No available worker
407
+ if worker_addr == "":
408
+ conv.update_last_message(SERVER_ERROR_MSG)
409
+ yield (
410
+ state,
411
+ state.to_gradio_chatbot(),
412
+ disable_btn,
413
+ disable_btn,
414
+ disable_btn,
415
+ enable_btn,
416
+ enable_btn,
417
+ )
418
+ return
419
+
420
+ # Construct prompt.
421
+ # We need to call it here, so it will not be affected by "▌".
422
+ prompt = conv.get_prompt()
423
+ # Set repetition_penalty
424
+ if "t5" in model_name:
425
+ repetition_penalty = 1.2
426
+ else:
427
+ repetition_penalty = 1.0
428
+
429
+ stream_iter = model_worker_stream_iter(
430
+ conv,
431
+ model_name,
432
+ worker_addr,
433
+ prompt,
434
+ temperature,
435
+ repetition_penalty,
436
+ top_p,
437
+ max_new_tokens,
438
+ )
439
+
440
+ conv.update_last_message("▌")
441
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
442
+
443
+ try:
444
+ for i, data in enumerate(stream_iter):
445
+ if data["error_code"] == 0:
446
+ output = data["text"].strip()
447
+ conv.update_last_message(output + "▌")
448
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
449
+ else:
450
+ output = data["text"] + f"\n\n(error_code: {data['error_code']})"
451
+ conv.update_last_message(output)
452
+ yield (state, state.to_gradio_chatbot()) + (
453
+ disable_btn,
454
+ disable_btn,
455
+ disable_btn,
456
+ enable_btn,
457
+ enable_btn,
458
+ )
459
+ return
460
+ output = data["text"].strip()
461
+ if "vicuna" in model_name:
462
+ output = post_process_code(output)
463
+ conv.update_last_message(output)
464
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
465
+ except requests.exceptions.RequestException as e:
466
+ conv.update_last_message(
467
+ f"{SERVER_ERROR_MSG}\n\n"
468
+ f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
469
+ )
470
+ yield (state, state.to_gradio_chatbot()) + (
471
+ disable_btn,
472
+ disable_btn,
473
+ disable_btn,
474
+ enable_btn,
475
+ enable_btn,
476
+ )
477
+ return
478
+ except Exception as e:
479
+ conv.update_last_message(
480
+ f"{SERVER_ERROR_MSG}\n\n"
481
+ f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
482
+ )
483
+ yield (state, state.to_gradio_chatbot()) + (
484
+ disable_btn,
485
+ disable_btn,
486
+ disable_btn,
487
+ enable_btn,
488
+ enable_btn,
489
+ )
490
+ return
491
+
492
+ finish_tstamp = time.time()
493
+ logger.info(f"{output}")
494
+
495
+ with open(get_conv_log_filename(), "a") as fout:
496
+ data = {
497
+ "tstamp": round(finish_tstamp, 4),
498
+ "type": "chat",
499
+ "model": model_name,
500
+ "gen_params": {
501
+ "temperature": temperature,
502
+ "top_p": top_p,
503
+ "max_new_tokens": max_new_tokens,
504
+ },
505
+ "start": round(start_tstamp, 4),
506
+ "finish": round(finish_tstamp, 4),
507
+ "state": state.dict(),
508
+ "ip": get_ip(request),
509
+ }
510
+ fout.write(json.dumps(data) + "\n")
511
+
512
+
513
+ block_css = """
514
+ #notice_markdown {
515
+ font-size: 110%
516
+ }
517
+ #notice_markdown th {
518
+ display: none;
519
+ }
520
+ #notice_markdown td {
521
+ padding-top: 6px;
522
+ padding-bottom: 6px;
523
+ }
524
+ #leaderboard_markdown {
525
+ font-size: 110%
526
+ }
527
+ #leaderboard_markdown td {
528
+ padding-top: 6px;
529
+ padding-bottom: 6px;
530
+ }
531
+ #leaderboard_dataframe td {
532
+ line-height: 0.1em;
533
+ }
534
+ #about_markdown {
535
+ font-size: 110%
536
+ }
537
+ #input_box textarea {
538
+ }
539
+ footer {
540
+ display:none !important
541
+ }
542
+ .image-container {
543
+ display: flex;
544
+ align-items: center;
545
+ padding: 1px;
546
+ }
547
+ .image-container img {
548
+ margin: 0 30px;
549
+ height: 20px;
550
+ max-height: 100%;
551
+ width: auto;
552
+ max-width: 20%;
553
+ }
554
+ .image-about img {
555
+ margin: 0 30px;
556
+ margin-top: 30px;
557
+ height: 60px;
558
+ max-height: 100%;
559
+ width: auto;
560
+ float: left;
561
+ }
562
+ """
563
+
564
+
565
+ def get_model_description_md(models):
566
+ model_description_md = """
567
+ | | | |
568
+ | ---- | ---- | ---- |
569
+ """
570
+ ct = 0
571
+ visited = set()
572
+ for i, name in enumerate(models):
573
+ minfo = get_model_info(name)
574
+ if minfo.simple_name in visited:
575
+ continue
576
+ visited.add(minfo.simple_name)
577
+ one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
578
+
579
+ if ct % 3 == 0:
580
+ model_description_md += "|"
581
+ model_description_md += f" {one_model_md} |"
582
+ if ct % 3 == 2:
583
+ model_description_md += "\n"
584
+ ct += 1
585
+ return model_description_md
586
+
587
+
588
+ def build_about():
589
+ about_markdown = f"""
590
+ # About Us
591
+ Chatbot Arena is an open-source research project developed by members from [LMSYS](https://lmsys.org/about/) and UC Berkeley [SkyLab](https://sky.cs.berkeley.edu/). Our mission is to build an open crowdsourced platform to collect human feedback and evaluate LLMs under real-world scenarios. We open-source our code at [GitHub](https://github.com/lm-sys/FastChat) and release chat and human feedback datasets [here](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md). We invite everyone to join us in this journey!
592
+
593
+ ## Read More
594
+ - Chatbot Arena [launch post](https://lmsys.org/blog/2023-05-03-arena/), [data release](https://lmsys.org/blog/2023-07-20-dataset/)
595
+ - LMSYS-Chat-1M [report](https://arxiv.org/abs/2309.11998)
596
+
597
+ ## Core Members
598
+ [Lianmin Zheng](https://lmzheng.net/), [Wei-Lin Chiang](https://infwinston.github.io/), [Ying Sheng](https://sites.google.com/view/yingsheng/home), [Siyuan Zhuang](https://scholar.google.com/citations?user=KSZmI5EAAAAJ)
599
+
600
+ ## Advisors
601
+ [Ion Stoica](http://people.eecs.berkeley.edu/~istoica/), [Joseph E. Gonzalez](https://people.eecs.berkeley.edu/~jegonzal/), [Hao Zhang](https://cseweb.ucsd.edu/~haozhang/)
602
+
603
+ ## Contact Us
604
+ - Follow our [Twitter](https://twitter.com/lmsysorg), [Discord](https://discord.gg/HSWAKCrnFx) or email us at lmsys.org@gmail.com
605
+ - File issues on [GitHub](https://github.com/lm-sys/FastChat)
606
+ - Download our datasets and models on [HuggingFace](https://huggingface.co/lmsys)
607
+
608
+ ## Sponsors
609
+ We thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous sponsorship.
610
+ Learn more about partnership [here](https://lmsys.org/donations/).
611
+
612
+ <div class="image-about">
613
+ <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/7/7c/Kaggle_logo.png/400px-Kaggle_logo.png" alt="Image 1">
614
+ <img src="https://upload.wikimedia.org/wikipedia/en/5/55/Mohamed_bin_Zayed_University_of_Artificial_Intelligence_logo.png" alt="Image 2">
615
+ <img src="https://docs.anyscale.com/site-assets/logo.png" alt="Image 3">
616
+ <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png" alt="Image 4">
617
+ </div>
618
+ """
619
+
620
+ # state = gr.State()
621
+ gr.Markdown(about_markdown, elem_id="about_markdown")
622
+
623
+ # return [state]
624
+
625
+
626
+ def build_single_model_ui(models, add_promotion_links=False):
627
+ promotion = (
628
+ """
629
+ - | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
630
+ - Introducing Llama 2: The Next Generation Open Source Large Language Model. [[Website]](https://ai.meta.com/llama/)
631
+ - Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog]](https://lmsys.org/blog/2023-03-30-vicuna/)
632
+ """
633
+ if add_promotion_links
634
+ else ""
635
+ )
636
+
637
+ notice_markdown = f"""
638
+ # 🏔️ Chat with Open Large Language Models
639
+ {promotion}
640
+
641
+ ## 👉 Choose any model to chat
642
+ """
643
+
644
+ state = gr.State()
645
+ model_description_md = get_model_description_md(models)
646
+ gr.Markdown(notice_markdown + model_description_md, elem_id="notice_markdown")
647
+
648
+ with gr.Row(elem_id="model_selector_row"):
649
+ model_selector = gr.Dropdown(
650
+ choices=models,
651
+ value=models[0] if len(models) > 0 else "",
652
+ interactive=True,
653
+ show_label=False,
654
+ container=False,
655
+ )
656
+
657
+ chatbot = gr.Chatbot(
658
+ elem_id="chatbot",
659
+ label="Scroll down and start chatting",
660
+ height=550,
661
+ )
662
+ with gr.Row():
663
+ with gr.Column(scale=20):
664
+ textbox = gr.Textbox(
665
+ show_label=False,
666
+ placeholder="Enter your prompt here and press ENTER",
667
+ container=False,
668
+ elem_id="input_box",
669
+ )
670
+ with gr.Column(scale=1, min_width=50):
671
+ send_btn = gr.Button(value="Send", variant="primary")
672
+
673
+ with gr.Row() as button_row:
674
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
675
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
676
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
677
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
678
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
679
+
680
+ with gr.Accordion("Parameters", open=False) as parameter_row:
681
+ temperature = gr.Slider(
682
+ minimum=0.0,
683
+ maximum=1.0,
684
+ value=0.7,
685
+ step=0.1,
686
+ interactive=True,
687
+ label="Temperature",
688
+ )
689
+ top_p = gr.Slider(
690
+ minimum=0.0,
691
+ maximum=1.0,
692
+ value=1.0,
693
+ step=0.1,
694
+ interactive=True,
695
+ label="Top P",
696
+ )
697
+ max_output_tokens = gr.Slider(
698
+ minimum=16,
699
+ maximum=3072,
700
+ value=2048,
701
+ step=1,
702
+ interactive=True,
703
+ label="Max output tokens",
704
+ )
705
+
706
+ if add_promotion_links:
707
+ gr.Markdown(acknowledgment_md)
708
+
709
+ # Register listeners
710
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
711
+ upvote_btn.click(
712
+ upvote_last_response,
713
+ [state, model_selector],
714
+ [textbox, upvote_btn, downvote_btn, flag_btn],
715
+ )
716
+ downvote_btn.click(
717
+ downvote_last_response,
718
+ [state, model_selector],
719
+ [textbox, upvote_btn, downvote_btn, flag_btn],
720
+ )
721
+ flag_btn.click(
722
+ flag_last_response,
723
+ [state, model_selector],
724
+ [textbox, upvote_btn, downvote_btn, flag_btn],
725
+ )
726
+ regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
727
+ bot_response,
728
+ [state, temperature, top_p, max_output_tokens],
729
+ [state, chatbot] + btn_list,
730
+ )
731
+ clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
732
+
733
+ model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list)
734
+
735
+ textbox.submit(
736
+ add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list
737
+ ).then(
738
+ bot_response,
739
+ [state, temperature, top_p, max_output_tokens],
740
+ [state, chatbot] + btn_list,
741
+ )
742
+ send_btn.click(
743
+ add_text,
744
+ [state, model_selector, textbox],
745
+ [state, chatbot, textbox] + btn_list,
746
+ ).then(
747
+ bot_response,
748
+ [state, temperature, top_p, max_output_tokens],
749
+ [state, chatbot] + btn_list,
750
+ )
751
+
752
+ return [state, model_selector]
753
+
754
+
755
+ def build_demo(models):
756
+ with gr.Blocks(
757
+ title="Chat with Open Large Language Models",
758
+ theme=gr.themes.Default(),
759
+ css=block_css,
760
+ ) as demo:
761
+ url_params = gr.JSON(visible=False)
762
+
763
+ state, model_selector = build_single_model_ui(models)
764
+
765
+ if args.model_list_mode not in ["once", "reload"]:
766
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
767
+
768
+ if args.show_terms_of_use:
769
+ load_js = get_window_url_params_with_tos_js
770
+ else:
771
+ load_js = get_window_url_params_js
772
+
773
+ demo.load(
774
+ load_demo,
775
+ [url_params],
776
+ [
777
+ state,
778
+ model_selector,
779
+ ],
780
+ _js=load_js,
781
+ )
782
+
783
+ return demo
784
+
785
+
786
+ if __name__ == "__main__":
787
+ parser = argparse.ArgumentParser()
788
+ parser.add_argument("--host", type=str, default="0.0.0.0")
789
+ parser.add_argument("--port", type=int)
790
+ parser.add_argument(
791
+ "--conv-template",
792
+ type=str,
793
+ default="megrez",
794
+ help="The address of the controller",
795
+ )
796
+ parser.add_argument(
797
+ "--share",
798
+ action="store_true",
799
+ help="Whether to generate a public, shareable link",
800
+ )
801
+ parser.add_argument(
802
+ "--controller-url",
803
+ type=str,
804
+ default="http://localhost:21001",
805
+ help="The address of the controller",
806
+ )
807
+ parser.add_argument(
808
+ "--concurrency-count",
809
+ type=int,
810
+ default=10,
811
+ help="The concurrency count of the gradio queue",
812
+ )
813
+ parser.add_argument(
814
+ "--model-list-mode",
815
+ type=str,
816
+ default="once",
817
+ choices=["once", "reload"],
818
+ help="Whether to load the model list once or reload the model list every time",
819
+ )
820
+ parser.add_argument(
821
+ "--moderate",
822
+ action="store_true",
823
+ help="Enable content moderation to block unsafe inputs",
824
+ )
825
+ parser.add_argument(
826
+ "--show-terms-of-use",
827
+ action="store_true",
828
+ help="Shows term of use before loading the demo",
829
+ )
830
+ parser.add_argument(
831
+ "--add-chatgpt",
832
+ action="store_true",
833
+ help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)",
834
+ )
835
+ parser.add_argument(
836
+ "--add-claude",
837
+ action="store_true",
838
+ help="Add Anthropic's Claude models (claude-2, claude-instant-1)",
839
+ )
840
+ parser.add_argument(
841
+ "--add-palm",
842
+ action="store_true",
843
+ help="Add Google's PaLM model (PaLM 2 for Chat: chat-bison@001)",
844
+ )
845
+ parser.add_argument(
846
+ "--register-openai-compatible-models",
847
+ type=str,
848
+ help="Register custom OpenAI API compatible models by loading them from a JSON file",
849
+ )
850
+ parser.add_argument(
851
+ "--gradio-auth-path",
852
+ type=str,
853
+ help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
854
+ )
855
+ args = parser.parse_args()
856
+ logger.info(f"args: {args}")
857
+ CONV_TEMPLATE = args.conv_template
858
+ # Set global variables
859
+ set_global_vars(args.controller_url, args.moderate)
860
+ models = get_model_list(
861
+ args.controller_url,
862
+ args.register_openai_compatible_models,
863
+ args.add_chatgpt,
864
+ args.add_claude,
865
+ args.add_palm,
866
+ )
867
+ # Set authorization credentials
868
+ auth = None
869
+ if args.gradio_auth_path is not None:
870
+ auth = parse_gradio_auth_creds(args.gradio_auth_path)
871
+
872
+ # Launch the demo
873
+ demo = build_demo(models)
874
+ ret = demo.queue(
875
+ concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
876
+ ).launch(
877
+ server_name=args.host,
878
+ server_port=args.port,
879
+ share=args.share,
880
+ max_threads=200,
881
+ auth=auth,
882
+ )
883
+ from IPython import embed;embed()
gradio_web_server_multi.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The gradio demo server with multiple tabs.
3
+ It supports chatting with a single model or chatting with two models side-by-side.
4
+ """
5
+
6
+ import argparse
7
+ import pickle
8
+ import time
9
+
10
+ import gradio as gr
11
+
12
+ from fastchat.constants import (
13
+ SESSION_EXPIRATION_TIME,
14
+ )
15
+ from fastchat.serve.gradio_block_arena_anony import (
16
+ build_side_by_side_ui_anony,
17
+ load_demo_side_by_side_anony,
18
+ set_global_vars_anony,
19
+ )
20
+ from fastchat.serve.gradio_block_arena_named import (
21
+ build_side_by_side_ui_named,
22
+ load_demo_side_by_side_named,
23
+ set_global_vars_named,
24
+ )
25
+ from fastchat.serve.gradio_web_server import (
26
+ set_global_vars,
27
+ block_css,
28
+ build_single_model_ui,
29
+ build_about,
30
+ get_model_list,
31
+ load_demo_single,
32
+ ip_expiration_dict,
33
+ get_ip,
34
+ )
35
+ from fastchat.serve.monitor.monitor import build_leaderboard_tab
36
+ from fastchat.utils import (
37
+ build_logger,
38
+ get_window_url_params_js,
39
+ get_window_url_params_with_tos_js,
40
+ parse_gradio_auth_creds,
41
+ )
42
+
43
+ logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
44
+
45
+
46
+ def load_demo(url_params, request: gr.Request):
47
+ global models
48
+
49
+ ip = get_ip(request)
50
+ logger.info(f"load_demo. ip: {ip}. params: {url_params}")
51
+ ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME
52
+
53
+ selected = 0
54
+ if "arena" in url_params:
55
+ selected = 0
56
+ elif "compare" in url_params:
57
+ selected = 1
58
+ elif "single" in url_params:
59
+ selected = 2
60
+ elif "leaderboard" in url_params:
61
+ selected = 3
62
+
63
+ if args.model_list_mode == "reload":
64
+ if args.anony_only_for_proprietary_model:
65
+ models = get_model_list(
66
+ args.controller_url,
67
+ args.register_openai_compatible_models,
68
+ False,
69
+ False,
70
+ False,
71
+ )
72
+ else:
73
+ models = get_model_list(
74
+ args.controller_url,
75
+ args.register_openai_compatible_models,
76
+ args.add_chatgpt,
77
+ args.add_claude,
78
+ args.add_palm,
79
+ )
80
+
81
+ single_updates = load_demo_single(models, url_params)
82
+
83
+ models_anony = list(models)
84
+ if args.anony_only_for_proprietary_model:
85
+ # Only enable these models in anony battles.
86
+ if args.add_chatgpt:
87
+ models_anony += [
88
+ "gpt-4",
89
+ "gpt-3.5-turbo",
90
+ "gpt-4-turbo",
91
+ "gpt-3.5-turbo-1106",
92
+ ]
93
+ if args.add_claude:
94
+ models_anony += ["claude-2", "claude-1", "claude-instant-1"]
95
+ if args.add_palm:
96
+ models_anony += ["palm-2"]
97
+ models_anony = list(set(models_anony))
98
+
99
+ side_by_side_anony_updates = load_demo_side_by_side_anony(models_anony, url_params)
100
+ side_by_side_named_updates = load_demo_side_by_side_named(models, url_params)
101
+ return (
102
+ (gr.Tabs.update(selected=selected),)
103
+ + single_updates
104
+ + side_by_side_anony_updates
105
+ + side_by_side_named_updates
106
+ )
107
+
108
+
109
+ def build_demo(models, elo_results_file, leaderboard_table_file):
110
+ text_size = gr.themes.sizes.text_md
111
+ with gr.Blocks(
112
+ title="Chat with Open Large Language Models",
113
+ theme=gr.themes.Default(text_size=text_size),
114
+ css=block_css,
115
+ ) as demo:
116
+ with gr.Tabs() as tabs:
117
+ with gr.Tab("Arena (battle)", id=0):
118
+ side_by_side_anony_list = build_side_by_side_ui_anony(models)
119
+
120
+ with gr.Tab("Arena (side-by-side)", id=1):
121
+ side_by_side_named_list = build_side_by_side_ui_named(models)
122
+
123
+ with gr.Tab("Direct Chat", id=2):
124
+ single_model_list = build_single_model_ui(
125
+ models, add_promotion_links=True
126
+ )
127
+ if elo_results_file:
128
+ with gr.Tab("Leaderboard", id=3):
129
+ build_leaderboard_tab(elo_results_file, leaderboard_table_file)
130
+ with gr.Tab("About Us", id=4):
131
+ about = build_about()
132
+
133
+ url_params = gr.JSON(visible=False)
134
+
135
+ if args.model_list_mode not in ["once", "reload"]:
136
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
137
+
138
+ if args.show_terms_of_use:
139
+ load_js = get_window_url_params_with_tos_js
140
+ else:
141
+ load_js = get_window_url_params_js
142
+
143
+ demo.load(
144
+ load_demo,
145
+ [url_params],
146
+ [tabs]
147
+ + single_model_list
148
+ + side_by_side_anony_list
149
+ + side_by_side_named_list,
150
+ _js=load_js,
151
+ )
152
+
153
+ return demo
154
+
155
+
156
+ if __name__ == "__main__":
157
+ parser = argparse.ArgumentParser()
158
+ parser.add_argument("--host", type=str, default="0.0.0.0")
159
+ parser.add_argument("--port", type=int)
160
+ parser.add_argument(
161
+ "--share",
162
+ action="store_true",
163
+ help="Whether to generate a public, shareable link",
164
+ )
165
+ parser.add_argument(
166
+ "--controller-url",
167
+ type=str,
168
+ default="http://localhost:21001",
169
+ help="The address of the controller",
170
+ )
171
+ parser.add_argument(
172
+ "--concurrency-count",
173
+ type=int,
174
+ default=10,
175
+ help="The concurrency count of the gradio queue",
176
+ )
177
+ parser.add_argument(
178
+ "--model-list-mode",
179
+ type=str,
180
+ default="once",
181
+ choices=["once", "reload"],
182
+ help="Whether to load the model list once or reload the model list every time.",
183
+ )
184
+ parser.add_argument(
185
+ "--moderate",
186
+ action="store_true",
187
+ help="Enable content moderation to block unsafe inputs",
188
+ )
189
+ parser.add_argument(
190
+ "--show-terms-of-use",
191
+ action="store_true",
192
+ help="Shows term of use before loading the demo",
193
+ )
194
+ parser.add_argument(
195
+ "--add-chatgpt",
196
+ action="store_true",
197
+ help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)",
198
+ )
199
+ parser.add_argument(
200
+ "--add-claude",
201
+ action="store_true",
202
+ help="Add Anthropic's Claude models (claude-2, claude-instant-1)",
203
+ )
204
+ parser.add_argument(
205
+ "--add-palm",
206
+ action="store_true",
207
+ help="Add Google's PaLM model (PaLM 2 for Chat: chat-bison@001)",
208
+ )
209
+ parser.add_argument(
210
+ "--anony-only-for-proprietary-model",
211
+ action="store_true",
212
+ help="Only add ChatGPT, Claude, Bard under anony battle tab",
213
+ )
214
+ parser.add_argument(
215
+ "--register-openai-compatible-models",
216
+ type=str,
217
+ help="Register custom OpenAI API compatible models by loading them from a JSON file",
218
+ )
219
+ parser.add_argument(
220
+ "--gradio-auth-path",
221
+ type=str,
222
+ help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
223
+ default=None,
224
+ )
225
+ parser.add_argument(
226
+ "--elo-results-file", type=str, help="Load leaderboard results and plots"
227
+ )
228
+ parser.add_argument(
229
+ "--leaderboard-table-file", type=str, help="Load leaderboard results and plots"
230
+ )
231
+ args = parser.parse_args()
232
+ logger.info(f"args: {args}")
233
+
234
+ # Set global variables
235
+ set_global_vars(args.controller_url, args.moderate)
236
+ set_global_vars_named(args.moderate)
237
+ set_global_vars_anony(args.moderate)
238
+ if args.anony_only_for_proprietary_model:
239
+ models = get_model_list(
240
+ args.controller_url,
241
+ args.register_openai_compatible_models,
242
+ False,
243
+ False,
244
+ False,
245
+ )
246
+ else:
247
+ models = get_model_list(
248
+ args.controller_url,
249
+ args.register_openai_compatible_models,
250
+ args.add_chatgpt,
251
+ args.add_claude,
252
+ args.add_palm,
253
+ )
254
+
255
+ # Set authorization credentials
256
+ auth = None
257
+ if args.gradio_auth_path is not None:
258
+ auth = parse_gradio_auth_creds(args.gradio_auth_path)
259
+
260
+ # Launch the demo
261
+ demo = build_demo(models, args.elo_results_file, args.leaderboard_table_file)
262
+ demo.queue(
263
+ concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
264
+ ).launch(
265
+ server_name=args.host,
266
+ server_port=args.port,
267
+ share=args.share,
268
+ max_threads=200,
269
+ auth=auth,
270
+ )
huggingface_api.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Use FastChat with Hugging Face generation APIs.
3
+
4
+ Usage:
5
+ python3 -m fastchat.serve.huggingface_api --model lmsys/vicuna-7b-v1.5
6
+ python3 -m fastchat.serve.huggingface_api --model lmsys/fastchat-t5-3b-v1.0
7
+ """
8
+ import argparse
9
+
10
+ import torch
11
+
12
+ from fastchat.model import load_model, get_conversation_template, add_model_args
13
+
14
+
15
+ @torch.inference_mode()
16
+ def main(args):
17
+ # Load model
18
+ model, tokenizer = load_model(
19
+ args.model_path,
20
+ device=args.device,
21
+ num_gpus=args.num_gpus,
22
+ max_gpu_memory=args.max_gpu_memory,
23
+ load_8bit=args.load_8bit,
24
+ cpu_offloading=args.cpu_offloading,
25
+ revision=args.revision,
26
+ debug=args.debug,
27
+ )
28
+
29
+ # Build the prompt with a conversation template
30
+ msg = args.message
31
+ conv = get_conversation_template(args.model_path)
32
+ conv.append_message(conv.roles[0], msg)
33
+ conv.append_message(conv.roles[1], None)
34
+ prompt = conv.get_prompt()
35
+
36
+ # Run inference
37
+ inputs = tokenizer([prompt], return_tensors="pt").to(args.device)
38
+ output_ids = model.generate(
39
+ **inputs,
40
+ do_sample=True if args.temperature > 1e-5 else False,
41
+ temperature=args.temperature,
42
+ repetition_penalty=args.repetition_penalty,
43
+ max_new_tokens=args.max_new_tokens,
44
+ )
45
+
46
+ if model.config.is_encoder_decoder:
47
+ output_ids = output_ids[0]
48
+ else:
49
+ output_ids = output_ids[0][len(inputs["input_ids"][0]) :]
50
+ outputs = tokenizer.decode(
51
+ output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
52
+ )
53
+
54
+ # Print results
55
+ print(f"{conv.roles[0]}: {msg}")
56
+ print(f"{conv.roles[1]}: {outputs}")
57
+
58
+
59
+ if __name__ == "__main__":
60
+ parser = argparse.ArgumentParser()
61
+ add_model_args(parser)
62
+ parser.add_argument("--temperature", type=float, default=0.7)
63
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
64
+ parser.add_argument("--max-new-tokens", type=int, default=512)
65
+ parser.add_argument("--debug", action="store_true")
66
+ parser.add_argument("--message", type=str, default="Hello! Who are you?")
67
+ args = parser.parse_args()
68
+
69
+ # Reset default repetition penalty for T5 models.
70
+ if "t5" in args.model_path and args.repetition_penalty == 1.0:
71
+ args.repetition_penalty = 1.2
72
+
73
+ main(args)
huggingface_api_worker.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker that calls huggingface inference endpoint.
3
+
4
+ Register models in a JSON file with the following format:
5
+ {
6
+ "falcon-180b-chat": {
7
+ "model_path": "tiiuae/falcon-180B-chat",
8
+ "api_base": "https://api-inference.huggingface.co/models",
9
+ "token": "hf_xxx",
10
+ "context_length": 2048,
11
+ "model_names": "falcon-180b-chat",
12
+ "conv_template": null
13
+ }
14
+ }
15
+
16
+ "model_path", "api_base", "token", and "context_length" are necessary, while others are optional.
17
+ """
18
+ import argparse
19
+ import asyncio
20
+ import json
21
+ import uuid
22
+ from typing import List, Optional
23
+
24
+ import requests
25
+ import uvicorn
26
+ from fastapi import BackgroundTasks, FastAPI, Request
27
+ from fastapi.responses import JSONResponse, StreamingResponse
28
+ from huggingface_hub import InferenceClient
29
+
30
+ from fastchat.constants import SERVER_ERROR_MSG, ErrorCode
31
+ from fastchat.serve.base_model_worker import BaseModelWorker
32
+ from fastchat.utils import build_logger
33
+
34
+ worker_id = str(uuid.uuid4())[:8]
35
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
36
+
37
+ workers = []
38
+ worker_map = {}
39
+ app = FastAPI()
40
+
41
+
42
+ # reference to
43
+ # https://github.com/philschmid/easyllm/blob/cbd908b3b3f44a97a22cb0fc2c93df3660bacdad/easyllm/clients/huggingface.py#L374-L392
44
+ def get_gen_kwargs(
45
+ params,
46
+ seed: Optional[int] = None,
47
+ ):
48
+ stop = params.get("stop", None)
49
+ if isinstance(stop, list):
50
+ stop_sequences = stop
51
+ elif isinstance(stop, str):
52
+ stop_sequences = [stop]
53
+ else:
54
+ stop_sequences = []
55
+ gen_kwargs = {
56
+ "do_sample": True,
57
+ "return_full_text": bool(params.get("echo", False)),
58
+ "max_new_tokens": int(params.get("max_new_tokens", 256)),
59
+ "top_p": float(params.get("top_p", 1.0)),
60
+ "temperature": float(params.get("temperature", 1.0)),
61
+ "stop_sequences": stop_sequences,
62
+ "repetition_penalty": float(params.get("repetition_penalty", 1.0)),
63
+ "top_k": params.get("top_k", None),
64
+ "seed": seed,
65
+ }
66
+ if gen_kwargs["top_p"] == 1:
67
+ gen_kwargs["top_p"] = 0.9999999
68
+ if gen_kwargs["top_p"] == 0:
69
+ gen_kwargs.pop("top_p")
70
+ if gen_kwargs["temperature"] == 0:
71
+ gen_kwargs.pop("temperature")
72
+ gen_kwargs["do_sample"] = False
73
+ return gen_kwargs
74
+
75
+
76
+ def could_be_stop(text, stop):
77
+ for s in stop:
78
+ if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)):
79
+ return True
80
+ return False
81
+
82
+
83
+ class HuggingfaceApiWorker(BaseModelWorker):
84
+ def __init__(
85
+ self,
86
+ controller_addr: str,
87
+ worker_addr: str,
88
+ worker_id: str,
89
+ model_path: str,
90
+ api_base: str,
91
+ token: str,
92
+ context_length: int,
93
+ model_names: List[str],
94
+ limit_worker_concurrency: int,
95
+ no_register: bool,
96
+ conv_template: Optional[str] = None,
97
+ seed: Optional[int] = None,
98
+ **kwargs,
99
+ ):
100
+ super().__init__(
101
+ controller_addr,
102
+ worker_addr,
103
+ worker_id,
104
+ model_path,
105
+ model_names,
106
+ limit_worker_concurrency,
107
+ conv_template=conv_template,
108
+ )
109
+
110
+ self.model_path = model_path
111
+ self.api_base = api_base
112
+ self.token = token
113
+ self.context_len = context_length
114
+ self.seed = seed
115
+
116
+ logger.info(
117
+ f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..."
118
+ )
119
+
120
+ if not no_register:
121
+ self.init_heart_beat()
122
+
123
+ def count_token(self, params):
124
+ # No tokenizer here
125
+ ret = {
126
+ "count": 0,
127
+ "error_code": 0,
128
+ }
129
+ return ret
130
+
131
+ def generate_stream_gate(self, params):
132
+ self.call_ct += 1
133
+
134
+ prompt = params["prompt"]
135
+ gen_kwargs = get_gen_kwargs(params, seed=self.seed)
136
+ stop = gen_kwargs["stop_sequences"]
137
+ if "falcon" in self.model_path and "chat" in self.model_path:
138
+ stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"])
139
+ stop = list(set(stop))
140
+ gen_kwargs["stop_sequences"] = stop
141
+
142
+ logger.info(f"prompt: {prompt}")
143
+ logger.info(f"gen_kwargs: {gen_kwargs}")
144
+
145
+ try:
146
+ if self.model_path == "":
147
+ url = f"{self.api_base}"
148
+ else:
149
+ url = f"{self.api_base}/{self.model_path}"
150
+ client = InferenceClient(url, token=self.token)
151
+ res = client.text_generation(
152
+ prompt, stream=True, details=True, **gen_kwargs
153
+ )
154
+
155
+ reason = None
156
+ text = ""
157
+ for chunk in res:
158
+ if chunk.token.special:
159
+ continue
160
+ text += chunk.token.text
161
+
162
+ s = next((x for x in stop if text.endswith(x)), None)
163
+ if s is not None:
164
+ text = text[: -len(s)]
165
+ reason = "stop"
166
+ break
167
+ if could_be_stop(text, stop):
168
+ continue
169
+ if (
170
+ chunk.details is not None
171
+ and chunk.details.finish_reason is not None
172
+ ):
173
+ reason = chunk.details.finish_reason
174
+ if reason not in ["stop", "length"]:
175
+ reason = None
176
+ ret = {
177
+ "text": text,
178
+ "error_code": 0,
179
+ "finish_reason": reason,
180
+ }
181
+ yield json.dumps(ret).encode() + b"\0"
182
+ except Exception as e:
183
+ ret = {
184
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
185
+ "error_code": ErrorCode.INTERNAL_ERROR,
186
+ }
187
+ yield json.dumps(ret).encode() + b"\0"
188
+
189
+ def generate_gate(self, params):
190
+ for x in self.generate_stream_gate(params):
191
+ pass
192
+ return json.loads(x[:-1].decode())
193
+
194
+ def get_embeddings(self, params):
195
+ raise NotImplementedError()
196
+
197
+
198
+ def release_worker_semaphore(worker):
199
+ worker.semaphore.release()
200
+
201
+
202
+ def acquire_worker_semaphore(worker):
203
+ if worker.semaphore is None:
204
+ worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
205
+ return worker.semaphore.acquire()
206
+
207
+
208
+ def create_background_tasks(worker):
209
+ background_tasks = BackgroundTasks()
210
+ background_tasks.add_task(lambda: release_worker_semaphore(worker))
211
+ return background_tasks
212
+
213
+
214
+ @app.post("/worker_generate_stream")
215
+ async def api_generate_stream(request: Request):
216
+ params = await request.json()
217
+ worker = worker_map[params["model"]]
218
+ await acquire_worker_semaphore(worker)
219
+ generator = worker.generate_stream_gate(params)
220
+ background_tasks = create_background_tasks(worker)
221
+ return StreamingResponse(generator, background=background_tasks)
222
+
223
+
224
+ @app.post("/worker_generate")
225
+ async def api_generate(request: Request):
226
+ params = await request.json()
227
+ worker = worker_map[params["model"]]
228
+ await acquire_worker_semaphore(worker)
229
+ output = worker.generate_gate(params)
230
+ release_worker_semaphore(worker)
231
+ return JSONResponse(output)
232
+
233
+
234
+ @app.post("/worker_get_embeddings")
235
+ async def api_get_embeddings(request: Request):
236
+ params = await request.json()
237
+ worker = worker_map[params["model"]]
238
+ await acquire_worker_semaphore(worker)
239
+ embedding = worker.get_embeddings(params)
240
+ release_worker_semaphore(worker)
241
+ return JSONResponse(content=embedding)
242
+
243
+
244
+ @app.post("/worker_get_status")
245
+ async def api_get_status(request: Request):
246
+ return {
247
+ "model_names": [m for w in workers for m in w.model_names],
248
+ "speed": 1,
249
+ "queue_length": sum([w.get_queue_length() for w in workers]),
250
+ }
251
+
252
+
253
+ @app.post("/count_token")
254
+ async def api_count_token(request: Request):
255
+ params = await request.json()
256
+ worker = worker_map[params["model"]]
257
+ return worker.count_token(params)
258
+
259
+
260
+ @app.post("/worker_get_conv_template")
261
+ async def api_get_conv(request: Request):
262
+ params = await request.json()
263
+ worker = worker_map[params["model"]]
264
+ return worker.get_conv_template()
265
+
266
+
267
+ @app.post("/model_details")
268
+ async def api_model_details(request: Request):
269
+ params = await request.json()
270
+ worker = worker_map[params["model"]]
271
+ return {"context_length": worker.context_len}
272
+
273
+
274
+ def create_huggingface_api_worker():
275
+ parser = argparse.ArgumentParser()
276
+ parser.add_argument("--host", type=str, default="localhost")
277
+ parser.add_argument("--port", type=int, default=21002)
278
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
279
+ parser.add_argument(
280
+ "--controller-address", type=str, default="http://localhost:21001"
281
+ )
282
+ # all model-related parameters are listed in --model-info-file
283
+ parser.add_argument(
284
+ "--model-info-file",
285
+ type=str,
286
+ required=True,
287
+ help="Huggingface API model's info file path",
288
+ )
289
+
290
+ parser.add_argument(
291
+ "--limit-worker-concurrency",
292
+ type=int,
293
+ default=5,
294
+ help="Limit the model concurrency to prevent OOM.",
295
+ )
296
+ parser.add_argument("--no-register", action="store_true")
297
+ parser.add_argument(
298
+ "--seed",
299
+ type=int,
300
+ default=None,
301
+ help="Overwrite the random seed for each generation.",
302
+ )
303
+ args = parser.parse_args()
304
+
305
+ with open(args.model_info_file, "r", encoding="UTF-8") as f:
306
+ model_info = json.load(f)
307
+
308
+ logger.info(f"args: {args}")
309
+
310
+ model_path_list = []
311
+ api_base_list = []
312
+ token_list = []
313
+ context_length_list = []
314
+ model_names_list = []
315
+ conv_template_list = []
316
+
317
+ for m in model_info:
318
+ model_path_list.append(model_info[m]["model_path"])
319
+ api_base_list.append(model_info[m]["api_base"])
320
+ token_list.append(model_info[m]["token"])
321
+
322
+ context_length = model_info[m]["context_length"]
323
+ model_names = model_info[m].get("model_names", [m.split("/")[-1]])
324
+ if isinstance(model_names, str):
325
+ model_names = [model_names]
326
+ conv_template = model_info[m].get("conv_template", None)
327
+
328
+ context_length_list.append(context_length)
329
+ model_names_list.append(model_names)
330
+ conv_template_list.append(conv_template)
331
+
332
+ logger.info(f"Model paths: {model_path_list}")
333
+ logger.info(f"API bases: {api_base_list}")
334
+ logger.info(f"Tokens: {token_list}")
335
+ logger.info(f"Context lengths: {context_length_list}")
336
+ logger.info(f"Model names: {model_names_list}")
337
+ logger.info(f"Conv templates: {conv_template_list}")
338
+
339
+ for (
340
+ model_names,
341
+ conv_template,
342
+ model_path,
343
+ api_base,
344
+ token,
345
+ context_length,
346
+ ) in zip(
347
+ model_names_list,
348
+ conv_template_list,
349
+ model_path_list,
350
+ api_base_list,
351
+ token_list,
352
+ context_length_list,
353
+ ):
354
+ m = HuggingfaceApiWorker(
355
+ args.controller_address,
356
+ args.worker_address,
357
+ worker_id,
358
+ model_path,
359
+ api_base,
360
+ token,
361
+ context_length,
362
+ model_names,
363
+ args.limit_worker_concurrency,
364
+ no_register=args.no_register,
365
+ conv_template=conv_template,
366
+ seed=args.seed,
367
+ )
368
+ workers.append(m)
369
+ for name in model_names:
370
+ worker_map[name] = m
371
+
372
+ # register all the models
373
+ url = args.controller_address + "/register_worker"
374
+ data = {
375
+ "worker_name": workers[0].worker_addr,
376
+ "check_heart_beat": not args.no_register,
377
+ "worker_status": {
378
+ "model_names": [m for w in workers for m in w.model_names],
379
+ "speed": 1,
380
+ "queue_length": sum([w.get_queue_length() for w in workers]),
381
+ },
382
+ }
383
+ r = requests.post(url, json=data)
384
+ assert r.status_code == 200
385
+
386
+ return args, workers
387
+
388
+
389
+ if __name__ == "__main__":
390
+ args, workers = create_huggingface_api_worker()
391
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
inference.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference for FastChat models."""
2
+ import abc
3
+ import gc
4
+ import json
5
+ import math
6
+ import os
7
+ import sys
8
+ import time
9
+ from typing import Iterable, Optional, Dict
10
+ import warnings
11
+
12
+ import psutil
13
+ import torch
14
+ from transformers import (
15
+ AutoTokenizer,
16
+ AutoModelForCausalLM,
17
+ LlamaTokenizer,
18
+ LlamaForCausalLM,
19
+ AutoModel,
20
+ AutoModelForSeq2SeqLM,
21
+ T5Tokenizer,
22
+ AutoConfig,
23
+ )
24
+ from transformers.generation.logits_process import (
25
+ LogitsProcessorList,
26
+ RepetitionPenaltyLogitsProcessor,
27
+ TemperatureLogitsWarper,
28
+ TopKLogitsWarper,
29
+ TopPLogitsWarper,
30
+ )
31
+
32
+ from fastchat.conversation import get_conv_template, SeparatorStyle
33
+ from fastchat.model.model_adapter import (
34
+ load_model,
35
+ get_conversation_template,
36
+ get_generate_stream_function,
37
+ )
38
+ from fastchat.modules.awq import AWQConfig
39
+ from fastchat.modules.gptq import GptqConfig
40
+ from fastchat.modules.exllama import ExllamaConfig
41
+ from fastchat.modules.xfastertransformer import XftConfig
42
+ from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
43
+
44
+
45
+ def prepare_logits_processor(
46
+ temperature: float, repetition_penalty: float, top_p: float, top_k: int
47
+ ) -> LogitsProcessorList:
48
+ processor_list = LogitsProcessorList()
49
+ # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
50
+ if temperature >= 1e-5 and temperature != 1.0:
51
+ processor_list.append(TemperatureLogitsWarper(temperature))
52
+ if repetition_penalty > 1.0:
53
+ processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
54
+ if 1e-8 <= top_p < 1.0:
55
+ processor_list.append(TopPLogitsWarper(top_p))
56
+ if top_k > 0:
57
+ processor_list.append(TopKLogitsWarper(top_k))
58
+ return processor_list
59
+
60
+
61
+ @torch.inference_mode()
62
+ def generate_stream(
63
+ model,
64
+ tokenizer,
65
+ params: Dict,
66
+ device: str,
67
+ context_len: int,
68
+ stream_interval: int = 2,
69
+ judge_sent_end: bool = False,
70
+ ):
71
+ if hasattr(model, "device"):
72
+ device = model.device
73
+
74
+ # Read parameters
75
+ prompt = params["prompt"]
76
+ len_prompt = len(prompt)
77
+ temperature = float(params.get("temperature", 1.0))
78
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
79
+ top_p = float(params.get("top_p", 1.0))
80
+ top_k = int(params.get("top_k", -1)) # -1 means disable
81
+ max_new_tokens = int(params.get("max_new_tokens", 256))
82
+ logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1.
83
+ echo = bool(params.get("echo", True))
84
+ stop_str = params.get("stop", None)
85
+ stop_token_ids = params.get("stop_token_ids", None) or []
86
+ if tokenizer.eos_token_id not in stop_token_ids:
87
+ stop_token_ids.append(tokenizer.eos_token_id)
88
+ if params.get('none_stop'):
89
+ stop_token_ids = []
90
+ skip_special_tokens = params.get('skip_special_tokens')
91
+
92
+ logits_processor = prepare_logits_processor(
93
+ temperature, repetition_penalty, top_p, top_k
94
+ )
95
+ input_ids = tokenizer(prompt).input_ids
96
+
97
+ if model.config.is_encoder_decoder:
98
+ max_src_len = context_len
99
+ else: # truncate
100
+ max_src_len = context_len - max_new_tokens - 1
101
+
102
+ input_ids = input_ids[-max_src_len:]
103
+ output_ids = list(input_ids)
104
+ input_echo_len = len(input_ids)
105
+
106
+ if model.config.is_encoder_decoder:
107
+ if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models.
108
+ raise NotImplementedError
109
+ encoder_output = model.encoder(
110
+ input_ids=torch.as_tensor([input_ids], device=device)
111
+ )[0]
112
+ start_ids = torch.as_tensor(
113
+ [[model.generation_config.decoder_start_token_id]],
114
+ dtype=torch.int64,
115
+ device=device,
116
+ )
117
+ else:
118
+ start_ids = torch.as_tensor([input_ids], device=device)
119
+
120
+ past_key_values = out = None
121
+ token_logprobs = [None] # The first token has no logprobs.
122
+ sent_interrupt = False
123
+ finish_reason = None
124
+ for i in range(max_new_tokens):
125
+ if i == 0: # prefill
126
+ if model.config.is_encoder_decoder:
127
+ out = model.decoder(
128
+ input_ids=start_ids,
129
+ encoder_hidden_states=encoder_output,
130
+ use_cache=True,
131
+ )
132
+ logits = model.lm_head(out[0])
133
+ else:
134
+ out = model(input_ids=start_ids, use_cache=True)
135
+ logits = out.logits
136
+ past_key_values = out.past_key_values
137
+
138
+ if logprobs is not None:
139
+ # Prefull logprobs for the prompt.
140
+ shift_input_ids = start_ids[..., 1:].contiguous()
141
+ shift_logits = logits[..., :-1, :].contiguous()
142
+ shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist()
143
+ for label_id, logit in zip(
144
+ shift_input_ids[0].tolist(), shift_logits[0]
145
+ ):
146
+ token_logprobs.append(logit[label_id])
147
+ else: # decoding
148
+ if model.config.is_encoder_decoder:
149
+ out = model.decoder(
150
+ input_ids=torch.as_tensor(
151
+ [[token] if not sent_interrupt else output_ids],
152
+ device=device,
153
+ ),
154
+ encoder_hidden_states=encoder_output,
155
+ use_cache=True,
156
+ past_key_values=past_key_values if not sent_interrupt else None,
157
+ )
158
+ sent_interrupt = False
159
+
160
+ logits = model.lm_head(out[0])
161
+ else:
162
+ out = model(
163
+ input_ids=torch.as_tensor(
164
+ [[token] if not sent_interrupt else output_ids],
165
+ device=device,
166
+ ),
167
+ use_cache=True,
168
+ past_key_values=past_key_values if not sent_interrupt else None,
169
+ )
170
+ sent_interrupt = False
171
+ logits = out.logits
172
+ past_key_values = out.past_key_values
173
+
174
+ if logits_processor:
175
+ if repetition_penalty > 1.0:
176
+ tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
177
+ else:
178
+ tmp_output_ids = None
179
+ last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
180
+ else:
181
+ last_token_logits = logits[0, -1, :]
182
+
183
+ if device == "mps":
184
+ # Switch to CPU by avoiding some bugs in mps backend.
185
+ last_token_logits = last_token_logits.float().to("cpu")
186
+
187
+ if temperature < 1e-5 or top_p < 1e-8: # greedy
188
+ _, indices = torch.topk(last_token_logits, 2)
189
+ tokens = [int(index) for index in indices.tolist()]
190
+ else:
191
+ probs = torch.softmax(last_token_logits, dim=-1)
192
+ indices = torch.multinomial(probs, num_samples=2)
193
+ tokens = [int(token) for token in indices.tolist()]
194
+ token = tokens[0]
195
+ output_ids.append(token)
196
+ if logprobs is not None:
197
+ # Cannot use last_token_logits because logprobs is based on raw logits.
198
+ token_logprobs.append(
199
+ torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist()
200
+ )
201
+
202
+ if token in stop_token_ids:
203
+ stopped = True
204
+ else:
205
+ stopped = False
206
+
207
+ # Yield the output tokens
208
+ if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
209
+ if echo:
210
+ tmp_output_ids = output_ids
211
+ rfind_start = len_prompt
212
+ else:
213
+ tmp_output_ids = output_ids[input_echo_len:]
214
+ rfind_start = 0
215
+
216
+ output = tokenizer.decode(
217
+ tmp_output_ids,
218
+ skip_special_tokens=skip_special_tokens,
219
+ spaces_between_special_tokens=False,
220
+ clean_up_tokenization_spaces=True,
221
+ )
222
+ ret_logprobs = None
223
+ if logprobs is not None:
224
+ ret_logprobs = {
225
+ "text_offset": [],
226
+ "tokens": [
227
+ tokenizer.decode(token)
228
+ for token in (
229
+ output_ids if echo else output_ids[input_echo_len:]
230
+ )
231
+ ],
232
+ "token_logprobs": token_logprobs
233
+ if echo
234
+ else token_logprobs[input_echo_len:],
235
+ "top_logprobs": [{}]
236
+ * len(token_logprobs if echo else token_logprobs[input_echo_len:]),
237
+ }
238
+ # Compute text_offset
239
+ curr_pos = 0
240
+ for text in ret_logprobs["tokens"]:
241
+ ret_logprobs["text_offset"].append(curr_pos)
242
+ curr_pos += len(text)
243
+
244
+ # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
245
+ if judge_sent_end and stopped and not is_sentence_complete(output):
246
+ if len(tokens) > 1:
247
+ token = tokens[1]
248
+ output_ids[-1] = token
249
+ else:
250
+ output_ids.pop()
251
+ stopped = False
252
+ sent_interrupt = True
253
+
254
+ partially_stopped = False
255
+ if stop_str:
256
+ if isinstance(stop_str, str):
257
+ pos = output.rfind(stop_str, rfind_start)
258
+ if pos != -1:
259
+ output = output[:pos]
260
+ stopped = True
261
+ else:
262
+ partially_stopped = is_partial_stop(output, stop_str)
263
+ elif isinstance(stop_str, Iterable):
264
+ for each_stop in stop_str:
265
+ pos = output.rfind(each_stop, rfind_start)
266
+ if pos != -1:
267
+ output = output[:pos]
268
+ stopped = True
269
+ break
270
+ else:
271
+ partially_stopped = is_partial_stop(output, each_stop)
272
+ if partially_stopped:
273
+ break
274
+ else:
275
+ raise ValueError("Invalid stop field type.")
276
+
277
+ # Prevent yielding partial stop sequence
278
+ if not partially_stopped:
279
+ yield {
280
+ "text": output,
281
+ "logprobs": ret_logprobs,
282
+ "usage": {
283
+ "prompt_tokens": input_echo_len,
284
+ "completion_tokens": i,
285
+ "total_tokens": input_echo_len + i,
286
+ },
287
+ "finish_reason": None,
288
+ }
289
+
290
+ if stopped:
291
+ break
292
+
293
+ # Finish stream event, which contains finish reason
294
+ else:
295
+ finish_reason = "length"
296
+
297
+ if stopped:
298
+ finish_reason = "stop"
299
+
300
+ yield {
301
+ "text": output,
302
+ "logprobs": ret_logprobs,
303
+ "usage": {
304
+ "prompt_tokens": input_echo_len,
305
+ "completion_tokens": i,
306
+ "total_tokens": input_echo_len + i,
307
+ },
308
+ "finish_reason": finish_reason,
309
+ }
310
+
311
+ # Clean
312
+ del past_key_values, out
313
+ gc.collect()
314
+ torch.cuda.empty_cache()
315
+ if device == "xpu":
316
+ torch.xpu.empty_cache()
317
+ if device == "npu":
318
+ torch.npu.empty_cache()
319
+
320
+
321
+ class ChatIO(abc.ABC):
322
+ @abc.abstractmethod
323
+ def prompt_for_input(self, role: str) -> str:
324
+ """Prompt for input from a role."""
325
+
326
+ @abc.abstractmethod
327
+ def prompt_for_output(self, role: str):
328
+ """Prompt for output from a role."""
329
+
330
+ @abc.abstractmethod
331
+ def stream_output(self, output_stream):
332
+ """Stream output."""
333
+
334
+ @abc.abstractmethod
335
+ def print_output(self, text: str):
336
+ """Print output."""
337
+
338
+
339
+ def convert_message_format(message):
340
+ formated_message = []
341
+ for i, turn in enumerate(message):
342
+ role = 'user' if i % 2 == 0 else 'assistant'
343
+ formated_message.append({'role': role, 'content': turn[1]})
344
+
345
+ data = {
346
+ 'conversations': formated_message,
347
+ 'idx': -1,
348
+ 'tinder': 'badcase',
349
+ 'model': '',
350
+ 'tokens_in': 0,
351
+ 'tokens_out': 0,
352
+ }
353
+
354
+ return data
355
+
356
+
357
+ def chat_loop(
358
+ model_path: str,
359
+ device: str,
360
+ num_gpus: int,
361
+ max_gpu_memory: str,
362
+ dtype: Optional[torch.dtype],
363
+ load_8bit: bool,
364
+ cpu_offloading: bool,
365
+ conv_template: Optional[str],
366
+ conv_system_msg: Optional[str],
367
+ temperature: float,
368
+ repetition_penalty: float,
369
+ max_new_tokens: int,
370
+ chatio: ChatIO,
371
+ gptq_config: Optional[GptqConfig] = None,
372
+ awq_config: Optional[AWQConfig] = None,
373
+ exllama_config: Optional[ExllamaConfig] = None,
374
+ xft_config: Optional[XftConfig] = None,
375
+ revision: str = "main",
376
+ judge_sent_end: bool = True,
377
+ debug: bool = True,
378
+ history: bool = True,
379
+ ):
380
+ # Model
381
+ model, tokenizer = load_model(
382
+ model_path,
383
+ device=device,
384
+ num_gpus=num_gpus,
385
+ max_gpu_memory=max_gpu_memory,
386
+ dtype=dtype,
387
+ load_8bit=load_8bit,
388
+ cpu_offloading=cpu_offloading,
389
+ gptq_config=gptq_config,
390
+ awq_config=awq_config,
391
+ exllama_config=exllama_config,
392
+ xft_config=xft_config,
393
+ revision=revision,
394
+ debug=debug,
395
+ )
396
+ generate_stream_func = get_generate_stream_function(model, model_path)
397
+
398
+ model_type = str(type(model)).lower()
399
+ is_t5 = "t5" in model_type
400
+ is_codet5p = "codet5p" in model_type
401
+ is_xft = "xft" in model_type
402
+
403
+ # Hardcode T5's default repetition penalty to be 1.2
404
+ if is_t5 and repetition_penalty == 1.0:
405
+ repetition_penalty = 1.2
406
+
407
+ # Set context length
408
+ context_len = get_context_length(model.config)
409
+
410
+ # Chat
411
+ def new_chat():
412
+ if conv_template:
413
+ conv = get_conv_template(conv_template)
414
+ else:
415
+ conv = get_conversation_template(model_path)
416
+ if conv_system_msg is not None:
417
+ conv.set_system_message(conv_system_msg)
418
+ return conv
419
+
420
+ def reload_conv(conv):
421
+ """
422
+ Reprints the conversation from the start.
423
+ """
424
+ for message in conv.messages[conv.offset :]:
425
+ chatio.prompt_for_output(message[0])
426
+ chatio.print_output(message[1])
427
+
428
+ conv = None
429
+
430
+ while True:
431
+ if not history or not conv:
432
+ conv = new_chat()
433
+
434
+ try:
435
+ inp = chatio.prompt_for_input(conv.roles[0])
436
+ except EOFError:
437
+ inp = ""
438
+
439
+ if inp == "!!exit":# or not inp:
440
+ print("exit...")
441
+ break
442
+ elif inp == "!!reset":
443
+ print("resetting...")
444
+ conv = new_chat()
445
+ continue
446
+ elif inp == "!!remove":
447
+ print("removing last message...")
448
+ if len(conv.messages) > conv.offset:
449
+ # Assistant
450
+ if conv.messages[-1][0] == conv.roles[1]:
451
+ conv.messages.pop()
452
+ # User
453
+ if conv.messages[-1][0] == conv.roles[0]:
454
+ conv.messages.pop()
455
+ reload_conv(conv)
456
+ else:
457
+ print("No messages to remove.")
458
+ continue
459
+ elif inp == "!!regen":
460
+ print("regenerating last message...")
461
+ if len(conv.messages) > conv.offset:
462
+ # Assistant
463
+ if conv.messages[-1][0] == conv.roles[1]:
464
+ conv.messages.pop()
465
+ # User
466
+ if conv.messages[-1][0] == conv.roles[0]:
467
+ reload_conv(conv)
468
+ # Set inp to previous message
469
+ inp = conv.messages.pop()[1]
470
+ else:
471
+ # Shouldn't happen in normal circumstances
472
+ print("No user message to regenerate from.")
473
+ continue
474
+ else:
475
+ print("No messages to regenerate.")
476
+ continue
477
+ elif inp.startswith("!!save"):
478
+ args = inp.split(" ", 1)
479
+
480
+ if len(args) != 2:
481
+ print("usage: !!save <filename>")
482
+ continue
483
+ else:
484
+ filename = args[1]
485
+
486
+ # Add .json if extension not present
487
+ if not "." in filename:
488
+ filename += ".json"
489
+
490
+ print("saving...", filename)
491
+ with open(filename, "w", encoding="utf-8") as outfile:
492
+ json.dump(conv.dict(), outfile, ensure_ascii=False)
493
+ continue
494
+ elif inp.startswith("!!badcase"):
495
+ args = inp.split(" ", 1)
496
+
497
+ if len(args) != 2:
498
+ print("usage: !!save <filename>")
499
+ continue
500
+ else:
501
+ filename = args[1]
502
+
503
+ # Add .json if extension not present
504
+ if not "." in filename:
505
+ filename += ".jsonl"
506
+
507
+ print("saving...", filename)
508
+ with open(filename, "a+", encoding="utf-8") as outfile:
509
+ data = convert_message_format(conv.messages)
510
+ json.dump(data, outfile, ensure_ascii=False)
511
+ outfile.write('\n')
512
+ continue
513
+ elif inp.startswith("!!load"):
514
+ args = inp.split(" ", 1)
515
+
516
+ if len(args) != 2:
517
+ print("usage: !!load <filename>")
518
+ continue
519
+ else:
520
+ filename = args[1]
521
+
522
+ # Check if file exists and add .json if needed
523
+ if not os.path.exists(filename):
524
+ if (not filename.endswith(".json")) and os.path.exists(
525
+ filename + ".json"
526
+ ):
527
+ filename += ".json"
528
+ else:
529
+ print("file not found:", filename)
530
+ continue
531
+
532
+ print("loading...", filename)
533
+ with open(filename, "r") as infile:
534
+ new_conv = json.load(infile)
535
+
536
+ conv = get_conv_template(new_conv["template_name"])
537
+ conv.set_system_message(new_conv["system_message"])
538
+ conv.messages = new_conv["messages"]
539
+ reload_conv(conv)
540
+ continue
541
+
542
+ conv.append_message(conv.roles[0], inp)
543
+ conv.append_message(conv.roles[1], None)
544
+ prompt = conv.get_prompt(tokenizer)
545
+
546
+ if is_codet5p: # codet5p is a code completion model.
547
+ prompt = inp
548
+
549
+ gen_params = {
550
+ "model": model_path,
551
+ "prompt": prompt,
552
+ "temperature": temperature,
553
+ "repetition_penalty": repetition_penalty,
554
+ "max_new_tokens": max_new_tokens,
555
+ "stop": conv.stop_str,
556
+ "stop_token_ids": conv.stop_token_ids,
557
+ "none_stop": conv.none_stop,
558
+ "skip_special_tokens": conv.skip_special_tokens,
559
+ "echo": False,
560
+ }
561
+
562
+ try:
563
+ chatio.prompt_for_output(conv.roles[1])
564
+ output_stream = generate_stream_func(
565
+ model,
566
+ tokenizer,
567
+ gen_params,
568
+ device,
569
+ context_len=context_len,
570
+ judge_sent_end=judge_sent_end,
571
+ )
572
+ t = time.time()
573
+ outputs = chatio.stream_output(output_stream)
574
+ duration = time.time() - t
575
+ conv.update_last_message(outputs.strip())
576
+
577
+ if debug:
578
+ num_tokens = len(tokenizer.encode(outputs))
579
+ msg = {
580
+ "conv_template": conv.name,
581
+ "prompt": prompt,
582
+ "outputs": outputs,
583
+ "speed (token/s)": round(num_tokens / duration, 2),
584
+ }
585
+ print(f"\n{msg}\n")
586
+
587
+ except KeyboardInterrupt:
588
+ print("stopped generation.")
589
+ # If generation didn't finish
590
+ if conv.messages[-1][1] is None:
591
+ conv.messages.pop()
592
+ # Remove last user message, so there isn't a double up
593
+ if conv.messages[-1][0] == conv.roles[0]:
594
+ conv.messages.pop()
595
+
596
+ reload_conv(conv)
launch_all_serve.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage: python launch_all_serve_by_shell.py --model-path-address "THUDM/chatglm2-6b@localhost@2021" "huggyllama/llama-7b@localhost@2022"
3
+
4
+ Workers are listed in format of `model-path`@`host`@`port`
5
+
6
+ The key mechanism behind this scripts is:
7
+ 1, execute shell cmd to launch the controller/worker/openai-api-server;
8
+ 2, check the log of controller/worker/openai-api-server to ensure that the serve is launched properly.
9
+ Note that a few of non-critical `fastchat.serve` cmd options are not supported currently.
10
+ """
11
+ import sys
12
+ import os
13
+
14
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
15
+
16
+ import subprocess
17
+ import re
18
+ import argparse
19
+
20
+ LOGDIR = "./logs/"
21
+
22
+ if not os.path.exists(LOGDIR):
23
+ os.makedirs(LOGDIR)
24
+
25
+ parser = argparse.ArgumentParser()
26
+ # ------multi worker-----------------
27
+ parser.add_argument(
28
+ "--model-path-address",
29
+ default="THUDM/chatglm2-6b@localhost@20002",
30
+ nargs="+",
31
+ type=str,
32
+ help="model path, host, and port, formatted as model-path@host@port",
33
+ )
34
+ # ---------------controller-------------------------
35
+
36
+ parser.add_argument("--controller-host", type=str, default="localhost")
37
+ parser.add_argument("--controller-port", type=int, default=21001)
38
+ parser.add_argument(
39
+ "--dispatch-method",
40
+ type=str,
41
+ choices=["lottery", "shortest_queue"],
42
+ default="shortest_queue",
43
+ )
44
+ controller_args = ["controller-host", "controller-port", "dispatch-method"]
45
+
46
+ # ----------------------worker------------------------------------------
47
+
48
+ parser.add_argument("--worker-host", type=str, default="localhost")
49
+ parser.add_argument("--worker-port", type=int, default=21002)
50
+ # parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
51
+ # parser.add_argument(
52
+ # "--controller-address", type=str, default="http://localhost:21001"
53
+ # )
54
+ parser.add_argument(
55
+ "--model-path",
56
+ type=str,
57
+ default="lmsys/vicuna-7b-v1.5",
58
+ help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
59
+ )
60
+ parser.add_argument(
61
+ "--revision",
62
+ type=str,
63
+ default="main",
64
+ help="Hugging Face Hub model revision identifier",
65
+ )
66
+ parser.add_argument(
67
+ "--device",
68
+ type=str,
69
+ choices=["cpu", "cuda", "mps", "xpu", "npu"],
70
+ default="cuda",
71
+ help="The device type",
72
+ )
73
+ parser.add_argument(
74
+ "--gpus",
75
+ type=str,
76
+ default="0",
77
+ help="A single GPU like 1 or multiple GPUs like 0,2",
78
+ )
79
+ parser.add_argument("--num-gpus", type=int, default=1)
80
+ parser.add_argument(
81
+ "--max-gpu-memory",
82
+ type=str,
83
+ help="The maximum memory per gpu. Use a string like '13Gib'",
84
+ )
85
+ parser.add_argument("--load-8bit", action="store_true", help="Use 8-bit quantization")
86
+ parser.add_argument(
87
+ "--cpu-offloading",
88
+ action="store_true",
89
+ help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
90
+ )
91
+ parser.add_argument(
92
+ "--gptq-ckpt",
93
+ type=str,
94
+ default=None,
95
+ help="Load quantized model. The path to the local GPTQ checkpoint.",
96
+ )
97
+ parser.add_argument(
98
+ "--gptq-wbits",
99
+ type=int,
100
+ default=16,
101
+ choices=[2, 3, 4, 8, 16],
102
+ help="#bits to use for quantization",
103
+ )
104
+ parser.add_argument(
105
+ "--gptq-groupsize",
106
+ type=int,
107
+ default=-1,
108
+ help="Groupsize to use for quantization; default uses full row.",
109
+ )
110
+ parser.add_argument(
111
+ "--gptq-act-order",
112
+ action="store_true",
113
+ help="Whether to apply the activation order GPTQ heuristic",
114
+ )
115
+ parser.add_argument(
116
+ "--model-names",
117
+ type=lambda s: s.split(","),
118
+ help="Optional display comma separated names",
119
+ )
120
+ parser.add_argument(
121
+ "--limit-worker-concurrency",
122
+ type=int,
123
+ default=5,
124
+ help="Limit the model concurrency to prevent OOM.",
125
+ )
126
+ parser.add_argument("--stream-interval", type=int, default=2)
127
+ parser.add_argument("--no-register", action="store_true")
128
+
129
+ worker_args = [
130
+ "worker-host",
131
+ "worker-port",
132
+ "model-path",
133
+ "revision",
134
+ "device",
135
+ "gpus",
136
+ "num-gpus",
137
+ "max-gpu-memory",
138
+ "load-8bit",
139
+ "cpu-offloading",
140
+ "gptq-ckpt",
141
+ "gptq-wbits",
142
+ "gptq-groupsize",
143
+ "gptq-act-order",
144
+ "model-names",
145
+ "limit-worker-concurrency",
146
+ "stream-interval",
147
+ "no-register",
148
+ "controller-address",
149
+ ]
150
+ # -----------------openai server---------------------------
151
+
152
+ parser.add_argument("--server-host", type=str, default="localhost", help="host name")
153
+ parser.add_argument("--server-port", type=int, default=8001, help="port number")
154
+ parser.add_argument(
155
+ "--allow-credentials", action="store_true", help="allow credentials"
156
+ )
157
+ # parser.add_argument(
158
+ # "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
159
+ # )
160
+ # parser.add_argument(
161
+ # "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
162
+ # )
163
+ # parser.add_argument(
164
+ # "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
165
+ # )
166
+ parser.add_argument(
167
+ "--api-keys",
168
+ type=lambda s: s.split(","),
169
+ help="Optional list of comma separated API keys",
170
+ )
171
+ server_args = [
172
+ "server-host",
173
+ "server-port",
174
+ "allow-credentials",
175
+ "api-keys",
176
+ "controller-address",
177
+ ]
178
+
179
+ args = parser.parse_args()
180
+
181
+ args = argparse.Namespace(
182
+ **vars(args),
183
+ **{"controller-address": f"http://{args.controller_host}:{args.controller_port}"},
184
+ )
185
+
186
+ if args.gpus:
187
+ if len(args.gpus.split(",")) < args.num_gpus:
188
+ raise ValueError(
189
+ f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
190
+ )
191
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
192
+
193
+ # 0,controller, model_worker, openai_api_server
194
+ # 1, cmd options
195
+ # 2,LOGDIR
196
+ # 3, log file name
197
+ base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &"
198
+
199
+ # 0 LOGDIR
200
+ #! 1 log file name
201
+ # 2 controller, worker, openai_api_server
202
+ base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do
203
+ sleep 1s;
204
+ echo "wait {2} running"
205
+ done
206
+ echo '{2} running' """
207
+
208
+
209
+ def string_args(args, args_list):
210
+ args_str = ""
211
+ for key, value in args._get_kwargs():
212
+ key = key.replace("_", "-")
213
+ if key not in args_list:
214
+ continue
215
+
216
+ key = key.split("-")[-1] if re.search("port|host", key) else key
217
+ if not value:
218
+ pass
219
+ # 1==True -> True
220
+ elif isinstance(value, bool) and value == True:
221
+ args_str += f" --{key} "
222
+ elif (
223
+ isinstance(value, list)
224
+ or isinstance(value, tuple)
225
+ or isinstance(value, set)
226
+ ):
227
+ value = " ".join(value)
228
+ args_str += f" --{key} {value} "
229
+ else:
230
+ args_str += f" --{key} {value} "
231
+
232
+ return args_str
233
+
234
+
235
+ def launch_worker(item):
236
+ log_name = (
237
+ item.split("/")[-1]
238
+ .split("\\")[-1]
239
+ .replace("-", "_")
240
+ .replace("@", "_")
241
+ .replace(".", "_")
242
+ )
243
+
244
+ args.model_path, args.worker_host, args.worker_port = item.split("@")
245
+ print("*" * 80)
246
+ worker_str_args = string_args(args, worker_args)
247
+ print(worker_str_args)
248
+ worker_sh = base_launch_sh.format(
249
+ "model_worker", worker_str_args, LOGDIR, f"worker_{log_name}"
250
+ )
251
+ worker_check_sh = base_check_sh.format(LOGDIR, f"worker_{log_name}", "model_worker")
252
+ subprocess.run(worker_sh, shell=True, check=True)
253
+ subprocess.run(worker_check_sh, shell=True, check=True)
254
+
255
+
256
+ def launch_all():
257
+ controller_str_args = string_args(args, controller_args)
258
+ controller_sh = base_launch_sh.format(
259
+ "controller", controller_str_args, LOGDIR, "controller"
260
+ )
261
+ controller_check_sh = base_check_sh.format(LOGDIR, "controller", "controller")
262
+ subprocess.run(controller_sh, shell=True, check=True)
263
+ subprocess.run(controller_check_sh, shell=True, check=True)
264
+
265
+ if isinstance(args.model_path_address, str):
266
+ launch_worker(args.model_path_address)
267
+ else:
268
+ for idx, item in enumerate(args.model_path_address):
269
+ print(f"loading {idx}th model:{item}")
270
+ launch_worker(item)
271
+
272
+ server_str_args = string_args(args, server_args)
273
+ server_sh = base_launch_sh.format(
274
+ "openai_api_server", server_str_args, LOGDIR, "openai_api_server"
275
+ )
276
+ server_check_sh = base_check_sh.format(
277
+ LOGDIR, "openai_api_server", "openai_api_server"
278
+ )
279
+ subprocess.run(server_sh, shell=True, check=True)
280
+ subprocess.run(server_check_sh, shell=True, check=True)
281
+
282
+
283
+ if __name__ == "__main__":
284
+ launch_all()
model_worker.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker that executes the model.
3
+ """
4
+ import argparse
5
+ import base64
6
+ import gc
7
+ import json
8
+ import os
9
+ from typing import List, Optional
10
+ import uuid
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from transformers import set_seed
15
+ import uvicorn
16
+
17
+ from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
18
+ from fastchat.model.model_adapter import (
19
+ load_model,
20
+ add_model_args,
21
+ get_generate_stream_function,
22
+ )
23
+ from fastchat.modules.awq import AWQConfig
24
+ from fastchat.modules.exllama import ExllamaConfig
25
+ from fastchat.modules.xfastertransformer import XftConfig
26
+ from fastchat.modules.gptq import GptqConfig
27
+ from fastchat.serve.base_model_worker import BaseModelWorker, app
28
+ from fastchat.utils import (
29
+ build_logger,
30
+ get_context_length,
31
+ str_to_torch_dtype,
32
+ )
33
+
34
+
35
+ worker_id = str(uuid.uuid4())[:8]
36
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
37
+
38
+
39
+ class ModelWorker(BaseModelWorker):
40
+ def __init__(
41
+ self,
42
+ controller_addr: str,
43
+ worker_addr: str,
44
+ worker_id: str,
45
+ model_path: str,
46
+ model_names: List[str],
47
+ limit_worker_concurrency: int,
48
+ no_register: bool,
49
+ device: str,
50
+ num_gpus: int,
51
+ max_gpu_memory: str,
52
+ dtype: Optional[torch.dtype] = None,
53
+ load_8bit: bool = False,
54
+ cpu_offloading: bool = False,
55
+ gptq_config: Optional[GptqConfig] = None,
56
+ awq_config: Optional[AWQConfig] = None,
57
+ exllama_config: Optional[ExllamaConfig] = None,
58
+ xft_config: Optional[XftConfig] = None,
59
+ stream_interval: int = 2,
60
+ conv_template: Optional[str] = None,
61
+ embed_in_truncate: bool = False,
62
+ seed: Optional[int] = None,
63
+ debug: bool = False,
64
+ **kwargs,
65
+ ):
66
+ super().__init__(
67
+ controller_addr,
68
+ worker_addr,
69
+ worker_id,
70
+ model_path,
71
+ model_names,
72
+ limit_worker_concurrency,
73
+ conv_template=conv_template,
74
+ )
75
+
76
+ logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...")
77
+ self.model, self.tokenizer = load_model(
78
+ model_path,
79
+ device=device,
80
+ num_gpus=num_gpus,
81
+ max_gpu_memory=max_gpu_memory,
82
+ dtype=dtype,
83
+ load_8bit=load_8bit,
84
+ cpu_offloading=cpu_offloading,
85
+ gptq_config=gptq_config,
86
+ awq_config=awq_config,
87
+ exllama_config=exllama_config,
88
+ xft_config=xft_config,
89
+ debug=debug,
90
+ model_name=model_names[0],
91
+ )
92
+ self.device = device
93
+ if self.tokenizer.pad_token == None:
94
+ self.tokenizer.pad_token = self.tokenizer.eos_token
95
+ self.context_len = get_context_length(self.model.config)
96
+ self.generate_stream_func = get_generate_stream_function(self.model, model_path)
97
+ self.stream_interval = stream_interval
98
+ self.embed_in_truncate = embed_in_truncate
99
+ self.seed = seed
100
+
101
+ if not no_register:
102
+ self.init_heart_beat()
103
+
104
+ def generate_stream_gate(self, params):
105
+ self.call_ct += 1
106
+
107
+ try:
108
+ if self.seed is not None:
109
+ set_seed(self.seed)
110
+ for output in self.generate_stream_func(
111
+ self.model,
112
+ self.tokenizer,
113
+ params,
114
+ self.device,
115
+ self.context_len,
116
+ self.stream_interval,
117
+ ):
118
+ ret = {
119
+ "text": output["text"],
120
+ "error_code": 0,
121
+ }
122
+ if "usage" in output:
123
+ ret["usage"] = output["usage"]
124
+ if "finish_reason" in output:
125
+ ret["finish_reason"] = output["finish_reason"]
126
+ if "logprobs" in output:
127
+ ret["logprobs"] = output["logprobs"]
128
+ yield json.dumps(ret).encode() + b"\0"
129
+ except torch.cuda.OutOfMemoryError as e:
130
+ ret = {
131
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
132
+ "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
133
+ }
134
+ yield json.dumps(ret).encode() + b"\0"
135
+ except (ValueError, RuntimeError) as e:
136
+ ret = {
137
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
138
+ "error_code": ErrorCode.INTERNAL_ERROR,
139
+ }
140
+ yield json.dumps(ret).encode() + b"\0"
141
+
142
+ def generate_gate(self, params):
143
+ for x in self.generate_stream_gate(params):
144
+ pass
145
+ return json.loads(x[:-1].decode())
146
+
147
+ def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict):
148
+ if model_type_dict.get("is_bert"):
149
+ model_output = self.model(input_ids)
150
+ if model_type_dict.get("is_robert"):
151
+ data = model_output.last_hidden_state
152
+ else:
153
+ data = model_output[0]
154
+ elif model_type_dict.get("is_t5"):
155
+ model_output = self.model(input_ids, decoder_input_ids=input_ids)
156
+ data = model_output.encoder_last_hidden_state
157
+ else:
158
+ model_output = self.model(input_ids, output_hidden_states=True)
159
+ if model_type_dict.get("is_chatglm"):
160
+ data = model_output.hidden_states[-1].transpose(0, 1)
161
+ else:
162
+ data = model_output.hidden_states[-1]
163
+ mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
164
+ masked_embeddings = data * mask
165
+ sum_embeddings = torch.sum(masked_embeddings, dim=1)
166
+ token_num = torch.sum(attention_mask).item()
167
+
168
+ return sum_embeddings, token_num
169
+
170
+ def __encode_base64(self, embeddings: torch.Tensor) -> List[str]:
171
+ embeddings = embeddings.cpu()
172
+ return [
173
+ base64.b64encode(e.numpy().tobytes()).decode("utf-8") for e in embeddings
174
+ ]
175
+
176
+ @torch.inference_mode()
177
+ def get_embeddings(self, params):
178
+ self.call_ct += 1
179
+
180
+ try:
181
+ tokenizer = self.tokenizer
182
+ ret = {"embedding": [], "token_num": 0}
183
+
184
+ model_type_dict = {
185
+ "is_llama": "llama" in str(type(self.model)),
186
+ "is_t5": "t5" in str(type(self.model)),
187
+ "is_chatglm": "chatglm" in str(type(self.model)),
188
+ "is_bert": "bert" in str(type(self.model)),
189
+ "is_robert": "robert" in str(type(self.model)),
190
+ }
191
+
192
+ if self.embed_in_truncate:
193
+ encoding = tokenizer.batch_encode_plus(
194
+ params["input"],
195
+ padding=True,
196
+ truncation="longest_first",
197
+ return_tensors="pt",
198
+ max_length=self.context_len,
199
+ )
200
+ else:
201
+ encoding = tokenizer.batch_encode_plus(
202
+ params["input"], padding=True, return_tensors="pt"
203
+ )
204
+ input_ids = encoding["input_ids"].to(self.device)
205
+ attention_mask = input_ids != tokenizer.pad_token_id
206
+
207
+ base64_encode = params.get("encoding_format", None)
208
+
209
+ if self.embed_in_truncate:
210
+ chunk_embeddings, token_num = self.__process_embed_chunk(
211
+ input_ids, attention_mask, **model_type_dict
212
+ )
213
+ embedding = chunk_embeddings / token_num
214
+ normalized_embeddings = F.normalize(embedding, p=2, dim=1)
215
+ ret["token_num"] = token_num
216
+ else:
217
+ all_embeddings = []
218
+ all_token_num = 0
219
+ for i in range(0, input_ids.size(1), self.context_len):
220
+ chunk_input_ids = input_ids[:, i : i + self.context_len]
221
+ chunk_attention_mask = attention_mask[:, i : i + self.context_len]
222
+
223
+ chunk_embeddings, token_num = self.__process_embed_chunk(
224
+ chunk_input_ids, chunk_attention_mask, **model_type_dict
225
+ )
226
+ all_embeddings.append(chunk_embeddings)
227
+ all_token_num += token_num
228
+
229
+ all_embeddings_tensor = torch.stack(all_embeddings)
230
+ embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num
231
+ normalized_embeddings = F.normalize(embedding, p=2, dim=1)
232
+
233
+ ret["token_num"] = all_token_num
234
+
235
+ if base64_encode == "base64":
236
+ out_embeddings = self.__encode_base64(normalized_embeddings)
237
+ else:
238
+ out_embeddings = normalized_embeddings.tolist()
239
+ ret["embedding"] = out_embeddings
240
+
241
+ gc.collect()
242
+ torch.cuda.empty_cache()
243
+ if self.device == "xpu":
244
+ torch.xpu.empty_cache()
245
+ if self.device == "npu":
246
+ torch.npu.empty_cache()
247
+ except torch.cuda.OutOfMemoryError as e:
248
+ ret = {
249
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
250
+ "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
251
+ }
252
+ except (ValueError, RuntimeError) as e:
253
+ ret = {
254
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
255
+ "error_code": ErrorCode.INTERNAL_ERROR,
256
+ }
257
+ return ret
258
+
259
+
260
+ def create_model_worker():
261
+ parser = argparse.ArgumentParser()
262
+ parser.add_argument("--host", type=str, default="localhost")
263
+ parser.add_argument("--port", type=int, default=21002)
264
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
265
+ parser.add_argument(
266
+ "--controller-address", type=str, default="http://localhost:21001"
267
+ )
268
+ add_model_args(parser)
269
+ parser.add_argument(
270
+ "--model-names",
271
+ type=lambda s: s.split(","),
272
+ help="Optional display comma separated names",
273
+ )
274
+ parser.add_argument(
275
+ "--conv-template", type=str, default=None, help="Conversation prompt template."
276
+ )
277
+ parser.add_argument("--embed-in-truncate", action="store_true")
278
+ parser.add_argument(
279
+ "--limit-worker-concurrency",
280
+ type=int,
281
+ default=5,
282
+ help="Limit the model concurrency to prevent OOM.",
283
+ )
284
+ parser.add_argument("--stream-interval", type=int, default=2)
285
+ parser.add_argument("--no-register", action="store_true")
286
+ parser.add_argument(
287
+ "--seed",
288
+ type=int,
289
+ default=None,
290
+ help="Overwrite the random seed for each generation.",
291
+ )
292
+ parser.add_argument(
293
+ "--debug", type=bool, default=False, help="Print debugging messages"
294
+ )
295
+ args = parser.parse_args()
296
+ logger.info(f"args: {args}")
297
+
298
+ if args.gpus:
299
+ if len(args.gpus.split(",")) < args.num_gpus:
300
+ raise ValueError(
301
+ f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
302
+ )
303
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
304
+
305
+ gptq_config = GptqConfig(
306
+ ckpt=args.gptq_ckpt or args.model_path,
307
+ wbits=args.gptq_wbits,
308
+ groupsize=args.gptq_groupsize,
309
+ act_order=args.gptq_act_order,
310
+ )
311
+ awq_config = AWQConfig(
312
+ ckpt=args.awq_ckpt or args.model_path,
313
+ wbits=args.awq_wbits,
314
+ groupsize=args.awq_groupsize,
315
+ )
316
+ if args.enable_exllama:
317
+ exllama_config = ExllamaConfig(
318
+ max_seq_len=args.exllama_max_seq_len,
319
+ gpu_split=args.exllama_gpu_split,
320
+ )
321
+ else:
322
+ exllama_config = None
323
+ if args.enable_xft:
324
+ xft_config = XftConfig(
325
+ max_seq_len=args.xft_max_seq_len,
326
+ data_type=args.xft_dtype,
327
+ )
328
+ if args.device != "cpu":
329
+ print("xFasterTransformer now is only support CPUs. Reset device to CPU")
330
+ args.device = "cpu"
331
+ else:
332
+ xft_config = None
333
+
334
+ worker = ModelWorker(
335
+ args.controller_address,
336
+ args.worker_address,
337
+ worker_id,
338
+ args.model_path,
339
+ args.model_names,
340
+ args.limit_worker_concurrency,
341
+ no_register=args.no_register,
342
+ device=args.device,
343
+ num_gpus=args.num_gpus,
344
+ max_gpu_memory=args.max_gpu_memory,
345
+ dtype=str_to_torch_dtype(args.dtype),
346
+ load_8bit=args.load_8bit,
347
+ cpu_offloading=args.cpu_offloading,
348
+ gptq_config=gptq_config,
349
+ awq_config=awq_config,
350
+ exllama_config=exllama_config,
351
+ xft_config=xft_config,
352
+ stream_interval=args.stream_interval,
353
+ conv_template=args.conv_template,
354
+ embed_in_truncate=args.embed_in_truncate,
355
+ seed=args.seed,
356
+ debug=args.debug,
357
+ )
358
+ return args, worker
359
+
360
+
361
+ if __name__ == "__main__":
362
+ args, worker = create_model_worker()
363
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
monitor/basic_stats.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import code
3
+ import datetime
4
+ import json
5
+ import os
6
+ from pytz import timezone
7
+ import time
8
+
9
+ import pandas as pd # pandas>=2.0.3
10
+ import plotly.express as px
11
+ import plotly.graph_objects as go
12
+ from tqdm import tqdm
13
+
14
+
15
+ NUM_SERVERS = 14
16
+
17
+
18
+ def get_log_files(max_num_files=None):
19
+ dates = []
20
+ for month in range(4, 12):
21
+ for day in range(1, 33):
22
+ dates.append(f"2023-{month:02d}-{day:02d}")
23
+
24
+ filenames = []
25
+ for d in dates:
26
+ for i in range(NUM_SERVERS):
27
+ name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
28
+ if os.path.exists(name):
29
+ filenames.append(name)
30
+ max_num_files = max_num_files or len(filenames)
31
+ filenames = filenames[-max_num_files:]
32
+ return filenames
33
+
34
+
35
+ def load_log_files(log_files):
36
+ data = []
37
+ for filename in tqdm(log_files, desc="read files"):
38
+ for retry in range(5):
39
+ try:
40
+ lines = open(filename).readlines()
41
+ break
42
+ except FileNotFoundError:
43
+ time.sleep(2)
44
+
45
+ for l in lines:
46
+ row = json.loads(l)
47
+
48
+ data.append(
49
+ dict(
50
+ type=row["type"],
51
+ tstamp=row["tstamp"],
52
+ model=row.get("model", ""),
53
+ models=row.get("models", ["", ""]),
54
+ )
55
+ )
56
+
57
+ return data
58
+
59
+
60
+ def get_anony_vote_df(df):
61
+ anony_vote_df = df[
62
+ df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"])
63
+ ]
64
+ anony_vote_df = anony_vote_df[anony_vote_df["models"].apply(lambda x: x[0] == "")]
65
+ return anony_vote_df
66
+
67
+
68
+ def merge_counts(series, on, names):
69
+ ret = pd.merge(series[0], series[1], on=on)
70
+ for i in range(2, len(series)):
71
+ ret = pd.merge(ret, series[i], on=on)
72
+ ret = ret.reset_index()
73
+ old_names = list(ret.columns)[-len(series) :]
74
+ rename = {old_name: new_name for old_name, new_name in zip(old_names, names)}
75
+ ret = ret.rename(columns=rename)
76
+ return ret
77
+
78
+
79
+ def report_basic_stats(log_files):
80
+ df_all = load_log_files(log_files)
81
+ df_all = pd.DataFrame(df_all)
82
+ now_t = df_all["tstamp"].max()
83
+ df_1_hour = df_all[df_all["tstamp"] > (now_t - 3600)]
84
+ df_1_day = df_all[df_all["tstamp"] > (now_t - 3600 * 24)]
85
+ anony_vote_df_all = get_anony_vote_df(df_all)
86
+
87
+ # Chat trends
88
+ chat_dates = [
89
+ datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime(
90
+ "%Y-%m-%d"
91
+ )
92
+ for x in df_all[df_all["type"] == "chat"]["tstamp"]
93
+ ]
94
+ chat_dates_counts = pd.value_counts(chat_dates)
95
+ vote_dates = [
96
+ datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime(
97
+ "%Y-%m-%d"
98
+ )
99
+ for x in anony_vote_df_all["tstamp"]
100
+ ]
101
+ vote_dates_counts = pd.value_counts(vote_dates)
102
+ chat_dates_bar = go.Figure(
103
+ data=[
104
+ go.Bar(
105
+ name="Anony. Vote",
106
+ x=vote_dates_counts.index,
107
+ y=vote_dates_counts,
108
+ text=[f"{val:.0f}" for val in vote_dates_counts],
109
+ textposition="auto",
110
+ ),
111
+ go.Bar(
112
+ name="Chat",
113
+ x=chat_dates_counts.index,
114
+ y=chat_dates_counts,
115
+ text=[f"{val:.0f}" for val in chat_dates_counts],
116
+ textposition="auto",
117
+ ),
118
+ ]
119
+ )
120
+ chat_dates_bar.update_layout(
121
+ barmode="stack",
122
+ xaxis_title="Dates",
123
+ yaxis_title="Count",
124
+ height=300,
125
+ width=1200,
126
+ )
127
+
128
+ # Model call counts
129
+ model_hist_all = df_all[df_all["type"] == "chat"]["model"].value_counts()
130
+ model_hist_1_day = df_1_day[df_1_day["type"] == "chat"]["model"].value_counts()
131
+ model_hist_1_hour = df_1_hour[df_1_hour["type"] == "chat"]["model"].value_counts()
132
+ model_hist = merge_counts(
133
+ [model_hist_all, model_hist_1_day, model_hist_1_hour],
134
+ on="model",
135
+ names=["All", "Last Day", "Last Hour"],
136
+ )
137
+ model_hist_md = model_hist.to_markdown(index=False, tablefmt="github")
138
+
139
+ # Action counts
140
+ action_hist_all = df_all["type"].value_counts()
141
+ action_hist_1_day = df_1_day["type"].value_counts()
142
+ action_hist_1_hour = df_1_hour["type"].value_counts()
143
+ action_hist = merge_counts(
144
+ [action_hist_all, action_hist_1_day, action_hist_1_hour],
145
+ on="type",
146
+ names=["All", "Last Day", "Last Hour"],
147
+ )
148
+ action_hist_md = action_hist.to_markdown(index=False, tablefmt="github")
149
+
150
+ # Anony vote counts
151
+ anony_vote_hist_all = anony_vote_df_all["type"].value_counts()
152
+ anony_vote_df_1_day = get_anony_vote_df(df_1_day)
153
+ anony_vote_hist_1_day = anony_vote_df_1_day["type"].value_counts()
154
+ # anony_vote_df_1_hour = get_anony_vote_df(df_1_hour)
155
+ # anony_vote_hist_1_hour = anony_vote_df_1_hour["type"].value_counts()
156
+ anony_vote_hist = merge_counts(
157
+ [anony_vote_hist_all, anony_vote_hist_1_day],
158
+ on="type",
159
+ names=["All", "Last Day"],
160
+ )
161
+ anony_vote_hist_md = anony_vote_hist.to_markdown(index=False, tablefmt="github")
162
+
163
+ # Last 24 hours
164
+ chat_1_day = df_1_day[df_1_day["type"] == "chat"]
165
+ num_chats_last_24_hours = []
166
+ base = df_1_day["tstamp"].min()
167
+ for i in range(24, 0, -1):
168
+ left = base + (i - 1) * 3600
169
+ right = base + i * 3600
170
+ num = ((chat_1_day["tstamp"] >= left) & (chat_1_day["tstamp"] < right)).sum()
171
+ num_chats_last_24_hours.append(num)
172
+ times = [
173
+ datetime.datetime.fromtimestamp(
174
+ base + i * 3600, tz=timezone("US/Pacific")
175
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
176
+ for i in range(24, 0, -1)
177
+ ]
178
+ last_24_hours_df = pd.DataFrame({"time": times, "value": num_chats_last_24_hours})
179
+ last_24_hours_md = last_24_hours_df.to_markdown(index=False, tablefmt="github")
180
+
181
+ # Last update datetime
182
+ last_updated_tstamp = now_t
183
+ last_updated_datetime = datetime.datetime.fromtimestamp(
184
+ last_updated_tstamp, tz=timezone("US/Pacific")
185
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
186
+
187
+ # code.interact(local=locals())
188
+
189
+ return {
190
+ "chat_dates_bar": chat_dates_bar,
191
+ "model_hist_md": model_hist_md,
192
+ "action_hist_md": action_hist_md,
193
+ "anony_vote_hist_md": anony_vote_hist_md,
194
+ "num_chats_last_24_hours": last_24_hours_md,
195
+ "last_updated_datetime": last_updated_datetime,
196
+ }
197
+
198
+
199
+ if __name__ == "__main__":
200
+ parser = argparse.ArgumentParser()
201
+ parser.add_argument("--max-num-files", type=int)
202
+ args = parser.parse_args()
203
+
204
+ log_files = get_log_files(args.max_num_files)
205
+ basic_stats = report_basic_stats(log_files)
206
+
207
+ print(basic_stats["action_hist_md"] + "\n")
208
+ print(basic_stats["model_hist_md"] + "\n")
209
+ print(basic_stats["anony_vote_hist_md"] + "\n")
210
+ print(basic_stats["num_chats_last_24_hours"] + "\n")
monitor/clean_battle_data.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Clean chatbot arena battle log.
3
+
4
+ Usage:
5
+ python3 clean_battle_data.py --mode conv_release
6
+ """
7
+ import argparse
8
+ import datetime
9
+ import json
10
+ import os
11
+ from pytz import timezone
12
+ import time
13
+
14
+ from tqdm import tqdm
15
+
16
+ from fastchat.serve.monitor.basic_stats import get_log_files, NUM_SERVERS
17
+ from fastchat.utils import detect_language
18
+
19
+
20
+ VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"]
21
+ IDENTITY_WORDS = [
22
+ "vicuna",
23
+ "lmsys",
24
+ "koala",
25
+ "uc berkeley",
26
+ "open assistant",
27
+ "laion",
28
+ "chatglm",
29
+ "chatgpt",
30
+ "openai",
31
+ "anthropic",
32
+ "claude",
33
+ "bard",
34
+ "palm",
35
+ "lamda",
36
+ "google",
37
+ "llama",
38
+ "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.",
39
+ "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES.",
40
+ ]
41
+
42
+ for i in range(len(IDENTITY_WORDS)):
43
+ IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower()
44
+
45
+
46
+ def get_log_files(max_num_files=None):
47
+ dates = []
48
+ for month in range(4, 12):
49
+ for day in range(1, 33):
50
+ dates.append(f"2023-{month:02d}-{day:02d}")
51
+
52
+ filenames = []
53
+ for d in dates:
54
+ for i in range(NUM_SERVERS):
55
+ name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
56
+ if os.path.exists(name):
57
+ filenames.append(name)
58
+ max_num_files = max_num_files or len(filenames)
59
+ filenames = filenames[-max_num_files:]
60
+ return filenames
61
+
62
+
63
+ def remove_html(raw):
64
+ if raw.startswith("<h3>"):
65
+ return raw[raw.find(": ") + 2 : -len("</h3>\n")]
66
+ return raw
67
+
68
+
69
+ def to_openai_format(messages):
70
+ roles = ["user", "assistant"]
71
+ ret = []
72
+ for i, x in enumerate(messages):
73
+ ret.append({"role": roles[i % 2], "content": x[1]})
74
+ return ret
75
+
76
+
77
+ def replace_model_name(old_name):
78
+ return (
79
+ old_name.replace("bard", "palm-2")
80
+ .replace("claude-v1", "claude-1")
81
+ .replace("claude-instant-v1", "claude-instant-1")
82
+ .replace("oasst-sft-1-pythia-12b", "oasst-pythia-12b")
83
+ )
84
+
85
+
86
+ def clean_battle_data(log_files, exclude_model_names):
87
+ data = []
88
+ for filename in tqdm(log_files, desc="read files"):
89
+ for retry in range(5):
90
+ try:
91
+ lines = open(filename).readlines()
92
+ break
93
+ except FileNotFoundError:
94
+ time.sleep(2)
95
+
96
+ for l in lines:
97
+ row = json.loads(l)
98
+ if row["type"] in VOTES:
99
+ data.append(row)
100
+
101
+ convert_type = {
102
+ "leftvote": "model_a",
103
+ "rightvote": "model_b",
104
+ "tievote": "tie",
105
+ "bothbad_vote": "tie (bothbad)",
106
+ }
107
+
108
+ all_models = set()
109
+ all_ips = dict()
110
+ ct_anony = 0
111
+ ct_invalid = 0
112
+ ct_leaked_identity = 0
113
+ battles = []
114
+ for row in data:
115
+ if row["models"][0] is None or row["models"][1] is None:
116
+ continue
117
+
118
+ # Resolve model names
119
+ models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])]
120
+ if "model_name" in row["states"][0]:
121
+ models_hidden = [
122
+ row["states"][0]["model_name"],
123
+ row["states"][1]["model_name"],
124
+ ]
125
+ if models_hidden[0] is None:
126
+ models_hidden = models_public
127
+ else:
128
+ models_hidden = models_public
129
+
130
+ if (models_public[0] == "" and models_public[1] != "") or (
131
+ models_public[1] == "" and models_public[0] != ""
132
+ ):
133
+ ct_invalid += 1
134
+ continue
135
+
136
+ if models_public[0] == "" or models_public[0] == "Model A":
137
+ anony = True
138
+ models = models_hidden
139
+ ct_anony += 1
140
+ else:
141
+ anony = False
142
+ models = models_public
143
+ if not models_public == models_hidden:
144
+ ct_invalid += 1
145
+ continue
146
+
147
+ # Detect langauge
148
+ state = row["states"][0]
149
+ if state["offset"] >= len(state["messages"]):
150
+ ct_invalid += 1
151
+ continue
152
+ lang_code = detect_language(state["messages"][state["offset"]][1])
153
+
154
+ # Drop conversations if the model names are leaked
155
+ leaked_identity = False
156
+ messages = ""
157
+ for i in range(2):
158
+ state = row["states"][i]
159
+ for role, msg in state["messages"][state["offset"] :]:
160
+ if msg:
161
+ messages += msg.lower()
162
+ for word in IDENTITY_WORDS:
163
+ if word in messages:
164
+ leaked_identity = True
165
+ break
166
+
167
+ if leaked_identity:
168
+ ct_leaked_identity += 1
169
+ continue
170
+
171
+ # Replace bard with palm
172
+ models = [replace_model_name(m) for m in models]
173
+
174
+ # Exclude certain models
175
+ if any(x in exclude_model_names for x in models):
176
+ ct_invalid += 1
177
+ continue
178
+
179
+ question_id = row["states"][0]["conv_id"]
180
+ conversation_a = to_openai_format(
181
+ row["states"][0]["messages"][row["states"][0]["offset"] :]
182
+ )
183
+ conversation_b = to_openai_format(
184
+ row["states"][1]["messages"][row["states"][1]["offset"] :]
185
+ )
186
+
187
+ ip = row["ip"]
188
+ if ip not in all_ips:
189
+ all_ips[ip] = len(all_ips)
190
+ user_id = all_ips[ip]
191
+
192
+ # Save the results
193
+ battles.append(
194
+ dict(
195
+ question_id=question_id,
196
+ model_a=models[0],
197
+ model_b=models[1],
198
+ winner=convert_type[row["type"]],
199
+ judge=f"arena_user_{user_id}",
200
+ conversation_a=conversation_a,
201
+ conversation_b=conversation_b,
202
+ turn=len(conversation_a) // 2,
203
+ anony=anony,
204
+ language=lang_code,
205
+ tstamp=row["tstamp"],
206
+ )
207
+ )
208
+
209
+ all_models.update(models_hidden)
210
+ battles.sort(key=lambda x: x["tstamp"])
211
+ last_updated_tstamp = battles[-1]["tstamp"]
212
+
213
+ last_updated_datetime = datetime.datetime.fromtimestamp(
214
+ last_updated_tstamp, tz=timezone("US/Pacific")
215
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
216
+
217
+ print(
218
+ f"#votes: {len(data)}, #invalid votes: {ct_invalid}, "
219
+ f"#leaked_identity: {ct_leaked_identity}"
220
+ )
221
+ print(f"#battles: {len(battles)}, #anony: {ct_anony}")
222
+ print(f"#models: {len(all_models)}, {all_models}")
223
+ print(f"last-updated: {last_updated_datetime}")
224
+
225
+ return battles
226
+
227
+
228
+ if __name__ == "__main__":
229
+ parser = argparse.ArgumentParser()
230
+ parser.add_argument("--max-num-files", type=int)
231
+ parser.add_argument(
232
+ "--mode", type=str, choices=["simple", "conv_release"], default="simple"
233
+ )
234
+ parser.add_argument("--exclude-model-names", type=str, nargs="+")
235
+ args = parser.parse_args()
236
+
237
+ log_files = get_log_files(args.max_num_files)
238
+ battles = clean_battle_data(log_files, args.exclude_model_names or [])
239
+ last_updated_tstamp = battles[-1]["tstamp"]
240
+ cutoff_date = datetime.datetime.fromtimestamp(
241
+ last_updated_tstamp, tz=timezone("US/Pacific")
242
+ ).strftime("%Y%m%d")
243
+
244
+ if args.mode == "simple":
245
+ for x in battles:
246
+ for key in [
247
+ "conversation_a",
248
+ "conversation_b",
249
+ "question_id",
250
+ ]:
251
+ del x[key]
252
+ print("Samples:")
253
+ for i in range(4):
254
+ print(battles[i])
255
+ output = f"clean_battle_{cutoff_date}.json"
256
+ elif args.mode == "conv_release":
257
+ new_battles = []
258
+ for x in battles:
259
+ if not x["anony"]:
260
+ continue
261
+ for key in []:
262
+ del x[key]
263
+ new_battles.append(x)
264
+ battles = new_battles
265
+ output = f"clean_battle_conv_{cutoff_date}.json"
266
+
267
+ with open(output, "w") as fout:
268
+ json.dump(battles, fout, indent=2, ensure_ascii=False)
269
+ print(f"Write cleaned data to {output}")
monitor/clean_chat_data.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Clean chatbot arena chat log.
3
+
4
+ Usage:
5
+ python3 clean_chat_data.py --mode conv_release
6
+ """
7
+ import argparse
8
+ import datetime
9
+ import json
10
+ import os
11
+ from pytz import timezone
12
+ import time
13
+
14
+ from tqdm import tqdm
15
+
16
+ from fastchat.serve.monitor.basic_stats import NUM_SERVERS
17
+ from fastchat.serve.monitor.clean_battle_data import (
18
+ to_openai_format,
19
+ replace_model_name,
20
+ )
21
+ from fastchat.utils import detect_language
22
+
23
+
24
+ NETWORK_ERROR_MSG = (
25
+ "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.".lower()
26
+ )
27
+
28
+
29
+ def get_log_files(max_num_files=None):
30
+ dates = []
31
+ for month in range(4, 12):
32
+ for day in range(1, 33):
33
+ dates.append(f"2023-{month:02d}-{day:02d}")
34
+
35
+ filenames = []
36
+ for d in dates:
37
+ for i in range(NUM_SERVERS):
38
+ name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
39
+ if os.path.exists(name):
40
+ filenames.append(name)
41
+ max_num_files = max_num_files or len(filenames)
42
+ # filenames = list(reversed(filenames))
43
+ filenames = filenames[-max_num_files:]
44
+ return filenames
45
+
46
+
47
+ def clean_chat_data(log_files, action_type):
48
+ raw_data = []
49
+ for filename in tqdm(log_files, desc="read files"):
50
+ for retry in range(5):
51
+ try:
52
+ lines = open(filename).readlines()
53
+ break
54
+ except FileNotFoundError:
55
+ time.sleep(2)
56
+
57
+ for l in lines:
58
+ row = json.loads(l)
59
+ if row["type"] == action_type:
60
+ raw_data.append(row)
61
+
62
+ all_models = set()
63
+ all_ips = dict()
64
+ chats = []
65
+ ct_invalid_conv_id = 0
66
+ ct_invalid = 0
67
+ ct_network_error = 0
68
+ for row in raw_data:
69
+ try:
70
+ if action_type in ["chat", "upvote", "downvote"]:
71
+ state = row["state"]
72
+ model = row["model"]
73
+ elif action_type == "leftvote":
74
+ state = row["states"][0]
75
+ model = row["states"][0]["model_name"]
76
+ elif action_type == "rightvote":
77
+ state = row["states"][1]
78
+ model = row["states"][1]["model_name"]
79
+ conversation_id = state["conv_id"]
80
+ except KeyError:
81
+ ct_invalid_conv_id += 1
82
+ continue
83
+
84
+ if conversation_id is None:
85
+ ct_invalid_conv_id += 1
86
+ continue
87
+
88
+ conversation = to_openai_format(state["messages"][state["offset"] :])
89
+ if not isinstance(model, str):
90
+ ct_invalid += 1
91
+ continue
92
+ model = replace_model_name(model)
93
+
94
+ try:
95
+ lang_code = detect_language(state["messages"][state["offset"]][1])
96
+ except IndexError:
97
+ ct_invalid += 1
98
+ continue
99
+
100
+ if not all(isinstance(x["content"], str) for x in conversation):
101
+ ct_invalid += 1
102
+ continue
103
+
104
+ messages = "".join([x["content"] for x in conversation]).lower()
105
+ if NETWORK_ERROR_MSG in messages:
106
+ ct_network_error += 1
107
+ continue
108
+
109
+ ip = row["ip"]
110
+ if ip not in all_ips:
111
+ all_ips[ip] = len(all_ips)
112
+ user_id = all_ips[ip]
113
+
114
+ chats.append(
115
+ dict(
116
+ conversation_id=conversation_id,
117
+ model=model,
118
+ conversation=conversation,
119
+ turn=len(conversation) // 2,
120
+ language=lang_code,
121
+ user_id=user_id,
122
+ tstamp=row["tstamp"],
123
+ )
124
+ )
125
+
126
+ all_models.update([model])
127
+
128
+ chats.sort(key=lambda x: x["tstamp"])
129
+ last_updated_tstamp = chats[-1]["tstamp"]
130
+ last_updated_datetime = datetime.datetime.fromtimestamp(
131
+ last_updated_tstamp, tz=timezone("US/Pacific")
132
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
133
+
134
+ # Deduplication
135
+ dedup_chats = []
136
+ visited_conv_ids = set()
137
+ for i in reversed(range(len(chats))):
138
+ if chats[i]["conversation_id"] in visited_conv_ids:
139
+ continue
140
+ visited_conv_ids.add(chats[i]["conversation_id"])
141
+ dedup_chats.append(chats[i])
142
+
143
+ print(
144
+ f"#raw: {len(raw_data)}, #chat: {len(chats)}, #dedup_chat: {len(dedup_chats)}"
145
+ )
146
+ print(
147
+ f"#invalid_conv_id: {ct_invalid_conv_id}, #network_error: {ct_network_error}, #invalid: {ct_invalid}"
148
+ )
149
+ print(f"#models: {len(all_models)}, {all_models}")
150
+ print(f"last-updated: {last_updated_datetime}")
151
+
152
+ return list(reversed(dedup_chats))
153
+
154
+
155
+ if __name__ == "__main__":
156
+ parser = argparse.ArgumentParser()
157
+ parser.add_argument("--action-type", type=str, default="chat")
158
+ parser.add_argument("--max-num-files", type=int)
159
+ args = parser.parse_args()
160
+
161
+ log_files = get_log_files(args.max_num_files)
162
+ chats = clean_chat_data(log_files, args.action_type)
163
+ last_updated_tstamp = chats[-1]["tstamp"]
164
+ cutoff_date = datetime.datetime.fromtimestamp(
165
+ last_updated_tstamp, tz=timezone("US/Pacific")
166
+ ).strftime("%Y%m%d")
167
+
168
+ output = f"clean_{args.action_type}_conv_{cutoff_date}.json"
169
+ with open(output, "w") as fout:
170
+ json.dump(chats, fout, indent=2, ensure_ascii=False)
171
+ print(f"Write cleaned data to {output}")
monitor/dataset_release_scripts/arena_33k/count_unique_users.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Count the unique users in a battle log file."""
2
+
3
+ import argparse
4
+ import json
5
+
6
+
7
+ if __name__ == "__main__":
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--input", type=str)
10
+ args = parser.parse_args()
11
+
12
+ lines = json.load(open(args.input))
13
+ ct_anony_votes = 0
14
+ all_users = set()
15
+ all_models = set()
16
+ for l in lines:
17
+ if not l["anony"]:
18
+ continue
19
+ all_users.add(l["judge"])
20
+ all_models.add(l["model_a"])
21
+ all_models.add(l["model_b"])
22
+ ct_anony_votes += 1
23
+
24
+ print(f"#anony_vote: {ct_anony_votes}, #user: {len(all_users)}")
25
+ print(f"#model: {len(all_models)}")
monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Filter conversations for release.
3
+
4
+ Usage: python3 filter_bad_conv.py --in clean_battle_conv_20230630_tagged_v1_pii.json
5
+ """
6
+ import argparse
7
+ from collections import defaultdict
8
+ from enum import Enum, auto
9
+ import json
10
+ import os
11
+ import random
12
+
13
+ from tqdm import tqdm
14
+
15
+ BLOCKED_WORDS_FILENAME = "blocked_words.json"
16
+ blocked_words = []
17
+ frequency = defaultdict(lambda: 0)
18
+
19
+
20
+ class TypeCode(Enum):
21
+ CORRECT = auto()
22
+ ANONYMIZED = auto()
23
+ REDACTED = auto()
24
+ BAD_FORMAT = auto()
25
+ BLOCKED_WORD = auto()
26
+ BLOCKED_MODEL = auto()
27
+ TOO_SHORT = auto()
28
+ TOO_FREQUENT = auto()
29
+
30
+
31
+ def detect_type(conv):
32
+ for key in ["conversation_a", "conversation_b"]:
33
+ messages = [row["content"] for row in conv[key]]
34
+ for msg in messages:
35
+ if not isinstance(msg, str):
36
+ return TypeCode.BAD_FORMAT
37
+
38
+ user_prompts = [
39
+ row["content"].lower().strip() for row in conv[key] if row["role"] == "user"
40
+ ]
41
+ if len(messages) <= 2 and all(len(x) < 16 for x in user_prompts):
42
+ return TypeCode.TOO_SHORT
43
+
44
+ if all(x in frequent_prompts for x in user_prompts):
45
+ return TypeCode.TOO_FREQUENT
46
+
47
+ for msg in messages:
48
+ msg = msg.lower()
49
+ if "<anonymized>" in msg:
50
+ return TypeCode.ANONYMIZED
51
+ if "<redacted>" in msg:
52
+ return TypeCode.REDACTED
53
+
54
+ for w in blocked_words:
55
+ if w in msg:
56
+ return TypeCode.BLOCKED_WORD
57
+
58
+ for key in ["model_a", "model_b"]:
59
+ if conv[key] in ["vicuna-33b", "mpt-30b-chat"]:
60
+ return TypeCode.BLOCKED_MODEL
61
+
62
+ return TypeCode.CORRECT
63
+
64
+
65
+ if __name__ == "__main__":
66
+ parser = argparse.ArgumentParser()
67
+ parser.add_argument("--in-file", type=str, required=True)
68
+ parser.add_argument("--sample", type=int)
69
+ args = parser.parse_args()
70
+
71
+ # Read conversations
72
+ convs = json.load(open(args.in_file))
73
+ print(f"#conv: {len(convs)}")
74
+
75
+ # Read blocked words
76
+ if os.path.exists(BLOCKED_WORDS_FILENAME):
77
+ blocked_words = json.load(open(BLOCKED_WORDS_FILENAME))
78
+
79
+ # Count frequency
80
+ for conv in convs:
81
+ for key in ["conversation_a", "conversation_b"]:
82
+ messages = [row["content"] for row in conv[key] if row["role"] == "user"]
83
+ for msg in messages:
84
+ if not isinstance(msg, str):
85
+ continue
86
+ msg = msg.lower().strip()
87
+ frequency[msg] += 1
88
+
89
+ keys = list(frequency.keys())
90
+ keys.sort(key=lambda x: -frequency[x])
91
+ frequent_prompts = keys[:10]
92
+ frequent_prompts = set(frequent_prompts)
93
+ frequent_prompts.add("")
94
+
95
+ # Start filter
96
+ ct_bad_format = 0
97
+ ct_anonymized = 0
98
+ ct_redacted = 0
99
+ ct_error = 0
100
+ ct_lang_filter = 0
101
+ ct_flagged = 0
102
+ ct_blocked_word = 0
103
+ ct_blocked_model = 0
104
+ ct_too_short = 0
105
+ ct_too_frequent = 0
106
+
107
+ new_convs = []
108
+ for conv in tqdm(convs):
109
+ type_code = detect_type(conv)
110
+
111
+ if type_code == TypeCode.BAD_FORMAT:
112
+ ct_bad_format += 1
113
+ continue
114
+
115
+ if type_code == TypeCode.ANONYMIZED:
116
+ ct_anonymized += 1
117
+ continue
118
+ elif type_code == TypeCode.REDACTED:
119
+ ct_redacted += 1
120
+ continue
121
+ elif type_code == TypeCode.BLOCKED_WORD:
122
+ ct_blocked_word += 1
123
+ continue
124
+ elif type_code == TypeCode.BLOCKED_MODEL:
125
+ ct_blocked_model += 1
126
+ continue
127
+ elif type_code == TypeCode.TOO_SHORT:
128
+ ct_too_short += 1
129
+ continue
130
+ elif type_code == TypeCode.TOO_FREQUENT:
131
+ ct_too_frequent += 1
132
+ continue
133
+
134
+ if conv["openai_moderation"]["flagged"]:
135
+ ct_flagged += 1
136
+ continue
137
+
138
+ if type_code in [TypeCode.CORRECT]:
139
+ new_convs.append(conv)
140
+
141
+ if args.sample:
142
+ # random.seed(0)
143
+ # random.shuffle(new_convs)
144
+ new_convs = new_convs[: args.sample]
145
+
146
+ print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}")
147
+ print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}")
148
+ print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}")
149
+ print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_anonymized}")
150
+ print(f"new_conv: {len(new_convs)}")
151
+
152
+ out_file = args.in_file.replace(".json", ".out.json")
153
+ print(f"Output to {out_file}")
154
+ with open(out_file, "w") as fout:
155
+ json.dump(new_convs, fout, indent=2, ensure_ascii=False)
monitor/dataset_release_scripts/arena_33k/merge_field.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Count the unique users in a battle log file."""
2
+
3
+ import argparse
4
+ import json
5
+
6
+
7
+ if __name__ == "__main__":
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--input", type=str)
10
+ parser.add_argument("--tag-file", type=str)
11
+ args = parser.parse_args()
12
+
13
+ # build index
14
+ objs = json.load(open(args.tag_file))
15
+ new_field_dict = {}
16
+ for obj in objs:
17
+ new_field_dict[obj["question_id"]] = obj["toxic_chat"]
18
+
19
+ objs = json.load(open(args.input))
20
+ for obj in objs:
21
+ obj["toxic_chat_tag"] = new_field_dict[obj["question_id"]]
22
+
23
+ output = args.input.replace(".json", "_added.json")
24
+ with open(output, "w") as fout:
25
+ json.dump(objs, fout, indent=2, ensure_ascii=False)
monitor/dataset_release_scripts/arena_33k/sample.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Count the unique users in a battle log file.
3
+
4
+ Usage:
5
+ python3 -input in.json --number 1000
6
+ """
7
+
8
+ import argparse
9
+ import json
10
+ import random
11
+
12
+ K = 1000
13
+
14
+ if __name__ == "__main__":
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("--input", type=str)
17
+ parser.add_argument("--number", type=int, nargs="+")
18
+ args = parser.parse_args()
19
+
20
+ convs = json.load(open(args.input))
21
+ random.seed(0)
22
+ random.shuffle(convs)
23
+
24
+ for number in args.number:
25
+ new_convs = convs[:number]
26
+
27
+ output = args.input.replace(".json", f"_{number//K}k.json")
28
+ with open(output, "w") as fout:
29
+ json.dump(new_convs, fout, indent=2, ensure_ascii=False)
30
+
31
+ print(f"#in: {len(convs)}, #out: {len(new_convs)}")
32
+ print(f"Write to file: {output}")
monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Upload to huggingface.
3
+ """
4
+ import json
5
+ from datasets import Dataset, DatasetDict, load_dataset
6
+
7
+ objs = json.load(open("clean_battle_conv_20230630_tagged_v3_pii_33k_added.json"))
8
+ data = Dataset.from_list(objs)
9
+ data.push_to_hub("lmsys/chatbot_arena_conversations", private=True)
monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ headers = {"authorization": "Bearer hf_XXX"}
4
+
5
+ url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/pending"
6
+ a = requests.get(url, headers=headers)
7
+
8
+ for u in a.json():
9
+ user = u["user"]["user"]
10
+ url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/grant"
11
+ ret = requests.post(url, headers=headers, json={"user": user})
12
+ print(user, ret.status_code)
13
+ assert ret.status_code == 200
monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ From colab:
3
+ https://colab.research.google.com/drive/1oMdw_Lqgmd6DletSOLHsyD-Rc96cRShs?usp=sharing
4
+ """
5
+ import argparse
6
+ import datetime
7
+ import json
8
+ import os
9
+ from pytz import timezone
10
+ import time
11
+
12
+ import kaleido
13
+ import numpy as np
14
+ import pandas as pd
15
+ import plotly.express as px
16
+ import plotly.graph_objects as go
17
+ from tqdm import tqdm
18
+
19
+ import plotly.io as pio
20
+
21
+ pio.kaleido.scope.mathjax = None
22
+
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--in-file", type=str, required=True)
25
+ parser.add_argument("--scale", type=int, required=True)
26
+ args = parser.parse_args()
27
+
28
+ filename = args.in_file
29
+ scale = args.scale
30
+ convs = json.load(open(filename))
31
+ df = pd.DataFrame(convs)
32
+ df
33
+
34
+ print(f"#ips: {df['user_id'].nunique() * scale}")
35
+ print(f"#models: {df['model'].nunique()}")
36
+ print(f"#language: {df['language'].nunique()}")
37
+ print(f"#turns: {df['turn'].mean()}")
38
+
39
+ model_counts = df["model"].value_counts() * scale
40
+ # print("model counts", model_counts)
41
+ fig = px.bar(x=model_counts.index, y=model_counts)
42
+ fig.update_layout(
43
+ xaxis_title=None,
44
+ yaxis_title="Count",
45
+ height=200,
46
+ width=950,
47
+ margin=dict(l=0, r=0, t=0, b=0),
48
+ )
49
+ fig.show()
50
+ fig.write_image("model_count.pdf")
51
+
52
+
53
+ model_counts = df["language"].value_counts().head(25) * scale
54
+ fig = px.bar(x=model_counts.index, y=model_counts)
55
+ fig.update_layout(
56
+ xaxis_title=None,
57
+ yaxis_title="Count",
58
+ height=200,
59
+ width=950,
60
+ margin=dict(l=0, r=0, t=0, b=0),
61
+ )
62
+ fig.show()
63
+ fig.write_image("language_count.pdf")
64
+
65
+ chat_dates = [
66
+ datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime("%Y-%m-%d")
67
+ for x in df["tstamp"]
68
+ ]
69
+
70
+
71
+ def to_remove(x):
72
+ for d in ["08-09", "08-08", "08-07", "08-06", "08-05", "08-04"]:
73
+ if d in x:
74
+ return True
75
+ return False
76
+
77
+
78
+ chat_dates = [x for x in chat_dates if not to_remove(x)]
79
+
80
+ chat_dates_counts = pd.value_counts(chat_dates) * scale
81
+ print(f"mean #chat per day: {np.mean(chat_dates_counts):.2f}")
82
+
83
+ fig = px.bar(x=chat_dates_counts.index, y=chat_dates_counts)
84
+ fig.update_layout(
85
+ xaxis_title="Dates",
86
+ yaxis_title="Count",
87
+ height=200,
88
+ width=950,
89
+ margin=dict(l=0, r=0, t=0, b=0),
90
+ )
91
+ fig.show()
92
+ fig.write_image("daily_conversation_count.pdf")
93
+
94
+ import transformers
95
+
96
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
97
+ "lmsys/vicuna-7b-v1.5", use_fast=False
98
+ )
99
+
100
+ prompts = []
101
+ responses = []
102
+ for conv in df["conversation"]:
103
+ for row in conv:
104
+ if row["role"] == "user":
105
+ prompts.append(row["content"])
106
+ else:
107
+ responses.append(row["content"])
108
+
109
+ print(f"#prompts: {len(prompts)}")
110
+ print(f"#responses: {len(responses)}")
111
+
112
+
113
+ prompt_lens = [len(tokenizer(x).input_ids) for x in tqdm(prompts)]
114
+ print()
115
+ print(f"mean prompt len: {np.mean(prompt_lens):.2f}")
116
+
117
+ response_lens = [len(tokenizer(x).input_ids) if x else 0 for x in tqdm(responses)]
118
+ print()
119
+ print(f"mean response len: {np.mean(response_lens):.2f}")
monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Filter conversations for release.
3
+
4
+ Dependency:
5
+ pip install opencc-python-reimplementedpip install opencc-python-reimplemented
6
+
7
+ Usage:
8
+ python3 filter_bad_conv_lmsys_chat_1m.py --in clean_battle_conv_20230630_tagged_v1_pii.json
9
+ """
10
+ import argparse
11
+ from concurrent.futures import ProcessPoolExecutor
12
+ from collections import defaultdict
13
+ from enum import Enum, auto
14
+ import json
15
+ import os
16
+ import random
17
+
18
+ from tqdm import tqdm
19
+ import opencc
20
+
21
+ BLOCKED_WORDS_FILENAME = "blocked_words.json"
22
+ blocked_words = []
23
+ frequency = defaultdict(lambda: 0)
24
+
25
+ cc_converter = opencc.OpenCC("t2s")
26
+
27
+
28
+ class TypeCode(Enum):
29
+ CORRECT = auto()
30
+ ANONYMIZED = auto()
31
+ REDACTED = auto()
32
+ BAD_FORMAT = auto()
33
+ BLOCKED_WORD = auto()
34
+ BLOCKED_MODEL = auto()
35
+ TOO_SHORT = auto()
36
+ TOO_FREQUENT = auto()
37
+
38
+
39
+ def detect_type(conv):
40
+ for key in ["conversation_a", "conversation_b", "conversation"]:
41
+ if key not in conv:
42
+ continue
43
+
44
+ messages = [row["content"] for row in conv[key]]
45
+ for msg in messages:
46
+ if not isinstance(msg, str):
47
+ return TypeCode.BAD_FORMAT
48
+
49
+ if len(messages) == 0:
50
+ return TypeCode.BAD_FORMAT
51
+
52
+ user_prompts = [
53
+ row["content"].lower().strip() for row in conv[key] if row["role"] == "user"
54
+ ]
55
+
56
+ for msg in messages:
57
+ msg = cc_converter.convert(msg.lower())
58
+ if "<anonymized>" in msg:
59
+ return TypeCode.ANONYMIZED
60
+ if "<redacted>" in msg:
61
+ return TypeCode.REDACTED
62
+
63
+ for w in blocked_words:
64
+ if w in msg:
65
+ return TypeCode.BLOCKED_WORD
66
+
67
+ return TypeCode.CORRECT
68
+
69
+
70
+ if __name__ == "__main__":
71
+ parser = argparse.ArgumentParser()
72
+ parser.add_argument("--in-file", type=str, required=True)
73
+ parser.add_argument("--sample", type=int)
74
+ args = parser.parse_args()
75
+
76
+ # Read conversations
77
+ convs = json.load(open(args.in_file))
78
+ print(f"#conv: {len(convs)}")
79
+
80
+ # Read blocked words
81
+ if os.path.exists(BLOCKED_WORDS_FILENAME):
82
+ blocked_words = json.load(open(BLOCKED_WORDS_FILENAME))
83
+ blocked_words = [cc_converter.convert(w) for w in blocked_words]
84
+
85
+ # Start filter
86
+ ct_bad_format = 0
87
+ ct_anonymized = 0
88
+ ct_redacted = 0
89
+ ct_error = 0
90
+ ct_lang_filter = 0
91
+ ct_flagged = 0
92
+ ct_blocked_word = 0
93
+ ct_blocked_model = 0
94
+ ct_too_short = 0
95
+ ct_too_frequent = 0
96
+
97
+ type_codes = []
98
+ with ProcessPoolExecutor() as executor:
99
+ for result in tqdm(executor.map(detect_type, convs), total=len(convs)):
100
+ type_codes.append(result)
101
+
102
+ new_convs = []
103
+ for conv, type_code in zip(convs, type_codes):
104
+ if type_code == TypeCode.BAD_FORMAT:
105
+ ct_bad_format += 1
106
+ continue
107
+
108
+ if type_code == TypeCode.ANONYMIZED:
109
+ ct_anonymized += 1
110
+ continue
111
+ elif type_code == TypeCode.REDACTED:
112
+ ct_redacted += 1
113
+ continue
114
+ elif type_code == TypeCode.BLOCKED_WORD:
115
+ ct_blocked_word += 1
116
+ continue
117
+ elif type_code == TypeCode.BLOCKED_MODEL:
118
+ ct_blocked_model += 1
119
+ continue
120
+ elif type_code == TypeCode.TOO_SHORT:
121
+ ct_too_short += 1
122
+ continue
123
+ elif type_code == TypeCode.TOO_FREQUENT:
124
+ ct_too_frequent += 1
125
+ continue
126
+
127
+ if "openai_moderation" in conv and conv["openai_moderation"]["flagged"]:
128
+ ct_flagged += 1
129
+ continue
130
+
131
+ if type_code in [TypeCode.CORRECT]:
132
+ new_convs.append(conv)
133
+
134
+ if args.sample:
135
+ random.seed(42)
136
+ random.shuffle(new_convs)
137
+ new_convs = new_convs[: args.sample]
138
+
139
+ print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}")
140
+ print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}")
141
+ print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}")
142
+ print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_too_frequent}")
143
+ print(f"new_conv: {len(new_convs)}")
144
+
145
+ out_file = args.in_file.replace(".json", ".s1.json")
146
+ print(f"Output to {out_file}")
147
+ with open(out_file, "w") as fout:
148
+ json.dump(new_convs, fout, indent=2, ensure_ascii=False)
monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+
7
+
8
+ if __name__ == "__main__":
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument("--in-file", type=str, required=True)
11
+ args = parser.parse_args()
12
+
13
+ # Read conversations
14
+ convs = json.load(open(args.in_file))
15
+ print(f"#conv: {len(convs)}")
16
+
17
+ # Delete some fileds
18
+ for c in convs:
19
+ del c["tstamp"]
20
+ del c["user_id"]
21
+
22
+ # Write
23
+ print(f"#out conv: {len(convs)}")
24
+ out_file = args.in_file.replace(".json", ".s2.json")
25
+ print(f"Output to {out_file}")
26
+ with open(out_file, "w") as fout:
27
+ json.dump(convs, fout, indent=2, ensure_ascii=False)
monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```
2
+ export BASE=clean_conv_20230809_100k_pii
3
+ export SCALE=10
4
+
5
+ # filter words
6
+ python3 filter_bad_conv.py --in $BASE.json
7
+
8
+ # Clean up some fileds (e.g., timestamps)
9
+ python3 final_post_processing.py --in $BASE.s1.json
10
+
11
+ # upload to hf
12
+ python3 upload_hf_dataset.py --in $BASE.s1.s2.json
13
+
14
+ # Make another version with openai moderation tag
15
+ python3 merge_oai_tag.py --in $BASE.s1.s2.json
16
+
17
+ # Make visualizations
18
+ python3 compute_stats.py --in $BASE.s1.json --scale $SCALE
19
+
20
+ # Copy figures
21
+ scp "atlas:/data/lmzheng/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/*.pdf" .
22
+ ```
23
+
monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+
5
+ from tqdm import tqdm
6
+
7
+
8
+ if __name__ == "__main__":
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument("--in-file", type=str, required=True)
11
+ parser.add_argument("--sample", type=int)
12
+ args = parser.parse_args()
13
+
14
+ tag_file = "clean_conv_20230809_1.5M_oai_filter_v2.json"
15
+ # tag_file = "clean_conv_20230809_1.5M_oai_filter_v2_100k.json"
16
+ in_file = args.in_file
17
+ tic = time.time()
18
+
19
+ # Load tags
20
+ print("Load tags...")
21
+ tag_data = json.load(open(tag_file))
22
+ tag_dict = {}
23
+ for c in tqdm(tag_data):
24
+ tag_dict[c["conversation_id"]] = [x["oai_filter"] for x in c["conversation"]]
25
+ print(f"elapsed: {time.time() - tic:.2f} s")
26
+
27
+ # Append to input_file
28
+ print("Load inputs...")
29
+ input_data = json.load(open(in_file))
30
+ for c in tqdm(input_data):
31
+ cid = c["conversation_id"]
32
+ if cid in tag_dict:
33
+ c["openai_moderation"] = tag_dict[cid]
34
+ else:
35
+ print(f"missing tag for conv {cid}")
36
+ exit()
37
+ print(f"elapsed: {time.time() - tic:.2f} s")
38
+
39
+ # Write output
40
+ print("Write outputs...")
41
+ out_file = in_file.replace(".json", ".with_tag.json")
42
+ print(f"Output to {out_file}")
43
+ with open(out_file, "w") as fout:
44
+ json.dump(input_data, fout, indent=2, ensure_ascii=False)
45
+ print(f"elapsed: {time.time() - tic:.2f} s")
monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export BASE=clean_conv_20230809_1.5M_pii
2
+ #export BASE=clean_conv_20230809_100k_pii
3
+ export SCALE=1
4
+
5
+ # Filter words
6
+ python3 filter_bad_conv.py --in $BASE.json --sample 1000000
7
+
8
+ # Clean up some fileds (e.g., timestamps)
9
+ python3 final_post_processing.py --in $BASE.s1.json
10
+
11
+ # Upload to hf
12
+ python3 upload_hf_dataset.py --in $BASE.s1.s2.json
13
+
14
+ # Make another version with openai moderation tag
15
+ python3 merge_oai_tag.py --in $BASE.s1.s2.json
16
+
17
+ # Make visualizations
18
+ python3 compute_stats.py --in $BASE.s1.json --scale $SCALE
monitor/dataset_release_scripts/lmsys_chat_1m/sample.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Count the unique users in a battle log file.
3
+
4
+ Usage:
5
+ python3 -input in.json --number 1000
6
+ """
7
+
8
+ import argparse
9
+ import json
10
+ import random
11
+
12
+ K = 1000
13
+
14
+ if __name__ == "__main__":
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("--input", type=str)
17
+ parser.add_argument("--number", type=int, nargs="+")
18
+ args = parser.parse_args()
19
+
20
+ convs = json.load(open(args.input))
21
+ random.seed(42)
22
+ random.shuffle(convs)
23
+
24
+ for number in args.number:
25
+ new_convs = convs[:number]
26
+
27
+ output = args.input.replace(".json", f"_{number//K}k.json")
28
+ with open(output, "w") as fout:
29
+ json.dump(new_convs, fout, indent=2, ensure_ascii=False)
30
+
31
+ print(f"#in: {len(convs)}, #out: {len(new_convs)}")
32
+ print(f"Write to file: {output}")
monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Upload to huggingface.
3
+ """
4
+ import argparse
5
+ import json
6
+ from datasets import Dataset, DatasetDict, load_dataset
7
+
8
+
9
+ if __name__ == "__main__":
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--in-file", type=str, required=True)
12
+ args = parser.parse_args()
13
+
14
+ objs = json.load(open(args.in_file))
15
+ print(f"#convs: {len(objs)}")
16
+ data = Dataset.from_list(objs)
17
+ data.push_to_hub("lmsys/lmsys-chat-1m", private=True)
monitor/elo_analysis.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from collections import defaultdict
3
+ import datetime
4
+ import json
5
+ import math
6
+ import pickle
7
+ from pytz import timezone
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import plotly.express as px
12
+ from tqdm import tqdm
13
+
14
+ from fastchat.model.model_registry import get_model_info
15
+ from fastchat.serve.monitor.basic_stats import get_log_files
16
+ from fastchat.serve.monitor.clean_battle_data import clean_battle_data
17
+
18
+
19
+ pd.options.display.float_format = "{:.2f}".format
20
+
21
+
22
+ def compute_elo(battles, K=4, SCALE=400, BASE=10, INIT_RATING=1000):
23
+ rating = defaultdict(lambda: INIT_RATING)
24
+
25
+ for rd, model_a, model_b, winner in battles[
26
+ ["model_a", "model_b", "winner"]
27
+ ].itertuples():
28
+ ra = rating[model_a]
29
+ rb = rating[model_b]
30
+ ea = 1 / (1 + BASE ** ((rb - ra) / SCALE))
31
+ eb = 1 / (1 + BASE ** ((ra - rb) / SCALE))
32
+ if winner == "model_a":
33
+ sa = 1
34
+ elif winner == "model_b":
35
+ sa = 0
36
+ elif winner == "tie" or winner == "tie (bothbad)":
37
+ sa = 0.5
38
+ else:
39
+ raise Exception(f"unexpected vote {winner}")
40
+ rating[model_a] += K * (sa - ea)
41
+ rating[model_b] += K * (1 - sa - eb)
42
+
43
+ return dict(rating)
44
+
45
+
46
+ def get_bootstrap_result(battles, func_compute_elo, num_round=1000):
47
+ rows = []
48
+ for i in tqdm(range(num_round), desc="bootstrap"):
49
+ tmp_battles = battles.sample(frac=1.0, replace=True)
50
+ rows.append(func_compute_elo(tmp_battles))
51
+ df = pd.DataFrame(rows)
52
+ return df[df.median().sort_values(ascending=False).index]
53
+
54
+
55
+ def get_median_elo_from_bootstrap(bootstrap_df):
56
+ median = dict(bootstrap_df.quantile(0.5))
57
+ median = {k: int(v + 0.5) for k, v in median.items()}
58
+ return median
59
+
60
+
61
+ def compute_pairwise_win_fraction(battles, model_order, limit_show_number=None):
62
+ # Times each model wins as Model A
63
+ a_win_ptbl = pd.pivot_table(
64
+ battles[battles["winner"] == "model_a"],
65
+ index="model_a",
66
+ columns="model_b",
67
+ aggfunc="size",
68
+ fill_value=0,
69
+ )
70
+
71
+ # Table counting times each model wins as Model B
72
+ b_win_ptbl = pd.pivot_table(
73
+ battles[battles["winner"] == "model_b"],
74
+ index="model_a",
75
+ columns="model_b",
76
+ aggfunc="size",
77
+ fill_value=0,
78
+ )
79
+
80
+ # Table counting number of A-B pairs
81
+ num_battles_ptbl = pd.pivot_table(
82
+ battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0
83
+ )
84
+
85
+ # Computing the proportion of wins for each model as A and as B
86
+ # against all other models
87
+ row_beats_col_freq = (a_win_ptbl + b_win_ptbl.T) / (
88
+ num_battles_ptbl + num_battles_ptbl.T
89
+ )
90
+
91
+ if model_order is None:
92
+ prop_wins = row_beats_col_freq.mean(axis=1).sort_values(ascending=False)
93
+ model_order = list(prop_wins.keys())
94
+
95
+ if limit_show_number is not None:
96
+ model_order = model_order[:limit_show_number]
97
+
98
+ # Arrange ordering according to proprition of wins
99
+ row_beats_col = row_beats_col_freq.loc[model_order, model_order]
100
+ return row_beats_col
101
+
102
+
103
+ def visualize_leaderboard_table(rating):
104
+ models = list(rating.keys())
105
+ models.sort(key=lambda k: -rating[k])
106
+
107
+ emoji_dict = {
108
+ 1: "🥇",
109
+ 2: "🥈",
110
+ 3: "🥉",
111
+ }
112
+
113
+ md = ""
114
+ md += "| Rank | Model | Elo Rating | Description |\n"
115
+ md += "| --- | --- | --- | --- |\n"
116
+ for i, model in enumerate(models):
117
+ rank = i + 1
118
+ minfo = get_model_info(model)
119
+ emoji = emoji_dict.get(rank, "")
120
+ md += f"| {rank} | {emoji} [{model}]({minfo.link}) | {rating[model]:.0f} | {minfo.description} |\n"
121
+
122
+ return md
123
+
124
+
125
+ def visualize_pairwise_win_fraction(battles, model_order):
126
+ row_beats_col = compute_pairwise_win_fraction(battles, model_order)
127
+ fig = px.imshow(
128
+ row_beats_col,
129
+ color_continuous_scale="RdBu",
130
+ text_auto=".2f",
131
+ height=700,
132
+ width=700,
133
+ )
134
+ fig.update_layout(
135
+ xaxis_title="Model B",
136
+ yaxis_title="Model A",
137
+ xaxis_side="top",
138
+ title_y=0.07,
139
+ title_x=0.5,
140
+ )
141
+ fig.update_traces(
142
+ hovertemplate="Model A: %{y}<br>Model B: %{x}<br>Fraction of A Wins: %{z}<extra></extra>"
143
+ )
144
+
145
+ return fig
146
+
147
+
148
+ def visualize_battle_count(battles, model_order):
149
+ ptbl = pd.pivot_table(
150
+ battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0
151
+ )
152
+ battle_counts = ptbl + ptbl.T
153
+ fig = px.imshow(
154
+ battle_counts.loc[model_order, model_order],
155
+ text_auto=True,
156
+ height=700,
157
+ width=700,
158
+ )
159
+ fig.update_layout(
160
+ xaxis_title="Model B",
161
+ yaxis_title="Model A",
162
+ xaxis_side="top",
163
+ title_y=0.07,
164
+ title_x=0.5,
165
+ )
166
+ fig.update_traces(
167
+ hovertemplate="Model A: %{y}<br>Model B: %{x}<br>Count: %{z}<extra></extra>"
168
+ )
169
+ return fig
170
+
171
+
172
+ def visualize_average_win_rate(battles, limit_show_number):
173
+ row_beats_col_freq = compute_pairwise_win_fraction(
174
+ battles, None, limit_show_number=limit_show_number
175
+ )
176
+ fig = px.bar(
177
+ row_beats_col_freq.mean(axis=1).sort_values(ascending=False),
178
+ text_auto=".2f",
179
+ height=500,
180
+ width=700,
181
+ )
182
+ fig.update_layout(
183
+ yaxis_title="Average Win Rate", xaxis_title="Model", showlegend=False
184
+ )
185
+ return fig
186
+
187
+
188
+ def visualize_bootstrap_elo_rating(df, limit_show_number):
189
+ bars = (
190
+ pd.DataFrame(
191
+ dict(
192
+ lower=df.quantile(0.025),
193
+ rating=df.quantile(0.5),
194
+ upper=df.quantile(0.975),
195
+ )
196
+ )
197
+ .reset_index(names="model")
198
+ .sort_values("rating", ascending=False)
199
+ )
200
+ bars = bars[:limit_show_number]
201
+ bars["error_y"] = bars["upper"] - bars["rating"]
202
+ bars["error_y_minus"] = bars["rating"] - bars["lower"]
203
+ bars["rating_rounded"] = np.round(bars["rating"], 2)
204
+ fig = px.scatter(
205
+ bars,
206
+ x="model",
207
+ y="rating",
208
+ error_y="error_y",
209
+ error_y_minus="error_y_minus",
210
+ text="rating_rounded",
211
+ height=500,
212
+ width=700,
213
+ )
214
+ fig.update_layout(xaxis_title="Model", yaxis_title="Rating")
215
+ return fig
216
+
217
+
218
+ def report_elo_analysis_results(battles_json):
219
+ battles = pd.DataFrame(battles_json)
220
+ battles = battles.sort_values(ascending=True, by=["tstamp"])
221
+ # Only use anonymous votes
222
+ battles = battles[battles["anony"]].reset_index(drop=True)
223
+ battles_no_ties = battles[~battles["winner"].str.contains("tie")]
224
+
225
+ # Online update
226
+ elo_rating_online = compute_elo(battles)
227
+
228
+ # Bootstrap
229
+ bootstrap_df = get_bootstrap_result(battles, compute_elo)
230
+ elo_rating_median = get_median_elo_from_bootstrap(bootstrap_df)
231
+ model_order = list(elo_rating_median.keys())
232
+ model_order.sort(key=lambda k: -elo_rating_median[k])
233
+
234
+ limit_show_number = 25 # limit show number to make plots smaller
235
+ model_order = model_order[:limit_show_number]
236
+
237
+ # Plots
238
+ leaderboard_table = visualize_leaderboard_table(elo_rating_median)
239
+ win_fraction_heatmap = visualize_pairwise_win_fraction(battles_no_ties, model_order)
240
+ battle_count_heatmap = visualize_battle_count(battles_no_ties, model_order)
241
+ average_win_rate_bar = visualize_average_win_rate(
242
+ battles_no_ties, limit_show_number
243
+ )
244
+ bootstrap_elo_rating = visualize_bootstrap_elo_rating(
245
+ bootstrap_df, limit_show_number
246
+ )
247
+
248
+ last_updated_tstamp = battles["tstamp"].max()
249
+ last_updated_datetime = datetime.datetime.fromtimestamp(
250
+ last_updated_tstamp, tz=timezone("US/Pacific")
251
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
252
+
253
+ return {
254
+ "elo_rating_online": elo_rating_online,
255
+ "elo_rating_median": elo_rating_median,
256
+ "leaderboard_table": leaderboard_table,
257
+ "win_fraction_heatmap": win_fraction_heatmap,
258
+ "battle_count_heatmap": battle_count_heatmap,
259
+ "average_win_rate_bar": average_win_rate_bar,
260
+ "bootstrap_elo_rating": bootstrap_elo_rating,
261
+ "last_updated_datetime": last_updated_datetime,
262
+ "last_updated_tstamp": last_updated_tstamp,
263
+ }
264
+
265
+
266
+ def pretty_print_elo_rating(rating):
267
+ model_order = list(rating.keys())
268
+ model_order.sort(key=lambda k: -rating[k])
269
+ for i, model in enumerate(model_order):
270
+ print(f"{i+1:2d}, {model:25s}, {rating[model]:.0f}")
271
+
272
+
273
+ if __name__ == "__main__":
274
+ parser = argparse.ArgumentParser()
275
+ parser.add_argument("--clean-battle-file", type=str)
276
+ parser.add_argument("--max-num-files", type=int)
277
+ args = parser.parse_args()
278
+
279
+ np.random.seed(42)
280
+
281
+ if args.clean_battle_file:
282
+ # Read data from a cleaned battle files
283
+ battles = pd.read_json(args.clean_battle_file)
284
+ else:
285
+ # Read data from all log files
286
+ log_files = get_log_files(args.max_num_files)
287
+ battles = clean_battle_data(log_files)
288
+
289
+ results = report_elo_analysis_results(battles)
290
+
291
+ print("# Online")
292
+ pretty_print_elo_rating(results["elo_rating_online"])
293
+ print("# Median")
294
+ pretty_print_elo_rating(results["elo_rating_median"])
295
+ print(f"last update : {results['last_updated_datetime']}")
296
+
297
+ last_updated_tstamp = results["last_updated_tstamp"]
298
+ cutoff_date = datetime.datetime.fromtimestamp(
299
+ last_updated_tstamp, tz=timezone("US/Pacific")
300
+ ).strftime("%Y%m%d")
301
+
302
+ with open(f"elo_results_{cutoff_date}.pkl", "wb") as fout:
303
+ pickle.dump(results, fout)
monitor/inspect_conv.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import code
3
+ import datetime
4
+ import json
5
+ import os
6
+ from pytz import timezone
7
+ import time
8
+
9
+ import pandas as pd
10
+ from tqdm import tqdm
11
+
12
+
13
+ def get_log_files(max_num_files=None):
14
+ dates = []
15
+ for month in [4, 5]:
16
+ for day in range(1, 32):
17
+ dates.append(f"2023-{month:02d}-{day:02d}")
18
+
19
+ num_servers = 14
20
+ filenames = []
21
+ for d in dates:
22
+ for i in range(num_servers):
23
+ name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
24
+ if os.path.exists(name):
25
+ filenames.append(name)
26
+ max_num_files = max_num_files or len(filenames)
27
+ filenames = filenames[-max_num_files:]
28
+ return filenames
29
+
30
+
31
+ def pretty_print_conversation(messages):
32
+ for role, msg in messages:
33
+ print(f"[[{role}]]: {msg}")
34
+
35
+
36
+ def inspect_convs(log_files):
37
+ data = []
38
+ for filename in tqdm(log_files, desc="read files"):
39
+ for retry in range(5):
40
+ try:
41
+ lines = open(filename).readlines()
42
+ break
43
+ except FileNotFoundError:
44
+ time.sleep(2)
45
+
46
+ for l in lines:
47
+ row = json.loads(l)
48
+
49
+ if "states" not in row:
50
+ continue
51
+ if row["type"] not in ["leftvote", "rightvote", "bothbad_vote"]:
52
+ continue
53
+
54
+ model_names = row["states"][0]["model_name"], row["states"][1]["model_name"]
55
+ if row["type"] == "leftvote":
56
+ winner, loser = model_names[0], model_names[1]
57
+ winner_conv, loser_conv = row["states"][0], row["states"][1]
58
+ elif row["type"] == "rightvote":
59
+ loser, winner = model_names[0], model_names[1]
60
+ loser_conv, winner_conv = row["states"][0], row["states"][1]
61
+
62
+ if loser == "bard" and winner == "vicuna-13b":
63
+ print("=" * 20)
64
+ print(f"Winner: {winner}")
65
+ pretty_print_conversation(winner_conv["messages"])
66
+ print(f"Loser: {loser}")
67
+ pretty_print_conversation(loser_conv["messages"])
68
+ print("=" * 20)
69
+ input()
70
+
71
+ # if row["type"] == "bothbad_vote" and "gpt-4" in model_names:
72
+ # print("=" * 20)
73
+ # print(f"Model A: {model_names[0]}")
74
+ # pretty_print_conversation(row["states"][0]["messages"])
75
+ # print(f"Model B: {model_names[1]}")
76
+ # pretty_print_conversation(row["states"][1]["messages"])
77
+ # print("=" * 20)
78
+ # input()
79
+
80
+
81
+ if __name__ == "__main__":
82
+ parser = argparse.ArgumentParser()
83
+ parser.add_argument("--max-num-files", type=int)
84
+ args = parser.parse_args()
85
+
86
+ log_files = get_log_files(args.max_num_files)
87
+ inspect_convs(log_files)
monitor/intersect_conv_file.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Take the intersection of two conversation files.
3
+
4
+ Usage: python3 -m fastchat.data.merge --input input.json --conv-id conv_id_file.json --out intersect.json
5
+ """
6
+
7
+ import argparse
8
+ import json
9
+
10
+
11
+ if __name__ == "__main__":
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--input", type=str, required=True)
14
+ parser.add_argument("--conv-id", type=str, required=True)
15
+ parser.add_argument("--out-file", type=str, default="intersect.json")
16
+ args = parser.parse_args()
17
+
18
+ conv_id_objs = json.load(open(args.conv_id, "r"))
19
+ conv_ids = set(x["conversation_id"] for x in conv_id_objs)
20
+
21
+ objs = json.load(open(args.input, "r"))
22
+ after_objs = [x for x in objs if x["conversation_id"] in conv_ids]
23
+
24
+ print(f"#in: {len(objs)}, #out: {len(after_objs)}")
25
+ json.dump(after_objs, open(args.out_file, "w"), indent=2, ensure_ascii=False)
monitor/leaderboard_csv_to_html.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert a leaderboard csv file to html table used in the blog.
3
+
4
+ Usage:
5
+ python3 leaderboard_csv_to_html.py --in leaderboard_table_20230619.csv
6
+ """
7
+ import argparse
8
+
9
+ import numpy as np
10
+
11
+ from fastchat.serve.monitor.monitor import load_leaderboard_table_csv
12
+
13
+
14
+ def model_hyperlink(model_name, link):
15
+ return f'<a target="_blank" href="{link}"> {model_name} </a>'
16
+
17
+
18
+ if __name__ == "__main__":
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--input", type=str, required=True)
21
+ args = parser.parse_args()
22
+
23
+ data = load_leaderboard_table_csv(args.input, add_hyperlink=False)
24
+ headers = [
25
+ "Model",
26
+ "MT-bench (score)",
27
+ "Arena Elo rating",
28
+ "MMLU",
29
+ "License",
30
+ ]
31
+ values = []
32
+ for item in data:
33
+ row = []
34
+ for key in headers:
35
+ value = item[key]
36
+ row.append(value)
37
+ row[0] = model_hyperlink(item["Model"], item["Link"])
38
+ values.append(row)
39
+ values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9)
40
+
41
+ for value in values:
42
+ row = "<tr>"
43
+ for x in value:
44
+ try:
45
+ if np.isnan(x):
46
+ x = "-"
47
+ except TypeError:
48
+ pass
49
+ row += f" <td>{x}</td> "
50
+ row += "</tr>"
51
+ print(row)
monitor/monitor.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Live monitor of the website statistics and leaderboard.
3
+
4
+ Dependency:
5
+ sudo apt install pkg-config libicu-dev
6
+ pip install pytz gradio gdown plotly polyglot pyicu pycld2 tabulate
7
+ """
8
+
9
+ import argparse
10
+ import ast
11
+ import pickle
12
+ import os
13
+ import threading
14
+ import time
15
+
16
+ import gradio as gr
17
+ import numpy as np
18
+
19
+ from fastchat.serve.monitor.basic_stats import report_basic_stats, get_log_files
20
+ from fastchat.serve.monitor.clean_battle_data import clean_battle_data
21
+ from fastchat.serve.monitor.elo_analysis import report_elo_analysis_results
22
+ from fastchat.utils import build_logger, get_window_url_params_js
23
+
24
+
25
+ notebook_url = "https://colab.research.google.com/drive/1RAWb22-PFNI-X1gPVzc927SGUdfr6nsR?usp=sharing"
26
+
27
+
28
+ basic_component_values = [None] * 6
29
+ leader_component_values = [None] * 5
30
+
31
+
32
+ def make_leaderboard_md(elo_results):
33
+ leaderboard_md = f"""
34
+ # 🏆 Chatbot Arena Leaderboard
35
+ | [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
36
+
37
+ This leaderboard is based on the following three benchmarks.
38
+ - [Chatbot Arena](https://lmsys.org/blog/2023-05-03-arena/) - a crowdsourced, randomized battle platform. We use 100K+ user votes to compute Elo ratings.
39
+ - [MT-Bench](https://arxiv.org/abs/2306.05685) - a set of challenging multi-turn questions. We use GPT-4 to grade the model responses.
40
+ - [MMLU](https://arxiv.org/abs/2009.03300) (5-shot) - a test to measure a model's multitask accuracy on 57 tasks.
41
+
42
+ 💻 Code: The Arena Elo ratings are computed by this [notebook]({notebook_url}). The MT-bench scores (single-answer grading on a scale of 10) are computed by [fastchat.llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge). The MMLU scores are mostly computed by [InstructEval](https://github.com/declare-lab/instruct-eval). Higher values are better for all benchmarks. Empty cells mean not available. Last updated: November, 2023.
43
+ """
44
+ return leaderboard_md
45
+
46
+
47
+ def make_leaderboard_md_live(elo_results):
48
+ leaderboard_md = f"""
49
+ # Leaderboard
50
+ Last updated: {elo_results["last_updated_datetime"]}
51
+ {elo_results["leaderboard_table"]}
52
+ """
53
+ return leaderboard_md
54
+
55
+
56
+ def update_elo_components(max_num_files, elo_results_file):
57
+ log_files = get_log_files(max_num_files)
58
+
59
+ # Leaderboard
60
+ if elo_results_file is None: # Do live update
61
+ battles = clean_battle_data(log_files, [])
62
+ elo_results = report_elo_analysis_results(battles)
63
+
64
+ leader_component_values[0] = make_leaderboard_md_live(elo_results)
65
+ leader_component_values[1] = elo_results["win_fraction_heatmap"]
66
+ leader_component_values[2] = elo_results["battle_count_heatmap"]
67
+ leader_component_values[3] = elo_results["bootstrap_elo_rating"]
68
+ leader_component_values[4] = elo_results["average_win_rate_bar"]
69
+
70
+ # Basic stats
71
+ basic_stats = report_basic_stats(log_files)
72
+ md0 = f"Last updated: {basic_stats['last_updated_datetime']}"
73
+
74
+ md1 = "### Action Histogram\n"
75
+ md1 += basic_stats["action_hist_md"] + "\n"
76
+
77
+ md2 = "### Anony. Vote Histogram\n"
78
+ md2 += basic_stats["anony_vote_hist_md"] + "\n"
79
+
80
+ md3 = "### Model Call Histogram\n"
81
+ md3 += basic_stats["model_hist_md"] + "\n"
82
+
83
+ md4 = "### Model Call (Last 24 Hours)\n"
84
+ md4 += basic_stats["num_chats_last_24_hours"] + "\n"
85
+
86
+ basic_component_values[0] = md0
87
+ basic_component_values[1] = basic_stats["chat_dates_bar"]
88
+ basic_component_values[2] = md1
89
+ basic_component_values[3] = md2
90
+ basic_component_values[4] = md3
91
+ basic_component_values[5] = md4
92
+
93
+
94
+ def update_worker(max_num_files, interval, elo_results_file):
95
+ while True:
96
+ tic = time.time()
97
+ update_elo_components(max_num_files, elo_results_file)
98
+ durtaion = time.time() - tic
99
+ print(f"update duration: {durtaion:.2f} s")
100
+ time.sleep(max(interval - durtaion, 0))
101
+
102
+
103
+ def load_demo(url_params, request: gr.Request):
104
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
105
+ return basic_component_values + leader_component_values
106
+
107
+
108
+ def model_hyperlink(model_name, link):
109
+ return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
110
+
111
+
112
+ def load_leaderboard_table_csv(filename, add_hyperlink=True):
113
+ lines = open(filename).readlines()
114
+ heads = [v.strip() for v in lines[0].split(",")]
115
+ rows = []
116
+ for i in range(1, len(lines)):
117
+ row = [v.strip() for v in lines[i].split(",")]
118
+ for j in range(len(heads)):
119
+ item = {}
120
+ for h, v in zip(heads, row):
121
+ if h == "Arena Elo rating":
122
+ if v != "-":
123
+ v = int(ast.literal_eval(v))
124
+ else:
125
+ v = np.nan
126
+ elif h == "MMLU":
127
+ if v != "-":
128
+ v = round(ast.literal_eval(v) * 100, 1)
129
+ else:
130
+ v = np.nan
131
+ elif h == "MT-bench (win rate %)":
132
+ if v != "-":
133
+ v = round(ast.literal_eval(v[:-1]), 1)
134
+ else:
135
+ v = np.nan
136
+ elif h == "MT-bench (score)":
137
+ if v != "-":
138
+ v = round(ast.literal_eval(v), 2)
139
+ else:
140
+ v = np.nan
141
+ item[h] = v
142
+ if add_hyperlink:
143
+ item["Model"] = model_hyperlink(item["Model"], item["Link"])
144
+ rows.append(item)
145
+
146
+ return rows
147
+
148
+
149
+ def build_basic_stats_tab():
150
+ empty = "Loading ..."
151
+ basic_component_values[:] = [empty, None, empty, empty, empty, empty]
152
+
153
+ md0 = gr.Markdown(empty)
154
+ gr.Markdown("#### Figure 1: Number of model calls and votes")
155
+ plot_1 = gr.Plot(show_label=False)
156
+ with gr.Row():
157
+ with gr.Column():
158
+ md1 = gr.Markdown(empty)
159
+ with gr.Column():
160
+ md2 = gr.Markdown(empty)
161
+ with gr.Row():
162
+ with gr.Column():
163
+ md3 = gr.Markdown(empty)
164
+ with gr.Column():
165
+ md4 = gr.Markdown(empty)
166
+ return [md0, plot_1, md1, md2, md3, md4]
167
+
168
+
169
+ def build_leaderboard_tab(elo_results_file, leaderboard_table_file):
170
+ if elo_results_file is None: # Do live update
171
+ md = "Loading ..."
172
+ p1 = p2 = p3 = p4 = None
173
+ else:
174
+ with open(elo_results_file, "rb") as fin:
175
+ elo_results = pickle.load(fin)
176
+
177
+ md = make_leaderboard_md(elo_results)
178
+ p1 = elo_results["win_fraction_heatmap"]
179
+ p2 = elo_results["battle_count_heatmap"]
180
+ p3 = elo_results["bootstrap_elo_rating"]
181
+ p4 = elo_results["average_win_rate_bar"]
182
+
183
+ md_1 = gr.Markdown(md, elem_id="leaderboard_markdown")
184
+
185
+ if leaderboard_table_file:
186
+ data = load_leaderboard_table_csv(leaderboard_table_file)
187
+ headers = [
188
+ "Model",
189
+ "Arena Elo rating",
190
+ "MT-bench (score)",
191
+ "MMLU",
192
+ "License",
193
+ ]
194
+ values = []
195
+ for item in data:
196
+ row = []
197
+ for key in headers:
198
+ value = item[key]
199
+ row.append(value)
200
+ values.append(row)
201
+ values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9)
202
+
203
+ headers[1] = "⭐ " + headers[1]
204
+ headers[2] = "📈 " + headers[2]
205
+
206
+ gr.Dataframe(
207
+ headers=headers,
208
+ datatype=["markdown", "number", "number", "number", "str"],
209
+ value=values,
210
+ elem_id="leaderboard_dataframe",
211
+ )
212
+ gr.Markdown(
213
+ """ ## Visit our [HF space](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) for more analysis!
214
+ If you want to see more models, please help us [add them](https://github.com/lm-sys/FastChat/blob/main/docs/arena.md#how-to-add-a-new-model).
215
+ """,
216
+ elem_id="leaderboard_markdown",
217
+ )
218
+ else:
219
+ pass
220
+
221
+ leader_component_values[:] = [md, p1, p2, p3, p4]
222
+
223
+ """
224
+ with gr.Row():
225
+ with gr.Column():
226
+ gr.Markdown(
227
+ "#### Figure 1: Fraction of Model A Wins for All Non-tied A vs. B Battles"
228
+ )
229
+ plot_1 = gr.Plot(p1, show_label=False)
230
+ with gr.Column():
231
+ gr.Markdown(
232
+ "#### Figure 2: Battle Count for Each Combination of Models (without Ties)"
233
+ )
234
+ plot_2 = gr.Plot(p2, show_label=False)
235
+ with gr.Row():
236
+ with gr.Column():
237
+ gr.Markdown(
238
+ "#### Figure 3: Bootstrap of Elo Estimates (1000 Rounds of Random Sampling)"
239
+ )
240
+ plot_3 = gr.Plot(p3, show_label=False)
241
+ with gr.Column():
242
+ gr.Markdown(
243
+ "#### Figure 4: Average Win Rate Against All Other Models (Assuming Uniform Sampling and No Ties)"
244
+ )
245
+ plot_4 = gr.Plot(p4, show_label=False)
246
+ """
247
+
248
+ from fastchat.serve.gradio_web_server import acknowledgment_md
249
+
250
+ gr.Markdown(acknowledgment_md)
251
+
252
+ # return [md_1, plot_1, plot_2, plot_3, plot_4]
253
+ return [md_1]
254
+
255
+
256
+ def build_demo(elo_results_file, leaderboard_table_file):
257
+ from fastchat.serve.gradio_web_server import block_css
258
+
259
+ text_size = gr.themes.sizes.text_lg
260
+
261
+ with gr.Blocks(
262
+ title="Monitor",
263
+ theme=gr.themes.Base(text_size=text_size),
264
+ css=block_css,
265
+ ) as demo:
266
+ with gr.Tabs() as tabs:
267
+ with gr.Tab("Leaderboard", id=0):
268
+ leader_components = build_leaderboard_tab(
269
+ elo_results_file, leaderboard_table_file
270
+ )
271
+
272
+ with gr.Tab("Basic Stats", id=1):
273
+ basic_components = build_basic_stats_tab()
274
+
275
+ url_params = gr.JSON(visible=False)
276
+ demo.load(
277
+ load_demo,
278
+ [url_params],
279
+ basic_components + leader_components,
280
+ _js=get_window_url_params_js,
281
+ )
282
+
283
+ return demo
284
+
285
+
286
+ if __name__ == "__main__":
287
+ parser = argparse.ArgumentParser()
288
+ parser.add_argument("--host", type=str, default="0.0.0.0")
289
+ parser.add_argument("--port", type=int)
290
+ parser.add_argument("--share", action="store_true")
291
+ parser.add_argument("--concurrency-count", type=int, default=10)
292
+ parser.add_argument("--update-interval", type=int, default=300)
293
+ parser.add_argument("--max-num-files", type=int)
294
+ parser.add_argument("--elo-results-file", type=str)
295
+ parser.add_argument("--leaderboard-table-file", type=str)
296
+ args = parser.parse_args()
297
+
298
+ logger = build_logger("monitor", "monitor.log")
299
+ logger.info(f"args: {args}")
300
+
301
+ if args.elo_results_file is None: # Do live update
302
+ update_thread = threading.Thread(
303
+ target=update_worker,
304
+ args=(args.max_num_files, args.update_interval, args.elo_results_file),
305
+ )
306
+ update_thread.start()
307
+
308
+ demo = build_demo(args.elo_results_file, args.leaderboard_table_file)
309
+ demo.queue(
310
+ concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
311
+ ).launch(
312
+ server_name=args.host, server_port=args.port, share=args.share, max_threads=200
313
+ )