Runtime error
A newer version of the Gradio SDK is available:
title: Erav2s13
emoji: π₯
colorFrom: yellow
colorTo: red
sdk: gradio
sdk_version: 4.27.0
pinned: false
license: mit
Erav2s13- SOUTRIK π₯
This repository leverages the Hugging Face repository and Gradio for building a user interface (UI). The model training was conducted using Google Colab, and the resulting model files are utilized for inference in the Gradio app.
- Model Training:
- Colab notebook used to build and train the model. - Inference: The same model structure and files are used in the Gradio app.
Custom ResNet Model
file defines a custom ResNet (Residual Network) model using PyTorch Lightning. This model is specifically designed for image classification tasks, particularly for the CIFAR-10 dataset.
Model Architecture
The custom ResNet model comprises the following components:
- Preparation Layer: Convolutional layer with 64 filters, followed by batch normalization, ReLU activation, and dropout.
- Layer 1: Convolutional layer with 128 filters, max pooling, batch normalization, ReLU activation, and dropout. Includes a residual block with two convolutional layers (128 filters each), batch normalization, ReLU activation, and dropout.
- Layer 2: Convolutional layer with 256 filters, max pooling, batch normalization, ReLU activation, and dropout.
- Layer 3: Convolutional layer with 512 filters, max pooling, batch normalization, ReLU activation, and dropout. Includes a residual block with two convolutional layers (512 filters each), batch normalization, ReLU activation, and dropout.
- Max Pooling: Max pooling layer with a kernel size of 4.
- Fully Connected Layer: Flattened output passed through a fully connected layer with 10 output units (for CIFAR-10 classes).
- Softmax: Log softmax activation function to obtain predicted class probabilities.
Training and Evaluation
The model is trained using PyTorch Lightning, which provides a high-level interface for training, validation, and testing. Key components include:
- Optimizer: Adam with a learning rate specified by
. - Scheduler: OneCycleLR for learning rate adjustment.
- Loss and Accuracy: Cross-entropy loss and accuracy are computed and logged during training, validation, and testing.
Misclassified Images
During testing, misclassified images are tracked and stored in a dictionary along with their ground truth and predicted labels, facilitating error analysis and model improvement.
Key hyperparameters include:
: Initial learning rate.PREFERRED_WEIGHT_DECAY
: Weight decay for regularization.
Model Summary
The detailed_model_summary
function prints a comprehensive summary of the model architecture, detailing input size, kernel size, output size, number of parameters, and trainable status of each layer.
Lightning Dataset Module
file contains the CIFARDataModule
class, which is a PyTorch Lightning LightningDataModule
for the CIFAR-10 dataset. This class handles data preparation, splitting, and loading.
CIFARDataModule Class
: Directory path for CIFAR-10 dataset.batch_size
: Batch size for data loaders.seed
: Random seed for reproducibility.val_split
: Fraction of training data used for validation (default: 0).num_workers
: Number of worker processes for data loading (default: 0).
: Downloads CIFAR-10 dataset if not present.setup
: Defines data transformations and creates training, validation, and testing datasets.train_dataloader
: Returns training data loader.val_dataloader
: Returns validation data loader.test_dataloader
: Returns testing data loader.
Utility Methods
: Splits training dataset into training and validation subsets._init_fn
: Initializes random seed for each worker process to ensure reproducibility.
This project is licensed under the MIT License.