|
--- |
|
language: en |
|
tags: |
|
- topic-drift |
|
- conversation-analysis |
|
- pytorch |
|
- attention |
|
license: mit |
|
datasets: |
|
- leonvanbokhorst/topic-drift-v2 |
|
metrics: |
|
- rmse |
|
- r2_score |
|
model-index: |
|
- name: topic-drift-detector |
|
results: |
|
- task: |
|
type: topic-drift-detection |
|
name: Topic Drift Detection |
|
dataset: |
|
name: leonvanbokhorst/topic-drift-v2 |
|
type: conversations |
|
metrics: |
|
- name: Test RMSE |
|
type: rmse |
|
value: 0.0144 |
|
- name: Test R² |
|
type: r2 |
|
value: 0.8666 |
|
--- |
|
|
|
# Topic Drift Detector Model |
|
|
|
## Version: v20241226_114030 |
|
|
|
This model detects topic drift in conversations using an efficient attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset. |
|
|
|
## Model Architecture |
|
|
|
### Key Components: |
|
1. **Input Processing**: |
|
- Input dimension: 1024 (BGE-M3 embeddings) |
|
- Hidden dimension: 512 |
|
- Sequence length: 8 turns |
|
|
|
2. **Attention Block**: |
|
- Multi-head attention (4 heads) |
|
- PreNorm layers with residual connections |
|
- Dropout rate: 0.1 |
|
|
|
3. **Feed-Forward Network**: |
|
- Two-layer MLP with GELU activation |
|
- Hidden dimension: 512 -> 2048 -> 512 |
|
- Residual connections |
|
|
|
4. **Output Layer**: |
|
- Two-layer MLP: 512 -> 256 -> 1 |
|
- GELU activation |
|
- Direct sigmoid output for [0,1] range |
|
|
|
## Performance Metrics |
|
```txt |
|
=== Test Set Results === |
|
Loss: 0.0002 |
|
RMSE: 0.0144 |
|
R²: 0.8666 |
|
``` |
|
|
|
## Training Details |
|
- Dataset: 6400 conversations (5120 train, 640 val, 640 test) |
|
- Window size: 8 turns |
|
- Batch size: 32 |
|
- Learning rate: 0.0001 |
|
- Early stopping patience: 15 |
|
- Distribution regularization weight: 0.1 |
|
- Target standard deviation: 0.2 |
|
- Base embeddings: BAAI/bge-m3 |
|
|
|
## Usage Example |
|
|
|
```python |
|
# Install dependencies |
|
pip install torch transformers huggingface_hub |
|
|
|
# Import required packages |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
from huggingface_hub import hf_hub_download |
|
|
|
# Load base model and tokenizer |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
base_model = AutoModel.from_pretrained('BAAI/bge-m3').to(device) |
|
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3') |
|
|
|
# Download and load topic drift model |
|
model_path = hf_hub_download( |
|
repo_id='leonvanbokhorst/topic-drift-detector', |
|
filename='models/v20241226_114030/topic_drift_model.pt' |
|
) |
|
checkpoint = torch.load(model_path, weights_only=True, map_location=device) |
|
model = EnhancedTopicDriftDetector( |
|
input_dim=1024, |
|
hidden_dim=checkpoint['hyperparameters']['hidden_dim'] |
|
).to(device) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model.eval() |
|
|
|
# Example conversation |
|
conversation = [ |
|
"How was your weekend?", |
|
"It was great! Went hiking.", |
|
"Which trail did you take?", |
|
"The mountain loop trail.", |
|
"That's nice. By the way, did you watch the game?", |
|
"Yes! What an amazing match!", |
|
"The final score was incredible.", |
|
"I couldn't believe that last-minute goal." |
|
] |
|
|
|
# Process conversation |
|
with torch.no_grad(): |
|
# Get embeddings |
|
inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt') |
|
inputs = dict((k, v.to(device)) for k, v in inputs.items()) |
|
embeddings = base_model(**inputs).last_hidden_state.mean(dim=1) |
|
|
|
# Get drift score |
|
conversation_embeddings = embeddings.view(1, -1) |
|
drift_score = model(conversation_embeddings) |
|
print(f"Topic drift score: {drift_score.item():.4f}") |
|
``` |
|
|
|
## Limitations |
|
- Works best with English conversations |
|
- Requires exactly 8 turns of conversation |
|
- Each turn should be between 1-512 tokens |
|
- Relies on BAAI/bge-m3 embeddings |
|
|