semantic_similarity / twc_embeddings.py
taskswithcode's picture
Upload twc_embeddings.py
5b36a6d
raw
history blame
9.15 kB
from transformers import AutoModel, AutoTokenizer
from scipy.spatial.distance import cosine
import argparse
import json
import pdb
import torch
import torch.nn.functional as F
def read_text(input_file):
arr = open(input_file).read().split("\n")
return arr[:-1]
class SimCSEModel:
def __init__(self):
self.model = None
self.tokenizer = None
self.debug = False
print("In SimCSE constructor")
def init_model(self,model_name = None):
if (model_name == None):
model_name = "princeton-nlp/sup-simcse-roberta-large"
#self.model = SimCSE(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
def compute_embeddings(self,input_data,is_file):
texts = read_text(input_data) if is_file == True else input_data
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
embeddings = self.model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
return texts,embeddings
def output_results(self,output_file,texts,embeddings,main_index = 0):
# Calculate cosine similarities
# Cosine similarities are in [-1, 1]. Higher means more similar
cosine_dict = {}
#print("Total sentences",len(texts))
for i in range(len(texts)):
cosine_dict[texts[i]] = 1 - cosine(embeddings[main_index], embeddings[i])
#print("Input sentence:",texts[main_index])
sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
if (self.debug):
for key in sorted_dict:
print("Cosine similarity with \"%s\" is: %.3f" % (key, sorted_dict[key]))
if (output_file is not None):
with open(output_file,"w") as fp:
fp.write(json.dumps(sorted_dict,indent=0))
return sorted_dict
class SGPTModel:
def __init__(self):
self.model = None
self.tokenizer = None
self.debug = False
print("In SGPT Constructor")
def init_model(self,model_name = None):
# Get our models - The package will take care of downloading the models automatically
# For best performance: Muennighoff/SGPT-5.8B-weightedmean-nli-bitfit
if (self.debug):
print("Init model",model_name)
if (model_name is None):
model_name = "Muennighoff/SGPT-125M-weightedmean-nli-bitfit"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
#self.tokenizer = AutoTokenizer.from_pretrained("Muennighoff/SGPT-1.3B-weightedmean-msmarco-specb-bitfit")
#self.model = AutoModel.from_pretrained("Muennighoff/SGPT-1.3B-weightedmean-msmarco-specb-bitfit")
#self.tokenizer = AutoTokenizer.from_pretrained("Muennighoff/SGPT-5.8B-weightedmean-msmarco-specb-bitfit")
#self.model = AutoModel.from_pretrained("Muennighoff/SGPT-5.8B-weightedmean-msmarco-specb-bitfit")
# Deactivate Dropout (There is no dropout in the above models so it makes no difference here but other SGPT models may have dropout)
self.model.eval()
def compute_embeddings(self,input_data,is_file):
if (self.debug):
print("Computing embeddings for:", input_data[:20])
model = self.model
tokenizer = self.tokenizer
texts = read_text(input_data) if is_file == True else input_data
# Tokenize input texts
batch_tokens = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
# Get the embeddings
with torch.no_grad():
# Get hidden state of shape [bs, seq_len, hid_dim]
last_hidden_state = model(**batch_tokens, output_hidden_states=True, return_dict=True).last_hidden_state
# Get weights of shape [bs, seq_len, hid_dim]
weights = (
torch.arange(start=1, end=last_hidden_state.shape[1] + 1)
.unsqueeze(0)
.unsqueeze(-1)
.expand(last_hidden_state.size())
.float().to(last_hidden_state.device)
)
# Get attn mask of shape [bs, seq_len, hid_dim]
input_mask_expanded = (
batch_tokens["attention_mask"]
.unsqueeze(-1)
.expand(last_hidden_state.size())
.float()
)
# Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1)
sum_mask = torch.sum(input_mask_expanded * weights, dim=1)
embeddings = sum_embeddings / sum_mask
return texts,embeddings
def output_results(self,output_file,texts,embeddings,main_index = 0):
# Calculate cosine similarities
# Cosine similarities are in [-1, 1]. Higher means more similar
cosine_dict = {}
if (self.debug):
print("Total sentences",len(texts))
for i in range(len(texts)):
cosine_dict[texts[i]] = 1 - cosine(embeddings[main_index], embeddings[i])
if (self.debug):
print("Input sentence:",texts[main_index])
sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
if (self.debug):
for key in sorted_dict:
print("Cosine similarity with \"%s\" is: %.3f" % (key, sorted_dict[key]))
if (output_file is not None):
with open(output_file,"w") as fp:
fp.write(json.dumps(sorted_dict,indent=0))
return sorted_dict
class HFModel:
def __init__(self):
self.model = None
self.tokenizer = None
self.debug = False
print("In HF Constructor")
def init_model(self,model_name = None):
# Get our models - The package will take care of downloading the models automatically
# For best performance: Muennighoff/SGPT-5.8B-weightedmean-nli-bitfit
#print("Init model",model_name)
if (model_name is None):
model_name = "sentence-transformers/all-MiniLM-L6-v2"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
def mean_pooling(self,model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def compute_embeddings(self,input_data,is_file):
#print("Computing embeddings for:", input_data[:20])
model = self.model
tokenizer = self.tokenizer
texts = read_text(input_data) if is_file == True else input_data
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)
# Perform pooling
sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return texts,sentence_embeddings
def output_results(self,output_file,texts,embeddings,main_index = 0):
# Calculate cosine similarities
# Cosine similarities are in [-1, 1]. Higher means more similar
cosine_dict = {}
#print("Total sentences",len(texts))
for i in range(len(texts)):
cosine_dict[texts[i]] = 1 - cosine(embeddings[main_index], embeddings[i])
#print("Input sentence:",texts[main_index])
sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
if (self.debug):
for key in sorted_dict:
print("Cosine similarity with \"%s\" is: %.3f" % (key, sorted_dict[key]))
if (output_file is not None):
with open(output_file,"w") as fp:
fp.write(json.dumps(sorted_dict,indent=0))
return sorted_dict
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='SGPT model for sentence embeddings ',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-input', action="store", dest="input",required=True,help="Input file with sentences")
parser.add_argument('-output', action="store", dest="output",default="output.txt",help="Output file with results")
parser.add_argument('-model', action="store", dest="model",default="sentence-transformers/all-MiniLM-L6-v2",help="model name")
results = parser.parse_args()
obj = HFModel()
obj.init_model(results.model)
texts, embeddings = obj.compute_embeddings(results.input,is_file = True)
results = obj.output_results(results.output,texts,embeddings)