Jose Benitez commited on
Commit
61debfb
·
1 Parent(s): dc2215e

update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -4
app.py CHANGED
@@ -1,7 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
1
+ # import gradio as gr
2
+ # import torch
3
+ # from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ # def load_model():
6
+ # model = AutoModelForCausalLM.from_pretrained("mattshumer/mistral-8x7b-chat", trust_remote_code=True)
7
+ # tok = AutoTokenizer.from_pretrained("mattshumer/mistral-8x7b-chat")
8
+ # return model, tok
9
+
10
+ # def inference(model, tok, PROMPT):
11
+ # x = tok.encode(PROMPT, return_tensors="pt").cuda()
12
+ # x = model.generate(x, max_new_tokens=512).cpu()
13
+ # return tok.batch_decode(x)
14
+
15
+
16
+ # gr.ChatInterface(inference).queue().launch()
17
+
18
+
19
  import gradio as gr
20
+ import torch
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
22
+ from threading import Thread
23
+
24
+ tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
25
+ model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16)
26
+ model = model.to('cuda:0')
27
+
28
+ class StopOnTokens(StoppingCriteria):
29
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
30
+ stop_ids = [29, 0]
31
+ for stop_id in stop_ids:
32
+ if input_ids[0][-1] == stop_id:
33
+ return True
34
+ return False
35
+
36
+ def predict(message, history):
37
+
38
+ history_transformer_format = history + [[message, ""]]
39
+ stop = StopOnTokens()
40
+
41
+ messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
42
+ for item in history_transformer_format])
43
+
44
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
45
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
46
+ generate_kwargs = dict(
47
+ model_inputs,
48
+ streamer=streamer,
49
+ max_new_tokens=1024,
50
+ do_sample=True,
51
+ top_p=0.95,
52
+ top_k=1000,
53
+ temperature=1.0,
54
+ num_beams=1,
55
+ stopping_criteria=StoppingCriteriaList([stop])
56
+ )
57
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
58
+ t.start()
59
+
60
+ partial_message = ""
61
+ for new_token in streamer:
62
+ if new_token != '<':
63
+ partial_message += new_token
64
+ yield partial_message
65
+
66
+
67
+ gr.ChatInterface(predict).queue().launch()
68
+
69
+
70
+
71
+ def predict(message, history):
72
+ history_openai_format = []
73
+ for human, assistant in history:
74
+ history_openai_format.append({"role": "user", "content": human })
75
+ history_openai_format.append({"role": "assistant", "content":assistant})
76
+ history_openai_format.append({"role": "user", "content": message})
77
+
78
+ response = openai.ChatCompletion.create(
79
+ model='gpt-3.5-turbo',
80
+ messages= history_openai_format,
81
+ temperature=1.0,
82
+ stream=True
83
+ )
84
 
85
+ partial_message = ""
86
+ for chunk in response:
87
+ if len(chunk['choices'][0]['delta']) != 0:
88
+ partial_message = partial_message + chunk['choices'][0]['delta']['content']
89
+ yield partial_message
90