Davidzhangyuanhan commited on
Commit
df59928
1 Parent(s): d2d4aba

Add application file

Browse files
Files changed (1) hide show
  1. app.py +13 -14
app.py CHANGED
@@ -6,7 +6,7 @@ import cv2
6
  import torch
7
  import torch.nn as nn
8
  from PIL import Image
9
- from torchvision import transforms
10
  from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
  from timm.data import create_transform
12
 
@@ -20,14 +20,14 @@ def pil_loader(filepath):
20
  img = img.convert('RGB')
21
  return img
22
 
23
- def build_transforms(input_size):
24
- transform = transforms.Compose([
25
- transforms.Resize(input_size * 8 // 7),
26
- transforms.CenterCrop(input_size),
27
- transforms.ToTensor(),
28
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
29
  ])
30
- return transforms
31
 
32
  # Download human-readable labels for Bamboo.
33
  with open('./trainid2name.json') as f:
@@ -40,11 +40,6 @@ build model
40
  model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert')
41
  model.eval()
42
 
43
- '''
44
- build data transform
45
- '''
46
- eval_transforms = build_transforms(224)
47
-
48
  '''
49
  borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
50
  '''
@@ -73,15 +68,19 @@ def show_cam_on_image(img: np.ndarray,
73
  # cam = cam / np.max(cam)
74
  return np.uint8(255 * cam)
75
 
 
 
 
76
  def recognize_image(image):
77
  img_t = eval_transforms(image)
78
-
79
  # compute output
80
  output = model(img_t.unsqueeze(0))
81
  prediction = output.softmax(-1).flatten()
82
  _,top5_idx = torch.topk(prediction, 5)
83
  return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}
84
 
 
 
85
 
86
  image = gr.inputs.Image()
87
  label = gr.outputs.Label(num_top_classes=5)
 
6
  import torch
7
  import torch.nn as nn
8
  from PIL import Image
9
+ import torchvision
10
  from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
  from timm.data import create_transform
12
 
 
20
  img = img.convert('RGB')
21
  return img
22
 
23
+ def build_transforms(input_size, center_crop=True):
24
+ transform = torchvision.transforms.Compose([
25
+ torchvision.transforms.Resize(input_size * 8 // 7),
26
+ torchvision.transforms.CenterCrop(input_size),
27
+ torchvision.transforms.ToTensor(),
28
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
29
  ])
30
+ return transform
31
 
32
  # Download human-readable labels for Bamboo.
33
  with open('./trainid2name.json') as f:
 
40
  model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert')
41
  model.eval()
42
 
 
 
 
 
 
43
  '''
44
  borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
45
  '''
 
68
  # cam = cam / np.max(cam)
69
  return np.uint8(255 * cam)
70
 
71
+
72
+
73
+
74
  def recognize_image(image):
75
  img_t = eval_transforms(image)
 
76
  # compute output
77
  output = model(img_t.unsqueeze(0))
78
  prediction = output.softmax(-1).flatten()
79
  _,top5_idx = torch.topk(prediction, 5)
80
  return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}
81
 
82
+ eval_transforms = build_transforms(224)
83
+
84
 
85
  image = gr.inputs.Image()
86
  label = gr.outputs.Label(num_top_classes=5)