\nUSER: What's the content of the image?\nASSISTANT:"
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_length=30)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if inputs_embeds is None:
+ # 1. Extra the input embeddings
+ no_img_input_ids = torch.where(input_ids!=self.config.image_token_index, input_ids, self.pad_token_id) # some model used up all the embeddings
+ inputs_embeds = self.get_input_embeddings()(no_img_input_ids)
+ batch_size = inputs_embeds.shape[0]
+ # 2. Merge text and images
+ if pixel_values is not None and input_ids.shape[1] != 1:
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer] # ( b, img_seqlen, embed_dim)
+ if vision_feature_select_strategy == "default":
+ selected_image_feature = selected_image_feature[:, 1:]
+ elif vision_feature_select_strategy == "full":
+ raise ValueError("not implemented")
+ selected_image_feature = selected_image_feature
+ else:
+ raise ValueError(
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
+ )
+
+ image_features = self.multi_modal_projector(selected_image_feature,
+ media_type,
+ batch_size=batch_size,
+ num_videos=pixel_values.shape[0]//self.config.num_frames//batch_size,)
+
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
+ image_features, inputs_embeds, input_ids, attention_mask, labels
+ )
+ if labels is None:
+ labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
+ else:
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
+ # generation with cache
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
+ # that are set to 0
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
+
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
+
+ # Get the target length
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
+
+ extended_attention_mask = torch.ones(
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
+ dtype=attention_mask.dtype,
+ device=attention_mask.device,
+ )
+
+ # Filter out only the tokens that can be un-attended, this can happen
+ # if one uses Llava + Fused modules where the cache on the
+ # first iteration is already big enough, or if one passes custom cache
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
+ new_batch_index = batch_index[valid_indices]
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
+
+ # Zero-out the places where we don't need to attend
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
+
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = outputs[0]
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ if attention_mask is not None:
+ shift_attention_mask = attention_mask[..., 1:]
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
+ else:
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = nn.CrossEntropyLoss()
+ loss = loss_fct(
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return PllavaCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
+ ):
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ cache_length = past_key_values.get_seq_length()
+ past_length = past_key_values.seen_tokens
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+ # input)
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+ elif self.config.image_token_index in input_ids:
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
+ # older attention values, as their corresponding values are not part of the input.
+ if cache_length < past_length and attention_mask is not None:
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+ media_type = kwargs.get('media_type', None)
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ "media_type": media_type,
+ }
+ )
+ return model_inputs
+
+ def _reorder_cache(self, *args, **kwargs):
+ return self.language_model._reorder_cache(*args, **kwargs)
diff --git a/models/pllava/processing_pllava.py b/models/pllava/processing_pllava.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f1211f0170b918628a5d4720c07478af2c18f35
--- /dev/null
+++ b/models/pllava/processing_pllava.py
@@ -0,0 +1,292 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for Llava.
+"""
+
+
+import itertools
+from typing import List, Optional, Union
+import PIL.Image
+import numpy as np
+
+from transformers import AutoTokenizer
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.image_utils import (
+ ImageInput,
+ make_list_of_images,
+ valid_images,
+ infer_channel_dimension_format,
+ to_numpy_array,
+ get_image_size,
+ ChannelDimension,
+)
+from transformers.image_processing_utils import get_size_dict
+from transformers.image_utils import PILImageResampling
+from transformers.processing_utils import ProcessorMixin
+from transformers.image_transforms import resize, pad, PaddingMode, to_channel_dimension_format, get_resize_output_image_size
+from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
+from transformers.utils import TensorType
+
+
+class PllavaProcessor(ProcessorMixin):
+ r"""
+ Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.
+
+ [`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
+ [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`CLIPImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`LlamaTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "CLIPImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(self, image_processor=None, tokenizer=None,
+ shortest_edge=336,
+ longest_edge=762,
+ center_pad=False):
+ self.shortest_edge = shortest_edge
+ self.longest_edge = longest_edge
+ self.center_pad = center_pad
+ super().__init__(image_processor, tokenizer)
+
+ def resize_crop_longshort(self, videos: list[list[np.ndarray]], input_data_format):
+ video_spatial_sizes = [get_image_size(images[0], input_data_format) for images in videos]
+ long_short_rates = [max(size) / min(size) for size in video_spatial_sizes]
+ min_long_short_rate = min(long_short_rates)
+ min_long_short_video_idx = long_short_rates.index(min_long_short_rate)
+
+ clip_resolution = self.image_processor.size['shortest_edge']
+ out_video_spatial_size = video_spatial_sizes[min_long_short_video_idx]
+ out_videos_short_edge = max(min(size) for size in video_spatial_sizes)
+ resize_longest_edge = max(max(size) for size in video_spatial_sizes)
+ resize_longest_edge = min(640, resize_longest_edge)
+ out_videos_short_edge = min(out_videos_short_edge, int(resize_longest_edge / min_long_short_rate))
+ out_videos_short_edge = max(out_videos_short_edge, clip_resolution)
+
+
+ if out_video_spatial_size[0] > out_video_spatial_size[1]: # h > w:
+ out_video_spatial_size = (int(out_videos_short_edge * min_long_short_rate), out_videos_short_edge )
+ else:
+ out_video_spatial_size = ( out_videos_short_edge, int(out_videos_short_edge * min_long_short_rate) )
+ videos = [
+ [self.resize(frame, input_data_format=input_data_format, shortest_edge=out_videos_short_edge, longest_edge=9999) for frame in frames]
+ for frames in videos
+ ]
+ out_videos = []
+ for frames in videos:
+ out_frames = []
+ video_spatial_size = get_image_size(frames[0], input_data_format)
+ assert min(video_spatial_size) == out_videos_short_edge
+ overhead = (max(video_spatial_size) - max(out_video_spatial_size)) // 2
+ slice_start, slice_end = overhead // 2, overhead // 2 + max(out_video_spatial_size)
+ hslice, wslice = (slice(slice_start, slice_end), slice(None, None)) if video_spatial_size[0] > video_spatial_size[1] \
+ else (slice(None, None), slice(slice_start, slice_end)) # h > w
+ for frame in frames:
+ if input_data_format == ChannelDimension.FIRST:
+ out_frames.append(frame[..., hslice, wslice])
+ elif input_data_format == ChannelDimension.LAST:
+ out_frames.append(frame[..., hslice, wslice, :])
+ out_videos.append(out_frames)
+
+ return out_videos
+
+ @staticmethod
+ def _compute_num_blocks_and_overlaps(input_shape, resolution):
+ input_shape = np.array(input_shape)
+ resolution = np.array(resolution)
+ assert input_shape.max() >= resolution
+ num_blocks = np.ceil(input_shape / resolution).astype(np.int32).tolist()
+ overlaps = [0 if size % resolution==0
+ else int(np.floor((resolution - size % resolution) / (num_block - 1))) for num_block, size in zip(num_blocks, input_shape)]
+ return num_blocks, overlaps
+
+ def resize(
+ self,
+ image: np.ndarray,
+ resample: PILImageResampling = PILImageResampling.BICUBIC, # type: ignore
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ shortest_edge: int = None,
+ longest_edge: int = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
+ resized to keep the input aspect ratio.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ shortest_edge = getattr(self, 'shortest_edge', None) if shortest_edge is None else shortest_edge
+ longest_edge = getattr(self, 'longest_edge', None) if longest_edge is None else longest_edge
+ default_to_square = False
+ output_size = get_resize_output_image_size(
+ image,
+ size=shortest_edge,
+ default_to_square=default_to_square,
+ max_size=longest_edge,
+ input_data_format=input_data_format,
+ )
+ clip_resolution = self.image_processor.size['shortest_edge']
+ if min(output_size) < clip_resolution:
+ output_size = get_resize_output_image_size(
+ image,
+ size=shortest_edge,
+ default_to_square=default_to_square,
+ input_data_format=input_data_format,
+ )
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
+ images: ImageInput = None,
+ center_pad = None,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length=None,
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
+ of the above two methods for more information.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ truncation (`bool`, *optional*):
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+ data=dict()
+ if images is not None:
+ if isinstance(images, list) and isinstance(images[0], PIL.Image.Image):
+ videos = [images] # one video
+ else:
+ videos = images
+
+ pixel_values_list = []
+ videos = [[to_numpy_array(image) for image in make_list_of_images(images)] for images in videos]
+ # images = [self.resize(image, ) if min(get_image_size(image, input_data_format)) < clip_resolution else image for image in images]
+ input_data_format = infer_channel_dimension_format(videos[0][0])
+ videos = self.resize_crop_longshort(videos, input_data_format)
+
+ for images in videos:
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ center_pad = center_pad if center_pad is not None else self.center_pad
+ if center_pad:
+ images = [self.pad_to_square(image, 0, input_data_format, input_data_format) for image in images]
+
+ pixel_values = self.image_processor(images, return_tensors='np')["pixel_values"]
+ pixel_values_list.append(pixel_values)
+
+ pixel_values = np.concatenate(pixel_values_list)
+ data.update(pixel_values=pixel_values)
+
+ else:
+ data.update(pixel_values = None)
+
+ if text is not None:
+ text_inputs = self.tokenizer(
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
+ )
+ data.update(**text_inputs)
+ return BatchFeature(data, tensor_type=return_tensors)
+
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
diff --git a/python_scripts/hf.py b/python_scripts/hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..83929d1738b85594130d80e99072e7abaa63bfaf
--- /dev/null
+++ b/python_scripts/hf.py
@@ -0,0 +1,80 @@
+import os.path as osp
+import os
+import re
+import multiprocessing
+import functools
+import huggingface_hub
+from huggingface_hub import snapshot_download
+
+
+def upload(repo_id, local_dir, path_in_repo, repo_type, token):
+ huggingface_hub.upload_folder(
+ repo_id=repo_id,
+ folder_path=local_dir,
+ path_in_repo=path_in_repo,
+ token=token,
+ repo_type=repo_type
+ )
+
+def download(repo_id, local_dir, repo_type, token, filter_re=None):
+ files = huggingface_hub.list_repo_files(repo_id, repo_type=repo_type, token=token)
+ if filter_re is not None:
+ files = [file for file in files if re.search(filter_re, file) is not None]
+ pool = multiprocessing.Pool(8)
+ download_func = functools.partial(
+ huggingface_hub.hf_hub_download,
+ repo_id,
+ repo_type=repo_type,
+ local_dir=local_dir,
+ local_dir_use_symlinks=True,
+ token=token
+ )
+ pool.map(download_func, files)
+ print(f'downloaded files {files}')
+
+
+def upload_file(repo_id, file_path, repo_type, token):
+ huggingface_hub.upload_file(
+ repo_id=repo_id,
+ path_or_fileobj=file_path,
+ path_in_repo=file_path,
+ token=token,
+ repo_type=repo_type,
+ )
+
+if __name__ == '__main__':
+ read_token = '...'
+ write_token = '...'
+ repo_id = '...'
+ local_dir = '...'
+ repo_type = '...'
+
+
+ # #############
+ # # Examples on most simple hf usage
+ # # downlaod
+ # filters = []
+ # for filter_re in filters:
+ # download(repo_id,
+ # local_dir,
+ # repo_type,
+ # filter_re)
+
+ # # upload
+ # upload(repo_id, local_dir, local_dir, repo_type, write_token)
+ # #############
+
+ # download models
+ repo_ids = [
+ 'ermu2001/pllava-7b',
+ 'ermu2001/pllava-13b',
+ ]
+ for repo_id in repo_ids:
+ local_dir = repo_id.replace('ermu2001', 'MODELS')
+ snapshot_download(
+ repo_id,
+ local_dir=local_dir,
+ repo_type='model',
+ local_dir_use_symlinks=True,
+ token=read_token,
+ )
\ No newline at end of file
diff --git a/requirements.no_torch.txt b/requirements.no_torch.txt
new file mode 100644
index 0000000000000000000000000000000000000000..307cc3971f513ed7e81ae7f122060941eea2dc00
--- /dev/null
+++ b/requirements.no_torch.txt
@@ -0,0 +1,244 @@
+absl-py==2.1.0
+accelerate==0.26.1
+addict==2.4.0
+aiofiles==23.2.1
+aliyun-python-sdk-core==2.15.0
+aliyun-python-sdk-kms==2.16.2
+altair==5.2.0
+annotated-types==0.6.0
+antlr4-python3-runtime==4.9.3
+anyio==4.3.0
+anykeystore==0.2
+apex==0.9.10.dev0
+appdirs==1.4.4
+argcomplete==3.2.3
+attrs==23.2.0
+av==10.0.0
+beautifulsoup4==4.12.3
+blessed==1.20.0
+blessings==1.7
+boto3==1.34.63
+botocore==1.34.63
+Brotli==1.1.0
+cachetools==5.3.3
+certifi==2024.2.2
+cffi==1.16.0
+charset-normalizer==3.3.2
+click==8.1.7
+colorama==0.4.6
+contourpy==1.2.0
+crcmod==1.7
+cryptacular==1.6.2
+cryptography==42.0.5
+cycler==0.12.1
+dacite==1.7.0
+decorator==4.4.2
+decord==0.6.0
+deepspeed==0.14.0
+defusedxml==0.7.1
+Deprecated==1.2.14
+dill==0.3.8
+distro==1.9.0
+dnspython==2.6.1
+docker-pycreds==0.4.0
+einops==0.6.1
+exceptiongroup==1.2.0
+fastapi==0.110.0
+ffmpeg==1.4
+ffmpy==0.3.2
+fiftyone==0.23.6
+fiftyone-brain==0.16.1
+fiftyone_db==1.1.2
+filelock==3.9.0
+flash-attn==2.5.6
+fonttools==4.49.0
+fsspec==2024.2.0
+ftfy==6.1.3
+future==1.0.0
+fvcore==0.1.5.post20221221
+gdown==5.1.0
+gitdb==4.0.11
+GitPython==3.1.42
+glob2==0.7
+google-auth==2.28.2
+google-auth-oauthlib==1.2.0
+gpustat==1.1.1
+gradio==4.21.0
+gradio_client==0.12.0
+graphql-core==3.2.3
+greenlet==3.0.3
+grpcio==1.62.1
+h11==0.14.0
+h2==4.1.0
+hjson==3.1.0
+hpack==4.0.0
+httpcore==1.0.4
+httpx==0.27.0
+huggingface-hub==0.21.4
+humanize==4.9.0
+hupper==1.12.1
+Hypercorn==0.16.0
+hyperframe==6.0.1
+idna==3.6
+idscheck==2.3.0
+imageio==2.27.0
+imageio-ffmpeg==0.4.9
+importlib_metadata==7.0.2
+importlib_resources==6.3.0
+inflate64==1.0.0
+iopath==0.1.10
+Jinja2==3.1.2
+jmespath==0.10.0
+joblib==1.3.2
+jsonlines==4.0.0
+jsonschema==4.21.1
+jsonschema-specifications==2023.12.1
+kaleido==0.2.1
+kiwisolver==1.4.5
+lazy_loader==0.3
+Markdown==3.6
+markdown-it-py==3.0.0
+MarkupSafe==2.1.3
+matplotlib==3.8.3
+mdurl==0.1.2
+mmcv-full==1.7.2
+model-index==0.1.11
+mongoengine==0.24.2
+motor==3.3.2
+moviepy==1.0.3
+mpmath==1.3.0
+multivolumefile==0.2.3
+networkx==3.2.1
+ninja==1.11.1.1
+numpy
+oauthlib==3.2.2
+omegaconf==2.3.0
+openai==1.14.0
+opencv-python==4.9.0.80
+opencv-python-headless==4.9.0.80
+opendatalab==0.0.10
+openmim==0.3.9
+openxlab==0.0.36
+ordered-set==4.1.0
+orjson==3.9.15
+oss2==2.17.0
+packaging==24.0
+pandas==1.5.3
+PasteDeploy==3.1.0
+pathtools==0.1.2
+pbkdf2==1.3
+peft==0.10.0
+pillow==10.2.0
+plaster==1.1.2
+plaster-pastedeploy==1.0.1
+platformdirs==4.2.0
+plotly==5.20.0
+portalocker==2.8.2
+pprintpp==0.4.0
+priority==2.0.0
+proglog==0.1.10
+protobuf==4.23.4
+psutil==5.9.4
+py-cpuinfo==9.0.0
+py7zr==0.21.0
+pyasn1==0.5.1
+pyasn1-modules==0.3.0
+pybcj==1.0.2
+pycparser==2.21
+pycryptodome==3.20.0
+pycryptodomex==3.20.0
+pydantic==2.6.4
+pydantic_core==2.16.3
+pydub==0.25.1
+Pygments==2.17.2
+pymongo==4.6.2
+pynvml==11.5.0
+pyparsing==3.1.2
+pyppmd==1.1.0
+pyramid==2.0.2
+pyramid-mailer==0.15.1
+PySocks==1.7.1
+python-dateutil==2.9.0.post0
+python-multipart==0.0.9
+python3-openid==3.2.0
+pytz==2023.4
+PyYAML==6.0
+pyzstd==0.15.9
+rarfile==4.1
+referencing==0.33.0
+regex==2023.12.25
+repoze.sendmail==4.4.1
+requests==2.28.2
+requests-oauthlib==1.4.0
+retrying==1.3.4
+rich==13.4.2
+rpds-py==0.18.0
+rsa==4.9
+ruff==0.3.2
+s3transfer==0.10.1
+safetensors==0.4.2
+scikit-image==0.22.0
+scikit-learn==1.4.1.post1
+scipy==1.10.1
+semantic-version==2.10.0
+sentencepiece==0.2.0
+sentry-sdk==1.42.0
+setproctitle==1.3.3
+shellingham==1.5.4
+six==1.16.0
+smmap==5.0.1
+sniffio==1.3.1
+sortedcontainers==2.4.0
+soupsieve==2.5
+SQLAlchemy==2.0.28
+sse-starlette==0.10.3
+sseclient-py==1.8.0
+starlette==0.36.3
+strawberry-graphql==0.138.1
+sympy==1.12
+tabulate==0.9.0
+taskgroup==0.0.0a4
+tenacity==8.2.3
+tensorboard==2.15.1
+tensorboard-data-server==0.7.2
+tensorboardX==2.6.2.2
+termcolor==2.3.0
+texttable==1.7.0
+threadpoolctl==3.3.0
+tifffile==2024.2.12
+timm==0.6.12
+tokenizers==0.15.2
+tomli==2.0.1
+tomlkit==0.12.0
+toolz==0.12.1
+tqdm==4.65.2
+transaction==4.0
+transformers==4.37.1
+translationstring==1.4
+triton==2.2.0
+typer==0.9.0
+typing_extensions==4.8.0
+tzdata==2024.1
+tzlocal==5.2
+universal-analytics-python3==1.1.1
+urllib3==1.26.18
+uvicorn==0.28.0
+velruse==1.1.1
+venusian==3.1.0
+voxel51-eta==0.12.6
+wandb==0.14.0
+wcwidth==0.2.13
+WebOb==1.8.7
+websockets==11.0.3
+Werkzeug==3.0.1
+wrapt==1.16.0
+wsproto==1.2.0
+WTForms==3.1.2
+wtforms-recaptcha==0.3.2
+xmltodict==0.13.0
+yacs==0.1.8
+yapf==0.40.2
+zipp==3.18.1
+zope.deprecation==5.0
+zope.interface==6.2
+zope.sqlalchemy==3.1
diff --git a/requirements.torch.txt b/requirements.torch.txt
new file mode 100644
index 0000000000000000000000000000000000000000..75367ad5ca53ff03cc399347237e3f565f9dee34
--- /dev/null
+++ b/requirements.torch.txt
@@ -0,0 +1,4 @@
+--index-url https://download.pytorch.org/whl/cu118
+torch==2.2.1
+torchaudio==2.2.1
+torchvision==0.17.1
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..54656e56be266fb3a4a6a4769be4e0f83874c2bf
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,246 @@
+absl-py==2.1.0
+accelerate==0.26.1
+addict==2.4.0
+aiofiles==23.2.1
+aliyun-python-sdk-core==2.15.0
+aliyun-python-sdk-kms==2.16.2
+altair==5.2.0
+annotated-types==0.6.0
+antlr4-python3-runtime==4.9.3
+anyio==4.3.0
+anykeystore==0.2
+apex==0.9.10.dev0
+appdirs==1.4.4
+argcomplete==3.2.3
+attrs==23.2.0
+av==10.0.0
+beautifulsoup4==4.12.3
+blessed==1.20.0
+blessings==1.7
+boto3==1.34.63
+botocore==1.34.63
+Brotli==1.1.0
+cachetools==5.3.3
+certifi==2024.2.2
+cffi==1.16.0
+charset-normalizer==3.3.2
+click==8.1.7
+colorama==0.4.6
+contourpy==1.2.0
+crcmod==1.7
+cryptacular==1.6.2
+cryptography==42.0.5
+cycler==0.12.1
+dacite==1.7.0
+decorator==4.4.2
+decord==0.6.0
+deepspeed==0.14.0
+defusedxml==0.7.1
+Deprecated==1.2.14
+dill==0.3.8
+distro==1.9.0
+dnspython==2.6.1
+docker-pycreds==0.4.0
+einops==0.6.1
+exceptiongroup==1.2.0
+fastapi==0.110.0
+ffmpeg==1.4
+ffmpy==0.3.2
+fiftyone==0.23.6
+fiftyone-brain==0.16.1
+fiftyone_db==1.1.2
+filelock==3.9.0
+fonttools==4.49.0
+fsspec==2024.2.0
+ftfy==6.1.3
+future==1.0.0
+fvcore==0.1.5.post20221221
+gdown==5.1.0
+gitdb==4.0.11
+GitPython==3.1.42
+glob2==0.7
+google-auth==2.28.2
+google-auth-oauthlib==1.2.0
+gpustat==1.1.1
+gradio==4.21.0
+gradio_client==0.12.0
+graphql-core==3.2.3
+greenlet==3.0.3
+grpcio==1.62.1
+h11==0.14.0
+h2==4.1.0
+hjson==3.1.0
+hpack==4.0.0
+httpcore==1.0.4
+httpx==0.27.0
+huggingface-hub==0.21.4
+humanize==4.9.0
+hupper==1.12.1
+Hypercorn==0.16.0
+hyperframe==6.0.1
+idna==3.6
+idscheck==2.3.0
+imageio==2.27.0
+imageio-ffmpeg==0.4.9
+importlib_metadata==7.0.2
+importlib_resources==6.3.0
+inflate64==1.0.0
+iopath==0.1.10
+Jinja2==3.1.2
+jmespath==0.10.0
+joblib==1.3.2
+jsonlines==4.0.0
+jsonschema==4.21.1
+jsonschema-specifications==2023.12.1
+kaleido==0.2.1
+kiwisolver==1.4.5
+lazy_loader==0.3
+Markdown==3.6
+markdown-it-py==3.0.0
+MarkupSafe==2.1.3
+matplotlib==3.8.3
+mdurl==0.1.2
+mmcv-full==1.7.2
+model-index==0.1.11
+mongoengine==0.24.2
+motor==3.3.2
+moviepy==1.0.3
+mpmath==1.3.0
+multivolumefile==0.2.3
+networkx==3.2.1
+ninja==1.11.1.1
+numpy==1.23.5
+oauthlib==3.2.2
+omegaconf==2.3.0
+openai==1.14.0
+opencv-python==4.9.0.80
+opencv-python-headless==4.9.0.80
+opendatalab==0.0.10
+openmim==0.3.9
+openxlab==0.0.36
+ordered-set==4.1.0
+orjson==3.9.15
+oss2==2.17.0
+packaging==24.0
+pandas==1.5.3
+PasteDeploy==3.1.0
+pathtools==0.1.2
+pbkdf2==1.3
+peft==0.10.0
+pillow==10.2.0
+plaster==1.1.2
+plaster-pastedeploy==1.0.1
+platformdirs==4.2.0
+plotly==5.20.0
+portalocker==2.8.2
+pprintpp==0.4.0
+priority==2.0.0
+proglog==0.1.10
+protobuf==4.23.4
+psutil==5.9.4
+py-cpuinfo==9.0.0
+py7zr==0.21.0
+pyasn1==0.5.1
+pyasn1-modules==0.3.0
+pybcj==1.0.2
+pycparser==2.21
+pycryptodome==3.20.0
+pycryptodomex==3.20.0
+pydantic==2.6.4
+pydantic_core==2.16.3
+pydub==0.25.1
+Pygments==2.17.2
+pymongo==4.6.2
+pynvml==11.5.0
+pyparsing==3.1.2
+pyppmd==1.1.0
+pyramid==2.0.2
+pyramid-mailer==0.15.1
+PySocks==1.7.1
+python-dateutil==2.9.0.post0
+python-multipart==0.0.9
+python3-openid==3.2.0
+pytz==2023.4
+PyYAML==6.0
+pyzstd==0.15.9
+rarfile==4.1
+referencing==0.33.0
+regex==2023.12.25
+repoze.sendmail==4.4.1
+requests==2.28.2
+requests-oauthlib==1.4.0
+retrying==1.3.4
+rich==13.4.2
+rpds-py==0.18.0
+rsa==4.9
+ruff==0.3.2
+s3transfer==0.10.1
+safetensors==0.4.2
+scikit-image==0.22.0
+scikit-learn==1.4.1.post1
+scipy==1.10.1
+semantic-version==2.10.0
+sentencepiece==0.2.0
+sentry-sdk==1.42.0
+setproctitle==1.3.3
+shellingham==1.5.4
+six==1.16.0
+smmap==5.0.1
+sniffio==1.3.1
+sortedcontainers==2.4.0
+soupsieve==2.5
+SQLAlchemy==2.0.28
+sse-starlette==0.10.3
+sseclient-py==1.8.0
+starlette==0.36.3
+strawberry-graphql==0.138.1
+sympy==1.12
+tabulate==0.9.0
+taskgroup==0.0.0a4
+tenacity==8.2.3
+tensorboard==2.15.1
+tensorboard-data-server==0.7.2
+tensorboardX==2.6.2.2
+termcolor==2.3.0
+texttable==1.7.0
+threadpoolctl==3.3.0
+tifffile==2024.2.12
+timm==0.6.12
+tokenizers==0.15.2
+tomli==2.0.1
+tomlkit==0.12.0
+toolz==0.12.1
+torch==2.2.1
+torchaudio==2.2.1
+torchvision==0.17.1
+tqdm==4.65.2
+transaction==4.0
+transformers
+translationstring==1.4
+triton==2.2.0
+typer==0.9.0
+typing_extensions==4.8.0
+tzdata==2024.1
+tzlocal==5.2
+universal-analytics-python3==1.1.1
+urllib3==1.26.18
+uvicorn==0.28.0
+velruse==1.1.1
+venusian==3.1.0
+voxel51-eta==0.12.6
+wandb==0.14.0
+wcwidth==0.2.13
+WebOb==1.8.7
+websockets==11.0.3
+Werkzeug==3.0.1
+wrapt==1.16.0
+wsproto==1.2.0
+WTForms==3.1.2
+wtforms-recaptcha==0.3.2
+xmltodict==0.13.0
+yacs==0.1.8
+yapf==0.40.2
+zipp==3.18.1
+zope.deprecation==5.0
+zope.interface==6.2
+zope.sqlalchemy==3.1
diff --git a/scripts/accel_config_deepspeed_zero2.yaml b/scripts/accel_config_deepspeed_zero2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ee8d5e49ae4c5d253ba8c1ea0ffe7b729b905cfd
--- /dev/null
+++ b/scripts/accel_config_deepspeed_zero2.yaml
@@ -0,0 +1,21 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ gradient_accumulation_steps: 8
+ offload_optimizer_device: none
+ offload_param_device: none
+ zero3_init_flag: false
+ zero_stage: 2
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 4
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/scripts/accel_config_deepspeed_zero3_offload.yaml b/scripts/accel_config_deepspeed_zero3_offload.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..436357c30fc3ca74e68eded9495fef8b3b244f22
--- /dev/null
+++ b/scripts/accel_config_deepspeed_zero3_offload.yaml
@@ -0,0 +1,22 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ gradient_accumulation_steps: 2
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 3
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/scripts/accel_config_deepspeed_zero3_offload_multinode.yaml b/scripts/accel_config_deepspeed_zero3_offload_multinode.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..333b4f18e6e540b162c9846d7632c64d6c8827e0
--- /dev/null
+++ b/scripts/accel_config_deepspeed_zero3_offload_multinode.yaml
@@ -0,0 +1,25 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 2
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 3
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+machine_rank: 0
+main_process_ip: fdbd:dc61:18:8::20
+main_process_port: 6876
+main_training_function: main
+mixed_precision: bf16
+num_machines: 2
+num_processes: 16
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/scripts/accel_config_deepspeed_zero3_offload_multinode_1.yaml b/scripts/accel_config_deepspeed_zero3_offload_multinode_1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..333b4f18e6e540b162c9846d7632c64d6c8827e0
--- /dev/null
+++ b/scripts/accel_config_deepspeed_zero3_offload_multinode_1.yaml
@@ -0,0 +1,25 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 2
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 3
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+machine_rank: 0
+main_process_ip: fdbd:dc61:18:8::20
+main_process_port: 6876
+main_training_function: main
+mixed_precision: bf16
+num_machines: 2
+num_processes: 16
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/scripts/accel_config_deepspeed_zero3_offload_multinode_2.yaml b/scripts/accel_config_deepspeed_zero3_offload_multinode_2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f2c57be497505189415cf0ffbf98af21516f676a
--- /dev/null
+++ b/scripts/accel_config_deepspeed_zero3_offload_multinode_2.yaml
@@ -0,0 +1,25 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 2
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 3
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+machine_rank: 1
+main_process_ip: fdbd:dc61:18:8::20
+main_process_port: 6876
+main_training_function: main
+mixed_precision: bf16
+num_machines: 2
+num_processes: 16
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/scripts/accel_config_deepspeed_zero3_offload_singlegpu.yaml b/scripts/accel_config_deepspeed_zero3_offload_singlegpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0583d16a02f966f66c74e02edb80da970f6dceee
--- /dev/null
+++ b/scripts/accel_config_deepspeed_zero3_offload_singlegpu.yaml
@@ -0,0 +1,23 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ gradient_accumulation_steps: 16
+ gradient_clipping: 1.0
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 3
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 1
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/scripts/accel_config_multigpu.yaml b/scripts/accel_config_multigpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dbe0dc7b6ade744eca906c95a06c018f21cac09f
--- /dev/null
+++ b/scripts/accel_config_multigpu.yaml
@@ -0,0 +1,16 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+gpu_ids: 2,3,4,5
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 4
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/scripts/accel_config_multinode.yaml b/scripts/accel_config_multinode.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b437201b4d6b27f339756bc44061c9e3f568c50c
--- /dev/null
+++ b/scripts/accel_config_multinode.yaml
@@ -0,0 +1,18 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+gpu_ids: all
+machine_rank: 1
+main_process_ip: 10.193.16.150
+main_process_port: 6784
+main_training_function: main
+mixed_precision: bf16
+num_machines: 2
+num_processes: 16
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/scripts/accel_config_singlegpu.yaml b/scripts/accel_config_singlegpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cda636385ae4afb7425dbb4ed6c2630ec42b6c70
--- /dev/null
+++ b/scripts/accel_config_singlegpu.yaml
@@ -0,0 +1,16 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: 'NO'
+downcast_bf16: 'no'
+gpu_ids: '0'
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 1
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/scripts/demo.sh b/scripts/demo.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5b6dfd2f00f4f463b91dc3911efdc66b2c8b97f0
--- /dev/null
+++ b/scripts/demo.sh
@@ -0,0 +1,32 @@
+model_dir=${1:-"MODELS/pllava-7b"}
+weight_dir=${2:-"${model_dir}"}
+num_frames=16
+lora_alpha=4
+
+echo Running DEMO from model_dir: ${model_dir}
+echo Running DEMO from weights_dir: ${weight_dir}
+echo Running DEMO On Devices: ${CUDA_VISIBLE_DEVICES}
+
+
+# # 34B Need to Use dispatch for this large.
+# CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} python -m tasks.eval.demo.pllava_demo \
+# --pretrained_model_name_or_path ${model_dir} \
+# --num_frames ${num_frames} \
+# --use_lora \
+# --weight_dir ${weight_dir} \
+# --lora_alpha ${lora_alpha} \
+# --conv_mode eval_vcg_llava_next \
+# --use_multi_gpus \
+
+
+# 7B and 13B, There are problem if Model was split around A100 40G... Probably because some unkown bug in accelerate dispatch
+CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0,1"} python -m tasks.eval.demo.pllava_demo \
+ --pretrained_model_name_or_path ${model_dir} \
+ --num_frames ${num_frames} \
+ --use_lora \
+ --weight_dir ${weight_dir} \
+ --lora_alpha ${lora_alpha} \
+ --conv_mode plain \
+ --use_multi_gpus
+
+
diff --git a/scripts/eval.sh b/scripts/eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..db91cdde75b56e1c15c0bfc1e76b0d1764d08ac2
--- /dev/null
+++ b/scripts/eval.sh
@@ -0,0 +1,104 @@
+# export CUDA_VISIBLE_DEVICES=2,6,7
+export OPENAI_API_KEY=...
+num_frames=16
+test_ratio=1
+
+# 13b, uses offload thus saving the full model
+model_dir=MODELS/pllava-13b
+weight_dir=MODELS/pllava-13b
+SAVE_DIR=test_results/test_pllava_13b
+lora_alpha=4
+conv_mode=eval_vcgbench
+python -m tasks.eval.vcgbench.pllava_eval_vcgbench \
+ --pretrained_model_name_or_path ${model_dir} \
+ --save_path ${SAVE_DIR}/vcgbench \
+ --num_frames ${num_frames} \
+ --use_lora \
+ --lora_alpha ${lora_alpha} \
+ --weight_dir ${weight_dir} \
+ --pooling_shape 16-12-12 \
+ --test_ratio ${test_ratio} \
+ --conv_mode ${conv_mode}
+
+conv_mode=eval_mvbench
+python -m tasks.eval.mvbench.pllava_eval_mvbench \
+ --pretrained_model_name_or_path ${model_dir} \
+ --save_path ${SAVE_DIR}/mvbench \
+ --use_lora \
+ --lora_alpha ${lora_alpha} \
+ --num_frames ${num_frames} \
+ --weight_dir ${weight_dir} \
+ --pooling_shape 16-12-12 \
+ --conv_mode ${conv_mode}
+
+onv_mode=eval_videoqabench
+python -m tasks.eval.videoqabench.pllava_eval_videoqabench \
+ --pretrained_model_name_or_path ${model_dir} \
+ --save_path ${SAVE_DIR}/videoqabench \
+ --num_frames ${num_frames} \
+ --use_lora \
+ --lora_alpha ${lora_alpha} \
+ --weight_dir ${weight_dir} \
+ --test_ratio ${test_ratio} \
+ --conv_mode ${conv_mode}
+
+
+conv_mode=eval_recaption
+python -m tasks.eval.recaption.pllava_recaption \
+ --pretrained_model_name_or_path ${model_dir} \
+ --save_path ${SAVE_DIR}/recaption \
+ --num_frames ${num_frames} \
+ --use_lora \
+ --weight_dir ${weight_dir} \
+ --lora_alpha ${lora_alpha} \
+ --test_ratio ${test_ratio} \
+ --conv_mode ${conv_mode}
+
+
+model_dir=MODELS/pllava-7b
+weight_dir=MODELS/pllava-7b
+SAVE_DIR=test_results/test_pllava_7b
+lora_alpha=4
+
+conv_mode=eval_vcgbench
+python -m tasks.eval.vcgbench.pllava_eval_vcgbench \
+ --pretrained_model_name_or_path ${model_dir} \
+ --save_path ${SAVE_DIR}/vcgbench \
+ --num_frames ${num_frames} \
+ --use_lora \
+ --lora_alpha ${lora_alpha} \
+ --weight_dir ${weight_dir} \
+ --pooling_shape 16-12-12 \
+ --test_ratio ${test_ratio}
+
+
+conv_mode=eval_mvbench
+python -m tasks.eval.mvbench.pllava_eval_mvbench \
+ --pretrained_model_name_or_path ${model_dir} \
+ --save_path ${SAVE_DIR}/mvbench \
+ --use_lora \
+ --lora_alpha ${lora_alpha} \
+ --num_frames ${num_frames} \
+ --weight_dir ${weight_dir} \
+ --pooling_shape 16-12-12
+
+
+onv_mode=eval_videoqabench
+python -m tasks.eval.videoqabench.pllava_eval_videoqabench \
+ --pretrained_model_name_or_path ${model_dir} \
+ --save_path ${SAVE_DIR}/videoqabench \
+ --num_frames ${num_frames} \
+ --use_lora \
+ --lora_alpha ${lora_alpha} \
+ --weight_dir ${weight_dir} \
+ --test_ratio ${test_ratio}
+
+conv_mode=eval_recaption
+python -m tasks.eval.recaption.pllava_recaption \
+ --pretrained_model_name_or_path ${model_dir} \
+ --save_path ${SAVE_DIR}/recaption \
+ --num_frames ${num_frames} \
+ --use_lora \
+ --lora_alpha ${lora_alpha} \
+ --weight_dir ${weight_dir} \
+ --test_ratio ${test_ratio}
\ No newline at end of file
diff --git a/scripts/eval_yiprompt.sh b/scripts/eval_yiprompt.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0307017c9d314133a2a2071d2b418a782ddc8a2d
--- /dev/null
+++ b/scripts/eval_yiprompt.sh
@@ -0,0 +1,53 @@
+# export CUDA_VISIBLE_DEVICES=0,3,4,5,6,7
+export OPENAI_API_KEY=...
+num_frames=16
+test_ratio=200
+
+model_dir=MODELS/pllava-34b
+weight_dir=MODELS/pllava-34b
+SAVE_DIR=test_results/test_pllava_34b
+lora_alpha=4
+conv_mode=eval_vcg_llavanext
+python -m tasks.eval.vcgbench.pllava_eval_vcgbench \
+ --pretrained_model_name_or_path ${model_dir} \
+ --save_path ${SAVE_DIR}/vcgbench \
+ --num_frames ${num_frames} \
+ --use_lora \
+ --lora_alpha ${lora_alpha} \
+ --weight_dir ${weight_dir} \
+ --pooling_shape 16-12-12 \
+ --test_ratio ${test_ratio} \
+ --conv_mode $conv_mode
+
+conv_mode=eval_mvbench_llavanext
+python -m tasks.eval.mvbench.pllava_eval_mvbench \
+ --pretrained_model_name_or_path ${model_dir} \
+ --save_path ${SAVE_DIR}/mvbench \
+ --use_lora \
+ --lora_alpha ${lora_alpha} \
+ --num_frames ${num_frames} \
+ --weight_dir ${weight_dir} \
+ --pooling_shape 16-12-12 \
+ --conv_mode $conv_mode
+
+conv_mode=eval_videoqa_llavanext
+python -m tasks.eval.videoqabench.pllava_eval_videoqabench \
+ --pretrained_model_name_or_path ${model_dir} \
+ --save_path ${SAVE_DIR}/videoqabench \
+ --num_frames ${num_frames} \
+ --use_lora \
+ --lora_alpha ${lora_alpha} \
+ --weight_dir ${weight_dir} \
+ --test_ratio ${test_ratio} \
+ --conv_mode ${conv_mode}
+
+conv_mode=eval_recaption_llavanext
+python -m tasks.eval.recaption.pllava_recaption \
+ --pretrained_model_name_or_path ${model_dir} \
+ --save_path ${SAVE_DIR}/recaption \
+ --num_frames ${num_frames} \
+ --use_lora \
+ --weight_dir ${weight_dir} \
+ --lora_alpha ${lora_alpha} \
+ --test_ratio ${test_ratio} \
+ --conv_mode $conv_mode
diff --git a/scripts/gallery.sh b/scripts/gallery.sh
new file mode 100644
index 0000000000000000000000000000000000000000..862898a40b8a98405922b89e0d1ce166f6b42e0b
--- /dev/null
+++ b/scripts/gallery.sh
@@ -0,0 +1,11 @@
+export OPENAI_API_KEY=...
+SAVE_DIR=${1:-"test_results"}
+
+# # gallery view
+# python -m tasks.eval.show_gallery \
+# --root_dir ${SAVE_DIR}
+
+# # compare view
+python -m tasks.eval.demo.show_compare \
+ --root_dir ${SAVE_DIR}
+
diff --git a/scripts/train_pllava.sh b/scripts/train_pllava.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3c7b2c23bc7dd9699fcc6027752d7ce9dbaf826c
--- /dev/null
+++ b/scripts/train_pllava.sh
@@ -0,0 +1,34 @@
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+OUTPUT_DIR=./pllava_video_outputs/test_train_7b_reconstruct
+
+# # Naive Env
+# rm -rf ${OUTPUT_DIR}
+pooling_shape=(16,12,12)
+accelerate launch --main_process_port 6876 --config_file scripts/accel_config_multigpu.yaml tasks/train/train_pllava_nframe_accel.py \
+ tasks/train/config_pllava_nframe.py \
+ output_dir ${OUTPUT_DIR} \
+ train_corpus videochat2_video \
+ save_steps 10000 \
+ num_workers 8 \
+ num_frames 16 \
+ model.pooling_method avg \
+ model.repo_id llava-hf/llava-v1.6-vicuna-7b-hf \
+ model.use_lora True \
+ model.pooling_shape $pooling_shape \
+ optimizer.lr 2e-5 \
+ scheduler.epochs 3 \
+ scheduler.warmup_ratio 0.2 \
+ scheduler.min_lr_multi 0.25 \
+ scheduler.is_videochat2_custom True \
+ preprocess.mm_alone False \
+ preprocess.random_shuffle False \
+ preprocess.add_second_msg False \
+ train_corpus videochat2_instruction_debug
+
+
\ No newline at end of file
diff --git a/scripts/train_pllava_13b.sh b/scripts/train_pllava_13b.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ba23997cbbb77b268fa2a2766a00d352ffbd6f85
--- /dev/null
+++ b/scripts/train_pllava_13b.sh
@@ -0,0 +1,50 @@
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+OUTPUT_DIR=./pllava_video_outputs/pllava_13b
+
+
+pooling_shape=(16,12,12)
+num_save_samples=80000
+num_gpus=8
+full_batch_size=128
+batch_size=8
+save_steps=$[$num_save_samples/($batch_size*$num_gpus)]
+ckpt_steps=$[$save_steps/10]
+gradient_accumulation_steps=$[$full_batch_size/($batch_size*$num_gpus)]
+echo $batch_size
+echo $gradient_accumulation_steps
+repo_id=llava-hf/llava-v1.6-vicuna-13b-hf
+accelerate launch --main_process_port 6876 --config_file scripts/accel_config_deepspeed_zero3_offload.yaml tasks/train/train_pllava_nframe_accel.py \
+ tasks/train/config_pllava_nframe.py \
+ output_dir ${OUTPUT_DIR} \
+ train_corpus videochat2_instruction_debug \
+ save_steps $save_steps \
+ ckpt_steps $ckpt_steps \
+ num_workers 8 \
+ num_frames 16 \
+ gradient_accumulation_steps $gradient_accumulation_steps \
+ batch_size $batch_size \
+ deepspeed True \
+ model.pooling_method avg \
+ model.use_lora True \
+ model.use_pooling True \
+ model.repo_id $repo_id \
+ gradient_checkpointing True \
+ preprocess.center_pad False \
+ preprocess.clip_transform False \
+ optimizer.lr 2e-5 \
+ scheduler.epochs 3 \
+ scheduler.warmup_ratio 0.2 \
+ scheduler.min_lr_multi 0.25 \
+ model.pooling_shape $pooling_shape \
+ scheduler.is_videochat2_custom True \
+ preprocess.mm_alone False \
+ preprocess.random_shuffle False \
+ preprocess.add_second_msg False
+
+
diff --git a/scripts/train_pllava_34b.sh b/scripts/train_pllava_34b.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2c167e34dd7a5b0bbe776d784af1894d8b1830d4
--- /dev/null
+++ b/scripts/train_pllava_34b.sh
@@ -0,0 +1,50 @@
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+machine_rank=${1:-"0"} # machine rank
+
+OUTPUT_DIR=./pllava_video_outputs/pllava_34b_videchat2-video
+
+pooling_shape=(16,12,12)
+num_save_samples=80000
+num_gpus=8
+full_batch_size=128
+batch_size=4
+save_steps=$[$num_save_samples/($batch_size*$num_gpus)]
+ckpt_steps=$[$save_steps/10]
+gradient_accumulation_steps=$[$full_batch_size/($batch_size*$num_gpus)]
+echo $batch_size
+echo $gradient_accumulation_steps
+repo_id=llava-hf/llava-v1.6-34b-hf
+accelerate launch --main_process_port 6876 --config_file scripts/accel_config_deepspeed_zero3_offload.yaml tasks/train/train_pllava_nframe_accel.py \
+ tasks/train/config_pllava_nframe_yiprompt.py \
+ output_dir ${OUTPUT_DIR} \
+ train_corpus videochat2_instruction_debug \
+ save_steps $save_steps \
+ ckpt_steps $ckpt_steps \
+ num_workers 8 \
+ num_frames 16 \
+ deepspeed True \
+ gradient_accumulation_steps $gradient_accumulation_steps \
+ batch_size $batch_size \
+ model.pooling_method avg \
+ model.use_lora True \
+ model.use_pooling True \
+ model.repo_id $repo_id \
+ gradient_checkpointing True \
+ preprocess.center_pad False \
+ preprocess.clip_transform True \
+ optimizer.lr 2e-5 \
+ scheduler.epochs 3 \
+ scheduler.warmup_ratio 0.2 \
+ scheduler.min_lr_multi 0.25 \
+ model.pooling_shape $pooling_shape \
+ scheduler.is_videochat2_custom True \
+ preprocess.image_token_index 64002 \
+ preprocess.mm_alone False \
+ preprocess.random_shuffle False \
+ preprocess.add_second_msg False
diff --git a/scripts/train_pllava_7b.sh b/scripts/train_pllava_7b.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f21cad8869e90727b2836af987ffd0e00972ceef
--- /dev/null
+++ b/scripts/train_pllava_7b.sh
@@ -0,0 +1,49 @@
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+OUTPUT_DIR=./pllava_video_outputs/test_train_7b_reconstruct
+
+pooling_shape=(16,12,12)
+num_save_samples=80000
+num_gpus=8
+full_batch_size=128
+batch_size=8
+save_steps=$[$num_save_samples/($batch_size*$num_gpus)]
+ckpt_steps=$[$save_steps/10]
+gradient_accumulation_steps=$[$full_batch_size/($batch_size*$num_gpus)]
+echo $batch_size
+echo $gradient_accumulation_steps
+repo_id=llava-hf/llava-v1.6-vicuna-7b-hf
+accelerate launch --main_process_port 6876 --config_file scripts/accel_config_multigpu.yaml tasks/train/train_pllava_nframe_accel.py \
+ tasks/train/config_pllava_nframe.py \
+ output_dir ${OUTPUT_DIR} \
+ train_corpus videochat2_instruction_debug \
+ save_steps $save_steps \
+ ckpt_steps $ckpt_steps \
+ num_workers 8 \
+ num_frames 16 \
+ gradient_accumulation_steps $gradient_accumulation_steps \
+ batch_size $batch_size \
+ model.pooling_method avg \
+ model.use_lora True \
+ model.use_pooling True \
+ model.repo_id $repo_id \
+ gradient_checkpointing True \
+ preprocess.center_pad False \
+ preprocess.clip_transform False \
+ optimizer.lr 2e-5 \
+ scheduler.epochs 3 \
+ scheduler.warmup_ratio 0.2 \
+ scheduler.min_lr_multi 0.25 \
+ model.pooling_shape $pooling_shape \
+ scheduler.is_videochat2_custom True \
+ preprocess.mm_alone False \
+ preprocess.random_shuffle False \
+ preprocess.add_second_msg False
+
+
+
diff --git a/tasks/eval/demo/__init__.py b/tasks/eval/demo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b18087bf6f52339838e51b5dad0e1a1ab17f43cc
--- /dev/null
+++ b/tasks/eval/demo/__init__.py
@@ -0,0 +1,15 @@
+import gradio as gr
+from gradio.themes.utils import colors, fonts, sizes
+
+
+pllava_theme = gr.themes.Monochrome(
+ text_size="sm",
+ spacing_size="sm",
+ primary_hue=gr.themes.Color(c100="#f5f5f5", c200="#e5e5e5", c300="#d4d4d4", c400="#a3a3a3", c50="#fafafa", c500="#737373", c600="#525252", c700="#404040", c800="#262626", c900="#171717", c950="#000000"),
+ secondary_hue=gr.themes.Color(c100="#f5f5f5", c200="#e5e5e5", c300="#d4d4d4", c400="#a3a3a3", c50="#fafafa", c500="#737373", c600="#525252", c700="#404040", c800="#262626", c900="#171717", c950="#000000"),
+ neutral_hue=gr.themes.Color(c100="#f5f5f5", c200="#e5e5e5", c300="#d4d4d4", c400="#a3a3a3", c50="#fafafa", c500="#737373", c600="#525252", c700="#404040", c800="#262626", c900="#171717", c950="#000000"),
+).set(
+ background_fill_primary_dark='*primary_950',
+ background_fill_secondary_dark='*neutral_950'
+)
+
diff --git a/tasks/eval/demo/pllava_demo.py b/tasks/eval/demo/pllava_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..6126db37b766e25665a8c4a7a7a4f10ab7958bb7
--- /dev/null
+++ b/tasks/eval/demo/pllava_demo.py
@@ -0,0 +1,261 @@
+from argparse import ArgumentParser
+import copy
+import gradio as gr
+from gradio.themes.utils import colors, fonts, sizes
+
+from utils.easydict import EasyDict
+from tasks.eval.model_utils import load_pllava
+from tasks.eval.eval_utils import (
+ ChatPllava,
+ conv_plain_v1,
+ Conversation,
+ conv_templates
+)
+from tasks.eval.demo import pllava_theme
+
+SYSTEM="""You are Pllava, a large vision-language assistant.
+You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language.
+Follow the instructions carefully and explain your answers in detail based on the provided video.
+"""
+INIT_CONVERSATION: Conversation = conv_plain_v1.copy()
+
+
+# ========================================
+# Model Initialization
+# ========================================
+def init_model(args):
+
+ print('Initializing PLLaVA')
+ model, processor = load_pllava(
+ args.pretrained_model_name_or_path, args.num_frames,
+ use_lora=args.use_lora,
+ weight_dir=args.weight_dir,
+ lora_alpha=args.lora_alpha,
+ use_multi_gpus=args.use_multi_gpus)
+ if not args.use_multi_gpus:
+ model = model.to('cuda')
+ chat = ChatPllava(model, processor)
+ return chat
+
+
+# ========================================
+# Gradio Setting
+# ========================================
+def gradio_reset(chat_state, img_list):
+ if chat_state is not None:
+ chat_state = INIT_CONVERSATION.copy()
+ if img_list is not None:
+ img_list = []
+ return (
+ None,
+ gr.update(value=None, interactive=True),
+ gr.update(value=None, interactive=True),
+ gr.update(placeholder='Please upload your video first', interactive=False),
+ gr.update(value="Upload & Start Chat", interactive=True),
+ chat_state,
+ img_list
+ )
+
+
+def upload_img(gr_img, gr_video, chat_state=None, num_segments=None, img_list=None):
+ print(gr_img, gr_video)
+ chat_state = INIT_CONVERSATION.copy() if chat_state is None else chat_state
+ img_list = [] if img_list is None else img_list
+
+ if gr_img is None and gr_video is None:
+ return None, None, gr.update(interactive=True),gr.update(interactive=True, placeholder='Please upload video/image first!'), chat_state, None
+ if gr_video:
+ llm_message, img_list, chat_state = chat.upload_video(gr_video, chat_state, img_list, num_segments)
+ return (
+ gr.update(interactive=True),
+ gr.update(interactive=True),
+ gr.update(interactive=True, placeholder='Type and press Enter'),
+ gr.update(value="Start Chatting", interactive=False),
+ chat_state,
+ img_list,
+ )
+ if gr_img:
+ llm_message, img_list,chat_state = chat.upload_img(gr_img, chat_state, img_list)
+ return (
+ gr.update(interactive=True),
+ gr.update(interactive=True),
+ gr.update(interactive=True, placeholder='Type and press Enter'),
+ gr.update(value="Start Chatting", interactive=False),
+ chat_state,
+ img_list
+ )
+
+
+def gradio_ask(user_message, chatbot, chat_state, system):
+ if len(user_message) == 0:
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
+ chat_state = chat.ask(user_message, chat_state, system)
+ chatbot = chatbot + [[user_message, None]]
+ return '', chatbot, chat_state
+
+
+def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
+ llm_message, llm_message_token, chat_state = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=200, num_beams=num_beams, temperature=temperature)
+ llm_message = llm_message.replace("", "") # handle
+ chatbot[-1][1] = llm_message
+ print(chat_state)
+ print(f"Answer: {llm_message}")
+ return chatbot, chat_state, img_list
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ required=True,
+ default='llava-hf/llava-1.5-7b-hf'
+ )
+ parser.add_argument(
+ "--num_frames",
+ type=int,
+ required=True,
+ default=4,
+ )
+ parser.add_argument(
+ "--use_lora",
+ action='store_true'
+ )
+ parser.add_argument(
+ "--use_multi_gpus",
+ action='store_true'
+ )
+ parser.add_argument(
+ "--weight_dir",
+ type=str,
+ required=False,
+ default=None,
+ )
+ parser.add_argument(
+ "--conv_mode",
+ type=str,
+ required=False,
+ default=None,
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ required=False,
+ default=None,
+ )
+ parser.add_argument(
+ "--server_port",
+ type=int,
+ required=False,
+ default=7868,
+ )
+ args = parser.parse_args()
+ return args
+
+
+title = """
"""
+description = (
+ """
+ # PLLAVA!
+
+ - Upload A Video
+ - Press Upload
+ - Start Chatting
+ """
+)
+
+args = parse_args()
+
+model_description = f"""
+ # MODEL INFO
+ - pretrained_model_name_or_path:{args.pretrained_model_name_or_path}
+ - use_lora:{args.use_lora}
+ - weight_dir:{args.weight_dir}
+"""
+
+# with gr.Blocks(title="InternVideo-VideoChat!",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
+with gr.Blocks(title="PLLaVA",
+ theme=pllava_theme,
+ css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
+ gr.Markdown(title)
+ gr.Markdown(description)
+ gr.Markdown(model_description)
+ with gr.Row():
+ with gr.Column(scale=0.5, visible=True) as video_upload:
+ # with gr.Column(elem_id="image", scale=0.5) as img_part:
+ with gr.Tab("Video", elem_id='video_tab'):
+ up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload", height=360)
+ with gr.Tab("Image", elem_id='image_tab'):
+ up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload", height=360)
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
+ clear = gr.Button("Restart")
+
+ # num_segments = gr.Slider(
+ # minimum=8,
+ # maximum=64,
+ # value=8,
+ # step=1,
+ # interactive=True,
+ # label="Video Segments",
+ # )
+
+ with gr.Column(visible=True) as input_raws:
+ system_string = gr.Textbox(SYSTEM, interactive=True, label='system')
+ num_beams = gr.Slider(
+ minimum=1,
+ maximum=5,
+ value=1,
+ step=1,
+ interactive=True,
+ label="beam search numbers",
+ )
+ temperature = gr.Slider(
+ minimum=0.1,
+ maximum=2.0,
+ value=1.0,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+
+ chat_state = gr.State()
+ img_list = gr.State()
+ chatbot = gr.Chatbot(elem_id="chatbot",label='Conversation')
+ with gr.Row():
+ with gr.Column(scale=0.7):
+ text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False, container=False)
+ with gr.Column(scale=0.15, min_width=0):
+ run = gr.Button("💭Send")
+ with gr.Column(scale=0.15, min_width=0):
+ clear = gr.Button("🔄Clear")
+
+ with gr.Row():
+ examples = gr.Examples(
+ examples=[
+ ['example/jesse_dance.mp4', 'What is the man doing?'],
+ ['example/yoga.mp4', 'What is the woman doing?'],
+ ['example/cooking.mp4', 'Describe the background, characters and the actions in the provided video.'],
+ # ['example/cooking.mp4', 'What is happening in the video?'],
+ ['example/working.mp4', 'Describe the background, characters and the actions in the provided video.'],
+ ['example/1917.mp4', 'Describe the background, characters and the actions in the provided video.'],
+ ],
+ inputs=[up_video, text_input]
+ )
+
+
+ chat = init_model(args)
+ INIT_CONVERSATION = conv_templates[args.conv_mode]
+ upload_button.click(upload_img, [up_image, up_video, chat_state], [up_image, up_video, text_input, upload_button, chat_state, img_list])
+
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state, system_string], [text_input, chatbot, chat_state]).then(
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
+ )
+ run.click(gradio_ask, [text_input, chatbot, chat_state, system_string], [text_input, chatbot, chat_state]).then(
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
+ )
+ run.click(lambda: "", None, text_input)
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, up_image, up_video, text_input, upload_button, chat_state, img_list], queue=False)
+
+# demo.queue(max_size=5)
+demo.launch(share=True,server_port=args.server_port)
+# demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)
diff --git a/tasks/eval/demo/show_compare.py b/tasks/eval/demo/show_compare.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7accf685a3db3e7428f6a861d1e53028b5b216a
--- /dev/null
+++ b/tasks/eval/demo/show_compare.py
@@ -0,0 +1,124 @@
+
+
+import argparse
+import json
+import os
+import os.path as osp
+import gradio as gr
+import numpy as np
+
+from tasks.eval.recaption import load_results as load_results_recaption
+from tasks.eval.mvbench import load_results as load_results_mvbench
+from tasks.eval.vcgbench import load_results as load_results_vcgbench
+from tasks.eval.videoqabench import load_results as load_results_videoqabench
+from tasks.eval.demo import pllava_theme
+
+
+load_results_funcs = [
+ load_results_recaption,
+ load_results_mvbench,
+ load_results_vcgbench,
+ load_results_videoqabench,
+]
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--root_dir',
+ required=True,
+ )
+ args = parser.parse_args()
+ return args
+
+args = parse_args()
+root_dir = args.root_dir
+
+def show(result_list_first, result_list_second, result_index):
+ sample2index_second = {}
+
+ for i, result in enumerate(result_list_second):
+ if 'video_path' not in result:
+ continue
+
+ question = result['question'] if 'question' in result else ''
+ video_path = result['video_path']
+ samplehash = question + '--' +video_path
+ sample2index_second[samplehash] = i
+
+ info = result_list_first[result_index]
+ info_str_first = json.dumps(info, indent=4, ensure_ascii=False)
+ video_path = info['video_path']
+ question = info['question'] if 'question' in info else ''
+ samplehash = question + '--' +video_path
+ if samplehash in sample2index_second:
+ info = result_list_second[sample2index_second[samplehash]]
+ info_str_second = json.dumps(info, indent=4, ensure_ascii=False)
+ else:
+ info_str_second = f"NO {video_path} IN THE SECOND RESULT DIR"
+ return video_path, info_str_first, info_str_second
+
+def reload_results_dirs():
+ result_dirs = []
+ # load result dir paths
+ for dirpath, dirnames, filenames in os.walk(args.root_dir):
+ if len(dirnames) == 0 and len(filenames) != 0:
+ result_dirs.append(dirpath)
+ return gr.Dropdown(result_dirs, value=result_dirs[0])
+
+def reload_results(result_dir):
+ # if isinstance(result_dir, list):
+ # result_dir = result_dir[0]
+
+ if result_dir is None or not osp.exists(result_dir):
+ return None
+
+ for fn in load_results_funcs:
+ result_list = fn(result_dir)
+ if result_list is not None:
+ np.random.shuffle(result_list)
+ break
+ result_index = gr.Slider(0, len(result_list), step=1)
+
+ return result_list, result_index
+
+
+
+with gr.Blocks(title="PLLAVA RESULTS", theme=pllava_theme) as demo:
+ result_list_first = gr.State()
+ result_list_second = gr.State()
+
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown("# Showing off Model's Outputs.")
+ gr.Markdown(
+ "You can find all our results, including:\n"
+ "1. results of Captioned Inter4k\n"
+ "2. results of Different Benchmark inference outputs.\n"
+ "Choose a directory to see the different output variant.\n"
+ "You can also choose secondary directory (as long as they are from the same dataset.) to compare on the results.\n"
+ )
+
+ with gr.Row():
+ with gr.Column():
+ show_video = gr.Video(interactive=False)
+
+ with gr.Column():
+ button_reload = gr.Button(value='Reload From The Evaluation/Inference Root Directory')
+ result_index = gr.Slider(0, 0, step=1, label="Index")
+
+ result_dir_first = gr.Dropdown(label='Test Result Path')
+ info_first = gr.Text(interactive=False, label='Detailed Output Information')
+ result_dir_second = gr.Dropdown(label='Test Result Path')
+ info_second = gr.Text(interactive=False, label='Detailed Output Information')
+
+
+ button_reload.click(reload_results_dirs, [], [result_dir_first])
+ button_reload.click(reload_results_dirs, [], [result_dir_second])
+ result_dir_first.change(reload_results, [result_dir_first], [result_list_first, result_index])
+ result_dir_second.change(reload_results, [result_dir_second], [result_list_second, result_index])
+ result_index.change(show, [result_list_first, result_list_second, result_index], [show_video, info_first, info_second])
+ demo.load(reload_results_dirs, [], [result_dir_first])
+ demo.load(reload_results_dirs, [], [result_dir_second])
+
+demo.launch(share=True)
\ No newline at end of file
diff --git a/tasks/eval/demo/show_gallery.py b/tasks/eval/demo/show_gallery.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fc7725f5f37eab84deb6c8071d7e7895579964d
--- /dev/null
+++ b/tasks/eval/demo/show_gallery.py
@@ -0,0 +1,94 @@
+
+
+import argparse
+import json
+import os
+import os.path as osp
+import gradio as gr
+
+from tasks.eval.recaption import load_results as load_results_recaption
+from tasks.eval.mvbench import load_results as load_results_mvbench
+from tasks.eval.vcgbench import load_results as load_results_vcgbench
+from tasks.eval.videoqabench import load_results as load_results_videoqabench
+
+load_results_funcs = [
+ load_results_recaption,
+ load_results_mvbench,
+ load_results_vcgbench,
+ load_results_videoqabench,
+]
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--root_dir',
+ required=True,
+ )
+ args = parser.parse_args()
+ return args
+
+args = parse_args()
+root_dir = args.root_dir
+
+def show(result_list, result_index):
+ info = result_list[result_index]
+ video_path = info['video_path']
+ info_str = json.dumps(info, indent=4)
+ return video_path, info_str
+
+def reload_results_dirs():
+ result_dirs = []
+ # load result dir paths
+ for dirpath, dirnames, filenames in os.walk(args.root_dir):
+ if len(dirnames) == 0 and len(filenames) != 0:
+ result_dirs.append(dirpath)
+ return gr.Dropdown(result_dirs, value=result_dirs[0])
+
+def reload_results(result_dir):
+ # if isinstance(result_dir, list):
+ # result_dir = result_dir[0]
+
+ if result_dir is None or not osp.exists(result_dir):
+ return None
+
+ for fn in load_results_funcs:
+ result_list = fn(result_dir)
+ if result_list is not None:
+ break
+
+ result_index = gr.Slider(0, len(result_list), step=1)
+
+ return result_list, result_index
+
+with gr.Blocks() as demo:
+ result_list = gr.State()
+
+ with gr.Row():
+ gr.Markdown("# Showing of what has came out.")
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Markdown(f"### From Saved Results Directory {args.root_dir}")
+
+ with gr.Column(scale=2):
+ result_dir = gr.Dropdown(label='Test Result Path')
+ button_reload = gr.Button(value='Reload From The Evaluation/Inference Root Directory')
+
+
+
+ with gr.Row():
+ with gr.Column():
+ show_video = gr.Video(interactive=False)
+
+ with gr.Column():
+ result_index = gr.Slider(0, 0, step=1, label="Index")
+ info = gr.Text(interactive=False, label='Detailed Output Information')
+
+
+ button_reload.click(reload_results_dirs, [], [result_dir])
+ result_dir.change(reload_results, [result_dir], [result_list, result_index])
+ result_index.change(show, [result_list, result_index], [show_video, info])
+ demo.load(reload_results_dirs, [], [result_dir])
+
+demo.launch(share=True)
\ No newline at end of file
diff --git a/tasks/eval/eval_utils.py b/tasks/eval/eval_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3aabad1c2a33dec15ba4997d8b7f004519c05376
--- /dev/null
+++ b/tasks/eval/eval_utils.py
@@ -0,0 +1,517 @@
+import copy
+import itertools
+import re
+import os
+import json
+from enum import auto, Enum
+import dataclasses
+from typing import Any, List
+
+from PIL import Image
+import cv2
+import imageio
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+import torchvision.transforms as T
+from torchvision.transforms.functional import InterpolationMode
+from moviepy.editor import VideoFileClip
+
+
+from decord import VideoReader, cpu # This is Terrible, if you have this line of import in front of torch, will cause model.to(device) to hang
+from transformers import StoppingCriteria, StoppingCriteriaList
+from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
+
+from utils.easydict import EasyDict
+
+IMAGE_TOKEN = ""
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+ SINGLE = auto()
+ TWO = auto()
+ MPT = auto()
+
+class MultiModalConvStyle(Enum):
+ """Different separator style."""
+ MM_ALONE = 'mm_alone'
+ MM_INTERLEAF = 'mm_inferleaf'
+
+def dump_json(obj_serializable ,save_dir_path, json_file_name):
+ os.makedirs(save_dir_path, exist_ok=True)
+ save_path = os.path.join(save_dir_path, json_file_name)
+ with open(save_path, 'w', encoding='utf-8') as f:
+ json.dump(obj_serializable, f, indent=4, ensure_ascii=False, )
+
+def load_json(load_dir_path, json_file_name):
+
+ load_path = os.path.join(load_dir_path, json_file_name)
+ if not os.path.exists(load_path):
+ return None
+ with open(load_path, 'r', encoding='utf-8') as f:
+ obj_serializable = json.load(f)
+ return obj_serializable
+
+
+
+@dataclasses.dataclass
+class Conversation(EasyDict):
+ """A class that keeps all conversation history."""
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ sep: List[str]
+ mm_token: str
+
+ mm_style: MultiModalConvStyle = MultiModalConvStyle.MM_INTERLEAF
+ pre_query_prompt: str=None
+ post_query_prompt: str=None
+ answer_prompt: str=None
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if isinstance(self.sep, str):
+ self.sep = [self.sep for _ in self.roles]
+
+ def get_prompt(self):
+ sep = [self.sep for _ in self.roles] if isinstance(self.sep, str) else self.sep # if only one sep given, then both sep are the sames
+ sep = dict(zip(self.roles, sep))
+ ret = self.system + sep[self.roles[0]] if self.system != "" else ""
+ for i, (role, message) in enumerate(self.messages):
+ # if is last msg(the prompt for assistant), if answer prompt exists, no sep added
+ if i+1 == len(self.messages):
+ if role != self.roles[-1]: # last role is not the model
+ ret += role + message + sep[role] + self.roles[-1]
+ else:
+ ret += role + message
+ else:
+ ret += role + message + sep[role]
+ return ret
+ # def get_prompt_multichoice(self):
+ # pass
+ def user_query(self, query=None, pre_query_prompt=None, post_query_prompt=None, is_mm=False, num_mm_token=1):
+ if post_query_prompt is not None:
+ query = f"{query} {post_query_prompt}"
+
+ if pre_query_prompt is not None:
+ query = f"{pre_query_prompt} {query}"
+ role = self.roles[0]
+ # TODO: remove the num_mm_token and hack the self.mm_token outside
+ if is_mm:
+ mm_str = num_mm_token*self.mm_token[:-1] + self.mm_token[-1]
+ if self.mm_style == MultiModalConvStyle.MM_ALONE:
+ self._append_message(role, mm_str)
+ elif self.mm_style == MultiModalConvStyle.MM_INTERLEAF:
+ if self.mm_token not in query:
+ query = f'{mm_str} {query}'
+ self._append_message(role, query)
+
+ def assistant_response(self, response, pre_query_prompt=None, post_query_prompt=None):
+ if post_query_prompt is not None:
+ response = f"{response} {post_query_prompt}"
+
+ if pre_query_prompt is not None:
+ response = f"{post_query_prompt} {response}"
+
+ role = self.roles[1]
+ self._append_message(role, response)
+
+ def _append_message(self, role, message):
+ message = '' if message is None else message
+ self.messages.append([role, message])
+
+ def copy(self):
+ return copy.deepcopy(self)
+
+conv_video_chatgpt_v1 = Conversation(
+ system="You are Video-ChatGPT, a large vision-language assistant. "
+ "You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language."
+ "Follow the instructions carefully and explain your answers in detail based on the provided video.",
+ roles=("USER:", "ASSISTANT:"),
+ messages=[],
+ sep=[" ",""],
+ mm_token='',
+ mm_style=MultiModalConvStyle.MM_INTERLEAF,
+)
+
+
+conv_plain_v1 = Conversation(
+ system="",
+ roles=("USER:", "ASSISTANT:"),
+ messages=[],
+ sep=(" ", ""),
+ mm_token=''
+)
+
+# Attention to the roles[0] "USER: " has a space!
+conv_eval_vcg = Conversation(
+ system="You are Video-ChatGPT, a large vision-language assistant. "
+ "You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language."
+ "Follow the instructions carefully and explain your answers in detail based on the provided video.",
+ roles=("USER: ", "ASSISTANT:"),
+ messages=[],
+ sep=[" ",""],
+ mm_token='\n',
+ mm_style=MultiModalConvStyle.MM_ALONE,
+)
+
+conv_eval_vcg_llavanext = Conversation(
+ system="You are Video-ChatGPT, a large vision-language assistant. "
+ "You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language."
+ "Follow the instructions carefully and explain your answers in detail based on the provided video.",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ messages=[],
+ sep=["<|im_end|>\n","<|im_end|>\n"],
+ mm_token='\n',
+ mm_style=MultiModalConvStyle.MM_ALONE,
+)
+
+SYSTEM_MVBENCH="Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n"
+conv_eval_mvbench = Conversation(
+ system=SYSTEM_MVBENCH,
+ roles=("USER: ", "ASSISTANT:"),
+ messages=[],
+ sep=[" ",""],
+ mm_token='\n',
+ mm_style=MultiModalConvStyle.MM_ALONE,
+)
+conv_eval_mvbench_llavanext = Conversation(
+ system="You are Video-ChatGPT, a large vision-language assistant. "
+ "You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language."
+ "Follow the instructions carefully and explain your answers in detail based on the provided video.",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ messages=[],
+ sep=["<|im_end|>\n","<|im_end|>\n"],
+ mm_token='\n',
+ mm_style=MultiModalConvStyle.MM_ALONE,
+)
+
+
+conv_eval_videoqabench = Conversation(
+ system="",
+ roles=("USER: ", "ASSISTANT:"),
+ messages=[],
+ sep=[" ",""],
+ mm_token='\n',
+ mm_style=MultiModalConvStyle.MM_INTERLEAF,
+ pre_query_prompt="The input consists of a sequence of key frames from a video. Answer the question concisely first and followed by significant events, characters, or objects that appear throughout the frames. Question:",
+ post_query_prompt="\n",
+ answer_prompt='\nAnswer: In the video,'
+)
+
+conv_eval_videoqa_llavanext = Conversation(
+ system="<|im_start|>system\nAnswer the question.",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ messages=[],
+ sep=["<|im_end|>\n","<|im_end|>\n"],
+ mm_token='\n',
+ mm_style=MultiModalConvStyle.MM_INTERLEAF,
+ pre_query_prompt="The input consists of a sequence of key frames from a video. Answer the question concisely first and followed by significant events, characters, or objects that appear throughout the frames. Question:",
+ post_query_prompt="\n",
+ answer_prompt='\nAnswer: In the video,'
+)
+
+
+SYSTEM_RECAPTION="""You are a powerful Video Magic ChatBot, a large vision-language assistant.
+You are able to understand the video content that the user provides and assist the user in a video recaptioning task.
+The user will provide you with the video and maybe some extra noisy information to help you out. Make use of the information in a proper way to be competent for the recaption job
+### INSTRUCTIONS:
+1. Follow the user's instruction.
+2. Be critical yet believe in yourself.
+"""
+conv_eval_recaption = Conversation(
+ system=SYSTEM_RECAPTION,
+ roles=("USER: ", "ASSISTANT:"),
+ messages=[],
+ sep=[" ",""],
+ mm_token='\n',
+ mm_style=MultiModalConvStyle.MM_ALONE,
+)
+
+
+conv_eval_recaption_llavanext = Conversation(
+ system=SYSTEM_RECAPTION,
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ messages=[],
+ sep=["<|im_end|>\n","<|im_end|>\n"],
+ mm_token='\n',
+ mm_style=MultiModalConvStyle.MM_ALONE,
+)
+
+
+conv_templates = {
+ "plain": conv_plain_v1,
+ "eval_vcgbench": conv_eval_vcg,
+ "eval_vcg_llavanext": conv_eval_vcg_llavanext,
+ "eval_mvbench": conv_eval_mvbench,
+ "eval_mvbench_llavanext": conv_eval_mvbench_llavanext,
+ "eval_videoqabench": conv_eval_videoqabench,
+ "eval_videoqa_llavanext": conv_eval_videoqa_llavanext,
+ "eval_recaption": conv_eval_recaption,
+ "eval_recaption_llavanext": conv_eval_recaption_llavanext,
+}
+
+
+class EvalDataset(Dataset):
+
+ def __init__(self, num_segments, test_ratio=None):
+ super().__init__()
+ self.num_segments = num_segments
+ self.test_ratio = test_ratio
+ self.decord_method = {
+ 'video': self.read_video,
+ 'gif': self.read_clip_gif,
+ 'frame': self.read_frame,
+ }
+
+ def __getitem__(self, index) -> Any:
+ raise NotImplementedError('')
+
+ def __str__(self):
+ len_list = {}
+ option_list = {}
+ for data in self.data_list:
+ if data['task_type'] not in len_list:
+ len_list[data['task_type']] = 0
+ len_list[data['task_type']] += 1
+ if data['task_type'] not in option_list:
+ option_list[data['task_type']] = 0
+ option_list[data['task_type']] += len(data['data']['candidates'])
+
+ correct = 0
+ total = 0
+ res = f"There are {len(self.data_list)} videos as follow:\n"
+ for k, v in len_list.items():
+ correct += len_list[k]
+ total += option_list[k]
+ res += f"{v} for {k} ({option_list[k]} options => {len_list[k]/option_list[k]*100:.2f}%)\n"
+ correct = correct + 1 / option_list[k]
+ res += f"Total random accuracy: {correct/total*100:.2f}%"
+ return res.rstrip()
+
+ def __len__(self):
+ return len(self.data_list)
+
+ def get_index(self, bound, fps, max_frame, first_idx=0):
+ if bound:
+ start, end = bound[0], bound[1]
+ else:
+ start, end = -100000, 100000
+ start_idx = max(first_idx, round(start * fps))
+ end_idx = min(round(end * fps), max_frame)
+ seg_size = float(end_idx - start_idx) / self.num_segments
+ frame_indices = np.array([
+ int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
+ for idx in range(self.num_segments)
+ ])
+ return frame_indices
+
+ def read_video(self, video_path, bound=None):
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=4)
+ max_frame = len(vr) - 1
+ fps = float(vr.get_avg_fps())
+
+ images_group = list()
+ frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
+ for frame_index in frame_indices:
+ img = Image.fromarray(vr[frame_index].asnumpy())
+ images_group.append(img)
+ return images_group
+
+ def read_gif(self, video_path, bound=None, fps=25):
+ gif = imageio.get_reader(video_path)
+ max_frame = len(gif) - 1
+
+ images_group = list()
+ frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
+ for index, frame in enumerate(gif):
+ if index in frame_indices:
+ img = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
+ img = Image.fromarray(img)
+ images_group.append(img)
+ if len(images_group) == len(frame_indices):
+ break
+
+ # might be some really short videos in the gif datasets
+ if len(images_group) < self.num_segments:
+ multiplier = int(self.num_segments/len(images_group)) + 1
+ images_group = [image for _ in range(multiplier) for image in images_group][:self.num_segments]
+ assert len(images_group) == self.num_segments
+
+ return images_group
+
+ def read_clip_gif(self, video_path, bound=None, fps=25):
+ gif = VideoFileClip(video_path)
+ frames = gif.iter_frames()
+ max_frame = gif.reader.nframes - 1
+ images_group = list()
+ frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
+ for index, frame in enumerate(frames):
+ if index in frame_indices:
+ img = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
+ img = Image.fromarray(img)
+ images_group.append(img)
+
+ # might be some really short videos in the gif datasets
+ if len(images_group) < self.num_segments:
+ multiplier = int(self.num_segments/len(images_group)) + 1
+ images_group = [image for _ in range(multiplier) for image in images_group][:self.num_segments]
+ assert len(images_group) == self.num_segments
+
+ return images_group
+
+ def read_frame(self, video_path, bound=None, fps=3):
+ max_frame = len(os.listdir(video_path))
+ images_group = list()
+ frame_indices = self.get_index(bound, fps, max_frame, first_idx=1) # frame_idx starts from 1
+ for frame_index in frame_indices:
+ img = Image.open(os.path.join(video_path, f"{frame_index:05d}.jpg"))
+ images_group.append(img)
+ return images_group
+
+ def set_rank_and_world_size(self, rank, world_size):
+ self.rank = rank
+ self.world_size = world_size
+ # self.data_list = self.data_list[::200] # debug
+ if self.test_ratio is None:
+ self.data_list = self.data_list[rank::world_size]
+ else:
+ np.random.RandomState(42).shuffle(self.data_list)
+ if isinstance(self.test_ratio, float):
+ num_samples = int(len(self.data_list) * self.test_ratio)
+ else:
+ num_samples = int(self.test_ratio)
+ self.data_list = self.data_list[rank:num_samples:world_size]
+
+
+class ChatPllava:
+ print_res=True
+ do_sample=False
+ def __init__(self, model, processor):
+ self.model = model
+ self.processor = processor
+
+ def ask(self, text, conv: Conversation, system):
+ conv.system = system
+ conv.user_query(text, )
+ return conv
+
+ def answer(self, conv: Conversation, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
+ repetition_penalty=1.0, length_penalty=1, temperature=1.0):
+ torch.cuda.empty_cache()
+ prompt = conv.get_prompt()
+ if prompt.count(conv.mm_token) < len(img_list):
+ diff_mm_num = len(img_list) - prompt.count(conv.mm_token)
+ for i in range(diff_mm_num):
+ conv.user_query("", is_mm=True)
+ prompt = conv.get_prompt()
+
+ inputs = self.processor(text=prompt, images=img_list, return_tensors="pt")
+ if inputs['pixel_values'] is None:
+ inputs.pop('pixel_values')
+ inputs = inputs.to(self.model.device)
+
+ with torch.no_grad():
+ output_token = self.model.generate(**inputs, media_type='video',
+ do_sample=self.do_sample,max_new_tokens=max_new_tokens, num_beams=num_beams, min_length=min_length,
+ top_p=top_p, repetition_penalty=repetition_penalty, length_penalty=length_penalty, temperature=temperature,
+ ) # dont need to long for the choice.
+ output_text = self.processor.batch_decode(output_token, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+
+ if self.print_res:
+ print('###PROMPT: ', prompt)
+ print('###LM OUTPUT TEXT', output_text)
+ # <|im_start|> encode and then decode would extend a space at folloing, this is insane...
+ if conv.roles[-1] == "<|im_start|>assistant\n":
+ split_tag = "<|im_start|> assistant\n"
+ else:
+ split_tag = conv.roles[-1]
+ output_text = output_text.split(split_tag)[-1].rstrip(conv.sep[1])
+ conv.assistant_response(output_text)
+ return output_text, output_token.cpu().numpy(), conv
+
+
+ def get_index(self, num_frames, num_segments):
+ seg_size = float(num_frames - 1) / num_segments
+ start = int(seg_size / 2)
+ offsets = np.array([
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
+ ])
+ return offsets
+
+ def load_video(self, video_path, num_segments=8, return_msg=False):
+ vr = VideoReader(video_path, ctx=cpu(0))
+ num_frames = len(vr)
+ frame_indices = self.get_index(num_frames, num_segments)
+
+ duration = len(vr) // vr.get_avg_fps()
+ index = np.linspace(0, len(vr)-1, num=int(duration))
+ buffer = vr.get_batch(index).asnumpy()
+ # transform
+
+ images_group = list()
+ for frame in buffer:
+ img = Image.fromarray(frame)
+ images_group.append(img)
+ images_group = list()
+ for frame_index in frame_indices:
+ img = Image.fromarray(vr[frame_index].asnumpy())
+ images_group.append(img)
+ if return_msg:
+ fps = float(vr.get_avg_fps())
+ sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
+ # " " should be added in the start and end
+ msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
+ return images_group, msg
+ else:
+ return images_group
+
+ def upload_video(self, image, conv: Conversation, img_list: list[list], num_segments=None):
+ num_segments = self.model.config.num_frames if num_segments is None else num_segments
+ if isinstance(image, str): # is a image path
+ vid, msg = self.load_video(image, num_segments=num_segments, return_msg=True)
+ else:
+ raise NotImplementedError
+ print("Input video shape:", len(vid), *vid[0].size)
+ img_list.append(vid)
+ conv.user_query("", is_mm=True)
+ msg = "Received."
+ # self.conv.append_message(self.conv.roles[1], msg)
+ return msg, img_list, conv
+
+ def upload_img(self, image, conv, img_list):
+ assert False
+ img = image#Image.open(image)#.convert('RGB')
+ transform = T.Compose(
+ [
+ T.Resize(
+ (224, 224), interpolation=InterpolationMode.BICUBIC
+ ),
+ T.ToTensor(),
+ T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ]
+ )
+
+ img = transform(img).unsqueeze(0).unsqueeze(0).cuda()
+ image_emb, _ = self.model.encode_img(img, "Observe the image and answer the question.")
+ img_list.append(image_emb)
+ conv.messages.append([
+ conv.roles[0],
+ f"\n"
+ ])
+ msg = "Received."
+ # self.conv.append_message(self.conv.roles[1], msg)
+ return msg,img_list, conv
+
+class StoppingCriteriaSub(StoppingCriteria):
+ def __init__(self, stops=[], encounters=1):
+ super().__init__()
+ self.stops = stops
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
+ for stop in self.stops:
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
+ return True
+ return False
diff --git a/tasks/eval/model_utils.py b/tasks/eval/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9396464ee5e42da2c5032ec2fa892c9d0c3efc7
--- /dev/null
+++ b/tasks/eval/model_utils.py
@@ -0,0 +1,172 @@
+
+import torch
+import os
+from peft import get_peft_model, LoraConfig, TaskType
+from safetensors import safe_open
+from peft import PeftModel
+from tasks.eval.eval_utils import Conversation
+from models.pllava import PllavaProcessor, PllavaForConditionalGeneration, PllavaConfig
+from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map,load_checkpoint_in_model
+from accelerate.utils import get_balanced_memory
+
+from transformers import StoppingCriteria
+class KeywordsStoppingCriteria(StoppingCriteria):
+ def __init__(self, keywords, tokenizer, input_ids):
+ self.keywords = keywords
+ self.tokenizer = tokenizer
+ self.start_len = None
+ self.input_ids = input_ids
+
+ def __call__(
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
+ ) -> bool:
+ if self.start_len is None:
+ self.start_len = self.input_ids.shape[1]
+ return False
+ else:
+ outputs = self.tokenizer.batch_decode(
+ output_ids[:, self.start_len:], skip_special_tokens=True
+ )
+ flag = True
+ for output in outputs:
+ for keyword in self.keywords:
+ if keyword not in output:
+ flag = False
+ return False
+ return flag
+
+
+def load_pllava(repo_id, num_frames, use_lora=False, weight_dir=None, lora_alpha=32, use_multi_gpus=False, pooling_shape=(16,12,12)):
+ kwargs = {
+ 'num_frames': num_frames,
+ }
+ # print("===============>pooling_shape", pooling_shape)
+ if num_frames == 0:
+ kwargs.update(pooling_shape=(0,12,12)) # produce a bug if ever usen the pooling projector
+ config = PllavaConfig.from_pretrained(
+ repo_id if not use_lora else weight_dir,
+ pooling_shape=pooling_shape,
+ **kwargs,
+ )
+
+ with torch.no_grad():
+ model = PllavaForConditionalGeneration.from_pretrained(repo_id, config=config, torch_dtype=torch.bfloat16)
+
+ try:
+ processor = PllavaProcessor.from_pretrained(repo_id)
+ except Exception as e:
+ processor = PllavaProcessor.from_pretrained('llava-hf/llava-1.5-7b-hf')
+
+ # config lora
+ if use_lora and weight_dir is not None:
+ print("Use lora")
+ peft_config = LoraConfig(
+ task_type=TaskType.CAUSAL_LM, inference_mode=False, target_modules=["q_proj", "v_proj"],
+ r=128, lora_alpha=lora_alpha, lora_dropout=0.
+ )
+ print("Lora Scaling:", lora_alpha/128)
+ model.language_model = get_peft_model(model.language_model, peft_config)
+ assert weight_dir is not None, "pass a folder to your lora weight"
+ print("Finish use lora")
+
+ # load weights
+ if weight_dir is not None:
+ state_dict = {}
+ save_fnames = os.listdir(weight_dir)
+ if "model.safetensors" in save_fnames:
+ use_full = False
+ for fn in save_fnames:
+ if fn.startswith('model-0'):
+ use_full=True
+ break
+ else:
+ use_full= True
+
+ if not use_full:
+ print("Loading weight from", weight_dir, "model.safetensors")
+ with safe_open(f"{weight_dir}/model.safetensors", framework="pt", device="cpu") as f:
+ for k in f.keys():
+ state_dict[k] = f.get_tensor(k)
+ else:
+ print("Loading weight from", weight_dir)
+ for fn in save_fnames:
+ if fn.startswith('model-0'):
+ with safe_open(f"{weight_dir}/{fn}", framework="pt", device="cpu") as f:
+ for k in f.keys():
+ state_dict[k] = f.get_tensor(k)
+
+ if 'model' in state_dict.keys():
+ msg = model.load_state_dict(state_dict['model'], strict=False)
+ else:
+ msg = model.load_state_dict(state_dict, strict=False)
+ print(msg)
+ # dispatch model weight
+ if use_multi_gpus:
+ max_memory = get_balanced_memory(
+ model,
+ max_memory=None,
+ no_split_module_classes=["LlamaDecoderLayer"],
+ dtype='bfloat16',
+ low_zero=False,
+ )
+
+ device_map = infer_auto_device_map(
+ model,
+ max_memory=max_memory,
+ no_split_module_classes=["LlamaDecoderLayer"],
+ dtype='bfloat16'
+ )
+
+ dispatch_model(model, device_map=device_map)
+ print(model.hf_device_map)
+
+ model = model.eval()
+
+ return model, processor
+
+
+def load_adapters(model, adapter_model_name_or_paths):
+
+ for adapter_model_name_or_path in adapter_model_name_or_paths:
+ if not isinstance(model, PeftModel):
+ model = PeftModel.from_pretrained(model, adapter_model_name_or_path, adapter_model_name_or_path)
+ else:
+ model.load_adapter(adapter_model_name_or_path, adapter_model_name_or_path)
+
+ return model
+
+
+def pllava_answer(conv: Conversation, model, processor, img_list, do_sample=True, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
+ repetition_penalty=1.0, length_penalty=1, temperature=1.0, stop_criteria_keywords=None, print_res=False):
+ # torch.cuda.empty_cache()
+ prompt = conv.get_prompt()
+ inputs = processor(text=prompt, images=img_list, return_tensors="pt")
+ if inputs['pixel_values'] is None:
+ inputs.pop('pixel_values')
+ inputs = inputs.to(model.device)
+
+ # set up stopping criteria
+ if stop_criteria_keywords is not None:
+ stopping_criteria = [KeywordsStoppingCriteria(stop_criteria_keywords, processor.tokenizer, inputs["input_ids"])]
+ else:
+ stopping_criteria= None
+
+ with torch.no_grad():
+ output_token = model.generate(**inputs, media_type='video',
+ do_sample=do_sample, max_new_tokens=max_new_tokens, num_beams=num_beams, min_length=min_length,
+ top_p=top_p, repetition_penalty=repetition_penalty, length_penalty=length_penalty, temperature=temperature,
+ stopping_criteria=stopping_criteria,)
+ output_text = processor.batch_decode(output_token, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ if "###" in output_text:
+ output_text = "###".join(output_text.split('###')[:-1]) # remove the stop sign '###'
+ if print_res: # debug usage
+ print('### PROMPTING LM WITH: ', prompt)
+ print('### LM OUTPUT TEXT: ', output_text)
+ if conv.roles[-1] == "<|im_start|>assistant\n":
+ split_tag = "<|im_start|> assistant\n"
+ else:
+ split_tag = conv.roles[-1]
+ output_text = output_text.split(split_tag)[-1].rstrip(conv.sep if isinstance(conv.sep, str) else conv.sep[1]).strip()
+ conv.messages[-1][1] = output_text
+ return output_text, conv
+
diff --git a/tasks/eval/mvbench/__init__.py b/tasks/eval/mvbench/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6f7df338b72f28d273360beda9cd45ce24262fc
--- /dev/null
+++ b/tasks/eval/mvbench/__init__.py
@@ -0,0 +1,173 @@
+import os
+import json
+from tasks.eval.eval_utils import (
+ dump_json,
+ load_json,
+ EvalDataset,
+)
+
+
+def check_ans(pred, gt):
+ flag = False
+
+ pred_list = pred.lower().split(' ')
+ pred_option, pred_content = pred_list[0], ' '.join(pred_list[1:])
+ gt_list = gt.lower().split(' ')
+ gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
+ if gt_content[-1] == '.':
+ gt_content = gt_content[:-1]
+
+ if not any([c in pred_option for c in 'abcdefgABCDEFG']):
+ print(f"model doesn't follow instructions: {pred}")
+ elif pred_option.replace('.', '') in gt_option:
+ flag = True
+ elif gt_option in pred_option:
+ flag = True
+
+ return flag
+
+def save_results(result_list, save_path):
+
+ final_res, acc_dict = {}, {}
+ correct, total = 0, 0
+ for res in result_list:
+ task_type = res['task_type']
+ if task_type not in acc_dict:
+ acc_dict[task_type] = [0, 0] # correct, total
+ acc_dict[task_type][1] += 1
+ total += 1
+ pred = res['pred']
+ gt = res['gt']
+ if check_ans(pred=pred, gt=gt):
+ acc_dict[task_type][0] += 1
+ correct += 1
+
+ for k, v in acc_dict.items():
+ final_res[k] = v[0] / v[1] * 100
+ correct += v[0]
+ total += v[1]
+ final_res['Avg'] = correct / total * 100
+
+ all_results = {
+ "acc_dict": acc_dict,
+ "result_list": result_list
+ }
+ dump_json(all_results, save_path, 'all_results.json')
+ dump_json(final_res, save_path, 'upload_leaderboard.json')
+
+def load_results(save_path):
+ all_results = load_json(save_path, 'all_results.json')
+ if all_results is not None:
+ result_list = all_results['result_list']
+ else:
+ result_list = None
+ # json_data = load_json(save_path, 'all_results.json')['result_list']
+ return result_list
+
+class MVBenchDataset(EvalDataset):
+ data_list_info = {
+ # "task_type (sub task name)": ("json file name", "image/video prefix", "data_type", "bound")
+ "Action Sequence": ("action_sequence.json", "DATAS/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end
+ "Action Prediction": ("action_prediction.json", "DATAS/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end
+ "Action Antonym": ("action_antonym.json", "DATAS/MVBench/video/ssv2_video/", "video", False),
+ "Fine-grained Action": ("fine_grained_action.json", "DATAS/MVBench/video/Moments_in_Time_Raw/videos/", "video", False),
+ "Unexpected Action": ("unexpected_action.json", "DATAS/MVBench/video/FunQA_test/test/", "video", False),
+ "Object Existence": ("object_existence.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False),
+ "Object Interaction": ("object_interaction.json", "DATAS/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end
+ "Object Shuffle": ("object_shuffle.json", "DATAS/MVBench/video/perception/videos/", "video", False),
+ "Moving Direction": ("moving_direction.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False),
+ "Action Localization": ("action_localization.json", "DATAS/MVBench/video/sta/sta_video/", "video", True), # has start & end
+ "Scene Transition": ("scene_transition.json", "DATAS/MVBench/video/scene_qa/video/", "video", False),
+ "Action Count": ("action_count.json", "DATAS/MVBench/video/perception/videos/", "video", False),
+ "Moving Count": ("moving_count.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False),
+ "Moving Attribute": ("moving_attribute.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False),
+ "State Change": ("state_change.json", "DATAS/MVBench/video/perception/videos/", "video", False),
+ "Fine-grained Pose": ("fine_grained_pose.json", "DATAS/MVBench/video/nturgbd/", "video", False),
+ "Character Order": ("character_order.json", "DATAS/MVBench/video/perception/videos/", "video", False),
+ "Egocentric Navigation": ("egocentric_navigation.json", "DATAS/MVBench/video/vlnqa/", "video", False),
+ "Episodic Reasoning": ("episodic_reasoning.json", "DATAS/MVBench/video/tvqa/frames_fps3_hq/", "frame", True), # has start & end, read frame
+ "Counterfactual Inference": ("counterfactual_inference.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False),
+ }
+ data_dir = "DATAS/MVBench/json"
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ data_list_info = self.data_list_info
+ data_dir = self.data_dir
+
+ self.data_list = []
+ for k, v in data_list_info.items():
+ with open(os.path.join(data_dir, v[0]), 'r') as f:
+ json_data = json.load(f)
+ for data in json_data:
+ self.data_list.append({
+ 'task_type': k,
+ 'prefix': v[1],
+ 'data_type': v[2],
+ 'bound': v[3],
+ 'data': data
+ })
+ # self.data_list = self.data_list[:100] # for debug
+ self.decord_method = {
+ 'video': self.read_video,
+ 'gif': self.read_gif,
+ 'frame': self.read_frame,
+ }
+
+ # # transform
+ # crop_size = resolution
+ # scale_size = resolution
+ # input_mean = [0.48145466, 0.4578275, 0.40821073]
+ # input_std = [0.26862954, 0.26130258, 0.27577711]
+ # self.transform = T.Compose([
+ # GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
+ # GroupCenterCrop(crop_size),
+ # Stack(),
+ # ToTorchFormatTensor(),
+ # GroupNormalize(input_mean, input_std)
+ # ])
+
+ def __getitem__(self, idx):
+ question, answer = self.qa_template(self.data_list[idx]['data'])
+ task_type = self.data_list[idx]['task_type']
+ decord_method = self.decord_method[self.data_list[idx]['data_type']]
+ bound = None
+ if self.data_list[idx]['bound']:
+ bound = (
+ self.data_list[idx]['data']['start'],
+ self.data_list[idx]['data']['end'],
+ )
+ video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video'])
+
+
+ # images_group = decord_method(video_path, bound)
+ try: # might be problem with decord
+ images_group = decord_method(video_path, bound)
+ except Exception as e:
+ print(f'error decoding {video_path}')
+ task_type = 'error_reading_video'
+ images_group = None
+
+ return {
+ 'video_path': video_path,
+ 'video_pils': images_group, # some might use the original pils and do their own transforms
+ 'question': question,
+ 'answer': answer,
+ 'task_type': task_type,
+ }
+
+
+ def qa_template(self, data):
+ question = f"Question: {data['question']}\n"
+ question += "Options:\n"
+ answer = data['answer']
+ answer_idx = -1
+ for idx, c in enumerate(data['candidates']):
+ question += f"({chr(ord('A') + idx)}) {c}\n"
+ if c == answer:
+ answer_idx = idx
+ question = question.rstrip()
+ answer = f"({chr(ord('A') + answer_idx)}) {answer}"
+ return question, answer
+
diff --git a/tasks/eval/mvbench/pllava_eval_mvbench.py b/tasks/eval/mvbench/pllava_eval_mvbench.py
new file mode 100644
index 0000000000000000000000000000000000000000..117785e9b860b965a1549ba64163de516cf9748f
--- /dev/null
+++ b/tasks/eval/mvbench/pllava_eval_mvbench.py
@@ -0,0 +1,278 @@
+
+import functools
+import itertools
+import logging
+from tqdm import tqdm
+from PIL import Image
+from multiprocessing import Pool
+import multiprocessing as mp
+from argparse import ArgumentParser
+import numpy as np
+
+import torch
+import torchvision
+
+from decord import VideoReader, cpu
+import transformers
+
+
+from tasks.eval.model_utils import load_pllava, pllava_answer
+from tasks.eval.eval_utils import conv_templates
+from tasks.eval.mvbench import (
+ MVBenchDataset,
+ check_ans,
+ save_results,
+ load_results,
+)
+
+logging.basicConfig()
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+RESOLUTION = 672 #
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ required=True,
+ default='llava-hf/llava-1.5-7b-hf'
+ )
+ parser.add_argument(
+ "--save_path",
+ type=str,
+ required=True,
+ default='"./test_results/test_llava_mvbench"'
+ )
+ parser.add_argument(
+ "--num_frames",
+ type=int,
+ required=True,
+ default=4,
+ )
+ parser.add_argument(
+ "--use_lora",
+ action='store_true'
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ required=False,
+ default=32,
+ )
+ parser.add_argument(
+ "--weight_dir",
+ type=str,
+ required=False,
+ default=None,
+ )
+ parser.add_argument(
+ "--conv_mode",
+ type=str,
+ required=False,
+ default='eval_mvbench',
+ )
+ parser.add_argument(
+ "--pooling_shape",
+ type=str,
+ required=False,
+ default=None,
+ )
+ args = parser.parse_args()
+ return args
+
+def load_model_and_dataset(rank, world_size, pretrained_model_name_or_path, num_frames, use_lora, lora_alpha, weight_dir, pooling_shape=(16,12,12)):
+ # remind that, once the model goes larger (30B+) may cause the memory to be heavily used up. Even Tearing Nodes.
+ model, processor = load_pllava(pretrained_model_name_or_path, num_frames=num_frames, use_lora=use_lora, weight_dir=weight_dir, lora_alpha=lora_alpha, pooling_shape=pooling_shape)
+ logger.info('done loading llava')
+
+ # position embedding
+ model = model.to(torch.device(rank))
+ model = model.eval()
+
+ dataset = MVBenchDataset(num_segments=num_frames)
+ dataset.set_rank_and_world_size(rank, world_size)
+ return model, processor, dataset
+
+def infer_mvbench(
+ model,
+ processor,
+ data_sample,
+ conv_mode,
+ pre_query_prompt=None, # add in the head of question
+ post_query_prompt=None, # add in the end of question
+ answer_prompt=None, # add in the begining of answer
+ return_prompt=None, # add in the begining of return message
+ print_res=False,
+ ):
+ video_list = data_sample["video_pils"]
+ conv = conv_templates[conv_mode].copy()
+ conv.user_query(data_sample['question'], pre_query_prompt, post_query_prompt, is_mm=True)
+ if answer_prompt is not None:
+ conv.assistant_response(answer_prompt)
+
+ llm_message, conv = pllava_answer(
+ conv=conv,
+ model=model,
+ processor=processor,
+ img_list=video_list,
+ max_new_tokens=100,
+ do_sample=False,
+ print_res=print_res
+ )
+
+ if answer_prompt is not None:
+ llm_message = ''.join(llm_message.split(answer_prompt)[1:])
+
+ if return_prompt is not None:
+ llm_message = return_prompt + llm_message
+
+ return llm_message
+
+def single_test(model, processor, vid_path, num_frames=4, conv_mode="plain"):
+ def get_index(num_frames, num_segments):
+ seg_size = float(num_frames - 1) / num_segments
+ start = int(seg_size / 2)
+ offsets = np.array([
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
+ ])
+ return offsets
+
+ def load_video(video_path, num_segments=8, return_msg=False, num_frames=4, resolution=336):
+ transforms = torchvision.transforms.Resize(size=resolution)
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
+ num_frames = len(vr)
+ frame_indices = get_index(num_frames, num_segments)
+ images_group = list()
+ for frame_index in frame_indices:
+ img = Image.fromarray(vr[frame_index].asnumpy())
+ images_group.append(transforms(img))
+ if return_msg:
+ fps = float(vr.get_avg_fps())
+ sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
+ # " " should be added in the start and end
+ msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
+ return images_group, msg
+ else:
+ return images_group
+
+ if num_frames != 0:
+ vid, msg = load_video(vid_path, num_segments=num_frames, return_msg=True, resolution=RESOLUTION)
+ else:
+ vid, msg = None, 'num_frames is 0, not inputing image'
+ img_list = vid
+ conv = conv_templates[conv_mode].copy()
+ conv.user_query("Describe the video in details.", is_mm=True)
+ llm_response, conv = pllava_answer(conv=conv, model=model, processor=processor, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)
+
+def run(rank, args, world_size):
+ if rank != 0:
+ transformers.utils.logging.set_verbosity_error()
+ logger.setLevel(transformers.logging.ERROR)
+
+ print_res = False
+ conv_mode= args.conv_mode
+ pre_query_prompt = None
+ post_query_prompt = "\nOnly give the best option."
+ if args.pooling_shape is not None:
+ pooling_shape=tuple([int(x) for x in args.pooling_shape.split("-")])
+
+ logger.info(f'loading model and constructing dataset to gpu {rank}...')
+ model, processor, dataset = load_model_and_dataset(rank,
+ world_size,
+ pretrained_model_name_or_path=args.pretrained_model_name_or_path,
+ num_frames=args.num_frames,
+ use_lora=args.use_lora,
+ lora_alpha=args.lora_alpha,
+ weight_dir=args.weight_dir,
+ pooling_shape=pooling_shape)
+ logger.info(f'done model and dataset...')
+ logger.info('constructing dataset...')
+ logger.info('single test...')
+
+ vid_path = "./example/yoga.mp4"
+ # vid_path = "./example/jesse_dance.mp4"
+ if rank == 0:
+ single_test(model,
+ processor,
+ vid_path,
+ num_frames=args.num_frames,
+ conv_mode=args.conv_mode)
+ logger.info('single test done...')
+ tbar = tqdm(total=len(dataset))
+
+ correct = 0
+ total = 0
+ result_list = []
+ acc_dict = {}
+ done_count = 0
+
+ for example in dataset:
+ task_type = example['task_type']
+ if task_type not in acc_dict:
+ acc_dict[task_type] = [0, 0] # correct, total
+ acc_dict[task_type][1] += 1
+ total += 1
+ pred = infer_mvbench(
+ model,
+ processor,
+ example,
+ conv_mode=conv_mode,
+ pre_query_prompt=pre_query_prompt,
+ post_query_prompt=post_query_prompt,
+ answer_prompt="Best option:(",
+ return_prompt='(',
+ print_res=print_res,
+ )
+ gt = example['answer']
+ result_list.append({
+ 'pred': pred,
+ 'gt': gt,
+ 'task_type': task_type,
+ 'video_path': example['video_path'],
+ 'question': example['question'],
+
+ })
+ if check_ans(pred=pred, gt=gt):
+ acc_dict[task_type][0] += 1
+ correct += 1
+ if rank == 0:
+ tbar.update(len(result_list) - done_count, )
+ tbar.set_description_str(
+ f"One Chunk--Task Type: {task_type}, Chunk Part Acc: {acc_dict[task_type][0] / acc_dict[task_type][1] * 100 :.2f}%;"
+ f" Chunk Total Acc: {correct / total * 100 :.2f}%"
+ )
+ done_count = len(result_list)
+ return result_list
+
+def main():
+ multiprocess=True
+ mp.set_start_method('spawn')
+ args = parse_args()
+ save_path = args.save_path
+ json_data = load_results(save_path)
+ if json_data is None:
+ if multiprocess:
+ logger.info(f'started benchmarking, saving to: {save_path}')
+ n_gpus = torch.cuda.device_count()
+ # assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
+ world_size = n_gpus
+ with Pool(world_size) as pool:
+ func = functools.partial(run, args=args, world_size=world_size)
+ result_lists = pool.map(func, range(world_size))
+
+ logger.info('finished running')
+ result_list = [ res for res in itertools.chain(*result_lists)]
+ else:
+ result_list = run(0, world_size=1, args=args) # debug
+
+ else:
+ logger.info(f'loaded results from {save_path}')
+ result_list = json_data
+ save_results(result_list, save_path)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/tasks/eval/recaption/__init__.py b/tasks/eval/recaption/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e68d57bf16c6ca58666e4eac512957baad2938fe
--- /dev/null
+++ b/tasks/eval/recaption/__init__.py
@@ -0,0 +1,293 @@
+from functools import partial
+import os
+import json
+from typing import OrderedDict
+
+import tqdm
+import torch
+from PIL import Image
+import ast
+import numpy as np
+from multiprocessing import Pool
+
+from decord import VideoReader, cpu
+
+import os
+from tasks.eval.eval_utils import (
+ dump_json,
+ load_json,
+ EvalDataset,
+)
+from dataclasses import dataclass
+from openai import OpenAI
+from utils.easydict import EasyDict
+client = OpenAI(
+ # This is the default and can be omitted
+ api_key=os.environ.get("OPENAI_API_KEY"),
+)
+
+task_type2chatgpt_contents = OrderedDict({
+ "Panda70M": {
+ "system": "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for video captioning. "
+ "Your task is to compare the predicted captioning with a provided hint (which is usually a ground truth caption provided by human labor or autmated captioning pipeline)."
+ "You should determine if they match meaningfully, logically and precisely. Here's how you can accomplish the task:"
+ "------"
+ "##INSTRUCTIONS: "
+ "- Focus on the meaningful match between the predicted answer and the correct answer.\n"
+ "- Consider synonyms or paraphrases as valid matches.\n"
+ "- Evaluate the correctness of the prediction compared to the answer.",
+ "user": """Please evaluate the following video-based Captioning pair:\n\n"""
+ """Caption: {caption}\n"""
+ """Predicted Caption: {pred}\n\n"""
+ """Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. """
+ """Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."""
+ """DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. """
+ """For example, your response should look like this: {{'pred': 'yes', 'score': 4.8}}."""
+ },
+})
+
+# Follow the instructions carefully and be helpful and precise with your answer.
+
+def check_ans_recaption(pred, gt, task_type, model="gpt-3.5-turbo-0125"):
+ try:
+ # Compute the temporal understanding score
+ user_input = task_type2chatgpt_contents[task_type]['user']
+ user_input = user_input.format(caption=gt, pred=pred)
+ completion = client.chat.completions.create(
+ model=model,
+ messages=[
+ {
+ "role": "system",
+ "content": task_type2chatgpt_contents[task_type]['system'],
+ },
+ {
+ "role": "user",
+ "content": user_input,
+ }
+ ]
+ )
+ # Convert response to a Python dictionary.
+ # response_message = completion["choices"][0]["message"]["content"]
+ response_message = completion.choices[0].message.content
+ num_tokens_openai = completion.usage.total_tokens
+ response_dict = ast.literal_eval(response_message)
+ pred = response_dict['pred']
+ score = response_dict['score']
+ if not pred in ('yes', 'no') or not isinstance(score, (int, float)):
+ raise ValueError(f"{model} doesn't follow")
+ flag = pred == 'yes'
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ flag, score, num_tokens_openai = False, 0, 0
+ print(
+ f"GPT cannot deal with:\n"
+ f"--pred: {pred}\n"
+ f"--gt: {gt}\n"
+ f"--gpt responded: {response_message}\n"
+ "--will assign flag=False and score=0"
+ )
+ print(f"Dumb Answer in {task_type}")
+ return flag, score, num_tokens_openai
+
+def chatgpt_eval(res, model="gpt-3.5-turbo-0125"):
+ pred = res['pred']
+ gt = res['caption']
+ task_type = res['task_type']
+ correct, score, num_tokens_openai = check_ans_recaption(pred=pred, gt=gt,task_type=task_type, model=model) # acc is bool, score is given by chatgpt
+ # update the scores in result_list for this sample
+ res['score'] = score
+ res['correct'] = correct
+ res['num_tokens_openai'] = num_tokens_openai
+ return res
+
+def save_results(result_list, save_path, model="gpt-3.5-turbo-0125"):
+ dump_json(result_list, save_path, 'inference_results.json')
+ with Pool(7) as pool:
+ func = partial(chatgpt_eval, model=model)
+ result_list = [ res for res in tqdm.tqdm(pool.imap_unordered(func, result_list), total=len(result_list), desc='Language Chat Model Automated Evaluation...')]
+
+ # result_list = [chatgpt_eval(res, model=model) for res in result_list]
+
+ final_res, acc_dict = {}, {}
+ correct, total, total_score = 0, 0, 0
+ for i, res in enumerate(result_list):
+ task_type = res['task_type']
+ if task_type not in acc_dict:
+ acc_dict[task_type] = {
+ 'correct': 0,
+ 'total': 0,
+ 'score': 0,
+ } # correct, total
+ acc_dict[task_type]['total'] += 1
+ acc_dict[task_type]['correct'] += res['correct']
+ acc_dict[task_type]['score'] += res['score']
+
+ for k, v in acc_dict.items():
+ final_res[k] = {
+ 'acc': v['correct'] / v['total'] * 100,
+ 'score': v['score'] / v['total']
+ }
+ correct += v['correct']
+ total += v['total']
+ total_score += v['score']
+
+ final_res['Avg_Acc'] = correct / total * 100
+ final_res['Avg_Score'] = total_score / total
+
+ all_results = {
+ "acc_dict": acc_dict,
+ "result_list": result_list
+ }
+ dump_json(all_results, save_path, f'final_results-{model}.json')
+ dump_json(final_res, save_path, 'upload_leaderboard.json')
+
+def load_results(save_path, model="gpt-3.5-turbo-0125"):
+ result_list = load_json(save_path, f'final_results-{model}.json')
+ if result_list is not None:
+ result_list = result_list['result_list']
+
+ if result_list is None:
+ result_list = load_json(save_path, 'inference_results.json')
+
+ return result_list
+
+class CaptionSample(EasyDict):
+ def get_info(self):
+ return {}
+
+class RecaptionSample(EasyDict):
+ caption: str
+ def get_info(self):
+ # template = ("""To facilitate success in the task, I'll offer hints from the automated image captioning pipeline's output on the frames. """
+ # """Please note that this information may contain noise but remains descriptive."""
+ # """Presented below are the noisy details:\n"""
+ # """Hint: {hint}\n"""
+ # """The hint comprises noisy captions generated for certain frames in the video. """
+ # """Please refrain from disclosing the original hints provided; instead, provide rewritten accurate information.""")
+ # hint = template.format(hint=self.hint,)
+ return {
+ "noisy_caption": self.caption
+ }
+
+class RecaptionSampleWithMatchingScore(EasyDict):
+ caption: str
+ matching_score: float
+
+ def get_info(self):
+ # template = ("""To facilitate success in the task, I'll offer hints from the automated image captioning pipeline's output on the frames. """
+ # """Please note that this information may contain noise but remains descriptive."""
+ # """Presented below are the noisy details:\n"""
+ # """Hint: {hint}\n"""
+ # """Matching Score: {matching_score:.02f}\n"""
+ # """The hint comprises noisy captions generated for certain frames in the video. """
+ # """Matching scores indicate the likelihood of these captions matching the original frames.\n"""
+ # """Please refrain from disclosing the original hints provided; instead, provide rewritten accurate information."""
+ # )
+
+ # hint = template.format(hint=self.hint,
+ # matching_score=self.matching_score)
+ info = {
+ "noisy_caption": self.caption,
+ "matching_score": self.matching_score,
+ }
+ # by far, might use some prompting.
+ return info
+
+class RecaptionDataset(EvalDataset):
+ data_dir = "DATAS/Recaption"
+ data_list_info = OrderedDict({
+ # "Panda70M": OrderedDict(
+ # json_relpath="Panda70M/annotations.json",
+ # prefix="DATAS/Recaption/Panda70M/videos",
+ # data_type="video",
+ # bound=False,
+ # key_rename_map={
+ # # 'caption': 'hint',
+ # },
+ # name_key='video_name',
+ # postfix=('mp4', 'mkv', 'webm'),
+ # recaption_type=RecaptionSample,
+ # ), # don't has start & end
+ "Inter4K": OrderedDict(
+ json_relpath="Inter4K/annotations.json",
+ prefix="DATAS/Recaption/Inter4K/60fps/UHD",
+ data_type="video",
+ bound=False,
+ key_rename_map={
+ # 'caption': 'hint',
+ },
+ name_key='video_name',
+ postfix=('mp4', 'mkv', 'webm'),
+ recaption_type=CaptionSample,
+ ), # don't has start & end
+ })
+
+ def __init__(self, *args, **kwargs):
+ # recaption's test_ratio should shuffle the dataset
+ test_ratio = kwargs.pop('test_ratio', None)
+ super().__init__(*args, **kwargs)
+ self.test_ratio = test_ratio
+ test_ratio = 1. if test_ratio is None else test_ratio
+ data_list_info = self.data_list_info
+ data_dir = self.data_dir
+
+ self.data_list = []
+ for k, v in data_list_info.items():
+ with open(os.path.join(data_dir, v['json_relpath']), 'r') as f:
+ annotation_json_data = json.load(f)
+
+ indexs = list(range(len(annotation_json_data)))
+ np.random.RandomState(42).shuffle(indexs)
+ num_samples = int(len(indexs) * test_ratio) if 0 < test_ratio <= 1 else int(test_ratio)
+ indexs = indexs[:num_samples]
+ for i in indexs:
+ annotation_data = annotation_json_data[i]
+ for key_old, key_new in v['key_rename_map'].items():
+ # temporary renameing the keys
+ value = annotation_data.pop(key_old)
+ annotation_data[key_new] = value
+
+ data = dict(annotation_data)
+ self.data_list.append({
+ 'task_type': k,
+ 'data': data,
+ })
+
+ def __getitem__(self, idx):
+ task_type = self.data_list[idx]['task_type']
+ decord_method = self.decord_method[self.data_list_info[task_type]['data_type']]
+ bound = None
+
+ if self.data_list_info[task_type]['bound']:
+ bound = (
+ self.data_list[idx]['data']['start'],
+ self.data_list[idx]['data']['end'],
+ )
+ video_name_key = self.data_list_info[task_type]['name_key']
+ video_name = self.data_list[idx]['data'][video_name_key]
+
+ video_postfixs = self.data_list_info[task_type]['postfix']
+ video_paths = []
+ for p in video_postfixs:
+ video_path = os.path.join(self.data_list_info[task_type]['prefix'], video_name + '.' + p)
+ if os.path.exists(video_path):
+ video_paths.append(video_path)
+ assert len(video_paths) > 0, f'no video named {video_name}'
+ # video_filename = self.data_list[idx]['data'][video_name_key] + video_postfix
+ video_path = video_paths[0]
+ images_group = decord_method(video_path, bound)
+
+ sample = self.data_list_info[task_type]['recaption_type'](**self.data_list[idx]['data'],)
+ info = sample.get_info()
+
+ return {
+ 'video_pils': images_group, # some might use the original pils and do their own transforms
+ 'video_path': video_path,
+ 'info': info,
+ 'sample': sample,
+ 'task_type': task_type,
+ }
+
+
+
diff --git a/tasks/eval/recaption/pllava_recaption.py b/tasks/eval/recaption/pllava_recaption.py
new file mode 100644
index 0000000000000000000000000000000000000000..8530b8ee181a5db7acaad6332734486d79c9b516
--- /dev/null
+++ b/tasks/eval/recaption/pllava_recaption.py
@@ -0,0 +1,294 @@
+
+import functools
+import itertools
+import json
+import logging
+from tqdm import tqdm
+from PIL import Image
+from multiprocessing import Pool
+from argparse import ArgumentParser
+import multiprocessing as mp
+
+
+
+import numpy as np
+import torch
+
+import torchvision
+
+import transformers
+from decord import VideoReader, cpu
+
+from tasks.eval.model_utils import load_pllava, pllava_answer
+from tasks.eval.eval_utils import conv_templates
+
+logging.basicConfig()
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+IMAGE_TOKEN=''
+from tasks.eval.recaption import (
+ RecaptionDataset,
+ load_results,
+ save_results,
+)
+RESOLUTION = 672 #
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ required=True,
+ default='llava-hf/llava-1.5-7b-hf'
+ )
+ parser.add_argument(
+ "--save_path",
+ type=str,
+ required=True,
+ default='"./test_results/test_llava_mvbench"'
+ )
+ parser.add_argument(
+ "--num_frames",
+ type=int,
+ required=True,
+ default=4,
+ )
+ parser.add_argument(
+ "--use_lora",
+ action='store_true'
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ required=False,
+ default=32,
+ )
+ parser.add_argument(
+ "--weight_dir",
+ type=str,
+ required=False,
+ default=None,
+ )
+ parser.add_argument(
+ "--eval_model",
+ type=str,
+ required=False,
+ default="gpt-3.5-turbo-0125",
+ )
+ parser.add_argument(
+ '--test_ratio',
+ type=float,
+ required=False,
+ default=None
+ )
+ parser.add_argument(
+ "--conv_mode",
+ type=str,
+ required=False,
+ default='eval_videoqabench',
+ )
+ args = parser.parse_args()
+ return args
+
+def load_model_and_dataset(rank, world_size, pretrained_model_name_or_path, num_frames, use_lora, lora_alpha, weight_dir, test_ratio):
+ # remind that, once the model goes larger (30B+) may cause the memory to be heavily used up. Even Tearing Nodes.
+ model, processor = load_pllava(pretrained_model_name_or_path, num_frames=num_frames, use_lora=use_lora, lora_alpha=lora_alpha, weight_dir=weight_dir)
+ logger.info('done loading llava')
+ # position embedding
+ model = model.to(torch.device(rank))
+ model = model.eval()
+
+ dataset = RecaptionDataset(test_ratio=test_ratio, num_segments=num_frames)
+ dataset.set_rank_and_world_size(rank, world_size)
+ return model, processor, dataset
+
+def infer_recaption(
+ model,
+ processor,
+ data_sample,
+ conv_mode,
+ pre_query_prompt=None, # add in the head of question
+ post_query_prompt=None, # add in the end of question
+ answer_prompt=None, # add in the begining of answer
+ return_prompt=None, # add in the begining of return message
+ print_res=False,
+ ):
+ video_list = data_sample["video_pils"]
+ conv = conv_templates[conv_mode].copy()
+ # info = data_sample['info']
+ query = (
+ "You are to assist me in accomplishing a task about the input video. Reply to me with a precise yet detailed response. For how you would succeed in the recaptioning task, read the following Instructions section and Then, make your response with a elaborate paragraph.\n"
+ "# Instructions\n"
+ "1. Avoid providing over detailed information such as color, counts of any objects as you are terrible regarding observing these details\n"
+ "2. Instead, you should carefully go over the provided video and reason about key information about the overall video\n"
+ "3. If you are not sure about something, do not include it in you response.\n"
+ "# Task\n"
+ "Describe the background, characters and the actions in the provided video.\n"
+ )
+ conv.user_query(query, pre_query_prompt, post_query_prompt, is_mm=True)
+ if answer_prompt is not None:
+ conv.assistant_response(answer_prompt)
+
+ llm_message, conv = pllava_answer(
+ conv=conv,
+ model=model,
+ processor=processor,
+ img_list=video_list,
+ max_new_tokens=400,
+ num_beams=1,
+ do_sample=False,
+ print_res=print_res
+ )
+
+ if answer_prompt is not None:
+ llm_message = ''.join(llm_message.split(answer_prompt)[1:])
+
+ if return_prompt is not None:
+ llm_message = return_prompt + llm_message
+
+ return llm_message, query
+
+def single_test(model, processor, vid_path, num_frames=4, conv_mode="plain"):
+ def get_index(num_frames, num_segments):
+ seg_size = float(num_frames - 1) / num_segments
+ start = int(seg_size / 2)
+ offsets = np.array([
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
+ ])
+ return offsets
+
+ def load_video(video_path, num_segments=8, return_msg=False, num_frames=4, resolution=336):
+ transforms = torchvision.transforms.Resize(size=resolution)
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
+ num_frames = len(vr)
+ frame_indices = get_index(num_frames, num_segments)
+ images_group = list()
+ for frame_index in frame_indices:
+ img = Image.fromarray(vr[frame_index].asnumpy())
+ images_group.append(transforms(img))
+ if return_msg:
+ fps = float(vr.get_avg_fps())
+ sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
+ # " " should be added in the start and end
+ msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
+ return images_group, msg
+ else:
+ return images_group
+
+ if num_frames != 0:
+ vid, msg = load_video(vid_path, num_segments=num_frames, return_msg=True, resolution=RESOLUTION)
+ else:
+ vid, msg = None, 'num_frames is 0, not inputing image'
+ img_list = vid
+
+ conv = conv_templates[conv_mode].copy()
+ conv.user_query("Describe the video in details.", is_mm=True)
+ llm_response, conv = pllava_answer(conv=conv, model=model, processor=processor, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)
+
+def run(rank, args, world_size):
+ if rank != 0:
+ transformers.utils.logging.set_verbosity_error()
+ logger.setLevel(transformers.logging.ERROR)
+
+ print_res = True
+ conv_mode= args.conv_mode
+ pre_query_prompt = None
+ post_query_prompt = None
+
+ # pre_query_prompt = ("""Assist me in detailing the background, characters, and actions depicted in the provided video.\n""")
+ # post_query_prompt = ("""My apologies for any lack of precision; there may be errors in the supplementary information provided.\n"""
+ # """You are encouraged to be discerning and perceptive, paying attention to the minutest details, """
+ # """and to furnish a detailed yet precise description using eloquent language.""")
+
+ logger.info(f'loading model and constructing dataset to gpu {rank}...')
+ model, processor, dataset = load_model_and_dataset(rank,
+ world_size,
+ pretrained_model_name_or_path=args.pretrained_model_name_or_path,
+ num_frames=args.num_frames,
+ use_lora=args.use_lora,
+ lora_alpha=args.lora_alpha,
+ weight_dir=args.weight_dir,
+ test_ratio=args.test_ratio)
+ logger.info(f'done model and dataset...')
+ logger.info('constructing dataset...')
+ logger.info('single test...')
+ vid_path = "./example/yoga.mp4"
+ # vid_path = "./example/jesse_dance.mp4"
+ if rank == 0:
+ single_test(model, processor, vid_path, num_frames=args.num_frames)
+ logger.info('single test done...')
+ tbar = tqdm(total=len(dataset))
+ logger.info('single test...')
+
+ result_list = []
+ done_count = 0
+ for example in dataset:
+ task_type = example['task_type']
+ if task_type in dataset.data_list_info:
+ pred, query = infer_recaption(
+ model,
+ processor,
+ example,
+ conv_mode=conv_mode,
+ pre_query_prompt=pre_query_prompt,
+ post_query_prompt=post_query_prompt,
+ print_res=print_res,
+ )
+
+ infos = {k: v for k, v in example['sample'].items() if isinstance(v, (str, float, int))}
+ res = {
+ 'pred': pred,
+ 'task_type': task_type,
+ 'video_path': example['video_path'],
+ 'query': query,
+ **infos
+ }
+ else:
+ raise NotImplementedError(f'not implemented task type {task_type}')
+ # res = chatgpt_eval(res)
+ result_list.append(res)
+ if rank == 0:
+ tbar.update(len(result_list) - done_count, )
+ tbar.set_description_str(
+ f"One Chunk--Task Type: {task_type}-"
+ f"pred: {pred[:min(15, len(pred))]}......"
+ )
+ done_count = len(result_list)
+ return result_list
+
+def main():
+ multiprocess=True
+ mp.set_start_method('spawn')
+ args = parse_args()
+ save_path = args.save_path
+ eval_model = args.eval_model
+ logger.info(f'trying loading results from {save_path}')
+ result_list = load_results(save_path, model=args.eval_model)
+
+ if result_list is None:
+ if multiprocess:
+
+ logger.info(f'started benchmarking, saving to: {save_path}')
+ n_gpus = torch.cuda.device_count()
+ # assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
+ world_size = n_gpus
+ with Pool(world_size) as pool:
+ func = functools.partial(run, args=args, world_size=world_size)
+ # func = functools.partial(run, world_size=world_size, model=model, dataset=dataset, result_list=[], acc_dict={})
+ result_lists = pool.map(func, range(world_size))
+
+ logger.info('finished running')
+
+ result_list = [ res for res in itertools.chain(*result_lists)]
+ else:
+ result_list = run(0, world_size=1, args=args) # debug
+ else:
+ logger.info(f'loaded results from {save_path}')
+
+ save_results(result_list, save_path, model=eval_model)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/tasks/eval/recaption/show_recaption.py b/tasks/eval/recaption/show_recaption.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b3c00a0a775d485ff71c48dff2ea0d16ff446ec
--- /dev/null
+++ b/tasks/eval/recaption/show_recaption.py
@@ -0,0 +1,52 @@
+
+import argparse
+import gradio as gr
+
+from tasks.eval.recaption import load_results
+import json
+
+# example = videogallery().example_inputs()
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--save_path',
+ required=True,
+ )
+ args = parser.parse_args()
+ return args
+
+
+args = parse_args()
+result_list = load_results(args.save_path)
+
+
+def show(result_index, ):
+ info = result_list[result_index]
+ video_path = info['video_path']
+ info_str = json.dumps(info, indent=4)
+ return video_path, info_str
+
+
+
+from tasks.eval.recaption import load_results
+
+with gr.Blocks() as demo:
+ gr.Markdown("# Showing of what has came out.")
+ gr.Markdown(f"From Saved Results {args.save_path}")
+ with gr.Row():
+ with gr.Column(1):
+ show_video = gr.Video(interactive=False)
+
+ with gr.Column():
+ result_index = gr.Slider(0, len(result_list), step=1)
+ info = gr.Text(interactive=False)
+
+ result_index.change(show, [result_index], [show_video, info])
+
+
+
+
+
+demo.launch(share=True)
diff --git a/tasks/eval/vcgbench/__init__.py b/tasks/eval/vcgbench/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bad2cdd0b7bdf89806b75eebeaf757b45c5a7a9e
--- /dev/null
+++ b/tasks/eval/vcgbench/__init__.py
@@ -0,0 +1,397 @@
+import ast
+import os
+import json
+from typing import OrderedDict
+from multiprocessing import Pool
+from functools import partial
+
+import tqdm
+
+from tasks.eval.eval_utils import (
+ dump_json,
+ load_json,
+ EvalDataset,
+)
+
+from openai import OpenAI
+client = OpenAI(
+ # This is the default and can be omitted
+ api_key=os.environ.get("OPENAI_API_KEY"),
+)
+
+sub_task_type2chatgpt_contents = OrderedDict({
+ # general ones
+ 'temporal': {
+ "system": "You are an intelligent chatbot designed for evaluating the temporal understanding of generative outputs for video-based question-answer pairs. "
+ "Your task is to compare the predicted answer with the correct answer and determine if they correctly reflect the temporal sequence of events in the video content. Here's how you can accomplish the task:"
+ "------"
+ "##INSTRUCTIONS: "
+ "- Focus on the temporal consistency between the predicted answer and the correct answer. The predicted answer should correctly reflect the sequence of events or details as they are presented in the video content.\n"
+ "- Consider synonyms or paraphrases as valid matches, but only if the temporal order is maintained.\n"
+ "- Evaluate the temporal accuracy of the prediction compared to the answer.",
+ "user": "Please evaluate the following video-based question-answer pair:\n\n"
+ "Question: {question}\n"
+ "Correct Answer: {answer}\n"
+ "Predicted Answer: {pred}\n\n"
+ "Provide your evaluation only as a temporal accuracy score where the temporal accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of temporal consistency. "
+ "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the temporal accuracy score in INTEGER, not STRING."
+ "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
+ "For example, your response should look like this: {{'score': 4.8}}."
+ },
+ "context": {
+ "system": "You are an intelligent chatbot designed for evaluating the contextual understanding of generative outputs for video-based question-answer pairs. "
+ "Your task is to compare the predicted answer with the correct answer and determine if the generated response aligns with the overall context of the video content. Here's how you can accomplish the task:"
+ "------"
+ "##INSTRUCTIONS: "
+ "- Evaluate whether the predicted answer aligns with the overall context of the video content. It should not provide information that is out of context or misaligned.\n"
+ "- The predicted answer must capture the main themes and sentiments of the video.\n"
+ "- Consider synonyms or paraphrases as valid matches.\n"
+ "- Provide your evaluation of the contextual understanding of the prediction compared to the answer.",
+ "user": "Please evaluate the following video-based question-answer pair:\n\n"
+ "Question: {question}\n"
+ "Correct Answer: {answer}\n"
+ "Predicted Answer: {pred}\n\n"
+ "Provide your evaluation only as a contextual understanding score where the contextual understanding score is an integer value between 0 and 5, with 5 indicating the highest level of contextual understanding. "
+ "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is contextual understanding score in INTEGER, not STRING."
+ "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
+ "For example, your response should look like this: {{'score': 4.8}}."
+ },
+ 'detailed_orientation': {
+ "system": "You are an intelligent chatbot designed for evaluating the detail orientation of generative outputs for video-based question-answer pairs. "
+ "Your task is to compare the predicted answer with the correct answer and determine its level of detail, considering both completeness and specificity. Here's how you can accomplish the task:"
+ "------"
+ "##INSTRUCTIONS: "
+ "- Check if the predicted answer covers all major points from the video. The response should not leave out any key aspects.\n"
+ "- Evaluate whether the predicted answer includes specific details rather than just generic points. It should provide comprehensive information that is tied to specific elements of the video.\n"
+ "- Consider synonyms or paraphrases as valid matches.\n"
+ "- Provide a single evaluation score that reflects the level of detail orientation of the prediction, considering both completeness and specificity.",
+ "user": "Please evaluate the following video-based question-answer pair:\n\n"
+ "Question: {question}\n"
+ "Correct Answer: {answer}\n"
+ "Predicted Answer: {pred}\n\n"
+ "Provide your evaluation only as a detail orientation score where the detail orientation score is an integer value between 0 and 5, with 5 indicating the highest level of detail orientation. "
+ "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the detail orientation score in INTEGER, not STRING."
+ "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
+ "For example, your response should look like this: {{'score': 4.8}}."
+ ,
+ },
+ "correctness": {
+ "system": "You are an intelligent chatbot designed for evaluating the factual accuracy of generative outputs for video-based question-answer pairs. "
+ "Your task is to compare the predicted answer with the correct answer and determine if they are factually consistent. Here's how you can accomplish the task:"
+ "------"
+ "##INSTRUCTIONS: "
+ "- Focus on the factual consistency between the predicted answer and the correct answer. The predicted answer should not contain any misinterpretations or misinformation.\n"
+ "- The predicted answer must be factually accurate and align with the video content.\n"
+ "- Consider synonyms or paraphrases as valid matches.\n"
+ "- Evaluate the factual accuracy of the prediction compared to the answer.",
+ "user": "Please evaluate the following video-based question-answer pair:\n\n"
+ "Question: {question}\n"
+ "Correct Answer: {answer}\n"
+ "Predicted Answer: {pred}\n\n"
+ "Provide your evaluation only as a factual accuracy score where the factual accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of factual consistency. "
+ "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the factual accuracy score in INTEGER, not STRING."
+ "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
+ "For example, your response should look like this: {{'score': 4.8}}."
+
+ },
+ "consistency": {
+ "system": "You are an intelligent chatbot designed for evaluating the consistency of generative outputs for similar video-based question-answer pairs. "
+ "You will be given two very similar questions, a common answer common to both the questions and predicted answers for the two questions ."
+ "Your task is to compare the predicted answers for two very similar question, with a common correct answer and determine if they are consistent. Here's how you can accomplish the task:"
+ "------"
+ "##INSTRUCTIONS: "
+ "- Focus on the consistency between the two predicted answers and the correct answer. Both predicted answers should correspond to the correct answer and to each other, and should not contain any contradictions or significant differences in the conveyed information.\n"
+ "- Both predicted answers must be consistent with each other and the correct answer, in terms of the information they provide about the video content.\n"
+ "- Consider synonyms or paraphrases as valid matches, but only if they maintain the consistency in the conveyed information.\n"
+ "- Evaluate the consistency of the two predicted answers compared to the correct answer.",
+ "user":"Please evaluate the following video-based question-answer pair:\n\n"
+ "Question 1: {question}\n"
+ "Question 2: {question1}\n"
+ "Correct Answer: {answer}\n"
+ "Predicted Answer to Question 1: {pred}\n"
+ "Predicted Answer to Question 2: {pred1}\n\n"
+ "Provide your evaluation only as a consistency score where the consistency score is an integer value between 0 and 5, with 5 indicating the highest level of consistency. "
+ "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the consistency score in INTEGER, not STRING."
+ "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
+ "For example, your response should look like this: {{'score': 4.8}}."
+
+ },
+})
+
+SYSTEM_VCGBENCH="""
+You are Video-ChatGPT, a large vision-language assistant.
+You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language.
+Follow the instructions carefully and explain your answers in detail based on the provided video.
+"""
+
+def check_ans(gt, pred, question, sub_task_type, question1=None, pred1=None, model="gpt-3.5-turbo-0125"):
+ # # dummy
+ # print('-' * 10 + f'pred: {pred}')
+ # print('-' * 10 + f'gt: {gt}')
+ try:
+ # Compute the temporal understanding score
+ user_input = sub_task_type2chatgpt_contents[sub_task_type]['user']
+ if question1 is not None and pred1 is not None:
+ assert sub_task_type == 'consistency', 'consistency has two answers'
+ user_input = user_input.format(question=question, answer=gt, pred=pred, pred1=pred1, question1=question1)
+ else:
+ user_input = user_input.format(question=question, answer=gt, pred=pred)
+ completion = client.chat.completions.create(
+ model=model,
+ messages=[
+ {
+ "role": "system",
+ "content": sub_task_type2chatgpt_contents[sub_task_type]['system'],
+ },
+ {
+ "role": "user",
+ "content": user_input,
+ }
+ ]
+ )
+ # Convert response to a Python dictionary.
+ response_message = completion.choices[0].message.content
+ response_dict = ast.literal_eval(response_message)
+ flag, score = response_dict['score'] > 3, response_dict['score']
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ flag, score = False, 0
+ print(
+ f"GPT cannot deal with:\n"
+ f"--pred: {pred},\n"
+ f"--gt: {gt}\n"
+ f"--gpt responded: {response_message}\n"
+ "--will assign flag=False and score=0"
+ )
+ print(f"Dumb Answer in {sub_task_type}")
+ return flag, score
+
+def chatgpt_eval(res, model="gpt-3.5-turbo-0125"):
+ pred = res['pred']
+ gt = res['gt']
+ question=res['question']
+ task_type = res['task_type']
+ if task_type == 'generic_qa':
+ # eval three sub tasks for generic
+ for sub_task_type in ('context', 'detailed_orientation', 'correctness'):
+ if pred=="":
+ print("no pred")
+ score = 0
+ else:
+ acc, score = check_ans(gt=gt, pred=pred, question=question, sub_task_type=sub_task_type, model=model) # acc is bool, score is given by chatgpt
+ # update the scores in result_list for this sample
+ res['scores'] = res.get('scores', {})
+ res['scores'][sub_task_type] = score
+ elif task_type == 'temporal_qa': # only do temporal eval for temporal_qa
+ sub_task_type = 'temporal'
+ if pred=="":
+ print("no pred")
+ score = 0
+ else:
+ acc, score = check_ans(gt=gt, pred=pred, question=question, sub_task_type=sub_task_type, model=model) # acc is bool, score is given by chatgpt
+ # update the scores in result_list for this sample
+ res['scores'] = res.get('scores', {})
+ res['scores'][sub_task_type] = score
+ elif task_type == 'consistency_qa': # only do consistency eval for consistency_qa
+ sub_task_type = 'consistency'
+ assert 'pred1' in res and 'question1' in res, 'two questions and preds'
+ pred1 = res['pred1']
+ question1 = res['question1']
+ if pred=="" or pred1=="":
+ print("no pred")
+ score = 0
+ else:
+ acc, score = check_ans(
+ gt=gt, pred=pred, pred1=pred1, question=question, question1=question1,
+ sub_task_type=sub_task_type, model=model) # acc is bool, score is given by chatgpt
+ # update the scores in result_list for this sample
+ res['scores'] = res.get('scores', {})
+ res['scores'][sub_task_type] = score
+ else:
+ raise NotImplementedError(f'not implemented task type for {task_type}')
+
+ return res
+
+def save_results(result_list, save_path, model="gpt-3.5-turbo-0125"):
+ dump_json(result_list, save_path, 'inference_results.json')
+ with Pool(7) as pool:
+ # result_list = pool.map(partial(chatgpt_eval, model=model), result_list)
+ func = partial(chatgpt_eval, model=model)
+ result_list = [ res for res in tqdm.tqdm(pool.imap_unordered(func, result_list), total=len(result_list), desc='Language Chat Model Automated Evaluation...')]
+
+ final_res, acc_dict = {}, {}
+ correct, total, total_score = 0, 0, 0
+ for i, res in enumerate(result_list):
+ task_type = res['task_type']
+ for sub_task_type, score in res['scores'].items():
+ if sub_task_type not in acc_dict:
+ acc_dict[sub_task_type] = {
+ 'correct': 0,
+ 'total': 0,
+ 'score': 0,
+ } # correct, total
+ correct = score > 3
+ acc_dict[sub_task_type]['total'] += 1
+ acc_dict[sub_task_type]['correct'] += correct
+ acc_dict[sub_task_type]['score'] += score
+
+ for k, v in acc_dict.items():
+ final_res[k] = {
+ 'acc': v['correct'] / v['total'] * 100,
+ 'score': v['score'] / v['total']
+ }
+ correct += v['correct']
+ total += v['total']
+ total_score += v['score']
+ final_res['Avg_Acc'] = correct / total * 100
+ final_res['Avg_Score'] = total_score / total
+
+ all_results = {
+ "acc_dict": acc_dict,
+ "result_list": result_list
+ }
+ result_post =f"-{model}"
+ dump_json(all_results, save_path, f'final_results{result_post}.json')
+ dump_json(final_res, save_path, f'upload_leaderboard{result_post}.json')
+
+def load_results(save_path, model="gpt-3.5-turbo-0125"):
+
+ result_list = load_json(save_path, f'final_results-{model}.json')
+ if result_list is not None:
+ result_list = result_list['result_list']
+
+ if result_list is None:
+ result_list = load_json(save_path, 'inference_results.json')
+
+ return result_list
+
+class VideoChatGPTBenchDataset(EvalDataset):
+ data_dir = "DATAS/VCGBench"
+ data_list_info = OrderedDict({
+ "generic_qa": OrderedDict(
+ json_relpath="Zero_Shot_QA/Benchmarking_QA/generic_qa.json",
+ prefix="DATAS/VCGBench/Videos/Benchmarking",
+ data_type="video",
+ bound=False,
+ question_key='Q',
+ answer_key='A',
+ name_key='video_name',
+ postfix=('mp4', 'mkv'),
+ ),
+ "temporal_qa": OrderedDict(
+ json_relpath="Zero_Shot_QA/Benchmarking_QA/temporal_qa.json",
+ prefix="DATAS/VCGBench/Videos/Benchmarking",
+ data_type="video",
+ bound=False,
+ question_key='Q',
+ answer_key='A',
+ name_key='video_name',
+ postfix=('mp4', 'mkv'),
+ ), # don't has start & end
+ "consistency_qa": OrderedDict(
+ # consistency is quite different in evaluating, and also awkward, hold to later.
+ json_relpath="Zero_Shot_QA/Benchmarking_QA/consistency_qa.json",
+ prefix="DATAS/VCGBench/Videos/Benchmarking",
+ data_type="video",
+ bound=False,
+ question_key=('Q1', 'Q2'),
+ answer_key='A',
+ name_key='video_name',
+ postfix=('mp4', 'mkv'),
+ ),
+ })
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ data_list_info = self.data_list_info
+ data_dir = self.data_dir
+
+ self.data_list = []
+ for k, v in data_list_info.items():
+ with open(os.path.join(data_dir, v['json_relpath']), 'r') as f:
+ json_data = json.load(f)
+ for data in json_data:
+ self.data_list.append({
+ 'task_type': k,
+ 'data': data,
+ **v, # all the infos
+ })
+ # self.data_list = self.data_list[:10] # for debug
+ # random.shuffle(self.data_list) # for debug
+ self.decord_method = {
+ 'video': self.read_video,
+ 'gif': self.read_gif,
+ 'frame': self.read_frame,
+ }
+ # # transform
+ # crop_size = resolution
+ # scale_size = resolution
+ # input_mean = [0.48145466, 0.4578275, 0.40821073]
+ # input_std = [0.26862954, 0.26130258, 0.27577711]
+ # self.transform = T.Compose([
+ # GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
+ # GroupCenterCrop(crop_size),
+ # Stack(),
+ # ToTorchFormatTensor(),
+ # GroupNormalize(input_mean, input_std)
+ # ])
+
+ def __getitem__(self, idx):
+ task_type = self.data_list[idx]['task_type']
+ video_name_key = self.data_list[idx]['name_key']
+ video_name = self.data_list[idx]['data'][video_name_key]
+ video_postfixs = self.data_list[idx]['postfix']
+
+ if self.num_segments != 0:
+ video_paths = []
+ for p in video_postfixs:
+ video_path = os.path.join(self.data_list[idx]['prefix'], video_name + '.' + p)
+ if os.path.exists(video_path):
+ video_paths.append(video_path)
+ assert len(video_paths) > 0, f'no video named {video_name}'
+ # video_filename = self.data_list[idx]['data'][video_name_key] + video_postfix
+ video_path = video_paths[0]
+ decord_method = self.decord_method[self.data_list[idx]['data_type']]
+ bound = None
+ if self.data_list[idx]['bound']:
+ bound = (
+ self.data_list[idx]['data']['start'],
+ self.data_list[idx]['data']['end'],
+ )
+ images_group = decord_method(video_path, bound)
+ else:
+ # zero frame, no image
+ images_group = None
+
+ data = {
+ 'video_path': video_path,
+ 'video_pils': images_group, # some might use the original pils and do their own transforms
+ 'task_type': task_type,
+ }
+
+
+ answer_key = self.data_list[idx]['answer_key']
+ question_key = self.data_list[idx]['question_key']
+
+ if task_type == 'consistency_qa' and isinstance(question_key, tuple):
+ question=self.data_list[idx]['data'][question_key[0]]
+ question1=self.data_list[idx]['data'][question_key[1]]
+ answer=self.data_list[idx]['data'][answer_key]
+
+ data.update({
+ 'question': question,
+ 'question1': question1,
+ 'answer': answer,
+ })
+ elif isinstance(question_key, str):
+ question=self.data_list[idx]['data'][question_key]
+ answer=self.data_list[idx]['data'][answer_key]
+ data.update({
+ 'question': question,
+ 'answer': answer,
+ })
+ else:
+ raise ValueError('')
+
+ return data
diff --git a/tasks/eval/vcgbench/pllava_eval_vcgbench.py b/tasks/eval/vcgbench/pllava_eval_vcgbench.py
new file mode 100644
index 0000000000000000000000000000000000000000..7182a85cf2610fdf1a999bfc1e90b6fcf3fdee43
--- /dev/null
+++ b/tasks/eval/vcgbench/pllava_eval_vcgbench.py
@@ -0,0 +1,306 @@
+
+import functools
+import itertools
+import logging
+from tqdm import tqdm
+from PIL import Image
+from multiprocessing import Pool
+import multiprocessing as mp
+from argparse import ArgumentParser
+import numpy as np
+
+import torch
+import torchvision
+
+from decord import VideoReader, cpu
+import transformers
+
+
+from tasks.eval.model_utils import load_pllava, pllava_answer
+from tasks.eval.eval_utils import conv_templates
+from tasks.eval.vcgbench import (
+ VideoChatGPTBenchDataset,
+ save_results,
+ load_results,
+)
+
+logging.basicConfig()
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+RESOLUTION = 672 #
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ required=True,
+ default='llava-hf/llava-1.5-7b-hf'
+ )
+ parser.add_argument(
+ "--save_path",
+ type=str,
+ required=True,
+ default='"./test_results/test_llava_mvbench"'
+ )
+ parser.add_argument(
+ "--num_frames",
+ type=int,
+ required=True,
+ default=4,
+ )
+ parser.add_argument(
+ "--use_lora",
+ action='store_true'
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ required=False,
+ default=32,
+ )
+ parser.add_argument(
+ "--weight_dir",
+ type=str,
+ required=False,
+ default=None,
+ )
+ parser.add_argument(
+ "--eval_model",
+ type=str,
+ required=False,
+ default="gpt-3.5-turbo-0125",
+ )
+ parser.add_argument(
+ "--conv_mode",
+ type=str,
+ required=False,
+ default='eval_vcgbench',
+ )
+ parser.add_argument(
+ "--test_ratio",
+ required=False,
+ default=None,
+ )
+ parser.add_argument(
+ "--pooling_shape",
+ type=str,
+ required=False,
+ default=None,
+ )
+ args = parser.parse_args()
+ return args
+
+def load_model_and_dataset(rank, world_size, pretrained_model_name_or_path, num_frames, use_lora, lora_alpha, weight_dir, test_ratio, pooling_shape=(16,12,12)):
+ # remind that, once the model goes larger (30B+) may cause the memory to be heavily used up. Even Tearing Nodes.,
+ model, processor = load_pllava(pretrained_model_name_or_path, num_frames=num_frames, use_lora=use_lora, weight_dir=weight_dir, lora_alpha=lora_alpha, pooling_shape=pooling_shape)
+ logger.info('done loading llava')
+ # position embedding
+ model = model.to(torch.device(rank))
+ model = model.eval()
+
+ dataset = VideoChatGPTBenchDataset(num_segments=num_frames, test_ratio=test_ratio)
+ dataset.set_rank_and_world_size(rank, world_size)
+ return model, processor, dataset
+
+def infer_vcgbench(
+ model,
+ processor,
+ data_sample,
+ conv_mode,
+ pre_query_prompt=None, # add in the head of question
+ post_query_prompt=None, # add in the end of question
+ print_res=False,
+ ):
+ video_list = data_sample["video_pils"]
+ conv = conv_templates[conv_mode].copy()
+ conv.user_query(data_sample['question'], pre_query_prompt, post_query_prompt, is_mm=True)
+ stop_criteria_keywords=["###","USER"]
+
+ llm_message, conv = pllava_answer(
+ conv=conv,
+ model=model,
+ processor=processor,
+ img_list=video_list,
+ max_new_tokens=100,
+ do_sample=False,
+ print_res=print_res,
+ stop_criteria_keywords=stop_criteria_keywords
+ )
+
+
+ return llm_message
+
+def single_test(model, processor, vid_path, num_frames=4, conv_mode="plain"):
+ def get_index(num_frames, num_segments):
+ seg_size = float(num_frames - 1) / num_segments
+ start = int(seg_size / 2)
+ offsets = np.array([
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
+ ])
+ return offsets
+
+ def load_video(video_path, num_segments=8, return_msg=False, num_frames=4, resolution=336):
+ transforms = torchvision.transforms.Resize(size=resolution)
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
+ num_frames = len(vr)
+ frame_indices = get_index(num_frames, num_segments)
+ images_group = list()
+ for frame_index in frame_indices:
+ img = Image.fromarray(vr[frame_index].asnumpy())
+ images_group.append(transforms(img))
+ if return_msg:
+ fps = float(vr.get_avg_fps())
+ sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
+ # " " should be added in the start and end
+ msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
+ return images_group, msg
+ else:
+ return images_group
+
+ if num_frames != 0:
+ vid, msg = load_video(vid_path, num_segments=num_frames, return_msg=True, resolution=RESOLUTION)
+ else:
+ vid, msg = None, 'num_frames is 0, not inputing image'
+ img_list = vid
+ conv = conv_templates[conv_mode].copy()
+ conv.user_query("Describe the video in details.", is_mm=True)
+ llm_response, conv = pllava_answer(conv=conv, model=model, processor=processor, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)
+
+def run(rank, args, world_size,start_rank=0):
+ if rank != 0:
+ transformers.utils.logging.set_verbosity_error()
+ logger.setLevel(transformers.logging.ERROR)
+ print_res = True
+ conv_mode= args.conv_mode
+ pre_query_prompt = None
+ post_query_prompt = None
+
+
+ logger.info(f"CONV_MODE: {conv_mode}")
+
+ logger.info(f'loading model and constructing dataset to gpu {rank}...')
+ if args.pooling_shape is not None:
+ pooling_shape=tuple([int(x) for x in args.pooling_shape.split("-")])
+ model, processor, dataset = load_model_and_dataset(rank,
+ world_size,
+ pretrained_model_name_or_path=args.pretrained_model_name_or_path,
+ num_frames=args.num_frames,
+ use_lora=args.use_lora,
+ weight_dir=args.weight_dir,
+ lora_alpha=args.lora_alpha,
+ test_ratio=args.test_ratio,
+ pooling_shape=pooling_shape)
+ logger.info(f'done model and dataset...')
+ logger.info('constructing dataset...')
+ logger.info('single test...')
+ vid_path = "./example/yoga.mp4"
+ if rank == 0:
+ single_test(model,
+ processor,
+ vid_path,
+ num_frames=args.num_frames,
+ conv_mode=args.conv_mode)
+ logger.info('single test done...')
+ tbar = tqdm(total=len(dataset))
+
+ result_list = []
+ done_count = 0
+ for example in dataset:
+ task_type = example['task_type']
+ gt = example['answer']
+ if task_type == 'consistency_qa':
+ assert 'question' in example and 'question1' in example, 'two questions'
+ pred = infer_vcgbench(
+ model,
+ processor,
+ example,
+ conv_mode=conv_mode,
+ pre_query_prompt=pre_query_prompt,
+ post_query_prompt=post_query_prompt,
+ print_res=print_res,
+ )
+ # inference the other question
+ example['question'], example['question1'] = example['question1'], example['question']
+ pred1 = infer_vcgbench(
+ model,
+ processor,
+ example,
+ conv_mode=conv_mode,
+ pre_query_prompt=pre_query_prompt,
+ post_query_prompt=post_query_prompt,
+ print_res=print_res,
+ )
+ res = {
+ 'pred': pred,
+ 'pred1': pred1,
+ 'gt': gt,
+ 'video': example['video_path'],
+ 'task_type': task_type,
+ 'question': example['question'],
+ 'question1': example['question1'],
+ }
+ elif task_type in dataset.data_list_info:
+ pred = infer_vcgbench(
+ model,
+ processor,
+ example,
+ conv_mode=conv_mode,
+ pre_query_prompt=pre_query_prompt,
+ post_query_prompt=post_query_prompt,
+ print_res=print_res,
+ )
+ res = {
+ 'pred': pred,
+ 'gt': gt,
+ 'video_path': example['video_path'],
+ 'question': example['question'],
+ 'task_type': task_type,
+ }
+ else:
+ raise NotImplementedError(f'not implemented task type {task_type}')
+
+ result_list.append(res)
+ if rank == 0:
+ tbar.update(len(result_list) - done_count, )
+ tbar.set_description_str(
+ f"One Chunk--Task Type: {task_type}-"
+ f"gt: {gt[:min(15, len(gt))]}......--pred: {pred[:min(15, len(gt))]}......"
+ )
+ done_count = len(result_list)
+ return result_list
+
+def main():
+ multiprocess=True
+ mp.set_start_method('spawn')
+ args = parse_args()
+ save_path = args.save_path
+ eval_model = args.eval_model
+ result_list = load_results(save_path)
+ start_rank=0
+
+ if result_list is None:
+ if multiprocess:
+ logger.info(f'started benchmarking, saving to: {save_path}')
+ n_gpus = torch.cuda.device_count()
+ # assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
+ world_size = n_gpus
+ with Pool(world_size) as pool:
+ func = functools.partial(run, args=args, world_size=world_size, start_rank=start_rank)
+ result_lists = pool.map(func, range(world_size))
+
+ logger.info('finished running')
+ result_list = [ res for res in itertools.chain(*result_lists)]
+ else:
+ result_list = run(0, world_size=1, args=args) # debug
+
+ else:
+ logger.info(f'loaded results from {save_path}')
+
+ save_results(result_list, save_path, model=eval_model)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/tasks/eval/vcgbench/show_vcg.py b/tasks/eval/vcgbench/show_vcg.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1848f469b2b6dcb9fa884dc52212a91ab233cbc
--- /dev/null
+++ b/tasks/eval/vcgbench/show_vcg.py
@@ -0,0 +1,45 @@
+
+import argparse
+import gradio as gr
+
+from tasks.eval.vcgbench import load_results
+import json
+
+# example = videogallery().example_inputs()
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--save_path',
+ required=True,
+ )
+ args = parser.parse_args()
+ return args
+
+
+args = parse_args()
+result_list = load_results(args.save_path)
+
+
+def show(result_index, ):
+ info = result_list[result_index]
+ video_path = info['video_path']
+ info_str = json.dumps(info, indent=4)
+ return video_path, info_str
+
+with gr.Blocks() as demo:
+ gr.Markdown(
+ f"# Showing The Results from {args.save_path}"
+ )
+ with gr.Row():
+ with gr.Column():
+ show_video = gr.Video(interactive=False)
+
+ with gr.Column():
+ result_index = gr.Slider(0, len(result_list), step=1)
+ info = gr.Text(interactive=False)
+
+ result_index.change(show, [result_index], [show_video, info])
+
+demo.launch(share=True)
diff --git a/tasks/eval/videoqabench/__init__.py b/tasks/eval/videoqabench/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..541495d41da665f28951f701147e02b0126295a2
--- /dev/null
+++ b/tasks/eval/videoqabench/__init__.py
@@ -0,0 +1,348 @@
+from functools import partial
+import os
+import json
+from typing import OrderedDict
+
+import tqdm
+import torch
+from PIL import Image
+import ast
+import numpy as np
+from multiprocessing import Pool
+
+from decord import VideoReader, cpu
+
+import os
+from tasks.eval.eval_utils import (
+ dump_json,
+ load_json,
+ EvalDataset,
+)
+from dataclasses import dataclass
+from openai import OpenAI
+client = OpenAI(
+ # This is the default and can be omitted
+ api_key=os.environ.get("OPENAI_API_KEY"),
+)
+
+task_type2chatgpt_contents = OrderedDict({
+ "MSVD_QA": {
+ "system": "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
+ "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
+ "------"
+ "##INSTRUCTIONS: "
+ "- Focus on the meaningful match between the predicted answer and the correct answer.\n"
+ "- Consider synonyms or paraphrases as valid matches.\n"
+ "- Evaluate the correctness of the prediction compared to the answer.",
+ "user": """Please evaluate the following video-based question-answer pair:\n\n"""
+ """Question: {question}\n"""
+ """Correct Answer: {answer}\n"""
+ """Predicted Answer: {pred}\n\n"""
+ """Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. """
+ """Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."""
+ """DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. """
+ """For example, your response should look like this: {{'pred': 'yes', 'score': 4.8}}."""
+ },
+ "MSRVTT_QA": {
+ "system": "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
+ "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
+ "------"
+ "##INSTRUCTIONS: "
+ "- Focus on the meaningful match between the predicted answer and the correct answer.\n"
+ "- Consider synonyms or paraphrases as valid matches.\n"
+ "- Evaluate the correctness of the prediction compared to the answer.",
+ "user": """Please evaluate the following video-based question-answer pair:\n\n"""
+ """Question: {question}\n"""
+ """Correct Answer: {answer}\n"""
+ """Predicted Answer: {pred}\n\n"""
+ """Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. """
+ """Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."""
+ """DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. """
+ """For example, your response should look like this: {{'pred': 'yes', 'score': 4.8}}."""
+ # """Make sure you only response with text that Follows Python syntax. For example, your response should look like this: {'pred': 'yes', 'score': 4.8}."""
+ },
+ "ActivityNet": {
+ "system": "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
+ "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
+ "------"
+ "##INSTRUCTIONS: "
+ "- Focus on the meaningful match between the predicted answer and the correct answer.\n"
+ "- Consider synonyms or paraphrases as valid matches.\n"
+ "- Evaluate the correctness of the prediction compared to the answer.",
+ "user": """Please evaluate the following video-based question-answer pair:\n\n"""
+ """Question: {question}\n"""
+ """Correct Answer: {answer}\n"""
+ """Predicted Answer: {pred}\n\n"""
+ """Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. """
+ """Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."""
+ """DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. """
+ """For example, your response should look like this: {{'pred': 'yes', 'score': 4.8}}."""
+ # """Make sure you only response with text that Follows Python syntax. For example, your response should look like this: {'pred': 'yes', 'score': 4.8}."""
+ },
+ "TGIF_QA": {
+ "system": "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
+ "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
+ "------"
+ "##INSTRUCTIONS: "
+ "- Focus on the meaningful match between the predicted answer and the correct answer.\n"
+ "- Consider synonyms or paraphrases as valid matches.\n"
+ "- Evaluate the correctness of the prediction compared to the answer.",
+ "user": """Please evaluate the following video-based question-answer pair:\n\n"""
+ """Question: {question}\n"""
+ """Correct Answer: {answer}\n"""
+ """Predicted Answer: {pred}\n\n"""
+ """Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. """
+ """Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."""
+ """DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. """
+ """For example, your response should look like this: {{'pred': 'yes', 'score': 4.8}}."""
+ # """Make sure you only response with text that Follows Python syntax. For example, your response should look like this: {'pred': 'yes', 'score': 4.8}."""
+ },
+})
+
+# Follow the instructions carefully and be helpful and precise with your answer.
+
+def check_ans_qa(question, pred, gt, task_type, model="gpt-3.5-turbo-0125"):
+ try:
+ # Compute the temporal understanding score
+ user_input = task_type2chatgpt_contents[task_type]['user']
+ user_input = user_input.format(question=question, answer=gt, pred=pred)
+ completion = client.chat.completions.create(
+ model=model,
+ messages=[
+ {
+ "role": "system",
+ "content": task_type2chatgpt_contents[task_type]['system'],
+ },
+ {
+ "role": "user",
+ "content": user_input,
+ }
+ ]
+ )
+ # Convert response to a Python dictionary.
+ # response_message = completion["choices"][0]["message"]["content"]
+ response_message = completion.choices[0].message.content
+ response_dict = ast.literal_eval(response_message)
+ pred = response_dict['pred']
+ score = response_dict['score']
+ if not pred in ('yes', 'no') or not isinstance(score, (int, float)):
+ raise ValueError(f"{model} doesn't follow")
+ flag = pred == 'yes'
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ flag, score = False, 0
+ print(
+ f"GPT cannot deal with:\n"
+ f"--pred: {pred}\n"
+ f"--gt: {gt}\n"
+ f"--gpt responded: {response_message}\n"
+ "--will assign flag=False and score=0"
+ )
+ print(f"Dumb Answer in {task_type}")
+ return flag, score
+
+def chatgpt_eval(res, model="gpt-3.5-turbo-0125"):
+ pred = res['pred']
+ gt = res['gt']
+ question=res['question']
+ task_type = res['task_type']
+ correct, score = check_ans_qa(question=question, pred=pred, gt=gt,task_type=task_type, model=model) # acc is bool, score is given by chatgpt
+ # update the scores in result_list for this sample
+ res['score'] = score
+ res['correct'] = correct
+ return res
+
+def save_results(result_list, save_path, model="gpt-3.5-turbo-0125"):
+ dump_json(result_list, save_path, 'inference_results.json')
+ with Pool(7) as pool:
+ func = partial(chatgpt_eval, model=model)
+ result_list = [ res for res in tqdm.tqdm(pool.imap_unordered(func, result_list), total=len(result_list), desc='Language Chat Model Automated Evaluation...')]
+
+ # result_list = pool.map(partial(chatgpt_eval, model=model), result_list)
+ # result_list = [chatgpt_eval(res, model=model) for res in result_list]
+
+ final_res, acc_dict = {}, {}
+ correct, total, total_score = 0, 0, 0
+ for i, res in enumerate(result_list):
+ task_type = res['task_type']
+ if task_type not in acc_dict:
+ acc_dict[task_type] = {
+ 'correct': 0,
+ 'total': 0,
+ 'score': 0,
+ } # correct, total
+ acc_dict[task_type]['total'] += 1
+ acc_dict[task_type]['correct'] += res['correct']
+ acc_dict[task_type]['score'] += res['score']
+
+ for k, v in acc_dict.items():
+ final_res[k] = {
+ 'acc': v['correct'] / v['total'] * 100,
+ 'score': v['score'] / v['total']
+ }
+ correct += v['correct']
+ total += v['total']
+ total_score += v['score']
+
+ final_res['Avg_Acc'] = correct / total * 100
+ final_res['Avg_Score'] = total_score / total
+
+ all_results = {
+ "acc_dict": acc_dict,
+ "result_list": result_list
+ }
+ dump_json(all_results, save_path, 'all_results.json')
+ dump_json(final_res, save_path, 'upload_leaderboard.json')
+
+def load_results(save_path):
+ json_data = load_json(save_path, 'inference_results.json')
+ return json_data
+
+@dataclass
+class OpenendQASample():
+ question: str
+ answer: str
+
+
+
+class VideoQABenchDataset(EvalDataset):
+ data_dir = "DATAS/VideoQA"
+ data_list_info = OrderedDict({
+ "MSVD_QA": OrderedDict(
+ q_json_relpath="MSVD_Zero_Shot_QA/test_q.json",
+ a_json_relpath="MSVD_Zero_Shot_QA/test_a.json",
+ prefix="DATAS/VideoQA/MSVD_Zero_Shot_QA/videos",
+ data_type="video",
+ bound=False,
+ question_key='question',
+ answer_key='answer',
+ name_key='video_name',
+ postfix=('avi',),
+ ),
+ "MSRVTT_QA": OrderedDict(
+ q_json_relpath="MSRVTT_Zero_Shot_QA/test_q.json",
+ a_json_relpath="MSRVTT_Zero_Shot_QA/test_a.json",
+ prefix="DATAS/VideoQA/MSRVTT_Zero_Shot_QA/videos/all",
+ data_type="video",
+ bound=False,
+ question_key='question',
+ answer_key='answer',
+ name_key='video_name',
+ postfix=('mp4', ),
+ ), # don't has start & end
+ "ActivityNet": OrderedDict(
+ q_json_relpath="ActivityNet/test_q.json",
+ a_json_relpath="ActivityNet/test_a.json",
+ prefix="DATAS/VideoQA/ActivityNet/all_test",
+ data_type="video",
+ bound=False,
+ question_key='question',
+ answer_key='answer',
+ name_key='video_name',
+ postfix=('mp4', 'mkv', 'webm'),
+ ), # don't has start & end
+ "TGIF_QA": OrderedDict(
+ q_json_relpath="TGIF_QA/test_q.json",
+ a_json_relpath="TGIF_QA/test_a.json",
+ prefix="DATAS/VideoQA/TGIF_QA/tgif_videos",
+ data_type="gif",
+ bound=False,
+ question_key='question',
+ answer_key='answer',
+ name_key='video_name',
+ postfix=('gif',),
+ ), # don't has start & end
+
+ })
+
+ def __init__(self, *args, **kwargs):
+ # test_ratio for videoqa is for each sub dataset
+ test_ratio = kwargs.pop('test_ratio', None)
+ kwargs['test_ratio'] = None
+ test_datasets = kwargs.pop('test_datasets', None)
+ super().__init__(*args, **kwargs)
+ test_ratio = 1 if test_ratio is None else test_ratio
+ self.test_ratio = test_ratio
+ if test_datasets is not None:
+ data_list_info = {k:v for k,v in self.data_list_info.items() if k in test_datasets}
+ else:
+ data_list_info = self.data_list_info
+ data_dir = self.data_dir
+
+ self.data_list = []
+ for k, v in data_list_info.items():
+ with open(os.path.join(data_dir, v['q_json_relpath']), 'r') as f:
+ quesions_json_data = json.load(f)
+ with open(os.path.join(data_dir, v['a_json_relpath']), 'r') as f:
+ answers_json_data = json.load(f)
+
+ indexs = list(range(len(quesions_json_data)))
+ np.random.RandomState(42).shuffle(indexs)
+ num_samples = int(len(indexs) * self.test_ratio) if 0 < self.test_ratio <= 1 else int(self.test_ratio)
+ indexs = indexs[:num_samples]
+ for i in indexs:
+ question_data = quesions_json_data[i]
+ answer_data = answers_json_data[i]
+ data = {}
+ # why do we have anet's video name not in the original json file???
+ if k == "ActivityNet":
+ question_data['video_name'] = 'v_' + question_data['video_name']
+ data.update(**question_data)
+ data.update(**answer_data)
+ self.data_list.append({
+ 'task_type': k,
+ 'data': data,
+ **v, # all the infos
+ })
+ print(len(self.data_list))
+
+ def __len__(self):
+ return len(self.data_list)
+
+
+ def __getitem__(self, idx):
+ decord_method = self.decord_method[self.data_list[idx]['data_type']]
+ bound = None
+ if self.data_list[idx]['bound']:
+ bound = (
+ self.data_list[idx]['data']['start'],
+ self.data_list[idx]['data']['end'],
+ )
+ video_name_key = self.data_list[idx]['name_key']
+ video_name = self.data_list[idx]['data'][video_name_key]
+
+ video_postfixs = self.data_list[idx]['postfix']
+ video_paths = []
+ for p in video_postfixs:
+ video_path = os.path.join(self.data_list[idx]['prefix'], video_name + '.' + p)
+ if os.path.exists(video_path):
+ video_paths.append(video_path)
+ assert len(video_paths) > 0, f'no video named {video_name}'
+ # video_filename = self.data_list[idx]['data'][video_name_key] + video_postfix
+ video_path = video_paths[0]
+ images_group = decord_method(video_path, bound)
+
+ question_key = self.data_list[idx]['question_key']
+ answer_key = self.data_list[idx]['answer_key']
+ sample = OpenendQASample(
+ question=self.data_list[idx]['data'][question_key],
+ answer=self.data_list[idx]['data'][answer_key]
+ )
+ question, answer = self.qa_template(sample)
+
+ return {
+ 'video_pils': images_group, # some might use the original pils and do their own transforms
+ 'question': question,
+ 'video_path': video_path,
+ 'answer': answer,
+ 'task_type': self.data_list[idx]['task_type']
+ }
+
+ def qa_template(self, data: OpenendQASample):
+ answer = data.answer
+ question = data.question
+ # by far, might use some prompting.
+ return question, answer
+
+
diff --git a/tasks/eval/videoqabench/pllava_eval_videoqabench.py b/tasks/eval/videoqabench/pllava_eval_videoqabench.py
new file mode 100644
index 0000000000000000000000000000000000000000..a028c0a95ac4817f70dbb8119d9d190d30326875
--- /dev/null
+++ b/tasks/eval/videoqabench/pllava_eval_videoqabench.py
@@ -0,0 +1,304 @@
+
+import functools
+import itertools
+import logging
+from tqdm import tqdm
+from PIL import Image
+from multiprocessing import Pool
+from argparse import ArgumentParser
+import multiprocessing as mp
+
+
+
+import numpy as np
+import torch
+
+import torchvision
+
+import transformers
+from decord import VideoReader, cpu
+
+from tasks.eval.model_utils import load_pllava, pllava_answer
+from tasks.eval.eval_utils import conv_templates
+
+logging.basicConfig()
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+IMAGE_TOKEN=''
+from tasks.eval.videoqabench import (
+ VideoQABenchDataset,
+ load_results,
+ save_results,
+)
+RESOLUTION = 672 #
+VIDEOQA_DATASETS=["MSVD_QA","MSRVTT_QA", "ActivityNet","TGIF_QA"]
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ required=True,
+ default='llava-hf/llava-1.5-7b-hf'
+ )
+ parser.add_argument(
+ "--save_path",
+ type=str,
+ required=True,
+ default='"./test_results/test_llava_mvbench"'
+ )
+ parser.add_argument(
+ "--num_frames",
+ type=int,
+ required=True,
+ default=4,
+ )
+ parser.add_argument(
+ "--use_lora",
+ action='store_true'
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ required=False,
+ default=32,
+ )
+ parser.add_argument(
+ "--max_new_tokens",
+ type=int,
+ required=False,
+ default=100,
+ )
+ parser.add_argument(
+ "--weight_dir",
+ type=str,
+ required=False,
+ default=None,
+ )
+ parser.add_argument(
+ "--eval_model",
+ type=str,
+ required=False,
+ default="gpt-3.5-turbo-0125",
+ )
+ parser.add_argument(
+ '--test_ratio',
+ type=float,
+ required=False,
+ default=1
+ )
+ parser.add_argument(
+ "--conv_mode",
+ type=str,
+ required=False,
+ default='eval_videoqabench',
+ )
+ parser.add_argument(
+ "--test_datasets",
+ type=str,
+ required=False,
+ default='MSVD_QA',
+ )
+ args = parser.parse_args()
+ return args
+
+def load_model_and_dataset(rank, world_size, pretrained_model_name_or_path, num_frames, use_lora, lora_alpha, weight_dir, test_ratio, test_datasets):
+ # remind that, once the model goes larger (30B+) may cause the memory to be heavily used up. Even Tearing Nodes.
+ model, processor = load_pllava(pretrained_model_name_or_path, num_frames=num_frames, use_lora=use_lora, lora_alpha=lora_alpha, weight_dir=weight_dir)
+ logger.info('done loading llava')
+ # position embedding
+ model = model.to(torch.device(rank))
+ model = model.eval()
+
+ dataset = VideoQABenchDataset(test_ratio=test_ratio, test_datasets=test_datasets, num_segments=num_frames)
+ dataset.set_rank_and_world_size(rank, world_size)
+ return model, processor, dataset
+
+def infer_videoqabench(
+ model,
+ processor,
+ data_sample,
+ conv_mode,
+ pre_query_prompt=None, # add in the head of question
+ post_query_prompt=None, # add in the end of question
+ answer_prompt=None, # add in the begining of answer
+ return_prompt=None, # add in the begining of return message
+ print_res=False,
+ max_new_tokens=100,
+ ):
+ video_list = data_sample["video_pils"]
+ conv = conv_templates[conv_mode].copy()
+
+ pre_query_prompt=conv.pre_query_prompt
+ post_query_prompt=conv.post_query_prompt
+ answer_prompt=conv.answer_prompt
+
+ conv.user_query(data_sample['question'], pre_query_prompt, post_query_prompt, is_mm=True)
+ if answer_prompt is not None:
+ conv.assistant_response(answer_prompt)
+
+ llm_message, conv = pllava_answer(
+ conv=conv,
+ model=model,
+ processor=processor,
+ img_list=video_list,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ print_res=print_res,
+ )
+
+ if answer_prompt is not None:
+ llm_message = ''.join(llm_message.split(answer_prompt.strip("\n"))[1:]).strip()
+
+ if return_prompt is not None:
+ llm_message = return_prompt + llm_message
+
+ return llm_message
+
+def single_test(model, processor, vid_path, num_frames=4, conv_mode="plain"):
+ def get_index(num_frames, num_segments):
+ seg_size = float(num_frames - 1) / num_segments
+ start = int(seg_size / 2)
+ offsets = np.array([
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
+ ])
+ return offsets
+
+ def load_video(video_path, num_segments=8, return_msg=False, num_frames=4, resolution=336):
+ transforms = torchvision.transforms.Resize(size=resolution)
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
+ num_frames = len(vr)
+ frame_indices = get_index(num_frames, num_segments)
+ images_group = list()
+ for frame_index in frame_indices:
+ img = Image.fromarray(vr[frame_index].asnumpy())
+ images_group.append(transforms(img))
+ if return_msg:
+ fps = float(vr.get_avg_fps())
+ sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
+ # " " should be added in the start and end
+ msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
+ return images_group, msg
+ else:
+ return images_group
+
+ if num_frames != 0:
+ vid, msg = load_video(vid_path, num_segments=num_frames, return_msg=True, resolution=RESOLUTION)
+ else:
+ vid, msg = None, 'num_frames is 0, not inputing image'
+ img_list = vid
+
+ conv = conv_templates[conv_mode].copy()
+ conv.user_query("Describe the video in details.", is_mm=True)
+ llm_response, conv = pllava_answer(conv=conv, model=model, processor=processor, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)
+
+def run(rank, args, world_size):
+ if rank != 0:
+ transformers.utils.logging.set_verbosity_error()
+ logger.setLevel(transformers.logging.ERROR)
+
+ print_res = True
+ conv_mode= args.conv_mode
+ pre_query_prompt = None
+ post_query_prompt = None
+ # pre_query_prompt = "Answer the question with a single word or phrase."
+
+ logger.info(f'loading model and constructing dataset to gpu {rank}...')
+ test_datasets = [x for x in args.test_datasets.split("-") if x in VIDEOQA_DATASETS]
+ assert len(test_datasets)>=1
+
+ model, processor, dataset = load_model_and_dataset(rank,
+ world_size,
+ pretrained_model_name_or_path=args.pretrained_model_name_or_path,
+ num_frames=args.num_frames,
+ use_lora=args.use_lora,
+ lora_alpha=args.lora_alpha,
+ weight_dir=args.weight_dir,
+ test_ratio=args.test_ratio,
+ test_datasets=test_datasets)
+ logger.info(f'done model and dataset...')
+ logger.info('constructing dataset...')
+ logger.info('single test...')
+ vid_path = "./example/yoga.mp4"
+ # vid_path = "./example/jesse_dance.mp4"
+ if rank == 0:
+ single_test(model, processor, vid_path, num_frames=args.num_frames, conv_mode=args.conv_mode)
+ logger.info('single test done...')
+ tbar = tqdm(total=len(dataset))
+ logger.info('single test...')
+
+ result_list = []
+ done_count = 0
+ for example in dataset:
+ task_type = example['task_type']
+ gt = example['answer']
+ if task_type in dataset.data_list_info:
+ pred = infer_videoqabench(
+ model,
+ processor,
+ example,
+ conv_mode=conv_mode,
+ pre_query_prompt=pre_query_prompt,
+ post_query_prompt=post_query_prompt,
+ print_res=print_res,
+ max_new_tokens=args.max_new_tokens,
+ )
+
+ infos = {
+ 'question': example['question'],
+ 'video_path': example['video_path']
+ }
+ res = {
+ 'pred': pred,
+ 'gt': gt,
+ 'task_type': task_type,
+ **infos
+ }
+ else:
+ raise NotImplementedError(f'not implemented task type {task_type}')
+ # res = chatgpt_eval(res)
+ result_list.append(res)
+ if rank == 0:
+ tbar.update(len(result_list) - done_count, )
+ tbar.set_description_str(
+ f"One Chunk--Task Type: {task_type}-"
+ f"gt: {gt[:min(15, len(gt))]}......--pred: {pred[:min(15, len(gt))]}......"
+ )
+ done_count = len(result_list)
+ return result_list
+
+def main():
+ multiprocess=True
+ mp.set_start_method('spawn')
+ args = parse_args()
+ save_path = args.save_path
+ eval_model = args.eval_model
+ logger.info(f'trying loading results from {save_path}')
+ result_list = load_results(save_path)
+
+ if result_list is None:
+ if multiprocess:
+
+ logger.info(f'started benchmarking, saving to: {save_path}')
+ n_gpus = torch.cuda.device_count()
+ # assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
+ world_size = n_gpus
+ with Pool(world_size) as pool:
+ func = functools.partial(run, args=args, world_size=world_size)
+ # func = functools.partial(run, world_size=world_size, model=model, dataset=dataset, result_list=[], acc_dict={})
+ result_lists = pool.map(func, range(world_size))
+
+ logger.info('finished running')
+
+ result_list = [ res for res in itertools.chain(*result_lists)]
+ else:
+ result_list = run(0, world_size=1, args=args) # debug
+ else:
+ logger.info(f'loaded results from {save_path}')
+
+ save_results(result_list, save_path, model=eval_model)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/tasks/shared_utils.py b/tasks/shared_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5dec719791e801f8636cf29eb34bf965ab35846
--- /dev/null
+++ b/tasks/shared_utils.py
@@ -0,0 +1,36 @@
+import copy
+import logging
+import os
+import os.path as osp
+from os.path import join
+
+import torch
+from torch.utils.data import ConcatDataset, DataLoader
+
+from utils.optimizer import create_optimizer
+from utils.scheduler import create_scheduler
+
+logger = logging.getLogger(__name__)
+
+
+def get_media_types(datasources):
+ """get the media types for for all the dataloaders.
+
+ Args:
+ datasources (List): List of dataloaders or datasets.
+
+ Returns: List. The media_types.
+
+ """
+ if isinstance(datasources[0], DataLoader):
+ datasets = [dataloader.dataset for dataloader in datasources]
+ else:
+ datasets = datasources
+ media_types = [
+ dataset.datasets[0].media_type
+ if isinstance(dataset, ConcatDataset)
+ else dataset.media_type
+ for dataset in datasets
+ ]
+
+ return media_types
diff --git a/tasks/train/config_pllava_nframe.py b/tasks/train/config_pllava_nframe.py
new file mode 100644
index 0000000000000000000000000000000000000000..b80ac33155504b002ca182cb3fef4f556ef655a3
--- /dev/null
+++ b/tasks/train/config_pllava_nframe.py
@@ -0,0 +1,135 @@
+from tasks.train.instruction_data import *
+
+# ========================= data ==========================
+# train_corpus = "videochat2_instruction"
+train_corpus = "videochat2_instruction_full"
+
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict()
+test_types = []
+num_workers = 8
+save_steps=10000
+ckpt_steps=1000
+stop_key = None
+deepspeed=False
+# ========================= input ==========================
+num_frames = 16
+num_frames_test = 1
+batch_size = 1
+gradient_accumulation_steps=16
+max_txt_l = 512
+max_train_steps=None
+pre_text = False
+gradient_checkpointing=False
+inputs = dict(
+ image_res=336,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ repo_id="llava-hf/llava-v1.6-vicuna-7b-hf",
+ pretrained_path=None,
+ load_from_origin=False,
+ origin_vision="",
+ origin_llm="",
+ vision_encoder=dict(
+ name="vit_l14", # somehow need this to tell the dataset the mean std of pretrained model
+ ),
+ torch_dtype='bfloat16',
+ freeze_projector=False,
+ freeze_lm=True,
+ freeze_vision_tower=True,
+ lora_target_modules=["q_proj", "v_proj"], # for llama/mistral/gemma
+ use_lora=True,
+ lora_r=128,
+ lora_alpha=32,
+ lora_dropout=0.05,
+ num_frames="${num_frames}",
+ pooling_method='avg',
+ use_pooling=True,
+ frame_shape=(24,24),
+ pooling_shape=(16,8,8),
+)
+preprocess = dict(
+ system="",
+ mm_alone=True,
+ random_shuffle=True,
+ add_second_msg=True,
+ roles=['USER:', 'ASSISTANT:'],
+ end_signal=(' ', ''),
+ begin_signal='',
+ dataset_image_placeholder='',
+ dataset_video_placeholder='',
+ image_token_index=32000,
+ max_txt_l = "${max_txt_l}",
+ ignore_index=-100, # same as torch softmax ignore index
+ center_pad=False,
+ longest_edge=762,
+ shortest_edge=336,
+ clip_transform=False,
+ num_frames="${num_frames}",
+)
+
+
+optimizer = dict(
+ opt="adamW",
+ lr=2e-5,
+ opt_betas=[0.9, 0.999], # default
+ weight_decay=0.02,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+# scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6)
+# scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6)
+scheduler = dict(
+ is_videochat2_custom=False,
+ sched="cosine",
+ epochs=2,
+ warmup_ratio=0.2,
+ min_lr_multi=0.25)
+
+evaluate = False
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+fp16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="videochat2", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "it"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 5
+metric_window_size=10 # window size for metric
+seed = 42
+report_to='tensorboard'
+save_latest = True
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
diff --git a/tasks/train/config_pllava_nframe_yiprompt.py b/tasks/train/config_pllava_nframe_yiprompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ea7adeb7714ec9c60a6374bbbfc5dcf190c0e61
--- /dev/null
+++ b/tasks/train/config_pllava_nframe_yiprompt.py
@@ -0,0 +1,135 @@
+from tasks.train.instruction_data import *
+
+# ========================= data ==========================
+# train_corpus = "videochat2_instruction"
+train_corpus = "videochat2_instruction_full"
+
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict()
+test_types = []
+num_workers = 8
+save_steps=10000
+ckpt_steps=1000
+stop_key = None
+deepspeed=False
+highres=None
+# ========================= input ==========================
+num_frames = 16
+num_frames_test = 1
+batch_size = 1
+gradient_accumulation_steps=16
+max_txt_l = 512
+max_train_steps=None
+pre_text = False
+gradient_checkpointing=False
+inputs = dict(
+ image_res=336,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
+)
+
+model = dict(
+ repo_id="llava-hf/llava-1.5-7b-hf",
+ pretrained_path=None,
+ load_from_origin=False,
+ origin_vision="",
+ origin_llm="",
+ vision_encoder=dict(
+ name="vit_l14", # somehow need this to tell the dataset the mean std of pretrained model
+ ),
+ torch_dtype='bfloat16',
+ freeze_projector=False,
+ freeze_lm=True,
+ freeze_vision_tower=True,
+ lora_target_modules=["q_proj", "v_proj"], # for llama/mistral/gemma
+ use_lora=True,
+ lora_r=128,
+ lora_alpha=32,
+ lora_dropout=0.05,
+ num_frames="${num_frames}",
+ pooling_method='avg',
+ use_pooling=True,
+ frame_shape=(24,24),
+ pooling_shape=(16,8,8),
+)
+preprocess = dict(
+ system="",
+ mm_alone=True,
+ image_token_index=64002,
+ random_shuffle=True,
+ add_second_msg=True,
+ roles=['<|im_start|>user\n', '<|im_start|>assistant\n'],
+ end_signal=('<|im_end|>\n', '<|im_end|>\n'),
+ begin_signal='',
+ dataset_image_placeholder='',
+ dataset_video_placeholder='',
+ max_txt_l = "${max_txt_l}",
+ ignore_index=-100, # same as torch softmax ignore index
+ center_pad=False,
+ longest_edge=762,
+ shortest_edge=336,
+ clip_transform=False,
+ num_frames="${num_frames}",
+)
+
+
+optimizer = dict(
+ opt="adamW",
+ lr=2e-5,
+ opt_betas=[0.9, 0.999], # default
+ weight_decay=0.02,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+# scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6)
+# scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6)
+scheduler = dict(
+ is_videochat2_custom=False,
+ sched="cosine",
+ epochs=2,
+ warmup_ratio=0.2,
+ min_lr_multi=0.25)
+
+evaluate = False
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+fp16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="videochat2", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "it"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 5
+metric_window_size=10 # window size for metric
+seed = 42
+report_to='tensorboard'
+save_latest = True
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
diff --git a/tasks/train/instruction_data.py b/tasks/train/instruction_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..58ac2c8bd5c037b6cf4f37a4862c34c75b36757b
--- /dev/null
+++ b/tasks/train/instruction_data.py
@@ -0,0 +1,271 @@
+import os as __os # add "__" if not want to be exported
+from copy import deepcopy as __deepcopy
+import itertools as __itertools
+
+data_root = "DATAS/TRAIN_TEST"
+anno_root_it = f"{data_root}/magic_jsons"
+
+# ============== pretraining datasets=================
+available_corpus = dict(
+ # image
+ # caption_coco=[
+ # f"{anno_root_it}/image/caption/coco/train.json",
+ # f"{data_root}/images/coco",
+ # ],
+ # caption_llava=[
+ # f"{anno_root_it}/image/caption/llava/train.json",
+ # f"{data_root}/images/coco",
+ # ],
+ # caption_minigpt4=[
+ # f"{anno_root_it}/image/caption/minigpt4/train.json",
+ # f"{data_root}/images/minigpt4_align/image",
+ # ],
+ # caption_paragraph_captioning=[
+ # f"{anno_root_it}/image/caption/paragraph_captioning/train.json",
+ # f"{data_root}/images/m3it/image-paragraph-captioning",
+ # ],
+ # caption_textcaps=[
+ # f"{anno_root_it}/image/caption/textcaps/train.json",
+ # f"{data_root}/images/textcaps",
+ # ],
+ # classification_imagenet=[
+ # f"{anno_root_it}/image/classification/imagenet/train.json",
+ # f"{data_root}/images/m3it/imagenet",
+ # ],
+ # classification_coco_itm=[
+ # f"{anno_root_it}/image/classification/coco_itm/train.json",
+ # f"{data_root}/images/coco",
+ # ],
+ # conversation_llava=[
+ # f"{anno_root_it}/image/conversation/llava/train.json",
+ # f"{data_root}/images/coco",
+ # ],
+ # reasoning_clevr=[
+ # f"{anno_root_it}/image/reasoning/clevr/train.json",
+ # f"{data_root}/images/m3it/clevr",
+ # ],
+ # reasoning_visual_mrc=[
+ # f"{anno_root_it}/image/reasoning/visual_mrc/train.json",
+ # f"{data_root}/images/m3it/visual_mrc",
+ # ],
+ # reasoning_llava=[
+ # f"{anno_root_it}/image/reasoning/llava/train.json",
+ # f"{data_root}/images/coco",
+ # ],
+ # vqa_vqav2=[
+ # f"{anno_root_it}/image/vqa/vqav2/train.json",
+ # f"{data_root}/images/m3it/vqav2",
+ # ],
+ # vqa_gqa=[
+ # f"{anno_root_it}/image/vqa/gqa/train.json",
+ # f"{data_root}/images/gqa/images",
+ # ],
+ # vqa_okvqa=[
+ # f"{anno_root_it}/image/vqa/okvqa/train.json",
+ # f"{data_root}/images/m3it/okvqa",
+ # ],
+ # vqa_a_okvqa=[
+ # f"{anno_root_it}/image/vqa/a_okvqa/train.json",
+ # f"{data_root}/images/m3it/a_okvqa",
+ # ],
+ # vqa_viquae=[
+ # f"{anno_root_it}/image/vqa/viquae/train.json",
+ # f"{data_root}/images/viquae_images",
+ # ],
+ # vqa_ocr_vqa=[
+ # f"{anno_root_it}/image/vqa/ocr_vqa/train.json",
+ # f"{data_root}/images/ocr_vqa/images",
+ # ],
+ # vqa_text_vqa=[
+ # f"{anno_root_it}/image/vqa/text_vqa/train.json",
+ # f"{data_root}/images/textvqa",
+ # ],
+ # vqa_st_vqa=[
+ # f"{anno_root_it}/image/vqa/st_vqa/train.json",
+ # f"{data_root}/images/m3it/st-vqa",
+ # ],
+ # vqa_docvqa=[
+ # f"{anno_root_it}/image/vqa/docvqa/train.json",
+ # f"{data_root}/images/docvqa",
+ # ],
+ # origin_llava=[
+ # f"{anno_root_it}/image/origin_llava/train.json",
+ # f"{data_root}/images",
+ # ],
+ # video
+ caption_textvr=[
+ f"{anno_root_it}/video/caption/textvr/train.json",
+ f"{data_root}/videos/TextVR",
+ "video"
+ ],
+ caption_videochat=[
+ f"{anno_root_it}/video/caption/videochat/train.json",
+ f"{data_root}/videos/webvid_10m",
+ "video"
+ ], # not ready, need to read from hdfs
+ caption_webvid=[
+ f"{anno_root_it}/video/caption/webvid/train.json",
+ f"{data_root}/videos/webvid_10m",
+ "video"
+ ], # not ready, need to read from hdfs
+ caption_youcook2=[
+ f"{anno_root_it}/video/caption/youcook2/train.json",
+ f"{data_root}/videos/YouCook2/split_videos",
+ "video"
+ ],
+ classification_k710=[
+ f"{anno_root_it}/video/classification/k710/train.json",
+ f"{data_root}/videos/kinetics",
+ "video"
+ ],
+ classification_ssv2=[
+ f"{anno_root_it}/video/classification/ssv2/train.json",
+ f"{data_root}/videos/20bn-something-something-v2",
+ "video"
+ ],
+ conversation_videochat1=[
+ f"{anno_root_it}/video/conversation/videochat1/train.json",
+ f"{data_root}/videos/webvid_10m",
+ "video"
+ ],# not ready, need to read from hdfs
+ conversation_videochat2=[
+ f"{anno_root_it}/video/conversation/videochat2/train.json",
+ f"{data_root}/videos/InternVid-10M-FLT/videos",
+ "video"
+ ],
+ conversation_videochatgpt=[
+ f"{anno_root_it}/video/conversation/videochatgpt/train.json",
+ f"{data_root}/videos/AVideo_ChatGPT",
+ "video"
+ ],
+ reasoning_next_qa=[
+ f"{anno_root_it}/video/reasoning/next_qa/train.json",
+ f"{data_root}/videos/NExTVideo",
+ "video"
+ ],
+ reasoning_clevrer_qa=[
+ f"{anno_root_it}/video/reasoning/clevrer_qa/train.json",
+ f"{data_root}/videos/CLEVRER",
+ "video"
+ ],
+ reasoning_clevrer_mc=[
+ f"{anno_root_it}/video/reasoning/clevrer_mc/train.json",
+ f"{data_root}/videos/CLEVRER",
+ "video"
+ ],
+ vqa_ego_qa=[
+ f"{anno_root_it}/video/vqa/ego_qa/train.json",
+ f"{data_root}/videos/ego4d_data/split_videos",
+ "video"
+ ],
+ vqa_tgif_frame_qa=[
+ f"{anno_root_it}/video/vqa/tgif_frame_qa/train.json",
+ f"{data_root}/videos/tgif",
+ "video"
+ ],
+ vqa_tgif_transition_qa=[
+ f"{anno_root_it}/video/vqa/tgif_transition_qa/train.json",
+ f"{data_root}/videos/tgif",
+ "video"
+ ],
+ vqa_webvid_qa=[
+ f"{anno_root_it}/video/vqa/webvid_qa/train.json",
+ f"{data_root}/videos/webvid_10m",
+ "video"
+ ],# not ready, need to read from hdfs
+ origin_videochatgpt=[
+ f"{anno_root_it}/video/origin_videochatgpt/train.json",
+ f"{data_root}/videos/Video_ChatGPT",
+ "video"
+ ],
+)
+
+
+
+available_corpus["videochat2_instruction_full"] = [
+ available_corpus["caption_coco"],
+ available_corpus["caption_llava"],
+ available_corpus["caption_minigpt4"],
+ available_corpus["caption_paragraph_captioning"],
+ available_corpus["caption_textcaps"],
+ available_corpus["classification_imagenet"],
+ available_corpus["classification_coco_itm"],
+ available_corpus["conversation_llava"],
+ available_corpus["reasoning_clevr"],
+ available_corpus["reasoning_visual_mrc"],
+ available_corpus["reasoning_llava"],
+ available_corpus["vqa_vqav2"],
+ available_corpus["vqa_gqa"],
+ available_corpus["vqa_okvqa"],
+ available_corpus["vqa_a_okvqa"],
+ available_corpus["vqa_viquae"],
+ available_corpus["vqa_ocr_vqa"],
+ available_corpus["vqa_text_vqa"],
+ available_corpus["vqa_st_vqa"],
+ available_corpus["vqa_docvqa"],
+ available_corpus["caption_textvr"],
+ available_corpus["caption_youcook2"],
+ available_corpus["classification_k710"],
+ available_corpus["classification_ssv2"],
+ available_corpus["conversation_videochat2"],
+ available_corpus["conversation_videochatgpt"],
+ available_corpus["reasoning_next_qa"],
+ available_corpus["reasoning_clevrer_qa"],
+ available_corpus["reasoning_clevrer_mc"],
+ available_corpus["vqa_ego_qa"],
+ available_corpus["vqa_tgif_frame_qa"],
+ available_corpus["vqa_tgif_transition_qa"],
+ available_corpus["conversation_videochat1"],
+ available_corpus["vqa_webvid_qa"],
+ available_corpus["caption_videochat"],
+ available_corpus["caption_webvid"],
+]
+
+available_corpus["videochat2_video"] = [
+ available_corpus["caption_textvr"],
+ available_corpus["caption_youcook2"],
+ available_corpus["classification_k710"],
+ available_corpus["classification_ssv2"],
+ available_corpus["conversation_videochat2"],
+ available_corpus["conversation_videochatgpt"],
+ available_corpus["reasoning_next_qa"],
+ available_corpus["reasoning_clevrer_qa"],
+ available_corpus["reasoning_clevrer_mc"],
+ available_corpus["vqa_ego_qa"],
+ available_corpus["vqa_tgif_frame_qa"],
+ available_corpus["vqa_tgif_transition_qa"],
+ available_corpus["conversation_videochat1"],
+ available_corpus["vqa_webvid_qa"],
+ available_corpus["caption_videochat"],
+ available_corpus["caption_webvid"],
+]
+
+
+
+
+# ============== for debug=================
+available_corpus["videochat2_instruction_debug"] = [
+ # available_corpus["caption_minigpt4"],
+ available_corpus["caption_textvr"],
+ # available_corpus["vqa_ego_qa"],
+ # available_corpus["classification_k710"],
+ # available_corpus["reasoning_next_qa"],
+ # available_corpus["caption_textvr"],
+ # available_corpus["caption_youcook2"],
+
+ # available_corpus["caption_textcaps"], # realistic caption foucsing in real life text
+ # available_corpus["caption_textvr"], # good realistic captioning, also focusing on text
+]
+
+
+if __name__ == '__main__':
+ print(len(list(
+ __itertools.chain(
+ available_corpus['conversation_data'],
+ available_corpus['reasoning_data'],
+ available_corpus['conversation_videochat2'],
+ available_corpus['caption_data'],
+ available_corpus['classification_data'],
+ )
+ )))
+ print(len(available_corpus['videochat2_instruction_full']))
\ No newline at end of file
diff --git a/tasks/train/train_pllava_nframe_accel.py b/tasks/train/train_pllava_nframe_accel.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f02309d20ac3629f6f5382b697404d1ffcba96e
--- /dev/null
+++ b/tasks/train/train_pllava_nframe_accel.py
@@ -0,0 +1,545 @@
+import datetime
+import gc
+import time
+import os
+import os.path as osp
+import re
+import itertools
+import functools
+import random
+import math
+import shutil
+from typing import Optional, Union
+
+import torch
+import numpy as np
+from safetensors import safe_open
+
+import logging
+from accelerate.logging import get_logger
+from accelerate import Accelerator, DistributedType
+from accelerate.utils import set_seed
+from peft import get_peft_model, LoraConfig, TaskType
+
+
+from dataset import create_dataset, create_loader
+from tasks.shared_utils import get_media_types
+from utils.basic_utils import (MetricLogger, SmoothedValue, setup_seed)
+from utils.config_utils import setup_main
+from transformers.utils import TensorType
+
+from tasks.shared_utils import create_optimizer, create_scheduler
+import copy
+from transformers import (
+ DataCollatorWithPadding,
+ get_scheduler,
+ AutoModel,
+ AutoModelForCausalLM
+ )
+from models.pllava import PllavaConfig, PllavaForConditionalGeneration, PllavaProcessor
+
+# logger = logging.getLogger(__name__)
+IMAGE_TOKEN=''
+
+logger = get_logger(__name__)
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ print(name, 'no ignore status')
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+def get_state_maybe_zero_3(named_params, keys_to_match=["lora_","multi_modal_projector"]):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
+ return to_return
+
+def setup_dataloaders(config, mode="pt", collate_fn=None):
+ # train datasets, create a list of data loaders
+ logger.info(f"Creating dataset for {mode}")
+ train_datasets = create_dataset(f"{mode}_train", config)
+
+ media_types = get_media_types(train_datasets)
+ samplers = [None] * len(media_types)
+
+ train_loaders = create_loader(
+ train_datasets,
+ samplers,
+ batch_size=[config.inputs.batch_size[k] for k in media_types],
+ num_workers=[config.num_workers] * len(media_types),
+ is_trains=[True] * len(media_types),
+ collate_fns=[collate_fn] * len(media_types),
+ ) # [0]
+
+ return train_loaders, media_types
+
+
+def setup_model(
+ config, find_unused_parameters=False
+):
+ if config.model.torch_dtype in ('bfloat16', 'float16', 'float32'):
+ torch_dtype = eval(f'torch.{config.model.torch_dtype}')
+ else:
+ torch_dtype = config.model.torch_dtype
+ logger.info("Creating model")
+
+ processor = PllavaProcessor.from_pretrained(config.model.repo_id,
+ padding_side='right',
+ center_pad=config.preprocess.center_pad,
+ )
+
+
+ model_config = PllavaConfig.from_pretrained(config.model.repo_id,
+ torch_dtype=torch_dtype,
+ num_frames=config.model.num_frames,
+ pooling_method=config.model.pooling_method,
+ image_token_index=config.preprocess.image_token_index,
+ frame_shape=config.model.frame_shape,
+ pooling_shape=config.model.pooling_shape,
+ use_pooling=config.model.use_pooling,
+ gradient_checkpointing=config.gradient_checkpointing,
+ )
+ print("====>gradient_checkpointing",model_config.gradient_checkpointing)
+
+ model = PllavaForConditionalGeneration.from_pretrained(config.model.repo_id, config=model_config, torch_dtype=torch_dtype)
+
+ if config.model.load_from_origin:
+ with torch.no_grad():
+ lm_model = AutoModelForCausalLM.from_pretrained(config.model.origin_llm, torch_dtype=torch_dtype, device_map="cpu",)
+ with torch.no_grad():
+ clip = AutoModel.from_pretrained(config.model.origin_vision, torch_dtype=torch_dtype, device_map="cpu",)
+ msg = model.vision_tower.load_state_dict(clip.state_dict(), strict=False)
+ # print(msg)
+ msg = model.language_model.load_state_dict(lm_model.state_dict(), strict=False)
+ print(msg)
+
+
+ if config.model.freeze_lm:
+ logger.info("freezing parameters in model.language_model")
+ for p in model.language_model.parameters():
+ p.requires_grad = False
+
+ if config.model.freeze_projector:
+ logger.info("freezing parameters in model.multi_modal_projector")
+ for p in model.multi_modal_projector.parameters():
+ p.requires_grad = False
+
+ if config.model.freeze_vision_tower:
+ logger.info("freezing parameters in model.vision_tower")
+ for p in model.vision_tower.parameters():
+ p.requires_grad = False
+
+ if config.model.use_lora:
+ logger.info("getting LoRA Language Model")
+ kwargs = {}
+ if config.model.lora_target_modules is not None and len(config.model.lora_target_modules) > 0:
+ kwargs.update({"target_modules": config.model.lora_target_modules})
+ peft_config = LoraConfig(
+ task_type=TaskType.CAUSAL_LM, inference_mode=False,
+ r=config.model.lora_r, lora_alpha=config.model.lora_alpha, lora_dropout=config.model.lora_dropout,
+ **kwargs
+ )
+ model.language_model = get_peft_model(model.language_model, peft_config)
+ model.language_model.print_trainable_parameters()
+
+ if config.model.pretrained_path is not None and not config.deepspeed:
+ logger.info("======> loading pretrained weights from " + str(config.model.pretrained_path))
+ state_dict = {}
+ save_fnames = os.listdir(config.model.pretrained_path)
+ if "model.safetensors" in save_fnames:
+ print("Loading weight from", config.model.pretrained_path, "model.safetensors")
+ with safe_open(f"{config.model.pretrained_path}/model.safetensors", framework="pt", device="cpu") as f:
+ for k in f.keys():
+ state_dict[k] = f.get_tensor(k)
+ else:
+ print("Loading weight from", config.model.pretrained_path)
+ for fn in save_fnames:
+ if fn.startswith('model-0000'):
+ with safe_open(f"{config.model.pretrained_path}/{fn}", framework="pt", device="cpu") as f:
+ for k in f.keys():
+ state_dict[k] = f.get_tensor(k)
+
+ if 'model' in state_dict.keys():
+ msg = model.load_state_dict(state_dict['model'], strict=False)
+ else:
+ msg = model.load_state_dict(state_dict, strict=False)
+ logger.info(msg)
+ logger.info("=====> Finish loading")
+
+ return model, processor
+
+def setup_optimizer_and_scheduler(config, model):
+ optimizer = create_optimizer(config.optimizer, model) # do you want to filter bias and bn?
+ if config.scheduler.is_videochat2_custom:
+ scheduler = create_scheduler(config.scheduler, optimizer)
+ else:
+ scheduler=None
+
+ return optimizer, scheduler
+
+class RandomMappingIterator():
+ # a random iter through the multiple mapping style dataloaders
+ def __init__(self, train_loaders, media_types, resume_step=0):
+ self.train_loaders = train_loaders
+ self.media_types = media_types
+ self.total_num_samples = sum(len(train_loader) for train_loader in self.train_loaders)
+ self.weights = [len(loader) / self.total_num_samples for loader in train_loaders]
+ self.resume_step = resume_step
+ if resume_step != 0:
+ self.total_num_samples= self.total_num_samples-resume_step
+ # remove corresponding iters from each loader
+
+
+ def __iter__(self):
+ train_loaders = self.train_loaders
+ iters = [iter(train_loader) for train_loader in train_loaders]
+
+ media_types = copy.deepcopy(self.media_types)
+ weights = copy.deepcopy(self.weights)
+ while len(iters) > 0:
+ index = np.random.choice(list(range(len(iters))), p=weights, replace=True)
+ try:
+ batch = next(iters[index])
+ except StopIteration as e:
+ iters.pop(index)
+ media_types.pop(index)
+ weights.pop(index)
+ total = sum(weights)
+ weights = [w/total for w in weights]
+ continue
+
+ media_type = media_types[index]
+ yield media_type, batch
+
+ def __len__(self):
+ return self.total_num_samples
+
+def split_and_record_separators(input_string, separators) -> list:
+ texts = [input_string]
+ for sep in separators:
+ new_texts = []
+ for text in texts:
+ if sep not in text:
+ new_texts.append(text)
+ else:
+ split_strings = text.split(sep)
+ joint_strings = [t for pair in zip(split_strings[:-1], itertools.repeat(sep)) for t in pair ] + split_strings[-1:]
+ new_texts.extend(joint_strings)
+ texts = new_texts
+ return texts
+
+def preprocess(
+ batch,
+ args,
+ processor,
+ collate_fn,
+ dtype=torch.bfloat16,
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
+):
+ tokenizer = processor.tokenizer
+ # tokenization for training
+ max_length = args.max_txt_l
+ input_list, images = [], []
+ for sample in batch:
+ image, tex, instruction, index = sample # (nframe, 3, h, w), (0-255)
+ num_img = image.shape[0]
+ tex = tex.replace(args.dataset_video_placeholder, IMAGE_TOKEN).replace(args.dataset_image_placeholder, IMAGE_TOKEN)
+ seps = [role for role in args.roles]
+ segs = split_and_record_separators(tex, seps)
+ input_ids, labels, attention_mask = [], [], []
+
+ for i, seg in enumerate(segs):
+ seg_ignore = False if seg == seps[1] else \
+ (True if i == 0 or seg in seps else seg_ignore) # not ignoring assistant, changing in sepecific situations
+ current_ignore = True if seg in seps else seg_ignore # serve for only this one iteration
+ seg_input_ids = tokenizer.encode(seg, add_special_tokens=True if i==0 else False) # only add bos token
+ seg_labels = [args.ignore_index] * len(seg_input_ids) if current_ignore else seg_input_ids
+ seg_attention_mask = [1] * len(seg_input_ids) # do attend
+ input_ids.extend(seg_input_ids)
+ labels.extend(seg_labels)
+ attention_mask.extend(seg_attention_mask)
+
+ pad_length = max_length - len(input_ids)
+ labels = labels[:max_length]
+ attention_mask = attention_mask[:max_length]
+ input_ids=input_ids[:max_length]
+
+ labels = labels + [args.ignore_index] * pad_length # padding doesn't take care of labels. do the padding here
+ input_ids = input_ids + [tokenizer.pad_token_id] * pad_length
+ attention_mask = attention_mask + [0]*pad_length
+ sample_input = {
+ 'input_ids': input_ids,
+ 'labels': labels,
+ 'attention_mask': attention_mask,
+ }
+ input_list.append(sample_input)
+ images.append(image if image.ndim==4 else image.unsqueeze(0)) # made 4 dim for image, remain 4 dim for video
+
+ inputs = collate_fn(input_list)
+
+ # interpolate frames if the total frame is smaller than needed
+ for i, video in enumerate(images):
+ if video.shape[0] < args.num_frames:
+ multiplier = int(args.num_frames/video.shape[0]) + 1
+ video = video.repeat_interleave(multiplier, dim=0)[:args.num_frames]
+ images[i] = video
+ assert video.shape[0] == args.num_frames
+ if args.clip_transform:
+ multimodal_features = processor(images=images)
+ inputs.update(**multimodal_features)
+ else:
+ inputs["pixel_values"] = torch.concat(images) # already processed to features in dataset get item
+
+
+ return inputs
+
+def main(config):
+ accelerator_log_kwargs=dict(
+ log_with=config.report_to,
+ project_dir=config.output_dir
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
+ **accelerator_log_kwargs
+ )
+ logger.info(f"train_file: {config.train_file}")
+ model, processor = setup_model(
+ config,
+ find_unused_parameters=True,
+ )
+ if accelerator.is_main_process:
+ logger.setLevel(logging.INFO)
+ else:
+ logger.setLevel(logging.WARNING)
+
+ collate_fn = DataCollatorWithPadding(tokenizer=processor.tokenizer, padding='max_length', max_length=config.max_txt_l, return_tensors='pt',)
+ collate_fn = functools.partial(preprocess, args=config.preprocess, processor=processor, collate_fn=collate_fn)
+ train_loaders, train_media_types = setup_dataloaders(config, mode=config.mode, collate_fn=collate_fn)
+ num_steps_per_epoch = math.ceil(sum(len(d) for d in train_loaders) / config.gradient_accumulation_steps)
+ # load optimizer and custom scheduler
+ config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs
+ config.scheduler.num_warmup_steps = math.ceil(config.scheduler.num_training_steps * config.scheduler.warmup_ratio)
+ optimizer, lr_scheduler = setup_optimizer_and_scheduler(config, model)
+ # if not set customized scheduler, default hf scheduler
+ overrode_max_train_steps = False
+ if config.max_train_steps is None:
+ config.max_train_steps = config.scheduler.epochs * num_steps_per_epoch
+ overrode_max_train_steps = True
+ if lr_scheduler is None:
+ lr_scheduler = get_scheduler(
+ name=config.scheduler.sched,
+ optimizer=optimizer,
+ num_warmup_steps=config.scheduler.num_warmup_steps,
+ num_training_steps=config.max_train_steps
+ if overrode_max_train_steps
+ else config.max_train_steps * accelerator.num_processes,
+ )
+ model, optimizer, lr_scheduler, *train_loaders = accelerator.prepare(
+ model, optimizer, lr_scheduler, *train_loaders
+ )
+
+ if hasattr(config, 'seed'):
+ set_seed(config.seed)
+
+ experiment_config = { # include all the important hyperparam
+ 'num_frames': config.num_frames,
+ 'max_txt_l': config.max_txt_l,
+ 'batch_size': config.batch_size,
+ }
+
+ model.train()
+
+ start_epoch = 0
+ num_batches = sum(len(loader) for loader in train_loaders)
+ global_step = start_epoch * num_batches # the steps before divided by accumulation
+ if osp.exists(config.output_dir):
+ subfolders = os.listdir(config.output_dir)
+ sample_saving = False
+ for subfolder in subfolders:
+ if subfolder.endswith("M"):
+ sample_saving = True
+ if sample_saving:
+ ckpt_paths = [subfolder for subfolder in subfolders if re.match(r'ckpt_resume_[\d.]+M$', subfolder) is not None]
+ ckpt_iters = [float(re.findall(r'[\d.]+', x)[0]) for x in ckpt_paths]
+ else:
+ ckpt_paths = [subfolder for subfolder in subfolders if re.match("ckpt_[^\d]+", subfolder) is not None]
+ ckpt_iters = [int(s.split(re.match("ckpt_[^\d]+", s).group())[-1]) for s in ckpt_paths]
+
+
+ resume_cur_epoch_step=0
+ if len(ckpt_iters) > 0:
+ resume_iter = max(ckpt_iters)
+ ckpt_path = osp.join(config.output_dir, ckpt_paths[ckpt_iters.index(resume_iter)])
+ accelerator.print(f"Resumed from checkpoint: {ckpt_path}")
+ accelerator.load_state(ckpt_path)
+ if sample_saving:
+ resume_iter = int(resume_iter*1e6/(config.batch_size*accelerator.state.num_processes))
+
+ if "epoch" in ckpt_path:
+ start_epoch = int(resume_iter) + 1
+ resume_cur_epoch_step = 0
+ global_step = start_epoch * num_batches
+ else:
+ # need to multiply `gradient_accumulation_steps` to reflect real steps
+ # num_finish_smaple = int(max_ckpt_num) * config.gradient_accumulation_steps
+ start_epoch = resume_iter // num_batches
+ global_step = resume_iter
+ resume_cur_epoch_step = resume_iter - start_epoch * num_batches
+ accelerator.print(f"Resume from epoch {start_epoch}, steps{resume_cur_epoch_step}")
+
+
+
+ # TensorBoard cannot log Enums, need the raw value
+ accelerator.init_trackers("train_pllava_nframe", experiment_config)
+ start_time = time.time()
+
+
+
+ logger.info(f"Start training {str(start_time)}, from start_epoch-{start_epoch}, step-{resume_cur_epoch_step}")
+
+ # skip the first `n` batches in the dataloader when resuming from a checkpoint
+ active_train_loaders = train_loaders
+ if resume_cur_epoch_step > 0:
+ active_train_loaders = []
+ total_dta_num = sum(len(train_loader) for train_loader in train_loaders)
+ for train_loader in train_loaders:
+ skip_batch_num = int((resume_cur_epoch_step/total_dta_num)*len(train_loader))
+ skipped_train_loader = accelerator.skip_first_batches(train_loader, num_batches=skip_batch_num)
+ active_train_loaders.append(skipped_train_loader)
+
+ media_types = get_media_types(active_train_loaders)
+ train_loader = RandomMappingIterator(active_train_loaders, media_types)
+
+ for epoch in range(start_epoch, config.scheduler.epochs):
+ if not config.evaluate:
+ gc.collect()
+ torch.cuda.empty_cache()
+ metric_logger = MetricLogger(delimiter=" ")
+ loss_names = ["loss"]
+ for name in loss_names:
+ for m in media_types:
+ metric_logger.add_meter(
+ f"{m}-{name}", SmoothedValue(window=config.metric_window_size, fmt="{value:.4f}")
+ )
+
+ header = f"Train Epoch: [{epoch}]"
+ log_freq = config.log_freq
+
+ iterator = metric_logger.log_every(train_loader, log_freq, header)
+ mini_batch_losses = []
+
+ for i, (media_type, inputs) in enumerate(iterator): # video/image, conversation, instruction, index
+
+ with accelerator.accumulate(model):
+
+ inputs['media_type'] = media_type
+ response = model(**inputs)
+ loss = response.loss
+ mini_batch_losses.append(loss.detach().item())
+ optimizer.zero_grad()
+ accelerator.backward(loss)
+ if config.optimizer.max_grad_norm > 0:
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ # # logging
+ for name in loss_names:
+ value = loss
+ value = value if isinstance(value, float) else value.item()
+ metric_logger.update(**{f"{media_type}-{name}": value})
+ global_step += 1
+ resume_num_samples = global_step * config.batch_size * accelerator.state.num_processes/1e6
+
+ # save small global step checkpoint in case of breakdown
+ if global_step % config.ckpt_steps == 0:
+ accelerator.save_state(output_dir=osp.join(config.output_dir, f"ckpt_resume_{resume_num_samples:.4f}M"))
+ if accelerator.is_main_process:
+ for fn in os.listdir(config.output_dir):
+ if "resume" in fn and fn != f"ckpt_resume_{resume_num_samples:.4f}M":
+ shutil.rmtree(osp.join(config.output_dir, fn))
+
+ if global_step % config.save_steps == 0:
+ logger.info(f"global_step {global_step}")
+ with torch.no_grad():
+ accelerator.wait_for_everyone()
+ unwrapped_model = accelerator.unwrap_model(model)
+ if not config.deepspeed:
+ save_state_dict = {k:v for k,v in accelerator.get_state_dict(model).items() if "lora_" in k or "multi_modal_projector" in k}
+ else:
+ save_state_dict = accelerator.get_state_dict(model)
+ unwrapped_model.save_pretrained(osp.join(config.output_dir, f"pretrained_step{resume_num_samples:.4f}M"),
+ is_main_process=accelerator.is_main_process,
+ save_function=accelerator.save,
+ state_dict=save_state_dict)
+ processor.save_pretrained(osp.join(config.output_dir, f"pretrained_step{resume_num_samples:.4f}M"))
+
+ if global_step % log_freq == 0:
+ logs = metric_logger.get_global_avg_dict()
+ logs.update({
+ "step_loss_no_smoothing": accelerator.gather_for_metrics(loss).mean().item(),
+ "epoch": epoch,
+ "step": global_step,
+ "lr": lr_scheduler.get_last_lr()[0],
+ })
+ accelerator.log(logs, step=global_step,)
+ if accelerator.sync_gradients:
+ mini_batch_loss = torch.tensor(mini_batch_losses, device='cuda')
+ accelerator.log({"mini_batch_loss": accelerator.gather_for_metrics(mini_batch_loss).mean().item()},
+ step=global_step)
+ mini_batch_losses = []
+
+
+ if config.debug and global_step % 20 == 0:
+ logger.info("debug mode, break training loop")
+ break
+
+ if config.debug and global_step % (2 * log_freq + 3) == 0:
+ logger.info("debug mode, break training loop")
+ break
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ logger.info(f"Averaged stats: {metric_logger.global_avg()}")
+ logger.info(f"Epoch {epoch}")
+ with torch.no_grad():
+ accelerator.wait_for_everyone()
+ unwrapped_model = accelerator.unwrap_model(model)
+ if not config.deepspeed:
+ save_state_dict = {k:v for k,v in accelerator.get_state_dict(model).items() if "lora_" in k or "multi_modal_projector" in k}
+ else:
+ save_state_dict = accelerator.get_state_dict(model)
+ unwrapped_model.save_pretrained(osp.join(config.output_dir, f"pretrained_epoch{epoch:02d}"),
+ is_main_process=accelerator.is_main_process,
+ save_function=accelerator.save,
+ state_dict=save_state_dict)
+ processor.save_pretrained(osp.join(config.output_dir, f"pretrained_step{epoch:02d}"))
+ accelerator.save_state(output_dir=osp.join(config.output_dir, f"ckpt_epoch{epoch:02d}"))
+
+
+ if config.evaluate:
+ break
+
+ accelerator.end_training()
+ accelerator.wait_for_everyone()
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logger.info(f"Training time {total_time_str}")
+ logger.info(f"Checkpoints and Logs saved at {config.output_dir}")
+
+
+
+if __name__ == "__main__":
+ cfg = setup_main()
+ print(cfg)
+ main(cfg)
diff --git a/utils/basic_utils.py b/utils/basic_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb453d35c852741bf1ad6dfe27e604d9fef6557e
--- /dev/null
+++ b/utils/basic_utils.py
@@ -0,0 +1,286 @@
+import numpy as np
+import io
+import os
+import json
+import logging
+import random
+import time
+from collections import defaultdict, deque
+import datetime
+from pathlib import Path
+from typing import List, Union
+
+import torch
+import torch.distributed as dist
+from .distributed import is_dist_avail_and_initialized
+
+
+logger = logging.getLogger(__name__)
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total],
+ dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ if meter.count == 0: # skip empty meter
+ loss_str.append(
+ "{}: {}".format(name, "No data")
+ )
+ else:
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def global_avg(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ if meter.count == 0:
+ loss_str.append(
+ "{}: {}".format(name, "No data")
+ )
+ else:
+ loss_str.append(
+ "{}: {:.4f}".format(name, meter.global_avg)
+ )
+ return self.delimiter.join(loss_str)
+
+ def get_global_avg_dict(self, prefix=""):
+ """include a separator (e.g., `/`, or "_") at the end of `prefix`"""
+ d = {f"{prefix}{k}": m.global_avg if m.count > 0 else 0. for k, m in self.meters.items()}
+ return d
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, log_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ log_msg = [
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ]
+ if torch.cuda.is_available():
+ log_msg.append('max mem: {memory:.0f} res mem: {res_mem:.0f}')
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % log_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ logger.info(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB,
+ res_mem=torch.cuda.max_memory_reserved() / MB,
+ ))
+ else:
+ logger.info(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logger.info('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def compute_acc(logits, label, reduction='mean'):
+ ret = (torch.argmax(logits, dim=1) == label).float()
+ if reduction == 'none':
+ return ret.detach()
+ elif reduction == 'mean':
+ return ret.mean().item()
+
+
+def compute_n_params(model, return_str=True):
+ tot = 0
+ for p in model.parameters():
+ w = 1
+ for x in p.shape:
+ w *= x
+ tot += w
+ if return_str:
+ if tot >= 1e6:
+ return '{:.1f}M'.format(tot / 1e6)
+ else:
+ return '{:.1f}K'.format(tot / 1e3)
+ else:
+ return tot
+
+
+def setup_seed(seed):
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def remove_files_if_exist(file_paths):
+ for fp in file_paths:
+ if os.path.isfile(fp):
+ os.remove(fp)
+
+
+def save_json(data, filename, save_pretty=False, sort_keys=False):
+ with open(filename, "w") as f:
+ if save_pretty:
+ f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
+ else:
+ json.dump(data, f)
+
+
+def load_json(filename):
+ with open(filename, "r") as f:
+ return json.load(f)
+
+
+def flat_list_of_lists(l):
+ """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
+ return [item for sublist in l for item in sublist]
+
+
+def find_files_by_suffix_recursively(root: str, suffix: Union[str, List[str]]):
+ """
+ Args:
+ root: path to the directory to start search files
+ suffix: any str as suffix, or can match multiple such strings
+ when input is List[str].
+ Example 1, e.g., suffix: `.jpg` or [`.jpg`, `.png`]
+ Example 2, e.g., use a `*` in the `suffix`: `START*.jpg.`.
+ """
+ if isinstance(suffix, str):
+ suffix = [suffix, ]
+ filepaths = flat_list_of_lists(
+ [list(Path(root).rglob(f"*{e}")) for e in suffix])
+ return filepaths
+
+
+def match_key_and_shape(state_dict1, state_dict2):
+ keys1 = set(state_dict1.keys())
+ keys2 = set(state_dict2.keys())
+ print(f"keys1 - keys2: {keys1 - keys2}")
+ print(f"keys2 - keys1: {keys2 - keys1}")
+
+ mismatch = 0
+ for k in list(keys1):
+ if state_dict1[k].shape != state_dict2[k].shape:
+ print(
+ f"k={k}, state_dict1[k].shape={state_dict1[k].shape}, state_dict2[k].shape={state_dict2[k].shape}")
+ mismatch += 1
+ print(f"mismatch {mismatch}")
+
+
+def merge_dicts(list_dicts):
+ merged_dict = list_dicts[0].copy()
+ for i in range(1, len(list_dicts)):
+ merged_dict.update(list_dicts[i])
+ return merged_dict
diff --git a/utils/config.py b/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..63f9ef375b37daa6926f2259502913e38f22e6e2
--- /dev/null
+++ b/utils/config.py
@@ -0,0 +1,281 @@
+from __future__ import annotations
+
+import argparse
+import ast
+import json
+import os
+import os.path as osp
+import re
+import shutil
+import sys
+import tempfile
+from copy import deepcopy
+from importlib import import_module
+
+import yaml
+
+from .easydict import EasyDict
+
+__all__ = ["Config", "pretty_text"]
+
+
+BASE_KEY = "_base_"
+# BASE_CONFIG = {"OUTPUT_DIR": "./workspace", "SESSION": "base", "LOG_FILE": "log.txt"}
+BASE_CONFIG = {}
+
+cfg = None
+
+
+class Config(object):
+ """config"""
+
+ @classmethod
+ def pretty_text(cls, cfg: dict, indent=2) -> str:
+ """format dict to a string
+
+ Args:
+ cfg (EasyDict): the params.
+
+ Returns: The string to display.
+
+ """
+ msg = "{\n"
+ for i, (k, v) in enumerate(cfg.items()):
+ if isinstance(v, dict):
+ v = cls.pretty_text(v, indent + 4)
+ spaces = " " * indent
+ msg += spaces + "{}: {}".format(k, v)
+ if i == len(cfg) - 1:
+ msg += " }"
+ else:
+ msg += "\n"
+ return msg
+
+ @classmethod
+ def dump(cls, cfg, savepath=None):
+ """dump cfg to `json` file.
+
+ Args:
+ cfg (dict): The dict to dump.
+ savepath (str): The filepath to save the dumped dict.
+
+ Returns: TODO
+
+ """
+ if savepath is None:
+ savepath = osp.join(cfg.WORKSPACE, "config.json")
+ json.dump(cfg, open(savepath, "w"), indent=2)
+
+ @classmethod
+ def get_config(cls, default_config: dict = None):
+ """get a `Config` instance.
+
+ Args:
+ default_config (dict): The default config. `default_config` will be overrided
+ by config file `--cfg`, `--cfg` will be overrided by commandline args.
+
+ Returns: an EasyDict.
+ """
+ global cfg
+ if cfg is not None:
+ return cfg
+
+ # define arg parser.
+ parser = argparse.ArgumentParser()
+ # parser.add_argument("--cfg", help="load configs from yaml file", default="", type=str)
+ parser.add_argument(
+ "config_file", help="the configuration file to load. support: .yaml, .json, .py"
+ )
+ parser.add_argument(
+ "opts",
+ default=None,
+ nargs="*",
+ help="overrided configs. List. Format: 'key1 name1 key2 name2'",
+ )
+ args = parser.parse_args()
+
+ cfg = EasyDict(BASE_CONFIG)
+ if osp.isfile(args.config_file):
+ cfg_from_file = cls.from_file(args.config_file)
+ cfg = merge_a_into_b(cfg_from_file, cfg)
+ cfg = cls.merge_list(cfg, args.opts)
+ cfg = eval_dict_leaf(cfg)
+
+ # update some keys to make them show at the last
+ for k in BASE_CONFIG:
+ cfg[k] = cfg.pop(k)
+ return cfg
+
+ @classmethod
+ def from_file(cls, filepath: str) -> EasyDict:
+ """Build config from file. Supported filetypes: `.py`,`.yaml`,`.json`.
+
+ Args:
+ filepath (str): The config file path.
+
+ Returns: TODO
+
+ """
+ filepath = osp.abspath(osp.expanduser(filepath))
+ if not osp.isfile(filepath):
+ raise IOError(f"File does not exist: {filepath}")
+ if filepath.endswith(".py"):
+ with tempfile.TemporaryDirectory() as temp_config_dir:
+
+ shutil.copytree(osp.dirname(filepath), osp.join(temp_config_dir, "tmp_config"))
+ sys.path.insert(0, temp_config_dir)
+ mod = import_module("tmp_config." + osp.splitext(osp.basename(filepath))[0])
+ # mod = import_module(temp_module_name)
+ sys.path.pop(0)
+ cfg_dict = {
+ name: value
+ for name, value in mod.__dict__.items()
+ if not name.startswith("__")
+ }
+ for k in list(sys.modules.keys()):
+ if "tmp_config" in k:
+ del sys.modules[k]
+ elif filepath.endswith((".yml", ".yaml")):
+ cfg_dict = yaml.load(open(filepath, "r"), Loader=yaml.Loader)
+ elif filepath.endswith(".json"):
+ cfg_dict = json.load(open(filepath, "r"))
+ else:
+ raise IOError("Only py/yml/yaml/json type are supported now!")
+
+ cfg_text = filepath + "\n"
+ with open(filepath, "r") as f:
+ cfg_text += f.read()
+
+ if BASE_KEY in cfg_dict: # load configs in `BASE_KEY`
+ cfg_dir = osp.dirname(filepath)
+ base_filename = cfg_dict.pop(BASE_KEY)
+ base_filename = (
+ base_filename if isinstance(base_filename, list) else [base_filename]
+ )
+
+ cfg_dict_list = list()
+ for f in base_filename:
+ _cfg_dict = Config.from_file(osp.join(cfg_dir, f))
+ cfg_dict_list.append(_cfg_dict)
+
+ base_cfg_dict = dict()
+ for c in cfg_dict_list:
+ if len(base_cfg_dict.keys() & c.keys()) > 0:
+ raise KeyError("Duplicate key is not allowed among bases")
+ base_cfg_dict.update(c)
+
+ cfg_dict = merge_a_into_b(cfg_dict, base_cfg_dict)
+
+ return EasyDict(cfg_dict)
+
+ @classmethod
+ def merge_list(cls, cfg, opts: list):
+ """merge commandline opts.
+
+ Args:
+ cfg: (dict): The config to be merged.
+ opts (list): The list to merge. Format: [key1, name1, key2, name2,...].
+ The keys can be nested. For example, ["a.b", v] will be considered
+ as `dict(a=dict(b=v))`.
+
+ Returns: dict.
+
+ """
+ assert len(opts) % 2 == 0, f"length of opts must be even. Got: {opts}"
+ for i in range(0, len(opts), 2):
+ full_k, v = opts[i], opts[i + 1]
+ keys = full_k.split(".")
+ sub_d = cfg
+ for i, k in enumerate(keys):
+ if not hasattr(sub_d, k):
+ raise ValueError(f"The key {k} not exist in the config. Full key:{full_k}")
+ if i != len(keys) - 1:
+ sub_d = sub_d[k]
+ else:
+ sub_d[k] = v
+ return cfg
+
+
+def merge_a_into_b(a, b, inplace=False):
+ """The values in a will override values in b.
+
+ Args:
+ a (dict): source dict.
+ b (dict): target dict.
+
+ Returns: dict. recursively merge dict a into dict b.
+
+ """
+ if not inplace:
+ b = deepcopy(b)
+ for key in a:
+ if key in b:
+ if isinstance(a[key], dict) and isinstance(b[key], dict):
+ b[key] = merge_a_into_b(a[key], b[key], inplace=True)
+ else:
+ b[key] = a[key]
+ else:
+ b[key] = a[key]
+ return b
+
+
+def eval_dict_leaf(d, orig_dict=None):
+ """eval values of dict leaf.
+
+ Args:
+ d (dict): The dict to eval.
+
+ Returns: dict.
+
+ """
+ if orig_dict is None:
+ orig_dict = d
+ for k, v in d.items():
+ if not isinstance(v, dict):
+ d[k] = eval_string(v, orig_dict)
+ else:
+ eval_dict_leaf(v, orig_dict)
+ return d
+
+
+def eval_string(string, d):
+ """automatically evaluate string to corresponding types.
+
+ For example:
+ not a string -> return the original input
+ '0' -> 0
+ '0.2' -> 0.2
+ '[0, 1, 2]' -> [0,1,2]
+ 'eval(1+2)' -> 3
+ 'eval(range(5))' -> [0,1,2,3,4]
+ '${a}' -> d.a
+
+
+
+ Args:
+ string (str): The value to evaluate.
+ d (dict): The
+
+ Returns: the corresponding type
+
+ """
+ if not isinstance(string, str):
+ return string
+ # if len(string) > 1 and string[0] == "[" and string[-1] == "]":
+ # return eval(string)
+ if string[0:5] == "eval(":
+ return eval(string[5:-1])
+
+ s0 = string
+ s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
+ if s1 != s0:
+ while s1 != s0:
+ s0 = s1
+ s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
+ return eval(s1)
+
+ try:
+ v = ast.literal_eval(string)
+ except:
+ v = string
+ return v
diff --git a/utils/config_utils.py b/utils/config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..72e31c7c922e811e62e2b92e708ab087651c40c2
--- /dev/null
+++ b/utils/config_utils.py
@@ -0,0 +1,60 @@
+import logging
+import os
+import sys
+from os.path import dirname, join
+
+from utils.config import Config
+from utils.distributed import init_distributed_mode, is_main_process
+from utils.logger import setup_logger
+
+logger = logging.getLogger(__name__)
+
+
+def setup_config():
+ """Conbine yaml config and command line config with OmegaConf.
+ Also converts types, e.g., `'None'` (str) --> `None` (None)
+ """
+ config = Config.get_config()
+ if config.debug:
+ config.wandb.enable = False
+ return config
+
+
+def setup_evaluate_config(config):
+ """setup evaluation default settings, e.g., disable wandb"""
+ assert config.evaluate
+ config.wandb.enable = False
+ if config.output_dir is None:
+ config.output_dir = join(dirname(config.pretrained_path), "eval")
+ return config
+
+
+def setup_output_dir(output_dir, excludes=["code"]):
+ """ensure not overwritting an exisiting/non-empty output dir"""
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir, exist_ok=False)
+ else:
+ existing_dirs_files = os.listdir(output_dir) # list
+ remaining = set(existing_dirs_files) - set(excludes)
+ remaining = [e for e in remaining if "slurm" not in e]
+ remaining = [e for e in remaining if ".out" not in e]
+ # assert len(remaining) == 0, f"remaining dirs or files: {remaining}"
+ logger.warn(f"remaining dirs or files: {remaining}")
+
+
+def setup_main():
+ """
+ Setup config, logger, output_dir, etc.
+ Shared for pretrain and all downstream tasks.
+ """
+ config = setup_config()
+ if hasattr(config, "evaluate") and config.evaluate:
+ config = setup_evaluate_config(config)
+ init_distributed_mode(config)
+
+ if is_main_process():
+ setup_output_dir(config.output_dir, excludes=["code"])
+ setup_logger(output=config.output_dir, color=True, name="vindlu")
+ logger.info(f"config: {Config.pretty_text(config)}")
+ Config.dump(config, os.path.join(config.output_dir, "config.json"))
+ return config
diff --git a/utils/distributed.py b/utils/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..780417ec19767ec8b820bec13a0f030b64e2177e
--- /dev/null
+++ b/utils/distributed.py
@@ -0,0 +1,162 @@
+import os
+import torch
+import torch.distributed as dist
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+def setup_for_distributed(is_master):
+ import warnings
+
+ builtin_warn = warnings.warn
+
+ def warn(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ if is_master or force:
+ builtin_warn(*args, **kwargs)
+
+ # Log warnings only once
+ warnings.warn = warn
+ warnings.simplefilter("once", UserWarning)
+
+ if not is_master:
+ logging.disable()
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def is_port_in_use(port):
+ import socket
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ return s.connect_ex(('localhost', port)) == 0
+
+
+def init_distributed_mode(args):
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ # job started by torch.distributed.launch
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ # local rank on the current node / global rank
+ local_rank = int(os.environ['SLURM_LOCALID'])
+ global_rank = int(os.environ['SLURM_PROCID'])
+ # number of processes / GPUs per node
+ world_size = int(os.environ["SLURM_NNODES"]) * \
+ int(os.environ["SLURM_TASKS_PER_NODE"][0])
+
+ print(world_size)
+
+ args.rank = global_rank
+ args.gpu = local_rank
+ args.world_size = world_size
+ else:
+ logger.info('Not using distributed mode')
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+
+ if "tcp" in args.dist_url: # in slurm, multiple program runs in a single node
+ dist_port = int(args.dist_url.split(":")[-1])
+ while is_port_in_use(dist_port):
+ dist_port += 10
+ args.dist_url = ":".join(args.dist_url.split(":")[:-1] + [str(dist_port)])
+ print(args.dist_url)
+
+ logger.info('| distributed init (rank {}): {}'.format(
+ args.rank, args.dist_url))
+ if "SLURM_JOB_ID" in os.environ:
+ logger.info(f"SLURM_JOB_ID {os.environ['SLURM_JOB_ID']}")
+ torch.distributed.init_process_group(
+ backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py
+class GatherLayer(torch.autograd.Function):
+ """
+ Gather tensors from all workers with support for backward propagation:
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
+ """
+
+ @staticmethod
+ def forward(ctx, x):
+ output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
+ dist.all_gather(output, x)
+ return tuple(output)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ all_gradients = torch.stack(grads)
+ dist.all_reduce(all_gradients)
+ return all_gradients[dist.get_rank()]
+
+
+# copied from megavlt
+def gather_tensor_along_batch_with_backward(tensor, dim=0):
+ world_size = get_world_size()
+
+ if world_size < 2:
+ return tensor
+
+ tensor_list = GatherLayer.apply(tensor)
+ tensor_list = torch.cat(tensor_list, dim=dim)
+ return tensor_list
+
+
+@torch.no_grad()
+def gather_tensor_along_batch(tensor, dim=0):
+ """
+ Performs all_gather operation on the provided tensors.
+ *** Warning ***: torch.distributed.all_gather has no gradient.
+ """
+ world_size = get_world_size()
+
+ if world_size < 2:
+ return tensor
+
+ with torch.no_grad():
+ tensor_list = []
+
+ for _ in range(world_size):
+ tensor_list.append(torch.zeros_like(tensor))
+
+ dist.all_gather(tensor_list, tensor)
+ tensor_list = torch.cat(tensor_list, dim=dim)
+ return tensor_list
diff --git a/utils/easydict.py b/utils/easydict.py
new file mode 100644
index 0000000000000000000000000000000000000000..241aca41c9f1b0677be4bf6070c077fa24501816
--- /dev/null
+++ b/utils/easydict.py
@@ -0,0 +1,149 @@
+class EasyDict(dict):
+ """
+ Get attributes
+
+ >>> d = EasyDict({'foo':3})
+ >>> d['foo']
+ 3
+ >>> d.foo
+ 3
+ >>> d.bar
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'EasyDict' object has no attribute 'bar'
+
+ Works recursively
+
+ >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
+ >>> isinstance(d.bar, dict)
+ True
+ >>> d.bar.x
+ 1
+
+ Bullet-proof
+
+ >>> EasyDict({})
+ {}
+ >>> EasyDict(d={})
+ {}
+ >>> EasyDict(None)
+ {}
+ >>> d = {'a': 1}
+ >>> EasyDict(**d)
+ {'a': 1}
+
+ Set attributes
+
+ >>> d = EasyDict()
+ >>> d.foo = 3
+ >>> d.foo
+ 3
+ >>> d.bar = {'prop': 'value'}
+ >>> d.bar.prop
+ 'value'
+ >>> d
+ {'foo': 3, 'bar': {'prop': 'value'}}
+ >>> d.bar.prop = 'newer'
+ >>> d.bar.prop
+ 'newer'
+
+
+ Values extraction
+
+ >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
+ >>> isinstance(d.bar, list)
+ True
+ >>> from operator import attrgetter
+ >>> map(attrgetter('x'), d.bar)
+ [1, 3]
+ >>> map(attrgetter('y'), d.bar)
+ [2, 4]
+ >>> d = EasyDict()
+ >>> d.keys()
+ []
+ >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
+ >>> d.foo
+ 3
+ >>> d.bar.x
+ 1
+
+ Still like a dict though
+
+ >>> o = EasyDict({'clean':True})
+ >>> o.items()
+ [('clean', True)]
+
+ And like a class
+
+ >>> class Flower(EasyDict):
+ ... power = 1
+ ...
+ >>> f = Flower()
+ >>> f.power
+ 1
+ >>> f = Flower({'height': 12})
+ >>> f.height
+ 12
+ >>> f['power']
+ 1
+ >>> sorted(f.keys())
+ ['height', 'power']
+
+ update and pop items
+ >>> d = EasyDict(a=1, b='2')
+ >>> e = EasyDict(c=3.0, a=9.0)
+ >>> d.update(e)
+ >>> d.c
+ 3.0
+ >>> d['c']
+ 3.0
+ >>> d.get('c')
+ 3.0
+ >>> d.update(a=4, b=4)
+ >>> d.b
+ 4
+ >>> d.pop('a')
+ 4
+ >>> d.a
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'EasyDict' object has no attribute 'a'
+ """
+
+ def __init__(self, d=None, **kwargs):
+ if d is None:
+ d = {}
+ if kwargs:
+ d.update(**kwargs)
+ for k, v in d.items():
+ setattr(self, k, v)
+ # Class attributes
+ for k in self.__class__.__dict__.keys():
+ if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
+ setattr(self, k, getattr(self, k))
+
+ def __setattr__(self, name, value):
+ if isinstance(value, (list, tuple)):
+ value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
+ elif isinstance(value, dict) and not isinstance(value, self.__class__):
+ value = self.__class__(value)
+ super(EasyDict, self).__setattr__(name, value)
+ super(EasyDict, self).__setitem__(name, value)
+
+ __setitem__ = __setattr__
+
+ def update(self, e=None, **f):
+ d = e or dict()
+ d.update(f)
+ for k in d:
+ setattr(self, k, d[k])
+
+ def pop(self, k, d=None):
+ if hasattr(self, k):
+ delattr(self, k)
+ return super(EasyDict, self).pop(k, d)
+
+
+if __name__ == "__main__":
+ import doctest
+
diff --git a/utils/logger.py b/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3164ae7251e1f0006173c4f409c0901742048d6
--- /dev/null
+++ b/utils/logger.py
@@ -0,0 +1,263 @@
+# from MMF: https://github.com/facebookresearch/mmf/blob/master/mmf/utils/logger.py
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import functools
+import logging
+import os
+import sys
+import time
+import wandb
+from typing import Any, Dict, Union
+
+import torch
+from .distributed import get_rank, is_main_process
+from termcolor import colored
+
+
+def log_dict_to_wandb(log_dict, step, prefix=""):
+ """include a separator `/` at the end of `prefix`"""
+ if not is_main_process():
+ return
+
+ log_dict = {f"{prefix}{k}": v for k, v in log_dict.items()}
+ wandb.log(log_dict, step)
+
+
+def setup_wandb(config):
+ if not (config.wandb.enable and is_main_process()):
+ return
+
+ run = wandb.init(
+ config=config,
+ project=config.wandb.project,
+ entity=config.wandb.entity,
+ name=os.path.basename(config.output_dir),
+ reinit=True
+ )
+ return run
+
+
+def setup_output_folder(save_dir: str, folder_only: bool = False):
+ """Sets up and returns the output file where the logs will be placed
+ based on the configuration passed. Usually "save_dir/logs/log_.txt".
+ If env.log_dir is passed, logs will be directly saved in this folder.
+ Args:
+ folder_only (bool, optional): If folder should be returned and not the file.
+ Defaults to False.
+ Returns:
+ str: folder or file path depending on folder_only flag
+ """
+ log_filename = "train_"
+ log_filename += time.strftime("%Y_%m_%dT%H_%M_%S")
+ log_filename += ".log"
+
+ log_folder = os.path.join(save_dir, "logs")
+
+ if not os.path.exists(log_folder):
+ os.path.mkdirs(log_folder)
+
+ if folder_only:
+ return log_folder
+
+ log_filename = os.path.join(log_folder, log_filename)
+
+ return log_filename
+
+
+def setup_logger(
+ output: str = None,
+ color: bool = True,
+ name: str = "mmf",
+ disable: bool = False,
+ clear_handlers=True,
+ *args,
+ **kwargs,
+):
+ """
+ Initialize the MMF logger and set its verbosity level to "INFO".
+ Outside libraries shouldn't call this in case they have set there
+ own logging handlers and setup. If they do, and don't want to
+ clear handlers, pass clear_handlers options.
+ The initial version of this function was taken from D2 and adapted
+ for MMF.
+ Args:
+ output (str): a file name or a directory to save log.
+ If ends with ".txt" or ".log", assumed to be a file name.
+ Default: Saved to file
+ color (bool): If false, won't log colored logs. Default: true
+ name (str): the root module name of this logger. Defaults to "mmf".
+ disable: do not use
+ clear_handlers (bool): If false, won't clear existing handlers.
+ Returns:
+ logging.Logger: a logger
+ """
+ if disable:
+ return None
+ logger = logging.getLogger(name)
+ logger.propagate = False
+
+ logging.captureWarnings(True)
+ warnings_logger = logging.getLogger("py.warnings")
+
+ plain_formatter = logging.Formatter(
+ "%(asctime)s | %(levelname)s | %(name)s : %(message)s",
+ datefmt="%Y-%m-%dT%H:%M:%S",
+ )
+
+ distributed_rank = get_rank()
+ handlers = []
+
+ logging_level = logging.INFO
+ # logging_level = logging.DEBUG
+
+ if distributed_rank == 0:
+ logger.setLevel(logging_level)
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging_level)
+ if color:
+ formatter = ColorfulFormatter(
+ colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
+ datefmt="%Y-%m-%dT%H:%M:%S",
+ )
+ else:
+ formatter = plain_formatter
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+ warnings_logger.addHandler(ch)
+ handlers.append(ch)
+
+ # file logging: all workers
+ if output is None:
+ output = setup_output_folder()
+
+ if output is not None:
+ if output.endswith(".txt") or output.endswith(".log"):
+ filename = output
+ else:
+ filename = os.path.join(output, "train.log")
+ if distributed_rank > 0:
+ filename = filename + f".rank{distributed_rank}"
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+
+ fh = logging.StreamHandler(_cached_log_stream(filename))
+ fh.setLevel(logging_level)
+ fh.setFormatter(plain_formatter)
+ logger.addHandler(fh)
+ warnings_logger.addHandler(fh)
+ handlers.append(fh)
+
+ # Slurm/FB output, only log the main process
+ # save_dir = get_mmf_env(key="save_dir")
+ if "train.log" not in filename and distributed_rank == 0:
+ filename = os.path.join(output, "train.log")
+ sh = logging.StreamHandler(_cached_log_stream(filename))
+ sh.setLevel(logging_level)
+ sh.setFormatter(plain_formatter)
+ logger.addHandler(sh)
+ warnings_logger.addHandler(sh)
+ handlers.append(sh)
+
+ logger.info(f"Logging to: {filename}")
+
+ # Remove existing handlers to add MMF specific handlers
+ if clear_handlers:
+ for handler in logging.root.handlers[:]:
+ logging.root.removeHandler(handler)
+ # Now, add our handlers.
+ logging.basicConfig(level=logging_level, handlers=handlers)
+
+ return logger
+
+
+def setup_very_basic_config(color=True):
+ plain_formatter = logging.Formatter(
+ "%(asctime)s | %(levelname)s | %(name)s : %(message)s",
+ datefmt="%Y-%m-%dT%H:%M:%S",
+ )
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging.INFO)
+ if color:
+ formatter = ColorfulFormatter(
+ colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
+ datefmt="%Y-%m-%dT%H:%M:%S",
+ )
+ else:
+ formatter = plain_formatter
+ ch.setFormatter(formatter)
+ # Setup a minimal configuration for logging in case something tries to
+ # log a message even before logging is setup by MMF.
+ logging.basicConfig(level=logging.INFO, handlers=[ch])
+
+
+# cache the opened file object, so that different calls to `setup_logger`
+# with the same file name can safely write to the same file.
+@functools.lru_cache(maxsize=None)
+def _cached_log_stream(filename):
+ return open(filename, "a")
+
+
+# ColorfulFormatter is adopted from Detectron2 and adapted for MMF
+class ColorfulFormatter(logging.Formatter):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def formatMessage(self, record):
+ log = super().formatMessage(record)
+ if record.levelno == logging.WARNING:
+ prefix = colored("WARNING", "red", attrs=["blink"])
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
+ else:
+ return log
+ return prefix + " " + log
+
+
+class TensorboardLogger:
+ def __init__(self, log_folder="./logs", iteration=0):
+ # This would handle warning of missing tensorboard
+ from torch.utils.tensorboard import SummaryWriter
+
+ self.summary_writer = None
+ self._is_master = is_main_process()
+ # self.timer = Timer()
+ self.log_folder = log_folder
+
+ if self._is_master:
+ # current_time = self.timer.get_time_hhmmss(None, format=self.time_format)
+ current_time = time.strftime("%Y-%m-%dT%H:%M:%S")
+ # self.timer.get_time_hhmmss(None, format=self.time_format)
+ tensorboard_folder = os.path.join(
+ self.log_folder, f"tensorboard_{current_time}"
+ )
+ self.summary_writer = SummaryWriter(tensorboard_folder)
+
+ def __del__(self):
+ if getattr(self, "summary_writer", None) is not None:
+ self.summary_writer.close()
+
+ def _should_log_tensorboard(self):
+ if self.summary_writer is None or not self._is_master:
+ return False
+ else:
+ return True
+
+ def add_scalar(self, key, value, iteration):
+ if not self._should_log_tensorboard():
+ return
+
+ self.summary_writer.add_scalar(key, value, iteration)
+
+ def add_scalars(self, scalar_dict, iteration):
+ if not self._should_log_tensorboard():
+ return
+
+ for key, val in scalar_dict.items():
+ self.summary_writer.add_scalar(key, val, iteration)
+
+ def add_histogram_for_model(self, model, iteration):
+ if not self._should_log_tensorboard():
+ return
+
+ for name, param in model.named_parameters():
+ np_param = param.clone().cpu().data.numpy()
+ self.summary_writer.add_histogram(name, np_param, iteration)
diff --git a/utils/optimizer.py b/utils/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..679483b72556c83d6ff19bc51fe4db41c656b56d
--- /dev/null
+++ b/utils/optimizer.py
@@ -0,0 +1,133 @@
+""" Optimizer Factory w/ Custom Weight Decay
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import re
+import torch
+from torch import optim as optim
+from utils.distributed import is_main_process
+import logging
+logger = logging.getLogger(__name__)
+try:
+ from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
+ has_apex = True
+except ImportError:
+ has_apex = False
+
+
+def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True):
+ named_param_tuples = []
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")):
+ named_param_tuples.append([name, param, 0])
+ elif name in no_decay_list:
+ named_param_tuples.append([name, param, 0])
+ else:
+ named_param_tuples.append([name, param, weight_decay])
+ return named_param_tuples
+
+
+def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr):
+ """use lr=diff_lr for modules named found in diff_lr_names,
+ otherwise use lr=default_lr
+
+ Args:
+ named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module
+ diff_lr_names: List(str)
+ diff_lr: float
+ default_lr: float
+ Returns:
+ named_param_tuples_with_lr: List([name, param, weight_decay, lr])
+ """
+ named_param_tuples_with_lr = []
+ logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}")
+ for name, p, wd in named_param_tuples_or_model:
+ use_diff_lr = False
+ for diff_name in diff_lr_names:
+ # if diff_name in name:
+ if re.search(diff_name, name) is not None:
+ logger.info(f"param {name} use different_lr: {diff_lr}")
+ use_diff_lr = True
+ break
+
+ named_param_tuples_with_lr.append(
+ [name, p, wd, diff_lr if use_diff_lr else default_lr]
+ )
+
+ if is_main_process():
+ for name, _, wd, diff_lr in named_param_tuples_with_lr:
+ logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}")
+
+ return named_param_tuples_with_lr
+
+
+def create_optimizer_params_group(named_param_tuples_with_lr):
+ """named_param_tuples_with_lr: List([name, param, weight_decay, lr])"""
+ group = {}
+ for name, p, wd, lr in named_param_tuples_with_lr:
+ if wd not in group:
+ group[wd] = {}
+ if lr not in group[wd]:
+ group[wd][lr] = []
+ group[wd][lr].append(p)
+
+ optimizer_params_group = []
+ for wd, lr_groups in group.items():
+ for lr, p in lr_groups.items():
+ optimizer_params_group.append(dict(
+ params=p,
+ weight_decay=wd,
+ lr=lr
+ ))
+ logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}")
+ return optimizer_params_group
+
+
+def create_optimizer(args, model, filter_bias_and_bn=True):
+ opt_lower = args.opt.lower()
+ weight_decay = args.weight_decay
+ # check for modules that requires different lr
+ if hasattr(args, "different_lr") and args.different_lr.enable:
+ diff_lr_module_names = args.different_lr.module_names
+ diff_lr = args.different_lr.lr
+ else:
+ diff_lr_module_names = []
+ diff_lr = None
+
+ no_decay = {}
+ if hasattr(model, 'no_weight_decay'):
+ no_decay = model.no_weight_decay()
+ named_param_tuples = add_weight_decay(
+ model, weight_decay, no_decay, filter_bias_and_bn)
+ named_param_tuples = add_different_lr(
+ named_param_tuples, diff_lr_module_names, diff_lr, args.lr)
+ parameters = create_optimizer_params_group(named_param_tuples)
+
+ if 'fused' in opt_lower:
+ assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
+
+ opt_args = dict(lr=args.lr, weight_decay=weight_decay)
+ if hasattr(args, 'opt_eps') and args.opt_eps is not None:
+ opt_args['eps'] = args.opt_eps
+ if hasattr(args, 'opt_betas') and args.opt_betas is not None:
+ opt_args['betas'] = args.opt_betas
+ if hasattr(args, 'opt_args') and args.opt_args is not None:
+ opt_args.update(args.opt_args)
+
+ opt_split = opt_lower.split('_')
+ opt_lower = opt_split[-1]
+ if opt_lower == 'sgd' or opt_lower == 'nesterov':
+ opt_args.pop('eps', None)
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'momentum':
+ opt_args.pop('eps', None)
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
+ elif opt_lower == 'adam':
+ optimizer = optim.Adam(parameters, **opt_args)
+ elif opt_lower == 'adamw':
+ optimizer = optim.AdamW(parameters, **opt_args)
+ else:
+ assert False and "Invalid optimizer"
+ raise ValueError
+ return optimizer
diff --git a/utils/scheduler.py b/utils/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5d050fb0d95d8213651b36558668df969694d73
--- /dev/null
+++ b/utils/scheduler.py
@@ -0,0 +1,56 @@
+""" Scheduler Factory
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from torch.optim import Optimizer
+import math
+from torch.optim.lr_scheduler import LambdaLR
+
+
+def create_scheduler(args, optimizer):
+ lr_scheduler = None
+ if args.sched == 'cosine':
+ lr_scheduler = get_cosine_schedule_with_warmup(
+ optimizer,
+ num_warmup_steps=args.num_warmup_steps,
+ num_training_steps=args.num_training_steps,
+ num_cycles=0.5,
+ min_lr_multi=args.min_lr_multi
+ )
+ return lr_scheduler
+
+
+def get_cosine_schedule_with_warmup(
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int,
+ num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1
+):
+ """
+ Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py
+
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+ initial lr set in the optimizer.
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`float`, *optional*, defaults to 0.5):
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+ following a half-cosine).
+ min_lr_multi (`float`, *optional*, defaults to 0):
+ The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps)))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)