Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
@@ -34,7 +34,7 @@ model-index:
|
|
34 |
|
35 |
# Topic Drift Detector Model
|
36 |
|
37 |
-
## Version:
|
38 |
|
39 |
This model detects topic drift in conversations using a streamlined attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset.
|
40 |
|
@@ -99,19 +99,47 @@ R²: 0.8666
|
|
99 |
- Wide score range
|
100 |
|
101 |
## Usage Example
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
```python
|
103 |
import torch
|
104 |
from transformers import AutoModel, AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
# Load base embedding model
|
107 |
base_model = AutoModel.from_pretrained('BAAI/bge-m3')
|
108 |
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
|
109 |
|
110 |
-
# Load topic drift detector
|
111 |
-
model =
|
112 |
model.eval()
|
113 |
|
114 |
-
#
|
115 |
conversation = [
|
116 |
"How was your weekend?",
|
117 |
"It was great! Went hiking.",
|
@@ -145,4 +173,4 @@ print(f"Topic drift score: {drift_scores.item():.4f}")
|
|
145 |
- Relies on BAAI/bge-m3 embeddings
|
146 |
|
147 |
## Training Curves
|
148 |
-
![Training Curves](plots/
|
|
|
34 |
|
35 |
# Topic Drift Detector Model
|
36 |
|
37 |
+
## Version: v20241226_110212
|
38 |
|
39 |
This model detects topic drift in conversations using a streamlined attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset.
|
40 |
|
|
|
99 |
- Wide score range
|
100 |
|
101 |
## Usage Example
|
102 |
+
|
103 |
+
To use the model, first install the required packages:
|
104 |
+
```bash
|
105 |
+
pip install torch transformers huggingface_hub
|
106 |
+
```
|
107 |
+
|
108 |
+
Then use the following code:
|
109 |
```python
|
110 |
import torch
|
111 |
from transformers import AutoModel, AutoTokenizer
|
112 |
+
from huggingface_hub import hf_hub_download
|
113 |
+
|
114 |
+
def load_model(repo_id: str = "leonvanbokhorst/topic-drift-detector"):
|
115 |
+
# Download latest model weights
|
116 |
+
model_path = hf_hub_download(
|
117 |
+
repo_id=repo_id,
|
118 |
+
filename="models/latest/topic_drift_model.pt"
|
119 |
+
)
|
120 |
+
|
121 |
+
# Load checkpoint
|
122 |
+
checkpoint = torch.load(model_path, weights_only=True)
|
123 |
+
|
124 |
+
# Create model with same hyperparameters
|
125 |
+
model = EnhancedTopicDriftDetector(
|
126 |
+
input_dim=1024, # BGE-M3 embedding dimension
|
127 |
+
hidden_dim=checkpoint['hyperparameters']['hidden_dim']
|
128 |
+
)
|
129 |
+
|
130 |
+
# Load state dict
|
131 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
132 |
+
return model
|
133 |
|
134 |
# Load base embedding model
|
135 |
base_model = AutoModel.from_pretrained('BAAI/bge-m3')
|
136 |
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
|
137 |
|
138 |
+
# Load topic drift detector from Hugging Face
|
139 |
+
model = load_model()
|
140 |
model.eval()
|
141 |
|
142 |
+
# Example conversation
|
143 |
conversation = [
|
144 |
"How was your weekend?",
|
145 |
"It was great! Went hiking.",
|
|
|
173 |
- Relies on BAAI/bge-m3 embeddings
|
174 |
|
175 |
## Training Curves
|
176 |
+
![Training Curves](plots/v20241226_110212/training_curves.png)
|