--- language: en tags: - topic-drift - conversation-analysis - pytorch - attention license: mit datasets: - leonvanbokhorst/topic-drift-v2 metrics: - rmse - r2_score model-index: - name: topic-drift-detector results: - task: type: topic-drift-detection name: Topic Drift Detection dataset: name: leonvanbokhorst/topic-drift-v2 type: conversations metrics: - name: Test RMSE type: rmse value: 0.0144 - name: Test R² type: r2 value: 0.8666 - name: Test Loss type: loss value: 0.0002 --- # Topic Drift Detector Model ## Version: v20241226_110212 This model detects topic drift in conversations using a streamlined attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset. ## Model Architecture - Efficient single-layer attention mechanism - Direct pattern recognition - Streamlined processing pipeline - Optimized scaling factor (4.0) - PreNorm layers with residual connections ### Key Components: 1. **Embedding Processor**: - Input dimension: 1024 - Hidden dimension: 512 - Dropout rate: 0.35 - PreNorm layers with residual connections 2. **Attention Block**: - Single attention layer - Feed-forward dimension: 512 - Learned position encodings - Residual connections 3. **Pattern Recognition**: - Direct feature extraction - Efficient tensor operations - Optimized memory usage ## Performance Metrics ```txt === Full Training Results === Best Validation RMSE: 0.0142 Best Validation R²: 0.8711 === Test Set Results === Loss: 0.0002 RMSE: 0.0144 R²: 0.8666 ``` ## Training Details - Dataset: 6400 conversations (5120 train, 640 val, 640 test) - Window size: 8 turns - Batch size: 32 - Learning rate: 0.0001 - Early stopping patience: 15 - Distribution regularization weight: 0.1 - Target standard deviation: 0.2 - Base embeddings: BAAI/bge-m3 ## Key Improvements 1. **Simplified Architecture**: - Reduced complexity - Focused pattern detection - Efficient processing - Optimized memory usage 2. **Performance Benefits**: - Improved RMSE (0.0144) - Strong R² score (0.8666) - Consistent predictions - Wide score range ## Usage Example To use the model, first install the required packages: ```bash pip install torch transformers huggingface_hub ``` Then use the following code: ```python import torch from transformers import AutoModel, AutoTokenizer from huggingface_hub import hf_hub_download def load_model(repo_id: str = "leonvanbokhorst/topic-drift-detector"): # Download latest model weights model_path = hf_hub_download( repo_id=repo_id, filename="models/latest/topic_drift_model.pt" ) # Load checkpoint checkpoint = torch.load(model_path, weights_only=True) # Create model with same hyperparameters model = EnhancedTopicDriftDetector( input_dim=1024, # BGE-M3 embedding dimension hidden_dim=checkpoint['hyperparameters']['hidden_dim'] ) # Load state dict model.load_state_dict(checkpoint['model_state_dict']) return model # Load base embedding model base_model = AutoModel.from_pretrained('BAAI/bge-m3') tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3') # Load topic drift detector from Hugging Face model = load_model() model.eval() # Example conversation conversation = [ "How was your weekend?", "It was great! Went hiking.", "Which trail did you take?", "The mountain loop trail.", "That's nice. By the way, did you watch the game?", "Yes! What an amazing match!", "The final score was incredible.", "I couldn't believe that last-minute goal." ] # Get embeddings with torch.no_grad(): inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt') embeddings = base_model(**inputs).last_hidden_state.mean(dim=1) # [8, 1024] # Reshape for model input [1, 8*1024] conversation_embeddings = embeddings.view(1, -1) # Get drift score drift_scores = model(conversation_embeddings) print(f"Topic drift score: {drift_scores.item():.4f}") # Higher scores indicate more topic drift ``` ## Limitations - Works best with English conversations - Requires exactly 8 turns of conversation - Each turn should be between 1-512 tokens - Relies on BAAI/bge-m3 embeddings ## Training Curves ![Training Curves](plots/v20241226_110212/training_curves.png)