Spaces:
Running
on
T4
Running
on
T4
import numpy as np | |
import torch | |
from torch.utils.data import Dataset, IterableDataset | |
from ..utils.generic import ModelOutput | |
class PipelineDataset(Dataset): | |
def __init__(self, dataset, process, params): | |
self.dataset = dataset | |
self.process = process | |
self.params = params | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, i): | |
item = self.dataset[i] | |
processed = self.process(item, **self.params) | |
return processed | |
class PipelineIterator(IterableDataset): | |
def __init__(self, loader, infer, params, loader_batch_size=None): | |
""" | |
Roughly equivalent to | |
``` | |
for item in loader: | |
yield infer(item, **params) | |
``` | |
Arguments: | |
loader (`torch.utils.data.DataLoader` or any iterator): | |
The iterator that will be used to apply `infer` on. | |
infer (any function): | |
The function to apply of each element of `loader`. | |
params (`dict`): | |
The parameters passed to `infer` along with every item | |
loader_batch_size (`int`, *optional*): | |
If specified, the items of `loader` are supposed to come as batch, and are loader_batched here | |
making it roughly behave as | |
``` | |
for items in loader: | |
for i in loader_batch_size: | |
item = items[i] | |
yield infer(item, **params) | |
```""" | |
self.loader = loader | |
self.infer = infer | |
self.params = params | |
if loader_batch_size == 1: | |
# Let's spare some time by deactivating altogether | |
loader_batch_size = None | |
self.loader_batch_size = loader_batch_size | |
# Internal bookkeeping | |
self._loader_batch_index = None | |
self._loader_batch_data = None | |
def __len__(self): | |
return len(self.loader) | |
def __iter__(self): | |
self.iterator = iter(self.loader) | |
return self | |
def loader_batch_item(self): | |
""" | |
Return item located at `loader_batch_index` within the current `loader_batch_data`. | |
""" | |
if isinstance(self._loader_batch_data, torch.Tensor): | |
# Batch data is simple tensor, just fetch the slice | |
result = self._loader_batch_data[self._loader_batch_index] | |
else: | |
# Batch data is assumed to be BaseModelOutput (or dict) | |
loader_batched = {} | |
for k, element in self._loader_batch_data.items(): | |
if isinstance(element, ModelOutput): | |
# Convert ModelOutput to tuple first | |
element = element.to_tuple() | |
if isinstance(element[0], torch.Tensor): | |
loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element) | |
elif isinstance(element[0], np.ndarray): | |
loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element) | |
continue | |
if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple): | |
# Those are stored as lists of tensors so need specific unbatching. | |
if isinstance(element[0], torch.Tensor): | |
loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element) | |
elif isinstance(element[0], np.ndarray): | |
loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element) | |
continue | |
if element is None: | |
# This can happen for optional data that get passed around | |
loader_batched[k] = None | |
elif isinstance(element[self._loader_batch_index], torch.Tensor): | |
# Take correct batch data, but make it looked like batch_size=1 | |
# For compatibility with other methods within transformers | |
loader_batched[k] = element[self._loader_batch_index].unsqueeze(0) | |
elif isinstance(element[self._loader_batch_index], np.ndarray): | |
# Take correct batch data, but make it looked like batch_size=1 | |
# For compatibility with other methods within transformers | |
loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0) | |
else: | |
# This is typically a list, so no need to `unsqueeze`. | |
loader_batched[k] = element[self._loader_batch_index] | |
# Recreate the element by reusing the original class to make it look | |
# batch_size=1 | |
result = self._loader_batch_data.__class__(loader_batched) | |
self._loader_batch_index += 1 | |
return result | |
def __next__(self): | |
if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size: | |
# We are currently unrolling a batch so we just need to return | |
# the current item within a batch | |
return self.loader_batch_item() | |
# We're out of items within a batch | |
item = next(self.iterator) | |
processed = self.infer(item, **self.params) | |
# We now have a batch of "inferred things". | |
if self.loader_batch_size is not None: | |
# Try to infer the size of the batch | |
if isinstance(processed, torch.Tensor): | |
first_tensor = processed | |
else: | |
key = list(processed.keys())[0] | |
first_tensor = processed[key] | |
if isinstance(first_tensor, list): | |
observed_batch_size = len(first_tensor) | |
else: | |
observed_batch_size = first_tensor.shape[0] | |
if 0 < observed_batch_size < self.loader_batch_size: | |
# could be last batch so we can't unroll as many | |
# elements. | |
self.loader_batch_size = observed_batch_size | |
# Setting internal index to unwrap the batch | |
self._loader_batch_data = processed | |
self._loader_batch_index = 0 | |
return self.loader_batch_item() | |
else: | |
# We're not unrolling batches | |
return processed | |
class PipelineChunkIterator(PipelineIterator): | |
def __init__(self, loader, infer, params, loader_batch_size=None): | |
""" | |
Roughly equivalent to | |
``` | |
for iterator in loader: | |
for item in iterator: | |
yield infer(item, **params) | |
``` | |
Arguments: | |
loader (`torch.utils.data.DataLoader` or any iterator): | |
The iterator that will be used to apply `infer` on. | |
infer (any function): | |
The function to apply of each element of `loader`. | |
params (`dict`): | |
The parameters passed to `infer` along with every item | |
""" | |
super().__init__(loader, infer, params) | |
def __iter__(self): | |
self.iterator = iter(self.loader) | |
self.subiterator = None | |
return self | |
def __next__(self): | |
if self.subiterator is None: | |
"Subiterator None means we haven't started a `preprocess` iterator. so start it" | |
self.subiterator = self.infer(next(self.iterator), **self.params) | |
try: | |
# Try to return next item | |
processed = next(self.subiterator) | |
except StopIteration: | |
# When a preprocess iterator ends, we can start lookig at the next item | |
# ChunkIterator will keep feeding until ALL elements of iterator | |
# all have created their subiterator and have been iterating against. | |
# | |
# Another way to look at it, is we're basically flattening lists of lists | |
# into a single list, but with generators | |
self.subiterator = self.infer(next(self.iterator), **self.params) | |
processed = next(self.subiterator) | |
return processed | |
class PipelinePackIterator(PipelineIterator): | |
""" | |
Roughly equivalent to | |
``` | |
packed = [] | |
for item in loader: | |
packed.append(item) | |
if item["is_last"]: | |
yield packed | |
packed = [] | |
``` | |
but it also handles cases where `item` are batched (meaning it's a dict of Tensor with first dimension > 1. In | |
that case it does | |
``` | |
packed = [] | |
for batch in loader: | |
# item is batched | |
for item in batch: | |
packed.append(item) | |
if item["is_last"]: | |
yield packed | |
packed = [] | |
``` | |
Arguments: | |
loader (`torch.utils.data.DataLoader` or any iterator): | |
The iterator that will be used to apply `infer` on. | |
infer (any function): | |
The function to apply of each element of `loader`. | |
params (`dict`): | |
The parameters passed to `infer` along with every item | |
loader_batch_size (`int`, *optional*): | |
If specified, the items of `loader` are supposed to come as batch, and are loader_batched here making | |
it roughly behave as | |
``` | |
for items in loader: | |
for i in loader_batch_size: | |
item = items[i] | |
yield infer(item, **params) | |
```""" | |
def __iter__(self): | |
self.iterator = iter(self.loader) | |
return self | |
def __next__(self): | |
# Extremely similar to PipelineIterator in its unpacking mechanism | |
# BUT, we have an extra required item which is the presence of `is_last` | |
# That is because everything is flattened by `PipelineChunkIterator` we | |
# need to keep track of how to regroup here in the original `process` | |
# boundaries so that `process` and `postprocess` see the same data. | |
# This iterator accumulates items (possibly while unbatching) until it | |
# its a `is_last` and then just passes it on to the caller. | |
is_last = False | |
accumulator = [] | |
if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size: | |
while self._loader_batch_index < self.loader_batch_size: | |
item = self.loader_batch_item() | |
is_last = item.pop("is_last") | |
accumulator.append(item) | |
if is_last: | |
return accumulator | |
while not is_last: | |
processed = self.infer(next(self.iterator), **self.params) | |
if self.loader_batch_size is not None: | |
if isinstance(processed, torch.Tensor): | |
first_tensor = processed | |
else: | |
key = list(processed.keys())[0] | |
first_tensor = processed[key] | |
if isinstance(first_tensor, list): | |
observed_batch_size = len(first_tensor) | |
else: | |
observed_batch_size = first_tensor.shape[0] | |
if 0 < observed_batch_size < self.loader_batch_size: | |
# could be last batch so we can't unroll as many | |
# elements. | |
self.loader_batch_size = observed_batch_size | |
self._loader_batch_data = processed | |
self._loader_batch_index = 0 | |
while self._loader_batch_index < self.loader_batch_size: | |
item = self.loader_batch_item() | |
is_last = item.pop("is_last") | |
accumulator.append(item) | |
if is_last: | |
return accumulator | |
else: | |
item = processed | |
is_last = item.pop("is_last") | |
accumulator.append(item) | |
return accumulator | |
class KeyDataset(Dataset): | |
def __init__(self, dataset: Dataset, key: str): | |
self.dataset = dataset | |
self.key = key | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, i): | |
return self.dataset[i][self.key] | |
class KeyPairDataset(Dataset): | |
def __init__(self, dataset: Dataset, key1: str, key2: str): | |
self.dataset = dataset | |
self.key1 = key1 | |
self.key2 = key2 | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, i): | |
return {"text": self.dataset[i][self.key1], "text_pair": self.dataset[i][self.key2]} | |