Spaces:
Paused
Paused
nroggendorff
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,14 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
5 |
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
|
6 |
from threading import Thread
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
@spaces.GPU
|
9 |
def predict(message, history):
|
10 |
torch.set_default_device("cuda")
|
@@ -20,8 +28,9 @@ def predict(message, history):
|
|
20 |
trust_remote_code=True
|
21 |
)
|
22 |
history_transformer_format = history + [[message, ""]]
|
|
|
23 |
|
24 |
-
system_prompt = "<|im_start|>system\nYou are
|
25 |
messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
|
26 |
input_ids = tokenizer([messages], return_tensors="pt").to('cuda')
|
27 |
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
|
@@ -32,8 +41,9 @@ def predict(message, history):
|
|
32 |
do_sample=True,
|
33 |
top_p=0.95,
|
34 |
top_k=50,
|
35 |
-
temperature=0.
|
36 |
-
num_beams=1
|
|
|
37 |
)
|
38 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
39 |
t.start()
|
@@ -46,4 +56,5 @@ def predict(message, history):
|
|
46 |
|
47 |
|
48 |
gr.ChatInterface(predict,
|
|
|
49 |
).launch()
|
|
|
5 |
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
|
6 |
from threading import Thread
|
7 |
|
8 |
+
class StopOnTokens(StoppingCriteria):
|
9 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
10 |
+
stop_ids = [50256, 50295]
|
11 |
+
for stop_id in stop_ids:
|
12 |
+
if input_ids[0][-1] == stop_id:
|
13 |
+
return True
|
14 |
+
return False
|
15 |
+
|
16 |
@spaces.GPU
|
17 |
def predict(message, history):
|
18 |
torch.set_default_device("cuda")
|
|
|
28 |
trust_remote_code=True
|
29 |
)
|
30 |
history_transformer_format = history + [[message, ""]]
|
31 |
+
stop = StopOnTokens()
|
32 |
|
33 |
+
system_prompt = "<|im_start|>system\nYou are Dolphin, a helpful AI assistant.<|im_end|>"
|
34 |
messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
|
35 |
input_ids = tokenizer([messages], return_tensors="pt").to('cuda')
|
36 |
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
|
|
|
41 |
do_sample=True,
|
42 |
top_p=0.95,
|
43 |
top_k=50,
|
44 |
+
temperature=0.7,
|
45 |
+
num_beams=1,
|
46 |
+
stopping_criteria=StoppingCriteriaList([stop])
|
47 |
)
|
48 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
49 |
t.start()
|
|
|
56 |
|
57 |
|
58 |
gr.ChatInterface(predict,
|
59 |
+
theme=gr.themes.Soft(primary_hue="purple"),
|
60 |
).launch()
|