File size: 2,970 Bytes
6debff8
1
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: image_segmentation\n", "### Simple image segmentation using gradio's AnnotatedImage component.\n", "        "]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import numpy as np\n", "import random\n", "\n", "with gr.Blocks() as demo:\n", "    section_labels = [\n", "        \"apple\",\n", "        \"banana\",\n", "        \"carrot\",\n", "        \"donut\",\n", "        \"eggplant\",\n", "        \"fish\",\n", "        \"grapes\",\n", "        \"hamburger\",\n", "        \"ice cream\",\n", "        \"juice\",\n", "    ]\n", "\n", "    with gr.Row():\n", "        num_boxes = gr.Slider(0, 5, 2, step=1, label=\"Number of boxes\")\n", "        num_segments = gr.Slider(0, 5, 1, step=1, label=\"Number of segments\")\n", "\n", "    with gr.Row():\n", "        img_input = gr.Image()\n", "        img_output = gr.AnnotatedImage(\n", "            color_map={\"banana\": \"#a89a00\", \"carrot\": \"#ffae00\"}\n", "        )\n", "\n", "    section_btn = gr.Button(\"Identify Sections\")\n", "    selected_section = gr.Textbox(label=\"Selected Section\")\n", "\n", "    def section(img, num_boxes, num_segments):\n", "        sections = []\n", "        for a in range(num_boxes):\n", "            x = random.randint(0, img.shape[1])\n", "            y = random.randint(0, img.shape[0])\n", "            w = random.randint(0, img.shape[1] - x)\n", "            h = random.randint(0, img.shape[0] - y)\n", "            sections.append(((x, y, x + w, y + h), section_labels[a]))\n", "        for b in range(num_segments):\n", "            x = random.randint(0, img.shape[1])\n", "            y = random.randint(0, img.shape[0])\n", "            r = random.randint(0, min(x, y, img.shape[1] - x, img.shape[0] - y))\n", "            mask = np.zeros(img.shape[:2])\n", "            for i in range(img.shape[0]):\n", "                for j in range(img.shape[1]):\n", "                    dist_square = (i - y) ** 2 + (j - x) ** 2\n", "                    if dist_square < r**2:\n", "                        mask[i, j] = round((r**2 - dist_square) / r**2 * 4) / 4\n", "            sections.append((mask, section_labels[b + num_boxes]))\n", "        return (img, sections)\n", "\n", "    section_btn.click(section, [img_input, num_boxes, num_segments], img_output)\n", "\n", "    def select_section(evt: gr.SelectData):\n", "        return section_labels[evt.index]\n", "\n", "    img_output.select(select_section, None, selected_section)\n", "\n", "if __name__ == \"__main__\":\n", "    demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}