metadata
language: en
tags:
- topic-drift
- conversation-analysis
- pytorch
- attention
license: mit
datasets:
- leonvanbokhorst/topic-drift-v2
metrics:
- rmse
- r2_score
model-index:
- name: topic-drift-detector
results:
- task:
type: topic-drift-detection
name: Topic Drift Detection
dataset:
name: leonvanbokhorst/topic-drift-v2
type: conversations
metrics:
- name: Test RMSE
type: rmse
value: 0.0144
- name: Test R²
type: r2
value: 0.8666
- name: Test Loss
type: loss
value: 0.0002
Topic Drift Detector Model
Version: v20241226_110212
This model detects topic drift in conversations using a streamlined attention-based architecture. Trained on the leonvanbokhorst/topic-drift-v2 dataset.
Model Architecture
- Efficient single-layer attention mechanism
- Direct pattern recognition
- Streamlined processing pipeline
- Optimized scaling factor (4.0)
- PreNorm layers with residual connections
Key Components:
Embedding Processor:
- Input dimension: 1024
- Hidden dimension: 512
- Dropout rate: 0.35
- PreNorm layers with residual connections
Attention Block:
- Single attention layer
- Feed-forward dimension: 512
- Learned position encodings
- Residual connections
Pattern Recognition:
- Direct feature extraction
- Efficient tensor operations
- Optimized memory usage
Performance Metrics
=== Full Training Results ===
Best Validation RMSE: 0.0142
Best Validation R²: 0.8711
=== Test Set Results ===
Loss: 0.0002
RMSE: 0.0144
R²: 0.8666
Training Details
- Dataset: 6400 conversations (5120 train, 640 val, 640 test)
- Window size: 8 turns
- Batch size: 32
- Learning rate: 0.0001
- Early stopping patience: 15
- Distribution regularization weight: 0.1
- Target standard deviation: 0.2
- Base embeddings: BAAI/bge-m3
Key Improvements
Simplified Architecture:
- Reduced complexity
- Focused pattern detection
- Efficient processing
- Optimized memory usage
Performance Benefits:
- Improved RMSE (0.0144)
- Strong R² score (0.8666)
- Consistent predictions
- Wide score range
Usage Example
To use the model, first install the required packages:
pip install torch transformers huggingface_hub
Then use the following code:
import torch
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download
def load_model(repo_id: str = "leonvanbokhorst/topic-drift-detector"):
# Download latest model weights
model_path = hf_hub_download(
repo_id=repo_id,
filename="models/latest/topic_drift_model.pt"
)
# Load checkpoint
checkpoint = torch.load(model_path, weights_only=True)
# Create model with same hyperparameters
model = EnhancedTopicDriftDetector(
input_dim=1024, # BGE-M3 embedding dimension
hidden_dim=checkpoint['hyperparameters']['hidden_dim']
)
# Load state dict
model.load_state_dict(checkpoint['model_state_dict'])
return model
# Load base embedding model
base_model = AutoModel.from_pretrained('BAAI/bge-m3')
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
# Load topic drift detector from Hugging Face
model = load_model()
model.eval()
# Example conversation
conversation = [
"How was your weekend?",
"It was great! Went hiking.",
"Which trail did you take?",
"The mountain loop trail.",
"That's nice. By the way, did you watch the game?",
"Yes! What an amazing match!",
"The final score was incredible.",
"I couldn't believe that last-minute goal."
]
# Get embeddings
with torch.no_grad():
inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
embeddings = base_model(**inputs).last_hidden_state.mean(dim=1) # [8, 1024]
# Reshape for model input [1, 8*1024]
conversation_embeddings = embeddings.view(1, -1)
# Get drift score
drift_scores = model(conversation_embeddings)
print(f"Topic drift score: {drift_scores.item():.4f}")
# Higher scores indicate more topic drift
Limitations
- Works best with English conversations
- Requires exactly 8 turns of conversation
- Each turn should be between 1-512 tokens
- Relies on BAAI/bge-m3 embeddings