""" This module contains functions for generating responses using LLMs. """ import enum import logging from random import sample from typing import List from uuid import uuid4 from firebase_admin import firestore import gradio as gr from leaderboard import db from model import ContextWindowExceededError from model import Model from model import supported_models logging.basicConfig() logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def get_history_collection(category: str): if category == Category.SUMMARIZE.value: return db.collection("arena-summarization-history") if category == Category.TRANSLATE.value: return db.collection("arena-translation-history") def create_history(category: str, model_name: str, instruction: str, prompt: str, response: str): doc_id = uuid4().hex doc = { "id": doc_id, "model": model_name, "instruction": instruction, "prompt": prompt, "response": response, "timestamp": firestore.SERVER_TIMESTAMP } doc_ref = get_history_collection(category).document(doc_id) doc_ref.set(doc) class Category(enum.Enum): SUMMARIZE = "Summarize" TRANSLATE = "Translate" # TODO(#31): Let the model builders set the instruction. def get_instruction(category: str, model: Model, source_lang: str, target_lang: str): if category == Category.SUMMARIZE.value: return model.summarize_instruction if category == Category.TRANSLATE.value: return model.translate_instruction.format(source_lang=source_lang, target_lang=target_lang) def get_responses(prompt: str, category: str, source_lang: str, target_lang: str): if not category: raise gr.Error("Please select a category.") if category == Category.TRANSLATE.value and (not source_lang or not target_lang): raise gr.Error("Please select source and target languages.") models: List[Model] = sample(list(supported_models), 2) responses = [] for model in models: instruction = get_instruction(category, model, source_lang, target_lang) try: # TODO(#1): Allow user to set configuration. response = model.completion(messages=[{ "role": "system", "content": instruction }, { "role": "user", "content": prompt }]) create_history(category, model.name, instruction, prompt, response) responses.append(response) except ContextWindowExceededError as e: logger.exception("Context window exceeded for model %s.", model.name) raise gr.Error( "The prompt is too long. Please try again with a shorter prompt." ) from e except Exception as e: logger.exception("Failed to get response from model %s.", model.name) raise gr.Error("Failed to get response. Please try again.") from e model_names = [model.name for model in models] # It simulates concurrent stream response generation. max_response_length = max(len(response) for response in responses) for i in range(max_response_length): yield [response[:i + 1] for response in responses ] + model_names + [instruction] yield responses + model_names + [instruction]