import gradio as gr from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification from PIL import Image import torch # Load model and processor model_name = "google/vit-hybrid-base-bit-384" feature_extractor = ViTHybridImageProcessor.from_pretrained(model_name) model = ViTHybridForImageClassification.from_pretrained(model_name) # Function for prediction def classify_image(image): inputs = feature_extractor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() return model.config.id2label[predicted_class_idx] # Gradio UI iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs="text", title="ViT-Hybrid Image Classifier", description="Upload an image to classify it using the ViT-Hybrid model.", ) if __name__ == "__main__": iface.launch()