File size: 3,037 Bytes
4f292ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19f2328
4f292ea
 
 
 
19f2328
4f292ea
 
19f2328
 
 
 
4f292ea
513e1a2
 
 
19f2328
513e1a2
4f292ea
513e1a2
 
4f292ea
 
513e1a2
 
4f292ea
 
513e1a2
 
 
 
19f2328
 
513e1a2
 
 
19f2328
 
513e1a2
 
 
 
19f2328
513e1a2
 
 
 
4f292ea
513e1a2
4f292ea
 
 
513e1a2
4f292ea
19f2328
4f292ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539a918
4f292ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
---
language: en
tags:
- topic-drift
- conversation-analysis
- pytorch
- attention
- lstm
license: mit
datasets:
- leonvanbokhorst/topic-drift
metrics:
- rmse
- r2_score
model-index:
- name: topic-drift-detector
  results:
  - task: 
      type: topic-drift-detection
      name: Topic Drift Detection
    dataset:
      name: leonvanbokhorst/topic-drift-v2
      type: conversations
    metrics:
      - name: Test RMSE
        type: rmse
        value: 0.0153
      - name: Test 
        type: r2
        value: 0.8500
      - name: Test Loss 
        type: loss
        value: 0.0002
---

# Topic Drift Detector Model

## Version: v20241225_160448

This model detects topic drift in conversations using an enhanced attention-based architecture. Trained on the [leonvanbokhorst/topic-drift](https://huggingface.co/datasets/leonvanbokhorst/topic-drift) dataset.

## Model Architecture
- Multi-head attention mechanism (4 heads)
- Bidirectional LSTM (3 layers) for pattern detection
- Dynamic weight generation
- Semantic bridge detection
- Hidden dimension: 512
- Dropout rate: 0.2

## Performance Metrics
```txt
=== Full Training Results ===
Best Validation RMSE: 0.0145
Best Validation R²: 0.8656

=== Test Set Results ===
Loss: 0.0002
RMSE: 0.0153
R²: 0.8500

```

## Training Curves
![Training Curves](plots/v20241225_160448/training_curves.png)

## Usage
```python
import torch
from transformers import AutoModel, AutoTokenizer

# Load base embedding model
base_model = AutoModel.from_pretrained('BAAI/bge-m3')
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')

# Load topic drift detector
model = torch.load('models/v20241225_160448/topic_drift_model.pt')
model.eval()

# Prepare conversation window (8 turns)
conversation = [
    "How was your weekend?",
    "It was great! Went hiking.",
    "Which trail did you take?",
    "The mountain loop trail.",
    "That's nice. By the way, did you watch the game?",
    "Yes! What an amazing match!",
    "The final score was incredible.",
    "I couldn't believe that last-minute goal."
]

# Get embeddings
with torch.no_grad():
    inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
    embeddings = base_model(**inputs).last_hidden_state.mean(dim=1)  # [8, 1024]
    
    # Reshape for model input [1, 8*1024]
    conversation_embeddings = embeddings.view(1, -1)
    
    # Get drift score
    drift_scores = model(conversation_embeddings)
    
print(f"Topic drift score: {drift_scores.item():.4f}")
# Higher scores indicate more topic drift
```

## Training Details
- Dataset: [leonvanbokhorst/topic-drift](https://huggingface.co/datasets/leonvanbokhorst/topic-drift)
- Window size: 8 turns
- Batch size: 32
- Learning rate: 0.0001
- Early stopping patience: 10
- Total epochs: 37 (early stopped)
- Training framework: PyTorch
- Base embeddings: BAAI/bge-m3

## Limitations
- Works best with English conversations
- Requires exactly 8 turns of conversation
- Each turn should be between 1-512 tokens
- Relies on BAAI/bge-m3 embeddings