# import gradio as gr # import torch # from transformers import AutoModelForCausalLM, AutoTokenizer # def load_model(): # model = AutoModelForCausalLM.from_pretrained("mattshumer/mistral-8x7b-chat", trust_remote_code=True) # tok = AutoTokenizer.from_pretrained("mattshumer/mistral-8x7b-chat") # return model, tok # def inference(model, tok, PROMPT): # x = tok.encode(PROMPT, return_tensors="pt").cuda() # x = model.generate(x, max_new_tokens=512).cpu() # return tok.batch_decode(x) # gr.ChatInterface(inference).queue().launch() import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer from threading import Thread #tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1") #model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16) model = AutoModelForCausalLM.from_pretrained("mattshumer/mistral-8x7b-chat", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("mattshumer/mistral-8x7b-chat") model = model.to('cuda:0') class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = [29, 0] for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False def predict(message, history): history_transformer_format = history + [[message, ""]] stop = StopOnTokens() messages = "".join(["".join(["\n:"+item[0], "\n:"+item[1]]) #curr_system_message + for item in history_transformer_format]) # x = tok.encode(PROMPT, return_tensors="pt").cuda() # x = model.generate(x, max_new_tokens=512).cpu() # return tok.batch_decode(x) model_inputs = tokenizer([messages], return_tensors="pt").to("cuda") streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=1024, do_sample=True, top_p=0.95, top_k=1000, temperature=1.0, num_beams=1, stopping_criteria=StoppingCriteriaList([stop]) ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() partial_message = "" for new_token in streamer: if new_token != '<': partial_message += new_token yield partial_message gr.ChatInterface(predict).queue().launch() def predict(message, history): history_openai_format = [] for human, assistant in history: history_openai_format.append({"role": "user", "content": human }) history_openai_format.append({"role": "assistant", "content":assistant}) history_openai_format.append({"role": "user", "content": message}) response = openai.ChatCompletion.create( model='gpt-3.5-turbo', messages= history_openai_format, temperature=1.0, stream=True ) partial_message = "" for chunk in response: if len(chunk['choices'][0]['delta']) != 0: partial_message = partial_message + chunk['choices'][0]['delta']['content'] yield partial_message