"""A simple web interactive chat demo based on gradio.""" from argparse import ArgumentParser from threading import Thread import gradio as gr import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, ) class StopOnTokens(StoppingCriteria): def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> bool: stop_ids = ( [2, 6, 7, 8], ) # "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>" for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops = [], encounters=1): super().__init__() self.stops = [stop.to("cuda") for stop in stops] def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): last_token = input_ids[0][-1] for stop in self.stops: if tokenizer.decode(stop) == tokenizer.decode(last_token): return True return False def parse_text(text): lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split("`") if count % 2 == 1: lines[i] = f'
'
            else:
                lines[i] = f"
" else: if i > 0: if count % 2 == 1: line = line.replace("`", "\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "
" + line text = "".join(lines) return text def predict(history, max_length, top_p, temperature): stop = StopOnTokens() # messages = [{"role": "system", "content": "You are a helpful assistant"}] messages = [{"role": "system", "content": ""}] # messages = [] for idx, (user_msg, model_msg) in enumerate(history): if idx == len(history) - 1 and not model_msg: messages.append({"role": "user", "content": user_msg}) break if user_msg: messages.append({"role": "user", "content": user_msg}) if model_msg: messages.append({"role": "assistant", "content": model_msg}) print("\n\n====conversation====\n", messages) model_inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt" ).to(next(model.parameters()).device) streamer = TextIteratorStreamer( tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True ) # stop_words = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>"] stop_words = [""] stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words] stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) generate_kwargs = { "input_ids": model_inputs, "streamer": streamer, "max_new_tokens": max_length, "do_sample": True, "top_p": top_p, "temperature": temperature, "stopping_criteria": stopping_criteria, "repetition_penalty": 1.1, } t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() for new_token in streamer: if new_token != "": history[-1][1] += new_token yield history def main(args): with gr.Blocks() as demo: # gr.Markdown( # """\ #

""" # ) # gr.Markdown("""

Yi-Chat Bot
""") gr.Markdown("""
🦣MAmmoTH2
""") # gr.Markdown( # """\ #
This WebUI is based on Yi-Chat, developed by 01-AI.
""" # ) gr.Markdown( """\
MAmmoTH2-8x7B-Plus 🤗 """ # 🤖  #  Yi GitHub
) chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=4): with gr.Column(scale=12): user_input = gr.Textbox( show_label=False, placeholder="Input...", lines=10, container=False, ) with gr.Column(min_width=32, scale=1): submitBtn = gr.Button("🚀 Submit") with gr.Column(scale=1): emptyBtn = gr.Button("🧹 Clear History") max_length = gr.Slider( 0, 32768, value=4096, step=1.0, label="Maximum length", interactive=True, ) top_p = gr.Slider( 0, 1, value=1.0, step=0.01, label="Top P", interactive=True ) temperature = gr.Slider( 0.01, 1, value=0.7, step=0.01, label="Temperature", interactive=True ) def user(query, history): # return "", history + [[parse_text(query), ""]] return "", history + [[query, ""]] submitBtn.click( user, [user_input, chatbot], [user_input, chatbot], queue=False ).then(predict, [chatbot, max_length, top_p, temperature], chatbot) user_input.submit( user, [user_input, chatbot], [user_input, chatbot], queue=False ).then(predict, [chatbot, max_length, top_p, temperature], chatbot) emptyBtn.click(lambda: None, None, chatbot, queue=False) demo.queue() demo.launch( server_name=args.server_name, server_port=args.server_port, inbrowser=args.inbrowser, share=args.share ) if __name__ == "__main__": parser = ArgumentParser() parser.add_argument( "-c", "--checkpoint-path", type=str, default="TIGER-Lab/MAmmoTH2-8B-Plus", help="Checkpoint name or path, default to %(default)r", ) parser.add_argument( "--cpu-only", action="store_true", help="Run demo with CPU only" ) parser.add_argument( "--share", action="store_true", default=False, help="Create a publicly shareable link for the interface.", ) parser.add_argument( "--inbrowser", action="store_true", default=True, help="Automatically launch the interface in a new tab on the default browser.", ) parser.add_argument( "--server-port", type=int, default=8110, help="Demo server port." ) parser.add_argument( "--server-name", type=str, default="127.0.0.1", help="Demo server name." ) args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained( args.checkpoint_path, trust_remote_code=True ) if args.cpu_only: device_map = "cpu" else: device_map = "auto" model = AutoModelForCausalLM.from_pretrained( args.checkpoint_path, device_map=device_map, torch_dtype="auto", trust_remote_code=True, ).eval() main(args)