|
--- |
|
license: gpl-3.0 |
|
datasets: |
|
- p1atdev/danbooru-2024 |
|
metrics: |
|
- f1 |
|
tags: |
|
- art |
|
- code |
|
--- |
|
## Usage |
|
|
|
After installation, run the application by executing `setup.bat`. This launches a web interface where you can: |
|
|
|
- Upload your own images or select from example images |
|
- Choose different threshold profiles |
|
- Adjust category-specific thresholds |
|
- View predictions organized by category |
|
- Filter and sort tags based on confidence# Anime Image Tagger |
|
|
|
An advanced deep learning model for automatically tagging anime/manga illustrations with relevant tags across multiple categories, achieving **61% F1 score** across 70,000+ possible tags on a test set of 20,116 samples. |
|
|
|
## Key Highlights |
|
|
|
- **Efficient Training**: Completed on just a single RTX 3060 GPU (12GB VRAM) |
|
- **Fast Convergence**: Trained on 7,024,392 samples (3.52 epochs) in 1,756,098 batches |
|
- **Comprehensive Coverage**: 70,000+ tags across 7 categories (general, character, copyright, artist, meta, rating, year) |
|
- **Innovative Architecture**: Two-stage prediction model with cross-attention for tag context |
|
- **User-Friendly Interface**: Easy-to-use application with customizable thresholds |
|
|
|
*This project demonstrates that high-quality anime image tagging models can be trained on consumer hardware with the right optimization techniques.* |
|
|
|
## Features |
|
|
|
- **Multi-category tagging system**: Handles general tags, characters, copyright (series), artists, meta information, and content ratings |
|
- **High performance**: 61% F1 score across 70,000+ possible tags |
|
- **Dual-mode operation**: Full model for best quality or Initial-only mode for reduced VRAM usage |
|
- **Windows compatibility**: Initial-only mode works on Windows without Flash Attention |
|
- **Streamlit web interface**: User-friendly UI for uploading and analyzing images |
|
- **Adjustable threshold profiles**: Overall, Weighted, Category-specific, High Precision, and High Recall profiles |
|
- **Fine-grained control**: Per-category threshold adjustments for precision-recall tradeoffs |
|
|
|
## Loss Function |
|
|
|
The model employs a specialized `UnifiedFocalLoss` to address the extreme class imbalance inherent in multi-label tag prediction: |
|
|
|
```python |
|
class UnifiedFocalLoss(nn.Module): |
|
def __init__(self, device=None, gamma=2.0, alpha=0.25, lambda_initial=0.4): |
|
# Implementation details... |
|
``` |
|
|
|
### Key Components |
|
|
|
1. **Focal Loss Mechanism**: |
|
- Down-weights well-classified examples (γ=2.0) to focus training on difficult tags |
|
- Addresses the extreme imbalance between positive and negative examples (often 100:1 or worse) |
|
- Uses α=0.25 to balance positive/negative examples across 70,000+ possible tags |
|
|
|
2. **Two-stage Weighting**: |
|
- Combines losses from both prediction stages (`initial_predictions` and `refined_predictions`) |
|
- Uses λ=0.4 to weight the initial prediction loss, giving more importance (0.6) to refined predictions |
|
- This encourages the model to improve predictions in the refinement stage while still maintaining strong initial predictions |
|
|
|
3. **Per-sample Statistics**: |
|
- Tracks separate metrics for positive and negative samples |
|
- Provides detailed debugging information about prediction distributions |
|
- Enables analysis of which tag categories are performing well/poorly |
|
|
|
This loss function was essential for achieving high F1 scores across diverse tag categories despite the extreme classification challenge of 70,000+ possible tags. |
|
|
|
## DeepSpeed Configuration |
|
|
|
Microsoft DeepSpeed was crucial for training this model on consumer hardware. The project uses a carefully tuned configuration to maximize efficiency: |
|
|
|
```python |
|
def create_deepspeed_config( |
|
config_path, |
|
learning_rate=3e-4, |
|
weight_decay=0.01, |
|
num_train_samples=None, |
|
micro_batch_size=4, |
|
grad_accum_steps=8 |
|
): |
|
# Implementation details... |
|
``` |
|
|
|
### Key Optimizations |
|
|
|
1. **Memory Efficiency**: |
|
- **ZeRO Stage 2**: Partitions optimizer states and gradients, dramatically reducing memory requirements |
|
- **Activation Checkpointing**: Trades computation for memory by recomputing activations during backpropagation |
|
- **Contiguous Memory Optimization**: Reduces memory fragmentation |
|
|
|
2. **Mixed Precision Training**: |
|
- **FP16 Mode**: Uses half-precision (16-bit) for most calculations, with automatic loss scaling |
|
- **Initial Scale Power**: Set to 16 for stable convergence with large batch sizes |
|
|
|
3. **Gradient Accumulation**: |
|
- Micro-batch size of 4 with 8 gradient accumulation steps |
|
- Effective batch size of 32 while only requiring memory for 4 samples at once |
|
|
|
4. **Learning Rate Schedule**: |
|
- WarmupLR scheduler with gradual increase from 3e-6 to 3e-4 |
|
- Warmup over 1/4 of an epoch to stabilize early training |
|
|
|
This configuration allowed the model to train efficiently with only 12GB of VRAM while maintaining numerical stability across millions of training examples with 70,000+ output dimensions. |
|
|
|
## Dataset |
|
|
|
The model was trained on a carefully filtered subset of the [Danbooru 2024 dataset](https://huggingface.co/datasets/p1atdev/danbooru-2024), which contains a vast collection of anime/manga illustrations with comprehensive tagging. |
|
|
|
### Filtering Process |
|
|
|
The dataset was filtered with the following constraints: |
|
|
|
```python |
|
# Minimum tags per category required for each image |
|
min_tag_counts = { |
|
'general': 25, |
|
'character': 1, |
|
'copyright': 1, |
|
'artist': 0, |
|
'meta': 0 |
|
} |
|
|
|
# Minimum samples per tag required for tag to be included |
|
min_tag_samples = { |
|
'general': 20, |
|
'character': 40, |
|
'copyright': 50, |
|
'artist': 200, |
|
'meta': 50 |
|
} |
|
``` |
|
|
|
This filtering process: |
|
1. First removed low-sample tags (tags with fewer occurrences than specified in `min_tag_samples`) |
|
2. Then removed images with insufficient tags per category (as specified in `min_tag_counts`) |
|
|
|
### Training Data |
|
|
|
- **Starting dataset size**: ~3,000,000 filtered images |
|
- **Training subset**: 2,000,000 images (due to storage and time constraints) |
|
- **Training duration**: 3.5 epochs |
|
|
|
The model could potentially achieve even higher accuracy with more training epochs and the full dataset. |
|
|
|
### Preprocessing |
|
|
|
Images were preprocessed with minimal transformations: |
|
- Tensor normalization (scaled to 0-1 range) |
|
- Resized while maintaining original aspect ratio |
|
- No additional augmentations were applied |
|
|
|
## Model Architecture |
|
|
|
The model uses a novel two-stage prediction approach that achieves superior performance compared to traditional single-stage models: |
|
|
|
### Image Feature Extraction |
|
- **Backbone**: EfficientNet V2-L extracts high-quality visual features from input images |
|
- **Spatial Pooling**: Adaptive averaging converts spatial features to a compact 1280-dimensional embedding |
|
|
|
### Initial Prediction Stage |
|
- Direct classification from image features through a multi-layer classifier |
|
- Bottleneck architecture with LayerNorm and GELU activations between linear layers |
|
- Outputs initial tag probabilities across all 70,000+ possible tags |
|
|
|
### Tag Context Mechanism |
|
- Top predicted tags are embedded using a shared embedding space |
|
- Self-attention layer allows tags to influence each other based on co-occurrence patterns |
|
- Normalized tag embeddings represent a coherent "tag context" for the image |
|
|
|
### Cross-Attention Refinement |
|
- Image features and tag embeddings interact through cross-attention |
|
- Each dimension of the image features attends to relevant dimensions in the tag space |
|
- This creates a bidirectional flow of information between visual features and semantic tags |
|
|
|
### Refined Predictions |
|
- Fused features (original + cross-attended) feed into a final classifier |
|
- Residual connection ensures initial predictions are preserved when beneficial |
|
- Temperature scaling provides calibrated probability outputs |
|
|
|
This dual-stage approach allows the model to leverage tag co-occurrence patterns and semantic relationships, improving accuracy without increasing the parameter count significantly. |
|
|
|
## Installation |
|
|
|
Simply run the included setup script to install all dependencies: |
|
|
|
``` |
|
setup.bat |
|
``` |
|
|
|
This will automatically set up all necessary packages for the application. |
|
|
|
### Requirements |
|
|
|
- **Python 3.11.9 specifically** (newer versions are incompatible) |
|
- PyTorch 1.10+ |
|
- Streamlit |
|
- PIL/Pillow |
|
- NumPy |
|
- Flash Attention (note: doesn't work properly on Windows) |
|
|
|
### Running the Application |
|
|
|
The application is located in the `app` folder and can be launched via the setup script: |
|
|
|
1. Run `setup.bat` to install dependencies |
|
2. The Streamlit interface will automatically open in your browser |
|
3. If the browser doesn't open automatically, navigate to http://localhost:8501 |
|
|
|
## Model Details |
|
|
|
### Tag Categories |
|
|
|
The model recognizes tags across these categories: |
|
- **General**: Visual elements, concepts, clothing, etc. |
|
- **Character**: Individual characters appearing in the image |
|
- **Copyright**: Source material (anime, manga, game) |
|
- **Artist**: Creator of the artwork |
|
- **Meta**: Meta information about the image |
|
- **Rating**: Content rating |
|
- **Year**: Year of upload |
|
|
|
### Performance Notes |
|
|
|
The full model with refined predictions outperforms the initial-only model, though the performance gap is surprisingly small given the same parameter count. This is an interesting architectural finding - the refined predictions layer adds significant value without substantial computational overhead. |
|
|
|
This efficiency makes the initial-only model particularly valuable for Windows users or systems with limited VRAM, as they can still achieve near-optimal performance without requiring Flash Attention. |
|
|
|
In benchmarks, the model achieved a 61% F1 score across all categories, which is remarkable considering the extreme multi-label classification challenge of 70,000+ possible tags. The model performs particularly well on general tags and character recognition. |
|
|
|
### Threshold Profiles |
|
|
|
- **Overall**: Single threshold applied to all categories |
|
- **Weighted**: Threshold optimized for balanced performance across categories |
|
- **Category-specific**: Different thresholds for each category |
|
- **High Precision**: Higher thresholds for more confident predictions |
|
- **High Recall**: Lower thresholds to capture more potential tags |
|
|
|
## Windows Compatibility |
|
|
|
The full model uses Flash Attention, which does not work properly on Windows. For Windows users: |
|
|
|
- The application automatically defaults to the Initial-only model |
|
- Performance difference is minimal (0.2% absolute F1 score reduction, from 61.6% to 61.4%) |
|
- The Initial-only model still uses the same powerful EfficientNet backbone and initial classifier |
|
|
|
## Web Interface Guide |
|
|
|
The interface is divided into three main sections: |
|
|
|
1. **Model Selection** (Sidebar) |
|
- Choose between Full Model or Initial-only Model |
|
- View model information and memory usage |
|
|
|
2. **Image Upload** (Left Panel) |
|
- Upload your own images or select from examples |
|
- View the selected image |
|
|
|
3. **Tagging Controls** (Right Panel) |
|
- Select threshold profile |
|
- Adjust thresholds for precision-recall tradeoff |
|
- Configure display options |
|
- View predictions organized by category |
|
|
|
### Display Options |
|
|
|
- **Show all tags**: Display all tags including those below threshold |
|
- **Compact view**: Hide progress bars for cleaner display |
|
- **Minimum confidence**: Filter out low-confidence predictions |
|
- **Category selection**: Choose which categories to include in the summary |
|
|
|
### Interface Screenshots |
|
|
|
 |
|
|
|
 |
|
|
|
## Training Environment |
|
|
|
The model was trained using surprisingly modest hardware: |
|
|
|
- **GPU**: Single NVIDIA RTX 3060 (12GB VRAM) |
|
- **RAM**: 64GB system memory |
|
- **Platform**: Windows with WSL (Windows Subsystem for Linux) |
|
- **Libraries**: |
|
- Microsoft DeepSpeed for memory-efficient training |
|
- PyTorch with CUDA acceleration |
|
- Flash Attention for optimized attention computation |
|
|
|
### Training Notebooks |
|
|
|
The repository includes two main training notebooks: |
|
|
|
1. **CAMIE Tagger.ipynb** |
|
- Main training notebook |
|
- Dataset loading and preprocessing |
|
- Model initialization |
|
- Initial training loop with DeepSpeed integration |
|
- Tag selection optimization |
|
- Metric tracking and visualization |
|
|
|
2. **Camie Tagger Cont and Evals.ipynb** |
|
- Continuation of training from checkpoints |
|
- Comprehensive model evaluation |
|
- Per-category performance metrics |
|
- Threshold optimization |
|
- Model conversion for deployment in the app |
|
- Export functionality for the standalone application |
|
|
|
### Training Monitor |
|
|
|
The project includes a real-time training monitor accessible via browser at `localhost:5000` during training: |
|
|
|
#### Performance Tips |
|
|
|
⚠️ **Important**: For optimal training speed, keep VSCode minimized and the training monitor open in your browser. This can improve iteration speed by **3-5x** due to how the Windows/WSL graphics stack handles window focus and CUDA kernel execution. |
|
|
|
#### Monitor Features |
|
|
|
The training monitor provides three main views: |
|
|
|
##### 1. Overview Tab |
|
|
|
 |
|
|
|
- **Training Progress**: Real-time metrics including epoch, batch, speed, and time estimates |
|
- **Loss Chart**: Training and validation loss visualization |
|
- **F1 Scores**: Initial and refined F1 metrics for both training and validation |
|
|
|
##### 2. Predictions Tab |
|
|
|
 |
|
|
|
- **Image Preview**: Shows the current sample being analyzed |
|
- **Prediction Controls**: Toggle between initial and refined predictions |
|
- **Tag Analysis**: |
|
- Color-coded tag results (correct, incorrect, missing) |
|
- Confidence visualization with probability bars |
|
- Category-based organization |
|
- Filtering options for error analysis |
|
|
|
##### 3. Selection Analysis Tab |
|
|
|
 |
|
|
|
- **Selection Metrics**: Statistics on tag selection quality |
|
- Ground truth recall |
|
- Average probability for ground truth vs. non-ground truth tags |
|
- Unique tags selected |
|
- **Selection Graph**: Trends in selection quality over time |
|
- **Selected Tags Details**: Detailed view of model-selected tags with confidence scores |
|
|
|
The monitor provides invaluable insights into how the two-stage prediction model is performing, particularly how the tag selection process is working between the initial and refined prediction stages. |
|
|
|
### Training Notes |
|
|
|
- Training notebooks require WSL and likely 32GB+ of RAM to handle the dataset |
|
- Microsoft DeepSpeed was crucial for fitting the model and batches into the available VRAM |
|
- Despite hardware limitations, the model achieves impressive results |
|
- With more computational resources, the model could be trained longer on the full dataset |
|
|
|
|
|
## Support: |
|
|
|
I plan to move onto LLMs after this project as I have lots of ideas on how to improve upon them. I will update this model based on community attention. |
|
|
|
If you'd like to support further training on the complete dataset or my future projects, consider [buying me a coffee](https://www.buymeacoffee.com/yourusername). |
|
|
|
|
|
## Acknowledgments |
|
|
|
- [Danbooru](https://danbooru.donmai.us/) for the incredible dataset of tagged anime images |
|
- [p1atdev](https://huggingface.co/p1atdev) for the processed Danbooru 2024 dataset |
|
- Microsoft for DeepSpeed, which made training possible on consumer hardware |
|
- PyTorch and the open-source ML community |