leonvanbokhorst commited on
Commit
4f292ea
·
verified ·
1 Parent(s): b17fb9c

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +87 -10
README.md CHANGED
@@ -1,15 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  # Topic Drift Detector Model
3
 
4
- ## Version: v20241225_090318
5
 
6
- This model detects topic drift in conversations using an enhanced attention-based architecture.
7
 
8
  ## Model Architecture
9
- - Multi-head attention mechanism
10
- - Bidirectional LSTM for pattern detection
11
  - Dynamic weight generation
12
  - Semantic bridge detection
 
 
13
 
14
  ## Performance Metrics
15
  ```txt
@@ -25,16 +58,60 @@ R²: 0.8373
25
  ```
26
 
27
  ## Training Curves
28
- ![Training Curves](plots/v20241225_090318/training_curves.png)
29
 
30
  ## Usage
31
  ```python
32
  import torch
 
33
 
34
- # Load model
35
- model = torch.load('models/v20241225_090318/topic_drift_model.pt')
 
36
 
37
- # Use model for inference
38
- # Input shape: [batch_size, sequence_length * embedding_dim]
39
- # Output shape: [batch_size, 1] (drift score between 0 and 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - topic-drift
5
+ - conversation-analysis
6
+ - pytorch
7
+ - attention
8
+ - lstm
9
+ license: mit
10
+ datasets:
11
+ - leonvanbokhorst/topic-drift
12
+ metrics:
13
+ - rmse
14
+ - r2_score
15
+ model-index:
16
+ - name: topic-drift-detector
17
+ results:
18
+ - task:
19
+ type: topic-drift-detection
20
+ name: Topic Drift Detection
21
+ dataset:
22
+ name: leonvanbokhorst/topic-drift
23
+ type: conversations
24
+ metrics:
25
+ - name: Test RMSE
26
+ type: rmse
27
+ value: 0.0129
28
+ - name: Test R²
29
+ type: r2
30
+ value: 0.8373
31
+ ---
32
 
33
  # Topic Drift Detector Model
34
 
35
+ ## Version: v20241225_090654
36
 
37
+ This model detects topic drift in conversations using an enhanced attention-based architecture. Trained on the [leonvanbokhorst/topic-drift](https://huggingface.co/datasets/leonvanbokhorst/topic-drift) dataset.
38
 
39
  ## Model Architecture
40
+ - Multi-head attention mechanism (4 heads)
41
+ - Bidirectional LSTM (3 layers) for pattern detection
42
  - Dynamic weight generation
43
  - Semantic bridge detection
44
+ - Hidden dimension: 512
45
+ - Dropout rate: 0.2
46
 
47
  ## Performance Metrics
48
  ```txt
 
58
  ```
59
 
60
  ## Training Curves
61
+ ![Training Curves](plots/v20241225_090654/training_curves.png)
62
 
63
  ## Usage
64
  ```python
65
  import torch
66
+ from transformers import AutoModel, AutoTokenizer
67
 
68
+ # Load base embedding model
69
+ base_model = AutoModel.from_pretrained('BAAI/bge-m3')
70
+ tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
71
 
72
+ # Load topic drift detector
73
+ model = torch.load('models/v20241225_090654/topic_drift_model.pt')
74
+ model.eval()
75
+
76
+ # Prepare conversation window (8 turns)
77
+ conversation = [
78
+ "How was your weekend?",
79
+ "It was great! Went hiking.",
80
+ "Which trail did you take?",
81
+ "The mountain loop trail.",
82
+ "That's nice. By the way, did you watch the game?",
83
+ "Yes! What an amazing match!",
84
+ "The final score was incredible.",
85
+ "I couldn't believe that last-minute goal."
86
+ ]
87
+
88
+ # Get embeddings
89
+ with torch.no_grad():
90
+ inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
91
+ embeddings = base_model(**inputs).last_hidden_state.mean(dim=1) # [8, 1024]
92
+
93
+ # Reshape for model input [1, 8*1024]
94
+ conversation_embeddings = embeddings.view(1, -1)
95
+
96
+ # Get drift score
97
+ drift_scores = model(conversation_embeddings)
98
+
99
+ print(f"Topic drift score: {drift_scores.item():.4f}")
100
+ # Higher scores indicate more topic drift
101
  ```
102
+
103
+ ## Training Details
104
+ - Dataset: [leonvanbokhorst/topic-drift](https://huggingface.co/datasets/leonvanbokhorst/topic-drift)
105
+ - Window size: 8 turns
106
+ - Batch size: 32
107
+ - Learning rate: 0.0001
108
+ - Early stopping patience: 10
109
+ - Total epochs: 37 (early stopped)
110
+ - Training framework: PyTorch
111
+ - Base embeddings: BAAI/bge-m3
112
+
113
+ ## Limitations
114
+ - Works best with English conversations
115
+ - Requires exactly 8 turns of conversation
116
+ - Each turn should be between 1-512 tokens
117
+ - Relies on BAAI/bge-m3 embeddings