import gradio as gr #import torch #from threading import Thread #tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1") #model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16) #model = model.to('cuda:0') #class StopOnTokens(StoppingCriteria): # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: # stop_ids = [29, 0] # for stop_id in stop_ids: # if input_ids[0][-1] == stop_id: # return True # return False def predict(message, history): history_transformer_format = history + [[message, ""]] messages = "".join(["".join(["\n:"+item[0], "\n:"+item[1]]) for item in history_transformer_format]) #model_inputs = tokenizer([messages], return_tensors="pt").to("cuda") #streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) #generate_kwargs = dict( # model_inputs, # streamer=streamer, # max_new_tokens=1024, # do_sample=True, # top_p=0.95, # top_k=1000, # temperature=1.0, # num_beams=1, # stopping_criteria=StoppingCriteriaList([stop]) # ) #t = Thread(target=model.generate, kwargs=generate_kwargs) #t.start() #partial_message = "" #for new_token in streamer: # if new_token != '<': # partial_message += new_token # yield partial_message return "Hello" gr.ChatInterface(predict).launch()