from typing import Dict, List import torch import pandas as pd import streamlit as st from findkit import retrieval_pipeline import config from search_utils import ( RetrievalPipelineWrapper, get_doc_cols, get_repos_with_descriptions, get_retrieval_df, merge_cols, ) class RetrievalApp: def is_cuda_available(self): try: torch._C._cuda_init() except: return False return True def get_device_options(self): if self.is_cuda_available(): return ["cuda", "cpu"] else: return ["cpu"] @st.cache(allow_output_mutation=True) def get_retrieval_df(self): return get_retrieval_df(self.data_path, config.text_list_cols) def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"): self.data_path = data_path self.device = st.sidebar.selectbox("device", self.get_device_options()) print("loading data") self.retrieval_df = self.get_retrieval_df().copy() model_name = st.sidebar.selectbox("model", config.model_names) self.query_encoder_name = "lambdaofgod/query-" + model_name self.document_encoder_name = "lambdaofgod/document-" + model_name doc_cols = get_doc_cols(model_name) st.sidebar.text("using models") st.sidebar.text("https://huggingface.co/" + self.query_encoder_name) st.sidebar.text("HTTP://huggingface.co/" + self.document_encoder_name) self.additional_shown_cols = st.sidebar.multiselect( label="used text features", options=config.text_cols, default=doc_cols ) @staticmethod def show_retrieval_results( retrieval_pipe: RetrievalPipelineWrapper, query: str, k: int, all_queries: List[str], description_length: int, repos_by_query: Dict[str, pd.DataFrame], additional_shown_cols: List[str], ): print("started retrieval") if query in all_queries: with st.expander( "query is in gold standard set queries. Toggle viewing gold standard results?" ): st.write("gold standard results") task_repos = repos_by_query.get_group(query) st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos)) with st.spinner(text="fetching results"): st.write( retrieval_pipe.search( query, k, description_length, additional_shown_cols ).to_html(escape=False, index=False), unsafe_allow_html=True, ) print("finished retrieval") def run_app(self, retrieval_pipeline): retrieved_results = st.sidebar.number_input("number of results", value=10) description_length = st.sidebar.number_input( "number of used description words", value=10 ) tasks_deduped = ( self.retrieval_df["tasks"].explode().value_counts().reset_index() ) # drop_duplicates().sort_values().reset_index(drop=True) tasks_deduped.columns = ["task", "documents per task"] with st.sidebar.expander("View test set queries"): st.table(tasks_deduped.explode("task")) repos_by_query = self.retrieval_df.explode("tasks").groupby("tasks") query = st.text_input("input query", value="metric learning") RetrievalApp.show_retrieval_results( retrieval_pipeline, query, retrieved_results, tasks_deduped["task"].to_list(), description_length, repos_by_query, self.additional_shown_cols, ) @st.cache(allow_output_mutation=True) def get_retrieval_pipeline(self, displayed_retrieval_df): return RetrievalPipelineWrapper.setup_from_encoder_names( self.query_encoder_name, self.document_encoder_name, displayed_retrieval_df["document"], displayed_retrieval_df, device=self.device, ) def main(self): print("setting up retrieval_pipe") displayed_retrieval_df = merge_cols( self.retrieval_df.copy(), self.additional_shown_cols ) retrieval_pipeline = self.get_retrieval_pipeline(displayed_retrieval_df) self.run_app(retrieval_pipeline)