Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,15 +1,16 @@
|
|
1 |
import os
|
2 |
import time
|
3 |
-
|
4 |
import torch
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
6 |
import gradio as gr
|
7 |
from threading import Thread
|
8 |
|
9 |
-
MODEL_LIST = ["
|
10 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
|
|
11 |
|
12 |
-
TITLE = "<h1><center>
|
13 |
|
14 |
PLACEHOLDER = """
|
15 |
<center>
|
@@ -30,21 +31,12 @@ h3 {
|
|
30 |
}
|
31 |
"""
|
32 |
|
33 |
-
#
|
34 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
35 |
|
36 |
-
|
|
|
37 |
|
38 |
-
|
39 |
-
model0 = AutoModelForCausalLM.from_pretrained(MODEL_LIST[0]).to(device)
|
40 |
-
|
41 |
-
tokenizer1 = AutoTokenizer.from_pretrained(MODEL_LIST[1])
|
42 |
-
model1 = AutoModelForCausalLM.from_pretrained(MODEL_LIST[1]).to(device)
|
43 |
-
|
44 |
-
tokenizer2 = AutoTokenizer.from_pretrained(MODEL_LIST[2])
|
45 |
-
model2 = AutoModelForCausalLM.from_pretrained(MODEL_LIST[2]).to(device)
|
46 |
-
|
47 |
-
#@spaces.GPU()
|
48 |
def stream_chat(
|
49 |
message: str,
|
50 |
history: list,
|
@@ -53,7 +45,6 @@ def stream_chat(
|
|
53 |
top_p: float = 1.0,
|
54 |
top_k: int = 20,
|
55 |
penalty: float = 1.2,
|
56 |
-
choice: str = "135M"
|
57 |
):
|
58 |
print(f'message: {message}')
|
59 |
print(f'history: {history}')
|
@@ -67,16 +58,6 @@ def stream_chat(
|
|
67 |
|
68 |
conversation.append({"role": "user", "content": message})
|
69 |
|
70 |
-
if choice == "1.7B":
|
71 |
-
tokenizer = tokenizer0
|
72 |
-
model = model0
|
73 |
-
elif choice == "135M":
|
74 |
-
model = model1
|
75 |
-
tokenizer = tokenizer1
|
76 |
-
else:
|
77 |
-
model = model2
|
78 |
-
tokenizer = tokenizer2
|
79 |
-
|
80 |
input_text=tokenizer.apply_chat_template(conversation, tokenize=False)
|
81 |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
82 |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
@@ -154,12 +135,6 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
|
|
154 |
label="Repetition penalty",
|
155 |
render=False,
|
156 |
),
|
157 |
-
gr.Radio(
|
158 |
-
["135M", "360M", "1.7B"],
|
159 |
-
value="135M",
|
160 |
-
label="Load Model",
|
161 |
-
render=False,
|
162 |
-
),
|
163 |
],
|
164 |
examples=[
|
165 |
["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
|
|
|
1 |
import os
|
2 |
import time
|
3 |
+
import spaces
|
4 |
import torch
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
6 |
import gradio as gr
|
7 |
from threading import Thread
|
8 |
|
9 |
+
MODEL_LIST = ["mistralai/Mistral-Nemo-Instruct-2407"]
|
10 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
11 |
+
MODEL = os.environ.get("MODEL_ID")
|
12 |
|
13 |
+
TITLE = "<h1><center>Mistral-Nemo</center></h1>"
|
14 |
|
15 |
PLACEHOLDER = """
|
16 |
<center>
|
|
|
31 |
}
|
32 |
"""
|
33 |
|
34 |
+
device = "cuda" # for GPU usage or "cpu" for CPU usage
|
|
|
35 |
|
36 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
37 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL).to(device)
|
38 |
|
39 |
+
@spaces.GPU()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
def stream_chat(
|
41 |
message: str,
|
42 |
history: list,
|
|
|
45 |
top_p: float = 1.0,
|
46 |
top_k: int = 20,
|
47 |
penalty: float = 1.2,
|
|
|
48 |
):
|
49 |
print(f'message: {message}')
|
50 |
print(f'history: {history}')
|
|
|
58 |
|
59 |
conversation.append({"role": "user", "content": message})
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
input_text=tokenizer.apply_chat_template(conversation, tokenize=False)
|
62 |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
63 |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
|
|
135 |
label="Repetition penalty",
|
136 |
render=False,
|
137 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
],
|
139 |
examples=[
|
140 |
["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
|