Spaces:
Sleeping
Sleeping
import os | |
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset | |
USE_LOCAL_DATASET = os.environ.get("USE_LOCAL_DATASET", "1") == "1" | |
def add_metadata_column(dataset, column_name, value): | |
def add_source(example): | |
example[column_name] = value | |
return example | |
return dataset.map(add_source) | |
def load_train_set() -> DatasetDict: | |
ds_paths = [ | |
"infovqa_train", | |
"docvqa_train", | |
"arxivqa_train", | |
"tatdqa_train", | |
"syntheticDocQA_government_reports_train", | |
"syntheticDocQA_healthcare_industry_train", | |
"syntheticDocQA_artificial_intelligence_train", | |
"syntheticDocQA_energy_train", | |
] | |
base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/" | |
ds_tot = [] | |
for path in ds_paths: | |
cpath = base_path + path | |
ds = load_dataset(cpath, split="train") | |
if "arxivqa" in path: | |
# subsample 10k | |
ds = ds.shuffle(42).select(range(10000)) | |
ds_tot.append(ds) | |
dataset = concatenate_datasets(ds_tot) | |
dataset = dataset.shuffle(seed=42) | |
# split into train and test | |
dataset_eval = dataset.select(range(500)) | |
dataset = dataset.select(range(500, len(dataset))) | |
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval}) | |
return ds_dict | |
def load_train_set_with_tabfquad() -> DatasetDict: | |
ds_paths = [ | |
"infovqa_train", | |
"docvqa_train", | |
"arxivqa_train", | |
"tatdqa_train", | |
"tabfquad_train_subsampled", | |
"syntheticDocQA_government_reports_train", | |
"syntheticDocQA_healthcare_industry_train", | |
"syntheticDocQA_artificial_intelligence_train", | |
"syntheticDocQA_energy_train", | |
] | |
base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/" | |
ds_tot = [] | |
for path in ds_paths: | |
cpath = base_path + path | |
ds = load_dataset(cpath, split="train") | |
if "arxivqa" in path: | |
# subsample 10k | |
ds = ds.shuffle(42).select(range(10000)) | |
ds_tot.append(ds) | |
dataset = concatenate_datasets(ds_tot) | |
dataset = dataset.shuffle(seed=42) | |
# split into train and test | |
dataset_eval = dataset.select(range(500)) | |
dataset = dataset.select(range(500, len(dataset))) | |
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval}) | |
return ds_dict | |
def load_train_set_with_docmatix() -> DatasetDict: | |
ds_paths = [ | |
"infovqa_train", | |
"docvqa_train", | |
"arxivqa_train", | |
"tatdqa_train", | |
"tabfquad_train_subsampled", | |
"syntheticDocQA_government_reports_train", | |
"syntheticDocQA_healthcare_industry_train", | |
"syntheticDocQA_artificial_intelligence_train", | |
"syntheticDocQA_energy_train", | |
"Docmatix_filtered_train", | |
] | |
base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/" | |
ds_tot = [] | |
for path in ds_paths: | |
cpath = base_path + path | |
ds = load_dataset(cpath, split="train") | |
if "arxivqa" in path: | |
# subsample 10k | |
ds = ds.shuffle(42).select(range(10000)) | |
ds_tot.append(ds) | |
dataset = concatenate_datasets(ds_tot) | |
dataset = dataset.shuffle(seed=42) | |
# split into train and test | |
dataset_eval = dataset.select(range(500)) | |
dataset = dataset.select(range(500, len(dataset))) | |
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval}) | |
return ds_dict | |
def load_docvqa_dataset() -> DatasetDict: | |
if USE_LOCAL_DATASET: | |
dataset_doc = load_dataset("./data_dir/DocVQA", "DocVQA", split="validation") | |
dataset_doc_eval = load_dataset("./data_dir/DocVQA", "DocVQA", split="test") | |
dataset_info = load_dataset("./data_dir/DocVQA", "InfographicVQA", split="validation") | |
dataset_info_eval = load_dataset("./data_dir/DocVQA", "InfographicVQA", split="test") | |
else: | |
dataset_doc = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation") | |
dataset_doc_eval = load_dataset("lmms-lab/DocVQA", "DocVQA", split="test") | |
dataset_info = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="validation") | |
dataset_info_eval = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="test") | |
# concatenate the two datasets | |
dataset = concatenate_datasets([dataset_doc, dataset_info]) | |
dataset_eval = concatenate_datasets([dataset_doc_eval, dataset_info_eval]) | |
# sample 100 from eval dataset | |
dataset_eval = dataset_eval.shuffle(seed=42).select(range(200)) | |
# rename question as query | |
dataset = dataset.rename_column("question", "query") | |
dataset_eval = dataset_eval.rename_column("question", "query") | |
# create new column image_filename that corresponds to ucsf_document_id if not None, else image_url | |
dataset = dataset.map( | |
lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]} | |
) | |
dataset_eval = dataset_eval.map( | |
lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]} | |
) | |
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval}) | |
return ds_dict | |
class TestSetFactory: | |
def __init__(self, dataset_path): | |
self.dataset_path = dataset_path | |
def __call__(self, *args, **kwargs): | |
dataset = load_dataset(self.dataset_path, split="test") | |
return dataset | |
if __name__ == "__main__": | |
ds = TestSetFactory("vidore/tabfquad_test_subsampled")() | |
print(ds) | |