|
import gradio as gr |
|
import os |
|
from typing import List |
|
import logging |
|
import urllib.request |
|
from utils import model_name_mapping, urial_template, openai_base_request, DEFAULT_API_KEY |
|
from constant import js_code_label, HEADER_MD |
|
from openai import OpenAI |
|
import datetime |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
URIAL_VERSION = "inst_1k_v4.help" |
|
URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt" |
|
urial_prompt = urllib.request.urlopen(URIAL_URL).read().decode('utf-8') |
|
urial_prompt = urial_prompt.replace("```", '"""') |
|
STOP_STRS = ['"""', '# Query:', '# Answer:'] |
|
|
|
addr_limit_counter = {} |
|
LAST_UPDATE_TIME = datetime.datetime.now() |
|
|
|
def respond( |
|
message, |
|
history: list[tuple[str, str]], |
|
max_tokens, |
|
temperature, |
|
top_p, |
|
rp, |
|
model_name, |
|
together_api_key, |
|
request:gr.Request |
|
): |
|
global STOP_STRS, urial_prompt, LAST_UPDATE_TIME, addr_limit_counter |
|
rp = 1.0 |
|
prompt = urial_template(urial_prompt, history, message) |
|
|
|
|
|
_model_name = model_name_mapping(model_name) |
|
|
|
if together_api_key and len(together_api_key) == 64: |
|
api_key = together_api_key |
|
else: |
|
api_key = DEFAULT_API_KEY |
|
|
|
|
|
|
|
if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1): |
|
addr_limit_counter = {} |
|
LAST_UPDATE_TIME = datetime.datetime.now() |
|
host_addr = request.client.host |
|
if host_addr not in addr_limit_counter: |
|
addr_limit_counter[host_addr] = 0 |
|
if addr_limit_counter[host_addr] > 100: |
|
return "You have reached the limit of 100 requests for today. Please use your own API key." |
|
|
|
infer_request = openai_base_request(prompt=prompt, model=_model_name, |
|
temperature=temperature, |
|
max_tokens=max_tokens, |
|
top_p=top_p, |
|
repetition_penalty=rp, |
|
stop=STOP_STRS, api_key=api_key) |
|
addr_limit_counter[host_addr] += 1 |
|
logging.info(f"Requesting chat completion from OpenAI API with model {_model_name}") |
|
logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};") |
|
|
|
response = "" |
|
for msg in infer_request: |
|
|
|
token = msg.choices[0].delta["content"] |
|
should_stop = False |
|
for _stop in STOP_STRS: |
|
if _stop in response + token: |
|
should_stop = True |
|
break |
|
if should_stop: |
|
break |
|
response += token |
|
if response.endswith('\n"'): |
|
response = response[:-1] |
|
elif response.endswith('\n""'): |
|
response = response[:-2] |
|
yield response |
|
|
|
with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown(HEADER_MD) |
|
model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B", "Mistral-7B-v0.1", |
|
"Mixtral-8x22B", "Qwen1.5-72B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"] |
|
, value="Llama-3-8B", label="Base LLM name") |
|
with gr.Column(): |
|
together_api_key = gr.Textbox(label="π Together APIKey", placeholder="Enter your Together API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key") |
|
with gr.Column(): |
|
with gr.Row(): |
|
max_tokens = gr.Textbox(value=256, label="Max tokens") |
|
temperature = gr.Textbox(value=0.5, label="Temperature") |
|
top_p = gr.Textbox(value=0.9, label="Top-p") |
|
rp = gr.Textbox(value=1.1, label="Repetition penalty") |
|
chat = gr.ChatInterface( |
|
respond, |
|
additional_inputs=[max_tokens, temperature, top_p, rp, model_name, together_api_key], |
|
|
|
|
|
) |
|
chat.chatbot.label="Chat with Base LLMs via URIAL" |
|
chat.chatbot.height = 550 |
|
chat.chatbot.show_copy_button = True |
|
|
|
if __name__ == "__main__": |
|
demo.launch(show_api=False) |