kushagra124's picture
adding app
894c286
raw
history blame
1.7 kB
import gradio as gr
from transformers import AutoConfig,ViTImageProcessor,ViTForImageClassification,AutoModel
import base64
import os
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
images = 'room.jpg'
def image_classifier(image):
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
logits_np = logits.detach().cpu().numpy()
logits_args = logits_np.argsort()[0][-3:]
prediction_classes = [model.config.id2label[predicted_class_idx] for predicted_class_idx in logits_args ]
result = {}
for i,item in enumerate(prediction_classes):
result[item] = logits_np[0][i]
return result
with gr.Blocks(title="Image Classification using Google Vision Transformer") as demo :
gr.Markdown(
"""
<center>
<h1>
The Vision Transformer (ViT)
</h1>
Transformer encoder model (BERT-like) pretrained on a large collection of images in a supervised fashion, namely ImageNet-21k, at a resolution of 224x224 pixels.
Next, the model was fine-tuned on ImageNet (also referred to as ILSVRC2012), a dataset comprising 1 million images and 1,000 classes, also at resolution 224x224.
</center>
"""
)
with gr.Row():
with gr.Column():
# inputt = gr.inputs.Image(shape=(200, 200)),
inputt = gr.Image(type="numpy", label="Input Image for Classification")
button = gr.Button(value="Classify")
with gr.Column():
output = gr.Label()
button.click(image_classifier,inputt,output)
demo.launch()