File size: 3,691 Bytes
057d2a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28238e4
057d2a5
 
 
ffa5b50
 
 
057d2a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffa5b50
057d2a5
 
 
 
 
 
 
 
 
 
ffa5b50
057d2a5
 
 
28238e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
057d2a5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from PIL import ImageFilter, Image
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
import torch
import requests

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the CLIP-ViT model
checkpoint = "openai/clip-vit-large-patch14-336"
model = AutoModelForZeroShotImageClassification.from_pretrained(checkpoint)
model = model.to(device)

processor = AutoProcessor.from_pretrained(checkpoint)

def classify_image(image, candidate_labels):
    messages = []
    candidate_labels = [label.strip() for label in candidate_labels.split(",")] + ["other"]
    
    # Blur the image
    image = image.filter(ImageFilter.GaussianBlur(radius=5))
    
    # Process the image and candidate labels
    inputs = processor(images=image, text=candidate_labels, return_tensors="pt", padding=True)
    inputs = {key: val.to(device) for key, val in inputs.items()}

    # Get model's output
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits_per_image[0]
    probs = logits.softmax(dim=-1).cpu().numpy()

    # Organize results
    results = [
        {"score": float(score), "label": candidate_label}
        for score, candidate_label in sorted(zip(probs, candidate_labels), key=lambda x: -x[0])
    ]

    # Convert results to list of lists for Dataframe
    results_for_df = [[res['label'], res['score']] for res in results]

    # Decision-making logic
    top_label = results[0]["label"]
    second_label = results[1]["label"]

    # Add messages to understand the scores
    messages.append(f"Top label: {top_label} with score: {results[0]['score']:.2f}")
    messages.append(f"Second label: {second_label} with score: {results[1]['score']:.2f}")

    # Example decision logic for specific scenarios (can be customized further)
    if top_label == candidate_labels[0] and results[0]["score"] >= 0.58 and second_label != "other":
        messages.append("Triggered the new 0.58 check!")
        result = True
    elif top_label == candidate_labels[0] and second_label in candidate_labels[:-1] and (results[0]['score'] + results[1]['score']) >= 0.90:
        messages.append("Triggered the 90% combined check!")
        result = True
    elif top_label == candidate_labels[1] and second_label == candidate_labels[0] and (results[0]['score'] + results[1]['score']) >= 0.95:
        messages.append("Triggered the 90% reverse order check!")
        result = True
    else:
        result = False

    return result, top_label, results_for_df, messages

iface = gr.Interface(
    fn=classify_image,
    inputs=[
        gr.Image(type="pil", label="Upload an Image"),
        gr.Textbox(label="Candidate Labels (comma separated)")
    ],
    outputs=[
        gr.Label(label="Result"),
        gr.Textbox(label="Top Label"),
        gr.Dataframe(headers=["Label", "Score"], label="Details"),
        gr.Textbox(label="Messages")
    ],
    title="General Action Classifier",
    description="""
    **Instructions:**
    
    1. **Upload an Image**: Drag and drop an image or click to upload an image file.
    
    2. **Enter Candidate Labels**: 
       - Provide candidate labels separated by commas. 
       - For example: `human with beverage,human,beverage`
       - The label "other" will automatically be added to the list of candidate labels.
    
    3. **View Results**: 
       - The result will indicate whether the specified action (top label) is present in the image.
       - Detailed scores for each label will be displayed in a table.
       - Additional messages explaining the decision process will also be shown.
    """
)

if __name__ == "__main__":
    iface.launch()