VQA_CAP_GPT / app.py
xxx1's picture
Update app.py
a00a3a2
raw
history blame
No virus
5.15 kB
import string
import gradio as gr
import requests
import torch
from transformers import BlipForQuestionAnswering, BlipProcessor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
model_vqa = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large").to(device)
from transformers import BlipProcessor, BlipForConditionalGeneration
cap_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
cap_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
def caption(input_image):
inputs = cap_processor(input_image, return_tensors="pt")
inputs["num_beams"] = 1
inputs['num_return_sequences'] =1
out = cap_model.generate(**inputs)
return "\n".join(cap_processor.batch_decode(out, skip_special_tokens=True))
import openai
import os
openai.api_key= os.getenv('openai_appkey')
def gpt3(question,vqa_answer,caption):
prompt=caption+"\n"+question+"\n"+vqa_answer+"\n Tell me the right answer."
response = openai.Completion.create(
engine="text-davinci-003",
prompt=prompt,
max_tokens=10,
n=1,
stop=None,
temperature=0.7,
)
answer = response.choices[0].text.strip()
# return "input_text:\n"+prompt+"\n\n output_answer:\n"+answer
return answer
def inference_chat(input_image,input_text):
cap=caption(input_image)
inputs = processor(images=input_image, text=input_text,return_tensors="pt")
inputs["max_length"] = 10
inputs["num_beams"] = 5
inputs['num_return_sequences'] =4
out = model_vqa.generate(**inputs)
out=processor.batch_decode(out, skip_special_tokens=True)
vqa="\n".join(out)
gpt3_out=gpt3(input_text,vqa,cap)
gpt3_out1=gpt3(input_text,'',cap)
return out[0], gpt3_out,gpt3_out1
with gr.Blocks(
css="""
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
#component-21 > div.wrap.svelte-w6rprc {height: 600px;}
"""
) as iface:
state = gr.State([])
#caption_output = None
#gr.Markdown(title)
#gr.Markdown(description)
#gr.Markdown(article)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil")
with gr.Row():
with gr.Column(scale=1):
chat_input = gr.Textbox(lines=1, label="VQA Input(问题输入)")
with gr.Row():
clear_button = gr.Button(value="Clear", interactive=True)
submit_button = gr.Button(
value="VQA", interactive=True, variant="primary"
)
'''
cap_submit_button = gr.Button(
value="Submit_CAP", interactive=True, variant="primary"
)
gpt3_submit_button = gr.Button(
value="Submit_GPT3", interactive=True, variant="primary"
)
'''
with gr.Column():
caption_output = gr.Textbox(lines=0, label="VQA ")
gpt3_output_v1 = gr.Textbox(lines=0, label="VQA+LLM")
caption_output_v1 = gr.Textbox(lines=0, label="CAP+LLM")
image_input.change(
lambda: ("", [],"","",""),
[],
[ caption_output, state,caption_output,gpt3_output_v1,caption_output_v1],
queue=False,
)
chat_input.submit(
inference_chat,
[
image_input,
chat_input,
],
[ caption_output],
)
clear_button.click(
lambda: ("", [],"","",""),
[],
[chat_input, state,caption_output,gpt3_output_v1,caption_output_v1],
queue=False,
)
submit_button.click(
inference_chat,
[
image_input,
chat_input,
],
[caption_output,gpt3_output_v1,caption_output_v1],
)
'''
cap_submit_button.click(
caption,
[
image_input,
],
[caption_output_v1],
)
gpt3_submit_button.click(
gpt3,
[
chat_input,
caption_output ,
caption_output_v1,
],
[gpt3_output_v1],
)
'''
# examples = gr.Examples(
# examples=examples,
# inputs=[image_input, chat_input],
# )
iface.queue(concurrency_count=1, api_open=False, max_size=10)
iface.launch(enable_queue=True)