BLIP2 / app.py
root
update example to run on click
b9e7cf7
raw
history blame
No virus
7.16 kB
from io import BytesIO
import string
import gradio as gr
import requests
from utils import Endpoint
def encode_image(image):
buffered = BytesIO()
image.save(buffered, format="JPEG")
buffered.seek(0)
return buffered
def query_chat_api(
image, prompt, decoding_method, temperature, len_penalty, repetition_penalty
):
url = endpoint.url
headers = {"User-Agent": "BLIP-2 HuggingFace Space"}
data = {
"prompt": prompt,
"use_nucleus_sampling": decoding_method == "Nucleus sampling",
"temperature": temperature,
"length_penalty": len_penalty,
"repetition_penalty": repetition_penalty,
}
image = encode_image(image)
files = {"image": image}
response = requests.post(url, data=data, files=files, headers=headers)
if response.status_code == 200:
return response.json()
else:
return "Error: " + response.text
def query_caption_api(
image, decoding_method, temperature, len_penalty, repetition_penalty
):
url = endpoint.url
# replace /generate with /caption
url = url.replace("/generate", "/caption")
headers = {"User-Agent": "BLIP-2 HuggingFace Space"}
data = {
"use_nucleus_sampling": decoding_method == "Nucleus sampling",
"temperature": temperature,
"length_penalty": len_penalty,
"repetition_penalty": repetition_penalty,
}
image = encode_image(image)
files = {"image": image}
response = requests.post(url, data=data, files=files, headers=headers)
if response.status_code == 200:
return response.json()
else:
return "Error: " + response.text
def postprocess_output(output):
# if last character is not a punctuation, add a full stop
if not output[0][-1] in string.punctuation:
output[0] += "."
return output
def inference_chat(
image,
text_input,
decoding_method,
temperature,
length_penalty,
repetition_penalty,
history=[],
):
text_input = text_input
history.append(text_input)
prompt = " ".join(history)
print(prompt)
output = query_chat_api(
image, prompt, decoding_method, temperature, length_penalty, repetition_penalty
)
output = postprocess_output(output)
history += output
chat = [
(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)
] # convert to tuples of list
return {chatbot: chat, state: history}
def inference_caption(
image,
decoding_method,
temperature,
length_penalty,
repetition_penalty,
):
output = query_caption_api(
image, decoding_method, temperature, length_penalty, repetition_penalty
)
return output[0]
title = """<h1 align="center">BLIP-2</h1>"""
description = """Gradio demo for BLIP-2, a multimodal chatbot from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Please visit our <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'>project webpage</a>.</p>
<p> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected. </p>"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>"
endpoint = Endpoint()
examples = [
["house.png", "How could someone get out of the house?"],
# [
# "sunset.png",
# "Write a romantic message that goes along this photo.",
# ],
]
with gr.Blocks() as iface:
state = gr.State([])
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown(article)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil")
with gr.Row():
sampling = gr.Radio(
choices=["Beam search", "Nucleus sampling"],
value="Beam search",
label="Text Decoding Method",
interactive=True,
)
temperature = gr.Slider(
minimum=0.5,
maximum=1.0,
value=0.8,
interactive=True,
label="Temperature (set to 0 for greedy decoding with nucleus sampling)",
)
len_penalty = gr.Slider(
minimum=-2.0,
maximum=2.0,
value=1.0,
step=0.5,
interactive=True,
label="Length Penalty (larger value encourages longer sequence with beam search)",
)
rep_penalty = gr.Slider(
minimum=1.0,
maximum=5.0,
value=1.5,
step=0.5,
interactive=True,
label="Repeat Penalty (larger value prevents repetition)",
)
with gr.Row():
caption_output = gr.Textbox(lines=2, label="Caption Output")
caption_button = gr.Button(
value="Caption it!", interactive=True, variant="primary"
)
caption_button.click(
inference_caption,
[
image_input,
sampling,
temperature,
len_penalty,
rep_penalty,
],
[caption_output],
)
with gr.Column():
chat_input = gr.Textbox(lines=2, label="Chat Input")
with gr.Row():
chatbot = gr.Chatbot()
image_input.change(lambda: (None, "", "", []), [], [chatbot, chat_input, caption_output, state])
with gr.Row():
clear_button = gr.Button(value="Clear", interactive=True)
clear_button.click(
lambda: ("", None, [], []),
[],
[chat_input, image_input, chatbot, state],
)
submit_button = gr.Button(
value="Submit", interactive=True, variant="primary"
)
submit_button.click(
inference_chat,
[
image_input,
chat_input,
sampling,
temperature,
len_penalty,
rep_penalty,
state,
],
[chatbot, state],
)
examples = gr.Examples(
examples=examples,
inputs=[image_input, chat_input, sampling, temperature, len_penalty, rep_penalty, state],
outputs=[chatbot, state],
run_on_click=True,
fn = inference_chat,
)
iface.queue(concurrency_count=1, api_open=False, max_size=10)
iface.launch(enable_queue=True)