IbrahimHasani's picture
Update app.py
28238e4 verified
raw
history blame
No virus
3.69 kB
import gradio as gr
from PIL import ImageFilter, Image
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
import torch
import requests
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the CLIP-ViT model
checkpoint = "openai/clip-vit-large-patch14-336"
model = AutoModelForZeroShotImageClassification.from_pretrained(checkpoint)
model = model.to(device)
processor = AutoProcessor.from_pretrained(checkpoint)
def classify_image(image, candidate_labels):
messages = []
candidate_labels = [label.strip() for label in candidate_labels.split(",")] + ["other"]
# Blur the image
image = image.filter(ImageFilter.GaussianBlur(radius=5))
# Process the image and candidate labels
inputs = processor(images=image, text=candidate_labels, return_tensors="pt", padding=True)
inputs = {key: val.to(device) for key, val in inputs.items()}
# Get model's output
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits_per_image[0]
probs = logits.softmax(dim=-1).cpu().numpy()
# Organize results
results = [
{"score": float(score), "label": candidate_label}
for score, candidate_label in sorted(zip(probs, candidate_labels), key=lambda x: -x[0])
]
# Convert results to list of lists for Dataframe
results_for_df = [[res['label'], res['score']] for res in results]
# Decision-making logic
top_label = results[0]["label"]
second_label = results[1]["label"]
# Add messages to understand the scores
messages.append(f"Top label: {top_label} with score: {results[0]['score']:.2f}")
messages.append(f"Second label: {second_label} with score: {results[1]['score']:.2f}")
# Example decision logic for specific scenarios (can be customized further)
if top_label == candidate_labels[0] and results[0]["score"] >= 0.58 and second_label != "other":
messages.append("Triggered the new 0.58 check!")
result = True
elif top_label == candidate_labels[0] and second_label in candidate_labels[:-1] and (results[0]['score'] + results[1]['score']) >= 0.90:
messages.append("Triggered the 90% combined check!")
result = True
elif top_label == candidate_labels[1] and second_label == candidate_labels[0] and (results[0]['score'] + results[1]['score']) >= 0.95:
messages.append("Triggered the 90% reverse order check!")
result = True
else:
result = False
return result, top_label, results_for_df, messages
iface = gr.Interface(
fn=classify_image,
inputs=[
gr.Image(type="pil", label="Upload an Image"),
gr.Textbox(label="Candidate Labels (comma separated)")
],
outputs=[
gr.Label(label="Result"),
gr.Textbox(label="Top Label"),
gr.Dataframe(headers=["Label", "Score"], label="Details"),
gr.Textbox(label="Messages")
],
title="General Action Classifier",
description="""
**Instructions:**
1. **Upload an Image**: Drag and drop an image or click to upload an image file.
2. **Enter Candidate Labels**:
- Provide candidate labels separated by commas.
- For example: `human with beverage,human,beverage`
- The label "other" will automatically be added to the list of candidate labels.
3. **View Results**:
- The result will indicate whether the specified action (top label) is present in the image.
- Detailed scores for each label will be displayed in a table.
- Additional messages explaining the decision process will also be shown.
"""
)
if __name__ == "__main__":
iface.launch()