Spaces:
Runtime error
Runtime error
import os | |
from typing import Dict, List | |
import pandas as pd | |
import datasets | |
import streamlit as st | |
import config | |
from findkit import retrieval_pipeline | |
from search_utils import ( | |
get_repos_with_descriptions, | |
search_f, | |
merge_text_list_cols, | |
setup_retrieval_pipeline, | |
) | |
class RetrievalApp: | |
def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"): | |
print("loading data") | |
raw_retrieval_df = ( | |
datasets.load_dataset(data_path)["train"] | |
.to_pandas() | |
.drop_duplicates(subset=["repo"]) | |
.reset_index(drop=True) | |
) | |
self.retrieval_df = merge_text_list_cols( | |
raw_retrieval_df, config.text_list_cols | |
) | |
model_name = st.sidebar.selectbox("model", config.model_names) | |
self.query_encoder_name = "lambdaofgod/query-" + model_name | |
self.document_encoder_name = "lambdaofgod/document-" + model_name | |
st.sidebar.text("using models") | |
st.sidebar.text("https://huggingface.co/" + self.query_encoder_name) | |
st.sidebar.text("https://huggingface.co/" + self.document_encoder_name) | |
def show_retrieval_results( | |
retrieval_pipe: retrieval_pipeline.RetrievalPipeline, | |
query: str, | |
k: int, | |
all_queries: List[str], | |
description_length: int, | |
repos_by_query: Dict[str, pd.DataFrame], | |
doc_col: str, | |
): | |
print("started retrieval") | |
if query in all_queries: | |
with st.expander( | |
"query is in gold standard set queries. Toggle viewing gold standard results?" | |
): | |
st.write("gold standard results") | |
task_repos = repos_by_query.get_group(query) | |
st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos)) | |
with st.spinner(text="fetching results"): | |
st.write( | |
search_f(retrieval_pipe, query, k, description_length, doc_col).to_html( | |
escape=False, index=False | |
), | |
unsafe_allow_html=True, | |
) | |
print("finished retrieval") | |
def app(retrieval_pipeline, retrieval_df, doc_col): | |
retrieved_results = st.sidebar.number_input("number of results", value=10) | |
description_length = st.sidebar.number_input( | |
"number of used description words", value=10 | |
) | |
tasks_deduped = ( | |
retrieval_df["tasks"].explode().value_counts().reset_index() | |
) # drop_duplicates().sort_values().reset_index(drop=True) | |
tasks_deduped.columns = ["task", "documents per task"] | |
with st.sidebar.expander("View test set queries"): | |
st.table(tasks_deduped.explode("task")) | |
additional_shown_cols = st.sidebar.multiselect( | |
label="additional cols", options=config.text_cols, default=doc_col | |
) | |
repos_by_query = retrieval_df.explode("tasks").groupby("tasks") | |
query = st.text_input("input query", value="metric learning") | |
RetrievalApp.show_retrieval_results( | |
retrieval_pipeline, | |
query, | |
retrieved_results, | |
tasks_deduped["task"].to_list(), | |
description_length, | |
repos_by_query, | |
additional_shown_cols, | |
) | |
def main(self): | |
print("setting up retrieval_pipe") | |
doc_col = "dependencies" | |
retrieval_pipeline = setup_retrieval_pipeline( | |
self.query_encoder_name, | |
self.document_encoder_name, | |
self.retrieval_df[doc_col], | |
self.retrieval_df, | |
) | |
RetrievalApp.app(retrieval_pipeline, self.retrieval_df, doc_col) | |