|
from PIL import Image |
|
import gradio as gr |
|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
import torch |
|
|
|
model = ViTForImageClassification.from_pretrained('sreeramajay/pollution') |
|
transforms = ViTFeatureExtractor.from_pretrained('sreeramajay/pollution') |
|
|
|
def predict(image): |
|
labels = {0:"Air Pollution", 1: "Land Pollution" , 2: "Water Pollution"} |
|
inputs = transforms(image, return_tensors='pt') |
|
output = model(**inputs) |
|
probability = output.logits.softmax(1) |
|
values, indices = torch.topk(probability, k=3) |
|
return {labels[i.item()]: v.item() for i, v in zip(indices.numpy()[0], values.detach().numpy()[0])} |
|
|
|
|
|
gr.Interface( |
|
predict, |
|
inputs = gr.inputs.Image(type="pil", label="Chosen Image"), |
|
outputs = 'label', |
|
theme="seafoam", |
|
).launch(debug=True) |