{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "application/javascript": "IPython.notebook.set_autosave_interval(300000)" }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Autosaving every 300 seconds\n" ] } ], "source": [ "%autosave 300\n", "%load_ext autoreload\n", "%autoreload 2\n", "%reload_ext autoreload\n", "%config Completer.use_jedi = False" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws\n" ] } ], "source": [ "\n", "import os\n", "\n", "os.chdir(\"..\")\n", "print(os.getcwd())" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/anaconda/envs/emlo_env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import os\n", "import shutil\n", "from pathlib import Path\n", "import torch\n", "import lightning as L\n", "from lightning.pytorch.loggers import Logger\n", "from typing import List\n", "from src.datamodules.catdog_datamodule import CatDogImageDataModule\n", "from src.utils.logging_utils import setup_logger, task_wrapper\n", "from loguru import logger\n", "from dotenv import load_dotenv, find_dotenv\n", "import rootutils\n", "import hydra\n", "from omegaconf import DictConfig, OmegaConf\n", "from lightning.pytorch.callbacks import (\n", " ModelCheckpoint,\n", " EarlyStopping,\n", " RichModelSummary,\n", " RichProgressBar,\n", ")\n", "from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-08 18:25:17.572\u001b[0m | \u001b[31m\u001b[1mERROR \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m9\u001b[0m - \u001b[31m\u001b[1mname '__file__' is not defined\u001b[0m\n" ] } ], "source": [ "# Load environment variables\n", "load_dotenv(find_dotenv(\".env\"))\n", "\n", "# Setup root directory\n", "try:\n", " root = rootutils.setup_root(__file__, indicator=\".project-root\")\n", "\n", "except Exception as e:\n", " logger.error(e)\n", " root = Path(os.getcwd())\n", " os.environ[\"PROJECT_ROOT\"] = str(root)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def load_checkpoint_if_available(ckpt_path: str) -> str:\n", " \"\"\"Check if the specified checkpoint exists and return the valid checkpoint path.\"\"\"\n", " if ckpt_path and Path(ckpt_path).exists():\n", " logger.info(f\"Checkpoint found: {ckpt_path}\")\n", " return ckpt_path\n", " else:\n", " logger.warning(\n", " f\"No checkpoint found at {ckpt_path}. Using current model weights.\"\n", " )\n", " return None\n", "\n", "\n", "def clear_checkpoint_directory(ckpt_dir: str):\n", " \"\"\"Clear all contents of the checkpoint directory without deleting the directory itself.\"\"\"\n", " ckpt_dir_path = Path(ckpt_dir)\n", " if ckpt_dir_path.exists() and ckpt_dir_path.is_dir():\n", " logger.info(f\"Clearing checkpoint directory: {ckpt_dir}\")\n", " # Iterate over all files and directories in the checkpoint directory and remove them\n", " for item in ckpt_dir_path.iterdir():\n", " try:\n", " if item.is_file() or item.is_symlink():\n", " item.unlink() # Remove file or symlink\n", " elif item.is_dir():\n", " shutil.rmtree(item) # Remove directory\n", " except Exception as e:\n", " logger.error(f\"Failed to delete {item}: {e}\")\n", " logger.info(f\"Checkpoint directory cleared: {ckpt_dir}\")\n", " else:\n", " logger.info(\n", " f\"Checkpoint directory does not exist. Creating directory: {ckpt_dir}\"\n", " )\n", " os.makedirs(ckpt_dir_path, exist_ok=True)\n", "\n", "\n", "@task_wrapper\n", "def train_module(\n", " cfg: DictConfig,\n", " data_module: L.LightningDataModule,\n", " model: L.LightningModule,\n", " trainer: L.Trainer,\n", "):\n", " \"\"\"Train the model using the provided Trainer and DataModule.\"\"\"\n", " logger.info(\"Training the model\")\n", " trainer.fit(model, data_module)\n", " train_metrics = trainer.callback_metrics\n", " try:\n", " logger.info(\n", " f\"Training completed with the following metrics- train_acc: {train_metrics['train_acc'].item()} and val_acc: {train_metrics['val_acc'].item()}\"\n", " )\n", " except KeyError:\n", " logger.info(f\"Training completed with the following metrics:{train_metrics}\")\n", "\n", " return train_metrics\n", "\n", "\n", "@task_wrapper\n", "def run_test_module(\n", " cfg: DictConfig,\n", " datamodule: L.LightningDataModule,\n", " model: L.LightningModule,\n", " trainer: L.Trainer,\n", "):\n", " \"\"\"Test the model using the best checkpoint or the current model weights.\"\"\"\n", " logger.info(\"Testing the model\")\n", " datamodule.setup(stage=\"test\")\n", "\n", " ckpt_path = load_checkpoint_if_available(cfg.ckpt_path)\n", "\n", " # If no checkpoint is available, Lightning will use current model weights\n", " test_metrics = trainer.test(model, datamodule, ckpt_path=ckpt_path)\n", " logger.info(f\"Test metrics:\\n{test_metrics}\")\n", "\n", " return test_metrics[0] if test_metrics else {}" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_487789/541470590.py:8: UserWarning: \n", "The version_base parameter is not specified.\n", "Please specify a compatability version level, or None.\n", "Will assume defaults for version 1.1\n", " with hydra.initialize(config_path=\"../configs\"):\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Full Configuration:\n", "task_name: train\n", "tags:\n", "- dev\n", "train: true\n", "test: false\n", "ckpt_path: ${paths.ckpt_dir}/best-checkpoint.ckpt\n", "seed: 42\n", "name: catdog_experiment\n", "data:\n", " _target_: src.datamodules.catdog_datamodule.CatDogImageDataModule\n", " data_dir: ${paths.data_dir}\n", " url: ${paths.data_url}\n", " num_workers: 8\n", " batch_size: 64\n", " train_val_split:\n", " - 0.8\n", " - 0.2\n", " pin_memory: true\n", " image_size: 160\n", "model:\n", " _target_: src.models.catdog_model.ViTTinyClassifier\n", " img_size: 160\n", " patch_size: 16\n", " num_classes: 2\n", " embed_dim: 64\n", " depth: 6\n", " num_heads: 2\n", " mlp_ratio: 3\n", " pre_norm: false\n", " lr: 0.001\n", " weight_decay: 1.0e-05\n", " factor: 0.1\n", " patience: 10\n", " min_lr: 1.0e-06\n", "callbacks:\n", " model_checkpoint:\n", " dirpath: ${paths.ckpt_dir}\n", " filename: best-checkpoint\n", " monitor: val_acc\n", " verbose: true\n", " save_last: true\n", " save_top_k: 1\n", " mode: max\n", " auto_insert_metric_name: false\n", " save_weights_only: false\n", " every_n_train_steps: null\n", " train_time_interval: null\n", " every_n_epochs: null\n", " save_on_train_epoch_end: null\n", " early_stopping:\n", " monitor: val_acc\n", " min_delta: 0.0\n", " patience: 10\n", " verbose: true\n", " mode: max\n", " strict: true\n", " check_finite: true\n", " stopping_threshold: null\n", " divergence_threshold: null\n", " check_on_train_epoch_end: null\n", " rich_model_summary:\n", " max_depth: 1\n", " rich_progress_bar:\n", " refresh_rate: 1\n", "logger:\n", " csv:\n", " save_dir: ${paths.output_dir}\n", " name: csv/\n", " prefix: ''\n", " tensorboard:\n", " save_dir: ${paths.output_dir}/tensorboard/\n", " name: null\n", " log_graph: false\n", " default_hp_metric: true\n", " prefix: ''\n", "trainer:\n", " _target_: lightning.Trainer\n", " default_root_dir: ${paths.output_dir}\n", " min_epochs: 1\n", " max_epochs: 6\n", " accelerator: auto\n", " devices: auto\n", " deterministic: true\n", " log_every_n_steps: 10\n", " fast_dev_run: false\n", "paths:\n", " root_dir: ${oc.env:PROJECT_ROOT}\n", " data_dir: ${paths.root_dir}/data/\n", " log_dir: ${paths.root_dir}/logs/\n", " ckpt_dir: ${paths.root_dir}/checkpoints\n", " artifact_dir: ${paths.root_dir}/artifacts/\n", " data_url: https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\n", " output_dir: ${hydra:runtime.output_dir}\n", " work_dir: ${hydra:runtime.cwd}\n", "\n" ] } ], "source": [ "import hydra\n", "from omegaconf import DictConfig, OmegaConf\n", "\n", "\n", "# Function to load the configuration as an object without using the @hydra.main decorator\n", "def load_config() -> DictConfig:\n", " # Initialize the configuration context (e.g., \"../configs\" directory)\n", " with hydra.initialize(config_path=\"../configs\"):\n", " # Compose the configuration object with a specific config name (e.g., \"train\")\n", " cfg = hydra.compose(config_name=\"train\")\n", " return cfg\n", "\n", "\n", "# Load the configuration\n", "cfg = load_config()\n", "\n", "# Print the entire configuration for reference\n", "print(\"Full Configuration:\")\n", "print(OmegaConf.to_yaml(cfg))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-08 18:25:23\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m8\u001b[0m - \u001b[1mWhole Config:\n", "task_name: train\n", "tags:\n", "- dev\n", "train: true\n", "test: false\n", "ckpt_path: ${paths.ckpt_dir}/best-checkpoint.ckpt\n", "seed: 42\n", "name: catdog_experiment\n", "data:\n", " _target_: src.datamodules.catdog_datamodule.CatDogImageDataModule\n", " data_dir: ${paths.data_dir}\n", " url: ${paths.data_url}\n", " num_workers: 8\n", " batch_size: 64\n", " train_val_split:\n", " - 0.8\n", " - 0.2\n", " pin_memory: true\n", " image_size: 160\n", "model:\n", " _target_: src.models.catdog_model.ViTTinyClassifier\n", " img_size: 160\n", " patch_size: 16\n", " num_classes: 2\n", " embed_dim: 64\n", " depth: 6\n", " num_heads: 2\n", " mlp_ratio: 3\n", " pre_norm: false\n", " lr: 0.001\n", " weight_decay: 1.0e-05\n", " factor: 0.1\n", " patience: 10\n", " min_lr: 1.0e-06\n", "callbacks:\n", " model_checkpoint:\n", " dirpath: ${paths.ckpt_dir}\n", " filename: best-checkpoint\n", " monitor: val_acc\n", " verbose: true\n", " save_last: true\n", " save_top_k: 1\n", " mode: max\n", " auto_insert_metric_name: false\n", " save_weights_only: false\n", " every_n_train_steps: null\n", " train_time_interval: null\n", " every_n_epochs: null\n", " save_on_train_epoch_end: null\n", " early_stopping:\n", " monitor: val_acc\n", " min_delta: 0.0\n", " patience: 10\n", " verbose: true\n", " mode: max\n", " strict: true\n", " check_finite: true\n", " stopping_threshold: null\n", " divergence_threshold: null\n", " check_on_train_epoch_end: null\n", " rich_model_summary:\n", " max_depth: 1\n", " rich_progress_bar:\n", " refresh_rate: 1\n", "logger:\n", " csv:\n", " save_dir: ${paths.output_dir}\n", " name: csv/\n", " prefix: ''\n", " tensorboard:\n", " save_dir: ${paths.output_dir}/tensorboard/\n", " name: null\n", " log_graph: false\n", " default_hp_metric: true\n", " prefix: ''\n", "trainer:\n", " _target_: lightning.Trainer\n", " default_root_dir: ${paths.output_dir}\n", " min_epochs: 1\n", " max_epochs: 6\n", " accelerator: auto\n", " devices: auto\n", " deterministic: true\n", " log_every_n_steps: 10\n", " fast_dev_run: false\n", "paths:\n", " root_dir: ${oc.env:PROJECT_ROOT}\n", " data_dir: ${paths.root_dir}/data/\n", " log_dir: ${paths.root_dir}/logs/\n", " ckpt_dir: ${paths.root_dir}/checkpoints\n", " artifact_dir: ${paths.root_dir}/artifacts/\n", " data_url: https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\n", " output_dir: ${hydra:runtime.output_dir}\n", " work_dir: ${hydra:runtime.cwd}\n", "\u001b[0m\n" ] } ], "source": [ "# Initialize logger\n", "if cfg.task_name == \"train\":\n", " log_path = Path(cfg.paths.log_dir) / \"train.log\"\n", "else:\n", " log_path = Path(cfg.paths.log_dir) / \"eval.log\"\n", "setup_logger(log_path)\n", "\n", "logger.info(f\"Whole Config:\\n{OmegaConf.to_yaml(cfg)}\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m3\u001b[0m - \u001b[1mRoot directory: /mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws\u001b[0m\n", "\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m5\u001b[0m - \u001b[1mCurrent working directory: ['.dvc', '.dvcignore', '.env', '.git', '.github', '.gitignore', '.project-root', 'aws', 'basic_setup.md', 'configs', 'data', 'data.dvc', 'docker-compose.yaml', 'Dockerfile', 'ec2_runner_setup.md', 'logs', 'main.py', 'notebooks', 'poetry.lock', 'pyproject.toml', 'README.md', 'setup_aws_ci.md', 'src', 'tests', 'todo.md']\u001b[0m\n", "\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m8\u001b[0m - \u001b[1mCheckpoint directory: /mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws/checkpoints\u001b[0m\n", "\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m12\u001b[0m - \u001b[1mData directory: /mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws/data/\u001b[0m\n", "\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m16\u001b[0m - \u001b[1mLog directory: /mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws/logs/\u001b[0m\n", "\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m20\u001b[0m - \u001b[1mArtifact directory: /mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws/artifacts/\u001b[0m\n", "\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m28\u001b[0m - \u001b[1mExperiment name: catdog_experiment\u001b[0m\n" ] } ], "source": [ "# the path to the checkpoint directory\n", "root_dir = cfg.paths.root_dir\n", "logger.info(f\"Root directory: {root_dir}\")\n", "\n", "logger.info(f\"Current working directory: {os.listdir(root_dir)}\")\n", "\n", "ckpt_dir = cfg.paths.ckpt_dir\n", "logger.info(f\"Checkpoint directory: {ckpt_dir}\")\n", "\n", "# the path to the data directory\n", "data_dir = cfg.paths.data_dir\n", "logger.info(f\"Data directory: {data_dir}\")\n", "\n", "# the path to the log directory\n", "log_dir = cfg.paths.log_dir\n", "logger.info(f\"Log directory: {log_dir}\")\n", "\n", "# the path to the artifact directory\n", "artifact_dir = cfg.paths.artifact_dir\n", "logger.info(f\"Artifact directory: {artifact_dir}\")\n", "\n", "# output directory\n", "# output_dir = cfg.paths.output_dir\n", "# logger.info(f\"Output directory: {output_dir}\")\n", "\n", "# name of the experiment\n", "experiment_name = cfg.name\n", "logger.info(f\"Experiment name: {experiment_name}\")\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-08 18:25:28\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m2\u001b[0m - \u001b[1mInstantiating datamodule \u001b[0m\n" ] } ], "source": [ "# Initialize DataModule\n", "logger.info(f\"Instantiating datamodule <{cfg.data._target_}>\")\n", "datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-08 18:25:28\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m2\u001b[0m - \u001b[1mNo GPU available\u001b[0m\n", "Seed set to 42\n" ] }, { "data": { "text/plain": [ "42" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check for GPU availability\n", "logger.info(\"GPU available\" if torch.cuda.is_available() else \"No GPU available\")\n", "\n", "# Set seed for reproducibility\n", "L.seed_everything(cfg.seed, workers=True)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-08 18:25:29\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m2\u001b[0m - \u001b[1mInstantiating model \u001b[0m\n" ] } ], "source": [ "# Initialize model\n", "logger.info(f\"Instantiating model <{cfg.model._target_}>\")\n", "model: L.LightningModule = hydra.utils.instantiate(cfg.model)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-08 18:25:30\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m1\u001b[0m - \u001b[1mModel summary:\n", "ViTTinyClassifier(\n", " (model): VisionTransformer(\n", " (patch_embed): PatchEmbed(\n", " (proj): Conv2d(3, 64, kernel_size=(16, 16), stride=(16, 16))\n", " (norm): Identity()\n", " )\n", " (pos_drop): Dropout(p=0.0, inplace=False)\n", " (patch_drop): Identity()\n", " (norm_pre): Identity()\n", " (blocks): Sequential(\n", " (0): Block(\n", " (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=64, out_features=192, bias=False)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=64, out_features=64, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=64, out_features=192, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=192, out_features=64, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (1): Block(\n", " (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=64, out_features=192, bias=False)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=64, out_features=64, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=64, out_features=192, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=192, out_features=64, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (2): Block(\n", " (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=64, out_features=192, bias=False)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=64, out_features=64, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=64, out_features=192, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=192, out_features=64, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (3): Block(\n", " (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=64, out_features=192, bias=False)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=64, out_features=64, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=64, out_features=192, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=192, out_features=64, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (4): Block(\n", " (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=64, out_features=192, bias=False)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=64, out_features=64, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=64, out_features=192, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=192, out_features=64, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (5): Block(\n", " (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=64, out_features=192, bias=False)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=64, out_features=64, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=64, out_features=192, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=192, out_features=64, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " )\n", " (norm): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n", " (fc_norm): Identity()\n", " (head_drop): Dropout(p=0.0, inplace=False)\n", " (head): Linear(in_features=64, out_features=2, bias=True)\n", " )\n", " (train_metrics): ModuleDict(\n", " (accuracy): MulticlassAccuracy()\n", " (precision): MulticlassPrecision()\n", " (recall): MulticlassRecall()\n", " (f1): MulticlassF1Score()\n", " )\n", " (val_metrics): ModuleDict(\n", " (accuracy): MulticlassAccuracy()\n", " (precision): MulticlassPrecision()\n", " (recall): MulticlassRecall()\n", " (f1): MulticlassF1Score()\n", " )\n", " (test_metrics): ModuleDict(\n", " (accuracy): MulticlassAccuracy()\n", " (precision): MulticlassPrecision()\n", " (recall): MulticlassRecall()\n", " (f1): MulticlassF1Score()\n", " )\n", " (criterion): CrossEntropyLoss()\n", ")\u001b[0m\n" ] } ], "source": [ "logger.info(f\"Model summary:\\n{model}\")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def initialize_callbacks(cfg: DictConfig) -> List[L.Callback]:\n", " \"\"\"Initialize the callbacks based on the configuration.\"\"\"\n", " if not cfg:\n", " logger.warning(\"No callback configs found! Skipping..\")\n", " return callbacks\n", "\n", " if not isinstance(cfg, DictConfig):\n", " raise TypeError(\"Callbacks config must be a DictConfig!\")\n", " callbacks = []\n", "\n", " # Initialize the model checkpoint callback\n", " model_checkpoint = ModelCheckpoint(**cfg.callbacks.model_checkpoint)\n", " callbacks.append(model_checkpoint)\n", "\n", " # Initialize the early stopping callback\n", " early_stopping = EarlyStopping(**cfg.callbacks.early_stopping)\n", " callbacks.append(early_stopping)\n", "\n", " # Initialize the rich model summary callback\n", " model_summary = RichModelSummary(**cfg.callbacks.rich_model_summary)\n", " callbacks.append(model_summary)\n", "\n", " # Initialize the rich progress bar callback\n", " progress_bar = RichProgressBar(**cfg.callbacks.rich_progress_bar)\n", " callbacks.append(progress_bar)\n", "\n", " return callbacks\n", "\n", "\n", "def initialize_logger(cfg: DictConfig) -> Logger:\n", " \"\"\"Initialize the logger based on the configuration.\"\"\"\n", " if not cfg:\n", " logger.warning(\"No logger configs found! Skipping..\")\n", " return None\n", "\n", " if not isinstance(cfg, DictConfig):\n", " raise TypeError(\"Logger config must be a DictConfig!\")\n", "\n", " loggers = []\n", "\n", " # Initialize the TensorBoard logger\n", " tensorboard_logger = TensorBoardLogger(**cfg.loggers.tensorboard)\n", " loggers.append(tensorboard_logger)\n", "\n", " # Initialize the CSV logger\n", " csv_logger = CSVLogger(**cfg.loggers.csv)\n", " loggers.append(csv_logger)\n", "\n", " return loggers" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "########################################## End of the script ##########################################" ] } ], "metadata": { "kernelspec": { "display_name": "emlo_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 2 }