""" This script creates an interactive web demo for the GLM-4-9B model using Gradio, a Python library for building quick and easy UI components for machine learning models. It's designed to showcase the capabilities of the GLM-4-9B model in a user-friendly interface, allowing users to interact with the model through a chat-like interface. """ import os from pathlib import Path from threading import Thread from typing import Union import gradio as gr import torch from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM from transformers import ( AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer ) ModelType = Union[PreTrainedModel, PeftModelForCausalLM] TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat') #MODEL_PATH = "/Users/zmac/.cache/huggingface/hub/models--THUDM--glm-4-9b-chat/snapshots/04419001bc63e05e70991ade6da1f91c4aeec278" MODEL_PATH = "/Users/zmac/Documents/opensrc/llms/GLM-4/models" TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH) def _resolve_path(path: Union[str, Path]) -> Path: return Path(path).expanduser().resolve() def load_model_and_tokenizer( model_dir: Union[str, Path], trust_remote_code: bool = True ) -> tuple[ModelType, TokenizerType]: model_dir = _resolve_path(model_dir) if (model_dir / 'adapter_config.json').exists(): model = AutoPeftModelForCausalLM.from_pretrained( model_dir, trust_remote_code=trust_remote_code, device_map='auto' ) tokenizer_dir = model.peft_config['default'].base_model_name_or_path else: model = AutoModelForCausalLM.from_pretrained( model_dir, trust_remote_code=trust_remote_code, device_map='auto' ) tokenizer_dir = model_dir tokenizer = AutoTokenizer.from_pretrained( tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False ) return model, tokenizer model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True) class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = model.config.eos_token_id for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False def parse_text(text): lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split('`') if count % 2 == 1: lines[i] = f'
'
else:
lines[i] = f'
'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
lines[i] = "