chat_with_llm / main.py
qgyd2021's picture
[update]add main
72ad2e5
raw
history blame
8.04 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import List, Tuple
from threading import Thread
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.streamers import TextIteratorStreamer
import torch
from project_settings import project_path
def greet(question: str, history: List[Tuple[str, str]]):
answer = "Hello " + question + "!"
result = history + [(question, answer)]
return result
model_map: dict = dict()
def init_model(pretrained_model_name_or_path: str):
device: str = "cuda" if torch.cuda.is_available() else "cpu"
global model_map
if pretrained_model_name_or_path not in model_map.keys():
# clear
for k1, v1 in model_map.items():
for k2, v2 in v1.items():
del v2
model_map = dict()
# build model
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
device_map="auto",
offload_folder="./offload",
offload_state_dict=True,
# load_in_4bit=True,
)
model = model.to(device)
model = model.bfloat16().eval()
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
# llama不支持fast
use_fast=False if model.config.model_type == "llama" else True,
padding_side="left"
)
# QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|>
if tokenizer.__class__.__name__ == "QWenTokenizer":
tokenizer.pad_token_id = tokenizer.eod_id
tokenizer.bos_token_id = tokenizer.eod_id
tokenizer.eos_token_id = tokenizer.eod_id
model_map[pretrained_model_name_or_path] = {
"model": model,
"tokenizer": tokenizer,
}
else:
model = model_map[pretrained_model_name_or_path]["model"]
tokenizer = model_map[pretrained_model_name_or_path]["tokenizer"]
return model, tokenizer
def chat_with_llm_non_stream(question: str,
history: List[Tuple[str, str]],
pretrained_model_name_or_path: str,
max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
history_max_len: int,
):
device: str = "cuda" if torch.cuda.is_available() else "cpu"
model, tokenizer = init_model(pretrained_model_name_or_path)
text_list = list()
for pair in history:
text_list.extend(pair)
text_list.append(question)
text_encoded = tokenizer.__call__(text_list, add_special_tokens=False)
batch_input_ids = text_encoded["input_ids"]
input_ids = [tokenizer.bos_token_id]
for input_ids_ in batch_input_ids:
input_ids.extend(input_ids_)
input_ids.append(tokenizer.eos_token_id)
input_ids = torch.tensor([input_ids], dtype=torch.long)
input_ids = input_ids[:, -history_max_len:].to(device)
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
eos_token_id=tokenizer.eos_token_id
)
outputs = outputs.tolist()[0][len(input_ids[0]):]
answer = tokenizer.decode(outputs)
answer = answer.strip().replace(tokenizer.eos_token, "").strip()
result = history + [(question, answer)]
return result
def chat_with_llm_streaming(question: str,
history: List[Tuple[str, str]],
pretrained_model_name_or_path: str,
max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
history_max_len: int,
):
device: str = "cuda" if torch.cuda.is_available() else "cpu"
model, tokenizer = init_model(pretrained_model_name_or_path)
text_list = list()
for pair in history:
text_list.extend(pair)
text_list.append(question)
text_encoded = tokenizer.__call__(text_list, add_special_tokens=False)
batch_input_ids = text_encoded["input_ids"]
input_ids = [tokenizer.bos_token_id]
for input_ids_ in batch_input_ids:
input_ids.extend(input_ids_)
input_ids.append(tokenizer.eos_token_id)
input_ids = torch.tensor([input_ids], dtype=torch.long)
input_ids = input_ids[:, -history_max_len:].to(device)
streamer = TextIteratorStreamer(tokenizer=tokenizer)
generation_kwargs = dict(
inputs=input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
answer = ""
for output_ in streamer:
output_ = output_.replace(question, "")
output_ = output_.replace(tokenizer.eos_token, "")
answer += output_
result = [(question, answer)]
yield history + result
def main():
description = """
chat llm
"""
with gr.Blocks() as blocks:
gr.Markdown(value=description)
chatbot = gr.Chatbot([], elem_id="chatbot", height=400)
with gr.Row():
with gr.Column(scale=4):
text_box = gr.Textbox(show_label=False, placeholder="Enter text and press enter", container=False)
with gr.Column(scale=1):
submit_button = gr.Button("💬Submit")
with gr.Column(scale=1):
clear_button = gr.Button(
'🗑️Clear',
variant='secondary',
)
with gr.Row():
with gr.Column(scale=1):
max_new_tokens = gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens")
with gr.Column(scale=1):
top_p = gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p")
with gr.Column(scale=1):
temperature = gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature")
with gr.Column(scale=1):
repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty")
with gr.Column(scale=1):
history_max_len = gr.Slider(minimum=0, maximum=4096, value=1024, step=1, label="history_max_len")
with gr.Row():
with gr.Column(scale=1):
model_name = gr.Dropdown(
choices=[
"Qwen/Qwen-7B-Chat",
"THUDM/chatglm2-6b",
"baichuan-inc/Baichuan2-7B-Chat",
],
value="Qwen/Qwen-7B-Chat",
label="model_name",
)
gr.Examples(examples=["你好"], inputs=text_box)
inputs = [
text_box, chatbot, model_name,
max_new_tokens, top_p, temperature, repetition_penalty,
history_max_len
]
outputs = [
chatbot
]
text_box.submit(chat_with_llm_streaming, inputs, outputs)
submit_button.click(chat_with_llm_streaming, inputs, outputs)
clear_button.click(
fn=lambda: ('', ''),
outputs=[text_box, chatbot],
queue=False,
api_name=False,
)
blocks.queue().launch()
return
if __name__ == '__main__':
main()