import torch import torchvision.transforms as T from timm import create_model from safetensors.torch import load_model import numpy as np from pathlib import Path import gradio as gr examples = Path('./examples').glob('*') examples = list(map(str,examples)) valid_tfms = T.Compose([ T.Resize((224,224)), T.ToTensor(), T.Normalize( mean = (0.5,0.5,0.5), std = (0.5,0.5,0.5) ) ]) model_path = 'model/swin_s3_base_224-pascal/model.safetensors' model = create_model( 'swin_s3_base_224', pretrained = False, num_classes = 20 ) load_model(model,model_path) model.eval() class_names = [ "Aeroplane","Bicycle","Bird","Boat","Bottle", "Bus","Car","Cat","Chair","Cow","Diningtable", "Dog","Horse","Motorbike","Person", "Potted plant","Sheep","Sofa","Train","Tv/monitor" ] label2id = {c:idx for idx,c in enumerate(class_names)} id2label = {idx:c for idx,c in enumerate(class_names)} def predict(im): im = valid_tfms(im).unsqueeze(0) with torch.no_grad(): logits = model(im) confidences = logits.sigmoid().flatten() predictions = confidences > 0.5 predictions = predictions.float().numpy() pred_labels = np.where(predictions==1)[0] confidences = confidences[pred_labels].numpy() pred_labels = [id2label[label] for label in pred_labels] outputs = {l:c for l,c in zip(pred_labels, confidences)} return outputs gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(label='the image contains:'), examples=examples).queue().launch()