soutrik's picture
added: model and code and app
29730dd

A newer version of the Gradio SDK is available: 5.9.1

Upgrade
metadata
title: Erav2s13
emoji: πŸ”₯
colorFrom: yellow
colorTo: red
sdk: gradio
sdk_version: 4.27.0
app_file: app.py
pinned: false
license: mit

Erav2s13- SOUTRIK πŸ”₯

Overview

This repository leverages the Hugging Face repository and Gradio for building a user interface (UI). The model training was conducted using Google Colab, and the resulting model files are utilized for inference in the Gradio app.

  • Model Training: Main.ipynb - Colab notebook used to build and train the model.
  • Inference: The same model structure and files are used in the Gradio app.

Custom ResNet Model

The custom_resnet.py file defines a custom ResNet (Residual Network) model using PyTorch Lightning. This model is specifically designed for image classification tasks, particularly for the CIFAR-10 dataset.

Model Architecture

The custom ResNet model comprises the following components:

  1. Preparation Layer: Convolutional layer with 64 filters, followed by batch normalization, ReLU activation, and dropout.
  2. Layer 1: Convolutional layer with 128 filters, max pooling, batch normalization, ReLU activation, and dropout. Includes a residual block with two convolutional layers (128 filters each), batch normalization, ReLU activation, and dropout.
  3. Layer 2: Convolutional layer with 256 filters, max pooling, batch normalization, ReLU activation, and dropout.
  4. Layer 3: Convolutional layer with 512 filters, max pooling, batch normalization, ReLU activation, and dropout. Includes a residual block with two convolutional layers (512 filters each), batch normalization, ReLU activation, and dropout.
  5. Max Pooling: Max pooling layer with a kernel size of 4.
  6. Fully Connected Layer: Flattened output passed through a fully connected layer with 10 output units (for CIFAR-10 classes).
  7. Softmax: Log softmax activation function to obtain predicted class probabilities.

Training and Evaluation

The model is trained using PyTorch Lightning, which provides a high-level interface for training, validation, and testing. Key components include:

  • Optimizer: Adam with a learning rate specified by PREFERRED_START_LR.
  • Scheduler: OneCycleLR for learning rate adjustment.
  • Loss and Accuracy: Cross-entropy loss and accuracy are computed and logged during training, validation, and testing.

Misclassified Images

During testing, misclassified images are tracked and stored in a dictionary along with their ground truth and predicted labels, facilitating error analysis and model improvement.

Hyperparameters

Key hyperparameters include:

  • PREFERRED_START_LR: Initial learning rate.
  • PREFERRED_WEIGHT_DECAY: Weight decay for regularization.

Model Summary

The detailed_model_summary function prints a comprehensive summary of the model architecture, detailing input size, kernel size, output size, number of parameters, and trainable status of each layer.

Lightning Dataset Module

The lightning_dataset.py file contains the CIFARDataModule class, which is a PyTorch Lightning LightningDataModule for the CIFAR-10 dataset. This class handles data preparation, splitting, and loading.

CIFARDataModule Class

Parameters

  • data_path: Directory path for CIFAR-10 dataset.
  • batch_size: Batch size for data loaders.
  • seed: Random seed for reproducibility.
  • val_split: Fraction of training data used for validation (default: 0).
  • num_workers: Number of worker processes for data loading (default: 0).

Methods

  • prepare_data: Downloads CIFAR-10 dataset if not present.
  • setup: Defines data transformations and creates training, validation, and testing datasets.
  • train_dataloader: Returns training data loader.
  • val_dataloader: Returns validation data loader.
  • test_dataloader: Returns testing data loader.

Utility Methods

  • _split_train_val: Splits training dataset into training and validation subsets.
  • _init_fn: Initializes random seed for each worker process to ensure reproducibility.

License

This project is licensed under the MIT License.