VQA_CAP_GPT / app.py
xxx1's picture
Update app.py
5099b54
raw
history blame
7.57 kB
import string
import gradio as gr
import requests
import torch
from models.VLE import VLEForVQA, VLEProcessor, VLEForVQAPipeline
from PIL import Image
model_name="hfl/vle-base-for-vqa"
model = VLEForVQA.from_pretrained(model_name)
vle_processor = VLEProcessor.from_pretrained(model_name)
vqa_pipeline = VLEForVQAPipeline(model=model, device='cpu', vle_processor=vle_processor)
from transformers import BlipForQuestionAnswering, BlipProcessor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
model_vqa = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large").to(device)
from transformers import BlipProcessor, BlipForConditionalGeneration
cap_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
cap_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
def caption(input_image):
inputs = cap_processor(input_image, return_tensors="pt")
inputs["num_beams"] = 1
inputs['num_return_sequences'] =1
out = cap_model.generate(**inputs)
return "\n".join(cap_processor.batch_decode(out, skip_special_tokens=True))
import openai
import os
openai.api_key= os.getenv('openai_appkey')
def gpt3_short(question,vqa_answer,caption):
vqa_answer,vqa_score=vqa_answer
prompt="prompt: This is a picture of Caption: "+caption+". Question: "+question+" VQA model predicts:"+"A: "+vqa_answer[0]+"socre:"+str(vqa_score[0])+\
" B: "+vqa_answer[1]+" score:"+str(vqa_score[1])+" C: "+vqa_answer[2]+" score:"+str(vqa_score[2])+\
" D: "+vqa_answer[3]+'score:'+str(vqa_score[3])+\
". Choose A if it is not in conflict with the description of the picture and A's score is bigger than 0.8; otherwise choose the B, C or D based on the description. Answer with A or B or C or D."
# prompt=caption+"\n"+question+"\n"+vqa_answer+"\n Tell me the right answer."
response = openai.Completion.create(
engine="text-davinci-003",
prompt=prompt,
max_tokens=30,
n=1,
stop=None,
temperature=0.7,
)
answer = response.choices[0].text.strip()
llm_ans=answer
choice=set(["A","B","C","D"])
llm_ans=llm_ans.replace("\n"," ").replace(":"," ").replace("."," " ).replace(","," ")
sllm_ans=llm_ans.split(" ")
for cho in sllm_ans:
if cho in choice:
llm_ans=cho
break
if llm_ans not in choice:
llm_ans="A"
llm_ans=vqa_answer[ord(llm_ans)-ord("A")]
answer=llm_ans
return answer
def gpt3(question,vqa_answer,caption):
prompt=caption+"\n"+question+"\n"+vqa_answer+"\n Tell me the right answer."
response = openai.Completion.create(
engine="text-davinci-003",
prompt=prompt,
max_tokens=30,
n=1,
stop=None,
temperature=0.7,
)
answer = response.choices[0].text.strip()
# return "input_text:\n"+prompt+"\n\n output_answer:\n"+answer
return answer
def vle(input_image,input_text):
vqa_answers = vqa_pipeline({"image":input_image, "question":input_text}, top_k=4)
# return [" ".join([str(value) for key,value in vqa.items()] )for vqa in vqa_answers]
return [vqa['answer'] for vqa in vqa_answers],[vqa['score'] for vqa in vqa_answers]
def inference_chat(input_image,input_text):
cap=caption(input_image)
# inputs = processor(images=input_image, text=input_text,return_tensors="pt")
# inputs["max_length"] = 10
# inputs["num_beams"] = 5
# inputs['num_return_sequences'] =4
# out = model_vqa.generate(**inputs)
# out=processor.batch_decode(out, skip_special_tokens=True)
out=vle(input_image,input_text)
vqa="\n".join(out[0])
gpt3_out=gpt3(input_text,vqa,cap)
gpt3_out1=gpt3_short(input_text,out,cap)
return out[0][0], gpt3_out,gpt3_out1
title = """<h1 align="center">VQA</h1>"""
with gr.Blocks(
css="""
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
#component-21 > div.wrap.svelte-w6rprc {height: 600px;}
"""
) as iface:
state = gr.State([])
#caption_output = None
gr.Markdown(title)
# gr.Markdown(description)
#gr.Markdown(article)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil",label="VQA Image Input")
with gr.Row():
with gr.Column(scale=1):
chat_input = gr.Textbox(lines=1, label="VQA Question Input")
with gr.Row():
clear_button = gr.Button(value="Clear", interactive=True)
submit_button = gr.Button(
value="Submit", interactive=True, variant="primary"
)
'''
cap_submit_button = gr.Button(
value="Submit_CAP", interactive=True, variant="primary"
)
gpt3_submit_button = gr.Button(
value="Submit_GPT3", interactive=True, variant="primary"
)
'''
with gr.Column():
caption_output = gr.Textbox(lines=0, label="VQA ")
caption_output_v1 = gr.Textbox(lines=0, label="VQA+LLM (short answer)")
gpt3_output_v1 = gr.Textbox(lines=0, label="VQA+LLM (long answer)")
# image_input.change(
# lambda: ("", [],"","",""),
# [],
# [ caption_output, state,caption_output,gpt3_output_v1,caption_output_v1],
# queue=False,
# )
chat_input.submit(
inference_chat,
[
image_input,
chat_input,
],
[ caption_output],
)
clear_button.click(
lambda: ("", [],"","",""),
[],
[chat_input, state,caption_output,gpt3_output_v1,caption_output_v1],
queue=False,
)
submit_button.click(
inference_chat,
[
image_input,
chat_input,
],
[caption_output,gpt3_output_v1,caption_output_v1],
)
'''
cap_submit_button.click(
caption,
[
image_input,
],
[caption_output_v1],
)
gpt3_submit_button.click(
gpt3,
[
chat_input,
caption_output ,
caption_output_v1,
],
[gpt3_output_v1],
)
'''
examples=[['bird.jpeg',"How many birds are there in the tree?","2","2","2"],
['qa9.jpg',"What type of vehicle is being pulled by the horses ?",'carriage','carriage','Sled'],
['upload4.jpg',"What is this old man doing?","fishing","fishing","Fishing"]]
examples = gr.Examples(
examples=examples,inputs=[image_input, chat_input,caption_output,caption_output_v1,gpt3_output_v1],
)
iface.queue(concurrency_count=1, api_open=False, max_size=10)
iface.launch(enable_queue=True)