File size: 4,605 Bytes
95167e7
 
 
 
c02118d
95167e7
 
 
 
 
 
c02118d
 
95167e7
 
 
 
 
 
 
 
 
 
 
 
c02118d
 
04fbd0c
0299602
 
120bd05
c02118d
95167e7
 
 
 
 
 
c02118d
 
0299602
 
95167e7
 
c02118d
95167e7
 
 
 
 
 
 
 
 
 
 
 
 
c02118d
 
 
 
 
 
0299602
 
c02118d
95167e7
aa862df
95167e7
b05cdf7
0299602
 
95167e7
 
 
0299602
95167e7
 
120bd05
95167e7
 
 
120bd05
 
c02118d
120bd05
95167e7
 
0299602
95167e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa862df
b05cdf7
95167e7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
from huggingface_hub import Repository, InferenceClient
import os
import json
import re

API_TOKEN = os.environ.get("API_TOKEN")
API_ENDPOINT = os.environ.get("API_ENDPOINT")

KEY = os.environ.get("KEY")

SPECIAL_SYMBOLS = ["‹", "›"]

API_ENDPOINTS = {
    "Falcon": "tiiuae/falcon-180B-chat",
    "Llama": "meta-llama/Llama-2-70b-chat-hf"
}

CHOICES = []
CLIENTS = {}

for model_name, model_endpoint in API_ENDPOINTS.items():
    CHOICES.append(model_name)
    CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })

def format(instruction = "", history = "", input = "", preinput = ""):
    sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
    formatted_history = '\n'.join(f"{sy_l}{message}{sy_r}" for message in history)
    formatted_input = f"System: {sy_l}{instruction}{sy_r}\n{formatted_history}\n{sy_l}{input}{sy_r}\n{preinput}"
    return formatted_input
    
def predict(instruction, history, input, preinput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):

    if (access_key != KEY):
        print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
        return ("[UNAUTHORIZED ACCESS]", input);
        
    stops = json.loads(stop_seqs)

    formatted_input = format(instruction, history, input, preinput)

    history = history + [[input, ""]]
    
    response = CLIENTS[model].text_generation(
        formatted_input,
        temperature = temperature,
        max_new_tokens = max_tokens,
        top_p = top_p,
        top_k = top_k,
        repetition_penalty = rep_p,
        stop_sequences = stops,
        do_sample = True,
        seed = seed,
        stream = False,
        details = False,
        return_full_text = False
    )

    sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
    pre_result = f"{sy_l}{response}{sy_r}{''.join(SPECIAL_SYMBOLS)}"
    pattern = re.compile(f"{sy_l}(.*?){sy_r}", re.DOTALL)
    match = pattern.search(pre_result)
    get_result = match.group(1).strip() if match else ""

    history = history + [get_result]
    
    print(f"---\nUSER: {input}\nBOT: {get_result}\n---")

    return (get_result, input, history)

def clear_history():
    return []
     
def maintain_cloud():
    print(">>> SPACE MAINTAINED!")
    return ("SUCCESS!", "SUCCESS!")

with gr.Blocks() as demo:
    with gr.Row(variant = "panel"):
        gr.Markdown("🔯 This is a private LLM CHAT Space owned within STC Holdings!\n\n\nhttps://discord.gg/6JRtGawz7B")
            
    with gr.Row():
        with gr.Column():
            history = gr.Chatbot(elem_id = "chatbot")
            input = gr.Textbox(label = "Input", lines = 2)
            preinput = gr.Textbox(label = "Pre-Input", lines = 1)
            instruction = gr.Textbox(label = "Instruction", lines = 4)
            access_key = gr.Textbox(label = "Access Key", lines = 1)
            run = gr.Button("▶")
            clear = gr.Button("🗑️")
            cloud = gr.Button("☁️")
            
        with gr.Column():
            model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model")
            temperature = gr.Slider( minimum = 0, maximum = 2, value = 1, step = 0.01, interactive = True, label = "Temperature" )
            top_p = gr.Slider( minimum = 0.01, maximum = 0.99, value = 0.95, step = 0.01, interactive = True, label = "Top P" )
            top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" )
            rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" )
            max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" )
            stop_seqs = gr.Textbox(label = "Stop Sequences ( JSON Array / 4 Max )", lines = 1, value = '["‹", "›"]')
            seed = gr.Slider( minimum = 0, maximum = 8192, value = 42, step = 1, interactive = True, label = "Seed" )
            
    with gr.Row():
        with gr.Column():
            output = gr.Textbox(label = "Output", value = "", lines = 50)

    run.click(predict, inputs = [instruction, history, input, preinput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input, history])
    clear.click(clear_history, [], history)
    cloud.click(maintain_cloud, inputs = [], outputs = [input, output])
    
demo.queue(concurrency_count = 500, api_open = True).launch(show_api = True)