File size: 4,858 Bytes
7026285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38d6db9
 
 
 
 
7026285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38d6db9
 
7026285
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
from functools import lru_cache

import gradio as gr
import plotly.graph_objects as go

from wimbd.es import es_init, count_documents_containing_phrases


es = es_init(None, os.getenv("lm_datasets_cloud_id"), os.getenv("lm_datasets_api_key"))
es_dolma = es_init(None, os.getenv("dolma_cloud_id"), os.getenv("dolma_api_key"))

datasets = ["OpenWebText", "C4", "OSCAR", "The Pile", "LAION-2B-en", "Dolma"]
dataset_es_map = {
    "OSCAR": "re_oscar",
    "LAION-2B-en": "re_laion2b-en-*",
    "LAION-5B": "*laion2b*",
    "OpenWebText": "openwebtext",
    "The Pile": "re_pile",
    "C4": "c4",
    "Dolma v1.5": "docs_v1.5_2023-11-02",
    "Dolma v1.7": "docs_v1.7_2024-06-04",
    "Tulu v2": "tulu-v2-sft-mixture",
}
default_checked = ["C4", "The Pile", "Dolma v1.7"]  # Datasets to be checked by default


@lru_cache()
def get_counts(index_name, phrase, es):
    return count_documents_containing_phrases(index_name, phrase, es=es)


def process_input(phrases, *dataset_choices):
    results = []
    for dataset_name, index_name, is_selected in zip(
        dataset_es_map.keys(), dataset_es_map.values(), dataset_choices
    ):
        if is_selected:
            for phrase in phrases.split("\n"):
                phrase = phrase.strip()
                if phrase:
                    if "dolma" in dataset_name.lower():
                        count = get_counts(index_name, phrase, es=es_dolma)
                    else:
                        count = get_counts(index_name, phrase, es=es)
                    results.append((dataset_name, phrase, count))

    # Format results for different output components
    table_data = [[dataset, phrase, str(count)] for dataset, phrase, count in results]

    # Create bar chart using plotly
    fig = go.Figure()
    for phrase in set([r[1] for r in results]):
        dataset_names = [r[0] for r in results if r[1] == phrase]
        counts = [r[2] for r in results if r[1] == phrase]
        fig.add_trace(go.Bar(x=dataset_names, y=counts, name=phrase))

    fig.update_layout(
        title="Document Counts by Dataset and Phrase",
        xaxis_title="Dataset",
        yaxis_title="Count",
        barmode="group",
    )

    # return table_data, markdown_text, fig
    return table_data, fig


citation_text = """If you find this tool useful, please kindly cite our paper:
```bibtex
@inproceedings{elazar2023s,
  title={What's In My Big Data?},
  author={Elazar, Yanai and Bhagia, Akshita and Magnusson, Ian Helgi and Ravichander, Abhilasha and Schwenk, Dustin and Suhr, Alane and Walsh, Evan Pete and Groeneveld, Dirk and Soldaini, Luca and Singh, Sameer and Hajishirzi, Hanna and Smith, Noah A. and Dodge, Jesse},
  booktitle={The Twelfth International Conference on Learning Representations},
  year={2024}
}```"""


def custom_layout(input_components, output_components, citation):
    return [
        input_components[0],  # Textbox
        *input_components[1:],  # Checkboxes
        output_components[0],  # Dataframe
        # output_components[1],  # Markdown
        output_components[1],  # Plot
        citation,  # Citation Markdown
    ]


iface = gr.Interface(
    fn=process_input,
    inputs=[
        gr.Textbox(
            label="Enter phrases (one per line)",
            lines=5,
            value="let's think step by step\nhello world",
        ),
        *[
            gr.Checkbox(label=dataset, value=(dataset in default_checked))
            for dataset in dataset_es_map.keys()
        ],
    ],
    outputs=[
        gr.Dataframe(headers=["Dataset", "Phrase", "Count"], label="Counts Table"),
        # gr.Markdown(label="Results as Text"),
        gr.Plot(label="Results Chart"),
        # gr.Markdown(value=citation_text)
    ],
    title="What's In My Big Data? String Counts Demo",
    description="""This app connects to the WIMBD Elasticsearch instance and counts the number of documents containing a given string in the various indexed datasets.\\
    The app uses the wimbd pypi package, which can be installed by simply running `pip install wimbd`.\\
    Access to the indices require an API key, due to the sensitive nature of the data, but can be accessed by filling up the following [form](https://forms.gle/Mk9uwJibR9H4hh9Y9).\\
    This app was created by [Yanai Elazar](https://yanaiela.github.io/), and for bugs, improvements, or feature requests, please open an issue on the [GitHub repository](https://github.com/allenai/wimbd), or send me an email.
    
    The indices were set up as part of the WIMBD project, which you can read about in our [ICLR paper](https://arxiv.org/abs/2310.20707).
    
    The returned counts are the number of documents that contain each string per dataset.""",
    article=citation_text,  # This adds the citation at the bottom
    theme=custom_layout,  # This uses our custom layout function
)


iface.launch()