IbrahimHasani's picture
Update app.py
ffa5b50 verified
raw
history blame
No virus
3.13 kB
import gradio as gr
from PIL import ImageFilter, Image
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
import torch
import requests
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the CLIP-ViT model
checkpoint = "openai/clip-vit-large-patch14-336"
model = AutoModelForZeroShotImageClassification.from_pretrained(checkpoint)
model = model.to(device)
processor = AutoProcessor.from_pretrained(checkpoint)
def classify_image(image, candidate_labels):
messages = []
candidate_labels = [label.strip() for label in candidate_labels.split(",")] + ["other"]
# Blur the image
image = image.filter(ImageFilter.GaussianBlur(radius=5))
# Process the image and candidate labels
inputs = processor(images=image, text=candidate_labels, return_tensors="pt", padding=True)
inputs = {key: val.to(device) for key, val in inputs.items()}
# Get model's output
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits_per_image[0]
probs = logits.softmax(dim=-1).cpu().numpy()
# Organize results
results = [
{"score": score, "label": candidate_label}
for score, candidate_label in sorted(zip(probs, candidate_labels), key=lambda x: -x[0])
]
# Convert results to list of lists for Dataframe
results_for_df = [[res['label'], res['score']] for res in results]
# Decision-making logic
top_label = results[0]["label"]
second_label = results[1]["label"]
# Add messages to understand the scores
messages.append(f"Top label: {top_label} with score: {results[0]['score']:.2f}")
messages.append(f"Second label: {second_label} with score: {results[1]['score']:.2f}")
# Example decision logic for specific scenarios (can be customized further)
if top_label == candidate_labels[0] and results[0]["score"] >= 0.58 and second_label != "other":
messages.append("Triggered the new 0.58 check!")
result = True
elif top_label == candidate_labels[0] and second_label in candidate_labels[:-1] and (results[0]['score'] + results[1]['score']) >= 0.90:
messages.append("Triggered the 90% combined check!")
result = True
elif top_label == candidate_labels[1] and second_label == candidate_labels[0] and (results[0]['score'] + results[1]['score']) >= 0.95:
messages.append("Triggered the 90% reverse order check!")
result = True
else:
result = False
return result, top_label, results_for_df, messages
iface = gr.Interface(
fn=classify_image,
inputs=[
gr.Image(type="pil", label="Upload an Image"),
gr.Textbox(label="Candidate Labels (comma separated)")
],
outputs=[
gr.Label(label="Result"),
gr.Textbox(label="Top Label"),
gr.Dataframe(headers=["Label", "Score"], label="Details"),
gr.Textbox(label="Messages")
],
title="General Action Classifier",
description="Upload an image and specify candidate labels to check if an action is present in the image."
)
if __name__ == "__main__":
iface.launch()