VQA_CAP_GPT / app.py
xxx1's picture
Update app.py
f2977fa
raw
history blame
4.85 kB
import string
import gradio as gr
import requests
import torch
from transformers import BlipForQuestionAnswering, BlipProcessor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
model_vqa = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large").to(device)
from transformers import BlipProcessor, BlipForConditionalGeneration
cap_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
cap_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
def caption(input_image):
inputs = cap_processor(input_image, return_tensors="pt")
inputs["num_beams"] = 4
inputs['num_return_sequences'] =4
out = cap_model.generate(**inputs)
return "\n".join(cap_processor.batch_decode(out, skip_special_tokens=True))
import openai
openai.api_key="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"
def gpt3(question,vqa_answer,caption):
prompt=caption+"\n"+question+"\n"+vqa_answer+"\n Tell me the right answer."
response = openai.Completion.create(
engine="text-davinci-003",
prompt=prompt,
max_tokens=10,
n=1,
stop=None,
temperature=0.7,
)
answer = response.choices[0].text.strip()
return "input_text:\n"+prompt+"\n\n output_answer:\n"+answer
def inference_chat(input_image,input_text):
inputs = processor(images=input_image, text=input_text,return_tensors="pt")
inputs["max_length"] = 10
inputs["num_beams"] = 5
inputs['num_return_sequences'] =4
out = model_vqa.generate(**inputs)
return "\n".join(processor.batch_decode(out, skip_special_tokens=True))
with gr.Blocks(
css="""
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
#component-21 > div.wrap.svelte-w6rprc {height: 600px;}
"""
) as iface:
state = gr.State([])
#caption_output = None
#gr.Markdown(title)
#gr.Markdown(description)
#gr.Markdown(article)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil")
with gr.Row():
with gr.Column(scale=1):
chat_input = gr.Textbox(lines=1, label="VQA Input(问题输入)")
with gr.Row():
clear_button = gr.Button(value="Clear", interactive=True)
submit_button = gr.Button(
value="Submit_VQA", interactive=True, variant="primary"
)
cap_submit_button = gr.Button(
value="Submit_CAP", interactive=True, variant="primary"
)
gpt3_submit_button = gr.Button(
value="Submit_GPT3", interactive=True, variant="primary"
)
with gr.Column():
caption_output = gr.Textbox(lines=0, label="VQA Output(模型答案输出)")
caption_output_v1 = gr.Textbox(lines=0, label="Caption Output(模型caption输出)")
gpt3_output_v1 = gr.Textbox(lines=0, label="GPT3 Output(GPT3输出)")
image_input.change(
lambda: ("", "", []),
[],
[ caption_output, state],
queue=False,
)
chat_input.submit(
inference_chat,
[
image_input,
chat_input,
],
[ caption_output],
)
clear_button.click(
lambda: ("", [], []),
[],
[chat_input, state],
queue=False,
)
submit_button.click(
inference_chat,
[
image_input,
chat_input,
],
[caption_output],
)
cap_submit_button.click(
caption,
[
image_input,
],
[caption_output_v1],
)
gpt3_submit_button.click(
gpt3,
[
chat_input,
caption_output ,
caption_output_v1,
],
[gpt3_output_v1],
)
# examples = gr.Examples(
# examples=examples,
# inputs=[image_input, chat_input],
# )
iface.queue(concurrency_count=1, api_open=False, max_size=10)
iface.launch(enable_queue=True)