heygemini / app.py
rishiraj's picture
Update app.py
6cc6de1 verified
from huggingface_hub import InferenceClient
import gradio as gr
import os
import re
import requests
import random
import http.client
import typing
import urllib.request
import vertexai
from vertexai.generative_models import GenerativeModel, Image
with open(".config/application_default_credentials.json", 'w') as file:
file.write(str(os.getenv('credentials')))
vertexai.init(project=os.getenv('project_id'))
model = GenerativeModel("gemini-1.0-pro-vision")
client = InferenceClient("google/gemma-7b-it")
def extract_image_urls(text):
url_regex = r"(https?:\/\/.*\.(?:png|jpg|jpeg|gif|webp|svg))"
image_urls = re.findall(url_regex, text, flags=re.IGNORECASE)
valid_image_url = ""
for url in image_urls:
try:
response = requests.head(url) # Use HEAD request for efficiency
if response.status_code in range(200, 300) and 'image' in response.headers.get('content-type', ''):
valid_image_url = url
except requests.exceptions.RequestException:
pass # Ignore inaccessible URLs
return valid_image_url
def load_image_from_url(image_url: str) -> Image:
with urllib.request.urlopen(image_url) as response:
response = typing.cast(http.client.HTTPResponse, response)
image_bytes = response.read()
return Image.from_bytes(image_bytes)
def search(url):
image = load_image_from_url(url)
response = model.generate_content([image,"Describe what is shown in this image."])
return response.text
def format_prompt(message, history, cust_p):
prompt = ""
if history:
for user_prompt, bot_response in history:
prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
prompt += f"<start_of_turn>model{bot_response}<end_of_turn>"
prompt+=cust_p.replace("USER_INPUT",message)
return prompt
def chat_inf(system_prompt,prompt,history,memory,seed,temp,tokens,top_p,rep_p,chat_mem,cust_p):
hist_len=0
if not history:
history = []
hist_len=0
if not memory:
memory = []
mem_len=0
if memory:
for ea in memory[0-chat_mem:]:
hist_len+=len(str(ea))
in_len=len(system_prompt+prompt)+hist_len
if (in_len+tokens) > 8000:
history.append((prompt,"Wait, that's too many tokens, please reduce the 'Chat Memory' value, or reduce the 'Max new tokens' value"))
yield history,memory
else:
generate_kwargs = dict(
temperature=temp,
max_new_tokens=tokens,
top_p=top_p,
repetition_penalty=rep_p,
do_sample=True,
seed=seed,
)
image = extract_image_urls(prompt)
if image:
image_description = "Image Description: " + search(image)
prompt = prompt.replace(image, image_description)
print(prompt)
if system_prompt:
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", memory[0-chat_mem:],cust_p)
else:
formatted_prompt = format_prompt(prompt, memory[0-chat_mem:],cust_p)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
output = ""
for response in stream:
output += response.token.text
yield [(prompt,output)],memory
history.append((prompt,output))
memory.append((prompt,output))
yield history,memory
def clear_fn():
return None,None,None,None
rand_val=random.randint(1,1111111111111111)
def check_rand(inp,val):
if inp==True:
return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1,1111111111111111))
else:
return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
with gr.Blocks(theme=gr.themes.Soft()) as app:
memory=gr.State()
gr.HTML("""<center><h1 style='font-size:xx-large;'>Gemma Gemini Multimodal Chatbot</h1><br><h3>Gemini Sprint submission by Rishiraj Acharya. Uses Google's Gemini 1.0 Pro Vision multimodal model from Vertex AI with Google's Gemma 7B Instruct model from Hugging Face. Google Cloud credits are provided for this project.</h3>""")
chat_b = gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False)
with gr.Group():
inp = gr.Textbox(label="User Prompt")
sys_inp = gr.Textbox(label="System Prompt")
with gr.Accordion("Settings",open=False):
custom_prompt=gr.Textbox(label="Modify Prompt Format", info="For testing purposes. 'USER_INPUT' is where 'SYSTEM_PROMPT, PROMPT' will be placed", lines=3,value="<start_of_turn>userUSER_INPUT<end_of_turn><start_of_turn>model")
rand = gr.Checkbox(label="Random Seed", value=True)
seed=gr.Slider(label="Seed", minimum=1, maximum=1111111111111111,step=1, value=rand_val)
tokens = gr.Slider(label="Max new tokens",value=1600,minimum=0,maximum=8000,step=64,interactive=True, visible=True,info="The maximum number of tokens")
temp=gr.Slider(label="Temperature",step=0.01, minimum=0.01, maximum=1.0, value=0.49)
top_p=gr.Slider(label="Top-P",step=0.01, minimum=0.01, maximum=1.0, value=0.49)
rep_p=gr.Slider(label="Repetition Penalty",step=0.01, minimum=0.1, maximum=2.0, value=0.99)
chat_mem=gr.Number(label="Chat Memory", info="Number of previous chats to retain",value=4)
with gr.Group():
with gr.Row():
btn = gr.Button("Chat", variant="primary")
stop_btn = gr.Button("Stop", variant="stop")
clear_btn = gr.Button("Clear", variant="secondary")
chat_sub=inp.submit(check_rand,[rand,seed],seed).then(chat_inf,[sys_inp,inp,chat_b,memory,seed,temp,tokens,top_p,rep_p,chat_mem,custom_prompt],[chat_b,memory])
go=btn.click(check_rand,[rand,seed],seed).then(chat_inf,[sys_inp,inp,chat_b,memory,seed,temp,tokens,top_p,rep_p,chat_mem,custom_prompt],[chat_b,memory])
stop_btn.click(None,None,None,cancels=[go,chat_sub])
clear_btn.click(clear_fn,None,[inp,sys_inp,chat_b,memory])
app.queue(default_concurrency_limit=10).launch()