File size: 3,364 Bytes
568499b
 
a284f57
568499b
a284f57
568499b
 
 
 
 
 
 
 
 
a284f57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568499b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a284f57
 
568499b
a284f57
568499b
a284f57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
from typing import Dict, List
from dataclasses import dataclass

import datasets
import ast
import pandas as pd
import sentence_transformers
import streamlit as st
from findkit import feature_extractors, indexes, retrieval_pipeline
from toolz import partial
import config


def get_doc_cols(model_name):
    model_name = model_name.replace("query-", "")
    model_name = model_name.replace("document-", "")
    return model_name.split("-")[0].split("_")


def merge_cols(df, cols):
    df["document"] = df[cols[0]]
    for col in cols:
        df["document"] = df["document"] + " " + df[col]
    return df


def get_retrieval_df(
    data_path="lambdaofgod/pwc_repositories_with_dependencies", text_list_cols=None
):
    raw_retrieval_df = (
        datasets.load_dataset(data_path)["train"]
        .to_pandas()
        .drop_duplicates(subset=["repo"])
        .reset_index(drop=True)
    )
    if text_list_cols:
        return merge_text_list_cols(raw_retrieval_df, text_list_cols)
    return raw_retrieval_df


def truncate_description(description, length=50):
    return " ".join(description.split()[:length])


def get_repos_with_descriptions(repos_df, repos):
    return repos_df.loc[repos]


def merge_text_list_cols(retrieval_df, text_list_cols):
    retrieval_df = retrieval_df.copy()
    for col in text_list_cols:
        retrieval_df[col] = retrieval_df[col].apply(
            lambda t: " ".join(ast.literal_eval(t))
        )
    return retrieval_df


@dataclass
class RetrievalPipelineWrapper:

    pipeline: retrieval_pipeline.RetrievalPipeline

    @classmethod
    def build_from_encoders(cls, query_encoder, document_encoder, documents, metadata):
        retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory(
            feature_extractor=document_encoder,
            query_feature_extractor=query_encoder,
            index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"),
        )
        pipeline = retrieval_pipe.build(documents, metadata=metadata)
        return RetrievalPipelineWrapper(pipeline)

    def search(
        self,
        query: str,
        k: int,
        description_length: int,
        additional_shown_cols: List[str],
    ):
        results = self.pipeline.find_similar(query, k)
        # results['repo'] = results.index
        results["link"] = "https://github.com/" + results["repo"]
        for col in additional_shown_cols:
            results[col] = results[col].apply(
                lambda desc: truncate_description(desc, description_length)
            )
        shown_cols = ["repo", "tasks", "link", "distance"]
        shown_cols = shown_cols + additional_shown_cols 
        return results.reset_index(drop=True)[shown_cols]

    @classmethod
    def setup_from_encoder_names(cls, query_encoder_path, document_encoder_path, documents, metadata, device
    ):
        document_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
            sentence_transformers.SentenceTransformer(
                document_encoder_path, device=device
            )
        )
        query_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
            sentence_transformers.SentenceTransformer(query_encoder_path, device=device)
        )
        return cls.build_from_encoders(
            query_encoder, document_encoder, documents, metadata
        )