Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, ViTForImageClassification, ViTImageProcessor | |
import numpy as np | |
from PIL import Image | |
import warnings | |
import logging | |
from pytorch_grad_cam import run_dff_on_image, GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
import torch | |
from face_grab import FaceGrabber | |
from gradcam import GradCam | |
from torchvision import transforms | |
logging.basicConfig(level=logging.INFO) | |
model = ViTForImageClassification.from_pretrained("ongkn/emikes-classifier") | |
processor = ViTImageProcessor.from_pretrained("ongkn/emikes-classifier") | |
faceGrabber = FaceGrabber() | |
gradCam = GradCam() | |
targetsForGradCam = [ClassifierOutputTarget(gradCam.category_name_to_index(model, "emi")), | |
ClassifierOutputTarget(gradCam.category_name_to_index(model, "kes"))] | |
targetLayerDff = model.vit.layernorm | |
targetLayerGradCam = model.vit.encoder.layer[-2].output | |
def classify_image(input): | |
face = faceGrabber.grab_faces(np.array(input)) | |
if face is None: | |
return "No face detected", 0, input | |
face = Image.fromarray(face) | |
faceResized = face.resize((224, 224)) | |
tensorResized = transforms.ToTensor()(faceResized) | |
dffImage = run_dff_on_image(model=model, | |
target_layer=targetLayerDff, | |
classifier=model.classifier, | |
img_pil=faceResized, | |
img_tensor=tensorResized, | |
reshape_transform=gradCam.reshape_transform_vit_huggingface, | |
n_components=5, | |
top_k=10 | |
) | |
result = gradCam.get_top_category(model, tensorResized) | |
cls = result[0]["label"] | |
clsIdx = gradCam.category_name_to_index(model, cls) | |
clsTarget = ClassifierOutputTarget(clsIdx) | |
gradCamImage = gradCam.run_grad_cam_on_image(model=model, | |
target_layer=targetLayerGradCam, | |
targets_for_gradcam=[clsTarget], | |
input_tensor=tensorResized, | |
input_image=faceResized, | |
reshape_transform=gradCam.reshape_transform_vit_huggingface) | |
return result[0]["label"], result[0]["score"], face, dffImage, gradCamImage | |
iface = gr.Interface( | |
fn=classify_image, | |
inputs="image", | |
outputs=["text", "number", "image", "image", "image"], | |
title="Attraction Classifier - subjective", | |
description=f"Takes in a (224, 224) image and outputs a class: {'emi', 'kes'}, along with a GradCam/DFF explanation. Face detection, cropping, and resizing are done internally. Uploaded images are not stored by us, but may be stored by HF. Refer to their [privacy policy](https://huggingface.co/privacy) for details." | |
) | |
iface.launch() |