Spaces:
Build error
Build error
Hazzzardous
commited on
Commit
•
187c661
1
Parent(s):
d43a442
add config option
Browse files
app.py
CHANGED
@@ -22,7 +22,7 @@ import codecs
|
|
22 |
from ast import literal_eval
|
23 |
from datetime import datetime
|
24 |
from rwkvstic.load import RWKV
|
25 |
-
from
|
26 |
import torch
|
27 |
import gc
|
28 |
|
@@ -33,25 +33,25 @@ desc = '''<p>RNN with Transformer-level LLM Performance (<a href='https://github
|
|
33 |
|
34 |
thanks = '''<p>Thanks to <a href='https://www.rftcapital.com'>RFT Capital</a> for donating compute capability for our experiments. Additional thanks to the author of the <a href="https://github.com/harrisonvanderbyl/rwkvstic">rwkvstic</a> library.</p>'''
|
35 |
|
|
|
36 |
def to_md(text):
|
37 |
return text.replace("\n", "<br />")
|
38 |
|
|
|
39 |
def get_model():
|
40 |
model = None
|
41 |
model = RWKV(
|
42 |
-
|
43 |
-
"pytorch(cpu/gpu)",
|
44 |
-
runtimedtype=torch.float32,
|
45 |
-
useGPU=torch.cuda.is_available(),
|
46 |
-
dtype=torch.float32
|
47 |
)
|
48 |
return model
|
49 |
|
|
|
50 |
model = None
|
51 |
|
|
|
52 |
def infer(
|
53 |
prompt,
|
54 |
-
mode
|
55 |
max_new_tokens=10,
|
56 |
temperature=0.1,
|
57 |
top_p=1.0,
|
@@ -65,18 +65,18 @@ def infer(
|
|
65 |
if (DEVICE == "cuda"):
|
66 |
torch.cuda.empty_cache()
|
67 |
model = get_model()
|
68 |
-
|
69 |
max_new_tokens = int(max_new_tokens)
|
70 |
temperature = float(temperature)
|
71 |
top_p = float(top_p)
|
72 |
-
stop =
|
73 |
seed = seed
|
74 |
|
75 |
assert 1 <= max_new_tokens <= 384
|
76 |
assert 0.0 <= temperature <= 1.0
|
77 |
assert 0.0 <= top_p <= 1.0
|
78 |
|
79 |
-
temperature = max(0.05,temperature)
|
80 |
if prompt == "":
|
81 |
prompt = " "
|
82 |
|
@@ -84,7 +84,7 @@ def infer(
|
|
84 |
model.resetState()
|
85 |
if (mode == "Q/A"):
|
86 |
prompt = f"Ask Expert\n\nQuestion:\n{prompt}\n\nExpert Full Answer:\n"
|
87 |
-
|
88 |
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}")
|
89 |
print(f"OUTPUT ({datetime.now()}):\n-------\n")
|
90 |
# Load prompt
|
@@ -93,11 +93,12 @@ def infer(
|
|
93 |
done = False
|
94 |
with torch.no_grad():
|
95 |
for _ in range(max_new_tokens):
|
96 |
-
char = model.forward(stopStrings=stop,temp=temperature,top_p_usual=top_p)[
|
|
|
97 |
print(char, end='', flush=True)
|
98 |
generated_text += char
|
99 |
generated_text = generated_text.lstrip("\n ")
|
100 |
-
|
101 |
for stop_word in stop:
|
102 |
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
|
103 |
if stop_word != '' and stop_word in generated_text:
|
@@ -108,13 +109,13 @@ def infer(
|
|
108 |
print("<stopped>\n")
|
109 |
break
|
110 |
|
111 |
-
#print(f"{generated_text}")
|
112 |
-
|
113 |
for stop_word in stop:
|
114 |
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
|
115 |
if stop_word != '' and stop_word in generated_text:
|
116 |
generated_text = generated_text[:generated_text.find(stop_word)]
|
117 |
-
|
118 |
gc.collect()
|
119 |
yield generated_text
|
120 |
|
@@ -130,9 +131,9 @@ def chat(
|
|
130 |
):
|
131 |
global model
|
132 |
history = history or []
|
133 |
-
|
134 |
intro = ""
|
135 |
-
|
136 |
if model == None:
|
137 |
gc.collect()
|
138 |
if (DEVICE == "cuda"):
|
@@ -141,7 +142,7 @@ def chat(
|
|
141 |
|
142 |
username = username.strip()
|
143 |
username = username or "USER"
|
144 |
-
|
145 |
intro = f'''The following is a verbose and detailed conversation between an AI assistant called FRITZ, and a human user called USER. FRITZ is intelligent, knowledgeable, wise and polite.
|
146 |
|
147 |
{username}: What year was the french revolution?
|
@@ -156,23 +157,22 @@ def chat(
|
|
156 |
FRITZ: The Large Hadron Collider (LHC) is a high-energy particle collider, built by CERN, and completed in 2008. It was used to confirm the existence of the Higgs boson in 2012.
|
157 |
{username}: Tell me about yourself.
|
158 |
FRITZ: My name is Fritz. I am an RNN based Large Language Model (LLM).
|
159 |
-
'''
|
160 |
-
|
161 |
if len(history) == 0:
|
162 |
# no history, so lets reset chat state
|
163 |
model.resetState()
|
164 |
-
history = [[],model.emptyState]
|
165 |
print("reset chat state")
|
166 |
else:
|
167 |
if (history[0][0][0].split(':')[0] != username):
|
168 |
model.resetState()
|
169 |
-
history = [[],model.emptyState]
|
170 |
print("username changed, reset state")
|
171 |
else:
|
172 |
model.setState(history[1])
|
173 |
intro = ""
|
174 |
-
|
175 |
-
|
176 |
max_new_tokens = int(max_new_tokens)
|
177 |
temperature = float(temperature)
|
178 |
top_p = float(top_p)
|
@@ -182,16 +182,17 @@ def chat(
|
|
182 |
assert 0.0 <= temperature <= 1.0
|
183 |
assert 0.0 <= top_p <= 1.0
|
184 |
|
185 |
-
temperature = max(0.05,temperature)
|
186 |
|
187 |
-
prompt = f"{username}: " + prompt + "\n"
|
188 |
print(f"CHAT ({datetime.now()}):\n-------\n{prompt}")
|
189 |
print(f"OUTPUT ({datetime.now()}):\n-------\n")
|
190 |
# Load prompt
|
191 |
|
192 |
model.loadContext(newctx=intro+prompt)
|
193 |
|
194 |
-
out = model.forward(number=max_new_tokens, stopStrings=[
|
|
|
195 |
|
196 |
generated_text = out["output"].lstrip("\n ")
|
197 |
generated_text = generated_text.rstrip("USER:")
|
@@ -199,19 +200,19 @@ def chat(
|
|
199 |
|
200 |
gc.collect()
|
201 |
history[0].append((prompt, generated_text))
|
202 |
-
return history[0],[history[0],out["state"]]
|
203 |
|
204 |
|
205 |
examples = [
|
206 |
[
|
207 |
# Question Answering
|
208 |
-
'''What is the capital of Germany?''',"Q/A", 25, 0.2, 1.0, "<|endoftext|>"],
|
209 |
[
|
210 |
# Question Answering
|
211 |
-
'''Are humans good or bad?''',"Q/A", 150, 0.8, 0.8, "<|endoftext|>"],
|
212 |
[
|
213 |
# Question Answering
|
214 |
-
'''What is the purpose of Vitamin A?''',"Q/A", 50, 0.2, 0.8, "<|endoftext|>"],
|
215 |
[
|
216 |
# Chatbot
|
217 |
'''This is a conversation between two AI large language models named Alex and Fritz. They are exploring each other's capabilities, and trying to ask interesting questions of one another to explore the limits of each others AI.
|
@@ -231,7 +232,7 @@ Best Full Response:
|
|
231 |
[
|
232 |
# Natural Language Interface
|
233 |
'''Here is a short story (in the style of Tolkien) in which Aiden attacks a robot with a sword:
|
234 |
-
''',"generative", 140, 0.85, 0.8, "<|endoftext|>"]
|
235 |
]
|
236 |
|
237 |
|
@@ -241,11 +242,12 @@ iface = gr.Interface(
|
|
241 |
allow_flagging="never",
|
242 |
inputs=[
|
243 |
gr.Textbox(lines=20, label="Prompt"), # prompt
|
244 |
-
gr.Radio(["generative","Q/A"],
|
|
|
245 |
gr.Slider(1, 256, value=40), # max_tokens
|
246 |
gr.Slider(0.0, 1.0, value=0.8), # temperature
|
247 |
gr.Slider(0.0, 1.0, value=0.85), # top_p
|
248 |
-
gr.Textbox(lines=1, value="<|endoftext|>")
|
249 |
],
|
250 |
outputs=gr.Textbox(label="Generated Output", lines=25),
|
251 |
examples=examples,
|
@@ -259,20 +261,22 @@ chatiface = gr.Interface(
|
|
259 |
inputs=[
|
260 |
gr.Textbox(lines=5, label="Message"), # prompt
|
261 |
"state",
|
262 |
-
gr.Text(lines=1, value="USER", label="Your Name",
|
|
|
263 |
gr.Slider(1, 256, value=60), # max_tokens
|
264 |
gr.Slider(0.0, 1.0, value=0.8), # temperature
|
265 |
gr.Slider(0.0, 1.0, value=0.85) # top_p
|
266 |
],
|
267 |
-
outputs=[gr.Chatbot(label="Chat Log", color_map=(
|
|
|
268 |
).queue()
|
269 |
|
270 |
demo = gr.TabbedInterface(
|
271 |
|
272 |
-
[iface,chatiface],["Generative","Chatbot"],
|
273 |
-
title=
|
274 |
-
|
275 |
-
|
276 |
|
277 |
demo.queue()
|
278 |
demo.launch(share=False)
|
|
|
22 |
from ast import literal_eval
|
23 |
from datetime import datetime
|
24 |
from rwkvstic.load import RWKV
|
25 |
+
from config import config, title
|
26 |
import torch
|
27 |
import gc
|
28 |
|
|
|
33 |
|
34 |
thanks = '''<p>Thanks to <a href='https://www.rftcapital.com'>RFT Capital</a> for donating compute capability for our experiments. Additional thanks to the author of the <a href="https://github.com/harrisonvanderbyl/rwkvstic">rwkvstic</a> library.</p>'''
|
35 |
|
36 |
+
|
37 |
def to_md(text):
|
38 |
return text.replace("\n", "<br />")
|
39 |
|
40 |
+
|
41 |
def get_model():
|
42 |
model = None
|
43 |
model = RWKV(
|
44 |
+
**config
|
|
|
|
|
|
|
|
|
45 |
)
|
46 |
return model
|
47 |
|
48 |
+
|
49 |
model = None
|
50 |
|
51 |
+
|
52 |
def infer(
|
53 |
prompt,
|
54 |
+
mode="generative",
|
55 |
max_new_tokens=10,
|
56 |
temperature=0.1,
|
57 |
top_p=1.0,
|
|
|
65 |
if (DEVICE == "cuda"):
|
66 |
torch.cuda.empty_cache()
|
67 |
model = get_model()
|
68 |
+
|
69 |
max_new_tokens = int(max_new_tokens)
|
70 |
temperature = float(temperature)
|
71 |
top_p = float(top_p)
|
72 |
+
stop = [x.strip(' ') for x in stop.split(',')]
|
73 |
seed = seed
|
74 |
|
75 |
assert 1 <= max_new_tokens <= 384
|
76 |
assert 0.0 <= temperature <= 1.0
|
77 |
assert 0.0 <= top_p <= 1.0
|
78 |
|
79 |
+
temperature = max(0.05, temperature)
|
80 |
if prompt == "":
|
81 |
prompt = " "
|
82 |
|
|
|
84 |
model.resetState()
|
85 |
if (mode == "Q/A"):
|
86 |
prompt = f"Ask Expert\n\nQuestion:\n{prompt}\n\nExpert Full Answer:\n"
|
87 |
+
|
88 |
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}")
|
89 |
print(f"OUTPUT ({datetime.now()}):\n-------\n")
|
90 |
# Load prompt
|
|
|
93 |
done = False
|
94 |
with torch.no_grad():
|
95 |
for _ in range(max_new_tokens):
|
96 |
+
char = model.forward(stopStrings=stop, temp=temperature, top_p_usual=top_p)[
|
97 |
+
"output"]
|
98 |
print(char, end='', flush=True)
|
99 |
generated_text += char
|
100 |
generated_text = generated_text.lstrip("\n ")
|
101 |
+
|
102 |
for stop_word in stop:
|
103 |
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
|
104 |
if stop_word != '' and stop_word in generated_text:
|
|
|
109 |
print("<stopped>\n")
|
110 |
break
|
111 |
|
112 |
+
# print(f"{generated_text}")
|
113 |
+
|
114 |
for stop_word in stop:
|
115 |
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
|
116 |
if stop_word != '' and stop_word in generated_text:
|
117 |
generated_text = generated_text[:generated_text.find(stop_word)]
|
118 |
+
|
119 |
gc.collect()
|
120 |
yield generated_text
|
121 |
|
|
|
131 |
):
|
132 |
global model
|
133 |
history = history or []
|
134 |
+
|
135 |
intro = ""
|
136 |
+
|
137 |
if model == None:
|
138 |
gc.collect()
|
139 |
if (DEVICE == "cuda"):
|
|
|
142 |
|
143 |
username = username.strip()
|
144 |
username = username or "USER"
|
145 |
+
|
146 |
intro = f'''The following is a verbose and detailed conversation between an AI assistant called FRITZ, and a human user called USER. FRITZ is intelligent, knowledgeable, wise and polite.
|
147 |
|
148 |
{username}: What year was the french revolution?
|
|
|
157 |
FRITZ: The Large Hadron Collider (LHC) is a high-energy particle collider, built by CERN, and completed in 2008. It was used to confirm the existence of the Higgs boson in 2012.
|
158 |
{username}: Tell me about yourself.
|
159 |
FRITZ: My name is Fritz. I am an RNN based Large Language Model (LLM).
|
160 |
+
'''
|
161 |
+
|
162 |
if len(history) == 0:
|
163 |
# no history, so lets reset chat state
|
164 |
model.resetState()
|
165 |
+
history = [[], model.emptyState]
|
166 |
print("reset chat state")
|
167 |
else:
|
168 |
if (history[0][0][0].split(':')[0] != username):
|
169 |
model.resetState()
|
170 |
+
history = [[], model.emptyState]
|
171 |
print("username changed, reset state")
|
172 |
else:
|
173 |
model.setState(history[1])
|
174 |
intro = ""
|
175 |
+
|
|
|
176 |
max_new_tokens = int(max_new_tokens)
|
177 |
temperature = float(temperature)
|
178 |
top_p = float(top_p)
|
|
|
182 |
assert 0.0 <= temperature <= 1.0
|
183 |
assert 0.0 <= top_p <= 1.0
|
184 |
|
185 |
+
temperature = max(0.05, temperature)
|
186 |
|
187 |
+
prompt = f"{username}: " + prompt + "\n"
|
188 |
print(f"CHAT ({datetime.now()}):\n-------\n{prompt}")
|
189 |
print(f"OUTPUT ({datetime.now()}):\n-------\n")
|
190 |
# Load prompt
|
191 |
|
192 |
model.loadContext(newctx=intro+prompt)
|
193 |
|
194 |
+
out = model.forward(number=max_new_tokens, stopStrings=[
|
195 |
+
"<|endoftext|>", username+":"], temp=temperature, top_p_usual=top_p)
|
196 |
|
197 |
generated_text = out["output"].lstrip("\n ")
|
198 |
generated_text = generated_text.rstrip("USER:")
|
|
|
200 |
|
201 |
gc.collect()
|
202 |
history[0].append((prompt, generated_text))
|
203 |
+
return history[0], [history[0], out["state"]]
|
204 |
|
205 |
|
206 |
examples = [
|
207 |
[
|
208 |
# Question Answering
|
209 |
+
'''What is the capital of Germany?''', "Q/A", 25, 0.2, 1.0, "<|endoftext|>"],
|
210 |
[
|
211 |
# Question Answering
|
212 |
+
'''Are humans good or bad?''', "Q/A", 150, 0.8, 0.8, "<|endoftext|>"],
|
213 |
[
|
214 |
# Question Answering
|
215 |
+
'''What is the purpose of Vitamin A?''', "Q/A", 50, 0.2, 0.8, "<|endoftext|>"],
|
216 |
[
|
217 |
# Chatbot
|
218 |
'''This is a conversation between two AI large language models named Alex and Fritz. They are exploring each other's capabilities, and trying to ask interesting questions of one another to explore the limits of each others AI.
|
|
|
232 |
[
|
233 |
# Natural Language Interface
|
234 |
'''Here is a short story (in the style of Tolkien) in which Aiden attacks a robot with a sword:
|
235 |
+
''', "generative", 140, 0.85, 0.8, "<|endoftext|>"]
|
236 |
]
|
237 |
|
238 |
|
|
|
242 |
allow_flagging="never",
|
243 |
inputs=[
|
244 |
gr.Textbox(lines=20, label="Prompt"), # prompt
|
245 |
+
gr.Radio(["generative", "Q/A"],
|
246 |
+
value="generative", label="Choose Mode"),
|
247 |
gr.Slider(1, 256, value=40), # max_tokens
|
248 |
gr.Slider(0.0, 1.0, value=0.8), # temperature
|
249 |
gr.Slider(0.0, 1.0, value=0.85), # top_p
|
250 |
+
gr.Textbox(lines=1, value="<|endoftext|>") # stop
|
251 |
],
|
252 |
outputs=gr.Textbox(label="Generated Output", lines=25),
|
253 |
examples=examples,
|
|
|
261 |
inputs=[
|
262 |
gr.Textbox(lines=5, label="Message"), # prompt
|
263 |
"state",
|
264 |
+
gr.Text(lines=1, value="USER", label="Your Name",
|
265 |
+
placeholder="Enter your Name"),
|
266 |
gr.Slider(1, 256, value=60), # max_tokens
|
267 |
gr.Slider(0.0, 1.0, value=0.8), # temperature
|
268 |
gr.Slider(0.0, 1.0, value=0.85) # top_p
|
269 |
],
|
270 |
+
outputs=[gr.Chatbot(label="Chat Log", color_map=(
|
271 |
+
"green", "pink")), "state"],
|
272 |
).queue()
|
273 |
|
274 |
demo = gr.TabbedInterface(
|
275 |
|
276 |
+
[iface, chatiface], ["Generative", "Chatbot"],
|
277 |
+
title=title,
|
278 |
+
|
279 |
+
)
|
280 |
|
281 |
demo.queue()
|
282 |
demo.launch(share=False)
|
config.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rwkvstic.agnostic.backends import TORCH, TORCH_QUANT
|
2 |
+
import torch
|
3 |
+
|
4 |
+
quantized = {
|
5 |
+
"mode": TORCH_QUANT,
|
6 |
+
"runtimedtype": torch.bfloat16,
|
7 |
+
"useGPU": torch.cuda.is_available(),
|
8 |
+
"chunksize": 32, # larger = more accurate, but more memory
|
9 |
+
"target": 100 # your gpu max size, excess vram offloaded to cpu
|
10 |
+
}
|
11 |
+
|
12 |
+
# UNCOMMENT TO SELECT OPTIONS
|
13 |
+
# Not full list of options, see https://pypi.org/project/rwkvstic/ and https://huggingface.co/BlinkDL/ for more models/modes
|
14 |
+
|
15 |
+
# RWKV 1B5 instruct test 1 model
|
16 |
+
# Approximate
|
17 |
+
# [Vram usage: 6.0GB]
|
18 |
+
# [File size: 3.0GB]
|
19 |
+
|
20 |
+
|
21 |
+
config = {
|
22 |
+
"path": "https://huggingface.co/BlinkDL/rwkv-4-pile-1b5/resolve/main/RWKV-4-Pile-1B5-Instruct-test1-20230124.pth",
|
23 |
+
"mode": TORCH,
|
24 |
+
"runtimedtype": torch.float32,
|
25 |
+
"useGPU": torch.cuda.is_available(),
|
26 |
+
"dtype": torch.float32
|
27 |
+
}
|
28 |
+
|
29 |
+
title = "RWKV-4 (1.5b Instruct)"
|
30 |
+
|
31 |
+
# RWKV 1B5 instruct model quantized
|
32 |
+
# Approximate
|
33 |
+
# [Vram usage: 1.3GB]
|
34 |
+
# [File size: 3.0GB]
|
35 |
+
|
36 |
+
# config = {
|
37 |
+
# "path": "https://huggingface.co/BlinkDL/rwkv-4-pile-1b5/resolve/main/RWKV-4-Pile-1B5-Instruct-test1-20230124.pth",
|
38 |
+
# **quantized
|
39 |
+
# }
|
40 |
+
|
41 |
+
# title = "RWKV-4 (1.5b Instruct Quantized)"
|
42 |
+
|
43 |
+
# RWKV 7B instruct pre-quantized (settings baked into model)
|
44 |
+
# Approximate
|
45 |
+
# [Vram usage: 7.0GB]
|
46 |
+
# [File size: 8.0GB]
|
47 |
+
|
48 |
+
# config = {
|
49 |
+
# "path": "https://huggingface.co/Hazzzardous/RWKV-8Bit/resolve/main/RWKV-4-Pile-7B-Instruct.pqth"
|
50 |
+
# }
|
51 |
+
|
52 |
+
# title = "RWKV-4 (7b Instruct Quantized)"
|
53 |
+
|
54 |
+
# RWKV 14B quantized (latest as of feb 9)
|
55 |
+
# Approximate
|
56 |
+
# [Vram usage: 15.0GB]
|
57 |
+
# [File size: 28.0GB]
|
58 |
+
|
59 |
+
# config = {
|
60 |
+
# "path": "https://huggingface.co/BlinkDL/rwkv-4-pile-14b/resolve/main/RWKV-4-Pile-14B-20230204-7324.pth"
|
61 |
+
# }
|
62 |
+
|
63 |
+
# title = "RWKV-4 (14b)"
|