HUANG-Stephanie's picture
Upload 88 files
9ff79dc verified
import os
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
USE_LOCAL_DATASET = os.environ.get("USE_LOCAL_DATASET", "1") == "1"
def add_metadata_column(dataset, column_name, value):
def add_source(example):
example[column_name] = value
return example
return dataset.map(add_source)
def load_train_set() -> DatasetDict:
ds_paths = [
"infovqa_train",
"docvqa_train",
"arxivqa_train",
"tatdqa_train",
"syntheticDocQA_government_reports_train",
"syntheticDocQA_healthcare_industry_train",
"syntheticDocQA_artificial_intelligence_train",
"syntheticDocQA_energy_train",
]
base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
ds_tot = []
for path in ds_paths:
cpath = base_path + path
ds = load_dataset(cpath, split="train")
if "arxivqa" in path:
# subsample 10k
ds = ds.shuffle(42).select(range(10000))
ds_tot.append(ds)
dataset = concatenate_datasets(ds_tot)
dataset = dataset.shuffle(seed=42)
# split into train and test
dataset_eval = dataset.select(range(500))
dataset = dataset.select(range(500, len(dataset)))
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
return ds_dict
def load_train_set_with_tabfquad() -> DatasetDict:
ds_paths = [
"infovqa_train",
"docvqa_train",
"arxivqa_train",
"tatdqa_train",
"tabfquad_train_subsampled",
"syntheticDocQA_government_reports_train",
"syntheticDocQA_healthcare_industry_train",
"syntheticDocQA_artificial_intelligence_train",
"syntheticDocQA_energy_train",
]
base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
ds_tot = []
for path in ds_paths:
cpath = base_path + path
ds = load_dataset(cpath, split="train")
if "arxivqa" in path:
# subsample 10k
ds = ds.shuffle(42).select(range(10000))
ds_tot.append(ds)
dataset = concatenate_datasets(ds_tot)
dataset = dataset.shuffle(seed=42)
# split into train and test
dataset_eval = dataset.select(range(500))
dataset = dataset.select(range(500, len(dataset)))
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
return ds_dict
def load_train_set_with_docmatix() -> DatasetDict:
ds_paths = [
"infovqa_train",
"docvqa_train",
"arxivqa_train",
"tatdqa_train",
"tabfquad_train_subsampled",
"syntheticDocQA_government_reports_train",
"syntheticDocQA_healthcare_industry_train",
"syntheticDocQA_artificial_intelligence_train",
"syntheticDocQA_energy_train",
"Docmatix_filtered_train",
]
base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
ds_tot = []
for path in ds_paths:
cpath = base_path + path
ds = load_dataset(cpath, split="train")
if "arxivqa" in path:
# subsample 10k
ds = ds.shuffle(42).select(range(10000))
ds_tot.append(ds)
dataset = concatenate_datasets(ds_tot)
dataset = dataset.shuffle(seed=42)
# split into train and test
dataset_eval = dataset.select(range(500))
dataset = dataset.select(range(500, len(dataset)))
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
return ds_dict
def load_docvqa_dataset() -> DatasetDict:
if USE_LOCAL_DATASET:
dataset_doc = load_dataset("./data_dir/DocVQA", "DocVQA", split="validation")
dataset_doc_eval = load_dataset("./data_dir/DocVQA", "DocVQA", split="test")
dataset_info = load_dataset("./data_dir/DocVQA", "InfographicVQA", split="validation")
dataset_info_eval = load_dataset("./data_dir/DocVQA", "InfographicVQA", split="test")
else:
dataset_doc = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation")
dataset_doc_eval = load_dataset("lmms-lab/DocVQA", "DocVQA", split="test")
dataset_info = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="validation")
dataset_info_eval = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="test")
# concatenate the two datasets
dataset = concatenate_datasets([dataset_doc, dataset_info])
dataset_eval = concatenate_datasets([dataset_doc_eval, dataset_info_eval])
# sample 100 from eval dataset
dataset_eval = dataset_eval.shuffle(seed=42).select(range(200))
# rename question as query
dataset = dataset.rename_column("question", "query")
dataset_eval = dataset_eval.rename_column("question", "query")
# create new column image_filename that corresponds to ucsf_document_id if not None, else image_url
dataset = dataset.map(
lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
)
dataset_eval = dataset_eval.map(
lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
)
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
return ds_dict
class TestSetFactory:
def __init__(self, dataset_path):
self.dataset_path = dataset_path
def __call__(self, *args, **kwargs):
dataset = load_dataset(self.dataset_path, split="test")
return dataset
if __name__ == "__main__":
ds = TestSetFactory("vidore/tabfquad_test_subsampled")()
print(ds)