rishiraj commited on
Commit
d0790bd
1 Parent(s): a1375ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +220 -93
app.py CHANGED
@@ -14,7 +14,7 @@ with open(".config/application_default_credentials.json", 'w') as file:
14
 
15
  vertexai.init(project=os.getenv('project_id'))
16
  model = GenerativeModel("gemini-1.0-pro-vision")
17
- client = InferenceClient("google/gemma-7b-it")
18
 
19
  def extract_image_urls(text):
20
  url_regex = r"(https?:\/\/.*\.(?:png|jpg|jpeg|gif|webp|svg))"
@@ -40,100 +40,227 @@ def search(url):
40
  response = model.generate_content([image,"Describe what is shown in this image."])
41
  return response.text
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def format_prompt(message, history, cust_p):
44
  prompt = ""
45
- for user_prompt, bot_response in history:
46
- prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
47
- prompt += f"<start_of_turn>model{bot_response}<end_of_turn>"
48
- prompt += cust_p.replace("USER_INPUT",message)
 
49
  return prompt
50
 
51
- def generate(
52
- prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
53
- ):
54
- custom_prompt="<start_of_turn>userUSER_INPUT<end_of_turn><start_of_turn>model"
55
- temperature = float(temperature)
56
- if temperature < 1e-2:
57
- temperature = 1e-2
58
- top_p = float(top_p)
59
-
60
- generate_kwargs = dict(
61
- temperature=temperature,
62
- max_new_tokens=max_new_tokens,
63
- top_p=top_p,
64
- repetition_penalty=repetition_penalty,
65
- do_sample=True,
66
- seed=42,
67
- )
68
-
69
- image = extract_image_urls(prompt)
70
- if image:
71
- image_description = "Image Description: " + search(image)
72
- prompt = prompt.replace(image, image_description)
73
- print(prompt)
74
- formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history, custom_prompt)
75
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
76
- output = ""
77
-
78
- for response in stream:
79
- output += response.token.text
80
- yield output
81
- return output
82
-
83
-
84
- additional_inputs=[
85
- gr.Textbox(
86
- label="System Prompt",
87
- max_lines=1,
88
- interactive=True,
89
- ),
90
- gr.Slider(
91
- label="Temperature",
92
- value=0.9,
93
- minimum=0.0,
94
- maximum=1.0,
95
- step=0.05,
96
- interactive=True,
97
- info="Higher values produce more diverse outputs",
98
- ),
99
- gr.Slider(
100
- label="Max new tokens",
101
- value=256,
102
- minimum=0,
103
- maximum=1048,
104
- step=64,
105
- interactive=True,
106
- info="The maximum numbers of new tokens",
107
- ),
108
- gr.Slider(
109
- label="Top-p (nucleus sampling)",
110
- value=0.90,
111
- minimum=0.0,
112
- maximum=1,
113
- step=0.05,
114
- interactive=True,
115
- info="Higher values sample more low-probability tokens",
116
- ),
117
- gr.Slider(
118
- label="Repetition penalty",
119
- value=1.2,
120
- minimum=1.0,
121
- maximum=2.0,
122
- step=0.05,
123
- interactive=True,
124
- info="Penalize repeated tokens",
125
- )
126
- ]
 
 
 
 
 
 
 
 
 
127
 
128
- examples=[["What are they doing here https://upload.wikimedia.org/wikipedia/commons/3/38/Two_dancers.jpg ?", None, None, None, None, None]]
129
-
130
- gr.ChatInterface(
131
- fn=generate,
132
- chatbot=gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False),
133
- additional_inputs=additional_inputs,
134
- title="Gemma Gemini Multimodal Chatbot",
135
- description="Gemini Sprint submission by Rishiraj Acharya. Uses Google's Gemini 1.0 Pro Vision multimodal model from Vertex AI with Google's Gemma 7B Instruct model from Hugging Face. Google Cloud credits are provided for this project.",
136
- theme="Soft",
137
- examples=examples,
138
- concurrency_limit=20,
139
- ).launch(show_api=False)
 
14
 
15
  vertexai.init(project=os.getenv('project_id'))
16
  model = GenerativeModel("gemini-1.0-pro-vision")
17
+ # client = InferenceClient("google/gemma-7b-it")
18
 
19
  def extract_image_urls(text):
20
  url_regex = r"(https?:\/\/.*\.(?:png|jpg|jpeg|gif|webp|svg))"
 
40
  response = model.generate_content([image,"Describe what is shown in this image."])
41
  return response.text
42
 
43
+ # def format_prompt(message, history, cust_p):
44
+ # prompt = ""
45
+ # for user_prompt, bot_response in history:
46
+ # prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
47
+ # prompt += f"<start_of_turn>model{bot_response}<end_of_turn>"
48
+ # prompt += cust_p.replace("USER_INPUT",message)
49
+ # return prompt
50
+
51
+ # def generate(
52
+ # prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
53
+ # ):
54
+ # custom_prompt="<start_of_turn>userUSER_INPUT<end_of_turn><start_of_turn>model"
55
+ # temperature = float(temperature)
56
+ # if temperature < 1e-2:
57
+ # temperature = 1e-2
58
+ # top_p = float(top_p)
59
+
60
+ # generate_kwargs = dict(
61
+ # temperature=temperature,
62
+ # max_new_tokens=max_new_tokens,
63
+ # top_p=top_p,
64
+ # repetition_penalty=repetition_penalty,
65
+ # do_sample=True,
66
+ # seed=42,
67
+ # )
68
+
69
+ # image = extract_image_urls(prompt)
70
+ # if image:
71
+ # image_description = "Image Description: " + search(image)
72
+ # prompt = prompt.replace(image, image_description)
73
+ # print(prompt)
74
+ # formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history, custom_prompt)
75
+ # stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
76
+ # output = ""
77
+
78
+ # for response in stream:
79
+ # output += response.token.text
80
+ # yield output
81
+ # return output
82
+
83
+
84
+ # additional_inputs=[
85
+ # gr.Textbox(
86
+ # label="System Prompt",
87
+ # max_lines=1,
88
+ # interactive=True,
89
+ # ),
90
+ # gr.Slider(
91
+ # label="Temperature",
92
+ # value=0.9,
93
+ # minimum=0.0,
94
+ # maximum=1.0,
95
+ # step=0.05,
96
+ # interactive=True,
97
+ # info="Higher values produce more diverse outputs",
98
+ # ),
99
+ # gr.Slider(
100
+ # label="Max new tokens",
101
+ # value=256,
102
+ # minimum=0,
103
+ # maximum=1048,
104
+ # step=64,
105
+ # interactive=True,
106
+ # info="The maximum numbers of new tokens",
107
+ # ),
108
+ # gr.Slider(
109
+ # label="Top-p (nucleus sampling)",
110
+ # value=0.90,
111
+ # minimum=0.0,
112
+ # maximum=1,
113
+ # step=0.05,
114
+ # interactive=True,
115
+ # info="Higher values sample more low-probability tokens",
116
+ # ),
117
+ # gr.Slider(
118
+ # label="Repetition penalty",
119
+ # value=1.2,
120
+ # minimum=1.0,
121
+ # maximum=2.0,
122
+ # step=0.05,
123
+ # interactive=True,
124
+ # info="Penalize repeated tokens",
125
+ # )
126
+ # ]
127
+
128
+ # examples=[["What are they doing here https://upload.wikimedia.org/wikipedia/commons/3/38/Two_dancers.jpg ?", None, None, None, None, None]]
129
+
130
+ # gr.ChatInterface(
131
+ # fn=generate,
132
+ # chatbot=gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False),
133
+ # additional_inputs=additional_inputs,
134
+ # title="Gemma Gemini Multimodal Chatbot",
135
+ # description="Gemini Sprint submission by Rishiraj Acharya. Uses Google's Gemini 1.0 Pro Vision multimodal model from Vertex AI with Google's Gemma 7B Instruct model from Hugging Face. Google Cloud credits are provided for this project.",
136
+ # theme="Soft",
137
+ # examples=examples,
138
+ # concurrency_limit=20,
139
+ # ).launch(show_api=False)
140
+
141
+
142
+
143
+
144
+ import random
145
+
146
+ models=[
147
+ "google/gemma-7b",
148
+ "google/gemma-7b-it",
149
+ "google/gemma-2b",
150
+ "google/gemma-2b-it"
151
+ ]
152
+ clients=[
153
+ InferenceClient(models[0]),
154
+ InferenceClient(models[1]),
155
+ InferenceClient(models[2]),
156
+ InferenceClient(models[3]),
157
+ ]
158
+
159
+ def load_models(inp):
160
+ return gr.update(label=models[inp])
161
+
162
  def format_prompt(message, history, cust_p):
163
  prompt = ""
164
+ if history:
165
+ for user_prompt, bot_response in history:
166
+ prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
167
+ prompt += f"<start_of_turn>model{bot_response}<end_of_turn>"
168
+ prompt+=cust_p.replace("USER_INPUT",message)
169
  return prompt
170
 
171
+ def chat_inf(system_prompt,prompt,history,memory,client_choice,seed,temp,tokens,top_p,rep_p,chat_mem,cust_p):
172
+ print(client_choice)
173
+ hist_len=0
174
+ client=clients[int(client_choice)-1]
175
+ if not history:
176
+ history = []
177
+ hist_len=0
178
+ if not memory:
179
+ memory = []
180
+ mem_len=0
181
+ if memory:
182
+ for ea in memory[0-chat_mem:]:
183
+ hist_len+=len(str(ea))
184
+ in_len=len(system_prompt+prompt)+hist_len
185
+
186
+ if (in_len+tokens) > 8000:
187
+ history.append((prompt,"Wait, that's too many tokens, please reduce the 'Chat Memory' value, or reduce the 'Max new tokens' value"))
188
+ yield history,memory
189
+ else:
190
+ generate_kwargs = dict(
191
+ temperature=temp,
192
+ max_new_tokens=tokens,
193
+ top_p=top_p,
194
+ repetition_penalty=rep_p,
195
+ do_sample=True,
196
+ seed=seed,
197
+ )
198
+
199
+ image = extract_image_urls(prompt)
200
+ if image:
201
+ image_description = "Image Description: " + search(image)
202
+ prompt = prompt.replace(image, image_description)
203
+ print(prompt)
204
+
205
+ if system_prompt:
206
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", memory[0-chat_mem:],cust_p)
207
+ else:
208
+ formatted_prompt = format_prompt(prompt, memory[0-chat_mem:],cust_p)
209
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
210
+ output = ""
211
+ for response in stream:
212
+ output += response.token.text
213
+ yield [(prompt,output)],memory
214
+ history.append((prompt,output))
215
+ memory.append((prompt,output))
216
+ yield history,memory
217
+
218
+ def clear_fn():
219
+ return None,None,None,None
220
+ rand_val=random.randint(1,1111111111111111)
221
+
222
+ def check_rand(inp,val):
223
+ if inp==True:
224
+ return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1,1111111111111111))
225
+ else:
226
+ return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
227
+
228
+ with gr.Blocks() as app:
229
+ memory=gr.State()
230
+ gr.HTML("""<center><h1 style='font-size:xx-large;'>Gemma Gemini Multimodal Chatbot</h1><br><h2>Gemini Sprint submission by Rishiraj Acharya. Uses Google's Gemini 1.0 Pro Vision multimodal model from Vertex AI with Google's Gemma 7B Instruct model from Hugging Face. Google Cloud credits are provided for this project.</h2>""")
231
+ chat_b = gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False)
232
+ with gr.Group():
233
+ with gr.Row():
234
+ with gr.Column(scale=3):
235
+ inp = gr.Textbox(label="Prompt")
236
+ sys_inp = gr.Textbox(label="System Prompt (optional)")
237
+ with gr.Accordion("Prompt Format",open=False):
238
+ custom_prompt=gr.Textbox(label="Modify Prompt Format", info="For testing purposes. 'USER_INPUT' is where 'SYSTEM_PROMPT, PROMPT' will be placed", lines=3,value="<start_of_turn>userUSER_INPUT<end_of_turn><start_of_turn>model")
239
+ with gr.Row():
240
+ with gr.Column(scale=2):
241
+ btn = gr.Button("Chat")
242
+ with gr.Column(scale=1):
243
+ with gr.Group():
244
+ stop_btn=gr.Button("Stop")
245
+ clear_btn=gr.Button("Clear")
246
+ client_choice=gr.Dropdown(label="Models",type='index',choices=[c for c in models],value=models[0],interactive=True)
247
+ with gr.Column(scale=1):
248
+ with gr.Group():
249
+ rand = gr.Checkbox(label="Random Seed", value=True)
250
+ seed=gr.Slider(label="Seed", minimum=1, maximum=1111111111111111,step=1, value=rand_val)
251
+ tokens = gr.Slider(label="Max new tokens",value=1600,minimum=0,maximum=8000,step=64,interactive=True, visible=True,info="The maximum number of tokens")
252
+ temp=gr.Slider(label="Temperature",step=0.01, minimum=0.01, maximum=1.0, value=0.49)
253
+ top_p=gr.Slider(label="Top-P",step=0.01, minimum=0.01, maximum=1.0, value=0.49)
254
+ rep_p=gr.Slider(label="Repetition Penalty",step=0.01, minimum=0.1, maximum=2.0, value=0.99)
255
+ chat_mem=gr.Number(label="Chat Memory", info="Number of previous chats to retain",value=4)
256
 
257
+
258
+ client_choice.change(load_models,client_choice,[chat_b])
259
+ app.load(load_models,client_choice,[chat_b])
260
+
261
+ chat_sub=inp.submit(check_rand,[rand,seed],seed).then(chat_inf,[sys_inp,inp,chat_b,memory,client_choice,seed,temp,tokens,top_p,rep_p,chat_mem,custom_prompt],[chat_b,memory])
262
+ go=btn.click(check_rand,[rand,seed],seed).then(chat_inf,[sys_inp,inp,chat_b,memory,client_choice,seed,temp,tokens,top_p,rep_p,chat_mem,custom_prompt],[chat_b,memory])
263
+
264
+ stop_btn.click(None,None,None,cancels=[go,im_go,chat_sub])
265
+ clear_btn.click(clear_fn,None,[inp,sys_inp,chat_b,memory])
266
+ app.queue(default_concurrency_limit=10).launch()