|
--- |
|
language: en |
|
tags: |
|
- topic-drift |
|
- conversation-analysis |
|
- pytorch |
|
- attention |
|
license: mit |
|
datasets: |
|
- leonvanbokhorst/topic-drift-v2 |
|
metrics: |
|
- rmse |
|
- r2_score |
|
model-index: |
|
- name: topic-drift-detector |
|
results: |
|
- task: |
|
type: topic-drift-detection |
|
name: Topic Drift Detection |
|
dataset: |
|
name: leonvanbokhorst/topic-drift-v2 |
|
type: conversations |
|
metrics: |
|
- name: Test RMSE |
|
type: rmse |
|
value: 0.0144 |
|
- name: Test R² |
|
type: r2 |
|
value: 0.8666 |
|
- name: Test Loss |
|
type: loss |
|
value: 0.0002 |
|
--- |
|
|
|
# Topic Drift Detector Model |
|
|
|
## Version: v20241226_110212 |
|
|
|
This model detects topic drift in conversations using a streamlined attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset. |
|
|
|
## Model Architecture |
|
- Efficient single-layer attention mechanism |
|
- Direct pattern recognition |
|
- Streamlined processing pipeline |
|
- Optimized scaling factor (4.0) |
|
- PreNorm layers with residual connections |
|
|
|
### Key Components: |
|
1. **Embedding Processor**: |
|
- Input dimension: 1024 |
|
- Hidden dimension: 512 |
|
- Dropout rate: 0.35 |
|
- PreNorm layers with residual connections |
|
|
|
2. **Attention Block**: |
|
- Single attention layer |
|
- Feed-forward dimension: 512 |
|
- Learned position encodings |
|
- Residual connections |
|
|
|
3. **Pattern Recognition**: |
|
- Direct feature extraction |
|
- Efficient tensor operations |
|
- Optimized memory usage |
|
|
|
## Performance Metrics |
|
```txt |
|
=== Full Training Results === |
|
Best Validation RMSE: 0.0142 |
|
Best Validation R²: 0.8711 |
|
|
|
=== Test Set Results === |
|
Loss: 0.0002 |
|
RMSE: 0.0144 |
|
R²: 0.8666 |
|
``` |
|
|
|
## Training Details |
|
- Dataset: 6400 conversations (5120 train, 640 val, 640 test) |
|
- Window size: 8 turns |
|
- Batch size: 32 |
|
- Learning rate: 0.0001 |
|
- Early stopping patience: 15 |
|
- Distribution regularization weight: 0.1 |
|
- Target standard deviation: 0.2 |
|
- Base embeddings: BAAI/bge-m3 |
|
|
|
## Key Improvements |
|
1. **Simplified Architecture**: |
|
- Reduced complexity |
|
- Focused pattern detection |
|
- Efficient processing |
|
- Optimized memory usage |
|
|
|
2. **Performance Benefits**: |
|
- Improved RMSE (0.0144) |
|
- Strong R² score (0.8666) |
|
- Consistent predictions |
|
- Wide score range |
|
|
|
## Usage Example |
|
|
|
To use the model, first install the required packages: |
|
```bash |
|
pip install torch transformers huggingface_hub |
|
``` |
|
|
|
Then use the following code: |
|
```python |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
from huggingface_hub import hf_hub_download |
|
|
|
def load_model(repo_id: str = "leonvanbokhorst/topic-drift-detector"): |
|
# Download latest model weights |
|
model_path = hf_hub_download( |
|
repo_id=repo_id, |
|
filename="models/latest/topic_drift_model.pt" |
|
) |
|
|
|
# Load checkpoint |
|
checkpoint = torch.load(model_path, weights_only=True) |
|
|
|
# Create model with same hyperparameters |
|
model = EnhancedTopicDriftDetector( |
|
input_dim=1024, # BGE-M3 embedding dimension |
|
hidden_dim=checkpoint['hyperparameters']['hidden_dim'] |
|
) |
|
|
|
# Load state dict |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
return model |
|
|
|
# Load base embedding model |
|
base_model = AutoModel.from_pretrained('BAAI/bge-m3') |
|
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3') |
|
|
|
# Load topic drift detector from Hugging Face |
|
model = load_model() |
|
model.eval() |
|
|
|
# Example conversation |
|
conversation = [ |
|
"How was your weekend?", |
|
"It was great! Went hiking.", |
|
"Which trail did you take?", |
|
"The mountain loop trail.", |
|
"That's nice. By the way, did you watch the game?", |
|
"Yes! What an amazing match!", |
|
"The final score was incredible.", |
|
"I couldn't believe that last-minute goal." |
|
] |
|
|
|
# Get embeddings |
|
with torch.no_grad(): |
|
inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt') |
|
embeddings = base_model(**inputs).last_hidden_state.mean(dim=1) # [8, 1024] |
|
|
|
# Reshape for model input [1, 8*1024] |
|
conversation_embeddings = embeddings.view(1, -1) |
|
|
|
# Get drift score |
|
drift_scores = model(conversation_embeddings) |
|
|
|
print(f"Topic drift score: {drift_scores.item():.4f}") |
|
# Higher scores indicate more topic drift |
|
``` |
|
|
|
## Limitations |
|
- Works best with English conversations |
|
- Requires exactly 8 turns of conversation |
|
- Each turn should be between 1-512 tokens |
|
- Relies on BAAI/bge-m3 embeddings |
|
|
|
## Training Curves |
|
![Training Curves](plots/v20241226_110212/training_curves.png) |
|
|