Spaces:
Runtime error
Runtime error
import string | |
import gradio as gr | |
import requests | |
import torch | |
from models.VLE import VLEForVQA, VLEProcessor, VLEForVQAPipeline | |
from PIL import Image | |
model_name="hfl/vle-base-for-vqa" | |
model = VLEForVQA.from_pretrained(model_name) | |
vle_processor = VLEProcessor.from_pretrained(model_name) | |
vqa_pipeline = VLEForVQAPipeline(model=model, device='cpu', vle_processor=vle_processor) | |
from transformers import BlipForQuestionAnswering, BlipProcessor | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
model_vqa = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large").to(device) | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
cap_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
cap_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
def caption(input_image): | |
inputs = cap_processor(input_image, return_tensors="pt") | |
inputs["num_beams"] = 1 | |
inputs['num_return_sequences'] =1 | |
out = cap_model.generate(**inputs) | |
return "\n".join(cap_processor.batch_decode(out, skip_special_tokens=True)) | |
import openai | |
import os | |
openai.api_key= os.getenv('openai_appkey') | |
def gpt3_short(question,vqa_answer,caption): | |
vqa_answer,vqa_score=vqa_answer | |
prompt="prompt: This is a picture of Caption: "+caption+". Question: "+question+" VQA model predicts:"+"A: "+vqa_answer[0]+"socre:"+str(vqa_score[0])+\ | |
" B: "+vqa_answer[1]+" score:"+str(vqa_score[1])+" C: "+vqa_answer[2]+" score:"+str(vqa_score[2])+\ | |
" D: "+vqa_answer[3]+'score:'+str(vqa_score[3])+\ | |
". Choose A if it is not in conflict with the description of the picture and A's score is bigger than 0.8; otherwise choose the B, C or D based on the description. Answer with A or B or C or D." | |
# prompt=caption+"\n"+question+"\n"+vqa_answer+"\n Tell me the right answer." | |
response = openai.Completion.create( | |
engine="text-davinci-003", | |
prompt=prompt, | |
max_tokens=30, | |
n=1, | |
stop=None, | |
temperature=0.7, | |
) | |
answer = response.choices[0].text.strip() | |
llm_ans=answer | |
choice=set(["A","B","C","D"]) | |
llm_ans=llm_ans.replace("\n"," ").replace(":"," ").replace("."," " ).replace(","," ") | |
sllm_ans=llm_ans.split(" ") | |
for cho in sllm_ans: | |
if cho in choice: | |
llm_ans=cho | |
break | |
if llm_ans not in choice: | |
llm_ans="A" | |
llm_ans=vqa_answer[ord(llm_ans)-ord("A")] | |
answer=llm_ans | |
return answer | |
def gpt3(question,vqa_answer,caption): | |
prompt=caption+"\n"+question+"\n"+vqa_answer+"\n Tell me the right answer." | |
response = openai.Completion.create( | |
engine="text-davinci-003", | |
prompt=prompt, | |
max_tokens=30, | |
n=1, | |
stop=None, | |
temperature=0.7, | |
) | |
answer = response.choices[0].text.strip() | |
# return "input_text:\n"+prompt+"\n\n output_answer:\n"+answer | |
return answer | |
def vle(input_image,input_text): | |
vqa_answers = vqa_pipeline({"image":input_image, "question":input_text}, top_k=4) | |
# return [" ".join([str(value) for key,value in vqa.items()] )for vqa in vqa_answers] | |
return [vqa['answer'] for vqa in vqa_answers],[vqa['score'] for vqa in vqa_answers] | |
def inference_chat(input_image,input_text): | |
cap=caption(input_image) | |
# inputs = processor(images=input_image, text=input_text,return_tensors="pt") | |
# inputs["max_length"] = 10 | |
# inputs["num_beams"] = 5 | |
# inputs['num_return_sequences'] =4 | |
# out = model_vqa.generate(**inputs) | |
# out=processor.batch_decode(out, skip_special_tokens=True) | |
out=vle(input_image,input_text) | |
vqa="\n".join(out[0]) | |
gpt3_out=gpt3(input_text,vqa,cap) | |
gpt3_out1=gpt3_short(input_text,out,cap) | |
return out[0][0], gpt3_out,gpt3_out1 | |
title = """<h1 align="center">VQA</h1>""" | |
with gr.Blocks( | |
css=""" | |
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px} | |
#component-21 > div.wrap.svelte-w6rprc {height: 600px;} | |
""" | |
) as iface: | |
state = gr.State([]) | |
#caption_output = None | |
gr.Markdown(title) | |
# gr.Markdown(description) | |
#gr.Markdown(article) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="pil",label="VQA Image Input") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
chat_input = gr.Textbox(lines=1, label="VQA Question Input") | |
with gr.Row(): | |
clear_button = gr.Button(value="Clear", interactive=True) | |
submit_button = gr.Button( | |
value="Submit", interactive=True, variant="primary" | |
) | |
''' | |
cap_submit_button = gr.Button( | |
value="Submit_CAP", interactive=True, variant="primary" | |
) | |
gpt3_submit_button = gr.Button( | |
value="Submit_GPT3", interactive=True, variant="primary" | |
) | |
''' | |
with gr.Column(): | |
caption_output = gr.Textbox(lines=0, label="VQA ") | |
caption_output_v1 = gr.Textbox(lines=0, label="VQA+LLM (short answer)") | |
gpt3_output_v1 = gr.Textbox(lines=0, label="VQA+LLM (long answer)") | |
# image_input.change( | |
# lambda: ("", [],"","",""), | |
# [], | |
# [ caption_output, state,caption_output,gpt3_output_v1,caption_output_v1], | |
# queue=False, | |
# ) | |
chat_input.submit( | |
inference_chat, | |
[ | |
image_input, | |
chat_input, | |
], | |
[ caption_output], | |
) | |
clear_button.click( | |
lambda: ("", [],"","",""), | |
[], | |
[chat_input, state,caption_output,gpt3_output_v1,caption_output_v1], | |
queue=False, | |
) | |
submit_button.click( | |
inference_chat, | |
[ | |
image_input, | |
chat_input, | |
], | |
[caption_output,gpt3_output_v1,caption_output_v1], | |
) | |
''' | |
cap_submit_button.click( | |
caption, | |
[ | |
image_input, | |
], | |
[caption_output_v1], | |
) | |
gpt3_submit_button.click( | |
gpt3, | |
[ | |
chat_input, | |
caption_output , | |
caption_output_v1, | |
], | |
[gpt3_output_v1], | |
) | |
''' | |
examples=[['bird.jpeg',"How many birds are there in the tree?","2","2","2"], | |
['qa9.jpg',"What type of vehicle is being pulled by the horses ?",'carriage','carriage','Sled'], | |
['upload4.jpg',"What is this old man doing?","fishing","fishing","Fishing"]] | |
examples = gr.Examples( | |
examples=examples,inputs=[image_input, chat_input,caption_output,caption_output_v1,gpt3_output_v1], | |
) | |
iface.queue(concurrency_count=1, api_open=False, max_size=10) | |
iface.launch(enable_queue=True) |