Spaces:
Runtime error
Runtime error
import string | |
import gradio as gr | |
import requests | |
import torch | |
from transformers import BlipForQuestionAnswering, BlipProcessor | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
model_vqa = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large").to(device) | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
cap_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
cap_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
def caption(input_image): | |
inputs = cap_processor(input_image, return_tensors="pt") | |
inputs["num_beams"] = 1 | |
inputs['num_return_sequences'] =1 | |
out = cap_model.generate(**inputs) | |
return "\n".join(cap_processor.batch_decode(out, skip_special_tokens=True)) | |
import openai | |
import os | |
openai.api_key= os.getenv('openai_appkey') | |
def gpt3(question,vqa_answer,caption): | |
prompt=caption+"\n"+question+"\n"+vqa_answer+"\n Tell me the right answer." | |
response = openai.Completion.create( | |
engine="text-davinci-003", | |
prompt=prompt, | |
max_tokens=10, | |
n=1, | |
stop=None, | |
temperature=0.7, | |
) | |
answer = response.choices[0].text.strip() | |
# return "input_text:\n"+prompt+"\n\n output_answer:\n"+answer | |
return answer | |
def inference_chat(input_image,input_text): | |
cap=caption(input_image) | |
inputs = processor(images=input_image, text=input_text,return_tensors="pt") | |
inputs["max_length"] = 10 | |
inputs["num_beams"] = 5 | |
inputs['num_return_sequences'] =4 | |
out = model_vqa.generate(**inputs) | |
out=processor.batch_decode(out, skip_special_tokens=True) | |
vqa="\n".join(out) | |
gpt3_out=gpt3(input_text,vqa,cap) | |
gpt3_out1=gpt3(input_text,'',cap) | |
return out[0], gpt3_out,gpt3_out1 | |
with gr.Blocks( | |
css=""" | |
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px} | |
#component-21 > div.wrap.svelte-w6rprc {height: 600px;} | |
""" | |
) as iface: | |
state = gr.State([]) | |
#caption_output = None | |
#gr.Markdown(title) | |
#gr.Markdown(description) | |
#gr.Markdown(article) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="pil") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
chat_input = gr.Textbox(lines=1, label="VQA Input(问题输入)") | |
with gr.Row(): | |
clear_button = gr.Button(value="Clear", interactive=True) | |
submit_button = gr.Button( | |
value="VQA", interactive=True, variant="primary" | |
) | |
''' | |
cap_submit_button = gr.Button( | |
value="Submit_CAP", interactive=True, variant="primary" | |
) | |
gpt3_submit_button = gr.Button( | |
value="Submit_GPT3", interactive=True, variant="primary" | |
) | |
''' | |
with gr.Column(): | |
caption_output = gr.Textbox(lines=0, label="VQA ") | |
gpt3_output_v1 = gr.Textbox(lines=0, label="VQA+LLM") | |
caption_output_v1 = gr.Textbox(lines=0, label="CAP+LLM") | |
image_input.change( | |
lambda: ("", [],"","",""), | |
[], | |
[ caption_output, state,caption_output,gpt3_output_v1,caption_output_v1], | |
queue=False, | |
) | |
chat_input.submit( | |
inference_chat, | |
[ | |
image_input, | |
chat_input, | |
], | |
[ caption_output], | |
) | |
clear_button.click( | |
lambda: ("", [],"","",""), | |
[], | |
[chat_input, state,caption_output,gpt3_output_v1,caption_output_v1], | |
queue=False, | |
) | |
submit_button.click( | |
inference_chat, | |
[ | |
image_input, | |
chat_input, | |
], | |
[caption_output,gpt3_output_v1,caption_output_v1], | |
) | |
''' | |
cap_submit_button.click( | |
caption, | |
[ | |
image_input, | |
], | |
[caption_output_v1], | |
) | |
gpt3_submit_button.click( | |
gpt3, | |
[ | |
chat_input, | |
caption_output , | |
caption_output_v1, | |
], | |
[gpt3_output_v1], | |
) | |
''' | |
# examples = gr.Examples( | |
# examples=examples, | |
# inputs=[image_input, chat_input], | |
# ) | |
iface.queue(concurrency_count=1, api_open=False, max_size=10) | |
iface.launch(enable_queue=True) |