File size: 5,321 Bytes
0d4900d
 
 
 
 
 
 
 
7580432
0d4900d
 
 
 
 
 
 
 
 
 
 
 
7580432
0d4900d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from gradio_utils import *
def process_images_gradcam(show_gradcam, gradcam_count, gradcam_layer, gradcam_opacity):
    if show_gradcam:
        inv_normalize = transforms.Normalize(
             mean=[-1.9899, -1.9844, -1.7111],
            std=[4.0486, 4.1152, 3.8314])
        classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')
        misclassified_data = get_misclassified_data(modelfin, "cpu", test_loader)
        if gradcam_layer=="1":
            images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer1[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity)
        if gradcam_layer=="2":
            images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer2[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity)
        if gradcam_layer=="3":
            images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer3[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity)
        if gradcam_layer=="4":
            images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer4[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity)
        return images
    
def process_images_misclass(show_misclassify, misclassify_count):
    if show_misclassify:
        misclassified_data = get_misclassified_data(modelfin, "cpu", test_loader)    
        image = display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=misclassify_count)    
        return image
    
def predict_classes(upload_image, top_classes):
    transform = transforms.Compose([
    transforms.Resize((32, 32)),             # Resize to 32x32 pixels
    transforms.ToTensor(),                   # Convert image to tensor
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],  # CIFAR-10 normalization
                         std=[0.2023, 0.1994, 0.2010])])

    # Load and transform an image
    image = upload_image
    image = transform(image)
    image = image.unsqueeze(0)
    device = next(modelfin.parameters()).device
    image = image.to(device)
    # Ensure the model is in evaluation mode
    modelfin.eval()

    # Disable gradient computation for inference
    with torch.no_grad():
        output = modelfin(image)

    # Get the top 5 predictions and their indices
    probabilities = torch.nn.functional.softmax(output, dim=1)
    top_prob, top_catid = torch.topk(probabilities, top_classes)

    # CIFAR-10 classes
    classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']



    # Initialize an empty string to collect predictions
    predictions_str = ""

    # Collect top 5 predictions in the string with line breaks
    for i in range(top_prob.size(1)):
        predictions_str += f"{classes[top_catid[0][i]]}: {top_prob[0][i].item()*100:.2f}%\n"

    # Print or return the complete predictions string
    return predictions_str


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            show_gradcam = gr.Checkbox(label="Show GradCAM Images?")
            gradcam_count = gr.Number(label="How many GradCAM images?", value=1, precision=0)
            gradcam_layer = gr.Radio(choices=["1", "2", "3", "4"], label="Choose a layer", value=4)
            gradcam_opacity = gr.Slider(minimum=0, maximum=1, label="Opacity of overlay", value=0.5)
        # with gr.Column():
        #     show_misclassified = gr.Checkbox(label="Show Misclassified Images?")
        #     misclassified_count = gr.Number(label="How many misclassified images?", value=1, precision=0)
    
    #uploaded_images = gr.File(label="Upload New Images", type="file", accept="image/*", multiple=True)
    #top_classes = gr.Number(label="How many top classes to show?", value=5, minimum=1, maximum=10, precision=0)
    
            submit_button = gr.Button("GradCam")
            outputs = gr.Image(label="Output")

            show_misclassify = gr.Checkbox(label="Show misclassified images?")
            misclassify_count=gr.Number(label="How many misclassified images?")
            submit_button_misclass = gr.Button("Misclassified")
            outputs_misclass = gr.Image(label="Output")  

            upload_image = gr.Image(label="Upload your image", interactive = True, type='pil')  
            top_classes = gr.Number(label="How many top classes would you like to see?", maximum=10)
            upload_btn = gr.Button("Classify your image")
            show_classes = gr.Textbox(label="Your top classes", interactive=False)
                    
    submit_button_misclass.click(
        process_images_misclass,
        inputs=[show_misclassify, misclassify_count],
        outputs=outputs_misclass
    )
    submit_button.click(
        process_images_gradcam,
        inputs=[show_gradcam, gradcam_count, gradcam_layer, gradcam_opacity],
        outputs=outputs
    )

    upload_btn.click(
        predict_classes,
        inputs=[upload_image, top_classes],
        outputs=show_classes
    )

demo.launch()