File size: 5,273 Bytes
95167e7
 
 
 
c02118d
95167e7
 
 
 
 
 
0c74ca6
c02118d
514b2f9
 
caf7748
5510202
0c74ca6
1a1d724
95167e7
 
80dce94
 
4870a28
95167e7
 
 
 
 
 
 
 
 
514b2f9
c02118d
e2ad91e
37ba605
3bb6f14
e2ad91e
120bd05
514b2f9
95167e7
 
 
d762b96
04320ba
 
 
 
514b2f9
1a1d724
95167e7
 
c02118d
514b2f9
cc6ed2c
25f48f1
95167e7
c02118d
95167e7
 
 
 
 
 
 
 
 
 
 
 
 
c02118d
514b2f9
7822b29
 
df6c8dd
7822b29
df6c8dd
d762b96
5bd743c
8465750
7822b29
5510202
7822b29
95167e7
b05cdf7
444efa4
0299602
 
95167e7
 
 
0299602
95167e7
 
514b2f9
95167e7
 
 
f7d5ff7
5510202
514b2f9
5510202
95167e7
 
0299602
95167e7
 
 
 
 
 
 
 
 
1a1d724
b14f3c1
95167e7
 
 
 
 
514b2f9
b05cdf7
95167e7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
from huggingface_hub import Repository, InferenceClient
import os
import json
import re

API_TOKEN = os.environ.get("API_TOKEN")
API_ENDPOINT = os.environ.get("API_ENDPOINT")

KEY = os.environ.get("KEY")

SPECIAL_SYMBOLS = ["⠀", "⠀"] # ["‹", "›"] ['"', '"']

DEFAULT_INPUT = "User: Hi!"
DEFAULT_WRAP = "Statical: %s"
DEFAULT_INSTRUCTION = "Statical is a helpful chatbot who is communicating with people."

DEFAULT_STOPS = '["⠀", "⠀"]' # '["‹", "›"]' '[\"\\\"\"]'

API_ENDPOINTS = {
    "Falcon": "tiiuae/falcon-180B-chat",
    "Llama": "meta-llama/Llama-2-70b-chat-hf",
    "Mistral": "mistralai/Mistral-7B-v0.1",
    "Mistral-2": "mistralai/Mistral-7B-Instruct-v0.1",
}

CHOICES = []
CLIENTS = {}

for model_name, model_endpoint in API_ENDPOINTS.items():
    CHOICES.append(model_name)
    CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })

def format(instruction, history, input, wrap):
    sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
    wrapped_input = wrap % ("")
    formatted_history = "".join(f"{sy_l}{message[0]}{sy_r}\n{sy_l}{message[1]}{sy_r}\n" for message in history)
    formatted_input = f"{sy_l}INSTRUCTIONS: {instruction}{sy_r}\n{formatted_history}{sy_l}{input}{sy_r}\n{sy_l}"
    return f"{formatted_input}{wrapped_input}", formatted_input
    
def predict(access_key, instruction, history, input, wrap, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):

    if (access_key != KEY):
        print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
        return ("[UNAUTHORIZED ACCESS]", input, []);

    instruction = instruction or DEFAULT_INSTRUCTION
    history = history or []
    input = input or ""
    wrap = wrap or ""
    stop_seqs = stop_seqs or DEFAULT_STOPS
        
    stops = json.loads(stop_seqs)

    formatted_input, formatted_input_base = format(instruction, history, input, wrap)
    print(seed)
    print(formatted_input)
    response = CLIENTS[model].text_generation(
        formatted_input,
        temperature = temperature,
        max_new_tokens = max_tokens,
        top_p = top_p,
        top_k = top_k,
        repetition_penalty = rep_p,
        stop_sequences = stops,
        do_sample = True,
        seed = seed,
        stream = False,
        details = False,
        return_full_text = False
    )

    sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
    result = wrap % (response)
    
    for stop in stops:
        result = result.split(stop, 1)[0]
    for symbol in stops:
        result = result.replace(symbol, '')

    history = history + [[input, result]]

    print(f"---\nUSER: {input}\nBOT: {result}\n---")

    return (result, input, history)

def clear_history():
    print(">>> HISTORY CLEARED!")
    return []
     
def maintain_cloud():
    print(">>> SPACE MAINTAINED!")
    return ("SUCCESS!", "SUCCESS!")

with gr.Blocks() as demo:
    with gr.Row(variant = "panel"):
        gr.Markdown("✡️ This is a private LLM CHAT Space owned within STC Holdings!\n\n\nhttps://discord.gg/6JRtGawz7B")
            
    with gr.Row():
        with gr.Column():
            history = gr.Chatbot(abel = "History", elem_id = "chatbot")
            input = gr.Textbox(label = "Input", value = DEFAULT_INPUT, lines = 2)
            wrap = gr.Textbox(label = "Wrap", value = DEFAULT_WRAP, lines = 1)
            instruction = gr.Textbox(label = "Instruction", value = DEFAULT_INSTRUCTION, lines = 4)
            access_key = gr.Textbox(label = "Access Key", lines = 1)
            run = gr.Button("▶")
            clear = gr.Button("🗑️")
            cloud = gr.Button("☁️")
            
        with gr.Column():
            model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model")
            temperature = gr.Slider( minimum = 0, maximum = 2, value = 1, step = 0.01, interactive = True, label = "Temperature" )
            top_p = gr.Slider( minimum = 0.01, maximum = 0.99, value = 0.95, step = 0.01, interactive = True, label = "Top P" )
            top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" )
            rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" )
            max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" )
            stop_seqs = gr.Textbox(label = "Stop Sequences ( JSON Array / 4 Max )", lines = 1, value = DEFAULT_STOPS )
            seed = gr.Slider( minimum = 0, maximum = 9007199254740991, value = 42, step = 1, interactive = True, label = "Seed" )
            
    with gr.Row():
        with gr.Column():
            output = gr.Textbox(label = "Output", value = "", lines = 50)

    run.click(predict, inputs = [access_key, instruction, history, input, wrap, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input, history])
    clear.click(clear_history, [], history)
    cloud.click(maintain_cloud, inputs = [], outputs = [input, output])
    
demo.queue(concurrency_count = 500, api_open = True).launch(show_api = True)