Spaces:
Runtime error
Runtime error
File size: 4,794 Bytes
a845f24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import List, Tuple
from threading import Thread
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from project_settings import project_path
def greet(question: str, history: List[Tuple[str, str]]):
answer = "Hello " + question + "!"
result = history + [(question, answer)]
return result
def chat_with_llm_non_stream(question: str,
history: List[Tuple[str, str]],
pretrained_model_name_or_path: str,
max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
device_map="auto",
offload_folder="./offload",
offload_state_dict=True,
# load_in_4bit=True,
)
model = model.to(device)
model = model.bfloat16().eval()
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
# llama不支持fast
use_fast=False if model.config.model_type == "llama" else True,
padding_side="left"
)
# QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|>
if tokenizer.__class__.__name__ == "QWenTokenizer":
tokenizer.pad_token_id = tokenizer.eod_id
tokenizer.bos_token_id = tokenizer.eod_id
tokenizer.eos_token_id = tokenizer.eod_id
input_ids = tokenizer(
question,
return_tensors="pt",
add_special_tokens=False,
).input_ids.to(device)
bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(device)
eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(device)
input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
eos_token_id=tokenizer.eos_token_id
)
outputs = outputs.tolist()[0][len(input_ids[0]):]
response = tokenizer.decode(outputs)
response = response.strip().replace(tokenizer.eos_token, "").strip()
return
def main():
description = """
chat llm
"""
with gr.Blocks() as blocks:
gr.Markdown(value="gradio demo")
chatbot = gr.Chatbot([], elem_id="chatbot", height=400)
with gr.Row():
with gr.Column(scale=4):
text_box = gr.Textbox(show_label=False, placeholder="Enter text and press enter", container=False)
with gr.Column(scale=1):
submit_button = gr.Button("💬Submit")
with gr.Column(scale=1):
clear_button = gr.Button(
'🗑️Clear',
variant='secondary',
)
with gr.Row():
with gr.Column(scale=1):
max_new_tokens = gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens"),
with gr.Column(scale=1):
top_p = gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"),
with gr.Column(scale=1):
temperature = gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"),
with gr.Column(scale=1):
repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"),
with gr.Row():
model_name = gr.Dropdown(choices=["Qwen/Qwen-7B-Chat"],
value="Qwen/Qwen-7B-Chat",
label="model_name",
)
gr.Examples(examples=["你好"], inputs=text_box)
inputs = [
text_box, chatbot, model_name,
max_new_tokens, top_p, temperature, repetition_penalty
]
outputs = [
chatbot
]
text_box.submit(chat_with_llm_non_stream, inputs, outputs)
submit_button.click(chat_with_llm_non_stream, inputs, outputs)
clear_button.click(
fn=lambda: ('', ''),
outputs=[text_box, chatbot],
queue=False,
api_name=False,
)
blocks.queue().launch()
return
if __name__ == '__main__':
main()
|