{ "cells": [ { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initializing ViT model as backbone using ckpt: ./Bamboo_v0-1_ViT-B16.pth.tar.convert\n" ] } ], "source": [ "%matplotlib inline\n", "import argparse\n", "import requests\n", "import gradio as gr\n", "import numpy as np\n", "import cv2\n", "import torch\n", "import torch.nn as nn\n", "from PIL import Image\n", "import torchvision\n", "from torchvision import transforms\n", "from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\n", "from timm.data import create_transform\n", "import openai\n", "from timmvit import timmvit\n", "import json\n", "from timm.models.hub import download_cached_file\n", "from PIL import Image\n", "import tempfile\n", "\n", "# key for GPT\n", "openai.api_key = \"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\"\n", "\n", "def pil_loader(filepath):\n", " with Image.open(filepath) as img:\n", " img = img.convert('RGB')\n", " return img\n", "\n", "def build_transforms(input_size, center_crop=True):\n", " transform = torchvision.transforms.Compose([\n", " torchvision.transforms.ToPILImage(),\n", " torchvision.transforms.Resize(input_size * 8 // 7),\n", " torchvision.transforms.CenterCrop(input_size),\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", " ])\n", " return transform\n", "\n", "# Download human-readable labels for Bamboo.\n", "with open('./trainid2name.json') as f:\n", " id2name = json.load(f)\n", "\n", "\n", "'''\n", "build model\n", "'''\n", "model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert')\n", "model.eval()\n", "\n", "'''\n", "borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py\n", "'''\n", "def show_cam_on_image(img: np.ndarray,\n", " mask: np.ndarray,\n", " use_rgb: bool = False,\n", " colormap: int = cv2.COLORMAP_JET) -> np.ndarray:\n", " \"\"\" This function overlays the cam mask on the image as an heatmap.\n", " By default the heatmap is in BGR format.\n", " :param img: The base image in RGB or BGR format.\n", " :param mask: The cam mask.\n", " :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.\n", " :param colormap: The OpenCV colormap to be used.\n", " :returns: The default image with the cam overlay.\n", " \"\"\"\n", " heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)\n", " if use_rgb:\n", " heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)\n", " heatmap = np.float32(heatmap) / 255\n", "\n", " if np.max(img) > 1:\n", " raise Exception(\n", " \"The input image should np.float32 in the range [0, 1]\")\n", "\n", " cam = 0.7*heatmap + 0.3*img\n", " # cam = cam / np.max(cam)\n", " return np.uint8(255 * cam)\n", "\n", "\n", "\n", "\n", "def chat_with_GPT(my_prompt,history,*args):\n", " this_history = ''\n", " for i in history:\n", " for j in i:\n", " this_history += j + '\\n'\n", "\n", " # print(\"----this_history----\\n\"+this_history)\n", " # my_prompt = input('Please give your Q:')\n", " my_resp = openai.Completion.create(\n", " model=\"text-davinci-003\",\n", " prompt=this_history+my_prompt,\n", " temperature=args[1],\n", " max_tokens=args[0],\n", " frequency_penalty=args[2],\n", " presence_penalty=args[3],\n", " )\n", " msg = my_resp.choices[0].text.strip()\n", " return msg\n", "\n", "def run_chatbot(input, max_tokens,temperature,frequency_penalty,presence_penalty,gr_state=[]):\n", " history, conversation = gr_state[0],gr_state[1]\n", " output = chat_with_GPT(input,history,max_tokens,temperature,frequency_penalty,presence_penalty)\n", " history.append((input, output))\n", " conversation.append((input, output))\n", " # chatbox, state\n", " return conversation,[history,conversation]\n", "\n", "def run_chatbot_with_img(input_img,max_tokens,temperature,frequency_penalty,presence_penalty,gr_state=[]):\n", " print(type(input_img))\n", " # input_image = Image.open(input_img.name).convert('RGB')\n", " # input_image.save(input_image.name)\n", " history, conversation = gr_state[0],gr_state[1]\n", "\n", " # TODO: save img and show in conversation\n", " # save_img(input_img)\n", "\n", "\n", " img_cls = recognize_image(input_img)\n", " # conversation = conversation+ [(f'', \"\")]\n", " input = 'I have given you a photo about '+ img_cls + ', and tell me its definition.'\n", " output = chat_with_GPT(input,history,max_tokens,temperature,frequency_penalty,presence_penalty)\n", "\n", " # input_mask = [(f'', \"\")]\n", " input_mask = 'Upload image'\n", " # conversation save chatbox content\n", " conversation.append((input_mask,output))\n", " # history for GPT\n", " history.append((input, output))\n", "\n", " # chatbox gr_state\n", " return conversation , [history,conversation]\n", "\n", "def save_img(image: Image.Image):\n", "\n", " filename = next(tempfile._get_candidate_names()) + '.png'\n", " print(filename)\n", " image.save(filename)\n", " return filename\n", "\n", "def recognize_image(image):\n", " img_t = eval_transforms(image)\n", " # compute output\n", " output = model(img_t.unsqueeze(0))\n", " prediction = output.softmax(-1).flatten()\n", " _,top5_idx = torch.topk(prediction, 5)\n", " idx_max= top5_idx.tolist()[0]\n", " print(id2name[str(idx_max)][0])\n", " print(float(prediction[idx_max]))\n", " # return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}\n", " return id2name[str(idx_max)][0]\n", "\n", "def reset():\n", " return [], [[],[]]\n", "\n", "\n", "eval_transforms = build_transforms(224)" ] }, { "cell_type": "code", "execution_count": null, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "D:\\Anaconda\\envs\\pytorch\\lib\\site-packages\\gradio\\inputs.py:257: UserWarning: Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components\n", " warnings.warn(\n", "D:\\Anaconda\\envs\\pytorch\\lib\\site-packages\\gradio\\deprecation.py:40: UserWarning: `optional` parameter is deprecated, and it has no effect\n", " warnings.warn(value)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/plain": "", "text/html": "
" }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "panda\n", "0.3023892641067505\n" ] } ], "source": [ "import openai\n", "import os\n", "\n", "with gr.Blocks() as demo:\n", " gr.HTML(\"\"\"\n", "

Bamboo

\n", "

Bamboo for Image Recognition Demo. Bamboo knows what this object is and what you are doing in a very fine-grain granularity: fratercula arctica (fig.5) and dribbler (fig.2)).

\n", " Paper: https://arxiv.org/abs/2203.07845
\n", " Project Website: https://opengvlab.shlab.org.cn/bamboo/home
\n", " Code and Model: https://github.com/ZhangYuanhan-AI/Bamboo
\n", " Tips:\n", " \"\"\")\n", " # history for GPT, conversation for chatbox\n", " gr_state = gr.State([[],[]])\n", "\n", " chatbot = gr.Chatbot(elem_id=\"chatbot\", label=\"Bamboo Chatbot\")\n", " text_input = gr.Textbox(label=\"Message\", placeholder=\"Send a message\")\n", " image = gr.inputs.Image()\n", " with gr.Row():\n", " submit_btn = gr.Button(\"Submit Text\", interactive=True,variant='primary' )\n", " reset_btn = gr.Button(\"Reset All\")\n", " submit_btn_img = gr.Button(\"Submit Img\", interactive=True,variant='primary')\n", " # clear_btn_img = gr.Button(\"Clear\", interactive=True,variant='primary')\n", " image_btn = gr.UploadButton(\"Upload Image\", file_types=[\"image\"])\n", " with gr.Column(scale=0.3, min_width=400):\n", " max_tokens = gr.Number(\n", " value=1000, precision=1, interactive=True, label=\"Maximum length of generated text\")\n", " temperature = gr.Slider(\n", " minimum=0.0, maximum=1.0, value=0.0, interactive=True, label=\"Diversity of generated text\")\n", " frequency_penalty = gr.Slider(minimum=-2.0, maximum=2.0, value=0.5,\n", " step=0.1, interactive=True, label=\"Frequency of generation of repeat tokens\")\n", " presence_penalty = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0,\n", " step=0.1, interactive=True, label=\"Frequency of generation of tokens independent of the given prefix\")\n", "\n", "\n", " # image_btn = gr.UploadButton(\"Upload Image\", file_types=[\"image\"])\n", "\n", " # image_btn.upload(run_chatbot_with_img, [image_btn,gr_state], [chatbot,gr_state])\n", "\n", " text_input.submit(fn=run_chatbot,inputs=[text_input,max_tokens,temperature,frequency_penalty,presence_penalty,gr_state],outputs=[chatbot,gr_state])\n", " text_input.submit(lambda: \"\", None, text_input)\n", " submit_btn.click(fn=run_chatbot,inputs=[text_input,max_tokens,temperature,frequency_penalty,presence_penalty,gr_state],outputs=[chatbot,gr_state])\n", " submit_btn.click(lambda: \"\", None, text_input)\n", " reset_btn.click(fn=reset,inputs=[],outputs=[chatbot,gr_state])\n", " submit_btn_img.click(run_chatbot_with_img, [image,max_tokens,temperature,frequency_penalty,presence_penalty,gr_state], [chatbot,gr_state])\n", " image_btn.upload(run_chatbot_with_img, [image_btn,max_tokens,temperature,frequency_penalty,presence_penalty,gr_state], [chatbot,gr_state])\n", "\n", "demo.launch(debug = True)" ], "metadata": { "collapsed": false, "pycharm": { "is_executing": true } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [], "metadata": { "collapsed": false } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }