Spaces:
Runtime error
Runtime error
# Copyright 2021 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import os | |
from os.path import exists | |
from os.path import join as pjoin | |
from datasets import Dataset, load_dataset, load_from_disk | |
from tqdm import tqdm | |
_CACHE_DIR = "cache_dir" | |
# grab first N rows of a dataset from the hub | |
def load_truncated_dataset( | |
dataset_name, | |
config_name=None, | |
split_name=None, | |
num_rows=0, | |
use_streaming=True, | |
use_auth_token=None, | |
use_dataset=None, | |
): | |
""" | |
This function loads the first `num_rows` items of a dataset for a | |
given `config_name` and `split_name`. | |
When the dataset is streamable, we iterate through the first | |
`num_rows` examples in streaming mode, write them to a jsonl file, | |
then create a new dataset from the json. | |
This is the most direct way to make a Dataset from an IterableDataset | |
as of datasets version 1.6.1. | |
Otherwise, we download the full dataset and select the first | |
`num_rows` items | |
Args: | |
dataset_name (string): | |
dataset id in the dataset library | |
config_name (string): | |
dataset configuration | |
split_name (string): | |
optional split name, defaults to `train` | |
num_rows (int): | |
number of rows to truncate the dataset to, <= 0 means no truncation | |
use_streaming (bool): | |
whether to use streaming when the dataset supports it | |
use_auth_token (string): | |
HF authentication token to access private datasets | |
use_dataset (Dataset): | |
use existing dataset instead of getting one from the hub | |
Returns: | |
Dataset: | |
the truncated dataset as a Dataset object | |
""" | |
split_name = "train" if split_name is None else split_name | |
cache_name = f"{dataset_name.replace('/', '---')}_{'default' if config_name is None else config_name}_{split_name}_{num_rows}" | |
if use_streaming: | |
if not exists(pjoin(_CACHE_DIR, "tmp", f"{cache_name}.jsonl")): | |
iterable_dataset = ( | |
load_dataset( | |
dataset_name, | |
name=config_name, | |
split=split_name, | |
cache_dir=pjoin(_CACHE_DIR, "tmp", cache_name + "_temp"), | |
streaming=True, | |
use_auth_token=use_auth_token, | |
) | |
if use_dataset is None | |
else use_dataset | |
) | |
if num_rows > 0: | |
iterable_dataset = iterable_dataset.take(num_rows) | |
f = open( | |
pjoin(_CACHE_DIR, "tmp", f"{cache_name}.jsonl"), "w", encoding="utf-8" | |
) | |
for row in tqdm(iterable_dataset): | |
_ = f.write(json.dumps(row) + "\n") | |
f.close() | |
dataset = Dataset.from_json( | |
pjoin(_CACHE_DIR, "tmp", f"{cache_name}.jsonl"), | |
cache_dir=pjoin(_CACHE_DIR, "tmp", cache_name + "_jsonl"), | |
) | |
else: | |
full_dataset = ( | |
load_dataset( | |
dataset_name, | |
name=config_name, | |
split=split_name, | |
use_auth_token=use_auth_token, | |
cache_dir=pjoin(_CACHE_DIR, "tmp", cache_name + "_temp"), | |
) | |
if use_dataset is None | |
else use_dataset | |
) | |
if num_rows > 0: | |
dataset = full_dataset.select(range(num_rows)) | |
else: | |
dataset = full_dataset | |
return dataset | |
# get all instances of a specific field in a dataset with indices and labels | |
def extract_features(examples, indices, input_field_path, label_name=None): | |
""" | |
This function prepares examples for further processing by: | |
- returning an "unrolled" list of all the fields denoted by input_field_path | |
- with the indices corresponding to the example the field item came from | |
- optionally, the corresponding label is also returned with each field item | |
Args: | |
examples (dict): | |
a dictionary of lists, provided dataset.map with batched=True | |
indices (list): | |
a list of indices, provided dataset.map with with_indices=True | |
input_field_path (tuple): | |
a tuple indicating the field we want to extract. Can be a singleton | |
for top-level features (e.g. `("text",)`) or a full path for nested | |
features (e.g. `("answers", "text")`) to get all answer strings in | |
SQuAD | |
label_name (string): | |
optionally used to align the field items with labels. Currently, | |
returns the top-most field that has this name, which may fail in some | |
edge cases | |
TODO: make it so the label is specified through a full path | |
Returns: | |
Dict: | |
a dictionary of lists, used by dataset.map with batched=True. | |
labels are all None if label_name!=None but label_name is not found | |
TODO: raised an error if label_name is specified but not found | |
""" | |
top_name = input_field_path[0] | |
if label_name is not None and label_name in examples: | |
item_list = [ | |
{"index": i, "label": label, "items": items} | |
for i, items, label in zip( | |
indices, examples[top_name], examples[label_name] | |
) | |
] | |
else: | |
item_list = [ | |
{"index": i, "label": None, "items": items} | |
for i, items in zip(indices, examples[top_name]) | |
] | |
for field_name in input_field_path[1:]: | |
new_item_list = [] | |
for dct in item_list: | |
if label_name is not None and label_name in dct["items"]: | |
if isinstance(dct["items"][field_name], list): | |
new_item_list += [ | |
{"index": dct["index"], "label": label, "items": next_item} | |
for next_item, label in zip( | |
dct["items"][field_name], dct["items"][label_name] | |
) | |
] | |
else: | |
new_item_list += [ | |
{ | |
"index": dct["index"], | |
"label": dct["items"][label_name], | |
"items": dct["items"][field_name], | |
} | |
] | |
else: | |
if isinstance(dct["items"][field_name], list): | |
new_item_list += [ | |
{ | |
"index": dct["index"], | |
"label": dct["label"], | |
"items": next_item, | |
} | |
for next_item in dct["items"][field_name] | |
] | |
else: | |
new_item_list += [ | |
{ | |
"index": dct["index"], | |
"label": dct["label"], | |
"items": dct["items"][field_name], | |
} | |
] | |
item_list = new_item_list | |
res = ( | |
{ | |
"ids": [dct["index"] for dct in item_list], | |
"field": [dct["items"] for dct in item_list], | |
} | |
if label_name is None | |
else { | |
"ids": [dct["index"] for dct in item_list], | |
"field": [dct["items"] for dct in item_list], | |
"label": [dct["label"] for dct in item_list], | |
} | |
) | |
return res | |
# grab some examples and extract interesting fields | |
def prepare_clustering_dataset( | |
dataset_name, | |
input_field_path, | |
label_name=None, | |
config_name=None, | |
split_name=None, | |
num_rows=0, | |
use_streaming=True, | |
use_auth_token=None, | |
cache_dir=_CACHE_DIR, | |
use_dataset=None, | |
): | |
""" | |
This function loads the first `num_rows` items of a dataset for a | |
given `config_name` and `split_name`, and extracts all instances of a field | |
of interest denoted by `input_field_path` along with the indices of the | |
examples the instances came from and optionall their labels (`label_name`) | |
in the original dataset | |
Args: | |
dataset_name (string): | |
dataset id in the dataset library | |
input_field_path (tuple): | |
a tuple indicating the field we want to extract. Can be a singleton | |
for top-level features (e.g. `("text",)`) or a full path for nested | |
features (e.g. `("answers", "text")`) to get all answer strings in | |
SQuAD | |
label_name (string): | |
optionally used to align the field items with labels. Currently, | |
returns the top-most field that has this name, which fails in edge cases | |
config_name (string): | |
dataset configuration | |
split_name (string): | |
optional split name, defaults to `train` | |
num_rows (int): | |
number of rows to truncate the dataset to, <= 0 means no truncation | |
use_streaming (bool): | |
whether to use streaming when the dataset supports it | |
use_auth_token (string): | |
HF authentication token to access private datasets | |
use_dataset (Dataset): | |
use existing dataset instead of getting one from the hub | |
Returns: | |
Dataset: | |
the extracted dataset as a Dataset object. Note that if there is more | |
than one instance of the field per example in the original dataset | |
(e.g. multiple answers per QA example), the returned dataset will | |
have more than `num_rows` rows | |
string: | |
the path to the newsly created dataset directory | |
""" | |
cache_path = [ | |
cache_dir, | |
dataset_name.replace("/", "---"), | |
f"{'default' if config_name is None else config_name}", | |
f"{'train' if split_name is None else split_name}", | |
f"field-{'->'.join(input_field_path)}-label-{label_name}", | |
f"{num_rows}_rows", | |
"features_dset", | |
] | |
if exists(pjoin(*cache_path)): | |
pre_clustering_dset = load_from_disk(pjoin(*cache_path)) | |
else: | |
truncated_dset = load_truncated_dataset( | |
dataset_name, | |
config_name, | |
split_name, | |
num_rows, | |
use_streaming, | |
use_auth_token, | |
use_dataset, | |
) | |
def batch_func(examples, indices): | |
return extract_features(examples, indices, input_field_path, label_name) | |
pre_clustering_dset = truncated_dset.map( | |
batch_func, | |
remove_columns=truncated_dset.features, | |
batched=True, | |
with_indices=True, | |
) | |
for i in range(1, len(cache_path) - 1): | |
if not exists(pjoin(*cache_path[:i])): | |
os.mkdir(pjoin(*cache_path[:i])) | |
pre_clustering_dset.save_to_disk(pjoin(*cache_path)) | |
return pre_clustering_dset, pjoin(*cache_path) | |