Spaces:
Runtime error
Runtime error
Upload 15 files
Browse files- Bamboo_v0-1_ViT-B16.pth.tar.convert +3 -0
- README.md +7 -7
- app.py +186 -0
- app_bak.py +105 -0
- examples/Ferrari-F355.jpg +0 -0
- examples/basketball.jpg +0 -0
- examples/dribbler.jpg +0 -0
- examples/fratercula_arctica.jpg +0 -0
- examples/husky.jpg +0 -0
- examples/northern_oriole.jpg +0 -0
- examples/playing_mahjong.jpg +0 -0
- examples/taraxacum_erythrospermum.jpg +0 -0
- requirements.txt +12 -0
- timmvit.py +79 -0
- trainid2name.json +0 -0
Bamboo_v0-1_ViT-B16.pth.tar.convert
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a6d30c823ba2fc764291e65a06747390a81b15a1e655dd02b45d58528e08c937
|
3 |
+
size 697651655
|
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
-
title: Bamboo
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license:
|
11 |
---
|
12 |
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Bamboo ViT-B16 Demo
|
3 |
+
emoji: 🎋
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.0.17
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: cc-by-4.0
|
11 |
---
|
12 |
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import requests
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from PIL import Image
|
9 |
+
import torchvision
|
10 |
+
from torchvision import transforms
|
11 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
12 |
+
from timm.data import create_transform
|
13 |
+
import openai
|
14 |
+
from timmvit import timmvit
|
15 |
+
import json
|
16 |
+
from timm.models.hub import download_cached_file
|
17 |
+
from PIL import Image
|
18 |
+
import tempfile
|
19 |
+
|
20 |
+
# key for GPT
|
21 |
+
openai.api_key = "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"
|
22 |
+
|
23 |
+
def pil_loader(filepath):
|
24 |
+
with Image.open(filepath) as img:
|
25 |
+
img = img.convert('RGB')
|
26 |
+
return img
|
27 |
+
|
28 |
+
def build_transforms(input_size, center_crop=True):
|
29 |
+
transform = torchvision.transforms.Compose([
|
30 |
+
torchvision.transforms.ToPILImage(),
|
31 |
+
torchvision.transforms.Resize(input_size * 8 // 7),
|
32 |
+
torchvision.transforms.CenterCrop(input_size),
|
33 |
+
torchvision.transforms.ToTensor(),
|
34 |
+
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
35 |
+
])
|
36 |
+
return transform
|
37 |
+
|
38 |
+
# Download human-readable labels for Bamboo.
|
39 |
+
with open('./trainid2name.json') as f:
|
40 |
+
id2name = json.load(f)
|
41 |
+
|
42 |
+
|
43 |
+
'''
|
44 |
+
build model
|
45 |
+
'''
|
46 |
+
model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert')
|
47 |
+
model.eval()
|
48 |
+
|
49 |
+
'''
|
50 |
+
borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
|
51 |
+
'''
|
52 |
+
def show_cam_on_image(img: np.ndarray,
|
53 |
+
mask: np.ndarray,
|
54 |
+
use_rgb: bool = False,
|
55 |
+
colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
|
56 |
+
""" This function overlays the cam mask on the image as an heatmap.
|
57 |
+
By default the heatmap is in BGR format.
|
58 |
+
:param img: The base image in RGB or BGR format.
|
59 |
+
:param mask: The cam mask.
|
60 |
+
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
|
61 |
+
:param colormap: The OpenCV colormap to be used.
|
62 |
+
:returns: The default image with the cam overlay.
|
63 |
+
"""
|
64 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
|
65 |
+
if use_rgb:
|
66 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
67 |
+
heatmap = np.float32(heatmap) / 255
|
68 |
+
|
69 |
+
if np.max(img) > 1:
|
70 |
+
raise Exception(
|
71 |
+
"The input image should np.float32 in the range [0, 1]")
|
72 |
+
|
73 |
+
cam = 0.7*heatmap + 0.3*img
|
74 |
+
# cam = cam / np.max(cam)
|
75 |
+
return np.uint8(255 * cam)
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
def chat_with_GPT(my_prompt,history):
|
81 |
+
this_history = ''
|
82 |
+
for i in history:
|
83 |
+
for j in i:
|
84 |
+
this_history += j + '\n'
|
85 |
+
|
86 |
+
# print("----this_history----\n"+this_history)
|
87 |
+
# my_prompt = input('Please give your Q:')
|
88 |
+
my_resp = openai.Completion.create(
|
89 |
+
model="text-davinci-003", # 模型选择达芬奇
|
90 |
+
prompt=this_history+my_prompt, # 提问
|
91 |
+
temperature=0.8,
|
92 |
+
max_tokens=2000, # 生成答案的字节数
|
93 |
+
top_p=1.0, # 跟temperature有点类似,结果概率的前面的选择
|
94 |
+
frequency_penalty=0.5, # [-2,2]频率太高的词的惩罚,就是减少重复的词出现(比如小于0会出现很多重复词)
|
95 |
+
presence_penalty=0.0, # [-2,2]围绕着提问来回答的程度(比如小于0的回答会过于紧扣主题)
|
96 |
+
)
|
97 |
+
msg = my_resp.choices[0].text.strip()
|
98 |
+
return msg
|
99 |
+
|
100 |
+
def run_chatbot(input, gr_state=[]):
|
101 |
+
history, conversation = gr_state[0],gr_state[1]
|
102 |
+
output = chat_with_GPT(input,history)
|
103 |
+
history.append((input, output))
|
104 |
+
conversation.append((input, output))
|
105 |
+
# chatbox, state
|
106 |
+
return conversation,[history,conversation]
|
107 |
+
|
108 |
+
def run_chatbot_with_img(input_img,gr_state=[]):
|
109 |
+
history, conversation = gr_state[0],gr_state[1]
|
110 |
+
img_cls = recognize_image(input_img)
|
111 |
+
# conversation = conversation+ [(f'<img src="/file={input_img.name}" style="display: inline-block;">', "")]
|
112 |
+
input = 'I have given you a photo about '+ img_cls + ', and tell me its definition.'
|
113 |
+
output = chat_with_GPT(input,history)
|
114 |
+
|
115 |
+
input_mask = 'Upload image'
|
116 |
+
# conversation保存显示内容
|
117 |
+
conversation.append((input_mask,output))
|
118 |
+
# history保留真实内容
|
119 |
+
history.append((input, output))
|
120 |
+
|
121 |
+
# chatbox gr_state
|
122 |
+
return conversation , [history,conversation]
|
123 |
+
|
124 |
+
def save_img(image):
|
125 |
+
|
126 |
+
filename = next(tempfile._get_candidate_names()) + '.png'
|
127 |
+
image.save(filename)
|
128 |
+
return filename
|
129 |
+
|
130 |
+
def recognize_image(image):
|
131 |
+
img_t = eval_transforms(image)
|
132 |
+
# compute output
|
133 |
+
output = model(img_t.unsqueeze(0))
|
134 |
+
prediction = output.softmax(-1).flatten()
|
135 |
+
_,top5_idx = torch.topk(prediction, 5)
|
136 |
+
idx_max= top5_idx.tolist()[0]
|
137 |
+
print(id2name[str(idx_max)][0])
|
138 |
+
print(float(prediction[idx_max]))
|
139 |
+
# return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}
|
140 |
+
return id2name[str(idx_max)][0]
|
141 |
+
|
142 |
+
def reset():
|
143 |
+
return [], [[],[]]
|
144 |
+
|
145 |
+
|
146 |
+
eval_transforms = build_transforms(224)
|
147 |
+
|
148 |
+
import openai
|
149 |
+
import os
|
150 |
+
|
151 |
+
with gr.Blocks() as demo:
|
152 |
+
gr.HTML("""
|
153 |
+
<h1>Bamboo</h1>
|
154 |
+
<p>Bamboo for Image Recognition Demo. Bamboo knows what this object is and what you are doing in a very fine-grain granularity: fratercula arctica (fig.5) and dribbler (fig.2)).</p>
|
155 |
+
<strong>Paper:</strong> <a href="https://arxiv.org/abs/2203.07845" target="_blank">https://arxiv.org/abs/2203.07845</a><br/>
|
156 |
+
<strong>Project Website:</strong> <a href="https://opengvlab.shlab.org.cn/bamboo/home" target="_blank">https://opengvlab.shlab.org.cn/bamboo/home</a><br/>
|
157 |
+
<strong>Code and Model:</strong> <a href="https://github.com/ZhangYuanhan-AI/Bamboo" target="_blank">https://github.com/ZhangYuanhan-AI/Bamboo</a><br/>
|
158 |
+
<strong>Tips:</strong><ul>
|
159 |
+
<li>We use Bamboo and GPT-3 from openai to build this demo</li>
|
160 |
+
</ul>
|
161 |
+
""")
|
162 |
+
# history for GPT, conversation for chatbox
|
163 |
+
gr_state = gr.State([[],[]])
|
164 |
+
|
165 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="Bamboo Chatbot")
|
166 |
+
text_input = gr.Textbox(label="Message", placeholder="Send a message")
|
167 |
+
image = gr.inputs.Image()
|
168 |
+
with gr.Row():
|
169 |
+
submit_btn = gr.Button("Submit Text", interactive=True,variant='primary' )
|
170 |
+
reset_btn = gr.Button("Reset All")
|
171 |
+
submit_btn_img = gr.Button("Submit Img", interactive=True,variant='primary')
|
172 |
+
clear_btn_img = gr.Button("Submit Img", interactive=True,variant='primary')
|
173 |
+
|
174 |
+
# image_btn = gr.UploadButton("Upload Image", file_types=["image"])
|
175 |
+
|
176 |
+
# image_btn.upload(run_chatbot_with_img, [image_btn,gr_state], [chatbot,gr_state])
|
177 |
+
|
178 |
+
text_input.submit(fn=run_chatbot,inputs=[text_input,gr_state],outputs=[chatbot,gr_state])
|
179 |
+
text_input.submit(lambda: "", None, text_input)
|
180 |
+
submit_btn.click(fn=run_chatbot,inputs=[text_input,gr_state],outputs=[chatbot,gr_state])
|
181 |
+
submit_btn.click(lambda: "", None, text_input)
|
182 |
+
reset_btn.click(fn=reset,inputs=[],outputs=[chatbot,gr_state])
|
183 |
+
submit_btn_img.click(run_chatbot_with_img, [image,gr_state], [chatbot,gr_state])
|
184 |
+
|
185 |
+
|
186 |
+
demo.launch(debug = True)
|
app_bak.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import requests
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from PIL import Image
|
9 |
+
import torchvision
|
10 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
11 |
+
from timm.data import create_transform
|
12 |
+
|
13 |
+
from timmvit import timmvit
|
14 |
+
import json
|
15 |
+
from timm.models.hub import download_cached_file
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
def pil_loader(filepath):
|
19 |
+
with Image.open(filepath) as img:
|
20 |
+
img = img.convert('RGB')
|
21 |
+
return img
|
22 |
+
|
23 |
+
def build_transforms(input_size, center_crop=True):
|
24 |
+
transform = torchvision.transforms.Compose([
|
25 |
+
torchvision.transforms.ToPILImage(),
|
26 |
+
torchvision.transforms.Resize(input_size * 8 // 7),
|
27 |
+
torchvision.transforms.CenterCrop(input_size),
|
28 |
+
torchvision.transforms.ToTensor(),
|
29 |
+
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
30 |
+
])
|
31 |
+
return transform
|
32 |
+
|
33 |
+
# Download human-readable labels for Bamboo.
|
34 |
+
with open('./trainid2name.json') as f:
|
35 |
+
id2name = json.load(f)
|
36 |
+
|
37 |
+
|
38 |
+
'''
|
39 |
+
build model
|
40 |
+
'''
|
41 |
+
model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert')
|
42 |
+
model.eval()
|
43 |
+
|
44 |
+
'''
|
45 |
+
borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
|
46 |
+
'''
|
47 |
+
def show_cam_on_image(img: np.ndarray,
|
48 |
+
mask: np.ndarray,
|
49 |
+
use_rgb: bool = False,
|
50 |
+
colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
|
51 |
+
""" This function overlays the cam mask on the image as an heatmap.
|
52 |
+
By default the heatmap is in BGR format.
|
53 |
+
:param img: The base image in RGB or BGR format.
|
54 |
+
:param mask: The cam mask.
|
55 |
+
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
|
56 |
+
:param colormap: The OpenCV colormap to be used.
|
57 |
+
:returns: The default image with the cam overlay.
|
58 |
+
"""
|
59 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
|
60 |
+
if use_rgb:
|
61 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
62 |
+
heatmap = np.float32(heatmap) / 255
|
63 |
+
|
64 |
+
if np.max(img) > 1:
|
65 |
+
raise Exception(
|
66 |
+
"The input image should np.float32 in the range [0, 1]")
|
67 |
+
|
68 |
+
cam = 0.7*heatmap + 0.3*img
|
69 |
+
# cam = cam / np.max(cam)
|
70 |
+
return np.uint8(255 * cam)
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
def recognize_image(image):
|
76 |
+
img_t = eval_transforms(image)
|
77 |
+
# compute output
|
78 |
+
output = model(img_t.unsqueeze(0))
|
79 |
+
prediction = output.softmax(-1).flatten()
|
80 |
+
_,top5_idx = torch.topk(prediction, 5)
|
81 |
+
return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}
|
82 |
+
|
83 |
+
eval_transforms = build_transforms(224)
|
84 |
+
|
85 |
+
|
86 |
+
image = gr.inputs.Image()
|
87 |
+
label = gr.outputs.Label(num_top_classes=5)
|
88 |
+
|
89 |
+
gr.Interface(
|
90 |
+
description="Bamboo for Image Recognition Demo (https://github.com/Davidzhangyuanhan/Bamboo). Bamboo knows what this object is and what you are doing in a very fine-grain granularity: fratercula arctica (fig.5) and dribbler (fig.2)).",
|
91 |
+
fn=recognize_image,
|
92 |
+
inputs=["image"],
|
93 |
+
outputs=[
|
94 |
+
label,
|
95 |
+
],
|
96 |
+
examples=[
|
97 |
+
["./examples/playing_mahjong.jpg"],
|
98 |
+
["./examples/dribbler.jpg"],
|
99 |
+
["./examples/Ferrari-F355.jpg"],
|
100 |
+
["./examples/northern_oriole.jpg"],
|
101 |
+
["./examples/fratercula_arctica.jpg"],
|
102 |
+
["./examples/husky.jpg"],
|
103 |
+
["./examples/taraxacum_erythrospermum.jpg"],
|
104 |
+
],
|
105 |
+
).launch()
|
examples/Ferrari-F355.jpg
ADDED
examples/basketball.jpg
ADDED
examples/dribbler.jpg
ADDED
examples/fratercula_arctica.jpg
ADDED
examples/husky.jpg
ADDED
examples/northern_oriole.jpg
ADDED
examples/playing_mahjong.jpg
ADDED
examples/taraxacum_erythrospermum.jpg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchvision==0.14.1
|
2 |
+
torch==1.13.1
|
3 |
+
opencv-python-headless<4.3
|
4 |
+
timm==0.6.12
|
5 |
+
numpy==1.21.5
|
6 |
+
|
7 |
+
|
8 |
+
requests==2.25.1
|
9 |
+
gradio==3.19.1
|
10 |
+
opencv-python==4.7.0.68
|
11 |
+
openai==0.26.5
|
12 |
+
pillow==9.3.0
|
timmvit.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
4 |
+
# ------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import timm
|
7 |
+
import torch
|
8 |
+
import copy
|
9 |
+
import torch.nn as nn
|
10 |
+
import torchvision
|
11 |
+
import json
|
12 |
+
from timm.models.hub import download_cached_file
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class MyViT(nn.Module):
|
18 |
+
def __init__(self, num_classes=115217, pretrain_path=None, enable_fc=False):
|
19 |
+
super().__init__()
|
20 |
+
print('initializing ViT model as backbone using ckpt:', pretrain_path)
|
21 |
+
self.model = timm.create_model('vit_base_patch16_224',checkpoint_path=pretrain_path,num_classes=num_classes)# pretrained=True)
|
22 |
+
# def forward_features(self, x):
|
23 |
+
# x = self.model.patch_embed(x)
|
24 |
+
# cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
25 |
+
# if self.model.dist_token is None:
|
26 |
+
# x = torch.cat((cls_token, x), dim=1)
|
27 |
+
# else:
|
28 |
+
# x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
29 |
+
|
30 |
+
# x = self.model.pos_drop(x + self.model.pos_embed)
|
31 |
+
# x = self.model.blocks(x)
|
32 |
+
# x = self.model.norm(x)
|
33 |
+
|
34 |
+
# return self.model.pre_logits(x[:, 0])
|
35 |
+
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
x = self.model.forward(x)
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
def timmvit(**kwargs):
|
43 |
+
default_kwargs={}
|
44 |
+
default_kwargs.update(**kwargs)
|
45 |
+
return MyViT(**default_kwargs)
|
46 |
+
|
47 |
+
|
48 |
+
def build_transforms(input_size, center_crop=True):
|
49 |
+
transform = torchvision.transforms.Compose([
|
50 |
+
torchvision.transforms.Resize(input_size * 8 // 7),
|
51 |
+
torchvision.transforms.CenterCrop(input_size),
|
52 |
+
torchvision.transforms.ToTensor(),
|
53 |
+
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
54 |
+
])
|
55 |
+
return transform
|
56 |
+
|
57 |
+
def pil_loader(filepath):
|
58 |
+
with Image.open(filepath) as img:
|
59 |
+
img = img.convert('RGB')
|
60 |
+
return img
|
61 |
+
|
62 |
+
def test_build():
|
63 |
+
with open('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/trainid2name.json') as f:
|
64 |
+
id2name = json.load(f)
|
65 |
+
img = pil_loader('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/142520422_6ad756ddf6_w_d.jpg')
|
66 |
+
eval_transforms = build_transforms(224)
|
67 |
+
img_t = eval_transforms(img)
|
68 |
+
img_t = img_t[None, :]
|
69 |
+
model = MyViT(pretrain_path='/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/Bamboo_v0-1_ViT-B16.pth.tar.convert')
|
70 |
+
# image = torch.rand(1, 3, 224, 224)
|
71 |
+
output = model(img_t)
|
72 |
+
# import pdb;pdb.set_trace()
|
73 |
+
prediction = output.softmax(-1).flatten()
|
74 |
+
_,top5_idx = torch.topk(prediction, 5)
|
75 |
+
# import pdb;pdb.set_trace()
|
76 |
+
print({id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()})
|
77 |
+
|
78 |
+
if __name__ == '__main__':
|
79 |
+
test_build()
|
trainid2name.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|