taskswithcode commited on
Commit
5b36a6d
1 Parent(s): b36de05

Upload twc_embeddings.py

Browse files
Files changed (1) hide show
  1. twc_embeddings.py +217 -0
twc_embeddings.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ from scipy.spatial.distance import cosine
3
+ import argparse
4
+ import json
5
+ import pdb
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ def read_text(input_file):
10
+ arr = open(input_file).read().split("\n")
11
+ return arr[:-1]
12
+
13
+
14
+ class SimCSEModel:
15
+ def __init__(self):
16
+ self.model = None
17
+ self.tokenizer = None
18
+ self.debug = False
19
+ print("In SimCSE constructor")
20
+
21
+ def init_model(self,model_name = None):
22
+ if (model_name == None):
23
+ model_name = "princeton-nlp/sup-simcse-roberta-large"
24
+ #self.model = SimCSE(model_name)
25
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ self.model = AutoModel.from_pretrained(model_name)
27
+
28
+ def compute_embeddings(self,input_data,is_file):
29
+ texts = read_text(input_data) if is_file == True else input_data
30
+ inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
31
+ with torch.no_grad():
32
+ embeddings = self.model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
33
+ return texts,embeddings
34
+
35
+ def output_results(self,output_file,texts,embeddings,main_index = 0):
36
+ # Calculate cosine similarities
37
+ # Cosine similarities are in [-1, 1]. Higher means more similar
38
+ cosine_dict = {}
39
+ #print("Total sentences",len(texts))
40
+ for i in range(len(texts)):
41
+ cosine_dict[texts[i]] = 1 - cosine(embeddings[main_index], embeddings[i])
42
+
43
+ #print("Input sentence:",texts[main_index])
44
+ sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
45
+ if (self.debug):
46
+ for key in sorted_dict:
47
+ print("Cosine similarity with \"%s\" is: %.3f" % (key, sorted_dict[key]))
48
+ if (output_file is not None):
49
+ with open(output_file,"w") as fp:
50
+ fp.write(json.dumps(sorted_dict,indent=0))
51
+ return sorted_dict
52
+
53
+
54
+
55
+ class SGPTModel:
56
+ def __init__(self):
57
+ self.model = None
58
+ self.tokenizer = None
59
+ self.debug = False
60
+ print("In SGPT Constructor")
61
+
62
+
63
+ def init_model(self,model_name = None):
64
+ # Get our models - The package will take care of downloading the models automatically
65
+ # For best performance: Muennighoff/SGPT-5.8B-weightedmean-nli-bitfit
66
+ if (self.debug):
67
+ print("Init model",model_name)
68
+ if (model_name is None):
69
+ model_name = "Muennighoff/SGPT-125M-weightedmean-nli-bitfit"
70
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
71
+ self.model = AutoModel.from_pretrained(model_name)
72
+ #self.tokenizer = AutoTokenizer.from_pretrained("Muennighoff/SGPT-1.3B-weightedmean-msmarco-specb-bitfit")
73
+ #self.model = AutoModel.from_pretrained("Muennighoff/SGPT-1.3B-weightedmean-msmarco-specb-bitfit")
74
+ #self.tokenizer = AutoTokenizer.from_pretrained("Muennighoff/SGPT-5.8B-weightedmean-msmarco-specb-bitfit")
75
+ #self.model = AutoModel.from_pretrained("Muennighoff/SGPT-5.8B-weightedmean-msmarco-specb-bitfit")
76
+ # Deactivate Dropout (There is no dropout in the above models so it makes no difference here but other SGPT models may have dropout)
77
+ self.model.eval()
78
+
79
+ def compute_embeddings(self,input_data,is_file):
80
+ if (self.debug):
81
+ print("Computing embeddings for:", input_data[:20])
82
+ model = self.model
83
+ tokenizer = self.tokenizer
84
+
85
+ texts = read_text(input_data) if is_file == True else input_data
86
+
87
+ # Tokenize input texts
88
+ batch_tokens = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
89
+
90
+ # Get the embeddings
91
+ with torch.no_grad():
92
+ # Get hidden state of shape [bs, seq_len, hid_dim]
93
+ last_hidden_state = model(**batch_tokens, output_hidden_states=True, return_dict=True).last_hidden_state
94
+
95
+ # Get weights of shape [bs, seq_len, hid_dim]
96
+ weights = (
97
+ torch.arange(start=1, end=last_hidden_state.shape[1] + 1)
98
+ .unsqueeze(0)
99
+ .unsqueeze(-1)
100
+ .expand(last_hidden_state.size())
101
+ .float().to(last_hidden_state.device)
102
+ )
103
+
104
+ # Get attn mask of shape [bs, seq_len, hid_dim]
105
+ input_mask_expanded = (
106
+ batch_tokens["attention_mask"]
107
+ .unsqueeze(-1)
108
+ .expand(last_hidden_state.size())
109
+ .float()
110
+ )
111
+
112
+ # Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim
113
+ sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1)
114
+ sum_mask = torch.sum(input_mask_expanded * weights, dim=1)
115
+
116
+ embeddings = sum_embeddings / sum_mask
117
+ return texts,embeddings
118
+
119
+ def output_results(self,output_file,texts,embeddings,main_index = 0):
120
+ # Calculate cosine similarities
121
+ # Cosine similarities are in [-1, 1]. Higher means more similar
122
+ cosine_dict = {}
123
+ if (self.debug):
124
+ print("Total sentences",len(texts))
125
+ for i in range(len(texts)):
126
+ cosine_dict[texts[i]] = 1 - cosine(embeddings[main_index], embeddings[i])
127
+
128
+ if (self.debug):
129
+ print("Input sentence:",texts[main_index])
130
+ sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
131
+ if (self.debug):
132
+ for key in sorted_dict:
133
+ print("Cosine similarity with \"%s\" is: %.3f" % (key, sorted_dict[key]))
134
+ if (output_file is not None):
135
+ with open(output_file,"w") as fp:
136
+ fp.write(json.dumps(sorted_dict,indent=0))
137
+ return sorted_dict
138
+
139
+
140
+
141
+
142
+
143
+ class HFModel:
144
+ def __init__(self):
145
+ self.model = None
146
+ self.tokenizer = None
147
+ self.debug = False
148
+ print("In HF Constructor")
149
+
150
+
151
+ def init_model(self,model_name = None):
152
+ # Get our models - The package will take care of downloading the models automatically
153
+ # For best performance: Muennighoff/SGPT-5.8B-weightedmean-nli-bitfit
154
+ #print("Init model",model_name)
155
+ if (model_name is None):
156
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
157
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
158
+ self.model = AutoModel.from_pretrained(model_name)
159
+ self.model.eval()
160
+
161
+ def mean_pooling(self,model_output, attention_mask):
162
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
163
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
164
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
165
+
166
+ def compute_embeddings(self,input_data,is_file):
167
+ #print("Computing embeddings for:", input_data[:20])
168
+ model = self.model
169
+ tokenizer = self.tokenizer
170
+
171
+ texts = read_text(input_data) if is_file == True else input_data
172
+
173
+ encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
174
+
175
+ # Compute token embeddings
176
+ with torch.no_grad():
177
+ model_output = model(**encoded_input)
178
+
179
+ # Perform pooling
180
+ sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
181
+
182
+ # Normalize embeddings
183
+ sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
184
+
185
+ return texts,sentence_embeddings
186
+
187
+ def output_results(self,output_file,texts,embeddings,main_index = 0):
188
+ # Calculate cosine similarities
189
+ # Cosine similarities are in [-1, 1]. Higher means more similar
190
+ cosine_dict = {}
191
+ #print("Total sentences",len(texts))
192
+ for i in range(len(texts)):
193
+ cosine_dict[texts[i]] = 1 - cosine(embeddings[main_index], embeddings[i])
194
+
195
+ #print("Input sentence:",texts[main_index])
196
+ sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
197
+ if (self.debug):
198
+ for key in sorted_dict:
199
+ print("Cosine similarity with \"%s\" is: %.3f" % (key, sorted_dict[key]))
200
+ if (output_file is not None):
201
+ with open(output_file,"w") as fp:
202
+ fp.write(json.dumps(sorted_dict,indent=0))
203
+ return sorted_dict
204
+
205
+
206
+
207
+ if __name__ == '__main__':
208
+ parser = argparse.ArgumentParser(description='SGPT model for sentence embeddings ',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
209
+ parser.add_argument('-input', action="store", dest="input",required=True,help="Input file with sentences")
210
+ parser.add_argument('-output', action="store", dest="output",default="output.txt",help="Output file with results")
211
+ parser.add_argument('-model', action="store", dest="model",default="sentence-transformers/all-MiniLM-L6-v2",help="model name")
212
+
213
+ results = parser.parse_args()
214
+ obj = HFModel()
215
+ obj.init_model(results.model)
216
+ texts, embeddings = obj.compute_embeddings(results.input,is_file = True)
217
+ results = obj.output_results(results.output,texts,embeddings)