Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import ImageFilter, Image | |
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor | |
import torch | |
import requests | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Initialize the CLIP-ViT model | |
checkpoint = "openai/clip-vit-large-patch14-336" | |
model = AutoModelForZeroShotImageClassification.from_pretrained(checkpoint) | |
model = model.to(device) | |
processor = AutoProcessor.from_pretrained(checkpoint) | |
def classify_image(image, candidate_labels): | |
messages = [] | |
candidate_labels = [label.strip() for label in candidate_labels.split(",")] + ["other"] | |
# Blur the image | |
image = image.filter(ImageFilter.GaussianBlur(radius=5)) | |
# Process the image and candidate labels | |
inputs = processor(images=image, text=candidate_labels, return_tensors="pt", padding=True) | |
inputs = {key: val.to(device) for key, val in inputs.items()} | |
# Get model's output | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits_per_image[0] | |
probs = logits.softmax(dim=-1).cpu().numpy() | |
# Organize results | |
results = [ | |
{"score": score, "label": candidate_label} | |
for score, candidate_label in sorted(zip(probs, candidate_labels), key=lambda x: -x[0]) | |
] | |
# Decision-making logic | |
top_label = results[0]["label"] | |
second_label = results[1]["label"] | |
# Add messages to understand the scores | |
messages.append(f"Top label: {top_label} with score: {results[0]['score']:.2f}") | |
messages.append(f"Second label: {second_label} with score: {results[1]['score']:.2f}") | |
# Example decision logic for specific scenarios (can be customized further) | |
if top_label == candidate_labels[0] and results[0]["score"] >= 0.58 and second_label != "other": | |
messages.append("Triggered the new 0.58 check!") | |
result = True | |
elif top_label == candidate_labels[0] and second_label in candidate_labels[:-1] and (results[0]['score'] + results[1]['score']) >= 0.90: | |
messages.append("Triggered the 90% combined check!") | |
result = True | |
elif top_label == candidate_labels[1] and second_label == candidate_labels[0] and (results[0]['score'] + results[1]['score']) >= 0.95: | |
messages.append("Triggered the 90% reverse order check!") | |
result = True | |
else: | |
result = False | |
return result, top_label, results, messages | |
iface = gr.Interface( | |
fn=classify_image, | |
inputs=[ | |
gr.Image(type="pil", label="Upload an Image"), | |
gr.Textbox(label="Candidate Labels (comma separated)") | |
], | |
outputs=[ | |
gr.Label(label="Result"), | |
gr.Textbox(label="Top Label"), | |
gr.Dataframe(label="Details"), | |
gr.Textbox(label="Messages") | |
], | |
title="General Action Classifier", | |
description="Upload an image and specify candidate labels to check if an action is present in the image." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |