camie-tagger / README.md
Camais03's picture
Update README.md
8eac823 verified
|
raw
history blame
15.3 kB
metadata
license: gpl-3.0
datasets:
  - p1atdev/danbooru-2024
metrics:
  - f1
tags:
  - art
  - code

Usage

After installation, run the application by executing setup.bat. This launches a web interface where you can:

  • Upload your own images or select from example images
  • Choose different threshold profiles
  • Adjust category-specific thresholds
  • View predictions organized by category
  • Filter and sort tags based on confidence# Anime Image Tagger

An advanced deep learning model for automatically tagging anime/manga illustrations with relevant tags across multiple categories, achieving 61% F1 score across 70,000+ possible tags on a test set of 20,116 samples.

Key Highlights

  • Efficient Training: Completed on just a single RTX 3060 GPU (12GB VRAM)
  • Fast Convergence: Trained on 7,024,392 samples (3.52 epochs) in 1,756,098 batches
  • Comprehensive Coverage: 70,000+ tags across 7 categories (general, character, copyright, artist, meta, rating, year)
  • Innovative Architecture: Two-stage prediction model with cross-attention for tag context
  • User-Friendly Interface: Easy-to-use application with customizable thresholds

This project demonstrates that high-quality anime image tagging models can be trained on consumer hardware with the right optimization techniques.

Features

  • Multi-category tagging system: Handles general tags, characters, copyright (series), artists, meta information, and content ratings
  • High performance: 61% F1 score across 70,000+ possible tags
  • Dual-mode operation: Full model for best quality or Initial-only mode for reduced VRAM usage
  • Windows compatibility: Initial-only mode works on Windows without Flash Attention
  • Streamlit web interface: User-friendly UI for uploading and analyzing images
  • Adjustable threshold profiles: Overall, Weighted, Category-specific, High Precision, and High Recall profiles
  • Fine-grained control: Per-category threshold adjustments for precision-recall tradeoffs

Loss Function

The model employs a specialized UnifiedFocalLoss to address the extreme class imbalance inherent in multi-label tag prediction:

class UnifiedFocalLoss(nn.Module):
    def __init__(self, device=None, gamma=2.0, alpha=0.25, lambda_initial=0.4):
        # Implementation details...

Key Components

  1. Focal Loss Mechanism:

    • Down-weights well-classified examples (γ=2.0) to focus training on difficult tags
    • Addresses the extreme imbalance between positive and negative examples (often 100:1 or worse)
    • Uses α=0.25 to balance positive/negative examples across 70,000+ possible tags
  2. Two-stage Weighting:

    • Combines losses from both prediction stages (initial_predictions and refined_predictions)
    • Uses λ=0.4 to weight the initial prediction loss, giving more importance (0.6) to refined predictions
    • This encourages the model to improve predictions in the refinement stage while still maintaining strong initial predictions
  3. Per-sample Statistics:

    • Tracks separate metrics for positive and negative samples
    • Provides detailed debugging information about prediction distributions
    • Enables analysis of which tag categories are performing well/poorly

This loss function was essential for achieving high F1 scores across diverse tag categories despite the extreme classification challenge of 70,000+ possible tags.

DeepSpeed Configuration

Microsoft DeepSpeed was crucial for training this model on consumer hardware. The project uses a carefully tuned configuration to maximize efficiency:

def create_deepspeed_config(
    config_path,
    learning_rate=3e-4,
    weight_decay=0.01,
    num_train_samples=None,
    micro_batch_size=4,
    grad_accum_steps=8
):
    # Implementation details...

Key Optimizations

  1. Memory Efficiency:

    • ZeRO Stage 2: Partitions optimizer states and gradients, dramatically reducing memory requirements
    • Activation Checkpointing: Trades computation for memory by recomputing activations during backpropagation
    • Contiguous Memory Optimization: Reduces memory fragmentation
  2. Mixed Precision Training:

    • FP16 Mode: Uses half-precision (16-bit) for most calculations, with automatic loss scaling
    • Initial Scale Power: Set to 16 for stable convergence with large batch sizes
  3. Gradient Accumulation:

    • Micro-batch size of 4 with 8 gradient accumulation steps
    • Effective batch size of 32 while only requiring memory for 4 samples at once
  4. Learning Rate Schedule:

    • WarmupLR scheduler with gradual increase from 3e-6 to 3e-4
    • Warmup over 1/4 of an epoch to stabilize early training

This configuration allowed the model to train efficiently with only 12GB of VRAM while maintaining numerical stability across millions of training examples with 70,000+ output dimensions.

Dataset

The model was trained on a carefully filtered subset of the Danbooru 2024 dataset, which contains a vast collection of anime/manga illustrations with comprehensive tagging.

Filtering Process

The dataset was filtered with the following constraints:

# Minimum tags per category required for each image
min_tag_counts = {
    'general': 25, 
    'character': 1, 
    'copyright': 1, 
    'artist': 0, 
    'meta': 0
}

# Minimum samples per tag required for tag to be included
min_tag_samples = {
    'general': 20, 
    'character': 40, 
    'copyright': 50, 
    'artist': 200, 
    'meta': 50
}

This filtering process:

  1. First removed low-sample tags (tags with fewer occurrences than specified in min_tag_samples)
  2. Then removed images with insufficient tags per category (as specified in min_tag_counts)

Training Data

  • Starting dataset size: ~3,000,000 filtered images
  • Training subset: 2,000,000 images (due to storage and time constraints)
  • Training duration: 3.5 epochs

The model could potentially achieve even higher accuracy with more training epochs and the full dataset.

Preprocessing

Images were preprocessed with minimal transformations:

  • Tensor normalization (scaled to 0-1 range)
  • Resized while maintaining original aspect ratio
  • No additional augmentations were applied

Model Architecture

The model uses a novel two-stage prediction approach that achieves superior performance compared to traditional single-stage models:

Image Feature Extraction

  • Backbone: EfficientNet V2-L extracts high-quality visual features from input images
  • Spatial Pooling: Adaptive averaging converts spatial features to a compact 1280-dimensional embedding

Initial Prediction Stage

  • Direct classification from image features through a multi-layer classifier
  • Bottleneck architecture with LayerNorm and GELU activations between linear layers
  • Outputs initial tag probabilities across all 70,000+ possible tags

Tag Context Mechanism

  • Top predicted tags are embedded using a shared embedding space
  • Self-attention layer allows tags to influence each other based on co-occurrence patterns
  • Normalized tag embeddings represent a coherent "tag context" for the image

Cross-Attention Refinement

  • Image features and tag embeddings interact through cross-attention
  • Each dimension of the image features attends to relevant dimensions in the tag space
  • This creates a bidirectional flow of information between visual features and semantic tags

Refined Predictions

  • Fused features (original + cross-attended) feed into a final classifier
  • Residual connection ensures initial predictions are preserved when beneficial
  • Temperature scaling provides calibrated probability outputs

This dual-stage approach allows the model to leverage tag co-occurrence patterns and semantic relationships, improving accuracy without increasing the parameter count significantly.

Installation

Simply run the included setup script to install all dependencies:

setup.bat

This will automatically set up all necessary packages for the application.

Requirements

  • Python 3.11.9 specifically (newer versions are incompatible)
  • PyTorch 1.10+
  • Streamlit
  • PIL/Pillow
  • NumPy
  • Flash Attention (note: doesn't work properly on Windows)

Running the Application

The application is located in the app folder and can be launched via the setup script:

  1. Run setup.bat to install dependencies
  2. The Streamlit interface will automatically open in your browser
  3. If the browser doesn't open automatically, navigate to http://localhost:8501

Model Details

Tag Categories

The model recognizes tags across these categories:

  • General: Visual elements, concepts, clothing, etc.
  • Character: Individual characters appearing in the image
  • Copyright: Source material (anime, manga, game)
  • Artist: Creator of the artwork
  • Meta: Meta information about the image
  • Rating: Content rating
  • Year: Year of upload

Performance Notes

The full model with refined predictions outperforms the initial-only model, though the performance gap is surprisingly small given the same parameter count. This is an interesting architectural finding - the refined predictions layer adds significant value without substantial computational overhead.

This efficiency makes the initial-only model particularly valuable for Windows users or systems with limited VRAM, as they can still achieve near-optimal performance without requiring Flash Attention.

In benchmarks, the model achieved a 61% F1 score across all categories, which is remarkable considering the extreme multi-label classification challenge of 70,000+ possible tags. The model performs particularly well on general tags and character recognition.

Threshold Profiles

  • Overall: Single threshold applied to all categories
  • Weighted: Threshold optimized for balanced performance across categories
  • Category-specific: Different thresholds for each category
  • High Precision: Higher thresholds for more confident predictions
  • High Recall: Lower thresholds to capture more potential tags

Windows Compatibility

The full model uses Flash Attention, which does not work properly on Windows. For Windows users:

  • The application automatically defaults to the Initial-only model
  • Performance difference is minimal (0.2% absolute F1 score reduction, from 61.6% to 61.4%)
  • The Initial-only model still uses the same powerful EfficientNet backbone and initial classifier

Web Interface Guide

The interface is divided into three main sections:

  1. Model Selection (Sidebar)

    • Choose between Full Model or Initial-only Model
    • View model information and memory usage
  2. Image Upload (Left Panel)

    • Upload your own images or select from examples
    • View the selected image
  3. Tagging Controls (Right Panel)

    • Select threshold profile
    • Adjust thresholds for precision-recall tradeoff
    • Configure display options
    • View predictions organized by category

Display Options

  • Show all tags: Display all tags including those below threshold
  • Compact view: Hide progress bars for cleaner display
  • Minimum confidence: Filter out low-confidence predictions
  • Category selection: Choose which categories to include in the summary

Interface Screenshots

Application Interface

Tag Results Example

Training Environment

The model was trained using surprisingly modest hardware:

  • GPU: Single NVIDIA RTX 3060 (12GB VRAM)
  • RAM: 64GB system memory
  • Platform: Windows with WSL (Windows Subsystem for Linux)
  • Libraries:
    • Microsoft DeepSpeed for memory-efficient training
    • PyTorch with CUDA acceleration
    • Flash Attention for optimized attention computation

Training Notebooks

The repository includes two main training notebooks:

  1. CAMIE Tagger.ipynb

    • Main training notebook
    • Dataset loading and preprocessing
    • Model initialization
    • Initial training loop with DeepSpeed integration
    • Tag selection optimization
    • Metric tracking and visualization
  2. Camie Tagger Cont and Evals.ipynb

    • Continuation of training from checkpoints
    • Comprehensive model evaluation
    • Per-category performance metrics
    • Threshold optimization
    • Model conversion for deployment in the app
    • Export functionality for the standalone application

Training Monitor

The project includes a real-time training monitor accessible via browser at localhost:5000 during training:

Performance Tips

⚠️ Important: For optimal training speed, keep VSCode minimized and the training monitor open in your browser. This can improve iteration speed by 3-5x due to how the Windows/WSL graphics stack handles window focus and CUDA kernel execution.

Monitor Features

The training monitor provides three main views:

1. Overview Tab

Overview Tab

  • Training Progress: Real-time metrics including epoch, batch, speed, and time estimates
  • Loss Chart: Training and validation loss visualization
  • F1 Scores: Initial and refined F1 metrics for both training and validation
2. Predictions Tab

Predictions Tab

  • Image Preview: Shows the current sample being analyzed
  • Prediction Controls: Toggle between initial and refined predictions
  • Tag Analysis:
    • Color-coded tag results (correct, incorrect, missing)
    • Confidence visualization with probability bars
    • Category-based organization
    • Filtering options for error analysis
3. Selection Analysis Tab

Selection Analysis Tab

  • Selection Metrics: Statistics on tag selection quality
    • Ground truth recall
    • Average probability for ground truth vs. non-ground truth tags
    • Unique tags selected
  • Selection Graph: Trends in selection quality over time
  • Selected Tags Details: Detailed view of model-selected tags with confidence scores

The monitor provides invaluable insights into how the two-stage prediction model is performing, particularly how the tag selection process is working between the initial and refined prediction stages.

Training Notes

  • Training notebooks require WSL and likely 32GB+ of RAM to handle the dataset
  • Microsoft DeepSpeed was crucial for fitting the model and batches into the available VRAM
  • Despite hardware limitations, the model achieves impressive results
  • With more computational resources, the model could be trained longer on the full dataset

Support:

I plan to move onto LLMs after this project as I have lots of ideas on how to improve upon them. I will update this model based on community attention.

If you'd like to support further training on the complete dataset or my future projects, consider buying me a coffee.

Acknowledgments

  • Danbooru for the incredible dataset of tagged anime images
  • p1atdev for the processed Danbooru 2024 dataset
  • Microsoft for DeepSpeed, which made training possible on consumer hardware
  • PyTorch and the open-source ML community