#!/usr/bin/python3 # -*- coding: utf-8 -*- from typing import List, Tuple from threading import Thread import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch from project_settings import project_path def greet(question: str, history: List[Tuple[str, str]]): answer = "Hello " + question + "!" result = history + [(question, answer)] return result model_map: dict = dict() def init_model(pretrained_model_name_or_path: str): device: str = "cuda" if torch.cuda.is_available() else "cpu" global model_map if pretrained_model_name_or_path not in model_map.keys(): # clear for k1, v1 in model_map.items(): for k2, v2 in v1.items(): del v2 model_map = dict() # build model model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, device_map="auto", offload_folder="./offload", offload_state_dict=True, # load_in_4bit=True, ) model = model.to(device) model = model.bfloat16().eval() tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, trust_remote_code=True, # llama不支持fast use_fast=False if model.config.model_type == "llama" else True, padding_side="left" ) # QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|> if tokenizer.__class__.__name__ == "QWenTokenizer": tokenizer.pad_token_id = tokenizer.eod_id tokenizer.bos_token_id = tokenizer.eod_id tokenizer.eos_token_id = tokenizer.eod_id model_map[pretrained_model_name_or_path] = { "model": model, "tokenizer": tokenizer, } else: model = model_map[pretrained_model_name_or_path]["model"] tokenizer = model_map[pretrained_model_name_or_path]["tokenizer"] return model, tokenizer def chat_with_llm_non_stream(question: str, history: List[Tuple[str, str]], pretrained_model_name_or_path: str, max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float, ): device: str = "cuda" if torch.cuda.is_available() else "cpu" model, tokenizer = init_model(pretrained_model_name_or_path) text_list = list() for pair in history: text_list.extend(pair) text_list.append(question) text_encoded = tokenizer.__call__(text_list, add_special_tokens=False) batch_input_ids = text_encoded["input_ids"] input_ids = [tokenizer.bos_token_id] for input_ids_ in batch_input_ids: input_ids.extend(input_ids_) input_ids.append(tokenizer.eos_token_id) input_ids = torch.tensor([input_ids], dtype=torch.long).to(device) with torch.no_grad(): outputs = model.generate( input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id ) outputs = outputs.tolist()[0][len(input_ids[0]):] answer = tokenizer.decode(outputs) answer = answer.strip().replace(tokenizer.eos_token, "").strip() result = history + [(question, answer)] return result def main(): description = """ chat llm """ with gr.Blocks() as blocks: gr.Markdown(value=description) chatbot = gr.Chatbot([], elem_id="chatbot", height=400) with gr.Row(): with gr.Column(scale=4): text_box = gr.Textbox(show_label=False, placeholder="Enter text and press enter", container=False) with gr.Column(scale=1): submit_button = gr.Button("💬Submit") with gr.Column(scale=1): clear_button = gr.Button( '🗑️Clear', variant='secondary', ) with gr.Row(): with gr.Column(scale=1): max_new_tokens = gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens") with gr.Column(scale=1): top_p = gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p") with gr.Column(scale=1): temperature = gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature") with gr.Column(scale=1): repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty") with gr.Row(): model_name = gr.Dropdown(choices=["Qwen/Qwen-7B-Chat"], value="Qwen/Qwen-7B-Chat", label="model_name", ) gr.Examples(examples=["你好"], inputs=text_box) inputs = [ text_box, chatbot, model_name, max_new_tokens, top_p, temperature, repetition_penalty, ] outputs = [ chatbot ] text_box.submit(chat_with_llm_non_stream, inputs, outputs) submit_button.click(chat_with_llm_non_stream, inputs, outputs) clear_button.click( fn=lambda: ('', ''), outputs=[text_box, chatbot], queue=False, api_name=False, ) blocks.queue().launch() return if __name__ == '__main__': main()