|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import io |
|
import re |
|
import subprocess |
|
from collections import UserDict |
|
from typing import List, Literal, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import PIL |
|
import PIL.Image |
|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
from transformers import TensorType |
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.image_utils import ImageInput |
|
from transformers.processing_utils import ProcessorMixin |
|
|
|
from .image_processing_megrezo import MegrezOImageProcessor |
|
|
|
AudioInput = Union[str, bytes, "np.ndarray", List[str], List[bytes], List["np.ndarray"]] |
|
ReturnTensorType = Union[str, TensorType] |
|
|
|
|
|
class ImageBatchFeature(BatchFeature): |
|
r""" |
|
Holds the image features of a batch of images. |
|
""" |
|
|
|
pixel_values: Union[np.ndarray, torch.Tensor] |
|
image_sizes: Union[np.ndarray, torch.Tensor] |
|
tgt_sizes: Union[np.ndarray, torch.Tensor] |
|
patch_attention_mask: Union[np.ndarray, torch.Tensor] |
|
image_bounds: Union[np.ndarray, torch.Tensor] |
|
|
|
|
|
class AudioBatchFeature(BatchFeature): |
|
r""" |
|
Holds the audio features of a batch of audio. |
|
""" |
|
|
|
input_audios: List[Union[np.ndarray, torch.Tensor]] |
|
input_audio_lengths: List[Union[np.ndarray, torch.Tensor]] |
|
audio_span_tokens: List[Union[np.ndarray, torch.Tensor]] |
|
audio_bounds: Union[np.ndarray, torch.Tensor] |
|
|
|
|
|
class ConvContent(UserDict): |
|
text: Optional[str] = None |
|
image: Optional[ImageInput] = None |
|
audio: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None |
|
|
|
|
|
class Conversation(UserDict): |
|
role: Literal["user", "assistant"] |
|
content: Union[str, dict, ConvContent] |
|
|
|
|
|
def load_audio( |
|
audio: Union[str, bytes], |
|
sample_rate: int = 16000, |
|
) -> "np.ndarray": |
|
"""Load audio from a file path or bytes and return as a numpy array. |
|
|
|
Args: |
|
audio (Union[str, bytes]): path to a audio file or audio bytes. |
|
sample_rate (int, optional): sample rate. Defaults to 16000. |
|
|
|
Raises: |
|
ValueError: if the input audio is neither a path nor bytes. |
|
|
|
Returns: |
|
np.ndarray: the audio as a numpy array. |
|
""" |
|
if isinstance(audio, str): |
|
inp = audio |
|
out = "-" |
|
cmd_inp = None |
|
elif isinstance(audio, bytes): |
|
inp = "pipe:" |
|
out = "pipe:" |
|
cmd_inp = audio |
|
else: |
|
raise ValueError("input audio must be either a path or bytes") |
|
|
|
cmd = [ |
|
"ffmpeg", |
|
"-nostdin", |
|
"-threads", |
|
"0", |
|
"-i", |
|
inp, |
|
"-f", |
|
"s16le", |
|
"-ac", |
|
"1", |
|
"-acodec", |
|
"pcm_s16le", |
|
"-ar", |
|
str(sample_rate), |
|
out, |
|
] |
|
|
|
out = subprocess.check_output(cmd, input=cmd_inp, stderr=subprocess.PIPE) |
|
arr = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 |
|
return arr |
|
|
|
|
|
def load_image( |
|
image: Union[str, bytes, PIL.Image.Image], |
|
) -> PIL.Image.Image: |
|
"""Load image from a file path or bytes and return as a PIL image. |
|
|
|
Args: |
|
image (Union[str, bytes, PIL.Image.Image]): path to an image file, image bytes or a PIL image. |
|
|
|
Raises: |
|
ValueError: if the input image is neither a path nor bytes. |
|
|
|
Returns: |
|
PIL.Image.Image: the image as a PIL image. |
|
""" |
|
if isinstance(image, PIL.Image.Image): |
|
return image |
|
|
|
if isinstance(image, str): |
|
img = PIL.Image.open(image) |
|
elif isinstance(image, bytes): |
|
img = PIL.Image.open(io.BytesIO(image)) |
|
else: |
|
raise ValueError("input image must be either a path or bytes") |
|
|
|
return img |
|
|
|
|
|
class MegrezOProcessor(ProcessorMixin): |
|
attributes = ["image_processor", "audio_feature_extractor", "tokenizer"] |
|
image_processor_class = "AutoImageProcessor" |
|
audio_feature_extractor_class = "WhisperFeatureExtractor" |
|
tokenizer_class = "AutoTokenizer" |
|
|
|
_image_placeholder = r"(<image>./</image>)" |
|
_audio_placeholder = r"(<audio>./</audio>)" |
|
|
|
def __init__(self, image_processor=None, audio_feature_extractor=None, tokenizer=None): |
|
super().__init__(image_processor, audio_feature_extractor, tokenizer) |
|
self.chat_template = self.tokenizer.chat_template |
|
|
|
def _parse_and_check_inputs(self, inputs) -> List[Conversation]: |
|
if not isinstance(inputs, list): |
|
raise ValueError("inputs must be a list of conversations") |
|
|
|
conversations = [] |
|
images = [] |
|
audios = [] |
|
|
|
for input in inputs: |
|
if not isinstance(input, dict) and not isinstance(input, Conversation): |
|
raise ValueError("each element of inputs must be a dictionary or a Conversation object") |
|
|
|
role = input.get("role") |
|
content = input.get("content") |
|
if role is None or content is None: |
|
raise ValueError("role and content must be provided in each conversation") |
|
|
|
if isinstance(content, str): |
|
content = content |
|
elif isinstance(content, dict): |
|
content = ConvContent({**content}) |
|
elif not isinstance(content, ConvContent): |
|
raise ValueError("content must be a dictionary or a ConvContent object") |
|
|
|
if not isinstance(content, str): |
|
if content.get("image") is not None: |
|
images.extend(content["image"] if isinstance(content["image"], list) else [content["image"]]) |
|
|
|
if content.get("audio") is not None: |
|
audios.extend(content["audio"] if isinstance(content["audio"], list) else [content["audio"]]) |
|
|
|
conv = Conversation({"role": role, "content": content}) |
|
conversations.append(conv) |
|
|
|
return conversations, images, audios |
|
|
|
def __call__( |
|
self, |
|
conversations: List[Conversation], |
|
apply_chat_template: bool = True, |
|
max_length: Optional[int] = None, |
|
return_tensors: ReturnTensorType = TensorType.PYTORCH, |
|
apply_data_collator: bool = True, |
|
**kwargs, |
|
): |
|
assert return_tensors is TensorType.PYTORCH, "Only PyTorch tensors are supported for now." |
|
convs, images, audios = self._parse_and_check_inputs(conversations) |
|
add_generation_prompt = kwargs.pop("add_generation_prompt", True) |
|
if apply_chat_template: |
|
prompt = self.tokenizer.apply_chat_template( |
|
convs, |
|
tokenize=False, |
|
add_generation_prompt=add_generation_prompt, |
|
) |
|
else: |
|
prompt = "\n".join([conv["content"] for conv in convs]) |
|
|
|
prompt, multimodal_inputs = self.process_multimodal_inputs( |
|
prompt, |
|
images=images, |
|
audios=audios, |
|
return_tensors=return_tensors, |
|
**kwargs, |
|
) |
|
text_encodings = self.tokenizer( |
|
prompt, |
|
return_tensors=return_tensors, |
|
max_length=max_length, |
|
padding=True, |
|
padding_side="left", |
|
truncation=True, |
|
**kwargs, |
|
) |
|
|
|
merged = self.merge_encodings(text_encodings, multimodal_inputs) |
|
|
|
if apply_data_collator: |
|
return self.data_collator([merged]) |
|
|
|
return merged |
|
|
|
def merge_encodings(self, text_encodings, multimodal_inputs): |
|
|
|
result = { |
|
"image_encoding": None, |
|
"audio_encoding": None, |
|
} |
|
|
|
result["input_ids"] = text_encodings["input_ids"].reshape(-1).to(torch.int32) |
|
result["attention_mask"] = result["input_ids"].ne(0) |
|
result["position_ids"] = torch.arange(result["input_ids"].size(0)).long() |
|
|
|
if "image_encoding" in multimodal_inputs and multimodal_inputs["image_encoding"]: |
|
result["image_encoding"] = multimodal_inputs["image_encoding"] |
|
result["image_encoding"]["image_bounds"] = self.compute_bounds_image(result["input_ids"]) |
|
|
|
if "audio_encoding" in multimodal_inputs and multimodal_inputs["audio_encoding"]: |
|
result["audio_encoding"] = multimodal_inputs["audio_encoding"] |
|
result["audio_encoding"]["audio_bounds"] = self.compute_bounds_audio(result["input_ids"]) |
|
|
|
return result |
|
|
|
def compute_bounds_image(self, input_ids: torch.Tensor) -> List[torch.Tensor]: |
|
image_start_ids = ( |
|
torch.where((input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id))[0] + 1 |
|
) |
|
image_end_ids = torch.where( |
|
(input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id) |
|
)[0] |
|
|
|
valid_image_nums = max(len(image_start_ids), len(image_end_ids)) |
|
bounds_image = torch.hstack( |
|
[ |
|
image_start_ids[:valid_image_nums].unsqueeze(-1), |
|
image_end_ids[:valid_image_nums].unsqueeze(-1), |
|
] |
|
) |
|
return bounds_image |
|
|
|
def compute_bounds_audio(self, input_ids: torch.Tensor) -> torch.Tensor: |
|
audio_bos_ids = torch.where(input_ids == self.tokenizer.audio_start_id)[0] |
|
audio_eos_ids = torch.where(input_ids == self.tokenizer.audio_end_id)[0] |
|
bounds_audio = torch.stack([audio_bos_ids, audio_eos_ids], 1) |
|
return bounds_audio |
|
|
|
def process_multimodal_inputs( |
|
self, |
|
text: str, |
|
images: Optional[ImageInput] = None, |
|
audios: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, |
|
return_tensors: ReturnTensorType = TensorType.PYTORCH, |
|
**kwargs, |
|
): |
|
|
|
|
|
if text is None and images is None and audios is None: |
|
raise ValueError("At least one of text, images or audio must be provided") |
|
|
|
image_processor_kwargs, audio_feature_extractor_kwargs = {}, {} |
|
if kwargs: |
|
image_processor_kwargs = { |
|
k: v for k, v in kwargs.items() if k in self.image_processor._valid_processor_keys |
|
} |
|
audio_feature_extractor_kwargs = { |
|
k: v for k, v in kwargs.items() if k in self.audio_feature_extractor._valid_processor_keys |
|
} |
|
|
|
multimodal_encodings = { |
|
"image_encoding": None, |
|
"audio_encoding": None, |
|
} |
|
|
|
if images: |
|
image_encoding = self.process_image( |
|
images, |
|
return_tensors=return_tensors, |
|
**image_processor_kwargs, |
|
) |
|
text = self.insert_image_feature_placeholders(text, image_encoding) |
|
multimodal_encodings["image_encoding"] = image_encoding |
|
|
|
if audios: |
|
audio_encoding = self.process_audio( |
|
audios, |
|
**audio_feature_extractor_kwargs, |
|
) |
|
text = self.insert_audio_feature_placeholders(text, audio_encoding) |
|
multimodal_encodings["audio_encoding"] = audio_encoding |
|
|
|
return text, multimodal_encodings |
|
|
|
def insert_image_feature_placeholders( |
|
self, |
|
prompt: str, |
|
image_features: ImageBatchFeature, |
|
max_slice_nums: Optional[int] = None, |
|
use_image_id: Optional[bool] = None, |
|
) -> List[str]: |
|
|
|
img_tags = re.findall(self._image_placeholder, prompt) |
|
assert len(img_tags) == len( |
|
image_features.image_sizes |
|
), f"the number of image tags must match the number of images, got {len(img_tags)} and {len(image_features.image_sizes)}" |
|
|
|
|
|
text_chunks = prompt.split(self._image_placeholder) |
|
final_text = "" |
|
for i in range(len(img_tags)): |
|
final_text += text_chunks[i] + self.image_processor.get_slice_image_placeholder( |
|
image_features.image_sizes[i], |
|
i, |
|
max_slice_nums, |
|
use_image_id, |
|
) |
|
final_text += text_chunks[-1] |
|
|
|
return final_text |
|
|
|
def insert_audio_feature_placeholders( |
|
self, |
|
prompt: str, |
|
audio_features: AudioBatchFeature, |
|
) -> List[str]: |
|
|
|
audio_tags = re.findall(self._audio_placeholder, prompt) |
|
assert len(audio_tags) == len( |
|
audio_features.input_audios |
|
), "the number of audio tags must match the number of audios" |
|
|
|
|
|
text_chunks = prompt.split(self._audio_placeholder) |
|
final_text = "" |
|
for idx in range(len(audio_features.input_audios)): |
|
final_text += text_chunks[idx] + ( |
|
self.tokenizer.audio_start |
|
+ self.tokenizer.unk_token * audio_features.audio_span_tokens[idx] |
|
+ self.tokenizer.audio_end |
|
) |
|
final_text += text_chunks[-1] |
|
|
|
return final_text |
|
|
|
def process_audio( |
|
self, |
|
audio_input: AudioInput, |
|
return_tensors: ReturnTensorType = TensorType.PYTORCH, |
|
**kwargs, |
|
) -> AudioBatchFeature: |
|
if isinstance(audio_input, list): |
|
inputs = [load_audio(x) for x in audio_input] |
|
elif isinstance(audio_input, (str, bytes, "np.ndarray")): |
|
inputs = [load_audio(audio_input)] |
|
else: |
|
raise ValueError("audio_input must be a path or bytes or a list of paths/bytes") |
|
|
|
features = self.audio_feature_extractor( |
|
inputs, |
|
sampling_rate=self.audio_feature_extractor.sampling_rate, |
|
return_attention_mask=True, |
|
return_token_timestamps=True, |
|
padding="max_length", |
|
return_tensors=return_tensors, |
|
**kwargs, |
|
) |
|
|
|
input_lengths = features["num_frames"] |
|
input_lengths = (input_lengths - 1) // 2 + 1 |
|
output_lengths = (input_lengths - 2) // 2 + 1 |
|
input_audio_lengths = torch.stack([input_lengths, output_lengths], dim=1) |
|
audio_span_tokens = (output_lengths + 2).tolist() |
|
|
|
data = { |
|
"input_audios": features["input_features"], |
|
"input_audio_lengths": input_audio_lengths, |
|
"audio_span_tokens": audio_span_tokens, |
|
} |
|
|
|
|
|
return AudioBatchFeature(data={**data}) |
|
|
|
def pad_images( |
|
self, |
|
pixel_values_list: List[torch.Tensor], |
|
tgt_sizes: torch.Tensor, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Pad images to the same size and return the padded pixel values and patch attention mask. |
|
|
|
Sliced pataches may have different sizes. We pad them to the same size and return the padded pixel values and corresponding patch attention mask. |
|
""" |
|
|
|
all_pixel_values = [] |
|
for pixel_value in pixel_values_list: |
|
all_pixel_values.append(pixel_value.flatten(end_dim=1).permute(1, 0)) |
|
|
|
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) |
|
all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True, padding_value=0.0) |
|
B, L, _ = all_pixel_values.shape |
|
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) |
|
|
|
patch_attention_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool) |
|
for i in range(B): |
|
patch_attention_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True |
|
|
|
return all_pixel_values, patch_attention_mask |
|
|
|
def process_image( |
|
self, |
|
image_input: ImageInput, |
|
do_pad: bool = True, |
|
max_slice_nums: Optional[int] = None, |
|
return_tensors: ReturnTensorType = TensorType.PYTORCH, |
|
**kwargs, |
|
) -> ImageBatchFeature: |
|
if isinstance(image_input, list): |
|
image_input = [load_image(x) for x in image_input] |
|
elif isinstance(image_input, (str, bytes, PIL.Image.Image)): |
|
image_input = [load_image(image_input)] |
|
else: |
|
raise ValueError(f"image_input must be a path or bytes or a list of paths/bytes, not: {type(image_input)}") |
|
|
|
image_features = self.image_processor( |
|
image_input, |
|
do_pad=do_pad, |
|
max_slice_nums=max_slice_nums, |
|
return_tensors=return_tensors, |
|
**kwargs, |
|
) |
|
|
|
|
|
assert len(image_features.pixel_values) == 1, "images should be packed into one list." |
|
pixel_values = image_features.pixel_values[0] |
|
tgt_sizes = image_features.tgt_sizes[0] |
|
image_sizes = image_features.image_sizes[0] |
|
|
|
pixel_values, patch_attention_mask = self.pad_images(pixel_values, tgt_sizes) |
|
|
|
data = { |
|
"pixel_values": pixel_values, |
|
"image_sizes": image_sizes, |
|
"tgt_sizes": tgt_sizes, |
|
"patch_attention_mask": patch_attention_mask, |
|
} |
|
|
|
|
|
return ImageBatchFeature(data=data) |
|
|
|
def data_collator(self, examples, padding_value=0, max_length=4096, collate_labels=False): |
|
"""Collate data for MegrezO model. |
|
|
|
Batch data for MegrezO model. This function trims and pads the input_ids, position_ids, and attention_mask tensors. For bounds tensors, it adds batch index to the bounds. |
|
""" |
|
|
|
|
|
def trim_and_pad(seq, batch_first, padding_value): |
|
return pad_sequence( |
|
[s[:max_length] for s in seq], |
|
batch_first=True, |
|
padding_value=padding_value, |
|
) |
|
|
|
input_ids = trim_and_pad( |
|
[example["input_ids"] for example in examples], |
|
batch_first=True, |
|
padding_value=padding_value, |
|
) |
|
position_ids = trim_and_pad( |
|
[example["position_ids"] for example in examples], |
|
batch_first=True, |
|
padding_value=padding_value, |
|
) |
|
|
|
attention_mask = trim_and_pad( |
|
[example["attention_mask"] for example in examples], |
|
batch_first=True, |
|
padding_value=padding_value, |
|
) |
|
|
|
image_encoding_list = { |
|
"pixel_values": [], |
|
"image_bounds": [], |
|
"tgt_sizes": [], |
|
"patch_attention_mask": [], |
|
} |
|
for bid, example in enumerate(examples): |
|
image_encoding = example.get("image_encoding") |
|
if not image_encoding: |
|
continue |
|
|
|
image_encoding_list["pixel_values"].append(image_encoding["pixel_values"]) |
|
image_encoding_list["tgt_sizes"].append(image_encoding["tgt_sizes"]) |
|
image_encoding_list["patch_attention_mask"].append(image_encoding["patch_attention_mask"]) |
|
|
|
|
|
|
|
bounds_with_bid = image_encoding["image_bounds"].clone() |
|
bounds_with_bid = torch.hstack( |
|
[ |
|
torch.full((bounds_with_bid.size(0), 1), bid, dtype=bounds_with_bid.dtype), |
|
bounds_with_bid, |
|
] |
|
) |
|
image_encoding_list["image_bounds"].append(bounds_with_bid) |
|
|
|
audio_encoding_list = { |
|
"input_audios": [], |
|
"input_audio_lengths": [], |
|
"audio_span_tokens": [], |
|
"audio_bounds": [], |
|
} |
|
for bid, example in enumerate(examples): |
|
audio_encoding = example.get("audio_encoding") |
|
if not audio_encoding: |
|
continue |
|
|
|
audio_encoding_list["input_audios"].append(audio_encoding["input_audios"]) |
|
audio_encoding_list["input_audio_lengths"].append(audio_encoding["input_audio_lengths"]) |
|
audio_encoding_list["audio_span_tokens"].extend(audio_encoding["audio_span_tokens"]) |
|
bounds_with_bid = audio_encoding["audio_bounds"].clone() |
|
bounds_with_bid = torch.hstack( |
|
[ |
|
torch.full((bounds_with_bid.size(0), 1), bid, dtype=bounds_with_bid.dtype), |
|
bounds_with_bid, |
|
] |
|
) |
|
audio_encoding_list["audio_bounds"].append(bounds_with_bid) |
|
|
|
result = { |
|
"input_ids": input_ids, |
|
"position_ids": position_ids, |
|
"attention_mask": attention_mask, |
|
"image_encoding": None, |
|
"audio_encoding": None, |
|
} |
|
|
|
if collate_labels: |
|
labels = trim_and_pad( |
|
[example["labels"] for example in examples], |
|
batch_first=True, |
|
padding_value=-100, |
|
) |
|
result["labels"] = labels |
|
|
|
if any(image_encoding_list.values()): |
|
result["image_encoding"] = { |
|
"pixel_values": torch.vstack(image_encoding_list["pixel_values"]), |
|
"tgt_sizes": torch.vstack(image_encoding_list["tgt_sizes"]), |
|
"patch_attention_mask": torch.vstack(image_encoding_list["patch_attention_mask"]), |
|
"image_bounds": torch.vstack(image_encoding_list["image_bounds"]), |
|
} |
|
if any(audio_encoding_list.values()): |
|
result["audio_encoding"] = { |
|
"input_audios": torch.vstack(audio_encoding_list["input_audios"]), |
|
"input_audio_lengths": torch.vstack(audio_encoding_list["input_audio_lengths"]), |
|
"audio_span_tokens": audio_encoding_list["audio_span_tokens"], |
|
"audio_bounds": torch.vstack(audio_encoding_list["audio_bounds"]), |
|
} |
|
return result |
|
|