""" This script creates an interactive web demo for the GLM-4-9B model using Gradio, a Python library for building quick and easy UI components for machine learning models. It's designed to showcase the capabilities of the GLM-4-9B model in a user-friendly interface, allowing users to interact with the model through a chat-like interface. """ import os from pathlib import Path from threading import Thread from typing import Union import gradio as gr import torch import pandas as pd from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM from transformers import ( AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer ) ModelType = Union[PreTrainedModel, PeftModelForCausalLM] TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat') TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH) def _resolve_path(path: Union[str, Path]) -> Path: return Path(path).expanduser().resolve() def load_model_and_tokenizer( model_dir: Union[str, Path], trust_remote_code: bool = True ) -> tuple[ModelType, TokenizerType]: model_dir = _resolve_path(model_dir) if (model_dir / 'adapter_config.json').exists(): model = AutoPeftModelForCausalLM.from_pretrained( model_dir, trust_remote_code=trust_remote_code, device_map='auto' ) tokenizer_dir = model.peft_config['default'].base_model_name_or_path else: model = AutoModelForCausalLM.from_pretrained( model_dir, trust_remote_code=trust_remote_code, device_map='auto' ) tokenizer_dir = model_dir tokenizer = AutoTokenizer.from_pretrained( tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False ) return model, tokenizer def load_knowledge_base(file: Union[str, Path]) -> pd.DataFrame: return pd.read_excel(file) def retrieve_from_knowledge_base(query: str, knowledge_base: pd.DataFrame) -> str: # Convert the knowledge base to a dictionary kb_dict = pd.Series(knowledge_base.iloc[:, 1].values, index=knowledge_base.iloc[:, 0]).to_dict() # Search for relevant fields relevant_info = [] for field, content in kb_dict.items(): if query.lower() in field.lower() or query.lower() in content.lower(): relevant_info.append(f"{field}: {content}") if not relevant_info: return "No relevant information found." return "\n".join(relevant_info) model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True) knowledge_base = pd.DataFrame() class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = model.config.eos_token_id for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False def parse_text(text): lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split('`') if count % 2 == 1: lines[i] = f'
'
            else:
                lines[i] = f'
' else: if i > 0: if count % 2 == 1: line = line.replace("`", "\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "
" + line text = "".join(lines) return text def predict(history, prompt, max_length, top_p, temperature): stop = StopOnTokens() messages = [] if prompt: messages.append({"role": "system", "content": prompt}) for idx, (user_msg, model_msg) in enumerate(history): if prompt and idx == 0: continue if idx == len(history) - 1 and not model_msg: messages.append({"role": "user", "content": user_msg}) break if user_msg: messages.append({"role": "user", "content": user_msg}) if model_msg: messages.append({"role": "assistant", "content": model_msg}) if not knowledge_base.empty: knowledge_text = retrieve_from_knowledge_base(messages[-1]['content'], knowledge_base) messages.append({"role": "system", "content": knowledge_text}) model_inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt").to(next(model.parameters()).device) streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True) generate_kwargs = { "input_ids": model_inputs, "streamer": streamer, "max_new_tokens": max_length, "do_sample": True, "top_p": top_p, "temperature": temperature, "stopping_criteria": StoppingCriteriaList([stop]), "repetition_penalty": 1.2, "eos_token_id": model.config.eos_token_id, } t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() for new_token in streamer: if new_token: history[-1][1] += new_token yield history def upload_file(file): global knowledge_base knowledge_base = load_knowledge_base(file.name) return f"Uploaded {file.name}" with gr.Blocks() as demo: gr.HTML("""

GLM-4-9B Gradio Simple Chat Demo

""") chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=3): with gr.Column(scale=12): user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False) with gr.Column(min_width=32, scale=1): submitBtn = gr.Button("Submit") with gr.Column(scale=1): prompt_input = gr.Textbox(show_label=False, placeholder="Prompt", lines=10, container=False) pBtn = gr.Button("Set Prompt") with gr.Column(scale=1): emptyBtn = gr.Button("Clear History") max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True) file_upload = gr.File(label="Upload Knowledge Base (.xlsx)", type="filepath", file_types=[".xlsx"]) upload_message = gr.Textbox(label="", placeholder="", interactive=False) def user(query, history): return "", history + [[parse_text(query), ""]] def set_prompt(prompt_text): return [[parse_text(prompt_text), "成功设置prompt"]] pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot) submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then( predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot ) emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False) file_upload.upload(upload_file, inputs=file_upload, outputs=upload_message) demo.queue() demo.launch(server_name="127.0.0.1", server_port=8000, inbrowser=True, share=True)