"""A simple web interactive chat demo based on gradio."""
from argparse import ArgumentParser
from threading import Thread
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
)
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
stop_ids = (
[2, 6, 7, 8],
) # "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>"
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops = [], encounters=1):
super().__init__()
self.stops = [stop.to("cuda") for stop in stops]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
last_token = input_ids[0][-1]
for stop in self.stops:
if tokenizer.decode(stop) == tokenizer.decode(last_token):
return True
return False
def parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'
'
else:
lines[i] = f"
"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
lines[i] = "
" + line
text = "".join(lines)
return text
def predict(history, max_length, top_p, temperature):
stop = StopOnTokens()
# messages = [{"role": "system", "content": "You are a helpful assistant"}]
messages = [{"role": "system", "content": ""}]
# messages = []
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
print("\n\n====conversation====\n", messages)
model_inputs = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
).to(next(model.parameters()).device)
streamer = TextIteratorStreamer(
tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
)
# stop_words = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>"]
stop_words = [""]
stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
generate_kwargs = {
"input_ids": model_inputs,
"streamer": streamer,
"max_new_tokens": max_length,
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"stopping_criteria": stopping_criteria,
"repetition_penalty": 1.1,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
for new_token in streamer:
if new_token != "":
history[-1][1] += new_token
yield history
def main(args):
with gr.Blocks() as demo:
# gr.Markdown(
# """\
# """
# )
# gr.Markdown("""
Yi-Chat Bot""")
gr.Markdown("""🦣MAmmoTH2""")
# gr.Markdown(
# """\
# This WebUI is based on Yi-Chat, developed by 01-AI."""
# )
gr.Markdown(
"""\
MAmmoTH2-8x7B-Plus 🤗 """
# 🤖 
#  Yi GitHub
)
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(
show_label=False,
placeholder="Input...",
lines=10,
container=False,
)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("🚀 Submit")
with gr.Column(scale=1):
emptyBtn = gr.Button("🧹 Clear History")
max_length = gr.Slider(
0,
32768,
value=4096,
step=1.0,
label="Maximum length",
interactive=True,
)
top_p = gr.Slider(
0, 1, value=1.0, step=0.01, label="Top P", interactive=True
)
temperature = gr.Slider(
0.01, 1, value=0.7, step=0.01, label="Temperature", interactive=True
)
def user(query, history):
# return "", history + [[parse_text(query), ""]]
return "", history + [[query, ""]]
submitBtn.click(
user, [user_input, chatbot], [user_input, chatbot], queue=False
).then(predict, [chatbot, max_length, top_p, temperature], chatbot)
user_input.submit(
user, [user_input, chatbot], [user_input, chatbot], queue=False
).then(predict, [chatbot, max_length, top_p, temperature], chatbot)
emptyBtn.click(lambda: None, None, chatbot, queue=False)
demo.queue()
demo.launch(
server_name=args.server_name,
server_port=args.server_port,
inbrowser=args.inbrowser,
share=args.share
)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"-c",
"--checkpoint-path",
type=str,
default="TIGER-Lab/MAmmoTH2-8B-Plus",
help="Checkpoint name or path, default to %(default)r",
)
parser.add_argument(
"--cpu-only", action="store_true", help="Run demo with CPU only"
)
parser.add_argument(
"--share",
action="store_true",
default=False,
help="Create a publicly shareable link for the interface.",
)
parser.add_argument(
"--inbrowser",
action="store_true",
default=True,
help="Automatically launch the interface in a new tab on the default browser.",
)
parser.add_argument(
"--server-port", type=int, default=8110, help="Demo server port."
)
parser.add_argument(
"--server-name", type=str, default="127.0.0.1", help="Demo server name."
)
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(
args.checkpoint_path, trust_remote_code=True
)
if args.cpu_only:
device_map = "cpu"
else:
device_map = "auto"
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
device_map=device_map,
torch_dtype="auto",
trust_remote_code=True,
).eval()
main(args)