{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "application/javascript": "IPython.notebook.set_autosave_interval(300000)" }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Autosaving every 300 seconds\n" ] } ], "source": [ "%autosave 300\n", "%load_ext autoreload\n", "%autoreload 2\n", "%reload_ext autoreload\n", "%config Completer.use_jedi = False" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws\n" ] } ], "source": [ "\n", "import os\n", "\n", "os.chdir(\"..\")\n", "print(os.getcwd())" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/anaconda/envs/emlo_env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import os\n", "import shutil\n", "from pathlib import Path\n", "import torch\n", "import lightning as L\n", "from lightning.pytorch.loggers import Logger\n", "from typing import List\n", "from src.datamodules.catdog_datamodule import CatDogImageDataModule\n", "from src.utils.logging_utils import setup_logger, task_wrapper\n", "from loguru import logger\n", "from dotenv import load_dotenv, find_dotenv\n", "import rootutils\n", "import hydra\n", "from omegaconf import DictConfig, OmegaConf\n", "from lightning.pytorch.callbacks import (\n", " ModelCheckpoint,\n", " EarlyStopping,\n", " RichModelSummary,\n", " RichProgressBar,\n", ")\n", "from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-08 18:25:17.572\u001b[0m | \u001b[31m\u001b[1mERROR \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m9\u001b[0m - \u001b[31m\u001b[1mname '__file__' is not defined\u001b[0m\n" ] } ], "source": [ "# Load environment variables\n", "load_dotenv(find_dotenv(\".env\"))\n", "\n", "# Setup root directory\n", "try:\n", " root = rootutils.setup_root(__file__, indicator=\".project-root\")\n", "\n", "except Exception as e:\n", " logger.error(e)\n", " root = Path(os.getcwd())\n", " os.environ[\"PROJECT_ROOT\"] = str(root)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def load_checkpoint_if_available(ckpt_path: str) -> str:\n", " \"\"\"Check if the specified checkpoint exists and return the valid checkpoint path.\"\"\"\n", " if ckpt_path and Path(ckpt_path).exists():\n", " logger.info(f\"Checkpoint found: {ckpt_path}\")\n", " return ckpt_path\n", " else:\n", " logger.warning(\n", " f\"No checkpoint found at {ckpt_path}. Using current model weights.\"\n", " )\n", " return None\n", "\n", "\n", "def clear_checkpoint_directory(ckpt_dir: str):\n", " \"\"\"Clear all contents of the checkpoint directory without deleting the directory itself.\"\"\"\n", " ckpt_dir_path = Path(ckpt_dir)\n", " if ckpt_dir_path.exists() and ckpt_dir_path.is_dir():\n", " logger.info(f\"Clearing checkpoint directory: {ckpt_dir}\")\n", " # Iterate over all files and directories in the checkpoint directory and remove them\n", " for item in ckpt_dir_path.iterdir():\n", " try:\n", " if item.is_file() or item.is_symlink():\n", " item.unlink() # Remove file or symlink\n", " elif item.is_dir():\n", " shutil.rmtree(item) # Remove directory\n", " except Exception as e:\n", " logger.error(f\"Failed to delete {item}: {e}\")\n", " logger.info(f\"Checkpoint directory cleared: {ckpt_dir}\")\n", " else:\n", " logger.info(\n", " f\"Checkpoint directory does not exist. Creating directory: {ckpt_dir}\"\n", " )\n", " os.makedirs(ckpt_dir_path, exist_ok=True)\n", "\n", "\n", "@task_wrapper\n", "def train_module(\n", " cfg: DictConfig,\n", " data_module: L.LightningDataModule,\n", " model: L.LightningModule,\n", " trainer: L.Trainer,\n", "):\n", " \"\"\"Train the model using the provided Trainer and DataModule.\"\"\"\n", " logger.info(\"Training the model\")\n", " trainer.fit(model, data_module)\n", " train_metrics = trainer.callback_metrics\n", " try:\n", " logger.info(\n", " f\"Training completed with the following metrics- train_acc: {train_metrics['train_acc'].item()} and val_acc: {train_metrics['val_acc'].item()}\"\n", " )\n", " except KeyError:\n", " logger.info(f\"Training completed with the following metrics:{train_metrics}\")\n", "\n", " return train_metrics\n", "\n", "\n", "@task_wrapper\n", "def run_test_module(\n", " cfg: DictConfig,\n", " datamodule: L.LightningDataModule,\n", " model: L.LightningModule,\n", " trainer: L.Trainer,\n", "):\n", " \"\"\"Test the model using the best checkpoint or the current model weights.\"\"\"\n", " logger.info(\"Testing the model\")\n", " datamodule.setup(stage=\"test\")\n", "\n", " ckpt_path = load_checkpoint_if_available(cfg.ckpt_path)\n", "\n", " # If no checkpoint is available, Lightning will use current model weights\n", " test_metrics = trainer.test(model, datamodule, ckpt_path=ckpt_path)\n", " logger.info(f\"Test metrics:\\n{test_metrics}\")\n", "\n", " return test_metrics[0] if test_metrics else {}" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_487789/541470590.py:8: UserWarning: \n", "The version_base parameter is not specified.\n", "Please specify a compatability version level, or None.\n", "Will assume defaults for version 1.1\n", " with hydra.initialize(config_path=\"../configs\"):\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Full Configuration:\n", "task_name: train\n", "tags:\n", "- dev\n", "train: true\n", "test: false\n", "ckpt_path: ${paths.ckpt_dir}/best-checkpoint.ckpt\n", "seed: 42\n", "name: catdog_experiment\n", "data:\n", " _target_: src.datamodules.catdog_datamodule.CatDogImageDataModule\n", " data_dir: ${paths.data_dir}\n", " url: ${paths.data_url}\n", " num_workers: 8\n", " batch_size: 64\n", " train_val_split:\n", " - 0.8\n", " - 0.2\n", " pin_memory: true\n", " image_size: 160\n", "model:\n", " _target_: src.models.catdog_model.ViTTinyClassifier\n", " img_size: 160\n", " patch_size: 16\n", " num_classes: 2\n", " embed_dim: 64\n", " depth: 6\n", " num_heads: 2\n", " mlp_ratio: 3\n", " pre_norm: false\n", " lr: 0.001\n", " weight_decay: 1.0e-05\n", " factor: 0.1\n", " patience: 10\n", " min_lr: 1.0e-06\n", "callbacks:\n", " model_checkpoint:\n", " dirpath: ${paths.ckpt_dir}\n", " filename: best-checkpoint\n", " monitor: val_acc\n", " verbose: true\n", " save_last: true\n", " save_top_k: 1\n", " mode: max\n", " auto_insert_metric_name: false\n", " save_weights_only: false\n", " every_n_train_steps: null\n", " train_time_interval: null\n", " every_n_epochs: null\n", " save_on_train_epoch_end: null\n", " early_stopping:\n", " monitor: val_acc\n", " min_delta: 0.0\n", " patience: 10\n", " verbose: true\n", " mode: max\n", " strict: true\n", " check_finite: true\n", " stopping_threshold: null\n", " divergence_threshold: null\n", " check_on_train_epoch_end: null\n", " rich_model_summary:\n", " max_depth: 1\n", " rich_progress_bar:\n", " refresh_rate: 1\n", "logger:\n", " csv:\n", " save_dir: ${paths.output_dir}\n", " name: csv/\n", " prefix: ''\n", " tensorboard:\n", " save_dir: ${paths.output_dir}/tensorboard/\n", " name: null\n", " log_graph: false\n", " default_hp_metric: true\n", " prefix: ''\n", "trainer:\n", " _target_: lightning.Trainer\n", " default_root_dir: ${paths.output_dir}\n", " min_epochs: 1\n", " max_epochs: 6\n", " accelerator: auto\n", " devices: auto\n", " deterministic: true\n", " log_every_n_steps: 10\n", " fast_dev_run: false\n", "paths:\n", " root_dir: ${oc.env:PROJECT_ROOT}\n", " data_dir: ${paths.root_dir}/data/\n", " log_dir: ${paths.root_dir}/logs/\n", " ckpt_dir: ${paths.root_dir}/checkpoints\n", " artifact_dir: ${paths.root_dir}/artifacts/\n", " data_url: https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\n", " output_dir: ${hydra:runtime.output_dir}\n", " work_dir: ${hydra:runtime.cwd}\n", "\n" ] } ], "source": [ "import hydra\n", "from omegaconf import DictConfig, OmegaConf\n", "\n", "\n", "# Function to load the configuration as an object without using the @hydra.main decorator\n", "def load_config() -> DictConfig:\n", " # Initialize the configuration context (e.g., \"../configs\" directory)\n", " with hydra.initialize(config_path=\"../configs\"):\n", " # Compose the configuration object with a specific config name (e.g., \"train\")\n", " cfg = hydra.compose(config_name=\"train\")\n", " return cfg\n", "\n", "\n", "# Load the configuration\n", "cfg = load_config()\n", "\n", "# Print the entire configuration for reference\n", "print(\"Full Configuration:\")\n", "print(OmegaConf.to_yaml(cfg))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-08 18:25:23\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m8\u001b[0m - \u001b[1mWhole Config:\n", "task_name: train\n", "tags:\n", "- dev\n", "train: true\n", "test: false\n", "ckpt_path: ${paths.ckpt_dir}/best-checkpoint.ckpt\n", "seed: 42\n", "name: catdog_experiment\n", "data:\n", " _target_: src.datamodules.catdog_datamodule.CatDogImageDataModule\n", " data_dir: ${paths.data_dir}\n", " url: ${paths.data_url}\n", " num_workers: 8\n", " batch_size: 64\n", " train_val_split:\n", " - 0.8\n", " - 0.2\n", " pin_memory: true\n", " image_size: 160\n", "model:\n", " _target_: src.models.catdog_model.ViTTinyClassifier\n", " img_size: 160\n", " patch_size: 16\n", " num_classes: 2\n", " embed_dim: 64\n", " depth: 6\n", " num_heads: 2\n", " mlp_ratio: 3\n", " pre_norm: false\n", " lr: 0.001\n", " weight_decay: 1.0e-05\n", " factor: 0.1\n", " patience: 10\n", " min_lr: 1.0e-06\n", "callbacks:\n", " model_checkpoint:\n", " dirpath: ${paths.ckpt_dir}\n", " filename: best-checkpoint\n", " monitor: val_acc\n", " verbose: true\n", " save_last: true\n", " save_top_k: 1\n", " mode: max\n", " auto_insert_metric_name: false\n", " save_weights_only: false\n", " every_n_train_steps: null\n", " train_time_interval: null\n", " every_n_epochs: null\n", " save_on_train_epoch_end: null\n", " early_stopping:\n", " monitor: val_acc\n", " min_delta: 0.0\n", " patience: 10\n", " verbose: true\n", " mode: max\n", " strict: true\n", " check_finite: true\n", " stopping_threshold: null\n", " divergence_threshold: null\n", " check_on_train_epoch_end: null\n", " rich_model_summary:\n", " max_depth: 1\n", " rich_progress_bar:\n", " refresh_rate: 1\n", "logger:\n", " csv:\n", " save_dir: ${paths.output_dir}\n", " name: csv/\n", " prefix: ''\n", " tensorboard:\n", " save_dir: ${paths.output_dir}/tensorboard/\n", " name: null\n", " log_graph: false\n", " default_hp_metric: true\n", " prefix: ''\n", "trainer:\n", " _target_: lightning.Trainer\n", " default_root_dir: ${paths.output_dir}\n", " min_epochs: 1\n", " max_epochs: 6\n", " accelerator: auto\n", " devices: auto\n", " deterministic: true\n", " log_every_n_steps: 10\n", " fast_dev_run: false\n", "paths:\n", " root_dir: ${oc.env:PROJECT_ROOT}\n", " data_dir: ${paths.root_dir}/data/\n", " log_dir: ${paths.root_dir}/logs/\n", " ckpt_dir: ${paths.root_dir}/checkpoints\n", " artifact_dir: ${paths.root_dir}/artifacts/\n", " data_url: https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\n", " output_dir: ${hydra:runtime.output_dir}\n", " work_dir: ${hydra:runtime.cwd}\n", "\u001b[0m\n" ] } ], "source": [ "# Initialize logger\n", "if cfg.task_name == \"train\":\n", " log_path = Path(cfg.paths.log_dir) / \"train.log\"\n", "else:\n", " log_path = Path(cfg.paths.log_dir) / \"eval.log\"\n", "setup_logger(log_path)\n", "\n", "logger.info(f\"Whole Config:\\n{OmegaConf.to_yaml(cfg)}\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m3\u001b[0m - \u001b[1mRoot directory: /mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws\u001b[0m\n", "\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m5\u001b[0m - \u001b[1mCurrent working directory: ['.dvc', '.dvcignore', '.env', '.git', '.github', '.gitignore', '.project-root', 'aws', 'basic_setup.md', 'configs', 'data', 'data.dvc', 'docker-compose.yaml', 'Dockerfile', 'ec2_runner_setup.md', 'logs', 'main.py', 'notebooks', 'poetry.lock', 'pyproject.toml', 'README.md', 'setup_aws_ci.md', 'src', 'tests', 'todo.md']\u001b[0m\n", "\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m8\u001b[0m - \u001b[1mCheckpoint directory: /mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws/checkpoints\u001b[0m\n", "\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m12\u001b[0m - \u001b[1mData directory: /mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws/data/\u001b[0m\n", "\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m16\u001b[0m - \u001b[1mLog directory: /mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws/logs/\u001b[0m\n", "\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m20\u001b[0m - \u001b[1mArtifact directory: /mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws/artifacts/\u001b[0m\n", "\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m28\u001b[0m - \u001b[1mExperiment name: catdog_experiment\u001b[0m\n" ] } ], "source": [ "# the path to the checkpoint directory\n", "root_dir = cfg.paths.root_dir\n", "logger.info(f\"Root directory: {root_dir}\")\n", "\n", "logger.info(f\"Current working directory: {os.listdir(root_dir)}\")\n", "\n", "ckpt_dir = cfg.paths.ckpt_dir\n", "logger.info(f\"Checkpoint directory: {ckpt_dir}\")\n", "\n", "# the path to the data directory\n", "data_dir = cfg.paths.data_dir\n", "logger.info(f\"Data directory: {data_dir}\")\n", "\n", "# the path to the log directory\n", "log_dir = cfg.paths.log_dir\n", "logger.info(f\"Log directory: {log_dir}\")\n", "\n", "# the path to the artifact directory\n", "artifact_dir = cfg.paths.artifact_dir\n", "logger.info(f\"Artifact directory: {artifact_dir}\")\n", "\n", "# output directory\n", "# output_dir = cfg.paths.output_dir\n", "# logger.info(f\"Output directory: {output_dir}\")\n", "\n", "# name of the experiment\n", "experiment_name = cfg.name\n", "logger.info(f\"Experiment name: {experiment_name}\")\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-08 18:25:28\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m2\u001b[0m - \u001b[1mInstantiating datamodule \u001b[0m\n" ] } ], "source": [ "# Initialize DataModule\n", "logger.info(f\"Instantiating datamodule <{cfg.data._target_}>\")\n", "datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-08 18:25:28\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m2\u001b[0m - \u001b[1mNo GPU available\u001b[0m\n", "Seed set to 42\n" ] }, { "data": { "text/plain": [ "42" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check for GPU availability\n", "logger.info(\"GPU available\" if torch.cuda.is_available() else \"No GPU available\")\n", "\n", "# Set seed for reproducibility\n", "L.seed_everything(cfg.seed, workers=True)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-08 18:25:29\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m2\u001b[0m - \u001b[1mInstantiating model \u001b[0m\n" ] } ], "source": [ "# Initialize model\n", "logger.info(f\"Instantiating model <{cfg.model._target_}>\")\n", "model: L.LightningModule = hydra.utils.instantiate(cfg.model)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-08 18:25:30\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m1\u001b[0m - \u001b[1mModel summary:\n", "ViTTinyClassifier(\n", " (model): VisionTransformer(\n", " (patch_embed): PatchEmbed(\n", " (proj): Conv2d(3, 64, kernel_size=(16, 16), stride=(16, 16))\n", " (norm): Identity()\n", " )\n", " (pos_drop): Dropout(p=0.0, inplace=False)\n", " (patch_drop): Identity()\n", " (norm_pre): Identity()\n", " (blocks): Sequential(\n", " (0): Block(\n", " (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=64, out_features=192, bias=False)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=64, out_features=64, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=64, out_features=192, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=192, out_features=64, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (1): Block(\n", " (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=64, out_features=192, bias=False)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=64, out_features=64, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=64, out_features=192, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=192, out_features=64, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (2): Block(\n", " (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=64, out_features=192, bias=False)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=64, out_features=64, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=64, out_features=192, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=192, out_features=64, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (3): Block(\n", " (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=64, out_features=192, bias=False)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=64, out_features=64, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=64, out_features=192, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=192, out_features=64, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (4): Block(\n", " (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=64, out_features=192, bias=False)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=64, out_features=64, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=64, out_features=192, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=192, out_features=64, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (5): Block(\n", " (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=64, out_features=192, bias=False)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=64, out_features=64, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=64, out_features=192, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=192, out_features=64, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " )\n", " (norm): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (fc_norm): Identity()\n", " (head_drop): Dropout(p=0.0, inplace=False)\n", " (head): Linear(in_features=64, out_features=2, bias=True)\n", " )\n", " (train_metrics): ModuleDict(\n", " (accuracy): MulticlassAccuracy()\n", " (precision): MulticlassPrecision()\n", " (recall): MulticlassRecall()\n", " (f1): MulticlassF1Score()\n", " )\n", " (val_metrics): ModuleDict(\n", " (accuracy): MulticlassAccuracy()\n", " (precision): MulticlassPrecision()\n", " (recall): MulticlassRecall()\n", " (f1): MulticlassF1Score()\n", " )\n", " (test_metrics): ModuleDict(\n", " (accuracy): MulticlassAccuracy()\n", " (precision): MulticlassPrecision()\n", " (recall): MulticlassRecall()\n", " (f1): MulticlassF1Score()\n", " )\n", " (criterion): CrossEntropyLoss()\n", ")\u001b[0m\n" ] } ], "source": [ "logger.info(f\"Model summary:\\n{model}\")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def initialize_callbacks(cfg: DictConfig) -> List[L.Callback]:\n", " \"\"\"Initialize the callbacks based on the configuration.\"\"\"\n", " if not cfg:\n", " logger.warning(\"No callback configs found! Skipping..\")\n", " return callbacks\n", "\n", " if not isinstance(cfg, DictConfig):\n", " raise TypeError(\"Callbacks config must be a DictConfig!\")\n", " callbacks = []\n", "\n", " # Initialize the model checkpoint callback\n", " model_checkpoint = ModelCheckpoint(**cfg.callbacks.model_checkpoint)\n", " callbacks.append(model_checkpoint)\n", "\n", " # Initialize the early stopping callback\n", " early_stopping = EarlyStopping(**cfg.callbacks.early_stopping)\n", " callbacks.append(early_stopping)\n", "\n", " # Initialize the rich model summary callback\n", " model_summary = RichModelSummary(**cfg.callbacks.rich_model_summary)\n", " callbacks.append(model_summary)\n", "\n", " # Initialize the rich progress bar callback\n", " progress_bar = RichProgressBar(**cfg.callbacks.rich_progress_bar)\n", " callbacks.append(progress_bar)\n", "\n", " return callbacks\n", "\n", "\n", "def initialize_logger(cfg: DictConfig) -> Logger:\n", " \"\"\"Initialize the logger based on the configuration.\"\"\"\n", " if not cfg:\n", " logger.warning(\"No logger configs found! Skipping..\")\n", " return None\n", "\n", " if not isinstance(cfg, DictConfig):\n", " raise TypeError(\"Logger config must be a DictConfig!\")\n", "\n", " loggers = []\n", "\n", " # Initialize the TensorBoard logger\n", " tensorboard_logger = TensorBoardLogger(**cfg.loggers.tensorboard)\n", " loggers.append(tensorboard_logger)\n", "\n", " # Initialize the CSV logger\n", " csv_logger = CSVLogger(**cfg.loggers.csv)\n", " loggers.append(csv_logger)\n", "\n", " return loggers" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/anaconda/envs/emlo_env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "['bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_384', 'beit_large_patch16_224', 'beit_large_patch16_384', 'beit_large_patch16_512', 'beitv2_base_patch16_224', 'beitv2_large_patch16_224', 'botnet26t_256', 'botnet50ts_256', 'caformer_b36', 'caformer_m36', 'caformer_s18', 'caformer_s36', 'cait_m36_384', 'cait_m48_448', 'cait_s24_224', 'cait_s24_384', 'cait_s36_384', 'cait_xs24_384', 'cait_xxs24_224', 'cait_xxs24_384', 'cait_xxs36_224', 'cait_xxs36_384', 'coat_lite_medium', 'coat_lite_medium_384', 'coat_lite_mini', 'coat_lite_small', 'coat_lite_tiny', 'coat_mini', 'coat_small', 'coat_tiny', 'coatnet_0_224', 'coatnet_0_rw_224', 'coatnet_1_224', 'coatnet_1_rw_224', 'coatnet_2_224', 'coatnet_2_rw_224', 'coatnet_3_224', 'coatnet_3_rw_224', 'coatnet_4_224', 'coatnet_5_224', 'coatnet_bn_0_rw_224', 'coatnet_nano_cc_224', 'coatnet_nano_rw_224', 'coatnet_pico_rw_224', 'coatnet_rmlp_0_rw_224', 'coatnet_rmlp_1_rw2_224', 'coatnet_rmlp_1_rw_224', 'coatnet_rmlp_2_rw_224', 'coatnet_rmlp_2_rw_384', 'coatnet_rmlp_3_rw_224', 'coatnet_rmlp_nano_rw_224', 'coatnext_nano_rw_224', 'convformer_b36', 'convformer_m36', 'convformer_s18', 'convformer_s36', 'convit_base', 'convit_small', 'convit_tiny', 'convmixer_768_32', 'convmixer_1024_20_ks9_p14', 'convmixer_1536_20', 'convnext_atto', 'convnext_atto_ols', 'convnext_base', 'convnext_femto', 'convnext_femto_ols', 'convnext_large', 'convnext_large_mlp', 'convnext_nano', 'convnext_nano_ols', 'convnext_pico', 'convnext_pico_ols', 'convnext_small', 'convnext_tiny', 'convnext_tiny_hnf', 'convnext_xlarge', 'convnext_xxlarge', 'convnextv2_atto', 'convnextv2_base', 'convnextv2_femto', 'convnextv2_huge', 'convnextv2_large', 'convnextv2_nano', 'convnextv2_pico', 'convnextv2_small', 'convnextv2_tiny', 'crossvit_9_240', 'crossvit_9_dagger_240', 'crossvit_15_240', 'crossvit_15_dagger_240', 'crossvit_15_dagger_408', 'crossvit_18_240', 'crossvit_18_dagger_240', 'crossvit_18_dagger_408', 'crossvit_base_240', 'crossvit_small_240', 'crossvit_tiny_240', 'cs3darknet_focus_l', 'cs3darknet_focus_m', 'cs3darknet_focus_s', 'cs3darknet_focus_x', 'cs3darknet_l', 'cs3darknet_m', 'cs3darknet_s', 'cs3darknet_x', 'cs3edgenet_x', 'cs3se_edgenet_x', 'cs3sedarknet_l', 'cs3sedarknet_x', 'cs3sedarknet_xdw', 'cspdarknet53', 'cspresnet50', 'cspresnet50d', 'cspresnet50w', 'cspresnext50', 'darknet17', 'darknet21', 'darknet53', 'darknetaa53', 'davit_base', 'davit_base_fl', 'davit_giant', 'davit_huge', 'davit_huge_fl', 'davit_large', 'davit_small', 'davit_tiny', 'deit3_base_patch16_224', 'deit3_base_patch16_384', 'deit3_huge_patch14_224', 'deit3_large_patch16_224', 'deit3_large_patch16_384', 'deit3_medium_patch16_224', 'deit3_small_patch16_224', 'deit3_small_patch16_384', 'deit_base_distilled_patch16_224', 'deit_base_distilled_patch16_384', 'deit_base_patch16_224', 'deit_base_patch16_384', 'deit_small_distilled_patch16_224', 'deit_small_patch16_224', 'deit_tiny_distilled_patch16_224', 'deit_tiny_patch16_224', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'densenet264d', 'densenetblur121d', 'dla34', 'dla46_c', 'dla46x_c', 'dla60', 'dla60_res2net', 'dla60_res2next', 'dla60x', 'dla60x_c', 'dla102', 'dla102x', 'dla102x2', 'dla169', 'dm_nfnet_f0', 'dm_nfnet_f1', 'dm_nfnet_f2', 'dm_nfnet_f3', 'dm_nfnet_f4', 'dm_nfnet_f5', 'dm_nfnet_f6', 'dpn48b', 'dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn107', 'dpn131', 'eca_botnext26ts_256', 'eca_halonext26ts', 'eca_nfnet_l0', 'eca_nfnet_l1', 'eca_nfnet_l2', 'eca_nfnet_l3', 'eca_resnet33ts', 'eca_resnext26ts', 'eca_vovnet39b', 'ecaresnet26t', 'ecaresnet50d', 'ecaresnet50d_pruned', 'ecaresnet50t', 'ecaresnet101d', 'ecaresnet101d_pruned', 'ecaresnet200d', 'ecaresnet269d', 'ecaresnetlight', 'ecaresnext26t_32x4d', 'ecaresnext50t_32x4d', 'edgenext_base', 'edgenext_small', 'edgenext_small_rw', 'edgenext_x_small', 'edgenext_xx_small', 'efficientformer_l1', 'efficientformer_l3', 'efficientformer_l7', 'efficientformerv2_l', 'efficientformerv2_s0', 'efficientformerv2_s1', 'efficientformerv2_s2', 'efficientnet_b0', 'efficientnet_b0_g8_gn', 'efficientnet_b0_g16_evos', 'efficientnet_b0_gn', 'efficientnet_b1', 'efficientnet_b1_pruned', 'efficientnet_b2', 'efficientnet_b2_pruned', 'efficientnet_b3', 'efficientnet_b3_g8_gn', 'efficientnet_b3_gn', 'efficientnet_b3_pruned', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', 'efficientnet_blur_b0', 'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e', 'efficientnet_el', 'efficientnet_el_pruned', 'efficientnet_em', 'efficientnet_es', 'efficientnet_es_pruned', 'efficientnet_h_b5', 'efficientnet_l2', 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4', 'efficientnet_x_b3', 'efficientnet_x_b5', 'efficientnetv2_l', 'efficientnetv2_m', 'efficientnetv2_rw_m', 'efficientnetv2_rw_s', 'efficientnetv2_rw_t', 'efficientnetv2_s', 'efficientnetv2_xl', 'efficientvit_b0', 'efficientvit_b1', 'efficientvit_b2', 'efficientvit_b3', 'efficientvit_l1', 'efficientvit_l2', 'efficientvit_l3', 'efficientvit_m0', 'efficientvit_m1', 'efficientvit_m2', 'efficientvit_m3', 'efficientvit_m4', 'efficientvit_m5', 'ese_vovnet19b_dw', 'ese_vovnet19b_slim', 'ese_vovnet19b_slim_dw', 'ese_vovnet39b', 'ese_vovnet39b_evos', 'ese_vovnet57b', 'ese_vovnet99b', 'eva02_base_patch14_224', 'eva02_base_patch14_448', 'eva02_base_patch16_clip_224', 'eva02_enormous_patch14_clip_224', 'eva02_large_patch14_224', 'eva02_large_patch14_448', 'eva02_large_patch14_clip_224', 'eva02_large_patch14_clip_336', 'eva02_small_patch14_224', 'eva02_small_patch14_336', 'eva02_tiny_patch14_224', 'eva02_tiny_patch14_336', 'eva_giant_patch14_224', 'eva_giant_patch14_336', 'eva_giant_patch14_560', 'eva_giant_patch14_clip_224', 'eva_large_patch14_196', 'eva_large_patch14_336', 'fastvit_ma36', 'fastvit_mci0', 'fastvit_mci1', 'fastvit_mci2', 'fastvit_s12', 'fastvit_sa12', 'fastvit_sa24', 'fastvit_sa36', 'fastvit_t8', 'fastvit_t12', 'fbnetc_100', 'fbnetv3_b', 'fbnetv3_d', 'fbnetv3_g', 'flexivit_base', 'flexivit_large', 'flexivit_small', 'focalnet_base_lrf', 'focalnet_base_srf', 'focalnet_huge_fl3', 'focalnet_huge_fl4', 'focalnet_large_fl3', 'focalnet_large_fl4', 'focalnet_small_lrf', 'focalnet_small_srf', 'focalnet_tiny_lrf', 'focalnet_tiny_srf', 'focalnet_xlarge_fl3', 'focalnet_xlarge_fl4', 'gc_efficientnetv2_rw_t', 'gcresnet33ts', 'gcresnet50t', 'gcresnext26ts', 'gcresnext50ts', 'gcvit_base', 'gcvit_small', 'gcvit_tiny', 'gcvit_xtiny', 'gcvit_xxtiny', 'gernet_l', 'gernet_m', 'gernet_s', 'ghostnet_050', 'ghostnet_100', 'ghostnet_130', 'ghostnetv2_100', 'ghostnetv2_130', 'ghostnetv2_160', 'gmixer_12_224', 'gmixer_24_224', 'gmlp_b16_224', 'gmlp_s16_224', 'gmlp_ti16_224', 'halo2botnet50ts_256', 'halonet26t', 'halonet50ts', 'halonet_h1', 'haloregnetz_b', 'hardcorenas_a', 'hardcorenas_b', 'hardcorenas_c', 'hardcorenas_d', 'hardcorenas_e', 'hardcorenas_f', 'hgnet_base', 'hgnet_small', 'hgnet_tiny', 'hgnetv2_b0', 'hgnetv2_b1', 'hgnetv2_b2', 'hgnetv2_b3', 'hgnetv2_b4', 'hgnetv2_b5', 'hgnetv2_b6', 'hiera_base_224', 'hiera_base_abswin_256', 'hiera_base_plus_224', 'hiera_huge_224', 'hiera_large_224', 'hiera_small_224', 'hiera_small_abswin_256', 'hiera_tiny_224', 'hieradet_small', 'hrnet_w18', 'hrnet_w18_small', 'hrnet_w18_small_v2', 'hrnet_w18_ssld', 'hrnet_w30', 'hrnet_w32', 'hrnet_w40', 'hrnet_w44', 'hrnet_w48', 'hrnet_w48_ssld', 'hrnet_w64', 'inception_next_base', 'inception_next_small', 'inception_next_tiny', 'inception_resnet_v2', 'inception_v3', 'inception_v4', 'lambda_resnet26rpt_256', 'lambda_resnet26t', 'lambda_resnet50ts', 'lamhalobotnet50ts_256', 'lcnet_035', 'lcnet_050', 'lcnet_075', 'lcnet_100', 'lcnet_150', 'legacy_senet154', 'legacy_seresnet18', 'legacy_seresnet34', 'legacy_seresnet50', 'legacy_seresnet101', 'legacy_seresnet152', 'legacy_seresnext26_32x4d', 'legacy_seresnext50_32x4d', 'legacy_seresnext101_32x4d', 'legacy_xception', 'levit_128', 'levit_128s', 'levit_192', 'levit_256', 'levit_256d', 'levit_384', 'levit_384_s8', 'levit_512', 'levit_512_s8', 'levit_512d', 'levit_conv_128', 'levit_conv_128s', 'levit_conv_192', 'levit_conv_256', 'levit_conv_256d', 'levit_conv_384', 'levit_conv_384_s8', 'levit_conv_512', 'levit_conv_512_s8', 'levit_conv_512d', 'maxvit_base_tf_224', 'maxvit_base_tf_384', 'maxvit_base_tf_512', 'maxvit_large_tf_224', 'maxvit_large_tf_384', 'maxvit_large_tf_512', 'maxvit_nano_rw_256', 'maxvit_pico_rw_256', 'maxvit_rmlp_base_rw_224', 'maxvit_rmlp_base_rw_384', 'maxvit_rmlp_nano_rw_256', 'maxvit_rmlp_pico_rw_256', 'maxvit_rmlp_small_rw_224', 'maxvit_rmlp_small_rw_256', 'maxvit_rmlp_tiny_rw_256', 'maxvit_small_tf_224', 'maxvit_small_tf_384', 'maxvit_small_tf_512', 'maxvit_tiny_pm_256', 'maxvit_tiny_rw_224', 'maxvit_tiny_rw_256', 'maxvit_tiny_tf_224', 'maxvit_tiny_tf_384', 'maxvit_tiny_tf_512', 'maxvit_xlarge_tf_224', 'maxvit_xlarge_tf_384', 'maxvit_xlarge_tf_512', 'maxxvit_rmlp_nano_rw_256', 'maxxvit_rmlp_small_rw_256', 'maxxvit_rmlp_tiny_rw_256', 'maxxvitv2_nano_rw_256', 'maxxvitv2_rmlp_base_rw_224', 'maxxvitv2_rmlp_base_rw_384', 'maxxvitv2_rmlp_large_rw_224', 'mixer_b16_224', 'mixer_b32_224', 'mixer_l16_224', 'mixer_l32_224', 'mixer_s16_224', 'mixer_s32_224', 'mixnet_l', 'mixnet_m', 'mixnet_s', 'mixnet_xl', 'mixnet_xxl', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_140', 'mnasnet_small', 'mobilenet_edgetpu_100', 'mobilenet_edgetpu_v2_l', 'mobilenet_edgetpu_v2_m', 'mobilenet_edgetpu_v2_s', 'mobilenet_edgetpu_v2_xs', 'mobilenetv1_100', 'mobilenetv1_100h', 'mobilenetv1_125', 'mobilenetv2_035', 'mobilenetv2_050', 'mobilenetv2_075', 'mobilenetv2_100', 'mobilenetv2_110d', 'mobilenetv2_120d', 'mobilenetv2_140', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_150d', 'mobilenetv3_rw', 'mobilenetv3_small_050', 'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv4_conv_aa_large', 'mobilenetv4_conv_aa_medium', 'mobilenetv4_conv_blur_medium', 'mobilenetv4_conv_large', 'mobilenetv4_conv_medium', 'mobilenetv4_conv_small', 'mobilenetv4_hybrid_large', 'mobilenetv4_hybrid_large_075', 'mobilenetv4_hybrid_medium', 'mobilenetv4_hybrid_medium_075', 'mobileone_s0', 'mobileone_s1', 'mobileone_s2', 'mobileone_s3', 'mobileone_s4', 'mobilevit_s', 'mobilevit_xs', 'mobilevit_xxs', 'mobilevitv2_050', 'mobilevitv2_075', 'mobilevitv2_100', 'mobilevitv2_125', 'mobilevitv2_150', 'mobilevitv2_175', 'mobilevitv2_200', 'mvitv2_base', 'mvitv2_base_cls', 'mvitv2_huge_cls', 'mvitv2_large', 'mvitv2_large_cls', 'mvitv2_small', 'mvitv2_small_cls', 'mvitv2_tiny', 'nasnetalarge', 'nest_base', 'nest_base_jx', 'nest_small', 'nest_small_jx', 'nest_tiny', 'nest_tiny_jx', 'nextvit_base', 'nextvit_large', 'nextvit_small', 'nf_ecaresnet26', 'nf_ecaresnet50', 'nf_ecaresnet101', 'nf_regnet_b0', 'nf_regnet_b1', 'nf_regnet_b2', 'nf_regnet_b3', 'nf_regnet_b4', 'nf_regnet_b5', 'nf_resnet26', 'nf_resnet50', 'nf_resnet101', 'nf_seresnet26', 'nf_seresnet50', 'nf_seresnet101', 'nfnet_f0', 'nfnet_f1', 'nfnet_f2', 'nfnet_f3', 'nfnet_f4', 'nfnet_f5', 'nfnet_f6', 'nfnet_f7', 'nfnet_l0', 'pit_b_224', 'pit_b_distilled_224', 'pit_s_224', 'pit_s_distilled_224', 'pit_ti_224', 'pit_ti_distilled_224', 'pit_xs_224', 'pit_xs_distilled_224', 'pnasnet5large', 'poolformer_m36', 'poolformer_m48', 'poolformer_s12', 'poolformer_s24', 'poolformer_s36', 'poolformerv2_m36', 'poolformerv2_m48', 'poolformerv2_s12', 'poolformerv2_s24', 'poolformerv2_s36', 'pvt_v2_b0', 'pvt_v2_b1', 'pvt_v2_b2', 'pvt_v2_b2_li', 'pvt_v2_b3', 'pvt_v2_b4', 'pvt_v2_b5', 'rdnet_base', 'rdnet_large', 'rdnet_small', 'rdnet_tiny', 'regnetv_040', 'regnetv_064', 'regnetx_002', 'regnetx_004', 'regnetx_004_tv', 'regnetx_006', 'regnetx_008', 'regnetx_016', 'regnetx_032', 'regnetx_040', 'regnetx_064', 'regnetx_080', 'regnetx_120', 'regnetx_160', 'regnetx_320', 'regnety_002', 'regnety_004', 'regnety_006', 'regnety_008', 'regnety_008_tv', 'regnety_016', 'regnety_032', 'regnety_040', 'regnety_040_sgn', 'regnety_064', 'regnety_080', 'regnety_080_tv', 'regnety_120', 'regnety_160', 'regnety_320', 'regnety_640', 'regnety_1280', 'regnety_2560', 'regnetz_005', 'regnetz_040', 'regnetz_040_h', 'regnetz_b16', 'regnetz_b16_evos', 'regnetz_c16', 'regnetz_c16_evos', 'regnetz_d8', 'regnetz_d8_evos', 'regnetz_d32', 'regnetz_e8', 'repghostnet_050', 'repghostnet_058', 'repghostnet_080', 'repghostnet_100', 'repghostnet_111', 'repghostnet_130', 'repghostnet_150', 'repghostnet_200', 'repvgg_a0', 'repvgg_a1', 'repvgg_a2', 'repvgg_b0', 'repvgg_b1', 'repvgg_b1g4', 'repvgg_b2', 'repvgg_b2g4', 'repvgg_b3', 'repvgg_b3g4', 'repvgg_d2se', 'repvit_m0_9', 'repvit_m1', 'repvit_m1_0', 'repvit_m1_1', 'repvit_m1_5', 'repvit_m2', 'repvit_m2_3', 'repvit_m3', 'res2net50_14w_8s', 'res2net50_26w_4s', 'res2net50_26w_6s', 'res2net50_26w_8s', 'res2net50_48w_2s', 'res2net50d', 'res2net101_26w_4s', 'res2net101d', 'res2next50', 'resmlp_12_224', 'resmlp_24_224', 'resmlp_36_224', 'resmlp_big_24_224', 'resnest14d', 'resnest26d', 'resnest50d', 'resnest50d_1s4x24d', 'resnest50d_4s2x40d', 'resnest101e', 'resnest200e', 'resnest269e', 'resnet10t', 'resnet14t', 'resnet18', 'resnet18d', 'resnet26', 'resnet26d', 'resnet26t', 'resnet32ts', 'resnet33ts', 'resnet34', 'resnet34d', 'resnet50', 'resnet50_clip', 'resnet50_clip_gap', 'resnet50_gn', 'resnet50_mlp', 'resnet50c', 'resnet50d', 'resnet50s', 'resnet50t', 'resnet50x4_clip', 'resnet50x4_clip_gap', 'resnet50x16_clip', 'resnet50x16_clip_gap', 'resnet50x64_clip', 'resnet50x64_clip_gap', 'resnet51q', 'resnet61q', 'resnet101', 'resnet101_clip', 'resnet101_clip_gap', 'resnet101c', 'resnet101d', 'resnet101s', 'resnet152', 'resnet152c', 'resnet152d', 'resnet152s', 'resnet200', 'resnet200d', 'resnetaa34d', 'resnetaa50', 'resnetaa50d', 'resnetaa101d', 'resnetblur18', 'resnetblur50', 'resnetblur50d', 'resnetblur101d', 'resnetrs50', 'resnetrs101', 'resnetrs152', 'resnetrs200', 'resnetrs270', 'resnetrs350', 'resnetrs420', 'resnetv2_50', 'resnetv2_50d', 'resnetv2_50d_evos', 'resnetv2_50d_frn', 'resnetv2_50d_gn', 'resnetv2_50t', 'resnetv2_50x1_bit', 'resnetv2_50x3_bit', 'resnetv2_101', 'resnetv2_101d', 'resnetv2_101x1_bit', 'resnetv2_101x3_bit', 'resnetv2_152', 'resnetv2_152d', 'resnetv2_152x2_bit', 'resnetv2_152x4_bit', 'resnext26ts', 'resnext50_32x4d', 'resnext50d_32x4d', 'resnext101_32x4d', 'resnext101_32x8d', 'resnext101_32x16d', 'resnext101_32x32d', 'resnext101_64x4d', 'rexnet_100', 'rexnet_130', 'rexnet_150', 'rexnet_200', 'rexnet_300', 'rexnetr_100', 'rexnetr_130', 'rexnetr_150', 'rexnetr_200', 'rexnetr_300', 'sam2_hiera_base_plus', 'sam2_hiera_large', 'sam2_hiera_small', 'sam2_hiera_tiny', 'samvit_base_patch16', 'samvit_base_patch16_224', 'samvit_huge_patch16', 'samvit_large_patch16', 'sebotnet33ts_256', 'sedarknet21', 'sehalonet33ts', 'selecsls42', 'selecsls42b', 'selecsls60', 'selecsls60b', 'selecsls84', 'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'semnasnet_140', 'senet154', 'sequencer2d_l', 'sequencer2d_m', 'sequencer2d_s', 'seresnet18', 'seresnet33ts', 'seresnet34', 'seresnet50', 'seresnet50t', 'seresnet101', 'seresnet152', 'seresnet152d', 'seresnet200d', 'seresnet269d', 'seresnetaa50d', 'seresnext26d_32x4d', 'seresnext26t_32x4d', 'seresnext26ts', 'seresnext50_32x4d', 'seresnext101_32x4d', 'seresnext101_32x8d', 'seresnext101_64x4d', 'seresnext101d_32x8d', 'seresnextaa101d_32x8d', 'seresnextaa201d_32x8d', 'skresnet18', 'skresnet34', 'skresnet50', 'skresnet50d', 'skresnext50_32x4d', 'spnasnet_100', 'swin_base_patch4_window7_224', 'swin_base_patch4_window12_384', 'swin_large_patch4_window7_224', 'swin_large_patch4_window12_384', 'swin_s3_base_224', 'swin_s3_small_224', 'swin_s3_tiny_224', 'swin_small_patch4_window7_224', 'swin_tiny_patch4_window7_224', 'swinv2_base_window8_256', 'swinv2_base_window12_192', 'swinv2_base_window12to16_192to256', 'swinv2_base_window12to24_192to384', 'swinv2_base_window16_256', 'swinv2_cr_base_224', 'swinv2_cr_base_384', 'swinv2_cr_base_ns_224', 'swinv2_cr_giant_224', 'swinv2_cr_giant_384', 'swinv2_cr_huge_224', 'swinv2_cr_huge_384', 'swinv2_cr_large_224', 'swinv2_cr_large_384', 'swinv2_cr_small_224', 'swinv2_cr_small_384', 'swinv2_cr_small_ns_224', 'swinv2_cr_small_ns_256', 'swinv2_cr_tiny_224', 'swinv2_cr_tiny_384', 'swinv2_cr_tiny_ns_224', 'swinv2_large_window12_192', 'swinv2_large_window12to16_192to256', 'swinv2_large_window12to24_192to384', 'swinv2_small_window8_256', 'swinv2_small_window16_256', 'swinv2_tiny_window8_256', 'swinv2_tiny_window16_256', 'test_byobnet', 'test_efficientnet', 'test_vit', 'tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3', 'tf_efficientnet_b4', 'tf_efficientnet_b5', 'tf_efficientnet_b6', 'tf_efficientnet_b7', 'tf_efficientnet_b8', 'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e', 'tf_efficientnet_el', 'tf_efficientnet_em', 'tf_efficientnet_es', 'tf_efficientnet_l2', 'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3', 'tf_efficientnet_lite4', 'tf_efficientnetv2_b0', 'tf_efficientnetv2_b1', 'tf_efficientnetv2_b2', 'tf_efficientnetv2_b3', 'tf_efficientnetv2_l', 'tf_efficientnetv2_m', 'tf_efficientnetv2_s', 'tf_efficientnetv2_xl', 'tf_mixnet_l', 'tf_mixnet_m', 'tf_mixnet_s', 'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100', 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100', 'tiny_vit_5m_224', 'tiny_vit_11m_224', 'tiny_vit_21m_224', 'tiny_vit_21m_384', 'tiny_vit_21m_512', 'tinynet_a', 'tinynet_b', 'tinynet_c', 'tinynet_d', 'tinynet_e', 'tnt_b_patch16_224', 'tnt_s_patch16_224', 'tresnet_l', 'tresnet_m', 'tresnet_v2_l', 'tresnet_xl', 'twins_pcpvt_base', 'twins_pcpvt_large', 'twins_pcpvt_small', 'twins_svt_base', 'twins_svt_large', 'twins_svt_small', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'visformer_small', 'visformer_tiny', 'vit_base_mci_224', 'vit_base_patch8_224', 'vit_base_patch14_dinov2', 'vit_base_patch14_reg4_dinov2', 'vit_base_patch16_18x2_224', 'vit_base_patch16_224', 'vit_base_patch16_224_miil', 'vit_base_patch16_384', 'vit_base_patch16_clip_224', 'vit_base_patch16_clip_384', 'vit_base_patch16_clip_quickgelu_224', 'vit_base_patch16_gap_224', 'vit_base_patch16_plus_240', 'vit_base_patch16_reg4_gap_256', 'vit_base_patch16_rope_reg1_gap_256', 'vit_base_patch16_rpn_224', 'vit_base_patch16_siglip_224', 'vit_base_patch16_siglip_256', 'vit_base_patch16_siglip_384', 'vit_base_patch16_siglip_512', 'vit_base_patch16_siglip_gap_224', 'vit_base_patch16_siglip_gap_256', 'vit_base_patch16_siglip_gap_384', 'vit_base_patch16_siglip_gap_512', 'vit_base_patch16_xp_224', 'vit_base_patch32_224', 'vit_base_patch32_384', 'vit_base_patch32_clip_224', 'vit_base_patch32_clip_256', 'vit_base_patch32_clip_384', 'vit_base_patch32_clip_448', 'vit_base_patch32_clip_quickgelu_224', 'vit_base_patch32_plus_256', 'vit_base_r26_s32_224', 'vit_base_r50_s16_224', 'vit_base_r50_s16_384', 'vit_base_resnet26d_224', 'vit_base_resnet50d_224', 'vit_betwixt_patch16_gap_256', 'vit_betwixt_patch16_reg1_gap_256', 'vit_betwixt_patch16_reg4_gap_256', 'vit_betwixt_patch16_reg4_gap_384', 'vit_betwixt_patch16_rope_reg4_gap_256', 'vit_betwixt_patch32_clip_224', 'vit_giant_patch14_224', 'vit_giant_patch14_clip_224', 'vit_giant_patch14_dinov2', 'vit_giant_patch14_reg4_dinov2', 'vit_giant_patch16_gap_224', 'vit_gigantic_patch14_224', 'vit_gigantic_patch14_clip_224', 'vit_huge_patch14_224', 'vit_huge_patch14_clip_224', 'vit_huge_patch14_clip_336', 'vit_huge_patch14_clip_378', 'vit_huge_patch14_clip_quickgelu_224', 'vit_huge_patch14_clip_quickgelu_378', 'vit_huge_patch14_gap_224', 'vit_huge_patch14_xp_224', 'vit_huge_patch16_gap_448', 'vit_large_patch14_224', 'vit_large_patch14_clip_224', 'vit_large_patch14_clip_336', 'vit_large_patch14_clip_quickgelu_224', 'vit_large_patch14_clip_quickgelu_336', 'vit_large_patch14_dinov2', 'vit_large_patch14_reg4_dinov2', 'vit_large_patch14_xp_224', 'vit_large_patch16_224', 'vit_large_patch16_384', 'vit_large_patch16_siglip_256', 'vit_large_patch16_siglip_384', 'vit_large_patch16_siglip_gap_256', 'vit_large_patch16_siglip_gap_384', 'vit_large_patch32_224', 'vit_large_patch32_384', 'vit_large_r50_s32_224', 'vit_large_r50_s32_384', 'vit_little_patch16_reg1_gap_256', 'vit_little_patch16_reg4_gap_256', 'vit_medium_patch16_clip_224', 'vit_medium_patch16_gap_240', 'vit_medium_patch16_gap_256', 'vit_medium_patch16_gap_384', 'vit_medium_patch16_reg1_gap_256', 'vit_medium_patch16_reg4_gap_256', 'vit_medium_patch16_rope_reg1_gap_256', 'vit_medium_patch32_clip_224', 'vit_mediumd_patch16_reg4_gap_256', 'vit_mediumd_patch16_reg4_gap_384', 'vit_mediumd_patch16_rope_reg1_gap_256', 'vit_pwee_patch16_reg1_gap_256', 'vit_relpos_base_patch16_224', 'vit_relpos_base_patch16_cls_224', 'vit_relpos_base_patch16_clsgap_224', 'vit_relpos_base_patch16_plus_240', 'vit_relpos_base_patch16_rpn_224', 'vit_relpos_base_patch32_plus_rpn_256', 'vit_relpos_medium_patch16_224', 'vit_relpos_medium_patch16_cls_224', 'vit_relpos_medium_patch16_rpn_224', 'vit_relpos_small_patch16_224', 'vit_relpos_small_patch16_rpn_224', 'vit_small_patch8_224', 'vit_small_patch14_dinov2', 'vit_small_patch14_reg4_dinov2', 'vit_small_patch16_18x2_224', 'vit_small_patch16_36x1_224', 'vit_small_patch16_224', 'vit_small_patch16_384', 'vit_small_patch32_224', 'vit_small_patch32_384', 'vit_small_r26_s32_224', 'vit_small_r26_s32_384', 'vit_small_resnet26d_224', 'vit_small_resnet50d_s16_224', 'vit_so150m_patch16_reg4_gap_256', 'vit_so150m_patch16_reg4_map_256', 'vit_so400m_patch14_siglip_224', 'vit_so400m_patch14_siglip_384', 'vit_so400m_patch14_siglip_gap_224', 'vit_so400m_patch14_siglip_gap_384', 'vit_so400m_patch14_siglip_gap_448', 'vit_so400m_patch14_siglip_gap_896', 'vit_srelpos_medium_patch16_224', 'vit_srelpos_small_patch16_224', 'vit_tiny_patch16_224', 'vit_tiny_patch16_384', 'vit_tiny_r_s16_p8_224', 'vit_tiny_r_s16_p8_384', 'vit_wee_patch16_reg1_gap_256', 'vit_xsmall_patch16_clip_224', 'vitamin_base_224', 'vitamin_large2_224', 'vitamin_large2_256', 'vitamin_large2_336', 'vitamin_large2_384', 'vitamin_large_224', 'vitamin_large_256', 'vitamin_large_336', 'vitamin_large_384', 'vitamin_small_224', 'vitamin_xlarge_256', 'vitamin_xlarge_336', 'vitamin_xlarge_384', 'volo_d1_224', 'volo_d1_384', 'volo_d2_224', 'volo_d2_384', 'volo_d3_224', 'volo_d3_448', 'volo_d4_224', 'volo_d4_448', 'volo_d5_224', 'volo_d5_448', 'volo_d5_512', 'vovnet39a', 'vovnet57a', 'wide_resnet50_2', 'wide_resnet101_2', 'xception41', 'xception41p', 'xception65', 'xception65p', 'xception71', 'xcit_large_24_p8_224', 'xcit_large_24_p8_384', 'xcit_large_24_p16_224', 'xcit_large_24_p16_384', 'xcit_medium_24_p8_224', 'xcit_medium_24_p8_384', 'xcit_medium_24_p16_224', 'xcit_medium_24_p16_384', 'xcit_nano_12_p8_224', 'xcit_nano_12_p8_384', 'xcit_nano_12_p16_224', 'xcit_nano_12_p16_384', 'xcit_small_12_p8_224', 'xcit_small_12_p8_384', 'xcit_small_12_p16_224', 'xcit_small_12_p16_384', 'xcit_small_24_p8_224', 'xcit_small_24_p8_384', 'xcit_small_24_p16_224', 'xcit_small_24_p16_384', 'xcit_tiny_12_p8_224', 'xcit_tiny_12_p8_384', 'xcit_tiny_12_p16_224', 'xcit_tiny_12_p16_384', 'xcit_tiny_24_p8_224', 'xcit_tiny_24_p8_384', 'xcit_tiny_24_p16_224', 'xcit_tiny_24_p16_384']\n" ] } ], "source": [ "import timm\n", "print(timm.list_models())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### testing the litserve model" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import requests\n", "from urllib.request import urlopen\n", "import base64" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "url = \"https://media.istockphoto.com/id/541844008/photo/portland-grand-floral-parade-2016.jpg?s=2048x2048&w=is&k=20&c=ZuvR6oDv5WxwL5dhXKAbevysEXhXV47shJdpzkqen5Y=\"\n", "img_data = urlopen(url).read()\n", "print(type(img_data))" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Convert to base64 string\n", "img_bytes = base64.b64encode(img_data).decode('utf-8')\n", "print(type(img_bytes))" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "response = requests.post(\n", " \"http://localhost:8080/predict\", json={\"image\": img_bytes} # image is the key\n", ")" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\\nTop 5 Predictions:\n", "mountain_bike, all-terrain_bike, off-roader: 82.13%\n", "maillot: 5.09%\n", "crash_helmet: 1.84%\n", "bicycle-built-for-two, tandem_bicycle, tandem: 1.83%\n", "alp: 0.69%\n" ] } ], "source": [ "if response.status_code == 200:\n", " predictions = response.json()[\"predictions\"]\n", " print(\"\\\\nTop 5 Predictions:\")\n", " for pred in predictions:\n", " print(f\"{pred['label']}: {pred['probability']:.2%}\")\n", "else:\n", " print(f\"Error: {response.status_code}\")\n", " print(response.text)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "########################################## End of the script ##########################################" ] } ], "metadata": { "kernelspec": { "display_name": "emlo_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 2 }