shreydan commited on
Commit
7a5a216
1 Parent(s): a00146a

create app

Browse files
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ from timm import create_model
4
+ from safetensors.torch import load_model
5
+ import numpy as np
6
+ from pathlib import Path
7
+ import gradio as gr
8
+
9
+ examples = Path('./examples').glob('*')
10
+ examples = list(map(str,examples))
11
+
12
+ valid_tfms = T.Compose([
13
+ T.Resize((224,224)),
14
+ T.ToTensor(),
15
+ T.Normalize(
16
+ mean = (0.5,0.5,0.5),
17
+ std = (0.5,0.5,0.5)
18
+ )
19
+ ])
20
+
21
+
22
+ model_path = 'model/swin_s3_base_224-pascal/model.safetensors'
23
+ model = create_model(
24
+ 'swin_s3_base_224',
25
+ pretrained = False,
26
+ num_classes = 20
27
+ )
28
+ load_model(model,model_path)
29
+ model.eval()
30
+
31
+ class_names = [
32
+ "Aeroplane","Bicycle","Bird","Boat","Bottle",
33
+ "Bus","Car","Cat","Chair","Cow","Diningtable",
34
+ "Dog","Horse","Motorbike","Person",
35
+ "Potted plant","Sheep","Sofa","Train","Tv/monitor"
36
+ ]
37
+
38
+ label2id = {c:idx for idx,c in enumerate(class_names)}
39
+ id2label = {idx:c for idx,c in enumerate(class_names)}
40
+
41
+
42
+ def predict(im):
43
+ im = valid_tfms(im).unsqueeze(0)
44
+ with torch.no_grad():
45
+ logits = model(im)
46
+
47
+ confidences = logits.sigmoid().flatten()
48
+ predictions = confidences > 0.5
49
+ predictions = predictions.float().numpy()
50
+ pred_labels = np.where(predictions==1)[0]
51
+ confidences = confidences[pred_labels].numpy()
52
+ pred_labels = [id2label[label] for label in pred_labels]
53
+ outputs = {l:c for l,c in zip(pred_labels, confidences)}
54
+ return outputs
55
+
56
+ gr.Interface(fn=predict,
57
+ inputs=gr.Image(type="pil"),
58
+ outputs=gr.Label(label='the image contains:'),
59
+ examples=examples).queue().launch()
examples/14_12.jpg ADDED
examples/17_7.jpg ADDED
examples/1_14.jpg ADDED
examples/2.jpg ADDED
examples/7.jpg ADDED
examples/8_14_19.jpg ADDED
model/swin_s3_base_224-pascal/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57c40f375cba8df0eae8186e3be85d6ad1fcb3e02d307fb263962872e990e66a
3
+ size 281538560
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ safetensors
4
+ timm
5
+ gradio