File size: 4,408 Bytes
0231e6a
 
 
 
54f7da0
8df0f23
 
1a5890e
a1a9059
0231e6a
 
 
 
54f7da0
 
 
 
0231e6a
a1a9059
 
460e2a9
 
 
 
 
 
 
0231e6a
 
a1a9059
 
0231e6a
a1a9059
0231e6a
 
8df0f23
0231e6a
8df0f23
460e2a9
0231e6a
 
 
 
460e2a9
a1a9059
 
 
 
 
 
 
 
 
 
 
 
0231e6a
 
 
 
54f7da0
a1a9059
 
 
 
460e2a9
a1a9059
0231e6a
 
 
54f7da0
9e82682
0231e6a
 
 
 
9e82682
 
 
 
 
 
0231e6a
9e82682
0231e6a
 
8df0f23
a19b85a
9374880
a19b85a
0231e6a
302880f
0231e6a
 
8359b58
0231e6a
 
 
 
 
 
 
 
 
d5ed8fa
f358cdd
d5ed8fa
460e2a9
a1a9059
ddd17ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr 
import os
from typing import List
import logging
import urllib.request
from utils import model_name_mapping, urial_template, openai_base_request, DEFAULT_API_KEY
from constant import js_code_label, HEADER_MD
from openai import OpenAI
import datetime
# add logging info to console 
logging.basicConfig(level=logging.INFO)

URIAL_VERSION = "inst_1k_v4.help"
URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt"
urial_prompt = urllib.request.urlopen(URIAL_URL).read().decode('utf-8')
urial_prompt = urial_prompt.replace("```", '"""') # new version of URIAL uses """ instead of ```
STOP_STRS = ['"""', '# Query:', '# Answer:']

addr_limit_counter = {}
LAST_UPDATE_TIME = datetime.datetime.now() 

def respond(
    message,
    history: list[tuple[str, str]],
    max_tokens,
    temperature,
    top_p,
    rp,
    model_name,
    together_api_key,
    request:gr.Request
):  
    global STOP_STRS, urial_prompt, LAST_UPDATE_TIME, addr_limit_counter
    rp = 1.0
    prompt = urial_template(urial_prompt, history, message)
    
    # _model_name = "meta-llama/Llama-3-8b-hf"
    _model_name = model_name_mapping(model_name)

    if together_api_key and len(together_api_key) == 64:
        api_key = together_api_key
    else:
        api_key = DEFAULT_API_KEY

    # headers = request.headers
    # if already 24 hours passed, reset the counter
    if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1):
        addr_limit_counter = {}
        LAST_UPDATE_TIME = datetime.datetime.now()
    host_addr = request.client.host
    if host_addr not in addr_limit_counter:
        addr_limit_counter[host_addr] = 0
    if addr_limit_counter[host_addr] > 100:
        return "You have reached the limit of 100 requests for today. Please use your own API key."

    infer_request = openai_base_request(prompt=prompt, model=_model_name, 
                                   temperature=temperature, 
                                   max_tokens=max_tokens, 
                                   top_p=top_p, 
                                   repetition_penalty=rp,
                                   stop=STOP_STRS, api_key=api_key)  
    addr_limit_counter[host_addr] += 1
    logging.info(f"Requesting chat completion from OpenAI API with model {_model_name}")
    logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};")

    response = ""
    for msg in infer_request:
        # print(msg.choices[0].delta.keys())
        token = msg.choices[0].delta["content"]
        should_stop = False
        for _stop in STOP_STRS:
            if _stop in response + token:
                should_stop = True
                break
        if should_stop:
            break
        response += token
        if response.endswith('\n"'):
            response = response[:-1]
        elif response.endswith('\n""'):
            response = response[:-2]
        yield response
 
with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown(HEADER_MD)
            model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B", "Mistral-7B-v0.1", 
                                   "Mixtral-8x22B", "Qwen1.5-72B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"]
                                  , value="Llama-3-8B", label="Base LLM name")
        with gr.Column():
            together_api_key = gr.Textbox(label="🔑 Together APIKey", placeholder="Enter your Together API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key")
            with gr.Column():
                with gr.Row():
                    max_tokens = gr.Textbox(value=256, label="Max tokens")
                    temperature = gr.Textbox(value=0.5, label="Temperature")
                    top_p = gr.Textbox(value=0.9, label="Top-p")
                    rp = gr.Textbox(value=1.1, label="Repetition penalty")
    chat = gr.ChatInterface(
        respond,
        additional_inputs=[max_tokens, temperature, top_p, rp, model_name, together_api_key],
        # additional_inputs_accordion="⚙️ Parameters",
        # fill_height=True, 
    )
    chat.chatbot.label="Chat with Base LLMs via URIAL"
    chat.chatbot.height = 550
    chat.chatbot.show_copy_button = True 

if __name__ == "__main__": 
    demo.launch(show_api=False)