import argparse import requests import gradio as gr import numpy as np import cv2 import torch import torch.nn as nn from PIL import Image import torchvision from torchvision import transforms from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import create_transform import openai from timmvit import timmvit import json from timm.models.hub import download_cached_file from PIL import Image import tempfile # key for GPT openai.api_key = "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" def pil_loader(filepath): with Image.open(filepath) as img: img = img.convert('RGB') return img def build_transforms(input_size, center_crop=True): transform = torchvision.transforms.Compose([ torchvision.transforms.ToPILImage(), torchvision.transforms.Resize(input_size * 8 // 7), torchvision.transforms.CenterCrop(input_size), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform # Download human-readable labels for Bamboo. with open('./trainid2name.json') as f: id2name = json.load(f) ''' build model ''' model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert') model.eval() ''' borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py ''' def show_cam_on_image(img: np.ndarray, mask: np.ndarray, use_rgb: bool = False, colormap: int = cv2.COLORMAP_JET) -> np.ndarray: """ This function overlays the cam mask on the image as an heatmap. By default the heatmap is in BGR format. :param img: The base image in RGB or BGR format. :param mask: The cam mask. :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. :param colormap: The OpenCV colormap to be used. :returns: The default image with the cam overlay. """ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) if use_rgb: heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) heatmap = np.float32(heatmap) / 255 if np.max(img) > 1: raise Exception( "The input image should np.float32 in the range [0, 1]") cam = 0.7*heatmap + 0.3*img # cam = cam / np.max(cam) return np.uint8(255 * cam) def chat_with_GPT(my_prompt,history,*args): this_history = '' for i in history: for j in i: this_history += j + '\n' # print("----this_history----\n"+this_history) # my_prompt = input('Please give your Q:') my_resp = openai.Completion.create( model="text-davinci-003", prompt=this_history+my_prompt, temperature=args[1], max_tokens=args[0], frequency_penalty=args[2], presence_penalty=args[3], ) msg = my_resp.choices[0].text.strip() return msg def run_chatbot(input, max_tokens,temperature,frequency_penalty,presence_penalty,gr_state=[]): history, conversation = gr_state[0],gr_state[1] output = chat_with_GPT(input,history,max_tokens,temperature,frequency_penalty,presence_penalty) history.append((input, output)) conversation.append((input, output)) # chatbox, state return conversation,[history,conversation] def run_chatbot_with_img(input_img,max_tokens,temperature,frequency_penalty,presence_penalty,gr_state=[]): history, conversation = gr_state[0],gr_state[1] # TODO: save img and show in conversation # save_img(input_img) img_cls = recognize_image(input_img) # conversation = conversation+ [(f'', "")] input = 'I have given you a photo about '+ img_cls + ', and tell me its definition.' output = chat_with_GPT(input,history,max_tokens,temperature,frequency_penalty,presence_penalty) input_mask = 'Upload image' # conversation save chatbox content conversation.append((input_mask,output)) # history for GPT history.append((input, output)) # chatbox gr_state return conversation , [history,conversation] def save_img(image): filename = next(tempfile._get_candidate_names()) + '.png' print(filename) image.save(filename) return filename def recognize_image(image): img_t = eval_transforms(image) # compute output output = model(img_t.unsqueeze(0)) prediction = output.softmax(-1).flatten() _,top5_idx = torch.topk(prediction, 5) idx_max= top5_idx.tolist()[0] print(id2name[str(idx_max)][0]) print(float(prediction[idx_max])) # return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()} return id2name[str(idx_max)][0] def reset(): return [], [[],[]] eval_transforms = build_transforms(224) with gr.Blocks() as demo: gr.HTML("""

Bamboo

Bamboo for Image Recognition Demo. Bamboo knows what this object is and what you are doing in a very fine-grain granularity: fratercula arctica (fig.5) and dribbler (fig.2)).

Paper: https://arxiv.org/abs/2203.07845
Project Website: https://opengvlab.shlab.org.cn/bamboo/home
Code and Model: https://github.com/ZhangYuanhan-AI/Bamboo
Tips: """) # history for GPT, conversation for chatbox gr_state = gr.State([[],[]]) chatbot = gr.Chatbot(elem_id="chatbot", label="Bamboo Chatbot") text_input = gr.Textbox(label="Message", placeholder="Send a message") image = gr.inputs.Image() with gr.Row(): submit_btn = gr.Button("Submit Text", interactive=True,variant='primary' ) reset_btn = gr.Button("Reset All") submit_btn_img = gr.Button("Submit Img", interactive=True,variant='primary') # clear_btn_img = gr.Button("Clear", interactive=True,variant='primary') image_btn = gr.UploadButton("Upload Image", file_types=["image"]) with gr.Column(scale=0.3, min_width=400): max_tokens = gr.Number( minimum=500, maximum=2000, value=1000, precision=1, interactive=True, label="Maximum length of generated text") temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, interactive=True, label="Diversity of generated text") frequency_penalty = gr.Slider(minimum=-2.0, maximum=2.0, value=0.5, step=0.1, interactive=True, label="Frequency of generation of repeat tokens") presence_penalty = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, interactive=True, label="Frequency of generation of tokens independent of the given prefix") # image_btn = gr.UploadButton("Upload Image", file_types=["image"]) # image_btn.upload(run_chatbot_with_img, [image_btn,gr_state], [chatbot,gr_state]) text_input.submit(fn=run_chatbot,inputs=[text_input,max_tokens,temperature,frequency_penalty,presence_penalty,gr_state],outputs=[chatbot,gr_state]) text_input.submit(lambda: "", None, text_input) submit_btn.click(fn=run_chatbot,inputs=[text_input,max_tokens,temperature,frequency_penalty,presence_penalty,gr_state],outputs=[chatbot,gr_state]) submit_btn.click(lambda: "", None, text_input) reset_btn.click(fn=reset,inputs=[],outputs=[chatbot,gr_state]) submit_btn_img.click(run_chatbot_with_img, [image,max_tokens,temperature,frequency_penalty,presence_penalty,gr_state], [chatbot,gr_state]) # image_btn.upload(run_chatbot_with_img, [image,max_tokens,temperature,frequency_penalty,presence_penalty,gr_state], [chatbot,gr_state]) demo.launch(debug = True)