llama-labahasa-11B / bahasa_processing.py
Munggok
init new model
2ab806d
raw
history blame
15 kB
from typing import Optional, Union
import numpy as np
import torch
import transformers
from .bahasa_config import BahasaConfig
class BahasaProcessor(transformers.ProcessorMixin):
"""
Constructs an Bahasa processor which wraps an audio processor and a text_processor into a single processor.
Args:
audio_processor: The audio processor for the audio encoder.
text_processor: The processor for the language model.
"""
attributes = ["audio_processor", "text_processor"]
audio_processor_class = (
"Wav2Vec2Processor",
"SeamlessM4TFeatureExtractor",
"WhisperProcessor",
)
text_processor_class = (
"PreTrainedTokenizer",
"PreTrainedTokenizerFast",
"MllamaProcessor",
)
tokenizer: transformers.PreTrainedTokenizerBase
text_processor: Union[
transformers.ProcessorMixin, transformers.PreTrainedTokenizerBase
]
audio_processor: transformers.ProcessorMixin
def __init__(
self,
audio_processor=None,
text_processor=None,
audio_padding: str = "longest",
encoder_ds_factor: int = 320,
stack_factor: int = 8,
audio_placeholder: str = "<|audio|>",
):
"""
Args:
audio_processor: The audio processor for the audio encoder.
text_processor: The processor for the language model.
audio_padding: The padding strategy for the audio encoder.
encoder_ds_factor: The downsample factor of the audio encoder.
stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
audio_placeholder: The placeholder for the audio in the text.
"""
self.audio_padding = audio_padding
self.encoder_ds_factor = encoder_ds_factor
self.stack_factor = stack_factor
self.audio_placeholder = audio_placeholder
if isinstance(text_processor, transformers.MllamaProcessor):
self.tokenizer: transformers.PreTrainedTokenizerFast = (
text_processor.tokenizer
)
else:
self.tokenizer = text_processor
super().__init__(audio_processor=audio_processor, text_processor=text_processor)
self.audio_token_replacement = self.tokenizer.bos_token
assert (
self.audio_token_replacement is not None
), "The tokenizer has no EOS token. Cannot recover."
# if tokenizer.pad_token_id is None:
# tokenizer.pad_token_id = tokenizer.eos_token_id
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config: BahasaConfig = transformers.AutoConfig.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
audio_processor = transformers.AutoProcessor.from_pretrained(
config.audio_model_id
or config.audio_config._name_or_path
or "facebook/wav2vec2-base-960h"
)
text_processor = transformers.AutoProcessor.from_pretrained(
config._text_config.name_or_path, **kwargs
)
text_processor.tokenizer.padding_side = "left"
text_processor.tokenizer.pad_token = text_processor.tokenizer.eos_token
new_template = """{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- if message['content'] is iterable and not message['content'] is string %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n{%- endfor %}\n\n{#- Always include system message, regardless of images #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n {%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"""
text_processor.tokenizer.chat_template = new_template
return cls(
audio_processor=audio_processor,
text_processor=text_processor,
stack_factor=config.stack_factor,
)
def __call__(
self,
text: Optional[str] = None,
audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
images: Optional[transformers.image_utils.ImageInput] = None,
sampling_rate: Optional[int] = None,
return_tensors: Optional[
Union[str, transformers.TensorType]
] = transformers.TensorType.PYTORCH,
**kwargs,
) -> transformers.BatchFeature:
"""
Main method to prepare for the model one text sequence and audio. This method forwards the `text`
and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
audio processor's [`~Wav2Vec2Processor.__call__`] if `audio` is not `None`. Please refer to the docstring
of the above two methods for more information.
Args:
text (`str`, `List[str]`):
The sequence to be encoded. Sequence can be a string or (pretokenized string).
audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The audio to be prepared. Audio can be NumPy array or PyTorch tensor. In case of a
NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, and T the
sample length of the audio.
sampling_rate (`int`, *optional*, defaults to 16000):
Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
you are doing.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **audio_values** -- Processed audio values to be fed to a model. Returned when `audio` is not `None`.
- **audio_token_len** -- Predicted number of audio frames: this value is guaranteed to be a close upper bound.
Returned when `audio` is not `None`.
- **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
"""
# TODO: Add support for multiple audio and text inputs.
data = {}
audio_embed_frames = 0
if audio is not None and len(audio) > 0:
if self.audio_padding == "max_length":
# 30 seconds is the expected length for Whisper
assert sampling_rate is not None, "Sampling rate must be provided."
audio_len = 30 * sampling_rate
else:
audio_len = audio.shape[-1]
# It's guaranteed that the number of frames is less than or equal to this amount.
# For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
# Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
nb_encoder_frames = int(round(audio_len / self.encoder_ds_factor + 1e-4))
audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor))
data["audio_token_len"] = [audio_embed_frames]
# Main audio processing. The processor is model-specific.
x = self.audio_processor(
audio,
sampling_rate=sampling_rate,
padding="longest",
max_length=audio_len,
**kwargs,
)
if "input_features" in x:
data["audio_values"] = x.input_features
else:
data["audio_values"] = x.input_values
if text is not None:
assert isinstance(
text, str
), "Text must be a string. Batch mode not supported yet."
if self.audio_placeholder in text:
if "audio_token_len" not in data:
raise ValueError(
f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
)
start_idx = len(
self.tokenizer.encode(
text[: text.index(self.audio_placeholder)],
add_special_tokens=False,
)
)
data["audio_token_start_idx"] = [start_idx]
# Replace the audio placeholder with the audio token.
# e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
# where the number of </s> is the number of audio frames.
text = text.replace(
self.audio_placeholder,
self.audio_token_replacement * audio_embed_frames,
)
# Special tokens like BOS should already have been added by the caller.
data.update(
self.text_processor(
text=[text], images=images, add_special_tokens=False, **kwargs
)
)
return transformers.BatchFeature(data=data, tensor_type=return_tensors)
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
text_processor_input_names = self.text_processor.model_input_names
audio_processor_input_names = self.audio_processor.model_input_names
return list(set(text_processor_input_names + audio_processor_input_names))
BahasaProcessor.register_for_auto_class()
transformers.AutoProcessor.register(BahasaConfig, BahasaProcessor)