Spaces:
Runtime error
Runtime error
import string | |
import gradio as gr | |
import requests | |
import torch | |
from models.VLE import VLEForVQA, VLEProcessor, VLEForVQAPipeline | |
from PIL import Image | |
model_name="hfl/vle-base-for-vqa" | |
model = VLEForVQA.from_pretrained(model_name) | |
vle_processor = VLEProcessor.from_pretrained(model_name) | |
vqa_pipeline = VLEForVQAPipeline(model=model, device='cpu', vle_processor=vle_processor) | |
from transformers import BlipForQuestionAnswering, BlipProcessor | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
model_vqa = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large").to(device) | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
cap_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
cap_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
def caption(input_image): | |
inputs = cap_processor(input_image, return_tensors="pt") | |
inputs["num_beams"] = 1 | |
inputs['num_return_sequences'] =1 | |
out = cap_model.generate(**inputs) | |
return "\n".join(cap_processor.batch_decode(out, skip_special_tokens=True)) | |
import openai | |
import os | |
openai.api_key= os.getenv('openai_appkey') | |
def gpt3_short(question,vqa_answer,caption): | |
vqa_answer,vqa_score=vqa_answer | |
prompt="prompt: This is the caption of a picture: "+caption+". Question: "+question+" VQA model predicts:"+"A: "+vqa_answer[0]+"socre:"+str(vqa_score[0])+\ | |
" B: "+vqa_answer[1]+" score:"+str(vqa_score[1])+" C: "+vqa_answer[2]+" score:"+str(vqa_score[2])+\ | |
" D: "+vqa_answer[3]+'score:'+str(vqa_score[3])+\ | |
". Choose A if it is not in conflict with the description of the picture and A's score is bigger than 0.8; otherwise choose the B, C or D based on the description. Answer with A or B or C or D." | |
# prompt=caption+"\n"+question+"\n"+vqa_answer+"\n Tell me the right answer." | |
response = openai.Completion.create( | |
engine="text-davinci-003", | |
prompt=prompt, | |
max_tokens=30, | |
n=1, | |
stop=None, | |
temperature=0.7, | |
) | |
answer = response.choices[0].text.strip() | |
llm_ans=answer | |
choice=set(["A","B","C","D"]) | |
llm_ans=llm_ans.replace("\n"," ").replace(":"," ").replace("."," " ).replace(","," ") | |
sllm_ans=llm_ans.split(" ") | |
for cho in sllm_ans: | |
if cho in choice: | |
llm_ans=cho | |
break | |
if llm_ans not in choice: | |
llm_ans="A" | |
llm_ans=vqa_answer[ord(llm_ans)-ord("A")] | |
answer=llm_ans | |
return answer | |
def gpt3(question,vqa_answer,caption): | |
prompt=caption+"\n"+question+"\n"+vqa_answer+"\n Tell me the right answer." | |
response = openai.Completion.create( | |
engine="text-davinci-003", | |
prompt=prompt, | |
max_tokens=30, | |
n=1, | |
stop=None, | |
temperature=0.7, | |
) | |
answer = response.choices[0].text.strip() | |
# return "input_text:\n"+prompt+"\n\n output_answer:\n"+answer | |
return answer | |
def vle(input_image,input_text): | |
vqa_answers = vqa_pipeline({"image":input_image, "question":input_text}, top_k=4) | |
# return [" ".join([str(value) for key,value in vqa.items()] )for vqa in vqa_answers] | |
return [vqa['answer'] for vqa in vqa_answers],[vqa['score'] for vqa in vqa_answers] | |
def inference_chat(input_image,input_text): | |
cap=caption(input_image) | |
# inputs = processor(images=input_image, text=input_text,return_tensors="pt") | |
# inputs["max_length"] = 10 | |
# inputs["num_beams"] = 5 | |
# inputs['num_return_sequences'] =4 | |
# out = model_vqa.generate(**inputs) | |
# out=processor.batch_decode(out, skip_special_tokens=True) | |
out=vle(input_image,input_text) | |
vqa="\n".join(out[0]) | |
gpt3_out=gpt3(input_text,vqa,cap) | |
gpt3_out1=gpt3_short(input_text,out,cap) | |
return out[0][0], gpt3_out,gpt3_out1 | |
title = """# VQA with VLE and LLM""" | |
description = """We demonstrate three visual question answering systems built with VLE and LLM: | |
* VQA: The image and the question are fed into a VQA model (VLEForVQA) and the model predicts the answer. | |
* VQA + LLM (short answer): The captioning model generates a caption of the image. We feed the caption, the question, and the answer candidates predicted by the VQA model to the LLM, and ask the LLM to select the most reasonable answer from the candidates. | |
* VQA + LLM (long answer): The pipeline is the same as VQA + LLM (short answer), except that the answer is freely generated by the LLM and not limited to VQA candidates. | |
For more details about VLE and the VQA pipeline, see [http://vle.hfl-rc.com](http://vle.hfl-rc.com)""" | |
with gr.Blocks( | |
css=""" | |
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px} | |
#component-21 > div.wrap.svelte-w6rprc {height: 600px;} | |
""" | |
) as iface: | |
state = gr.State([]) | |
#caption_output = None | |
gr.Markdown(title) | |
gr.Markdown(description) | |
#gr.Markdown(article) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="pil",label="VQA Image Input") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
chat_input = gr.Textbox(lines=1, label="VQA Question Input") | |
with gr.Row(): | |
clear_button = gr.Button(value="Clear", interactive=True) | |
submit_button = gr.Button( | |
value="Submit", interactive=True, variant="primary" | |
) | |
''' | |
cap_submit_button = gr.Button( | |
value="Submit_CAP", interactive=True, variant="primary" | |
) | |
gpt3_submit_button = gr.Button( | |
value="Submit_GPT3", interactive=True, variant="primary" | |
) | |
''' | |
with gr.Column(): | |
caption_output = gr.Textbox(lines=0, label="VQA ") | |
caption_output_v1 = gr.Textbox(lines=0, label="VQA+LLM (short answer)") | |
gpt3_output_v1 = gr.Textbox(lines=0, label="VQA+LLM (long answer)") | |
# image_input.change( | |
# lambda: ("", [],"","",""), | |
# [], | |
# [ caption_output, state,caption_output,gpt3_output_v1,caption_output_v1], | |
# queue=False, | |
# ) | |
chat_input.submit( | |
inference_chat, | |
[ | |
image_input, | |
chat_input, | |
], | |
[ caption_output], | |
) | |
clear_button.click( | |
lambda: ("", [],"","",""), | |
[], | |
[chat_input, state,caption_output,gpt3_output_v1,caption_output_v1], | |
queue=False, | |
) | |
submit_button.click( | |
inference_chat, | |
[ | |
image_input, | |
chat_input, | |
], | |
[caption_output,gpt3_output_v1,caption_output_v1], | |
) | |
''' | |
cap_submit_button.click( | |
caption, | |
[ | |
image_input, | |
], | |
[caption_output_v1], | |
) | |
gpt3_submit_button.click( | |
gpt3, | |
[ | |
chat_input, | |
caption_output , | |
caption_output_v1, | |
], | |
[gpt3_output_v1], | |
) | |
''' | |
examples=[['bird.jpeg',"How many birds are there in the tree?","2","2","2"], | |
['qa9.jpg',"What type of vehicle is being pulled by the horses ?",'carriage','sled','Sled'], | |
['upload4.jpg',"What is this old man doing?","fishing","fishing","Fishing"]] | |
examples = gr.Examples( | |
examples=examples,inputs=[image_input, chat_input,caption_output,caption_output_v1,gpt3_output_v1], | |
) | |
iface.queue(concurrency_count=1, api_open=False, max_size=10) | |
iface.launch(enable_queue=True) |