{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c0b71431"
},
"source": [
"## Environment Set-up\n",
"edit from [sam colab](https://colab.research.google.com/github/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb#scrollTo=MTeAdX_mHwAR)"
],
"id": "c0b71431"
},
{
"cell_type": "markdown",
"metadata": {
"id": "47e5a78f"
},
"source": [
"\n",
"\n",
"If you're running this notebook locally using Jupyter, please clone `SAM-Med2D` into a directory named `SAM_Med2D`. Note that you do **not** need to install `segment_anything` in your local environment, as `SAM-Med2D` and `SAM` share function names that could lead to conflicts.\n",
"\n",
"For Google Colab users: Set `using_colab=True` in the cell below before executing it. Although you can select 'GPU' under 'Edit' -> 'Notebook Settings' -> 'Hardware Accelerator', this notebook is designed to run efficiently in a CPU environment as well.\n",
"\n"
],
"id": "47e5a78f"
},
{
"cell_type": "code",
"source": [
"using_colab = True"
],
"metadata": {
"id": "K48skt_hFHCx"
},
"id": "K48skt_hFHCx",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!git clone https://github.com/uni-medical/SAM-Med2D.git SAM_Med2D\n",
"if using_colab:\n",
" import torch\n",
" import torchvision\n",
" print(\"PyTorch version:\", torch.__version__)\n",
" print(\"Torchvision version:\", torchvision.__version__)\n",
" print(\"CUDA is available:\", torch.cuda.is_available())\n",
" import sys\n",
" !{sys.executable} -m pip install opencv-python matplotlib\n",
" # !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'\n",
"\n",
" # !mkdir images\n",
" # !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/dog.jpg\n",
"\n",
"else:\n",
" # not sure for this part! Please check!\n",
" !pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html\n",
" !pip install opencv-python matplotlib\n",
" # !pip install 'git+https://github.com/facebookresearch/segment-anything.git'"
],
"metadata": {
"id": "vVVsJtIuFCsv",
"outputId": "00fa489f-89b1-45fa-9ab6-1b8bc9dad9c0",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"id": "vVVsJtIuFCsv",
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cloning into 'SAM_Med2D'...\n",
"remote: Enumerating objects: 111, done.\u001b[K\n",
"remote: Counting objects: 100% (111/111), done.\u001b[K\n",
"remote: Compressing objects: 100% (98/98), done.\u001b[K\n",
"remote: Total 111 (delta 42), reused 45 (delta 9), pack-reused 0\u001b[K\n",
"Receiving objects: 100% (111/111), 26.80 MiB | 16.19 MiB/s, done.\n",
"Resolving deltas: 100% (42/42), done.\n",
"PyTorch version: 2.0.1+cu118\n",
"Torchvision version: 0.15.2+cu118\n",
"CUDA is available: False\n",
"Requirement already satisfied: opencv-python in /usr/local/lib/python3.10/dist-packages (4.8.0.76)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (3.7.1)\n",
"Requirement already satisfied: numpy>=1.21.2 in /usr/local/lib/python3.10/dist-packages (from opencv-python) (1.23.5)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.1.0)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (0.11.0)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (4.42.1)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.4.4)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (23.1)\n",
"Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (9.4.0)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (3.1.1)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (2.8.2)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n",
"Collecting git+https://github.com/facebookresearch/segment-anything.git\n",
" Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-lfeo06av\n",
" Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-lfeo06av\n",
" Resolved https://github.com/facebookresearch/segment-anything.git to commit 6fdee8f2727f4506cfbbe553e23b895e27956588\n",
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Download Weight\n",
"edit from [samed colab](https://colab.research.google.com/drive/1KCS5ulpZasYl9DgJJn59WsGEB8vwSI_m?usp=sharing#scrollTo=NI9jWQnsPty2)"
],
"metadata": {
"id": "YUU2BAiwFpKE"
},
"id": "YUU2BAiwFpKE"
},
{
"cell_type": "code",
"source": [
"from pydrive.auth import GoogleAuth\n",
"from pydrive.drive import GoogleDrive\n",
"from google.colab import auth\n",
"from oauth2client.client import GoogleCredentials\n",
"import os\n",
"\n",
"ROOT_DIR = '.'\n",
"CODE_DIR = f\"{ROOT_DIR}/SAM_Med2D\"\n",
"MODEL_DIR = f\"{CODE_DIR}/pretrain_model\"\n",
"os.makedirs(f'{MODEL_DIR}')\n",
"\n",
"download_with_pydrive = True\n",
"\n",
"class Downloader(object):\n",
" def __init__(self, use_pydrive, save_dir='.'):\n",
" self.use_pydrive = use_pydrive\n",
" current_directory = os.getcwd()\n",
" self.save_dir = save_dir\n",
" if self.use_pydrive:\n",
" self.authenticate()\n",
"\n",
" def authenticate(self):\n",
" auth.authenticate_user()\n",
" gauth = GoogleAuth()\n",
" gauth.credentials = GoogleCredentials.get_application_default()\n",
" self.drive = GoogleDrive(gauth)\n",
"\n",
" def download_file(self, file_id, file_name):\n",
" file_dst = f'{self.save_dir}/{file_name}'\n",
" if os.path.exists(file_dst):\n",
" print(f'{file_name} already exists')\n",
" return\n",
" downloaded = self.drive.CreateFile({'id': file_id})\n",
" downloaded.FetchMetadata(fetch_all=True)\n",
" downloaded.GetContentFile(file_dst)\n",
"\n",
"downloader = Downloader(download_with_pydrive, MODEL_DIR)\n",
"\n",
"sam_med2d_model = {'id': '1ARiB5RkSsWmAB_8mqWnwDF8ZKTtFwsjl', 'name': 'sam-med2d_b.pth'}\n",
"# samed_model = {'id': '1P0Bm-05l-rfeghbrT1B62v5eN-3A-uOr', 'name': 'epoch_159.pth'}\n",
"# medsam_model = {'id': '1UAmWL88roYR7wKlnApw5Bcuzf2iQgk6_', 'name': 'medsam_vit_b.pth'}\n",
"downloader.download_file(file_id=sam_med2d_model['id'], file_name=sam_med2d_model['name'])\n",
"# !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P $MODEL_DIR\n",
"# downloader.download_file(file_id=medsam_model['id'], file_name=medsam_model['name'])\n",
"# downloader.download_file(file_id=samed_model['id'], file_name=samed_model['name'])\n"
],
"metadata": {
"id": "74gPg2AwFtPK"
},
"id": "74gPg2AwFtPK",
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"id": "b4a4b25c",
"metadata": {
"id": "b4a4b25c"
},
"source": [
"# SAM-Med2D generates predicted object masks based on prompts."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "69b28288",
"metadata": {
"id": "69b28288"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import cv2\n",
"import sys\n",
"sys.path.append(ROOT_DIR) # make sure the import SAM_Med2D.segment_anything work"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29bc90d5",
"metadata": {
"id": "29bc90d5"
},
"outputs": [],
"source": [
"def show_mask(mask, ax, random_color=False):\n",
" if random_color:\n",
" color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
" else:\n",
" color = np.array([30/255, 144/255, 255/255, 0.6])\n",
" h, w = mask.shape[-2:]\n",
" mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
" ax.imshow(mask_image)\n",
"\n",
"def show_points(coords, labels, ax, marker_size=375):\n",
" pos_points = coords[labels==1]\n",
" neg_points = coords[labels==0]\n",
" ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
" ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
"\n",
"def show_box(box, ax):\n",
" x0, y0 = box[0], box[1]\n",
" w, h = box[2] - box[0], box[3] - box[1]\n",
" ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))\n"
]
},
{
"cell_type": "markdown",
"id": "23842fb2",
"metadata": {
"id": "23842fb2"
},
"source": [
"## Example image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c2e4f6b",
"metadata": {
"id": "3c2e4f6b",
"outputId": "16670f19-01ac-472c-f535-8f3bf4fdb767",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(290, 320, 3)"
]
},
"metadata": {},
"execution_count": 10
}
],
"source": [
"os.chdir(f'{CODE_DIR}')\n",
"image = cv2.imread('data_demo/images/amos_0507_31.png')\n",
"image.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e30125fd",
"metadata": {
"scrolled": false,
"id": "e30125fd",
"outputId": "9f6444aa-6f6c-48e9-c026-71005ba5fe7e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 767
}
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"