leonvanbokhorst
commited on
Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
@@ -1,15 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
# Topic Drift Detector Model
|
3 |
|
4 |
-
## Version:
|
5 |
|
6 |
-
This model detects topic drift in conversations using an enhanced attention-based architecture.
|
7 |
|
8 |
## Model Architecture
|
9 |
-
- Multi-head attention mechanism
|
10 |
-
- Bidirectional LSTM for pattern detection
|
11 |
- Dynamic weight generation
|
12 |
- Semantic bridge detection
|
|
|
|
|
13 |
|
14 |
## Performance Metrics
|
15 |
```txt
|
@@ -25,16 +58,60 @@ R²: 0.8373
|
|
25 |
```
|
26 |
|
27 |
## Training Curves
|
28 |
-
![Training Curves](plots/
|
29 |
|
30 |
## Usage
|
31 |
```python
|
32 |
import torch
|
|
|
33 |
|
34 |
-
# Load model
|
35 |
-
|
|
|
36 |
|
37 |
-
#
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: en
|
3 |
+
tags:
|
4 |
+
- topic-drift
|
5 |
+
- conversation-analysis
|
6 |
+
- pytorch
|
7 |
+
- attention
|
8 |
+
- lstm
|
9 |
+
license: mit
|
10 |
+
datasets:
|
11 |
+
- leonvanbokhorst/topic-drift
|
12 |
+
metrics:
|
13 |
+
- rmse
|
14 |
+
- r2_score
|
15 |
+
model-index:
|
16 |
+
- name: topic-drift-detector
|
17 |
+
results:
|
18 |
+
- task:
|
19 |
+
type: topic-drift-detection
|
20 |
+
name: Topic Drift Detection
|
21 |
+
dataset:
|
22 |
+
name: leonvanbokhorst/topic-drift
|
23 |
+
type: conversations
|
24 |
+
metrics:
|
25 |
+
- name: Test RMSE
|
26 |
+
type: rmse
|
27 |
+
value: 0.0129
|
28 |
+
- name: Test R²
|
29 |
+
type: r2
|
30 |
+
value: 0.8373
|
31 |
+
---
|
32 |
|
33 |
# Topic Drift Detector Model
|
34 |
|
35 |
+
## Version: v20241225_090654
|
36 |
|
37 |
+
This model detects topic drift in conversations using an enhanced attention-based architecture. Trained on the [leonvanbokhorst/topic-drift](https://huggingface.co/datasets/leonvanbokhorst/topic-drift) dataset.
|
38 |
|
39 |
## Model Architecture
|
40 |
+
- Multi-head attention mechanism (4 heads)
|
41 |
+
- Bidirectional LSTM (3 layers) for pattern detection
|
42 |
- Dynamic weight generation
|
43 |
- Semantic bridge detection
|
44 |
+
- Hidden dimension: 512
|
45 |
+
- Dropout rate: 0.2
|
46 |
|
47 |
## Performance Metrics
|
48 |
```txt
|
|
|
58 |
```
|
59 |
|
60 |
## Training Curves
|
61 |
+
![Training Curves](plots/v20241225_090654/training_curves.png)
|
62 |
|
63 |
## Usage
|
64 |
```python
|
65 |
import torch
|
66 |
+
from transformers import AutoModel, AutoTokenizer
|
67 |
|
68 |
+
# Load base embedding model
|
69 |
+
base_model = AutoModel.from_pretrained('BAAI/bge-m3')
|
70 |
+
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
|
71 |
|
72 |
+
# Load topic drift detector
|
73 |
+
model = torch.load('models/v20241225_090654/topic_drift_model.pt')
|
74 |
+
model.eval()
|
75 |
+
|
76 |
+
# Prepare conversation window (8 turns)
|
77 |
+
conversation = [
|
78 |
+
"How was your weekend?",
|
79 |
+
"It was great! Went hiking.",
|
80 |
+
"Which trail did you take?",
|
81 |
+
"The mountain loop trail.",
|
82 |
+
"That's nice. By the way, did you watch the game?",
|
83 |
+
"Yes! What an amazing match!",
|
84 |
+
"The final score was incredible.",
|
85 |
+
"I couldn't believe that last-minute goal."
|
86 |
+
]
|
87 |
+
|
88 |
+
# Get embeddings
|
89 |
+
with torch.no_grad():
|
90 |
+
inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
|
91 |
+
embeddings = base_model(**inputs).last_hidden_state.mean(dim=1) # [8, 1024]
|
92 |
+
|
93 |
+
# Reshape for model input [1, 8*1024]
|
94 |
+
conversation_embeddings = embeddings.view(1, -1)
|
95 |
+
|
96 |
+
# Get drift score
|
97 |
+
drift_scores = model(conversation_embeddings)
|
98 |
+
|
99 |
+
print(f"Topic drift score: {drift_scores.item():.4f}")
|
100 |
+
# Higher scores indicate more topic drift
|
101 |
```
|
102 |
+
|
103 |
+
## Training Details
|
104 |
+
- Dataset: [leonvanbokhorst/topic-drift](https://huggingface.co/datasets/leonvanbokhorst/topic-drift)
|
105 |
+
- Window size: 8 turns
|
106 |
+
- Batch size: 32
|
107 |
+
- Learning rate: 0.0001
|
108 |
+
- Early stopping patience: 10
|
109 |
+
- Total epochs: 37 (early stopped)
|
110 |
+
- Training framework: PyTorch
|
111 |
+
- Base embeddings: BAAI/bge-m3
|
112 |
+
|
113 |
+
## Limitations
|
114 |
+
- Works best with English conversations
|
115 |
+
- Requires exactly 8 turns of conversation
|
116 |
+
- Each turn should be between 1-512 tokens
|
117 |
+
- Relies on BAAI/bge-m3 embeddings
|