|
from PIL import Image |
|
import requests |
|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as patches |
|
import gradio as gr |
|
import numpy as np |
|
|
|
def classify_and_label_image(image_array): |
|
|
|
image = Image.fromarray(np.uint8(image_array)) |
|
|
|
|
|
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') |
|
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') |
|
|
|
|
|
inputs = feature_extractor(images=image, return_tensors="pt") |
|
|
|
|
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
|
|
|
|
labels = model.config.id2label |
|
predicted_class_label = labels[predicted_class_idx] |
|
|
|
|
|
probabilities = logits.softmax(dim=-1) |
|
predicted_class_prob = probabilities[0, predicted_class_idx].item() |
|
|
|
|
|
fig, ax = plt.subplots() |
|
ax.imshow(image) |
|
rect = patches.Rectangle((50, 50), 100, 100, linewidth=1, edgecolor='r', facecolor='none') |
|
ax.add_patch(rect) |
|
plt.text(50, 40, f'{predicted_class_label} {predicted_class_prob * 100:.2f}%', color='r') |
|
|
|
|
|
fig.canvas.draw() |
|
img_arr = np.array(fig.canvas.renderer.buffer_rgba()) |
|
|
|
|
|
plt.close(fig) |
|
|
|
|
|
return img_arr |
|
|
|
examples = [ |
|
["bear.jpg"], |
|
["puppy.jpg"], |
|
["boat.jpg"] |
|
] |
|
gr.Interface(fn=classify_and_label_image, title="Diego's LLM Image to Labeled Image", |
|
description="Classify an image and draw the label on the image.", |
|
examples=examples, |
|
inputs="image", |
|
outputs="image")\ |
|
.launch(share=False) |