File size: 3,950 Bytes
bf8e6b0
 
 
 
 
 
 
0a8b37d
bf8e6b0
 
a09b56d
 
bf8e6b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a09b56d
 
bf8e6b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a09b56d
 
bf8e6b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a09b56d
bf8e6b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68ecf38
 
 
 
 
 
 
bf8e6b0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
import os
import pathlib
import pandas as pd
from collections import defaultdict
import json
import copy
import plotly.express as px



@st.cache_data
def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]):
    if corpus_file is None:
        return None
    did2text = {}
    id_key = "_id"
    with corpus_file as f:
        for idx, line in enumerate(f):
            uses_bytes = not (type(line) == str)
            if uses_bytes:
                if idx == 0 and "doc_id" in line.decode("utf-8"):
                    continue
                inst = json.loads(line.decode("utf-8"))
            else:
                if idx == 0 and "doc_id" in line:
                    continue
                inst = json.loads(line)
            all_text = " ".join([inst[col] for col in columns_to_combine if col in inst])
            if id_key not in inst:
                id_key = "doc_id"
            did2text[inst[id_key]] = {
                "text": all_text,
                "title": inst["title"] if "title" in inst else "",
            }
    return did2text


@st.cache_data
def load_local_queries(queries_file):
    if queries_file is None:
        return None
    qid2text = {}
    id_key = "_id"
    with queries_file as f:
        for idx, line in enumerate(f):
            uses_bytes = not (type(line) == str)
            if uses_bytes:
                if idx == 0 and "query_id" in line.decode("utf-8"):
                    continue
                inst = json.loads(line.decode("utf-8"))
            else:
                if idx == 0 and "query_id" in line:
                    continue
                inst = json.loads(line)
            if id_key not in inst:
                id_key = "query_id"
            qid2text[inst[id_key]] = inst["text"]
    return qid2text


@st.cache_data
def load_local_qrels(qrels_file):
    if qrels_file is None:
        return None
    qid2did2label = defaultdict(dict)
    with qrels_file as f:
        for idx, line in enumerate(f):
            uses_bytes = not (type(line) == str)
            if uses_bytes:
                if idx == 0 and "qid" in line.decode("utf-8") or "query-id" in line.decode("utf-8"):
                    continue
                cur_line = line.decode("utf-8")
            else:
                if idx == 0 and "qid" in line or "query-id" in line:
                    continue
                cur_line = line
            try:
                qid, _, doc_id, label = cur_line.split()
            except:
                qid, doc_id, label = cur_line.split()
            qid2did2label[str(qid)][str(doc_id)] = int(label)

    return qid2did2label



@st.cache_data
def load_jsonl(f):
    did2text = defaultdict(list)
    sub_did2text = {}

    for idx, line in enumerate(f):
        inst = json.loads(line)
        if "question" in inst:
            docid = inst["metadata"][0]["passage_id"] if "doc_id" not in inst else inst["doc_id"]
            did2text[docid].append(inst["question"])
        elif "text" in inst:
            docid = inst["doc_id"] if "doc_id" in inst else inst["did"]
            did2text[docid].append(inst["text"])
            sub_did2text[inst["did"]] = inst["text"]
        elif "query" in inst:
            docid = inst["doc_id"] if "doc_id" in inst else inst["did"]
            did2text[docid].append(inst["query"])
        else:
            breakpoint()
            raise NotImplementedError("Need to handle this case")
                
    return did2text, sub_did2text



@st.cache_data(persist="disk")
def get_dataset(dataset_name: str, input_fields_doc, input_fields_query):
    if type(input_fields_doc) == str:
        input_fields_doc = input_fields_doc.strip().split(",")
    if type(input_fields_query) == str:
        input_fields_query = input_fields_query.strip().split(",")

    if dataset_name == "":
        return {}, {}, {}
    
    else:
        raise NotImplementedError("Dataset not implemented")