nikravan commited on
Commit
36269b6
1 Parent(s): 94fa520

update file and model for only 1bit

Browse files
Files changed (1) hide show
  1. app.py +294 -1
app.py CHANGED
@@ -1,3 +1,296 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- gr.load("models/1bitLLM/bitnet_b1_58-xl").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from gradio_client import Client
3
+ from huggingface_hub import InferenceClient
4
+ import random
5
+ #ss_client = Client("https://omnibus-html-image-current-tab.hf.space/")
6
 
7
+ models=[
8
+
9
+ "1bitLLM/bitnet_b1_58-3B",
10
+ "1bitLLM/bitnet_b1_58-large",
11
+ "1bitLLM/bitnet_b1_58-xl",
12
+ ]
13
+ client_z=[]
14
+
15
+
16
+ def load_models(inp,new_models):
17
+ if not new_models:
18
+ new_models=models
19
+ out_box=[gr.Chatbot(),gr.Chatbot(),gr.Chatbot(),gr.Chatbot()]
20
+ print(type(inp))
21
+ print(inp)
22
+ #print(new_models[inp[0]])
23
+ client_z.clear()
24
+ for z,ea in enumerate(inp):
25
+ client_z.append(InferenceClient(new_models[inp[z]]))
26
+ out_box[z]=(gr.update(label=new_models[inp[z]]))
27
+ return out_box[0],out_box[1],out_box[2],out_box[3]
28
+
29
+ def format_prompt_default(message, history):
30
+ prompt = ""
31
+ if history:
32
+ #<start_of_turn>userHow does the brain work?<end_of_turn><start_of_turn>model
33
+ for user_prompt, bot_response in history:
34
+ prompt += f"{user_prompt}\n"
35
+ print(prompt)
36
+ prompt += f"{bot_response}\n"
37
+ print(prompt)
38
+ prompt += f"{message}\n"
39
+ return prompt
40
+
41
+ def format_prompt_gemma(message, history):
42
+ prompt = ""
43
+ if history:
44
+ #<start_of_turn>userHow does the brain work?<end_of_turn><start_of_turn>model
45
+ for user_prompt, bot_response in history:
46
+ prompt += f"{user_prompt}\n"
47
+ print(prompt)
48
+ prompt += f"{bot_response}\n"
49
+ print(prompt)
50
+ prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
51
+ return prompt
52
+
53
+
54
+ def format_prompt_mixtral(message, history):
55
+ prompt = "<s>"
56
+ if history:
57
+ for user_prompt, bot_response in history:
58
+ prompt += f"[INST] {user_prompt} [/INST]"
59
+ prompt += f" {bot_response}</s> "
60
+ prompt += f"[INST] {message} [/INST]"
61
+ return prompt
62
+
63
+ def format_prompt_choose(message, history, model_name, new_models=None):
64
+ if not new_models:
65
+ new_models=models
66
+ if "gemma" in new_models[model_name].lower() and "it" in new_models[model_name].lower():
67
+ return format_prompt_gemma(message,history)
68
+ if "mixtral" in new_models[model_name].lower():
69
+ return format_prompt_mixtral(message,history)
70
+ else:
71
+ return format_prompt_default(message,history)
72
+
73
+
74
+
75
+ mega_hist=[[],[],[],[]]
76
+ def chat_inf_tree(system_prompt,prompt,history,client_choice,seed,temp,tokens,top_p,rep_p,hid_val):
77
+ if len(client_choice)>=hid_val:
78
+ client=client_z[int(hid_val)-1]
79
+ if history:
80
+ mega_hist[hid_val-1]=history
81
+ #history = []
82
+ hist_len=0
83
+ generate_kwargs = dict(
84
+ temperature=temp,
85
+ max_new_tokens=tokens,
86
+ top_p=top_p,
87
+ repetition_penalty=rep_p,
88
+ do_sample=True,
89
+ seed=seed,
90
+ )
91
+ #formatted_prompt=prompt
92
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", mega_hist[hid_val-1])
93
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
94
+ output = ""
95
+ for response in stream:
96
+ output += response.token.text
97
+ yield [(prompt,output)]
98
+ mega_hist[hid_val-1].append((prompt,output))
99
+ yield mega_hist[hid_val-1]
100
+ else:
101
+ yield None
102
+
103
+
104
+
105
+
106
+ def chat_inf_a(system_prompt,prompt,history,client_choice,seed,temp,tokens,top_p,rep_p,hid_val):
107
+ if len(client_choice)>=hid_val:
108
+ if system_prompt:
109
+ system_prompt=f'{system_prompt}, '
110
+ client1=client_z[int(hid_val)-1]
111
+ if not history:
112
+ history = []
113
+ hist_len=0
114
+ generate_kwargs = dict(
115
+ temperature=temp,
116
+ max_new_tokens=tokens,
117
+ top_p=top_p,
118
+ repetition_penalty=rep_p,
119
+ do_sample=True,
120
+ seed=seed,
121
+ )
122
+ #formatted_prompt=prompt
123
+ formatted_prompt = format_prompt_choose(f"{system_prompt}{prompt}", history, client_choice[0])
124
+ stream1 = client1.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
125
+ output = ""
126
+ for response in stream1:
127
+ output += response.token.text
128
+ yield [(prompt,output)]
129
+ history.append((prompt,output))
130
+ yield history
131
+ else:
132
+ yield None
133
+
134
+
135
+ def chat_inf_b(system_prompt,prompt,history,client_choice,seed,temp,tokens,top_p,rep_p,hid_val):
136
+ if len(client_choice)>=hid_val:
137
+ if system_prompt:
138
+ system_prompt=f'{system_prompt}, '
139
+ client2=client_z[int(hid_val)-1]
140
+ if not history:
141
+ history = []
142
+ hist_len=0
143
+ generate_kwargs = dict(
144
+ temperature=temp,
145
+ max_new_tokens=tokens,
146
+ top_p=top_p,
147
+ repetition_penalty=rep_p,
148
+ do_sample=True,
149
+ seed=seed,
150
+ )
151
+ #formatted_prompt=prompt
152
+ formatted_prompt = format_prompt_choose(f"{system_prompt}{prompt}", history, client_choice[1])
153
+ stream2 = client2.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
154
+ output = ""
155
+ for response in stream2:
156
+ output += response.token.text
157
+ yield [(prompt,output)]
158
+ history.append((prompt,output))
159
+ yield history
160
+ else:
161
+ yield None
162
+
163
+ def chat_inf_c(system_prompt,prompt,history,client_choice,seed,temp,tokens,top_p,rep_p,hid_val):
164
+ if len(client_choice)>=hid_val:
165
+ if system_prompt:
166
+ system_prompt=f'{system_prompt}, '
167
+ client3=client_z[int(hid_val)-1]
168
+ if not history:
169
+ history = []
170
+ hist_len=0
171
+ generate_kwargs = dict(
172
+ temperature=temp,
173
+ max_new_tokens=tokens,
174
+ top_p=top_p,
175
+ repetition_penalty=rep_p,
176
+ do_sample=True,
177
+ seed=seed,
178
+ )
179
+ #formatted_prompt=prompt
180
+ formatted_prompt = format_prompt_choose(f"{system_prompt}{prompt}", history, client_choice[2])
181
+ stream3 = client3.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
182
+ output = ""
183
+ for response in stream3:
184
+ output += response.token.text
185
+ yield [(prompt,output)]
186
+ history.append((prompt,output))
187
+ yield history
188
+ else:
189
+ yield None
190
+
191
+ def chat_inf_d(system_prompt,prompt,history,client_choice,seed,temp,tokens,top_p,rep_p,hid_val):
192
+ if len(client_choice)>=hid_val:
193
+ if system_prompt:
194
+ system_prompt=f'{system_prompt}, '
195
+ client4=client_z[int(hid_val)-1]
196
+ if not history:
197
+ history = []
198
+ hist_len=0
199
+ generate_kwargs = dict(
200
+ temperature=temp,
201
+ max_new_tokens=tokens,
202
+ top_p=top_p,
203
+ repetition_penalty=rep_p,
204
+ do_sample=True,
205
+ seed=seed,
206
+ )
207
+ #formatted_prompt=prompt
208
+ formatted_prompt = format_prompt_choose(f"{system_prompt}{prompt}", history, client_choice[3])
209
+ stream4 = client4.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
210
+ output = ""
211
+ for response in stream4:
212
+ output += response.token.text
213
+ yield [(prompt,output)]
214
+ history.append((prompt,output))
215
+ yield history
216
+ else:
217
+ yield None
218
+ def add_new_model(inp, cur):
219
+ cur.append(inp)
220
+ return cur,gr.update(choices=[z for z in cur])
221
+ def load_new(models=models):
222
+ return models
223
+
224
+ def clear_fn():
225
+ return None,None,None,None,None,None
226
+ rand_val=random.randint(1,1111111111111111)
227
+ def check_rand(inp,val):
228
+ if inp==True:
229
+ return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1,1111111111111111))
230
+ else:
231
+ return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
232
+
233
+ with gr.Blocks() as app:
234
+ new_models=gr.State([])
235
+ gr.HTML("""<center><h1 style='font-size:xx-large;'>Chatbot Model Compare</h1><br><h3>running on Huggingface Inference Client</h3><br><h7>EXPERIMENTAL""")
236
+ with gr.Row():
237
+ chat_a = gr.Chatbot(height=500)
238
+ chat_b = gr.Chatbot(height=500)
239
+ with gr.Row():
240
+ chat_c = gr.Chatbot(height=500)
241
+ chat_d = gr.Chatbot(height=500)
242
+ with gr.Group():
243
+ with gr.Row():
244
+ with gr.Column(scale=3):
245
+ inp = gr.Textbox(label="Prompt")
246
+ sys_inp = gr.Textbox(label="System Prompt (optional)")
247
+ with gr.Row():
248
+ with gr.Column(scale=2):
249
+ btn = gr.Button("Chat")
250
+ with gr.Column(scale=1):
251
+ with gr.Group():
252
+ stop_btn=gr.Button("Stop")
253
+ clear_btn=gr.Button("Clear")
254
+ client_choice=gr.Dropdown(label="Models",type='index',choices=[c for c in models],max_choices=4,multiselect=True,interactive=True)
255
+ add_model=gr.Textbox(label="New Model")
256
+ add_btn=gr.Button("Add Model")
257
+ with gr.Column(scale=1):
258
+ with gr.Group():
259
+ rand = gr.Checkbox(label="Random Seed", value=True)
260
+ seed=gr.Slider(label="Seed", minimum=1, maximum=1111111111111111,step=1, value=rand_val)
261
+ tokens = gr.Slider(label="Max new tokens",value=3840,minimum=0,maximum=8000,step=64,interactive=True, visible=True,info="The maximum number of tokens")
262
+ temp=gr.Slider(label="Temperature",step=0.01, minimum=0.01, maximum=1.0, value=0.9)
263
+ top_p=gr.Slider(label="Top-P",step=0.01, minimum=0.01, maximum=1.0, value=0.9)
264
+ rep_p=gr.Slider(label="Repetition Penalty",step=0.1, minimum=0.1, maximum=2.0, value=1.0)
265
+ with gr.Accordion(label="Screenshot",open=False):
266
+ with gr.Row():
267
+ with gr.Column(scale=3):
268
+ im_btn=gr.Button("Screenshot")
269
+ img=gr.Image(type='filepath')
270
+ with gr.Column(scale=1):
271
+ with gr.Row():
272
+ im_height=gr.Number(label="Height",value=5000)
273
+ im_width=gr.Number(label="Width",value=500)
274
+ wait_time=gr.Number(label="Wait Time",value=3000)
275
+ theme=gr.Radio(label="Theme", choices=["light","dark"],value="light")
276
+ chatblock=gr.Dropdown(label="Chatblocks",info="Choose specific blocks of chat",choices=[c for c in range(1,40)],multiselect=True)
277
+ hid1=gr.Number(value=1,visible=False)
278
+ hid2=gr.Number(value=2,visible=False)
279
+ hid3=gr.Number(value=3,visible=False)
280
+ hid4=gr.Number(value=4,visible=False)
281
+
282
+ app.load(load_new,None,new_models)
283
+ add_btn.click(add_new_model,[add_model,new_models],[new_models,client_choice])
284
+ client_choice.change(load_models,[client_choice,new_models],[chat_a,chat_b,chat_c,chat_d])
285
+
286
+ #im_go=im_btn.click(get_screenshot,[chat_b,im_height,im_width,chatblock,theme,wait_time],img)
287
+ #chat_sub=inp.submit(check_rand,[rand,seed],seed).then(chat_inf,[sys_inp,inp,chat_b,client_choice,seed,temp,tokens,top_p,rep_p],chat_b)
288
+
289
+ go1=btn.click(check_rand,[rand,seed],seed).then(chat_inf_a,[sys_inp,inp,chat_b,client_choice,seed,temp,tokens,top_p,rep_p,hid1],chat_a)
290
+ go2=btn.click(check_rand,[rand,seed],seed).then(chat_inf_b,[sys_inp,inp,chat_b,client_choice,seed,temp,tokens,top_p,rep_p,hid2],chat_b)
291
+ go3=btn.click(check_rand,[rand,seed],seed).then(chat_inf_c,[sys_inp,inp,chat_b,client_choice,seed,temp,tokens,top_p,rep_p,hid3],chat_c)
292
+ go4=btn.click(check_rand,[rand,seed],seed).then(chat_inf_d,[sys_inp,inp,chat_b,client_choice,seed,temp,tokens,top_p,rep_p,hid4],chat_d)
293
+
294
+ stop_btn.click(None,None,None,cancels=[go1,go2,go3,go4])
295
+ clear_btn.click(clear_fn,None,[inp,sys_inp,chat_a,chat_b,chat_c,chat_d])
296
+ app.queue(default_concurrency_limit=10).launch()