STC-LLM / app.py
Staticaliza's picture
Update app.py
c194e76 verified
raw history blame
No virus
5.63 kB
import gradio as gr
from huggingface_hub import InferenceClient
import os
import json
import re
API_TOKEN = os.environ.get("API_TOKEN")
KEY = os.environ.get("KEY")
SPECIAL_SYMBOLS_AI = ["ㅤ", "ㅤ"]
SPECIAL_SYMBOLS_USER = ["⠀", "⠀"] # ["‹", "›"] ['"', '"']
DEFAULT_INPUT = "User: Hi!"
DEFAULT_WRAP = "Statical: %s"
DEFAULT_INSTRUCTION = "Conversation: Statical is a helpful chatbot who is communicating with people."
DEFAULT_STOPS = '["ㅤ", "⠀"]' # '["‹", "›"]' '[\"\\\"\"]'
API_ENDPOINTS = {
"Falcon*": "tiiuae/falcon-180B-chat",
"Llama*": "meta-llama/Llama-2-70b-chat-hf",
"Mistral": "mistralai/Mistral-7B-v0.1",
"Mistral_Chat": "mistralai/Mistral-7B-Instruct-v0.1",
"Xistral_Chat": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"Hermes": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
"CodeLlama*": "codellama/CodeLlama-70b-Instruct-hf",
"RX": "esab-xrbd/skcirbatad"[::-1],
"CH": "sulp-r-dnammoc-ia4c/IAroFerehoC"[::-1],
"MX": "QWA-1.0v-B22x8-lartxiM/ytinummoc-lartsim"[::-1],
"ZE": "1.0v-b53A-b141-opro-ryhpez/4HecaFgnigguH"[::-1],
"LL": "meta-llama/Meta-Llama-3-70B-Instruct"[::-1],
}
CHOICES = []
CLIENTS = {}
for model_name, model_endpoint in API_ENDPOINTS.items():
CHOICES.append(model_name)
CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })
def format(instruction, history, input, wrap):
sy_la, sy_ra = SPECIAL_SYMBOLS_AI[0], SPECIAL_SYMBOLS_AI[1]
sy_l, sy_r = SPECIAL_SYMBOLS_USER[0], SPECIAL_SYMBOLS_USER[1]
wrapped_input = wrap % ("")
formatted_history = "".join(f"{sy_l}{message[0]}{sy_r}{sy_la}{message[1]}{sy_la}" for message in history)
formatted_input = f"{sy_la}{instruction}{sy_ra}{formatted_history}{sy_l}{input}{sy_r}{sy_la}"
return f"{formatted_input}{wrapped_input}", formatted_input
def predict(access_key, instruction, history, input, wrap, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):
if (access_key != KEY):
print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
return ("[UNAUTHORIZED ACCESS]", input, []);
instruction = instruction or DEFAULT_INSTRUCTION
history = history or []
input = input or ""
wrap = wrap or ""
stop_seqs = stop_seqs or DEFAULT_STOPS
stops = json.loads(stop_seqs)
formatted_input, formatted_input_base = format(instruction, history, input, wrap)
print(seed)
print(formatted_input)
print(model)
response = CLIENTS[model].text_generation(
formatted_input,
temperature = temperature,
max_new_tokens = max_tokens,
top_p = top_p,
top_k = top_k,
repetition_penalty = rep_p,
stop_sequences = stops,
do_sample = True,
seed = seed,
stream = False,
details = False,
return_full_text = False
)
result = wrap % (response)
for stop in stops:
result = result.split(stop, 1)[0]
for symbol in stops:
result = result.replace(symbol, '')
history = history + [[input, result]]
print(f"---\nUSER: {input}\nBOT: {result}\n---")
return (result, input, history)
def clear_history():
print(">>> HISTORY CLEARED!")
return []
def cloud():
print("[CLOUD] | Space maintained.")
with gr.Blocks() as demo:
with gr.Row(variant = "panel"):
gr.Markdown("✨ A LLM space owned within Statical.")
with gr.Row():
with gr.Column():
history = gr.Chatbot(label = "History", elem_id = "chatbot")
input = gr.Textbox(label = "Input", value = DEFAULT_INPUT, lines = 2)
wrap = gr.Textbox(label = "Wrap", value = DEFAULT_WRAP, lines = 1)
instruction = gr.Textbox(label = "Instruction", value = DEFAULT_INSTRUCTION, lines = 4)
access_key = gr.Textbox(label = "Access Key", lines = 1)
run = gr.Button("▶")
clear = gr.Button("🗑️")
maintain = gr.Button("☁️")
with gr.Column():
model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model")
temperature = gr.Slider( minimum = 0, maximum = 2, value = 1, step = 0.01, interactive = True, label = "Temperature" )
top_p = gr.Slider( minimum = 0.01, maximum = 0.99, value = 0.95, step = 0.01, interactive = True, label = "Top P" )
top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" )
rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" )
max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" )
stop_seqs = gr.Textbox( value = DEFAULT_STOPS, interactive = True, label = "Stop Sequences ( JSON Array / 4 Max )" )
seed = gr.Slider( minimum = 0, maximum = 9007199254740991, value = 42, step = 1, interactive = True, label = "Seed" )
with gr.Row():
with gr.Column():
output = gr.Textbox(label = "Output", value = "", lines = 50)
run.click(predict, inputs = [access_key, instruction, history, input, wrap, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input, history], queue = False)
clear.click(clear_history, [], history, queue = False)
maintain.click(cloud, inputs = [], outputs = [], queue = False)
demo.launch(show_api = True)