import gradio as gr import os from typing import List import logging import urllib.request from utils import model_name_mapping, urial_template, openai_base_request, DEFAULT_API_KEY from constant import js_code_label, HEADER_MD from openai import OpenAI import datetime # add logging info to console logging.basicConfig(level=logging.INFO) URIAL_VERSION = "inst_1k_v4.help" URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt" urial_prompt = urllib.request.urlopen(URIAL_URL).read().decode('utf-8') urial_prompt = urial_prompt.replace("```", '"""') # new version of URIAL uses """ instead of ``` STOP_STRS = ['"""', '# Query:', '# Answer:'] addr_limit_counter = {} LAST_UPDATE_TIME = datetime.datetime.now() def respond( message, history: list[tuple[str, str]], max_tokens, temperature, top_p, rp, model_name, together_api_key, request:gr.Request ): global STOP_STRS, urial_prompt, LAST_UPDATE_TIME, addr_limit_counter rp = 1.0 prompt = urial_template(urial_prompt, history, message) # _model_name = "meta-llama/Llama-3-8b-hf" _model_name = model_name_mapping(model_name) if together_api_key and len(together_api_key) == 64: api_key = together_api_key else: api_key = DEFAULT_API_KEY # headers = request.headers # if already 24 hours passed, reset the counter if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1): addr_limit_counter = {} LAST_UPDATE_TIME = datetime.datetime.now() host_addr = request.client.host if host_addr not in addr_limit_counter: addr_limit_counter[host_addr] = 0 if addr_limit_counter[host_addr] > 100: return "You have reached the limit of 100 requests for today. Please use your own API key." infer_request = openai_base_request(prompt=prompt, model=_model_name, temperature=temperature, max_tokens=max_tokens, top_p=top_p, repetition_penalty=rp, stop=STOP_STRS, api_key=api_key) addr_limit_counter[host_addr] += 1 logging.info(f"Requesting chat completion from OpenAI API with model {_model_name}") logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};") response = "" for msg in infer_request: # print(msg.choices[0].delta.keys()) token = msg.choices[0].delta["content"] should_stop = False for _stop in STOP_STRS: if _stop in response + token: should_stop = True break if should_stop: break response += token if response.endswith('\n"'): response = response[:-1] elif response.endswith('\n""'): response = response[:-2] yield response with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo: with gr.Row(): with gr.Column(): gr.Markdown(HEADER_MD) model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B", "Mistral-7B-v0.1", "Mixtral-8x22B", "Qwen1.5-72B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"] , value="Llama-3-8B", label="Base LLM name") with gr.Column(): together_api_key = gr.Textbox(label="🔑 Together APIKey", placeholder="Enter your Together API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key") with gr.Column(): with gr.Row(): max_tokens = gr.Textbox(value=256, label="Max tokens") temperature = gr.Textbox(value=0.5, label="Temperature") top_p = gr.Textbox(value=0.9, label="Top-p") rp = gr.Textbox(value=1.1, label="Repetition penalty") chat = gr.ChatInterface( respond, additional_inputs=[max_tokens, temperature, top_p, rp, model_name, together_api_key], # additional_inputs_accordion="⚙️ Parameters", # fill_height=True, ) chat.chatbot.label="Chat with Base LLMs via URIAL" chat.chatbot.height = 550 chat.chatbot.show_copy_button = True if __name__ == "__main__": demo.launch(show_api=False)