Spaces:
Runtime error
A newer version of the Gradio SDK is available:
5.9.1
title: Erav2s13
emoji: π₯
colorFrom: yellow
colorTo: red
sdk: gradio
sdk_version: 4.27.0
app_file: app.py
pinned: false
license: mit
Erav2s13- SOUTRIK π₯
Overview
This repository leverages the Hugging Face repository and Gradio for building a user interface (UI). The model training was conducted using Google Colab, and the resulting model files are utilized for inference in the Gradio app.
- Model Training:
Main.ipynb
- Colab notebook used to build and train the model. - Inference: The same model structure and files are used in the Gradio app.
Custom ResNet Model
The custom_resnet.py
file defines a custom ResNet (Residual Network) model using PyTorch Lightning. This model is specifically designed for image classification tasks, particularly for the CIFAR-10 dataset.
Model Architecture
The custom ResNet model comprises the following components:
- Preparation Layer: Convolutional layer with 64 filters, followed by batch normalization, ReLU activation, and dropout.
- Layer 1: Convolutional layer with 128 filters, max pooling, batch normalization, ReLU activation, and dropout. Includes a residual block with two convolutional layers (128 filters each), batch normalization, ReLU activation, and dropout.
- Layer 2: Convolutional layer with 256 filters, max pooling, batch normalization, ReLU activation, and dropout.
- Layer 3: Convolutional layer with 512 filters, max pooling, batch normalization, ReLU activation, and dropout. Includes a residual block with two convolutional layers (512 filters each), batch normalization, ReLU activation, and dropout.
- Max Pooling: Max pooling layer with a kernel size of 4.
- Fully Connected Layer: Flattened output passed through a fully connected layer with 10 output units (for CIFAR-10 classes).
- Softmax: Log softmax activation function to obtain predicted class probabilities.
Training and Evaluation
The model is trained using PyTorch Lightning, which provides a high-level interface for training, validation, and testing. Key components include:
- Optimizer: Adam with a learning rate specified by
PREFERRED_START_LR
. - Scheduler: OneCycleLR for learning rate adjustment.
- Loss and Accuracy: Cross-entropy loss and accuracy are computed and logged during training, validation, and testing.
Misclassified Images
During testing, misclassified images are tracked and stored in a dictionary along with their ground truth and predicted labels, facilitating error analysis and model improvement.
Hyperparameters
Key hyperparameters include:
PREFERRED_START_LR
: Initial learning rate.PREFERRED_WEIGHT_DECAY
: Weight decay for regularization.
Model Summary
The detailed_model_summary
function prints a comprehensive summary of the model architecture, detailing input size, kernel size, output size, number of parameters, and trainable status of each layer.
Lightning Dataset Module
The lightning_dataset.py
file contains the CIFARDataModule
class, which is a PyTorch Lightning LightningDataModule
for the CIFAR-10 dataset. This class handles data preparation, splitting, and loading.
CIFARDataModule Class
Parameters
data_path
: Directory path for CIFAR-10 dataset.batch_size
: Batch size for data loaders.seed
: Random seed for reproducibility.val_split
: Fraction of training data used for validation (default: 0).num_workers
: Number of worker processes for data loading (default: 0).
Methods
prepare_data
: Downloads CIFAR-10 dataset if not present.setup
: Defines data transformations and creates training, validation, and testing datasets.train_dataloader
: Returns training data loader.val_dataloader
: Returns validation data loader.test_dataloader
: Returns testing data loader.
Utility Methods
_split_train_val
: Splits training dataset into training and validation subsets._init_fn
: Initializes random seed for each worker process to ensure reproducibility.
License
This project is licensed under the MIT License.