Edit model card

Model Description

This model takes in text from a news article and outputs an embedding representing that article. These output embeddings have been trained such that the cosine similarity between articles aligns with overall article similarity. The model was trained using data from the 2022 SemEval Task-8 News Article Similarity challenge, and achieves the second-highest score when evaluated using the test set from the challenge. Designed for speed and scalability, this model is ideal for embedding many news articles (or similar text) and using fast cosine similarity calculations for pairwise similarity over very large corpora.

  • Developed by: Ben Litterer, David Jurgens, Dallas Card
  • Finetuned from model: all-mpnet-base-v2

Uses

This model is ideal for embedding large corpora of text and calculating pairwise similarity scores. Note that when training, article headlines were first concatenated to the full article text. The first 288 tokens and the last 96 tokens were then concatenated to fit in the all-mpnet-base-v2 context window.

How to Get Started with the Model

Use the code below to get started with the model. All you need are the weights in state_dict.tar

import torch
import torch.nn
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
import numpy as np

MODEL_PATH = "/my/path/to/state_dict.tar"

#declare model class, inheriting from nn.Module 
class BiModel(torch.nn.Module):
    def __init__(self):
        super(BiModel,self).__init__()
        self.model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2').to(device).train()
        self.cos = torch.nn.CosineSimilarity(dim=1, eps=1e-4)
    
    #pool token level embeddings 
    def mean_pooling(self, token_embeddings, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    #Note that here we expect only one batch of input ids and attention masks
    def encode(self, input_ids, attention_mask):
        encoding = self.model(input_ids.squeeze(1), attention_mask=attention_mask.squeeze(1))[0]
        meanPooled = self.mean_pooling(encoding, attention_mask.squeeze(1))
        return meanPooled

    #NOTE: here we expect a list of two that we then unpack
    def forward(self, input_ids, attention_mask):

        input_ids_a = input_ids[0].to(device)
        input_ids_b = input_ids[1].to(device)
        attention_a = attention_mask[0].to(device)
        attention_b = attention_mask[1].to(device)

        #encode sentence and get mean pooled sentence representation
        encoding1 = self.model(input_ids_a, attention_mask=attention_a)[0] #all token embeddings
        encoding2 = self.model(input_ids_b, attention_mask=attention_b)[0]

        meanPooled1 = self.mean_pooling(encoding1, attention_a)
        meanPooled2 = self.mean_pooling(encoding2, attention_b)

        pred = self.cos(meanPooled1, meanPooled2)
        return pred

#set device as needed, initialize model, load weights 
device = torch.device("cpu")
trainedModel = BiModel()
sDict = torch.load(MODEL_PATH)

#may need to run depending on pytorch version 
del sDict["model.embeddings.position_ids"]

#initialize tokenizer for all-mpnet-base-v2
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')

#initialize model
trainedModel.load_state_dict(sDict)

#trainedModel is now ready to use 
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .