metadata
language: en
tags:
- topic-drift
- conversation-analysis
- pytorch
- attention
- lstm
license: mit
datasets:
- leonvanbokhorst/topic-drift
metrics:
- rmse
- r2_score
model-index:
- name: topic-drift-detector
results:
- task:
type: topic-drift-detection
name: Topic Drift Detection
dataset:
name: leonvanbokhorst/topic-drift-v2
type: conversations
metrics:
- name: Test RMSE
type: rmse
value: 0.0153
- name: Test R²
type: r2
value: 0.85
- name: Test Loss
type: loss
value: 0.0002
Topic Drift Detector Model
Version: v20241225_160448
This model detects topic drift in conversations using an enhanced attention-based architecture. Trained on the leonvanbokhorst/topic-drift dataset.
Model Architecture
- Multi-head attention mechanism (4 heads)
- Bidirectional LSTM (3 layers) for pattern detection
- Dynamic weight generation
- Semantic bridge detection
- Hidden dimension: 512
- Dropout rate: 0.2
Performance Metrics
=== Full Training Results ===
Best Validation RMSE: 0.0145
Best Validation R²: 0.8656
=== Test Set Results ===
Loss: 0.0002
RMSE: 0.0153
R²: 0.8500
Training Curves
Usage
import torch
from transformers import AutoModel, AutoTokenizer
# Load base embedding model
base_model = AutoModel.from_pretrained('BAAI/bge-m3')
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
# Load topic drift detector
model = torch.load('models/v20241225_160448/topic_drift_model.pt')
model.eval()
# Prepare conversation window (8 turns)
conversation = [
"How was your weekend?",
"It was great! Went hiking.",
"Which trail did you take?",
"The mountain loop trail.",
"That's nice. By the way, did you watch the game?",
"Yes! What an amazing match!",
"The final score was incredible.",
"I couldn't believe that last-minute goal."
]
# Get embeddings
with torch.no_grad():
inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
embeddings = base_model(**inputs).last_hidden_state.mean(dim=1) # [8, 1024]
# Reshape for model input [1, 8*1024]
conversation_embeddings = embeddings.view(1, -1)
# Get drift score
drift_scores = model(conversation_embeddings)
print(f"Topic drift score: {drift_scores.item():.4f}")
# Higher scores indicate more topic drift
Training Details
- Dataset: leonvanbokhorst/topic-drift
- Window size: 8 turns
- Batch size: 32
- Learning rate: 0.0001
- Early stopping patience: 10
- Total epochs: 37 (early stopped)
- Training framework: PyTorch
- Base embeddings: BAAI/bge-m3
Limitations
- Works best with English conversations
- Requires exactly 8 turns of conversation
- Each turn should be between 1-512 tokens
- Relies on BAAI/bge-m3 embeddings