import gradio as gr from transformers import FlavaModel, BertTokenizer, FlavaFeatureExtractor import numpy as np from PIL import Image import torch images="dog.jpg" model = FlavaModel.from_pretrained("facebook/flava-full") model.eval() fe = FlavaFeatureExtractor.from_pretrained("facebook/flava-full") tokenizer = BertTokenizer.from_pretrained("facebook/flava-full") def shot(image, labels_text): PIL_image = Image.fromarray(np.uint8(image)).convert('RGB') labels = labels_text.split(",") label_with_template = [f"This is a photo of a {label}" for label in labels] image_input = fe([PIL_image], return_tensors="pt") text_inputs = tokenizer(label_with_template, padding="max_length", return_tensors="pt") image_embeddings = model.get_image_features(**image_input)[:, 0, :] text_embeddings = model.get_text_features(**text_inputs)[:, 0, :] similarities = list(torch.nn.functional.softmax((text_embeddings @ image_embeddings.T).squeeze(0), dim=0)) return {label: similarities[idx].item() for idx, label in enumerate(labels)} iface = gr.Interface(shot, ["image", "text"], "label", examples=[["dog.jpg", "dog,cat,bird"], ["germany.jpg", "germany,belgium,colombia"], ["rocket.jpg", "car,rocket,train"] ], description="Add a picture and a list of labels separated by commas", title="FLAVA Zero-shot Image Classification") iface.launch()