Spaces:
Runtime error
Runtime error
taskswithcode
commited on
Commit
•
b65a786
1
Parent(s):
fb73c83
Fixes
Browse files- app.py +23 -15
- doc_app_models.json +61 -1
- text-search-ada-doc-001_planets_qna_search.json +0 -0
- text-search-ada-doc-001_qna2_search.json +0 -0
- text-search-ada-doc-001_qna_search.json +0 -0
- text-search-babbage-doc-001_planets_qna_search.json +0 -0
- text-search-babbage-doc-001_qna2_search.json +0 -0
- text-search-babbage-doc-001_qna_search.json +0 -0
- text-search-curie-doc-001_planets_qna_search.json +0 -0
- text-search-curie-doc-001_qna2_search.json +0 -0
- text-search-curie-doc-001_qna_search.json +0 -0
- text-search-davinci-doc-001_planets_qna_search.json +0 -0
- text-search-davinci-doc-001_qna2_search.json +0 -0
- text-search-davinci-doc-001_qna_search.json +0 -0
- twc_embeddings.py +6 -6
- twc_openai_search.py +124 -0
app.py
CHANGED
@@ -6,6 +6,7 @@ from io import StringIO
|
|
6 |
import pdb
|
7 |
import json
|
8 |
from twc_embeddings import HFModel,SimCSEModel,SGPTModel,CausalLMModel,SGPTQnAModel
|
|
|
9 |
import torch
|
10 |
import requests
|
11 |
import socket
|
@@ -59,7 +60,7 @@ def get_views(action):
|
|
59 |
|
60 |
def construct_model_info_for_display(model_names):
|
61 |
options_arr = []
|
62 |
-
markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b><br/><i>
|
63 |
markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>"
|
64 |
for node in model_names:
|
65 |
options_arr .append(node["name"])
|
@@ -102,15 +103,15 @@ def load_model(model_name,model_class,load_model_name):
|
|
102 |
|
103 |
|
104 |
@st.experimental_memo
|
105 |
-
def cached_compute_similarity(sentences,_model,model_name,main_index):
|
106 |
-
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
107 |
results = _model.output_results(None,texts,embeddings,main_index)
|
108 |
return results
|
109 |
|
110 |
|
111 |
-
def uncached_compute_similarity(sentences,_model,model_name,main_index):
|
112 |
with st.spinner('Computing vectors for sentences'):
|
113 |
-
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
114 |
results = _model.output_results(None,texts,embeddings,main_index)
|
115 |
#st.success("Similarity computation complete")
|
116 |
return results
|
@@ -123,7 +124,7 @@ def get_model_info(model_names,model_name):
|
|
123 |
return get_model_info(model_names,DEFAULT_HF_MODEL)
|
124 |
|
125 |
|
126 |
-
def run_test(model_names,model_name,sentences,display_area,main_index,user_uploaded,custom_model):
|
127 |
display_area.text("Loading model:" + model_name)
|
128 |
#Note. model_name may get mapped to new name in the call below for custom models
|
129 |
orig_model_name = model_name
|
@@ -135,14 +136,18 @@ def run_test(model_names,model_name,sentences,display_area,main_index,user_uploa
|
|
135 |
if ("Note" in model_info):
|
136 |
fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
|
137 |
display_area.write(fail_link)
|
|
|
|
|
|
|
|
|
138 |
model = load_model(model_name,model_info["class"],load_model_name)
|
139 |
display_area.text("Model " + model_name + " load complete")
|
140 |
try:
|
141 |
if (user_uploaded):
|
142 |
-
results = uncached_compute_similarity(sentences,model,model_name,main_index)
|
143 |
else:
|
144 |
display_area.text("Computing vectors for sentences")
|
145 |
-
results = cached_compute_similarity(sentences,model,model_name,main_index)
|
146 |
display_area.text("Similarity computation complete")
|
147 |
return results
|
148 |
|
@@ -254,15 +259,18 @@ def app_main(app_mode,example_files,model_name_files):
|
|
254 |
run_model = selected_model
|
255 |
st.session_state["model_name"] = selected_model
|
256 |
st.session_state["main_index"] = main_index
|
257 |
-
results = run_test(model_names,run_model,sentences,display_area,main_index - 1,(uploaded_file is not None),(len(custom_model_selection) != 0))
|
258 |
display_area.empty()
|
259 |
with display_area.container():
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
|
|
|
|
|
|
266 |
st.download_button(
|
267 |
label="Download results as json",
|
268 |
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
|
|
|
6 |
import pdb
|
7 |
import json
|
8 |
from twc_embeddings import HFModel,SimCSEModel,SGPTModel,CausalLMModel,SGPTQnAModel
|
9 |
+
from twc_openai_search import OpenAIQnAModel
|
10 |
import torch
|
11 |
import requests
|
12 |
import socket
|
|
|
60 |
|
61 |
def construct_model_info_for_display(model_names):
|
62 |
options_arr = []
|
63 |
+
markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b><br/><i>The selected models satisfy one or more of the following (1) state-of-the-art (2) the most downloaded models on Hugging Face (3) Large Language Models (e.g. GPT-3)</i></div>"
|
64 |
markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>"
|
65 |
for node in model_names:
|
66 |
options_arr .append(node["name"])
|
|
|
103 |
|
104 |
|
105 |
@st.experimental_memo
|
106 |
+
def cached_compute_similarity(input_file_name,sentences,_model,model_name,main_index):
|
107 |
+
texts,embeddings = _model.compute_embeddings(input_file_name,sentences,is_file=False)
|
108 |
results = _model.output_results(None,texts,embeddings,main_index)
|
109 |
return results
|
110 |
|
111 |
|
112 |
+
def uncached_compute_similarity(input_file_name,sentences,_model,model_name,main_index):
|
113 |
with st.spinner('Computing vectors for sentences'):
|
114 |
+
texts,embeddings = _model.compute_embeddings(input_file_name,sentences,is_file=False)
|
115 |
results = _model.output_results(None,texts,embeddings,main_index)
|
116 |
#st.success("Similarity computation complete")
|
117 |
return results
|
|
|
124 |
return get_model_info(model_names,DEFAULT_HF_MODEL)
|
125 |
|
126 |
|
127 |
+
def run_test(model_names,model_name,input_file_name,sentences,display_area,main_index,user_uploaded,custom_model):
|
128 |
display_area.text("Loading model:" + model_name)
|
129 |
#Note. model_name may get mapped to new name in the call below for custom models
|
130 |
orig_model_name = model_name
|
|
|
136 |
if ("Note" in model_info):
|
137 |
fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
|
138 |
display_area.write(fail_link)
|
139 |
+
if (user_uploaded and "custom_load" in model_info and model_info["custom_load"] == "False"):
|
140 |
+
fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
|
141 |
+
display_area.write(fail_link)
|
142 |
+
return {"error":fail_link}
|
143 |
model = load_model(model_name,model_info["class"],load_model_name)
|
144 |
display_area.text("Model " + model_name + " load complete")
|
145 |
try:
|
146 |
if (user_uploaded):
|
147 |
+
results = uncached_compute_similarity(input_file_name,sentences,model,model_name,main_index)
|
148 |
else:
|
149 |
display_area.text("Computing vectors for sentences")
|
150 |
+
results = cached_compute_similarity(input_file_name,sentences,model,model_name,main_index)
|
151 |
display_area.text("Similarity computation complete")
|
152 |
return results
|
153 |
|
|
|
259 |
run_model = selected_model
|
260 |
st.session_state["model_name"] = selected_model
|
261 |
st.session_state["main_index"] = main_index
|
262 |
+
results = run_test(model_names,run_model,st.session_state["file_name"],sentences,display_area,main_index - 1,(uploaded_file is not None),(len(custom_model_selection) != 0))
|
263 |
display_area.empty()
|
264 |
with display_area.container():
|
265 |
+
if ("error" in results):
|
266 |
+
st.error(results["error"])
|
267 |
+
else:
|
268 |
+
device = 'GPU' if torch.cuda.is_available() else 'CPU'
|
269 |
+
response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for {len(sentences)} sentences"
|
270 |
+
if (len(custom_model_selection) != 0):
|
271 |
+
st.info("Custom model overrides model selection in step 2 above. So please clear the custom model text box to choose models from step 2")
|
272 |
+
display_results(sentences,main_index - 1,results,response_info,app_mode,run_model)
|
273 |
+
#st.json(results)
|
274 |
st.download_button(
|
275 |
label="Download results as json",
|
276 |
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
|
doc_app_models.json
CHANGED
@@ -108,7 +108,67 @@
|
|
108 |
},
|
109 |
"paper_url":"https://arxiv.org/abs/2104.08821v4",
|
110 |
"mark":"True",
|
111 |
-
"class":"SimCSEModel","sota_link":"https://paperswithcode.com/sota/semantic-textual-similarity-on-sick"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
|
114 |
]
|
|
|
108 |
},
|
109 |
"paper_url":"https://arxiv.org/abs/2104.08821v4",
|
110 |
"mark":"True",
|
111 |
+
"class":"SimCSEModel","sota_link":"https://paperswithcode.com/sota/semantic-textual-similarity-on-sick"},
|
112 |
+
{ "name":"GPT-3-175B (text-search-davinci-doc-001)" ,
|
113 |
+
"model":"text-search-davinci-doc-001",
|
114 |
+
"fork_url":"https://openai.com/api/",
|
115 |
+
"orig_author_url":"https://openai.com/api/",
|
116 |
+
"orig_author":"OpenAI",
|
117 |
+
"sota_info": {
|
118 |
+
"task":"GPT-3 achieves strong zero-shot and few-shot performance on many NLP datasets etc.",
|
119 |
+
"sota_link":"https://paperswithcode.com/method/gpt-3"
|
120 |
+
},
|
121 |
+
"paper_url":"https://arxiv.org/abs/2005.14165v4",
|
122 |
+
"mark":"True",
|
123 |
+
"custom_load":"False",
|
124 |
+
"Note":"Custom file upload requires OpenAI API access to create embeddings. For API access, use this link ",
|
125 |
+
"alt_url":"https://openai.com/api/",
|
126 |
+
"class":"OpenAIQnAModel","sota_link":"https://arxiv.org/abs/2005.14165v4"},
|
127 |
+
{ "name":"GPT-3-6.7B (text-search-curie-doc-001)" ,
|
128 |
+
"model":"text-search-curie-doc-001",
|
129 |
+
"fork_url":"https://openai.com/api/",
|
130 |
+
"orig_author_url":"https://openai.com/api/",
|
131 |
+
"orig_author":"OpenAI",
|
132 |
+
"sota_info": {
|
133 |
+
"task":"GPT-3 achieves strong zero-shot and few-shot performance on many NLP datasets etc.",
|
134 |
+
"sota_link":"https://paperswithcode.com/method/gpt-3"
|
135 |
+
},
|
136 |
+
"paper_url":"https://arxiv.org/abs/2005.14165v4",
|
137 |
+
"mark":"True",
|
138 |
+
"custom_load":"False",
|
139 |
+
"Note":"Custom file upload requires OpenAI API access to create embeddings. For API access, use this link ",
|
140 |
+
"alt_url":"https://openai.com/api/",
|
141 |
+
"class":"OpenAIQnAModel","sota_link":"https://arxiv.org/abs/2005.14165v4"},
|
142 |
+
{ "name":"GPT-3-1.3B (text-search-babbage-doc-001)" ,
|
143 |
+
"model":"text-search-babbage-doc-001",
|
144 |
+
"fork_url":"https://openai.com/api/",
|
145 |
+
"orig_author_url":"https://openai.com/api/",
|
146 |
+
"orig_author":"OpenAI",
|
147 |
+
"sota_info": {
|
148 |
+
"task":"GPT-3 achieves strong zero-shot and few-shot performance on many NLP datasets etc.",
|
149 |
+
"sota_link":"https://paperswithcode.com/method/gpt-3"
|
150 |
+
},
|
151 |
+
"paper_url":"https://arxiv.org/abs/2005.14165v4",
|
152 |
+
"mark":"True",
|
153 |
+
"custom_load":"False",
|
154 |
+
"Note":"Custom file upload requires OpenAI API access to create embeddings. For API access, use this link ",
|
155 |
+
"alt_url":"https://openai.com/api/",
|
156 |
+
"class":"OpenAIQnAModel","sota_link":"https://arxiv.org/abs/2005.14165v4"},
|
157 |
+
{ "name":"GPT-3-350M (text-search-ada-doc-001)" ,
|
158 |
+
"model":"text-search-ada-doc-001",
|
159 |
+
"fork_url":"https://openai.com/api/",
|
160 |
+
"orig_author_url":"https://openai.com/api/",
|
161 |
+
"orig_author":"OpenAI",
|
162 |
+
"sota_info": {
|
163 |
+
"task":"GPT-3 achieves strong zero-shot and few-shot performance on many NLP datasets etc.",
|
164 |
+
"sota_link":"https://paperswithcode.com/method/gpt-3"
|
165 |
+
},
|
166 |
+
"paper_url":"https://arxiv.org/abs/2005.14165v4",
|
167 |
+
"mark":"True",
|
168 |
+
"custom_load":"False",
|
169 |
+
"Note":"Custom file upload requires OpenAI API access to create embeddings. For API access, use this link ",
|
170 |
+
"alt_url":"https://openai.com/api/",
|
171 |
+
"class":"OpenAIQnAModel","sota_link":"https://arxiv.org/abs/2005.14165v4"}
|
172 |
|
173 |
|
174 |
]
|
text-search-ada-doc-001_planets_qna_search.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-search-ada-doc-001_qna2_search.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-search-ada-doc-001_qna_search.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-search-babbage-doc-001_planets_qna_search.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-search-babbage-doc-001_qna2_search.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-search-babbage-doc-001_qna_search.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-search-curie-doc-001_planets_qna_search.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-search-curie-doc-001_qna2_search.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-search-curie-doc-001_qna_search.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-search-davinci-doc-001_planets_qna_search.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-search-davinci-doc-001_qna2_search.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text-search-davinci-doc-001_qna_search.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
twc_embeddings.py
CHANGED
@@ -32,7 +32,7 @@ class CausalLMModel:
|
|
32 |
self.model.eval()
|
33 |
self.prompt = 'Documents are searched to find matches with the same content.\nThe document "{}" is a good search result for "'
|
34 |
|
35 |
-
def compute_embeddings(self,input_data,is_file):
|
36 |
if (self.debug):
|
37 |
print("Computing embeddings for:", input_data[:20])
|
38 |
model = self.model
|
@@ -160,7 +160,7 @@ class SGPTQnAModel:
|
|
160 |
|
161 |
return embeddings
|
162 |
|
163 |
-
def compute_embeddings(self,input_data,is_file):
|
164 |
if (self.debug):
|
165 |
print("Computing embeddings for:", input_data[:20])
|
166 |
model = self.model
|
@@ -215,7 +215,7 @@ class SimCSEModel:
|
|
215 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
216 |
self.model = AutoModel.from_pretrained(model_name)
|
217 |
|
218 |
-
def compute_embeddings(self,input_data,is_file):
|
219 |
texts = read_text(input_data) if is_file == True else input_data
|
220 |
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
221 |
with torch.no_grad():
|
@@ -266,7 +266,7 @@ class SGPTModel:
|
|
266 |
# Deactivate Dropout (There is no dropout in the above models so it makes no difference here but other SGPT models may have dropout)
|
267 |
self.model.eval()
|
268 |
|
269 |
-
def compute_embeddings(self,input_data,is_file):
|
270 |
if (self.debug):
|
271 |
print("Computing embeddings for:", input_data[:20])
|
272 |
model = self.model
|
@@ -353,7 +353,7 @@ class HFModel:
|
|
353 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
354 |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
355 |
|
356 |
-
def compute_embeddings(self,input_data,is_file):
|
357 |
#print("Computing embeddings for:", input_data[:20])
|
358 |
model = self.model
|
359 |
tokenizer = self.tokenizer
|
@@ -403,5 +403,5 @@ if __name__ == '__main__':
|
|
403 |
results = parser.parse_args()
|
404 |
obj = HFModel()
|
405 |
obj.init_model(results.model)
|
406 |
-
texts, embeddings = obj.compute_embeddings(results.input,is_file = True)
|
407 |
results = obj.output_results(results.output,texts,embeddings)
|
|
|
32 |
self.model.eval()
|
33 |
self.prompt = 'Documents are searched to find matches with the same content.\nThe document "{}" is a good search result for "'
|
34 |
|
35 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
36 |
if (self.debug):
|
37 |
print("Computing embeddings for:", input_data[:20])
|
38 |
model = self.model
|
|
|
160 |
|
161 |
return embeddings
|
162 |
|
163 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
164 |
if (self.debug):
|
165 |
print("Computing embeddings for:", input_data[:20])
|
166 |
model = self.model
|
|
|
215 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
216 |
self.model = AutoModel.from_pretrained(model_name)
|
217 |
|
218 |
+
def compute_embeddings(self,input_file_name,input_file,input_data,is_file):
|
219 |
texts = read_text(input_data) if is_file == True else input_data
|
220 |
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
221 |
with torch.no_grad():
|
|
|
266 |
# Deactivate Dropout (There is no dropout in the above models so it makes no difference here but other SGPT models may have dropout)
|
267 |
self.model.eval()
|
268 |
|
269 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
270 |
if (self.debug):
|
271 |
print("Computing embeddings for:", input_data[:20])
|
272 |
model = self.model
|
|
|
353 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
354 |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
355 |
|
356 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
357 |
#print("Computing embeddings for:", input_data[:20])
|
358 |
model = self.model
|
359 |
tokenizer = self.tokenizer
|
|
|
403 |
results = parser.parse_args()
|
404 |
obj = HFModel()
|
405 |
obj.init_model(results.model)
|
406 |
+
texts, embeddings = obj.compute_embeddings(results.input,results.input,is_file = True)
|
407 |
results = obj.output_results(results.output,texts,embeddings)
|
twc_openai_search.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scipy.spatial.distance import cosine
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import openai
|
6 |
+
import pdb
|
7 |
+
|
8 |
+
def read_text(input_file):
|
9 |
+
arr = open(input_file).read().split("\n")
|
10 |
+
return arr[:-1]
|
11 |
+
|
12 |
+
|
13 |
+
class OpenAIQnAModel:
|
14 |
+
def __init__(self):
|
15 |
+
self.debug = False
|
16 |
+
self.q_model_name = None
|
17 |
+
self.d_model_name = None
|
18 |
+
self.skip_key = True
|
19 |
+
print("In OpenAI API constructor")
|
20 |
+
|
21 |
+
|
22 |
+
def init_model(self,model_name = None):
|
23 |
+
#print("OpenAI: Init model",model_name)
|
24 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
25 |
+
if (openai.api_key == None):
|
26 |
+
openai.api_key = ""
|
27 |
+
print("API key not set")
|
28 |
+
|
29 |
+
if (len(openai.api_key) == 0 and not self.skip_key):
|
30 |
+
print("Open API key not set")
|
31 |
+
|
32 |
+
if (model_name is None):
|
33 |
+
self.d_model_name = "text-search-ada-doc-001"
|
34 |
+
else:
|
35 |
+
self.d_model_name = model_name
|
36 |
+
self.q_model_name = self.construct_query_model_name(self.d_model_name)
|
37 |
+
print(f"OpenAI: Init model complete :query model {self.q_model_name} doc:{self.d_model_name}")
|
38 |
+
|
39 |
+
def construct_query_model_name(self,d_model_name):
|
40 |
+
return d_model_name.replace('-doc-','-query-')
|
41 |
+
|
42 |
+
|
43 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
44 |
+
if (len(openai.api_key) == 0 and not self.skip_key):
|
45 |
+
print("Open API key not set")
|
46 |
+
return [],[]
|
47 |
+
#print("In compute embeddings after key check")
|
48 |
+
in_file = input_file_name.split('/')[-1]
|
49 |
+
in_file = self.d_model_name + '_' + '.'.join(in_file.split('.')[:-1]) + "_search.json"
|
50 |
+
cached = False
|
51 |
+
try:
|
52 |
+
fp = open(in_file)
|
53 |
+
cached = True
|
54 |
+
embeddings = json.load(fp)
|
55 |
+
q_embeddings = [embeddings[0]]
|
56 |
+
d_embeddings = embeddings[1:]
|
57 |
+
print("Using cached embeddings")
|
58 |
+
except:
|
59 |
+
pass
|
60 |
+
|
61 |
+
texts = read_text(input_data) if is_file == True else input_data
|
62 |
+
queries = [texts[0]]
|
63 |
+
docs = texts[1:]
|
64 |
+
|
65 |
+
if (not cached):
|
66 |
+
print(f"Computing embeddings for {input_file_name} and query model {self.q_model_name}")
|
67 |
+
query_embeds = openai.Embedding.create(
|
68 |
+
input=queries,
|
69 |
+
model=self.q_model_name
|
70 |
+
)
|
71 |
+
print(f"Computing embeddings for {input_file_name} and doc model {self.q_model_name}")
|
72 |
+
doc_embeds = openai.Embedding.create(
|
73 |
+
input=docs,
|
74 |
+
model=self.d_model_name
|
75 |
+
)
|
76 |
+
q_embeddings = []
|
77 |
+
d_embeddings = []
|
78 |
+
for i in range(len(query_embeds['data'])):
|
79 |
+
q_embeddings.append(query_embeds['data'][i]['embedding'])
|
80 |
+
for i in range(len(doc_embeds['data'])):
|
81 |
+
d_embeddings.append(doc_embeds['data'][i]['embedding'])
|
82 |
+
if (not cached):
|
83 |
+
embeddings = q_embeddings + d_embeddings
|
84 |
+
with open(in_file,"w") as fp:
|
85 |
+
json.dump(embeddings,fp)
|
86 |
+
return texts,(q_embeddings,d_embeddings)
|
87 |
+
|
88 |
+
def output_results(self,output_file,texts,embeddings,main_index = 0):
|
89 |
+
# Calculate cosine similarities
|
90 |
+
# Cosine similarities are in [-1, 1]. Higher means more similar
|
91 |
+
query_embeddings = embeddings[0]
|
92 |
+
doc_embeddings = embeddings[1]
|
93 |
+
cosine_dict = {}
|
94 |
+
queries = [texts[0]]
|
95 |
+
docs = texts[1:]
|
96 |
+
if (self.debug):
|
97 |
+
print("Total sentences",len(texts))
|
98 |
+
for i in range(len(docs)):
|
99 |
+
cosine_dict[docs[i]] = 1 - cosine(query_embeddings[0], doc_embeddings[i])
|
100 |
+
|
101 |
+
if (self.debug):
|
102 |
+
print("Input sentence:",texts[main_index])
|
103 |
+
sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
|
104 |
+
if (self.debug):
|
105 |
+
for key in sorted_dict:
|
106 |
+
print("Cosine similarity with \"%s\" is: %.3f" % (key, sorted_dict[key]))
|
107 |
+
if (output_file is not None):
|
108 |
+
with open(output_file,"w") as fp:
|
109 |
+
fp.write(json.dumps(sorted_dict,indent=0))
|
110 |
+
return sorted_dict
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == '__main__':
|
115 |
+
parser = argparse.ArgumentParser(description='OpenAI model for document search embeddings ',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
116 |
+
parser.add_argument('-input', action="store", dest="input",required=True,help="Input file with sentences")
|
117 |
+
parser.add_argument('-output', action="store", dest="output",default="output.txt",help="Output file with results")
|
118 |
+
parser.add_argument('-model', action="store", dest="model",default="text-search-ada-doc-001",help="model name")
|
119 |
+
|
120 |
+
results = parser.parse_args()
|
121 |
+
obj = OpenAIQnAModel()
|
122 |
+
obj.init_model(results.model)
|
123 |
+
texts, embeddings = obj.compute_embeddings(results.input,results.input,is_file = True)
|
124 |
+
results = obj.output_results(results.output,texts,embeddings)
|