leonvanbokhorst commited on
Commit
b40f52b
·
verified ·
1 Parent(s): cedd82d

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +80 -28
README.md CHANGED
@@ -24,10 +24,10 @@ model-index:
24
  metrics:
25
  - name: Test RMSE
26
  type: rmse
27
- value: 0.0139
28
  - name: Test R²
29
  type: r2
30
- value: 0.8766
31
  - name: Test Loss
32
  type: loss
33
  value: 0.0002
@@ -35,35 +35,83 @@ model-index:
35
 
36
  # Topic Drift Detector Model
37
 
38
- ## Version: v20241225_162244
39
 
40
- This model detects topic drift in conversations using an enhanced attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset.
41
 
42
  ## Model Architecture
43
- - Multi-head attention mechanism (4 heads)
44
- - Bidirectional LSTM (3 layers) for pattern detection
45
- - Dynamic weight generation
46
- - Semantic bridge detection
47
- - Hidden dimension: 512
48
- - Dropout rate: 0.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  ## Performance Metrics
51
  ```txt
52
  === Full Training Results ===
53
- Best Validation RMSE: 0.0133
54
- Best Validation R²: 0.8873
55
 
56
  === Test Set Results ===
57
  Loss: 0.0002
58
- RMSE: 0.0139
59
- R²: 0.8766
60
-
61
  ```
62
 
63
- ## Training Curves
64
- ![Training Curves](plots/v20241225_162244/training_curves.png)
 
 
 
 
 
 
 
 
65
 
66
- ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  ```python
68
  import torch
69
  from transformers import AutoModel, AutoTokenizer
@@ -73,7 +121,7 @@ base_model = AutoModel.from_pretrained('BAAI/bge-m3')
73
  tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
74
 
75
  # Load topic drift detector
76
- model = torch.load('models/v20241225_162244/topic_drift_model.pt')
77
  model.eval()
78
 
79
  # Prepare conversation window (8 turns)
@@ -103,18 +151,22 @@ print(f"Topic drift score: {drift_scores.item():.4f}")
103
  # Higher scores indicate more topic drift
104
  ```
105
 
106
- ## Training Details
107
- - Dataset: [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2)
108
- - Window size: 8 turns
109
- - Batch size: 32
110
- - Learning rate: 0.0001
111
- - Early stopping patience: 10
112
- - Total epochs: 70 (early stopped)
113
- - Training framework: PyTorch
114
- - Base embeddings: BAAI/bge-m3
115
 
116
  ## Limitations
117
  - Works best with English conversations
118
  - Requires exactly 8 turns of conversation
119
  - Each turn should be between 1-512 tokens
120
  - Relies on BAAI/bge-m3 embeddings
 
 
 
 
 
24
  metrics:
25
  - name: Test RMSE
26
  type: rmse
27
+ value: 0.0144
28
  - name: Test R²
29
  type: r2
30
+ value: 0.8666
31
  - name: Test Loss
32
  type: loss
33
  value: 0.0002
 
35
 
36
  # Topic Drift Detector Model
37
 
38
+ ## Version: v20241225_184257
39
 
40
+ This model detects topic drift in conversations using an enhanced hierarchical attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset.
41
 
42
  ## Model Architecture
43
+ - Multi-head attention mechanism (4 heads, head dimension 128)
44
+ - Hierarchical pattern detection with multi-scale analysis
45
+ - Explicit transition point detection with linguistic markers
46
+ - Pattern-aware self-attention mechanism
47
+ - Dynamic window augmentation
48
+ - Contrastive learning with pattern-aware sampling
49
+ - Adversarial training with pattern-aware perturbations
50
+
51
+ ### Key Components:
52
+ 1. **Embedding Processor**:
53
+ - Input dimension: 1024
54
+ - Hidden dimension: 512
55
+ - Dropout rate: 0.35
56
+ - PreNorm layers with residual connections
57
+
58
+ 2. **Attention Blocks**:
59
+ - 3 layers of attention
60
+ - 4 attention heads
61
+ - Feed-forward dimension: 2048
62
+ - Learned position encodings
63
+
64
+ 3. **Pattern Detection**:
65
+ - Hierarchical LSTM layers
66
+ - Bidirectional processing
67
+ - Multi-scale pattern analysis
68
+ - Pattern classification with 7 types
69
+
70
+ 4. **Transition Detection**:
71
+ - Linguistic marker attention
72
+ - Explicit transition scoring
73
+ - Marker-based context integration
74
 
75
  ## Performance Metrics
76
  ```txt
77
  === Full Training Results ===
78
+ Best Validation RMSE: 0.0142
79
+ Best Validation R²: 0.8711
80
 
81
  === Test Set Results ===
82
  Loss: 0.0002
83
+ RMSE: 0.0144
84
+ R²: 0.8666
 
85
  ```
86
 
87
+ ## Training Details
88
+ - Dataset: 6400 conversations (5120 train, 640 val, 640 test)
89
+ - Window size: 8 turns
90
+ - Batch size: 32
91
+ - Learning rate: 0.0001 with cosine decay
92
+ - Warmup steps: 100
93
+ - Early stopping patience: 15
94
+ - Max gradient norm: 1.0
95
+ - Mixed precision training (AMP)
96
+ - Base embeddings: BAAI/bge-m3
97
 
98
+ ### Training Enhancements:
99
+ 1. **Dynamic Window Augmentation**:
100
+ - Adaptive window sizes
101
+ - Interpolation-based resizing
102
+ - Maintains temporal consistency
103
+
104
+ 2. **Contrastive Learning**:
105
+ - Pattern-aware positive/negative sampling
106
+ - Temperature-scaled similarities
107
+ - Weighted combination of embeddings
108
+
109
+ 3. **Adversarial Training**:
110
+ - Pattern-aware perturbations
111
+ - Self-distillation loss
112
+ - Epsilon ball projection
113
+
114
+ ## Usage Example
115
  ```python
116
  import torch
117
  from transformers import AutoModel, AutoTokenizer
 
121
  tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
122
 
123
  # Load topic drift detector
124
+ model = torch.load('models/v20241225_184257/topic_drift_model.pt')
125
  model.eval()
126
 
127
  # Prepare conversation window (8 turns)
 
151
  # Higher scores indicate more topic drift
152
  ```
153
 
154
+ ## Pattern Types
155
+ The model detects 7 distinct pattern types:
156
+ 1. "maintain" - No significant drift
157
+ 2. "gentle_wave" - Subtle topic evolution
158
+ 3. "single_peak" - One clear transition
159
+ 4. "multi_peak" - Multiple transitions
160
+ 5. "ascending" - Gradually increasing drift
161
+ 6. "descending" - Gradually decreasing drift
162
+ 7. "abrupt" - Sudden topic change
163
 
164
  ## Limitations
165
  - Works best with English conversations
166
  - Requires exactly 8 turns of conversation
167
  - Each turn should be between 1-512 tokens
168
  - Relies on BAAI/bge-m3 embeddings
169
+ - May be sensitive to conversation style variations
170
+
171
+ ## Training Curves
172
+ ![Training Curves](plots/v20241225_184257/training_curves.png)