Spaces:
Sleeping
Sleeping
Hazzzardous
commited on
Commit
•
0791700
1
Parent(s):
5667dbf
Update app.py
Browse files
app.py
CHANGED
@@ -46,7 +46,8 @@ def get_model():
|
|
46 |
return model
|
47 |
|
48 |
|
49 |
-
model =
|
|
|
50 |
|
51 |
|
52 |
def infer(
|
@@ -126,12 +127,29 @@ def infer(
|
|
126 |
|
127 |
gc.collect()
|
128 |
yield generated_text
|
|
|
|
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
|
|
|
|
|
|
131 |
def chat(
|
132 |
prompt,
|
133 |
history,
|
134 |
-
username,
|
135 |
max_new_tokens=10,
|
136 |
temperature=0.1,
|
137 |
top_p=1.0,
|
@@ -151,31 +169,17 @@ def chat(
|
|
151 |
username = username.strip()
|
152 |
username = username or "USER"
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
{username}: What year was the french revolution?
|
157 |
-
FRITZ: The French Revolution started in 1789, and lasted 10 years until 1799.
|
158 |
-
{username}: 3+5=?
|
159 |
-
FRITZ: The answer is 8.
|
160 |
-
{username}: What year did the Berlin Wall fall?
|
161 |
-
FRITZ: The Berlin wall stood for 28 years and fell in 1989.
|
162 |
-
{username}: solve for a: 9-a=2
|
163 |
-
FRITZ: The answer is a=7, because 9-7 = 2.
|
164 |
-
{username}: wat is lhc
|
165 |
-
FRITZ: The Large Hadron Collider (LHC) is a high-energy particle collider, built by CERN, and completed in 2008. It was used to confirm the existence of the Higgs boson in 2012.
|
166 |
-
{username}: Tell me about yourself.
|
167 |
-
FRITZ: My name is Fritz. I am an RNN based Large Language Model (LLM).
|
168 |
-
'''
|
169 |
|
170 |
if len(history) == 0:
|
171 |
# no history, so lets reset chat state
|
172 |
-
model.
|
173 |
history = [[], model.emptyState]
|
174 |
print("reset chat state")
|
175 |
else:
|
176 |
if (history[0][0][0].split(':')[0] != username):
|
177 |
-
model.
|
178 |
-
history = [[], model.
|
179 |
print("username changed, reset state")
|
180 |
else:
|
181 |
model.setState(history[1])
|
@@ -186,8 +190,8 @@ def chat(
|
|
186 |
top_p = float(top_p)
|
187 |
seed = seed
|
188 |
|
189 |
-
assert 1 <= max_new_tokens <=
|
190 |
-
assert 0.0 <= temperature <=
|
191 |
assert 0.0 <= top_p <= 1.0
|
192 |
|
193 |
temperature = max(0.05, temperature)
|
@@ -197,13 +201,13 @@ def chat(
|
|
197 |
print(f"OUTPUT ({datetime.now()}):\n-------\n")
|
198 |
# Load prompt
|
199 |
|
200 |
-
model.loadContext(newctx=
|
201 |
|
202 |
out = model.forward(number=max_new_tokens, stopStrings=[
|
203 |
"<|endoftext|>", username+":"], temp=temperature, top_p_usual=top_p)
|
204 |
|
205 |
generated_text = out["output"].lstrip("\n ")
|
206 |
-
generated_text = generated_text.rstrip("
|
207 |
print(f"{generated_text}")
|
208 |
|
209 |
gc.collect()
|
@@ -251,12 +255,13 @@ iface = gr.Interface(
|
|
251 |
inputs=[
|
252 |
gr.Textbox(lines=20, label="Prompt"), # prompt
|
253 |
gr.Radio(["generative", "Q/A","ELDR","EFR","BFR"],
|
254 |
-
value="
|
255 |
-
gr.Slider(1,
|
256 |
-
gr.Slider(0.0,
|
257 |
gr.Slider(0.0, 1.0, value=0.85), # top_p
|
|
|
258 |
gr.Slider(-999, 0.0, value=0.0), # end_adj
|
259 |
-
|
260 |
],
|
261 |
outputs=gr.Textbox(label="Generated Output", lines=25),
|
262 |
examples=examples,
|
@@ -270,8 +275,6 @@ chatiface = gr.Interface(
|
|
270 |
inputs=[
|
271 |
gr.Textbox(lines=5, label="Message"), # prompt
|
272 |
"state",
|
273 |
-
gr.Text(lines=1, value="USER", label="Your Name",
|
274 |
-
placeholder="Enter your Name"),
|
275 |
gr.Slider(1, 256, value=60), # max_tokens
|
276 |
gr.Slider(0.0, 1.0, value=0.8), # temperature
|
277 |
gr.Slider(0.0, 1.0, value=0.85) # top_p
|
@@ -282,7 +285,7 @@ chatiface = gr.Interface(
|
|
282 |
|
283 |
demo = gr.TabbedInterface(
|
284 |
|
285 |
-
[iface, chatiface], ["
|
286 |
title=title,
|
287 |
|
288 |
)
|
|
|
46 |
return model
|
47 |
|
48 |
|
49 |
+
model = get_model()
|
50 |
+
|
51 |
|
52 |
|
53 |
def infer(
|
|
|
127 |
|
128 |
gc.collect()
|
129 |
yield generated_text
|
130 |
+
username = "USER"
|
131 |
+
intro = f'''The following is a verbose and detailed conversation between an AI assistant called FRITZ, and a human user called USER. FRITZ is intelligent, knowledgeable, wise and polite.
|
132 |
|
133 |
+
{username}: What year was the french revolution?
|
134 |
+
FRITZ: The French Revolution started in 1789, and lasted 10 years until 1799.
|
135 |
+
{username}: 3+5=?
|
136 |
+
FRITZ: The answer is 8.
|
137 |
+
{username}: What year did the Berlin Wall fall?
|
138 |
+
FRITZ: The Berlin wall stood for 28 years and fell in 1989.
|
139 |
+
{username}: solve for a: 9-a=2
|
140 |
+
FRITZ: The answer is a=7, because 9-7 = 2.
|
141 |
+
{username}: wat is lhc
|
142 |
+
FRITZ: The Large Hadron Collider (LHC) is a high-energy particle collider, built by CERN, and completed in 2008. It was used to confirm the existence of the Higgs boson in 2012.
|
143 |
+
{username}: Tell me about yourself.
|
144 |
+
FRITZ: My name is Fritz. I am an RNN based Large Language Model (LLM).
|
145 |
+
'''
|
146 |
|
147 |
+
model.loadContext(intro)
|
148 |
+
chatState = model.getState().clone()
|
149 |
+
model.resetState()
|
150 |
def chat(
|
151 |
prompt,
|
152 |
history,
|
|
|
153 |
max_new_tokens=10,
|
154 |
temperature=0.1,
|
155 |
top_p=1.0,
|
|
|
169 |
username = username.strip()
|
170 |
username = username or "USER"
|
171 |
|
172 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
if len(history) == 0:
|
175 |
# no history, so lets reset chat state
|
176 |
+
model.setState(chatState)
|
177 |
history = [[], model.emptyState]
|
178 |
print("reset chat state")
|
179 |
else:
|
180 |
if (history[0][0][0].split(':')[0] != username):
|
181 |
+
model.setState(chatState)
|
182 |
+
history = [[], model.chatState]
|
183 |
print("username changed, reset state")
|
184 |
else:
|
185 |
model.setState(history[1])
|
|
|
190 |
top_p = float(top_p)
|
191 |
seed = seed
|
192 |
|
193 |
+
assert 1 <= max_new_tokens <= 512
|
194 |
+
assert 0.0 <= temperature <= 3.0
|
195 |
assert 0.0 <= top_p <= 1.0
|
196 |
|
197 |
temperature = max(0.05, temperature)
|
|
|
201 |
print(f"OUTPUT ({datetime.now()}):\n-------\n")
|
202 |
# Load prompt
|
203 |
|
204 |
+
model.loadContext(newctx=prompt)
|
205 |
|
206 |
out = model.forward(number=max_new_tokens, stopStrings=[
|
207 |
"<|endoftext|>", username+":"], temp=temperature, top_p_usual=top_p)
|
208 |
|
209 |
generated_text = out["output"].lstrip("\n ")
|
210 |
+
generated_text = generated_text.rstrip(username+":")
|
211 |
print(f"{generated_text}")
|
212 |
|
213 |
gc.collect()
|
|
|
255 |
inputs=[
|
256 |
gr.Textbox(lines=20, label="Prompt"), # prompt
|
257 |
gr.Radio(["generative", "Q/A","ELDR","EFR","BFR"],
|
258 |
+
value="ELDR", label="Choose Mode"),
|
259 |
+
gr.Slider(1, 512, value=40), # max_tokens
|
260 |
+
gr.Slider(0.0, 5.0, value=1.0), # temperature
|
261 |
gr.Slider(0.0, 1.0, value=0.85), # top_p
|
262 |
+
gr.Textbox(lines=1, value="<|endoftext|>"), # stop
|
263 |
gr.Slider(-999, 0.0, value=0.0), # end_adj
|
264 |
+
|
265 |
],
|
266 |
outputs=gr.Textbox(label="Generated Output", lines=25),
|
267 |
examples=examples,
|
|
|
275 |
inputs=[
|
276 |
gr.Textbox(lines=5, label="Message"), # prompt
|
277 |
"state",
|
|
|
|
|
278 |
gr.Slider(1, 256, value=60), # max_tokens
|
279 |
gr.Slider(0.0, 1.0, value=0.8), # temperature
|
280 |
gr.Slider(0.0, 1.0, value=0.85) # top_p
|
|
|
285 |
|
286 |
demo = gr.TabbedInterface(
|
287 |
|
288 |
+
[iface, chatiface], ["ELDR", "Chatbot"],
|
289 |
title=title,
|
290 |
|
291 |
)
|