Spaces:
Runtime error
Runtime error
File size: 5,095 Bytes
b395c00 cbadb1a b395c00 de2dda2 b395c00 de2dda2 b395c00 de2dda2 b395c00 a2b1833 fe6ca74 b395c00 a2b1833 fe6ca74 b395c00 bbec7cd 6b4bc05 b395c00 ab02472 cbadb1a b395c00 de2dda2 b395c00 de2dda2 b395c00 2ce2973 b395c00 c567393 b395c00 2753619 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForQuestionAnswering, ViltForQuestionAnswering
import torch
import math
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')
git_processor_base = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
git_model_base = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")
# git_processor_large = AutoProcessor.from_pretrained("microsoft/git-large-vqav2")
# git_model_large = AutoModelForCausalLM.from_pretrained("microsoft/git-large-vqav2")
blip_processor_base = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
blip_model_base = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
# blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
# blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
vilt_processor = AutoProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
device = "cuda" if torch.cuda.is_available() else "cpu"
git_model_base.to(device)
blip_model_base.to(device)
#git_model_large.to(device)
#blip_model_large.to(device)
vilt_model.to(device)
def generate_answer_git(processor, model, image, question):
# prepare image
pixel_values = processor(images=image, return_tensors="pt").pixel_values
# prepare question
input_ids = processor(text=question, add_special_tokens=False).input_ids
input_ids = [processor.tokenizer.cls_token_id] + input_ids
input_ids = torch.tensor(input_ids).unsqueeze(0)
generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50, output_scores=True)
print(generated_ids)
generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
return generated_answer
def generate_answer_blip(processor, model, image, question):
# prepare image + question
inputs = processor(images=image, text=question, return_tensors="pt")
generated_ids = model.generate(**inputs, max_length=50, output_scores=True)
print(generated_ids)
generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
return generated_answer
def generate_answer_vilt(processor, model, image, question):
# prepare image + question
encoding = processor(images=image, text=question, return_tensors="pt")
with torch.no_grad():
outputs = model(**encoding)
print(outputs.logits)
lsm = torch.nn.LogSoftmax(dim=1)
print(lsm(outputs.logits))
predicted_class_idx = outputs.logits.argmax(-1).item()
logitsList = outputs.logits.tolist()
print(logitsList)
maybeProbsList = [math.exp(i) for i in logitsList]
return model.config.id2label[predicted_class_idx]
def generate_answers(image, question):
answer_git_base = generate_answer_git(git_processor_base, git_model_base, image, question)
# answer_git_large = generate_answer_git(git_processor_large, git_model_large, image, question)
answer_blip_base = generate_answer_blip(blip_processor_base, blip_model_base, image, question)
# answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question)
answer_vilt = generate_answer_vilt(vilt_processor, vilt_model, image, question)
return answer_git_base, answer_blip_base, answer_vilt
examples = [["cats.jpg", "How many cats are there?"], ["stop_sign.png", "What's behind the stop sign?"], ["astronaut.jpg", "What's the astronaut riding on?"]]
outputs = [gr.outputs.Textbox(label="Answer generated by GIT-base"), gr.outputs.Textbox(label="Answer generated by BLIP-base"), gr.outputs.Textbox(label="Answer generated by ViLT")]
title = "Interactive demo: comparing visual question answering (VQA) models"
description = "Gradio Demo to compare GIT, BLIP and ViLT, 3 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"
interface = gr.Interface(fn=generate_answers,
inputs=[gr.inputs.Image(type="pil"), gr.inputs.Textbox(label="Question")],
outputs=outputs,
examples=examples,
title=title,
description=description,
article=article,
enable_queue=True)
interface.launch(debug=True) |