Spaces:
Runtime error
Runtime error
import argparse | |
import requests | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
import numpy | |
import torchvision | |
from torchvision import transforms | |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from timm.data import create_transform | |
import openai | |
from timmvit import timmvit | |
import json | |
from timm.models.hub import download_cached_file | |
from PIL import Image | |
import tempfile | |
# key for GPT | |
openai.api_key = "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" | |
def pil_loader(filepath): | |
with Image.open(filepath) as img: | |
img = img.convert('RGB') | |
return img | |
def build_transforms(input_size, center_crop=True): | |
transform = torchvision.transforms.Compose([ | |
torchvision.transforms.ToPILImage(), | |
torchvision.transforms.Resize(input_size * 8 // 7), | |
torchvision.transforms.CenterCrop(input_size), | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
return transform | |
# Download human-readable labels for Bamboo. | |
with open('./trainid2name.json') as f: | |
id2name = json.load(f) | |
''' | |
build model | |
''' | |
model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert') | |
model.eval() | |
''' | |
borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py | |
''' | |
def show_cam_on_image(img: np.ndarray, | |
mask: np.ndarray, | |
use_rgb: bool = False, | |
colormap: int = cv2.COLORMAP_JET) -> np.ndarray: | |
""" This function overlays the cam mask on the image as an heatmap. | |
By default the heatmap is in BGR format. | |
:param img: The base image in RGB or BGR format. | |
:param mask: The cam mask. | |
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. | |
:param colormap: The OpenCV colormap to be used. | |
:returns: The default image with the cam overlay. | |
""" | |
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) | |
if use_rgb: | |
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
heatmap = np.float32(heatmap) / 255 | |
if np.max(img) > 1: | |
raise Exception( | |
"The input image should np.float32 in the range [0, 1]") | |
cam = 0.7*heatmap + 0.3*img | |
# cam = cam / np.max(cam) | |
return np.uint8(255 * cam) | |
def chat_with_GPT(my_prompt,history,*args): | |
this_history = '' | |
for i in history: | |
for j in i: | |
this_history += j + '\n' | |
# print("----this_history----\n"+this_history) | |
# my_prompt = input('Please give your Q:') | |
my_resp = openai.Completion.create( | |
model="text-davinci-003", | |
prompt=this_history+my_prompt, | |
temperature=args[1], | |
max_tokens=args[0], | |
frequency_penalty=args[2], | |
presence_penalty=args[3], | |
) | |
msg = my_resp.choices[0].text.strip() | |
return msg | |
def run_chatbot(input, max_tokens,temperature,frequency_penalty,presence_penalty,gr_state=[]): | |
history, conversation = gr_state[0],gr_state[1] | |
output = chat_with_GPT(input,history,max_tokens,temperature,frequency_penalty,presence_penalty) | |
history.append((input, output)) | |
conversation.append((input, output)) | |
# chatbox, state | |
return conversation,[history,conversation] | |
def run_chatbot_with_img(input_img,max_tokens,temperature,frequency_penalty,presence_penalty,gr_state=[]): | |
print(type(input_img)) | |
img_save = Image.open(input_img.name).resize((224,224)).convert('RGB') | |
img_save.save(input_img.name) | |
img4cls = numpy.array(img_save) | |
history, conversation = gr_state[0],gr_state[1] | |
# TODO: save img and show in conversation | |
# save_img(input_img) | |
img_cls = recognize_image(img4cls) | |
# conversation = conversation+ [(f'<img src="/file={input_img.name}" style="display: inline-block;">', "")] | |
input = 'I have given you a photo about '+ img_cls + ', and tell me its definition.' | |
output = chat_with_GPT(input,history,max_tokens,temperature,frequency_penalty,presence_penalty) | |
input_mask = f'<img src="/file={input_img.name}" style="display: inline-block;">' | |
# input_mask = 'Upload image' | |
# conversation save chatbox content | |
conversation.append((input_mask,output)) | |
# history for GPT | |
history.append((input, output)) | |
# chatbox gr_state | |
return conversation , [history,conversation] | |
def save_img(image: Image.Image): | |
filename = next(tempfile._get_candidate_names()) + '.png' | |
print(filename) | |
image.save(filename) | |
return filename | |
def recognize_image(image): | |
img_t = eval_transforms(image) | |
# compute output | |
output = model(img_t.unsqueeze(0)) | |
prediction = output.softmax(-1).flatten() | |
_,top5_idx = torch.topk(prediction, 5) | |
idx_max= top5_idx.tolist()[0] | |
print(id2name[str(idx_max)][0]) | |
print(float(prediction[idx_max])) | |
# return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()} | |
return id2name[str(idx_max)][0] | |
def reset(): | |
return [], [[],[]] | |
eval_transforms = build_transforms(224) | |
import openai | |
import os | |
with gr.Blocks() as demo: | |
gr.HTML(""" | |
<h1>Bamboo</h1> | |
<p>Bamboo for Image Recognition Demo. Bamboo knows what this object is and what you are doing in a very fine-grain granularity: fratercula arctica (fig.5) and dribbler (fig.2)).</p> | |
<strong>Paper:</strong> <a href="https://arxiv.org/abs/2203.07845" target="_blank">https://arxiv.org/abs/2203.07845</a><br/> | |
<strong>Project Website:</strong> <a href="https://opengvlab.shlab.org.cn/bamboo/home" target="_blank">https://opengvlab.shlab.org.cn/bamboo/home</a><br/> | |
<strong>Code and Model:</strong> <a href="https://github.com/ZhangYuanhan-AI/Bamboo" target="_blank">https://github.com/ZhangYuanhan-AI/Bamboo</a><br/> | |
<strong>Tips:</strong><ul> | |
<li>We use Bamboo and GPT-3 from openai to build this demo</li> | |
</ul> | |
""") | |
# history for GPT, conversation for chatbox | |
gr_state = gr.State([[],[]]) | |
chatbot = gr.Chatbot(elem_id="chatbot", label="Bamboo Chatbot") | |
text_input = gr.Textbox(label="Message", placeholder="Send a message") | |
image = gr.inputs.Image() | |
with gr.Row(): | |
submit_btn = gr.Button("Submit Text", interactive=True,variant='primary' ) | |
reset_btn = gr.Button("Reset All") | |
submit_btn_img = gr.Button("Submit Img", interactive=True,variant='primary') | |
# clear_btn_img = gr.Button("Clear", interactive=True,variant='primary') | |
image_btn = gr.UploadButton("Upload Image", file_types=["image"]) | |
with gr.Column(scale=0.3, min_width=400): | |
max_tokens = gr.Number( | |
value=1000, precision=1, interactive=True, label="Maximum length of generated text") | |
temperature = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.0, interactive=True, label="Diversity of generated text") | |
frequency_penalty = gr.Slider(minimum=-2.0, maximum=2.0, value=0.5, | |
step=0.1, interactive=True, label="Frequency of generation of repeat tokens") | |
presence_penalty = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, | |
step=0.1, interactive=True, label="Frequency of generation of tokens independent of the given prefix") | |
# image_btn = gr.UploadButton("Upload Image", file_types=["image"]) | |
# image_btn.upload(run_chatbot_with_img, [image_btn,gr_state], [chatbot,gr_state]) | |
text_input.submit(fn=run_chatbot,inputs=[text_input,max_tokens,temperature,frequency_penalty,presence_penalty,gr_state],outputs=[chatbot,gr_state]) | |
text_input.submit(lambda: "", None, text_input) | |
submit_btn.click(fn=run_chatbot,inputs=[text_input,max_tokens,temperature,frequency_penalty,presence_penalty,gr_state],outputs=[chatbot,gr_state]) | |
submit_btn.click(lambda: "", None, text_input) | |
reset_btn.click(fn=reset,inputs=[],outputs=[chatbot,gr_state]) | |
submit_btn_img.click(run_chatbot_with_img, [image,max_tokens,temperature,frequency_penalty,presence_penalty,gr_state], [chatbot,gr_state]) | |
image_btn.upload(run_chatbot_with_img, [image_btn,max_tokens,temperature,frequency_penalty,presence_penalty,gr_state], [chatbot,gr_state]) | |
demo.launch(debug = True) |