lambdaofgod commited on
Commit
a284f57
1 Parent(s): 568499b

additional cols and optional device

Browse files
Files changed (3) hide show
  1. app_implementation.py +48 -36
  2. config.py +7 -7
  3. search_utils.py +74 -44
app_implementation.py CHANGED
@@ -1,50 +1,61 @@
1
- import os
2
  from typing import Dict, List
3
 
 
4
  import pandas as pd
5
- import datasets
6
  import streamlit as st
7
- import config
8
  from findkit import retrieval_pipeline
 
 
9
  from search_utils import (
 
 
10
  get_repos_with_descriptions,
11
- search_f,
12
- merge_text_list_cols,
13
- setup_retrieval_pipeline,
14
  )
15
 
16
 
17
  class RetrievalApp:
 
 
 
 
 
 
 
 
 
 
18
  def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"):
 
 
19
  print("loading data")
20
 
21
- raw_retrieval_df = (
22
- datasets.load_dataset(data_path)["train"]
23
- .to_pandas()
24
- .drop_duplicates(subset=["repo"])
25
- .reset_index(drop=True)
26
- )
27
- self.retrieval_df = merge_text_list_cols(
28
- raw_retrieval_df, config.text_list_cols
29
- )
30
 
31
  model_name = st.sidebar.selectbox("model", config.model_names)
32
  self.query_encoder_name = "lambdaofgod/query-" + model_name
33
  self.document_encoder_name = "lambdaofgod/document-" + model_name
34
 
 
 
35
  st.sidebar.text("using models")
36
  st.sidebar.text("https://huggingface.co/" + self.query_encoder_name)
37
- st.sidebar.text("https://huggingface.co/" + self.document_encoder_name)
 
 
 
 
38
 
39
  @staticmethod
40
  def show_retrieval_results(
41
- retrieval_pipe: retrieval_pipeline.RetrievalPipeline,
42
  query: str,
43
  k: int,
44
  all_queries: List[str],
45
  description_length: int,
46
  repos_by_query: Dict[str, pd.DataFrame],
47
- doc_col: str,
48
  ):
49
  print("started retrieval")
50
  if query in all_queries:
@@ -56,15 +67,14 @@ class RetrievalApp:
56
  st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
57
  with st.spinner(text="fetching results"):
58
  st.write(
59
- search_f(retrieval_pipe, query, k, description_length, doc_col).to_html(
60
  escape=False, index=False
61
  ),
62
  unsafe_allow_html=True,
63
  )
64
  print("finished retrieval")
65
 
66
- @staticmethod
67
- def app(retrieval_pipeline, retrieval_df, doc_col):
68
 
69
  retrieved_results = st.sidebar.number_input("number of results", value=10)
70
  description_length = st.sidebar.number_input(
@@ -72,17 +82,12 @@ class RetrievalApp:
72
  )
73
 
74
  tasks_deduped = (
75
- retrieval_df["tasks"].explode().value_counts().reset_index()
76
  ) # drop_duplicates().sort_values().reset_index(drop=True)
77
  tasks_deduped.columns = ["task", "documents per task"]
78
  with st.sidebar.expander("View test set queries"):
79
  st.table(tasks_deduped.explode("task"))
80
-
81
- additional_shown_cols = st.sidebar.multiselect(
82
- label="additional cols", options=config.text_cols, default=doc_col
83
- )
84
-
85
- repos_by_query = retrieval_df.explode("tasks").groupby("tasks")
86
  query = st.text_input("input query", value="metric learning")
87
  RetrievalApp.show_retrieval_results(
88
  retrieval_pipeline,
@@ -91,16 +96,23 @@ class RetrievalApp:
91
  tasks_deduped["task"].to_list(),
92
  description_length,
93
  repos_by_query,
94
- additional_shown_cols,
95
  )
96
 
97
- def main(self):
98
- print("setting up retrieval_pipe")
99
- doc_col = "dependencies"
100
- retrieval_pipeline = setup_retrieval_pipeline(
101
  self.query_encoder_name,
102
  self.document_encoder_name,
103
- self.retrieval_df[doc_col],
104
- self.retrieval_df,
 
 
 
 
 
 
 
105
  )
106
- RetrievalApp.app(retrieval_pipeline, self.retrieval_df, doc_col)
 
 
 
1
  from typing import Dict, List
2
 
3
+ import torch
4
  import pandas as pd
 
5
  import streamlit as st
 
6
  from findkit import retrieval_pipeline
7
+
8
+ import config
9
  from search_utils import (
10
+ RetrievalPipelineWrapper,
11
+ get_doc_cols,
12
  get_repos_with_descriptions,
13
+ get_retrieval_df,
14
+ merge_cols,
 
15
  )
16
 
17
 
18
  class RetrievalApp:
19
+ def get_device_options(self):
20
+ if torch.cuda.is_available:
21
+ return ["cuda", "cpu"]
22
+ else:
23
+ return ["cpu"]
24
+
25
+ @st.cache(allow_output_mutation=True)
26
+ def get_retrieval_df(self):
27
+ return get_retrieval_df(self.data_path, config.text_list_cols)
28
+
29
  def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"):
30
+ self.data_path = data_path
31
+ self.device = st.sidebar.selectbox("device", self.get_device_options())
32
  print("loading data")
33
 
34
+ self.retrieval_df = self.get_retrieval_df().copy()
 
 
 
 
 
 
 
 
35
 
36
  model_name = st.sidebar.selectbox("model", config.model_names)
37
  self.query_encoder_name = "lambdaofgod/query-" + model_name
38
  self.document_encoder_name = "lambdaofgod/document-" + model_name
39
 
40
+ doc_cols = get_doc_cols(model_name)
41
+
42
  st.sidebar.text("using models")
43
  st.sidebar.text("https://huggingface.co/" + self.query_encoder_name)
44
+ st.sidebar.text("HTTP://huggingface.co/" + self.document_encoder_name)
45
+
46
+ self.additional_shown_cols = st.sidebar.multiselect(
47
+ label="used text features", options=config.text_cols, default=doc_cols
48
+ )
49
 
50
  @staticmethod
51
  def show_retrieval_results(
52
+ retrieval_pipe: RetrievalPipelineWrapper,
53
  query: str,
54
  k: int,
55
  all_queries: List[str],
56
  description_length: int,
57
  repos_by_query: Dict[str, pd.DataFrame],
58
+ additional_shown_cols: List[str],
59
  ):
60
  print("started retrieval")
61
  if query in all_queries:
 
67
  st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
68
  with st.spinner(text="fetching results"):
69
  st.write(
70
+ retrieval_pipe.search(query, k, description_length, additional_shown_cols).to_html(
71
  escape=False, index=False
72
  ),
73
  unsafe_allow_html=True,
74
  )
75
  print("finished retrieval")
76
 
77
+ def run_app(self, retrieval_pipeline):
 
78
 
79
  retrieved_results = st.sidebar.number_input("number of results", value=10)
80
  description_length = st.sidebar.number_input(
 
82
  )
83
 
84
  tasks_deduped = (
85
+ self.retrieval_df["tasks"].explode().value_counts().reset_index()
86
  ) # drop_duplicates().sort_values().reset_index(drop=True)
87
  tasks_deduped.columns = ["task", "documents per task"]
88
  with st.sidebar.expander("View test set queries"):
89
  st.table(tasks_deduped.explode("task"))
90
+ repos_by_query = self.retrieval_df.explode("tasks").groupby("tasks")
 
 
 
 
 
91
  query = st.text_input("input query", value="metric learning")
92
  RetrievalApp.show_retrieval_results(
93
  retrieval_pipeline,
 
96
  tasks_deduped["task"].to_list(),
97
  description_length,
98
  repos_by_query,
99
+ self.additional_shown_cols,
100
  )
101
 
102
+ @st.cache(allow_output_mutation=True)
103
+ def get_retrieval_pipeline(self, displayed_retrieval_df):
104
+ return RetrievalPipelineWrapper.setup_from_encoder_names(
 
105
  self.query_encoder_name,
106
  self.document_encoder_name,
107
+ displayed_retrieval_df["document"],
108
+ displayed_retrieval_df,
109
+ device=self.device,
110
+ )
111
+
112
+ def main(self):
113
+ print("setting up retrieval_pipe")
114
+ displayed_retrieval_df = merge_cols(
115
+ self.retrieval_df.copy(), self.additional_shown_cols
116
  )
117
+ retrieval_pipeline = self.get_retrieval_pipeline(displayed_retrieval_df)
118
+ self.run_app(retrieval_pipeline)
config.py CHANGED
@@ -1,11 +1,11 @@
1
  model_names = [
2
- 'dependencies-nbow-nbow-mnrl',
3
- 'readme-nbow-nbow-mnrl',
4
- 'titles-nbow-nbow-mnrl',
5
- 'titles#dependencies-nbow-nbow-mnrl',
6
- 'readme#dependencies-nbow-nbow-mnrl'
7
  ]
8
- best_tasks_path="assets/best_tasks.csv"
9
- worst_tasks_path="assets/worst_tasks.csv"
10
  text_cols = ["dependencies", "readme", "titles"]
11
  text_list_cols = ["titles"]
 
1
  model_names = [
2
+ "dependencies-nbow-nbow-mnrl",
3
+ "readme-nbow-nbow-mnrl",
4
+ "titles-nbow-nbow-mnrl",
5
+ "titles_dependencies-nbow-nbow-mnrl",
6
+ "readme_dependencies-nbow-nbow-mnrl",
7
  ]
8
+ best_tasks_path = "assets/best_tasks.csv"
9
+ worst_tasks_path = "assets/worst_tasks.csv"
10
  text_cols = ["dependencies", "readme", "titles"]
11
  text_list_cols = ["titles"]
search_utils.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
  from typing import Dict, List
 
3
 
4
-
5
  import ast
6
  import pandas as pd
7
  import sentence_transformers
@@ -11,6 +12,33 @@ from toolz import partial
11
  import config
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def truncate_description(description, length=50):
15
  return " ".join(description.split()[:length])
16
 
@@ -19,25 +47,6 @@ def get_repos_with_descriptions(repos_df, repos):
19
  return repos_df.loc[repos]
20
 
21
 
22
- def search_f(
23
- retrieval_pipe: retrieval_pipeline.RetrievalPipeline,
24
- query: str,
25
- k: int,
26
- description_length: int,
27
- doc_col: List[str],
28
- ):
29
- results = retrieval_pipe.find_similar(query, k)
30
- # results['repo'] = results.index
31
- results["link"] = "https://github.com/" + results["repo"]
32
- for col in doc_col:
33
- results[col] = results[col].apply(
34
- lambda desc: truncate_description(desc, description_length)
35
- )
36
- shown_cols = ["repo", "tasks", "link", "distance"]
37
- shown_cols = shown_cols + doc_col
38
- return results.reset_index(drop=True)[shown_cols]
39
-
40
-
41
  def merge_text_list_cols(retrieval_df, text_list_cols):
42
  retrieval_df = retrieval_df.copy()
43
  for col in text_list_cols:
@@ -47,29 +56,50 @@ def merge_text_list_cols(retrieval_df, text_list_cols):
47
  return retrieval_df
48
 
49
 
50
- def setup_pipeline(
51
- extractor: feature_extractors.SentenceEncoderFeatureExtractor,
52
- documents_df: pd.DataFrame,
53
- text_col: str,
54
- ):
55
- retrieval_pipeline.RetrievalPipelineFactory.build(
56
- documents_df[text_col], metadata=documents_df
57
- )
58
 
 
59
 
60
- @st.cache(allow_output_mutation=True)
61
- def setup_retrieval_pipeline(
62
- query_encoder_path, document_encoder_path, documents, metadata
63
- ):
64
- document_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
65
- sentence_transformers.SentenceTransformer(document_encoder_path, device="cpu")
66
- )
67
- query_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
68
- sentence_transformers.SentenceTransformer(query_encoder_path, device="cpu")
69
- )
70
- retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory(
71
- feature_extractor=document_encoder,
72
- query_feature_extractor=query_encoder,
73
- index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"),
74
- )
75
- return retrieval_pipe.build(documents, metadata=metadata)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from typing import Dict, List
3
+ from dataclasses import dataclass
4
 
5
+ import datasets
6
  import ast
7
  import pandas as pd
8
  import sentence_transformers
 
12
  import config
13
 
14
 
15
+ def get_doc_cols(model_name):
16
+ model_name = model_name.replace("query-", "")
17
+ model_name = model_name.replace("document-", "")
18
+ return model_name.split("-")[0].split("_")
19
+
20
+
21
+ def merge_cols(df, cols):
22
+ df["document"] = df[cols[0]]
23
+ for col in cols:
24
+ df["document"] = df["document"] + " " + df[col]
25
+ return df
26
+
27
+
28
+ def get_retrieval_df(
29
+ data_path="lambdaofgod/pwc_repositories_with_dependencies", text_list_cols=None
30
+ ):
31
+ raw_retrieval_df = (
32
+ datasets.load_dataset(data_path)["train"]
33
+ .to_pandas()
34
+ .drop_duplicates(subset=["repo"])
35
+ .reset_index(drop=True)
36
+ )
37
+ if text_list_cols:
38
+ return merge_text_list_cols(raw_retrieval_df, text_list_cols)
39
+ return raw_retrieval_df
40
+
41
+
42
  def truncate_description(description, length=50):
43
  return " ".join(description.split()[:length])
44
 
 
47
  return repos_df.loc[repos]
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def merge_text_list_cols(retrieval_df, text_list_cols):
51
  retrieval_df = retrieval_df.copy()
52
  for col in text_list_cols:
 
56
  return retrieval_df
57
 
58
 
59
+ @dataclass
60
+ class RetrievalPipelineWrapper:
 
 
 
 
 
 
61
 
62
+ pipeline: retrieval_pipeline.RetrievalPipeline
63
 
64
+ @classmethod
65
+ def build_from_encoders(cls, query_encoder, document_encoder, documents, metadata):
66
+ retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory(
67
+ feature_extractor=document_encoder,
68
+ query_feature_extractor=query_encoder,
69
+ index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"),
70
+ )
71
+ pipeline = retrieval_pipe.build(documents, metadata=metadata)
72
+ return RetrievalPipelineWrapper(pipeline)
73
+
74
+ def search(
75
+ self,
76
+ query: str,
77
+ k: int,
78
+ description_length: int,
79
+ additional_shown_cols: List[str],
80
+ ):
81
+ results = self.pipeline.find_similar(query, k)
82
+ # results['repo'] = results.index
83
+ results["link"] = "https://github.com/" + results["repo"]
84
+ for col in additional_shown_cols:
85
+ results[col] = results[col].apply(
86
+ lambda desc: truncate_description(desc, description_length)
87
+ )
88
+ shown_cols = ["repo", "tasks", "link", "distance"]
89
+ shown_cols = shown_cols + additional_shown_cols
90
+ return results.reset_index(drop=True)[shown_cols]
91
+
92
+ @classmethod
93
+ def setup_from_encoder_names(cls, query_encoder_path, document_encoder_path, documents, metadata, device
94
+ ):
95
+ document_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
96
+ sentence_transformers.SentenceTransformer(
97
+ document_encoder_path, device=device
98
+ )
99
+ )
100
+ query_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
101
+ sentence_transformers.SentenceTransformer(query_encoder_path, device=device)
102
+ )
103
+ return cls.build_from_encoders(
104
+ query_encoder, document_encoder, documents, metadata
105
+ )