File size: 4,516 Bytes
4f292ea
 
 
 
 
 
 
 
 
57353d0
4f292ea
 
 
 
 
 
 
 
 
 
19f2328
4f292ea
 
 
 
b40f52b
4f292ea
 
b40f52b
19f2328
 
 
4f292ea
513e1a2
 
 
0903ccc
513e1a2
f9d091d
513e1a2
 
f9d091d
 
 
 
 
b40f52b
 
 
 
 
 
 
 
f9d091d
 
 
b40f52b
f9d091d
b40f52b
f9d091d
 
 
 
513e1a2
 
 
 
b40f52b
 
513e1a2
 
 
b40f52b
 
513e1a2
 
b40f52b
 
 
 
f9d091d
b40f52b
f9d091d
 
b40f52b
513e1a2
f9d091d
 
 
 
 
 
b40f52b
f9d091d
 
 
 
 
b40f52b
 
0903ccc
 
 
 
 
 
 
513e1a2
 
4f292ea
0903ccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513e1a2
4f292ea
 
 
513e1a2
0903ccc
 
4f292ea
 
0903ccc
4f292ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539a918
4f292ea
 
 
 
 
 
b40f52b
 
0903ccc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
---
language: en
tags:
- topic-drift
- conversation-analysis
- pytorch
- attention
license: mit
datasets:
- leonvanbokhorst/topic-drift-v2
metrics:
- rmse
- r2_score
model-index:
- name: topic-drift-detector
  results:
  - task: 
      type: topic-drift-detection
      name: Topic Drift Detection
    dataset:
      name: leonvanbokhorst/topic-drift-v2
      type: conversations
    metrics:
      - name: Test RMSE
        type: rmse
        value: 0.0144
      - name: Test 
        type: r2
        value: 0.8666
      - name: Test Loss 
        type: loss
        value: 0.0002
---

# Topic Drift Detector Model

## Version: v20241226_110212

This model detects topic drift in conversations using a streamlined attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset.

## Model Architecture
- Efficient single-layer attention mechanism
- Direct pattern recognition
- Streamlined processing pipeline
- Optimized scaling factor (4.0)
- PreNorm layers with residual connections

### Key Components:
1. **Embedding Processor**:
   - Input dimension: 1024
   - Hidden dimension: 512
   - Dropout rate: 0.35
   - PreNorm layers with residual connections

2. **Attention Block**:
   - Single attention layer
   - Feed-forward dimension: 512
   - Learned position encodings
   - Residual connections

3. **Pattern Recognition**:
   - Direct feature extraction
   - Efficient tensor operations
   - Optimized memory usage

## Performance Metrics
```txt
=== Full Training Results ===
Best Validation RMSE: 0.0142
Best Validation R²: 0.8711

=== Test Set Results ===
Loss: 0.0002
RMSE: 0.0144
R²: 0.8666
```

## Training Details
- Dataset: 6400 conversations (5120 train, 640 val, 640 test)
- Window size: 8 turns
- Batch size: 32
- Learning rate: 0.0001
- Early stopping patience: 15
- Distribution regularization weight: 0.1
- Target standard deviation: 0.2
- Base embeddings: BAAI/bge-m3

## Key Improvements
1. **Simplified Architecture**:
   - Reduced complexity
   - Focused pattern detection
   - Efficient processing
   - Optimized memory usage

2. **Performance Benefits**:
   - Improved RMSE (0.0144)
   - Strong R² score (0.8666)
   - Consistent predictions
   - Wide score range

## Usage Example

To use the model, first install the required packages:
```bash
pip install torch transformers huggingface_hub
```

Then use the following code:
```python
import torch
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download

def load_model(repo_id: str = "leonvanbokhorst/topic-drift-detector"):
    # Download latest model weights
    model_path = hf_hub_download(
        repo_id=repo_id,
        filename="models/latest/topic_drift_model.pt"
    )
    
    # Load checkpoint
    checkpoint = torch.load(model_path, weights_only=True)
    
    # Create model with same hyperparameters
    model = EnhancedTopicDriftDetector(
        input_dim=1024,  # BGE-M3 embedding dimension
        hidden_dim=checkpoint['hyperparameters']['hidden_dim']
    )
    
    # Load state dict
    model.load_state_dict(checkpoint['model_state_dict'])
    return model

# Load base embedding model
base_model = AutoModel.from_pretrained('BAAI/bge-m3')
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')

# Load topic drift detector from Hugging Face
model = load_model()
model.eval()

# Example conversation
conversation = [
    "How was your weekend?",
    "It was great! Went hiking.",
    "Which trail did you take?",
    "The mountain loop trail.",
    "That's nice. By the way, did you watch the game?",
    "Yes! What an amazing match!",
    "The final score was incredible.",
    "I couldn't believe that last-minute goal."
]

# Get embeddings
with torch.no_grad():
    inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
    embeddings = base_model(**inputs).last_hidden_state.mean(dim=1)  # [8, 1024]
    
    # Reshape for model input [1, 8*1024]
    conversation_embeddings = embeddings.view(1, -1)
    
    # Get drift score
    drift_scores = model(conversation_embeddings)
    
print(f"Topic drift score: {drift_scores.item():.4f}")
# Higher scores indicate more topic drift
```

## Limitations
- Works best with English conversations
- Requires exactly 8 turns of conversation
- Each turn should be between 1-512 tokens
- Relies on BAAI/bge-m3 embeddings

## Training Curves
![Training Curves](plots/v20241226_110212/training_curves.png)