Mishmosh commited on
Commit
5be194a
1 Parent(s): b8a446a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install gradio
2
+ !pip install transformers
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
5
+ from threading import Thread
6
+ tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
7
+ model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16)
8
+ model = model.to('cuda:0')
9
+ class StopOnTokens(StoppingCriteria):
10
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
11
+ stop_ids = [29, 0]
12
+ for stop_id in stop_ids:
13
+ if input_ids[0][-1] == stop_id:
14
+ return True
15
+ return False
16
+ def predict(message, history):
17
+
18
+ history_transformer_format = history + [[message, ""]]
19
+ stop = StopOnTokens()
20
+
21
+ messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
22
+ for item in history_transformer_format])
23
+
24
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
25
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
26
+ generate_kwargs = dict(
27
+ model_inputs,
28
+ streamer=streamer,
29
+ max_new_tokens=1024,
30
+ do_sample=True,
31
+ top_p=0.95,
32
+ top_k=1000,
33
+ temperature=1.0,
34
+ num_beams=1,
35
+ stopping_criteria=StoppingCriteriaList([stop])
36
+ )
37
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
38
+ t.start()
39
+
40
+ partial_message = ""
41
+ for new_token in streamer:
42
+ if new_token != '<':
43
+ partial_message += new_token
44
+ yield partial_message
45
+
46
+ import gradio as gr
47
+
48
+ demo_llm = gr.ChatInterface(predict,
49
+ chatbot=gr.Chatbot(height=500),
50
+ textbox=gr.Textbox(placeholder="Ask me a yes or no question", container=False, scale=7),
51
+ title="LLM Chatbot",
52
+ description="Chat with LLM",
53
+ theme="soft",
54
+ examples=["Hello", "What are you?", "What is the meaning of life?"],
55
+ cache_examples=True,
56
+ retry_btn=None,
57
+ undo_btn="Delete Previous",
58
+ clear_btn="Clear")
59
+ demo_llm.queue().launch()