taskswithcode commited on
Commit
07e062e
1 Parent(s): 56e7f3c
Files changed (3) hide show
  1. imdb_sent.txt +2 -2
  2. run.sh +1 -1
  3. twc_embeddings.py +190 -0
imdb_sent.txt CHANGED
@@ -47,7 +47,7 @@ a mesmerizing film that certainly keeps your attention... Ben Daniels is fascina
47
  I hope this group of film-makers never re-unites.
48
  Unwatchable. You can't even make it past the first three minutes. And this is coming from a huge Adam Sandler fan!!1
49
  "One of the funniest movies made in recent years. Good characterization, plot and exceptional chemistry make this one a classic"
50
- "Add this little gem to your list of holiday regulars. It is<br /><br />sweet, funny, and endearing"
51
  "no comment - stupid movie, acting average or worse... screenplay - no sense at all... SKIP IT!"
52
  "If you haven't seen this, it's terrible. It is pure trash. I saw this about 17 years ago, and I'm still screwed up from it."
53
  Absolutely fantastic! Whatever I say wouldn't do this underrated movie the justice it deserves. Watch it now! FANTASTIC!
@@ -56,7 +56,7 @@ Widow hires a psychopath as a handyman. Sloppy film noir thriller which doesn't
56
  The Fiendish Plot of Dr. Fu Manchu (1980). This is hands down the worst film I've ever seen. What a sad way for a great comedian to go out.
57
  "Obviously written for the stage. Lightweight but worthwhile. How can you go wrong with Ralph Richardson, Olivier and Merle Oberon."
58
  This movie turned out to be better than I had expected it to be. Some parts were pretty funny. It was nice to have a movie with a new plot.
59
- This movie is terrible. It's about some no brain surfin dude that inherits some company. Does Carrot Top have no shame?<br /><br />
60
  Adrian Pasdar is excellent is this film. He makes a fascinating woman.
61
  "An unfunny, unworthy picture which is an undeserving end to Peter Sellers' career. It is a pity this movie was ever made."
62
  "The plot was really weak and confused. This is a true Oprah flick. (In Oprah's world, all men are evil and all women are victims.)"
 
47
  I hope this group of film-makers never re-unites.
48
  Unwatchable. You can't even make it past the first three minutes. And this is coming from a huge Adam Sandler fan!!1
49
  "One of the funniest movies made in recent years. Good characterization, plot and exceptional chemistry make this one a classic"
50
+ "Add this little gem to your list of holiday regulars. It is sweet, funny, and endearing"
51
  "no comment - stupid movie, acting average or worse... screenplay - no sense at all... SKIP IT!"
52
  "If you haven't seen this, it's terrible. It is pure trash. I saw this about 17 years ago, and I'm still screwed up from it."
53
  Absolutely fantastic! Whatever I say wouldn't do this underrated movie the justice it deserves. Watch it now! FANTASTIC!
 
56
  The Fiendish Plot of Dr. Fu Manchu (1980). This is hands down the worst film I've ever seen. What a sad way for a great comedian to go out.
57
  "Obviously written for the stage. Lightweight but worthwhile. How can you go wrong with Ralph Richardson, Olivier and Merle Oberon."
58
  This movie turned out to be better than I had expected it to be. Some parts were pretty funny. It was nice to have a movie with a new plot.
59
+ This movie is terrible. It's about some no brain surfin dude that inherits some company. Does Carrot Top have no shame?
60
  Adrian Pasdar is excellent is this film. He makes a fascinating woman.
61
  "An unfunny, unworthy picture which is an undeserving end to Peter Sellers' career. It is a pity this movie was ever made."
62
  "The plot was really weak and confused. This is a true Oprah flick. (In Oprah's world, all men are evil and all women are victims.)"
run.sh CHANGED
@@ -1,2 +1,2 @@
1
- streamlit run app.py --server.port 80
2
 
 
1
+ streamlit run app.py --server.port 80 "1" "sim_app_examples.json" "sim_app_models.json"
2
 
twc_embeddings.py CHANGED
@@ -1,4 +1,5 @@
1
  from transformers import AutoModel, AutoTokenizer
 
2
  from scipy.spatial.distance import cosine
3
  import argparse
4
  import json
@@ -11,6 +12,195 @@ def read_text(input_file):
11
  return arr[:-1]
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class SimCSEModel:
15
  def __init__(self):
16
  self.model = None
 
1
  from transformers import AutoModel, AutoTokenizer
2
+ from transformers import AutoModelForCausalLM
3
  from scipy.spatial.distance import cosine
4
  import argparse
5
  import json
 
12
  return arr[:-1]
13
 
14
 
15
+ class CausalLMModel:
16
+ def __init__(self):
17
+ self.model = None
18
+ self.tokenizer = None
19
+ self.debug = False
20
+ print("In CausalLMModel Constructor")
21
+
22
+ def init_model(self,model_name = None):
23
+ # Get our models - The package will take care of downloading the models automatically
24
+ # For best performance: Muennighoff/SGPT-5.8B-weightedmean-nli-bitfit
25
+ if (self.debug):
26
+ print("Init model",model_name)
27
+ # For best performance: EleutherAI/gpt-j-6B
28
+ if (model_name is None):
29
+ model_name = "EleutherAI/gpt-neo-125M"
30
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
31
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
32
+ self.model.eval()
33
+ self.prompt = 'Documents are searched to find matches with the same content.\nThe document "{}" is a good search result for "'
34
+
35
+ def compute_embeddings(self,input_data,is_file):
36
+ if (self.debug):
37
+ print("Computing embeddings for:", input_data[:20])
38
+ model = self.model
39
+ tokenizer = self.tokenizer
40
+
41
+ texts = read_text(input_data) if is_file == True else input_data
42
+ query = texts[0]
43
+ docs = texts[1:]
44
+
45
+ # Tokenize input texts
46
+
47
+ #print(f"Query: {query}")
48
+ scores = []
49
+ for doc in docs:
50
+ context = self.prompt.format(doc)
51
+
52
+ context_enc = tokenizer.encode(context, add_special_tokens=False)
53
+ continuation_enc = tokenizer.encode(query, add_special_tokens=False)
54
+ # Slice off the last token, as we take its probability from the one before
55
+ model_input = torch.tensor(context_enc+continuation_enc[:-1])
56
+ continuation_len = len(continuation_enc)
57
+ input_len, = model_input.shape
58
+
59
+ # [seq_len] -> [seq_len, vocab]
60
+ logprobs = torch.nn.functional.log_softmax(model(model_input)[0], dim=-1).cpu()
61
+ # [seq_len, vocab] -> [continuation_len, vocab]
62
+ logprobs = logprobs[input_len-continuation_len:]
63
+ # Gather the log probabilities of the continuation tokens -> [continuation_len]
64
+ logprobs = torch.gather(logprobs, 1, torch.tensor(continuation_enc).unsqueeze(-1)).squeeze(-1)
65
+ score = torch.sum(logprobs)
66
+ scores.append(score.tolist())
67
+ return texts,scores
68
+
69
+ def output_results(self,output_file,texts,scores,main_index = 0):
70
+ cosine_dict = {}
71
+ docs = texts[1:]
72
+ if (self.debug):
73
+ print("Total sentences",len(texts))
74
+ assert(len(scores) == len(docs))
75
+ for i in range(len(docs)):
76
+ cosine_dict[docs[i]] = scores[i]
77
+
78
+ if (self.debug):
79
+ print("Input sentence:",texts[main_index])
80
+ sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
81
+ if (self.debug):
82
+ for key in sorted_dict:
83
+ print("Document score for \"%s\" is: %.3f" % (key[:100], sorted_dict[key]))
84
+ if (output_file is not None):
85
+ with open(output_file,"w") as fp:
86
+ fp.write(json.dumps(sorted_dict,indent=0))
87
+ return sorted_dict
88
+
89
+
90
+ class SGPTQnAModel:
91
+ def __init__(self):
92
+ self.model = None
93
+ self.tokenizer = None
94
+ self.debug = False
95
+ print("In SGPT Q&A Constructor")
96
+
97
+
98
+ def init_model(self,model_name = None):
99
+ # Get our models - The package will take care of downloading the models automatically
100
+ # For best performance: Muennighoff/SGPT-5.8B-weightedmean-nli-bitfit
101
+ if (self.debug):
102
+ print("Init model",model_name)
103
+ if (model_name is None):
104
+ model_name = "Muennighoff/SGPT-125M-weightedmean-msmarco-specb-bitfit"
105
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
106
+ self.model = AutoModel.from_pretrained(model_name)
107
+ self.model.eval()
108
+ self.SPECB_QUE_BOS = self.tokenizer.encode("[", add_special_tokens=False)[0]
109
+ self.SPECB_QUE_EOS = self.tokenizer.encode("]", add_special_tokens=False)[0]
110
+
111
+ self.SPECB_DOC_BOS = self.tokenizer.encode("{", add_special_tokens=False)[0]
112
+ self.SPECB_DOC_EOS = self.tokenizer.encode("}", add_special_tokens=False)[0]
113
+
114
+
115
+ def tokenize_with_specb(self,texts, is_query):
116
+ # Tokenize without padding
117
+ batch_tokens = self.tokenizer(texts, padding=False, truncation=True)
118
+ # Add special brackets & pay attention to them
119
+ for seq, att in zip(batch_tokens["input_ids"], batch_tokens["attention_mask"]):
120
+ if is_query:
121
+ seq.insert(0, self.SPECB_QUE_BOS)
122
+ seq.append(self.SPECB_QUE_EOS)
123
+ else:
124
+ seq.insert(0, self.SPECB_DOC_BOS)
125
+ seq.append(self.SPECB_DOC_EOS)
126
+ att.insert(0, 1)
127
+ att.append(1)
128
+ # Add padding
129
+ batch_tokens = self.tokenizer.pad(batch_tokens, padding=True, return_tensors="pt")
130
+ return batch_tokens
131
+
132
+ def get_weightedmean_embedding(self,batch_tokens, model):
133
+ # Get the embeddings
134
+ with torch.no_grad():
135
+ # Get hidden state of shape [bs, seq_len, hid_dim]
136
+ last_hidden_state = self.model(**batch_tokens, output_hidden_states=True, return_dict=True).last_hidden_state
137
+
138
+ # Get weights of shape [bs, seq_len, hid_dim]
139
+ weights = (
140
+ torch.arange(start=1, end=last_hidden_state.shape[1] + 1)
141
+ .unsqueeze(0)
142
+ .unsqueeze(-1)
143
+ .expand(last_hidden_state.size())
144
+ .float().to(last_hidden_state.device)
145
+ )
146
+
147
+ # Get attn mask of shape [bs, seq_len, hid_dim]
148
+ input_mask_expanded = (
149
+ batch_tokens["attention_mask"]
150
+ .unsqueeze(-1)
151
+ .expand(last_hidden_state.size())
152
+ .float()
153
+ )
154
+
155
+ # Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim
156
+ sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1)
157
+ sum_mask = torch.sum(input_mask_expanded * weights, dim=1)
158
+
159
+ embeddings = sum_embeddings / sum_mask
160
+
161
+ return embeddings
162
+
163
+ def compute_embeddings(self,input_data,is_file):
164
+ if (self.debug):
165
+ print("Computing embeddings for:", input_data[:20])
166
+ model = self.model
167
+ tokenizer = self.tokenizer
168
+
169
+ texts = read_text(input_data) if is_file == True else input_data
170
+
171
+ queries = [texts[0]]
172
+ docs = texts[1:]
173
+ query_embeddings = self.get_weightedmean_embedding(self.tokenize_with_specb(queries, is_query=True), self.model)
174
+ doc_embeddings = self.get_weightedmean_embedding(self.tokenize_with_specb(docs, is_query=False), self.model)
175
+ return texts,(query_embeddings,doc_embeddings)
176
+
177
+
178
+
179
+ def output_results(self,output_file,texts,embeddings,main_index = 0):
180
+ # Calculate cosine similarities
181
+ # Cosine similarities are in [-1, 1]. Higher means more similar
182
+ query_embeddings = embeddings[0]
183
+ doc_embeddings = embeddings[1]
184
+ cosine_dict = {}
185
+ queries = [texts[0]]
186
+ docs = texts[1:]
187
+ if (self.debug):
188
+ print("Total sentences",len(texts))
189
+ for i in range(len(docs)):
190
+ cosine_dict[docs[i]] = 1 - cosine(query_embeddings[0], doc_embeddings[i])
191
+
192
+ if (self.debug):
193
+ print("Input sentence:",texts[main_index])
194
+ sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
195
+ if (self.debug):
196
+ for key in sorted_dict:
197
+ print("Cosine similarity with \"%s\" is: %.3f" % (key, sorted_dict[key]))
198
+ if (output_file is not None):
199
+ with open(output_file,"w") as fp:
200
+ fp.write(json.dumps(sorted_dict,indent=0))
201
+ return sorted_dict
202
+
203
+
204
  class SimCSEModel:
205
  def __init__(self):
206
  self.model = None