Spaces:
Build error
Build error
taskswithcode
commited on
Commit
•
d872b74
1
Parent(s):
515e4d1
Fixes
Browse files- app.py +24 -16
- clus_app_models.json +61 -1
- text-similarity-ada-001imdb_sent_embed.json +0 -0
- text-similarity-ada-001larger_test_embed.json +0 -0
- text-similarity-ada-001small_test_embed.json +0 -0
- text-similarity-babbage-001imdb_sent_embed.json +0 -0
- text-similarity-babbage-001larger_test_embed.json +0 -0
- text-similarity-babbage-001small_test_embed.json +0 -0
- text-similarity-curie-001imdb_sent_embed.json +0 -0
- text-similarity-curie-001larger_test_embed.json +0 -0
- text-similarity-curie-001small_test_embed.json +0 -0
- text-similarity-davinci-001small_test_embed.json +0 -0
- twc_embeddings.py +6 -6
- twc_openai_embeddings.py +102 -0
app.py
CHANGED
@@ -6,6 +6,7 @@ from io import StringIO
|
|
6 |
import pdb
|
7 |
import json
|
8 |
from twc_embeddings import HFModel,SimCSEModel,SGPTModel,CausalLMModel,SGPTQnAModel
|
|
|
9 |
from twc_clustering import TWCClustering
|
10 |
import torch
|
11 |
import requests
|
@@ -60,7 +61,7 @@ def get_views(action):
|
|
60 |
|
61 |
def construct_model_info_for_display(model_names):
|
62 |
options_arr = []
|
63 |
-
markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b><br/><i>
|
64 |
markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>"
|
65 |
for node in model_names:
|
66 |
options_arr .append(node["name"])
|
@@ -96,22 +97,22 @@ def load_model(model_name,model_class,load_model_name):
|
|
96 |
ret_model.init_model(load_model_name)
|
97 |
assert(ret_model is not None)
|
98 |
except Exception as e:
|
99 |
-
st.error("Unable to load model:
|
100 |
pass
|
101 |
return ret_model
|
102 |
|
103 |
|
104 |
|
105 |
@st.experimental_memo
|
106 |
-
def cached_compute_similarity(sentences,_model,model_name,threshold,_cluster,clustering_type):
|
107 |
-
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
108 |
results = _cluster.cluster(None,texts,embeddings,threshold,clustering_type)
|
109 |
return results
|
110 |
|
111 |
|
112 |
-
def uncached_compute_similarity(sentences,_model,model_name,threshold,cluster,clustering_type):
|
113 |
with st.spinner('Computing vectors for sentences'):
|
114 |
-
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
115 |
results = cluster.cluster(None,texts,embeddings,threshold,clustering_type)
|
116 |
#st.success("Similarity computation complete")
|
117 |
return results
|
@@ -124,7 +125,7 @@ def get_model_info(model_names,model_name):
|
|
124 |
return get_model_info(model_names,DEFAULT_HF_MODEL)
|
125 |
|
126 |
|
127 |
-
def run_test(model_names,model_name,sentences,display_area,threshold,user_uploaded,custom_model,clustering_type):
|
128 |
display_area.text("Loading model:" + model_name)
|
129 |
#Note. model_name may get mapped to new name in the call below for custom models
|
130 |
orig_model_name = model_name
|
@@ -136,14 +137,18 @@ def run_test(model_names,model_name,sentences,display_area,threshold,user_upload
|
|
136 |
if ("Note" in model_info):
|
137 |
fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
|
138 |
display_area.write(fail_link)
|
|
|
|
|
|
|
|
|
139 |
model = load_model(model_name,model_info["class"],load_model_name)
|
140 |
display_area.text("Model " + model_name + " load complete")
|
141 |
try:
|
142 |
if (user_uploaded):
|
143 |
-
results = uncached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"],clustering_type)
|
144 |
else:
|
145 |
display_area.text("Computing vectors for sentences")
|
146 |
-
results = cached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"],clustering_type)
|
147 |
display_area.text("Similarity computation complete")
|
148 |
return results
|
149 |
|
@@ -263,15 +268,18 @@ def app_main(app_mode,example_files,model_name_files,clus_types):
|
|
263 |
st.session_state["model_name"] = selected_model
|
264 |
st.session_state["threshold"] = threshold
|
265 |
st.session_state["overlapped"] = cluster_types[clustering_type]["type"]
|
266 |
-
results = run_test(model_names,run_model,sentences,display_area,threshold,(uploaded_file is not None),(len(custom_model_selection) != 0),cluster_types[clustering_type]["type"])
|
267 |
display_area.empty()
|
268 |
with display_area.container():
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
275 |
st.download_button(
|
276 |
label="Download results as json",
|
277 |
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
|
|
|
6 |
import pdb
|
7 |
import json
|
8 |
from twc_embeddings import HFModel,SimCSEModel,SGPTModel,CausalLMModel,SGPTQnAModel
|
9 |
+
from twc_openai_embeddings import OpenAIModel
|
10 |
from twc_clustering import TWCClustering
|
11 |
import torch
|
12 |
import requests
|
|
|
61 |
|
62 |
def construct_model_info_for_display(model_names):
|
63 |
options_arr = []
|
64 |
+
markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b><br/><i>The selected models satisfy one or more of the following (1) state-of-the-art (2) the most downloaded models on Hugging Face (3) Large Language Models (e.g. GPT-3)</i></div>"
|
65 |
markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>"
|
66 |
for node in model_names:
|
67 |
options_arr .append(node["name"])
|
|
|
97 |
ret_model.init_model(load_model_name)
|
98 |
assert(ret_model is not None)
|
99 |
except Exception as e:
|
100 |
+
st.error(f"Unable to load model class:{model_class} model_name: {model_name} load_model_name: {load_model_name} {str(e)}")
|
101 |
pass
|
102 |
return ret_model
|
103 |
|
104 |
|
105 |
|
106 |
@st.experimental_memo
|
107 |
+
def cached_compute_similarity(input_file_name,sentences,_model,model_name,threshold,_cluster,clustering_type):
|
108 |
+
texts,embeddings = _model.compute_embeddings(input_file_name,sentences,is_file=False)
|
109 |
results = _cluster.cluster(None,texts,embeddings,threshold,clustering_type)
|
110 |
return results
|
111 |
|
112 |
|
113 |
+
def uncached_compute_similarity(input_file_name,sentences,_model,model_name,threshold,cluster,clustering_type):
|
114 |
with st.spinner('Computing vectors for sentences'):
|
115 |
+
texts,embeddings = _model.compute_embeddings(input_file_name,sentences,is_file=False)
|
116 |
results = cluster.cluster(None,texts,embeddings,threshold,clustering_type)
|
117 |
#st.success("Similarity computation complete")
|
118 |
return results
|
|
|
125 |
return get_model_info(model_names,DEFAULT_HF_MODEL)
|
126 |
|
127 |
|
128 |
+
def run_test(model_names,model_name,input_file_name,sentences,display_area,threshold,user_uploaded,custom_model,clustering_type):
|
129 |
display_area.text("Loading model:" + model_name)
|
130 |
#Note. model_name may get mapped to new name in the call below for custom models
|
131 |
orig_model_name = model_name
|
|
|
137 |
if ("Note" in model_info):
|
138 |
fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
|
139 |
display_area.write(fail_link)
|
140 |
+
if (user_uploaded and "custom_load" in model_info and model_info["custom_load"] == "False"):
|
141 |
+
fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
|
142 |
+
display_area.write(fail_link)
|
143 |
+
return {"error":fail_link}
|
144 |
model = load_model(model_name,model_info["class"],load_model_name)
|
145 |
display_area.text("Model " + model_name + " load complete")
|
146 |
try:
|
147 |
if (user_uploaded):
|
148 |
+
results = uncached_compute_similarity(input_file_name,sentences,model,model_name,threshold,st.session_state["cluster"],clustering_type)
|
149 |
else:
|
150 |
display_area.text("Computing vectors for sentences")
|
151 |
+
results = cached_compute_similarity(input_file_name,sentences,model,model_name,threshold,st.session_state["cluster"],clustering_type)
|
152 |
display_area.text("Similarity computation complete")
|
153 |
return results
|
154 |
|
|
|
268 |
st.session_state["model_name"] = selected_model
|
269 |
st.session_state["threshold"] = threshold
|
270 |
st.session_state["overlapped"] = cluster_types[clustering_type]["type"]
|
271 |
+
results = run_test(model_names,run_model,st.session_state["file_name"],sentences,display_area,threshold,(uploaded_file is not None),(len(custom_model_selection) != 0),cluster_types[clustering_type]["type"])
|
272 |
display_area.empty()
|
273 |
with display_area.container():
|
274 |
+
if ("error" in results):
|
275 |
+
st.error(results["error"])
|
276 |
+
else:
|
277 |
+
device = 'GPU' if torch.cuda.is_available() else 'CPU'
|
278 |
+
response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for {len(sentences)} sentences"
|
279 |
+
if (len(custom_model_selection) != 0):
|
280 |
+
st.info("Custom model overrides model selection in step 2 above. So please clear the custom model text box to choose models from step 2")
|
281 |
+
display_results(sentences,results,response_info,app_mode,run_model)
|
282 |
+
#st.json(results)
|
283 |
st.download_button(
|
284 |
label="Download results as json",
|
285 |
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
|
clus_app_models.json
CHANGED
@@ -84,7 +84,67 @@
|
|
84 |
},
|
85 |
"paper_url":"https://arxiv.org/abs/2104.08821v4",
|
86 |
"mark":"True",
|
87 |
-
"class":"SimCSEModel","sota_link":"https://paperswithcode.com/sota/semantic-textual-similarity-on-sick"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
|
90 |
]
|
|
|
84 |
},
|
85 |
"paper_url":"https://arxiv.org/abs/2104.08821v4",
|
86 |
"mark":"True",
|
87 |
+
"class":"SimCSEModel","sota_link":"https://paperswithcode.com/sota/semantic-textual-similarity-on-sick"},
|
88 |
+
{ "name":"GPT-3-175B (text-similarity-davinci-001)" ,
|
89 |
+
"model":"text-similarity-davinci-001",
|
90 |
+
"fork_url":"https://openai.com/api/",
|
91 |
+
"orig_author_url":"https://openai.com/api/",
|
92 |
+
"orig_author":"OpenAI",
|
93 |
+
"sota_info": {
|
94 |
+
"task":"GPT-3 achieves strong zero-shot and few-shot performance on many NLP datasets etc.",
|
95 |
+
"sota_link":"https://paperswithcode.com/method/gpt-3"
|
96 |
+
},
|
97 |
+
"paper_url":"https://arxiv.org/abs/2005.14165v4",
|
98 |
+
"mark":"True",
|
99 |
+
"custom_load":"False",
|
100 |
+
"Note":"Custom file upload requires OpenAI API access to create embeddings. For API access, use this link ",
|
101 |
+
"alt_url":"https://openai.com/api/",
|
102 |
+
"class":"OpenAIModel","sota_link":"https://arxiv.org/abs/2005.14165v4"},
|
103 |
+
{ "name":"GPT-3-6.7B (text-similarity-curie-001)" ,
|
104 |
+
"model":"text-similarity-curie-001",
|
105 |
+
"fork_url":"https://openai.com/api/",
|
106 |
+
"orig_author_url":"https://openai.com/api/",
|
107 |
+
"orig_author":"OpenAI",
|
108 |
+
"sota_info": {
|
109 |
+
"task":"GPT-3 achieves strong zero-shot and few-shot performance on many NLP datasets etc.",
|
110 |
+
"sota_link":"https://paperswithcode.com/method/gpt-3"
|
111 |
+
},
|
112 |
+
"paper_url":"https://arxiv.org/abs/2005.14165v4",
|
113 |
+
"mark":"True",
|
114 |
+
"custom_load":"False",
|
115 |
+
"Note":"Custom file upload requires OpenAI API access to create embeddings. For API access, use this link ",
|
116 |
+
"alt_url":"https://openai.com/api/",
|
117 |
+
"class":"OpenAIModel","sota_link":"https://arxiv.org/abs/2005.14165v4"},
|
118 |
+
{ "name":"GPT-3-1.3B (text-similarity-babbage-001)" ,
|
119 |
+
"model":"text-similarity-babbage-001",
|
120 |
+
"fork_url":"https://openai.com/api/",
|
121 |
+
"orig_author_url":"https://openai.com/api/",
|
122 |
+
"orig_author":"OpenAI",
|
123 |
+
"sota_info": {
|
124 |
+
"task":"GPT-3 achieves strong zero-shot and few-shot performance on many NLP datasets etc.",
|
125 |
+
"sota_link":"https://paperswithcode.com/method/gpt-3"
|
126 |
+
},
|
127 |
+
"paper_url":"https://arxiv.org/abs/2005.14165v4",
|
128 |
+
"mark":"True",
|
129 |
+
"custom_load":"False",
|
130 |
+
"Note":"Custom file upload requires OpenAI API access to create embeddings. For API access, use this link ",
|
131 |
+
"alt_url":"https://openai.com/api/",
|
132 |
+
"class":"OpenAIModel","sota_link":"https://arxiv.org/abs/2005.14165v4"},
|
133 |
+
{ "name":"GPT-3-350M (text-similarity-ada-001)" ,
|
134 |
+
"model":"text-similarity-ada-001",
|
135 |
+
"fork_url":"https://openai.com/api/",
|
136 |
+
"orig_author_url":"https://openai.com/api/",
|
137 |
+
"orig_author":"OpenAI",
|
138 |
+
"sota_info": {
|
139 |
+
"task":"GPT-3 achieves strong zero-shot and few-shot performance on many NLP datasets etc.",
|
140 |
+
"sota_link":"https://paperswithcode.com/method/gpt-3"
|
141 |
+
},
|
142 |
+
"paper_url":"https://arxiv.org/abs/2005.14165v4",
|
143 |
+
"mark":"True",
|
144 |
+
"custom_load":"False",
|
145 |
+
"Note":"Custom file upload requires OpenAI API access to create embeddings. For API access, use this link ",
|
146 |
+
"alt_url":"https://openai.com/api/",
|
147 |
+
"class":"OpenAIModel","sota_link":"https://arxiv.org/abs/2005.14165v4"}
|
148 |
|
149 |
|
150 |
]
|
text-similarity-ada-001imdb_sent_embed.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-similarity-ada-001larger_test_embed.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-similarity-ada-001small_test_embed.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-similarity-babbage-001imdb_sent_embed.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-similarity-babbage-001larger_test_embed.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-similarity-babbage-001small_test_embed.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-similarity-curie-001imdb_sent_embed.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-similarity-curie-001larger_test_embed.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-similarity-curie-001small_test_embed.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-similarity-davinci-001small_test_embed.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
twc_embeddings.py
CHANGED
@@ -32,7 +32,7 @@ class CausalLMModel:
|
|
32 |
self.model.eval()
|
33 |
self.prompt = 'Documents are searched to find matches with the same content.\nThe document "{}" is a good search result for "'
|
34 |
|
35 |
-
def compute_embeddings(self,input_data,is_file):
|
36 |
if (self.debug):
|
37 |
print("Computing embeddings for:", input_data[:20])
|
38 |
model = self.model
|
@@ -160,7 +160,7 @@ class SGPTQnAModel:
|
|
160 |
|
161 |
return embeddings
|
162 |
|
163 |
-
def compute_embeddings(self,input_data,is_file):
|
164 |
if (self.debug):
|
165 |
print("Computing embeddings for:", input_data[:20])
|
166 |
model = self.model
|
@@ -215,7 +215,7 @@ class SimCSEModel:
|
|
215 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
216 |
self.model = AutoModel.from_pretrained(model_name)
|
217 |
|
218 |
-
def compute_embeddings(self,input_data,is_file):
|
219 |
texts = read_text(input_data) if is_file == True else input_data
|
220 |
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
221 |
with torch.no_grad():
|
@@ -266,7 +266,7 @@ class SGPTModel:
|
|
266 |
# Deactivate Dropout (There is no dropout in the above models so it makes no difference here but other SGPT models may have dropout)
|
267 |
self.model.eval()
|
268 |
|
269 |
-
def compute_embeddings(self,input_data,is_file):
|
270 |
if (self.debug):
|
271 |
print("Computing embeddings for:", input_data[:20])
|
272 |
model = self.model
|
@@ -353,7 +353,7 @@ class HFModel:
|
|
353 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
354 |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
355 |
|
356 |
-
def compute_embeddings(self,input_data,is_file):
|
357 |
#print("Computing embeddings for:", input_data[:20])
|
358 |
model = self.model
|
359 |
tokenizer = self.tokenizer
|
@@ -403,5 +403,5 @@ if __name__ == '__main__':
|
|
403 |
results = parser.parse_args()
|
404 |
obj = HFModel()
|
405 |
obj.init_model(results.model)
|
406 |
-
texts, embeddings = obj.compute_embeddings(results.input,is_file = True)
|
407 |
results = obj.output_results(results.output,texts,embeddings)
|
|
|
32 |
self.model.eval()
|
33 |
self.prompt = 'Documents are searched to find matches with the same content.\nThe document "{}" is a good search result for "'
|
34 |
|
35 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
36 |
if (self.debug):
|
37 |
print("Computing embeddings for:", input_data[:20])
|
38 |
model = self.model
|
|
|
160 |
|
161 |
return embeddings
|
162 |
|
163 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
164 |
if (self.debug):
|
165 |
print("Computing embeddings for:", input_data[:20])
|
166 |
model = self.model
|
|
|
215 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
216 |
self.model = AutoModel.from_pretrained(model_name)
|
217 |
|
218 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
219 |
texts = read_text(input_data) if is_file == True else input_data
|
220 |
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
221 |
with torch.no_grad():
|
|
|
266 |
# Deactivate Dropout (There is no dropout in the above models so it makes no difference here but other SGPT models may have dropout)
|
267 |
self.model.eval()
|
268 |
|
269 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
270 |
if (self.debug):
|
271 |
print("Computing embeddings for:", input_data[:20])
|
272 |
model = self.model
|
|
|
353 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
354 |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
355 |
|
356 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
357 |
#print("Computing embeddings for:", input_data[:20])
|
358 |
model = self.model
|
359 |
tokenizer = self.tokenizer
|
|
|
403 |
results = parser.parse_args()
|
404 |
obj = HFModel()
|
405 |
obj.init_model(results.model)
|
406 |
+
texts, embeddings = obj.compute_embeddings(results.input,results.input,is_file = True)
|
407 |
results = obj.output_results(results.output,texts,embeddings)
|
twc_openai_embeddings.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scipy.spatial.distance import cosine
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import openai
|
6 |
+
import pdb
|
7 |
+
|
8 |
+
def read_text(input_file):
|
9 |
+
arr = open(input_file).read().split("\n")
|
10 |
+
return arr[:-1]
|
11 |
+
|
12 |
+
|
13 |
+
class OpenAIModel:
|
14 |
+
def __init__(self):
|
15 |
+
self.debug = False
|
16 |
+
self.model_name = None
|
17 |
+
self.skip_key = True
|
18 |
+
print("In OpenAI API constructor")
|
19 |
+
|
20 |
+
|
21 |
+
def init_model(self,model_name = None):
|
22 |
+
#print("OpenAI: Init model",model_name)
|
23 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
24 |
+
if (openai.api_key == None):
|
25 |
+
openai.api_key = ""
|
26 |
+
print("API key not set")
|
27 |
+
|
28 |
+
if (len(openai.api_key) == 0 and not self.skip_key):
|
29 |
+
print("Open API key not set")
|
30 |
+
|
31 |
+
if (model_name is None):
|
32 |
+
self.model_name = "text-similarity-ada-001"
|
33 |
+
else:
|
34 |
+
self.model_name = model_name
|
35 |
+
print("OpenAI: Init model complete",model_name)
|
36 |
+
|
37 |
+
|
38 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
39 |
+
if (len(openai.api_key) == 0 and not self.skip_key):
|
40 |
+
print("Open API key not set")
|
41 |
+
return [],[]
|
42 |
+
#print("In compute embeddings after key check")
|
43 |
+
in_file = self.model_name + '.'.join(input_file_name.split('.')[:-1]) + "_embed.json"
|
44 |
+
cached = False
|
45 |
+
try:
|
46 |
+
fp = open(in_file)
|
47 |
+
cached = True
|
48 |
+
embeddings = json.load(fp)
|
49 |
+
print("Using cached embeddings")
|
50 |
+
except:
|
51 |
+
pass
|
52 |
+
|
53 |
+
texts = read_text(input_data) if is_file == True else input_data
|
54 |
+
if (not cached):
|
55 |
+
print(f"Computing embeddings for {input_file_name} and model {self.model_name}")
|
56 |
+
response = openai.Embedding.create(
|
57 |
+
input=texts,
|
58 |
+
model=self.model_name
|
59 |
+
)
|
60 |
+
embeddings = []
|
61 |
+
for i in range(len(response['data'])):
|
62 |
+
embeddings.append(response['data'][i]['embedding'])
|
63 |
+
if (not cached):
|
64 |
+
with open(in_file,"w") as fp:
|
65 |
+
json.dump(embeddings,fp)
|
66 |
+
return texts,embeddings
|
67 |
+
|
68 |
+
def output_results(self,output_file,texts,embeddings,main_index = 0):
|
69 |
+
if (len(openai.api_key) == 0 and not self.skip_key):
|
70 |
+
print("Open API key not set")
|
71 |
+
return {}
|
72 |
+
#print("In output results after key check")
|
73 |
+
# Calculate cosine similarities
|
74 |
+
# Cosine similarities are in [-1, 1]. Higher means more similar
|
75 |
+
cosine_dict = {}
|
76 |
+
#print("Total sentences",len(texts))
|
77 |
+
for i in range(len(texts)):
|
78 |
+
cosine_dict[texts[i]] = 1 - cosine(embeddings[main_index], embeddings[i])
|
79 |
+
|
80 |
+
#print("Input sentence:",texts[main_index])
|
81 |
+
sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
|
82 |
+
if (self.debug):
|
83 |
+
for key in sorted_dict:
|
84 |
+
print("Cosine similarity with \"%s\" is: %.3f" % (key, sorted_dict[key]))
|
85 |
+
if (output_file is not None):
|
86 |
+
with open(output_file,"w") as fp:
|
87 |
+
fp.write(json.dumps(sorted_dict,indent=0))
|
88 |
+
return sorted_dict
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == '__main__':
|
93 |
+
parser = argparse.ArgumentParser(description='OpenAI model for sentence embeddings ',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
94 |
+
parser.add_argument('-input', action="store", dest="input",required=True,help="Input file with sentences")
|
95 |
+
parser.add_argument('-output', action="store", dest="output",default="output.txt",help="Output file with results")
|
96 |
+
parser.add_argument('-model', action="store", dest="model",default="text-similarity-ada-001",help="model name")
|
97 |
+
|
98 |
+
results = parser.parse_args()
|
99 |
+
obj = OpenAIModel()
|
100 |
+
obj.init_model(results.model)
|
101 |
+
texts, embeddings = obj.compute_embeddings(results.input,is_file = True)
|
102 |
+
results = obj.output_results(results.output,texts,embeddings)
|