File size: 6,809 Bytes
0bb28cf
 
ad38c8f
 
 
 
 
 
 
a38b615
 
284db10
 
ad38c8f
17b4878
19042f6
0bb28cf
 
a38b615
ad38c8f
 
 
 
a38b615
 
 
 
 
 
 
 
ad38c8f
 
 
 
 
 
 
a38b615
 
 
 
 
 
 
 
 
 
ad38c8f
a38b615
 
 
 
 
 
 
 
 
 
 
 
 
 
ad38c8f
 
 
e5cdc54
4fd027c
ad38c8f
 
 
 
 
 
 
 
d46183f
ad38c8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976f652
 
ad38c8f
 
 
 
 
4fd027c
ad38c8f
 
 
 
 
 
 
 
 
 
 
 
 
 
976f652
 
ad38c8f
 
0bb28cf
976f652
ad38c8f
976f652
 
 
 
0bb28cf
45edcec
 
 
ad38c8f
976f652
 
 
 
 
 
 
ad38c8f
 
 
 
 
 
976f652
 
 
 
 
e5cdc54
976f652
 
ad38c8f
 
976f652
 
ad38c8f
976f652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45edcec
ad38c8f
976f652
 
 
 
 
 
 
 
 
 
ab438d6
976f652
 
 
 
ad38c8f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import os

import arxiv
import gradio as gr
import pandas as pd
from apscheduler.schedulers.background import BackgroundScheduler
from cachetools import TTLCache, cached
from setfit import SetFitModel
from tqdm.auto import tqdm
import stamina
from arxiv import UnexpectedEmptyPageError, ArxivError

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

CACHE_TIME = 60 * 60 * 12  # 12 hours
MAX_RESULTS = 300


client = arxiv.Client(page_size=50, delay_seconds=3, num_retries=2)


@cached(cache=TTLCache(maxsize=10, ttl=CACHE_TIME))
def get_arxiv_result():
    return _get_arxiv_result()


@stamina.retry(
    on=(ValueError, UnexpectedEmptyPageError, ArxivError), attempts=10, wait_max=60 * 15
)
def _get_arxiv_result():
    results = [
        {
            "title": result.title,
            "abstract": result.summary,
            "url": result.entry_id,
            "category": result.primary_category,
            "updated": result.updated,
        }
        for result in tqdm(
            client.results(
                arxiv.Search(
                    query="ti:dataset",
                    max_results=MAX_RESULTS,
                    sort_by=arxiv.SortCriterion.SubmittedDate,
                )
            ),
            total=MAX_RESULTS,
        )
    ]
    if len(results) > 1:
        return results
    else:
        raise ValueError("No results found")
    # return [
    #     {
    #         "title": result.title,
    #         "abstract": result.summary,
    #         "url": result.entry_id,
    #         "category": result.primary_category,
    #         "updated": result.updated,
    #     }
    #     for result in tqdm(search.results(), total=MAX_RESULTS)
    # ]


def load_model():
    return SetFitModel.from_pretrained("librarian-bots/is_new_dataset_teacher_model")


def format_row_for_model(row):
    return f"TITLE: {row['title']} \n\nABSTRACT: {row['abstract']}"


int2label = {0: "new_dataset", 1: "not_new_dataset"}


def get_predictions(data: list[dict], model=None, batch_size=128):
    if model is None:
        model = load_model()
    predictions = []
    for i in tqdm(range(0, len(data), batch_size)):
        batch = data[i : i + batch_size]
        text_inputs = [format_row_for_model(row) for row in batch]
        batch_predictions = model.predict_proba(text_inputs)
        for j, row in enumerate(batch):
            prediction = batch_predictions[j]
            row["prediction"] = int2label[int(prediction.argmax())]
            row["probability"] = float(prediction.max())
            predictions.append(row)
    return predictions


def create_markdown(row):
    title = row["title"]
    abstract = row["abstract"]
    arxiv_id = row["arxiv_id"]
    hub_paper_url = f"https://huggingface.co/papers/{arxiv_id}"
    updated = row["updated"]
    updated = updated.strftime("%Y-%m-%d")
    broad_category = row["broad_category"]
    category = row["category"]
    return f""" <h2> {title} </h2> Updated: {updated} 
    | Category: {broad_category}  | Subcategory: {category} |
\n\n{abstract}
\n\n [Hugging Face Papers page]({hub_paper_url})
    """


@cached(cache=TTLCache(maxsize=10, ttl=CACHE_TIME))
def prepare_data():
    print("Downloading arxiv results...")
    arxiv_results = get_arxiv_result()
    print("loading model...")
    model = load_model()
    print("Making predictions...")
    predictions = get_predictions(arxiv_results, model=model)
    df = pd.DataFrame(predictions)
    df.loc[:, "arxiv_id"] = df["url"].str.extract(r"(\d+\.\d+)")
    df.loc[:, "broad_category"] = df["category"].str.split(".").str[0]
    df.loc[:, "markdown"] = df.apply(create_markdown, axis=1)
    return df


all_possible_arxiv_categories = sorted(prepare_data().category.unique().tolist())
broad_categories = sorted(prepare_data().broad_category.unique().tolist())


# @list_cacheable
def create_markdown_summary(categories=None, new_only=True, narrow_categories=None):
    df = prepare_data()
    if new_only:
        df = df[df["prediction"] == "new_dataset"]
    if narrow_categories is not None:
        df = df[df["category"].isin(narrow_categories)]
    if categories is not None and not narrow_categories:
        df = prepare_data()
        if new_only:
            df = df[df["prediction"] == "new_dataset"]
        df = df[df["broad_category"].isin(categories)]
    number_of_results = len(df)
    results = (
        "<h1 style='text-align: center'> arXiv papers related to datasets</h1> \n\n"
    )
    results += f"Number of results: {number_of_results}\n\n"
    results += "\n\n<br>".join(df["markdown"].tolist())
    return results


scheduler = BackgroundScheduler()
scheduler.add_job(prepare_data, "cron", hour=3, minute=30)
scheduler.start()

description = """This Space shows recent papers on arXiv that are *likely* to be papers introducing new datasets related to machine learning. \n\n
The Space works by:
- searching for papers on arXiv with the term `dataset` in the title + "machine learning" in the abstract
- passing the abstract and title of the papers to a machine learning model that predicts if the paper is introducing a new dataset or not
 
This Space is a work in progress. The model is not perfect, and the search query is not perfect. If you have  suggestions for how to improve this Space, please open a Discussion.\n\n"""


with gr.Blocks() as demo:
    gr.Markdown(
        "<h1 style='text-align: center'>  &#x2728;New Datasets in Machine Learning "
        " &#x2728; </h1>"
    )
    gr.Markdown(description)
    with gr.Row():
        broad_categories = gr.Dropdown(
            choices=broad_categories,
            label="Broad arXiv Category",
            multiselect=True,
            value="cs",
        )
    with gr.Accordion("Advanced Options", open=False):
        gr.Markdown(
            "Narrow by arXiv categories. **Note** this will take precedence over the"
            " broad category selection."
        )
        narrow_categories = gr.Dropdown(
            choices=all_possible_arxiv_categories,
            value=None,
            multiselect=True,
            label="Narrow arXiv Category",
        )
        gr.ClearButton(narrow_categories, "Clear Narrow Categories", size="sm")
    with gr.Row():
        new_only = gr.Checkbox(True, label="New Datasets Only", interactive=True)
    results = gr.Markdown(create_markdown_summary())
    broad_categories.change(
        create_markdown_summary,
        inputs=[broad_categories, new_only, narrow_categories],
        outputs=results,
    )
    narrow_categories.change(
        create_markdown_summary,
        inputs=[broad_categories, new_only, narrow_categories],
        outputs=results,
    )
    new_only.change(
        create_markdown_summary,
        [broad_categories, new_only, narrow_categories],
        results,
    )

demo.launch()