Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import copy | |
import datasets | |
import pandas as pd | |
from datasets import Dataset | |
from collections import defaultdict | |
from datetime import datetime, timedelta | |
from background import process_arxiv_ids | |
from utils import create_hf_hub | |
from apscheduler.schedulers.background import BackgroundScheduler | |
def _count_nans(row): | |
count = 0 | |
for _, (k, v) in enumerate(row.items()): | |
if v is None: | |
count = count + 1 | |
return count | |
def _initialize_requested_arxiv_ids(request_ds): | |
requested_arxiv_ids = [] | |
for request_d in request_ds['train']: | |
arxiv_ids = request_d['Requested arXiv IDs'] | |
requested_arxiv_ids = requested_arxiv_ids + arxiv_ids | |
requested_arxiv_ids_df = pd.DataFrame({'Requested arXiv IDs': requested_arxiv_ids}) | |
return requested_arxiv_ids_df | |
def _initialize_paper_info(source_ds): | |
title2qna, date2qna = {}, {} | |
date_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) | |
arxivid2data = {} | |
count = 0 | |
if len(source_ds["train"]) > 1: | |
for data in source_ds["train"]: | |
if data["title"] != "dummy": | |
date = data["target_date"].strftime("%Y-%m-%d") | |
arxiv_id = data["arxiv_id"] | |
if date in date2qna: | |
papers = copy.deepcopy(date2qna[date]) | |
for paper in papers: | |
if paper["title"] == data["title"]: | |
if _count_nans(paper) > _count_nans(data): | |
date2qna[date].remove(paper) | |
date2qna[date].append(data) | |
del papers | |
else: | |
date2qna[date] = [data] | |
for date in date2qna: | |
year, month, day = date.split("-") | |
papers = date2qna[date] | |
for paper in papers: | |
title2qna[paper["title"]] = paper | |
arxivid2data[paper['arxiv_id']] = {"idx": count, "paper": paper} | |
date_dict[year][month][day].append(paper) | |
titles = [f"[{v['arxiv_id']}] {k}" for k, v in title2qna.items()] | |
return titles, date_dict, arxivid2data | |
else: | |
return [], {}, {} | |
def initialize_data(source_data_repo_id, request_data_repo_id): | |
global date_dict, arxivid2data | |
global requested_arxiv_ids_df | |
source_ds = datasets.load_dataset(source_data_repo_id) | |
request_ds = datasets.load_dataset(request_data_repo_id) | |
titles, date_dict, arxivid2data = _initialize_paper_info(source_ds) | |
requested_arxiv_ids_df = _initialize_requested_arxiv_ids(request_ds) | |
return ( | |
titles, date_dict, requested_arxiv_ids_df, arxivid2data | |
) | |
def update_dataframe(): | |
global request_arxiv_repo_id | |
request_ds = datasets.load_dataset(request_arxiv_repo_id) | |
return _initialize_requested_arxiv_ids(request_ds) | |
def initialize_repos( | |
source_data_repo_id, request_data_repo_id, hf_token | |
): | |
if create_hf_hub(source_data_repo_id, hf_token) is False: | |
print(f"{source_data_repo_id} repository already exists") | |
else: | |
dummy_row = {"title": ["dummy"]} | |
ds = Dataset.from_dict(dummy_row) | |
ds.push_to_hub(source_data_repo_id, token=hf_token) | |
if create_hf_hub(request_data_repo_id, hf_token) is False: | |
print(f"{request_data_repo_id} repository already exists") | |
else: | |
df = pd.DataFrame(data={"Requested arXiv IDs": [["top"]]}) | |
ds = Dataset.from_pandas(df) | |
ds.push_to_hub(request_data_repo_id, token=hf_token) | |
def get_secrets(): | |
global gemini_api_key | |
global hf_token | |
global request_arxiv_repo_id | |
global dataset_repo_id | |
gemini_api_key = os.getenv("GEMINI_API_KEY") | |
hf_token = os.getenv("HF_TOKEN") | |
dataset_repo_id = os.getenv("SOURCE_DATA_REPO_ID") | |
request_arxiv_repo_id = os.getenv("REQUEST_DATA_REPO_ID") | |
restart_repo_id = os.getenv("RESTART_TARGET_SPACE_REPO_ID", "chansung/paper_qa") | |
return ( | |
gemini_api_key, | |
hf_token, | |
dataset_repo_id, | |
request_arxiv_repo_id, | |
restart_repo_id | |
) |