leonvanbokhorst commited on
Commit
0903ccc
·
verified ·
1 Parent(s): cd21873

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +33 -5
README.md CHANGED
@@ -34,7 +34,7 @@ model-index:
34
 
35
  # Topic Drift Detector Model
36
 
37
- ## Version: v20241226_105942
38
 
39
  This model detects topic drift in conversations using a streamlined attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset.
40
 
@@ -99,19 +99,47 @@ R²: 0.8666
99
  - Wide score range
100
 
101
  ## Usage Example
 
 
 
 
 
 
 
102
  ```python
103
  import torch
104
  from transformers import AutoModel, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  # Load base embedding model
107
  base_model = AutoModel.from_pretrained('BAAI/bge-m3')
108
  tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
109
 
110
- # Load topic drift detector
111
- model = torch.load('models/v20241226_105942/topic_drift_model.pt')
112
  model.eval()
113
 
114
- # Prepare conversation window (8 turns)
115
  conversation = [
116
  "How was your weekend?",
117
  "It was great! Went hiking.",
@@ -145,4 +173,4 @@ print(f"Topic drift score: {drift_scores.item():.4f}")
145
  - Relies on BAAI/bge-m3 embeddings
146
 
147
  ## Training Curves
148
- ![Training Curves](plots/v20241226_105942/training_curves.png)
 
34
 
35
  # Topic Drift Detector Model
36
 
37
+ ## Version: v20241226_110212
38
 
39
  This model detects topic drift in conversations using a streamlined attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset.
40
 
 
99
  - Wide score range
100
 
101
  ## Usage Example
102
+
103
+ To use the model, first install the required packages:
104
+ ```bash
105
+ pip install torch transformers huggingface_hub
106
+ ```
107
+
108
+ Then use the following code:
109
  ```python
110
  import torch
111
  from transformers import AutoModel, AutoTokenizer
112
+ from huggingface_hub import hf_hub_download
113
+
114
+ def load_model(repo_id: str = "leonvanbokhorst/topic-drift-detector"):
115
+ # Download latest model weights
116
+ model_path = hf_hub_download(
117
+ repo_id=repo_id,
118
+ filename="models/latest/topic_drift_model.pt"
119
+ )
120
+
121
+ # Load checkpoint
122
+ checkpoint = torch.load(model_path, weights_only=True)
123
+
124
+ # Create model with same hyperparameters
125
+ model = EnhancedTopicDriftDetector(
126
+ input_dim=1024, # BGE-M3 embedding dimension
127
+ hidden_dim=checkpoint['hyperparameters']['hidden_dim']
128
+ )
129
+
130
+ # Load state dict
131
+ model.load_state_dict(checkpoint['model_state_dict'])
132
+ return model
133
 
134
  # Load base embedding model
135
  base_model = AutoModel.from_pretrained('BAAI/bge-m3')
136
  tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
137
 
138
+ # Load topic drift detector from Hugging Face
139
+ model = load_model()
140
  model.eval()
141
 
142
+ # Example conversation
143
  conversation = [
144
  "How was your weekend?",
145
  "It was great! Went hiking.",
 
173
  - Relies on BAAI/bge-m3 embeddings
174
 
175
  ## Training Curves
176
+ ![Training Curves](plots/v20241226_110212/training_curves.png)