Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import gradio as gr | |
from transformers import ViTFeatureExtractor | |
from huggingface_hub import hf_hub_download | |
import spaces | |
from torchvision import transforms | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
model = None | |
feature_extractor = None | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
VALID_DS_PATH = 'valid_ds.pth' | |
valid_ds = torch.load(VALID_DS_PATH) | |
from transformers import ViTModel | |
from transformers.modeling_outputs import SequenceClassifierOutput | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class ViTForImageClassification(nn.Module): | |
def __init__(self, num_labels=3): | |
super(ViTForImageClassification, self).__init__() | |
self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') | |
self.dropout = nn.Dropout(0.1) | |
self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels) | |
self.num_labels = num_labels | |
def forward(self, pixel_values, labels): | |
outputs = self.vit(pixel_values=pixel_values) | |
output = self.dropout(outputs.last_hidden_state[:,0]) | |
logits = self.classifier(output) | |
loss = None | |
if labels is not None: | |
loss_fct = nn.CrossEntropyLoss() | |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
if loss is not None: | |
return logits, loss.item() | |
else: | |
return logits, None | |
# Load an image from file for inference | |
def load_image(image_path): | |
img = Image.open(image_path) | |
img = img.convert("RGB") # Ensure it's in RGB format | |
return img | |
# Inference function | |
def run_inference(image, device, valid_ds): | |
# Load image from the Gradio input | |
# input_image = Image.fromarray(image.astype('uint8'), 'RGB') | |
global model, feature_extractor | |
if model is None or feature_extractor is None: | |
MODEL_PATH = hf_hub_download(repo_id="limitedonly41/offers_26", | |
filename="model_50.pt", | |
use_auth_token=HF_TOKEN) | |
try: | |
model = torch.load(MODEL_PATH) | |
except: | |
model = torch.load(MODEL_PATH, map_location=torch.device('cpu')) | |
model.eval() | |
model.to(device) | |
# feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k', do_rescale=False) | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), # Resize to the model's input size | |
transforms.ToTensor(), | |
]) | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
input_tensor = transform(image) | |
input_tensor = input_tensor.unsqueeze(0) # Add a batch dimension | |
input_tensor = input_tensor.to(device) # Send to appropriate computing device | |
# Disable grad | |
with torch.no_grad(): | |
# Generate prediction | |
prediction, _ = model(input_tensor, labels=None) | |
# Get the predicted class index | |
predicted_class = torch.argmax(prediction, dim=1).item() | |
value_predicted = list(valid_ds.class_to_idx.keys())[list(valid_ds.class_to_idx.values()).index(predicted_class)] | |
# return f"Predicted Class: {value_predicted}, {predicted_class}" | |
return value_predicted | |
# # Preprocess the image using the feature extractor | |
# inputs = feature_extractor(images=input_image, return_tensors="pt")['pixel_values'] | |
# # Send to the appropriate device (CPU/GPU) | |
# inputs = inputs.to(device) | |
# # Disable gradients during inference | |
# with torch.no_grad(): | |
# # Generate prediction | |
# prediction, _ = model(inputs, None) | |
# # Predicted class value using argmax | |
# predicted_class = np.argmax(prediction.cpu().numpy()) | |
# value_predicted = list(valid_ds.class_to_idx.keys())[list(valid_ds.class_to_idx.values()).index(predicted_class)] | |
# # Return the result with the predicted class | |
# return f"Predicted Class: {value_predicted}, {predicted_class}" | |
# Create a Gradio interface | |
iface = gr.Interface( | |
fn=lambda image: run_inference(image, device, valid_ds), | |
inputs=gr.Image(type="numpy"), # Updated to use gr.Image | |
outputs="text", # Output is text (predicted class) | |
title="Image Classification", | |
description="Upload an image to get the predicted class using the ViT model." | |
) | |
# Launch the Gradio app | |
iface.launch() | |