Spaces:
Runtime error
Runtime error
lambdaofgod
commited on
Commit
•
a284f57
1
Parent(s):
568499b
additional cols and optional device
Browse files- app_implementation.py +48 -36
- config.py +7 -7
- search_utils.py +74 -44
app_implementation.py
CHANGED
@@ -1,50 +1,61 @@
|
|
1 |
-
import os
|
2 |
from typing import Dict, List
|
3 |
|
|
|
4 |
import pandas as pd
|
5 |
-
import datasets
|
6 |
import streamlit as st
|
7 |
-
import config
|
8 |
from findkit import retrieval_pipeline
|
|
|
|
|
9 |
from search_utils import (
|
|
|
|
|
10 |
get_repos_with_descriptions,
|
11 |
-
|
12 |
-
|
13 |
-
setup_retrieval_pipeline,
|
14 |
)
|
15 |
|
16 |
|
17 |
class RetrievalApp:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"):
|
|
|
|
|
19 |
print("loading data")
|
20 |
|
21 |
-
|
22 |
-
datasets.load_dataset(data_path)["train"]
|
23 |
-
.to_pandas()
|
24 |
-
.drop_duplicates(subset=["repo"])
|
25 |
-
.reset_index(drop=True)
|
26 |
-
)
|
27 |
-
self.retrieval_df = merge_text_list_cols(
|
28 |
-
raw_retrieval_df, config.text_list_cols
|
29 |
-
)
|
30 |
|
31 |
model_name = st.sidebar.selectbox("model", config.model_names)
|
32 |
self.query_encoder_name = "lambdaofgod/query-" + model_name
|
33 |
self.document_encoder_name = "lambdaofgod/document-" + model_name
|
34 |
|
|
|
|
|
35 |
st.sidebar.text("using models")
|
36 |
st.sidebar.text("https://huggingface.co/" + self.query_encoder_name)
|
37 |
-
st.sidebar.text("
|
|
|
|
|
|
|
|
|
38 |
|
39 |
@staticmethod
|
40 |
def show_retrieval_results(
|
41 |
-
retrieval_pipe:
|
42 |
query: str,
|
43 |
k: int,
|
44 |
all_queries: List[str],
|
45 |
description_length: int,
|
46 |
repos_by_query: Dict[str, pd.DataFrame],
|
47 |
-
|
48 |
):
|
49 |
print("started retrieval")
|
50 |
if query in all_queries:
|
@@ -56,15 +67,14 @@ class RetrievalApp:
|
|
56 |
st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
|
57 |
with st.spinner(text="fetching results"):
|
58 |
st.write(
|
59 |
-
|
60 |
escape=False, index=False
|
61 |
),
|
62 |
unsafe_allow_html=True,
|
63 |
)
|
64 |
print("finished retrieval")
|
65 |
|
66 |
-
|
67 |
-
def app(retrieval_pipeline, retrieval_df, doc_col):
|
68 |
|
69 |
retrieved_results = st.sidebar.number_input("number of results", value=10)
|
70 |
description_length = st.sidebar.number_input(
|
@@ -72,17 +82,12 @@ class RetrievalApp:
|
|
72 |
)
|
73 |
|
74 |
tasks_deduped = (
|
75 |
-
retrieval_df["tasks"].explode().value_counts().reset_index()
|
76 |
) # drop_duplicates().sort_values().reset_index(drop=True)
|
77 |
tasks_deduped.columns = ["task", "documents per task"]
|
78 |
with st.sidebar.expander("View test set queries"):
|
79 |
st.table(tasks_deduped.explode("task"))
|
80 |
-
|
81 |
-
additional_shown_cols = st.sidebar.multiselect(
|
82 |
-
label="additional cols", options=config.text_cols, default=doc_col
|
83 |
-
)
|
84 |
-
|
85 |
-
repos_by_query = retrieval_df.explode("tasks").groupby("tasks")
|
86 |
query = st.text_input("input query", value="metric learning")
|
87 |
RetrievalApp.show_retrieval_results(
|
88 |
retrieval_pipeline,
|
@@ -91,16 +96,23 @@ class RetrievalApp:
|
|
91 |
tasks_deduped["task"].to_list(),
|
92 |
description_length,
|
93 |
repos_by_query,
|
94 |
-
additional_shown_cols,
|
95 |
)
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
retrieval_pipeline = setup_retrieval_pipeline(
|
101 |
self.query_encoder_name,
|
102 |
self.document_encoder_name,
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
)
|
106 |
-
|
|
|
|
|
|
1 |
from typing import Dict, List
|
2 |
|
3 |
+
import torch
|
4 |
import pandas as pd
|
|
|
5 |
import streamlit as st
|
|
|
6 |
from findkit import retrieval_pipeline
|
7 |
+
|
8 |
+
import config
|
9 |
from search_utils import (
|
10 |
+
RetrievalPipelineWrapper,
|
11 |
+
get_doc_cols,
|
12 |
get_repos_with_descriptions,
|
13 |
+
get_retrieval_df,
|
14 |
+
merge_cols,
|
|
|
15 |
)
|
16 |
|
17 |
|
18 |
class RetrievalApp:
|
19 |
+
def get_device_options(self):
|
20 |
+
if torch.cuda.is_available:
|
21 |
+
return ["cuda", "cpu"]
|
22 |
+
else:
|
23 |
+
return ["cpu"]
|
24 |
+
|
25 |
+
@st.cache(allow_output_mutation=True)
|
26 |
+
def get_retrieval_df(self):
|
27 |
+
return get_retrieval_df(self.data_path, config.text_list_cols)
|
28 |
+
|
29 |
def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"):
|
30 |
+
self.data_path = data_path
|
31 |
+
self.device = st.sidebar.selectbox("device", self.get_device_options())
|
32 |
print("loading data")
|
33 |
|
34 |
+
self.retrieval_df = self.get_retrieval_df().copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
model_name = st.sidebar.selectbox("model", config.model_names)
|
37 |
self.query_encoder_name = "lambdaofgod/query-" + model_name
|
38 |
self.document_encoder_name = "lambdaofgod/document-" + model_name
|
39 |
|
40 |
+
doc_cols = get_doc_cols(model_name)
|
41 |
+
|
42 |
st.sidebar.text("using models")
|
43 |
st.sidebar.text("https://huggingface.co/" + self.query_encoder_name)
|
44 |
+
st.sidebar.text("HTTP://huggingface.co/" + self.document_encoder_name)
|
45 |
+
|
46 |
+
self.additional_shown_cols = st.sidebar.multiselect(
|
47 |
+
label="used text features", options=config.text_cols, default=doc_cols
|
48 |
+
)
|
49 |
|
50 |
@staticmethod
|
51 |
def show_retrieval_results(
|
52 |
+
retrieval_pipe: RetrievalPipelineWrapper,
|
53 |
query: str,
|
54 |
k: int,
|
55 |
all_queries: List[str],
|
56 |
description_length: int,
|
57 |
repos_by_query: Dict[str, pd.DataFrame],
|
58 |
+
additional_shown_cols: List[str],
|
59 |
):
|
60 |
print("started retrieval")
|
61 |
if query in all_queries:
|
|
|
67 |
st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
|
68 |
with st.spinner(text="fetching results"):
|
69 |
st.write(
|
70 |
+
retrieval_pipe.search(query, k, description_length, additional_shown_cols).to_html(
|
71 |
escape=False, index=False
|
72 |
),
|
73 |
unsafe_allow_html=True,
|
74 |
)
|
75 |
print("finished retrieval")
|
76 |
|
77 |
+
def run_app(self, retrieval_pipeline):
|
|
|
78 |
|
79 |
retrieved_results = st.sidebar.number_input("number of results", value=10)
|
80 |
description_length = st.sidebar.number_input(
|
|
|
82 |
)
|
83 |
|
84 |
tasks_deduped = (
|
85 |
+
self.retrieval_df["tasks"].explode().value_counts().reset_index()
|
86 |
) # drop_duplicates().sort_values().reset_index(drop=True)
|
87 |
tasks_deduped.columns = ["task", "documents per task"]
|
88 |
with st.sidebar.expander("View test set queries"):
|
89 |
st.table(tasks_deduped.explode("task"))
|
90 |
+
repos_by_query = self.retrieval_df.explode("tasks").groupby("tasks")
|
|
|
|
|
|
|
|
|
|
|
91 |
query = st.text_input("input query", value="metric learning")
|
92 |
RetrievalApp.show_retrieval_results(
|
93 |
retrieval_pipeline,
|
|
|
96 |
tasks_deduped["task"].to_list(),
|
97 |
description_length,
|
98 |
repos_by_query,
|
99 |
+
self.additional_shown_cols,
|
100 |
)
|
101 |
|
102 |
+
@st.cache(allow_output_mutation=True)
|
103 |
+
def get_retrieval_pipeline(self, displayed_retrieval_df):
|
104 |
+
return RetrievalPipelineWrapper.setup_from_encoder_names(
|
|
|
105 |
self.query_encoder_name,
|
106 |
self.document_encoder_name,
|
107 |
+
displayed_retrieval_df["document"],
|
108 |
+
displayed_retrieval_df,
|
109 |
+
device=self.device,
|
110 |
+
)
|
111 |
+
|
112 |
+
def main(self):
|
113 |
+
print("setting up retrieval_pipe")
|
114 |
+
displayed_retrieval_df = merge_cols(
|
115 |
+
self.retrieval_df.copy(), self.additional_shown_cols
|
116 |
)
|
117 |
+
retrieval_pipeline = self.get_retrieval_pipeline(displayed_retrieval_df)
|
118 |
+
self.run_app(retrieval_pipeline)
|
config.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
model_names = [
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
]
|
8 |
-
best_tasks_path="assets/best_tasks.csv"
|
9 |
-
worst_tasks_path="assets/worst_tasks.csv"
|
10 |
text_cols = ["dependencies", "readme", "titles"]
|
11 |
text_list_cols = ["titles"]
|
|
|
1 |
model_names = [
|
2 |
+
"dependencies-nbow-nbow-mnrl",
|
3 |
+
"readme-nbow-nbow-mnrl",
|
4 |
+
"titles-nbow-nbow-mnrl",
|
5 |
+
"titles_dependencies-nbow-nbow-mnrl",
|
6 |
+
"readme_dependencies-nbow-nbow-mnrl",
|
7 |
]
|
8 |
+
best_tasks_path = "assets/best_tasks.csv"
|
9 |
+
worst_tasks_path = "assets/worst_tasks.csv"
|
10 |
text_cols = ["dependencies", "readme", "titles"]
|
11 |
text_list_cols = ["titles"]
|
search_utils.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import os
|
2 |
from typing import Dict, List
|
|
|
3 |
|
4 |
-
|
5 |
import ast
|
6 |
import pandas as pd
|
7 |
import sentence_transformers
|
@@ -11,6 +12,33 @@ from toolz import partial
|
|
11 |
import config
|
12 |
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def truncate_description(description, length=50):
|
15 |
return " ".join(description.split()[:length])
|
16 |
|
@@ -19,25 +47,6 @@ def get_repos_with_descriptions(repos_df, repos):
|
|
19 |
return repos_df.loc[repos]
|
20 |
|
21 |
|
22 |
-
def search_f(
|
23 |
-
retrieval_pipe: retrieval_pipeline.RetrievalPipeline,
|
24 |
-
query: str,
|
25 |
-
k: int,
|
26 |
-
description_length: int,
|
27 |
-
doc_col: List[str],
|
28 |
-
):
|
29 |
-
results = retrieval_pipe.find_similar(query, k)
|
30 |
-
# results['repo'] = results.index
|
31 |
-
results["link"] = "https://github.com/" + results["repo"]
|
32 |
-
for col in doc_col:
|
33 |
-
results[col] = results[col].apply(
|
34 |
-
lambda desc: truncate_description(desc, description_length)
|
35 |
-
)
|
36 |
-
shown_cols = ["repo", "tasks", "link", "distance"]
|
37 |
-
shown_cols = shown_cols + doc_col
|
38 |
-
return results.reset_index(drop=True)[shown_cols]
|
39 |
-
|
40 |
-
|
41 |
def merge_text_list_cols(retrieval_df, text_list_cols):
|
42 |
retrieval_df = retrieval_df.copy()
|
43 |
for col in text_list_cols:
|
@@ -47,29 +56,50 @@ def merge_text_list_cols(retrieval_df, text_list_cols):
|
|
47 |
return retrieval_df
|
48 |
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
documents_df: pd.DataFrame,
|
53 |
-
text_col: str,
|
54 |
-
):
|
55 |
-
retrieval_pipeline.RetrievalPipelineFactory.build(
|
56 |
-
documents_df[text_col], metadata=documents_df
|
57 |
-
)
|
58 |
|
|
|
59 |
|
60 |
-
@
|
61 |
-
def
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
from typing import Dict, List
|
3 |
+
from dataclasses import dataclass
|
4 |
|
5 |
+
import datasets
|
6 |
import ast
|
7 |
import pandas as pd
|
8 |
import sentence_transformers
|
|
|
12 |
import config
|
13 |
|
14 |
|
15 |
+
def get_doc_cols(model_name):
|
16 |
+
model_name = model_name.replace("query-", "")
|
17 |
+
model_name = model_name.replace("document-", "")
|
18 |
+
return model_name.split("-")[0].split("_")
|
19 |
+
|
20 |
+
|
21 |
+
def merge_cols(df, cols):
|
22 |
+
df["document"] = df[cols[0]]
|
23 |
+
for col in cols:
|
24 |
+
df["document"] = df["document"] + " " + df[col]
|
25 |
+
return df
|
26 |
+
|
27 |
+
|
28 |
+
def get_retrieval_df(
|
29 |
+
data_path="lambdaofgod/pwc_repositories_with_dependencies", text_list_cols=None
|
30 |
+
):
|
31 |
+
raw_retrieval_df = (
|
32 |
+
datasets.load_dataset(data_path)["train"]
|
33 |
+
.to_pandas()
|
34 |
+
.drop_duplicates(subset=["repo"])
|
35 |
+
.reset_index(drop=True)
|
36 |
+
)
|
37 |
+
if text_list_cols:
|
38 |
+
return merge_text_list_cols(raw_retrieval_df, text_list_cols)
|
39 |
+
return raw_retrieval_df
|
40 |
+
|
41 |
+
|
42 |
def truncate_description(description, length=50):
|
43 |
return " ".join(description.split()[:length])
|
44 |
|
|
|
47 |
return repos_df.loc[repos]
|
48 |
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
def merge_text_list_cols(retrieval_df, text_list_cols):
|
51 |
retrieval_df = retrieval_df.copy()
|
52 |
for col in text_list_cols:
|
|
|
56 |
return retrieval_df
|
57 |
|
58 |
|
59 |
+
@dataclass
|
60 |
+
class RetrievalPipelineWrapper:
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
pipeline: retrieval_pipeline.RetrievalPipeline
|
63 |
|
64 |
+
@classmethod
|
65 |
+
def build_from_encoders(cls, query_encoder, document_encoder, documents, metadata):
|
66 |
+
retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory(
|
67 |
+
feature_extractor=document_encoder,
|
68 |
+
query_feature_extractor=query_encoder,
|
69 |
+
index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"),
|
70 |
+
)
|
71 |
+
pipeline = retrieval_pipe.build(documents, metadata=metadata)
|
72 |
+
return RetrievalPipelineWrapper(pipeline)
|
73 |
+
|
74 |
+
def search(
|
75 |
+
self,
|
76 |
+
query: str,
|
77 |
+
k: int,
|
78 |
+
description_length: int,
|
79 |
+
additional_shown_cols: List[str],
|
80 |
+
):
|
81 |
+
results = self.pipeline.find_similar(query, k)
|
82 |
+
# results['repo'] = results.index
|
83 |
+
results["link"] = "https://github.com/" + results["repo"]
|
84 |
+
for col in additional_shown_cols:
|
85 |
+
results[col] = results[col].apply(
|
86 |
+
lambda desc: truncate_description(desc, description_length)
|
87 |
+
)
|
88 |
+
shown_cols = ["repo", "tasks", "link", "distance"]
|
89 |
+
shown_cols = shown_cols + additional_shown_cols
|
90 |
+
return results.reset_index(drop=True)[shown_cols]
|
91 |
+
|
92 |
+
@classmethod
|
93 |
+
def setup_from_encoder_names(cls, query_encoder_path, document_encoder_path, documents, metadata, device
|
94 |
+
):
|
95 |
+
document_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
|
96 |
+
sentence_transformers.SentenceTransformer(
|
97 |
+
document_encoder_path, device=device
|
98 |
+
)
|
99 |
+
)
|
100 |
+
query_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
|
101 |
+
sentence_transformers.SentenceTransformer(query_encoder_path, device=device)
|
102 |
+
)
|
103 |
+
return cls.build_from_encoders(
|
104 |
+
query_encoder, document_encoder, documents, metadata
|
105 |
+
)
|