|
import tensorflow as tf |
|
import gradio as gr |
|
import gcvit |
|
from gcvit.utils import get_gradcam_model, get_gradcam_prediction |
|
|
|
def predict_fn(image, model_name): |
|
"""A predict function that will be invoked by gradio.""" |
|
model = getattr(gcvit, model_name)(pretrain=True) |
|
gradcam_model = get_gradcam_model(model) |
|
preds, overlay = get_gradcam_prediction(image, gradcam_model, cmap='jet', alpha=0.4, pred_index=None) |
|
preds = {x[1]:float(x[2]) for x in preds} |
|
return [preds, overlay] |
|
|
|
demo = gr.Interface( |
|
fn=predict_fn, |
|
inputs=[ |
|
gr.inputs.Image(label="Input Image"), |
|
gr.Radio(['GCViTXXTiny', 'GCViTXTiny', 'GCViTTiny', |
|
'GCViTSmall', 'GCViTBase','GCViTLarge'], value='GCViTXXTiny', label='Model Name') |
|
], |
|
outputs=[ |
|
gr.outputs.Label(label="Prediction"), |
|
gr.inputs.Image(label="GradCAM"), |
|
], |
|
title="Global Context Vision Transformer (GCViT) Demo", |
|
description="Image Classification with GCViT Model using ImageNet Pretrain Weights.", |
|
examples=[ |
|
["example/hot_air_ballon.jpg", 'GCViTXXTiny'], |
|
["example/chelsea.png", 'GCViTXXTiny'], |
|
["example/penguin.JPG", 'GCViTXXTiny'], |
|
["example/bus.jpg", 'GCViTXXTiny'], |
|
], |
|
) |
|
demo.launch() |