--- language: en tags: - topic-drift - conversation-analysis - pytorch - attention - lstm license: mit datasets: - leonvanbokhorst/topic-drift metrics: - rmse - r2_score model-index: - name: topic-drift-detector results: - task: type: topic-drift-detection name: Topic Drift Detection dataset: name: leonvanbokhorst/topic-drift-v2 type: conversations metrics: - name: Test RMSE type: rmse value: 0.0153 - name: Test R² type: r2 value: 0.8500 - name: Test Loss type: loss value: 0.0002 --- # Topic Drift Detector Model ## Version: v20241225_160448 This model detects topic drift in conversations using an enhanced attention-based architecture. Trained on the [leonvanbokhorst/topic-drift](https://huggingface.co/datasets/leonvanbokhorst/topic-drift) dataset. ## Model Architecture - Multi-head attention mechanism (4 heads) - Bidirectional LSTM (3 layers) for pattern detection - Dynamic weight generation - Semantic bridge detection - Hidden dimension: 512 - Dropout rate: 0.2 ## Performance Metrics ```txt === Full Training Results === Best Validation RMSE: 0.0145 Best Validation R²: 0.8656 === Test Set Results === Loss: 0.0002 RMSE: 0.0153 R²: 0.8500 ``` ## Training Curves ![Training Curves](plots/v20241225_160448/training_curves.png) ## Usage ```python import torch from transformers import AutoModel, AutoTokenizer # Load base embedding model base_model = AutoModel.from_pretrained('BAAI/bge-m3') tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3') # Load topic drift detector model = torch.load('models/v20241225_160448/topic_drift_model.pt') model.eval() # Prepare conversation window (8 turns) conversation = [ "How was your weekend?", "It was great! Went hiking.", "Which trail did you take?", "The mountain loop trail.", "That's nice. By the way, did you watch the game?", "Yes! What an amazing match!", "The final score was incredible.", "I couldn't believe that last-minute goal." ] # Get embeddings with torch.no_grad(): inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt') embeddings = base_model(**inputs).last_hidden_state.mean(dim=1) # [8, 1024] # Reshape for model input [1, 8*1024] conversation_embeddings = embeddings.view(1, -1) # Get drift score drift_scores = model(conversation_embeddings) print(f"Topic drift score: {drift_scores.item():.4f}") # Higher scores indicate more topic drift ``` ## Training Details - Dataset: [leonvanbokhorst/topic-drift](https://huggingface.co/datasets/leonvanbokhorst/topic-drift) - Window size: 8 turns - Batch size: 32 - Learning rate: 0.0001 - Early stopping patience: 10 - Total epochs: 37 (early stopped) - Training framework: PyTorch - Base embeddings: BAAI/bge-m3 ## Limitations - Works best with English conversations - Requires exactly 8 turns of conversation - Each turn should be between 1-512 tokens - Relies on BAAI/bge-m3 embeddings