leonvanbokhorst's picture
Upload README.md with huggingface_hub
19f2328 verified
|
raw
history blame
3.04 kB
metadata
language: en
tags:
  - topic-drift
  - conversation-analysis
  - pytorch
  - attention
  - lstm
license: mit
datasets:
  - leonvanbokhorst/topic-drift
metrics:
  - rmse
  - r2_score
model-index:
  - name: topic-drift-detector
    results:
      - task:
          type: topic-drift-detection
          name: Topic Drift Detection
        dataset:
          name: leonvanbokhorst/topic-drift-v2
          type: conversations
        metrics:
          - name: Test RMSE
            type: rmse
            value: 0.0153
          - name: Test 
            type: r2
            value: 0.85
          - name: Test Loss
            type: loss
            value: 0.0002

Topic Drift Detector Model

Version: v20241225_160448

This model detects topic drift in conversations using an enhanced attention-based architecture. Trained on the leonvanbokhorst/topic-drift dataset.

Model Architecture

  • Multi-head attention mechanism (4 heads)
  • Bidirectional LSTM (3 layers) for pattern detection
  • Dynamic weight generation
  • Semantic bridge detection
  • Hidden dimension: 512
  • Dropout rate: 0.2

Performance Metrics

=== Full Training Results ===
Best Validation RMSE: 0.0145
Best Validation R²: 0.8656

=== Test Set Results ===
Loss: 0.0002
RMSE: 0.0153
R²: 0.8500

Training Curves

Training Curves

Usage

import torch
from transformers import AutoModel, AutoTokenizer

# Load base embedding model
base_model = AutoModel.from_pretrained('BAAI/bge-m3')
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')

# Load topic drift detector
model = torch.load('models/v20241225_160448/topic_drift_model.pt')
model.eval()

# Prepare conversation window (8 turns)
conversation = [
    "How was your weekend?",
    "It was great! Went hiking.",
    "Which trail did you take?",
    "The mountain loop trail.",
    "That's nice. By the way, did you watch the game?",
    "Yes! What an amazing match!",
    "The final score was incredible.",
    "I couldn't believe that last-minute goal."
]

# Get embeddings
with torch.no_grad():
    inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
    embeddings = base_model(**inputs).last_hidden_state.mean(dim=1)  # [8, 1024]
    
    # Reshape for model input [1, 8*1024]
    conversation_embeddings = embeddings.view(1, -1)
    
    # Get drift score
    drift_scores = model(conversation_embeddings)
    
print(f"Topic drift score: {drift_scores.item():.4f}")
# Higher scores indicate more topic drift

Training Details

  • Dataset: leonvanbokhorst/topic-drift
  • Window size: 8 turns
  • Batch size: 32
  • Learning rate: 0.0001
  • Early stopping patience: 10
  • Total epochs: 37 (early stopped)
  • Training framework: PyTorch
  • Base embeddings: BAAI/bge-m3

Limitations

  • Works best with English conversations
  • Requires exactly 8 turns of conversation
  • Each turn should be between 1-512 tokens
  • Relies on BAAI/bge-m3 embeddings