lambdaofgod commited on
Commit
568499b
1 Parent(s): 1ed024e

app refactor and new models

Browse files
Files changed (4) hide show
  1. app_implementation.py +106 -0
  2. config.py +9 -2
  3. pages/1_Retrieval_App.py +3 -149
  4. search_utils.py +75 -0
app_implementation.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List
3
+
4
+ import pandas as pd
5
+ import datasets
6
+ import streamlit as st
7
+ import config
8
+ from findkit import retrieval_pipeline
9
+ from search_utils import (
10
+ get_repos_with_descriptions,
11
+ search_f,
12
+ merge_text_list_cols,
13
+ setup_retrieval_pipeline,
14
+ )
15
+
16
+
17
+ class RetrievalApp:
18
+ def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"):
19
+ print("loading data")
20
+
21
+ raw_retrieval_df = (
22
+ datasets.load_dataset(data_path)["train"]
23
+ .to_pandas()
24
+ .drop_duplicates(subset=["repo"])
25
+ .reset_index(drop=True)
26
+ )
27
+ self.retrieval_df = merge_text_list_cols(
28
+ raw_retrieval_df, config.text_list_cols
29
+ )
30
+
31
+ model_name = st.sidebar.selectbox("model", config.model_names)
32
+ self.query_encoder_name = "lambdaofgod/query-" + model_name
33
+ self.document_encoder_name = "lambdaofgod/document-" + model_name
34
+
35
+ st.sidebar.text("using models")
36
+ st.sidebar.text("https://huggingface.co/" + self.query_encoder_name)
37
+ st.sidebar.text("https://huggingface.co/" + self.document_encoder_name)
38
+
39
+ @staticmethod
40
+ def show_retrieval_results(
41
+ retrieval_pipe: retrieval_pipeline.RetrievalPipeline,
42
+ query: str,
43
+ k: int,
44
+ all_queries: List[str],
45
+ description_length: int,
46
+ repos_by_query: Dict[str, pd.DataFrame],
47
+ doc_col: str,
48
+ ):
49
+ print("started retrieval")
50
+ if query in all_queries:
51
+ with st.expander(
52
+ "query is in gold standard set queries. Toggle viewing gold standard results?"
53
+ ):
54
+ st.write("gold standard results")
55
+ task_repos = repos_by_query.get_group(query)
56
+ st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
57
+ with st.spinner(text="fetching results"):
58
+ st.write(
59
+ search_f(retrieval_pipe, query, k, description_length, doc_col).to_html(
60
+ escape=False, index=False
61
+ ),
62
+ unsafe_allow_html=True,
63
+ )
64
+ print("finished retrieval")
65
+
66
+ @staticmethod
67
+ def app(retrieval_pipeline, retrieval_df, doc_col):
68
+
69
+ retrieved_results = st.sidebar.number_input("number of results", value=10)
70
+ description_length = st.sidebar.number_input(
71
+ "number of used description words", value=10
72
+ )
73
+
74
+ tasks_deduped = (
75
+ retrieval_df["tasks"].explode().value_counts().reset_index()
76
+ ) # drop_duplicates().sort_values().reset_index(drop=True)
77
+ tasks_deduped.columns = ["task", "documents per task"]
78
+ with st.sidebar.expander("View test set queries"):
79
+ st.table(tasks_deduped.explode("task"))
80
+
81
+ additional_shown_cols = st.sidebar.multiselect(
82
+ label="additional cols", options=config.text_cols, default=doc_col
83
+ )
84
+
85
+ repos_by_query = retrieval_df.explode("tasks").groupby("tasks")
86
+ query = st.text_input("input query", value="metric learning")
87
+ RetrievalApp.show_retrieval_results(
88
+ retrieval_pipeline,
89
+ query,
90
+ retrieved_results,
91
+ tasks_deduped["task"].to_list(),
92
+ description_length,
93
+ repos_by_query,
94
+ additional_shown_cols,
95
+ )
96
+
97
+ def main(self):
98
+ print("setting up retrieval_pipe")
99
+ doc_col = "dependencies"
100
+ retrieval_pipeline = setup_retrieval_pipeline(
101
+ self.query_encoder_name,
102
+ self.document_encoder_name,
103
+ self.retrieval_df[doc_col],
104
+ self.retrieval_df,
105
+ )
106
+ RetrievalApp.app(retrieval_pipeline, self.retrieval_df, doc_col)
config.py CHANGED
@@ -1,4 +1,11 @@
1
- query_encoder_model_name = "lambdaofgod/query_nbow_embedder"
2
- document_encoder_model_name = "lambdaofgod/document_nbow_embedder"
 
 
 
 
 
3
  best_tasks_path="assets/best_tasks.csv"
4
  worst_tasks_path="assets/worst_tasks.csv"
 
 
 
1
+ model_names = [
2
+ 'dependencies-nbow-nbow-mnrl',
3
+ 'readme-nbow-nbow-mnrl',
4
+ 'titles-nbow-nbow-mnrl',
5
+ 'titles#dependencies-nbow-nbow-mnrl',
6
+ 'readme#dependencies-nbow-nbow-mnrl'
7
+ ]
8
  best_tasks_path="assets/best_tasks.csv"
9
  worst_tasks_path="assets/worst_tasks.csv"
10
+ text_cols = ["dependencies", "readme", "titles"]
11
+ text_list_cols = ["titles"]
pages/1_Retrieval_App.py CHANGED
@@ -1,151 +1,5 @@
1
- import os
2
- from typing import Dict, List
3
 
4
- import datasets
5
- import pandas as pd
6
- import sentence_transformers
7
- import streamlit as st
8
- from findkit import feature_extractors, indexes, retrieval_pipeline
9
- from toolz import partial
10
- import config
11
 
12
-
13
- def truncate_description(description, length=50):
14
- return " ".join(description.split()[:length])
15
-
16
-
17
- def get_repos_with_descriptions(repos_df, repos):
18
- return repos_df.loc[repos]
19
-
20
-
21
- def search_f(
22
- retrieval_pipe: retrieval_pipeline.RetrievalPipeline,
23
- query: str,
24
- k: int,
25
- description_length: int,
26
- doc_col: List[str],
27
- ):
28
- results = retrieval_pipe.find_similar(query, k)
29
- # results['repo'] = results.index
30
- results["link"] = "https://github.com/" + results["repo"]
31
- for col in doc_col:
32
- results[col] = results[col].apply(
33
- lambda desc: truncate_description(desc, description_length)
34
- )
35
- shown_cols = ["repo", "tasks", "link", "distance"]
36
- shown_cols = shown_cols + doc_col
37
- return results.reset_index(drop=True)[shown_cols]
38
-
39
-
40
- def show_retrieval_results(
41
- retrieval_pipe: retrieval_pipeline.RetrievalPipeline,
42
- query: str,
43
- k: int,
44
- all_queries: List[str],
45
- description_length: int,
46
- repos_by_query: Dict[str, pd.DataFrame],
47
- doc_col: str,
48
- ):
49
- print("started retrieval")
50
- if query in all_queries:
51
- with st.expander(
52
- "query is in gold standard set queries. Toggle viewing gold standard results?"
53
- ):
54
- st.write("gold standard results")
55
- task_repos = repos_by_query.get_group(query)
56
- st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
57
- with st.spinner(text="fetching results"):
58
- st.write(
59
- search_f(retrieval_pipe, query, k, description_length, doc_col).to_html(
60
- escape=False, index=False
61
- ),
62
- unsafe_allow_html=True,
63
- )
64
- print("finished retrieval")
65
-
66
-
67
- def setup_pipeline(
68
- extractor: feature_extractors.SentenceEncoderFeatureExtractor,
69
- documents_df: pd.DataFrame,
70
- text_col: str,
71
- ):
72
- retrieval_pipeline.RetrievalPipelineFactory.build(
73
- documents_df[text_col], metadata=documents_df
74
- )
75
-
76
-
77
- @st.cache(allow_output_mutation=True)
78
- def setup_retrieval_pipeline(
79
- query_encoder_path, document_encoder_path, documents, metadata
80
- ):
81
- document_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
82
- sentence_transformers.SentenceTransformer(document_encoder_path, device="cpu")
83
- )
84
- query_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
85
- sentence_transformers.SentenceTransformer(query_encoder_path, device="cpu")
86
- )
87
- retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory(
88
- feature_extractor=document_encoder,
89
- query_feature_extractor=query_encoder,
90
- index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"),
91
- )
92
- return retrieval_pipe.build(documents, metadata=metadata)
93
-
94
-
95
- def app(retrieval_pipeline, retrieval_df, doc_col):
96
-
97
- retrieved_results = st.sidebar.number_input("number of results", value=10)
98
- description_length = st.sidebar.number_input(
99
- "number of used description words", value=10
100
- )
101
-
102
- tasks_deduped = (
103
- retrieval_df["tasks"].explode().value_counts().reset_index()
104
- ) # drop_duplicates().sort_values().reset_index(drop=True)
105
- tasks_deduped.columns = ["task", "documents per task"]
106
- with st.sidebar.expander("View test set queries"):
107
- st.table(tasks_deduped.explode("task"))
108
-
109
- additional_shown_cols = st.sidebar.multiselect(
110
- label="additional cols", options=[doc_col], default=doc_col
111
- )
112
-
113
- repos_by_query = retrieval_df.explode("tasks").groupby("tasks")
114
- query = st.text_input("input query", value="metric learning")
115
- show_retrieval_results(
116
- retrieval_pipeline,
117
- query,
118
- retrieved_results,
119
- tasks_deduped["task"].to_list(),
120
- description_length,
121
- repos_by_query,
122
- additional_shown_cols,
123
- )
124
-
125
-
126
- def app_main(
127
- query_encoder_path,
128
- document_encoder_path,
129
- data_path,
130
- ):
131
- print("loading data")
132
-
133
- retrieval_df = (
134
- datasets.load_dataset(data_path)["train"]
135
- .to_pandas()
136
- .drop_duplicates(subset=["repo"])
137
- .reset_index(drop=True)
138
- )
139
- print("setting up retrieval_pipe")
140
- doc_col = "dependencies"
141
- retrieval_pipeline = setup_retrieval_pipeline(
142
- query_encoder_path, document_encoder_path, retrieval_df[doc_col], retrieval_df
143
- )
144
- app(retrieval_pipeline, retrieval_df, doc_col)
145
-
146
-
147
- app_main(
148
- query_encoder_path=config.query_encoder_model_name,
149
- document_encoder_path=config.document_encoder_model_name,
150
- data_path="lambdaofgod/pwc_repositories_with_dependencies",
151
- )
 
1
+ from app_implementation import RetrievalApp
 
2
 
 
 
 
 
 
 
 
3
 
4
+ app = RetrievalApp()
5
+ app.main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
search_utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List
3
+
4
+
5
+ import ast
6
+ import pandas as pd
7
+ import sentence_transformers
8
+ import streamlit as st
9
+ from findkit import feature_extractors, indexes, retrieval_pipeline
10
+ from toolz import partial
11
+ import config
12
+
13
+
14
+ def truncate_description(description, length=50):
15
+ return " ".join(description.split()[:length])
16
+
17
+
18
+ def get_repos_with_descriptions(repos_df, repos):
19
+ return repos_df.loc[repos]
20
+
21
+
22
+ def search_f(
23
+ retrieval_pipe: retrieval_pipeline.RetrievalPipeline,
24
+ query: str,
25
+ k: int,
26
+ description_length: int,
27
+ doc_col: List[str],
28
+ ):
29
+ results = retrieval_pipe.find_similar(query, k)
30
+ # results['repo'] = results.index
31
+ results["link"] = "https://github.com/" + results["repo"]
32
+ for col in doc_col:
33
+ results[col] = results[col].apply(
34
+ lambda desc: truncate_description(desc, description_length)
35
+ )
36
+ shown_cols = ["repo", "tasks", "link", "distance"]
37
+ shown_cols = shown_cols + doc_col
38
+ return results.reset_index(drop=True)[shown_cols]
39
+
40
+
41
+ def merge_text_list_cols(retrieval_df, text_list_cols):
42
+ retrieval_df = retrieval_df.copy()
43
+ for col in text_list_cols:
44
+ retrieval_df[col] = retrieval_df[col].apply(
45
+ lambda t: " ".join(ast.literal_eval(t))
46
+ )
47
+ return retrieval_df
48
+
49
+
50
+ def setup_pipeline(
51
+ extractor: feature_extractors.SentenceEncoderFeatureExtractor,
52
+ documents_df: pd.DataFrame,
53
+ text_col: str,
54
+ ):
55
+ retrieval_pipeline.RetrievalPipelineFactory.build(
56
+ documents_df[text_col], metadata=documents_df
57
+ )
58
+
59
+
60
+ @st.cache(allow_output_mutation=True)
61
+ def setup_retrieval_pipeline(
62
+ query_encoder_path, document_encoder_path, documents, metadata
63
+ ):
64
+ document_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
65
+ sentence_transformers.SentenceTransformer(document_encoder_path, device="cpu")
66
+ )
67
+ query_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
68
+ sentence_transformers.SentenceTransformer(query_encoder_path, device="cpu")
69
+ )
70
+ retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory(
71
+ feature_extractor=document_encoder,
72
+ query_feature_extractor=query_encoder,
73
+ index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"),
74
+ )
75
+ return retrieval_pipe.build(documents, metadata=metadata)