Spaces:
Sleeping
Sleeping
File size: 5,024 Bytes
95167e7 c02118d 95167e7 c02118d 677e3b7 b155bbd 5510202 95167e7 e864faa c02118d 37ba605 0299602 120bd05 04320ba 95167e7 d762b96 04320ba 95167e7 c02118d 677e3b7 95167e7 c02118d 95167e7 c02118d 14026a4 c02118d d762b96 0094eb2 57d5907 d762b96 c02118d 5510202 d762b96 95167e7 b05cdf7 444efa4 0299602 95167e7 0299602 95167e7 120bd05 95167e7 f7d5ff7 5510202 0d2c418 5510202 95167e7 0299602 95167e7 677e3b7 b05cdf7 95167e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import gradio as gr
from huggingface_hub import Repository, InferenceClient
import os
import json
import re
API_TOKEN = os.environ.get("API_TOKEN")
API_ENDPOINT = os.environ.get("API_ENDPOINT")
KEY = os.environ.get("KEY")
SPECIAL_SYMBOLS = ["‹", "›"]
DEFAULT_INPUT = f"You: Hi!"
DEFAULT_PREOUTPUT = f"AI: "
DEFAULT_INSTRUCTION = "You are an helpful chatbot."
API_ENDPOINTS = {
"Falcon": "tiiuae/falcon-180B-chat",
"Llama": "meta-llama/Llama-2-70b-chat-hf"
}
CHOICES = []
CLIENTS = {}
for model_name, model_endpoint in API_ENDPOINTS.items():
CHOICES.append(model_name)
CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })
def format(instruction, history, input, preoutput):
sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
formatted_history = "".join(f"{sy_l}{message[0]}{sy_r}\n{sy_l}{message[1]}{sy_r}\n" for message in history)
formatted_input = f"{sy_l}System: {instruction}{sy_r}\n{formatted_history}{sy_l}{input}{sy_r}\n{sy_l}{preoutput}"
return formatted_input
def predict(instruction, history, input, preoutput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):
if (access_key != KEY):
print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
return ("[UNAUTHORIZED ACCESS]", input, []);
instruction = instruction or DEFAULT_INSTRUCTION
history = history or []
input = input or ""
preoutput = preoutput or ""
stops = json.loads(stop_seqs)
formatted_input = format(instruction, history, input, preoutput)
response = CLIENTS[model].text_generation(
formatted_input,
temperature = temperature,
max_new_tokens = max_tokens,
top_p = top_p,
top_k = top_k,
repetition_penalty = rep_p,
stop_sequences = stops,
do_sample = True,
seed = seed,
stream = False,
details = False,
return_full_text = False
)
sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
pre_result = f"{sy_l}{response}{sy_r}{''.join(SPECIAL_SYMBOLS)}"
pattern = re.compile(f"{sy_l}(.*){sy_r}", re.DOTALL)
match = pattern.search(pre_result)
get_result = preoutput + match.group(1).strip() if match else ""
history = history + [[input, get_result]]
print(formatted_input + get_result)
print(f"---\nUSER: {input}\nBOT: {get_result}\n---")
return (get_result, input, history)
def clear_history():
print(">>> HISTORY CLEARED!")
return []
def maintain_cloud():
print(">>> SPACE MAINTAINED!")
return ("SUCCESS!", "SUCCESS!")
with gr.Blocks() as demo:
with gr.Row(variant = "panel"):
gr.Markdown("🔯 This is a private LLM CHAT Space owned within STC Holdings!\n\n\nhttps://discord.gg/6JRtGawz7B")
with gr.Row():
with gr.Column():
history = gr.Chatbot(abel = "History", elem_id = "chatbot")
input = gr.Textbox(label = "Input", value = DEFAULT_INPUT, lines = 2)
preoutput = gr.Textbox(label = "Pre-Output", value = DEFAULT_PREOUTPUT, lines = 1)
instruction = gr.Textbox(label = "Instruction", value = DEFAULT_INSTRUCTION, lines = 4)
access_key = gr.Textbox(label = "Access Key", lines = 1)
run = gr.Button("▶")
clear = gr.Button("🗑️")
cloud = gr.Button("☁️")
with gr.Column():
model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model")
temperature = gr.Slider( minimum = 0, maximum = 2, value = 1, step = 0.01, interactive = True, label = "Temperature" )
top_p = gr.Slider( minimum = 0.01, maximum = 0.99, value = 0.95, step = 0.01, interactive = True, label = "Top P" )
top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" )
rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" )
max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" )
stop_seqs = gr.Textbox(label = "Stop Sequences ( JSON Array / 4 Max )", lines = 1, value = '["‹", "›"]')
seed = gr.Slider( minimum = 0, maximum = 8192, value = 42, step = 1, interactive = True, label = "Seed" )
with gr.Row():
with gr.Column():
output = gr.Textbox(label = "Output", value = "", lines = 50)
run.click(predict, inputs = [instruction, history, input, preoutput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input, history])
clear.click(clear_history, [], history)
cloud.click(maintain_cloud, inputs = [], outputs = [input, output])
demo.queue(concurrency_count = 500, api_open = True).launch(show_api = True) |