paperswithcode_nbow / search_utils.py
lambdaofgod's picture
additional cols and optional device
a284f57
raw
history blame
3.36 kB
import os
from typing import Dict, List
from dataclasses import dataclass
import datasets
import ast
import pandas as pd
import sentence_transformers
import streamlit as st
from findkit import feature_extractors, indexes, retrieval_pipeline
from toolz import partial
import config
def get_doc_cols(model_name):
model_name = model_name.replace("query-", "")
model_name = model_name.replace("document-", "")
return model_name.split("-")[0].split("_")
def merge_cols(df, cols):
df["document"] = df[cols[0]]
for col in cols:
df["document"] = df["document"] + " " + df[col]
return df
def get_retrieval_df(
data_path="lambdaofgod/pwc_repositories_with_dependencies", text_list_cols=None
):
raw_retrieval_df = (
datasets.load_dataset(data_path)["train"]
.to_pandas()
.drop_duplicates(subset=["repo"])
.reset_index(drop=True)
)
if text_list_cols:
return merge_text_list_cols(raw_retrieval_df, text_list_cols)
return raw_retrieval_df
def truncate_description(description, length=50):
return " ".join(description.split()[:length])
def get_repos_with_descriptions(repos_df, repos):
return repos_df.loc[repos]
def merge_text_list_cols(retrieval_df, text_list_cols):
retrieval_df = retrieval_df.copy()
for col in text_list_cols:
retrieval_df[col] = retrieval_df[col].apply(
lambda t: " ".join(ast.literal_eval(t))
)
return retrieval_df
@dataclass
class RetrievalPipelineWrapper:
pipeline: retrieval_pipeline.RetrievalPipeline
@classmethod
def build_from_encoders(cls, query_encoder, document_encoder, documents, metadata):
retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory(
feature_extractor=document_encoder,
query_feature_extractor=query_encoder,
index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"),
)
pipeline = retrieval_pipe.build(documents, metadata=metadata)
return RetrievalPipelineWrapper(pipeline)
def search(
self,
query: str,
k: int,
description_length: int,
additional_shown_cols: List[str],
):
results = self.pipeline.find_similar(query, k)
# results['repo'] = results.index
results["link"] = "https://github.com/" + results["repo"]
for col in additional_shown_cols:
results[col] = results[col].apply(
lambda desc: truncate_description(desc, description_length)
)
shown_cols = ["repo", "tasks", "link", "distance"]
shown_cols = shown_cols + additional_shown_cols
return results.reset_index(drop=True)[shown_cols]
@classmethod
def setup_from_encoder_names(cls, query_encoder_path, document_encoder_path, documents, metadata, device
):
document_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
sentence_transformers.SentenceTransformer(
document_encoder_path, device=device
)
)
query_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
sentence_transformers.SentenceTransformer(query_encoder_path, device=device)
)
return cls.build_from_encoders(
query_encoder, document_encoder, documents, metadata
)