|
import io |
|
import os |
|
from pathlib import Path |
|
from tempfile import TemporaryDirectory |
|
import torch |
|
import torchaudio |
|
import random |
|
import numpy as np |
|
from PIL import Image |
|
from urllib.parse import urlparse |
|
from os.path import exists |
|
import re |
|
from num2words import num2words |
|
import uuid |
|
|
|
from typing import List, Optional, Dict, Union, Tuple, Iterable |
|
|
|
from src.utils.image_utils import is_valid_image |
|
|
|
|
|
IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5] |
|
IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5] |
|
|
|
|
|
def is_local(url): |
|
url_parsed = urlparse(url) |
|
if url_parsed.scheme in ("file", ""): |
|
return exists(url_parsed.path) |
|
return False |
|
|
|
|
|
def replace_numbers_with_words(sentence): |
|
sentence = re.sub(r"(\d+)", r" \1 ", sentence) |
|
|
|
def replace_with_words(match): |
|
num = match.group(0) |
|
try: |
|
return num2words(num) |
|
except: |
|
return num |
|
|
|
return re.sub(r"\b\d+\b", replace_with_words, sentence) |
|
|
|
|
|
def save_to_buffer(audio_tensors, codec_audio_sr): |
|
|
|
result = torch.cat(audio_tensors, 1) |
|
buffer = io.BytesIO() |
|
torchaudio.save(buffer, result, int(codec_audio_sr), format="wav") |
|
buffer.seek(0) |
|
return buffer.read() |
|
|
|
|
|
def save_to_file(audio_tensors, codec_audio_sr): |
|
generated_audio_dir = f"media/voicecraft/generated" |
|
Path(generated_audio_dir).mkdir(parents=True, exist_ok=True) |
|
filename = f"{generated_audio_dir}/{str(uuid.uuid4())}.wav" |
|
tensors = torch.cat(audio_tensors, 1) |
|
torchaudio.save(filename, tensors, int(codec_audio_sr), format="wav") |
|
return filename |
|
|
|
|
|
def split_line_to_sentences(line): |
|
line = line.strip().capitalize() |
|
line = line + "." if line and line[-1] not in (".", "!", "?") else line |
|
sentences = re.findall(r"\w+.*?[.?!]", line.replace("\n", " "), flags=re.S) |
|
return sentences |
|
|
|
|
|
def seed_everything(seed=1): |
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
def add_image_tokens_to_prompt(prefix_prompt, bos_token, image_seq_length, image_token): |
|
return f"{image_token * image_seq_length}{bos_token}{prefix_prompt}\n" |
|
|
|
|
|
def rescale( |
|
image: np.ndarray, scale: float, dtype: np.dtype = np.float32 |
|
) -> np.ndarray: |
|
rescaled_image = image * scale |
|
rescaled_image = rescaled_image.astype(dtype) |
|
return rescaled_image |
|
|
|
|
|
def resize( |
|
image: Image, |
|
size: Tuple[int, int], |
|
resample: Image.Resampling = None, |
|
reducing_gap: Optional[int] = None, |
|
) -> np.ndarray: |
|
height, width = size |
|
resized_image = image.resize( |
|
(width, height), resample=resample, reducing_gap=reducing_gap |
|
) |
|
return resized_image |
|
|
|
|
|
def normalize( |
|
image: np.ndarray, |
|
mean: Union[float, Iterable[float]], |
|
std: Union[float, Iterable[float]], |
|
) -> np.ndarray: |
|
mean = np.array(mean, dtype=image.dtype) |
|
std = np.array(std, dtype=image.dtype) |
|
image = (image - mean) / std |
|
return image |
|
|
|
|
|
def process_images( |
|
images: List[Image.Image], |
|
size: Dict[str, int] = None, |
|
resample: Image.Resampling = None, |
|
rescale_factor: float = None, |
|
image_mean: Optional[Union[float, List[float]]] = None, |
|
image_std: Optional[Union[float, List[float]]] = None, |
|
) -> List[np.ndarray]: |
|
height, width = size[0], size[1] |
|
images = [ |
|
resize(image=image, size=(height, width), resample=resample) for image in images |
|
] |
|
|
|
images = [np.array(image) for image in images] |
|
|
|
images = [rescale(image, scale=rescale_factor) for image in images] |
|
|
|
images = [normalize(image, mean=image_mean, std=image_std) for image in images] |
|
|
|
images = [image.transpose(2, 0, 1) for image in images] |
|
return images |
|
|
|
|
|
def sample_top_p(probs: torch.Tensor, p: float): |
|
|
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) |
|
|
|
probs_sum = torch.cumsum(probs_sort, dim=-1) |
|
|
|
|
|
mask = probs_sum - probs_sort > p |
|
|
|
probs_sort[mask] = 0.0 |
|
|
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) |
|
|
|
next_token = torch.multinomial(probs_sort, num_samples=1) |
|
|
|
next_token = torch.gather(probs_idx, -1, next_token) |
|
return next_token |
|
|
|
|
|
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: |
|
""" |
|
Args: |
|
lengths: |
|
A 1-D tensor containing sentence lengths. |
|
max_len: |
|
The length of masks. |
|
Returns: |
|
Return a 2-D bool tensor, where masked positions |
|
are filled with `True` and non-masked positions are |
|
filled with `False`. |
|
>>> lengths = torch.tensor([1, 3, 2, 5]) |
|
>>> make_pad_mask(lengths) |
|
tensor([[False, True, True, True, True], |
|
[False, False, False, True, True], |
|
[False, False, True, True, True], |
|
[False, False, False, False, False]]) |
|
""" |
|
assert lengths.ndim == 1, lengths.ndim |
|
max_len = max(max_len, lengths.max()) |
|
n = lengths.size(0) |
|
seq_range = torch.arange(0, max_len, device=lengths.device) |
|
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) |
|
|
|
return expaned_lengths >= lengths.unsqueeze(-1) |
|
|
|
|
|
def _prepare_4d_causal_attention_mask_with_cache_position( |
|
attention_mask: torch.Tensor, |
|
sequence_length: int, |
|
target_length: int, |
|
dtype: torch.dtype, |
|
device: torch.device, |
|
min_dtype: float, |
|
cache_position: torch.Tensor, |
|
batch_size: int, |
|
is_training: bool = False, |
|
token_type_ids: torch.Tensor = None, |
|
): |
|
""" |
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
|
|
|
Args: |
|
attention_mask (`torch.Tensor`): |
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. |
|
sequence_length (`int`): |
|
The sequence length being processed. |
|
target_length (`int`): |
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. |
|
dtype (`torch.dtype`): |
|
The dtype to use for the 4D attention mask. |
|
device (`torch.device`): |
|
The device to plcae the 4D attention mask on. |
|
min_dtype (`float`): |
|
The minimum value representable with the dtype `dtype`. |
|
cache_position (`torch.Tensor`): |
|
Indices depicting the position of the input sequence tokens in the sequence. |
|
batch_size (`torch.Tensor`): |
|
Batch size. |
|
is_training (`bool`): |
|
Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels` |
|
""" |
|
if attention_mask is not None and attention_mask.dim() == 4: |
|
|
|
causal_mask = attention_mask |
|
else: |
|
causal_mask = torch.full( |
|
(sequence_length, target_length), |
|
fill_value=min_dtype, |
|
dtype=dtype, |
|
device=device, |
|
) |
|
|
|
if sequence_length != 1: |
|
if is_training: |
|
causal_mask = torch.triu(causal_mask, diagonal=1) |
|
else: |
|
causal_mask[:, :sequence_length] = 0.0 |
|
|
|
causal_mask *= torch.arange( |
|
target_length, device=cache_position.device |
|
) > cache_position.reshape(-1, 1) |
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
|
if attention_mask is not None: |
|
causal_mask = ( |
|
causal_mask.clone() |
|
) |
|
mask_length = attention_mask.shape[-1] |
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ |
|
:, None, None, : |
|
].to(causal_mask.device) |
|
padding_mask = padding_mask == 0 |
|
causal_mask[:, :, :, :mask_length] = causal_mask[ |
|
:, :, :, :mask_length |
|
].masked_fill(padding_mask, min_dtype) |
|
|
|
if is_training: |
|
causal_mask[:, :, :, :mask_length] = causal_mask[ |
|
:, :, :, :mask_length |
|
].masked_fill( |
|
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 |
|
) |
|
return causal_mask |
|
|
|
|
|
|
|
def is_url(val) -> bool: |
|
return isinstance(val, str) and val.startswith("http") |
|
|
|
|
|
|
|
def is_image_or_image_url(elem): |
|
return is_url(elem) or is_valid_image(elem) |
|
|
|
|
|
def _is_str_or_image(elem): |
|
return isinstance(elem, (str)) or is_image_or_image_url(elem) |
|
|
|
|
|
def generate_partial_autoregressive_mask(sz, start, end): |
|
mask = torch.zeros(sz, sz).bool() |
|
mask[start:end, start:end] = torch.triu( |
|
torch.ones(end - start, end - start, dtype=torch.bool), diagonal=1 |
|
) |
|
mask[:start, start:end] = True |
|
mask[end:, start:end] = True |
|
return mask |
|
|
|
|
|
def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images): |
|
|
|
return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n" |
|
|
|
|
|
def is_torchdynamo_compiling(): |
|
|
|
try: |
|
import torch |
|
|
|
return torch.compiler.is_compiling() |
|
except Exception: |
|
try: |
|
import torch._dynamo as dynamo |
|
|
|
return dynamo.is_compiling() |
|
except Exception: |
|
return False |
|
|