fesvhtr commited on
Commit
c0f0698
1 Parent(s): c2d01eb

Upload 15 files

Browse files
Bamboo_v0-1_ViT-B16.pth.tar.convert ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6d30c823ba2fc764291e65a06747390a81b15a1e655dd02b45d58528e08c937
3
+ size 697651655
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Bamboo Test V1
3
- emoji: 🏃
4
- colorFrom: red
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
- license: openrail
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Bamboo ViT-B16 Demo
3
+ emoji: 🎋
4
+ colorFrom: blue
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.0.17
8
  app_file: app.py
9
  pinned: false
10
+ license: cc-by-4.0
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import requests
3
+ import gradio as gr
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+ import torch.nn as nn
8
+ from PIL import Image
9
+ import torchvision
10
+ from torchvision import transforms
11
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
12
+ from timm.data import create_transform
13
+ import openai
14
+ from timmvit import timmvit
15
+ import json
16
+ from timm.models.hub import download_cached_file
17
+ from PIL import Image
18
+ import tempfile
19
+
20
+ # key for GPT
21
+ openai.api_key = "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"
22
+
23
+ def pil_loader(filepath):
24
+ with Image.open(filepath) as img:
25
+ img = img.convert('RGB')
26
+ return img
27
+
28
+ def build_transforms(input_size, center_crop=True):
29
+ transform = torchvision.transforms.Compose([
30
+ torchvision.transforms.ToPILImage(),
31
+ torchvision.transforms.Resize(input_size * 8 // 7),
32
+ torchvision.transforms.CenterCrop(input_size),
33
+ torchvision.transforms.ToTensor(),
34
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
35
+ ])
36
+ return transform
37
+
38
+ # Download human-readable labels for Bamboo.
39
+ with open('./trainid2name.json') as f:
40
+ id2name = json.load(f)
41
+
42
+
43
+ '''
44
+ build model
45
+ '''
46
+ model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert')
47
+ model.eval()
48
+
49
+ '''
50
+ borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
51
+ '''
52
+ def show_cam_on_image(img: np.ndarray,
53
+ mask: np.ndarray,
54
+ use_rgb: bool = False,
55
+ colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
56
+ """ This function overlays the cam mask on the image as an heatmap.
57
+ By default the heatmap is in BGR format.
58
+ :param img: The base image in RGB or BGR format.
59
+ :param mask: The cam mask.
60
+ :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
61
+ :param colormap: The OpenCV colormap to be used.
62
+ :returns: The default image with the cam overlay.
63
+ """
64
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
65
+ if use_rgb:
66
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
67
+ heatmap = np.float32(heatmap) / 255
68
+
69
+ if np.max(img) > 1:
70
+ raise Exception(
71
+ "The input image should np.float32 in the range [0, 1]")
72
+
73
+ cam = 0.7*heatmap + 0.3*img
74
+ # cam = cam / np.max(cam)
75
+ return np.uint8(255 * cam)
76
+
77
+
78
+
79
+
80
+ def chat_with_GPT(my_prompt,history):
81
+ this_history = ''
82
+ for i in history:
83
+ for j in i:
84
+ this_history += j + '\n'
85
+
86
+ # print("----this_history----\n"+this_history)
87
+ # my_prompt = input('Please give your Q:')
88
+ my_resp = openai.Completion.create(
89
+ model="text-davinci-003", # 模型选择达芬奇
90
+ prompt=this_history+my_prompt, # 提问
91
+ temperature=0.8,
92
+ max_tokens=2000, # 生成答案的字节数
93
+ top_p=1.0, # 跟temperature有点类似,结果概率的前面的选择
94
+ frequency_penalty=0.5, # [-2,2]频率太高的词的惩罚,就是减少重复的词出现(比如小于0会出现很多重复词)
95
+ presence_penalty=0.0, # [-2,2]围绕着提问来回答的程度(比如小于0的回答会过于紧扣主题)
96
+ )
97
+ msg = my_resp.choices[0].text.strip()
98
+ return msg
99
+
100
+ def run_chatbot(input, gr_state=[]):
101
+ history, conversation = gr_state[0],gr_state[1]
102
+ output = chat_with_GPT(input,history)
103
+ history.append((input, output))
104
+ conversation.append((input, output))
105
+ # chatbox, state
106
+ return conversation,[history,conversation]
107
+
108
+ def run_chatbot_with_img(input_img,gr_state=[]):
109
+ history, conversation = gr_state[0],gr_state[1]
110
+ img_cls = recognize_image(input_img)
111
+ # conversation = conversation+ [(f'<img src="/file={input_img.name}" style="display: inline-block;">', "")]
112
+ input = 'I have given you a photo about '+ img_cls + ', and tell me its definition.'
113
+ output = chat_with_GPT(input,history)
114
+
115
+ input_mask = 'Upload image'
116
+ # conversation保存显示内容
117
+ conversation.append((input_mask,output))
118
+ # history保留真实内容
119
+ history.append((input, output))
120
+
121
+ # chatbox gr_state
122
+ return conversation , [history,conversation]
123
+
124
+ def save_img(image):
125
+
126
+ filename = next(tempfile._get_candidate_names()) + '.png'
127
+ image.save(filename)
128
+ return filename
129
+
130
+ def recognize_image(image):
131
+ img_t = eval_transforms(image)
132
+ # compute output
133
+ output = model(img_t.unsqueeze(0))
134
+ prediction = output.softmax(-1).flatten()
135
+ _,top5_idx = torch.topk(prediction, 5)
136
+ idx_max= top5_idx.tolist()[0]
137
+ print(id2name[str(idx_max)][0])
138
+ print(float(prediction[idx_max]))
139
+ # return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}
140
+ return id2name[str(idx_max)][0]
141
+
142
+ def reset():
143
+ return [], [[],[]]
144
+
145
+
146
+ eval_transforms = build_transforms(224)
147
+
148
+ import openai
149
+ import os
150
+
151
+ with gr.Blocks() as demo:
152
+ gr.HTML("""
153
+ <h1>Bamboo</h1>
154
+ <p>Bamboo for Image Recognition Demo. Bamboo knows what this object is and what you are doing in a very fine-grain granularity: fratercula arctica (fig.5) and dribbler (fig.2)).</p>
155
+ <strong>Paper:</strong> <a href="https://arxiv.org/abs/2203.07845" target="_blank">https://arxiv.org/abs/2203.07845</a><br/>
156
+ <strong>Project Website:</strong> <a href="https://opengvlab.shlab.org.cn/bamboo/home" target="_blank">https://opengvlab.shlab.org.cn/bamboo/home</a><br/>
157
+ <strong>Code and Model:</strong> <a href="https://github.com/ZhangYuanhan-AI/Bamboo" target="_blank">https://github.com/ZhangYuanhan-AI/Bamboo</a><br/>
158
+ <strong>Tips:</strong><ul>
159
+ <li>We use Bamboo and GPT-3 from openai to build this demo</li>
160
+ </ul>
161
+ """)
162
+ # history for GPT, conversation for chatbox
163
+ gr_state = gr.State([[],[]])
164
+
165
+ chatbot = gr.Chatbot(elem_id="chatbot", label="Bamboo Chatbot")
166
+ text_input = gr.Textbox(label="Message", placeholder="Send a message")
167
+ image = gr.inputs.Image()
168
+ with gr.Row():
169
+ submit_btn = gr.Button("Submit Text", interactive=True,variant='primary' )
170
+ reset_btn = gr.Button("Reset All")
171
+ submit_btn_img = gr.Button("Submit Img", interactive=True,variant='primary')
172
+ clear_btn_img = gr.Button("Submit Img", interactive=True,variant='primary')
173
+
174
+ # image_btn = gr.UploadButton("Upload Image", file_types=["image"])
175
+
176
+ # image_btn.upload(run_chatbot_with_img, [image_btn,gr_state], [chatbot,gr_state])
177
+
178
+ text_input.submit(fn=run_chatbot,inputs=[text_input,gr_state],outputs=[chatbot,gr_state])
179
+ text_input.submit(lambda: "", None, text_input)
180
+ submit_btn.click(fn=run_chatbot,inputs=[text_input,gr_state],outputs=[chatbot,gr_state])
181
+ submit_btn.click(lambda: "", None, text_input)
182
+ reset_btn.click(fn=reset,inputs=[],outputs=[chatbot,gr_state])
183
+ submit_btn_img.click(run_chatbot_with_img, [image,gr_state], [chatbot,gr_state])
184
+
185
+
186
+ demo.launch(debug = True)
app_bak.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import requests
3
+ import gradio as gr
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+ import torch.nn as nn
8
+ from PIL import Image
9
+ import torchvision
10
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
+ from timm.data import create_transform
12
+
13
+ from timmvit import timmvit
14
+ import json
15
+ from timm.models.hub import download_cached_file
16
+ from PIL import Image
17
+
18
+ def pil_loader(filepath):
19
+ with Image.open(filepath) as img:
20
+ img = img.convert('RGB')
21
+ return img
22
+
23
+ def build_transforms(input_size, center_crop=True):
24
+ transform = torchvision.transforms.Compose([
25
+ torchvision.transforms.ToPILImage(),
26
+ torchvision.transforms.Resize(input_size * 8 // 7),
27
+ torchvision.transforms.CenterCrop(input_size),
28
+ torchvision.transforms.ToTensor(),
29
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
30
+ ])
31
+ return transform
32
+
33
+ # Download human-readable labels for Bamboo.
34
+ with open('./trainid2name.json') as f:
35
+ id2name = json.load(f)
36
+
37
+
38
+ '''
39
+ build model
40
+ '''
41
+ model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert')
42
+ model.eval()
43
+
44
+ '''
45
+ borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
46
+ '''
47
+ def show_cam_on_image(img: np.ndarray,
48
+ mask: np.ndarray,
49
+ use_rgb: bool = False,
50
+ colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
51
+ """ This function overlays the cam mask on the image as an heatmap.
52
+ By default the heatmap is in BGR format.
53
+ :param img: The base image in RGB or BGR format.
54
+ :param mask: The cam mask.
55
+ :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
56
+ :param colormap: The OpenCV colormap to be used.
57
+ :returns: The default image with the cam overlay.
58
+ """
59
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
60
+ if use_rgb:
61
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
62
+ heatmap = np.float32(heatmap) / 255
63
+
64
+ if np.max(img) > 1:
65
+ raise Exception(
66
+ "The input image should np.float32 in the range [0, 1]")
67
+
68
+ cam = 0.7*heatmap + 0.3*img
69
+ # cam = cam / np.max(cam)
70
+ return np.uint8(255 * cam)
71
+
72
+
73
+
74
+
75
+ def recognize_image(image):
76
+ img_t = eval_transforms(image)
77
+ # compute output
78
+ output = model(img_t.unsqueeze(0))
79
+ prediction = output.softmax(-1).flatten()
80
+ _,top5_idx = torch.topk(prediction, 5)
81
+ return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}
82
+
83
+ eval_transforms = build_transforms(224)
84
+
85
+
86
+ image = gr.inputs.Image()
87
+ label = gr.outputs.Label(num_top_classes=5)
88
+
89
+ gr.Interface(
90
+ description="Bamboo for Image Recognition Demo (https://github.com/Davidzhangyuanhan/Bamboo). Bamboo knows what this object is and what you are doing in a very fine-grain granularity: fratercula arctica (fig.5) and dribbler (fig.2)).",
91
+ fn=recognize_image,
92
+ inputs=["image"],
93
+ outputs=[
94
+ label,
95
+ ],
96
+ examples=[
97
+ ["./examples/playing_mahjong.jpg"],
98
+ ["./examples/dribbler.jpg"],
99
+ ["./examples/Ferrari-F355.jpg"],
100
+ ["./examples/northern_oriole.jpg"],
101
+ ["./examples/fratercula_arctica.jpg"],
102
+ ["./examples/husky.jpg"],
103
+ ["./examples/taraxacum_erythrospermum.jpg"],
104
+ ],
105
+ ).launch()
examples/Ferrari-F355.jpg ADDED
examples/basketball.jpg ADDED
examples/dribbler.jpg ADDED
examples/fratercula_arctica.jpg ADDED
examples/husky.jpg ADDED
examples/northern_oriole.jpg ADDED
examples/playing_mahjong.jpg ADDED
examples/taraxacum_erythrospermum.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchvision==0.14.1
2
+ torch==1.13.1
3
+ opencv-python-headless<4.3
4
+ timm==0.6.12
5
+ numpy==1.21.5
6
+
7
+
8
+ requests==2.25.1
9
+ gradio==3.19.1
10
+ opencv-python==4.7.0.68
11
+ openai==0.26.5
12
+ pillow==9.3.0
timmvit.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Modified from DETR (https://github.com/facebookresearch/detr)
3
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
4
+ # ------------------------------------------------------------------------
5
+
6
+ import timm
7
+ import torch
8
+ import copy
9
+ import torch.nn as nn
10
+ import torchvision
11
+ import json
12
+ from timm.models.hub import download_cached_file
13
+ from PIL import Image
14
+
15
+
16
+
17
+ class MyViT(nn.Module):
18
+ def __init__(self, num_classes=115217, pretrain_path=None, enable_fc=False):
19
+ super().__init__()
20
+ print('initializing ViT model as backbone using ckpt:', pretrain_path)
21
+ self.model = timm.create_model('vit_base_patch16_224',checkpoint_path=pretrain_path,num_classes=num_classes)# pretrained=True)
22
+ # def forward_features(self, x):
23
+ # x = self.model.patch_embed(x)
24
+ # cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
25
+ # if self.model.dist_token is None:
26
+ # x = torch.cat((cls_token, x), dim=1)
27
+ # else:
28
+ # x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
29
+
30
+ # x = self.model.pos_drop(x + self.model.pos_embed)
31
+ # x = self.model.blocks(x)
32
+ # x = self.model.norm(x)
33
+
34
+ # return self.model.pre_logits(x[:, 0])
35
+
36
+
37
+ def forward(self, x):
38
+ x = self.model.forward(x)
39
+ return x
40
+
41
+
42
+ def timmvit(**kwargs):
43
+ default_kwargs={}
44
+ default_kwargs.update(**kwargs)
45
+ return MyViT(**default_kwargs)
46
+
47
+
48
+ def build_transforms(input_size, center_crop=True):
49
+ transform = torchvision.transforms.Compose([
50
+ torchvision.transforms.Resize(input_size * 8 // 7),
51
+ torchvision.transforms.CenterCrop(input_size),
52
+ torchvision.transforms.ToTensor(),
53
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
54
+ ])
55
+ return transform
56
+
57
+ def pil_loader(filepath):
58
+ with Image.open(filepath) as img:
59
+ img = img.convert('RGB')
60
+ return img
61
+
62
+ def test_build():
63
+ with open('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/trainid2name.json') as f:
64
+ id2name = json.load(f)
65
+ img = pil_loader('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/142520422_6ad756ddf6_w_d.jpg')
66
+ eval_transforms = build_transforms(224)
67
+ img_t = eval_transforms(img)
68
+ img_t = img_t[None, :]
69
+ model = MyViT(pretrain_path='/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/Bamboo_v0-1_ViT-B16.pth.tar.convert')
70
+ # image = torch.rand(1, 3, 224, 224)
71
+ output = model(img_t)
72
+ # import pdb;pdb.set_trace()
73
+ prediction = output.softmax(-1).flatten()
74
+ _,top5_idx = torch.topk(prediction, 5)
75
+ # import pdb;pdb.set_trace()
76
+ print({id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()})
77
+
78
+ if __name__ == '__main__':
79
+ test_build()
trainid2name.json ADDED
The diff for this file is too large to render. See raw diff