File size: 4,929 Bytes
95167e7
 
 
 
c02118d
95167e7
 
 
 
 
 
c02118d
 
677e3b7
 
682b31c
5510202
95167e7
 
 
 
 
 
 
 
 
 
 
 
677e3b7
c02118d
04fbd0c
677e3b7
5510202
0299602
120bd05
677e3b7
95167e7
 
 
 
 
 
c02118d
677e3b7
95167e7
 
c02118d
95167e7
 
 
 
 
 
 
 
 
 
 
 
 
c02118d
 
 
 
 
0094eb2
57d5907
0299602
c02118d
5510202
aa862df
95167e7
b05cdf7
444efa4
0299602
 
95167e7
 
 
0299602
95167e7
 
120bd05
95167e7
 
 
f7d5ff7
5510202
0d2c418
5510202
95167e7
 
0299602
95167e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
677e3b7
b05cdf7
95167e7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
from huggingface_hub import Repository, InferenceClient
import os
import json
import re

API_TOKEN = os.environ.get("API_TOKEN")
API_ENDPOINT = os.environ.get("API_ENDPOINT")

KEY = os.environ.get("KEY")

SPECIAL_SYMBOLS = ["‹", "›"]

DEFAULT_INPUT = f"You: Hi!"
DEFAULT_PREOUTPUT = f"AI: "
DEFAULT_INSTRUCTION = "You are an helpful chatbot. You must respond in this format: \"‹NAME: MESSAGE›\""

API_ENDPOINTS = {
    "Falcon": "tiiuae/falcon-180B-chat",
    "Llama": "meta-llama/Llama-2-70b-chat-hf"
}

CHOICES = []
CLIENTS = {}

for model_name, model_endpoint in API_ENDPOINTS.items():
    CHOICES.append(model_name)
    CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })

def format(instruction = DEFAULT_INSTRUCTION, history = [], input = "", preoutput = ""):
    sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
    formatted_history = '\n'.join(f"{sy_l}{message}{sy_r}" for message in history)
    formatted_input = f"{sy_l}System: {instruction}{sy_r}\n{formatted_history}\n{sy_l}{input}{sy_r}\n{sy_l}{preoutput}"
    print(formatted_input)
    return formatted_input
    
def predict(instruction, history, input, preoutput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):

    if (access_key != KEY):
        print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
        return ("[UNAUTHORIZED ACCESS]", input);
        
    stops = json.loads(stop_seqs)

    formatted_input = format(instruction, history, input, preoutput)
    
    response = CLIENTS[model].text_generation(
        formatted_input,
        temperature = temperature,
        max_new_tokens = max_tokens,
        top_p = top_p,
        top_k = top_k,
        repetition_penalty = rep_p,
        stop_sequences = stops,
        do_sample = True,
        seed = seed,
        stream = False,
        details = False,
        return_full_text = False
    )

    sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
    pre_result = f"{sy_l}{response}{sy_r}{''.join(SPECIAL_SYMBOLS)}"
    pattern = re.compile(f"{sy_l}(.*?){sy_r}", re.DOTALL)
    match = pattern.search(pre_result)
    get_result = match.group(1).strip() if match else ""

    history = history + [[input, get_result]]
    
    print(f"---\nUSER: {input}\nBOT: {get_result}\n---")

    return (get_result, input, history)

def clear_history():
    print(">>> HISTORY CLEARED!")
    return []
     
def maintain_cloud():
    print(">>> SPACE MAINTAINED!")
    return ("SUCCESS!", "SUCCESS!")

with gr.Blocks() as demo:
    with gr.Row(variant = "panel"):
        gr.Markdown("🔯 This is a private LLM CHAT Space owned within STC Holdings!\n\n\nhttps://discord.gg/6JRtGawz7B")
            
    with gr.Row():
        with gr.Column():
            history = gr.Chatbot(abel = "History", elem_id = "chatbot")
            input = gr.Textbox(label = "Input", value = DEFAULT_INPUT, lines = 2)
            preoutput = gr.Textbox(label = "Pre-Output", value = DEFAULT_PREOUTPUT, lines = 1)
            instruction = gr.Textbox(label = "Instruction", value = DEFAULT_INSTRUCTION, lines = 4)
            access_key = gr.Textbox(label = "Access Key", lines = 1)
            run = gr.Button("▶")
            clear = gr.Button("🗑️")
            cloud = gr.Button("☁️")
            
        with gr.Column():
            model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model")
            temperature = gr.Slider( minimum = 0, maximum = 2, value = 1, step = 0.01, interactive = True, label = "Temperature" )
            top_p = gr.Slider( minimum = 0.01, maximum = 0.99, value = 0.95, step = 0.01, interactive = True, label = "Top P" )
            top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" )
            rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" )
            max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" )
            stop_seqs = gr.Textbox(label = "Stop Sequences ( JSON Array / 4 Max )", lines = 1, value = '["‹", "›"]')
            seed = gr.Slider( minimum = 0, maximum = 8192, value = 42, step = 1, interactive = True, label = "Seed" )
            
    with gr.Row():
        with gr.Column():
            output = gr.Textbox(label = "Output", value = "", lines = 50)

    run.click(predict, inputs = [instruction, history, input, preoutput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input, history])
    clear.click(clear_history, [], history)
    cloud.click(maintain_cloud, inputs = [], outputs = [input, output])
    
demo.queue(concurrency_count = 500, api_open = True).launch(show_api = True)