Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import Repository, InferenceClient | |
import os | |
import json | |
import re | |
API_TOKEN = os.environ.get("API_TOKEN") | |
API_ENDPOINT = os.environ.get("API_ENDPOINT") | |
KEY = os.environ.get("KEY") | |
SPECIAL_SYMBOLS = ["‹", "›"] | |
API_ENDPOINTS = { | |
"Falcon": "tiiuae/falcon-180B-chat", | |
"Llama": "meta-llama/Llama-2-70b-chat-hf" | |
} | |
CHOICES = [] | |
CLIENTS = {} | |
for model_name, model_endpoint in API_ENDPOINTS.items(): | |
CHOICES.append(model_name) | |
CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" }) | |
def format(instruction = "", history = "", input = "", preinput = ""): | |
sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1] | |
formatted_history = '\n'.join(f"{{sy_l}{message}{sy_r}" for message in history) | |
task_message = f"{instruction}\n{formatted_history}\n{sy_l}{input}{sy_r}\n{preinput}" | |
return prompt | |
def predict(instruction, history, input, preinput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed): | |
if (access_key != KEY): | |
print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key) | |
return ("[UNAUTHORIZED ACCESS]", input); | |
stops = json.loads(stop_seqs) | |
formatted_input = format(instruction, history, input, preinput) | |
response = CLIENTS[model].text_generation( | |
formatted_input, | |
temperature = temperature, | |
max_new_tokens = max_tokens, | |
top_p = top_p, | |
top_k = top_k, | |
repetition_penalty = rep_p, | |
stop_sequences = stops, | |
do_sample = True, | |
seed = seed, | |
stream = False, | |
details = False, | |
return_full_text = False | |
) | |
sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1] | |
pre_result = f"{sy_l}{response}{sy_r}{''.join(SPECIAL_SYMBOLS)}" | |
pattern = re.compile(f"{sy_l}(.*?){sy_r}", re.DOTALL) | |
match = pattern.search(pre_result) | |
get_result = match.group(1).strip() if match else "" | |
print(f"---\nUSER: {input}\nBOT: {get_result}\n---") | |
return (get_result, input) | |
def maintain_cloud(): | |
print(">>> SPACE MAINTAINED!") | |
return ("SUCCESS!", "SUCCESS!") | |
with gr.Blocks() as demo: | |
with gr.Row(variant = "panel"): | |
gr.Markdown("🔯 This is a private LLM CHAT Space owned within STC Holdings!\n\n\nhttps://discord.gg/6JRtGawz7B") | |
with gr.Row(): | |
with gr.Column(): | |
history = gr.Chatbot(elem_id = "chatbot") | |
input = gr.Textbox(label = "Input", lines = 2) | |
preinput = gr.Textbox(label = "Pre-Input", lines = 1) | |
instruction = gr.Textbox(label = "Instruction", lines = 4) | |
access_key = gr.Textbox(label = "Access Key", lines = 1) | |
run = gr.Button("▶") | |
cloud = gr.Button("☁️") | |
with gr.Column(): | |
model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model") | |
temperature = gr.Slider( minimum = 0, maximum = 2, value = 1, step = 0.01, interactive = True, label = "Temperature" ) | |
top_p = gr.Slider( minimum = 0.01, maximum = 0.99, value = 0.95, step = 0.01, interactive = True, label = "Top P" ) | |
top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" ) | |
rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" ) | |
max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" ) | |
stop_seqs = gr.Textbox(label = "Stop Sequences ( JSON Array / 4 Max )", lines = 1, value = '["‹", "›"]') | |
seed = gr.Slider( minimum = 0, maximum = 8192, value = 42, step = 1, interactive = True, label = "Seed" ) | |
with gr.Row(): | |
with gr.Column(): | |
output = gr.Textbox(label = "Output", value = "", lines = 50) | |
run.click(predict, inputs = [instruction, history, input, preinput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input]) | |
cloud.click(maintain_cloud, inputs = [], outputs = [input, output]) | |
demo.queue(concurrency_count = 500, api_open = True).launch(show_api = True) |