diff --git a/README.md b/README.md index 189c5bed187485751804fd14f8e32cba7599bf04..f01612b9c5e2a28f216e4d45f668d834b75c06f8 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,6 @@ --- -title: Demo Test -emoji: 🌍 -colorFrom: indigo -colorTo: purple +title: demo_test +app_file: gradio_web_server.py sdk: gradio -sdk_version: 4.39.0 -app_file: app.py -pinned: false +sdk_version: 3.45.0 --- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/__pycache__/__init__.cpython-310.pyc b/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d264ee86b93bee08578afcdc7208f94c2646f31 Binary files /dev/null and b/__pycache__/__init__.cpython-310.pyc differ diff --git a/__pycache__/__init__.cpython-311.pyc b/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f388c132ae6e0ecca6b4fc940f10f593a6d0cfc Binary files /dev/null and b/__pycache__/__init__.cpython-311.pyc differ diff --git a/__pycache__/api_provider.cpython-310.pyc b/__pycache__/api_provider.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da748235a77b7a4b536f15110469a198967fc952 Binary files /dev/null and b/__pycache__/api_provider.cpython-310.pyc differ diff --git a/__pycache__/base_model_worker.cpython-310.pyc b/__pycache__/base_model_worker.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa6afe37df86561fd3d0f4a5bc472310916224bb Binary files /dev/null and b/__pycache__/base_model_worker.cpython-310.pyc differ diff --git a/__pycache__/cli.cpython-310.pyc b/__pycache__/cli.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19a460d5116b39b75cecf0909349e00af8c83c2f Binary files /dev/null and b/__pycache__/cli.cpython-310.pyc differ diff --git a/__pycache__/cli.cpython-311.pyc b/__pycache__/cli.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c9bfe97259e32b96368cb05fe6928a420b3d963 Binary files /dev/null and b/__pycache__/cli.cpython-311.pyc differ diff --git a/__pycache__/controller.cpython-310.pyc b/__pycache__/controller.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d476553e5f9bc9c7aa60ed6af5beb0111bdc44d Binary files /dev/null and b/__pycache__/controller.cpython-310.pyc differ diff --git a/__pycache__/gradio_web_server.cpython-310.pyc b/__pycache__/gradio_web_server.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7aa2fcdb21b9fd4940a345fc25d038d7e7e2b898 Binary files /dev/null and b/__pycache__/gradio_web_server.cpython-310.pyc differ diff --git a/__pycache__/inference.cpython-310.pyc b/__pycache__/inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1d49d8c4d9fdf6f6e05f8ffc97fb53286cfff4e Binary files /dev/null and b/__pycache__/inference.cpython-310.pyc differ diff --git a/__pycache__/model_worker.cpython-310.pyc b/__pycache__/model_worker.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..373080e38561629f28c7fe09b6b3b04b633073bd Binary files /dev/null and b/__pycache__/model_worker.cpython-310.pyc differ diff --git a/__pycache__/test_message.cpython-310.pyc b/__pycache__/test_message.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..981fe6f340d6ec7f63f1c9c5e0d68e83793b3073 Binary files /dev/null and b/__pycache__/test_message.cpython-310.pyc differ diff --git a/api_provider.py b/api_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..3dbb8a69032a5ac3378a8de34774a2dcdf304223 --- /dev/null +++ b/api_provider.py @@ -0,0 +1,130 @@ +"""Call API providers.""" + +import os +import random +import time + +from fastchat.utils import build_logger +from fastchat.constants import WORKER_API_TIMEOUT + + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + + +def openai_api_stream_iter( + model_name, + messages, + temperature, + top_p, + max_new_tokens, + api_base=None, + api_key=None, +): + import openai + + openai.api_base = api_base or "https://api.openai.com/v1" + openai.api_key = api_key or os.environ["OPENAI_API_KEY"] + if model_name == "gpt-4-turbo": + model_name = "gpt-4-1106-preview" + + # Make requests + gen_params = { + "model": model_name, + "prompt": messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + res = openai.ChatCompletion.create( + model=model_name, + messages=messages, + temperature=temperature, + max_tokens=max_new_tokens, + stream=True, + ) + text = "" + for chunk in res: + text += chunk["choices"][0]["delta"].get("content", "") + data = { + "text": text, + "error_code": 0, + } + yield data + + +def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens): + import anthropic + + c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) + + # Make requests + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + res = c.completions.create( + prompt=prompt, + stop_sequences=[anthropic.HUMAN_PROMPT], + max_tokens_to_sample=max_new_tokens, + temperature=temperature, + top_p=top_p, + model=model_name, + stream=True, + ) + text = "" + for chunk in res: + text += chunk.completion + data = { + "text": text, + "error_code": 0, + } + yield data + + +def init_palm_chat(model_name): + import vertexai # pip3 install google-cloud-aiplatform + from vertexai.preview.language_models import ChatModel + + project_id = os.environ["GCP_PROJECT_ID"] + location = "us-central1" + vertexai.init(project=project_id, location=location) + + chat_model = ChatModel.from_pretrained(model_name) + chat = chat_model.start_chat(examples=[]) + return chat + + +def palm_api_stream_iter(chat, message, temperature, top_p, max_new_tokens): + parameters = { + "temperature": temperature, + "top_p": top_p, + "max_output_tokens": max_new_tokens, + } + gen_params = { + "model": "palm-2", + "prompt": message, + } + gen_params.update(parameters) + logger.info(f"==== request ====\n{gen_params}") + + response = chat.send_message(message, **parameters) + content = response.text + + pos = 0 + while pos < len(content): + # This is a fancy way to simulate token generation latency combined + # with a Poisson process. + pos += random.randint(10, 20) + time.sleep(random.expovariate(50)) + data = { + "text": content[:pos], + "error_code": 0, + } + yield data diff --git a/base_model_worker.py b/base_model_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..a297ab7e440cf727526a60e1481bbeac286544c9 --- /dev/null +++ b/base_model_worker.py @@ -0,0 +1,239 @@ +import asyncio +import threading +import time +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import requests + +from fastchat.constants import WORKER_HEART_BEAT_INTERVAL +from fastchat.conversation import Conversation +from fastchat.utils import pretty_print_semaphore, build_logger + + +worker = None +logger = None + +app = FastAPI() + + +def heart_beat_worker(obj): + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + obj.send_heart_beat() + + +class BaseModelWorker: + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + conv_template: str = None, + ): + global logger, worker + + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + if model_path.endswith("/"): + model_path = model_path[:-1] + self.model_names = model_names or [model_path.split("/")[-1]] + self.limit_worker_concurrency = limit_worker_concurrency + self.conv = self.make_conv_template(conv_template, model_path) + self.conv.sep_style = int(self.conv.sep_style) + self.tokenizer = None + self.context_len = None + self.call_ct = 0 + self.semaphore = None + + self.heart_beat_thread = None + + if logger is None: + logger = build_logger("model_worker", f"model_worker_{self.worker_id}.log") + if worker is None: + worker = self + + def make_conv_template( + self, + conv_template: str = None, + model_path: str = None, + ) -> Conversation: + """ + can be overrided to costomize the conversation template for different model workers. + """ + from fastchat.conversation import get_conv_template + from fastchat.model.model_adapter import get_conversation_template + + if conv_template: + conv = get_conv_template(conv_template) + else: + conv = get_conversation_template(model_path) + print(conv) + return conv + + def init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, + args=(self,), + daemon=True, + ) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status(), + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info( + f"Send heart beat. Models: {self.model_names}. " + f"Semaphore: {pretty_print_semaphore(self.semaphore)}. " + f"call_ct: {self.call_ct}. " + f"worker_id: {self.worker_id}. " + ) + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post( + url, + json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length(), + }, + timeout=5, + ) + exist = ret.json()["exist"] + break + except (requests.exceptions.RequestException, KeyError) as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if ( + self.semaphore is None + or self.semaphore._value is None + or self.semaphore._waiters is None + ): + return 0 + else: + return ( + self.limit_worker_concurrency + - self.semaphore._value + + len(self.semaphore._waiters) + ) + + def get_status(self): + return { + "model_names": self.model_names, + "speed": 1, + "queue_length": self.get_queue_length(), + } + + def count_token(self, params): + prompt = params["prompt"] + + try: + input_ids = self.tokenizer(prompt).input_ids + input_echo_len = len(input_ids) + except TypeError: + input_echo_len = self.tokenizer.num_tokens(prompt) + + ret = { + "count": input_echo_len, + "error_code": 0, + } + return ret + + def get_conv_template(self): + return {"conv": self.conv} + + def generate_stream_gate(self, params): + raise NotImplementedError + + def generate_gate(self, params): + raise NotImplementedError + + def get_embeddings(self, params): + raise NotImplementedError + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + output = await asyncio.to_thread(worker.generate_gate, params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + await acquire_worker_semaphore() + embedding = worker.get_embeddings(params) + release_worker_semaphore() + return JSONResponse(content=embedding) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} diff --git a/cli.py b/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..cb815cde6a751e283307c6d0cd5bbbde2fd062ed --- /dev/null +++ b/cli.py @@ -0,0 +1,313 @@ +""" +Chat with a model with command line interface. + +Usage: +python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5 +python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0 + +Other commands: +- Type "!!exit" or an empty line to exit. +- Type "!!reset" to start a new conversation. +- Type "!!remove" to remove the last prompt. +- Type "!!regen" to regenerate the last message. +- Type "!!save " to save the conversation history to a json file. +- Type "!!load " to load a conversation history from a json file. +""" +import argparse +import os +import re +import sys + +from prompt_toolkit import PromptSession +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.completion import WordCompleter +from prompt_toolkit.history import InMemoryHistory +from prompt_toolkit.key_binding import KeyBindings +from rich.console import Console +from rich.live import Live +from rich.markdown import Markdown +import torch + +from fastchat.model.model_adapter import add_model_args +from fastchat.modules.awq import AWQConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.modules.gptq import GptqConfig +from fastchat.serve.inference import ChatIO, chat_loop +from fastchat.utils import str_to_torch_dtype + + +class SimpleChatIO(ChatIO): + def __init__(self, multiline: bool = False, prefix: str = ''): + self._multiline = multiline + self.prefix = prefix + + def prompt_for_input(self, role) -> str: + if not self._multiline: + return input(f"{role}: {self.prefix}") + + prompt_data = [] + line = input(f"{role} [ctrl-d/z on empty line to end]: ") + while True: + prompt_data.append(line.strip()) + try: + line = input() + except EOFError as e: + break + return f"\n{self.prefix}".join(prompt_data) + + def prompt_for_output(self, role: str): + print(f"{role}: ", end="", flush=True) + + def stream_output(self, output_stream): + pre = 0 + for outputs in output_stream: + output_text = outputs["text"] + output_text = output_text.strip().split(" ") + now = len(output_text) - 1 + if now > pre: + print(" ".join(output_text[pre:now]), end=" ", flush=True) + pre = now + print(" ".join(output_text[pre:]), flush=True) + return " ".join(output_text) + + def print_output(self, text: str): + print(text) + + +class RichChatIO(ChatIO): + bindings = KeyBindings() + + @bindings.add("escape", "enter") + def _(event): + event.app.current_buffer.newline() + + def __init__(self, multiline: bool = False, mouse: bool = False): + self._prompt_session = PromptSession(history=InMemoryHistory()) + self._completer = WordCompleter( + words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"], + pattern=re.compile("$"), + ) + self._console = Console() + self._multiline = multiline + self._mouse = mouse + + def prompt_for_input(self, role) -> str: + self._console.print(f"[bold]{role}:") + # TODO(suquark): multiline input has some issues. fix it later. + prompt_input = self._prompt_session.prompt( + completer=self._completer, + multiline=False, + mouse_support=self._mouse, + auto_suggest=AutoSuggestFromHistory(), + key_bindings=self.bindings if self._multiline else None, + ) + self._console.print() + return prompt_input + + def prompt_for_output(self, role: str): + self._console.print(f"[bold]{role.replace('/', '|')}:") + + def stream_output(self, output_stream): + """Stream output from a role.""" + # TODO(suquark): the console flickers when there is a code block + # above it. We need to cut off "live" when a code block is done. + + # Create a Live context for updating the console output + with Live(console=self._console, refresh_per_second=4) as live: + # Read lines from the stream + for outputs in output_stream: + if not outputs: + continue + text = outputs["text"] + # Render the accumulated text as Markdown + # NOTE: this is a workaround for the rendering "unstandard markdown" + # in rich. The chatbots output treat "\n" as a new line for + # better compatibility with real-world text. However, rendering + # in markdown would break the format. It is because standard markdown + # treat a single "\n" in normal text as a space. + # Our workaround is adding two spaces at the end of each line. + # This is not a perfect solution, as it would + # introduce trailing spaces (only) in code block, but it works well + # especially for console output, because in general the console does not + # care about trailing spaces. + lines = [] + for line in text.splitlines(): + lines.append(line) + if line.startswith("```"): + # Code block marker - do not add trailing spaces, as it would + # break the syntax highlighting + lines.append("\n") + else: + lines.append(" \n") + markdown = Markdown("".join(lines)) + # Update the Live console output + live.update(markdown) + self._console.print() + return text + + def print_output(self, text: str): + self.stream_output([{"text": text}]) + + +class ProgrammaticChatIO(ChatIO): + def prompt_for_input(self, role) -> str: + contents = "" + # `end_sequence` signals the end of a message. It is unlikely to occur in + # message content. + end_sequence = " __END_OF_A_MESSAGE_47582648__\n" + len_end = len(end_sequence) + while True: + if len(contents) >= len_end: + last_chars = contents[-len_end:] + if last_chars == end_sequence: + break + try: + char = sys.stdin.read(1) + contents = contents + char + except EOFError: + continue + contents = contents[:-len_end] + print(f"[!OP:{role}]: {contents}", flush=True) + return contents + + def prompt_for_output(self, role: str): + print(f"[!OP:{role}]: ", end="", flush=True) + + def stream_output(self, output_stream): + pre = 0 + for outputs in output_stream: + output_text = outputs["text"] + output_text = output_text.strip().split(" ") + now = len(output_text) - 1 + if now > pre: + print(" ".join(output_text[pre:now]), end=" ", flush=True) + pre = now + print(" ".join(output_text[pre:]), flush=True) + return " ".join(output_text) + + def print_output(self, text: str): + print(text) + + +def main(args): + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + os.environ["XPU_VISIBLE_DEVICES"] = args.gpus + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + ) + else: + exllama_config = None + if args.enable_xft: + xft_config = XftConfig( + max_seq_len=args.xft_max_seq_len, + data_type=args.xft_dtype, + ) + if args.device != "cpu": + print("xFasterTransformer now is only support CPUs. Reset device to CPU") + args.device = "cpu" + else: + xft_config = None + if args.style == "simple": + chatio = SimpleChatIO(args.multiline) + elif args.style == "rich": + chatio = RichChatIO(args.multiline, args.mouse) + elif args.style == "programmatic": + chatio = ProgrammaticChatIO() + else: + raise ValueError(f"Invalid style for console: {args.style}") + try: + if args.upload_file_path: + prefix = open(args.upload_file_path, 'r').read() + args.conv_system_msg = prefix[:20000] + chat_loop( + args.model_path, + args.device, + args.num_gpus, + args.max_gpu_memory, + str_to_torch_dtype(args.dtype), + args.load_8bit, + args.cpu_offloading, + args.conv_template, + args.conv_system_msg, + args.temperature, + args.repetition_penalty, + args.max_new_tokens, + chatio, + gptq_config=GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ), + awq_config=AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ), + exllama_config=exllama_config, + xft_config=xft_config, + revision=args.revision, + judge_sent_end=args.judge_sent_end, + debug=args.debug, + history=not args.no_history, + ) + except KeyboardInterrupt: + print("exit...") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_model_args(parser) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--conv-system-msg", type=str, default=None, help="Conversation system message." + ) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--repetition_penalty", type=float, default=1.0) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--no-history", action="store_true") + parser.add_argument( + "--style", + type=str, + default="simple", + choices=["simple", "rich", "programmatic"], + help="Display style.", + ) + parser.add_argument( + "--multiline", + action="store_true", + help="Enable multiline input. Use ESC+Enter for newline.", + ) + parser.add_argument( + "--mouse", + action="store_true", + help="[Rich Style]: Enable mouse support for cursor positioning.", + ) + parser.add_argument( + "--judge-sent-end", + action="store_true", + help="Whether enable the correction logic that interrupts the output of sentences due to EOS.", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Print useful debug information (e.g., prompts)", + ) + parser.add_argument( + "--upload-file-path", + type=str, + default="", + help="upload long txt for summary.", + ) + args = parser.parse_args() + main(args) diff --git a/controller.py b/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..a67da62c42d898c17c23c0cc7244770cfeb78746 --- /dev/null +++ b/controller.py @@ -0,0 +1,348 @@ +""" +A controller manages distributed workers. +It sends worker addresses to clients. +""" +import argparse +import asyncio +import dataclasses +from enum import Enum, auto +import json +import logging +import os +import time +from typing import List, Union +import threading + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +import numpy as np +import requests +import uvicorn + +from fastchat.constants import ( + CONTROLLER_HEART_BEAT_EXPIRATION, + WORKER_API_TIMEOUT, + ErrorCode, + SERVER_ERROR_MSG, +) +from fastchat.utils import build_logger + + +logger = build_logger("controller", "controller.log") + + +class DispatchMethod(Enum): + LOTTERY = auto() + SHORTEST_QUEUE = auto() + + @classmethod + def from_str(cls, name): + if name == "lottery": + return cls.LOTTERY + elif name == "shortest_queue": + return cls.SHORTEST_QUEUE + else: + raise ValueError(f"Invalid dispatch method") + + +@dataclasses.dataclass +class WorkerInfo: + model_names: List[str] + speed: int + queue_length: int + check_heart_beat: bool + last_heart_beat: str + + +def heart_beat_controller(controller): + while True: + time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) + controller.remove_stale_workers_by_expiration() + + +class Controller: + def __init__(self, dispatch_method: str): + # Dict[str -> WorkerInfo] + self.worker_info = {} + self.dispatch_method = DispatchMethod.from_str(dispatch_method) + + self.heart_beat_thread = threading.Thread( + target=heart_beat_controller, args=(self,) + ) + self.heart_beat_thread.start() + + def register_worker( + self, worker_name: str, check_heart_beat: bool, worker_status: dict + ): + if worker_name not in self.worker_info: + logger.info(f"Register a new worker: {worker_name}") + else: + logger.info(f"Register an existing worker: {worker_name}") + + if not worker_status: + worker_status = self.get_worker_status(worker_name) + if not worker_status: + return False + + self.worker_info[worker_name] = WorkerInfo( + worker_status["model_names"], + worker_status["speed"], + worker_status["queue_length"], + check_heart_beat, + time.time(), + ) + + logger.info(f"Register done: {worker_name}, {worker_status}") + return True + + def get_worker_status(self, worker_name: str): + try: + r = requests.post(worker_name + "/worker_get_status", timeout=5) + except requests.exceptions.RequestException as e: + logger.error(f"Get status fails: {worker_name}, {e}") + return None + + if r.status_code != 200: + logger.error(f"Get status fails: {worker_name}, {r}") + return None + + return r.json() + + def remove_worker(self, worker_name: str): + del self.worker_info[worker_name] + + def refresh_all_workers(self): + old_info = dict(self.worker_info) + self.worker_info = {} + + for w_name, w_info in old_info.items(): + if not self.register_worker(w_name, w_info.check_heart_beat, None): + logger.info(f"Remove stale worker: {w_name}") + + def list_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + model_names.update(w_info.model_names) + + return list(model_names) + + def get_worker_address(self, model_name: str): + if self.dispatch_method == DispatchMethod.LOTTERY: + worker_names = [] + worker_speeds = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_speeds.append(w_info.speed) + worker_speeds = np.array(worker_speeds, dtype=np.float32) + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + if True: # Directly return address + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) + worker_name = worker_names[pt] + return worker_name + + # Check status before returning + while True: + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) + worker_name = worker_names[pt] + + if self.get_worker_status(worker_name): + break + else: + self.remove_worker(worker_name) + worker_speeds[pt] = 0 + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + continue + return worker_name + elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: + worker_names = [] + worker_qlen = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_qlen.append(w_info.queue_length / w_info.speed) + if len(worker_names) == 0: + return "" + min_index = np.argmin(worker_qlen) + w_name = worker_names[min_index] + self.worker_info[w_name].queue_length += 1 + logger.info( + f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}" + ) + return w_name + else: + raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") + + def receive_heart_beat(self, worker_name: str, queue_length: int): + if worker_name not in self.worker_info: + logger.info(f"Receive unknown heart beat. {worker_name}") + return False + + self.worker_info[worker_name].queue_length = queue_length + self.worker_info[worker_name].last_heart_beat = time.time() + logger.info(f"Receive heart beat. {worker_name}") + return True + + def remove_stale_workers_by_expiration(self): + expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION + to_delete = [] + for worker_name, w_info in self.worker_info.items(): + if w_info.check_heart_beat and w_info.last_heart_beat < expire: + to_delete.append(worker_name) + + for worker_name in to_delete: + self.remove_worker(worker_name) + + def handle_no_worker(self, params): + logger.info(f"no worker: {params['model']}") + ret = { + "text": SERVER_ERROR_MSG, + "error_code": ErrorCode.CONTROLLER_NO_WORKER, + } + return json.dumps(ret).encode() + b"\0" + + def handle_worker_timeout(self, worker_address): + logger.info(f"worker timeout: {worker_address}") + ret = { + "text": SERVER_ERROR_MSG, + "error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT, + } + return json.dumps(ret).encode() + b"\0" + + # Let the controller act as a worker to achieve hierarchical + # management. This can be used to connect isolated sub networks. + def worker_api_get_status(self): + model_names = set() + speed = 0 + queue_length = 0 + + for w_name in self.worker_info: + worker_status = self.get_worker_status(w_name) + if worker_status is not None: + model_names.update(worker_status["model_names"]) + speed += worker_status["speed"] + queue_length += worker_status["queue_length"] + + model_names = sorted(list(model_names)) + return { + "model_names": model_names, + "speed": speed, + "queue_length": queue_length, + } + + def worker_api_generate_stream(self, params): + worker_addr = self.get_worker_address(params["model"]) + if not worker_addr: + yield self.handle_no_worker(params) + + try: + response = requests.post( + worker_addr + "/worker_generate_stream", + json=params, + stream=True, + timeout=WORKER_API_TIMEOUT, + ) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield chunk + b"\0" + except requests.exceptions.RequestException as e: + yield self.handle_worker_timeout(worker_addr) + + +app = FastAPI() + + +@app.post("/register_worker") +async def register_worker(request: Request): + data = await request.json() + controller.register_worker( + data["worker_name"], data["check_heart_beat"], data.get("worker_status", None) + ) + + +@app.post("/refresh_all_workers") +async def refresh_all_workers(): + models = controller.refresh_all_workers() + + +@app.post("/list_models") +async def list_models(): + models = controller.list_models() + return {"models": models} + + +@app.post("/get_worker_address") +async def get_worker_address(request: Request): + data = await request.json() + addr = controller.get_worker_address(data["model"]) + return {"address": addr} + + +@app.post("/receive_heart_beat") +async def receive_heart_beat(request: Request): + data = await request.json() + exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"]) + return {"exist": exist} + + +@app.post("/worker_generate_stream") +async def worker_api_generate_stream(request: Request): + params = await request.json() + generator = controller.worker_api_generate_stream(params) + return StreamingResponse(generator) + + +@app.post("/worker_get_status") +async def worker_api_get_status(request: Request): + return controller.worker_api_get_status() + + +@app.get("/test_connection") +async def worker_api_get_status(request: Request): + return "success" + + +def create_controller(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21001) + parser.add_argument( + "--dispatch-method", + type=str, + choices=["lottery", "shortest_queue"], + default="shortest_queue", + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + controller = Controller(args.dispatch_method) + return args, controller + + +if __name__ == "__main__": + args, controller = create_controller() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/gateway/README.md b/gateway/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b3afaf171bc38b232b68609585244c9e76489da7 --- /dev/null +++ b/gateway/README.md @@ -0,0 +1,57 @@ +# fastchat Nginx Gateway + +## Purpose of the Gateway + +The Nginx gateway serves the following purposes: + +1. Protects Gradio servers by acting as a firewall. +2. Facilitates dynamic mounting and unmounting of Gradio servers. +3. Provides load balancing for Gradio servers. +4. Offers additional security features, such as total connection limit. +5. Reduces attack surface by requiring only a single public port to be exposed for serving. + +## Deployment and Updating of the Gateway + +### Installing Nginx + +On Debian-based distributions (e.g., Ubuntu): + +```bash +sudo apt update +sudo apt install nginx +``` +On Red Hat-based distributions (e.g., CentOS, Fedora): + +```bash +sudo yum install epel-release +sudo yum install nginx +``` + +### Deployment + +Copy `nginx.conf` to `/etc/nginx/nginx.conf` (need sudo permission). + +Replace the port number 7860 in `server localhost:7860` with the port where you deploy the Gradio web server. + +Modify `upstream websocket` to configure Gradio servers behind the gateway. + +Lastly, update Nginx. + + +### HTTPS Deployment with a Public Domain URL + +Make sure you obtain the HTTPS certificate and the private key used to generate the certificate. + +Fill the addresses to your certificate and private key in the `[PATH_TO_SSL_CERT]` and `[PATH_TO_PRIVATE_KEY]` fields. + +If you have your own domain url to serve the chatbot, replace the chat.lmsys.org url with your own domain url. + +### Updating + +Every time when `/etc/nginx/nginx.conf` is modified, you need to update the Nginx service: + +```bash +sudo nginx -t # check `/etc/nginx/nginx.conf` +sudo systemctl reload nginx # restart Nginx service to load the new config +sudo systemctl status nginx # check the status of the Nginx service. It should be active (running). +``` diff --git a/gateway/nginx.conf b/gateway/nginx.conf new file mode 100644 index 0000000000000000000000000000000000000000..b88ca8c50772421fca91f33ff77ef75f4d23ad4d --- /dev/null +++ b/gateway/nginx.conf @@ -0,0 +1,97 @@ +user www-data; +worker_processes auto; +pid /run/nginx.pid; +include /etc/nginx/modules-enabled/*.conf; + +events { + worker_connections 1024; # maximum number of connections that a worker process can handle concurrently + # multi_accept on; # enabling multi_accept can help improve performance under high load, but may increase the number of simultaneous connections that a worker process can handle + +} + +http { + ## + # Basic Settings + ## + + sendfile on; # enable sendfile for performance optimization + tcp_nopush on; # enable TCP no-pushing + tcp_nodelay on; # enable TCP no-delay + keepalive_timeout 65; # sets the timeout for keep-alive connections + types_hash_max_size 2048; # maximum size of the types hash table + # server_tokens off; # disable server token (i.e., server signature) in response headers to improve security + + # server_names_hash_bucket_size 64; + # server_name_in_redirect off; + + include /etc/nginx/mime.types; # include MIME types file + default_type application/octet-stream; # default MIME type for unknown file types + + ## + # SSL Settings + ## + + ssl_protocols TLSv1.2; # specify SSL/TLS protocols to use + ssl_prefer_server_ciphers on; # prefer server ciphers over client ciphers + + ## + # Logging Settings + ## + + access_log /var/log/nginx/access.log; # path to access log file + error_log /var/log/nginx/error.log; # path to error log file + + ## + # Gzip Settings + ## + gzip on; # enable Gzip compression + + ## + # Virtual Host Configs + ## + + include /etc/nginx/conf.d/*.conf; # include all configuration files in conf.d directory + include /etc/nginx/sites-enabled/*; # include all enabled sites configuration files + + # WebSocket Proxy: https://www.nginx.com/blog/websocket-nginx/ + map $http_upgrade $connection_upgrade { + default upgrade; + '' close; + } + + upstream websocket { + ip_hash; # load balancing by IP to guarantee session persistence + server localhost:7860; # The port should be the gradio web server port + # server localhost:7861; # extra gradio server if more than one + } + + limit_conn_status 429; + limit_conn_zone $binary_remote_addr zone=perip:10m; # limit number of connections per IP + limit_conn_zone $server_name zone=perserver:10m; # limit number of connections per server + + server { + listen 443 ssl; # the listening port of our server + ssl_certificate [PATH_TO_SSL_CERT]; + ssl_certificate_key [PATH_TO_PRIVATE_KEY]; + server_name chat.lmsys.org; # replace the url with your own domain url + limit_conn perserver 1024; # connections per server + location / { + proxy_pass http://websocket; # proxy all requests to the defined upstream server + limit_conn perip 5; # connections per IP + proxy_set_header Host $host; # set the Host header for the upstream server + proxy_set_header X-Real-IP $remote_addr; # set the client IP address as the real IP for the upstream server + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; # set the client IP addresses in the X-Forwarded-For header + proxy_http_version 1.1; # use HTTP version 1.1 for upstream communication + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "Upgrade"; # set the Connection header to Upgrade to enable WebSocket communication + } + } + + # the following block routes all HTTP traffic to HTTPS via nginx + server { + listen 80; + server_name chat.lmsys.org; + return 301 https://chat.lmsys.org$request_uri; + } + +} diff --git a/gradio_block_arena_anony.py b/gradio_block_arena_anony.py new file mode 100644 index 0000000000000000000000000000000000000000..48e49deef8818595ec8e8104972955d6be889b5f --- /dev/null +++ b/gradio_block_arena_anony.py @@ -0,0 +1,608 @@ +""" +Chatbot Arena (battle) tab. +Users chat with two anonymous models. +""" + +import json +import time + +import gradio as gr +import numpy as np + +from fastchat.constants import ( + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + SLOW_MODEL_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.serve.gradio_block_arena_named import flash_buttons +from fastchat.serve.gradio_web_server import ( + State, + bot_response, + get_conv_log_filename, + no_change_btn, + enable_btn, + disable_btn, + invisible_btn, + acknowledgment_md, + ip_expiration_dict, + get_ip, +) +from fastchat.utils import ( + build_logger, + moderation_filter, +) + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + +num_sides = 2 +enable_moderation = False +anony_names = ["", ""] +models = [] + + +def set_global_vars_anony(enable_moderation_): + global enable_moderation + enable_moderation = enable_moderation_ + + +def load_demo_side_by_side_anony(models_, url_params): + global models + models = models_ + + states = (None,) * num_sides + selector_updates = ( + gr.Markdown.update(visible=True), + gr.Markdown.update(visible=True), + ) + + return states + selector_updates + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + + if ":" not in model_selectors[0]: + for i in range(15): + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + yield names + ("",) + (disable_btn,) * 4 + time.sleep(0.2) + else: + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + yield names + ("",) + (disable_btn,) * 4 + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ): + yield x + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ): + yield x + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ): + yield x + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ): + yield x + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (anony). ip: {get_ip(request)}") + states = [state0, state1] + for i in range(num_sides): + states[i].conv.update_last_message(None) + return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (anony). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + anony_names + + [""] + + [invisible_btn] * 4 + + [disable_btn] * 2 + + [""] + ) + + +def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request): + logger.info(f"share (anony). ip: {get_ip(request)}") + if state0 is not None and state1 is not None: + vote_last_response( + [state0, state1], "share", [model_selector0, model_selector1], request + ) + + +SAMPLING_WEIGHTS = { + # tier 0 + "gpt-4": 4, + "gpt-4-turbo": 4, + "gpt-3.5-turbo": 2, + "gpt-3.5-turbo-1106": 2, + "claude-2": 8, + "claude-1": 2, + "claude-instant-1": 8, + "zephyr-7b-beta": 2, + "openchat-3.5": 2, + # tier 1 + "deluxe-chat-v1.1": 2, + "palm-2": 1.5, + "llama-2-70b-chat": 1.5, + "llama-2-13b-chat": 1.5, + "codellama-34b-instruct": 1.5, + "vicuna-33b": 8, + "vicuna-13b": 1.5, + "wizardlm-70b": 1.5, + "wizardlm-13b": 1.5, + "qwen-14b-chat": 1.5, + "mistral-7b-instruct": 1.5, + # tier 2 + "vicuna-7b": 1.0, + "llama-2-7b-chat": 1.0, + "chatglm2-6b": 1.0, + # deprecated + "zephyr-7b-alpha": 1.5, + "codellama-13b-instruct": 1.0, + "mpt-30b-chat": 1.5, + "guanaco-33b": 1.0, + "fastchat-t5-3b": 0.5, + "alpaca-13b": 0.5, + "mpt-7b-chat": 0.1, + "oasst-pythia-12b": 0.1, + "RWKV-4-Raven-14B": 0.1, + "gpt4all-13b-snoozy": 0.1, + "koala-13b": 0.1, + "stablelm-tuned-alpha-7b": 0.1, + "dolly-v2-12b": 0.1, + "llama-13b": 0.1, + "chatglm-6b": 0.5, + "deluxe-chat-v1": 4, +} + +# target model sampling weights will be boosted. +BATTLE_TARGETS = { + "gpt-4": {"claude-2"}, + "gpt-4-turbo": {"gpt-4", "gpt-3.5-turbo"}, + "gpt-3.5-turbo": {"claude-instant-1", "gpt-4", "claude-2"}, + "claude-2": {"gpt-4", "gpt-3.5-turbo", "claude-1"}, + "claude-1": {"claude-2", "gpt-4", "gpt-3.5-turbo"}, + "claude-instant-1": {"gpt-3.5-turbo", "claude-2"}, + "deluxe-chat-v1.1": {"gpt-4"}, + "openchat-3.5": {"gpt-3.5-turbo", "llama-2-70b-chat", "zephyr-7b-beta"}, + "qwen-14b-chat": {"vicuna-13b", "llama-2-13b-chat", "llama-2-70b-chat"}, + "zephyr-7b-alpha": {"mistral-7b-instruct", "llama-2-13b-chat"}, + "zephyr-7b-beta": { + "mistral-7b-instruct", + "llama-2-13b-chat", + "llama-2-7b-chat", + "wizardlm-13b", + }, + "llama-2-70b-chat": {"gpt-3.5-turbo", "vicuna-33b", "claude-instant-1"}, + "llama-2-13b-chat": {"mistral-7b-instruct", "vicuna-13b", "llama-2-70b-chat"}, + "llama-2-7b-chat": {"mistral-7b-instruct", "vicuna-7b", "llama-2-13b-chat"}, + "mistral-7b-instruct": { + "llama-2-7b-chat", + "llama-2-13b-chat", + "llama-2-70b-chat", + }, + "vicuna-33b": {"llama-2-70b-chat", "gpt-3.5-turbo", "claude-instant-1"}, + "vicuna-13b": {"llama-2-13b-chat", "llama-2-70b-chat"}, + "vicuna-7b": {"llama-2-7b-chat", "mistral-7b-instruct", "llama-2-13b-chat"}, + "wizardlm-70b": {"gpt-3.5-turbo", "vicuna-33b", "claude-instant-1"}, + "palm-2": {"llama-2-13b-chat", "gpt-3.5-turbo"}, +} + +SAMPLING_BOOST_MODELS = ["openchat-3.5", "gpt-4-turbo", "gpt-3.5-turbo-1106"] + +# outage models won't be sampled. +OUTAGE_MODELS = [] + + +def get_sample_weight(model): + if model in OUTAGE_MODELS: + return 0 + weight = SAMPLING_WEIGHTS.get(model, 1.0) + if model in SAMPLING_BOOST_MODELS: + weight *= 5 + return weight + + +def get_battle_pair(): + if len(models) == 1: + return models[0], models[0] + + model_weights = [] + for model in models: + weight = get_sample_weight(model) + model_weights.append(weight) + total_weight = np.sum(model_weights) + model_weights = model_weights / total_weight + chosen_idx = np.random.choice(len(models), p=model_weights) + chosen_model = models[chosen_idx] + + rival_models = [] + rival_weights = [] + for model in models: + if model == chosen_model: + continue + weight = get_sample_weight(model) + if ( + weight != 0 + and chosen_model in BATTLE_TARGETS + and model in BATTLE_TARGETS[chosen_model] + ): + # boost to 50% chance + weight = total_weight / len(BATTLE_TARGETS[chosen_model]) + rival_models.append(model) + rival_weights.append(weight) + # for p, w in zip(rival_models, rival_weights): + # print(p, w) + rival_weights = rival_weights / np.sum(rival_weights) + rival_idx = np.random.choice(len(rival_models), p=rival_weights) + rival_model = rival_models[rival_idx] + + swap = np.random.randint(2) + if swap == 0: + return chosen_model, rival_model + else: + return rival_model, chosen_model + + +def add_text( + state0, state1, model_selector0, model_selector1, text, request: gr.Request +): + ip = get_ip(request) + logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + # Init states if necessary + if states[0] is None: + assert states[1] is None + + model_left, model_right = get_battle_pair() + states = [ + State(model_left), + State(model_right), + ] + + if len(text) <= 0: + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + no_change_btn, + ] + * 6 + + [""] + ) + + model_list = [states[i].model_name for i in range(num_sides)] + flagged = moderation_filter(text, model_list) + if flagged: + logger.info(f"violate moderation (anony). ip: {ip}. text: {text}") + # overwrite the original text + text = MODERATION_MSG + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {get_ip(request)}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [CONVERSATION_LIMIT_MSG] + + [ + no_change_btn, + ] + * 6 + + [""] + ) + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_sides): + states[i].conv.append_message(states[i].conv.roles[0], text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + slow_model_msg = "" + for i in range(num_sides): + if "deluxe" in states[i].model_name: + slow_model_msg = SLOW_MODEL_MSG + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + disable_btn, + ] + * 6 + + [slow_model_msg] + ) + + +def bot_response_multi( + state0, + state1, + temperature, + top_p, + max_new_tokens, + request: gr.Request, +): + logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}") + + if state0 is None or state0.skip_next: + # This generate call is skipped due to invalid inputs + yield ( + state0, + state1, + state0.to_gradio_chatbot(), + state1.to_gradio_chatbot(), + ) + (no_change_btn,) * 6 + return + + states = [state0, state1] + gen = [] + for i in range(num_sides): + gen.append( + bot_response( + states[i], + temperature, + top_p, + max_new_tokens, + request, + ) + ) + + chatbots = [None] * num_sides + while True: + stop = True + for i in range(num_sides): + try: + ret = next(gen[i]) + states[i], chatbots[i] = ret[0], ret[1] + stop = False + except StopIteration: + pass + yield states + chatbots + [disable_btn] * 6 + if stop: + break + + +def build_side_by_side_ui_anony(models): + notice_markdown = """ +# ⚔️ Chatbot Arena ⚔️ : Benchmarking LLMs in the Wild +| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | + +## 📜 Rules +- Ask any question to two anonymous models (e.g., ChatGPT, Claude, Llama) and vote for the better one! +- You can continue chatting until you identify a winner. +- Vote won't be counted if model identity is revealed during conversation. + +## 🏆 Arena Elo [Leaderboard](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) +We use **100K** human votes to compile an Elo-based LLM leaderboard. +Find out who is the 🥇LLM Champion! + +## 👇 Chat now! + +""" + + states = [gr.State() for _ in range(num_sides)] + model_selectors = [None] * num_sides + chatbots = [None] * num_sides + + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Box(elem_id="share-region-anony"): + with gr.Row(): + for i in range(num_sides): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = gr.Chatbot( + label=label, elem_id=f"chatbot", height=550 + ) + + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Markdown(anony_names[i]) + with gr.Row(): + slow_warning = gr.Markdown("", elem_id="notice_markdown") + + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) + + with gr.Row(): + with gr.Column(scale=20): + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + container=False, + elem_id="input_box", + ) + with gr.Column(scale=1, min_width=50): + send_btn = gr.Button(value="Send", variant="primary") + + with gr.Row() as button_row: + clear_btn = gr.Button(value="🎲 New Round", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=1024, + value=512, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(acknowledgment_md) + + # Register listeners + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + clear_btn.click( + clear_history, + None, + states + chatbots + model_selectors + [textbox] + btn_list + [slow_warning], + ) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-anony'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], _js=share_js) + + textbox.submit( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list + [slow_warning], + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, + [], + btn_list, + ) + + send_btn.click( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + + return states + model_selectors diff --git a/gradio_block_arena_named.py b/gradio_block_arena_named.py new file mode 100644 index 0000000000000000000000000000000000000000..c13283495aec8262de92d47fbd758e47204172d0 --- /dev/null +++ b/gradio_block_arena_named.py @@ -0,0 +1,458 @@ +""" +Chatbot Arena (side-by-side) tab. +Users chat with two chosen models. +""" + +import json +import time + +import gradio as gr +import numpy as np + +from fastchat.constants import ( + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.serve.gradio_web_server import ( + State, + bot_response, + get_conv_log_filename, + no_change_btn, + enable_btn, + disable_btn, + invisible_btn, + acknowledgment_md, + get_model_description_md, + ip_expiration_dict, + get_ip, +) +from fastchat.utils import ( + build_logger, + moderation_filter, +) + + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + +num_sides = 2 +enable_moderation = False + + +def set_global_vars_named(enable_moderation_): + global enable_moderation + enable_moderation = enable_moderation_ + + +def load_demo_side_by_side_named(models, url_params): + states = (None,) * num_sides + + model_left = models[0] if len(models) > 0 else "" + if len(models) > 1: + weights = ([8] * 4 + [4] * 8 + [1] * 32)[: len(models) - 1] + weights = weights / np.sum(weights) + model_right = np.random.choice(models[1:], p=weights) + else: + model_right = model_left + + selector_updates = ( + gr.Dropdown.update(choices=models, value=model_left, visible=True), + gr.Dropdown.update(choices=models, value=model_right, visible=True), + ) + + return states + selector_updates + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (named). ip: {get_ip(request)}") + states = [state0, state1] + for i in range(num_sides): + states[i].conv.update_last_message(None) + return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (named). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + [""] + + [invisible_btn] * 4 + + [disable_btn] * 2 + ) + + +def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request): + logger.info(f"share (named). ip: {get_ip(request)}") + if state0 is not None and state1 is not None: + vote_last_response( + [state0, state1], "share", [model_selector0, model_selector1], request + ) + + +def add_text( + state0, state1, model_selector0, model_selector1, text, request: gr.Request +): + ip = get_ip(request) + logger.info(f"add_text (named). ip: {ip}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + # Init states if necessary + for i in range(num_sides): + if states[i] is None: + states[i] = State(model_selectors[i]) + + if len(text) <= 0: + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + no_change_btn, + ] + * 6 + ) + + model_list = [states[i].model_name for i in range(num_sides)] + flagged = moderation_filter(text, model_list) + if flagged: + logger.info(f"violate moderation (named). ip: {ip}. text: {text}") + # overwrite the original text + text = MODERATION_MSG + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {ip}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [CONVERSATION_LIMIT_MSG] + + [ + no_change_btn, + ] + * 6 + ) + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_sides): + states[i].conv.append_message(states[i].conv.roles[0], text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + disable_btn, + ] + * 6 + ) + + +def bot_response_multi( + state0, + state1, + temperature, + top_p, + max_new_tokens, + request: gr.Request, +): + logger.info(f"bot_response_multi (named). ip: {get_ip(request)}") + + if state0.skip_next: + # This generate call is skipped due to invalid inputs + yield ( + state0, + state1, + state0.to_gradio_chatbot(), + state1.to_gradio_chatbot(), + ) + (no_change_btn,) * 6 + return + + states = [state0, state1] + gen = [] + for i in range(num_sides): + gen.append( + bot_response( + states[i], + temperature, + top_p, + max_new_tokens, + request, + ) + ) + + chatbots = [None] * num_sides + while True: + stop = True + for i in range(num_sides): + try: + ret = next(gen[i]) + states[i], chatbots[i] = ret[0], ret[1] + stop = False + except StopIteration: + pass + yield states + chatbots + [disable_btn] * 6 + if stop: + break + + +def flash_buttons(): + btn_updates = [ + [disable_btn] * 4 + [enable_btn] * 2, + [enable_btn] * 6, + ] + for i in range(4): + yield btn_updates[i % 2] + time.sleep(0.5) + + +def build_side_by_side_ui_named(models): + notice_markdown = """ +# ⚔️ Chatbot Arena ⚔️ : Benchmarking LLMs in the Wild +| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | + +## 📜 Rules +- Chat with any two models side-by-side and vote! +- You can continue chatting for multiple rounds. +- Click "Clear history" to start a new round. + +## 🤖 Choose two models to compare +""" + + states = [gr.State() for _ in range(num_sides)] + model_selectors = [None] * num_sides + chatbots = [None] * num_sides + + model_description_md = get_model_description_md(models) + notice = gr.Markdown( + notice_markdown + model_description_md, elem_id="notice_markdown" + ) + + with gr.Box(elem_id="share-region-named"): + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Dropdown( + choices=models, + value=models[i] if len(models) > i else "", + interactive=True, + show_label=False, + container=False, + ) + + with gr.Row(): + for i in range(num_sides): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = gr.Chatbot( + label=label, elem_id=f"chatbot", height=550 + ) + + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) + + with gr.Row(): + with gr.Column(scale=20): + textbox = gr.Textbox( + show_label=False, + placeholder="Enter your prompt here and press ENTER", + container=False, + elem_id="input_box", + ) + with gr.Column(scale=1, min_width=50): + send_btn = gr.Button(value="Send", variant="primary") + + with gr.Row() as button_row: + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=1024, + value=512, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(acknowledgment_md) + + # Register listeners + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-named'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], _js=share_js) + + for i in range(num_sides): + model_selectors[i].change( + clear_history, None, states + chatbots + [textbox] + btn_list + ) + + textbox.submit( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + send_btn.click( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + + return states + model_selectors diff --git a/gradio_web_server.py b/gradio_web_server.py new file mode 100644 index 0000000000000000000000000000000000000000..adab6e9063f6b60b8f446cc6df009a643fc910ad --- /dev/null +++ b/gradio_web_server.py @@ -0,0 +1,883 @@ +""" +The gradio demo server for chatting with a single model. +""" + +import argparse +from collections import defaultdict +import datetime +import json +import os +import random +import time +import uuid + +import gradio as gr +import requests + +from fastchat.conversation import SeparatorStyle +from fastchat.constants import ( + LOGDIR, + WORKER_API_TIMEOUT, + ErrorCode, + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + SERVER_ERROR_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, + SESSION_EXPIRATION_TIME, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.conversation import get_conv_template +from fastchat.model.model_registry import get_model_info, model_info +from fastchat.serve.api_provider import ( + anthropic_api_stream_iter, + openai_api_stream_iter, + palm_api_stream_iter, + init_palm_chat, +) +from fastchat.utils import ( + build_logger, + moderation_filter, + get_window_url_params_js, + get_window_url_params_with_tos_js, + parse_gradio_auth_creds, +) + +CONV_TEMPLATE = '' + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + +headers = {"User-Agent": "FastChat Client"} + +no_change_btn = gr.Button.update() +enable_btn = gr.Button.update(interactive=True, visible=True) +disable_btn = gr.Button.update(interactive=False) +invisible_btn = gr.Button.update(interactive=False, visible=False) + +controller_url = None +enable_moderation = False + +acknowledgment_md = """ +### Acknowledgment +
+

We thank Kaggle, MBZUAI, AnyScale, and HuggingFace for their sponsorship.

+ Image 1 + Image 2 + Image 3 + Image 4 +
+""" + +ip_expiration_dict = defaultdict(lambda: 0) + +# Information about custom OpenAI compatible API models. +# JSON file format: +# { +# "vicuna-7b": { +# "model_name": "vicuna-7b-v1.5", +# "api_base": "http://8.8.8.55:5555/v1", +# "api_key": "password" +# }, +# } +openai_compatible_models_info = {} + + +class State: + def __init__(self, model_name): + # if model_name=='checkpoint-800': + # self.conv = get_conv_template(CONV_TEMPLATE) + # elif model_name=='MiniCPM-2B-sft-bf16': + ret = requests.post( + controller_url + "/get_worker_address", json={"model": model_name} + ) + worker_addr = ret.json()["address"] + conv_name = requests.post( + worker_addr + "/worker_get_conv_template", + ).json()['conv']['name'] + self.conv = get_conv_template(conv_name) + # self.conv = get_conv_template('minicpm') + # print(self.conv) + # self.conv = get_conversation_template(model_name) + self.conv_id = uuid.uuid4().hex + self.skip_next = False + self.model_name = model_name + + if model_name == "palm-2": + # According to release note, "chat-bison@001" is PaLM 2 for chat. + # https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023 + self.palm_chat = init_palm_chat("chat-bison@001") + + def to_gradio_chatbot(self): + return self.conv.to_gradio_chatbot() + + def dict(self): + base = self.conv.dict() + base.update( + { + "conv_id": self.conv_id, + "model_name": self.model_name, + } + ) + return base + + +def set_global_vars(controller_url_, enable_moderation_): + global controller_url, enable_moderation + controller_url = controller_url_ + enable_moderation = enable_moderation_ + + +def get_conv_log_filename(): + t = datetime.datetime.now() + name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + return name + + +def get_model_list( + controller_url, register_openai_compatible_models, add_chatgpt, add_claude, add_palm +): + if controller_url: + ret = requests.post(controller_url + "/refresh_all_workers") + assert ret.status_code == 200 + ret = requests.post(controller_url + "/list_models") + # ret = requests.post(controller_url + "/get_worker_address") + # ret = requests.post(controller_url + "/worker_get_status") + models = ret.json()["models"] + else: + models = [] + + # Add API providers + if register_openai_compatible_models: + global openai_compatible_models_info + openai_compatible_models_info = json.load( + open(register_openai_compatible_models) + ) + models += list(openai_compatible_models_info.keys()) + + if add_chatgpt: + models += ["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", "gpt-3.5-turbo-1106"] + if add_claude: + models += ["claude-2", "claude-instant-1"] + if add_palm: + models += ["palm-2"] + models = list(set(models)) + + if "deluxe-chat-v1" in models: + del models[models.index("deluxe-chat-v1")] + if "deluxe-chat-v1.1" in models: + del models[models.index("deluxe-chat-v1.1")] + + priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)} + models.sort(key=lambda x: priority.get(x, x)) + logger.info(f"Models: {models}") + return models + + +def load_demo_single(models, url_params): + selected_model = models[0] if len(models) > 0 else "" + if "model" in url_params: + model = url_params["model"] + if model in models: + selected_model = model + + dropdown_update = gr.Dropdown.update( + choices=models, value=selected_model, visible=True + ) + + state = None + return state, dropdown_update + + +def load_demo(url_params, request: gr.Request): + global models + + ip = get_ip(request) + logger.info(f"load_demo. ip: {ip}. params: {url_params}") + ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME + + if args.model_list_mode == "reload": + models = get_model_list( + controller_url, + args.register_openai_compatible_models, + args.add_chatgpt, + args.add_claude, + args.add_palm, + ) + + return load_demo_single(models, url_params) + + +def vote_last_response(state, vote_type, model_selector, request: gr.Request): + with open('./web_chat_downvote.jsonl', "a+") as fout: + # data = { + # "tstamp": round(time.time(), 4), + # "type": vote_type, + # "model": model_selector, + # "state": state.dict(), + # "ip": get_ip(request), + # } + conversations = [] + for i, turn in enumerate(state.dict()['messages']): + role = 'user' if i % 2 == 0 else 'assistant' + conversations.append({'role': role, 'content': turn[1]}) + data = { + 'conversations': conversations, + 'idx': state.dict()['conv_id'], + 'tinder': 'badcase', + 'model': state.dict()['model_name'], + 'tokens_in': -1, + 'tokens_out': -1, + } + fout.write(json.dumps(data, ensure_ascii=False) + "\n") + + +def upvote_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"upvote. ip: {ip}") + vote_last_response(state, "upvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def downvote_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"downvote. ip: {ip}") + vote_last_response(state, "downvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def flag_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"flag. ip: {ip}") + vote_last_response(state, "flag", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def regenerate(state, request: gr.Request): + ip = get_ip(request) + logger.info(f"regenerate. ip: {ip}") + state.conv.update_last_message(None) + return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + + +def clear_history(request: gr.Request): + ip = get_ip(request) + logger.info(f"clear_history. ip: {ip}") + state = None + return (state, [], "") + (disable_btn,) * 5 + + +def get_ip(request: gr.Request): + if "cf-connecting-ip" in request.headers: + ip = request.headers["cf-connecting-ip"] + else: + ip = request.client.host + return ip + + +def add_text(state, model_selector, text, request: gr.Request): + ip = get_ip(request) + logger.info(f"add_text. ip: {ip}. len: {len(text)}") + + if state is None: + state = State(model_selector) + + if len(text) <= 0: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 + + flagged = moderation_filter(text, [state.model_name]) + if flagged: + logger.info(f"violate moderation. ip: {ip}. text: {text}") + # overwrite the original text + text = MODERATION_MSG + + conv = state.conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {ip}. text: {text}") + state.skip_next = True + return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + ( + no_change_btn, + ) * 5 + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + conv.append_message(conv.roles[0], text) + conv.append_message(conv.roles[1], None) + return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + + +def post_process_code(code): + sep = "\n```" + if sep in code: + blocks = code.split(sep) + if len(blocks) % 2 == 1: + for i in range(1, len(blocks), 2): + blocks[i] = blocks[i].replace("\\_", "_") + code = sep.join(blocks) + return code + + +def model_worker_stream_iter( + conv, + model_name, + worker_addr, + prompt, + temperature, + repetition_penalty, + top_p, + max_new_tokens, +): + # Make requests + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + logger.info(f"==== request ====\n{gen_params}") + + # Stream output + response = requests.post( + worker_addr + "/worker_generate_stream", + headers=headers, + json=gen_params, + stream=True, + timeout=WORKER_API_TIMEOUT, + ) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + yield data + + +def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request): + ip = get_ip(request) + logger.info(f"bot_response. ip: {ip}") + start_tstamp = time.time() + temperature = float(temperature) + top_p = float(top_p) + max_new_tokens = int(max_new_tokens) + + if state.skip_next: + # This generate call is skipped due to invalid inputs + state.skip_next = False + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + + conv, model_name = state.conv, state.model_name + if model_name in ["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", "gpt-3.5-turbo-1106"]: + prompt = conv.to_openai_api_messages() + stream_iter = openai_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + elif model_name in ["claude-2", "claude-1", "claude-instant-1"]: + prompt = conv.get_prompt() + stream_iter = anthropic_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + elif model_name == "palm-2": + stream_iter = palm_api_stream_iter( + state.palm_chat, conv.messages[-2][1], temperature, top_p, max_new_tokens + ) + elif model_name in openai_compatible_models_info: + model_info = openai_compatible_models_info[model_name] + prompt = conv.to_openai_api_messages() + stream_iter = openai_api_stream_iter( + model_info["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_base=model_info["api_base"], + api_key=model_info["api_key"], + ) + else: + # Query worker address + ret = requests.post( + controller_url + "/get_worker_address", json={"model": model_name} + ) + worker_addr = ret.json()["address"] + logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") + + # No available worker + if worker_addr == "": + conv.update_last_message(SERVER_ERROR_MSG) + yield ( + state, + state.to_gradio_chatbot(), + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + + # Construct prompt. + # We need to call it here, so it will not be affected by "▌". + prompt = conv.get_prompt() + # Set repetition_penalty + if "t5" in model_name: + repetition_penalty = 1.2 + else: + repetition_penalty = 1.0 + + stream_iter = model_worker_stream_iter( + conv, + model_name, + worker_addr, + prompt, + temperature, + repetition_penalty, + top_p, + max_new_tokens, + ) + + conv.update_last_message("▌") + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + + try: + for i, data in enumerate(stream_iter): + if data["error_code"] == 0: + output = data["text"].strip() + conv.update_last_message(output + "▌") + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + else: + output = data["text"] + f"\n\n(error_code: {data['error_code']})" + conv.update_last_message(output) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + output = data["text"].strip() + if "vicuna" in model_name: + output = post_process_code(output) + conv.update_last_message(output) + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + except requests.exceptions.RequestException as e: + conv.update_last_message( + f"{SERVER_ERROR_MSG}\n\n" + f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})" + ) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + except Exception as e: + conv.update_last_message( + f"{SERVER_ERROR_MSG}\n\n" + f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})" + ) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + + finish_tstamp = time.time() + logger.info(f"{output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "gen_params": { + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + }, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + + +block_css = """ +#notice_markdown { + font-size: 110% +} +#notice_markdown th { + display: none; +} +#notice_markdown td { + padding-top: 6px; + padding-bottom: 6px; +} +#leaderboard_markdown { + font-size: 110% +} +#leaderboard_markdown td { + padding-top: 6px; + padding-bottom: 6px; +} +#leaderboard_dataframe td { + line-height: 0.1em; +} +#about_markdown { + font-size: 110% +} +#input_box textarea { +} +footer { + display:none !important +} +.image-container { + display: flex; + align-items: center; + padding: 1px; +} +.image-container img { + margin: 0 30px; + height: 20px; + max-height: 100%; + width: auto; + max-width: 20%; +} +.image-about img { + margin: 0 30px; + margin-top: 30px; + height: 60px; + max-height: 100%; + width: auto; + float: left; +} +""" + + +def get_model_description_md(models): + model_description_md = """ +| | | | +| ---- | ---- | ---- | +""" + ct = 0 + visited = set() + for i, name in enumerate(models): + minfo = get_model_info(name) + if minfo.simple_name in visited: + continue + visited.add(minfo.simple_name) + one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}" + + if ct % 3 == 0: + model_description_md += "|" + model_description_md += f" {one_model_md} |" + if ct % 3 == 2: + model_description_md += "\n" + ct += 1 + return model_description_md + + +def build_about(): + about_markdown = f""" +# About Us +Chatbot Arena is an open-source research project developed by members from [LMSYS](https://lmsys.org/about/) and UC Berkeley [SkyLab](https://sky.cs.berkeley.edu/). Our mission is to build an open crowdsourced platform to collect human feedback and evaluate LLMs under real-world scenarios. We open-source our code at [GitHub](https://github.com/lm-sys/FastChat) and release chat and human feedback datasets [here](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md). We invite everyone to join us in this journey! + +## Read More +- Chatbot Arena [launch post](https://lmsys.org/blog/2023-05-03-arena/), [data release](https://lmsys.org/blog/2023-07-20-dataset/) +- LMSYS-Chat-1M [report](https://arxiv.org/abs/2309.11998) + +## Core Members +[Lianmin Zheng](https://lmzheng.net/), [Wei-Lin Chiang](https://infwinston.github.io/), [Ying Sheng](https://sites.google.com/view/yingsheng/home), [Siyuan Zhuang](https://scholar.google.com/citations?user=KSZmI5EAAAAJ) + +## Advisors +[Ion Stoica](http://people.eecs.berkeley.edu/~istoica/), [Joseph E. Gonzalez](https://people.eecs.berkeley.edu/~jegonzal/), [Hao Zhang](https://cseweb.ucsd.edu/~haozhang/) + +## Contact Us +- Follow our [Twitter](https://twitter.com/lmsysorg), [Discord](https://discord.gg/HSWAKCrnFx) or email us at lmsys.org@gmail.com +- File issues on [GitHub](https://github.com/lm-sys/FastChat) +- Download our datasets and models on [HuggingFace](https://huggingface.co/lmsys) + +## Sponsors +We thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous sponsorship. +Learn more about partnership [here](https://lmsys.org/donations/). + +
+ Image 1 + Image 2 + Image 3 + Image 4 +
+""" + + # state = gr.State() + gr.Markdown(about_markdown, elem_id="about_markdown") + + # return [state] + + +def build_single_model_ui(models, add_promotion_links=False): + promotion = ( + """ +- | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | +- Introducing Llama 2: The Next Generation Open Source Large Language Model. [[Website]](https://ai.meta.com/llama/) +- Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog]](https://lmsys.org/blog/2023-03-30-vicuna/) +""" + if add_promotion_links + else "" + ) + + notice_markdown = f""" +# 🏔️ Chat with Open Large Language Models +{promotion} + +## 👉 Choose any model to chat +""" + + state = gr.State() + model_description_md = get_model_description_md(models) + gr.Markdown(notice_markdown + model_description_md, elem_id="notice_markdown") + + with gr.Row(elem_id="model_selector_row"): + model_selector = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "", + interactive=True, + show_label=False, + container=False, + ) + + chatbot = gr.Chatbot( + elem_id="chatbot", + label="Scroll down and start chatting", + height=550, + ) + with gr.Row(): + with gr.Column(scale=20): + textbox = gr.Textbox( + show_label=False, + placeholder="Enter your prompt here and press ENTER", + container=False, + elem_id="input_box", + ) + with gr.Column(scale=1, min_width=50): + send_btn = gr.Button(value="Send", variant="primary") + + with gr.Row() as button_row: + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=3072, + value=2048, + step=1, + interactive=True, + label="Max output tokens", + ) + + if add_promotion_links: + gr.Markdown(acknowledgment_md) + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] + upvote_btn.click( + upvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + downvote_btn.click( + downvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + flag_btn.click( + flag_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) + + model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list) + + textbox.submit( + add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list + ).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + send_btn.click( + add_text, + [state, model_selector, textbox], + [state, chatbot, textbox] + btn_list, + ).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + + return [state, model_selector] + + +def build_demo(models): + with gr.Blocks( + title="Chat with Open Large Language Models", + theme=gr.themes.Default(), + css=block_css, + ) as demo: + url_params = gr.JSON(visible=False) + + state, model_selector = build_single_model_ui(models) + + if args.model_list_mode not in ["once", "reload"]: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + if args.show_terms_of_use: + load_js = get_window_url_params_with_tos_js + else: + load_js = get_window_url_params_js + + demo.load( + load_demo, + [url_params], + [ + state, + model_selector, + ], + _js=load_js, + ) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument( + "--conv-template", + type=str, + default="megrez", + help="The address of the controller", + ) + parser.add_argument( + "--share", + action="store_true", + help="Whether to generate a public, shareable link", + ) + parser.add_argument( + "--controller-url", + type=str, + default="http://localhost:21001", + help="The address of the controller", + ) + parser.add_argument( + "--concurrency-count", + type=int, + default=10, + help="The concurrency count of the gradio queue", + ) + parser.add_argument( + "--model-list-mode", + type=str, + default="once", + choices=["once", "reload"], + help="Whether to load the model list once or reload the model list every time", + ) + parser.add_argument( + "--moderate", + action="store_true", + help="Enable content moderation to block unsafe inputs", + ) + parser.add_argument( + "--show-terms-of-use", + action="store_true", + help="Shows term of use before loading the demo", + ) + parser.add_argument( + "--add-chatgpt", + action="store_true", + help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)", + ) + parser.add_argument( + "--add-claude", + action="store_true", + help="Add Anthropic's Claude models (claude-2, claude-instant-1)", + ) + parser.add_argument( + "--add-palm", + action="store_true", + help="Add Google's PaLM model (PaLM 2 for Chat: chat-bison@001)", + ) + parser.add_argument( + "--register-openai-compatible-models", + type=str, + help="Register custom OpenAI API compatible models by loading them from a JSON file", + ) + parser.add_argument( + "--gradio-auth-path", + type=str, + help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', + ) + args = parser.parse_args() + logger.info(f"args: {args}") + CONV_TEMPLATE = args.conv_template + # Set global variables + set_global_vars(args.controller_url, args.moderate) + models = get_model_list( + args.controller_url, + args.register_openai_compatible_models, + args.add_chatgpt, + args.add_claude, + args.add_palm, + ) + # Set authorization credentials + auth = None + if args.gradio_auth_path is not None: + auth = parse_gradio_auth_creds(args.gradio_auth_path) + + # Launch the demo + demo = build_demo(models) + ret = demo.queue( + concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share, + max_threads=200, + auth=auth, + ) + from IPython import embed;embed() \ No newline at end of file diff --git a/gradio_web_server_multi.py b/gradio_web_server_multi.py new file mode 100644 index 0000000000000000000000000000000000000000..b918f9d6b65c91a2d0a8d1488e9290c207c39b30 --- /dev/null +++ b/gradio_web_server_multi.py @@ -0,0 +1,270 @@ +""" +The gradio demo server with multiple tabs. +It supports chatting with a single model or chatting with two models side-by-side. +""" + +import argparse +import pickle +import time + +import gradio as gr + +from fastchat.constants import ( + SESSION_EXPIRATION_TIME, +) +from fastchat.serve.gradio_block_arena_anony import ( + build_side_by_side_ui_anony, + load_demo_side_by_side_anony, + set_global_vars_anony, +) +from fastchat.serve.gradio_block_arena_named import ( + build_side_by_side_ui_named, + load_demo_side_by_side_named, + set_global_vars_named, +) +from fastchat.serve.gradio_web_server import ( + set_global_vars, + block_css, + build_single_model_ui, + build_about, + get_model_list, + load_demo_single, + ip_expiration_dict, + get_ip, +) +from fastchat.serve.monitor.monitor import build_leaderboard_tab +from fastchat.utils import ( + build_logger, + get_window_url_params_js, + get_window_url_params_with_tos_js, + parse_gradio_auth_creds, +) + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + + +def load_demo(url_params, request: gr.Request): + global models + + ip = get_ip(request) + logger.info(f"load_demo. ip: {ip}. params: {url_params}") + ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME + + selected = 0 + if "arena" in url_params: + selected = 0 + elif "compare" in url_params: + selected = 1 + elif "single" in url_params: + selected = 2 + elif "leaderboard" in url_params: + selected = 3 + + if args.model_list_mode == "reload": + if args.anony_only_for_proprietary_model: + models = get_model_list( + args.controller_url, + args.register_openai_compatible_models, + False, + False, + False, + ) + else: + models = get_model_list( + args.controller_url, + args.register_openai_compatible_models, + args.add_chatgpt, + args.add_claude, + args.add_palm, + ) + + single_updates = load_demo_single(models, url_params) + + models_anony = list(models) + if args.anony_only_for_proprietary_model: + # Only enable these models in anony battles. + if args.add_chatgpt: + models_anony += [ + "gpt-4", + "gpt-3.5-turbo", + "gpt-4-turbo", + "gpt-3.5-turbo-1106", + ] + if args.add_claude: + models_anony += ["claude-2", "claude-1", "claude-instant-1"] + if args.add_palm: + models_anony += ["palm-2"] + models_anony = list(set(models_anony)) + + side_by_side_anony_updates = load_demo_side_by_side_anony(models_anony, url_params) + side_by_side_named_updates = load_demo_side_by_side_named(models, url_params) + return ( + (gr.Tabs.update(selected=selected),) + + single_updates + + side_by_side_anony_updates + + side_by_side_named_updates + ) + + +def build_demo(models, elo_results_file, leaderboard_table_file): + text_size = gr.themes.sizes.text_md + with gr.Blocks( + title="Chat with Open Large Language Models", + theme=gr.themes.Default(text_size=text_size), + css=block_css, + ) as demo: + with gr.Tabs() as tabs: + with gr.Tab("Arena (battle)", id=0): + side_by_side_anony_list = build_side_by_side_ui_anony(models) + + with gr.Tab("Arena (side-by-side)", id=1): + side_by_side_named_list = build_side_by_side_ui_named(models) + + with gr.Tab("Direct Chat", id=2): + single_model_list = build_single_model_ui( + models, add_promotion_links=True + ) + if elo_results_file: + with gr.Tab("Leaderboard", id=3): + build_leaderboard_tab(elo_results_file, leaderboard_table_file) + with gr.Tab("About Us", id=4): + about = build_about() + + url_params = gr.JSON(visible=False) + + if args.model_list_mode not in ["once", "reload"]: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + if args.show_terms_of_use: + load_js = get_window_url_params_with_tos_js + else: + load_js = get_window_url_params_js + + demo.load( + load_demo, + [url_params], + [tabs] + + single_model_list + + side_by_side_anony_list + + side_by_side_named_list, + _js=load_js, + ) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument( + "--share", + action="store_true", + help="Whether to generate a public, shareable link", + ) + parser.add_argument( + "--controller-url", + type=str, + default="http://localhost:21001", + help="The address of the controller", + ) + parser.add_argument( + "--concurrency-count", + type=int, + default=10, + help="The concurrency count of the gradio queue", + ) + parser.add_argument( + "--model-list-mode", + type=str, + default="once", + choices=["once", "reload"], + help="Whether to load the model list once or reload the model list every time.", + ) + parser.add_argument( + "--moderate", + action="store_true", + help="Enable content moderation to block unsafe inputs", + ) + parser.add_argument( + "--show-terms-of-use", + action="store_true", + help="Shows term of use before loading the demo", + ) + parser.add_argument( + "--add-chatgpt", + action="store_true", + help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)", + ) + parser.add_argument( + "--add-claude", + action="store_true", + help="Add Anthropic's Claude models (claude-2, claude-instant-1)", + ) + parser.add_argument( + "--add-palm", + action="store_true", + help="Add Google's PaLM model (PaLM 2 for Chat: chat-bison@001)", + ) + parser.add_argument( + "--anony-only-for-proprietary-model", + action="store_true", + help="Only add ChatGPT, Claude, Bard under anony battle tab", + ) + parser.add_argument( + "--register-openai-compatible-models", + type=str, + help="Register custom OpenAI API compatible models by loading them from a JSON file", + ) + parser.add_argument( + "--gradio-auth-path", + type=str, + help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', + default=None, + ) + parser.add_argument( + "--elo-results-file", type=str, help="Load leaderboard results and plots" + ) + parser.add_argument( + "--leaderboard-table-file", type=str, help="Load leaderboard results and plots" + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + # Set global variables + set_global_vars(args.controller_url, args.moderate) + set_global_vars_named(args.moderate) + set_global_vars_anony(args.moderate) + if args.anony_only_for_proprietary_model: + models = get_model_list( + args.controller_url, + args.register_openai_compatible_models, + False, + False, + False, + ) + else: + models = get_model_list( + args.controller_url, + args.register_openai_compatible_models, + args.add_chatgpt, + args.add_claude, + args.add_palm, + ) + + # Set authorization credentials + auth = None + if args.gradio_auth_path is not None: + auth = parse_gradio_auth_creds(args.gradio_auth_path) + + # Launch the demo + demo = build_demo(models, args.elo_results_file, args.leaderboard_table_file) + demo.queue( + concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share, + max_threads=200, + auth=auth, + ) diff --git a/huggingface_api.py b/huggingface_api.py new file mode 100644 index 0000000000000000000000000000000000000000..2a49bf5f175995d20c339906a84c71b97a49fae6 --- /dev/null +++ b/huggingface_api.py @@ -0,0 +1,73 @@ +""" +Use FastChat with Hugging Face generation APIs. + +Usage: +python3 -m fastchat.serve.huggingface_api --model lmsys/vicuna-7b-v1.5 +python3 -m fastchat.serve.huggingface_api --model lmsys/fastchat-t5-3b-v1.0 +""" +import argparse + +import torch + +from fastchat.model import load_model, get_conversation_template, add_model_args + + +@torch.inference_mode() +def main(args): + # Load model + model, tokenizer = load_model( + args.model_path, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + revision=args.revision, + debug=args.debug, + ) + + # Build the prompt with a conversation template + msg = args.message + conv = get_conversation_template(args.model_path) + conv.append_message(conv.roles[0], msg) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + # Run inference + inputs = tokenizer([prompt], return_tensors="pt").to(args.device) + output_ids = model.generate( + **inputs, + do_sample=True if args.temperature > 1e-5 else False, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + max_new_tokens=args.max_new_tokens, + ) + + if model.config.is_encoder_decoder: + output_ids = output_ids[0] + else: + output_ids = output_ids[0][len(inputs["input_ids"][0]) :] + outputs = tokenizer.decode( + output_ids, skip_special_tokens=True, spaces_between_special_tokens=False + ) + + # Print results + print(f"{conv.roles[0]}: {msg}") + print(f"{conv.roles[1]}: {outputs}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_model_args(parser) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--repetition_penalty", type=float, default=1.0) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--debug", action="store_true") + parser.add_argument("--message", type=str, default="Hello! Who are you?") + args = parser.parse_args() + + # Reset default repetition penalty for T5 models. + if "t5" in args.model_path and args.repetition_penalty == 1.0: + args.repetition_penalty = 1.2 + + main(args) diff --git a/huggingface_api_worker.py b/huggingface_api_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..7eef50e472ea739c17fac7ff586f775f8a9d1f4f --- /dev/null +++ b/huggingface_api_worker.py @@ -0,0 +1,391 @@ +""" +A model worker that calls huggingface inference endpoint. + +Register models in a JSON file with the following format: +{ + "falcon-180b-chat": { + "model_path": "tiiuae/falcon-180B-chat", + "api_base": "https://api-inference.huggingface.co/models", + "token": "hf_xxx", + "context_length": 2048, + "model_names": "falcon-180b-chat", + "conv_template": null + } +} + +"model_path", "api_base", "token", and "context_length" are necessary, while others are optional. +""" +import argparse +import asyncio +import json +import uuid +from typing import List, Optional + +import requests +import uvicorn +from fastapi import BackgroundTasks, FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse +from huggingface_hub import InferenceClient + +from fastchat.constants import SERVER_ERROR_MSG, ErrorCode +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.utils import build_logger + +worker_id = str(uuid.uuid4())[:8] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") + +workers = [] +worker_map = {} +app = FastAPI() + + +# reference to +# https://github.com/philschmid/easyllm/blob/cbd908b3b3f44a97a22cb0fc2c93df3660bacdad/easyllm/clients/huggingface.py#L374-L392 +def get_gen_kwargs( + params, + seed: Optional[int] = None, +): + stop = params.get("stop", None) + if isinstance(stop, list): + stop_sequences = stop + elif isinstance(stop, str): + stop_sequences = [stop] + else: + stop_sequences = [] + gen_kwargs = { + "do_sample": True, + "return_full_text": bool(params.get("echo", False)), + "max_new_tokens": int(params.get("max_new_tokens", 256)), + "top_p": float(params.get("top_p", 1.0)), + "temperature": float(params.get("temperature", 1.0)), + "stop_sequences": stop_sequences, + "repetition_penalty": float(params.get("repetition_penalty", 1.0)), + "top_k": params.get("top_k", None), + "seed": seed, + } + if gen_kwargs["top_p"] == 1: + gen_kwargs["top_p"] = 0.9999999 + if gen_kwargs["top_p"] == 0: + gen_kwargs.pop("top_p") + if gen_kwargs["temperature"] == 0: + gen_kwargs.pop("temperature") + gen_kwargs["do_sample"] = False + return gen_kwargs + + +def could_be_stop(text, stop): + for s in stop: + if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)): + return True + return False + + +class HuggingfaceApiWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + api_base: str, + token: str, + context_length: int, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + conv_template: Optional[str] = None, + seed: Optional[int] = None, + **kwargs, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template=conv_template, + ) + + self.model_path = model_path + self.api_base = api_base + self.token = token + self.context_len = context_length + self.seed = seed + + logger.info( + f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..." + ) + + if not no_register: + self.init_heart_beat() + + def count_token(self, params): + # No tokenizer here + ret = { + "count": 0, + "error_code": 0, + } + return ret + + def generate_stream_gate(self, params): + self.call_ct += 1 + + prompt = params["prompt"] + gen_kwargs = get_gen_kwargs(params, seed=self.seed) + stop = gen_kwargs["stop_sequences"] + if "falcon" in self.model_path and "chat" in self.model_path: + stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"]) + stop = list(set(stop)) + gen_kwargs["stop_sequences"] = stop + + logger.info(f"prompt: {prompt}") + logger.info(f"gen_kwargs: {gen_kwargs}") + + try: + if self.model_path == "": + url = f"{self.api_base}" + else: + url = f"{self.api_base}/{self.model_path}" + client = InferenceClient(url, token=self.token) + res = client.text_generation( + prompt, stream=True, details=True, **gen_kwargs + ) + + reason = None + text = "" + for chunk in res: + if chunk.token.special: + continue + text += chunk.token.text + + s = next((x for x in stop if text.endswith(x)), None) + if s is not None: + text = text[: -len(s)] + reason = "stop" + break + if could_be_stop(text, stop): + continue + if ( + chunk.details is not None + and chunk.details.finish_reason is not None + ): + reason = chunk.details.finish_reason + if reason not in ["stop", "length"]: + reason = None + ret = { + "text": text, + "error_code": 0, + "finish_reason": reason, + } + yield json.dumps(ret).encode() + b"\0" + except Exception as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + def generate_gate(self, params): + for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + def get_embeddings(self, params): + raise NotImplementedError() + + +def release_worker_semaphore(worker): + worker.semaphore.release() + + +def acquire_worker_semaphore(worker): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(worker): + background_tasks = BackgroundTasks() + background_tasks.add_task(lambda: release_worker_semaphore(worker)) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + await acquire_worker_semaphore(worker) + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks(worker) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + await acquire_worker_semaphore(worker) + output = worker.generate_gate(params) + release_worker_semaphore(worker) + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + await acquire_worker_semaphore(worker) + embedding = worker.get_embeddings(params) + release_worker_semaphore(worker) + return JSONResponse(content=embedding) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + } + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return {"context_length": worker.context_len} + + +def create_huggingface_api_worker(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + # all model-related parameters are listed in --model-info-file + parser.add_argument( + "--model-info-file", + type=str, + required=True, + help="Huggingface API model's info file path", + ) + + parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", + ) + parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--seed", + type=int, + default=None, + help="Overwrite the random seed for each generation.", + ) + args = parser.parse_args() + + with open(args.model_info_file, "r", encoding="UTF-8") as f: + model_info = json.load(f) + + logger.info(f"args: {args}") + + model_path_list = [] + api_base_list = [] + token_list = [] + context_length_list = [] + model_names_list = [] + conv_template_list = [] + + for m in model_info: + model_path_list.append(model_info[m]["model_path"]) + api_base_list.append(model_info[m]["api_base"]) + token_list.append(model_info[m]["token"]) + + context_length = model_info[m]["context_length"] + model_names = model_info[m].get("model_names", [m.split("/")[-1]]) + if isinstance(model_names, str): + model_names = [model_names] + conv_template = model_info[m].get("conv_template", None) + + context_length_list.append(context_length) + model_names_list.append(model_names) + conv_template_list.append(conv_template) + + logger.info(f"Model paths: {model_path_list}") + logger.info(f"API bases: {api_base_list}") + logger.info(f"Tokens: {token_list}") + logger.info(f"Context lengths: {context_length_list}") + logger.info(f"Model names: {model_names_list}") + logger.info(f"Conv templates: {conv_template_list}") + + for ( + model_names, + conv_template, + model_path, + api_base, + token, + context_length, + ) in zip( + model_names_list, + conv_template_list, + model_path_list, + api_base_list, + token_list, + context_length_list, + ): + m = HuggingfaceApiWorker( + args.controller_address, + args.worker_address, + worker_id, + model_path, + api_base, + token, + context_length, + model_names, + args.limit_worker_concurrency, + no_register=args.no_register, + conv_template=conv_template, + seed=args.seed, + ) + workers.append(m) + for name in model_names: + worker_map[name] = m + + # register all the models + url = args.controller_address + "/register_worker" + data = { + "worker_name": workers[0].worker_addr, + "check_heart_beat": not args.no_register, + "worker_status": { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + }, + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + return args, workers + + +if __name__ == "__main__": + args, workers = create_huggingface_api_worker() + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..888a13d00c0f8a723cf417035dcf9d0ab65e0076 --- /dev/null +++ b/inference.py @@ -0,0 +1,596 @@ +"""Inference for FastChat models.""" +import abc +import gc +import json +import math +import os +import sys +import time +from typing import Iterable, Optional, Dict +import warnings + +import psutil +import torch +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LlamaTokenizer, + LlamaForCausalLM, + AutoModel, + AutoModelForSeq2SeqLM, + T5Tokenizer, + AutoConfig, +) +from transformers.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) + +from fastchat.conversation import get_conv_template, SeparatorStyle +from fastchat.model.model_adapter import ( + load_model, + get_conversation_template, + get_generate_stream_function, +) +from fastchat.modules.awq import AWQConfig +from fastchat.modules.gptq import GptqConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length + + +def prepare_logits_processor( + temperature: float, repetition_penalty: float, top_p: float, top_k: int +) -> LogitsProcessorList: + processor_list = LogitsProcessorList() + # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. + if temperature >= 1e-5 and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if repetition_penalty > 1.0: + processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) + if 1e-8 <= top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + if top_k > 0: + processor_list.append(TopKLogitsWarper(top_k)) + return processor_list + + +@torch.inference_mode() +def generate_stream( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, +): + if hasattr(model, "device"): + device = model.device + + # Read parameters + prompt = params["prompt"] + len_prompt = len(prompt) + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", -1)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 256)) + logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1. + echo = bool(params.get("echo", True)) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if tokenizer.eos_token_id not in stop_token_ids: + stop_token_ids.append(tokenizer.eos_token_id) + if params.get('none_stop'): + stop_token_ids = [] + skip_special_tokens = params.get('skip_special_tokens') + + logits_processor = prepare_logits_processor( + temperature, repetition_penalty, top_p, top_k + ) + input_ids = tokenizer(prompt).input_ids + + if model.config.is_encoder_decoder: + max_src_len = context_len + else: # truncate + max_src_len = context_len - max_new_tokens - 1 + + input_ids = input_ids[-max_src_len:] + output_ids = list(input_ids) + input_echo_len = len(input_ids) + + if model.config.is_encoder_decoder: + if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models. + raise NotImplementedError + encoder_output = model.encoder( + input_ids=torch.as_tensor([input_ids], device=device) + )[0] + start_ids = torch.as_tensor( + [[model.generation_config.decoder_start_token_id]], + dtype=torch.int64, + device=device, + ) + else: + start_ids = torch.as_tensor([input_ids], device=device) + + past_key_values = out = None + token_logprobs = [None] # The first token has no logprobs. + sent_interrupt = False + finish_reason = None + for i in range(max_new_tokens): + if i == 0: # prefill + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=start_ids, + encoder_hidden_states=encoder_output, + use_cache=True, + ) + logits = model.lm_head(out[0]) + else: + out = model(input_ids=start_ids, use_cache=True) + logits = out.logits + past_key_values = out.past_key_values + + if logprobs is not None: + # Prefull logprobs for the prompt. + shift_input_ids = start_ids[..., 1:].contiguous() + shift_logits = logits[..., :-1, :].contiguous() + shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() + for label_id, logit in zip( + shift_input_ids[0].tolist(), shift_logits[0] + ): + token_logprobs.append(logit[label_id]) + else: # decoding + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], + device=device, + ), + encoder_hidden_states=encoder_output, + use_cache=True, + past_key_values=past_key_values if not sent_interrupt else None, + ) + sent_interrupt = False + + logits = model.lm_head(out[0]) + else: + out = model( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], + device=device, + ), + use_cache=True, + past_key_values=past_key_values if not sent_interrupt else None, + ) + sent_interrupt = False + logits = out.logits + past_key_values = out.past_key_values + + if logits_processor: + if repetition_penalty > 1.0: + tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) + else: + tmp_output_ids = None + last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] + else: + last_token_logits = logits[0, -1, :] + + if device == "mps": + # Switch to CPU by avoiding some bugs in mps backend. + last_token_logits = last_token_logits.float().to("cpu") + + if temperature < 1e-5 or top_p < 1e-8: # greedy + _, indices = torch.topk(last_token_logits, 2) + tokens = [int(index) for index in indices.tolist()] + else: + probs = torch.softmax(last_token_logits, dim=-1) + indices = torch.multinomial(probs, num_samples=2) + tokens = [int(token) for token in indices.tolist()] + token = tokens[0] + output_ids.append(token) + if logprobs is not None: + # Cannot use last_token_logits because logprobs is based on raw logits. + token_logprobs.append( + torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() + ) + + if token in stop_token_ids: + stopped = True + else: + stopped = False + + # Yield the output tokens + if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: + if echo: + tmp_output_ids = output_ids + rfind_start = len_prompt + else: + tmp_output_ids = output_ids[input_echo_len:] + rfind_start = 0 + + output = tokenizer.decode( + tmp_output_ids, + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + ret_logprobs = None + if logprobs is not None: + ret_logprobs = { + "text_offset": [], + "tokens": [ + tokenizer.decode(token) + for token in ( + output_ids if echo else output_ids[input_echo_len:] + ) + ], + "token_logprobs": token_logprobs + if echo + else token_logprobs[input_echo_len:], + "top_logprobs": [{}] + * len(token_logprobs if echo else token_logprobs[input_echo_len:]), + } + # Compute text_offset + curr_pos = 0 + for text in ret_logprobs["tokens"]: + ret_logprobs["text_offset"].append(curr_pos) + curr_pos += len(text) + + # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way + if judge_sent_end and stopped and not is_sentence_complete(output): + if len(tokens) > 1: + token = tokens[1] + output_ids[-1] = token + else: + output_ids.pop() + stopped = False + sent_interrupt = True + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # Prevent yielding partial stop sequence + if not partially_stopped: + yield { + "text": output, + "logprobs": ret_logprobs, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + + if stopped: + break + + # Finish stream event, which contains finish reason + else: + finish_reason = "length" + + if stopped: + finish_reason = "stop" + + yield { + "text": output, + "logprobs": ret_logprobs, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + + # Clean + del past_key_values, out + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() + + +class ChatIO(abc.ABC): + @abc.abstractmethod + def prompt_for_input(self, role: str) -> str: + """Prompt for input from a role.""" + + @abc.abstractmethod + def prompt_for_output(self, role: str): + """Prompt for output from a role.""" + + @abc.abstractmethod + def stream_output(self, output_stream): + """Stream output.""" + + @abc.abstractmethod + def print_output(self, text: str): + """Print output.""" + + +def convert_message_format(message): + formated_message = [] + for i, turn in enumerate(message): + role = 'user' if i % 2 == 0 else 'assistant' + formated_message.append({'role': role, 'content': turn[1]}) + + data = { + 'conversations': formated_message, + 'idx': -1, + 'tinder': 'badcase', + 'model': '', + 'tokens_in': 0, + 'tokens_out': 0, + } + + return data + + +def chat_loop( + model_path: str, + device: str, + num_gpus: int, + max_gpu_memory: str, + dtype: Optional[torch.dtype], + load_8bit: bool, + cpu_offloading: bool, + conv_template: Optional[str], + conv_system_msg: Optional[str], + temperature: float, + repetition_penalty: float, + max_new_tokens: int, + chatio: ChatIO, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + xft_config: Optional[XftConfig] = None, + revision: str = "main", + judge_sent_end: bool = True, + debug: bool = True, + history: bool = True, +): + # Model + model, tokenizer = load_model( + model_path, + device=device, + num_gpus=num_gpus, + max_gpu_memory=max_gpu_memory, + dtype=dtype, + load_8bit=load_8bit, + cpu_offloading=cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + revision=revision, + debug=debug, + ) + generate_stream_func = get_generate_stream_function(model, model_path) + + model_type = str(type(model)).lower() + is_t5 = "t5" in model_type + is_codet5p = "codet5p" in model_type + is_xft = "xft" in model_type + + # Hardcode T5's default repetition penalty to be 1.2 + if is_t5 and repetition_penalty == 1.0: + repetition_penalty = 1.2 + + # Set context length + context_len = get_context_length(model.config) + + # Chat + def new_chat(): + if conv_template: + conv = get_conv_template(conv_template) + else: + conv = get_conversation_template(model_path) + if conv_system_msg is not None: + conv.set_system_message(conv_system_msg) + return conv + + def reload_conv(conv): + """ + Reprints the conversation from the start. + """ + for message in conv.messages[conv.offset :]: + chatio.prompt_for_output(message[0]) + chatio.print_output(message[1]) + + conv = None + + while True: + if not history or not conv: + conv = new_chat() + + try: + inp = chatio.prompt_for_input(conv.roles[0]) + except EOFError: + inp = "" + + if inp == "!!exit":# or not inp: + print("exit...") + break + elif inp == "!!reset": + print("resetting...") + conv = new_chat() + continue + elif inp == "!!remove": + print("removing last message...") + if len(conv.messages) > conv.offset: + # Assistant + if conv.messages[-1][0] == conv.roles[1]: + conv.messages.pop() + # User + if conv.messages[-1][0] == conv.roles[0]: + conv.messages.pop() + reload_conv(conv) + else: + print("No messages to remove.") + continue + elif inp == "!!regen": + print("regenerating last message...") + if len(conv.messages) > conv.offset: + # Assistant + if conv.messages[-1][0] == conv.roles[1]: + conv.messages.pop() + # User + if conv.messages[-1][0] == conv.roles[0]: + reload_conv(conv) + # Set inp to previous message + inp = conv.messages.pop()[1] + else: + # Shouldn't happen in normal circumstances + print("No user message to regenerate from.") + continue + else: + print("No messages to regenerate.") + continue + elif inp.startswith("!!save"): + args = inp.split(" ", 1) + + if len(args) != 2: + print("usage: !!save ") + continue + else: + filename = args[1] + + # Add .json if extension not present + if not "." in filename: + filename += ".json" + + print("saving...", filename) + with open(filename, "w", encoding="utf-8") as outfile: + json.dump(conv.dict(), outfile, ensure_ascii=False) + continue + elif inp.startswith("!!badcase"): + args = inp.split(" ", 1) + + if len(args) != 2: + print("usage: !!save ") + continue + else: + filename = args[1] + + # Add .json if extension not present + if not "." in filename: + filename += ".jsonl" + + print("saving...", filename) + with open(filename, "a+", encoding="utf-8") as outfile: + data = convert_message_format(conv.messages) + json.dump(data, outfile, ensure_ascii=False) + outfile.write('\n') + continue + elif inp.startswith("!!load"): + args = inp.split(" ", 1) + + if len(args) != 2: + print("usage: !!load ") + continue + else: + filename = args[1] + + # Check if file exists and add .json if needed + if not os.path.exists(filename): + if (not filename.endswith(".json")) and os.path.exists( + filename + ".json" + ): + filename += ".json" + else: + print("file not found:", filename) + continue + + print("loading...", filename) + with open(filename, "r") as infile: + new_conv = json.load(infile) + + conv = get_conv_template(new_conv["template_name"]) + conv.set_system_message(new_conv["system_message"]) + conv.messages = new_conv["messages"] + reload_conv(conv) + continue + + conv.append_message(conv.roles[0], inp) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt(tokenizer) + + if is_codet5p: # codet5p is a code completion model. + prompt = inp + + gen_params = { + "model": model_path, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "none_stop": conv.none_stop, + "skip_special_tokens": conv.skip_special_tokens, + "echo": False, + } + + try: + chatio.prompt_for_output(conv.roles[1]) + output_stream = generate_stream_func( + model, + tokenizer, + gen_params, + device, + context_len=context_len, + judge_sent_end=judge_sent_end, + ) + t = time.time() + outputs = chatio.stream_output(output_stream) + duration = time.time() - t + conv.update_last_message(outputs.strip()) + + if debug: + num_tokens = len(tokenizer.encode(outputs)) + msg = { + "conv_template": conv.name, + "prompt": prompt, + "outputs": outputs, + "speed (token/s)": round(num_tokens / duration, 2), + } + print(f"\n{msg}\n") + + except KeyboardInterrupt: + print("stopped generation.") + # If generation didn't finish + if conv.messages[-1][1] is None: + conv.messages.pop() + # Remove last user message, so there isn't a double up + if conv.messages[-1][0] == conv.roles[0]: + conv.messages.pop() + + reload_conv(conv) diff --git a/launch_all_serve.py b/launch_all_serve.py new file mode 100644 index 0000000000000000000000000000000000000000..2f4ad7b0b134d1699ff8ba0d95d8039ec3c1f204 --- /dev/null +++ b/launch_all_serve.py @@ -0,0 +1,284 @@ +""" +Usage: python launch_all_serve_by_shell.py --model-path-address "THUDM/chatglm2-6b@localhost@2021" "huggyllama/llama-7b@localhost@2022" + +Workers are listed in format of `model-path`@`host`@`port` + +The key mechanism behind this scripts is: + 1, execute shell cmd to launch the controller/worker/openai-api-server; + 2, check the log of controller/worker/openai-api-server to ensure that the serve is launched properly. +Note that a few of non-critical `fastchat.serve` cmd options are not supported currently. +""" +import sys +import os + +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + +import subprocess +import re +import argparse + +LOGDIR = "./logs/" + +if not os.path.exists(LOGDIR): + os.makedirs(LOGDIR) + +parser = argparse.ArgumentParser() +# ------multi worker----------------- +parser.add_argument( + "--model-path-address", + default="THUDM/chatglm2-6b@localhost@20002", + nargs="+", + type=str, + help="model path, host, and port, formatted as model-path@host@port", +) +# ---------------controller------------------------- + +parser.add_argument("--controller-host", type=str, default="localhost") +parser.add_argument("--controller-port", type=int, default=21001) +parser.add_argument( + "--dispatch-method", + type=str, + choices=["lottery", "shortest_queue"], + default="shortest_queue", +) +controller_args = ["controller-host", "controller-port", "dispatch-method"] + +# ----------------------worker------------------------------------------ + +parser.add_argument("--worker-host", type=str, default="localhost") +parser.add_argument("--worker-port", type=int, default=21002) +# parser.add_argument("--worker-address", type=str, default="http://localhost:21002") +# parser.add_argument( +# "--controller-address", type=str, default="http://localhost:21001" +# ) +parser.add_argument( + "--model-path", + type=str, + default="lmsys/vicuna-7b-v1.5", + help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", +) +parser.add_argument( + "--revision", + type=str, + default="main", + help="Hugging Face Hub model revision identifier", +) +parser.add_argument( + "--device", + type=str, + choices=["cpu", "cuda", "mps", "xpu", "npu"], + default="cuda", + help="The device type", +) +parser.add_argument( + "--gpus", + type=str, + default="0", + help="A single GPU like 1 or multiple GPUs like 0,2", +) +parser.add_argument("--num-gpus", type=int, default=1) +parser.add_argument( + "--max-gpu-memory", + type=str, + help="The maximum memory per gpu. Use a string like '13Gib'", +) +parser.add_argument("--load-8bit", action="store_true", help="Use 8-bit quantization") +parser.add_argument( + "--cpu-offloading", + action="store_true", + help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU", +) +parser.add_argument( + "--gptq-ckpt", + type=str, + default=None, + help="Load quantized model. The path to the local GPTQ checkpoint.", +) +parser.add_argument( + "--gptq-wbits", + type=int, + default=16, + choices=[2, 3, 4, 8, 16], + help="#bits to use for quantization", +) +parser.add_argument( + "--gptq-groupsize", + type=int, + default=-1, + help="Groupsize to use for quantization; default uses full row.", +) +parser.add_argument( + "--gptq-act-order", + action="store_true", + help="Whether to apply the activation order GPTQ heuristic", +) +parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", +) +parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", +) +parser.add_argument("--stream-interval", type=int, default=2) +parser.add_argument("--no-register", action="store_true") + +worker_args = [ + "worker-host", + "worker-port", + "model-path", + "revision", + "device", + "gpus", + "num-gpus", + "max-gpu-memory", + "load-8bit", + "cpu-offloading", + "gptq-ckpt", + "gptq-wbits", + "gptq-groupsize", + "gptq-act-order", + "model-names", + "limit-worker-concurrency", + "stream-interval", + "no-register", + "controller-address", +] +# -----------------openai server--------------------------- + +parser.add_argument("--server-host", type=str, default="localhost", help="host name") +parser.add_argument("--server-port", type=int, default=8001, help="port number") +parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" +) +# parser.add_argument( +# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" +# ) +# parser.add_argument( +# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" +# ) +# parser.add_argument( +# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" +# ) +parser.add_argument( + "--api-keys", + type=lambda s: s.split(","), + help="Optional list of comma separated API keys", +) +server_args = [ + "server-host", + "server-port", + "allow-credentials", + "api-keys", + "controller-address", +] + +args = parser.parse_args() + +args = argparse.Namespace( + **vars(args), + **{"controller-address": f"http://{args.controller_host}:{args.controller_port}"}, +) + +if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + +# 0,controller, model_worker, openai_api_server +# 1, cmd options +# 2,LOGDIR +# 3, log file name +base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &" + +# 0 LOGDIR +#! 1 log file name +# 2 controller, worker, openai_api_server +base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do + sleep 1s; + echo "wait {2} running" + done + echo '{2} running' """ + + +def string_args(args, args_list): + args_str = "" + for key, value in args._get_kwargs(): + key = key.replace("_", "-") + if key not in args_list: + continue + + key = key.split("-")[-1] if re.search("port|host", key) else key + if not value: + pass + # 1==True -> True + elif isinstance(value, bool) and value == True: + args_str += f" --{key} " + elif ( + isinstance(value, list) + or isinstance(value, tuple) + or isinstance(value, set) + ): + value = " ".join(value) + args_str += f" --{key} {value} " + else: + args_str += f" --{key} {value} " + + return args_str + + +def launch_worker(item): + log_name = ( + item.split("/")[-1] + .split("\\")[-1] + .replace("-", "_") + .replace("@", "_") + .replace(".", "_") + ) + + args.model_path, args.worker_host, args.worker_port = item.split("@") + print("*" * 80) + worker_str_args = string_args(args, worker_args) + print(worker_str_args) + worker_sh = base_launch_sh.format( + "model_worker", worker_str_args, LOGDIR, f"worker_{log_name}" + ) + worker_check_sh = base_check_sh.format(LOGDIR, f"worker_{log_name}", "model_worker") + subprocess.run(worker_sh, shell=True, check=True) + subprocess.run(worker_check_sh, shell=True, check=True) + + +def launch_all(): + controller_str_args = string_args(args, controller_args) + controller_sh = base_launch_sh.format( + "controller", controller_str_args, LOGDIR, "controller" + ) + controller_check_sh = base_check_sh.format(LOGDIR, "controller", "controller") + subprocess.run(controller_sh, shell=True, check=True) + subprocess.run(controller_check_sh, shell=True, check=True) + + if isinstance(args.model_path_address, str): + launch_worker(args.model_path_address) + else: + for idx, item in enumerate(args.model_path_address): + print(f"loading {idx}th model:{item}") + launch_worker(item) + + server_str_args = string_args(args, server_args) + server_sh = base_launch_sh.format( + "openai_api_server", server_str_args, LOGDIR, "openai_api_server" + ) + server_check_sh = base_check_sh.format( + LOGDIR, "openai_api_server", "openai_api_server" + ) + subprocess.run(server_sh, shell=True, check=True) + subprocess.run(server_check_sh, shell=True, check=True) + + +if __name__ == "__main__": + launch_all() diff --git a/model_worker.py b/model_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..056c2f8954195d57a4d940c185b1504ae2862664 --- /dev/null +++ b/model_worker.py @@ -0,0 +1,363 @@ +""" +A model worker that executes the model. +""" +import argparse +import base64 +import gc +import json +import os +from typing import List, Optional +import uuid + +import torch +import torch.nn.functional as F +from transformers import set_seed +import uvicorn + +from fastchat.constants import ErrorCode, SERVER_ERROR_MSG +from fastchat.model.model_adapter import ( + load_model, + add_model_args, + get_generate_stream_function, +) +from fastchat.modules.awq import AWQConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.modules.gptq import GptqConfig +from fastchat.serve.base_model_worker import BaseModelWorker, app +from fastchat.utils import ( + build_logger, + get_context_length, + str_to_torch_dtype, +) + + +worker_id = str(uuid.uuid4())[:8] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") + + +class ModelWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + device: str, + num_gpus: int, + max_gpu_memory: str, + dtype: Optional[torch.dtype] = None, + load_8bit: bool = False, + cpu_offloading: bool = False, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + xft_config: Optional[XftConfig] = None, + stream_interval: int = 2, + conv_template: Optional[str] = None, + embed_in_truncate: bool = False, + seed: Optional[int] = None, + debug: bool = False, + **kwargs, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template=conv_template, + ) + + logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...") + self.model, self.tokenizer = load_model( + model_path, + device=device, + num_gpus=num_gpus, + max_gpu_memory=max_gpu_memory, + dtype=dtype, + load_8bit=load_8bit, + cpu_offloading=cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + debug=debug, + model_name=model_names[0], + ) + self.device = device + if self.tokenizer.pad_token == None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.context_len = get_context_length(self.model.config) + self.generate_stream_func = get_generate_stream_function(self.model, model_path) + self.stream_interval = stream_interval + self.embed_in_truncate = embed_in_truncate + self.seed = seed + + if not no_register: + self.init_heart_beat() + + def generate_stream_gate(self, params): + self.call_ct += 1 + + try: + if self.seed is not None: + set_seed(self.seed) + for output in self.generate_stream_func( + self.model, + self.tokenizer, + params, + self.device, + self.context_len, + self.stream_interval, + ): + ret = { + "text": output["text"], + "error_code": 0, + } + if "usage" in output: + ret["usage"] = output["usage"] + if "finish_reason" in output: + ret["finish_reason"] = output["finish_reason"] + if "logprobs" in output: + ret["logprobs"] = output["logprobs"] + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + yield json.dumps(ret).encode() + b"\0" + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + def generate_gate(self, params): + for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict): + if model_type_dict.get("is_bert"): + model_output = self.model(input_ids) + if model_type_dict.get("is_robert"): + data = model_output.last_hidden_state + else: + data = model_output[0] + elif model_type_dict.get("is_t5"): + model_output = self.model(input_ids, decoder_input_ids=input_ids) + data = model_output.encoder_last_hidden_state + else: + model_output = self.model(input_ids, output_hidden_states=True) + if model_type_dict.get("is_chatglm"): + data = model_output.hidden_states[-1].transpose(0, 1) + else: + data = model_output.hidden_states[-1] + mask = attention_mask.unsqueeze(-1).expand(data.size()).float() + masked_embeddings = data * mask + sum_embeddings = torch.sum(masked_embeddings, dim=1) + token_num = torch.sum(attention_mask).item() + + return sum_embeddings, token_num + + def __encode_base64(self, embeddings: torch.Tensor) -> List[str]: + embeddings = embeddings.cpu() + return [ + base64.b64encode(e.numpy().tobytes()).decode("utf-8") for e in embeddings + ] + + @torch.inference_mode() + def get_embeddings(self, params): + self.call_ct += 1 + + try: + tokenizer = self.tokenizer + ret = {"embedding": [], "token_num": 0} + + model_type_dict = { + "is_llama": "llama" in str(type(self.model)), + "is_t5": "t5" in str(type(self.model)), + "is_chatglm": "chatglm" in str(type(self.model)), + "is_bert": "bert" in str(type(self.model)), + "is_robert": "robert" in str(type(self.model)), + } + + if self.embed_in_truncate: + encoding = tokenizer.batch_encode_plus( + params["input"], + padding=True, + truncation="longest_first", + return_tensors="pt", + max_length=self.context_len, + ) + else: + encoding = tokenizer.batch_encode_plus( + params["input"], padding=True, return_tensors="pt" + ) + input_ids = encoding["input_ids"].to(self.device) + attention_mask = input_ids != tokenizer.pad_token_id + + base64_encode = params.get("encoding_format", None) + + if self.embed_in_truncate: + chunk_embeddings, token_num = self.__process_embed_chunk( + input_ids, attention_mask, **model_type_dict + ) + embedding = chunk_embeddings / token_num + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + ret["token_num"] = token_num + else: + all_embeddings = [] + all_token_num = 0 + for i in range(0, input_ids.size(1), self.context_len): + chunk_input_ids = input_ids[:, i : i + self.context_len] + chunk_attention_mask = attention_mask[:, i : i + self.context_len] + + chunk_embeddings, token_num = self.__process_embed_chunk( + chunk_input_ids, chunk_attention_mask, **model_type_dict + ) + all_embeddings.append(chunk_embeddings) + all_token_num += token_num + + all_embeddings_tensor = torch.stack(all_embeddings) + embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + + ret["token_num"] = all_token_num + + if base64_encode == "base64": + out_embeddings = self.__encode_base64(normalized_embeddings) + else: + out_embeddings = normalized_embeddings.tolist() + ret["embedding"] = out_embeddings + + gc.collect() + torch.cuda.empty_cache() + if self.device == "xpu": + torch.xpu.empty_cache() + if self.device == "npu": + torch.npu.empty_cache() + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return ret + + +def create_model_worker(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + add_model_args(parser) + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument("--embed-in-truncate", action="store_true") + parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", + ) + parser.add_argument("--stream-interval", type=int, default=2) + parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--seed", + type=int, + default=None, + help="Overwrite the random seed for each generation.", + ) + parser.add_argument( + "--debug", type=bool, default=False, help="Print debugging messages" + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + + gptq_config = GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + awq_config = AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ) + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + ) + else: + exllama_config = None + if args.enable_xft: + xft_config = XftConfig( + max_seq_len=args.xft_max_seq_len, + data_type=args.xft_dtype, + ) + if args.device != "cpu": + print("xFasterTransformer now is only support CPUs. Reset device to CPU") + args.device = "cpu" + else: + xft_config = None + + worker = ModelWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + args.limit_worker_concurrency, + no_register=args.no_register, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + dtype=str_to_torch_dtype(args.dtype), + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + stream_interval=args.stream_interval, + conv_template=args.conv_template, + embed_in_truncate=args.embed_in_truncate, + seed=args.seed, + debug=args.debug, + ) + return args, worker + + +if __name__ == "__main__": + args, worker = create_model_worker() + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/monitor/basic_stats.py b/monitor/basic_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..e1934bb07863ddedb4b5dedba0b3e4724c78a765 --- /dev/null +++ b/monitor/basic_stats.py @@ -0,0 +1,210 @@ +import argparse +import code +import datetime +import json +import os +from pytz import timezone +import time + +import pandas as pd # pandas>=2.0.3 +import plotly.express as px +import plotly.graph_objects as go +from tqdm import tqdm + + +NUM_SERVERS = 14 + + +def get_log_files(max_num_files=None): + dates = [] + for month in range(4, 12): + for day in range(1, 33): + dates.append(f"2023-{month:02d}-{day:02d}") + + filenames = [] + for d in dates: + for i in range(NUM_SERVERS): + name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") + if os.path.exists(name): + filenames.append(name) + max_num_files = max_num_files or len(filenames) + filenames = filenames[-max_num_files:] + return filenames + + +def load_log_files(log_files): + data = [] + for filename in tqdm(log_files, desc="read files"): + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + + data.append( + dict( + type=row["type"], + tstamp=row["tstamp"], + model=row.get("model", ""), + models=row.get("models", ["", ""]), + ) + ) + + return data + + +def get_anony_vote_df(df): + anony_vote_df = df[ + df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"]) + ] + anony_vote_df = anony_vote_df[anony_vote_df["models"].apply(lambda x: x[0] == "")] + return anony_vote_df + + +def merge_counts(series, on, names): + ret = pd.merge(series[0], series[1], on=on) + for i in range(2, len(series)): + ret = pd.merge(ret, series[i], on=on) + ret = ret.reset_index() + old_names = list(ret.columns)[-len(series) :] + rename = {old_name: new_name for old_name, new_name in zip(old_names, names)} + ret = ret.rename(columns=rename) + return ret + + +def report_basic_stats(log_files): + df_all = load_log_files(log_files) + df_all = pd.DataFrame(df_all) + now_t = df_all["tstamp"].max() + df_1_hour = df_all[df_all["tstamp"] > (now_t - 3600)] + df_1_day = df_all[df_all["tstamp"] > (now_t - 3600 * 24)] + anony_vote_df_all = get_anony_vote_df(df_all) + + # Chat trends + chat_dates = [ + datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime( + "%Y-%m-%d" + ) + for x in df_all[df_all["type"] == "chat"]["tstamp"] + ] + chat_dates_counts = pd.value_counts(chat_dates) + vote_dates = [ + datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime( + "%Y-%m-%d" + ) + for x in anony_vote_df_all["tstamp"] + ] + vote_dates_counts = pd.value_counts(vote_dates) + chat_dates_bar = go.Figure( + data=[ + go.Bar( + name="Anony. Vote", + x=vote_dates_counts.index, + y=vote_dates_counts, + text=[f"{val:.0f}" for val in vote_dates_counts], + textposition="auto", + ), + go.Bar( + name="Chat", + x=chat_dates_counts.index, + y=chat_dates_counts, + text=[f"{val:.0f}" for val in chat_dates_counts], + textposition="auto", + ), + ] + ) + chat_dates_bar.update_layout( + barmode="stack", + xaxis_title="Dates", + yaxis_title="Count", + height=300, + width=1200, + ) + + # Model call counts + model_hist_all = df_all[df_all["type"] == "chat"]["model"].value_counts() + model_hist_1_day = df_1_day[df_1_day["type"] == "chat"]["model"].value_counts() + model_hist_1_hour = df_1_hour[df_1_hour["type"] == "chat"]["model"].value_counts() + model_hist = merge_counts( + [model_hist_all, model_hist_1_day, model_hist_1_hour], + on="model", + names=["All", "Last Day", "Last Hour"], + ) + model_hist_md = model_hist.to_markdown(index=False, tablefmt="github") + + # Action counts + action_hist_all = df_all["type"].value_counts() + action_hist_1_day = df_1_day["type"].value_counts() + action_hist_1_hour = df_1_hour["type"].value_counts() + action_hist = merge_counts( + [action_hist_all, action_hist_1_day, action_hist_1_hour], + on="type", + names=["All", "Last Day", "Last Hour"], + ) + action_hist_md = action_hist.to_markdown(index=False, tablefmt="github") + + # Anony vote counts + anony_vote_hist_all = anony_vote_df_all["type"].value_counts() + anony_vote_df_1_day = get_anony_vote_df(df_1_day) + anony_vote_hist_1_day = anony_vote_df_1_day["type"].value_counts() + # anony_vote_df_1_hour = get_anony_vote_df(df_1_hour) + # anony_vote_hist_1_hour = anony_vote_df_1_hour["type"].value_counts() + anony_vote_hist = merge_counts( + [anony_vote_hist_all, anony_vote_hist_1_day], + on="type", + names=["All", "Last Day"], + ) + anony_vote_hist_md = anony_vote_hist.to_markdown(index=False, tablefmt="github") + + # Last 24 hours + chat_1_day = df_1_day[df_1_day["type"] == "chat"] + num_chats_last_24_hours = [] + base = df_1_day["tstamp"].min() + for i in range(24, 0, -1): + left = base + (i - 1) * 3600 + right = base + i * 3600 + num = ((chat_1_day["tstamp"] >= left) & (chat_1_day["tstamp"] < right)).sum() + num_chats_last_24_hours.append(num) + times = [ + datetime.datetime.fromtimestamp( + base + i * 3600, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + for i in range(24, 0, -1) + ] + last_24_hours_df = pd.DataFrame({"time": times, "value": num_chats_last_24_hours}) + last_24_hours_md = last_24_hours_df.to_markdown(index=False, tablefmt="github") + + # Last update datetime + last_updated_tstamp = now_t + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + # code.interact(local=locals()) + + return { + "chat_dates_bar": chat_dates_bar, + "model_hist_md": model_hist_md, + "action_hist_md": action_hist_md, + "anony_vote_hist_md": anony_vote_hist_md, + "num_chats_last_24_hours": last_24_hours_md, + "last_updated_datetime": last_updated_datetime, + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + basic_stats = report_basic_stats(log_files) + + print(basic_stats["action_hist_md"] + "\n") + print(basic_stats["model_hist_md"] + "\n") + print(basic_stats["anony_vote_hist_md"] + "\n") + print(basic_stats["num_chats_last_24_hours"] + "\n") diff --git a/monitor/clean_battle_data.py b/monitor/clean_battle_data.py new file mode 100644 index 0000000000000000000000000000000000000000..23357d08cd2b24ca2bbecdd5c1434d1b7203a2f9 --- /dev/null +++ b/monitor/clean_battle_data.py @@ -0,0 +1,269 @@ +""" +Clean chatbot arena battle log. + +Usage: +python3 clean_battle_data.py --mode conv_release +""" +import argparse +import datetime +import json +import os +from pytz import timezone +import time + +from tqdm import tqdm + +from fastchat.serve.monitor.basic_stats import get_log_files, NUM_SERVERS +from fastchat.utils import detect_language + + +VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"] +IDENTITY_WORDS = [ + "vicuna", + "lmsys", + "koala", + "uc berkeley", + "open assistant", + "laion", + "chatglm", + "chatgpt", + "openai", + "anthropic", + "claude", + "bard", + "palm", + "lamda", + "google", + "llama", + "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.", + "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES.", +] + +for i in range(len(IDENTITY_WORDS)): + IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower() + + +def get_log_files(max_num_files=None): + dates = [] + for month in range(4, 12): + for day in range(1, 33): + dates.append(f"2023-{month:02d}-{day:02d}") + + filenames = [] + for d in dates: + for i in range(NUM_SERVERS): + name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") + if os.path.exists(name): + filenames.append(name) + max_num_files = max_num_files or len(filenames) + filenames = filenames[-max_num_files:] + return filenames + + +def remove_html(raw): + if raw.startswith("

"): + return raw[raw.find(": ") + 2 : -len("

\n")] + return raw + + +def to_openai_format(messages): + roles = ["user", "assistant"] + ret = [] + for i, x in enumerate(messages): + ret.append({"role": roles[i % 2], "content": x[1]}) + return ret + + +def replace_model_name(old_name): + return ( + old_name.replace("bard", "palm-2") + .replace("claude-v1", "claude-1") + .replace("claude-instant-v1", "claude-instant-1") + .replace("oasst-sft-1-pythia-12b", "oasst-pythia-12b") + ) + + +def clean_battle_data(log_files, exclude_model_names): + data = [] + for filename in tqdm(log_files, desc="read files"): + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + if row["type"] in VOTES: + data.append(row) + + convert_type = { + "leftvote": "model_a", + "rightvote": "model_b", + "tievote": "tie", + "bothbad_vote": "tie (bothbad)", + } + + all_models = set() + all_ips = dict() + ct_anony = 0 + ct_invalid = 0 + ct_leaked_identity = 0 + battles = [] + for row in data: + if row["models"][0] is None or row["models"][1] is None: + continue + + # Resolve model names + models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])] + if "model_name" in row["states"][0]: + models_hidden = [ + row["states"][0]["model_name"], + row["states"][1]["model_name"], + ] + if models_hidden[0] is None: + models_hidden = models_public + else: + models_hidden = models_public + + if (models_public[0] == "" and models_public[1] != "") or ( + models_public[1] == "" and models_public[0] != "" + ): + ct_invalid += 1 + continue + + if models_public[0] == "" or models_public[0] == "Model A": + anony = True + models = models_hidden + ct_anony += 1 + else: + anony = False + models = models_public + if not models_public == models_hidden: + ct_invalid += 1 + continue + + # Detect langauge + state = row["states"][0] + if state["offset"] >= len(state["messages"]): + ct_invalid += 1 + continue + lang_code = detect_language(state["messages"][state["offset"]][1]) + + # Drop conversations if the model names are leaked + leaked_identity = False + messages = "" + for i in range(2): + state = row["states"][i] + for role, msg in state["messages"][state["offset"] :]: + if msg: + messages += msg.lower() + for word in IDENTITY_WORDS: + if word in messages: + leaked_identity = True + break + + if leaked_identity: + ct_leaked_identity += 1 + continue + + # Replace bard with palm + models = [replace_model_name(m) for m in models] + + # Exclude certain models + if any(x in exclude_model_names for x in models): + ct_invalid += 1 + continue + + question_id = row["states"][0]["conv_id"] + conversation_a = to_openai_format( + row["states"][0]["messages"][row["states"][0]["offset"] :] + ) + conversation_b = to_openai_format( + row["states"][1]["messages"][row["states"][1]["offset"] :] + ) + + ip = row["ip"] + if ip not in all_ips: + all_ips[ip] = len(all_ips) + user_id = all_ips[ip] + + # Save the results + battles.append( + dict( + question_id=question_id, + model_a=models[0], + model_b=models[1], + winner=convert_type[row["type"]], + judge=f"arena_user_{user_id}", + conversation_a=conversation_a, + conversation_b=conversation_b, + turn=len(conversation_a) // 2, + anony=anony, + language=lang_code, + tstamp=row["tstamp"], + ) + ) + + all_models.update(models_hidden) + battles.sort(key=lambda x: x["tstamp"]) + last_updated_tstamp = battles[-1]["tstamp"] + + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + print( + f"#votes: {len(data)}, #invalid votes: {ct_invalid}, " + f"#leaked_identity: {ct_leaked_identity}" + ) + print(f"#battles: {len(battles)}, #anony: {ct_anony}") + print(f"#models: {len(all_models)}, {all_models}") + print(f"last-updated: {last_updated_datetime}") + + return battles + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-files", type=int) + parser.add_argument( + "--mode", type=str, choices=["simple", "conv_release"], default="simple" + ) + parser.add_argument("--exclude-model-names", type=str, nargs="+") + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + battles = clean_battle_data(log_files, args.exclude_model_names or []) + last_updated_tstamp = battles[-1]["tstamp"] + cutoff_date = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y%m%d") + + if args.mode == "simple": + for x in battles: + for key in [ + "conversation_a", + "conversation_b", + "question_id", + ]: + del x[key] + print("Samples:") + for i in range(4): + print(battles[i]) + output = f"clean_battle_{cutoff_date}.json" + elif args.mode == "conv_release": + new_battles = [] + for x in battles: + if not x["anony"]: + continue + for key in []: + del x[key] + new_battles.append(x) + battles = new_battles + output = f"clean_battle_conv_{cutoff_date}.json" + + with open(output, "w") as fout: + json.dump(battles, fout, indent=2, ensure_ascii=False) + print(f"Write cleaned data to {output}") diff --git a/monitor/clean_chat_data.py b/monitor/clean_chat_data.py new file mode 100644 index 0000000000000000000000000000000000000000..7f0c9bd4fa4cce17af03597dab12a9cdcc9453c5 --- /dev/null +++ b/monitor/clean_chat_data.py @@ -0,0 +1,171 @@ +""" +Clean chatbot arena chat log. + +Usage: +python3 clean_chat_data.py --mode conv_release +""" +import argparse +import datetime +import json +import os +from pytz import timezone +import time + +from tqdm import tqdm + +from fastchat.serve.monitor.basic_stats import NUM_SERVERS +from fastchat.serve.monitor.clean_battle_data import ( + to_openai_format, + replace_model_name, +) +from fastchat.utils import detect_language + + +NETWORK_ERROR_MSG = ( + "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.".lower() +) + + +def get_log_files(max_num_files=None): + dates = [] + for month in range(4, 12): + for day in range(1, 33): + dates.append(f"2023-{month:02d}-{day:02d}") + + filenames = [] + for d in dates: + for i in range(NUM_SERVERS): + name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") + if os.path.exists(name): + filenames.append(name) + max_num_files = max_num_files or len(filenames) + # filenames = list(reversed(filenames)) + filenames = filenames[-max_num_files:] + return filenames + + +def clean_chat_data(log_files, action_type): + raw_data = [] + for filename in tqdm(log_files, desc="read files"): + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + if row["type"] == action_type: + raw_data.append(row) + + all_models = set() + all_ips = dict() + chats = [] + ct_invalid_conv_id = 0 + ct_invalid = 0 + ct_network_error = 0 + for row in raw_data: + try: + if action_type in ["chat", "upvote", "downvote"]: + state = row["state"] + model = row["model"] + elif action_type == "leftvote": + state = row["states"][0] + model = row["states"][0]["model_name"] + elif action_type == "rightvote": + state = row["states"][1] + model = row["states"][1]["model_name"] + conversation_id = state["conv_id"] + except KeyError: + ct_invalid_conv_id += 1 + continue + + if conversation_id is None: + ct_invalid_conv_id += 1 + continue + + conversation = to_openai_format(state["messages"][state["offset"] :]) + if not isinstance(model, str): + ct_invalid += 1 + continue + model = replace_model_name(model) + + try: + lang_code = detect_language(state["messages"][state["offset"]][1]) + except IndexError: + ct_invalid += 1 + continue + + if not all(isinstance(x["content"], str) for x in conversation): + ct_invalid += 1 + continue + + messages = "".join([x["content"] for x in conversation]).lower() + if NETWORK_ERROR_MSG in messages: + ct_network_error += 1 + continue + + ip = row["ip"] + if ip not in all_ips: + all_ips[ip] = len(all_ips) + user_id = all_ips[ip] + + chats.append( + dict( + conversation_id=conversation_id, + model=model, + conversation=conversation, + turn=len(conversation) // 2, + language=lang_code, + user_id=user_id, + tstamp=row["tstamp"], + ) + ) + + all_models.update([model]) + + chats.sort(key=lambda x: x["tstamp"]) + last_updated_tstamp = chats[-1]["tstamp"] + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + # Deduplication + dedup_chats = [] + visited_conv_ids = set() + for i in reversed(range(len(chats))): + if chats[i]["conversation_id"] in visited_conv_ids: + continue + visited_conv_ids.add(chats[i]["conversation_id"]) + dedup_chats.append(chats[i]) + + print( + f"#raw: {len(raw_data)}, #chat: {len(chats)}, #dedup_chat: {len(dedup_chats)}" + ) + print( + f"#invalid_conv_id: {ct_invalid_conv_id}, #network_error: {ct_network_error}, #invalid: {ct_invalid}" + ) + print(f"#models: {len(all_models)}, {all_models}") + print(f"last-updated: {last_updated_datetime}") + + return list(reversed(dedup_chats)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--action-type", type=str, default="chat") + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + chats = clean_chat_data(log_files, args.action_type) + last_updated_tstamp = chats[-1]["tstamp"] + cutoff_date = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y%m%d") + + output = f"clean_{args.action_type}_conv_{cutoff_date}.json" + with open(output, "w") as fout: + json.dump(chats, fout, indent=2, ensure_ascii=False) + print(f"Write cleaned data to {output}") diff --git a/monitor/dataset_release_scripts/arena_33k/count_unique_users.py b/monitor/dataset_release_scripts/arena_33k/count_unique_users.py new file mode 100644 index 0000000000000000000000000000000000000000..8e94cf2756203f207e82cc7f31ff544ecdcc80f0 --- /dev/null +++ b/monitor/dataset_release_scripts/arena_33k/count_unique_users.py @@ -0,0 +1,25 @@ +"""Count the unique users in a battle log file.""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + args = parser.parse_args() + + lines = json.load(open(args.input)) + ct_anony_votes = 0 + all_users = set() + all_models = set() + for l in lines: + if not l["anony"]: + continue + all_users.add(l["judge"]) + all_models.add(l["model_a"]) + all_models.add(l["model_b"]) + ct_anony_votes += 1 + + print(f"#anony_vote: {ct_anony_votes}, #user: {len(all_users)}") + print(f"#model: {len(all_models)}") diff --git a/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py b/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..6d12d7c652bc02bb7b5c9f65bce0e1644f739c1b --- /dev/null +++ b/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py @@ -0,0 +1,155 @@ +""" +Filter conversations for release. + +Usage: python3 filter_bad_conv.py --in clean_battle_conv_20230630_tagged_v1_pii.json +""" +import argparse +from collections import defaultdict +from enum import Enum, auto +import json +import os +import random + +from tqdm import tqdm + +BLOCKED_WORDS_FILENAME = "blocked_words.json" +blocked_words = [] +frequency = defaultdict(lambda: 0) + + +class TypeCode(Enum): + CORRECT = auto() + ANONYMIZED = auto() + REDACTED = auto() + BAD_FORMAT = auto() + BLOCKED_WORD = auto() + BLOCKED_MODEL = auto() + TOO_SHORT = auto() + TOO_FREQUENT = auto() + + +def detect_type(conv): + for key in ["conversation_a", "conversation_b"]: + messages = [row["content"] for row in conv[key]] + for msg in messages: + if not isinstance(msg, str): + return TypeCode.BAD_FORMAT + + user_prompts = [ + row["content"].lower().strip() for row in conv[key] if row["role"] == "user" + ] + if len(messages) <= 2 and all(len(x) < 16 for x in user_prompts): + return TypeCode.TOO_SHORT + + if all(x in frequent_prompts for x in user_prompts): + return TypeCode.TOO_FREQUENT + + for msg in messages: + msg = msg.lower() + if "" in msg: + return TypeCode.ANONYMIZED + if "" in msg: + return TypeCode.REDACTED + + for w in blocked_words: + if w in msg: + return TypeCode.BLOCKED_WORD + + for key in ["model_a", "model_b"]: + if conv[key] in ["vicuna-33b", "mpt-30b-chat"]: + return TypeCode.BLOCKED_MODEL + + return TypeCode.CORRECT + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--sample", type=int) + args = parser.parse_args() + + # Read conversations + convs = json.load(open(args.in_file)) + print(f"#conv: {len(convs)}") + + # Read blocked words + if os.path.exists(BLOCKED_WORDS_FILENAME): + blocked_words = json.load(open(BLOCKED_WORDS_FILENAME)) + + # Count frequency + for conv in convs: + for key in ["conversation_a", "conversation_b"]: + messages = [row["content"] for row in conv[key] if row["role"] == "user"] + for msg in messages: + if not isinstance(msg, str): + continue + msg = msg.lower().strip() + frequency[msg] += 1 + + keys = list(frequency.keys()) + keys.sort(key=lambda x: -frequency[x]) + frequent_prompts = keys[:10] + frequent_prompts = set(frequent_prompts) + frequent_prompts.add("") + + # Start filter + ct_bad_format = 0 + ct_anonymized = 0 + ct_redacted = 0 + ct_error = 0 + ct_lang_filter = 0 + ct_flagged = 0 + ct_blocked_word = 0 + ct_blocked_model = 0 + ct_too_short = 0 + ct_too_frequent = 0 + + new_convs = [] + for conv in tqdm(convs): + type_code = detect_type(conv) + + if type_code == TypeCode.BAD_FORMAT: + ct_bad_format += 1 + continue + + if type_code == TypeCode.ANONYMIZED: + ct_anonymized += 1 + continue + elif type_code == TypeCode.REDACTED: + ct_redacted += 1 + continue + elif type_code == TypeCode.BLOCKED_WORD: + ct_blocked_word += 1 + continue + elif type_code == TypeCode.BLOCKED_MODEL: + ct_blocked_model += 1 + continue + elif type_code == TypeCode.TOO_SHORT: + ct_too_short += 1 + continue + elif type_code == TypeCode.TOO_FREQUENT: + ct_too_frequent += 1 + continue + + if conv["openai_moderation"]["flagged"]: + ct_flagged += 1 + continue + + if type_code in [TypeCode.CORRECT]: + new_convs.append(conv) + + if args.sample: + # random.seed(0) + # random.shuffle(new_convs) + new_convs = new_convs[: args.sample] + + print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}") + print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}") + print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}") + print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_anonymized}") + print(f"new_conv: {len(new_convs)}") + + out_file = args.in_file.replace(".json", ".out.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) diff --git a/monitor/dataset_release_scripts/arena_33k/merge_field.py b/monitor/dataset_release_scripts/arena_33k/merge_field.py new file mode 100644 index 0000000000000000000000000000000000000000..5a88209bfcb58cb2131ce94d6eba03c899e74a0a --- /dev/null +++ b/monitor/dataset_release_scripts/arena_33k/merge_field.py @@ -0,0 +1,25 @@ +"""Count the unique users in a battle log file.""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + parser.add_argument("--tag-file", type=str) + args = parser.parse_args() + + # build index + objs = json.load(open(args.tag_file)) + new_field_dict = {} + for obj in objs: + new_field_dict[obj["question_id"]] = obj["toxic_chat"] + + objs = json.load(open(args.input)) + for obj in objs: + obj["toxic_chat_tag"] = new_field_dict[obj["question_id"]] + + output = args.input.replace(".json", "_added.json") + with open(output, "w") as fout: + json.dump(objs, fout, indent=2, ensure_ascii=False) diff --git a/monitor/dataset_release_scripts/arena_33k/sample.py b/monitor/dataset_release_scripts/arena_33k/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd78b71e95a3034bf3440aee3557a38426d0244 --- /dev/null +++ b/monitor/dataset_release_scripts/arena_33k/sample.py @@ -0,0 +1,32 @@ +""" +Count the unique users in a battle log file. + +Usage: +python3 -input in.json --number 1000 +""" + +import argparse +import json +import random + +K = 1000 + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + parser.add_argument("--number", type=int, nargs="+") + args = parser.parse_args() + + convs = json.load(open(args.input)) + random.seed(0) + random.shuffle(convs) + + for number in args.number: + new_convs = convs[:number] + + output = args.input.replace(".json", f"_{number//K}k.json") + with open(output, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) + + print(f"#in: {len(convs)}, #out: {len(new_convs)}") + print(f"Write to file: {output}") diff --git a/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py b/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e37aadcea65df7ca605369b88c068aa57c8f35f2 --- /dev/null +++ b/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py @@ -0,0 +1,9 @@ +""" +Upload to huggingface. +""" +import json +from datasets import Dataset, DatasetDict, load_dataset + +objs = json.load(open("clean_battle_conv_20230630_tagged_v3_pii_33k_added.json")) +data = Dataset.from_list(objs) +data.push_to_hub("lmsys/chatbot_arena_conversations", private=True) diff --git a/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py b/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py new file mode 100644 index 0000000000000000000000000000000000000000..a7084207309907dcb8fa37eccf55fd2a6b62ca48 --- /dev/null +++ b/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py @@ -0,0 +1,13 @@ +import requests + +headers = {"authorization": "Bearer hf_XXX"} + +url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/pending" +a = requests.get(url, headers=headers) + +for u in a.json(): + user = u["user"]["user"] + url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/grant" + ret = requests.post(url, headers=headers, json={"user": user}) + print(user, ret.status_code) + assert ret.status_code == 200 diff --git a/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py b/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..97abaaa0df053c93c3adb655f1b5c41af0aab00d --- /dev/null +++ b/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py @@ -0,0 +1,119 @@ +""" +From colab: +https://colab.research.google.com/drive/1oMdw_Lqgmd6DletSOLHsyD-Rc96cRShs?usp=sharing +""" +import argparse +import datetime +import json +import os +from pytz import timezone +import time + +import kaleido +import numpy as np +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +from tqdm import tqdm + +import plotly.io as pio + +pio.kaleido.scope.mathjax = None + +parser = argparse.ArgumentParser() +parser.add_argument("--in-file", type=str, required=True) +parser.add_argument("--scale", type=int, required=True) +args = parser.parse_args() + +filename = args.in_file +scale = args.scale +convs = json.load(open(filename)) +df = pd.DataFrame(convs) +df + +print(f"#ips: {df['user_id'].nunique() * scale}") +print(f"#models: {df['model'].nunique()}") +print(f"#language: {df['language'].nunique()}") +print(f"#turns: {df['turn'].mean()}") + +model_counts = df["model"].value_counts() * scale +# print("model counts", model_counts) +fig = px.bar(x=model_counts.index, y=model_counts) +fig.update_layout( + xaxis_title=None, + yaxis_title="Count", + height=200, + width=950, + margin=dict(l=0, r=0, t=0, b=0), +) +fig.show() +fig.write_image("model_count.pdf") + + +model_counts = df["language"].value_counts().head(25) * scale +fig = px.bar(x=model_counts.index, y=model_counts) +fig.update_layout( + xaxis_title=None, + yaxis_title="Count", + height=200, + width=950, + margin=dict(l=0, r=0, t=0, b=0), +) +fig.show() +fig.write_image("language_count.pdf") + +chat_dates = [ + datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime("%Y-%m-%d") + for x in df["tstamp"] +] + + +def to_remove(x): + for d in ["08-09", "08-08", "08-07", "08-06", "08-05", "08-04"]: + if d in x: + return True + return False + + +chat_dates = [x for x in chat_dates if not to_remove(x)] + +chat_dates_counts = pd.value_counts(chat_dates) * scale +print(f"mean #chat per day: {np.mean(chat_dates_counts):.2f}") + +fig = px.bar(x=chat_dates_counts.index, y=chat_dates_counts) +fig.update_layout( + xaxis_title="Dates", + yaxis_title="Count", + height=200, + width=950, + margin=dict(l=0, r=0, t=0, b=0), +) +fig.show() +fig.write_image("daily_conversation_count.pdf") + +import transformers + +tokenizer = transformers.AutoTokenizer.from_pretrained( + "lmsys/vicuna-7b-v1.5", use_fast=False +) + +prompts = [] +responses = [] +for conv in df["conversation"]: + for row in conv: + if row["role"] == "user": + prompts.append(row["content"]) + else: + responses.append(row["content"]) + +print(f"#prompts: {len(prompts)}") +print(f"#responses: {len(responses)}") + + +prompt_lens = [len(tokenizer(x).input_ids) for x in tqdm(prompts)] +print() +print(f"mean prompt len: {np.mean(prompt_lens):.2f}") + +response_lens = [len(tokenizer(x).input_ids) if x else 0 for x in tqdm(responses)] +print() +print(f"mean response len: {np.mean(response_lens):.2f}") diff --git a/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py b/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..3ccde1ca57546acf5d1131cae14a499f1228a02c --- /dev/null +++ b/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py @@ -0,0 +1,148 @@ +""" +Filter conversations for release. + +Dependency: +pip install opencc-python-reimplementedpip install opencc-python-reimplemented + +Usage: +python3 filter_bad_conv_lmsys_chat_1m.py --in clean_battle_conv_20230630_tagged_v1_pii.json +""" +import argparse +from concurrent.futures import ProcessPoolExecutor +from collections import defaultdict +from enum import Enum, auto +import json +import os +import random + +from tqdm import tqdm +import opencc + +BLOCKED_WORDS_FILENAME = "blocked_words.json" +blocked_words = [] +frequency = defaultdict(lambda: 0) + +cc_converter = opencc.OpenCC("t2s") + + +class TypeCode(Enum): + CORRECT = auto() + ANONYMIZED = auto() + REDACTED = auto() + BAD_FORMAT = auto() + BLOCKED_WORD = auto() + BLOCKED_MODEL = auto() + TOO_SHORT = auto() + TOO_FREQUENT = auto() + + +def detect_type(conv): + for key in ["conversation_a", "conversation_b", "conversation"]: + if key not in conv: + continue + + messages = [row["content"] for row in conv[key]] + for msg in messages: + if not isinstance(msg, str): + return TypeCode.BAD_FORMAT + + if len(messages) == 0: + return TypeCode.BAD_FORMAT + + user_prompts = [ + row["content"].lower().strip() for row in conv[key] if row["role"] == "user" + ] + + for msg in messages: + msg = cc_converter.convert(msg.lower()) + if "" in msg: + return TypeCode.ANONYMIZED + if "" in msg: + return TypeCode.REDACTED + + for w in blocked_words: + if w in msg: + return TypeCode.BLOCKED_WORD + + return TypeCode.CORRECT + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--sample", type=int) + args = parser.parse_args() + + # Read conversations + convs = json.load(open(args.in_file)) + print(f"#conv: {len(convs)}") + + # Read blocked words + if os.path.exists(BLOCKED_WORDS_FILENAME): + blocked_words = json.load(open(BLOCKED_WORDS_FILENAME)) + blocked_words = [cc_converter.convert(w) for w in blocked_words] + + # Start filter + ct_bad_format = 0 + ct_anonymized = 0 + ct_redacted = 0 + ct_error = 0 + ct_lang_filter = 0 + ct_flagged = 0 + ct_blocked_word = 0 + ct_blocked_model = 0 + ct_too_short = 0 + ct_too_frequent = 0 + + type_codes = [] + with ProcessPoolExecutor() as executor: + for result in tqdm(executor.map(detect_type, convs), total=len(convs)): + type_codes.append(result) + + new_convs = [] + for conv, type_code in zip(convs, type_codes): + if type_code == TypeCode.BAD_FORMAT: + ct_bad_format += 1 + continue + + if type_code == TypeCode.ANONYMIZED: + ct_anonymized += 1 + continue + elif type_code == TypeCode.REDACTED: + ct_redacted += 1 + continue + elif type_code == TypeCode.BLOCKED_WORD: + ct_blocked_word += 1 + continue + elif type_code == TypeCode.BLOCKED_MODEL: + ct_blocked_model += 1 + continue + elif type_code == TypeCode.TOO_SHORT: + ct_too_short += 1 + continue + elif type_code == TypeCode.TOO_FREQUENT: + ct_too_frequent += 1 + continue + + if "openai_moderation" in conv and conv["openai_moderation"]["flagged"]: + ct_flagged += 1 + continue + + if type_code in [TypeCode.CORRECT]: + new_convs.append(conv) + + if args.sample: + random.seed(42) + random.shuffle(new_convs) + new_convs = new_convs[: args.sample] + + print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}") + print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}") + print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}") + print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_too_frequent}") + print(f"new_conv: {len(new_convs)}") + + out_file = args.in_file.replace(".json", ".s1.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) diff --git a/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py b/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..e368e92a1dcf260ecb5b175b77e85c6971809a3c --- /dev/null +++ b/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py @@ -0,0 +1,27 @@ +import argparse +import json + +from tqdm import tqdm +import numpy as np + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + args = parser.parse_args() + + # Read conversations + convs = json.load(open(args.in_file)) + print(f"#conv: {len(convs)}") + + # Delete some fileds + for c in convs: + del c["tstamp"] + del c["user_id"] + + # Write + print(f"#out conv: {len(convs)}") + out_file = args.in_file.replace(".json", ".s2.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(convs, fout, indent=2, ensure_ascii=False) diff --git a/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md b/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md new file mode 100644 index 0000000000000000000000000000000000000000..4c439731f6aee43bd29e1a65576c5ae04ff59cfa --- /dev/null +++ b/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md @@ -0,0 +1,23 @@ +``` +export BASE=clean_conv_20230809_100k_pii +export SCALE=10 + +# filter words +python3 filter_bad_conv.py --in $BASE.json + +# Clean up some fileds (e.g., timestamps) +python3 final_post_processing.py --in $BASE.s1.json + +# upload to hf +python3 upload_hf_dataset.py --in $BASE.s1.s2.json + +# Make another version with openai moderation tag +python3 merge_oai_tag.py --in $BASE.s1.s2.json + +# Make visualizations +python3 compute_stats.py --in $BASE.s1.json --scale $SCALE + +# Copy figures +scp "atlas:/data/lmzheng/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/*.pdf" . +``` + diff --git a/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py b/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py new file mode 100644 index 0000000000000000000000000000000000000000..18bef5f1962384d80f174aa22a7b6dcc867fe7c0 --- /dev/null +++ b/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py @@ -0,0 +1,45 @@ +import argparse +import json +import time + +from tqdm import tqdm + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--sample", type=int) + args = parser.parse_args() + + tag_file = "clean_conv_20230809_1.5M_oai_filter_v2.json" + # tag_file = "clean_conv_20230809_1.5M_oai_filter_v2_100k.json" + in_file = args.in_file + tic = time.time() + + # Load tags + print("Load tags...") + tag_data = json.load(open(tag_file)) + tag_dict = {} + for c in tqdm(tag_data): + tag_dict[c["conversation_id"]] = [x["oai_filter"] for x in c["conversation"]] + print(f"elapsed: {time.time() - tic:.2f} s") + + # Append to input_file + print("Load inputs...") + input_data = json.load(open(in_file)) + for c in tqdm(input_data): + cid = c["conversation_id"] + if cid in tag_dict: + c["openai_moderation"] = tag_dict[cid] + else: + print(f"missing tag for conv {cid}") + exit() + print(f"elapsed: {time.time() - tic:.2f} s") + + # Write output + print("Write outputs...") + out_file = in_file.replace(".json", ".with_tag.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(input_data, fout, indent=2, ensure_ascii=False) + print(f"elapsed: {time.time() - tic:.2f} s") diff --git a/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh b/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..5bae9fbad221c57eba8f2cf5b7eb2779a6f040a8 --- /dev/null +++ b/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh @@ -0,0 +1,18 @@ +export BASE=clean_conv_20230809_1.5M_pii +#export BASE=clean_conv_20230809_100k_pii +export SCALE=1 + +# Filter words +python3 filter_bad_conv.py --in $BASE.json --sample 1000000 + +# Clean up some fileds (e.g., timestamps) +python3 final_post_processing.py --in $BASE.s1.json + +# Upload to hf +python3 upload_hf_dataset.py --in $BASE.s1.s2.json + +# Make another version with openai moderation tag +python3 merge_oai_tag.py --in $BASE.s1.s2.json + +# Make visualizations +python3 compute_stats.py --in $BASE.s1.json --scale $SCALE diff --git a/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py b/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..3b6da455fc7bf8af1ce473f80440bff280c9366e --- /dev/null +++ b/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py @@ -0,0 +1,32 @@ +""" +Count the unique users in a battle log file. + +Usage: +python3 -input in.json --number 1000 +""" + +import argparse +import json +import random + +K = 1000 + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + parser.add_argument("--number", type=int, nargs="+") + args = parser.parse_args() + + convs = json.load(open(args.input)) + random.seed(42) + random.shuffle(convs) + + for number in args.number: + new_convs = convs[:number] + + output = args.input.replace(".json", f"_{number//K}k.json") + with open(output, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) + + print(f"#in: {len(convs)}, #out: {len(new_convs)}") + print(f"Write to file: {output}") diff --git a/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py b/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..41d0fbdb59b4c7dc8385bef87a1bf0c8ea6e7401 --- /dev/null +++ b/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py @@ -0,0 +1,17 @@ +""" +Upload to huggingface. +""" +import argparse +import json +from datasets import Dataset, DatasetDict, load_dataset + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + args = parser.parse_args() + + objs = json.load(open(args.in_file)) + print(f"#convs: {len(objs)}") + data = Dataset.from_list(objs) + data.push_to_hub("lmsys/lmsys-chat-1m", private=True) diff --git a/monitor/elo_analysis.py b/monitor/elo_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..e95f157c87e31297f7193ef9cecc21d5a90b8b01 --- /dev/null +++ b/monitor/elo_analysis.py @@ -0,0 +1,303 @@ +import argparse +from collections import defaultdict +import datetime +import json +import math +import pickle +from pytz import timezone + +import numpy as np +import pandas as pd +import plotly.express as px +from tqdm import tqdm + +from fastchat.model.model_registry import get_model_info +from fastchat.serve.monitor.basic_stats import get_log_files +from fastchat.serve.monitor.clean_battle_data import clean_battle_data + + +pd.options.display.float_format = "{:.2f}".format + + +def compute_elo(battles, K=4, SCALE=400, BASE=10, INIT_RATING=1000): + rating = defaultdict(lambda: INIT_RATING) + + for rd, model_a, model_b, winner in battles[ + ["model_a", "model_b", "winner"] + ].itertuples(): + ra = rating[model_a] + rb = rating[model_b] + ea = 1 / (1 + BASE ** ((rb - ra) / SCALE)) + eb = 1 / (1 + BASE ** ((ra - rb) / SCALE)) + if winner == "model_a": + sa = 1 + elif winner == "model_b": + sa = 0 + elif winner == "tie" or winner == "tie (bothbad)": + sa = 0.5 + else: + raise Exception(f"unexpected vote {winner}") + rating[model_a] += K * (sa - ea) + rating[model_b] += K * (1 - sa - eb) + + return dict(rating) + + +def get_bootstrap_result(battles, func_compute_elo, num_round=1000): + rows = [] + for i in tqdm(range(num_round), desc="bootstrap"): + tmp_battles = battles.sample(frac=1.0, replace=True) + rows.append(func_compute_elo(tmp_battles)) + df = pd.DataFrame(rows) + return df[df.median().sort_values(ascending=False).index] + + +def get_median_elo_from_bootstrap(bootstrap_df): + median = dict(bootstrap_df.quantile(0.5)) + median = {k: int(v + 0.5) for k, v in median.items()} + return median + + +def compute_pairwise_win_fraction(battles, model_order, limit_show_number=None): + # Times each model wins as Model A + a_win_ptbl = pd.pivot_table( + battles[battles["winner"] == "model_a"], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + + # Table counting times each model wins as Model B + b_win_ptbl = pd.pivot_table( + battles[battles["winner"] == "model_b"], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + + # Table counting number of A-B pairs + num_battles_ptbl = pd.pivot_table( + battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0 + ) + + # Computing the proportion of wins for each model as A and as B + # against all other models + row_beats_col_freq = (a_win_ptbl + b_win_ptbl.T) / ( + num_battles_ptbl + num_battles_ptbl.T + ) + + if model_order is None: + prop_wins = row_beats_col_freq.mean(axis=1).sort_values(ascending=False) + model_order = list(prop_wins.keys()) + + if limit_show_number is not None: + model_order = model_order[:limit_show_number] + + # Arrange ordering according to proprition of wins + row_beats_col = row_beats_col_freq.loc[model_order, model_order] + return row_beats_col + + +def visualize_leaderboard_table(rating): + models = list(rating.keys()) + models.sort(key=lambda k: -rating[k]) + + emoji_dict = { + 1: "🥇", + 2: "🥈", + 3: "🥉", + } + + md = "" + md += "| Rank | Model | Elo Rating | Description |\n" + md += "| --- | --- | --- | --- |\n" + for i, model in enumerate(models): + rank = i + 1 + minfo = get_model_info(model) + emoji = emoji_dict.get(rank, "") + md += f"| {rank} | {emoji} [{model}]({minfo.link}) | {rating[model]:.0f} | {minfo.description} |\n" + + return md + + +def visualize_pairwise_win_fraction(battles, model_order): + row_beats_col = compute_pairwise_win_fraction(battles, model_order) + fig = px.imshow( + row_beats_col, + color_continuous_scale="RdBu", + text_auto=".2f", + height=700, + width=700, + ) + fig.update_layout( + xaxis_title="Model B", + yaxis_title="Model A", + xaxis_side="top", + title_y=0.07, + title_x=0.5, + ) + fig.update_traces( + hovertemplate="Model A: %{y}
Model B: %{x}
Fraction of A Wins: %{z}" + ) + + return fig + + +def visualize_battle_count(battles, model_order): + ptbl = pd.pivot_table( + battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0 + ) + battle_counts = ptbl + ptbl.T + fig = px.imshow( + battle_counts.loc[model_order, model_order], + text_auto=True, + height=700, + width=700, + ) + fig.update_layout( + xaxis_title="Model B", + yaxis_title="Model A", + xaxis_side="top", + title_y=0.07, + title_x=0.5, + ) + fig.update_traces( + hovertemplate="Model A: %{y}
Model B: %{x}
Count: %{z}" + ) + return fig + + +def visualize_average_win_rate(battles, limit_show_number): + row_beats_col_freq = compute_pairwise_win_fraction( + battles, None, limit_show_number=limit_show_number + ) + fig = px.bar( + row_beats_col_freq.mean(axis=1).sort_values(ascending=False), + text_auto=".2f", + height=500, + width=700, + ) + fig.update_layout( + yaxis_title="Average Win Rate", xaxis_title="Model", showlegend=False + ) + return fig + + +def visualize_bootstrap_elo_rating(df, limit_show_number): + bars = ( + pd.DataFrame( + dict( + lower=df.quantile(0.025), + rating=df.quantile(0.5), + upper=df.quantile(0.975), + ) + ) + .reset_index(names="model") + .sort_values("rating", ascending=False) + ) + bars = bars[:limit_show_number] + bars["error_y"] = bars["upper"] - bars["rating"] + bars["error_y_minus"] = bars["rating"] - bars["lower"] + bars["rating_rounded"] = np.round(bars["rating"], 2) + fig = px.scatter( + bars, + x="model", + y="rating", + error_y="error_y", + error_y_minus="error_y_minus", + text="rating_rounded", + height=500, + width=700, + ) + fig.update_layout(xaxis_title="Model", yaxis_title="Rating") + return fig + + +def report_elo_analysis_results(battles_json): + battles = pd.DataFrame(battles_json) + battles = battles.sort_values(ascending=True, by=["tstamp"]) + # Only use anonymous votes + battles = battles[battles["anony"]].reset_index(drop=True) + battles_no_ties = battles[~battles["winner"].str.contains("tie")] + + # Online update + elo_rating_online = compute_elo(battles) + + # Bootstrap + bootstrap_df = get_bootstrap_result(battles, compute_elo) + elo_rating_median = get_median_elo_from_bootstrap(bootstrap_df) + model_order = list(elo_rating_median.keys()) + model_order.sort(key=lambda k: -elo_rating_median[k]) + + limit_show_number = 25 # limit show number to make plots smaller + model_order = model_order[:limit_show_number] + + # Plots + leaderboard_table = visualize_leaderboard_table(elo_rating_median) + win_fraction_heatmap = visualize_pairwise_win_fraction(battles_no_ties, model_order) + battle_count_heatmap = visualize_battle_count(battles_no_ties, model_order) + average_win_rate_bar = visualize_average_win_rate( + battles_no_ties, limit_show_number + ) + bootstrap_elo_rating = visualize_bootstrap_elo_rating( + bootstrap_df, limit_show_number + ) + + last_updated_tstamp = battles["tstamp"].max() + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + return { + "elo_rating_online": elo_rating_online, + "elo_rating_median": elo_rating_median, + "leaderboard_table": leaderboard_table, + "win_fraction_heatmap": win_fraction_heatmap, + "battle_count_heatmap": battle_count_heatmap, + "average_win_rate_bar": average_win_rate_bar, + "bootstrap_elo_rating": bootstrap_elo_rating, + "last_updated_datetime": last_updated_datetime, + "last_updated_tstamp": last_updated_tstamp, + } + + +def pretty_print_elo_rating(rating): + model_order = list(rating.keys()) + model_order.sort(key=lambda k: -rating[k]) + for i, model in enumerate(model_order): + print(f"{i+1:2d}, {model:25s}, {rating[model]:.0f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--clean-battle-file", type=str) + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + np.random.seed(42) + + if args.clean_battle_file: + # Read data from a cleaned battle files + battles = pd.read_json(args.clean_battle_file) + else: + # Read data from all log files + log_files = get_log_files(args.max_num_files) + battles = clean_battle_data(log_files) + + results = report_elo_analysis_results(battles) + + print("# Online") + pretty_print_elo_rating(results["elo_rating_online"]) + print("# Median") + pretty_print_elo_rating(results["elo_rating_median"]) + print(f"last update : {results['last_updated_datetime']}") + + last_updated_tstamp = results["last_updated_tstamp"] + cutoff_date = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y%m%d") + + with open(f"elo_results_{cutoff_date}.pkl", "wb") as fout: + pickle.dump(results, fout) diff --git a/monitor/inspect_conv.py b/monitor/inspect_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..a680a419bd9d11d0db85afbc21c0063a2ae36df7 --- /dev/null +++ b/monitor/inspect_conv.py @@ -0,0 +1,87 @@ +import argparse +import code +import datetime +import json +import os +from pytz import timezone +import time + +import pandas as pd +from tqdm import tqdm + + +def get_log_files(max_num_files=None): + dates = [] + for month in [4, 5]: + for day in range(1, 32): + dates.append(f"2023-{month:02d}-{day:02d}") + + num_servers = 14 + filenames = [] + for d in dates: + for i in range(num_servers): + name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") + if os.path.exists(name): + filenames.append(name) + max_num_files = max_num_files or len(filenames) + filenames = filenames[-max_num_files:] + return filenames + + +def pretty_print_conversation(messages): + for role, msg in messages: + print(f"[[{role}]]: {msg}") + + +def inspect_convs(log_files): + data = [] + for filename in tqdm(log_files, desc="read files"): + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + + if "states" not in row: + continue + if row["type"] not in ["leftvote", "rightvote", "bothbad_vote"]: + continue + + model_names = row["states"][0]["model_name"], row["states"][1]["model_name"] + if row["type"] == "leftvote": + winner, loser = model_names[0], model_names[1] + winner_conv, loser_conv = row["states"][0], row["states"][1] + elif row["type"] == "rightvote": + loser, winner = model_names[0], model_names[1] + loser_conv, winner_conv = row["states"][0], row["states"][1] + + if loser == "bard" and winner == "vicuna-13b": + print("=" * 20) + print(f"Winner: {winner}") + pretty_print_conversation(winner_conv["messages"]) + print(f"Loser: {loser}") + pretty_print_conversation(loser_conv["messages"]) + print("=" * 20) + input() + + # if row["type"] == "bothbad_vote" and "gpt-4" in model_names: + # print("=" * 20) + # print(f"Model A: {model_names[0]}") + # pretty_print_conversation(row["states"][0]["messages"]) + # print(f"Model B: {model_names[1]}") + # pretty_print_conversation(row["states"][1]["messages"]) + # print("=" * 20) + # input() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + inspect_convs(log_files) diff --git a/monitor/intersect_conv_file.py b/monitor/intersect_conv_file.py new file mode 100644 index 0000000000000000000000000000000000000000..9eadd7cd57510ecbbd23798d55b079c69aac1a12 --- /dev/null +++ b/monitor/intersect_conv_file.py @@ -0,0 +1,25 @@ +""" +Take the intersection of two conversation files. + +Usage: python3 -m fastchat.data.merge --input input.json --conv-id conv_id_file.json --out intersect.json +""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, required=True) + parser.add_argument("--conv-id", type=str, required=True) + parser.add_argument("--out-file", type=str, default="intersect.json") + args = parser.parse_args() + + conv_id_objs = json.load(open(args.conv_id, "r")) + conv_ids = set(x["conversation_id"] for x in conv_id_objs) + + objs = json.load(open(args.input, "r")) + after_objs = [x for x in objs if x["conversation_id"] in conv_ids] + + print(f"#in: {len(objs)}, #out: {len(after_objs)}") + json.dump(after_objs, open(args.out_file, "w"), indent=2, ensure_ascii=False) diff --git a/monitor/leaderboard_csv_to_html.py b/monitor/leaderboard_csv_to_html.py new file mode 100644 index 0000000000000000000000000000000000000000..ad52e7b2b6e234ed33a51d516e9d682addd1e0eb --- /dev/null +++ b/monitor/leaderboard_csv_to_html.py @@ -0,0 +1,51 @@ +""" +Convert a leaderboard csv file to html table used in the blog. + +Usage: +python3 leaderboard_csv_to_html.py --in leaderboard_table_20230619.csv +""" +import argparse + +import numpy as np + +from fastchat.serve.monitor.monitor import load_leaderboard_table_csv + + +def model_hyperlink(model_name, link): + return f' {model_name} ' + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, required=True) + args = parser.parse_args() + + data = load_leaderboard_table_csv(args.input, add_hyperlink=False) + headers = [ + "Model", + "MT-bench (score)", + "Arena Elo rating", + "MMLU", + "License", + ] + values = [] + for item in data: + row = [] + for key in headers: + value = item[key] + row.append(value) + row[0] = model_hyperlink(item["Model"], item["Link"]) + values.append(row) + values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9) + + for value in values: + row = "" + for x in value: + try: + if np.isnan(x): + x = "-" + except TypeError: + pass + row += f" {x} " + row += "" + print(row) diff --git a/monitor/monitor.py b/monitor/monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..580a2c866ab77e92c1eab61c74ec8b96ce3d30ee --- /dev/null +++ b/monitor/monitor.py @@ -0,0 +1,313 @@ +""" +Live monitor of the website statistics and leaderboard. + +Dependency: +sudo apt install pkg-config libicu-dev +pip install pytz gradio gdown plotly polyglot pyicu pycld2 tabulate +""" + +import argparse +import ast +import pickle +import os +import threading +import time + +import gradio as gr +import numpy as np + +from fastchat.serve.monitor.basic_stats import report_basic_stats, get_log_files +from fastchat.serve.monitor.clean_battle_data import clean_battle_data +from fastchat.serve.monitor.elo_analysis import report_elo_analysis_results +from fastchat.utils import build_logger, get_window_url_params_js + + +notebook_url = "https://colab.research.google.com/drive/1RAWb22-PFNI-X1gPVzc927SGUdfr6nsR?usp=sharing" + + +basic_component_values = [None] * 6 +leader_component_values = [None] * 5 + + +def make_leaderboard_md(elo_results): + leaderboard_md = f""" +# 🏆 Chatbot Arena Leaderboard +| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | + +This leaderboard is based on the following three benchmarks. +- [Chatbot Arena](https://lmsys.org/blog/2023-05-03-arena/) - a crowdsourced, randomized battle platform. We use 100K+ user votes to compute Elo ratings. +- [MT-Bench](https://arxiv.org/abs/2306.05685) - a set of challenging multi-turn questions. We use GPT-4 to grade the model responses. +- [MMLU](https://arxiv.org/abs/2009.03300) (5-shot) - a test to measure a model's multitask accuracy on 57 tasks. + +💻 Code: The Arena Elo ratings are computed by this [notebook]({notebook_url}). The MT-bench scores (single-answer grading on a scale of 10) are computed by [fastchat.llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge). The MMLU scores are mostly computed by [InstructEval](https://github.com/declare-lab/instruct-eval). Higher values are better for all benchmarks. Empty cells mean not available. Last updated: November, 2023. +""" + return leaderboard_md + + +def make_leaderboard_md_live(elo_results): + leaderboard_md = f""" +# Leaderboard +Last updated: {elo_results["last_updated_datetime"]} +{elo_results["leaderboard_table"]} +""" + return leaderboard_md + + +def update_elo_components(max_num_files, elo_results_file): + log_files = get_log_files(max_num_files) + + # Leaderboard + if elo_results_file is None: # Do live update + battles = clean_battle_data(log_files, []) + elo_results = report_elo_analysis_results(battles) + + leader_component_values[0] = make_leaderboard_md_live(elo_results) + leader_component_values[1] = elo_results["win_fraction_heatmap"] + leader_component_values[2] = elo_results["battle_count_heatmap"] + leader_component_values[3] = elo_results["bootstrap_elo_rating"] + leader_component_values[4] = elo_results["average_win_rate_bar"] + + # Basic stats + basic_stats = report_basic_stats(log_files) + md0 = f"Last updated: {basic_stats['last_updated_datetime']}" + + md1 = "### Action Histogram\n" + md1 += basic_stats["action_hist_md"] + "\n" + + md2 = "### Anony. Vote Histogram\n" + md2 += basic_stats["anony_vote_hist_md"] + "\n" + + md3 = "### Model Call Histogram\n" + md3 += basic_stats["model_hist_md"] + "\n" + + md4 = "### Model Call (Last 24 Hours)\n" + md4 += basic_stats["num_chats_last_24_hours"] + "\n" + + basic_component_values[0] = md0 + basic_component_values[1] = basic_stats["chat_dates_bar"] + basic_component_values[2] = md1 + basic_component_values[3] = md2 + basic_component_values[4] = md3 + basic_component_values[5] = md4 + + +def update_worker(max_num_files, interval, elo_results_file): + while True: + tic = time.time() + update_elo_components(max_num_files, elo_results_file) + durtaion = time.time() - tic + print(f"update duration: {durtaion:.2f} s") + time.sleep(max(interval - durtaion, 0)) + + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + return basic_component_values + leader_component_values + + +def model_hyperlink(model_name, link): + return f'{model_name}' + + +def load_leaderboard_table_csv(filename, add_hyperlink=True): + lines = open(filename).readlines() + heads = [v.strip() for v in lines[0].split(",")] + rows = [] + for i in range(1, len(lines)): + row = [v.strip() for v in lines[i].split(",")] + for j in range(len(heads)): + item = {} + for h, v in zip(heads, row): + if h == "Arena Elo rating": + if v != "-": + v = int(ast.literal_eval(v)) + else: + v = np.nan + elif h == "MMLU": + if v != "-": + v = round(ast.literal_eval(v) * 100, 1) + else: + v = np.nan + elif h == "MT-bench (win rate %)": + if v != "-": + v = round(ast.literal_eval(v[:-1]), 1) + else: + v = np.nan + elif h == "MT-bench (score)": + if v != "-": + v = round(ast.literal_eval(v), 2) + else: + v = np.nan + item[h] = v + if add_hyperlink: + item["Model"] = model_hyperlink(item["Model"], item["Link"]) + rows.append(item) + + return rows + + +def build_basic_stats_tab(): + empty = "Loading ..." + basic_component_values[:] = [empty, None, empty, empty, empty, empty] + + md0 = gr.Markdown(empty) + gr.Markdown("#### Figure 1: Number of model calls and votes") + plot_1 = gr.Plot(show_label=False) + with gr.Row(): + with gr.Column(): + md1 = gr.Markdown(empty) + with gr.Column(): + md2 = gr.Markdown(empty) + with gr.Row(): + with gr.Column(): + md3 = gr.Markdown(empty) + with gr.Column(): + md4 = gr.Markdown(empty) + return [md0, plot_1, md1, md2, md3, md4] + + +def build_leaderboard_tab(elo_results_file, leaderboard_table_file): + if elo_results_file is None: # Do live update + md = "Loading ..." + p1 = p2 = p3 = p4 = None + else: + with open(elo_results_file, "rb") as fin: + elo_results = pickle.load(fin) + + md = make_leaderboard_md(elo_results) + p1 = elo_results["win_fraction_heatmap"] + p2 = elo_results["battle_count_heatmap"] + p3 = elo_results["bootstrap_elo_rating"] + p4 = elo_results["average_win_rate_bar"] + + md_1 = gr.Markdown(md, elem_id="leaderboard_markdown") + + if leaderboard_table_file: + data = load_leaderboard_table_csv(leaderboard_table_file) + headers = [ + "Model", + "Arena Elo rating", + "MT-bench (score)", + "MMLU", + "License", + ] + values = [] + for item in data: + row = [] + for key in headers: + value = item[key] + row.append(value) + values.append(row) + values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9) + + headers[1] = "⭐ " + headers[1] + headers[2] = "📈 " + headers[2] + + gr.Dataframe( + headers=headers, + datatype=["markdown", "number", "number", "number", "str"], + value=values, + elem_id="leaderboard_dataframe", + ) + gr.Markdown( + """ ## Visit our [HF space](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) for more analysis! + If you want to see more models, please help us [add them](https://github.com/lm-sys/FastChat/blob/main/docs/arena.md#how-to-add-a-new-model). + """, + elem_id="leaderboard_markdown", + ) + else: + pass + + leader_component_values[:] = [md, p1, p2, p3, p4] + + """ + with gr.Row(): + with gr.Column(): + gr.Markdown( + "#### Figure 1: Fraction of Model A Wins for All Non-tied A vs. B Battles" + ) + plot_1 = gr.Plot(p1, show_label=False) + with gr.Column(): + gr.Markdown( + "#### Figure 2: Battle Count for Each Combination of Models (without Ties)" + ) + plot_2 = gr.Plot(p2, show_label=False) + with gr.Row(): + with gr.Column(): + gr.Markdown( + "#### Figure 3: Bootstrap of Elo Estimates (1000 Rounds of Random Sampling)" + ) + plot_3 = gr.Plot(p3, show_label=False) + with gr.Column(): + gr.Markdown( + "#### Figure 4: Average Win Rate Against All Other Models (Assuming Uniform Sampling and No Ties)" + ) + plot_4 = gr.Plot(p4, show_label=False) + """ + + from fastchat.serve.gradio_web_server import acknowledgment_md + + gr.Markdown(acknowledgment_md) + + # return [md_1, plot_1, plot_2, plot_3, plot_4] + return [md_1] + + +def build_demo(elo_results_file, leaderboard_table_file): + from fastchat.serve.gradio_web_server import block_css + + text_size = gr.themes.sizes.text_lg + + with gr.Blocks( + title="Monitor", + theme=gr.themes.Base(text_size=text_size), + css=block_css, + ) as demo: + with gr.Tabs() as tabs: + with gr.Tab("Leaderboard", id=0): + leader_components = build_leaderboard_tab( + elo_results_file, leaderboard_table_file + ) + + with gr.Tab("Basic Stats", id=1): + basic_components = build_basic_stats_tab() + + url_params = gr.JSON(visible=False) + demo.load( + load_demo, + [url_params], + basic_components + leader_components, + _js=get_window_url_params_js, + ) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--share", action="store_true") + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument("--update-interval", type=int, default=300) + parser.add_argument("--max-num-files", type=int) + parser.add_argument("--elo-results-file", type=str) + parser.add_argument("--leaderboard-table-file", type=str) + args = parser.parse_args() + + logger = build_logger("monitor", "monitor.log") + logger.info(f"args: {args}") + + if args.elo_results_file is None: # Do live update + update_thread = threading.Thread( + target=update_worker, + args=(args.max_num_files, args.update_interval, args.elo_results_file), + ) + update_thread.start() + + demo = build_demo(args.elo_results_file, args.leaderboard_table_file) + demo.queue( + concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + ).launch( + server_name=args.host, server_port=args.port, share=args.share, max_threads=200 + ) diff --git a/monitor/summarize_cluster.py b/monitor/summarize_cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5fbcddc445b2434a35cefda5780f69a6cd8bca --- /dev/null +++ b/monitor/summarize_cluster.py @@ -0,0 +1,76 @@ +""" +Usage: +python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model gpt-4 --num-prompts 100 +python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model azure-gpt-4-32k --num-prompts 200 +""" +import argparse +import pickle + +from fastchat.llm_judge.common import ( + chat_compeletion_openai, + chat_compeletion_openai_azure, + chat_compeletion_anthropic, +) +from fastchat.conversation import get_conv_template + + +def truncate_string(s, l): + half = int(l // 2) + return s[:half] + s[-half:] if len(s) > l else s + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-file", type=str, required=True) + parser.add_argument("--model", type=str, default="gpt-3.5-turbo") + parser.add_argument("--num-prompts", type=int, default=100) + args = parser.parse_args() + + model = args.model + + cluster_infos = pickle.load(open(args.input_file, "rb")) + num_total_prompts = sum([x[0] for x in cluster_infos]) + + topics = [] + percentages = [] + for i, info in enumerate(cluster_infos): + num_samples, topk_prompts, random_prompts = info + percentage = num_samples / num_total_prompts + print( + f"cluster {i}, #prompts {num_samples}, percentage: {percentage * 100:.2f}%" + ) + instruct = "Given a list of user messages, use less than 8 words to summarize a central topic for all messages in English. Your output should only include a single line. Try to be specific." + split = int(args.num_prompts * 0.8) + prompt = "\n".join( + [truncate_string(x, l=200) for x in topk_prompts[:split]] + + [ + truncate_string(x, l=200) + for x in random_prompts[: args.num_prompts - split] + ] + ) + prompt = "BEGIN OF THE MESSAGE LIST\n" + prompt + "\nEND OF THE MESSAGE LIST." + + if "azure-" in model: + template_name = "chatgpt" + completion_func = chat_compeletion_openai_azure + elif "gpt" in model: + template_name = "chatgpt" + completion_func = chat_compeletion_openai + elif "claude" in model: + template_name = "claude" + completion_func = chat_compeletion_anthropic + + conv = get_conv_template(template_name) + conv.set_system_message(instruct) + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + + topic = completion_func(model, conv, temperature=0, max_tokens=256) + print(topic) + + topics.append(topic) + percentages.append(round(percentage, 6)) + + print() + print(f"topics: {topics}") + print(f"percentages: {percentages}") diff --git a/monitor/tag_openai_moderation.py b/monitor/tag_openai_moderation.py new file mode 100644 index 0000000000000000000000000000000000000000..b80703388b2a47bf372a09bbed81d7bede2bd412 --- /dev/null +++ b/monitor/tag_openai_moderation.py @@ -0,0 +1,63 @@ +""" +Add OpenAI moderation API results to all conversations. +""" +import argparse +from concurrent.futures import ThreadPoolExecutor +import json +import os +import time + +import openai +import requests +from tqdm import tqdm + + +API_MAX_RETRY = 16 +API_RETRY_SLEEP = 10 +API_ERROR_OUTPUT = "$ERROR$" + + +def tag_moderation(text): + result = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + result = openai.Moderation.create(input=text)["results"][0] + break + except openai.error.OpenAIError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + + return result + + +def tag_openai_moderation(x): + conv = x["conversation_a"] + user_prompts = "\n".join([x["content"] for x in conv if x["role"] == "user"]) + result = tag_moderation(user_prompts) + x["openai_moderation"] = result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, required=True) + parser.add_argument( + "--parallel", type=int, default=1, help="The number of concurrent API calls." + ) + parser.add_argument("--first-n", type=int) + args = parser.parse_args() + + battles = json.load(open(args.input)) + + if args.first_n: + battles = battles[: args.first_n] + + with ThreadPoolExecutor(args.parallel) as executor: + for line in tqdm( + executor.map(tag_openai_moderation, battles), total=len(battles) + ): + pass + + output = args.input.replace(".json", "_tagged.json") + with open(output, "w") as fout: + json.dump(battles, fout, indent=2, ensure_ascii=False) + print(f"Write cleaned data to {output}") diff --git a/monitor/topic_clustering.py b/monitor/topic_clustering.py new file mode 100644 index 0000000000000000000000000000000000000000..dd15c6edca6666127f5ee441c41f55bb88878249 --- /dev/null +++ b/monitor/topic_clustering.py @@ -0,0 +1,267 @@ +""" + +Usage: +python3 topic_clustering.py --in arena.json --english-only --min-length 32 +python3 topic_clustering.py --in clean_conv_20230809_100k.json --english-only --min-length 32 --max-length 1536 +""" +import argparse +import json +import pickle +import string +import time + +import numpy as np +from sentence_transformers import SentenceTransformer +from sentence_transformers.util import cos_sim +from sklearn.cluster import KMeans, AgglomerativeClustering +import torch +from tqdm import tqdm + +from fastchat.utils import detect_language + + +def remove_punctuation(input_string): + # Make a translator object to remove all punctuation + translator = str.maketrans("", "", string.punctuation) + + # Use the translator object to remove the punctuation + no_punct = input_string.translate(translator) + return no_punct + + +def read_texts(input_file, min_length, max_length, english_only): + visited = set() + texts = [] + + lines = json.load(open(input_file, "r")) + + for l in tqdm(lines): + if "text" in l: + line_texts = [l["text"]] + elif "conversation_a" in l: + line_texts = [ + x["content"] for x in l["conversation_a"] if x["role"] == "user" + ] + elif "conversation" in l: + line_texts = [ + x["content"] for x in l["conversation"] if x["role"] == "user" + ] + + for text in line_texts: + text = text.strip() + + # Filter language + if english_only: + lang = detect_language(text) + if lang != "English": + continue + + # Filter short or long prompts + if min_length: + if len(text) < min_length: + continue + + if max_length: + if len(text) > max_length: + continue + + # De-duplication + words = sorted([x.lower() for x in remove_punctuation(text).split(" ")]) + words = "".join(words) + if words in visited: + continue + + visited.add(words) + texts.append(text) + return np.array(texts) + + +def get_embeddings(texts, model_name, batch_size): + model = SentenceTransformer(model_name) + embeddings = model.encode( + texts, + batch_size=batch_size, + show_progress_bar=True, + device="cuda", + convert_to_tensor=True, + ) + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + return embeddings.cpu() + + +def run_k_means(embeddings, num_clusters): + np.random.seed(42) + clustering_model = KMeans(n_clusters=num_clusters, n_init="auto") + clustering_model.fit(embeddings.numpy()) + centers = torch.from_numpy(clustering_model.cluster_centers_) + labels = torch.from_numpy(clustering_model.labels_) + + # Sort labels + classes, counts = np.unique(labels, return_counts=True) + indices = np.argsort(counts)[::-1] + classes = [classes[i] for i in indices] + new_labels = torch.empty_like(labels) + new_centers = torch.empty_like(centers) + for i, c in enumerate(classes): + new_labels[labels == c] = i + new_centers[i] = centers[c] + return new_centers, new_labels + + +def run_agg_cluster(embeddings, num_clusters): + np.random.seed(42) + clustering_model = AgglomerativeClustering(n_clusters=num_clusters) + clustering_model.fit(embeddings) + labels = torch.from_numpy(clustering_model.labels_) + + # Sort labels + classes, counts = np.unique(labels, return_counts=True) + indices = np.argsort(counts)[::-1] + classes = [classes[i] for i in indices] + new_labels = torch.empty_like(labels) + for i, c in enumerate(classes): + new_labels[labels == c] = i + + # Compute centers + centers = [] + for i in range(len(classes)): + centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True)) + centers = torch.cat(centers) + return centers, new_labels + + +def run_hdbscan_cluster(embeddings): + import hdbscan + + np.random.seed(42) + clusterer = hdbscan.HDBSCAN(min_cluster_size=10) + labels = torch.from_numpy(clusterer.fit_predict(embeddings)) + + # Sort labels + classes, counts = np.unique(labels, return_counts=True) + indices = np.argsort(counts)[::-1] + classes = [classes[i] for i in indices] + new_labels = torch.empty_like(labels) + for i, c in enumerate(classes): + new_labels[labels == c] = i + + # Compute centers + centers = [] + for i in range(len(classes)): + centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True)) + centers = torch.cat(centers) + return centers, new_labels + + +def get_topk_indices(centers, labels, embeddings, topk): + indices = [] + arange = torch.arange(len(labels)) + counts = torch.unique(labels, return_counts=True)[1] + topk = min(topk, counts.min().item()) + for i in range(len(centers)): + tmp_indices = labels == i + tmp_arange = arange[tmp_indices] + tmp_embeddings = embeddings[tmp_indices] + + scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0] + sorted_indices = torch.flip(torch.argsort(scores), dims=[0]) + indices.append(tmp_arange[sorted_indices[:topk]].unsqueeze(0)) + return torch.cat(indices) + + +def print_topk(texts, labels, topk_indices, show_cut_off): + ret = "" + for k in range(len(topk_indices)): + num_samples = torch.sum(labels == k).item() + + ret += "=" * 20 + f" cluster {k}, #samples: {num_samples} " + "=" * 20 + "\n" + for idx in topk_indices[k]: + ret += "PROMPT: " + texts[idx][:show_cut_off] + "\n" + ret += "=" * 40 + "\n\n" + + return ret + + +def get_cluster_info(texts, labels, topk_indices): + np.random.seed(42) + + cluster_info = [] + for k in range(len(topk_indices)): + num_samples = torch.sum(labels == k).item() + topk_prompts = [] + for idx in topk_indices[k]: + topk_prompts.append(texts[idx]) + random_prompts = [] + for idx in range(len(topk_indices)): + random_prompts.append(np.random.choice(texts)) + cluster_info.append((num_samples, topk_prompts, random_prompts)) + + return cluster_info + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-file", type=str, required=True) + parser.add_argument("--model", type=str, default="all-mpnet-base-v2") + # default="all-MiniLM-L12-v2") + # default="multi-qa-distilbert-cos-v1") + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--min-length", type=int) + parser.add_argument("--max-length", type=int) + parser.add_argument("--english-only", action="store_true") + parser.add_argument("--num-clusters", type=int, default=20) + parser.add_argument( + "--cluster-alg", + type=str, + choices=["kmeans", "aggcls", "HDBSCAN"], + default="kmeans", + ) + parser.add_argument("--show-top-k", type=int, default=200) + parser.add_argument("--show-cut-off", type=int, default=512) + args = parser.parse_args() + + num_clusters = args.num_clusters + show_top_k = args.show_top_k + show_cut_off = args.show_cut_off + + texts = read_texts( + args.input_file, args.min_length, args.max_length, args.english_only + ) + print(f"#text: {len(texts)}") + + embeddings = get_embeddings(texts, args.model, args.batch_size) + if args.cluster_alg == "kmeans": + centers, labels = run_k_means(embeddings, num_clusters) + elif args.cluster_alg == "aggcls": + centers, labels = run_agg_cluster(embeddings, num_clusters) + elif args.cluster_alg == "HDBSCAN": + centers, labels = run_hdbscan_cluster(embeddings) + else: + raise ValueError(f"Invalid clustering algorithm: {args.cluster_alg}") + + topk_indices = get_topk_indices(centers, labels, embeddings, args.show_top_k) + topk_str = print_topk(texts, labels, topk_indices, args.show_cut_off) + num_clusters = len(centers) + + # Dump results + filename_prefix = f"results_c{num_clusters}_{args.cluster_alg}" + print(topk_str) + with open(filename_prefix + "_topk.txt", "w") as fout: + fout.write(topk_str) + + with open(filename_prefix + "_all.txt", "w") as fout: + for i in range(len(centers)): + tmp_indices = labels == i + tmp_embeddings = embeddings[tmp_indices] + tmp_texts = texts[tmp_indices] + + scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0] + sorted_indices = torch.flip(torch.argsort(scores), dims=[0]) + + for text, score in zip(tmp_texts[sorted_indices], scores[sorted_indices]): + obj = {"cluster": i, "text": text, "sim": score.item()} + fout.write(json.dumps(obj, ensure_ascii=False) + "\n") + + cluster_info = get_cluster_info(texts, labels, topk_indices) + with open(filename_prefix + "_cluster.pkl", "wb") as fout: + pickle.dump(cluster_info, fout) diff --git a/multi_model_worker.py b/multi_model_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..f77ff444790e51c7b4933765d6079fa41fe55ab1 --- /dev/null +++ b/multi_model_worker.py @@ -0,0 +1,282 @@ +""" +A multi-model worker that contains multiple sub-works one for each model. This +supports running a list of models on the same machine so that they can +(potentially) share the same background weights. + +Each model can have one or more model names. + +This multi-model worker assumes the models shares some underlying weights and +thus reports the combined queue lengths for health checks. + +We recommend using this with multiple Peft models (with `peft` in the name) +where all Peft models are trained on the exact same base model. +""" +import argparse +import asyncio +import dataclasses +import logging +import json +import os +import time +from typing import List, Union +import threading +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import requests + +try: + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LlamaTokenizer, + AutoModel, + ) +except ImportError: + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LLaMATokenizer, + AutoModel, + ) +import torch +import torch.nn.functional as F +import uvicorn + +from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG +from fastchat.model.model_adapter import ( + load_model, + add_model_args, + get_conversation_template, +) +from fastchat.model.model_chatglm import generate_stream_chatglm +from fastchat.model.model_falcon import generate_stream_falcon +from fastchat.model.model_codet5p import generate_stream_codet5p +from fastchat.modules.gptq import GptqConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.serve.inference import generate_stream +from fastchat.serve.model_worker import ModelWorker, worker_id, logger +from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length + + +# We store both the underlying workers and a mapping from their model names to +# the worker instance. This makes it easy to fetch the appropriate worker for +# each API call. +workers = [] +worker_map = {} +app = FastAPI() + + +def release_worker_semaphore(): + workers[0].semaphore.release() + + +def acquire_worker_semaphore(): + if workers[0].semaphore is None: + # Share the same semaphore for all workers because + # all workers share the same GPU. + semaphore = asyncio.Semaphore(workers[0].limit_worker_concurrency) + for w in workers: + w.semaphore = semaphore + return workers[0].semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +# Note: for all the calls below, we make a hard assumption that the caller +# includes the model name in the payload, otherwise we can't figure out which +# underlying sub-worker to call. + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + worker = worker_map[params["model"]] + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + worker = worker_map[params["model"]] + output = worker.generate_gate(params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + await acquire_worker_semaphore() + worker = worker_map[params["model"]] + embedding = worker.get_embeddings(params) + background_tasks = create_background_tasks() + return JSONResponse(content=embedding, background=background_tasks) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + } + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return {"context_length": worker.context_len} + + +def create_multi_model_worker(): + # Note: Ensure we resolve arg conflicts. We let `add_model_args` add MOST + # of the model args but we'll override one to have an append action that + # supports multiple values. + parser = argparse.ArgumentParser(conflict_handler="resolve") + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + add_model_args(parser) + # Override the model path to be repeated and align it with model names. + parser.add_argument( + "--model-path", + type=str, + default=[], + action="append", + help="One or more paths to model weights to load. This can be a local folder or a Hugging Face repo ID.", + ) + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + action="append", + help="One or more model names. Values must be aligned with `--model-path` values.", + ) + parser.add_argument( + "--conv-template", + type=str, + default=None, + action="append", + help="Conversation prompt template. Values must be aligned with `--model-path` values. If only one value is provided, it will be repeated for all models.", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=5) + parser.add_argument("--stream-interval", type=int, default=2) + parser.add_argument("--no-register", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + + gptq_config = GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + ) + else: + exllama_config = None + if args.enable_xft: + xft_config = XftConfig( + max_seq_len=args.xft_max_seq_len, + data_type=args.xft_dtype, + ) + if args.device != "cpu": + print("xFasterTransformer now is only support CPUs. Reset device to CPU") + args.device = "cpu" + else: + xft_config = None + + if args.model_names is None: + args.model_names = [[x.split("/")[-1]] for x in args.model_path] + + if args.conv_template is None: + args.conv_template = [None] * len(args.model_path) + elif len(args.conv_template) == 1: # Repeat the same template + args.conv_template = args.conv_template * len(args.model_path) + + # Launch all workers + workers = [] + for conv_template, model_path, model_names in zip( + args.conv_template, args.model_path, args.model_names + ): + w = ModelWorker( + args.controller_address, + args.worker_address, + worker_id, + model_path, + model_names, + args.limit_worker_concurrency, + args.no_register, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + exllama_config=exllama_config, + xft_config=xft_config, + stream_interval=args.stream_interval, + conv_template=conv_template, + ) + workers.append(w) + for model_name in model_names: + worker_map[model_name] = w + + # Register all models + url = args.controller_address + "/register_worker" + data = { + "worker_name": workers[0].worker_addr, + "check_heart_beat": not args.no_register, + "worker_status": { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + }, + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + return args, workers + + +if __name__ == "__main__": + args, workers = create_multi_model_worker() + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/openai_api_server.py b/openai_api_server.py new file mode 100644 index 0000000000000000000000000000000000000000..c15527f4c12ee1c1261ae3ba005f6e1e530b483f --- /dev/null +++ b/openai_api_server.py @@ -0,0 +1,879 @@ +"""A server that provides OpenAI-compatible RESTful APIs. It supports: + +- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat) +- Completions. (Reference: https://platform.openai.com/docs/api-reference/completions) +- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings) + +Usage: +python3 -m fastchat.serve.openai_api_server +""" +import asyncio +import argparse +import json +import logging +import os +from typing import Generator, Optional, Union, Dict, List, Any + +import aiohttp +import fastapi +from fastapi import Depends, HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +import httpx +from pydantic import BaseSettings +import shortuuid +import tiktoken +import uvicorn + +from fastchat.constants import ( + WORKER_API_TIMEOUT, + WORKER_API_EMBEDDING_BATCH_SIZE, + ErrorCode, +) +from fastchat.conversation import Conversation, SeparatorStyle +from fastchat.protocol.openai_api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + ChatCompletionResponseChoice, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + DeltaMessage, + CompletionResponseStreamChoice, + CompletionStreamResponse, + EmbeddingsRequest, + EmbeddingsResponse, + ErrorResponse, + LogProbs, + ModelCard, + ModelList, + ModelPermission, + UsageInfo, +) +from fastchat.protocol.api_protocol import ( + APIChatCompletionRequest, + APITokenCheckRequest, + APITokenCheckResponse, + APITokenCheckResponseItem, +) + +logger = logging.getLogger(__name__) + +conv_template_map = {} + +fetch_timeout = aiohttp.ClientTimeout(total=3 * 3600) + + +async def fetch_remote(url, pload=None, name=None): + async with aiohttp.ClientSession(timeout=fetch_timeout) as session: + async with session.post(url, json=pload) as response: + chunks = [] + if response.status != 200: + ret = { + "text": f"{response.reason}", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return json.dumps(ret) + + async for chunk, _ in response.content.iter_chunks(): + chunks.append(chunk) + output = b"".join(chunks) + + if name is not None: + res = json.loads(output) + if name != "": + res = res[name] + return res + + return output + + +class AppSettings(BaseSettings): + # The address of the model controller. + controller_address: str = "http://localhost:21001" + api_keys: Optional[List[str]] = None + + +app_settings = AppSettings() +app = fastapi.FastAPI() +headers = {"User-Agent": "FastChat API Server"} +get_bearer_token = HTTPBearer(auto_error=False) + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), +) -> str: + if app_settings.api_keys: + if auth is None or (token := auth.credentials) not in app_settings.api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +def create_error_response(code: int, message: str) -> JSONResponse: + return JSONResponse( + ErrorResponse(message=message, code=code).dict(), status_code=400 + ) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc): + return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc)) + + +async def check_model(request) -> Optional[JSONResponse]: + controller_address = app_settings.controller_address + ret = None + + models = await fetch_remote(controller_address + "/list_models", None, "models") + if request.model not in models: + ret = create_error_response( + ErrorCode.INVALID_MODEL, + f"Only {'&&'.join(models)} allowed now, your model {request.model}", + ) + return ret + + +async def check_length(request, prompt, max_tokens, worker_addr): + if ( + not isinstance(max_tokens, int) or max_tokens <= 0 + ): # model worker not support max_tokens=None + max_tokens = 1024 * 1024 + + context_len = await fetch_remote( + worker_addr + "/model_details", {"model": request.model}, "context_length" + ) + token_num = await fetch_remote( + worker_addr + "/count_token", + {"model": request.model, "prompt": prompt}, + "count", + ) + return min(max_tokens, context_len - token_num) + + +def check_requests(request) -> Optional[JSONResponse]: + # Check all params + if request.max_tokens is not None and request.max_tokens <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'", + ) + if request.n is not None and request.n <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.n} is less than the minimum of 1 - 'n'", + ) + if request.temperature is not None and request.temperature < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is less than the minimum of 0 - 'temperature'", + ) + if request.temperature is not None and request.temperature > 2: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is greater than the maximum of 2 - 'temperature'", + ) + if request.top_p is not None and request.top_p < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is less than the minimum of 0 - 'top_p'", + ) + if request.top_p is not None and request.top_p > 1: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is greater than the maximum of 1 - 'temperature'", + ) + if request.top_k is not None and (request.top_k > -1 and request.top_k < 1): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.", + ) + if request.stop is not None and ( + not isinstance(request.stop, str) and not isinstance(request.stop, list) + ): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.stop} is not valid under any of the given schemas - 'stop'", + ) + + return None + + +def process_input(model_name, inp): + if isinstance(inp, str): + inp = [inp] + elif isinstance(inp, list): + if isinstance(inp[0], int): + decoding = tiktoken.model.encoding_for_model(model_name) + inp = [decoding.decode(inp)] + elif isinstance(inp[0], list): + decoding = tiktoken.model.encoding_for_model(model_name) + inp = [decoding.decode(text) for text in inp] + + return inp + + +def create_openai_logprobs(logprob_dict): + """Create OpenAI-style logprobs.""" + return LogProbs(**logprob_dict) if logprob_dict is not None else None + + +def _add_to_set(s, new_stop): + if not s: + return + if isinstance(s, str): + new_stop.add(s) + else: + new_stop.update(s) + + +async def get_gen_params( + model_name: str, + worker_addr: str, + messages: Union[str, List[Dict[str, str]]], + *, + temperature: float, + top_p: float, + top_k: Optional[int], + presence_penalty: Optional[float], + frequency_penalty: Optional[float], + max_tokens: Optional[int], + echo: Optional[bool], + logprobs: Optional[int] = None, + stop: Optional[Union[str, List[str]]], + best_of: Optional[int] = None, + use_beam_search: Optional[bool] = None, +) -> Dict[str, Any]: + conv = await get_conv(model_name, worker_addr) + conv = Conversation( + name=conv["name"], + system_template=conv["system_template"], + system_message=conv["system_message"], + roles=conv["roles"], + messages=list(conv["messages"]), # prevent in-place modification + offset=conv["offset"], + sep_style=SeparatorStyle(conv["sep_style"]), + sep=conv["sep"], + sep2=conv["sep2"], + stop_str=conv["stop_str"], + stop_token_ids=conv["stop_token_ids"], + ) + + if isinstance(messages, str): + prompt = messages + else: + for message in messages: + msg_role = message["role"] + if msg_role == "system": + conv.set_system_message(message["content"]) + elif msg_role == "user": + conv.append_message(conv.roles[0], message["content"]) + elif msg_role == "assistant": + conv.append_message(conv.roles[1], message["content"]) + else: + raise ValueError(f"Unknown role: {msg_role}") + + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "logprobs": logprobs, + "top_p": top_p, + "top_k": top_k, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "max_new_tokens": max_tokens, + "echo": echo, + "stop_token_ids": conv.stop_token_ids, + } + + if best_of is not None: + gen_params.update({"best_of": best_of}) + if use_beam_search is not None: + gen_params.update({"use_beam_search": use_beam_search}) + + new_stop = set() + _add_to_set(stop, new_stop) + _add_to_set(conv.stop_str, new_stop) + + gen_params["stop"] = list(new_stop) + + logger.debug(f"==== request ====\n{gen_params}") + return gen_params + + +async def get_worker_address(model_name: str) -> str: + """ + Get worker address based on the requested model + + :param model_name: The worker's model name + :return: Worker address from the controller + :raises: :class:`ValueError`: No available worker for requested model + """ + controller_address = app_settings.controller_address + worker_addr = await fetch_remote( + controller_address + "/get_worker_address", {"model": model_name}, "address" + ) + + # No available worker + if worker_addr == "": + raise ValueError(f"No available worker for {model_name}") + logger.debug(f"model_name: {model_name}, worker_addr: {worker_addr}") + return worker_addr + + +async def get_conv(model_name: str, worker_addr: str): + conv_template = conv_template_map.get((worker_addr, model_name)) + if conv_template is None: + conv_template = await fetch_remote( + worker_addr + "/worker_get_conv_template", {"model": model_name}, "conv" + ) + conv_template_map[(worker_addr, model_name)] = conv_template + return conv_template + + +@app.get("/v1/models", dependencies=[Depends(check_api_key)]) +async def show_available_models(): + controller_address = app_settings.controller_address + ret = await fetch_remote(controller_address + "/refresh_all_workers") + models = await fetch_remote(controller_address + "/list_models", None, "models") + + models.sort() + # TODO: return real model permission details + model_cards = [] + for m in models: + model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()])) + return ModelList(data=model_cards) + + +@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)]) +async def create_chat_completion(request: ChatCompletionRequest): + """Creates a completion for the chat message""" + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + worker_addr = await get_worker_address(request.model) + + gen_params = await get_gen_params( + request.model, + worker_addr, + request.messages, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + echo=False, + stop=request.stop, + ) + gen_params["max_new_tokens"] = await check_length( + request, + gen_params["prompt"], + gen_params["max_new_tokens"], + worker_addr, + ) + + if request.stream: + generator = chat_completion_stream_generator( + request.model, gen_params, request.n, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + + choices = [] + chat_completions = [] + for i in range(request.n): + content = asyncio.create_task(generate_completion(gen_params, worker_addr)) + chat_completions.append(content) + try: + all_tasks = await asyncio.gather(*chat_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role="assistant", content=content["text"]), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + if "usage" in content: + task_usage = UsageInfo.parse_obj(content["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) + + +async def chat_completion_stream_generator( + model_name: str, gen_params: Dict[str, Any], n: int, worker_addr: str +) -> Generator[str, Any, None]: + """ + Event stream format: + https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format + """ + id = f"chatcmpl-{shortuuid.random()}" + finish_stream_events = [] + for i in range(n): + # First chunk with role + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role="assistant"), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + + previous_text = "" + async for content in generate_completion_stream(gen_params, worker_addr): + if content["error_code"] != 0: + yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = content["text"].replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + + if len(delta_text) == 0: + delta_text = None + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + finish_reason=content.get("finish_reason", None), + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + if delta_text is None: + if content.get("finish_reason", None) is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + +@app.post("/v1/completions", dependencies=[Depends(check_api_key)]) +async def create_completion(request: CompletionRequest): + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + request.prompt = process_input(request.model, request.prompt) + + worker_addr = await get_worker_address(request.model) + for text in request.prompt: + max_tokens = await check_length(request, text, request.max_tokens, worker_addr) + if isinstance(max_tokens, int) and max_tokens < request.max_tokens: + request.max_tokens = max_tokens + + if request.stream: + generator = generate_completion_stream_generator( + request, request.n, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + else: + text_completions = [] + for text in request.prompt: + gen_params = await get_gen_params( + request.model, + worker_addr, + text, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + frequency_penalty=request.frequency_penalty, + presence_penalty=request.presence_penalty, + max_tokens=request.max_tokens, + logprobs=request.logprobs, + echo=request.echo, + stop=request.stop, + best_of=request.best_of, + use_beam_search=request.use_beam_search, + ) + for i in range(request.n): + content = asyncio.create_task( + generate_completion(gen_params, worker_addr) + ) + text_completions.append(content) + + try: + all_tasks = await asyncio.gather(*text_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + + choices = [] + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + CompletionResponseChoice( + index=i, + text=content["text"], + logprobs=create_openai_logprobs(content.get("logprobs", None)), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + task_usage = UsageInfo.parse_obj(content["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return CompletionResponse( + model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage) + ) + + +async def generate_completion_stream_generator( + request: CompletionRequest, n: int, worker_addr: str +): + model_name = request.model + id = f"cmpl-{shortuuid.random()}" + finish_stream_events = [] + for text in request.prompt: + for i in range(n): + previous_text = "" + gen_params = await get_gen_params( + request.model, + worker_addr, + text, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + logprobs=request.logprobs, + echo=request.echo, + stop=request.stop, + ) + async for content in generate_completion_stream(gen_params, worker_addr): + if content["error_code"] != 0: + yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = content["text"].replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + # todo: index is not apparent + choice_data = CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=create_openai_logprobs(content.get("logprobs", None)), + finish_reason=content.get("finish_reason", None), + ) + chunk = CompletionStreamResponse( + id=id, + object="text_completion", + choices=[choice_data], + model=model_name, + ) + if len(delta_text) == 0: + if content.get("finish_reason", None) is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + +async def generate_completion_stream(payload: Dict[str, Any], worker_addr: str): + controller_address = app_settings.controller_address + async with httpx.AsyncClient() as client: + delimiter = b"\0" + async with client.stream( + "POST", + worker_addr + "/worker_generate_stream", + headers=headers, + json=payload, + timeout=WORKER_API_TIMEOUT, + ) as response: + # content = await response.aread() + buffer = b"" + async for raw_chunk in response.aiter_raw(): + buffer += raw_chunk + while (chunk_end := buffer.find(delimiter)) >= 0: + chunk, buffer = buffer[:chunk_end], buffer[chunk_end + 1 :] + if not chunk: + continue + yield json.loads(chunk.decode()) + + +async def generate_completion(payload: Dict[str, Any], worker_addr: str): + return await fetch_remote(worker_addr + "/worker_generate", payload, "") + + +@app.post("/v1/embeddings", dependencies=[Depends(check_api_key)]) +@app.post("/v1/engines/{model_name}/embeddings", dependencies=[Depends(check_api_key)]) +async def create_embeddings(request: EmbeddingsRequest, model_name: str = None): + """Creates embeddings for the text""" + if request.model is None: + request.model = model_name + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + + request.input = process_input(request.model, request.input) + + data = [] + token_num = 0 + batch_size = WORKER_API_EMBEDDING_BATCH_SIZE + batches = [ + request.input[i : min(i + batch_size, len(request.input))] + for i in range(0, len(request.input), batch_size) + ] + for num_batch, batch in enumerate(batches): + payload = { + "model": request.model, + "input": batch, + "encoding_format": request.encoding_format, + } + embedding = await get_embedding(payload) + if "error_code" in embedding and embedding["error_code"] != 0: + return create_error_response(embedding["error_code"], embedding["text"]) + data += [ + { + "object": "embedding", + "embedding": emb, + "index": num_batch * batch_size + i, + } + for i, emb in enumerate(embedding["embedding"]) + ] + token_num += embedding["token_num"] + return EmbeddingsResponse( + data=data, + model=request.model, + usage=UsageInfo( + prompt_tokens=token_num, + total_tokens=token_num, + completion_tokens=None, + ), + ).dict(exclude_none=True) + + +async def get_embedding(payload: Dict[str, Any]): + controller_address = app_settings.controller_address + model_name = payload["model"] + worker_addr = await get_worker_address(model_name) + + embedding = await fetch_remote(worker_addr + "/worker_get_embeddings", payload) + return json.loads(embedding) + + +### GENERAL API - NOT OPENAI COMPATIBLE ### + + +@app.post("/api/v1/token_check") +async def count_tokens(request: APITokenCheckRequest): + """ + Checks the token count for each message in your list + This is not part of the OpenAI API spec. + """ + checkedList = [] + for item in request.prompts: + worker_addr = await get_worker_address(item.model) + + context_len = await fetch_remote( + worker_addr + "/model_details", + {"prompt": item.prompt, "model": item.model}, + "context_length", + ) + + token_num = await fetch_remote( + worker_addr + "/count_token", + {"prompt": item.prompt, "model": item.model}, + "count", + ) + + can_fit = True + if token_num + item.max_tokens > context_len: + can_fit = False + + checkedList.append( + APITokenCheckResponseItem( + fits=can_fit, contextLength=context_len, tokenCount=token_num + ) + ) + + return APITokenCheckResponse(prompts=checkedList) + + +@app.post("/api/v1/chat/completions") +async def create_chat_completion(request: APIChatCompletionRequest): + """Creates a completion for the chat message""" + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + worker_addr = await get_worker_address(request.model) + + gen_params = await get_gen_params( + request.model, + worker_addr, + request.messages, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + echo=False, + stop=request.stop, + ) + + if request.repetition_penalty is not None: + gen_params["repetition_penalty"] = request.repetition_penalty + + gen_params["max_new_tokens"] = await check_length( + request, + gen_params["prompt"], + gen_params["max_new_tokens"], + worker_addr, + ) + + if request.stream: + generator = chat_completion_stream_generator( + request.model, gen_params, request.n, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + + choices = [] + chat_completions = [] + for i in range(request.n): + content = asyncio.create_task(generate_completion(gen_params, worker_addr)) + chat_completions.append(content) + try: + all_tasks = await asyncio.gather(*chat_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role="assistant", content=content["text"]), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + task_usage = UsageInfo.parse_obj(content["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) + + +### END GENERAL API - NOT OPENAI COMPATIBLE ### + + +def create_openai_api_server(): + parser = argparse.ArgumentParser( + description="FastChat ChatGPT-Compatible RESTful API server." + ) + parser.add_argument("--host", type=str, default="localhost", help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" + ) + parser.add_argument( + "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" + ) + parser.add_argument( + "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" + ) + parser.add_argument( + "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" + ) + parser.add_argument( + "--api-keys", + type=lambda s: s.split(","), + help="Optional list of comma separated API keys", + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + app_settings.controller_address = args.controller_address + app_settings.api_keys = args.api_keys + + logger.info(f"args: {args}") + return args + + +if __name__ == "__main__": + args = create_openai_api_server() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/register_worker.py b/register_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..2c2c40295e0351f25709ba25554c9329f15bf0d2 --- /dev/null +++ b/register_worker.py @@ -0,0 +1,26 @@ +""" +Manually register workers. + +Usage: +python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 +""" + +import argparse + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--controller-address", type=str) + parser.add_argument("--worker-name", type=str) + parser.add_argument("--check-heart-beat", action="store_true") + args = parser.parse_args() + + url = args.controller_address + "/register_worker" + data = { + "worker_name": args.worker_name, + "check_heart_beat": args.check_heart_beat, + "worker_status": None, + } + r = requests.post(url, json=data) + assert r.status_code == 200 diff --git a/shutdown_serve.py b/shutdown_serve.py new file mode 100644 index 0000000000000000000000000000000000000000..95e2b704f0b65584c5be15ce14b40bc150bd6009 --- /dev/null +++ b/shutdown_serve.py @@ -0,0 +1,24 @@ +""" +Usage: +python shutdown_serve.py --down all +options: "all","controller","model_worker","openai_api_server", `all` means to stop all related servers +""" + +import argparse +import os +import subprocess + +parser = argparse.ArgumentParser() +parser.add_argument( + "--down", choices=["all", "controller", "model_worker", "openai_api_server"] +) +args = parser.parse_args() +base_shell = "ps -eo user,pid,cmd|grep fastchat.serve{}|grep -v grep|awk '{{print $2}}'|xargs kill -9" +if args.down == "all": + shell_script = base_shell.format("") +else: + serve = f".{args.down}" + shell_script = base_shell.format(serve) +print(f"execute shell cmd: {shell_script}") +subprocess.run(shell_script, shell=True, check=True) +print(f"{args.down} has been shutdown!") diff --git a/test_message.py b/test_message.py new file mode 100644 index 0000000000000000000000000000000000000000..3d83f0fe7a6c06b285dc517d7b32e236d7867d88 --- /dev/null +++ b/test_message.py @@ -0,0 +1,82 @@ +"""Send a test message.""" +import argparse +import json + +import requests + +from fastchat.model.model_adapter import get_conversation_template +from fastchat.conversation import get_conv_template + +def main(): + model_name = args.model_name + conv_template = args.conv_template + if args.worker_address: + worker_addr = args.worker_address + else: + controller_addr = args.controller_address + ret = requests.post(controller_addr + "/refresh_all_workers") + ret = requests.post(controller_addr + "/list_models") + models = ret.json()["models"] + models.sort() + print(f"Models: {models}") + + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": model_name} + ) + worker_addr = ret.json()["address"] + print(f"worker_addr: {worker_addr}") + + if worker_addr == "": + print(f"No available workers for {model_name}") + return + + # conv = get_conversation_template(model_name) + conv = get_conv_template(conv_template) + conv.append_message(conv.roles[0], args.message) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + headers = {"User-Agent": "FastChat Client"} + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": args.temperature, + "max_new_tokens": args.max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + response = requests.post( + worker_addr + "/worker_generate_stream", + headers=headers, + json=gen_params, + stream=True, + ) + + print(f"{conv.roles[0]}: {args.message}") + print(f"{conv.roles[1]}: ", end="") + prev = 0 + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + output = data["text"].strip() + print(output[prev:], end="", flush=True) + prev = len(output) + print("") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--worker-address", type=str) + parser.add_argument("--model-name", type=str, required=True) + parser.add_argument("--conv-template", type=str, required=True) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--max-new-tokens", type=int, default=32) + parser.add_argument( + "--message", type=str, default="Tell me a story with more than 1000 words." + ) + args = parser.parse_args() + + main() diff --git a/test_throughput.py b/test_throughput.py new file mode 100644 index 0000000000000000000000000000000000000000..3796a6e2a7cb53dc6921674fc4c488246e0b93c7 --- /dev/null +++ b/test_throughput.py @@ -0,0 +1,115 @@ +"""Benchmarking script to test the throughput of serving workers.""" +import argparse +import json + +import requests +import threading +import time + +from fastchat.conversation import get_conv_template + + +def main(): + if args.worker_address: + worker_addr = args.worker_address + else: + controller_addr = args.controller_address + ret = requests.post(controller_addr + "/refresh_all_workers") + ret = requests.post(controller_addr + "/list_models") + models = ret.json()["models"] + models.sort() + print(f"Models: {models}") + + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": args.model_name} + ) + worker_addr = ret.json()["address"] + print(f"worker_addr: {worker_addr}") + + if worker_addr == "": + return + + conv = get_conv_template("vicuna_v1.1") + conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words") + prompt_template = conv.get_prompt() + prompts = [prompt_template for _ in range(args.n_thread)] + + headers = {"User-Agent": "fastchat Client"} + ploads = [ + { + "model": args.model_name, + "prompt": prompts[i], + "max_new_tokens": args.max_new_tokens, + "temperature": 0.0, + # "stop": conv.sep, + } + for i in range(len(prompts)) + ] + + def send_request(results, i): + if args.test_dispatch: + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": args.model_name} + ) + thread_worker_addr = ret.json()["address"] + else: + thread_worker_addr = worker_addr + print(f"thread {i} goes to {thread_worker_addr}") + response = requests.post( + thread_worker_addr + "/worker_generate_stream", + headers=headers, + json=ploads[i], + stream=False, + ) + k = list( + response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0") + ) + # print(k) + response_new_words = json.loads(k[-2].decode("utf-8"))["text"] + error_code = json.loads(k[-2].decode("utf-8"))["error_code"] + # print(f"=== Thread {i} ===, words: {1}, error code: {error_code}") + results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" ")) + + # use N threads to prompt the backend + tik = time.time() + threads = [] + results = [None] * args.n_thread + for i in range(args.n_thread): + t = threading.Thread(target=send_request, args=(results, i)) + t.start() + # time.sleep(0.5) + threads.append(t) + + for t in threads: + t.join() + + print(f"Time (POST): {time.time() - tik} s") + # n_words = 0 + # for i, response in enumerate(results): + # # print(prompt[i].replace(conv.sep, "\n"), end="") + # # make sure the streaming finishes at EOS or stopping criteria + # k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")) + # response_new_words = json.loads(k[-2].decode("utf-8"))["text"] + # # print(response_new_words) + # n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" ")) + n_words = sum(results) + time_seconds = time.time() - tik + print( + f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, " + f"throughput: {n_words / time_seconds} words/s." + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--worker-address", type=str) + parser.add_argument("--model-name", type=str, default="vicuna") + parser.add_argument("--max-new-tokens", type=int, default=2048) + parser.add_argument("--n-thread", type=int, default=8) + parser.add_argument("--test-dispatch", action="store_true") + args = parser.parse_args() + + main() diff --git a/vllm_worker.py b/vllm_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..6428d8b442f188add25b5db9ff04c86bc9bfc7fd --- /dev/null +++ b/vllm_worker.py @@ -0,0 +1,271 @@ +""" +A model worker that executes the model based on vLLM. + +See documentations at docs/vllm_integration.md +""" + +import argparse +import asyncio +import json +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from vllm import AsyncLLMEngine +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid + +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) +from fastchat.utils import get_context_length + + +app = FastAPI() + + +class VLLMWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + llm_engine: AsyncLLMEngine, + conv_template: str, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker..." + ) + self.tokenizer = llm_engine.engine.tokenizer + self.context_len = get_context_length(llm_engine.engine.model_config.hf_config) + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + context = params.pop("prompt") + request_id = params.pop("request_id") + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + presence_penalty = float(params.get("presence_penalty", 0.0)) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + max_new_tokens = params.get("max_new_tokens", 256) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if self.tokenizer.eos_token_id is not None: + stop_token_ids.append(self.tokenizer.eos_token_id) + echo = params.get("echo", True) + use_beam_search = params.get("use_beam_search", False) + best_of = params.get("best_of", None) + + # Handle stop_str + stop = set() + if isinstance(stop_str, str) and stop_str != "": + stop.add(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.update(stop_str) + + for tid in stop_token_ids: + if tid is not None: + stop.add(self.tokenizer.decode(tid)) + + # make sampling params in vllm + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + + sampling_params = SamplingParams( + n=1, + temperature=temperature, + top_p=top_p, + use_beam_search=use_beam_search, + stop=list(stop), + max_tokens=max_new_tokens, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + best_of=best_of, + ) + results_generator = engine.generate(context, sampling_params, request_id) + + async for request_output in results_generator: + prompt = request_output.prompt + if echo: + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + else: + text_outputs = [output.text for output in request_output.outputs] + text_outputs = " ".join(text_outputs) + # Note: usage is not supported yet + prompt_tokens = len(request_output.prompt_token_ids) + completion_tokens = sum( + len(output.token_ids) for output in request_output.outputs + ) + ret = { + "text": text_outputs, + "error_code": 0, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "cumulative_logprob": [ + output.cumulative_logprob for output in request_output.outputs + ], + "finish_reason": request_output.outputs[0].finish_reason + if len(request_output.outputs) == 1 + else [output.finish_reason for output in request_output.outputs], + } + yield (json.dumps(ret) + "\0").encode() + + async def generate(self, params): + async for x in self.generate_stream(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(request_id): + async def abort_request() -> None: + await engine.abort(request_id) + + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + background_tasks.add_task(abort_request) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = random_uuid() + params["request_id"] = request_id + generator = worker.generate_stream(params) + background_tasks = create_background_tasks(request_id) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = random_uuid() + params["request_id"] = request_id + output = await worker.generate(params) + release_worker_semaphore() + await engine.abort(request_id) + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.5") + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=1024) + parser.add_argument("--no-register", action="store_true") + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--trust_remote_code", + action="store_false", + default=True, + help="Trust remote code (e.g., from HuggingFace) when" + "downloading the model and tokenizer.", + ) + parser.add_argument( + "--gpu_memory_utilization", + type=float, + default=0.9, + help="The ratio (between 0 and 1) of GPU memory to" + "reserve for the model weights, activations, and KV cache. Higher" + "values will increase the KV cache size and thus improve the model's" + "throughput. However, if the value is too high, it may cause out-of-" + "memory (OOM) errors.", + ) + + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + if args.model_path: + args.model = args.model_path + if args.num_gpus > 1: + args.tensor_parallel_size = args.num_gpus + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args) + worker = VLLMWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + args.limit_worker_concurrency, + args.no_register, + engine, + args.conv_template, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info")