VPTQ_demo / app.py
OpenSourceRonin's picture
Update app.py
5c539b4 verified
raw
history blame
4.05 kB
import spaces
import os
import threading
import gradio as gr
from huggingface_hub import snapshot_download
from vptq.app_utils import get_chat_loop_generator
models = [
{
"name": "VPTQ-community/Meta-Llama-3.1-8B-Instruct-v12-k65536-4096-woft",
"bits": "2.3 bits"
},
{
"name": "VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-256-woft",
"bits": "3 bits"
},
{
"name": "VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-4096-woft",
"bits": "3.5 bits"
},
{
"name": "VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k32768-0-woft",
"bits": "1.85 bits"
},
]
def initialize_history():
"""
Initializes the GPU utilization and memory usage history.
"""
for _ in range(100):
gpu_info = get_gpu_info()
gpu_util_history.append(round(gpu_info.get('gpu_util', 0), 1))
mem_usage_history.append(round(gpu_info.get('mem_percent', 0), 1))
model_choices = [f"{model['name']} ({model['bits']})" for model in models]
display_to_model = {f"{model['name']} ({model['bits']})": model['name'] for model in models}
def download_model(model):
print(f"Downloading {model['name']}...")
snapshot_download(repo_id=model['name'])
def download_models_in_background():
print('Downloading models for the first time...')
for model in models:
download_model(model)
download_thread = threading.Thread(target=download_models_in_background)
download_thread.start()
loaded_model = None
loaded_model_name = None
@spaces.GPU
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
selected_model_display_label,
):
model_name = display_to_model[selected_model_display_label]
global loaded_model
global loaded_model_name
# Check if the model is already loaded
if model_name is not loaded_model_name:
# Load and store the model in the cache
loaded_model = get_chat_loop_generator(model_name)
loaded_model_name = model_name
chat_completion = loaded_model
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message
response += token
yield response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
# enable_gpu_info()
with gr.Blocks(fill_height=True) as demo:
# with gr.Row():
# def update_chart():
# return _update_charts(chart_height=200)
# gpu_chart = gr.Plot(update_chart, every=0.1) # update every 0.1 seconds
with gr.Column():
chat_interface = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
gr.Dropdown(
choices=model_choices,
value=model_choices[0],
label="Select Model",
),
],
)
if __name__ == "__main__":
share = os.getenv("SHARE_LINK", None) in ["1", "true", "True"]
demo.launch(share=share)