nikhil_no_persistent / lilac /batch_utils.py
nsthorat-lilac's picture
Duplicate from lilacai/nikhil_staging
bfc0ec6
raw
history blame
3.47 kB
"""Utils for the python server."""
import itertools
from typing import Any, Callable, Generator, Iterable, Iterator, TypeVar, Union, cast
from .schema import Item
from .utils import chunks, is_primitive
def _deep_flatten(input: Union[Iterator, object],
is_primitive_predicate: Callable[[object], bool]) -> Generator:
"""Flattens a nested iterable."""
if is_primitive_predicate(input):
yield input
elif isinstance(input, dict):
yield input
elif is_primitive(input):
yield input
else:
for elem in cast(Iterator, input):
yield from _deep_flatten(elem, is_primitive_predicate)
def deep_flatten(input: Union[Iterator, Iterable],
is_primitive_predicate: Callable[[object], bool] = is_primitive) -> Iterator:
"""Flattens a deeply nested iterator.
Primitives and dictionaries are not flattened. The user can also provide a predicate to determine
what is a primitive.
"""
return _deep_flatten(input, is_primitive_predicate)
def _deep_unflatten(flat_input: Iterator[list[object]], original_input: Union[Iterable, object],
is_primitive_predicate: Callable[[object], bool]) -> Union[list, dict]:
"""Unflattens a deeply flattened iterable according to the original iterable's structure."""
if is_primitive_predicate(original_input):
return next(flat_input)
else:
values: Iterable
if isinstance(original_input, dict):
values = original_input.values()
else:
values = cast(Iterable, original_input)
return [_deep_unflatten(flat_input, orig_elem, is_primitive_predicate) for orig_elem in values]
def deep_unflatten(flat_input: Union[Iterable, Iterator],
original_input: Union[Iterable, object],
is_primitive_predicate: Callable[[object], bool] = is_primitive) -> list:
"""Unflattens a deeply flattened iterable according to the original iterable's structure."""
return cast(list, _deep_unflatten(iter(flat_input), original_input, is_primitive_predicate))
TFlatten = TypeVar('TFlatten')
def flatten(inputs: Iterable[Iterable[TFlatten]]) -> Iterator[TFlatten]:
"""Flattens a nested iterator.
Only supports flattening one level deep.
"""
for input in inputs:
yield from input
TUnflatten = TypeVar('TUnflatten')
def unflatten(flat_inputs: Union[Iterable[TUnflatten], Iterator[TUnflatten]],
original_inputs: Iterable[Iterable[Any]]) -> Iterator[list[TUnflatten]]:
"""Unflattens a flattened iterable according to the original iterable's structure."""
flat_inputs_iter = iter(flat_inputs)
for original_input in original_inputs:
yield [next(flat_inputs_iter) for _ in original_input]
TFlatBatchedInput = TypeVar('TFlatBatchedInput')
TFlatBatchedOutput = TypeVar('TFlatBatchedOutput')
def flat_batched_compute(input: Iterable[Iterable[TFlatBatchedInput]],
f: Callable[[list[TFlatBatchedInput]], Iterable[TFlatBatchedOutput]],
batch_size: int) -> Iterable[Iterable[TFlatBatchedOutput]]:
"""Flatten the input, batched call f, and return the output unflattened."""
# Tee the input so we can use it twice for the input and output shapes.
input_1, input_2 = itertools.tee(input, 2)
batches = chunks(flatten(input_1), batch_size)
batched_outputs = flatten((f(batch) for batch in batches))
return unflatten(batched_outputs, input_2)
TBatchSpanVectorOutput = TypeVar('TBatchSpanVectorOutput', bound=Item)