Heng666's picture
Update app.py
95c51fd verified
raw
history blame contribute delete
No virus
3.31 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from transformers.generation.utils import GenerationConfig
from threading import Thread
# Loading the tokenizer and model from Hugging Face's model hub.
# model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained(model_name,trust_remote_code=True)
# model_name_or_path = "Flmc/DISC-MedLLM"
# tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False, trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
# model.generation_config = GenerationConfig.from_pretrained(model_name_or_path)
model_name_or_path = "scutcyr/BianQue-2"
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True).half()
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,trust_remote_code=True)
# using CUDA for an optimal experience
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Defining a custom stopping criteria class for the model's text generation.
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [2] # IDs of tokens where the generation should stop.
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
return True
return False
# Function to generate model predictions.
def predict(message, history):
history_transformer_format = history + [[message, ""]]
stop = StopOnTokens()
# Formatting the input for the model.
messages = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
for item in history_transformer_format])
model_inputs = tokenizer([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=2048,
do_sample=True,
top_p=0.75,
top_k=50,
temperature=0.95,
num_beams=1,
# stopping_criteria=StoppingCriteriaList([stop]) 暫時拿掉
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start() # Starting the generation in a separate thread.
partial_message = ""
for new_token in streamer:
partial_message += new_token
if '</s>' in partial_message: # Breaking the loop if the stop token is generated.
break
yield partial_message
# Setting up the Gradio chat interface.
gr.ChatInterface(predict,
title="TCM_ChatBLM_chatBot",
description="Ask TCM_ChatBLM_chatBot any questions",
examples=['你好,我最近失眠,可以怎麼解決?', '請問有沒有跌打藥可以用?']
).launch() # Launching the web interface.