Spaces:
Runtime error
Runtime error
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
from typing import List, Tuple | |
from threading import Thread | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from project_settings import project_path | |
def greet(question: str, history: List[Tuple[str, str]]): | |
answer = "Hello " + question + "!" | |
result = history + [(question, answer)] | |
return result | |
def chat_with_llm_non_stream(question: str, | |
history: List[Tuple[str, str]], | |
pretrained_model_name_or_path: str, | |
max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float, | |
): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = AutoModelForCausalLM.from_pretrained( | |
pretrained_model_name_or_path, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
offload_folder="./offload", | |
offload_state_dict=True, | |
# load_in_4bit=True, | |
) | |
model = model.to(device) | |
model = model.bfloat16().eval() | |
tokenizer = AutoTokenizer.from_pretrained( | |
pretrained_model_name_or_path, | |
trust_remote_code=True, | |
# llama不支持fast | |
use_fast=False if model.config.model_type == "llama" else True, | |
padding_side="left" | |
) | |
# QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|> | |
if tokenizer.__class__.__name__ == "QWenTokenizer": | |
tokenizer.pad_token_id = tokenizer.eod_id | |
tokenizer.bos_token_id = tokenizer.eod_id | |
tokenizer.eos_token_id = tokenizer.eod_id | |
input_ids = tokenizer( | |
question, | |
return_tensors="pt", | |
add_special_tokens=False, | |
).input_ids.to(device) | |
bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(device) | |
eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(device) | |
input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1) | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
outputs = outputs.tolist()[0][len(input_ids[0]):] | |
response = tokenizer.decode(outputs) | |
response = response.strip().replace(tokenizer.eos_token, "").strip() | |
return | |
def main(): | |
description = """ | |
chat llm | |
""" | |
with gr.Blocks() as blocks: | |
gr.Markdown(value="gradio demo") | |
chatbot = gr.Chatbot([], elem_id="chatbot", height=400) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
text_box = gr.Textbox(show_label=False, placeholder="Enter text and press enter", container=False) | |
with gr.Column(scale=1): | |
submit_button = gr.Button("💬Submit") | |
with gr.Column(scale=1): | |
clear_button = gr.Button( | |
'🗑️Clear', | |
variant='secondary', | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
max_new_tokens = gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens"), | |
with gr.Column(scale=1): | |
top_p = gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"), | |
with gr.Column(scale=1): | |
temperature = gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"), | |
with gr.Column(scale=1): | |
repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"), | |
with gr.Row(): | |
model_name = gr.Dropdown(choices=["Qwen/Qwen-7B-Chat"], | |
value="Qwen/Qwen-7B-Chat", | |
label="model_name", | |
) | |
gr.Examples(examples=["你好"], inputs=text_box) | |
inputs = [ | |
text_box, chatbot, model_name, | |
max_new_tokens, top_p, temperature, repetition_penalty | |
] | |
outputs = [ | |
chatbot | |
] | |
text_box.submit(chat_with_llm_non_stream, inputs, outputs) | |
submit_button.click(chat_with_llm_non_stream, inputs, outputs) | |
clear_button.click( | |
fn=lambda: ('', ''), | |
outputs=[text_box, chatbot], | |
queue=False, | |
api_name=False, | |
) | |
blocks.queue().launch() | |
return | |
if __name__ == '__main__': | |
main() | |