File size: 5,634 Bytes
95167e7
e8f3085
95167e7
 
c02118d
95167e7
 
 
 
 
2a52702
 
c02118d
514b2f9
 
d1f0e94
5510202
2a52702
1a1d724
95167e7
5dc03de
 
80dce94
5dc03de
 
53af226
5dc03de
30e967f
94d32f5
59cd481
fac2414
9b09f94
95167e7
 
 
 
 
 
 
 
 
514b2f9
2a52702
 
e2ad91e
2a52702
94c41a5
e2ad91e
120bd05
514b2f9
c3fca4f
95167e7
 
d762b96
04320ba
 
 
 
514b2f9
1a1d724
95167e7
 
c02118d
514b2f9
69d55a1
cc6ed2c
25f48f1
69d55a1
 
95167e7
c02118d
95167e7
 
 
 
 
 
 
 
 
 
 
 
 
514b2f9
7822b29
 
df6c8dd
7822b29
df6c8dd
d762b96
5bd743c
8465750
7822b29
5510202
7822b29
95167e7
b05cdf7
444efa4
0299602
 
c194e76
 
0299602
95167e7
 
ae89068
95167e7
 
 
7a9ac91
5510202
514b2f9
5510202
95167e7
 
0299602
c194e76
95167e7
 
e458a61
95167e7
 
 
 
 
c3fca4f
b14f3c1
95167e7
 
 
 
 
d4bf10a
3ca2ec3
c194e76
95167e7
39f1f96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
from huggingface_hub import InferenceClient
import os
import json
import re

API_TOKEN = os.environ.get("API_TOKEN")

KEY = os.environ.get("KEY")

SPECIAL_SYMBOLS_AI = ["ㅤ", "ㅤ"]
SPECIAL_SYMBOLS_USER = ["⠀", "⠀"] # ["‹", "›"] ['"', '"']

DEFAULT_INPUT = "User: Hi!"
DEFAULT_WRAP = "Statical: %s"
DEFAULT_INSTRUCTION = "Conversation: Statical is a helpful chatbot who is communicating with people."

DEFAULT_STOPS = '["ㅤ", "⠀"]' # '["‹", "›"]' '[\"\\\"\"]'

API_ENDPOINTS = {
    "Falcon*": "tiiuae/falcon-180B-chat",
    "Llama*": "meta-llama/Llama-2-70b-chat-hf",
    "Mistral": "mistralai/Mistral-7B-v0.1",
    "Mistral_Chat": "mistralai/Mistral-7B-Instruct-v0.1",
    "Xistral_Chat": "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "Hermes": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
    "CodeLlama*": "codellama/CodeLlama-70b-Instruct-hf",
    "RX": "esab-xrbd/skcirbatad"[::-1],
    "CH": "sulp-r-dnammoc-ia4c/IAroFerehoC"[::-1],
    "MX": "QWA-1.0v-B22x8-lartxiM/ytinummoc-lartsim"[::-1],
    "ZE": "1.0v-b53A-b141-opro-ryhpez/4HecaFgnigguH"[::-1],
    "LL": "meta-llama/Meta-Llama-3-70B-Instruct"[::-1],
}

CHOICES = []
CLIENTS = {}

for model_name, model_endpoint in API_ENDPOINTS.items():
    CHOICES.append(model_name)
    CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })

def format(instruction, history, input, wrap):
    sy_la, sy_ra = SPECIAL_SYMBOLS_AI[0], SPECIAL_SYMBOLS_AI[1]
    sy_l, sy_r = SPECIAL_SYMBOLS_USER[0], SPECIAL_SYMBOLS_USER[1]
    wrapped_input = wrap % ("")
    formatted_history = "".join(f"{sy_l}{message[0]}{sy_r}{sy_la}{message[1]}{sy_la}" for message in history)
    formatted_input = f"{sy_la}{instruction}{sy_ra}{formatted_history}{sy_l}{input}{sy_r}{sy_la}"
    return f"{formatted_input}{wrapped_input}", formatted_input
    
def predict(access_key, instruction, history, input, wrap, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):

    if (access_key != KEY):
        print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
        return ("[UNAUTHORIZED ACCESS]", input, []);

    instruction = instruction or DEFAULT_INSTRUCTION
    history = history or []
    input = input or ""
    wrap = wrap or ""
    stop_seqs = stop_seqs or DEFAULT_STOPS
        
    stops = json.loads(stop_seqs)

    formatted_input, formatted_input_base = format(instruction, history, input, wrap)
    
    print(seed)
    print(formatted_input)
    print(model)

    response = CLIENTS[model].text_generation(
        formatted_input,
        temperature = temperature,
        max_new_tokens = max_tokens,
        top_p = top_p,
        top_k = top_k,
        repetition_penalty = rep_p,
        stop_sequences = stops,
        do_sample = True,
        seed = seed,
        stream = False,
        details = False,
        return_full_text = False
    )

    result = wrap % (response)
    
    for stop in stops:
        result = result.split(stop, 1)[0]
    for symbol in stops:
        result = result.replace(symbol, '')

    history = history + [[input, result]]

    print(f"---\nUSER: {input}\nBOT: {result}\n---")

    return (result, input, history)

def clear_history():
    print(">>> HISTORY CLEARED!")
    return []
     
def cloud():
    print("[CLOUD] | Space maintained.")

with gr.Blocks() as demo:
    with gr.Row(variant = "panel"):
        gr.Markdown("✨ A LLM space owned within Statical.")
            
    with gr.Row():
        with gr.Column():
            history = gr.Chatbot(label = "History", elem_id = "chatbot")
            input = gr.Textbox(label = "Input", value = DEFAULT_INPUT, lines = 2)
            wrap = gr.Textbox(label = "Wrap", value = DEFAULT_WRAP, lines = 1)
            instruction = gr.Textbox(label = "Instruction", value = DEFAULT_INSTRUCTION, lines = 4)
            access_key = gr.Textbox(label = "Access Key", lines = 1)
            run = gr.Button("▶")
            clear = gr.Button("🗑️")
            maintain = gr.Button("☁️")
            
        with gr.Column():
            model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model")
            temperature = gr.Slider( minimum = 0, maximum = 2, value = 1, step = 0.01, interactive = True, label = "Temperature" )
            top_p = gr.Slider( minimum = 0.01, maximum = 0.99, value = 0.95, step = 0.01, interactive = True, label = "Top P" )
            top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" )
            rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" )
            max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" )
            stop_seqs = gr.Textbox( value = DEFAULT_STOPS, interactive = True, label = "Stop Sequences ( JSON Array / 4 Max )" )
            seed = gr.Slider( minimum = 0, maximum = 9007199254740991, value = 42, step = 1, interactive = True, label = "Seed" )
            
    with gr.Row():
        with gr.Column():
            output = gr.Textbox(label = "Output", value = "", lines = 50)

    run.click(predict, inputs = [access_key, instruction, history, input, wrap, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input, history], queue = False)
    clear.click(clear_history, [], history, queue = False)
    maintain.click(cloud, inputs = [], outputs = [], queue = False)
    
demo.launch(show_api = True)