--- license: gpl-3.0 datasets: - p1atdev/danbooru-2024 metrics: - f1 tags: - art - code --- ## Usage After installation, run the application by executing `setup.bat`. This launches a web interface where you can: - Upload your own images or select from example images - Choose different threshold profiles - Adjust category-specific thresholds - View predictions organized by category - Filter and sort tags based on confidence# Anime Image Tagger An advanced deep learning model for automatically tagging anime/manga illustrations with relevant tags across multiple categories, achieving **61% F1 score** across 70,000+ possible tags on a test set of 20,116 samples. ## Key Highlights - **Efficient Training**: Completed on just a single RTX 3060 GPU (12GB VRAM) - **Fast Convergence**: Trained on 7,024,392 samples (3.52 epochs) in 1,756,098 batches - **Comprehensive Coverage**: 70,000+ tags across 7 categories (general, character, copyright, artist, meta, rating, year) - **Innovative Architecture**: Two-stage prediction model with cross-attention for tag context - **User-Friendly Interface**: Easy-to-use application with customizable thresholds *This project demonstrates that high-quality anime image tagging models can be trained on consumer hardware with the right optimization techniques.* ## Features - **Multi-category tagging system**: Handles general tags, characters, copyright (series), artists, meta information, and content ratings - **High performance**: 61% F1 score across 70,000+ possible tags - **Dual-mode operation**: Full model for best quality or Initial-only mode for reduced VRAM usage - **Windows compatibility**: Initial-only mode works on Windows without Flash Attention - **Streamlit web interface**: User-friendly UI for uploading and analyzing images - **Adjustable threshold profiles**: Overall, Weighted, Category-specific, High Precision, and High Recall profiles - **Fine-grained control**: Per-category threshold adjustments for precision-recall tradeoffs ## Loss Function The model employs a specialized `UnifiedFocalLoss` to address the extreme class imbalance inherent in multi-label tag prediction: ```python class UnifiedFocalLoss(nn.Module): def __init__(self, device=None, gamma=2.0, alpha=0.25, lambda_initial=0.4): # Implementation details... ``` ### Key Components 1. **Focal Loss Mechanism**: - Down-weights well-classified examples (γ=2.0) to focus training on difficult tags - Addresses the extreme imbalance between positive and negative examples (often 100:1 or worse) - Uses α=0.25 to balance positive/negative examples across 70,000+ possible tags 2. **Two-stage Weighting**: - Combines losses from both prediction stages (`initial_predictions` and `refined_predictions`) - Uses λ=0.4 to weight the initial prediction loss, giving more importance (0.6) to refined predictions - This encourages the model to improve predictions in the refinement stage while still maintaining strong initial predictions 3. **Per-sample Statistics**: - Tracks separate metrics for positive and negative samples - Provides detailed debugging information about prediction distributions - Enables analysis of which tag categories are performing well/poorly This loss function was essential for achieving high F1 scores across diverse tag categories despite the extreme classification challenge of 70,000+ possible tags. ## DeepSpeed Configuration Microsoft DeepSpeed was crucial for training this model on consumer hardware. The project uses a carefully tuned configuration to maximize efficiency: ```python def create_deepspeed_config( config_path, learning_rate=3e-4, weight_decay=0.01, num_train_samples=None, micro_batch_size=4, grad_accum_steps=8 ): # Implementation details... ``` ### Key Optimizations 1. **Memory Efficiency**: - **ZeRO Stage 2**: Partitions optimizer states and gradients, dramatically reducing memory requirements - **Activation Checkpointing**: Trades computation for memory by recomputing activations during backpropagation - **Contiguous Memory Optimization**: Reduces memory fragmentation 2. **Mixed Precision Training**: - **FP16 Mode**: Uses half-precision (16-bit) for most calculations, with automatic loss scaling - **Initial Scale Power**: Set to 16 for stable convergence with large batch sizes 3. **Gradient Accumulation**: - Micro-batch size of 4 with 8 gradient accumulation steps - Effective batch size of 32 while only requiring memory for 4 samples at once 4. **Learning Rate Schedule**: - WarmupLR scheduler with gradual increase from 3e-6 to 3e-4 - Warmup over 1/4 of an epoch to stabilize early training This configuration allowed the model to train efficiently with only 12GB of VRAM while maintaining numerical stability across millions of training examples with 70,000+ output dimensions. ## Dataset The model was trained on a carefully filtered subset of the [Danbooru 2024 dataset](https://huggingface.co/datasets/p1atdev/danbooru-2024), which contains a vast collection of anime/manga illustrations with comprehensive tagging. ### Filtering Process The dataset was filtered with the following constraints: ```python # Minimum tags per category required for each image min_tag_counts = { 'general': 25, 'character': 1, 'copyright': 1, 'artist': 0, 'meta': 0 } # Minimum samples per tag required for tag to be included min_tag_samples = { 'general': 20, 'character': 40, 'copyright': 50, 'artist': 200, 'meta': 50 } ``` This filtering process: 1. First removed low-sample tags (tags with fewer occurrences than specified in `min_tag_samples`) 2. Then removed images with insufficient tags per category (as specified in `min_tag_counts`) ### Training Data - **Starting dataset size**: ~3,000,000 filtered images - **Training subset**: 2,000,000 images (due to storage and time constraints) - **Training duration**: 3.5 epochs The model could potentially achieve even higher accuracy with more training epochs and the full dataset. ### Preprocessing Images were preprocessed with minimal transformations: - Tensor normalization (scaled to 0-1 range) - Resized while maintaining original aspect ratio - No additional augmentations were applied ## Model Architecture The model uses a novel two-stage prediction approach that achieves superior performance compared to traditional single-stage models: ### Image Feature Extraction - **Backbone**: EfficientNet V2-L extracts high-quality visual features from input images - **Spatial Pooling**: Adaptive averaging converts spatial features to a compact 1280-dimensional embedding ### Initial Prediction Stage - Direct classification from image features through a multi-layer classifier - Bottleneck architecture with LayerNorm and GELU activations between linear layers - Outputs initial tag probabilities across all 70,000+ possible tags ### Tag Context Mechanism - Top predicted tags are embedded using a shared embedding space - Self-attention layer allows tags to influence each other based on co-occurrence patterns - Normalized tag embeddings represent a coherent "tag context" for the image ### Cross-Attention Refinement - Image features and tag embeddings interact through cross-attention - Each dimension of the image features attends to relevant dimensions in the tag space - This creates a bidirectional flow of information between visual features and semantic tags ### Refined Predictions - Fused features (original + cross-attended) feed into a final classifier - Residual connection ensures initial predictions are preserved when beneficial - Temperature scaling provides calibrated probability outputs This dual-stage approach allows the model to leverage tag co-occurrence patterns and semantic relationships, improving accuracy without increasing the parameter count significantly. ## Installation Simply run the included setup script to install all dependencies: ``` setup.bat ``` This will automatically set up all necessary packages for the application. ### Requirements - **Python 3.11.9 specifically** (newer versions are incompatible) - PyTorch 1.10+ - Streamlit - PIL/Pillow - NumPy - Flash Attention (note: doesn't work properly on Windows) ### Running the Application The application is located in the `app` folder and can be launched via the setup script: 1. Run `setup.bat` to install dependencies 2. The Streamlit interface will automatically open in your browser 3. If the browser doesn't open automatically, navigate to http://localhost:8501 ## Model Details ### Tag Categories The model recognizes tags across these categories: - **General**: Visual elements, concepts, clothing, etc. - **Character**: Individual characters appearing in the image - **Copyright**: Source material (anime, manga, game) - **Artist**: Creator of the artwork - **Meta**: Meta information about the image - **Rating**: Content rating - **Year**: Year of upload ### Performance Notes The full model with refined predictions outperforms the initial-only model, though the performance gap is surprisingly small given the same parameter count. This is an interesting architectural finding - the refined predictions layer adds significant value without substantial computational overhead. This efficiency makes the initial-only model particularly valuable for Windows users or systems with limited VRAM, as they can still achieve near-optimal performance without requiring Flash Attention. In benchmarks, the model achieved a 61% F1 score across all categories, which is remarkable considering the extreme multi-label classification challenge of 70,000+ possible tags. The model performs particularly well on general tags and character recognition. ### Threshold Profiles - **Overall**: Single threshold applied to all categories - **Weighted**: Threshold optimized for balanced performance across categories - **Category-specific**: Different thresholds for each category - **High Precision**: Higher thresholds for more confident predictions - **High Recall**: Lower thresholds to capture more potential tags ## Windows Compatibility The full model uses Flash Attention, which does not work properly on Windows. For Windows users: - The application automatically defaults to the Initial-only model - Performance difference is minimal (0.2% absolute F1 score reduction, from 61.6% to 61.4%) - The Initial-only model still uses the same powerful EfficientNet backbone and initial classifier ## Web Interface Guide The interface is divided into three main sections: 1. **Model Selection** (Sidebar) - Choose between Full Model or Initial-only Model - View model information and memory usage 2. **Image Upload** (Left Panel) - Upload your own images or select from examples - View the selected image 3. **Tagging Controls** (Right Panel) - Select threshold profile - Adjust thresholds for precision-recall tradeoff - Configure display options - View predictions organized by category ### Display Options - **Show all tags**: Display all tags including those below threshold - **Compact view**: Hide progress bars for cleaner display - **Minimum confidence**: Filter out low-confidence predictions - **Category selection**: Choose which categories to include in the summary ### Interface Screenshots ![Application Interface](images/app_screenshot.png) ![Tag Results Example](image/tag_results_example.png) ## Training Environment The model was trained using surprisingly modest hardware: - **GPU**: Single NVIDIA RTX 3060 (12GB VRAM) - **RAM**: 64GB system memory - **Platform**: Windows with WSL (Windows Subsystem for Linux) - **Libraries**: - Microsoft DeepSpeed for memory-efficient training - PyTorch with CUDA acceleration - Flash Attention for optimized attention computation ### Training Notebooks The repository includes two main training notebooks: 1. **CAMIE Tagger.ipynb** - Main training notebook - Dataset loading and preprocessing - Model initialization - Initial training loop with DeepSpeed integration - Tag selection optimization - Metric tracking and visualization 2. **Camie Tagger Cont and Evals.ipynb** - Continuation of training from checkpoints - Comprehensive model evaluation - Per-category performance metrics - Threshold optimization - Model conversion for deployment in the app - Export functionality for the standalone application ### Training Monitor The project includes a real-time training monitor accessible via browser at `localhost:5000` during training: #### Performance Tips ⚠️ **Important**: For optimal training speed, keep VSCode minimized and the training monitor open in your browser. This can improve iteration speed by **3-5x** due to how the Windows/WSL graphics stack handles window focus and CUDA kernel execution. #### Monitor Features The training monitor provides three main views: ##### 1. Overview Tab ![Overview Tab](images/training_monitor_overview.png) - **Training Progress**: Real-time metrics including epoch, batch, speed, and time estimates - **Loss Chart**: Training and validation loss visualization - **F1 Scores**: Initial and refined F1 metrics for both training and validation ##### 2. Predictions Tab ![Predictions Tab](images/training_monitor_predictions.png) - **Image Preview**: Shows the current sample being analyzed - **Prediction Controls**: Toggle between initial and refined predictions - **Tag Analysis**: - Color-coded tag results (correct, incorrect, missing) - Confidence visualization with probability bars - Category-based organization - Filtering options for error analysis ##### 3. Selection Analysis Tab ![Selection Analysis Tab](images/training_monitor_selection.png) - **Selection Metrics**: Statistics on tag selection quality - Ground truth recall - Average probability for ground truth vs. non-ground truth tags - Unique tags selected - **Selection Graph**: Trends in selection quality over time - **Selected Tags Details**: Detailed view of model-selected tags with confidence scores The monitor provides invaluable insights into how the two-stage prediction model is performing, particularly how the tag selection process is working between the initial and refined prediction stages. ### Training Notes - Training notebooks require WSL and likely 32GB+ of RAM to handle the dataset - Microsoft DeepSpeed was crucial for fitting the model and batches into the available VRAM - Despite hardware limitations, the model achieves impressive results - With more computational resources, the model could be trained longer on the full dataset ## Support: I plan to move onto LLMs after this project as I have lots of ideas on how to improve upon them. I will update this model based on community attention. If you'd like to support further training on the complete dataset or my future projects, consider [buying me a coffee](https://www.buymeacoffee.com/yourusername). ## Acknowledgments - [Danbooru](https://danbooru.donmai.us/) for the incredible dataset of tagged anime images - [p1atdev](https://huggingface.co/p1atdev) for the processed Danbooru 2024 dataset - Microsoft for DeepSpeed, which made training possible on consumer hardware - PyTorch and the open-source ML community