camie-tagger / README.md
Camais03's picture
Update README.md
8eac823 verified
|
raw
history blame
15.3 kB
---
license: gpl-3.0
datasets:
- p1atdev/danbooru-2024
metrics:
- f1
tags:
- art
- code
---
## Usage
After installation, run the application by executing `setup.bat`. This launches a web interface where you can:
- Upload your own images or select from example images
- Choose different threshold profiles
- Adjust category-specific thresholds
- View predictions organized by category
- Filter and sort tags based on confidence# Anime Image Tagger
An advanced deep learning model for automatically tagging anime/manga illustrations with relevant tags across multiple categories, achieving **61% F1 score** across 70,000+ possible tags on a test set of 20,116 samples.
## Key Highlights
- **Efficient Training**: Completed on just a single RTX 3060 GPU (12GB VRAM)
- **Fast Convergence**: Trained on 7,024,392 samples (3.52 epochs) in 1,756,098 batches
- **Comprehensive Coverage**: 70,000+ tags across 7 categories (general, character, copyright, artist, meta, rating, year)
- **Innovative Architecture**: Two-stage prediction model with cross-attention for tag context
- **User-Friendly Interface**: Easy-to-use application with customizable thresholds
*This project demonstrates that high-quality anime image tagging models can be trained on consumer hardware with the right optimization techniques.*
## Features
- **Multi-category tagging system**: Handles general tags, characters, copyright (series), artists, meta information, and content ratings
- **High performance**: 61% F1 score across 70,000+ possible tags
- **Dual-mode operation**: Full model for best quality or Initial-only mode for reduced VRAM usage
- **Windows compatibility**: Initial-only mode works on Windows without Flash Attention
- **Streamlit web interface**: User-friendly UI for uploading and analyzing images
- **Adjustable threshold profiles**: Overall, Weighted, Category-specific, High Precision, and High Recall profiles
- **Fine-grained control**: Per-category threshold adjustments for precision-recall tradeoffs
## Loss Function
The model employs a specialized `UnifiedFocalLoss` to address the extreme class imbalance inherent in multi-label tag prediction:
```python
class UnifiedFocalLoss(nn.Module):
def __init__(self, device=None, gamma=2.0, alpha=0.25, lambda_initial=0.4):
# Implementation details...
```
### Key Components
1. **Focal Loss Mechanism**:
- Down-weights well-classified examples (γ=2.0) to focus training on difficult tags
- Addresses the extreme imbalance between positive and negative examples (often 100:1 or worse)
- Uses α=0.25 to balance positive/negative examples across 70,000+ possible tags
2. **Two-stage Weighting**:
- Combines losses from both prediction stages (`initial_predictions` and `refined_predictions`)
- Uses λ=0.4 to weight the initial prediction loss, giving more importance (0.6) to refined predictions
- This encourages the model to improve predictions in the refinement stage while still maintaining strong initial predictions
3. **Per-sample Statistics**:
- Tracks separate metrics for positive and negative samples
- Provides detailed debugging information about prediction distributions
- Enables analysis of which tag categories are performing well/poorly
This loss function was essential for achieving high F1 scores across diverse tag categories despite the extreme classification challenge of 70,000+ possible tags.
## DeepSpeed Configuration
Microsoft DeepSpeed was crucial for training this model on consumer hardware. The project uses a carefully tuned configuration to maximize efficiency:
```python
def create_deepspeed_config(
config_path,
learning_rate=3e-4,
weight_decay=0.01,
num_train_samples=None,
micro_batch_size=4,
grad_accum_steps=8
):
# Implementation details...
```
### Key Optimizations
1. **Memory Efficiency**:
- **ZeRO Stage 2**: Partitions optimizer states and gradients, dramatically reducing memory requirements
- **Activation Checkpointing**: Trades computation for memory by recomputing activations during backpropagation
- **Contiguous Memory Optimization**: Reduces memory fragmentation
2. **Mixed Precision Training**:
- **FP16 Mode**: Uses half-precision (16-bit) for most calculations, with automatic loss scaling
- **Initial Scale Power**: Set to 16 for stable convergence with large batch sizes
3. **Gradient Accumulation**:
- Micro-batch size of 4 with 8 gradient accumulation steps
- Effective batch size of 32 while only requiring memory for 4 samples at once
4. **Learning Rate Schedule**:
- WarmupLR scheduler with gradual increase from 3e-6 to 3e-4
- Warmup over 1/4 of an epoch to stabilize early training
This configuration allowed the model to train efficiently with only 12GB of VRAM while maintaining numerical stability across millions of training examples with 70,000+ output dimensions.
## Dataset
The model was trained on a carefully filtered subset of the [Danbooru 2024 dataset](https://huggingface.co/datasets/p1atdev/danbooru-2024), which contains a vast collection of anime/manga illustrations with comprehensive tagging.
### Filtering Process
The dataset was filtered with the following constraints:
```python
# Minimum tags per category required for each image
min_tag_counts = {
'general': 25,
'character': 1,
'copyright': 1,
'artist': 0,
'meta': 0
}
# Minimum samples per tag required for tag to be included
min_tag_samples = {
'general': 20,
'character': 40,
'copyright': 50,
'artist': 200,
'meta': 50
}
```
This filtering process:
1. First removed low-sample tags (tags with fewer occurrences than specified in `min_tag_samples`)
2. Then removed images with insufficient tags per category (as specified in `min_tag_counts`)
### Training Data
- **Starting dataset size**: ~3,000,000 filtered images
- **Training subset**: 2,000,000 images (due to storage and time constraints)
- **Training duration**: 3.5 epochs
The model could potentially achieve even higher accuracy with more training epochs and the full dataset.
### Preprocessing
Images were preprocessed with minimal transformations:
- Tensor normalization (scaled to 0-1 range)
- Resized while maintaining original aspect ratio
- No additional augmentations were applied
## Model Architecture
The model uses a novel two-stage prediction approach that achieves superior performance compared to traditional single-stage models:
### Image Feature Extraction
- **Backbone**: EfficientNet V2-L extracts high-quality visual features from input images
- **Spatial Pooling**: Adaptive averaging converts spatial features to a compact 1280-dimensional embedding
### Initial Prediction Stage
- Direct classification from image features through a multi-layer classifier
- Bottleneck architecture with LayerNorm and GELU activations between linear layers
- Outputs initial tag probabilities across all 70,000+ possible tags
### Tag Context Mechanism
- Top predicted tags are embedded using a shared embedding space
- Self-attention layer allows tags to influence each other based on co-occurrence patterns
- Normalized tag embeddings represent a coherent "tag context" for the image
### Cross-Attention Refinement
- Image features and tag embeddings interact through cross-attention
- Each dimension of the image features attends to relevant dimensions in the tag space
- This creates a bidirectional flow of information between visual features and semantic tags
### Refined Predictions
- Fused features (original + cross-attended) feed into a final classifier
- Residual connection ensures initial predictions are preserved when beneficial
- Temperature scaling provides calibrated probability outputs
This dual-stage approach allows the model to leverage tag co-occurrence patterns and semantic relationships, improving accuracy without increasing the parameter count significantly.
## Installation
Simply run the included setup script to install all dependencies:
```
setup.bat
```
This will automatically set up all necessary packages for the application.
### Requirements
- **Python 3.11.9 specifically** (newer versions are incompatible)
- PyTorch 1.10+
- Streamlit
- PIL/Pillow
- NumPy
- Flash Attention (note: doesn't work properly on Windows)
### Running the Application
The application is located in the `app` folder and can be launched via the setup script:
1. Run `setup.bat` to install dependencies
2. The Streamlit interface will automatically open in your browser
3. If the browser doesn't open automatically, navigate to http://localhost:8501
## Model Details
### Tag Categories
The model recognizes tags across these categories:
- **General**: Visual elements, concepts, clothing, etc.
- **Character**: Individual characters appearing in the image
- **Copyright**: Source material (anime, manga, game)
- **Artist**: Creator of the artwork
- **Meta**: Meta information about the image
- **Rating**: Content rating
- **Year**: Year of upload
### Performance Notes
The full model with refined predictions outperforms the initial-only model, though the performance gap is surprisingly small given the same parameter count. This is an interesting architectural finding - the refined predictions layer adds significant value without substantial computational overhead.
This efficiency makes the initial-only model particularly valuable for Windows users or systems with limited VRAM, as they can still achieve near-optimal performance without requiring Flash Attention.
In benchmarks, the model achieved a 61% F1 score across all categories, which is remarkable considering the extreme multi-label classification challenge of 70,000+ possible tags. The model performs particularly well on general tags and character recognition.
### Threshold Profiles
- **Overall**: Single threshold applied to all categories
- **Weighted**: Threshold optimized for balanced performance across categories
- **Category-specific**: Different thresholds for each category
- **High Precision**: Higher thresholds for more confident predictions
- **High Recall**: Lower thresholds to capture more potential tags
## Windows Compatibility
The full model uses Flash Attention, which does not work properly on Windows. For Windows users:
- The application automatically defaults to the Initial-only model
- Performance difference is minimal (0.2% absolute F1 score reduction, from 61.6% to 61.4%)
- The Initial-only model still uses the same powerful EfficientNet backbone and initial classifier
## Web Interface Guide
The interface is divided into three main sections:
1. **Model Selection** (Sidebar)
- Choose between Full Model or Initial-only Model
- View model information and memory usage
2. **Image Upload** (Left Panel)
- Upload your own images or select from examples
- View the selected image
3. **Tagging Controls** (Right Panel)
- Select threshold profile
- Adjust thresholds for precision-recall tradeoff
- Configure display options
- View predictions organized by category
### Display Options
- **Show all tags**: Display all tags including those below threshold
- **Compact view**: Hide progress bars for cleaner display
- **Minimum confidence**: Filter out low-confidence predictions
- **Category selection**: Choose which categories to include in the summary
### Interface Screenshots
![Application Interface](images/app_screenshot.png)
![Tag Results Example](image/tag_results_example.png)
## Training Environment
The model was trained using surprisingly modest hardware:
- **GPU**: Single NVIDIA RTX 3060 (12GB VRAM)
- **RAM**: 64GB system memory
- **Platform**: Windows with WSL (Windows Subsystem for Linux)
- **Libraries**:
- Microsoft DeepSpeed for memory-efficient training
- PyTorch with CUDA acceleration
- Flash Attention for optimized attention computation
### Training Notebooks
The repository includes two main training notebooks:
1. **CAMIE Tagger.ipynb**
- Main training notebook
- Dataset loading and preprocessing
- Model initialization
- Initial training loop with DeepSpeed integration
- Tag selection optimization
- Metric tracking and visualization
2. **Camie Tagger Cont and Evals.ipynb**
- Continuation of training from checkpoints
- Comprehensive model evaluation
- Per-category performance metrics
- Threshold optimization
- Model conversion for deployment in the app
- Export functionality for the standalone application
### Training Monitor
The project includes a real-time training monitor accessible via browser at `localhost:5000` during training:
#### Performance Tips
⚠️ **Important**: For optimal training speed, keep VSCode minimized and the training monitor open in your browser. This can improve iteration speed by **3-5x** due to how the Windows/WSL graphics stack handles window focus and CUDA kernel execution.
#### Monitor Features
The training monitor provides three main views:
##### 1. Overview Tab
![Overview Tab](images/training_monitor_overview.png)
- **Training Progress**: Real-time metrics including epoch, batch, speed, and time estimates
- **Loss Chart**: Training and validation loss visualization
- **F1 Scores**: Initial and refined F1 metrics for both training and validation
##### 2. Predictions Tab
![Predictions Tab](images/training_monitor_predictions.png)
- **Image Preview**: Shows the current sample being analyzed
- **Prediction Controls**: Toggle between initial and refined predictions
- **Tag Analysis**:
- Color-coded tag results (correct, incorrect, missing)
- Confidence visualization with probability bars
- Category-based organization
- Filtering options for error analysis
##### 3. Selection Analysis Tab
![Selection Analysis Tab](images/training_monitor_selection.png)
- **Selection Metrics**: Statistics on tag selection quality
- Ground truth recall
- Average probability for ground truth vs. non-ground truth tags
- Unique tags selected
- **Selection Graph**: Trends in selection quality over time
- **Selected Tags Details**: Detailed view of model-selected tags with confidence scores
The monitor provides invaluable insights into how the two-stage prediction model is performing, particularly how the tag selection process is working between the initial and refined prediction stages.
### Training Notes
- Training notebooks require WSL and likely 32GB+ of RAM to handle the dataset
- Microsoft DeepSpeed was crucial for fitting the model and batches into the available VRAM
- Despite hardware limitations, the model achieves impressive results
- With more computational resources, the model could be trained longer on the full dataset
## Support:
I plan to move onto LLMs after this project as I have lots of ideas on how to improve upon them. I will update this model based on community attention.
If you'd like to support further training on the complete dataset or my future projects, consider [buying me a coffee](https://www.buymeacoffee.com/yourusername).
## Acknowledgments
- [Danbooru](https://danbooru.donmai.us/) for the incredible dataset of tagged anime images
- [p1atdev](https://huggingface.co/p1atdev) for the processed Danbooru 2024 dataset
- Microsoft for DeepSpeed, which made training possible on consumer hardware
- PyTorch and the open-source ML community